diff --git a/cloud_pipelines_backend/launchers/skypilot_launchers.py b/cloud_pipelines_backend/launchers/skypilot_launchers.py new file mode 100644 index 0000000..2f307df --- /dev/null +++ b/cloud_pipelines_backend/launchers/skypilot_launchers.py @@ -0,0 +1,820 @@ +"""SkyPilot launcher for Tangle pipelines. + +Translates Tangle's ContainerTaskLauncher contract into sky.jobs managed-job +submissions. SkyPilot then handles container scheduling, multi-cloud / +multi-cluster placement, multi-node coordination, preemption recovery, log +streaming, and cancellation. + +Layout follows the existing launchers in cloud_pipelines_backend/launchers/ +(local_docker_launchers.py, kubernetes_launchers.py). + +Storage provider compatibility +============================== + +This launcher relies on SkyPilot's ``file_mounts`` for input/output artifact +transfer, which can mount cloud-storage URIs (``gs://``, ``s3://``, ``abfs://``, +``r2://``, ``https://``) directly into the container but cannot represent the +relative-local-path artifact URIs produced by Tangle's +``LocalStorageProvider``. + +Practical consequences: + + * Single-component pipelines run end-to-end on any storage provider — the + container's stdout/stderr is captured by SkyPilot regardless. Components + with file outputs that target a local URI will simply have those outputs + discarded with a warning (the container still executes). + + * Multi-component (graph) pipelines require a cloud StorageProvider. With + ``LocalStorageProvider``, downstream tasks cannot read the upstream + output (the SkyPilot pod can't write back to the orchestrator's local + filesystem), so step 2 fails to find its input. + +To run multi-step pipelines, configure Tangle with a cloud storage provider, +for example:: + + from cloud_pipelines.orchestration.storage_providers.google_cloud_storage \\ + import GoogleCloudStorageProvider + + orchestrator = orchestrator_sql.OrchestratorService_Sql( + ..., + launcher=SkyPilotKubernetesLauncher(infra="kubernetes/"), + storage_provider=GoogleCloudStorageProvider(), + data_root_uri="gs://my-tangle-bucket/artifacts", + logs_root_uri="gs://my-tangle-bucket/logs", + ) +""" + +from __future__ import annotations + +import dataclasses +import datetime +import io +import json +import logging +import shlex +import threading +from typing import Any, Iterator, Optional + +from cloud_pipelines.orchestration.launchers import naming_utils +from cloud_pipelines.orchestration.storage_providers import ( + interfaces as storage_provider_interfaces, +) +from cloud_pipelines_backend import component_structures as structures +from cloud_pipelines_backend.launchers import ( + container_component_utils, + interfaces, + kubernetes_launchers as _k8s_launchers, +) + +import sky +from sky import jobs as sky_jobs + +_logger = logging.getLogger(__name__) + +_MAX_INPUT_VALUE_SIZE = 10000 +_CONTAINER_FILE_NAME = "data" +# SkyPilot itself does not impose an upper bound on num_nodes (only the cloud +# quota does). We keep a sanity cap that is significantly higher than Tangle's +# kubernetes_launchers cap of 16; raise it freely if you have larger jobs. +_MULTI_NODE_MAX_NUMBER_OF_NODES = 256 + +# Re-use the resource annotation keys from kubernetes_launchers so a Tangle +# component spec is portable across the K8s launcher and this one. +RESOURCES_CPU_ANNOTATION_KEY = _k8s_launchers.RESOURCES_CPU_ANNOTATION_KEY +RESOURCES_MEMORY_ANNOTATION_KEY = _k8s_launchers.RESOURCES_MEMORY_ANNOTATION_KEY +RESOURCES_ACCELERATORS_ANNOTATION_KEY = ( + _k8s_launchers.RESOURCES_ACCELERATORS_ANNOTATION_KEY +) +RESOURCES_EPHEMERAL_STORAGE_ANNOTATION_KEY = ( + _k8s_launchers.RESOURCES_EPHEMERAL_STORAGE_ANNOTATION_KEY +) +MULTI_NODE_NUMBER_OF_NODES_ANNOTATION_KEY = ( + _k8s_launchers.MULTI_NODE_NUMBER_OF_NODES_ANNOTATION_KEY +) + +# SkyPilot-specific annotation keys (opt-in). +PRIORITY_CLASS_ANNOTATION_KEY = "skypilot.co/launchers/skypilot/priority_class" +SPOT_ANNOTATION_KEY = "skypilot.co/launchers/skypilot/use_spot" + +# Tangle's multi-node dynamic-data keys (mirror of kubernetes_launchers). +_MULTI_NODE_NUMBER_OF_NODES_DYNAMIC_DATA_KEY = "system/multi_node/number_of_nodes" +_MULTI_NODE_NODE_INDEX_DYNAMIC_DATA_KEY = "system/multi_node/node_index" +_MULTI_NODE_NODE_0_ADDRESS_DYNAMIC_DATA_KEY = "system/multi_node/node_0_address" +_MULTI_NODE_ALL_NODE_ADDRESSES_DYNAMIC_DATA_KEY = "system/multi_node/all_node_addresses" + +# ManagedJobStatus (str-valued enum or string) -> Tangle ContainerStatus. +_TERMINAL_STATUSES = frozenset( + { + "SUCCEEDED", + "CANCELLED", + "FAILED", + "FAILED_SETUP", + "FAILED_PRECHECKS", + "FAILED_NO_RESOURCE", + "FAILED_CONTROLLER", + } +) +_STATUS_MAP: dict[str, interfaces.ContainerStatus] = { + "PENDING": interfaces.ContainerStatus.PENDING, + "STARTING": interfaces.ContainerStatus.PENDING, + "RUNNING": interfaces.ContainerStatus.RUNNING, + "RECOVERING": interfaces.ContainerStatus.RUNNING, + "WINDING_DOWN": interfaces.ContainerStatus.RUNNING, + "CANCELLING": interfaces.ContainerStatus.RUNNING, + "SUCCEEDED": interfaces.ContainerStatus.SUCCEEDED, + "CANCELLED": interfaces.ContainerStatus.FAILED, + "FAILED": interfaces.ContainerStatus.FAILED, + "FAILED_SETUP": interfaces.ContainerStatus.FAILED, + "FAILED_PRECHECKS": interfaces.ContainerStatus.FAILED, + "FAILED_NO_RESOURCE": interfaces.ContainerStatus.FAILED, + "FAILED_CONTROLLER": interfaces.ContainerStatus.ERROR, +} + + +def _status_to_string(status: Any) -> str: + if status is None: + return "PENDING" + if hasattr(status, "value"): + return str(status.value) + return str(status) + + +def _shell_quote_argv(argv: list[str]) -> str: + return " ".join(shlex.quote(p) for p in argv) + + +# Bash prelude that bridges Tangle's multi-node contract to SkyPilot's runtime +# env vars. Lets a component's command line reference $TANGLE_MULTI_NODE_* +# without caring whether SkyPilot or Kubernetes is the launcher. +_MULTI_NODE_ENV_PRELUDE = ( + 'export TANGLE_MULTI_NODE_NUMBER_OF_NODES="${SKYPILOT_NUM_NODES:-1}"\n' + 'export TANGLE_MULTI_NODE_NODE_INDEX="${SKYPILOT_NODE_RANK:-0}"\n' + 'export TANGLE_MULTI_NODE_NODE_0_ADDRESS="$(echo "${SKYPILOT_NODE_IPS:-localhost}" | head -n1)"\n' + 'export TANGLE_MULTI_NODE_ALL_NODE_ADDRESSES="${SKYPILOT_NODE_IPS:-localhost}"\n' +) + + +@dataclasses.dataclass +class _SkyPilotJobHandle: + job_id: int + job_name: str + output_uris: dict[str, str] + log_uri: str + cached_status: Optional[str] = None + cached_failure_reason: Optional[str] = None + cached_started_at: Optional[float] = None + cached_ended_at: Optional[float] = None + + +def _coerce_disk_size_gb(spec: Any) -> int: + """Best-effort parse of an ephemeral-storage annotation into GiB.""" + if isinstance(spec, (int, float)): + return max(1, int(spec)) + s = str(spec).strip().lower() + multipliers = { + "ti": 1024.0, "gi": 1.0, "mi": 1 / 1024, "ki": 1 / 1024 / 1024, + "t": 1000.0, "g": 1.0, "m": 1 / 1000, "k": 1 / 1000 / 1000, + } + for suffix, mult in sorted(multipliers.items(), key=lambda kv: -len(kv[0])): + if s.endswith(suffix): + try: + return max(1, int(float(s[: -len(suffix)]) * mult)) + except ValueError: + continue + try: + return max(1, int(float(s))) + except ValueError: + return 8 + + +class SkyPilotKubernetesLauncher( + interfaces.ContainerTaskLauncher["SkyPilotLaunchedJob"] +): + """Launches Tangle container tasks via SkyPilot managed jobs. + + Designed for Kubernetes-only deployments (the Shopify use case) but works + against any infra SkyPilot supports. Set ``infra="kubernetes"`` (or + ``"kubernetes/"``) to keep behavior aligned with the existing + KubernetesWithGcsFuseContainerLauncher; pass ``infra=None`` to let + SkyPilot's optimizer pick across any clouds the user has configured. + """ + + def __init__( + self, + *, + infra: Optional[str] = "kubernetes", + pool: Optional[str] = None, + default_image: Optional[str] = None, + default_labels: Optional[dict[str, str]] = None, + default_envs: Optional[dict[str, str]] = None, + annotation_to_label_keys: Optional[list[str]] = None, + priority_class: Optional[str] = None, + use_spot: Optional[bool] = None, + job_name_prefix: str = "tangle-", + storage_provider: Optional[ + storage_provider_interfaces.StorageProvider + ] = None, + ): + """ + Args: + infra: SkyPilot infra string. ``"kubernetes"`` for any K8s context; + ``"kubernetes/"`` to pin to one cluster; ``None`` to + let the optimizer pick across all configured clouds. + pool: Optional SkyPilot Pool name. Submitting to a warm Pool gives + much faster cold-start than full provisioning. + default_image: Fallback container image when ComponentSpec doesn't + specify one. + default_labels: Labels applied to every Sky resource (propagated to + K8s pod labels under the kubernetes infra). + default_envs: Env vars injected into every container. + annotation_to_label_keys: Tangle annotation keys whose values are + copied into Sky labels. Useful for passing through things like + ``ml.shopify.io/priority-class`` so the K8s pod ends up with the + same label kueue is configured to read. + priority_class: Default SkyPilot priority class (Kueue-compatible). + Can be overridden per-task via the ``PRIORITY_CLASS_ANNOTATION_KEY`` + annotation on a ComponentSpec. + use_spot: Default spot-instance preference. Per-task override via + ``SPOT_ANNOTATION_KEY``. + job_name_prefix: Prefix for SkyPilot managed job names. + storage_provider: Used by ``upload_log()`` to mirror the SkyPilot + managed-job logs to ``log_uri`` so the Tangle UI can display + them. If ``None``, ``upload_log()`` is a no-op (sky logs are + still available via ``sky jobs logs ``). + """ + self._infra = infra + self._pool = pool + self._default_image = default_image + self._default_labels = dict(default_labels or {}) + self._default_envs = dict(default_envs or {}) + self._annotation_to_label_keys = list(annotation_to_label_keys or []) + self._default_priority_class = priority_class + self._default_use_spot = use_spot + self._job_name_prefix = job_name_prefix + self._storage_provider = storage_provider + # Serialize submissions; sky's SDK is safe to call concurrently but the + # orchestrator may invoke launch_container_task from many workers, and + # serializing keeps log_uri / file_mount conflict checks deterministic. + self._lock = threading.Lock() + + # ----------------- ContainerTaskLauncher contract ----------------- + + def launch_container_task( + self, + *, + component_spec: structures.ComponentSpec, + input_arguments: dict[str, interfaces.InputArgument], + output_uris: dict[str, str], + log_uri: str, + annotations: dict[str, Any] | None = None, + ) -> "SkyPilotLaunchedJob": + if not isinstance( + component_spec.implementation, structures.ContainerImplementation + ): + raise interfaces.LauncherError( + f"Component must have container implementation. {component_spec=}" + ) + container_spec = component_spec.implementation.container + annotations = dict(annotations or {}) + + task = self._build_task( + component_spec=component_spec, + container_spec=container_spec, + input_arguments=input_arguments, + output_uris=output_uris, + annotations=annotations, + ) + job_name = task.name or self._job_name_prefix + "task" + + # Submit. sky.jobs.launch returns a RequestId; await it via sky.get(). + # Result shape changed across sky versions: older returns (List[int], Handle), + # newer (>=0.12) returns (Optional[int], Handle). Handle both. + with self._lock: + launch_kwargs: dict[str, Any] = {} + if self._pool is not None: + launch_kwargs["pool"] = self._pool + request_id = sky_jobs.launch(task, name=job_name, **launch_kwargs) + result = sky.get(request_id) + + job_id_or_ids, _handle = result + if job_id_or_ids is None: + raise interfaces.LauncherError( + f"sky.jobs.launch returned no job id for {job_name}" + ) + if isinstance(job_id_or_ids, (list, tuple)): + if not job_id_or_ids: + raise interfaces.LauncherError( + f"sky.jobs.launch returned empty job-id list for {job_name}" + ) + job_id = int(job_id_or_ids[0]) + else: + job_id = int(job_id_or_ids) + _logger.info( + "Submitted SkyPilot managed job %s (job_id=%d)", job_name, job_id + ) + + return SkyPilotLaunchedJob( + handle=_SkyPilotJobHandle( + job_id=job_id, + job_name=job_name, + output_uris=dict(output_uris), + log_uri=log_uri, + ), + storage_provider=self._storage_provider, + ) + + def deserialize_launched_container_from_dict( + self, launched_container_dict: dict + ) -> "SkyPilotLaunchedJob": + return SkyPilotLaunchedJob.from_dict( + launched_container_dict, storage_provider=self._storage_provider + ) + + def get_refreshed_launched_container_from_dict( + self, launched_container_dict: dict + ) -> "SkyPilotLaunchedJob": + return SkyPilotLaunchedJob.from_dict( + launched_container_dict, storage_provider=self._storage_provider + ).get_refreshed() + + # ----------------- ComponentSpec -> sky.Task translation ----------------- + + def _build_task( + self, + *, + component_spec: structures.ComponentSpec, + container_spec: structures.ContainerSpec, + input_arguments: dict[str, interfaces.InputArgument], + output_uris: dict[str, str], + annotations: dict[str, Any], + ) -> sky.Task: + # Resources + cpus = annotations.get(RESOURCES_CPU_ANNOTATION_KEY) + memory = annotations.get(RESOURCES_MEMORY_ANNOTATION_KEY) + accelerators = annotations.get(RESOURCES_ACCELERATORS_ANNOTATION_KEY) + ephemeral_storage = annotations.get( + RESOURCES_EPHEMERAL_STORAGE_ANNOTATION_KEY + ) + + # Multi-node count + num_nodes_str = annotations.get(MULTI_NODE_NUMBER_OF_NODES_ANNOTATION_KEY, "1") + try: + num_nodes = int(num_nodes_str) + except ValueError as ex: + raise interfaces.LauncherError( + f"Invalid {MULTI_NODE_NUMBER_OF_NODES_ANNOTATION_KEY}={num_nodes_str!r}" + ) from ex + if not (1 <= num_nodes <= _MULTI_NODE_MAX_NUMBER_OF_NODES): + raise interfaces.LauncherError( + f"num_nodes must be in [1, {_MULTI_NODE_MAX_NUMBER_OF_NODES}], " + f"got {num_nodes}" + ) + + # Labels: defaults, plus selected annotations propagated as labels. + labels = dict(self._default_labels) + for ann_key in self._annotation_to_label_keys: + if ann_key in annotations: + sky_label_key = ( + ann_key.replace("/", "_").replace(".", "_").replace(":", "_") + ) + labels[sky_label_key] = str(annotations[ann_key]) + + # Pre-resolve multi-node dynamic-data inputs to shell expansions of the + # TANGLE_MULTI_NODE_* env vars set by our run-script prelude. This keeps + # the rest of resolve_container_command_line oblivious to multi-node. + for input_argument in input_arguments.values(): + if input_argument.value is not None or input_argument.uri is not None: + continue + if not input_argument.dynamic_data: + continue + kind, _payload = container_component_utils.parse_dynamic_data_argument( + input_argument.dynamic_data + ) + if kind == _MULTI_NODE_NUMBER_OF_NODES_DYNAMIC_DATA_KEY: + input_argument.value = "${TANGLE_MULTI_NODE_NUMBER_OF_NODES}" + elif kind == _MULTI_NODE_NODE_INDEX_DYNAMIC_DATA_KEY: + input_argument.value = "${TANGLE_MULTI_NODE_NODE_INDEX}" + elif kind == _MULTI_NODE_NODE_0_ADDRESS_DYNAMIC_DATA_KEY: + input_argument.value = "${TANGLE_MULTI_NODE_NODE_0_ADDRESS}" + elif kind == _MULTI_NODE_ALL_NODE_ADDRESSES_DYNAMIC_DATA_KEY: + input_argument.value = "${TANGLE_MULTI_NODE_ALL_NODE_ADDRESSES}" + else: + raise interfaces.LauncherError( + f"Dynamic data '{kind}' is not supported by the SkyPilot launcher" + ) + + file_mounts: dict[str, str] = {} + envs: dict[str, str] = {**self._default_envs, **(container_spec.env or {})} + + def get_input_value(input_name: str) -> str: + ia = input_arguments[input_name] + if ia.is_dir: + raise interfaces.LauncherError( + f"Cannot consume directory as value. {input_name=}" + ) + if ia.total_size > _MAX_INPUT_VALUE_SIZE: + raise interfaces.LauncherError( + f"Artifact too big to consume as value. Use a path. {input_name=}" + ) + if ia.value is None: + # First cut: require scalar values to be pre-resolved by the + # orchestrator. Adding a storage_provider hook here is the + # natural extension for parity with kubernetes_launchers. + raise interfaces.LauncherError( + f"Input '{input_name}' has no inline value. Pre-resolve " + "scalar inputs before submission, or extend this launcher " + "with a storage_provider for downloads." + ) + return ia.value + + def _is_cloud_uri(uri: str) -> bool: + return any( + uri.startswith(s) + for s in ("gs://", "s3://", "abfs://", "https://", "http://", "r2://") + ) + + # SkyPilot's MOUNT mode requires the source to be a bucket root (not a + # sub-path within a bucket — sky/data/storage.py:_validate_source raises + # StorageModeError otherwise). We mount each unique bucket once at a + # stable container path and use sub-paths inside. + bucket_mount_root = "/mnt/skypilot" + + def _bucket_and_subpath(uri: str) -> tuple[str, str]: + # gs://bucket/sub/path -> ("gs://bucket", "sub/path") + scheme, _, rest = uri.partition("://") + bucket, _, sub = rest.partition("/") + return f"{scheme}://{bucket}", sub + + def _container_path_for(uri: str) -> str: + bucket_uri, sub_path = _bucket_and_subpath(uri) + scheme = bucket_uri.split("://", 1)[0] + bucket_name = bucket_uri.split("://", 1)[1] + mount_point = f"{bucket_mount_root}/{scheme}/{bucket_name}" + file_mounts[mount_point] = {"source": bucket_uri, "mode": "MOUNT"} + return f"{mount_point}/{sub_path}" + + def get_input_path(input_name: str) -> str: + ia = input_arguments[input_name] + if ia.uri is None: + raise interfaces.LauncherError( + f"Input '{input_name}' has no URI. Stage values to cloud " + "storage (gs://, s3://, ...) before submitting through the " + "SkyPilot launcher." + ) + if not _is_cloud_uri(ia.uri): + raise interfaces.LauncherError( + f"Input '{input_name}' uri={ia.uri!r} is not a cloud storage URI. " + "The SkyPilot launcher requires gs://, s3://, abfs://, https://, or " + "r2:// for inputs. Configure Tangle with a cloud StorageProvider " + "(e.g. GoogleCloudStorageProvider) for cloud-based runs." + ) + return _container_path_for(ia.uri) + + def get_output_path(output_name: str) -> str: + uri = output_uris[output_name] + if _is_cloud_uri(uri): + return _container_path_for(uri) + _logger.warning( + "Output '%s' uri=%r is not a cloud URI; the SkyPilot launcher " + "will not persist it back to Tangle's storage. Use a cloud " + "StorageProvider (gs://, s3://, ...) to persist outputs.", + output_name, uri, + ) + sanitized = naming_utils.sanitize_file_name(output_name) + return f"/tmp/outputs/{sanitized}/{_CONTAINER_FILE_NAME}" + + resolved = container_component_utils.resolve_container_command_line( + component_spec=component_spec, + provided_input_names=set(input_arguments.keys()), + get_input_value=get_input_value, + get_input_path=get_input_path, + get_output_path=get_output_path, + ) + + cmd_str = _shell_quote_argv(list(resolved.command) + list(resolved.args)) + run_script = "set -euo pipefail\n" + _MULTI_NODE_ENV_PRELUDE + cmd_str + "\n" + + # Name the SkyPilot job after the component for easy filtering. + component_name = ( + component_spec.name + or (component_spec.metadata.name if component_spec.metadata else None) + or "task" + ) + job_name = self._job_name_prefix + naming_utils.sanitize_file_name( + component_name + ) + + image = container_spec.image or self._default_image + if not image: + raise interfaces.LauncherError( + f"Component '{component_name}' has no container image and the " + "launcher was not configured with default_image." + ) + + # Resources kwargs + resources_kwargs: dict[str, Any] = { + # SkyPilot accepts container images via image_id="docker:" — this + # is the canonical way to launch an arbitrary user container in K8s. + "image_id": f"docker:{image}", + } + if self._infra is not None: + resources_kwargs["infra"] = self._infra + if cpus is not None: + resources_kwargs["cpus"] = cpus + if memory is not None: + resources_kwargs["memory"] = memory + if accelerators is not None: + # Tangle's kubernetes_launchers expects a JSON object like + # `{"nvidia-tesla-h100": 8}`. SkyPilot accepts either a Sky-format + # string ("H100:8") or a {name: count} dict — try JSON first so the + # same component spec works under either launcher. + parsed_accel: Any = accelerators + if isinstance(accelerators, str): + try: + parsed_accel = json.loads(accelerators) + except (ValueError, TypeError): + parsed_accel = accelerators + resources_kwargs["accelerators"] = parsed_accel + if ephemeral_storage is not None: + resources_kwargs["disk_size"] = _coerce_disk_size_gb(ephemeral_storage) + if labels: + resources_kwargs["labels"] = labels + + priority_class = annotations.get( + PRIORITY_CLASS_ANNOTATION_KEY, self._default_priority_class + ) + if priority_class is not None: + resources_kwargs["priority_class"] = priority_class + + spot_value = annotations.get(SPOT_ANNOTATION_KEY) + if spot_value is not None: + resources_kwargs["use_spot"] = str(spot_value).lower() in ( + "1", "true", "yes", + ) + elif self._default_use_spot is not None: + resources_kwargs["use_spot"] = self._default_use_spot + + # Build a sky YAML-shaped dict and let Task.from_yaml_config parse it. + # The YAML parser auto-promotes cloud-URI entries in file_mounts into + # sky.Storage MOUNT mounts (sky/task.py:660-688), which is the path that + # works under consolidation mode. Constructing sky.Task() directly with + # cloud-URI file_mounts goes through translate_local_file_mounts_to_two_hop + # which rejects them. + task_config: dict[str, Any] = { + "name": job_name, + "run": run_script, + "envs": envs, + "num_nodes": num_nodes, + "resources": resources_kwargs, + } + if file_mounts: + task_config["file_mounts"] = file_mounts + task = sky.Task.from_yaml_config(task_config) + return task + + +class SkyPilotLaunchedJob(interfaces.LaunchedContainer): + """Tangle-side handle around a SkyPilot managed job. + + Wraps a job_id; status/logs are queried lazily via the sky SDK. The handle + is serializable via to_dict / from_dict so the orchestrator can persist it + in the request DB and reload across restarts. + """ + + def __init__( + self, + handle: _SkyPilotJobHandle, + *, + storage_provider: Optional[ + storage_provider_interfaces.StorageProvider + ] = None, + ): + self._handle = handle + self._storage_provider = storage_provider + + @property + def id(self) -> str: + return f"sky:{self._handle.job_id}" + + @property + def job_id(self) -> int: + return self._handle.job_id + + @property + def status(self) -> interfaces.ContainerStatus: + return _STATUS_MAP.get( + self._handle.cached_status or "PENDING", interfaces.ContainerStatus.ERROR + ) + + @property + def exit_code(self) -> Optional[int]: + if not self.has_ended: + return None + return 0 if self.has_succeeded else 1 + + @property + def has_ended(self) -> bool: + return (self._handle.cached_status or "PENDING") in _TERMINAL_STATUSES + + @property + def has_succeeded(self) -> bool: + return self._handle.cached_status == "SUCCEEDED" + + @property + def has_failed(self) -> bool: + return self.has_ended and not self.has_succeeded + + @property + def started_at(self) -> Optional[datetime.datetime]: + if self._handle.cached_started_at is None: + return None + return datetime.datetime.fromtimestamp( + self._handle.cached_started_at, tz=datetime.timezone.utc + ) + + @property + def ended_at(self) -> Optional[datetime.datetime]: + if self._handle.cached_ended_at is None: + return None + return datetime.datetime.fromtimestamp( + self._handle.cached_ended_at, tz=datetime.timezone.utc + ) + + @property + def launcher_error_message(self) -> Optional[str]: + return self._handle.cached_failure_reason + + # Sky returns only the "Job N is already in terminal state ..." message + # when tail_logs is called immediately after a job ends — the user-job + # log file may not yet be flushed/visible on the controller. Detect this + # so callers can retry. + _SKY_TERMINAL_STATE_HINT = "is already in terminal state" + + def get_log(self) -> str: + import time as _time + # Retry briefly: when the job has just ended, sky's tail_logs returns + # the terminal-state hint instead of the real log file. The actual + # logs become available after the controller finalizes them. + attempts, delay = 6, 2.0 + out = "" + for i in range(attempts): + buf = io.StringIO() + sky_jobs.tail_logs( + job_id=self._handle.job_id, follow=False, output_stream=buf + ) + out = buf.getvalue() + if self._SKY_TERMINAL_STATE_HINT not in out or len(out) > 500: + # Real log content: either the hint is absent, or there's + # enough content that the hint is just the trailing footer. + return out + if i + 1 < attempts: + _time.sleep(delay) + delay = min(delay * 1.5, 8.0) + return out + + def upload_log(self) -> None: + # Mirror the SkyPilot managed-job logs to log_uri via the orchestrator's + # storage provider so the Tangle UI can display them. Without this, + # logs are still available via ``sky jobs logs `` but the UI's + # /api/.../log endpoint reads from log_uri and returns nothing. + if self._storage_provider is None: + _logger.debug( + "upload_log: no storage_provider configured; skipping mirror " + "(sky.jobs.tail_logs(job_id=%d) is still available).", + self._handle.job_id, + ) + return + log_text = self.get_log() + writer = self._storage_provider.make_uri(self._handle.log_uri).get_writer() + writer.upload_from_text(log_text) + + def stream_log_lines(self) -> Iterator[str]: + # Bridge sky.jobs.tail_logs(follow=True) into a generator of lines. + class _LineBuffer(io.TextIOBase): + def __init__(self) -> None: + self._buf = "" + self.lines: list[str] = [] + self.lock = threading.Lock() + + def write(self, s: str) -> int: + with self.lock: + self._buf += s + while "\n" in self._buf: + line, self._buf = self._buf.split("\n", 1) + self.lines.append(line + "\n") + return len(s) + + def take(self) -> list[str]: + with self.lock: + out, self.lines = self.lines, [] + return out + + buf = _LineBuffer() + finished = threading.Event() + + def _run() -> None: + try: + sky_jobs.tail_logs( + job_id=self._handle.job_id, follow=True, output_stream=buf + ) + finally: + finished.set() + + thread = threading.Thread(target=_run, daemon=True) + thread.start() + try: + while not finished.is_set() or buf.lines: + for line in buf.take(): + yield line + if not finished.is_set(): + finished.wait(timeout=0.5) + for line in buf.take(): + yield line + finally: + thread.join(timeout=1.0) + + def terminate(self) -> None: + request_id = sky_jobs.cancel(job_ids=[self._handle.job_id]) + sky.get(request_id) + + def to_dict(self) -> dict[str, Any]: + return { + "skypilot": { + "job_id": self._handle.job_id, + "job_name": self._handle.job_name, + "output_uris": self._handle.output_uris, + "log_uri": self._handle.log_uri, + "cached_status": self._handle.cached_status, + "cached_failure_reason": self._handle.cached_failure_reason, + "cached_started_at": self._handle.cached_started_at, + "cached_ended_at": self._handle.cached_ended_at, + } + } + + @classmethod + def from_dict( + cls, + d: dict[str, Any], + *, + storage_provider: Optional[ + storage_provider_interfaces.StorageProvider + ] = None, + ) -> "SkyPilotLaunchedJob": + sk = d["skypilot"] + return cls( + handle=_SkyPilotJobHandle( + job_id=int(sk["job_id"]), + job_name=sk["job_name"], + output_uris=dict(sk.get("output_uris") or {}), + log_uri=sk["log_uri"], + cached_status=sk.get("cached_status"), + cached_failure_reason=sk.get("cached_failure_reason"), + cached_started_at=sk.get("cached_started_at"), + cached_ended_at=sk.get("cached_ended_at"), + ), + storage_provider=storage_provider, + ) + + def get_refreshed(self) -> "SkyPilotLaunchedJob": + request_id = sky_jobs.queue(refresh=True, job_ids=[self._handle.job_id]) + result = sky.get(request_id) + # Sky's queue() return shape varied across versions. Newer (nightly): + # tuple[list[dict], ...]; older: list[dict] directly. Unwrap to the + # list of records. + if isinstance(result, tuple): + records = result[0] if result else [] + else: + records = result or [] + # Find the record matching our job_id (queue may ignore the filter + # and return all jobs). + rec = None + for r in records: + jid = r.get("job_id") if isinstance(r, dict) else getattr(r, "job_id", None) + if jid == self._handle.job_id: + rec = r + break + if rec is None: + return SkyPilotLaunchedJob( + handle=dataclasses.replace( + self._handle, + cached_status="FAILED_CONTROLLER", + cached_failure_reason="job not found in sky.jobs.queue", + ), + storage_provider=self._storage_provider, + ) + get = (lambda k: rec.get(k)) if isinstance(rec, dict) else ( + lambda k: getattr(rec, k, None) + ) + status_str = _status_to_string(get("status")) + started_at = get("start_at") + ended_at = get("end_at") + return SkyPilotLaunchedJob( + handle=dataclasses.replace( + self._handle, + cached_status=status_str, + cached_failure_reason=get("failure_reason"), + cached_started_at=float(started_at) if started_at else None, + cached_ended_at=float(ended_at) if ended_at else None, + ), + storage_provider=self._storage_provider, + ) diff --git a/examples/multicluster_inference_e2e.py b/examples/multicluster_inference_e2e.py new file mode 100644 index 0000000..160c17c --- /dev/null +++ b/examples/multicluster_inference_e2e.py @@ -0,0 +1,292 @@ +"""Multi-cluster inference example via the SkyPilot launcher. + +Two inference tasks in the same Tangle pipeline, each pinned to a +different cloud/cluster purely by accelerator constraint — sky's +optimizer routes each to the matching K8s context: + + - 'infer_gke_l4': asks for an L4 → lands on a GKE cluster + - 'infer_nebius_h100': asks for an H100 → lands on a Nebius cluster + +Both run the same Qwen2.5-0.5B-Instruct generation script on a fixed +list of prompts. Outputs land in cloud storage; a final 'compare' task +reads both and prints them side-by-side. Demonstrates SkyPilot's +cross-cluster placement under the Tangle launcher: one ComponentSpec +deployed to two different K8s clusters purely by accelerator +constraint. + +Requires SkyPilot's `kubernetes.allowed_contexts` to include both +contexts and the launcher to be initialized with `infra=None` so the +optimizer can pick per task. +""" +from __future__ import annotations +import datetime, json, time, urllib.request, urllib.error + +BASE = "http://localhost:9091" + + +def post(path, body): + req = urllib.request.Request( + BASE + path, data=json.dumps(body).encode(), + headers={"Content-Type": "application/json"}, method="POST", + ) + with urllib.request.urlopen(req, timeout=30) as r: + return json.loads(r.read()) + + +def get(path): + try: + with urllib.request.urlopen(BASE + path, timeout=30) as r: + return json.loads(r.read()) + except urllib.error.HTTPError as e: + if e.code == 404: + return None + return {"_error": e.code, "_body": e.read().decode()[:200]} + + +_PROMPTS = [ + "The capital of France is", + "In the year 2050, robots will", + "A haiku about distributed computing:", +] + + +_INFER_PY = r""" +import json, os, socket, time +prompts_in = os.environ['PROMPTS_PATH'] +out_path = os.environ['OUTPUT_PATH'] +model_id = os.environ.get('MODEL_ID', 'Qwen/Qwen2.5-0.5B-Instruct') +os.makedirs(os.path.dirname(out_path), exist_ok=True) + +with open(prompts_in) as f: + prompts = json.load(f) +print(f'[{socket.gethostname()}] loaded {len(prompts)} prompts', flush=True) + +print('[importing transformers]', flush=True) +import torch +from transformers import AutoModelForCausalLM, AutoTokenizer +print(f'[torch={torch.__version__} cuda={torch.cuda.is_available()} ' + f'gpu={torch.cuda.get_device_name(0) if torch.cuda.is_available() else "cpu"}]', + flush=True) + +device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') +print(f'[loading {model_id}]', flush=True) +tok = AutoTokenizer.from_pretrained(model_id) +model = AutoModelForCausalLM.from_pretrained( + model_id, torch_dtype=torch.bfloat16 if device.type == 'cuda' else torch.float32, +).to(device).eval() + +results = [] +for p in prompts: + t0 = time.time() + # Qwen2.5-Instruct uses a chat template; format as a single user turn. + messages = [{'role': 'user', 'content': p}] + chat = tok.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) + ids = tok(chat, return_tensors='pt').to(device) + with torch.no_grad(): + out = model.generate(**ids, max_new_tokens=64, do_sample=False, + pad_token_id=tok.eos_token_id) + new_tokens = out[0][ids['input_ids'].shape[1]:] + text = tok.decode(new_tokens, skip_special_tokens=True).strip() + elapsed_ms = round((time.time() - t0) * 1000, 1) + print(f'[{p!r}] ({elapsed_ms}ms) -> {text!r}', flush=True) + results.append({'prompt': p, 'completion': text, + 'elapsed_ms': elapsed_ms, + 'model': model_id, + 'gpu': torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'cpu', + 'host': socket.gethostname()}) + +with open(out_path, 'w') as f: + json.dump(results, f, indent=2) +print(f'[wrote {out_path}, {os.path.getsize(out_path)} bytes]', flush=True) +""" + + +# --- Task 1: prepare prompts (CPU only, lands wherever sky picks) ----------- +prepare_spec = { + "name": "skypilot-prepare-prompts", + "outputs": [{"name": "prompts", "type": "String"}], + "implementation": { + "container": { + "image": "python:3.11-slim", + "command": [ + "bash", "-c", + 'set -euo pipefail; mkdir -p "$(dirname "$0")"; ' + f"python3 -c 'import json,sys; json.dump({json.dumps(_PROMPTS)}, open(sys.argv[1], \"w\"))' \"$0\"; " + 'echo "wrote prompts to $0"; cat "$0"', + {"outputPath": "prompts"}, + ], + } + }, +} + + +def _make_inference_spec(suffix: str) -> dict: + return { + # Using the same component name across both inference tasks would + # hit Tangle's cache and reuse one execution for both. Encoding + # `-` into the component name keeps the cache keys + # distinct AND surfaces the placement in the Tangle UI / sky + # dashboard at a glance. + "name": f"skypilot-qwen-inference-{suffix}", + "inputs": [{"name": "prompts", "type": "String"}], + "outputs": [{"name": "completions", "type": "String"}], + "implementation": { + "container": { + "image": "pytorch/pytorch:2.4.0-cuda12.1-cudnn9-runtime", + "env": {"COMPONENT_VARIANT": suffix}, + "command": [ + "bash", "-c", + 'set -euo pipefail; ' + 'export PROMPTS_PATH="$0"; export OUTPUT_PATH="$1"; ' + # transformers isn't bundled in pytorch image — pip + # install once, ~10s on a cold pod. + 'pip install -q --no-cache-dir transformers==4.41.1 >/dev/null; ' + 'nvidia-smi -L; ' + f"python3 -u <<'PYEOF'\n{_INFER_PY}\nPYEOF", + {"inputPath": "prompts"}, + {"outputPath": "completions"}, + ], + } + }, + } + + +# --- Task 4: print results from both clusters side-by-side ------------------ +compare_spec = { + "name": "skypilot-compare-completions", + "inputs": [ + {"name": "gke_l4_completions", "type": "String"}, + {"name": "nebius_h100_completions", "type": "String"}, + ], + "outputs": [{"name": "report", "type": "String"}], + "implementation": { + "container": { + "image": "python:3.11-slim", + "command": [ + "bash", "-c", + 'set -euo pipefail; mkdir -p "$(dirname "$2")"; ' + 'python3 - "$0" "$1" "$2" <<\'PY\'\n' + 'import json, sys\n' + 'a = json.load(open(sys.argv[1])) # gke-l4\n' + 'b = json.load(open(sys.argv[2])) # nebius-h100\n' + 'lines = ["=== Multi-cluster inference comparison ==="]\n' + 'for pa, pb in zip(a, b):\n' + ' lines.append(f"prompt: {pa[\'prompt\']!r}")\n' + ' lines.append(f" gke-l4 ({pa[\'gpu\']} on {pa[\'host\']}, {pa[\'elapsed_ms\']}ms): {pa[\'completion\']!r}")\n' + ' lines.append(f" nebius-h100 ({pb[\'gpu\']} on {pb[\'host\']}, {pb[\'elapsed_ms\']}ms): {pb[\'completion\']!r}")\n' + ' lines.append("")\n' + 'report = "\\n".join(lines)\n' + 'print(report)\n' + 'open(sys.argv[3], "w").write(report + "\\n")\n' + 'PY', + {"inputPath": "gke_l4_completions"}, + {"inputPath": "nebius_h100_completions"}, + {"outputPath": "report"}, + ], + } + }, +} + + +ts = datetime.datetime.utcnow().strftime("%Y%m%dT%H%M%S") +pipeline_spec = { + "name": f"skypilot-multicluster-inference-{ts}", + "outputs": [{"name": "report", "type": "String"}], + "implementation": { + "graph": { + "tasks": { + "prepare": { + "componentRef": {"spec": prepare_spec}, + "annotations": { + "cloud-pipelines.net/launchers/generic/resources.cpu": "1", + "cloud-pipelines.net/launchers/generic/resources.memory": "1", + }, + }, + "infer_gke_l4": { + "componentRef": {"spec": _make_inference_spec("gke-l4")}, + "arguments": { + "prompts": {"taskOutput": {"taskId": "prepare", "outputName": "prompts"}} + }, + "annotations": { + "cloud-pipelines.net/launchers/generic/resources.cpu": "2", + "cloud-pipelines.net/launchers/generic/resources.memory": "8", + # H100 is what's actually available in our test + # environment; swap to "L4:1" once a GKE-L4 cluster + # is in allowed_contexts to make the name match + # the placement. + "cloud-pipelines.net/launchers/generic/resources.accelerators": "H100:1", + }, + }, + "infer_nebius_h100": { + "componentRef": {"spec": _make_inference_spec("nebius-h100")}, + "arguments": { + "prompts": {"taskOutput": {"taskId": "prepare", "outputName": "prompts"}} + }, + "annotations": { + "cloud-pipelines.net/launchers/generic/resources.cpu": "2", + "cloud-pipelines.net/launchers/generic/resources.memory": "8", + # Asking for H200 here so this task is forced onto a + # different K8s context than the H100 one, exercising + # cross-cluster placement. Swap to "H100:1" once a + # Nebius-H100 cluster is in allowed_contexts. + "cloud-pipelines.net/launchers/generic/resources.accelerators": "H200:1", + }, + }, + "compare": { + "componentRef": {"spec": compare_spec}, + "arguments": { + "gke_l4_completions": {"taskOutput": {"taskId": "infer_gke_l4", "outputName": "completions"}}, + "nebius_h100_completions": {"taskOutput": {"taskId": "infer_nebius_h100", "outputName": "completions"}}, + }, + "annotations": { + "cloud-pipelines.net/launchers/generic/resources.cpu": "1", + "cloud-pipelines.net/launchers/generic/resources.memory": "1", + }, + }, + }, + "outputValues": { + "report": {"taskOutput": {"taskId": "compare", "outputName": "report"}} + }, + } + }, +} + +print(f"=== submit multi-cluster inference (ts={ts}) ===") +body = {"root_task": {"componentRef": {"spec": pipeline_spec}, "arguments": {}}} +run = post("/api/pipeline_runs/", body) +print(json.dumps(run, indent=2)) +root_exec = run["root_execution_id"] + +print(f"\n=== poll graph_execution_state for {root_exec} ===") +deadline = time.time() + 1800 +last = None +while time.time() < deadline: + state = get(f"/api/executions/{root_exec}/graph_execution_state") + line = json.dumps(state.get("child_execution_status_stats", {})) if state else "" + if line != last: + print(f" [{time.strftime('%H:%M:%S')}] {line}", flush=True) + last = line + stats = (state or {}).get("child_execution_status_stats", {}) or {} + summary = {} + for child_id, status_dict in stats.items(): + for status, count in status_dict.items(): + summary[status] = summary.get(status, 0) + count + if any(summary.get(k, 0) > 0 for k in ("FAILED", "SYSTEM_ERROR", "INVALID", "CANCELLED")): + break + if (summary.get("SUCCEEDED", 0) >= 4 and + not any(summary.get(k, 0) > 0 + for k in ("PENDING", "QUEUED", "RUNNING", "WAITING_FOR_UPSTREAM", + "STARTING"))): + break + time.sleep(20) + +print(f"\n=== child task statuses ===") +details = get(f"/api/executions/{root_exec}/details") +child_ids = (details or {}).get("child_task_execution_ids", {}) or {} +for task_id, exec_id in child_ids.items(): + cstate = get(f"/api/executions/{exec_id}/container_state") + print(f" {task_id}: status={(cstate or {}).get('status')} " + f"exit_code={(cstate or {}).get('exit_code')}") + if cstate and cstate.get("debug_info", {}).get("skypilot"): + sky = cstate["debug_info"]["skypilot"] + print(f" sky job_id={sky.get('job_id')} name={sky.get('job_name')}") diff --git a/examples/multinode_pipeline_e2e.py b/examples/multinode_pipeline_e2e.py new file mode 100644 index 0000000..9dc7ca1 --- /dev/null +++ b/examples/multinode_pipeline_e2e.py @@ -0,0 +1,226 @@ +"""Real GPU PyTorch DDP multi-node test on CoreWeave H100s. + +Two pods, one H100 each, NCCL backend. Trains a small MLP with synthetic +data via DistributedDataParallel, writes a checkpoint and per-epoch +JSONL log to GCS. + +The launcher's run-script prelude exports TANGLE_MULTI_NODE_* from +SkyPilot's runtime values; we map those onto torch.distributed's +MASTER_ADDR / RANK / WORLD_SIZE so torch can rendez-vous across pods. + +Worker pods authenticate to GCS via a GCP service-account key mounted +by SkyPilot's helm chart (gcpCredentials.enabled=true), so storage +mounts work outside GKE. +""" +from __future__ import annotations +import datetime, json, time, urllib.request, urllib.error + +BASE = "http://localhost:9091" + + +def post(path, body): + req = urllib.request.Request( + BASE + path, data=json.dumps(body).encode(), + headers={"Content-Type": "application/json"}, method="POST", + ) + with urllib.request.urlopen(req, timeout=30) as r: + return json.loads(r.read()) + + +def get(path): + try: + with urllib.request.urlopen(BASE + path, timeout=30) as r: + return json.loads(r.read()) + except urllib.error.HTTPError as e: + if e.code == 404: + return None + return {"_error": e.code, "_body": e.read().decode()[:200]} + + +_TRAIN_PY = r""" +import json, os, socket, time +import torch, torch.nn as nn +import torch.distributed as dist +from torch.utils.data import DataLoader, TensorDataset +from torch.utils.data.distributed import DistributedSampler +from torch.nn.parallel import DistributedDataParallel as DDP + +rank = int(os.environ["RANK"]) +world = int(os.environ["WORLD_SIZE"]) +torch.manual_seed(42 + rank) + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +gpu_name = torch.cuda.get_device_name(0) if torch.cuda.is_available() else "cpu" +print(f"[rank {rank}/{world}] host={socket.gethostname()} torch={torch.__version__} " + f"device={device} gpu={gpu_name} master={os.environ['MASTER_ADDR']}:{os.environ['MASTER_PORT']}", + flush=True) + +if device.type == "cuda": + torch.cuda.set_device(0) + +backend = "nccl" if device.type == "cuda" else "gloo" +dist.init_process_group(backend=backend, rank=rank, world_size=world, + timeout=__import__("datetime").timedelta(seconds=180)) +print(f"[rank {rank}] process group initialized (backend={backend})", flush=True) + +# Synthetic regression — a small MLP that gives the GPU something to do. +N, D, H = 32768, 256, 512 +torch.manual_seed(0) +X = torch.randn(N, D) +W_true = torch.randn(D, 1) * 0.5 +y = X @ W_true + 0.1 + 0.05 * torch.randn(N, 1) +ds = TensorDataset(X, y) + +batch = int(os.environ.get("BATCH_SIZE", "128")) +epochs = int(os.environ.get("EPOCHS", "5")) +lr = float(os.environ.get("LR", "0.01")) + +sampler = DistributedSampler(ds, num_replicas=world, rank=rank, shuffle=True) +loader = DataLoader(ds, batch_size=batch, sampler=sampler) + +model = nn.Sequential(nn.Linear(D, H), nn.ReLU(), + nn.Linear(H, H), nn.ReLU(), + nn.Linear(H, 1)).to(device) +ddp = DDP(model, device_ids=[0] if device.type == "cuda" else None) +opt = torch.optim.Adam(ddp.parameters(), lr=lr) +loss_fn = nn.MSELoss() + +log_lines = [] +t0 = time.time() +for epoch in range(epochs): + sampler.set_epoch(epoch) + epoch_loss, n_batches = 0.0, 0 + for xb, yb in loader: + xb, yb = xb.to(device), yb.to(device) + pred = ddp(xb) + loss = loss_fn(pred, yb) + opt.zero_grad() + loss.backward() + opt.step() + epoch_loss += loss.item() + n_batches += 1 + t = torch.tensor([epoch_loss, float(n_batches)], device=device) + dist.all_reduce(t, op=dist.ReduceOp.SUM) + avg = (t[0] / t[1]).item() + msg = {"epoch": epoch + 1, "rank": rank, "avg_loss": round(avg, 6), + "elapsed_s": round(time.time() - t0, 2), "device": str(device)} + print(f"[rank {rank}] {msg}", flush=True) + log_lines.append(json.dumps(msg)) + +dist.barrier() +if rank == 0: + ckpt_path = os.environ["OUTPUT_CKPT"] + log_path = os.environ["OUTPUT_LOG"] + os.makedirs(os.path.dirname(ckpt_path), exist_ok=True) + os.makedirs(os.path.dirname(log_path), exist_ok=True) + torch.save({"state_dict": ddp.module.state_dict(), + "world_size": world, "epochs": epochs, + "device": str(device), "gpu": gpu_name}, ckpt_path) + with open(log_path, "w") as f: + f.write("\n".join(log_lines) + "\n") + print(f"[rank 0] wrote {ckpt_path} ({os.path.getsize(ckpt_path)} bytes) " + f"and {log_path}", flush=True) +dist.destroy_process_group() +print(f"[rank {rank}] done", flush=True) +""" + +ts = datetime.datetime.utcnow().strftime("%Y%m%dT%H%M%S") +training_spec = { + # The "skypilot-" prefix surfaces the launcher in Tangle's UI / sky + # dashboard — both display the component name unchanged. + "name": f"skypilot-pytorch-ddp-h100-{ts}", + "outputs": [ + {"name": "checkpoint", "type": "Model"}, + {"name": "training_log", "type": "String"}, + ], + "implementation": { + "container": { + "image": "pytorch/pytorch:2.4.0-cuda12.1-cudnn9-runtime", + "env": {"PIPELINE_RUN_TS": ts}, + "command": [ + "bash", "-c", + 'set -euo pipefail; ' + 'export MASTER_ADDR="${TANGLE_MULTI_NODE_NODE_0_ADDRESS:-localhost}"; ' + 'export MASTER_PORT="29500"; ' + 'export RANK="${TANGLE_MULTI_NODE_NODE_INDEX:-0}"; ' + 'export WORLD_SIZE="${TANGLE_MULTI_NODE_NUMBER_OF_NODES:-1}"; ' + 'export OUTPUT_CKPT="$0"; ' + 'export OUTPUT_LOG="$1"; ' + 'export EPOCHS="5"; export BATCH_SIZE="128"; export LR="0.01"; ' + 'echo "[$(hostname)] rank=$RANK/$WORLD_SIZE master=$MASTER_ADDR:$MASTER_PORT"; ' + 'nvidia-smi -L 2>/dev/null || echo "nvidia-smi unavailable"; ' + f'python3 -u <<\'PYEOF\'\n{_TRAIN_PY}\nPYEOF', + {"outputPath": "checkpoint"}, + {"outputPath": "training_log"}, + ], + } + }, +} + +pipeline_spec = { + "name": f"skypilot-pytorch-ddp-h100-pipeline-{ts}", + "outputs": [{"name": "checkpoint", "type": "Model"}], + "implementation": { + "graph": { + "tasks": { + "train": { + "componentRef": {"spec": training_spec}, + "annotations": { + "tangleml.com/launchers/kubernetes/multi_node/number_of_nodes": "2", + "cloud-pipelines.net/launchers/generic/resources.cpu": "4", + "cloud-pipelines.net/launchers/generic/resources.memory": "16", + "cloud-pipelines.net/launchers/generic/resources.accelerators": "H100:1", + }, + }, + }, + "outputValues": { + "checkpoint": { + "taskOutput": {"taskId": "train", "outputName": "checkpoint"} + } + }, + } + }, +} + +print(f"=== submit pytorch-ddp pipeline (ts={ts}) ===") +body = {"root_task": {"componentRef": {"spec": pipeline_spec}, "arguments": {}}} +run = post("/api/pipeline_runs/", body) +print(json.dumps(run, indent=2)) +root_exec = run["root_execution_id"] + +print(f"\n=== poll graph_execution_state for {root_exec} ===") +deadline = time.time() + 1800 +last = None +while time.time() < deadline: + state = get(f"/api/executions/{root_exec}/graph_execution_state") + line = json.dumps(state.get("child_execution_status_stats", {})) if state else "" + if line != last: + print(f" [{time.strftime('%H:%M:%S')}] {line}", flush=True) + last = line + stats = (state or {}).get("child_execution_status_stats", {}) or {} + summary = {} + for child_id, status_dict in stats.items(): + for status, count in status_dict.items(): + summary[status] = summary.get(status, 0) + count + if any(summary.get(k, 0) > 0 for k in ("FAILED", "SYSTEM_ERROR", "INVALID", "CANCELLED")): + break + if (summary.get("SUCCEEDED", 0) >= 1 and + not any(summary.get(k, 0) > 0 + for k in ("PENDING", "QUEUED", "RUNNING", "WAITING_FOR_UPSTREAM", + "STARTING"))): + break + time.sleep(15) + +print(f"\n=== final root state ===") +print(json.dumps(get(f"/api/executions/{root_exec}/graph_execution_state"), indent=2)[:2500]) + +print(f"\n=== child task statuses ===") +details = get(f"/api/executions/{root_exec}/details") +child_ids = (details or {}).get("child_task_execution_ids", {}) or {} +for task_id, exec_id in child_ids.items(): + cstate = get(f"/api/executions/{exec_id}/container_state") + print(f" {task_id}: status={(cstate or {}).get('status')} " + f"exit_code={(cstate or {}).get('exit_code')}") + if cstate and cstate.get("debug_info", {}).get("skypilot"): + sky = cstate["debug_info"]["skypilot"] + print(f" sky job_id={sky.get('job_id')} name={sky.get('job_name')}") diff --git a/examples/publish_skypilot_components.py b/examples/publish_skypilot_components.py new file mode 100644 index 0000000..931af32 --- /dev/null +++ b/examples/publish_skypilot_components.py @@ -0,0 +1,206 @@ +"""Publish SkyPilot-flavored components into Tangle's component library. + +Run once after starting Tangle (start_local_skypilot.py) and the components +will appear in the UI's component picker when building pipelines. Each +component name is prefixed with "SkyPilot:" so they're easy to spot. + +Usage: + python examples/publish_skypilot_components.py + +Idempotent — already-published components return 409 and are skipped. +""" +from __future__ import annotations +import json +import sys +import urllib.error +import urllib.request + +BASE = "http://localhost:9091" + + +_GPU_SANITY_CHECK = """\ +name: 'SkyPilot: GPU Sanity Check' +description: | + Demonstrates SkyPilot-specific capabilities exposed through the Tangle + launcher contract: + * GPU accelerator request via cloud-pipelines.net resource annotations + * Multi-node coordination via TANGLE_MULTI_NODE_* env vars (bridged from + SKYPILOT_NODE_RANK / NODE_IPS by the SkyPilot launcher prelude) + * SkyPilot-only annotations: priority_class (Kueue) and use_spot + +inputs: + - {name: epochs, type: Integer, default: "1"} +outputs: + - {name: report, type: String, description: "Sanity-check report."} +implementation: + container: + image: pytorch/pytorch:2.4.0-cuda12.1-cudnn9-runtime + command: + - sh + - -c + - | + set -euo pipefail + mkdir -p "$(dirname "$1")" + python -c " + import os, socket, datetime + try: + import torch + cuda_ok = torch.cuda.is_available() + n = torch.cuda.device_count() if cuda_ok else 0 + name = torch.cuda.get_device_name(0) if cuda_ok else 'no-gpu' + except Exception as e: + cuda_ok, n, name = False, 0, f'torch-import-failed: {e}' + rank = os.environ.get('TANGLE_MULTI_NODE_NODE_INDEX', '0') + nnodes = os.environ.get('TANGLE_MULTI_NODE_NUMBER_OF_NODES', '1') + peer0 = os.environ.get('TANGLE_MULTI_NODE_NODE_0_ADDRESS', 'localhost') + msg = (f'host={socket.gethostname()} rank={rank}/{nnodes} ' + f'peer0={peer0} cuda={cuda_ok} ndev={n} gpu={name} ' + f'epochs=$0 ts={datetime.datetime.utcnow().isoformat()}Z') + print(msg) + with open('$1', 'w') as f: f.write(msg + chr(10)) + " "$0" "$1" + - {inputValue: epochs} + - {outputPath: report} +""" + + +_PYTORCH_DDP = """\ +name: 'SkyPilot: Multi-node PyTorch DDP' +description: | + Real multi-node PyTorch DistributedDataParallel training driven by the + SkyPilot launcher. Two pods, NCCL backend, synthetic regression on a + small MLP — gradients are all-reduced across ranks each step. Rank 0 + saves a checkpoint and per-epoch JSONL log to the configured outputs. + + Annotate the TaskSpec with + tangleml.com/launchers/kubernetes/multi_node/number_of_nodes: "2" + cloud-pipelines.net/launchers/generic/resources.accelerators: "H100:1" + to launch as 2 pods × 1 H100 each. Set num_nodes=1 for single-pod + multi-GPU. Drop the accelerators annotation for a CPU-only run with + the gloo backend (slower but no GPU needed). + + Verified on CoreWeave H100s (sky-dev cluster): + loss 3.46 → 0.08 → 0.057 → 0.025 → 0.026 over 5 epochs + ~18s job duration; rank-synchronized all-reduce confirms gradient sync + rank 0 wrote a 1.5 MiB checkpoint to gs://tangle-skypilot-test-zhwu/... + Worker pods get GCS auth via SkyPilot's gcpCredentials helm option + (mounts a GCP service-account key + GOOGLE_APPLICATION_CREDENTIALS). + +outputs: + - {name: checkpoint, type: Model} + - {name: training_log, type: String} +implementation: + container: + image: pytorch/pytorch:2.4.0-cuda12.1-cudnn9-runtime + command: + - bash + - -c + - | + set -euo pipefail + export MASTER_ADDR="${TANGLE_MULTI_NODE_NODE_0_ADDRESS:-localhost}" + export MASTER_PORT="29500" + export RANK="${TANGLE_MULTI_NODE_NODE_INDEX:-0}" + export WORLD_SIZE="${TANGLE_MULTI_NODE_NUMBER_OF_NODES:-1}" + export OUTPUT_CKPT="$0"; export OUTPUT_LOG="$1" + : "${EPOCHS:=5}"; : "${BATCH_SIZE:=128}"; : "${LR:=0.01}" + echo "[$(hostname)] rank=$RANK/$WORLD_SIZE master=$MASTER_ADDR:$MASTER_PORT" + nvidia-smi -L 2>/dev/null || echo "nvidia-smi unavailable" + python3 -u <<'PY' + import json, os, socket, time, datetime + import torch, torch.nn as nn + import torch.distributed as dist + from torch.utils.data import DataLoader, TensorDataset + from torch.utils.data.distributed import DistributedSampler + from torch.nn.parallel import DistributedDataParallel as DDP + rank = int(os.environ['RANK']); world = int(os.environ['WORLD_SIZE']) + torch.manual_seed(42 + rank) + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + gpu = torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'cpu' + print(f'[rank {rank}/{world}] host={socket.gethostname()} torch={torch.__version__} ' + f'device={device} gpu={gpu} master={os.environ["MASTER_ADDR"]}:{os.environ["MASTER_PORT"]}', + flush=True) + if device.type == 'cuda': + torch.cuda.set_device(0) + backend = 'nccl' if device.type == 'cuda' else 'gloo' + dist.init_process_group(backend=backend, rank=rank, world_size=world, + timeout=datetime.timedelta(seconds=180)) + print(f'[rank {rank}] process group initialized (backend={backend})', flush=True) + N, D, H = 32768, 256, 512 + torch.manual_seed(0) + X = torch.randn(N, D) + W = torch.randn(D, 1) * 0.5 + y = X @ W + 0.1 + 0.05 * torch.randn(N, 1) + ds = TensorDataset(X, y) + sampler = DistributedSampler(ds, num_replicas=world, rank=rank, shuffle=True) + loader = DataLoader(ds, batch_size=int(os.environ['BATCH_SIZE']), sampler=sampler) + model = nn.Sequential(nn.Linear(D, H), nn.ReLU(), + nn.Linear(H, H), nn.ReLU(), + nn.Linear(H, 1)).to(device) + ddp = DDP(model, device_ids=[0] if device.type == 'cuda' else None) + opt = torch.optim.Adam(ddp.parameters(), lr=float(os.environ['LR'])) + loss_fn = nn.MSELoss() + log_lines, t0 = [], time.time() + for epoch in range(int(os.environ['EPOCHS'])): + sampler.set_epoch(epoch) + running, n = 0.0, 0 + for xb, yb in loader: + xb, yb = xb.to(device), yb.to(device) + pred = ddp(xb) + loss = loss_fn(pred, yb) + opt.zero_grad(); loss.backward(); opt.step() + running += loss.item(); n += 1 + t = torch.tensor([running, float(n)], device=device) + dist.all_reduce(t, op=dist.ReduceOp.SUM) + msg = {'epoch': epoch + 1, 'rank': rank, + 'avg_loss': round((t[0]/t[1]).item(), 6), + 'elapsed_s': round(time.time() - t0, 2), + 'device': str(device)} + print(f'[rank {rank}]', json.dumps(msg), flush=True) + log_lines.append(json.dumps(msg)) + dist.barrier() + if rank == 0: + ckpt, log = os.environ['OUTPUT_CKPT'], os.environ['OUTPUT_LOG'] + os.makedirs(os.path.dirname(ckpt), exist_ok=True) + os.makedirs(os.path.dirname(log), exist_ok=True) + torch.save({'state_dict': ddp.module.state_dict(), + 'world_size': world, 'gpu': gpu, 'device': str(device)}, ckpt) + with open(log, 'w') as f: + f.write('\\n'.join(log_lines) + '\\n') + print(f'[rank 0] wrote {ckpt} ({os.path.getsize(ckpt)} bytes)', flush=True) + dist.destroy_process_group() + print(f'[rank {rank}] done', flush=True) + PY + - {outputPath: checkpoint} + - {outputPath: training_log} +""" + + +_COMPONENTS = [ + ("SkyPilot: GPU Sanity Check", _GPU_SANITY_CHECK), + ("SkyPilot: Multi-node PyTorch DDP", _PYTORCH_DDP), +] + + +def publish(name: str, text: str) -> None: + body = {"text": text} + req = urllib.request.Request( + BASE + "/api/published_components/", + data=json.dumps(body).encode(), + headers={"Content-Type": "application/json"}, + method="POST", + ) + try: + with urllib.request.urlopen(req, timeout=30) as r: + resp = json.loads(r.read()) + print(f"published: {name!r} digest={resp.get('digest')}") + except urllib.error.HTTPError as e: + body = e.read().decode()[:300] + if e.code == 409 or "already exists" in body: + print(f"already exists: {name!r}") + else: + print(f"FAILED: {name!r} HTTP {e.code}: {body}", file=sys.stderr) + sys.exit(1) + + +for n, t in _COMPONENTS: + publish(n, t) diff --git a/examples/skypilot_launcher_dryrun.py b/examples/skypilot_launcher_dryrun.py new file mode 100644 index 0000000..d57437d --- /dev/null +++ b/examples/skypilot_launcher_dryrun.py @@ -0,0 +1,274 @@ +"""End-to-end dry-run of a Tangle pipeline through the SkyPilot launcher. + +Exercises the full ContainerTaskLauncher contract — submit, refresh, log +streaming, persistence round-trip, terminate — without requiring a live +Kubernetes cluster or SkyPilot controller. The sky.jobs SDK is stubbed inline +so the script runs anywhere Tangle is installed. + +Run: + /home/sky/.venv/bin/python examples/run_pipeline_dryrun.py + +Output sections: + 1. Translation (ComponentSpec -> sky.Task) + 2. Submission (sky.jobs.launch) + 3. Status refresh (sky.jobs.queue) + 4. Log fetch (sky.jobs.tail_logs) + 5. Persistence round-trip (orchestrator restart simulation) + 6. Termination (sky.jobs.cancel) +""" + +from __future__ import annotations + +import json +import sys +import textwrap +import types + + +def _stub_sky() -> dict: + """Stub the sky module before the launcher imports it.""" + submissions: list = [] + cancellations: list = [] + + sky_mod = types.ModuleType("sky") + + class _FakeTask: + def __init__(self, *, name=None, run=None, envs=None, num_nodes=1, + file_mounts=None, **kwargs): + self.name = name + self.run = run + self.envs = envs or {} + self.num_nodes = num_nodes + self.file_mounts = file_mounts + self.resources = None + + def set_resources(self, r): + self.resources = r + return self + + def set_file_mounts(self, fm): + self.file_mounts = fm + return self + + class _FakeResources: + def __init__(self, **kwargs): + self.kwargs = kwargs + + def __repr__(self): + return f"Resources({self.kwargs})" + + sky_mod.Task = _FakeTask + sky_mod.Resources = _FakeResources + sky_mod.get = lambda req: req + + sky_jobs = types.ModuleType("sky.jobs") + + def _launch(task, name=None, **kwargs): + submissions.append({"task": task, "name": name, "kwargs": kwargs}) + return ([20251029], None) + + def _queue(refresh=False, job_ids=None, **kwargs): + return [{ + "job_id": (job_ids or [20251029])[0], + "status": "RUNNING", + "start_at": 1700000000.0, + "end_at": None, + "failure_reason": None, + }] + + def _cancel(job_ids=None, **kwargs): + cancellations.append(list(job_ids or [])) + return {"cancelled": list(job_ids or [])} + + def _tail_logs(job_id=None, follow=False, output_stream=None, **kwargs): + if output_stream is not None: + output_stream.write(f"job {job_id} step 1/3 ok\n") + output_stream.write(f"job {job_id} step 2/3 ok\n") + output_stream.write(f"job {job_id} step 3/3 ok\n") + return 0 + + sky_jobs.launch = _launch + sky_jobs.queue = _queue + sky_jobs.cancel = _cancel + sky_jobs.tail_logs = _tail_logs + + sky_mod.jobs = sky_jobs + sys.modules["sky"] = sky_mod + sys.modules["sky.jobs"] = sky_jobs + return {"submissions": submissions, "cancellations": cancellations} + + +def _hr(title: str) -> None: + print(f"\n{'=' * 72}\n {title}\n{'=' * 72}") + + +def main() -> None: + captures = _stub_sky() + + from cloud_pipelines_backend import component_structures as structures + from cloud_pipelines_backend.launchers.skypilot_launchers import ( + SkyPilotKubernetesLauncher, + SkyPilotLaunchedJob, + PRIORITY_CLASS_ANNOTATION_KEY, + SPOT_ANNOTATION_KEY, + ) + + # ----------------------------------------------------------------- + _hr("1. Build a ComponentSpec — multi-node H100 fine-tune") + # ----------------------------------------------------------------- + component = structures.ComponentSpec( + name="qwen3_finetune", + implementation=structures.ContainerImplementation( + container=structures.ContainerSpec( + image="ghcr.io/example/finetune:1.0", + command=["torchrun"], + args=[ + structures.ConcatPlaceholder([ + "--nnodes=", + structures.InputValuePlaceholder("nnodes"), + ]), + structures.ConcatPlaceholder([ + "--node_rank=", + structures.InputValuePlaceholder("rank"), + ]), + structures.ConcatPlaceholder([ + "--master_addr=", + structures.InputValuePlaceholder("master"), + ]), + "train.py", + "--data", + structures.InputPathPlaceholder("dataset"), + "--ckpt-out", + structures.OutputPathPlaceholder("checkpoint"), + ], + env={"WANDB_PROJECT": "tangle-skypilot-demo"}, + ) + ), + inputs=[ + structures.InputSpec(name="nnodes"), + structures.InputSpec(name="rank"), + structures.InputSpec(name="master"), + structures.InputSpec(name="dataset"), + ], + ) + print(f" Component: {component.name}") + print(f" Image: {component.implementation.container.image}") + + # ----------------------------------------------------------------- + _hr("2. Configure the SkyPilot launcher") + # ----------------------------------------------------------------- + launcher = SkyPilotKubernetesLauncher( + infra="kubernetes", + pool="ml-training", # warm-pool reuse — no Tangle K8s equivalent + priority_class="batch", # first-class Kueue integration + default_labels={"managed-by": "tangle"}, + annotation_to_label_keys=[ + "ml.shopify.io/priority-class", + ], + ) + print(" infra=kubernetes (single-cluster); set None for multi-cloud") + print(" pool='ml-training' (warm-pool reuse)") + print(" priority_class='batch' (Kueue-compatible)") + + # ----------------------------------------------------------------- + _hr("3. Submit through launch_container_task — full lifecycle") + # ----------------------------------------------------------------- + from cloud_pipelines_backend.launchers import interfaces as _ifaces + + input_arguments = { + # Multi-node dynamic data: get bridged to bash env vars set from + # SKYPILOT_NUM_NODES / SKYPILOT_NODE_RANK / SKYPILOT_NODE_IPS. + "nnodes": _ifaces.InputArgument( + total_size=0, is_dir=False, staging_uri="", + dynamic_data="system/multi_node/number_of_nodes", + ), + "rank": _ifaces.InputArgument( + total_size=0, is_dir=False, staging_uri="", + dynamic_data="system/multi_node/node_index", + ), + "master": _ifaces.InputArgument( + total_size=0, is_dir=False, staging_uri="", + dynamic_data="system/multi_node/node_0_address", + ), + # SkyPilot accepts s3:// directly via file_mounts. Tangle's K8s + # launcher only does GCS (gcsfuse) or HostPath today. + "dataset": _ifaces.InputArgument( + total_size=10**9, is_dir=False, + uri="s3://example-datasets/finetune.parquet", + staging_uri="", + ), + } + output_uris = { + "checkpoint": "gs://example-ckpts/qwen3-finetune/run-20260426/", + } + + handle = launcher.launch_container_task( + component_spec=component, + input_arguments=input_arguments, + output_uris=output_uris, + log_uri="gs://example-logs/qwen3-finetune/run-20260426.log", + annotations={ + # Resource asks + "cloud-pipelines.net/launchers/generic/resources.cpu": "16+", + "cloud-pipelines.net/launchers/generic/resources.memory": "256", + "cloud-pipelines.net/launchers/generic/resources.accelerators": + json.dumps({"nvidia-tesla-h100": 8}), # Tangle's JSON form + "cloud-pipelines.net/launchers/generic/resources.ephemeral_storage": + "1Ti", + # 32 nodes — above Tangle K8s launcher's hardcoded cap of 16. + "tangleml.com/launchers/kubernetes/multi_node/number_of_nodes": "32", + # SkyPilot-only: spot instances with auto-recovery via managed jobs. + SPOT_ANNOTATION_KEY: "true", + # Per-task priority override (overrides launcher default). + PRIORITY_CLASS_ANNOTATION_KEY: "interactive", + # Propagated to K8s pod label by `annotation_to_label_keys`: + "ml.shopify.io/priority-class": "interactive", + }, + ) + submission = captures["submissions"][-1] + task = submission["task"] + res = task.resources + print(f" Submitted job_id = {handle.job_id}") + print(f" Job name = {submission['name']}") + print(f" num_nodes = {task.num_nodes} " + f"(Tangle K8s cap is 16)") + print(f" resources = {res.kwargs}") + print(f" pool kwarg = {submission['kwargs'].get('pool')}") + print(f" file_mounts = {task.file_mounts}") + print("\n Generated run script (first 6 lines):") + for line in task.run.splitlines()[:6]: + print(f" {line}") + + # ----------------------------------------------------------------- + _hr("4. Refresh status from the controller") + # ----------------------------------------------------------------- + handle = handle.get_refreshed() + print(f" status = {handle.status.value}") + print(f" has_ended = {handle.has_ended}") + print(f" started_at = {handle.started_at}") + + # ----------------------------------------------------------------- + _hr("5. Fetch logs (one-shot)") + # ----------------------------------------------------------------- + print(textwrap.indent(handle.get_log(), " ")) + + # ----------------------------------------------------------------- + _hr("6. Persistence round-trip — simulate orchestrator restart") + # ----------------------------------------------------------------- + serialized = handle.to_dict() + print(f" serialized: {json.dumps(serialized, indent=2, default=str)[:200]}...") + reloaded = SkyPilotLaunchedJob.from_dict(serialized) + print(f" reloaded.job_id = {reloaded.job_id}") + print(f" reloaded.status = {reloaded.status.value}") + + # ----------------------------------------------------------------- + _hr("7. Terminate") + # ----------------------------------------------------------------- + reloaded.terminate() + print(f" cancellations sent: {captures['cancellations']}") + + print("\nAll lifecycle steps completed without errors.") + + +if __name__ == "__main__": + main() diff --git a/orchestrator_main.py b/orchestrator_main.py index 99b7e8f..3ccd381 100644 --- a/orchestrator_main.py +++ b/orchestrator_main.py @@ -10,6 +10,47 @@ from cloud_pipelines.orchestration.storage_providers import local_storage +def _build_launcher(): + """Select container launcher via TANGLE_LAUNCHER env var. + + Values: + "kubernetes" (default) — KubernetesWithHostPathContainerLauncher + "kubernetes_gcs" — KubernetesWithGcsFuseContainerLauncher (GKE) + "skypilot" — SkyPilotKubernetesLauncher (requires `skypilot` extra) + """ + choice = os.environ.get("TANGLE_LAUNCHER", "kubernetes").strip().lower() + + if choice == "skypilot": + # Lazy import so deployments without the [skypilot] extra don't pay for it. + from cloud_pipelines_backend.launchers.skypilot_launchers import ( + SkyPilotKubernetesLauncher, + ) + return SkyPilotKubernetesLauncher( + infra=os.environ.get("SKYPILOT_INFRA", "kubernetes"), + pool=os.environ.get("SKYPILOT_POOL"), + default_image=os.environ.get("DEFAULT_CONTAINER_IMAGE"), + priority_class=os.environ.get("DEFAULT_PRIORITY_CLASS"), + default_labels={"managed-by": "tangle"}, + ) + + from kubernetes import config as k8s_config_lib + from kubernetes import client as k8s_client_lib + try: + k8s_config_lib.load_incluster_config() + except Exception: + k8s_config_lib.load_kube_config() + k8s_client = k8s_client_lib.ApiClient() + k8s_client_lib.VersionApi(k8s_client).get_code(_request_timeout=5) + + if choice == "kubernetes_gcs": + return kubernetes_launchers.KubernetesWithGcsFuseContainerLauncher( + api_client=k8s_client, + ) + return kubernetes_launchers.KubernetesWithHostPathContainerLauncher( + api_client=k8s_client, + ) + + def main(): logger = logging.getLogger(__name__) orchestrator_logger = logging.getLogger("cloud_pipelines_backend.orchestrator_sql") @@ -46,27 +87,12 @@ def main(): artifact_store_root_dir = (pathlib.Path.cwd() / "tmp" / "artifacts").as_posix() log_store_root_dir = (pathlib.Path.cwd() / "tmp" / "logs").as_posix() - from kubernetes import config as k8s_config_lib - from kubernetes import client as k8s_client_lib - - try: - k8s_config_lib.load_incluster_config() - except Exception: - k8s_config_lib.load_kube_config() - k8s_client = k8s_client_lib.ApiClient() - - k8s_client_lib.VersionApi(k8s_client).get_code(_request_timeout=5) - logger.info("Kubernetes works") - default_task_annotations = { kubernetes_launchers.RESOURCES_CPU_ANNOTATION_KEY: "1", kubernetes_launchers.RESOURCES_MEMORY_ANNOTATION_KEY: "512Mi", } - # launcher = kubernetes_launchers.KubernetesWithGcsFuseContainerLauncher( - launcher = kubernetes_launchers.KubernetesWithHostPathContainerLauncher( - api_client=k8s_client, - ) + launcher = _build_launcher() orchestrator = orchestrator_sql.OrchestratorService_Sql( session_factory=session_factory, diff --git a/pyproject.toml b/pyproject.toml index 16bf527..a54c950 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,6 +27,9 @@ dev = [ huggingface = [ "huggingface-hub[oauth]>=0.35.3", ] +skypilot = [ + "skypilot>=0.12.1", +] [tool.pytest.ini_options] testpaths = ["tests"] diff --git a/start_local_skypilot.py b/start_local_skypilot.py new file mode 100644 index 0000000..d000b10 --- /dev/null +++ b/start_local_skypilot.py @@ -0,0 +1,256 @@ +"""Local Tangle launcher using SkyPilot as the container backend. + +Drop-in replacement for start_local.py that swaps the Docker launcher for +SkyPilotKubernetesLauncher. Useful for: + - Browsing the Tangle UI without needing Docker installed. + - Demoing the SkyPilot launcher integration end-to-end. + +Run: + /home/sky/.venv/bin/uvicorn start_local_skypilot:app --host 0.0.0.0 --port 8000 + +Or: + /home/sky/.venv/bin/python -m uvicorn start_local_skypilot:app --host 0.0.0.0 --port 8000 + +The Tangle frontend (cloned to ./ui_build) is served at http://localhost:8000/. +Pipelines submitted from the UI will be dispatched through SkyPilot — they need +a configured infra (e.g. a Kubernetes context) to actually execute, but the UI +itself is fully browsable without one. +""" + +from __future__ import annotations + +import contextlib +import logging +import logging.config +import os +import pathlib +import threading +import traceback + +import fastapi +import sqlalchemy +from fastapi import staticfiles +from sqlalchemy import orm + +# region: Paths +root_data_dir = ( + os.environ.get("CLOUD_PIPELINES_BACKEND_DATA_DIR") + or os.environ.get("TANGLE_BACKEND_DATA_DIR") + or "data" +) +root_data_dir_path = pathlib.Path(root_data_dir).expanduser() +artifacts_dir_path = root_data_dir_path / "artifacts" +logs_dir_path = root_data_dir_path / "logs" + +root_data_dir_path.mkdir(parents=True, exist_ok=True) +artifacts_dir_path.mkdir(parents=True, exist_ok=True) +logs_dir_path.mkdir(parents=True, exist_ok=True) +# endregion + +# region: DB +database_path = root_data_dir_path / "db.sqlite" +database_uri = f"sqlite:///{database_path}" +print(f"{database_uri=}") +# endregion + +# region: Storage +# Choose between LocalStorageProvider (default) and GoogleCloudStorageProvider +# via TANGLE_STORAGE_BUCKET env var. SkyPilot's file_mounts can mount cloud +# URIs (gs://, s3://, abfs://) but cannot represent relative local paths, so +# multi-step pipelines need a cloud StorageProvider. +storage_bucket = os.environ.get("TANGLE_STORAGE_BUCKET") +if storage_bucket: + from cloud_pipelines.orchestration.storage_providers import google_cloud_storage + storage_provider = google_cloud_storage.GoogleCloudStorageProvider() + bucket_uri = storage_bucket.rstrip("/") + if not bucket_uri.startswith("gs://"): + bucket_uri = "gs://" + bucket_uri + artifacts_root_uri = bucket_uri + "/artifacts" + logs_root_uri = bucket_uri + "/logs" +else: + from cloud_pipelines.orchestration.storage_providers import local_storage + storage_provider = local_storage.LocalStorageProvider() + artifacts_root_uri = artifacts_dir_path.as_posix() + logs_root_uri = logs_dir_path.as_posix() +# endregion + +# region: Launcher — SkyPilot instead of Docker +from cloud_pipelines_backend.launchers.skypilot_launchers import ( + SkyPilotKubernetesLauncher, +) + +_infra_env = os.environ.get("SKYPILOT_INFRA") +launcher = SkyPilotKubernetesLauncher( + # Empty string -> None (let optimizer pick / use API server's in-cluster). + infra=_infra_env if _infra_env else None, + pool=os.environ.get("SKYPILOT_POOL"), + default_image=os.environ.get( + "DEFAULT_CONTAINER_IMAGE", "python:3.11-slim" + ), + default_labels={"managed-by": "tangle"}, + annotation_to_label_keys=["ml.shopify.io/priority-class"], + priority_class=os.environ.get("DEFAULT_PRIORITY_CLASS"), + # Pass the storage provider so upload_log() can mirror SkyPilot logs to + # log_uri and the Tangle UI's /api/.../log endpoint serves them. + storage_provider=storage_provider, +) +# endregion + +# region: Auth (single-user placeholder — same as upstream start_local.py) +from cloud_pipelines_backend import api_router + +ADMIN_USER_NAME = "admin" +default_component_library_owner_username = ADMIN_USER_NAME + + +def get_user_details(request: fastapi.Request): + return api_router.UserDetails( + name=ADMIN_USER_NAME, + permissions=api_router.Permissions(read=True, write=True, admin=True), + ) +# endregion + +# region: Logging +from cloud_pipelines_backend.instrumentation import structured_logging + +LOGGING_CONFIG = { + "version": 1, + "disable_existing_loggers": True, + "formatters": { + "standard": {"format": "%(asctime)s [%(levelname)s] %(name)s: %(message)s"}, + "with_context": {"()": structured_logging.ContextAwareFormatter}, + }, + "filters": { + "context_filter": {"()": structured_logging.LoggingContextFilter}, + }, + "handlers": { + "default": { + "level": "INFO", + "formatter": "with_context", + "class": "logging.StreamHandler", + "stream": "ext://sys.stderr", + "filters": ["context_filter"], + }, + }, + "loggers": { + "": {"level": "INFO", "handlers": ["default"], "propagate": False}, + "uvicorn.error": {"level": "DEBUG", "handlers": ["default"], "propagate": False}, + "uvicorn.access": {"level": "DEBUG", "handlers": ["default"]}, + "watchfiles.main": {"level": "WARNING", "handlers": ["default"]}, + }, +} +logging.config.dictConfig(LOGGING_CONFIG) +logger = logging.getLogger(__name__) +# endregion + +# region: OpenTelemetry (no-op if not configured) +from cloud_pipelines_backend.instrumentation import opentelemetry as otel +otel.setup_providers() +# endregion + +# region: DB engine +from cloud_pipelines_backend import database_ops + +db_engine = database_ops.create_db_engine(database_uri=database_uri) +# endregion + +# region: Orchestrator +from cloud_pipelines_backend import orchestrator_sql + + +def run_configured_orchestrator(): + logger.info("Starting orchestrator (SkyPilot launcher)") + session_factory = orm.sessionmaker( + autocommit=False, autoflush=False, bind=db_engine + ) + orchestrator = orchestrator_sql.OrchestratorService_Sql( + session_factory=session_factory, + launcher=launcher, + storage_provider=storage_provider, + data_root_uri=artifacts_root_uri, + logs_root_uri=logs_root_uri, + default_task_annotations={}, + sleep_seconds_between_queue_sweeps=5.0, + ) + orchestrator.run_loop() +# endregion + +# region: API server +from cloud_pipelines_backend.instrumentation import api_tracing +from cloud_pipelines_backend.instrumentation import contextual_logging + + +@contextlib.asynccontextmanager +async def lifespan(app: fastapi.FastAPI): + database_ops.initialize_and_migrate_db(db_engine=db_engine) + threading.Thread(target=run_configured_orchestrator, daemon=True).start() + logger.info("Tangle UI: open http://localhost:8000/ in a browser") + yield + + +app = fastapi.FastAPI( + title="Cloud Pipelines API (SkyPilot launcher)", + version="0.0.1", + separate_input_output_schemas=False, + lifespan=lifespan, +) +otel.instrument_fastapi(app) +app.add_middleware(api_tracing.RequestContextMiddleware) + + +@app.exception_handler(Exception) +def handle_error(request: fastapi.Request, exc: BaseException): + exception_str = traceback.format_exception(type(exc), exc, exc.__traceback__) + response = fastapi.responses.JSONResponse( + status_code=503, content={"exception": exception_str}, + ) + request_id = contextual_logging.get_context_metadata("request_id") + if request_id: + response.headers["x-tangle-request-id"] = request_id + return response + + +api_router.setup_routes( + app=app, + db_engine=db_engine, + user_details_getter=get_user_details, + container_launcher_for_log_streaming=launcher, + default_component_library_owner_username=default_component_library_owner_username, +) + + +@app.get("/services/ping") +def health_check(): + return {} + + +# Mount the prebuilt frontend (cloned from TangleML/tangle-ui to ./ui_build). +this_dir = pathlib.Path(__file__).parent +web_app_search_dirs = [ + this_dir / "ui_build", + this_dir / ".." / "ui_build", +] +mounted = False +for web_app_dir in web_app_search_dirs: + if web_app_dir.exists(): + logger.info(f"Mounting frontend from {web_app_dir}") + app.mount( + "/tangle-ui/", + staticfiles.StaticFiles(directory=web_app_dir, html=True), + name="static-tangle-ui", + ) + app.mount( + "/pipeline-studio-app/", + staticfiles.StaticFiles(directory=web_app_dir, html=True), + name="static-studio", + ) + app.mount( + "/", + staticfiles.StaticFiles(directory=web_app_dir, html=True), + name="static-root", + ) + mounted = True + break +if not mounted: + logger.warning("Frontend build files not found; UI will not be available.") +# endregion diff --git a/tests/test_skypilot_launchers.py b/tests/test_skypilot_launchers.py new file mode 100644 index 0000000..eeeb00a --- /dev/null +++ b/tests/test_skypilot_launchers.py @@ -0,0 +1,763 @@ +"""Tests for cloud_pipelines_backend.launchers.skypilot_launchers. + +Translation tests; the sky.jobs SDK calls are stubbed so the test runs offline. +""" + +from __future__ import annotations + +import dataclasses +import sys +import types + +import pytest + + +@pytest.fixture(autouse=True) +def _stub_sky(monkeypatch): + """Stub the sky module so launcher tests can run without a real SkyPilot install.""" + sky_mod = types.ModuleType("sky") + + class _FakeTask: + def __init__(self, *, name=None, run=None, envs=None, num_nodes=1, + file_mounts=None, **kwargs): + self.name = name + self.run = run + self.envs = envs or {} + self.num_nodes = num_nodes + self.file_mounts = file_mounts + self.resources = None + self.kwargs = kwargs + + def set_resources(self, r): + self.resources = r + return self + + def set_file_mounts(self, fm): + self.file_mounts = fm + return self + + @classmethod + def from_yaml_config(cls, cfg): + # Mirror sky.Task.from_yaml_config: build the task and apply Resources. + t = cls( + name=cfg.get("name"), + run=cfg.get("run"), + envs=cfg.get("envs", {}), + num_nodes=cfg.get("num_nodes", 1), + file_mounts=cfg.get("file_mounts"), + ) + if "resources" in cfg: + t.set_resources(_FakeResources(**cfg["resources"])) + return t + + class _FakeResources: + def __init__(self, **kwargs): + self.kwargs = kwargs + + def _fake_get(req): + return req + + sky_mod.Task = _FakeTask + sky_mod.Resources = _FakeResources + sky_mod.get = _fake_get + + sky_jobs_mod = types.ModuleType("sky.jobs") + + def _fake_launch(task, name=None, **kwargs): + return ([12345], None) + + def _fake_queue(refresh=False, job_ids=None, **kwargs): + return [{ + "job_id": (job_ids or [12345])[0], + "status": "RUNNING", + "start_at": 1700000000.0, + "end_at": None, + "failure_reason": None, + }] + + def _fake_cancel(job_ids=None, **kwargs): + return {"cancelled": list(job_ids or [])} + + def _fake_tail_logs(job_id=None, follow=False, output_stream=None, **kwargs): + if output_stream is not None: + output_stream.write(f"job {job_id} log line\n") + return 0 + + sky_jobs_mod.launch = _fake_launch + sky_jobs_mod.queue = _fake_queue + sky_jobs_mod.cancel = _fake_cancel + sky_jobs_mod.tail_logs = _fake_tail_logs + + sky_mod.jobs = sky_jobs_mod + monkeypatch.setitem(sys.modules, "sky", sky_mod) + monkeypatch.setitem(sys.modules, "sky.jobs", sky_jobs_mod) + sys.modules.pop("cloud_pipelines_backend.launchers.skypilot_launchers", None) + yield + sys.modules.pop("cloud_pipelines_backend.launchers.skypilot_launchers", None) + + +def _make_component(image="python:3.11", command=None, args=None, env=None, + inputs=None, name="test"): + from cloud_pipelines_backend import component_structures as structures + return structures.ComponentSpec( + name=name, + inputs=[structures.InputSpec(n) for n in (inputs or [])], + implementation=structures.ContainerImplementation( + container=structures.ContainerSpec( + image=image, command=command, args=args, env=env, + ) + ), + ) + + +def test_minimal_command_translation(): + from cloud_pipelines_backend.launchers.skypilot_launchers import ( + SkyPilotKubernetesLauncher, + ) + + launcher = SkyPilotKubernetesLauncher(infra="kubernetes") + component = _make_component( + image="ghcr.io/example/trainer:1.0", + command=["python", "-m", "trainer"], + args=["--epochs", "3"], + ) + task = launcher._build_task( + component_spec=component, + container_spec=component.implementation.container, + input_arguments={}, + output_uris={}, + annotations={}, + ) + assert "python -m trainer --epochs 3" in task.run + assert "TANGLE_MULTI_NODE_NODE_INDEX" in task.run + assert task.num_nodes == 1 + assert task.resources.kwargs["image_id"] == "docker:ghcr.io/example/trainer:1.0" + assert task.resources.kwargs["infra"] == "kubernetes" + + +def test_resource_annotations_propagate(): + from cloud_pipelines_backend.launchers import kubernetes_launchers as k8sL + from cloud_pipelines_backend.launchers.skypilot_launchers import ( + SkyPilotKubernetesLauncher, + ) + + launcher = SkyPilotKubernetesLauncher( + infra="kubernetes", priority_class="emergency" + ) + component = _make_component(command=["bash", "-c", "true"]) + task = launcher._build_task( + component_spec=component, + container_spec=component.implementation.container, + input_arguments={}, + output_uris={}, + annotations={ + k8sL.RESOURCES_CPU_ANNOTATION_KEY: "4+", + k8sL.RESOURCES_MEMORY_ANNOTATION_KEY: "32", + k8sL.RESOURCES_ACCELERATORS_ANNOTATION_KEY: "H100:8", + k8sL.RESOURCES_EPHEMERAL_STORAGE_ANNOTATION_KEY: "200Gi", + k8sL.MULTI_NODE_NUMBER_OF_NODES_ANNOTATION_KEY: "4", + }, + ) + r = task.resources.kwargs + assert r["cpus"] == "4+" + assert r["memory"] == "32" + assert r["accelerators"] == "H100:8" + assert r["disk_size"] == 200 + assert r["priority_class"] == "emergency" + assert task.num_nodes == 4 + + +def test_multi_node_dynamic_data(): + from cloud_pipelines_backend import component_structures as structures + from cloud_pipelines_backend.launchers import interfaces, kubernetes_launchers as k8sL + from cloud_pipelines_backend.launchers.skypilot_launchers import ( + SkyPilotKubernetesLauncher, + ) + + component = _make_component( + command=["echo"], + args=[ + structures.InputValuePlaceholder("rank"), + structures.InputValuePlaceholder("nnodes"), + ], + inputs=["rank", "nnodes"], + ) + input_arguments = { + "rank": interfaces.InputArgument( + total_size=0, is_dir=False, staging_uri="", + dynamic_data="system/multi_node/node_index", + ), + "nnodes": interfaces.InputArgument( + total_size=0, is_dir=False, staging_uri="", + dynamic_data="system/multi_node/number_of_nodes", + ), + } + launcher = SkyPilotKubernetesLauncher(infra="kubernetes") + task = launcher._build_task( + component_spec=component, + container_spec=component.implementation.container, + input_arguments=input_arguments, + output_uris={}, + annotations={k8sL.MULTI_NODE_NUMBER_OF_NODES_ANNOTATION_KEY: "2"}, + ) + assert "TANGLE_MULTI_NODE_NODE_INDEX" in task.run + assert "TANGLE_MULTI_NODE_NUMBER_OF_NODES" in task.run + assert task.num_nodes == 2 + + +def test_input_path_uri_becomes_file_mount(): + from cloud_pipelines_backend import component_structures as structures + from cloud_pipelines_backend.launchers import interfaces + from cloud_pipelines_backend.launchers.skypilot_launchers import ( + SkyPilotKubernetesLauncher, + ) + + component = _make_component( + command=["cat"], + args=[structures.InputPathPlaceholder("dataset")], + inputs=["dataset"], + ) + input_arguments = { + "dataset": interfaces.InputArgument( + total_size=10**9, + is_dir=False, + uri="gs://example-bucket/datasets/foo.parquet", + staging_uri="gs://example-bucket/staging/foo", + ), + } + launcher = SkyPilotKubernetesLauncher(infra="kubernetes") + task = launcher._build_task( + component_spec=component, + container_spec=component.implementation.container, + input_arguments=input_arguments, + output_uris={}, + annotations={}, + ) + assert task.file_mounts is not None + # SkyPilot MOUNT mode requires source to be a bucket root, not a sub-path. + # Launcher mounts each unique bucket once at /mnt/skypilot// + # and uses sub-paths inside the container. + mount_point = "/mnt/skypilot/gs/example-bucket" + assert mount_point in task.file_mounts + fm_value = task.file_mounts[mount_point] + assert isinstance(fm_value, dict) + assert fm_value["source"] == "gs://example-bucket" + assert fm_value["mode"] == "MOUNT" + assert f"{mount_point}/datasets/foo.parquet" in task.run + + +def test_priority_class_annotation_overrides_default(): + from cloud_pipelines_backend.launchers.skypilot_launchers import ( + SkyPilotKubernetesLauncher, + PRIORITY_CLASS_ANNOTATION_KEY, + ) + + launcher = SkyPilotKubernetesLauncher(priority_class="batch") + component = _make_component(command=["true"]) + task = launcher._build_task( + component_spec=component, + container_spec=component.implementation.container, + input_arguments={}, + output_uris={}, + annotations={PRIORITY_CLASS_ANNOTATION_KEY: "interactive"}, + ) + assert task.resources.kwargs["priority_class"] == "interactive" + + +def test_annotation_to_label_propagation(): + from cloud_pipelines_backend.launchers.skypilot_launchers import ( + SkyPilotKubernetesLauncher, + ) + + launcher = SkyPilotKubernetesLauncher( + annotation_to_label_keys=["ml.shopify.io/priority-class"], + ) + component = _make_component(command=["true"]) + task = launcher._build_task( + component_spec=component, + container_spec=component.implementation.container, + input_arguments={}, + output_uris={}, + annotations={"ml.shopify.io/priority-class": "interactive"}, + ) + labels = task.resources.kwargs["labels"] + assert labels["ml_shopify_io_priority-class"] == "interactive" + + +def test_serialize_round_trip(): + from cloud_pipelines_backend.launchers.skypilot_launchers import ( + SkyPilotLaunchedJob, _SkyPilotJobHandle, + ) + from cloud_pipelines_backend.launchers import interfaces + + job = SkyPilotLaunchedJob( + handle=_SkyPilotJobHandle( + job_id=42, + job_name="tangle-test", + output_uris={"out": "gs://x/y"}, + log_uri="gs://x/log", + cached_status="RUNNING", + cached_started_at=1700000000.0, + ) + ) + d = job.to_dict() + job2 = SkyPilotLaunchedJob.from_dict(d) + assert job2.job_id == 42 + assert job2.status == interfaces.ContainerStatus.RUNNING + assert not job2.has_ended + + +def test_status_mapping_terminal_states(): + from cloud_pipelines_backend.launchers.skypilot_launchers import ( + SkyPilotLaunchedJob, _SkyPilotJobHandle, + ) + from cloud_pipelines_backend.launchers import interfaces + + def make(status): + return SkyPilotLaunchedJob( + handle=_SkyPilotJobHandle( + job_id=1, job_name="x", output_uris={}, log_uri="", + cached_status=status, + ) + ) + + assert make("SUCCEEDED").has_succeeded + assert make("FAILED").has_failed + assert make("CANCELLED").has_failed + assert make("RUNNING").status == interfaces.ContainerStatus.RUNNING + assert make("FAILED_CONTROLLER").status == interfaces.ContainerStatus.ERROR + + +def test_missing_image_raises(): + from cloud_pipelines_backend.launchers import interfaces + from cloud_pipelines_backend.launchers.skypilot_launchers import ( + SkyPilotKubernetesLauncher, + ) + + component = _make_component(image=None, command=["true"]) + launcher = SkyPilotKubernetesLauncher() + with pytest.raises(interfaces.LauncherError, match="no container image"): + launcher._build_task( + component_spec=component, + container_spec=component.implementation.container, + input_arguments={}, + output_uris={}, + annotations={}, + ) + + +def test_default_image_fallback(): + from cloud_pipelines_backend.launchers.skypilot_launchers import ( + SkyPilotKubernetesLauncher, + ) + + component = _make_component(image=None, command=["true"]) + launcher = SkyPilotKubernetesLauncher(default_image="python:3.11") + task = launcher._build_task( + component_spec=component, + container_spec=component.implementation.container, + input_arguments={}, + output_uris={}, + annotations={}, + ) + assert task.resources.kwargs["image_id"] == "docker:python:3.11" + + +def test_num_nodes_out_of_range(): + from cloud_pipelines_backend.launchers import interfaces, kubernetes_launchers as k8sL + from cloud_pipelines_backend.launchers.skypilot_launchers import ( + SkyPilotKubernetesLauncher, + ) + + component = _make_component(command=["true"]) + launcher = SkyPilotKubernetesLauncher() + with pytest.raises(interfaces.LauncherError): + launcher._build_task( + component_spec=component, + container_spec=component.implementation.container, + input_arguments={}, + output_uris={}, + annotations={k8sL.MULTI_NODE_NUMBER_OF_NODES_ANNOTATION_KEY: "100000"}, + ) + + +# ----------------------------------------------------------------- +# SkyPilot-only capabilities — features Tangle's existing +# kubernetes_launchers cannot do today, exercised end-to-end here. +# ----------------------------------------------------------------- + + +def test_skypilot_only_num_nodes_above_tangle_k8s_cap(): + """Tangle's kubernetes_launchers caps num_nodes at 16; SkyPilot scales further.""" + from cloud_pipelines_backend.launchers import kubernetes_launchers as k8sL + from cloud_pipelines_backend.launchers.skypilot_launchers import ( + SkyPilotKubernetesLauncher, _MULTI_NODE_MAX_NUMBER_OF_NODES, + ) + + # Sanity: this launcher's cap exceeds Tangle K8s launcher's hardcoded 16. + assert k8sL._MULTI_NODE_MAX_NUMBER_OF_NODES == 16 + assert _MULTI_NODE_MAX_NUMBER_OF_NODES > 16 + + component = _make_component(command=["true"]) + launcher = SkyPilotKubernetesLauncher() + task = launcher._build_task( + component_spec=component, + container_spec=component.implementation.container, + input_arguments={}, + output_uris={}, + annotations={ + k8sL.MULTI_NODE_NUMBER_OF_NODES_ANNOTATION_KEY: "32", + }, + ) + # 32 nodes is unrepresentable in Tangle's K8s launcher; works here. + assert task.num_nodes == 32 + + +def test_skypilot_only_use_spot_with_recovery(): + """SkyPilot supports cross-cloud spot/preemptible + auto-recovery via managed + jobs. Tangle's kubernetes_launchers only has GKE-specific spot via node selector + (KUBERNETES_GOOGLE_USE_SPOT_VMS_ANNOTATION_KEY) and no preemption recovery.""" + from cloud_pipelines_backend.launchers.skypilot_launchers import ( + SkyPilotKubernetesLauncher, SPOT_ANNOTATION_KEY, + ) + + component = _make_component(command=["true"]) + launcher = SkyPilotKubernetesLauncher() + task = launcher._build_task( + component_spec=component, + container_spec=component.implementation.container, + input_arguments={}, + output_uris={}, + annotations={SPOT_ANNOTATION_KEY: "true"}, + ) + assert task.resources.kwargs["use_spot"] is True + + +def test_skypilot_only_multi_cloud_no_infra(): + """infra=None lets SkyPilot's optimizer pick across all configured clouds and + K8s contexts. Tangle's kubernetes_launchers takes a single api_client and is + pinned to one cluster.""" + from cloud_pipelines_backend.launchers.skypilot_launchers import ( + SkyPilotKubernetesLauncher, + ) + + component = _make_component(command=["true"]) + launcher = SkyPilotKubernetesLauncher(infra=None) + task = launcher._build_task( + component_spec=component, + container_spec=component.implementation.container, + input_arguments={}, + output_uris={}, + annotations={}, + ) + # Without infra, SkyPilot's optimizer picks across all configured clouds. + assert "infra" not in task.resources.kwargs + + +def test_skypilot_only_pool_dispatch_passes_through_to_launch(): + """SkyPilot Pools provide warm-pool reuse for fast cold-start. No Tangle + equivalent — every Tangle K8s task creates a fresh Pod from scratch.""" + from cloud_pipelines_backend.launchers import interfaces + from cloud_pipelines_backend.launchers.skypilot_launchers import ( + SkyPilotKubernetesLauncher, + ) + import sky.jobs + + submissions: list = [] + original_launch = sky.jobs.launch + + def _capture(task, name=None, **kwargs): + submissions.append({"task": task, "name": name, "kwargs": kwargs}) + return ([42], None) + + sky.jobs.launch = _capture + try: + launcher = SkyPilotKubernetesLauncher(pool="ml-training") + component = _make_component(command=["true"]) + launcher.launch_container_task( + component_spec=component, + input_arguments={}, + output_uris={}, + log_uri="local:test.log", + annotations={}, + ) + finally: + sky.jobs.launch = original_launch + + assert len(submissions) == 1 + assert submissions[0]["kwargs"].get("pool") == "ml-training" + + +def test_skypilot_only_s3_file_mount_accepted(): + """SkyPilot file_mounts accept gs://, s3://, https://, abfs://, and more. + Tangle's kubernetes_launchers ships with HostPath (local) and gcsfuse (GCS) + only — no S3 / R2 / Azure Blob built in.""" + from cloud_pipelines_backend import component_structures as structures + from cloud_pipelines_backend.launchers import interfaces + from cloud_pipelines_backend.launchers.skypilot_launchers import ( + SkyPilotKubernetesLauncher, + ) + + component = _make_component( + command=["cat"], + args=[structures.InputPathPlaceholder("dataset")], + inputs=["dataset"], + ) + input_arguments = { + "dataset": interfaces.InputArgument( + total_size=10**9, is_dir=False, + uri="s3://my-bucket/datasets/foo.parquet", + staging_uri="", + ), + } + launcher = SkyPilotKubernetesLauncher() + task = launcher._build_task( + component_spec=component, + container_spec=component.implementation.container, + input_arguments=input_arguments, + output_uris={}, + annotations={}, + ) + # Cloud URIs are wrapped in dicts (source: ...) for storage_mount promotion. + assert any( + isinstance(v, dict) and v.get("source", "").startswith("s3://") + for v in task.file_mounts.values() + ) + + +def test_skypilot_only_first_class_priority_class(): + """SkyPilot has priority_class as a first-class Resources kwarg with built-in + Kueue integration. Tangle's kubernetes_launchers requires a custom + pod_postprocessor to set spec.priorityClassName — there's no annotation API + for it out of the box.""" + from cloud_pipelines_backend.launchers.skypilot_launchers import ( + SkyPilotKubernetesLauncher, PRIORITY_CLASS_ANNOTATION_KEY, + ) + + component = _make_component(command=["true"]) + launcher = SkyPilotKubernetesLauncher() + task = launcher._build_task( + component_spec=component, + container_spec=component.implementation.container, + input_arguments={}, + output_uris={}, + annotations={PRIORITY_CLASS_ANNOTATION_KEY: "interactive"}, + ) + assert task.resources.kwargs["priority_class"] == "interactive" + + +def test_accelerators_json_dict_format_compat(): + """Tangle's kubernetes_launchers expects accelerators as a JSON object + ({"nvidia-tesla-h100": 8}). Our launcher accepts that form too so the same + ComponentSpec is portable across launchers — and forwards it to SkyPilot, + which accepts {name: count} dicts directly.""" + import json as _json + from cloud_pipelines_backend.launchers import kubernetes_launchers as k8sL + from cloud_pipelines_backend.launchers.skypilot_launchers import ( + SkyPilotKubernetesLauncher, + ) + + component = _make_component(command=["true"]) + launcher = SkyPilotKubernetesLauncher() + task = launcher._build_task( + component_spec=component, + container_spec=component.implementation.container, + input_arguments={}, + output_uris={}, + annotations={ + k8sL.RESOURCES_ACCELERATORS_ANNOTATION_KEY: _json.dumps( + {"nvidia-tesla-h100": 8} + ), + }, + ) + assert task.resources.kwargs["accelerators"] == {"nvidia-tesla-h100": 8} + + +def test_accelerators_sky_string_format_still_works(): + """Plain SkyPilot string accelerators ('H100:8') also pass through unchanged.""" + from cloud_pipelines_backend.launchers import kubernetes_launchers as k8sL + from cloud_pipelines_backend.launchers.skypilot_launchers import ( + SkyPilotKubernetesLauncher, + ) + + component = _make_component(command=["true"]) + launcher = SkyPilotKubernetesLauncher() + task = launcher._build_task( + component_spec=component, + container_spec=component.implementation.container, + input_arguments={}, + output_uris={}, + annotations={k8sL.RESOURCES_ACCELERATORS_ANNOTATION_KEY: "H100:8"}, + ) + assert task.resources.kwargs["accelerators"] == "H100:8" + + +def test_multistep_with_cloud_uris_passes_through(): + """Two-step pipelines work when the storage provider produces cloud URIs. + The upstream task's output URI (e.g. gs://bucket/.../output/data) is + handed to the downstream task as InputArgument.uri, and our launcher + mounts both via SkyPilot's file_mounts — the same URI on both sides + means the downstream container reads what the upstream wrote.""" + from cloud_pipelines_backend import component_structures as structures + from cloud_pipelines_backend.launchers import interfaces + from cloud_pipelines_backend.launchers.skypilot_launchers import ( + SkyPilotKubernetesLauncher, + ) + + # Step 2's component reads `message_file` and writes `shouted`. + component = _make_component( + command=["sh", "-c", 'tr "[:lower:]" "[:upper:]" < "$0" > "$1"'], + args=[ + structures.InputPathPlaceholder("message_file"), + structures.OutputPathPlaceholder("shouted"), + ], + inputs=["message_file"], + ) + # Storage provider has put step 1's output at this gs:// URI; Tangle hands + # it to step 2 verbatim as InputArgument.uri: + upstream_uri = "gs://tangle-test/by_execution/abc123/outputs/message_file/data" + downstream_output_uri = ( + "gs://tangle-test/by_execution/def456/outputs/shouted/data" + ) + input_arguments = { + "message_file": interfaces.InputArgument( + total_size=10**6, is_dir=False, + uri=upstream_uri, staging_uri="", + ), + } + launcher = SkyPilotKubernetesLauncher(infra="kubernetes") + task = launcher._build_task( + component_spec=component, + container_spec=component.implementation.container, + input_arguments=input_arguments, + output_uris={"shouted": downstream_output_uri}, + annotations={}, + ) + # MOUNT mode mounts the parent dir of each artifact file (gcsfuse can't + # mount individual files). The launcher rpartitions on '/'. + assert task.file_mounts is not None + mount_sources = [ + v.get("source") for v in task.file_mounts.values() if isinstance(v, dict) + ] + # The launcher dedups by bucket, so both URIs share the same mount. + assert "gs://tangle-test" in mount_sources + + +def test_input_local_uri_raises_actionable_error(): + """A non-cloud (local) URI for an input is rejected up front with an + actionable message (vs. letting sky.Task validation fail with a generic + 'file does not exist' error). Surfaced during E2E testing with Tangle's + LocalStorageProvider.""" + from cloud_pipelines_backend import component_structures as structures + from cloud_pipelines_backend.launchers import interfaces + from cloud_pipelines_backend.launchers.skypilot_launchers import ( + SkyPilotKubernetesLauncher, + ) + + component = _make_component( + command=["cat"], + args=[structures.InputPathPlaceholder("dataset")], + inputs=["dataset"], + ) + input_arguments = { + "dataset": interfaces.InputArgument( + total_size=10**6, is_dir=False, + uri="data/artifacts/by_execution/abc/inputs/dataset/data", # local + staging_uri="", + ), + } + launcher = SkyPilotKubernetesLauncher() + with pytest.raises(interfaces.LauncherError, match="cloud storage URI"): + launcher._build_task( + component_spec=component, + container_spec=component.implementation.container, + input_arguments=input_arguments, + output_uris={}, + annotations={}, + ) + + +def test_output_local_uri_skipped_no_mount(): + """A non-cloud (local) output URI is skipped with a warning. The container + can still write to its own /tmp/outputs/ inside the pod; the artifact just + won't be persisted to Tangle's local storage. Surfaced during E2E testing.""" + from cloud_pipelines_backend import component_structures as structures + from cloud_pipelines_backend.launchers.skypilot_launchers import ( + SkyPilotKubernetesLauncher, + ) + + component = _make_component( + command=["sh", "-c", "echo hi"], + args=[structures.OutputPathPlaceholder("greeting")], + ) + launcher = SkyPilotKubernetesLauncher() + task = launcher._build_task( + component_spec=component, + container_spec=component.implementation.container, + input_arguments={}, + output_uris={"greeting": "data/artifacts/by_execution/abc/outputs/greeting/data"}, + annotations={}, + ) + # No file_mounts entry for the local output URI. + assert task.file_mounts is None or all( + not v.startswith("data/") for v in (task.file_mounts or {}).values() + ) + + +def test_end_to_end_lifecycle_through_stubbed_sky(): + """End-to-end exercise: launch -> refresh status -> stream logs -> terminate. + + Uses the stubbed sky.jobs SDK from the test fixture. Verifies the full + LaunchedJob lifecycle a Tangle orchestrator would drive. + """ + from cloud_pipelines_backend.launchers import interfaces + from cloud_pipelines_backend.launchers.skypilot_launchers import ( + SkyPilotKubernetesLauncher, SkyPilotLaunchedJob, + ) + + component = _make_component( + image="ghcr.io/example/trainer:1.0", + command=["python", "-m", "trainer", "--epochs", "3"], + name="trainer", + ) + launcher = SkyPilotKubernetesLauncher( + infra="kubernetes", + priority_class="batch", + annotation_to_label_keys=["ml.shopify.io/priority-class"], + default_labels={"managed-by": "tangle"}, + ) + + # 1. Submit + handle = launcher.launch_container_task( + component_spec=component, + input_arguments={}, + output_uris={"checkpoint": "gs://example/ckpt"}, + log_uri="gs://example/logs/trainer", + annotations={ + "cloud-pipelines.net/launchers/generic/resources.cpu": "8+", + "cloud-pipelines.net/launchers/generic/resources.memory": "32", + "cloud-pipelines.net/launchers/generic/resources.accelerators": "H100:8", + "tangleml.com/launchers/kubernetes/multi_node/number_of_nodes": "2", + "ml.shopify.io/priority-class": "interactive", + }, + ) + assert isinstance(handle, SkyPilotLaunchedJob) + assert handle.job_id == 12345 + + # 2. Refresh — pull current status from sky.jobs.queue + refreshed = handle.get_refreshed() + assert refreshed.status == interfaces.ContainerStatus.RUNNING + assert not refreshed.has_ended + + # 3. Pull logs (one-shot) + log = refreshed.get_log() + assert "log line" in log + + # 4. Persist + reload (orchestrator restart simulation) + serialized = refreshed.to_dict() + reloaded = SkyPilotLaunchedJob.from_dict(serialized) + assert reloaded.job_id == refreshed.job_id + assert reloaded.status == refreshed.status + + # 5. Terminate + reloaded.terminate() # No exception => success