跳转至

FSDP 全分片数据并行 — PyTorch 原生的大模型多卡训练方案


一句话说明

FSDP(Fully Sharded Data Parallel)是 PyTorch 内置的分布式训练方式,把模型参数、梯度、优化器状态都分散到多张 GPU 上,不需要额外安装依赖,Meta 内部用它训练 LLaMA 系列模型。


安装与配置

# FSDP 是 PyTorch 内置功能,无需额外安装
pip install torch                          # PyTorch 2.0+ 内置 FSDP

# 配合 accelerate 使用(推荐,更简单)
pip install accelerate transformers        # HuggingFace 生态

# 生成 accelerate 配置(交互式向导)
accelerate config                          # 选择 FSDP 模式,跟着提示走

# 验证
python -c "from torch.distributed.fsdp import FullyShardedDataParallel; print('OK')"

核心用法

# 方式1:直接用 PyTorch FSDP API
import torch
import torch.distributed as dist
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy  # 自动分片策略

# 初始化分布式环境
dist.init_process_group(backend="nccl")   # NCCL 是 GPU 间通信最快的后端
local_rank = dist.get_rank()
torch.cuda.set_device(local_rank)

# 加载模型
model = MyLargeModel()

# 用 FSDP 包装模型(核心步骤)
model = FSDP(
    model,
    auto_wrap_policy=size_based_auto_wrap_policy,  # 自动对大于阈值的层分片
    device_id=local_rank,                           # 当前 GPU 编号
)

# 正常训练(FSDP 自动处理分片和通信)
for batch in dataloader:
    output = model(batch)
    loss = loss_fn(output)
    loss.backward()                                # 梯度自动 all-reduce
    optimizer.step()
    optimizer.zero_grad()

参数详解

from torch.distributed.fsdp import (
    FullyShardedDataParallel as FSDP,
    StateDictType,
    FullStateDictConfig,
    MixedPrecision,
    ShardingStrategy,                              # 分片策略
    CPUOffload,                                    # CPU Offload 选项
)
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
import functools

# 分片策略说明
# FULL_SHARD:全分片(模型+梯度+优化器都分片)= 最省显存
# SHARD_GRAD_OP:只分片梯度和优化器(推理更快)
# NO_SHARD:不分片(等同于 DDP)

# 混合精度配置
mp_policy = MixedPrecision(
    param_dtype=torch.bfloat16,                   # 参数用 bf16
    reduce_dtype=torch.bfloat16,                  # 梯度聚合用 bf16
    buffer_dtype=torch.bfloat16,                  # 缓冲区用 bf16
)

# 对 Transformer 层自动分片(比 size_based 更精准)
from transformers.models.llama.modeling_llama import LlamaDecoderLayer

auto_wrap_policy = functools.partial(
    transformer_auto_wrap_policy,
    transformer_layer_cls={LlamaDecoderLayer},     # 对每个 Decoder 层单独分片
)

model = FSDP(
    model,
    auto_wrap_policy=auto_wrap_policy,
    mixed_precision=mp_policy,                     # 混合精度
    sharding_strategy=ShardingStrategy.FULL_SHARD, # 全分片策略
    cpu_offload=CPUOffload(offload_params=True),   # 把参数 offload 到 CPU(极度省显存)
    device_id=torch.cuda.current_device(),
)

实战案例

# 方式2:用 accelerate + FSDP(最简单的集成方式)
# 先运行 accelerate config 生成配置,然后:

from accelerate import Accelerator
from transformers import AutoModelForCausalLM, AutoTokenizer

accelerator = Accelerator()                        # 自动读取 accelerate 配置

model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.2-8B")
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5)

# 让 accelerate 自动处理 FSDP 分片
model, optimizer, dataloader = accelerator.prepare(model, optimizer, dataloader)

for batch in dataloader:
    outputs = model(**batch)
    loss = outputs.loss
    accelerator.backward(loss)                     # 用 accelerate 的 backward
    optimizer.step()
    optimizer.zero_grad()

# 保存 FSDP 模型(需要特殊处理)
# FSDP 保存方式
from torch.distributed.fsdp import StateDictType, FullStateDictConfig

with FSDP.state_dict_type(
    model,
    StateDictType.FULL_STATE_DICT,                 # 收集所有分片,合并成完整权重
    FullStateDictConfig(offload_to_cpu=True, rank0_only=True),  # 只在主进程保存
):
    state_dict = model.state_dict()
    if dist.get_rank() == 0:                       # 只有主进程保存
        torch.save(state_dict, "model.pt")
# accelerate FSDP 配置文件示例(~/.cache/huggingface/accelerate/default_config.yaml)
compute_environment: LOCAL_MACHINE
distributed_type: FSDP                            # 指定使用 FSDP
fsdp_config:
  fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP   # 按 Transformer 层自动分片
  fsdp_backward_prefetch: BACKWARD_PRE            # 预取下一层参数(加速)
  fsdp_offload_params: false                      # 是否 offload 到 CPU
  fsdp_sharding_strategy: 1                       # 1=FULL_SHARD
  fsdp_state_dict_type: FULL_STATE_DICT           # 保存完整权重
num_processes: 4                                  # GPU 数量

常见报错与解决

报错信息原因解决方法
Expected all tensors on same device模型没被 FSDP 正确包装确保在 FSDP() 之后再移到设备
保存的模型是空的保存方式错误必须用 FULL_STATE_DICT 类型保存
训练速度慢CPU Offload 开销显存够用则关闭 offload_params
NCCL timeout节点间通信超时增加 NCCL_TIMEOUT 环境变量
wrap_policy 找不到层层类名不对print(model) 查看实际层类名

速查表

功能方式
全分片ShardingStrategy.FULL_SHARD
CPU OffloadCPUOffload(offload_params=True)
混合精度MixedPrecision(param_dtype=torch.bfloat16)
保存模型StateDictType.FULL_STATE_DICT
accelerate 启动accelerate launch --config_file config.yaml train.py
官方文档https://pytorch.org/docs/stable/fsdp.html