python下tensorflow模型的导出

一 基本流程

1.python脚本中定义自己的模型,训练完成后将tensorflow graph定位导出为protobuf的二进制文件或者文本文件(一个仅有tensor定义但是不含有权重参数的文件);

2.python脚本训练过程保存模型参数文件*.ckpt;

3.调用tensorflow自带的freeze_graph.py小工具,输入格式为*.pb活在*.pbtxt的protobuf文件和*.ckpt的参数文件,输出为一个新的同时包含图定义和参数的*.pb文件(这个步骤的作用是把checkpoint .ckpt文件中的参数转化为常量const operator后和之前的tensor定义绑在一起)。

二 具体操作

1.几个用到的python API

(1)tf.train.write_graph

write_graph(
    graph_or_graph_def,
    logdir,
    name,
    as_text=True
)
The graph is written as text proto unless as_text is False.
v = tf.Variable(0, name='my_variable')
sess = tf.Session()
tf.train.write_graph(sess.graph_def, '/tmp/my-model', 'train.pbtxt')

or

v = tf.Variable(0, name='my_variable')
sess = tf.Session()
tf.train.write_graph(sess.graph, '/tmp/my-model', 'train.pbtxt')

Args:

  • graph_or_graph_def: A Graph or a GraphDef protocol buffer.
  • logdir:Directory where to write the graph. This can refer to remote fiesystems, such as Google Cloud Storage (GCS).
  • name:Filename for the graph.
  • as_text: If True ,writes the graph as an ASCII proto.

Returns:

The path of the output proto file.

(2)tf.name_scope

Aliases:

  • tf.contrib.keras.backend.name_scope
  • tf.name_scope
name_scope(
    name,
    default_name=None,
    values=None
)

 

For example, to define a new python op called my_op:

def my_op(a, b, c, name=None):
  with tf.name_scope(name, "MyOp", [a, b, c]) as scope:
    a = tf.convert_to_tensor(a, name="a")
    b = tf.convert_to_tensor(b, name="b")
    c = tf.convert_to_tensor(c, name="c")
    # Define some computation that uses `a`, `b`, and `c`.
    return foo_op(..., name=scope)


Args:

  • name: The name argument that is passed to the op function.
  • default_name:The default name to use if the name argument is None.
  • values: The list of Tensor arguments that are passed to the op function.

Returns:

A context manager for use in defining Python ops. Yields the name scope.

Raises:

  • ValueError: if neither name nor default_name is provided but values are.

2.操作举例

    用一个简单的脚本,来训练一个包含1个隐含层的ANN模型来对Iris数据集分类,模型每层节点数:【5,64,3】

具体脚本参考:https://github.com/rockingdingo/tensorflow-tutorial

(1)定义Graph中输入和输出tensor名称

    为了方便我们在调用C++ API时,能够准确根据Tensor的名称取出对应的结果,在Python脚本训练时就定义好每个tensor的tensor_name。如果tensor包含命名空间namespace的如“namespace_A/tensor_A"需要用完整的名称。

 在这个例子中,我们定义以下三个tensor的tensorname:

    class TensorNameConfig(object):  
        input_tensor = "inputs"  
        target_tensor = "target"  
        output_tensor = "output_node"  
        # To Do  

 

(2)输出graph的定义文件*.pb和参数文件*.ckpt

    我们要在训练的脚本nn_model.py中加入两处代码:

第一处是将tensorflow的graph_def保存成./models/目录下一个文件nn_model.pbtxt,里面包含由图中各个tensor的定义名称等信息;

第二处是在训练代码红加入保存参数文件的代码,将训练好的ANN模型的权重Weight和Bias同时保存到./ckpt目录下的*.ckpt,*.meta等文件。

最后执行python nn_model.py就可以完成训练过程。

    # 保存图模型  
    tf.train.write_graph(session.graph_def, FLAGS.model_dir, "nn_model.pbtxt", as_text=True)  
      
    # 保存 Checkpoint  
    checkpoint_path = os.path.join(FLAGS.train_dir, "nn_model.ckpt")  
    model.saver.save(session, checkpoint_path)  
      
    # 执行命令完成训练过程  
    python nn_model.py 

(3)使用freeze_graph.py小工具整合模型freeze_graph

    最后利用tensorflow自带的freeze_graph.py小工具把.ckpt文件中的参数固定在graph内,输出nn_model_frozen.pb

# 运行freeze_graph.py 小工具
# freeze the graph and the weights
python freeze_graph.py --input_graph=../model/nn_model.pbtxt --input_checkpoint=../ckpt/nn_model.ckpt --output_graph=../model/nn_model_frozen.pb --output_node_names=output_node

# 或者执行
sh build.sh

# 成功标志: 
# Converted 2 variables to const ops.
# 9 ops in the final graph.

Args:

  • --input_graph:模型的图的定义文件nn_model.pb(不包含权重);
  • --input_checkpoint:模型的参数文件nn_model.ckpt;
  • --output_graph:绑定后包含参数的图模型文件nn_model_frozen.pb;
  • --output_node_names:输出待机算的tensor名字

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

posted on 2017-10-16 22:06  一万种树  阅读(8507)  评论(0编辑  收藏  举报

导航