图像数据集的均值与方差

使用Pytorch进行预处理时,通常使用torchvision.transforms.Normalize(mean, std)方法进行数据标准化,其中参数mean和std分别表示图像集每个通道的均值和标准差序列。
首先,给出mean和std的定义,数学表示如下:
假设有一组数据集X i ,    i ∈ { 1 , 2 , ⋯   , n },则这组数据集的均值为:

这组数据集的标准差为

 

下面给出计算图像数据集每个通道的均值和标准差的函数代码:

复制代码
import torch
from torchvision import transforms,datasets
from torch.utils.data import DataLoader

batch_size = 64

# 训练集(以CIFAR-10数据集为例)
train_dataset = datasets.CIFAR10(root='G:/datasets/cifar10',train=True,download=False,transform=transforms.ToTensor())
train_loader = DataLoader(train_dataset,shuffle=True,batch_size=batch_size)

def get_mean_std_value(loader):
    '''
    求数据集的均值和标准差
    :param loader:
    :return:
    '''
    data_sum,data_squared_sum,num_batches = 0,0,0

    for data,_ in loader:
        # data: [batch_size,channels,height,width]
        # 计算dim=0,2,3维度的均值和,dim=1为通道数量,不用参与计算
        data_sum += torch.mean(data,dim=[0,2,3])    # [batch_size,channels,height,width]
        # 计算dim=0,2,3维度的平方均值和,dim=1为通道数量,不用参与计算
        data_squared_sum += torch.mean(data**2,dim=[0,2,3])  # [batch_size,channels,height,width]
        # 统计batch的数量
        num_batches += 1
    # 计算均值
    mean = data_sum/num_batches
    # 计算标准差
    std = (data_squared_sum/num_batches - mean**2)**0.5
    return mean,std

mean,std = get_mean_std_value(train_loader)
print('mean = {},std = {}'.format(mean,std))
复制代码

 原文来自:https://blog.csdn.net/weixin_43821559/article/details/123459085

posted @   抚琴尘世客  阅读(726)  评论(0编辑  收藏  举报
相关博文:
阅读排行:
· 震惊!C++程序真的从main开始吗?99%的程序员都答错了
· winform 绘制太阳,地球,月球 运作规律
· 【硬核科普】Trae如何「偷看」你的代码?零基础破解AI编程运行原理
· 上周热点回顾(3.3-3.9)
· 超详细:普通电脑也行Windows部署deepseek R1训练数据并当服务器共享给他人
点击右上角即可分享
微信分享提示