tf.train.Saver()模型保存与恢复

1.保存

将训练好的模型参数保存起来,以便以后进行验证或测试。tf里面提供模型保存的是tf.train.Saver()模块。

模型保存,先要创建一个Saver对象:如

saver=tf.train.Saver()

在创建这个Saver对象的时候,有一个参数经常会用到,max_to_keep 参数,这个是用来设置保存模型的个数,默认为5,即 max_to_keep=5,保存最近的5个模型。如果想每训练一代(epoch)就想保存一次模型,则可以将 max_to_keep设置为None或者0,但是这样做除了多占用硬盘,并没有实际多大的用处,因此不推荐,如:

saver=tf.train.Saver(max_to_keep=0)

当然,如果你只想保存最后一代的模型,则只需要将max_to_keep设置为1即可,即

saver=tf.train.Saver(max_to_keep=1)

创建完saver对象后,就可以保存训练好的模型了,如:

saver.save(sess,‘ckpt/mnist.ckpt',global_step=step)

第二个参数设定保存的路径和名字,第三个参数将训练的次数作为后缀加入到模型名字中

saver.save(sess, 'my-model', global_step=0) ==>      filename: 'my-model-0'
...
saver.save(sess, 'my-model', global_step=1000) ==> filename: 'my-model-1000'

 

复制代码
a = tf.Variable(1., tf.float32)
b = tf.Variable(2., tf.float32)
num = 10

model_save_path = './mod/'
model_name = 'mod'

saver = tf.train.Saver()

# with tf.Session() as sess:
#     init_op = tf.global_variables_initializer()
#     sess.run(init_op)
#     for step in np.arange(num):
#         c = sess.run(tf.add(a, b))
#         # checkpoint_path = os.path.join(model_save_path, model_name)
#         # 默认最多同时存放 5 个模型
#         saver.save(sess, os.path.join(model_save_path, model_name), global_step=step)
复制代码

 

Tensorflow 会自动生成4个文件

第一个文件为 model.ckpt.meta,保存了 Tensorflow 计算图的结构,可以简单理解为神经网络的网络结构。

model.ckpt.index 和 model.ckpt.data-*****-of-***** 文件保存了所有变量的取值。

最后一个文件为 checkpoint 文件,保存了一个目录下所有的模型文件列表。

 

 

with tf.Session() as sess:
ckpt=tf.train.get_checkpoint_state('mod/')
print(ckpt)

tf.train.get_checkpoint_state函数通过checkpoint文件找到模型文件名。

tf.train.get_checkpoint_state(checkpoint_dir,latest_filename=None)
该函数返回的是checkpoint文件CheckpointState proto类型的内容,其中有model_checkpoint_path和all_model_checkpoint_paths两个属性。其中model_checkpoint_path保存了最新的tensorflow模型文件的文件名,all_model_checkpoint_paths则有未被删除的所有tensorflow模型文件的文件名。


model_checkpoint_path: "mod/mod-9"
all_model_checkpoint_paths: "mod/mod-5"
all_model_checkpoint_paths: "mod/mod-6"
all_model_checkpoint_paths: "mod/mod-7"
all_model_checkpoint_paths: "mod/mod-8"
all_model_checkpoint_paths: "mod/mod-9"

# 载入模型,不需要提供模型的名字,会通过 checkpoint 文件定位到最新保存的模型
if ckpt and ckpt.model_checkpoint_path:
saver.restore(sess, ckpt.model_checkpoint_path)

 

posted on   cltt  阅读(806)  评论(0编辑  收藏  举报

编辑推荐:
· go语言实现终端里的倒计时
· 如何编写易于单元测试的代码
· 10年+ .NET Coder 心语,封装的思维:从隐藏、稳定开始理解其本质意义
· .NET Core 中如何实现缓存的预热?
· 从 HTTP 原因短语缺失研究 HTTP/2 和 HTTP/3 的设计差异
阅读排行:
· 分享一个免费、快速、无限量使用的满血 DeepSeek R1 模型,支持深度思考和联网搜索!
· 基于 Docker 搭建 FRP 内网穿透开源项目(很简单哒)
· ollama系列01:轻松3步本地部署deepseek,普通电脑可用
· 25岁的心里话
· 按钮权限的设计及实现
< 2025年3月 >
23 24 25 26 27 28 1
2 3 4 5 6 7 8
9 10 11 12 13 14 15
16 17 18 19 20 21 22
23 24 25 26 27 28 29
30 31 1 2 3 4 5

导航

统计

点击右上角即可分享
微信分享提示