大模型入门(七)—— RLHF中的PPO算法理解

  本文主要是结合PPO在大模型中RLHF微调中的应用来理解PPO算法。

一、强化学习介绍

1.1、基本要素

  环境的状态S:t时刻环境的状态$S_{t}$是环境状态集中某一个状态,以RLHF中为例,序列$w1,w2,w3$是当前的状态。

  个体的动作A:t时刻个体采取的动作$A_{t}$,给定序列$w1,w2,w3$,此时得到$w4$,得到$w4$就是执行的一次动作。然后就得到下一状态$S_{t+1} = w1,w2,w3,w4$。

  环境的奖励R:t时刻个体在$S_{t}$采取动作$A_t$得到的奖励$R_t$,奖励是对当前动作,不会考虑到未来的影响。

  个体的策略$\pi$:根据输入的状态获取动作,可以表示为$\pi (a | s)$。

  状态价值函数:价值一般是一个期望函数,即当前状态下所有动作产生的奖励和未来的奖励的期望值,这也是它不同于奖励R的地方,可以表示为$v_{\pi}(s) = \mathbb{E}_{\pi}(G_t|S_t=s ) = \mathbb{E}_{\pi}(R_{t} + \gamma R_{t+1} + \gamma^2R_{t+2}+...|S_t=s)$。所以价值才是真正衡量当前状态下能产生的价值。

  动作价值函数:动作价值函数类似于状态价值函数,只不过是在当前的状态和动作下获得的价值。可以表示为$q_{\pi}(s,a) = \mathbb{E}_{\pi}(G_t|S_t=s, A_t=a) = \mathbb{E}_{\pi}(R_{t} + \gamma R_{t+1} + \gamma^2R_{t+2}+...|S_t=s,A_t=a)$。

  状态转移概率模型:根据当前的动作和状态转移到下一状态的概率。

1.2、有模型和无模型的区别

  有模型和无模型实际上指是否要对环境建模,换句话说核心在于是否有状态转移概率模型,有模型是指有状态转移概率模型,知道状态是怎么转移的,是一个白盒模型,但实际中大多数强化学习的算法都是无模型的,不去构建状态转移概率模型,而是直接得到下一个状态。如上面的例子中,在状态$w1,w2,w3$直接进入到下一状态$w1,w2,w3,w4$,这一过程是由生成模型(策略函数)直接给出的。

1.3、蒙特卡洛(MC)和时序差分(TD)的区别

  蒙特卡洛和时序差分的主要差别在于计算价值函数,在蒙特卡洛中计算价值需要获取完整的状态序列,而时序差分则不需要,但实际训练过程中想要获取完整的状态序列成本是非常高的,所以时序差分也是现在的训练强化学习的主流方法。

  假定给定一条完整的状态序列$S_1,A_1,R_2,S_2,A_2,...S_t,A_t,R_{t+1},...R_T, S_T$。

  价值函数的计算可以表示为:

    $v_{\pi}(s) = \mathbb{E}_{\pi}(G_t|S_t=s ) = \mathbb{E}_{\pi}(R_{t} + \gamma R_{t+1} + \gamma^2R_{t+2}+...|S_t=s)$。

  在这里$G_t$表示某一条状态序列的收获,注意这又引入了一个新的概念。而价值函数是通过采样多个序列得到多个$G_t$的平均值。可以表示为:

    $v_{\pi}(s) \approx average(G_t), s.t. S_t=s$

  价值函数的求解到这里就结束了,但还有一些可以优化的地方,一是同一个状态可能在一个完整的序列中重复出现,这里会涉及到first visit和every visit,感兴趣的自己可以去了解;二是在求解平均值时需要先把所有的$G_t$存储下来,这样太浪费存储空间,最好的是增量计算平均值,这里的推导公式也很简单,最后可以得到这样的价值函数求解方式:

    $V(S_t) = V(S_t)  + \frac{1}{N(S_t)}(G_t -  V(S_t) )$

  在这里$N(S_t)$表示状态$S_t$出现的次数。但是当海量数据分布式迭代时,无法精确地计算出$N(S_t)$,所以这里通常用一个系数$\alpha$来替代,这样就可以表示为:

    $V(S_t) = V(S_t)  + \alpha(G_t -  V(S_t) )$

  以上就是蒙特卡洛的求解价值的方式,现在回到时序差分中,我们说到要得到一个完整的序列是比较难得,这时就想到马尔科夫性,只考虑下一状态的价值,将收获$G_t$表示为:

    $G(t) = R_{t} + \gamma V(S_{t+1})$ 

  一般把$R_{t} + \gamma V(S_{t+1})$ 称为TD目标值。这样我就不需要得到完整的序列,只需要下一状态就可以求解价值函数。

  蒙特卡洛因为采样了完整的序列,能更精准的估计奖励值,可以认为是无偏的估计,但因为序列越长,序列之间的差异越大,会产生较大的方差,导致收敛很慢;而时序差分不需要完整的序列,单步时序差分还只考虑到下一状态,本身是有偏估计,但方差会比较小。  

