Ray 分布式框架¶
为什么要学 Ray¶
Ray 是一个通用的分布式计算框架,由 UC Berkeley RISELab 开发。它不仅支持数据处理(Ray Data),还支持分布式训练(Ray Train)、超参搜索(Ray Tune)、模型服务(Ray Serve)和强化学习(RLlib)。Ray 让你用简单的 Python 代码就能将计算扩展到集群,是 AI/ML 领域最热门的分布式框架。
核心概念¶
| 概念 | 白话解释 | 用途 |
|---|---|---|
| Task | 远程函数 | 在集群上并行执行函数 |
| Actor | 有状态服务 | 分布式的类实例 |
| Object Store | 对象存储 | 节点间共享数据的内存 |
| Ray Data | 数据处理 | 大规模数据管道 |
| Ray Train | 分布式训练 | 多 GPU/节点训练 |
| Ray Serve | 模型服务 | 在线推理部署 |
安装配置¶
pip install "ray[default]"
# 含 dashboard, cluster launcher
# ML 全家桶
pip install "ray[data,train,tune,serve]"
快速上手¶
基础并行¶
import ray
ray.init() # 初始化本地集群
@ray.remote
def heavy_task(x):
import time; time.sleep(1)
return x * x
# 并行执行(不等待结果)
futures = [heavy_task.remote(i) for i in range(10)]
# 获取结果
results = ray.get(futures)
print(results) # [0, 1, 4, 9, 16, ...]
# 10个1秒任务,并行执行只需约1秒
Actor(有状态服务)¶
@ray.remote
class Counter:
def __init__(self):
self.count = 0
def increment(self):
self.count += 1
return self.count
def get(self):
return self.count
counter = Counter.remote()
[counter.increment.remote() for _ in range(100)]
print(ray.get(counter.get.remote())) # 100
Ray Data¶
import ray
ds = ray.data.read_csv("s3://bucket/data/")
ds = ds.filter(lambda row: row["age"] > 25)
ds = ds.map(lambda row: {**row, "name_upper": row["name"].upper()})
result = ds.to_pandas()
Ray Serve¶
from ray import serve
from transformers import pipeline
@serve.deployment
class SentimentAnalysis:
def __init__(self):
self.model = pipeline("sentiment-analysis")
async def __call__(self, request):
text = (await request.json())["text"]
return self.model(text)
app = SentimentAnalysis.bind()
serve.run(app)
进阶用法¶
分布式训练¶
from ray.train.torch import TorchTrainer
def train_func(config):
model = build_model()
for epoch in range(config["epochs"]):
train_one_epoch(model)
ray.train.report({"loss": loss})
trainer = TorchTrainer(
train_func,
train_loop_config={"epochs": 10},
scaling_config=ray.train.ScalingConfig(num_workers=4, use_gpu=True),
)
result = trainer.fit()
Ray Tune 超参搜索¶
from ray import tune
def objective(config):
score = train_model(lr=config["lr"], batch_size=config["batch_size"])
tune.report(score=score)
analysis = tune.run(
objective,
config={"lr": tune.loguniform(1e-4, 1e-1), "batch_size": tune.choice([16, 32, 64])},
num_samples=50,
metric="score",
mode="max",
)
print(analysis.best_config)
集群部署¶
# ray-cluster.yaml
cluster_name: my-cluster
provider:
type: aws
region: us-west-2
head_node:
InstanceType: m5.2xlarge
worker_nodes:
InstanceType: m5.4xlarge
min_workers: 2
max_workers: 10
常见问题¶
Q: Ray vs Dask?¶
- Ray:通用分布式框架,ML 生态强,Actor 模型
- Dask:专注数据处理,Pandas/NumPy 兼容性更好
Q: 学习曲线如何?¶
@ray.remote 装饰器非常简单。高级功能(Train/Tune/Serve)需要更多学习。
Q: 本地开发如何模拟集群?¶
ray.init() 默认创建本地集群,利用所有 CPU 核心,开发体验与集群一致。
参考资源¶
- 官网:https://www.ray.io/
- GitHub:https://github.com/ray-project/ray
- 文档:https://docs.ray.io/
- 教程:https://docs.ray.io/en/latest/ray-overview/getting-started.html
- Ray Summit 视频:https://www.youtube.com/c/RayProject