pytorch GoogLeNet

一。GoogLeNet网络结构

1.特点:

采用inspection结构和2个辅助的分类器。inspection结构是并行结构。加入了1x1的卷积核来实现降维,能够减少训练参数。

2.网络结构

 

 3.Inspection结构

 

 4.参数列表

 

 二。训练代码

model.py

  1 import torch.nn as nn
  2 import torch
  3 import torch.nn.functional as F
  4 
  5 
  6 class GoogLeNet(nn.Module):
  7     def __init__(self, num_classes=1000, aux_logits=True, init_weights=False):
  8         super(GoogLeNet, self).__init__()
  9         self.aux_logits = aux_logits
 10 
 11         self.conv1 = BasicConv2d(3, 64, kernel_size=7, stride=2, padding=3)
 12         self.maxpool1 = nn.MaxPool2d(3, stride=2, ceil_mode=True)
 13 
 14         self.conv2 = BasicConv2d(64, 64, kernel_size=1)
 15         self.conv3 = BasicConv2d(64, 192, kernel_size=3, padding=1)
 16         self.maxpool2 = nn.MaxPool2d(3, stride=2, ceil_mode=True)
 17 
 18         self.inception3a = Inception(192, 64, 96, 128, 16, 32, 32)
 19         self.inception3b = Inception(256, 128, 128, 192, 32, 96, 64)
 20         self.maxpool3 = nn.MaxPool2d(3, stride=2, ceil_mode=True)
 21 
 22         self.inception4a = Inception(480, 192, 96, 208, 16, 48, 64)
 23         self.inception4b = Inception(512, 160, 112, 224, 24, 64, 64)
 24         self.inception4c = Inception(512, 128, 128, 256, 24, 64, 64)
 25         self.inception4d = Inception(512, 112, 144, 288, 32, 64, 64)
 26         self.inception4e = Inception(528, 256, 160, 320, 32, 128, 128)
 27         self.maxpool4 = nn.MaxPool2d(3, stride=2, ceil_mode=True)
 28 
 29         self.inception5a = Inception(832, 256, 160, 320, 32, 128, 128)
 30         self.inception5b = Inception(832, 384, 192, 384, 48, 128, 128)
 31 
 32         if self.aux_logits:
 33             self.aux1 = InceptionAux(512, num_classes)
 34             self.aux2 = InceptionAux(528, num_classes)
 35 
 36         self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
 37         self.dropout = nn.Dropout(0.4)
 38         self.fc = nn.Linear(1024, num_classes)
 39         if init_weights:
 40             self._initialize_weights()
 41 
 42     def forward(self, x):
 43         # N x 3 x 224 x 224
 44         x = self.conv1(x)
 45         # N x 64 x 112 x 112
 46         x = self.maxpool1(x)
 47         # N x 64 x 56 x 56
 48         x = self.conv2(x)
 49         # N x 64 x 56 x 56
 50         x = self.conv3(x)
 51         # N x 192 x 56 x 56
 52         x = self.maxpool2(x)
 53 
 54         # N x 192 x 28 x 28
 55         x = self.inception3a(x)
 56         # N x 256 x 28 x 28
 57         x = self.inception3b(x)
 58         # N x 480 x 28 x 28
 59         x = self.maxpool3(x)
 60         # N x 480 x 14 x 14
 61         x = self.inception4a(x)
 62         # N x 512 x 14 x 14
 63         if self.training and self.aux_logits:    # eval model lose this layer
 64             aux1 = self.aux1(x)
 65 
 66         x = self.inception4b(x)
 67         # N x 512 x 14 x 14
 68         x = self.inception4c(x)
 69         # N x 512 x 14 x 14
 70         x = self.inception4d(x)
 71         # N x 528 x 14 x 14
 72         if self.training and self.aux_logits:    # eval model lose this layer
 73             aux2 = self.aux2(x)
 74 
 75         x = self.inception4e(x)
 76         # N x 832 x 14 x 14
 77         x = self.maxpool4(x)
 78         # N x 832 x 7 x 7
 79         x = self.inception5a(x)
 80         # N x 832 x 7 x 7
 81         x = self.inception5b(x)
 82         # N x 1024 x 7 x 7
 83 
 84         x = self.avgpool(x)
 85         # N x 1024 x 1 x 1
 86         x = torch.flatten(x, 1)
 87         # N x 1024
 88         x = self.dropout(x)
 89         x = self.fc(x)
 90         # N x 1000 (num_classes)
 91         if self.training and self.aux_logits:   # eval model lose this layer
 92             return x, aux2, aux1
 93         return x
 94 
 95     def _initialize_weights(self):
 96         for m in self.modules():
 97             if isinstance(m, nn.Conv2d):
 98                 nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
 99                 if m.bias is not None:
