Colossal-AI 训练框架 — 国产开源的高效大模型训练系统
一句话说明
Colossal-AI 是国内团队(伏羲资本)开源的大模型训练系统,提供内存友好的并行策略和 Booster API,可以用一张消费级显卡微调 7B 模型,同时支持分布式预训练,是对标 DeepSpeed 的国产方案。
安装与配置
# 安装 Colossal-AI(需要 PyTorch 2.1+)
pip install colossalai # 基础安装
# 安装含扩展的版本(推荐,功能更全)
pip install colossalai[flash-attn] # 含 Flash Attention
# 从源码安装(获取最新功能)
git clone https://github.com/hpcaitech/ColossalAI.git
cd ColossalAI
pip install .
# 验证安装
python -c "import colossalai; print(colossalai.__version__)"
核心用法
# Colossal-AI 核心是 Booster API(类似 accelerate 的包装层)
import colossalai
from colossalai.booster import Booster
from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin # 两个核心插件
from colossalai.nn.optimizer import HybridAdam # 内存优化的 Adam
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
# 初始化分布式(单机时自动处理)
colossalai.launch_from_torch()
# 加载模型
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.2-7B")
optimizer = HybridAdam(model.parameters(), lr=1e-4) # HybridAdam 比 AdamW 省显存
# 选择插件(决定训练策略)
# GeminiPlugin:最省显存,适合单卡/显存紧张的情况
plugin = GeminiPlugin(
precision="bf16", # 训练精度
placement_policy="auto", # 自动决定参数放 GPU 还是 CPU
initial_scale=2**16, # 混合精度初始缩放系数
)
# 用 Booster 包装(类似 accelerate.prepare)
booster = Booster(plugin=plugin)
model, optimizer, _, dataloader, _ = booster.boost(
model=model,
optimizer=optimizer,
dataloader=dataloader,
)
参数详解
# 两种主要插件对比:
# 1. GeminiPlugin(最省显存)
from colossalai.booster.plugin import GeminiPlugin
plugin = GeminiPlugin(
precision="bf16", # 精度:fp16/bf16/fp32
placement_policy="auto", # auto/cpu/cuda(auto 根据显存自动决定)
pin_memory=True, # 锁页内存(加速 CPU-GPU 传输)
offload_optim_frag=True, # Offload 优化器碎片到 CPU
max_norm=1.0, # 梯度裁剪最大范数
)
# 2. LowLevelZeroPlugin(类似 DeepSpeed ZeRO Stage 1/2)
from colossalai.booster.plugin import LowLevelZeroPlugin
plugin = LowLevelZeroPlugin(
stage=2, # 1=只分片优化器, 2=分片优化器+梯度
precision="bf16",
initial_scale=2**16,
max_norm=1.0,
)
# 3. TorchDDPPlugin(普通数据并行,不省显存但稳定)
from colossalai.booster.plugin import TorchDDPPlugin
plugin = TorchDDPPlugin(
find_unused_parameters=False, # 不检测未使用参数(加速)
)
实战案例
# 完整的 LLaMA 微调流程(单卡或多卡)
import colossalai
from colossalai.booster import Booster
from colossalai.booster.plugin import LowLevelZeroPlugin
from colossalai.nn.optimizer import HybridAdam
from transformers import AutoModelForCausalLM, AutoTokenizer
from torch.utils.data import DataLoader
from datasets import load_dataset
# 初始化分布式
colossalai.launch_from_torch()
# 模型和分词器
model_name = "meta-llama/Llama-3.2-7B-Instruct"
model = AutoModelForCausalLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
# 数据集
dataset = load_dataset("yahma/alpaca-cleaned", split="train[:1000]")
def collate_fn(batch): # 数据批处理函数
texts = [f"指令:{x['instruction']}\n输出:{x['output']}" for x in batch]
return tokenizer(texts, padding=True, truncation=True,
max_length=512, return_tensors="pt")
dataloader = DataLoader(dataset, batch_size=2, collate_fn=collate_fn)
# 优化器和插件
optimizer = HybridAdam(model.parameters(), lr=2e-5)
plugin = LowLevelZeroPlugin(stage=2, precision="bf16")
# Booster 包装
booster = Booster(plugin=plugin)
model, optimizer, _, dataloader, _ = booster.boost(
model=model, optimizer=optimizer, dataloader=dataloader)
# 训练循环
model.train()
for epoch in range(3): # 训练 3 轮
for batch in dataloader:
batch = {k: v.cuda() for k, v in batch.items()} # 移到 GPU
outputs = model(**batch, labels=batch["input_ids"])
loss = outputs.loss
booster.backward(loss, optimizer) # 用 booster.backward
optimizer.step()
optimizer.zero_grad()
print(f"Loss: {loss.item():.4f}")
# 保存模型
booster.save_model(model, "output/llama-ft", shard=True) # 分片保存(大模型必须)
常见报错与解决
| 报错信息 | 原因 | 解决方法 |
|---|
launch_from_torch 失败 | 分布式初始化问题 | 确保用 torchrun 或设置 RANK 环境变量 |
| GeminiPlugin OOM | 显存不够 | 设 placement_policy="cpu" |
shard=True 保存慢 | 大模型分片 | 正常现象,等待即可 |
| 梯度 NaN | 混合精度溢出 | 降低 initial_scale 或用 bf16 |
| 与 PEFT 不兼容 | 版本冲突 | 查阅官方 PEFT 集成文档 |
速查表
| 功能 | 方式 |
|---|
| 最省显存训练 | GeminiPlugin(placement_policy="auto") |
| ZeRO Stage 2 | LowLevelZeroPlugin(stage=2) |
| 内存优化优化器 | HybridAdam |
| 多卡启动 | torchrun --nproc_per_node=N train.py |
| 保存模型 | booster.save_model(model, path, shard=True) |
| GitHub | https://github.com/hpcaitech/ColossalAI |