[基础]BatchNormalization及其在mmdet中的使用

Batch Normlization

NOTES: 下面关于BN概念的介绍基本转载自 (1)

什么是 BN ?

作用:类似于在输入的时候对输入数据进行零均值化和方差归一化的操作,只是 BN 是发生在网络的中间层

算法流程:

  • 注意流程中的最后一步也称之为仿射(affine),引入这一步的目的主要是设计一个通道,使得输出output至少能够回到输入input的状态(当γ = 1 , β = 0 时)使得BN的引入至少不至于降低模型的表现,这是深度网络设计的一个套路。
  • $\gamma $ 和 $\beta $ 是网络学习的参数,在 pytorch 中用 weightsbias 表示
  • \(\mu_{\mathcal{B}}\)\(\sigma^2_{\mathcal{B}}\) 是 batch 的统计特性,严格来说不算“学习” 到的参数。在 pytorch 中,这两个统计参数,用running_meanrunning_var表示,这里的 running 指的是当前的统计参数不一定只是由当前输入的batch决定,还可能和历史输入的batch有关。

在 pytorch 中使用

torch.nn.BatchNorm1d(num_features, 
                     eps=1e-05, 
                     momentum=0.1, 
                     affine=True, 
                     track_running_stats=True)

一般来说pytorch中的模型都是继承 nn.Module 类的,都有一个属性trainning指定是否是训练状态,训练状态与否将会影响到某些层的参数是否是固定的,比如BN层或者Dropout层。通常用Module.train()指定当前模型为训练状态,Module.eval()指定当前模型为测试状态。

  • 延伸:如果模型中有BN层 或是 Dropout,需要在训练时添加model.train(),在测试时添加model.eval()。其中model.train() 是保证BN层用每一批数据的均值和方差,而model.eval() 是保证BN用全部训练数据的均值和方差;而对于Dropout,model.train() 是随机取一部分网络连接来训练更新参数,而model.eval() 是利用到了所有网络连接。 (5)

BN的API中比较重要的参数,一个是affine指定是否需要仿射,还有个是track_running_stats指定是否跟踪当前batch的统计特性。

容易出现问题也正好是这三个参数:trainningaffinetrack_running_stats

一般来说,trainningtrack_running_stats有四种组合[7]

  1. trainning=True, track_running_stats=True。这个是期望中的训练阶段的设置,此时BN将会跟踪整个训练过程中batch的统计特性。
  2. trainning=True, track_running_stats=False。此时BN只会计算当前输入的训练batch的统计特性,可能没法很好地描述全局的数据统计特性。
  3. trainning=False, track_running_stats=True。这个是期望中的测试阶段的设置,此时BN会用之前训练好的模型中的(假设已经保存下了)running_meanrunning_var并且不会对其进行更新。一般来说,只需要设置model.eval()其中model中含有BN层,即可实现这个功能。
  4. trainning=False, track_running_stats=False 效果同(2),只不过是位于测试状态,这个一般不采用,这个只是用测试输入的batch的统计特性,容易造成统计特性的偏移,导致糟糕效果。

同时,我们要注意到,BN层中的running_meanrunning_var的更新是在forward()操作中进行的,而不是optimizer.step()中进行的,因此如果处于训练状态,就算你不进行手动step(),BN的统计特性也会变化的。

实际应用

实际应用中,BN 层可能学习到的参数非常不稳定,不利于最终的性能。

RetinaNet 在训练过程冻结了BN层 (2), (3),因为 batch size 太小了,这时候使用 BN 学习到的参数难以稳定

mmdetection 中,大部分 detector 的配置文件中,通过设置 norm_eval=True 使得在训练过程冻结 BN 层,而在推理过程时使用 BN 层 。

在训练过程冻结 BN 层是什么意思?

  • 网络不会更新 \(\gamma\)\(\beta\) 参数。
  • 同时,根据冻结的程度,决定是否会计算running_meanrunning_var
    • mmdetection 中,冻结是通过训练时挑出BN层,将这些层设置为 eval 状态。这样算是 fully freeze,即 mean\var 也不会计算。 (4)
    • 也有的冻结方式只是将 BN 层的 learnable parameters 设置 requires_grad = False

在推理过程时使用 BN 层 是什么意思?

  • 应该是用训练时计算的 running_meanrunning_var ,而 weights, bias 是 None

但是如果用的是多GPU训练,且 BN 层使用的是 SyncBN 而不是 常规的 BN, 就会使 norm_eval=False ,即训练过程不会冻结 BN 层。

参考资料

[1] https://blog.csdn.net/LoseInVain/article/details/86476010

[2] https://github.com/yhenon/pytorch-retinanet/issues/24

[3] https://github.com/kuangliu/pytorch-retinanet/issues/18

[4] https://www.jianshu.com/p/142e2ab879d3

[5] http://blog.sciencenet.cn/blog-3428464-1252019.html

[6] https://yichengsu.github.io/2019/12/pytorch-batchnorm-freeze/

posted @ 2021-02-04 17:33  lunaY  阅读(1101)  评论(0编辑  收藏  举报