Source: examples/airflow
Running SkyPilot tasks in Airflow with the SkyPilot API Server#
In this guide, we show how a training workflow involving data preprocessing, training and evaluation can be first easily developed with SkyPilot, and then orchestrated in Airflow.
This example uses a remote SkyPilot API Server to manage shared state across invocations.
💡 Tip: SkyPilot also supports defining and running pipelines without Airflow. Check out Jobs Pipelines for more information.
Why use SkyPilot with Airflow?#
In AI workflows, the transition from development to production is hard.
Workflow development happens ad-hoc, with a lot of interaction required with the code and data. When moving this to an Airflow DAG in production, managing dependencies, environments and the infra requirements of the workflow gets complex. Porting the code to an airflow requires significant time to test and validate any changes, often requiring re-writing the code as Airflow operators.
SkyPilot seamlessly bridges the dev -> production gap.
SkyPilot can operate on any of your infra, allowing you to package and run the same code that you ran during development on a production Airflow cluster. Behind the scenes, SkyPilot handles environment setup, dependency management, and infra orchestration, allowing you to focus on your code.
Here’s how you can use SkyPilot to take your dev workflows to production in Airflow:
Define and test your workflow as SkyPilot tasks.
Use
sky launch
and Sky VSCode integration to run, debug and iterate on your code.
Orchestrate SkyPilot tasks in Airflow by invoking
sky launch
on their YAMLs as a task in the Airflow DAG.Airflow does the scheduling, logging, and monitoring, while SkyPilot handles the infra setup and task execution.
Prerequisites#
Airflow installed locally (
SequentialExecutor
)SkyPilot API server endpoint to send requests to.
If you do not have one, refer to the API server docs to deploy one.
For this specific example: the API server should have AWS/GCS access to create buckets to store intermediate task outputs.
Configuring the API server endpoint#
Once your API server is deployed, you will need to configure Airflow to use it. Set the SKYPILOT_API_SERVER_ENDPOINT
variable in Airflow - it will be used by the run_sky_task
function to send requests to the API server:
airflow variables set SKYPILOT_API_SERVER_ENDPOINT http://<skypilot-api-server-endpoint>
You can also use the Airflow web UI to set the variable:
Defining the tasks#
We will define the following tasks to mock a training workflow:
data_preprocessing.yaml
: Generates data and writes it to a bucket.train.yaml
: Trains a model on the data in the bucket.eval.yaml
: Evaluates the model and writes evaluation results to the bucket.
We have defined these tasks in this directory and uploaded them to a Git repository.
When developing the workflow, you can run the tasks independently using sky launch
:
# Run the data preprocessing task, replacing <bucket-name> with the bucket you created above
sky launch -c data --env DATA_BUCKET_NAME=<bucket-name> --env DATA_BUCKET_STORE_TYPE=s3 data_preprocessing.yaml
The train and eval step can be run in a similar way:
# Run the train task
sky launch -c train --env DATA_BUCKET_NAME=<bucket-name> --env DATA_BUCKET_STORE_TYPE=s3 train.yaml
Hint: You can use ssh
and VSCode to interactively develop and debug the tasks.
Note: eval
can be optionally run on the same cluster as train
with sky exec
.
Writing the Airflow DAG#
Once we have developed the tasks, we can seamlessly run them in Airflow.
No changes required to our tasks - we use the same YAMLs we wrote in the previous step to create an Airflow DAG in
sky_train_dag.py
.Airflow native logging - SkyPilot logs are written to container stdout, which is captured as task logs in Airflow and displayed in the UI.
Easy debugging - If a task fails, you can independently run the task using
sky launch
to debug the issue. SkyPilot will recreate the environment in which the task failed.
Here’s a snippet of the DAG declaration in sky_train_dag.py:
with DAG(dag_id='sky_train_dag', default_args=default_args,
catchup=False) as dag:
# Path to SkyPilot YAMLs. Can be a git repo or local directory.
base_path = 'https://github.com/skypilot-org/mock-train-workflow.git'
# Generate bucket UUID as first task
bucket_uuid = generate_bucket_uuid()
# Use the bucket_uuid from previous task
common_envs = {
'DATA_BUCKET_NAME': f"sky-data-demo-{bucket_uuid}",
'DATA_BUCKET_STORE_TYPE': 's3',
}
skypilot_api_server_endpoint = "{{ var.value.SKYPILOT_API_SERVER_ENDPOINT }}"
preprocess_task = run_sky_task.override(task_id="data_preprocess")(
base_path,
'data_preprocessing.yaml',
skypilot_api_server_endpoint,
envs_override=common_envs)
train_task = run_sky_task.override(task_id="train")(
base_path,
'train.yaml',
skypilot_api_server_endpoint,
envs_override=common_envs)
eval_task = run_sky_task.override(task_id="eval")(
base_path,
'eval.yaml',
skypilot_api_server_endpoint,
envs_override=common_envs)
# Define the workflow
bucket_uuid >> preprocess_task >> train_task >> eval_task
Behind the scenes, the run_sky_task
uses the Airflow native PythonVirtualenvOperator (@task.virtualenv), which creates a Python virtual environment with skypilot
installed. We need to run the task in a virtual environment as there’s a dependency conflict between the latest skypilot
and airflow
Python package.
@task.virtualenv(
python_version='3.11',
requirements=['skypilot-nightly[gcp,aws,kubernetes]'],
system_site_packages=False,
templates_dict={
'SKYPILOT_API_SERVER_ENDPOINT': "{{ var.value.SKYPILOT_API_SERVER_ENDPOINT }}",
})
def run_sky_task(...):
...
The task YAML files can be sourced in two ways:
From a Git repository (as shown above):
repo_url = 'https://github.com/skypilot-org/mock-train-workflow.git' run_sky_task(...)(repo_url, 'path/to/yaml', git_branch='optional_branch')
The task will automatically clone the repository and checkout the specified branch before execution.
From a local path:
local_path = '/path/to/local/directory' run_sky_task(...)(local_path, 'path/to/yaml')
This is useful during development or when your tasks are stored locally.
All clusters are set to auto-down after the task is done, so no dangling clusters are left behind.
Running the DAG#
Copy the DAG file to the Airflow DAGs directory.
cp sky_train_dag.py /path/to/airflow/dags # If your Airflow is running on Kubernetes, you may use kubectl cp to copy the file to the pod # kubectl cp sky_train_dag.py <airflow-pod-name>:/opt/airflow/dags
Run
airflow dags list
to confirm that the DAG is loaded.Find the DAG in the Airflow UI (typically http://localhost:8080) and enable it. The UI may take a couple of minutes to reflect the changes. Force unpause the DAG if it is paused with
airflow dags unpause sky_train_dag
Trigger the DAG from the Airflow UI using the
Trigger DAG
button.Navigate to the run in the Airflow UI to see the DAG progress and logs of each task.
If a task fails, SkyPilot will automatically tear down the SkyPilot cluster.
Optional: Configure cloud accounts#
By default, the SkyPilot task will run using the same cloud credentials the SkyPilot API Server has. This may not be ideal if we have an existing service account for our task.
In this example, we’ll use a GCP service account to demonstrate how we can use custom credentials. Refer to GCP service account guide on how to set up a service account.
Once you have the JSON key for our service account, create an Airflow connection to store the credentials.
airflow connections add \
--conn-type google_cloud_platform \
--conn-extra "{\"keyfile_dict\": \"<YOUR_SERVICE_ACCOUNT_JSON_KEY>\"}" \
skypilot_gcp_task
You can also use the Airflow web UI to add the connection:
Next, we will define data_preprocessing_gcp_sa.yaml
, which contains small modifications to data_preprocessing.yaml
that will use our GCP service account. The key changes needed here are to mount the GCP service account JSON key to our SkyPilot cluster, and to activate it using gcloud auth activate-service-account
.
We will also need a new task to read the GCP service account JSON key from our Airflow connection, and then change the preprocess task in our DAG to refer to this new YAML file.
with DAG(dag_id='sky_train_dag', default_args=default_args,
catchup=False) as dag:
...
# Get GCP credentials
gcp_service_account_json = get_gcp_service_account_json()
...
preprocess_task = run_sky_task.override(task_id="data_preprocess")(
base_path,
'data_preprocessing_gcp_sa.yaml',
gcp_service_account_json=gcp_service_account_json,
envs_override=common_envs)
...
bucket_uuid >> gcp_service_account_json >> preprocess_task >> train_task >> eval_task
Future work: a native Airflow Executor built on SkyPilot#
Currently this example relies on a helper run_sky_task
method to wrap SkyPilot invocation in @task, but in the future SkyPilot can provide a native Airflow Executor.
In such a setup, SkyPilot state management also not be required, as the executor will handle SkyPilot cluster launching and termination.
Included files#
data_preprocessing.yaml
resources:
cpus: 1
envs:
DATA_BUCKET_NAME: sky-demo-data-test
DATA_BUCKET_STORE_TYPE: s3
file_mounts:
/data:
name: $DATA_BUCKET_NAME
store: $DATA_BUCKET_STORE_TYPE
setup: |
echo "Setting up dependencies for data preprocessing..."
run: |
echo "Running data preprocessing..."
# Generate few files with random data to simulate data preprocessing
for i in {0..9}; do
dd if=/dev/urandom of=/data/file_$i bs=1M count=10
done
echo "Data preprocessing completed, wrote to $DATA_BUCKET_NAME"
data_preprocessing_gcp_sa.yaml
resources:
cpus: 1+
envs:
DATA_BUCKET_NAME: sky-demo-data-test
DATA_BUCKET_STORE_TYPE: s3
GCP_SERVICE_ACCOUNT_JSON_PATH: null
file_mounts:
/data:
name: $DATA_BUCKET_NAME
store: $DATA_BUCKET_STORE_TYPE
/tmp/gcp-service-account.json: $GCP_SERVICE_ACCOUNT_JSON_PATH
setup: |
echo "Setting up dependencies for data preprocessing..."
curl -O https://dl.google.com/dl/cloudsdk/channels/rapid/downloads/google-cloud-cli-linux-x86_64.tar.gz
tar -xf google-cloud-cli-linux-x86_64.tar.gz
./google-cloud-sdk/install.sh --quiet --path-update true
source ~/.bashrc
gcloud auth activate-service-account --key-file=/tmp/gcp-service-account.json
run: |
echo "Running data preprocessing on behalf of $(gcloud auth list --filter=status:ACTIVE --format="value(account)")..."
# Generate few files with random data to simulate data preprocessing
for i in {0..9}; do
dd if=/dev/urandom of=/data/file_$i bs=1M count=10
done
echo "Data preprocessing completed, wrote to $DATA_BUCKET_NAME"
eval.yaml
resources:
cpus: 1
# Add GPUs here
envs:
DATA_BUCKET_NAME: sky-demo-data-test
DATA_BUCKET_STORE_TYPE: s3
file_mounts:
/data:
name: $DATA_BUCKET_NAME
store: $DATA_BUCKET_STORE_TYPE
setup: |
echo "Setting up dependencies for eval..."
run: |
echo "Evaluating the trained model..."
# Run a mock evaluation job that reads the trained model from /data/trained_model.txt
cat /data/trained_model.txt || true
# Generate a mock accuracy
ACCURACY=$(shuf -i 90-100 -n 1)
echo "Metric - accuracy: $ACCURACY%"
echo "Evaluation report" > /data/evaluation_report.txt
echo "Evaluation completed, report written to $DATA_BUCKET_NAME"
sky_train_dag.py
from typing import Optional
import uuid
from airflow import DAG
from airflow.decorators import task
from airflow.exceptions import AirflowNotFoundException
from airflow.providers.google.common.hooks.base_google import GoogleBaseHook
import pendulum
default_args = {
'owner': 'airflow',
'start_date': pendulum.today('UTC').add(days=-1,),
}
@task.virtualenv(
python_version='3.11',
requirements=['skypilot-nightly[gcp,aws,kubernetes]'],
system_site_packages=False,
templates_dict={
'SKYPILOT_API_SERVER_ENDPOINT':
('{{ var.value.SKYPILOT_API_SERVER_ENDPOINT }}'),
},
)
def run_sky_task(
base_path: str, # pylint: disable=redefined-outer-name
yaml_path: str,
gcp_service_account_json: Optional[str] = None, # pylint: disable=redefined-outer-name
envs_override: dict = None,
git_branch: str = None,
**kwargs,
):
"""Generic function to run a SkyPilot task.
This is a blocking call that runs the SkyPilot task and streams the logs.
In the future, we can use deferrable tasks to avoid blocking the worker
while waiting for cluster to start.
Args:
base_path: Base path (local directory or git repo URL)
yaml_path: Path to the YAML file (relative to base_path)
gcp_service_account_json: GCP service account JSON-encoded string
envs_override: Dictionary of environment variables to override in the
task config
git_branch: Optional branch name to checkout (only used if base_path
is a git repo)
"""
# pylint: disable=import-outside-toplevel
import os
import subprocess
import tempfile
# pylint: disable=import-outside-toplevel, reimported, redefined-outer-name
import uuid
import yaml
def _run_sky_task(yaml_path: str, envs_override: dict):
"""Internal helper to run the sky task after directory setup."""
# pylint: disable=import-outside-toplevel
import sky
with open(os.path.expanduser(yaml_path), 'r', encoding='utf-8') as f:
task_config = yaml.safe_load(f)
# Initialize envs if not present
if 'envs' not in task_config:
task_config['envs'] = {}
# Update the envs with the override values
# task.update_envs() is not used here, see
# https://github.com/skypilot-org/skypilot/issues/4363
task_config['envs'].update(envs_override)
# pylint: disable=redefined-outer-name
task = sky.Task.from_yaml_config(task_config)
cluster_uuid = str(uuid.uuid4())[:4]
task_name = os.path.splitext(os.path.basename(yaml_path))[0]
cluster_name = f'{task_name}-{cluster_uuid}'
print(f'Starting SkyPilot task with cluster: {cluster_name}')
launch_request_id = sky.launch(task,
cluster_name=cluster_name,
down=True)
job_id, _ = sky.stream_and_get(launch_request_id)
# TODO(romilb): In the future, we can use deferrable tasks to avoid
# blocking the worker while waiting for cluster to start.
# Stream the logs for airflow logging
sky.tail_logs(cluster_name=cluster_name, job_id=job_id, follow=True)
# Terminate the cluster after the task is done
down_id = sky.down(cluster_name)
sky.stream_and_get(down_id)
return cluster_name
# Set the SkyPilot API server endpoint
if kwargs['templates_dict']:
os.environ['SKYPILOT_API_SERVER_ENDPOINT'] = (
kwargs['templates_dict']['SKYPILOT_API_SERVER_ENDPOINT'])
original_cwd = os.getcwd()
# Write GCP service account JSON to a temporary file,
# which will be mounted to the SkyPilot cluster.
if gcp_service_account_json:
with tempfile.NamedTemporaryFile(delete=False,
suffix='.json') as temp_file:
temp_file.write(gcp_service_account_json.encode('utf-8'))
envs_override['GCP_SERVICE_ACCOUNT_JSON_PATH'] = temp_file.name
try:
# Handle git repos vs local paths
if base_path.startswith(('http://', 'https://', 'git://')):
with tempfile.TemporaryDirectory() as temp_dir:
# TODO(romilb): This assumes git credentials are available
# in the airflow worker
subprocess.run(['git', 'clone', base_path, temp_dir],
check=True)
# Checkout specific branch if provided
if git_branch:
subprocess.run(['git', 'checkout', git_branch],
cwd=temp_dir,
check=True)
full_yaml_path = os.path.join(temp_dir, yaml_path)
# Change to the temp dir to set context
os.chdir(temp_dir)
# Run the sky task
return _run_sky_task(full_yaml_path, envs_override or {})
else:
full_yaml_path = os.path.join(base_path, yaml_path)
os.chdir(base_path)
# Run the sky task
return _run_sky_task(full_yaml_path, envs_override or {})
finally:
os.chdir(original_cwd)
@task
def generate_bucket_uuid():
"""Generate a unique bucket UUID for this DAG run."""
bucket_uuid = str(uuid.uuid4())[:4] # pylint: disable=redefined-outer-name
return bucket_uuid
@task
def get_gcp_service_account_json() -> Optional[str]:
"""Fetch GCP credentials from Airflow connection."""
try:
hook = GoogleBaseHook(gcp_conn_id='skypilot_gcp_task')
status, message = hook.test_connection()
print(f'GCP connection status: {status}, message: {message}')
except AirflowNotFoundException:
print('GCP connection not found, skipping')
return None
conn = hook.get_connection(hook.gcp_conn_id)
service_account_json = conn.extra_dejson.get('keyfile_dict')
return service_account_json
with DAG(dag_id='sky_train_dag', default_args=default_args,
catchup=False) as dag:
# Path to SkyPilot YAMLs. Can be a git repo or local directory.
base_path = 'https://github.com/skypilot-org/mock-train-workflow.git'
# Generate bucket UUID as first task
bucket_uuid = generate_bucket_uuid()
# Get GCP credentials (if available)
gcp_service_account_json = get_gcp_service_account_json()
# Use the bucket_uuid from previous task
common_envs = {
'DATA_BUCKET_NAME': f'sky-data-demo-{bucket_uuid}',
'DATA_BUCKET_STORE_TYPE': 's3',
}
preprocess_task = run_sky_task.override(task_id='data_preprocess')(
base_path,
# Or data_preprocessing_gcp_sa.yaml if you want
# to use a custom GCP service account
'data_preprocessing.yaml',
gcp_service_account_json=gcp_service_account_json,
envs_override=common_envs,
)
train_task = run_sky_task.override(task_id='train')(
base_path, 'train.yaml', envs_override=common_envs)
eval_task = run_sky_task.override(task_id='eval')(base_path,
'eval.yaml',
envs_override=common_envs)
# Define the workflow
# pylint: disable=pointless-statement
(bucket_uuid >> gcp_service_account_json >> preprocess_task >> train_task >>
eval_task)
train.yaml
resources:
cpus: 1
# Add GPUs here
envs:
DATA_BUCKET_NAME: sky-demo-data-test
DATA_BUCKET_STORE_TYPE: s3
NUM_EPOCHS: 2
file_mounts:
/data:
name: $DATA_BUCKET_NAME
store: $DATA_BUCKET_STORE_TYPE
setup: |
echo "Setting up dependencies for training..."
run: |
echo "Running training..."
# Run a mock training job that loops through the files in /data starting with 'file_'
for (( i=1; i<=NUM_EPOCHS; i++ )); do
for file in /data/file_*; do
echo "Epoch $i: Training on $file"
sleep 2
done
done
# Mock checkpointing the trained model to the data bucket
echo "Trained model" > /data/trained_model.txt
echo "Training completed, model written to to $DATA_BUCKET_NAME"