sMLP

paper

import torch.nn as nn
import torch
class sMLPBlock(nn.Module):
    '''
    稀疏MLP 不是一个样本的所有特征通过全连接层 而是部分通过全连接层
    '''
    def __init__(self, W, H, channels):
        super().__init__()
        assert W == H
        self.channels = channels
        self.activation = nn.GELU()
        self.BN = nn.BatchNorm2d(channels)
        self.proj_h = nn.Conv2d(H, H, (1, 1))
        self.proh_w = nn.Conv2d(W, W, (1, 1))
        self.fuse = nn.Conv2d(channels*3, channels, (1,1), (1,1), bias=False)
        #  也可以这样写
        # self.proj_h2=nn.Linear(H,H)
        # self.proh_w2=nn.Linear(W,W)
        # self.fuse2=nn.Linear(channels*3,channels)

    def forward(self, x):
        x = self.activation(self.BN(x))
        x_w = self.proj_h(x.permute(0, 3, 1, 2)).permute(0, 2, 3, 1)
        x_h = self.proh_w(x.permute(0, 2, 1, 3)).permute(0, 2, 1, 3)
        x = self.fuse(torch.cat([x, x_h, x_w], dim=1))
        print(x)
        # 写法2
        # x_h2=self.proj_h2(x.permute(0,1,3,2)).permute(0,1,3,2)
        # x_w2=self.proh_w2(x)
        # x_fuse=torch.cat([x_h2,x_w2,x],dim=1)
        # x_2=self.fuse2(x_fuse.permute(0,2,3,1)).permute(0,3,1,2)
        # print(x_2)
        return x

if __name__ == '__main__':
    x = torch.randn(1, 3, 2, 2).cuda() # 输入 B C H W
    model = sMLPBlock(2,2,3).cuda()
    res=model(x)

posted @ 2024-11-12 12:37  iceeci  阅读(1)  评论(0编辑  收藏  举报