Weights & Biases (W&B) 完全指南¶
为什么要学 W&B¶
ML 实验追踪的事实标准:W&B 是机器学习实验管理领域使用最广泛的工具。OpenAI、NVIDIA、Microsoft、Toyota 等超过 70,000 个团队在使用。它解决了"上周的模型参数是什么?哪次实验效果最好?"这类 ML 开发中的核心痛点。
两行代码集成:在现有训练代码中加入
wandb.init()和wandb.log()就能开始追踪。与 PyTorch、TensorFlow、HuggingFace、Keras、scikit-learn、XGBoost 等所有主流框架都有深度集成。超参数搜索(Sweeps):内置的 Sweeps 功能让你用声明式配置定义搜索空间,W&B 自动运行贝叶斯优化、网格搜索、随机搜索。分布式执行,一个 agent 跑一组参数。
团队协作与报告:W&B Reports 让你创建交互式实验报告,嵌入实时图表和代码。比起在 Slack 中贴截图,这是更专业的 ML 研究沟通方式。
模型版本管理(Artifacts & Registry):追踪数据集版本、模型 checkpoint、推理结果。完整的血缘追踪——从数据到模型到预测,每一步都可溯源。
核心概念详解¶
W&B 是什么(白话解释)¶
做机器学习实验就像做化学实验:你调整配方(超参数)、记录过程(指标)、保存结果(模型)。没有实验记录本,你很快就会忘记哪个配方效果最好。
W&B 就是 ML 的"数字实验记录本": - 自动记录每次实验的所有参数和指标 - 漂亮的图表对比不同实验 - 模型和数据的版本管理 - 团队共享实验结果
核心产品¶
| 产品 | 用途 | 说明 |
|---|---|---|
| Experiments | 实验追踪 | 记录参数、指标、代码、环境 |
| Sweeps | 超参数搜索 | 自动化参数调优 |
| Artifacts | 版本管理 | 数据集、模型、结果的版本控制 |
| Reports | 实验报告 | 交互式文档,嵌入图表 |
| Model Registry | 模型注册 | 模型生命周期管理 |
| Tables | 数据可视化 | 交互式数据表格和可视化 |
| Launch | 作业管理 | 在计算集群上运行实验 |
| Weave | AI 应用追踪 | LLM 应用的可观测性 |
W&B vs MLflow 对比¶
| 特性 | W&B | MLflow |
|---|---|---|
| 部署 | 云托管(免费+付费) | 自托管为主 |
| UI | 现代化,交互丰富 | 简洁,功能较少 |
| 实验追踪 | 自动+手动,非常详细 | 手动为主 |
| 超参数搜索 | 内置 Sweeps | 无(需集成 Optuna 等) |
| 协作 | 团队功能丰富 | 基础 |
| 报告 | Reports(发布级别) | 无 |
| 数据可视化 | Tables/Charts/Media | 基本图表 |
| Artifact 管理 | 内置 | 内置 |
| 模型服务 | Registry + Deploy | MLflow Serving |
| LLM 支持 | Weave(LLM追踪) | MLflow Tracing |
| 开源 | 客户端开源,服务端闭源 | 全部开源 |
| 价格 | 免费个人版/付费团队版 | 免费(自托管成本) |
| 学习曲线 | 低(两行代码) | 中 |
安装与配置¶
# 安装
pip install wandb
# 登录(获取 API Key)
wandb login
# 或设置环境变量
export WANDB_API_KEY=your_api_key
# 验证
python -c "import wandb; print(wandb.__version__)"
基本配置¶
import wandb
# 初始化项目
run = wandb.init(
project="my-ml-project", # 项目名
name="experiment-001", # 运行名称(可选)
config={ # 超参数
"learning_rate": 0.001,
"batch_size": 32,
"epochs": 50,
"model": "ResNet50",
"optimizer": "Adam",
},
tags=["baseline", "v1"], # 标签
notes="首次基线实验", # 备注
group="experiment-group-1", # 分组
)
离线模式¶
# 无网络环境
import os
os.environ["WANDB_MODE"] = "offline"
# 或在 init 时设置
wandb.init(mode="offline")
# 之后同步到云端
# wandb sync wandb/offline-run-xxx
快速上手:5 分钟最小示例¶
import wandb
import random
import math
# 1. 初始化
wandb.init(
project="quickstart-demo",
config={
"learning_rate": 0.02,
"architecture": "CNN",
"dataset": "CIFAR-10",
"epochs": 10,
}
)
config = wandb.config
# 2. 模拟训练循环
for epoch in range(config.epochs):
# 模拟指标
train_loss = math.exp(-epoch * config.learning_rate) + random.uniform(-0.1, 0.1)
val_loss = train_loss + random.uniform(0, 0.2)
accuracy = 1 - val_loss + random.uniform(-0.05, 0.05)
# 3. 记录指标
wandb.log({
"epoch": epoch,
"train/loss": train_loss,
"val/loss": val_loss,
"val/accuracy": accuracy,
})
print(f"Epoch {epoch}: loss={train_loss:.4f}, acc={accuracy:.4f}")
# 4. 记录最终结果
wandb.summary["best_accuracy"] = 0.95
wandb.summary["total_params"] = 25_000_000
# 5. 结束
wandb.finish()
运行后打开控制台输出的 URL,即可在 W&B 仪表板中看到实验图表。
进阶用法¶
场景一:PyTorch 训练集成¶
import wandb
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
wandb.init(project="pytorch-mnist", config={
"lr": 1e-3, "batch_size": 64, "epochs": 10, "hidden_size": 128,
})
config = wandb.config
# 模型定义
model = nn.Sequential(
nn.Flatten(),
nn.Linear(28*28, config.hidden_size),
nn.ReLU(),
nn.Dropout(0.2),
nn.Linear(config.hidden_size, 10),
)
# 监控模型梯度和参数
wandb.watch(model, log="all", log_freq=100)
optimizer = optim.Adam(model.parameters(), lr=config.lr)
criterion = nn.CrossEntropyLoss()
# 数据
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
train_loader = torch.utils.data.DataLoader(
datasets.MNIST("data", train=True, download=True, transform=transform),
batch_size=config.batch_size, shuffle=True,
)
# 训练循环
for epoch in range(config.epochs):
model.train()
total_loss = 0
correct = 0
total = 0
for batch_idx, (data, target) in enumerate(train_loader):
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()
total_loss += loss.item()
pred = output.argmax(dim=1)
correct += pred.eq(target).sum().item()
total += target.size(0)
if batch_idx % 100 == 0:
wandb.log({
"batch_loss": loss.item(),
"running_accuracy": correct / total,
})
epoch_loss = total_loss / len(train_loader)
epoch_acc = correct / total
wandb.log({"epoch": epoch, "train_loss": epoch_loss, "train_accuracy": epoch_acc})
# 保存模型 Artifact
artifact = wandb.Artifact("mnist-model", type="model")
torch.save(model.state_dict(), "model.pth")
artifact.add_file("model.pth")
wandb.log_artifact(artifact)
wandb.finish()
场景二:HuggingFace Transformers 集成¶
from transformers import Trainer, TrainingArguments
import wandb
wandb.init(project="nlp-classification")
training_args = TrainingArguments(
output_dir="./results",
num_train_epochs=3,
per_device_train_batch_size=16,
learning_rate=2e-5,
report_to="wandb", # 自动集成
run_name="bert-classification-v1",
logging_steps=10,
evaluation_strategy="steps",
eval_steps=50,
save_steps=100,
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
compute_metrics=compute_metrics,
)
trainer.train()
wandb.finish()
场景三:Sweeps 超参数搜索¶
import wandb
# 1. 定义搜索空间
sweep_config = {
"method": "bayes", # bayes / grid / random
"metric": {"name": "val_accuracy", "goal": "maximize"},
"parameters": {
"learning_rate": {
"distribution": "log_uniform_values",
"min": 1e-5,
"max": 1e-2,
},
"batch_size": {"values": [16, 32, 64, 128]},
"hidden_size": {"values": [64, 128, 256, 512]},
"dropout": {"distribution": "uniform", "min": 0.1, "max": 0.5},
"optimizer": {"values": ["adam", "sgd", "adamw"]},
},
"early_terminate": {
"type": "hyperband",
"min_iter": 3,
"eta": 2,
},
}
# 2. 创建 sweep
sweep_id = wandb.sweep(sweep_config, project="sweep-demo")
# 3. 定义训练函数
def train():
wandb.init()
config = wandb.config
model = create_model(config.hidden_size, config.dropout)
optimizer = get_optimizer(config.optimizer, model, config.learning_rate)
for epoch in range(20):
train_loss, val_acc = train_epoch(model, optimizer, config.batch_size)
wandb.log({"train_loss": train_loss, "val_accuracy": val_acc, "epoch": epoch})
wandb.finish()
# 4. 启动 agent
wandb.agent(sweep_id, function=train, count=50)
场景四:Artifacts(数据集和模型版本管理)¶
import wandb
# 记录数据集 Artifact
run = wandb.init(project="artifact-demo", job_type="data-prep")
artifact = wandb.Artifact(
name="training-data",
type="dataset",
description="清洗后的训练数据集",
metadata={"rows": 10000, "features": 50, "version": "v2"},
)
artifact.add_dir("data/processed/")
wandb.log_artifact(artifact)
wandb.finish()
# 在训练中使用 Artifact
run = wandb.init(project="artifact-demo", job_type="training")
data_artifact = run.use_artifact("training-data:latest")
data_dir = data_artifact.download()
# 训练...
# 记录模型 Artifact
model_artifact = wandb.Artifact("trained-model", type="model",
metadata={"accuracy": 0.95, "framework": "pytorch"})
model_artifact.add_file("model.pth")
wandb.log_artifact(model_artifact)
wandb.finish()
场景五:Tables(数据可视化)¶
import wandb
wandb.init(project="tables-demo")
# 记录预测结果表格
columns = ["image", "true_label", "predicted", "confidence"]
table = wandb.Table(columns=columns)
for img, true, pred, conf in zip(images, labels, predictions, confidences):
table.add_data(
wandb.Image(img),
true,
pred,
conf,
)
wandb.log({"predictions": table})
# 记录混淆矩阵
wandb.log({
"confusion_matrix": wandb.plot.confusion_matrix(
probs=None,
y_true=true_labels,
preds=predicted_labels,
class_names=class_names,
)
})
# 记录 ROC 曲线
wandb.log({
"roc": wandb.plot.roc_curve(
y_true=true_labels,
y_probas=predicted_probas,
labels=class_names,
)
})
wandb.finish()
场景六:Reports(交互式实验报告)¶
import wandb
# Reports 主要通过 W&B 网页 UI 创建
# 但可以通过 API 创建
api = wandb.Api()
# 获取实验数据
runs = api.runs("my-project")
for run in runs:
print(f"{run.name}: accuracy={run.summary.get('accuracy', 'N/A')}")
# 获取特定 run 的历史数据
run = api.run("username/project/run_id")
history = run.history()
print(history[["epoch", "loss", "accuracy"]])
场景七:Alert 告警¶
import wandb
wandb.init(project="alert-demo")
for epoch in range(100):
loss = train_one_epoch()
wandb.log({"loss": loss})
# 训练损失异常时告警
if loss > 10.0:
wandb.alert(
title="训练损失异常",
text=f"Epoch {epoch} 的损失为 {loss:.2f},超过阈值 10.0",
level=wandb.AlertLevel.WARN,
)
# 训练发散时告警
if loss != loss: # NaN check
wandb.alert(
title="训练发散!",
text=f"Epoch {epoch} 出现 NaN 损失",
level=wandb.AlertLevel.ERROR,
)
break
场景八:Model Registry¶
import wandb
# 将 Artifact 链接到 Model Registry
run = wandb.init(project="registry-demo")
# 训练并保存模型...
model_artifact = wandb.Artifact("my-model", type="model")
model_artifact.add_file("model.pth")
run.log_artifact(model_artifact)
# 链接到 Registry
run.link_artifact(
artifact=model_artifact,
target_path="my-org/model-registry/production-model",
aliases=["staging", "v1.2"],
)
wandb.finish()
常见问题与排错¶
问题一:wandb.log 的调用频率¶
# 不要每个 batch 都 log(会很慢)
# 推荐:每 N 个 batch 或每个 epoch log 一次
for step, batch in enumerate(dataloader):
loss = train_step(batch)
if step % 100 == 0: # 每100步记录一次
wandb.log({"loss": loss, "step": step})
问题二:大文件上传慢¶
# 使用 Reference Artifacts(不上传实际文件)
artifact = wandb.Artifact("large-dataset", type="dataset")
artifact.add_reference(f"s3://my-bucket/data/", max_objects=10000)
问题三:断网环境使用¶
问题四:多 GPU 训练中只记录一次¶
import wandb
import torch.distributed as dist
if dist.get_rank() == 0:
wandb.init(project="distributed-training")
# 只在 rank 0 记录
if dist.get_rank() == 0:
wandb.log({"loss": loss})
问题五:如何删除 run¶
问题六:自定义图表¶
# 使用 wandb.plot 或 plotly
import plotly.express as px
fig = px.scatter(df, x="epoch", y="loss", color="experiment")
wandb.log({"custom_chart": fig})
# 使用 wandb.plot 内置图表
wandb.log({
"pr_curve": wandb.plot.pr_curve(y_true, y_scores, labels=["positive", "negative"]),
})
参考资源¶
- 官方文档:https://docs.wandb.ai/
- GitHub:https://github.com/wandb/wandb
- W&B Courses(免费课程):https://www.wandb.courses/
- Fully Connected(博客):https://wandb.ai/fully-connected
- W&B Gallery(示例):https://wandb.ai/gallery
- API Reference:https://docs.wandb.ai/ref/python/
- W&B Community:https://community.wandb.ai/
- 定价:https://wandb.ai/pricing(个人免费)