Python 之懒惰的生成器

问题起源于为了完成一个多任务的训练,自己写了个 BatchSampler,如下:

class MultiTaskBatchSampler(BatchSampler):

    def __init__(self, datasets: MultiTaskDataset, batch_size: int, shuffle=True):
        super(MultiTaskBatchSampler, self).__init__(sampler=Sampler(datasets), batch_size=batch_size,
                                                    drop_last=False)
        self.datasets_length = {task_id: len(dataset) for
                                task_id, dataset in datasets.datasets.items()}
        self.shuffle = shuffle
        self.batch_size = batch_size
        self.task_indexes = []
        self.batch_indexes = {}
        self.task_num_batches = {}
        self.total_batches = 0
        self.init()

    def init(self):
        for task_id, dataset_len in self.datasets_length.items():
            num_batches = (dataset_len - 1) // self.batch_size + 1
            self.batch_indexes[task_id] = list(range(dataset_len))
            self.task_num_batches[task_id] = num_batches
            self.total_batches += num_batches
            self.task_indexes.extend([task_id] * num_batches)

    def __len__(self) -> int:
        return self.total_batches

    def __iter__(self) -> Iterable:
        batch_generator = self.get_batch_generator()
        for task_id in self.task_indexes:
            current_indexes_gen = batch_generator[task_id]
            batch = next(current_indexes_gen)
            yield [(task_id, index) for index in batch]

    def get_batch_generator(self) -> Dict[str, Iterable]:
        if self.shuffle:
            random.shuffle(self.task_indexes)
        batch_generator = {}
        for task_id, batch_indexes in self.batch_indexes.items():
            if self.shuffle:
                random.shuffle(batch_indexes)
            # 错误的写法:
            # batch_generator[task_id] = (batch_indexes[i * self.batch_size: (i + 1) * self.batch_size]
            #                                  for i in range(self.task_num_batches[task_id]))
            batch_generator[task_id] = iter([batch_indexes[i * self.batch_size: (i + 1) * self.batch_size]
                                            for i in range(self.task_num_batches[task_id])])
        return batch_generator

由于错误的写法中使用的是生成器,在构造的时候并没有立即使用 batch_indexes 来为当前任务生成批量索引,在 __iter__ 方法中使用 next 调用的时候才会激活生成器产生批量索引,而此时的 batch_indexes 并不是我认为的当前 task 对应的 batch_indexes,它使用的是 get_batch_generator 中循环完成后最后一个任务对应的 batch_indexes,这时如果每个任务的数据量不一致就会导致在 Dataset 中根据索引获取数据的时候出错。

折腾了两个多小时 (还是基础不够啊 😭),即便是自己深信不疑的东西也要动手去 debug,不能光凭已有的认知空想,即便知道生成器懒惰的特性,但是想想自己好像却从未真正使用过。

posted @ 2020-12-11 17:42  GNEPIUHUX  阅读(129)  评论(0编辑  收藏  举报