SegNet网络的Pytorch实现
1.文章原文地址
SegNet: A Deep Convolutional Encoder-Decoder Architecture for Image Segmentation
2.文章摘要
语义分割具有非常广泛的应用,从场景理解、目标相互关系推断到自动驾驶。早期依赖于低水平视觉线索的方法已经快速的被流行的机器学习算法所取代。特别是最近的深度学习在手写数字识别、语音、图像中的分类和目标检测上取得巨大成功。如今有一个活跃的领域是语义分割(对每个像素进行归类)。然而,最近有一些方法直接采用了为图像分类而设计的网络结构来进行语义分割任务。虽然结果十分鼓舞人心,但还是比较粗糙。这首要的原因是最大池化和下采样减小了特征图的分辨率。我们设计SegNet的动机来自于分割任务需要将低分辨率的特征图映射到输入的分辨率并进行像素级分类,这个映射必须产生对准确边界定位有用的特征。
3.网络结构
4.Pytorch实现
1 import torch.nn as nn 2 import torch 3 4 class conv2DBatchNormRelu(nn.Module): 5 def __init__(self,in_channels,out_channels,kernel_size,stride,padding, 6 bias=True,dilation=1,is_batchnorm=True): 7 super(conv2DBatchNormRelu,self).__init__() 8 if is_batchnorm: 9 self.cbr_unit=nn.Sequential( 10 nn.Conv2d(in_channels,out_channels,kernel_size=kernel_size,stride=stride,padding=padding, 11 bias=bias,dilation=dilation), 12 nn.BatchNorm2d(out_channels), 13 nn.ReLU(inplace=True), 14 ) 15 else: 16 self.cbr_unit=nn.Sequential( 17 nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, 18 bias=bias, dilation=dilation), 19 nn.ReLU(inplace=True) 20 ) 21 22 def forward(self,inputs): 23 outputs=self.cbr_unit(inputs) 24 return outputs 25 26 class segnetDown2(nn.Module): 27 def __init__(self,in_channels,out_channels): 28 super(segnetDown2,self).__init__() 29 self.conv1=conv2DBatchNormRelu(in_channels,out_channels,kernel_size=3,stride=1,padding=1) 30 self.conv2=conv2DBatchNormRelu(out_channels,out_channels,kernel_size=3,stride=1,padding=1) 31 self.maxpool_with_argmax=nn.MaxPool2d(kernel_size=2,stride=2,return_indices=True) 32 33 def forward(self,inputs): 34 outputs=self.conv1(inputs) 35 outputs=self.conv2(outputs) 36 unpooled_shape=outputs.size() 37 outputs,indices=self.maxpool_with_argmax(outputs) 38 return outputs,indices,unpooled_shape 39 40 class segnetDown3(nn.Module): 41 def __init__(self,in_channels,out_channels): 42 super(segnetDown3,self).__init__() 43 self.conv1=conv2DBatchNormRelu(in_channels,out_channels,kernel_size=3,stride=1,padding=1) 44 self.conv2=conv2DBatchNormRelu(out_channels,out_channels,kernel_size=3,stride=1,padding=1) 45 self.conv3=conv2DBatchNormRelu(out_channels,out_channels,kernel_size=3,stride=1,padding=1) 46 self.maxpool_with_argmax=nn.MaxPool2d(kernel_size=2,stride=2,return_indices=True) 47 48 def forward(self,inputs): 49 outputs=self.conv1(inputs) 50 outputs=self.conv2(outputs) 51 outputs=self.conv3(outputs) 52 unpooled_shape=outputs.size() 53 outputs,indices=self.maxpool_with_argmax(outputs) 54 return outputs,indices,unpooled_shape 55 56 57 class segnetUp2(nn.Module): 58 def __init__(self,in_channels,out_channels): 59 super(segnetUp2,self).__init__() 60 self.unpool=nn.MaxUnpool2d(2,2) 61 self.conv1=conv2DBatchNormRelu(in_channels,out_channels,kernel_size=3,stride=1,padding=1) 62 self.conv2=conv2DBatchNormRelu(out_channels,out_channels,kernel_size=3,stride=1,padding=1) 63 64 def forward(self,inputs,indices,output_shape): 65 outputs=self.unpool(inputs,indices=indices,output_size=output_shape) 66 outputs=self.conv1(outputs) 67 outputs=self.conv2(outputs) 68 return outputs 69 70 class segnetUp3(nn.Module): 71 def __init__(self,in_channels,out_channels): 72 super(segnetUp3,self).__init__() 73 self.unpool=nn.MaxUnpool2d(2,2) 74 self.conv1=conv2DBatchNormRelu(in_channels,out_channels,kernel_size=3,stride=1,padding=1) 75 self.conv2=conv2DBatchNormRelu(out_channels,out_channels,kernel_size=3,stride=1,padding=1) 76 self.conv3=conv2DBatchNormRelu(out_channels,out_channels,kernel_size=3,stride=1,padding=1) 77 78 def forward(self,inputs,indices,output_shape): 79 outputs=self.unpool(inputs,indices=indices,output_size=output_shape) 80 outputs=self.conv1(outputs) 81 outputs=self.conv2(outputs) 82 outputs=self.conv3(outputs) 83 return outputs 84 85 class segnet(nn.Module): 86 def __init__(self,in_channels=3,num_classes=21): 87 super(segnet,self).__init__() 88 self.down1=segnetDown2(in_channels=in_channels,out_channels=64) 89 self.down2=segnetDown2(64,128) 90 self.down3=segnetDown3(128,256) 91 self.down4=segnetDown3(256,512) 92 self.down5=segnetDown3(512,512) 93 94 self.up5=segnetUp3(512,512) 95 self.up4=segnetUp3(512,256) 96 self.up3=segnetUp3(256,128) 97 self.up2=segnetUp2(128,64) 98 self.up1=segnetUp2(64,64) 99 self.finconv=conv2DBatchNormRelu(64,num_classes,3,1,1) 100 101 def forward(self,inputs): 102 down1,indices_1,unpool_shape1=self.down1(inputs) 103 down2,indices_2,unpool_shape2=self.down2(down1) 104 down3,indices_3,unpool_shape3=self.down3(down2) 105 down4,indices_4,unpool_shape4=self.down4(down3) 106 down5,indices_5,unpool_shape5=self.down5(down4) 107 108 up5=self.up5(down5,indices=indices_5,output_shape=unpool_shape5) 109 up4=self.up4(up5,indices=indices_4,output_shape=unpool_shape4) 110 up3=self.up3(up4,indices=indices_3,output_shape=unpool_shape3) 111 up2=self.up2(up3,indices=indices_2,output_shape=unpool_shape2) 112 up1=self.up1(up2,indices=indices_1,output_shape=unpool_shape1) 113 outputs=self.finconv(up1) 114 115 return outputs 116 117 if __name__=="__main__": 118 inputs=torch.ones(1,3,224,224) 119 model=segnet() 120 print(model(inputs).size()) 121 print(model)
参考