SwinUNet2022

1. 概述

本文提出了一种以\(Swin\)变压器层为基本块的\(SUNet\)恢复模型,并将其应用于\(UNet\)架构中进行图像去噪。

2. 背景

图像恢复是一种重要的低级图像处理方法,可以提高其在目标检测、图像分割和图像分类等高级视觉任务中的性能。在一般的恢复任务中,一个被损坏的图像Y可以表示为:

\[Y=D(X)+n \tag 1 \]

其中\(X\)是一个干净的图像,\(D(\cdot)\)表示退化函数,\(n\)表示加性噪声。一些常见的恢复任务是去噪、去模糊和去阻塞。

2.1 CNN局限性

虽然大多数基于卷积神经网络(CNN)的方法都取得了良好的性能,但卷积层存在几个问题。首先,卷积核与图像的内容无关(无法与图像内容相适应)。使用相同的卷积核来恢复不同的图像区域可能不是最好的解决方案。其次,由于卷积核可以看作是一个小块,其中获取的特征是局部信息,换句话说,当我们进行长期依赖建模时,全局信息就会丢失。

3. 结构

3.1 UNet

目前,UNet由于具有层次特征映射来获得丰富的多尺度上下文特征,是许多图像处理应用中著名的架构。此外,它利用编码器和解码器之间的跳跃连接来增强图像的重建过程。UNet被广泛应用于许多计算机视觉任务,如分割、恢复[。此外,它还有各种改进的版本,如Res-UNet,Dense-UNet,Attention-UNet[和Non-local-UNet。由于具有较强的自适应骨干网,UNet可以很容易地应用于不同的提取块,以提高性能。

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-MpLyfXEs-1647427670075)(C:\Users\Liujiawang\AppData\Roaming\Typora\typora-user-images\image-20220316140830395.png)]

3.2 Swin Transformer

Transformer模型在自然语言处理(NLP)领域取得了成功,并具有良好的竞争性能,特别是在图像分类方面。然而,直接使用Transformer到视觉任务的两个主要问题是:

(1)图像和序列之间的尺度差异很大。由于Transformer需要参数量为一维序列参数的平方倍,所以存在长序列建模的缺陷。

(2)Transformer不擅长解决实例分割等密集预测任务,即像素级任务。然而,Swin Transfomer通过滑动窗口解决了上述问题,降低了参数,并在许多像素级视觉任务中实现了最先进的性能。

3.3 SUNet

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-mK2P8qBG-1647427670077)(C:\Users\Liujiawang\AppData\Roaming\Typora\typora-user-images\image-20220316141320770.png)]

所提出的Swin Transformer UNet(SUNet)的架构是基于图像分割模型,如上图所示。SUNet由三个模块组成:

(1)浅层特征提取;

(2)UNet特征提取;

(3)重建模块

浅层特征提取模块:

对于有噪声的输入图像\(Y∈R^{H×W×3}\),其中H,W为失真图像的分辨率。我们使用单个3×3卷积层\(M_{SFE}(\cdot)\)获取输入图像的颜色或纹理等低频信息。浅特征\(F_{shallow}∈R^{H×W×C}\)可以表示为:

\[F_{shallow}=M_{SFE}(Y) \tag 2 \]

其中,C是浅层特征的通道数,在后一个实验部分中,我们都设置为96.

UNet 特征提取网络:

然后,将浅层特征\(F_{shallow}\)输入UNet特征提取\(M_{UFE}(\cdot)\),UNet用来提取高级、多尺度深度特征\(F_{deep}∈R^{H×W×C}\)

\[F_{deep}=M_{UFE}(F_{shallow}) \tag 3 \]

其中,\(M_{UFE}(\cdot)\)是带有Swin变压器块的UNet架构,它在单个块中包含8个Swin Transformer层来代替卷积。Swin Transformer Block(STB)和Swin Transformer Layer(STL)将在下一小节中进行详细说明。

重建层:

