跳转至

DeepSpeed 训练加速 — 微软出品的大模型分布式训练加速引擎


一句话说明

DeepSpeed 是微软开源的深度学习训练优化库,通过 ZeRO 优化(Zero Redundancy Optimizer)将模型状态分散到多张 GPU 上,让单机 8 卡可以训练 1750 亿参数的模型,同时提供混合精度、梯度压缩等加速手段。


安装与配置

# 安装 DeepSpeed(需要先装 PyTorch + CUDA)
pip install deepspeed                      # 基础安装

# 完整安装(含所有可选优化组件)
DS_BUILD_OPS=1 pip install deepspeed      # 编译所有 CUDA 扩展(慢但功能全)

# 只编译需要的组件(推荐)
DS_BUILD_FUSED_ADAM=1 pip install deepspeed    # 只编译 FusedAdam 优化器

# 验证安装
ds_report                                  # 查看支持的功能和 GPU 信息

核心用法

# 单机多卡训练(4 张 GPU)
deepspeed --num_gpus=4 train.py --deepspeed ds_config.json

# 多机多卡训练(2台机器,每台 8 卡)
deepspeed --num_nodes=2 --num_gpus=8 \
    --hostfile hostfile \                  # 主机文件,记录各节点地址
    train.py --deepspeed ds_config.json
# 在训练代码中集成 DeepSpeed(替换 PyTorch 训练循环)
import deepspeed
from transformers import AutoModelForCausalLM

model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.2-8B")

# 用 DeepSpeed 初始化引擎(替代 optimizer 和 scaler)
model_engine, optimizer, _, _ = deepspeed.initialize(
    model=model,
    model_parameters=model.parameters(),
    config="ds_config.json"                # DeepSpeed 配置文件
)

# 训练循环(用 engine 替代原始 model)
for batch in dataloader:
    outputs = model_engine(batch["input_ids"])  # 前向传播
    loss = outputs.loss
    model_engine.backward(loss)                  # 反向传播(DeepSpeed 管理)
    model_engine.step()                          # 参数更新(DeepSpeed 管理)

参数详解

DeepSpeed 核心配置文件 ds_config.json,有 ZeRO Stage 1/2/3 三个级别:

{
    "train_batch_size": 16,
    "gradient_accumulation_steps": 4,
    "train_micro_batch_size_per_gpu": 1,

    "bf16": {
        "enabled": true
    },

    "zero_optimization": {
        "stage": 2,

        "stage3_param_persistence_threshold": 1e4,
        "stage3_max_live_parameters": 3e7,
        "stage3_prefetch_bucket_size": 3e7,
        "memory_efficient_linear": false,

        "allgather_partitions": true,
        "reduce_scatter": true,
        "allgather_bucket_size": 2e8,
        "overlap_comm": true,
        "contiguous_gradients": true
    },

    "optimizer": {
        "type": "AdamW",
        "params": {
            "lr": "auto",
            "betas": "auto",
            "eps": "auto",
            "weight_decay": "auto"
        }
    },

    "scheduler": {
        "type": "WarmupLR",
        "params": {
            "warmup_min_lr": "auto",
            "warmup_max_lr": "auto",
            "warmup_num_steps": "auto"
        }
    }
}

ZeRO 三个阶段的区别(越高显存越少,通信开销越大):

ZeRO Stage分片内容适用场景
Stage 1优化器状态多卡分担优化器内存
Stage 2优化器状态 + 梯度常用,省显存且通信不多
Stage 3优化器状态 + 梯度 + 模型参数超大模型,但通信开销大

实战案例

# 与 HuggingFace Transformers 集成(最常用方式)

from transformers import TrainingArguments, Trainer, AutoModelForCausalLM

# 只需在 TrainingArguments 中指定 deepspeed 配置
training_args = TrainingArguments(
    output_dir="./output",
    per_device_train_batch_size=1,          # 单卡 batch(配合梯度累积)
    gradient_accumulation_steps=16,
    num_train_epochs=3,
    learning_rate=2e-5,
    bf16=True,
    deepspeed="ds_config_zero2.json",       # ZeRO Stage 2 配置文件路径
    logging_steps=10,
)

model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.2-8B")

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
)
trainer.train()
// ds_config_zero2.json(ZeRO Stage 2 标准配置)
{
    "zero_optimization": {
        "stage": 2,                         // Stage 2:分片优化器状态+梯度
        "overlap_comm": true,              // 通信与计算重叠(提速)
        "contiguous_gradients": true       // 连续内存梯度(提速)
    },
    "bf16": {"enabled": true},
    "train_micro_batch_size_per_gpu": 1,
    "gradient_accumulation_steps": 16
}

常见报错与解决

报错信息原因解决方法
No module named 'deepspeed'未安装pip install deepspeed
NCCL error多卡通信失败检查 GPU 间连接,设置 NCCL_DEBUG=INFO
OOM with ZeRO Stage 2内存仍不足升级到 "stage": 3
config json missing keys配置文件格式错误参考官方模板
训练速度很慢(Stage 3)通信开销改用 Stage 2 或开启 overlap_comm

速查表

ZeRO Stage配置推荐场景
Stage 1"stage": 1显存够用,优化训练效率
Stage 2"stage": 2大多数情况(推荐默认)
Stage 3"stage": 3超大模型(70B+)
Offload CPU"offload_optimizer": {"device": "cpu"}进一步省显存
启动命令deepspeed --num_gpus=N train.py
官方文档https://www.deepspeed.ai/docs/config-json/