pytorch-Dataset-Dataloader

pytorch-Dataset-Dataloader

pyTorch为我们提供的两个Dataset和DataLoader类分别负责可被Pytorh使用的数据集的创建以及向训练传递数据的任务。

data.Dataset

torch.utils.data.Dataset 是一个表示数据集的抽象类。任何自定义的数据集都需要继承这个类并覆写相关方法。

只负责数据的抽象,一次只是返回一个数据

Dataset是用来解决数据从哪里读取以及如何读取的问题。pytorch给定的Dataset是一个抽象类,所有自定义的Dataset都要继承它,

并且复写__getitem__()__len__()类方法,__getitem__()的作用是接受一个索引,返回一个样本或者标签。

__len__ 前者提供了数据集的大小,

_getitem__ 后者支持整数索引,范围从0到len(self)

import torch
from torch.utils.data import Dataset

class MyDataset(Dataset):
    # 构造函数
    def __init__(self, data_tensor, target_tensor):
        self.data_tensor = data_tensor
        self.target_tensor = target_tensor
    # 返回数据集大小
    def __len__(self):
        return self.data_tensor.size(0)
    # 返回索引的数据与标签
    def __getitem__(self, index):
        return self.data_tensor[index], self.target_tensor[index]

    
data_tensor = torch.randn(10, 3)
target_tensor = torch.randint(2, (10,))  # 标签是0或1

print(data_tensor.shape)
print(target_tensor.shape)

# 将数据封装成Dataset
my_dataset = MyDataset(data_tensor, target_tensor)
print(my_dataset)
print(my_dataset[0])    

# 执行结果
torch.Size([10, 3])
torch.Size([10])
<__main__.MyDataset object at 0x000002496FF46820>
(tensor([1.0655, 1.4536, 1.0800]), tensor(1))

MyDataset 的特点
1、继承于torch.utils.data.Dataset。
2、通过读取任意格式的数据、预处理、数据增强、以及数据转换、将数据以tensor输出
3、输出的结果有两个。tensor格式的数据和数据标签
4、主要是实现了三个函数__init__,__len__,__getitem__

from torch.utils.data import Dataset
from PIL import Image
import os
import json


class GetData(Dataset):

    def __init__(self, img_dir, labelfile):
        # self
        self.img_dir = img_dir
        self.img_list = os.listdir(self.img_dir)

        with open(str(labelfile)) as f:
            label = json.load(f)
        self.label = label

    def __getitem__(self, idx):
        imgname = self.img_list[idx]  # 只获取了文件名
        img_path = os.path.join(self.img_dir, imgname)  # 每个图片的位置
        img = Image.open(img_path)

        label = self.label[imgname]
        return img, label

    def __len__(self):
        return len(self.img_list)


root_dir = "../../assets/datasets"
ants_label_dir = "../../assets/classes.json"
ants_dataset = GetData(root_dir, ants_label_dir)
print(len(ants_dataset))
img, lable = ants_dataset[0]  # 返回一个元组,返回值就是__getitem__的返回值
print(img.size)
print(lable)


3
(473, 266)
dog
===================================
数据目录结构
|---datasets/
	|----/1.jpg
	|----/2.jpg   
---classes.json

classes.json标注结构
{
  "1.jpg": "dog",
  "2.jpg": "dog",
  "3.jpg": "dog"
}

该接口主要用来将自定义的数据读取接口的输出或者PyTorch已有的数据读取接口的输入按照batch size封装成Tensor,后续只需要再包装成Variable即可作为模型的输入,因此该接口有点承上启下的作用,比较重要

data.DataLoader

Dataset这个类中的__getitem__的返回值,应该是某一个样本的数据和标签,在训练的过程中,一般是需要将多个数据组成batch。所以PyTorch中存在DataLoader这个迭代器。

形成batch数据,并且可以使用shuffe和加速

数据加载器。组合了一个数据集和采样器,并提供关于数据的迭代器。 返回迭代器。
在训练模型时使用到此函数,用来把训练数据分成多个小组,此函数每次抛出一组数据。直至把所有的数据都抛出。

class torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=None, sampler=None, batch_sampler=None, num_workers=0, collate_fn=None, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None, multiprocessing_context=None, generator=None, *, prefetch_factor=None, persistent_workers=False, pin_memory_device='')


"""输入参数

dataset:      		数据集的储存的路径位置等信息,定义好的Map式或者Iterable式数据集
batch_size: 		每次取数据的数量,比如batchi_size=2
shuffle 			default: False  打乱数据
sampler: 			如果指定,"shuffle"必须为false,提取样本的策略
batch_sampler 		None,和batch_size、shuffle 、sampler and drop_last参数
num_workers 		加载数据的进程,多进程会更快
collate_fn 		    如何将多个样本数据拼接成一个batch,自定义数据读取方式,可以用来过滤数据	
pin_memory 			张量复制到CUDA内存,能够加快内存访问速度
drop_last 			如何处理数据集长度除于batch_size余下的数据。True就抛弃
timeout 			default:0	   读取超时,超时报错


在数据处理中,有时会出现某个样本无法读取等问题,如果实在是遇到这种情况无法处理,则可以返回None对象,然后在Dataloader中实现自定义的collate_fn,将空对象过滤掉。

返回参数
迭代器 tensor_loader
for data, target in tensor_dataloader: 
    print(data, target)
"""


