torchvision 之 transforms 模块详解
torchvision 是独立于 PyTorch 的关于图像操作的一个工具库,目前包括六个模块:
1)torchvision.datasets:几个常用视觉数据集,可以下载和加载,以及如何编写自己的 Dataset。
2)torchvision.models:经典模型,例如 AlexNet、VGG、ResNet 等,以及训练好的参数。
3)torchvision.transforms:常用的图像操作,例随机切割、旋转、数据类型转换、tensor 与 numpy 和 PIL Image 的互换等。
4)torchvision.ops:提供 CV 中常用的一些操作,比如 NMS、ROI_Align、ROI_Pool 等。
5)torchvision.io:提供输入输出的一些操作,目前针对的是视频的写入写出。
6)torchvision.utils:其他工具,比如产生一个图像网格等。
这里主要介绍下 torchvision.transforms 模块。torchvision.transforms 是 pytorch 中的图像预处理包。一般用 Compose 把多个步骤整合到一起。
""" transforms: list of Transform objects, 是一个列表 """ class torchvision.transforms.Compose(transforms)
事实上,Compose()类会对 transforms 列表里面的 transform 操作进行遍历。实现的代码很简单,截取部分源码如下:
def __call__(self, img): for t in self.transforms: img = t(img) return img
transforms 中的常见图像操作:
1. transforms.ToTensor
Converts a PIL Image or numpy.ndarray (H x W x C) in the range [0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0]。
这个变换改变了图像的参数顺序,最终得到的图像形状为 $(C,H,W)$,并转换为 Tensor 类型,归一化至 [0,1] 是直接除以 255,每个像素变成一个 32 位
的浮点类型。
2. transforms.Normalize
""" mean (sequence) – Sequence of means for each channel. std (sequence) – Sequence of standard deviations for each channel. """ torchvision.transforms.Normalize(mean, std)
当数据量很大的时候,每个通道的数据都可以看成正态分布(大数定律),求出每个通道数据对应的均值和标准差,然后利用这两个值将每个通道数据的分布
转换为标准正态分布。