Datawhale 组队学习 fun-transformer💡Task 3 Encoder
Datawhale 组队学习 fun-transformer
Datawhale项目链接:https://www.datawhale.cn/learn/summary/87
笔记作者:博客园-岁月月宝贝💫
微信名:有你在就不需要🔮给的勇气
Task 03 Encoder
本节主要从Encoder通向Attention!
一、什么是编码器?
🐣先回顾下“输入嵌入”和“位置编码”:
输入嵌入:将输入的原始数据(比如文本中的单词等)转化为向量表示。
位置编码:让模型能够捕捉到输入序列中元素的位置信息,因为标准的向量表示本身没有位置概念。
编码器工作流程
💯俗话说“实践是检验真理的唯一标准”,我们先从编码器的工作流程了解下它是否可以起到编码的作用(⊙o⊙)?
1.输入的信息传递
Input Embedding ────┐
│
Position Embedding ─┘
↓
Combined Input ────> Encoder 1 ────> Encoder 2 ────> ... ────> Encoder 6 ────> Output
整个 Encoder 部分由 6 个相同的子模块按顺序连接构成。第一个 Encoder 子模块接收来自嵌入(Input Embedding)和位置编码(Position Embedding)组合后的输入(inputs),其他 Encoder 子模块从前一个 Encoder 接收相应的输入(inputs),这样就形成了一个顺序传递信息的链路。
2.子模块核心处理
💐多头自注意力层(Multi-Head Self-Attention layer)
每个 Encoder 子模块在接收到输入后,首先会将其传递到多头自注意力层,每个输入元素会被映射为三个向量:Query(Q)、Key(K)和Value(V),接着计算输入序列不同位置之间的关联关系,生成相应的自注意力输出。
🌟这里讲的是多头自注意力机制(Multi-Head Self-Attention layer),是对Self-Attention(自注意力机制)的拓展,多头自注意力机制(Multi-Head Self-Attention layer)是由缩放点积注意力(Scaled Dot-product Attention) 和 多头注意力机制(Multi-Head Attention) 组成的,整个模块的输出会经过残差连接和层归一化(Add & Norm)。
🌺前馈层(Feedforward layer)
自注意力层的输出被传递到前馈层(Feedforward layer,也叫Position-wise Feed-Forward Networks,也叫FFN)。
前馈层一般是由全连接网络等构成,对自注意力层输出的特征做进一步的非线性变换(By激活函数),以提取更复杂、高层次的特征,然后将其输出向上发送到下一个编码器(如果不是最后一个 Encoder 的话),以便后续 Encoder 子模块继续进行处理。
数学结构解释:
FFN层是一个顺序结构,包括一个全连接层(FC) + ReLU 激活层 + 第二个全连接层(FC),通过在两个 FC 中间添加非线性变换,增加模型的表达能力,使模型能够捕捉到复杂的特征和模式。
\[FFN(x)=max(0,xW_1+b_1)W_2+b_2 \]
- \(xW_1+b_1\) 为第一个全连接层的计算公式
- \(max(0,xW_1+b_1)W_2\) 为 relu 的计算公式
- \(max(0,xW_1+b_1)W_2+b_2\)为第二个全连接层的计算公式
全连接层线性变换的主要作用为数据的升维和降维。\(W_1\) 的维度是(2048,512),\(W_2\) 是 (512,2048)。 即先升维,后降维,这是为了扩充中间层的表示能力,从而抵抗 ReLU 带来的模型表达能力的下降。
🔗它们都包含的残差连接(Residual connection)
数学上,残差连接可以表示为:\(\mathrm{Residual}=x+\mathrm{SubLayer}(x)\)
其中 \(x\) 是子层的输入;\({SubLayer}(x)\)是子层的输出。
每个 Encoder 子模块内部的自注意力层和前馈层均配备了残差快捷链路,这种连接方式使得输入信号能够以残差的形式参与到每一层的输出计算中。
🛑在Transformer架构中,残差连接是每个编码器单元内部的自注意力层和前馈层分别独立应用的,而不是整个编码器的自注意力层或前馈层共享一个残差连接。
- 对于自注意力层,其输入会与自注意力层的输出进行相加操作(假设自注意力层输入为 x,输出为 y,经过残差连接后变为 x + y )。
- 前馈层的输入也会和前馈层的输出进行相加。
残差连接(Residual Connection)有助于缓解深度网络训练过程中的梯度消失或梯度爆炸问题,使得信息更顺畅地在网络中传递,网络能够更容易地训练深层模型。
在残差连接之后,紧跟着会进行层归一化操作。
⏬它们都包含的层归一化(Normalisation)
数学公式:\(\mathrm{Output}=\text{Layer}\mathrm{Norm}(\text{Residual})\)
层归一化是对每一层的神经元的输出进行归一化处理。经过层归一化后的结果就是当前 Encoder 子模块最终的输出,然后传递给下一个 Encoder 子模块或者后续的其他模块(比如在 Encoder-Decoder 架构中传递给 Decoder 部分等情况)。
层归一化可以加速网络的收敛速度、提高模型的泛化能力等,使得模型训练更加稳定、高效。
💡层归一化(Norm)特征 【独立归一化】与批量归一化(Batch Normalization)不同,层归一化不依赖于批次中的其他样本。这意味着即使在处理小批量数据或者在线学习场景时,层归一化也能保持稳定和有效。 【稳定训练过程】层归一化通过将每个特征的均值变为 0,标准差变为 1,有助于减少内部协变量偏移(Internal Covariate Shift)。这种偏移是指神经网络在训练过程中,由于参数更新导致的每层输入分布的变化。通过归一化,可以使得每一层的输入分布更加稳定,从而加速训练过程。 【提高模型稳定性】由于层归一化减少了特征之间的尺度差异,这有助于避免某些特征在学习过程中占据主导地位,从而提高了模型的泛化能力和稳定性。 【适应不同类型的网络】层归一化特别适用于循环神经网络(RNN)和Transformer模型,因为这些网络结构在处理序列数据时,每个时间步或位置的状态是相互依赖的,而批量归一化在这些情况下可能不太适用。 【减少梯度消失和爆炸】通过归一化处理,可以减少梯度在传播过程中的消失或爆炸问题,尤其是在深层网络中。这有助于更有效地进行反向传播,从而提高训练效率。 【不受批量大小限制】层归一化不依赖于批次大小,因此在处理不同大小的批次时,不需要调整超参数,这使得层归一化更加灵活。 |
二、里面最难的多头自注意力(Multi-Head Self-Attention)
Transformer模型中,多头注意力机制允许模型在不同的子空间中学习到不同的关系😲
多头自注意力(Multi-Head Self-Attention)中的每个头都有自己的Q(查询)、K(键)和V(值),最后所有头的输出会经一个线性层拼接起来。
自注意力机制中的Q、K和V:
Query ≈ 搜索查询 —— 每个序列元素都有一个对应的查询,代表当前元素想要关注的信息,即“我应该关注什么?”它用于与其他元素的Key进行匹配,以确定哪些信息与当前元素最相关。
Key ≈ 数据库索引 —— 每个序列元素都有一个对应的键 ,代表其他元素可以被关注的信息,即“我这里有什么可以被关注?”,用于与Query向量进行匹配。
Value ≈ 实际的数据库条目 —— 当查询匹配到一个特定的键时,其对应的值就会被选中并返回。代表每个元素的实际内容,即“关注之后能得到的信息”。
计算过程概述:
在自注意力机制中,模型会根据Query和Key的匹配结果(即注意力分数)对Value进行加权求和,从而生成当前元素的上下文表示。每个元素的上下文表示组合起来,就是最终的输出序列。
前文有提到“多头自注意力机制(Multi-Head Self-Attention layer)是由缩放点积注意力(Scaled Dot-product Attention) 和 多头注意力机制(Multi-Head Attention) 组成的,整个模块的输出会经过残差连接和层归一化(Add & Norm)”,我们先来了解下缩放点积注意力(Scaled Dot-product Attention) 和 多头注意力机制(Multi-Head Attention) 👍
缩放点积注意力
self-attention 的输入是序列词向量,此处记为\(x\)。
1️⃣生成Query、Key和Value
输入序列 x 经过三个独立的线性变换(即三个不同的权重矩阵),分别生成Query(Q)、Key(K)和Value(V):
具体:
这些变换是独立的,意味着Q、K和V的生成互不影响。
eg.分别生成Query(Q)、Key(K)和Value(V):
2️⃣计算注意力分数
对于每个Query向量\(Q_i\),模型会计算其与所有Key向量 \(K_j\) 的点积,得到一个分数矩阵。这个分数矩阵表示每个Query与所有Key之间的相似度或相关性:
其中,\(K^T\) 是Key矩阵的转置。
eg.计算点积分数:
WHY使用点乘?
特性 | 点乘注意力 | 加法注意力 |
---|---|---|
计算效率 | 高,可以通过矩阵乘法并行优化 | 低,计算更为复杂 |
计算复杂度 | 理论上为 O(d),但实际硬件中通过并行化显著提升速度 | 理论上为 O(d),但非线性操作导致计算复杂 |
相似性衡量 | 有效,尤其在高维度向量时 | 非线性操作引入,效果上并无显著提升 |
数值稳定性 | 通过缩放避免数值不稳定问题 | 可能存在数值不稳定问题 |
硬件适应性 | 适合大规模模型训练和推理 | 适应性较差 |
实际效果 | 在高维度向量时表现良好 | 效果上没有显著提升,计算复杂 |
3️⃣缩放分数
首先,你要知道在多头注意力机制中,模型的总维度\(d_{\text{model}}\) 被平均分配到每个头。如果有 \(h\) 个头,那么每个头的维度 \(d_{k}\)(即每个头的Key向量的维度)为:
为了防止点积结果过大导致Softmax函数的梯度消失问题,分数会除以一个缩放因子\(\sqrt{d_k}\),其中 \(d_k\) 是Key向量的维度。缩放后的分数为:
💡如此设置缩放因子的优势 参数平衡:确保每个头的参数数量相同,总参数量与单头注意力模型一致,提升模型扩展性和可管理性。 计算效率:每个头的维度dk作为模型总维度的因子,有助于在矩阵运算时利用硬件加速,提高效率。 多样性:多个头在不同子空间操作,捕获输入间更丰富的关系。 可解释性和调试:合适的dk值使每个头更易于解释和调试,特定场景下可手动设置dk 值。 |
4️⃣计算注意力权重
使用Softmax函数对缩放后的分数进行归一化,得到每个Query与Key之间的注意力权重。这些权重表示每个Query应该从哪些Key中获取信息,权重值介于0到1之间,且每个Query的权重之和为1:
5️⃣加权求和Value
最后,使用注意力权重对Value向量进行加权求和,得到最终的注意力输出:
这样,每个Query都会根据注意力权重从Value中提取相关信息,最终生成包含上下文信息的输出。
另:输出矩阵维度为\(n·d_{v}\),\(n\) 表示输入序列的长度,\(d_{v}\)是Value向量的维度。
😎缩放点积注意力的掩码机制:
mask机制介绍:在注意力机制中,掩码是一种技术,用于防止模型在计算注意力时看到未来的信息。例如,在机器翻译任务中,当模型预测一个句子中的某个词时,它不应该看到该词之后的句子部分。
使用方法:训练时关闭,测试或推理时打开。在训练阶段,模型需要学习整个序列的依赖关系,因此掩码通常不使用。然而,在测试或推理阶段,模型需要逐个预测序列中的词,这时掩码就会被应用,以确保每个预测只依赖于已经生成的词,而不能看到未来的词。
😋输出矩阵的另一种表达方式:
\(head=Attention(QW_i^Q,KW_i^K,VW_i^V)\)(注意:本节默认只有一个头)
其中\(W_i^Q\text{、}W_i^K\text{、}W_i^V\) 的权重矩阵的维度分别为\((d_{k}\times\tilde{d}_{k})\text{、}(d_{k}\times\tilde{d}_{k})\text{、}(d_{v}\times\tilde{d}_{v})\)
多头注意力机制
把Attention过程独立地(参数不共享)重复做 head 次,再将结果拼接起来通过一个线性层进行变换\(MultiHead(Q,K,V)=Concat(head_1,\ldots,head_h)\),生成最终的输出——一个 \(n\times(h\tilde{d}_v)\text{}\) 维度的序列。
∵Attention过程独立(参数不共享) ∴不同的 head 的矩阵是不同的
∵Attention过程独立(参数不共享) ∴multi-head-Attention可以并行计算(Google论文里\(h=8\text{, }d_k=d_v=d_{model}/4=64\))
∵多头 attention可以形成多个子空间,让模型关注不同方面的信息(差异性) ∴一般多头 attention 的效果要优于单个 attention
∵头之间的差距随着所在层数变大而减小(可能意味着模型在深层并没有充分利用多头机制的潜力,或者模型可能在某种程度上变得过于复杂)
∴在评估模型时,需要考虑头之间的差异性是否达到了预期的效果,以及这种差异性是否真的对模型性能有正面影响
∵头数 h 的设置增大到某一个数就没效果了 ∴头数 h 不是越大越好
特性 | 多头注意力机制 | 传统(单头)注意力机制 |
---|---|---|
并行化 | 高,可以同时处理多个注意力头 | 低,通常一次只能处理一个注意力头 |
计算效率 | 高,显著加快计算速度 | 低,计算速度较慢 |
表示能力 | 强,每个头关注不同方面,表示更丰富 | 弱,只有一个注意力头,表示能力有限 |
模式捕捉 | 优秀,能捕捉各种模式和关系 | 有限,只能捕捉单一模式或关系 |
关系理解 | 强,增强了理解序列内复杂关系的能力 | 弱,理解复杂关系的能力有限 |
文本生成 | 优秀,增强了生成文本的能力 | 一般,生成文本的能力有限 |
泛化性 | 改进,关注局部和全局依赖,提高泛化性 | 一般,泛化性有限 |
任务适应性 | 高,提高了跨任务和领域的适应性 | 低,适应性有限 |
前面有讲“这里讲的是多头自注意力机制(Multi-Head Self-Attention layer),是对Self-Attention(自注意力机制)的拓展”,其实“在序列内部进行 Attention 操作,旨在寻找序列内部的联系”,叫 “自注意力”,亦称为内部注意力的自注意力机制的这节前面已经介绍完毕啦!
但是,我们还可以了解下“自注意力机制”的类别吧!
自注意力机制
以往关于 Seq2Seq 的研究大多仅将注意力机制应用于解码端,而 Google 的创新之处在于使用 Self Multi-Head Attention 进行序列编码。
Google版的Self-Attention: 🔅对同一个输入序列\(X\),分别进行三种独立的线性变换得到 \(Q_x、K_x、V_x\) 后,将其输入 Attention,体现在公式上即 \(Attention (Q_x, K_x, V_x)\)。
🐖注意:本文只把“Google版的Self-Attention”视为“Self-Attention”,其他的注意力方法(含以下三种)均视为普通“Attention”机制!
在Transformer模型中,自注意力(Self-Attention)机制是其核心组件之一,它允许模型在序列的每个位置计算注意力,从而捕捉序列内部的依赖关系。三种注意力机制——Encoder Self-Attention、Masked Decoder Self-Attention和Encoder-Decoder Attention——都是自注意力机制的不同应用形式,它们在Transformer模型中扮演着不同的角色:
- Encoder Self-Attention:
- 这是自注意力机制在编码器(Encoder)阶段的应用。
- 在这个阶段,每个输入词(或称为token)都会计算与其他所有输入词的关联。
- 这种机制使得模型能够捕捉输入序列中不同位置之间的依赖关系,无论这些位置之间的距离有多远。
- 它是自注意力机制的直接应用,因为每个词都在关注序列中的其他词。
- Masked Decoder Self-Attention:
- 这是自注意力机制在解码器(Decoder)阶段的应用,但加入了掩码(Masking)技术。
- 在解码器中,模型在生成序列时应该只依赖于已经生成的词,而不能“看到”未来的词。
- 为了实现这一点,Masked Decoder Self-Attention使用一个掩码来确保在计算当前词的注意力时,只考虑之前的词,从而形成一个三角矩阵。
- 这种掩码机制是自注意力机制的一个变体,它通过限制注意力的范围来保证序列生成的自洽性。
- Encoder-Decoder Attention:
- 这是自注意力机制在编码器和解码器之间的应用。
- 在这个阶段,解码器中的每个词都会计算与编码器输出的关联,即整个输入序列。
- 这种机制允许解码器在生成每个词时,都能关注到输入序列中的所有相关信息。
- 它是自注意力机制的一个扩展,因为它涉及到不同阶段(编码器和解码器)之间的交互(亮点也仅仅在于加了交互)。
总结来说,这三种注意力机制都是自注意力机制的特定应用,它们在Transformer模型中协同工作,以实现序列到序列的转换任务。自注意力机制提供了一种灵活的方式来捕捉序列内部的依赖关系,而通过不同的应用和扩展(如掩码和跨阶段注意力),Transformer模型能够处理复杂的序列生成和转换任务。
💚Self-Attention核心概念理解以这节以上“缩放点积注意力”为主!
Self-Attention优点
特性 | Self-Attention | RNN | CNN |
---|---|---|---|
参数数量 | \(O(n^2d)\) | \(O(nd^2)\) | \(O(knd^2)\) |
参数效率 | 高,当 n 远小于 d 时参数更少 | 低 | 中等 |
并行化能力 | 高,可以并行处理 | 低,需要逐步递推 | 中等,可以通过层叠并行 |
全局信息捕捉 | 优秀,挑重点,一步获取全局信息 | 较差,难以捕捉长时依赖 | 中等,需要增加层数来扩大感受野 |
解决长时依赖 | 优秀,一步矩阵计算即可 | 较差,需要逐步计算 | 中等,需要增加卷积层数 |
计算路径长度 | 一步矩阵计算 | 从1到n逐个计算 | 需要增加卷积层数来扩大视野 |
*计算路径长度:计算一个序列长度为n的信息要经过的路径长度。
表格总结了Self-Attention、RNN和CNN在参数数量、参数效率、并行化能力、全局信息捕捉、解决长时依赖和计算路径长度等方面的特点。Self-Attention在这些方面通常表现更好,尤其是在处理长序列和捕捉全局信息时。但是,对于非常长的序列,比如序列长度 N 大于序列维度 D 这种情况,可能需要使用窗口限制Self-Attention的计算量,以控制计算成本。
Self-Attention缺点
缺点 | 描述 | 解决方案或备注 |
---|---|---|
计算量大 | Self-Attention中包含三次线性映射\(O(n)\)和两次序列自身的矩阵乘法\(O(n^2d)\),对于长序列更难接受。 | 三次线性映射计算量相当于卷积核大小为3的一维卷积,但另外的矩阵乘法计算量更大。 |
无法捕捉位置信息 | 无法学习序列中的顺序关系。 | 可以通过加入位置信息(如位置向量)来改善,参考BERT模型。 |
实践上的局限性 | 某些任务RNN能轻松应对,而Transformer表现不佳,如复制字符串任务或处理超出训练时最大长度的序列(由于遇到了未曾见过的位置嵌入)。 | 论文《Universal Transformers》指出,Transformer在这些任务上不如RNN。 |
理论上的局限性 | Transformer不具备计算上的通用性(非图灵完备),在处理计算密集型问题时存在局限性。 | 与RNN不同,Transformer无法独立完成某些复杂的计算任务。 |
表格总结了Self-Attention机制的主要缺点,包括计算量大、无法捕捉位置信息、实践上的局限性和理论上的局限性,并提供了可能的解决方案或备注。
注意力机制与自注意力机制的区别
(🌟“加多头”涉及的操作都一样,都是多个并行的注意力头(Attention Head),每个头都有自己的线性变换矩阵用于计算查询(Query)、键(Key)和值(Value);除了下面这一点不一样,其他操作内容和流程如缩放点积注意力、“对输入进行线性变换以得到 Q、K、V,然后通过注意力机制计算每个头的输出,最后将各个头的输出进行拼接和线性变换得到最终的输出” 均相同!)
- 本质区别:查询、键、值的来源不同
🍎Attention:查询(Query)、键(Key)和值(Value)可以来自不同的输入源。
eg. Decoder部分: 查询(Query)来自解码器当前的输入,而键(Key)和值(Value)通常来自编码器的输出。这种机制使得模型能够将解码器当前的信息与编码器已经处理好的信息进行关联,从而更好地生成输出序列。
🍏Self - Attention:查询(Query)、键(Key)和值(Value)都来自同一个输入序列。
这意味着模型关注的是输入序列自身不同位置之间的关系。它可以让模型自己发现句子中不同单词之间的相互关联,比如在句子 “The dog chased the cat” 中,单词 “dog” 与 “chased”、“chased” 与 “cat” 之间的关系可以通过 Multi - Head Self - Attention 来挖掘。
- 本质区别导致1:功能重点有所差异
🍎Attention:主要用于融合不同来源的信息。
eg.在机器翻译的任务中,它用于将元文本经过 Encoder编码后的信息(作为 K 和 V )与解码器当前生成的部分目标语言句子(作为 Q )相结合,帮助解码器在生成目标语言句子时更好地参考源语言句子的语义和结构,从而生成更准确的翻译。
🍏Self - Attention:更侧重于挖掘输入序列自身的内在结构和关系。
eg.在文本生成任务中,它可以帮助模型理解当前正在生成的文本自身的语义连贯和语法结构。例如,在续写一个故事时,通过 Multi - Head Self - Attention 可以让模型把握已经生成的部分文本的主题、情节发展等内部关系,以便更好地续写。
- 本质区别导致2:输出信息性质不同
🍎Attention:输出的结果往往包含了两个或多个不同输入序列之间相互作用后的特征。
eg.在跨模态任务(如将文本和图像信息相结合)中,输出会包含文本和图像相互关联后的综合特征,用于后续的分类或生成等任务。
🍏Self - Attention:输出的是输入序列自身内部关系的一种特征表示。
eg.在对一个文本序列进行词性标注任务时,输出的特征能够反映出句子内部单词之间的语法和语义关联,用于确定每个单词的词性;也可用于图像处理中的自相关特征提取等。
三、用于区分的Cross Attention
Cross Attention的Q、K、V介绍和计算过程,同Self-Attention,可以跳到“二”的开头看。
Cross attention概念
定义:一种注意力机制,它允许一个序列(称为“查询”序列,提供Query向量)中的元素关注另一个序列(称为“键-值”序列,提供Key和Value向量)中的元素,从而在两个序列之间建立联系。
序列维度要求:为了进行交叉注意力计算,两个序列必须具有相同的维度(因为注意力机制的计算涉及到Q、K和V向量的点积操作,但是下文说明未必成立😥)。
两个序列可以是不同模态的数据,例如:
- 文本序列:一系列单词或子词的嵌入表示。
- 声音序列:音频信号的时序特征表示。
- 图像序列:图像的像素或特征图的嵌入表示。
作用:
- 跨模态学习:Cross Attention在多模态学习中尤为重要,因为它能够将不同模态(如文本、图像、声音)的信息进行有效融合。
- 交互式解码:在序列到序列的任务中,Cross Attention使得解码器能够利用编码器提供的上下文信息[eg.编码器生成的隐藏状态(作为键和值)],从而更准确地生成目标序列。
- 灵活性:Cross Attention提供了灵活性,因为它允许模型动态地关注另一个序列中与当前任务最相关的部分,适合那些那些需要建立跨序列依赖关系的任务。
Self Attention 和 Cross Attention 的对比
类别 | Self Attention | Cross Attention |
---|---|---|
输入来源 | 在Self Attention中,Query、Key和Value均来源于同一个序列。这意味着模型是在内部进行信息的自我比较和关联。 | 在Cross Attention中,Query来自于一个序列,而Key和Value则来自于另一个不同的序列。这种配置允许模型在不同的数据源之间建立联系。 |
信息交互对象 | Self Attention使得序列中的每个元素都能够关注序列中的所有其他元素,并基于这种关注来更新自己的表示。 | Cross Attention则允许来自一个序列的元素(通过Query)关注另一个序列中的所有元素(通过Key和Value),从而实现跨序列的信息融合。 |
应用场景 | Self Attention广泛应用于需要理解序列内部复杂依赖关系的场景,例如在自然语言处理中,用于捕捉句子中单词之间的相互作用。 | Cross Attention适用于那些需要在不同序列之间建立联系的场合,如机器翻译中的编码器和解码器之间的交互,或者在多模态学习中,将文本信息与图像特征对齐。 |
特征捕捉 | Self Attention能够捕捉并编码序列内部的全局依赖关系,使得每个位置的表示都融入了序列中其他位置的信息。 | Cross Attention则专注于捕捉并编码不同序列之间的全局依赖关系,使得一个序列的表示能够反映另一个序列中的相关信息。 |
*Cross Attention感觉就是前面的对比时的Attention!Attention应该包含Cross Attention~
代码实现
Cross Attention:一个序列(称为“查询”序列,提供Query向量)中的元素关注另一个序列(称为“键-值”序列,提供Key和Value向量)。
import torch
from torch import nn
import torch.nn.functional as F
import math
# 定义CrossAttention类,继承自nn.Module
class CrossAttention(nn.Module):
#初始化
def __init__(
self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.,
proj_drop=0., window_size=None, attn_head_dim=None):
super().__init__()
self.num_heads = num_heads#头数
head_dim = dim // num_heads#维数
if attn_head_dim is not None:
head_dim = attn_head_dim
all_head_dim = head_dim * self.num_heads
self.scale = qk_scale or head_dim ** -0.5
# 定义Q, K, V的线性变换层
self.q = nn.Linear(dim, all_head_dim, bias=False)
self.k = nn.Linear(dim, all_head_dim, bias=False)
self.v = nn.Linear(dim, all_head_dim, bias=False)
# 定义偏置参数
if qkv_bias:
self.q_bias = nn.Parameter(torch.zeros(all_head_dim))
self.v_bias = nn.Parameter(torch.zeros(all_head_dim))
else:
self.q_bias = None
self.v_bias = None
# 定义注意力权重的Dropout层和投影层
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(all_head_dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
#Dropout通过在训练过程中随机“丢弃”(即设置为零)网络中的一部分神经元的输出,来减少神经元之间复杂的共适应性。
#投影层通常指的是一个全连接层(Fully Connected Layer),它用于将输入数据从一个维度映射到另一个维度。这里定义了一个投影层,它将自注意力的输出从all_head_dim维度映射回模型的原始维度dim。
#前向传播
def forward(self, x, bool_masked_pos=None, k=None, v=None):
B, N, C = x.shape
N_k = k.shape[1]
N_v = v.shape[1]
q_bias, k_bias, v_bias = None, None, None
if self.q_bias is not None:
q_bias = self.q_bias
k_bias = torch.zeros_like(self.v_bias, requires_grad=False)
v_bias = self.v_bias
# 计算Q, K, V
q = F.linear(input=x, weight=self.q.weight, bias=q_bias)
q = q.reshape(B, N, 1, self.num_heads, -1).permute(2, 0, 3, 1, 4).squeeze(0) # (B, N_head, N_q, dim)
k = F.linear(input=k, weight=self.k.weight, bias=k_bias)
k = k.reshape(B, N_k, 1, self.num_heads, -1).permute(2, 0, 3, 1, 4).squeeze(0)
v = F.linear(input=v, weight=self.v.weight, bias=v_bias)
v = v.reshape(B, N_v, 1, self.num_heads, -1).permute(2, 0, 3, 1, 4).squeeze(0)
#Q、K、V线性变换,并通过重塑和置换将其转换为多头注意力的格式
#计算缩放点积注意力,得到注意力权重
q = q * self.scale
attn = (q @ k.transpose(-2, -1)) # (B, N_head, N_q, N_k)
#对注意力权重进行softmax归一化和Dropout处理。
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, -1)#使用注意力权重对值(Value)进行加权求和。
x = self.proj(x)#将加权求和的结果通过投影层,恢复到原始维度。
x = self.proj_drop(x)#对投影层的输出进行Dropout处理。
return x
# 设置相关的维度参数和输入张量示例
batch_size = 2 # 批次大小
dim = 64 # 特征维度
num_heads = 4 # 头的数量
seq_len_query = 10 # 查询序列长度
seq_len_key_value = 8 # 键值对序列长度
# 随机生成输入张量,模拟查询、键、值
query = torch.rand(batch_size, seq_len_query, dim)
key = torch.rand(batch_size, seq_len_key_value, dim)
value = torch.rand(batch_size, seq_len_key_value, dim)
# 实例化CrossAttention模块
cross_attention_module = CrossAttention(dim=dim, num_heads=num_heads)
# 进行前向传播计算
output = cross_attention_module(query, k=key, v=value)
print("输出结果的形状:", output.shape)
我感觉与Self attention最大的差别就在于“# 随机生成输入张量,模拟查询、键、值
query = torch.rand(batch_size, seq_len_query, dim)
key = torch.rand(batch_size, seq_len_key_value, dim)
value = torch.rand(batch_size, seq_len_key_value, dim)”的Q、K、V不是一个序列了。
上述代码输出:
输出结果的形状: torch.Size([2, 10, 64])
我们来逐步分析代码的逻辑,以解释为什么输出结果的形状是 torch.Size([2, 10, 64])
。
1. 输入张量的形状
- 查询(query):
torch.Size([2, 10, 64])
,表示有 2 个样本,每个样本有 10 个查询向量,每个向量的维度是 64。 - 键(key):
torch.Size([2, 8, 64])
,表示有 2 个样本,每个样本有 8 个键向量,每个向量的维度是 64。 - 值(value):
torch.Size([2, 8, 64])
,表示有 2 个样本,每个样本有 8 个值向量,每个向量的维度是 64。
2. CrossAttention 模块的逻辑
(1) 查询(Q)、键(K)、值(V)的线性变换
在 forward
方法中,首先对输入的 query
、key
和 value
进行线性变换:
q = F.linear(input=x, weight=self.q.weight, bias=q_bias)
k = F.linear(input=k, weight=self.k.weight, bias=k_bias)
v = F.linear(input=v, weight=self.v.weight, bias=v_bias)
这些线性变换的目的是将输入的特征向量映射到一个新的空间,以便计算注意力权重。
(2) 多头注意力的维度变换
接下来,代码将 q
、k
和 v
的形状重新排列,以支持多头注意力机制:
q = q.reshape(B, N, 1, self.num_heads, -1).permute(2, 0, 3, 1, 4).squeeze(0)
k = k.reshape(B, N_k, 1, self.num_heads, -1).permute(2, 0, 3, 1, 4).squeeze(0)
v = v.reshape(B, N_v, 1, self.num_heads, -1).permute(2, 0, 3, 1, 4).squeeze(0)
经过这些操作后:
q
的形状变为(B, num_heads, seq_len_query, head_dim)
,即(2, 4, 10, 16)
。k
的形状变为(B, num_heads, seq_len_key_value, head_dim)
,即(2, 4, 8, 16)
。v
的形状变为(B, num_heads, seq_len_key_value, head_dim)
,即(2, 4, 8, 16)
。
(3) 计算注意力权重
注意力权重通过 q
和 k
的点积计算:
attn = (q @ k.transpose(-2, -1)) # (B, num_heads, seq_len_query, seq_len_key_value)
q
的形状是(2, 4, 10, 16)
。k.transpose(-2, -1)
的形状是(2, 4, 16, 8)
(transpose
的主要作用是交换张量的两个维度)。- 点积后的
attn
的形状是(2, 4, 10, 8)
,表示每个头的注意力权重矩阵。
(4) 应用 Softmax 和 Dropout
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
- Softmax 沿着最后一个维度(
seq_len_key_value
,即 8)归一化注意力权重。 - Dropout 用于正则化,防止过拟合。
(5) 加权求和
最后,通过注意力权重对 v
进行加权求和:
x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
attn @ v
的形状是(2, 4, 10, 16)
。transpose(1, 2)
将形状变为(2, 10, 4, 16)
。reshape(B, N, -1)
将形状变为(2, 10, 64)
——reshape(B, N, -1)
将最后两个维度合并,最终形状为(B, seq_len, dim)
,其中dim = num_heads * head_dim
——-1
的作用是让 PyTorch 自动推断该维度的大小,以确保张量中元素的总数保持不变。
(6) 线性投影和 Dropout
x = self.proj(x)
x = self.proj_drop(x)
self.proj
是一个线性层,将x
的形状从(2, 10, 64)
映射回(2, 10, 64)
。self.proj_drop
是一个 Dropout 层,用于正则化。
3. 输出结果的形状
最终,输出的形状是 torch.Size([2, 10, 64])
,其中:
- 2 表示批次大小(batch size),因此输出包含 2 个样本。
- 10 表示查询序列的长度(
seq_len_query
),即每个查询向量经过注意力机制后生成一个输出向量。 - 64 表示每个输出向量的特征维度(
dim
)。
这种形状符合交叉注意力机制的预期:每个查询向量通过与键值对的交互生成一个输出向量,输出的序列长度与查询序列长度一致。