Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 82 additions & 31 deletions gigl/distributed/dist_ablp_neighborloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@
reverse_edge_type,
select_label_edge_types,
)
from gigl.utils.data_splitters import get_labels_for_anchor_nodes
from gigl.utils.data_splitters import PADDING_NODE, get_labels_for_anchor_nodes
from gigl.utils.sampling import ABLPInputNodes

logger = Logger()
Expand Down Expand Up @@ -755,6 +755,58 @@ def _setup_for_graph_store(
),
)

def _compute_label_matches(
self,
local_to_global: torch.Tensor,
label_tensor: torch.Tensor,
num_anchors: int,
) -> dict[int, torch.Tensor]:
"""
Compute label matches using fully vectorized operations.

Args:
local_to_global: [N] tensor mapping local node idx to global node ID
label_tensor: [A, M] tensor of label global node IDs (padded with PADDING_NODE)
num_anchors: Number of anchor nodes (A)

Returns:
dict[int, torch.Tensor]: Mapping from anchor_idx to tensor of matching local node indices
"""
# Vectorized broadcast comparison: [A, N, M]
# local_to_global: [N] -> [1, N, 1]
# label_tensor: [A, M] -> [A, 1, M]
matches = local_to_global.view(1, -1, 1) == label_tensor.unsqueeze(1)

# Mask out padding matches (PADDING_NODE should not match any real node)
padding_mask = (label_tensor == PADDING_NODE).unsqueeze(1) # [A, 1, M]
matches = matches & ~padding_mask

# Reduce: any match across labels dimension -> [A, N]
any_match = matches.any(dim=2)

# Single nonzero call on full tensor to get all (anchor_idx, node_idx) pairs
# Returns tuple of (anchor_indices, node_indices) tensors
match_coords = torch.nonzero(any_match, as_tuple=True)
anchor_indices = match_coords[0]
node_indices = match_coords[1]

# Count matches per anchor using bincount for efficient splitting
if anchor_indices.numel() > 0:
counts = torch.bincount(anchor_indices, minlength=num_anchors)
else:
counts = torch.zeros(num_anchors, dtype=torch.long, device=any_match.device)

# Transfer node_indices to target device ONCE before splitting
# This avoids num_anchors small device transfers which have significant overhead
node_indices_on_device = node_indices.to(self.to_device)

# Split on device - torch.split returns a tuple of views (no copy)
split_sizes = counts.tolist()
split_indices = torch.split(node_indices_on_device, split_sizes)

# Build output dict using dict comprehension with enumerate (faster than loop)
return dict(enumerate(split_indices))

def _set_labels(
self,
data: Union[Data, HeteroData],
Expand All @@ -765,6 +817,10 @@ def _set_labels(
Sets the labels and relevant fields in the torch_geometric Data object, converting the global node ids for labels to their
local index. Removes inserted supervision edge type from the data variables, since this is an implementation detail and should not be
exposed in the final HeteroData/Data object.

This method uses fully vectorized operations to efficiently process all anchor nodes in a batch simultaneously,
including a single nonzero call for all anchors followed by efficient splitting.

Args:
data (Union[Data, HeteroData]): Graph to provide labels for
positive_labels_by_label_edge_type (dict[EdgeType, torch.Tensor]): Dict[positive label edge type, label ID tensor],
Expand All @@ -784,48 +840,42 @@ def _set_labels(
node_type_to_local_node_to_global_node[
DEFAULT_HOMOGENEOUS_NODE_TYPE
] = data.node

output_positive_labels: dict[EdgeType, dict[int, torch.Tensor]] = defaultdict(
dict
)
output_negative_labels: dict[EdgeType, dict[int, torch.Tensor]] = defaultdict(
dict
)

# We always have supervision edge types of the form (anchor_node_type, to, supervision_node_type)
# So we can index into the edge type accordingly.
edge_index = 2

# Process positive labels with fully vectorized operations
for edge_type, label_tensor in positive_labels_by_label_edge_type.items():
for local_anchor_node_id in range(label_tensor.size(0)):
positive_mask = (
node_type_to_local_node_to_global_node[
edge_type[edge_index]
].unsqueeze(1)
== label_tensor[local_anchor_node_id]
) # shape [N, P], where N is the number of nodes and P is the number of positive labels for the current anchor node

# Gets the indexes of the items in local_node_to_global_node which match any of the positive labels for the current anchor node
output_positive_labels[
label_edge_type_to_message_passing_edge_type(edge_type)
][local_anchor_node_id] = torch.nonzero(positive_mask)[:, 0].to(
self.to_device
)
# Shape [X], where X is the number of indexes in the original local_node_to_global_node which match a node in the positive labels for the current anchor node
local_to_global = node_type_to_local_node_to_global_node[
edge_type[edge_index]
]
num_anchors = label_tensor.size(0)

mp_edge_type = label_edge_type_to_message_passing_edge_type(edge_type)
output_positive_labels[mp_edge_type] = self._compute_label_matches(
local_to_global, label_tensor, num_anchors
)

# Process negative labels with fully vectorized operations
for edge_type, label_tensor in negative_labels_by_label_edge_type.items():
for local_anchor_node_id in range(label_tensor.size(0)):
negative_mask = (
node_type_to_local_node_to_global_node[
edge_type[edge_index]
].unsqueeze(1)
== label_tensor[local_anchor_node_id]
) # shape [N, M], where N is the number of nodes and M is the number of negative labels for the current anchor node

# Gets the indexes of the items in local_node_to_global_node which match any of the negative labels for the current anchor node
output_negative_labels[
label_edge_type_to_message_passing_edge_type(edge_type)
][local_anchor_node_id] = torch.nonzero(negative_mask)[:, 0].to(
self.to_device
)
# Shape [X], where X is the number of indexes in the original local_node_to_global_node which match a node in the negative labels for the current anchor node
local_to_global = node_type_to_local_node_to_global_node[
edge_type[edge_index]
]
num_anchors = label_tensor.size(0)

mp_edge_type = label_edge_type_to_message_passing_edge_type(edge_type)
output_negative_labels[mp_edge_type] = self._compute_label_matches(
local_to_global, label_tensor, num_anchors
)

if not output_positive_labels:
raise ValueError("No positive labels were found in the data!")
elif len(output_positive_labels) == 1:
Expand All @@ -837,6 +887,7 @@ def _set_labels(
data.y_negative = next(iter(output_negative_labels.values()))
elif len(output_negative_labels) > 0:
data.y_negative = output_negative_labels

return data

def _collate_fn(self, msg: SampleMessage) -> Union[Data, HeteroData]:
Expand Down