torch.utils.data.Dataset 和 torch.utils.data.DataLoader

torch.utils.dataPyTorch中用于数据加载和预处理的模块。通常结合使用其中的DatasetDataLoader两个类来加载和处理数据。

Dataset

torch.utils.data.Dataset是一个抽象类,用于表示数据集。

需要用户自己实现两个方法:__len____getitem__

__len__方法返回数据集的大小,__getitem__方法用于根据给定的索引返回一个数据样本。

import torch.utils.data as data

class MyDataset(data.Dataset):
    def __init__(self, data_list):
        self.data_list = data_list

    def __len__(self):
        return len(self.data_list)

    def __getitem__(self, index):
        return self.data_list[index]

DataLoader

torch.utils.data.Dataset用于表示数据集,torch.utils.data.DataLoader用于加载数据,并对数据进行批量处理和随机化。

dataset

要加载的数据集对象,必须是实现了len()和getitem()方法的对象。

batch_size

每个批次的数据量大小,默认为1。batch size的大小会直接影响到模型的训练速度和效果。如果batch size过大,可能会导致内存不足或者训练速度变慢;如果batch size过小,则可能会降低模型的泛化能力。因此,我们需要根据实际情况来选择合适的batch size

num_workers

num_workers参数用于指定使用多少个进程来加载数据。默认值为0,表示使用主进程加载数据。如果设置为正数,则会使用多个子进程来加载数据,从而提高数据加载的速度。

通过设置num_workers参数为正数来启用多个子进程加载数据,并利用PyTorch的自动混合精度训练(Automatic Mixed Precision, AMP)功能来提高数据加载和处理的速度。

pin_memory

pin_memory参数用于指定是否将数据加载到CUDA主机内存中的固定位置(pinned memory),以提高数据传输效率。默认值为False

collate_fn

collate_fn参数用于指定如何将样本组合成一个批次。默认情况下,DataLoader将每个样本作为一个单独的元素传递给模型,但在某些情况下,需要将样本组合成一个批次,以便一次性对整个批次进行处理。 默认为None,表示使用默认的方式进行组合。

在某些特殊情况下,我们可能需要自定义collate_fn函数来按照特定的方式组合多个数据样本。例如,在处理图像数据时,可能需要将多个图像拼接成一个大的图像作为输入;在处理文本数据时,可能需要将多个文本序列拼接成一个长的文本序列作为输入。通过自定义collate_fn函数,我们可以轻松实现这些需求。

def my_collate_fn(batch):
    # 将样本组合成一个批次
    data = [item[0] for item in batch]
    target = [item[1] for item in batch]
    return [data, target]

my_dataloader = data.DataLoader(my_dataset, batch_size=2, shuffle=True, collate_fn=my_collate_fn)

DataLoader将每个样本作为一个元素传递给my_collate_fn函数,函数将样本组合成一个批次,并返回一个包含数据和目标的列表。

shuffle

是否对数据进行随机洗牌操作,默认为False。通过启用shuffle功能来打乱数据的顺序,可以有效防止模型过拟合。但是需要注意的是,在每个epoch开始时都需要重新打乱数据的顺序,否则会导致模型训练效果不佳。

Sampler

Sampler是一个用于指定数据集采样方式的类,它控制DataLoader如何从数据集中选取样本。PyTorch提供了多种Sampler类,例如RandomSamplerSequentialSampler,分别用于随机采样和顺序采样。如果指定了Sampler,则shuffle参数将被忽略。

from torch.utils.data.sampler import RandomSampler

my_sampler = RandomSampler(my_dataset)
my_dataloader = data.DataLoader(my_dataset, batch_size=2, shuffle=False, sampler=my_sampler)

自定义Sampler

class MySampler(Sampler):
    def __init__(self, data_source):
        self.data_source = data_source
        
    def __iter__(self):
        return iter(range(len(self.data_source)))
    
    def __len__(self):
        return len(self.data_source)

MySampler类继承自torch.utils.data.sampler.Sampler类,实现了__iter____len__方法。MySampler类的构造函数接受一个数据集作为参数,__iter__方法返回一个迭代器,用于遍历数据集中的样本索引,__len__方法返回数据集中样本的数量。

drop_last

如果数据集大小不能被batch size整除,设置为True可以删除最后一个不完整的批次,默认为False。

posted @ 2024-08-02 21:23  华小电  阅读(66)  评论(0编辑  收藏  举报