U-net代码解析备忘录

0.前言

这并不是一个教程,纯粹是记录自己遇到的问题,所以想到什么就写什么。
另外这份代码是基于GitHub : milesial/Pytorch-UNet简化了很多东西得来的。

1.一些细节

建立的模型都要写成类的形式,计算都用方法,不要在其中混入函数。否则可能无法调用GPU计算,而且输出模型时函数的部分不会被显示出来。(补充:改写为类后速度似乎变快了)。

2. 需要解释的代码

2.1 转置卷积层 nn.ConvTranspose2d

文档介绍:
Applies a 2D transposed convolution operator over an input image composed of several input planes.
在一系列输入图像上进行2D转置卷积操作。
This module can be seen as the gradient of Conv2d with respect to its input. It is also known as a fractionally-strided convolution or a deconvolution (although it is not an actual deconvolution operation).
这一模块可以被视作其输入的二维卷积的梯度(Conv2d)。这也被称为分数步长卷积(fractionally-strided convolution)或者反卷积(尽管这并不是真正的卷积操作)。
补充阅读:Dive into Deep Learning-9.10.1

>>> input = torch.randn(1, 16, 12, 12)
>>> downsample = nn.Conv2d(16, 16, 3, stride=2, padding=1)
>>> upsample = nn.ConvTranspose2d(16, 16, 3, stride=2, padding=1)
>>> h = downsample(input)
>>> h.size()
torch.Size([1, 16, 6, 6])
>>> output = upsample(h, output_size=input.size())
>>> output.size()
torch.Size([1, 16, 12, 12])

代码中上采样的实现:

