transformer中解码器的实现细节
1. 前言
17年google团队发表l了论文《Attention Is All You Need》,transformer横空出世,并引领了AI学术圈的研发风向,以Transformer为基础模型的新模型层出不穷,无论是NLP还是CV或者是多模态,attention遍地开花。
这篇文章遵循encoder-decoder架构,并在其中使用了self-attention和cross-attention,如下图所示:
其中,encoder的行为还是非常好理解的,至于decoder,则相关细节在原文中都只草草提过,令人留下很多疑问,譬如,
decoder第一个attention为什么需要使用masked?
decoder在训练阶段和测试阶段有什么区别?
decoder在测试阶段,decoder的query输入是将目前所有的预测输入,还是只输入上一次decoder的输出?
2. 问题探讨
decoder第一个attention为什么需要使用masked?
Transformer模型属于自回归模型,也就是说后面的token的推断是基于前面的token的。Decoder端的Mask的功能是为了保证训练阶段和推理阶段的一致性。
在推理阶段,token是按照从左往右的顺序推理的。也就是说,在推理timestep=T的token时,decoder只能“看到”timestep < T的 T-1 个Token, 不能和timestep大于它自身的token做attention(因为根本还不知道后面的token是什么)。为了保证训练时和推理时的一致性,所以,训练时要同样防止token与它之后的token去做attention。
decoder在训练阶段和测试阶段有什么区别?
在训练阶段,预测序列是直接全部喂到decoder的输入的,只是在算self-attention的时候加了一个mask,前面时间步的不能看到后面时间步的词,decoder的预测也是一次就全部出来了,也就是Teacher Forcing机制,如下图所示,在训练的时候,需要预测一段语音,decoder端的input,就直接把gt喂进去了,当然加进去前还需要shift right,在序列最左边增加一个Begin的特殊字符(为了和预测阶段保持一致),然后这些gt作为query,进行进入第一层mask multi-head attention层(根据时间步增加mask,以免在self-attention阶段前面的词可以看到后面的),然后以这层的输出为query,来自encoder的输出为key-value pair输入第二个子层multi-head attention,输出作为下层的输入,继续前面的过程,重复N次。
如果是测试阶段,则就不一样,首先decoder会先输入Begin,预测出下一个词,然后再以已经预测的词作为输入,再进入decoder预测下一个词,直到遇到预测出的词是表示结束的特殊次元,才结束这个过程,参考以下视频:
https://www.zhihu.com/zvideo/1330559583777939456
decoder在测试阶段,decoder的query输入是将目前所有的预测输入,还是只输入上一次decoder的输出?
两种实现都有,具体来说,分别是:
a. 每次都将当前预测全部输入,在self-attention和cross-attention中均进行全量计算,优点是实现简单,缺点是计算量大,如下面的代码实现:
class DecoderLayer(nn.Module): def __init__(self, size, self_attn, src_attn, feed_forward, dropout): super(DecoderLayer, self).__init__() self.size = size self.self_attn = self_attn self.src_attn = src_attn self.feed_forward = feed_forward self.sublayer = clones(SublayerConnection(size, dropout), 3) def forward(self, x, memory, src_mask, tgt_mask): m = memory x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, tgt_mask)) x = self.sublayer[1](x, lambda x: self.src_attn(x, m, m, src_mask)) return self.sublayer[2](x, self.feed_forward)
可以看到每次的做self-attention的时候query,key,value都是目前所有的词(query 做了mask操作)。
完全版可以查看:https://zhuanlan.zhihu.com/p/398039366
b. 还有另外一种实现就是增量进行计算,李沐在《动手学深度学习》中就用了这种实现,优点是每次只需要计算一个query,但是因为在self-attention中需要与其他的词进行attention操作,因此需要在每层中保存之前的词作为key和value,如下面代码所示:
class DecoderBlock(nn.Module): """解码器中第i个块""" def __init__(self, key_size, query_size, value_size, num_hiddens, norm_shape, ffn_num_input, ffn_num_hiddens, num_heads, dropout, i, **kwargs): super(DecoderBlock, self).__init__(**kwargs) self.i = i self.attention1 = d2l.MultiHeadAttention( key_size, query_size, value_size, num_hiddens, num_heads, dropout) self.addnorm1 = AddNorm(norm_shape, dropout) self.attention2 = d2l.MultiHeadAttention( key_size, query_size, value_size, num_hiddens, num_heads, dropout) self.addnorm2 = AddNorm(norm_shape, dropout) self.ffn = PositionWiseFFN(ffn_num_input, ffn_num_hiddens, num_hiddens) self.addnorm3 = AddNorm(norm_shape, dropout) def forward(self, X, state): enc_outputs, enc_valid_lens = state[0], state[1] # 训练阶段,输出序列的所有词元都在同一时间处理, # 因此state[2][self.i]初始化为None。 # 预测阶段,输出序列是通过词元一个接着一个解码的, # 因此state[2][self.i]包含着直到当前时间步第i个块解码的输出表示 if state[2][self.i] is None: key_values = X else: key_values = torch.cat((state[2][self.i], X), axis=1) state[2][self.i] = key_values if self.training: batch_size, num_steps, _ = X.shape # dec_valid_lens的开头:(batch_size,num_steps), # 其中每一行是[1,2,...,num_steps] dec_valid_lens = torch.arange( 1, num_steps + 1, device=X.device).repeat(batch_size, 1) else: dec_valid_lens = None # 自注意力 X2 = self.attention1(X, key_values, key_values, dec_valid_lens) Y = self.addnorm1(X, X2) # 编码器-解码器注意力。 # enc_outputs的开头:(batch_size,num_steps,num_hiddens) Y2 = self.attention2(Y, enc_outputs, enc_outputs, enc_valid_lens) Z = self.addnorm2(Y, Y2) return self.addnorm3(Z, self.ffn(Z)), state
其中state[2][self.i]就存储了目前为止所有预测到的词。
完整版可以查看:https://zh-v2.d2l.ai/chapter_attention-mechanisms/transformer.html
3. 参考
[1] Transformer源码详解(Pytorch版本)
(完)