【文本摘要项目】4-解码算法及模型测试结果评估

背景

   结果前面文章中对文本处理、模型构建及训练等内容,本文主要介绍训练完成之后,如何利用模型进行生成文本?以及如何衡量模型的性能等。

核心内容

   为尽快使baseline完整,本文先采用两种常见的解码算法:Greedy DecodeBeam Serach进行解码,因此后续文中实现也主要围绕这两个内容。训练过程和预测过程代码结构基本差不多,主要在predict.py文件中。

模型加载

   首先,重新加载已训练好的tensorflow模型,样例代码如下:

  checkpoint = tf.train.Checkpoint(Seq2Seq=model)
    checkpoint_manager = tf.train.CheckpointManager(checkpoint, seq2seq_checkpoint_dir, max_to_keep=5)
    checkpoint.restore(checkpoint_manager.latest_checkpoint)
    # checkpoint.restore('../../data/checkpoints/training_checkpoints_seq2seq/ckpt-6')
    if checkpoint_manager.latest_checkpoint:
        print("Restored from {}".format(checkpoint_manager.latest_checkpoint))
    else:
        print("Initializing from scratch.")

   tensorflow中,指定latest_checkpoint方法,可以自动加载最新训练保存的模型。在加载模型后,还需要定义解码算法,本文实现两种解码算法:贪心搜索和beam search,样例代码如下:

解码算法

def greedy_decode(model, data_X, batch_size, vocab, params):
    # 存储结果
    results = []
    # 样本数量
    sample_size = len(data_X)
    # batch 操作轮数 math.ceil向上取整 小数 +1
    # 因为最后一个batch可能不足一个batch size 大小 ,但是依然需要计算
    steps_epoch = math.ceil(sample_size / batch_size)
    # [0,steps_epoch)
    for i in tqdm(range(steps_epoch)):
        batch_data = data_X[i * batch_size:(i + 1) * batch_size]
        results += batch_greedy_decode(model, batch_data, vocab, params)
    return results
b = beam_test_batch_generator(params["beam_size"])
results = []
for batch in b:
    best_hyp = beam_decode(model, batch, vocab, params)
    results.append(best_hyp.abstract)
#         get_rouge(results)  # 模型生成结果衡量,后续展开

    beam search代码分为两部分,数据加载和模型解码。

def beam_decode(model, batch_data, vocab, params):
    # 初始化mask
    start_index = vocab.STOP_DECODING_INDEX
    stop_index = vocab.STOP_DECODING_INDEX
    unk_index = vocab.UNKNOWN_TOKEN_INDEX
    batch_size = params['batch_size']

    # 单步decoder
    def decoder_one_step(enc_output, dec_input, dec_hidden):
        final_pred, dec_hidden, attention_weights = model.decoder(dec_input, dec_hidden, enc_output)

        # 取top K个index及其对应的概率
        top_k_probs, top_k_idx = tf.nn.top_k(tf.squeeze(final_pred), k=params['beam_size'] * 2)

        # 重新计算概率分布
        top_k_log_probs = tf.math.log(top_k_probs)

        results = {
            'dec_hidden': dec_hidden,
            'attention_weights': attention_weights,
            'top_k_idx': top_k_idx,
            'top_k_log_probs': top_k_log_probs
        }

        return results

    # 测试数据的输入
    enc_input = batch_data
    init_enc_hidden = model.encoder.initialize_hidden_state()

    # 计算encoder的输出
    enc_output, enc_hidden = model.encoder(enc_input, init_enc_hidden)

    hyps_batch = [Hypothesis(tokens=[start_index],
                             log_probs=[0.],
                             hidden=enc_hidden[0],
                             attn_dists=[]) for _ in range(batch_size)]

    # 初始化结果集合
    results = []
    steps = 0  # 遍历步数

    # 当长度不够或者结果还不够时,继续搜索
    while steps < params['max_dec_len'] and len(results) < params['beam_size']:

        # 获取最新待使用的token
        latest_tokens = [hyps.latest_token for hyps in hyps_batch]
        # 替换掉oov token为unk token
        latest_tokens = [token if token in vocab.index2word else unk_index for token in latest_tokens]

        # 获取隐变量
        hiddens = [hyps.hidden for hyps in hyps_batch]

        dec_input = tf.expand_dims(latest_tokens, axis=1)
        dec_hidden = tf.stack(hiddens, axis=0)

        # 单步运行decoder
        decoder_results = decoder_one_step(enc_output, dec_input, dec_hidden)

        dec_hidden = decoder_results['dec_hidden']
        attention_weights = decoder_results['attention_weights']
        top_k_log_probs = decoder_results['top_k_log_probs']
        top_k_idx = decoder_results['top_k_idx']

        # 现阶段全部可能的情况
        all_hyps = []

        # 原有的所有可能情况
        num_ori_hyps = 1 if steps == 0 else len(hyps_batch)

        # 便利添加所有可能的结果
        for i in range(num_ori_hyps):
            hyps, new_hidden, attn_dist = hyps_batch[i], dec_hidden[i], attention_weights[i]

            for j in range(params['beam_size'] * 2):
                new_hyps = hyps.extend(
                    token=top_k_idx[i, j].numpy(),
                    log_prob=top_k_log_probs[i, j],
                    hidden=new_hidden,
                    attn_dist=attn_dist
                )

                all_hyps.append(new_hyps)


        # 重置
        hyps_batch = []
        sorted_hyps = sorted(all_hyps, key=lambda h: h.ave_log_prob, reverse=True)

        # 筛选
        for h in sorted_hyps:
            if h.latest_token == stop_index:
                # 长度符合预测,遇到居委,添加到结果集
                if steps >= params['min_dec_steps']:
                    h.tokens = h.tokens[1: -1]
                    results.append(h)

            else:
                hyps.append(h)

            if len(hyps) == params['beam_size'] or len(results) == params['beam_size']:
                break

        steps += 1

    if len(results) == 0:
        results = hyps

    hyps_sorted = sorted(results, key=lambda h: h.ave_log_prob, reverse=True)
    print_top_k(hyps_sorted, 3, vocab, batch_data)

    best_hyp = hyps_sorted[0]
    best_hyp.abstract = ' '.join([vocab.index_to_word(index) for index in best_hyp.tokens])

    return best_hyp
