MindSpore易点通·精讲系列--网络构建之LSTM算子--上篇

Dive Into MindSpore–Lstm Operator For Network Construction

MindSpore易点通·精讲系列–网络构建之LSTM算子

MindSpore易点通·精讲系列–网络构建之LSTM算子–上篇
MindSpore易点通·精讲系列–网络构建之LSTM算子–中篇
MindSpore易点通·精讲系列–网络构建之LSTM算子–下篇

本文开发环境

  • MindSpore 1.7.0

本文内容提要

  • 原理介绍
  • 文档说明
  • 案例解说
  • 本文总结
  • 本文参考

3. 案例解说

3.2 单层双向LSTM

本示例中随机生成了[4, 8, 4]数据,该数据batch_size为4,固定seq_length为8,输入维度为4。

本示例采用单层双向LSTM,隐层大小为8。

本示例中LSTM调用时进行对比测试,一个seq_length为默认值None,一个为有效长度input_seq_length

示例代码如下:

import numpy as np

from mindspore import dtype
from mindspore import Tensor
from mindspore.nn import LSTM


def single_layer_bi_lstm():
    random_data = np.random.rand(4, 8, 4)
    seq_length = [3, 8, 5, 1]
    input_seq_data = Tensor(random_data, dtype=dtype.float32)
    input_seq_length = Tensor(seq_length, dtype=dtype.int32)

    batch_size = 4
    input_size = 4
    hidden_size = 8
    num_layers = 1
    bidirectional = True
    num_bi = 2 if bidirectional else 1

    lstm = LSTM(
        input_size=input_size, hidden_size=hidden_size, num_layers=num_layers,
        has_bias=True, batch_first=True, dropout=0.0, bidirectional=bidirectional)

    h0 = Tensor(np.ones([num_bi * num_layers, batch_size, hidden_size]).astype(np.float32))
    c0 = Tensor(np.ones([num_bi * num_layers, batch_size, hidden_size]).astype(np.float32))

    output_0, (hn_0, cn_0) = lstm(input_seq_data, (h0, c0))
    output_1, (hn_1, cn_1) = lstm(input_seq_data, (h0, c0), input_seq_length)

    print("====== single layer bi lstm output 0 shape: {} ======\n{}".format(output_0.shape, output_0), flush=True)
    print("====== single layer bi lstm hn0 shape: {} ======\n{}".format(hn_0.shape, hn_0), flush=True)
    print("====== single layer bi lstm cn0 shape: {} ======\n{}".format(cn_0.shape, cn_0), flush=True)

    print("====== single layer bi lstm output 1 shape: {} ======\n{}".format(output_1.shape, output_1), flush=True)
    print("====== single layer bi lstm hn1 shape: {} ======\n{}".format(hn_1.shape, hn_1), flush=True)
    print("====== single layer bi lstm cn1 shape: {} ======\n{}".format(cn_1.shape, cn_1), flush=True)

示例代码输出内容如下:

对输出内容进行分析:

  1. output_0和output_1维度都是[4, 8, 16],即batch_size, seq_length和hidden_size * 2,这里乘2是因为是双向输出。
  2. output_0对应的是调用时seq_length为None的情况,即默认有效seq_length为8,可以看到output_0各个长度输出数值皆非全零。
  3. output_1对应的是调用时seq_length为设定值[3, 8, 5, 1],可以看到output_1超过有效长度的输出部分皆为全零。
  4. hn和cn分别为隐层状态和细胞状态输出。下面以hn_1和cn_1为例进行讲解。
  5. hn_1维度为[2, 4, 8],2代表双向单层(2*1),4代表batch_size,8代表hidden_size。
  6. 仔细观察可以看出,hn_1中第一维度第0索引的正向输出部分与output_1最后一维输出前hidden_size数值一致,即与有效长度内最后一个的输出的前hidden_size数值保持一致。
  7. 仔细观察可以看出,hn_1中第一维度第1索引的反向输出部分与output_1开始一维输出后hidden_size数值一致。
  8. cn_1为有效最后一步的细胞状态。
