跳转至

MLflow 实验管理

为什么要学

MLflow 是最广泛使用的 ML 实验追踪和模型管理平台:

  • 实验追踪:记录参数、指标、模型、数据集
  • 模型注册:版本化管理模型,生命周期管理
  • 模型部署:一键部署为 REST API
  • 框架无关:PyTorch/TensorFlow/Sklearn/XGBoost 等全支持
  • 开源免费:完全开源,可自托管
  • 标准化:MLOps 事实标准

任何做 ML 的人都应该用 MLflow 管理实验。

核心概念

概念说明类比
Experiment实验组(一组相关Run)项目
Run一次训练执行一次实验
Parameter超参数输入配置
Metric评估指标输出结果
Artifact产物(模型/图表/数据)输出文件
Model Registry模型注册表模型仓库
Model Stage模型阶段(Staging/Production)环境

安装配置

pip install mlflow

# 启动UI
mlflow ui
# 访问 http://localhost:5000

快速上手

import mlflow
import mlflow.sklearn
from sklearn.ensemble import RandomForestClassifier
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, f1_score

# 设置实验
mlflow.set_experiment("iris-classification")

# 加载数据
X, y = load_iris(return_X_y=True)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)

# 开始记录
with mlflow.start_run(run_name="rf-baseline"):
    # 记录参数
    n_estimators = 100
    max_depth = 5
    mlflow.log_param("n_estimators", n_estimators)
    mlflow.log_param("max_depth", max_depth)

    # 训练
    model = RandomForestClassifier(n_estimators=n_estimators, max_depth=max_depth)
    model.fit(X_train, y_train)

    # 预测和评估
    y_pred = model.predict(X_test)
    accuracy = accuracy_score(y_test, y_pred)
    f1 = f1_score(y_test, y_pred, average="macro")

    # 记录指标
    mlflow.log_metric("accuracy", accuracy)
    mlflow.log_metric("f1_score", f1)

    # 记录模型
    mlflow.sklearn.log_model(model, "model")

    print(f"Accuracy: {accuracy:.4f}, F1: {f1:.4f}")

自动记录

import mlflow

# 自动记录(零代码修改)
mlflow.autolog()

# 正常训练代码
model = RandomForestClassifier(n_estimators=200)
model.fit(X_train, y_train)
# MLflow自动记录了参数、指标、模型!

模型注册和部署

# 注册模型
mlflow.register_model(
    f"runs:/{run_id}/model",
    "IrisClassifier"
)

# 加载已注册模型
model = mlflow.pyfunc.load_model("models:/IrisClassifier/Production")
predictions = model.predict(X_test)
# 部署为REST API
mlflow models serve -m "models:/IrisClassifier/Production" --port 5001

# 调用
curl -X POST http://localhost:5001/invocations \
  -H 'Content-Type: application/json' \
  -d '{"inputs": [[5.1, 3.5, 1.4, 0.2]]}'

进阶用法

PyTorch 集成

import mlflow.pytorch

with mlflow.start_run():
    mlflow.log_params({"epochs": 10, "lr": 0.001, "batch_size": 32})

    for epoch in range(10):
        train_loss = train_one_epoch(model, train_loader)
        val_loss = evaluate(model, val_loader)
        mlflow.log_metrics({"train_loss": train_loss, "val_loss": val_loss}, step=epoch)

    mlflow.pytorch.log_model(model, "model")

超参搜索 + MLflow

import optuna

def objective(trial):
    with mlflow.start_run(nested=True):
        lr = trial.suggest_float("lr", 1e-5, 1e-1, log=True)
        depth = trial.suggest_int("depth", 2, 10)

        mlflow.log_params({"lr": lr, "depth": depth})
        score = train_and_evaluate(lr, depth)
        mlflow.log_metric("score", score)
        return score

with mlflow.start_run(run_name="optuna-search"):
    study = optuna.create_study(direction="maximize")
    study.optimize(objective, n_trials=50)
    mlflow.log_params(study.best_params)
    mlflow.log_metric("best_score", study.best_value)

数据集追踪

import mlflow.data
from mlflow.data.pandas_dataset import PandasDataset

dataset = mlflow.data.from_pandas(df, source="s3://bucket/data.csv")
with mlflow.start_run():
    mlflow.log_input(dataset, context="training")

常见问题

Q1: 存储后端选择?

后端适合
本地文件个人实验
SQLite + 文件小团队
PostgreSQL + S3生产环境
Databricks企业级

Q2: vs W&B/Neptune?

方面MLflowW&BNeptune
开源完全部分部分
自托管支持企业版企业版
UI基础优秀良好
价格免费有免费层有免费层
模型注册内置内置内置

参考资源