Admin Policy Enforcement#

SkyPilot provides an admin policy mechanism that admins can use to enforce certain policies on users’ SkyPilot usage. An admin policy applies custom validation and mutation logic to a user’s tasks and SkyPilot config.

Example usage:

Overview#

SkyPilot has a client-server architecture, where a centralized API server can be deployed and users can interact with the server through a client.

To deploy a admin policy, here are the steps:

  1. Implement the policy in a Python package or host the policy as a RESTful server.

  2. Install the policy at the server-side to enforce it for all users.

  3. Optionally, an admin policy can also be installed at the client-side if it needs to access user’s local environment.

The order of policy application is demonstrated below:

Admin policy application order

Note

Client-side policy lacks enforcement capability, i.e., end-user may modify them. However, client policy is useful for automation that can only be applied at client-side. Refer to Use local GCP credentials for all tasks as an example.

Quickstart#

Install the example policy package:

git clone https://github.com/skypilot-org/skypilot.git
cd skypilot
pip install examples/admin_policy/example_policy

Then, set the admin_policy field in the SkyPilot config to use an example policy.

admin_policy: example_policy.DoNothingPolicy

Tip

You can replace DoNothingPolicy with any of the example policies.

Then, run a task:

$ sky launch

You should see the admin policy is applied to the task.

$ sky launch
Applying client admin policy: DoNothingPolicy
Applying server admin policy: DoNothingPolicy
...

You can also deploy an admin policy as a RESTful server.

Deploy an admin policy#

Server-side#

If you have a centralized API server deployed, you can enforce a policy for all users by setting it at the server-side.

Open SkyPilot dashboard https://api.server.com/dashboard/config, and set the admin_policy field to the URL of the RESTful policy. To host a RESTful policy, see here.

admin_policy: https://example.com/policy

First, install the Python package that implements the policy on the API server host:

pip install mypackage.subpackage

For helm deployment, refer to Setting an admin policy to install the policy package.

Then, open the server’s dashboard, go to the server’s SkyPilot config and set the admin_policy field to the path of the Python package that implements the policy.

admin_policy: mypackage.subpackage.MyPolicy

Client-side#

If the policy needs to access user’s local environment, you can get the policy applied at the client-side by following the steps below.

First, install the Python package that implements the policy:

pip install mypackage.subpackage

Then, set the admin_policy field in the SkyPilot config to the path of the Python package that implements the policy.

admin_policy: mypackage.subpackage.MyPolicy

Optionally, you can also apply different policies in different projects by leveraging the layered config, e.g. set a different policy in $pwd/.sky.yaml for the current project:

admin_policy: mypackage.subpackage.AnotherPolicy

Hint

SkyPilot loads the policy from the given package in the same Python environment. You can test the existence of the policy by running:

python -c "from mypackage.subpackage import MyPolicy"

Note

It is possible to call a RESTful policy at client-side. However, a RESTful policy is executed on the policy server host, i.e., cannot access user’s local environment, e.g., local files.

Host admin policy as a RESTful server#

You can host an admin policy as a RESTful API server and configure the SkyPilot to call the RESTful url to apply the policy.

It is recommended to inherit your implementation from the AdminPolicy interface to ensure the request and response body are correctly typed. You can also import existing policies at the server and composite these policies to fit your needs. Here is an example of implementing a policy server using Python and FastAPI:

Example Policy Server
#!/usr/bin/env python3
"""Example RESTful admin policy server for SkyPilot."""

import argparse
from typing import List

import example_policy
from fastapi import FastAPI
from fastapi import Request
from fastapi.responses import JSONResponse
import uvicorn

import sky

app = FastAPI(title="Example Admin Policy Server", version="1.0.0")


class DoNothingPolicy(sky.AdminPolicy):
    """Example policy: do nothing."""

    @classmethod
    def validate_and_mutate(
            cls, user_request: sky.UserRequest) -> sky.MutatedUserRequest:
        """Returns the user request unchanged."""
        return sky.MutatedUserRequest(user_request.task,
                                      user_request.skypilot_config)


