|
| 1 | +"""GiGL Example Graph Store Server. |
| 2 | +
|
| 3 | +Derived from https://github.com/alibaba/graphlearn-for-pytorch/blob/main/examples/distributed/server_client_mode/sage_supervised_server.py |
| 4 | +
|
| 5 | +TODO(kmonte): Figure out how we should split out common utils from this file. |
| 6 | +
|
| 7 | +Cluster Setup |
| 8 | +============= |
| 9 | +
|
| 10 | +In Graph Store mode, storage nodes hold the graph data and serve sampling requests from compute nodes. |
| 11 | +Each storage node initializes a GLT (GraphLearn-Torch) server and waits for connections from compute nodes. |
| 12 | +
|
| 13 | +Storage nodes accept connections from compute nodes **sequentially, by compute node**. For example: |
| 14 | +- First, all connections from Compute Node 0 are established to Storage Nodes 0, 1, 2, ... |
| 15 | +- Then, all connections from Compute Node 1 are established to Storage Nodes 0, 1, 2, ... |
| 16 | +- And so on. |
| 17 | +
|
| 18 | +It's important to distinguish between: |
| 19 | +- **Compute Node**: A physical machine in the compute cluster (e.g., a VM with multiple GPUs). |
| 20 | +- **Compute Process**: A process running on a compute node (typically one per GPU). |
| 21 | +
|
| 22 | +Each compute node may have multiple compute processes (e.g., one per GPU), and each compute process |
| 23 | +establishes its own connection to every storage node. For example, if a compute node has 4 GPUs, |
| 24 | +it will establish 4 connections to each storage node. |
| 25 | +
|
| 26 | +This sequential connection setup is required because the GLT server uses a per-server lock when |
| 27 | +initializing samplers. If connections from multiple compute nodes were established concurrently, |
| 28 | +it could cause a deadlock. |
| 29 | +
|
| 30 | +Connection Diagram |
| 31 | +------------------ |
| 32 | +
|
| 33 | +╔═══════════════════════════════════════════════════════════════════════════════════════╗ |
| 34 | +║ COMPUTE TO STORAGE NODE CONNECTIONS ║ |
| 35 | +╚═══════════════════════════════════════════════════════════════════════════════════════╝ |
| 36 | +
|
| 37 | + COMPUTE NODES STORAGE NODES |
| 38 | + ═════════════ ═════════════ |
| 39 | +
|
| 40 | + ┌──────────────────────┐ (1) ┌───────────────┐ |
| 41 | + │ COMPUTE NODE 0 │ │ │ |
| 42 | + │ ┌────┬────┬────┬────┤ ══════════════════════════════════│ STORAGE 0 │ |
| 43 | + │ │GPU │GPU │GPU │GPU │ ╱ │ │ |
| 44 | + │ │ 0 │ 1 │ 2 │ 3 │ ════════════════════╲ ╱ └───────────────┘ |
| 45 | + │ └────┴────┴────┴────┤ (2) ╲ ╱ |
| 46 | + └──────────────────────┘ ╲ ╱ |
| 47 | + ╳ |
| 48 | + (3) ╱ ╲ (4) |
| 49 | + ┌──────────────────────┐ ╱ ╲ ┌───────────────┐ |
| 50 | + │ COMPUTE NODE 1 │ ╱ ╲ │ │ |
| 51 | + │ ┌────┬────┬────┬────┤ ═════════════════╱ ═│ STORAGE 1 │ |
| 52 | + │ │GPU │GPU │GPU │GPU │ │ │ |
| 53 | + │ │ 0 │ 1 │ 2 │ 3 │ ══════════════════════════════════│ │ |
| 54 | + │ └────┴────┴────┴────┤ └───────────────┘ |
| 55 | + └──────────────────────┘ |
| 56 | +
|
| 57 | + ┌─────────────────────────────────────────────────────────────────────────────┐ |
| 58 | + │ (1) Compute Node 0 → Storage 0 (4 connections, one per GPU) │ |
| 59 | + │ (2) Compute Node 0 → Storage 1 (4 connections, one per GPU) │ |
| 60 | + │ (3) Compute Node 1 → Storage 0 (4 connections, one per GPU) │ |
| 61 | + │ (4) Compute Node 1 → Storage 1 (4 connections, one per GPU) │ |
| 62 | + └─────────────────────────────────────────────────────────────────────────────┘ |
| 63 | +
|
| 64 | +Storage nodes wait for all compute processes to connect, then serve sampling requests until |
| 65 | +the compute processes signal shutdown via `gigl.distributed.graph_store.compute.shutdown_compute_process`. |
| 66 | +
|
| 67 | +""" |
| 68 | +import argparse |
| 69 | +import os |
| 70 | +from distutils.util import strtobool |
| 71 | +from typing import Literal, Optional |
| 72 | + |
| 73 | +# TODO(kmonte): Remove GLT imports from this file. |
| 74 | +import graphlearn_torch as glt |
| 75 | +import torch |
| 76 | + |
| 77 | +from gigl.common import Uri, UriFactory |
| 78 | +from gigl.common.logger import Logger |
| 79 | +from gigl.distributed.dataset_factory import build_dataset |
| 80 | +from gigl.distributed.dist_dataset import DistDataset |
| 81 | +from gigl.distributed.dist_range_partitioner import DistRangePartitioner |
| 82 | +from gigl.distributed.graph_store.storage_utils import register_dataset |
| 83 | +from gigl.distributed.utils import get_free_ports_from_master_node, get_graph_store_info |
| 84 | +from gigl.distributed.utils.networking import get_free_ports_from_master_node |
| 85 | +from gigl.distributed.utils.serialized_graph_metadata_translator import ( |
| 86 | + convert_pb_to_serialized_graph_metadata, |
| 87 | +) |
| 88 | +from gigl.env.distributed import GraphStoreInfo |
| 89 | +from gigl.src.common.types.pb_wrappers.gbml_config import GbmlConfigPbWrapper |
| 90 | + |
| 91 | +logger = Logger() |
| 92 | + |
| 93 | + |
| 94 | +def _run_storage_process( |
| 95 | + storage_rank: int, |
| 96 | + cluster_info: GraphStoreInfo, |
| 97 | + dataset: DistDataset, |
| 98 | + torch_process_port: int, |
| 99 | + storage_world_backend: Optional[str], |
| 100 | +) -> None: |
| 101 | + """ |
| 102 | + Runs a storage process. |
| 103 | +
|
| 104 | + This function does the following: |
| 105 | +
|
| 106 | + 1. "Registers" the dataset so that gigl.distributed.graph_store.remote_dist_dataset.RemoteDistDataset can access it. |
| 107 | + 2. Initialized the GLT server. |
| 108 | + Under the hood this is synchronized with the clients initializing via gigl.distributed.graph_store.compute.init_compute_process, |
| 109 | + and after this call there will be Torch RPC connections between the storage nodes and compute nodes. |
| 110 | + 3. Initializes the Torch Distributed process group for the storage node. |
| 111 | + 4. Waits for the server to exit. |
| 112 | + Will wait until clients are also shutdown (with `gigl.distributed.graph_store.compute.shutdown_compute_proccess`) |
| 113 | +
|
| 114 | + Args: |
| 115 | + storage_rank (int): The rank of the storage node. |
| 116 | + cluster_info (GraphStoreInfo): The cluster information. |
| 117 | + dataset (DistDataset): The dataset. |
| 118 | + torch_process_port (int): The port for the Torch process. |
| 119 | + storage_world_backend (Optional[str]): The backend for the storage Torch Distributed process group. |
| 120 | + """ |
| 121 | + |
| 122 | + # "Register" the dataset so that gigl.distributed.graph_store.remote_dist_dataset.RemoteDistDataset can access it. |
| 123 | + register_dataset(dataset) |
| 124 | + cluster_master_ip = cluster_info.storage_cluster_master_ip |
| 125 | + logger.info( |
| 126 | + f"Initializing GLT server for storage node process group {storage_rank} / {cluster_info.num_storage_nodes} on {cluster_master_ip}:{cluster_info.rpc_master_port}" |
| 127 | + ) |
| 128 | + # Initialize the GLT server before starting the Torch Distributed process group. |
| 129 | + # Otherwise, we saw intermittent hangs when initializing the server. |
| 130 | + glt.distributed.init_server( |
| 131 | + num_servers=cluster_info.num_storage_nodes, |
| 132 | + server_rank=storage_rank, |
| 133 | + dataset=dataset, |
| 134 | + master_addr=cluster_master_ip, |
| 135 | + master_port=cluster_info.rpc_master_port, |
| 136 | + num_clients=cluster_info.compute_cluster_world_size, |
| 137 | + ) |
| 138 | + |
| 139 | + init_method = f"tcp://{cluster_info.storage_cluster_master_ip}:{torch_process_port}" |
| 140 | + logger.info( |
| 141 | + f"Initializing storage node process group {storage_rank} / {cluster_info.num_storage_nodes} with backend {storage_world_backend} on {init_method}" |
| 142 | + ) |
| 143 | + |
| 144 | + # Torch Distributed process group is needed so that the storage cluster can talk to each other. |
| 145 | + # This is needed for `RemoteDistDataset.get_free_ports_on_storage_cluster` to work. |
| 146 | + # Note this is called on the *compute* cluster, but requires the storage cluster to have a process group initialized. |
| 147 | + torch.distributed.init_process_group( |
| 148 | + backend=storage_world_backend, |
| 149 | + world_size=cluster_info.num_storage_nodes, |
| 150 | + rank=storage_rank, |
| 151 | + init_method=init_method, |
| 152 | + ) |
| 153 | + |
| 154 | + logger.info( |
| 155 | + f"Waiting for storage node {storage_rank} / {cluster_info.num_storage_nodes} to exit" |
| 156 | + ) |
| 157 | + # Wait for the server to exit. |
| 158 | + # Will wait until clients are also shutdown (with `gigl.distributed.graph_store.compute.shutdown_compute_proccess`) |
| 159 | + glt.distributed.wait_and_shutdown_server() |
| 160 | + logger.info(f"Storage node {storage_rank} exited") |
| 161 | + |
| 162 | + |
| 163 | +def storage_node_process( |
| 164 | + storage_rank: int, |
| 165 | + cluster_info: GraphStoreInfo, |
| 166 | + task_config_uri: Uri, |
| 167 | + sample_edge_direction: Literal["in", "out"], |
| 168 | + should_load_tf_records_in_parallel: bool = True, |
| 169 | + tf_record_uri_pattern: str = ".*-of-.*\.tfrecord(\.gz)?$", |
| 170 | + storage_world_backend: Optional[str] = None, |
| 171 | +) -> None: |
| 172 | + """Run a storage node process |
| 173 | +
|
| 174 | + Should be called *once* per storage node (machine). |
| 175 | +
|
| 176 | + Args: |
| 177 | + storage_rank (int): The rank of the storage node. |
| 178 | + cluster_info (GraphStoreInfo): The cluster information. |
| 179 | + task_config_uri (Uri): The task config URI. |
| 180 | + is_inference (bool): Whether the process is an inference process. Defaults to True. |
| 181 | + tf_record_uri_pattern (str): The TF Record URI pattern. |
| 182 | + storage_world_backend (Optional[str]): The backend for the storage Torch Distributed process group. |
| 183 | + """ |
| 184 | + init_method = f"tcp://{cluster_info.storage_cluster_master_ip}:{cluster_info.storage_cluster_master_port}" |
| 185 | + logger.info( |
| 186 | + f"Initializing storage node {storage_rank} / {cluster_info.num_storage_nodes}. OS rank: {os.environ['RANK']}, OS world size: {os.environ['WORLD_SIZE']} init method: {init_method}" |
| 187 | + ) |
| 188 | + torch.distributed.init_process_group( |
| 189 | + backend="gloo", |
| 190 | + world_size=cluster_info.num_storage_nodes, |
| 191 | + rank=storage_rank, |
| 192 | + init_method=init_method, |
| 193 | + group_name="gigl_server_comms", |
| 194 | + ) |
| 195 | + logger.info( |
| 196 | + f"Storage node {storage_rank} / {cluster_info.num_storage_nodes} process group initialized" |
| 197 | + ) |
| 198 | + gbml_config_pb_wrapper = GbmlConfigPbWrapper.get_gbml_config_pb_wrapper_from_uri( |
| 199 | + gbml_config_uri=task_config_uri |
| 200 | + ) |
| 201 | + serialized_graph_metadata = convert_pb_to_serialized_graph_metadata( |
| 202 | + preprocessed_metadata_pb_wrapper=gbml_config_pb_wrapper.preprocessed_metadata_pb_wrapper, |
| 203 | + graph_metadata_pb_wrapper=gbml_config_pb_wrapper.graph_metadata_pb_wrapper, |
| 204 | + tfrecord_uri_pattern=tf_record_uri_pattern, |
| 205 | + ) |
| 206 | + dataset = build_dataset( |
| 207 | + serialized_graph_metadata=serialized_graph_metadata, |
| 208 | + sample_edge_direction=sample_edge_direction, |
| 209 | + should_load_tensors_in_parallel=should_load_tf_records_in_parallel, |
| 210 | + partitioner_class=DistRangePartitioner, |
| 211 | + ) |
| 212 | + torch_process_port = get_free_ports_from_master_node(num_ports=1)[0] |
| 213 | + torch.distributed.destroy_process_group() |
| 214 | + server_processes = [] |
| 215 | + mp_context = torch.multiprocessing.get_context("spawn") |
| 216 | + # TODO(kmonte): Enable more than one server process per machine |
| 217 | + for i in range(1): |
| 218 | + server_process = mp_context.Process( |
| 219 | + target=_run_storage_process, |
| 220 | + args=( |
| 221 | + storage_rank + i, # storage_rank |
| 222 | + cluster_info, # cluster_info |
| 223 | + dataset, # dataset |
| 224 | + torch_process_port, # torch_process_port |
| 225 | + storage_world_backend, # storage_world_backend |
| 226 | + ), |
| 227 | + ) |
| 228 | + server_processes.append(server_process) |
| 229 | + for server_process in server_processes: |
| 230 | + server_process.start() |
| 231 | + for server_process in server_processes: |
| 232 | + server_process.join() |
| 233 | + |
| 234 | + |
| 235 | +if __name__ == "__main__": |
| 236 | + # TODO(kmonte): We want to expose splitter class here probably. |
| 237 | + parser = argparse.ArgumentParser() |
| 238 | + parser.add_argument("--task_config_uri", type=str, required=True) |
| 239 | + parser.add_argument("--resource_config_uri", type=str, required=True) |
| 240 | + parser.add_argument("--job_name", type=str, required=True) |
| 241 | + parser.add_argument("--sample_edge_direction", type=str, required=True) |
| 242 | + parser.add_argument("--should_load_tf_records_in_parallel", type=str, default="True") |
| 243 | + args = parser.parse_args() |
| 244 | + logger.info(f"Running storage node with arguments: {args}") |
| 245 | + |
| 246 | + # Setup cluster-wide (e.g. storage and compute nodes) Torch Distributed process group. |
| 247 | + # This is needed so we can get the cluster information (e.g. number of storage and compute nodes) and rank/world_size. |
| 248 | + torch.distributed.init_process_group(backend="gloo") |
| 249 | + cluster_info = get_graph_store_info() |
| 250 | + logger.info(f"Cluster info: {cluster_info}") |
| 251 | + logger.info( |
| 252 | + f"World size: {torch.distributed.get_world_size()}, rank: {torch.distributed.get_rank()}, OS world size: {os.environ['WORLD_SIZE']}, OS rank: {os.environ['RANK']}" |
| 253 | + ) |
| 254 | + # Tear down the """"global""" process group so we can have a server-specific process group. |
| 255 | + |
| 256 | + torch.distributed.destroy_process_group() |
| 257 | + storage_node_process( |
| 258 | + storage_rank=cluster_info.storage_node_rank, |
| 259 | + cluster_info=cluster_info, |
| 260 | + task_config_uri=UriFactory.create_uri(args.task_config_uri), |
| 261 | + sample_edge_direction=args.sample_edge_direction, |
| 262 | + should_load_tf_records_in_parallel=bool( |
| 263 | + strtobool(args.should_load_tf_records_in_parallel) |
| 264 | + ), |
| 265 | + ) |
0 commit comments