自注意力机制(2)-多头自注意层
多头自注意层
上一篇描述了单头多注意层,但在实际应用中,通常使用的是多头自注意层,多头自注意层是由多个单头的组合。
1. 数学形式
输入:\(X=\{x_1, x_2,...,x_m\}\),\(x_i\)是\(d_{in}\times1\)的向量。
参数:每个单头自注意层都有三个参数矩阵,\(W_q:d_q*d_{in}\); \(W_k:d_q*d_{in}\); \(W_v:d_{out}*d_{in}\),多头自注意层总共有3\(l\)个参数矩阵,\(l\)表示自注意层的个数。
输出:每个单头自注意层的输出为\(C=\{c_1, c_2,...,c_m\}\),\(c_i\)是\(d_{out}\times1\)的向量,多头的输出就是有\(l\)个C矩阵,然后将所有单头自注意层对应位置的输出做连接。最终的每个输出\(c_i=[c_i^1; c_i^2; c_i^3;...;c_i^l]\)。
2.Pytorch代码实现(多头自注意层)
使用一个大矩阵,将所有参数矩阵并行起来计算。计算过程和单层自注意层相同,最后将多头注意力的输出连接起来。
import torch
import torch.nn as nn
from math import sqrt
class MultiHeadSelfAttention(nn.Module):
def __init__(self, d_in, d_k, d_out, num_heads=8):
super(MultiHeadSelfAttention, self).__init__()
assert d_k % num_heads == 0 and d_out % num_heads == 0 # dk和dout必须是多头数量的倍数,因为dk和dout表示所有头的总参数量
self.din = d_in
self.dq = d_k
self.dout = d_out
self.num_heads = num_heads
self.Wq = nn.Linear(self.din, self.dq, bias=False)
self.Wk = nn.Linear(self.din, self.dq, bias=False)
self.Wv = nn.Linear(self.din, self.dout, bias=False)
self._norm_fact = 1/sqrt(self.dq//num_heads) # "//" 为整除运算符
def forward(self, x):
m, din = x.shape
assert din == self.din
nh = self.num_heads
dk = self.dq // nh # 每一个头的dq大小
dv = self.dout // nh # 每一个头的dout大小
# 第一步
Q = self.Wq(x).reshape(m, nh, dk).transpose(0, 1) # nh*m*dk
K = self.Wk(x).reshape(m, nh, dk).transpose(0, 1)
V = self.Wv(x).reshape(m, nh, dv).transpose(0, 1)
# 第二步
A = torch.softmax(torch.matmul(Q, K.transpose(1, 2))*self._norm_fact, dim=-1)
# 第三步
C = torch.matmul(A, V) # nh, m, dv
# 将输出进行连接
C = C.transpose(0, 1).reshape(m, self.dout)
return C