Pytorch——Dataset类和DataLoader类
这篇文章主要探讨一下,Dataset类以及DataLoader类的使用以及注意事项。Dataset类主要是用于原始数据的读取或者基本的数据处理(比如在NLP任务中常常需要把文字转化为对应字典ids,这个步骤就可以放在Dataset中执行)。DataLoader,是进一步对Dataset的处理,Dataset得到的数据集你可以理解为是个"列表"(可以根据index取出某个特定位置的数据),而DataLoder就是把这个数据集(Dataset)根据你设定的batch_size划分成很多个“子数据集”,每个“子数据集”中的元素数量就是batch_size。
DataLoader为什么要把Dataset划分成多个”子数据集“呢?因为一次性把所有的数据放进模型会导致内存溢出,而且模型的迭代会很慢。下面我们就深度解析下Dataset和DataLoader的使用方式。
一、Dataset的使用
这里说到的Dataset其实就是,torch.utils.data.Dataset类 ,换句话说我们需要创建一个Dataset类,使用类的继承就可以了。既然是继承类,那么肯定会修改一些父类(torch.utils.data.Dataset类 )的方法来适应我们的真实数据和逻辑。而我们主要要重写的就是,__init__(),__len__(),__getitem__(),这三个方法分别是以下作用:
__init__
方法:进行类的初始化,一般是用来读取原始数据。__getitem__
方法:根据下标对每一个数据进行进一步的处理。return:希望通过dataset[index]在数据集中取出的元素__len__
方法:return:数据集的数量(int)
下面用一个例子来大致说明下Dataset该怎么构建,并且如何使用。
from torch.utils.data import Dataset import torch def MyTokenizer(sentence): src_vocab = {'度':0,'上':1,'世':2,'中':3, '为':4,'人':5,'伟':6,'你':7,'务':8,'国':9, '大':10,'我':11,'是':12,'最':13,'服':14, '民':15,'爱':16,'界':17,'的':18} enc_input = [src_vocab[n] for n in sentence] if len(enc_input) < 12: ## 如果enc_input的长度小于12,则用100来补足,使得enc_input长度为12. enc_input = enc_input+(12-len(enc_input))*[100] return enc_input class MyDataset(Dataset): def __init__(self, data): self.data = data self.tokenizer = MyTokenizer def __len__(self): return len(self.data) def __getitem__(self, index): sentence = self.data[index] return torch.tensor(self.tokenizer(sentence)) data = ['我爱你中国', '中国是世界上最伟大的国度', '为人民服务','你爱我'] dataset = MyDataset(data)
简单介绍一下,函数MyTokenizer(sentence):把一个句子根据字典,转化为一个列表[...],例如”我爱你“——>[11, 16, 7, 100,100,...100]。这里为了简便,我是用的数据就是四句话。我在初始化(__init__)我的MyDataset类时,把数据储存下来,并且定义了我的编码器Mytokenizer。(这里为什么输出列表末尾会有100呢,主要是使得每一个数据长度是一样的,为后面进入DataLoader做准备,其实这个操作就叫做padding)
__len__(self):这里返回了我传入数据集的大小。而__getitem__(self,index):中index指的是数据下标,根据这个下标提取出原始数据(self.data中的一句话),并且把这句话传入到self.tokenizer进行编码,最后返回编码的结果(一个列表[....]),可以看到__getitem__这个函数就是根据index来处理每一个拿出来的原始数据的,你对原始数据的所有处理都可以放在这里。我们最后一行代码是完成了MyDataset的实例化。看一看这个实例化之后的结果。
print(len(dataset)) print(dataset[0])
3
[11, 16, 7, 3, 9]
这里返回的tensor([11,16,7,3,9]),其实就是”我爱你中国“经过编码之后的结果(可以对照上面的字典看看)。
其实说白了Dataset就是一个数据处理器,把数据收集起来,并且进行对每一个index的数据进行处理,最后输出。有人会问为啥不先处理好这些数据呢?其实是因为DataLoader只能接受torch.utils.data.Dataset类作为传入参数,因此用其他任意的数据结构都没办法放到DataLoader里面,这样就没法自动根据batch_size拆分成”子数据集“。因此Dataset是我们必须构建的,就算是我的数据不想进一步处理,也必须写一个以上的最简单的MyDataset类(直接传入啥,输出啥的类)。
二、DataLoader的使用
先放官方文档:官方文档
刚才说到Dataset的构建是为了放进到DataLoader里,为啥非要放到这里面呢?其根本原因是DataLoader中有很多好用的设置可以让我们更好的处理数据,比如参数shuffle,可以让Dataset中的数据打乱重新排列再进行分批次,num_workers参数可以设定安排多少个进程来加载数据(加速)。一般情况下我们不需要重写DataLoader类,只需要实例化就可以了。例如我们把上面创建好的Dataset实例——dataset传入到DataLoader中构建实例。
这里一定要注意,每个batch(子集)里的长度一定要一致,不然会报错“RuntimeError: each element in list of batch should be of equal size”。(这也就是为什么,在建立Dataset的时候我会用100来吧不足12长度的句子填充成统一长度,因为我举的例子中没有超过12的句子,所以不存在切割句子,真实情况需要按你自己的数据需求,但是一定要保证出来的数据要一样长,至于为什么一会后面说)。
from torch.utils.data import DataLoader myDataloader = DataLoader(dataset, shuffle=True, batch_size=2)
这个myDataloader就是DataLoader的实例,已经被分为了2个数据为一个batch,接下来我们打印一下每个batch(由于我们只有4句话,2个样本为一个batch那么其实就只有2个batch,所以可以打印来看看)。
for batch in myDataloader: print(batch) print('===============================')
tensor([[ 7, 16, 11, 100, 100, 100, 100, 100, 100, 100, 100, 100], [ 4, 5, 15, 14, 8, 100, 100, 100, 100, 100, 100, 100]]) =============================== tensor([[ 3, 9, 12, 2, 17, 1, 13, 6, 10, 18, 9, 0], [ 11, 16, 7, 3, 9, 100, 100, 100, 100, 100, 100, 100]]) ===============================
可以看到每个batch其实是一个tensor,维度是(2,12)。每个tensor的每一行其实就是一个dataset里的一个样本。并且要注意每个样本已经不是按照原本的顺序排列了。
三、collate_fn参数的使用
在DataLoader里,除了上面提到的shuffle参数和batch_size参数以外,还有一个非常重要的传入参数collate_fn,这个参数传入的是一个函数,这个函数主要是对每个batch进行处理,最终输出一个batch的返回值,换句话说collate_fn函数的返回值,就是遍历DataLoader的时候每个“batch”的返回值了(类似于上面例子中的二维tensor)。下面我写一个函数,让大家看看到底是怎么处理的。
def mycollate(item): sample1, sample2 = item return {'第一个样本':sample1,'第二个样本':sample2} from torch.utils.data import DataLoader myDataloader = DataLoader(dataset, shuffle=True, batch_size=2, collate_fn=mycollate)
我们现在再来打印一下myDataloader的每个元素。
for batch in myDataloader: print(batch) print('===============================')
{'第一个样本': tensor([ 11, 16, 7, 3, 9, 100, 100, 100, 100, 100, 100, 100]), '第二个样本': tensor([ 7, 16, 11, 100, 100, 100, 100, 100, 100, 100, 100, 100])} =============================== {'第一个样本': tensor([ 3, 9, 12, 2, 17, 1, 13, 6, 10, 18, 9, 0]), '第二个样本': tensor([ 4, 5, 15, 14, 8, 100, 100, 100, 100, 100, 100, 100])} ===============================
可以看到,这个时候打印myDataloader的每个元素,就变成我在mycollate()函数中的返回值了。或许会不明白,我在mycollate()函数中这个传入的item是什么?其实这个item是个元组,元组的每个元素就是dataset的每个元素(tensor([3,9,12,....])),item的元素个数其实就是batch_size,这里的batch_size是2,所以我在mycollate()中用了两个变量来接收(换句话说要是我把batch_size换成2以外的其他数字,就会报错了)。
可以看到其实我们在DataLoader的时候依然可以使用函数来处理我们的数据,换句话说我们完全可以把tokenizer函数放到mycollate()函数中。
好了现在我们可以来解释为什么我在第二节的时候要求每个batch的数据要一样长了,那是因为当你不给定collate_fn这个参数的时候,会自动调用一个函数叫做default_collate(),大家可以粗略的看看这个内置函数的源码:
def default_collate(batch): r"""Puts each data field into a tensor with outer dimension batch size""" elem = batch[0] elem_type = type(elem) if isinstance(elem, torch.Tensor): out = None if torch.utils.data.get_worker_info() is not None: # If we're in a background process, concatenate directly into a # shared memory tensor to avoid an extra copy numel = sum([x.numel() for x in batch]) storage = elem.storage()._new_shared(numel) out = elem.new(storage) return torch.stack(batch, 0, out=out) elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \ and elem_type.__name__ != 'string_': if elem_type.__name__ == 'ndarray' or elem_type.__name__ == 'memmap': # array of string classes and object if np_str_obj_array_pattern.search(elem.dtype.str) is not None: raise TypeError(default_collate_err_msg_format.format(elem.dtype)) return default_collate([torch.as_tensor(b) for b in batch]) elif elem.shape == (): # scalars return torch.as_tensor(batch) elif isinstance(elem, float): return torch.tensor(batch, dtype=torch.float64) elif isinstance(elem, int_classes): return torch.tensor(batch) elif isinstance(elem, string_classes): return batch elif isinstance(elem, container_abcs.Mapping): return {key: default_collate([d[key] for d in batch]) for key in elem} elif isinstance(elem, tuple) and hasattr(elem, '_fields'): # namedtuple return elem_type(*(default_collate(samples) for samples in zip(*batch))) elif isinstance(elem, container_abcs.Sequence): # check to make sure that the elements in batch have consistent size it = iter(batch) elem_size = len(next(it)) if not all(len(elem) == elem_size for elem in it): raise RuntimeError('each element in list of batch should be of equal size') transposed = zip(*batch) return [default_collate(samples) for samples in transposed] raise TypeError(default_collate_err_msg_format.format(elem_type))
看到倒数第三行了么?这就是为什么会报错的原因了。所以如果可以,我建议还是自己设定mycollate()函数,因为源码里如果你的dataset输出的元素不是tensor类型,那么将会按照它的方式来重新组织来返回,不同类别返回的东西是不一样的,大家可以看看源码。
参考网站:
Pytorch的第一步:(1) Dataset类的使用 - 简书 (jianshu.com)
Pytorch 中的数据类型 torch.utils.data.DataLoader 参数详解_Never-Giveup的博客-CSDN博客_dataloader参数
RuntimeError: each element in list of batch should be of equal size_NLP新手村成员的博客-CSDN博客