手写Swin Transformer

1.如何基于图片生成patch_embedding?
方法一:
- 基于pytorch unfold的API来将图片进行分块,也就是模仿卷积的思路,设置kernel_size=stride=patch_size,得到分块后的图片
- 得到格式为[bs, num_patch, patch_depth]的张量
- 将张量与形状为[patch_depth, model_dim_C]的权重矩阵进行乘法操作,即可得到形状为[bs, num_patch,model_dim_C]的patch_embedding

import torch
import torch.nn as nn
import torch.nn.functional as F

"""难点一: patch_embedding"""
def image2emb_naive(image, patch_size, weight):
    """直观方法去实现patch embedding"""
    # image shape: bs * channel * h * w
    patch = F.unfold(image, kernel_size=(patch_size,patch_size),
                     stride=(patch_size,patch_size)).transpose(-1,-2)
    patch_embedding = patch @ weight
    return patch_embedding


def image2emb_conv(image, kernel, stride):
    """基于二维卷积来实现 patch_embedding, embedding的维度就是卷积的输出通道数"""
    conv_output = F.conv2d(image, kernel, stride=stride) # bs*oc*oh*ow
    bs,oc,oh,ow = conv_output.shape
    patch_embedding = conv_output.reshape((bs, oc, oh*ow)).transpose(-1, -2)
    return patch_embedding

class MultiHeadSelfAttention(nn.Module):
    def __init__(self, model_dim, num_head):
        super(MultiHeadSelfAttention, self).__init__()
        self.num_head = num_head

        self.proj_linear_layer = nn.Linear(model_dim, 3*model_dim)
        self.final_linear_layer = nn.Linear(model_dim, model_dim)

    def forward(self, input, additive_mask=None):
        bs, seqlen, model_dim = input.shape
        num_head = self.num_head
        head_dim = model_dim // num_head

2.如何基于图片生成patch_embedding?

  1. 基于输入x进行三个映射分别得到Q,K,V
    - 此步复杂度为$3LC^2$,其中L为序列长度,C为特征大小
  2. 将Q,K,V拆分成多头的形式,注意这里的多头各自计算不影响,所以可以与bs维度进行统一看待
  3. 计算$QK^T$,并考虑可能的掩码,即让无效的两两位置之间的能量为负无穷,掩码是在shift window MHSA中会需要,而在window MHSA中暂不需要
    - 此步复杂度为$L^2C$
  4. 计算概率值与V的乘积
    - 此步复杂度为$L^2C$
  5. 对输出进行再次映射
  6. 总体复杂度为$4LC2+2L2C$
posted @   Trkly  阅读(11)  评论(0编辑  收藏  举报
相关博文:
阅读排行:
· 无需6万激活码!GitHub神秘组织3小时极速复刻Manus,手把手教你使用OpenManus搭建本
· C#/.NET/.NET Core优秀项目和框架2025年2月简报
· 葡萄城 AI 搜索升级:DeepSeek 加持,客户体验更智能
· 什么是nginx的强缓存和协商缓存
· 一文读懂知识蒸馏
点击右上角即可分享
微信分享提示