BERT模型+rabbitmq队列,进行实时预测,防止每次预测都重新加载图
1. 创建一个新类,使用tensorflow内置的from_generator函数,通过生成器传入句子,生成器中使用while循环,通过channel获取rabbitmq的句子进行预测
# 用于实时预测的一个类,内置了rabbitmq消息队列,由消息队列传入预测句子,最终实时打印出预测结果 class BertPredictByGen(object): def __init__(self, estimator, label_list, tokenizer, channel, queue_name): self.estimator = estimator self.label_list = label_list self.tokenizer = tokenizer self.channel = channel self.queue_name = queue_name def input_fn_builder2(self): def gen(): while True: method, properties, qs = self.channel.basic_get(self.queue_name, auto_ack=False) if not qs: continue self.channel.basic_ack(delivery_tag=method.delivery_tag) # 应答 text = str(qs, encoding='UTF-8') # guid这里实际用不到,可以随便写,但是label必须是label_list中的一个 examples = [InputExample(guid=0, text_a=text, text_b=None, label="其他")] features = convert_examples_to_features(examples, self.label_list, FLAGS.max_seq_length, self.tokenizer) all_input_ids = [] all_input_mask = [] all_segment_ids = [] all_label_ids = [] for feature in features: all_input_ids.append(feature.input_ids) all_input_mask.append(feature.input_mask) all_segment_ids.append(feature.segment_ids) all_label_ids.append(feature.label_id) yield { 'input_ids': all_input_ids, 'input_mask': all_input_mask, 'segment_ids': all_segment_ids, 'label_ids': all_label_ids, } def input_fn(params): # batch_size = params["batch_size"] types = { 'input_ids': tf.int32, 'input_mask': tf.int32, 'segment_ids': tf.int32, 'label_ids': tf.int32, } shapes = { 'input_ids': (None, FLAGS.max_seq_length), 'input_mask': (None, FLAGS.max_seq_length), 'segment_ids': (None, FLAGS.max_seq_length), 'label_ids': (None,), } return tf.data.Dataset.from_generator(gen, output_types=types, output_shapes=shapes).prefetch(1) return input_fn def predict(self): for result in self.estimator.predict(self.input_fn_builder2(), yield_single_examples=False): answer = self.label_list[np.argmax(result['probabilities'])] # 预测结果 print("raw result:", answer)
2. 原本源码的do_predict函数改成如下:
if FLAGS.do_predict: project_config = modeling.BertConfig.from_json_file(FLAGS.project_config_file) # 加载项目配置文件,自己按照bert_config_file写一个配置文件,用于存储rabbitmq的配置 credentials = pika.PlainCredentials(username=project_config.queue_username, password=project_config.queue_password) connection = pika.BlockingConnection( pika.ConnectionParameters(host=project_config.queue_host, virtual_host=project_config.queue_virtual_host, credentials=credentials)) channel = connection.channel() # 创建频道 classifer = BertPredictByGen(estimator=estimator, label_list=label_list, tokenizer=tokenizer, channel=channel, queue_name=project_config.queue_name) # 实例化类 classifer.predict() # 进行预测