【469】RNN, LSTM参考博客
参考:【推荐】ML Lecture 21-1: Recurrent Neural Network (Part I)
参考:Illustrated Guide to Recurrent Neural Networks
参考:Illustrated Guide to LSTM’s and GRU’s: A step by step explanation
参考:Understanding LSTM Networks
参考:Recurrent Neural Networks | MIT 6.S191 —— YouTube
参考:The Unreasonable Effectiveness of Recurrent Neural Networks
参考:Anyone Can Learn To Code an LSTM-RNN in Python (Part 1: RNN)
参考:The Unreasonable Effectiveness of Recurrent Neural Networks
RNN
feed forward
- input1 -> hidden layer1 -> output1
- (hidden layer1 + input2) -> hidden layer2 -> output2
- (hidden layer2 + input3) -> hidden layer3 -> output3
- (hidden layer3 + input4) -> hidden layer4 -> output4
back propagation(橙色是权重更新)
- W_hy4 -> W_xh4
- W_hy3 -> W_xh3
- W_hy2 -> W_xh2
- W_hy1 -> W_xh1
RNN形式
下面以 many to one RNN 举例说明
参考:https://victorzhou.com/blog/intro-to-rnns/
The Plan
- 可以理解为一段文本,然后判断文本的sentiment
- 每一个 $x_i$ 表示文本中每个单词的 one-hot 编码向量
- 输出 $y$ 是一个二维向量,分别为 positive 和 negative
The Forward Phase
公式如下:
$h_t = tanh(W_{xh}x_t + W_{hh}h_{t-1} + b_h)$
$y_t = W_{hy}h_t + b_y$
- 初始化 3 个权重矩阵和 2 个偏置向量
- $W_{hh}$: [hidden_size, hidden_size],公式里面权重值在前
- $W_{xh}$: [hidden_size, input_size],公式里面权重值在前
- $W_{hy}$: [output_size, hidden_size],公式里面权重值在前
- $b_h$: [hidden_size, 1]
- $b_y$: [output_size, 1]
- 初始化 $h$,赋值为全 0 的矩阵,直接用来计算 $h_0$,因为它前面是没有 $h$ 的
- 循环,将 $x_i$ 的值循环传递到 $h_i$,再传递到 $y_i$
- 整个过程中,$W_{hh}$、$W_{xh}$、$W_{hy}$、$b_h$、$b_y$ 的值是保持不变的
- 最后 $y$ 通过 softmax 函数获得概率值
The Backward Phase
- loss function: cross-entropy loss
- 记录 $h_i$ 的所有值
- 【479】cross-entropy与softmax的求导
- 求出了$L$ 对于 $y_i$ 的偏导数
- 再求 $y_i$ 对于 $W_{hy}$ 的偏导数
- $y = W_{hy}h_n + b_y$
LSTM
一个 layer,用来获取 forget gate 的比例,激活函数是 sigmoid。用来计算前一个 cell 有多少部分被保留了。
两个 layer
- 左边为 input gate 的比例(sigmoid),说明有多少比例可以被输入
- 右边为从 $x_t$ 输入数据的部分(tanh)
- 两者相乘,表示有多少 $x_t$ 被输入进去
- 左边为上面所说的 forget gate 与前一个 cell state ($C_{t-1}$) 的乘积,表示保留的部分
- 右边为 input gate 与输入的信息的乘积,表示输入的部分
- 两者再相加,表示新的 cell state ($C_t$) 的结果
- 一个 output gate 的 layer(sigmoid),用来计算多少部分可以被输出
- 将前面算到的 cell state 做个 tanh 转黄,再与 output gate 的比例相乘
- 最终输出结果为 $h_t$
gif