Python 之懒惰的生成器
问题起源于为了完成一个多任务的训练,自己写了个 BatchSampler,如下:
class MultiTaskBatchSampler(BatchSampler):
def __init__(self, datasets: MultiTaskDataset, batch_size: int, shuffle=True):
super(MultiTaskBatchSampler, self).__init__(sampler=Sampler(datasets), batch_size=batch_size,
drop_last=False)
self.datasets_length = {task_id: len(dataset) for
task_id, dataset in datasets.datasets.items()}
self.shuffle = shuffle
self.batch_size = batch_size
self.task_indexes = []
self.batch_indexes = {}
self.task_num_batches = {}
self.total_batches = 0
self.init()
def init(self):
for task_id, dataset_len in self.datasets_length.items():
num_batches = (dataset_len - 1) // self.batch_size + 1
self.batch_indexes[task_id] = list(range(dataset_len))
self.task_num_batches[task_id] = num_batches
self.total_batches += num_batches
self.task_indexes.extend([task_id] * num_batches)
def __len__(self) -> int:
return self.total_batches
def __iter__(self) -> Iterable:
batch_generator = self.get_batch_generator()
for task_id in self.task_indexes:
current_indexes_gen = batch_generator[task_id]
batch = next(current_indexes_gen)
yield [(task_id, index) for index in batch]
def get_batch_generator(self) -> Dict[str, Iterable]:
if self.shuffle:
random.shuffle(self.task_indexes)
batch_generator = {}
for task_id, batch_indexes in self.batch_indexes.items():
if self.shuffle:
random.shuffle(batch_indexes)
# 错误的写法:
# batch_generator[task_id] = (batch_indexes[i * self.batch_size: (i + 1) * self.batch_size]
# for i in range(self.task_num_batches[task_id]))
batch_generator[task_id] = iter([batch_indexes[i * self.batch_size: (i + 1) * self.batch_size]
for i in range(self.task_num_batches[task_id])])
return batch_generator
由于错误的写法中使用的是生成器,在构造的时候并没有立即使用 batch_indexes 来为当前任务生成批量索引,在 __iter__
方法中使用 next 调用的时候才会激活生成器产生批量索引,而此时的 batch_indexes 并不是我认为的当前 task 对应的 batch_indexes,它使用的是 get_batch_generator
中循环完成后最后一个任务对应的 batch_indexes,这时如果每个任务的数据量不一致就会导致在 Dataset 中根据索引获取数据的时候出错。
折腾了两个多小时 (还是基础不够啊 😭),即便是自己深信不疑的东西也要动手去 debug,不能光凭已有的认知空想,即便知道生成器懒惰的特性,但是想想自己好像却从未真正使用过。