李宏毅《机器学习》总结 - RNN & LSTM

在 slot-filling 问题(如给一个句子,自己分析出时间、地点等)
能解决的问题如给若干个向量,输出相同数量的向量
image
如果只连着不同的 FC,那么会导致无法读出是 arrive 还是 leave 的情况,导致错误
因此,需要 NN 来考虑到整个句子的信息,也就是需要有 memory,这就是 RNN

RNN 原理

image
有了 memory,就可以初步解决同一个信息由于句子不同导致的意义不同的问题了
更进一步的,有这样的结构:
image
黄色代表输入,即每一个单词。绿色代表隐藏层,注意事实上只有一个网络,不同单词对应的网络的参数是相同的,橙色代表输出,是一个概率向量
image
在不同的单词的网络之间传递的 memory 除了可以从隐藏层得到之外,还可以从输出层得到,这对应了另外一种架构:
image

Bidirectional RNN

普通的 RNN 处理每个单词的时候只能获取其前面的信息,但是通过双向 RNN,可以得到其后面的信息。
简单来说就是训练两个 RNN,一个正向一个反向,再将两个 RNN 的同一个对应的隐藏层扔给输出层,就得到了输出。
image

LSTM

是 RNN 的改良。每一个位置有 4 个输入,除了数据的输入(input)之外,还有是否将 input gate 打开(输入进网络中)、是否要遗忘 memory 里的数据,是否要打开 output gate(输出到输出层中)
image
在 RNN 中,memory 只能记着上一层的隐藏层,但是 LSTM 中,能记着更久远的事情了,因此叫“long short-term ..”
具体的,在实作中,关于 input/output gate 是否打开,可以利用 sigmoid function 来实现,如果输入是一个负值,那么就可以认为相应的 gate 是关闭的(因为 \(f(z_i) \rightarrow 0\)
另外,当 \(z_f>0\) 时在 memory 中的数据是要保存的,反之则是遗忘,因此应该叫做“keep gate” 而非 "forget gate" 更为恰当
image
关于如何进行输入的问题,也就是 \(z_i, z_o, z_f\) 是怎么得到的问题,可以这么看:
对于每一个 LSTM 的 cell(即上图) 而言,都是有好几个原始输入变量的,原始输入变量线性组合得到 \(z_{xx}\),而得到 \(z_{xx}\) 时的线性组合函数不同,因此得到了不同的输入。
以输入的几个原始变量 \((x_1,x_2,x_3)=(2,0,0)\) 为例:
image
绿色框代表 bias,也就是默认情况下 input gate 和 output gate 是关的,而 forget gate 是开的,不同的 \(z\) 对原始变量做 weighted sum(实际上这就是需要训练的网络),就得到了 \(z\) 的值
可以发现,当 \((2,0,0)\) 时,input/output gate 是关的,forget gate 是开的。
\((1,0,1)\) 时,经过计算可以得到 input 是关的,output gate 和 forget gate 是开的。
也就是说,如果有 2 个初始输入的话,每个初始输入需要 4 倍的参数来 transform 得到所需要的 input
image
那么,LSTM 是怎么建立在 RNN 上的呢?这就需要多个 cell 串起来了,具体地说:
设当前处理的向量是 \(x^t\)\(x^t\) 通过和某一个矩阵做矩阵乘法(这个矩阵也是需要学习的),得到 \(z\) 向量,而将 \(z\) 的每一个 dimension 拆出来,作为每一个对应的 cell 的 input,同理得到 \(z^i, z^o, z^f\) 向量,也是同理拆出每一个 dimension 作为 input/output/forget gate 的输入
image
可以简单记为 “z向量” 作为 cell 的 input,(实际上还是指将 \(z\) 的每一个 dimension 拿出来分别作为对应 cell 的 input)易得(相当于是把 \(|z|\) 个 cell 串起来了):
image
可以将一个 cell 做的事情用线性代数表示:
image
(其中,\(\times\) 表示向量对应位置相乘所组成的向量),可以发现这样的写法和 \(|z|\) 个 cell 串起来的效果一样
除此之外,作为原始输入的 \(x^t\) 还应该考虑 memory 和 上一个输出层的影响。将他们串起来,就得到了 LSTM 的最终原理图:
image

posted @ 2024-01-29 21:04  SkyRainWind  阅读(86)  评论(0编辑  收藏  举报