第五讲 卷积神经网络 --baseline
1 import tensorflow as tf 2 import os 3 import numpy as np 4 from matplotlib import pyplot as plt 5 from tensorflow.keras.layers import Conv2D, BatchNormalization, Activation, MaxPool2D, Dropout, Flatten, Dense 6 from tensorflow.keras import Model 7 8 9 np.set_printoptions(threshold=np.inf) 10 11 cifar10 = tf.keras.datasets.cifar10 12 (x_train, y_train), (x_test, y_test) = cifar10.load_data() 13 x_train, x_test = x_train/25.0, x_test/255.0 14 15 16 class BaseLine(Model): 17 def __init__(self): 18 super(BaseLine, self).__init__() 19 self.c1 = Conv2D(filters=6, kernel_size=(5, 5), padding='same') #卷积层 20 self.b1 = BatchNormalization() #BN层 21 self.a1 = Activation('relu') #激活层 22 self.p1 = MaxPool2D(pool_size=(2, 2), strides=2, padding='same') #池化层 23 self.d1 = Dropout(0.2) #dropou层 24 25 self.flatten = Flatten() 26 self.f1 = Dense(128, activation='relu') 27 self.d2 = Dropout(0.2) 28 self.f2 = Dense(10, activation='softmax') 29 30 def call(self, x): 31 x = self.c1(x) 32 x = self.b1(x) 33 x = self.a1(x) 34 x = self.p1(x) 35 x = self.d1(x) 36 37 x = self.flatten(x) 38 x = self.f1(x) 39 x = self.d2(x) 40 y = self.f2(x) 41 return y 42 43 44 45 model = BaseLine() 46 47 model.compile(optimizer='adam', 48 loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False), 49 metrics = ['sparse_categorical_accuracy']) 50 51 checkpoint_save_path = "./checkpoint/Baseline.ckpt" 52 if os.path.exists(checkpoint_save_path + ".index"): 53 print("--------------------load the model-----------------") 54 model.load_weights(checkpoint_save_path) 55 56 cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_save_path, save_weights_only=True, save_best_only=True) 57 58 history = model.fit(x_train, y_train, batch_size=32, epochs=5, validation_data=(x_test, y_test), validation_freq=1, callbacks=[cp_callback]) 59 60 model.summary() 61 62 63 with open('./weights.txt', 'w') as file: 64 for v in model.trainable_variables: 65 file.write(str(v.name) + '\n') 66 file.write(str(v.shape) + '\n') 67 file.write(str(v.numpy()) + '\n') 68 69 70 def plot_acc_loss_curve(history): 71 # 显示训练集和验证集的acc和loss曲线 72 from matplotlib import pyplot as plt 73 acc = history.history['sparse_categorical_accuracy'] 74 val_acc = history.history['val_sparse_categorical_accuracy'] 75 loss = history.history['loss'] 76 val_loss = history.history['val_loss'] 77 78 plt.figure(figsize=(15, 5)) 79 plt.subplot(1, 2, 1) 80 plt.plot(acc, label='Training Accuracy') 81 plt.plot(val_acc, label='Validation Accuracy') 82 plt.title('Training and Validation Accuracy') 83 plt.legend() 84 #plt.grid() 85 86 plt.subplot(1, 2, 2) 87 plt.plot(loss, label='Training Loss') 88 plt.plot(val_loss, label='Validation Loss') 89 plt.title('Training and Validation Loss') 90 plt.legend() 91 #plt.grid() 92 plt.show() 93 94 plot_acc_loss_curve(history)