reward model learning papers
1. Fine-Tuning Language Models from Human Preferences
reward model:774M参数量的GPT-2,先进行了有监督训练
训练loss:
其中r(x,y)代表reward model,x代表输入或者prompt,y代表输出或者reponse。
会给定标记者4个候选,即y1, y2, y3, y4,然后让标记者从中选择一个,其序号记为b(即标记者选择了yb).
训练好后对reward model的输出进行了归一化。
这个loss的缺点就是,只考虑了4个候选句子中标记者最喜欢的那个句子,对于剩下的三个句子没有进行比较或者训练。另外,这个loss需要最大化。
Pytorch实现:
import numpy as np import torch from torch.nn import CrossEntropyLoss def loss_v1(): loss = CrossEntropyLoss()
# 这里准备了4条数据(即batch_size=4),每条数据中有4个候选句子,[1,2,3,4]为reward model对四个候选句子输出的打分,标记者标记为0,那么标记者的标记和模型输出差异非常大,loss也会很大。 rewards = torch.tensor(np.array([[1,2,3,4],[5,6,7,8],[9,10,11,12],[13,14,15,16]],dtype=np.float32), requires_grad=True) label = torch.tensor([0,1,2,3]) loss_value = loss(rewards,label) return loss_value
如果你的数据已经排好序(第一个便为标记者选取的候选句子),这时候便不需要label,手动计算即可:
# 这个手动计算似乎比前一个实现快一点
def loss_v1_plus(): rewards = torch.tensor(np.array([[1,2,3,4],[6,5,7,8],[11,10,9,12],[16,13,14,15]],dtype=np.float32), requires_grad=True) exp_rewards = rewards.exp() loss_value = - torch.log(exp_rewards[:,0]/torch.sum(exp_rewards, 1)) loss_value = loss_value.mean() return loss_value
下面这个实现稍稍慢一点(不太影响的慢),但是由于大概率减去了最大值,在取exp的时候可能数值更加稳定,所以比较推荐使用这个实现方式:
def loss_v1_plus_plus(): rewards = torch.tensor(np.array([[1,2,3,4],[6,5,7,8],[11,10,9,12],[16,13,14,15]],dtype=np.float32), requires_grad=True) sub_first_rewards = rewards-rewards[:,0][:,None] loss_value = torch.sum(sub_first_rewards.exp(), -1).log().mean() return loss_value
2. Learning to summarize from human feedback
reward model: 在 1.3B 和 6.7B 模型上做了实验,进行了superivsed pretraining
训练loss:
训练好后对reward model的输出进行了归一化。其中:
其中r,x,y同上文,i为标记者觉得更好的y的下标。
这片论文中,value model是从reward model初始化的,而且policy model, value model, reward model之间保持统一大小。这个loss需要最小化。
这里σ的含义是:标记者觉得前一个候选句子比后一个候选句子好的概率,当两个句子一样好时,x=0, σ(x)=0.5。
这个loss其实就是前一篇论文loss的简化版,即把前一篇论文候选句子数量从4降为了2。
Pytorch实现:
def loss_v2(): loss = CrossEntropyLoss() rewards = torch.tensor(np.array([[1,2],[3,4],[5,6],[7,8]],dtype=np.float32), requires_grad=True) label = torch.tensor([0,1,0,1]) loss_value = loss(rewards,label) return loss_value
同样的,有更加快速的实现:
def loss_v2_plus(): rewards = torch.tensor(np.array([[1,2],[4,3],[5,6],[8,7]],dtype=np.float32), requires_grad=True) exp_rewards = rewards.exp() loss_value = - torch.log(exp_rewards[:,0]/torch.sum(exp_rewards, 1)) loss_value = loss_value.mean() return loss_value # 推荐: def loss_v2_plus_plus(): rewards = torch.tensor(np.array([[1,2],[4,3],[5,6],[8,7]],dtype=np.float32), requires_grad=True) sub_first_rewards = rewards-rewards[:,0][:,None] loss_value = torch.sum(sub_first_rewards.exp(), -1).log().mean() return loss_value
3. Training language models to follow instructions with human feedback
reward model: 6B
训练loss:
其中K为候选句子的个数(K等于4到9之间的值),y_w是标记者选出来更好的那个句子。这个loss需要最小化。
这个loss结合了第一篇论文和第二篇的好处,即像第一篇论文一样引入了多个候选句子。又像第二篇论文一样引入了两两比较(弥补了第一篇论文的缺点)。
作者还发现,可以把具有相同x的数据放在一个batch里面进行训练,这样有两个好处:
1. 减小overfit。如果把两两的pair拆开打乱进行训练,那么同一条(x,y)数据因为会和其他K-1条数据进行对比训练,那么这条(x,y)会进行K-1次训练,相当于数据重复了。
2. 提高计算效率。当把具有相同x的数据放入一个batch进行训练,那么每条(x,y)数据在一个epoch内刚好只用进行一次前向传播,而以前要进行K-1次,提高了训练的计算效率。
该文章所有代码如下:
#!/usr/bin/env python # -*- coding: utf-8 -*- # Copyright 2023 The TARTRL Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import itertools import numpy as np import torch from torch.nn import CrossEntropyLoss import time # 用于得到combination_dict和re_combination_dict def get_conbination(): for N in range(2,10): A = [list(a) for a in itertools.combinations(range(N),2)] print("{}:{},".format(N,A)) # get reversed for N in range(2,10): A = [[a[1],a[0]] for a in itertools.combinations(range(N),2)] print("{}:{},".format(N,A)) # 预先计算好各种组合 combination_dict = { 2:[[0, 1]], 3:[[0, 1], [0, 2], [1, 2]], 4:[[0, 1], [0, 2], [0, 3], [1, 2], [1, 3], [2, 3]], 5:[[0, 1], [0, 2], [0, 3], [0, 4], [1, 2], [1, 3], [1, 4], [2, 3], [2, 4], [3, 4]], 6:[[0, 1], [0, 2], [0, 3], [0, 4], [0, 5], [1, 2], [1, 3], [1, 4], [1, 5], [2, 3], [2, 4], [2, 5], [3, 4], [3, 5], [4, 5]], 7:[[0, 1], [0, 2], [0, 3], [0, 4], [0, 5], [0, 6], [1, 2], [1, 3], [1, 4], [1, 5], [1, 6], [2, 3], [2, 4], [2, 5], [2, 6], [3, 4], [3, 5], [3, 6], [4, 5], [4, 6], [5, 6]], 8:[[0, 1], [0, 2], [0, 3], [0, 4], [0, 5], [0, 6], [0, 7], [1, 2], [1, 3], [1, 4], [1, 5], [1, 6], [1, 7], [2, 3], [2, 4], [2, 5], [2, 6], [2, 7], [3, 4], [3, 5], [3, 6], [3, 7], [4, 5], [4, 6], [4, 7], [5, 6], [5, 7], [6, 7]], 9:[[0, 1], [0, 2], [0, 3], [0, 4], [0, 5], [0, 6], [0, 7], [0, 8], [1, 2], [1, 3], [1, 4], [1, 5], [1, 6], [1, 7], [1, 8], [2, 3], [2, 4], [2, 5], [2, 6], [2, 7], [2, 8], [3, 4], [3, 5], [3, 6], [3, 7], [3, 8], [4, 5], [4, 6], [4, 7], [4, 8], [5, 6], [5, 7], [5, 8], [6, 7], [6, 8], [7, 8]], } # combination_dict的reverse版本 re_combination_dict = { 2:[[1, 0]], 3:[[1, 0], [2, 0], [2, 1]], 4:[[1, 0], [2, 0], [3, 0], [2, 1], [3, 1], [3, 2]], 5:[[1, 0], [2, 0], [3, 0], [4, 0], [2, 1], [3, 1], [4, 1], [3, 2], [4, 2], [4, 3]], 6:[[1, 0], [2, 0], [3, 0], [4, 0], [5, 0], [2, 1], [3, 1], [4, 1], [5, 1], [3, 2], [4, 2], [5, 2], [4, 3], [5, 3], [5, 4]], 7:[[1, 0], [2, 0], [3, 0], [4, 0], [5, 0], [6, 0], [2, 1], [3, 1], [4, 1], [5, 1], [6, 1], [3, 2], [4, 2], [5, 2], [6, 2], [4, 3], [5, 3], [6, 3], [5, 4], [6, 4], [6, 5]], 8:[[1, 0], [2, 0], [3, 0], [4, 0], [5, 0], [6, 0], [7, 0], [2, 1], [3, 1], [4, 1], [5, 1], [6, 1], [7, 1], [3, 2], [4, 2], [5, 2], [6, 2], [7, 2], [4, 3], [5, 3], [6, 3], [7, 3], [5, 4], [6, 4], [7, 4], [6, 5], [7, 5], [7, 6]], 9:[[1, 0], [2, 0], [3, 0], [4, 0], [5, 0], [6, 0], [7, 0], [8, 0], [2, 1], [3, 1], [4, 1], [5, 1], [6, 1], [7, 1], [8, 1], [3, 2], [4, 2], [5, 2], [6, 2], [7, 2], [8, 2], [4, 3], [5, 3], [6, 3], [7, 3], [8, 3], [5, 4], [6, 4], [7, 4], [8, 4], [6, 5], [7, 5], [8, 5], [7, 6], [8, 6], [8, 7]], } # combination_dict的torch版本 combination_dict_torch = {} for key in combination_dict: combination_dict_torch[key] = torch.tensor(combination_dict[key]) # 第一篇论文的loss def loss_v1(): loss = CrossEntropyLoss() rewards = torch.tensor(np.array([[1,2,3,4],[5,6,7,8],[9,10,11,12],[13,14,15,16]],dtype=np.float32), requires_grad=True) label = torch.tensor([0,1,2,3]) loss_value = loss(rewards,label) return loss_value # 第一篇论文的loss def loss_v1_plus(): rewards = torch.tensor(np.array([[1,2,3,4],[6,5,7,8],[11,10,9,12],[16,13,14,15]],dtype=np.float32), requires_grad=True) exp_rewards = rewards.exp() loss_value = - torch.log(exp_rewards[:,0]/torch.sum(exp_rewards, 1)) loss_value = loss_value.mean() return loss_value # 第一篇论文的loss(推荐),要保证放在rewards[:,0]处对应的句子是标记者更喜欢的 def loss_v1_plus_plus(): rewards = torch.tensor(np.array([[1,2,3,4],[6,5,7,8],[11,10,9,12],[16,13,14,15]],dtype=np.float32), requires_grad=True) sub_first_rewards = rewards-rewards[:,0][:,None] loss_value = torch.sum(sub_first_rewards.exp(), -1).log().mean() return loss_value # 第二篇论文的loss,要保证放在rewards[:,0]处对应的句子是标记者更喜欢的 def loss_v2(): loss = CrossEntropyLoss() rewards = torch.tensor(np.array([[1,2],[3,4],[5,6],[7,8]],dtype=np.float32), requires_grad=True) label = torch.tensor([0,1,0,1]) loss_value = loss(rewards,label) return loss_value # 第二篇论文的loss,要保证放在rewards[:,0]处对应的句子是标记者更喜欢的 def loss_v2_plus(): rewards = torch.tensor(np.array([[1,2],[4,3],[5,6],[8,7]],dtype=np.float32), requires_grad=True) exp_rewards = rewards.exp() loss_value = - torch.log(exp_rewards[:,0]/torch.sum(exp_rewards, 1)) loss_value = loss_value.mean() return loss_value # 第二篇论文的loss(推荐),要保证放在rewards[:,0]处对应的句子是标记者更喜欢的 def loss_v2_plus_plus(): rewards = torch.tensor(np.array([[1,2],[4,3],[5,6],[8,7]],dtype=np.float32), requires_grad=True) sub_first_rewards = rewards-rewards[:,0][:,None] loss_value = torch.sum(sub_first_rewards.exp(), -1).log().mean() return loss_value # 把排序转换成两两之间的pair label,label=1代表前面好于后面,rank数值越小代表越好 def rank_to_pair_label(ranks): # for 3 rewards, e.g., ranks = [[1,2,0],[2,1,0]], # 说明:[1,2,0] 代表index为0的reward的排名为1,index为1的reward的排名为2(最小,最差), index为2的reward的排名为0(最大,最好) N = len(ranks[0]) assert N <= 9, "N should be smaller than 10, but get {}".format(N) pair_labels = [] c_p = combination_dict[N] re_c_p = re_combination_dict[N] for rank in ranks: pair_label = [] for index in range(len(c_p)): if rank[c_p[index][0]]<=rank[c_p[index][1]]: pair_label.append(1) else: pair_label.append(0) pair_labels.append(pair_label) return pair_labels # 把pair label转换成两两之间的pair id, id在前的代表更好,label=1代表前面好于后面 def pair_label_to_gather_id(pair_labels): # for 3 rewards, there are 3 pair, e.g., pair_label = [[1,0,1],[0,0,1]] N = len(pair_labels[0]) assert N <= 9, "N should be smaller than 10, but get {}".format(N) c_p = combination_dict[N] re_c_p = re_combination_dict[N] gather_ids = [] for pair_label in pair_labels: gather_id = [] for index, p in enumerate(pair_label): if p: # p=1, means the first one is better gather_id.append(c_p[index]) else: gather_id.append(re_c_p[index]) gather_ids.append(gather_id) return gather_ids # 把两两之间的pair id转换成pair label, id在前的代表更好,label=1代表前面好于后面 def gather_id_to_pair_label(gather_ids): # e.g., gather_ids = [[[1, 0], [2, 0], [2, 1]], [[1, 0], [2, 0], [2, 1]], [[0, 1], [0, 2], [1, 2]]] N = len(gather_ids[0]) assert N <= 9, "N should be smaller than 10, but get {}".format(N) pair_labels = [] c_p = combination_dict[N] re_c_p = re_combination_dict[N] for gather_id in gather_ids: pair_label = [] for index, g in enumerate(gather_id): if g==c_p[index]: pair_label.append(1) else: assert g==re_c_p[index] pair_label.append(0) pair_labels.append(pair_label) return pair_labels # 测试 def test(loss_v1_f,loss_v2_f,loss_v3_f): t1,t2,t3 = 0.,0.,0. for i in range(100000): s_t = time.time() loss = loss_v1_f() t1 += time.time()-s_t s_t = time.time() loss_plus = loss_v2_f() t2 += time.time()-s_t s_t = time.time() loss_plus_plus = loss_v3_f() t3 += time.time()-s_t if i == 0: print("{} {} {}".format(loss,loss_plus,loss_plus_plus)) print(t1,t2,t3) # 第一篇论文loss测试 def v1_test(): test(loss_v1,loss_v1_plus,loss_v1_plus_plus) # 第二篇论文loss测试 def v2_test(): test(loss_v2,loss_v2_plus,loss_v2_plus_plus) # 第三篇论文loss(推荐:快速且简单),可以通过如下可得到gather_ids: # pair_labels = rank_to_pair_label(ranks) # gather_ids = torch.tensor(pair_label_to_gather_id(pair_labels)) def loss_v3(gather_ids,rewards): r1 = torch.gather(rewards, 1, gather_ids[:,:,0]) r2 = torch.gather(rewards, 1, gather_ids[:,:,1]) sub_r = r1-r2 loss_value = -torch.nn.functional.logsigmoid(sub_r).mean() return loss_value # 第三篇论文loss(较慢,计算复杂且慢,不推荐),可通过如下得到函数需要的参数: # pair_labels = torch.tensor(pair_labels) # N = len(pair_labels[0]) # c = combination_dict_torch[N] # indx = c[None].expand(len(rewards), -1, -1) def loss_v3_plus(pair_labels,rewards,c,idx): rewards_combined = rewards[:,None].expand(-1,len(c),-1).gather(dim=2, index=idx) r1 = rewards_combined[pair_labels.nonzero(as_tuple=True)] r2 = rewards_combined[(pair_labels==0).nonzero(as_tuple=True)] r1_sub = r1[:,0]-r1[:,1] r2_sub = r2[:,1]-r2[:,0] sub_r = torch.concat([r1_sub,r2_sub]) # print("loss_v3_plus:",sub_r.sort()) loss_value = -torch.nn.functional.logsigmoid(sub_r).mean(-1) return loss_value # 第三篇论文loss的测试 def v3_test(): pair_labels = [[0,0,0],[0,0,1],[1,1,1]] rewards = torch.tensor(np.array([[1,2,3],[2,4,6],[3,9,27]],dtype=np.float32), requires_grad=True) all_ranks = [] ranks = [[0,1,2],[0,1,2],[0,1,2]] all_ranks.append(ranks) ranks = [[0,1,2],[0,1,2],[2,1,0]] all_ranks.append(ranks) ranks = [[0,1,2],[2,1,0],[2,1,0]] all_ranks.append(ranks) ranks = [[1,0,2],[2,1,0],[2,1,0]] all_ranks.append(ranks) ranks = [[2,1,0],[2,1,0],[2,1,0]] all_ranks.append(ranks) t1, t2 = 0., 0. for i in range(5000): for ranks in all_ranks: pair_labels = rank_to_pair_label(ranks) gather_ids = torch.tensor(pair_label_to_gather_id(pair_labels)) s_t = time.time() l1 = loss_v3(gather_ids,rewards) t1+=time.time()-s_t pair_labels = torch.tensor(pair_labels) if i == 0: N = len(pair_labels[0]) c = combination_dict_torch[N] indx = c[None].expand(len(rewards), -1, -1) s_t = time.time() l2 = loss_v3_plus(pair_labels,rewards,c,indx) t2+=time.time()-s_t if i==0: print(l1,l2) print(t1,t2) # gather_id转换测试 def gather_id_convert(): pair_labels = [[1,0,1],[0,0,1]] gather_ids = pair_label_to_gather_id(pair_labels) recovered_pair_labels = gather_id_to_pair_label(gather_ids) recovered_gather_ids = pair_label_to_gather_id(recovered_pair_labels) print(pair_labels,recovered_pair_labels) print(gather_ids,recovered_gather_ids) assert pair_labels==recovered_pair_labels, "recovered pair_labels wrong!" assert gather_ids==recovered_gather_ids, "recovered gather_ids wrong!" if __name__== '__main__': v3_test()