手写Swin Transformer
1.如何基于图片生成patch_embedding?
方法一:
- 基于pytorch unfold的API来将图片进行分块,也就是模仿卷积的思路,设置kernel_size=stride=patch_size,得到分块后的图片
- 得到格式为[bs, num_patch, patch_depth]的张量
- 将张量与形状为[patch_depth, model_dim_C]的权重矩阵进行乘法操作,即可得到形状为[bs, num_patch,model_dim_C]的patch_embedding
import torch
import torch.nn as nn
import torch.nn.functional as F
"""难点一: patch_embedding"""
def image2emb_naive(image, patch_size, weight):
"""直观方法去实现patch embedding"""
# image shape: bs * channel * h * w
patch = F.unfold(image, kernel_size=(patch_size,patch_size),
stride=(patch_size,patch_size)).transpose(-1,-2)
patch_embedding = patch @ weight
return patch_embedding
def image2emb_conv(image, kernel, stride):
"""基于二维卷积来实现 patch_embedding, embedding的维度就是卷积的输出通道数"""
conv_output = F.conv2d(image, kernel, stride=stride) # bs*oc*oh*ow
bs,oc,oh,ow = conv_output.shape
patch_embedding = conv_output.reshape((bs, oc, oh*ow)).transpose(-1, -2)
return patch_embedding
class MultiHeadSelfAttention(nn.Module):
def __init__(self, model_dim, num_head):
super(MultiHeadSelfAttention, self).__init__()
self.num_head = num_head
self.proj_linear_layer = nn.Linear(model_dim, 3*model_dim)
self.final_linear_layer = nn.Linear(model_dim, model_dim)
def forward(self, input, additive_mask=None):
bs, seqlen, model_dim = input.shape
num_head = self.num_head
head_dim = model_dim // num_head
2.如何基于图片生成patch_embedding?
- 基于输入x进行三个映射分别得到Q,K,V
- 此步复杂度为$3LC^2$,其中L为序列长度,C为特征大小 - 将Q,K,V拆分成多头的形式,注意这里的多头各自计算不影响,所以可以与bs维度进行统一看待
- 计算$QK^T$,并考虑可能的掩码,即让无效的两两位置之间的能量为负无穷,掩码是在shift window MHSA中会需要,而在window MHSA中暂不需要
- 此步复杂度为$L^2C$ - 计算概率值与V的乘积
- 此步复杂度为$L^2C$ - 对输出进行再次映射
- 总体复杂度为$4LC2+2L2C$
全栈爱好者,欢迎交流学习