transformer中的attention机制详解

transformer中用到的注意力机制包括self-attention(intra-attention)和传统的attention(cross-attention),本篇文章将在第一节简述这两者的差别,第二节详述self-attention机制,第三节介绍其实现

self-attention和attention的区别

传统attention机制

发生在decoder和encoder之间,decoder可以更多的参考encoder中相关的信息,以便指导其输出。attention机制可以分为以下三步

  • 计算algnment score

其中hj时encoder输出的隐状态, si是decoder 输出的隐状态, eij描述的时输入位置j和输出位置i的匹配度

  • 匹配度归一化,这里使用softmax进行计算

  • 计算context vector

从表达式可看出,decoder 计算所需的context vector 实际上就是输入隐状态的加权和。

具体到RNN中,每个时间步应用attention机制的计算步骤如下

  • decoder RNN 接收 token 的嵌入和初始解码器隐藏状态。
  • RNN 处理其输入,产生输出和新的隐藏状态向量 si。输出被丢弃。
  • attention计算:我们使用encoder输出的所有的隐藏状态和 decoder 输出的si 向量来计算此时间步骤的context vector ci。
  • 我们将 si 和 ci 连接成一个向量。
  • 我们将此向量传递给前馈神经网络(与模型联合训练)。
  • 前馈神经网络的输出表示此时间步骤的输出词。
  • 对下一个时间步骤重复此操作

self-attention

发生在decoder或者encoder内部,将输出或者输入序列内部不同位置关联起来,以计算序列表征

self-attention机制

self-attention的实现步骤和attention类似,在attention中计算align score时用到了输入和输出的hidden state,但是对于self-attention只需要用到一种,即在encoder中的self-attention只用到encoder层输出的hidden state, decoder中的self-attention只用到decoder层的hidden state

我们将self-attention拆解为两部分,1. self-attention计算 2. multi-head attention

self attention计算:scaled dot-product attention

  • 获取encoder输入的embeding,并计算每个embedding 的query,key,value,下文简写为q,k,v。

其中WQ, WK, WV为去要学习的权重矩阵

  • 接下来我们要计算不同位置之间的关联度。例如我们要计算位置0处的embedding和其他位置embedding的关联度,参考传统attention机制align score的计算方法,我们要用位置0处计算得到的hidden states即query 和其他位置处计算的key进行计算。在self-attention中计算过程如下

相较于传统attention计算align score,self-attention中多了一步scale,即用key维度的开方对qk结果进行缩放。论文中提出这样做的理由是

We suspect that for large values of dk, the dot products grow large in magnitude, pushing the softmax function into regions where it has extremely small gradients. To counteract this effect, we scale the dot products by 1/sqrt(dk)

即缩放的目的是为了保证softmax有更加稳定的梯度

  • 得到其他位置对于位置0的关联度/权重之后,我们就可以计算位置0处包含有上下文信息的context vector:

  • 利用矩阵运算可以同步求出其他位置的context vector
    获取输入向量的Q,K,V矩阵

    运用矩阵计算得到每个位置的结果

multi-head attention

相较于单头注意力,使用多头注意力的目的在于

  1. 从不同表征空间挖掘不同位置之间的关联。

Multi-head attention allows the model to jointly attend to information from different representation subspaces at different positions. With a single attention head, averaging inhibits this.

  1. 单头注意力在计算不同位置间关联时用到了加权平均,这在一定程度上影响了特征计算的准确性,因此要用多头注意力来抵消这种影响

In these models, the number of operations required to relate signals from two arbitrary input or output positions grows in the distance between positions, linearly for ConvS2S and logarithmically for ByteNet. This makes it more difficult to learn dependencies between distant positions [12]. In the Transformer this is reduced to a constant number of operations, albeit at the cost of reduced effective resolution due to averaging attention-weighted positions, an effect we counteract with Multi-Head Attention as described in section 3.2.

单头注意力包含WQ,WK,WV, 产生一个output,多头注意力则包含n个WQ,WK,WV,这些参数的权重不共享,产生n个output

这n个output被拼接到一起,并对拼接结果再次进行projection得到最终结果

self-attention实现

  1. 首先是single attention
def attention(query, key, value, mask=None, dropout=None):
    # 获取维度, query, key, value 的size 均为(batch_size,  n_head, seq_length, hidden_state_length)
    d_k = query.size(-1)
    scores = torch.matmul(query, key.transpose(-2, -1))/math.sqrt(d_k)
    if mask:
        scores = score.masked_fill(mask==0, -1e9)
    p_atten = scores.softmax(dim = -1)
    if dropout:
        p_atten = dropout(p_atten)
    return torch.matmul(p_atten, vakue), p_atten

在transformer decoder中会在self-attention中使用mask,在encoder中不会用到。因为本篇文章主要讲解self-attention因此没有讲解mask的使用,下一篇讲解transformer的文章中会具体分析self-attention在decoder和encoder中的区别。

  1. multi head attention实现
class MultiHeadAttention(nn.Module):
    def __init__(self, h, d_model, dropout=0.1):
        super(MultiHeadAttention, self).__init__()
        assert d_model % h == 0
        self.d_k = d_model / h
        self.h = h 
        self.wq =  nn.Linear(d_model, d_model)
        self.wk =  nn.Linear(d_model, d_model)
        self.wv =  nn.Linear(d_model, d_model)
        self.wo =  nn.Linear(d_model, d_model)
        self.atten = None 
        self.dropout = nn.Dropout(p=dropout)
    def forward(self, query, key, value, mask=None):
        if mask:
            # same mask applied to all heads
            mask = mask.unsqueeze(1)
        nbatches = query.size(0)
        query = self.wq(query).view(nbatches, -1, self.h, self.dk).transpose(1, 2)
        key = self.wk(key).view(nbatches, -1, self.h, self.dk).transpose(1, 2)
        value = self.wv(value).view(nbatches, -1, self.h, self.dk).transpose(1, 2)
        
        x, self.atten = attention(query, key, value, mask, self.dropout)
        # concat n heads outputs 
        x = (
              x.transpose(1,2)
              .contiguous()
              .view(nbatches, -1, self.h*self.d_k))
        del query
        del key
        del value
        return self.wo(x)

从multihead attention的实现中看出,实际上是将维度为(nbatches, seq_length, d_model)的矩阵,利用矩阵变换,得到了一个 (nbatches,h, seq_length,dk)的矩阵,且h*dk = d_model。

在attention中计算的时候所有head并行计算,得到一个(nbatches,h, seq_length,dk)的输出,对这个输出结果在进行矩阵变换得到 (nbatches, seq_length, d_model)的矩阵。完成了所谓的‘矩阵拼接’

拼接后的矩阵经过wo计算得到最终结果

ref:
Attention is all you need
The Illustrated Transformer
Visualizing A Neural Machine Translation Model (Mechanics of Seq2seq Models With Attention)
The Annotated Transformer

posted @ 2024-07-02 18:26  老张哈哈哈  阅读(78)  评论(0编辑  收藏  举报