Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion gigl/distributed/dist_ablp_neighborloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -700,7 +700,11 @@ def _setup_for_graph_store(
# Extract supervision edge types and derive label edge types from the
# ABLPInputNodes.labels dict (keyed by supervision edge type).
self._supervision_edge_types = list(first_input.labels.keys())
has_negatives = any(neg is not None for _, neg in first_input.labels.values())
has_negatives = any(
negative_labels is not None
for ablp_input in input_nodes.values()
for _, negative_labels in ablp_input.labels.values()
)

self._positive_label_edge_types = [
message_passing_to_positive_label(et) for et in self._supervision_edge_types
Expand Down
8 changes: 7 additions & 1 deletion gigl/distributed/graph_store/dist_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
FetchABLPInputRequest,
FetchNodesRequest,
)
from gigl.distributed.graph_store.sharding import ServerSlice
from gigl.distributed.sampler import ABLPNodeSamplerInput
from gigl.distributed.sampler_options import SamplerOptions
from gigl.distributed.utils.neighborloader import shard_nodes_by_process
Expand Down Expand Up @@ -282,7 +283,7 @@ def get_node_ids(

Args:
request: The node-fetch request, including split, node type,
and round-robin rank/world_size.
and either round-robin rank/world_size or a contiguous slice.

Returns:
The node ids.
Expand All @@ -305,6 +306,7 @@ def get_node_ids(
node_type=request.node_type,
rank=request.rank,
world_size=request.world_size,
server_slice=request.server_slice,
)

def _get_node_ids(
Expand All @@ -313,6 +315,7 @@ def _get_node_ids(
node_type: Optional[NodeType],
rank: Optional[int] = None,
world_size: Optional[int] = None,
server_slice: Optional[ServerSlice] = None,
) -> torch.Tensor:
"""Core implementation for fetching node IDs by split, type, and sharding.

Expand Down Expand Up @@ -365,6 +368,8 @@ def _get_node_ids(
f"node_type was not provided, so node ids must be a torch.Tensor (e.g. a homogeneous dataset), got {type(nodes)}."
)

if server_slice is not None:
return server_slice.slice_tensor(nodes)
if rank is not None and world_size is not None:
return shard_nodes_by_process(nodes, rank, world_size)
return nodes
Expand Down Expand Up @@ -419,6 +424,7 @@ def get_ablp_input(
node_type=request.node_type,
rank=request.rank,
world_size=request.world_size,
server_slice=request.server_slice,
)
positive_label_edge_type, negative_label_edge_type = select_label_edge_types(
request.supervision_edge_type, self.dataset.get_edge_types()
Expand Down
19 changes: 17 additions & 2 deletions gigl/distributed/graph_store/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from dataclasses import dataclass
from typing import Literal, Optional, Union

from gigl.distributed.graph_store.sharding import ServerSlice
from gigl.src.common.types.graph_data import EdgeType, NodeType


Expand Down Expand Up @@ -36,18 +37,25 @@ class FetchNodesRequest:
world_size: Optional[int] = None
split: Optional[Union[Literal["train", "val", "test"], str]] = None
node_type: Optional[NodeType] = None
server_slice: Optional[ServerSlice] = None

def validate(self) -> None:
"""Validate that the request has consistent rank/world_size.

Raises:
ValueError: If only one of ``rank`` or ``world_size`` is provided.
ValueError:
If only one of ``rank`` or ``world_size`` is provided.
If ``server_slice`` is provided together with ``rank`` or ``world_size``.
"""
if (self.rank is None) ^ (self.world_size is None):
raise ValueError(
"rank and world_size must be provided together. "
f"Received rank={self.rank}, world_size={self.world_size}"
)
if self.server_slice is not None and (
self.rank is not None or self.world_size is not None
):
raise ValueError("server_slice cannot be combined with rank/world_size.")


@dataclass(frozen=True)
Expand Down Expand Up @@ -78,15 +86,22 @@ class FetchABLPInputRequest:
supervision_edge_type: EdgeType
rank: Optional[int] = None
world_size: Optional[int] = None
server_slice: Optional[ServerSlice] = None

def validate(self) -> None:
"""Validate that the request has consistent rank/world_size.

Raises:
ValueError: If only one of ``rank`` or ``world_size`` is provided.
ValueError:
If only one of ``rank`` or ``world_size`` is provided.
If ``server_slice`` is provided together with ``rank`` or ``world_size``.
"""
if (self.rank is None) ^ (self.world_size is None):
raise ValueError(
"rank and world_size must be provided together. "
f"Received rank={self.rank}, world_size={self.world_size}"
)
if self.server_slice is not None and (
self.rank is not None or self.world_size is not None
):
raise ValueError("server_slice cannot be combined with rank/world_size.")
Loading