Multi-Scale and Detail-Enhanced Segment Anything-1-LMSA-轻量级多尺度适配器

用自适应的全局平均池化获得不同尺寸的大小(设置的值就是每层经过池化之后的大小) 再用上采样 让不同尺寸的特征图来到原始大小上 然后将它们和原特征图在深度方面叠加
代码
论文


import torch.nn as nn
import torch
import torch.nn.functional as F
class ModifyPPM(nn.Module):
    def __init__(self, in_dim, reduction_dim, bins):
        super(ModifyPPM, self).__init__()
        self.features = []
        for bin in bins:
            self.features.append(nn.Sequential(
                nn.AdaptiveAvgPool2d(bin),
                nn.Conv2d(in_dim, reduction_dim, kernel_size=1),
                nn.GELU(),
                nn.Conv2d(reduction_dim, reduction_dim, kernel_size=3, bias=False, groups=reduction_dim),
                nn.GELU()
            ))
        self.features = nn.ModuleList(self.features)
        self.local_conv = nn.Sequential(
            nn.Conv2d(in_dim, in_dim, kernel_size=3, padding=1, bias=False, groups=in_dim),
            nn.GELU(),
        )

    def forward(self, x):
        x_size = x.size()  # shape: 1 64 16 16
        out = [self.local_conv(x)] # shape: 1 64 16 16
        for f in self.features:
            y=f(x)  # 获得不同尺寸的特征  通过自适应全局平均池化来实现 传入一个大小 这个大小就是最后输出特征图的大小
            y=F.interpolate(y, x_size[2:], mode='bilinear', align_corners=True) #将特征图的宽高再通过插值的方法变成和输入的宽高一样
            out.append(y)  # 记录这个尺寸对应的特征
        return torch.cat(out, 1)  #将所有尺寸的特征在深度方向进行融合 返回的结果中包含原本的输入和输入在四个尺寸上的特征,而四个尺寸的特征图的深度都是原来的四分之一,四个加在一起刚好和原来的深度相同,也就是说返回的特征图的深度是两倍的输入深度

class LMSA(nn.Module):
    def __init__(self, in_dim, hidden_dim, patch_num):
        super().__init__()
        self.down_project = nn.Linear(in_dim,hidden_dim)
        self.act = nn.GELU()
        self.mppm = ModifyPPM(hidden_dim, hidden_dim //4,  [3,6,9,12])
        self.patch_num = patch_num
        self.up_project = nn.Linear(hidden_dim, in_dim)
        self.down_conv = nn.Sequential(nn.Conv2d(hidden_dim*2, hidden_dim, 1),
                                       nn.GELU())

    def forward(self, x):
        down_x = self.down_project(x) # dim 128 ->64
        down_x = self.act(down_x)

        down_x = down_x.permute(0, 3, 1, 2).contiguous()  #  B C H W
        down_x = self.mppm(down_x).contiguous()
        # 感觉下面有点啰嗦了 返回的是2*hidden_dim  直接用1*1的卷积修改通道数和输入相加就行
        down_x = self.down_conv(down_x)  # 因为返回的特征图的深度被翻倍了 这里再降下来
        down_x = down_x.permute(0, 2, 3, 1).contiguous() # 将通道数移动到最后一个维度

        up_x = self.up_project(down_x) # 升维 使其和输入的x相同 为了后一步的元素相加
        return x + up_x

if __name__ == '__main__':
    in_dim=128
    hidden_dim=64
    patch_num=16

    block = LMSA(in_dim,hidden_dim,patch_num).cuda()
    input = torch.randn(1, patch_num, patch_num, in_dim).cuda() #输入 B C H W
    output = block(input)

    print(input.size())
    print(output.size())

posted @ 2024-11-07 23:08  iceeci  阅读(19)  评论(0编辑  收藏  举报