浅析注意力(Attention)机制(一)-- 基本概念与核心思想
Attention顾名思义,说明这项机制是模仿人脑的注意力机制建立的,我们不妨从这个角度展开理解
2.1 人脑的注意力机制
人脑的注意力机制,就是将有限的注意力资源分配到当前关注的任务,或关注的目标之上,暂时忽略其他不重要的因素,这是人类利用有限的注意力资源从大量信息中快速筛选出高价值信息的手段,是人类在长期进化中形成的一种生存机制,极大地提高了信息处理的效率与准确性。
举个栗子,就以上班为例,今天本该又是摸鱼的一天,但你的“恩人”突然交给你一项任务——查找关于“注意力机制”的资料并总结,并于下班之前向她汇报。于是你不得不放下手上的娱乐节目,转而应付恩人派下的工作。你选定了“注意力机制”作为关键词开始搜索,在搜索引擎的推送下阴差阳错的看到了这篇博文(这是不可能的),又因为这篇博文关键信息太少而选择忽略了它,努力一番后又查到了一些资料,汇总的大量初步结果并提交恩人,按时下班,happy ending!
上面的例子中其实出现了多次“识别关键要素”或“筛选重要信息”的动作,这便是注意力机制的体现。而深度学习中的注意力机制从本质上讲和人类的选择性注意力机制类似,核心目标也是从众多信息中选择出对当前任务目标更关键的信息。
2.2 为什么需要Attention
在之前的博文《理解LSTM》中提到过,LSTM通过引入逻辑门,从结构层面上有效解决了序列长距离依赖问题(梯度消失)。然而,面对超长序列时(例如一段500多词的文本),LSTM也可能失效。而 Attention 机制可以更好地解决序列长距离依赖问题,并且具有并行计算能力。
我们还是以文本问题举例, 看一看RNN或LSTM处理超长文本序列时会发生什么?
可以看到, 为了理解当前文本,我们有时需要获得很久之前的历史状态下的某些信息。而RNNs从结构层面上无形中添加了一种假设,那就是当前的文本只和临近区域的文本具有较强的关联性,而和距离较远的上下文关联不大或没有关联。很明显,这样的假设是不恰当的,这就限制了RNNs处理文本的长度和理解文本的精度,而Attention的出现则几乎打破了模型对于文本长度的限制。
采用RNN架构的网络均具有这种局限, 包括LSTM, GRU等等
为了进一步理解,让我们从循环神经网络的老大难问题——机器翻译问题入手。
在翻译任务中,源语言和目标语言的单词数和语序往往不是一一对应的,这种输入和输出都是不定长序列的任务,称为 Seq2Seq,以英语和德语为例,如下图所示。
为了解决这个问题,我们创造了Encoder-Decoder结构的循环神经网络。
- 它先通过一个Encoder循环神经网络读入所有的待翻译句子中的单词,得到一个包含原文所有信息的中间隐藏层,接着把中间隐藏层状态输入Decoder网络,一个词一个词的输出翻译句子。
- 这样子,无论输入中的关键词语有着怎样的先后次序,由于都被打包到中间层一起输入后方网络,我们的Encoder-Decoder网络都可以很好地处理这些词的输出位置和形式了。
问题在于,由于中间状态\(C\)来自输入网络最后的隐藏层,一般来说它是一个大小固定的向量。既然是大小固定的向量,那么它能储存的信息就是有限的,当句子长度不断变长,由于后方的decoder网络的所有信息都来自中间状态,中间状态需要表达的信息就越来越多。在语句信息量过大时,中间状态就作为一个信息的瓶颈阻碍翻译了。这时我们很容易联想到,如果网络能够在处理长文本时懂得筛选关键信息, 而不是将全部文本都作为都作为中间状态储存,是不是就可以突破文本长度的限制了?这便是注意力机制的由来。
Encoder-Decoder(编码-解码)是深度学习中非常常见的一个模型框架,比如无监督算法的auto-encoding就是用编码-解码的结构设计并训练的;比如这两年比较热的image caption的应用,就是CNN-RNN的编码-解码框架;再比如神经网络机器翻译NMT模型,往往就是LSTM-LSTM的编码-解码框架。因此,准确的说,Encoder-Decoder并不是一个具体的模型,而是一类框架。Encoder和Decoder部分可以是任意的文字,语音,图像,视频数据,模型可以采用CNN,RNN,BiRNN、LSTM、GRU等等。所以基于Encoder-Decoder架构,我们可以设计出各种各样的应用算法。
2.3 Attention的核心思想
在正式介绍注意力机制之前,我们先要明确以下几个概念:
- 查询(Query):用于记录模型当前关注的任务信息,向量形式
- 键(Key):用于记录输入序列中每个信息单元的标识符或标签, 用于与Query进行比较,以决定哪些信息是相关的, 在机器翻译任务中,Key可能是源语言的每个单词或短语的特征向量
- 值(Value):Value通常包含输入序列的实际信息,当Query和Key匹配时,相应的Value值被用于计算输出
- 分数(Score): Score又称为注意力分数,用于表示Query和Key的匹配程度,Score越高,模型对当前信息单元的关注度越高
我们仍以机器翻译为例,通过引入注意力机制,让生成词不是只能关注全局的语义编码向量c,而是增加了一个“注意力范围”,表示接下来输出词时候要重点关注输入序列中的哪些部分,然后根据关注的区域来产生下一个输出,如下图所示。
在理解了注意力机制的作用之后,我们就可以对其具体步骤加以描述了(正片开始)。
如上图所示,Attention 通常可以进行如下描述,表示为将 Query(Q) 和 key-value pairs(把 Values 拆分成了键值对的形式) 映射到输出上,其中 query、每个 key、每个 value 都是向量,输出是 \(V\) 中所有 values 的加权,其中权重是由 Query 和每个 key 计算出来的,计算方法分为三步:
- 第一步:计算并比较 Q 和 K 的相似度,用 f 来表示:\(f(Q,K_i)\quad i=1,2,\cdots,m\), 一般第一步计算方法包括四种
- 点乘(transformer使用):\(f(Q,K_i)=Q^TK_i\)
- 加权:\(f(Q,K_i)=Q^TWK_i\)
- 拼接权重:\(f(Q,K_i)=W[Q^T;K_i]\)
- 感知器:\(f(Q,K_i)=V^T\tanh(WQ+UK_i)\)
- 将得到的相似度进行 softmax 操作,进行归一化,得到注意力分数:\(\alpha_i=softmax(\frac{f(Q,K_i)}{\sqrt{d}_k})\)
- 针对计算出来的权重 \(\alpha_{i}\),对 \(V\) 中的所有 values 进行加权求和计算,得到 Attention 向量:\(Attention=\sum_{i=1}^m\alpha_iV_i\)
2.4 Attention代码实现
最后附一个Attention机制的代码示例:
import torch
import torch.nn as nn
import torch.nn.functional as F
class SimpleAttention(nn.Module):
def __init__(self, input_dim):
super(SimpleAttention, self).__init__()
self.input_dim = input_dim
self.query = nn.Linear(input_dim, input_dim)
self.key = nn.Linear(input_dim, input_dim)
self.value = nn.Linear(input_dim, input_dim)
def forward(self, x):
Q = self.query(x)
K = self.key(x)
V = self.value(x)
# Compute attention scores (dot product of queries and keys)
attention_scores = torch.matmul(Q, K.transpose(-2, -1)) / self.input_dim ** 0.5
# Apply softmax to get attention weights
attention_weights = F.softmax(attention_scores, dim=-1)
# Weighted sum of values
output = torch.matmul(attention_weights, V)
return output, attention_weights
# Example usage
input_dim = 64
seq_length = 10
batch_size = 5
# Dummy input tensor (batch_size, seq_length, input_dim)
x = torch.rand(batch_size, seq_length, input_dim)
# Initialize the attention module
attention = SimpleAttention(input_dim)
# Forward pass
output, attention_weights = attention(x)
print("Output shape:", output.shape) # Expected: (batch_size, seq_length, input_dim)
print("Attention weights shape:", attention_weights.shape) # Expected: (batch_size, seq_length, seq_length)