Ray

  博客园 :: 首页 :: 博问 :: 闪存 :: 新随笔 :: 联系 :: 订阅 订阅 :: 管理 ::

Codingame 散列表为主题的练习题中,马尔科夫链文本生成吸引到了我的注意力。它集合了马尔科夫链,状态机和散列表三个方面的学习内容。其中,n-gram马尔科夫链运用到了文本聊天机器人的设计中,还是蛮有启发性的,应该是chatgpt之前的一项经典技术。下面简单讲讲这个编程练习题。

目标

制作一个游戏,让NPC说话,即使这是荒谬的。因为懒得写下所有荒谬的表述,所以你决定创造一个文本生成器。幸运的是,你有一系列的文本来作为训练数据。你将构建一个典型的n-gram马尔科夫链。请研究什么是马尔科夫链链,什么是n-gram,怎么应用。

一个例子

文本 t= one fish is good and no fish is bad and that is it
对应的n-gram深度为2,可以按照如下步骤生成马尔科夫链:
Step 1 : 'one fish' => ['is']
Step 2 : 'fish is' => ['good']
Step 3 : 'is good' => ['and']
Step 4 : 'good and' => ['no']
Step 5 : 'and no' => ['fish']
Step 6 : 'no fish' => ['is']
注意到步骤2中的'fish is',因此step 7中添加新值到列表的末尾
Step 7 : 'fish is' => ['good','bad']
Step 8 : 'is bad => ['and']
如此处理,直到文本t的末尾。

现在,我们可以生成文本了。对于长度为5的输出,seed文本是fish is,我们可以随机生成如下的文本:

  • fish is good and no
  • fish is bad and that
    因为走到'fish is'时,我们可以随机选择'good'或者'bad'。其他的文本都是确定的。

重复性

如果n-gram马尔科夫链特定状态的下一个状态,采用以下的伪代码来“随机”选择下一个状态。

random_seed = 0
function pick_option_index( num_of_options ) {
    random_seed += 7
    return random_seed % num_of_options
}

在上面的例子里,第一次查询返回['good','bad']。有两个选项。调用pick_option_index(2)返回7%2 = 1。因此,我们在输出文本的末尾添加'bad'。针对只有一个选项的情况,也调用此函数。

Solution

按照叙述的要求,算法大致分两个部分,一是由输入的文本和深度参数生成n-gram马尔科夫链,二是由seed文本出发查询n-gram马尔科夫链补齐单词。

#include <iostream>
#include <vector>
#include <map>
#include <unordered_map>
#include <cstdlib>
#include <algorithm>

using namespace std;
int random_seed = 0;

// pick an option index randomly using a predetermined algorithm
int pick_option_index(int num_of_options) {
    random_seed += 7;
    return random_seed % num_of_options;
}

auto split(string text)
{
    // Split text into words
    vector<string> words;
    string word = "";
    for (char c : text) {
        if (c == ' ') {
            words.push_back(word);
            word = "";
        } else {
            word += c;
        }
    }
    words.push_back(word);
    return words;
}

// Generate Markov chain from input text
auto generateMarkovChain(vector<string>& words, int depth) {
    // map<string, vector<string>> chain;
    unordered_map<string, vector<string>> chain;

    // Generate n-grams and add to chain
    for (int i = 0; i < words.size() - depth; i++) {
        string ngram = "";
        for (int j = i; j < i + depth; j++) {
            ngram += words[j] + " ";
        }
        ngram.pop_back();  // Remove extra space at the end
        string nextWord = words[i + depth];
        if (chain.count(ngram)) {
            chain[ngram].push_back(nextWord);
        } else {
            chain[ngram] = {nextWord};
        }
    }

    return chain;
}

inline string vectorToString(vector<string> vec)
{
    string current = "";
    for (auto w:vec)
    {
        current += w + " ";
    }
    current.pop_back();
    return current;
}

std::string generateOutputText(unordered_map<std::string, std::vector<std::string>> chain, int length, std::string seed, int depth) {
    std::string output = seed;
    std::string current = seed;
    int seed_num = std::count(current.begin(), current.end(), ' ') + 1;  // Determine depth from first ngram in chain
    auto seed_words = split(seed);
    std::vector<std::string> ngram_words;
    for (int i= seed_num-depth; i < seed_num; ++i)
    {
        ngram_words.push_back(seed_words[i]);
    }
    current = vectorToString(ngram_words);

    for (int i = 0; i < length - seed_num; i++) {
        if (chain.count(current) == 0) {
            // No match for current ngram in chain, stop generating output
            break;
        }
        std::vector<std::string> options = chain.at(current);
        if (options.empty()) {
            // No options available, stop generating output
            break;
        }
        int index = pick_option_index(options.size());
        if (index >= options.size()) {
            // Index out of bounds, stop generating output
            break;
        }
        std::string next = options[index];
        output += " " + next;
        ngram_words.push_back(next);
        ngram_words.erase(ngram_words.begin());
        cerr << current << ":" << output << endl;
        current = vectorToString(ngram_words);
        cerr << current << ":" << output << endl;
    }
    return output;
}

int main()
{
    string text;
    getline(cin, text);
    int depth, length;
    cin >> depth >> length;
    string seed;
    cin.ignore();
    getline(cin, seed);

    // Split text into words
    auto words = split(text);

    // Generate Markov chain from input text
    // map<string, vector<string>> chain = generateMarkovChain(words, depth);
    unordered_map<string, vector<string>> chain = generateMarkovChain(words, depth);

    // Generate output text using Markov chain and seed
    string output = generateOutputText(chain, length, seed, depth);

    // Print output
    cout << output << endl;
}

参考

https://www.codingame.com/blog/markov-chain-automaton2000/
https://analyticsindiamag.com/hands-on-guide-to-markov-chain-for-text-generation/
https://www.codingame.com/training/hard/code-your-own-automaton2000-step-1

posted on 2023-04-10 15:49  RayChenCode  阅读(63)  评论(0编辑  收藏  举报