美熙科技说
大模型

做业务时,往往要加载自己的数据集,本文将结合torch.utils.data.Dataset 和 Dataloader 介绍如何自定义一个数据集加载模块。

 

本文转自:公众号:写bug的程旭源​​​​​​​

原文链接:pytorch构建自定义数据集,流程和代码模板

pytorch中的数据pipeline设计:

生产者消费者模式,分为sampler、dataset、dataloaderlter、dataloader四个抽象层次:

1、sampler:(采样器)

负责生成读取index序列采样(可以自定义控制采样顺序)

2、dataset:

负责根据index读取相应数据并执行预处理(负责处理索引index到样本sample映射的一个类)

3、dataloaderlter:

负责协调多进程执行dataset

4、dataloader:最顶层的抽象

通过index找出一条数据出来 index——>record

本文主要讲解的是dataset和 dataloader。

深度学习中使用Dataset和Dataloader类的流程:

定义Dataset并实例化;
使用Dataloader加载数据;
循环迭代使用Dataloader加载的数据进行训练或者验证;

(一)构建自定义数据集

import torch.utils.data as data


# 自定义Dataset的基本模板
class ExampleDataset(data.Dataset):  

    # 自定义一个类
    def __init__(self, data):

    # 初始化,把数据作为一个参数传递给类;
       self.data = data

    def __len__(self):

    # 返回数据的长度
       return len(self.data)  

    def __getitem__(self, idx):
        x = ...
        y = ...
    return x, y

    #根据索引返回数据
      # return self.data[idx]

自定义数据集一般要有这三个函魔法函数:

def __init__ :初始化,把数据作为一个参数传给类

def __getitem__:根据索引获取样本对(x,y) 索引为(0,len(dataset)-1),根据数据集长度从0开始的索引序列;模型通过这个函数获取一对样本对

def __len__:表示数据集的长度,最终训练时用到的数据集的样本个数

示例

第一个示例,参考[2]:

import torch.utils.data as data
import torchvision.transforms as transforms
import os
import torch


class MyDataset(data.Dataset):
    def __init__(self, data_folder):
        self.data_folder = data_folder
        self.filenames = []
        self.labels = []

        per_classes = os.listdir(data_folder)
        for per_class in per_classes:
            per_class_paths = os.path.join(data_folder, per_class)
            label = torch.tensor(int(per_class))

            per_datas = os.listdir(per_class_paths)
            for per_data in per_datas:
                self.filenames.append(os.path.join(per_class_paths, per_data))
                self.labels.append(label)

    def __getitem__(self, index):
        image = Image.open(self.filenames[index])
        label = self.labels[index]
        data = self.preprocess(image)
        return data, label

    def __len__(self):
        return len(self.filenames)

    def preprocess(self, data):
        transform_train_list = [
            transforms.Resize((self.opt.h, self.opt.w), interpolation=3),
            transforms.Pad(self.opt.pad, padding_mode='edge'),
            transforms.RandomCrop((self.opt.h, self.opt.w)),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ]
        return transforms.Compose(transform_train_list)(data)

第二个示例,参考[3]:

class Dataset_name(Dataset):

   def __init__(self, flag='train'):
        assert flag in ['train', 'test', 'valid']
        self.flag = flag
        self.__load_data__()

    def __getitem__(self, index):
        pass

    def __len__(self):
        pass

    def __load_data__(self, csv_paths: list):
        pass
        # print( "train_X.shape:{}\ntrain_Y.shape:{}\nvalid_X.shape:{}\nvalid_Y.shape:{}\n"
        #     .format(self.train_X.shape, self.train_Y.shape, self.valid_X.shape,             # self.valid_Y.shape))

train_dataset = Dataset_name(flag='train')
train_dataloader = DataLoader(dataset=train_dataset, batch_size=64, shuffle=True)
valid_dataset = Dataset_name(flag='valid')
valid_dataloader = DataLoader(dataset=valid_dataset, batch_size=64, shuffle=True)

第三个示例,参考[4]:

import os
import pandas as pd
from torchvision.io import read_image


class CustomImageDataset(data.Dataset):
    def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):
        self.img_labels = pd.read_csv(annotations_file)
        self.img_dir = img_dir
        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        return len(self.img_labels)

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])
        image = read_image(img_path)
        label = self.img_labels.iloc[idx, 1]
        if self.transform:
            image = self.transform(image)
        if self.target_transform:
            label = self.target_transform(label)
        return image, label

(二)读取数据

创建DataLoader迭代器

from torch.utils.data import DataLoader

train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True)
valid_dataloader = DataLoader(valid_data, batch_size=64, shuffle=True)

使用enumerate访问可遍历的数组对象,参考[5]

for step, (data, label) in enumerate(dataloader):
    print('step is :', step)
    # data, label = item
    print('data is {}, label is {}'.format(data, label))

for i, item in enumerate(dataloader):
    print('i:', i)
    data, label = item
    print('data:', data)
    print('label:', label)

参考链接:

[1.] https://github.com/pytorch/pytorch/blob/67b7e751e6b5931a9f45274653f4f653a4e6cdf6/torch/utils/data/_utils/collate.py
[2.] https://mp.weixin.qq.com/s/EcqsWhissIr1woerYQwEcQ
[3.] https://zhuanlan.zhihu.com/p/396666255
[4.] https://pytorch.org/tutorials/beginner/basics/data_tutorial.html
[5.] https://blog.csdn.net/qq_38607066/article/details/98474121

posted on 2022-11-18 13:00  写bug的程旭源  阅读(605)  评论(0编辑  收藏  举报  来源