100                     nn.init.constant_(m.bias, 0)
101             elif isinstance(m, nn.Linear):
102                 nn.init.normal_(m.weight, 0, 0.01)
103                 nn.init.constant_(m.bias, 0)
104 
105 
106 class Inception(nn.Module):
107     def __init__(self, in_channels, ch1x1, ch3x3red, ch3x3, ch5x5red, ch5x5, pool_proj):
108         super(Inception, self).__init__()
109 
110         self.branch1 = BasicConv2d(in_channels, ch1x1, kernel_size=1)
111 
112         self.branch2 = nn.Sequential(
113             BasicConv2d(in_channels, ch3x3red, kernel_size=1),
114             BasicConv2d(ch3x3red, ch3x3, kernel_size=3, padding=1)   # 保证输出大小等于输入大小
115         )
116 
117         self.branch3 = nn.Sequential(
118             BasicConv2d(in_channels, ch5x5red, kernel_size=1),
119             BasicConv2d(ch5x5red, ch5x5, kernel_size=5, padding=2)   # 保证输出大小等于输入大小
120         )
121 
122         self.branch4 = nn.Sequential(
123             nn.MaxPool2d(kernel_size=3, stride=1, padding=1),
124             BasicConv2d(in_channels, pool_proj, kernel_size=1)
125         )
126 
127     def forward(self, x):
128         branch1 = self.branch1(x)
129         branch2 = self.branch2(x)
130         branch3 = self.branch3(x)
131         branch4 = self.branch4(x)
132 
133         outputs = [branch1, branch2, branch3, branch4]
134         return torch.cat(outputs, 1)
135 
136 
137 class InceptionAux(nn.Module):
138     def __init__(self, in_channels, num_classes):
139         super(InceptionAux, self).__init__()
140         self.averagePool = nn.AvgPool2d(kernel_size=5, stride=3)
141         self.conv = BasicConv2d(in_channels, 128, kernel_size=1)  # output[batch, 128, 4, 4]
142 
143         self.fc1 = nn.Linear(2048, 1024)
144         self.fc2 = nn.Linear(1024, num_classes)
145 
146     def forward(self, x):
147         # aux1: N x 512 x 14 x 14, aux2: N x 528 x 14 x 14
148         x = self.averagePool(x)
149         # aux1: N x 512 x 4 x 4, aux2: N x 528 x 4 x 4
150         x = self.conv(x)
151         # N x 128 x 4 x 4
152         x = torch.flatten(x, 1)
153         x = F.dropout(x, 0.5, training=self.training)
154         # N x 2048
155         x = F.relu(self.fc1(x), inplace=True)
156         x = F.dropout(x, 0.5, training=self.training)
157         # N x 1024
158         x = self.fc2(x)
159         # N x num_classes
160         return x
161 
162 
163 class BasicConv2d(nn.Module):
164     def __init__(self, in_channels, out_channels, **kwargs):
165         super(BasicConv2d, self).__init__()
166         self.conv = nn.Conv2d(in_channels, out_channels, **kwargs)
167         self.relu = nn.ReLU(inplace=True)
168 
169     def forward(self, x):
170         x = self.conv(x)
171         x = self.relu(x)
172         return x

