跳转至

Determined AI 训练平台 — 企业级 ML 实验管理和分布式训练调度系统


一句话说明

Determined AI 是一个开源的机器学习训练平台,提供实验调度、超参搜索、分布式训练、模型版本管理一体化解决方案,类似私有化部署的 MLflow + Kubernetes 训练调度器,适合多人团队共享 GPU 资源。


安装与配置

# 客户端安装(用于提交任务)
pip install determined                     # 安装 Determined 客户端

# 服务端安装(需要 Docker 和 PostgreSQL)
pip install determined                     # 客户端同时含服务端工具

# 本地开发模式(单机无需 K8s,最简单)
det deploy local cluster-up \
    --det-version latest \                 # 使用最新版本
    --master-port 8080                     # 管理界面端口

# 验证连接
export DET_MASTER=localhost:8080           # 设置服务器地址
det user login admin                       # 登录(默认密码空)
det user whoami                            # 查看当前用户

# 在浏览器访问管理界面
# http://localhost:8080

核心用法

Determined 使用 YAML 文件定义实验,核心是编写 Trial 类:

# trial.py - 定义训练逻辑

import determined as det
from determined.pytorch import PyTorchTrial, PyTorchTrialContext, DataLoader
import torch
from torch import nn

class MyTrial(PyTorchTrial):
    """继承 PyTorchTrial 来定义训练逻辑"""

    def __init__(self, context: PyTorchTrialContext):
        self.context = context

        # 从超参数配置中读取值(由 Determined 自动注入)
        hparams = context.get_hparams()
        lr = hparams.get("learning_rate", 1e-3)
        hidden_size = hparams.get("hidden_size", 128)

        # 包装模型和优化器(必须通过 context 包装!)
        model = nn.Sequential(
            nn.Linear(784, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, 10)
        )
        self.model = context.wrap_model(model)           # 必须用 context.wrap_model

        optimizer = torch.optim.Adam(model.parameters(), lr=lr)
        self.optimizer = context.wrap_optimizer(optimizer)  # 必须用 context.wrap_optimizer

    def build_training_data_loader(self) -> DataLoader:
        """返回训练数据集"""
        from torchvision import datasets, transforms
        dataset = datasets.MNIST("/data/mnist", train=True, download=True,
                                  transform=transforms.ToTensor())
        return DataLoader(dataset, batch_size=self.context.get_per_slot_batch_size())

    def build_validation_data_loader(self) -> DataLoader:
        """返回验证数据集"""
        from torchvision import datasets, transforms
        dataset = datasets.MNIST("/data/mnist", train=False,
                                  transform=transforms.ToTensor())
        return DataLoader(dataset, batch_size=self.context.get_per_slot_batch_size())

    def train_batch(self, batch, epoch_idx, batch_idx):
        """单个 batch 的训练逻辑"""
        x, y = batch
        x = x.view(x.size(0), -1)                       # 展平
        logits = self.model(x)
        loss = nn.functional.cross_entropy(logits, y)

        self.context.backward(loss)                      # 用 context.backward 反向传播
        self.context.step_optimizer(self.optimizer)      # 用 context.step_optimizer 更新参数

        return {"loss": loss.item()}                     # 返回指标(自动记录)

    def evaluate_batch(self, batch, batch_idx):
        """单个 batch 的评估逻辑"""
        x, y = batch
        x = x.view(x.size(0), -1)
        logits = self.model(x)
        pred = logits.argmax(dim=1)
        accuracy = (pred == y).float().mean()
        return {"accuracy": accuracy.item()}             # 返回验证指标

参数详解

# experiment.yaml - 实验配置文件

name: mnist-classification                 # 实验名称

# 超参数搜索配置
hyperparameters:
  learning_rate:
    type: double                           # 参数类型
    minval: 0.0001
    maxval: 0.01
  hidden_size:
    type: categorical                      # 类别型超参数
    vals: [64, 128, 256]

# 超参数搜索策略
searcher:
  name: adaptive_asha                     # 自适应超参搜索(比 grid search 高效)
  metric: validation_accuracy             # 优化的目标指标
  smaller_is_better: false               # 越大越好
  max_trials: 20                         # 最多尝试 20 组超参数

# 资源配置
resources:
  slots_per_trial: 2                     # 每个 trial 使用 2 张 GPU
  max_slots: 8                           # 集群最多用 8 张 GPU(并行跑4个trial)

# 训练配置
min_checkpoint_period:
  batches: 500                           # 至少每 500 批次保存一次

# 指向训练代码
entrypoint: trial.py:MyTrial            # 文件名:类名

实战案例

# 提交实验并监控
det experiment create experiment.yaml .   # 提交实验('.' 表示上传当前目录代码)

# 查看实验状态
det experiment list                       # 列出所有实验
det experiment describe 1                 # 查看实验1的详情

# 查看日志
det trial logs 1                          # 实验1第1个trial的日志

# 等待并下载最佳模型
det experiment wait 1                     # 等待实验完成
det checkpoint download best --exp-id 1  # 下载最佳检查点

# 暂停/恢复/取消实验
det experiment pause 1                    # 暂停(释放 GPU)
det experiment activate 1                 # 恢复
det experiment kill 1                     # 取消

常见报错与解决

报错信息原因解决方法
Connection refused服务未启动检查 det deploy local cluster-up 是否运行
wrap_model 未调用模型未被 Determined 管理必须用 context.wrap_model(model)
GPU 资源不足slots_per_trial > 可用GPU减少 slots_per_trial
实验卡在 queued配额问题调整 max_slots 或等待其他实验完成
检查点找不到保存策略问题设置 min_checkpoint_period

速查表

功能命令/方式
提交实验det experiment create config.yaml .
查看实验det experiment list
下载检查点det checkpoint download best --exp-id N
查看日志det trial logs <trial_id>
超参搜索searcher.name: adaptive_asha
Web 界面http://localhost:8080
官方文档https://docs.determined.ai