Dataset与Dataloader

在pytorch中,Dataset和Dataloader是用来加载数据的两个类

数据集

采用蚂蚁蜜蜂数据集,数据集的目录结构如下

hymenoptera_data/
|-- train/
|   |-- ants/
|   |   |-- 0013035.jpg
|   |   |-- 24335309_c5ea483bb8.jpg
|   |   |-- ... ...
|   |-- bees/
|   |   |-- 16838648_415acd9e3f.jpg
|   |   |-- ... ...
|-- val/
|   |-- ants/
|   |   |-- 8124241_36b290d372.jpg
|   |   |-- ... ...
|   |-- bees/
|   |   |-- 26589803_5ba7000313.jpg
|   |   |-- ... ...

Dataset

首先来看一下Dataset的函数文档

from torch.utils.data import Dataset
help(Dataset)
---------------------------------
class Dataset(typing.Generic)
 |  Dataset(*args, **kwds)
 |  
 |  An abstract class representing a :class:`Dataset`.
 |  
 |  All datasets that represent a map from keys to data samples should subclass
 |  it. All subclasses should overwrite :meth:`__getitem__`, supporting fetching a
 |  data sample for a given key. Subclasses could also optionally overwrite
 |  :meth:`__len__`, which is expected to return the size of the dataset by many
 |  :class:`~torch.utils.data.Sampler` implementations and the default options
 |  of :class:`~torch.utils.data.DataLoader`.

可以看到,Dataset是一个抽象类(接口),实现该接口的子类必须实现__getitem__和__len__方法

接下来通过Dataset实现对蚂蚁蜜蜂数据集的加载

from torch.utils.data import Dataset
from PIL import Image
import os

class MyData(Dataset):
    def __init__(self, root, is_ants=True, is_train=True):
        if is_ants:
            self.label = 'ants'
        else:
            self.label = 'bees'
        if is_train:
            tv = 'train'
        else:
            tv = 'val'
        self.path = os.path.join(root, tv, self.label) 
        self.img_list = os.listdir(self.path)
        
    def __getitem__(self, idx):
        img_name = self.img_list[idx]
        img_path = os.path.join(self.path, img_name)
        img = Image.open(img_path)
        return img, self.label  # 返回第idx张图片和该图片的标签
    
    def __len__(self):
        return len(self.img_list)

使用MyData对数据集进行加载和读取

val_ants = MyData(r'D:\dataset\hymenoptera_data', True, False)  # 加载验证集中的蚂蚁数据集
print("共%d个样本" % len(val_ants))
print(val_ants[0])  # 读取第10张图片及其标签
val_bees = MyData(r'D:\dataset\hymenoptera_data', False, False)  # 加载验证集中的蜜蜂数据集
print("共%d个样本" % len(val_bees))
print(val_bees[0])  # 读取第10张图片及其标签
---------------------------------
共70个样本
(<PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=500x375 at 0x1D58C7C0550>, 'ants')
共83个样本
(<PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=366x500 at 0x1D58C7C08E0>, 'bees')

由于Dataset中已经定义了__add__方法,因此我们可以直接通过加法将两个数据集合并

val_data = val_ants + val_bees
print("共%d个样本" % len(val_data))
print(val_data[len(val_ants)-1][1], val_data[len(val_ants)][1])
---------------------------------
共153个样本
ants bees
posted @ 2022-03-06 19:04  Bill_H  阅读(119)  评论(0编辑  收藏  举报