def batch_greedy_decode(model, batch_data, vocab, params):
    # 判断输入长度
    batch_size = len(batch_data)

    # 存储预测结果
    predictions = [''] * batch_size

    inputs = tf.convert_to_tensor(batch_data)
    # 0. 初始化隐层输入
    init_hidden = tf.zeros(shape=(batch_size, params['enc_units']))
    # 1. 构造encoder
    enc_output, enc_hidden = model.encoder(inputs, init_hidden)

    # 2. 复制到解码器
    dec_hidden = enc_hidden

    # 3. <START> * batch_size
    dec_input = tf.expand_dims([vocab.word_to_index(vocab.START_DECODING)] * batch_size, 1)

    # 4. 解码
    for t in range(params['max_dec_len']):
        # 4.0. 预测
        predictions, dec_hidden, attention_weights = model.decoder(dec_input, dec_hidden, enc_output)

        # 4.1. 取预测结果,概率最大值所对应的index
        predictions_idx = tf.argmax(predictions, axis=1).numpy()  # 最大值所对应的角标

        # 4.2. 根据index,取相应的词,存放到列表
        for index, predict_idx in enumerate(predictions_idx):
            predictions[index] += vocab.index_to_word(predict_idx) + ' '

        # 4.3. 继续下一个词的预测(用上一步预测的结果)
        dec_input = tf.expand_dims(predictions_idx)

    # 5. 解码结果处理
    results = []
    for prediction in predictions:

        prediction = prediction.strip()
        if vocab.STOP_DECODING in prediction:
            prediction = prediction[:prediction.index(vocab.STOP_DECODING)]
        results.append(prediction)

    return results

class Hypothesis:

    def __init__(self, tokens, log_probs, hidden, attn_dists):
        self.tokens = tokens
        self.log_probs = log_probs
        self.hidden = hidden
        self.attn_dists = attn_dists
        self.abstract = ''

    def extend(self, token, log_prob, hidden, attn_dist):

        return Hypothesis(
            tokens=self.tokens + [token],
            log_probs=self.log_probs + [log_prob],
            hidden=hidden,
            attn_dists=self.attn_dists + [attn_dist]
        )

    @property
    def latest_token(self):
        return self.tokens[-1]

    @property
    def total_log_prob(self):
        return sum(self.log_probs)

    @property
    def avg_log_prob(self):
        return self.total_log_prob / len(self.tokens)

测试结果衡量

    对测试结果衡量,主要采取的时Rouge分数。样例代码如下:

def get_rouge(results):
    # 读取结果
    seg_test_report = pd.read_csv(test_seg_path, header=None).iloc[:, 5].tolist()
    seg_test_report = [' '.join(str(token) for token in str(line).split()) for line in seg_test_report]
    rouge_scores = Rouge().get_scores(results, seg_test_report, avg=True)
    print_rouge = json.dumps(rouge_scores, indent=2)
    with open(os.path.join(os.path.dirname(test_seg_path), 'results.csv'), 'w', encoding='utf8') as f:
        json.dump(list(zip(results, seg_test_report)), f, indent=2, ensure_ascii=False)
    print('*' * 8 + ' rouge score ' + '*' * 8)
    print(print_rouge)

完整代码

github地址

posted @ 2021-07-05 17:51  温良Miner  阅读(529)  评论(0编辑  收藏  举报
分享到: