RLHF与PPO:大模型对齐技术详解

张开发
2026/5/6 15:05:05 15 分钟阅读

分享文章

RLHF与PPO:大模型对齐技术详解
RLHF与PPO大模型对齐技术详解前言大语言模型通过海量文本学习到了强大的语言能力但如何让模型的输出符合人类期望和价值观RLHFReinforcement Learning from Human Feedback人类反馈强化学习是解决这一问题的核心技术。本文从原理到实践深入解析RLHF及其核心算法PPO。一、为什么需要RLHF1.1 SFT的局限性监督微调SFT让模型学习人类标注的问答对# SFT的局限sft_data[{instruction:如何制作炸弹,response:抱歉我不能帮助这个问题},{instruction:写一首诗,response:春眠不觉晓处处闻啼鸟...},# 人工标注成本高难以覆盖所有场景]# SFT问题# 1. 人工标注成本每条数据$0.5-$2# 2. 无法覆盖长尾场景安全性、毒性、偏见# 3. 难以表达复杂偏好简洁vs详细、正式vs随意1.2 人类偏好的复杂性人类评估维度维度SFTRLHF真实性❌ 难以保证✅ reward信号安全性❌ 需大量规则✅ 惩罚有害输出有用性人工标注人类偏好学习风格控制困难reward模型学习二、RLHF三阶段流程2.1 阶段一监督微调SFT# 第一步微调基础语言模型fromtransformersimportAutoModelForCausalLM,AutoTokenizer,TrainingArguments base_modelAutoModelForCausalLM.from_pretrained(meta-llama/Llama-2-7b)tokenizerAutoTokenizer.from_pretrained(meta-llama/Llama-2-7b)# 格式化为对话格式defformat_prompt(sample):returnf|user|{sample[instruction]}|assistant|{sample[response]}|end|2.2 阶段二训练Reward Model收集人类偏好数据# 人类偏好数据格式preference_data[{prompt:解释量子纠缠,chosen:量子纠缠是量子力学中两个或多个粒子之间存在的一种特殊关联...,# 人类选择rejected:量子纠缠就是两个粒子连在一起。# 人类拒绝},# 收集10万-100万对比数据]# Bradley-Terry模型P(prefer|AB) σ(r(A) - r(B))# 训练Reward Model预测人类偏好reward_modelAutoModelForCausalLM.from_pretrained(base_model,torch_dtypetorch.float16,)defreward_loss(reward_chosen,reward_rejected): 人类偏好损失chosen的reward应高于rejected logitsreward_chosen-reward_rejectedreturn-F.logsigmoid(logits).mean()2.3 阶段三强化学习PPO# PPO算法核心classPPO:def__init__(self,policy,ref_policy,reward_model,clip_ratio0.2):self.policypolicy# 当前策略self.ref_policyref_policy# SFT基线KL散度约束self.reward_modelreward_model self.clip_ratioclip_ratiodefcompute_advantage(self,rewards,values,gamma0.99,lam0.95): GAE (Generalized Advantage Estimation) 优势函数当前动作比平均水平好多少 advantages[]gae0fortinreversed(range(len(rewards))):deltarewards[t]gamma*values[t1]-values[t]gaedeltagamma*lam*gae advantages.insert(0,gae)returntorch.tensor(advantages)defpolicy_loss(self,log_probs,old_log_probs,advantages): PPO-Clip损失限制策略更新幅度 ratiotorch.exp(log_probs-old_log_probs)clippedtorch.clamp(ratio,1-self.clip_ratio,1self.clip_ratio)return-(torch.min(ratio*advantages,clipped*advantages)).mean()三、InstructGPT完整流程3.1 数据采集# 人类反馈数据采集平台human_feedback_pipeline{step1_generate:给定提示采样多个模型输出,step2_label:人类对输出对进行比较评分,step3_aggregate:使用Bradley-Terry模型估计奖励,}# 典型数据量DATASET_STATS{SFT_data:~10k-100k高质量对话,reward_model_data:~100k-1M偏好对比,PPO_data:~10k-100k提示无需标注,}3.2 PPO训练循环# 简化PPO训练循环defppo_train_step(policy,ref_policy,optimizer,prompts,reward_model):# 1. 用当前策略生成响应responsespolicy.generate(prompts)# 2. 用reward model打分reward_scoresreward_model(prompts,responses)# 3. 计算KL惩罚防止偏离SFT太远kl_penaltycompute_kl_divergence(responses,policy,ref_policy)# 4. 最终reward 模型reward - β * KLfinal_rewardreward_scores-0.04*kl_penalty# 5. PPO更新for_inrange(4):# 4个epoch# 计算优势函数advantagescompute_gae(final_reward)# 计算策略损失并更新policy_lossppo_objective(log_probs,old_log_probs,advantages)optimizer.zero_grad()policy_loss.backward()optimizer.step()returnpolicy_loss.item()3.3 完整训练代码fromtransformersimportAutoModelForCausalLM,AutoTokenizerfromtorch.utils.dataimportDataLoaderimporttorch.nn.functionalasF# 加载模型actorAutoModelForCausalLM.from_pretrained(llama-2-7b)ref_modelAutoModelForCausalLM.from_pretrained(llama-2-7b)reward_modelRewardModel.from_pretrained(reward-model-7b)# PPO配置ppo_config{lr:1e-6,clip_ratio:0.2,kl_coef:0.04,# KL惩罚系数num_epochs:4,batch_size:8,}# 训练循环forstepinrange(1000):# 采样提示promptssample_prompts(ppo_train_prompts,batch_size8)# 生成响应responsesactor.generate(prompts,max_length512)# Reward评估rewardsreward_model.get_reward(prompts,responses)# KL惩罚kl_penaltycompute_kl(responses,actor,ref_model)# 最终rewardfinal_rewardsrewards-ppo_config[kl_coef]*kl_penalty# PPO更新ppo_update(actor,final_rewards)四、DPO绕过PPO的替代方案4.1 DPO原理DPODirect Preference Optimization绕过了reward model训练和PPO# DPO损失函数defdpo_loss(policy_chosen,policy_rejected,ref_chosen,ref_rejected,beta0.1): DPO直接优化人类偏好无需RL # 策略偏好比policy_ratiopolicy_chosen-policy_rejected# 参考模型偏好比ref_ratioref_chosen-ref_rejected# DPO损失logitsbeta*(policy_ratio-ref_ratio)return-F.logsigmoid(logits).mean()4.2 DPO vs PPO对比维度PPORMDPO训练稳定性需要KL约束、clipping更稳定计算量需要4个模型actor/ref/reward/critic只需2个模型显存需求~4x模型大小~2x模型大小效果InstructGPT验证相当或更好超参多clip、kl_coef、gamma等少主要β4.3 DPO实践# DPO训练示例fromtransformersimportAutoModelForCausalLMfromdatasetsimportload_dataset datasetload_dataset(argilla/DPO-Mix-7K)# 现成偏好数据defdpo_train():policyAutoModelForCausalLM.from_pretrained(llama-2-7b)ref_modelAutoModelForCausalLM.from_pretrained(llama-2-7b)forbatchindataloader:# DPO损失lossdpo_loss(policy(batch[chosen]),policy(batch[rejected]),ref_model(batch[chosen]),ref_model(batch[rejected]),)loss.backward()optimizer.step()五、RLHF的挑战与未来5.1 主要挑战# RLHF的典型问题challenges{reward_hacking:模型找到欺骗reward的捷径而非真正完成目标,mode_collapse:输出变得单调缺少多样性,human_bias:人类反馈本身带有偏见模型学会这些偏见,训练不稳定:PPO训练需要仔细调参容易训崩,}5.2 解决方案# 1. Reward Hacking防护reward_hacking_solution{ensemble_rewards:多个reward model投票,constitutional_ai:加入规则约束,red_team:对抗性测试,}# 2. 输出多样性diversity_solution{do_sample:True,# 采样而非greedytemperature:0.7,# 温度采样nucleus_sampling:True,# Top-p采样}# 3. 更多人类反馈feedback_scaling{RLHF:~100K对比,RLAIF:用AI反馈替代人类Google Sparrow,ConstitutionalAI:用规则AI反馈Anthropic,}六、总结RLHF开启了让AI对齐人类价值观的时代RLHF技术演进 ├── 第一代纯SFT规则驱动 │ ├── 第二代RLHFPPOInstructGPT/GPT-3.5 │ └── 问题训练复杂、需4模型 │ └── 第三代DPODirect Preference Optimization └── 优势简化流程、训练更稳定GPT-4、Claude、Gemini等强大模型都经过了RLHF微调正是这项技术让AI从能说话进化到会说话。延伸阅读InstructGPT论文arxiv.org/abs/2203.02155RLHF综述arxiv.org/abs/2304.11540DPO论文arxiv.org/abs/2305.18290ChatGPT技术解析如何让AI学会对话

更多文章