Multi-Scale and Detail-Enhanced Segment Anything-1-MEEM-差分边缘增强模块

`
import torch.nn as nn
import torch
class MEEM(nn.Module):
def init(self, in_dim, hidden_dim, width=4, norm=nn.BatchNorm2d, act=nn.GELU):
super().init()
self.in_dim = in_dim
self.hidden_dim = hidden_dim
self.width = width
self.in_conv = nn.Sequential(
nn.Conv2d(in_dim, hidden_dim, 1, bias=False),
norm(hidden_dim),
nn.Sigmoid()
)

    self.pool = nn.AvgPool2d(3, stride=1, padding=1)

    self.mid_conv = nn.ModuleList()
    self.edge_enhance = nn.ModuleList()
    for i in range(width - 1):
        self.mid_conv.append(nn.Sequential(
            nn.Conv2d(hidden_dim, hidden_dim, 1, bias=False),
            norm(hidden_dim),
            nn.Sigmoid()
        ))
        self.edge_enhance.append(EdgeEnhancer(hidden_dim, norm, act))

    self.out_conv = nn.Sequential(
        nn.Conv2d(hidden_dim * width, in_dim, 1, bias=False),
        norm(in_dim),
        act()
    )

def forward(self, x):
    # 先降维  将降维之后的直接放入out  用总的组数-1得到少的组 循环补齐其他的组
    mid = self.in_conv(x)

    out = mid


    for i in range(self.width - 1):  # self.width-1是因为out里面已经有了一个经过卷积的值
        mid = self.pool(mid)
        mid = self.mid_conv[i](mid)

        out = torch.cat([out, self.edge_enhance[i](mid)], dim=1)

    out = self.out_conv(out)

    return out

class EdgeEnhancer(nn.Module):
def init(self, in_dim, norm=nn.BatchNorm2d, act=nn.GELU):
super().init()
self.out_conv = nn.Sequential(
nn.Conv2d(in_dim, in_dim, 1, bias=False),
norm(in_dim),
nn.Sigmoid()
)
self.pool = nn.AvgPool2d(3, stride=1, padding=1)

def forward(self, x):
    '''
    经过平均池化  降低细节
    原图减去 弱化细节的图  突出细节
    再将这些细节加到原图上 就增签了原图的细节
    '''
    edge = self.pool(x)
    edge = x - edge
    edge = self.out_conv(edge)
    return x + edge

if name == 'main':
in_dim=128
width=4
hidden_dim=in_dim//width

block = MEEM(in_dim,hidden_dim,width).cuda()
input = torch.randn(3,128,64,64).cuda() #输入 B C H W
output = block(input)

print(input.size())
print(output.size())`

posted @ 2024-11-08 17:11  iceeci  阅读(38)  评论(0编辑  收藏  举报