====== single layer bi lstm output 0 shape: (4, 8, 16) ======
[[[ 0.11591419  0.29961097  0.3425573   0.4287143   0.17212108
    0.07444338  0.43271446  0.15715674  0.08194006  0.11577142
   -0.09744498 -0.02763127  0.09280778  0.08716499  0.02522062
    0.33181873]
  [-0.01308823  0.13623668  0.19448121  0.37028143  0.22777143
    0.00628781  0.39128026  0.15501572  0.08111142  0.11017906
   -0.12316822 -0.00816909  0.09567513  0.05021677  0.08249568
    0.33742255]
  [-0.05627449  0.04682723  0.15380071  0.3137156   0.26430035
   -0.046514    0.35723254  0.16584632  0.10204285  0.10223756
   -0.13232729 -0.00190703  0.11279006  0.07007243  0.07809626
    0.36085904]
  [-0.09489179 -0.00705127  0.1340199   0.24711385  0.27097055
   -0.05539801  0.29088783  0.180727    0.13702057  0.07165765
   -0.15263684 -0.02301912  0.14440101  0.09643525  0.04434848
    0.32824463]
  [-0.13192342 -0.09842218  0.13483751  0.2363211   0.2714419
   -0.06301905  0.23002718  0.12190706  0.1600955   0.0820565
   -0.13324322  0.00847512  0.15308659  0.12757084  0.06873622
    0.3726861 ]
  [-0.16037701 -0.12437794  0.12642992  0.23676534  0.29797453
   -0.04277696  0.24219972  0.16359471  0.16195399  0.07269616
   -0.1250204  -0.0185749   0.19040069  0.12709007  0.12064856
    0.30454746]
  [-0.1353235  -0.12385159  0.1025193   0.23867385  0.30110353
   -0.03195428  0.2832907   0.18136714  0.19130123  0.09153596
   -0.05207976  0.02430173  0.2524703   0.22256352  0.17788586
    0.3196903 ]
  [-0.15227936 -0.16710246  0.11279354  0.2324703   0.3158889
   -0.05391366  0.28967926  0.21905534  0.34464788  0.06061291
    0.10662059  0.08228769  0.38103724  0.44488934  0.22631703
    0.38864976]]

 [[ 0.07946795  0.30921736  0.35205007  0.37194842  0.2058839
    0.09482588  0.4332572   0.2775039   0.10343523  0.07151344
   -0.13616626 -0.04245609  0.10985457  0.06919786  0.0364913
    0.31924048]
  [-0.04591701  0.14795585  0.20307627  0.35713255  0.21074952
    0.03478044  0.36047992  0.15351431  0.11235587  0.07168273
   -0.11715946 -0.02380875  0.11772131  0.11803672  0.00387634
    0.33266184]
  [-0.09412251  0.02499678  0.17255405  0.3178058   0.23692454
   -0.03471331  0.26576498  0.10732022  0.14581609  0.07355653
   -0.12852795  0.01927058  0.13053373  0.14796041  0.01590303
    0.3854578 ]
  [-0.09348419  0.00631614  0.1466178   0.22848201  0.22966608
   -0.05388562  0.14963126  0.08823045  0.15729474  0.0657778
   -0.15222837 -0.01835432  0.15758416  0.17561477 -0.03188463
    0.3511778 ]
  [-0.15382743 -0.04836275  0.14573918  0.22835778  0.2532363
   -0.03674607  0.1401736   0.09852327  0.17570393  0.04582136
   -0.13850203  0.00081276  0.16863164  0.14211492  0.04397457
    0.33833435]
  [-0.14028388 -0.08847751  0.13194019  0.21878807  0.28851762
   -0.06432837  0.15592363  0.16226491  0.20294866  0.04400881
   -0.11535563  0.04870296  0.22049154  0.17808373  0.09339966
    0.34441146]
  [-0.1683049  -0.16189072  0.1318028   0.22591397  0.3027075
   -0.07447627  0.15145044  0.1329806   0.2544369   0.06014252
   -0.01793557  0.11026148  0.2146467   0.3118566   0.12141219
    0.39812002]
  [-0.19805393 -0.17752953  0.12876241  0.21628919  0.3038769
   -0.036511    0.1357605   0.10460708  0.3527281   0.07156999
    0.1540587   0.09252883  0.35960466  0.54258245  0.16377062
    0.40849966]]

 [[ 0.08452003  0.3159105   0.3420099   0.3319746   0.20285761
    0.08632328  0.3581056   0.27760154  0.14828831  0.04973472
   -0.18127252 -0.02664946  0.11601479  0.06740937  0.0379785
    0.342705  ]
  [-0.0266434   0.16035607  0.18312001  0.31999707  0.22840345
    0.01311543  0.3133277   0.20360778  0.12191478  0.06214391
   -0.16598006 -0.03916245  0.10791545  0.06448431  0.03113508
    0.33138022]
  [-0.10794992  0.03787376  0.16952753  0.2500641   0.24685495
   -0.05109966  0.20483223  0.18794663  0.16794644  0.03811646
   -0.17785533  0.00866746  0.13491729  0.06493596  0.055873
    0.3487326 ]
  [-0.11205798 -0.04663825  0.13637729  0.2688466   0.2944545
   -0.06623676  0.24580626  0.1894824   0.12357055  0.08545923
   -0.13890322  0.02125055  0.12671538  0.05041068  0.10938939
    0.37651145]
  [-0.14464049 -0.11277611  0.12929943  0.2506328   0.32429394
   -0.06989705  0.26676533  0.22626272  0.14871088  0.06151669
   -0.14160013  0.01764496  0.15616798  0.06309532  0.11477884
    0.3533678 ]
  [-0.1919359  -0.14934857  0.12687694  0.2482472   0.30332044
   -0.02129422  0.24142255  0.19039477  0.1872613   0.05607529
   -0.10981983  0.02655923  0.19725962  0.15991098  0.08460074
    0.32532936]
  [-0.15997384 -0.16905244  0.12601317  0.24978957  0.3109707
   -0.05129525  0.25644392  0.18721735  0.23115595  0.07164647
   -0.04363466  0.09616573  0.23608637  0.23462081  0.16639999
    0.36137852]
  [-0.17784727 -0.19330868  0.12555353  0.25036657  0.3237954
   -0.05024423  0.27374345  0.16953917  0.3444527   0.074378
    0.12866443  0.11058272  0.34053382  0.47292238  0.20279881
    0.42136478]]

 [[ 0.09268619  0.35032618  0.34263822  0.33635783  0.19130397
    0.089779    0.3541034   0.26252666  0.15370639  0.05593391
   -0.16430146 -0.00316385  0.14068598  0.13546935 -0.01566708
    0.32892445]
  [ 0.00249528  0.16723414  0.19037648  0.32905748  0.20670214
   -0.01093364  0.22814633  0.10346357  0.14574584  0.08942283
   -0.13508694  0.02989143  0.13283192  0.155128   -0.00928066
    0.38435996]
  [-0.09191902  0.02066077  0.1762495   0.2693505   0.2615397
   -0.07361222  0.17539641  0.12341685  0.14845897  0.06833903
   -0.15054268  0.02503714  0.12414654  0.08736143  0.07049443
    0.35888508]
  [-0.08116069 -0.0288023   0.12298302  0.24174306  0.3107592
   -0.07053182  0.23929915  0.17529318  0.09909797  0.10476568
   -0.13906275 -0.0065798   0.12028767  0.09093229  0.08531829
    0.33838242]
  [-0.08996075 -0.04482763  0.10432535  0.18569301  0.29469466
   -0.064595    0.21119419  0.19096416  0.15567164  0.06260847
   -0.15861334 -0.01660161  0.17961282  0.14018227  0.05389842
    0.32480207]
  [-0.13079894 -0.12208281  0.11661161  0.20262218  0.31364897
   -0.09002802  0.23725566  0.21705934  0.20321131  0.03772969
   -0.12727125  0.04301733  0.21097985  0.16362298  0.12457186
    0.3570657 ]
  [-0.14077222 -0.14493458  0.10797977  0.20154148  0.32082993
   -0.06558356  0.24276899  0.20433648  0.23955566  0.04574178
   -0.03365875  0.05299059  0.26905897  0.3059458   0.11437013
    0.3523326 ]
  [-0.20353709 -0.20380074  0.12652008  0.19772139  0.28259847
   -0.04320877  0.1549557   0.12743628  0.37037018  0.04201189
    0.16136979  0.10812846  0.3535916   0.573114    0.14248823
    0.42301312]]]
