pytorch-多头注意力(维度分析)重要

 

多头注意力

在实践中,当给定相同的查询、键和值的集合时,我们希望模型可以基于相同的注意力机制学习到不同的行为,然后将不同的行为作为知识组合起来,捕获序列内各种范围的依赖关系(例如,短距离依赖和长距离依
赖关系)
。因此,允许注意力机制组合使用查询、键和值的不同子空间表示(representation subspaces)可能是有益的。为此,与其只使用单独一个注意力汇聚,我们可以用独立学习得到的h组不同的线性投影(linear projections)来变换查询、键和值。然后,这h组变换后的查询、键和值将并行地送到注意力汇聚中。最后,将这h个注意力汇聚的输出拼接在一起,并且通过另一个可以学习的线性投影进行变换,以产生最终输出。这种设计被称为多头注意力(multihead attention)(Vaswani et al., 2017)。对于h个注意力汇聚输出,每一个注意力汇聚都被称作一个头(head)。下图展示了使用全连接层来实现可学习的线性变换的多头注意力。
image

模型

在实现多头注意力之前,让我们用数学语言将这个模型形式化地描述出来。给定查询qRdq、键kRdk和值vRdv,每个注意力头hii=1,...,h的计算方法为:

hi=f(Wi(q)q,Wi(k)k,Wi(v)v)Rpv

其中,可学习的参数包括Wi(q)Rpq×dqWi(k)Rpk×dkWi(v)Rpv×dv,以及代表注意力汇聚的函数f。f可以的加性注意力和缩放点积注意力。多头注意力的输出需要经过另一个线性转换,它对应着h个头连结后的结果,因此其可学习参数是 WoRpo×hpv
image

总结

多头注意力机制现在的使用是非常广泛的。为什么需要比较多的head呢?可以想成相关这件事情在做Self-attention的时候,就是用q去找相关的k,但是相关这件事情有很多种不同的形式,有很多种不同的定义,所以我们不能只有一个q,应该要有多个q,不同的q负责不同种类的相关性。
我们应在怎么做呢?首先对于这个qi我们分别乘两个矩阵变成qi,1qi,2。这个可以理解为两种不同的相关性。之后q,k,v都要有两个:
image
用第一个head:
image
用第二个head:
image
将这两个接起来,然后通过一个trannsform,也就是乘上一个矩阵,得到bi传到下一层去。
image

Attention、Self-attention、Multi-headed Self-attenion

image
如上图所示,最底层的输入x1,s2,x3.....,xT,表示输入的序列数据,比如,x1可以代表某个句子的第一个词所对应的向量。首先,通过嵌入层(可选)将它们进行初步的embedding,得到a1,a2,a3.....,aT。然后,使用三个矩阵WQ,WK,Wv分别与之相乘,得到qi,ki,vi,i(1,2,3...T)。上图显示了与输入x1所对应的输出b1是如何得到的。即:

  • 利用q1分别与k1,k2,k3...kT计算向量点积,得到a1,1,a1,2,a1,3....a1,T(从数值上看,ai,i还不一定是0-1之间的数,还需经过softmax处理);
  • a1,1,a1,2,a1,3....a1,T输入softmax层,从而得到均在0-1之间的注意力权重值:a1,1^,a1,2^,a1,3^....a1,T^。分别于对应位置上的v1,v2,v3....vT相乘。然后求和,这样便得到了与输入的x1所对应的输出b1
    同样地,与输入的x2所对应的输出b2,也根据类似过程获得,只是此时是利用与b2对应的q2分别与k1,k2,kT计算向量点积,主要过程如下图所示:
    image
    其他输入的计算过程以此类推,如下图所示:
    image
    image
    对于输入的序列x1,x2,x3....xT来说,与RNN/LSTM的处理过程不同,Self-attention机制能够并行对x1,x2,x3....xT进行计算,这大大提高了对x1,x2,x3....xT特征进行提取(即获得b1b2b3....bT)的速度。结合上述Self-attention的计算过程,并行计算的原理如下图所示:
    image
    由上图可以看到,通过对输入I分别乘以矩阵WQ,WK,WV,我们便得到了三个矩阵Q,K,W,然后通过后续计算得到注意力矩阵α^,进而得到输出O

