Understanding transform.Normalize( )
如果自己的数据集和 imageNet 数据集相差较大,需要计算自己数据集的 mean 和 std。
Normalize 让模型可以更快的收敛。
参考:https://discuss.pytorch.org/t/understanding-transform-normalize/21730/17
该文章解释了 怎么计算 mean 和 std,并且 给出了 为什么 需要 Normalize?
https://blog.csdn.net/u013685264/article/details/126764095
教给你怎么计算自己数据集的mean 和 std。当我用该 代码 计算 10w + 的图像的 mean + std 时候,直接 ImageFolder 读取 卡死,需要优化该代码。
优化的代码:放弃 ImageFolder 读取 ,使用自定义的 dataset。
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
'''
@File : get_mean_std.py
@Time : 2023/05/19 17:16:53
@Author :
@Version : 1.0
@Contact :
@License :
@Desc :
'''
# here put the import lib
from pathlib import Path
import sys
import tqdm
FILE = Path(__file__).resolve()
ROOT = FILE.parents[1]
if str(ROOT) not in sys.path:
sys.path.append(str(ROOT)) # add ROOT to PATH
import torch
from torchvision import transforms
from dataset import ClsDataset
def tensor_transform():
return transforms.Compose([
transforms.ToTensor(),
])
def getStat(train_data):
'''
Compute mean and variance for training data
:param train_data: 自定义类Dataset(或ImageFolder即可)
:return: (mean, std)
'''
print('Compute mean and variance for training data.')
print(len(train_data))
train_loader = torch.utils.data.DataLoader(
train_data, batch_size=1, shuffle=False, num_workers=0,
pin_memory=True)
mean = torch.zeros(3)
std = torch.zeros(3)
for X, _, _ in tqdm.tqdm(train_loader):
for d in range(3):
mean[d] += X[:, d, :, :].mean()
std[d] += X[:, d, :, :].std()
mean.div_(len(train_data))
std.div_(len(train_data))
return list(mean.numpy()), list(std.numpy())
if __name__ == '__main__':
train_list = 'xxx/train_list.txt'
train_dataset = ClsDataset(
list_file = train_list,
transform = tensor_transform()
)
print(getStat(train_dataset))
输出:
Compute mean and variance for training data.
101378
100%|██████████████████████████████████████████████████████████████████████████████████████| 101378/101378 [05:58<00:00, 282.67it/s]
([0.43321776, 0.3833695, 0.36299735], [0.24641703, 0.2344137, 0.22908022])
这里有个问题: 如果我们把数据集划分为了 训练集、验证集,此外我们还有自己的test集,那么 这个 mean + std 是在哪个数据集上算出来的呢?
目前我是从 训练集算出来的,至于是否需要从 训练集+验证集 的所有数据集上算出来,我还不不知道。我看了个回答:数据预处理的归一化手段应该如何应用到训练集,测试集和验证集中? - StefanChou的回答 - 知乎
https://www.zhihu.com/question/60490799/answer/214685372
他说是要从 训练集算出来。
其他参考链接:
https://forums.fast.ai/t/image-normalization-in-pytorch/7534?u=laochanlam
注意:如果训练时候使用了 Normalize ,测试时候一定也要用,不然精度会大量下降。反之亦然。