【PyTorch】Normalization
一图胜千言
以四维 N x C x H x W (批量大小 x 通道数 x 高 x 宽)为例,更形象一点,有 N 张图片,每张图片有 C 个通道,每个通道的大小为 H x W。
相同的颜色代表同一个计算均值和方差的区域。
以 Batch Norm 为例,粉色区域:[1, 1, *, *](即第一张图片的第一个通道的大小为H x W 的区域),[2, 1, *, *],…… ,[N, 1, *, *]。利用这 N 个粉色区域的数去计算均值和方差,然后将计算得到的均值和方差作用到在每个粉色区域的数上,就完成了标准化。
再理解
如果上面的图看懂了,那么下面这个常见(但难以理解)的图也就懂了。
代码
1 import torch
2 import math
3
4
5 def manual_fun(tensor):
6 mean = tensor.sum() / tensor.numel()
7 var = ((tensor - mean) * (tensor - mean)).sum() / tensor.numel()
8 new_tensor = (tensor - mean) / (math.sqrt(var + 1e-5))
9 return new_tensor
10
11
12 l = torch.tensor(
13 [[[[11, 2, 3], [4, 57, 6], [7, 8, 9]], [[1, 2, -3], [-4, 5, 6], [7, 8, 9]], [[1, 92, 3], [4, -95, 6], [7, 18, 9]]],
14 [[[46, 7, 8], [4, 66, 7], [7, 8, 9]], [[100, 2, 3], [4, 5, 6], [7, 8, 9]], [[1, 2, 3], [4, 6, 6], [7, 8, 9]]]],
15 dtype=torch.float)
16 print(l.shape, l, sep='\n')
17 auto_tensor = torch.nn.InstanceNorm2d(3, momentum=1)(l)
18 for i in range(0, 2):
19 for j in range(0, 3):
20 print('manual: ', manual_fun(l[i, j, :, :]))
21 print('auto: ', auto_tensor[i, j, :, :])
最后
关于 Normalization ,思考了很久,网上的资料也乱七八糟(我菜,所以就自己整理下吧。
把第一张图看懂了,然后结合 PyTorch 官方代码和我的样例,还不会你来扇我(:。
第一张图片参考于NJU1healer 的博客