Understanding transform.Normalize( )
如果自己的数据集和 imageNet 数据集相差较大,需要计算自己数据集的 mean 和 std。
Normalize 让模型可以更快的收敛。
参考:https://discuss.pytorch.org/t/understanding-transform-normalize/21730/17
该文章解释了 怎么计算 mean 和 std,并且 给出了 为什么 需要 Normalize?
https://blog.csdn.net/u013685264/article/details/126764095
教给你怎么计算自己数据集的mean 和 std。当我用该 代码 计算 10w + 的图像的 mean + std 时候,直接 ImageFolder 读取 卡死,需要优化该代码。
优化的代码:放弃 ImageFolder 读取 ,使用自定义的 dataset。
#!/usr/bin/env python # -*- encoding: utf-8 -*- ''' @File : get_mean_std.py @Time : 2023/05/19 17:16:53 @Author : @Version : 1.0 @Contact : @License : @Desc : ''' # here put the import lib from pathlib import Path import sys import tqdm FILE = Path(__file__).resolve() ROOT = FILE.parents[1] if str(ROOT) not in sys.path: sys.path.append(str(ROOT)) # add ROOT to PATH import torch from torchvision import transforms from dataset import ClsDataset def tensor_transform(): return transforms.Compose([ transforms.ToTensor(), ]) def getStat(train_data): ''' Compute mean and variance for training data :param train_data: 自定义类Dataset(或ImageFolder即可) :return: (mean, std) ''' print('Compute mean and variance for training data.') print(len(train_data)) train_loader = torch.utils.data.DataLoader( train_data, batch_size=1, shuffle=False, num_workers=0, pin_memory=True) mean = torch.zeros(3) std = torch.zeros(3) for X, _, _ in tqdm.tqdm(train_loader): for d in range(3): mean[d] += X[:, d, :, :].mean() std[d] += X[:, d, :, :].std() mean.div_(len(train_data)) std.div_(len(train_data)) return list(mean.numpy()), list(std.numpy()) if __name__ == '__main__': train_list = 'xxx/train_list.txt' train_dataset = ClsDataset( list_file = train_list, transform = tensor_transform() ) print(getStat(train_dataset))
输出:
Compute mean and variance for training data. 101378 100%|██████████████████████████████████████████████████████████████████████████████████████| 101378/101378 [05:58<00:00, 282.67it/s] ([0.43321776, 0.3833695, 0.36299735], [0.24641703, 0.2344137, 0.22908022])
这里有个问题: 如果我们把数据集划分为了 训练集、验证集,此外我们还有自己的test集,那么 这个 mean + std 是在哪个数据集上算出来的呢?
目前我是从 训练集算出来的,至于是否需要从 训练集+验证集 的所有数据集上算出来,我还不不知道。我看了个回答:数据预处理的归一化手段应该如何应用到训练集,测试集和验证集中? - StefanChou的回答 - 知乎
https://www.zhihu.com/question/60490799/answer/214685372
他说是要从 训练集算出来。
其他参考链接:
https://forums.fast.ai/t/image-normalization-in-pytorch/7534?u=laochanlam
注意:如果训练时候使用了 Normalize ,测试时候一定也要用,不然精度会大量下降。反之亦然。
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· 分享一个免费、快速、无限量使用的满血 DeepSeek R1 模型,支持深度思考和联网搜索!
· 基于 Docker 搭建 FRP 内网穿透开源项目(很简单哒)
· ollama系列01:轻松3步本地部署deepseek,普通电脑可用
· 25岁的心里话
· 按钮权限的设计及实现