# 定义加载器
tensor_loader=DataLoader(mydataset, batch_size=64,)

# 训练或者测试加载器
for i,(data, target) in enumerate(tensor_dataloader): 
    print(data, target)
    
import torch
from torch.utils.data import Dataset,DataLoader

class MyDataset(Dataset):
    # 构造函数
    def __init__(self, data_tensor, target_tensor):
        self.data_tensor = data_tensor
        self.target_tensor = target_tensor

    # 返回数据集大小
    def __len__(self):
        return self.data_tensor.size(0)

    # 返回索引的数据与标签
    def __getitem__(self, index):
        return self.data_tensor[index], self.target_tensor[index]


data_tensor = torch.randn(10, 3)
target_tensor = torch.randint(2, (10,))  # 标签是0或1

print(data_tensor.shape)
print(target_tensor.shape)

# 将数据封装成Dataset
my_dataset = MyDataset(data_tensor, target_tensor)
tensor_loader=DataLoader(my_dataset, batch_size=3,)

# 训练或者测试加载器
for i,(data, target) in enumerate(tensor_loader):
    print(len(data), len(target))
    
torch.Size([10, 3])
torch.Size([10])
3 3
3 3
3 3
1 1

自定义collate_fn 过滤失效数据

'''
在数据处理中,有时会出现某个样本无法读取等问题,比如某张图片损坏。这时在__getitem__函数中将出现异常,
此时最好的解决方案即是将出错的样本剔除。如果实在是遇到这种情况无法处理,则可以返回None对象,
然后在Dataloader中实现自定义的collate_fn,将空对象过滤掉。但要注意,在这种情况下dataloader返回的batch数目会少于batch_size。
'''
import os, json
from PIL import Image
import torch
from torch.utils.data import DataLoader, Dataset


class NewDogCat(Dataset):  # 继承前面实现的DogCat数据集
    # 构造函数
    def __init__(self, img_dir, labelfile, transform):
        # self
        self.img_dir = img_dir
        self.img_list = os.listdir(self.img_dir)
        self.transform = transform

        with open(str(labelfile)) as f:
            label = json.load(f)
        self.label = label

    def __getitem__(self, idx):
        try:
            imgname = self.img_list[idx]  # 只获取了文件名
            img_path = os.path.join(self.img_dir, imgname)  # 每个图片的位置

            label = self.label[imgname]
            img = Image.open(img_path)
            img = self.transform(img)

            return img, label
        except Exception as e:
            print(e,"数据读取错误")
            return None,None

    def __len__(self):
        return len(self.img_list)


from torch.utils.data.dataloader import default_collate  # 导入默认的拼接方式
from torch.utils.data import DataLoader
from torchvision import transforms


def my_collate_fn(batch):
    '''
    batch中每个元素形如(data, label)
    '''
    # 过滤为None的数据
    batch = list(filter(lambda x: x[0] is not None, batch))
    if len(batch) == 0: return torch.Tensor()
    return default_collate(batch)  # 用默认方式拼接过滤后的batch数据

transform = transforms.Compose([
    transforms.Resize(224),  # 缩放图片,保持长宽比不变,最短边的长为224像素,
    transforms.CenterCrop(224),  # 从中间切出 224*224的图片
    transforms.ToTensor(),  # 将图片转换为Tensor,归一化至[0,1]
    transforms.Normalize(mean=[.5, .5, .5], std=[.5, .5, .5])  # 标准化至[-1,1]
])
root_dir = "../../assets/datasets"
ants_label_dir = "../../assets/classes.json"
dataset = NewDogCat(root_dir, ants_label_dir, transform=transform)

dataloader = DataLoader(dataset, 2, collate_fn=my_collate_fn, shuffle=True)
for batch_datas, batch_labels in dataloader:
    print(batch_datas[0].shape, len(batch_labels))
    
'5.jpg' 数据读取错误
torch.Size([3, 224, 224]) 1
torch.Size([3, 224, 224]) 2

总结

  1. 首先我们要去构建自己继承Dataset的MyDataSet

  2. 传入到Dataloader中,最后进行enumerate遍历每个batchsize

  3. Dataset通过index输出的最好是tensor

  4. 整体的Dataset和Dataloader中,基本上是Dataloader每次给你返回一个shuffle过的index

参考资料

https://zhuanlan.zhihu.com/p/340465632

posted @ 2023-07-15 17:21  贝壳里的星海  阅读(109)  评论(0编辑  收藏  举报