MindSpore易点通·精讲系列--网络构建之LSTM算子--下篇
Dive Into MindSpore–Lstm Operator For Network Construction
MindSpore易点通·精讲系列–网络构建之LSTM算子
MindSpore易点通·精讲系列–网络构建之LSTM算子–上篇
MindSpore易点通·精讲系列–网络构建之LSTM算子–中篇
MindSpore易点通·精讲系列–网络构建之LSTM算子–下篇
本文开发环境
- MindSpore 1.7.0
本文内容提要
- 原理介绍
- 文档说明
- 案例解说
- 本文总结
- 本文参考
3. 案例解说
3.3 双层双向LSTM
本示例中随机生成了[4, 8, 4]数据,该数据batch_size为4,固定seq_length为8,输入维度为4。
本示例采用双层双向LSTM,隐层大小为8。
本示例中LSTM调用时进行对比测试,一个
seq_length
为默认值None,一个为有效长度input_seq_length
。
示例代码如下:
import numpy as np
from mindspore import dtype
from mindspore import Tensor
from mindspore.nn import LSTM
def double_layer_bi_lstm():
random_data = np.random.rand(4, 8, 4)
seq_length = [3, 8, 5, 1]
input_seq_data = Tensor(random_data, dtype=dtype.float32)
input_seq_length = Tensor(seq_length, dtype=dtype.int32)
batch_size = 4
input_size = 4
hidden_size = 8
num_layers = 2
bidirectional = True
num_bi = 2 if bidirectional else 1
lstm = LSTM(
input_size=input_size, hidden_size=hidden_size, num_layers=num_layers,
has_bias=True, batch_first=True, dropout=0.0, bidirectional=bidirectional)
h0 = Tensor(np.ones([num_bi * num_layers, batch_size, hidden_size]).astype(np.float32))
c0 = Tensor(np.ones([num_bi * num_layers, batch_size, hidden_size]).astype(np.float32))
output_0, (hn_0, cn_0) = lstm(input_seq_data, (h0, c0))
output_1, (hn_1, cn_1) = lstm(input_seq_data, (h0, c0), input_seq_length)
print("====== double layer bi lstm output 0 shape: {} ======\n{}".format(output_0.shape, output_0), flush=True)
print("====== double layer bi lstm hn0 shape: {} ======\n{}".format(hn_0.shape, hn_0), flush=True)
print("====== double layer bi lstm cn0 shape: {} ======\n{}".format(cn_0.shape, cn_0), flush=True)
print("====== double layer bi lstm output 1 shape: {} ======\n{}".format(output_1.shape, output_1), flush=True)
print("====== double layer bi lstm hn1 shape: {} ======\n{}".format(hn_1.shape, hn_1), flush=True)
print("====== double layer bi lstm cn1 shape: {} ======\n{}".format(cn_1.shape, cn_1), flush=True)
示例代码输出内容如下:
对输出内容进行分析:
- output_0和output_1维度都是[4, 8, 16],即batch_size, seq_length和hidden_size * 2,这里乘2是因为是双向输出。
- output_0和output_1皆是第二层(最后一层)的输出,中间层(本例为第一层)输出没有显示给出。
- output_0对应的是调用时seq_length为None的情况,即默认有效seq_length为8,可以看到output_0各个长度输出数值皆非全零。
- output_1对应的是调用时seq_length为设定值[3, 8, 5, 1],可以看到output_1超过有效长度的输出部分皆为全零。
- hn和cn分别为隐层状态和细胞状态输出。下面以hn_1和cn_1为例进行讲解。
- hn_1维度为[4, 4, 8],4代表双向双层(2*2),4代表batch_size,8代表hidden_size。
- 6中说明4代表双向双层(2*2),hn_1包含各层的最终有效隐层状态输出,这里同output_1只包含最后一层的输出不同。
- 仔细观察可以看出,hn_1中第一维度第2索引位置(即最后一层)的正向输出部分与output_1最后一维输出前hidden_size数值一致,即与有效长度内最后一个的输出的前hidden_size数值保持一致。
- 仔细观察可以看出,hn_1中第一维度第3索引位置(即最后一层)的反向输出部分与output_1开始一维输出后hidden_size数值一致。
- cn_1为有效最后一步的细胞状态。
====== double layer bi lstm output 0 shape: (4, 8, 16) ======
[[[ 3.70550364e-01 2.17652053e-01 3.79816592e-01 5.39002419e-01
2.28588611e-01 3.83301824e-02 2.20795229e-01 2.44438455e-01
2.06572518e-01 -3.78293954e-02 2.60271341e-01 -4.60247397e-02
-3.78369205e-02 -1.90976545e-01 -1.01466656e-01 1.76680252e-01]
...
...
...
...
...
...
[-8.48584175e-02 -4.15292941e-02 4.26153004e-01 -1.12198450e-01
2.93441713e-01 4.73045520e-02 7.22456872e-02 -1.52661309e-01
6.08003795e-01 1.02589525e-01 2.28410736e-01 3.57809156e-01
2.30974391e-01 7.29562640e-02 1.54908523e-01 1.37615114e-01]]
[[ 3.73128176e-01 2.24487275e-01 3.83654892e-01 5.39644539e-01
2.24863932e-01 3.69703583e-02 2.22563371e-01 2.47377262e-01
2.09958509e-01 -3.67934220e-02 2.55294740e-01 -5.44558465e-02
-3.49954516e-02 -1.88630879e-01 -9.97974724e-02 1.72440261e-01]
...
...
...
...
...
...
[-9.71160829e-02 -4.43801992e-02 4.20233607e-01 -1.02356419e-01
3.03063601e-01 3.99401113e-02 8.28935355e-02 -1.43912748e-01
6.09543681e-01 1.04935512e-01 2.27933496e-01 3.57850134e-01
2.31336534e-01 7.57181123e-02 1.55172557e-01 1.39436752e-01]]
[[ 3.74232024e-01 2.23312378e-01 3.80826175e-01 5.25748074e-01
2.30494052e-01 3.75359394e-02 2.19325155e-01 2.45338157e-01
1.90327644e-01 -9.49237868e-03 2.51282185e-01 -4.07305919e-02
-7.68693071e-03 -1.96041882e-01 -9.43402052e-02 1.52500823e-01]
...
...
...
...
...
...
[-1.07369550e-01 -7.64680207e-02 4.24612671e-01 -8.88631567e-02
3.25147092e-01 5.22605665e-02 7.02133700e-02 -1.30118832e-01
6.03053808e-01 1.08490229e-01 2.35621274e-01 3.42306137e-01
2.33348757e-01 7.23976195e-02 1.51835442e-01 1.38724014e-01]]
[[ 3.68833274e-01 2.19720796e-01 3.75712991e-01 5.39344609e-01
2.32777387e-01 3.75517495e-02 2.15990663e-01 2.38119900e-01
2.03846872e-01 -3.31601547e-03 2.63746709e-01 -5.33154309e-02
-1.53900171e-02 -1.96350247e-01 -9.86721516e-02 1.51238605e-01]
...
...
...
...
...
...
[-9.11041871e-02 -4.77942340e-02 4.29545075e-01 -1.14117011e-01
3.04611683e-01 5.14086746e-02 7.33837485e-02 -1.44734517e-01
6.06585741e-01 9.89784896e-02 2.24559098e-01 3.55441421e-01
2.28052005e-01 7.30600879e-02 1.55306384e-01 1.37683451e-01]]]
====== double layer bi lstm hn0 shape: (4, 4, 8) ======
[[[ 0.25934413 -0.07461581 0.19370164 0.11095355 0.02041678
0.29797387 0.03047622 0.19640712]
[ 0.2874061 -0.08844143 0.22119689 0.1251989 -0.01900517
0.29294112 0.05027778 0.2071664 ]
[ 0.2596095 0.03271259 0.26155 0.10348854 0.08536521
0.28197888 -0.08929807 0.18018515]
[ 0.2509837 -0.07010224 0.20813467 0.10349585 0.04007874
0.27277622 0.01278557 0.18474495]]
...
...
[[ 0.20657252 -0.0378294 0.26027134 -0.04602474 -0.03783692
-0.19097655 -0.10146666 0.17668025]
[ 0.20995851 -0.03679342 0.25529474 -0.05445585 -0.03499545
-0.18863088 -0.09979747 0.17244026]
[ 0.19032764 -0.00949238 0.2512822 -0.04073059 -0.00768693
-0.19604188 -0.09434021 0.15250082]
[ 0.20384687 -0.00331602 0.2637467 -0.05331543 -0.01539002
-0.19635025 -0.09867215 0.1512386 ]]]
====== double layer bi lstm cn0 shape: (4, 4, 8) ======
[[[ 0.5770398 -0.16899881 0.40028483