机器学习 —— 基础整理(八)循环神经网络的BPTT算法步骤整理;梯度消失与梯度爆炸
网上有很多Simple RNN的BPTT(Backpropagation through time,随时间反向传播)算法推导。下面用自己的记号整理一下。
我之前有个习惯是用下标表示样本序号,这里不能再这样表示了,因为下标需要用做表示时刻。
典型的Simple RNN结构如下:
图片来源:[3]
约定一下记号:
输入序列 x(1:T)=(x1,x2,...,xT)x(1:T)=(x1,x2,...,xT) ;
标记序列 y(1:T)=(y1,y2,...,yT)y(1:T)=(y1,y2,...,yT) ;
输出序列 ˆy(1:T)=(ˆy1,ˆy2,...,ˆyT)^y(1:T)=(^y1,^y2,...,^yT) ;
隐层输出 ht∈RH ;
隐层输入 st∈RH ;
过softmax之前输出层的输出 zt 。
(一)Simple RNN的BPTT
那么对于Simple RNN来说,前向传播过程如下(省略了偏置):
st=Uht−1+Wxt
ht=f(st)
zt=Vht
ˆyt=softmax(zt)
其中 f 是激活函数。注意,三个权重矩阵在时间维度上是共享的。这可以理解为:每个时刻都在执行相同的任务,所以是共享的。
既然每个时刻都有输出 ˆyt ,那么相应地,每个时刻都会有损失。记 t 时刻的损失为 Lt ,那么对于样本 x(1:T) 来说,损失 L 为
L=T∑t=1Lt
使用交叉熵损失函数,那么
Lt=−y⊤tlogˆyt
一、 L 对 V 的梯度
下面首先求取 L 对 V 的梯度。根据chain rule:∂z∂x=∂y∂x∂z∂y 、∂z∂Xij=(∂z∂y)⊤∂y∂Xij ,有
∂Lt∂Vij=(∂Lt∂zt)⊤∂zt∂Vij
这里其实和BP是一样的,前一项相当于是误差项 δ ,后一项等于
∂zt∂Vij=∂Vht∂Vij=(0,...,[ht]j,...,0)⊤
只有第 i 行非零,[ht]j 是指 ht 的第 j 个元素。参考上一篇博客的结尾部分,可知前一项等于
∂Lt∂zt=ˆyt−yt
所以有
∂Lt∂Vij=[ˆyt−yt]i[ht]j
从而有
∂Lt∂V=(ˆyt−yt)h⊤t=(ˆyt−yt)⊗ht
向量外积是矩阵的Kronecker积在向量下的特殊情况。因此,
∂L∂V=T∑t=1(ˆyt−yt)⊗ht
二、 L 对 U 的梯度
继续求取 L 对 U 的梯度。在求 ∂Lt∂U 时,需要注意到一个事实,那就是不光 t 时刻的隐状态与 U 有关,之前所有时刻的隐状态都与 U 有关。
图片来源:[1]
所以,根据chain rule:
∂Lt∂U=t∑k=1∂sk∂U∂Lt∂sk
下面使用和之前类似的套路求解:先求对一个矩阵一个元素的梯度。
∂Lt∂Uij=t∑k=1(∂Lt∂sk)⊤∂sk∂Uij
前一项先定义为 δt,k=∂Lt∂sk ,对于后一项:
∂sk∂Uij=∂(Uhk−1+Wxk)∂Uij=(0,...,[hk−1]j,...,0)⊤
只有第 i 行非零,[hk−1]j 是指 hk−1 的第 j 个元素。现在来求解 δt,k=∂Lt∂sk ,使用上篇文章求 δ(l) 的套路:
δt,k=∂Lt∂sk=∂hk∂sk∂sk+1∂hk∂Lt∂sk+1=diag(f′(sk))U⊤δt,k+1=f′(sk)⊙(U⊤δt,k+1)
一种特殊情况是当 δt,t ,有
δt,t=∂Lt∂st=∂ht∂st∂zt∂ht∂Lt∂zt=diag(f′(st))V⊤(ˆyt−yt)=f′(st)⊙(V⊤(ˆyt−yt))
所以,
∂Lt∂Uij=t∑k=1[δt,k]i[hk−1]j
∂Lt∂U=t∑k=1δt,kh⊤k−1=t∑k=1δt,k⊗hk−1
因此,
∂L∂U=T∑t=1t∑k=1δt,k⊗hk−1
三、L 对 W 的梯度
观察 st=Uht−1+Wxt 这个式子,不难发现只要把刚刚推导的结果做一下简单的替换就可以直接得到新的结果:
∂Lt∂W=t∑k=1δt,k⊗xk
∂L∂W=T∑t=1t∑k=1δt,k⊗xk
总的来说,没有写什么insightful的东西,就是记录一下而已。使用的套路都是BP中使用的(其实就是很基本的chain rule)。但是需要注意的是,这里实际上是在时间维度上的展开。如果是跟普通的神经网络那样构造多个隐层,则需要在“纵向”上继续扩展,形成所谓的深度RNN。因为Theano等自动求导工具的存在,所以如果只是为了编程的话,很多情况下其实也不太需要手推了。
深度双向RNN。图片来源:[2]
(二)梯度消失(gradient vanishing)
我们考察一下下面这个梯度:
∂Lt∂U=∂ht∂U∂ˆyt∂ht∂Lt∂ˆyt
这里的 ∂ht∂U 比较麻烦,是因为各个时刻共享了参数:ht 这个参数是和 ht−1 、U 有关的,而 ht−1 又和 ht−2 、U 有关。所以参照 [5] ,可以写成以下形式(读 [5] 的时候需要注意其前向传播过程和 [4] 一样,与本文是有区别的,但在这里不妨碍理解):
∂Lt∂U=t∑k=1∂hk∂U∂ht∂hk∂ˆyt∂ht∂Lt∂ˆyt
其中,
∂ht∂hk=t∏i=k+1∂hi∂hi−1=t∏i=k+1∂si∂hi−1∂f(si)∂si=t∏i=k+1U⊤diagf′(si)
从这个式子可以看出,当使用tanh或logistic激活函数时,由于导数值分别在0到1之间、0到1/4之间,所以如果权重矩阵 U 的范数也不很大,那么经过 t−k 次传播后,∂ht∂hk 的范数会趋于0,也就导致了梯度消失问题。其实从上面误差项的表达式也可以看出,δt,k 与 δt,k+1 是乘一个导函数的关系,这个导函数值域在0到1之间(tanh)、0到1/4之间(logistic),那么随着时间的累积,当然会造成梯度消失问题。
为了缓解梯度消失,可以使用ReLU、PReLU来作为激活函数,以及将 U 初始化为单位矩阵(而不是用随机初始化)等方式。
(普通的前馈深层神经网络也会存在梯度消失,只不过那里是“纵向”上的。)
也就是说,虽然Simple RNN从理论上可以保持长时间间隔的状态之间的依赖关系,但是实际上只能学习到短期依赖关系。这就造成了“长期依赖”问题。打个比方,你对着模型说了一大段话,“你好,我叫小明,balabala……,很高兴认识你”。模型听完之后回答你:“很高兴认识你,你叫什么?我叫小红。”——模型已经忘了你叫什么了。
需要通过带LSTM单元的RNN来缓解梯度消失问题,现在一般把使用LSTM单元的RNN就直接叫LSTM了。LSTM单元引入了门机制(Gate),通过遗忘门、输入门和输出门来控制流过单元的信息。我们知道,Simple RNN之所以有梯度消失是因为误差项之间的相乘关系;如果用LSTM推导,会发现这个相乘关系变成了相加关系,所以可以缓解梯度消失。
(三)梯度爆炸(gradient exploding)
而对于梯度爆炸问题,通常就是使用比较简单的策略,也就是gradient clipping梯度裁剪:如果在一次迭代中各个权重的梯度平方和大于某个阈值,那么为避免权重的变化值太大,求一个缩放因子(阈值除以平方和),将所有的梯度乘以这个因子。TensorFlow里提供了很多种梯度裁剪的函数,直接看API吧。
参考:
[1] 《神经网络与深度学习讲义》
[2] Recurrent Neural Networks Tutorial, Part 1 – Introduction to RNNs
Recurrent Neural Networks Tutorial, Part 3 – Backpropagation Through Time and Vanishing Gradients
[3] BPTT算法推导
[4] On the difficulty of training RNN
[6] 知乎:deep bidirectional RNN +LSTM 用于癫痫检测的疑问?
[7] caffe里的clip gradient是什么意思?
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· 开发者必知的日志记录最佳实践
· SQL Server 2025 AI相关能力初探
· Linux系列:如何用 C#调用 C方法造成内存泄露
· AI与.NET技术实操系列(二):开始使用ML.NET
· 记一次.NET内存居高不下排查解决与启示
· Manus重磅发布:全球首款通用AI代理技术深度解析与实战指南
· 被坑几百块钱后,我竟然真的恢复了删除的微信聊天记录!
· 没有Manus邀请码?试试免邀请码的MGX或者开源的OpenManus吧
· 园子的第一款AI主题卫衣上架——"HELLO! HOW CAN I ASSIST YOU TODAY
· 【自荐】一款简洁、开源的在线白板工具 Drawnix