MindSpore易点通·精讲系列--模型训练之GPU分布式并行训练

Dive Into MindSpore–Lstm Operator For Network Construction

MindSpore易点通·精讲系列–网络构建之LSTM算子

MindSpore易点通·精讲系列–网络构建之LSTM算子–上篇
MindSpore易点通·精讲系列–网络构建之LSTM算子–中篇
MindSpore易点通·精讲系列–网络构建之LSTM算子–下篇

本文开发环境

  • MindSpore 1.7.0

本文内容提要

  • 原理介绍
  • 文档说明
  • 案例解说
  • 本文总结
  • 本文参考

1. 原理介绍

LSTM,Long Short Term Memory,又称长短时记忆网络。原始RNN存在一个严重的缺陷:训练过程中经常会出现梯度爆炸和梯度消失的问题,以至于原始的RNN很难处理长距离的依赖,为了解决(缓解)这个问题,研究人员提出了LSTM。

1.1 LSTM公式

LSTM的公式表示如下:

\begin{split}\begin{array}{ll} \\ i_t = \sigma(W_{ix} x_t + b_{ix} + W_{ih} h_{(t-1)} + b_{ih}) \\ f_t = \sigma(W_{fx} x_t + b_{fx} + W_{fh} h_{(t-1)} + b_{fh}) \\ \tilde{c}_t = \tanh(W_{cx} x_t + b_{cx} + W_{ch} h_{(t-1)} + b_{ch}) \\ o_t = \sigma(W_{ox} x_t + b_{ox} + W_{oh} h_{(t-1)} + b_{oh}) \\ c_t = f_t * c_{(t-1)} + i_t * \tilde{c}_t \\ h_t = o_t * \tanh(c_t) \\ \end{array}\end{split}

其中σ是sigmoid激活函数, *是乘积。 W, b 是公式中输出和输入之间的可学习权重。

1.2 LSTM结构

为方便理解,1.1中的公式的结构示意图如下:

2022_05_06_lstm_inner_structure

1.3 LSTM门控

1.3.1 遗忘门

遗忘门公式为:

f_t = \sigma(W_{fx} x_t + b_{fx} + W_{fh} h_{(t-1)} + b_{fh})ft=σ(Wfxxt+bfx+Wfhh(t1)+bfh)

解读:

“遗忘门”决定之前状态中的信息有多少应该舍弃。它会读取 h_{t-1}ht1 和 x_txt的内容,\sigmaσ符号代表Sigmoid函数,它会输出一个0到1之间的值。其中0代表舍弃之前细胞状态C_{t-1}Ct1中的内容,1代表完全保留之前细胞状态C_{t-1}Ct1中的内容。0、1之间的值代表部分保留之前细胞状态C_{t-1}Ct1中的内容。

1.3.2 输入门

输入门公式为:

i_t = \sigma(W_{ix} x_t + b_{ix} + W_{ih} h_{(t-1)} + b_{ih}) \\ \tilde{c}_t = \tanh(W_{cx} x_t + b_{cx} + W_{ch} h_{(t-1)} + b_{ch})it=σ(Wixxt+bix+Wihh(t1)+bih)c~t=tanh(Wcxxt+bcx+Wchh(t1)+bch)

解读:

“输入门”决定什么样的信息保留在细胞状态C_tCt中,它会读取 h_{t-1}ht1 和 x_txt的内容,\sigmaσ符号代表Sigmoid函数,它会输出一个0到1之间的值。

和“输入门”配合的还有另外一部分,这部分输入也是h_{t-1}ht1 和 x_txt,不过采用tanh激活函数,将这部分标记为\tilde c^{(t)}c~(t),称作为“候选状态”。

1.3.3 细胞状态

细胞状态公式为:

c_t = f_t * c_{(t-1)} + i_t * \tilde{c}_tct=ftc(t1)+itc~t

解读:

C_{t-1}Ct1 计算得到C_tCt

