代码改变世界

常用Python函数

2018-12-12 17:35  JohnRain  阅读(153)  评论(0编辑  收藏  举报

自动缓存文件

读取网络硬盘上的文件时,常常因为网速问题导致大量的时间浪费在IO操作上, 这个方法在第一次调用时会自动将网络文件缓存到本地临时文件夹
在第二次运行时就会调用本地的缓存文件, 免去网络IO的限制.

import shutil
import time
class auto_cache(object):
    def __init__(self, myfile, cache_path="/tmp"):
        assert os.path.isfile(myfile)
        path, filename = os.path.split(myfile)
        self.cache_file = os.path.join(cache_path, filename)
        if not os.path.isfile(self.cache_file):
            print("Caching {} from {} to {}".format(filename, path, cache_path))
            start_time = time.perf_counter()
            shutil.copy(myfile, self.cache_file)
            end_time = time.perf_counter()
            print("Elapsed Time: {:.2f}s".format(end_time-start_time))
        else:
            print("Cache file {} been detected!".format(self.cache_file))

    def __enter__(self):
        return self.cache_file

    def __exit__(self, exc_type, exc_val, exc_tb):
        return True

使用方法举例:

# original code
data = pd.read_csv(filename)

# new code
with auto_cache(filename) as filename:
    data = pd.read_csv(filename)

自动计时方法

class tic_toc(object):
    def __init__(self, comment):
        self.comment = comment

    def __enter__(self):
        self.st_time = time.time()

    def __exit__(self,a,b,c):
        ed_time = time.time()
        print(self.comment+", 耗时:{:.2f}s\n".format(ed_time-self.st_time))
import sys

class print_and_save(object):
    def __init__(self, filepath):
        self.file = open(filepath, 'a')
        self.old = sys.stdout  # 将当前系统输出储存到临时变量
        sys.stdout = self

    def __enter__(self):
        pass

    def __call__(self,func):
        def wrapper(*args, **kwargs):
            frs = func(*args, **kwargs)
            self._exit()
            return frs
        return wrapper

    
    def write(self, message):
        self.old.write(message)
        self.file.write(message)

    def flush(self):
        self.old.flush()
        self.file.flush()

    def __exit__(self, exc_type, exc_val, exc_tb):
        self._exit()

    def _exit(self):
        self.file.flush()
        self.file.close()
        sys.stdout = self.old
def zeroPadding(seqs, fillvalue=0, max_seq_length=None):
    '''转置补零,输出的句子的列表示句子的词标
    '''
    if max_seq_length:
        assert max_seq_length > 0
        seqs = [seq[:min(len(seq), max_seq_length)] for seq in seqs]
    return np.array(list(itertools.zip_longest(*seqs, fillvalue=fillvalue)))
def batch_itr(data, batch_size):
    '''返回的每个batch都是排序好的

    '''
    data_size = len(data)
    ids = list(range(data_size))
    random.shuffle(ids)
    num_batch = int(data_size/batch_size)
    for i in range(num_batch):
        gen_data = data[i*batch_size:min((i+1)*batch_size, data_size)]
        gen_data.sort(key=lambda x: len(x[0]), reverse=True)
        yield np.array(gen_data)
def model_load(resultpath):
    '''模型重载
    '''
    if not os.path.exists(resultpath):
        os.makedirs(resultpath,exist_ok=True)
        return False,0

    dirlist = sorted(os.listdir(resultpath), reverse=True)
    print("============== 模型列表 ======================")
    for i, p in enumerate(dirlist):
        print("\t[{}]: {}".format(i, p))
    print("==============================================")
    model_id = int(input("请选择模型:"))
    try:
        path = os.path.join(resultpath, dirlist[model_id])
        with open(os.path.join(path, "checkpoint"), "r") as f:
            checkpoint, model_save_name = f.read().split(",")
        with open(model_save_name, "rb") as f:
            model = pickle.load(f)
        print("成功加载模型:{}".format(os.path.join(path, model_save_name)))
        return int(checkpoint), model
    except Exception as e:
        print("模型加载失败:{}".format(e))
        return False, 0