FSDP 全分片数据并行 — PyTorch 原生的大模型多卡训练方案
一句话说明
FSDP(Fully Sharded Data Parallel)是 PyTorch 内置的分布式训练方式,把模型参数、梯度、优化器状态都分散到多张 GPU 上,不需要额外安装依赖,Meta 内部用它训练 LLaMA 系列模型。
安装与配置
# FSDP 是 PyTorch 内置功能,无需额外安装
pip install torch # PyTorch 2.0+ 内置 FSDP
# 配合 accelerate 使用(推荐,更简单)
pip install accelerate transformers # HuggingFace 生态
# 生成 accelerate 配置(交互式向导)
accelerate config # 选择 FSDP 模式,跟着提示走
# 验证
python -c "from torch.distributed.fsdp import FullyShardedDataParallel; print('OK')"
核心用法
# 方式1:直接用 PyTorch FSDP API
import torch
import torch.distributed as dist
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy # 自动分片策略
# 初始化分布式环境
dist.init_process_group(backend="nccl") # NCCL 是 GPU 间通信最快的后端
local_rank = dist.get_rank()
torch.cuda.set_device(local_rank)
# 加载模型
model = MyLargeModel()
# 用 FSDP 包装模型(核心步骤)
model = FSDP(
model,
auto_wrap_policy=size_based_auto_wrap_policy, # 自动对大于阈值的层分片
device_id=local_rank, # 当前 GPU 编号
)
# 正常训练(FSDP 自动处理分片和通信)
for batch in dataloader:
output = model(batch)
loss = loss_fn(output)
loss.backward() # 梯度自动 all-reduce
optimizer.step()
optimizer.zero_grad()
参数详解
from torch.distributed.fsdp import (
FullyShardedDataParallel as FSDP,
StateDictType,
FullStateDictConfig,
MixedPrecision,
ShardingStrategy, # 分片策略
CPUOffload, # CPU Offload 选项
)
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
import functools
# 分片策略说明
# FULL_SHARD:全分片(模型+梯度+优化器都分片)= 最省显存
# SHARD_GRAD_OP:只分片梯度和优化器(推理更快)
# NO_SHARD:不分片(等同于 DDP)
# 混合精度配置
mp_policy = MixedPrecision(
param_dtype=torch.bfloat16, # 参数用 bf16
reduce_dtype=torch.bfloat16, # 梯度聚合用 bf16
buffer_dtype=torch.bfloat16, # 缓冲区用 bf16
)
# 对 Transformer 层自动分片(比 size_based 更精准)
from transformers.models.llama.modeling_llama import LlamaDecoderLayer
auto_wrap_policy = functools.partial(
transformer_auto_wrap_policy,
transformer_layer_cls={LlamaDecoderLayer}, # 对每个 Decoder 层单独分片
)
model = FSDP(
model,
auto_wrap_policy=auto_wrap_policy,
mixed_precision=mp_policy, # 混合精度
sharding_strategy=ShardingStrategy.FULL_SHARD, # 全分片策略
cpu_offload=CPUOffload(offload_params=True), # 把参数 offload 到 CPU(极度省显存)
device_id=torch.cuda.current_device(),
)
实战案例
# 方式2:用 accelerate + FSDP(最简单的集成方式)
# 先运行 accelerate config 生成配置,然后:
from accelerate import Accelerator
from transformers import AutoModelForCausalLM, AutoTokenizer
accelerator = Accelerator() # 自动读取 accelerate 配置
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.2-8B")
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5)
# 让 accelerate 自动处理 FSDP 分片
model, optimizer, dataloader = accelerator.prepare(model, optimizer, dataloader)
for batch in dataloader:
outputs = model(**batch)
loss = outputs.loss
accelerator.backward(loss) # 用 accelerate 的 backward
optimizer.step()
optimizer.zero_grad()
# 保存 FSDP 模型(需要特殊处理)
# FSDP 保存方式
from torch.distributed.fsdp import StateDictType, FullStateDictConfig
with FSDP.state_dict_type(
model,
StateDictType.FULL_STATE_DICT, # 收集所有分片,合并成完整权重
FullStateDictConfig(offload_to_cpu=True, rank0_only=True), # 只在主进程保存
):
state_dict = model.state_dict()
if dist.get_rank() == 0: # 只有主进程保存
torch.save(state_dict, "model.pt")
# accelerate FSDP 配置文件示例(~/.cache/huggingface/accelerate/default_config.yaml)
compute_environment: LOCAL_MACHINE
distributed_type: FSDP # 指定使用 FSDP
fsdp_config:
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP # 按 Transformer 层自动分片
fsdp_backward_prefetch: BACKWARD_PRE # 预取下一层参数(加速)
fsdp_offload_params: false # 是否 offload 到 CPU
fsdp_sharding_strategy: 1 # 1=FULL_SHARD
fsdp_state_dict_type: FULL_STATE_DICT # 保存完整权重
num_processes: 4 # GPU 数量
常见报错与解决
| 报错信息 | 原因 | 解决方法 |
|---|
Expected all tensors on same device | 模型没被 FSDP 正确包装 | 确保在 FSDP() 之后再移到设备 |
| 保存的模型是空的 | 保存方式错误 | 必须用 FULL_STATE_DICT 类型保存 |
| 训练速度慢 | CPU Offload 开销 | 显存够用则关闭 offload_params |
| NCCL timeout | 节点间通信超时 | 增加 NCCL_TIMEOUT 环境变量 |
wrap_policy 找不到层 | 层类名不对 | 用 print(model) 查看实际层类名 |
速查表
| 功能 | 方式 |
|---|
| 全分片 | ShardingStrategy.FULL_SHARD |
| CPU Offload | CPUOffload(offload_params=True) |
| 混合精度 | MixedPrecision(param_dtype=torch.bfloat16) |
| 保存模型 | StateDictType.FULL_STATE_DICT |
| accelerate 启动 | accelerate launch --config_file config.yaml train.py |
| 官方文档 | https://pytorch.org/docs/stable/fsdp.html |