python 3.6 生成器

最近在做bert文本分类,有一个生成器,记录一下使用,跟我网上查到的不太一样,主要在.iter()这个地方,很多代码都是没有这个,不知道是不是版本原因
另外需要注意,自定义的生成器需要注意什么时候结束,不然会一直产生数据

datalist, labellist = get_data_from_excel(r'data/test.xlsx')
data = data_generator(datalist).__iter__() # 注意这个.__iter__()
# 获取一批数据
print(next(data))
# 或者
for x in data:
   print(x)
点击查看代码
class data_generator:
    """
    data_generator只是一种为了节约内存的数据方式
    """
    def __init__(self, data, batch_size=Batch_size, shuffle=True):
        """
        :param data: 训练的文本列表
        :param batch_size:  每次训练的个数
        :param shuffle: 文本是否打乱
        """
        self.data = data
        self.batch_size = batch_size
        self.shuffle = shuffle
        self.steps = len(self.data) // self.batch_size
        if len(self.data) % self.batch_size != 0:
            self.steps += 1

    def __len__(self):
        return self.steps

    def __iter__(self):
        while True:
            idxs = list(range(len(self.data)))  # 生成一个序列
            if self.shuffle:
                np.random.shuffle(idxs)  # 打乱序列
            X1, X2, Y = [], [], []
            for i in idxs:
                d = self.data[i]
                text = d[0][:maxlen]
                x1, x2 = tokenizer.encode(first=text)  # 添加[CLS]和[SEP]
                y = d[1]
                X1.append(x1)
                X2.append(x2)
                Y.append([y])
                if len(X1) == self.batch_size or i == idxs[-1]:
                    # 对一批数据(最后一批不满batch_size)进行padding
                    X1 = seq_padding(X1)  # 内部转为了np.array
                    X2 = seq_padding(X2)
                    Y = seq_padding(Y)
                    yield [X1, X2], Y[:, 0, :]
                    [X1, X2, Y] = [], [], []
posted @ 2022-06-05 21:36  ecnu_lxz  阅读(44)  评论(0编辑  收藏  举报