Batch Normalization
前文讲到 Batch Normalization 可以有效的降低梯度消失和梯度爆炸的发生。本文就对 Batch Normalization 进行详细的介绍。首先从两个角度介绍 Batch Normalization 的作用。
a、Internal Covariate Shift
在机器学习中要求"应用数据集的分布"需要与"训练的数据集分布一致", 如果不一致的话便会出现 Covariate Shift 现象,体现在预测的结果不是很准。为了降低 Covariate Shift 的影响,模型一般会对输入数据进行标准化处理。我们想另外一个问题:对于神经网络来讲,每一层都是下一层的输入,如果当前层的输出改变了,就意味这下一层的输入分布发生变化了。这叫做Internal Covariate Shift。为了让每一层每次都是在学习同样的分布,可以进行 Batch Normalization,这样可以加速训练的速度。
b、梯度消失梯度爆炸
Batch Normalization 主要是针对梯度消失的。以 sigmoid 函数为例,输入只在[-1,1]间梯度比较明显,而在这个区间之外梯度非常非常小,这样很容易导致梯度消失。为了减轻这种效果的影响,可以进行 Batch Normalization。
Batch Normalization 过程如以下图:
整个算法流程很好理解。对于每一次 min-batch 的输入首先进行标准化,这一步的目的是让每次的 min-batch 符合同一分布。需要注意的是在标准化之后增加了一个放缩和偏移调整。我的理解是对标准的化的输入进行微调来更好的拟合激活函数和目标 label。