Cloud TPU#
SkyPilot supports running jobs on Google’s Cloud TPU, a specialized hardware accelerator for ML workloads.
Free TPUs via TPU Research Cloud (TRC)#
ML researchers and students are encouraged to apply for free TPU access through TPU Research Cloud (TRC) program!
Getting TPUs in one command#
Use one command to quickly get TPU nodes for development:
# Use latest TPU v6 (Trillium) VMs:
sky launch --gpus tpu-v6e-8
# Use TPU v4 (Titan) VMs:
sky launch --gpus tpu-v4-8
# Preemptible TPUs:
sky launch --gpus tpu-v6e-8 --use-spot
After the command finishes, you will be dropped into a TPU host VM and can start developing code right away.
Below, we show examples of using SkyPilot to (1) train LLMs on TPU VMs/Pods and (2) train MNIST on TPU Nodes (legacy).
TPU Architectures#
Two different TPU architectures are available on GCP:
Both are supported by SkyPilot. We recommend TPU VMs and Pods which are newer architectures encouraged by GCP.
The two architectures differ as follows.
For TPU VMs/Pods, you can directly SSH into the “TPU host” VM that is physically connected to the TPU device.
For TPU Nodes, a user VM (an n1 instance) must be separately provisioned to communicate with an inaccessible TPU host over gRPC.
More details can be found on GCP documentation.
TPU VMs/Pods#
Google’s latest TPU v6 (Trillium) VMs offers great performance and it is now supported by SkyPilot.
To use TPU VMs/Pods, set the following in a task YAML’s resources
field:
resources:
accelerators: tpu-v6e-8
accelerator_args:
runtime_version: v2-alpha-tpuv6e # optional
The accelerators
field specifies the TPU type, and the accelerator_args
dict includes the optional tpu_vm
bool (defaults to true, which means TPU VM is used), and an optional TPU runtime_version
field.
To show what TPU types are supported, run sky show-gpus
.
Here is a complete task YAML that trains a Llama 3 model on a TPU VM using Torch XLA.
resources:
accelerators: tpu-v6e-8 # Fill in the accelerator type you want to use
envs:
HF_TOKEN: # fill in your huggingface token
workdir: .
setup: |
pip3 install huggingface_hub
python3 -c "import huggingface_hub; huggingface_hub.login('${HF_TOKEN}')"
# Setup TPU
pip3 install cloud-tpu-client
sudo apt update
sudo apt install -y libopenblas-base
pip3 install --pre torch==2.6.0.dev20240916+cpu torchvision==0.20.0.dev20240916+cpu \
--index-url https://download.pytorch.org/whl/nightly/cpu
pip install "torch_xla[tpu]@https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.6.0.dev20240916-cp310-cp310-linux_x86_64.whl" \
-f https://storage.googleapis.com/libtpu-releases/index.html
pip install torch_xla[pallas] \
-f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \
-f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
# Setup runtime for training
git clone -b flash_attention https://github.com/pytorch-tpu/transformers.git
cd transformers
pip3 install -e .
pip3 install datasets evaluate scikit-learn accelerate
run: |
unset LD_PRELOAD
PJRT_DEVICE=TPU XLA_USE_SPMD=1 ENABLE_PJRT_COMPATIBILITY=true \
python3 transformers/examples/pytorch/language-modeling/run_clm.py \
--dataset_name wikitext \
--dataset_config_name wikitext-2-raw-v1 \
--per_device_train_batch_size 16 \
--do_train \
--output_dir /home/$USER/tmp/test-clm \
--overwrite_output_dir \
--config_name /home/$USER/sky_workdir/config-8B.json \
--cache_dir /home/$USER/cache \
--tokenizer_name meta-llama/Meta-Llama-3-8B \
--block_size 8192 \
--optim adafactor \
--save_strategy no \
--logging_strategy no \
--fsdp "full_shard" \
--fsdp_config /home/$USER/sky_workdir/fsdp_config.json \
--torch_dtype bfloat16 \
--dataloader_drop_last yes \
--flash_attention \
--max_steps 20
This YAML lives under the SkyPilot repo, or you can paste it into a local file.
Launch it with:
$ HF_TOKEN=<your-huggingface-token> sky launch train-llama3-8b.yaml -c llama-3-train --env HF_TOKEN
You should see the following outputs when the job finishes.
$ sky launch train-llama3-8b.yaml -c llama-3-train
(task, pid=17499) ***** train metrics *****
(task, pid=17499) epoch = 1.1765
(task, pid=17499) total_flos = 109935420GF
(task, pid=17499) train_loss = 10.6011
(task, pid=17499) train_runtime = 0:11:12.77
(task, pid=17499) train_samples = 282
(task, pid=17499) train_samples_per_second = 0.476
(task, pid=17499) train_steps_per_second = 0.03
Multi-Host TPU Pods#
A TPU Pod is a collection of TPU devices connected by dedicated high-speed network interfaces for high-performance training.
To use a TPU Pod, simply change the accelerators
field in the task YAML (e.g., tpu-v6e-8
-> tpu-v6e-32
).
resources:
accelerators: tpu-v6e-32 # Pods have > 8 cores (the last number)
Note
Both TPU architectures, TPU VMs and TPU Nodes, can be used with TPU Pods. The example below is based on TPU VMs.
To show all available TPU Pod types, run sky show-gpus
(more than 8 cores means Pods):
GOOGLE_TPU AVAILABLE_QUANTITIES
tpu-v6e-8 1
tpu-v6e-32 1
tpu-v6e-128 1
tpu-v6e-256 1
...
After creating a TPU Pod, multiple host VMs (e.g., tpu-v6e-32
comes with 4 host VMs) are launched.
Normally, the user needs to SSH into all hosts to prepare files and setup environments, and
then launch the job on each host, which is a tedious and error-prone process.
SkyPilot automates away this complexity. From your laptop, a single sky launch
command will perform:
workdir/file_mounts syncing; and
execute the setup/run commands on every host of the pod.
We can run the same Llama 3 training job in on a TPU Pod with the following command, with a slight change to the YAML (--per_device_train_batch_size
from 16 to 32):
$ HF_TOKEN=<your-huggingface-token> sky launch -c tpu-pod --gpus tpu-v6e-32 train-llama3-8b.yaml --env HF_TOKEN
You should see the following output.
(head, rank=0, pid=17894) ***** train metrics *****
(head, rank=0, pid=17894) epoch = 2.5
(head, rank=0, pid=17894) total_flos = 219870840GF
(head, rank=0, pid=17894) train_loss = 10.1527
(head, rank=0, pid=17894) train_runtime = 0:11:13.18
(head, rank=0, pid=17894) train_samples = 282
(head, rank=0, pid=17894) train_samples_per_second = 0.951
(head, rank=0, pid=17894) train_steps_per_second = 0.03
(worker1, rank=1, pid=15406, ip=10.164.0.57) ***** train metrics *****
(worker1, rank=1, pid=15406, ip=10.164.0.57) epoch = 2.5
(worker1, rank=1, pid=15406, ip=10.164.0.57) total_flos = 219870840GF
(worker1, rank=1, pid=15406, ip=10.164.0.57) train_loss = 10.1527
(worker1, rank=1, pid=15406, ip=10.164.0.57) train_runtime = 0:11:15.08
(worker1, rank=1, pid=15406, ip=10.164.0.57) train_samples = 282
(worker1, rank=1, pid=15406, ip=10.164.0.57) train_samples_per_second = 0.948
(worker1, rank=1, pid=15406, ip=10.164.0.57) train_steps_per_second = 0.03
(worker2, rank=2, pid=16552, ip=10.164.0.58) ***** train metrics *****
(worker2, rank=2, pid=16552, ip=10.164.0.58) epoch = 2.5
(worker2, rank=2, pid=16552, ip=10.164.0.58) total_flos = 219870840GF
(worker2, rank=2, pid=16552, ip=10.164.0.58) train_loss = 10.1527
(worker2, rank=2, pid=16552, ip=10.164.0.58) train_runtime = 0:11:15.61
(worker2, rank=2, pid=16552, ip=10.164.0.58) train_samples = 282
(worker2, rank=2, pid=16552, ip=10.164.0.58) train_samples_per_second = 0.947
(worker2, rank=2, pid=16552, ip=10.164.0.58) train_steps_per_second = 0.03
(worker3, rank=3, pid=17469, ip=10.164.0.59) ***** train metrics *****
(worker3, rank=3, pid=17469, ip=10.164.0.59) epoch = 2.5
(worker3, rank=3, pid=17469, ip=10.164.0.59) total_flos = 219870840GF
(worker3, rank=3, pid=17469, ip=10.164.0.59) train_loss = 10.1527
(worker3, rank=3, pid=17469, ip=10.164.0.59) train_runtime = 0:11:15.10
(worker3, rank=3, pid=17469, ip=10.164.0.59) train_samples = 282
(worker3, rank=3, pid=17469, ip=10.164.0.59) train_samples_per_second = 0.948
(worker3, rank=3, pid=17469, ip=10.164.0.59) train_steps_per_second = 0.03
To submit more jobs to the same TPU Pod, use sky exec
:
$ HF_TOKEN=<your-huggingface-token> sky exec tpu-pod train-llama3-8b.yaml --env HF_TOKEN
You can find more useful examples for Serving LLMs on TPUs in SkyPilot repo.
TPU Nodes (Legacy)#
In a TPU Node, a normal CPU VM (an n1 instance) needs to be provisioned to communicate with the TPU host/device.
To use a TPU Node, set the following in a task YAML’s resources
field:
resources:
instance_type: n1-highmem-8
accelerators: tpu-v2-8
accelerator_args:
runtime_version: 2.12.0 # optional, TPU runtime version.
tpu_vm: False
The above YAML considers n1-highmem-8
as the host machine and tpu-v2-8
as the TPU node resource.
You can modify the host instance type or the TPU type.
Here is a complete task YAML that runs MNIST training on a TPU Node using TensorFlow.
name: mnist-tpu-node
resources:
accelerators: tpu-v2-8
accelerator_args:
runtime_version: 2.12.0 # optional, TPU runtime version.
tpu_vm: False
# TPU node requires loading data from a GCS bucket.
# We use SkyPilot bucket mounting to mount a GCS bucket to /dataset.
file_mounts:
/dataset:
name: mnist-tpu-node
store: gcs
mode: MOUNT
setup: |
git clone https://github.com/tensorflow/models.git
conda activate mnist
if [ $? -eq 0 ]; then
echo 'conda env exists'
else
conda create -n mnist python=3.8 -y
conda activate mnist
pip install tensorflow==2.12.0 tensorflow-datasets tensorflow-model-optimization cloud-tpu-client
fi
run: |
conda activate mnist
cd models/official/legacy/image_classification/
export STORAGE_BUCKET=gs://mnist-tpu-node
export MODEL_DIR=${STORAGE_BUCKET}/mnist
export DATA_DIR=${STORAGE_BUCKET}/data
export PYTHONPATH=/home/gcpuser/sky_workdir/models
python3 mnist_main.py \
--tpu=${TPU_NAME} \
--model_dir=${MODEL_DIR} \
--data_dir=${DATA_DIR} \
--train_epochs=10 \
--distribution_strategy=tpu \
--download
Note
TPU node requires loading data from a GCS bucket. The file_mounts
spec above simplifies this by using SkyPilot bucket mounting to create a new bucket/mount an existing bucket.
If you encounter a bucket Permission denied
error,
make sure the bucket is created in the same region as the Host VM/TPU Nodes and IAM permission for Cloud TPU is
correctly setup (follow instructions here).
Note
The special environment variable $TPU_NAME
is automatically set by SkyPilot at run time, so it can be used in the run
commands.
This YAML lives under the SkyPilot repo (examples/tpu/tpu_node_mnist.yaml
). Launch it with:
$ sky launch examples/tpu/tpu_node_mnist.yaml -c mycluster
...
(mnist-tpu-node pid=28961) Epoch 9/10
(mnist-tpu-node pid=28961) 58/58 [==============================] - 1s 19ms/step - loss: 0.1181 - sparse_categorical_accuracy: 0.9646 - val_loss: 0.0921 - val_sparse_categorical_accuracy: 0.9719
(mnist-tpu-node pid=28961) Epoch 10/10
(mnist-tpu-node pid=28961) 58/58 [==============================] - 1s 20ms/step - loss: 0.1139 - sparse_categorical_accuracy: 0.9655 - val_loss: 0.0831 - val_sparse_categorical_accuracy: 0.9742
...
(mnist-tpu-node pid=28961) {'accuracy_top_1': 0.9741753339767456, 'eval_loss': 0.0831054300069809, 'loss': 0.11388632655143738, 'training_accuracy_top_1': 0.9654667377471924}