【DL基础】梯度爆炸和梯度消失的理解
前言
一、名词解释
目前优化神经网络的方法都是基于BP,即根据损失函数计算的误差通过梯度反向传播的方式,指导深度网络权值的更新优化。其中将误差从末层往前传递的过程需要链式法则(Chain Rule)的帮助,因此反向传播算法可以说是梯度下降在链式法则中的应用。
而链式法则是一个连乘的形式,所以当层数越深的时候,梯度将以指数形式传播。梯度消失问题和梯度爆炸问题一般随着网络层数的增加会变得越来越明显。在根据损失函数计算的误差通过梯度反向传播的方式对深度网络权值进行更新时,得到的梯度值接近0或特别大,也就是梯度消失或爆炸。梯度消失或梯度爆炸在本质原理上其实是一样的。
在反向传播过程中需要对激活函数进行求导,如果导数大于1,那么随着网络层数的增加梯度更新将会朝着指数爆炸的方式增加这就是梯度爆炸。同样如果导数小于1,那么随着网络层数的增加梯度更新信息会朝着指数衰减的方式减少这就是梯度消失(弥散)。因此,梯度消失、爆炸,其根本原因在于反向传播训练法则,属于先天不足。
二、表现形式
梯度消失的表现:
模型无法从训练数据中获得更新,损失几乎保持不变。
梯度爆炸的表现:
1) 模型不稳定,更新过程中损失变化明显。
2) 训练中,模型损失为NaN。
3) 模型无法在训练数据上收敛。
三、产生原因
问:现象描述,怎么确定是否出现了梯度爆炸或梯度消失呢?
答:可以根据loss的值看出来,在训练过程中loss突然变成inf,之后就保持nan了。
问:工程研发过程中哪些会引起梯度爆炸?
答:学习率过大,损失函数,脏数据(需要进行数据清洗)。
【梯度消失】经常出现,产生的原因有:一是在深层网络中,二是采用了不合适的损失函数,比如sigmoid。当梯度消失发生时,接近于输出层的隐藏层由于其梯度相对正常,所以权值更新时也就相对正常,但是当越靠近输入层时,由于梯度消失现象,会导致靠近输入层的隐藏层权值更新缓慢或者更新停滞。这就导致在训练时,只等价于后面几层的浅层网络的学习。
【梯度爆炸】一般出现在深层网络和权值初始化值太大的情况下。在深层神经网络或循环神经网络中,误差的梯度可在更新中累积相乘。如果网络层之间的梯度值大于 1.0,那么重复相乘会导致梯度呈指数级增长,梯度变的非常大,然后导致网络权重的大幅更新,并因此使网络变得不稳定。
梯度爆炸会伴随一些细微的信号,如:①模型不稳定,导致更新过程中的损失出现显著变化;②训练过程中,在极端情况下,权重的值变得非常大,以至于溢出,导致模型损失变成 NaN等等。
根本原因:1)隐藏层的层数过多;2)激活函数不合适;3)初始权重过大;
3.1 从深层网络的BP(反向传播原理)解释梯度消失和梯度爆炸
3.2 从激活函数角度分析梯度消失
3.3 初始化权重参数的数值过大
四、解决方法
梯度消失和梯度爆炸问题都是因为网络太深,网络权值更新不稳定造成的,本质上是因为梯度反向传播中的连乘效应。解决梯度消失、爆炸主要有以下几种方法:
- 预训练加微调 - 梯度剪切、权重正则(针对梯度爆炸) - 使用不同的激活函数 - 使用batchnorm - 使用残差结构 - 使用LSTM网络
问答
问:每次训练过程梯度爆炸都会存在吗?
答:不一定,和产生梯度爆炸的原因有关。博主某次训练之后发现梯度爆炸问题,但是之前和之后的训练没有这个问题。
参考文章梯度为NAN的一些实践总结之后,猜想遇到的问题原因应该是一样的,即学习率过大导致的,可以跟踪训练过程梯度值的更新来确定。在一些简单的线性拟合优化过程中,学习率只要小于1就能够逐步优化,但稍微有些复杂的曲线的拟合,对学习率设置的要求就已经比较高了。当增大数据量后,发现学习率要调的更低才能避免nan的出现,所以,回想起当时在训练目标检测模型时,初始设置为0.001数值有点大,学习率初始值可能需要设置到1e-5才可以。
参考
2. 梯度消失和梯度爆炸及解决方法;
3. 【机器学习】梯度消失和梯度爆炸的原因分析、表现及解决方案;
5. 梯度为NAN的一些实践总结;
完
心正意诚,做自己该做的事情,做自己喜欢做的事情,安静做一枚有思想的技术媛。
版权声明,转载请注明出处:https://www.cnblogs.com/happyamyhope/
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· 地球OL攻略 —— 某应届生求职总结
· 周边上新:园子的第一款马克杯温暖上架
· Open-Sora 2.0 重磅开源!
· 提示词工程——AI应用必不可少的技术
· .NET周刊【3月第1期 2025-03-02】
2017-08-10 OpenCV Error: Insufficient memory问题解析