最后,我们仍然使用3×3卷积\(M_{R}(\cdot)\)从深度特征\(F_{deep}\)中生成无噪声图像\(\hat{X}∈R^{H×W×3}\),其公式为:

\[\hat{X}=M_{R}(F_{deep}) \tag 4 \]

注意,\(\hat{X}\)是以噪声图像\(Y\)作为SUNet的输入得到的,其中\(X\)是(1)中Y图像的原高分率图像。

3.4 Loss function

我们优化了我们的SUNet端到端与规则的\(L1\)像素损失的图像去噪:

\[L_{denoise}=||\hat{X}-X||_1 \tag 5 \]

3.5 Swin Transformer Block

在UNet提取模块中,我们使用STB来代替传统的卷积层,如下图所示。STL是基于NLP中的原始Transformer Layer。STL的数量总是2的倍数,其中一个是window multi-head-self-attention(W-MSA),另一个是shifted-window multi-head self-attention(SW-MSA)。

STL的公式描述:

\[\hat{f}^L=W-MSA(LN(f^{L-1}))+f^{L-1} \\ f^L=MLP(LN(\hat{f}^L))+\hat{f}^L \\ \hat{f}^{L+1}=SW-MSA(LN(f^{L}))+f^{L} \\ f^{L+1}=MLP(LN(\hat{f}^{L+1}))+\hat{f}^{L+1} \tag 6 \]

其中,\(LN(\cdot)\)表示为层归一化,\(MLP\)是多层感知器,它具有两个完全连接的层,同时后面跟一个线性单位(GELU)激活函数。

3.6 Resizing module

由于UNet具有不同的特征图尺度,因此调整大小的模块(例如,下样本和上样本)是必要的。在我们的SUNet中,我们使用\(patch\ merging\),并提出\(dual\ up-sample\)分别作为下样本和上样本模块。

3.6.1 patch merging

对于降采样模块,该文将每一组2×2相邻斑块的输入特征连接起来,然后使用线性层获得指定的输出通道特征。我们也可以把这看作是做卷积操作的第一步,也就是展开输入的特征映射。

3.6.2 Dual up-sample

对于上采样,原始的Swin-UNet采用patch expanding方法,等价于上采样模块中的转置卷积。然而,转置卷积很容易面对块效应。在这里,我们提出了一个新的模块,称为双上样本,它包括两种现有的上样本方法(即Bilinear和PixelShuffle),以防止棋盘式的artifacts。所提出的上采样模块的体系结构如下图所示。

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-joxathTz-1647427670077)(C:\Users\Liujiawang\AppData\Roaming\Typora\typora-user-images\image-20220316145653720.png)]

4. 结果

评估指标:

为了进行定量比较,我们考虑了峰值信噪比(PSNR)和结构相似度(SSIM)指数度量。

训练集:

采用DIV2K作为训练集,一共有900张高清图片。我们对每个训练图像随机裁剪100个大小为\(256×256\)的斑块,并对\(800\)张训练图像从\(σ=5\)\(σ=50\)\(patch\)中随机添加AWGN噪声。至于验证集,我们直接使用包含100张图像的测试集,并添加具有三种不同噪声水平的AWGN,\(σ=10、σ=30和σ=50\)

测试集:

对于评估,我们选择了CBSD68数据集,它有68张彩色图像,分辨率为768×512,以及Kodak24张数据集,由24张图像组成,图像大小为321×481。

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-WdJUyCFm-1647427670078)(C:\Users\Liujiawang\AppData\Roaming\Typora\typora-user-images\image-20220316170249843.png)]

在表1中,我们对去噪图像进行了客观的质量评价,并观察到以下三件事:

(1)该文的SUNet具有竞争性的SSIM值,因为Swin-Transformer是基于全局信息(q,k,v可以提取全局信息),使得去噪图像拥有更多的视觉效果。

(2)与基于unet的方法(DHDN、RDUNet)相比,该文所提出的SUNet模型中参数(↓60%)和FLOPs(↓3%)较少,在PSNR和SSIM上仍保持良好的得分

