深度学习 Tensorflow(五)

保存模型

一、保存整个模型

整个模型可以保存到一个文件当中,其中包含权重值、模型配置乃至优化器配置。这样,就可以为模型设置检查点,并稍后从完全相同的状态继续训练,而无需访问原始代码。

在 Keras 中保存完全可正常使用的模型非常有用,可以在 TensorFlow.js 中加载他们,然后在网络浏览器中训练和运行他们。

Keras 使用 HDF5 标准提供基本的保存格式。

 

model.save('less_model_10_14.h5')    # 保存模型,h5 格式
# 使用保存的模型
new_model = tf.keras.models.load_model('less_model_10_14.h5')

 

 

二、保存模型架构

保存模型架构,模型的层数设置,不保存权重和优化器设置

json_config = model.to_json()
# 模型恢复,重建
reinitialized_model = tf.keras.models.model_from_json(json_config)
reinitialized_model.summary()
# 重建的模型没有经过配置,权重是随机的,使用时需要配置优化器
reinitialized_model.compile(optimizer='adam',
                            loss='sparse_categorical_crossentropy',
                            metrics=['acc']
)

 

三、保存模型权重

保存模型的状态(权重值),可以通过 get_weights() 获取权值,通过 set_weights() 设置权重值

weights = model.get_weights()
reinitialized_model.set_weights(weights)
reinitialized_model.evaluate(test_image, test_label, verbose=0)

# 保存权重到本地文件
model.save_weights('less_weights_10_14.h5')
# 加载权重
reinitialized_model.load_weights('less_weights_10_14.h5')
reinitialized_model.evaluate(test_image, test_label, verbose=0)

 

四、在训练期间保存检查点

在训练期间或寻来结束自动保存检查点,这样可以使用经过训练的模型,无需重新训练该模型,或从上次暂停的地方继续训练,以防训练过程中断。

回调函数:tf.keras.callbacks.ModelCheckpoint

checkpoint_path = 'training/check_point_10_14.ckpt'
cp_callback = tf.keras.callbacks.ModelCheckpoint(checkpoint_path, save_weights_only=True)    # 仅仅保存了权重值,如果保存整个模型,调用方式同上

model.fit(train_image, train_label, epochs=3, callbacks=[cp_callback])

# 当网络重新开始时调用检查点
model.load_weights(checkpoint_path)

 

五、自定义训练中保存检查点

cp_dir = './customtrain'
cp_prefix = os.path.join(cp_dir, 'ckpt')    # 添加文件前缀
checkpoint = tf.train.Checkpoint(optimizer = optimizer, model = model)
def train():
    for epoch in range(5):
        for (batch, (images, labels)) in enumerate(dataset):
            train_step(model, images, labels)
        print('Epoch{} loss is {}'.format(epoch, train_loss.result()))
        print('Epoch{} accuracy is {}'.format(epoch, train_accuracy.result()))
        train_loss.reset_states()
        train_accuracy.reset_states()
        if (epoch + 1) % 2 == 0:    # 保存的频率
            checkpoint.save(file_prefix = cp_prefix)

# 恢复模型
tf.train.latest_checkpoint(cp_dir)    # 最新的检查点
checkpoint.restore(tf.train.latest_checkpoint(cp_dir))    # 通过模型的 name 属性对应来进行恢复

posted @ 2020-10-14 20:27  我脑子不好  阅读(169)  评论(0编辑  收藏  举报