【NN 基础模型】:U-Net及其各种变体

最早被提出应用于医学图像分割,后扩展至通用分割,后面在 low-level 领域也发挥着巨大的作用。

U-Net

友情链接:U-Net: Convolutional Networks for Biomedical Image Segmentation)

网络原理

U-Net 网络结构

U-Net 的网络结构其实很简单,类似于传统图像处理中的金字塔结构。对输入进行多次的 conv+relu 特征提取,然后进行 maxpooling 下采样,扩大感受野的同时减小特征图尺寸,循环多次后得到上图中最下面一层的特征图,然后进行 upsample+conv+concat,再对上采样之后的特征图进行 conv+relu 操作,和前面一样重复多次,便得到了最后的结果。

如今的 U-Net 不光是在分割领域,在笔者所从事的 low-level 视觉中也得到了广泛的应用,比如超分、降噪等,他们的一个共同特征是,输入和输出往往是相同尺寸的,image2image 的任务。

pytorch 代码

import torch
import torch.nn as nn


def double_conv_relu(n_in, n_out):
    block = nn.Sequential(
        nn.Conv2d(n_in, n_out, 3, 1, 1),
        nn.ReLU(),
        nn.Conv2d(n_out, n_out, 3, 1, 1),
        nn.ReLU(),
    )

    return block


class UNet(nn.Module):
    def __init__(self) -> None:
        super(UNet, self).__init__()

        self.conv1 = double_conv_relu(1, 64)
        self.down1 = nn.MaxPool2d(2, 2)

        self.conv2 = double_conv_relu(64, 128)
        self.down2 = nn.MaxPool2d(2, 2)

        self.conv3 = double_conv_relu(128, 256)
        self.down3 = nn.MaxPool2d(2, 2)

        self.conv4 = double_conv_relu(256, 512)
        self.down4 = nn.MaxPool2d(2, 2)

        self.conv5 = double_conv_relu(512, 1024)

        self.up1 = nn.ConvTranspose2d(1024, 512, 4, 2, 1)
        self.conv6 = double_conv_relu(1024, 512)

        self.up2 = nn.ConvTranspose2d(512, 256, 4, 2, 1)
        self.conv7 = double_conv_relu(512, 256)

        self.up3 = nn.ConvTranspose2d(256, 128, 4, 2, 1)
        self.conv8 = double_conv_relu(256, 128)

        self.up4 = nn.ConvTranspose2d(128, 64, 4, 2, 1)
        self.conv9 = double_conv_relu(128, 64)

        self.conv_last = nn.Conv2d(64, 3, 3, 1, 1)

    def forward(self, x):
        feat1 = self.conv1(x)
        feat2 = self.conv2(self.down1(feat1))
        feat3 = self.conv3(self.down2(feat2))
        feat4 = self.conv4(self.down3(feat3))
        feat5 = self.conv5(self.down4(feat4))
        feat6 = self.conv6(torch.cat((feat4, self.up1(feat5)), dim=1))
        feat7 = self.conv7(torch.cat((feat3, self.up2(feat6)), dim=1))
        feat8 = self.conv8(torch.cat((feat2, self.up3(feat7)), dim=1))
        feat9 = self.conv9(torch.cat((feat1, self.up4(feat8)), dim=1))
        out = self.conv_last(feat9)

        return out


if __name__ == "__main__":
    x = torch.rand(1, 1, 256, 256)

    net = UNet(3, 64, 3)

    y = net(x)
    print(y.shape)

自己写的,比较粗糙,主要看一下网络结构。因为 U-Net 提出是在 2015 年,很早了,后续在应用的时候会把原本代码中的 conv+relu 的结构换成残差块的结构,可以取得更好的效果。

UNet++

友情链接:UNet++: A Nested U-Net Architecture for Medical Image Segmentation

nnUNet

友情链接:nnU-Net: Self-adapting Framework for U-Net-Based Medical Image Segmentation

posted @   rangoXTY  阅读(253)  评论(0编辑  收藏  举报
相关博文:
阅读排行:
· 震惊!C++程序真的从main开始吗?99%的程序员都答错了
· 【硬核科普】Trae如何「偷看」你的代码?零基础破解AI编程运行原理
· 单元测试从入门到精通
· 上周热点回顾(3.3-3.9)
· winform 绘制太阳,地球,月球 运作规律
点击右上角即可分享
微信分享提示