【PyTorch】Normalization

一图胜千言

  以四维 N x C x H x W (批量大小 x 通道数 x 高 x 宽)为例,更形象一点,有 N 张图片,每张图片有 C 个通道,每个通道的大小为 H x W。

  相同的颜色代表同一个计算均值和方差的区域。

  以 Batch Norm 为例,粉色区域:[1, 1, *, *](即第一张图片的第一个通道的大小为H x W 的区域),[2, 1, *, *],…… ,[N, 1, *, *]。利用这 N 个粉色区域的数去计算均值和方差,然后将计算得到的均值和方差作用到在每个粉色区域的数上,就完成了标准化。

 

再理解

  如果上面的图看懂了,那么下面这个常见(但难以理解)的图也就懂了。

代码

 1 import torch
 2 import math
 3 
 4 
 5 def manual_fun(tensor):
 6     mean = tensor.sum() / tensor.numel()
 7     var = ((tensor - mean) * (tensor - mean)).sum() / tensor.numel()
 8     new_tensor = (tensor - mean) / (math.sqrt(var + 1e-5))
 9     return new_tensor
10 
11 
12 l = torch.tensor(
13     [[[[11, 2, 3], [4, 57, 6], [7, 8, 9]], [[1, 2, -3], [-4, 5, 6], [7, 8, 9]], [[1, 92, 3], [4, -95, 6], [7, 18, 9]]],
14      [[[46, 7, 8], [4, 66, 7], [7, 8, 9]], [[100, 2, 3], [4, 5, 6], [7, 8, 9]], [[1, 2, 3], [4, 6, 6], [7, 8, 9]]]],
15     dtype=torch.float)
16 print(l.shape, l, sep='\n')
17 auto_tensor = torch.nn.InstanceNorm2d(3, momentum=1)(l)
18 for i in range(0, 2):
19     for j in range(0, 3):
20         print('manual: ', manual_fun(l[i, j, :, :]))
21         print('auto: ', auto_tensor[i, j, :, :])

最后

  关于 Normalization ,思考了很久,网上的资料也乱七八糟(我菜,所以就自己整理下吧。

  把第一张图看懂了,然后结合 PyTorch 官方代码和我的样例,还不会你来扇我(:。

  第一张图片参考于NJU1healer 的博客

posted @ 2022-03-06 16:04  Vivid-BinGo  阅读(81)  评论(0编辑  收藏  举报