Processing math: 100%

Hiroki

大部分笔记已经转移到 https://github.com/hschen0712/machine_learning_notes ,QQ:357033150, 欢迎交流

BPTT算法推导

随时间反向传播 (BackPropagation Through Time,BPTT)

符号注解:

  • K:词汇表的大小
  • T:句子的长度
  • H:隐藏层单元数
  • Et:第t个时刻(第t个word)的损失函数,定义为交叉熵误差Et=yTtlog(ˆyt)
  • E:一个句子的损失函数,由各个时刻(即每个word)的损失函数组成,E=TtEt
    注: 由于我们要推倒的是SGD算法, 更新梯度是相对于一个训练样例而言的, 因此我们一次只考虑一个句子的误差,而不是整个训练集的误差(对应BGD算法)
  • xtRK×1:第t个时刻RNN的输入,为one-hot vector,1表示一个单词的出现,0表示不出现
  • stRH×1:第t个时刻RNN隐藏层的输入
  • htRH×1:第t个时刻RNN隐藏层的输出
  • ztRK×1:输出层的汇集输入
  • ˆytRK×1:输出层的输出,激活函数为softmax
  • ytRK×1:第t个时刻的监督信息,为一个one-hot vector
  • rt=ˆytyt:残差向量
  • WRH×K:从输入层到隐藏层的权值
  • URH×H:隐藏层上一个时刻到当前时刻的权值
  • VRK×H:隐藏层到输出层的权值

他们之间的关系:

{st=Uht1+Wxtht=σ(st)zt=Vhtˆyt=softmax(zt)

其中,σ()是sigmoid函数。由于xt是one-hot向量,假设第j个词出现,则Wxt相当于把W的第j列选出来,因此这一步是不用进行任何矩阵运算的,直接做下标操作即可,在matlab里就是W(:,xt)

BPTT与BP类似,是在时间上反传的梯度下降算法。RNN中,我们的目的是求得EU,EW,EV,根据这三个变化率来优化三个参数U,V,W
注意到EU=tEtU,因此我们只要对每个时刻的损失函数求偏导数再加起来即可。
1.计算EtV

EtVij=tr((Etzt)TztVij)=tr((ˆytyt)T[0z(i)tVij0])=r(i)th(j)t

注:推导中用到了之前推导用到的结论。其中r(i)t=(ˆytyt)(i)表示残差向量第i个分量,h(j)t表示ht的第j个分量。
上述结果可以改写为:

EtV=(ˆytyt)ht

EV=tk=0(ˆykyk)hk

其中表示向量外积。
2.计算EtU
由于U是各个时刻共享的,所以t之前每个时刻U的变化都对Et有贡献,反过来求偏导时,也要考虑之前每个时刻U对E的影响。我们以sk为中间变量,应用链式法则:

EtU=tk=0skUEtsk

但由于skU(分子向量,分母矩阵)以目前的数学发展水平是没办法求的,因此我们要求这个偏导,可以拆解为EtUij的偏导数:

EtUij=tk=0tr[(Etsk)TskUij]=tk=0tr[(δk)TskUij]

其中,δk=Etsk,遵循

skhksk+1...Et

的传递关系,应用链式法则有:

δk=hksksk+1hkEtsk+1=diag(1hkhk)UTδk+1=(UTδk+1)(1hkhk)

其中,表示向量点乘。于是,我们得到了关于δ 的递推关系式。由δt出发,我们可以往前推出每一个δ,现在计算δt
\begin{equation}\delta_t=\frac{\partial E_t}{\partial s_t}=\frac{\partial h_t}{\partial s_t}\frac{\partial z_t}{\partial h_t}\frac{\partial E_t}{\partial z_t}=diag(1-h_t\odot h_t)\cdot VT\cdot(\hat{y}_t-y_t)=(VT(\hat{y}t-y_t))\odot (1-h_t\odot h_t)\end{equation}
δ0,...,δt代入$ \frac{\partial E_t}{\partial U
{ij}} $有:

EtUij=tk=0δ(i)kh(j)k1

将上式写成矩阵形式:

EtU=tk=0δkhk1

不失严谨性,定义h1为全0的向量。

3.计算EtW
按照上述思路,我们可以得到

EtW=tk=0δkxk

由于xk是个one-hot vector,假设其第m个位置为1,那么我们在更新W时只需要更新W的第m列即可,计算EtW的伪代码如下:

delta_t = V.T.dot(residual[T]) * (1-h[T]**2)
for t from T to 0
    dEdW[ :,x[t] ] += delta_t
    #delta_t = W.T.dot(delta_t) * (1 - h[t-1]**2)
    delta_t = U.T.dot(delta_t) * (1 - h[t-1]**2)

posted on   Hiroki  阅读(28282)  评论(6编辑  收藏  举报

编辑推荐:
· 为什么说在企业级应用开发中,后端往往是效率杀手?
· 用 C# 插值字符串处理器写一个 sscanf
· Java 中堆内存和栈内存上的数据分布和特点
· 开发中对象命名的一点思考
· .NET Core内存结构体系(Windows环境)底层原理浅谈
阅读排行:
· 为什么说在企业级应用开发中,后端往往是效率杀手?
· 本地部署DeepSeek后,没有好看的交互界面怎么行!
· DeepSeek 解答了困扰我五年的技术问题。时代确实变了!
· 趁着过年的时候手搓了一个低代码框架
· 推荐一个DeepSeek 大模型的免费 API 项目!兼容OpenAI接口!
< 2025年2月 >
26 27 28 29 30 31 1
2 3 4 5 6 7 8
9 10 11 12 13 14 15
16 17 18 19 20 21 22
23 24 25 26 27 28 1
2 3 4 5 6 7 8

导航

统计

点击右上角即可分享
微信分享提示