Pytorh数据集加载框架

Pytorh数据集加载框架

图像分类的数据集加载

以加载mnist手写数字数据集为例。
首先将官网的mnist格式通过以下代码存储为图片。

# -*- coding: UTF-8 -*-
# 把mnist数据集转成图片做测试,图片更为通用
import cv2
import os
from keras.datasets import mnist
import numpy as np
import shutil

str_1 = r'D:\DL_Code\dataset\mnist\mnist_train'
str_2 = r'D:\DL_Code\dataset\mnist\mnist_test'

# if os.path.exists(str_1):
#     shutil.rmtree(str_1)
# if os.path.exists(str_2):
#     shutil.rmtree(str_2)

if os.path.exists(str_1) is False:
    os.mkdir(str_1)
if os.path.exists(str_2) is False:
    os.mkdir(str_2)

#自动下载mnist数据集
(X_train, Y_train), (X_test, Y_test) = mnist.load_data()
print('save training data start')
for i in range(0, 59999):  # 迭代 0 到 59999 之间的数字
    fileName = os.path.join(str_1,str(Y_train[i]),str(Y_train[i]) + "_" + str(i) + ".jpg")
    cv2.imwrite(fileName, X_train[i])
    print(i,end='\r')
print('save training data finish!')

print('save testing data start')
for i in range(0, 9999):  # 迭代 0 到 9999 之间的数字
    fileName = os.path.join(str_2,str(Y_test[i]),str(Y_test[i]) + "_" + str(i) + ".jpg")
    cv2.imwrite(fileName, X_test[i])
    print(i,end='\r')
print('save testing data finish!')

如下图,转换完成后会在目录下生成0~9的文件夹,文件夹下存放着对应的数字图片。而文件夹的名字则是其对应的label。

接下来就是要通过Pytorch的数据集加载框架将图片读取进来。首先自定义一个Dataset类,该类要实现__getitem__(self, index)def __init__(self)__len__(self)三个方法。再调用DataLoader类将自定义好的类的对象加载进来。
文件夹目录结构:

  • 根目录: mnist/
    • mnist_train/
      • 0......9
    • mnist_test/
      • 0......9
  • tips: 可以把dataset理解为一幅扑克牌,而DataLoader就是抽牌的过程。
from torchvision import transforms as T
from torchvision.transforms import functional as func
import os
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from PIL import Image
import os
import numpy as np
import pandas as pd
import glob

# 搭建一个数据集生成器通用框架
class Dataset_x(Dataset):
    def __init__(self, root, transforms=None, train=True, val=False):
        """
        get images and execute transforms.
        """
        self.val = val
        imgs = glob.glob(os.path.join(root, '*/*.jpg'))
        imgs = sorted(imgs, key=lambda x: x.split('.')[-2])  # 对图片进行排序,按照图片的索引
        self.imgs = imgs
        if transforms is None:
            # normalize
            normalize = T.Normalize(mean = [0.485, 0.456, 0.406],
                                     std = [0.229, 0.224, 0.225])
            # trainset and valset have different data transform
            # trainset need data augmentation but valset don't.
            # valset

            if self.val:
                # T是组合的意思,将下面的操作组合起来
                self.transforms = T.Compose([
                    T.Resize(224),
                    T.CenterCrop(224),
                    T.ToTensor(),
                    normalize
                ])
            # trainset
            else:
                self.transforms = T.Compose([
                    T.Resize(256),
                    T.RandomResizedCrop(224), #随机大小裁剪
                    T.RandomHorizontalFlip(),
                    T.ToTensor(),
                    normalize
                ])

    def __getitem__(self, index):
        """
        return data and label
        """
        img_path = self.imgs[index]
        label = img_path.split('\\')[-2]
        data = Image.open(img_path)
        #data = self.transforms(data)
        data = func.to_tensor(data)
        return data, label

    def __len__(self):
        """
        return images size.
        """
        return len(self.imgs)

