第五讲 卷积神经网络 - Resnet--cifar10

  1 import tensorflow as tf
  2 import os
  3 import numpy as np
  4 from matplotlib import pyplot as plt
  5 from tensorflow.keras.layers import Conv2D, BatchNormalization, Activation, MaxPooling2D, Dropout, Flatten, Dense, GlobalAveragePooling2D
  6 from tensorflow.keras import Model
  7 
  8 np.set_printoptions(threshold=np.inf)
  9 
 10 
 11 cifar10 = tf.keras.datasets.cifar10
 12 (x_train, y_train), (x_test, y_test) = cifar10.load_data()
 13 x_train, x_test = x_train/255.0, x_test/255.0
 14 
 15 
 16 
 17 class ResnetBlock(Model):
 18     def __init__(self, filters, strides=1, residual_path=False):
 19         super(ResnetBlock, self).__init__()
 20         self.filters = filters
 21         self.strides = strides
 22         self.residual_path = residual_path
 23 
 24         self.c1 = Conv2D(filters, (3, 3), strides=strides, padding='same', use_bias=False)
 25         self.b1 = BatchNormalization()
 26         self.a1 = Activation('relu')
 27 
 28         self.c2 = Conv2D(filters, (3, 3), strides=1, padding='same', use_bias=False)
 29         self.b2 = BatchNormalization()
 30 
 31         # residual_path为True时,对输入进行下采样,即用1x1的卷积核做卷积操作,保证x能和F(x)维度相同,顺利相加
 32         if residual_path:
 33             self.down_c1 = Conv2D(filters, (1, 1), strides=strides, padding='same', use_bias=False)
 34             self.down_b1 = BatchNormalization()
 35 
 36         self.a2 = Activation('relu')
 37 
 38     def call(self, inputs):
 39         residual = inputs # residual等于输入值本身,即residual=x
 40         x = self.c1(inputs)
 41         x = self.b1(x)
 42         x = self.a1(x)
 43 
 44         x = self.c2(x)
 45         y = self.b2(x)
 46 
 47         if self.residual_path:
 48             residual  = self.down_c1(inputs)
 49             residual  = self.down_b1(residual)
 50 
 51         out = self.a2(y + residual) # 最后输出的是两部分的和,即F(x)+x或F(x)+Wx,再过激活函数
 52         return out
 53 
 54 
 55 
 56 class ResNet18(Model):
 57     def __init__(self, block_list, initial_filters=64): # block_list表示每个block有几个卷积层
 58         super(ResNet18, self).__init__()
 59         self.num_blocks = len(block_list) # 共有几个block
 60         self.block_list = block_list
 61         self.out_filters = initial_filters
 62         self.c1 = Conv2D(self.out_filters, (3, 3), strides=1, padding='same', use_bias = False)
 63         self.b1 = BatchNormalization()
 64         self.a1 = Activation('relu')
 65         self.blocks = tf.keras.models.Sequential()
 66         # 构建ResNet网络结构
 67         for block_id in range(len(block_list)):
 68             for layer_id in range(block_list[block_id]):
 69                 if block_id != 0 and layer_id == 0: # 对除第一个block以外的每个block的输入进行下采样
 70                     block = ResnetBlock(self.out_filters, strides=2, residual_path=True)
 71                 else:
 72                     block = ResnetBlock(self.out_filters, residual_path=False)
 73                 self.blocks.add(block) # 将构建好的block加入resnet
 74             self.out_filters *= 2 # 下一个block的卷积核数是上一个block的2倍
 75         self.p1 = tf.keras.layers.GlobalAveragePooling2D()
 76         self.f1 = tf.keras.layers.Dense(10, activation='softmax', kernel_regularizer=tf.keras.regularizers.l2())
 77 
 78     
 79     def call(self, inputs):
 80         x = self.c1(inputs)
 81         x = self.b1(x)
 82         x = self.a1(x)
 83         x = self.blocks(x)
 84         x = self.p1(x)
 85         y = self.f1(x)
 86         return y
 87 
 88 
 89 
 90 model = ResNet18([2, 2, 2, 2])
 91 
 92 model.compile(optimizer='adam',
 93               loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
 94               metrics=['sparse_categorical_accuracy'])
 95 
 96 
 97 checkpoint_save_path = "./checkpoint/Inception10.ckpt"
 98 if os.path.exists(checkpoint_save_path + '.index'):
 99     print('-------------load the model---------------')
100     model.load_weights(checkpoint_save_path)
101 
102 cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath = checkpoint_save_path,
103                                                 save_weights_only = True,
104                                                 save_best_only = True)
105 
106 history = model.fit(x_train, y_train, batch_size=32, epochs=5, validation_data=(x_test, y_test),validation_freq=1,
107                     callbacks=[cp_callback])
108 model.summary()
109 
110 
111 
112 with open('./weights.txt', 'w') as f:
113     for v in model.trainable_variables:
114         f.write(str(v.name) + '\n')
115         f.write(str(v.shape) + '\n')
116         f.write(str(v.numpy()) + '\n')
117 
118 
119 def plot_acc_loss_curve(history):
120     # 显示训练集和验证集的acc和loss曲线
121     from matplotlib import pyplot as plt
122     acc = history.history['sparse_categorical_accuracy']
123     val_acc = history.history['val_sparse_categorical_accuracy']
124     loss = history.history['loss']
125     val_loss = history.history['val_loss']
126     
127     plt.figure(figsize=(15, 5))
128     plt.subplot(1, 2, 1)
129     plt.plot(acc, label='Training Accuracy')
130     plt.plot(val_acc, label='Validation Accuracy')
131     plt.title('Training and Validation Accuracy')
132     plt.legend()
133     #plt.grid()
134     
135     plt.subplot(1, 2, 2)
136     plt.plot(loss, label='Training Loss')
137     plt.plot(val_loss, label='Validation Loss')
138     plt.title('Training and Validation Loss')
139     plt.legend()
140     #plt.grid()
141     plt.show()
142 
143 plot_acc_loss_curve(history)

 

posted @ 2020-05-10 10:07  WWBlog  阅读(402)  评论(0编辑  收藏  举报