From 288d0c67d79e7876ce28d2cc6f29300e19dd3a6b Mon Sep 17 00:00:00 2001 From: Akanksha Gupta Date: Tue, 17 Mar 2026 14:36:05 -0700 Subject: [PATCH] Add CLI mode to Shared Pathways Service This commit adds a new script `run_tpu_workload.py`, which allows users to provide a command to the Shared Pathways Service. The user can simply add `pathwaysutils.initialize()` to their script and run it with `--command` flag. PiperOrigin-RevId: 885215367 --- .../shared_pathways_service/gke_utils.py | 2 + .../shared_pathways_service/isc_pathways.py | 27 +-- .../shared_pathways_service/run_workload.py | 167 ++++++++++++++++++ 3 files changed, 183 insertions(+), 13 deletions(-) create mode 100644 pathwaysutils/experimental/shared_pathways_service/run_workload.py diff --git a/pathwaysutils/experimental/shared_pathways_service/gke_utils.py b/pathwaysutils/experimental/shared_pathways_service/gke_utils.py index 3b30dc1..a47199a 100644 --- a/pathwaysutils/experimental/shared_pathways_service/gke_utils.py +++ b/pathwaysutils/experimental/shared_pathways_service/gke_utils.py @@ -3,6 +3,7 @@ import logging import socket import subprocess +import time import urllib.parse import portpicker @@ -189,6 +190,7 @@ def wait_for_pod(job_name: str) -> str: RuntimeError: If the pod is not ready. """ _logger.info("Waiting for pod to be created...") + time.sleep(1) pod_name = get_pod_from_job(job_name) _logger.info( diff --git a/pathwaysutils/experimental/shared_pathways_service/isc_pathways.py b/pathwaysutils/experimental/shared_pathways_service/isc_pathways.py index b25df7e..8ad4f20 100644 --- a/pathwaysutils/experimental/shared_pathways_service/isc_pathways.py +++ b/pathwaysutils/experimental/shared_pathways_service/isc_pathways.py @@ -192,8 +192,10 @@ def __repr__(self): def __enter__(self): """Enters the context manager, ensuring cluster exists.""" - self._old_jax_platforms = os.environ.get(_JAX_PLATFORMS_KEY) - self._old_jax_backend_target = os.environ.get(_JAX_BACKEND_TARGET_KEY) + self._old_jax_platforms = os.environ.get(_JAX_PLATFORMS_KEY.upper()) + self._old_jax_backend_target = os.environ.get( + _JAX_BACKEND_TARGET_KEY.upper() + ) self._old_jax_platforms_config = getattr( jax.config, _JAX_PLATFORMS_KEY, None ) @@ -224,16 +226,13 @@ def __enter__(self): ) # Update the JAX backend to use the proxy. - os.environ[_JAX_PLATFORMS_KEY] = _JAX_PLATFORM_PROXY - os.environ[ - _JAX_BACKEND_TARGET_KEY - ] = f"{_JAX_BACKEND_TARGET_HOSTNAME}:{self._proxy_port}" - + jax_backend_target = f"{_JAX_BACKEND_TARGET_HOSTNAME}:{self._proxy_port}" jax.config.update(_JAX_PLATFORMS_KEY, _JAX_PLATFORM_PROXY) - jax.config.update( - _JAX_BACKEND_TARGET_KEY, - f"{_JAX_BACKEND_TARGET_HOSTNAME}:{self._proxy_port}", - ) + jax.config.update(_JAX_BACKEND_TARGET_KEY, jax_backend_target) + # Update the environment variables for the CLI mode of Shared Pathways + # Service. + os.environ[_JAX_PLATFORMS_KEY.upper()] = _JAX_PLATFORM_PROXY + os.environ[_JAX_BACKEND_TARGET_KEY.upper()] = jax_backend_target pathwaysutils.initialize() _logger.info( @@ -281,8 +280,10 @@ def _cleanup(self) -> None: # 4. Restore JAX variables. _logger.info("Restoring JAX env and config variables...") - _restore_env_var(_JAX_PLATFORMS_KEY, self._old_jax_platforms) - _restore_env_var(_JAX_BACKEND_TARGET_KEY, self._old_jax_backend_target) + _restore_env_var(_JAX_PLATFORMS_KEY.upper(), self._old_jax_platforms) + _restore_env_var( + _JAX_BACKEND_TARGET_KEY.upper(), self._old_jax_backend_target + ) jax.config.update(_JAX_PLATFORMS_KEY, self._old_jax_platforms_config) jax.config.update( _JAX_BACKEND_TARGET_KEY, self._old_jax_backend_target_config diff --git a/pathwaysutils/experimental/shared_pathways_service/run_workload.py b/pathwaysutils/experimental/shared_pathways_service/run_workload.py new file mode 100644 index 0000000..f662c35 --- /dev/null +++ b/pathwaysutils/experimental/shared_pathways_service/run_workload.py @@ -0,0 +1,167 @@ +r"""Run a TPU workload with Shared Pathways Service. + +Run your TPU workload locally using Shared Pathways Service, the service will +deploy a Pathways proxy to run the TPU-specific components of your workload on +the requested TPU slices. + +Example: +python3 run_workload.py \ + --cluster my-cluster \ + --project my-project \ + --region=us-central1 \ + --gcs_bucket=my-gcs-bucket \ + --pathways_service=pathways-head:8000 \ + --tpu_type=tpuv6e:4x8 \ + --tpu_count=1 \ + --command "python3 my_workload.py ..." + +""" + +from collections.abc import Callable, Sequence +import os +import shlex +import subprocess +from typing import Any, ContextManager + +from absl import app +from absl import flags +from absl import logging +from pathwaysutils.experimental.shared_pathways_service import isc_pathways + + +_CLUSTER = flags.DEFINE_string( + "cluster", None, "The name of the GKE cluster.", required=True +) +_PROJECT = flags.DEFINE_string( + "project", None, "The GCP project ID.", required=True +) +_REGION = flags.DEFINE_string( + "region", None, "The GCP region.", required=True +) +_GCS_BUCKET = flags.DEFINE_string( + "gcs_bucket", None, "The Google Cloud Storage bucket.", required=True +) +_PATHWAYS_SERVICE = flags.DEFINE_string( + "pathways_service", + None, + "The address and port of the Pathways Resource Manager. See" + " https://github.com/AI-Hypercomputer/pathways-utils/tree/main/pathwaysutils/experimental/shared_pathways_service#4-find-the-pathways-service-address" + " for instructions on how to get the Pathways service address.", + required=True, +) +_TPU_TYPE = flags.DEFINE_string( + "tpu_type", "tpuv6e:2x2", "The TPU machine type and topology." +) +_TPU_COUNT = flags.DEFINE_integer("tpu_count", 1, "The number of TPU slices.") +_PROXY_SERVER_IMAGE = flags.DEFINE_string( + "proxy_server_image", + "", + "The proxy server image to use. If not provided, a default will be used.", +) +_PROXY_OPTIONS = flags.DEFINE_list( + "proxy_options", + [], + "Configuration options for the Pathways proxy. Specify entries in the form" + ' "key:value". For example: --proxy_options=use_insecure_credentials:true', +) +_COMMAND = flags.DEFINE_string( + "command", None, "The command to run on TPUs.", required=True +) + +flags.register_validator( + "proxy_options", + lambda value: all( + ":" in item + and len(item.split(":")) > 1 + and item.split(":", 1)[0] + and item.split(":", 1)[1] + for item in value + ), + message='--proxy_options must be in the format "key:value".', +) + + +def run_command( + *, + cluster: str, + project: str, + region: str, + gcs_bucket: str, + pathways_service: str, + tpu_type: str, + tpu_count: int, + command: str, + proxy_server_image: str | None = None, + proxy_options: Sequence[str] | None = None, + connect_fn: Callable[..., ContextManager[Any]] = isc_pathways.connect, +) -> None: + """Run the TPU workload within a Shared Pathways connection. + + Args: + cluster: The name of the GKE cluster. + project: The GCP project ID. + region: The GCP region. + gcs_bucket: The Google Cloud Storage bucket. + pathways_service: The address and port of the Pathways Resource Manager. + tpu_type: The TPU machine type and topology. + tpu_count: The number of TPU slices. + command: The command to run on TPUs. + proxy_server_image: The proxy server image to use. + proxy_options: Configuration options for the Pathways proxy. + connect_fn: The function to use for establishing the connection context, + expected to be a callable that returns a context manager. + + Raises: + subprocess.CalledProcessError: If the workload command fails. + """ + parsed_proxy_options = isc_pathways.ProxyOptions.from_list(proxy_options) + + logging.info("Connecting to Shared Pathways Service...") + with connect_fn( + cluster=cluster, + project=project, + region=region, + gcs_bucket=gcs_bucket, + pathways_service=pathways_service, + expected_tpu_instances={tpu_type: tpu_count}, + proxy_server_image=( + proxy_server_image + if proxy_server_image + else isc_pathways.DEFAULT_PROXY_IMAGE + ), + proxy_options=parsed_proxy_options, + ): + logging.info("Connection established. Running command: %r", command) + try: + command_args = shlex.split(command) + subprocess.run(command_args, check=True, env=os.environ.copy()) + except subprocess.CalledProcessError: + logging.error( + "Command failed! Find the underlying error in the logs above, where" + " the command is invoked." + ) + raise + finally: + logging.info("Command execution finished.") + + +def main(argv: Sequence[str]) -> None: + if len(argv) > 1: + raise app.UsageError("Too many command-line arguments.") + + run_command( + cluster=_CLUSTER.value, + project=_PROJECT.value, + region=_REGION.value, + gcs_bucket=_GCS_BUCKET.value, + pathways_service=_PATHWAYS_SERVICE.value, + tpu_type=_TPU_TYPE.value, + tpu_count=_TPU_COUNT.value, + command=_COMMAND.value, + proxy_server_image=_PROXY_SERVER_IMAGE.value, + proxy_options=_PROXY_OPTIONS.value, + ) + + +if __name__ == "__main__": + app.run(main)