PyTorch自定义数据集
数据传递机制
我们首先回顾识别手写数字的程序:
1 2 3 4 5 6 7 | ... Dataset = torchvision.datasets.MNIST(root = './mnist/' , train = True , transform = transform, download = True ,) dataloader = torch.utils.data.DataLoader(dataset = Dataset, batch_size = 64 , shuffle = True ) ... for epoch in range (EPOCH): for i, (image, label) in enumerate (dataloader): ... |
从上面的程序,我们可以知道,在PyTorch中,数据传递机制是这样的:
- 创建Dataset
- Dataset传递给DataLoader
- DataLoader迭代产生训练数据提供给模型
总结这个数据传递机制就是,Dataset负责建立索引到样本的映射,DataLoader负责以特定的方式从数据集中迭代的产生一个个batch的样本集合。在enumerate过程中实际上是dataloader按照其参数sampler规定的策略调用了其dataset的getitem方法(下文中将介绍该方法)。关于Dataloder和Dataset的关系,具体可参考博客PyTorch中Dataset, DataLoader, Sampler的关系
在上面的识别手写数字的例子中,数据集是直接下载的,但如果我们自己收集了一些数据,存在电脑文件夹里,我们该如何把这些数据变为可以在PyTorch框架下进行神经网络训练的数据集呢,即如何自定义数据集呢?
自定义数据集
torch.utils.data.Dataset 是一个表示数据集的抽象类。任何自定义的数据集都需要继承这个类并覆写相关方法。所谓数据集,其实就是一个负责处理索引(index)到样本(sample)映射的一个类(class)。Pytorch提供两种数据集: Map式数据集 Iterable式数据集。这里我们只介绍前者。
一个Map式的数据集必须要重写getitem(self, index)、 len(self) 两个内建方法,用来表示从索引到样本的映射(Map)。这样一个数据集dataset,举个例子,当使用dataset[idx]命令时,可以在你的硬盘中读取数据集中第idx张图片以及其标签(如果有的话); len(dataset)则会返回这个数据集的容量。
自定义数据集类的范式大致是这样的:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 | class CustomDataset(torch.utils.data.Dataset): #需要继承torch.utils.data.Dataset def __init__( self ): # TODO # 1. Initialize file path or list of file names. pass def __getitem__( self , index): # TODO # 1. Read one data from file (e.g. using numpy.fromfile, PIL.Image.open). # 2. Preprocess the data (e.g. torchvision.Transform). # 3. Return a data pair (e.g. image and label). #这里需要注意的是,第一步:read one data,是一个data point pass def __len__( self ): # You should change 0 to the total size of your dataset. return 0 |
根据这个范式,我们举一个例子。
实例
从kaggle官网下载dogsVScats的数据集(百度网盘下载链接见文末),该数据集包含test1文件夹和train文件夹,train文件夹中包含12500张猫的图片和12500张狗的图片,图片的文件名中带序号:
cat. 0.jpg cat. 1.jpg cat. 2.jpg ... cat. 12499.jpg dog. 0.jpg dog. 1.jpg dog. 2.jpg ... dog. 12499.jpg |
我们把其中前10000张猫的图片和10000张狗的图片作为训练集,把后面的2500张猫的图片和2500张狗的图片作为验证集。猫的label记为0,狗的label记为1。因为图片大小不一,所以,我们需要对图像进行transform。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 | import matplotlib.pyplot as plt import numpy as np import torch from torch.utils.data import Dataset, DataLoader from torchvision import transforms from PIL import Image import os image_transform = transforms.Compose([ transforms.Resize( 256 ), # 把图片resize为256*256 transforms.RandomCrop( 224 ), # 随机裁剪224*224 transforms.RandomHorizontalFlip(), # 水平翻转 transforms.ToTensor(), # 将图像转为Tensor transforms.Normalize(mean = [ 0.485 , 0.456 , 0.406 ], std = [ 0.229 , 0.224 , 0.225 ]) # 标准化 ]) class DogVsCatDataset(Dataset): # 创建一个叫做DogVsCatDataset的Dataset,继承自父类torch.utils.data.Dataset def __init__( self , root_dir, train = True , transform = None ): """ Args: root_dir (string): Directory with all the images. transform (callable, optional): Optional transform to be applied on a sample. """ self .root_dir = root_dir self .img_path = os.listdir( self .root_dir) if train: self .img_path = list ( filter ( lambda x: int (x.split( '.' )[ 1 ]) < 10000 , self .img_path)) # 划分训练集和验证集 else : self .img_path = list ( filter ( lambda x: int (x.split( '.' )[ 1 ]) > = 10000 , self .img_path)) self .transform = transform def __len__( self ): return len ( self .img_path) def __getitem__( self , idx): image = Image. open (os.path.join( self .root_dir, self .img_path[idx])) label = 0 if self .img_path[idx].split( '.' )[ 0 ] = = 'cat' else 1 # label, 猫为0,狗为1 if self .transform: image = self .transform(image) label = torch.from_numpy(np.array([label])) return image, label |
我们来测试一下:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 | if __name__ = = '__main__' : catanddog_dataset = DogVsCatDataset(root_dir = '/Users/wangpeng/Desktop/train' , train = False , transform = image_transform) train_loader = DataLoader(catanddog_dataset, batch_size = 8 , shuffle = True , num_workers = 4 ) # num_workers=4表示用4个线程读取数据 image, label = iter (train_loader). next () # iter()函数把train_loader变为迭代器,然后调用迭代器的next()方法 sample = image[ 0 ].squeeze() sample = sample.permute(( 1 , 2 , 0 )).numpy() sample * = [ 0.229 , 0.224 , 0.225 ] sample + = [ 0.485 , 0.456 , 0.406 ] sample = np.clip(sample, 0 , 1 ) plt.imshow(sample) plt.show() print ( 'Label is: {}' . format (label[ 0 ].numpy())) |
运行结果:
Label is: [0]
dogsVScats数据下载链接:链接:https://pan.baidu.com/s/17768gqeaX9NrdURV_tR_ow 提取密码:478x
参考文献
[1] Pytorch之Dataset与DataLoader,打造你自己的数据集
[2] 基于PyTorch的卷积神经网络图像分类——猫狗大战(一):使用Pytorch定义DataLoader
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· 基于Microsoft.Extensions.AI核心库实现RAG应用
· Linux系列:如何用heaptrack跟踪.NET程序的非托管内存泄露
· 开发者必知的日志记录最佳实践
· SQL Server 2025 AI相关能力初探
· Linux系列:如何用 C#调用 C方法造成内存泄露
· Manus爆火,是硬核还是营销?
· 终于写完轮子一部分:tcp代理 了,记录一下
· 别再用vector<bool>了!Google高级工程师:这可能是STL最大的设计失误
· 震惊!C++程序真的从main开始吗?99%的程序员都答错了
· 单元测试从入门到精通