class up(nn.Module):
    def __init__(self,ch_in,ch_out):
        super(up, self).__init__()
        #----------上采样--------------
        # 
        self.UpConv = nn.ConvTranspose2d(ch_in, ch_out , kernel_size=2, stride=2)
        #-----------end----------------
        # self.UpConv=nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        self.conv=doubleconv(ch_in,ch_out)
    def __call__(self, left,right):
        print(right.shape)
        right = self.UpConv(right)
        #----------------调整图像的长宽---------
        diffY = left.size()[2] - right.size()[2]  # height
        diffX = left.size()[3] - right.size()[3]  # width
        # F.pad 向 [左右上下]填充 0
        # 把right的长宽调整到与left一致
        right = F.pad(right, [diffX // 2, diffX - diffX // 2,
                        diffY // 2, diffY - diffY // 2])
        #-------------end----------------------
        print('left : {} ; right : {}'.format(left.shape,right.shape))
        #--------------在channel的维度上合并--------------
        temp=torch.cat([left,right],dim=1)
        #--------------------end-------------------------
        print('temp : {}'.format(temp.shape))
        return self.conv(temp)

2.2 nn.BatchNorm2d

Applies Batch Normalization over a 4D input (a mini-batch of 2D inputs with additional channel dimension) as described in the paper Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift.

在一个4维的输入上应用批标准化(Batch Normalization)(带有通道维度的2维输入的mini-batch,类似常见的torch.Size([1, 3, 12, 12]))。
作用:在神经网络训练时遇到收敛速度很慢,或梯度爆炸等无法训练的状况时可以尝试BN来解决。另外,在一般使用情况下也可以加入BN来加快训练速度,提高模型精度。

补充:深度学习中 Batch Normalization为什么效果好? - 魏秀参的回答 - 知乎

2.3 梯度截断 nn.utils.clip_grad_value_(model.parameters(), 0.1)

2.4 动态调整学习率 torch.optim.lr_scheduler.ReduceLROnPlateau

torch.optim.lr_scheduler.ReduceLROnPlateau
Reduce learning rate when a metric has stopped improving. Models often benefit from reducing the learning rate by a factor of 2-10 once learning stagnates. This scheduler reads a metrics quantity and if no improvement is seen for a ‘patience’ number of epochs, the learning rate is reduced.
当指标停止增长时降低学习率。一旦学习过程停滞,学习率2-10倍的降低总是有利于模型的训练。这一计划(scheduler)考察一个指标的质量,如果一定训练轮数内没有提高的话,学习率就会降低。

常用参数:

  1. optimizer (Optimizer) – Wrapped optimizer.

  2. mode (str) – One of min, max. In min mode, lr will be reduced when the quantity monitored has stopped decreasing; in max mode it will be reduced when the quantity monitored has stopped increasing. Default: ‘min’.
    mode - min或者max
    在min模式下,当被监控的指标(各种loss)停止下降时,lr(学习率)将会减小,而在max模式下,被监控的指标停止上升(比如分类的正确率,AUC等)lr才会减小。
    默认为min

  3. factor (float) – Factor by which the learning rate will be reduced. new_lr = lr * factor. Default: 0.1.
    factor (float) - 学习率降低的系数 new_lr = lr * factor,默认为0.1

  4. patience (int) – Number of epochs with no improvement after which learning rate will be reduced. For example, if patience = 2, then we will ignore the first 2 epochs with no improvement, and will only decrease the LR after the 3rd epoch if the loss still hasn’t improved then. Default: 10.
    patience (int) 在patience轮后,指标仍未好转,那么调整学习率,默认为10轮。

n_val = int(len(dataset) * val_percent)
n_train = len(dataset) - n_val
# 将训练集分为两部分,train和val,val用于测试结果并计算val_score
train, val = data.random_split(dataset, [n_train, n_val])
train_loader = data.DataLoader(train, batch_size=batch_size, shuffle=True, num_workers=8, pin_memory=True)
val_loader = data.DataLoader(val, batch_size=batch_size, shuffle=False, num_workers=8, pin_memory=True, drop_last=True)
# 设定scheduler
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min' if net.n_classes > 1 else 'max', patience=2)
# 根据val_score调整学习率
scheduler.step(val_score)

2.5 one-hot to label

torch.argmax()

argmax(self: Tensor, dim: Optional[_int]=None, keepdim: _bool=False)

label to one-hot

n = 5  #类别数
indices = torch.randint(0, n, size=(15,15))  #生成数组元素0~5的二维数组(15*15)
one_hot = torch.nn.functional.one_hot(indices, n)  #size=(15, 15, n)
label = torch.argmax(one_hot,dim=1)

2.6 一些细节

for ori in train_loader:

    scheduler.step(loss) 
# scheduler写在for batch in datasets中会导致lr过早下降,训练缓慢

3. 辅助模块

3.1 进度条tqdm

  In[]: from time import sleep
   ...: from tqdm import tqdm
        # n_train 训练集样本数
        # total 预期的迭代数目,即进度条的最大值
        # desc 进度条的描述
        # unit 进度条右边的unit/s
   ...: with tqdm(total=100, desc='henhenaaaaaaa', unit=' img') as pbar:
   ...:     pbar.set_postfix(**{'loss' : 1}) # 设置右边显示的信息
   ...:     pbar.update(2) # 通过update控制更新的进度
   ...:     sleep(0.1)

henhenaaaaaaa:   2%|▏         | 2/100 [00:00<00:05, 18.61 img/s, loss=1]

4. Unet模型代码

import  torch
from torch import nn
from torch.nn import functional as F
from torchsnooper import snoop
#@snoop
class doubleconv(nn.Module):
    def __init__(self,ch_in,ch_out):
        super(doubleconv, self).__init__()
        self.conv1=nn.Conv2d(ch_in,ch_out,kernel_size=(3,3))
        self.relu=nn.ReLU()
        self.norm=nn.BatchNorm2d(ch_out)
        self.conv2=nn.Conv2d(ch_out,ch_out,kernel_size=(3,3))
    def forward(self,x):
        x = self.relu(self.norm(self.conv1(x)))
        x = self.relu(self.norm(self.conv2(x)))
        return x


class down(nn.Module):
    def __init__(self,ch_in,ch_out):
        super(down, self).__init__()
        self.pool = nn.MaxPool2d(kernel_size=(2,2))
        self.conv = doubleconv(ch_in,ch_out)
    def forward(self,x):
        return self.pool(self.conv(x))

# def down(x,ch_in,ch_out):
#     Pool=nn.MaxPool2d(kernel_size=(2,2))
#     Conv=doubleconv(ch_in,ch_out)
#     return Pool(Conv(x))

class up(nn.Module):
    def __init__(self,ch_in,ch_out):
        super(up, self).__init__()
        self.UpConv = nn.ConvTranspose2d(ch_in, ch_out , kernel_size=2, stride=2)
        # self.UpConv=nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        self.conv=doubleconv(ch_in,ch_out)
    def __call__(self, left,right):
        # print(right.shape)
        right = self.UpConv(right)

        diffY = left.size()[2] - right.size()[2]  # height
        diffX = left.size()[3] - right.size()[3]  # width
        # F.pad 向 [左右上下]填充 0
        # 把right的长宽调整到与left一致
        right = F.pad(right, [diffX // 2, diffX - diffX // 2,
                        diffY // 2, diffY - diffY // 2])
        # print('left : {} ; right : {}'.format(left.shape,right.shape))
        temp=torch.cat([left,right],dim=1)
        # print('temp : {}'.format(temp.shape))
        return self.conv(temp)

class out(nn.Module):
    def __init__(self,ch_in,ch_out):
        super(out, self).__init__()
        self.conv = nn.Conv2d(ch_in, ch_out, kernel_size=1)
    def forward(self, x):
        return self.conv(x)

class unet(nn.Module):
    def __init__(self):
        super(unet, self).__init__()
        self.conv = doubleconv(1,64)

        self.down1 = down(64 , 128)
        self.down2 = down(128, 256)
        self.down3 = down(256, 512)
        self.down4 = down(512, 1024)

        self.up1 = up(1024,512)
        self.up2 = up(512,256)
        self.up3 = up(256,128)
        self.up4 = up(128,64)

        self.out = out(64,2)
    def __call__(self, x):
        x1 = self.conv(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        x  = self.up1(x4,x5)
        x  = self.up2(x3, x)
        x  = self.up3(x2, x)
        x  = self.up4(x1, x)
        return self.out(x)

if __name__ == '__main__':
    device = 'cuda' if torch.cuda.is_available() else 'cpu'

    temp=torch.randn(1,1,572,572).to(device=device)

    model=unet().to(device)
    pre=model(temp)
posted @ 2020-08-06 19:55  绝望的我  阅读(559)  评论(0编辑  收藏  举报