(3)与基于cnn的方法(DnCNN,IrCNN,FFDNet)相比,该文得到了其中最好的PSNR和SSIM结果,以及几乎相同的FLOPs。虽然该文的模型的参数最多(99M),但它是由于自注意操作不能共享核的权值造成的。

4. 总结

  • 提出了一种基于图像分割的双unet模型的双变换网络进行图像去噪。
  • 该文提出了一种双上样本块结构,它包括亚像素方法和双线性上样本方法,以防止棋盘伪影。实验结果表明,该方法优于转置卷积的原始上样本。
  • 该文的模型是第一个结合Swin变压器和UNet进行去噪的模型。

Reference: Swin-unet: Unet-like pure transformer for medical image segmentation

5. 某些代码的理解

5.1 window attention中的相对位置编码

代码位置位于: ./model/SUNet_detail.py 中的89行左右

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-KhTSbrYb-1647503570456)(C:\Users\Liujiawang\AppData\Roaming\Typora\typora-user-images\image-20220316205927221.png)]

我这里展示了一个例子:

>>> import torch
>>> coords_h=torch.arange(3)
>>> coords_w=torch.arange(3)
>>> coords=torch.stack(torch.meshgrid([coords_h,coords_w]))
>>> coords
tensor([[[0, 0, 0],
         [1, 1, 1],
         [2, 2, 2]],

        [[0, 1, 2],
         [0, 1, 2],
         [0, 1, 2]]])
>>> coords_flatten=torch.flatten(coords,1)
>>> coords_flatten
tensor([[0, 0, 0, 1, 1, 1, 2, 2, 2],
        [0, 1, 2, 0, 1, 2, 0, 1, 2]])
>>> relative_coords=coords_flatten[:,:,None]-coords_flatten[:,None,:]
>>> relative_coords
tensor([[[ 0,  0,  0, -1, -1, -1, -2, -2, -2],
         [ 0,  0,  0, -1, -1, -1, -2, -2, -2],
         [ 0,  0,  0, -1, -1, -1, -2, -2, -2],
         [ 1,  1,  1,  0,  0,  0, -1, -1, -1],
         [ 1,  1,  1,  0,  0,  0, -1, -1, -1],
         [ 1,  1,  1,  0,  0,  0, -1, -1, -1],
         [ 2,  2,  2,  1,  1,  1,  0,  0,  0],
         [ 2,  2,  2,  1,  1,  1,  0,  0,  0],
         [ 2,  2,  2,  1,  1,  1,  0,  0,  0]],

        [[ 0, -1, -2,  0, -1, -2,  0, -1, -2],
         [ 1,  0, -1,  1,  0, -1,  1,  0, -1],
         [ 2,  1,  0,  2,  1,  0,  2,  1,  0],
         [ 0, -1, -2,  0, -1, -2,  0, -1, -2],
         [ 1,  0, -1,  1,  0, -1,  1,  0, -1],
         [ 2,  1,  0,  2,  1,  0,  2,  1,  0],
         [ 0, -1, -2,  0, -1, -2,  0, -1, -2],
         [ 1,  0, -1,  1,  0, -1,  1,  0, -1],
         [ 2,  1,  0,  2,  1,  0,  2,  1,  0]]])

可以看到\(relative\_coords\)的第一维是2,分别对应x轴和y轴方向(或者高,宽的方向)。剩下两维呢是9*9。SUNet中间使用了一些SwinIR的结构,在SwinIR中是有shift-window的,在这里,我设置的window size为3。又因为再做attention的时候,我们把每一个window中的像素点当作一个token,那么最终的attention map(\(q * v\))的最后两维就是\(window\_width \cdot window\_height\)

下面再来看具体的物理意义,\(3 \times 3\)的window一共有9个数值,第一维度分别代表这两个轴;第一个矩阵中,第一行分别代表着第一个数值(一共有9个)在某个轴上相对于其他位置的距离(在第一个数值的右边为负,左边为正),第二个矩阵类似。不同window的相对位置是一样的。

