diff --git a/gigl/distributed/dist_ablp_neighborloader.py b/gigl/distributed/dist_ablp_neighborloader.py index 92059df87..c52081c86 100644 --- a/gigl/distributed/dist_ablp_neighborloader.py +++ b/gigl/distributed/dist_ablp_neighborloader.py @@ -25,12 +25,12 @@ NEGATIVE_LABEL_METADATA_KEY, POSITIVE_LABEL_METADATA_KEY, ABLPNodeSamplerInput, - metadata_key_with_prefix, ) from gigl.distributed.sampler_options import SamplerOptions, resolve_sampler_options from gigl.distributed.utils.neighborloader import ( DatasetSchema, SamplingClusterSetup, + extract_metadata, labeled_to_homogeneous, set_missing_features, shard_nodes_by_process, @@ -772,72 +772,64 @@ def _setup_for_graph_store( ), ) - def _get_labels( - self, msg: SampleMessage + def _extract_labels( + self, metadata: dict[str, torch.Tensor] ) -> tuple[ - SampleMessage, dict[EdgeType, torch.Tensor], dict[EdgeType, torch.Tensor], + dict[str, torch.Tensor], ]: + """Partition pre-extracted metadata into labels and remaining metadata. + # TODO (mkolodner-sc): Remove the need to modify metadata once GLT's `to_hetero_data` function is fixed - f""" - Gets the labels from the output SampleMessage and removes them from the metadata. We need to remove the labels from GLT's metadata since the - `to_hetero_data` function strangely assumes that we are doing edge-based sampling if the metadata is not empty at the time of + + Takes the metadata dict already extracted by ``_extract_metadata`` (keys + without the ``#META.`` prefix) and separates label entries from + non-label entries. We need to remove the labels from GLT's metadata since the `to_hetero_data` function + strangely assumes that we are doing edge-based sampling if the metadata is not empty at the time of building the HeteroData object. + Label keys use ``POSITIVE_LABEL_METADATA_KEY`` / ``NEGATIVE_LABEL_METADATA_KEY`` + prefixes followed by a string-encoded edge type tuple. If ``edge_dir`` + is ``"in"``, the edge type is reversed because GLT swaps src/dst + internally. + Args: - msg (SampleMessage): All possible results from a sampler, including subgraph data, features, and used defined metadata + metadata: Dict of metadata keys (without ``#META.`` prefix) to tensors, + as returned by ``_extract_metadata``. + Returns: - SampleMessage: Updated sample messsage with the label fields removed dict[EdgeType, torch.Tensor]: Dict[positive label edge type, label ID tensor], where the ith row of the tensor corresponds to the ith anchor node ID. dict[EdgeType, torch.Tensor]: Dict[negative label edge type, label ID tensor], where the ith row of the tensor corresponds to the ith anchor node ID. May be empty if no negative labels are present. + dict[str, torch.Tensor]: Non-label metadata entries """ - metadata: dict[str, torch.Tensor] = {} positive_labels_by_label_edge_type: dict[EdgeType, torch.Tensor] = {} negative_labels_by_label_edge_type: dict[EdgeType, torch.Tensor] = {} - # We update metadata with sepcial POSITIVE_LABEL_METADATA_KEY and NEGATIVE_LABEL_METADATA_KEY keys - # in gigl/distributed/dist_neighbor_sampler.py. - # We need to encode the tuples as strings because GLT requires the keys to be strings. - # As such, we decode the strings back into tuples, - # And then pop those keys out of the metadata as they are not needed otherwise. - # If edge_dir is "in", we need to reverse the edge type because GLT swaps src/dst for edge_dir = "out". - # NOTE: GLT *prepends* the keys with "#META." - positive_label_metadata_key_prefix = metadata_key_with_prefix( - POSITIVE_LABEL_METADATA_KEY - ) - negative_label_metadata_key_prefix = metadata_key_with_prefix( - NEGATIVE_LABEL_METADATA_KEY - ) - for k in list(msg.keys()): - if k.startswith(positive_label_metadata_key_prefix): - edge_type_str = k[len(positive_label_metadata_key_prefix) :] + remaining_metadata: dict[str, torch.Tensor] = {} + + for key, value in metadata.items(): + if key.startswith(POSITIVE_LABEL_METADATA_KEY): + edge_type_str = key[len(POSITIVE_LABEL_METADATA_KEY) :] edge_type = ast.literal_eval(edge_type_str) if self.edge_dir == "in": edge_type = reverse_edge_type(edge_type) - positive_labels_by_label_edge_type[edge_type] = msg[k].to( - self.to_device - ) - del msg[k] - elif k.startswith(negative_label_metadata_key_prefix): - edge_type_str = k[len(negative_label_metadata_key_prefix) :] + positive_labels_by_label_edge_type[edge_type] = value + elif key.startswith(NEGATIVE_LABEL_METADATA_KEY): + edge_type_str = key[len(NEGATIVE_LABEL_METADATA_KEY) :] edge_type = ast.literal_eval(edge_type_str) if self.edge_dir == "in": edge_type = reverse_edge_type(edge_type) - negative_labels_by_label_edge_type[edge_type] = msg[k].to( - self.to_device - ) - del msg[k] - elif k.startswith("#META."): - meta_key = str(k[len("#META.") :]) - metadata[meta_key] = msg[k].to(self.to_device) - del msg[k] + negative_labels_by_label_edge_type[edge_type] = value + else: + remaining_metadata[key] = value + return ( - msg, positive_labels_by_label_edge_type, negative_labels_by_label_edge_type, + remaining_metadata, ) def _set_labels( @@ -925,8 +917,14 @@ def _set_labels( return data def _collate_fn(self, msg: SampleMessage) -> Union[Data, HeteroData]: - msg, positive_labels, negative_labels = self._get_labels(msg) - data = super()._collate_fn(msg) + # _extract_metadata separates #META. keys from the message to work + # around a GLT bug in to_hetero_data. _extract_labels then partitions + # the metadata into labels vs remaining non-label metadata. + all_metadata, stripped_msg = extract_metadata(msg, self.to_device) + positive_labels, negative_labels, non_label_metadata = self._extract_labels( + all_metadata + ) + data = super()._collate_fn(stripped_msg) data = set_missing_features( data=data, node_feature_info=self._node_feature_info, @@ -941,5 +939,7 @@ def _collate_fn(self, msg: SampleMessage) -> Union[Data, HeteroData]: f"Expected 1 supervision edge type, got {len(self._supervision_edge_types)}" ) data = labeled_to_homogeneous(self._supervision_edge_types[0], data) + for key, value in non_label_metadata.items(): + data[key] = value data = self._set_labels(data, positive_labels, negative_labels) return data diff --git a/gigl/distributed/dist_ppr_sampler.py b/gigl/distributed/dist_ppr_sampler.py new file mode 100644 index 000000000..453ca1f72 --- /dev/null +++ b/gigl/distributed/dist_ppr_sampler.py @@ -0,0 +1,684 @@ +import heapq +from collections import defaultdict +from typing import Optional, Union + +import torch +from graphlearn_torch.channel import SampleMessage +from graphlearn_torch.sampler import ( + HeteroSamplerOutput, + NeighborOutput, + NodeSamplerInput, + SamplerOutput, +) +from graphlearn_torch.typing import EdgeType, NodeType +from graphlearn_torch.utils import merge_dict + +from gigl.distributed.dist_dataset import DistDataset +from gigl.distributed.dist_neighbor_sampler import DistNeighborSampler + +# Sentinel type names for homogeneous graphs. The PPR algorithm uses +# dict[NodeType, ...] internally for both homo and hetero graphs; these +# sentinels let the homogeneous path reuse the same dict-based code. +_PPR_HOMOGENEOUS_NODE_TYPE = "ppr_homogeneous_node_type" +_PPR_HOMOGENEOUS_EDGE_TYPE = ( + _PPR_HOMOGENEOUS_NODE_TYPE, + "to", + _PPR_HOMOGENEOUS_NODE_TYPE, +) + + +class DistPPRNeighborSampler(DistNeighborSampler): + """ + Personalized PageRank (PPR) based neighbor sampler that inherits from GLT DistNeighborSampler. + + Instead of uniform random sampling, this sampler uses PPR scores to select the most + relevant neighbors for each seed node. The PPR algorithm approximates the stationary + distribution of a random walk with restart probability alpha. + + This sampler supports both homogeneous and heterogeneous graphs. For heterogeneous graphs, + the PPR algorithm traverses across all edge types, switching edge types based on the + current node type and the configured edge direction. + + Degree tensors are sourced automatically from the dataset at initialization time. + + Args: + alpha: Restart probability (teleport probability back to seed). Higher values + keep samples closer to seeds. Typical values: 0.15-0.25. + eps: Convergence threshold. Smaller values give more accurate PPR scores + but require more computation. Typical values: 1e-4 to 1e-6. + max_ppr_nodes: Maximum number of nodes to return per seed based on PPR scores. + num_nbrs_per_hop: Maximum number of neighbors to fetch per hop. + total_degree_dtype: Dtype for precomputed total-degree tensors. Defaults to + ``torch.int32``, which supports total degrees up to ~2 billion. Use a + larger dtype if nodes have exceptionally high aggregate degrees. + """ + + def __init__( + self, + *args, + alpha: float = 0.5, + eps: float = 1e-4, + max_ppr_nodes: int = 50, + num_nbrs_per_hop: int = 100000, + total_degree_dtype: torch.dtype = torch.int32, + **kwargs, + ): + super().__init__(*args, **kwargs) + self._alpha = alpha + self._eps = eps + self._max_ppr_nodes = max_ppr_nodes + self._requeue_threshold_factor = alpha * eps + self._num_nbrs_per_hop = num_nbrs_per_hop + + assert isinstance( + self.data, DistDataset + ), "DistPPRNeighborSampler requires a GiGL DistDataset to access degree tensors." + degree_tensors = self.data.degree_tensor + + # Build mapping from node type to edge types that can be traversed from that node type. + self._node_type_to_edge_types: dict[NodeType, list[EdgeType]] = defaultdict( + list + ) + + if hasattr(self, "edge_types") and self.edge_types is not None: + self._is_homogeneous = False + # Heterogeneous case: map each node type to its outgoing/incoming edge types + for etype in self.edge_types: + if self.edge_dir == "in": + # For incoming edges, we traverse FROM the destination node type + anchor_type = etype[-1] + else: # "out" + # For outgoing edges, we traverse FROM the source node type + anchor_type = etype[0] + + self._node_type_to_edge_types[anchor_type].append(etype) + else: + self._node_type_to_edge_types[_PPR_HOMOGENEOUS_NODE_TYPE] = [ + _PPR_HOMOGENEOUS_EDGE_TYPE + ] + self._is_homogeneous = True + + # Precompute total degree per node type: the sum of degrees across all + # edge types traversable from that node type. This is a graph-level + # property used on every PPR iteration, so computing it once at init + # avoids per-node summation and cache lookups in the hot loop. + # TODO (mkolodner-sc): This trades memory for throughput — we + # materialize a tensor per node type to avoid recomputing total degree + # on every neighbor during sampling. Computing it here (rather than in + # the dataset) also keeps the door open for edge-specific degree + # strategies. If memory becomes a bottleneck, revisit this. + self._total_degree_by_node_type: dict[ + NodeType, torch.Tensor + ] = self._build_total_degree_tensors(degree_tensors, total_degree_dtype) + + def _build_total_degree_tensors( + self, + degree_tensors: Union[torch.Tensor, dict[EdgeType, torch.Tensor]], + dtype: torch.dtype, + ) -> dict[NodeType, torch.Tensor]: + """Build total-degree tensors by summing per-edge-type degrees for each node type. + + For homogeneous graphs, the total degree is just the single degree tensor. + For heterogeneous graphs, it sums degree tensors across all edge types + traversable from each node type, padding shorter tensors with zeros. + + Args: + degree_tensors: Per-edge-type degree tensors from the dataset. + dtype: Dtype for the output tensors. + + Returns: + Dict mapping node type to a 1-D tensor of total degrees. + """ + result: dict[NodeType, torch.Tensor] = {} + + if self._is_homogeneous: + assert isinstance(degree_tensors, torch.Tensor) + result[_PPR_HOMOGENEOUS_NODE_TYPE] = degree_tensors.to(dtype) + else: + assert isinstance(degree_tensors, dict) + for node_type, edge_types in self._node_type_to_edge_types.items(): + max_len = 0 + for et in edge_types: + if et not in degree_tensors: + raise ValueError( + f"Edge type {et} not found in degree tensors. " + f"Available: {list(degree_tensors.keys())}" + ) + max_len = max(max_len, len(degree_tensors[et])) + + summed = torch.zeros(max_len, dtype=dtype) + for et in edge_types: + et_degrees = degree_tensors[et] + summed[: len(et_degrees)] += et_degrees.to(dtype) + result[node_type] = summed + + return result + + def _get_total_degree(self, node_id: int, node_type: NodeType) -> int: + """Look up the precomputed total degree of a node. + + Args: + node_id: The ID of the node to look up. + node_type: The node type. + + Returns: + The total degree (sum across all edge types) for the node. + + Raises: + ValueError: If the node ID is out of range, indicating corrupted + graph data or a sampler bug. + """ + degree_tensor = self._total_degree_by_node_type[node_type] + if node_id >= len(degree_tensor): + raise ValueError( + f"Node ID {node_id} exceeds total degree tensor length " + f"({len(degree_tensor)}) for node type {node_type}." + ) + return int(degree_tensor[node_id].item()) + + def _get_destination_type(self, edge_type: EdgeType) -> NodeType: + """Get the node type at the destination end of an edge type.""" + return edge_type[0] if self.edge_dir == "in" else edge_type[-1] + + async def _batch_fetch_neighbors( + self, + nodes_by_edge_type: dict[EdgeType, set[int]], + neighbor_target: dict[tuple[int, EdgeType], list[int]], + device: torch.device, + ) -> None: + """ + Batch fetch neighbors for nodes grouped by edge type. + + Fetches neighbors for all nodes in nodes_by_edge_type, populating + neighbor_target with neighbor lists. Degrees are looked up separately + from the in-memory degree_tensors. + + Args: + nodes_by_edge_type: Dict mapping edge type to set of node IDs to fetch + neighbor_target: Dict to populate with (node_id, edge_type) -> neighbor list + device: Torch device for tensor creation + """ + for etype, node_ids in nodes_by_edge_type.items(): + if not node_ids: + continue + nodes_list = list(node_ids) + lookup_tensor = torch.tensor(nodes_list, dtype=torch.long, device=device) + + # _sample_one_hop expects None for homogeneous graphs, not the PPR sentinel. + output: NeighborOutput = await self._sample_one_hop( + srcs=lookup_tensor, + num_nbr=self._num_nbrs_per_hop, + etype=etype if etype != _PPR_HOMOGENEOUS_EDGE_TYPE else None, + ) + neighbors = output.nbr + neighbor_counts = output.nbr_num + + neighbors_list = neighbors.tolist() + counts_list = neighbor_counts.tolist() + del neighbors, neighbor_counts + + # neighbors_list is a flat concatenation of all neighbors for all looked-up nodes. + # We use offset to slice out each node's neighbors: node i's neighbors are at + # neighbors_list[offset : offset + count], then we advance offset by count. + offset = 0 + for node_id, count in zip(nodes_list, counts_list): + cache_key = (node_id, etype) + neighbor_target[cache_key] = neighbors_list[offset : offset + count] + offset += count + + async def _compute_ppr_scores( + self, + seed_nodes: torch.Tensor, + seed_node_type: Optional[NodeType] = None, + ) -> tuple[ + Union[torch.Tensor, dict[NodeType, torch.Tensor]], + Union[torch.Tensor, dict[NodeType, torch.Tensor]], + Union[torch.Tensor, dict[NodeType, torch.Tensor]], + ]: + """ + Compute PPR scores for seed nodes using the push-based approximation algorithm. + + This implements the Forward Push algorithm (Andersen et al., 2006) which + iteratively pushes probability mass from nodes with high residual to their + neighbors. For heterogeneous graphs, the algorithm traverses across all + edge types, switching based on the current node type. + + Algorithm Overview (each iteration of the main loop): + 1. Fetch neighbors: Drain all nodes from the queue, group by edge type, + and perform a batched neighbor lookup to populate neighbor/degree caches. + 2. Push residual + re-queue (single pass): For each queued node, add its + residual to its PPR score, reset its residual to zero, then distribute + (1-alpha) * residual to all neighbors proportionally by degree. After + each push, immediately check if the neighbor's accumulated residual + exceeds alpha * eps * total_degree; if so, add it to the queue for + the next iteration. Total degree lookups are cached across the entire + PPR computation to avoid redundant summation. + + Args: + seed_nodes: Tensor of seed node IDs [batch_size] + seed_node_type: Node type of seed nodes. Should be None for homogeneous graphs. + + Returns: + tuple of (flat_neighbor_ids, flat_weights, valid_counts) where each is either + a 1-D tensor (homogeneous) or a dict mapping NodeType to a 1-D tensor + (heterogeneous): + - flat_neighbor_ids: global neighbor IDs in top-k order, concatenated + across all seeds. Length equals sum(valid_counts). + - flat_weights: corresponding PPR scores, same length as flat_neighbor_ids. + - valid_counts: number of PPR neighbors found per seed [batch_size]. + """ + if seed_node_type is None: + seed_node_type = _PPR_HOMOGENEOUS_NODE_TYPE + device = seed_nodes.device + batch_size = seed_nodes.size(0) + + # Per-seed PPR state, nested by node type for efficient type-grouped access. + # + # ppr_scores[i][node_type][node_id] = accumulated PPR score for node_id + # of type node_type, relative to seed i. Updated each iteration by + # absorbing the node's residual. + # + # residuals[i][node_type][node_id] = unconverged probability mass at node_id + # of type node_type for seed i. Each iteration, a node's residual is + # absorbed into its PPR score and then distributed to its neighbors. + # + # queue[i][node_type] = set of node IDs whose residual exceeds the + # convergence threshold (alpha * eps * total_degree). The algorithm + # terminates when all queues are empty. A set is used because multiple + # neighbors can push residual to the same node in one iteration — + # deduplication avoids redundant processing, and the O(1) membership + # check matters since it runs in the innermost loop. + ppr_scores: list[dict[NodeType, dict[int, float]]] = [ + defaultdict(lambda: defaultdict(float)) for _ in range(batch_size) + ] + residuals: list[dict[NodeType, dict[int, float]]] = [ + defaultdict(lambda: defaultdict(float)) for _ in range(batch_size) + ] + queue: list[dict[NodeType, set[int]]] = [ + defaultdict(set) for _ in range(batch_size) + ] + + seed_list = seed_nodes.tolist() + + for i, seed in enumerate(seed_list): + residuals[i][seed_node_type][seed] = self._alpha + queue[i][seed_node_type].add(seed) + + # Cache keyed by (node_id, edge_type) since same node can have different neighbors per edge type + neighbor_cache: dict[tuple[int, EdgeType], list[int]] = {} + + num_nodes_in_queue = batch_size + one_minus_alpha = 1 - self._alpha + + while num_nodes_in_queue > 0: + # Drain all nodes from all queues and group by edge type for batched lookups + queued_nodes: list[dict[NodeType, set[int]]] = [ + defaultdict(set) for _ in range(batch_size) + ] + nodes_by_edge_type: dict[EdgeType, set[int]] = defaultdict(set) + + for i in range(batch_size): + if queue[i]: + queued_nodes[i] = queue[i] + queue[i] = defaultdict(set) + for node_type, node_ids in queued_nodes[i].items(): + num_nodes_in_queue -= len(node_ids) + edge_types_for_node = self._node_type_to_edge_types[node_type] + for node_id in node_ids: + for etype in edge_types_for_node: + cache_key = (node_id, etype) + if cache_key not in neighbor_cache: + nodes_by_edge_type[etype].add(node_id) + + await self._batch_fetch_neighbors( + nodes_by_edge_type, neighbor_cache, device + ) + + # Push residual to neighbors and re-queue in a single pass. This + # is safe because each seed's state is independent, and residuals + # are always positive so the merged loop can never miss a re-queue. + for i in range(batch_size): + for source_type, source_nodes in queued_nodes[i].items(): + for source_node in source_nodes: + source_residual = residuals[i][source_type].get( + source_node, 0.0 + ) + + ppr_scores[i][source_type][source_node] += source_residual + residuals[i][source_type][source_node] = 0.0 + + edge_types_for_node = self._node_type_to_edge_types[source_type] + + total_degree = self._get_total_degree(source_node, source_type) + + if total_degree == 0: + continue + + residual_per_neighbor = ( + one_minus_alpha * source_residual / total_degree + ) + + for etype in edge_types_for_node: + cache_key = (source_node, etype) + neighbor_list = neighbor_cache[cache_key] + if not neighbor_list: + continue + + neighbor_type = self._get_destination_type(etype) + + for neighbor_node in neighbor_list: + residuals[i][neighbor_type][ + neighbor_node + ] += residual_per_neighbor + + requeue_threshold = ( + self._requeue_threshold_factor + * self._get_total_degree( + neighbor_node, neighbor_type + ) + ) + should_requeue = ( + neighbor_node not in queue[i][neighbor_type] + and residuals[i][neighbor_type][neighbor_node] + >= requeue_threshold + ) + if should_requeue: + queue[i][neighbor_type].add(neighbor_node) + num_nodes_in_queue += 1 + + # Extract top-k nodes by PPR score, grouped by node type. + # Results are three flat tensors per node type (no padding): + # - flat_ids: [id_seed0_0, id_seed0_1, ..., id_seed1_0, ...] + # - flat_weights: [wt_seed0_0, wt_seed0_1, ..., wt_seed1_0, ...] + # - valid_counts: [count_seed0, count_seed1, ...] + # + # valid_counts[i] records how many top-k neighbors seed i contributed. + # Callers use it to slice flat_ids/flat_weights back into per-seed + # groups and to build PyG edge-index tensors via repeat_interleave: + # + # Example: 3 seeds, valid_counts = [2, 3, 1] + # flat_dst = [dst_0a, dst_0b, dst_1a, dst_1b, dst_1c, dst_2a] + # + # src_indices = repeat_interleave(arange(3), valid_counts) + # = [0, 0, 1, 1, 1, 2] + # + # edge_index = stack([src_indices, flat_dst]) + # = [[0, 0, 1, 1, 1, 2], + # [dst_0a, dst_0b, dst_1a, dst_1b, dst_1c, dst_2a]] + # + # Column j means "edge from seed src_indices[j] to neighbor flat_dst[j]" + # with PPR weight flat_weights[j]. + all_node_types = self._node_type_to_edge_types.keys() + + flat_ids_by_ntype: dict[NodeType, torch.Tensor] = {} + flat_weights_by_ntype: dict[NodeType, torch.Tensor] = {} + valid_counts_by_ntype: dict[NodeType, torch.Tensor] = {} + + for ntype in all_node_types: + flat_ids: list[int] = [] + flat_weights: list[float] = [] + valid_counts: list[int] = [] + + for i in range(batch_size): + type_scores = ppr_scores[i].get(ntype, {}) + top_k = heapq.nlargest( + self._max_ppr_nodes, type_scores.items(), key=lambda x: x[1] + ) + for node_id, weight in top_k: + flat_ids.append(node_id) + flat_weights.append(weight) + valid_counts.append(len(top_k)) + + flat_ids_by_ntype[ntype] = torch.tensor( + flat_ids, dtype=torch.long, device=device + ) + flat_weights_by_ntype[ntype] = torch.tensor( + flat_weights, dtype=torch.float, device=device + ) + valid_counts_by_ntype[ntype] = torch.tensor( + valid_counts, dtype=torch.long, device=device + ) + + if self._is_homogeneous: + assert ( + len(flat_ids_by_ntype) == 1 + and _PPR_HOMOGENEOUS_NODE_TYPE in flat_ids_by_ntype + ) + return ( + flat_ids_by_ntype[_PPR_HOMOGENEOUS_NODE_TYPE], + flat_weights_by_ntype[_PPR_HOMOGENEOUS_NODE_TYPE], + valid_counts_by_ntype[_PPR_HOMOGENEOUS_NODE_TYPE], + ) + else: + return flat_ids_by_ntype, flat_weights_by_ntype, valid_counts_by_ntype + + async def _sample_from_nodes( + self, + inputs: NodeSamplerInput, + ) -> Optional[SampleMessage]: + """ + Override the base sampling method to use PPR-based neighbor selection. + + Supports both NodeSamplerInput and ABLPNodeSamplerInput. For ABLP, PPR + scores are computed from both anchor and supervision nodes, so the sampled + subgraph includes neighbors relevant to all seed types. + + For heterogeneous graphs, PPR traverses across all edge types, switching + edge types based on the current node type. + + Output format (PyG edge-index style, no padding): + + - ``ppr_neighbor_ids`` (homo) / ``ppr_neighbor_ids_{seed_type}_{ntype}`` (hetero): + shape ``[2, num_edges]`` — row 0 is local seed indices, row 1 is local + neighbor indices. Both index into ``data[ntype].node``. + - ``ppr_weights`` (homo) / ``ppr_weights_{seed_type}_{ntype}`` (hetero): + shape ``[num_edges]`` — PPR score for each edge, aligned with the columns + of ``ppr_neighbor_ids``. + + Local indices are produced by the inducer (see below), so row 1 of + ``ppr_neighbor_ids`` directly indexes into ``data[ntype].x`` without any + additional global→local remapping. + + The inducer is GLT's C++ data structure (backed by a per-node-type hash map) + that maintains a single global-ID → local-index mapping for the entire + subgraph being built. We use it here instead of a Python dict for two reasons: + + 1. **Consistency across seed types.** For heterogeneous ABLP inputs, + ``_compute_ppr_scores`` is called once per seed type (anchors, supervision + nodes, …). A node reachable from multiple seed types must receive the + *same* local index in ``node_dict[ntype]`` regardless of which seed type + discovered it. The inducer is shared across all those calls, so it + guarantees this automatically. + + 2. **Performance.** The inducer's C++ hash map is faster than a Python dict + for per-node lookups on large graphs, and its lifecycle is already managed + by GLT's inducer pool (``_acquire_inducer`` / ``inducer_pool.put``). + + The API used here mirrors GLT's own ``DistNeighborSampler._sample_from_nodes``: + + - ``inducer.init_node(seeds)`` registers seed nodes and returns their global + IDs (local indices 0, 1, … are assigned internally). + - ``inducer.induce_next(srcs, flat_nbrs, counts)`` (homo) or + ``inducer.induce_next(nbr_dict)`` (hetero) deduplicates neighbors against + all previously seen nodes and returns: + + - ``new_nodes``: global IDs of nodes not yet registered. + - ``cols``: flat local destination indices for *every* neighbor edge, + in the same order as the input ``flat_nbrs``. Combined with + ``repeat_interleave``-expanded seed indices, this forms the + ``[2, num_edges]`` edge-index tensor directly. + """ + sample_loop_inputs = self._prepare_sample_loop_inputs(inputs) + input_seeds = inputs.node.to(self.device) + input_type = inputs.input_type + is_hetero = self.dist_graph.data_cls == "hetero" + metadata = sample_loop_inputs.metadata + nodes_to_sample = sample_loop_inputs.nodes_to_sample + + # The inducer is GLT's C++ data structure that maintains a global-ID → + # local-index mapping for the subgraph being built. It serves two roles: + # + # 1. Deduplication: when the same global node ID appears from multiple + # seeds or seed types, induce_next assigns it a single local index. + # This ensures node_dict[ntype] has no duplicates. + # + # 2. Local index assignment: init_node registers seeds at local indices + # 0..N-1. induce_next then assigns the next available indices to + # newly discovered neighbors. The returned "cols" tensor contains + # the local destination index for every neighbor (including those + # that were already registered), which we use directly as row 1 of + # the PyG edge-index tensor. + # + # Acquired once per sample call; returned to the pool at the end. + inducer = self._acquire_inducer() + + if is_hetero: + assert isinstance(nodes_to_sample, dict) + assert input_type is not None + + # Register all seeds (anchors + supervision nodes for ABLP) with the + # inducer first, so they occupy the lowest local indices. src_dict maps + # NodeType -> global IDs (same values as nodes_to_sample). + src_dict = inducer.init_node(nodes_to_sample) + + # Compute PPR for each seed type, collecting flat global neighbor IDs, + # weights, and per-seed counts. Build nbr_dict for a single + # inducer.induce_next call using virtual edge types (seed_type, 'ppr', ntype) + # — the inducer only cares about etype[0] and etype[-1] as source/dest + # node types, so the relation name is arbitrary. + nbr_dict: dict[EdgeType, list[torch.Tensor]] = {} + all_flat_weights: dict[tuple[NodeType, NodeType], torch.Tensor] = {} + all_valid_counts: dict[tuple[NodeType, NodeType], torch.Tensor] = {} + + for seed_type, seed_nodes in nodes_to_sample.items(): + ( + flat_ids_by_type, + flat_weights_by_type, + valid_counts_by_type, + ) = await self._compute_ppr_scores(seed_nodes, seed_type) + assert isinstance(flat_ids_by_type, dict) + assert isinstance(flat_weights_by_type, dict) + assert isinstance(valid_counts_by_type, dict) + + for ntype, flat_ids in flat_ids_by_type.items(): + valid_counts = valid_counts_by_type[ntype] + all_flat_weights[(seed_type, ntype)] = flat_weights_by_type[ntype] + all_valid_counts[(seed_type, ntype)] = valid_counts + + # Skip empty pairs; induce_next handles deduplication across + # seed types so a neighbor reachable from multiple seed types + # gets one consistent local index in node_dict[ntype]. + if flat_ids.numel() > 0: + virtual_etype: EdgeType = (seed_type, "ppr", ntype) + nbr_dict[virtual_etype] = [ + src_dict[seed_type], + flat_ids, + valid_counts, + ] + + # induce_next assigns local indices to all neighbors not yet registered, + # deduplicating across all virtual edge types in one pass. + # new_nodes_dict: newly discovered global IDs per node type. + # cols_dict: flat local destination indices per virtual edge type, + # in the same order the flat neighbors were provided. + new_nodes_dict, _rows_dict, cols_dict = inducer.induce_next(nbr_dict) + + # node_dict = seeds (already in src_dict) + newly discovered PPR + # neighbors. merge_dict appends tensors into lists; cat collapses them. + out_nodes_hetero: dict[NodeType, list[torch.Tensor]] = defaultdict(list) + merge_dict(src_dict, out_nodes_hetero) + merge_dict(new_nodes_dict, out_nodes_hetero) + node_dict = { + ntype: torch.cat(nodes) + for ntype, nodes in out_nodes_hetero.items() + if nodes + } + + # Build PyG-style edge-index output per (seed_type, ntype) pair. + # cols_dict[(seed_type, 'ppr', ntype)] gives flat local dst indices in + # the same order as the flat neighbors passed to induce_next. + # repeat_interleave expands seed local indices to match. + for (seed_type, ntype), flat_weights in all_flat_weights.items(): + valid_counts = all_valid_counts[(seed_type, ntype)] + virtual_etype = (seed_type, "ppr", ntype) + cols = cols_dict.get(virtual_etype) + if cols is not None: + seed_batch_size = nodes_to_sample[seed_type].size(0) + src_indices = torch.repeat_interleave( + torch.arange(seed_batch_size, device=self.device), valid_counts + ) + ppr_edge_index = torch.stack([src_indices, cols]) + else: + ppr_edge_index = torch.zeros( + 2, 0, dtype=torch.long, device=self.device + ) + flat_weights = torch.zeros(0, dtype=torch.float, device=self.device) + metadata[f"ppr_neighbor_ids_{seed_type}_{ntype}"] = ppr_edge_index + metadata[f"ppr_weights_{seed_type}_{ntype}"] = flat_weights + + sample_output = HeteroSamplerOutput( + node=node_dict, + row={}, # PPR doesn't maintain edge structure + col={}, + edge={}, # Empty dict — GLT SampleQueue requires all values to be tensors + batch={input_type: input_seeds}, + num_sampled_nodes={ + ntype: [nodes.size(0)] for ntype, nodes in node_dict.items() + }, + num_sampled_edges={}, + input_type=input_type, + metadata=metadata, + ) + + else: + assert isinstance(nodes_to_sample, torch.Tensor) + + # Register seeds; local indices 0..N-1 are assigned internally. + # srcs holds their global IDs (same values as nodes_to_sample). + srcs = inducer.init_node(nodes_to_sample) + + ( + homo_flat_ids, + homo_flat_weights, + homo_valid_counts, + ) = await self._compute_ppr_scores(nodes_to_sample, None) + assert isinstance(homo_flat_ids, torch.Tensor) + assert isinstance(homo_flat_weights, torch.Tensor) + assert isinstance(homo_valid_counts, torch.Tensor) + + # induce_next deduplicates homo_flat_ids against already-seen nodes + # (the seeds registered above) and returns: + # new_nodes: global IDs of nodes not yet registered. + # cols: flat local destination indices for every neighbor, in the + # same order as homo_flat_ids. + new_nodes, _rows, cols = inducer.induce_next( + srcs, homo_flat_ids, homo_valid_counts + ) + all_nodes = torch.cat([srcs, new_nodes]) + + # Build PyG-style edge-index: row 0 = local seed indices (expanded via + # repeat_interleave), row 1 = local neighbor indices from inducer cols. + src_indices = torch.repeat_interleave( + torch.arange(nodes_to_sample.size(0), device=self.device), + homo_valid_counts, + ) + ppr_edge_index = torch.stack([src_indices, cols]) + + metadata["ppr_neighbor_ids"] = ppr_edge_index + metadata["ppr_weights"] = homo_flat_weights + + sample_output = SamplerOutput( + node=all_nodes, + row=torch.tensor([], dtype=torch.long, device=self.device), + col=torch.tensor([], dtype=torch.long, device=self.device), + edge=torch.tensor( + [], dtype=torch.long, device=self.device + ), # Empty tensor — GLT SampleQueue requires all values to be tensors + batch=input_seeds, + num_sampled_nodes=[srcs.size(0), new_nodes.size(0)], + num_sampled_edges=[], + metadata=metadata, + ) + + self.inducer_pool.put(inducer) + return sample_output diff --git a/gigl/distributed/dist_sampling_producer.py b/gigl/distributed/dist_sampling_producer.py index 7b07ff006..c07e8caaf 100644 --- a/gigl/distributed/dist_sampling_producer.py +++ b/gigl/distributed/dist_sampling_producer.py @@ -34,8 +34,17 @@ from torch.utils.data.dataloader import DataLoader from torch.utils.data.dataset import Dataset +from gigl.common.logger import Logger +from gigl.distributed.dist_dataset import DistDataset as GiglDistDataset from gigl.distributed.dist_neighbor_sampler import DistNeighborSampler -from gigl.distributed.sampler_options import KHopNeighborSamplerOptions, SamplerOptions +from gigl.distributed.dist_ppr_sampler import DistPPRNeighborSampler +from gigl.distributed.sampler_options import ( + KHopNeighborSamplerOptions, + PPRSamplerOptions, + SamplerOptions, +) + +logger = Logger() def _sampling_worker_loop( @@ -89,9 +98,19 @@ def _sampling_worker_loop( if sampling_config.seed is not None: seed_everything(sampling_config.seed) - # Resolve sampler class from options + # Resolve sampler class and any extra kwargs from options + extra_sampler_kwargs: dict[str, object] = {} if isinstance(sampler_options, KHopNeighborSamplerOptions): sampler_cls = DistNeighborSampler + elif isinstance(sampler_options, PPRSamplerOptions): + sampler_cls = DistPPRNeighborSampler + extra_sampler_kwargs = { + "alpha": sampler_options.alpha, + "eps": sampler_options.eps, + "max_ppr_nodes": sampler_options.max_ppr_nodes, + "num_nbrs_per_hop": sampler_options.num_nbrs_per_hop, + "total_degree_dtype": sampler_options.total_degree_dtype, + } else: raise NotImplementedError( f"Unsupported sampler options type: {type(sampler_options)}" @@ -110,6 +129,7 @@ def _sampling_worker_loop( worker_options.worker_concurrency, current_device, seed=sampling_config.seed, + **extra_sampler_kwargs, ) dist_sampler.start_loop() @@ -193,6 +213,23 @@ def __init__( def init(self): r"""Create the subprocess pool. Init samplers and rpc server.""" + # PPR sampling requires degree tensors in the sampler __init__. + # Worker subprocesses only initialize RPC (not torch.distributed), + # so the lazy degree computation would fail there. Eagerly compute + # here — where torch.distributed IS initialized — so the cached + # tensor is shared to workers via IPC. + if isinstance(self._sampler_options, PPRSamplerOptions): + assert isinstance(self.data, GiglDistDataset) + degree_tensor = self.data.degree_tensor + if isinstance(degree_tensor, dict): + logger.info( + f"Pre-computed degree tensors for PPR sampling across {len(degree_tensor)} edge types." + ) + else: + logger.info( + f"Pre-computed degree tensor for PPR sampling with {degree_tensor.size(0)} nodes." + ) + if self.sampling_config.seed is not None: seed_everything(self.sampling_config.seed) if not self.sampling_config.shuffle: diff --git a/gigl/distributed/distributed_neighborloader.py b/gigl/distributed/distributed_neighborloader.py index 9c72500b1..ab3a5bbc1 100644 --- a/gigl/distributed/distributed_neighborloader.py +++ b/gigl/distributed/distributed_neighborloader.py @@ -25,6 +25,7 @@ from gigl.distributed.utils.neighborloader import ( DatasetSchema, SamplingClusterSetup, + extract_metadata, labeled_to_homogeneous, set_missing_features, shard_nodes_by_process, @@ -558,7 +559,13 @@ def _setup_for_colocated( ) def _collate_fn(self, msg: SampleMessage) -> Union[Data, HeteroData]: - data = super()._collate_fn(msg) + # Extract user-defined metadata (e.g. PPR scores) before + # super()._collate_fn, which calls GLT's to_hetero_data. + # to_hetero_data misinterprets #META. keys as edge types and + # fails when edge_dir="out" (tries to reverse_edge_type on them). + # We strip them here and re-apply after conversion. + non_edge_metadata, stripped_msg = extract_metadata(msg, self.to_device) + data = super()._collate_fn(stripped_msg) data = set_missing_features( data=data, node_feature_info=self._node_feature_info, @@ -569,4 +576,6 @@ def _collate_fn(self, msg: SampleMessage) -> Union[Data, HeteroData]: data = strip_label_edges(data) if self._is_homogeneous_with_labeled_edge_type: data = labeled_to_homogeneous(DEFAULT_HOMOGENEOUS_EDGE_TYPE, data) + for key, value in non_edge_metadata.items(): + data[key] = value return data diff --git a/gigl/distributed/sampler.py b/gigl/distributed/sampler.py index 5d0d63fa9..e99dd65dc 100644 --- a/gigl/distributed/sampler.py +++ b/gigl/distributed/sampler.py @@ -10,14 +10,6 @@ NEGATIVE_LABEL_METADATA_KEY: Final[str] = "gigl_negative_labels." -def metadata_key_with_prefix(key: str) -> str: - """Prefixes the key with "#META - Do this as GLT also does this. - https://github.com/alibaba/graphlearn-for-pytorch/blob/88ff111ac0d9e45c6c9d2d18cfc5883dca07e9f9/graphlearn_torch/python/distributed/dist_neighbor_sampler.py#L714 - """ - return f"#META.{key}" - - class ABLPNodeSamplerInput(NodeSamplerInput): """ Sampler input specific for ABLP use case. Contains additional information about positive labels, negative labels, and the corresponding diff --git a/gigl/distributed/sampler_options.py b/gigl/distributed/sampler_options.py index f7bbb2e4b..678756795 100644 --- a/gigl/distributed/sampler_options.py +++ b/gigl/distributed/sampler_options.py @@ -1,14 +1,16 @@ """Sampler option types for configuring which sampler class to use in distributed loading. -Provides ``KHopNeighborSamplerOptions`` for using GiGL's built-in ``DistNeighborSampler``. +Provides ``KHopNeighborSamplerOptions`` for using GiGL's built-in ``DistNeighborSampler``, +and ``PPRSamplerOptions`` for PPR-based sampling using ``DistPPRNeighborSampler``. -Frozen dataclass so it is safe to pickle across RPC boundaries +Frozen dataclasses so they are safe to pickle across RPC boundaries (required for Graph Store mode). """ from dataclasses import dataclass from typing import Optional, Union +import torch from graphlearn_torch.typing import EdgeType from gigl.common.logger import Logger @@ -28,7 +30,37 @@ class KHopNeighborSamplerOptions: num_neighbors: Union[list[int], dict[EdgeType, list[int]]] -SamplerOptions = KHopNeighborSamplerOptions +@dataclass(frozen=True) +class PPRSamplerOptions: + """Sampler options for PPR-based neighbor sampling using DistPPRNeighborSampler. + + Degree tensors are sourced automatically from the dataset at sampler + initialization time and do not need to be provided here. + + Attributes: + alpha: Restart probability (teleport probability back to seed). Higher + values keep samples closer to seeds. Typical values: 0.15-0.25. + eps: Convergence threshold for the Forward Push algorithm. Smaller + values give more accurate PPR scores but require more computation. + Typical values: 1e-4 to 1e-6. + max_ppr_nodes: Maximum number of nodes to return per seed based on PPR + scores. + num_nbrs_per_hop: Maximum number of neighbors fetched per node per edge + type during PPR traversal. Set large to approximate fetching all + neighbors. + total_degree_dtype: Dtype for precomputed total-degree tensors. Defaults + to ``torch.int32``, which supports total degrees up to ~2 billion. + Use a larger dtype if nodes have exceptionally high aggregate degrees. + """ + + alpha: float = 0.5 + eps: float = 1e-4 + max_ppr_nodes: int = 50 + num_nbrs_per_hop: int = 100000 + total_degree_dtype: torch.dtype = torch.int32 + + +SamplerOptions = Union[KHopNeighborSamplerOptions, PPRSamplerOptions] def resolve_sampler_options( @@ -37,12 +69,14 @@ def resolve_sampler_options( ) -> SamplerOptions: """Resolve sampler_options from user-provided values. - If ``sampler_options`` is ``None``, wraps ``num_neighbors`` in a - ``KHopNeighborSamplerOptions``. If ``KHopNeighborSamplerOptions`` is - provided, validates that its ``num_neighbors`` matches the explicit value. + If ``sampler_options`` is a ``PPRSamplerOptions``, returns it directly + (``num_neighbors`` is unused for PPR). If ``sampler_options`` is ``None``, + wraps ``num_neighbors`` in a ``KHopNeighborSamplerOptions``. If + ``KHopNeighborSamplerOptions`` is provided, validates that its + ``num_neighbors`` matches the explicit value. Args: - num_neighbors: Fanout per hop (always required). + num_neighbors: Fanout per hop (required for KHop; ignored for PPR). sampler_options: Sampler configuration, or None. Returns: @@ -52,6 +86,9 @@ def resolve_sampler_options( ValueError: If ``KHopNeighborSamplerOptions.num_neighbors`` conflicts with the explicit ``num_neighbors``. """ + if isinstance(sampler_options, PPRSamplerOptions): + return sampler_options + if sampler_options is None: return KHopNeighborSamplerOptions(num_neighbors) diff --git a/gigl/distributed/utils/neighborloader.py b/gigl/distributed/utils/neighborloader.py index fdac550bc..2faf2bf67 100644 --- a/gigl/distributed/utils/neighborloader.py +++ b/gigl/distributed/utils/neighborloader.py @@ -6,6 +6,7 @@ from typing import Literal, Optional, TypeVar, Union import torch +from graphlearn_torch.channel import SampleMessage from torch_geometric.data import Data, HeteroData from torch_geometric.typing import EdgeType, NodeType @@ -264,3 +265,36 @@ def set_missing_features( ) return data + + +def extract_metadata( + msg: SampleMessage, device: torch.device +) -> tuple[dict[str, torch.Tensor], SampleMessage]: + """Separate user-defined metadata from a SampleMessage. + + GLT's ``to_hetero_data`` misinterprets ``#META.``-prefixed keys as + edge types, causing failures with ``edge_dir="out"`` (it tries to call + ``reverse_edge_type`` on metadata key strings). This function separates + metadata from the sampling data so the stripped message can be passed to + GLT's ``_collate_fn`` without triggering the bug. + + The original ``msg`` is not modified. + + Args: + msg: The SampleMessage to extract metadata from. + device: The device to move metadata tensors to. + + Returns: + A 2-tuple of: + - metadata: Dict mapping metadata key (without ``#META.`` prefix) to tensor. + - stripped_msg: A new SampleMessage with ``#META.``-prefixed keys removed. + """ + meta_prefix = "#META." + metadata: dict[str, torch.Tensor] = {} + stripped_msg: SampleMessage = {} + for k, v in msg.items(): + if k.startswith(meta_prefix): + metadata[k[len(meta_prefix) :]] = v.to(device) + else: + stripped_msg[k] = v + return metadata, stripped_msg diff --git a/tests/unit/distributed/dist_ppr_sampler_test.py b/tests/unit/distributed/dist_ppr_sampler_test.py new file mode 100644 index 000000000..dc2eb78a5 --- /dev/null +++ b/tests/unit/distributed/dist_ppr_sampler_test.py @@ -0,0 +1,617 @@ +"""Unit tests for DistPPRNeighborSampler correctness via DistNeighborLoader. + +Verifies that the PPR scores produced by the distributed sampler match +NetworkX's ``pagerank`` with personalization — a well-tested, independent +PPR implementation. + +Note on compatability with NetworkX: + +Both our forward push algorithm (Andersen et al., 2006) and NetworkX's +``pagerank`` (power iteration) compute Personalized PageRank — they are +different solvers for the same quantity. With a small residual tolerance +(eps=1e-6), forward push converges close enough that per-node scores match +NetworkX within atol=1e-3. + +Another note is that our ``alpha`` is the *restart* (teleport) probability — the probability of +jumping back to the seed at each step. NetworkX's ``alpha`` is the *damping +factor* — the probability of following an edge. These are complements:: + + nx_alpha = 1 - our_alpha + +Finally, with ``edge_dir="in"``, the PPR walk from node v follows *incoming* edges — +it moves to nodes u where edge (u, v) exists in the graph. NetworkX's +``pagerank`` follows *outgoing* edges. To make NetworkX traverse the same +neighbors as the sampler, we reverse the edges when building the reference +graph (add dst→src instead of src→dst). When ``edge_dir="out"``, no +reversal is needed since both follow the original edge direction. +""" + +import heapq +from collections import defaultdict +from typing import Literal + +import networkx as nx +import torch +import torch.multiprocessing as mp +from absl.testing import absltest +from graphlearn_torch.distributed import shutdown_rpc +from parameterized import param, parameterized +from torch_geometric.data import Data, HeteroData + +from gigl.distributed.dist_ablp_neighborloader import DistABLPLoader +from gigl.distributed.distributed_neighborloader import DistNeighborLoader +from gigl.distributed.sampler_options import PPRSamplerOptions +from tests.test_assets.distributed.test_dataset import ( + STORY, + STORY_TO_USER, + USER, + USER_TO_STORY, + create_heterogeneous_dataset, + create_heterogeneous_dataset_for_ablp, + create_homogeneous_dataset, +) +from tests.test_assets.distributed.utils import create_test_process_group +from tests.test_assets.test_case import TestCase + +# --------------------------------------------------------------------------- +# Homogeneous test graph (5 nodes, undirected edges stored as bidirectional) +# +# 0 --- 1 --- 3 +# | | +# 2 --- + +# | +# 4 +# +# Undirected edges: {0-1, 0-2, 1-2, 1-3, 2-4} +# --------------------------------------------------------------------------- +_TEST_EDGE_INDEX = torch.tensor( + [ + [0, 1, 0, 2, 1, 2, 1, 3, 2, 4], + [1, 0, 2, 0, 2, 1, 3, 1, 4, 2], + ] +) +_NUM_TEST_NODES = 5 + +# --------------------------------------------------------------------------- +# Heterogeneous bipartite test graph (3 users, 3 stories) +# USER_TO_STORY: user 0 -> {story 0, story 1} +# user 1 -> {story 1, story 2} +# user 2 -> {story 0, story 2} +# STORY_TO_USER: reverse of USER_TO_STORY +# --------------------------------------------------------------------------- +_TEST_HETERO_EDGE_INDICES = { + USER_TO_STORY: torch.tensor([[0, 0, 1, 1, 2, 2], [0, 1, 1, 2, 0, 2]]), + STORY_TO_USER: torch.tensor([[0, 0, 1, 1, 2, 2], [0, 2, 0, 1, 1, 2]]), +} +_NUM_TEST_USERS = 3 +_NUM_TEST_STORIES = 3 + +_TEST_ALPHA = 0.5 +_TEST_EPS = 1e-6 +_TEST_MAX_PPR_NODES = 5 + + +# --------------------------------------------------------------------------- +# Reference PPR implementations (NetworkX-based) +# --------------------------------------------------------------------------- +def _build_reference_graph(edge_dir: Literal["in", "out"] = "in") -> nx.DiGraph: + """Build a NetworkX DiGraph matching the homogeneous test edge_index. + + With ``edge_dir="in"``, edges are reversed (dst→src) so that NetworkX's + outgoing-edge traversal matches GLT's incoming-edge PPR walk. With + ``edge_dir="out"``, edges keep their original direction (src→dst). + + See the module docstring for a full explanation of why reversal is needed. + """ + graph = nx.DiGraph() + graph.add_nodes_from(range(_NUM_TEST_NODES)) + src = _TEST_EDGE_INDEX[0].tolist() + dst = _TEST_EDGE_INDEX[1].tolist() + if edge_dir == "in": + graph.add_edges_from(zip(dst, src)) + else: + graph.add_edges_from(zip(src, dst)) + return graph + + +def _reference_ppr( + graph: nx.DiGraph, + seed: int, + alpha: float, + max_ppr_nodes: int, +) -> dict[int, float]: + """Compute reference PPR scores for a homogeneous graph using NetworkX. + + See the module docstring for the alpha mapping rationale. + + Args: + graph: NetworkX DiGraph with edges oriented for the sampling direction. + seed: Seed node ID. + alpha: Restart probability (our convention). + max_ppr_nodes: Maximum number of top-scoring nodes to return. + + Returns: + Dict mapping node_id -> PPR score for the top-k nodes. + """ + personalization = {n: 0.0 for n in graph.nodes()} + personalization[seed] = 1.0 + + scores = nx.pagerank( + graph, alpha=1 - alpha, personalization=personalization, tol=1e-12 + ) + top_k = heapq.nlargest(max_ppr_nodes, scores.items(), key=lambda x: x[1]) + return dict(top_k) + + +def _build_hetero_reference_graph(edge_dir: Literal["in", "out"] = "in") -> nx.DiGraph: + """Build a NetworkX DiGraph for the heterogeneous test graph. + + Nodes are ``(type_str, id)`` tuples. Edge direction is handled the same + way as :func:`_build_reference_graph` — see the module docstring for the + full explanation of why reversal is needed for ``edge_dir="in"``. + """ + graph = nx.DiGraph() + for i in range(_NUM_TEST_USERS): + graph.add_node((str(USER), i)) + for i in range(_NUM_TEST_STORIES): + graph.add_node((str(STORY), i)) + + for edge_type, edge_index in _TEST_HETERO_EDGE_INDICES.items(): + src_type, _, dst_type = edge_type + src = edge_index[0].tolist() + dst = edge_index[1].tolist() + if edge_dir == "in": + for s, d in zip(src, dst): + graph.add_edge((str(dst_type), d), (str(src_type), s)) + else: + for s, d in zip(src, dst): + graph.add_edge((str(src_type), s), (str(dst_type), d)) + + return graph + + +def _reference_ppr_hetero( + graph: nx.DiGraph, + seed: int, + seed_type: str, + alpha: float, + max_ppr_nodes: int, +) -> dict[str, dict[int, float]]: + """Compute reference PPR scores for a heterogeneous graph using NetworkX. + + See the module docstring for the alpha mapping rationale. + + Args: + graph: NetworkX DiGraph with ``(type_str, id)`` tuple nodes. + seed: Seed node ID. + seed_type: Node type string of the seed. + alpha: Restart probability (our convention). + max_ppr_nodes: Maximum top-scoring nodes to return per node type. + + Returns: + Dict mapping node_type_str -> {node_id: PPR score} for top-k per type. + """ + personalization = {n: 0.0 for n in graph.nodes()} + personalization[(seed_type, seed)] = 1.0 + + scores = nx.pagerank( + graph, alpha=1 - alpha, personalization=personalization, tol=1e-12 + ) + + type_to_scores: dict[str, dict[int, float]] = defaultdict(dict) + for (ntype, nid), score in scores.items(): + type_to_scores[ntype][nid] = score + + result: dict[str, dict[int, float]] = {} + for ntype, type_scores in type_to_scores.items(): + top_k = heapq.nlargest(max_ppr_nodes, type_scores.items(), key=lambda x: x[1]) + result[ntype] = dict(top_k) + + return result + + +# --------------------------------------------------------------------------- +# Shared verification helpers +# --------------------------------------------------------------------------- +def _extract_hetero_ppr_scores( + datum: HeteroData, + seed_type: str, + node_types: list[str], +) -> dict[str, dict[int, float]]: + """Extract and validate PPR metadata from a HeteroData batch. + + Verifies tensor shapes and invariants (positive weights, valid indices), + maps local indices to global IDs, and returns scores grouped by node type. + + Args: + datum: A single HeteroData batch (batch_size=1). + seed_type: The seed node type used to key PPR metadata attributes. + node_types: Node types to extract PPR scores for. + + Returns: + Dict mapping node_type_str -> {global_node_id: ppr_score}. + """ + sampler_ppr_by_type: dict[str, dict[int, float]] = {} + for ntype in node_types: + key_ids = f"ppr_neighbor_ids_{seed_type}_{ntype}" + key_weights = f"ppr_weights_{seed_type}_{ntype}" + + assert hasattr(datum, key_ids), f"Missing {key_ids}" + assert hasattr(datum, key_weights), f"Missing {key_weights}" + + ppr_edge_index = getattr(datum, key_ids) + ppr_weights = getattr(datum, key_weights) + + assert ( + ppr_edge_index.dim() == 2 and ppr_edge_index.size(0) == 2 + ), f"Expected [2, X] edge_index, got shape {list(ppr_edge_index.shape)}" + assert ppr_weights.dim() == 1 + assert ppr_edge_index.size(1) == ppr_weights.size(0) + assert (ppr_weights > 0).all(), f"PPR weights for {ntype} must be positive" + assert ( + ppr_edge_index[0] == 0 + ).all(), "All src indices must be 0 for batch_size=1" + + global_node_ids = datum[ntype].node + type_ppr: dict[int, float] = {} + for j in range(ppr_edge_index.size(1)): + local_dst = ppr_edge_index[1, j].item() + global_dst = global_node_ids[local_dst].item() + type_ppr[global_dst] = ppr_weights[j].item() + sampler_ppr_by_type[str(ntype)] = type_ppr + + return sampler_ppr_by_type + + +def _assert_ppr_scores_match_reference( + sampler_ppr_by_type: dict[str, dict[int, float]], + reference_ppr: dict[str, dict[int, float]], + seed_id: int, + context_label: str = "", +) -> None: + """Assert sampler PPR scores match reference scores per node type. + + Checks that top-k node sets are identical and that per-node scores + are within atol=1e-3. The forward push error per node is bounded by + O(alpha * eps * degree), so atol=1e-3 is generous for eps=1e-6. + + Args: + sampler_ppr_by_type: Sampler output from :func:`_extract_hetero_ppr_scores`. + reference_ppr: Reference output from :func:`_reference_ppr_hetero`. + seed_id: Global seed node ID (for error messages). + context_label: Optional prefix for error messages (e.g. "ABLP"). + """ + prefix = f"{context_label} seed" if context_label else f"Seed" + for ntype_str in reference_ppr: + assert set(sampler_ppr_by_type[ntype_str].keys()) == set( + reference_ppr[ntype_str].keys() + ), ( + f"{prefix} {seed_id}, type {ntype_str}: top-k node sets differ.\n" + f" Sampler: {sorted(sampler_ppr_by_type[ntype_str].keys())}\n" + f" Reference: {sorted(reference_ppr[ntype_str].keys())}" + ) + + for node_id in reference_ppr[ntype_str]: + ref_score = reference_ppr[ntype_str][node_id] + sam_score = sampler_ppr_by_type[ntype_str][node_id] + assert abs(sam_score - ref_score) < 1e-3, ( + f"{prefix} {seed_id}, type {ntype_str}, node {node_id}: " + f"sampler={sam_score:.6f} vs reference={ref_score:.6f}" + ) + + +# --------------------------------------------------------------------------- +# Spawned process functions +# --------------------------------------------------------------------------- +def _run_ppr_loader_correctness_check( + _: int, + alpha: float, + max_ppr_nodes: int, + edge_dir: Literal["in", "out"], +) -> None: + """Iterate homogeneous PPR loader and verify each batch against NetworkX PPR.""" + create_test_process_group() + + dataset = create_homogeneous_dataset(edge_index=_TEST_EDGE_INDEX, edge_dir=edge_dir) + + loader = DistNeighborLoader( + dataset=dataset, + num_neighbors=[10], # Unused by PPR sampler; required by interface + sampler_options=PPRSamplerOptions( + alpha=alpha, + eps=_TEST_EPS, + max_ppr_nodes=max_ppr_nodes, + ), + pin_memory_device=torch.device("cpu"), + batch_size=1, + ) + + reference_graph = _build_reference_graph(edge_dir) + + batches_checked = 0 + for datum in loader: + assert isinstance(datum, Data) + + # GLT's to_data() unpacks metadata dict keys directly onto the Data + # object (data[k] = v), so PPR results are top-level attributes. + assert hasattr(datum, "ppr_neighbor_ids"), "Missing ppr_neighbor_ids on Data" + assert hasattr(datum, "ppr_weights"), "Missing ppr_weights on Data" + + ppr_edge_index = datum.ppr_neighbor_ids + ppr_weights = datum.ppr_weights + + assert ( + ppr_edge_index.dim() == 2 and ppr_edge_index.size(0) == 2 + ), f"Expected [2, X] edge_index, got shape {list(ppr_edge_index.shape)}" + assert ppr_weights.dim() == 1, f"Expected 1D weights, got {ppr_weights.dim()}D" + assert ppr_edge_index.size(1) == ppr_weights.size( + 0 + ), f"Edge count mismatch: {ppr_edge_index.size(1)} vs {ppr_weights.size(0)}" + assert (ppr_weights > 0).all(), "PPR weights must be positive" + assert ( + ppr_edge_index[0] == 0 + ).all(), "All src indices must be 0 for batch_size=1" + + # Map local indices to global IDs + global_node_ids = datum.node + seed_global_id = datum.batch[0].item() + + sampler_ppr: dict[int, float] = {} + for j in range(ppr_edge_index.size(1)): + local_dst = ppr_edge_index[1, j].item() + global_dst = global_node_ids[local_dst].item() + sampler_ppr[global_dst] = ppr_weights[j].item() + + # Compute reference PPR + reference_ppr = _reference_ppr( + graph=reference_graph, + seed=seed_global_id, + alpha=alpha, + max_ppr_nodes=max_ppr_nodes, + ) + + # Verify same top-k node set + assert set(sampler_ppr.keys()) == set(reference_ppr.keys()), ( + f"Seed {seed_global_id}: top-k node sets differ.\n" + f" Sampler: {sorted(sampler_ppr.keys())}\n" + f" Reference: {sorted(reference_ppr.keys())}" + ) + + # Forward push is an approximation; with eps=1e-6 the per-node error + # is bounded by O(alpha * eps * degree), so atol=1e-3 is generous. + for node_id in reference_ppr: + ref_score = reference_ppr[node_id] + sam_score = sampler_ppr[node_id] + assert abs(sam_score - ref_score) < 1e-3, ( + f"Seed {seed_global_id}, node {node_id}: " + f"sampler={sam_score:.6f} vs reference={ref_score:.6f}" + ) + + batches_checked += 1 + + assert ( + batches_checked == _NUM_TEST_NODES + ), f"Expected {_NUM_TEST_NODES} batches, got {batches_checked}" + shutdown_rpc() + + +def _run_ppr_hetero_loader_correctness_check( + _: int, + alpha: float, + max_ppr_nodes: int, + edge_dir: Literal["in", "out"], +) -> None: + """Iterate heterogeneous PPR loader and verify each batch against NetworkX PPR.""" + create_test_process_group() + + dataset = create_heterogeneous_dataset( + edge_indices=_TEST_HETERO_EDGE_INDICES, edge_dir=edge_dir + ) + + node_ids = dataset.node_ids + assert isinstance(node_ids, dict) + + loader = DistNeighborLoader( + dataset=dataset, + input_nodes=(USER, node_ids[USER]), + num_neighbors=[10], # Unused by PPR sampler; required by interface + sampler_options=PPRSamplerOptions( + alpha=alpha, + eps=_TEST_EPS, + max_ppr_nodes=max_ppr_nodes, + ), + pin_memory_device=torch.device("cpu"), + batch_size=1, + ) + + reference_graph = _build_hetero_reference_graph(edge_dir) + + batches_checked = 0 + for datum in loader: + assert isinstance(datum, HeteroData) + + seed_global_id = datum[USER].batch[0].item() + + sampler_ppr_by_type = _extract_hetero_ppr_scores( + datum, str(USER), [USER, STORY] + ) + + reference_ppr = _reference_ppr_hetero( + graph=reference_graph, + seed=seed_global_id, + seed_type=str(USER), + alpha=alpha, + max_ppr_nodes=max_ppr_nodes, + ) + + _assert_ppr_scores_match_reference( + sampler_ppr_by_type, reference_ppr, seed_global_id + ) + + batches_checked += 1 + + assert ( + batches_checked == _NUM_TEST_USERS + ), f"Expected {_NUM_TEST_USERS} batches, got {batches_checked}" + shutdown_rpc() + + +def _run_ppr_ablp_loader_correctness_check( + _: int, + alpha: float, + max_ppr_nodes: int, + edge_dir: Literal["in", "out"], +) -> None: + """Iterate ABLP PPR loader and verify anchor-seed PPR against NetworkX reference. + + Checks both anchor (USER) seed PPR scores for correctness against NetworkX, + and verifies that supervision (STORY) seed PPR metadata is present with + valid shapes. Also confirms that ABLP-specific output (y_positive) is + produced alongside PPR metadata. + + The ABLP dataset is created inside this spawned process because the + splitter requires torch.distributed to be initialized. + """ + create_test_process_group() + + dataset = create_heterogeneous_dataset_for_ablp( + positive_labels={0: [0, 1], 1: [1, 2], 2: [0, 2]}, + train_node_ids=[0, 1], + val_node_ids=[2], + test_node_ids=[], + edge_indices=_TEST_HETERO_EDGE_INDICES, + edge_dir=edge_dir, + ) + + train_node_ids = dataset.train_node_ids + assert isinstance(train_node_ids, dict) + + loader = DistABLPLoader( + dataset=dataset, + num_neighbors=[10], # Unused by PPR sampler; required by interface + input_nodes=(USER, train_node_ids[USER]), + supervision_edge_type=USER_TO_STORY, + sampler_options=PPRSamplerOptions( + alpha=alpha, + eps=_TEST_EPS, + max_ppr_nodes=max_ppr_nodes, + ), + pin_memory_device=torch.device("cpu"), + batch_size=1, + ) + + reference_graph = _build_hetero_reference_graph(edge_dir=edge_dir) + + batches_checked = 0 + for datum in loader: + assert isinstance(datum, HeteroData) + + # ABLP should produce positive labels alongside PPR metadata + assert hasattr(datum, "y_positive"), "Missing y_positive on HeteroData" + + seed_global_id = datum[USER].batch[0].item() + + # --- Verify anchor (USER) seed PPR correctness against NetworkX --- + sampler_ppr_by_type = _extract_hetero_ppr_scores( + datum, str(USER), [USER, STORY] + ) + + reference_ppr = _reference_ppr_hetero( + graph=reference_graph, + seed=seed_global_id, + seed_type=str(USER), + alpha=alpha, + max_ppr_nodes=max_ppr_nodes, + ) + + _assert_ppr_scores_match_reference( + sampler_ppr_by_type, reference_ppr, seed_global_id, context_label="ABLP" + ) + + # --- Verify supervision (STORY) seed PPR metadata --- + # ABLP adds supervision nodes as additional seeds, producing PPR metadata + # keyed by the STORY seed type. We only check shapes here (not correctness + # against NetworkX) because the supervision seeds vary per batch depending + # on the label edges, making deterministic reference computation complex. + for ntype in [USER, STORY]: + key_ids = f"ppr_neighbor_ids_{STORY}_{ntype}" + key_weights = f"ppr_weights_{STORY}_{ntype}" + + assert hasattr(datum, key_ids), f"Missing {key_ids}" + assert hasattr(datum, key_weights), f"Missing {key_weights}" + + ppr_edge_index = getattr(datum, key_ids) + ppr_weights = getattr(datum, key_weights) + + assert ppr_edge_index.dim() == 2 and ppr_edge_index.size(0) == 2 + assert ppr_weights.dim() == 1 + assert ppr_edge_index.size(1) == ppr_weights.size(0) + if ppr_weights.numel() > 0: + assert (ppr_weights > 0).all() + assert (ppr_edge_index[1] >= 0).all() + assert (ppr_edge_index[1] < datum[ntype].node.size(0)).all() + + batches_checked += 1 + + assert batches_checked > 0, "Expected at least one ABLP batch" + shutdown_rpc() + + +# --------------------------------------------------------------------------- +# Test class +# --------------------------------------------------------------------------- +class DistPPRSamplerTest(TestCase): + def setUp(self) -> None: + super().setUp() + + def tearDown(self) -> None: + if torch.distributed.is_initialized(): + torch.distributed.destroy_process_group() + super().tearDown() + + @parameterized.expand( + [ + param("edge_dir_in", edge_dir="in"), + param("edge_dir_out", edge_dir="out"), + ] + ) + def test_ppr_sampler_correctness_homogeneous(self, _, edge_dir: str) -> None: + """Verify PPR scores match NetworkX pagerank on a small homogeneous graph.""" + mp.spawn( + fn=_run_ppr_loader_correctness_check, + args=(_TEST_ALPHA, _TEST_MAX_PPR_NODES, edge_dir), + ) + + @parameterized.expand( + [ + param("edge_dir_in", edge_dir="in"), + param("edge_dir_out", edge_dir="out"), + ] + ) + def test_ppr_sampler_correctness_heterogeneous(self, _, edge_dir: str) -> None: + """Verify PPR scores match NetworkX pagerank on a heterogeneous bipartite graph.""" + mp.spawn( + fn=_run_ppr_hetero_loader_correctness_check, + args=(_TEST_ALPHA, _TEST_MAX_PPR_NODES, edge_dir), + ) + + @parameterized.expand( + [ + param("edge_dir_out", edge_dir="out"), + ] + ) + def test_ppr_sampler_ablp_correctness(self, _, edge_dir: str) -> None: + """Verify PPR scores through DistABLPLoader on a heterogeneous graph. + + Only tests ``edge_dir="out"`` because ``DistNodeAnchorLinkSplitter`` + with ``edge_dir="in"`` reverses the supervision edge type, requiring + a reversed labeled edge type that the test dataset does not include. + """ + mp.spawn( + fn=_run_ppr_ablp_loader_correctness_check, + args=(_TEST_ALPHA, _TEST_MAX_PPR_NODES, edge_dir), + ) + + +if __name__ == "__main__": + absltest.main() diff --git a/tests/unit/distributed/utils/neighborloader_test.py b/tests/unit/distributed/utils/neighborloader_test.py index 603b2dadb..51a156f0c 100644 --- a/tests/unit/distributed/utils/neighborloader_test.py +++ b/tests/unit/distributed/utils/neighborloader_test.py @@ -7,6 +7,7 @@ from torch_geometric.typing import EdgeType from gigl.distributed.utils.neighborloader import ( + extract_metadata, labeled_to_homogeneous, patch_fanout_for_sampling, set_missing_features, @@ -490,5 +491,64 @@ def test_set_custom_features_heterogeneous(self): ) +class ExtractMetadataTest(TestCase): + def setUp(self): + self._device = torch.device("cpu") + super().setUp() + + def test_separates_metadata_from_sampling_data(self): + msg = { + "#META.ppr_scores": torch.tensor([1.0, 2.0]), + "#META.custom_key": torch.tensor([3]), + "user.ids": torch.tensor([10, 20]), + "user__to__item.rows": torch.tensor([0, 1]), + } + metadata, stripped_msg = extract_metadata(msg, self._device) + + self.assertEqual(set(metadata.keys()), {"ppr_scores", "custom_key"}) + self.assert_tensor_equality(metadata["ppr_scores"], torch.tensor([1.0, 2.0])) + self.assert_tensor_equality(metadata["custom_key"], torch.tensor([3])) + + self.assertEqual(set(stripped_msg.keys()), {"user.ids", "user__to__item.rows"}) + self.assert_tensor_equality(stripped_msg["user.ids"], torch.tensor([10, 20])) + + def test_no_metadata_keys(self): + msg = { + "user.ids": torch.tensor([10, 20]), + "#IS_HETERO": torch.tensor([1]), + } + metadata, stripped_msg = extract_metadata(msg, self._device) + + self.assertEqual(metadata, {}) + self.assertEqual(set(stripped_msg.keys()), {"user.ids", "#IS_HETERO"}) + + def test_only_metadata_keys(self): + msg = { + "#META.scores": torch.tensor([1.0]), + } + metadata, stripped_msg = extract_metadata(msg, self._device) + + self.assertEqual(set(metadata.keys()), {"scores"}) + self.assertEqual(stripped_msg, {}) + + def test_does_not_modify_original_message(self): + original_tensor = torch.tensor([1.0, 2.0]) + msg = { + "#META.scores": original_tensor, + "user.ids": torch.tensor([10]), + } + original_keys = set(msg.keys()) + + extract_metadata(msg, self._device) + + self.assertEqual(set(msg.keys()), original_keys) + self.assertIn("#META.scores", msg) + + def test_empty_message(self): + metadata, stripped_msg = extract_metadata({}, self._device) + self.assertEqual(metadata, {}) + self.assertEqual(stripped_msg, {}) + + if __name__ == "__main__": absltest.main()