greedy search和beam search的原理以及实现

在自然语言处理seq2seq模型中,模型训练完成后,预测推理时需要预测每一步输出的最可能的单词,之后组合成完整的预测输出句子。这里每一步最可能的输出单词的选择就用到greedy search或者beam search。下面详细介绍一下这两种搜索的区别,以及实现方法。

贪婪搜索(greedy search)

greedy search比较简单,就是贪婪式的搜索,每一步都选择概率最大的单词输出,最后组成整个句子输出。这种方法给出的结果一般情况结果比较差,因为只考虑了每一步的最优解,往往里全局最优解差距很大。
贪婪搜索实现比较简单,这里就不写了,每一部找一个最大值就好了。

集束搜索(beam search)

beam search是介于全局搜索和贪婪搜索之间。这里先讲一下全局搜索,全局搜索考虑的是全局最优解,需要把每一种可能输出结果都算出来,然后找出概率最大的输出。这种搜索空间是非常巨大的,假设我们的词表大小为N,句子长度为T个单词,整个搜索时间复杂度为O(N*N*T),一般N取值在几万到几十万级别,T在为几百个单词,实际计算比较慢。
而降低时间复杂度的方法就是寻找次优解,具体就是把搜索空间中的N减下来,每一步计算完只保留K个最大的取值路径,这样时间复杂度降为O(K*N*T),K取值一般比N小很多。这样得到的虽然不是最优解,但是在seq2seq模型的推理预测中可以兼顾时间和效果。

beam search实现

下面用python简单实现一下beam-search算法,这里实现的是假设每一步输出的可能概率是提前算好的,然后传入一个k值,计算beam-search的最优输出。
实际seq2seq中beam-search是要考虑每一部状态输出的,下一步预测输出需要上一步的状态,这里的实现暂时未考虑。

写法1:每次先排序,然后选择排序后的k个最大值。

import numpy as np

def beam_search(probs, k):
    seq_scores = [[list(), 1.0]] # 存放路径和概率
    for prob in probs:
        cands = list()
        for i in range(len(seq_scores)):
            seq, score = seq_scores[i]
            for j in range(len(prob)):
                cand = [seq+[j], score * prob[j]]
                cands.append(cand)
        # 寻找topk个最大
        seq_scores = sorted(cands, key=lambda x: x[1], reverse=True)[:k]
    # k个最大中的最优
    seq, score = seq_scores[0]
    return seq
# test data
data = np.array([[0.1, 0.2, 0.4, 0.3],
                [0.3, 0.5, 0.15, 0.05],
                [0.25, 0.2, 0.3, 0.25],
                [0.5, 0.3, 0.08, 0.12],
                [0.1, 0.4, 0.3, 0.2]])

seq = beam_search(data, 3)
print(seq) # output: [2, 1, 2, 0, 1]

写法2:利用快排的思想,找出第k大的值,根据partition划分直接找到k个最大的值。 numpy中argpartition可以完成这个操作。

import numpy as np

def beam_search2(probs, k):
    seqs, scores = [[]], [1.0]
    for prob in probs:
        beam_seqs, beams_scores = [], []
        for i in range(len(seqs)):
            seq, score = seqs[i], scores[i]
            for j in range(len(prob)):
                beam_seqs.append(seq + [j])
                beams_scores.append(score * prob[j])
        ind = np.argpartition(beams_scores, -k)[-k:]
        seqs = np.array(beam_seqs)[ind].tolist()
        scores = np.array(beams_scores)[ind].tolist()
    ind = np.argmax(scores)
    seq = np.array(seqs)[ind]
    return seq
# test data
data = np.array([[0.1, 0.2, 0.4, 0.3],
                [0.3, 0.5, 0.15, 0.05],
                [0.25, 0.2, 0.3, 0.25],
                [0.5, 0.3, 0.08, 0.12],
                [0.1, 0.4, 0.3, 0.2]])

seq = beam_search2(data, 3)
print(seq) # output: [2 1 2 0 1]

写法3:利用堆排序,建立一个k个元素的小顶堆,每次通过与堆顶比较判断大小,并更新堆,最后完成topk最大的目标。 heapq中有现成的方法nlargest。

import numpy as np
import heapq

def beam_search3(probs, k):
    seq_scores = [[list(), 1.0]]
    for prob in probs:
        cands = list()
        for i in range(len(seq_scores)):
            seq, score = seq_scores[i]
            for j in range(len(prob)):
                cand = [seq + [j], score * prob[j]]
                cands.append(cand)
        seq_scores = heapq.nlargest(k, cands, lambda d: d[1])
    seq, score = seq_scores[0]
    return seq
# test data
data = np.array([[0.1, 0.2, 0.4, 0.3],
                [0.3, 0.5, 0.15, 0.05],
                [0.25, 0.2, 0.3, 0.25],
                [0.5, 0.3, 0.08, 0.12],
                [0.1, 0.4, 0.3, 0.2]])

seq = beam_search3(data, 3)
print(seq) # output: [2, 1, 2, 0, 1]

参考文档
[1]. https://zhuanlan.zhihu.com/p/42006406/
[2]. https://blog.csdn.net/qq_16234613/article/details/83012046
[3]. https://zhuanlan.zhihu.com/p/43703136

posted @ 2020-07-18 17:04  黄然小悟  阅读(892)  评论(0编辑  收藏  举报