图像数据集的均值与方差
使用Pytorch进行预处理时,通常使用torchvision.transforms.Normalize(mean, std)方法进行数据标准化,其中参数mean和std分别表示图像集每个通道的均值和标准差序列。
首先,给出mean和std的定义,数学表示如下:
假设有一组数据集X i , i ∈ { 1 , 2 , ⋯ , n },则这组数据集的均值为:
这组数据集的标准差为
下面给出计算图像数据集每个通道的均值和标准差的函数代码:
import torch from torchvision import transforms,datasets from torch.utils.data import DataLoader batch_size = 64 # 训练集(以CIFAR-10数据集为例) train_dataset = datasets.CIFAR10(root='G:/datasets/cifar10',train=True,download=False,transform=transforms.ToTensor()) train_loader = DataLoader(train_dataset,shuffle=True,batch_size=batch_size) def get_mean_std_value(loader): ''' 求数据集的均值和标准差 :param loader: :return: ''' data_sum,data_squared_sum,num_batches = 0,0,0 for data,_ in loader: # data: [batch_size,channels,height,width] # 计算dim=0,2,3维度的均值和,dim=1为通道数量,不用参与计算 data_sum += torch.mean(data,dim=[0,2,3]) # [batch_size,channels,height,width] # 计算dim=0,2,3维度的平方均值和,dim=1为通道数量,不用参与计算 data_squared_sum += torch.mean(data**2,dim=[0,2,3]) # [batch_size,channels,height,width] # 统计batch的数量 num_batches += 1 # 计算均值 mean = data_sum/num_batches # 计算标准差 std = (data_squared_sum/num_batches - mean**2)**0.5 return mean,std mean,std = get_mean_std_value(train_loader) print('mean = {},std = {}'.format(mean,std))
原文来自:https://blog.csdn.net/weixin_43821559/article/details/123459085
但行好事 莫问前程