torchvision 之 transforms 模块详解

torchvision 是独立于 PyTorch 的关于图像操作的一个工具库,目前包括六个模块:

   1)torchvision.datasets:几个常用视觉数据集,可以下载和加载,以及如何编写自己的 Dataset。

   2)torchvision.models:经典模型,例如 AlexNet、VGG、ResNet 等,以及训练好的参数。

   3)torchvision.transforms:常用的图像操作,例随机切割、旋转、数据类型转换、tensor 与 numpy 和 PIL Image 的互换等。

   4)torchvision.ops:提供 CV 中常用的一些操作,比如 NMS、ROI_Align、ROI_Pool 等。

   5)torchvision.io:提供输入输出的一些操作,目前针对的是视频的写入写出。

   6)torchvision.utils:其他工具,比如产生一个图像网格等。

这里主要介绍下 torchvision.transforms 模块。torchvision.transforms 是 pytorch 中的图像预处理包。一般用 Compose 把多个步骤整合到一起。

"""
transforms: list of Transform objects, 是一个列表
"""
class torchvision.transforms.Compose(transforms)

 事实上,Compose()类会对 transforms 列表里面的 transform 操作进行遍历。实现的代码很简单,截取部分源码如下:

def __call__(self, img):
    for t in self.transforms:   
        img = t(img)
    return img

transforms 中的常见图像操作:

1. transforms.ToTensor 

   Converts a PIL Image or numpy.ndarray (H x W x C) in the range [0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0]。

   这个变换改变了图像的参数顺序,最终得到的图像形状为 $(C,H,W)$,并转换为 Tensor 类型,归一化至 [0,1] 是直接除以 255,每个像素变成一个 32 位

   的浮点类型。

 

2. transforms.Normalize

"""
mean (sequence) – Sequence of means for each channel.
std (sequence) – Sequence of standard deviations for each channel.
"""
torchvision.transforms.Normalize(mean, std)

   当数据量很大的时候,每个通道的数据都可以看成正态分布(大数定律),求出每个通道数据对应的均值和标准差,然后利用这两个值将每个通道数据的分布

   转换为标准正态分布。

posted @ 2020-12-05 13:57  _yanghh  阅读(2758)  评论(0编辑  收藏  举报