Swin Transformer结构梳理
目录
Swim Transformer是特为视觉领域设计的一种分层Transformer结构。Swin Transformer的两大特性是滑动窗口和层级式结构。
1.滑动窗口使相邻的窗口之间进行交互,从而达到全局建模的能力。
2.层级式结构的好处在于不仅灵活的提供各种尺度的信息,同时还因为自注意力是在窗口内计算的,所以它的计算复杂度随着图片大小线性增长而不是平方级增长,这就使Swin Transformer能够在特别大的分辨率上进行预训练模型,并且通过多尺度的划分,使得Swin Transformer能够提取到多尺度的特征。也因此被人成为披着transformer皮的CNN。
模型图如下:
整体网络架构图:
其中Transformer Blocks详细结构如下图:
1.得到各Pathch特征构建序列
- 输入图像数据为(224,224,3),通过卷积得到特征图,特征图分块转成向量,得到每个patch,每个patch带编码。
def forward(self, x):
B, C, H, W = x.shape
# FIXME look at relaxing size constraints
assert H == self.img_size[0] and W == self.img_size[1], \
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
x = self.proj(x).flatten(2).transpose(1, 2) # B Ph*Pw C,通过卷积(3,96,(4,4),(4,4))(颜色通道数,得到向量维度,卷积核大小,步长)得到特征图,特征图分块转成向量,得到每个patch,每个patch带编码
print(x.shape)#4,3136,96,其中4表示batch,3136就是224/4*224/4,相当于有这么长的序列,其中每个元素是96维向量
if self.norm is not None:
x = self.norm(x)
return x
2.window_partition窗口划分
(1)判断需不需要做窗口移动
- 刚开始shift_size为0,不做偏移
# cyclic shift
if self.shift_size > 0:#做不做窗口滑动,刚开始shift_size为0,不做偏移
if not self.fused_window_process:
shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))#进行偏移
# partition windows
x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
else:
x_windows = WindowProcess.apply(x, B, H, W, C, -self.shift_size, self.window_size)
else:
shifted_x = x
(2)window_partition窗口划分
- 划分的窗口大小7*7,个数8*8
def window_partition(x, window_size):
"""
Args:
x: (B, H, W, C)
window_size (int): window size
Returns:
windows: (num_windows*B, window_size, window_size, C)
"""
B, H, W, C = x.shape#输入为4.3136.96
print(x.shape)#4.8.7.8.7.96窗口大小7*7,个数8*8
x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
print(windows.shape)#256.7.7.96(256表示窗口数4个batch*56)
return windows
3.W-MSA(Window Multi-head Self Attention)
- 对得到的窗口,计算各个窗口自己的自注意力得分
def forward(self, x, mask=None):
"""注意力机制计算
Args:
x: input features with shape of (num_windows*B, N, C)
mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
"""
B_, N, C = x.shape
qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)#qkv矩阵一起做
print(qkv.shape)#3.256.3.49.32(3个矩阵,256个窗口,3头,一个窗口49个元素,96/3=32每一头得到32维向量)下采样后:3.64.6.49.32
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
print(q.shape)#256.3.49.32
print(k.shape)#256.3.49.32
print(v.shape)#256.3.49.32
q = q * self.scale
attn = (q @ k.transpose(-2, -1))
print(attn.shape)#256.3.49.49(3头,49个都要与49个计算注意力)
relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # 加上49个位置的不同特征Wh*Ww,Wh*Ww,nH
print(relative_position_bias.shape) # 49.49.3(256个窗口都是相同的49个位置,只需做一个49,49就行,3头)
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
print(relative_position_bias.shape) # 3.49.49
attn = attn + relative_position_bias.unsqueeze(0)#位置编码+注意力机制
print(attn.shape) # 256.3.49.49
if mask is not None:#W-MSA不执行mask
nW = mask.shape[0]
attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
attn = attn.view(-1, self.num_heads, N, N)
attn = self.softmax(attn)
else:
attn = self.softmax(attn)
attn = self.attn_drop(attn)
print(attn.shape) # 256.3.49.49
x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
print(x.shape) # 256.49.96
x = self.proj(x)
print(x.shape) # 256.49.96
x = self.proj_drop(x)
print(x.shape) # 256.49.96
return x
4.还原操作window_reverse
- 还原成跟输入特征图一样的大小,便于进行下一个Block
def window_reverse(windows, window_size, H, W):
"""
Args:
windows: (num_windows*B, window_size, window_size, C)
window_size (int): Window size
H (int): Height of image
W (int): Width of image
Returns:
x: (B, H, W, C)
返回窗口大小"""
B = int(windows.shape[0] / (H * W / window_size / window_size))
x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
print(x.shape)#4.8.8.7.7.96下采样一次:4.4.7.4.7.192
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
print(x.shape) #4.56.56.96;下采样一次:64.7.7.192((56/2/7)*(56/2/7))=16*4=64
return x
5.SW-MSA(Shifted Window)
- 原来的window都是算自己内部的,没有它们之间的关系,容易上模型局限在自己的小领地,于是执行SW-MSA(Shifted Window)。
- 代码的执行与W-MSA不同的三点如下:
(1)做窗口滑动
- 通过窗口的滑动,划分成新的窗口,计算新窗口内部的MSA
# cyclic shift
if self.shift_size > 0:#做不做窗口滑动,刚开始shift_size为0,不做偏移
if not self.fused_window_process:
shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))#进行偏移
# partition windows
x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
else:
x_windows = WindowProcess.apply(x, B, H, W, C, -self.shift_size, self.window_size)
else:
shifted_x = x
(2)mask
- 原来算窗口自注意机制只用算4个,移动后需要算9个,为了让移动后窗户依然保持4个且每个窗口中的patch数量也保持一致,于是提出了mask。对于移动后的拼接在一起的新窗口,其中包含了不是挨着的地方移动过来的部分,他们之间不需要做自注意力机制,于是使用mask掩掉。
示意图如下:
if mask is not None:#W-MSA不执行mask
nW = mask.shape[0]
attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
attn = attn.view(-1, self.num_heads, N, N)
attn = self.softmax(attn)
else:
attn = self.softmax(attn)
(3)还原shift
- 计算完特征后需要对图像进行还原,也就是还原平移
# reverse cyclic shift
if self.shift_size > 0:
x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))#还原shift-3变为3
print(x.shape)#
else:
x = shifted_x
x = x.view(B, H * W, C)
print(x.shape) #4.3136.96
# FFN残差连接
x = shortcut + self.drop_path(x)
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x
6.PatchMerging
- 类似于卷积神经网络中的池化操作,增大了感受野,PatchMerging把相邻的4小patch合成一个大patch,从而实现增大感受野,获取多尺寸的特征。如图所示:
- 对于具体的把相邻的4小patch合成一个大patch是间隔取,对H和W维度进行间隔采样后拼接在一起,得到H/2,W/2,C*4。示意图如下:
class PatchMerging(nn.Module):
r""" Patch Merging Layer.
Args:
input_resolution (tuple[int]): Resolution of input feature.
dim (int): Number of input channels.
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
"""
def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
super().__init__()
self.input_resolution = input_resolution
self.dim = dim
self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
self.norm = norm_layer(4 * dim)
def forward(self, x):
"""
x: B, H*W, C
"""
H, W = self.input_resolution
B, L, C = x.shape
assert L == H * W, "input feature has wrong size"
assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."
x = x.view(B, H, W, C)
x0 = x[:, 0::2, 0::2, :] #切片 B H/2 W/2 C
x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
x = torch.cat([x0, x1, x2, x3], -1) #拼接 B H/2 W/2 4*C
x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
x = self.norm(x)
x = self.reduction(x)
return x
7.分层计算(执行后续的Block)
- 一次下采样后(3136->784也就是5656->2828),然后继续走W-MSA和SW-MSA,也就是整体网络架构图中的各个stage的流程
8.输出层
x = self.norm(x) # B L C
print(x.shape)#4.49.768
x = self.avgpool(x.transpose(1, 2)) #平均池化 B C 1
print(x.shape)#4.768.1
x = torch.flatten(x, 1)
print(x.shape)#4.768
return x
def forward(self, x):#把768个向量转换成1000个类别
x = self.forward_features(x)
x = self.head(x)
return x