if __name__ == "__main__":
    train_dataset = Dataset_x(r'D:\DL_Code\dataset\mnist\mnist_train', train=True)
    test_dataset = Dataset_x(r'D:\DL_Code\dataset\mnist\mnist_test', train=False)
    train_data_loader = DataLoader(dataset=train_dataset, batch_size=10, shuffle=True)
    test_data_loader = DataLoader(dataset=train_dataset, batch_size=10, shuffle=True)
    # for data,label in train_data_loader:
    #     print('label',label)
    for data, label in test_data_loader:
        print('label', label)
        break

图像分割类的数据集加载

图像分割类问题其label由字符串转变为了图像。因此要格外注意img与mask的对应关系。

如下图:根目录文件夹首先被分为了img目录和mask目录,img存储篡改图像,mask存储其篡改区域的mask图像。

若下载的数据集中把img和mask放在了一个文件夹里,可以用如下代码👇 把它们分离出来。

# 代码在服务器上,下次贴过来。

分离开后可以使用下面👇的Pytorch数据集加载框架载入到dataloader里来:

from torchvision import transforms as T
from torchvision.transforms import functional as func
import os
import epoch
import matplotlib.pyplot as plt
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from PIL import Image
import os
import numpy as np
import pandas as pd
import tqdm
import glob

# 搭建一个用于图像分割的数据集生成器
class Dataset_x(Dataset):
    def __init__(self, root, transforms=None, train=True, val=False):
        """
        get images and execute transforms.
        """
        self.val = val
        imgs_path_list = glob.glob(os.path.join(train_data_path, 'img/*.png'))
        imgs_path_list = sorted(imgs_path_list, key=lambda x: x.split('.')[-2])  # 对图片进行排序,按照图片的索引
        self.imgs_path_list,self.masks_path_list = self.return_realpath(imgs_path_list)
        if transforms is None:
            # normalize
            normalize = T.Normalize(mean = [0.485, 0.456, 0.406],
                                     std = [0.229, 0.224, 0.225])
            # trainset and valset have different data transform
            # trainset need data augmentation but valset don't.
            # valset

            if self.val:
                # T是组合的意思,将下面的操作组合起来
                self.transforms = T.Compose([
                    T.Resize(224),
                    T.CenterCrop(224),
                    T.ToTensor(),
                    normalize
                ])
            # trainset
            else:
                self.transforms = T.Compose([
                    T.Resize(256),
                    T.RandomResizedCrop(224), #随机大小裁剪
                    T.RandomHorizontalFlip(),
                    T.ToTensor(),
                    normalize
                ])

    def __getitem__(self, index):
        img_path = self.imgs_path_list[index]
        mask_path = self.masks_path_list[index]
        data = Image.open(img_path)
        #data = self.transforms(data)
        data = func.to_tensor(data)

        label = Image.open(mask_path)
        label = func.to_tensor(label)
        return data, label

    def __len__(self):
        return len(self.imgs_path_list)

    def return_realpath(self,imgs_path_list):
        real_img_path_list = []
        real_mask_path_list = []
        for imgs_path in imgs_path_list:
            mask_path = imgs_path.replace('img', 'mask').replace('rgb', 'mask')
            if os.path.exists(mask_path) == False:
                print("{}对应的mask图片不存在".format(imgs_path))
                continue
            real_img_path_list.append(imgs_path)
            real_mask_path_list.append(mask_path)
        return real_img_path_list,real_mask_path_list

def display_data_label(data,label):
    img = data[0]
    img = img.numpy()
    img = np.transpose(img, (1, 2, 0))
    mask = label[0]
    mask = mask.numpy()
    mask = np.transpose(mask, (1, 2, 0))
    # 显示图片
    plt.figure(figsize=(8, 8))
    plt.subplot(121)
    plt.imshow(img)
    plt.xticks([])
    plt.yticks([])
    plt.subplot(122)
    plt.imshow(mask)
    plt.xticks([])
    plt.yticks([])
    #plt.imshow(mask,cmap= 'gray') #二值化显示
    plt.show()

