【TensorFlow】分析模型常用函数

常用函数

获取模型输入节点信息

import tensorflow as tf
from tensorflow.python.tools import saved_model_utils

model_dir = 'model_dir'
meta_graph_def = saved_model_utils.get_meta_graph_def(model_dir, tf.saved_model.SERVING)

signatures = meta_graph_def.signature_def

input_tensor_names = {}
for sig_name in signatures:
    for input_name in signatures[sig_name].inputs:
        input_tensor_shape = []
        input_tensor = signatures[sig_name].inputs[input_name]
        for dim in input_tensor.tensor_shape.dim:
            input_tensor_shape.append(int(dim.size))
        input_tensor_names[input_name] = input_tensor.name

print(input_tensor_names)

  

获取模型输出节点信息

import tensorflow as tf
from tensorflow.python.tools import saved_model_utils

model_dir = 'model_dir'
meta_graph_def = saved_model_utils.get_meta_graph_def(model_dir, tf.saved_model.SERVING)

signatures = meta_graph_def.signature_def

output_tensor_names = {}
for sig_name in signatures:
    for output_name in signatures[sig_name].outputs:
        output_tensor_shape = []
        output_tensor = signatures[sig_name].outputs[output_name]
        for dim in output_tensor.tensor_shape.dim:
            output_tensor_shape.append(int(dim.size))

        output_tensor_names[output_name] = output_tensor.name

print(output_tensor_names)

  

posted @ 2024-03-04 11:12  周周周文阳  阅读(4)  评论(0编辑  收藏  举报