pytorch的dataset与dataloader解析
整理一下pytorch获取的流程:
- 创建Dataset对象
- 创建DataLoader对象,装载有dataset对象
- 循环DataLoader对象,DataLoader.__iter__返回的是DataLoaderIter对象
dataset = MyDataset() dataloader = DataLoader(dataset) num_epoches = 100 for epoch in range(num_epoches): for data in dataloader: ....
根据源码分析:torch.utils.data
1 - Dataset:
class Dataset(object): """An abstract class representing a Dataset. All other datasets should subclass it. All subclasses should override ``__len__``, that provides the size of the dataset, and ``__getitem__``, supporting integer indexing in range from 0 to len(self) exclusive. """ def __getitem__(self, index): raise NotImplementedError def __len__(self): raise NotImplementedError def __add__(self, other): return ConcatDataset([self, other])
Dataset这是一个抽象类,不能实例化,需要重写类方法,关键点有两个:
-
__getitem__ 这个很重要,规定了如何读数据,比如常用的transform
-
__len__ 这个就是返回数据集的长度,比如:return len(self.data)
2 - DataLoader:
DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0, collate_fn=None, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None)
先看一下主要参数:
- dataset:就是 torch.utils.data.Dataset 类的实例。也就是说为了使用 DataLoader 类,需要先定义一个 torch.utils.data.Dataset 类的实例。
- batch_size:每一个批次需要加载的训练样本个数。
- shuffle:如果设置为 True 表示训练样本数据会被随机打乱,默认值为 False。一般会设置为 True 。
- sampler:自定义从数据集中取样本的策略,如果指定这个参数,那么 shuffle 必须为 False 。从源码中可以看到,如果指定了该参数,同时 shuffle 设定为 True,DataLoader 的 __init__ 函数就会抛出一个异常 。
- batch_sampler:与 sampler 类似,但是一次只返回一个 batch 的 indices(索引),需要注意的是,一旦指定了这个参数,那么 batch_size,shuffle,sampler,drop_last 就不能再指定了。源码中同样做了限制。
- num_workers:表示会使用多少个线程来加载训练数据;默认值为 0,表示数据加载直接在主线程中进行。
- collate_fn:对每一个 batch 的数据做一些你想要的操作。一个例子,https://zhuanlan.zhihu.com/p/346332974
- pin_memory:把数据转移到和 GPU 相关联的 CPU 内存,加速 GPU 载入数据的速度。
- drop_last:比如你的batch_size设置为 32,而一个 epoch 只有 100 个样本;如果设置为 True,那么训练的时候后面的 4 个就被扔掉了。如果为 False(默认),那么会继续正常执行,只是最后的 batch_size 会小一点。
- timeout:加载一个 batch 数据的超时时间。
- worker_init_fn:指定每个数据加载线程的入口函数。
源码分析:
class DataLoader(object): def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0, collate_fn=default_collate, pin_memory=False, drop_last=False): self.dataset = dataset self.batch_size = batch_size self.num_workers = num_workers self.collate_fn = collate_fn self.pin_memory = pin_memory self.drop_last = drop_last if batch_sampler is not None: if batch_size > 1 or shuffle or sampler is not None or drop_last: raise ValueError('batch_sampler is mutually exclusive with ' 'batch_size, shuffle, sampler, and drop_last') if sampler is not None and shuffle: raise ValueError('sampler is mutually exclusive with shuffle') if batch_sampler is None: if sampler is None: if shuffle: # dataset.__len__() 在 Sampler 中被使用。 # 目的是生成一个 长度为 len(dataset) 的 序列索引(随机的)。 sampler = RandomSampler(dataset) else: # dataset.__len__() 在 Sampler 中被使用。 # 目的是生成一个 长度为 len(dataset) 的 序列索引(顺序的)。 sampler = SequentialSampler(dataset) # Sampler 是个迭代器,一次之只返回一个 索引 # BatchSampler 也是个迭代器,但是一次返回 batch_size 个 索引 batch_sampler = BatchSampler(sampler, batch_size, drop_last) self.sampler = sampler self.batch_sampler = batch_sampler def __iter__(self): return DataLoaderIter(self) def __len__(self): return len(self.batch_sampler)
可以发现__iter__返回的是DataLoaderIter
3 - DataLoaderIter
先看init初始化:
if self.num_workers > 0: self.worker_init_fn = loader.worker_init_fn
# 定义了workers相同数量个Queue并放置在index_queues这个list中, # 这些Queue与worker一一对应,用来给worker传递“工作内容” self.index_queues = [multiprocessing.Queue() for _ in range(self.num_workers)]
# worker_queue_idx用于下一个工作的workre序号,主进程轮询使用不同workers self.worker_queue_idx = 0
# 各个workre将自己所取得的数据传递给wokrker_result_queue,供主进程fetch self.worker_result_queue = multiprocessing.SimpleQueue() # 记录当前时刻分配了多少个任务(可能有处于等待状态的任务) self.batches_outstanding = 0 self.worker_pids_set = False self.shutdown = False # 发送出去数据的编号 self.send_idx = 0 # 接受到数据的编号 self.rcvd_idx = 0 # 缓存区 self.reorder_dict = {} self.workers = [ multiprocessing.Process( target=_worker_loop, args=(self.dataset, self.index_queues[i], self.worker_result_queue, self.collate_fn, base_seed + i, self.worker_init_fn, i)) for i in range(self.num_workers)] # 初始化相应的进程,目标函数为_worker_loop # 参数:dataset(用于数据读取),index_queues[i]为worker对应的index_queue # 以及用于输出的queue # 此处主要用于数据读取后的pin_memory操作,不影响多进程主逻辑,暂不展开 if self.pin_memory or self.timeout > 0: ... else: self.data_queue = self.worker_result_queue for w in self.workers: w.daemon = True # ensure that the worker exits on process exit # 将父进程设置为守护进程,保证父进程结束后,worker进程也结束,必须设置在start之前 w.start() # 下面是一些系统信号处理逻辑,对这方面我还不太熟悉就不介绍了。 _update_worker_pids(id(self), tuple(w.pid for w in self.workers)) _set_SIGCHLD_handler() self.worker_pids_set = True # 初始化后生成2*num_workers数量个prefetch的数据,使dataloader提前工作,提升整体效率。 # prime the prefetch loop for _ in range(2 * self.num_workers): self._put_indices()
init过程有两个函数,一个是worker_loop,另个是put_indices
a. 先看worker_loop:
def _worker_loop(dataset, index_queue, data_queue, collate_fn, seed, init_fn, worker_id): global _use_shared_memory _use_shared_memory = True # Intialize C side signal handlers for SIGBUS and SIGSEGV. Python signal # module's handlers are executed after Python returns from C low-level # handlers, likely when the same fatal signal happened again already. # https://docs.python.org/3/library/signal.html Sec. 18.8.1.1 _set_worker_signal_handlers() torch.set_num_threads(1) random.seed(seed) torch.manual_seed(seed) if init_fn is not None: init_fn(worker_id) # 父进程状态监测 watchdog = ManagerWatchdog() # 死循环查询是否有任务传进来 while True: try: # 从index_queue获取相应数据 r = index_queue.get(timeout=MANAGER_STATUS_CHECK_INTERVAL) except queue.Empty: if watchdog.is_alive(): continue else: break if r is None: break idx, batch_indices = r try: # 获得以后for循环进行读取数据读取,此处和单进程的工作原理是一样的 # 因此时间花费和batchsize数量呈线性关系 samples = collate_fn([dataset[i] for i in batch_indices]) # 经过collate_fn后变成torch.Tensor except Exception: # 异常处理 data_queue.put((idx, ExceptionWrapper(sys.exc_info()))) else: # 通过data_queue传回处理好的batch数据 data_queue.put((idx, samples)) # 显示删除中间变量,降低内存消耗 del samples
这里就是不停地轮询,从index_queues队列里获得索引,然后通过collate_fn函数和索引获取tensor,然后塞入data_queue。
b. 再看put_indices
def _put_indices(self): assert self.batches_outstanding < 2 * self.num_workers # 默认设定是只允许分配2*num_workers个任务,保证内存等资源不被耗尽 indices = next(self.sample_iter, None) # 从sample_iter中拿到dataset中下一轮次的索引,用于fetch数据 if indices is None: return self.index_queues[self.worker_queue_idx].put((self.send_idx, indices)) # 轮询选择worker,找到其对应的队列,向其中发送工作内容(数据编号,数据索引) self.worker_queue_idx = (self.worker_queue_idx + 1) % self.num_workers # worker_queue_idx自增 self.batches_outstanding += 1 # 任务分配数+1 self.send_idx += 1 # 已发送任务总数+1(下批数据编号)
这个就是把索引塞进队列index_queues
以上就是init,当for循环时,会调用next:
c. __next__返回一个batch
def __next__(self): if self.num_workers == 0: # same-process loading (主进程阻塞式读取数据) indices = next(self.sample_iter) # may raise StopIteration batch = self.collate_fn([self.dataset[i] for i in indices]) if self.pin_memory: batch = pin_memory_batch(batch) return batch # check if the next sample has already been generated # 先查看数据是否在缓存dict中 if self.rcvd_idx in self.reorder_dict: batch = self.reorder_dict.pop(self.rcvd_idx) return self._process_next_batch(batch) # 异常处理 if self.batches_outstanding == 0: self._shutdown_workers() raise StopIteration while True: assert (not self.shutdown and self.batches_outstanding > 0) # 阻塞式的从data_queue里面获取处理好的批数据 idx, batch = self._get_batch() # 任务数减一 self.batches_outstanding -= 1 # 这一步可能会造成的周期阻塞现象 # 每次获取data以后,要校验和rcvd_idx是否一致 # 若不一致,则先把获取到的数据放到reorder_dict这个缓存dict中,继续死循环 # 直到获取到相应的idx编号于rcvd_idx可以对应上,并将数据返回 if idx != self.rcvd_idx: # store out-of-order samples self.reorder_dict[idx] = batch continue return self._process_next_batch(batch)
__next__里的while True,要从data_queue里面读到的数据idx和rcvd_idx一致才将数据返回。因此可能会存在如下这种情况:
假设num_workers=8,现在发送了8个数据给相应的worker,此时send_idx=8,rcvd_idx=0。过了一段时间以后,{1,2,3,5,6,7}进程数据准备完毕,此时主进程从data_queue读取到相关的数据,但由于和rcvd_idx不匹配,只能将其放在缓存里。直到send_idx=0数据准备齐以后,才能将数据返回出去,随后从缓存中弹出2,3的数据,之后又阻塞等待idx=4的数据。即输出的数据必须保持顺序性!因此在worker变多,出现这种逆序现象可能性会更大,这种现象也会出现在非num_workrers次迭代,只要相应的rcvd_idx没有得到相关数据,则主进程就会一直等待。
d. process_next_batch
def _process_next_batch(self, batch): # 序号对上以后,rcvd_idx自加1 self.rcvd_idx += 1 # 添加一个fetchdata任务给worker self._put_indices() if isinstance(batch, ExceptionWrapper): raise batch.exc_type(batch.exc_msg) return batch
这个函数注意的是,只有在__next__中,idx == self.rcvd_idx时才会调用,也就是可能出现多个worker已经准备好了,但是只能放在缓存区,并且无法向index_queues塞入索引,使worker无法保持活跃状态。
最后对于for循环从dataloader获取data总体流程:
for epoch in range(num_epoches): for data in dataloader:
对于这个for,其实就是调用了dataloader 的__iter__() 方法, 产生了一个DataLoaderIter,如果是num_worker>0,init里就会创建多线程,并且有两个队列,一个是存放dataset的索引index_queues,一个是从index_queues里拿到索引,调用dataset的__getitem__()方法 (如果num_worker>0就多线程调用), 然后用collate_fn来把它们打包成batch,放到data_queue队列里,反复调用DataLoaderIter 的__next__,从data_queue中获取batch。
参考:
Pytorch数据读取(Dataset, DataLoader, DataLoaderIter) https://zhuanlan.zhihu.com/p/30934236
PyTorch 之 Dataset 和 Dataloader https://zhuanlan.zhihu.com/p/339675188
PyTorch36.DataLoader源代码剖析 https://zhuanlan.zhihu.com/p/169497395
PyTorch DataLoader初探 https://zhuanlan.zhihu.com/p/91521705
一文弄懂Pytorch的DataLoader, DataSet, Sampler之间的关系 https://zhuanlan.zhihu.com/p/76893455