Determined AI 训练平台 — 企业级 ML 实验管理和分布式训练调度系统¶
一句话说明¶
Determined AI 是一个开源的机器学习训练平台,提供实验调度、超参搜索、分布式训练、模型版本管理一体化解决方案,类似私有化部署的 MLflow + Kubernetes 训练调度器,适合多人团队共享 GPU 资源。
安装与配置¶
# 客户端安装(用于提交任务)
pip install determined # 安装 Determined 客户端
# 服务端安装(需要 Docker 和 PostgreSQL)
pip install determined # 客户端同时含服务端工具
# 本地开发模式(单机无需 K8s,最简单)
det deploy local cluster-up \
--det-version latest \ # 使用最新版本
--master-port 8080 # 管理界面端口
# 验证连接
export DET_MASTER=localhost:8080 # 设置服务器地址
det user login admin # 登录(默认密码空)
det user whoami # 查看当前用户
# 在浏览器访问管理界面
# http://localhost:8080
核心用法¶
Determined 使用 YAML 文件定义实验,核心是编写 Trial 类:
# trial.py - 定义训练逻辑
import determined as det
from determined.pytorch import PyTorchTrial, PyTorchTrialContext, DataLoader
import torch
from torch import nn
class MyTrial(PyTorchTrial):
"""继承 PyTorchTrial 来定义训练逻辑"""
def __init__(self, context: PyTorchTrialContext):
self.context = context
# 从超参数配置中读取值(由 Determined 自动注入)
hparams = context.get_hparams()
lr = hparams.get("learning_rate", 1e-3)
hidden_size = hparams.get("hidden_size", 128)
# 包装模型和优化器(必须通过 context 包装!)
model = nn.Sequential(
nn.Linear(784, hidden_size),
nn.ReLU(),
nn.Linear(hidden_size, 10)
)
self.model = context.wrap_model(model) # 必须用 context.wrap_model
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
self.optimizer = context.wrap_optimizer(optimizer) # 必须用 context.wrap_optimizer
def build_training_data_loader(self) -> DataLoader:
"""返回训练数据集"""
from torchvision import datasets, transforms
dataset = datasets.MNIST("/data/mnist", train=True, download=True,
transform=transforms.ToTensor())
return DataLoader(dataset, batch_size=self.context.get_per_slot_batch_size())
def build_validation_data_loader(self) -> DataLoader:
"""返回验证数据集"""
from torchvision import datasets, transforms
dataset = datasets.MNIST("/data/mnist", train=False,
transform=transforms.ToTensor())
return DataLoader(dataset, batch_size=self.context.get_per_slot_batch_size())
def train_batch(self, batch, epoch_idx, batch_idx):
"""单个 batch 的训练逻辑"""
x, y = batch
x = x.view(x.size(0), -1) # 展平
logits = self.model(x)
loss = nn.functional.cross_entropy(logits, y)
self.context.backward(loss) # 用 context.backward 反向传播
self.context.step_optimizer(self.optimizer) # 用 context.step_optimizer 更新参数
return {"loss": loss.item()} # 返回指标(自动记录)
def evaluate_batch(self, batch, batch_idx):
"""单个 batch 的评估逻辑"""
x, y = batch
x = x.view(x.size(0), -1)
logits = self.model(x)
pred = logits.argmax(dim=1)
accuracy = (pred == y).float().mean()
return {"accuracy": accuracy.item()} # 返回验证指标
参数详解¶
# experiment.yaml - 实验配置文件
name: mnist-classification # 实验名称
# 超参数搜索配置
hyperparameters:
learning_rate:
type: double # 参数类型
minval: 0.0001
maxval: 0.01
hidden_size:
type: categorical # 类别型超参数
vals: [64, 128, 256]
# 超参数搜索策略
searcher:
name: adaptive_asha # 自适应超参搜索(比 grid search 高效)
metric: validation_accuracy # 优化的目标指标
smaller_is_better: false # 越大越好
max_trials: 20 # 最多尝试 20 组超参数
# 资源配置
resources:
slots_per_trial: 2 # 每个 trial 使用 2 张 GPU
max_slots: 8 # 集群最多用 8 张 GPU(并行跑4个trial)
# 训练配置
min_checkpoint_period:
batches: 500 # 至少每 500 批次保存一次
# 指向训练代码
entrypoint: trial.py:MyTrial # 文件名:类名
实战案例¶
# 提交实验并监控
det experiment create experiment.yaml . # 提交实验('.' 表示上传当前目录代码)
# 查看实验状态
det experiment list # 列出所有实验
det experiment describe 1 # 查看实验1的详情
# 查看日志
det trial logs 1 # 实验1第1个trial的日志
# 等待并下载最佳模型
det experiment wait 1 # 等待实验完成
det checkpoint download best --exp-id 1 # 下载最佳检查点
# 暂停/恢复/取消实验
det experiment pause 1 # 暂停(释放 GPU)
det experiment activate 1 # 恢复
det experiment kill 1 # 取消
常见报错与解决¶
| 报错信息 | 原因 | 解决方法 |
|---|---|---|
Connection refused | 服务未启动 | 检查 det deploy local cluster-up 是否运行 |
wrap_model 未调用 | 模型未被 Determined 管理 | 必须用 context.wrap_model(model) |
| GPU 资源不足 | slots_per_trial > 可用GPU | 减少 slots_per_trial |
| 实验卡在 queued | 配额问题 | 调整 max_slots 或等待其他实验完成 |
| 检查点找不到 | 保存策略问题 | 设置 min_checkpoint_period |
速查表¶
| 功能 | 命令/方式 |
|---|---|
| 提交实验 | det experiment create config.yaml . |
| 查看实验 | det experiment list |
| 下载检查点 | det checkpoint download best --exp-id N |
| 查看日志 | det trial logs <trial_id> |
| 超参搜索 | searcher.name: adaptive_asha |
| Web 界面 | http://localhost:8080 |
| 官方文档 | https://docs.determined.ai |