pytorch中如何在lstm中输入可变长的序列

PyTorch 训练 RNN 时,序列长度不固定怎么办?

pytorch中如何在lstm中输入可变长的序列

上面两篇文章写得很好,把LSTM中训练变长序列所需的三个函数讲解的很清晰,但是这两篇文章没有给出完整的训练代码,并且没有写关于带label的情况,为此,本文给出一个完整的带label的训练代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import torch
from torch import nn
import torch.nn.utils.rnn as rnn_utils
from torch.utils.data import DataLoader
import torch.utils.data as data_
 
 
class MyData(data_.Dataset):
    def __init__(self, data, label):
        self.data = data
        self.label = label
 
    def __len__(self):
        return len(self.data)
 
    def __getitem__(self, idx):
        tuple_ = (self.data[idx], self.label[idx])
        return tuple_
 
 
def collate_fn(data_tuple):   # data_tuple是一个列表,列表中包含batchsize个元组,每个元组中包含数据和标签
    data_tuple.sort(key=lambda x: len(x[0]), reverse=True)
    data = [sq[0] for sq in data_tuple]
    label = [sq[1] for sq in data_tuple]
    data_length = [len(sq) for sq in data]
    data = rnn_utils.pad_sequence(data, batch_first=True, padding_value=0.0)     # 用零补充,使长度对齐
    label = rnn_utils.pad_sequence(label, batch_first=True, padding_value=0.0)   # 这行代码只是为了把列表变为tensor
    return data.unsqueeze(-1), label, data_length
 
 
if __name__ == '__main__':
 
    EPOCH = 2
    batchsize = 3
    hiddensize = 4
    num_layers = 2
    learning_rate = 0.001
 
    # 训练数据
    train_x = [torch.FloatTensor([1, 1, 1, 1, 1, 1, 1]),
               torch.FloatTensor([2, 2, 2, 2, 2, 2]),
               torch.FloatTensor([3, 3, 3, 3, 3]),
               torch.FloatTensor([4, 4, 4, 4]),
               torch.FloatTensor([5, 5, 5]),
               torch.FloatTensor([6, 6]),
               torch.FloatTensor([7])]
    # 标签
    train_y = [torch.rand(7, hiddensize),
               torch.rand(6, hiddensize),
               torch.rand(5, hiddensize),
               torch.rand(4, hiddensize),
               torch.rand(3, hiddensize),
               torch.rand(2, hiddensize),
               torch.rand(1, hiddensize)]
 
    data_ = MyData(train_x, train_y)
    data_loader = DataLoader(data_, batch_size=batchsize, shuffle=True, collate_fn=collate_fn)
    net = nn.LSTM(input_size=1, hidden_size=hiddensize, num_layers=num_layers, batch_first=True)
    criteria = nn.MSELoss()
    optimizer = torch.optim.Adam(net.parameters(), lr=learning_rate)
 
    # 训练方法一
    for epoch in range(EPOCH):
        for batch_id, (batch_x, batch_y, batch_x_len) in enumerate(data_loader):
            batch_x_pack = rnn_utils.pack_padded_sequence(batch_x, batch_x_len, batch_first=True)
            out, _ = net(batch_x_pack)   # out.data's shape (所有序列总长度, hiddensize)
            out_pad, out_len = rnn_utils.pad_packed_sequence(out, batch_first=True)
            loss = criteria(out_pad, batch_y)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            print('epoch:{:2d}, batch_id:{:2d}, loss:{:6.4f}'.format(epoch, batch_id, loss))
 
    # 训练方法二
    for epoch in range(EPOCH):
        for batch_id, (batch_x, batch_y, batch_x_len) in enumerate(data_loader):
            batch_x_pack = rnn_utils.pack_padded_sequence(batch_x, batch_x_len, batch_first=True)
            batch_y_pack = rnn_utils.pack_padded_sequence(batch_y, batch_x_len, batch_first=True)
            out, _ = net(batch_x_pack)   # out.data's shape (所有序列总长度, hiddensize)
            loss = criteria(out.data, batch_y_pack.data)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            print('epoch:{:2d}, batch_id:{:2d}, loss:{:6.4f}'.format(epoch, batch_id, loss))
 
    print('Training done!')

运行结果:

 

 

 

posted @   Picassooo  阅读(3310)  评论(5编辑  收藏  举报
编辑推荐:
· 记一次.NET内存居高不下排查解决与启示
· 探究高空视频全景AR技术的实现原理
· 理解Rust引用及其生命周期标识(上)
· 浏览器原生「磁吸」效果!Anchor Positioning 锚点定位神器解析
· 没有源码,如何修改代码逻辑?
阅读排行:
· 全程不用写代码,我用AI程序员写了一个飞机大战
· MongoDB 8.0这个新功能碉堡了,比商业数据库还牛
· 记一次.NET内存居高不下排查解决与启示
· DeepSeek 开源周回顾「GitHub 热点速览」
· 白话解读 Dapr 1.15:你的「微服务管家」又秀新绝活了
点击右上角即可分享
微信分享提示