一、获取pb模型的节点名称
import tensorflow as tf import os model_dir = ‘ ’ model_name = ' ' def create_graph(): with tf.gfile.FastGFile(os.path.join( model_dir, model_name), 'rb') as f: graph_def = tf.GraphDef() graph_def.ParseFromString(f.read()) tf.import_graph_def(graph_def, name='') create_graph() tensor_name_list = [tensor.name for tensor in tf.get_default_graph().as_graph_def().node] f = open('/home/yk/Desktop/Conv-TasNet-master-20190508/op.txt', 'wb') for tensor_name in tensor_name_list: print(tensor_name,'\n') f.write(tensor_name + '\n')
二、ckpt转换为pb模型
from tensorflow.python.tools import inspect_checkpoint as chkp import tensorflow as tf saver = tf.train.import_meta_graph("./ade20k/model.ckpt-27150.meta", clear_devices=True) #【敲黑板!】这里就是填写输出节点名称惹 output_nodes = ["xxx"] with tf.Session(graph=tf.get_default_graph()) as sess: input_graph_def = sess.graph.as_graph_def() saver.restore(sess, "./ade20k/model.ckpt-27150") output_graph_def = tf.graph_util.convert_variables_to_constants(sess, input_graph_def, output_nodes) with open("frozen_model.pb", "wb") as f: f.write(output_graph_def.SerializeToString())
三、pb TensorBoard 可视化
1. 从pb文件中恢复计算图 import tensorflow as tf model = 'model.pb' #请将这里的pb文件路径改为自己的 graph = tf.get_default_graph() graph_def = graph.as_graph_def() graph_def.ParseFromString(tf.gfile.FastGFile(model, 'rb').read()) tf.import_graph_def(graph_def, name='graph') summaryWriter = tf.summary.FileWriter('log/', graph) 执行以上代码就会生成文件在log/events.out.tfevents.1535079670.DESKTOP-5IRM000。 2. 在tensorboard中加载 tensorboard --logdir path/to/log 3. 在浏览器中打开链接
附加:ckpt模型节点获取
import os from tensorflow.python import pywrap_tensorflow checkpoint_path = os.path.join('./ade20k', "model.ckpt-27150") reader = pywrap_tensorflow.NewCheckpointReader(checkpoint_path) var_to_shape_map = reader.get_variable_to_shape_map() for key in var_to_shape_map: print("tensor_name: ", key) # print(reader.get_tensor(key)) #相应的值