train.py

  1 import torch
  2 import torch.nn as nn
  3 from torchvision import transforms, datasets
  4 import torchvision
  5 import json
  6 import matplotlib.pyplot as plt
  7 import os
  8 import torch.optim as optim
  9 from model import GoogLeNet
 10 
 11 
 12 def main():
 13     device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
 14     print("using {} device.".format(device))
 15 
 16     data_transform = {
 17         "train": transforms.Compose([transforms.RandomResizedCrop(224),
 18                                      transforms.RandomHorizontalFlip(),
 19                                      transforms.ToTensor(),
 20                                      transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]),
 21         "val": transforms.Compose([transforms.Resize((224, 224)),
 22                                    transforms.ToTensor(),
 23                                    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])}
 24 
 25     data_root = os.path.abspath(os.path.join(os.getcwd(), "../.."))  # get data root path
 26     image_path = os.path.join(data_root, "data_set", "flower_data")  # flower data set path
 27     assert os.path.exists(image_path), "{} path does not exist.".format(image_path)
 28     train_dataset = datasets.ImageFolder(root=os.path.join(image_path, "train"),
 29                                          transform=data_transform["train"])
 30     train_num = len(train_dataset)
 31 
 32     # {'daisy':0, 'dandelion':1, 'roses':2, 'sunflower':3, 'tulips':4}
 33     flower_list = train_dataset.class_to_idx
 34     cla_dict = dict((val, key) for key, val in flower_list.items())
 35     # write dict into json file
 36     json_str = json.dumps(cla_dict, indent=4)
 37     with open('class_indices.json', 'w') as json_file:
 38         json_file.write(json_str)
 39 
 40     batch_size = 32
 41     nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8])  # number of workers
 42     print('Using {} dataloader workers every process'.format(nw))
 43 
 44     train_loader = torch.utils.data.DataLoader(train_dataset,
 45                                                batch_size=batch_size, shuffle=True,
 46                                                num_workers=0)
 47 
 48     validate_dataset = datasets.ImageFolder(root=os.path.join(image_path, "val"),
 49                                             transform=data_transform["val"])
 50     val_num = len(validate_dataset)
 51     validate_loader = torch.utils.data.DataLoader(validate_dataset,
 52                                                   batch_size=batch_size, shuffle=False,
 53                                                   num_workers=0)
 54 
 55     print("using {} images for training, {} images fot validation.".format(train_num,
 56                                                                            val_num))
 57 
 58     # test_data_iter = iter(validate_loader)
 59     # test_image, test_label = test_data_iter.next()
 60 
 61     # net = torchvision.models.googlenet(num_classes=5)
 62     # model_dict = net.state_dict()
 63     # pretrain_model = torch.load("googlenet.pth")
 64     # del_list = ["aux1.fc2.weight", "aux1.fc2.bias",
 65     #             "aux2.fc2.weight", "aux2.fc2.bias",
 66     #             "fc.weight", "fc.bias"]
 67     # pretrain_dict = {k: v for k, v in pretrain_model.items() if k not in del_list}
 68     # model_dict.update(pretrain_dict)
 69     # net.load_state_dict(model_dict)
 70     net = GoogLeNet(num_classes=5, aux_logits=True, init_weights=True)
 71     net.to(device)
 72     loss_function = nn.CrossEntropyLoss()
 73     optimizer = optim.Adam(net.parameters(), lr=0.0003)
 74 
 75     best_acc = 0.0
 76     save_path = './googleNet.pth'
 77     for epoch in range(30):
 78         # train
 79         net.train()
 80         running_loss = 0.0
 81         for step, data in enumerate(train_loader, start=0):
 82             images, labels = data
 83             optimizer.zero_grad()
 84             logits, aux_logits2, aux_logits1 = net(images.to(device))
 85             loss0 = loss_function(logits, labels.to(device))
 86             loss1 = loss_function(aux_logits1, labels.to(device))
 87             loss2 = loss_function(aux_logits2, labels.to(device))
 88             loss = loss0 + loss1 * 0.3 + loss2 * 0.3
 89             loss.backward()
 90             optimizer.step()
 91 
 92             # print statistics
 93             running_loss += loss.item()
 94             # print train process
 95             rate = (step + 1) / len(train_loader)
 96             a = "*" * int(rate * 50)
 97             b = "." * int((1 - rate) * 50)
 98             print("\rtrain loss: {:^3.0f}%[{}->{}]{:.3f}".format(int(rate * 100), a, b, loss), end="")
 99         print()
100 
101         # validate
102         net.eval()
103         acc = 0.0  # accumulate accurate number / epoch
104         with torch.no_grad():
105             for val_data in validate_loader:
106                 val_images, val_labels = val_data
107                 outputs = net(val_images.to(device))  # eval model only have last output layer
108                 predict_y = torch.max(outputs, dim=1)[1]
109                 acc += (predict_y == val_labels.to(device)).sum().item()
110             val_accurate = acc / val_num
111             if val_accurate > best_acc:
112                 best_acc = val_accurate
113                 torch.save(net.state_dict(), save_path)
114             print('[epoch %d] train_loss: %.3f  test_accuracy: %.3f' %
115                   (epoch + 1, running_loss / step, val_accurate))
116 
117     print('Finished Training')
118 
119 
120 if __name__ == '__main__':
121     main()

 predict.py

 1 import torch
 2 from model import GoogLeNet
 3 from PIL import Image
 4 from torchvision import transforms
 5 import matplotlib.pyplot as plt
 6 import json
 7 
 8 data_transform = transforms.Compose(
 9     [transforms.Resize((224, 224)),
10      transforms.ToTensor(),
11      transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
12 
13 # load image
14 img = Image.open("../rose.jpg")
15 plt.imshow(img)
16 # [N, C, H, W]
17 img = data_transform(img)
18 # expand batch dimension
19 img = torch.unsqueeze(img, dim=0)
20 
21 # read class_indict
22 try:
23     json_file = open('./class_indices.json', 'r')
24     class_indict = json.load(json_file)
25 except Exception as e:
26     print(e)
27     exit(-1)
28 
29 # create model
30 model = GoogLeNet(num_classes=5, aux_logits=False)
31 # load model weights
32 model_weight_path = "./googleNet.pth"
33 missing_keys, unexpected_keys = model.load_state_dict(torch.load(model_weight_path), strict=False)
34 model.eval()
35 with torch.no_grad():
36     # predict class
37     output = torch.squeeze(model(img))
38     predict = torch.softmax(output, dim=0)
39     predict_cla = torch.argmax(predict).numpy()
40 print(class_indict[str(predict_cla)])
41 plt.show()

 

posted @ 2020-12-20 18:16  荼离伤花  阅读(363)  评论(0编辑  收藏  举报