Swin Transformer 时间复杂度的分析
from https://blog.csdn.net/weixin_45943887/article/details/127881179
Swin Transformer 时间复杂度的分析
Swin Transformer 的论文中涉及到了两个关于时间复杂度的计算公式,在此梳理一下推导过程。
1. 前置知识
神经网络的运算过程中涉及大量矩阵运算,因此在分析时间复杂度之前,需要对矩阵运算的复杂度有一个基本的认识,假设有三个矩阵\(A\in \mathbb{R}^{m\times n}, B\in \mathbb{R}^{n\times l}, C\in \mathbb{R}^{l\times m}\) :
可以理解为:第一个矩阵的行维(第一维) × 第二个矩阵的列维(第二维) × 两个矩阵的相等维度。三个矩阵的情况需要先计算前两个,根据计算结果和第三个矩阵的维度就可以计算整体的复杂度。
2. Transformer 的时间复杂度
Transformer 是 2017 由 Google 提出的用于 NLP 领域的自注意力模型,其核心模块则是 Multi-Head Self-Attention(MSA):

假设序列长度为 L,词向量维度为C,所以输入的形状是 \([ b a t c h s i z e , L , C ]\) 。在计算时间复杂度时暂时忽略 batch size,而多头各自计算并不影响结果,所以也可以忽略。
MSA 可以分为四个阶段:
- Q, K, V 分别进行了 Linear 变换,每个都可以看成是\([L,C]\times[C,C]\),复杂度\(LC^2 \times 3 = 3LC^2\)
- dot-product 的 \(QK^T\),\([L, C]\times[C, L]\),复杂度\(L^2C\)
- Softmax 操作后与 V 相乘, \([L, L]\times[L, C]\),复杂度\(L^2 C\)
- Attention 最后的 Linear 层, \([L, C]\times[C, C]\),复杂度\(LC^2\)
四个阶段相加,得到最终的时间复杂度 \(4LC^2+2L^2 C\)
3. Vision Transformer 的时间复杂度
Vision Transformer 提出了 Patch Embedding 的思想:

Transformer 的时间复杂度为 \(4LC^2+2L^2 C\)。如上图所示,在 ViT 中, L = Patch 的个数 = 9, C = 每个 Patch 的 Depth = Embedding 的维度,这个 Depth 类似 CNN 中的 output channel。假设图像在 Patch 后的宽度为 w,高度为 h,则:
因此,ViT 的时间复杂度可以表示为:
这与 Swin Transformer 论文中所列的结果一致,时间复杂度与 h w 呈平方相关。
4. Swin Transformer 的时间复杂度
Swin Transformer 沿用了 Patch 的设定,但为了进一步降低时间复杂度,在此基础上提出了 Window 的思想。

如下图所示,Swin Transformer Block 的时间复杂度集中于 W-MSA 与 SW-MSA。SW-MSA 比 W-MSA 多了一来一回两步平移操作,和一步 Mask 操作,但是二者的计算量依然是同一个量级。

假设 Window 的边长为 M,则大小为\(M\times M\)。如第一张图中的 Layer 1 所示,在 W-MSA 中,所有的 patch 被划分为\(\frac{h}{M} \times \frac{w}{M}\) 个 Windows,每个 Window 单独做 self-attention 的 Q, K, V 运算。把 M 带入,每个 Window 的时间复杂度为:
因此,整个 W-MSA 的时间复杂度可以表示为:
推导结果与论文中的保持一致,时间复杂度降到与 h w 呈线性相关,至此推导完毕。
在论文后半部的实验也证明,Swin 相比于 ViT 很大幅度地降低了计算时间。

浙公网安备 33010602011771号