MindSpore易点通·精讲系列--网络构建之LSTM算子--上篇
Dive Into MindSpore–Lstm Operator For Network Construction
MindSpore易点通·精讲系列–网络构建之LSTM算子
MindSpore易点通·精讲系列–网络构建之LSTM算子–上篇
MindSpore易点通·精讲系列–网络构建之LSTM算子–中篇
MindSpore易点通·精讲系列–网络构建之LSTM算子–下篇
本文开发环境
- MindSpore 1.7.0
本文内容提要
- 原理介绍
- 文档说明
- 案例解说
- 本文总结
- 本文参考
3. 案例解说
3.2 单层双向LSTM
本示例中随机生成了[4, 8, 4]数据,该数据batch_size为4,固定seq_length为8,输入维度为4。
本示例采用单层双向LSTM,隐层大小为8。
本示例中LSTM调用时进行对比测试,一个
seq_length
为默认值None,一个为有效长度input_seq_length
。
示例代码如下:
import numpy as np
from mindspore import dtype
from mindspore import Tensor
from mindspore.nn import LSTM
def single_layer_bi_lstm():
random_data = np.random.rand(4, 8, 4)
seq_length = [3, 8, 5, 1]
input_seq_data = Tensor(random_data, dtype=dtype.float32)
input_seq_length = Tensor(seq_length, dtype=dtype.int32)
batch_size = 4
input_size = 4
hidden_size = 8
num_layers = 1
bidirectional = True
num_bi = 2 if bidirectional else 1
lstm = LSTM(
input_size=input_size, hidden_size=hidden_size, num_layers=num_layers,
has_bias=True, batch_first=True, dropout=0.0, bidirectional=bidirectional)
h0 = Tensor(np.ones([num_bi * num_layers, batch_size, hidden_size]).astype(np.float32))
c0 = Tensor(np.ones([num_bi * num_layers, batch_size, hidden_size]).astype(np.float32))
output_0, (hn_0, cn_0) = lstm(input_seq_data, (h0, c0))
output_1, (hn_1, cn_1) = lstm(input_seq_data, (h0, c0), input_seq_length)
print("====== single layer bi lstm output 0 shape: {} ======\n{}".format(output_0.shape, output_0), flush=True)
print("====== single layer bi lstm hn0 shape: {} ======\n{}".format(hn_0.shape, hn_0), flush=True)
print("====== single layer bi lstm cn0 shape: {} ======\n{}".format(cn_0.shape, cn_0), flush=True)
print("====== single layer bi lstm output 1 shape: {} ======\n{}".format(output_1.shape, output_1), flush=True)
print("====== single layer bi lstm hn1 shape: {} ======\n{}".format(hn_1.shape, hn_1), flush=True)
print("====== single layer bi lstm cn1 shape: {} ======\n{}".format(cn_1.shape, cn_1), flush=True)
示例代码输出内容如下:
对输出内容进行分析:
- output_0和output_1维度都是[4, 8, 16],即batch_size, seq_length和hidden_size * 2,这里乘2是因为是双向输出。
- output_0对应的是调用时seq_length为None的情况,即默认有效seq_length为8,可以看到output_0各个长度输出数值皆非全零。
- output_1对应的是调用时seq_length为设定值[3, 8, 5, 1],可以看到output_1超过有效长度的输出部分皆为全零。
- hn和cn分别为隐层状态和细胞状态输出。下面以hn_1和cn_1为例进行讲解。
- hn_1维度为[2, 4, 8],2代表双向单层(2*1),4代表batch_size,8代表hidden_size。
- 仔细观察可以看出,hn_1中第一维度第0索引的正向输出部分与output_1最后一维输出前hidden_size数值一致,即与有效长度内最后一个的输出的前hidden_size数值保持一致。
- 仔细观察可以看出,hn_1中第一维度第1索引的反向输出部分与output_1开始一维输出后hidden_size数值一致。
- cn_1为有效最后一步的细胞状态。
====== single layer bi lstm output 0 shape: (4, 8, 16) ======
[[[ 0.11591419 0.29961097 0.3425573 0.4287143 0.17212108
0.07444338 0.43271446 0.15715674 0.08194006 0.11577142
-0.09744498 -0.02763127 0.09280778 0.08716499 0.02522062
0.33181873]
[-0.01308823 0.13623668 0.19448121 0.37028143 0.22777143
0.00628781 0.39128026 0.15501572 0.08111142 0.11017906
-0.12316822 -0.00816909 0.09567513 0.05021677 0.08249568
0.33742255]
[-0.05627449 0.04682723 0.15380071 0.3137156 0.26430035
-0.046514 0.35723254 0.16584632 0.10204285 0.10223756
-0.13232729 -0.00190703 0.11279006 0.07007243 0.07809626
0.36085904]
[-0.09489179 -0.00705127 0.1340199 0.24711385 0.27097055
-0.05539801 0.29088783 0.180727 0.13702057 0.07165765
-0.15263684 -0.02301912 0.14440101 0.09643525 0.04434848
0.32824463]
[-0.13192342 -0.09842218 0.13483751 0.2363211 0.2714419
-0.06301905 0.23002718 0.12190706 0.1600955 0.0820565
-0.13324322 0.00847512 0.15308659 0.12757084 0.06873622
0.3726861 ]
[-0.16037701 -0.12437794 0.12642992 0.23676534 0.29797453
-0.04277696 0.24219972 0.16359471 0.16195399 0.07269616
-0.1250204 -0.0185749 0.19040069 0.12709007 0.12064856
0.30454746]
[-0.1353235 -0.12385159 0.1025193 0.23867385 0.30110353
-0.03195428 0.2832907 0.18136714 0.19130123 0.09153596
-0.05207976 0.02430173 0.2524703 0.22256352 0.17788586
0.3196903 ]
[-0.15227936 -0.16710246 0.11279354 0.2324703 0.3158889
-0.05391366 0.28967926 0.21905534 0.34464788 0.06061291
0.10662059 0.08228769 0.38103724 0.44488934 0.22631703
0.38864976]]
[[ 0.07946795 0.30921736 0.35205007 0.37194842 0.2058839
0.09482588 0.4332572 0.2775039 0.10343523 0.07151344
-0.13616626 -0.04245609 0.10985457 0.06919786 0.0364913
0.31924048]
[-0.04591701 0.14795585 0.20307627 0.35713255 0.21074952
0.03478044 0.36047992 0.15351431 0.11235587 0.07168273
-0.11715946 -0.02380875 0.11772131 0.11803672 0.00387634
0.33266184]
[-0.09412251 0.02499678 0.17255405 0.3178058 0.23692454
-0.03471331 0.26576498 0.10732022 0.14581609 0.07355653
-0.12852795 0.01927058 0.13053373 0.14796041 0.01590303
0.3854578 ]
[-0.09348419 0.00631614 0.1466178 0.22848201 0.22966608
-0.05388562 0.14963126 0.08823045 0.15729474 0.0657778
-0.15222837 -0.01835432 0.15758416 0.17561477 -0.03188463
0.3511778 ]
[-0.15382743 -0.04836275 0.14573918 0.22835778 0.2532363
-0.03674607 0.1401736 0.09852327 0.17570393 0.04582136
-0.13850203 0.00081276 0.16863164 0.14211492 0.04397457
0.33833435]
[-0.14028388 -0.08847751 0.13194019 0.21878807 0.28851762
-0.06432837 0.15592363 0.16226491 0.20294866 0.04400881
-0.11535563 0.04870296 0.22049154 0.17808373 0.09339966
0.34441146]
[-0.1683049 -0.16189072 0.1318028 0.22591397 0.3027075
-0.07447627 0.15145044 0.1329806 0.2544369 0.06014252
-0.01793557 0.11026148 0.2146467 0.3118566 0.12141219
0.39812002]
[-0.19805393 -0.17752953 0.12876241 0.21628919 0.3038769
-0.036511 0.1357605 0.10460708 0.3527281 0.07156999
0.1540587 0.09252883 0.35960466 0.54258245 0.16377062