torch.utils.data.Dataset 和 torch.utils.data.DataLoader
torch.utils.data
是PyTorch
中用于数据加载和预处理的模块。通常结合使用其中的Dataset
和DataLoader
两个类来加载和处理数据。
Dataset
torch.utils.data.Dataset
是一个抽象类,用于表示数据集。
需要用户自己实现两个方法:__len__
和__getitem__
。
__len__
方法返回数据集的大小,__getitem__
方法用于根据给定的索引返回一个数据样本。
import torch.utils.data as data
class MyDataset(data.Dataset):
def __init__(self, data_list):
self.data_list = data_list
def __len__(self):
return len(self.data_list)
def __getitem__(self, index):
return self.data_list[index]
DataLoader
torch.utils.data.Dataset
用于表示数据集,torch.utils.data.DataLoader
用于加载数据,并对数据进行批量处理和随机化。
dataset
要加载的数据集对象,必须是实现了len()和getitem()方法的对象。
batch_size
每个批次的数据量大小,默认为1。batch size
的大小会直接影响到模型的训练速度和效果。如果batch size
过大,可能会导致内存不足或者训练速度变慢;如果batch size
过小,则可能会降低模型的泛化能力。因此,我们需要根据实际情况来选择合适的batch size
。
num_workers
num_workers
参数用于指定使用多少个进程来加载数据。默认值为0,表示使用主进程加载数据。如果设置为正数,则会使用多个子进程来加载数据,从而提高数据加载的速度。
通过设置num_workers
参数为正数来启用多个子进程加载数据,并利用PyTorch的自动混合精度训练(Automatic Mixed Precision, AMP)功能来提高数据加载和处理的速度。
pin_memory
pin_memory
参数用于指定是否将数据加载到CUDA主机内存中的固定位置(pinned memory
),以提高数据传输效率。默认值为False
。
collate_fn
collate_fn
参数用于指定如何将样本组合成一个批次。默认情况下,DataLoader
将每个样本作为一个单独的元素传递给模型,但在某些情况下,需要将样本组合成一个批次,以便一次性对整个批次进行处理。 默认为None,表示使用默认的方式进行组合。
在某些特殊情况下,我们可能需要自定义collate_fn
函数来按照特定的方式组合多个数据样本。例如,在处理图像数据时,可能需要将多个图像拼接成一个大的图像作为输入;在处理文本数据时,可能需要将多个文本序列拼接成一个长的文本序列作为输入。通过自定义collate_fn函数,我们可以轻松实现这些需求。
def my_collate_fn(batch):
# 将样本组合成一个批次
data = [item[0] for item in batch]
target = [item[1] for item in batch]
return [data, target]
my_dataloader = data.DataLoader(my_dataset, batch_size=2, shuffle=True, collate_fn=my_collate_fn)
DataLoader
将每个样本作为一个元素传递给my_collate_fn
函数,函数将样本组合成一个批次,并返回一个包含数据和目标的列表。
shuffle
是否对数据进行随机洗牌操作,默认为False。通过启用shuffle功能来打乱数据的顺序,可以有效防止模型过拟合。但是需要注意的是,在每个epoch开始时都需要重新打乱数据的顺序,否则会导致模型训练效果不佳。
Sampler
Sampler
是一个用于指定数据集采样方式的类,它控制DataLoader
如何从数据集中选取样本。PyTorch
提供了多种Sampler
类,例如RandomSampler
和SequentialSampler
,分别用于随机采样和顺序采样。如果指定了Sampler
,则shuffle参数将被忽略。
from torch.utils.data.sampler import RandomSampler
my_sampler = RandomSampler(my_dataset)
my_dataloader = data.DataLoader(my_dataset, batch_size=2, shuffle=False, sampler=my_sampler)
自定义Sampler
class MySampler(Sampler):
def __init__(self, data_source):
self.data_source = data_source
def __iter__(self):
return iter(range(len(self.data_source)))
def __len__(self):
return len(self.data_source)
MySampler
类继承自torch.utils.data.sampler.Sampler
类,实现了__iter__
和__len__
方法。MySampler类的构造函数接受一个数据集作为参数,__iter__
方法返回一个迭代器,用于遍历数据集中的样本索引,__len__
方法返回数据集中样本的数量。
drop_last
如果数据集大小不能被batch size整除,设置为True可以删除最后一个不完整的批次,默认为False。