旧“细胞状态”C_{t-1}Ct1和“遗忘门”的结果进行计算,决定旧的“细胞状态”保留多少,忘记多少。接着“输入门”i^{(t)}i(t)和候选状态\tilde c^{(t)}c~(t)进行计算,将所得到的结果加入到“细胞状态”中,这表示新的输入信息有多少加入到“细胞状态中”。

1.3.4 输出门

输出门公式为:

o_t = \sigma(W_{ox} x_t + b_{ox} + W_{oh} h_{(t-1)} + b_{oh}) \\ h_t = o_t * \tanh(c_t)ot=σ(Woxxt+box+Wohh(t1)+boh)ht=ottanh(ct)

解读:

和其他门计算一样,它会读取 h_{t-1}ht1 和 x_txt的内容,然后计算Sigmoid函数,得到“输出门”的值。接着把“细胞状态”通过tanh进行处理(得到一个在-1到1之间的值),并将它和输出门的结果相乘,最终得到确定输出的部分h_tht,即新的隐藏状态。

特别说明:

在上述公式中,xt为当前的输入,h(t-1)为上一步的隐藏状态,c(t-1)为上一步的细胞状态。

当t=1时,可知h(t-1)为h0,c(t-1)为c0。

一般来说,h0/c0设置为0或1,或固定的随机值。

2. 文档说明

下面来看看官网文档说明,主要看参数部分:

2022_05_06_ms_lstm_api

从官方文档可知,MindSpore中的LSTM算子支持多层双向设置,同时可接受输入数据第一维为非batch_size的情况,而且自带dropout。

下面通过案例来对该算子的输入和输出进行讲解。

3. 案例解说

3.1 单层正向LSTM

本示例中随机生成了[4, 8, 4]数据,该数据batch_size为4,固定seq_length为8,输入维度为4。

本示例采用单层单向LSTM,隐层大小为8。

本示例中LSTM调用时进行对比测试,一个seq_length为默认值None,一个为有效长度input_seq_length

示例代码如下:

import numpy as np

from mindspore import dtype
from mindspore import Tensor
from mindspore.nn import LSTM


def single_layer_lstm():
    random_data = np.random.rand(4, 8, 4)
    seq_length = [3, 8, 5, 1]
    input_seq_data = Tensor(random_data, dtype=dtype.float32)
    input_seq_length = Tensor(seq_length, dtype=dtype.int32)

    batch_size = 4
    input_size = 4
    hidden_size = 8
    num_layers = 1
    bidirectional = False
    num_bi = 2 if bidirectional else 1

    lstm = LSTM(
        input_size=input_size, hidden_size=hidden_size, num_layers=num_layers,
        has_bias=True, batch_first=True, dropout=0.0, bidirectional=bidirectional)

    h0 = Tensor(np.ones([num_bi * num_layers, batch_size, hidden_size]).astype(np.float32))
    c0 = Tensor(np.ones([num_bi * num_layers, batch_size, hidden_size]).astype(np.float32))

    output_0, (hn_0, cn_0) = lstm(input_seq_data, (h0, c0))
    output_1, (hn_1, cn_1) = lstm(input_seq_data, (h0, c0), input_seq_length)

    print("====== single layer lstm output 0 shape: {} ======\n{}".format(output_0.shape, output_0), flush=True)
    print("====== single layer lstm hn0 shape: {} ======\n{}".format(hn_0.shape, hn_0), flush=True)
    print("====== single layer lstm cn0 shape: {} ======\n{}".format(cn_0.shape, cn_0), flush=True)

    print("====== single layer lstm output 1 shape: {} ======\n{}".format(output_1.shape, output_1), flush=True)
    print("====== single layer lstm hn1 shape: {} ======\n{}".format(hn_1.shape, hn_1), flush=True)
    print("====== single layer lstm cn1 shape: {} ======\n{}".format(cn_1.shape, cn_1), flush=True)

示例代码输出内容如下:

