tf多个tensor输出并完成加载

转换成pb模型,设定多输出

def fun():
    """ 保留bert第一层和第二层信息"""
    OUTPUT_GRAPH = 'pb_model/query_encoder.pb'
    output_node = ["loss/Softmax", "bert/pooler/dense/Tanh", "Mean"]
    ckpt_model = r'best_ckpt'
    bert_config_file = r'model/chinese_L-12_H-768_A-12/bert_config.json'
    max_seq_length = 10
    gpu_config = tf.ConfigProto()
    gpu_config.gpu_options.allow_growth = True
    sess = tf.Session(config=gpu_config)
    graph = tf.get_default_graph()
    with graph.as_default():
        print("going to restore checkpoint")
        input_ids_p = tf.placeholder(tf.int32, [None, max_seq_length], name="input_ids")
        input_mask_p = tf.placeholder(tf.int32, [None, max_seq_length], name="input_mask")
        bert_config = modeling.BertConfig.from_json_file(bert_config_file)
        (loss, per_example_loss, logits, probabilities, out) = create_model(bert_config=bert_config, is_training=False, input_ids=input_ids_p, input_mask=input_mask_p,
            segment_ids=None, labels=None, num_labels=len(label_list), use_one_hot_embeddings=False, fp16=FLAGS.use_fp16)
        saver = tf.train.Saver()
        saver.restore(sess, tf.train.latest_checkpoint(ckpt_model))
        graph = tf.graph_util.convert_variables_to_constants(sess, sess.graph_def, output_node)
        with tf.gfile.GFile(OUTPUT_GRAPH, "wb") as f:
            f.write(graph.SerializeToString())
        print('extract vector pb model saved!')

推理部分

class BertEncoder(object):
    """ model
    """
    def __init__(self, OUTPUT_GRAPH):
        self.max_length = 30
        self.tokenizer = TOKENIZER
        self.out_graph = os.path.join(CURRENT_DIR, "pb_model", OUTPUT_GRAPH)
        self.model_graph = {}
        graph = tf.Graph()
        with graph.as_default():
            self.model_graph['output_graph_def'] = tf.compat.v1.GraphDef()
            with open(self.out_graph, "rb") as f:
                self.model_graph['output_graph_def'].ParseFromString(f.read())
            self.model_graph['sess'] = tf.Session(graph=graph)
        with self.model_graph['sess'].as_default():
            with graph.as_default():
                self.model_graph['sess'].run(tf.compat.v1.global_variables_initializer())
                _input_1, _input2, _output_1, _output_2, _cls_out = tf.import_graph_def(self.model_graph['output_graph_def'],
                                                                             return_elements=[INPUT_1, INPUT_2, SOFTMAX_OUTPUT, FIRST_LAST_OUTPUT, CLS_OUTPUT])
                self.input_ids_p = self.model_graph['sess'].graph.get_tensor_by_name("import/input_ids:0")
                self.input_mask_p = self.model_graph['sess'].graph.get_tensor_by_name("import/input_mask:0")
                self.output_1 = self.model_graph['sess'].graph.get_tensor_by_name("import/loss/Softmax:0")
                self.output_2 = self.model_graph['sess'].graph.get_tensor_by_name("import/Mean:0")
                self.output_3 = self.model_graph['sess'].graph.get_tensor_by_name("import/bert/pooler/dense/Tanh:0")

 

posted @ 2021-05-08 20:14  今夜无风  阅读(147)  评论(0编辑  收藏  举报