tensorflow三种加载模型的方法和三种模型保存文件(.ckpt,.pb, SavedModel)

一、.ckpt文件的保存和加载

1、保存的文件

 

 这是我保存的文件,保存一次有四个文件:

checkpoint文件:用于告知某些TF函数,这是最新的检查点文件(可以用记事本打开看一下)

.data文件:(后面缀的那一串我也布吉岛是啥)这个文件保存的是图中所有变量的值,没有结构。

.index文件:可能是保存了一些必要的索引叭(这个文件不大清楚)。

.meta文件:保存了计算图的结构,但是不包含里面变量的值。

使用这种方法保存模型时会保存成上面这四个文件,重新加载模型时通常只会用到.meta文件恢复图结构然后用.data文件把各个变量的值再加进去。

2、保存模型的方法

代码:

saver=tf.train.Saver(max_to_keep)

saver.save(sess,'D:/model',global_step=epoch)

创建一个saver(max_to_keep可设置要保存的模型的个数),调用save方法将当前sess会话中的图和变量等信息保存到指定路径,global_step代表当前的轮数,设置之后会在文件名后面缀一个‘-600’这样的东西

3、重加载模型的方法

saver=tf.train.import_meta_graph('model1/my-model-190.meta')  #恢复计算图结构

saver.restore(sess, tf.train.latest_checkpoint("model/"))  #恢复所有变量信息

现在sess中已经恢复了网络结构和变量信息了,接下来可以直接用节点的名称来调用:

print(sess.run('op:0',feed_dict={'x:0':2,'y:0':3})

或者采用:

graph = tf.get_default_graph()

input_x = graph.get_tensor_by_name('x:0')

input_y=graph.get_tensor_by_name('y:0')

op=graph.get_tensor_name('op:0')

print(sess.run(op,feed_dict={input_x:2,input_y:3)

这样子使用也可。

4、PS

.ckpt方式保存模型,这种模型文件是依赖 TensorFlow 的,只能在其框架下使用。

https://cv-tricks.com/tensorflow-tutorial/save-restore-tensorflow-models-quick-complete-tutorial/这篇文章详细解释和演示了利用ckpt文件保存模型,并进行迁移学习的方法(不过是英文版的)

二、.pb文件的保存和加载

1、保存的文件

 .pb文件里面保存了图结构+数据,加载模型时只需要这一个文件就好。

2、保存模型的方法

constant_graph = tf.graph_util.convert_variables_to_constants(sess, sess.graph_def, ['op'])

with tf.gfile.FastGFile('D:/pycharm files/model.pb', mode='wb') as f:

  f.write(constant_graph.SerializeToString())

3、加载模型的方法

with tf.gfile.FastGFile(pb_file_path, 'rb') as f:

  graph_def = tf.GraphDef() # 生成图

  graph_def.ParseFromString(f.read()) # 图加载模型

   tf.import_graph_def(graph_def, name='')

接下来与前面的相同可以直接用节点的名称来调用:

print(sess.run('op:0',feed_dict={'x:0':2,'y:0':3})

或者采用:

graph = tf.get_default_graph()

input_x = graph.get_tensor_by_name('x:0')

input_y=graph.get_tensor_by_name('y:0')

op=graph.get_tensor_name('op:0')

print(sess.run(op,feed_dict={input_x:2,input_y:3)

这样子使用也可。

4、

谷歌推荐的保存模型的方式是保存模型为 PB 文件,它具有语言独立性,可独立运行,封闭的序列化格式,任何语言都可以解析它,它允许其他语言和深度学习框架读取、继续训练和迁移 TensorFlow 的模型。另外的好处是保存为 PB 文件时候,模型的变量都会变成固定的,导致模型的大小会大大减小。

加载一个pb文件之后再对其进行微调(也就是将这个pb文件的网络作为自己网络的一部分),然后再保存成pb文件,后一个pb网络会包含前一个pb网络。

三、saved model

1、保存文件

在传入的目录下会有一个pb文件和一个文件夹:

2、保存模型

builder = tf.saved_model.builder.SavedModelBuilder(path)

builder.add_meta_graph_and_variables(sess,['cpu_server_1'])

3、加载模型

with tf.Session(graph=tf.Graph()) as sess:

  tf.saved_model.loader.load(sess, ['cpu_server_1'], pb_file_path+'savemodel')

接下来可以直接使用名字或者get_tensor_by_name后再进行使用

  input_x = sess.graph.get_tensor_by_name('x:0')

  input_y = sess.graph.get_tensor_by_name('y:0')

  op = sess.graph.get_tensor_by_name('op:0')

  ret = sess.run(op, feed_dict={input_x: 5, input_y: 5})

关于savedmodel文件,可阅读这篇博客,讲的很清楚https://blog.csdn.net/thriving_fcl/article/details/75213361

 

下面代码是实现上面三种保存模式的小例子,可以粘贴复制把相关的代码注释掉,运行一下看看结果,能加深理解:

import tensorflow as tf
with tf.Session() as sess:
  #搭建网络
  x=tf.placeholder(tf.float32,name='x')
  y=tf.placeholder(tf.float32,name='y')
  b=tf.Variable(1.,name='b')
  xy=tf.multiply(x,y)
  op=tf.add(xy,b,name='op')
  sess.run(tf.global_variables_initializer())
  print(sess.run(op,feed_dict={x:2,y:3}))

  #ckpt保存
  saver=tf.train.Saver()
  saver.save(sess,'D:/pycharm files/111/ckpt/model_ck')

  #pb保存
  constant_graph=tf.graph_util.convert_variables_to_constants(sess,sess.graph_def,['op'])
  with tf.gfile.FastGFile('D:/pycharm files/111/pb/model.pb','wb') as f:
  f.write(constant_graph.SerializeToString())

  #savedmodel文件保存
  builder=tf.saved_model.builder.SavedModelBuilder('D:/pycharm files/111/savemodel')
  builder.add_meta_graph_and_variables(sess,['cpu_server_1'])
  builder.save()

  print('over')


  #ckpt加载
  saver=tf.train.import_meta_graph('D:/pycharm files/111/ckpt/model_ck.meta')
  saver.restore(sess,tf.train.latest_checkpoint('D:/pycharm files/111/ckpt'))

  #pb加载
  with tf.gfile.FastGFile('D:/pycharm files/111/pb/model.pb','rb') as f:
    graph_def=tf.GraphDef()
    graph_def.ParseFromString(f.read())
    tf.import_graph_def(graph_def,name='')

  #savemodel加载
  tf.saved_model.loader.load(sess, ['cpu_server_1'], 'D:/pycharm files/111/savemodel')

  #测试模型加载是否成功
  input_x = sess.graph.get_tensor_by_name('x:0')
  input_y = sess.graph.get_tensor_by_name('y:0')
  op = sess.graph.get_tensor_by_name('op:0')
  ret = sess.run(op, feed_dict={input_x: 5, input_y: 5})
  print(ret)

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

posted @ 2019-11-18 15:40  彼岸的客人  阅读(18445)  评论(0编辑  收藏  举报