U-Net网络的Pytorch实现

1.文章原文地址

U-Net: Convolutional Networks for Biomedical Image Segmentation

2.文章摘要

普遍认为成功训练深度神经网络需要大量标注的训练数据。在本文中,我们提出了一个网络结构,以及使用数据增强的策略来训练网络使得可用的标注样本更加有效的被使用。这个网络是由一个捕捉上下文信息的收缩部分和与之相对称的放大部分,后者能够准确的定位。我们的结果展示了这个网络可以进行端到端的训练,使用非常少的数据就可以达到非常好的结果,并且超过了当前的最佳方法(滑动窗网络)在ISBII挑战赛上电子显微镜下神经结构的分割的结果。利用透射光显微镜图像使用相同网络进行训练,我们大幅度的赢得了2015年的ISBI细胞追踪挑战赛。而且,这个网络非常快,在一个当前的GPU上,分割一个512x512的图像所花费的时间少于一秒。完整的代码以及训练好的网络可见(基于Caffe)http://lmb.informatik.uni-freiburg.de/people/ronneber/u-net.

3.网络结构

4.Pytorch实现

  1 import torch
  2 import torch.nn as nn
  3 import torch.nn.functional as F
  4 from torchsummary import summary
  5 
  6 
  7 class unetConv2(nn.Module):
  8     def __init__(self,in_size,out_size,is_batchnorm):
  9         super(unetConv2,self).__init__()
 10 
 11         if is_batchnorm:
 12             self.conv1=nn.Sequential(
 13                 nn.Conv2d(in_size,out_size,kernel_size=3,stride=1,padding=0),
 14                 nn.BatchNorm2d(out_size),
 15                 nn.ReLU(inplace=True),
 16             )
 17             self.conv2=nn.Sequential(
 18                 nn.Conv2d(out_size,out_size,kernel_size=3,stride=1,padding=0),
 19                 nn.BatchNorm2d(out_size),
 20                 nn.ReLU(inplace=True),
 21             )
 22         else:
 23             self.conv1=nn.Sequential(
 24                 nn.Conv2d(in_size,out_size,kernel_size=3,stride=1,padding=0),
 25                 nn.ReLU(inplace=True),
 26             )
 27             self.conv2=nn.Sequential(
 28                 nn.Conv2d(out_size,out_size,kernel_size=3,stride=1,padding=0),
 29                 nn.ReLU(inplace=True)
 30             )
 31     def forward(self, inputs):
 32         outputs=self.conv1(inputs)
 33         outputs=self.conv2(outputs)
 34 
 35         return outputs
 36 
 37 class unetUp(nn.Module):
 38     def __init__(self,in_size,out_size,is_deconv):
 39         super(unetUp,self).__init__()
 40         self.conv=unetConv2(in_size,out_size,False)
 41         if is_deconv:
 42             self.up=nn.ConvTranspose2d(in_size,out_size,kernel_size=2,stride=2)
 43         else:
 44             self.up=nn.UpsamplingBilinear2d(scale_factor=2)
 45 
 46     def forward(self, inputs1,inputs2):
 47         outputs2=self.up(inputs2)
 48         offset=outputs2.size()[2]-inputs1.size()[2]
 49         padding=2*[offset//2,offset//2]
 50         outputs1=F.pad(inputs1,padding)     #padding is negative, size become smaller
 51 
 52         return self.conv(torch.cat([outputs1,outputs2],1))
 53 
 54 class unet(nn.Module):
 55     def __init__(self,feature_scale=4,n_classes=21,is_deconv=True,in_channels=3,is_batchnorm=True):
 56         super(unet,self).__init__()
 57         self.is_deconv=is_deconv
 58         self.in_channels=in_channels
 59         self.is_batchnorm=is_batchnorm
 60         self.feature_scale=feature_scale
 61 
 62         filters=[64,128,256,512,1024]
 63         filters=[int(x/self.feature_scale) for x in filters]
 64 
 65         #downsample
 66         self.conv1=unetConv2(self.in_channels,filters[0],self.is_batchnorm)
 67         self.maxpool1=nn.MaxPool2d(kernel_size=2)
 68 
 69         self.conv2=unetConv2(filters[0],filters[1],self.is_batchnorm)
 70         self.maxpool2=nn.MaxPool2d(kernel_size=2)
 71 
 72         self.conv3=unetConv2(filters[1],filters[2],self.is_batchnorm)
 73         self.maxpool3=nn.MaxPool2d(kernel_size=2)
 74 
 75         self.conv4=unetConv2(filters[2],filters[3],self.is_batchnorm)
 76         self.maxpool4=nn.MaxPool2d(kernel_size=2)
 77 
 78         self.center=unetConv2(filters[3],filters[4],self.is_batchnorm)
 79 
 80         #umsampling
 81         self.up_concat4=unetUp(filters[4],filters[3],self.is_deconv)
 82         self.up_concat3=unetUp(filters[3],filters[2],self.is_deconv)
 83         self.up_concat2=unetUp(filters[2],filters[1],self.is_deconv)
 84         self.up_concat1=unetUp(filters[1],filters[0],self.is_deconv)
 85 
 86         #final conv (without and concat)
 87         self.final=nn.Conv2d(filters[0],n_classes,kernel_size=1)
 88 
 89     def forward(self, inputs):
 90         conv1=self.conv1(inputs)
 91         maxpool1=self.maxpool1(conv1)
 92 
 93         conv2=self.conv2(maxpool1)
 94         maxpool2=self.maxpool2(conv2)
 95 
 96         conv3=self.conv3(maxpool2)
 97         maxpool3=self.maxpool3(conv3)
 98 
 99         conv4=self.conv4(maxpool3)
100         maxpool4=self.maxpool4(conv4)
101 
102         center=self.center(maxpool4)
103         up4=self.up_concat4(conv4,center)
104         up3=self.up_concat3(conv3,up4)
105         up2=self.up_concat2(conv2,up3)
106         up1=self.up_concat1(conv1,up2)
107 
108         final=self.final(up1)
109 
110         return final
111 
112 if __name__=="__main__":
113     model=unet(feature_scale=1)
114     print(summary(model,(3,572,572)))
 1 ----------------------------------------------------------------
 2         Layer (type)               Output Shape         Param #
 3 ================================================================
 4             Conv2d-1         [-1, 64, 570, 570]           1,792
 5        BatchNorm2d-2         [-1, 64, 570, 570]             128
 6               ReLU-3         [-1, 64, 570, 570]               0
 7             Conv2d-4         [-1, 64, 568, 568]          36,928
 8        BatchNorm2d-5         [-1, 64, 568, 568]             128
 9               ReLU-6         [-1, 64, 568, 568]               0
10          unetConv2-7         [-1, 64, 568, 568]               0
11          MaxPool2d-8         [-1, 64, 284, 284]               0
12             Conv2d-9        [-1, 128, 282, 282]          73,856
13       BatchNorm2d-10        [-1, 128, 282, 282]             256
14              ReLU-11        [-1, 128, 282, 282]               0
15            Conv2d-12        [-1, 128, 280, 280]         147,584
16       BatchNorm2d-13        [-1, 128, 280, 280]             256
17              ReLU-14        [-1, 128, 280, 280]               0
18         unetConv2-15        [-1, 128, 280, 280]               0
19         MaxPool2d-16        [-1, 128, 140, 140]               0
20            Conv2d-17        [-1, 256, 138, 138]         295,168
21       BatchNorm2d-18        [-1, 256, 138, 138]             512
22              ReLU-19        [-1, 256, 138, 138]               0
23            Conv2d-20        [-1, 256, 136, 136]         590,080
24       BatchNorm2d-21        [-1, 256, 136, 136]             512
25              ReLU-22        [-1, 256, 136, 136]               0
26         unetConv2-23        [-1, 256, 136, 136]               0
27         MaxPool2d-24          [-1, 256, 68, 68]               0
28            Conv2d-25          [-1, 512, 66, 66]       1,180,160
29       BatchNorm2d-26          [-1, 512, 66, 66]           1,024
30              ReLU-27          [-1, 512, 66, 66]               0
31            Conv2d-28          [-1, 512, 64, 64]       2,359,808
32       BatchNorm2d-29          [-1, 512, 64, 64]           1,024
33              ReLU-30          [-1, 512, 64, 64]               0
34         unetConv2-31          [-1, 512, 64, 64]               0
35         MaxPool2d-32          [-1, 512, 32, 32]               0
36            Conv2d-33         [-1, 1024, 30, 30]       4,719,616
37       BatchNorm2d-34         [-1, 1024, 30, 30]           2,048
38              ReLU-35         [-1, 1024, 30, 30]               0
39            Conv2d-36         [-1, 1024, 28, 28]       9,438,208
40       BatchNorm2d-37         [-1, 1024, 28, 28]           2,048
41              ReLU-38         [-1, 1024, 28, 28]               0
42         unetConv2-39         [-1, 1024, 28, 28]               0
43   ConvTranspose2d-40          [-1, 512, 56, 56]       2,097,664
44            Conv2d-41          [-1, 512, 54, 54]       4,719,104
45              ReLU-42          [-1, 512, 54, 54]               0
46            Conv2d-43          [-1, 512, 52, 52]       2,359,808
47              ReLU-44          [-1, 512, 52, 52]               0
48         unetConv2-45          [-1, 512, 52, 52]               0
49            unetUp-46          [-1, 512, 52, 52]               0
50   ConvTranspose2d-47        [-1, 256, 104, 104]         524,544
51            Conv2d-48        [-1, 256, 102, 102]       1,179,904
52              ReLU-49        [-1, 256, 102, 102]               0
53            Conv2d-50        [-1, 256, 100, 100]         590,080
54              ReLU-51        [-1, 256, 100, 100]               0
55         unetConv2-52        [-1, 256, 100, 100]               0
56            unetUp-53        [-1, 256, 100, 100]               0
57   ConvTranspose2d-54        [-1, 128, 200, 200]         131,200
58            Conv2d-55        [-1, 128, 198, 198]         295,040
59              ReLU-56        [-1, 128, 198, 198]               0
60            Conv2d-57        [-1, 128, 196, 196]         147,584
61              ReLU-58        [-1, 128, 196, 196]               0
62         unetConv2-59        [-1, 128, 196, 196]               0
63            unetUp-60        [-1, 128, 196, 196]               0
64   ConvTranspose2d-61         [-1, 64, 392, 392]          32,832
65            Conv2d-62         [-1, 64, 390, 390]          73,792
66              ReLU-63         [-1, 64, 390, 390]               0
67            Conv2d-64         [-1, 64, 388, 388]          36,928
68              ReLU-65         [-1, 64, 388, 388]               0
69         unetConv2-66         [-1, 64, 388, 388]               0
70            unetUp-67         [-1, 64, 388, 388]               0
71            Conv2d-68         [-1, 21, 388, 388]           1,365
72 ================================================================
73 Total params: 31,040,981
74 Trainable params: 31,040,981
75 Non-trainable params: 0
76 ----------------------------------------------------------------
77 Input size (MB): 3.74
78 Forward/backward pass size (MB): 3158.15
79 Params size (MB): 118.41
80 Estimated Total Size (MB): 3280.31

参考

https://github.com/meetshah1995/pytorch-semseg

posted @ 2019-05-19 16:28  ysyouaremyall  阅读(6824)  评论(0编辑  收藏  举报