>>> relative_coords
tensor([[[ 0,  0,  0, -1, -1, -1, -2, -2, -2],
         [ 0,  0,  0, -1, -1, -1, -2, -2, -2],
         [ 0,  0,  0, -1, -1, -1, -2, -2, -2],
         [ 1,  1,  1,  0,  0,  0, -1, -1, -1],
         [ 1,  1,  1,  0,  0,  0, -1, -1, -1],
         [ 1,  1,  1,  0,  0,  0, -1, -1, -1],
         [ 2,  2,  2,  1,  1,  1,  0,  0,  0],
         [ 2,  2,  2,  1,  1,  1,  0,  0,  0],
         [ 2,  2,  2,  1,  1,  1,  0,  0,  0]],

        [[ 0, -1, -2,  0, -1, -2,  0, -1, -2],
         [ 1,  0, -1,  1,  0, -1,  1,  0, -1],
         [ 2,  1,  0,  2,  1,  0,  2,  1,  0],
         [ 0, -1, -2,  0, -1, -2,  0, -1, -2],
         [ 1,  0, -1,  1,  0, -1,  1,  0, -1],
         [ 2,  1,  0,  2,  1,  0,  2,  1,  0],
         [ 0, -1, -2,  0, -1, -2,  0, -1, -2],
         [ 1,  0, -1,  1,  0, -1,  1,  0, -1],
         [ 2,  1,  0,  2,  1,  0,  2,  1,  0]]])
>>> relative_coords = relative_coords.permute(1, 2, 0).contiguous()
>>> relative_coords
tensor([[[ 0,  0],
         [ 0, -1],
         [ 0, -2],
         [-1,  0],
         [-1, -1],
         [-1, -2],
         [-2,  0],
         [-2, -1],
         [-2, -2]],

        [[ 0,  1],
         [ 0,  0],
         [ 0, -1],
         [-1,  1],
         [-1,  0],
         [-1, -1],
         [-2,  1],
         [-2,  0],
         [-2, -1]],

        [[ 0,  2],
         [ 0,  1],
         [ 0,  0],
         [-1,  2],
         [-1,  1],
         [-1,  0],
         [-2,  2],
         [-2,  1],
         [-2,  0]],

        [[ 1,  0],
         [ 1, -1],
         [ 1, -2],
         [ 0,  0],
         [ 0, -1],
         [ 0, -2],
         [-1,  0],
         [-1, -1],
         [-1, -2]],

        [[ 1,  1],
         [ 1,  0],
         [ 1, -1],
         [ 0,  1],
         [ 0,  0],
         [ 0, -1],
         [-1,  1],
         [-1,  0],
         [-1, -1]],

        [[ 1,  2],
         [ 1,  1],
         [ 1,  0],
         [ 0,  2],
         [ 0,  1],
         [ 0,  0],
         [-1,  2],
         [-1,  1],
         [-1,  0]],

        [[ 2,  0],
         [ 2, -1],
         [ 2, -2],
         [ 1,  0],
         [ 1, -1],
         [ 1, -2],
         [ 0,  0],
         [ 0, -1],
         [ 0, -2]],

        [[ 2,  1],
         [ 2,  0],
         [ 2, -1],
         [ 1,  1],
         [ 1,  0],
         [ 1, -1],
         [ 0,  1],
         [ 0,  0],
         [ 0, -1]],

        [[ 2,  2],
         [ 2,  1],
         [ 2,  0],
         [ 1,  2],
         [ 1,  1],
         [ 1,  0],
         [ 0,  2],
         [ 0,  1],
         [ 0,  0]]])

下面的代码就是将相对位置坐标全部加上\(window\_size-1\),使得全部为正值:

