torchvision.transforms模块使用(自定义增强及调用库增强模块)

针对深度学习,基本会有一个数据增强环节,而该环节要不自己手写处理方法、要不调用已有的库,而对于已有库有很多。

本文仅仅使用torchvision中自带的transforms库,进行图像增强使用介绍,主要内容如下:

① 简单介绍下背景

②调用重点函数介绍

③使用简单代码实现数据增强,主要使用PIL读图方式

④在③的基础上使用如何使用cv读图方式实现数据增强模块

⑤在④的基础上如何使用自定义的函数实现数据增强

注:④ ⑤为更改版本

 

 

一.简单介绍背景

torchvision是pytorch的一个图形库,它服务于PyTorch深度学习框架的,主要用来构建计算机视觉模型。

transforms是torchvision的一个功能,主要集成数据增强的变化,针对torchvision的其它功能本文将不在介绍。 

 二.transforms的重点函数

transforms.Compose将集成所有数据增强模块的功能,而数据增强模块除了调用transforms中的函数模块外,还可以自定义函数功能模块,实现任意自由的方法,将在后文介绍。
transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize(std=(0.5,0.5,0.5),mean=(0.5,0.5,0.5))]
) 

函数一:transforms.ToTensor()。shape(H, W, C)nump.ndarrayimg转为shape(C, H, W)tensor,其将每一个数值归一化到[0,1],其归一化方法比较简单,直接除以255即可

 函数二:transforms.ToTensor()。则其作用就是先将输入归一化到(0,1),再使用公式"(x-mean)/std",将每个元素分布到(-1,1)

 

三.使用PIL库读图调用数据增强模块

from torchvision import transformsfrom PIL import Image
import matplotlib.pyplot as plt
if __name__ == '__main__':
    root = r'D:\Users\User\Desktop\ligth_0824\123\a\1.jpg'
    img = Image.open(root)
    transform = transforms.Compose([
        transforms.ColorJitter(brightness=0.5)
    ])
    img_new=transform(img)
  img_new=np.asarray(img_new)
  plt.imshow(img_new)
   plt.show()
    

说明:以上代码简单实现此库的调用方法,值得说明一点,transforms内部增强模块使用call函数的类,其参数为PIL的,因此传出来也为PIL格式,需转换为CV格式。

 

 

 

 

四.在③的基础上使用如何使用cv读图方式实现数据增强模块

from PIL import Image
from torchvision import transforms
transform = transforms.Compose([
    # transforms.RandomResizedCrop(224),
    # transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(brightness=0.5),
    # transforms.ColorJitter(contrast=0.8),
    # transforms.ColorJitter(hue=0.5),
    # transforms.ColorJitter(saturation=(0.6,1.2))
])


if __name__ == '__main__':
    root=r'C:\Users\Administrator\Desktop\company_data\a\50.jpg'
    img=cv2.imread(root)
    # img=np.array(img)
    img = Image.fromarray(img)
    img_new = transform(img)
    img_new=np.asarray(img_new)

    show_img(img_new)

说明:有时使用CV读图,传进来的参数为CV格式,那么需将CV格式转为PIL格式,才能实现transforms库的使用,详情如上代码。

结果如下图:

 五 在④的基础上如何使用自定义的函数实现数据增强

有时需要自定义功能的数据增强方法,此时需编写一个增强方法的类,使用call函数实现自定义的功能,以下代码为实现方法。

from PIL import Image
from torchvision import transforms
import random
class noisy():
    def __call__(self, img):
        if random.random()<1:
            #************ *转为CV格式****************
            image = np.asarray(img)
            #***************自定义功能方法****************
            s_vs_p = 0.6 # 设置添加噪声图像像素的数目
            amount = 0.04
            noisy_img = np.copy(image)# 添加salt噪声
            num_salt = np.ceil(amount * image.size * s_vs_p)# 设置添加噪声的坐标位置
            coords = [np.random.randint(0, i - 1, int(num_salt)) for i in image.shape]
            noisy_img[coords] = 255# 添加pepper噪声
            num_pepper = np.ceil(amount * image.size * (1. - s_vs_p))# 设置添加噪声的坐标位置
            coords = [np.random.randint(0, i - 1, int(num_pepper)) for i in image.shape]
            noisy_img[coords] = 0
            #*************** 以下转换PIL格式输出**********
            noisy_img=np.array(noisy_img,dtype=np.uint8)
            img = Image.fromarray(noisy_img)
        return img
transform = transforms.Compose([
    transforms.ColorJitter(brightness=0.3),
    transforms.ColorJitter(contrast=0.3),
    transforms.ColorJitter(hue=0.3),
    transforms.ColorJitter(saturation=(0.6,1.2)),
    noisy()
])
if __name__ == '__main__':
    root=r'C:\Users\Administrator\Desktop\company_data\a\50.jpg'
    img=cv2.imread(root)
    img = Image.fromarray(img)
    img_new = transform(img)
    img_new=np.asarray(img_new)
    show_img(img_new)

 

posted @ 2022-04-05 18:51  tangjunjun  阅读(1172)  评论(0编辑  收藏  举报
https://rpc.cnblogs.com/metaweblog/tangjunjun