@app.post('/')
async def apply_policy(request: Request) -> JSONResponse:
    """Apply an admin policy loaded from external package to a user request"""
    # Decode
    json_data = await request.json()
    user_request = sky.UserRequest.decode(json_data)
    # Example: change the following list to apply different policies.
    policies: List[sky.AdminPolicy] = [
        # Example: policy that implemented in the server package.
        DoNothingPolicy,
        # Example: policy from third party packages.
        example_policy.UseSpotForGpuPolicy,
    ]
    try:
        for policy in policies:
            mutated_request = policy.validate_and_mutate(user_request)
            user_request.task = mutated_request.task
            user_request.skypilot_config = mutated_request.skypilot_config
    except Exception as e:  # pylint: disable=broad-except
        return JSONResponse(content=str(e), status_code=400)

    return JSONResponse(content=mutated_request.encode())


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--host',
                        default='0.0.0.0',
                        help='Host to bind to (default: 0.0.0.0)')
    parser.add_argument('--port',
                        type=int,
                        default=8080,
                        help='Port to bind to (default: 8080)')
    args = parser.parse_args()
    uvicorn.run(app,
                workers=1,
                host=args.host,
                port=args.port,
                log_level="info")

Optionally, the server can also be implemented in other languages as long as it follows the API convention:

The Admin Policy API

POST /<api-path>

Request body is a marshalled sky.UserRequest in JSON format:

{
  "task": {
    "name": "sky-cmd",
    "resources": {
      "cpus": "1+",
    },
    "num_nodes": 1,
  },
  "skypilot_config": {},
  "request_options": {
    "cluster_name": "test",
    "idle_minutes_to_autostop": null,
    "down": false,
    "dryrun": false
  },
  "at_client_side": false
}

Response body is a marshalled sky.MutatedUserRequest in JSON format:

{
  "task": {
    "name": "sky-cmd",
    "resources": {
      "cpus": "1+",
    },
    "num_nodes": 1,
  },
  "skypilot_config": {}
}

Implement an admin policy#

Admin policies are implemented by extending the sky.AdminPolicy interface:

AdminPolicy Interface
class AdminPolicy(PolicyInterface):
    """Abstract interface of an admin-defined policy for all user requests.

    Admins can implement a subclass of AdminPolicy with the following signature:

        import sky

        class SkyPilotPolicyV1(sky.AdminPolicy):
            def validate_and_mutate(user_request: UserRequest) -> MutatedUserRequest:
                ...
                return MutatedUserRequest(task=..., skypilot_config=...)

    The policy can mutate both task and skypilot_config. Admins then distribute
    a simple module that contains this implementation, installable in a way
    that it can be imported by users from the same Python environment where
    SkyPilot is running.

    Users can register a subclass of AdminPolicy in the SkyPilot config file
    under the key 'admin_policy', e.g.

        admin_policy: my_package.SkyPilotPolicyV1
    """

    @classmethod
    @abc.abstractmethod
    def validate_and_mutate(cls,
                            user_request: UserRequest) -> MutatedUserRequest:
        """Validates and mutates the user request and returns mutated request.

        Args:
            user_request: The user request to validate and mutate.
                UserRequest contains (sky.Task, sky.Config)

        Returns:
            MutatedUserRequest: The mutated user request.

        Raises:
            Exception to throw if the user request failed the validation.
        """
        raise NotImplementedError(
            'Your policy must implement validate_and_mutate')

    def apply(self, user_request: UserRequest) -> MutatedUserRequest:
        return self.validate_and_mutate(user_request)

Your custom admin policy should look like this:

import sky

class MyPolicy(sky.AdminPolicy):
    @classmethod
    def validate_and_mutate(cls, user_request: sky.UserRequest) -> sky.MutatedUserRequest:
        # Logic for validate and modify user requests.
        ...
        return sky.MutatedUserRequest(user_request.task,
                                      user_request.skypilot_config)

UserRequest and MutatedUserRequest are defined as follows (see source code for more details):

