pytorch ResNet

一、ResNet网络结构

1.1ResNet特点

  • 深层网络结构
  • 残差模块

 

 

  •  Batch Normalization加速训练

使一批feature map满足均值为0,方差为1的分布。

ResNet解决了网络层数增加带来的梯度消失,梯度爆炸和梯度退化问题。

1.2网络结构

 

 residua block的虚线代表主分支和shortcut的shape不同,所以要在shortcut中加入kernel,使得输出的维度和主分支相同,才能进行相加。

1.3参数列表

 

 二。模型和训练代码

2.1 model.py

  1 import torch.nn as nn
  2 import torch
  3 
  4 
  5 class BasicBlock(nn.Module):
  6     expansion = 1
  7 
  8     def __init__(self, in_channel, out_channel, stride=1, downsample=None):
  9         super(BasicBlock, self).__init__()
 10         self.conv1 = nn.Conv2d(in_channels=in_channel, out_channels=out_channel,
 11                                kernel_size=3, stride=stride, padding=1, bias=False)
 12         self.bn1 = nn.BatchNorm2d(out_channel)
 13         self.relu = nn.ReLU()
 14         self.conv2 = nn.Conv2d(in_channels=out_channel, out_channels=out_channel,
 15                                kernel_size=3, stride=1, padding=1, bias=False)
 16         self.bn2 = nn.BatchNorm2d(out_channel)
 17         self.downsample = downsample
 18 
 19     def forward(self, x):
 20         identity = x
 21         if self.downsample is not None:
 22             identity = self.downsample(x)
 23 
 24         out = self.conv1(x)
 25         out = self.bn1(out)
 26         out = self.relu(out)
 27 
 28         out = self.conv2(out)
 29         out = self.bn2(out)
 30 
 31         out += identity
 32         out = self.relu(out)
 33 
 34         return out
 35 
 36 
 37 class Bottleneck(nn.Module):
 38     expansion = 4
 39 
 40     def __init__(self, in_channel, out_channel, stride=1, downsample=None):
 41         super(Bottleneck, self).__init__()
 42         self.conv1 = nn.Conv2d(in_channels=in_channel, out_channels=out_channel,
 43                                kernel_size=1, stride=1, bias=False)  # squeeze channels
 44         self.bn1 = nn.BatchNorm2d(out_channel)
 45         # -----------------------------------------
 46         self.conv2 = nn.Conv2d(in_channels=out_channel, out_channels=out_channel,
 47                                kernel_size=3, stride=stride, bias=False, padding=1)
 48         self.bn2 = nn.BatchNorm2d(out_channel)
 49         # -----------------------------------------
 50         self.conv3 = nn.Conv2d(in_channels=out_channel, out_channels=out_channel*self.expansion,
 51                                kernel_size=1, stride=1, bias=False)  # unsqueeze channels
 52         self.bn3 = nn.BatchNorm2d(out_channel*self.expansion)
 53         self.relu = nn.ReLU(inplace=True)
 54         self.downsample = downsample
 55 
 56     def forward(self, x):
 57         identity = x
 58         if self.downsample is not None:
 59             identity = self.downsample(x)
 60 
 61         out = self.conv1(x)
 62         out = self.bn1(out)
 63         out = self.relu(out)
 64 
 65         out = self.conv2(out)
 66         out = self.bn2(out)
 67         out = self.relu(out)
 68 
 69         out = self.conv3(out)
 70         out = self.bn3(out)
 71 
 72         out += identity
 73         out = self.relu(out)
 74 
 75         return out
 76 
 77 
 78 class ResNet(nn.Module):
 79 
 80     def __init__(self, block, blocks_num, num_classes=1000, include_top=True):
 81         super(ResNet, self).__init__()
 82         self.include_top = include_top
 83         self.in_channel = 64
 84 
 85         self.conv1 = nn.Conv2d(3, self.in_channel, kernel_size=7, stride=2,
 86                                padding=3, bias=False)
 87         self.bn1 = nn.BatchNorm2d(self.in_channel)
 88         self.relu = nn.ReLU(inplace=True)
 89         self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
 90         self.layer1 = self._make_layer(block, 64, blocks_num[0])
 91         self.layer2 = self._make_layer(block, 128, blocks_num[1], stride=2)
 92         self.layer3 = self._make_layer(block, 256, blocks_num[2], stride=2)
 93         self.layer4 = self._make_layer(block, 512, blocks_num[3], stride=2)
 94         if self.include_top:
 95             self.avgpool = nn.AdaptiveAvgPool2d((1, 1))  # output size = (1, 1)
 96             self.fc = nn.Linear(512 * block.expansion, num_classes)
 97 
 98         for m in self.modules():
 99             if isinstance(m, nn.Conv2d):
