【tensorflow】神经网络:断点续训
断点续训,即在一次训练结束后,可以先将得到的最优训练参数保存起来,待到下次训练时,直接读取最优参数,在此基础上继续训练。
读取模型参数:
存储模型参数的文件格式为 ckpt(checkpoint)。
生成 ckpt 文件时,会同步生成索引表,所以可通过判断是否存在索引表来判断是否存在模型参数。
# 模型参数保存路径 checkpoint_save_path = "class4/MNIST_FC/checkpoint/mnist.ckpt"
if os.path.exists(checkpoint_save_path + ".index"): model.load_weights(checkpoint_save_path)
保存模型参数:
# 定义回调函数,在模型训练时,回调函数会被执行,完成保留参数操作
cp_callback = tf.keras.callbacks.ModelCheckpoint(
# 文件保存路径
filepath=checkpoint_save_path,
# 是否只保留模型参数
save_weights_only=True,
# 是否只保留最优结果
save_best_only=True
)
# 执行训练过程,保存新的训练参数
history = model.fit(x_train, y_train,
batch_size=32, epochs=5,
validation_data=(x_test, y_test),
validation_freq=1,
callbacks=[cp_callback])
代码:
import tensorflow as tf
import os
# 读取输入特征和标签
mnist = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()
# 数据归一化,减小计算量,方便神经网络吸收
x_train, x_test = x_train/255.0, x_test/255.0
# 声明网络结构
model = tf.keras.models.Sequential([
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(128, activation="relu"),
tf.keras.layers.Dense(10, activation="softmax")
])
# 配置训练方法
model.compile(optimizer="adam",
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
metrics=[tf.keras.metrics.sparse_categorical_accuracy])
# 如果存在参数文件,直接读取,在此基础上继续训练
checkpoint_save_path = "class4/MNIST_FC/checkpoint/mnist.ckpt" # 模型参数保存路径
if os.path.exists(checkpoint_save_path + ".index"):
model.load_weights(checkpoint_save_path)
# 定义回调函数,在模型训练时,完成保留参数操作
cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_save_path,
save_weights_only=True,
save_best_only=True)
# 执行训练过程,保存新的训练参数
history = model.fit(x_train, y_train,
batch_size=32, epochs=5,
validation_data=(x_test, y_test),
validation_freq=1,
callbacks=[cp_callback])
# 打印网络结构和参数
model.summary()