LSTM计算分析
普通的RNN,中间循环的就是一个tanh激活函数
LSTM也具有这种链式结构,但重复模块具有不同的结构。有四个神经网络层,它们以一种非常特殊的方式相互作用,而不是只有一个单独的神经网络层。
LSTM (Long Short-Term Memory)也称长短时记忆结构,它是传统RNN的变体,与经典RNN相比能够有效捕捉长序列之间的语义关联缓解梯度消失或爆炸现象.
核心结构可以分为四个部分:
- 遗忘门 (forget gate),它决定了上一时刻的单元状态ct-1有多少保留到当前时刻;
- 输入门 (input gate),它决定了当前时刻网络的输入c’t有多少保存到新的单元状态ct中。
- 细胞状态
- 输出门 (output gate),它利用当前时刻单元状态cn对hn的输出进行控制。
遗忘门就是控制LSTM的长时记忆部分,把以前的记忆和现在发生的事情做一次融合,记住重要的信息,忘记不重要的信息。
上面也提到了短时记忆就是把上一步的输出h(t-1)和当前输入Xt做一次拼接 [Xt,h(t-1)],有了拼接的输入信息后,我们就做一次全连接层FC,然后对结果做sigmod激活处理,最后与长时记忆C(t-1) 做一次乘积
从上面的描述不难看出,参数就在全连接产生的参数 Wf 和 bf,这样说吧,整个LSTM的参数是由4个全连接层产生的
输入也就是词向量为200维,每个token用一个200的向量表示 即 Xt
输出ht为100维的向量, 则 [Xt,h(t-1)]=300维向量,即1×300
则 Wf 为 300 × 100的矩阵 Wf · [Xt,h(t-1)] = 1 ×300 · 300 ×100 = 1 × 100 然后进行向量的sigmoid。也就是上次来的细胞状态遗忘分数
输入门:当前的信息是不是重要的,如果重要的,就多记住一些,如果是不重要的就少记住一些
输入们和我们的理解东西也类似,这里引入了一个细胞状态的概念C'(t),就是有多少信息保存到细胞状态中,输入的和遗忘门的信息和输出们都一样,都是把上一步的输出h(t-1)和当前输入Xt做一次拼接 [Xt,h(t-1)],然后做一次全连接,所不同的是细胞状态的激活函数是tanh。其他的都是sigmod.
输入门还有另外一块输入信息i(t),这一部分就是拼接后的输入信息做一次全连接,然后经过sigmod激活函数 本次细胞状态记忆的分数
根据上面遗忘门的计算
i(t) C'(t) 的结构也是 1 × 100
长时记忆 C(t)
- 经过了遗忘门和输入门后,长时记忆C(t)就可以计算出来了,计算比较简单,就是把遗忘门和输入门的结果做一次加法
我们用公式说明就是: 注意不是矩阵乘,是对应位置乘
C(t) = ft*Ct-1 + it * C`t
对位置乘,不是矩阵乘
Ct =
1 × 100 * 1 × 100 + 1 × 100 * 1 × 100 = 1 × 100 + 1 × 100 = 1 × 100
输出门
ot 也是 1 × 100
Ct经过tanh后也是 1 × 100
ht 为两者的对应位置乘
即 ht = 1 × 100 * 1 × 100 = 1 × 100
计算完成后进行下一个字的计算
参考资料:
Understanding LSTM Networks -- colah's blog
LSTM单元结构及参数计算 - 知乎 (zhihu.com)