>>> relative_coords[:, :, 0] += window_size[0] - 1
>>> relative_coords[:, :, 1] += window_size[1] - 1
>>> relative_coords
tensor([[[2, 2],
         [2, 1],
         [2, 0],
         [1, 2],
         [1, 1],
         [1, 0],
         [0, 2],
         [0, 1],
         [0, 0]],

        [[2, 3],
         [2, 2],
         [2, 1],
         [1, 3],
         [1, 2],
         [1, 1],
         [0, 3],
         [0, 2],
         [0, 1]],

        [[2, 4],
         [2, 3],
         [2, 2],
         [1, 4],
         [1, 3],
         [1, 2],
         [0, 4],
         [0, 3],
         [0, 2]],

        [[3, 2],
         [3, 1],
         [3, 0],
         [2, 2],
         [2, 1],
         [2, 0],
         [1, 2],
         [1, 1],
         [1, 0]],

        [[3, 3],
         [3, 2],
         [3, 1],
         [2, 3],
         [2, 2],
         [2, 1],
         [1, 3],
         [1, 2],
         [1, 1]],

        [[3, 4],
         [3, 3],
         [3, 2],
         [2, 4],
         [2, 3],
         [2, 2],
         [1, 4],
         [1, 3],
         [1, 2]],

        [[4, 2],
         [4, 1],
         [4, 0],
         [3, 2],
         [3, 1],
         [3, 0],
         [2, 2],
         [2, 1],
         [2, 0]],

        [[4, 3],
         [4, 2],
         [4, 1],
         [3, 3],
         [3, 2],
         [3, 1],
         [2, 3],
         [2, 2],
         [2, 1]],

        [[4, 4],
         [4, 3],
         [4, 2],
         [3, 4],
         [3, 3],
         [3, 2],
         [2, 4],
         [2, 3],
         [2, 2]]])

下面是将横纵坐标的相对位置加起来:

>>> relative_position_index = relative_coords.sum(-1)
>>> relative_position_index
tensor([[4, 3, 2, 3, 2, 1, 2, 1, 0],
        [5, 4, 3, 4, 3, 2, 3, 2, 1],
        [6, 5, 4, 5, 4, 3, 4, 3, 2],
        [5, 4, 3, 4, 3, 2, 3, 2, 1],
        [6, 5, 4, 5, 4, 3, 4, 3, 2],
        [7, 6, 5, 6, 5, 4, 5, 4, 3],
        [6, 5, 4, 5, 4, 3, 4, 3, 2],
        [7, 6, 5, 6, 5, 4, 5, 4, 3],
        [8, 7, 6, 7, 6, 5, 6, 5, 4]])

下面为定义bias,随机初始化,但是在网络的迭代训练中,是会被反向传播的

 # define a parameter table of relative position bias
 self.relative_position_bias_table = nn.Parameter(torch.zeros((2 * window_size[0] 			- 1) * (2 * window_size[1] - 1), num_heads))  # 2*Wh-1 * 2*Ww-1, nH
 trunc_normal_(self.relative_position_bias_table, std=.02)

一个window里面明明只有9个数值,为什么定义bias时,矩阵的维度为\(25, num_heads\)呢?**

这是因为上面我们加了\(window\_size -1\):不加之前最大值为\(window\_size-1\),后面在加上\(window\_size-1\),此时,最大值为\(2\times window\_size -2\),再算上零,一共有\(2\times window\_size-1\)。所以再初始化bias的时候,我觉得维度为\(2\times window\_size-1\)

就够了,不知道为什么要定义\((2\times window\_size-1)\times 2\times window\_size-1\)呢?

每个window是独立attention的,所以每个window的relative_position_bias都是一样的。

下面就是加MASK操作了,不再赘述。

5.2 Shifted-window-attention

代码位于\(SwinTransformerBlock\)中。

\(shift\_size>0\)时:

# calculate attention mask for SW-MSA
H, W = self.input_resolution
img_mask = torch.zeros((1, H, W, 1))  # 1 H W 1
h_slices = (slice(0, -self.window_size),
            slice(-self.window_size, -self.shift_size),
            slice(-self.shift_size, None))
