68多头注意力

点击查看代码
import math
import torch
from torch import nn
from d2l import torch as d2l

# 选择缩放点积注意力作为每一个注意力头
# 𝑝𝑞=𝑝𝑘=𝑝𝑣=𝑝𝑜/ℎ
# 查询、键和值的线性变换的输出数量设置为  𝑝𝑞ℎ=𝑝𝑘ℎ=𝑝𝑣ℎ=𝑝𝑜 , 则可以并行计算 ℎ 个头
#@save
class MultiHeadAttention(nn.Module):
    """多头注意力"""
    # 多个头一同计算
    def __init__(self, key_size, query_size, value_size, num_hiddens,
                 num_heads, dropout, bias=False, **kwargs):
        super(MultiHeadAttention, self).__init__(**kwargs)
        self.num_heads = num_heads
        self.attention = d2l.DotProductAttention(dropout)
        # 全连接层
        self.W_q = nn.Linear(query_size, num_hiddens, bias=bias)
        self.W_k = nn.Linear(key_size, num_hiddens, bias=bias)
        self.W_v = nn.Linear(value_size, num_hiddens, bias=bias)
        self.W_o = nn.Linear(num_hiddens, num_hiddens, bias=bias)

    def forward(self, queries, keys, values, valid_lens):
        # queries,keys,values的形状:
        # (batch_size,查询或者“键-值”对的个数,num_hiddens)
        # valid_lens 的形状:
        # (batch_size,)或(batch_size,查询的个数)
        # 经过变换后,输出的queries,keys,values 的形状:
        # (batch_size*num_heads,查询或者“键-值”对的个数,num_hiddens/num_heads)
        queries = transpose_qkv(self.W_q(queries), self.num_heads)
        keys = transpose_qkv(self.W_k(keys), self.num_heads)
        values = transpose_qkv(self.W_v(values), self.num_heads)

        if valid_lens is not None:
            # 在轴0,将第一项(标量或者矢量)复制num_heads次,
            # 然后如此复制第二项,然后诸如此类。
            # print('valid_lens : ', valid_lens)
            valid_lens = torch.repeat_interleave(
                valid_lens, repeats=self.num_heads, dim=0)
            # print('valid_lens : ', valid_lens)
            """
            valid_lens :  tensor([3, 2])
            valid_lens :  tensor([3, 3, 3, 3, 3, 2, 2, 2, 2, 2])
            """

        # output的形状:(batch_size*num_heads,查询的个数,num_hiddens/num_heads)
        output = self.attention(queries, keys, values, valid_lens)

        # output_concat的形状:(batch_size,查询的个数,num_hiddens)
        output_concat = transpose_output(output, self.num_heads)
        return self.W_o(output_concat)

# transpose_output函数反转了transpose_qkv函数的操作
#@save
def transpose_qkv(X, num_heads):
    """为了多注意力头的并行计算而变换形状"""
    # 输入X的形状:(batch_size,查询或者“键-值”对的个数,num_hiddens)
    # 输出X的形状:(batch_size,查询或者“键-值”对的个数,num_heads,num_hiddens/num_heads)
    X = X.reshape(X.shape[0], X.shape[1], num_heads, -1)

    # 输出X的形状:(batch_size,num_heads,查询或者“键-值”对的个数, num_hiddens/num_heads)
    X = X.permute(0, 2, 1, 3)

    # 最终输出的形状:(batch_size*num_heads,查询或者“键-值”对的个数, num_hiddens/num_heads)
    return X.reshape(-1, X.shape[2], X.shape[3])


#@save
def transpose_output(X, num_heads):
    # X的形状:(batch_size*num_heads,查询的个数,num_hiddens/num_heads)
    """逆转transpose_qkv函数的操作"""
    X = X.reshape(-1, num_heads, X.shape[1], X.shape[2])
    # X的形状:(batch_size,num_heads,查询的个数,num_hiddens/num_heads)
    X = X.permute(0, 2, 1, 3)
    # X的形状:(batch_size,查询的个数,num_heads,num_hiddens/num_heads)
    return X.reshape(X.shape[0], X.shape[1], -1)
    # X的形状:(batch_size,查询的个数,num_heads*num_hiddens/num_heads)
    # X的形状:(batch_size,查询的个数,num_hiddens)

num_hiddens, num_heads = 100, 5
attention = MultiHeadAttention(num_hiddens, num_hiddens, num_hiddens,
                               num_hiddens, num_heads, 0.5)
attention.eval()
print(attention)

batch_size, num_queries = 2, 4
num_kvpairs, valid_lens =  6, torch.tensor([3, 2])
# (2, 4, 100)
X = torch.ones((batch_size, num_queries, num_hiddens))
# (2, 6, 100)
Y = torch.ones((batch_size, num_kvpairs, num_hiddens))
# (2, 4, 100)
print(attention(X, Y, Y, valid_lens).shape)
posted @   荒北  阅读(30)  评论(0编辑  收藏  举报
相关博文:
阅读排行:
· 震惊!C++程序真的从main开始吗?99%的程序员都答错了
· 别再用vector<bool>了!Google高级工程师:这可能是STL最大的设计失误
· 单元测试从入门到精通
· 【硬核科普】Trae如何「偷看」你的代码?零基础破解AI编程运行原理
· 上周热点回顾(3.3-3.9)
点击右上角即可分享
微信分享提示