笨方法实现unet

import logging

logging.basicConfig(level=logging.INFO,
                    format='%(asctime)s-%(filename)s[line:%(lineno)d]-%(levelname)s:%(message)s',
                    datefmt='%Y-%m-%d %H:%M:%S')
import torch
import torch.nn as nn
import torch.nn.functional as F


class iUnet(nn.Module):
    def __init__(self, num_classes=10):
        super().__init__()
        # 第一次卷积 - encode   N 3 512 512 -> N 64 256 256
        self.conv1_1 = nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, padding=1)
        self.bn1_1 = nn.BatchNorm2d(64)
        self.relu1_1 = nn.ReLU(inplace=1)

        self.conv1_2 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, padding=1)
        self.bn1_2 = nn.BatchNorm2d(64)
        self.relu1_2 = nn.ReLU(inplace=1)

        self.pool1 = nn.MaxPool2d(2)  # N 64 512 512 -> N 64 256 256

        # 第二次卷积 - encode  N 64 256 256 -> N 128 128 128
        self.conv2_1 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, padding=1)
        self.bn2_1 = nn.BatchNorm2d(128)
        self.relu2_1 = nn.ReLU(inplace=1)

        self.conv2_2 = nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, padding=1)
        self.bn2_2 = nn.BatchNorm2d(128)
        self.relu2_2 = nn.ReLU(inplace=1)

        self.pool2 = nn.MaxPool2d(2)  # N 128 256 256 -> N 128 128 128

        # 第三次卷积 - encode  N 128 128 128 -> N 256 64 64
        self.conv3_1 =  nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, padding=1)
        self.bn3_1 = nn.BatchNorm2d(256)
        self.relu3_1 = nn.ReLU(inplace=1)

        self.conv3_2 = nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, padding=1)
        self.bn3_2 = nn.BatchNorm2d(256)
        self.relu3_2 = nn.ReLU(inplace=1)

        self.pool3 = nn.MaxPool2d(2)  # N 256 128 128 -> N 256 64 64

        # 第四次卷积 - encode N 256 64 64 -> N 512 32 32
        self.conv4_1 = nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, padding=1)
        self.bn4_1 = nn.BatchNorm2d(512)
        self.relu4_1 = nn.ReLU(inplace=1)

        self.conv4_2 = nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, padding=1)
        self.bn4_2 = nn.BatchNorm2d(512)
        self.relu4_2 = nn.ReLU(inplace=1)

        self.pool4 = nn.MaxPool2d(2)  # N 512 64 64 -> N 512 32 32

        # 第五次卷积 - encode  N 512 32 32 -> N 1024 32 32
        self.conv5_1 = nn.Conv2d(in_channels=512, out_channels=1024, kernel_size=3, padding=1)
        self.bn5_1 = nn.BatchNorm2d(1024)
        self.relu5_1 = nn.ReLU(inplace=1)

        self.conv5_2 = nn.Conv2d(in_channels=1024, out_channels=1024, kernel_size=3, padding=1)
        self.bn5_2 = nn.BatchNorm2d(1024)
        self.relu5_2 = nn.ReLU(inplace=1)

        # 第1次解码 - decode
        self.upsample1 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        self._conv1 = nn.Conv2d(in_channels=1024, out_channels=512, kernel_size=3, padding=1)

        self._conv1_1 = nn.Conv2d(in_channels=1024, out_channels=512, kernel_size=3, padding=1)
        self._bn1_1 = nn.BatchNorm2d(512)
        self._relu1_1 = nn.ReLU(inplace=1)

        self._conv1_2 = nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, padding=1)
        self._bn1_2 = nn.BatchNorm2d(512)
        self._relu1_2 = nn.ReLU(inplace=1)

        # 第2次解码 - decode
        self.upsample2 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        self._conv2 = nn.Conv2d(in_channels=512, out_channels=256, kernel_size=3, padding=1)

        self._conv2_1 = nn.Conv2d(in_channels=512, out_channels=256, kernel_size=3, padding=1)
        self._bn2_1 = nn.BatchNorm2d(256)
        self._relu2_1 = nn.ReLU(inplace=1)

        self._conv2_2 = nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, padding=1)
        self._bn2_2 = nn.BatchNorm2d(256)
        self._relu2_2 = nn.ReLU(inplace=1)

        # 第3次解码 - decode
        self.upsample3 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        self._conv3 = nn.Conv2d(in_channels=256, out_channels=128, kernel_size=3, padding=1)

        self._conv3_1 = nn.Conv2d(in_channels=256, out_channels=128, kernel_size=3, padding=1)
        self._bn3_1 = nn.BatchNorm2d(128)
        self._relu3_1 = nn.ReLU(inplace=1)

        self._conv3_2 = nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, padding=1)
        self._bn3_2 = nn.BatchNorm2d(128)
        self._relu3_2 = nn.ReLU(inplace=1)

        # 第4次解码 - decode
        self.upsample4 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        self._conv4 = nn.Conv2d(in_channels=128, out_channels=64, kernel_size=3, padding=1)

        self._conv4_1 = nn.Conv2d(in_channels=128, out_channels=64, kernel_size=3, padding=1)
        self._bn4_1 = nn.BatchNorm2d(64)
        self._relu4_1 = nn.ReLU(inplace=1)

        self._conv4_2 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, padding=1)
        self._bn4_2 = nn.BatchNorm2d(64)
        self._relu4_2 = nn.ReLU(inplace=1)

        # 输出类别信息
        self.out = nn.Conv2d(in_channels=64, out_channels=num_classes, kernel_size=1)

    def forward(self, x):
        # 编码
        x = self.relu1_1(self.bn1_1(self.conv1_1(x)))
        x1 = self.relu1_2(self.bn1_2(self.conv1_2(x)))
        # logging.info(f'x1:{x1.shape}')

        x = self.pool1(x1)
        x = self.relu2_1(self.bn2_1(self.conv2_1(x)))
        x2 = self.relu2_2(self.bn2_2(self.conv2_2(x)))
        # logging.info(f'x2:{x2.shape}')

        x = self.pool2(x2)
        x = self.relu3_1(self.bn3_1(self.conv3_1(x)))
        x3 = self.relu3_2(self.bn3_2(self.conv3_2(x)))
        # logging.info(f'x3:{x3.shape}')

        x = self.pool3(x3)
        x = self.relu4_1(self.bn4_1(self.conv4_1(x)))
        x4 = self.relu4_2(self.bn4_2(self.conv4_2(x)))
        # logging.info(f'x4:{x4.shape}')

        x = self.pool4(x4)
        x = self.relu5_1(self.bn5_1(self.conv5_1(x)))
        x = self.relu5_2(self.bn5_2(self.conv5_2(x)))
        # logging.info(f'x5:{x.shape}')

        # 解码
        x = self.upsample1(x)
        x = self._conv1(x)
        x = torch.cat([x, x4], dim=1)
        x = self._relu1_1(self._bn1_1(self._conv1_1(x)))
        x = self._relu1_2(self._bn1_2(self._conv1_2(x)))
        # logging.info(f'dx1:{x.shape}')

        x = self.upsample2(x)
        x = self._conv2(x)
        x = torch.cat([x, x3], dim=1)
        x = self._relu2_1(self._bn2_1(self._conv2_1(x)))
        x = self._relu2_2(self._bn2_2(self._conv2_2(x)))
        # logging.info(f'dx2:{x.shape}')

        x = self.upsample3(x)
        x = self._conv3(x)
        x = torch.cat([x, x2], dim=1)
        x = self._relu3_1(self._bn3_1(self._conv3_1(x)))
        x = self._relu3_2(self._bn3_2(self._conv3_2(x)))
        # logging.info(f'dx3:{x.shape}')

        x = self.upsample4(x)
        x = self._conv4(x)
        x = torch.cat([x, x1], dim=1)
        x = self._relu4_1(self._bn4_1(self._conv4_1(x)))
        x = self._relu4_2(self._bn4_2(self._conv4_2(x)))
        # logging.info(f'dx4:{x.shape}')

        x = self.out(x)
        return x


if __name__ == '__main__':
    data = torch.randn(4, 3, 384, 384)
    net = iUnet()
    pred = net(data)    
posted @ 2023-07-18 18:30  ddzhen  阅读(21)  评论(0编辑  收藏  举报