Tensorflow savedmodel to graph def

1.使用tf2onnx工具,把saved model转换为tf的graph def(不带function,也就是tf1的计算图)

https://github.com/onnx/tensorflow-onnx/blob/v1.9.3/tf2onnx/tf_loader.py

# -*- coding: utf-8 -*-

import os
import multiprocessing
from typing import List, Dict

try:
    from tf2onnx import tf_loader
except ImportError:
    # install tf2onnx
    import subprocess
    subprocess.call(["sudo", "/usr/bin/python3", "-m", "pip", "install", "tf2onnx==1.9.3"])
    from tf2onnx import tf_loader

from tensorflow.core.protobuf import meta_graph_pb2, config_pb2
from tensorflow.python.grappler import tf_optimizer
from google.protobuf import text_format
from tensorflow.core.protobuf import rewriter_config_pb2
import tensorflow as tf

DEFAULT_OPTIMIZERS = ('dependency',)


def run_graph_grappler(graph, inputs, outputs, optimizers=DEFAULT_OPTIMIZERS):
    tf.compat.v1.disable_eager_execution()
    config = config_pb2.ConfigProto()
    config.graph_options.rewrite_options.optimizers.extend(optimizers)
    config.graph_options.rewrite_options.meta_optimizer_iterations = rewriter_config_pb2.RewriterConfig.ONE
    meta_graph = tf.compat.v1.train.export_meta_graph(graph_def=graph)
    fetch_collection = meta_graph_pb2.CollectionDef()
    fetch_collection.node_list.value.extend(inputs)
    fetch_collection.node_list.value.extend(outputs)
    meta_graph.collection_def['train_op'].CopyFrom(fetch_collection)
    graph_def = tf_optimizer.OptimizeGraph(config, meta_graph)
    return graph_def


def is_control_dependency(node_name: str) -> bool:
    return node_name.startswith("^")


def is_saved_model_control_node(node: tf.compat.v1.NodeDef) -> bool:
    '''
    control node looks like:
    node {
    name: "Func/StatefulPartitionedCall/input_control_node/_0"
    op: "NoOp"
    input: "^deep_fm4_1024"
    input: "^deep_fm4_1552"
    }
    such nodes should be removed if we need to inference the subgraph
    '''
    if node.op != "NoOp":
        return False
    if "input_control_node" not in node.name and "output_control_node" not in node.name:
        return False
    return all([is_control_dependency(input_name) for input_name in node.input])


def fix_saved_model_control_dependency(graph_def: tf.compat.v1.GraphDef):
    saved_model_control_nodes = set()
    # collect input_control_node
    for node in graph_def.node:
        if is_saved_model_control_node(node):
            saved_model_control_nodes.update(["^" + node.name])
    # remove input_control_node dependencies from normal node inputs
    for node in graph_def.node:
        for i in reversed(range(len(node.input))):
            input_name = node.input[i]
            if input_name in saved_model_control_nodes:
                # safe deletion in iteration
                node.input[i], node.input[-1] = node.input[-1], node.input[i]
                del node.input[-1]
    return graph_def


def fix_output_name(graph_def: tf.compat.v1.GraphDef, outputs: List[str], alias_map: Dict[str, str]):
    '''
    outputs looks like:
    ['Identity:0', 'Identity_1:0',
    'Identity_2:0', 'Identity_3:0',
    'Identity_4:0', 'Identity_5:0',
    'Identity_6:0', 'Identity_7:0']

    alias_map looks like:
    {'Identity:0': 'logit_dislike', 'Identity_1:0': 'logit_like',
    'Identity_2:0': 'logit_play', 'Identity_3:0': 'logit_staytime',
    'Identity_4:0': 'pred_dislike', 'Identity_5:0': 'pred_like',
    'Identity_6:0': 'pred_play', 'Identity_7:0': 'pred_staytime'}

    apply alias name inplace so that serving won't need alias mapping
    '''
    for node in graph_def.node:
        tensor_name =  node.name + ":0"
        if tensor_name in outputs:
            node.name = alias_map[tensor_name]
    return graph_def


def convert_saved_model_to_graph_def(export_dir):
    print("Start to convert saved model to graph def pbtxt", flush=True)
    assert(os.path.exists("{}/saved_model.pb".format(export_dir)))
    frozen_graph_def, inputs, outputs, alias_map = tf_loader.from_saved_model(
        export_dir, input_names = None, output_names = None,
        return_tensors_to_rename=True)
    # remove trival Identity and control dependency for readability
    frozen_graph_def = run_graph_grappler(frozen_graph_def, inputs=inputs, outputs=outputs)
    frozen_graph_def = fix_saved_model_control_dependency(frozen_graph_def)
    frozen_graph_def = fix_output_name(frozen_graph_def, outputs, alias_map)
    graph_def_file = "{}/graph.pbtxt".format(export_dir)
    with open(graph_def_file, 'w') as f:
        f.write(text_format.MessageToString(frozen_graph_def))
    print("Convert saved model to graph def success", flush=True)

  

----2022.09.28补充--------------

通过阅读tf_loader的远吗,发现在转换成graph的时候,已经做了grappler的优化,取的是constfold, dependency,如果取constfold的话,会导致中间节点被折叠起来,不想被折叠的话,禁止使用constflod优化方法就可以了。但是需要改tf_loader.py的源码(目前没找到能仅仅替换import的模块,里面某个函数的方法)

posted @ 2022-09-29 11:49  灰太狼锅锅  阅读(204)  评论(0编辑  收藏  举报