跳转至

TRL 强化学习训练 — 从 SFT 到 RLHF/DPO 的完整 LLM 训练库


一句话说明

TRL(Transformer Reinforcement Learning,v0.15+)是 HuggingFace 的 LLM 对齐训练库,支持 SFT(监督微调)、RLHF(人类反馈强化学习)、DPO(直接偏好优化)、GRPO 等方法,是 Unsloth、Axolotl 等框架的底层依赖。


安装与配置

# 安装 TRL(v0.15+)
pip install trl                            # 基础安装
pip install trl[peft]                      # 含 PEFT 支持
pip install trl[vllm]                      # 含 vLLM 推理加速(用于 RLHF 生成)

# 验证版本
python -c "import trl; print(trl.__version__)"

核心用法

# TRL 提供三类 Trainer,分别对应三个训练阶段

# === 阶段1:SFTTrainer(监督微调)===
from trl import SFTConfig, SFTTrainer                  # SFT = Supervised Fine-Tuning
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer

model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
dataset = load_dataset("trl-lib/Capybara", split="train")  # 内置示例数据集

training_args = SFTConfig(
    output_dir="./sft_output",             # 输出目录
    max_seq_length=512,                    # 最大序列长度
    per_device_train_batch_size=4,         # 每卡 batch 大小
    num_train_epochs=3,                    # 训练轮数
    logging_steps=10,                      # 日志间隔
)

trainer = SFTTrainer(
    model=model,
    args=training_args,
    train_dataset=dataset,
    processing_class=tokenizer,            # v0.15+ 改用 processing_class 替代 tokenizer
)
trainer.train()

参数详解

# === 阶段2:DPOTrainer(直接偏好优化,替代 PPO 的更稳定方案)===
from trl import DPOConfig, DPOTrainer

# DPO 数据格式:每条数据包含一个好答案和一个坏答案
# {"prompt": "...", "chosen": "好答案", "rejected": "坏答案"}
dpo_dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train")

dpo_args = DPOConfig(
    output_dir="./dpo_output",
    beta=0.1,                              # KL 散度惩罚系数(越大越靠近参考模型)
    per_device_train_batch_size=2,
    gradient_accumulation_steps=4,
    learning_rate=5e-7,                    # DPO 学习率要比 SFT 小很多
    num_train_epochs=1,
    loss_type="sigmoid",                   # DPO 损失类型(sigmoid/hinge/ipo/kto_pair)
)

dpo_trainer = DPOTrainer(
    model=model,                           # 待训练模型
    ref_model=None,                        # 参考模型(None=用 SFT 模型的冻结副本)
    args=dpo_args,
    train_dataset=dpo_dataset,
    processing_class=tokenizer,
)
dpo_trainer.train()

# === 阶段3:GRPOTrainer(组相对策略优化,DeepSeek-R1 同款算法)===
from trl import GRPOConfig, GRPOTrainer

def reward_fn(completions, **kwargs):      # 自定义奖励函数(核心!)
    """根据模型输出计算奖励分数"""
    rewards = []
    for completion in completions:
        # 示例:答案中包含数字则得高分
        score = 1.0 if any(c.isdigit() for c in completion) else 0.0
        rewards.append(score)
    return rewards

grpo_args = GRPOConfig(
    output_dir="./grpo_output",
    num_generations=8,                     # 每个 prompt 生成多少个候选答案
    max_completion_length=256,             # 生成答案的最大长度
    per_device_train_batch_size=4,
    learning_rate=1e-6,
)

grpo_trainer = GRPOTrainer(
    model=model,
    args=grpo_args,
    reward_funcs=reward_fn,                # 奖励函数(支持列表,多个奖励叠加)
    train_dataset=dataset,
    processing_class=tokenizer,
)
grpo_trainer.train()

实战案例

# 用 GRPO 训练数学推理模型(类似 DeepSeek-R1 的方法)

import re
from trl import GRPOConfig, GRPOTrainer

# 奖励函数1:格式正确(回答包含 <think> 标签)
def format_reward(completions, **kwargs):
    rewards = []
    for comp in completions:
        has_think = "<think>" in comp and "</think>" in comp
        rewards.append(1.0 if has_think else 0.0)
    return rewards

# 奖励函数2:答案正确
def accuracy_reward(completions, ground_truth, **kwargs):
    rewards = []
    for comp, gt in zip(completions, ground_truth):
        # 提取最终答案(格式:<answer>42</answer>)
        match = re.search(r"<answer>(.*?)</answer>", comp)
        pred = match.group(1).strip() if match else ""
        rewards.append(1.0 if pred == str(gt) else 0.0)
    return rewards

grpo_trainer = GRPOTrainer(
    model=model,
    args=GRPOConfig(output_dir="./math_model", num_generations=4),
    reward_funcs=[format_reward, accuracy_reward],  # 多个奖励函数叠加
    train_dataset=math_dataset,
    processing_class=tokenizer,
)
grpo_trainer.train()

常见报错与解决

报错信息原因解决方法
processing_class 参数报错旧版 TRL(<0.13)升级 pip install --upgrade trl
DPO loss 为 NaN学习率太大或 beta 设置问题降低 learning_rate 到 1e-7
GRPO 生成超时num_generations 太大减少到 4-8
chosen 格式错误DPO 数据格式不对确保数据包含 prompt/chosen/rejected 字段
奖励函数返回值类型错误必须返回 list确保返回 list[float]

速查表

训练方法Trainer 类适用场景
监督微调SFTTrainer最基础的指令微调
直接偏好优化DPOTrainer替代 RLHF,更稳定
奖励模型RewardTrainer训练打分模型
PPOPPOTrainer经典 RLHF
GRPOGRPOTrainerDeepSeek-R1 同款推理优化
KTOKTOTrainer只需好/坏标签,无需配对
官方文档https://huggingface.co/docs/trl