PyTorch_dataloader
最近了解torch时,发现dataloader不懂,所以写一篇博客,以此记录一下。
官方链接:https://pytorch.org/docs/stable/data.html?highlight=torch%20utils%20data%20dataloader#torch.utils.data.DataLoader
- Dataloader将Dataset或其子类封装成一个迭代器
- 这个迭代器可以迭代输出Dataset的内容
- 同时可以实现多进程、shuffle、不同采样策略,数据校对等等处理过程
Dataset是一个包装类,用来将数据包装为Dataset类,然后传入DataLoader中,我们再使用DataLoader这个类来更加快捷的对数据进行操作。
DataLoader(dataset, batch_size=1, shuffle=False, sampler=None,
batch_sampler=None, num_workers=0, collate_fn=None,
pin_memory=False, drop_last=False, timeout=0,
worker_init_fn=None, *, prefetch_factor=2,
persistent_workers=False)
其中dataset就是我们封装好的数据集
-
batch_size ( int , optional ) – 每批要加载多少样本(默认值:)
1
。 -
shuffle(布尔,可选) -设置为
True
有数据在每个时间段改组(默认:False
)。 -
sampler ( Sampler or Iterable , optional ) – 定义从数据集中抽取样本的策略。可以是任何
Iterable
与__len__
实施。如果指定,则shuffle
不得指定。 -
batch_sampler ( Sampler or Iterable , optional ) – 类似
sampler
,但一次返回一批索引。互斥有batch_size
,shuffle
,sampler
,和drop_last
。 -
num_workers ( int , optional ) – 用于数据 加载的子进程数。
0
意味着数据将在主进程中加载。(默认值:0
) -
collate_fn ( callable , optional ) – 合并一个样本列表以形成一个小批量的 Tensor(s)。在使用地图样式数据集的批量加载时使用。
-
pin_memory ( bool , optional ) – 如果
True
,数据加载器将在返回之前将张量复制到 CUDA 固定内存中。如果您的数据元素是自定义类型,或者您collate_fn
返回的批次是自定义类型,请参见下面的示例。 -
drop_last ( bool , optional ) –
True
如果数据集大小不能被批处理大小整除,则设置为删除最后一个不完整的批处理。如果False
并且数据集的大小不能被批大小整除,那么最后一批将更小。(默认值:False
) -
timeout ( numeric , optional ) – 如果为正,则为从工作人员收集批次的超时值。应该总是非负的。(默认值:
0
) -
worker_init_fn ( callable , optional ) – 如果不是
None
,这将在每个工作子进程上调用,在播种之后和数据加载之前,将工作人员 id(一个 int in )作为输入。(默认值:)[0, num_workers - 1]
None
-
generator ( torch .Generator , optional ) – 如果没有,RandomSampler 将使用这个 RNG 来生成随机索引和多处理来为 worker生成 base_seed。(默认值:)
None
None
-
prefetch_factor ( int , optional , keyword-only arg ) – 每个工作人员提前加载的样本数。
2
意味着将在所有工作人员中预取总共 2 * num_workers 个样本。(默认值:2
) -
persistent_workers ( bool , optional ) – 如果
True
,数据加载器在数据集被消费一次后不会关闭工作进程。这允许保持工作人员数据集实例处于活动状态。(默认值:False
)
首先,__getitem__就是获取样本对,模型直接通过这一函数获得一对样本对{x:y}。__len__是指数据集长度。