Source: examples/distributed_ray_train
Distributed Ray Training with SkyPilot#
This example shows how to launch distributed Ray jobs with SkyPilot.
Important: Ray Runtime Best Practices#
SkyPilot uses Ray internally on port 6380 for cluster management, so always start your own Ray cluster on a different port (e.g. 6379 is the default) when running Ray workloads on SkyPilot. Don’t use ray.init(address="auto")
as it would connect to SkyPilot’s internal cluster, causing resource conflicts.
The example in ray_train.yaml
demonstrates the correct approach:
Start Ray head node on rank 0
Start Ray workers on other ranks
Connect to your own Ray cluster, not SkyPilot’s internal one
Running the Example#
# Download the training script
wget https://raw.githubusercontent.com/skypilot-org/skypilot/master/examples/distributed_ray_train/train.py
# Launch on a cluster
sky launch -c ray-train ray_train.yaml
Included files#
ray_train.yaml
# This example starts its own Ray cluster on port 6379, separate from SkyPilot's
# internal Ray cluster (port 6380). Do not use ray.init(address="auto") as it
# connects to SkyPilot's internal cluster, causing resource conflicts.
resources:
accelerators: L4:2
memory: 64+
num_nodes: 2
workdir: .
setup: |
conda activate ray
if [ $? -ne 0 ]; then
conda create -n ray python=3.10 -y
conda activate ray
fi
pip install "ray[train]"
pip install tqdm
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
run: |
sudo chmod 777 -R /var/tmp
conda activate ray
head_ip=`echo "$SKYPILOT_NODE_IPS" | head -n1`
num_nodes=`echo "$SKYPILOT_NODE_IPS" | wc -l`
if [ "$SKYPILOT_NODE_RANK" == "0" ]; then
ps aux | grep ray | grep 6379 &> /dev/null || ray start --head --disable-usage-stats --port 6379
sleep 5
python train.py --num-workers $num_nodes
else
sleep 5
ps aux | grep ray | grep 6379 &> /dev/null || ray start --address $head_ip:6379 --disable-usage-stats
# Add sleep to after `ray start` to give ray enough time to daemonize
sleep 5
fi
import argparse
import os
from typing import Dict
from filelock import FileLock
import ray.train
from ray.train import ScalingConfig
from ray.train.torch import TorchTrainer
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision import transforms
from torchvision.transforms import Normalize
from torchvision.transforms import ToTensor
from tqdm import tqdm
def get_dataloaders(batch_size):
# Transform to normalize the input images
transform = transforms.Compose([ToTensor(), Normalize((0.5,), (0.5,))])
with FileLock(os.path.expanduser('~/data.lock')):
# Download training data from open datasets
training_data = datasets.FashionMNIST(
root='~/data',
train=True,
download=True,
transform=transform,
)
# Download test data from open datasets
test_data = datasets.FashionMNIST(
root='~/data',
train=False,
download=True,
transform=transform,
)
# Create data loaders
train_dataloader = DataLoader(training_data, batch_size=batch_size)
test_dataloader = DataLoader(test_data, batch_size=batch_size)
return train_dataloader, test_dataloader
# Model Definition
class NeuralNetwork(nn.Module):
def __init__(self):
super(NeuralNetwork, self).__init__()
self.flatten = nn.Flatten()
self.linear_relu_stack = nn.Sequential(
nn.Linear(28 * 28, 512),
nn.ReLU(),
nn.Dropout(0.25),
nn.Linear(512, 512),
nn.ReLU(),
nn.Dropout(0.25),
nn.Linear(512, 10),
nn.ReLU(),
)
def forward(self, x):
x = self.flatten(x)
logits = self.linear_relu_stack(x)
return logits
def train_func_per_worker(config: Dict):
lr = config['lr']
epochs = config['epochs']
batch_size = config['batch_size_per_worker']
# Get dataloaders inside the worker training function
train_dataloader, test_dataloader = get_dataloaders(batch_size=batch_size)
# [1] Prepare Dataloader for distributed training
# Shard the datasets among workers and move batches to the correct device
# =======================================================================
train_dataloader = ray.train.torch.prepare_data_loader(train_dataloader)
test_dataloader = ray.train.torch.prepare_data_loader(test_dataloader)
model = NeuralNetwork()
# [2] Prepare and wrap your model with DistributedDataParallel
# Move the model to the correct GPU/CPU device
# ============================================================
model = ray.train.torch.prepare_model(model)
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9)
# Model training loop
for epoch in range(epochs):
model.train()
for X, y in tqdm(train_dataloader, desc=f'Train Epoch {epoch}'):
pred = model(X)
loss = loss_fn(pred, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
model.eval()
test_loss, num_correct, num_total = 0, 0, 0
with torch.no_grad():
for X, y in tqdm(test_dataloader, desc=f'Test Epoch {epoch}'):
pred = model(X)
loss = loss_fn(pred, y)
test_loss += loss.item()
num_total += y.shape[0]
num_correct += (pred.argmax(1) == y).sum().item()
test_loss /= len(test_dataloader)
accuracy = num_correct / num_total
# [3] Report metrics to Ray Train
# ===============================
ray.train.report(metrics={'loss': test_loss, 'accuracy': accuracy})
def train_fashion_mnist(num_workers=2, use_gpu=False):
global_batch_size = 32
train_config = {
'lr': 1e-3,
'epochs': 10,
'batch_size_per_worker': global_batch_size // num_workers,
}
# Configure computation resources
scaling_config = ScalingConfig(num_workers=num_workers, use_gpu=use_gpu)
# Initialize a Ray TorchTrainer
trainer = TorchTrainer(
train_loop_per_worker=train_func_per_worker,
train_loop_config=train_config,
scaling_config=scaling_config,
)
# [4] Start distributed training
# Run `train_func_per_worker` on all workers
# =============================================
result = trainer.fit()
print(f'Training result: {result}')
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--num-workers', type=int, default=2)
args = parser.parse_args()
train_fashion_mnist(num_workers=args.num_workers, use_gpu=True)