Bert tensorflow 版本的线上预测demo
import tensorflow as tf import os import pickle from bert_crf import tokenization model_dir = r'crf_output_bak/' output_graph = './pb_model/query_model.pb' bert_dir = r'chinese_L-12_H-768_A-12' # 加载label->id的词典 with open(os.path.join(model_dir, 'label2id.pkl'), 'rb') as rf: label2id = pickle.load(rf) id2label = {value: key for key, value in label2id.items()} with open(os.path.join(model_dir, 'label_list.pkl'), 'rb') as rf: label_list = pickle.load(rf) num_labels = len(label_list) tokenizer = tokenization.FullTokenizer( vocab_file=os.path.join(bert_dir, 'vocab.txt'), do_lower_case=True) def load_pb_predict(): """加载pb预测 """ text = ['深汕特别合作区'] # print('input the test sentence:\t{}'.format(sentence_all)) # sentence = str(input()) sentence = [[s for s in str(each)] for each in text] input_ids, input_mask, = convert(sentence) with tf.Graph().as_default(): output_graph_def = tf.GraphDef() with open(output_graph, "rb") as f: output_graph_def.ParseFromString( tf.import_graph_def(output_graph_def, name="") res = [ for each in output_graph_def.node] for each in res: print(each) with tf.compat.v1.Session() as sess: t1 = time.time() input_ids_p = sess.graph.get_tensor_by_name("input_ids:0") input_mask_p = sess.graph.get_tensor_by_name("input_mask:0") # feed_dict = {input_ids_p: input_ids, input_mask_p: input_mask} # 定义输出的张量名称 output_tensor_name = sess.graph.get_tensor_by_name("viterbi/ReverseSequence_1:0") out =, feed_dict) pred_label_result = convert_id_to_label(out, id2label) t2 = time.time() print('模型预测吞吐量:{}'.format((t2-t1)/len(input_ids))) print(pred_label_result) def convert_id_to_label(pred_ids_result, idx2label): result = [] for row in range(len(pred_ids_result)): curr_seq = [] for ids in pred_ids_result[row]: if ids == 0: continue curr_label = idx2label[ids] if curr_label in ['[CLS]', '[SEP]']: continue curr_seq.append(curr_label) result.append(curr_seq) return result def convert(samples): input_ids_list = [] input_mask_list = [] for line in samples: feature = convert_single_example(0, line, label_list, 25) input_ids_list.append(feature.input_ids) input_mask_list.append(feature.input_mask) # input_ids = np.reshape([feature.input_ids],(batch_size, FLAGS.max_seq_length)) # input_mask = np.reshape([feature.input_mask],(batch_size, FLAGS.max_seq_length)) # segment_ids = np.reshape([feature.segment_ids],(batch_size, FLAGS.max_seq_length)) # label_ids =np.reshape([feature.label_ids],(batch_size, FLAGS.max_seq_length)) return input_ids_list, input_mask_list def convert_single_example(ex_index, example, label_list, max_seq_length): """ 将一个样本进行分析,然后将字转化为id, 标签转化为id,然后结构化到InputFeatures对象中 :param ex_index: index :param example: 一个样本 :param label_list: 标签列表 :param max_seq_length: :param tokenizer: :param mode: :return: """ label_map = {} # 1表示从1开始对label进行index化 for (i, label) in enumerate(label_list, 1): label_map[label] = i # 保存label->index 的map if not os.path.exists(os.path.join(model_dir, 'label2id.pkl')): with open(os.path.join(model_dir, 'label2id.pkl'), 'wb') as w: pickle.dump(label_map, w) tokens = example # tokens = .tokenize(example.text) # 序列截断 if len(tokens) >= max_seq_length - 1: tokens = tokens[0:(max_seq_length - 2)] # -2 的原因是因为序列需要加一个句首和句尾标志 ntokens = [] segment_ids = [] label_ids = [] ntokens.append("[CLS]") # 句子开始设置CLS 标志 segment_ids.append(0) # append("O") or append("[CLS]") not sure! label_ids.append("[CLS]") # O OR CLS 没有任何影响,不过我觉得O 会减少标签个数,不过拒收和句尾使用不同的标志来标注,使用LCS 也没毛病 for i, token in enumerate(tokens): ntokens.append(token) segment_ids.append(0) label_ids.append(0) ntokens.append("[SEP]") # 句尾添加[SEP] 标志 segment_ids.append(0) # append("O") or append("[SEP]") not sure! label_ids.append("[SEP]") input_ids = tokenizer.convert_tokens_to_ids(ntokens) # 将序列中的字(ntokens)转化为ID形式 input_mask = [1] * len(input_ids) # padding, 使用 while len(input_ids) < max_seq_length: input_ids.append(0) input_mask.append(0) segment_ids.append(0) # we don't concerned about it! label_ids.append(0) ntokens.append("**NULL**") # label_mask.append(0) # print(len(input_ids)) assert len(input_ids) == max_seq_length assert len(input_mask) == max_seq_length assert len(segment_ids) == max_seq_length assert len(label_ids) == max_seq_length # assert len(label_mask) == max_seq_length # 结构化为一个类 feature = InputFeatures( input_ids=input_ids, input_mask=input_mask, segment_ids=segment_ids, label_ids=label_ids, # label_mask = label_mask ) return feature class InputFeatures(object): """A single set of features of data.""" def __init__(self, input_ids, input_mask, segment_ids, label_ids, is_real_example=True): self.input_ids = input_ids self.input_mask = input_mask self.segment_ids = segment_ids self.label_ids = label_ids self.is_real_example = is_real_example if __name__ == '__main__': load_pb_predict()
【推荐】凌霞软件回馈社区,博客园 & 1Panel & Halo 联合会员上线
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步