tensorflow 如何获取graph中的所有tensor name
import tensorflow as tf
saved_model_dir = "./saved_model"
with tf.Session(graph=tf.Graph()) as sess:
tf.saved_model.loader.load(sess, ["serve"], saved_model_dir)
graph = tf.get_default_graph()
[print(n.name) for n in tf.get_default_graph().as_graph_def().node]
# 得到name之后,就可以获取相应的tensor了,例如:
# input_tensor = sess.graph.get_tensor_by_name('input:0')
# output_tensor = sess.graph.get_tensor_by_name('output:0')
找我内推: 字节跳动各种岗位
作者:
ZH奶酪(张贺)
邮箱:
cheesezh@qq.com
出处:
http://www.cnblogs.com/CheeseZH/
*
本文版权归作者和博客园共有,欢迎转载,但未经作者同意必须保留此段声明,且在文章页面明显位置给出原文连接,否则保留追究法律责任的权利。