[NLP] beam search的简单实现

介绍

本文基于该博客的内容改编。

想象你是一位校长,手下有十个班级,每个班级5个学生,每个学生都坐在自己的座位上,每个人成绩都不一样。
(如果你很难想象,那么就看看下面的代码实例中的data变量。)

现在你的任务是,从一班到十班,根据特定规则,在每个班级寻找一位学生,然后搭配成为最好的学生组合(注意不是找最好的学生,而是学生组合),出道成为偶像,拯救学校的衰落(?)。

由于有一套独特的评分规则,单纯的找出最好的学生并不是最佳的方案,而取决于不同学生组合而产生的最终分数,于是你采用了以下的策略:每次选择的时候,都会寻找前k个最佳组合。

例如你进入了三班,那么手里也许已经有k个来自前一班和二班的不同组合。例如:[1,2],[2,3],[1,3],list 中的每个位置代表班级,数字代表学生,三个list,说明你有三套方案(k=3)。现在你要往这个表单中加入来自三班的新学生,5位同学都跃跃欲试,你把他们5个人,每个人都放在了三套方案的后面,那么也就是15套新方案(5个学生*三个备选方案)。此时计算每一个方案的分数,然后选择前k个方案(k=3)作为结果。

最后你进入了四班,继续开始这套操作。

最后就可以找出最优方案,以及它的备选。

代码部分

import numpy as np

data = np.array([[0.1, 0.2, 0.3, 0.4, 0.5],
        [0.5, 0.4, 0.3, 0.2, 0.1],
        [0.1, 0.2, 0.3, 0.4, 0.5],
        [0.5, 0.4, 0.3, 0.2, 0.1],
        [0.1, 0.2, 0.3, 0.4, 0.5],
        [0.5, 0.4, 0.3, 0.2, 0.1],
        [0.1, 0.2, 0.3, 0.4, 0.5],
        [0.5, 0.4, 0.3, 0.2, 0.1],
        [0.1, 0.2, 0.3, 0.4, 0.5],
        [0.5, 0.4, 0.3, 0.2, 0.1]])

def beam_search_decoder(data, k=3):
    # 第一步:初始化 seq
    # seq 是一个大list,最终会包含k个list,每一个list里面带有两个东西:序列(list) & 分数(int)
    sequences = [[list(),1.0]] 

    # 第二步: load 每一行
    for row in data: 
        # 2.1 初始化,到该行为止所有的可能
        all_candidates = list() 
        # 2.2 获得上一轮的结果,为了加入这一次的值
        for i in range(len(sequences)):
            seq, score = sequences[i] 

            # 2.3 将 row 中的所有结果 与sequences每一个(k)个进行运算,获得scores
            # 如果 seqences中有k个结果,一个row中有5个概率(label 或 该位置的可能符号),那么共产生 5*k个备选
            for j in range(len(row)):
                # candidate = [ list + 每一个备选的index j, 新分数 ]
                candidate = [seq + [j], score * -np.log(row[j])]
                # 将生成好的新备选加入候补席位
                all_candidates.append(candidate) # 加入备选


        # 根据分数进行排序
        ordered = sorted(all_candidates, key=lambda tup:tup[1])

        # 选择前k个,动态调整 seqences 的数量
        sequences = ordered[:k]
        # 查看每次sequences 的输出: print(sequences)

    return sequences

def greedy_decoder(data):
    # greedy decoder
    # 每一行最大概率词的索引
    return [np.argmax(s) for s in data]

# 数据准备阶段
data = np.array(data)
# 贪婪
greedy_result = greedy_decoder(data)
print("- 贪婪算法的结果:\n{}\n".format(greedy_result))
# beam search
beam_result = beam_search_decoder(data)
print("- beam 搜索的结果:\n{}\n".format(beam_result))

代码结果

  • 贪婪算法的结果:
    [4, 0, 4, 0, 4, 0, 4, 0, 4, 0]

  • beam 搜索的结果:
    这里 k = 3,前面是结果,后面是分数
    [[[4, 0, 4, 0, 4, 0, 4, 0, 4, 0], 0.025600863289563108],
    [[4, 0, 4, 0, 4, 0, 4, 0, 4, 1], 0.03384250043584397],
    [[4, 0, 4, 0, 4, 0, 4, 0, 3, 0], 0.03384250043584397]]

posted @ 2020-11-25 05:07  schaffen  阅读(606)  评论(0编辑  收藏  举报