MLflow 实验管理
为什么要学
MLflow 是最广泛使用的 ML 实验追踪和模型管理平台:
- 实验追踪:记录参数、指标、模型、数据集
- 模型注册:版本化管理模型,生命周期管理
- 模型部署:一键部署为 REST API
- 框架无关:PyTorch/TensorFlow/Sklearn/XGBoost 等全支持
- 开源免费:完全开源,可自托管
- 标准化:MLOps 事实标准
任何做 ML 的人都应该用 MLflow 管理实验。
核心概念
| 概念 | 说明 | 类比 |
|---|
| Experiment | 实验组(一组相关Run) | 项目 |
| Run | 一次训练执行 | 一次实验 |
| Parameter | 超参数 | 输入配置 |
| Metric | 评估指标 | 输出结果 |
| Artifact | 产物(模型/图表/数据) | 输出文件 |
| Model Registry | 模型注册表 | 模型仓库 |
| Model Stage | 模型阶段(Staging/Production) | 环境 |
安装配置
pip install mlflow
# 启动UI
mlflow ui
# 访问 http://localhost:5000
快速上手
import mlflow
import mlflow.sklearn
from sklearn.ensemble import RandomForestClassifier
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, f1_score
# 设置实验
mlflow.set_experiment("iris-classification")
# 加载数据
X, y = load_iris(return_X_y=True)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)
# 开始记录
with mlflow.start_run(run_name="rf-baseline"):
# 记录参数
n_estimators = 100
max_depth = 5
mlflow.log_param("n_estimators", n_estimators)
mlflow.log_param("max_depth", max_depth)
# 训练
model = RandomForestClassifier(n_estimators=n_estimators, max_depth=max_depth)
model.fit(X_train, y_train)
# 预测和评估
y_pred = model.predict(X_test)
accuracy = accuracy_score(y_test, y_pred)
f1 = f1_score(y_test, y_pred, average="macro")
# 记录指标
mlflow.log_metric("accuracy", accuracy)
mlflow.log_metric("f1_score", f1)
# 记录模型
mlflow.sklearn.log_model(model, "model")
print(f"Accuracy: {accuracy:.4f}, F1: {f1:.4f}")
自动记录
import mlflow
# 自动记录(零代码修改)
mlflow.autolog()
# 正常训练代码
model = RandomForestClassifier(n_estimators=200)
model.fit(X_train, y_train)
# MLflow自动记录了参数、指标、模型!
模型注册和部署
# 注册模型
mlflow.register_model(
f"runs:/{run_id}/model",
"IrisClassifier"
)
# 加载已注册模型
model = mlflow.pyfunc.load_model("models:/IrisClassifier/Production")
predictions = model.predict(X_test)
# 部署为REST API
mlflow models serve -m "models:/IrisClassifier/Production" --port 5001
# 调用
curl -X POST http://localhost:5001/invocations \
-H 'Content-Type: application/json' \
-d '{"inputs": [[5.1, 3.5, 1.4, 0.2]]}'
进阶用法
PyTorch 集成
import mlflow.pytorch
with mlflow.start_run():
mlflow.log_params({"epochs": 10, "lr": 0.001, "batch_size": 32})
for epoch in range(10):
train_loss = train_one_epoch(model, train_loader)
val_loss = evaluate(model, val_loader)
mlflow.log_metrics({"train_loss": train_loss, "val_loss": val_loss}, step=epoch)
mlflow.pytorch.log_model(model, "model")
超参搜索 + MLflow
import optuna
def objective(trial):
with mlflow.start_run(nested=True):
lr = trial.suggest_float("lr", 1e-5, 1e-1, log=True)
depth = trial.suggest_int("depth", 2, 10)
mlflow.log_params({"lr": lr, "depth": depth})
score = train_and_evaluate(lr, depth)
mlflow.log_metric("score", score)
return score
with mlflow.start_run(run_name="optuna-search"):
study = optuna.create_study(direction="maximize")
study.optimize(objective, n_trials=50)
mlflow.log_params(study.best_params)
mlflow.log_metric("best_score", study.best_value)
数据集追踪
import mlflow.data
from mlflow.data.pandas_dataset import PandasDataset
dataset = mlflow.data.from_pandas(df, source="s3://bucket/data.csv")
with mlflow.start_run():
mlflow.log_input(dataset, context="training")
常见问题
Q1: 存储后端选择?
| 后端 | 适合 |
|---|
| 本地文件 | 个人实验 |
| SQLite + 文件 | 小团队 |
| PostgreSQL + S3 | 生产环境 |
| Databricks | 企业级 |
Q2: vs W&B/Neptune?
| 方面 | MLflow | W&B | Neptune |
|---|
| 开源 | 完全 | 部分 | 部分 |
| 自托管 | 支持 | 企业版 | 企业版 |
| UI | 基础 | 优秀 | 良好 |
| 价格 | 免费 | 有免费层 | 有免费层 |
| 模型注册 | 内置 | 内置 | 内置 |
参考资源