Tensorflow 模型保存与调用
Tensorflow 两种保存模型的方式:pb 和 saved_model 都可以。
1、pb
1.1 模型保存成pb
freozen_pb.py
1 import tensorflow as tf 2 from tensorflow.python.framework import graph_util 3 4 5 6 with tf.Session(graph=tf.Graph()) as sess: 7 x = tf.placeholder(tf.int32, name='in_x') 8 y = tf.placeholder(tf.int32, name='in_y') 9 b = tf.Variable(1, name='b') 10 m = tf.multiply(x, y) 11 a = tf.add(m, b, name='out_add') 12 13 sess.run(tf.global_variables_initializer()) 14 constant_graph = graph_util.convert_variables_to_constants(sess, sess.graph_def, ['out_add']) 15 16 feed_dict = {x: 10, y: 3} 17 print(sess.run(a, feed_dict)) 18 19 with tf.gfile.FastGFile('./model.pb', mode='wb') as f: 20 f.write(constant_graph.SerializeToString())
1.2 调用pb模型
call_pb.py
1 import tensorflow as tf 2 from tensorflow.python.platform import gfile 3 4 5 sess = tf.Session() 6 with gfile.FastGFile('./model.pb', 'rb') as f: 7 graph_def = tf.GraphDef() 8 graph_def.ParseFromString(f.read()) 9 sess.graph.as_default() 10 tf.import_graph_def(graph_def, name='') 11 12 sess.run(tf.global_variables_initializer()) 13 #print(sess.run('b:0')) 14 15 in_x = sess.graph.get_tensor_by_name('in_x:0') 16 in_y = sess.graph.get_tensor_by_name('in_y:0') 17 out_add = sess.graph.get_tensor_by_name('out_add:0') 18 19 ret = sess.run(out_add, feed_dict={in_x: 8, in_y: 9}) 20 print(ret)
2、 saved_model
2.1 模型保存成saved model
freozen_sm.py
1 import os 2 import tensorflow as tf 3 4 saved_model_path = os.getcwd() 5 6 with tf.Session(graph=tf.Graph()) as sess: 7 x = tf.placeholder(tf.int32, name='in_x') 8 y = tf.placeholder(tf.int32, name='in_y') 9 b = tf.Variable(1, name='b') 10 m = tf.multiply(x, y) 11 a = tf.add(m, b, name='out_add') 12 13 sess.run(tf.global_variables_initializer()) 14 15 tf.saved_model.simple_save(sess, './sm', {'in_x': x, 'in_y': y}, {'out_add': a}, )
2.2 调用saved model模型
call_sm.py
1 import tensorflow as tf 2 3 sess = tf.Session() 4 tf.saved_model.load(sess, [tf.saved_model.tag_constants.SERVING], './sm') 5 in_x = sess.graph.get_tensor_by_name('in_x:0') 6 in_y = sess.graph.get_tensor_by_name('in_y:0') 7 out_add = sess.graph.get_tensor_by_name('out_add:0') 8 9 ret = sess.run(out_add, feed_dict={in_x: 8, in_y: 5}) 10 print(ret)