LaViT:这也行,微软提出直接用上一层的注意力权重生成当前层的注意力权重 | CVPR 2024
Less-Attention Vision Transformer
利用了在多头自注意力(MHSA
)块中计算的依赖关系,通过重复使用先前MSA
块的注意力来绕过注意力计算,还额外增加了一个简单的保持对角性的损失函数,旨在促进注意力矩阵在表示标记之间关系方面的预期行为。该架构你能有效地捕捉了跨标记的关联,超越了基线的性能,同时在参数数量和每秒浮点运算操作(FLOPs
)方面保持了计算效率。来源:晓飞的算法工程笔记 公众号
论文: You Only Need Less Attention at Each Stage in Vision Transformers
Introduction
近年来,计算机视觉经历了快速的增长和发展,主要得益于深度学习的进步以及大规模数据集的可获得性。在杰出的深度学习技术中,卷积神经网络(Convolutional Neural Networks
, CNNs
)被证明特别有效,在包括图像分类、目标检测和语义分割等广泛应用中展现了卓越的性能。
受到Transformer
在自然语言处理领域巨大成功的启发,ViT
(Vision Transformers
)将每幅图像划分为一组标记。这些标记随后被编码以生成一个注意力矩阵,作为自注意力机制的基础组成部分。自注意力机制的计算复杂度随着标记数量的增加而呈平方增长,且随着图像分辨率的提高,计算负担变得更加沉重。一些研究人员尝试通过动态选择或标记修剪来减少标记冗余,以减轻注意力计算的计算负担。这些方法在性能上已证明与标准ViT
相当。然而,涉及标记减少和修剪的方法需要对标记选择模块进行细致设计,可能导致关键标记的意外丢失。在本研究中,作者探索了不同的方向,并重新思考自注意力的机制。发现在注意力饱和问题中,随着ViTs
层数的逐渐加深,注意力矩阵往往保持大部分不变,重复前面层中观察到的权重分配。考虑到这些因素,作者提出以下问题:
在网络的每个阶段,从开始到结束,是否真的有必要始终一致地应用自注意力机制?
在本文中,作者提出通过引入少注意力ViT
(Less-Attention Vision Transformer
)来修改标准ViT
的基本架构。框架由原始注意力(Vanilla Attention
, VA
)层和少注意力(Less Attention
, LA
)层组成,以捕捉长范围的关系。在每个阶段,专门计算传统的自注意力,并将注意力分数存储在几个初始的原始注意力(VA
)层中。在后续的层中,通过利用先前计算的注意力矩阵高效地生成注意力分数,从而减轻与自注意力机制相关的平方计算开销。此外,在跨阶段的降采样过程中,在注意力层内集成了残差连接,允许保留在早期阶段学习到的重要语义信息,同时通过替代路径传输全局上下文信息。最后,作者仔细设计了一种新颖的损失函数,从而在变换过程中保持注意力矩阵的对角性。这些关键组件使作者提出的ViT
模型能够减少计算复杂性和注意力饱和,从而实现显著的性能提升,同时降低每秒浮点运算次数(FLOPs
)和显著的吞吐量。
为验证作者提出的方法的有效性,在各种基准数据集上进行了全面的实验,将模型的性能与现有最先进的ViT
变种(以及最近的高效ViT
)进行了比较。实验结果表明,作者的方法在解决注意力饱和并在视觉识别任务中取得优越性能方面非常有效。
论文的主要贡献总结如下:
-
提出了一种新颖的
ViT
架构,通过重新参数化前面层计算的注意力矩阵生成注意力分数,这种方法同时解决了注意力饱和和相关的计算负担。 -
此外,提出了一种新颖的损失函数,旨在在注意力重新参数化的过程中保持注意力矩阵的对角性。作者认为这一点对维护注意力的语义完整性至关重要,确保注意力矩阵准确反映输入标记之间的相对重要性。
-
论文的架构在包括分类、检测和分割在内的多个视觉任务中,始终表现优异,同时在计算复杂度和内存消耗方面具有类似甚至更低的特点,胜过几种最先进的
ViTs
。
Methodology
Vision Transformer
令 \(\mathbf{x} \in \mathbb{R}^{H \times W \times C}\) 表示一个输入图像,其中 \(H \times W\) 表示空间分辨率, \(C\) 表示通道数。首先通过将图像划分为 $N = \frac{HW}{p^{2}} $ 个块来对图像进行分块,其中每个块 \(P_i \in \mathbb{R}^{p \times p \times C}\left(i \in \{1, \ldots, N\} \right)\) 的大小为 \(p \times p\) 像素和 \(C\) 通道。块大小 \(p\) 是一个超参数,用于确定标记的粒度。块嵌入可以通过使用步幅和卷积核大小均等于块大小的卷积操作提取。然后,每个块通过不重叠的卷积投影到嵌入空间 \(\boldsymbol{Z} \in \mathbb{R}^{N\times{D}}\) ,其中 \(D\) 表示每个块的维度。
-
Multi-Head Self-Attention
首先提供一个关于处理块嵌入的经典自注意力机制的简要概述,该机制在多头自注意力块(MHSAs
)的框架内工作。在第 \(l\) 个MHSA
块中,输入 \(\boldsymbol{Z}_{l-1}, l \in \{1,\cdots, L\}\) 被投影为三个可学习的嵌入 \(\{\mathbf{Q,K,V}\} \in \mathbb{R}^{N \times D}\) 。多头注意力旨在从不同的视角捕捉注意力;为简单起见,选择 \(H\) 个头,每个头都是一个维度为 \(N \times \frac{D}{H}\) 的矩阵。第 \(h\) 个头的注意力矩阵 \(\mathbf{A}_h\) 可以通过以下方式计算:
\(\mathbf{A}_h, \mathbf{Q}_h\) 和 \(\mathbf{K}_h\) 分别是第 \(h\) 个头的注意力矩阵、查询和键。还将值 \(\mathbf{V}\) 分割成 \(H\) 个头。为了避免由于概率分布的锐性导致的梯度消失,将 \(\mathbf{Q}_h\) 和 \(\mathbf{K}_h\) 的内积除以 \(\sqrt{d}\) ( \(d = D/H\) )。注意力矩阵被拼接为:
在空间分割的标记之间计算的注意力,可能会引导模型关注视觉数据中最有价值的标记。随后,将加权线性聚合应用于相应的值 \(\mathbf{V}\) :
-
Downsampling Operation
受到CNN
中层次架构成功的启发,一些研究将层次结构引入到ViTs
中。这些工作将Transformer
块划分为 \(M\) 个阶段,并在每个Transformer
阶段之前应用下采样操作,从而减少序列长度。在论文的研究中,作者采用了一个卷积层进行下采样操作,卷积核的大小和步幅都设置为 \(2\) 。该方法允许在每个阶段灵活调整特征图的尺度,从而建立一个与人类视觉系统的组织相一致的Transformer
层次结构。
The Less-Attention Framework
整体框架如图1
所示。在每个阶段,分两步提取特征表示。在最初的几个Vanilla Attention
(VA
) 层中,进行标准的多头自注意力(MHSA
)操作,以捕捉整体的长距离依赖关系。随后,通过对存储的注意力分数应用线性变换,模拟注意力矩阵,以减少平方计算并解决接下来的低注意力(LA
)层中的注意力饱和问题。在这里,将第 \(m\) 个阶段的初始 \(l\) -th VA 层的 \(\textrm{Softmax}\) 函数之前的注意力分数表示为 \(\mathbf{A}^{\text{VA},l}_m\) ,它是通过以下标准程序计算的:
这里, \(\mathbf{Q}_m^l\) 和 \(\mathbf{K}_m^l\) 分别表示来自第 \(m\) 个阶段第 \(l\) 层的查询和键,遵循来自前一阶段的下采样。而 \(L^{\text{VA}}_m\) 用于表示VA
层的数量。在最初的原始注意力阶段之后,丢弃传统的平方MHSA
,并对 \(\mathbf{A}^\textrm{VA}_m\) 应用变换,以减少注意力计算的数量。这个过程包括进行两次线性变换,中间夹一个矩阵转置操作。为了说明,对于该阶段的第 \(l\) 层( \(l > L^{\text{VA}}_m\) ,即LA
层)的注意力矩阵:
在这个上下文中, \(\Psi\) 和 \(\Theta\) 表示维度为 \(\mathbb{R}^{N\times{N}}\) 的线性变换层。这里, \(L_m\) 和 \(L_m^{\text{VA}}\) 分别表示第 \(m\) 个阶段的层数和VA
层的数量。在这两个线性层之间插入转置操作的目的是保持矩阵的相似性行为。这个步骤是必需的,因为单层中的线性变换是逐行进行的,这可能导致对角特性丧失。
Residual-based Attention Downsampling
当计算在分层ViT
(ViTs
)中跨阶段进行时,通常会对特征图进行下采样操作。虽然该技术减少了标记数量,但可能会导致重要上下文信息的丧失。因此,论文认为来自前一阶段学习的注意力亲和度对于当前阶段在捕捉更复杂的全局关系方面可能是有利的。受到ResNet
的启发,后者引入了快捷连接以减轻特征饱和问题,作者在架构的下采样注意力计算中采用了类似的概念。通过引入一个短路连接,可以将固有的偏差引入当前的多头自注意力(MHSA
)块。这使得前一阶段的注意力矩阵能够有效引导当前阶段的注意力计算,从而保留重要的上下文信息。
然而,直接将短路连接应用于注意力矩阵可能在这种情况下面临挑战,主要是由于当前阶段和前一阶段之间注意力维度的不同。为此,作者设计了一个注意力残差(AR
)模块,该模块由深度卷积(DWConv
)和一个 \(\textrm{Conv}_{1\times1}\) 层构成,用以在保持语义信息的同时对前一阶段的注意力图进行下采样。将前一阶段(第 \(m-1\) 阶段)的最后一个注意力矩阵(在 \(L_{m-1}\) 层)表示为 \(\textbf{A}_{m-1}^{\text{last}}\) ,将当前阶段(第 \(m\) 阶段)的下采样初始注意力矩阵表示为 \(\textbf{A}_m^\text{init}\) 。 \(\textbf{A}_{m-1}^{\text{last}}\) 的维度为 \(\mathbb{R}^{B\times{H}\times{N_{m-1}}\times{N_{m-1}}}\) ( \(N_{m-1}\) 表示第 \(m-1\) 阶段的标记数量)。将多头维度 \(H\) 视为常规图像空间中的通道维度,因此通过 \(\textrm{DWConv}\) 操作符( \(\textrm{stride}=2,\ \textrm{kernel size}=2\) ),可以在注意力下采样过程中捕获标记之间的空间依赖关系。经过 \(\textrm{DWConv}\) 变换后的输出矩阵适合当前阶段的注意力矩阵的尺寸,即 \(\mathbb{R}^{B\times{H}\times{N_m}\times{N_m}} (N_m = \frac{N_{m-1}}{2})\) 。在对注意力矩阵进行深度卷积后,再执行 \(\text{Conv}_{1\times1}\) ,以便在不同头之间交换信息。
论文的注意力下采样过程如图2
所示,从 \(\textbf{A}_{m-1}^\text{last}\) 到 \(\textbf{A}_{m}^\text{init}\) 的变换可以表示为:
其中 \(\textrm{LS}\) 是在CaiT
中引入的层缩放操作符,用以缓解注意力饱和现象。 \(\mathbf{A}^{\text{VA}}_m\) 是第 \(m\) 阶段第一层的注意力得分,它是通过将标准多头自注意力(MHSA
)与公式4
和由公式6
计算的残差相加得出的。
论文的注意力下采样模块受两个基本设计原则的指导。首先,利用 \(\text{DWConv}\) 在下采样过程中捕获空间局部关系,从而实现对注意力关系的高效压缩。其次,采用 \(\textrm{Conv}_{1\times1}\) 操作在不同头之间交换注意力信息。这一设计至关重要,因为它促进了注意力从前一阶段有效传播到后续阶段。引入残差注意力机制只需进行少量调整,通常只需在现有的ViT
主干中添加几行代码。值得强调的是,这项技术可以无缝应用于各种版本的Transformer
架构。唯一的前提是存储来自上一层的注意力得分,并相应地建立到该层的跳跃连接。通过综合的消融研究,该模块的重要性将得到进一步阐明。
Diagonality Preserving Loss
作者通过融入注意力变换算子,精心设计了Transformer
模块,旨在减轻计算成本和注意力饱和的问题。然而,仍然存在一个紧迫的挑战——确保变换后的注意力保留跨Token
之间的关系。众所周知,对注意力矩阵应用变换可能会妨碍其捕捉相似性的能力,这在很大程度上是因为线性变换以行的方式处理注意力矩阵。因此,作者设计了一种替代方法,以确保变换后的注意力矩阵保留传达Token
之间关联所需的基本属性。一个常规的注意力矩阵应该具备以下两个属性,即对角性和对称性:
因此,设计了第 \(l\) 层的对角性保持损失,以保持这两个基本属性如下所示:
在这里, \(\mathcal{L}_\textrm{DP}\) 是对角性保持损失,旨在维护公式8
中注意力矩阵的属性。在所有变换层上将对角性保持损失与普通的交叉熵 (CE
) 损失相结合,因此训练中的总损失可以表示为:
其中, \(Z_\texttt{Cls}\) 是最后一层表示中的分类标记。
Complexity Analysis
论文的架构由四个阶段组成,每个阶段包含 \(L_m\) 层。下采样层应用于每个连续阶段之间。因此,传统自注意力的计算复杂度为 \(\mathcal{O}(N_m^2{D})\) ,而相关的K-Q-V
转换则带来了 \(\mathcal{O}(3N_mD^2)\) 的复杂度。相比之下,论文的方法在变换层内利用了 \(N_m\times N_m\) 的线性变换,从而避免了计算内积的需要。因此,变换层中注意力机制的计算复杂度降至 \(\mathcal{O}(N_m^2)\) ,实现了 \(D\) 的减少因子。此外,由于论文的方法在 Less-Attention
中只计算查询嵌入,因此K-Q-V
转换复杂度也减少了3
倍。
在连续阶段之间的下采样层中,以下采样率2
为例,注意力下采样层中DWConv
的计算复杂度可以计算为 \(\textrm{Complexity} = 2 \times 2 \times \frac{N_m}{2} \times \frac{N_m}{2} \times D = \mathcal{O}(N_m^2D)\) 。同样,注意力残差模块中 \(\textrm{Conv}_{1\times1}\) 操作的复杂度也是 \(\mathcal{O}(N_m^2D)\) 。然而,重要的是,注意力下采样在每个阶段仅发生一次。因此,对比Less-Attention
方法所实现的复杂度减少,这些操作引入的额外复杂度可以忽略不计。
Experiments
如果本文对你有帮助,麻烦点个赞或在看呗~
更多内容请关注 微信公众号【晓飞的算法工程笔记】