RNN 梯度消失&梯度爆炸
RNN 梯度消失&梯度爆炸
参考:https://zhuanlan.zhihu.com/p/33006526?from_voters_page=true
梯度消失和梯度爆炸本质是同一种情况。梯度消失经常出现的原因:一是使用深层网络;二是采用不合适的损失函数,如Sigmoid。梯度爆炸一般出现的场景:一是深层网络;二是权值初始化太大。
1. 深层网络角度解释梯度消失和梯度爆炸
深层网络由许多非线性层堆叠而来,每一层网络激活后的输出为\(f_{i}(x)\),其中\(i\)为第\(i\)层,\(x\)是第\(i\)层的输入,即第\(i-1\)层的输出,\(f\)是激活函数,整个深层网络可视为一个复合的非线性多元函数:
目的是多元函数\(F(x)\)完成输入到输出的映射,假设不同的输入,输出的最优解是g(x),则优化深层网络就是为了找到合适的权值,满足\(Loss=L(g(x),F(x))\)取得极小值。
BP 算法基于梯度下降策略,以负梯度方向对参数进行调整,参数更新:
\(\frac{\partial f_n}{\partial f_{n-1}}\)即对激活函数求导,如果此部分大于1,随着层数增加,梯度更新将以指数形式增加,即发生梯度爆炸;如果此部分小于1,随着层数增加,梯度更新将以指数形式衰减,即发生梯度消失。
梯度消失、爆炸,其根本原因在于反向传播训练法则,链式求导次数太多。
2. 激活函数角度解释梯度消失和梯度爆炸
计算权值更新信息,需要计算前层偏导信息,因此激活函数选择不合适,比如Sigmoid,梯度消失会更明显。
如果使用sigmoid作为损失函数,其梯度是不可能超过0.25的,这样经过链式求导之后,很容易发生梯度消失。
tanh作为损失函数,它的导数图如下,可以看出,tanh比sigmoid要好一些,但是它的导数仍然是小于1的。
由于sigmoid和tanh存在上述的缺点,因此relu激活函数成为了大多数神经网络的默认选择。relu函数的导数在正数部分是恒等于1,因此在深层网络中就不存在梯度消失/爆炸的问题,每层网络都可以得到相同的更新速度。另外计算方便,计算速度快,加速网络的训练。
但是relu也存在缺点:即在\(x\)小于0时,导数为0,导致一些神经元无法激活。输出不是以0为中心的。因此引申出下面的leaky relu函数,但是实际上leaky relu使用的并不多。
3. RNN中的梯度消失和CNN的梯度消失有区别
RNN中的梯度消失/爆炸和MLP/CNN中的梯度消失/爆炸含义不同:MLP/CNN中不同的层有不同的参数,各是各的梯度;而 RNN 中同样的权重在各个时间步共享,最终的梯度 g 等于各个时间步的梯度 \(g_t\) 的和。
-
RNN中的总的梯度不会消失。即便梯度越传越弱,那也只是远距离的梯度消失,由于近距离的梯度不会消失,所有梯度之和并不会消失。RNN 所谓梯度消失的真正含义是,梯度被近距离梯度主导,导致模型难以学到远距离的依赖关系。
RNN前向传导过程:
\[\begin{split} t &= 1 \\ s_1 &= g(Ux_1 + Ws_{0})\\ o_1 &= f(Vg(Ux_1 + Ws_{0}))\\ t &= 2 \\ s_2 &= g(Ux_2 + Ws_{1})\\ o_2 &= f(Vg(Ux_2 + Ws_{1})) =f(Vg(Ux_2 + Wg(Ux_1 + Ws_{0})))\\ t &= 3 \\ s_3 &= g(Ux_3 + Ws_{2})\\ o_3 &= f(Vg(Ux_3 + Ws_{2})) = f(Vg(Ux_3 + Wg(Ux_2 + W(Ux_1 + Ws_{0}))))\\ ...\\ t &= m \\ ...\\ Loss &= L(o_m,y)\\ \frac{\partial L}{\partial U} &= \frac{\partial L}{\partial o_m}\frac{\partial o_m}{\partial s_m}\frac{\partial s_m}{\partial U} + \frac{\partial L}{\partial o_m}\frac{\partial o_m}{\partial s_m}\frac{\partial s_m}{\partial s_{m-1}}\frac{\partial s_{m-1}}{\partial U}+...+ \frac{\partial L}{\partial o_m}\frac{\partial o_m}{\partial s_m}\frac{\partial s_m}{\partial s_{m-1}}...\frac{\partial s_{2}}{\partial s_{1}}\frac{\partial s_{1}}{\partial U}\\ &= \sum_{t=1}^{m}\frac{\partial L}{\partial o_m}\frac{\partial o_m}{\partial s_m}\left(\prod_{j=t+1}^{m}\frac{\partial s_{j}}{\partial s_{j-1}}\right)\frac{\partial s_t}{\partial U} \end{split} \]当激活函数为tanh,\(s_t = tanh(Ux_t + Ws_{t-1})\),
权值梯度:
\[\begin{split} Loss &= \sum_{t=1}^{m}\frac{\partial L}{\partial o_m}\frac{\partial o_m}{\partial s_m}\left(\prod_{j=t+1}^{m}\frac{\partial s_{j}}{\partial s_{j-1}}\right)\frac{\partial s_t}{\partial U}\\ &= \sum_{t=1}^{m}\frac{\partial L}{\partial o_m}\frac{\partial o_m}{\partial s_m}\left(\prod_{j=t+1}^{m}tanh'W \right)\frac{\partial s_t}{\partial U} \end{split} \] -
MLP/CNN 的梯度消失:主要是随着网络加深,浅层网络的梯度越来越小,导致参数无法更新迭代。
4. 梯度消失、爆炸的解决方案
在深度神经网络中,往往是梯度消失出现的更多一些。
4.1 梯度爆炸的解决方案:
-
梯度裁剪:主要思想是设置一个梯度剪切阈值,然后更新梯度的时候,如果梯度超过这个阈值,那么就将其强制限制在这个范围之内。这可以防止梯度爆炸。
-
权值正则化(weithts regularization):正则化是通过对网络权重做正则限制过拟合,如下正则项在损失函数中的形式:
\[Loss = (y-W^Tx)^2+\alpha||W||^2 \]常见的是L1正则和L2正则,在各个深度框架中都有相应的API可以使用正则化。
其中,\(\alpha\)是指正则项系数,因此,如果发生梯度爆炸,权值的范数就会变的非常大,通过正则化项,可以部分限制梯度爆炸的发生。
4.2 梯度消失的解决方案:
4.2.1 选择relu、leakrelu、elu等激活函数:
-
relu函数的导数在正数部分是恒等于1的,因此在深层网络中不会导致梯度消失和爆炸的问题。relu优点:解决了梯度消失、爆炸的问题,计算速度快,加速网络训练。relu缺点:导数的负数部分恒为0,会导致一些神经元无法激活(可通过设置小学习率部分解决),输出不是以0为中心的。
-
leakrelu就是为了解决relu的0区间带来的影响。数学表达:\(leakrelu=max(x*k,x)\)
其中k是leak系数,一般选择0.1或者0.2,或者通过学习而来。leakrelu解决了0区间带来的影响,而且包含了relu的所有优点。
-
elu也是为了解决relu的0区间带来的影响,其数学表达为:
但是elu相对于leakrelu来说,计算要更耗时间一些。
4.2.2 使用Batchnorm(batch normalization,简称BN):
目前已经被广泛的应用到了各大网络中,具有加速网络收敛速度,提升训练稳定性的效果,Batchnorm本质上是解决反向传播过程中的梯度问题。通过规范化操作将输出信号x规范化到均值为0,方差为1保证网络的稳定性。 具体来说就是反向传播中,经过每一层的梯度会乘以该层的权重,举个简单例子: 正向传播中\(f_3=f_2(w^Tx+b)\) ,那么反向传播中,\(\frac{\partial f_2}{\partial x}=\frac{\partial f_2}{\partial f_1}w\), 反向传播式子中有\(w\) 的存在,所以 \(w\) 的大小影响了梯度的消失和爆炸,batchnorm就是通过对每一层的输出做scale和shift的方法,通过一定的规范化手段,把每层神经网络任意神经元这个输入值的分布强行拉回到接近均值为0方差为1的标准正太分布,即严重偏离的分布强制拉回比较标准的分布,这样使得激活输入值落在非线性函数对输入比较敏感的区域,这样输入的小变化就会导致损失函数较大的变化,使得让梯度变大,避免梯度消失问题产生,而且梯度变大意味着学习收敛速度快,能大大加快训练速度。
4.2.3 残差结构:
残差单元里的shortcut(捷径)部分可以保证在反向传播中梯度不会消失。
式子的第一个因子\(\frac{\partial Loss}{\partial x_L}\)表示的损失函数到达 L 层的梯度,小括号中的1表明短路机制可以无损地传播梯度,而另外一项残差梯度则需要经过带有weights的层,梯度不是直接传递过来的。残差梯度不会那么巧全为-1,而且就算其比较小,有1的存在也不会导致梯度消失。所以残差学习会更容易。
4.2.4 LSTM:
使用LSTM(long-short term memory networks,长短期记忆网络),就不那么容易发生梯度消失,主要原因在于LSTM内部复杂的“门”(gates),如下图,LSTM通过它内部的“门”可以在更新的时候“记住”前几次训练的“残留记忆”,因此,经常用于生成文本中。
5. 参考
RNN
- https://zybuluo.com/hanbingtao/note/541458
- https://colab.research.google.com/drive/1Zfvt9Vfs3PrJwSDF8jMvomz7CzU36RXk
LSTM
- http://colah.github.io/posts/2015-08-Understanding-LSTMs/
- https://www.youtube.com/watch?v=9zhrxE5PQgY&feature=youtu.be
- https://towardsdatascience.com/illustrated-guide-to-lstms-and-gru-s-a-stepby-step-explanation-44e9eb85bf21
- https://medium.com/datadriveninvestor/how-do-lstm-networks-solve-theproblem-of-vanishing-gradients-a6784971a577