.h5模型转.pb模型(tensorflow2)

转自https://www.cnblogs.com/buctyk/archive/2004/01/13/12932663.html

import tensorflow.compat.v1 as tf1
 
tf1.reset_default_graph()
tf1.keras.backend.set_learning_phase(0)  # 调用模型前一定要执行该命令
tf1.disable_v2_behavior()  # 禁止tensorflow2.0的行为

hdf5_pb_model = tf1.keras.models.load_model('mnist_test_cnn.h5')
def freeze_session(session, keep_var_names=None, output_names=None, clear_devices=True):
    graph = session.graph
    with graph.as_default():
        #         freeze_var_names = list(set(v.op.name for v in tf1.global_variables()).difference(keep_var_names or []))
        output_names = output_names or []
        #         output_names += [v.op.name for v in tf1.global_variables()]
        print("output_names", output_names)
        input_graph_def = graph.as_graph_def()
        #         for node in input_graph_def.node:
        #             print('node:', node.name)
        print("len node1", len(input_graph_def.node))
        if clear_devices:
            for node in input_graph_def.node:
                node.device = ""
        frozen_graph = tf1.graph_util.convert_variables_to_constants(session, input_graph_def,
                                                                     output_names)
 
        outgraph = tf1.graph_util.remove_training_nodes(frozen_graph)  # 云掉与推理无关的内容
        print("##################################################################")
        for node in outgraph.node:
            print('node:', node.name)
        print("len node1", len(outgraph.node))
        return outgraph

output_folder2 = 'keras_model'
 
frozen_graph = freeze_session(tf1.compat.v1.keras.backend.get_session(),
                              output_names=[out.op.name for out in hdf5_pb_model.outputs])
tf1.train.write_graph(frozen_graph, output_folder2, "classify.pb", as_text=False)

结果如图

 

 

 

 

 第一种方法用的是tensorflow v1的api,第二种用的是tf2

import tensorflow as tf
from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2
 
def convert_h5to_pb():
    model = tf.keras.models.load_model("./alexnet_cifar10.h5",compile=False)
    model.summary()
    full_model = tf.function(lambda Input: model(Input))
    full_model = full_model.get_concrete_function(tf.TensorSpec(model.inputs[0].shape, model.inputs[0].dtype))
 
    # Get frozen ConcreteFunction
    frozen_func = convert_variables_to_constants_v2(full_model)
    frozen_func.graph.as_graph_def()
 
    layers = [op.name for op in frozen_func.graph.get_operations()]
    print("-" * 50)
    print("Frozen model layers: ")
    for layer in layers:
        print(layer)
 
    print("-" * 50)
    print("Frozen model inputs: ")
    print(frozen_func.inputs)
    print("Frozen model outputs: ")
    print(frozen_func.outputs)
 
    # Save frozen graph from frozen ConcreteFunction to hard drive
    tf.io.write_graph(graph_or_graph_def=frozen_func.graph,
                      logdir="./frozen_models",
                      name="alexnet_tf2.pb",
                      as_text=False)

 

posted @ 2021-01-27 16:18  Wangtn  阅读(2317)  评论(2编辑  收藏  举报