Keras模型转换:h5-> pb -> saved_model
tf模型线上部署需要采用saved_model形式,现将踩过的坑记录如下:
""" Function: h5 model to pb to saved_model """ import os import keras import tensorflow import tensorflow as tf import keras.backend as K from keras import backend from tensorflow.python.platform import gfile from keras.models import Model def h5_to_pb(model_path, output_dir, model_name, out_prefix="output_", log_tensorboard=False): """ .h5模型文件转换成pb模型文件 :param h5_model: .h5模型 :param output_dir: pb模型文件保存路径 :param model_name: pb模型文件名称 :param out_prefix: 根据训练,需要修改 :param log_tensorboard: 是否生成日志文件,默认为True :return: pb模型文件 """ h5_model = keras.models.load_model(model_path,custom_objects={'tf': tf}) if not os.path.exists(output_dir): os.mkdir(output_dir) out_nodes = ["geek_vector/Relu"] # 必须对准h5模型输出,否则遇到难以想象问题 for i in h5_model.inputs: print("输入节点Tensor:{}".format(i)) for i in h5_model.outputs: print("输出节点Tensor:{}".format(i)) sess = backend.get_session() from tensorflow.python.framework import graph_util, graph_io # 写入pb模型文件 init_graph = sess.graph.as_graph_def() main_graph = graph_util.convert_variables_to_constants(sess, init_graph, out_nodes) graph_io.write_graph(main_graph, output_dir, name=model_name, as_text=False) # 输出日志文件 if log_tensorboard: from tensorflow.python.tools import import_pb_to_tensorboard import_pb_to_tensorboard.import_to_tensorboard(os.path.join(output_dir, model_name), output_dir) def pb_to_savedmodel(pb_path=None, savedmodel_path=None, input_name=None, output_name=None): if output_name is None: output_name = ["output_1:0"] config = tf.ConfigProto(allow_soft_placement=True) sess = tf.Session(config=config) with gfile.FastGFile(pb_path, 'rb') as f: # 加载冻结图模型文件 graph_def = tf.GraphDef() graph_def.ParseFromString(f.read()) sess.graph.as_default() tf.import_graph_def(graph_def, name='') # 导入图定义 sess.run(tf.global_variables_initializer()) tensor_name_list = [tensor.name for tensor in tf.get_default_graph().as_graph_def().node] # 建立tensor info bundle input_dict = {k:tf.saved_model.utils.build_tensor_info(sess.graph.get_tensor_by_name('{}:0'.format(k))) for k in input_name} output_1 = tf.saved_model.utils.build_tensor_info(sess.graph.get_tensor_by_name(output_name[0])) export_path = os.path.join(tf.compat.as_bytes(savedmodel_path), tf.compat.as_bytes('1')) # Export model with signature builder = tf.saved_model.builder.SavedModelBuilder(export_path) prediction_signature = ( tf.saved_model.signature_def_utils.build_signature_def( inputs=input_dict, outputs={'geek_embedding': output_1}, method_name=tf.saved_model.signature_constants.PREDICT_METHOD_NAME)) builder.add_meta_graph_and_variables( sess, [tf.saved_model.tag_constants.SERVING], signature_def_map={ 'serving_default': prediction_signature }, main_op=tf.tables_initializer()) builder.save() print('savedmodel 保存成功') if __name__ == "__main__": h5_model_path = r"/models/dengyx/embedding_recall/model_bak/20220223_no_history_newbert_integer/online_geek.h5" pb_model_path = r"/models/dengyx/embedding_recall/model_bak/pb_model" pb_model_name = r"online_geek_encoder.pb" saved_model_path = r"/models/dengyx/embedding_recall/model_bak/saved_model" geek_names = ["age","gender","degree_code","work_years","work_position_code","expect_position_code","expect_city_code","expect_low_salary", "expect_high_salary","expect_type","expect_position_type","expect_sub_location"] geek_features = ["G"+str(i) for i in range(1,len(geek_names)+1)] online_geek_feats = geek_features + ["query_embedding"] ###################################################### # 一定要分开进行转换,不然模型会出错 ###################################################### # step 01 # h5_to_pb(h5_model_path, output_dir=pb_model_path, model_name=pb_model_name) # step 02 pb_to_savedmodel(os.path.join(pb_model_path, pb_model_name), saved_model_path, input_name=online_geek_feats, output_name=["geek_vector/Relu:0"])
时刻记着自己要成为什么样的人!