[TORCH] pack_padded_sequence 和 pad_packed_sequence 的使用 (2020版本)
内容简介
本文主要是通过代码的方式展示pytorch的pack和pad函数。
找到的两个可以参考的靠谱网站(不是CSDN的奇怪东西):
理论链接,建议直接看图
实践链接,直接看代码
使用的代码
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
from torch.nn import utils as nn_utils
import torch.nn.functional as F
import torch
# seq example
# batch的尺寸是5,假设我们有五句话,每句话有不定长的词汇
# 这里只假设每个词汇的feature是一维的
batch_size = 5
a = torch.tensor([1,2])
b = torch.tensor([1,2,3])
c = torch.tensor([1,2,3,4])
d = torch.tensor([1])
e = torch.tensor([1,2,3,4,5,6])
# general setting
# 提取五个句子的有效内容的长度
# 并且提取最大句子的长度
seq_lens = []
for i in [a,b,c,d,e]:
seq_lens.append(len(i))
max_len = max(seq_lens)
# Zero padding
# 通过加入0pad,让他们的长度相等,这个长度是最长句子的长度
a = F.pad(a,(0,max_len-len(a))) # 最低维度,前面增加0个,后面增加max-len(a)个
b = F.pad(b,(0,max_len-len(b)))
c = F.pad(c,(0,max_len-len(c)))
d = F.pad(d,(0,max_len-len(d)))
e = F.pad(e,(0,max_len-len(e)))
print("在a句子经过pad填充以后:\n{}\n".format(a))
# merge the seq
seq = torch.cat((a,b,c,d,e),0).view(-1,max_len)
print("所有句子融合以后可以获得整个矩阵:\n{}\n".format(seq))
# Pack
# 1. input size 可以是(T×B×* ) = (最长序列长度T,batch size B,任意维度*)
# 2. input size 可以是(B×T×*), 如果batch_first=True的话
# 这里我们选择 batch 在前,所以是2
packed_seq = pack_padded_sequence(seq, seq_lens, batch_first=True, enforce_sorted=False)
print('经过了 pack_padded_sequence 处理:\n{}\n'.format(packed_seq))
# Unpack
unpacked_seq, unpacked_lens = pad_packed_sequence(packed_seq, batch_first=True)
print('Unpack还原的结果:\n{}\n'.format(unpacked_seq))
print('同时返回seq的length:\n{}\n'.format(unpacked_lens))
代码运行的结果
-
在a句子经过pad填充以后:
tensor([1, 2, 0, 0, 0, 0]) -
所有句子融合以后可以获得整个矩阵:
tensor([[1, 2, 0, 0, 0, 0],
[1, 2, 3, 0, 0, 0],
[1, 2, 3, 4, 0, 0],
[1, 0, 0, 0, 0, 0],
[1, 2, 3, 4, 5, 6]]) -
经过了 pack_padded_sequence 处理:
PackedSequence(
data=tensor([1, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 4, 4, 5, 6]),
batch_sizes=tensor([5, 4, 3, 2, 1, 1]),
sorted_indices=tensor([4, 2, 1, 0, 3]),
unsorted_indices=tensor([3, 2, 1, 4, 0])) -
Unpack还原的结果:
tensor([[1, 2, 0, 0, 0, 0],
[1, 2, 3, 0, 0, 0],
[1, 2, 3, 4, 0, 0],
[1, 0, 0, 0, 0, 0],
[1, 2, 3, 4, 5, 6]]) -
同时返回seq的length:
tensor([2, 3, 4, 1, 6])