UserRequest Class
@dataclasses.dataclass
class UserRequest:
    """A user request.

    A "user request" is defined as a `sky launch / exec` command or its API
    equivalent.

    `sky jobs launch / serve up` involves multiple launch requests, including
    the launch of controller and clusters for a job (which can have multiple
    tasks if it is a pipeline) or service replicas. Each launch is a separate
    request.

    This class wraps the underlying task, the global skypilot config used to run
    a task, and the request options.

    Args:
        task: User specified task.
        skypilot_config: Global skypilot config to be used in this request.
        request_options: Request options. It is None for jobs and services.
        at_client_side: Is the request intercepted by the policy at client-side?
        user: User who made the request.
              Only available on the server side.
              This value is None if at_client_side is True.
    """
    task: 'sky.Task'
    skypilot_config: 'sky.Config'
    request_name: request_names.AdminPolicyRequestName
    request_options: Optional['RequestOptions'] = None
    at_client_side: bool = False
    user: Optional['models.User'] = None

    def encode(self) -> str:
        return _UserRequestBody(
            task=yaml_utils.dump_yaml_str(self.task.to_yaml_config()),
            skypilot_config=yaml_utils.dump_yaml_str(dict(
                self.skypilot_config)),
            request_name=self.request_name.value,
            request_options=self.request_options,
            at_client_side=self.at_client_side,
            user=(yaml_utils.dump_yaml_str(self.user.to_dict())
                  if self.user is not None else ''),
        ).model_dump_json()

    @classmethod
    def decode(cls, body: str) -> 'UserRequest':
        user_request_body = _UserRequestBody.model_validate_json(body)
        user_dict = yaml_utils.read_yaml_str(
            user_request_body.user) if user_request_body.user != '' else None
        user = models.User(
            id=user_dict['id'],
            name=user_dict['name']) if user_dict is not None else None
        return cls(
            task=sky.Task.from_yaml_config(
                yaml_utils.read_yaml_all_str(user_request_body.task)[0]),
            skypilot_config=config_utils.Config.from_dict(
                yaml_utils.read_yaml_all_str(
                    user_request_body.skypilot_config)[0]),
            request_name=request_names.AdminPolicyRequestName(
                user_request_body.request_name),
            request_options=user_request_body.request_options,
            at_client_side=user_request_body.at_client_side,
            user=user,
        )
MutatedUserRequest Class
@dataclasses.dataclass
class MutatedUserRequest:
    """Mutated user request."""

    task: 'sky.Task'
    skypilot_config: 'sky.Config'

    def encode(self) -> str:
        return _MutatedUserRequestBody(
            task=yaml_utils.dump_yaml_str(self.task.to_yaml_config()),
            skypilot_config=yaml_utils.dump_yaml_str(dict(
                self.skypilot_config),)).model_dump_json()

    @classmethod
    def decode(cls, mutated_user_request_body: str,
               original_request: UserRequest) -> 'MutatedUserRequest':
        mutated_user_request_body = _MutatedUserRequestBody.model_validate_json(
            mutated_user_request_body)
        task = sky.Task.from_yaml_config(
            yaml_utils.read_yaml_all_str(mutated_user_request_body.task)[0])
        # Some internal Task fields are not serialized. We need to manually
        # restore them from the original request.
        task.managed_job_dag = original_request.task.managed_job_dag
        task.service_name = original_request.task.service_name
        return cls(task=task,
                   skypilot_config=config_utils.Config.from_dict(
                       yaml_utils.read_yaml_all_str(
                           mutated_user_request_body.skypilot_config)[0],))

In other words, an AdminPolicy can mutate any fields of a user request, including the task and the skypilot config for that specific user request, giving admins a lot of flexibility to control user’s SkyPilot usage.

An AdminPolicy can be used to both validate and mutate user requests. If a request should be rejected, the policy should raise an exception.

The sky.Config and sky.RequestOptions classes are defined as follows:

Config Class
class Config(Dict[str, Any]):
    """SkyPilot config that supports setting/getting values with nested keys."""

    def get_nested(
        self,
        keys: Tuple[str, ...],
        default_value: Any,
        override_configs: Optional[Dict[str, Any]] = None,
        allowed_override_keys: Optional[List[Tuple[str, ...]]] = None,
        disallowed_override_keys: Optional[List[Tuple[str,
                                                      ...]]] = None) -> Any:
        """Gets a nested key.

        If any key is not found, or any intermediate key does not point to a
        dict value, returns 'default_value'.

        Args:
            keys: A tuple of strings representing the nested keys.
            default_value: The default value to return if the key is not found.
            override_configs: A dict of override configs with the same schema as
                the config file, but only containing the keys to override.
            allowed_override_keys: A list of keys that are allowed to be
                overridden.
            disallowed_override_keys: A list of keys that are disallowed to be
                overridden.

        Returns:
            The value of the nested key, or 'default_value' if not found.
        """
        config = copy.deepcopy(self)
        if override_configs is not None:
            config = _recursive_update(config, override_configs,
                                       allowed_override_keys,
                                       disallowed_override_keys)
        return _get_nested(config, keys, default_value, pop=False)

    def set_nested(self, keys: Tuple[str, ...], value: Any) -> None:
        """In-place sets a nested key to value.

        Like get_nested(), if any key is not found, this will not raise an
        error.
        """
        override = {}
        for i, key in enumerate(reversed(keys)):
            if i == 0:
                override = {key: value}
            else:
                override = {key: override}
        _recursive_update(self, override)

    def pop_nested(self, keys: Tuple[str, ...], default_value: Any) -> Any:
        """Pops a nested key."""
        return _get_nested(self, keys, default_value, pop=True)

    @classmethod
    def from_dict(cls, config: Optional[Dict[str, Any]]) -> 'Config':
        if config is None:
            return cls()
        return cls(**config)
