Batch Normalization 和 Batch Renormalization 前向和反向公式详细推导

Batch Normalization 和 Batch Renormalization 前向和反向公式详细推导

一、BN前向传播

根据论文‘’Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift" 的推导过程,主要有下面四个公式:

\[\mu_B=\frac{1}{m}\sum_i^mx_i\tag{1}\label{1} \]

\[\delta_B^2=\frac{1}{m}\sum_i^m(x_i-\mu_B)^2\tag{2}\label{2} \]

\[\widehat{x_i}=\frac{x_i-\mu_B}{\sqrt{\delta_B^2+\epsilon}}\tag{3}\label{3} \]

\[y_i=\gamma\widehat{x_i}+\beta\tag{4}\label{4} \]

以MLP为例,假设输入的mini-batch样本数为\(m\),则此处的\(x_i,i=1,2,...m\)是第\(i\)个样本对应的某一层激活值中的一个激活值。也就是说,假设输入\(m\)个样本作为一次训练,其中第\(i\)个样本输入网络后,在\(l\)层得到了\(N\)个激活单元,则\(x_i\)代表其中任意一个激活单元。事实上应该写为\(x_i^l(n)\)更为直观。

所以BN实际上就是对第\(l\)层的第\(n\)个激活单元\(x_i^l(n)\)求其在一个batch中的平均值和方差,并对其进行标准归一化,得到\(\widehat{x_i^l(n)}\),可知归一化后的m个激活单元均值为0方差为1,一定程度上消除了Internal Covariate Shift,减少了网络的各层激活值在训练样本上的边缘分布的变化。

二、BN的反向传播

  • 设前一层的梯度为\(\frac{\partial{L}}{\partial{y_i}}\).
  • 需要计算\(\frac{\partial{L}}{\partial{x_i}},\frac{\partial{L}}{\partial{\gamma}}以及\frac{\partial{L}}{\partial{\beta}}\)

由链式法则以及公式\eqref{4}:

\[\frac{\partial{L}}{\partial{\gamma}}=\frac{\partial{L}}{\partial{y_i}}\frac{\partial{y_i}}{\partial{\gamma}}=\frac{\partial{L}}{\partial{y_i}}\widehat{x_i} \tag{5} \]

由于对于所有\(i=1,2...m. \frac{\partial{L}}{\partial{y_i}}\widehat{x_i}对\frac{\partial{L}}{\partial{\gamma}}\)均有贡献,因此一个batch的训练中将\(\frac{\partial{L}}{\partial{\gamma}}\)定义为:

\[\frac{\partial{L}}{\partial{\gamma}}=\sum_{i=1}^m \frac{\partial{L}}{\partial{y_i}}\widehat{x_i}\tag{6}\label{6} \]

同样有:

\[\frac{\partial{L}}{\partial{\beta}}=\sum_{i=1}^m \frac{\partial{L}}{\partial{y_i}}\tag{7}\label{7} \]

另外,求\(\frac{\partial{L}}{\partial{x_i}}\)过程则较为复杂。根据链式法则,以及公式\(\eqref{3}\),将\(\widehat{x_i}\)视为\(g(x_i,\delta_B^2,\mu_B)\)有:

