Swin Transformer结构梳理

Swim Transformer是特为视觉领域设计的一种分层Transformer结构。Swin Transformer的两大特性是滑动窗口和层级式结构。
1.滑动窗口使相邻的窗口之间进行交互,从而达到全局建模的能力。
2.层级式结构的好处在于不仅灵活的提供各种尺度的信息,同时还因为自注意力是在窗口内计算的,所以它的计算复杂度随着图片大小线性增长而不是平方级增长,这就使Swin Transformer能够在特别大的分辨率上进行预训练模型,并且通过多尺度的划分,使得Swin Transformer能够提取到多尺度的特征。也因此被人成为披着transformer皮的CNN。

模型图如下:

整体网络架构图:

其中Transformer Blocks详细结构如下图:

1.得到各Pathch特征构建序列

  • 输入图像数据为(224,224,3),通过卷积得到特征图,特征图分块转成向量,得到每个patch,每个patch带编码。
    def forward(self, x):
        B, C, H, W = x.shape
        # FIXME look at relaxing size constraints
        assert H == self.img_size[0] and W == self.img_size[1], \
            f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
        x = self.proj(x).flatten(2).transpose(1, 2)  # B Ph*Pw C,通过卷积(3,96,(4,4),(4,4))(颜色通道数,得到向量维度,卷积核大小,步长)得到特征图,特征图分块转成向量,得到每个patch,每个patch带编码
        print(x.shape)#4,3136,96,其中4表示batch,3136就是224/4*224/4,相当于有这么长的序列,其中每个元素是96维向量
        if self.norm is not None:
            x = self.norm(x)
        return x

2.window_partition窗口划分

(1)判断需不需要做窗口移动

  • 刚开始shift_size为0,不做偏移
        # cyclic shift
        if self.shift_size > 0:#做不做窗口滑动,刚开始shift_size为0,不做偏移
            if not self.fused_window_process:
                shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))#进行偏移
                # partition windows
                x_windows = window_partition(shifted_x, self.window_size)  # nW*B, window_size, window_size, C
            else:
                x_windows = WindowProcess.apply(x, B, H, W, C, -self.shift_size, self.window_size)
        else:
            shifted_x = x

(2)window_partition窗口划分

  • 划分的窗口大小7*7,个数8*8
def window_partition(x, window_size):
    """
    Args:
        x: (B, H, W, C)
        window_size (int): window size

    Returns:
        windows: (num_windows*B, window_size, window_size, C)
    """
    B, H, W, C = x.shape#输入为4.3136.96
    print(x.shape)#4.8.7.8.7.96窗口大小7*7,个数8*8
    x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
    windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
    print(windows.shape)#256.7.7.96(256表示窗口数4个batch*56)
    return windows

3.W-MSA(Window Multi-head Self Attention)

  • 对得到的窗口,计算各个窗口自己的自注意力得分
    def forward(self, x, mask=None):
        """注意力机制计算
        Args:
            x: input features with shape of (num_windows*B, N, C)
            mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
        """
        B_, N, C = x.shape
        qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)#qkv矩阵一起做
        print(qkv.shape)#3.256.3.49.32(3个矩阵,256个窗口,3头,一个窗口49个元素,96/3=32每一头得到32维向量)下采样后:3.64.6.49.32
        q, k, v = qkv[0], qkv[1], qkv[2]  # make torchscript happy (cannot use tensor as tuple)
        print(q.shape)#256.3.49.32
        print(k.shape)#256.3.49.32
        print(v.shape)#256.3.49.32
        q = q * self.scale
        attn = (q @ k.transpose(-2, -1))
        print(attn.shape)#256.3.49.49(3头,49个都要与49个计算注意力)

        relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
            self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1)  # 加上49个位置的不同特征Wh*Ww,Wh*Ww,nH
        print(relative_position_bias.shape)  # 49.49.3(256个窗口都是相同的49个位置,只需做一个49,49就行,3头)
        relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()  # nH, Wh*Ww, Wh*Ww
        print(relative_position_bias.shape)  # 3.49.49
        attn = attn + relative_position_bias.unsqueeze(0)#位置编码+注意力机制
        print(attn.shape)  # 256.3.49.49

        if mask is not None:#W-MSA不执行mask
            nW = mask.shape[0]
            attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
            attn = attn.view(-1, self.num_heads, N, N)
            attn = self.softmax(attn)
        else:
            attn = self.softmax(attn)

        attn = self.attn_drop(attn)
        print(attn.shape)  # 256.3.49.49
        x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
        print(x.shape)  # 256.49.96
        x = self.proj(x)
        print(x.shape)  # 256.49.96
        x = self.proj_drop(x)
        print(x.shape)  # 256.49.96
        return x

4.还原操作window_reverse

  • 还原成跟输入特征图一样的大小,便于进行下一个Block