对输出内容进行分析:

  1. output_0和output_1维度都是[4, 8, 8],即batch_size, seq_length和hidden_size
  2. output_0对应的是调用时seq_length为None的情况,即默认有效seq_length为8,可以看到output_0各个长度输出数值皆非全零。
  3. output_1对应的是调用时seq_length为设定值[3, 8, 5, 1],可以看到output_1超过有效长度的输出部分皆为全零。
  4. hn和cn分别为隐层状态和细胞状态输出。下面以hn_1和cn_1为例进行讲解。
  5. hn_1维度为[1, 4, 8],1代表单向单层(1*1),4代表batch_size,8代表hidden_size。
  6. 仔细观察可以看出,hn_1的输出与output_1最后一维的输出一致,即与有效长度内最后一个的输出保持一致。
  7. cn_1为有效最后一步的细胞状态。
====== single layer lstm output 0 shape: (4, 8, 8) ======
[[[ 0.13193643  0.31574252  0.21773982  0.359429    0.23590101
    0.28213733  0.24443595  0.37388077]
  [-0.02988351  0.1415896   0.15356182  0.2834958  -0.00328176
    0.3491612   0.12643641  0.142024  ]
  [-0.09670443  0.03373189  0.1445203   0.19673887  0.06278481
    0.33509392 -0.02579015  0.07650157]
  [-0.15380219 -0.04781847  0.07795938  0.15893918  0.01305779
    0.33979264 -0.00364386  0.04361304]
  [-0.16254447 -0.06737433  0.05285644  0.10944269  0.01782622
    0.34567034 -0.04204851  0.01285298]
  [-0.21082401 -0.09526701  0.0265205   0.10617667 -0.03112434
    0.33731762 -0.02207689 -0.00955394]
  [-0.23450094 -0.09586379  0.02365175  0.09352495 -0.03744857
    0.33376914 -0.04699665 -0.03528202]
  [-0.24089803 -0.06166056  0.02839395  0.09916345 -0.04156012
    0.31369895 -0.08876226 -0.0487675 ]]

 [[ 0.10673305  0.30631748  0.22279048  0.35392687  0.270858
    0.2800686   0.21576329  0.37215734]
  [ 0.07373721  0.07924869  0.20754944  0.2059646   0.12672944
    0.35556036  0.05576535  0.2124105 ]
  [-0.09233213  0.02507205  0.11608997  0.23507075  0.0269099
    0.3196378   0.00475359  0.05898073]
  [-0.14939436 -0.04166775  0.07941992  0.15797664  0.02167228
    0.34059638 -0.02956495  0.00525782]
  [-0.18659307 -0.08790994  0.04543061  0.12085741  0.01649844
    0.33063915 -0.03531799 -0.01156766]
  [-0.22867033 -0.10603286  0.03872797  0.11688479  0.01904946
    0.3056394  -0.05695718 -0.01623933]
  [-0.21695574 -0.11095987  0.03115554  0.08672465  0.04249544
    0.3152427  -0.07418983 -0.02036544]
  [-0.21967101 -0.10076816  0.01712734  0.08198812  0.02862469
    0.31535396 -0.09173042 -0.05647325]]

 [[ 0.1493079   0.28768584  0.2575181   0.3199168   0.30599245
    0.28865623  0.16678075  0.41237575]
  [ 0.01445133  0.13631815  0.18265024  0.2577204   0.09361918
    0.3227448   0.04080902  0.17163058]
  [-0.1164555   0.05409181  0.1229048   0.24406306  0.02090637
    0.31171325 -0.02868806  0.06015658]
  [-0.12215493 -0.04073931  0.09229688  0.13461691  0.05322267
    0.34697118 -0.04028781  0.05017967]
  [-0.16058712 -0.02990636  0.06711683  0.13881728  0.04944531
    0.30471358 -0.08764775  0.01227296]
  [-0.17542893 -0.04518626  0.06441598  0.12666796  0.1039256
    0.29512212 -0.12625514 -0.01764686]
  [-0.18198647 -0.06205402  0.05437353  0.12312049  0.11571115
    0.27589387 -0.13898477 -0.00659172]
  [-0.18840623 -0.03089028  0.02871101  0.13332503  0.02779378
    0.2934873  -0.12758468 -0.02508291]]

 [[ 0.16055782  0.28248906  0.24979302  0.3381475   0.28849283
    0.3085897   0.21882199  0.3911534 ]
  [ 0.03212452  0.10363571  0.18571742  0.25555134  0.11808199
    0.33315352  0.0612903   0.16566488]
  [-0.09707587  0.08886775  0.130165    0.23324937  0.0596167
    0.28433815 -0.05993269  0.06611289]
  [-0.15705962 -0.00274712  0.09360209  0.18597823  0.04157853
    0.32279128 -0.07580574  0.01155218]
  [-0.15376413 -0.07929687  0.06302985  0.11465057  0.07184268
    0.3261627  -0.05871713  0.04223134]
  [-0.18791473 -0.07859336  0.02364462  0.12526496 -0.02513029
    0.33071572 -0.03542359 -0.00976665]
  [-0.23625109 -0.03007499  0.03267653  0.15940045 -0.08530897
    0.30445266 -0.0852924  -0.04507463]
  [-0.23499809 -0.07687293  0.03790941  0.08663946 -0.00264841
    0.33423126 -0.06512782  0.01413365]]]
