Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 48 additions & 4 deletions pathwaysutils/experimental/shared_pathways_service/isc_pathways.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from collections.abc import Iterator, Mapping
import contextlib
import dataclasses
import gc
import logging
import os
Expand Down Expand Up @@ -29,17 +30,41 @@
_JAX_PLATFORM_PROXY = "proxy"
_JAX_BACKEND_TARGET_KEY = "jax_backend_target"
_JAX_BACKEND_TARGET_HOSTNAME = "grpc://127.0.0.1"
_DEFAULT_PROXY_IMAGE = "us-docker.pkg.dev/cloud-tpu-v2-images/pathways/proxy_server:latest"
DEFAULT_PROXY_IMAGE = (
"us-docker.pkg.dev/cloud-tpu-v2-images/pathways/proxy_server:latest"
)

_logger = logging.getLogger(__name__)


@dataclasses.dataclass
class ProxyOptions:
"""Configuration options for the Pathways proxy.

Attributes:
use_insecure_credentials: Whether to use insecure gRPC credentials for the
proxy server.
"""
use_insecure_credentials: bool = False

@classmethod
def from_dict(cls, options: Mapping[str, str] | None) -> "ProxyOptions":
"""Creates a ProxyOptions object from a dictionary of options."""
options = options or {}
use_insecure = (
options.get("use_insecure_credentials", "false").lower() == "true"
)
return cls(use_insecure_credentials=use_insecure)


def _deploy_pathways_proxy_server(
*, pathways_service: str,
*,
pathways_service: str,
proxy_job_name: str,
expected_instances: Mapping[Any, Any],
gcs_scratch_location: str,
proxy_server_image: str,
proxy_options: ProxyOptions | None = None,
) -> None:
"""Deploys the Pathways proxy pods to the GKE cluster.

Expand All @@ -50,6 +75,8 @@ def _deploy_pathways_proxy_server(
instances.
gcs_scratch_location: The Google Cloud Storage location to use.
proxy_server_image: The image to use for the proxy server.
proxy_options: Configuration options for the Pathways proxy. If not
provided, no extra options will be used.

Raises:
subprocess.CalledProcessError: If the kubectl command fails.
Expand All @@ -67,6 +94,13 @@ def _deploy_pathways_proxy_server(
instance_type, count = next(iter(expected_instances.items()))
instances_str = ",".join(instance_type for _ in range(count))

proxy_options = proxy_options or ProxyOptions()

proxy_env_str = (
' - name: IFRT_PROXY_USE_INSECURE_GRPC_CREDENTIALS\n'
' value: "true"\n'
) if proxy_options.use_insecure_credentials else ""

template = string.Template(yaml_template)
substituted_yaml = template.substitute(
PROXY_JOB_NAME=proxy_job_name,
Expand All @@ -76,6 +110,7 @@ def _deploy_pathways_proxy_server(
EXPECTED_INSTANCES=instances_str,
GCS_SCRATCH_LOCATION=gcs_scratch_location,
PROXY_SERVER_IMAGE=proxy_server_image,
PROXY_ENV=proxy_env_str,
)

_logger.info("Deploying Pathways proxy: %s", proxy_job_name)
Expand All @@ -97,6 +132,7 @@ class _ISCPathways:
of instances.
proxy_job_name: The name to use for the deployed proxy.
proxy_server_image: The image to use for the proxy server.
proxy_options: Configuration options for the Pathways proxy.
"""