def window_reverse(windows, window_size, H, W):
    """
    Args:
        windows: (num_windows*B, window_size, window_size, C)
        window_size (int): Window size
        H (int): Height of image
        W (int): Width of image

    Returns:
        x: (B, H, W, C)
    返回窗口大小"""
    B = int(windows.shape[0] / (H * W / window_size / window_size))
    x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
    print(x.shape)#4.8.8.7.7.96下采样一次:4.4.7.4.7.192
    x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
    print(x.shape)  #4.56.56.96;下采样一次:64.7.7.192((56/2/7)*(56/2/7))=16*4=64
    return x

5.SW-MSA(Shifted Window)

  • 原来的window都是算自己内部的,没有它们之间的关系,容易上模型局限在自己的小领地,于是执行SW-MSA(Shifted Window)。
  • 代码的执行与W-MSA不同的三点如下:

(1)做窗口滑动

  • 通过窗口的滑动,划分成新的窗口,计算新窗口内部的MSA
        # cyclic shift
        if self.shift_size > 0:#做不做窗口滑动,刚开始shift_size为0,不做偏移
            if not self.fused_window_process:
                shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))#进行偏移
                # partition windows
                x_windows = window_partition(shifted_x, self.window_size)  # nW*B, window_size, window_size, C
            else:
                x_windows = WindowProcess.apply(x, B, H, W, C, -self.shift_size, self.window_size)
        else:
            shifted_x = x

(2)mask

  • 原来算窗口自注意机制只用算4个,移动后需要算9个,为了让移动后窗户依然保持4个且每个窗口中的patch数量也保持一致,于是提出了mask。对于移动后的拼接在一起的新窗口,其中包含了不是挨着的地方移动过来的部分,他们之间不需要做自注意力机制,于是使用mask掩掉。
    示意图如下:
        if mask is not None:#W-MSA不执行mask
            nW = mask.shape[0]
            attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
            attn = attn.view(-1, self.num_heads, N, N)
            attn = self.softmax(attn)
        else:
            attn = self.softmax(attn)

(3)还原shift

  • 计算完特征后需要对图像进行还原,也就是还原平移
         # reverse cyclic shift
        if self.shift_size > 0:
                x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))#还原shift-3变为3
                print(x.shape)#
        else:
            x = shifted_x
        x = x.view(B, H * W, C)
        print(x.shape)  #4.3136.96

        # FFN残差连接
        x = shortcut + self.drop_path(x)
        x = x + self.drop_path(self.mlp(self.norm2(x)))
        return x

6.PatchMerging

  • 类似于卷积神经网络中的池化操作,增大了感受野,PatchMerging把相邻的4小patch合成一个大patch,从而实现增大感受野,获取多尺寸的特征。如图所示:
  • 对于具体的把相邻的4小patch合成一个大patch是间隔取,对H和W维度进行间隔采样后拼接在一起,得到H/2,W/2,C*4。示意图如下:
class PatchMerging(nn.Module):
    r""" Patch Merging Layer.

    Args:
        input_resolution (tuple[int]): Resolution of input feature.
        dim (int): Number of input channels.
        norm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm
    """

    def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
        super().__init__()
        self.input_resolution = input_resolution
        self.dim = dim
        self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
        self.norm = norm_layer(4 * dim)

    def forward(self, x):
        """
        x: B, H*W, C
        """
        H, W = self.input_resolution
        B, L, C = x.shape
        assert L == H * W, "input feature has wrong size"
        assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."

        x = x.view(B, H, W, C)

        x0 = x[:, 0::2, 0::2, :]  #切片 B H/2 W/2 C
        x1 = x[:, 1::2, 0::2, :]  # B H/2 W/2 C
        x2 = x[:, 0::2, 1::2, :]  # B H/2 W/2 C
        x3 = x[:, 1::2, 1::2, :]  # B H/2 W/2 C
        x = torch.cat([x0, x1, x2, x3], -1)  #拼接 B H/2 W/2 4*C
        x = x.view(B, -1, 4 * C)  # B H/2*W/2 4*C

        x = self.norm(x)
        x = self.reduction(x)
        return x

7.分层计算(执行后续的Block)

  • 一次下采样后(3136->784也就是5656->2828),然后继续走W-MSA和SW-MSA,也就是整体网络架构图中的各个stage的流程

8.输出层

        x = self.norm(x)  # B L C
        print(x.shape)#4.49.768
        x = self.avgpool(x.transpose(1, 2))  #平均池化 B C 1
        print(x.shape)#4.768.1
        x = torch.flatten(x, 1)
        print(x.shape)#4.768
        return x

    def forward(self, x):#把768个向量转换成1000个类别
        x = self.forward_features(x)
        x = self.head(x)
        return x
posted @ 2023-07-13 17:16  Frommoon  阅读(825)  评论(0编辑  收藏  举报