Source code for sky.jobs.client.sdk

"""SDK functions for managed jobs."""
import json
import typing
from typing import Dict, List, Optional, Union
import webbrowser

import click

from sky import sky_logging
from sky.adaptors import common as adaptors_common
from sky.client import common as client_common
from sky.client import sdk
from sky.server import common as server_common
from sky.server.requests import payloads
from sky.skylet import constants
from sky.usage import usage_lib
from sky.utils import common_utils
from sky.utils import dag_utils

if typing.TYPE_CHECKING:
    import io

    import requests

    import sky
else:
    requests = adaptors_common.LazyImport('requests')

logger = sky_logging.init_logger(__name__)


[docs]@usage_lib.entrypoint @server_common.check_server_healthy_or_start def launch( task: Union['sky.Task', 'sky.Dag'], name: Optional[str] = None, # Internal only: # pylint: disable=invalid-name _need_confirmation: bool = False, ) -> server_common.RequestId: """Launches a managed job. Please refer to sky.cli.job_launch for documentation. Args: task: sky.Task, or sky.Dag (experimental; 1-task only) to launch as a managed job. name: Name of the managed job. _need_confirmation: (Internal only) Whether to show a confirmation prompt before launching the job. Returns: The request ID of the launch request. Request Returns: job_id (Optional[int]): Job ID for the managed job controller_handle (Optional[ResourceHandle]): ResourceHandle of the controller Request Raises: ValueError: cluster does not exist. Or, the entrypoint is not a valid chain dag. sky.exceptions.NotSupportedError: the feature is not supported. """ dag = dag_utils.convert_entrypoint_to_dag(task) sdk.validate(dag) if _need_confirmation: request_id = sdk.optimize(dag) sdk.stream_and_get(request_id) prompt = f'Launching a managed job {dag.name!r}. Proceed?' if prompt is not None: click.confirm(prompt, default=True, abort=True, show_default=True) dag = client_common.upload_mounts_to_api_server(dag) dag_str = dag_utils.dump_chain_dag_to_yaml_str(dag) body = payloads.JobsLaunchBody( task=dag_str, name=name, ) response = requests.post( f'{server_common.get_server_url()}/jobs/launch', json=json.loads(body.model_dump_json()), timeout=(5, None), ) return server_common.get_request_id(response)
[docs]@usage_lib.entrypoint @server_common.check_server_healthy_or_start def queue(refresh: bool, skip_finished: bool = False, all_users: bool = False) -> server_common.RequestId: """Gets statuses of managed jobs. Please refer to sky.cli.job_queue for documentation. Args: refresh: Whether to restart the jobs controller if it is stopped. skip_finished: Whether to skip finished jobs. all_users: Whether to show all users' jobs. Returns: The request ID of the queue request. Request Returns: job_records (List[Dict[str, Any]]): A list of dicts, with each dict containing the information of a job. .. code-block:: python [ { 'job_id': (int) job id, 'job_name': (str) job name, 'resources': (str) resources of the job, 'submitted_at': (float) timestamp of submission, 'end_at': (float) timestamp of end, 'duration': (float) duration in seconds, 'recovery_count': (int) Number of retries, 'status': (sky.jobs.ManagedJobStatus) of the job, 'cluster_resources': (str) resources of the cluster, 'region': (str) region of the cluster, } ] Request Raises: sky.exceptions.ClusterNotUpError: the jobs controller is not up or does not exist. RuntimeError: if failed to get the managed jobs with ssh. """ body = payloads.JobsQueueBody( refresh=refresh, skip_finished=skip_finished, all_users=all_users, ) response = requests.post( f'{server_common.get_server_url()}/jobs/queue', json=json.loads(body.model_dump_json()), timeout=(5, None), ) return server_common.get_request_id(response=response)
[docs]@usage_lib.entrypoint @server_common.check_server_healthy_or_start def cancel( name: Optional[str] = None, job_ids: Optional[List[int]] = None, all: bool = False, # pylint: disable=redefined-builtin all_users: bool = False, ) -> server_common.RequestId: """Cancels managed jobs. Please refer to sky.cli.job_cancel for documentation. Args: name: Name of the managed job to cancel. job_ids: IDs of the managed jobs to cancel. all: Whether to cancel all managed jobs. all_users: Whether to cancel all managed jobs from all users. Returns: The request ID of the cancel request. Request Raises: sky.exceptions.ClusterNotUpError: the jobs controller is not up. RuntimeError: failed to cancel the job. """ body = payloads.JobsCancelBody( name=name, job_ids=job_ids, all=all, all_users=all_users, ) response = requests.post( f'{server_common.get_server_url()}/jobs/cancel', json=json.loads(body.model_dump_json()), timeout=(5, None), ) return server_common.get_request_id(response=response)
[docs]@usage_lib.entrypoint @server_common.check_server_healthy_or_start def tail_logs(name: Optional[str] = None, job_id: Optional[int] = None, follow: bool = True, controller: bool = False, refresh: bool = False, output_stream: Optional['io.TextIOBase'] = None) -> int: """Tails logs of managed jobs. You can provide either a job name or a job ID to tail logs. If both are not provided, the logs of the latest job will be shown. Args: name: Name of the managed job to tail logs. job_id: ID of the managed job to tail logs. follow: Whether to follow the logs. controller: Whether to tail logs from the jobs controller. refresh: Whether to restart the jobs controller if it is stopped. output_stream: The stream to write the logs to. If None, print to the console. Returns: Exit code based on success or failure of the job. 0 if success, 100 if the job failed. See exceptions.JobExitCode for possible exit codes. Request Raises: ValueError: invalid arguments. sky.exceptions.ClusterNotUpError: the jobs controller is not up. """ body = payloads.JobsLogsBody( name=name, job_id=job_id, follow=follow, controller=controller, refresh=refresh, ) response = requests.post( f'{server_common.get_server_url()}/jobs/logs', json=json.loads(body.model_dump_json()), stream=True, timeout=(5, None), ) request_id = server_common.get_request_id(response) return sdk.stream_response(request_id, response, output_stream)
@usage_lib.entrypoint @server_common.check_server_healthy_or_start def download_logs( name: Optional[str], job_id: Optional[int], refresh: bool, controller: bool, local_dir: str = constants.SKY_LOGS_DIRECTORY) -> Dict[int, str]: """Sync down logs of managed jobs. Please refer to sky.cli.job_logs for documentation. Args: name: Name of the managed job to sync down logs. job_id: ID of the managed job to sync down logs. refresh: Whether to restart the jobs controller if it is stopped. controller: Whether to sync down logs from the jobs controller. local_dir: Local directory to sync down logs. Returns: A dictionary mapping job ID to the local path. Request Raises: ValueError: invalid arguments. sky.exceptions.ClusterNotUpError: the jobs controller is not up. """ body = payloads.JobsDownloadLogsBody( name=name, job_id=job_id, refresh=refresh, controller=controller, local_dir=local_dir, ) response = requests.post( f'{server_common.get_server_url()}/jobs/download_logs', json=json.loads(body.model_dump_json()), timeout=(5, None), ) job_id_remote_path_dict = sdk.stream_and_get( server_common.get_request_id(response)) remote2local_path_dict = client_common.download_logs_from_api_server( job_id_remote_path_dict.values()) return { job_id: remote2local_path_dict[remote_path] for job_id, remote_path in job_id_remote_path_dict.items() } spot_launch = common_utils.deprecated_function( launch, name='sky.jobs.launch', deprecated_name='spot_launch', removing_version='0.8.0', override_argument={'use_spot': True}) spot_queue = common_utils.deprecated_function(queue, name='sky.jobs.queue', deprecated_name='spot_queue', removing_version='0.8.0') spot_cancel = common_utils.deprecated_function(cancel, name='sky.jobs.cancel', deprecated_name='spot_cancel', removing_version='0.8.0') spot_tail_logs = common_utils.deprecated_function( tail_logs, name='sky.jobs.tail_logs', deprecated_name='spot_tail_logs', removing_version='0.8.0') @usage_lib.entrypoint @server_common.check_server_healthy_or_start def dashboard() -> None: """Starts a dashboard for managed jobs.""" user_hash = common_utils.get_user_hash() api_server_url = server_common.get_server_url() params = f'user_hash={user_hash}' url = f'{api_server_url}/jobs/dashboard?{params}' logger.info(f'Opening dashboard in browser: {url}') webbrowser.open(url)