====== single layer bi lstm hn0 shape: (2, 4, 8) ======
[[[-0.15227936 -0.16710246  0.11279354  0.2324703   0.3158889
   -0.05391366  0.28967926  0.21905534]
  [-0.19805393 -0.17752953  0.12876241  0.21628919  0.3038769
   -0.036511    0.1357605   0.10460708]
  [-0.17784727 -0.19330868  0.12555353  0.25036657  0.3237954
   -0.05024423  0.27374345  0.16953917]
  [-0.20353709 -0.20380074  0.12652008  0.19772139  0.28259847
   -0.04320877  0.1549557   0.12743628]]

 [[ 0.08194006  0.11577142 -0.09744498 -0.02763127  0.09280778
    0.08716499  0.02522062  0.33181873]
  [ 0.10343523  0.07151344 -0.13616626 -0.04245609  0.10985457
    0.06919786  0.0364913   0.31924048]
  [ 0.14828831  0.04973472 -0.18127252 -0.02664946  0.11601479
    0.06740937  0.0379785   0.342705  ]
  [ 0.15370639  0.05593391 -0.16430146 -0.00316385  0.14068598
    0.13546935 -0.01566708  0.32892445]]]
====== single layer bi lstm cn0 shape: (2, 4, 8) ======
[[[-0.48307976 -0.40690032  0.24048738  0.49366224  0.5961513
   -0.13565473  0.5191028   0.48418468]
  [-0.55306923 -0.41890883  0.31527558  0.4081013   0.5560535
   -0.10868378  0.22270739  0.224445  ]
  [-0.5595058  -0.5172409   0.28816614  0.4680259   0.6353333
   -0.1406159   0.45408633  0.39424264]
  [-0.55914015 -0.42366728  0.29431793  0.42468843  0.5133875
   -0.11134674  0.27713037  0.2564772 ]]

 [[ 0.13141792  0.26979685 -0.20174497 -0.06629345  0.16831748
    0.14618596  0.05280813  0.84774   ]
  [ 0.16957031  0.19068424 -0.28012666 -0.10653219  0.1932735
    0.12457087  0.07286038  0.91865647]
  [ 0.25553685  0.1275407  -0.37673476 -0.06495219  0.21608156
    0.11330918  0.07597075  0.97954106]
  [ 0.2739099   0.14198926 -0.342751   -0.00778307  0.25392675
    0.23573248 -0.03052862  0.89955646]]]