====== single layer lstm hn0 shape: (1, 4, 8) ======
[[[-0.24089803 -0.06166056  0.02839395  0.09916345 -0.04156012
    0.31369895 -0.08876226 -0.0487675 ]
  [-0.21967101 -0.10076816  0.01712734  0.08198812  0.02862469
    0.31535396 -0.09173042 -0.05647325]
  [-0.18840623 -0.03089028  0.02871101  0.13332503  0.02779378
    0.2934873  -0.12758468 -0.02508291]
  [-0.23499809 -0.07687293  0.03790941  0.08663946 -0.00264841
    0.33423126 -0.06512782  0.01413365]]]
====== single layer lstm cn0 shape: (1, 4, 8) ======
[[[-0.72842515 -0.10623126  0.07748945  0.23840414 -0.0663506
    0.82394135 -0.20612013 -0.11983471]
  [-0.6431069  -0.17861958  0.04168103  0.20188545  0.0463764
    0.73273325 -0.21914008 -0.13169488]
  [-0.61163914 -0.05123866  0.07892742  0.32583922  0.04181815
    0.79872614 -0.2969701  -0.0625343 ]
  [-0.58037984 -0.15040846  0.09998614  0.24211554 -0.0044073
    0.8616534  -0.1546249   0.03137078]]]
