自定义 数据集对象

在 pytorch 中,数据加载可以通过自定义数据集对象实现;

数据集对象被抽象为 DataSet 类;

自定义数据集对象,需要继承该类,并且实现 __getitem__ 和 __len__ 两个方法

 

示例

class DogCat(data.Dataset):
    def __init__(self, path):
        images = os.listdir(path)
        self.imgs = [os.path.join(path, i) for i in images]

    def __getitem__(self, index):
        ### 只是简单读取图片,并未做任何处理
        img_path = self.imgs[index]
        label = 1 if 'dog' in img_path.split('\\')[-1] else 0
        pil_img = Image.open(img_path)

        array = np.asarray(pil_img)
        data = t.from_numpy(array)
        return data, label

    def __len__(self):
        return len(self.imgs)

dataSet = DogCat(r'F:\dl_dataset\cat_dog\train')
img, label = dataSet[0]
print(img.size(), img.float().mean())
print(label)

 

数据预处理

通常在加载数据时需要进行数据处理;

torchvision 是一个视觉工具包,提供了很多视觉图像处理的工具,其中 transforms 提供了对 PIL Image 和 Tensor 对象的常用操作

 

PIL Image 常见操作我后续会专门写一篇博客;

Tensor 常见操作

  • Normalize:标准化,减去均值除以标准差
  • ToPILImage:将 Tensor 转成 PILImage 对象

 

这些操作定义后都以对象的形式存在,真正使用时调用它的 __call__ 方法

# 正确写法
show = transforms.ToPILImage()
image = show(img)
# 错误写法
show = transforms.ToPILImage(img)

另外需注意:

1. Compose 将多个操作拼接起来,类似于 nn.Sequential

2. transforms.Lambda 方法支持用户自定义数据处理策略

 

示例

import os
import numpy as np
import torch as t
from torch.utils import data
from PIL import Image
from torchvision import transforms
import matplotlib.pylab as plt

transform = transforms.Compose([
    transforms.Resize(224),     ### 缩放图片,长宽比不变,最短边为 224 像素
    transforms.CenterCrop(224), ### 从中间裁剪出 224x224
    transforms.ToTensor(),      ### image to tensor,并归一化至 [0, 1]
    transforms.Normalize(mean=[.5, .5, .5], std=[.5, .5, .5])   ### 标准化至 [-1, 1]
])

class DogCat(data.Dataset):
    def __init__(self, path, transforms=None):
        images = os.listdir(path)
        self.imgs = [os.path.join(path, i) for i in images]
        self.transforms = transforms

    def __getitem__(self, index):
        self.img = self.imgs[index]
        data = Image.open(self.img)
        label = 0 if 'dog' in self.img.split('\\')[-1] else 1

        if self.transforms:
            ### 自定义转换操作
            mytrans = transforms.Lambda(lambda img: img.rotate(np.random.rand() * 360))
            data = mytrans(data)
            data = self.transforms(data)

        return data, label

    def __len__(self):
        return len(self.imgs)

dataSet = DogCat(r'F:\dl_dataset\cat_dog\train', transforms=transform)
img, label = dataSet[0]
print(img.size())
print(label)

show = transforms.ToPILImage()
image = show((img+1)/2)
image.show()

 

批量加载 batch

DataSet 负责数据集的抽象,其 __getitem__ 方法每次获取一个样本,这不利于网络的训练,我们需要 batch、shuffle、甚至并行;

在 pytorch 中使用 DataLoader 实现上述需求

class DataLoader(object):
    r"""
    Data loader. Combines a dataset and a sampler, and provides an iterable over
    the given dataset.
    """

    __initialized = False

    def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None,
                 batch_sampler=None, num_workers=0, collate_fn=None,
                 pin_memory=False, drop_last=False, timeout=0,
                 worker_init_fn=None, multiprocessing_context=None):

参数解释:

  • dataset、batch_size、shuffle 不解释
  • sampler:样本抽样,如随机采样,RandomSampler,当 shuffle 为 True 时,系统自动调用该采样方式,实现打乱数据;默认采样器为 SequentialSampler,逐个采样;WeightRandomSampler,在类别不均衡问题中,这种方式可以实现重采样
  • batch_sampler:
  • num_workers:使用多进程加载数据, 0 代表不使用多进程
  • collate_fn:如何将多个样本拼接成一个 batch,一般使用默认的方式即可
  • pin_memory:是否将数据保存在 pin memory 区,pin memory 区的数据转到 GPU 会更快一些
  • drop_last:如果 dataset 中最后的数据不足一个 batch,弃掉

 

DataLoader 生成的数据类似于一个迭代器,有两种方式读取该数据

from torch.utils.data import DataLoader

dataloader = DataLoader(dataSet, batch_size=3, shuffle=True, num_workers=0, drop_last=False)    ### 类似于迭代器

### 批量获取数据有两种方式
# 方式1
dataiter = iter(dataloader)
imgs, label = next(dataiter)
print(imgs.size())

show = transforms.ToPILImage()
image = show((imgs[0]+1)/2)
image.show()

# 方式2
for batch_data, batch_label in dataloader:
    print(batch_label)

 

WeightRandomSampler

 

加载异常

在数据处理中,有时会遇到某个样本无法读取的情况,如图片已损坏,此时 __getitem__ 函数将出现异常,处理方式有几种:

1. 剔除错误样本

2. __getitem__ 返回 None,然后自定义 collate_fn,将空过滤掉,但是这种情况获取的 batch 会少于 batch_size

3. 随机找一张代替

4. 提前进行数据清洗

 

DataSet 和 DataLoader 使用建议

1. 高负载的操作放在 __getitem__ 中,如图片读取

  // 多进程会并行的调用 __getitem__ 方法,高负载的操作并行执行提高效率

2. dataSet 中尽量只包含只读对象,避免修改任何可变对象

  // 线程安全问题

 

 

 

 

参考资料: