tf用 tf.train.Saver类来实现神经网络模型的保存和读取。无论保存还是读取,都首先要创建saver对象。
用saver对象的save方法保存模型
保存的是所有变量
save( sess, save_path, global_step=None, latest_filename=None, meta_graph_suffix='meta', write_meta_graph=True, write_state=True )
保存模型需要session,初始化变量
用法示例
import tensorflow as tf v1 = tf.Variable(tf.constant(1.0, shape=[1]), name="v1") v2 = tf.Variable(tf.constant(2.0, shape=[1]), name="v2") result = v1 + v2 saver = tf.train.Saver() with tf.Session() as sess: sess.run(tf.global_variables_initializer()) saver.save(sess, "Model/model.ckpt", global_step=3)
输出
1. global_step 放在文件名后面,起个标记作用
2. save方法输出4个文件
// checkpoint 里面是一堆路径,model_checkpoint_path 记录了最新模型的路径,all_model_checkpoint_paths 记录了之前模型的路径
// model.ckpt-3.data-00000-of-00001 存放的是模型参数
// model.ckpt-3.meta 存放的是计算图
3. 最多只能保存近5次模型,比如我们迭代100次,每次保存一下,最后只留下了最近的5次。
用saver对象的restore方法加载模型
加载的是所有变量,以name为准,假如保存的模型中有变量叫 a ,value是2,那么在加载后,即使重新建立变量a,并赋其他value,其value仍然是2
restore(
sess,
save_path
)
加载模型需要session,不需要初始化变量
用法示例(接前例)
v1 = tf.Variable(tf.constant(1.0, shape=[1]), name="v1") v2 = tf.Variable(tf.constant(7.0, shape=[1]), name="v2") # v2 = tf.Variable(tf.constant(7.0, shape=[1]), name="v22") # Key v22 not found in checkpoint result = v1 + v2 saver = tf.train.Saver() # with tf.Session() as sess: saver.restore(sess, "./Model/model.ckpt-3") # 注意此处路径前添加"./" print(sess.run(result)) # [ 3.]
1. 重新给 name为 v2的变量 赋值,其结果仍然是3,说明加载了之前的v2
2. 新建name为 v22 的变量,报错, 在保存的模型中没找到v2 。说明寻找变量以name为准,不以变量名为准
继续做如下尝试
v1 = tf.Variable(tf.constant(1.0, shape=[1]), name="v1") # v2 = tf.Variable(tf.constant(7.0, shape=[1]), name="v2") v3 = tf.Variable(tf.constant(7.0, shape=[1]), name="v22") # Key v22 not found in checkpoint result = v1 + v3 saver = tf.train.Saver() # with tf.Session() as sess: # sess.run(tf.global_variables_initializer()) # Key v22 not found in checkpoint saver.restore(sess, "./Model/model.ckpt-3") # 注意此处路径前添加"./" # sess.run(tf.global_variables_initializer()) # Key v22 not found in checkpoint print(sess.run(result)) # [ 3.]
1. 新建name为v22的变量v3,仍然报错,说明新的变量没有被接受
2. 在加载模型前初始化v3,仍然报错,加载模型后初始化v3,仍然报错,这说明在加载的模型中不接受新的变量。
继续尝试
v1 = tf.Variable(tf.constant(1.0, shape=[1]), name="v1") # v2 = tf.Variable(tf.constant(7.0, shape=[1]), name="v2") v3 = tf.Variable(tf.constant(7.0, shape=[1]), name="v22") # Key v22 not found in checkpoint result = v1 + v3 saver = tf.train.Saver() # with tf.Session() as sess: sess.run(tf.global_variables_initializer()) # Key v22 not found in checkpoint print(sess.run(v3)) # [7.] saver.restore(sess, "./Model/model.ckpt-3") # 注意此处路径前添加"./" sess.run(tf.global_variables_initializer()) # Key v22 not found in checkpoint print(sess.run(result)) # [ 3.]
在加载模型前初始化变量,正确输出,但在加载后,报错,证实了我上面的说法,“不接受新的变量”
总结:
1. 模型加载加载的是所有变量,以name为准
2. 模型加载后不接受任何新的变量
3. 在加载模型时需要重新定义计算图上的所有节点,但是变量无需初始化
加载计算图
直接加载计算图就无需重新定义计算图上的节点
用法示例
saver = tf.train.import_meta_graph("Model/model.ckpt-3.meta") with tf.Session() as sess: saver.restore(sess, "./Model/model.ckpt-3") # 注意路径写法 print(sess.run(tf.get_default_graph().get_tensor_by_name("add:0"))) # [3.] # print(sess.run(sess.graph.get_tensor_by_name('add:0'))) # [3.]
重命名变量
在加载模型时不接受新的变量,这会造成很多麻烦。
为解决这个问题,加载模型时可以给变量重命名。
用法示例
u1 = tf.Variable(tf.constant(1.0, shape=[1]), name="other-v1") u2 = tf.Variable(tf.constant(2.0, shape=[1]), name="other-v2") result = u1 + u2 # 若直接声明Saver类对象,会报错变量找不到 # 使用一个字典dict重命名变量即可,{"已保存的变量的名称name": 重命名变量名} # 原来名称name为v1的变量现在加载到变量u1(名称name为other-v1)中 saver = tf.train.Saver({"v1": u1, "v2": u2}) with tf.Session() as sess: saver.restore(sess, "./Model/model.ckpt-3") print(sess.run(result)) # [ 3.]
注意重命名格式 老变量的name: 新变量名
参考资料:
https://blog.csdn.net/marsjhao/article/details/72829635
https://blog.csdn.net/shuzfan/article/details/79197432
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· AI与.NET技术实操系列:向量存储与相似性搜索在 .NET 中的实现
· 基于Microsoft.Extensions.AI核心库实现RAG应用
· Linux系列:如何用heaptrack跟踪.NET程序的非托管内存泄露
· 开发者必知的日志记录最佳实践
· SQL Server 2025 AI相关能力初探
· 震惊!C++程序真的从main开始吗?99%的程序员都答错了
· winform 绘制太阳,地球,月球 运作规律
· 【硬核科普】Trae如何「偷看」你的代码?零基础破解AI编程运行原理
· 上周热点回顾(3.3-3.9)
· 超详细:普通电脑也行Windows部署deepseek R1训练数据并当服务器共享给他人