第四讲 网络八股拓展--用mnist数据集实现断点续训
1 import tensorflow as tf 2 import os 3 4 5 mnist = tf.keras.datasets.mnist 6 (x_train, y_train), (x_test, y_test) = mnist.load_data() 7 x_train, x_test = x_train/255.0, x_test/255.0 8 9 10 model = tf.keras.models.Sequential([ 11 tf.keras.layers.Flatten(), 12 tf.keras.layers.Dense(128, activation='relu'), 13 tf.keras.layers.Dense(10, activation='softmax') 14 ]) 15 16 model.compile(optimizer='adam', 17 loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False), 18 metrics=['sparse_categorical_accuracy']) 19 20 21 22 checkpoint_save_path = "./checkpoint/mnist.ckpt" 23 if os.path.exists(checkpoint_save_path + ".index"): 24 print("-----------------load the model-----------------------") 25 model.load_weights(checkpoint_save_path) 26 27 cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_save_path, 28 save_weights_only=True, 29 save_best_only=True) 30 31 history = model.fit(x_train, y_train, batch_size=32, epochs=5, validation_data=(x_test, y_test), 32 validation_freq=1, callbacks=[cp_callback]) 33 model.summary()