pytorch LSTM 简单形式

import  torch
from torch import nn

lstm = nn.LSTM(input_size=100,hidden_size=20,num_layers=2)

print(lstm)

x = torch.randn(10,3,100)

out,(h,c) = lstm(x)

print('out shape:',out.shape) #[10,3,20]
print('h shape:',h.shape) #[2,3,20]
print('c shape:',c.shape) #[2,3,20]
posted @ 2020-08-08 15:51  kpwong  阅读(188)  评论(0编辑  收藏  举报