代码片段
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 | #pack_padded_sequence,pad_sequence的代码import torch from torch.utils.data import Dataset,DataLoader from torch.nn.utils.rnn import pack_padded_sequence,pad_packed_sequence,pack_sequence,pad_sequence import torch.nn as nn class MyData(Dataset): def __init__( self , data): self .data = data def __len__( self ): return len ( self .data) def __getitem__( self , idx): return self .data[idx] def collate_fn(data): data.sort(key = lambda x: len (x), reverse = True ) seq_len = [s.size( 0 ) for s in data] data = pad_sequence(data, batch_first = True ). float () data = data.unsqueeze( - 1 ) data = pack_padded_sequence(data, seq_len, batch_first = True ) return data a = torch.tensor([ 1 , 2 , 3 , 4 ]) b = torch.tensor([ 5 , 6 , 7 ]) c = torch.tensor([ 7 , 8 ]) d = torch.tensor([ 9 ]) train_x = [a, b, c, d] data = MyData(train_x) data_loader = DataLoader(data, batch_size = 2 , shuffle = True ,collate_fn = collate_fn) batch_x = iter (data_loader). next () rnn = nn.LSTM( 1 , 4 , 1 , batch_first = True ) h0 = torch.rand( 1 , 2 , 4 ). float () c0 = torch.rand( 1 , 2 , 4 ). float () out, (h1, c1) = rnn(batch_x, (h0, c0)) print (out) |
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· 震惊!C++程序真的从main开始吗?99%的程序员都答错了
· 【硬核科普】Trae如何「偷看」你的代码?零基础破解AI编程运行原理
· 单元测试从入门到精通
· 上周热点回顾(3.3-3.9)
· Vue3状态管理终极指南:Pinia保姆级教程