100                 nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
101 
102     def _make_layer(self, block, channel, block_num, stride=1):
103         downsample = None
104         if stride != 1 or self.in_channel != channel * block.expansion:
105             downsample = nn.Sequential(
106                 nn.Conv2d(self.in_channel, channel * block.expansion, kernel_size=1, stride=stride, bias=False),
107                 nn.BatchNorm2d(channel * block.expansion))
108 
109         layers = []
110         layers.append(block(self.in_channel, channel, downsample=downsample, stride=stride))
111         self.in_channel = channel * block.expansion
112 
113         for _ in range(1, block_num):
114             layers.append(block(self.in_channel, channel))
115 
116         return nn.Sequential(*layers)
117 
118     def forward(self, x):
119         x = self.conv1(x)
120         x = self.bn1(x)
121         x = self.relu(x)
122         x = self.maxpool(x)
123 
124         x = self.layer1(x)
125         x = self.layer2(x)
126         x = self.layer3(x)
127         x = self.layer4(x)
128 
129         if self.include_top:
130             x = self.avgpool(x)
131             x = torch.flatten(x, 1)
132             x = self.fc(x)
133 
134         return x
135 
136 
137 def resnet34(num_classes=1000, include_top=True):
138     return ResNet(BasicBlock, [3, 4, 6, 3], num_classes=num_classes, include_top=include_top)
139 
140 
141 def resnet101(num_classes=1000, include_top=True):
142     return ResNet(Bottleneck, [3, 4, 23, 3], num_classes=num_classes, include_top=include_top)

2.2 train.py 带迁移学习

  1 import torch
  2 import torch.nn as nn
  3 from torchvision import transforms, datasets
  4 import json
  5 import matplotlib.pyplot as plt
  6 import os
  7 import torch.optim as optim
  8 from model import resnet34, resnet101
  9 
 10 import torchvision.models.resnet
 11 def main():
 12     device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
 13     print("using {} device.".format(device))
 14 
 15     data_transform = {
 16         "train": transforms.Compose([transforms.RandomResizedCrop(224),
 17                                      transforms.RandomHorizontalFlip(),
 18                                      transforms.ToTensor(),
 19                                      transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),
 20         "val": transforms.Compose([transforms.Resize(256),
 21                                    transforms.CenterCrop(224),
 22                                    transforms.ToTensor(),
 23                                    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])}
 24 
 25     data_root = os.path.abspath(os.path.join(os.getcwd(), "../.."))  # get data root path
 26     image_path = os.path.join(data_root, "data_set", "flower_data")  # flower data set path
 27     assert os.path.exists(image_path), "{} path does not exist.".format(image_path)
 28     train_dataset = datasets.ImageFolder(root=os.path.join(image_path, "train"),
 29                                          transform=data_transform["train"])
 30     train_num = len(train_dataset)
 31 
 32     # {'daisy':0, 'dandelion':1, 'roses':2, 'sunflower':3, 'tulips':4}
 33     flower_list = train_dataset.class_to_idx
 34     cla_dict = dict((val, key) for key, val in flower_list.items())
 35     # write dict into json file
 36     json_str = json.dumps(cla_dict, indent=4)
 37     with open('class_indices.json', 'w') as json_file:
 38         json_file.write(json_str)
 39 
 40     batch_size = 16
 41     nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8])  # number of workers
 42     print('Using {} dataloader workers every process'.format(0))
 43 
 44     train_loader = torch.utils.data.DataLoader(train_dataset,
 45                                                batch_size=batch_size, shuffle=True,
 46                                                num_workers=0)
 47 
 48     validate_dataset = datasets.ImageFolder(root=os.path.join(image_path, "val"),
 49                                             transform=data_transform["val"])
 50     val_num = len(validate_dataset)
 51     validate_loader = torch.utils.data.DataLoader(validate_dataset,
 52                                                   batch_size=batch_size, shuffle=False,
 53                                                   num_workers=0)
 54 
 55     print("using {} images for training, {} images fot validation.".format(train_num,
 56                                                                            val_num))
 57     
 58     net = resnet34()
 59     # load pretrain weights transfer learning
 60     # download url: https://download.pytorch.org/models/resnet34-333f7ec4.pth
 61     model_weight_path = "./resnet34-pre.pth"
 62     assert os.path.exists(model_weight_path), "file {} does not exist.".format(model_weight_path)
 63     missing_keys, unexpected_keys = net.load_state_dict(torch.load(model_weight_path), strict=False)
 64     # for param in net.parameters():
 65     #     param.requires_grad = False
 66     # change fc layer structure
 67     in_channel = net.fc.in_features
 68     net.fc = nn.Linear(in_channel, 5)
 69     net.to(device)
 70 
 71     loss_function = nn.CrossEntropyLoss()
 72     optimizer = optim.Adam(net.parameters(), lr=0.0001)
 73 
 74     best_acc = 0.0
 75     save_path = './resNet34.pth'
 76     for epoch in range(3):
 77         # train
 78         net.train()
 79         running_loss = 0.0
 80         for step, data in enumerate(train_loader, start=0):
 81             images, labels = data
 82             optimizer.zero_grad()
 83             logits = net(images.to(device))
 84             loss = loss_function(logits, labels.to(device))
 85             loss.backward()
 86             optimizer.step()
 87 
 88             # print statistics
 89             running_loss += loss.item()
 90             # print train process
 91             rate = (step+1)/len(train_loader)
 92             a = "*" * int(rate * 50)
 93             b = "." * int((1 - rate) * 50)
 94             print("\rtrain loss: {:^3.0f}%[{}->{}]{:.4f}".format(int(rate*100), a, b, loss), end="")
 95         print()
 96 
 97         # validate
 98         net.eval()
 99         acc = 0.0  # accumulate accurate number / epoch
