Source: examples/distributed_ray_train
Distributed Ray Training with SkyPilot#
This example shows how to launch distributed Ray jobs with SkyPilot.
Setting Up Your Ray Cluster#
SkyPilot provides templates for common workloads, such as Ray: ~/sky_templates/ray/start_cluster will be available on SkyPilot clusters, and will set up a Ray cluster for your workloads. Simply call it in your task’s run commands:
run: |
~/sky_templates/ray/start_cluster
Under the hood, this script automatically:
Installs
rayif not already presentStarts the head node (rank 0) and workers on all other nodes
Waits for the head node to be healthy before starting workers
Ensures all nodes have joined before proceeding
Tip: The script uses SkyPilot’s environment variables (
SKYPILOT_NODE_RANK,SKYPILOT_NODE_IPS,SKYPILOT_NUM_NODES,SKYPILOT_NUM_GPUS_PER_NODE) to coordinate the distributed setup. See Distributed Multi-Node Jobs for more details.
Customizing the Ray Cluster#
Customize the Ray cluster by setting environment variables before calling start_cluster:
Variable |
Default |
Description |
|---|---|---|
|
|
Ray head node port (must differ from SkyPilot’s 6380) |
|
|
Ray dashboard port (must differ from SkyPilot’s 8266) |
|
|
Dashboard host (set to |
|
|
Optional dashboard agent listen port |
|
|
Optional head node IP address override |
|
|
Ray command (e.g., |
Managing the Ray Cluster#
Stop your Ray cluster with:
~/sky_templates/ray/stop_cluster
Do not use ray stop directly, as it may interfere with SkyPilot’s cluster management.
To restart, simply run start_cluster again. The script detects if Ray is already running and skips startup if the cluster is healthy.
Running the Example#
# Download the training script
wget https://raw.githubusercontent.com/skypilot-org/skypilot/master/examples/distributed_ray_train/train.py
# Launch on a cluster
sky launch -c ray-train --num-nodes 4 ray_train.yaml
# To stop the Ray cluster
sky exec ray-train --num-nodes 4 'RAY_CMD=~/sky_workdir/.venv/bin/ray ~/sky_templates/ray/stop_cluster'
Important: Ray Runtime Best Practices#
SkyPilot uses Ray internally on port 6380 for cluster management. So when running your own Ray applications, you need to start a separate Ray
cluster on a different port (e.g. 6379 is the default) to avoid conflicts. Do not use ray.init(address="auto") as it would connect to
SkyPilot’s internal cluster, causing resource conflicts.
Included files#
ray_train.yaml
# This example starts its own Ray cluster on port 6379, separate from SkyPilot's
# internal Ray cluster (port 6380). Do not use ray.init(address="auto") as it
# connects to SkyPilot's internal cluster, causing resource conflicts.
resources:
accelerators: L4:2
memory: 64+
num_nodes: 2
workdir: .
setup: |
uv venv --python 3.10 --seed
source .venv/bin/activate
uv pip install "ray[train]" "click<8.2.0"
uv pip install tqdm
uv pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
run: |
sudo chmod 777 -R /var/tmp
source .venv/bin/activate
# This script is only available on skypilot-nightly>=1.0.0.dev20251114
# If you are using an older version, you can copy and paste the script from:
# https://github.com/skypilot-org/skypilot/blob/master/sky_templates/ray/start_cluster
~/sky_templates/ray/start_cluster
num_nodes=`echo "$SKYPILOT_NODE_IPS" | wc -l`
if [ "$SKYPILOT_NODE_RANK" == "0" ]; then
python train.py --num-workers $num_nodes
fi
import argparse
import os
from typing import Dict
from filelock import FileLock
import ray.train
from ray.train import ScalingConfig
from ray.train.torch import TorchTrainer
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision import transforms
from torchvision.transforms import Normalize
from torchvision.transforms import ToTensor
from tqdm import tqdm
def get_dataloaders(batch_size):
# Transform to normalize the input images
transform = transforms.Compose([ToTensor(), Normalize((0.5,), (0.5,))])
with FileLock(os.path.expanduser('~/data.lock')):
# Download training data from open datasets
training_data = datasets.FashionMNIST(
root='~/data',
train=True,
download=True,
transform=transform,
)
# Download test data from open datasets
test_data = datasets.FashionMNIST(
root='~/data',
train=False,
download=True,
transform=transform,
)
# Create data loaders
train_dataloader = DataLoader(training_data, batch_size=batch_size)
test_dataloader = DataLoader(test_data, batch_size=batch_size)
return train_dataloader, test_dataloader
# Model Definition
class NeuralNetwork(nn.Module):
def __init__(self):
super(NeuralNetwork, self).__init__()
self.flatten = nn.Flatten()
self.linear_relu_stack = nn.Sequential(
nn.Linear(28 * 28, 512),
nn.ReLU(),
nn.Dropout(0.25),
nn.Linear(512, 512),
nn.ReLU(),
nn.Dropout(0.25),
nn.Linear(512, 10),
nn.ReLU(),
)
def forward(self, x):
x = self.flatten(x)
logits = self.linear_relu_stack(x)
return logits
def train_func_per_worker(config: Dict):
lr = config['lr']
epochs = config['epochs']
batch_size = config['batch_size_per_worker']
# Get dataloaders inside the worker training function
train_dataloader, test_dataloader = get_dataloaders(batch_size=batch_size)
# [1] Prepare Dataloader for distributed training
# Shard the datasets among workers and move batches to the correct device
# =======================================================================
train_dataloader = ray.train.torch.prepare_data_loader(train_dataloader)
test_dataloader = ray.train.torch.prepare_data_loader(test_dataloader)
model = NeuralNetwork()
# [2] Prepare and wrap your model with DistributedDataParallel
# Move the model to the correct GPU/CPU device
# ============================================================
model = ray.train.torch.prepare_model(model)
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9)
# Model training loop
for epoch in range(epochs):
model.train()
for X, y in tqdm(train_dataloader, desc=f'Train Epoch {epoch}'):
pred = model(X)
loss = loss_fn(pred, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
model.eval()
test_loss, num_correct, num_total = 0, 0, 0
with torch.no_grad():
for X, y in tqdm(test_dataloader, desc=f'Test Epoch {epoch}'):
pred = model(X)
loss = loss_fn(pred, y)
test_loss += loss.item()
num_total += y.shape[0]
num_correct += (pred.argmax(1) == y).sum().item()
test_loss /= len(test_dataloader)
accuracy = num_correct / num_total
# [3] Report metrics to Ray Train
# ===============================
ray.train.report(metrics={'loss': test_loss, 'accuracy': accuracy})
def train_fashion_mnist(num_workers=2, use_gpu=False):
global_batch_size = 32
train_config = {
'lr': 1e-3,
'epochs': 10,
'batch_size_per_worker': global_batch_size // num_workers,
}
# Configure computation resources
scaling_config = ScalingConfig(num_workers=num_workers, use_gpu=use_gpu)
# Initialize a Ray TorchTrainer
trainer = TorchTrainer(
train_loop_per_worker=train_func_per_worker,
train_loop_config=train_config,
scaling_config=scaling_config,
)
# [4] Start distributed training
# Run `train_func_per_worker` on all workers
# =============================================
result = trainer.fit()
print(f'Training result: {result}')
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--num-workers', type=int, default=2)
args = parser.parse_args()
train_fashion_mnist(num_workers=args.num_workers, use_gpu=True)