Segmation Model 结构-代码
Segmation Model 结构-代码
Deeplab系列
DeepLabV1
作者发现Deep Convolutional Neural Networks (DCNNs) 能够很好的处理的图像级别的分类问题,因为它具有很好的平移不变性(空间细节信息已高度抽象),但是DCNNs很难处理像素级别的分类问题,例如姿态估计和语义分割,它们需要准确的位置信息。
1创新点:
- 将深度神经网络DCNN与全连接CRF结合起来,提高图像分割的分割精度。
- 提出空洞卷积的思想。
- 应用尝试了多尺度、多层次的信息融合。
2. 动机:
DCNN应用在语义分割任务上存在两个缺陷:
- 重复堆叠的池化和下采样操作导致分辨率大幅下降,位置信息丢失难以恢复。
- 从分类器获得以对象为中心的决策需要空间转换的不变性,忽略对细节处的把控,这从本质上限制了DCNN模型的空间准确性。
分类任务具有空间不变性,图像的仿射变换不会影响最后的分类结果,而且恰恰通过仿射变换等操作(数据增广)可以增加数据量,提高模型的精度;但是像分割和检测这类问题,不具有空间不变性。
3. 应对策略:
- 空洞卷积
- Fully-connected Conditional Random Field (CRF)
2.DeepLabV2
1.创新点:
- 空洞卷积,作为密集预测任务的强大工具。空洞卷积能够明确地控制DCNN内计算特征响应的分辨率。它还允许我们有效地扩大滤波器的视野以并入较大的上下文,而不增加参数的数量或计算量。
- 提出了空洞空间卷积池化金字塔(atrous spatial pyramid pooling (ASPP)),以多尺度的信息得到更精确的分割结果。ASPP并行的采用多个采样率的空洞卷积层来探测,以多个比例捕捉对象以及图像上下文。
- 通过组合DCNN和概率图模型(CRF),改进分割边界结果。在DCNN中最大池化和下采样组合实现可平移不变性,但这对精度是有影响的。通过将最终的DCNN层响应与全连接的CRF结合来克服这个问题。
2.动机
DCNNs中语义分割存在三个挑战:
- 连续下采样和池化操作,导致最后特征图分辨率低。
- 图像中存在多尺度的物体(相比V1而言提出的新的挑战)
- 空间不变性导致细节信息丢失
3. 应对策略:
- 移除部分池化操作,使用空洞卷积。
- 利用不同膨胀因子的空洞卷积融合多尺度信息—atrous spatial pyramid pooling(ASPP)(新的创新点)
- 全连接CRF。
Deeplabv3
Chen, L., Papandreou, G., Schroff, F., & Adam, H. (2017). Rethinking Atrous Convolution for Semantic Image Segmentation. ArXiv, abs/1706.05587.
1.创新点:
- 增加了多尺度(multiple scales)分割物体的模块
- 设计了串行和并行的空洞卷积模块,采用多种不同的atrous rates(采样率)来获取多尺度的内容信息
2. 动机:
DCNN中语义分割存在三个挑战:
- 连续下采用和重复池化,导致最后特征图分辨率低
- 图像中存在多尺度的物体
3. 应对策略:
- 使用空洞卷积,防止分辨率过低情况
- 串联不同膨胀率的空洞卷积或者并行不同膨胀率的空洞卷积(v2的ASPP),来获取更多上下文信息
4. 主要贡献:
- 重新讨论了空洞卷积的使用,这让我们在级联模块和空间金字塔池化的框架下,能够获取更大的感受野从而获取多尺度信息
- 改进了ASPP模块:由不同采样率的空洞卷积和BN层组成,我们尝试以级联或并行的方式布局模块
- 讨论了一个重要问题:使用大采样率的3×3的空洞卷积,因为图像边界的原因无法捕捉远距离信息,会退化为1×1的卷积, 作者提出将图像级特征融合到ASPP模块中
- 阐述了训练细节并分享了训练经验,论文提出的”DeepLabv3”改进了以前的工作,获得了很好的结果
DeepLabV3详解
5.提出问题
首先,语义分割问题存在两大挑战:
- 第一个挑战:连续池化操作或卷积中的stride导致的特征分辨率降低。这使得DCNN能够学习更抽象的特征表示。然而,这种不变性可能会阻碍密集预测任务,因为不变性也导致了详细空间信息的不确定。为了克服这个问题,提倡使用空洞卷积。
- 第二个挑战:多尺度物体的存在。几种方法已经被提出来处理这个问题,在本文中我们主要考虑了这些工作中的四种类型,如图1所示。
- 第一种:
Image Pyramid
,将输入图片放缩成不同比例,分别应用在DCNN上,将预测结果融合得到最终输出 - 第二种:
Encoder-Decoder
,将Encoder阶段的多尺度特征运用到Decoder阶段上来恢复空间分辨率 - 第三种:在原始模型的顶端叠加额外的模块,以捕捉像素间长距离信息。例如
Dense CRF
,或者叠加一些其他的卷积层 - 第四种:
Spatial Pyramid Pooling
空间金字塔池化,使用不同采样率和多种视野的卷积核,以捕捉多尺度对象
Deeplabv3+
Chen, L., Zhu, Y., Papandreou, G., Schroff, F., & Adam, H. (2018). Encoder-Decoder with Atrous Separable Convolution for Semantic Image Segmentation. ArXiv, abs/1802.02611.
1.创新点:
- 更深的Xception结构,不同的地方在于不修改entry flow network的结构,为了快速计算和有效的使用内存
- 所有的max pooling结构被stride=2的深度可分离卷积代替
- 每个3x3的depthwise convolution都跟BN和Relu
- 将改进后的Xception作为encodet主干网络,替换原本DeepLabv3的ResNet101
2. 动机:
语义分割主要面临两个问题:
- 物体的多尺度问题(DeepLabV3解决)
- DCNN的多次下采样会造成特征图分辨率变小,导致预测精度降低,边界信息丢失(DeepLabV3+解决目标)
3. 应对策略:
- 改进Xception,层数增加
- 将所有最大值池化层替换为带步长的深度可分离卷积层
Pytorch代码
class DeepLabV3Decoder(nn.Sequential):
def __init__(self, in_channels, out_channels=256, atrous_rates=(12, 24, 36)):
super().__init__(
ASPP(in_channels, out_channels, atrous_rates),
nn.Conv2d(out_channels, out_channels, 3, padding=1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(),
)
self.out_channels = out_channels
def forward(self, *features):
return super().forward(features[-1])
class DeepLabV3PlusDecoder(nn.Module):
def __init__(
self,
encoder_channels,
out_channels=256,
atrous_rates=(12, 24, 36),
output_stride=16,
):
super().__init__()
if output_stride not in {8, 16}:
raise ValueError("Output stride should be 8 or 16, got {}.".format(output_stride))
self.out_channels = out_channels
self.output_stride = output_stride
self.aspp = nn.Sequential(
ASPP(encoder_channels[-1], out_channels, atrous_rates, separable=True),
SeparableConv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(),
)
scale_factor = 2 if output_stride == 8 else 4
self.up = nn.UpsamplingBilinear2d(scale_factor=scale_factor)
highres_in_channels = encoder_channels[-4]
highres_out_channels = 48 # proposed by authors of paper
self.block1 = nn.Sequential(
nn.Conv2d(highres_in_channels, highres_out_channels, kernel_size=1, bias=False),
nn.BatchNorm2d(highres_out_channels),
nn.ReLU(),
)
self.block2 = nn.Sequential(
SeparableConv2d(
highres_out_channels + out_channels,
out_channels,
kernel_size=3,
padding=1,
bias=False,
),
nn.BatchNorm2d(out_channels),
nn.ReLU(),
)
def forward(self, *features):
aspp_features = self.aspp(features[-1])
aspp_features = self.up(aspp_features)
high_res_features = self.block1(features[-4])
concat_features = torch.cat([aspp_features, high_res_features], dim=1)
fused_features = self.block2(concat_features)
return fused_features
class ASPPConv(nn.Sequential):
def __init__(self, in_channels, out_channels, dilation):
super().__init__(
nn.Conv2d(
in_channels,
out_channels,
kernel_size=3,
padding=dilation,
dilation=dilation,
bias=False,
),
nn.BatchNorm2d(out_channels),
nn.ReLU(),
)
class ASPPSeparableConv(nn.Sequential):
def __init__(self, in_channels, out_channels, dilation):
super().__init__(
SeparableConv2d(
in_channels,
out_channels,
kernel_size=3,
padding=dilation,
dilation=dilation,
bias=False,
),
nn.BatchNorm2d(out_channels),
nn.ReLU(),
)
class ASPPPooling(nn.Sequential):
def __init__(self, in_channels, out_channels):
super().__init__(
nn.AdaptiveAvgPool2d(1),
nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(),
)
def forward(self, x):
size = x.shape[-2:]
for mod in self:
x = mod(x)
return F.interpolate(x, size=size, mode='bilinear', align_corners=False)
class ASPP(nn.Module):
def __init__(self, in_channels, out_channels, atrous_rates, separable=False):
super(ASPP, self).__init__()
modules = []
modules.append(
nn.Sequential(
nn.Conv2d(in_channels, out_channels, 1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(),
)
)
rate1, rate2, rate3 = tuple(atrous_rates)
ASPPConvModule = ASPPConv if not separable else ASPPSeparableConv
modules.append(ASPPConvModule(in_channels, out_channels, rate1))
modules.append(ASPPConvModule(in_channels, out_channels, rate2))
modules.append(ASPPConvModule(in_channels, out_channels, rate3))
modules.append(ASPPPooling(in_channels, out_channels))
self.convs = nn.ModuleList(modules)
self.project = nn.Sequential(
nn.Conv2d(5 * out_channels, out_channels, kernel_size=1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(),
nn.Dropout(0.5),
)
def forward(self, x):
res = []
for conv in self.convs:
res.append(conv(x))
res = torch.cat(res, dim=1)
return self.project(res)
class SeparableConv2d(nn.Sequential):
def __init__(
self,
in_channels,
out_channels,
kernel_size,
stride=1,
padding=0,
dilation=1,
bias=True,
):
dephtwise_conv = nn.Conv2d(
in_channels,
in_channels,
kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
groups=in_channels,
bias=False,
)
pointwise_conv = nn.Conv2d(
in_channels,
out_channels,
kernel_size=1,
bias=bias,
)
super().__init__(dephtwise_conv, pointwise_conv)
Fpn
模型结构
Pytorch代码
import torch
import torch.nn as nn
import torch.nn.functional as F
class Conv3x3GNReLU(nn.Module):
def __init__(self, in_channels, out_channels, upsample=False):
super().__init__()
self.upsample = upsample
self.block = nn.Sequential(
nn.Conv2d(
in_channels, out_channels, (3, 3), stride=1, padding=1, bias=False
),
nn.GroupNorm(32, out_channels),
nn.ReLU(inplace=True),
)
def forward(self, x):
x = self.block(x)
if self.upsample:
x = F.interpolate(x, scale_factor=2, mode="bilinear", align_corners=True)
return x
class FPNBlock(nn.Module):
def __init__(self, pyramid_channels, skip_channels):
super().__init__()
self.skip_conv = nn.Conv2d(skip_channels, pyramid_channels, kernel_size=1)
def forward(self, x, skip=None):
x = F.interpolate(x, scale_factor=2, mode="nearest")
skip = self.skip_conv(skip)
x = x + skip
return x
class SegmentationBlock(nn.Module):
def __init__(self, in_channels, out_channels, n_upsamples=0):
super().__init__()
blocks = [Conv3x3GNReLU(in_channels, out_channels, upsample=bool(n_upsamples))]
if n_upsamples > 1:
for _ in range(1, n_upsamples):
blocks.append(Conv3x3GNReLU(out_channels, out_channels, upsample=True))
self.block = nn.Sequential(*blocks)
def forward(self, x):
return self.block(x)
class MergeBlock(nn.Module):
def __init__(self, policy):
super().__init__()
if policy not in ["add", "cat"]:
raise ValueError(
"`merge_policy` must be one of: ['add', 'cat'], got {}".format(
policy
)
)
self.policy = policy
def forward(self, x):
if self.policy == 'add':
return sum(x)
elif self.policy == 'cat':
return torch.cat(x, dim=1)
else:
raise ValueError(
"`merge_policy` must be one of: ['add', 'cat'], got {}".format(self.policy)
)
class FPNDecoder(nn.Module):
def __init__(
self,
encoder_channels,
encoder_depth=5,
pyramid_channels=256,
segmentation_channels=128,
dropout=0.2,
merge_policy="add",
):
super().__init__()
self.out_channels = segmentation_channels if merge_policy == "add" else segmentation_channels * 4
if encoder_depth < 3:
raise ValueError("Encoder depth for FPN decoder cannot be less than 3, got {}.".format(encoder_depth))
encoder_channels = encoder_channels[::-1]
encoder_channels = encoder_channels[:encoder_depth + 1]
self.p5 = nn.Conv2d(encoder_channels[0], pyramid_channels, kernel_size=1)
self.p4 = FPNBlock(pyramid_channels, encoder_channels[1])
self.p3 = FPNBlock(pyramid_channels, encoder_channels[2])
self.p2 = FPNBlock(pyramid_channels, encoder_channels[3])
self.seg_blocks = nn.ModuleList([
SegmentationBlock(pyramid_channels, segmentation_channels, n_upsamples=n_upsamples)
for n_upsamples in [3, 2, 1, 0]
])
self.merge = MergeBlock(merge_policy)
self.dropout = nn.Dropout2d(p=dropout, inplace=True)
def forward(self, *features):
c2, c3, c4, c5 = features[-4:]
p5 = self.p5(c5)
p4 = self.p4(p5, c4)
p3 = self.p3(p4, c3)
p2 = self.p2(p3, c2)
feature_pyramid = [seg_block(p) for seg_block, p in zip(self.seg_blocks, [p5, p4, p3, p2])]
x = self.merge(feature_pyramid)
x = self.dropout(x)
return x
Linknet
Chaurasia, A., & Culurciello, E. (2017). LinkNet: Exploiting encoder representations for efficient semantic segmentation. 2017 IEEE Visual Communications and Image Processing (VCIP), 1-4.
模型结构
Pytorch代码
import torch.nn as nn
from ..base import modules
class TransposeX2(nn.Sequential):
def __init__(self, in_channels, out_channels, use_batchnorm=True):
super().__init__()
layers = [
nn.ConvTranspose2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1),
nn.ReLU(inplace=True)
]
if use_batchnorm:
layers.insert(1, nn.BatchNorm2d(out_channels))
super().__init__(*layers)
class DecoderBlock(nn.Module):
def __init__(self, in_channels, out_channels, use_batchnorm=True):
super().__init__()
self.block = nn.Sequential(
modules.Conv2dReLU(in_channels, in_channels // 4, kernel_size=1, use_batchnorm=use_batchnorm),
TransposeX2(in_channels // 4, in_channels // 4, use_batchnorm=use_batchnorm),
modules.Conv2dReLU(in_channels // 4, out_channels, kernel_size=1, use_batchnorm=use_batchnorm),
)
def forward(self, x, skip=None):
x = self.block(x)
if skip is not None:
x = x + skip
return x
class LinknetDecoder(nn.Module):
def __init__(
self,
encoder_channels,
prefinal_channels=32,
n_blocks=5,
use_batchnorm=True,
):
super().__init__()
encoder_channels = encoder_channels[1:] # remove first skip
encoder_channels = encoder_channels[::-1] # reverse channels to start from head of encoder
channels = list(encoder_channels) + [prefinal_channels]
self.blocks = nn.ModuleList([
DecoderBlock(channels[i], channels[i + 1], use_batchnorm=use_batchnorm)
for i in range(n_blocks)
])
def forward(self, *features):
features = features[1:] # remove first skip
features = features[::-1] # reverse channels to start from head of encoder
x = features[0]
skips = features[1:]
for i, decoder_block in enumerate(self.blocks):
skip = skips[i] if i < len(skips) else None
x = decoder_block(x, skip)
return x
MAnet
Fan, T., Wang, G., Li, Y., & Wang, H. (2020). MA-Net: A Multi-Scale Attention Network for Liver and Tumor Segmentation. IEEE Access, 8, 179656-179665.
模型结构
PAB Block
MFAB Block
Pytorch代码
import torch
import torch.nn as nn
import torch.nn.functional as F
from ..base import modules as md
class PAB(nn.Module):
def __init__(self, in_channels, out_channels, pab_channels=64):
super(PAB, self).__init__()
# Series of 1x1 conv to generate attention feature maps
self.pab_channels = pab_channels
self.in_channels = in_channels
self.top_conv = nn.Conv2d(in_channels, pab_channels, kernel_size=1)
self.center_conv = nn.Conv2d(in_channels, pab_channels, kernel_size=1)
self.bottom_conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1)
self.map_softmax = nn.Softmax(dim=1)
self.out_conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1)
def forward(self, x):
bsize = x.size()[0]
h = x.size()[2]
w = x.size()[3]
x_top = self.top_conv(x)
x_center = self.center_conv(x)
x_bottom = self.bottom_conv(x)
x_top = x_top.flatten(2)
x_center = x_center.flatten(2).transpose(1, 2)
x_bottom = x_bottom.flatten(2).transpose(1, 2)
sp_map = torch.matmul(x_center, x_top)
sp_map = self.map_softmax(sp_map.view(bsize, -1)).view(bsize, h*w, h*w)
sp_map = torch.matmul(sp_map, x_bottom)
sp_map = sp_map.reshape(bsize, self.in_channels, h, w)
x = x + sp_map
x = self.out_conv(x)
return x
class MFAB(nn.Module):
def __init__(self, in_channels, skip_channels, out_channels, use_batchnorm=True, reduction=16):
# MFAB is just a modified version of SE-blocks, one for skip, one for input
super(MFAB, self).__init__()
self.hl_conv = nn.Sequential(
md.Conv2dReLU(
in_channels,
in_channels,
kernel_size=3,
padding=1,
use_batchnorm=use_batchnorm,
),
md.Conv2dReLU(
in_channels,
skip_channels,
kernel_size=1,
use_batchnorm=use_batchnorm,
)
)
reduced_channels = max(1, skip_channels // reduction)
self.SE_ll = nn.Sequential(
nn.AdaptiveAvgPool2d(1),
nn.Conv2d(skip_channels, reduced_channels, 1),
nn.ReLU(inplace=True),
nn.Conv2d(reduced_channels, skip_channels, 1),
nn.Sigmoid(),
)
self.SE_hl = nn.Sequential(
nn.AdaptiveAvgPool2d(1),
nn.Conv2d(skip_channels, reduced_channels, 1),
nn.ReLU(inplace=True),
nn.Conv2d(reduced_channels, skip_channels, 1),
nn.Sigmoid(),
)
self.conv1 = md.Conv2dReLU(
skip_channels + skip_channels, # we transform C-prime form high level to C from skip connection
out_channels,
kernel_size=3,
padding=1,
use_batchnorm=use_batchnorm,
)
self.conv2 = md.Conv2dReLU(
out_channels,
out_channels,
kernel_size=3,
padding=1,
use_batchnorm=use_batchnorm,
)
def forward(self, x, skip=None):
x = self.hl_conv(x)
x = F.interpolate(x, scale_factor=2, mode="nearest")
attention_hl = self.SE_hl(x)
if skip is not None:
attention_ll = self.SE_ll(skip)
attention_hl = attention_hl + attention_ll
x = x * attention_hl
x = torch.cat([x, skip], dim=1)
x = self.conv1(x)
x = self.conv2(x)
return x
class DecoderBlock(nn.Module):
def __init__(
self,
in_channels,
skip_channels,
out_channels,
use_batchnorm=True
):
super().__init__()
self.conv1 = md.Conv2dReLU(
in_channels + skip_channels,
out_channels,
kernel_size=3,
padding=1,
use_batchnorm=use_batchnorm,
)
self.conv2 = md.Conv2dReLU(
out_channels,
out_channels,
kernel_size=3,
padding=1,
use_batchnorm=use_batchnorm,
)
def forward(self, x, skip=None):
x = F.interpolate(x, scale_factor=2, mode="nearest")
if skip is not None:
x = torch.cat([x, skip], dim=1)
x = self.conv1(x)
x = self.conv2(x)
return x
class MAnetDecoder(nn.Module):
def __init__(
self,
encoder_channels,
decoder_channels,
n_blocks=5,
reduction=16,
use_batchnorm=True,
pab_channels=64
):
super().__init__()
if n_blocks != len(decoder_channels):
raise ValueError(
"Model depth is {}, but you provide `decoder_channels` for {} blocks.".format(
n_blocks, len(decoder_channels)
)
)
encoder_channels = encoder_channels[1:] # remove first skip with same spatial resolution
encoder_channels = encoder_channels[::-1] # reverse channels to start from head of encoder
# computing blocks input and output channels
head_channels = encoder_channels[0]
in_channels = [head_channels] + list(decoder_channels[:-1])
skip_channels = list(encoder_channels[1:]) + [0]
out_channels = decoder_channels
self.center = PAB(head_channels, head_channels, pab_channels=pab_channels)
# combine decoder keyword arguments
kwargs = dict(use_batchnorm=use_batchnorm) # no attention type here
blocks = [
MFAB(in_ch, skip_ch, out_ch, reduction=reduction, **kwargs) if skip_ch > 0 else
DecoderBlock(in_ch, skip_ch, out_ch, **kwargs)
for in_ch, skip_ch, out_ch in zip(in_channels, skip_channels, out_channels)
]
# for the last we dont have skip connection -> use simple decoder block
self.blocks = nn.ModuleList(blocks)
def forward(self, *features):
# for fea in features:
# print(fea.shape)
# exit()
features = features[1:] # remove first skip with same spatial resolution
features = features[::-1] # reverse channels to start from head of encoder
head = features[0]
skips = features[1:]
x = self.center(head)
print(x.shape)
for i, decoder_block in enumerate(self.blocks):
skip = skips[i] if i < len(skips) else None
x = decoder_block(x, skip)
print(x.shape)
exit()
return x
PAN
Li, H., Xiong, P., An, J., & Wang, L. (2018). Pyramid Attention Network for Semantic Segmentation. ArXiv, abs/1805.10180.
模型结构
Pytorch代码
import torch
import torch.nn as nn
import torch.nn.functional as F
class ConvBnRelu(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: int,
stride: int = 1,
padding: int = 0,
dilation: int = 1,
groups: int = 1,
bias: bool = True,
add_relu: bool = True,
interpolate: bool = False
):
super(ConvBnRelu, self).__init__()
self.conv = nn.Conv2d(
in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size,
stride=stride, padding=padding, dilation=dilation, bias=bias, groups=groups
)
self.add_relu = add_relu
self.interpolate = interpolate
self.bn = nn.BatchNorm2d(out_channels)
self.activation = nn.ReLU(inplace=True)
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
if self.add_relu:
x = self.activation(x)
if self.interpolate:
x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=True)
return x
class FPABlock(nn.Module):
def __init__(
self,
in_channels,
out_channels,
upscale_mode='bilinear'
):
super(FPABlock, self).__init__()
self.upscale_mode = upscale_mode
if self.upscale_mode == 'bilinear':
self.align_corners = True
else:
self.align_corners = False
# global pooling branch
self.branch1 = nn.Sequential(
nn.AdaptiveAvgPool2d(1),
ConvBnRelu(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=1, padding=0)
)
# midddle branch
self.mid = nn.Sequential(
ConvBnRelu(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=1, padding=0)
)
self.down1 = nn.Sequential(
nn.MaxPool2d(kernel_size=2, stride=2),
ConvBnRelu(in_channels=in_channels, out_channels=1, kernel_size=7, stride=1, padding=3)
)
self.down2 = nn.Sequential(
nn.MaxPool2d(kernel_size=2, stride=2),
ConvBnRelu(in_channels=1, out_channels=1, kernel_size=5, stride=1, padding=2)
)
self.down3 = nn.Sequential(
nn.MaxPool2d(kernel_size=2, stride=2),
ConvBnRelu(in_channels=1, out_channels=1, kernel_size=3, stride=1, padding=1),
ConvBnRelu(in_channels=1, out_channels=1, kernel_size=3, stride=1, padding=1),
)
self.conv2 = ConvBnRelu(in_channels=1, out_channels=1, kernel_size=5, stride=1, padding=2)
self.conv1 = ConvBnRelu(in_channels=1, out_channels=1, kernel_size=7, stride=1, padding=3)
def forward(self, x):
h, w = x.size(2), x.size(3)
b1 = self.branch1(x)
upscale_parameters = dict(
mode=self.upscale_mode,
align_corners=self.align_corners
)
b1 = F.interpolate(b1, size=(h, w), **upscale_parameters)
mid = self.mid(x)
x1 = self.down1(x)
x2 = self.down2(x1)
x3 = self.down3(x2)
x3 = F.interpolate(x3, size=(h // 4, w // 4), **upscale_parameters)
x2 = self.conv2(x2)
x = x2 + x3
x = F.interpolate(x, size=(h // 2, w // 2), **upscale_parameters)
x1 = self.conv1(x1)
x = x + x1
x = F.interpolate(x, size=(h, w), **upscale_parameters)
x = torch.mul(x, mid)
x = x + b1
return x
class GAUBlock(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
upscale_mode: str = 'bilinear'
):
super(GAUBlock, self).__init__()
self.upscale_mode = upscale_mode
self.align_corners = True if upscale_mode == 'bilinear' else None
self.conv1 = nn.Sequential(
nn.AdaptiveAvgPool2d(1),
ConvBnRelu(in_channels=out_channels, out_channels=out_channels, kernel_size=1, add_relu=False),
nn.Sigmoid()
)
self.conv2 = ConvBnRelu(in_channels=in_channels, out_channels=out_channels, kernel_size=3, padding=1)
def forward(self, x, y):
"""
Args:
x: low level feature
y: high level feature
"""
h, w = x.size(2), x.size(3)
y_up = F.interpolate(
y, size=(h, w), mode=self.upscale_mode, align_corners=self.align_corners
)
x = self.conv2(x)
y = self.conv1(y)
z = torch.mul(x, y)
return y_up + z
class PANDecoder(nn.Module):
def __init__(
self,
encoder_channels,
decoder_channels,
upscale_mode: str = 'bilinear'
):
super().__init__()
self.fpa = FPABlock(in_channels=encoder_channels[-1], out_channels=decoder_channels)
self.gau3 = GAUBlock(in_channels=encoder_channels[-2], out_channels=decoder_channels, upscale_mode=upscale_mode)
self.gau2 = GAUBlock(in_channels=encoder_channels[-3], out_channels=decoder_channels, upscale_mode=upscale_mode)
self.gau1 = GAUBlock(in_channels=encoder_channels[-4], out_channels=decoder_channels, upscale_mode=upscale_mode)
def forward(self, *features):
bottleneck = features[-1]
x5 = self.fpa(bottleneck) # 1/32
x4 = self.gau3(features[-2], x5) # 1/16
x3 = self.gau2(features[-3], x4) # 1/8
x2 = self.gau1(features[-4], x3) # 1/4
return x2
PSPnet
模型结构
Pytorch代码
import torch
import torch.nn as nn
import torch.nn.functional as F
from ..base import modules
class PSPBlock(nn.Module):
def __init__(self, in_channels, out_channels, pool_size, use_bathcnorm=True):
super().__init__()
if pool_size == 1:
use_bathcnorm = False # PyTorch does not support BatchNorm for 1x1 shape
self.pool = nn.Sequential(
nn.AdaptiveAvgPool2d(output_size=(pool_size, pool_size)),
modules.Conv2dReLU(in_channels, out_channels, (1, 1), use_batchnorm=use_bathcnorm)
)
def forward(self, x):
h, w = x.size(2), x.size(3)
x = self.pool(x)
x = F.interpolate(x, size=(h, w), mode='bilinear', align_corners=True)
return x
class PSPModule(nn.Module):
def __init__(self, in_channels, sizes=(1, 2, 3, 6), use_bathcnorm=True):
super().__init__()
self.blocks = nn.ModuleList([
PSPBlock(in_channels, in_channels // len(sizes), size, use_bathcnorm=use_bathcnorm) for size in sizes
])
def forward(self, x):
xs = [block(x) for block in self.blocks] + [x]
x = torch.cat(xs, dim=1)
return x
class PSPDecoder(nn.Module):
def __init__(
self,
encoder_channels,
use_batchnorm=True,
out_channels=512,
dropout=0.2,
):
super().__init__()
self.psp = PSPModule(
in_channels=encoder_channels[-1],
sizes=(1, 2, 3, 6),
use_bathcnorm=use_batchnorm,
)
self.conv = modules.Conv2dReLU(
in_channels=encoder_channels[-1] * 2,
out_channels=out_channels,
kernel_size=1,
use_batchnorm=use_batchnorm,
)
self.dropout = nn.Dropout2d(p=dropout)
def forward(self, *features):
x = features[-1]
x = self.psp(x)
x = self.conv(x)
x = self.dropout(x)
return x
U-net
模型结构
Pytorch代码
import torch
import torch.nn as nn
import torch.nn.functional as F
from ..base import modules as md
class DecoderBlock(nn.Module):
def __init__(
self,
in_channels,
skip_channels,
out_channels,
use_batchnorm=True,
attention_type=None,
):
super().__init__()
self.conv1 = md.Conv2dReLU(
in_channels + skip_channels,
out_channels,
kernel_size=3,
padding=1,
use_batchnorm=use_batchnorm,
)
self.attention1 = md.Attention(attention_type, in_channels=in_channels + skip_channels)
self.conv2 = md.Conv2dReLU(
out_channels,
out_channels,
kernel_size=3,
padding=1,
use_batchnorm=use_batchnorm,
)
self.attention2 = md.Attention(attention_type, in_channels=out_channels)
def forward(self, x, skip=None):
x = F.interpolate(x, scale_factor=2, mode="nearest")
if skip is not None:
x = torch.cat([x, skip], dim=1)
x = self.attention1(x)
x = self.conv1(x)
x = self.conv2(x)
x = self.attention2(x)
return x
class CenterBlock(nn.Sequential):
def __init__(self, in_channels, out_channels, use_batchnorm=True):
conv1 = md.Conv2dReLU(
in_channels,
out_channels,
kernel_size=3,
padding=1,
use_batchnorm=use_batchnorm,
)
conv2 = md.Conv2dReLU(
out_channels,
out_channels,
kernel_size=3,
padding=1,
use_batchnorm=use_batchnorm,
)
super().__init__(conv1, conv2)
class UnetDecoder(nn.Module):
def __init__(
self,
encoder_channels,
decoder_channels,
n_blocks=5,
use_batchnorm=True,
attention_type=None,
center=False,
):
super().__init__()
if n_blocks != len(decoder_channels):
raise ValueError(
"Model depth is {}, but you provide `decoder_channels` for {} blocks.".format(
n_blocks, len(decoder_channels)
)
)
encoder_channels = encoder_channels[1:] # remove first skip with same spatial resolution
encoder_channels = encoder_channels[::-1] # reverse channels to start from head of encoder
# computing blocks input and output channels
head_channels = encoder_channels[0]
in_channels = [head_channels] + list(decoder_channels[:-1])
skip_channels = list(encoder_channels[1:]) + [0]
out_channels = decoder_channels
if center:
self.center = CenterBlock(
head_channels, head_channels, use_batchnorm=use_batchnorm
)
else:
self.center = nn.Identity()
# combine decoder keyword arguments
kwargs = dict(use_batchnorm=use_batchnorm, attention_type=attention_type)
blocks = [
DecoderBlock(in_ch, skip_ch, out_ch, **kwargs)
for in_ch, skip_ch, out_ch in zip(in_channels, skip_channels, out_channels)
]
self.blocks = nn.ModuleList(blocks)
def forward(self, *features):
features = features[1:] # remove first skip with same spatial resolution
features = features[::-1] # reverse channels to start from head of encoder
head = features[0]
skips = features[1:]
x = self.center(head)
for i, decoder_block in enumerate(self.blocks):
skip = skips[i] if i < len(skips) else None
x = decoder_block(x, skip)
return x
Unet-plus-plus
模型结构
Pytorch代码
import torch
import torch.nn as nn
import torch.nn.functional as F
from ..base import modules as md
class DecoderBlock(nn.Module):
def __init__(
self,
in_channels,
skip_channels,
out_channels,
use_batchnorm=True,
attention_type=None,
):
super().__init__()
self.conv1 = md.Conv2dReLU(
in_channels + skip_channels,
out_channels,
kernel_size=3,
padding=1,
use_batchnorm=use_batchnorm,
)
self.attention1 = md.Attention(attention_type, in_channels=in_channels + skip_channels)
self.conv2 = md.Conv2dReLU(
out_channels,
out_channels,
kernel_size=3,
padding=1,
use_batchnorm=use_batchnorm,
)
self.attention2 = md.Attention(attention_type, in_channels=out_channels)
def forward(self, x, skip=None):
x = F.interpolate(x, scale_factor=2, mode="nearest")
if skip is not None:
x = torch.cat([x, skip], dim=1)
x = self.attention1(x)
x = self.conv1(x)
x = self.conv2(x)
x = self.attention2(x)
return x
class CenterBlock(nn.Sequential):
def __init__(self, in_channels, out_channels, use_batchnorm=True):
conv1 = md.Conv2dReLU(
in_channels,
out_channels,
kernel_size=3,
padding=1,
use_batchnorm=use_batchnorm,
)
conv2 = md.Conv2dReLU(
out_channels,
out_channels,
kernel_size=3,
padding=1,
use_batchnorm=use_batchnorm,
)
super().__init__(conv1, conv2)
class UnetPlusPlusDecoder(nn.Module):
def __init__(
self,
encoder_channels,
decoder_channels,
n_blocks=5,
use_batchnorm=True,
attention_type=None,
center=False,
):
super().__init__()
if n_blocks != len(decoder_channels):
raise ValueError(
"Model depth is {}, but you provide `decoder_channels` for {} blocks.".format(
n_blocks, len(decoder_channels)
)
)
encoder_channels = encoder_channels[1:] # remove first skip with same spatial resolution
encoder_channels = encoder_channels[::-1] # reverse channels to start from head of encoder
# computing blocks input and output channels
head_channels = encoder_channels[0]
self.in_channels = [head_channels] + list(decoder_channels[:-1])
self.skip_channels = list(encoder_channels[1:]) + [0]
self.out_channels = decoder_channels
if center:
self.center = CenterBlock(
head_channels, head_channels, use_batchnorm=use_batchnorm
)
else:
self.center = nn.Identity()
# combine decoder keyword arguments
kwargs = dict(use_batchnorm=use_batchnorm, attention_type=attention_type)
blocks = {}
for layer_idx in range(len(self.in_channels) - 1):
for depth_idx in range(layer_idx+1):
if depth_idx == 0:
in_ch = self.in_channels[layer_idx]
skip_ch = self.skip_channels[layer_idx] * (layer_idx+1)
out_ch = self.out_channels[layer_idx]
else:
out_ch = self.skip_channels[layer_idx]
skip_ch = self.skip_channels[layer_idx] * (layer_idx+1-depth_idx)
in_ch = self.skip_channels[layer_idx - 1]
blocks[f'x_{depth_idx}_{layer_idx}'] = DecoderBlock(in_ch, skip_ch, out_ch, **kwargs)
blocks[f'x_{0}_{len(self.in_channels)-1}'] =\
DecoderBlock(self.in_channels[-1], 0, self.out_channels[-1], **kwargs)
self.blocks = nn.ModuleDict(blocks)
self.depth = len(self.in_channels) - 1
def forward(self, *features):
features = features[1:] # remove first skip with same spatial resolution
features = features[::-1] # reverse channels to start from head of encoder
# start building dense connections
dense_x = {}
for layer_idx in range(len(self.in_channels)-1):
for depth_idx in range(self.depth-layer_idx):
if layer_idx == 0:
output = self.blocks[f'x_{depth_idx}_{depth_idx}'](features[depth_idx], features[depth_idx+1])
dense_x[f'x_{depth_idx}_{depth_idx}'] = output
else:
dense_l_i = depth_idx + layer_idx
cat_features = [dense_x[f'x_{idx}_{dense_l_i}'] for idx in range(depth_idx+1, dense_l_i+1)]
cat_features = torch.cat(cat_features + [features[dense_l_i+1]], dim=1)
dense_x[f'x_{depth_idx}_{dense_l_i}'] =\
self.blocks[f'x_{depth_idx}_{dense_l_i}'](dense_x[f'x_{depth_idx}_{dense_l_i-1}'], cat_features)
dense_x[f'x_{0}_{self.depth}'] = self.blocks[f'x_{0}_{self.depth}'](dense_x[f'x_{0}_{self.depth-1}'])
return dense_x[f'x_{0}_{self.depth}']