tensorflow保存读取-【老鱼学tensorflow】

当我们对模型进行了训练后,就需要把模型保存起来,便于在预测时直接用已经训练好的模型进行预测。

保存模型的权重和偏置值

假设我们已经训练好了模型,其中有关于weights和biases的值,例如:

import tensorflow as tf
# 保存到文件
W = tf.Variable([[1, 2, 3], [3, 4, 5]], dtype=tf.float32, name='weights')
b = tf.Variable([[1, 2, 3]], dtype=tf.float32, name='biases')

然后我们初始化这些变量的值,假装是训练后被设置上的值:

init = tf.global_variables_initializer()
sess = tf.Session()
sess.run(init)

最后进行保存:

# 创建saver
saver = tf.train.Saver()
save_path = saver.save(sess, "D:/todel/python/saver/save_net.ckpt")
print("保存的路径为:", save_path)

这样在打印出:

保存的路径为: D:/todel/python/saver/save_net.ckpt

在那个目录下,我们看到:

这样,这些训练后的参数就被保存起来了。

完整的保存参数的代码为:

import tensorflow as tf
# 保存到文件
W = tf.Variable([[1, 2, 3], [3, 4, 5]], dtype=tf.float32, name='weights')
b = tf.Variable([[1, 2, 3]], dtype=tf.float32, name='biases')

init = tf.global_variables_initializer()
sess = tf.Session()
sess.run(init)

# 创建saver
saver = tf.train.Saver()
save_path = saver.save(sess, "D:/todel/python/saver/save_net.ckpt")
print("保存的路径为:", save_path)

恢复模型的权重和偏置值

在我们训练好模型并把训练后的权重和偏置值保存了之后,当我们需要进行预测时,只要读取这个已经保存好的权重和偏置值就可以进行预测了。
当然,这里的模型结构还是需要进行创建的,因为我们保存的仅仅是权重值和偏置值。

首先定义要恢复的权重和偏置值的结构:

import tensorflow as tf
import numpy as np
# 定义权重和偏置值的结构,但其中的数值随便填
W = tf.Variable(np.arange(6).reshape((2, 3)), dtype=tf.float32, name="weights")
b = tf.Variable(np.arange(3).reshape((1, 3)), dtype=tf.float32, name="biases")

注意:其中的name要跟之前保存时一致。

然后进行加载:

saver = tf.train.Saver()
sess = tf.Session()
# 不需要对变量进行初始化,因为这些变量的值我们会从saver中进行恢复
saver.restore(sess, "D:/todel/python/saver/save_net.ckpt")
print("weights:", sess.run(W))
print("biases:", sess.run(b))

这样输出为:

weights: [[ 1.  2.  3.]
 [ 3.  4.  5.]]
biases: [[ 1.  2.  3.]]

就是前面我们保存的内容被恢复出来了。

完整的恢复代码为:

import tensorflow as tf
import numpy as np
# 定义权重和偏置值的结构,但其中的数值随便填
W = tf.Variable(np.arange(6).reshape((2, 3)), dtype=tf.float32, name="weights")
b = tf.Variable(np.arange(3).reshape((1, 3)), dtype=tf.float32, name="biases")

saver = tf.train.Saver()
sess = tf.Session()
# 不需要对变量进行初始化,因为这些变量的值我们会从saver中进行恢复
saver.restore(sess, "D:/todel/python/saver/save_net.ckpt")
print("weights:", sess.run(W))
print("biases:", sess.run(b))

posted @ 2018-02-28 09:33  dreampursuer  阅读(464)  评论(0编辑  收藏  举报