RequestOptions Class
class RequestOptions(pydantic.BaseModel):
    """Request options for admin policy.

    Args:
        cluster_name: Name of the cluster to create/reuse. It is None if not
            specified by the user.
        idle_minutes_to_autostop: Autostop setting requested by a user. The
            cluster will be set to autostop after this many minutes of idleness.
        down: If true, use autodown rather than autostop.
        dryrun: Is the request a dryrun?
    """
    cluster_name: Optional[str]
    # Keep these two fields for backward compatibility. The values are copied
    # from task.resources.autostop_config, so that legacy admin policy plugins
    # can still read the correct autostop config from request options before
    # we drop the compatibility.
    # TODO(aylei): remove these fields after 0.12.0
    idle_minutes_to_autostop: Optional[int]
    down: bool
    dryrun: bool

Tips for writing an admin policy#

When writing an admin policy, you can leverage the following tips to make your policy more robust and flexible.

Access user information#

Use UserRequest.user to access the user who made the request. This field is only available when the policy is applied at the server-side.

Useful for:

  • Logging the user who made a specific request.

  • Implementing per-user quotas or rate limits.

  • Changing the behavior of the policy based on the user.

Access request name#

Use UserRequest.request_name to access the request name.

Useful for:

  • Activating a policy selectively for certain request types.

  • Implementing per-request rate limits.

  • Changing the behavior of the policy based on the request name.

Inspect and modify resources#

Use UserRequest.task.get_resource_config() to get the resource configuration of a task.

The resource configuration is a dictionary conforming to the resource config schema.

Once the resource configuration is modified, you can use UserRequest.task.set_resources(resource_config) to set the modified resource configuration back to the task.

resource_config = user_request.task.get_resource_config()
resource_config['use_spot'] = True
if 'any_of' in resource_config:
    for any_of_config in resource_config['any_of']:
        any_of_config['use_spot'] = True
elif 'ordered' in resource_config:
    for ordered_config in resource_config['ordered']:
        ordered_config['use_spot'] = True
user_request.task.set_resources(resource_config)

Useful for:

  • Enforcing resource constraints (e.g. use spot instances for all GPU tasks, enforce autostop for all tasks).

Example policies#

We have provided a few example policies in examples/admin_policy/example_policy. You can test these policies by installing the example policy package in your Python environment.

git clone https://github.com/skypilot-org/skypilot.git
cd skypilot
pip install examples/admin_policy/example_policy

Reject all tasks#

class RejectAllPolicy(sky.AdminPolicy):
    """Example policy: rejects all user requests."""

    @classmethod
    def validate_and_mutate(
            cls, user_request: sky.UserRequest) -> sky.MutatedUserRequest:
        """Rejects all user requests."""
        raise RuntimeError('Reject all policy')
admin_policy: example_policy.RejectAllPolicy

Add labels for all tasks on Kubernetes#

class AddLabelsPolicy(sky.AdminPolicy):
    """Example policy: adds a kubernetes label for skypilot_config."""

    @classmethod
    def validate_and_mutate(
            cls, user_request: sky.UserRequest) -> sky.MutatedUserRequest:
        config = user_request.skypilot_config
        labels = config.get_nested(('kubernetes', 'custom_metadata', 'labels'),
                                   {})
        labels['app'] = 'skypilot'
        config.set_nested(('kubernetes', 'custom_metadata', 'labels'), labels)
        return sky.MutatedUserRequest(user_request.task, config)
admin_policy: example_policy.AddLabelsPolicy

Always disable public IP for AWS tasks#

class DisablePublicIpPolicy(sky.AdminPolicy):
    """Example policy: disables public IP for all AWS tasks."""

    @classmethod
    def validate_and_mutate(
            cls, user_request: sky.UserRequest) -> sky.MutatedUserRequest:
        config = user_request.skypilot_config
        config.set_nested(('aws', 'use_internal_ip'), True)
        if config.get_nested(('aws', 'vpc_name'), None) is None:
            # If no VPC name is specified, it is likely a mistake. We should
            # reject the request
            raise RuntimeError('VPC name should be set. Check organization '
                               'wiki for more information.')
        return sky.MutatedUserRequest(user_request.task, config)
