Weights & Biases 实验追踪¶
一句话概述:W&B(Weights & Biases)是机器学习实验管理平台,自动记录训练指标、超参数、模型版本,生成交互式图表,让你轻松比较不同实验的结果。
核心知识点¶
| 概念 | 白话解释 |
|---|---|
| Run | 运行 = 一次训练实验 |
| Project | 项目 = 一组相关实验的集合 |
| Sweep | 扫描 = 自动超参数搜索 |
| Artifact | 工件 = 版本化的数据集/模型/文件 |
| Dashboard | 仪表板 = 实验结果的交互式可视化 |
安装配置¶
基本使用¶
import wandb # 导入 W&B
# 初始化实验
wandb.init(
project="t2d-microbiome", # 项目名
name="rf-baseline", # 运行名
config={ # 超参数
"model": "RandomForest",
"n_estimators": 100,
"max_depth": 10,
"learning_rate": 0.1
}
)
# 训练循环中记录指标
for epoch in range(100):
train_loss = train_one_epoch() # 训练
val_loss, val_acc = evaluate() # 评估
wandb.log({ # 记录指标
"epoch": epoch,
"train_loss": train_loss,
"val_loss": val_loss,
"val_accuracy": val_acc
})
# 记录图片/图表
wandb.log({"confusion_matrix": wandb.Image("cm.png")}) # 记录图片
wandb.log({"roc_curve": wandb.plot.roc_curve(y_test, y_prob)}) # ROC 曲线
# 记录模型
wandb.save("best_model.pkl") # 保存模型文件
wandb.finish() # 结束实验
Scikit-learn 集成¶
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score, f1_score
wandb.init(project="ml-experiments", config={
"n_estimators": 200, "max_depth": 10
})
model = RandomForestClassifier(**wandb.config) # 用 wandb config
model.fit(X_train, y_train)
y_pred = model.predict(X_test)
wandb.log({
"accuracy": accuracy_score(y_test, y_pred), # 记录准确率
"f1": f1_score(y_test, y_pred, average='macro') # 记录 F1
})
wandb.finish()
PyTorch 集成¶
wandb.init(project="dl-experiments")
wandb.watch(model, log='all') # 自动追踪模型梯度
for epoch in range(epochs):
for batch_x, batch_y in loader:
loss = train_step(batch_x, batch_y)
wandb.log({"batch_loss": loss}) # 记录每个 batch 损失
wandb.log({"epoch": epoch, "val_acc": val_acc})
高级用法¶
Sweep 超参数搜索¶
# 定义搜索空间
sweep_config = {
'method': 'bayes', # 贝叶斯搜索
'metric': {'name': 'val_accuracy', 'goal': 'maximize'},
'parameters': {
'n_estimators': {'values': [50, 100, 200, 500]},
'max_depth': {'min': 3, 'max': 20},
'learning_rate': {'min': 0.01, 'max': 0.3}
}
}
sweep_id = wandb.sweep(sweep_config, project="my-project")
def train():
wandb.init()
config = wandb.config
model = train_model(config)
wandb.log({"val_accuracy": evaluate(model)})
wandb.agent(sweep_id, train, count=50) # 运行 50 次
常见报错¶
| 报错信息 | 原因 | 解决方法 |
|---|---|---|
wandb: ERROR Run already finished | 重复 finish | 每个 run 只 finish 一次 |
CommError | 网络问题 | 用 WANDB_MODE=offline 离线模式 |
API key not configured | 未登录 | wandb login |
速查表¶
wandb.init(project, name, config) # 初始化
wandb.log({"key": value}) # 记录指标
wandb.log({"img": wandb.Image(path)}) # 记录图片
wandb.watch(model) # 追踪模型
wandb.save(path) # 保存文件
wandb.finish() # 结束
# === 离线模式 ===
# WANDB_MODE=offline python train.py # 离线训练
# wandb sync ./wandb/offline-run-* # 后续同步
# === W&B vs MLflow ===
# W&B: 云端免费版、界面更好、协作方便
# MLflow: 完全开源、可私有部署、与 Spark 集成好
参考:W&B 文档 | 更新于 2026 年