====== single layer bi lstm output 1 shape: (4, 8, 16) ======
[[[ 0.11591419  0.299611    0.3425573   0.4287143   0.17212108
    0.07444337  0.43271446  0.15715674  0.14267941  0.11772849
   -0.08396029 -0.0199183   0.17602898  0.19761203  0.06850712
    0.30409858]
  [-0.01308823  0.1362367   0.19448121  0.3702814   0.22777143
    0.00628781  0.39128026  0.1550157   0.19404428  0.11392959
   -0.04281732  0.02546077  0.24461909  0.24037687  0.16997418
    0.30728906]
  [-0.05627449  0.04682725  0.15380071  0.3137156   0.26430035
   -0.04651401  0.3572325   0.1658463   0.32523182  0.10201547
    0.12631407  0.07232428  0.37344953  0.46444228  0.22052252
    0.38782993]
  [ 0.          0.          0.          0.          0.
    0.          0.          0.          0.          0.
    0.          0.          0.          0.          0.
    0.        ]
  [ 0.          0.          0.          0.          0.
    0.          0.          0.          0.          0.
    0.          0.          0.          0.          0.
    0.        ]
  [ 0.          0.          0.          0.          0.
    0.          0.          0.          0.          0.
    0.          0.          0.          0.          0.
    0.        ]
  [ 0.          0.          0.          0.          0.
    0.          0.          0.          0.          0.
    0.          0.          0.          0.          0.
    0.        ]
  [ 0.          0.          0.          0.          0.
    0.          0.          0.          0.          0.
    0.          0.          0.          0.          0.
    0.        ]]

 [[ 0.07946795  0.30921736  0.35205007  0.37194842  0.2058839
    0.09482589  0.4332572   0.27750388  0.10343523  0.07151344
   -0.13616627 -0.04245608  0.10985459  0.06919786  0.0364913
    0.31924048]
  [-0.04591701  0.14795585  0.20307627  0.35713255  0.21074952
    0.03478044  0.36047992  0.1535143   0.11235587  0.07168273
   -0.11715946 -0.02380875  0.11772133  0.11803672  0.00387635
    0.33266184]
  [-0.09412251  0.02499679  0.17255405  0.31780577  0.23692457
   -0.03471331  0.265765    0.10732021  0.14581607  0.07355653
   -0.12852795  0.01927058  0.13053373  0.14796041  0.01590303
    0.38545772]
  [-0.09348419  0.00631614  0.14661779  0.228482    0.2296661
   -0.05388563  0.14963126  0.08823042  0.15729474  0.0657778
   -0.15222837 -0.01835432  0.15758416  0.17561477 -0.03188463
    0.35117778]
  [-0.15382743 -0.04836275  0.14573918  0.22835778  0.25323635
   -0.03674608  0.14017357  0.09852324  0.17570391  0.04582136
   -0.13850203  0.00081274  0.16863164  0.14211491  0.04397457
    0.33833435]
  [-0.14028388 -0.08847751  0.13194019  0.21878807  0.28851762
   -0.06432837  0.15592363  0.16226488  0.20294866  0.04400881
   -0.11535563  0.04870294  0.22049154  0.17808372  0.09339967
    0.34441146]
  [-0.1683049  -0.16189072  0.1318028   0.22591396  0.30270752
   -0.07447628  0.15145041  0.13298061  0.2544369   0.06014251
   -0.01793558  0.11026147  0.2146467   0.31185657  0.1214122
    0.39812005]
  [-0.19805394 -0.17752953  0.12876241  0.21628918  0.30387694
   -0.036511    0.1357605   0.10460708  0.3527281   0.07156998
    0.1540587   0.09252883  0.35960466  0.54258245  0.16377063
    0.40849966]]

 [[ 0.08452003  0.31591052  0.3420099   0.3319746   0.2028576
    0.08632328  0.3581056   0.2776015   0.16127887  0.05090985
   -0.18798977 -0.03278283  0.14869703  0.09618111  0.05077953
    0.32884052]
  [-0.0266434   0.16035606  0.18312001  0.31999707  0.22840345
    0.01311543  0.31332764  0.20360778  0.14828573  0.06162609
   -0.16532603 -0.04184524  0.17109753  0.11741111  0.05272176
    0.31123316]
  [-0.10794992  0.03787376  0.16952753  0.2500641   0.24685495
   -0.05109966  0.2048322   0.18794663  0.21637706  0.03754523
   -0.15342048  0.0159312   0.2186653   0.17495207  0.09126361
    0.32591543]
  [-0.11205798 -0.04663826  0.13637729  0.2688466   0.2944545
   -0.06623676  0.24580622  0.1894824   0.21777555  0.08560579
   -0.0555483   0.0522357   0.2504716   0.23061936  0.18061498
    0.34555358]
  [-0.14464049 -0.11277609  0.12929943  0.2506328   0.32429394
   -0.06989705  0.26676533  0.22626273  0.34267974  0.06394035
    0.10800922  0.07929072  0.38286424  0.44688055  0.22619261
    0.38621217]
  [ 0.          0.          0.          0.          0.
    0.          0.          0.          0.          0.
    0.          0.          0.          0.          0.
    0.        ]
  [ 0.          0.          0.          0.          0.
    0.          0.          0.          0.          0.
    0.          0.          0.          0.          0.
    0.        ]
  [ 0.          0.          0.          0.          0.
    0.          0.          0.          0.          0.
    0.          0.          0.          0.          0.
    0.        ]]

 [[ 0.09268619  0.35032618  0.34263822  0.33635783  0.19130397
    0.089779    0.3541034   0.26252666  0.34620598  0.06714007
    0.13512857  0.04233981  0.42014182  0.5216394   0.18838547
    0.3683127 ]
  [ 0.          0.          0.          0.          0.
    0.          0.          0.          0.          0.
    0.          0.          0.          0.          0.
    0.        ]
  [ 0.          0.          0.          0.          0.
    0.          0.          0.          0.          0.
    0.          0.          0.          0.          0.
    0.        ]
  [ 0.          0.          0.          0.          0.
    0.          0.          0.          0.          0.
    0.          0.          0.          0.          0.
    0.        ]
  [ 0.          0.          0.          0.          0.
    0.          0.          0.          0.          0.
    0.          0.          0.          0.          0.
    0.        ]
  [ 0.          0.          0.          0.          0.
    0.          0.          0.          0.          0.
    0.          0.          0.          0.          0.
    0.        ]
  [ 0.          0.          0.          0.          0.
    0.          0.          0.          0.          0.
    0.          0.          0.          0.          0.
    0.        ]
  [ 0.          0.          0.          0.          0.
    0.          0.          0.          0.          0.
    0.          0.          0.          0.          0.
    0.        ]]]
