Pytorch:Dataloader和Dataset以及搭建数据部分的步骤
接下来几篇博文开始,介绍pytorch五大模块中的数据模块,所有概念都会以第四代人民币1元和100元纸币的二分类问题为例来具体介绍,在实例中明白相关知识。
数据模块的结构体系
数据模块分为数据的收集、划分、读取、预处理四部分,其中收集和划分是人工可以设定,而读取部分和预处理部分,pytorch有相应的函数和运行机制来实现。读取部分中pytorch靠dataloader这个数据读取机制来读取数据。
Dataloader
Dataloader涉及两个部分,一是sampler部分,用于生成数据的索引(即序号),二是dataset,根据索引来读取相应的数据和标签。
torch.utils.data.Dataloader
功能:构建可迭代的数据装载器
主要属性:
dataset:Dataset类,决定数据从哪里读取以及如何读取
batchsize:批大小
num_works:是否以多进程读取数据
shuffle:每个epoch是否乱序
drop_last:当样本数不能被batchsize整除时,是否舍弃最后一批数据
epoch:所有训练样本都已输入到模型中,称为一个epoch
iteration:一批样本输入到模型中,称之为一个iteration
batchsize:批大小,决定一个epoch有多少个iteration
举例:样本总数:80,batchsize:8,则 1 epoch = 10 iteration
样本总数:85,batchsize:8,则 1 epoch = {设定drop_last:10 iteration;不设定:11 iteration}
torch.utils.data.Dataset
功能:抽象类,所有自定义的Dataset需要继承它,并且复写__getitem__()
getitem:接收一个索引,返回一个样本
实例体现
下面介绍一下代码构建的流程,主要涉及数据模块
1.数据收集(img,label)
由于是二分类,所以可以构建两个文件夹进行简单区分
并划分训练、验证和数据集
以8:1:1的比例划分train valid test三个数据集,接下来设置好各数据路径
以及数据各通道的均值和标准差(这个需要自己计算得出)
下面就是数据模块中预处理中transform方法的建立,这个会在下一篇博文中展开
接下来为构建自定义Dataset实例
以及构建Dataloader
其中Dataset必须是用户自己写的
接下来便是模型模块、损失函数模块、优化器模块、迭代训练模块
在迭代训练中,数据的获取为 for i, data in enumerate(train_loader)
主要探究enumerate(train_loader)其中的机制
阅读Dataloader源码可知:
- 迭代dataloader首先会进入是否多线程运行的判断(比如单进程singleprocess)
- 然后进入_SingleProcessDataloaderIter.__next__中获取index和通过index获取data
- index列表由sampler生成,长度为一个batch_size
- 再由self.dataset_fetcher.fetch(index)去获取data的路径和标签,fetch会一步步跳转到自定义dataset中的__getitem__(self, index)
- 采用Image.open读取路径中的数据,如果有transform方法,则进行transform后再返回img及label
- 当fetch进行return时,会采用collate_fn(data)方法将所有单个数据整理成一个batch(字典样式:label - img.tensor)的形式并返回
可以归纳得到:
另附CNY二分类模型代码示例:(为节省篇幅,以下只展示与数据模块有关的步骤)