paper
import torch.nn as nn
import torch
class sMLPBlock(nn.Module):
'''
稀疏MLP 不是一个样本的所有特征通过全连接层 而是部分通过全连接层
'''
def __init__(self, W, H, channels):
super().__init__()
assert W == H
self.channels = channels
self.activation = nn.GELU()
self.BN = nn.BatchNorm2d(channels)
self.proj_h = nn.Conv2d(H, H, (1, 1))
self.proh_w = nn.Conv2d(W, W, (1, 1))
self.fuse = nn.Conv2d(channels*3, channels, (1,1), (1,1), bias=False)
# 也可以这样写
# self.proj_h2=nn.Linear(H,H)
# self.proh_w2=nn.Linear(W,W)
# self.fuse2=nn.Linear(channels*3,channels)
def forward(self, x):
x = self.activation(self.BN(x))
x_w = self.proj_h(x.permute(0, 3, 1, 2)).permute(0, 2, 3, 1)
x_h = self.proh_w(x.permute(0, 2, 1, 3)).permute(0, 2, 1, 3)
x = self.fuse(torch.cat([x, x_h, x_w], dim=1))
print(x)
# 写法2
# x_h2=self.proj_h2(x.permute(0,1,3,2)).permute(0,1,3,2)
# x_w2=self.proh_w2(x)
# x_fuse=torch.cat([x_h2,x_w2,x],dim=1)
# x_2=self.fuse2(x_fuse.permute(0,2,3,1)).permute(0,3,1,2)
# print(x_2)
return x
if __name__ == '__main__':
x = torch.randn(1, 3, 2, 2).cuda() # 输入 B C H W
model = sMLPBlock(2,2,3).cuda()
res=model(x)