1.4、on policy和off policy的区别

  在实际的强化学习模型中会有两个策略:行为策略和目标策略,行为策略是用来与环境互动产生数据的策略,即在训练过程中做决策;而目标策略在行为策略产生的数据中不断学习、优化,即学习训练完毕后拿去应用的策略。on policy中行为策略和目标策略是同一个策略,off policy中行为策略和目标策略是不同的策略。通常来说,off policy会先用不同的策略产生大量的样本,如DQN中,通过经验回放的方式构造目标策略的训练样本,经验回放的方式会使得样本产生的策略不同于目标策略;on policy一般是目标策略先生成一条样本,然后接着计算价值去更新目标策略,on policy的这种方式会存在探索-利用的矛盾,因为行为策略和目标策略一致缺乏探索,会导致模型收敛到局部最优,但好处是训练速度够快,容易收敛,实际的on policy也不一定是生成一条样本就去训练,可能会是生成n条样本再去训练,但这种方式是不是还是on policy,这个问题也有不少争议,具体可以见强化学习里的 on-policy 和 off-policy 的区别,像RLHF的ppo中就有类似的设计。

1.5、value-based和policy-based的区别

  在强化学习中可以将训练方法最基本的可以分为两种:value-based和policy-based,这两种方式从形式上来可以看做是:

  value-based:输入状态$s$,输出动作价值函数$Q(s, a)$

  policy-based: 输入状态$s$,输出动作选择的概率$P(s, a)$

  在这里可以看到,value-based是根据当前的状态$s$,配合所有的动作空间中的动作$a$,直接计算动作价值函数,在训练时会通过${\epsilon} -greedy$的方式选择动作,而在预测是会直接按最大的动作价值去选择动作,这种方式会有什么问题呢?

  1)无法处理连续动作空间或者高维离散动作空间,毕竟你需要求解所有动作对应的动作价值函数。

  2)缺乏随机性,基于value-based的策略是确定的,预测时是会选择价值最高的动作。

  基于上面的问题再来看policy-based,policy-based是计算动作发生的概率,当高维离散或者连续动作空间时,也可以转换成连续值的预测,可以很好的解决问题1),再加上动作的生成本身就是按照概率生成的,因此也具有一定的随机性,也多了更多的探索的可能性。

  以上是从策略模型的输出方式来看的,而建模训练时的目标函数也会有比较大的差异,在value-based中的目标函数通常是当前模型输出的Q值和目标Q值的均方误差,而policy-based中的目标函数是最大化价值最大的序列发生的概率。

  除了value-based和policy-based之外,还有一种常用的训练方法Actor-Critic,Actor-Critic结合了value-based和policy-based两种方式,Actor是策略网络,Critic是价值网络。在经典的policy-based reinforce算法中,其策略参数的更新是$\theta = \theta + \alpha \nabla_{\theta}log \pi_{\theta}(s_t,a_t)  v_t$,在reinforce算法中$v_t$是由蒙特卡洛求解得到的,而在Actor-Critic中该值可以直接通过一个Q网络直接计算得到,而且在这里的$v_t$具体选择什么样的评估方式,可以自己选择状态价值、动作价值、TD误差、优势函数等等多种方式。除了更新Actor的参数之外,Critic的参数也会被更新,类似于value-based中一样用均方误差更新即可。关于Actor-Critic详情见强化学习(十四) Actor-Critic

二、PPO算法理解

  在这里不会对PPO算法的做完整的推理,而是就着当前主流RLHF框架(如deepspeed-chat,trl,trlx等)中常用的PPO算法去理解PPO算法是如何训练对齐模型。

