Source: examples/temporal
Running SkyPilot Tasks in Temporal Workflows#
This example demonstrates how to launch SkyPilot tasks and manage them in a Temporal workflow.
All activities, such as launching clusters, executing tasks, and tearing down clusters, are run on the same worker, eliminating the need for SkyPilot’s state management across multiple workers.
Defining the Tasks#
We will define the following tasks to mock a training workflow:
data_preprocessing.yaml: Generates data and writes it to a bucket.train.yaml: Trains a model using the data in the bucket.eval.yaml: Evaluates the model and writes the evaluation results to the bucket.
These tasks are defined in the mock_training_workflow repository. The repository is cloned during the workflow to execute the tasks.
Workflow Overview#
We define a Temporal workflow consisting of the following steps:
Clone the repository containing tasks using
git.Launch a SkyPilot cluster to run the data preprocessing job.
Terminate the cluster after preprocessing.
Launch another cluster for training the model.
Execute an evaluation task on the same training cluster.
Terminate the cluster after evaluation.
Temporal Activities#
These steps are implemented as Temporal activities, which are functions that can be executed by the Temporal worker:
run_sky_launch: Launches a SkyPilot cluster with a specified configuration.run_sky_down: Terminates the specified SkyPilot cluster.run_sky_exec: Executes a task on an existing SkyPilot cluster.run_git_clone: Clones a Git repository to a specified location.
Single Worker Execution#
In this workflow, all tasks are handled by the same Temporal worker. This simplifies the workflow, as SkyPilot’s internal state does not need to be transferred between different workers, ensuring seamless orchestration.
This is achieved by registering all activities (run_sky_launch, run_sky_down, run_sky_exec)to the same worker and enqueueing them in the same task queue:
async with Worker(
client,
task_queue='skypilot-task-queue',
workflows=[SkyPilotWorkflow],
activities=[run_sky_launch, run_sky_down, run_sky_exec, run_git_clone]
):
Running the Workflow#
If running temporal locally, start the Temporal server:
temporal server start-dev
Launch the workflow:
python skypilot_workflow.pyMonitor the workflow execution in the Temporal Web UI (typically http://localhost:8233).
When the workflow completes, all logs will be available in the Temporal Web UI.
Included files#
skypilot_workflow.py
import asyncio
from dataclasses import dataclass
from datetime import timedelta
import os
import subprocess
from temporalio import activity
from temporalio import workflow
from temporalio.client import Client
from temporalio.worker import Worker
@dataclass
class SkyLaunchCommand:
cluster_name: str
entrypoint: str
flags: str
@dataclass
class SkyDownCommand:
cluster_name: str
@dataclass
class SkyExecCommand:
cluster_name: str
entrypoint: str
flags: str
@activity.defn
async def run_sky_launch(input: SkyLaunchCommand) -> str:
activity.logger.info(
f'Running Sky Launch on cluster: {input.cluster_name} '
f'with entrypoint: {input.entrypoint} and flags: {input.flags}')
# Run the provided SkyPilot command using subprocess
command = f'sky launch -y -c {input.cluster_name} {input.flags} {input.entrypoint}'
try:
result = subprocess.run(command.split(),
capture_output=True,
text=True,
check=True)
activity.logger.info(f'Sky launch output: {result.stdout}')
return result.stdout.strip() # Return the output from the subprocess
except subprocess.CalledProcessError as e:
activity.logger.error(f'Sky launch failed with error: {e}')
activity.logger.error(f'Stdout: {e.stdout}')
activity.logger.error(f'Stderr: {e.stderr}')
raise # Re-raise the exception to indicate failure
@activity.defn
async def run_sky_down(input: SkyDownCommand) -> str:
activity.logger.info(f'Running Sky Down on cluster: {input.cluster_name}')
# Run the sky down command using subprocess
command = f'sky down -y {input.cluster_name}'
try:
result = subprocess.run(command.split(),
capture_output=True,
text=True,
check=True)
activity.logger.info(f'Sky down output: {result.stdout}')
return result.stdout.strip()
except subprocess.CalledProcessError as e:
activity.logger.error(f'Sky down failed with error: {e}')
activity.logger.error(f'Stdout: {e.stdout}')
activity.logger.error(f'Stderr: {e.stderr}')
raise # Re-raise the exception to indicate failure
@activity.defn
async def run_sky_exec(input: SkyExecCommand) -> str:
activity.logger.info(
f'Running Sky exec on cluster: {input.cluster_name} '
f'with entrypoint: {input.entrypoint} and flags: {input.flags}')
# Run the sky exec command using subprocess
full_command = f'sky exec {input.cluster_name} {input.flags} {input.entrypoint}'
try:
result = subprocess.run(full_command,
shell=True,
capture_output=True,
text=True,
check=True)
activity.logger.info(f'Sky exec output: {result.stdout}')
return result.stdout.strip()
except subprocess.CalledProcessError as e:
activity.logger.error(f'Sky exec failed with error: {e}')
activity.logger.error(f'Stdout: {e.stdout}')
activity.logger.error(f'Stderr: {e.stderr}')
raise # Re-raise the exception to indicate failure
@dataclass
class GitCloneInput:
repo_url: str
clone_path: str
@activity.defn
async def run_git_clone(input: GitCloneInput) -> str:
activity.logger.info(
f'Cloning git repository: {input.repo_url} to {input.clone_path}')
# Create clone path if it doesn't exist
os.makedirs(input.clone_path, exist_ok=True)
# Check if the repository already exists
if os.path.exists(os.path.join(input.clone_path, '.git')):
# If it exists, pull the latest changes
command = f'git -C {input.clone_path} pull'
else:
# If it doesn't exist, clone the repository
command = f'git clone {input.repo_url} {input.clone_path}'
try:
result = subprocess.run(command.split(),
capture_output=True,
text=True,
check=True)
activity.logger.info(f'Git clone output: {result.stdout}')
return result.stdout.strip()
except subprocess.CalledProcessError as e:
activity.logger.error(f'Git clone failed with error: {e}')
raise # Re-raise the exception to indicate failure
@dataclass
class SkyPilotWorkflowInput:
cluster_prefix: str
repo_url: str
data_bucket_url: str = None
@workflow.defn
class SkyPilotWorkflow:
@workflow.run
async def run(self, input: SkyPilotWorkflowInput) -> str:
cluster_prefix = input.cluster_prefix
repo_url = input.repo_url
data_bucket_url = input.data_bucket_url
workflow.logger.info(
f'Running SkyPilot workflow with cluster prefix: {cluster_prefix} ')
# 1. Clone the repository
clone_path = '/tmp/skypilot_repo'
clone_result = await workflow.execute_activity(
run_git_clone,
GitCloneInput(repo_url, clone_path),
start_to_close_timeout=timedelta(minutes=5),
)
workflow.logger.info(f'Clone result: {clone_result}')
data_bucket_flag = '--env DATA_BUCKET_URL=' + data_bucket_url if data_bucket_url else ''
# 2. Launch data preprocessing
cluster_name = f'{cluster_prefix}-preprocess'
preprocess_result = await workflow.execute_activity(
run_sky_launch,
SkyLaunchCommand(cluster_name,
f'{clone_path}/data_preprocessing.yaml',
f'--cloud kubernetes {data_bucket_flag}'),
start_to_close_timeout=timedelta(minutes=30),
)
workflow.logger.info(f'Preprocessing result: {preprocess_result}')
# 3. Down the cluster
down_result = await workflow.execute_activity(
run_sky_down,
SkyDownCommand(cluster_name),
start_to_close_timeout=timedelta(minutes=10),
)
workflow.logger.info(f'Down result: {down_result}')
# 4. Launch training
cluster_name = f'{cluster_prefix}-train'
train_result = await workflow.execute_activity(
run_sky_launch,
SkyLaunchCommand(cluster_name, f'{clone_path}/train.yaml',
f'--cloud kubernetes {data_bucket_flag}'),
start_to_close_timeout=timedelta(minutes=60),
)
workflow.logger.info(f'Training result: {train_result}')
# 5. Execute evaluation on the same
eval_result = await workflow.execute_activity(
run_sky_exec,
SkyExecCommand(cluster_name, f'{clone_path}/eval.yaml',
f'{data_bucket_flag}'),
start_to_close_timeout=timedelta(minutes=30),
)
workflow.logger.info(f'Evaluation result: {eval_result}')
# 6. Down the cluster
down_result = await workflow.execute_activity(
run_sky_down,
SkyDownCommand(cluster_name),
start_to_close_timeout=timedelta(minutes=10),
)
workflow.logger.info(f'Down result: {down_result}')
# Return the combined result
return f'Preprocessing: {preprocess_result}, Training: {train_result}, Evaluation: {eval_result}'
async def main():
# Start client
client = await Client.connect('localhost:7233')
# Run a worker for the workflow
async with Worker(
client,
task_queue='skypilot-task-queue',
workflows=[SkyPilotWorkflow],
activities=[run_sky_launch, run_sky_down, run_sky_exec, run_git_clone
], # Register all Sky activities to the same worker
):
# Execute the workflow with cluster name and config path
result = await client.execute_workflow(
SkyPilotWorkflow.run,
SkyPilotWorkflowInput(
cluster_prefix='my-workflow', # cluster name prefix
repo_url=
'https://github.com/romilbhardwaj/mock_train_workflow.git',
data_bucket_url='gs://sky-example-data'), # repo url
id='skypilot-workflow-id',
task_queue='skypilot-task-queue',
)
print(f'SkyPilot Workflow Result: {result}')
if __name__ == '__main__':
asyncio.run(main())