[基础]BatchNormalization及其在mmdet中的使用
Batch Normlization
NOTES: 下面关于BN概念的介绍基本转载自 (1)
什么是 BN ?
作用:类似于在输入的时候对输入数据进行零均值化和方差归一化的操作,只是 BN 是发生在网络的中间层
算法流程:
- 注意流程中的最后一步也称之为仿射(affine),引入这一步的目的主要是设计一个通道,使得输出output至少能够回到输入input的状态(当γ = 1 , β = 0 时)使得BN的引入至少不至于降低模型的表现,这是深度网络设计的一个套路。
- $\gamma $ 和 $\beta $ 是网络学习的参数,在 pytorch 中用
weights
和bias
表示 - \(\mu_{\mathcal{B}}\) 和 \(\sigma^2_{\mathcal{B}}\) 是 batch 的统计特性,严格来说不算“学习” 到的参数。在 pytorch 中,这两个统计参数,用
running_mean
和running_var
表示,这里的running
指的是当前的统计参数不一定只是由当前输入的batch决定,还可能和历史输入的batch有关。
在 pytorch 中使用
torch.nn.BatchNorm1d(num_features,
eps=1e-05,
momentum=0.1,
affine=True,
track_running_stats=True)
一般来说pytorch中的模型都是继承 nn.Module
类的,都有一个属性trainning
指定是否是训练状态,训练状态与否将会影响到某些层的参数是否是固定的,比如BN层或者Dropout层。通常用Module.train()
指定当前模型为训练状态,Module.eval()
指定当前模型为测试状态。
- 延伸:如果模型中有BN层 或是 Dropout,需要在训练时添加model.train(),在测试时添加model.eval()。其中model.train() 是保证BN层用每一批数据的均值和方差,而model.eval() 是保证BN用全部训练数据的均值和方差;而对于Dropout,model.train() 是随机取一部分网络连接来训练更新参数,而model.eval() 是利用到了所有网络连接。 (5)
BN的API中比较重要的参数,一个是affine
指定是否需要仿射,还有个是track_running_stats
指定是否跟踪当前batch的统计特性。
容易出现问题也正好是这三个参数:trainning
,affine
,track_running_stats
。
一般来说,trainning
和track_running_stats
有四种组合[7]
trainning=True
,track_running_stats=True
。这个是期望中的训练阶段的设置,此时BN将会跟踪整个训练过程中batch的统计特性。trainning=True
,track_running_stats=False
。此时BN只会计算当前输入的训练batch的统计特性,可能没法很好地描述全局的数据统计特性。trainning=False
,track_running_stats=True
。这个是期望中的测试阶段的设置,此时BN会用之前训练好的模型中的(假设已经保存下了)running_mean
和running_var
并且不会对其进行更新。一般来说,只需要设置model.eval()
其中model
中含有BN层,即可实现这个功能。trainning=False
,track_running_stats=False
效果同(2),只不过是位于测试状态,这个一般不采用,这个只是用测试输入的batch的统计特性,容易造成统计特性的偏移,导致糟糕效果。
同时,我们要注意到,BN层中的running_mean
和running_var
的更新是在forward()
操作中进行的,而不是optimizer.step()
中进行的,因此如果处于训练状态,就算你不进行手动step()
,BN的统计特性也会变化的。
实际应用
实际应用中,BN 层可能学习到的参数非常不稳定,不利于最终的性能。
RetinaNet 在训练过程冻结了BN层 (2), (3),因为 batch size 太小了,这时候使用 BN 学习到的参数难以稳定
在 mmdetection
中,大部分 detector 的配置文件中,通过设置 norm_eval=True
使得在训练过程冻结 BN 层,而在推理过程时使用 BN 层 。
在训练过程冻结 BN 层是什么意思?
- 网络不会更新 \(\gamma\) 和 \(\beta\) 参数。
- 同时,根据冻结的程度,决定是否会计算
running_mean
和running_var
- 在
mmdetection
中,冻结是通过训练时挑出BN层,将这些层设置为eval
状态。这样算是 fully freeze,即 mean\var 也不会计算。 (4) - 也有的冻结方式只是将 BN 层的 learnable parameters 设置
requires_grad = False
- 在
在推理过程时使用 BN 层 是什么意思?
- 应该是用训练时计算的
running_mean
和running_var
,而 weights, bias 是 None
但是如果用的是多GPU训练,且 BN 层使用的是 SyncBN
而不是 常规的 BN
, 就会使 norm_eval=False
,即训练过程不会冻结 BN 层。
参考资料
[1] https://blog.csdn.net/LoseInVain/article/details/86476010
[2] https://github.com/yhenon/pytorch-retinanet/issues/24
[3] https://github.com/kuangliu/pytorch-retinanet/issues/18
[4] https://www.jianshu.com/p/142e2ab879d3
[5] http://blog.sciencenet.cn/blog-3428464-1252019.html
[6] https://yichengsu.github.io/2019/12/pytorch-batchnorm-freeze/