[个人理解] llama.cpp之sample策略

最近有点时间看了几天llama.cpp 的code,有几个点,想记录一下,不对的地方,欢迎大家指正。

话说本该去年就看,奈何这个领域变的太快,索性积累到今年,当openAI也开始挤牙膏的时候一并看了。

Summary

- llama是跟chatpgt一样,基于transformer架构的decode only的一挂,这一系列的模型擅长文字接龙

- llama的模型几个部分,tokenalize, embedding, decode, sample

index module  function in llama.cpp
1 tokenize llama_tokenize
2 embedding  
3 decode llama_decode
4 sample llama_sampling_sample
     

这篇文章重点介绍sample这一块

 

sample 策略

总结起来就是,本来decode输出的是字典里面每个token是next token的概率,但是不能直接取概率最大的那一个,需要有一定的sample策略。基本上就是两步走,第一步,要先筛筛,把符合一定条件的token留下,第二步掷骰子,随机选取一个token。

先说这第一步

llama里面用的是多种策略的组合,核心的code是下面这几行。samplers_sequence里面基本上这几种策略都用了,先是top40,把概率排名前40的token留下,然后又是top p就是把概率大于p的都留下,然后是temperature的策略,把所有token的概率重新算一算,温度越高,那么所有的概率就越接近,温度越低,就让概率高的越高,概率低的越低,拉大差距,知道谁在裸泳。

for (auto sampler_type : samplers_sequence) {
    switch (sampler_type) {
        case llama_sampler_type::TOP_K    : llama_sample_top_k    (ctx_main, &cur_p, top_k,     min_keep); break;
        case llama_sampler_type::TFS_Z    : llama_sample_tail_free(ctx_main, &cur_p, tfs_z,     min_keep); break;
        case llama_sampler_type::TYPICAL_P: llama_sample_typical  (ctx_main, &cur_p, typical_p, min_keep); break;
        case llama_sampler_type::TOP_P    : llama_sample_top_p    (ctx_main, &cur_p, top_p,     min_keep); break;
        case llama_sampler_type::MIN_P    : llama_sample_min_p    (ctx_main, &cur_p, min_p,     min_keep); break;
        case llama_sampler_type::TEMPERATURE:
            if (dynatemp_range > 0) {
                float dynatemp_min = std::max(0.0f, temp - dynatemp_range);
                float dynatemp_max = std::max(0.0f, temp + dynatemp_range);
                llama_sample_entropy(ctx_main, &cur_p, dynatemp_min, dynatemp_max, dynatemp_exponent);
            } else {
                llama_sample_temp(ctx_main, &cur_p, temp);
            }
            break;
        default : break;
    }
}

  

再说第二步

这第二步也很简单,code里面核心部分就一行,就是根据每个token的概率值去归一化然后重新产生一组概率,这重新产生的概率和为1,然后根据这个概率分布去掷骰子,看看究竟选了哪个token的id。就结束了。

std::discrete_distribution<> dist(probs.begin(), probs.end());

  

差不多就是这样。

看到网上讲这一块的比较少,就记录一下。

posted @ 2024-07-30 11:37  sunny,lee  阅读(79)  评论(0编辑  收藏  举报