对于在Transformer及BERT模型中用到的Multi-headed Self-attention结构与之略有差异,具体体现在:如果将前文中得到的整体看做一个“头”,则“多头”即指对于特定的 来说,需要用多组与之相乘,进而得到多组。如下图所示:
image
如上图所示,以右侧示意图中输入的a1为例,通过多头(这里取head=3)机制得到了三个输出bhead1,bhead2,bhead3,为了获得与对应的输出,在Multi-headed Self-attention中,我们会将这里得到的bhead1,bhead2,bhead3进行拼接(向量首尾相连),然后通过线性转换(即不含非线性激活层的单层全连接神经网络)得到。对于序列中的其他输入也是同样的处理过程,且它们共享这些网络的参数。
注意其中Wq,Wk,Wv是可以学习的参数。

实现

在实现过程中通常选择缩放点积注意力作为每一个注意力头。为了避免计算代价和参数代价的大幅增长,我
们设定pq=pk=pv=po/h。值得注意的是,如果将查询、键和值的线性变换的输出数量设置为pqh=pkh=pvh=po,则可以并行计算h个头。在下面的实现中,po是通过参数num_hiddens指定的。

import math
import torch
from torch import nn
from d2l import torch as d2l

注意这里的维度变化是为了增加并行性。

#@save
class MultiHeadAttention(nn.Module):
    """多头注意力"""
    def __init__(self, key_size, query_size, value_size, num_hiddens,
                 num_heads, dropout, bias=False, **kwargs):
        super(MultiHeadAttention, self).__init__(**kwargs)
        self.num_heads = num_heads
        self.attention = d2l.DotProductAttention(dropout)
        self.W_q = nn.Linear(query_size, num_hiddens, bias=bias)
        self.W_k = nn.Linear(key_size, num_hiddens, bias=bias)
        self.W_v = nn.Linear(value_size, num_hiddens, bias=bias)
        self.W_o = nn.Linear(num_hiddens, num_hiddens, bias=bias)

    def forward(self, queries, keys, values, valid_lens):
        # queries,keys,values的形状:
        # (batch_size,查询或者“键-值”对的个数,num_hiddens)
        # valid_lens 的形状:
        # (batch_size,)或(batch_size,查询的个数)
        # 经过变换后,输出的queries,keys,values 的形状:
        # (batch_size*num_heads,查询或者“键-值”对的个数,
        # num_hiddens/num_heads)
    #    print(queries.shape,keys.shape,values.shape)
        
        queries = transpose_qkv(self.W_q(queries), self.num_heads)
        keys = transpose_qkv(self.W_k(keys), self.num_heads)
        values = transpose_qkv(self.W_v(values), self.num_heads)
        
      #  print(queries.shape,keys.shape,values.shape)
        if valid_lens is not None:
            # 在轴0,将第一项(标量或者矢量)复制num_heads次,
            # 然后如此复制第二项,然后诸如此类。
            valid_lens = torch.repeat_interleave(
                valid_lens, repeats=self.num_heads, dim=0)

        # output的形状:(batch_size*num_heads,查询的个数,
        # num_hiddens/num_heads)
        output = self.attention(queries, keys, values, valid_lens)

        # output_concat的形状:(batch_size,查询的个数,num_hiddens)
        output_concat = transpose_output(output, self.num_heads)
        return self.W_o(output_concat)

为了能够使多个头并行计算,上面的MultiHeadAttention类将使用下面定义的两个转置函数。具体来说,transpose_output函数反转了transpose_qkv函数的操作。

#@save
def transpose_qkv(X, num_heads):
    """为了多注意力头的并行计算而变换形状"""
    # 输入X的形状:(batch_size,查询或者“键-值”对的个数,num_hiddens)
    # 输出X的形状:(batch_size,查询或者“键-值”对的个数,num_heads,num_hiddens/num_heads)
    X = X.reshape(X.shape[0], X.shape[1], num_heads, -1)

    # 输出X的形状:(batch_size,num_heads,查询或者“键-值”对的个数,num_hiddens/num_heads)
    X = X.permute(0, 2, 1, 3)

    # 最终输出的形状:(batch_size*num_heads,查询或者“键-值”对的个数,
    # num_hiddens/num_heads)
    return X.reshape(-1, X.shape[2], X.shape[3])


