全网最详细的深度学习经典模型RESNET解析【京东特邀专家 朱利明】(bilibili视频学习)(代码解析)
1 import torch 2 import torch.nn as nn 3 from .utils import load_state_dict_from_url 4 5 6 __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 7 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d', 8 'wide_resnet50_2', 'wide_resnet101_2'] 9 10 11 model_urls = { 12 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 13 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 14 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 15 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 16 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 17 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth', 18 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth', 19 'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth', 20 'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth', 21 } 22 23 # 封装 24 def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): 25 """3x3 convolution with padding""" 26 return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 27 padding=dilation, groups=groups, bias=False, dilation=dilation) 28 29 30 def conv1x1(in_planes, out_planes, stride=1): 31 """1x1 convolution""" 32 return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 33 34 #block定义 35 class BasicBlock(nn.Module): 36 expansion = 1 37 #定义 38 def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 39 base_width=64, dilation=1, norm_layer=None): 40 super(BasicBlock, self).__init__() 41 if norm_layer is None: 42 norm_layer = nn.BatchNorm2d 43 if groups != 1 or base_width != 64: 44 raise ValueError('BasicBlock only supports groups=1 and base_width=64') 45 if dilation > 1: 46 raise NotImplementedError("Dilation > 1 not supported in BasicBlock") 47 # Both self.conv1 and self.downsample layers downsample the input when stride != 1 48 self.conv1 = conv3x3(inplanes, planes, stride) 49 self.bn1 = norm_layer(planes) # 归一化 50 self.relu = nn.ReLU(inplace=True) 51 self.conv2 = conv3x3(planes, planes) 52 self.bn2 = norm_layer(planes) 53 self.downsample = downsample 54 self.stride = stride 55 #实现 56 def forward(self, x): 57 # 保存x 做残差 58 identity = x 59 60 out = self.conv1(x) 61 out = self.bn1(out) 62 out = self.relu(out) 63 64 out = self.conv2(out) 65 out = self.bn2(out) 66 67 # 下采样 68 if self.downsample is not None: 69 identity = self.downsample(x) 70 71 out += identity # 先和 X 融合,再做relu 72 out = self.relu(out) 73 74 return out 75 76 # 瓶颈(50层以上) 77 class Bottleneck(nn.Module): 78 # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2) 79 # while original implementation places the stride at the first 1x1 convolution(self.conv1) 80 # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385. 81 # This variant is also known as ResNet V1.5 and improves accuracy according to 82 # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch. 83 84 # 通道放大的倍数 85 expansion = 4 86 87 def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 88 base_width=64, dilation=1, norm_layer=None): 89 super(Bottleneck, self).__init__() 90 if norm_layer is None: 91 norm_layer = nn.BatchNorm2d 92 width = int(planes * (base_width / 64.)) * groups 93 # Both self.conv2 and self.downsample layers downsample the input when stride != 1 94 self.conv1 = conv1x1(inplanes, width) 95 self.bn1 = norm_layer(width) 96 self.conv2 = conv3x3(width, width, stride, groups, dilation) 97 self.bn2 = norm_layer(width) 98 self.conv3 = conv1x1(width, planes * self.expansion) 99 self.bn3 = norm_layer(planes * self.expansion) 100 self.relu = nn.ReLU(inplace=True) 101 self.downsample = downsample 102 self.stride = stride 103 104 # 网络前向传播过程(调用过程) 105 def forward(self, x): 106 identity = x 107 108 out = self.conv1(x) 109 out = self.bn1(out) 110 out = self.relu(out) 111 112 out = self.conv2(out) 113 out = self.bn2(out) 114 out = self.relu(out) 115 116 out = self.conv3(out) 117 out = self.bn3(out) 118 119 if self.downsample is not None: 120 identity = self.downsample(x) 121 122 out += identity 123 out = self.relu(out) 124 125 return out 126 127 128 class ResNet(nn.Module): 129 130 def __init__(self, block, layers, num_classes=1000, zero_init_residual=False, 131 groups=1, width_per_group=64, replace_stride_with_dilation=None, 132 norm_layer=None): 133 super(ResNet, self).__init__() 134 if norm_layer is None: 135 norm_layer = nn.BatchNorm2d 136 self._norm_layer = norm_layer 137 138 self.inplanes = 64 139 self.dilation = 1 140 if replace_stride_with_dilation is None: 141 # each element in the tuple indicates if we should replace 142 # the 2x2 stride with a dilated convolution instead 143 replace_stride_with_dilation = [False, False, False] 144 if len(replace_stride_with_dilation) != 3: 145 raise ValueError("replace_stride_with_dilation should be None " 146 "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) 147 self.groups = groups 148 self.base_width = width_per_group 149 self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, 150 bias=False) 151 self.bn1 = norm_layer(self.inplanes) 152 self.relu = nn.ReLU(inplace=True) 153 self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 154 #每一层stage 155 self.layer1 = self._make_layer(block, 64, layers[0]) 156 self.layer2 = self._make_layer(block, 128, layers[1], stride=2, 157 dilate=replace_stride_with_dilation[0]) 158 self.layer3 = self._make_layer(block, 256, layers[2], stride=2, 159 dilate=replace_stride_with_dilation[1]) 160 self.layer4 = self._make_layer(block, 512, layers[3], stride=2, 161 dilate=replace_stride_with_dilation[2]) 162 self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 163 self.fc = nn.Linear(512 * block.expansion, num_classes) 164 165 #参数初始化 166 for m in self.modules(): 167 if isinstance(m, nn.Conv2d): 168 nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 169 elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 170 nn.init.constant_(m.weight, 1) 171 nn.init.constant_(m.bias, 0) 172 173 # Zero-initialize the last BN in each residual branch, 174 # so that the residual branch starts with zeros, and each residual block behaves like an identity. 175 # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 176 if zero_init_residual: 177 for m in self.modules(): 178 if isinstance(m, Bottleneck): 179 nn.init.constant_(m.bn3.weight, 0) 180 elif isinstance(m, BasicBlock): 181 nn.init.constant_(m.bn2.weight, 0) 182 183 def _make_layer(self, block, planes, blocks, stride=1, dilate=False): 184 norm_layer = self._norm_layer 185 downsample = None 186 previous_dilation = self.dilation 187 if dilate: 188 self.dilation *= stride 189 stride = 1 190 if stride != 1 or self.inplanes != planes * block.expansion: 191 downsample = nn.Sequential( 192 conv1x1(self.inplanes, planes * block.expansion, stride), 193 norm_layer(planes * block.expansion), 194 ) 195 196 layers = [] 197 layers.append(block(self.inplanes, planes, stride, downsample, self.groups, 198 self.base_width, previous_dilation, norm_layer)) 199 self.inplanes = planes * block.expansion 200 for _ in range(1, blocks): 201 layers.append(block(self.inplanes, planes, groups=self.groups, 202 base_width=self.base_width, dilation=self.dilation, 203 norm_layer=norm_layer)) 204 205 return nn.Sequential(*layers) 206 207 def _forward_impl(self, x): 208 # See note [TorchScript super()] 209 x = self.conv1(x) 210 x = self.bn1(x) 211 x = self.relu(x) 212 x = self.maxpool(x) 213 214 x = self.layer1(x) 215 x = self.layer2(x) 216 x = self.layer3(x) 217 x = self.layer4(x) 218 219 x = self.avgpool(x) 220 x = torch.flatten(x, 1) 221 x = self.fc(x) 222 223 return x 224 225 def forward(self, x): 226 return self._forward_impl(x) 227 228 229 def _resnet(arch, block, layers, pretrained, progress, **kwargs): 230 model = ResNet(block, layers, **kwargs) 231 if pretrained: 232 state_dict = load_state_dict_from_url(model_urls[arch], 233 progress=progress) 234 model.load_state_dict(state_dict) 235 return model 236 237 238 def resnet18(pretrained=False, progress=True, **kwargs): 239 r"""ResNet-18 model from 240 `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_ 241 Args: 242 pretrained (bool): If True, returns a model pre-trained on ImageNet 243 progress (bool): If True, displays a progress bar of the download to stderr 244 """ 245 return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress, 246 **kwargs) 247 248 249 def resnet34(pretrained=False, progress=True, **kwargs): 250 r"""ResNet-34 model from 251 `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_ 252 Args: 253 pretrained (bool): If True, returns a model pre-trained on ImageNet 254 progress (bool): If True, displays a progress bar of the download to stderr 255 """ 256 return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress, 257 **kwargs) 258 259 260 def resnet50(pretrained=False, progress=True, **kwargs): 261 r"""ResNet-50 model from 262 `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_ 263 Args: 264 pretrained (bool): If True, returns a model pre-trained on ImageNet 265 progress (bool): If True, displays a progress bar of the download to stderr 266 """ 267 return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress, 268 **kwargs) 269 270 271 def resnet101(pretrained=False, progress=True, **kwargs): 272 r"""ResNet-101 model from 273 `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_ 274 Args: 275 pretrained (bool): If True, returns a model pre-trained on ImageNet 276 progress (bool): If True, displays a progress bar of the download to stderr 277 """ 278 return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress, 279 **kwargs) 280 281 282 def resnet152(pretrained=False, progress=True, **kwargs): 283 r"""ResNet-152 model from 284 `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_ 285 Args: 286 pretrained (bool): If True, returns a model pre-trained on ImageNet 287 progress (bool): If True, displays a progress bar of the download to stderr 288 """ 289 return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress, 290 **kwargs) 291 292 293 def resnext50_32x4d(pretrained=False, progress=True, **kwargs): 294 r"""ResNeXt-50 32x4d model from 295 `"Aggregated Residual Transformation for Deep Neural Networks" <https://arxiv.org/pdf/1611.05431.pdf>`_ 296 Args: 297 pretrained (bool): If True, returns a model pre-trained on ImageNet 298 progress (bool): If True, displays a progress bar of the download to stderr 299 """ 300 kwargs['groups'] = 32 301 kwargs['width_per_group'] = 4 302 return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3], 303 pretrained, progress, **kwargs) 304 305 306 def resnext101_32x8d(pretrained=False, progress=True, **kwargs): 307 r"""ResNeXt-101 32x8d model from 308 `"Aggregated Residual Transformation for Deep Neural Networks" <https://arxiv.org/pdf/1611.05431.pdf>`_ 309 Args: 310 pretrained (bool): If True, returns a model pre-trained on ImageNet 311 progress (bool): If True, displays a progress bar of the download to stderr 312 """ 313 kwargs['groups'] = 32 314 kwargs['width_per_group'] = 8 315 return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3], 316 pretrained, progress, **kwargs) 317 318 319 def wide_resnet50_2(pretrained=False, progress=True, **kwargs): 320 r"""Wide ResNet-50-2 model from 321 `"Wide Residual Networks" <https://arxiv.org/pdf/1605.07146.pdf>`_ 322 The model is the same as ResNet except for the bottleneck number of channels 323 which is twice larger in every block. The number of channels in outer 1x1 324 convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 325 channels, and in Wide ResNet-50-2 has 2048-1024-2048. 326 Args: 327 pretrained (bool): If True, returns a model pre-trained on ImageNet 328 progress (bool): If True, displays a progress bar of the download to stderr 329 """ 330 kwargs['width_per_group'] = 64 * 2 331 return _resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3], 332 pretrained, progress, **kwargs) 333 334 335 def wide_resnet101_2(pretrained=False, progress=True, **kwargs): 336 r"""Wide ResNet-101-2 model from 337 `"Wide Residual Networks" <https://arxiv.org/pdf/1605.07146.pdf>`_ 338 The model is the same as ResNet except for the bottleneck number of channels 339 which is twice larger in every block. The number of channels in outer 1x1 340 convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 341 channels, and in Wide ResNet-50-2 has 2048-1024-2048. 342 Args: 343 pretrained (bool): If True, returns a model pre-trained on ImageNet 344 progress (bool): If True, displays a progress bar of the download to stderr 345 """ 346 kwargs['width_per_group'] = 64 * 2 347 return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3], 348 pretrained, progress, **kwargs)