w_slices = (slice(0, -self.window_size),
            slice(-self.window_size, -self.shift_size),
            slice(-self.shift_size, None))
cnt = 0
for h in h_slices:
    for w in w_slices:
        img_mask[:, h, w, :] = cnt
        cnt += 1
# nW, window_size, window_size, 1
mask_windows = window_partition(img_mask, self.window_size)  
mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
attn_mask = attn_mask.masked_fill(attn_mask != 0, 			 	float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
>>> img_mask = torch.zeros((1, 3, 3, 1))
>>> h_slices=(slice(0,-3),slice(-3,1),slice(-1,None))
>>> h_slices
(slice(0, -3, None), slice(-3, 1, None), slice(-1, None, None))
>>> w_slices=(slice(0,-3),slice(-3,-1),slice(-1,None))
>>> w_slices
(slice(0, -3, None), slice(-3, -1, None), slice(-1, None, None))
>>> cnt = 0
>>> for h in h_slices:
...     for w in w_slices:
...             img_mask[:,h,w,:]=cnt
...             cnt+=1
...
>>> img_mask
tensor([[[[4.],
          [4.],
          [5.]],

         [[4.],
          [4.],
          [5.]],

         [[7.],
          [7.],
          [8.]]]])
# nW, window_size, window_size, 1    [1,3,3,1]
mask_windows = window_partition(img_mask, self.window_size)
# 经过window_partition后,没有发生变化
>>> mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
tensor([[4., 4., 5., 4., 4., 5., 7., 7., 8.]])
>>> attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
>>> attn_mask.size()
torch.Size([1, 9, 9])
>>> attn_mask
tensor([[[ 0.,  0.,  1.,  0.,  0.,  1.,  3.,  3.,  4.],
         [ 0.,  0.,  1.,  0.,  0.,  1.,  3.,  3.,  4.],
         [-1., -1.,  0., -1., -1.,  0.,  2.,  2.,  3.],
         [ 0.,  0.,  1.,  0.,  0.,  1.,  3.,  3.,  4.],
         [ 0.,  0.,  1.,  0.,  0.,  1.,  3.,  3.,  4.],
         [-1., -1.,  0., -1., -1.,  0.,  2.,  2.,  3.],
         [-3., -3., -2., -3., -3., -2.,  0.,  0.,  1.],
         [-3., -3., -2., -3., -3., -2.,  0.,  0.,  1.],
         [-4., -4., -3., -4., -4., -3., -1., -1.,  0.]]])
>>> attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
>>> attn_mask
tensor([[[   0.,    0., -100.,    0.,    0., -100., -100., -100., -100.],
         [   0.,    0., -100.,    0.,    0., -100., -100., -100., -100.],
         [-100., -100.,    0., -100., -100.,    0., -100., -100., -100.],
         [   0.,    0., -100.,    0.,    0., -100., -100., -100., -100.],
         [   0.,    0., -100.,    0.,    0., -100., -100., -100., -100.],
         [-100., -100.,    0., -100., -100.,    0., -100., -100., -100.],
         [-100., -100., -100., -100., -100., -100.,    0.,    0., -100.],
         [-100., -100., -100., -100., -100., -100.,    0.,    0., -100.],
         [-100., -100., -100., -100., -100., -100., -100., -100.,    0.]]])

我们再来看forward函数对x的shift操作:

>>> x=torch.arange(0,9)
>>> x
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8])
>>> x=x.unsqueeze(0)+x.unsqueeze(1)
>>> x
tensor([[ 0,  1,  2,  3,  4,  5,  6,  7,  8],
        [ 1,  2,  3,  4,  5,  6,  7,  8,  9],
        [ 2,  3,  4,  5,  6,  7,  8,  9, 10],
        [ 3,  4,  5,  6,  7,  8,  9, 10, 11],
        [ 4,  5,  6,  7,  8,  9, 10, 11, 12],
        [ 5,  6,  7,  8,  9, 10, 11, 12, 13],
        [ 6,  7,  8,  9, 10, 11, 12, 13, 14],
        [ 7,  8,  9, 10, 11, 12, 13, 14, 15],
        [ 8,  9, 10, 11, 12, 13, 14, 15, 16]])
