图像数据集的均值与方差

使用Pytorch进行预处理时,通常使用torchvision.transforms.Normalize(mean, std)方法进行数据标准化,其中参数mean和std分别表示图像集每个通道的均值和标准差序列。
首先,给出mean和std的定义,数学表示如下:
假设有一组数据集X i ,    i ∈ { 1 , 2 , ⋯   , n },则这组数据集的均值为:

这组数据集的标准差为

 

下面给出计算图像数据集每个通道的均值和标准差的函数代码:

import torch
from torchvision import transforms,datasets
from torch.utils.data import DataLoader

batch_size = 64

# 训练集(以CIFAR-10数据集为例)
train_dataset = datasets.CIFAR10(root='G:/datasets/cifar10',train=True,download=False,transform=transforms.ToTensor())
train_loader = DataLoader(train_dataset,shuffle=True,batch_size=batch_size)

def get_mean_std_value(loader):
    '''
    求数据集的均值和标准差
    :param loader:
    :return:
    '''
    data_sum,data_squared_sum,num_batches = 0,0,0

    for data,_ in loader:
        # data: [batch_size,channels,height,width]
        # 计算dim=0,2,3维度的均值和,dim=1为通道数量,不用参与计算
        data_sum += torch.mean(data,dim=[0,2,3])    # [batch_size,channels,height,width]
        # 计算dim=0,2,3维度的平方均值和,dim=1为通道数量,不用参与计算
        data_squared_sum += torch.mean(data**2,dim=[0,2,3])  # [batch_size,channels,height,width]
        # 统计batch的数量
        num_batches += 1
    # 计算均值
    mean = data_sum/num_batches
    # 计算标准差
    std = (data_squared_sum/num_batches - mean**2)**0.5
    return mean,std

mean,std = get_mean_std_value(train_loader)
print('mean = {},std = {}'.format(mean,std))

 原文来自:https://blog.csdn.net/weixin_43821559/article/details/123459085

posted @ 2023-03-14 22:04  抚琴尘世客  阅读(657)  评论(0编辑  收藏  举报