100         with torch.no_grad():
101             for val_data in validate_loader:
102                 val_images, val_labels = val_data
103                 outputs = net(val_images.to(device))  # eval model only have last output layer
104                 # loss = loss_function(outputs, test_labels)
105                 predict_y = torch.max(outputs, dim=1)[1]
106                 acc += (predict_y == val_labels.to(device)).sum().item()
107             val_accurate = acc / val_num
108             if val_accurate > best_acc:
109                 best_acc = val_accurate
110                 torch.save(net.state_dict(), save_path)
111             print('[epoch %d] train_loss: %.3f  test_accuracy: %.3f' %
112                   (epoch + 1, running_loss / step, val_accurate))
113 
114     print('Finished Training')
115 
116 
117 if __name__ == '__main__':
118     main()

2.3predict.py

 1 import torch
 2 from model import resnet34
 3 from PIL import Image
 4 from torchvision import transforms
 5 import matplotlib.pyplot as plt
 6 import json
 7 
 8 device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
 9 
10 data_transform = transforms.Compose(
11     [transforms.Resize(256),
12      transforms.CenterCrop(224),
13      transforms.ToTensor(),
14      transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
15 
16 # load image
17 img = Image.open("../rose.jpg")
18 plt.imshow(img)
19 # [N, C, H, W]
20 img = data_transform(img)
21 # expand batch dimension
22 img = torch.unsqueeze(img, dim=0)
23 
24 # read class_indict
25 try:
26     json_file = open('./class_indices.json', 'r')
27     class_indict = json.load(json_file)
28 except Exception as e:
29     print(e)
30     exit(-1)
31 
32 # create model
33 model = resnet34(num_classes=5)
34 # load model weights
35 model_weight_path = "./resNet34.pth"
36 model.load_state_dict(torch.load(model_weight_path, map_location=device))
37 model.eval()
38 with torch.no_grad():
39     # predict class
40     output = torch.squeeze(model(img))
41     predict = torch.softmax(output, dim=0)
42     predict_cla = torch.argmax(predict).numpy()
43 print(class_indict[str(predict_cla)], predict[predict_cla].numpy())
44 plt.show()

 

posted @ 2020-12-20 20:46  荼离伤花  阅读(454)  评论(0编辑  收藏  举报