Segmation Model 结构-代码

Segmation Model 结构-代码

Deeplab系列

DeepLabV1

作者发现Deep Convolutional Neural Networks (DCNNs) 能够很好的处理的图像级别的分类问题,因为它具有很好的平移不变性(空间细节信息已高度抽象),但是DCNNs很难处理像素级别的分类问题,例如姿态估计和语义分割,它们需要准确的位置信息。

1创新点:

  • 将深度神经网络DCNN与全连接CRF结合起来,提高图像分割的分割精度。
  • 提出空洞卷积的思想。
  • 应用尝试了多尺度、多层次的信息融合。

2. 动机:

DCNN应用在语义分割任务上存在两个缺陷:

  • 重复堆叠的池化下采样操作导致分辨率大幅下降,位置信息丢失难以恢复。
  • 从分类器获得以对象为中心的决策需要空间转换的不变性,忽略对细节处的把控,这从本质上限制了DCNN模型的空间准确性。

分类任务具有空间不变性,图像的仿射变换不会影响最后的分类结果,而且恰恰通过仿射变换等操作(数据增广)可以增加数据量,提高模型的精度;但是像分割和检测这类问题,不具有空间不变性。

3. 应对策略:

  • 空洞卷积
  • Fully-connected Conditional Random Field (CRF)

2.DeepLabV2

1.创新点:

  • 空洞卷积,作为密集预测任务的强大工具。空洞卷积能够明确地控制DCNN内计算特征响应的分辨率。它还允许我们有效地扩大滤波器的视野以并入较大的上下文,而不增加参数的数量或计算量。
  • 提出了空洞空间卷积池化金字塔(atrous spatial pyramid pooling (ASPP)),以多尺度的信息得到更精确的分割结果。ASPP并行的采用多个采样率的空洞卷积层来探测,以多个比例捕捉对象以及图像上下文。
  • 通过组合DCNN和概率图模型(CRF),改进分割边界结果。在DCNN中最大池化和下采样组合实现可平移不变性,但这对精度是有影响的。通过将最终的DCNN层响应与全连接的CRF结合来克服这个问题。

2.动机

