部分文章内容为公开资料查询整理,原文出处可能未标注,如有侵权,请联系我,谢谢。邮箱地址:gnivor@163.com ►►►需要气球么?请点击我吧!

tensorflow笔记-meta graph解析

参考资料
https://blog.csdn.net/weixin_36670529/article/details/87895103
https://www.jianshu.com/p/b853a46dcbb1

1. MetaGraph

https://github.com/tensorflow/tensorflow/blob/v2.12.1/tensorflow/core/protobuf/meta_graph.proto
tensorflow通过MetaGraph来记录计算图中节点的信息以及运行计算图中节点所需要的元结构

message MetaGraphDef {
  MetaInfoDef meta_info_def = 1;

  // GraphDef.
  GraphDef graph_def = 2;

  // SaverDef.
  SaverDef saver_def = 3;

  // collection_def: Map from collection name to collections.
  // See CollectionDef section for details.
  map<string, CollectionDef> collection_def = 4;

  // signature_def: Map from user supplied key for a signature to a single
  // SignatureDef.
  map<string, SignatureDef> signature_def = 5;

  // Asset file def to be used with the defined graph.
  repeated AssetFileDef asset_file_def = 6;

  // Extra information about the structure of functions and stateful objects.
  SavedObjectGraph object_graph_def = 7;
}

2. SignatureDef

https://github.com/tensorflow/tensorflow/blob/v2.10.0/tensorflow/core/protobuf/meta_graph.proto#L257
在TensorFlow中,Signature是一种描述TensorFlow模型输入和输出的方式。Signature定义了模型输入和输出的名称、类型和形状等信息,可以帮助使用者更好地理解和使用模型。

每个TensorFlow模型都有一个默认的Signature。Signature定义了一个或多个输入和输出,每个输入和输出都包括以下信息:

  1. 名称:输入或输出的名称,用于在TensorFlow中引用。
  2. TensorInfo:输入或输出的类型、形状和名称等信息。

2.1 协议

message SignatureDef {
  // Named input parameters.
  map<string, TensorInfo> inputs = 1;
  // Named output parameters.
  map<string, TensorInfo> outputs = 2;
  // Extensible method_name information enabling third-party users to mark a
  // SignatureDef as supporting a particular method. This enables producers and
  // consumers of SignatureDefs, e.g. a model definition library and a serving
  // library to have a clear hand-off regarding the semantics of a computation.
  //
  // Note that multiple SignatureDefs in a single MetaGraphDef may have the same
  // method_name. This is commonly used to support multi-headed computation,
  // where a single graph computation may return multiple results.
  string method_name = 3;
}

2.2 格式

// SignatureDef defines the signature of a computation supported by a TensorFlow
// graph.
//
// For example, a model with two loss computations, sharing a single input,
// might have the following signature_def map, in a MetaGraphDef message.
//
// Note that across the two SignatureDefs "loss_A" and "loss_B", the input key,
// output key, and method_name are identical, and will be used by system(s) that
// implement or rely upon this particular loss method. The output tensor names
// differ, demonstrating how different outputs can exist for the same method.
//
// signature_def {
//   key: "loss_A"
//   value {
//     inputs {
//       key: "input"
//       value {
//         name: "input:0"
//         dtype: DT_STRING
//         tensor_shape: ...
//       }
//     }
//     outputs {
//       key: "loss_output"
//       value {
//         name: "loss_output_A:0"
//         dtype: DT_FLOAT
//         tensor_shape: ...
//       }
//     }
//     method_name: "some/package/compute_loss"
//   }
//   ...
// }
// signature_def {
//   key: "loss_B"
//   value {
//     inputs {
//       key: "input"
//       value {
//         name: "input:0"
//         dtype: DT_STRING
//         tensor_shape: ...
//       }
//     }
//     outputs {
//       key: "loss_output"
//       value {
//         name: "loss_output_B:0"
//         dtype: DT_FLOAT
//         tensor_shape: ...
//       }
//     }
//     method_name: "some/package/compute_loss"
//   }
//   ...
// }

key: signature
value.inputs.key: tensor 名

2.3 根据signature名获取tensor name

在C++中,要使用TensorFlow Serving API从模型签名中获取输入和输出张量的名称:

  1. 加载SavedModel。
  2. 获取SignatureDef。
  3. 从SignatureDef中提取输入和输出张量的名称。
#include "tensorflow/cc/saved_model/loader.h"
#include "tensorflow/core/protobuf/meta_graph.pb.h"

int main() {
  const std::string export_dir = "path/to/saved_model";
  tensorflow::SavedModelBundle bundle;

  // 加载模型
  tensorflow::SessionOptions session_options;
  tensorflow::RunOptions run_options;
  std::unordered_set<std::string> tags = {tensorflow::kSavedModelTagServe};
  tensorflow::LoadSavedModel(session_options, run_options, export_dir, tags, &bundle);

  // 获取签名
  const auto& signature_def_map = bundle.meta_graph_def.signature_def();
  const auto& signature_def = signature_def_map.at("serving_default");

  // 获取输入和输出张量的名称
  std::string input_tensor_name = signature_def.inputs().at("input_layer").name();
  std::string output_tensor_name = signature_def.outputs().at("output_layer").name();

  std::cout << "Input tensor name: " << input_tensor_name << std::endl;
  std::cout << "Output tensor name: " << output_tensor_name << std::endl;

  return 0;
}
posted @ 2023-09-07 19:22  流了个火  阅读(220)  评论(0编辑  收藏  举报
►►►需要气球么?请点击我吧!►►►
View My Stats