【pytorch】土堆pytorch教程学习(六)DataLoader 的使用

DataLoader 用于读取数据到网络。

DataLoader 将数据集(Dataset)和采样器(Sampler)组合在一起,并在给定数据集上提供迭代。Sampler 的功能是生成索引,Dataset 根据生成的索引读取样本以及标签。

DataLoader 支持 map 式和 iterable 式的数据集,可进行单进程或多进程加载、自定义加载顺序和可选的自动批处理和内存固定。

在学习之前我们先来了解一些英文的含义:

  • Epoch:所有训练样本都已经输入到模型中,称为一个 Epoch(阶段)
  • Iteration: 一批样本输入到模型中,称为一个 Iteration(迭代)
  • batch_size: 批大小,决定一个 Iteration 的样本数,也决定了一个 Epoch 有多少个 Iteration

假设样本总数为100,分批处理,设置 batch_size 为10,则一个 Epoch 有 100 / 10 = 10 个 Iteration。

DataLoader

先看下实例化一个 DataLoader 所需的参数:

torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=None, sampler=None,
                            batch_sampler=None, num_workers=0, collate_fn=None, 
                            pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None,
                            multiprocessing_context=None, generator=None, *, 
                            prefetch_factor=None, persistent_workers=False, pin_memory_device='')

只需关注几个常用的参数即可,剩下的可以慢慢了解:

  • dataset:Dataset类,要从中加载数据的数据集
  • batch_size:每批要加载的样本大小,默认为1
  • shuffle:每个 epoch 是否乱序。设为 True 则训练样本在每个 epoch 都打乱顺序,进行重组。
  • num_workers:用于数据加载的子进程数。0表示将在主进程中加载数据(默认值: 0)。决定了是否多进程读取数据。
  • drop_last:当样本数不能被 batch_size 整除时,设置为 True 可舍弃最后一批不完整的数据。
  • sampler:定义从数据集中提取样本的策略。可以是实现了 __len__ 的任何 Iterable。如果指定了,则 shuffle 必须为 False

一般 PyTorch 中深度学习训练的流程如下:

  1. 创建Dateset
  2. Dataset传递给DataLoader
  3. DataLoader迭代产生训练数据提供给模型
# 创建 Dateset
test_set = torchvision.datasets.CIFAR10('./dataset', train=False, transform=torchvision.transforms.ToTensor())
# Dataset 传递给 DataLoader
dataloader = torch.utils.data.DataLoader(test_set, batch_size=64, shuffle=True, num_workers=0, drop_last=True)
# DataLoader 迭代产生训练数据提供给模型
for i in range(epoch):
    for index,(imgs,targets) in enumerate(dataloader):
        pass

Dataset 负责建立索引到样本的映射,DataLoader 负责以特定的方式从数据集中迭代的产生一个个 batch 的样本集合。在 enumerate 过程中实际上是 dataloader 按照其参数 sampler 规定的策略调用了其 dataset 的 __getitem__方法。

posted @   hzyuan  阅读(134)  评论(0编辑  收藏  举报
相关博文:
阅读排行:
· DeepSeek 开源周回顾「GitHub 热点速览」
· 记一次.NET内存居高不下排查解决与启示
· 物流快递公司核心技术能力-地址解析分单基础技术分享
· .NET 10首个预览版发布:重大改进与新特性概览!
· .NET10 - 预览版1新功能体验(一)

喜欢请打赏

扫描二维码打赏

支付宝打赏

点击右上角即可分享
微信分享提示