DCNNs中语义分割存在三个挑战:

  • 连续下采样和池化操作,导致最后特征图分辨率低。
  • 图像中存在多尺度的物体(相比V1而言提出的新的挑战
  • 空间不变性导致细节信息丢失

3. 应对策略:

  • 移除部分池化操作,使用空洞卷积。
  • 利用不同膨胀因子的空洞卷积融合多尺度信息—atrous spatial pyramid pooling(ASPP)(新的创新点
  • 全连接CRF。

Deeplabv3

Chen, L., Papandreou, G., Schroff, F., & Adam, H. (2017). Rethinking Atrous Convolution for Semantic Image Segmentation. ArXiv, abs/1706.05587.

1.创新点:

  • 增加了多尺度(multiple scales)分割物体的模块
  • 设计了串行和并行的空洞卷积模块,采用多种不同的atrous rates(采样率)来获取多尺度的内容信息

2. 动机:

DCNN中语义分割存在三个挑战:

  • 连续下采用和重复池化,导致最后特征图分辨率低
  • 图像中存在多尺度的物体

3. 应对策略:

  • 使用空洞卷积,防止分辨率过低情况
  • 串联不同膨胀率的空洞卷积或者并行不同膨胀率的空洞卷积(v2的ASPP),来获取更多上下文信息

4. 主要贡献:

  • 重新讨论了空洞卷积的使用,这让我们在级联模块和空间金字塔池化的框架下,能够获取更大的感受野从而获取多尺度信息
  • 改进了ASPP模块:由不同采样率的空洞卷积和BN层组成,我们尝试以级联或并行的方式布局模块
  • 讨论了一个重要问题:使用大采样率的3×3的空洞卷积,因为图像边界的原因无法捕捉远距离信息,会退化为1×1的卷积, 作者提出将图像级特征融合到ASPP模块中
  • 阐述了训练细节并分享了训练经验,论文提出的”DeepLabv3”改进了以前的工作,获得了很好的结果

DeepLabV3详解

5.提出问题

首先,语义分割问题存在两大挑战:

  • 第一个挑战:连续池化操作或卷积中的stride导致的特征分辨率降低。这使得DCNN能够学习更抽象的特征表示。然而,这种不变性可能会阻碍密集预测任务,因为不变性也导致了详细空间信息的不确定。为了克服这个问题,提倡使用空洞卷积。
  • 第二个挑战:多尺度物体的存在。几种方法已经被提出来处理这个问题,在本文中我们主要考虑了这些工作中的四种类型,如图1所示。

  • 第一种:Image Pyramid,将输入图片放缩成不同比例,分别应用在DCNN上,将预测结果融合得到最终输出
  • 第二种:Encoder-Decoder,将Encoder阶段的多尺度特征运用到Decoder阶段上来恢复空间分辨率
  • 第三种:在原始模型的顶端叠加额外的模块,以捕捉像素间长距离信息。例如Dense CRF,或者叠加一些其他的卷积层
  • 第四种:Spatial Pyramid Pooling空间金字塔池化,使用不同采样率和多种视野的卷积核,以捕捉多尺度对象

Deeplabv3+

Chen, L., Zhu, Y., Papandreou, G., Schroff, F., & Adam, H. (2018). Encoder-Decoder with Atrous Separable Convolution for Semantic Image Segmentation. ArXiv, abs/1802.02611.

1.创新点:

  • 更深的Xception结构,不同的地方在于不修改entry flow network的结构,为了快速计算和有效的使用内存
  • 所有的max pooling结构被stride=2的深度可分离卷积代替
  • 每个3x3的depthwise convolution都跟BN和Relu
  • 将改进后的Xception作为encodet主干网络,替换原本DeepLabv3的ResNet101

2. 动机:

语义分割主要面临两个问题:

  • 物体的多尺度问题(DeepLabV3解决)
  • DCNN的多次下采样会造成特征图分辨率变小,导致预测精度降低,边界信息丢失(DeepLabV3+解决目标)

3. 应对策略:

  • 改进Xception,层数增加
  • 将所有最大值池化层替换为带步长的深度可分离卷积层

Pytorch代码

copy
class DeepLabV3Decoder(nn.Sequential): def __init__(self, in_channels, out_channels=256, atrous_rates=(12, 24, 36)): super().__init__( ASPP(in_channels, out_channels, atrous_rates), nn.Conv2d(out_channels, out_channels, 3, padding=1, bias=False), nn.BatchNorm2d(out_channels), nn.ReLU(), ) self.out_channels = out_channels def forward(self, *features): return super().forward(features[-1]) class DeepLabV3PlusDecoder(nn.Module): def __init__( self, encoder_channels, out_channels=256, atrous_rates=(12, 24, 36), output_stride=16, ): super().__init__() if output_stride not in {8, 16}: raise ValueError("Output stride should be 8 or 16, got {}.".format(output_stride)) self.out_channels = out_channels self.output_stride = output_stride self.aspp = nn.Sequential( ASPP(encoder_channels[-1], out_channels, atrous_rates, separable=True), SeparableConv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=False), nn.BatchNorm2d(out_channels), nn.ReLU(), ) scale_factor = 2 if output_stride == 8 else 4 self.up = nn.UpsamplingBilinear2d(scale_factor=scale_factor) highres_in_channels = encoder_channels[-4] highres_out_channels = 48 # proposed by authors of paper self.block1 = nn.Sequential( nn.Conv2d(highres_in_channels, highres_out_channels, kernel_size=1, bias=False), nn.BatchNorm2d(highres_out_channels), nn.ReLU(), ) self.block2 = nn.Sequential( SeparableConv2d( highres_out_channels + out_channels, out_channels, kernel_size=3, padding=1, bias=False, ), nn.BatchNorm2d(out_channels), nn.ReLU(), ) def forward(self, *features): aspp_features = self.aspp(features[-1]) aspp_features = self.up(aspp_features) high_res_features = self.block1(features[-4]) concat_features = torch.cat([aspp_features, high_res_features], dim=1) fused_features = self.block2(concat_features) return fused_features class ASPPConv(nn.Sequential): def __init__(self, in_channels, out_channels, dilation): super().__init__( nn.Conv2d( in_channels, out_channels, kernel_size=3, padding=dilation, dilation=dilation, bias=False, ), nn.BatchNorm2d(out_channels), nn.ReLU(), ) class ASPPSeparableConv(nn.Sequential): def __init__(self, in_channels, out_channels, dilation): super().__init__( SeparableConv2d( in_channels, out_channels, kernel_size=3, padding=dilation, dilation=dilation, bias=False, ), nn.BatchNorm2d(out_channels), nn.ReLU(), ) class ASPPPooling(nn.Sequential): def __init__(self, in_channels, out_channels): super().__init__( nn.AdaptiveAvgPool2d(1), nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False), nn.BatchNorm2d(out_channels), nn.ReLU(), ) def forward(self, x): size = x.shape[-2:] for mod in self: x = mod(x) return F.interpolate(x, size=size, mode='bilinear', align_corners=False) class ASPP(nn.Module): def __init__(self, in_channels, out_channels, atrous_rates, separable=False): super(ASPP, self).__init__() modules = [] modules.append( nn.Sequential( nn.Conv2d(in_channels, out_channels, 1, bias=False), nn.BatchNorm2d(out_channels), nn.ReLU(), ) ) rate1, rate2, rate3 = tuple(atrous_rates) ASPPConvModule = ASPPConv if not separable else ASPPSeparableConv modules.append(ASPPConvModule(in_channels, out_channels, rate1)) modules.append(ASPPConvModule(in_channels, out_channels, rate2)) modules.append(ASPPConvModule(in_channels, out_channels, rate3)) modules.append(ASPPPooling(in_channels, out_channels)) self.convs = nn.ModuleList(modules) self.project = nn.Sequential( nn.Conv2d(5 * out_channels, out_channels, kernel_size=1, bias=False), nn.BatchNorm2d(out_channels), nn.ReLU(), nn.Dropout(0.5), ) def forward(self, x): res = [] for conv in self.convs: res.append(conv(x)) res = torch.cat(res, dim=1) return self.project(res) class SeparableConv2d(nn.Sequential): def __init__( self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, bias=True, ): dephtwise_conv = nn.Conv2d( in_channels, in_channels, kernel_size, stride=stride, padding=padding, dilation=dilation, groups=in_channels, bias=False, ) pointwise_conv = nn.Conv2d( in_channels, out_channels, kernel_size=1, bias=bias, ) super().__init__(dephtwise_conv, pointwise_conv)

Fpn

模型结构

Pytorch代码

copy
import torch import torch.nn as nn import torch.nn.functional as F class Conv3x3GNReLU(nn.Module): def __init__(self, in_channels, out_channels, upsample=False): super().__init__() self.upsample = upsample self.block = nn.Sequential( nn.Conv2d( in_channels, out_channels, (3, 3), stride=1, padding=1, bias=False ), nn.GroupNorm(32, out_channels), nn.ReLU(inplace=True), ) def forward(self, x): x = self.block(x) if self.upsample: x = F.interpolate(x, scale_factor=2, mode="bilinear", align_corners=True) return x class FPNBlock(nn.Module): def __init__(self, pyramid_channels, skip_channels): super().__init__() self.skip_conv = nn.Conv2d(skip_channels, pyramid_channels, kernel_size=1) def forward(self, x, skip=None): x = F.interpolate(x, scale_factor=2, mode="nearest") skip = self.skip_conv(skip) x = x + skip return x class SegmentationBlock(nn.Module): def __init__(self, in_channels, out_channels, n_upsamples=0): super().__init__() blocks = [Conv3x3GNReLU(in_channels, out_channels, upsample=bool(n_upsamples))] if n_upsamples > 1: for _ in range(1, n_upsamples): blocks.append(Conv3x3GNReLU(out_channels, out_channels, upsample=True)) self.block = nn.Sequential(*blocks) def forward(self, x): return self.block(x) class MergeBlock(nn.Module): def __init__(self, policy): super().__init__() if policy not in ["add", "cat"]: raise ValueError( "`merge_policy` must be one of: ['add', 'cat'], got {}".format( policy ) ) self.policy = policy def forward(self, x): if self.policy == 'add': return sum(x) elif self.policy == 'cat': return torch.cat(x, dim=1) else: raise ValueError( "`merge_policy` must be one of: ['add', 'cat'], got {}".format(self.policy) ) class FPNDecoder(nn.Module): def __init__( self, encoder_channels, encoder_depth=5, pyramid_channels=256, segmentation_channels=128, dropout=0.2, merge_policy="add", ): super().__init__() self.out_channels = segmentation_channels if merge_policy == "add" else segmentation_channels * 4 if encoder_depth < 3: raise ValueError("Encoder depth for FPN decoder cannot be less than 3, got {}.".format(encoder_depth)) encoder_channels = encoder_channels[::-1] encoder_channels = encoder_channels[:encoder_depth + 1] self.p5 = nn.Conv2d(encoder_channels[0], pyramid_channels, kernel_size=1) self.p4 = FPNBlock(pyramid_channels, encoder_channels[1]) self.p3 = FPNBlock(pyramid_channels, encoder_channels[2]) self.p2 = FPNBlock(pyramid_channels, encoder_channels[3]) self.seg_blocks = nn.ModuleList([ SegmentationBlock(pyramid_channels, segmentation_channels, n_upsamples=n_upsamples) for n_upsamples in [3, 2, 1, 0] ]) self.merge = MergeBlock(merge_policy) self.dropout = nn.Dropout2d(p=dropout, inplace=True) def forward(self, *features): c2, c3, c4, c5 = features[-4:] p5 = self.p5(c5) p4 = self.p4(p5, c4) p3 = self.p3(p4, c3) p2 = self.p2(p3, c2) feature_pyramid = [seg_block(p) for seg_block, p in zip(self.seg_blocks, [p5, p4, p3, p2])] x = self.merge(feature_pyramid) x = self.dropout(x) return x

Linknet

Chaurasia, A., & Culurciello, E. (2017). LinkNet: Exploiting encoder representations for efficient semantic segmentation. 2017 IEEE Visual Communications and Image Processing (VCIP), 1-4.

模型结构

Pytorch代码

copy
import torch.nn as nn from ..base import modules class TransposeX2(nn.Sequential): def __init__(self, in_channels, out_channels, use_batchnorm=True): super().__init__() layers = [ nn.ConvTranspose2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1), nn.ReLU(inplace=True) ] if use_batchnorm: layers.insert(1, nn.BatchNorm2d(out_channels)) super().__init__(*layers) class DecoderBlock(nn.Module): def __init__(self, in_channels, out_channels, use_batchnorm=True): super().__init__() self.block = nn.Sequential( modules.Conv2dReLU(in_channels, in_channels // 4, kernel_size=1, use_batchnorm=use_batchnorm), TransposeX2(in_channels // 4, in_channels // 4, use_batchnorm=use_batchnorm), modules.Conv2dReLU(in_channels // 4, out_channels, kernel_size=1, use_batchnorm=use_batchnorm), ) def forward(self, x, skip=None): x = self.block(x) if skip is not None: x = x + skip return x class LinknetDecoder(nn.Module): def __init__( self, encoder_channels, prefinal_channels=32, n_blocks=5, use_batchnorm=True, ): super().__init__() encoder_channels = encoder_channels[1:] # remove first skip encoder_channels = encoder_channels[::-1] # reverse channels to start from head of encoder channels = list(encoder_channels) + [prefinal_channels] self.blocks = nn.ModuleList([ DecoderBlock(channels[i], channels[i + 1], use_batchnorm=use_batchnorm) for i in range(n_blocks) ]) def forward(self, *features): features = features[1:] # remove first skip features = features[::-1] # reverse channels to start from head of encoder x = features[0] skips = features[1:] for i, decoder_block in enumerate(self.blocks): skip = skips[i] if i < len(skips) else None x = decoder_block(x, skip) return x

MAnet

Fan, T., Wang, G., Li, Y., & Wang, H. (2020). MA-Net: A Multi-Scale Attention Network for Liver and Tumor Segmentation. IEEE Access, 8, 179656-179665.

模型结构

PAB Block

MFAB Block

Pytorch代码

copy
import torch import torch.nn as nn import torch.nn.functional as F from ..base import modules as md class PAB(nn.Module): def __init__(self, in_channels, out_channels, pab_channels=64): super(PAB, self).__init__() # Series of 1x1 conv to generate attention feature maps self.pab_channels = pab_channels self.in_channels = in_channels self.top_conv = nn.Conv2d(in_channels, pab_channels, kernel_size=1) self.center_conv = nn.Conv2d(in_channels, pab_channels, kernel_size=1) self.bottom_conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1) self.map_softmax = nn.Softmax(dim=1) self.out_conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1) def forward(self, x): bsize = x.size()[0] h = x.size()[2] w = x.size()[3] x_top = self.top_conv(x) x_center = self.center_conv(x) x_bottom = self.bottom_conv(x) x_top = x_top.flatten(2) x_center = x_center.flatten(2).transpose(1, 2) x_bottom = x_bottom.flatten(2).transpose(1, 2) sp_map = torch.matmul(x_center, x_top) sp_map = self.map_softmax(sp_map.view(bsize, -1)).view(bsize, h*w, h*w) sp_map = torch.matmul(sp_map, x_bottom) sp_map = sp_map.reshape(bsize, self.in_channels, h, w) x = x + sp_map x = self.out_conv(x) return x class MFAB(nn.Module): def __init__(self, in_channels, skip_channels, out_channels, use_batchnorm=True, reduction=16): # MFAB is just a modified version of SE-blocks, one for skip, one for input super(MFAB, self).__init__() self.hl_conv = nn.Sequential( md.Conv2dReLU( in_channels, in_channels, kernel_size=3, padding=1, use_batchnorm=use_batchnorm, ), md.Conv2dReLU( in_channels, skip_channels, kernel_size=1, use_batchnorm=use_batchnorm, ) ) reduced_channels = max(1, skip_channels // reduction) self.SE_ll = nn.Sequential( nn.AdaptiveAvgPool2d(1), nn.Conv2d(skip_channels, reduced_channels, 1), nn.ReLU(inplace=True), nn.Conv2d(reduced_channels, skip_channels, 1), nn.Sigmoid(), ) self.SE_hl = nn.Sequential( nn.AdaptiveAvgPool2d(1), nn.Conv2d(skip_channels, reduced_channels, 1), nn.ReLU(inplace=True), nn.Conv2d(reduced_channels, skip_channels, 1), nn.Sigmoid(), ) self.conv1 = md.Conv2dReLU( skip_channels + skip_channels, # we transform C-prime form high level to C from skip connection out_channels, kernel_size=3, padding=1, use_batchnorm=use_batchnorm, ) self.conv2 = md.Conv2dReLU( out_channels, out_channels, kernel_size=3, padding=1, use_batchnorm=use_batchnorm, ) def forward(self, x, skip=None): x = self.hl_conv(x) x = F.interpolate(x, scale_factor=2, mode="nearest") attention_hl = self.SE_hl(x) if skip is not None: attention_ll = self.SE_ll(skip) attention_hl = attention_hl + attention_ll x = x * attention_hl x = torch.cat([x, skip], dim=1) x = self.conv1(x) x = self.conv2(x) return x class DecoderBlock(nn.Module): def __init__( self, in_channels, skip_channels, out_channels, use_batchnorm=True ): super().__init__() self.conv1 = md.Conv2dReLU( in_channels + skip_channels, out_channels, kernel_size=3, padding=1, use_batchnorm=use_batchnorm, ) self.conv2 = md.Conv2dReLU( out_channels, out_channels, kernel_size=3, padding=1, use_batchnorm=use_batchnorm, ) def forward(self, x, skip=None): x = F.interpolate(x, scale_factor=2, mode="nearest") if skip is not None: x = torch.cat([x, skip], dim=1) x = self.conv1(x) x = self.conv2(x) return x class MAnetDecoder(nn.Module): def __init__( self, encoder_channels, decoder_channels, n_blocks=5, reduction=16, use_batchnorm=True, pab_channels=64 ): super().__init__() if n_blocks != len(decoder_channels): raise ValueError( "Model depth is {}, but you provide `decoder_channels` for {} blocks.".format( n_blocks, len(decoder_channels) ) ) encoder_channels = encoder_channels[1:] # remove first skip with same spatial resolution encoder_channels = encoder_channels[::-1] # reverse channels to start from head of encoder # computing blocks input and output channels head_channels = encoder_channels[0] in_channels = [head_channels] + list(decoder_channels[:-1]) skip_channels = list(encoder_channels[1:]) + [0] out_channels = decoder_channels self.center = PAB(head_channels, head_channels, pab_channels=pab_channels) # combine decoder keyword arguments kwargs = dict(use_batchnorm=use_batchnorm) # no attention type here blocks = [ MFAB(in_ch, skip_ch, out_ch, reduction=reduction, **kwargs) if skip_ch > 0 else DecoderBlock(in_ch, skip_ch, out_ch, **kwargs) for in_ch, skip_ch, out_ch in zip(in_channels, skip_channels, out_channels) ] # for the last we dont have skip connection -> use simple decoder block self.blocks = nn.ModuleList(blocks) def forward(self, *features): # for fea in features: # print(fea.shape) # exit() features = features[1:] # remove first skip with same spatial resolution features = features[::-1] # reverse channels to start from head of encoder head = features[0] skips = features[1:] x = self.center(head) print(x.shape) for i, decoder_block in enumerate(self.blocks): skip = skips[i] if i < len(skips) else None x = decoder_block(x, skip) print(x.shape) exit() return x

PAN

Li, H., Xiong, P., An, J., & Wang, L. (2018). Pyramid Attention Network for Semantic Segmentation. ArXiv, abs/1805.10180.

模型结构


Pytorch代码

copy
import torch import torch.nn as nn import torch.nn.functional as F class ConvBnRelu(nn.Module): def __init__( self, in_channels: int, out_channels: int, kernel_size: int, stride: int = 1, padding: int = 0, dilation: int = 1, groups: int = 1, bias: bool = True, add_relu: bool = True, interpolate: bool = False ): super(ConvBnRelu, self).__init__() self.conv = nn.Conv2d( in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, bias=bias, groups=groups ) self.add_relu = add_relu self.interpolate = interpolate self.bn = nn.BatchNorm2d(out_channels) self.activation = nn.ReLU(inplace=True) def forward(self, x): x = self.conv(x) x = self.bn(x) if self.add_relu: x = self.activation(x) if self.interpolate: x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=True) return x class FPABlock(nn.Module): def __init__( self, in_channels, out_channels, upscale_mode='bilinear' ): super(FPABlock, self).__init__() self.upscale_mode = upscale_mode if self.upscale_mode == 'bilinear': self.align_corners = True else: self.align_corners = False # global pooling branch self.branch1 = nn.Sequential( nn.AdaptiveAvgPool2d(1), ConvBnRelu(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=1, padding=0) ) # midddle branch self.mid = nn.Sequential( ConvBnRelu(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=1, padding=0) ) self.down1 = nn.Sequential( nn.MaxPool2d(kernel_size=2, stride=2), ConvBnRelu(in_channels=in_channels, out_channels=1, kernel_size=7, stride=1, padding=3) ) self.down2 = nn.Sequential( nn.MaxPool2d(kernel_size=2, stride=2), ConvBnRelu(in_channels=1, out_channels=1, kernel_size=5, stride=1, padding=2) ) self.down3 = nn.Sequential( nn.MaxPool2d(kernel_size=2, stride=2), ConvBnRelu(in_channels=1, out_channels=1, kernel_size=3, stride=1, padding=1), ConvBnRelu(in_channels=1, out_channels=1, kernel_size=3, stride=1, padding=1), ) self.conv2 = ConvBnRelu(in_channels=1, out_channels=1, kernel_size=5, stride=1, padding=2) self.conv1 = ConvBnRelu(in_channels=1, out_channels=1, kernel_size=7, stride=1, padding=3) def forward(self, x): h, w = x.size(2), x.size(3) b1 = self.branch1(x) upscale_parameters = dict( mode=self.upscale_mode, align_corners=self.align_corners ) b1 = F.interpolate(b1, size=(h, w), **upscale_parameters) mid = self.mid(x) x1 = self.down1(x) x2 = self.down2(x1) x3 = self.down3(x2) x3 = F.interpolate(x3, size=(h // 4, w // 4), **upscale_parameters) x2 = self.conv2(x2) x = x2 + x3 x = F.interpolate(x, size=(h // 2, w // 2), **upscale_parameters) x1 = self.conv1(x1) x = x + x1 x = F.interpolate(x, size=(h, w), **upscale_parameters) x = torch.mul(x, mid) x = x + b1 return x class GAUBlock(nn.Module): def __init__( self, in_channels: int, out_channels: int, upscale_mode: str = 'bilinear' ): super(GAUBlock, self).__init__() self.upscale_mode = upscale_mode self.align_corners = True if upscale_mode == 'bilinear' else None self.conv1 = nn.Sequential( nn.AdaptiveAvgPool2d(1), ConvBnRelu(in_channels=out_channels, out_channels=out_channels, kernel_size=1, add_relu=False), nn.Sigmoid() ) self.conv2 = ConvBnRelu(in_channels=in_channels, out_channels=out_channels, kernel_size=3, padding=1) def forward(self, x, y): """ Args: x: low level feature y: high level feature """ h, w = x.size(2), x.size(3) y_up = F.interpolate( y, size=(h, w), mode=self.upscale_mode, align_corners=self.align_corners ) x = self.conv2(x) y = self.conv1(y) z = torch.mul(x, y) return y_up + z class PANDecoder(nn.Module): def __init__( self, encoder_channels, decoder_channels, upscale_mode: str = 'bilinear' ): super().__init__() self.fpa = FPABlock(in_channels=encoder_channels[-1], out_channels=decoder_channels) self.gau3 = GAUBlock(in_channels=encoder_channels[-2], out_channels=decoder_channels, upscale_mode=upscale_mode) self.gau2 = GAUBlock(in_channels=encoder_channels[-3], out_channels=decoder_channels, upscale_mode=upscale_mode) self.gau1 = GAUBlock(in_channels=encoder_channels[-4], out_channels=decoder_channels, upscale_mode=upscale_mode) def forward(self, *features): bottleneck = features[-1] x5 = self.fpa(bottleneck) # 1/32 x4 = self.gau3(features[-2], x5) # 1/16 x3 = self.gau2(features[-3], x4) # 1/8 x2 = self.gau1(features[-4], x3) # 1/4 return x2

PSPnet

模型结构

Pytorch代码

copy
import torch import torch.nn as nn import torch.nn.functional as F from ..base import modules class PSPBlock(nn.Module): def __init__(self, in_channels, out_channels, pool_size, use_bathcnorm=True): super().__init__() if pool_size == 1: use_bathcnorm = False # PyTorch does not support BatchNorm for 1x1 shape self.pool = nn.Sequential( nn.AdaptiveAvgPool2d(output_size=(pool_size, pool_size)), modules.Conv2dReLU(in_channels, out_channels, (1, 1), use_batchnorm=use_bathcnorm) ) def forward(self, x): h, w = x.size(2), x.size(3) x = self.pool(x) x = F.interpolate(x, size=(h, w), mode='bilinear', align_corners=True) return x class PSPModule(nn.Module): def __init__(self, in_channels, sizes=(1, 2, 3, 6), use_bathcnorm=True): super().__init__() self.blocks = nn.ModuleList([ PSPBlock(in_channels, in_channels // len(sizes), size, use_bathcnorm=use_bathcnorm) for size in sizes ]) def forward(self, x): xs = [block(x) for block in self.blocks] + [x] x = torch.cat(xs, dim=1) return x class PSPDecoder(nn.Module): def __init__( self, encoder_channels, use_batchnorm=True, out_channels=512, dropout=0.2, ): super().__init__() self.psp = PSPModule( in_channels=encoder_channels[-1], sizes=(1, 2, 3, 6), use_bathcnorm=use_batchnorm, ) self.conv = modules.Conv2dReLU( in_channels=encoder_channels[-1] * 2, out_channels=out_channels, kernel_size=1, use_batchnorm=use_batchnorm, ) self.dropout = nn.Dropout2d(p=dropout) def forward(self, *features): x = features[-1] x = self.psp(x) x = self.conv(x) x = self.dropout(x) return x

U-net

模型结构

Pytorch代码

copy
import torch import torch.nn as nn import torch.nn.functional as F from ..base import modules as md class DecoderBlock(nn.Module): def __init__( self, in_channels, skip_channels, out_channels, use_batchnorm=True, attention_type=None, ): super().__init__() self.conv1 = md.Conv2dReLU( in_channels + skip_channels, out_channels, kernel_size=3, padding=1, use_batchnorm=use_batchnorm, ) self.attention1 = md.Attention(attention_type, in_channels=in_channels + skip_channels) self.conv2 = md.Conv2dReLU( out_channels, out_channels, kernel_size=3, padding=1, use_batchnorm=use_batchnorm, ) self.attention2 = md.Attention(attention_type, in_channels=out_channels) def forward(self, x, skip=None): x = F.interpolate(x, scale_factor=2, mode="nearest") if skip is not None: x = torch.cat([x, skip], dim=1) x = self.attention1(x) x = self.conv1(x) x = self.conv2(x) x = self.attention2(x) return x class CenterBlock(nn.Sequential): def __init__(self, in_channels, out_channels, use_batchnorm=True): conv1 = md.Conv2dReLU( in_channels, out_channels, kernel_size=3, padding=1, use_batchnorm=use_batchnorm, ) conv2 = md.Conv2dReLU( out_channels, out_channels, kernel_size=3, padding=1, use_batchnorm=use_batchnorm, ) super().__init__(conv1, conv2) class UnetDecoder(nn.Module): def __init__( self, encoder_channels, decoder_channels, n_blocks=5, use_batchnorm=True, attention_type=None, center=False, ): super().__init__() if n_blocks != len(decoder_channels): raise ValueError( "Model depth is {}, but you provide `decoder_channels` for {} blocks.".format( n_blocks, len(decoder_channels) ) ) encoder_channels = encoder_channels[1:] # remove first skip with same spatial resolution encoder_channels = encoder_channels[::-1] # reverse channels to start from head of encoder # computing blocks input and output channels head_channels = encoder_channels[0] in_channels = [head_channels] + list(decoder_channels[:-1]) skip_channels = list(encoder_channels[1:]) + [0] out_channels = decoder_channels if center: self.center = CenterBlock( head_channels, head_channels, use_batchnorm=use_batchnorm ) else: self.center = nn.Identity() # combine decoder keyword arguments kwargs = dict(use_batchnorm=use_batchnorm, attention_type=attention_type) blocks = [ DecoderBlock(in_ch, skip_ch, out_ch, **kwargs) for in_ch, skip_ch, out_ch in zip(in_channels, skip_channels, out_channels) ] self.blocks = nn.ModuleList(blocks) def forward(self, *features): features = features[1:] # remove first skip with same spatial resolution features = features[::-1] # reverse channels to start from head of encoder head = features[0] skips = features[1:] x = self.center(head) for i, decoder_block in enumerate(self.blocks): skip = skips[i] if i < len(skips) else None x = decoder_block(x, skip) return x

Unet-plus-plus

模型结构

Pytorch代码

copy
import torch import torch.nn as nn import torch.nn.functional as F from ..base import modules as md class DecoderBlock(nn.Module): def __init__( self, in_channels, skip_channels, out_channels, use_batchnorm=True, attention_type=None, ): super().__init__() self.conv1 = md.Conv2dReLU( in_channels + skip_channels, out_channels, kernel_size=3, padding=1, use_batchnorm=use_batchnorm, ) self.attention1 = md.Attention(attention_type, in_channels=in_channels + skip_channels) self.conv2 = md.Conv2dReLU( out_channels, out_channels, kernel_size=3, padding=1, use_batchnorm=use_batchnorm, ) self.attention2 = md.Attention(attention_type, in_channels=out_channels) def forward(self, x, skip=None): x = F.interpolate(x, scale_factor=2, mode="nearest") if skip is not None: x = torch.cat([x, skip], dim=1) x = self.attention1(x) x = self.conv1(x) x = self.conv2(x) x = self.attention2(x) return x class CenterBlock(nn.Sequential): def __init__(self, in_channels, out_channels, use_batchnorm=True): conv1 = md.Conv2dReLU( in_channels, out_channels, kernel_size=3, padding=1, use_batchnorm=use_batchnorm, ) conv2 = md.Conv2dReLU( out_channels, out_channels, kernel_size=3, padding=1, use_batchnorm=use_batchnorm, ) super().__init__(conv1, conv2) class UnetPlusPlusDecoder(nn.Module): def __init__( self, encoder_channels, decoder_channels, n_blocks=5, use_batchnorm=True, attention_type=None, center=False, ): super().__init__() if n_blocks != len(decoder_channels): raise ValueError( "Model depth is {}, but you provide `decoder_channels` for {} blocks.".format( n_blocks, len(decoder_channels) ) ) encoder_channels = encoder_channels[1:] # remove first skip with same spatial resolution encoder_channels = encoder_channels[::-1] # reverse channels to start from head of encoder # computing blocks input and output channels head_channels = encoder_channels[0] self.in_channels = [head_channels] + list(decoder_channels[:-1]) self.skip_channels = list(encoder_channels[1:]) + [0] self.out_channels = decoder_channels if center: self.center = CenterBlock( head_channels, head_channels, use_batchnorm=use_batchnorm ) else: self.center = nn.Identity() # combine decoder keyword arguments kwargs = dict(use_batchnorm=use_batchnorm, attention_type=attention_type) blocks = {} for layer_idx in range(len(self.in_channels) - 1): for depth_idx in range(layer_idx+1): if depth_idx == 0: in_ch = self.in_channels[layer_idx] skip_ch = self.skip_channels[layer_idx] * (layer_idx+1) out_ch = self.out_channels[layer_idx] else: out_ch = self.skip_channels[layer_idx] skip_ch = self.skip_channels[layer_idx] * (layer_idx+1-depth_idx) in_ch = self.skip_channels[layer_idx - 1] blocks[f'x_{depth_idx}_{layer_idx}'] = DecoderBlock(in_ch, skip_ch, out_ch, **kwargs) blocks[f'x_{0}_{len(self.in_channels)-1}'] =\ DecoderBlock(self.in_channels[-1], 0, self.out_channels[-1], **kwargs) self.blocks = nn.ModuleDict(blocks) self.depth = len(self.in_channels) - 1 def forward(self, *features): features = features[1:] # remove first skip with same spatial resolution features = features[::-1] # reverse channels to start from head of encoder # start building dense connections dense_x = {} for layer_idx in range(len(self.in_channels)-1): for depth_idx in range(self.depth-layer_idx): if layer_idx == 0: output = self.blocks[f'x_{depth_idx}_{depth_idx}'](features[depth_idx], features[depth_idx+1]) dense_x[f'x_{depth_idx}_{depth_idx}'] = output else: dense_l_i = depth_idx + layer_idx cat_features = [dense_x[f'x_{idx}_{dense_l_i}'] for idx in range(depth_idx+1, dense_l_i+1)] cat_features = torch.cat(cat_features + [features[dense_l_i+1]], dim=1) dense_x[f'x_{depth_idx}_{dense_l_i}'] =\ self.blocks[f'x_{depth_idx}_{dense_l_i}'](dense_x[f'x_{depth_idx}_{dense_l_i-1}'], cat_features) dense_x[f'x_{0}_{self.depth}'] = self.blocks[f'x_{0}_{self.depth}'](dense_x[f'x_{0}_{self.depth-1}']) return dense_x[f'x_{0}_{self.depth}']
posted @   梁君牧  阅读(319)  评论(0编辑  收藏  举报
点击右上角即可分享
微信分享提示
🚀