>>> x=x.unsqueeze(2).unsqueeze(0)
>>> x.size()
torch.Size([1, 9, 9, 1])
>>> shifted_x = torch.roll(x, shifts=(-1, -1), dims=(1, 2))
>>> xx=shifted_x.squeeze(3).squeeze(0)
>>> xx
tensor([[ 2,  3,  4,  5,  6,  7,  8,  9,  1],
        [ 3,  4,  5,  6,  7,  8,  9, 10,  2],
        [ 4,  5,  6,  7,  8,  9, 10, 11,  3],
        [ 5,  6,  7,  8,  9, 10, 11, 12,  4],
        [ 6,  7,  8,  9, 10, 11, 12, 13,  5],
        [ 7,  8,  9, 10, 11, 12, 13, 14,  6],
        [ 8,  9, 10, 11, 12, 13, 14, 15,  7],
        [ 9, 10, 11, 12, 13, 14, 15, 16,  8],
        [ 1,  2,  3,  4,  5,  6,  7,  8,  0]])

那为什么要加mask呢,它是由一个假设的,假设各个window之间不相关,各个window单独做attention。其中我们以\(H=9,W=9, window\_size=3, shift\_size=1\)为例。

图片参考:SWin Transformer

没有进行shift的window划分图:
在这里插入图片描述

再forward中,是会对输入的x进行shift操作的:
在这里插入图片描述

shift后的操作:

在这里插入图片描述

其中,上图黑线代表原来的边界。每个彩色框代表经过shift操作后的window划分,可以看到每个彩色框内部黑线位置是一样的;黑线是window的边界。

>>> attn_mask
tensor([[[   0.,    0., -100.,    0.,    0., -100., -100., -100., -100.],
         [   0.,    0., -100.,    0.,    0., -100., -100., -100., -100.],
         [-100., -100.,    0., -100., -100.,    0., -100., -100., -100.],
         [   0.,    0., -100.,    0.,    0., -100., -100., -100., -100.],
         [   0.,    0., -100.,    0.,    0., -100., -100., -100., -100.],
         [-100., -100.,    0., -100., -100.,    0., -100., -100., -100.],
         [-100., -100., -100., -100., -100., -100.,    0.,    0., -100.],
         [-100., -100., -100., -100., -100., -100.,    0.,    0., -100.],
         [-100., -100., -100., -100., -100., -100., -100., -100.,    0.]]])

我们举个例子说明:

以第一个彩色框为例,第一行代表第一个元素是否可以看到对应位置的元素(0代表看得到,-100代表看不到)。当两个像素点位于不同的window之中(黑线)就是看不到,就赋给一个负数,后面再做softmax,对应权重就会非常小。

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-COD9Kl1R-1647503570459)(C:\Users\Liujiawang\AppData\Roaming\Typora\typora-user-images\image-20220317132729985.png)]

5.3 Dual Up Sample

使用PixelShuffle和bilinear合起来的特征作为输出。

5.4 Absolute position embedding

# absolute position embedding
if self.ape:
    self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, 			                embed_dim)) 
    trunc_normal_(self.absolute_pos_embed, std=.02)

大概位置再SUNet的635行左右。

那为什么要加Absolute position embedding呢?

是因为再5.2 window attention中呢,是有一个relative position的,但是relative position作用范围仅仅是在一个window里面,即每个window相同位置上的relative position都是一样的。所以需要absolute position。

5.5 DownSampling

下采样是通过Patch Embedding实现的。

posted @ 2022-03-16 18:50  为红颜  阅读(1708)  评论(0编辑  收藏  举报