Batch Normalization 和 Batch Renormalization 前向和反向公式详细推导
一、BN前向传播
根据论文‘’Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift" 的推导过程,主要有下面四个公式:
\[\mu_B=\frac{1}{m}\sum_i^mx_i\tag{1}\label{1}
\]
\[\delta_B^2=\frac{1}{m}\sum_i^m(x_i-\mu_B)^2\tag{2}\label{2}
\]
\[\widehat{x_i}=\frac{x_i-\mu_B}{\sqrt{\delta_B^2+\epsilon}}\tag{3}\label{3}
\]
\[y_i=\gamma\widehat{x_i}+\beta\tag{4}\label{4}
\]
以MLP为例,假设输入的mini-batch样本数为\(m\),则此处的\(x_i,i=1,2,...m\)是第\(i\)个样本对应的某一层激活值中的一个激活值。也就是说,假设输入\(m\)个样本作为一次训练,其中第\(i\)个样本输入网络后,在\(l\)层得到了\(N\)个激活单元,则\(x_i\)代表其中任意一个激活单元。事实上应该写为\(x_i^l(n)\)更为直观。
所以BN实际上就是对第\(l\)层的第\(n\)个激活单元\(x_i^l(n)\)求其在一个batch中的平均值和方差,并对其进行标准归一化,得到\(\widehat{x_i^l(n)}\),可知归一化后的m个激活单元均值为0方差为1,一定程度上消除了Internal Covariate Shift,减少了网络的各层激活值在训练样本上的边缘分布的变化。
二、BN的反向传播
- 设前一层的梯度为\(\frac{\partial{L}}{\partial{y_i}}\).
- 需要计算\(\frac{\partial{L}}{\partial{x_i}},\frac{\partial{L}}{\partial{\gamma}}以及\frac{\partial{L}}{\partial{\beta}}\)
由链式法则以及公式\eqref{4}:
\[\frac{\partial{L}}{\partial{\gamma}}=\frac{\partial{L}}{\partial{y_i}}\frac{\partial{y_i}}{\partial{\gamma}}=\frac{\partial{L}}{\partial{y_i}}\widehat{x_i} \tag{5}
\]
由于对于所有\(i=1,2...m. \frac{\partial{L}}{\partial{y_i}}\widehat{x_i}对\frac{\partial{L}}{\partial{\gamma}}\)均有贡献,因此一个batch的训练中将\(\frac{\partial{L}}{\partial{\gamma}}\)定义为:
\[\frac{\partial{L}}{\partial{\gamma}}=\sum_{i=1}^m \frac{\partial{L}}{\partial{y_i}}\widehat{x_i}\tag{6}\label{6}
\]
同样有:
\[\frac{\partial{L}}{\partial{\beta}}=\sum_{i=1}^m \frac{\partial{L}}{\partial{y_i}}\tag{7}\label{7}
\]
另外,求\(\frac{\partial{L}}{\partial{x_i}}\)过程则较为复杂。根据链式法则,以及公式\(\eqref{3}\),将\(\widehat{x_i}\)视为\(g(x_i,\delta_B^2,\mu_B)\)有:
\[\frac{\partial{L}}{\partial{x_i}}=\frac{\partial{L}}{\partial{y_i}}\frac{\partial{y_i}}{\partial{\widehat{x_i}}}(\frac{\partial{\widehat{x_i}}}{\partial{x_i}}+\frac{\partial{\widehat{x_i}}}{\partial{\delta_B^2}}\frac{\partial{\delta_B^2}}{\partial{x_i}}+\frac{\partial{\widehat{x_i}}}{\partial{\mu_B}}\frac{\partial{\mu_B}}{\partial{x_i}})
=\frac{\partial{L}}{\partial{y_i}}\frac{\partial{y_i}}{\partial{\widehat{x_i}}}(g_1'+g_2'\frac{\partial{\delta_B^2}}{\partial{x_i}}+g_3'\frac{\partial{\mu_B}}{\partial{x_i}})
\tag{8}\label{8}
\]
而因为公式\(\eqref{2}\)可知上式括号中的第二项求偏导可以进一步拆分。(将\(\delta_B^2\)视为\(f(x_i,\mu_B)\))
\[\frac{\partial{\delta_B^2}}{\partial{x_i}}=
\frac{\partial{\delta_B^2}}{\partial{x_i}}+
\frac{\partial{\delta_B^2}}{\partial{\mu_B}}
\frac{\partial{\mu_B}}{\partial{x_i}}=
f_1'+f_2'\frac{\partial{\mu_B}}{\partial{x_i}}
\tag{9}\label{9}
\]
注意公式\(\eqref{9}\)中的两个\(\frac{\partial{\delta_B^2}}{\partial{x_i}}\)代表不同的含义。由公式\(\eqref{8},\eqref{9}\)可知,只要求出\(f_1',f_2',g_1',g_2',g_3',\frac{\partial{\mu_B}}{\partial{x_i}},\frac{\partial{y_i}}{\partial{\widehat{x_i}}}\).即可求出\(\frac{\partial{L}}{\partial{x_i}}\).
原论文中将公式\(\eqref{8}\)拆分成如下几项:
\[\frac{\partial{L}}{\partial{x_i}}=
\frac{\partial{L}}{\partial{\widehat{x_i}}}
\frac{\partial{\widehat{x_i}}}{\partial{x_i}}+
\frac{\partial{L}}{\partial{\delta_B^2}}
\frac{\partial{\delta_B^2}}{\partial{x_i}}+
\frac{\partial{L}}{\partial{\mu_B}}
\frac{\partial{\mu_B}}{\partial{x_i}}
\tag{10}\label{10}
\]
其中:
\[\frac{\partial{L}}{\partial{\widehat{x_i}}}=
\frac{\partial{L}}{\partial{y_i}}
\frac{\partial{y_i}}{\partial{\widehat{x_i}}}=
\frac{\partial{L}}{\partial{y_i}}
\gamma\tag{10.1}\label{10.1}
\]
\[\frac{\partial{\widehat{x_i}}}{\partial{x_i}}=g'_1=\frac{1}{\sqrt{\delta_B^2+\epsilon}}
\tag{10.2}\label{10.2}
\]
\[\frac{\partial{L}}{\partial{\delta_B^2}}=
\frac{\partial{L}}{\partial{\widehat{x_i}}}g'_2=
\frac{\partial{L}}{\partial{\widehat{x_i}}}
\frac{\mu_B-x_i}{2}(\delta_B^2+\epsilon)^{-\frac{3}{2}}
\longrightarrow
\]
\[\sum_{i=1}^m\frac{\partial{L}}{\partial{\widehat{x_i}}}
\frac{\mu_B-x_i}{2}(\delta_B^2+\epsilon)^{-\frac{3}{2}}
\tag{10.3}\label{10.3}
\]
\[\frac{\partial{\delta_B^2}}{\partial{x_i}}=f'_1=\frac{2(x_i-\mu_B)}{m}
\tag{10.4}\label{10.4}
\]
\[\frac{\partial{L}}{\partial{\mu_B}}=
\frac{\partial{L}}{\partial{\widehat{x_i}}}g'_3+
\frac{\partial{L}}{\partial{\widehat{x_i}}}g'_2f'_2
\longrightarrow
\]
\[\sum_{i=1}^m(
\frac{\partial{L}}{\partial{\widehat{x_i}}}\frac{-1}{\sqrt{\delta_B^2+\epsilon}}
+\frac{\partial{L}}{\partial{\delta_B^2}}\frac{2(\mu_B-x_i)}{m})
\tag{10.5}\label{10.5}
\]
\[\frac{\partial{\mu_B}}{\partial{x_i}}=\frac{1}{m}
\tag{10.6}\label{10.6}
\]
最终BN的反向过程由公式\(\eqref{6},\eqref{7},\eqref{10}\)给出。
三、Batch Renormalization
参照论文—— Batch Renormalization: Towards Reducing Minibatch Dependence
in Batch-Normalized Models
Batch Renormalization是对传统BN的优化,该方法保证了train和inference阶段的等效性,解决了非独立同分布和小minibatch的问题。
1、前向
跟原来的公式类似,添加了两个非训练参数\(r,d\):
\[\mu_B=\frac{1}{m}\sum_i^mx_i\tag{1.1}\label{1.1}
\]
\[\sigma_B=\sqrt{\epsilon+\frac{1}{m}\sum_i^m(x_i-\mu_B)^2}\tag{2.1}\label{2.1}
\]
\[\widehat{x_i}=\frac{x_i-\mu_B}{\sigma_B}r+d\tag{3.1}\label{3.1}
\]
\[y_i=\gamma\widehat{x_i}+\beta\tag{4.1}\label{4.1}
\]
\[r=Stop\_Gradient(Clip_{[1/r_{max} ,r_{max}]}(\frac{\sigma_B}{\sigma}))\tag{5.1}\label{5.1}
\]
\[d=Stop\_Gradient(Clip_{[-d_{max} ,d_{max}]}(\frac{\mu_B-\mu}{\sigma}))\tag{6.1}\label{6.1}
\]
Update moving averages:
\[\mu:=\mu+\alpha(\mu_B-\mu)\tag{7.1}\label{7.1}
\]
\[\sigma:=\sigma+\alpha(\sigma_B-\sigma)\tag{8.1}\label{8.1}
\]
Inference:
\[y=\gamma\frac{x-\mu}{\sigma}+\beta\tag{9.1}\label{9.1}
\]
相比于之前的BN只在训练时计算滑动均值与方差,推断时才使用他们;BRN在训练和推断时都用到了滑动均值与方差。
2、反向
反向的推导与BN类似,
\[\frac{\partial{L}}{\partial{\widehat{x_i}}}=
\frac{\partial{L}}{\partial{y_i}}
\frac{\partial{y_i}}{\partial{\widehat{x_i}}}=
\frac{\partial{L}}{\partial{y_i}}
\gamma\tag{10.11}\label{10.11}
\]
\[\frac{\partial{L}}{\partial{\sigma_B}}
\longrightarrow\sum_{i=1}^m
\frac{\partial{L}}{\partial{\widehat{x_i}}}
\frac{-r(x_i-\mu_B)}{\sigma_B^2}
\tag{10.22}\label{10.22}
\]
\[\frac{\partial{L}}{\partial{\mu_B}}\longrightarrow\sum_{i=1}^{m}\frac{\partial{L}}{\partial{\widehat{x_i}}}\frac{-r}{\sigma_B}
\tag{10.33}\label{10.33}
\]
\[\frac{\partial{L}}{\partial{x_i}}=
\frac{\partial{L}}{\partial{\widehat{x_i}}}
\frac{r}{\sigma_B}+
\frac{\partial{L}}{\partial{\sigma_B}}
\frac{x_i-\mu_B}{m\sigma_B}+
\frac{\partial{L}}{\partial{\mu_B}}
\frac{1}{m}
\tag{10.44}\label{10.44}
\]
\[\frac{\partial{L}}{\partial{\gamma}}=\sum_{i=1}^m \frac{\partial{L}}{\partial{y_i}}\widehat{x_i}\tag{10.55}\label{10.55}
\]
\[\frac{\partial{L}}{\partial{\beta}}=\sum_{i=1}^m \frac{\partial{L}}{\partial{y_i}}\tag{10.66}\label{10.66}
\]
三、卷积网络中的BN
上面的推导过程都是基于MLP的。对于卷积网络而言,BN过程中的m个激活单元被推广为m幅特征图像。 假设某一层卷积后的feature map是\([N,H,W,C]\)的张量,其中N表示batch数目,H,W分别表示长和宽,C表示特征通道数。则对卷积网络的BN操作时,令\(m = N\times H\times W\),也就是说将第\(i\)个batch内某一通道\(c\)上的任意一个特征图像素点视为\(x_i\),套用上面的BN公式即可。所以对于卷积网络来说,中间激活层每个通道都对应一组BN参数\(\gamma,\beta\).