Pytorch学习笔记05---- pack_padded_sequence和pad_packed_sequence理解
1.为什么要用pack_padded_sequence
在使用深度学习特别是RNN(LSTM/GRU)进行序列分析时,经常会遇到序列长度不一样的情况,此时就需要对同一个batch中的不同序列使用padding的方式进行序列长度对齐(可以都填充为batch中最长序列的长度,也可以设置一个统一的长度,对所有序列长截短填),方便将训练数据输入到LSTM模型进行训练,填充后一个batch的序列可以统一处理,加快速度。但是此时会有一个问题,LSTM会对序列中非填充部分和填充部分同等看待,这样会影响模型训练的精度,应该告诉LSTM相关序列的padding情况,让LSTM只对非填充部分进行运算。此时,pytorch中的pack_padded_sequence就有了用武之地。
其实有时候,可以填充后直接做,影响有时也不是很大,使用pack_padded_sequence后效果可能会更好。
结合例子分析:
如果不用pack和pad操作会有一个问题,什么问题呢?比如上图,句子“Yes”只有一个单词,但是padding了多余的pad符号,这样会导致LSTM对它的表示通过了非常多无用的字符,这样得到的句子表示就会有误差,更直观的如下图:
那么我们正确的做法应该是怎么样呢?
在上面这个例子,我们想要得到的表示仅仅是LSTM过完单词"Yes"之后的表示,而不是通过了多个无用的“Pad”得到的表示:如下图:
torch.nn.utils.rnn.pack_padded_sequence()
这里的pack
,理解成压紧比较好。 将一个 填充过的变长序列 压紧。(填充时候,会有冗余,所以压紧一下)
其中pack的过程为:(注意pack的形式,不是按行压,而是按列压)
pack之后,原来填充的 PAD(一般初始化为0)占位符被删掉了。
输入的形状可以是(T×B×* )。T
是最长序列长度,B
是batch size
,*
代表任意维度(可以是0)。如果batch_first=True
的话,那么相应的 input size
就是 (B×T×*)
。
Variable
中保存的序列,应该按序列长度的长短排序,长的在前,短的在后。即input[:,0]
代表的是最长的序列,input[:, B-1]
保存的是最短的序列。
NOTE:
只要是维度大于等于2的input
都可以作为这个函数的参数。你可以用它来打包labels
,然后用RNN
的输出和打包后的labels
来计算loss
。通过PackedSequence
对象的.data
属性可以获取 Variable
。
参数说明:
- input (Variable) – 变长序列 被填充后的 batch
- lengths (list[int]) –
Variable
中 每个序列的有效长度(即去掉pad的真实长度)。 - batch_first (bool, optional) – 如果是
True
,input的形状应该是B*T*size
。
返回值:
一个PackedSequence
对象。
torch.nn.utils.rnn.pad_packed_sequence()
填充packed_sequence
。
上面提到的函数的功能是将一个填充后的变长序列压紧。 这个操作和pack_padded_sequence()是相反的。把压紧的序列再填充回来。填充时会初始化为0。
返回的Varaible的值的size
是 T×B×*
, T
是最长序列的长度,B
是 batch_size,如果 batch_first=True
,那么返回值是B×T×*
。
Batch中的元素将会以它们长度的逆序排列。
参数说明:
- sequence (PackedSequence) – 将要被填充的 batch
- batch_first (bool, optional) – 如果为True,返回的数据的格式为
B×T×*
。
返回值: 一个tuple,包含被填充后的序列,和batch中序列的长度列表
2.小案例:
假设有demo.txt文件,包含下面5段文本/序列:
Some people like to choose those who are different from themselves while others prefer those who are similar to themselves. People choose friends in differrent ways. For instance, if an active and energetic guy proposes to his equally active and energetic friends that they should have some activities, it is more likely that his will agree at once. When people have friends similar to themselves, they and their friends chat, play, and do thing together natually and harmoniously. The result is that they all can feel relaxed and can trully enjoy each other's company.
使用下面的脚本将单词转换为索引,并填充为统一的长度:
import numpy as np import torch import torch.nn as nn vocab = {} #词到索引的映射字典 token_id = 1 #token_id=0 预留给填充符号 lengths = [] #存储每个文本的实际长度 with open('demo.txt', 'r') as f: for l in f: tokens = l.strip().split() #这里对英文分词 简单的按空格切分。(当然可以使用一些效果更好的分词工具,可以把标点分出来) print(tokens) lengths.append(len(tokens)) for t in tokens: if t not in vocab: vocab[t] = token_id token_id += 1 x = np.zeros((len(lengths), max(lengths))) #所有文本填充为最大的长度 l_no = 0 with open('demo.txt', 'r') as f: for l in f: tokens = l.strip().split() for i in range(len(tokens)): x[l_no, i] = vocab[tokens[i]] l_no += 1 print(x) print(x.shape) x = torch.tensor(x,requires_grad=True) lengths = torch.Tensor(lengths) print("lenghts:",lengths) #所有文本长度按从大到小排序 (降序),返回排序后的索引idx_sort _, idx_sort = torch.sort(torch.Tensor(lengths), dim=0, descending=True) print("idx_sort:",idx_sort) #对索引idx_sort进行从小到大排序 ,返回排序后的索引 idx_unsort _, idx_unsort = torch.sort(idx_sort, dim=0) print("idx_unsort:",idx_unsort) x1 = x[idx_sort]#x中的各个文本 随着排序 即最长的文本在第一行... lengths1 = list(lengths[idx_sort])#此时各个文本对应的长度(从大到小排序后) print("lenghts1:",lengths1) print("x1的形状与内容:") print(x1) print(x1.shape) x2=x1[idx_unsort] print("x2的形状与内容:") print(x2) print(x2.shape)
控制台输出:
D:\softwaretools\anaconda\python.exe D:/pycharmprojects/hoteltest01/hoteltest01/testpy/test07_pack_pad.py ['Some', 'people', 'like', 'to', 'choose', 'those', 'who', 'are', 'different', 'from', 'themselves', 'while', 'others', 'prefer', 'those', 'who', 'are', 'similar', 'to', 'themselves.'] ['People', 'choose', 'friends', 'in', 'differrent', 'ways.'] ['For', 'instance,', 'if', 'an', 'active', 'and', 'energetic', 'guy', 'proposes', 'to', 'his', 'equally', 'active', 'and', 'energetic', 'friends', 'that', 'they', 'should', 'have', 'some', 'activities,', 'it', 'is', 'more', 'likely', 'that', 'his', 'will', 'agree', 'at', 'once.'] ['When', 'people', 'have', 'friends', 'similar', 'to', 'themselves,', 'they', 'and', 'their', 'friends', 'chat,', 'play,', 'and', 'do', 'thing', 'together', 'natually', 'and', 'harmoniously.'] ['The', 'result', 'is', 'that', 'they', 'all', 'can', 'feel', 'relaxed', 'and', 'can', 'trully', 'enjoy', 'each', "other's", 'company.'] [[ 1. 2. 3. 4. 5. 6. 7. 8. 9. 10. 11. 12. 13. 14. 6. 7. 8. 15. 4. 16. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.] [17. 5. 18. 19. 20. 21. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.] [22. 23. 24. 25. 26. 27. 28. 29. 30. 4. 31. 32. 26. 27. 28. 18. 33. 34. 35. 36. 37. 38. 39. 40. 41. 42. 33. 31. 43. 44. 45. 46.] [47. 2. 36. 18. 15. 4. 48. 34. 27. 49. 18. 50. 51. 27. 52. 53. 54. 55. 27. 56. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.] [57. 58. 40. 33. 34. 59. 60. 61. 62. 27. 60. 63. 64. 65. 66. 67. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]] (5, 32) lenghts: tensor([20., 6., 32., 20., 16.]) idx_sort: tensor([2, 0, 3, 4, 1]) idx_unsort: tensor([1, 4, 0, 2, 3]) lenghts1: [tensor(32.), tensor(20.), tensor(20.), tensor(16.), tensor(6.)] x1的形状与内容: tensor([[22., 23., 24., 25., 26., 27., 28., 29., 30., 4., 31., 32., 26., 27., 28., 18., 33., 34., 35., 36., 37., 38., 39., 40., 41., 42., 33., 31., 43., 44., 45., 46.], [ 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 6., 7., 8., 15., 4., 16., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], [47., 2., 36., 18., 15., 4., 48., 34., 27., 49., 18., 50., 51., 27., 52., 53., 54., 55., 27., 56., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], [57., 58., 40., 33., 34., 59., 60., 61., 62., 27., 60., 63., 64., 65., 66., 67., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], [17., 5., 18., 19., 20., 21., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]], dtype=torch.float64, grad_fn=<IndexBackward>) torch.Size([5, 32]) x2的形状与内容: tensor([[ 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 6., 7., 8., 15., 4., 16., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], [17., 5., 18., 19., 20., 21., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], [22., 23., 24., 25., 26., 27., 28., 29., 30., 4., 31., 32., 26., 27., 28., 18., 33., 34., 35., 36., 37., 38., 39., 40., 41., 42., 33., 31., 43., 44., 45., 46.], [47., 2., 36., 18., 15., 4., 48., 34., 27., 49., 18., 50., 51., 27., 52., 53., 54., 55., 27., 56., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], [57., 58., 40., 33., 34., 59., 60., 61., 62., 27., 60., 63., 64., 65., 66., 67., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]], dtype=torch.float64, grad_fn=<IndexBackward>) torch.Size([5, 32]) Process finished with exit code 0
由x2与原始x的形状是一样的,主要是因为下面两行
idx_sort: tensor([2, 0, 3, 4, 1])
idx_unsort: tensor([1, 4, 0, 2, 3])
x_packed = nn.utils.rnn.pack_padded_sequence(input=x1, lengths=lengths1, batch_first=True) print(x_packed)
需要注意的是,pack_padded_sequence函数的参数,lengths需要从大到小排序(length1),x1已根据长度大小排好序(最长的序列在第一行…),batch_first如果设置为true,则x的第一维为batch_size,第二维为seq_length,否则相反。
打印x_packed如下:
PackedSequence(data=tensor([22., 1., 47., 57., 17., 23., 2., 2., 58., 5., 24., 3., 36., 40., 18., 25., 4., 18., 33., 19., 26., 5., 15., 34., 20., 27., 6., 4., 59., 21., 28., 7., 48., 60., 29., 8., 34., 61., 30., 9., 27., 62., 4., 10., 49., 27., 31., 11., 18., 60., 32., 12., 50., 63., 26., 13., 51., 64., 27., 14., 27., 65., 28., 6., 52., 66., 18., 7., 53., 67., 33., 8., 54., 34., 15., 55., 35., 4., 27., 36., 16., 56., 37., 38., 39., 40., 41., 42., 33., 31., 43., 44., 45., 46.], dtype=torch.float64, grad_fn=<PackPaddedSequenceBackward>), batch_sizes=tensor([5, 5, 5, 5, 5, 5, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 3, 3, 3, 3, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]), sorted_indices=None, unsorted_indices=None)
他把x1的两个维度合并成了一个维度,原本x1(batch_size,max_seq_len)=(5,32),x_packed相当于对x1按列进行访问,并且忽略掉其中的填充值0;下面多出的batch_size有max_seq_len=32个数字,可以理解为对x1进行按列访问时,每一列非填充值的个数,可以看到刚开始的几列没有填充值(每个序列的开始部分),值为batch_size=5,后面由于有的序列不够长,逐渐出现填充值0,所以batch_size的大小逐渐变小<5,直到最后等于1,也就是只有那个batch中最长的序列还有非填充值,其余序列都是填充值0.
参考文献:
https://blog.csdn.net/sdu_hao/article/details/105408552
https://www.cnblogs.com/sbj123456789/p/9834018.html