====== single layer bi lstm hn1 shape: (2, 4, 8) ======
[[[-0.05627449  0.04682725  0.15380071  0.3137156   0.26430035
   -0.04651401  0.3572325   0.1658463 ]
  [-0.19805394 -0.17752953  0.12876241  0.21628918  0.30387694
   -0.036511    0.1357605   0.10460708]
  [-0.14464049 -0.11277609  0.12929943  0.2506328   0.32429394
   -0.06989705  0.26676533  0.22626273]
  [ 0.09268619  0.35032618  0.34263822  0.33635783  0.19130397
    0.089779    0.3541034   0.26252666]]

 [[ 0.14267941  0.11772849 -0.08396029 -0.0199183   0.17602898
    0.19761203  0.06850712  0.30409858]
  [ 0.10343523  0.07151344 -0.13616627 -0.04245608  0.10985459
    0.06919786  0.0364913   0.31924048]
  [ 0.16127887  0.05090985 -0.18798977 -0.03278283  0.14869703
    0.09618111  0.05077953  0.32884052]
  [ 0.34620598  0.06714007  0.13512857  0.04233981  0.42014182
    0.5216394   0.18838547  0.3683127 ]]]
====== single layer bi lstm cn1 shape: (2, 4, 8) ======
[[[-0.16340391  0.12338591  0.36321753  0.60983956  0.4963916
   -0.14528881  0.61422133  0.37583172]
  [-0.5530693  -0.41890883  0.31527558  0.40810126  0.5560536
   -0.10868377  0.22270739  0.22444502]
  [-0.46137562 -0.27004397  0.27595642  0.5348579   0.62363803
   -0.18086377  0.46610427  0.4973321 ]
  [ 0.23746979  0.6868869   0.56339467  0.96855223  0.39346337
    0.32335475  0.7259624   0.4185825 ]]

 [[ 0.22938183  0.2952913  -0.17549752 -0.05000385  0.33509728
    0.3336044   0.14473113  0.7370499 ]
  [ 0.16957031  0.19068426 -0.2801267  -0.10653219  0.19327351
    0.12457087  0.07286038  0.91865647]
  [ 0.27940926  0.13317151 -0.39137632 -0.081429    0.28198367
    0.16170114  0.10146889  0.91004795]
  [ 0.6180897   0.28882137  0.28748003  0.15160248  0.7991137
    0.90929043  0.45457762  0.8128108 ]]]

MindSpore易点通·精讲系列–网络构建之LSTM算子–上篇
MindSpore易点通·精讲系列–网络构建之LSTM算子–中篇
MindSpore易点通·精讲系列–网络构建之LSTM算子–下篇

本文为原创文章,版权归作者所有,未经授权不得转载!

posted @ 2022-08-11 18:29  Skytier  阅读(172)  评论(0编辑  收藏  举报