第五讲 卷积神经网络 - Resnet--cifar10
1 import tensorflow as tf 2 import os 3 import numpy as np 4 from matplotlib import pyplot as plt 5 from tensorflow.keras.layers import Conv2D, BatchNormalization, Activation, MaxPooling2D, Dropout, Flatten, Dense, GlobalAveragePooling2D 6 from tensorflow.keras import Model 7 8 np.set_printoptions(threshold=np.inf) 9 10 11 cifar10 = tf.keras.datasets.cifar10 12 (x_train, y_train), (x_test, y_test) = cifar10.load_data() 13 x_train, x_test = x_train/255.0, x_test/255.0 14 15 16 17 class ResnetBlock(Model): 18 def __init__(self, filters, strides=1, residual_path=False): 19 super(ResnetBlock, self).__init__() 20 self.filters = filters 21 self.strides = strides 22 self.residual_path = residual_path 23 24 self.c1 = Conv2D(filters, (3, 3), strides=strides, padding='same', use_bias=False) 25 self.b1 = BatchNormalization() 26 self.a1 = Activation('relu') 27 28 self.c2 = Conv2D(filters, (3, 3), strides=1, padding='same', use_bias=False) 29 self.b2 = BatchNormalization() 30 31 # residual_path为True时,对输入进行下采样,即用1x1的卷积核做卷积操作,保证x能和F(x)维度相同,顺利相加 32 if residual_path: 33 self.down_c1 = Conv2D(filters, (1, 1), strides=strides, padding='same', use_bias=False) 34 self.down_b1 = BatchNormalization() 35 36 self.a2 = Activation('relu') 37 38 def call(self, inputs): 39 residual = inputs # residual等于输入值本身,即residual=x 40 x = self.c1(inputs) 41 x = self.b1(x) 42 x = self.a1(x) 43 44 x = self.c2(x) 45 y = self.b2(x) 46 47 if self.residual_path: 48 residual = self.down_c1(inputs) 49 residual = self.down_b1(residual) 50 51 out = self.a2(y + residual) # 最后输出的是两部分的和,即F(x)+x或F(x)+Wx,再过激活函数 52 return out 53 54 55 56 class ResNet18(Model): 57 def __init__(self, block_list, initial_filters=64): # block_list表示每个block有几个卷积层 58 super(ResNet18, self).__init__() 59 self.num_blocks = len(block_list) # 共有几个block 60 self.block_list = block_list 61 self.out_filters = initial_filters 62 self.c1 = Conv2D(self.out_filters, (3, 3), strides=1, padding='same', use_bias = False) 63 self.b1 = BatchNormalization() 64 self.a1 = Activation('relu') 65 self.blocks = tf.keras.models.Sequential() 66 # 构建ResNet网络结构 67 for block_id in range(len(block_list)): 68 for layer_id in range(block_list[block_id]): 69 if block_id != 0 and layer_id == 0: # 对除第一个block以外的每个block的输入进行下采样 70 block = ResnetBlock(self.out_filters, strides=2, residual_path=True) 71 else: 72 block = ResnetBlock(self.out_filters, residual_path=False) 73 self.blocks.add(block) # 将构建好的block加入resnet 74 self.out_filters *= 2 # 下一个block的卷积核数是上一个block的2倍 75 self.p1 = tf.keras.layers.GlobalAveragePooling2D() 76 self.f1 = tf.keras.layers.Dense(10, activation='softmax', kernel_regularizer=tf.keras.regularizers.l2()) 77 78 79 def call(self, inputs): 80 x = self.c1(inputs) 81 x = self.b1(x) 82 x = self.a1(x) 83 x = self.blocks(x) 84 x = self.p1(x) 85 y = self.f1(x) 86 return y 87 88 89 90 model = ResNet18([2, 2, 2, 2]) 91 92 model.compile(optimizer='adam', 93 loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False), 94 metrics=['sparse_categorical_accuracy']) 95 96 97 checkpoint_save_path = "./checkpoint/Inception10.ckpt" 98 if os.path.exists(checkpoint_save_path + '.index'): 99 print('-------------load the model---------------') 100 model.load_weights(checkpoint_save_path) 101 102 cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath = checkpoint_save_path, 103 save_weights_only = True, 104 save_best_only = True) 105 106 history = model.fit(x_train, y_train, batch_size=32, epochs=5, validation_data=(x_test, y_test),validation_freq=1, 107 callbacks=[cp_callback]) 108 model.summary() 109 110 111 112 with open('./weights.txt', 'w') as f: 113 for v in model.trainable_variables: 114 f.write(str(v.name) + '\n') 115 f.write(str(v.shape) + '\n') 116 f.write(str(v.numpy()) + '\n') 117 118 119 def plot_acc_loss_curve(history): 120 # 显示训练集和验证集的acc和loss曲线 121 from matplotlib import pyplot as plt 122 acc = history.history['sparse_categorical_accuracy'] 123 val_acc = history.history['val_sparse_categorical_accuracy'] 124 loss = history.history['loss'] 125 val_loss = history.history['val_loss'] 126 127 plt.figure(figsize=(15, 5)) 128 plt.subplot(1, 2, 1) 129 plt.plot(acc, label='Training Accuracy') 130 plt.plot(val_acc, label='Validation Accuracy') 131 plt.title('Training and Validation Accuracy') 132 plt.legend() 133 #plt.grid() 134 135 plt.subplot(1, 2, 2) 136 plt.plot(loss, label='Training Loss') 137 plt.plot(val_loss, label='Validation Loss') 138 plt.title('Training and Validation Loss') 139 plt.legend() 140 #plt.grid() 141 plt.show() 142 143 plot_acc_loss_curve(history)