lstm pytorch梳理之 batch_first 参数 和torch.nn.utils.rnn.pack_padded_sequence

小萌新在看pytorch官网 LSTM代码时 对batch_first 参数 和torch.nn.utils.rnn.pack_padded_sequence 不太理解,

在回去苦学了一番 ,将自己消化过的记录在这,希望能帮到跟我有同样迷惑的伙伴

 

官方API:

  • 参数
    – input_size
    – hidden_size
    – num_layers
    – bias
    – batch_first
    – dropout
    – bidirectional
  • 特别说下batch_first ,参数默认为False,也就是它鼓励我们第一维不是batch,这与我们常规输入想悖,毕竟我们习惯的输入是(batch, seq_len, hidden_size),那么官方为啥会 这样子设置呢?

 

先不考虑hiddem_dim,左边图矩阵维度为batch_size * max_length, 6个序列,没个序列填充到最大长度4,经过转置后得到max_length *batch_size , 右图标蓝的一列 对应的就是 左图第二列,而左图第二列表示的是 每个序列里面第二个token,这样子有什么好处呢?相当于可以并行处理 每个句子在time step下时刻的计算,这样就 可以并行过LSTM,从而一定程度上提高处理速度。因为官网放的图例子 里面数字都是句子token 索引化之后的,反而让人容易看晕,因而小萌新自己画了个好理解的图。一起看下图呀。

一共有3个句子,最大长度为6,我们之前习惯的是 按行看,我们现在按一列一列来看(就相当于转置啦)

time step 0接受的是[ZHAOJIAN girls eat];

time step1接收的是[and are apple];

time step2接受的是[YUQIN beautiful PAD];

time step3接收的是[are angles PAD];

以此类推。 现在这3个句子就可以并行过LSTM

 pad_sequence

  我们知道一个batch里的序列长度是不一致的,而LSTM是无法处理长度不同的序列的,需要pad操作用0把它们都填充成max_length长度。下图有3个句子,以最长的句子长度 6 作为max_length,其余句子都填充到max_length 。这是PAD的作用,很好理解。

from torch.nn.utils.rnn import pack_padded_sequence ,pad_sequence ,pack_sequence                                                                                           
inputs = ["LIHUA went to The TsinghUA University",                                                                                                                         
          "Liping went to technical school ",                                                                                                                              
          "I work in the mall ",                                                                                                                                           
          "we both have bright future"]                                                                                                                                    
inputs.sort(key=lambda x:len(x.split()),reverse=True)                                                                                                                      
batch_size=len(inputs)                                                                                                                                                     
max_length=len(inputs[0].split())                                                                                                                                          
lengths=[len(s.split()) for s in inputs]                                                                                                                                   
                                                                                                                                                                           
word_to_idx={}                                                                                                                                                             
for sen in inputs:                                                                                                                                                         
    for word in sen.split():                                                                                                                                               
        if word not in word_to_idx:                                                                                                                                        
            word_to_idx[word]=len(word_to_idx)                                                                                                                             
idx=[]                                                                                                                                                                     
for sentence in inputs:                                                                                                                                                    
    a=[word_to_idx[w] for w in sentence.split()]                                                                                                                           
    idx.append(a)                                                                                                                                                          
pprint(idx)                                                                                                                                                                
padded_sequence = pad_sequence([torch.FloatTensor(id) for id in idx], batch_first=True)                                                                                    
print(padded_sequence)                                                                                                                                                     
packed_sequence = pack_sequence([torch.FloatTensor(id) for id in idx])  # packed_sequence是PackedSequence的实例                                                                

pack_sequence

  但带来一个问题 ,什么问题呢? 对于长度小于MAX_LENGTH ,经过PAD填充操作后的句子,会导致LSTM对它的表示多了很多无用的字符,如下图所示,我们希望的是在最后一个有用token 就输入句子的向量表示,而不是在很多PAD后才输入句子表示,这是pack就派上场了,可以理解成 将一个填充过的变长序列压紧.压缩的对象就是 padded suquence, 压缩后的输入将不含 0

看图更好理解 哦

  那么,聪明如你肯定会觉得不对劲,这先 填充又 压紧, 这不是做无用功?其实不是哦,因为 pack 后可并不是一个简单的 Tensor 类型的数据,而是一个 ”PackedSequence“ 类型的 object,可以直接传给RNN。小萌新在苦逼地看RNN源码时,发现forward 函数 里上来就是判断 输入是否是 PackedSequence 的实例,进而采取不同的操作。如果输入是 PackedSequence,输出也是该类型。这里的输出类型都指的是 forward 函数的第一个返回值(每个time step 对应的hidden_state),第二个返回值(最后一个time step对应的hidden_state)的类型不管输入是不是 PackedSequence 类型,都是一样的。

pack_padded_sequence

pytorch里 有封装的更好的 :torch.nn.utils.rnn.pack_padded_sequence(input, lengths, batch_first=False, enforce_sorted=True)

接下来说说这些参数作用

 lengths :该参数中各句子长度值的顺序要和对应的输入中的序列顺序一致

enforce_sorted: 默认值是 True,表示输入已经按句子长度降序排好序。如果输入在 pad 时没有顺序,那么此时在此处需要设置该值为 False,那么函数会再去排序

返回的对象是PackedSequence object。该类型的变量便可以直接喂给 RNN/LSTM等。

torch.nn.utils.rnn.pad_packed_sequence():之前的pack_padded_sequence 是先补齐到相同长度 再压紧,这个当然就是反过来,对压紧后的序列 进行扩充补齐操作。

注意:inputs是否排好序和 lengths参数和enforce_sorted 一定要对应起来。小萌新习惯将 inputs 按照长度先排好序,再将length 排好序enforce_sorted参数不去动它。

inputs.sort(key=lambda x:len(x.split()),reverse=True) 
lengths=[len(s.split()) for s in inputs] 

 

 

 

 

 

------------恢复内容结束------------

posted @ 2020-12-08 11:29  打了鸡血的女汉子  阅读(4201)  评论(3编辑  收藏  举报