第五讲 卷积神经网路-- Inception10 --cifar10

  1 import tensorflow as tf
  2 import os
  3 import numpy as np
  4 from matplotlib import pyplot as plt
  5 from tensorflow.keras.layers import Conv2D, BatchNormalization, Activation, MaxPooling2D, Dropout, Flatten, Dense, GlobalAveragePooling2D
  6 from tensorflow.keras import Model
  7 
  8 np.set_printoptions(threshold=np.inf)
  9 
 10 ciar10 = tf.keras.datasets.cifar10
 11 (x_train, y_train), (x_test, y_test) = cifar10.load_data()
 12 x_train, x_test = x_train/255.0, x_test/255.0
 13 
 14 class ConvBNRelu(Model):
 15     def __init__(self, ch, kernelsz=3, strides=1, padding='same'):
 16         super(ConvBNRelu, self).__init__()
 17         self.model = tf.keras.models.Sequential([
 18             Conv2D(ch, kernelsz, strides=strides, padding=padding),
 19             BatchNormalization(),
 20             Activation('relu')
 21         ])
 22 
 23     def call(self, x):
 24         x = self.model(x, training=False)
 25         #在training=False时,BN通过整个训练集计算均值、方差去做批归一化,training=True时,通过当前batch的均值、方差去做批归一化。推理时 training=False效果好
 26         return x
 27 
 28 
 29 
 30 class InceptionBlk(Model):
 31     def __init__(self, ch, strides=1):
 32         super(InceptionBlk, self).__init__()
 33         self.ch = ch
 34         self.strides = strides
 35         self.c1 = ConvBNRelu(ch, kernelsz=1, strides=strides)
 36         self.c2_1 = ConvBNRelu(ch, kernelsz=1, strides=strides)
 37         self.c2_2 = ConvBNRelu(ch, kernelsz=3, strides=1)
 38         self.c3_1 = ConvBNRelu(ch, kernelsz=1, strides=strides)
 39         self.c3_2 = ConvBNRelu(ch, kernelsz=5, strides=1)
 40         self.p4_1 = MaxPooling2D(3, strides=1, padding='same')
 41         self.c4_2 = ConvBNRelu(ch, kernelsz=1, strides=strides)
 42 
 43     def call(self, x):
 44         x1 = self.c1(x)
 45         x2_1 = self.c2_1(x)
 46         x2_2 = self.c2_2(x2_1)
 47         x3_1 = self.c3_1(x)
 48         x3_2 = self.c3_2(x3_1)
 49         x4_1 = self.p4_1(x)
 50         x4_2 = self.c4_2(x4_1)
 51         # concat along axis=channel
 52         x = tf.concat([x1, x2_2, x3_2, x4_2], axis=1)
 53         return x
 54 
 55 class Inception10(Model):
 56     def __init__(self, num_blocks, num_classes, init_ch=16, **kwargs):
 57         super(Inception10, self).__init__(**kwargs)
 58         self.in_channels = init_ch
 59         self.out_channels = init_ch
 60         self.num_blocks = num_blocks
 61         self.init_ch = init_ch
 62         self.c1 = ConvBNRelu(init_ch)
 63         self.blocks = tf.keras.models.Sequential()
 64         for block_id in range(num_blocks):
 65             for layer_id in range(2):
 66                 if layer_id == 0:
 67                     block = InceptionBlk(self.out_channels, strides=1)
 68                 else:
 69                     block = InceptionBlk(self.out_channels, strides=1)
 70                 self.blocks.add(block)
 71             # enlarger out_channels per block
 72             self.out_channels *=2
 73         self.p1 = GlobalAveragePooling2D()
 74         self.f1 = Dense(num_classes, activation='softmax')
 75 
 76     def call(self, x):
 77         x = self.c1(x)
 78         x = self.blocks(x)
 79         x = self.p1(x)
 80         y = self.f1(x)
 81         return y
 82 
 83 model = Inception10(num_blocks=2, num_classes=10)
 84 
 85 model.compile(optimizer='adam',
 86               loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
 87               metrics=['sparse_categorical_accuracy'])
 88 
 89 
 90 checkpoint_save_path = "./checkpoint/Inception10.ckpt"
 91 if os.path.exists(checkpoint_save_path + '.index'):
 92     print('-------------load the model---------------')
 93     model.load_weights(checkpoint_save_path)
 94 
 95 cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath = checkpoint_save_path,
 96                                                 save_weights_only = True,
 97                                                 save_best_only = True)
 98 
 99 history = model.fit(x_train, y_train, batch_size=32, epochs=5, validation_data=(x_test, y_test),validation_freq=1,
100                     callbacks=[cp_callback])
101 model.summary()
102 
103 
104 with open('./weights.txt', 'w') as f:
105     for v in model.trainable_variables:
106         f.write(str(v.name) + '\n')
107         f.write(str(v.shape) + '\n')
108         f.write(str(v.numpy()) + '\n')
109 
110 
111 
112 def plot_acc_loss_curve(history):
113     # 显示训练集和验证集的acc和loss曲线
114     from matplotlib import pyplot as plt
115     acc = history.history['sparse_categorical_accuracy']
116     val_acc = history.history['val_sparse_categorical_accuracy']
117     loss = history.history['loss']
118     val_loss = history.history['val_loss']
119     
120     plt.figure(figsize=(15, 5))
121     plt.subplot(1, 2, 1)
122     plt.plot(acc, label='Training Accuracy')
123     plt.plot(val_acc, label='Validation Accuracy')
124     plt.title('Training and Validation Accuracy')
125     #plt.legend()
126     plt.grid()
127     
128     plt.subplot(1, 2, 2)
129     plt.plot(loss, label='Training Loss')
130     plt.plot(val_loss, label='Validation Loss')
131     plt.title('Training and Validation Loss')
132     plt.legend()
133     #plt.grid()
134     plt.show()
135 
136 plot_acc_loss_curve(history)
137     

 

posted @ 2020-05-10 08:47  WWBlog  阅读(539)  评论(0编辑  收藏  举报