====== single layer lstm output 1 shape: (4, 8, 8) ======
[[[ 0.13193643  0.31574252  0.21773985  0.35942894  0.23590101
    0.28213733  0.24443595  0.37388077]
  [-0.02988352  0.1415896   0.15356182  0.28349578 -0.00328175
    0.34916118  0.12643641  0.142024  ]
  [-0.09670443  0.0337319   0.14452031  0.19673884  0.06278481
    0.33509392 -0.02579015  0.07650157]
  [ 0.          0.          0.          0.          0.
    0.          0.          0.        ]
  [ 0.          0.          0.          0.          0.
    0.          0.          0.        ]
  [ 0.          0.          0.          0.          0.
    0.          0.          0.        ]
  [ 0.          0.          0.          0.          0.
    0.          0.          0.        ]
  [ 0.          0.          0.          0.          0.
    0.          0.          0.        ]]

 [[ 0.10673306  0.30631748  0.22279048  0.35392687  0.27085796
    0.2800686   0.21576326  0.37215734]
  [ 0.07373722  0.0792487   0.20754944  0.2059646   0.12672943
    0.35556036  0.05576536  0.2124105 ]
  [-0.09233214  0.02507207  0.11608997  0.23507075  0.02690989
    0.3196378   0.00475359  0.05898073]
  [-0.14939436 -0.04166774  0.07941992  0.15797664  0.02167228
    0.34059638 -0.02956495  0.00525782]
  [-0.18659307 -0.08790994  0.04543061  0.12085741  0.01649844
    0.33063915 -0.03531799 -0.01156766]
  [-0.22867033 -0.10603285  0.03872797  0.11688479  0.01904945
    0.3056394  -0.05695718 -0.01623933]
  [-0.21695574 -0.11095986  0.03115554  0.08672465  0.04249543
    0.3152427  -0.07418983 -0.02036544]
  [-0.21967097 -0.10076815  0.01712734  0.08198812  0.02862468
    0.31535396 -0.09173042 -0.05647324]]

 [[ 0.1493079   0.28768584  0.25751814  0.3199168   0.30599245
    0.28865623  0.16678077  0.41237575]
  [ 0.01445133  0.13631816  0.18265024  0.25772038  0.09361918
    0.3227448   0.04080902  0.17163058]
  [-0.1164555   0.05409183  0.1229048   0.24406303  0.02090637
    0.31171325 -0.02868806  0.06015658]
  [-0.12215493 -0.0407393   0.09229688  0.1346169   0.05322267
    0.3469712  -0.0402878   0.05017967]
  [-0.16058712 -0.02990635  0.06711683  0.13881728  0.0494453
    0.30471358 -0.08764775  0.01227296]
  [ 0.          0.          0.          0.          0.
    0.          0.          0.        ]
  [ 0.          0.          0.          0.          0.
    0.          0.          0.        ]
  [ 0.          0.          0.          0.          0.
    0.          0.          0.        ]]

 [[ 0.16055782  0.2824891   0.24979301  0.33814746  0.28849283
    0.30858967  0.21882202  0.3911534 ]
  [ 0.          0.          0.          0.          0.
    0.          0.          0.        ]
  [ 0.          0.          0.          0.          0.
    0.          0.          0.        ]
  [ 0.          0.          0.          0.          0.
    0.          0.          0.        ]
  [ 0.          0.          0.          0.          0.
    0.          0.          0.        ]
  [ 0.          0.          0.          0.          0.
    0.          0.          0.        ]
  [ 0.          0.          0.          0.          0.
    0.          0.          0.        ]
  [ 0.          0.          0.          0.          0.
    0.          0.          0.        ]]]
====== single layer lstm hn1 shape: (1, 4, 8) ======
[[[-0.09670443  0.0337319   0.14452031  0.19673884  0.06278481
    0.33509392 -0.02579015  0.07650157]
  [-0.21967097 -0.10076815  0.01712734  0.08198812  0.02862468
    0.31535396 -0.09173042 -0.05647324]
  [-0.16058712 -0.02990635  0.06711683  0.13881728  0.0494453
    0.30471358 -0.08764775  0.01227296]
  [ 0.16055782  0.2824891   0.24979301  0.33814746  0.28849283
    0.30858967  0.21882202  0.3911534 ]]]
====== single layer lstm cn1 shape: (1, 4, 8) ======
[[[-0.22198828  0.05788375  0.38487202  0.5277796   0.10692163
    0.88817626 -0.06333658  0.15489307]
  [-0.6431068  -0.17861956  0.04168103  0.20188545  0.04637639
    0.73273325 -0.21914008 -0.13169487]
  [-0.44337854 -0.05043292  0.17615467  0.36942852  0.0769525
    0.8138213  -0.22219141  0.02737183]
  [ 0.50136805  0.47527558  0.8696786   0.7511291   0.37594885
    0.9162327   0.5345433   0.6333548 ]]]

MindSpore易点通·精讲系列–网络构建之LSTM算子–上篇
MindSpore易点通·精讲系列–网络构建之LSTM算子–中篇
MindSpore易点通·精讲系列–网络构建之LSTM算子–下篇

本文为原创文章,版权归作者所有,未经授权不得转载!

posted @ 2022-08-11 18:30  Skytier  阅读(133)  评论(0编辑  收藏  举报