深度学习(计算数据集均值标准差)
深度学习中有些数据集可能不符合imagenet计算出的均值和标准差,需要根据自己的数据集单独计算。
下面这个脚本能够计算当前数据集均值和标准差。
import torch import os from PIL import Image from torchvision import transforms # trans = transforms.Compose([ # transforms.Resize((256, 256)), # transforms.ToTensor(), # transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), # ]) toTensor = transforms.ToTensor() root = './imgs/' torch.set_printoptions(precision=10) def get_file_names(directory): file_names = [] for file_name in os.listdir(directory): if os.path.isfile(os.path.join(directory, file_name)): file_names.append(file_name) return file_names filenames = get_file_names(root) mean = torch.zeros(3) std = torch.zeros(3) #tensor([0.4526, 0.4316, 0.3995]) tensor([0.2419, 0.2364, 0.2406]) count = 0 for file in filenames: imgname = root + file image = Image.open(imgname) tensor = toTensor(image) for c in range(3): mean[c] += tensor[c,:,:].mean() std[c] += tensor[c,:,:].std() count+=1 print(mean/count,std/count)