admin_policy: example_policy.DisablePublicIpPolicy

Use spot for all GPU tasks#

class UseSpotForGpuPolicy(sky.AdminPolicy):
    """Example policy: use spot instances for all GPU tasks."""

    @classmethod
    def validate_and_mutate(
            cls, user_request: sky.UserRequest) -> sky.MutatedUserRequest:
        """Sets use_spot to True for all GPU tasks."""
        task = user_request.task
        new_resources = []
        for r in task.resources:
            if r.accelerators:
                new_resources.append(r.copy(use_spot=True))
            else:
                new_resources.append(r)

        task.set_resources(type(task.resources)(new_resources))

        return sky.MutatedUserRequest(
            task=task, skypilot_config=user_request.skypilot_config)
admin_policy: example_policy.UseSpotForGpuPolicy

Enforce autostop for all tasks#

class EnforceAutostopPolicy(sky.AdminPolicy):
    """Example policy: enforce autostop for all tasks."""

    @classmethod
    def validate_and_mutate(
            cls, user_request: sky.UserRequest) -> sky.MutatedUserRequest:
        """Enforces autostop for all tasks.

        Note that with this policy enforced, users can still change the autostop
        setting for an existing cluster by using `sky autostop`.

        Since we refresh the cluster status with `sky.status` whenever this
        policy is applied, we should expect a few seconds latency when a user
        run a request.
        """
        if user_request.request_name not in [
                sky.AdminPolicyRequestName.CLUSTER_LAUNCH,
                sky.AdminPolicyRequestName.CLUSTER_EXEC,
        ]:
            return sky.MutatedUserRequest(
                task=user_request.task,
                skypilot_config=user_request.skypilot_config)

        request_options = user_request.request_options
        # Request options is not None when a task is executed with `sky launch`.
        assert request_options is not None
        # Get the cluster record to operate on.
        cluster_name = request_options.cluster_name
        cluster_records: List[responses.StatusResponse] = []
        if cluster_name is not None:
            try:
                cluster_records = sky.get(
                    sky.status([cluster_name],
                               refresh=common.StatusRefreshMode.AUTO,
                               all_users=True))
            except Exception as e:
                raise RuntimeError('Failed to get cluster status for '
                                   f'{cluster_name}: {e}') from None

        # Check if the user request should specify autostop settings.
        need_autostop = False
        if not cluster_records:
            # Cluster does not exist
            need_autostop = True
        elif cluster_records[0]['status'] == sky.ClusterStatus.STOPPED:
            # Cluster is stopped
            need_autostop = True
        elif cluster_records[0]['autostop'] < 0:
            # Cluster is running but autostop is not set
            need_autostop = True

        # Check if the user request is setting autostop settings.
        is_setting_autostop = False
        idle_minutes_to_autostop = request_options.idle_minutes_to_autostop
        is_setting_autostop = (idle_minutes_to_autostop is not None and
                               idle_minutes_to_autostop >= 0)

        # If the cluster requires autostop but the user request is not setting
        # autostop settings, raise an error.
        if need_autostop and not is_setting_autostop:
            raise RuntimeError('Autostop/down must be set for all clusters.')

        return sky.MutatedUserRequest(
            task=user_request.task,
            skypilot_config=user_request.skypilot_config)
admin_policy: example_policy.EnforceAutostopPolicy

Set max autostop idle minutes for all tasks#

class SetMaxAutostopIdleMinutesPolicy(sky.AdminPolicy):
    """Example policy: set max autostop idle minutes for all tasks."""

    @classmethod
    def validate_and_mutate(
            cls, user_request: sky.UserRequest) -> sky.MutatedUserRequest:
        """Sets max autostop idle minutes for all tasks."""
        max_idle_minutes = 10
        task = user_request.task
        resources = task.get_resource_config()
        if 'autostop' not in resources:
            # autostop is disabled
            resources['autostop'] = {'idle_minutes': max_idle_minutes}
        elif ('idle_minutes' not in resources['autostop'] or
              int(resources['autostop']['idle_minutes']) > max_idle_minutes):
            # Autostop idle minutes is too long
            resources['autostop']['idle_minutes'] = max_idle_minutes
        task.set_resources(resources)

        return sky.MutatedUserRequest(
            task=task, skypilot_config=user_request.skypilot_config)
admin_policy: example_policy.SetMaxAutostopIdleMinutesPolicy

