nlp基础-生成模型解码策略

首先参考transformers的源代码

# transformers.generation.utils..GenerationMixin._get_generation_mode
    def _get_generation_mode(
        self, generation_config: GenerationConfig, assistant_model: Optional["PreTrainedModel"]
    ) -> GenerationMode:

        if generation_config.constraints is not None or generation_config.force_words_ids is not None:
            generation_mode = GenerationMode.CONSTRAINED_BEAM_SEARCH
        elif generation_config.num_beams == 1:
            if generation_config.do_sample is False:
                if (
                    generation_config.top_k is not None
                    and generation_config.top_k > 1
                    and generation_config.penalty_alpha is not None
                    and generation_config.penalty_alpha > 0
                ):
                    generation_mode = GenerationMode.CONTRASTIVE_SEARCH
                else:
                    generation_mode = GenerationMode.GREEDY_SEARCH
            else:
                generation_mode = GenerationMode.SAMPLE
        else:
            if generation_config.num_beam_groups > 1:
                generation_mode = GenerationMode.GROUP_BEAM_SEARCH
            elif generation_config.do_sample is True:
                generation_mode = GenerationMode.BEAM_SAMPLE
            else:
                generation_mode = GenerationMode.BEAM_SEARCH


        if assistant_model is not None:
            if generation_mode in ("greedy_search", "sample"):
                generation_mode = GenerationMode.ASSISTED_GENERATION

        return generation_mode

每次挑概率最大的token作为预测token

目的:避免出现大量重复
核心思想:把当前要生成的token和已经生成的所有token做相似度计算,得到最大的相似度值;然后使得该token的概率与最大的相似度值的差值最大化的那个token就是我们要生成的token;具体的公式如下:
\(x_t=argmax_{v \in V} \{(1-\alpha) * P_{\theta}(v|x_{<t}) - \alpha*(max\{s(h_v,h_{xj}):1\le j \le t-1 \}) \}\)

\(tok_k\)常取3-10$

SAMPLE

top_k:取概率最高的前k个token
top_p:取的token概率不超过p(例0.7),用来避免长尾分布
temperature:输出多样性

SAMPLE 源代码中会有 \(logits\_warper\) 用来sample
image

top_k
length_penalty

源码解读

https://www.bing.com/search?pglt=2081&q=GROUP_BEAM_SEARCH&cvid=ea25d35258734257845368b6be2a2f11&gs_lcrp=EgZjaHJvbWUyBggAEEUYOdIBBzM4M2owajGoAgCwAgA&FORM=ANNTA1&PC=LCTS&mkt=zh-CN

BEAM_SAMPLE

BEAM_SEARCH + SAMPLE

ASSISTED_GENERATION

https://zhuanlan.zhihu.com/p/632253732

posted @   shiiiilong  阅读(163)  评论(0编辑  收藏  举报
相关博文:
阅读排行:
· DeepSeek 开源周回顾「GitHub 热点速览」
· 物流快递公司核心技术能力-地址解析分单基础技术分享
· .NET 10首个预览版发布:重大改进与新特性概览!
· AI与.NET技术实操系列(二):开始使用ML.NET
· 单线程的Redis速度为什么快?
点击右上角即可分享
微信分享提示