参考
Attention Is All You Need
A General Survey on Attention Mechanisms in Deep Learning
注意力足矣(Attention Is All You Need)
一般注意力模型
这个模型接受一个输入,执行指定的任务,然后产生所需的输出
输入 X d x × n x = [ x 1 , … , x n x ] X d x × n x = [ x 1 , … , x n x ] ,每个向量 x ∈ R d x x ∈ R d x 为词、像素、声音序列特征等
特征模型 feature model:输入 X X ,输出特征 F d f × n f = [ f 1 , ⋯ , f n f ] F d f × n f = [ f 1 , ⋯ , f n f ] ,特征向量 f ∈ R d f f ∈ R d f
查询模型 query model:得到查询向量 q ∈ R d q q ∈ R d q ,用于提示注意力模型需要注意哪些特征部分
注意力模型 attention model,由单个或多个注意力模块 (Fig.2) 构成:输入特征向量 F F 和查询向量 q q ,对特征向量用权重矩阵 W K , W V W K , W V 线性变换出键矩阵 K d k × n f = W K F = [ k 1 , ⋯ , k n f ] K d k × n f = W K F = [ k 1 , ⋯ , k n f ] 和值矩阵 V d v × n f = W V F = [ v 1 , ⋯ , v n f ] V d v × n f = W V F = [ v 1 , ⋯ , v n f ] ;W K , W V W K , W V 可以是可训练的或者预先指定的
注意力模块为了得到 V V 中的值向量的加权平均值——特征 F F 中对查询 q q 重要的信息。为每个键向量 k k ,通过某个打分函数计算其和查询向量 q q 的注意力得分 e l = score ( q , k l ) e l = score ( q , k l ) (通常在 [ 0 , 1 ] [ 0 , 1 ] ),组成得分向量 e = [ e 1 , ⋯ , e n f ] e = [ e 1 , ⋯ , e n f ] ;
通过对齐层 (alignment) 归一化,比如 Softmax a l = exp e l ∑ j exp e j a l = exp e l ∑ j exp e j ,得到注意力权重向量 a a
对 V V 加权平均,得到上下文向量 (context vector) c = ∑ l a l v l c = ∑ l a l v l
输出模型 output model:输入上下文向量 c ∈ R d v c ∈ R d v ,训练模型输出预测值 ^ y ∈ R d ^ y y ^ ∈ R d y ^
自注意力 self att
若注意力模型 attention model 完全通过特征 F F 得到
例如,查询向量也由 F F 得到:Q d q × n f = W Q F = [ q 1 , ⋯ , q n f ] Q d q × n f = W Q F = [ q 1 , ⋯ , q n f ] ;当使用 q l q l 查询时,生成 c l c l ,即 C = self-att ( Q , K , V ) C = self-att ( Q , K , V )
多头注意力 multi-head att
如图 Fig.9 有 d d 个并行的注意力模块,思想是使用不同的权重矩阵对查询 q q 进行线性变换得到多个查询,每个查询期望专注于不同类型的信息,从而使得注意模型在上下文向量计算中引入更多信息。
每个 att head 都有自己训练的矩阵:W ( l ) q , W ( l ) K , W ( l ) V W q ( l ) , W K ( l ) , W V ( l ) ,得到查询向量、键矩阵和值矩阵 q ( l ) , K ( l ) , V ( l ) q ( l ) , K ( l ) , V ( l ) ,过一遍注意力模型得到上下文向量 c ( l ) c ( l ) ;将它们连接然后线性变换 W O W O ,得到最终的上下文向量 c = W O concat ( c ( 1 ) , ⋯ , c ( d ) ) c = W O concat ( c ( 1 ) , ⋯ , c ( d ) )
多头自注意力
Transformer 所使用的;后面就不强调自注意力了
Transformer 是一种序列转录模型 (sequence transduction models);序列转录模型输入、输出都为序列,通常由编码器 encoder 和解码器 decoder 组成,并用一种注意力机制 attention mechanism 连接它们
Transformer 模型在提出时仅作为机器翻译用,后来用于图像等其他领域
传统翻译通常是顺着序列跑,根据上一个位置对应的隐藏状态 h t − 1 h t − 1 和位置 t t 的输入,确定 h t h t ——这是一个难以并行的过程,且会面临历史信息存储过大或者丢失的问题
而 Transformer 能更好地支持并行,而且只使用注意力机制(在此之前通常使用 RNN 做)
概述
一般编码器/解码器结构:对于输入序列(比如单词序列)X = [ x 1 , ⋯ , x n ] X = [ x 1 , ⋯ , x n ] ,编码器输出 Z = [ z 1 , ⋯ , z n ] Z = [ z 1 , ⋯ , z n ] (序列长度不变);Z Z 丢进解码器输出 Y = [ y 1 , ⋯ , y m ] Y = [ y 1 , ⋯ , y m ] (序列长度可能改变)
如上图,Transformer 由左侧堆叠的 N N 个编码器、右侧堆叠的 N N 个解码器构成,N = 6 N = 6
对于解码器,它是自回归 auto-regressive 的:依次输出 y j y j ,且 y j y j 需要通过引入 y 1 , ⋯ , y j − 1 y 1 , ⋯ , y j − 1 的信息而得到;因此解码器也以 Outputs 作为输入,其每次右移一位 (shifted right)
位置编码
输入特征向量序列 X = [ x 1 , ⋯ , x n ] , x ∈ R d X = [ x 1 , ⋯ , x n ] , x ∈ R d (这里的 n n 就是之后的 n f n f )
在之后的 att-head 中我们只考虑键和查询的距离,并没有引入序列自带的时序信息;所以我们先为每个向量 x t x t 添加它所在位置 t t 的信息:构造关于位置 t t 的位置编码 positional encodings p t ∈ R d p t ∈ R d ,并直接加给 x t x t :
F (old) d × n f = [ x 1 + p 1 , ⋯ , x n + p n ] F d × n f (old) = [ x 1 + p 1 , ⋯ , x n + p n ]
p t = [ sin ( ω 0 t ) cos ( ω 0 t ) sin ( ω 1 t ) cos ( ω 1 t ) ⋯ sin ( ω d 2 − 1 t ) cos ( ω d 2 − 1 t ) ] T p t = [ sin ( ω 0 t ) cos ( ω 0 t ) sin ( ω 1 t ) cos ( ω 1 t ) ⋯ sin ( ω d 2 − 1 t ) cos ( ω d 2 − 1 t ) ] T
其中 ω k = 1 10000 2 k / d ω k = 1 10000 2 k / d
下图中🔗 ,从上往下每一行依次为一个 p t p t
使用三角函数有一个好处,就是可以用线性变换刻画相对位置:对于 x t x t 使用 ω k ω k 编码的片段 [ sin ( ω k t ) cos ( ω k t ) ] [ sin ( ω k t ) cos ( ω k t ) ] ,和相对位置为 ϕ ϕ 的 x t + ϕ x t + ϕ 的编码片段 [ sin ( ω k ( t + ϕ ) ) cos ( ω k ( t + ϕ ) ) ] [ sin ( ω k ( t + ϕ ) ) cos ( ω k ( t + ϕ ) ) ] ,可以用线性变换得到:
[ sin ( ω k ( t + ϕ ) ) cos ( ω k ( t + ϕ ) ) ] = [ cos ( ω k ϕ ) sin ( ω k ϕ ) − sin ( ω k ϕ ) cos ( ω k ϕ ) ] [ sin ( ω k t ) cos ( ω k t ) ] [ sin ( ω k ( t + ϕ ) ) cos ( ω k ( t + ϕ ) ) ] = [ cos ( ω k ϕ ) sin ( ω k ϕ ) − sin ( ω k ϕ ) cos ( ω k ϕ ) ] [ sin ( ω k t ) cos ( ω k t ) ]
编码器
多头自注意力
如上图所示,Multi-Head Att 亦如图,其为多头注意力,h = 8 h = 8 (相当于 8 8 个通道,学习出不同的距离空间)
对第 l l 个 att-head 先使用这个头自己的线性变换,将 F (old) F (old) 变换出 K ( l ) d k × n f , V ( l ) d v × n f , Q ( l ) d q × n f K d k × n f ( l ) , V d v × n f ( l ) , Q d q × n f ( l )
(给我:K , V , Q K , V , Q 作为长度仍为 n f n f 的序列,保持了一定的 F F 的序列信息,就是说“作为单词序列看还是有意义的”——不知道怎么更好地表述这个感觉;而之后通过 Multi-Head Att 的输出,也是同样长度的序列,其中每一项都是输入的加权和、而权重来自于这一项——类似卷积里的自相关给人的感觉)
以缩放点乘 (Scaled Dot-Product) 计算注意力得分,e ( l ) j e j ( l ) 为第 j ( 1 ≤ j ≤ n f ) j ( 1 ≤ j ≤ n f ) 个查询向量 q j q j 依次与所有键向量 { k } n f i = 1 { k } i = 1 n f 的点积、列出的向量
此外还要 Scale:除以 √ d k d k 防止某些向量因为长度较大导致得分占据优势
E ( l ) = [ e ( l ) 1 , e ( l ) 2 , ⋯ , e ( l ) n f ] = K ( l ) T Q ( l ) √ d k e ( l ) j = 1 √ d k ⎡ ⎢
⎢
⎢
⎢
⎢
⎢
⎢
⎢ ⎣ k ( l ) 1 T k ( l ) 2 T ⋮ k ( l ) n f T ⎤ ⎥
⎥
⎥
⎥
⎥
⎥
⎥
⎥ ⎦ q ( l ) j E ( l ) = [ e 1 ( l ) , e 2 ( l ) , ⋯ , e n f ( l ) ] = K ( l ) T Q ( l ) d k e j ( l ) = 1 d k [ k 1 ( l ) T k 2 ( l ) T ⋮ k n f ( l ) T ] q j ( l )
以 Softmax 对齐得 A ( l ) A ( l ) ,给 V ( l ) V ( l ) 加权得 C ( l ) C ( l )
第 j j 个上下文向量 c j c j ,为在查询向量 q j q j 的引导下、计算出的得分 e j e j 归一化后的 a j a j 作为权重、将值向量 { v } n f i = 1 { v } i = 1 n f 加权和后的结果
A ( l ) = [ a ( l ) 1 , ⋯ , a ( l ) n f ] = [ softmax ( e ( l ) 1 ) , ⋯ , softmax ( e ( l ) n f ) ] C ( l ) = [ c ( l ) 1 , ⋯ , c ( l ) n f ] = V ( l ) A ( l ) , c ( l ) j = ∑ i v ( l ) i a ( l ) j i A ( l ) = [ a 1 ( l ) , ⋯ , a n f ( l ) ] = [ softmax ( e 1 ( l ) ) , ⋯ , softmax ( e n f ( l ) ) ] C ( l ) = [ c 1 ( l ) , ⋯ , c n f ( l ) ] = V ( l ) A ( l ) , c j ( l ) = ∑ i v i ( l ) a j i ( l )
创建一个上下文向量作为注意力模型的输出:对于第 j j 个查询,将所有 att-head 的输出直接连接起来:concat ( C ( 1 ) , ⋯ , C ( h ) ) ∈ R ( d v h ) × n f concat ( C ( 1 ) , ⋯ , C ( h ) ) ∈ R ( d v h ) × n f ,然后用权重矩阵 W O ∈ R d c × ( d v h ) W O ∈ R d c × ( d v h ) 线性变换,得到最终 Multi-Head Attention 的输出 C C
C d c × n f = W O [ concat ( c ( 1 ) 1 , ⋯ , c ( h ) 1 ) , concat ( c ( 1 ) 2 , ⋯ , c ( h ) 2 ) , ⋯ , concat ( c ( 1 ) n f , ⋯ , c ( h ) n f ) ] = W O ⎡ ⎢
⎢
⎢
⎢
⎢
⎢
⎢ ⎣ ⎡ ⎢
⎢
⎢
⎢
⎢
⎢
⎢ ⎣ c ( 1 ) 1 c ( 2 ) 1 ⋮ c ( h ) 1 ⎤ ⎥
⎥
⎥
⎥
⎥
⎥
⎥ ⎦ ⎡ ⎢
⎢
⎢
⎢
⎢
⎢
⎢ ⎣ c ( 1 ) 2 c ( 2 ) 2 ⋮ c ( h ) 2 ⎤ ⎥
⎥
⎥
⎥
⎥
⎥
⎥ ⎦ ⋯ ⎡ ⎢
⎢
⎢
⎢
⎢
⎢
⎢ ⎣ c ( 1 ) n f c ( 2 ) n f ⋮ c ( h ) n f ⎤ ⎥
⎥
⎥
⎥
⎥
⎥
⎥ ⎦ ⎤ ⎥
⎥
⎥
⎥
⎥
⎥
⎥ ⎦ C d c × n f = W O [ concat ( c 1 ( 1 ) , ⋯ , c 1 ( h ) ) , concat ( c 2 ( 1 ) , ⋯ , c 2 ( h ) ) , ⋯ , concat ( c n f ( 1 ) , ⋯ , c n f ( h ) ) ] = W O [ [ c 1 ( 1 ) c 1 ( 2 ) ⋮ c 1 ( h ) ] [ c 2 ( 1 ) c 2 ( 2 ) ⋮ c 2 ( h ) ] ⋯ [ c n f ( 1 ) c n f ( 2 ) ⋮ c n f ( h ) ] ]
对于上述的维度可以令 d k = d v = d q = d c / h d k = d v = d q = d c / h
最后过一遍残差连接和层归一化 Add & Norm :
F (new) = LayerNorm ( F (old) + C ) F (new) = LayerNorm ( F (old) + C )
LayerNorm 是一个与 BatchNorm 有所异同的东西
BatchNorm 批标准化,输入 B N B N 层一批数据 batch × x batch × x :
实际上是 x x 的每一维 feature 对整个 batch 标准化
而对于 文章集合-单词序列-词特征向量 的情况:batch × seq × x batch × seq × x ,就要 x x 的每一维 feature 对整个 batch×seq 标准化
LayerNorm 对于 batch × seq × x batch × seq × x ,它是 batch 的每个样本在自己的 seq×feature 上标准化
这么选择的一个考量是,一个 batch 内,不同样本的 seq 不同,使用同一个均值/方差会估计不准确
前馈网络
输入之前的 F (new) = [ f 1 , ⋯ , f n f ] F (new) = [ f 1 , ⋯ , f n f ] ,通过前馈网络(简单 MLP + Add&Norm),输出 Z = [ z 1 , ⋯ , z n f ] Z = [ z 1 , ⋯ , z n f ]
特别注意,我们将序列的每个向量分别丢进同一个 MLP,而不是将整个序列矩阵丢进 MLP
Z = LayerNorm ( F (new) + FFN ( F (new) ) ) FFN ( F (new) ) = [ FFN ( f (new) 1 ) , FFN ( f (new) 2 ) , ⋯ , FFN ( f (new) n f ) ] FFN ( f ) = W 1 ReLU ( W 0 f + b 0 ) + b 1 Z = LayerNorm ( F (new) + FFN ( F (new) ) ) FFN ( F (new) ) = [ FFN ( f 1 (new) ) , FFN ( f 2 (new) ) , ⋯ , FFN ( f n f (new) ) ] FFN ( f ) = W 1 ReLU ( W 0 f + b 0 ) + b 1
解码器
带掩码的多头自注意力
注意力层计算加权和为 c j = ∑ i v i a j i c j = ∑ i v i a j i ,但是解码器在预测过程是自回归的,对于某个时刻 j j ,你输入给解码器的特征向量序列是之前的输出,即要求 c j c j 只与 v 1 , ⋯ , v j v 1 , ⋯ , v j 有关;所以训练 时不能把当前位置之后的信息也放进去了,这可以在打分函数的 Scale 之后、Softmax 之前增加一个掩码 Mask 实现
E ′ = Mask ( E ) = ⎡ ⎢
⎢
⎢
⎢
⎢
⎢ ⎣ e 1 , 1 e 1 , 2 ⋯ e 1 , n f − ∞ e 2 , 2 ⋯ e 2 , n f ⋮ ⋮ ⋱ ⋮ − ∞ − ∞ ⋯ e n f , n f ⎤ ⎥
⎥
⎥
⎥
⎥
⎥ ⎦ E ′ = Mask ( E ) = [ e 1 , 1 e 1 , 2 ⋯ e 1 , n f − ∞ e 2 , 2 ⋯ e 2 , n f ⋮ ⋮ ⋱ ⋮ − ∞ − ∞ ⋯ e n f , n f ]
这样在之后过 Softmax 时,取 ln ln 会使得 − ∞ − ∞ 变为 0 0 ,也就是对每个 e e 的前缀的归一化、而后缀都变为 0 0
剩下同理:加权、连接、过一遍残差连接和层归一化,得到 F (dec) F (dec)
以编码器输出作为输入的多头自注意力
如图,由编码器的输出 Z (enc) Z (enc) 得到键和值 K , V K , V 、由解码器第一部分的输出 F (dec) F (dec) 得到查询 Q Q :
K ( l ) = W ( l ) K Z (enc) V ( l ) = W ( l ) V Z (enc) Q ( l ) = W ( l ) Q F (dec) K ( l ) = W K ( l ) Z (enc) V ( l ) = W V ( l ) Z (enc) Q ( l ) = W Q ( l ) F (dec)
也就是说,你从 V V 中,根据来自解码器给的查询序列 Q Q 的每个位置 j j 的查询 q j q j ,用 K , q j K , q j 计算分数、从 V V 中提取出 q j q j 感兴趣的信息,作为输出序列中这个位置的值
前馈网络
与上文同理
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· 震惊!C++程序真的从main开始吗?99%的程序员都答错了
· 【硬核科普】Trae如何「偷看」你的代码?零基础破解AI编程运行原理
· 单元测试从入门到精通
· 上周热点回顾(3.3-3.9)
· winform 绘制太阳,地球,月球 运作规律