2.1、优势函数和泛化优势估计(GAE)

  在前面提到了TD误差和蒙特卡洛两种方法去计算价值函数,但这两种方法分别代表了两个极端:TD误差的高偏差低方差可能会导致无法收敛,蒙特卡洛的高方差低偏差需要大量的训练数据才可能收敛。在这种情况下,有很多偏差和方差折中的方法被提出,如在GAE估计中被用到的$\lambda - return$算法。

  TD误差算法中每次只考虑到当前时刻的回报值$R_t$和下一时刻的状态价值$V(S_{t + 1})$。回顾下TD误差算法的表达式$R_{t+1} + \gamma V(S_{t+1})$,在这里状态值的近似估计会带来偏差,因为收折扣因子$\gamma$的影响,所以带来的偏差可以定义为$\gamma {\epsilon}_{t+1}$。如果此时用到两步的回报值,则TD误差可以写成$R_{t} + \gamma R_{t+1} + {\gamma}^2 V(S_{t+2})$,则偏差可以定义为${\gamma}^2 {\epsilon}_{t+2}$。由于$0 ≤ \gamma ≤ 1$,随着你使用的回报步数的增大,偏差值会越来越小。

  假定一次采样的最终终止状态时刻是$t+N$,则TD误差和蒙特卡洛可以分别表示为:

  $G(t) = R_{t} + \gamma V(S_{t+1})$ 和 $G(t+N) = \sum_{n=0}^{N} {\gamma}^n R(t+N)$。

  $\lambda - return$算法是在TD误差和蒙特卡洛算法中找到一个折中点,通过这种方式在偏差和方差间找到平衡,则$\lambda - return$可以表示为:

  $G^{\lambda} t = (1 - \lambda) \sum_{n=1}^{N-1} {\lambda}^{n - 1}G(t+n) + {\lambda}^{N-1}G(t+N)$

  在上式中,其中$0 ≤ \lambda ≤ 1$,当$\lambda = 0$即为TD算法,当$\lambda = 1$即为蒙特卡洛算法。从$\lambda - return$算法中来看,也是需要采样得到完整的序列才能完成整个计算。

   花了一部分内容也介绍$\lambda - return$算法,再回到我们的优势函数上来,优势函数是用来度量某个状态$s$下选取某个具体动作$a$的合理性,其表达式可以表示为:

  $A(s,a) = Q(s,a) - V(s)$

  从式子中可以看出优势函数表达的就是选取某个动作后得到的动作价值相对于平均动作价值的增量,当优势函数大于0时,说明当前的动作选择是优于平均的,优势函数对策略梯度是非常适合的。

  泛化优势估计实际上就是$\lambda - return$应用在估计优势函数的版本。可以按照$\lambda - return$方法中使用n步回报值的思路去推导,最终的表达是可以表示为:

  $A_t = \sum_{n = 0}^{\infty}(\gamma \lambda)^n {\delta}_{t+n}$

  其中${\delta}_{t+n} = R_{t+n} + \gamma V(S_{t+n+1}) - V(s_{t+n})$。所以这里的${\delta}_{t+n}$就是$t+n$时刻的TD误差。

