PyTorch_dataloader

最近了解torch时,发现dataloader不懂,所以写一篇博客,以此记录一下。

官方链接:https://pytorch.org/docs/stable/data.html?highlight=torch%20utils%20data%20dataloader#torch.utils.data.DataLoader

    1. Dataloader将Dataset或其子类封装成一个迭代器
    2. 这个迭代器可以迭代输出Dataset的内容
    3. 同时可以实现多进程、shuffle、不同采样策略,数据校对等等处理过程

Dataset是一个包装类,用来将数据包装为Dataset类,然后传入DataLoader中,我们再使用DataLoader这个类来更加快捷的对数据进行操作。

DataLoader(dataset, batch_size=1, shuffle=False, sampler=None,
           batch_sampler=None, num_workers=0, collate_fn=None,
           pin_memory=False, drop_last=False, timeout=0,
           worker_init_fn=None, *, prefetch_factor=2,
           persistent_workers=False)
其中dataset就是我们封装好的数据集
  • dataset数据) -数据从其中加载所述数据集。

  • batch_size ( int , optional ) – 每批要加载多少样本(默认值:)1

  • shuffle(布尔可选) -设置为True数据在每个时间段改组(默认:False)。

  • sampler ( Sampler or Iterable optional ) – 定义从数据集中抽取样本的策略可以是任何Iterable__len__ 实施。如果指定,则shuffle不得指定。

  • batch_sampler ( Sampler or Iterable optional ) – 类似sampler,但一次返回一批索引。互斥有 batch_sizeshufflesampler,和drop_last

  • num_workers ( int , optional ) – 用于数据 加载的进程数0意味着数据将在主进程中加载​​。(默认值:0

  • collat​​e_fn ( callable optional ) – 合并一个样本列表以形成一个小批量的 Tensor(s)。在使用地图样式数据集的批量加载时使用

  • pin_memory ( bool , optional ) – 如果True数据加载器将在返回之前将张量复制到 CUDA 固定内存中。如果您的数据元素是自定义类型,或者您collate_fn返回的批次是自定义类型,请参见下面的示例。

  • drop_last ( bool , optional ) –True如果数据集大小不能被批处理大小整除,则设置为删除最后一个不完整的批处理。如果False并且数据的大小不能被批大小整除,那么最后一批将更小。(默认值:False

  • timeout ( numeric optional ) – 如果为正,则为从工作人员收集批次的超时值。应该总是非负的。(默认值:0

  • worker_init_fn ( callable optional ) – 如果不是None,这将在每个工作子进程上调用在播种之后和数据加载之前,将工作人员 id(一个 int in )作为输入(默认值:[0, num_workers 1]None

  • generator ( torch .Generator , optional ) – 如果没有,RandomSampler 将使用这个 RNG 来生成随机索引和多处理来为 worker生成 base_seed(默认值:NoneNone

  • prefetch_factor ( int , optional keyword-only arg ) – 每个工作人员提前加载的样本数。2意味着将在所有工作人员中预取总共 2 * num_workers 个样本。(默认值:2

  • persistent_workers ( bool , optional ) – 如果True数据加载器在数据集被消费一次后不会关闭工作进程这允许保持工作人员数据集实例处于活动状态。(默认值:False

 
其中dataset也稍微记录一下
Dataset是一个包装类,用来将数据包装为Dataset类,然后传入DataLoader中,我们再使用DataLoader这个类来更加快捷的对数据进行操作。在训练模型时使用到此函数,用来把训练数据分成多个小组,此函数每次抛出一组数据。直至把所有的数据都抛出。就是做一个数据的初始化。
pytorch给出的官方代码,其中__getitem__和__len__是子类必须继承的。
首先,__getitem__就是获取样本对,模型直接通过这一函数获得一对样本对{x:y}。__len__是指数据集长度。
自定义数据集我还没有尝试过,有时间在补一下吧。感觉太难了呀。。。。。。。。
 

posted @ 2021-10-07 23:31  微草wd  阅读(103)  评论(0编辑  收藏  举报