跳转至

Weights & Biases (W&B) 完全指南

为什么要学 W&B

  1. ML 实验追踪的事实标准:W&B 是机器学习实验管理领域使用最广泛的工具。OpenAI、NVIDIA、Microsoft、Toyota 等超过 70,000 个团队在使用。它解决了"上周的模型参数是什么?哪次实验效果最好?"这类 ML 开发中的核心痛点。

  2. 两行代码集成:在现有训练代码中加入 wandb.init()wandb.log() 就能开始追踪。与 PyTorch、TensorFlow、HuggingFace、Keras、scikit-learn、XGBoost 等所有主流框架都有深度集成。

  3. 超参数搜索(Sweeps):内置的 Sweeps 功能让你用声明式配置定义搜索空间,W&B 自动运行贝叶斯优化、网格搜索、随机搜索。分布式执行,一个 agent 跑一组参数。

  4. 团队协作与报告:W&B Reports 让你创建交互式实验报告,嵌入实时图表和代码。比起在 Slack 中贴截图,这是更专业的 ML 研究沟通方式。

  5. 模型版本管理(Artifacts & Registry):追踪数据集版本、模型 checkpoint、推理结果。完整的血缘追踪——从数据到模型到预测,每一步都可溯源。


核心概念详解

W&B 是什么(白话解释)

做机器学习实验就像做化学实验:你调整配方(超参数)、记录过程(指标)、保存结果(模型)。没有实验记录本,你很快就会忘记哪个配方效果最好。

W&B 就是 ML 的"数字实验记录本": - 自动记录每次实验的所有参数和指标 - 漂亮的图表对比不同实验 - 模型和数据的版本管理 - 团队共享实验结果

核心产品

产品用途说明
Experiments实验追踪记录参数、指标、代码、环境
Sweeps超参数搜索自动化参数调优
Artifacts版本管理数据集、模型、结果的版本控制
Reports实验报告交互式文档,嵌入图表
Model Registry模型注册模型生命周期管理
Tables数据可视化交互式数据表格和可视化
Launch作业管理在计算集群上运行实验
WeaveAI 应用追踪LLM 应用的可观测性

W&B vs MLflow 对比

特性W&BMLflow
部署云托管(免费+付费)自托管为主
UI现代化,交互丰富简洁,功能较少
实验追踪自动+手动,非常详细手动为主
超参数搜索内置 Sweeps无(需集成 Optuna 等)
协作团队功能丰富基础
报告Reports(发布级别)
数据可视化Tables/Charts/Media基本图表
Artifact 管理内置内置
模型服务Registry + DeployMLflow Serving
LLM 支持Weave(LLM追踪)MLflow Tracing
开源客户端开源,服务端闭源全部开源
价格免费个人版/付费团队版免费(自托管成本)
学习曲线低(两行代码)

安装与配置

# 安装
pip install wandb

# 登录(获取 API Key)
wandb login
# 或设置环境变量
export WANDB_API_KEY=your_api_key

# 验证
python -c "import wandb; print(wandb.__version__)"

基本配置

import wandb

# 初始化项目
run = wandb.init(
    project="my-ml-project",    # 项目名
    name="experiment-001",       # 运行名称(可选)
    config={                     # 超参数
        "learning_rate": 0.001,
        "batch_size": 32,
        "epochs": 50,
        "model": "ResNet50",
        "optimizer": "Adam",
    },
    tags=["baseline", "v1"],     # 标签
    notes="首次基线实验",          # 备注
    group="experiment-group-1",   # 分组
)

离线模式

# 无网络环境
import os
os.environ["WANDB_MODE"] = "offline"

# 或在 init 时设置
wandb.init(mode="offline")

# 之后同步到云端
# wandb sync wandb/offline-run-xxx

快速上手:5 分钟最小示例

import wandb
import random
import math

# 1. 初始化
wandb.init(
    project="quickstart-demo",
    config={
        "learning_rate": 0.02,
        "architecture": "CNN",
        "dataset": "CIFAR-10",
        "epochs": 10,
    }
)
config = wandb.config

# 2. 模拟训练循环
for epoch in range(config.epochs):
    # 模拟指标
    train_loss = math.exp(-epoch * config.learning_rate) + random.uniform(-0.1, 0.1)
    val_loss = train_loss + random.uniform(0, 0.2)
    accuracy = 1 - val_loss + random.uniform(-0.05, 0.05)

    # 3. 记录指标
    wandb.log({
        "epoch": epoch,
        "train/loss": train_loss,
        "val/loss": val_loss,
        "val/accuracy": accuracy,
    })

    print(f"Epoch {epoch}: loss={train_loss:.4f}, acc={accuracy:.4f}")

