Tutorial on GoogleNet based image classification --- focus on Inception module and save/load models
Tutorial on GoogleNet based image classification
2018-06-26 15:50:29
本文旨在通过案例来学习 GoogleNet 及其 Inception 结构的定义。针对这种复杂模型的保存以及读取。
1. GoogleNet 的结构:
1 class Inception(nn.Module): 2 def __init__(self, in_planes, kernel_1_x, kernel_3_in, kernel_3_x, kernel_5_in, kernel_5_x, pool_planes): 3 super(Inception, self).__init__() 4 # 1x1 conv branch 5 self.b1 = nn.Sequential( 6 nn.Conv2d(in_planes, kernel_1_x, kernel_size=1), 7 nn.BatchNorm2d(kernel_1_x), 8 nn.ReLU(True), 9 ) 10 11 # 1x1 conv -> 3x3 conv branch 12 self.b2 = nn.Sequential( 13 nn.Conv2d(in_planes, kernel_3_in, kernel_size=1), 14 nn.BatchNorm2d(kernel_3_in), 15 nn.ReLU(True), 16 nn.Conv2d(kernel_3_in, kernel_3_x, kernel_size=3, padding=1), 17 nn.BatchNorm2d(kernel_3_x), 18 nn.ReLU(True), 19 ) 20 21 # 1x1 conv -> 5x5 conv branch 22 self.b3 = nn.Sequential( 23 nn.Conv2d(in_planes, kernel_5_in, kernel_size=1), 24 nn.BatchNorm2d(kernel_5_in), 25 nn.ReLU(True), 26 nn.Conv2d(kernel_5_in, kernel_5_x, kernel_size=3, padding=1), 27 nn.BatchNorm2d(kernel_5_x), 28 nn.ReLU(True), 29 nn.Conv2d(kernel_5_x, kernel_5_x, kernel_size=3, padding=1), 30 nn.BatchNorm2d(kernel_5_x), 31 nn.ReLU(True), 32 ) 33 34 # 3x3 pool -> 1x1 conv branch 35 self.b4 = nn.Sequential( 36 nn.MaxPool2d(3, stride=1, padding=1), 37 nn.Conv2d(in_planes, pool_planes, kernel_size=1), 38 nn.BatchNorm2d(pool_planes), 39 nn.ReLU(True), 40 ) 41 42 def forward(self, x): 43 y1 = self.b1(x) 44 y2 = self.b2(x) 45 y3 = self.b3(x) 46 y4 = self.b4(x) 47 return torch.cat([y1,y2,y3,y4], 1)
class GoogLeNet(nn.Module): def __init__(self): super(GoogLeNet, self).__init__() self.pre_layers = nn.Sequential( nn.Conv2d(3, 192, kernel_size=3, padding=1), nn.BatchNorm2d(192), nn.ReLU(True), ) self.a3 = Inception(192, 64, 96, 128, 16, 32, 32) self.b3 = Inception(256, 128, 128, 192, 32, 96, 64) self.max_pool = nn.MaxPool2d(3, stride=2, padding=1) self.a4 = Inception(480, 192, 96, 208, 16, 48, 64) self.b4 = Inception(512, 160, 112, 224, 24, 64, 64) self.c4 = Inception(512, 128, 128, 256, 24, 64, 64) self.d4 = Inception(512, 112, 144, 288, 32, 64, 64) self.e4 = Inception(528, 256, 160, 320, 32, 128, 128) self.a5 = Inception(832, 256, 160, 320, 32, 128, 128) self.b5 = Inception(832, 384, 192, 384, 48, 128, 128) self.avgpool = nn.AvgPool2d(8, stride=1) self.linear = nn.Linear(1024, 10) def forward(self, x): x = self.pre_layers(x) x = self.a3(x) x = self.b3(x) x = self.max_pool(x) x = self.a4(x) x = self.b4(x) x = self.c4(x) x = self.d4(x) x = self.e4(x) x = self.max_pool(x) x = self.a5(x) x = self.b5(x) x = self.avgpool(x) x = x.view(x.size(0), -1) x = self.linear(x) return x
2. 保存和加载模型:
# 保存和加载整个模型 torch.save(model_object, 'model.pkl') model = torch.load('model.pkl') # 仅保存和加载模型参数(推荐使用) torch.save(model_object.state_dict(), 'params.pkl') model_object.load_state_dict(torch.load('params.pkl'))
Stay Hungry,Stay Foolish ...