greedy search和beam search的原理以及实现
在自然语言处理seq2seq模型中,模型训练完成后,预测推理时需要预测每一步输出的最可能的单词,之后组合成完整的预测输出句子。这里每一步最可能的输出单词的选择就用到greedy search或者beam search。下面详细介绍一下这两种搜索的区别,以及实现方法。
贪婪搜索(greedy search)
greedy search比较简单,就是贪婪式的搜索,每一步都选择概率最大的单词输出,最后组成整个句子输出。这种方法给出的结果一般情况结果比较差,因为只考虑了每一步的最优解,往往里全局最优解差距很大。
贪婪搜索实现比较简单,这里就不写了,每一部找一个最大值就好了。
集束搜索(beam search)
beam search是介于全局搜索和贪婪搜索之间。这里先讲一下全局搜索,全局搜索考虑的是全局最优解,需要把每一种可能输出结果都算出来,然后找出概率最大的输出。这种搜索空间是非常巨大的,假设我们的词表大小为N,句子长度为T个单词,整个搜索时间复杂度为O(N*N*T),一般N取值在几万到几十万级别,T在为几百个单词,实际计算比较慢。
而降低时间复杂度的方法就是寻找次优解,具体就是把搜索空间中的N减下来,每一步计算完只保留K个最大的取值路径,这样时间复杂度降为O(K*N*T),K取值一般比N小很多。这样得到的虽然不是最优解,但是在seq2seq模型的推理预测中可以兼顾时间和效果。
beam search实现
下面用python简单实现一下beam-search算法,这里实现的是假设每一步输出的可能概率是提前算好的,然后传入一个k值,计算beam-search的最优输出。
实际seq2seq中beam-search是要考虑每一部状态输出的,下一步预测输出需要上一步的状态,这里的实现暂时未考虑。
写法1:每次先排序,然后选择排序后的k个最大值。
import numpy as np
def beam_search(probs, k):
seq_scores = [[list(), 1.0]] # 存放路径和概率
for prob in probs:
cands = list()
for i in range(len(seq_scores)):
seq, score = seq_scores[i]
for j in range(len(prob)):
cand = [seq+[j], score * prob[j]]
cands.append(cand)
# 寻找topk个最大
seq_scores = sorted(cands, key=lambda x: x[1], reverse=True)[:k]
# k个最大中的最优
seq, score = seq_scores[0]
return seq
# test data
data = np.array([[0.1, 0.2, 0.4, 0.3],
[0.3, 0.5, 0.15, 0.05],
[0.25, 0.2, 0.3, 0.25],
[0.5, 0.3, 0.08, 0.12],
[0.1, 0.4, 0.3, 0.2]])
seq = beam_search(data, 3)
print(seq) # output: [2, 1, 2, 0, 1]
写法2:利用快排的思想,找出第k大的值,根据partition划分直接找到k个最大的值。 numpy中argpartition可以完成这个操作。
import numpy as np
def beam_search2(probs, k):
seqs, scores = [[]], [1.0]
for prob in probs:
beam_seqs, beams_scores = [], []
for i in range(len(seqs)):
seq, score = seqs[i], scores[i]
for j in range(len(prob)):
beam_seqs.append(seq + [j])
beams_scores.append(score * prob[j])
ind = np.argpartition(beams_scores, -k)[-k:]
seqs = np.array(beam_seqs)[ind].tolist()
scores = np.array(beams_scores)[ind].tolist()
ind = np.argmax(scores)
seq = np.array(seqs)[ind]
return seq
# test data
data = np.array([[0.1, 0.2, 0.4, 0.3],
[0.3, 0.5, 0.15, 0.05],
[0.25, 0.2, 0.3, 0.25],
[0.5, 0.3, 0.08, 0.12],
[0.1, 0.4, 0.3, 0.2]])
seq = beam_search2(data, 3)
print(seq) # output: [2 1 2 0 1]
写法3:利用堆排序,建立一个k个元素的小顶堆,每次通过与堆顶比较判断大小,并更新堆,最后完成topk最大的目标。 heapq中有现成的方法nlargest。
import numpy as np
import heapq
def beam_search3(probs, k):
seq_scores = [[list(), 1.0]]
for prob in probs:
cands = list()
for i in range(len(seq_scores)):
seq, score = seq_scores[i]
for j in range(len(prob)):
cand = [seq + [j], score * prob[j]]
cands.append(cand)
seq_scores = heapq.nlargest(k, cands, lambda d: d[1])
seq, score = seq_scores[0]
return seq
# test data
data = np.array([[0.1, 0.2, 0.4, 0.3],
[0.3, 0.5, 0.15, 0.05],
[0.25, 0.2, 0.3, 0.25],
[0.5, 0.3, 0.08, 0.12],
[0.1, 0.4, 0.3, 0.2]])
seq = beam_search3(data, 3)
print(seq) # output: [2, 1, 2, 0, 1]
参考文档
[1]. https://zhuanlan.zhihu.com/p/42006406/
[2]. https://blog.csdn.net/qq_16234613/article/details/83012046
[3]. https://zhuanlan.zhihu.com/p/43703136