Dynamically update Kubernetes contexts to use#

class DynamicKubernetesContextsUpdatePolicy(sky.AdminPolicy):
    """Example policy: update the kubernetes context to use."""

    @classmethod
    def validate_and_mutate(
            cls, user_request: sky.UserRequest) -> sky.MutatedUserRequest:
        """Updates the kubernetes context to use."""
        # Append any new kubernetes clusters in local kubeconfig. An example
        # implementation of this method can be:
        #  1. Query an organization's internal Kubernetes cluster registry,
        #     which can be some internal API, or a secret vault.
        #  2. Append the new credentials to the local kubeconfig.
        update_current_kubernetes_clusters_from_registry()
        # Get the allowed contexts for the user. Similarly, it can retrieve
        # the latest allowed contexts from an organization's internal API.
        allowed_contexts = get_allowed_contexts()

        # Update the kubernetes allowed contexts in skypilot config.
        config = user_request.skypilot_config
        config.set_nested(('kubernetes', 'allowed_contexts'), allowed_contexts)
        return sky.MutatedUserRequest(task=user_request.task,
                                      skypilot_config=config)
admin_policy: example_policy.DynamicKubernetesContextsUpdatePolicy

Use local GCP credentials for all tasks#

class UseLocalGcpCredentialsPolicy(sky.AdminPolicy):
    """Example policy: use local GCP credentials in the task."""

    @classmethod
    def validate_and_mutate(
            cls, user_request: sky.UserRequest) -> sky.MutatedUserRequest:
        # Only apply the policy at client-side.
        if not user_request.at_client_side:
            if not _LOCAL_GCP_CREDENTIALS_SET_ENV_VAR in user_request.task.envs:
                raise RuntimeError(
                    f'Policy {cls.__name__} was not applied at client-side. '
                    'Please install the policy and retry.')
            cv = user_request.task.envs[_LOCAL_GCP_CREDENTIALS_SET_ENV_VAR]
            if (cv != _POLICY_VERSION):
                raise RuntimeError(
                    f'Policy {cls.__name__} at {cv} was applied at client-side '
                    f'but the server requires {_POLICY_VERSION} to be applied. '
                    'Please upgrade the policy and retry.')
            return sky.MutatedUserRequest(user_request.task,
                                          user_request.skypilot_config)

        task = user_request.task
        if task.file_mounts is None:
            task.file_mounts = {}
        # Use the env var to detect whether an explicit credential path is
        # specified.
        cred_path = os.environ.get(_GOOGLE_APPLICATION_CREDENTIALS_ENV)

        if cred_path is not None:
            task.file_mounts[_GOOGLE_APPLICATION_CREDENTIALS_PATH] = cred_path
            activate_cmd = (f'gcloud auth activate-service-account --key-file '
                            f'{_GOOGLE_APPLICATION_CREDENTIALS_PATH}')
            if task.run is None:
                task.run = activate_cmd
            elif isinstance(task.run, str):
                task.run = f'{activate_cmd} && {task.run}'
            else:
                # Impossible according to current code base, but just in case.
                logger.warning('The task run command is not a string, '
                               f'so the local {cred_path} will not be used.')
        else:
            # Otherwise upload the entire default credential directory to get
            # consistent identity in the task and the local environment.
            task.file_mounts['~/.config/gcloud'] = '~/.config/gcloud'
        task.envs[_LOCAL_GCP_CREDENTIALS_SET_ENV_VAR] = _POLICY_VERSION
        return sky.MutatedUserRequest(task, user_request.skypilot_config)

Specify the following config in the SkyPilot config at the client:

admin_policy: example_policy.UseLocalGcpCredentialsPolicy

Then specify the policy at the server with the same config, or call this policy in the RESTful policy server.

This policy only take effects when applied at client-side. Use this policy at the server-side will be a no-op.

Add volumes to all tasks#

class AddVolumesPolicy(sky.AdminPolicy):
    """Example policy: add volumes to the task."""

    @classmethod
    def validate_and_mutate(
            cls, user_request: sky.UserRequest) -> sky.MutatedUserRequest:
        task = user_request.task
        if task.is_controller_task():
            # Skip applying admin policy to job/serve controller
            return sky.MutatedUserRequest(task, user_request.skypilot_config)
        # Use `task.set_volumes` to set the volumes.
        # Or use `task.update_volumes` to update in-place
        # instead of overwriting.
        task.set_volumes({'/mnt/data0': 'pvc0'})
        return sky.MutatedUserRequest(task, user_request.skypilot_config)
