随时间的反向传播算法 BPTT

本文转自:https://www.cntofu.com/book/85/dl/rnn/bptt.md

随时间反向传播(BPTT)算法


先简单回顾一下RNN的基本公式:

 

st=tanh(Uxt+Wst1)st=tanh⁡(Uxt+Wst−1)

 

 

y^t=softmax(Vst)y^t=softmax(Vst)

 

RNN的损失函数定义为交叉熵损失:

 

Et(yt,y^t)=ytlogy^tEt(yt,y^t)=−ytlog⁡y^t

 

 

E(y,y^)=tEt(yt,y^t)=tytlogy^tE(y,y^)=∑tEt(yt,y^t)=−∑tytlog⁡y^t

 

 

ytyt

是时刻t的样本实际值, 

y^ty^t

是预测值,我们通常把整个序列作为一个训练样本,所以总的误差就是每一步的误差的加和。我们的目标是计算损失函数的梯度,然后通过梯度下降方法学习出所有的参数U, V, W。比如:

EW=tEtW∂E∂W=∑t∂Et∂W

 

为了更好理解BPTT我们来推导一下公式:

前向 前向传播1:

 

a0=x0ua0=x0∗u

 

 

b0=s1wb0=s−1∗w

 

 

z0=a0+b0+kz0=a0+b0+k

 

 

s0=func(z0)s0=func(z0)

 (

funcfunc

 是 sig或者tanh)

 

前向 前向传播2:

 

a1=x1ua1=x1∗u

 

 

b1=s0wb1=s0∗w

 

 

z1=a1+b1+kz1=a1+b1+k

 

 

s1=func(z1)s1=func(z1)

(

funcfunc

 是 sig 或者tanh)

 

 

q=s1v1q=s1∗v1

 

$$z_t = ux_t + ws_{t-1} + k$$

 

st=func(zt)st=func(zt)

 

输出 层:

 

o=func(q)o=func(q)

(

funcfunc

 是 softmax)

 

 

E=func(o)E=func(o)

(

funcfunc

 是 x-entropy)

 

下面 是U的推导

 

E/u=E/u1+E/u0∂E/∂u=∂E/∂u1+∂E/∂u0

 

 

E/u1=E/oo/qq/s1s1/z1z1/a1a1/u1∂E/∂u1=∂E/∂o∗∂o/∂q∗∂q/∂s1∗∂s1/∂z1∗∂z1/∂a1∗∂a1/∂u1

 

 

E/u0=E/oo/qq/s1s1/z1z1/b1b1/s0s0/dz0z0/a0a0/u0∂E/∂u0=∂E/∂o∗∂o/∂q∗∂q/∂s1∗∂s1/∂z1∗∂z1/∂b1∗∂b1/∂s0∗∂s0/dz0∗∂z0/∂a0∗∂a0/∂u0

 

 

E/u=E/oo/qv1s1/z1((1x1)+(1w1s0/z01x0))∂E/∂u=∂E/∂o∗∂o/∂q∗v1∗∂s1/∂z1∗((1∗x1)+(1∗w1∗∂s0/∂z0∗1∗x0))

 

 

E/u=E/oo/qv1s1/z1(x1+w1s0/z0x0)∂E/∂u=∂E/∂o∗∂o/∂q∗v1∗∂s1/∂z1∗(x1+w1∗∂s0/∂z0∗x0)

 

W参数的推导如下

 

E/w=E/oo/qv1s1/z1(s0+w1s0/z0s1)∂E/∂w=∂E/∂o∗∂o/∂q∗v1∗∂s1/∂z1∗(s0+w1∗∂s0/∂z0∗s−1)

 

总结

 

Lu=tLut=Loos1s1u1+Loos1s1s0s0u0∂L∂u=∑t∂L∂ut=∂L∂o∂o∂s1∂s1∂u1+∂L∂o∂o∂s1∂s1∂s0∂s0∂u0

 

 

Lw=tLwt=Loos1s1w1+Loos1s1s0s0w0∂L∂w=∑t∂L∂wt=∂L∂o∂o∂s1∂s1∂w1+∂L∂o∂o∂s1∂s1∂s0∂s0∂w0

 

 

xtxt

是时间t的输入

 

posted @ 2019-06-25 19:35  RamboBai  阅读(879)  评论(0编辑  收藏  举报