Source: llm/verl
Verl: State-of-the-art RL Training for LLMs#
Verl is the most popular open-source reinforcement learning framework for LLMs, supporting PPO, GRPO, and other algorithms.
Why SkyPilot + Verl?#
SkyPilot makes RL training easy and cost-effective:
Get GPUs instantly across clouds and Kubernetes
3x cheaper with managed spot instances
Zero setup - handles distributed Ray clusters automatically
Quick Start#
Launch single node agent training:
sky launch -c verl-ppo llm/verl/verl-ppo.yaml --secret WANDB_API_KEY --num-nodes 1 -y
sky launch -c verl-ppo llm/verl/verl-ppo.yaml --secret WANDB_API_KEY --secret HF_TOKEN --num-nodes 1 -y
sky launch -c verl-grpo llm/verl/verl-grpo.yaml --secret WANDB_API_KEY --num-nodes 1 -y
sky launch -c verl-grpo llm/verl/verl-grpo.yaml --secret WANDB_API_KEY --secret HF_TOKEN --num-nodes 1 -y
Launch a 2-node RLHF training job on the cheapest available GPUs:
sky launch -c verl llm/verl/multinode.yaml
Monitor training progress:
sky logs verl
Training logs showing PPO optimization progress with reward metrics
Access Ray dashboard:
sky status --endpoint 8280 verl
Ray dashboard showing real-time monitoring of distributed training across multiple nodes
Key Features#
The example trains Qwen2.5-0.5B-Instruct on the GSM8K dataset using PPO:
Multi-node distributed training with automatic Ray cluster setup
Checkpoint persistence to cloud storage for fault tolerance
Customizable models and datasets via environment variables
Optional: Enable W&B for Training Visualization#
To track training curves and metrics in Weights & Biases:
# 1. Set your W&B API key locally
export WANDB_API_KEY=your-api-key
# 2. Launch with the secret flag
sky launch -c verl llm/verl/multinode.yaml --secret WANDB_API_KEY
# 3. Edit multinode.yaml to enable W&B logger (see comments in the file)
Advanced Usage#
💰 Use Spot Instances for 3x Cost Savings#
sky jobs launch -n verl-job llm/verl/multinode.yaml
Training automatically resumes from checkpoints if preempted.
🚀 Continue Experiments on the Same Cluster#
# Run additional training epochs
sky exec verl llm/verl/multinode.yaml --env TOTAL_EPOCHS=10
# The YAML automatically detects and reuses the existing Ray cluster
📈 Scale to More Nodes#
sky launch -c verl llm/verl/multinode.yaml --num-nodes 4
🔧 Customize Training Configuration#
Modify parameters directly:
sky launch -c verl llm/verl/multinode.yaml \
--env MODEL_NAME=meta-llama/Llama-2-7b-hf \
--env ACTOR_LR=5e-6 \
--env CRITIC_LR=1e-5
Train a larger model:
sky launch -c verl llm/verl/multinode.yaml \
--env MODEL_NAME=Qwen/Qwen2.5-7B-Instruct \
--gpus A100-80GB:8 --num-nodes 4
Understanding the Setup#
Head node: Prepares data, starts Ray head, submits training job
Worker nodes: Join Ray cluster for distributed training
Smart resumption: Ray cluster is reused if already running, avoiding restart overhead
Troubleshooting#
OOM errors: Reduce batch sizes or
gpu_memory_utilizationConnection issues: Ensure ports 6385 (Ray) and 8280 (dashboard) are not blocked
First run is slow: Model download happens once, subsequent runs are faster
Learn More#
Included files#
code/preprocess_rstar_coder.py
# Copyright 2025 MIT
"""
Preprocess rStar-Coder dataset.
"""
import argparse
import os
import datasets
from verl.utils.hdfs_io import copy
from verl.utils.hdfs_io import makedirs
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--local_dir", default=None)
parser.add_argument("--hdfs_dir", default=None)
parser.add_argument("--local_save_dir", default="~/data/rstar_coder")
args = parser.parse_args()
data_source = "microsoft/rStar-Coder"
dataset = datasets.load_dataset(
data_source,
data_files="synthetic_sft/data-00000-of-00015.parquet",
split="train",
trust_remote_code=True)
data_source = 'openai/gsm8k'
# Split into train/test (90/10)
split_dataset = dataset.train_test_split(test_size=0.1, seed=42)
train_dataset = split_dataset["train"]
test_dataset = split_dataset["test"]
instruction_following = 'Let\'s think step by step and output the final answer after "####".'
def make_map_fn(split):
def process_fn(doc, idx):
question_raw = doc.get("question", "")
question = question_raw + " " + instruction_following
answer = doc.get("response") or doc.get("code", "")
data = {
"data_source": data_source,
"prompt": [{
"role": "user",
"content": question
}],
"ability": "code",
"reward_model": {
"style": "rule",
"ground_truth": answer
},
"extra_info": {
"split": split,
"index": idx,
}
}
return data
return process_fn
train_dataset = train_dataset.map(function=make_map_fn("train"),
with_indices=True)
test_dataset = test_dataset.map(function=make_map_fn("test"),
with_indices=True)
hdfs_dir = args.hdfs_dir
local_save_dir = args.local_dir
if local_save_dir is not None:
print(
"Warning: Argument 'local_dir' is deprecated. Please use 'local_save_dir' instead."
)
else:
local_save_dir = args.local_save_dir
train_dataset.to_parquet(os.path.join(local_save_dir, "train.parquet"))
test_dataset.to_parquet(os.path.join(local_save_dir, "test.parquet"))
if hdfs_dir is not None:
makedirs(hdfs_dir)
copy(src=local_save_dir, dst=hdfs_dir)
multinode.yaml
# Multi-node distributed training with Verl (Volcano Engine Reinforcement Learning) framework.
#
# Verl is a flexible and efficient reinforcement learning framework designed for
# training large language models with RLHF (Reinforcement Learning from Human Feedback).
# This example demonstrates multi-node training using PPO on the GSM8K dataset.
#
# Prerequisites:
# - GPU nodes with at least 40GB memory (e.g., A100)
# - Access to Hugging Face models (Qwen/Qwen2.5-0.5B-Instruct in this example)
#
# Usage:
# # Launch a 2-node training cluster:
# $ sky launch -c verl-cluster examples/verl/multinode.yaml
#
# # Monitor the Ray dashboard (optional):
# $ sky status --endpoint 8280 verl-cluster
#
# # Stream logs:
# $ sky logs verl-cluster
#
# # Cleanup:
# $ sky down verl-cluster
name: verl-multinode-training
resources:
accelerators:
- A100:8
- A100-80GB:8
- H100:8 # H100 for faster training, can also use A100-80GB:1
# cloud: lambda # Optional: specify cloud provider
use_spot: false # Set to true to use spot instances with managed jobs
ports:
- 8280 # Ray dashboard port
num_nodes: 2 # Number of nodes for distributed training
# Environment variables
envs:
HF_HUB_ENABLE_HF_TRANSFER: "1"
TORCH_NCCL_AVOID_RECORD_STREAMS: "1"
# Change this to your own checkpoint bucket
CHECKPOINT_BUCKET_NAME: sky-verl-checkpoints
# Optional: Add your W&B API key for experiment tracking
WANDB_API_KEY: null # Pass with `--secret WANDB_API_KEY` in CLI
# Training configuration
MODEL_NAME: Qwen/Qwen2.5-0.5B-Instruct
TOTAL_EPOCHS: 3
ACTOR_LR: 1e-6
CRITIC_LR: 1e-5
# Mount cloud storage for checkpoints
file_mounts:
/checkpoints:
name: ${CHECKPOINT_BUCKET_NAME}
mode: MOUNT
# Optionally, specify the store to enforce to use one of the stores below:
# r2/azure/gcs/s3/cos
# store: s3
setup: |
# Clone and setup Verl
rm -rf verl
git clone https://github.com/volcengine/verl.git
cd verl
# Create virtual environment and install dependencies
uv venv --seed
source .venv/bin/activate
# Install Verl and its dependencies (skip Megatron for this example)
USE_MEGATRON=0 bash scripts/install_vllm_sglang_mcore.sh
uv pip install --no-deps -e .
uv pip install "ray[default]" # For Ray dashboard
run: |
# Set up distributed training environment
head_ip=$(echo "$SKYPILOT_NODE_IPS" | head -n1)
num_nodes=$(echo "$SKYPILOT_NODE_IPS" | wc -l)
echo "Head IP: $head_ip"
echo "Number of nodes: $num_nodes"
cd verl
source .venv/bin/activate
# Create custom runtime environment configuration
cat > runtime_env_custom.yaml <<EOF
working_dir: ./
excludes: ["/.git/", "*.whl", "**/*.whl"]
env_vars:
TORCH_NCCL_AVOID_RECORD_STREAMS: "1"
CUDA_DEVICE_MAX_CONNECTIONS: "1"
HF_HUB_ENABLE_HF_TRANSFER: "1"
EOF
# Ray cluster configuration
HEAD_PORT=6385
DASH_PORT=8280
# Function to check if Ray is already running
is_ray_alive () {
ray status --address="$1:$HEAD_PORT" >/dev/null 2>&1
}
if [ "$SKYPILOT_NODE_RANK" == "0" ]; then
# Head node: prepare data, download model, start Ray head, and submit training job
echo "Setting up head node..."
# Install additional dependencies for data processing
uv pip install datasets transformers
# Prepare GSM8K dataset
python3 examples/data_preprocess/gsm8k.py --local_dir ~/data/gsm8k
# Download model to cache
python3 -c "import transformers; transformers.pipeline('text-generation', model='Qwen/Qwen2.5-0.5B-Instruct')"
# Start Ray head node if not already running
if ! is_ray_alive "$head_ip"; then
echo "Starting Ray head node..."
ray start --head --node-ip-address="$head_ip" \
--port $HEAD_PORT --dashboard-port $DASH_PORT \
--dashboard-host=0.0.0.0 \
--dashboard-agent-listen-port=52366 \
--disable-usage-stats \
--num-gpus=$SKYPILOT_NUM_GPUS_PER_NODE
sleep 10
else
echo "Ray is already running at $head_ip:$HEAD_PORT, reusing existing instance"
ray status --address="$head_ip:$HEAD_PORT"
fi
# Submit the training job to Ray
export RAY_ADDRESS="http://localhost:$DASH_PORT"
echo "Submitting training job to Ray cluster..."
ray job submit --address="$RAY_ADDRESS" --working-dir=. \
--runtime-env=runtime_env_custom.yaml \
-- python3 -m verl.trainer.main_ppo \
data.train_files=$HOME/data/gsm8k/train.parquet \
data.val_files=$HOME/data/gsm8k/test.parquet \
data.train_batch_size=128 \
data.max_prompt_length=512 \
data.max_response_length=256 \
actor_rollout_ref.model.path=$MODEL_NAME \
actor_rollout_ref.actor.optim.lr=$ACTOR_LR \
actor_rollout_ref.actor.ppo_mini_batch_size=64 \
actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=4 \
actor_rollout_ref.rollout.name=vllm \
actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=8 \
actor_rollout_ref.rollout.tensor_model_parallel_size=1 \
actor_rollout_ref.rollout.gpu_memory_utilization=0.4 \
actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=4 \
critic.optim.lr=$CRITIC_LR \
critic.model.path=$MODEL_NAME \
critic.ppo_micro_batch_size_per_gpu=4 \
critic.ppo_mini_batch_size=64 \
algorithm.kl_ctrl.kl_coef=0.001 \
trainer.project_name=ppo_training \
trainer.experiment_name=qwen-2.5-0.5B \
trainer.val_before_train=False \
trainer.n_gpus_per_node=$SKYPILOT_NUM_GPUS_PER_NODE \
trainer.nnodes=$num_nodes \
trainer.default_local_dir=/checkpoints \
trainer.save_freq=10 \
trainer.test_freq=10 \
trainer.total_epochs=$TOTAL_EPOCHS \
trainer.logger=['console'] \
trainer.resume_mode=auto 2>&1 | tee verl_training.log
# To enable W&B logging:
# 1. Set WANDB_API_KEY in envs or pass via --secret WANDB_API_KEY
# 2. Change trainer.logger to: trainer.logger=['console', 'wandb']
# 3. Add: trainer.wandb_project='verl-rlhf'
# 4. Add: trainer.wandb_run_name='${SKYPILOT_CLUSTER_NAME}-${SKYPILOT_TASK_ID}'
else
# Worker nodes: connect to Ray head
echo "Setting up worker node..."
echo "Head IP: $head_ip"
echo "HEAD_PORT: $HEAD_PORT"
echo "SKYPILOT_NUM_GPUS_PER_NODE: $SKYPILOT_NUM_GPUS_PER_NODE"
# Get this worker's IP address
worker_ip=$(hostname -I | awk '{print $1}')
echo "Worker IP: $worker_ip"
echo "Checking if worker $worker_ip is already in Ray cluster at $head_ip:$HEAD_PORT"
if ray list nodes --address=$head_ip:$HEAD_PORT 2>/dev/null | grep -q "$worker_ip"; then
echo "Worker $worker_ip already connected to Ray cluster"
ray status --address=$head_ip:$HEAD_PORT
else
echo "Worker not connected, waiting for head node to start"
sleep 20
echo "Starting Ray worker"
ray start --address $head_ip:$HEAD_PORT --disable-usage-stats --num-gpus=$SKYPILOT_NUM_GPUS_PER_NODE
echo "Ray start exit code: $?"
# Verify connection after starting
sleep 5
ray status --address=$head_ip:$HEAD_PORT
fi
fi
verl-grpo.yaml
# Usage:
# sky launch -c verl-grpo llm/verl/verl-grpo.yaml --secret WANDB_API_KEY --num-nodes 1 -y
#
# sky launch -c verl-grpo llm/verl/verl-grpo.yaml --secret WANDB_API_KEY --secret HF_TOKEN --num-nodes 1 -y
resources:
accelerators: H100:1
memory: 128+
image_id: docker:verlai/verl:app-verl0.6-transformers4.56.1-sglang0.5.2-mcore0.13.0-te2.2
ports:
- 8265
- 9090
envs:
TOTAL_EPOCHS: 1
WANDB_PROJECT_NAME: skypilot-verl
WANDB_EXPERIMENT_NAME: grpo-code
CHECKPOINT_BUCKET_NAME: sky-verl-grpo-checkpoints
HF_UPLOAD_MODEL_NAME: "maknee/verl-grpo-code"
SAVE_FINAL_MODEL_HF_PATH: /checkpoints/hf_model
file_mounts:
/checkpoints:
store: nebius
name: ${CHECKPOINT_BUCKET_NAME}
mode: MOUNT
/code:
name: code
source: llm/verl/code
mode: COPY
secrets:
HF_TOKEN: null
WANDB_API_KEY: null
setup: |
rm -f ~/.pip/pip.conf
rm -f ~/.config/pip/pip.conf
sudo apt install iproute2 -y
uv venv --python 3.10 --seed
source .venv/bin/activate
rm -rf verl
git clone https://github.com/volcengine/verl.git
cd verl
git checkout 83aebcc133663c12ac33ea3d5ba5c5c5b4687286
uv pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu126
uv pip install -v -e .
uv pip install hf_transfer
uv pip install flashinfer-python
uv pip install "vllm==0.10.0" --torch-backend=auto
uv pip install "https://github.com/Dao-AILab/flash-attention/releases/download/v2.8.3/flash_attn-2.8.3+cu12torch2.7cxx11abiFALSE-cp310-cp310-linux_x86_64.whl"
uv pip install datasets
uv pip install "ray[train]" "click<8.2.0"
uv pip install tqdm
echo "Downloading code dataset..."
mkdir -p ~/data/code
python3 /code/preprocess_rstar_coder.py --local_dir ~/data/code
echo "code dataset download completed"
run: |
HEAD_IP=$(echo "$SKYPILOT_NODE_IPS" | head -n1)
NUM_NODES=$SKYPILOT_NUM_NODES
NUM_GPUS_PER_NODE=$SKYPILOT_NUM_GPUS_PER_NODE
#NETWORK_INTERFACE=$(ip route get 8.8.8.8 | grep -oP 'src \K\S+')
#export GLOO_SOCKET_IFNAME=$NETWORK_INTERFACE
NETWORK_INTERFACE=$(ip route get 8.8.8.8 | grep -oP 'dev \K\S+')
export GLOO_SOCKET_IFNAME=$NETWORK_INTERFACE
export NCCL_SOCKET_IFNAME=$NETWORK_INTERFACE
export VLLM_USE_V1=1
source .venv/bin/activate
python3 -c "import wandb; wandb.login(relogin=True, key='$WANDB_API_KEY')"
if [ "$SKYPILOT_NODE_RANK" == "0" ]; then
echo "Starting Ray head node..."
ps aux | grep ray | grep 6379 &> /dev/null || ray start --head --disable-usage-stats \
--port=6379 \
--dashboard-host=0.0.0.0 \
--dashboard-port=8265
# Wait for all worker nodes to join
retry_count=0
max_retries=30
while [ $retry_count -lt $max_retries ]; do
connected_nodes=$(ray status 2>/dev/null | grep -c "node_" || echo "0")
echo "Connected nodes: $connected_nodes/$NUM_NODES (attempt $((retry_count+1))/$max_retries)"
if [ "$connected_nodes" -ge "$NUM_NODES" ]; then
echo "All nodes connected to Ray cluster"
break
fi
retry_count=$((retry_count+1))
sleep 10
done
python3 -m verl.trainer.main_ppo \
algorithm.adv_estimator=grpo \
data.train_files=$HOME/data/code/train.parquet \
data.val_files=$HOME/data/code/test.parquet \
data.train_batch_size=32 \
data.max_prompt_length=256 \
data.max_response_length=256 \
data.filter_overlong_prompts=True \
data.truncation='error' \
actor_rollout_ref.model.path=Qwen/Qwen2.5-7B-Instruct \
actor_rollout_ref.actor.optim.lr=1e-6 \
actor_rollout_ref.model.use_remove_padding=True \
actor_rollout_ref.actor.ppo_mini_batch_size=16 \
actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=4 \
actor_rollout_ref.actor.ppo_epochs=1 \
actor_rollout_ref.actor.use_kl_loss=False \
actor_rollout_ref.actor.entropy_coeff=0 \
actor_rollout_ref.model.enable_gradient_checkpointing=True \
actor_rollout_ref.actor.fsdp_config.param_offload=True \
actor_rollout_ref.actor.fsdp_config.optimizer_offload=True \
actor_rollout_ref.actor.fsdp_config.model_dtype=bfloat16 \
actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=16 \
actor_rollout_ref.rollout.tensor_model_parallel_size=1 \
actor_rollout_ref.rollout.name=vllm \
actor_rollout_ref.rollout.gpu_memory_utilization=0.4 \
actor_rollout_ref.rollout.n=1 \
actor_rollout_ref.rollout.enable_chunked_prefill=True \
actor_rollout_ref.rollout.max_num_batched_tokens=2048 \
actor_rollout_ref.rollout.trace.backend=weave \
actor_rollout_ref.rollout.trace.token2text=True \
actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=16 \
actor_rollout_ref.ref.fsdp_config.param_offload=True \
algorithm.use_kl_in_reward=False \
trainer.critic_warmup=0 \
trainer.logger=[console,wandb] \
trainer.n_gpus_per_node=$NUM_GPUS_PER_NODE \
trainer.nnodes=$NUM_NODES \
trainer.save_freq=10 \
trainer.test_freq=1 \
trainer.total_epochs=${TOTAL_EPOCHS} \
trainer.default_local_dir=/checkpoints \
trainer.project_name=$WANDB_PROJECT_NAME \
trainer.experiment_name=$WANDB_EXPERIMENT_NAME
LATEST_STEP=$(cat /checkpoints/latest_checkpointed_iteration.txt)
CHECKPOINT_DIR="/checkpoints/global_step_${LATEST_STEP}/actor"
if [ -z "$HF_TOKEN" ]; then
python -m verl.model_merger merge \
--backend fsdp \
--tie-word-embedding \
--local_dir ${CHECKPOINT_DIR} \
--target_dir ${SAVE_FINAL_MODEL_HF_PATH} \
--hf_upload_path ${HF_UPLOAD_MODEL_NAME}
else
python -m verl.model_merger merge \
--backend fsdp \
--tie-word-embedding \
--local_dir ${CHECKPOINT_DIR} \
--target_dir ${SAVE_FINAL_MODEL_HF_PATH}
fi
vllm serve /checkpoints/hf_model \
--host 0.0.0.0 \
--port 9090
else
sleep 15
echo "Starting Ray worker node..."
ps aux | grep ray | grep $HEAD_IP:6379 &> /dev/null || ray start --address $HEAD_IP:6379 --disable-usage-stats
sleep 10
fi
echo "Node setup and Ray start script finished for rank $SKYPILOT_NODE_RANK."
verl-ppo.yaml
# Usage:
# sky launch -c verl-ppo llm/verl/verl-ppo.yaml --secret WANDB_API_KEY --num-nodes 1 -y
#
# sky launch -c verl-ppo llm/verl/verl-ppo.yaml --secret WANDB_API_KEY --secret HF_TOKEN --num-nodes 1 -y
resources:
infra: nebius
accelerators: H100:1
memory: 128+
image_id: docker:verlai/verl:app-verl0.6-transformers4.56.1-sglang0.5.2-mcore0.13.0-te2.2
ports:
- 8265
- 9090
num_nodes: 1
envs:
TOTAL_EPOCHS: 1
WANDB_PROJECT_NAME: skypilot-verl
WANDB_EXPERIMENT_NAME: ppo-math
CHECKPOINT_BUCKET_NAME: sky-verl-ppo-checkpoints
HF_UPLOAD_MODEL_NAME: "maknee/verl-ppo-math"
SAVE_FINAL_MODEL_HF_PATH: /checkpoints/hf_model
file_mounts:
/checkpoints:
store: nebius
name: ${CHECKPOINT_BUCKET_NAME}
mode: MOUNT
secrets:
HF_TOKEN: null
WANDB_API_KEY: null
setup: |
rm -f ~/.pip/pip.conf
rm -f ~/.config/pip/pip.conf
sudo apt install iproute2 -y
uv venv --python 3.10 --seed
source .venv/bin/activate
rm -rf verl
git clone https://github.com/volcengine/verl.git
cd verl
git checkout 83aebcc133663c12ac33ea3d5ba5c5c5b4687286
uv pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu126
uv pip install -v -e .
uv pip install hf_transfer
uv pip install flashinfer-python
uv pip install "vllm==0.10.0" --torch-backend=auto
uv pip install "https://github.com/Dao-AILab/flash-attention/releases/download/v2.8.3/flash_attn-2.8.3+cu12torch2.7cxx11abiFALSE-cp310-cp310-linux_x86_64.whl"
uv pip install datasets
uv pip install "ray[train]" "click<8.2.0"
uv pip install tqdm
echo "Downloading Math dataset..."
mkdir -p ~/data/math
python3 "$(pwd)/examples/data_preprocess/math_dataset.py" --local_dir ~/data/math
echo "Math dataset download completed"
uv pip install zmq
run: |
HEAD_IP=$(echo "$SKYPILOT_NODE_IPS" | head -n1)
NUM_NODES=$SKYPILOT_NUM_NODES
NUM_GPUS_PER_NODE=$SKYPILOT_NUM_GPUS_PER_NODE
#NETWORK_INTERFACE=$(ip route get 8.8.8.8 | grep -oP 'src \K\S+')
#export GLOO_SOCKET_IFNAME=$NETWORK_INTERFACE
NETWORK_INTERFACE=$(ip route get 8.8.8.8 | grep -oP 'dev \K\S+')
export GLOO_SOCKET_IFNAME=$NETWORK_INTERFACE
export NCCL_SOCKET_IFNAME=$NETWORK_INTERFACE
export VLLM_USE_V1=1
source .venv/bin/activate
python3 -c "import wandb; wandb.login(relogin=True, key='$WANDB_API_KEY')"
if [ "$SKYPILOT_NODE_RANK" == "0" ]; then
echo "Starting Ray head node..."
ps aux | grep ray | grep 6379 &> /dev/null || ray start --head --disable-usage-stats \
--port=6379 \
--dashboard-host=0.0.0.0 \
--dashboard-port=8265
# Wait for all worker nodes to join
retry_count=0
max_retries=30
while [ $retry_count -lt $max_retries ]; do
connected_nodes=$(ray status 2>/dev/null | grep -c "node_" || echo "0")
echo "Connected nodes: $connected_nodes/$NUM_NODES (attempt $((retry_count+1))/$max_retries)"
if [ "$connected_nodes" -ge "$NUM_NODES" ]; then
echo "All nodes connected to Ray cluster"
break
fi
retry_count=$((retry_count+1))
sleep 10
done
python3 -m verl.trainer.main_ppo \
data.train_files=$HOME/data/math/train.parquet \
data.val_files=$HOME/data/math/test.parquet \
data.train_batch_size=256 \
data.max_prompt_length=1024 \
data.max_response_length=1024 \
data.filter_overlong_prompts=True \
actor_rollout_ref.model.path=Qwen/Qwen2.5-0.5B-Instruct \
actor_rollout_ref.actor.optim.lr=1e-6 \
actor_rollout_ref.actor.ppo_mini_batch_size=64 \
actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=4 \
actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=8 \
actor_rollout_ref.rollout.tensor_model_parallel_size=1 \
actor_rollout_ref.rollout.name=vllm \
actor_rollout_ref.rollout.gpu_memory_utilization=0.4 \
actor_rollout_ref.rollout.trace.backend=weave \
actor_rollout_ref.rollout.trace.token2text=True \
actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=4 \
actor_rollout_ref.actor.fsdp_config.model_dtype=bfloat16 \
critic.optim.lr=1e-5 \
critic.model.path=Qwen/Qwen2.5-0.5B-Instruct \
critic.ppo_micro_batch_size_per_gpu=4 \
critic.model.fsdp_config.model_dtype=bfloat16 \
algorithm.kl_ctrl.kl_coef=0.001 \
trainer.logger=[console,wandb] \
trainer.val_before_train=False \
trainer.n_gpus_per_node=$NUM_GPUS_PER_NODE \
trainer.nnodes=$NUM_NODES \
trainer.save_freq=10 \
trainer.test_freq=1 \
trainer.default_local_dir=/checkpoints \
trainer.total_epochs=${TOTAL_EPOCHS} \
trainer.project_name=$WANDB_PROJECT_NAME \
trainer.experiment_name=$WANDB_EXPERIMENT_NAME
LATEST_STEP=$(cat /checkpoints/latest_checkpointed_iteration.txt)
CHECKPOINT_DIR="/checkpoints/global_step_${LATEST_STEP}/actor"
if [ -n "$HF_TOKEN" ]; then
python -m verl.model_merger merge \
--backend fsdp \
--tie-word-embedding \
--local_dir ${CHECKPOINT_DIR} \
--target_dir ${SAVE_FINAL_MODEL_HF_PATH} \
--hf_upload_path ${HF_UPLOAD_MODEL_NAME}
else
python -m verl.model_merger merge \
--backend fsdp \
--tie-word-embedding \
--local_dir ${CHECKPOINT_DIR} \
--target_dir ${SAVE_FINAL_MODEL_HF_PATH}
fi
vllm serve /checkpoints/hf_model \
--host 0.0.0.0 \
--port 9090
else
sleep 15
echo "Starting Ray worker node..."
ps aux | grep ray | grep $HEAD_IP:6379 &> /dev/null || ray start --address $HEAD_IP:6379 --disable-usage-stats
sleep 10
fi
echo "Node setup and Ray start script finished for rank $SKYPILOT_NODE_RANK."