pb模型提取

 

import tensorflow as tf  
from tensorflow.python.framework import graph_util  
  
v1 = tf.Variable(tf.constant(1.0, shape=[1]), name="v1")  
v2 = tf.Variable(tf.constant(3.0, shape=[1]), name="v2")  
result = v1 + v2  
  
with tf.Session() as sess:  
    sess.run(tf.global_variables_initializer())  
    # 导出当前计算图的GraphDef部分,即从输入层到输出层的计算过程部分  
    graph_def = tf.get_default_graph().as_graph_def()  
    output_graph_def = graph_util.convert_variables_to_constants(sess,  
                                                        graph_def, ['add'])  
  
    with tf.gfile.GFile("d:/model/combined_model.pb", 'wb') as f:  
        f.write(output_graph_def.SerializeToString())  

 

python使用

import tensorflow as tf  
from tensorflow.python.platform import gfile  
  
with tf.Session() as sess:  
    model_filename = "d:/model/combined_model.pb"  
    with gfile.FastGFile(model_filename, 'rb') as f:  
        graph_def = tf.GraphDef()  
        graph_def.ParseFromString(f.read())  
  
    result = tf.import_graph_def(graph_def, return_elements=["add:0"])  
    print(sess.run(result)) # [array([ 3.], dtype=float32)]

 

此模型无法在android上使用

加载模型时会报如下错误java.lang.IllegalArgumentException: ByteBuffer is not a valid flatbuffer model

 

posted @ 2019-04-30 10:32  牧 天  阅读(760)  评论(0)    收藏  举报