自注意力机制(1)-单头自注意层
自注意机制
1. 自注意机制的特点
考虑这样一个问题,输入长度为m的序列\(\{x_1, x_2,...,x_m\}\),序列中的元素都是向量,要求输出长度同样为m的序列\(\{c_1, c_2,...,c_m\}\),另外还有两个要求:
- 序列的长度m是不确定的,可以动态变化,但是神经网络的参数数量不能变。
- 输出的向量\(c_i\)不仅仅和\(x_i\)有关,而是依赖于所有新的输入向量\(\{x_1, x_2,...,x_m\}\)。
传统的RNN不能解决上述问题,因此传统RNN的输出\(c_i\)只依赖于\(\{x_1, x_2,...,x_i\}\),而不依赖于\(\{x_{i+1},...,x_m\}\)。自注意机制就能很好的解决上述问题。
2. 数学形式
输入:\(X=\{x_1, x_2,...,x_m\}\),\(x_i\)是\(d_{in}\times1\)的向量。
三个参数矩阵:\(W_q:d_q*d_{in}\); \(W_k:d_q*d_{in}\); \(W_v:d_{out}*d_{in}\)。
无论输入序列有多长,参数矩阵不需要发生改变,这三个参数矩阵需要从训练数据中进行学习
输出:\(C=\{c_1, c_2,...,c_m\}\),\(c_i\)是\(d_{out}\times1\)的向量。
计算步骤:
-
第一步将输入\(x_i\)映射为三元组\(\{q_i,k_i,v_i\}\):
- \(q_i=W_q*x_i\),\(q_i\)的大小是\(d_q\times1\)。
- \(k_i=W_k*x_i\),\(k_i\)的大小为\(d_q*1\)。
- \(v_i=W_v*x_i\),\(v_i\)的大小为\(d_{out}*1\)。
第一步将输出映射为三元组,上述是每个元素的计算过程。在实际计算中,会得到三个矩阵,\(Q=\{q_1, q_2,...,q_m\}\)大小为\(d_q\times m\),\(K=\{k_1,k_2,...,k_m\}\)大小为\(d_q\times m\),\(V=\{v_i, v_2,...,v_m\}\),大小为\(d_{out}\times m\)。
-
第二步利用\(q_i\)和\(K\)计算权重向量\(a_i\):
- \(a_i=\text{softmax}(<q_i,k_1>,<q_i, k_2>,...,<q_i, k_m>), i=1,..,m\)
上述的<,>表示内积,\(\text{softmax}\)函数导致\(a_i\)中所有元素的和为1,每个元素对应着与\(\{x_1, x_2,...,x_m\}\)的重要程度,权重矩阵\(A=\{a_1,a_2,...,a_m\}\),大小为\(m \times m\) 。
-
第三步利用权重矩阵\(A\)和\(V\)矩阵得到最终的输出矩阵\(C=\{c_1, c_2,...,c_m\}\),第\(i\)个输出向量\(c_i\)依赖于\(a_i\)和\(\{v_1, v_2,..., v_m\}\):
- \(c_i=[v_1, v_2,..,v_m]*a_i=\sum_{j=1}^m a_i^j*v_j, i=1,..,m\)
\(c_i\)是向量\(\{v_1, v_2,..., v_m\}\)的加权平均,权重是\(a_i=[a_i^1, a_i^2,...,a_i^m]\)。\(c_i\)的大小是\(d_{out}\times 1\)。整个输出矩阵\(C\)大小为\(d_{out}\times m\)。
为什么要叫“注意力”呢,我们看最后的输出\(c_i=a_i^1v_1+a_i^2v_2+\cdot \cdot+a_i^mv_m\),权重\(a_i=[a_i^1, a_i^2,...,a_i^m]\)反映出\(c_i\)最关注那些输入的\(v_i=W_v*x_i\),如果权重\(a_i^j\)大,说明\(x_j\)对\(c_i\)的影响较大,应当重点关注。
3. Pytorch代码实现(单头自注意层)
import torch
import torch.nn as nn
from math import sqrt
class Self_attention(nn.Module):
def __init__(self, d_in, d_q, d_out):
super(Self_attention, self).__init__()
self.din = d_in
self.dq = d_q
self.dout = d_out
self.Wq = nn.Linear(self.din, self.dq, bias=False)
self.Wk = nn.Linear(self.din, self.dq, bias=False)
self.Wv = nn.Linear(self.din, self.dout, bias=False)
self._norm_fact = 1/sqrt(self.dq) # 归一化层
def forward(self, x):
m, din = x.shape
assert din == self.din # 判断输入数据维度是否正确
# 第一步
Q = self.Wq(x) # m*dq
K = self.Wk(x) # m*dq
V = self.Wv(x) # m*dout
# 第二步
A = torch.softmax(torch.matmul(Q, K.T)*self._norm_fact, dim=-1) # m*m
# 第三步
C = torch.matmul(A, V) # m*dout
return C