深度学习 Tensorflow(五)
保存模型
一、保存整个模型
整个模型可以保存到一个文件当中,其中包含权重值、模型配置乃至优化器配置。这样,就可以为模型设置检查点,并稍后从完全相同的状态继续训练,而无需访问原始代码。
在 Keras 中保存完全可正常使用的模型非常有用,可以在 TensorFlow.js 中加载他们,然后在网络浏览器中训练和运行他们。
Keras 使用 HDF5 标准提供基本的保存格式。
model.save('less_model_10_14.h5') # 保存模型,h5 格式
# 使用保存的模型
new_model = tf.keras.models.load_model('less_model_10_14.h5')
二、保存模型架构
保存模型架构,模型的层数设置,不保存权重和优化器设置
json_config = model.to_json()
# 模型恢复,重建
reinitialized_model = tf.keras.models.model_from_json(json_config)
reinitialized_model.summary()
# 重建的模型没有经过配置,权重是随机的,使用时需要配置优化器
reinitialized_model.compile(optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['acc']
)
三、保存模型权重
保存模型的状态(权重值),可以通过 get_weights() 获取权值,通过 set_weights() 设置权重值
weights = model.get_weights()
reinitialized_model.set_weights(weights)
reinitialized_model.evaluate(test_image, test_label, verbose=0)
# 保存权重到本地文件
model.save_weights('less_weights_10_14.h5')
# 加载权重
reinitialized_model.load_weights('less_weights_10_14.h5')
reinitialized_model.evaluate(test_image, test_label, verbose=0)
四、在训练期间保存检查点
在训练期间或寻来结束自动保存检查点,这样可以使用经过训练的模型,无需重新训练该模型,或从上次暂停的地方继续训练,以防训练过程中断。
回调函数:tf.keras.callbacks.ModelCheckpoint
checkpoint_path = 'training/check_point_10_14.ckpt'
cp_callback = tf.keras.callbacks.ModelCheckpoint(checkpoint_path, save_weights_only=True) # 仅仅保存了权重值,如果保存整个模型,调用方式同上
model.fit(train_image, train_label, epochs=3, callbacks=[cp_callback])
# 当网络重新开始时调用检查点
model.load_weights(checkpoint_path)
五、自定义训练中保存检查点
cp_dir = './customtrain'
cp_prefix = os.path.join(cp_dir, 'ckpt') # 添加文件前缀
checkpoint = tf.train.Checkpoint(optimizer = optimizer, model = model)
def train():
for epoch in range(5):
for (batch, (images, labels)) in enumerate(dataset):
train_step(model, images, labels)
print('Epoch{} loss is {}'.format(epoch, train_loss.result()))
print('Epoch{} accuracy is {}'.format(epoch, train_accuracy.result()))
train_loss.reset_states()
train_accuracy.reset_states()
if (epoch + 1) % 2 == 0: # 保存的频率
checkpoint.save(file_prefix = cp_prefix)
# 恢复模型
tf.train.latest_checkpoint(cp_dir) # 最新的检查点
checkpoint.restore(tf.train.latest_checkpoint(cp_dir)) # 通过模型的 name 属性对应来进行恢复