【pytorch】土堆pytorch教程学习(六)DataLoader 的使用
DataLoader
用于读取数据到网络。
DataLoader
将数据集(Dataset)和采样器(Sampler)组合在一起,并在给定数据集上提供迭代。Sampler 的功能是生成索引,Dataset 根据生成的索引读取样本以及标签。
DataLoader
支持 map 式和 iterable 式的数据集,可进行单进程或多进程加载、自定义加载顺序和可选的自动批处理和内存固定。
在学习之前我们先来了解一些英文的含义:
- Epoch:所有训练样本都已经输入到模型中,称为一个 Epoch(阶段)
- Iteration: 一批样本输入到模型中,称为一个 Iteration(迭代)
- batch_size: 批大小,决定一个 Iteration 的样本数,也决定了一个 Epoch 有多少个 Iteration
假设样本总数为100,分批处理,设置 batch_size 为10,则一个 Epoch 有 100 / 10 = 10 个 Iteration。
DataLoader
先看下实例化一个 DataLoader
所需的参数:
torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=None, sampler=None,
batch_sampler=None, num_workers=0, collate_fn=None,
pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None,
multiprocessing_context=None, generator=None, *,
prefetch_factor=None, persistent_workers=False, pin_memory_device='')
只需关注几个常用的参数即可,剩下的可以慢慢了解:
dataset
:Dataset类,要从中加载数据的数据集batch_size
:每批要加载的样本大小,默认为1shuffle
:每个 epoch 是否乱序。设为 True 则训练样本在每个 epoch 都打乱顺序,进行重组。num_workers
:用于数据加载的子进程数。0表示将在主进程中加载数据(默认值: 0)。决定了是否多进程读取数据。drop_last
:当样本数不能被batch_size
整除时,设置为True
可舍弃最后一批不完整的数据。sampler
:定义从数据集中提取样本的策略。可以是实现了__len__
的任何 Iterable。如果指定了,则shuffle
必须为False
。
一般 PyTorch 中深度学习训练的流程如下:
- 创建Dateset
- Dataset传递给DataLoader
- DataLoader迭代产生训练数据提供给模型
# 创建 Dateset
test_set = torchvision.datasets.CIFAR10('./dataset', train=False, transform=torchvision.transforms.ToTensor())
# Dataset 传递给 DataLoader
dataloader = torch.utils.data.DataLoader(test_set, batch_size=64, shuffle=True, num_workers=0, drop_last=True)
# DataLoader 迭代产生训练数据提供给模型
for i in range(epoch):
for index,(imgs,targets) in enumerate(dataloader):
pass
Dataset 负责建立索引到样本的映射,DataLoader 负责以特定的方式从数据集中迭代的产生一个个 batch 的样本集合。在 enumerate
过程中实际上是 dataloader 按照其参数 sampler
规定的策略调用了其 dataset 的 __getitem__
方法。
本文来自博客园,作者:hzyuan,转载请注明原文链接:https://www.cnblogs.com/hzyuan/p/17369963.html
分类:
ai / pytorch
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· DeepSeek 开源周回顾「GitHub 热点速览」
· 记一次.NET内存居高不下排查解决与启示
· 物流快递公司核心技术能力-地址解析分单基础技术分享
· .NET 10首个预览版发布:重大改进与新特性概览!
· .NET10 - 预览版1新功能体验(一)