自注意力机制(2)-多头自注意层

多头自注意层

上一篇描述了单头多注意层,但在实际应用中,通常使用的是多头自注意层,多头自注意层是由多个单头的组合。

1. 数学形式

输入:\(X=\{x_1, x_2,...,x_m\}\)\(x_i\)\(d_{in}\times1\)的向量。

参数:每个单头自注意层都有三个参数矩阵,\(W_q:d_q*d_{in}\); \(W_k:d_q*d_{in}\); \(W_v:d_{out}*d_{in}\),多头自注意层总共有3\(l\)个参数矩阵,\(l\)表示自注意层的个数。

输出:每个单头自注意层的输出为\(C=\{c_1, c_2,...,c_m\}\)\(c_i\)\(d_{out}\times1\)的向量,多头的输出就是有\(l\)个C矩阵,然后将所有单头自注意层对应位置的输出做连接。最终的每个输出\(c_i=[c_i^1; c_i^2; c_i^3;...;c_i^l]\)

2.Pytorch代码实现(多头自注意层)

使用一个大矩阵,将所有参数矩阵并行起来计算。计算过程和单层自注意层相同,最后将多头注意力的输出连接起来。

import torch
import torch.nn as nn
from math import sqrt

class MultiHeadSelfAttention(nn.Module):
    def __init__(self, d_in, d_k, d_out, num_heads=8):
        super(MultiHeadSelfAttention, self).__init__()
        assert d_k % num_heads == 0 and d_out % num_heads == 0   # dk和dout必须是多头数量的倍数,因为dk和dout表示所有头的总参数量
        self.din = d_in
        self.dq = d_k
        self.dout = d_out
        self.num_heads = num_heads

        self.Wq = nn.Linear(self.din, self.dq, bias=False)
        self.Wk = nn.Linear(self.din, self.dq, bias=False)
        self.Wv = nn.Linear(self.din, self.dout, bias=False)

        self._norm_fact = 1/sqrt(self.dq//num_heads)  # "//" 为整除运算符


    def forward(self, x):
        m, din = x.shape
        assert din == self.din

        nh = self.num_heads
        dk = self.dq // nh         # 每一个头的dq大小
        dv = self.dout // nh       # 每一个头的dout大小

        # 第一步
        Q = self.Wq(x).reshape(m, nh, dk).transpose(0, 1)   # nh*m*dk
        K = self.Wk(x).reshape(m, nh, dk).transpose(0, 1)
        V = self.Wv(x).reshape(m, nh, dv).transpose(0, 1)

        # 第二步
        A = torch.softmax(torch.matmul(Q, K.transpose(1, 2))*self._norm_fact, dim=-1)

        # 第三步
        C = torch.matmul(A, V)   # nh, m, dv

        # 将输出进行连接
        C = C.transpose(0, 1).reshape(m, self.dout)
        
        return C



posted @ 2024-09-23 17:23  吃瓜的哲学  阅读(36)  评论(0编辑  收藏  举报