[NLP] beam search的简单实现
介绍
本文基于该博客的内容改编。
想象你是一位校长,手下有十个班级,每个班级5个学生,每个学生都坐在自己的座位上,每个人成绩都不一样。
(如果你很难想象,那么就看看下面的代码实例中的data变量。)
现在你的任务是,从一班到十班,根据特定规则,在每个班级寻找一位学生,然后搭配成为最好的学生组合(注意不是找最好的学生,而是学生组合),出道成为偶像,拯救学校的衰落(?)。
由于有一套独特的评分规则,单纯的找出最好的学生并不是最佳的方案,而取决于不同学生组合而产生的最终分数,于是你采用了以下的策略:每次选择的时候,都会寻找前k个最佳组合。
例如你进入了三班,那么手里也许已经有k个来自前一班和二班的不同组合。例如:[1,2],[2,3],[1,3],list 中的每个位置代表班级,数字代表学生,三个list,说明你有三套方案(k=3)。现在你要往这个表单中加入来自三班的新学生,5位同学都跃跃欲试,你把他们5个人,每个人都放在了三套方案的后面,那么也就是15套新方案(5个学生*三个备选方案)。此时计算每一个方案的分数,然后选择前k个方案(k=3)作为结果。
最后你进入了四班,继续开始这套操作。
最后就可以找出最优方案,以及它的备选。
代码部分
import numpy as np
data = np.array([[0.1, 0.2, 0.3, 0.4, 0.5],
[0.5, 0.4, 0.3, 0.2, 0.1],
[0.1, 0.2, 0.3, 0.4, 0.5],
[0.5, 0.4, 0.3, 0.2, 0.1],
[0.1, 0.2, 0.3, 0.4, 0.5],
[0.5, 0.4, 0.3, 0.2, 0.1],
[0.1, 0.2, 0.3, 0.4, 0.5],
[0.5, 0.4, 0.3, 0.2, 0.1],
[0.1, 0.2, 0.3, 0.4, 0.5],
[0.5, 0.4, 0.3, 0.2, 0.1]])
def beam_search_decoder(data, k=3):
# 第一步:初始化 seq
# seq 是一个大list,最终会包含k个list,每一个list里面带有两个东西:序列(list) & 分数(int)
sequences = [[list(),1.0]]
# 第二步: load 每一行
for row in data:
# 2.1 初始化,到该行为止所有的可能
all_candidates = list()
# 2.2 获得上一轮的结果,为了加入这一次的值
for i in range(len(sequences)):
seq, score = sequences[i]
# 2.3 将 row 中的所有结果 与sequences每一个(k)个进行运算,获得scores
# 如果 seqences中有k个结果,一个row中有5个概率(label 或 该位置的可能符号),那么共产生 5*k个备选
for j in range(len(row)):
# candidate = [ list + 每一个备选的index j, 新分数 ]
candidate = [seq + [j], score * -np.log(row[j])]
# 将生成好的新备选加入候补席位
all_candidates.append(candidate) # 加入备选
# 根据分数进行排序
ordered = sorted(all_candidates, key=lambda tup:tup[1])
# 选择前k个,动态调整 seqences 的数量
sequences = ordered[:k]
# 查看每次sequences 的输出: print(sequences)
return sequences
def greedy_decoder(data):
# greedy decoder
# 每一行最大概率词的索引
return [np.argmax(s) for s in data]
# 数据准备阶段
data = np.array(data)
# 贪婪
greedy_result = greedy_decoder(data)
print("- 贪婪算法的结果:\n{}\n".format(greedy_result))
# beam search
beam_result = beam_search_decoder(data)
print("- beam 搜索的结果:\n{}\n".format(beam_result))
代码结果
-
贪婪算法的结果:
[4, 0, 4, 0, 4, 0, 4, 0, 4, 0] -
beam 搜索的结果:
这里 k = 3,前面是结果,后面是分数
[[[4, 0, 4, 0, 4, 0, 4, 0, 4, 0], 0.025600863289563108],
[[4, 0, 4, 0, 4, 0, 4, 0, 4, 1], 0.03384250043584397],
[[4, 0, 4, 0, 4, 0, 4, 0, 3, 0], 0.03384250043584397]]