#@save
def transpose_output(X, num_heads):
    """逆转transpose_qkv函数的操作"""
    X = X.reshape(-1, num_heads, X.shape[1], X.shape[2])
    X = X.permute(0, 2, 1, 3)
    return X.reshape(X.shape[0], X.shape[1], -1)

下面使用键和值相同的小例子来测试我们编写的MultiHeadAttention类。多头注意力输出的形状是(batch_size,num_queries,num_hiddens)。

num_hiddens, num_heads = 100, 5
attention = MultiHeadAttention(num_hiddens, num_hiddens, num_hiddens,
                               num_hiddens, num_heads, 0.5)
attention.eval()
MultiHeadAttention(
  (attention): DotProductAttention(
    (dropout): Dropout(p=0.5, inplace=False)
  )
  (W_q): Linear(in_features=100, out_features=100, bias=False)
  (W_k): Linear(in_features=100, out_features=100, bias=False)
  (W_v): Linear(in_features=100, out_features=100, bias=False)
  (W_o): Linear(in_features=100, out_features=100, bias=False)
)
batch_size, num_queries = 2, 4
num_kvpairs, valid_lens =  6, torch.tensor([3, 2])
X = torch.ones((batch_size, num_queries, num_hiddens))
Y = torch.ones((batch_size, num_kvpairs, num_hiddens))
attention(X, Y, Y, valid_lens).shape
queries,keys,values的维度
torch.Size([2, 4, 100]) torch.Size([2, 6, 100]) torch.Size([2, 6, 100])
queries,keys,values的维度
torch.Size([10, 4, 20]) torch.Size([10, 6, 20]) torch.Size([10, 6, 20])
torch.Size([2, 4, 100])

这个最后输出的维度为batch_size×tar_len×hidden_size

• 多头注意力融合了来自于多个注意力汇聚的不同知识,这些知识的不同来源于相同的查询、键和值的不同的子空间表示。
• 基于适当的张量操作,可以实现多头注意力的并行计算。

Multihead Attention中维度变化分析

image
1.Input: Encoder Multihead Attention 输入的 query, key, value 是相同的,都是经过了word embedding和pos embedding之后的source sentence,其维度为batch_size×sr_len×hidden_size。由于有num_heads个头需要并行计算,首先query, key, value分别经过一个线性变换,再将数据split给num_heads个头分别做注意力查询,即:
query:
batch_size×sr_len_q×hidden_sizereshapebatch_size×num_heads×sr_len_q×hidden_sizenum_heads
key:
batch_size×sr_len_q×hidden_sizereshapebatch_size×num_heads×sr_len_q×hidden_sizenum_heads
value:
batch_size×sr_len_q×hidden_sizereshapebatch_size×num_heads×sr_len_q×hidden_sizenum_heads

由于query, key, value 是相同的,因此有 sr_len_q = sr_len_k = sr_len_v
2.DotProductAttention: num_heads 个头的计算是并行的,即:
image

Encoder Multihead Attention中在计算softmax之前对 key 进行了 mask,目的是消除 padding 的影响。事实上,padding不仅对key有影响,对query也有影响,但在实际代码中mask仅针对key,而没有针对query。其实最原始代码是既有key mask,也有query mask的,但后来作者将query mask删去了,因为在最后计算 loss 的时候对 padding 位置的 loss 进行mask,也可达到相同的效果。

假设 batch_size = num_heads = 1,sr_len_q = sr_len_k = 6,source sentence 的最后两个位置是padding,那么Encoder Multihead Attention 中的 mask 为:

(111100111100111100111100111100)

即只对 key 的 padding 位置进行了 mask

3.Output: 需要将上面输出的num_heads个头的结果堆叠之后,再做一个线性变换:

batch_size×num_heads×sr_len_q×hiddensizenumheads

reshape

batch_size×sr_len_q×hidden_size

线

batch_size×sr_len_q×hidden_size

可以看一下这个图:
image

posted @   lipu123  阅读(1258)  评论(0编辑  收藏  举报
(评论功能已被禁用)
相关博文:
阅读排行:
· 震惊!C++程序真的从main开始吗?99%的程序员都答错了
· 别再用vector<bool>了!Google高级工程师:这可能是STL最大的设计失误
· 单元测试从入门到精通
· 【硬核科普】Trae如何「偷看」你的代码?零基础破解AI编程运行原理
· 上周热点回顾(3.3-3.9)
点击右上角即可分享
微信分享提示