PyTorch之DataLoader杂谈
输入数据PipeLine
pytorch 的数据加载到模型的操作顺序是这样的:
①创建一个 Dataset
对象
②创建一个 DataLoader
对象
③循环这个 DataLoader
对象,将img, label加载到模型中进行训练
dataset = MyDataset() dataloader = DataLoader(dataset) num_epoches = 100 for epoch in range(num_epoches): for img, label in dataloader: ....
所以,作为直接对数据进入模型中的关键一步, DataLoader非常重要。
首先简单介绍一下DataLoader,它是PyTorch中数据读取的一个重要接口,该接口定义在dataloader.py中,只要是用PyTorch来训练模型基本都会用到该接口(除非用户重写…),该接口的目的:将自定义的Dataset根据batch size大小、是否shuffle等封装成一个Batch Size大小的Tensor,用于后面的训练。
DataLoader的官方说明是:
“数据加载由数据集和采样器组成,基于python的单、多进程的iterators来处理数据。”
关于iterator和iterable的区别和概念若有兴趣请自行查阅,在实现中的差别就是iterators有__iter__和__next__方法,而iterable只有__iter__方法。
1.DataLoader源码剖析
在分析源码之前,先介绍一下DataLoader(object)的参数:
- dataset(Dataset): 传入的数据集
- batch_size(int, optional): 每个batch有多少个样本
- shuffle(bool, optional): 在每个epoch开始的时候,对数据进行重新排序
- sampler(Sampler, optional): 自定义从数据集中取样本的策略,如果指定这个参数,那么shuffle必须为False
- batch_sampler(Sampler, optional): 与sampler类似,但是一次只返回一个batch的indices(索引),需要注意的是,一旦指定了这个参数,那么batch_size,shuffle,sampler,drop_last就不能再制定了(互斥——Mutually exclusive)
- num_workers (int, optional): 这个参数决定了有几个进程来处理data loading。0意味着所有的数据都会被load进主进程。(默认为0)
- collate_fn (callable, optional): 将一个list的sample组成一个mini-batch的函数
- pin_memory (bool, optional): 如果设置为True,那么data loader将会在返回它们之前,将tensors拷贝到CUDA中的固定内存(CUDA pinned memory)中.
- drop_last (bool, optional): 如果设置为True:这个是对最后的未完成的batch来说的,比如你的batch_size设置为64,而一个epoch只有100个样本,那么训练的时候后面的36个就被扔掉了…如果为False(默认),那么会继续正常执行,只是最后的batch_size会小一点。
- timeout(numeric, optional): 如果是正数,表明等待从worker进程中收集一个batch等待的时间,若超出设定的时间还没有收集到,那就不收集这个内容了。这个numeric应总是大于等于0。默认为0
- worker_init_fn (callable, optional): 每个worker初始化函数 If not None, this will be called on each worker subprocess with the worker id (an int in [0, num_workers - 1]) as input, after seeding and before data loading. (default: None)
显然,根据上面参数的解释,DataLoader
这个类就是进行数据的初始化的操作,好了,下面来看源码吧:
class DataLoader(object): __initialized = False def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0, collate_fn=default_collate, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None): self.dataset = dataset self.batch_size = batch_size self.num_workers = num_workers self.collate_fn = collate_fn self.pin_memory = pin_memory self.drop_last = drop_last self.timeout = timeout self.worker_init_fn = worker_init_fn if timeout < 0: raise ValueError('timeout option should be non-negative') if batch_sampler is not None: if batch_size > 1 or shuffle or sampler is not None or drop_last: raise ValueError('batch_sampler option is mutually exclusive ' 'with batch_size, shuffle, sampler, and ' 'drop_last') self.batch_size = None self.drop_last = None if sampler is not None and shuffle: raise ValueError('sampler option is mutually exclusive with ' 'shuffle') if self.num_workers < 0: raise ValueError('num_workers option cannot be negative; ' 'use num_workers=0 to disable multiprocessing.') if batch_sampler is None: if sampler is None: if shuffle: sampler = RandomSampler(dataset) else: sampler = SequentialSampler(dataset) batch_sampler = BatchSampler(sampler, batch_size, drop_last) self.sampler = sampler self.batch_sampler = batch_sampler self.__initialized = True def __setattr__(self, attr, val): if self.__initialized and attr in ('batch_size', 'sampler', 'drop_last'): raise ValueError('{} attribute should not be set after {} is ' 'initialized'.format(attr, self.__class__.__name__)) super(DataLoader, self).__setattr__(attr, val) def __iter__(self): return _DataLoaderIter(self) def __len__(self): return len(self.batch_sampler)
这里主要看__init__()和__iter__():
①数据的shuffle和batch处理
- RandomSampler(dataset)
- SequentialSampler(dataset)
- BatchSampler(sampler, batch_size, drop_last)
②因为DataLoader只有__iter__()而没有实现__next__(),所以DataLoader是一个iterable而不是iterator。这个iterator的实现在_DataLoaderIter中。
1.1DataLoader之RandomSampler(dataset)、 SequentialSampler(dataset)
实现是在dataloader.py
的同级目录下的torch/utils/data/sampler.py。sampler.py中实现了一个父类Sampler
,以及SequentialSampler
,RandomSampler
和BatchSampler
等五个继承Sampler
的子类。对每个采样器,都需要提供__iter__方法用以表示数据遍历的方式和__len__方法用以返回数据的长度。
class Sampler(object): r"""Base class for all Samplers. Every Sampler subclass has to provide an __iter__ method, providing a way to iterate over indices of dataset elements, and a __len__ method that returns the length of the returned iterators. """ def __init__(self, data_source): pass def __iter__(self): raise NotImplementedError def __len__(self): raise NotImplementedError class SequentialSampler(Sampler): r"""Samples elements sequentially, always in the same order. Arguments: data_source (Dataset): dataset to sample from """ def __init__(self, data_source): self.data_source = data_source def __iter__(self): return iter(range(len(self.data_source))) def __len__(self): return len(self.data_source) class RandomSampler(Sampler): r"""Samples elements randomly, without replacement. Arguments: data_source (Dataset): dataset to sample from """ def __init__(self, data_source): self.data_source = data_source def __iter__(self): return iter(torch.randperm(len(self.data_source)).tolist()) def __len__(self): return len(self.data_source) if __name__ == "__main__": print(list(RandomSampler(range(10)))) #[2, 8, 3, 5, 9, 4, 6, 0, 1, 7] print(list(SequentialSampler(range(10)))) #[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
可以看出RandomSampler
等方法返回的就是DataSet
中的索引位置(indices),其中,在子类中的__iter__
方法中,需要返回的是iter(xxx)
(即iterator)的形式:
#### 以下两个代码是等价的 for data in dataloader: ... #### 等价与 iters = iter(dataloader) while 1: try: next(iters) except StopIteration: break
1.2 DataLoader之BatchSampler(Sampler)
BatchSampler
是wrap一个sampler,并生成mini-batch的索引(indices)的方式
这里主要看__iter__
方法,可以看到,代码的思路很清楚明白的展示了batch indices的是如何取出的.
class BatchSampler(Sampler): r"""Wraps another sampler to yield a mini-batch of indices. Args: sampler (Sampler): Base sampler. batch_size (int): Size of mini-batch. drop_last (bool): If ``True``, the sampler will drop the last batch if its size would be less than ``batch_size`` Example: >>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=False)) [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]] >>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=True)) [[0, 1, 2], [3, 4, 5], [6, 7, 8]] """ def __init__(self, sampler, batch_size, drop_last): if not isinstance(sampler, Sampler): raise ValueError("sampler should be an instance of " "torch.utils.data.Sampler, but got sampler={}" .format(sampler)) if not isinstance(batch_size, _int_classes) or isinstance(batch_size, bool) or \ batch_size <= 0: raise ValueError("batch_size should be a positive integeral value, " "but got batch_size={}".format(batch_size)) if not isinstance(drop_last, bool): raise ValueError("drop_last should be a boolean value, but got " "drop_last={}".format(drop_last)) self.sampler = sampler self.batch_size = batch_size self.drop_last = drop_last def __iter__(self): batch = [] # 一旦达到batch_size的长度,说明batch被填满,就可以yield出去了 for idx in self.sampler: batch.append(idx) if len(batch) == self.batch_size: yield batch batch = [] if len(batch) > 0 and not self.drop_last: yield batch def __len__(self): # 比如epoch有100个样本,batch_size选择为64,那么drop_last的结果为1,不drop_last的结果为2 if self.drop_last: return len(self.sampler) // self.batch_size else: return (len(self.sampler) + self.batch_size - 1) // self.batch_size if __name__ == "__main__": print(list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=False))) # [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]] print(list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=True))) # [[0, 1, 2], [3, 4, 5], [6, 7, 8]]
1.3 _DataLoaderIter
_DataLoaderIter其实就是DataLoader
类的__iter__()
方法的返回值:
class _DataLoaderIter(object): r"""Iterates once over the DataLoader's dataset, as specified by the sampler""" # NOTE [ Data Loader Multiprocessing Shutdown Logic ] # # Preliminary: # # Our data model looks like this (queues are indicated with curly brackets): # # main process || # | || # {index_queue} || # | || # worker processes || DATA # | || # {worker_result_queue} || FLOW # | || # pin_memory_thread of main process || DIRECTION # | || # {data_queue} || # | || # data output \/ # # P.S. `worker_result_queue` and `pin_memory_thread` part may be omitted if # `pin_memory=False`. # # # Terminating multiprocessing logic requires very careful design. In # particular, we need to make sure that # # 1. The iterator gracefully exits the workers when its last reference is # gone or it is depleted. # # In this case, the workers should be gracefully exited because the # main process may still need to continue to run, and we want cleaning # up code in the workers to be executed (e.g., releasing GPU memory). # Naturally, we implement the shutdown logic in `__del__` of # DataLoaderIterator. # # We delay the discussion on the logic in this case until later. # # 2. The iterator exits the workers when the loader process and/or worker # processes exits normally or with error. # # We set all workers and `pin_memory_thread` to have `daemon=True`. # # You may ask, why can't we make the workers non-daemonic, and # gracefully exit using the same logic as we have in `__del__` when the # iterator gets deleted (see 1 above)? # # First of all, `__del__` is **not** guaranteed to be called when # interpreter exits. Even if it is called, by the time it executes, # many Python core library resources may alreay be freed, and even # simple things like acquiring an internal lock of a queue may hang. # Therefore, in this case, we actually need to prevent `__del__` from # being executed, and rely on the automatic termination of daemonic # children. Thus, we register an `atexit` hook that sets a global flag # `_utils.python_exit_status`. Since `atexit` hooks are executed in the # reverse order of registration, we are guaranteed that this flag is # set before library resources we use are freed. (Hooks freeing those # resources are registered at importing the Python core libraries at # the top of this file.) So in `__del__`, we check if # `_utils.python_exit_status` is set or `None` (freed), and perform # no-op if so. # # Another problem with `__del__` is also related to the library cleanup # calls. When a process ends, it shuts the all its daemonic children # down with a SIGTERM (instead of joining them without a timeout). # Simiarly for threads, but by a different mechanism. This fact, # together with a few implementation details of multiprocessing, forces # us to make workers daemonic. All of our problems arise when a # DataLoader is used in a subprocess, and are caused by multiprocessing # code which looks more or less like this: # # try: # your_function_using_a_dataloader() # finally: # multiprocessing.util._exit_function() # # The joining/termination mentioned above happens inside # `_exit_function()`. Now, if `your_function_using_a_dataloader()` # throws, the stack trace stored in the exception will prevent the # frame which uses `DataLoaderIter` to be freed. If the frame has any # reference to the `DataLoaderIter` (e.g., in a method of the iter), # its `__del__`, which starts the shutdown procedure, will not be # called. That, in turn, means that workers aren't notified. Attempting # to join in `_exit_function` will then result in a hang. # # For context, `_exit_function` is also registered as an `atexit` call. # So it is unclear to me (@ssnl) why this is needed in a finally block. # The code dates back to 2008 and there is no comment on the original # PEP 371 or patch https://bugs.python.org/issue3050 (containing both # the finally block and the `atexit` registration) that explains this. # # Another choice is to just shutdown workers with logic in 1 above # whenever we see an error in `next`. This isn't ideal because # a. It prevents users from using try-catch to resume data loading. # b. It doesn't prevent hanging if users have references to the # iterator. # # 3. All processes exit if any of them die unexpectedly by fatal signals. # # As shown above, the workers are set as daemonic children of the main # process. However, automatic cleaning-up of such child processes only # happens if the parent process exits gracefully (e.g., not via fatal # signals like SIGKILL). So we must ensure that each process will exit # even the process that should send/receive data to/from it were # killed, i.e., # # a. A process won't hang when getting from a queue. # # Even with carefully designed data dependencies (i.e., a `put()` # always corresponding to a `get()`), hanging on `get()` can still # happen when data in queue is corrupted (e.g., due to # `cancel_join_thread` or unexpected exit). # # For child exit, we set a timeout whenever we try to get data # from `data_queue`, and check the workers' status on each timeout # and error. # See `_DataLoaderiter._get_batch()` and # `_DataLoaderiter._try_get_batch()` for details. # # Additionally, for child exit on non-Windows platforms, we also # register a SIGCHLD handler (which is supported on Windows) on # the main process, which checks if any of the workers fail in the # (Python) handler. This is more efficient and faster in detecting # worker failures, compared to only using the above mechanism. # See `DataLoader.cpp` and `_utils/signal_handling.py` for details. # # For `.get()` calls where the sender(s) is not the workers, we # guard them with timeouts, and check the status of the sender # when timeout happens: # + in the workers, the `_utils.worker.ManagerWatchdog` class # checks the status of the main process. # + if `pin_memory=True`, when getting from `pin_memory_thread`, # check `pin_memory_thread` status periodically until `.get()` # returns or see that `pin_memory_thread` died. # # b. A process won't hang when putting into a queue; # # We use `mp.Queue` which has a separate background thread to put # objects from an unbounded buffer array. The background thread is # daemonic and usually automatically joined when the process # exits. # # However, in case that the receiver has ended abruptly while # reading from the pipe, the join will hang forever. Therefore, # for both `worker_result_queue` (worker -> main process/pin_memory_thread) # and each `index_queue` (main process -> worker), we use # `q.cancel_join_thread()` in sender process before any `q.put` to # prevent this automatic join. # # Moreover, having all queues called `cancel_join_thread` makes # implementing graceful shutdown logic in `__del__` much easier. # It won't need to get from any queue, which would also need to be # guarded by periodic status checks. # # Note that this may leave corrupted data in the queue, but we # don't care about the data anyways once we are shutting down. # # # Now let's get back to 1: # how we gracefully exit the workers when the last reference to the # iterator is gone. # # To achieve this, we implement the following logic along with the design # choices mentioned above: # # [worker processes] # While loader process is alive: # Get from index_queue. # If got a `None`, exit. # If get anything else, # Check `done_event`. # If set, continue to next iteration # i.e., keep getting until see the `None`, then exit. # Otherwise, process data. # If timed out, # No matter `done_event` is set (still need to see `None`) or not, # must continue to next iteration . # # [pin_memory_thread] # # No need to check main thread. If this thread is alive, the main loader # # thread must be alive, because this thread is set as daemonic. # While True: # Get from index_queue. # If got a `None`, exit. # If get anything else, # Check `done_event`. # If set, continue to next iteration # i.e., keep getting until see the `None`, then exit. # Otherwise, process data. # # NOTE: we don't check the status of the main thread because # 1. if the process is killed by fatal signal, `pin_memory_thread` # ends. # 2. in other cases, either the cleaning-up in __del__ or the # automatic exit of daemonic thread will take care of it. # This won't busy-wait either because `.get(timeout)` does not # busy-wait. # # [main process] # In the DataLoader Iter's `__del__` # a. Set `done_event` (shared with `pin_memory_thread` and workers). # # Note: from here on, the workers & `pin_memory_thread` may exit at # any time after they receive `None`. # # b. Exit `pin_memory_thread` # i. Put `None` in `worker_result_queue`. # ii. Join the `pin_memory_thread`. # # c. Exit the workers. # i. Put `None` in each worker's `index_queue`. # ii. Join the workers. # # NOTE: This has to be after (b) because it may leave corrupted data # in `worker_result_queue`, which `pin_memory_thread` reads # from. # # NOTE: If `pin_memory=False`, there is no `pin_memory_thread` and (b) # can be omitted # # NB: `done_event`s isn't strictly needed. E.g., we can just check for # `None` from `index_queue`, but it allows us to skip wasting resources # processing indices already in `index_queue` if we are already shutting # down. def __init__(self, loader): self.dataset = loader.dataset self.collate_fn = loader.collate_fn self.batch_sampler = loader.batch_sampler self.num_workers = loader.num_workers self.pin_memory = loader.pin_memory and torch.cuda.is_available() self.timeout = loader.timeout self.sample_iter = iter(self.batch_sampler) base_seed = torch.LongTensor(1).random_().item() if self.num_workers > 0: self.worker_init_fn = loader.worker_init_fn self.worker_queue_idx = 0 self.worker_result_queue = multiprocessing.Queue() self.batches_outstanding = 0 self.worker_pids_set = False self.shutdown = False self.send_idx = 0 self.rcvd_idx = 0 self.reorder_dict = {} self.done_event = multiprocessing.Event() self.index_queues = [] self.workers = [] for i in range(self.num_workers): index_queue = multiprocessing.Queue() index_queue.cancel_join_thread() w = multiprocessing.Process( target=_utils.worker._worker_loop, args=(self.dataset, index_queue, self.worker_result_queue, self.done_event, self.collate_fn, base_seed + i, self.worker_init_fn, i)) w.daemon = True # NB: Process.start() actually take some time as it needs to # start a process and pass the arguments over via a pipe. # Therefore, we only add a worker to self.workers list after # it started, so that we do not call .join() if program dies # before it starts, and __del__ tries to join but will get: # AssertionError: can only join a started process. w.start() self.index_queues.append(index_queue) self.workers.append(w) if self.pin_memory: self.data_queue = queue.Queue() pin_memory_thread = threading.Thread( target=_utils.pin_memory._pin_memory_loop, args=(self.worker_result_queue, self.data_queue, torch.cuda.current_device(), self.done_event)) pin_memory_thread.daemon = True pin_memory_thread.start() # Similar to workers (see comment above), we only register # pin_memory_thread once it is started. self.pin_memory_thread = pin_memory_thread else: self.data_queue = self.worker_result_queue _utils.signal_handling._set_worker_pids(id(self), tuple(w.pid for w in self.workers)) _utils.signal_handling._set_SIGCHLD_handler() self.worker_pids_set = True # prime the prefetch loop for _ in range(2 * self.num_workers): self._put_indices() def __len__(self): return len(self.batch_sampler) def _try_get_batch(self, timeout=_utils.MP_STATUS_CHECK_INTERVAL): # Tries to fetch data from `data_queue` for a given timeout. This can # also be used as inner loop of fetching without timeout, with the # sender status as the loop condition. # # This raises a `RuntimeError` if any worker died expectedly. This error # can come from either the SIGCHLD handler in `_utils/signal_handling.py` # (only for non-Windows platforms), or the manual check below on errors # and timeouts. # # Returns a 2-tuple: # (bool: whether successfully get data, any: data if successful else None) try: data = self.data_queue.get(timeout=timeout) return (True, data) except Exception as e: # At timeout and error, we manually check whether any worker has # failed. Note that this is the only mechanism for Windows to detect # worker failures. if not all(w.is_alive() for w in self.workers): pids_str = ', '.join(str(w.pid) for w in self.workers if not w.is_alive()) raise RuntimeError('DataLoader worker (pid(s) {}) exited unexpectedly'.format(pids_str)) if isinstance(e, queue.Empty): return (False, None) raise def _get_batch(self): # Fetches data from `self.data_queue`. # # We check workers' status every `MP_STATUS_CHECK_INTERVAL` seconds, # which we achieve by running `self._try_get_batch(timeout=MP_STATUS_CHECK_INTERVAL)` # in a loop. This is the only mechanism to detect worker failures for # Windows. For other platforms, a SIGCHLD handler is also used for # worker failure detection. # # If `pin_memory=True`, we also need check if `pin_memory_thread` had # died at timeouts. if self.timeout > 0: success, data = self._try_get_batch(self.timeout) if success: return data else: raise RuntimeError('DataLoader timed out after {} seconds'.format(self.timeout)) elif self.pin_memory: while self.pin_memory_thread.is_alive(): success, data = self._try_get_batch() if success: return data else: # while condition is false, i.e., pin_memory_thread died. raise RuntimeError('Pin memory thread exited unexpectedly') # In this case, `self.data_queue` is a `queue.Queue`,. But we don't # need to call `.task_done()` because we don't use `.join()`. else: while True: success, data = self._try_get_batch() if success: return data def __next__(self): if self.num_workers == 0: # same-process loading indices = next(self.sample_iter) # may raise StopIteration batch = self.collate_fn([self.dataset[i] for i in indices]) if self.pin_memory: batch = _utils.pin_memory.pin_memory_batch(batch) return batch # check if the next sample has already been generated if self.rcvd_idx in self.reorder_dict: batch = self.reorder_dict.pop(self.rcvd_idx) return self._process_next_batch(batch) if self.batches_outstanding == 0: self._shutdown_workers() raise StopIteration while True: assert (not self.shutdown and self.batches_outstanding > 0) idx, batch = self._get_batch() self.batches_outstanding -= 1 if idx != self.rcvd_idx: # store out-of-order samples self.reorder_dict[idx] = batch continue return self._process_next_batch(batch) next = __next__ # Python 2 compatibility def __iter__(self): return self def _put_indices(self): assert self.batches_outstanding < 2 * self.num_workers indices = next(self.sample_iter, None) if indices is None: return self.index_queues[self.worker_queue_idx].put((self.send_idx, indices)) self.worker_queue_idx = (self.worker_queue_idx + 1) % self.num_workers self.batches_outstanding += 1 self.send_idx += 1 def _process_next_batch(self, batch): self.rcvd_idx += 1 self._put_indices() if isinstance(batch, _utils.ExceptionWrapper): # make multiline KeyError msg readable by working around # a python bug https://bugs.python.org/issue2651 if batch.exc_type == KeyError and "\n" in batch.exc_msg: raise Exception("KeyError:" + batch.exc_msg) else: raise batch.exc_type(batch.exc_msg) return batch def __getstate__(self): # TODO: add limited pickling support for sharing an iterator # across multiple threads for HOGWILD. # Probably the best way to do this is by moving the sample pushing # to a separate thread and then just sharing the data queue # but signalling the end is tricky without a non-blocking API raise NotImplementedError("_DataLoaderIter cannot be pickled") def _shutdown_workers(self): # See NOTE [ Data Loader Multiprocessing Shutdown Logic ] for details on # the logic of this function. python_exit_status = _utils.python_exit_status if python_exit_status is True or python_exit_status is None: # See (2) of the note. If Python is shutting down, do no-op. return # Normal exit when last reference is gone / iterator is depleted. # See (1) and the second half of the note. if not self.shutdown: self.shutdown = True try: self.done_event.set() # Exit `pin_memory_thread` first because exiting workers may leave # corrupted data in `worker_result_queue` which `pin_memory_thread` # reads from. if hasattr(self, 'pin_memory_thread'): # Use hasattr in case error happens before we set the attribute. # First time do `worker_result_queue.put` in this process. # `cancel_join_thread` in case that `pin_memory_thread` exited. self.worker_result_queue.cancel_join_thread() self.worker_result_queue.put(None) self.pin_memory_thread.join() # Indicate that no more data will be put on this queue by the # current process. This **must** be called after # `pin_memory_thread` is joined because that thread shares the # same pipe handles with this loader thread. If the handle is # closed, Py3 will error in this case, but Py2 will just time # out even if there is data in the queue. self.worker_result_queue.close() # Exit workers now. for q in self.index_queues: q.put(None) # Indicate that no more data will be put on this queue by the # current process. q.close() for w in self.workers: w.join() finally: # Even though all this function does is putting into queues that # we have called `cancel_join_thread` on, weird things can # happen when a worker is killed by a signal, e.g., hanging in # `Event.set()`. So we need to guard this with SIGCHLD handler, # and remove pids from the C side data structure only at the # end. # # FIXME: Unfortunately, for Windows, we are missing a worker # error detection mechanism here in this function, as it # doesn't provide a SIGCHLD handler. if self.worker_pids_set: _utils.signal_handling._remove_worker_pids(id(self)) self.worker_pids_set = False def __del__(self): if self.num_workers > 0: self._shutdown_workers()
self.index_queue = multiprocessing.SimpleQueue()中的multiprocessing是Python中的多进程管理包,而threading则是Python中的多线程管理包,二者很大一部分的接口用法类似。
__init__函数中,前面部分都是一些赋值操作,比较特殊的是self.sample_iter = iter(self.batch_sampler),得到的self.sample_iter可以通过next(self.sample_iter)来获取batch size个数据的index。self.rcvd_idx表示读取到的一个batch数据的index,初始化为0,该值在迭代读取数据的时候会用到。if self.num_workers语句是针对多进程或单进程的情况进行初始化,如果不是设置为多进程读取数据,那么就不需要这些初始化操作,后面会介绍单进程数据读取。在if语句中通过multiprocessing.SimpleQueue()类创建了一个简单的队列对象。multiprocessing.Process类就是构造进程的类,这里根据设定的进程数来启动,然后赋值给self.workers。接下来的一个for循环就通过调用start方法依次启动self.workers中的进程。接下来关于self.pin_memory的判断语句,该判断语句内部主要是实现了多线程操作。self.pin_memory的含义在前面已经介绍过了,当为True的时候,就会把数据拷到CUDA中。self.data_queue = queue.Queue()是通过Python的queue模块初始化得到一个先进先出的队列(queue模块也可以初始化得到先进后出的队列,需要用queue.LifoQueue()初始化),queue模块主要应用在多线程读取数据中。在threading.Thread的args参数中,第一个参数in_data就是一个进程的数据,一个进程中不同线程的数据也是通过队列来维护的,这里采用的是Python的queue模块来初始化得到一个队列:queue.Queue()。初始化结束后,就会调用__next__方法。
总的来说,如果设置为多进程读取数据,那么就会采用队列的方式来读,如果不是采用多进程来读取数据,那就采用普通方式来读。