diff --git a/pathwaysutils/experimental/shared_pathways_service/isc_pathways.py b/pathwaysutils/experimental/shared_pathways_service/isc_pathways.py index 797b1a2..c6d2806 100644 --- a/pathwaysutils/experimental/shared_pathways_service/isc_pathways.py +++ b/pathwaysutils/experimental/shared_pathways_service/isc_pathways.py @@ -2,6 +2,7 @@ from collections.abc import Iterator, Mapping import contextlib +import dataclasses import gc import logging import os @@ -29,17 +30,41 @@ _JAX_PLATFORM_PROXY = "proxy" _JAX_BACKEND_TARGET_KEY = "jax_backend_target" _JAX_BACKEND_TARGET_HOSTNAME = "grpc://127.0.0.1" -_DEFAULT_PROXY_IMAGE = "us-docker.pkg.dev/cloud-tpu-v2-images/pathways/proxy_server:latest" +DEFAULT_PROXY_IMAGE = ( + "us-docker.pkg.dev/cloud-tpu-v2-images/pathways/proxy_server:latest" +) _logger = logging.getLogger(__name__) +@dataclasses.dataclass +class ProxyOptions: + """Configuration options for the Pathways proxy. + + Attributes: + use_insecure_credentials: Whether to use insecure gRPC credentials for the + proxy server. + """ + use_insecure_credentials: bool = False + + @classmethod + def from_dict(cls, options: Mapping[str, str] | None) -> "ProxyOptions": + """Creates a ProxyOptions object from a dictionary of options.""" + options = options or {} + use_insecure = ( + options.get("use_insecure_credentials", "false").lower() == "true" + ) + return cls(use_insecure_credentials=use_insecure) + + def _deploy_pathways_proxy_server( - *, pathways_service: str, + *, + pathways_service: str, proxy_job_name: str, expected_instances: Mapping[Any, Any], gcs_scratch_location: str, proxy_server_image: str, + proxy_options: ProxyOptions | None = None, ) -> None: """Deploys the Pathways proxy pods to the GKE cluster. @@ -50,6 +75,8 @@ def _deploy_pathways_proxy_server( instances. gcs_scratch_location: The Google Cloud Storage location to use. proxy_server_image: The image to use for the proxy server. + proxy_options: Configuration options for the Pathways proxy. If not + provided, no extra options will be used. Raises: subprocess.CalledProcessError: If the kubectl command fails. @@ -67,6 +94,13 @@ def _deploy_pathways_proxy_server( instance_type, count = next(iter(expected_instances.items())) instances_str = ",".join(instance_type for _ in range(count)) + proxy_options = proxy_options or ProxyOptions() + + proxy_env_str = ( + ' - name: IFRT_PROXY_USE_INSECURE_GRPC_CREDENTIALS\n' + ' value: "true"\n' + ) if proxy_options.use_insecure_credentials else "" + template = string.Template(yaml_template) substituted_yaml = template.substitute( PROXY_JOB_NAME=proxy_job_name, @@ -76,6 +110,7 @@ def _deploy_pathways_proxy_server( EXPECTED_INSTANCES=instances_str, GCS_SCRATCH_LOCATION=gcs_scratch_location, PROXY_SERVER_IMAGE=proxy_server_image, + PROXY_ENV=proxy_env_str, ) _logger.info("Deploying Pathways proxy: %s", proxy_job_name) @@ -97,6 +132,7 @@ class _ISCPathways: of instances. proxy_job_name: The name to use for the deployed proxy. proxy_server_image: The image to use for the proxy server. + proxy_options: Configuration options for the Pathways proxy. """ def __init__( @@ -109,6 +145,7 @@ def __init__( expected_tpu_instances: Mapping[Any, Any], proxy_job_name: str, proxy_server_image: str, + proxy_options: ProxyOptions | None = None, ): """Initializes the TPU manager.""" self.cluster = cluster @@ -121,6 +158,7 @@ def __init__( self._port_forward_process = None self._proxy_port = None self.proxy_server_image = proxy_server_image + self.proxy_options = proxy_options or ProxyOptions() def __repr__(self): return ( @@ -128,7 +166,8 @@ def __repr__(self): f"region='{self.region}', bucket='{self.bucket}', " f"pathways_service='{self.pathways_service}', " f"expected_tpu_instances={self.expected_tpu_instances}, " - f"_proxy_job_name='{self._proxy_job_name}')" + f"_proxy_job_name='{self._proxy_job_name}', " + f"proxy_options={self.proxy_options})" ) def __enter__(self): @@ -140,6 +179,7 @@ def __enter__(self): expected_instances=self.expected_tpu_instances, gcs_scratch_location=self.bucket, proxy_server_image=self.proxy_server_image, + proxy_options=self.proxy_options, ) # Print a link to Cloud Logging cloud_logging_link = gke_utils.get_log_link( @@ -215,7 +255,8 @@ def connect( pathways_service: str, expected_tpu_instances: Mapping[str, int], proxy_job_name: str | None = None, - proxy_server_image: str = _DEFAULT_PROXY_IMAGE, + proxy_server_image: str = DEFAULT_PROXY_IMAGE, + proxy_options: ProxyOptions | None = None, ) -> Iterator["_ISCPathways"]: """Connects to a Pathways server if the cluster exists. If not, creates it. @@ -231,6 +272,8 @@ def connect( random name will be generated. proxy_server_image: The proxy server image to use. If not provided, a default will be used. + proxy_options: Configuration options for the Pathways proxy. If not + provided, no extra options will be used. Yields: The Pathways manager. @@ -259,5 +302,6 @@ def connect( expected_tpu_instances=expected_tpu_instances, proxy_job_name=proxy_job_name, proxy_server_image=proxy_server_image, + proxy_options=proxy_options, ) as t: yield t diff --git a/pathwaysutils/experimental/shared_pathways_service/run_connect_example.py b/pathwaysutils/experimental/shared_pathways_service/run_connect_example.py index 21145ac..01c71ea 100644 --- a/pathwaysutils/experimental/shared_pathways_service/run_connect_example.py +++ b/pathwaysutils/experimental/shared_pathways_service/run_connect_example.py @@ -9,6 +9,8 @@ from pathwaysutils.experimental.shared_pathways_service import isc_pathways +from google3.pyglib.flags.contrib import dict_flag + FLAGS = flags.FLAGS flags.DEFINE_string("cluster", None, "The name of the GKE cluster.") @@ -35,6 +37,12 @@ None, "The proxy server image to use. If not provided, a default will be used.", ) +dict_flag.DEFINE_dict( + "proxy_options", + None, + "Configuration options for the Pathways proxy. Specify entries in the form" + ' "key:value". For example: --proxy_options=use_insecure_credentials:true', +) flags.mark_flags_as_required([ "cluster", @@ -49,11 +57,7 @@ def main(argv: Sequence[str]) -> None: if len(argv) > 1: raise app.UsageError("Too many command-line arguments.") - kwargs = {} - if FLAGS.proxy_job_name: - kwargs["proxy_job_name"] = FLAGS.proxy_job_name - if FLAGS.proxy_server_image: - kwargs["proxy_server_image"] = FLAGS.proxy_server_image + proxy_options = isc_pathways.ProxyOptions.from_dict(FLAGS.proxy_options) with isc_pathways.connect( cluster=FLAGS.cluster, @@ -62,7 +66,10 @@ def main(argv: Sequence[str]) -> None: gcs_bucket=FLAGS.gcs_bucket, pathways_service=FLAGS.pathways_service, expected_tpu_instances={FLAGS.tpu_type: FLAGS.tpu_count}, - **kwargs, + proxy_job_name=FLAGS.proxy_job_name, + proxy_server_image=FLAGS.proxy_server_image + or isc_pathways.DEFAULT_PROXY_IMAGE, + proxy_options=proxy_options, ): orig_matrix = jnp.zeros(5) result_matrix = orig_matrix + 1 diff --git a/pathwaysutils/experimental/shared_pathways_service/yamls/pw-proxy.yaml b/pathwaysutils/experimental/shared_pathways_service/yamls/pw-proxy.yaml index 9c17221..91d6aa4 100644 --- a/pathwaysutils/experimental/shared_pathways_service/yamls/pw-proxy.yaml +++ b/pathwaysutils/experimental/shared_pathways_service/yamls/pw-proxy.yaml @@ -21,6 +21,8 @@ spec: - --resource_manager_address=${PATHWAYS_HEAD_HOSTNAME}:${PATHWAYS_HEAD_PORT} - --gcs_scratch_location=${GCS_SCRATCH_LOCATION} - --virtual_slices=${EXPECTED_INSTANCES} + env: +${PROXY_ENV} ports: - containerPort: ${PROXY_SERVER_PORT} protocol: TCP