bp算法中为什么会产生梯度消失?

作者:维吉特伯
链接:https://www.zhihu.com/question/49812013/answer/148825073
来源:知乎
著作权归作者所有。商业转载请联系作者获得授权,非商业转载请注明出处。

简单地说,根据链式法则,如果每一层神经元对上一层的输出的偏导乘上权重结果都小于1的话(w_{ij}y_{i}'<1.0 ),那么即使这个结果是0.99,在经过足够多层传播之后,误差对输入层的偏导会趋于0(\lim_{n\to\infty}0.99^n=0 )。下面是数学推导推导。

假设网络输出层中的第k 个神经元输出为y_{k}(t),而要学习的目标为d_{k}(t) 。这里的t 表示时序,与输入无关,可以理解为网络的第t 层。


若采用平方误差作为损失函数,第k 个输出神经元对应的损失为 L=\frac{1}{2}(d_{k}(t)-y_{k}(t))^{2}

将损失L 对输出y_{k}(t)求偏导 \vartheta_{k}(t)=\frac{\partial{L}}{\partial{y_{k}(t)}}=y_{k}'(t)(d_{k}(t)-y_{k}(t))

根据链式法则,我们知道,第t-1 层的梯度可以根据第t 层的梯度求出来

\vartheta_{i}(t-1)=y_{i}'(t-1)\sum_{j}w_{ij}\vartheta_{j}(t)

这里用i 表示第t-1 层的第i 个神经元,j 表示第t 层的第j 个神经元。

进一步,第t-q 层的梯度可以由第t-q+1 层的梯度计算出来

\vartheta_{i}(t-q)=y_{i}'(t-q)\sum_{j}w_{ij}\vartheta_{j}(t-q+1)

这实际上是一个递归嵌套的式子,如果我们对\vartheta_{j}(t-q+1) 做进一步展开,可以得到式子

\vartheta_{i}(t-q)=y_{i}'(t-q)\sum_{j}w_{ij}[y_{j}'(t-q+1)\sum_{k}w_{jk}\vartheta_{k}(t-q+2))]

最终,可以一直展开到第t 层。

把所有的加法都移到最外层,可以得到

\vartheta_{i}(t-q)=\sum_{l_{t-q+1}=1}^{n}\cdot\cdot\cdot\sum_{l_{t}=1}^{n}\prod_{m=0}^{q}w_{l_{m}l_{m-1}}\vartheta_{lm}(t-m)

l_{t-q+1} 表示的是第t-q+1 层中神经元的下标(即第t-q+1 层第l_{t-q+1} 个神经元),l_{t} 表示第t 层的下标。m=0 对应输出层,m=q 对应第t-q 层。实际上展开式就是从网络的第t 层到t-q 层,每一层都取出一个神经元来进行排列组合的结果。这个式子并不准确,因为m=0 时实际是损失L 对输出层的偏导,即

\vartheta_{k}(t)=y_{k}'(t)(d_{k}(t)-y_{k}(t))

并没有应用权重w_{l_{m}l_{m-1}},把它修正一下

\vartheta_{i}(t-q)=\sum_{l_{t-q+1}=1}^{n}\cdot\cdot\cdot\sum_{l_{t}=1}^{n}\prod_{m=1}^{q}w_{l_{m}l_{m-1}}y_{lm}'(t-m)\cdot\vartheta_{k}(t)

这样,我们就得到了第t-q 层和第t 层的梯度之间的关系

\frac{\vartheta_{i}(t-q)}{\vartheta_{k}(t)}=\sum_{l_{t-q+1}=1}^{n}\cdot\cdot\cdot\sum_{l_{t}=1}^{n}\prod_{m=1}^{q}w_{l_{m}l_{m-1}}y_{lm}'(t-m)

在上面的式子中,由于加法项正负号之间可能互相抵消。因此,比值的量级主要受最后的乘法项影响。如果对于所有的m

|w_{l_{m}l_{m-1}}y_{lm}'(t-m)|>1.0

则梯度会随着反向传播层数的增加而呈指数增长,导致梯度爆炸。

如果对于所有的m

|w_{l_{m}l_{m-1}}y_{lm}'(t-m)|<1.0

则在经过多层的传播后,梯度会趋向于0,导致梯度消失。

LSTM就是为了解决以上两个问题提出的方法之一,它强制令w_{l_{m}l_{m-1}}y_{lm}'(t-m)=1.0LSTM如何来避免梯度弥撒和梯度爆炸? - 知乎

有兴趣可以参考Long Short Term Memory 一文 。上面的推导过程大体上也参考自这篇论文。

Reference:

Graves, Alex. Long Short-Term Memory. Supervised Sequence Labelling with Recurrent Neural Networks. Springer Berlin Heidelberg, 2012:1735-1780. 

posted @ 2017-10-20 15:05  Django's blog  阅读(611)  评论(0编辑  收藏  举报