admin_policy: example_policy.AddVolumesPolicy

Rate limit cluster launch requests#

class RateLimitLaunchPolicy(sky.AdminPolicy):
    """Example policy: rate limit cluster launch requests."""
    _RATE_LIMITER = TokenBucketRateLimiter(capacity=10, fill_rate=1)

    @classmethod
    def validate_and_mutate(
            cls, user_request: sky.UserRequest) -> sky.MutatedUserRequest:
        """Rate limit cluster launch requests."""
        if (user_request.at_client_side or user_request.request_name !=
                sky.AdminPolicyRequestName.CLUSTER_LAUNCH):
            return sky.MutatedUserRequest(
                task=user_request.task,
                skypilot_config=user_request.skypilot_config)

        # user is not None when the policy is applied at the server-side
        assert user_request.user is not None
        user_name = user_request.user.name
        if not cls._RATE_LIMITER.allow_request(user_name):
            raise RuntimeError(f'Rate limit exceeded for user {user_name}')

        return sky.MutatedUserRequest(
            task=user_request.task,
            skypilot_config=user_request.skypilot_config)

The RateLimitLaunchPolicy uses TokenBucketRateLimiter class to rate limit cluster launch requests.

TokenBucketRateLimiter Class
class TokenBucketRateLimiter:
    """Token bucket rate limiter.

    This rate limiter allows a user to make requests up to
    fill_rate requests per second.

    Args:
        capacity: The maximum number of requests allowed in the bucket.
        fill_rate: The rate at which the bucket is filled with requests.

    Example:
        .. code-block:: python

            rate_limiter = TokenBucketRateLimiter(capacity=2, fill_rate=1)
            # The first two calls use up the two tokens in the bucket.
            assert rate_limiter.allow_request('user1') is True
            assert rate_limiter.allow_request('user1') is True
            # The third call is denied as the bucket is empty.
            assert rate_limiter.allow_request('user1') is False
            # Wait for 1 second, the bucket is refilled with 1 token.
            time.sleep(1)
            assert rate_limiter.allow_request('user1') is True
    """

    def __init__(self, capacity, fill_rate):
        """Initializes the token bucket rate limiter.

        Args:
            capacity: The maximum number of requests allowed in the bucket.
            fill_rate: The rate at which the bucket is filled with requests.
        """
        # import the modules here so users importing this module
        # to use other policies do not need to import these modules.
        # pylint: disable=import-outside-toplevel
        import threading

        import sqlalchemy
        from sqlalchemy import orm
        from sqlalchemy.dialects import postgresql
        from sqlalchemy.dialects import sqlite

        self.capacity = float(capacity)
        self.fill_rate = float(fill_rate)  # tokens per second
        self.lock = threading.Lock()
        # Tip: you can swap out the connection string
        # to use a postgres database.
        self._db_engine = sqlalchemy.create_engine('sqlite:///rate_limit.db')
        if self._db_engine.dialect.name == 'sqlite':
            self.insert_func = sqlite.insert
        elif self._db_engine.dialect.name == 'postgresql':
            self.insert_func = postgresql.insert

        self.rate_limit_table = sqlalchemy.Table(
            'rate_limit',
            orm.declarative_base().metadata,
            sqlalchemy.Column('user_name', sqlalchemy.Text, primary_key=True),
            sqlalchemy.Column('tokens', sqlalchemy.REAL),
            sqlalchemy.Column('last_refill_time', sqlalchemy.REAL),
        )

        with orm.Session(self._db_engine) as init_session:
            init_session.execute(
                sqlalchemy.text('CREATE TABLE IF NOT EXISTS rate_limit '
                                '(user_name TEXT PRIMARY KEY, tokens REAL, '
                                'last_refill_time REAL)'))
            init_session.commit()

    def allow_request(self, user_name):
        """Determines if a request is allowed for a given user.

        Args:
            user_name: The name of the user.

        Returns:
            True if the request is allowed, False otherwise.
        """
        # import the modules here so users importing this module
        # to use other policies do not need to import these modules.
        # pylint: disable=import-outside-toplevel
        import time

        from sqlalchemy import orm

        with self.lock:
            now = time.time()
            with orm.Session(self._db_engine) as session:
                # with_for_update() locks the row until commit() or rollback()
                # is called, or until the code escapes the with block.
                result = session.query(self.rate_limit_table).filter(
                    self.rate_limit_table.c.user_name ==
                    user_name).with_for_update().first()
                if result:
                    tokens = result.tokens
                    last_refill_time = result.last_refill_time
                else:
                    tokens = self.capacity
                    last_refill_time = now
                time_elapsed = now - last_refill_time
                # Refill the bucket based on the fill rate and the time elapsed
                # since the last refill.
                tokens = min(self.capacity,
                             tokens + time_elapsed * self.fill_rate)
                # Check if the request is allowed.
                if tokens >= 1:
                    tokens -= 1
                    allowed = True
                else:
                    allowed = False
                # update the bucket in the database.
                insert_or_update_stmt = (self.insert_func(
                    self.rate_limit_table).values(
                        user_name=user_name,
                        tokens=tokens,
                        last_refill_time=now).on_conflict_do_update(
                            index_elements=[self.rate_limit_table.c.user_name],
                            set_={
                                self.rate_limit_table.c.tokens: tokens,
                                self.rate_limit_table.c.last_refill_time: now
                            }))
                session.execute(insert_or_update_stmt)
                session.commit()
            return allowed
