跳转至

Weights & Biases 实验追踪

一句话概述:W&B(Weights & Biases)是机器学习实验管理平台,自动记录训练指标、超参数、模型版本,生成交互式图表,让你轻松比较不同实验的结果。

核心知识点

概念白话解释
Run运行 = 一次训练实验
Project项目 = 一组相关实验的集合
Sweep扫描 = 自动超参数搜索
Artifact工件 = 版本化的数据集/模型/文件
Dashboard仪表板 = 实验结果的交互式可视化

安装配置

pip install wandb                                     # 安装
wandb login                                           # 登录(输入 API Key)
# API Key 在 https://wandb.ai/authorize 获取

基本使用

import wandb                                          # 导入 W&B

# 初始化实验
wandb.init(
    project="t2d-microbiome",                         # 项目名
    name="rf-baseline",                               # 运行名
    config={                                          # 超参数
        "model": "RandomForest",
        "n_estimators": 100,
        "max_depth": 10,
        "learning_rate": 0.1
    }
)

# 训练循环中记录指标
for epoch in range(100):
    train_loss = train_one_epoch()                    # 训练
    val_loss, val_acc = evaluate()                    # 评估
    wandb.log({                                       # 记录指标
        "epoch": epoch,
        "train_loss": train_loss,
        "val_loss": val_loss,
        "val_accuracy": val_acc
    })

# 记录图片/图表
wandb.log({"confusion_matrix": wandb.Image("cm.png")})  # 记录图片
wandb.log({"roc_curve": wandb.plot.roc_curve(y_test, y_prob)})  # ROC 曲线

# 记录模型
wandb.save("best_model.pkl")                          # 保存模型文件

wandb.finish()                                        # 结束实验

Scikit-learn 集成

from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score, f1_score

wandb.init(project="ml-experiments", config={
    "n_estimators": 200, "max_depth": 10
})

model = RandomForestClassifier(**wandb.config)        # 用 wandb config
model.fit(X_train, y_train)
y_pred = model.predict(X_test)

wandb.log({
    "accuracy": accuracy_score(y_test, y_pred),       # 记录准确率
    "f1": f1_score(y_test, y_pred, average='macro')   # 记录 F1
})
wandb.finish()

PyTorch 集成

wandb.init(project="dl-experiments")
wandb.watch(model, log='all')                         # 自动追踪模型梯度

for epoch in range(epochs):
    for batch_x, batch_y in loader:
        loss = train_step(batch_x, batch_y)
        wandb.log({"batch_loss": loss})               # 记录每个 batch 损失
    wandb.log({"epoch": epoch, "val_acc": val_acc})

高级用法

Sweep 超参数搜索

# 定义搜索空间
sweep_config = {
    'method': 'bayes',                                # 贝叶斯搜索
    'metric': {'name': 'val_accuracy', 'goal': 'maximize'},
    'parameters': {
        'n_estimators': {'values': [50, 100, 200, 500]},
        'max_depth': {'min': 3, 'max': 20},
        'learning_rate': {'min': 0.01, 'max': 0.3}
    }
}

sweep_id = wandb.sweep(sweep_config, project="my-project")

def train():
    wandb.init()
    config = wandb.config
    model = train_model(config)
    wandb.log({"val_accuracy": evaluate(model)})

wandb.agent(sweep_id, train, count=50)                # 运行 50 次

常见报错

报错信息原因解决方法
wandb: ERROR Run already finished重复 finish每个 run 只 finish 一次
CommError网络问题WANDB_MODE=offline 离线模式
API key not configured未登录wandb login

速查表

wandb.init(project, name, config)      # 初始化
wandb.log({"key": value})              # 记录指标
wandb.log({"img": wandb.Image(path)})  # 记录图片
wandb.watch(model)                     # 追踪模型
wandb.save(path)                       # 保存文件
wandb.finish()                         # 结束

# === 离线模式 ===
# WANDB_MODE=offline python train.py   # 离线训练
# wandb sync ./wandb/offline-run-*     # 后续同步

# === W&B vs MLflow ===
# W&B: 云端免费版、界面更好、协作方便
# MLflow: 完全开源、可私有部署、与 Spark 集成好

参考:W&B 文档 | 更新于 2026 年