Keras模型转换:h5-> pb -> saved_model

tf模型线上部署需要采用saved_model形式,现将踩过的坑记录如下:

"""
    Function: h5 model to pb to saved_model
"""
import os
import keras
import tensorflow
import tensorflow as tf
import keras.backend as K
from keras import backend
from tensorflow.python.platform import gfile
from keras.models import Model


def h5_to_pb(model_path, output_dir, model_name, out_prefix="output_", log_tensorboard=False):
    """
    .h5模型文件转换成pb模型文件
    :param h5_model: .h5模型
    :param output_dir: pb模型文件保存路径
    :param model_name: pb模型文件名称
    :param out_prefix: 根据训练,需要修改
    :param log_tensorboard: 是否生成日志文件,默认为True
    :return: pb模型文件
    """
    h5_model = keras.models.load_model(model_path,custom_objects={'tf': tf})
    if not os.path.exists(output_dir):
        os.mkdir(output_dir)
    out_nodes = ["geek_vector/Relu"]       # 必须对准h5模型输出,否则遇到难以想象问题
        
    for i in h5_model.inputs:
        print("输入节点Tensor:{}".format(i))
    for i in h5_model.outputs:
        print("输出节点Tensor:{}".format(i))
    sess = backend.get_session()

    from tensorflow.python.framework import graph_util, graph_io
    # 写入pb模型文件
    init_graph = sess.graph.as_graph_def()
    main_graph = graph_util.convert_variables_to_constants(sess, init_graph, out_nodes)
    graph_io.write_graph(main_graph, output_dir, name=model_name, as_text=False)
    # 输出日志文件
    if log_tensorboard:
        from tensorflow.python.tools import import_pb_to_tensorboard
        import_pb_to_tensorboard.import_to_tensorboard(os.path.join(output_dir, model_name), output_dir)

def pb_to_savedmodel(pb_path=None, savedmodel_path=None, input_name=None, output_name=None):
    if output_name is None:
        output_name = ["output_1:0"]
    config = tf.ConfigProto(allow_soft_placement=True)
    sess = tf.Session(config=config)
    with gfile.FastGFile(pb_path, 'rb') as f:  # 加载冻结图模型文件
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())
        sess.graph.as_default()
        tf.import_graph_def(graph_def, name='')  # 导入图定义
    sess.run(tf.global_variables_initializer())
    tensor_name_list = [tensor.name for tensor in tf.get_default_graph().as_graph_def().node]
    
    # 建立tensor info bundle
    input_dict = {k:tf.saved_model.utils.build_tensor_info(sess.graph.get_tensor_by_name('{}:0'.format(k))) for k in input_name}
    output_1 = tf.saved_model.utils.build_tensor_info(sess.graph.get_tensor_by_name(output_name[0]))
    export_path = os.path.join(tf.compat.as_bytes(savedmodel_path), tf.compat.as_bytes('1'))
    # Export model with signature
    builder = tf.saved_model.builder.SavedModelBuilder(export_path)
    prediction_signature = (
        tf.saved_model.signature_def_utils.build_signature_def(
            inputs=input_dict,
            outputs={'geek_embedding': output_1},
            method_name=tf.saved_model.signature_constants.PREDICT_METHOD_NAME))
    builder.add_meta_graph_and_variables(
        sess, [tf.saved_model.tag_constants.SERVING],
        signature_def_map={
            'serving_default':
                prediction_signature
        },
        main_op=tf.tables_initializer())
    builder.save()
    print('savedmodel 保存成功')
        
if __name__ == "__main__":
    h5_model_path = r"/models/dengyx/embedding_recall/model_bak/20220223_no_history_newbert_integer/online_geek.h5"
    pb_model_path = r"/models/dengyx/embedding_recall/model_bak/pb_model"
    pb_model_name = r"online_geek_encoder.pb"
    saved_model_path = r"/models/dengyx/embedding_recall/model_bak/saved_model"
    geek_names = ["age","gender","degree_code","work_years","work_position_code","expect_position_code","expect_city_code","expect_low_salary",
                   "expect_high_salary","expect_type","expect_position_type","expect_sub_location"]
    geek_features = ["G"+str(i) for i in range(1,len(geek_names)+1)]
    online_geek_feats = geek_features + ["query_embedding"]
    ######################################################
    #   一定要分开进行转换,不然模型会出错
    ######################################################
    # step 01
#     h5_to_pb(h5_model_path, output_dir=pb_model_path, model_name=pb_model_name)
    # step 02
    pb_to_savedmodel(os.path.join(pb_model_path, pb_model_name), saved_model_path, input_name=online_geek_feats, output_name=["geek_vector/Relu:0"])
   

 

posted @ 2022-02-25 16:12  今夜无风  阅读(702)  评论(0编辑  收藏  举报