"""SDK functions for managed jobs."""
import json
import pathlib
import threading
import typing
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
import zlib
import click
from sky import sky_logging
from sky.backends import backend_utils
from sky.client import common as client_common
from sky.client import sdk
from sky.schemas.api import responses
from sky.serve.client import impl
from sky.server import common as server_common
from sky.server import constants as server_constants
from sky.server import rest
from sky.server import versions
from sky.server.requests import payloads
from sky.server.requests import request_names
from sky.skylet import constants
from sky.usage import usage_lib
from sky.utils import admin_policy_utils
from sky.utils import common_utils
from sky.utils import context
from sky.utils import dag_utils
if typing.TYPE_CHECKING:
import io
import sky
from sky import backends
from sky.serve import serve_utils
logger = sky_logging.init_logger(__name__)
[docs]@context.contextual
@usage_lib.entrypoint
@server_common.check_server_healthy_or_start
def launch(
task: Union['sky.Task', 'sky.Dag'],
name: Optional[str] = None,
pool: Optional[str] = None,
num_jobs: Optional[int] = None,
# Internal only:
# pylint: disable=invalid-name
_need_confirmation: bool = False,
) -> server_common.RequestId[Tuple[Optional[List[int]],
Optional['backends.ResourceHandle']]]:
"""Launches a managed job.
Please refer to sky.cli.job_launch for documentation.
Args:
task: sky.Task, or sky.Dag (experimental; 1-task only) to launch as a
managed job.
name: Name of the managed job.
_need_confirmation: (Internal only) Whether to show a confirmation
prompt before launching the job.
Returns:
The request ID of the launch request.
Request Returns:
job_ids (Optional[List[int]]]): Job IDs for the managed jobs
controller_handle (Optional[ResourceHandle]): ResourceHandle of the
controller
Request Raises:
ValueError: cluster does not exist. Or, the entrypoint is not a valid
chain dag.
sky.exceptions.NotSupportedError: the feature is not supported.
"""
remote_api_version = versions.get_remote_api_version()
if (pool is not None and
(remote_api_version is None or remote_api_version < 12)):
raise click.UsageError('Pools are not supported in your API server. '
'Please upgrade to a newer API server to use '
'pools.')
if pool is None and num_jobs is not None:
raise click.UsageError('Cannot specify num_jobs without pool.')
dag = dag_utils.convert_entrypoint_to_dag(task)
if name is not None:
dag.name = name
with admin_policy_utils.apply_and_use_config_in_current_request(
dag,
request_name=request_names.AdminPolicyRequestName.JOBS_LAUNCH,
at_client_side=True) as dag:
sdk.validate(dag)
if _need_confirmation:
job_identity = 'a managed job'
if pool is None:
optimize_request_id = sdk.optimize(dag)
sdk.stream_and_get(optimize_request_id)
else:
pool_status_request_id = pool_status(pool)
pool_statuses = sdk.get(pool_status_request_id)
if not pool_statuses:
raise click.UsageError(f'Pool {pool!r} not found.')
# Show the job's requested resources, not the pool worker
# resources
job_resources_str = backend_utils.get_task_resources_str(
dag.tasks[0], is_managed_job=True)
click.secho(
f'Use resources from pool {pool!r}: {job_resources_str}.',
fg='green')
if num_jobs is not None:
job_identity = f'{num_jobs} managed jobs'
prompt = f'Launching {job_identity} {dag.name!r}. Proceed?'
if prompt is not None:
click.confirm(prompt,
default=True,
abort=True,
show_default=True)
# Inject the client's API server endpoint for tasks with
# api_server_access. Done client-side because get_server_url()
# returns the externally reachable endpoint here, whereas
# the server sees 127.0.0.1.
any_api_access = any(t.api_server_access for t in dag.tasks)
if any_api_access:
remote_api_version = versions.get_remote_api_version()
if (remote_api_version is not None and remote_api_version <
server_constants.MIN_API_ACCESS_API_VERSION):
logger.debug(
'Skipping api_server_access injection: API server '
'version too old (need >= %s, got %s).',
server_constants.MIN_API_ACCESS_API_VERSION,
remote_api_version)
else:
endpoint = server_common.get_server_url()
if server_common.is_api_server_local(endpoint):
logger.debug(
'Skipping api_server_access injection: '
'API server appears to be local (%s).', endpoint)
else:
for task_ in dag.tasks:
if task_.api_server_access:
task_.update_envs({
constants.SKY_API_SERVER_URL_ENV_VAR: endpoint,
})
dag, file_mounts_blob_id = (
client_common.upload_mounts_to_api_server(dag))
dag_str = dag_utils.dump_dag_to_yaml_str(dag)
body = payloads.JobsLaunchBody(
task=dag_str,
name=name,
pool=pool,
num_jobs=num_jobs,
file_mounts_blob_id=file_mounts_blob_id,
)
response = server_common.make_authenticated_request(
'POST',
'/jobs/launch',
json=json.loads(body.model_dump_json()),
timeout=(5, None))
return server_common.get_request_id(response)
[docs]@usage_lib.entrypoint
@server_common.check_server_healthy_or_start
@versions.minimal_api_version(18)
def queue_v2(
refresh: bool,
skip_finished: bool = False,
all_users: bool = False,
job_ids: Optional[List[int]] = None,
limit: Optional[int] = None,
fields: Optional[List[str]] = None,
sort_by: Optional[str] = None,
sort_order: Optional[str] = None,
) -> server_common.RequestId[Tuple[List[responses.ManagedJobRecord], int, Dict[
str, int], int]]:
"""Gets statuses of managed jobs.
Please refer to sky.cli.job_queue for documentation.
Args:
refresh: Whether to restart the jobs controller if it is stopped.
skip_finished: Whether to skip finished jobs.
all_users: Whether to show all users' jobs.
job_ids: IDs of the managed jobs to show.
limit: Number of jobs to show.
fields: Fields to get for the managed jobs.
sort_by: Field to sort by (e.g., 'job_id', 'name', 'submitted_at').
sort_order: Sort direction ('asc' or 'desc').
Returns:
The request ID of the queue request.
Request Returns:
job_records (List[responses.ManagedJobRecord]): A list of dicts, with each dict
containing the information of a job.
.. code-block:: python
[
{
'job_id': (int) job id,
'job_name': (str) job name,
'resources': (str) resources of the job,
'submitted_at': (float) timestamp of submission,
'end_at': (float) timestamp of end,
'job_duration': (float) duration in seconds,
'recovery_count': (int) Number of retries,
'status': (sky.jobs.ManagedJobStatus) of the job,
'cluster_resources': (str) resources of the cluster,
'region': (str) region of the cluster,
'task_id': (int), set to 0 (except in pipelines, which may have multiple tasks), # pylint: disable=line-too-long
'task_name': (str), same as job_name (except in pipelines, which may have multiple tasks), # pylint: disable=line-too-long
'internal_external_ips': (List[Tuple[str, str]]) List of (internal_ip, external_ip) tuples for all nodes, # pylint: disable=line-too-long
'internal_services': (Dict[str, str]) K8s DNS entries, which maps Pod name to internal service (only for K8s), # pylint: disable=line-too-long
}
]
total (int): Total number of jobs after filter,
status_counts (Dict[str, int]): Status counts after filter,
total_no_filter (int): Total number of jobs before filter,
Request Raises:
sky.exceptions.ClusterNotUpError: the jobs controller is not up or
does not exist.
RuntimeError: if failed to get the managed jobs with ssh.
"""
# Filter out fields not supported by older servers.
# Maps minimum API version -> fields introduced in that version.
version_to_fields = {
31: {'is_primary_in_job_group'},
49: {'batch_total_batches', 'batch_completed_batches'},
}
if fields is not None:
remote_api_version = versions.get_remote_api_version()
for min_version, new_fields in version_to_fields.items():
if remote_api_version is None or remote_api_version < min_version:
fields = [f for f in fields if f not in new_fields]
body = payloads.JobsQueueV2Body(
refresh=refresh,
skip_finished=skip_finished,
all_users=all_users,
job_ids=job_ids,
limit=limit,
fields=fields,
sort_by=sort_by,
sort_order=sort_order,
)
path = '/jobs/queue/v2'
response = server_common.make_authenticated_request(
'POST',
path,
json=json.loads(body.model_dump_json()),
timeout=(5, None))
return server_common.get_request_id(response=response)
# Deprecated. Please use queue_v2 instead for better performance.
# In https://github.com/skypilot-org/skypilot/pull/7695, the `queue` function
# is updated to return new typed data for performance improvement if the API
# server supports it, which breaks the backward compatibility.
# In https://github.com/skypilot-org/skypilot/pull/8015, we revert the change
# and add a new function `queue_v2` to return the new typed data.
# TODO(lloyd): Remove version=1 support before 0.13.
[docs]@usage_lib.entrypoint
@server_common.check_server_healthy_or_start
def queue(
refresh: bool,
skip_finished: bool = False,
all_users: bool = False,
job_ids: Optional[List[int]] = None,
version: int = 1,
) -> server_common.RequestId[Union[List[responses.ManagedJobRecord], Tuple[
List[responses.ManagedJobRecord], int, Dict[str, int], int]]]:
"""Gets statuses of managed jobs.
Deprecated. Please use queue_v2 instead for better performance.
Please refer to sky.cli.job_queue for documentation.
Args:
refresh: Whether to restart the jobs controller if it is stopped.
skip_finished: Whether to skip finished jobs.
all_users: Whether to show all users' jobs.
job_ids: IDs of the managed jobs to show.
version: Queue API version to use. Must be 1 or 2.
Returns:
The request ID of the queue request.
Request Returns:
job_records (List[responses.ManagedJobRecord]): A list of dicts, with each dict
containing the information of a job.
.. code-block:: python
[
{
'job_id': (int) job id,
'job_name': (str) job name,
'resources': (str) resources of the job,
'submitted_at': (float) timestamp of submission,
'end_at': (float) timestamp of end,
'job_duration': (float) duration in seconds,
'recovery_count': (int) Number of retries,
'status': (sky.jobs.ManagedJobStatus) of the job,
'cluster_resources': (str) resources of the cluster,
'region': (str) region of the cluster,
'task_id': (int), set to 0 (except in pipelines, which may have multiple tasks), # pylint: disable=line-too-long
'task_name': (str), same as job_name (except in pipelines, which may have multiple tasks), # pylint: disable=line-too-long
'internal_external_ips': (List[Tuple[str, str]]) List of (internal_ip, external_ip) tuples for all nodes, # pylint: disable=line-too-long
'internal_services': (Dict[str, str]) K8s DNS entries, which maps Pod name to internal service (only for K8s), # pylint: disable=line-too-long
}
]
Request Raises:
sky.exceptions.ClusterNotUpError: the jobs controller is not up or
does not exist.
RuntimeError: if failed to get the managed jobs with ssh.
"""
if version not in (1, 2):
raise ValueError(f'Invalid queue version: {version}. Must be 1 or 2.')
if version == 2:
return queue_v2(refresh=refresh,
skip_finished=skip_finished,
all_users=all_users,
job_ids=job_ids)
logger.warning('sky.jobs.queue(version=1) is deprecated and will be '
'removed in v0.13. Use sky.jobs.queue(version=2) or '
'sky.jobs.queue_v2() instead.')
body = payloads.JobsQueueBody(
refresh=refresh,
skip_finished=skip_finished,
all_users=all_users,
job_ids=job_ids,
)
response = server_common.make_authenticated_request(
'POST',
'/jobs/queue',
json=json.loads(body.model_dump_json()),
timeout=(5, None))
return server_common.get_request_id(response=response)
[docs]@usage_lib.entrypoint
@server_common.check_server_healthy_or_start
def cancel(
name: Optional[str] = None,
job_ids: Optional[Sequence[int]] = None,
all: bool = False, # pylint: disable=redefined-builtin
all_users: bool = False,
pool: Optional[str] = None,
graceful: bool = False,
graceful_timeout: Optional[int] = None,
) -> server_common.RequestId[None]:
"""Cancels managed jobs.
Please refer to sky.cli.job_cancel for documentation.
Args:
name: Name of the managed job to cancel.
job_ids: IDs of the managed jobs to cancel.
all: Whether to cancel all managed jobs.
all_users: Whether to cancel all managed jobs from all users.
pool: Pool name to cancel.
graceful: Cancel the user's task but block until MOUNT_CACHED data is
fully uploaded. This helps with preserving user data integrity.
graceful_timeout: If not None, sets a timeout for the graceful option
above (in seconds).
Returns:
The request ID of the cancel request.
Request Raises:
sky.exceptions.ClusterNotUpError: the jobs controller is not up.
RuntimeError: failed to cancel the job.
"""
remote_api_version = versions.get_remote_api_version()
if (pool is not None and
(remote_api_version is None or remote_api_version < 12)):
raise click.UsageError('Pools are not supported in your API server. '
'Please upgrade to a newer API server to use '
'pools.')
if graceful and (remote_api_version is None or remote_api_version < 39):
logger.warning('`--graceful` is ignored because the server does '
'not support it yet.')
if graceful and pool is not None:
logger.warning('Pools are not cleaned up after job cancel, so '
'`--graceful` is ignored.')
body = payloads.JobsCancelBody(
name=name,
job_ids=job_ids,
all=all,
all_users=all_users,
pool=pool,
graceful=graceful,
graceful_timeout=graceful_timeout,
)
response = server_common.make_authenticated_request(
'POST',
'/jobs/cancel',
json=json.loads(body.model_dump_json()),
timeout=(5, None))
return server_common.get_request_id(response=response)
[docs]@usage_lib.entrypoint
@server_common.check_server_healthy_or_start
@rest.retry_transient_errors()
def tail_logs(name: Optional[str] = None,
job_id: Optional[int] = None,
follow: bool = True,
controller: bool = False,
refresh: bool = False,
tail: Optional[int] = None,
tail_offset: Optional[int] = None,
output_stream: Optional['io.TextIOBase'] = None,
task: Optional[Union[str, int]] = None) -> Optional[int]:
"""Tails logs of managed jobs.
You can provide either a job name or a job ID to tail logs. If both are not
provided, the logs of the latest job will be shown.
Args:
name: Name of the managed job to tail logs.
job_id: ID of the managed job to tail logs.
follow: Whether to follow the logs.
controller: Whether to tail logs from the jobs controller.
refresh: Whether to restart the jobs controller if it is stopped.
tail: Number of lines to tail from the end of the log file.
output_stream: The stream to write the logs to. If None, print to the
console.
task: Task identifier to view logs for a specific task in a JobGroup.
If an int, it is treated as a task ID. If a str, it is treated as
a task name. If None, logs for all tasks are shown.
Returns:
Exit code based on success or failure of the job. 0 if success,
100 if the job failed. See exceptions.JobExitCode for possible exit
codes.
Will return None if follow is False
(see note in sky/client/sdk.py::stream_response)
Request Raises:
ValueError: invalid arguments.
sky.exceptions.ClusterNotUpError: the jobs controller is not up.
"""
if tail is not None and tail <= 0:
raise ValueError(
f'tail must be None or a positive integer, got {tail}.')
if tail_offset is not None and tail_offset < 0:
raise ValueError(f'tail_offset must be None or a non-negative integer, '
f'got {tail_offset}.')
body = payloads.JobsLogsBody(
name=name,
job_id=job_id,
follow=follow,
controller=controller,
refresh=refresh,
tail=tail,
tail_offset=tail_offset,
task=task,
)
response = server_common.make_authenticated_request(
'POST',
'/jobs/logs',
json=json.loads(body.model_dump_json()),
stream=True,
timeout=(5, None))
request_id: server_common.RequestId[int] = server_common.get_request_id(
response)
# Log request is idempotent when tail is None or 0 (both stream from
# the beginning), thus can resume previous streaming point on retry.
return sdk.stream_response(request_id=request_id,
response=response,
output_stream=output_stream,
resumable=(tail is None or tail == 0),
get_result=follow)
@context.contextual
@usage_lib.entrypoint
@server_common.check_server_healthy_or_start
@versions.minimal_api_version(45)
def wait(
name: Optional[str] = None,
job_id: Optional[int] = None,
timeout: Optional[int] = None,
poll_interval: int = 15,
task: Optional[Union[str, int]] = None,
) -> server_common.RequestId[int]:
"""Waits for a managed job to reach a terminal state.
Blocks until the specified managed job finishes (succeeds, fails, or is
cancelled), or until the timeout is exceeded.
You can provide either a job name or a job ID. If a name is provided and
multiple jobs share that name, the most recent one is used.
For JobGroups (jobs with multiple tasks), if ``task`` is specified, waits
only for that specific task. Otherwise, waits until all tasks in the job
are in a terminal state.
Args:
name: Name of the managed job to wait for.
job_id: ID of the managed job to wait for.
timeout: Maximum time to wait in seconds. None means wait forever.
poll_interval: Time between status polls in seconds. Minimum 5,
default 15.
task: Task identifier to wait for a specific task in a JobGroup.
If an int, it is treated as a task ID. If a str, it is treated
as a task name. If None, waits for all tasks.
Returns:
The request ID of the wait request. The result is an exit code (int)
based on the terminal job status. 0 if success, 100 if failed.
See exceptions.JobExitCode for possible exit codes.
Request Raises:
ValueError: if arguments are invalid or job/task is not found.
TimeoutError: if the timeout is exceeded before the job finishes.
"""
body = payloads.JobsWaitBody(
name=name,
job_id=job_id,
timeout=timeout,
poll_interval=poll_interval,
task=task,
)
response = server_common.make_authenticated_request(
'POST',
'/jobs/wait',
json=json.loads(body.model_dump_json()),
timeout=(5, None))
return server_common.get_request_id(response=response)
@usage_lib.entrypoint
@server_common.check_server_healthy_or_start
def download_logs_streaming(
name: Optional[str],
job_id: Optional[int],
refresh: bool,
controller: bool,
local_dir: str = constants.SKY_LOGS_DIRECTORY,
) -> Optional[Dict[int, str]]:
"""Download a managed job's log via the streaming /api/stream path.
Returns None when the server stream is empty (e.g. terminal job
whose worker cluster is gone) — the caller should fall back to
``download_logs``.
This dispatches the same /jobs/logs (tail=None, follow=False) path
that the live-tail UI uses, then attaches to /api/stream with
compress=gz so gzip framing saves bandwidth on the wire. The
response is decompressed on the client and saved as a plain log
file inside a per-job directory; the directory shape matches the
legacy ``download_logs`` output (``<dir>/controller.log`` for
``--controller``, ``<dir>/run.log`` otherwise) so callers that
walk the returned path with ``[ -d ]`` / ``cat <dir>/foo.log``
keep working.
Returns:
``{job_id: local_directory}``. The directory contains
``controller.log`` (controller mode) or ``run.log``
(non-controller).
"""
body = payloads.JobsLogsBody(
name=name,
job_id=job_id,
follow=False,
controller=controller,
refresh=refresh,
tail=None,
)
dispatch = server_common.make_authenticated_request(
'POST',
'/jobs/logs',
json=json.loads(body.model_dump_json()),
stream=True,
timeout=(5, None))
if not dispatch.ok:
raise RuntimeError(
f'Failed to dispatch /jobs/logs: HTTP {dispatch.status_code}')
request_id = dispatch.headers.get(server_constants.STREAM_REQUEST_HEADER) \
or dispatch.headers.get('X-SkyPilot-Request-ID')
if not request_id:
raise RuntimeError(
'/jobs/logs response missing X-SkyPilot-Request-ID header')
# Drain the dispatch body in a background thread. Cancelling/closing
# would tell the API server the client disconnected and the running
# tail_logs task would be cancelled, leaving /api/stream with only
# a partial log. Reading and discarding keeps the request alive.
def _drain() -> None:
try:
for _ in dispatch.iter_content(chunk_size=64 * 1024):
pass
except Exception: # pylint: disable=broad-except
pass
threading.Thread(target=_drain, daemon=True).start()
stream_url = (f'/api/stream?request_id={request_id}'
'&format=plain&compress=gz')
stream_resp = server_common.make_authenticated_request('GET',
stream_url,
stream=True,
timeout=(5, None))
if not stream_resp.ok:
raise RuntimeError(
f'Failed to attach to /api/stream: HTTP {stream_resp.status_code}')
# Save into a per-job directory matching the legacy download_logs
# shape (<dir>/controller.log or <dir>/run.log) so existing scripts
# that grep <path>/controller.log keep working. Decompress on the
# client when the server gzipped the stream — older API servers
# without compress=gz support silently ignore the query param and
# return text/plain, so sniff Content-Type and skip decompression
# in that case.
content_type = (stream_resp.headers.get('Content-Type') or '').lower()
is_gzipped = content_type.startswith('application/gzip')
decompressor = (zlib.decompressobj(16 +
zlib.MAX_WBITS) if is_gzipped else None)
log_type = 'controller' if controller else 'job'
log_filename = 'controller.log' if controller else 'run.log'
job_label = job_id if job_id is not None else (name or 'latest')
job_dir = (pathlib.Path(local_dir).expanduser() / 'managed_jobs' /
f'managed-{log_type}-{job_label}')
job_dir.mkdir(parents=True, exist_ok=True)
local_path = job_dir / log_filename
bytes_written = 0
with open(local_path, 'wb') as f:
for chunk in stream_resp.iter_content(chunk_size=64 * 1024):
if not chunk:
continue
out = decompressor.decompress(chunk) if decompressor else chunk
if out:
f.write(out)
bytes_written += len(out)
if decompressor is not None:
tail_bytes = decompressor.flush()
if tail_bytes:
f.write(tail_bytes)
bytes_written += len(tail_bytes)
if bytes_written == 0:
# Server sent nothing (e.g., terminal job, worker cluster gone) —
# the underlying tail_logs has no source. Remove the empty file
# + dir and return None so the caller falls back to sync-down.
try:
local_path.unlink()
job_dir.rmdir()
except OSError:
pass
return None
key = int(job_id) if job_id is not None else 0
return {key: str(job_dir)}
@usage_lib.entrypoint
@server_common.check_server_healthy_or_start
def download_logs(
name: Optional[str],
job_id: Optional[int],
refresh: bool,
controller: bool,
local_dir: str = constants.SKY_LOGS_DIRECTORY) -> Dict[int, str]:
"""Sync down logs of managed jobs.
Please refer to sky.cli.job_logs for documentation.
Args:
name: Name of the managed job to sync down logs.
job_id: ID of the managed job to sync down logs.
refresh: Whether to restart the jobs controller if it is stopped.
controller: Whether to sync down logs from the jobs controller.
local_dir: Local directory to sync down logs.
Returns:
A dictionary mapping job ID to the local path.
Request Raises:
ValueError: invalid arguments.
sky.exceptions.ClusterNotUpError: the jobs controller is not up.
"""
body = payloads.JobsDownloadLogsBody(
name=name,
job_id=job_id,
refresh=refresh,
controller=controller,
local_dir=local_dir,
)
response = server_common.make_authenticated_request(
'POST',
'/jobs/download_logs',
json=json.loads(body.model_dump_json()),
timeout=(5, None))
request_id: server_common.RequestId[Dict[
str, str]] = server_common.get_request_id(response)
job_id_remote_path_dict = sdk.stream_and_get(request_id)
remote2local_path_dict = client_common.download_logs_from_api_server(
job_id_remote_path_dict.values())
return {
int(job_id): remote2local_path_dict[remote_path]
for job_id, remote_path in job_id_remote_path_dict.items()
}
spot_launch = common_utils.deprecated_function(
launch,
name='sky.jobs.launch',
deprecated_name='spot_launch',
removing_version='0.8.0',
override_argument={'use_spot': True})
spot_queue = common_utils.deprecated_function(queue,
name='sky.jobs.queue',
deprecated_name='spot_queue',
removing_version='0.8.0')
spot_cancel = common_utils.deprecated_function(cancel,
name='sky.jobs.cancel',
deprecated_name='spot_cancel',
removing_version='0.8.0')
spot_tail_logs = common_utils.deprecated_function(
tail_logs,
name='sky.jobs.tail_logs',
deprecated_name='spot_tail_logs',
removing_version='0.8.0')
@usage_lib.entrypoint
@server_common.check_server_healthy_or_start
def dashboard() -> None:
"""Starts a dashboard for managed jobs."""
user_hash = common_utils.get_user_hash()
api_server_url = server_common.get_server_url()
params = f'user_hash={user_hash}'
url = f'{api_server_url}/jobs/dashboard?{params}'
logger.info(f'Opening dashboard in browser: {url}')
common_utils.open_browser(url)
@context.contextual
@usage_lib.entrypoint
@server_common.check_server_healthy_or_start
@versions.minimal_api_version(12)
def pool_apply(
task: Optional[Union['sky.Task', 'sky.Dag']],
pool_name: str,
mode: 'serve_utils.UpdateMode',
workers: Optional[int] = None,
# Internal only:
# pylint: disable=invalid-name
_need_confirmation: bool = False
) -> server_common.RequestId[None]:
"""Apply a config to a pool."""
remote_api_version = versions.get_remote_api_version()
if (workers is not None and
(remote_api_version is None or remote_api_version < 19)):
raise click.UsageError('Updating the number of workers in a pool is '
'not supported in your API server. Please '
'upgrade to a newer API server to use this '
'feature.')
return impl.apply(task,
workers,
pool_name,
mode,
pool=True,
_need_confirmation=_need_confirmation)
@usage_lib.entrypoint
@server_common.check_server_healthy_or_start
@versions.minimal_api_version(12)
def pool_down(
pool_names: Optional[Union[str, List[str]]],
all: bool = False, # pylint: disable=redefined-builtin
purge: bool = False,
) -> server_common.RequestId[None]:
"""Delete a pool."""
return impl.down(pool_names, all, purge, pool=True)
@usage_lib.entrypoint
@server_common.check_server_healthy_or_start
@versions.minimal_api_version(12)
def pool_status(
pool_names: Optional[Union[str, List[str]]],
) -> server_common.RequestId[List[Dict[str, Any]]]:
"""Query a pool."""
return impl.status(pool_names, pool=True)
@usage_lib.entrypoint
@server_common.check_server_healthy_or_start
@rest.retry_transient_errors()
@versions.minimal_api_version(16)
def pool_tail_logs(pool_name: str,
target: Union[str, 'serve_utils.ServiceComponent'],
worker_id: Optional[int] = None,
follow: bool = True,
output_stream: Optional['io.TextIOBase'] = None,
tail: Optional[int] = None) -> None:
"""Tails logs of a pool."""
return impl.tail_logs(pool_name,
target,
worker_id,
follow,
output_stream,
tail,
pool=True)
@usage_lib.entrypoint
@server_common.check_server_healthy_or_start
@rest.retry_transient_errors()
@versions.minimal_api_version(16)
def pool_sync_down_logs(pool_name: str,
local_dir: str,
*,
targets: Optional[Union[
str, 'serve_utils.ServiceComponent', Sequence[Union[
str, 'serve_utils.ServiceComponent']]]] = None,
worker_ids: Optional[List[int]] = None,
tail: Optional[int] = None) -> None:
"""Sync down logs of a pool."""
return impl.sync_down_logs(pool_name,
local_dir,
targets=targets,
replica_ids=worker_ids,
tail=tail,
pool=True)