# 4. 记录最终结果
wandb.summary["best_accuracy"] = 0.95
wandb.summary["total_params"] = 25_000_000

# 5. 结束
wandb.finish()

运行后打开控制台输出的 URL,即可在 W&B 仪表板中看到实验图表。


进阶用法

场景一:PyTorch 训练集成

import wandb
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms

wandb.init(project="pytorch-mnist", config={
    "lr": 1e-3, "batch_size": 64, "epochs": 10, "hidden_size": 128,
})
config = wandb.config

# 模型定义
model = nn.Sequential(
    nn.Flatten(),
    nn.Linear(28*28, config.hidden_size),
    nn.ReLU(),
    nn.Dropout(0.2),
    nn.Linear(config.hidden_size, 10),
)

# 监控模型梯度和参数
wandb.watch(model, log="all", log_freq=100)

optimizer = optim.Adam(model.parameters(), lr=config.lr)
criterion = nn.CrossEntropyLoss()

# 数据
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
train_loader = torch.utils.data.DataLoader(
    datasets.MNIST("data", train=True, download=True, transform=transform),
    batch_size=config.batch_size, shuffle=True,
)

# 训练循环
for epoch in range(config.epochs):
    model.train()
    total_loss = 0
    correct = 0
    total = 0

    for batch_idx, (data, target) in enumerate(train_loader):
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        pred = output.argmax(dim=1)
        correct += pred.eq(target).sum().item()
        total += target.size(0)

        if batch_idx % 100 == 0:
            wandb.log({
                "batch_loss": loss.item(),
                "running_accuracy": correct / total,
            })

    epoch_loss = total_loss / len(train_loader)
    epoch_acc = correct / total
    wandb.log({"epoch": epoch, "train_loss": epoch_loss, "train_accuracy": epoch_acc})

# 保存模型 Artifact
artifact = wandb.Artifact("mnist-model", type="model")
torch.save(model.state_dict(), "model.pth")
artifact.add_file("model.pth")
wandb.log_artifact(artifact)

wandb.finish()

场景二:HuggingFace Transformers 集成

from transformers import Trainer, TrainingArguments
import wandb

wandb.init(project="nlp-classification")

training_args = TrainingArguments(
    output_dir="./results",
    num_train_epochs=3,
    per_device_train_batch_size=16,
    learning_rate=2e-5,
    report_to="wandb",                  # 自动集成
    run_name="bert-classification-v1",
    logging_steps=10,
    evaluation_strategy="steps",
    eval_steps=50,
    save_steps=100,
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    compute_metrics=compute_metrics,
)

trainer.train()
wandb.finish()

场景三:Sweeps 超参数搜索

import wandb

# 1. 定义搜索空间
sweep_config = {
    "method": "bayes",  # bayes / grid / random
    "metric": {"name": "val_accuracy", "goal": "maximize"},
    "parameters": {
        "learning_rate": {
            "distribution": "log_uniform_values",
            "min": 1e-5,
            "max": 1e-2,
        },
        "batch_size": {"values": [16, 32, 64, 128]},
        "hidden_size": {"values": [64, 128, 256, 512]},
        "dropout": {"distribution": "uniform", "min": 0.1, "max": 0.5},
        "optimizer": {"values": ["adam", "sgd", "adamw"]},
    },
    "early_terminate": {
        "type": "hyperband",
        "min_iter": 3,
        "eta": 2,
    },
}

# 2. 创建 sweep
sweep_id = wandb.sweep(sweep_config, project="sweep-demo")

# 3. 定义训练函数
def train():
    wandb.init()
    config = wandb.config

    model = create_model(config.hidden_size, config.dropout)
    optimizer = get_optimizer(config.optimizer, model, config.learning_rate)

    for epoch in range(20):
        train_loss, val_acc = train_epoch(model, optimizer, config.batch_size)
        wandb.log({"train_loss": train_loss, "val_accuracy": val_acc, "epoch": epoch})

    wandb.finish()

# 4. 启动 agent
wandb.agent(sweep_id, function=train, count=50)

场景四:Artifacts(数据集和模型版本管理)

import wandb

# 记录数据集 Artifact
run = wandb.init(project="artifact-demo", job_type="data-prep")

artifact = wandb.Artifact(
    name="training-data",
    type="dataset",
    description="清洗后的训练数据集",
    metadata={"rows": 10000, "features": 50, "version": "v2"},
)
artifact.add_dir("data/processed/")
wandb.log_artifact(artifact)
wandb.finish()

# 在训练中使用 Artifact
run = wandb.init(project="artifact-demo", job_type="training")
data_artifact = run.use_artifact("training-data:latest")
data_dir = data_artifact.download()

# 训练...

