Tutorial on GoogleNet based image classification --- focus on Inception module and save/load models

 Tutorial on GoogleNet based image classification 

2018-06-26 15:50:29 

本文旨在通过案例来学习 GoogleNet 及其 Inception 结构的定义。针对这种复杂模型的保存以及读取。

1. GoogleNet 的结构:

 1 class Inception(nn.Module):
 2     def __init__(self, in_planes, kernel_1_x, kernel_3_in, kernel_3_x, kernel_5_in, kernel_5_x, pool_planes):
 3         super(Inception, self).__init__()
 4         # 1x1 conv branch
 5         self.b1 = nn.Sequential(
 6             nn.Conv2d(in_planes, kernel_1_x, kernel_size=1),
 7             nn.BatchNorm2d(kernel_1_x),
 8             nn.ReLU(True),
 9         )
10 
11         # 1x1 conv -> 3x3 conv branch
12         self.b2 = nn.Sequential(
13             nn.Conv2d(in_planes, kernel_3_in, kernel_size=1),
14             nn.BatchNorm2d(kernel_3_in),
15             nn.ReLU(True),
16             nn.Conv2d(kernel_3_in, kernel_3_x, kernel_size=3, padding=1),
17             nn.BatchNorm2d(kernel_3_x),
18             nn.ReLU(True),
19         )
20 
21         # 1x1 conv -> 5x5 conv branch
22         self.b3 = nn.Sequential(
23             nn.Conv2d(in_planes, kernel_5_in, kernel_size=1),
24             nn.BatchNorm2d(kernel_5_in),
25             nn.ReLU(True),
26             nn.Conv2d(kernel_5_in, kernel_5_x, kernel_size=3, padding=1),
27             nn.BatchNorm2d(kernel_5_x),
28             nn.ReLU(True),
29             nn.Conv2d(kernel_5_x, kernel_5_x, kernel_size=3, padding=1),
30             nn.BatchNorm2d(kernel_5_x),
31             nn.ReLU(True),
32         )
33 
34         # 3x3 pool -> 1x1 conv branch
35         self.b4 = nn.Sequential(
36             nn.MaxPool2d(3, stride=1, padding=1),
37             nn.Conv2d(in_planes, pool_planes, kernel_size=1),
38             nn.BatchNorm2d(pool_planes),
39             nn.ReLU(True),
40         )
41 
42     def forward(self, x):
43         y1 = self.b1(x)
44         y2 = self.b2(x)
45         y3 = self.b3(x)
46         y4 = self.b4(x)
47         return torch.cat([y1,y2,y3,y4], 1)
View Code
class GoogLeNet(nn.Module):
    def __init__(self):
        super(GoogLeNet, self).__init__()
        self.pre_layers = nn.Sequential(
            nn.Conv2d(3, 192, kernel_size=3, padding=1),
            nn.BatchNorm2d(192),
            nn.ReLU(True),
        )

        self.a3 = Inception(192,  64,  96, 128, 16, 32, 32)
        self.b3 = Inception(256, 128, 128, 192, 32, 96, 64)

        self.max_pool = nn.MaxPool2d(3, stride=2, padding=1)

        self.a4 = Inception(480, 192,  96, 208, 16,  48,  64)
        self.b4 = Inception(512, 160, 112, 224, 24,  64,  64)
        self.c4 = Inception(512, 128, 128, 256, 24,  64,  64)
        self.d4 = Inception(512, 112, 144, 288, 32,  64,  64)
        self.e4 = Inception(528, 256, 160, 320, 32, 128, 128)

        self.a5 = Inception(832, 256, 160, 320, 32, 128, 128)
        self.b5 = Inception(832, 384, 192, 384, 48, 128, 128)

        self.avgpool = nn.AvgPool2d(8, stride=1)
        self.linear = nn.Linear(1024, 10)

    def forward(self, x):
        x = self.pre_layers(x)
        x = self.a3(x)
        x = self.b3(x)
        x = self.max_pool(x)
        x = self.a4(x)
        x = self.b4(x)
        x = self.c4(x)
        x = self.d4(x)
        x = self.e4(x)
        x = self.max_pool(x)
        x = self.a5(x)
        x = self.b5(x)
        x = self.avgpool(x)
        x = x.view(x.size(0), -1)
        x = self.linear(x)
        return x
View Code

 

2. 保存和加载模型:

# 保存和加载整个模型
torch.save(model_object, 'model.pkl')
model = torch.load('model.pkl')


# 仅保存和加载模型参数(推荐使用)
torch.save(model_object.state_dict(), 'params.pkl')
model_object.load_state_dict(torch.load('params.pkl'))

 

posted @ 2018-06-26 16:06  AHU-WangXiao  阅读(212)  评论(0编辑  收藏  举报