\[\frac{\partial{L}}{\partial{x_i}}=\frac{\partial{L}}{\partial{y_i}}\frac{\partial{y_i}}{\partial{\widehat{x_i}}}(\frac{\partial{\widehat{x_i}}}{\partial{x_i}}+\frac{\partial{\widehat{x_i}}}{\partial{\delta_B^2}}\frac{\partial{\delta_B^2}}{\partial{x_i}}+\frac{\partial{\widehat{x_i}}}{\partial{\mu_B}}\frac{\partial{\mu_B}}{\partial{x_i}}) =\frac{\partial{L}}{\partial{y_i}}\frac{\partial{y_i}}{\partial{\widehat{x_i}}}(g_1'+g_2'\frac{\partial{\delta_B^2}}{\partial{x_i}}+g_3'\frac{\partial{\mu_B}}{\partial{x_i}}) \tag{8}\label{8} \]

而因为公式\(\eqref{2}\)可知上式括号中的第二项求偏导可以进一步拆分。(将\(\delta_B^2\)视为\(f(x_i,\mu_B)\)

\[\frac{\partial{\delta_B^2}}{\partial{x_i}}= \frac{\partial{\delta_B^2}}{\partial{x_i}}+ \frac{\partial{\delta_B^2}}{\partial{\mu_B}} \frac{\partial{\mu_B}}{\partial{x_i}}= f_1'+f_2'\frac{\partial{\mu_B}}{\partial{x_i}} \tag{9}\label{9} \]

注意公式\(\eqref{9}\)中的两个\(\frac{\partial{\delta_B^2}}{\partial{x_i}}\)代表不同的含义。由公式\(\eqref{8},\eqref{9}\)可知,只要求出\(f_1',f_2',g_1',g_2',g_3',\frac{\partial{\mu_B}}{\partial{x_i}},\frac{\partial{y_i}}{\partial{\widehat{x_i}}}\).即可求出\(\frac{\partial{L}}{\partial{x_i}}\).

原论文中将公式\(\eqref{8}\)拆分成如下几项:

\[\frac{\partial{L}}{\partial{x_i}}= \frac{\partial{L}}{\partial{\widehat{x_i}}} \frac{\partial{\widehat{x_i}}}{\partial{x_i}}+ \frac{\partial{L}}{\partial{\delta_B^2}} \frac{\partial{\delta_B^2}}{\partial{x_i}}+ \frac{\partial{L}}{\partial{\mu_B}} \frac{\partial{\mu_B}}{\partial{x_i}} \tag{10}\label{10} \]

其中:

\[\frac{\partial{L}}{\partial{\widehat{x_i}}}= \frac{\partial{L}}{\partial{y_i}} \frac{\partial{y_i}}{\partial{\widehat{x_i}}}= \frac{\partial{L}}{\partial{y_i}} \gamma\tag{10.1}\label{10.1} \]

\[\frac{\partial{\widehat{x_i}}}{\partial{x_i}}=g'_1=\frac{1}{\sqrt{\delta_B^2+\epsilon}} \tag{10.2}\label{10.2} \]

\[\frac{\partial{L}}{\partial{\delta_B^2}}= \frac{\partial{L}}{\partial{\widehat{x_i}}}g'_2= \frac{\partial{L}}{\partial{\widehat{x_i}}} \frac{\mu_B-x_i}{2}(\delta_B^2+\epsilon)^{-\frac{3}{2}} \longrightarrow \]

\[\sum_{i=1}^m\frac{\partial{L}}{\partial{\widehat{x_i}}} \frac{\mu_B-x_i}{2}(\delta_B^2+\epsilon)^{-\frac{3}{2}} \tag{10.3}\label{10.3} \]

\[\frac{\partial{\delta_B^2}}{\partial{x_i}}=f'_1=\frac{2(x_i-\mu_B)}{m} \tag{10.4}\label{10.4} \]

\[\frac{\partial{L}}{\partial{\mu_B}}= \frac{\partial{L}}{\partial{\widehat{x_i}}}g'_3+ \frac{\partial{L}}{\partial{\widehat{x_i}}}g'_2f'_2 \longrightarrow \]

\[\sum_{i=1}^m( \frac{\partial{L}}{\partial{\widehat{x_i}}}\frac{-1}{\sqrt{\delta_B^2+\epsilon}} +\frac{\partial{L}}{\partial{\delta_B^2}}\frac{2(\mu_B-x_i)}{m}) \tag{10.5}\label{10.5} \]

\[\frac{\partial{\mu_B}}{\partial{x_i}}=\frac{1}{m} \tag{10.6}\label{10.6} \]

最终BN的反向过程由公式\(\eqref{6},\eqref{7},\eqref{10}\)给出。

三、Batch Renormalization

参照论文—— Batch Renormalization: Towards Reducing Minibatch Dependence
in Batch-Normalized Models

Batch Renormalization是对传统BN的优化,该方法保证了train和inference阶段的等效性,解决了非独立同分布和小minibatch的问题。

1、前向

跟原来的公式类似,添加了两个非训练参数\(r,d\):

\[\mu_B=\frac{1}{m}\sum_i^mx_i\tag{1.1}\label{1.1} \]

\[\sigma_B=\sqrt{\epsilon+\frac{1}{m}\sum_i^m(x_i-\mu_B)^2}\tag{2.1}\label{2.1} \]

\[\widehat{x_i}=\frac{x_i-\mu_B}{\sigma_B}r+d\tag{3.1}\label{3.1} \]

\[y_i=\gamma\widehat{x_i}+\beta\tag{4.1}\label{4.1} \]

\[r=Stop\_Gradient(Clip_{[1/r_{max} ,r_{max}]}(\frac{\sigma_B}{\sigma}))\tag{5.1}\label{5.1} \]

\[d=Stop\_Gradient(Clip_{[-d_{max} ,d_{max}]}(\frac{\mu_B-\mu}{\sigma}))\tag{6.1}\label{6.1} \]


Update moving averages:

\[\mu:=\mu+\alpha(\mu_B-\mu)\tag{7.1}\label{7.1} \]

\[\sigma:=\sigma+\alpha(\sigma_B-\sigma)\tag{8.1}\label{8.1} \]

Inference:

\[y=\gamma\frac{x-\mu}{\sigma}+\beta\tag{9.1}\label{9.1} \]

相比于之前的BN只在训练时计算滑动均值与方差,推断时才使用他们;BRN在训练和推断时都用到了滑动均值与方差。

2、反向

反向的推导与BN类似,

\[\frac{\partial{L}}{\partial{\widehat{x_i}}}= \frac{\partial{L}}{\partial{y_i}} \frac{\partial{y_i}}{\partial{\widehat{x_i}}}= \frac{\partial{L}}{\partial{y_i}} \gamma\tag{10.11}\label{10.11} \]

\[\frac{\partial{L}}{\partial{\sigma_B}} \longrightarrow\sum_{i=1}^m \frac{\partial{L}}{\partial{\widehat{x_i}}} \frac{-r(x_i-\mu_B)}{\sigma_B^2} \tag{10.22}\label{10.22} \]

\[\frac{\partial{L}}{\partial{\mu_B}}\longrightarrow\sum_{i=1}^{m}\frac{\partial{L}}{\partial{\widehat{x_i}}}\frac{-r}{\sigma_B} \tag{10.33}\label{10.33} \]

\[\frac{\partial{L}}{\partial{x_i}}= \frac{\partial{L}}{\partial{\widehat{x_i}}} \frac{r}{\sigma_B}+ \frac{\partial{L}}{\partial{\sigma_B}} \frac{x_i-\mu_B}{m\sigma_B}+ \frac{\partial{L}}{\partial{\mu_B}} \frac{1}{m} \tag{10.44}\label{10.44} \]

\[\frac{\partial{L}}{\partial{\gamma}}=\sum_{i=1}^m \frac{\partial{L}}{\partial{y_i}}\widehat{x_i}\tag{10.55}\label{10.55} \]

\[\frac{\partial{L}}{\partial{\beta}}=\sum_{i=1}^m \frac{\partial{L}}{\partial{y_i}}\tag{10.66}\label{10.66} \]

三、卷积网络中的BN

​ 上面的推导过程都是基于MLP的。对于卷积网络而言,BN过程中的m个激活单元被推广为m幅特征图像。 假设某一层卷积后的feature map是\([N,H,W,C]\)的张量,其中N表示batch数目,H,W分别表示长和宽,C表示特征通道数。则对卷积网络的BN操作时,令\(m = N\times H\times W\),也就是说将第\(i\)个batch内某一通道\(c\)上的任意一个特征图像素点视为\(x_i\),套用上面的BN公式即可。所以对于卷积网络来说,中间激活层每个通道都对应一组BN参数\(\gamma,\beta\).

posted @ 2020-04-10 22:09  love小酒窝  阅读(1778)  评论(0编辑  收藏  举报