if __name__ == "__main__":
    train_data_path = r'C:\Users\Liang\dataset\spliced_copymove_NIST_part_two'
    test_data_path = r'C:\Users\Liang\dataset\spliced_copymove_NIST_part_three'
    train_dataset = Dataset_x(train_data_path, train=True)
    # test_dataset = Dataset_x(test_data_path, train=False)
    train_data_loader = DataLoader(dataset=train_dataset, batch_size=10, shuffle=True)
    # test_data_loader = DataLoader(dataset=test_data_path, batch_size=10, shuffle=True)
    cnt = 0
    for t in range(80):
        for data,label in tqdm.tqdm(train_data_loader, desc='Epoch %2d' % (t + 1)):
            cnt+=1
            #display_data_label(data,label)
            #print("\r"+str(cnt),end='')
  • 还定义了一个display_data_label()方法用于可视化img与其对应的mask图像。

个人自用

from torch.utils.data import Dataset
import cv2
import numpy as np
import os
import random
from PIL import Image
from torchvision import transforms as T
from torchvision.transforms import Compose, CenterCrop, ToTensor, Pad,Resize
import torch

def load_img(filepath,colordim):
    if colordim==1:
        img = Image.open(filepath).convert('L')
    else:
        img = Image.open(filepath).convert('RGB')
    return img

def transforms_lr(img_pil,padding_size):
    transforms_lr_ = Compose([
            Pad(padding=(padding_size, 0, 0, 0), fill=0),
            Resize((512,512)),
            T.RandomHorizontalFlip(p=0.5),
            #T.RandomRotation((-45,45))
            ToTensor(),
        ])
    img_pil = transforms_lr_(img_pil)
    return img_pil


def transforms_tb(img_pil,padding_size):
    transforms_tb_ = Compose([
            Pad(padding=(0, 0, 0, padding_size), fill=0),
            Resize((512,512)),
            T.RandomHorizontalFlip(p=0.5),
            #T.RandomRotation((-45,45))
            ToTensor(),
        ])
    img_pil = transforms_tb_(img_pil)
    return img_pil

class UNetDataset(Dataset):
    def __init__(self, dir_train, dir_mask,colordim=1,out_dim=1):
        self.dirTrain = dir_train
        self.dirMask = dir_mask
        self.colordim = colordim
        self.out_dim = out_dim
        self.image_w = 0
        self.image_h = 0
        self.dataTrain = [os.path.join(self.dirTrain, filename)
                          for filename in os.listdir(self.dirTrain)
                          if filename.endswith('.jpg') or filename.endswith('.png')]
        self.dataMask = [os.path.join(self.dirMask, filename)
                         for filename in os.listdir(self.dirMask)
                         if filename.endswith('.jpg') or filename.endswith('.png')]
        self.trainDataSize = len(self.dataTrain)
        self.maskDataSize = len(self.dataMask)

        self.transforms_ = Compose([
            T.RandomCrop((256,256)),
            ToTensor(),
        ])

    def __getitem__(self, index):
        assert self.trainDataSize == self.maskDataSize
        image = load_img(self.dataTrain[index], self.colordim)
        label = load_img(self.dataTrain[index], self.out_dim)#dataMask
        w,h = image.size
        self.image_w = w
        self.image_h = h
        seed = np.random.randint(index)
        print(w,h)
        if (w-h)<0:
            torch.manual_seed(seed)
            image = transforms_lr(image,h-w)
            torch.manual_seed(seed)
            label = transforms_lr(label,h-w)
        if (w-h)>0:
            torch.manual_seed(seed)
            image = transforms_tb(image,w-h)
            torch.manual_seed(seed)
            label = transforms_tb(label,w-h)

        # torch.manual_seed(seed)
        # image = self.transforms_(image)
        # torch.manual_seed(seed)
        # label = self.transforms_(label)

        return image, label

    def __len__(self):
        return self.trainDataSize

#调用
input_dim = 3
out_dim = 3
img_dir = r'E:\Liang\Dataset\IDM2020_total\IDM2020_img'
gt_mask_dir = r'E:\Liang\Dataset\IDM2020_total\IDM2020_mask'
dataset = UNetDataset(img_dir, gt_mask_dir,colordim=input_dim,out_dim=out_dim)

dataLoader = DataLoader(dataset=dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0,pin_memory=True,drop_last=True)
posted @ 2021-10-20 10:44  梁君牧  阅读(148)  评论(0编辑  收藏  举报