Skip to content

Commit 46c089b

Browse files
[Custom Storage 2/3] Implement custom storage main (#462)
Co-authored-by: kmontemayor <kyle.e.montemayor@gmail.com>
1 parent fe0fc39 commit 46c089b

20 files changed

Lines changed: 1175 additions & 232 deletions

File tree

examples/link_prediction/graph_store/configs/e2e_hom_cora_sup_gs_task_config.yaml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,10 @@ inferencerConfig:
3030
num_neighbors: "[10, 10]" # Fanout per hop, specified as a string representation of a list for the homogeneous use case
3131
inferenceBatchSize: 512
3232
command: python -m examples.link_prediction.graph_store.homogeneous_inference
33+
graphStoreStorageConfig:
34+
command: python -m examples.link_prediction.graph_store.storage_main
35+
storageArgs:
36+
sample_edge_direction: "in"
3337
sharedConfig:
3438
shouldSkipInference: false
3539
# Model Evaluation is currently only supported for tabularized SGS GiGL pipelines. This will soon be added for in-mem SGS GiGL pipelines.
Lines changed: 265 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,265 @@
1+
"""GiGL Example Graph Store Server.
2+
3+
Derived from https://github.com/alibaba/graphlearn-for-pytorch/blob/main/examples/distributed/server_client_mode/sage_supervised_server.py
4+
5+
TODO(kmonte): Figure out how we should split out common utils from this file.
6+
7+
Cluster Setup
8+
=============
9+
10+
In Graph Store mode, storage nodes hold the graph data and serve sampling requests from compute nodes.
11+
Each storage node initializes a GLT (GraphLearn-Torch) server and waits for connections from compute nodes.
12+
13+
Storage nodes accept connections from compute nodes **sequentially, by compute node**. For example:
14+
- First, all connections from Compute Node 0 are established to Storage Nodes 0, 1, 2, ...
15+
- Then, all connections from Compute Node 1 are established to Storage Nodes 0, 1, 2, ...
16+
- And so on.
17+
18+
It's important to distinguish between:
19+
- **Compute Node**: A physical machine in the compute cluster (e.g., a VM with multiple GPUs).
20+
- **Compute Process**: A process running on a compute node (typically one per GPU).
21+
22+
Each compute node may have multiple compute processes (e.g., one per GPU), and each compute process
23+
establishes its own connection to every storage node. For example, if a compute node has 4 GPUs,
24+
it will establish 4 connections to each storage node.
25+
26+
This sequential connection setup is required because the GLT server uses a per-server lock when
27+
initializing samplers. If connections from multiple compute nodes were established concurrently,
28+
it could cause a deadlock.
29+
30+
Connection Diagram
31+
------------------
32+
33+
╔═══════════════════════════════════════════════════════════════════════════════════════╗
34+
║ COMPUTE TO STORAGE NODE CONNECTIONS ║
35+
╚═══════════════════════════════════════════════════════════════════════════════════════╝
36+
37+
COMPUTE NODES STORAGE NODES
38+
═════════════ ═════════════
39+
40+
┌──────────────────────┐ (1) ┌───────────────┐
41+
│ COMPUTE NODE 0 │ │ │
42+
│ ┌────┬────┬────┬────┤ ══════════════════════════════════│ STORAGE 0 │
43+
│ │GPU │GPU │GPU │GPU │ ╱ │ │
44+
│ │ 0 │ 1 │ 2 │ 3 │ ════════════════════╲ ╱ └───────────────┘
45+
│ └────┴────┴────┴────┤ (2) ╲ ╱
46+
└──────────────────────┘ ╲ ╱
47+
48+
(3) ╱ ╲ (4)
49+
┌──────────────────────┐ ╱ ╲ ┌───────────────┐
50+
│ COMPUTE NODE 1 │ ╱ ╲ │ │
51+
│ ┌────┬────┬────┬────┤ ═════════════════╱ ═│ STORAGE 1 │
52+
│ │GPU │GPU │GPU │GPU │ │ │
53+
│ │ 0 │ 1 │ 2 │ 3 │ ══════════════════════════════════│ │
54+
│ └────┴────┴────┴────┤ └───────────────┘
55+
└──────────────────────┘
56+
57+
┌─────────────────────────────────────────────────────────────────────────────┐
58+
│ (1) Compute Node 0 → Storage 0 (4 connections, one per GPU) │
59+
│ (2) Compute Node 0 → Storage 1 (4 connections, one per GPU) │
60+
│ (3) Compute Node 1 → Storage 0 (4 connections, one per GPU) │
61+
│ (4) Compute Node 1 → Storage 1 (4 connections, one per GPU) │
62+
└─────────────────────────────────────────────────────────────────────────────┘
63+
64+
Storage nodes wait for all compute processes to connect, then serve sampling requests until
65+
the compute processes signal shutdown via `gigl.distributed.graph_store.compute.shutdown_compute_process`.
66+
67+
"""
68+
import argparse
69+
import os
70+
from distutils.util import strtobool
71+
from typing import Literal, Optional
72+
73+
# TODO(kmonte): Remove GLT imports from this file.
74+
import graphlearn_torch as glt
75+
import torch
76+
77+
from gigl.common import Uri, UriFactory
78+
from gigl.common.logger import Logger
79+
from gigl.distributed.dataset_factory import build_dataset
80+
from gigl.distributed.dist_dataset import DistDataset
81+
from gigl.distributed.dist_range_partitioner import DistRangePartitioner
82+
from gigl.distributed.graph_store.storage_utils import register_dataset
83+
from gigl.distributed.utils import get_free_ports_from_master_node, get_graph_store_info
84+
from gigl.distributed.utils.networking import get_free_ports_from_master_node
85+
from gigl.distributed.utils.serialized_graph_metadata_translator import (
86+
convert_pb_to_serialized_graph_metadata,
87+
)
88+
from gigl.env.distributed import GraphStoreInfo
89+
from gigl.src.common.types.pb_wrappers.gbml_config import GbmlConfigPbWrapper
90+
91+
logger = Logger()
92+
93+
94+
def _run_storage_process(
95+
storage_rank: int,
96+
cluster_info: GraphStoreInfo,
97+
dataset: DistDataset,
98+
torch_process_port: int,
99+
storage_world_backend: Optional[str],
100+
) -> None:
101+
"""
102+
Runs a storage process.
103+
104+
This function does the following:
105+
106+
1. "Registers" the dataset so that gigl.distributed.graph_store.remote_dist_dataset.RemoteDistDataset can access it.
107+
2. Initialized the GLT server.
108+
Under the hood this is synchronized with the clients initializing via gigl.distributed.graph_store.compute.init_compute_process,
109+
and after this call there will be Torch RPC connections between the storage nodes and compute nodes.
110+
3. Initializes the Torch Distributed process group for the storage node.
111+
4. Waits for the server to exit.
112+
Will wait until clients are also shutdown (with `gigl.distributed.graph_store.compute.shutdown_compute_proccess`)
113+
114+
Args:
115+
storage_rank (int): The rank of the storage node.
116+
cluster_info (GraphStoreInfo): The cluster information.
117+
dataset (DistDataset): The dataset.
118+
torch_process_port (int): The port for the Torch process.
119+
storage_world_backend (Optional[str]): The backend for the storage Torch Distributed process group.
120+
"""
121+
122+
# "Register" the dataset so that gigl.distributed.graph_store.remote_dist_dataset.RemoteDistDataset can access it.
123+
register_dataset(dataset)
124+
cluster_master_ip = cluster_info.storage_cluster_master_ip
125+
logger.info(
126+
f"Initializing GLT server for storage node process group {storage_rank} / {cluster_info.num_storage_nodes} on {cluster_master_ip}:{cluster_info.rpc_master_port}"
127+
)
128+
# Initialize the GLT server before starting the Torch Distributed process group.
129+
# Otherwise, we saw intermittent hangs when initializing the server.
130+
glt.distributed.init_server(
131+
num_servers=cluster_info.num_storage_nodes,
132+
server_rank=storage_rank,
133+
dataset=dataset,
134+
master_addr=cluster_master_ip,
135+
master_port=cluster_info.rpc_master_port,
136+
num_clients=cluster_info.compute_cluster_world_size,
137+
)
138+
139+
init_method = f"tcp://{cluster_info.storage_cluster_master_ip}:{torch_process_port}"
140+
logger.info(
141+
f"Initializing storage node process group {storage_rank} / {cluster_info.num_storage_nodes} with backend {storage_world_backend} on {init_method}"
142+
)
143+
144+
# Torch Distributed process group is needed so that the storage cluster can talk to each other.
145+
# This is needed for `RemoteDistDataset.get_free_ports_on_storage_cluster` to work.
146+
# Note this is called on the *compute* cluster, but requires the storage cluster to have a process group initialized.
147+
torch.distributed.init_process_group(
148+
backend=storage_world_backend,
149+
world_size=cluster_info.num_storage_nodes,
150+
rank=storage_rank,
151+
init_method=init_method,
152+
)
153+
154+
logger.info(
155+
f"Waiting for storage node {storage_rank} / {cluster_info.num_storage_nodes} to exit"
156+
)
157+
# Wait for the server to exit.
158+
# Will wait until clients are also shutdown (with `gigl.distributed.graph_store.compute.shutdown_compute_proccess`)
159+
glt.distributed.wait_and_shutdown_server()
160+
logger.info(f"Storage node {storage_rank} exited")
161+
162+
163+
def storage_node_process(
164+
storage_rank: int,
165+
cluster_info: GraphStoreInfo,
166+
task_config_uri: Uri,
167+
sample_edge_direction: Literal["in", "out"],
168+
should_load_tf_records_in_parallel: bool = True,
169+
tf_record_uri_pattern: str = ".*-of-.*\.tfrecord(\.gz)?$",
170+
storage_world_backend: Optional[str] = None,
171+
) -> None:
172+
"""Run a storage node process
173+
174+
Should be called *once* per storage node (machine).
175+
176+
Args:
177+
storage_rank (int): The rank of the storage node.
178+
cluster_info (GraphStoreInfo): The cluster information.
179+
task_config_uri (Uri): The task config URI.
180+
is_inference (bool): Whether the process is an inference process. Defaults to True.
181+
tf_record_uri_pattern (str): The TF Record URI pattern.
182+
storage_world_backend (Optional[str]): The backend for the storage Torch Distributed process group.
183+
"""
184+
init_method = f"tcp://{cluster_info.storage_cluster_master_ip}:{cluster_info.storage_cluster_master_port}"
185+
logger.info(
186+
f"Initializing storage node {storage_rank} / {cluster_info.num_storage_nodes}. OS rank: {os.environ['RANK']}, OS world size: {os.environ['WORLD_SIZE']} init method: {init_method}"
187+
)
188+
torch.distributed.init_process_group(
189+
backend="gloo",
190+
world_size=cluster_info.num_storage_nodes,
191+
rank=storage_rank,
192+
init_method=init_method,
193+
group_name="gigl_server_comms",
194+
)
195+
logger.info(
196+
f"Storage node {storage_rank} / {cluster_info.num_storage_nodes} process group initialized"
197+
)
198+
gbml_config_pb_wrapper = GbmlConfigPbWrapper.get_gbml_config_pb_wrapper_from_uri(
199+
gbml_config_uri=task_config_uri
200+
)
201+
serialized_graph_metadata = convert_pb_to_serialized_graph_metadata(
202+
preprocessed_metadata_pb_wrapper=gbml_config_pb_wrapper.preprocessed_metadata_pb_wrapper,
203+
graph_metadata_pb_wrapper=gbml_config_pb_wrapper.graph_metadata_pb_wrapper,
204+
tfrecord_uri_pattern=tf_record_uri_pattern,
205+
)
206+
dataset = build_dataset(
207+
serialized_graph_metadata=serialized_graph_metadata,
208+
sample_edge_direction=sample_edge_direction,
209+
should_load_tensors_in_parallel=should_load_tf_records_in_parallel,
210+
partitioner_class=DistRangePartitioner,
211+
)
212+
torch_process_port = get_free_ports_from_master_node(num_ports=1)[0]
213+
torch.distributed.destroy_process_group()
214+
server_processes = []
215+
mp_context = torch.multiprocessing.get_context("spawn")
216+
# TODO(kmonte): Enable more than one server process per machine
217+
for i in range(1):
218+
server_process = mp_context.Process(
219+
target=_run_storage_process,
220+
args=(
221+
storage_rank + i, # storage_rank
222+
cluster_info, # cluster_info
223+
dataset, # dataset
224+
torch_process_port, # torch_process_port
225+
storage_world_backend, # storage_world_backend
226+
),
227+
)
228+
server_processes.append(server_process)
229+
for server_process in server_processes:
230+
server_process.start()
231+
for server_process in server_processes:
232+
server_process.join()
233+
234+
235+
if __name__ == "__main__":
236+
# TODO(kmonte): We want to expose splitter class here probably.
237+
parser = argparse.ArgumentParser()
238+
parser.add_argument("--task_config_uri", type=str, required=True)
239+
parser.add_argument("--resource_config_uri", type=str, required=True)
240+
parser.add_argument("--job_name", type=str, required=True)
241+
parser.add_argument("--sample_edge_direction", type=str, required=True)
242+
parser.add_argument("--should_load_tf_records_in_parallel", type=str, default="True")
243+
args = parser.parse_args()
244+
logger.info(f"Running storage node with arguments: {args}")
245+
246+
# Setup cluster-wide (e.g. storage and compute nodes) Torch Distributed process group.
247+
# This is needed so we can get the cluster information (e.g. number of storage and compute nodes) and rank/world_size.
248+
torch.distributed.init_process_group(backend="gloo")
249+
cluster_info = get_graph_store_info()
250+
logger.info(f"Cluster info: {cluster_info}")
251+
logger.info(
252+
f"World size: {torch.distributed.get_world_size()}, rank: {torch.distributed.get_rank()}, OS world size: {os.environ['WORLD_SIZE']}, OS rank: {os.environ['RANK']}"
253+
)
254+
# Tear down the """"global""" process group so we can have a server-specific process group.
255+
256+
torch.distributed.destroy_process_group()
257+
storage_node_process(
258+
storage_rank=cluster_info.storage_node_rank,
259+
cluster_info=cluster_info,
260+
task_config_uri=UriFactory.create_uri(args.task_config_uri),
261+
sample_edge_direction=args.sample_edge_direction,
262+
should_load_tf_records_in_parallel=bool(
263+
strtobool(args.should_load_tf_records_in_parallel)
264+
),
265+
)

gigl/distributed/graph_store/storage_main.py

Lines changed: 26 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,30 @@
11
"""Built-in GiGL Graph Store Server.
22
3-
Derivved from https://github.com/alibaba/graphlearn-for-pytorch/blob/main/examples/distributed/server_client_mode/sage_supervised_server.py
3+
Derived from https://github.com/alibaba/graphlearn-for-pytorch/blob/main/examples/distributed/server_client_mode/sage_supervised_server.py
44
5+
TODO(kmonte): Remove this, and only expose utils.
6+
We keep this around so we can use the utils in tests/integration/distributed/graph_store/graph_store_integration_test.py.
57
"""
68
import argparse
79
import os
8-
from typing import Optional
10+
from typing import Literal, Optional
911

1012
import graphlearn_torch as glt
1113
import torch
1214

1315
from gigl.common import Uri, UriFactory
1416
from gigl.common.logger import Logger
15-
from gigl.distributed import build_dataset_from_task_config_uri
17+
from gigl.distributed.dataset_factory import build_dataset
1618
from gigl.distributed.dist_dataset import DistDataset
19+
from gigl.distributed.dist_range_partitioner import DistRangePartitioner
1720
from gigl.distributed.graph_store.storage_utils import register_dataset
18-
from gigl.distributed.utils import get_graph_store_info
21+
from gigl.distributed.utils import get_free_ports_from_master_node, get_graph_store_info
1922
from gigl.distributed.utils.networking import get_free_ports_from_master_node
23+
from gigl.distributed.utils.serialized_graph_metadata_translator import (
24+
convert_pb_to_serialized_graph_metadata,
25+
)
2026
from gigl.env.distributed import GraphStoreInfo
27+
from gigl.src.common.types.pb_wrappers.gbml_config import GbmlConfigPbWrapper
2128

2229
logger = Logger()
2330

@@ -67,7 +74,7 @@ def storage_node_process(
6774
storage_rank: int,
6875
cluster_info: GraphStoreInfo,
6976
task_config_uri: Uri,
70-
is_inference: bool = True,
77+
sample_edge_direction: Literal["in", "out"],
7178
tf_record_uri_pattern: str = ".*-of-.*\.tfrecord(\.gz)?$",
7279
storage_world_backend: Optional[str] = None,
7380
) -> None:
@@ -97,10 +104,18 @@ def storage_node_process(
97104
logger.info(
98105
f"Storage node {storage_rank} / {cluster_info.num_storage_nodes} process group initialized"
99106
)
100-
dataset = build_dataset_from_task_config_uri(
101-
task_config_uri=task_config_uri,
102-
is_inference=is_inference,
103-
_tfrecord_uri_pattern=tf_record_uri_pattern,
107+
gbml_config_pb_wrapper = GbmlConfigPbWrapper.get_gbml_config_pb_wrapper_from_uri(
108+
gbml_config_uri=task_config_uri
109+
)
110+
serialized_graph_metadata = convert_pb_to_serialized_graph_metadata(
111+
preprocessed_metadata_pb_wrapper=gbml_config_pb_wrapper.preprocessed_metadata_pb_wrapper,
112+
graph_metadata_pb_wrapper=gbml_config_pb_wrapper.graph_metadata_pb_wrapper,
113+
tfrecord_uri_pattern=tf_record_uri_pattern,
114+
)
115+
dataset = build_dataset(
116+
serialized_graph_metadata=serialized_graph_metadata,
117+
sample_edge_direction=sample_edge_direction,
118+
partitioner_class=DistRangePartitioner,
104119
)
105120
torch_process_port = get_free_ports_from_master_node(num_ports=1)[0]
106121
torch.distributed.destroy_process_group()
@@ -130,6 +145,7 @@ def storage_node_process(
130145
parser.add_argument("--task_config_uri", type=str, required=True)
131146
parser.add_argument("--resource_config_uri", type=str, required=True)
132147
parser.add_argument("--job_name", type=str, required=True)
148+
parser.add_argument("--sample_edge_direction", type=str, required=True)
133149
args = parser.parse_args()
134150
logger.info(f"Running storage node with arguments: {args}")
135151

@@ -145,4 +161,5 @@ def storage_node_process(
145161
storage_rank=cluster_info.storage_node_rank,
146162
cluster_info=cluster_info,
147163
task_config_uri=UriFactory.create_uri(args.task_config_uri),
164+
sample_edge_direction=args.sample_edge_direction,
148165
)

0 commit comments

Comments
 (0)