Source: examples/tpu
TPU#
This example shows how to launch TPU jobs with SkyPilot.
Note: Some examples may be old. See the
v6e/
files for the latest examples. See also: https://docs.skypilot.co/en/latest/reference/tpu.html.
Included files#
tpu_app.py
import sky
with sky.Dag() as dag:
# The working directory contains all code and will be synced to remote.
workdir = './examples/tpu/tpu_app_code'
# The setup command. Will be run under the working directory.
setup = 'pip install --upgrade pip && \
conda activate huggingface || \
(conda create -n huggingface python=3.8 -y && \
conda activate huggingface && \
pip install -r requirements.txt)'
# The command to run. Will be run under the working directory.
run = 'conda activate huggingface && python -u run_tpu.py'
train = sky.Task(
'train',
workdir=workdir,
setup=setup,
run=run,
)
train.set_resources({
sky.Resources(accelerators='tpu-v3-8',
accelerator_args={
'runtime_version': '2.12.0',
'tpu_name': 'weilin-bert-test-big'
}),
})
sky.launch(dag)
tpu_app.yaml
name: tpu_app
# The working directory contains all code and will be synced to remote.
workdir: ./examples/tpu/tpu_app_code
resources:
accelerators: tpu-v2-8
accelerator_args:
runtime_version: 2.12.0
tpu_vm: False
# The setup command. Will be run under the working directory.
setup: |
pip install --upgrade pip
conda activate huggingface
if [ $? -eq 0 ]; then
echo 'conda env exists'
else
conda create -n huggingface python=3.8 -y
conda activate huggingface
pip install -r requirements.txt
fi
# The command to run. Will be run under the working directory.
run: |
conda activate huggingface
python -u run_tpu.py
tpu_app_code/requirements.txt
tensorflow==2.5.1
tensorflow-datasets==4.4.0
transformers==4.12.0
tensorflow-text==2.5.0
cloud-tpu-client==0.10
tpu_app_code/run_tpu.py
import tensorflow as tf
import tensorflow_datasets as tfds
import tensorflow_text as tf_text
from transformers import TFBertForSequenceClassification
from transformers import TFDistilBertForSequenceClassification
tpu = tf.distribute.cluster_resolver.TPUClusterResolver()
tf.config.experimental_connect_to_cluster(tpu)
tf.tpu.experimental.initialize_tpu_system(tpu)
strategy = tf.distribute.experimental.TPUStrategy(tpu)
ds_train, ds_info = tfds.load('amazon_us_reviews/Books_v1_02',
split='train[:5%]',
with_info=True,
data_dir="gs://weilin-bert-test")
MAX_SEQ_LEN = 512
bert_tokenizer = tf_text.BertTokenizer(
vocab_lookup_table='gs://weilin-bert-test/vocab.txt',
token_out_type=tf.int64,
lower_case=True)
def preprocessing_fn(inputs):
"""Preprocess input column of text into transformed columns of.
* input token ids
* input mask
* input type ids
"""
CLS_ID = tf.constant(101, dtype=tf.int64)
SEP_ID = tf.constant(102, dtype=tf.int64)
PAD_ID = tf.constant(0, dtype=tf.int64)
def tokenize_text(text, sequence_length=MAX_SEQ_LEN):
"""
Perform the BERT preprocessing from text -> input token ids
"""
# convert text into token ids
tokens = bert_tokenizer.tokenize(text)
# flatten the output ragged tensors
tokens = tokens.merge_dims(1, 2)[:, :sequence_length]
# Add start and end token ids to the id sequence
start_tokens = tf.fill([tf.shape(text)[0], 1], CLS_ID)
end_tokens = tf.fill([tf.shape(text)[0], 1], SEP_ID)
tokens = tokens[:, :sequence_length - 2]
tokens = tf.concat([start_tokens, tokens, end_tokens], axis=1)
# truncate sequences greater than MAX_SEQ_LEN
tokens = tokens[:, :sequence_length]
# pad shorter sequences with the pad token id
tokens = tokens.to_tensor(default_value=PAD_ID)
pad = sequence_length - tf.shape(tokens)[1]
tokens = tf.pad(tokens, [[0, 0], [0, pad]], constant_values=PAD_ID)
# and finally reshape the word token ids to fit the output
# data structure of TFT
return tf.reshape(tokens, [-1, sequence_length])
def preprocess_bert_input(text):
"""
Convert input text into the input_word_ids, input_mask, input_type_ids
"""
input_word_ids = tokenize_text(text)
input_mask = tf.cast(input_word_ids > 0, tf.int64)
input_mask = tf.reshape(input_mask, [-1, MAX_SEQ_LEN])
zeros_dims = tf.stack(tf.shape(input_mask))
input_type_ids = tf.fill(zeros_dims, 0)
input_type_ids = tf.cast(input_type_ids, tf.int64)
return (tf.squeeze(input_word_ids,
axis=0), tf.squeeze(input_mask, axis=0),
tf.squeeze(input_type_ids, axis=0))
input_word_ids, input_mask, input_type_ids = preprocess_bert_input(
[inputs['data']['review_body']])
return (dict({
'input_ids': input_word_ids,
'token_type_ids': input_type_ids,
'attention_mask': input_mask
}), inputs['data']['star_rating'])
from transformers import BertTokenizerFast
tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')
def dataset_fn(ds):
return ds.filter(lambda x: x['data']['helpful_votes'] >= 7)
ds_train_filtered = ds_train.apply(dataset_fn)
def process(example):
return (dict(tokenizer(
example['data']['review_body'].numpy().decode('utf-8')),
truncation=True,
padding=True), example['data']['star_rating'].numpy())
def process_py(inp1, inp2):
return [
dict(tokenizer(inp1.numpy().decode('utf-8')),
truncation=True,
padding=True),
inp2.numpy()
]
ds_train_filtered_2 = ds_train_filtered.map(preprocessing_fn)
tf.keras.mixed_precision.experimental.set_policy('mixed_bfloat16')
with strategy.scope():
model = TFBertForSequenceClassification.from_pretrained('bert-base-uncased',
num_labels=1)
optimizer = tf.keras.optimizers.Adam(learning_rate=5e-5)
model.compile(optimizer=optimizer,
loss=model.compute_loss) # can also use any keras loss fn
model.summary()
inuse_dataset = ds_train_filtered_2.shuffle(1000).batch(256).prefetch(
tf.data.experimental.AUTOTUNE)
model.fit(inuse_dataset, epochs=1, batch_size=256)
tpu_node_mnist.yaml
name: mnist-tpu-node
resources:
instance_type: n1-highmem-8
accelerators: tpu-v2-8
accelerator_args:
runtime_version: 2.12.0
tpu_vm: False
file_mounts:
/dataset:
name: demo-mnist-tpu
store: gcs
mode: MOUNT
# The setup command. Will be run under the working directory.
setup: |
git clone https://github.com/tensorflow/models.git
conda activate mnist
if [ $? -eq 0 ]; then
echo 'conda env exists'
else
conda create -n mnist python=3.8 -y
conda activate mnist
pip install tensorflow==2.12.0 tensorflow-datasets tensorflow-model-optimization cloud-tpu-client
fi
# The command to run. Will be run under the working directory.
run: |
conda activate mnist
cd models/official/legacy/image_classification/
export STORAGE_BUCKET=gs://demo-mnist-tpu
export MODEL_DIR=${STORAGE_BUCKET}/mnist
export DATA_DIR=${STORAGE_BUCKET}/data
export PYTHONPATH=/home/gcpuser/sky_workdir/models
python3 mnist_main.py \
--tpu=${TPU_NAME} \
--model_dir=${MODEL_DIR} \
--data_dir=${DATA_DIR} \
--train_epochs=10 \
--distribution_strategy=tpu \
--download
tpuvm_mnist.yaml
name: tpuvm_mnist
resources:
accelerators: tpu-v2-8
# The setup command. Will be run under the working directory.
setup: |
git clone https://github.com/google/flax.git --branch v0.10.1
conda activate flax
if [ $? -eq 0 ]; then
echo 'conda env exists'
else
conda create -n flax python=3.10 -y
conda activate flax
# Make sure to install TPU related packages in a conda env to avoid package conflicts.
pip install \
-f https://storage.googleapis.com/jax-releases/libtpu_releases.html "jax[tpu]==0.4.35" \
clu \
tensorflow tensorflow-datasets
pip install -e flax
fi
# The command to run. Will be run under the working directory.
run: |
conda activate flax
pip install clu
cd flax/examples/mnist
python3 main.py --workdir=/tmp/mnist \
--config=configs/default.py \
--config.learning_rate=0.05 \
--config.num_epochs=10
v6e/README.md
TPU v6e
Trillium (also refers to v6e) is Cloud TPU’s latest generation AI accelerator. SkyPilot support TPU v6e with provisioning, training and serving.
Catalogs
Currently, for TPU v6e, the public APIs for regions and pricing is not released yet, and pricing info for us-central1
, us-central2
, us-south1
is not available. We set the price to 0.0
in those regions for now.
## Provisioning
To provision TPU v6e, use the following command:
```bash
$ sky launch --gpus tpu-v6e-16 -c tpu-v6e
After that, you can SSH to the instance and start developing your model:
$ ssh tpu-v6e
Training
Examples in this directory (train-llama3-8b.yaml
) shows how to use TPU v6e to train a Llama3 8b model, using PyTorch (XLA) on the wikitext dataset. To start the training, use the following command:
$ HF_TOKEN=hf_xxx sky launch train-llama3-8b.yaml -c train-llama3-8b --env HF_TOKEN
Single-Host Training
The training throughput for a tpu-v6e-8
instance should around 0.5 samples/s:
(task, pid=17499) ***** train metrics *****
(task, pid=17499) epoch = 1.1765
(task, pid=17499) total_flos = 109935420GF
(task, pid=17499) train_loss = 10.6011
(task, pid=17499) train_runtime = 0:11:12.77
(task, pid=17499) train_samples = 282
(task, pid=17499) train_samples_per_second = 0.476
(task, pid=17499) train_steps_per_second = 0.03
INFO: Job finished (status: SUCCEEDED).
Multi-Host Training
By changing the TPU type to tpu-v6e-16
and the --per_device_train_batch_size
to 32
, the training throughput increased to around 1 samples/s:
(head, rank=0, pid=17894) ***** train metrics *****
(head, rank=0, pid=17894) epoch = 2.5
(head, rank=0, pid=17894) total_flos = 219870840GF
(head, rank=0, pid=17894) train_loss = 10.1527
(head, rank=0, pid=17894) train_runtime = 0:11:13.18
(head, rank=0, pid=17894) train_samples = 282
(head, rank=0, pid=17894) train_samples_per_second = 0.951
(head, rank=0, pid=17894) train_steps_per_second = 0.03
(worker1, rank=1, pid=15406, ip=10.164.0.57) ***** train metrics *****
(worker1, rank=1, pid=15406, ip=10.164.0.57) epoch = 2.5
(worker1, rank=1, pid=15406, ip=10.164.0.57) total_flos = 219870840GF
(worker1, rank=1, pid=15406, ip=10.164.0.57) train_loss = 10.1527
(worker1, rank=1, pid=15406, ip=10.164.0.57) train_runtime = 0:11:15.08
(worker1, rank=1, pid=15406, ip=10.164.0.57) train_samples = 282
(worker1, rank=1, pid=15406, ip=10.164.0.57) train_samples_per_second = 0.948
(worker1, rank=1, pid=15406, ip=10.164.0.57) train_steps_per_second = 0.03
(worker2, rank=2, pid=16552, ip=10.164.0.58) ***** train metrics *****
(worker2, rank=2, pid=16552, ip=10.164.0.58) epoch = 2.5
(worker2, rank=2, pid=16552, ip=10.164.0.58) total_flos = 219870840GF
(worker2, rank=2, pid=16552, ip=10.164.0.58) train_loss = 10.1527
(worker2, rank=2, pid=16552, ip=10.164.0.58) train_runtime = 0:11:15.61
(worker2, rank=2, pid=16552, ip=10.164.0.58) train_samples = 282
(worker2, rank=2, pid=16552, ip=10.164.0.58) train_samples_per_second = 0.947
(worker2, rank=2, pid=16552, ip=10.164.0.58) train_steps_per_second = 0.03
(worker3, rank=3, pid=17469, ip=10.164.0.59) ***** train metrics *****
(worker3, rank=3, pid=17469, ip=10.164.0.59) epoch = 2.5
(worker3, rank=3, pid=17469, ip=10.164.0.59) total_flos = 219870840GF
(worker3, rank=3, pid=17469, ip=10.164.0.59) train_loss = 10.1527
(worker3, rank=3, pid=17469, ip=10.164.0.59) train_runtime = 0:11:15.10
(worker3, rank=3, pid=17469, ip=10.164.0.59) train_samples = 282
(worker3, rank=3, pid=17469, ip=10.164.0.59) train_samples_per_second = 0.948
(worker3, rank=3, pid=17469, ip=10.164.0.59) train_steps_per_second = 0.03
INFO: Job finished (status: SUCCEEDED).
Serving
TPU v6e also supports serving. Examples in this directory (serve-llama2-7b.yaml
) shows how to use TPU v6e to serve a Llama2 7b model, using PyTorch (XLA) and the JetStream lib. To start the serving, use the following command:
$ HF_TOKEN=hf_xxx sky launch serve-llama2-7b.yaml -c serve-llama2-7b --env HF_TOKEN
After the server is ready, you should see the following message:
(task, pid=26431) 2024-09-24 19:58:15,160 - root - INFO - Starting server on port 9000 with 64 threads
(task, pid=26431) I0924 19:58:15.160293 140454572087296 server_lib.py:155] Starting server on port 9000 with 64 threads
(task, pid=26431) 2024-09-24 19:58:15,161 - root - INFO - Not starting JAX profiler server: False
(task, pid=26431) I0924 19:58:15.161907 140454572087296 server_lib.py:164] Not starting JAX profiler server: False
(task, pid=26431) Started jetstream_server....
You can now start a benchmark to test the serving performance:
$ sky exec serve-llama2-7b benchmark-llama2-7b.yaml
... (emitted logs)
(task, pid=25491) Successful requests: 100
(task, pid=25491) Benchmark duration: 8.753792 s
(task, pid=25491) Total input tokens: 21888
(task, pid=25491) Total generated tokens: 18803
(task, pid=25491) Request throughput: 11.42 requests/s
(task, pid=25491) Input token throughput: 2500.40 tokens/s
(task, pid=25491) Output token throughput: 2147.98 tokens/s
(task, pid=25491) Mean TTFT: 1981.93 ms
(task, pid=25491) Median TTFT: 1829.33 ms
(task, pid=25491) P99 TTFT: 4511.95 ms
(task, pid=25491) Mean TPOT: 130.71 ms
(task, pid=25491) Median TPOT: 18.88 ms
(task, pid=25491) P99 TPOT: 2487.37 ms
v6e/benchmark-llama2-7b.yaml
envs:
model_name: llama-2
tokenizer_path: /home/gcpuser/sky_workdir/ckpt/llama2-7b/original/tokenizer.model
run: |
cd JetStream
python benchmarks/benchmark_serving.py \
--tokenizer=$tokenizer_path --num-prompts=100 \
--dataset openorca --save-request-outputs \
--warmup-mode=sampled --model=$model_name
v6e/config-8B.json
{
"architectures": [
"LlamaForCausalLM"
],
"attention_bias": false,
"attention_dropout": 0.0,
"bos_token_id": 128000,
"eos_token_id": 128001,
"hidden_act": "silu",
"hidden_size": 4096,
"initializer_range": 0.02,
"intermediate_size": 14336,
"max_position_embeddings": 8192,
"model_type": "llama",
"num_attention_heads": 32,
"num_hidden_layers": 32,
"num_key_value_heads": 8,
"pretraining_tp": 1,
"rms_norm_eps": 1e-05,
"rope_scaling": null,
"rope_theta": 500000.0,
"tie_word_embeddings": false,
"torch_dtype": "bfloat16",
"transformers_version": "4.40.0.dev0",
"use_cache": true,
"vocab_size": 128256
}
v6e/fsdp_config.json
{
"fsdp_transformer_layer_cls_to_wrap": [
"LlamaDecoderLayer"
],
"xla": true,
"xla_fsdp_v2": true,
"xla_fsdp_grad_ckpt": true
}
v6e/serve-llama2-7b.yaml
resources:
accelerators: tpu-v6e-8 # Fill in the accelerator type you want to use
envs:
HF_TOKEN: # fill in your huggingface token
HF_REPO_ID: meta-llama/Llama-2-7b
model_name: llama-2
input_ckpt_dir: /home/gcpuser/sky_workdir/ckpt/llama2-7b/original
output_ckpt_dir: /home/gcpuser/sky_workdir/ckpt/llama2-7b/converted
tokenizer_path: /home/gcpuser/sky_workdir/ckpt/llama2-7b/original/tokenizer.model
setup: |
pip3 install huggingface_hub
python3 -c "import huggingface_hub; huggingface_hub.login('${HF_TOKEN}')"
# Setup TPU
pip3 install cloud-tpu-client
sudo apt update
sudo apt install -y libopenblas-base
pip3 install --pre torch==2.6.0.dev20240916+cpu torchvision==0.20.0.dev20240916+cpu \
--index-url https://download.pytorch.org/whl/nightly/cpu
pip install "torch_xla[tpu]@https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.6.0.dev20240916-cp310-cp310-linux_x86_64.whl" \
-f https://storage.googleapis.com/libtpu-releases/index.html
pip install torch_xla[pallas] \
-f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \
-f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
# Setup runtime for serving
git clone https://github.com/google/JetStream.git
cd JetStream
git checkout main
git pull origin main
pip install -e .
cd benchmarks
pip install -r requirements.in
cd ../..
git clone https://github.com/google/jetstream-pytorch.git
cd jetstream-pytorch/
git checkout jetstream-v0.2.3
source install_everything.sh
pip3 install -U --pre jax jaxlib libtpu-nightly requests \
-f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \
-f https://storage.googleapis.com/jax-releases/libtpu_releases.html
# Prepare checkpoint, inside jetstream-pytorch repo
mkdir -p ${input_ckpt_dir}
python3 -c "import huggingface_hub; huggingface_hub.snapshot_download('${HF_REPO_ID}', local_dir='${input_ckpt_dir}')"
mkdir -p ${output_ckpt_dir}
python -m convert_checkpoints --model_name=$model_name \
--input_checkpoint_dir=$input_ckpt_dir \
--output_checkpoint_dir=$output_ckpt_dir
run: |
cd jetstream-pytorch
python run_server.py --model_name=$model_name \
--size=7b --batch_size=24 --max_cache_length=2048 \
--checkpoint_path=$output_ckpt_dir \
--tokenizer_path=$tokenizer_path \
--sharding_config="default_shardings/llama.yaml"
v6e/train-llama3-8b.yaml
resources:
accelerators: tpu-v6e-8 # Fill in the accelerator type you want to use
envs:
HF_TOKEN: # fill in your huggingface token
workdir: .
setup: |
pip3 install huggingface_hub
python3 -c "import huggingface_hub; huggingface_hub.login('${HF_TOKEN}')"
# Setup TPU
pip3 install cloud-tpu-client
sudo apt update
sudo apt install -y libopenblas-base
pip3 install --pre torch==2.6.0.dev20240916+cpu torchvision==0.20.0.dev20240916+cpu \
--index-url https://download.pytorch.org/whl/nightly/cpu
pip install "torch_xla[tpu]@https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.6.0.dev20240916-cp310-cp310-linux_x86_64.whl" \
-f https://storage.googleapis.com/libtpu-releases/index.html
pip install torch_xla[pallas] \
-f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \
-f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
# Setup runtime for training
git clone -b flash_attention https://github.com/pytorch-tpu/transformers.git
cd transformers
pip3 install -e .
pip3 install datasets evaluate scikit-learn accelerate
run: |
unset LD_PRELOAD
PJRT_DEVICE=TPU XLA_USE_SPMD=1 ENABLE_PJRT_COMPATIBILITY=true \
python3 transformers/examples/pytorch/language-modeling/run_clm.py \
--dataset_name wikitext \
--dataset_config_name wikitext-2-raw-v1 \
--per_device_train_batch_size 16 \
--do_train \
--output_dir /home/$USER/tmp/test-clm \
--overwrite_output_dir \
--config_name /home/$USER/sky_workdir/config-8B.json \
--cache_dir /home/$USER/cache \
--tokenizer_name meta-llama/Meta-Llama-3-8B \
--block_size 8192 \
--optim adafactor \
--save_strategy no \
--logging_strategy no \
--fsdp "full_shard" \
--fsdp_config /home/$USER/sky_workdir/fsdp_config.json \
--torch_dtype bfloat16 \
--dataloader_drop_last yes \
--flash_attention \
--max_steps 20