玩烂bert--集成输出预测分类、特征向量、参数冻结、推理
功能:
1)微调模型后,下游任务在此模型上继续开发模型,冻结12层。方法:加载微调模型后(不是google原始ckpt),在custom_optimization.py中仅梯度更新需要的variable
update_var_list = [] tvars = tf.trainable_variables() for v in tvars: if "my_variable" in v.name: update_var_list.append(v) # gvs = optimizer.compute_gradients(loss, tvars) gvs = optimizer.compute_gradients(loss, update_var_list)
2)顺带输出每个字符的编码向量(768 dim),vector来源根据自身需求选取,供下游相似度查询、检索使用,直接get出
本次记录ckpt转pb主要代码:
def bert_first_last_layer(): """ 保留bert第一层和第二层信息""" OUTPUT_GRAPH = 'pb_model/my_model.pb' # output_node = ["bert/pooler/dense/Tanh", "Mean"] output_node = ["loss/Softmax", "bert/pooler/dense/Tanh", "Mean", "loss/Softmax_1"] ckpt_model = r'new_ckpt' bert_config_file = r'model/chinese_L-12_H-768_A-12/bert_config.json' max_seq_length = 350 confidence_labels_length = 2 gpu_config = tf.ConfigProto() gpu_config.gpu_options.allow_growth = True sess = tf.Session(config=gpu_config) graph = tf.get_default_graph() with graph.as_default(): print("going to restore checkpoint") input_ids_p = tf.placeholder(tf.int32, [None, max_seq_length], name="input_ids") input_mask_p = tf.placeholder(tf.int32, [None, max_seq_length], name="input_mask") bert_config = modeling.BertConfig.from_json_file(bert_config_file) (loss, per_example_loss, logits, probabilities) = create_model( bert_config=bert_config, is_training=False, input_ids=input_ids_p, input_mask=input_mask_p, segment_ids=None, labels=None, num_labels=confidence_labels_length, use_one_hot_embeddings=False, fp16=FLAGS.use_fp16) saver = tf.train.Saver() saver.restore(sess, tf.train.latest_checkpoint(ckpt_model)) graph = tf.graph_util.convert_variables_to_constants(sess, sess.graph_def, output_node) with tf.gfile.GFile(OUTPUT_GRAPH, "wb") as f: f.write(graph.SerializeToString()) print('extract vector pb model saved!') def pb_2_savedmodel(pb_path="pb_model/my_model.pb", savedmodel_path="merge_savedmodel", output_name=None): if output_name is None: output_name = ["loss/Softmax:0", "bert/pooler/dense/Tanh:0", "Mean:0", "loss/Softmax_1:0"] config = tf.ConfigProto(allow_soft_placement=True) sess = tf.Session(config=config) with gfile.FastGFile(pb_path, 'rb') as f: # 加载冻结图模型文件 graph_def = tf.GraphDef() graph_def.ParseFromString(f.read()) sess.graph.as_default() tf.import_graph_def(graph_def, name='') # 导入图定义 sess.run(tf.global_variables_initializer()) # 建立tensor info bundle input_ids = tf.saved_model.utils.build_tensor_info(sess.graph.get_tensor_by_name('input_ids:0')) input_mask = tf.saved_model.utils.build_tensor_info(sess.graph.get_tensor_by_name('input_mask:0')) output_1 = tf.saved_model.utils.build_tensor_info(sess.graph.get_tensor_by_name(output_name[0])) output_2 = tf.saved_model.utils.build_tensor_info(sess.graph.get_tensor_by_name(output_name[1])) output_3 = tf.saved_model.utils.build_tensor_info(sess.graph.get_tensor_by_name(output_name[2])) output_4 = tf.saved_model.utils.build_tensor_info(sess.graph.get_tensor_by_name(output_name[3])) export_path = os.path.join(tf.compat.as_bytes(savedmodel_path), tf.compat.as_bytes('1')) # Export model with signature builder = tf.saved_model.builder.SavedModelBuilder(export_path) prediction_signature = ( tf.saved_model.signature_def_utils.build_signature_def( inputs={'input_ids': input_ids, 'input_mask': input_mask}, outputs={'output_class': output_1, "output_cls_vector":output_2, "output_fl_vector":output_3, "output_confidence_class":output_4}, method_name=tf.saved_model.signature_constants.PREDICT_METHOD_NAME)) builder.add_meta_graph_and_variables( sess, [tf.saved_model.tag_constants.SERVING], signature_def_map={ 'a_signature': prediction_signature }, main_op=tf.tables_initializer()) builder.save() print('savedmodel 保存成功') if __name__ == '__main__': # ckpt_2_pb() # read_tfrecord() # create_20210119() # extract_bert_vector() bert_first_last_layer() pb_2_savedmodel()
3)infer阶段,如果多个图会产生冲突,get graph时要有处理技巧。现将加载图的主干代码记录如下:
class ConfidenceModel(object): """ model """ def __init__(self): self.max_length = 350 self.tokenizer = TOKENIZER self.out_graph = OUTPUT_GRAPH self.model_graph = {} graph = tf.Graph() with graph.as_default(): self.model_graph['output_graph_def'] = tf.compat.v1.GraphDef() with open(self.out_graph, "rb") as f: self.model_graph['output_graph_def'].ParseFromString(f.read()) self.model_graph['sess'] = tf.Session(graph=graph) with self.model_graph['sess'].as_default(): with graph.as_default(): self.model_graph['sess'].run(tf.compat.v1.global_variables_initializer()) _input_1, _input2, _output_1, _output_2, _output_3, _output_4 = tf.import_graph_def( self.model_graph['output_graph_def'], return_elements=[INPUT_1, INPUT_2, SOFTMAX_OUTPUT, FIRST_LAST_OUTPUT, CLS_OUTPUT, CONFIDENCE_OUTPUT]) self.input_ids_p = self.model_graph['sess'].graph.get_tensor_by_name("import/input_ids:0") self.input_mask_p = self.model_graph['sess'].graph.get_tensor_by_name("import/input_mask:0") self.output_1 = self.model_graph['sess'].graph.get_tensor_by_name("import/loss/Softmax:0") self.output_2 = self.model_graph['sess'].graph.get_tensor_by_name("import/Mean:0") self.output_3 = self.model_graph['sess'].graph.get_tensor_by_name("import/bert/pooler/dense/Tanh:0") self.output_4 = self.model_graph['sess'].graph.get_tensor_by_name("import/loss/Softmax_1:0")
时刻记着自己要成为什么样的人!