Multi-Scale and Detail-Enhanced Segment Anything-1-MEEM-差分边缘增强模块
`
import torch.nn as nn
import torch
class MEEM(nn.Module):
def init(self, in_dim, hidden_dim, width=4, norm=nn.BatchNorm2d, act=nn.GELU):
super().init()
self.in_dim = in_dim
self.hidden_dim = hidden_dim
self.width = width
self.in_conv = nn.Sequential(
nn.Conv2d(in_dim, hidden_dim, 1, bias=False),
norm(hidden_dim),
nn.Sigmoid()
)
self.pool = nn.AvgPool2d(3, stride=1, padding=1)
self.mid_conv = nn.ModuleList()
self.edge_enhance = nn.ModuleList()
for i in range(width - 1):
self.mid_conv.append(nn.Sequential(
nn.Conv2d(hidden_dim, hidden_dim, 1, bias=False),
norm(hidden_dim),
nn.Sigmoid()
))
self.edge_enhance.append(EdgeEnhancer(hidden_dim, norm, act))
self.out_conv = nn.Sequential(
nn.Conv2d(hidden_dim * width, in_dim, 1, bias=False),
norm(in_dim),
act()
)
def forward(self, x):
# 先降维 将降维之后的直接放入out 用总的组数-1得到少的组 循环补齐其他的组
mid = self.in_conv(x)
out = mid
for i in range(self.width - 1): # self.width-1是因为out里面已经有了一个经过卷积的值
mid = self.pool(mid)
mid = self.mid_conv[i](mid)
out = torch.cat([out, self.edge_enhance[i](mid)], dim=1)
out = self.out_conv(out)
return out
class EdgeEnhancer(nn.Module):
def init(self, in_dim, norm=nn.BatchNorm2d, act=nn.GELU):
super().init()
self.out_conv = nn.Sequential(
nn.Conv2d(in_dim, in_dim, 1, bias=False),
norm(in_dim),
nn.Sigmoid()
)
self.pool = nn.AvgPool2d(3, stride=1, padding=1)
def forward(self, x):
'''
经过平均池化 降低细节
原图减去 弱化细节的图 突出细节
再将这些细节加到原图上 就增签了原图的细节
'''
edge = self.pool(x)
edge = x - edge
edge = self.out_conv(edge)
return x + edge
if name == 'main':
in_dim=128
width=4
hidden_dim=in_dim//width
block = MEEM(in_dim,hidden_dim,width).cuda()
input = torch.randn(3,128,64,64).cuda() #输入 B C H W
output = block(input)
print(input.size())
print(output.size())`