TRL 强化学习训练 — 从 SFT 到 RLHF/DPO 的完整 LLM 训练库
一句话说明
TRL(Transformer Reinforcement Learning,v0.15+)是 HuggingFace 的 LLM 对齐训练库,支持 SFT(监督微调)、RLHF(人类反馈强化学习)、DPO(直接偏好优化)、GRPO 等方法,是 Unsloth、Axolotl 等框架的底层依赖。
安装与配置
# 安装 TRL(v0.15+)
pip install trl # 基础安装
pip install trl[peft] # 含 PEFT 支持
pip install trl[vllm] # 含 vLLM 推理加速(用于 RLHF 生成)
# 验证版本
python -c "import trl; print(trl.__version__)"
核心用法
# TRL 提供三类 Trainer,分别对应三个训练阶段
# === 阶段1:SFTTrainer(监督微调)===
from trl import SFTConfig, SFTTrainer # SFT = Supervised Fine-Tuning
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
dataset = load_dataset("trl-lib/Capybara", split="train") # 内置示例数据集
training_args = SFTConfig(
output_dir="./sft_output", # 输出目录
max_seq_length=512, # 最大序列长度
per_device_train_batch_size=4, # 每卡 batch 大小
num_train_epochs=3, # 训练轮数
logging_steps=10, # 日志间隔
)
trainer = SFTTrainer(
model=model,
args=training_args,
train_dataset=dataset,
processing_class=tokenizer, # v0.15+ 改用 processing_class 替代 tokenizer
)
trainer.train()
参数详解
# === 阶段2:DPOTrainer(直接偏好优化,替代 PPO 的更稳定方案)===
from trl import DPOConfig, DPOTrainer
# DPO 数据格式:每条数据包含一个好答案和一个坏答案
# {"prompt": "...", "chosen": "好答案", "rejected": "坏答案"}
dpo_dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train")
dpo_args = DPOConfig(
output_dir="./dpo_output",
beta=0.1, # KL 散度惩罚系数(越大越靠近参考模型)
per_device_train_batch_size=2,
gradient_accumulation_steps=4,
learning_rate=5e-7, # DPO 学习率要比 SFT 小很多
num_train_epochs=1,
loss_type="sigmoid", # DPO 损失类型(sigmoid/hinge/ipo/kto_pair)
)
dpo_trainer = DPOTrainer(
model=model, # 待训练模型
ref_model=None, # 参考模型(None=用 SFT 模型的冻结副本)
args=dpo_args,
train_dataset=dpo_dataset,
processing_class=tokenizer,
)
dpo_trainer.train()
# === 阶段3:GRPOTrainer(组相对策略优化,DeepSeek-R1 同款算法)===
from trl import GRPOConfig, GRPOTrainer
def reward_fn(completions, **kwargs): # 自定义奖励函数(核心!)
"""根据模型输出计算奖励分数"""
rewards = []
for completion in completions:
# 示例:答案中包含数字则得高分
score = 1.0 if any(c.isdigit() for c in completion) else 0.0
rewards.append(score)
return rewards
grpo_args = GRPOConfig(
output_dir="./grpo_output",
num_generations=8, # 每个 prompt 生成多少个候选答案
max_completion_length=256, # 生成答案的最大长度
per_device_train_batch_size=4,
learning_rate=1e-6,
)
grpo_trainer = GRPOTrainer(
model=model,
args=grpo_args,
reward_funcs=reward_fn, # 奖励函数(支持列表,多个奖励叠加)
train_dataset=dataset,
processing_class=tokenizer,
)
grpo_trainer.train()
实战案例
# 用 GRPO 训练数学推理模型(类似 DeepSeek-R1 的方法)
import re
from trl import GRPOConfig, GRPOTrainer
# 奖励函数1:格式正确(回答包含 <think> 标签)
def format_reward(completions, **kwargs):
rewards = []
for comp in completions:
has_think = "<think>" in comp and "</think>" in comp
rewards.append(1.0 if has_think else 0.0)
return rewards
# 奖励函数2:答案正确
def accuracy_reward(completions, ground_truth, **kwargs):
rewards = []
for comp, gt in zip(completions, ground_truth):
# 提取最终答案(格式:<answer>42</answer>)
match = re.search(r"<answer>(.*?)</answer>", comp)
pred = match.group(1).strip() if match else ""
rewards.append(1.0 if pred == str(gt) else 0.0)
return rewards
grpo_trainer = GRPOTrainer(
model=model,
args=GRPOConfig(output_dir="./math_model", num_generations=4),
reward_funcs=[format_reward, accuracy_reward], # 多个奖励函数叠加
train_dataset=math_dataset,
processing_class=tokenizer,
)
grpo_trainer.train()
常见报错与解决
| 报错信息 | 原因 | 解决方法 |
|---|
processing_class 参数报错 | 旧版 TRL(<0.13) | 升级 pip install --upgrade trl |
| DPO loss 为 NaN | 学习率太大或 beta 设置问题 | 降低 learning_rate 到 1e-7 |
| GRPO 生成超时 | num_generations 太大 | 减少到 4-8 |
chosen 格式错误 | DPO 数据格式不对 | 确保数据包含 prompt/chosen/rejected 字段 |
| 奖励函数返回值类型错误 | 必须返回 list | 确保返回 list[float] |
速查表
| 训练方法 | Trainer 类 | 适用场景 |
|---|
| 监督微调 | SFTTrainer | 最基础的指令微调 |
| 直接偏好优化 | DPOTrainer | 替代 RLHF,更稳定 |
| 奖励模型 | RewardTrainer | 训练打分模型 |
| PPO | PPOTrainer | 经典 RLHF |
| GRPO | GRPOTrainer | DeepSeek-R1 同款推理优化 |
| KTO | KTOTrainer | 只需好/坏标签,无需配对 |
| 官方文档 | https://huggingface.co/docs/trl | — |