炙手可热的LSTM
引言
上一讲说到RNN. RNN可说是目前处理时间序列的大杀器,相比于传统的时间序列算法,使用起来更方便,不需要太多的前提假设,也不需太多的参数调节,更重要的是有学习能力,因此是一种'智能'算法.前面也说到, 不只时间序列,在很多领域,特别是涉及序列数据的,RNN的表现总是那么的'抢眼'.不过,在这抢眼的过程中, 冲锋最前面的可不是简单的RNN(或者说最原始的RNN), 真正把RNN这个牌子''做大做强''的是: 长短时记忆循环神经网络(Long-Short Term Memory RNN, LSTM).
RNN的问题
传统的(或称原始的)RNN理论上是可以记忆任意长度的时间序列,比如你把一本<红楼梦>给她, 理论上她也是可以记忆的. 但, 理论和实际是有差距滴~.
在应用过程中,发现RNN对长时记忆的能力比较弱, 也就是RNN的记性不太好,对于长时间的东西她就有点记不清了,而几乎只会关注最近一段时间的信息. 也就是说, 当你给她<红楼梦>之后,梦想着她如何给你讲讲''大观园''的趣事, 她却不知所云地来了句' 说到辛酸处,荒唐愈可悲.由来同一梦,休笑世人痴.' —— 前面的完全忘光了! 为什么?辛辛苦苦地训练她,可她竟这样地不念旧情?! 请不要怪她, 她的''内在逻辑'',上她无法控制的.这个所谓的'内在逻辑'就是指数函数(Exponential function).
恐怖的指数函数
指数函数大家应该都记得:
\[y = a^x \qquad x \in R
\]
这就是指数函数,是不是其貌不扬? 是不是纯天然无公害? 你错了, 它展现的实力可是爆炸性的(explosive). 大家听过'棋盘放米'的故事吧: 第一格放1粒米(注意是一粒可不是一斤哦~),第二格放 2 粒米,第三个格放 4 粒米,…, 最后国王付不起了... 对于8x8的棋盘(国际象棋), 一共要放 (\(2^{64} - 1\)) 粒米, 即18,446,744,073,709,551,615粒. 一千克米的粒数大约在40000-60000之间, 就算最多的60000 粒, 18,446,744,073,709,551,615 / 60000 约 等于 307445734561826千克. 2016年,大宗粮油全球总产量27.84亿吨,取整算28亿吨, \(307445734561826 \div (28*10^{11}) \approx 110\) 年. 哪个国王支付得起?
不够直观? 好,再举个例子: 拿一张A4纸,对折,对折,再对折,…尝试一下,你能折几次? 世界记录只有13次! what?! 这么少?! 对的,这就是指数的威力.
\(10^{100}\), 10 的100次幂叫做Googol(哈哈,被你发现了, Google名字的由来...) ,这个Googol 可以说是宇宙极限了. 想象一下, 宇宙万物都是由基本粒子组成,你所知道的最小粒子是什么? 电子 还是夸克? 科学家估算全宇宙的基本粒子数(都算上) 最多不超过\(10^{80}\)个!
最后一个例子 \(0.9^{20}\approx 0.12\).
通过以上,就好理解为什么RNN健忘了.回忆下上讲的公式.
\[\begin{array}
\\
s_t &=& f(s_{t-1},x_t)\\
&=& f(W s_{t-1}+Ux_{t})\\
&= &f(Wf(s_{t-2},x_{t-1})+Ux_t)\\
&=& ...\\
&= &f(W(f(W(...(f(Ws_0+Ux1))))))\\
\end{array}
\]
看! 有多个W(矩阵)相乘, 般权重范围在(-1,1)之间,来个指数幂,一下子就没了...也就是之前的信息不会对当前或未来的信息产生影响了. 也就是说RNN失去了记忆的能力了.
专业一点的说法叫做梯度消失(gradient vanishing); 如果权重大于1就会产生梯度爆炸(gradient explosion)(比较少见).
梯度消失*
上面是较通俗的说法, 其实从这个问题的名字就可看出,其切入点是梯度(gradient).
回顾上阶公式:
\[\nabla_W L = \sum_{t}^{T} \boldsymbol{\delta_t\odot s_{t-1}}
\]
其中
\[\delta_t = \frac{\partial L}{\partial z_t}
= \delta_T \Pi_{i = t}^{T-1} W\odot f'(z_i)
= \delta_T * W*f'(z_t)...*W*f'(z_{T-1})
\]
用模的上界表示:
\[||\delta_t||\le ||\delta_T||. (||W||.||\bar{f}||)^{T-t}
\]
其中 \(\bar{f}\)表示的是\(f'(z_t)\)的模上界. 这样就了然了, 指数形式的出现表明前面(T-t 比较大时)的信息对于梯度的更新没有贡献,也就是无论之前信息怎样,最终的权重更新(学习)都不会受到影响. 因为前面信息的梯度由于指数逻辑的存在,使梯度趋近于0 — 消失了.也就是说,RNN忘记了.比如 某个\(w_{i,j}\)的值在t时刻是的梯度0.3(实际中w的一般量级,甚至更小), t-1时 大约为0.3* 0.3 = 0.09了,t-2时约为0.09* 0.3 = 0.027,t-3; 0.027* 0.3 = 0.0081,t-4: 0.0081 *0.3 = 0.00243,…不用写了,在反向传播4次就已经很小了,而RNN时间深度再深一些,比如15,比如20,比如30,…, 可以想见,之前的信息RNN直接忽略了.
不只在RNN中,其实梯度消失及梯度爆炸在深度学习领域一直是一个比较头疼的问题,这也是深度网络难以训练的主要原因.只不过在深度网络中,是因为层数的增多导致产生类似指数形式的连续乘积.
解决方案
出现梯度消失(主要)与爆炸问题后,有很多解决方法提出来,比如设计更好的初始化权重,限制权重范围等等. 这种''通用''的方法的作用有限. 在RNN中有人提出设计隐藏单元用来储存信息,称为储层计算(Reservoir Computing),比如回声状态网络(Echo State Network, ESN).也有人提出在不同的时间粒度上处理数据,不同的时间处理单元称为渗透单元(Leaky Unit).但目前效果最好的,通用性最高的还是门限(Gated) RNN.其中最火的就是LSTM(Long-Short Term Memory)及GRU(Gated Recurrent Unit).
LSTM
设计初衷
首先,抛开恐怖的指数函数不谈,咱们先想象一个场景:假设你很喜欢古龙的小说,他的小说你都看了好多遍.现在给一篇他的小说,比如<七种武器>里的<霸王枪>, 篇幅不长,故事也不太复杂,让你阅读. 几个小时后,或者大方点,第二天,我来找你,让你一字不落地背出第一章<落日照大旗>. 你一定会问我我是不是凯丁蜜(Are you kidding me?),然后我会说我是斯尔瑞尔斯(I'm serious.). 最后你会承认你背不出. 但我要问你:谁是丁喜?百里长青与丁喜是什么关系?这本小说讲了一个什么故事?你一定滔滔不绝.
背不出一章内容,但却能说出整本小说的故事梗概, 是因为我们会提取主要信息, 不会对信息'一视同仁',懂得取舍.有些信息比如环境描写看看就过去了,一般不会刻意去记忆.但有些重要线索,比如谁杀了谁等等这样的信息我们会记住.
回过头来再看RNN,继续忽略恐怖的指数函数,直观的理解一下: RNN读取的信息,对信息一视同仁:经过处理的信息,RNN认为这些信息的任何一部分都对接下来的信息有影响,全部都抛给接下来处理的程序.对这些信息,RNN进行同样的处理.,造成大量无用信息冗余,浪费大量记忆空间,导致关键信息无法突出,更多的信息又无法存储.从而产生较前面的信息RNN记不住的问题.这才是''本质原因". 神马指数函数只是'刽子手'而已.
其他方法都只是从表象处理问题(针对梯度消失,或指数函数的连续乘法),或者虽针对本质原因但方法不对头. 而门限RNN正是针对信息的重要性设计的.
LSTM原理
考虑重要性,那就自然而然的产生两种时态信息. 一种就是长时态(Long term state)信息. 此信息包含'趋势'信息或'主旨'信息,是剔除冗余信息后,对未来信息真正产生作用的信息.比如小说中的主旨大意,新闻要点等等.另一种短时态信息(Short term state). 此类信息是最直接地,对未来信息产生影响的信息. 比如'今天真热啊, 我得吹吹(空调)', ''吹吹'直接导致'空调''或'风扇'的产生,而不是可乐,'凉水澡'等等.
相比传统RNN的'一视同仁', 两种时态信息的区分,致使长时态信息不会被短时信息所淹没.
图1: 信息分态
对于两种时态信息, LSTM是如何提取重要信息的呢? 顾名思义,通过门(gate)来'提取'的.
图2:信息流门限控制
上图中, \(C_t\) 代表长时态信息 \(C_{t-1}\) 为前一个时刻的长时态信息),而 \(C'_t\) 则代表短时态信 息, \( h_t\) 为经过LSTM单元后的输出信息,三条线上的开关,即为门限.图中展示的三种门分别为:
- 前一时刻长态信息与当前时刻长态信息之间控制门: 遗忘门(Forget gate);
- 当前时刻即短态信息与长态信息之间控制门: 输入门(Input gate);
- 当前信息(长,短汇总后)与输出态信息之间控制门:输出门(Output gate).
遗忘门控制的是历史信息有多少对现在,对未来有影响,即有多少是可以继续保留在长态信息的; 输入门控制的是输入信息有多少可以加入到长态信息中去;输出门控制的是汇总后的信息有多少是可以作为当前输出的信息.
门限控制*
门的设计根据以上信息也就不难设计:
\[gate(x) = \sigma(Wx + b)
\]
其中 x 表示的是门的输入, 而 \(\sigma\) 表示的是门限激活(控制)函数,一般为(或当前比较流行)sigmoid函数.设 t 时刻的遗忘门,输入门及输出门分别用 \(f_t, i_t,o_t\) 表示. 则三种的表示方式:
\[\begin{array}
\\
f_t = \sigma(h_{t-1},x_t) = \sigma(W_{f,h}h_{t-1} + W_{f,x}x+b_f)\\
i_t = \sigma(h_{t-1},x_t) = \sigma(W_{i,h}h_{t-1} + W_{i ,x}x+b_i) \\
o_t = \sigma(h_{t-1},x_t) = \sigma(W_{o,h}h_{t-1} + W_{o,x}x+b_o)\\
\end{array}
\]
其中 W 为权重,如 \(W_{f,h}\) 为遗忘门对应的上一时态输出信息的权重,其他同理. b 为偏置.
门的输入 (x) 又是什么呢?即门的开关取决于什么呢?没错,是单元的输入信息,当前时刻的输入信息包括前一时刻的输出(\(h_{t-1}\))以及当前时刻的外部信息输入 (\(x_t\)). 用 \(C'_t\) 表示当前输入则:
\[C'_t = tanh(W_{C,h} h_{t-1}+W_{C,x}x + b_C)
\]
试下吧.其中 tanh() 为双曲正切函数,它可以理解成为激活函数,当然其激活函数也是可以的,只不过当前流行(或者说当前效果好)tanh(),下同.
以上,门与输入都有了,那 t 时刻的状态信息(Ct)就可以写出来了(观察图2):
\[C_t = f_t\odot C_{t-1} + i_t \odot C'_t
\]
可见,当门的值为1时,门属于完全开放状态,所有信息都可以通过, 而门的值为0 则表示关闭状态,所有信息都不能通过, 而正常情况下则是(0,1)之间,即对信息是有取舍的.
t 时刻的状态信息产生,那么 t 时刻的输出 (ht) 就可以得出了:
\[h_t = o_t \odot tanh(C_t)
\]
至此.LSTM单元就构建完成了.
LSTM 的 BPTT
对模型训练,要更新的参数即为权重(与偏置),其中权重的设置有四处,三个门与输入的端.
设加权输入为 z, 则:
\[\begin{array}
\\
z_{f,t} = W_{f,h} h_{t-1} + W_{f,x} x_t + b_f\\
z_{i,t} = W_{i,h} h_{t-1} + W_{i,x} x_t + b_i\\
z_{o,t} = W_{o,h} h_{t-1} + W_{o,x} x_t + b_o\\
z_{f,t} = W_{c',h} h_{t-1} + W_{C',x} x_t + b_{C'}\\
\end{array}
\]
设其对应的误差项为 \(\delta\), 则:
\[\begin{array}\\
\delta_{f,t} = \frac{\partial L}{\partial{ z_{f,t}}}\\
\delta_{i,t} = \frac{\partial L}{\partial{ z_{i,t}}}\\
\delta_{o,t} = \frac{\partial L}{\partial{ z_{o,t}}}\\
\delta_{C',t} = \frac{\partial L}{\partial{ z_{C',t}}}\\
\end{array}
\]
其中 L 为损失函数.
设 t 时刻的误差项为 \(\delta_t\):
\[\delta_t = \frac{\partial L}{\partial h_t}
\]
则 t-1 时刻的误差项为:
\[\delta_{t-1} = \frac{\partial L}{\partial h_{t-1}} =\frac{\partial L}{\partial h_t} \frac{\partial h_{t}}{\partial h_{t-1}} = \delta_t \frac{\partial h_{t}}{\partial h_{t-1}}
\]
回顾下(7-10)四式:
\[\begin{array}\\
\frac{\partial h_{t}}{\partial h_{t-1}} & = & \frac{\partial h_t}{\partial o_t} \frac{\partial o_t}{\partial z_{o,t}} \frac{\partial z_{o,t}}{\partial h_{t-1}}\\
&&+\frac{\partial h_t}{\partial C_t}\frac{\partial C_t}{\partial f_t} \frac{\partial f_t}{\partial z_{f,t}} \frac{\partial z_{f,t}}{\partial h_{t-1}}\\
&& + \frac{\partial h_t}{\partial C_t}\frac{\partial C_t}{\partial i_t} \frac{\partial i_t}{\partial z_{i,t}} \frac{\partial z_{i,t}}{\partial h_{t-1}} \\
&&+ \frac{\partial h_t}{\partial C_t}\frac{\partial C_t}{\partial X'_t} \frac{\partial C'_t}{\partial z_{C',t}} \frac{\partial z_{C',t}}{\partial h_{t-1}}
\end{array}
\]
把(15)式代入(14)式, 可得:
\[\delta_{t-1} = \delta_{o,t}\frac{\partial z_{o,t}}{\partial h_{t-1}}+\delta_{f,t}\frac{\partial z_{f,t}}{\partial h_{t-1}}+\delta_{i,t}\frac{\partial z_{i,t}}{\partial h_{t-1}}+\delta_{C',t}\frac{\partial z_{C',t}}{\partial h_{t-1}}
\]
其中用到了:
\[\begin{array}\\
\delta_{f,t} = \frac{\partial L}{\partial{ z_{f,t}}} = \frac{\partial L}{\partial h_t} \frac{\partial h_{t}}{\partial C_{t}}\frac{\partial C_t}{\partial f_t} \frac{\partial f_t}{\partial z_{f,t}} \\
\delta_{i,t} = \frac{\partial L}{\partial{ z_{i,t}}} = \frac{\partial L}{\partial h_t} \frac{\partial h_{t}}{\partial C_{t}}\frac{\partial C_t}{\partial i_t} \frac{\partial i_t}{\partial z_{i,t}} \\
\delta_{o,t} = \frac{\partial L}{\partial{ z_{o,t}}} = \frac{\partial L}{\partial h_t} \frac{\partial h_{t}}{\partial C_{t}}\frac{\partial C_t}{\partial o_t} \frac{\partial o_t}{\partial z_{o,t}} \\
\delta_{C',t} = \frac{\partial L}{\partial{ z_{C',t}}} = \frac{\partial L}{\partial h_t} \frac{\partial h_{t}}{\partial C_{t}}\frac{\partial C_t}{\partial C'_t} \frac{\partial C'_t}{\partial z_{C',t}} \\
\end{array}
\]
于是:
\[\delta_{t-1} = \delta_{o,t} W_{o,h}+\delta_{f,t} W_{f,h}+\delta_{i,t} W_{i,h}+\delta_{C',t} W_{C',h}
\]
其中用到了:
\[\begin{array}\\
\frac{\partial z_{f,t}}{\partial h_{t-1}} =\frac{\partial (W_{f,h} h_{t-1}+W_{C',x} x_t + b_f)}{\partial h_{t-1}} = W_{f,h} \\
\frac{\partial z_{i,t}}{\partial h_{t-1}} =\frac{\partial (W_{i,h} h_{t-1}+W_{C',x} x_t + b_{i})}{\partial h_{t-1}} = W_{i,h} \\
\frac{\partial z_{o,t}}{\partial h_{t-1}} =\frac{\partial (W_{o,h} h_{t-1}+W_{o,x} x_t + b_{o})}{\partial h_{t-1}} = W_{o,h} \\
\frac{\partial z_{C',t}}{\partial h_{t-1}} =\frac{\partial (W_{C',h} h_{t-1}+W_{C',x} x_t + b_{C'})}{\partial h_{t-1}} = W_{C',h} \\
\end{array}
\]
以上是误差在时域上的传播.
接下来探讨传播到上一层(l-1):
\[\delta^{l-1}_t = \frac{\partial L }{\partial z_t^{l-1}}
\]
t 时刻的输入 \(x_t\) :
\[x_t^l = f^{l-1}(z_{t}^{l-1})
\]
在 l 层, $z_{f,t}l,z_{i,t}l,z_{o,t}l,z_{C',t}l $ 均为\(x_t\) 的函数,则
\[\begin{array}\\
\delta^{l-1}_t & = & \frac{\partial L }{\partial z_t^{l-1}} \\
&=& \frac{\partial L}{\partial z_{f,t}^l} \frac{\partial z_{f,t}^l}{\partial x_t^l}\frac{\partial x_t}{\partial z_{t}^{l-1}} \\
&&+ \frac{\partial L}{\partial z_{i,t}^l} \frac{\partial z_{i,t}^l}{\partial x_t^l}\frac{\partial x_t}{\partial z_{t}^{l-1}} \\
&&+ \frac{\partial L}{\partial z_{o,t}^l} \frac{\partial z_{o,t}^l}{\partial x_t^l}\frac{\partial x_t}{\partial z_{t}^{l-1}} \\
&&+ \frac{\partial L}{\partial z_{C',t}^l} \frac{\partial z_{C',t}^l}{\partial x_t^l}\frac{\partial x_t}{\partial z_{t}^{l-1}} \\
&=& \delta_{f,t}W_{f,x}\odot f'(z_{t}^{l-1}) \\
&&+ \delta_{i,t}W_{i,x}\odot f'(z_{t}^{l-1}) \\
&&+ \delta_{o,t}W_{o,x}\odot f'(z_{t}^{l-1}) \\
&& +\delta_{C',t}W_{C',x}\odot f'(z_{t}^{l-1}) \\
& = & ( \delta_{f,t} W_{f,x}+ \delta_{i,t}W_{i,x}+ \delta_{o,t}W_{o,x}+ \delta_{C',t}W_{C',x})\odot f'(z_{t}^{l-1})
\end{array}
\]
有以上误差项, 梯度求解就简单多了,t 时刻的 \(W_{f,h},W_{i,h},W_{o,h},W_{C',h}\) 分别为:
\[\begin{array}\\
\frac{\partial L}{\partial W_{f,h,t}} = \frac{\partial L}{\partial z_{f,t}}\frac{\partial z_{f,t}}{\partial W_{f,h,t}}= \delta_{f,t} h_{t-1}\\
\frac{\partial L}{\partial W_{i,h,t}} = \frac{\partial L}{\partial z_{i,t}}\frac{\partial z_{i,t}}{\partial W_{i,h,t}}= \delta_{i,t} h_{t-1}\\
\frac{\partial L}{\partial W_{o,h,t}} = \frac{\partial L}{\partial z_{o,t}}\frac{\partial z_{o,t}}{\partial W_{o,h,t}}= \delta_{o,t} h_{t-1}\\
\frac{\partial L}{\partial W_{C',h,t}} = \frac{\partial L}{\partial z_{C',t}}\frac{\partial z_{C',t}}{\partial W_{C',h,t}}= \delta_{C',t} h_{t-1}\\
\end{array}
\]
各个时刻的梯度之和即最终梯度:
\[\begin{array}\\
\frac{\partial L}{\partial W_{f,h}} = \sum_{t =1}^T \delta_{f,t}h_{t-1} \\
\frac{\partial L}{\partial W_{i,h}} = \sum_{t =1}^T \delta_{i,t}h_{t-1} \\
\frac{\partial L}{\partial W_{o,h}} = \sum_{t =1}^T \delta_{o,t}h_{t-1} \\
\frac{\partial L}{\partial W_{C',h}} = \sum_{t =1}^T \delta_{C',t}h_{t-1} \\
\end{array}
\]
对于偏置:
\[\begin{array}\\
\frac{\partial L}{\partial b_{f,t}} = \frac{\partial L}{\partial z_{f,t}} \frac{\partial z_{f,t}}{\partial b_{f,t}} = \delta_{f,t} \\
\frac{\partial L}{\partial b_{i,t}} = \frac{\partial L}{\partial z_{i,t}} \frac{\partial z_{i,t}}{\partial b_{i,t}} = \delta_{i,t} \\
\frac{\partial L}{\partial b_{o,t}} = \frac{\partial L}{\partial z_{o,t}} \frac{\partial z_{o,t}}{\partial b_{o,t}} = \delta_{o,t} \\
\frac{\partial L}{\partial b_{C',t}} = \frac{\partial L}{\partial z_{C',t}} \frac{\partial z_{C',t}}{\partial b_{C',t}} = \delta_{C',t}
\end{array}
\]
最终梯度:
\[\begin{array}\\
\frac{\partial L}{\partial b_{f,t}} = \sum_{t =1}^T \delta_{f,t}\\
\frac{\partial L}{\partial b_{i,t}} = \sum_{t =1}^T \delta_{i,t}\\
\frac{\partial L}{\partial b_{o,t}} = \sum_{t =1}^T \delta_{o,t}\\
\frac{\partial L}{\partial b_{C',t}} = \sum_{t =1}^T \delta_{C',t}\\
\end{array}
\]
最后:
\[\begin{array}\\
\frac{\partial L}{\partial W_{f,x}} = \frac{\partial L}{\partial z_{f,t}}\frac{\partial z_{f,t}}{\partial W_{f,x}} = \delta_{f,x}x_t\\
\frac{\partial L}{\partial W_{i,x}} = \frac{\partial L}{\partial z_{i,t}}\frac{\partial z_{i,t}}{\partial W_{i,x}} = \delta_{i,x}x_t\\
\frac{\partial L}{\partial W_{o,x}} = \frac{\partial L}{\partial z_{o,t}}\frac{\partial z_{o,t}}{\partial W_{o,x}} = \delta_{o,x}x_t\\
\frac{\partial L}{\partial W_{C,x}} = \frac{\partial L}{\partial z_{C',t}}\frac{\partial z_{C',t}}{\partial W_{C',x}} = \delta_{C',x}x_t\\
\end{array}
\]
以上就是 LSTM 的 BPTT, 似乎很多公式,但其实四种模式都是一样的,怕大家混淆就都写上了,只不过这样看着会很多的样子.
参考文献:
1: Deep Learning, 2016, Ian Goodfellow, Yoshua Bengio, Aaron Courville.
2: Neural Networks and Deep Learning, 2016, Michael Nielsen.
3: Understanding LSTM, 2015, Colah's blog.
4: A Critical Review of Recurrent Neural Networks for Sequence Learning, 2015, Zachary C. Lipton et al.
5: 零基础入门深度学习(6)- 长短时记忆网络(LSTM) 2017, hanbingtao.