手写Swin Transformer
1.如何基于图片生成patch_embedding?
方法一:
- 基于pytorch unfold的API来将图片进行分块,也就是模仿卷积的思路,设置kernel_size=stride=patch_size,得到分块后的图片
- 得到格式为[bs, num_patch, patch_depth]的张量
- 将张量与形状为[patch_depth, model_dim_C]的权重矩阵进行乘法操作,即可得到形状为[bs, num_patch,model_dim_C]的patch_embedding
import torch
import torch.nn as nn
import torch.nn.functional as F
"""难点一: patch_embedding"""
def image2emb_naive(image, patch_size, weight):
"""直观方法去实现patch embedding"""
# image shape: bs * channel * h * w
patch = F.unfold(image, kernel_size=(patch_size,patch_size),
stride=(patch_size,patch_size)).transpose(-1,-2)
patch_embedding = patch @ weight
return patch_embedding
def image2emb_conv(image, kernel, stride):
"""基于二维卷积来实现 patch_embedding, embedding的维度就是卷积的输出通道数"""
conv_output = F.conv2d(image, kernel, stride=stride) # bs*oc*oh*ow
bs,oc,oh,ow = conv_output.shape
patch_embedding = conv_output.reshape((bs, oc, oh*ow)).transpose(-1, -2)
return patch_embedding
class MultiHeadSelfAttention(nn.Module):
def __init__(self, model_dim, num_head):
super(MultiHeadSelfAttention, self).__init__()
self.num_head = num_head
self.proj_linear_layer = nn.Linear(model_dim, 3*model_dim)
self.final_linear_layer = nn.Linear(model_dim, model_dim)
def forward(self, input, additive_mask=None):
bs, seqlen, model_dim = input.shape
num_head = self.num_head
head_dim = model_dim // num_head
2.如何基于图片生成patch_embedding?
- 基于输入x进行三个映射分别得到Q,K,V
- 此步复杂度为$3LC^2$,其中L为序列长度,C为特征大小 - 将Q,K,V拆分成多头的形式,注意这里的多头各自计算不影响,所以可以与bs维度进行统一看待
- 计算$QK^T$,并考虑可能的掩码,即让无效的两两位置之间的能量为负无穷,掩码是在shift window MHSA中会需要,而在window MHSA中暂不需要
- 此步复杂度为$L^2C$ - 计算概率值与V的乘积
- 此步复杂度为$L^2C$ - 对输出进行再次映射
- 总体复杂度为$4LC2+2L2C$
全栈爱好者,欢迎交流学习
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· 无需6万激活码!GitHub神秘组织3小时极速复刻Manus,手把手教你使用OpenManus搭建本
· C#/.NET/.NET Core优秀项目和框架2025年2月简报
· 葡萄城 AI 搜索升级:DeepSeek 加持,客户体验更智能
· 什么是nginx的强缓存和协商缓存
· 一文读懂知识蒸馏