def __init__(
Expand All @@ -109,6 +145,7 @@ def __init__(
expected_tpu_instances: Mapping[Any, Any],
proxy_job_name: str,
proxy_server_image: str,
proxy_options: ProxyOptions | None = None,
):
"""Initializes the TPU manager."""
self.cluster = cluster
Expand All @@ -121,14 +158,16 @@ def __init__(
self._port_forward_process = None
self._proxy_port = None
self.proxy_server_image = proxy_server_image
self.proxy_options = proxy_options or ProxyOptions()

def __repr__(self):
return (
f"_ISCPathways(cluster='{self.cluster}', project='{self.project}', "
f"region='{self.region}', bucket='{self.bucket}', "
f"pathways_service='{self.pathways_service}', "
f"expected_tpu_instances={self.expected_tpu_instances}, "
f"_proxy_job_name='{self._proxy_job_name}')"
f"_proxy_job_name='{self._proxy_job_name}', "
f"proxy_options={self.proxy_options})"
)

def __enter__(self):
Expand All @@ -140,6 +179,7 @@ def __enter__(self):
expected_instances=self.expected_tpu_instances,
gcs_scratch_location=self.bucket,
proxy_server_image=self.proxy_server_image,
proxy_options=self.proxy_options,
)
# Print a link to Cloud Logging
cloud_logging_link = gke_utils.get_log_link(
Expand Down Expand Up @@ -215,7 +255,8 @@ def connect(
pathways_service: str,
expected_tpu_instances: Mapping[str, int],
proxy_job_name: str | None = None,
proxy_server_image: str = _DEFAULT_PROXY_IMAGE,
proxy_server_image: str = DEFAULT_PROXY_IMAGE,
proxy_options: ProxyOptions | None = None,
) -> Iterator["_ISCPathways"]:
"""Connects to a Pathways server if the cluster exists. If not, creates it.

Expand All @@ -231,6 +272,8 @@ def connect(
random name will be generated.
proxy_server_image: The proxy server image to use. If not provided, a
default will be used.
proxy_options: Configuration options for the Pathways proxy. If not
provided, no extra options will be used.

Yields:
The Pathways manager.
Expand Down Expand Up @@ -259,5 +302,6 @@ def connect(
expected_tpu_instances=expected_tpu_instances,
proxy_job_name=proxy_job_name,
proxy_server_image=proxy_server_image,
proxy_options=proxy_options,
) as t:
yield t
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
from pathwaysutils.experimental.shared_pathways_service import isc_pathways


from google3.pyglib.flags.contrib import dict_flag

FLAGS = flags.FLAGS

flags.DEFINE_string("cluster", None, "The name of the GKE cluster.")
Expand All @@ -35,6 +37,12 @@
None,
"The proxy server image to use. If not provided, a default will be used.",
)
dict_flag.DEFINE_dict(
"proxy_options",
None,
"Configuration options for the Pathways proxy. Specify entries in the form"
' "key:value". For example: --proxy_options=use_insecure_credentials:true',
)

flags.mark_flags_as_required([
"cluster",
Expand All @@ -49,11 +57,7 @@ def main(argv: Sequence[str]) -> None:
if len(argv) > 1:
raise app.UsageError("Too many command-line arguments.")

kwargs = {}
if FLAGS.proxy_job_name:
kwargs["proxy_job_name"] = FLAGS.proxy_job_name
if FLAGS.proxy_server_image:
kwargs["proxy_server_image"] = FLAGS.proxy_server_image
proxy_options = isc_pathways.ProxyOptions.from_dict(FLAGS.proxy_options)

with isc_pathways.connect(
cluster=FLAGS.cluster,
Expand All @@ -62,7 +66,10 @@ def main(argv: Sequence[str]) -> None:
gcs_bucket=FLAGS.gcs_bucket,
pathways_service=FLAGS.pathways_service,
expected_tpu_instances={FLAGS.tpu_type: FLAGS.tpu_count},
**kwargs,
proxy_job_name=FLAGS.proxy_job_name,
proxy_server_image=FLAGS.proxy_server_image
or isc_pathways.DEFAULT_PROXY_IMAGE,
proxy_options=proxy_options,
):
orig_matrix = jnp.zeros(5)
result_matrix = orig_matrix + 1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ spec:
- --resource_manager_address=${PATHWAYS_HEAD_HOSTNAME}:${PATHWAYS_HEAD_PORT}
- --gcs_scratch_location=${GCS_SCRATCH_LOCATION}
- --virtual_slices=${EXPECTED_INSTANCES}
env:
${PROXY_ENV}
ports:
- containerPort: ${PROXY_SERVER_PORT}
protocol: TCP
Expand Down
Loading