Pytorch框架学习---(2)输入数据操作
本节讲述Data如何利用Pytorch提供的DataLoader进行读取,以及Transforms的图片处理方式。 【文中思维导图采用MindMaster软件】 注意:笼统总结Transforms,目前仅具体介绍裁剪、翻转、标准化,后续随着代码需要,再逐步更新。 |
一. 数据读取(DataLoader和Dataset)
1.DataLoader
我们采用Pytorch提供的DataLoader进行数据Batch封装,其中需要定义dataset类。
自定义的dataset类需要复写def getitem(self, index):函数!!!
train_loader = DataLoader(dataset=train_dataset,
batch_size=Batch_Size,
shuffle=True,
num_workers=4, # num_worker = 4 * GPU个数 为了数据进来更快一些
pin_memory=True, # 也是为了数据输入更快,但是会对增加显存负担 !!!
drop_last=True) # droplast:最后一个批次不满足设定数目Batch_Size,则舍弃
for epoch in range(Max_Epoch):
for i, (inputs, labels) in enumerate(train_loader): # 每次调用一个batch,后台索引
# 也可以采用next(iter(train_loader)), 读取一个批次
在网络运行时,我们采用enumerate函数,进行迭代,这里会:
-
进入DataLoader数据装载器;
-
判断参数,是否采用多进程处理;
-
调用Sampler函数,根据输入数据个数(由Dataset类中def len()函数得到),随机获取index索引值;
-
进入我们定义的Dataset类,调用def getitem(),根据index获取数据,返回;
-
调用collate_fn()函数整理数据,最终得到Batch。
2.代码(如何将电脑中的数据送入网络?)
注意:这里数据集已经分类好,文件夹已经各自建立,不包含划分数据的函数!!
import torch
from torch.utils.data import Dataset
import os
from PIL import Image
import numpy as np
import torchvision.transforms as transforms
category = {"0": 0, "1": 1, "1_enhanced": 2, "1_enhanced_2": 3, "0_enhanced_1":4} # 定义标签,"文件夹名":标签
class my_dataset(Dataset):
'''根据自己的数据,进行读取,Dataset类创建Pytorch数据集类型'''
'''
Args:
data_dir: 数据地址(训练集、验证集、测试集)
transform: torchvision.transforms(各种变换、以及Totensor)
Return:
read_data 根据dataloader的索引获取数据
len(self.data_info) 数据个数
'''
def __init__(self, data_dir, transform=None):
self.transforms = transform
self.data_info = self.get_dataset_info(data_dir) # 获取所有数据路径和对应的标签,方便dataloader 用index批量处理
def __getitem__(self, index): # 当dataloader sampler得到index,根据该index索引dataset中数据
path_data, label = self.data_info[index]
read_data = Image.open(path_data).convert("RGB") # PIL-->RGB(0-256)
if self.transforms is not None:
read_data = self.transforms(read_data)
return read_data, label
def __len__(self):
return len(self.data_info)
@staticmethod # 定义该函数为静态类型,不用实例化类也可调用
def get_dataset_info(data_dir):
data_info = list() # 最终包含所有图片、标签(每一行)
for root, dirs, files in os.walk(data_dir): # 获取当前文件夹的父目录、当前文件夹下所有文件名、所有内部文件
for sub_dir in dirs: # 遍历所有类别
each_cate = os.listdir(os.path.join(root, sub_dir)) # os.listdir() 方法用于返回指定的文件夹包含的文件或文件夹的名字的列表。
for i in range(len(each_cate)): # 遍历每一个类别下的图片数据,将标签一同嵌入
each_data_name = each_cate[i]
each_data_path = os.path.join(root, sub_dir, each_data_name)
each_label = category[sub_dir]
data_info.append((each_data_path, int(each_label)))
return data_info
二.数据预处理(torchvision.transforms)
1.torchvision
2.transforms.Compose([......])组合
计算机将按照Compose中定义的transforms操作,依次进行数据处理。
train_transforms = transforms.Compose([
transforms.Resize((75, 75)),
transforms.ToTensor(), # (H x W x C) [0, 255] to a torch.FloatTensor (C x H x W) [0.0, 1.0]
transforms.Normalize(mean=norm_mean,std=norm_std) # 逐通道归一化,注意通道数
])
3.各种transforms处理方式
本节目前仅介绍:标准化Normalize、图像裁剪Crop、旋转翻转。
(1)数据标准化
transforms.Normalize(mean, std, inplace=False) #逐通道对图像进行标准化,mean:(M1,...,Mn) and std: (S1,..,Sn) for n channels
# input[channel] = (input[channel] - mean[channel]) / std[channel]
(2)裁剪
a)从中心进行裁剪
transforms.CenterCrop(size=32) # 由图像中心进行裁剪,size=32*32
b)随机裁剪
transforms.RandomCrop(size, padding=None, pad_if_needed=False, fill=0, padding_mode='constant')
# 先填充再随机裁剪
# padding:设置填充大小,数值a --> 上下左右填充a个像素,(a,b)--> 左右a上下b, (a,b,c,d) --> 左a上b右c下d
# padding_mode:填充模式:
# constant:像素值由fill参数设定;
# edge:由图像边缘像素决定;
# reflect:镜像填充,最后一个像素不镜像;
# symmetric:镜像填充,最后一个像素镜像。
c)随机面积、随机长宽比裁剪图片
transforms.RandomResizedCrop(size, scale=(0.08, 1.0), ratio=(3. / 4., 4. / 3.), interpolation=Image.BILINEAR)
# 先选择scale,再ratio,再判断size,是否需要interpolation进行resized
# scale=(0.08, 1.0):随机裁剪面积比例,范围内随机选
# ratio=(3. / 4., 4. / 3.):随机长宽比
# interpolation:插值方法
d)上下左右中心随机裁剪5张图片
transforms.FiveCrop(size) # 从上下左右中心各裁剪出五张图片
transforms.TenCrop(size, vertical_flip=False) # 先进行FiveCrop(),再对五张图片进行水平/垂直镜像,获得10张图片
注意:这里返回的是tuple()类型,需要按行拼接起来,送入下游transforms处理。
>>> transform = Compose([
>>> TenCrop(size), # this is a list of PIL Images
>>> Lambda(lambda crops: torch.stack([ToTensor()(crop) for crop in crops])) # returns a 4D tensor
>>> ])
>>> #In your test loop you can do the following:
>>> input, target = batch # input is a 5d tensor, target is 2d
>>> bs, ncrops, c, h, w = input.size()
>>> result = model(input.view(-1, c, h, w)) # fuse batch size and ncrops
>>> result_avg = result.view(bs, ncrops, -1).mean(1) # avg over crops
有问题:当采用数据增强时,一方面采用TenCrop形式,另一方面采用其他数据变换,一同送入Dataloader时会产生错误,因为维度不一致,其他数据变换在dataset中为三维(channel,H,W),而TenCrop却是四维(ncrops,channel,H,W),于是当迭代获取Batch时会由于维度不匹配程序报错。
解决方法:【等后续找到再来写,手动狗头微笑】
(3)翻转、旋转
transforms.RandomHorizontalFlip(p=0.5) # 依概率进行水平(左右)翻转
transforms.RandomVerticalFlip(p=0.5) # 依概率进行垂直(上下)翻转
transforms.RandomRotation(degrees, resample=False, expand=False, center=None) # 随机旋转图片
# degrees:旋转角度,若为a,则在(-a,a)之间二选一,若为(a, b),则二选一
# expand:是否扩大图片(因为旋转过后可能会丢失图片某一块),仅针对中心点旋转
# center:旋转点设置,默认中心点
(4)对各种变换的组合--》选择操作(如RandomChoice)
transforms.RandomChoice([transforms1, transforms2, ......]) # 随机挑选一个
transforms.RandomApply([transforms1, transforms2, ......], p=0.5) # 依概率执行整个一组(要么执行,要么不执行)
transforms.RandomOrder([transforms1, transforms2, ......]) # 对一组操作进行打乱顺序,再去执行这一组
4.自定义Transforms方法
class YourTransforms(object):
def __init__(self,Arg1,Arg2):
'''传参数'''
def __call__(self, x):
'''定义该Transforms方法'''
return x