2.2 PPO算法详解

  PPO算法的推理不做过多的介绍,现在通用的PPO算法的loss一般都是采用下面的loss:  

  结合当前的主流的RLHF框架中的实现,来看看PPO算法中整个训练流程是如何运转的。首先借用复旦NLP组的MOSS-RLHF论文中的一张图:

  PPO算法是基于Actor-Critic架构的策略梯度算法,结合上面的流程图,来梳理下整个PPO的训练流程:

  1)通过监督学习微调好SFT模型和Reward模型,在实际的PPO训练过程中SFT模型主要是作为Actor策略模型,而Reward模型主要是输出环境对当前动作执行的奖励,可以是一个用人类偏好数据训练的打分模型,也可以是多个模型的组合,甚至是融合策略等等,只要最终能对Actor模型生成的回复有一个符合人类偏好的、客观的打分即可。

  2)在准备好SFT模型和Reward模型后,一般来说是以SFT模型初始话Actor(策略模型),Ref(用于约束策略模型的参数变化量),Critic(价值模型),Reward(对策略的执行反馈即时的奖励)4个模型,Ref和Reward代表着环境对Actor的奖励或约束,参数是不会更新的,而Actor和Critic是会迭代优化的。并不是所有的框架都会初始化这4个模型,在trl和trlx中Actor和Critic共用一个模型。

  3)准备一个batch_size=512大小的只有prompt的数据集data,首先输入到Actor模型中,Actor模型会生成相应的回复response,此时Actor只做eval。

  4)将3)中生成的response和对应的prompt组合成new_prompt构成新的数据集new_data,将new_data分别输入到Actor、Ref、Reward和Critic模型中,此时4个模型只做eval,Actor和Ref分别输出相应的logits和ref_logits,并且计算对应的logprobs和ref_logprobs,Reward输出奖励值rewards,Critic输出每次token生成的价值values,这样就通过采样生成了一系列的完整序列样本作为训练数据(new_prompts,logprobs,ref_logprobs,rewards,values)。

  5)从new_data中每次取出mini_batch_size=64的的训练数据mini_data用来执行后面的PPO算法中的模型参数更新。

  6)使用mini_data中的(logprobs,ref_logprobs,rewards)用于更新rewards,具体的做法是利用logprobs和ref_logprobs计算KL散度用于约束Actor模型的参数不要太偏离Ref模型,得到的KL散度的维度是[batch size, seq_len],而Reward模型输出的rewards是[batch_size, 1],将新的new_rewards表示为KL散度,且new_rewards最后一个token的结果加上rewards的值,也就是说新的new_rewards的维度是[batch_size, seq_len],其中所有的值都是KL散度值,维度最后一个token(是指去除了pad token的最后一个token)的值是加上了rewards。大家在这里可能会比较疑惑,为什么rewards只加在了最后一个token上,首先Reward模型输出的rewards本身就是对一条完整的序列输出的奖励值,也就是序列最后时刻输出的奖励值,而且最后时刻的奖励值也是对整个生成结果的的真实反馈,那这样是不是rewards最后就只作用在最后一个token上呢?那不是这样的,因为在计算价值的过程中,在获得完整序列后,之前时刻的价值都是会考虑到最终时刻的奖励。

  7)根据(new_rewards,values)去计算得到优势函数值advantages和回报returns,优势函数值advantages的计算公式在上面已给出,首先计算$\delta$值,考虑到优势函数值advantages在计算过程中是一个累加的操作,可以将序列反向计算,从最后时刻往前计算,回顾前面${\delta}_{t+n} = R_{t+n} + \gamma V(S_{t+n+1}) - V(s_{t+n})$,在这里$R_{t+n}$对应new_rewards,表示即时奖励,$V(s_{t+n})$对应到values,表示Critic模型预估出来的状态价值,在一开始预估出来的价值不一定会很准,所以Critic也是需要跟着迭代的。在时间维度上按照衰减因子$\gamma$累加所有的$\delta$就可以得到每个时刻对应的优势值advantages,那回报returns又是什么呢?通过前面的内容知道优势值advantages表示的是当前动作价值较平均动作价值的增量,平均动作价值又等于状态价值,returns在这里的计算是advantages+values,通过这种方式可以预估出当前的动作价值returns。

  8)得到advantages和returns之后,又该如何构建模型loss,更新模型参数呢?在这之前所有的操作都是在采样训练数据data,并且构建训练目标值y。从现在开始才开始进入了我们所熟知的输入训练数据,构建模型loss,并计算梯度反向迭代模型参数。

  9)首先是Actor模型loss的构建,输入new_prompts进入到Actor模型得到新的new_logporbs,再结合原来的信息组合成(logprobs, advantages, new_logprobs)计算Actor模型的loss,参考上面的PPO loss将这三个值代入,${\pi}_{\theta}$对应new_logprobs, ${\pi}_{{\theta}_{old}}$对应logprobs,$A_t$对应advantages。

  10)对于Critic模型,输入new_prompts进入到Critic模型得到新的new_values,再结合原来的信息组合成(values,returns, new_values),使用MSE作为loss,约束returns和new_values。

  至此整个训练流程基本上讲完了。强化学习的训练模式可以分为生成训练样本,结合环境的奖励计算每个动作下的长期价值函数,用长期价值函数来优化模型。

主要参考:

https://www.cnblogs.com/pinard/category/1254674.html

https://github.com/microsoft/DeepSpeedExamples

https://arxiv.org/abs/2307.04964

除此还有很多大佬的博客,知乎回答等等。

posted @ 2023-07-14 15:28  微笑sun  阅读(6219)  评论(0编辑  收藏  举报