(Fig. 1)
Fig.1 是一张展开的LSTM模型的示意图,绿色的模块表示隐藏层记忆单元,每个记忆单元都有三个输入,两个输出(虽然图中有三个输出箭头,但其中两个输出都是\(h(t)\)),因此,\(t\) 时刻记忆单元的输入、输出分别为 \(x(t), h(t-1), s(t-1)\) 和 \(h(t), s(t)\)
![differciate_chain_sh](http://obmpvqs90.bkt.clouddn.com/differciate_chain_sh.png)
模型的计算公式如下:
\[\begin{align}
g(t) &= \phi(W_{gx}x(t) + W_{gh}h(t-1) + b_g) & (\text{Eq. 1.1}) \\
i(t) &= \sigma(W_{ix}x(t) + W_{ih}h(t-1) + b_i) & (\text{Eq. 1.2}) \\
f(t) &= \sigma(W_{fx}x(t) + W_{fh}h(t-1) + b_f) & (\text{Eq. 1.3}) \\
o(t) &= \sigma(W_{ox}x(t) + W_{oh}h(t-1) + b_o) & (\text{Eq. 1.4}) \\
s(t) &= g(t)*i(t) + s(t-1)*f(t) & (\text{Eq. 1.5}) \\
h(t) &= s(t)*o(t) & (\text{Eq. 1.6})
\end{align}
\]
成本函数的定义为:
\[L = \sum_{t=1}^Tl(t)=\frac{1}{2}\sum_{t=1}^T||y(t)-h(t)||^2 \qquad (\text{Eq. 2})
\]
其中\(y(t), h(t)\) 分别是目标结果和模型输出。
我们的目标是计算
\[\frac{dL}{dw}=\sum_{t=1}^T\frac{dL}{dh(t)}\frac{dh(t)}{dw} \qquad (\text{Eq. 3})
\]
其中, \(\frac{dL}{dh(t)}\)表示成本函\(L\)数对变量\(h(t)\)的全微分,可表示成如下形式
\[\frac{dL}{dh(t)}=\frac{d}{dh(t)}\sum_{\tau=1}^T l(\tau)=\frac{d}{dh(t)}\sum_{\tau=t}^T l(\tau)=\frac{dL(t)}{dh(t)}
\]
我们接下来将推导出以下四个全微分公式,进而计算\(\frac{dL}{dw}\):
\[\bbox[yellow]
{
\begin{align}
\frac{dL(t)}{dh(t)} \qquad (\text{Eq. 4.1}) \\ \\
\frac{dL(t+1)}{ds(t)} \qquad (\text{Eq. 4.2}) \\ \\
\frac{dL(t)}{dh(t-1)} \qquad (\text{Eq. 4.3}) \\ \\
\frac{dL(t)}{ds(t-1)} \qquad (\text{Eq. 4.4})
\end{align}
}
\]
其中\(h(t), s(t)\) 代表\(t\)时刻记忆单元的输出值,\(h(t-1), s(t-1)\) 则代表\(t\)时刻记忆单元的输入值。
- 根据 \(L(t) = l(t) + L(t+1)\),可得
\[\frac{dL(t)}{dh(t)}=\frac{dl(t)}{dh(t)}+\frac{dL(t+1)}{dh(t)} \qquad (\text{Eq. 5})
\]
- 成本函数 \(L\) 对 \(s(t)\) 的全微分形式 \(\frac{dL(t)}{ds(t)}\):
根据 Eq. 1.6, \(s(t)\) 的值会影响 \(h(t)\), 根据 Eq. 1.5 \(s(t)\) 值会影响 \(s(t+1)\),进而影响 \(h(t+1)\),因此\(\frac{dL}{ds(t)}\)可分为两部分计算
\[ \begin{align}
\frac{dL(t)}{ds(t)} & =\frac{dL(t)}{dh(t)}\cdot\frac{dh(t)}{ds(t)}+\frac{dL(t+1)}{dh(t+1)}\cdot\frac{dh(t+1)}{ds(t)} \\ \\
& = \frac{dL(t)}{dh(t)}\cdot\frac{dh(t)}{ds(t)}+\frac{dL(t+1)}{ds(t)} \qquad \qquad (\text{Eq. 6})
\end{align}
\]
\[\begin{align}
&\frac{dL(t)}{dh(t)} = \frac{dl(t)}{dh(t)}=h(t) - y(t) \\ \\
&\frac{dL(t+1)}{ds(t)} = 0
\end{align}
\]
进而可以根据 Eq. 6 求得 \(\frac{dL(t)}{ds(t)}\),且根据 Eq. 1.5 有
\[\frac{dL(t)}{ds(t-1)}=\frac{dL(t)}{ds(t)}\cdot f(t)
\]
通过 Eq. 1.1 - 1.6,我们可计算出 \(\frac{dh(t)}{dh(t-1)}\),进而求得
\[\frac{dL(t)}{dh(t-1)} = \frac{dL(t)}{dh(t)}\cdot \frac{dh(t)}{dh(t-1)}
\]
至此,我们求得 \(t=T\) 时刻 Eq. 4.1 - Eq. 4.4 四个全微分的值。
\[\begin{align}
\frac{dL(t)}{dh(t)} &= \frac{dl(t)}{dh(t)} + \frac{dL(t+1)}{dh(t)} \\ \\
& = h(t) - y(t) + \frac{dL(T)}{dh(T-1)} \\ \\
\frac{dL(t+1)}{ds(t)} &= \frac{dL(T)}{ds(T-1)}
\end{align}
\]
此时,与 \(t=T\) 时刻的情况完全一样,可依次求出 \(t=T-2, T-3, ... 2, 1\) 时刻的微分方程 Eq. 4.1-Eq. 4.4,从而求出 Eq. 3
![cell](http://obmpvqs90.bkt.clouddn.com/lstm_cell2.png)
![lstm](http://obmpvqs90.bkt.clouddn.com/lstm.png)
![rnn_types](http://obmpvqs90.bkt.clouddn.com/rnn_types.png)
参考:
[1] http://colah.github.io/posts/2015-08-Understanding-LSTMs/
[2] http://nicodjimenez.github.io/2014/08/08/lstm.html
[3] http://meta.math.stackexchange.com/questions/5020/mathjax-basic-tutorial-and-quick-reference
[4] A Critical Review of Recurrent Neural Networks for Sequence Learning. Zachary C. Lipton John Berkowitz