# 记录模型 Artifact
model_artifact = wandb.Artifact("trained-model", type="model",
    metadata={"accuracy": 0.95, "framework": "pytorch"})
model_artifact.add_file("model.pth")
wandb.log_artifact(model_artifact)
wandb.finish()

场景五:Tables(数据可视化)

import wandb

wandb.init(project="tables-demo")

# 记录预测结果表格
columns = ["image", "true_label", "predicted", "confidence"]
table = wandb.Table(columns=columns)

for img, true, pred, conf in zip(images, labels, predictions, confidences):
    table.add_data(
        wandb.Image(img),
        true,
        pred,
        conf,
    )

wandb.log({"predictions": table})

# 记录混淆矩阵
wandb.log({
    "confusion_matrix": wandb.plot.confusion_matrix(
        probs=None,
        y_true=true_labels,
        preds=predicted_labels,
        class_names=class_names,
    )
})

# 记录 ROC 曲线
wandb.log({
    "roc": wandb.plot.roc_curve(
        y_true=true_labels,
        y_probas=predicted_probas,
        labels=class_names,
    )
})

wandb.finish()

场景六:Reports(交互式实验报告)

import wandb

# Reports 主要通过 W&B 网页 UI 创建
# 但可以通过 API 创建

api = wandb.Api()

# 获取实验数据
runs = api.runs("my-project")
for run in runs:
    print(f"{run.name}: accuracy={run.summary.get('accuracy', 'N/A')}")

# 获取特定 run 的历史数据
run = api.run("username/project/run_id")
history = run.history()
print(history[["epoch", "loss", "accuracy"]])

场景七:Alert 告警

import wandb

wandb.init(project="alert-demo")

for epoch in range(100):
    loss = train_one_epoch()
    wandb.log({"loss": loss})

    # 训练损失异常时告警
    if loss > 10.0:
        wandb.alert(
            title="训练损失异常",
            text=f"Epoch {epoch} 的损失为 {loss:.2f},超过阈值 10.0",
            level=wandb.AlertLevel.WARN,
        )

    # 训练发散时告警
    if loss != loss:  # NaN check
        wandb.alert(
            title="训练发散!",
            text=f"Epoch {epoch} 出现 NaN 损失",
            level=wandb.AlertLevel.ERROR,
        )
        break

场景八:Model Registry

import wandb

# 将 Artifact 链接到 Model Registry
run = wandb.init(project="registry-demo")

# 训练并保存模型...
model_artifact = wandb.Artifact("my-model", type="model")
model_artifact.add_file("model.pth")
run.log_artifact(model_artifact)

# 链接到 Registry
run.link_artifact(
    artifact=model_artifact,
    target_path="my-org/model-registry/production-model",
    aliases=["staging", "v1.2"],
)

wandb.finish()

常见问题与排错

问题一:wandb.log 的调用频率

# 不要每个 batch 都 log(会很慢)
# 推荐:每 N 个 batch 或每个 epoch log 一次

for step, batch in enumerate(dataloader):
    loss = train_step(batch)
    if step % 100 == 0:  # 每100步记录一次
        wandb.log({"loss": loss, "step": step})

问题二:大文件上传慢

# 使用 Reference Artifacts(不上传实际文件)
artifact = wandb.Artifact("large-dataset", type="dataset")
artifact.add_reference(f"s3://my-bucket/data/", max_objects=10000)

问题三:断网环境使用

# 设置离线模式
export WANDB_MODE=offline

# 训练完成后,有网时同步
wandb sync --sync-all wandb/

问题四:多 GPU 训练中只记录一次

import wandb
import torch.distributed as dist

if dist.get_rank() == 0:
    wandb.init(project="distributed-training")

# 只在 rank 0 记录
if dist.get_rank() == 0:
    wandb.log({"loss": loss})

问题五:如何删除 run

api = wandb.Api()
run = api.run("username/project/run_id")
run.delete()

问题六:自定义图表

# 使用 wandb.plot 或 plotly
import plotly.express as px

fig = px.scatter(df, x="epoch", y="loss", color="experiment")
wandb.log({"custom_chart": fig})

# 使用 wandb.plot 内置图表
wandb.log({
    "pr_curve": wandb.plot.pr_curve(y_true, y_scores, labels=["positive", "negative"]),
})

参考资源

  • 官方文档:https://docs.wandb.ai/
  • GitHub:https://github.com/wandb/wandb
  • W&B Courses(免费课程):https://www.wandb.courses/
  • Fully Connected(博客):https://wandb.ai/fully-connected
  • W&B Gallery(示例):https://wandb.ai/gallery
  • API Reference:https://docs.wandb.ai/ref/python/
  • W&B Community:https://community.wandb.ai/
  • 定价:https://wandb.ai/pricing(个人免费)