Pytorch——Dataset类和DataLoader类
这篇文章主要探讨一下,Dataset类以及DataLoader类的使用以及注意事项。Dataset类主要是用于原始数据的读取或者基本的数据处理(比如在NLP任务中常常需要把文字转化为对应字典ids,这个步骤就可以放在Dataset中执行)。DataLoader,是进一步对Dataset的处理,Dataset得到的数据集你可以理解为是个"列表"(可以根据index取出某个特定位置的数据),而DataLoder就是把这个数据集(Dataset)根据你设定的batch_size划分成很多个“子数据集”,每个“子数据集”中的元素数量就是batch_size。
DataLoader为什么要把Dataset划分成多个”子数据集“呢?因为一次性把所有的数据放进模型会导致内存溢出,而且模型的迭代会很慢。下面我们就深度解析下Dataset和DataLoader的使用方式。
一、Dataset的使用
这里说到的Dataset其实就是,torch.utils.data.Dataset类 ,换句话说我们需要创建一个Dataset类,使用类的继承就可以了。既然是继承类,那么肯定会修改一些父类(torch.utils.data.Dataset类 )的方法来适应我们的真实数据和逻辑。而我们主要要重写的就是,__init__(),__len__(),__getitem__(),这三个方法分别是以下作用:
__init__
方法:进行类的初始化,一般是用来读取原始数据。__getitem__
方法:根据下标对每一个数据进行进一步的处理。return:希望通过dataset[index]在数据集中取出的元素__len__
方法:return:数据集的数量(int)
下面用一个例子来大致说明下Dataset该怎么构建,并且如何使用。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 | from torch.utils.data import Dataset import torch def MyTokenizer(sentence): src_vocab = { '度' : 0 , '上' : 1 , '世' : 2 , '中' : 3 , '为' : 4 , '人' : 5 , '伟' : 6 , '你' : 7 , '务' : 8 , '国' : 9 , '大' : 10 , '我' : 11 , '是' : 12 , '最' : 13 , '服' : 14 , '民' : 15 , '爱' : 16 , '界' : 17 , '的' : 18 } enc_input = [src_vocab[n] for n in sentence] if len (enc_input) < 12 : ## 如果enc_input的长度小于12,则用100来补足,使得enc_input长度为12. enc_input = enc_input + ( 12 - len (enc_input)) * [ 100 ] return enc_input class MyDataset(Dataset): def __init__( self , data): self .data = data self .tokenizer = MyTokenizer def __len__( self ): return len ( self .data) def __getitem__( self , index): sentence = self .data[index] return torch.tensor( self .tokenizer(sentence)) data = [ '我爱你中国' , '中国是世界上最伟大的国度' , '为人民服务' , '你爱我' ] dataset = MyDataset(data) |
简单介绍一下,函数MyTokenizer(sentence):把一个句子根据字典,转化为一个列表[...],例如”我爱你“——>[11, 16, 7, 100,100,...100]。这里为了简便,我是用的数据就是四句话。我在初始化(__init__)我的MyDataset类时,把数据储存下来,并且定义了我的编码器Mytokenizer。(这里为什么输出列表末尾会有100呢,主要是使得每一个数据长度是一样的,为后面进入DataLoader做准备,其实这个操作就叫做padding)
__len__(self):这里返回了我传入数据集的大小。而__getitem__(self,index):中index指的是数据下标,根据这个下标提取出原始数据(self.data中的一句话),并且把这句话传入到self.tokenizer进行编码,最后返回编码的结果(一个列表[....]),可以看到__getitem__这个函数就是根据index来处理每一个拿出来的原始数据的,你对原始数据的所有处理都可以放在这里。我们最后一行代码是完成了MyDataset的实例化。看一看这个实例化之后的结果。
1 2 | print ( len (dataset)) print (dataset[ 0 ]) 3 [ 11 , 16 , 7 , 3 , 9 ] |
这里返回的tensor([11,16,7,3,9]),其实就是”我爱你中国“经过编码之后的结果(可以对照上面的字典看看)。
其实说白了Dataset就是一个数据处理器,把数据收集起来,并且进行对每一个index的数据进行处理,最后输出。有人会问为啥不先处理好这些数据呢?其实是因为DataLoader只能接受torch.utils.data.Dataset类作为传入参数,因此用其他任意的数据结构都没办法放到DataLoader里面,这样就没法自动根据batch_size拆分成”子数据集“。因此Dataset是我们必须构建的,就算是我的数据不想进一步处理,也必须写一个以上的最简单的MyDataset类(直接传入啥,输出啥的类)。
二、DataLoader的使用
先放官方文档:官方文档
刚才说到Dataset的构建是为了放进到DataLoader里,为啥非要放到这里面呢?其根本原因是DataLoader中有很多好用的设置可以让我们更好的处理数据,比如参数shuffle,可以让Dataset中的数据打乱重新排列再进行分批次,num_workers参数可以设定安排多少个进程来加载数据(加速)。一般情况下我们不需要重写DataLoader类,只需要实例化就可以了。例如我们把上面创建好的Dataset实例——dataset传入到DataLoader中构建实例。
这里一定要注意,每个batch(子集)里的长度一定要一致,不然会报错“RuntimeError: each element in list of batch should be of equal size”。(这也就是为什么,在建立Dataset的时候我会用100来吧不足12长度的句子填充成统一长度,因为我举的例子中没有超过12的句子,所以不存在切割句子,真实情况需要按你自己的数据需求,但是一定要保证出来的数据要一样长,至于为什么一会后面说)。
1 2 | from torch.utils.data import DataLoader myDataloader = DataLoader(dataset, shuffle = True , batch_size = 2 ) |
这个myDataloader就是DataLoader的实例,已经被分为了2个数据为一个batch,接下来我们打印一下每个batch(由于我们只有4句话,2个样本为一个batch那么其实就只有2个batch,所以可以打印来看看)。
1 2 3 | for batch in myDataloader: print (batch) print ( '===============================' ) |
tensor([[ 7, 16, 11, 100, 100, 100, 100, 100, 100, 100, 100, 100], [ 4, 5, 15, 14, 8, 100, 100, 100, 100, 100, 100, 100]]) =============================== tensor([[ 3, 9, 12, 2, 17, 1, 13, 6, 10, 18, 9, 0], [ 11, 16, 7, 3, 9, 100, 100, 100, 100, 100, 100, 100]]) ===============================
可以看到每个batch其实是一个tensor,维度是(2,12)。每个tensor的每一行其实就是一个dataset里的一个样本。并且要注意每个样本已经不是按照原本的顺序排列了。
三、collate_fn参数的使用
在DataLoader里,除了上面提到的shuffle参数和batch_size参数以外,还有一个非常重要的传入参数collate_fn,这个参数传入的是一个函数,这个函数主要是对每个batch进行处理,最终输出一个batch的返回值,换句话说collate_fn函数的返回值,就是遍历DataLoader的时候每个“batch”的返回值了(类似于上面例子中的二维tensor)。下面我写一个函数,让大家看看到底是怎么处理的。
1 2 3 4 5 6 | def mycollate(item): sample1, sample2 = item return { '第一个样本' :sample1, '第二个样本' :sample2} from torch.utils.data import DataLoader myDataloader = DataLoader(dataset, shuffle = True , batch_size = 2 , collate_fn = mycollate) |
我们现在再来打印一下myDataloader的每个元素。
1 2 3 | for batch in myDataloader: print (batch) print ( '===============================' ) |
{'第一个样本': tensor([ 11, 16, 7, 3, 9, 100, 100, 100, 100, 100, 100, 100]), '第二个样本': tensor([ 7, 16, 11, 100, 100, 100, 100, 100, 100, 100, 100, 100])} =============================== {'第一个样本': tensor([ 3, 9, 12, 2, 17, 1, 13, 6, 10, 18, 9, 0]), '第二个样本': tensor([ 4, 5, 15, 14, 8, 100, 100, 100, 100, 100, 100, 100])} ===============================
可以看到,这个时候打印myDataloader的每个元素,就变成我在mycollate()函数中的返回值了。或许会不明白,我在mycollate()函数中这个传入的item是什么?其实这个item是个元组,元组的每个元素就是dataset的每个元素(tensor([3,9,12,....])),item的元素个数其实就是batch_size,这里的batch_size是2,所以我在mycollate()中用了两个变量来接收(换句话说要是我把batch_size换成2以外的其他数字,就会报错了)。
可以看到其实我们在DataLoader的时候依然可以使用函数来处理我们的数据,换句话说我们完全可以把tokenizer函数放到mycollate()函数中。
好了现在我们可以来解释为什么我在第二节的时候要求每个batch的数据要一样长了,那是因为当你不给定collate_fn这个参数的时候,会自动调用一个函数叫做default_collate(),大家可以粗略的看看这个内置函数的源码:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 | def default_collate(batch): r """Puts each data field into a tensor with outer dimension batch size""" elem = batch[ 0 ] elem_type = type (elem) if isinstance (elem, torch.Tensor): out = None if torch.utils.data.get_worker_info() is not None : # If we're in a background process, concatenate directly into a # shared memory tensor to avoid an extra copy numel = sum ([x.numel() for x in batch]) storage = elem.storage()._new_shared(numel) out = elem.new(storage) return torch.stack(batch, 0 , out = out) elif elem_type.__module__ = = 'numpy' and elem_type.__name__ ! = 'str_' \ and elem_type.__name__ ! = 'string_' : if elem_type.__name__ = = 'ndarray' or elem_type.__name__ = = 'memmap' : # array of string classes and object if np_str_obj_array_pattern.search(elem.dtype. str ) is not None : raise TypeError(default_collate_err_msg_format. format (elem.dtype)) return default_collate([torch.as_tensor(b) for b in batch]) elif elem.shape = = (): # scalars return torch.as_tensor(batch) elif isinstance (elem, float ): return torch.tensor(batch, dtype = torch.float64) elif isinstance (elem, int_classes): return torch.tensor(batch) elif isinstance (elem, string_classes): return batch elif isinstance (elem, container_abcs.Mapping): return {key: default_collate([d[key] for d in batch]) for key in elem} elif isinstance (elem, tuple ) and hasattr (elem, '_fields' ): # namedtuple return elem_type( * (default_collate(samples) for samples in zip ( * batch))) elif isinstance (elem, container_abcs.Sequence): # check to make sure that the elements in batch have consistent size it = iter (batch) elem_size = len ( next (it)) if not all ( len (elem) = = elem_size for elem in it): raise RuntimeError( 'each element in list of batch should be of equal size' ) transposed = zip ( * batch) return [default_collate(samples) for samples in transposed] raise TypeError(default_collate_err_msg_format. format (elem_type)) |
看到倒数第三行了么?这就是为什么会报错的原因了。所以如果可以,我建议还是自己设定mycollate()函数,因为源码里如果你的dataset输出的元素不是tensor类型,那么将会按照它的方式来重新组织来返回,不同类别返回的东西是不一样的,大家可以看看源码。
参考网站:
Pytorch的第一步:(1) Dataset类的使用 - 简书 (jianshu.com)
Pytorch 中的数据类型 torch.utils.data.DataLoader 参数详解_Never-Giveup的博客-CSDN博客_dataloader参数
RuntimeError: each element in list of batch should be of equal size_NLP新手村成员的博客-CSDN博客
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】凌霞软件回馈社区,博客园 & 1Panel & Halo 联合会员上线
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】博客园社区专享云产品让利特惠,阿里云新客6.5折上折
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· 浏览器原生「磁吸」效果!Anchor Positioning 锚点定位神器解析
· 没有源码,如何修改代码逻辑?
· 一个奇形怪状的面试题:Bean中的CHM要不要加volatile?
· [.NET]调用本地 Deepseek 模型
· 一个费力不讨好的项目,让我损失了近一半的绩效!
· PowerShell开发游戏 · 打蜜蜂
· 在鹅厂做java开发是什么体验
· 百万级群聊的设计实践
· WPF到Web的无缝过渡:英雄联盟客户端的OpenSilver迁移实战
· 永远不要相信用户的输入:从 SQL 注入攻防看输入验证的重要性