第四阶段 审查模块开发 微调-RM部分
RM部分
奖励模型(Reward Model)是在强化学习中用于评估和打分模型输出的模型,根据输出的质量给予奖励或惩罚,从而指导模型进行优化。
功能:
- 评估输出质量:奖励模型根据预定义的标准评估模型的输出质量,给予高质量输出高奖励,低质量输出低奖励或惩罚。
- 指导优化方向:通过奖励机制,模型能够了解哪些输出是符合预期的,从而在训练过程中逐步提高输出质量。
# coding=utf-8
# Implements parameter-efficient training of reward models.
# This code is inspired by:
# https://github.com/lvwerra/trl/blob/main/examples/summarization/scripts/reward_summarization.py
# https://github.com/CarperAI/trlx/blob/main/examples/summarize_rlhf/reward_model/train_reward_model_gptj.py
from utils import (
PairwiseDataCollatorWithPadding,
PairwisePeftTrainer,
LogCallback,
load_pretrained,
prepare_args,
prepare_data,
preprocess_data,
compute_accuracy,
plot_loss
)
def main():
# Prepare pretrained model and dataset
model_args, data_args, training_args, finetuning_args = prepare_args(stage="rm")
dataset = prepare_data(model_args, data_args)
model, tokenizer = load_pretrained(model_args, finetuning_args, training_args.do_train, stage="rm")
dataset = preprocess_data(dataset, tokenizer, data_args, training_args, stage="rm")
data_collator = PairwiseDataCollatorWithPadding(tokenizer)
training_args.remove_unused_columns = False # important for pairwise dataset
# Split the dataset
if training_args.do_train:
if data_args.dev_ratio > 1e-6:
dataset = dataset.train_test_split(test_size=data_args.dev_ratio)
trainer_kwargs = {"train_dataset": dataset["train"], "eval_dataset": dataset["test"]}
else:
trainer_kwargs = {"train_dataset": dataset}
else: # do_eval or do_predict
trainer_kwargs = {"eval_dataset": dataset}
# Initialize our Trainer
trainer = PairwisePeftTrainer(
finetuning_args=finetuning_args,
model=model,
args=training_args,
tokenizer=tokenizer,
data_collator=data_collator,
callbacks=[LogCallback()],
compute_metrics=compute_accuracy,
**trainer_kwargs
)
# Training
if training_args.do_train:
train_result = trainer.train()
trainer.log_metrics("train", train_result.metrics)
trainer.save_metrics("train", train_result.metrics)
trainer.save_state()
trainer.save_model()
if trainer.is_world_process_zero() and model_args.plot_loss:
plot_loss(training_args.output_dir, keys=["loss", "eval_loss"])
# Evaluation
if training_args.do_eval:
metrics = trainer.evaluate(metric_key_prefix="eval")
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)
def _mp_fn(index):
# For xla_spawn (TPUs)
main()
if __name__ == "__main__":
main()
该代码的结构与PT_Train的结构相似,无非是将模型的选择替换为继承自PeftTrainer的PairwisePeftTrainer,并且将training_args.remove_unused_columns 设置为 False,这是为了保留数据集中所有的列,对于成对数据集来说很重要。而在模型方面,PairwisePeftTrainer的主要功能是实现成对损失(pairwise loss)的计算,用于训练奖励模型。在奖励模型的训练中,成对损失用于比较两个候选文本的好坏,并根据比较结果调整模型参数。此外与PeftTrainer 无异。
class PairwisePeftTrainer(PeftTrainer):
r"""
Inherits PeftTrainer to compute pairwise loss.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.can_return_loss = True # override property to return eval_loss
def compute_loss(self, model, inputs, return_outputs=False):
r"""
Computes pairwise loss. The first n examples are chosen and the last n examples are rejected.
We use score on the EOS token to represent reward of the whole sentence.
Subclass and override to inject custom behavior. It should not be directly used by external scripts.
Note that the first element will be removed from the output tuple.
See: https://github.com/huggingface/transformers/blob/v4.30.2/src/transformers/trainer.py#L3509
"""
batch_size = inputs["input_ids"].size(0) // 2
_, _, values = model(**inputs)
r_accept, r_reject = values[:, -1].split(batch_size, dim=0)
loss = -torch.log(torch.sigmoid(r_accept - r_reject)).mean()
return (loss, [loss, r_accept, r_reject]) if return_outputs else loss
浙公网安备 33010602011771号