第四讲 网络八股拓展--用mnist数据集实现断点续训, 绘制准确图像和损失图像
1 import tensorflow as tf 2 import os 3 import numpy as np 4 from matplotlib import pyplot as plt 5 6 7 np.set_printoptions(threshold=np.inf) 8 9 10 mnist = tf.keras.datasets.mnist 11 (x_train, y_train), (x_test, y_test) = mnist.load_data() 12 x_train, x_test = x_train/255.0, x_test/255.0 13 14 15 16 model = tf.keras.models.Sequential([ 17 tf.keras.layers.Flatten(), 18 tf.keras.layers.Dense(128, activation='relu'), 19 tf.keras.layers.Dense(10, activation='softmax') 20 ]) 21 22 model.compile(optimizer='adam', 23 loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False), 24 metrics = ['sparse_categorical_accuracy']) 25 26 27 28 checkpoint_save_path = "./checkpoint/mnist.ckpt" 29 if os.path.exists(checkpoint_save_path + '.index'): 30 print('----------------load the model-----------------') 31 model.load_weights(checkpoint_save_path) 32 33 34 35 cp_callback = tf.keras.callbacks.ModelCheckpoint( 36 filepath=checkpoint_save_path, 37 save_weights_only=True, 38 save_best_only=True 39 ) 40 41 history = model.fit(x_train, y_train, batch_size=32, epochs=5, validation_data=(x_test, y_test), validation_freq=1, callbacks=[cp_callback]) 42 43 model.summary() 44 45 46 print(model.trainable_variables) 47 with open(',.weights.txt', 'w') as file: 48 for v in model.trainable_variables: 49 file.write(str(v.name) + '\n') 50 file.write(str(v.shape) + '\n') 51 file.write(str(v.numpy()) + '\n') 52 53 54 55 # 显示训练集和验证集的acc和loss曲线 56 acc = history.history['sparse_categorical_accuracy'] 57 val_acc = history.history['val_sparse_categorical_accuracy'] 58 loss = history.history['loss'] 59 val_loss = history.history['val_loss'] 60 plt.figure(figsize=(15, 5)) 61 plt.subplot(1, 2, 1) 62 plt.plot(acc, label='Training Accuracy') 63 plt.plot(val_acc, label='Validation Accuracy') 64 plt.title('Training and Validation Accuracy') 65 #plt.legend() 66 plt.grid() 67 68 plt.subplot(1, 2, 2) 69 plt.plot(loss, label='Training Loss') 70 plt.plot(val_loss, label='Validation Loss') 71 plt.title('Training and Validation Loss') 72 plt.legend() 73 #plt.grid() 74 plt.show()