tensorflow三种加载模型的方法和三种模型保存文件(.ckpt,.pb, SavedModel)
一、.ckpt文件的保存和加载
1、保存的文件
这是我保存的文件,保存一次有四个文件:
checkpoint文件:用于告知某些TF函数,这是最新的检查点文件(可以用记事本打开看一下)
.data文件:(后面缀的那一串我也布吉岛是啥)这个文件保存的是图中所有变量的值,没有结构。
.index文件:可能是保存了一些必要的索引叭(这个文件不大清楚)。
.meta文件:保存了计算图的结构,但是不包含里面变量的值。
使用这种方法保存模型时会保存成上面这四个文件,重新加载模型时通常只会用到.meta文件恢复图结构然后用.data文件把各个变量的值再加进去。
2、保存模型的方法
代码:
saver=tf.train.Saver(max_to_keep)
saver.save(sess,'D:/model',global_step=epoch)
创建一个saver(max_to_keep可设置要保存的模型的个数),调用save方法将当前sess会话中的图和变量等信息保存到指定路径,global_step代表当前的轮数,设置之后会在文件名后面缀一个‘-600’这样的东西
3、重加载模型的方法
saver=tf.train.import_meta_graph('model1/my-model-190.meta') #恢复计算图结构
saver.restore(sess, tf.train.latest_checkpoint("model/")) #恢复所有变量信息
现在sess中已经恢复了网络结构和变量信息了,接下来可以直接用节点的名称来调用:
print(sess.run('op:0',feed_dict={'x:0':2,'y:0':3})
或者采用:
graph = tf.get_default_graph()
input_x = graph.get_tensor_by_name('x:0')
input_y=graph.get_tensor_by_name('y:0')
op=graph.get_tensor_name('op:0')
print(sess.run(op,feed_dict={input_x:2,input_y:3)
这样子使用也可。
4、PS
.ckpt方式保存模型,这种模型文件是依赖 TensorFlow 的,只能在其框架下使用。
https://cv-tricks.com/tensorflow-tutorial/save-restore-tensorflow-models-quick-complete-tutorial/这篇文章详细解释和演示了利用ckpt文件保存模型,并进行迁移学习的方法(不过是英文版的)
二、.pb文件的保存和加载
1、保存的文件
.pb文件里面保存了图结构+数据,加载模型时只需要这一个文件就好。
2、保存模型的方法
constant_graph = tf.graph_util.convert_variables_to_constants(sess, sess.graph_def, ['op'])
with tf.gfile.FastGFile('D:/pycharm files/model.pb', mode='wb') as f:
f.write(constant_graph.SerializeToString())
3、加载模型的方法
with tf.gfile.FastGFile(pb_file_path, 'rb') as f:
graph_def = tf.GraphDef() # 生成图
graph_def.ParseFromString(f.read()) # 图加载模型
tf.import_graph_def(graph_def, name='')
接下来与前面的相同可以直接用节点的名称来调用:
print(sess.run('op:0',feed_dict={'x:0':2,'y:0':3})
或者采用:
graph = tf.get_default_graph()
input_x = graph.get_tensor_by_name('x:0')
input_y=graph.get_tensor_by_name('y:0')
op=graph.get_tensor_name('op:0')
print(sess.run(op,feed_dict={input_x:2,input_y:3)
这样子使用也可。
4、
谷歌推荐的保存模型的方式是保存模型为 PB 文件,它具有语言独立性,可独立运行,封闭的序列化格式,任何语言都可以解析它,它允许其他语言和深度学习框架读取、继续训练和迁移 TensorFlow 的模型。另外的好处是保存为 PB 文件时候,模型的变量都会变成固定的,导致模型的大小会大大减小。
加载一个pb文件之后再对其进行微调(也就是将这个pb文件的网络作为自己网络的一部分),然后再保存成pb文件,后一个pb网络会包含前一个pb网络。
三、saved model
1、保存文件
在传入的目录下会有一个pb文件和一个文件夹:
2、保存模型
builder = tf.saved_model.builder.SavedModelBuilder(path)
builder.add_meta_graph_and_variables(sess,
['cpu_server_1'])
3、加载模型
with tf.Session(graph=tf.Graph()) as sess:
tf.saved_model.loader.load(sess, ['cpu_server_1'], pb_file_path+'savemodel')
接下来可以直接使用名字或者get_tensor_by_name后再进行使用
input_x = sess.graph.get_tensor_by_name('x:0')
input_y = sess.graph.get_tensor_by_name('y:0')
op = sess.graph.get_tensor_by_name('op:0')
ret = sess.run(op, feed_dict={input_x: 5, input_y: 5})
关于savedmodel文件,可阅读这篇博客,讲的很清楚https://blog.csdn.net/thriving_fcl/article/details/75213361
下面代码是实现上面三种保存模式的小例子,可以粘贴复制把相关的代码注释掉,运行一下看看结果,能加深理解:
import tensorflow as tf
with tf.Session() as sess:
#搭建网络
x=tf.placeholder(tf.float32,name='x')
y=tf.placeholder(tf.float32,name='y')
b=tf.Variable(1.,name='b')
xy=tf.multiply(x,y)
op=tf.add(xy,b,name='op')
sess.run(tf.global_variables_initializer())
print(sess.run(op,feed_dict={x:2,y:3}))
#ckpt保存
saver=tf.train.Saver()
saver.save(sess,'D:/pycharm files/111/ckpt/model_ck')
#pb保存
constant_graph=tf.graph_util.convert_variables_to_constants(sess,sess.graph_def,['op'])
with tf.gfile.FastGFile('D:/pycharm files/111/pb/model.pb','wb') as f:
f.write(constant_graph.SerializeToString())
#savedmodel文件保存
builder=tf.saved_model.builder.SavedModelBuilder('D:/pycharm files/111/savemodel')
builder.add_meta_graph_and_variables(sess,['cpu_server_1'])
builder.save()
print('over')
#ckpt加载
saver=tf.train.import_meta_graph('D:/pycharm files/111/ckpt/model_ck.meta')
saver.restore(sess,tf.train.latest_checkpoint('D:/pycharm files/111/ckpt'))
#pb加载
with tf.gfile.FastGFile('D:/pycharm files/111/pb/model.pb','rb') as f:
graph_def=tf.GraphDef()
graph_def.ParseFromString(f.read())
tf.import_graph_def(graph_def,name='')
#savemodel加载
tf.saved_model.loader.load(sess, ['cpu_server_1'], 'D:/pycharm files/111/savemodel')
#测试模型加载是否成功
input_x = sess.graph.get_tensor_by_name('x:0')
input_y = sess.graph.get_tensor_by_name('y:0')
op = sess.graph.get_tensor_by_name('op:0')
ret = sess.run(op, feed_dict={input_x: 5, input_y: 5})
print(ret)