Model Training Guide#
This guide covers the best practices and examples for achieving high performance distributed training using SkyPilot.
Distributed training basics#
SkyPilot supports all distributed training frameworks, including but not limited to:
The choice of framework depends on your specific needs, but all can be easily configured through SkyPilot’s YAML specification.
Best practices#
High performance instances#
Choose high performance instances for optimal training performance. SkyPilot allows you to specify instance types with powerful GPUs and high-bandwidth networking:
Use the latest GPU accelerators (A100, H100, etc.) for faster training
Consider instances with higher memory bandwidth and higher device memory for large models
Example configuration:
resources:
accelerators:
A100:1
A100-80GB:1
H100:1
Use disk_tier: best
#
Fast storage is critical for loading and storing data and model checkpoints.
SkyPilot’s disk_tier
option supports the fastest available storage with high-performance local SSDs to reduce I/O bottlenecks.
Example configuration:
resources:
disk_tier: best # Use highest performance disk tier.
disk_size: 1000 # GiB. Make the disk size large enough for checkpoints.
Use MOUNT_CACHED
for checkpointing#
Cloud buckets with the MOUNT_CACHED
mode provides high performance writing, making it ideal for model checkpoints, logs, and other outputs with fast local writes.
Unlike MOUNT
mode, it supports all write and append operations by using local disk as a cache for the files to be writen to cloud buckets. It can offer up to 9x writing speed of large checkpoints compared to the MOUNT mode.
Example configuration:
file_mounts:
/checkpoints:
name: my-checkpoint-bucket
mode: MOUNT_CACHED
For more on the differences between MOUNT
and MOUNT_CACHED
, see storage mounting modes.
Robust checkpointing for spot instances#
When using spot instances, robust checkpointing is crucial for recovering from preemptions. Your job should follow two key principles:
Write checkpoints periodically during training to save your progress
Always attempt to load checkpoints on startup, regardless of whether it’s the first run or a restart after preemption
This approach ensures your job can seamlessly resume from where it left off after preemption. On the first run, no checkpoints will exist, but on subsequent restarts, your job will automatically recover its state.
Basic checkpointing#
Saving to the bucket is easy – simply save to the mounted directory /checkpoints
specified above as if it is a local disk.
def save_checkpoint(step: int, model: torch.nn.Module):
# save checkpoint to local disk with step number
torch.save(model.state_dict(), f'/checkpoints/model_{step}.pt')
To make loading checkpoint robust against preemptions and incomplete checkpoitns, here is the recipe:
Always try loading from the latest checkpoint first
If the latest checkpoint is found to be corrupted or incomplete, fallback to earlier checkpoints
Here’s a simplified example showing the core concepts for torch.save
:
def load_checkpoint(save_dir: str='/checkpoints'):
try:
# Find all checkpoints, sorted by step (newest first)
checkpoints = sorted(
[f for f in Path(save_dir).glob("checkpoint_*.pt")],
key=lambda x: int(x.stem.split('_')[-1]),
reverse=True
)
# Try each checkpoint from newest to oldest
for checkpoint in checkpoints:
try:
step = int(checkpoint.stem.split('_')[-1])
result = load_checkpoint(checkpoint) # need to fill in
return result
except Exception as e:
logger.warning(f"Failed to load checkpoint {step}: {e}")
continue
except Exception as e:
logger.error(f"Failed to find checkpoints: {e}")
return None
Robust checkpointing with error handling#
For a complete implementation with additional features like custom prefixes, extended metadata, and more detailed error handling, see the code below:
Here are some common ways to use the checkpointing system:
Basic model saving:
@save_checkpoint(save_dir="checkpoints")
def save_model(step: int, model: torch.nn.Module):
torch.save(model.state_dict(), f"checkpoints/model_{step}.pt")
Saving with optimizer state:
@save_checkpoint(save_dir="checkpoints")
def save_training_state(step: int, model: torch.nn.Module, optimizer: torch.optim.Optimizer):
torch.save({
'model': model.state_dict(),
'optimizer': optimizer.state_dict(),
'step': step
}, f"checkpoints/training_{step}.pt")
Saving with metrics and custom prefix:
@save_checkpoint(save_dir="checkpoints", checkpoint_prefix="experiment1")
def save_with_metrics(step: int, model: torch.nn.Module, metrics: Dict[str, float]):
torch.save({
'model': model.state_dict(),
'metrics': metrics,
'step': step
}, f"checkpoints/experiment1_step_{step}.pt")
Loading checkpoints:
# Basic model loading
@load_checkpoint(save_dir="checkpoints")
def load_model(step: int, model: torch.nn.Module):
model.load_state_dict(torch.load(f"checkpoints/model_{step}.pt"))
# Loading with optimizer
@load_checkpoint(save_dir="checkpoints")
def load_training_state(step: int, model: torch.nn.Module, optimizer: torch.optim.Optimizer):
checkpoint = torch.load(f"checkpoints/training_{step}.pt")
model.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer'])
return checkpoint['step']
# Loading with custom prefix and metrics
@load_checkpoint(save_dir="checkpoints", checkpoint_prefix="experiment1")
def load_with_metrics(step: int, model: torch.nn.Module):
checkpoint = torch.load(f"checkpoints/experiment1_step_{step}.pt")
model.load_state_dict(checkpoint['model'])
return checkpoint['metrics']
Examples#
BERT end-to-end#
We can take the SkyPilot YAML for BERT fine-tuning from above, and add checkpointing/recovery to get everything working end-to-end.
Note
You can find all the code for this example in the documentation
In this example, we fine-tune a BERT model on a question-answering task with HuggingFace.
This example:
has SkyPilot find a V100 instance on any cloud,
uses spot instances to save cost, and
uses checkpointing to recover preempted jobs quickly.
# bert_qa.yaml
name: bert-qa
resources:
accelerators: V100:1
use_spot: true # Use spot instances to save cost.
disk_tier: best # using highest performance disk tier
file_mounts:
/checkpoint:
name: # NOTE: Fill in your bucket name
mode: MOUNT_CACHED
envs:
# Fill in your wandb key: copy from https://wandb.ai/authorize
# Alternatively, you can use `--env WANDB_API_KEY=$WANDB_API_KEY`
# to pass the key in the command line, during `sky jobs launch`.
WANDB_API_KEY:
# Assume your working directory is under `~/transformers`.
workdir: ~/transformers
setup: |
pip install -e .
cd examples/pytorch/question-answering/
pip install -r requirements.txt torch==1.12.1+cu113 --extra-index-url https://download.pytorch.org/whl/cu113
pip install wandb
run: |
cd examples/pytorch/question-answering/
python run_qa.py \
--model_name_or_path bert-base-uncased \
--dataset_name squad \
--do_train \
--do_eval \
--per_device_train_batch_size 12 \
--learning_rate 3e-5 \
--num_train_epochs 50 \
--max_seq_length 384 \
--doc_stride 128 \
--report_to wandb \
--output_dir /checkpoint/bert_qa/ \
--run_name $SKYPILOT_TASK_ID \
--save_total_limit 10 \
--save_steps 1000
The highlighted lines add a bucket for checkpoints. As HuggingFace has built-in support for periodic checkpointing, we just need to pass the highlighted arguments to save checkpoints to the bucket. (See more on Huggingface API). To see another example of periodic checkpointing with PyTorch, check out our ResNet example.
We also set --run_name
to $SKYPILOT_TASK_ID
so that the logs for all recoveries of the same job will be saved
to the same run in Weights & Biases.
Note
The environment variable $SKYPILOT_TASK_ID
(example: “sky-managed-2022-10-06-05-17-09-750781_bert-qa_8-0”) can be used to identify the same job, i.e., it is kept identical across all
recoveries of the job.
It can be accessed in the task’s run
commands or directly in the program itself (e.g., access
via os.environ
and pass to Weights & Biases for tracking purposes in your training script). It is made available to
the task whenever it is invoked. See more about environment variables provided by SkyPilot.
With the highlighted changes, the managed job can now resume training after preemption! We can enjoy the benefits of cost savings from spot instances without worrying about preemption or losing progress.
$ sky jobs launch -n bert-qa bert_qa.yaml
Real-world examples#
Vicuna LLM chatbot: instructions, YAML
Large-scale vector database ingestion, and the blog post about it
BERT (shown above): YAML
PyTorch DDP, ResNet: YAML
PyTorch Lightning DDP, CIFAR-10: YAML