【TensorFlow】分析模型常用函数
常用函数
获取模型输入节点信息
import tensorflow as tf from tensorflow.python.tools import saved_model_utils model_dir = 'model_dir' meta_graph_def = saved_model_utils.get_meta_graph_def(model_dir, tf.saved_model.SERVING) signatures = meta_graph_def.signature_def input_tensor_names = {} for sig_name in signatures: for input_name in signatures[sig_name].inputs: input_tensor_shape = [] input_tensor = signatures[sig_name].inputs[input_name] for dim in input_tensor.tensor_shape.dim: input_tensor_shape.append(int(dim.size)) input_tensor_names[input_name] = input_tensor.name print(input_tensor_names)
获取模型输出节点信息
import tensorflow as tf from tensorflow.python.tools import saved_model_utils model_dir = 'model_dir' meta_graph_def = saved_model_utils.get_meta_graph_def(model_dir, tf.saved_model.SERVING) signatures = meta_graph_def.signature_def output_tensor_names = {} for sig_name in signatures: for output_name in signatures[sig_name].outputs: output_tensor_shape = [] output_tensor = signatures[sig_name].outputs[output_name] for dim in output_tensor.tensor_shape.dim: output_tensor_shape.append(int(dim.size)) output_tensor_names[output_name] = output_tensor.name print(output_tensor_names)