admin_policy: example_policy.RateLimitLaunchPolicy

Enforce a static GPU quota for each user on cluster launch requests#

Note

This policy calls sky.status() to get the total number of GPUs currently used by the user and therefore adds a few seconds of latency to every cluster launch request.

This policy should be considered an educational example and not a production-ready policy.

class GPUStaticQuotaPolicy(sky.AdminPolicy):
    """Example policy: Enforce a static GPU quota
    for each user for cluster launch requests."""

    # GPU quota allotted for each user.
    GPU_QUOTA_PER_USER = {
        'H100': 2,
        'L40S': 10,
    }

    @classmethod
    def validate_and_mutate(
            cls, user_request: sky.UserRequest) -> sky.MutatedUserRequest:
        """Enforce a static GPU quota for each user for cluster launch requests.

        This policy is does not enforce a quota for jobs launch requests.

        Note: This policy calls sky.status() to get the total number
        of GPUs currently used by the user and therefore adds a
        few seconds of latency for every cluster launch request.

        Raises:
            RuntimeError: If the user has exceeded the GPU quota for any
            accelerator type.
        """
        # Import ast here so users importing this module
        # to use other policies do not need to import this module.
        # pylint: disable=import-outside-toplevel
        import ast

        # If the request is at client side or not a cluster launch request,
        # do not enforce GPU quota.
        if (user_request.at_client_side or user_request.request_name !=
                sky.AdminPolicyRequestName.CLUSTER_LAUNCH):
            return sky.MutatedUserRequest(
                task=user_request.task,
                skypilot_config=user_request.skypilot_config)

        assert user_request.user is not None, (
            'Failed to get user initiating the request.')
        user_name = user_request.user.name
        assert user_name is not None, (
            'Failed to get user name initiating the request.')

        # Get the total number of GPUs currently used by the user.
        try:
            cluster_records = sky.get(
                sky.status(refresh=common.StatusRefreshMode.NONE,
                           all_users=True,
                           _summary_response=True))
        except Exception as e:
            raise RuntimeError('Failed to get cluster records for '
                               f'all users: {e}') from None
        accelerators_used: Dict[str, int] = {}
        cluster_records_for_user = [
            record for record in cluster_records
            if record.user_name == user_name
        ]
        for record in cluster_records_for_user:
            if not record.accelerators:
                continue
            accelerators = ast.literal_eval(record.accelerators)
            for accelerator, count in accelerators.items():
                accelerators_used[accelerator] = accelerators_used.get(
                    accelerator, 0) + (count * record.nodes)
        # At this point, accelerators_used is a dictionary of the
        # GPUs currently used by the user in the format of
        # {accelerator_type: count}.

        # Now, check if any resource request exceeds the GPU quota.
        for resource in user_request.task.resources:
            if resource.accelerators:
                for accelerator, count in resource.accelerators.items():
                    count *= user_request.task.num_nodes
                    quota = cls.GPU_QUOTA_PER_USER.get(accelerator, 0)
                    if accelerators_used.get(accelerator, 0) + count > quota:
                        raise RuntimeError(
                            f'User {user_name} has exceeded the'
                            f'GPU quota for {accelerator}. '
                            f'In use: {accelerators_used.get(accelerator, 0)}, '
                            f'Requested: {count}, '
                            f'Quota: {quota}')

        return sky.MutatedUserRequest(
            task=user_request.task,
            skypilot_config=user_request.skypilot_config)
admin_policy: example_policy.GPUStaticQuotaPolicy