VGG网络的Pytorch实现
1.文章原文地址
Very Deep Convolutional Networks for Large-Scale Image Recognition
2.文章摘要
在这项工作中,我们研究了在大规模的图像识别数据集上卷积神经网络的深度对准确率的影响。我们主要贡献是使用非常小(3×3)卷积核的架构对深度增加的网络进行全面的评估,其结果表明将深度增大到16-19层时网络的性能会显著提升。这些发现是基于我们在ImageNet Challenge 2014的目标检测和分类任务分别获得了第一名和第二名的成绩而得出的。另外该网络也可以很好的推广到其他数据集上,在这些数据集上获得了当前最好结果。我们已经公开了性能最佳的ConvNet模型,为了促进在计算机视觉中使用深度视觉表征的进一步研究。
3.网络结构
4.Pytorch实现
1 import torch.nn as nn 2 try: 3 from torch.hub import load_state_dict_from_url 4 except ImportError: 5 from torch.utils.model_zoo import load_url as load_state_dict_from_url 6 7 __all__ = [ 8 'VGG', 'vgg11', 'vgg11_bn', 'vgg13', 'vgg13_bn', 'vgg16', 'vgg16_bn', 9 'vgg19_bn', 'vgg19', 10 ] 11 12 13 model_urls = { 14 'vgg11': 'https://download.pytorch.org/models/vgg11-bbd30ac9.pth', 15 'vgg13': 'https://download.pytorch.org/models/vgg13-c768596a.pth', 16 'vgg16': 'https://download.pytorch.org/models/vgg16-397923af.pth', 17 'vgg19': 'https://download.pytorch.org/models/vgg19-dcbb9e9d.pth', 18 'vgg11_bn': 'https://download.pytorch.org/models/vgg11_bn-6002323d.pth', 19 'vgg13_bn': 'https://download.pytorch.org/models/vgg13_bn-abd245e5.pth', 20 'vgg16_bn': 'https://download.pytorch.org/models/vgg16_bn-6c64b313.pth', 21 'vgg19_bn': 'https://download.pytorch.org/models/vgg19_bn-c79401a0.pth', 22 } 23 24 25 class VGG(nn.Module): 26 27 def __init__(self, features, num_classes=1000, init_weights=True): 28 super(VGG, self).__init__() 29 self.features = features 30 self.avgpool = nn.AdaptiveAvgPool2d((7, 7)) #固定全连接层的输入 31 self.classifier = nn.Sequential( 32 nn.Linear(512 * 7 * 7, 4096), 33 nn.ReLU(True), 34 nn.Dropout(), 35 nn.Linear(4096, 4096), 36 nn.ReLU(True), 37 nn.Dropout(), 38 nn.Linear(4096, num_classes), 39 ) 40 if init_weights: 41 self._initialize_weights() 42 43 def forward(self, x): 44 x = self.features(x) 45 x = self.avgpool(x) 46 x = x.view(x.size(0), -1) 47 x = self.classifier(x) 48 return x 49 50 def _initialize_weights(self): 51 for m in self.modules(): 52 if isinstance(m, nn.Conv2d): 53 nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 54 if m.bias is not None: 55 nn.init.constant_(m.bias, 0) 56 elif isinstance(m, nn.BatchNorm2d): 57 nn.init.constant_(m.weight, 1) 58 nn.init.constant_(m.bias, 0) 59 elif isinstance(m, nn.Linear): 60 nn.init.normal_(m.weight, 0, 0.01) 61 nn.init.constant_(m.bias, 0) 62 63 64 def make_layers(cfg, batch_norm=False): 65 layers = [] 66 in_channels = 3 67 for v in cfg: 68 if v == 'M': 69 layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 70 else: 71 conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1) 72 if batch_norm: 73 layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)] 74 else: 75 layers += [conv2d, nn.ReLU(inplace=True)] 76 in_channels = v 77 return nn.Sequential(*layers) 78 79 80 cfgs = { 81 'A': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 82 'B': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 83 'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'], 84 'E': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'], 85 } 86 87 88 def _vgg(arch, cfg, batch_norm, pretrained, progress, **kwargs): 89 if pretrained: 90 kwargs['init_weights'] = False 91 model = VGG(make_layers(cfgs[cfg], batch_norm=batch_norm), **kwargs) 92 if pretrained: 93 state_dict = load_state_dict_from_url(model_urls[arch], 94 progress=progress) 95 model.load_state_dict(state_dict) 96 return model 97 98 99 def vgg11(pretrained=False, progress=True, **kwargs): 100 """VGG 11-layer model (configuration "A") 101 Args: 102 pretrained (bool): If True, returns a model pre-trained on ImageNet 103 progress (bool): If True, displays a progress bar of the download to stderr 104 """ 105 return _vgg('vgg11', 'A', False, pretrained, progress, **kwargs) 106 107 108 def vgg11_bn(pretrained=False, progress=True, **kwargs): 109 """VGG 11-layer model (configuration "A") with batch normalization 110 Args: 111 pretrained (bool): If True, returns a model pre-trained on ImageNet 112 progress (bool): If True, displays a progress bar of the download to stderr 113 """ 114 return _vgg('vgg11_bn', 'A', True, pretrained, progress, **kwargs) 115 116 117 def vgg13(pretrained=False, progress=True, **kwargs): 118 """VGG 13-layer model (configuration "B") 119 Args: 120 pretrained (bool): If True, returns a model pre-trained on ImageNet 121 progress (bool): If True, displays a progress bar of the download to stderr 122 """ 123 return _vgg('vgg13', 'B', False, pretrained, progress, **kwargs) 124 125 126 def vgg13_bn(pretrained=False, progress=True, **kwargs): 127 """VGG 13-layer model (configuration "B") with batch normalization 128 Args: 129 pretrained (bool): If True, returns a model pre-trained on ImageNet 130 progress (bool): If True, displays a progress bar of the download to stderr 131 """ 132 return _vgg('vgg13_bn', 'B', True, pretrained, progress, **kwargs) 133 134 135 def vgg16(pretrained=False, progress=True, **kwargs): 136 """VGG 16-layer model (configuration "D") 137 Args: 138 pretrained (bool): If True, returns a model pre-trained on ImageNet 139 progress (bool): If True, displays a progress bar of the download to stderr 140 """ 141 return _vgg('vgg16', 'D', False, pretrained, progress, **kwargs) 142 143 144 def vgg16_bn(pretrained=False, progress=True, **kwargs): 145 """VGG 16-layer model (configuration "D") with batch normalization 146 Args: 147 pretrained (bool): If True, returns a model pre-trained on ImageNet 148 progress (bool): If True, displays a progress bar of the download to stderr 149 """ 150 return _vgg('vgg16_bn', 'D', True, pretrained, progress, **kwargs) 151 152 153 def vgg19(pretrained=False, progress=True, **kwargs): 154 """VGG 19-layer model (configuration "E") 155 Args: 156 pretrained (bool): If True, returns a model pre-trained on ImageNet 157 progress (bool): If True, displays a progress bar of the download to stderr 158 """ 159 return _vgg('vgg19', 'E', False, pretrained, progress, **kwargs) 160 161 162 def vgg19_bn(pretrained=False, progress=True, **kwargs): 163 """VGG 19-layer model (configuration 'E') with batch normalization 164 Args: 165 pretrained (bool): If True, returns a model pre-trained on ImageNet 166 progress (bool): If True, displays a progress bar of the download to stderr 167 """ 168 return _vgg('vgg19_bn', 'E', True, pretrained, progress, **kwargs)
参考
https://github.com/pytorch/vision/tree/master/torchvision/models