g
y
7
7
7
7

Pytorch Dataset入门

Dataset入门

Pytorch Dataset code:torch/utils/data/dataset.py#L17

Pytorch Dataset tutorial: tutorials/beginner/basics/data_tutorial.html

 
理论:

PyTorch中的Dataset是一个抽象类,用来表示数据集的接口,所有其他数据集都需要继承这个类,并且覆写以下三个方法:

  1. __init__:初始化数据集的一些配置,例如加载所有的数据标签。

  2. __len__:以便len(dataset)可以返回数据集的大小,例如n。如果n小于数据集长度,则只会取前n个的数据。

  3. __getitem__:输入是数据的索引,以便可以使用dataset[i]来获取第i个样本,数据增强一般会在这里做。

代码:

下面是一个自定义的Dataset样例(不可执行):

import cv2
import json
import torch.utils.Dataset as Dataset

class CustomDataset(Dataset):
    def __init__(self, imgs_path, labels_path, img_transform=None, label_transform=None):
        self.imgs_path = imgs_path  # 输入图像的路径,list
        self.labels_path = labels_path  # 输入图像对应的标签路径,list
        self.img_transform = img_transform  # 图像的数据增强
        self.label_transform = label_transform  # 标签的数据增强

    def __len__(self):
        return len(self.imgs_path)  # 返回数据集的长度

    def __getitem__(self, idx):
        img_path = self.imgs_path[idx]
        label_path = self.labels_path[idx]
        img = cv2.imread(img_path)  # 读取图像
        label = json.load(open(label_path))  # 读取标签
        if self.img_transform:  # 图像的数据增强
            img = self.img_transform(img)
        if self.label_transform:  # 标签的数据增强
            label = self.label_transform(label)
        return img, label  # 返回图像和标签,用于训练

 

总结:

值得注意的是,Dataset只负责数据的加载和预处理,对于如何训练数据(例如:是否进行shuffle,是否进行并行加速等)这部分的逻辑是由DataLoader实现的。通常情况下,我们会将DatasetDataLoader一起使用。

另外,PyTorch还提供了一些常用的数据集,如:ImageFolderCIFAR10MNIST等,这些数据集都是继承Dataset类,同时在init方法中进行数据的下载,以及在getitem方法中进行数据的加载和预处理。

Dataset是单线程读取数据,每次只能读取一个样本,不能一次性读取一个mini-batch的数据。

Dataset的主要特性包含:

  • 抽象接口:PyTorch通过定义一个抽象Dataset类,让用户可以使用统一的方式来加载各种不同的数据,提供了很好的扩展性。

  • 懒加载:实际的数据载入并不发生在构造数据集实例时,而是发生在用到这些数据时,这样可以提高内存利用率,并且可以实现对大规模数据的处理。

  • 预处理:Dataset的一个重要应用就是数据预处理,你可以在getitem函数中进行任何你的数据预处理过程。

 


嗨,欢迎大家关注我的公众号《CV之路》,一起讨论问题,一起学习进步~

posted @ 2024-04-16 10:17  gy77  阅读(58)  评论(0编辑  收藏  举报