torchvision.transforms模块使用(自定义增强及调用库增强模块)
针对深度学习,基本会有一个数据增强环节,而该环节要不自己手写处理方法、要不调用已有的库,而对于已有库有很多。
本文仅仅使用torchvision中自带的transforms库,进行图像增强使用介绍,主要内容如下:
① 简单介绍下背景
②调用重点函数介绍
③使用简单代码实现数据增强,主要使用PIL读图方式
④在③的基础上使用如何使用cv读图方式实现数据增强模块
⑤在④的基础上如何使用自定义的函数实现数据增强
注:④ ⑤为更改版本
一.简单介绍背景
torchvision
是pytorch的一个图形库,它服务于PyTorch深度学习框架的,主要用来构建计算机视觉模型。
transforms是torchvision的一个功能,主要集成数据增强的变化,针对torchvision的其它功能本文将不在介绍
。
二.transforms的重点函数
transforms.Compose将集成所有数据增强模块的功能,而数据增强模块除了调用transforms中的函数模块外,还可以自定义函数功能模块,实现任意自由的方法,将在后文介绍。
transforms.Compose( [transforms.ToTensor(), transforms.Normalize(std=(0.5,0.5,0.5),mean=(0.5,0.5,0.5))] )
函数一:transforms.ToTensor()。
将shape
为(H, W, C)
的nump.ndarray
或img
转为shape
为(C, H, W)
的tensor
,其将每一个数值归一化到[0,1]
,其归一化方法比较简单,直接除以255即可
函数二:transforms.ToTensor()。则其作用就是先将输入归一化到(0,1)
,再使用公式"(x-mean)/std"
,将每个元素分布到(-1,1)
三.使用PIL库读图调用数据增强模块
from torchvision import transformsfrom PIL import Image import matplotlib.pyplot as plt if __name__ == '__main__': root = r'D:\Users\User\Desktop\ligth_0824\123\a\1.jpg' img = Image.open(root) transform = transforms.Compose([ transforms.ColorJitter(brightness=0.5) ]) img_new=transform(img) img_new=np.asarray(img_new) plt.imshow(img_new) plt.show()
说明:以上代码简单实现此库的调用方法,值得说明一点,transforms内部增强模块使用call函数的类,其参数为PIL的,因此传出来也为PIL格式,需转换为CV格式。
四.在③的基础上使用如何使用cv读图方式实现数据增强模块
from PIL import Image from torchvision import transforms transform = transforms.Compose([ # transforms.RandomResizedCrop(224), # transforms.RandomHorizontalFlip(), transforms.ColorJitter(brightness=0.5), # transforms.ColorJitter(contrast=0.8), # transforms.ColorJitter(hue=0.5), # transforms.ColorJitter(saturation=(0.6,1.2)) ]) if __name__ == '__main__': root=r'C:\Users\Administrator\Desktop\company_data\a\50.jpg' img=cv2.imread(root) # img=np.array(img) img = Image.fromarray(img) img_new = transform(img) img_new=np.asarray(img_new) show_img(img_new)
说明:有时使用CV读图,传进来的参数为CV格式,那么需将CV格式转为PIL格式,才能实现transforms库的使用,详情如上代码。
结果如下图:
五 在④的基础上如何使用自定义的函数实现数据增强
有时需要自定义功能的数据增强方法,此时需编写一个增强方法的类,使用call函数实现自定义的功能,以下代码为实现方法。
from PIL import Image from torchvision import transforms import random class noisy(): def __call__(self, img): if random.random()<1: #************ *转为CV格式**************** image = np.asarray(img) #***************自定义功能方法**************** s_vs_p = 0.6 # 设置添加噪声图像像素的数目 amount = 0.04 noisy_img = np.copy(image)# 添加salt噪声 num_salt = np.ceil(amount * image.size * s_vs_p)# 设置添加噪声的坐标位置 coords = [np.random.randint(0, i - 1, int(num_salt)) for i in image.shape] noisy_img[coords] = 255# 添加pepper噪声 num_pepper = np.ceil(amount * image.size * (1. - s_vs_p))# 设置添加噪声的坐标位置 coords = [np.random.randint(0, i - 1, int(num_pepper)) for i in image.shape] noisy_img[coords] = 0 #*************** 以下转换PIL格式输出********** noisy_img=np.array(noisy_img,dtype=np.uint8) img = Image.fromarray(noisy_img) return img transform = transforms.Compose([ transforms.ColorJitter(brightness=0.3), transforms.ColorJitter(contrast=0.3), transforms.ColorJitter(hue=0.3), transforms.ColorJitter(saturation=(0.6,1.2)), noisy() ]) if __name__ == '__main__': root=r'C:\Users\Administrator\Desktop\company_data\a\50.jpg' img=cv2.imread(root) img = Image.fromarray(img) img_new = transform(img) img_new=np.asarray(img_new) show_img(img_new)