代码片段
#pack_padded_sequence,pad_sequence的代码
import torch from torch.utils.data import Dataset,DataLoader from torch.nn.utils.rnn import pack_padded_sequence,pad_packed_sequence,pack_sequence,pad_sequence import torch.nn as nn class MyData(Dataset): def __init__(self, data): self.data = data def __len__(self): return len(self.data) def __getitem__(self, idx): return self.data[idx] def collate_fn(data): data.sort(key=lambda x: len(x), reverse=True) seq_len = [s.size(0) for s in data] data = pad_sequence(data, batch_first=True).float() data = data.unsqueeze(-1) data = pack_padded_sequence(data, seq_len, batch_first=True) return data a = torch.tensor([1,2,3,4]) b = torch.tensor([5,6,7]) c = torch.tensor([7,8]) d = torch.tensor([9]) train_x = [a, b, c, d] data = MyData(train_x) data_loader = DataLoader(data, batch_size=2, shuffle=True,collate_fn=collate_fn) batch_x = iter(data_loader).next() rnn = nn.LSTM(1, 4, 1, batch_first=True) h0 = torch.rand(1, 2, 4).float() c0 = torch.rand(1, 2, 4).float() out, (h1, c1) = rnn(batch_x, (h0, c0)) print(out)