pytorch中如何在lstm中输入可变长的序列
上面两篇文章写得很好,把LSTM中训练变长序列所需的三个函数讲解的很清晰,但是这两篇文章没有给出完整的训练代码,并且没有写关于带label的情况,为此,本文给出一个完整的带label的训练代码:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 | import torch from torch import nn import torch.nn.utils.rnn as rnn_utils from torch.utils.data import DataLoader import torch.utils.data as data_ class MyData(data_.Dataset): def __init__( self , data, label): self .data = data self .label = label def __len__( self ): return len ( self .data) def __getitem__( self , idx): tuple_ = ( self .data[idx], self .label[idx]) return tuple_ def collate_fn(data_tuple): # data_tuple是一个列表,列表中包含batchsize个元组,每个元组中包含数据和标签 data_tuple.sort(key = lambda x: len (x[ 0 ]), reverse = True ) data = [sq[ 0 ] for sq in data_tuple] label = [sq[ 1 ] for sq in data_tuple] data_length = [ len (sq) for sq in data] data = rnn_utils.pad_sequence(data, batch_first = True , padding_value = 0.0 ) # 用零补充,使长度对齐 label = rnn_utils.pad_sequence(label, batch_first = True , padding_value = 0.0 ) # 这行代码只是为了把列表变为tensor return data.unsqueeze( - 1 ), label, data_length if __name__ = = '__main__' : EPOCH = 2 batchsize = 3 hiddensize = 4 num_layers = 2 learning_rate = 0.001 # 训练数据 train_x = [torch.FloatTensor([ 1 , 1 , 1 , 1 , 1 , 1 , 1 ]), torch.FloatTensor([ 2 , 2 , 2 , 2 , 2 , 2 ]), torch.FloatTensor([ 3 , 3 , 3 , 3 , 3 ]), torch.FloatTensor([ 4 , 4 , 4 , 4 ]), torch.FloatTensor([ 5 , 5 , 5 ]), torch.FloatTensor([ 6 , 6 ]), torch.FloatTensor([ 7 ])] # 标签 train_y = [torch.rand( 7 , hiddensize), torch.rand( 6 , hiddensize), torch.rand( 5 , hiddensize), torch.rand( 4 , hiddensize), torch.rand( 3 , hiddensize), torch.rand( 2 , hiddensize), torch.rand( 1 , hiddensize)] data_ = MyData(train_x, train_y) data_loader = DataLoader(data_, batch_size = batchsize, shuffle = True , collate_fn = collate_fn) net = nn.LSTM(input_size = 1 , hidden_size = hiddensize, num_layers = num_layers, batch_first = True ) criteria = nn.MSELoss() optimizer = torch.optim.Adam(net.parameters(), lr = learning_rate) # 训练方法一 for epoch in range (EPOCH): for batch_id, (batch_x, batch_y, batch_x_len) in enumerate (data_loader): batch_x_pack = rnn_utils.pack_padded_sequence(batch_x, batch_x_len, batch_first = True ) out, _ = net(batch_x_pack) # out.data's shape (所有序列总长度, hiddensize) out_pad, out_len = rnn_utils.pad_packed_sequence(out, batch_first = True ) loss = criteria(out_pad, batch_y) optimizer.zero_grad() loss.backward() optimizer.step() print ( 'epoch:{:2d}, batch_id:{:2d}, loss:{:6.4f}' . format (epoch, batch_id, loss)) # 训练方法二 for epoch in range (EPOCH): for batch_id, (batch_x, batch_y, batch_x_len) in enumerate (data_loader): batch_x_pack = rnn_utils.pack_padded_sequence(batch_x, batch_x_len, batch_first = True ) batch_y_pack = rnn_utils.pack_padded_sequence(batch_y, batch_x_len, batch_first = True ) out, _ = net(batch_x_pack) # out.data's shape (所有序列总长度, hiddensize) loss = criteria(out.data, batch_y_pack.data) optimizer.zero_grad() loss.backward() optimizer.step() print ( 'epoch:{:2d}, batch_id:{:2d}, loss:{:6.4f}' . format (epoch, batch_id, loss)) print ( 'Training done!' ) |
运行结果:
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· 记一次.NET内存居高不下排查解决与启示
· 探究高空视频全景AR技术的实现原理
· 理解Rust引用及其生命周期标识(上)
· 浏览器原生「磁吸」效果!Anchor Positioning 锚点定位神器解析
· 没有源码,如何修改代码逻辑?
· 全程不用写代码,我用AI程序员写了一个飞机大战
· MongoDB 8.0这个新功能碉堡了,比商业数据库还牛
· 记一次.NET内存居高不下排查解决与启示
· DeepSeek 开源周回顾「GitHub 热点速览」
· 白话解读 Dapr 1.15:你的「微服务管家」又秀新绝活了