From 2e870f0126be91826b6293e1d039d548285dfe57 Mon Sep 17 00:00:00 2001 From: mkolodner Date: Wed, 4 Mar 2026 22:07:36 +0000 Subject: [PATCH 01/46] update --- gigl/distributed/base_dist_loader.py | 16 ++++++++-------- gigl/distributed/dist_ablp_neighborloader.py | 8 ++++---- gigl/distributed/distributed_neighborloader.py | 8 ++++---- 3 files changed, 16 insertions(+), 16 deletions(-) diff --git a/gigl/distributed/base_dist_loader.py b/gigl/distributed/base_dist_loader.py index d4ae3e452..3455c8a11 100644 --- a/gigl/distributed/base_dist_loader.py +++ b/gigl/distributed/base_dist_loader.py @@ -83,9 +83,9 @@ class BaseDistLoader(DistLoader): 2. Determine mode (colocated vs graph store). 3. Call ``create_sampling_config()`` to build the SamplingConfig. 4. For colocated: call ``create_colocated_channel()`` and construct the - ``DistMpSamplingProducer`` (or subclass), then pass the producer as ``sampler``. + ``DistMpSamplingProducer`` (or subclass), then pass the producer as ``producer``. 5. For graph store: pass the RPC function (e.g. ``DistServer.create_sampling_producer``) - as ``sampler``. + as ``producer``. 6. Call ``super().__init__()`` with the prepared data. Args: @@ -98,7 +98,7 @@ class BaseDistLoader(DistLoader): sampling_config: Configuration for the sampler (created via ``create_sampling_config``). device: Target device for sampled results. runtime: Resolved distributed runtime information. - sampler: Either a pre-constructed ``DistMpSamplingProducer`` (colocated mode) + producer: Either a pre-constructed ``DistMpSamplingProducer`` (colocated mode) or a callable to dispatch on the ``DistServer`` (graph store mode). process_start_gap_seconds: Delay between each process for staggered colocated init. """ @@ -204,7 +204,7 @@ def __init__( sampling_config: SamplingConfig, device: torch.device, runtime: DistributedRuntimeInfo, - sampler: Union[DistMpSamplingProducer, Callable[..., int]], + producer: Union[DistMpSamplingProducer, Callable[..., int]], process_start_gap_seconds: float = 60.0, ): # Set right away so __del__ can clean up if we throw during init. @@ -239,7 +239,7 @@ def __init__( self._epoch = 0 # --- Mode-specific attributes and connection initialization --- - if isinstance(sampler, DistMpSamplingProducer): + if isinstance(producer, DistMpSamplingProducer): assert isinstance(dataset, DistDataset) assert isinstance(worker_options, MpDistSamplingWorkerOptions) assert isinstance(sampler_input, NodeSamplerInput) @@ -263,7 +263,7 @@ def __init__( self._shutdowned = False self._init_colocated_connections( dataset=dataset, - producer=sampler, + producer=producer, runtime=runtime, process_start_gap_seconds=process_start_gap_seconds, ) @@ -271,7 +271,7 @@ def __init__( assert isinstance(dataset, RemoteDistDataset) assert isinstance(worker_options, RemoteDistSamplingWorkerOptions) assert isinstance(sampler_input, list) - assert callable(sampler) + assert callable(producer) self.data = None self._is_mp_worker = False @@ -300,7 +300,7 @@ def __init__( self._shutdowned = False self._init_graph_store_connections( dataset=dataset, - create_producer_fn=sampler, + create_producer_fn=producer, ) @staticmethod diff --git a/gigl/distributed/dist_ablp_neighborloader.py b/gigl/distributed/dist_ablp_neighborloader.py index 268bc4b5a..8dd09820d 100644 --- a/gigl/distributed/dist_ablp_neighborloader.py +++ b/gigl/distributed/dist_ablp_neighborloader.py @@ -323,13 +323,13 @@ def __init__( drop_last=drop_last, ) - # Build the sampler: a pre-constructed producer for colocated mode, + # Build the producer: a pre-constructed producer for colocated mode, # or an RPC callable for graph store mode. if self._sampling_cluster_setup == SamplingClusterSetup.COLOCATED: assert isinstance(dataset, DistDataset) assert isinstance(worker_options, MpDistSamplingWorkerOptions) channel = BaseDistLoader.create_colocated_channel(worker_options) - sampler: Union[ + producer: Union[ DistABLPSamplingProducer, Callable[..., int] ] = DistABLPSamplingProducer( dataset, @@ -339,7 +339,7 @@ def __init__( channel, ) else: - sampler = DistServer.create_sampling_ablp_producer + producer = DistServer.create_sampling_ablp_producer # Call base class — handles metadata storage and connection initialization # (including staggered init for colocated mode). @@ -351,7 +351,7 @@ def __init__( sampling_config=sampling_config, device=device, runtime=runtime, - sampler=sampler, + producer=producer, process_start_gap_seconds=process_start_gap_seconds, ) diff --git a/gigl/distributed/distributed_neighborloader.py b/gigl/distributed/distributed_neighborloader.py index 6d385ef39..3033030d1 100644 --- a/gigl/distributed/distributed_neighborloader.py +++ b/gigl/distributed/distributed_neighborloader.py @@ -231,13 +231,13 @@ def __init__( drop_last=drop_last, ) - # Build the sampler: a pre-constructed producer for colocated mode, + # Build the producer: a pre-constructed producer for colocated mode, # or an RPC callable for graph store mode. if self._sampling_cluster_setup == SamplingClusterSetup.COLOCATED: assert isinstance(dataset, DistDataset) assert isinstance(worker_options, MpDistSamplingWorkerOptions) channel = BaseDistLoader.create_colocated_channel(worker_options) - sampler: Union[ + producer: Union[ DistMpSamplingProducer, Callable[..., int] ] = DistMpSamplingProducer( dataset, @@ -247,7 +247,7 @@ def __init__( channel, ) else: - sampler = GiglDistServer.create_sampling_producer + producer = GiglDistServer.create_sampling_producer # Call base class — handles metadata storage and connection initialization # (including staggered init for colocated mode). @@ -259,7 +259,7 @@ def __init__( sampling_config=sampling_config, device=device, runtime=runtime, - sampler=sampler, + producer=producer, process_start_gap_seconds=process_start_gap_seconds, ) From 9f2331e49989ad5d8741d6ef5fabd397eb370738 Mon Sep 17 00:00:00 2001 From: mkolodner Date: Wed, 4 Mar 2026 22:50:21 +0000 Subject: [PATCH 02/46] initial changes --- gigl/distributed/dist_neighbor_sampler.py | 282 ++++++++++++++++------ 1 file changed, 204 insertions(+), 78 deletions(-) diff --git a/gigl/distributed/dist_neighbor_sampler.py b/gigl/distributed/dist_neighbor_sampler.py index ec9692611..e8817f01b 100644 --- a/gigl/distributed/dist_neighbor_sampler.py +++ b/gigl/distributed/dist_neighbor_sampler.py @@ -1,6 +1,7 @@ import asyncio import gc from collections import defaultdict +from dataclasses import dataclass from typing import Optional, Union import torch @@ -22,78 +23,88 @@ ) from gigl.utils.data_splitters import PADDING_NODE -# TODO (mkolodner-sc): Investigate upstreaming this change back to GLT +@dataclass +class PreparedSamplingInputs: + """Prepared inputs for the sampling loop. -class DistABLPNeighborSampler(DistNeighborSampler): + Attributes: + input_seeds: The original anchor node seeds. + input_type: The node type of the anchor seeds. + nodes_to_sample: Dict mapping node types to tensors of nodes to include + in sampling. May include additional nodes beyond input_seeds + (e.g., supervision nodes in ABLP). + metadata: Metadata dict to include in the sampler output. + use_original_seeds_for_batch: If True, use input_seeds as the batch + (ABLP behavior). If False, use the inducer result as the batch + (GLT base behavior). Defaults to False for GLT compatibility. """ - We inherit from the GLT DistNeighborSampler base class and override the _sample_from_nodes function. Specifically, we - introduce functionality to read parse ABLPNodeSamplerInput, which contains information about the supervision nodes and node types - that we also want to fanout around. We add the supervision nodes to the initial fanout seeds, and inject the label information into the - output SampleMessage metadata. + + input_seeds: torch.Tensor + input_type: NodeType + nodes_to_sample: dict[Union[str, NodeType], torch.Tensor] + metadata: dict[str, torch.Tensor] + use_original_seeds_for_batch: bool = False + + +class GiglDistNeighborSampler(DistNeighborSampler): + """GiGL's base distributed neighbor sampler with template method pattern. + + Extends GLT's DistNeighborSampler and overrides _sample_from_nodes to use + a hook (_prepare_sampling_inputs) that subclasses can override to customize + input preparation without duplicating the core sampling loop. + + The default implementation behaves identically to the base GLT sampler. + Subclasses can override _prepare_sampling_inputs to add additional nodes + to the sampling (e.g., supervision nodes) or populate metadata. """ - async def _sample_from_nodes( + def _prepare_sampling_inputs( self, inputs: NodeSamplerInput, - ) -> Optional[SampleMessage]: - assert isinstance(inputs, ABLPNodeSamplerInput) + ) -> PreparedSamplingInputs: + """Prepare inputs for the sampling loop. + + Override this method in subclasses to customize input preparation. + The default implementation uses the input seeds directly with no + additional nodes or metadata. + + Args: + inputs: The node sampler input. + + Returns: + PreparedSamplingInputs containing the seeds, node type, nodes to + sample from, and metadata. + """ input_seeds = inputs.node.to(self.device) input_type = inputs.input_type + return PreparedSamplingInputs( + input_seeds=input_seeds, + input_type=input_type, + nodes_to_sample={input_type: input_seeds}, + metadata={}, + ) - # Since GLT swaps src/dst for edge_dir = "out", - # and GiGL assumes that supervision edge types are always (anchor_node_type, to, supervision_node_type), - # we need to index into supervision edge types accordingly. - label_edge_index = 0 if self.edge_dir == "in" else 2 + async def _sample_from_nodes( + self, + inputs: NodeSamplerInput, + ) -> Optional[SampleMessage]: + """Sample subgraph from seed nodes. - # Go through the positive and negative labels and add them to the metadata and input seeds builder. - # We need to sample from the supervision nodes as well, and ensure that we are sampling from the correct node type. - metadata: dict[str, torch.Tensor] = {} - input_seeds_builder: dict[ - Union[str, NodeType], list[torch.Tensor] - ] = defaultdict(list) - input_seeds_builder[input_type].append(input_seeds) - for edge_type, label_tensor in inputs.positive_label_by_edge_types.items(): - filtered_label_tensor = label_tensor[label_tensor != PADDING_NODE].to( - self.device - ) - input_seeds_builder[edge_type[label_edge_index]].append( - filtered_label_tensor - ) - # Update the metadata per positive label edge type. - # We do this because GLT only supports dict[str, torch.Tensor] for metadata. - metadata[ - f"{POSITIVE_LABEL_METADATA_KEY}{str(tuple(edge_type))}" - ] = label_tensor - for edge_type, label_tensor in inputs.negative_label_by_edge_types.items(): - filtered_label_tensor = label_tensor[label_tensor != PADDING_NODE].to( - self.device - ) - input_seeds_builder[edge_type[label_edge_index]].append( - filtered_label_tensor - ) - # Update the metadata per negative label edge type. - # We do this because GLT only supports dict[str, torch.Tensor] for metadata. - metadata[ - f"{NEGATIVE_LABEL_METADATA_KEY}{str(tuple(edge_type))}" - ] = label_tensor - # As a perf optimization, we *could* have `input_nodes` be only the unique nodes, - # but since torch.unique() calls a sort, we should investigate if it's worth it. - # TODO(kmonte, mkolodner-sc): Investigate if this is worth it. - input_nodes: dict[Union[str, NodeType], torch.Tensor] = { - node_type: torch.cat(seeds, dim=0).to(self.device) - for node_type, seeds in input_seeds_builder.items() - } - del filtered_label_tensor, label_tensor - for value in input_seeds_builder.values(): - value.clear() - input_seeds_builder.clear() - del input_seeds_builder - gc.collect() + Uses _prepare_sampling_inputs hook to allow subclasses to customize + which nodes to sample from and what metadata to include. + """ + prepared = self._prepare_sampling_inputs(inputs) + input_seeds = prepared.input_seeds + input_type = prepared.input_type + nodes_to_sample = prepared.nodes_to_sample + metadata = prepared.metadata + use_original_seeds_for_batch = prepared.use_original_seeds_for_batch - self.max_input_size: int = max(self.max_input_size, input_seeds.numel()) + self.max_input_size = max(self.max_input_size, input_seeds.numel()) inducer = self._acquire_inducer() is_hetero = self.dist_graph.data_cls == "hetero" + output: NeighborOutput if is_hetero: assert input_type is not None @@ -103,8 +114,13 @@ async def _sample_from_nodes( out_edges_hetero: dict[EdgeType, list[torch.Tensor]] = {} num_sampled_nodes_hetero: dict[NodeType, list[torch.Tensor]] = {} num_sampled_edges_hetero: dict[EdgeType, list[torch.Tensor]] = {} - src_dict = inducer.init_node(input_nodes) - batch = {input_type: input_seeds} + + src_dict = inducer.init_node(nodes_to_sample) + # GLT uses src_dict (inducer result) for batch; ABLP uses original seeds + batch = ( + {input_type: input_seeds} if use_original_seeds_for_batch else src_dict + ) + merge_dict(src_dict, out_nodes_hetero) count_dict(src_dict, num_sampled_nodes_hetero, 1) @@ -167,20 +183,22 @@ async def _sample_from_nodes( ) else: assert ( - len(input_nodes) == 1 - ), f"Expected 1 input node type, got {len(input_nodes)}" + len(nodes_to_sample) == 1 + ), f"Expected 1 input node type, got {len(nodes_to_sample)}" assert ( - input_type == list(input_nodes.keys())[0] - ), f"Expected input type {input_type}, got {list(input_nodes.keys())[0]}" - srcs = inducer.init_node(input_nodes[input_type]) - batch = input_seeds + input_type == list(nodes_to_sample.keys())[0] + ), f"Expected input type {input_type}, got {list(nodes_to_sample.keys())[0]}" + + srcs = inducer.init_node(nodes_to_sample[input_type]) + # GLT uses srcs (inducer result) for batch; ABLP uses original seeds + batch = input_seeds if use_original_seeds_for_batch else srcs out_nodes: list[torch.Tensor] = [] out_edges: list[tuple[torch.Tensor, torch.Tensor, torch.Tensor]] = [] num_sampled_nodes: list[torch.Tensor] = [] num_sampled_edges: list[torch.Tensor] = [] out_nodes.append(srcs) num_sampled_nodes.append(srcs.size(0)) - # Sample subgraph. + for req_num in self.num_neighbors: output = await self._sample_one_hop(srcs, req_num, None) if output.nbr.numel() == 0: @@ -194,17 +212,125 @@ async def _sample_from_nodes( num_sampled_edges.append(cols.size(0)) srcs = nodes - sample_output = SamplerOutput( - node=torch.cat(out_nodes), - row=torch.cat([e[0] for e in out_edges]), - col=torch.cat([e[1] for e in out_edges]), - edge=(torch.cat([e[2] for e in out_edges]) if self.with_edge else None), - batch=batch, - num_sampled_nodes=num_sampled_nodes, - num_sampled_edges=num_sampled_edges, - metadata=metadata, - ) + if not out_edges: + sample_output = SamplerOutput( + node=torch.cat(out_nodes), + row=torch.tensor([]).to(self.device), + col=torch.tensor([]).to(self.device), + edge=(torch.tensor([]).to(self.device) if self.with_edge else None), + batch=batch, + num_sampled_nodes=num_sampled_nodes, + num_sampled_edges=num_sampled_edges, + metadata=metadata, + ) + else: + sample_output = SamplerOutput( + node=torch.cat(out_nodes), + row=torch.cat([e[0] for e in out_edges]), + col=torch.cat([e[1] for e in out_edges]), + edge=( + torch.cat([e[2] for e in out_edges]) if self.with_edge else None + ), + batch=batch, + num_sampled_nodes=num_sampled_nodes, + num_sampled_edges=num_sampled_edges, + metadata=metadata, + ) - # Reclaim inducer into pool. self.inducer_pool.put(inducer) return sample_output + + +class DistABLPNeighborSampler(GiglDistNeighborSampler): + """ABLP-specific neighbor sampler that adds supervision nodes to sampling. + + Overrides _prepare_sampling_inputs to parse ABLPNodeSamplerInput, adding + supervision nodes (positive/negative labels) to the sampling seeds and + including label information in the output metadata. + """ + + def _prepare_sampling_inputs( + self, + inputs: NodeSamplerInput, + ) -> PreparedSamplingInputs: + """Prepare ABLP inputs with supervision nodes and label metadata. + + Parses ABLPNodeSamplerInput to extract positive/negative label nodes, + adds them to the sampling seeds, and builds metadata with label tensors. + + Args: + inputs: Must be an ABLPNodeSamplerInput. + + Returns: + PreparedSamplingInputs with supervision nodes included in + nodes_to_sample and label tensors in metadata. + """ + assert isinstance(inputs, ABLPNodeSamplerInput) + input_seeds = inputs.node.to(self.device) + input_type = inputs.input_type + + # Since GLT swaps src/dst for edge_dir = "out", + # and GiGL assumes that supervision edge types are always + # (anchor_node_type, to, supervision_node_type), + # we need to index into supervision edge types accordingly. + label_edge_index = 0 if self.edge_dir == "in" else 2 + + # Build metadata and input nodes from positive/negative labels. + # We need to sample from the supervision nodes as well, and ensure + # that we are sampling from the correct node type. + metadata: dict[str, torch.Tensor] = {} + input_seeds_builder: dict[ + Union[str, NodeType], list[torch.Tensor] + ] = defaultdict(list) + input_seeds_builder[input_type].append(input_seeds) + + for edge_type, label_tensor in inputs.positive_label_by_edge_types.items(): + filtered_label_tensor = label_tensor[label_tensor != PADDING_NODE].to( + self.device + ) + input_seeds_builder[edge_type[label_edge_index]].append( + filtered_label_tensor + ) + # Update the metadata per positive label edge type. + # We do this because GLT only supports dict[str, torch.Tensor] for metadata. + metadata[ + f"{POSITIVE_LABEL_METADATA_KEY}{str(tuple(edge_type))}" + ] = label_tensor + + for edge_type, label_tensor in inputs.negative_label_by_edge_types.items(): + filtered_label_tensor = label_tensor[label_tensor != PADDING_NODE].to( + self.device + ) + input_seeds_builder[edge_type[label_edge_index]].append( + filtered_label_tensor + ) + # Update the metadata per negative label edge type. + # We do this because GLT only supports dict[str, torch.Tensor] for metadata. + metadata[ + f"{NEGATIVE_LABEL_METADATA_KEY}{str(tuple(edge_type))}" + ] = label_tensor + + # As a perf optimization, we *could* have `nodes_to_sample` be only the + # unique nodes, but since torch.unique() calls a sort, we should + # investigate if it's worth it. + # TODO(kmonte, mkolodner-sc): Investigate if this is worth it. + nodes_to_sample: dict[Union[str, NodeType], torch.Tensor] = { + node_type: torch.cat(seeds, dim=0).to(self.device) + for node_type, seeds in input_seeds_builder.items() + } + + # Memory cleanup + del filtered_label_tensor, label_tensor + for value in input_seeds_builder.values(): + value.clear() + input_seeds_builder.clear() + del input_seeds_builder + gc.collect() + + return PreparedSamplingInputs( + input_seeds=input_seeds, + input_type=input_type, + nodes_to_sample=nodes_to_sample, + metadata=metadata, + use_original_seeds_for_batch=True, + ) From d75f96e8d78b400afec190ec80c1b0be6e8be14f Mon Sep 17 00:00:00 2001 From: mkolodner Date: Wed, 4 Mar 2026 23:40:14 +0000 Subject: [PATCH 03/46] Update --- CLAUDE.md | 4 +- gigl/distributed/dist_ablp_neighborloader.py | 8 +- gigl/distributed/dist_neighbor_sampler.py | 241 +++++++++--------- gigl/distributed/dist_sampling_producer.py | 11 +- .../distributed/distributed_neighborloader.py | 6 +- gigl/distributed/graph_store/dist_server.py | 80 +----- 6 files changed, 137 insertions(+), 213 deletions(-) diff --git a/CLAUDE.md b/CLAUDE.md index b09346fe7..a781eb5ff 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -56,8 +56,8 @@ GiGL extends GraphLearn-for-PyTorch (GLT) for distributed GNN training. Key clas labels, split metadata, and feature info - **`DistNeighborLoader`** (extends GLT `DistLoader`) - Standard node-based sampling loader - **`DistABLPLoader`** (extends GLT `DistLoader`) - Anchor-Based Link Prediction sampling loader -- **`DistABLPNeighborSampler`** (extends GLT `DistNeighborSampler`) - Custom sampler supporting ABLP with - positive/negative label injection +- **`DistNeighborSampler`** (extends GLT `DistNeighborSampler`) - Unified sampler supporting both standard + neighbor sampling and ABLP with positive/negative label injection **Two deployment modes:** diff --git a/gigl/distributed/dist_ablp_neighborloader.py b/gigl/distributed/dist_ablp_neighborloader.py index 8dd09820d..684187d67 100644 --- a/gigl/distributed/dist_ablp_neighborloader.py +++ b/gigl/distributed/dist_ablp_neighborloader.py @@ -17,7 +17,7 @@ from gigl.distributed.base_dist_loader import BaseDistLoader from gigl.distributed.dist_context import DistributedContext from gigl.distributed.dist_dataset import DistDataset -from gigl.distributed.dist_sampling_producer import DistABLPSamplingProducer +from gigl.distributed.dist_sampling_producer import DistSamplingProducer from gigl.distributed.distributed_neighborloader import DEFAULT_NUM_CPU_THREADS from gigl.distributed.graph_store.dist_server import DistServer from gigl.distributed.graph_store.remote_dist_dataset import RemoteDistDataset @@ -330,8 +330,8 @@ def __init__( assert isinstance(worker_options, MpDistSamplingWorkerOptions) channel = BaseDistLoader.create_colocated_channel(worker_options) producer: Union[ - DistABLPSamplingProducer, Callable[..., int] - ] = DistABLPSamplingProducer( + DistSamplingProducer, Callable[..., int] + ] = DistSamplingProducer( dataset, sampler_input, sampling_config, @@ -339,7 +339,7 @@ def __init__( channel, ) else: - producer = DistServer.create_sampling_ablp_producer + producer = DistServer.create_sampling_producer # Call base class — handles metadata storage and connection initialization # (including staggered init for colocated mode). diff --git a/gigl/distributed/dist_neighbor_sampler.py b/gigl/distributed/dist_neighbor_sampler.py index e8817f01b..c32a5dc3f 100644 --- a/gigl/distributed/dist_neighbor_sampler.py +++ b/gigl/distributed/dist_neighbor_sampler.py @@ -6,7 +6,7 @@ import torch from graphlearn_torch.channel import SampleMessage -from graphlearn_torch.distributed import DistNeighborSampler +from graphlearn_torch.distributed import DistNeighborSampler as GltDistNeighborSampler from graphlearn_torch.sampler import ( HeteroSamplerOutput, NeighborOutput, @@ -25,7 +25,7 @@ @dataclass -class PreparedSamplingInputs: +class SamplingInputs: """Prepared inputs for the sampling loop. Attributes: @@ -35,73 +35,156 @@ class PreparedSamplingInputs: in sampling. May include additional nodes beyond input_seeds (e.g., supervision nodes in ABLP). metadata: Metadata dict to include in the sampler output. - use_original_seeds_for_batch: If True, use input_seeds as the batch - (ABLP behavior). If False, use the inducer result as the batch - (GLT base behavior). Defaults to False for GLT compatibility. """ input_seeds: torch.Tensor input_type: NodeType nodes_to_sample: dict[Union[str, NodeType], torch.Tensor] metadata: dict[str, torch.Tensor] - use_original_seeds_for_batch: bool = False -class GiglDistNeighborSampler(DistNeighborSampler): - """GiGL's base distributed neighbor sampler with template method pattern. +class DistNeighborSampler(GltDistNeighborSampler): + """GiGL's distributed neighbor sampler supporting both standard and ABLP inputs. - Extends GLT's DistNeighborSampler and overrides _sample_from_nodes to use - a hook (_prepare_sampling_inputs) that subclasses can override to customize - input preparation without duplicating the core sampling loop. + Extends GLT's DistNeighborSampler and overrides _sample_from_nodes to support + both NodeSamplerInput (standard neighbor sampling) and ABLPNodeSamplerInput + (anchor-based link prediction with supervision nodes). - The default implementation behaves identically to the base GLT sampler. - Subclasses can override _prepare_sampling_inputs to add additional nodes - to the sampling (e.g., supervision nodes) or populate metadata. + For ABLPNodeSamplerInput, supervision nodes (positive/negative labels) are + added to the sampling seeds, and label information is included in the output + metadata. """ def _prepare_sampling_inputs( self, inputs: NodeSamplerInput, - ) -> PreparedSamplingInputs: + ) -> SamplingInputs: """Prepare inputs for the sampling loop. - Override this method in subclasses to customize input preparation. - The default implementation uses the input seeds directly with no - additional nodes or metadata. + Handles both standard NodeSamplerInput and ABLPNodeSamplerInput. + For ABLP inputs, adds supervision nodes to the sampling seeds and + builds label metadata. Args: - inputs: The node sampler input. + inputs: Either a NodeSamplerInput or ABLPNodeSamplerInput. Returns: - PreparedSamplingInputs containing the seeds, node type, nodes to + SamplingInputs containing the seeds, node type, nodes to sample from, and metadata. """ input_seeds = inputs.node.to(self.device) input_type = inputs.input_type - return PreparedSamplingInputs( + + if isinstance(inputs, ABLPNodeSamplerInput): + return self._prepare_ablp_inputs(inputs, input_seeds, input_type) + + return SamplingInputs( input_seeds=input_seeds, input_type=input_type, nodes_to_sample={input_type: input_seeds}, metadata={}, ) + def _prepare_ablp_inputs( + self, + inputs: ABLPNodeSamplerInput, + input_seeds: torch.Tensor, + input_type: NodeType, + ) -> SamplingInputs: + """Prepare ABLP inputs with supervision nodes and label metadata. + + Args: + inputs: The ABLPNodeSamplerInput containing label information. + input_seeds: The anchor node seeds (already moved to device). + input_type: The node type of the anchor seeds. + + Returns: + SamplingInputs with supervision nodes included in nodes_to_sample + and label tensors in metadata. + """ + # Since GLT swaps src/dst for edge_dir = "out", + # and GiGL assumes that supervision edge types are always + # (anchor_node_type, to, supervision_node_type), + # we need to index into supervision edge types accordingly. + label_edge_index = 0 if self.edge_dir == "in" else 2 + + # Build metadata and input nodes from positive/negative labels. + # We need to sample from the supervision nodes as well, and ensure + # that we are sampling from the correct node type. + metadata: dict[str, torch.Tensor] = {} + input_seeds_builder: dict[ + Union[str, NodeType], list[torch.Tensor] + ] = defaultdict(list) + input_seeds_builder[input_type].append(input_seeds) + + for edge_type, label_tensor in inputs.positive_label_by_edge_types.items(): + filtered_label_tensor = label_tensor[label_tensor != PADDING_NODE].to( + self.device + ) + input_seeds_builder[edge_type[label_edge_index]].append( + filtered_label_tensor + ) + # Update the metadata per positive label edge type. + # We do this because GLT only supports dict[str, torch.Tensor] for metadata. + metadata[ + f"{POSITIVE_LABEL_METADATA_KEY}{str(tuple(edge_type))}" + ] = label_tensor + + for edge_type, label_tensor in inputs.negative_label_by_edge_types.items(): + filtered_label_tensor = label_tensor[label_tensor != PADDING_NODE].to( + self.device + ) + input_seeds_builder[edge_type[label_edge_index]].append( + filtered_label_tensor + ) + # Update the metadata per negative label edge type. + # We do this because GLT only supports dict[str, torch.Tensor] for metadata. + metadata[ + f"{NEGATIVE_LABEL_METADATA_KEY}{str(tuple(edge_type))}" + ] = label_tensor + + # As a perf optimization, we *could* have `nodes_to_sample` be only the + # unique nodes, but since torch.unique() calls a sort, we should + # investigate if it's worth it. + # TODO(kmonte, mkolodner-sc): Investigate if this is worth it. + nodes_to_sample: dict[Union[str, NodeType], torch.Tensor] = { + node_type: torch.cat(seeds, dim=0).to(self.device) + for node_type, seeds in input_seeds_builder.items() + } + + # Memory cleanup + del filtered_label_tensor, label_tensor + for value in input_seeds_builder.values(): + value.clear() + input_seeds_builder.clear() + del input_seeds_builder + gc.collect() + + return SamplingInputs( + input_seeds=input_seeds, + input_type=input_type, + nodes_to_sample=nodes_to_sample, + metadata=metadata, + ) + async def _sample_from_nodes( self, inputs: NodeSamplerInput, ) -> Optional[SampleMessage]: """Sample subgraph from seed nodes. - Uses _prepare_sampling_inputs hook to allow subclasses to customize - which nodes to sample from and what metadata to include. + Supports both NodeSamplerInput and ABLPNodeSamplerInput. For ABLP, + supervision nodes are included in sampling and label metadata is + attached to the output. """ prepared = self._prepare_sampling_inputs(inputs) - input_seeds = prepared.input_seeds input_type = prepared.input_type nodes_to_sample = prepared.nodes_to_sample metadata = prepared.metadata - use_original_seeds_for_batch = prepared.use_original_seeds_for_batch - self.max_input_size = max(self.max_input_size, input_seeds.numel()) + self.max_input_size: int = max( + self.max_input_size, prepared.input_seeds.numel() + ) inducer = self._acquire_inducer() is_hetero = self.dist_graph.data_cls == "hetero" @@ -116,10 +199,10 @@ async def _sample_from_nodes( num_sampled_edges_hetero: dict[EdgeType, list[torch.Tensor]] = {} src_dict = inducer.init_node(nodes_to_sample) - # GLT uses src_dict (inducer result) for batch; ABLP uses original seeds - batch = ( - {input_type: input_seeds} if use_original_seeds_for_batch else src_dict - ) + # Extract only the anchor node type for batch tracking. + # This excludes supervision nodes (for ABLP) and uses the + # inducer result (deduplicated). + batch = {input_type: src_dict[input_type]} merge_dict(src_dict, out_nodes_hetero) count_dict(src_dict, num_sampled_nodes_hetero, 1) @@ -190,8 +273,7 @@ async def _sample_from_nodes( ), f"Expected input type {input_type}, got {list(nodes_to_sample.keys())[0]}" srcs = inducer.init_node(nodes_to_sample[input_type]) - # GLT uses srcs (inducer result) for batch; ABLP uses original seeds - batch = input_seeds if use_original_seeds_for_batch else srcs + batch = srcs out_nodes: list[torch.Tensor] = [] out_edges: list[tuple[torch.Tensor, torch.Tensor, torch.Tensor]] = [] num_sampled_nodes: list[torch.Tensor] = [] @@ -239,98 +321,3 @@ async def _sample_from_nodes( self.inducer_pool.put(inducer) return sample_output - - -class DistABLPNeighborSampler(GiglDistNeighborSampler): - """ABLP-specific neighbor sampler that adds supervision nodes to sampling. - - Overrides _prepare_sampling_inputs to parse ABLPNodeSamplerInput, adding - supervision nodes (positive/negative labels) to the sampling seeds and - including label information in the output metadata. - """ - - def _prepare_sampling_inputs( - self, - inputs: NodeSamplerInput, - ) -> PreparedSamplingInputs: - """Prepare ABLP inputs with supervision nodes and label metadata. - - Parses ABLPNodeSamplerInput to extract positive/negative label nodes, - adds them to the sampling seeds, and builds metadata with label tensors. - - Args: - inputs: Must be an ABLPNodeSamplerInput. - - Returns: - PreparedSamplingInputs with supervision nodes included in - nodes_to_sample and label tensors in metadata. - """ - assert isinstance(inputs, ABLPNodeSamplerInput) - input_seeds = inputs.node.to(self.device) - input_type = inputs.input_type - - # Since GLT swaps src/dst for edge_dir = "out", - # and GiGL assumes that supervision edge types are always - # (anchor_node_type, to, supervision_node_type), - # we need to index into supervision edge types accordingly. - label_edge_index = 0 if self.edge_dir == "in" else 2 - - # Build metadata and input nodes from positive/negative labels. - # We need to sample from the supervision nodes as well, and ensure - # that we are sampling from the correct node type. - metadata: dict[str, torch.Tensor] = {} - input_seeds_builder: dict[ - Union[str, NodeType], list[torch.Tensor] - ] = defaultdict(list) - input_seeds_builder[input_type].append(input_seeds) - - for edge_type, label_tensor in inputs.positive_label_by_edge_types.items(): - filtered_label_tensor = label_tensor[label_tensor != PADDING_NODE].to( - self.device - ) - input_seeds_builder[edge_type[label_edge_index]].append( - filtered_label_tensor - ) - # Update the metadata per positive label edge type. - # We do this because GLT only supports dict[str, torch.Tensor] for metadata. - metadata[ - f"{POSITIVE_LABEL_METADATA_KEY}{str(tuple(edge_type))}" - ] = label_tensor - - for edge_type, label_tensor in inputs.negative_label_by_edge_types.items(): - filtered_label_tensor = label_tensor[label_tensor != PADDING_NODE].to( - self.device - ) - input_seeds_builder[edge_type[label_edge_index]].append( - filtered_label_tensor - ) - # Update the metadata per negative label edge type. - # We do this because GLT only supports dict[str, torch.Tensor] for metadata. - metadata[ - f"{NEGATIVE_LABEL_METADATA_KEY}{str(tuple(edge_type))}" - ] = label_tensor - - # As a perf optimization, we *could* have `nodes_to_sample` be only the - # unique nodes, but since torch.unique() calls a sort, we should - # investigate if it's worth it. - # TODO(kmonte, mkolodner-sc): Investigate if this is worth it. - nodes_to_sample: dict[Union[str, NodeType], torch.Tensor] = { - node_type: torch.cat(seeds, dim=0).to(self.device) - for node_type, seeds in input_seeds_builder.items() - } - - # Memory cleanup - del filtered_label_tensor, label_tensor - for value in input_seeds_builder.values(): - value.clear() - input_seeds_builder.clear() - del input_seeds_builder - gc.collect() - - return PreparedSamplingInputs( - input_seeds=input_seeds, - input_type=input_type, - nodes_to_sample=nodes_to_sample, - metadata=metadata, - use_original_seeds_for_batch=True, - ) diff --git a/gigl/distributed/dist_sampling_producer.py b/gigl/distributed/dist_sampling_producer.py index 3e80d1d68..174cd0b09 100644 --- a/gigl/distributed/dist_sampling_producer.py +++ b/gigl/distributed/dist_sampling_producer.py @@ -1,5 +1,6 @@ -# All code in this file is directly taken from GraphLearn-for-PyTorch (graphlearn_torch/python/distributed/dist_sampling_producer.py), -# with the exception that we call the GiGL DistNeighborSampler with custom link prediction logic instead of the GLT DistNeighborSampler. +# A significant amount of code in this file is directly taken from GraphLearn-for-PyTorch (graphlearn_torch/python/distributed/dist_sampling_producer.py), +# Sampling producer that uses GiGL's DistNeighborSampler (which supports both +# standard neighbor sampling and ABLP) instead of GLT's DistNeighborSampler. import datetime import queue @@ -32,7 +33,7 @@ from torch.utils.data.dataloader import DataLoader from torch.utils.data.dataset import Dataset -from gigl.distributed.dist_neighbor_sampler import DistABLPNeighborSampler +from gigl.distributed.dist_neighbor_sampler import DistNeighborSampler def _sampling_worker_loop( @@ -84,7 +85,7 @@ def _sampling_worker_loop( if sampling_config.seed is not None: seed_everything(sampling_config.seed) - dist_sampler = DistABLPNeighborSampler( + dist_sampler = DistNeighborSampler( data, sampling_config.num_neighbors, sampling_config.with_edge, @@ -165,7 +166,7 @@ def _sampling_worker_loop( shutdown_rpc(graceful=False) -class DistABLPSamplingProducer(DistMpSamplingProducer): +class DistSamplingProducer(DistMpSamplingProducer): def init(self): r"""Create the subprocess pool. Init samplers and rpc server.""" if self.sampling_config.seed is not None: diff --git a/gigl/distributed/distributed_neighborloader.py b/gigl/distributed/distributed_neighborloader.py index 3033030d1..e914ffd05 100644 --- a/gigl/distributed/distributed_neighborloader.py +++ b/gigl/distributed/distributed_neighborloader.py @@ -9,7 +9,6 @@ MpDistSamplingWorkerOptions, RemoteDistSamplingWorkerOptions, ) -from graphlearn_torch.distributed.dist_sampling_producer import DistMpSamplingProducer from graphlearn_torch.sampler import NodeSamplerInput from torch_geometric.data import Data, HeteroData from torch_geometric.typing import EdgeType @@ -19,6 +18,7 @@ from gigl.distributed.base_dist_loader import BaseDistLoader from gigl.distributed.dist_context import DistributedContext from gigl.distributed.dist_dataset import DistDataset +from gigl.distributed.dist_sampling_producer import DistSamplingProducer from gigl.distributed.graph_store.dist_server import DistServer as GiglDistServer from gigl.distributed.graph_store.remote_dist_dataset import RemoteDistDataset from gigl.distributed.utils.neighborloader import ( @@ -238,8 +238,8 @@ def __init__( assert isinstance(worker_options, MpDistSamplingWorkerOptions) channel = BaseDistLoader.create_colocated_channel(worker_options) producer: Union[ - DistMpSamplingProducer, Callable[..., int] - ] = DistMpSamplingProducer( + DistSamplingProducer, Callable[..., int] + ] = DistSamplingProducer( dataset, input_data, sampling_config, diff --git a/gigl/distributed/graph_store/dist_server.py b/gigl/distributed/graph_store/dist_server.py index 1506432b2..46649c56e 100644 --- a/gigl/distributed/graph_store/dist_server.py +++ b/gigl/distributed/graph_store/dist_server.py @@ -1,7 +1,8 @@ """ GiGL implementation of GLT DistServer. -Main change here is that we use gigl DistAblpSamplingProducer instead of GLT DistMpSamplingProducer. +Uses GiGL's DistSamplingProducer which supports both standard neighbor sampling +and ABLP (anchor-based link prediction) via the unified DistNeighborSampler. Based on https://github.com/alibaba/graphlearn-for-pytorch/blob/main/graphlearn_torch/python/distributed/dist_server.py """ @@ -17,7 +18,6 @@ import torch from graphlearn_torch.channel import QueueTimeoutError, ShmChannel from graphlearn_torch.distributed import ( - DistMpSamplingProducer, RemoteDistSamplingWorkerOptions, barrier, init_rpc, @@ -33,7 +33,7 @@ from gigl.common.logger import Logger from gigl.distributed.dist_dataset import DistDataset -from gigl.distributed.dist_sampling_producer import DistABLPSamplingProducer +from gigl.distributed.dist_sampling_producer import DistSamplingProducer from gigl.distributed.sampler import ABLPNodeSamplerInput from gigl.distributed.utils.neighborloader import shard_nodes_by_process from gigl.src.common.types.graph_data import EdgeType, NodeType @@ -80,7 +80,7 @@ def __init__(self, dataset: DistDataset) -> None: # The mapping from the key in worker options (such as 'train', 'test') # to producer id self._worker_key2producer_id: dict[str, int] = {} - self._producer_pool: dict[int, DistMpSamplingProducer] = {} + self._producer_pool: dict[int, DistSamplingProducer] = {} self._msg_buffer_pool: dict[int, ShmChannel] = {} self._epoch: dict[int, int] = {} # last epoch for the producer # Per-producer locks that guard the lifecycle of individual producers @@ -424,40 +424,6 @@ def get_ablp_input( ) return anchors, positive_labels, negative_labels - def create_sampling_ablp_producer( - self, - sampler_input: Union[ - NodeSamplerInput, EdgeSamplerInput, RemoteSamplerInput, ABLPNodeSamplerInput - ], - sampling_config: SamplingConfig, - worker_options: RemoteDistSamplingWorkerOptions, - ) -> int: - r"""Create and initialize an instance of ``DistABLPSamplingProducer`` with - a group of subprocesses for distributed sampling. - - Args: - sampler_input (NodeSamplerInput or EdgeSamplerInput): The input data - for sampling. - sampling_config (SamplingConfig): Configuration of sampling meta info. - worker_options (RemoteDistSamplingWorkerOptions): Options for launching - remote sampling workers by this server. - - Returns: - A unique id of created sampling producer on this server. - """ - - if not isinstance(sampler_input, ABLPNodeSamplerInput): - raise ValueError( - f"Sampler input must be an instance of ABLPNodeSamplerInput. Received: {type(sampler_input)}" - ) - - return self._create_producer( - sampler_input=sampler_input, - sampling_config=sampling_config, - worker_options=worker_options, - producer_cls=DistABLPSamplingProducer, - ) - def create_sampling_producer( self, sampler_input: Union[ @@ -469,36 +435,8 @@ def create_sampling_producer( r"""Create and initialize an instance of ``DistSamplingProducer`` with a group of subprocesses for distributed sampling. - Args: - sampler_input (NodeSamplerInput or EdgeSamplerInput): The input data - for sampling. - sampling_config (SamplingConfig): Configuration of sampling meta info. - worker_options (RemoteDistSamplingWorkerOptions): Options for launching - remote sampling workers by this server. - - Returns: - A unique id of created sampling producer on this server. - """ - return self._create_producer( - sampler_input=sampler_input, - sampling_config=sampling_config, - worker_options=worker_options, - producer_cls=DistMpSamplingProducer, - ) - - def _create_producer( - self, - sampler_input: Union[ - NodeSamplerInput, EdgeSamplerInput, RemoteSamplerInput, ABLPNodeSamplerInput - ], - sampling_config: SamplingConfig, - worker_options: RemoteDistSamplingWorkerOptions, - producer_cls: type[Union[DistABLPSamplingProducer, DistMpSamplingProducer]], - ) -> int: - r"""Shared logic to create and initialize a sampling producer. - - Converts remote sampler inputs to local, creates a ``ShmChannel`` buffer, - instantiates the given ``producer_cls``, and registers it in the internal pools. + Supports both standard NodeSamplerInput and ABLPNodeSamplerInput since the + underlying DistNeighborSampler handles both input types. Args: sampler_input (NodeSamplerInput, EdgeSamplerInput, RemoteSamplerInput, @@ -506,11 +444,9 @@ def _create_producer( sampling_config (SamplingConfig): Configuration of sampling meta info. worker_options (RemoteDistSamplingWorkerOptions): Options for launching remote sampling workers by this server. - producer_cls: The producer class to instantiate - (``DistABLPSamplingProducer`` or ``DistMpSamplingProducer``). Returns: - int: A unique id of created sampling producer on this server. + A unique id of created sampling producer on this server. """ if isinstance(sampler_input, RemoteSamplerInput): sampler_input = sampler_input.to_local_sampler_input(dataset=self.dataset) @@ -530,7 +466,7 @@ def _create_producer( buffer = ShmChannel( worker_options.buffer_capacity, worker_options.buffer_size ) - producer = producer_cls( + producer = DistSamplingProducer( self.dataset, sampler_input, sampling_config, worker_options, buffer ) producer.init() From 7a36925594c8e7a247e8e66e86a74afefc5b7934 Mon Sep 17 00:00:00 2001 From: mkolodner Date: Wed, 4 Mar 2026 23:42:56 +0000 Subject: [PATCH 04/46] Update --- gigl/distributed/dist_neighbor_sampler.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/gigl/distributed/dist_neighbor_sampler.py b/gigl/distributed/dist_neighbor_sampler.py index c32a5dc3f..6639a1f9a 100644 --- a/gigl/distributed/dist_neighbor_sampler.py +++ b/gigl/distributed/dist_neighbor_sampler.py @@ -6,7 +6,7 @@ import torch from graphlearn_torch.channel import SampleMessage -from graphlearn_torch.distributed import DistNeighborSampler as GltDistNeighborSampler +from graphlearn_torch.distributed import DistNeighborSampler as GLTDistNeighborSampler from graphlearn_torch.sampler import ( HeteroSamplerOutput, NeighborOutput, @@ -43,7 +43,7 @@ class SamplingInputs: metadata: dict[str, torch.Tensor] -class DistNeighborSampler(GltDistNeighborSampler): +class DistNeighborSampler(GLTDistNeighborSampler): """GiGL's distributed neighbor sampler supporting both standard and ABLP inputs. Extends GLT's DistNeighborSampler and overrides _sample_from_nodes to support From 21d68ebf361324c279ddfb9b8ceb04f1a83144a3 Mon Sep 17 00:00:00 2001 From: mkolodner Date: Wed, 4 Mar 2026 23:49:12 +0000 Subject: [PATCH 05/46] Update --- gigl/distributed/dist_neighbor_sampler.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/gigl/distributed/dist_neighbor_sampler.py b/gigl/distributed/dist_neighbor_sampler.py index 6639a1f9a..f3e63909b 100644 --- a/gigl/distributed/dist_neighbor_sampler.py +++ b/gigl/distributed/dist_neighbor_sampler.py @@ -177,13 +177,13 @@ async def _sample_from_nodes( supervision nodes are included in sampling and label metadata is attached to the output. """ - prepared = self._prepare_sampling_inputs(inputs) - input_type = prepared.input_type - nodes_to_sample = prepared.nodes_to_sample - metadata = prepared.metadata + prepared_inputs = self._prepare_sampling_inputs(inputs) + input_type = prepared_inputs.input_type + nodes_to_sample = prepared_inputs.nodes_to_sample + metadata = prepared_inputs.metadata self.max_input_size: int = max( - self.max_input_size, prepared.input_seeds.numel() + self.max_input_size, prepared_inputs.input_seeds.numel() ) inducer = self._acquire_inducer() is_hetero = self.dist_graph.data_cls == "hetero" From f114822eb0198d1565584628c943d69d524e9674 Mon Sep 17 00:00:00 2001 From: mkolodner Date: Wed, 4 Mar 2026 23:53:30 +0000 Subject: [PATCH 06/46] Update --- gigl/distributed/dist_neighbor_sampler.py | 29 +++++++++++++---------- 1 file changed, 17 insertions(+), 12 deletions(-) diff --git a/gigl/distributed/dist_neighbor_sampler.py b/gigl/distributed/dist_neighbor_sampler.py index f3e63909b..51c61b953 100644 --- a/gigl/distributed/dist_neighbor_sampler.py +++ b/gigl/distributed/dist_neighbor_sampler.py @@ -25,16 +25,21 @@ @dataclass -class SamplingInputs: - """Prepared inputs for the sampling loop. +class SampleLoopInputs: + """Inputs prepared for the neighbor sampling loop in _sample_from_nodes. + + This dataclass holds the processed inputs that are passed to the core + sampling loop. It allows _prepare_sampling_inputs to customize what nodes + are sampled from and what metadata is attached to the output, without + duplicating the sampling loop logic. Attributes: - input_seeds: The original anchor node seeds. + input_seeds: The original anchor/batch node seeds (used for batch tracking). input_type: The node type of the anchor seeds. nodes_to_sample: Dict mapping node types to tensors of nodes to include - in sampling. May include additional nodes beyond input_seeds - (e.g., supervision nodes in ABLP). - metadata: Metadata dict to include in the sampler output. + in the sampling. For standard sampling, this equals {input_type: input_seeds}. + For ABLP, this also includes supervision nodes (positive/negative labels). + metadata: Metadata dict to attach to the sampler output (e.g., label tensors). """ input_seeds: torch.Tensor @@ -58,7 +63,7 @@ class DistNeighborSampler(GLTDistNeighborSampler): def _prepare_sampling_inputs( self, inputs: NodeSamplerInput, - ) -> SamplingInputs: + ) -> SampleLoopInputs: """Prepare inputs for the sampling loop. Handles both standard NodeSamplerInput and ABLPNodeSamplerInput. @@ -69,7 +74,7 @@ def _prepare_sampling_inputs( inputs: Either a NodeSamplerInput or ABLPNodeSamplerInput. Returns: - SamplingInputs containing the seeds, node type, nodes to + SampleLoopInputs containing the seeds, node type, nodes to sample from, and metadata. """ input_seeds = inputs.node.to(self.device) @@ -78,7 +83,7 @@ def _prepare_sampling_inputs( if isinstance(inputs, ABLPNodeSamplerInput): return self._prepare_ablp_inputs(inputs, input_seeds, input_type) - return SamplingInputs( + return SampleLoopInputs( input_seeds=input_seeds, input_type=input_type, nodes_to_sample={input_type: input_seeds}, @@ -90,7 +95,7 @@ def _prepare_ablp_inputs( inputs: ABLPNodeSamplerInput, input_seeds: torch.Tensor, input_type: NodeType, - ) -> SamplingInputs: + ) -> SampleLoopInputs: """Prepare ABLP inputs with supervision nodes and label metadata. Args: @@ -99,7 +104,7 @@ def _prepare_ablp_inputs( input_type: The node type of the anchor seeds. Returns: - SamplingInputs with supervision nodes included in nodes_to_sample + SampleLoopInputs with supervision nodes included in nodes_to_sample and label tensors in metadata. """ # Since GLT swaps src/dst for edge_dir = "out", @@ -160,7 +165,7 @@ def _prepare_ablp_inputs( del input_seeds_builder gc.collect() - return SamplingInputs( + return SampleLoopInputs( input_seeds=input_seeds, input_type=input_type, nodes_to_sample=nodes_to_sample, From 7af2d0cd5890073352564d628aaa48b60f01bd4e Mon Sep 17 00:00:00 2001 From: mkolodner Date: Thu, 5 Mar 2026 01:08:08 +0000 Subject: [PATCH 07/46] Update --- gigl/distributed/dist_neighbor_sampler.py | 39 ++++++++++------------- 1 file changed, 16 insertions(+), 23 deletions(-) diff --git a/gigl/distributed/dist_neighbor_sampler.py b/gigl/distributed/dist_neighbor_sampler.py index 51c61b953..150de2288 100644 --- a/gigl/distributed/dist_neighbor_sampler.py +++ b/gigl/distributed/dist_neighbor_sampler.py @@ -34,16 +34,12 @@ class SampleLoopInputs: duplicating the sampling loop logic. Attributes: - input_seeds: The original anchor/batch node seeds (used for batch tracking). - input_type: The node type of the anchor seeds. nodes_to_sample: Dict mapping node types to tensors of nodes to include in the sampling. For standard sampling, this equals {input_type: input_seeds}. For ABLP, this also includes supervision nodes (positive/negative labels). metadata: Metadata dict to attach to the sampler output (e.g., label tensors). """ - input_seeds: torch.Tensor - input_type: NodeType nodes_to_sample: dict[Union[str, NodeType], torch.Tensor] metadata: dict[str, torch.Tensor] @@ -60,7 +56,7 @@ class DistNeighborSampler(GLTDistNeighborSampler): metadata. """ - def _prepare_sampling_inputs( + def _prepare_sample_loop_inputs( self, inputs: NodeSamplerInput, ) -> SampleLoopInputs: @@ -74,8 +70,8 @@ def _prepare_sampling_inputs( inputs: Either a NodeSamplerInput or ABLPNodeSamplerInput. Returns: - SampleLoopInputs containing the seeds, node type, nodes to - sample from, and metadata. + SampleLoopInputs containing the node type, nodes to sample from, + and any metadata related to the task. """ input_seeds = inputs.node.to(self.device) input_type = inputs.input_type @@ -84,8 +80,6 @@ def _prepare_sampling_inputs( return self._prepare_ablp_inputs(inputs, input_seeds, input_type) return SampleLoopInputs( - input_seeds=input_seeds, - input_type=input_type, nodes_to_sample={input_type: input_seeds}, metadata={}, ) @@ -166,8 +160,6 @@ def _prepare_ablp_inputs( gc.collect() return SampleLoopInputs( - input_seeds=input_seeds, - input_type=input_type, nodes_to_sample=nodes_to_sample, metadata=metadata, ) @@ -182,14 +174,12 @@ async def _sample_from_nodes( supervision nodes are included in sampling and label metadata is attached to the output. """ - prepared_inputs = self._prepare_sampling_inputs(inputs) - input_type = prepared_inputs.input_type - nodes_to_sample = prepared_inputs.nodes_to_sample - metadata = prepared_inputs.metadata + sample_loop_inputs = self._prepare_sample_loop_inputs(inputs) + input_type = inputs.input_type + nodes_to_sample = sample_loop_inputs.nodes_to_sample + metadata = sample_loop_inputs.metadata - self.max_input_size: int = max( - self.max_input_size, prepared_inputs.input_seeds.numel() - ) + self.max_input_size = max(self.max_input_size, inputs.node.numel()) inducer = self._acquire_inducer() is_hetero = self.dist_graph.data_cls == "hetero" @@ -204,10 +194,10 @@ async def _sample_from_nodes( num_sampled_edges_hetero: dict[EdgeType, list[torch.Tensor]] = {} src_dict = inducer.init_node(nodes_to_sample) - # Extract only the anchor node type for batch tracking. - # This excludes supervision nodes (for ABLP) and uses the - # inducer result (deduplicated). - batch = {input_type: src_dict[input_type]} + # Use the original anchor seeds (inputs.node) for batch tracking, + # not the deduped nodes_to_sample. For ABLP, nodes_to_sample includes + # supervision nodes which should not be part of the batch. + batch = {input_type: inputs.node.to(self.device)} merge_dict(src_dict, out_nodes_hetero) count_dict(src_dict, num_sampled_nodes_hetero, 1) @@ -278,7 +268,10 @@ async def _sample_from_nodes( ), f"Expected input type {input_type}, got {list(nodes_to_sample.keys())[0]}" srcs = inducer.init_node(nodes_to_sample[input_type]) - batch = srcs + # Use the original anchor seeds (inputs.node) for batch tracking, + # not the deduped nodes_to_sample. For ABLP, nodes_to_sample includes + # supervision nodes which should not be part of the batch. + batch = inputs.node.to(self.device) out_nodes: list[torch.Tensor] = [] out_edges: list[tuple[torch.Tensor, torch.Tensor, torch.Tensor]] = [] num_sampled_nodes: list[torch.Tensor] = [] From 6498544081c6269bba24862b950d7ee77b77cd56 Mon Sep 17 00:00:00 2001 From: mkolodner Date: Thu, 5 Mar 2026 08:50:02 +0000 Subject: [PATCH 08/46] Fix --- gigl/distributed/dist_neighbor_sampler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gigl/distributed/dist_neighbor_sampler.py b/gigl/distributed/dist_neighbor_sampler.py index 150de2288..b1abbf5b0 100644 --- a/gigl/distributed/dist_neighbor_sampler.py +++ b/gigl/distributed/dist_neighbor_sampler.py @@ -179,7 +179,7 @@ async def _sample_from_nodes( nodes_to_sample = sample_loop_inputs.nodes_to_sample metadata = sample_loop_inputs.metadata - self.max_input_size = max(self.max_input_size, inputs.node.numel()) + self.max_input_size: int = max(self.max_input_size, inputs.node.numel()) inducer = self._acquire_inducer() is_hetero = self.dist_graph.data_cls == "hetero" From ec43f274e37c37264797170205c08d3863090b95 Mon Sep 17 00:00:00 2001 From: mkolodner Date: Thu, 5 Mar 2026 19:01:00 +0000 Subject: [PATCH 09/46] Update --- gigl/distributed/base_dist_loader.py | 16 +++++++------- gigl/distributed/dist_neighbor_sampler.py | 24 +++++++++++++-------- gigl/distributed/graph_store/dist_server.py | 10 ++++----- 3 files changed, 28 insertions(+), 22 deletions(-) diff --git a/gigl/distributed/base_dist_loader.py b/gigl/distributed/base_dist_loader.py index 3455c8a11..d1d9632fb 100644 --- a/gigl/distributed/base_dist_loader.py +++ b/gigl/distributed/base_dist_loader.py @@ -23,7 +23,6 @@ get_context, ) from graphlearn_torch.distributed.dist_client import async_request_server -from graphlearn_torch.distributed.dist_sampling_producer import DistMpSamplingProducer from graphlearn_torch.distributed.rpc import rpc_is_initialized from graphlearn_torch.sampler import ( NodeSamplerInput, @@ -39,6 +38,7 @@ from gigl.distributed.constants import DEFAULT_MASTER_INFERENCE_PORT from gigl.distributed.dist_context import DistributedContext from gigl.distributed.dist_dataset import DistDataset +from gigl.distributed.dist_sampling_producer import DistSamplingProducer from gigl.distributed.graph_store.dist_server import DistServer from gigl.distributed.graph_store.remote_dist_dataset import RemoteDistDataset from gigl.distributed.utils.neighborloader import ( @@ -83,7 +83,7 @@ class BaseDistLoader(DistLoader): 2. Determine mode (colocated vs graph store). 3. Call ``create_sampling_config()`` to build the SamplingConfig. 4. For colocated: call ``create_colocated_channel()`` and construct the - ``DistMpSamplingProducer`` (or subclass), then pass the producer as ``producer``. + ``DistSamplingProducer`` (or subclass), then pass the producer as ``producer``. 5. For graph store: pass the RPC function (e.g. ``DistServer.create_sampling_producer``) as ``producer``. 6. Call ``super().__init__()`` with the prepared data. @@ -98,7 +98,7 @@ class BaseDistLoader(DistLoader): sampling_config: Configuration for the sampler (created via ``create_sampling_config``). device: Target device for sampled results. runtime: Resolved distributed runtime information. - producer: Either a pre-constructed ``DistMpSamplingProducer`` (colocated mode) + producer: Either a pre-constructed ``DistSamplingProducer`` (colocated mode) or a callable to dispatch on the ``DistServer`` (graph store mode). process_start_gap_seconds: Delay between each process for staggered colocated init. """ @@ -204,7 +204,7 @@ def __init__( sampling_config: SamplingConfig, device: torch.device, runtime: DistributedRuntimeInfo, - producer: Union[DistMpSamplingProducer, Callable[..., int]], + producer: Union[DistSamplingProducer, Callable[..., int]], process_start_gap_seconds: float = 60.0, ): # Set right away so __del__ can clean up if we throw during init. @@ -239,7 +239,7 @@ def __init__( self._epoch = 0 # --- Mode-specific attributes and connection initialization --- - if isinstance(producer, DistMpSamplingProducer): + if isinstance(producer, DistSamplingProducer): assert isinstance(dataset, DistDataset) assert isinstance(worker_options, MpDistSamplingWorkerOptions) assert isinstance(sampler_input, NodeSamplerInput) @@ -356,7 +356,7 @@ def create_colocated_channel( worker_options: The colocated worker options (must already be fully configured). Returns: - A ShmChannel ready to be passed to a DistMpSamplingProducer. + A ShmChannel ready to be passed to a DistSamplingProducer. """ channel = ShmChannel( worker_options.channel_capacity, worker_options.channel_size @@ -368,7 +368,7 @@ def create_colocated_channel( def _init_colocated_connections( self, dataset: DistDataset, - producer: DistMpSamplingProducer, + producer: DistSamplingProducer, runtime: DistributedRuntimeInfo, process_start_gap_seconds: float, ) -> None: @@ -381,7 +381,7 @@ def _init_colocated_connections( Args: dataset: The local DistDataset. - producer: A pre-constructed DistMpSamplingProducer (or subclass). + producer: A pre-constructed DistSamplingProducer (or subclass). runtime: Resolved distributed runtime info (used for staggered sleep). process_start_gap_seconds: Delay multiplier for staggered init. """ diff --git a/gigl/distributed/dist_neighbor_sampler.py b/gigl/distributed/dist_neighbor_sampler.py index b1abbf5b0..bbacca3eb 100644 --- a/gigl/distributed/dist_neighbor_sampler.py +++ b/gigl/distributed/dist_neighbor_sampler.py @@ -34,13 +34,13 @@ class SampleLoopInputs: duplicating the sampling loop logic. Attributes: - nodes_to_sample: Dict mapping node types to tensors of nodes to include - in the sampling. For standard sampling, this equals {input_type: input_seeds}. + nodes_to_sample: For homogeneous graphs, a tensor of node IDs. For + heterogeneous graphs, a dict mapping node types to tensors. For ABLP, this also includes supervision nodes (positive/negative labels). metadata: Metadata dict to attach to the sampler output (e.g., label tensors). """ - nodes_to_sample: dict[Union[str, NodeType], torch.Tensor] + nodes_to_sample: Union[torch.Tensor, dict[NodeType, torch.Tensor]] metadata: dict[str, torch.Tensor] @@ -79,6 +79,13 @@ def _prepare_sample_loop_inputs( if isinstance(inputs, ABLPNodeSamplerInput): return self._prepare_ablp_inputs(inputs, input_seeds, input_type) + # For homogeneous graphs (input_type is None), return tensor directly. + # For heterogeneous graphs, return dict mapping node type to tensor. + if input_type is None: + return SampleLoopInputs( + nodes_to_sample=input_seeds, + metadata={}, + ) return SampleLoopInputs( nodes_to_sample={input_type: input_seeds}, metadata={}, @@ -186,6 +193,7 @@ async def _sample_from_nodes( output: NeighborOutput if is_hetero: assert input_type is not None + assert isinstance(nodes_to_sample, dict) out_nodes_hetero: dict[NodeType, list[torch.Tensor]] = {} out_rows_hetero: dict[EdgeType, list[torch.Tensor]] = {} out_cols_hetero: dict[EdgeType, list[torch.Tensor]] = {} @@ -261,13 +269,11 @@ async def _sample_from_nodes( ) else: assert ( - len(nodes_to_sample) == 1 - ), f"Expected 1 input node type, got {len(nodes_to_sample)}" - assert ( - input_type == list(nodes_to_sample.keys())[0] - ), f"Expected input type {input_type}, got {list(nodes_to_sample.keys())[0]}" + input_type is None + ), f"Expected input_type to be None for homogeneous graph, got {input_type}" + assert isinstance(nodes_to_sample, torch.Tensor) - srcs = inducer.init_node(nodes_to_sample[input_type]) + srcs = inducer.init_node(nodes_to_sample) # Use the original anchor seeds (inputs.node) for batch tracking, # not the deduped nodes_to_sample. For ABLP, nodes_to_sample includes # supervision nodes which should not be part of the batch. diff --git a/gigl/distributed/graph_store/dist_server.py b/gigl/distributed/graph_store/dist_server.py index 46649c56e..5d730b9da 100644 --- a/gigl/distributed/graph_store/dist_server.py +++ b/gigl/distributed/graph_store/dist_server.py @@ -1,8 +1,8 @@ """ GiGL implementation of GLT DistServer. -Uses GiGL's DistSamplingProducer which supports both standard neighbor sampling -and ABLP (anchor-based link prediction) via the unified DistNeighborSampler. +Uses GiGL's DistSamplingProducer which supports neighbor sampling +and ABLP (anchor-based link prediction) via the DistNeighborSampler. Based on https://github.com/alibaba/graphlearn-for-pytorch/blob/main/graphlearn_torch/python/distributed/dist_server.py """ @@ -435,8 +435,8 @@ def create_sampling_producer( r"""Create and initialize an instance of ``DistSamplingProducer`` with a group of subprocesses for distributed sampling. - Supports both standard NodeSamplerInput and ABLPNodeSamplerInput since the - underlying DistNeighborSampler handles both input types. + Supports standard NodeSamplerInput and ABLPNodeSamplerInput through the + DistNeighborSampler Args: sampler_input (NodeSamplerInput, EdgeSamplerInput, RemoteSamplerInput, @@ -446,7 +446,7 @@ def create_sampling_producer( remote sampling workers by this server. Returns: - A unique id of created sampling producer on this server. + int: A unique id of created sampling producer on this server. """ if isinstance(sampler_input, RemoteSamplerInput): sampler_input = sampler_input.to_local_sampler_input(dataset=self.dataset) From 5129dd0052101f104c52616c01f9b98d0cad85cd Mon Sep 17 00:00:00 2001 From: mkolodner Date: Thu, 5 Mar 2026 19:55:19 +0000 Subject: [PATCH 10/46] fix: empty tensor dtype mismatch, docstring cleanup, and del guard Co-Authored-By: Claude Opus 4.6 --- gigl/distributed/base_dist_loader.py | 2 +- gigl/distributed/dist_neighbor_sampler.py | 24 ++++++++++++++------- gigl/distributed/dist_sampling_producer.py | 7 +++--- gigl/distributed/graph_store/dist_server.py | 6 +++--- 4 files changed, 24 insertions(+), 15 deletions(-) diff --git a/gigl/distributed/base_dist_loader.py b/gigl/distributed/base_dist_loader.py index d1d9632fb..ae87d4b6d 100644 --- a/gigl/distributed/base_dist_loader.py +++ b/gigl/distributed/base_dist_loader.py @@ -95,7 +95,7 @@ class BaseDistLoader(DistLoader): dataset_schema: Contains edge types, feature info, edge dir, etc. worker_options: ``MpDistSamplingWorkerOptions`` (colocated) or ``RemoteDistSamplingWorkerOptions`` (graph store). - sampling_config: Configuration for the sampler (created via ``create_sampling_config``). + sampling_config: Configuration for sampling (created via ``create_sampling_config``). device: Target device for sampled results. runtime: Resolved distributed runtime information. producer: Either a pre-constructed ``DistSamplingProducer`` (colocated mode) diff --git a/gigl/distributed/dist_neighbor_sampler.py b/gigl/distributed/dist_neighbor_sampler.py index bbacca3eb..0892e090c 100644 --- a/gigl/distributed/dist_neighbor_sampler.py +++ b/gigl/distributed/dist_neighbor_sampler.py @@ -29,7 +29,7 @@ class SampleLoopInputs: """Inputs prepared for the neighbor sampling loop in _sample_from_nodes. This dataclass holds the processed inputs that are passed to the core - sampling loop. It allows _prepare_sampling_inputs to customize what nodes + sampling loop. It allows _prepare_sample_loop_inputs to customize what nodes are sampled from and what metadata is attached to the output, without duplicating the sampling loop logic. @@ -70,8 +70,8 @@ def _prepare_sample_loop_inputs( inputs: Either a NodeSamplerInput or ABLPNodeSamplerInput. Returns: - SampleLoopInputs containing the node type, nodes to sample from, - and any metadata related to the task. + SampleLoopInputs containing the nodes to sample from and any + metadata related to the task (e.g., label tensors for ABLP). """ input_seeds = inputs.node.to(self.device) input_type = inputs.input_type @@ -158,8 +158,12 @@ def _prepare_ablp_inputs( for node_type, seeds in input_seeds_builder.items() } - # Memory cleanup - del filtered_label_tensor, label_tensor + # Memory cleanup — only del loop vars if any labels were processed + has_labels = bool( + inputs.positive_label_by_edge_types or inputs.negative_label_by_edge_types + ) + if has_labels: + del filtered_label_tensor, label_tensor for value in input_seeds_builder.values(): value.clear() input_seeds_builder.clear() @@ -301,9 +305,13 @@ async def _sample_from_nodes( if not out_edges: sample_output = SamplerOutput( node=torch.cat(out_nodes), - row=torch.tensor([]).to(self.device), - col=torch.tensor([]).to(self.device), - edge=(torch.tensor([]).to(self.device) if self.with_edge else None), + row=torch.empty(0, dtype=torch.long, device=self.device), + col=torch.empty(0, dtype=torch.long, device=self.device), + edge=( + torch.empty(0, dtype=torch.long, device=self.device) + if self.with_edge + else None + ), batch=batch, num_sampled_nodes=num_sampled_nodes, num_sampled_edges=num_sampled_edges, diff --git a/gigl/distributed/dist_sampling_producer.py b/gigl/distributed/dist_sampling_producer.py index 174cd0b09..0949400ad 100644 --- a/gigl/distributed/dist_sampling_producer.py +++ b/gigl/distributed/dist_sampling_producer.py @@ -1,6 +1,7 @@ -# A significant amount of code in this file is directly taken from GraphLearn-for-PyTorch (graphlearn_torch/python/distributed/dist_sampling_producer.py), -# Sampling producer that uses GiGL's DistNeighborSampler (which supports both -# standard neighbor sampling and ABLP) instead of GLT's DistNeighborSampler. +# Significant portions of this file are taken from GraphLearn-for-PyTorch +# (graphlearn_torch/python/distributed/dist_sampling_producer.py). +# This version uses GiGL's DistNeighborSampler (which supports both standard +# neighbor sampling and ABLP) instead of GLT's DistNeighborSampler. import datetime import queue diff --git a/gigl/distributed/graph_store/dist_server.py b/gigl/distributed/graph_store/dist_server.py index 5d730b9da..dc4c0b082 100644 --- a/gigl/distributed/graph_store/dist_server.py +++ b/gigl/distributed/graph_store/dist_server.py @@ -432,11 +432,11 @@ def create_sampling_producer( sampling_config: SamplingConfig, worker_options: RemoteDistSamplingWorkerOptions, ) -> int: - r"""Create and initialize an instance of ``DistSamplingProducer`` with + """Create and initialize an instance of ``DistSamplingProducer`` with a group of subprocesses for distributed sampling. - Supports standard NodeSamplerInput and ABLPNodeSamplerInput through the - DistNeighborSampler + Supports both standard ``NodeSamplerInput`` and ``ABLPNodeSamplerInput`` + through the unified ``DistNeighborSampler``. Args: sampler_input (NodeSamplerInput, EdgeSamplerInput, RemoteSamplerInput, From 2377963106c33683695ef9f1474ce8aa2e8937b4 Mon Sep 17 00:00:00 2001 From: mkolodner Date: Thu, 5 Mar 2026 20:55:18 +0000 Subject: [PATCH 11/46] Update format --- CLAUDE.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/CLAUDE.md b/CLAUDE.md index a781eb5ff..76bde27e4 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -56,8 +56,8 @@ GiGL extends GraphLearn-for-PyTorch (GLT) for distributed GNN training. Key clas labels, split metadata, and feature info - **`DistNeighborLoader`** (extends GLT `DistLoader`) - Standard node-based sampling loader - **`DistABLPLoader`** (extends GLT `DistLoader`) - Anchor-Based Link Prediction sampling loader -- **`DistNeighborSampler`** (extends GLT `DistNeighborSampler`) - Unified sampler supporting both standard - neighbor sampling and ABLP with positive/negative label injection +- **`DistNeighborSampler`** (extends GLT `DistNeighborSampler`) - Unified sampler supporting both standard neighbor + sampling and ABLP with positive/negative label injection **Two deployment modes:** From e59ca077f03e223b3433ae0b16bda789e73223b0 Mon Sep 17 00:00:00 2001 From: mkolodner Date: Fri, 6 Mar 2026 00:20:48 +0000 Subject: [PATCH 12/46] Add SamplerOptions to enable configurable sampler classes --- gigl/distributed/base_dist_loader.py | 7 ++ gigl/distributed/dist_ablp_neighborloader.py | 8 ++ gigl/distributed/dist_sampling_producer.py | 33 ++++++- .../distributed/distributed_neighborloader.py | 8 ++ gigl/distributed/graph_store/dist_server.py | 12 ++- gigl/distributed/sampler_options.py | 80 +++++++++++++++++ .../dist_ablp_neighborloader_test.py | 86 +++++++++++++++++++ .../distributed_neighborloader_test.py | 60 +++++++++++++ 8 files changed, 291 insertions(+), 3 deletions(-) create mode 100644 gigl/distributed/sampler_options.py diff --git a/gigl/distributed/base_dist_loader.py b/gigl/distributed/base_dist_loader.py index ae87d4b6d..c3a58fd91 100644 --- a/gigl/distributed/base_dist_loader.py +++ b/gigl/distributed/base_dist_loader.py @@ -41,6 +41,7 @@ from gigl.distributed.dist_sampling_producer import DistSamplingProducer from gigl.distributed.graph_store.dist_server import DistServer from gigl.distributed.graph_store.remote_dist_dataset import RemoteDistDataset +from gigl.distributed.sampler_options import SamplerOptions from gigl.distributed.utils.neighborloader import ( DatasetSchema, patch_fanout_for_sampling, @@ -100,6 +101,8 @@ class BaseDistLoader(DistLoader): runtime: Resolved distributed runtime information. producer: Either a pre-constructed ``DistSamplingProducer`` (colocated mode) or a callable to dispatch on the ``DistServer`` (graph store mode). + sampler_options: Controls which sampler class is instantiated. If ``None``, + falls back to the default ``NeighborSamplerOptions``. process_start_gap_seconds: Delay between each process for staggered colocated init. """ @@ -205,6 +208,7 @@ def __init__( device: torch.device, runtime: DistributedRuntimeInfo, producer: Union[DistSamplingProducer, Callable[..., int]], + sampler_options: Optional[SamplerOptions] = None, process_start_gap_seconds: float = 60.0, ): # Set right away so __del__ can clean up if we throw during init. @@ -218,6 +222,8 @@ def __init__( self._node_feature_info = dataset_schema.node_feature_info self._edge_feature_info = dataset_schema.edge_feature_info + self._sampler_options = sampler_options + # --- Attributes shared by both modes (mirrors GLT DistLoader.__init__) --- self.input_data = sampler_input self.sampling_type = sampling_config.sampling_type @@ -504,6 +510,7 @@ def _init_graph_store_connections( inp_data, self.sampling_config, self.worker_options, + self._sampler_options, ) rpc_futures.append((server_rank, fut)) logger.info( diff --git a/gigl/distributed/dist_ablp_neighborloader.py b/gigl/distributed/dist_ablp_neighborloader.py index 684187d67..e848225c2 100644 --- a/gigl/distributed/dist_ablp_neighborloader.py +++ b/gigl/distributed/dist_ablp_neighborloader.py @@ -27,6 +27,7 @@ ABLPNodeSamplerInput, metadata_key_with_prefix, ) +from gigl.distributed.sampler_options import SamplerOptions from gigl.distributed.utils.neighborloader import ( DatasetSchema, SamplingClusterSetup, @@ -82,6 +83,7 @@ def __init__( num_cpu_threads: Optional[int] = None, shuffle: bool = False, drop_last: bool = False, + sampler_options: Optional[SamplerOptions] = None, context: Optional[DistributedContext] = None, # TODO: (svij) Deprecate this local_process_rank: Optional[int] = None, # TODO: (svij) Deprecate this local_process_world_size: Optional[int] = None, # TODO: (svij) Deprecate this @@ -189,6 +191,10 @@ def __init__( Defaults to `2` if set to `None` when using cpu training/inference. shuffle (bool): Whether to shuffle the input nodes. (default: ``False``). drop_last (bool): Whether to drop the last incomplete batch. (default: ``False``). + sampler_options (Optional[SamplerOptions]): Controls which sampler class is + instantiated. Pass ``NeighborSamplerOptions`` to use the built-in sampler, + or ``CustomSamplerOptions`` to dynamically import a custom sampler class. + If ``None``, defaults to ``NeighborSamplerOptions(num_neighbors)``. context (deprecated - will be removed soon) (Optional[DistributedContext]): Distributed context information of the current process. local_process_rank (deprecated - will be removed soon) (int): The local rank of the current process within a node. local_process_world_size (deprecated - will be removed soon) (int): The total number of processes within a node. @@ -337,6 +343,7 @@ def __init__( sampling_config, worker_options, channel, + sampler_options=sampler_options, ) else: producer = DistServer.create_sampling_producer @@ -352,6 +359,7 @@ def __init__( device=device, runtime=runtime, producer=producer, + sampler_options=sampler_options, process_start_gap_seconds=process_start_gap_seconds, ) diff --git a/gigl/distributed/dist_sampling_producer.py b/gigl/distributed/dist_sampling_producer.py index 0949400ad..9542d89d5 100644 --- a/gigl/distributed/dist_sampling_producer.py +++ b/gigl/distributed/dist_sampling_producer.py @@ -34,7 +34,12 @@ from torch.utils.data.dataloader import DataLoader from torch.utils.data.dataset import Dataset -from gigl.distributed.dist_neighbor_sampler import DistNeighborSampler +from gigl.distributed.sampler_options import ( + CustomSamplerOptions, + NeighborSamplerOptions, + SamplerOptions, + resolve_sampler_class, +) def _sampling_worker_loop( @@ -48,6 +53,7 @@ def _sampling_worker_loop( task_queue: mp.Queue, sampling_completed_worker_count, # mp.Value mp_barrier: Barrier, + sampler_options: Optional[SamplerOptions] = None, ): dist_sampler = None try: @@ -86,7 +92,16 @@ def _sampling_worker_loop( if sampling_config.seed is not None: seed_everything(sampling_config.seed) - dist_sampler = DistNeighborSampler( + + sampler_cls = resolve_sampler_class( + sampler_options + if sampler_options is not None + else NeighborSamplerOptions(num_neighbors=sampling_config.num_neighbors) + ) + extra_kwargs: dict[str, object] = {} + if isinstance(sampler_options, CustomSamplerOptions): + extra_kwargs = sampler_options.class_args + dist_sampler = sampler_cls( data, sampling_config.num_neighbors, sampling_config.with_edge, @@ -99,6 +114,7 @@ def _sampling_worker_loop( worker_options.worker_concurrency, current_device, seed=sampling_config.seed, + **extra_kwargs, ) dist_sampler.start_loop() @@ -168,6 +184,18 @@ def _sampling_worker_loop( class DistSamplingProducer(DistMpSamplingProducer): + def __init__( + self, + data: DistDataset, + sampler_input: Union[NodeSamplerInput, EdgeSamplerInput], + sampling_config: SamplingConfig, + worker_options: MpDistSamplingWorkerOptions, + channel: ChannelBase, + sampler_options: Optional[SamplerOptions] = None, + ): + super().__init__(data, sampler_input, sampling_config, worker_options, channel) + self._sampler_options = sampler_options + def init(self): r"""Create the subprocess pool. Init samplers and rpc server.""" if self.sampling_config.seed is not None: @@ -197,6 +225,7 @@ def init(self): task_queue, self.sampling_completed_worker_count, barrier, + self._sampler_options, ), ) w.daemon = True diff --git a/gigl/distributed/distributed_neighborloader.py b/gigl/distributed/distributed_neighborloader.py index e914ffd05..d63052b27 100644 --- a/gigl/distributed/distributed_neighborloader.py +++ b/gigl/distributed/distributed_neighborloader.py @@ -21,6 +21,7 @@ from gigl.distributed.dist_sampling_producer import DistSamplingProducer from gigl.distributed.graph_store.dist_server import DistServer as GiglDistServer from gigl.distributed.graph_store.remote_dist_dataset import RemoteDistDataset +from gigl.distributed.sampler_options import SamplerOptions from gigl.distributed.utils.neighborloader import ( DatasetSchema, SamplingClusterSetup, @@ -81,6 +82,7 @@ def __init__( num_cpu_threads: Optional[int] = None, shuffle: bool = False, drop_last: bool = False, + sampler_options: Optional[SamplerOptions] = None, ): """ Distributed Neighbor Loader. @@ -146,6 +148,10 @@ def __init__( Defaults to `2` if set to `None` when using cpu training/inference. shuffle (bool): Whether to shuffle the input nodes. (default: ``False``). drop_last (bool): Whether to drop the last incomplete batch. (default: ``False``). + sampler_options (Optional[SamplerOptions]): Controls which sampler class is + instantiated. Pass ``NeighborSamplerOptions`` to use the built-in sampler, + or ``CustomSamplerOptions`` to dynamically import a custom sampler class. + If ``None``, defaults to ``NeighborSamplerOptions(num_neighbors)``. """ # Set self._shutdowned right away, that way if we throw here, and __del__ is called, @@ -245,6 +251,7 @@ def __init__( sampling_config, worker_options, channel, + sampler_options=sampler_options, ) else: producer = GiglDistServer.create_sampling_producer @@ -260,6 +267,7 @@ def __init__( device=device, runtime=runtime, producer=producer, + sampler_options=sampler_options, process_start_gap_seconds=process_start_gap_seconds, ) diff --git a/gigl/distributed/graph_store/dist_server.py b/gigl/distributed/graph_store/dist_server.py index dc4c0b082..68642d615 100644 --- a/gigl/distributed/graph_store/dist_server.py +++ b/gigl/distributed/graph_store/dist_server.py @@ -35,6 +35,7 @@ from gigl.distributed.dist_dataset import DistDataset from gigl.distributed.dist_sampling_producer import DistSamplingProducer from gigl.distributed.sampler import ABLPNodeSamplerInput +from gigl.distributed.sampler_options import SamplerOptions from gigl.distributed.utils.neighborloader import shard_nodes_by_process from gigl.src.common.types.graph_data import EdgeType, NodeType from gigl.types.graph import ( @@ -431,6 +432,7 @@ def create_sampling_producer( ], sampling_config: SamplingConfig, worker_options: RemoteDistSamplingWorkerOptions, + sampler_options: Optional[SamplerOptions] = None, ) -> int: """Create and initialize an instance of ``DistSamplingProducer`` with a group of subprocesses for distributed sampling. @@ -444,6 +446,9 @@ def create_sampling_producer( sampling_config (SamplingConfig): Configuration of sampling meta info. worker_options (RemoteDistSamplingWorkerOptions): Options for launching remote sampling workers by this server. + sampler_options (Optional[SamplerOptions]): Controls which sampler class + is instantiated. If ``None``, defaults to the built-in + ``DistNeighborSampler``. Returns: int: A unique id of created sampling producer on this server. @@ -467,7 +472,12 @@ def create_sampling_producer( worker_options.buffer_capacity, worker_options.buffer_size ) producer = DistSamplingProducer( - self.dataset, sampler_input, sampling_config, worker_options, buffer + self.dataset, + sampler_input, + sampling_config, + worker_options, + buffer, + sampler_options=sampler_options, ) producer.init() self._producer_pool[producer_id] = producer diff --git a/gigl/distributed/sampler_options.py b/gigl/distributed/sampler_options.py new file mode 100644 index 000000000..1a137da31 --- /dev/null +++ b/gigl/distributed/sampler_options.py @@ -0,0 +1,80 @@ +"""Sampler option types for configuring which sampler class to use in distributed loading. + +Provides two options: +- ``NeighborSamplerOptions``: Uses GiGL's built-in ``DistNeighborSampler``. +- ``CustomSamplerOptions``: Dynamically imports and uses a user-provided sampler class. + +Both are frozen dataclasses so they are safe to pickle across RPC boundaries +(required for Graph Store mode). +""" + +import importlib +from dataclasses import dataclass, field +from typing import Any, Union + +from graphlearn_torch.typing import EdgeType + + +@dataclass(frozen=True) +class NeighborSamplerOptions: + """Default sampler options using GiGL's DistNeighborSampler. + + Attributes: + num_neighbors: Fanout per hop, either a flat list (homogeneous) or a + dict mapping edge types to per-hop fanout lists (heterogeneous). + """ + + num_neighbors: Union[list[int], dict[EdgeType, list[int]]] + + +@dataclass(frozen=True) +class CustomSamplerOptions: + """Custom sampler options that dynamically import a user-provided sampler class. + + The class at ``class_path`` must conform to the same interface as + ``DistNeighborSampler`` (extend ``GLTDistNeighborSampler`` or at minimum + support ``start_loop``, ``sample_from_nodes``, etc.). + + Attributes: + class_path: Fully qualified Python import path, e.g. + ``"my.module.MySampler"``. + class_args: Additional keyword arguments passed to the sampler + constructor (on top of the standard GLT arguments). + """ + + class_path: str + class_args: dict[str, Any] = field(default_factory=dict) + + +SamplerOptions = Union[NeighborSamplerOptions, CustomSamplerOptions] + + +def resolve_sampler_class(sampler_options: SamplerOptions) -> type: + """Resolve a sampler class from the given options. + + Args: + sampler_options: Either ``NeighborSamplerOptions`` (returns the built-in + ``DistNeighborSampler``) or ``CustomSamplerOptions`` (dynamically + imports the class at ``class_path``). + + Returns: + The sampler class to instantiate. + + Raises: + TypeError: If ``sampler_options`` is not a recognized type. + ImportError: If the module in ``class_path`` cannot be imported. + AttributeError: If the class name is not found in the module. + """ + if isinstance(sampler_options, NeighborSamplerOptions): + from gigl.distributed.dist_neighbor_sampler import DistNeighborSampler + + return DistNeighborSampler + elif isinstance(sampler_options, CustomSamplerOptions): + module_path, class_name = sampler_options.class_path.rsplit(".", 1) + module = importlib.import_module(module_path) + return getattr(module, class_name) + else: + raise TypeError( + f"Unsupported sampler_options type: {type(sampler_options)}. " + f"Expected NeighborSamplerOptions or CustomSamplerOptions." + ) diff --git a/tests/unit/distributed/dist_ablp_neighborloader_test.py b/tests/unit/distributed/dist_ablp_neighborloader_test.py index fd14aadd5..aba64690c 100644 --- a/tests/unit/distributed/dist_ablp_neighborloader_test.py +++ b/tests/unit/distributed/dist_ablp_neighborloader_test.py @@ -15,6 +15,11 @@ from gigl.distributed.dist_dataset import DistDataset from gigl.distributed.dist_partitioner import DistPartitioner from gigl.distributed.dist_range_partitioner import DistRangePartitioner +from gigl.distributed.sampler_options import ( + CustomSamplerOptions, + NeighborSamplerOptions, + SamplerOptions, +) from gigl.distributed.utils.serialized_graph_metadata_translator import ( convert_pb_to_serialized_graph_metadata, ) @@ -307,6 +312,30 @@ def _run_toy_heterogeneous_ablp( shutdown_rpc() +def _run_distributed_ablp_loader_with_sampler_options( + _, + dataset: DistDataset, + expected_data_count: int, + sampler_options: SamplerOptions, +): + create_test_process_group() + loader = DistABLPLoader( + dataset=dataset, + num_neighbors=[2, 2], + input_nodes=to_homogeneous(dataset.train_node_ids), + pin_memory_device=torch.device("cpu"), + sampler_options=sampler_options, + ) + count = 0 + for datum in loader: + assert isinstance(datum, Data) + assert hasattr(datum, "y_positive") + count += 1 + assert count == expected_data_count + + shutdown_rpc() + + def _run_distributed_ablp_neighbor_loader_multiple_supervision_edge_types( _, input_nodes: tuple[NodeType, torch.Tensor], @@ -918,6 +947,63 @@ def test_ablp_dataloder_multiple_supervision_edge_types( ), ), + @parameterized.expand( + [ + param( + "NeighborSamplerOptions", + sampler_options=NeighborSamplerOptions(num_neighbors=[2, 2]), + ), + param( + "CustomSamplerOptions with DistNeighborSampler", + sampler_options=CustomSamplerOptions( + class_path="gigl.distributed.dist_neighbor_sampler.DistNeighborSampler", + ), + ), + ] + ) + def test_ablp_loader_with_sampler_options( + self, + _: str, + sampler_options: SamplerOptions, + ): + create_test_process_group() + cora_supervised_info = get_mocked_dataset_artifact_metadata()[ + CORA_USER_DEFINED_NODE_ANCHOR_MOCKED_DATASET_INFO.name + ] + + gbml_config_pb_wrapper = ( + GbmlConfigPbWrapper.get_gbml_config_pb_wrapper_from_uri( + gbml_config_uri=cora_supervised_info.frozen_gbml_config_uri + ) + ) + + serialized_graph_metadata = convert_pb_to_serialized_graph_metadata( + preprocessed_metadata_pb_wrapper=gbml_config_pb_wrapper.preprocessed_metadata_pb_wrapper, + graph_metadata_pb_wrapper=gbml_config_pb_wrapper.graph_metadata_pb_wrapper, + tfrecord_uri_pattern=".*.tfrecord(.gz)?$", + ) + + splitter = DistNodeAnchorLinkSplitter( + sampling_direction="in", should_convert_labels_to_edges=True + ) + + dataset = build_dataset( + serialized_graph_metadata=serialized_graph_metadata, + sample_edge_direction="in", + splitter=splitter, + ) + + assert dataset.train_node_ids is not None, "Train node ids must exist." + + mp.spawn( + fn=_run_distributed_ablp_loader_with_sampler_options, + args=( + dataset, + to_homogeneous(dataset.train_node_ids).numel(), + sampler_options, + ), + ) + @parameterized.expand( [ param( diff --git a/tests/unit/distributed/distributed_neighborloader_test.py b/tests/unit/distributed/distributed_neighborloader_test.py index 30ef1d910..87bf78683 100644 --- a/tests/unit/distributed/distributed_neighborloader_test.py +++ b/tests/unit/distributed/distributed_neighborloader_test.py @@ -11,6 +11,11 @@ from gigl.distributed.dataset_factory import build_dataset from gigl.distributed.dist_dataset import DistDataset from gigl.distributed.distributed_neighborloader import DistNeighborLoader +from gigl.distributed.sampler_options import ( + CustomSamplerOptions, + NeighborSamplerOptions, + SamplerOptions, +) from gigl.distributed.utils import get_free_port from gigl.distributed.utils.serialized_graph_metadata_translator import ( convert_pb_to_serialized_graph_metadata, @@ -275,6 +280,30 @@ def _run_distributed_neighbor_loader_with_node_labels_heterogeneous( shutdown_rpc() +def _run_distributed_neighbor_loader_with_sampler_options( + _, + dataset: DistDataset, + expected_data_count: int, + sampler_options: SamplerOptions, +): + create_test_process_group() + loader = DistNeighborLoader( + dataset=dataset, + num_neighbors=[2, 2], + pin_memory_device=torch.device("cpu"), + sampler_options=sampler_options, + ) + + count = 0 + for datum in loader: + assert isinstance(datum, Data) + count += 1 + + assert count == expected_data_count + + shutdown_rpc() + + def _run_cora_supervised_node_classification( _, dataset: DistDataset, @@ -577,6 +606,37 @@ def test_isolated_homogeneous_neighbor_loader( args=(dataset, 18), ) + @parameterized.expand( + [ + param( + "NeighborSamplerOptions", + sampler_options=NeighborSamplerOptions(num_neighbors=[2, 2]), + ), + param( + "CustomSamplerOptions with DistNeighborSampler", + sampler_options=CustomSamplerOptions( + class_path="gigl.distributed.dist_neighbor_sampler.DistNeighborSampler", + ), + ), + ] + ) + def test_distributed_neighbor_loader_with_sampler_options( + self, + _: str, + sampler_options: SamplerOptions, + ): + expected_data_count = 2708 + dataset = run_distributed_dataset( + rank=0, + world_size=self._world_size, + mocked_dataset_info=CORA_NODE_ANCHOR_MOCKED_DATASET_INFO, + _port=get_free_port(), + ) + mp.spawn( + fn=_run_distributed_neighbor_loader_with_sampler_options, + args=(dataset, expected_data_count, sampler_options), + ) + @parameterized.expand( [ param( From c0fdb03bec82c3901c36ee1feef610ce019d195e Mon Sep 17 00:00:00 2001 From: mkolodner Date: Fri, 6 Mar 2026 00:35:00 +0000 Subject: [PATCH 13/46] Rename to KHopNeighborSamplerOptions, inline resolve logic, clean up imports --- gigl/distributed/base_dist_loader.py | 2 +- gigl/distributed/dist_ablp_neighborloader.py | 4 +- gigl/distributed/dist_sampling_producer.py | 21 +++++----- .../distributed/distributed_neighborloader.py | 4 +- gigl/distributed/sampler_options.py | 38 ++----------------- .../dist_ablp_neighborloader_test.py | 6 +-- .../distributed_neighborloader_test.py | 6 +-- 7 files changed, 24 insertions(+), 57 deletions(-) diff --git a/gigl/distributed/base_dist_loader.py b/gigl/distributed/base_dist_loader.py index c3a58fd91..6db2385e0 100644 --- a/gigl/distributed/base_dist_loader.py +++ b/gigl/distributed/base_dist_loader.py @@ -102,7 +102,7 @@ class BaseDistLoader(DistLoader): producer: Either a pre-constructed ``DistSamplingProducer`` (colocated mode) or a callable to dispatch on the ``DistServer`` (graph store mode). sampler_options: Controls which sampler class is instantiated. If ``None``, - falls back to the default ``NeighborSamplerOptions``. + falls back to the default ``KHopNeighborSamplerOptions``. process_start_gap_seconds: Delay between each process for staggered colocated init. """ diff --git a/gigl/distributed/dist_ablp_neighborloader.py b/gigl/distributed/dist_ablp_neighborloader.py index e848225c2..4aecda7e7 100644 --- a/gigl/distributed/dist_ablp_neighborloader.py +++ b/gigl/distributed/dist_ablp_neighborloader.py @@ -192,9 +192,9 @@ def __init__( shuffle (bool): Whether to shuffle the input nodes. (default: ``False``). drop_last (bool): Whether to drop the last incomplete batch. (default: ``False``). sampler_options (Optional[SamplerOptions]): Controls which sampler class is - instantiated. Pass ``NeighborSamplerOptions`` to use the built-in sampler, + instantiated. Pass ``KHopNeighborSamplerOptions`` to use the built-in sampler, or ``CustomSamplerOptions`` to dynamically import a custom sampler class. - If ``None``, defaults to ``NeighborSamplerOptions(num_neighbors)``. + If ``None``, defaults to ``KHopNeighborSamplerOptions(num_neighbors)``. context (deprecated - will be removed soon) (Optional[DistributedContext]): Distributed context information of the current process. local_process_rank (deprecated - will be removed soon) (int): The local rank of the current process within a node. local_process_world_size (deprecated - will be removed soon) (int): The total number of processes within a node. diff --git a/gigl/distributed/dist_sampling_producer.py b/gigl/distributed/dist_sampling_producer.py index 9542d89d5..b76a866fc 100644 --- a/gigl/distributed/dist_sampling_producer.py +++ b/gigl/distributed/dist_sampling_producer.py @@ -4,6 +4,7 @@ # neighbor sampling and ABLP) instead of GLT's DistNeighborSampler. import datetime +import importlib import queue from threading import Barrier from typing import Optional, Union, cast @@ -34,12 +35,8 @@ from torch.utils.data.dataloader import DataLoader from torch.utils.data.dataset import Dataset -from gigl.distributed.sampler_options import ( - CustomSamplerOptions, - NeighborSamplerOptions, - SamplerOptions, - resolve_sampler_class, -) +from gigl.distributed.dist_neighbor_sampler import DistNeighborSampler +from gigl.distributed.sampler_options import CustomSamplerOptions, SamplerOptions def _sampling_worker_loop( @@ -93,14 +90,16 @@ def _sampling_worker_loop( if sampling_config.seed is not None: seed_everything(sampling_config.seed) - sampler_cls = resolve_sampler_class( - sampler_options - if sampler_options is not None - else NeighborSamplerOptions(num_neighbors=sampling_config.num_neighbors) - ) + # Resolve sampler class from options extra_kwargs: dict[str, object] = {} if isinstance(sampler_options, CustomSamplerOptions): + module_path, class_name = sampler_options.class_path.rsplit(".", 1) + module = importlib.import_module(module_path) + sampler_cls = getattr(module, class_name) extra_kwargs = sampler_options.class_args + else: + sampler_cls = DistNeighborSampler + dist_sampler = sampler_cls( data, sampling_config.num_neighbors, diff --git a/gigl/distributed/distributed_neighborloader.py b/gigl/distributed/distributed_neighborloader.py index d63052b27..8567b6706 100644 --- a/gigl/distributed/distributed_neighborloader.py +++ b/gigl/distributed/distributed_neighborloader.py @@ -149,9 +149,9 @@ def __init__( shuffle (bool): Whether to shuffle the input nodes. (default: ``False``). drop_last (bool): Whether to drop the last incomplete batch. (default: ``False``). sampler_options (Optional[SamplerOptions]): Controls which sampler class is - instantiated. Pass ``NeighborSamplerOptions`` to use the built-in sampler, + instantiated. Pass ``KHopNeighborSamplerOptions`` to use the built-in sampler, or ``CustomSamplerOptions`` to dynamically import a custom sampler class. - If ``None``, defaults to ``NeighborSamplerOptions(num_neighbors)``. + If ``None``, defaults to ``KHopNeighborSamplerOptions(num_neighbors)``. """ # Set self._shutdowned right away, that way if we throw here, and __del__ is called, diff --git a/gigl/distributed/sampler_options.py b/gigl/distributed/sampler_options.py index 1a137da31..1466d3829 100644 --- a/gigl/distributed/sampler_options.py +++ b/gigl/distributed/sampler_options.py @@ -1,14 +1,13 @@ """Sampler option types for configuring which sampler class to use in distributed loading. Provides two options: -- ``NeighborSamplerOptions``: Uses GiGL's built-in ``DistNeighborSampler``. +- ``KHopNeighborSamplerOptions``: Uses GiGL's built-in ``DistNeighborSampler``. - ``CustomSamplerOptions``: Dynamically imports and uses a user-provided sampler class. Both are frozen dataclasses so they are safe to pickle across RPC boundaries (required for Graph Store mode). """ -import importlib from dataclasses import dataclass, field from typing import Any, Union @@ -16,7 +15,7 @@ @dataclass(frozen=True) -class NeighborSamplerOptions: +class KHopNeighborSamplerOptions: """Default sampler options using GiGL's DistNeighborSampler. Attributes: @@ -46,35 +45,4 @@ class CustomSamplerOptions: class_args: dict[str, Any] = field(default_factory=dict) -SamplerOptions = Union[NeighborSamplerOptions, CustomSamplerOptions] - - -def resolve_sampler_class(sampler_options: SamplerOptions) -> type: - """Resolve a sampler class from the given options. - - Args: - sampler_options: Either ``NeighborSamplerOptions`` (returns the built-in - ``DistNeighborSampler``) or ``CustomSamplerOptions`` (dynamically - imports the class at ``class_path``). - - Returns: - The sampler class to instantiate. - - Raises: - TypeError: If ``sampler_options`` is not a recognized type. - ImportError: If the module in ``class_path`` cannot be imported. - AttributeError: If the class name is not found in the module. - """ - if isinstance(sampler_options, NeighborSamplerOptions): - from gigl.distributed.dist_neighbor_sampler import DistNeighborSampler - - return DistNeighborSampler - elif isinstance(sampler_options, CustomSamplerOptions): - module_path, class_name = sampler_options.class_path.rsplit(".", 1) - module = importlib.import_module(module_path) - return getattr(module, class_name) - else: - raise TypeError( - f"Unsupported sampler_options type: {type(sampler_options)}. " - f"Expected NeighborSamplerOptions or CustomSamplerOptions." - ) +SamplerOptions = Union[KHopNeighborSamplerOptions, CustomSamplerOptions] diff --git a/tests/unit/distributed/dist_ablp_neighborloader_test.py b/tests/unit/distributed/dist_ablp_neighborloader_test.py index aba64690c..84b58ec78 100644 --- a/tests/unit/distributed/dist_ablp_neighborloader_test.py +++ b/tests/unit/distributed/dist_ablp_neighborloader_test.py @@ -17,7 +17,7 @@ from gigl.distributed.dist_range_partitioner import DistRangePartitioner from gigl.distributed.sampler_options import ( CustomSamplerOptions, - NeighborSamplerOptions, + KHopNeighborSamplerOptions, SamplerOptions, ) from gigl.distributed.utils.serialized_graph_metadata_translator import ( @@ -950,8 +950,8 @@ def test_ablp_dataloder_multiple_supervision_edge_types( @parameterized.expand( [ param( - "NeighborSamplerOptions", - sampler_options=NeighborSamplerOptions(num_neighbors=[2, 2]), + "KHopNeighborSamplerOptions", + sampler_options=KHopNeighborSamplerOptions(num_neighbors=[2, 2]), ), param( "CustomSamplerOptions with DistNeighborSampler", diff --git a/tests/unit/distributed/distributed_neighborloader_test.py b/tests/unit/distributed/distributed_neighborloader_test.py index 87bf78683..24ea12160 100644 --- a/tests/unit/distributed/distributed_neighborloader_test.py +++ b/tests/unit/distributed/distributed_neighborloader_test.py @@ -13,7 +13,7 @@ from gigl.distributed.distributed_neighborloader import DistNeighborLoader from gigl.distributed.sampler_options import ( CustomSamplerOptions, - NeighborSamplerOptions, + KHopNeighborSamplerOptions, SamplerOptions, ) from gigl.distributed.utils import get_free_port @@ -609,8 +609,8 @@ def test_isolated_homogeneous_neighbor_loader( @parameterized.expand( [ param( - "NeighborSamplerOptions", - sampler_options=NeighborSamplerOptions(num_neighbors=[2, 2]), + "KHopNeighborSamplerOptions", + sampler_options=KHopNeighborSamplerOptions(num_neighbors=[2, 2]), ), param( "CustomSamplerOptions with DistNeighborSampler", From e41b957bb1ca35e0ea4f473f1284e9c1f2f6ea05 Mon Sep 17 00:00:00 2001 From: mkolodner Date: Fri, 6 Mar 2026 00:36:44 +0000 Subject: [PATCH 14/46] Pass sampler_options positionally in DistABLPLoader --- gigl/distributed/dist_ablp_neighborloader.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gigl/distributed/dist_ablp_neighborloader.py b/gigl/distributed/dist_ablp_neighborloader.py index 4aecda7e7..99322ae56 100644 --- a/gigl/distributed/dist_ablp_neighborloader.py +++ b/gigl/distributed/dist_ablp_neighborloader.py @@ -343,7 +343,7 @@ def __init__( sampling_config, worker_options, channel, - sampler_options=sampler_options, + sampler_options, ) else: producer = DistServer.create_sampling_producer From 4b00573dd9f45efd82a2ec4fce0d41c5af5f9ed9 Mon Sep 17 00:00:00 2001 From: mkolodner Date: Fri, 6 Mar 2026 00:38:03 +0000 Subject: [PATCH 15/46] Pass sampler_options positionally in DistNeighborLoader and DistServer --- gigl/distributed/distributed_neighborloader.py | 2 +- gigl/distributed/graph_store/dist_server.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/gigl/distributed/distributed_neighborloader.py b/gigl/distributed/distributed_neighborloader.py index 8567b6706..5634e907e 100644 --- a/gigl/distributed/distributed_neighborloader.py +++ b/gigl/distributed/distributed_neighborloader.py @@ -251,7 +251,7 @@ def __init__( sampling_config, worker_options, channel, - sampler_options=sampler_options, + sampler_options, ) else: producer = GiglDistServer.create_sampling_producer diff --git a/gigl/distributed/graph_store/dist_server.py b/gigl/distributed/graph_store/dist_server.py index 68642d615..fdfffd305 100644 --- a/gigl/distributed/graph_store/dist_server.py +++ b/gigl/distributed/graph_store/dist_server.py @@ -477,7 +477,7 @@ def create_sampling_producer( sampling_config, worker_options, buffer, - sampler_options=sampler_options, + sampler_options, ) producer.init() self._producer_pool[producer_id] = producer From 8ac9ab8db159d649de8cd7d8a4bc7f4b111d38e6 Mon Sep 17 00:00:00 2001 From: mkolodner Date: Fri, 6 Mar 2026 00:53:34 +0000 Subject: [PATCH 16/46] Make num_neighbors optional, resolve from sampler_options --- gigl/distributed/dist_ablp_neighborloader.py | 12 +++-- .../distributed/distributed_neighborloader.py | 12 +++-- gigl/distributed/sampler_options.py | 48 ++++++++++++++++++- 3 files changed, 65 insertions(+), 7 deletions(-) diff --git a/gigl/distributed/dist_ablp_neighborloader.py b/gigl/distributed/dist_ablp_neighborloader.py index 99322ae56..779c53340 100644 --- a/gigl/distributed/dist_ablp_neighborloader.py +++ b/gigl/distributed/dist_ablp_neighborloader.py @@ -27,7 +27,7 @@ ABLPNodeSamplerInput, metadata_key_with_prefix, ) -from gigl.distributed.sampler_options import SamplerOptions +from gigl.distributed.sampler_options import SamplerOptions, resolve_sampler_options from gigl.distributed.utils.neighborloader import ( DatasetSchema, SamplingClusterSetup, @@ -63,7 +63,7 @@ class DistABLPLoader(BaseDistLoader): def __init__( self, dataset: Union[DistDataset, RemoteDistDataset], - num_neighbors: Union[list[int], dict[EdgeType, list[int]]], + num_neighbors: Optional[Union[list[int], dict[EdgeType, list[int]]]] = None, input_nodes: Optional[ Union[ torch.Tensor, @@ -138,11 +138,13 @@ def __init__( Args: dataset (Union[DistDataset, RemoteDistDataset]): The dataset to sample from. If this is a `RemoteDistDataset`, then we are in "Graph Store" mode. - num_neighbors (list[int] or dict[tuple[str, str, str], list[int]]): + num_neighbors (Optional[list[int] or dict[tuple[str, str, str], list[int]]]): The number of neighbors to sample for each node in each iteration. If an entry is set to `-1`, all neighbors will be included. In heterogeneous graphs, may also take in a dictionary denoting the amount of neighbors to sample for each individual edge type. + Either ``num_neighbors`` or ``sampler_options`` must be provided. + If both are provided with ``KHopNeighborSamplerOptions``, they must match. input_nodes: Indices of seed nodes to start sampling from. For Colocated mode: `torch.Tensor` or `tuple[NodeType, torch.Tensor]`. If set to `None` for homogeneous settings, all nodes will be considered. @@ -204,6 +206,10 @@ def __init__( # then we can properly clean up and don't get extraneous error messages. self._shutdowned = True + num_neighbors, sampler_options = resolve_sampler_options( + num_neighbors, sampler_options + ) + # Determine sampling cluster setup based on dataset type if isinstance(dataset, RemoteDistDataset): self._sampling_cluster_setup = SamplingClusterSetup.GRAPH_STORE diff --git a/gigl/distributed/distributed_neighborloader.py b/gigl/distributed/distributed_neighborloader.py index 5634e907e..52718c8dc 100644 --- a/gigl/distributed/distributed_neighborloader.py +++ b/gigl/distributed/distributed_neighborloader.py @@ -21,7 +21,7 @@ from gigl.distributed.dist_sampling_producer import DistSamplingProducer from gigl.distributed.graph_store.dist_server import DistServer as GiglDistServer from gigl.distributed.graph_store.remote_dist_dataset import RemoteDistDataset -from gigl.distributed.sampler_options import SamplerOptions +from gigl.distributed.sampler_options import SamplerOptions, resolve_sampler_options from gigl.distributed.utils.neighborloader import ( DatasetSchema, SamplingClusterSetup, @@ -60,7 +60,7 @@ class DistNeighborLoader(BaseDistLoader): def __init__( self, dataset: Union[DistDataset, RemoteDistDataset], - num_neighbors: Union[list[int], dict[EdgeType, list[int]]], + num_neighbors: Optional[Union[list[int], dict[EdgeType, list[int]]]] = None, input_nodes: Optional[ Union[ torch.Tensor, @@ -99,11 +99,13 @@ def __init__( Args: dataset (DistDataset | RemoteDistDataset): The dataset to sample from. If this is a `RemoteDistDataset`, then we assumed to be in "Graph Store" mode. - num_neighbors (list[int] or dict[Tuple[str, str, str], list[int]]): + num_neighbors (Optional[list[int] or dict[Tuple[str, str, str], list[int]]]): The number of neighbors to sample for each node in each iteration. If an entry is set to `-1`, all neighbors will be included. In heterogeneous graphs, may also take in a dictionary denoting the amount of neighbors to sample for each individual edge type. + Either ``num_neighbors`` or ``sampler_options`` must be provided. + If both are provided with ``KHopNeighborSamplerOptions``, they must match. context (deprecated - will be removed soon) (DistributedContext): Distributed context information of the current process. local_process_rank (deprecated - will be removed soon) (int): Required if context provided. The local rank of the current process within a node. local_process_world_size (deprecated - will be removed soon)(int): Required if context provided. The total number of processes within a node. @@ -158,6 +160,10 @@ def __init__( # then we can properly clean up and don't get extraneous error messages. self._shutdowned = True + num_neighbors, sampler_options = resolve_sampler_options( + num_neighbors, sampler_options + ) + # Resolve distributed context runtime = BaseDistLoader.resolve_runtime( context, local_process_rank, local_process_world_size diff --git a/gigl/distributed/sampler_options.py b/gigl/distributed/sampler_options.py index 1466d3829..faddcc51e 100644 --- a/gigl/distributed/sampler_options.py +++ b/gigl/distributed/sampler_options.py @@ -9,7 +9,7 @@ """ from dataclasses import dataclass, field -from typing import Any, Union +from typing import Any, Optional, Union from graphlearn_torch.typing import EdgeType @@ -46,3 +46,49 @@ class CustomSamplerOptions: SamplerOptions = Union[KHopNeighborSamplerOptions, CustomSamplerOptions] + + +def resolve_sampler_options( + num_neighbors: Optional[Union[list[int], dict[EdgeType, list[int]]]], + sampler_options: Optional[SamplerOptions], +) -> tuple[Union[list[int], dict[EdgeType, list[int]]], SamplerOptions]: + """Resolve num_neighbors and sampler_options from user-provided values. + + Handles backwards compatibility: callers can provide ``num_neighbors`` + alone (old API), ``sampler_options`` alone (new API), or both (validated + for consistency). + + Args: + num_neighbors: Fanout per hop, or None. + sampler_options: Sampler configuration, or None. + + Returns: + A tuple of (resolved_num_neighbors, resolved_sampler_options). + + Raises: + ValueError: If neither is provided, or if both provide conflicting + num_neighbors values. + """ + if num_neighbors is None and sampler_options is None: + raise ValueError("Either num_neighbors or sampler_options must be provided.") + + if sampler_options is None: + assert num_neighbors is not None + return num_neighbors, KHopNeighborSamplerOptions(num_neighbors) + + if isinstance(sampler_options, KHopNeighborSamplerOptions): + if num_neighbors is None: + return sampler_options.num_neighbors, sampler_options + if num_neighbors != sampler_options.num_neighbors: + raise ValueError( + f"num_neighbors ({num_neighbors}) does not match " + f"sampler_options.num_neighbors ({sampler_options.num_neighbors}). " + f"Provide one or the other, not both with different values." + ) + return num_neighbors, sampler_options + + # CustomSamplerOptions — num_neighbors is not meaningful, default to [] + assert isinstance(sampler_options, CustomSamplerOptions) + if num_neighbors is None: + return [], sampler_options + return num_neighbors, sampler_options From 7a51864eefd2ca1817cf8c48c04a55feb83762cb Mon Sep 17 00:00:00 2001 From: mkolodner Date: Fri, 6 Mar 2026 01:19:27 +0000 Subject: [PATCH 17/46] Add resolve_sampler_options tests, remove redundant num_neighbors from test helpers --- .../dist_ablp_neighborloader_test.py | 1 - .../distributed_neighborloader_test.py | 1 - .../unit/distributed/sampler_options_test.py | 57 +++++++++++++++++++ 3 files changed, 57 insertions(+), 2 deletions(-) create mode 100644 tests/unit/distributed/sampler_options_test.py diff --git a/tests/unit/distributed/dist_ablp_neighborloader_test.py b/tests/unit/distributed/dist_ablp_neighborloader_test.py index 84b58ec78..6879d27ba 100644 --- a/tests/unit/distributed/dist_ablp_neighborloader_test.py +++ b/tests/unit/distributed/dist_ablp_neighborloader_test.py @@ -321,7 +321,6 @@ def _run_distributed_ablp_loader_with_sampler_options( create_test_process_group() loader = DistABLPLoader( dataset=dataset, - num_neighbors=[2, 2], input_nodes=to_homogeneous(dataset.train_node_ids), pin_memory_device=torch.device("cpu"), sampler_options=sampler_options, diff --git a/tests/unit/distributed/distributed_neighborloader_test.py b/tests/unit/distributed/distributed_neighborloader_test.py index 24ea12160..82cb4ae48 100644 --- a/tests/unit/distributed/distributed_neighborloader_test.py +++ b/tests/unit/distributed/distributed_neighborloader_test.py @@ -289,7 +289,6 @@ def _run_distributed_neighbor_loader_with_sampler_options( create_test_process_group() loader = DistNeighborLoader( dataset=dataset, - num_neighbors=[2, 2], pin_memory_device=torch.device("cpu"), sampler_options=sampler_options, ) diff --git a/tests/unit/distributed/sampler_options_test.py b/tests/unit/distributed/sampler_options_test.py new file mode 100644 index 000000000..c7793ae83 --- /dev/null +++ b/tests/unit/distributed/sampler_options_test.py @@ -0,0 +1,57 @@ +from gigl.distributed.sampler_options import ( + CustomSamplerOptions, + KHopNeighborSamplerOptions, + resolve_sampler_options, +) +from tests.test_assets.test_case import TestCase + + +class ResolveSamplerOptionsTest(TestCase): + def test_both_none_raises(self): + with self.assertRaises(ValueError): + resolve_sampler_options(num_neighbors=None, sampler_options=None) + + def test_num_neighbors_only(self): + num_neighbors, options = resolve_sampler_options( + num_neighbors=[2, 2], sampler_options=None + ) + self.assertEqual(num_neighbors, [2, 2]) + assert isinstance(options, KHopNeighborSamplerOptions) + self.assertEqual(options.num_neighbors, [2, 2]) + + def test_khop_options_only(self): + opts = KHopNeighborSamplerOptions(num_neighbors=[3, 3]) + num_neighbors, options = resolve_sampler_options( + num_neighbors=None, sampler_options=opts + ) + self.assertEqual(num_neighbors, [3, 3]) + self.assertIs(options, opts) + + def test_khop_options_matching_num_neighbors(self): + opts = KHopNeighborSamplerOptions(num_neighbors=[2, 2]) + num_neighbors, options = resolve_sampler_options( + num_neighbors=[2, 2], sampler_options=opts + ) + self.assertEqual(num_neighbors, [2, 2]) + self.assertIs(options, opts) + + def test_khop_options_conflicting_num_neighbors_raises(self): + opts = KHopNeighborSamplerOptions(num_neighbors=[3, 3]) + with self.assertRaises(ValueError): + resolve_sampler_options(num_neighbors=[2, 2], sampler_options=opts) + + def test_custom_options_without_num_neighbors(self): + opts = CustomSamplerOptions(class_path="my.module.MySampler") + num_neighbors, options = resolve_sampler_options( + num_neighbors=None, sampler_options=opts + ) + self.assertEqual(num_neighbors, []) + self.assertIs(options, opts) + + def test_custom_options_with_num_neighbors(self): + opts = CustomSamplerOptions(class_path="my.module.MySampler") + num_neighbors, options = resolve_sampler_options( + num_neighbors=[2, 2], sampler_options=opts + ) + self.assertEqual(num_neighbors, [2, 2]) + self.assertIs(options, opts) From c924bb679561349e30fab4f7a8eade5742d9d244 Mon Sep 17 00:00:00 2001 From: mkolodner Date: Fri, 6 Mar 2026 01:44:04 +0000 Subject: [PATCH 18/46] Remove redundant KHop test cases, keep only CustomSamplerOptions tests --- .../dist_ablp_neighborloader_test.py | 30 +++------------- .../distributed_neighborloader_test.py | 34 +++++-------------- 2 files changed, 14 insertions(+), 50 deletions(-) diff --git a/tests/unit/distributed/dist_ablp_neighborloader_test.py b/tests/unit/distributed/dist_ablp_neighborloader_test.py index 6879d27ba..a1292358e 100644 --- a/tests/unit/distributed/dist_ablp_neighborloader_test.py +++ b/tests/unit/distributed/dist_ablp_neighborloader_test.py @@ -15,11 +15,7 @@ from gigl.distributed.dist_dataset import DistDataset from gigl.distributed.dist_partitioner import DistPartitioner from gigl.distributed.dist_range_partitioner import DistRangePartitioner -from gigl.distributed.sampler_options import ( - CustomSamplerOptions, - KHopNeighborSamplerOptions, - SamplerOptions, -) +from gigl.distributed.sampler_options import CustomSamplerOptions, SamplerOptions from gigl.distributed.utils.serialized_graph_metadata_translator import ( convert_pb_to_serialized_graph_metadata, ) @@ -946,25 +942,7 @@ def test_ablp_dataloder_multiple_supervision_edge_types( ), ), - @parameterized.expand( - [ - param( - "KHopNeighborSamplerOptions", - sampler_options=KHopNeighborSamplerOptions(num_neighbors=[2, 2]), - ), - param( - "CustomSamplerOptions with DistNeighborSampler", - sampler_options=CustomSamplerOptions( - class_path="gigl.distributed.dist_neighbor_sampler.DistNeighborSampler", - ), - ), - ] - ) - def test_ablp_loader_with_sampler_options( - self, - _: str, - sampler_options: SamplerOptions, - ): + def test_ablp_loader_with_custom_sampler_options(self): create_test_process_group() cora_supervised_info = get_mocked_dataset_artifact_metadata()[ CORA_USER_DEFINED_NODE_ANCHOR_MOCKED_DATASET_INFO.name @@ -999,7 +977,9 @@ def test_ablp_loader_with_sampler_options( args=( dataset, to_homogeneous(dataset.train_node_ids).numel(), - sampler_options, + CustomSamplerOptions( + class_path="gigl.distributed.dist_neighbor_sampler.DistNeighborSampler", + ), ), ) diff --git a/tests/unit/distributed/distributed_neighborloader_test.py b/tests/unit/distributed/distributed_neighborloader_test.py index 82cb4ae48..880922014 100644 --- a/tests/unit/distributed/distributed_neighborloader_test.py +++ b/tests/unit/distributed/distributed_neighborloader_test.py @@ -11,11 +11,7 @@ from gigl.distributed.dataset_factory import build_dataset from gigl.distributed.dist_dataset import DistDataset from gigl.distributed.distributed_neighborloader import DistNeighborLoader -from gigl.distributed.sampler_options import ( - CustomSamplerOptions, - KHopNeighborSamplerOptions, - SamplerOptions, -) +from gigl.distributed.sampler_options import CustomSamplerOptions, SamplerOptions from gigl.distributed.utils import get_free_port from gigl.distributed.utils.serialized_graph_metadata_translator import ( convert_pb_to_serialized_graph_metadata, @@ -605,25 +601,7 @@ def test_isolated_homogeneous_neighbor_loader( args=(dataset, 18), ) - @parameterized.expand( - [ - param( - "KHopNeighborSamplerOptions", - sampler_options=KHopNeighborSamplerOptions(num_neighbors=[2, 2]), - ), - param( - "CustomSamplerOptions with DistNeighborSampler", - sampler_options=CustomSamplerOptions( - class_path="gigl.distributed.dist_neighbor_sampler.DistNeighborSampler", - ), - ), - ] - ) - def test_distributed_neighbor_loader_with_sampler_options( - self, - _: str, - sampler_options: SamplerOptions, - ): + def test_distributed_neighbor_loader_with_custom_sampler_options(self): expected_data_count = 2708 dataset = run_distributed_dataset( rank=0, @@ -633,7 +611,13 @@ def test_distributed_neighbor_loader_with_sampler_options( ) mp.spawn( fn=_run_distributed_neighbor_loader_with_sampler_options, - args=(dataset, expected_data_count, sampler_options), + args=( + dataset, + expected_data_count, + CustomSamplerOptions( + class_path="gigl.distributed.dist_neighbor_sampler.DistNeighborSampler", + ), + ), ) @parameterized.expand( From 455cb24cb4670a7d1dce65c447a989dfafc7f24d Mon Sep 17 00:00:00 2001 From: mkolodner Date: Fri, 6 Mar 2026 01:49:03 +0000 Subject: [PATCH 19/46] Require num_neighbors always, remove silent [] default for CustomSamplerOptions --- gigl/distributed/dist_ablp_neighborloader.py | 2 +- .../distributed/distributed_neighborloader.py | 2 +- gigl/distributed/sampler_options.py | 22 ++++++++++++------- .../dist_ablp_neighborloader_test.py | 1 + .../distributed_neighborloader_test.py | 1 + .../unit/distributed/sampler_options_test.py | 9 +++----- 6 files changed, 21 insertions(+), 16 deletions(-) diff --git a/gigl/distributed/dist_ablp_neighborloader.py b/gigl/distributed/dist_ablp_neighborloader.py index 779c53340..b760fa9f3 100644 --- a/gigl/distributed/dist_ablp_neighborloader.py +++ b/gigl/distributed/dist_ablp_neighborloader.py @@ -143,7 +143,7 @@ def __init__( If an entry is set to `-1`, all neighbors will be included. In heterogeneous graphs, may also take in a dictionary denoting the amount of neighbors to sample for each individual edge type. - Either ``num_neighbors`` or ``sampler_options`` must be provided. + Required — either directly or via ``KHopNeighborSamplerOptions``. If both are provided with ``KHopNeighborSamplerOptions``, they must match. input_nodes: Indices of seed nodes to start sampling from. For Colocated mode: `torch.Tensor` or `tuple[NodeType, torch.Tensor]`. diff --git a/gigl/distributed/distributed_neighborloader.py b/gigl/distributed/distributed_neighborloader.py index 52718c8dc..86b5f8b08 100644 --- a/gigl/distributed/distributed_neighborloader.py +++ b/gigl/distributed/distributed_neighborloader.py @@ -104,7 +104,7 @@ def __init__( If an entry is set to `-1`, all neighbors will be included. In heterogeneous graphs, may also take in a dictionary denoting the amount of neighbors to sample for each individual edge type. - Either ``num_neighbors`` or ``sampler_options`` must be provided. + Required — either directly or via ``KHopNeighborSamplerOptions``. If both are provided with ``KHopNeighborSamplerOptions``, they must match. context (deprecated - will be removed soon) (DistributedContext): Distributed context information of the current process. local_process_rank (deprecated - will be removed soon) (int): Required if context provided. The local rank of the current process within a node. diff --git a/gigl/distributed/sampler_options.py b/gigl/distributed/sampler_options.py index faddcc51e..ba27193df 100644 --- a/gigl/distributed/sampler_options.py +++ b/gigl/distributed/sampler_options.py @@ -58,6 +58,10 @@ def resolve_sampler_options( alone (old API), ``sampler_options`` alone (new API), or both (validated for consistency). + ``num_neighbors`` is always required — either directly or via + ``KHopNeighborSamplerOptions.num_neighbors``. This follows the PyG + convention where neighbor fanout must be explicitly specified. + Args: num_neighbors: Fanout per hop, or None. sampler_options: Sampler configuration, or None. @@ -66,14 +70,15 @@ def resolve_sampler_options( A tuple of (resolved_num_neighbors, resolved_sampler_options). Raises: - ValueError: If neither is provided, or if both provide conflicting - num_neighbors values. + ValueError: If num_neighbors cannot be determined, or if both provide + conflicting num_neighbors values. """ - if num_neighbors is None and sampler_options is None: - raise ValueError("Either num_neighbors or sampler_options must be provided.") - if sampler_options is None: - assert num_neighbors is not None + if num_neighbors is None: + raise ValueError( + "num_neighbors must be provided, either directly or via " + "KHopNeighborSamplerOptions." + ) return num_neighbors, KHopNeighborSamplerOptions(num_neighbors) if isinstance(sampler_options, KHopNeighborSamplerOptions): @@ -87,8 +92,9 @@ def resolve_sampler_options( ) return num_neighbors, sampler_options - # CustomSamplerOptions — num_neighbors is not meaningful, default to [] assert isinstance(sampler_options, CustomSamplerOptions) if num_neighbors is None: - return [], sampler_options + raise ValueError( + "num_neighbors must be provided when using CustomSamplerOptions." + ) return num_neighbors, sampler_options diff --git a/tests/unit/distributed/dist_ablp_neighborloader_test.py b/tests/unit/distributed/dist_ablp_neighborloader_test.py index a1292358e..437b7ce37 100644 --- a/tests/unit/distributed/dist_ablp_neighborloader_test.py +++ b/tests/unit/distributed/dist_ablp_neighborloader_test.py @@ -317,6 +317,7 @@ def _run_distributed_ablp_loader_with_sampler_options( create_test_process_group() loader = DistABLPLoader( dataset=dataset, + num_neighbors=[2, 2], input_nodes=to_homogeneous(dataset.train_node_ids), pin_memory_device=torch.device("cpu"), sampler_options=sampler_options, diff --git a/tests/unit/distributed/distributed_neighborloader_test.py b/tests/unit/distributed/distributed_neighborloader_test.py index 880922014..2a85a810a 100644 --- a/tests/unit/distributed/distributed_neighborloader_test.py +++ b/tests/unit/distributed/distributed_neighborloader_test.py @@ -285,6 +285,7 @@ def _run_distributed_neighbor_loader_with_sampler_options( create_test_process_group() loader = DistNeighborLoader( dataset=dataset, + num_neighbors=[2, 2], pin_memory_device=torch.device("cpu"), sampler_options=sampler_options, ) diff --git a/tests/unit/distributed/sampler_options_test.py b/tests/unit/distributed/sampler_options_test.py index c7793ae83..4a7c5c806 100644 --- a/tests/unit/distributed/sampler_options_test.py +++ b/tests/unit/distributed/sampler_options_test.py @@ -40,13 +40,10 @@ def test_khop_options_conflicting_num_neighbors_raises(self): with self.assertRaises(ValueError): resolve_sampler_options(num_neighbors=[2, 2], sampler_options=opts) - def test_custom_options_without_num_neighbors(self): + def test_custom_options_without_num_neighbors_raises(self): opts = CustomSamplerOptions(class_path="my.module.MySampler") - num_neighbors, options = resolve_sampler_options( - num_neighbors=None, sampler_options=opts - ) - self.assertEqual(num_neighbors, []) - self.assertIs(options, opts) + with self.assertRaises(ValueError): + resolve_sampler_options(num_neighbors=None, sampler_options=opts) def test_custom_options_with_num_neighbors(self): opts = CustomSamplerOptions(class_path="my.module.MySampler") From 096493edd652d64684060a0b262fe662a2849e07 Mon Sep 17 00:00:00 2001 From: mkolodner Date: Fri, 6 Mar 2026 02:49:31 +0000 Subject: [PATCH 20/46] Make num_neighbors required on loaders, simplify resolve_sampler_options --- gigl/distributed/dist_ablp_neighborloader.py | 11 ++--- .../distributed/distributed_neighborloader.py | 11 ++--- gigl/distributed/sampler_options.py | 43 ++++++------------- .../unit/distributed/sampler_options_test.py | 32 ++------------ 4 files changed, 24 insertions(+), 73 deletions(-) diff --git a/gigl/distributed/dist_ablp_neighborloader.py b/gigl/distributed/dist_ablp_neighborloader.py index b760fa9f3..ac7410dcb 100644 --- a/gigl/distributed/dist_ablp_neighborloader.py +++ b/gigl/distributed/dist_ablp_neighborloader.py @@ -63,7 +63,7 @@ class DistABLPLoader(BaseDistLoader): def __init__( self, dataset: Union[DistDataset, RemoteDistDataset], - num_neighbors: Optional[Union[list[int], dict[EdgeType, list[int]]]] = None, + num_neighbors: Union[list[int], dict[EdgeType, list[int]]], input_nodes: Optional[ Union[ torch.Tensor, @@ -138,13 +138,12 @@ def __init__( Args: dataset (Union[DistDataset, RemoteDistDataset]): The dataset to sample from. If this is a `RemoteDistDataset`, then we are in "Graph Store" mode. - num_neighbors (Optional[list[int] or dict[tuple[str, str, str], list[int]]]): + num_neighbors (list[int] or dict[tuple[str, str, str], list[int]]): The number of neighbors to sample for each node in each iteration. If an entry is set to `-1`, all neighbors will be included. In heterogeneous graphs, may also take in a dictionary denoting the amount of neighbors to sample for each individual edge type. - Required — either directly or via ``KHopNeighborSamplerOptions``. - If both are provided with ``KHopNeighborSamplerOptions``, they must match. + If ``KHopNeighborSamplerOptions`` is also provided, they must match. input_nodes: Indices of seed nodes to start sampling from. For Colocated mode: `torch.Tensor` or `tuple[NodeType, torch.Tensor]`. If set to `None` for homogeneous settings, all nodes will be considered. @@ -206,9 +205,7 @@ def __init__( # then we can properly clean up and don't get extraneous error messages. self._shutdowned = True - num_neighbors, sampler_options = resolve_sampler_options( - num_neighbors, sampler_options - ) + sampler_options = resolve_sampler_options(num_neighbors, sampler_options) # Determine sampling cluster setup based on dataset type if isinstance(dataset, RemoteDistDataset): diff --git a/gigl/distributed/distributed_neighborloader.py b/gigl/distributed/distributed_neighborloader.py index 86b5f8b08..33ffc4fe4 100644 --- a/gigl/distributed/distributed_neighborloader.py +++ b/gigl/distributed/distributed_neighborloader.py @@ -60,7 +60,7 @@ class DistNeighborLoader(BaseDistLoader): def __init__( self, dataset: Union[DistDataset, RemoteDistDataset], - num_neighbors: Optional[Union[list[int], dict[EdgeType, list[int]]]] = None, + num_neighbors: Union[list[int], dict[EdgeType, list[int]]], input_nodes: Optional[ Union[ torch.Tensor, @@ -99,13 +99,12 @@ def __init__( Args: dataset (DistDataset | RemoteDistDataset): The dataset to sample from. If this is a `RemoteDistDataset`, then we assumed to be in "Graph Store" mode. - num_neighbors (Optional[list[int] or dict[Tuple[str, str, str], list[int]]]): + num_neighbors (list[int] or dict[Tuple[str, str, str], list[int]]): The number of neighbors to sample for each node in each iteration. If an entry is set to `-1`, all neighbors will be included. In heterogeneous graphs, may also take in a dictionary denoting the amount of neighbors to sample for each individual edge type. - Required — either directly or via ``KHopNeighborSamplerOptions``. - If both are provided with ``KHopNeighborSamplerOptions``, they must match. + If ``KHopNeighborSamplerOptions`` is also provided, they must match. context (deprecated - will be removed soon) (DistributedContext): Distributed context information of the current process. local_process_rank (deprecated - will be removed soon) (int): Required if context provided. The local rank of the current process within a node. local_process_world_size (deprecated - will be removed soon)(int): Required if context provided. The total number of processes within a node. @@ -160,9 +159,7 @@ def __init__( # then we can properly clean up and don't get extraneous error messages. self._shutdowned = True - num_neighbors, sampler_options = resolve_sampler_options( - num_neighbors, sampler_options - ) + sampler_options = resolve_sampler_options(num_neighbors, sampler_options) # Resolve distributed context runtime = BaseDistLoader.resolve_runtime( diff --git a/gigl/distributed/sampler_options.py b/gigl/distributed/sampler_options.py index ba27193df..e1efb7602 100644 --- a/gigl/distributed/sampler_options.py +++ b/gigl/distributed/sampler_options.py @@ -49,52 +49,35 @@ class CustomSamplerOptions: def resolve_sampler_options( - num_neighbors: Optional[Union[list[int], dict[EdgeType, list[int]]]], + num_neighbors: Union[list[int], dict[EdgeType, list[int]]], sampler_options: Optional[SamplerOptions], -) -> tuple[Union[list[int], dict[EdgeType, list[int]]], SamplerOptions]: - """Resolve num_neighbors and sampler_options from user-provided values. +) -> SamplerOptions: + """Resolve sampler_options from user-provided values. - Handles backwards compatibility: callers can provide ``num_neighbors`` - alone (old API), ``sampler_options`` alone (new API), or both (validated - for consistency). - - ``num_neighbors`` is always required — either directly or via - ``KHopNeighborSamplerOptions.num_neighbors``. This follows the PyG - convention where neighbor fanout must be explicitly specified. + If ``sampler_options`` is ``None``, wraps ``num_neighbors`` in a + ``KHopNeighborSamplerOptions``. If ``KHopNeighborSamplerOptions`` is + provided, validates that its ``num_neighbors`` matches the explicit value. Args: - num_neighbors: Fanout per hop, or None. + num_neighbors: Fanout per hop (always required). sampler_options: Sampler configuration, or None. Returns: - A tuple of (resolved_num_neighbors, resolved_sampler_options). + The resolved SamplerOptions. Raises: - ValueError: If num_neighbors cannot be determined, or if both provide - conflicting num_neighbors values. + ValueError: If ``KHopNeighborSamplerOptions.num_neighbors`` conflicts + with the explicit ``num_neighbors``. """ if sampler_options is None: - if num_neighbors is None: - raise ValueError( - "num_neighbors must be provided, either directly or via " - "KHopNeighborSamplerOptions." - ) - return num_neighbors, KHopNeighborSamplerOptions(num_neighbors) + return KHopNeighborSamplerOptions(num_neighbors) if isinstance(sampler_options, KHopNeighborSamplerOptions): - if num_neighbors is None: - return sampler_options.num_neighbors, sampler_options if num_neighbors != sampler_options.num_neighbors: raise ValueError( f"num_neighbors ({num_neighbors}) does not match " f"sampler_options.num_neighbors ({sampler_options.num_neighbors}). " f"Provide one or the other, not both with different values." ) - return num_neighbors, sampler_options - - assert isinstance(sampler_options, CustomSamplerOptions) - if num_neighbors is None: - raise ValueError( - "num_neighbors must be provided when using CustomSamplerOptions." - ) - return num_neighbors, sampler_options + + return sampler_options diff --git a/tests/unit/distributed/sampler_options_test.py b/tests/unit/distributed/sampler_options_test.py index 4a7c5c806..c8cde12c2 100644 --- a/tests/unit/distributed/sampler_options_test.py +++ b/tests/unit/distributed/sampler_options_test.py @@ -7,32 +7,14 @@ class ResolveSamplerOptionsTest(TestCase): - def test_both_none_raises(self): - with self.assertRaises(ValueError): - resolve_sampler_options(num_neighbors=None, sampler_options=None) - def test_num_neighbors_only(self): - num_neighbors, options = resolve_sampler_options( - num_neighbors=[2, 2], sampler_options=None - ) - self.assertEqual(num_neighbors, [2, 2]) + options = resolve_sampler_options(num_neighbors=[2, 2], sampler_options=None) assert isinstance(options, KHopNeighborSamplerOptions) self.assertEqual(options.num_neighbors, [2, 2]) - def test_khop_options_only(self): - opts = KHopNeighborSamplerOptions(num_neighbors=[3, 3]) - num_neighbors, options = resolve_sampler_options( - num_neighbors=None, sampler_options=opts - ) - self.assertEqual(num_neighbors, [3, 3]) - self.assertIs(options, opts) - def test_khop_options_matching_num_neighbors(self): opts = KHopNeighborSamplerOptions(num_neighbors=[2, 2]) - num_neighbors, options = resolve_sampler_options( - num_neighbors=[2, 2], sampler_options=opts - ) - self.assertEqual(num_neighbors, [2, 2]) + options = resolve_sampler_options(num_neighbors=[2, 2], sampler_options=opts) self.assertIs(options, opts) def test_khop_options_conflicting_num_neighbors_raises(self): @@ -40,15 +22,7 @@ def test_khop_options_conflicting_num_neighbors_raises(self): with self.assertRaises(ValueError): resolve_sampler_options(num_neighbors=[2, 2], sampler_options=opts) - def test_custom_options_without_num_neighbors_raises(self): - opts = CustomSamplerOptions(class_path="my.module.MySampler") - with self.assertRaises(ValueError): - resolve_sampler_options(num_neighbors=None, sampler_options=opts) - def test_custom_options_with_num_neighbors(self): opts = CustomSamplerOptions(class_path="my.module.MySampler") - num_neighbors, options = resolve_sampler_options( - num_neighbors=[2, 2], sampler_options=opts - ) - self.assertEqual(num_neighbors, [2, 2]) + options = resolve_sampler_options(num_neighbors=[2, 2], sampler_options=opts) self.assertIs(options, opts) From b6a359792b1e87c77786acafdfa59c8c0e3e04cc Mon Sep 17 00:00:00 2001 From: mkolodner Date: Fri, 6 Mar 2026 02:52:19 +0000 Subject: [PATCH 21/46] Fix stale error message in resolve_sampler_options --- gigl/distributed/sampler_options.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/gigl/distributed/sampler_options.py b/gigl/distributed/sampler_options.py index e1efb7602..9f6f41ba9 100644 --- a/gigl/distributed/sampler_options.py +++ b/gigl/distributed/sampler_options.py @@ -76,8 +76,7 @@ def resolve_sampler_options( if num_neighbors != sampler_options.num_neighbors: raise ValueError( f"num_neighbors ({num_neighbors}) does not match " - f"sampler_options.num_neighbors ({sampler_options.num_neighbors}). " - f"Provide one or the other, not both with different values." + f"sampler_options.num_neighbors ({sampler_options.num_neighbors})." ) return sampler_options From 2e49c514e2704e1abb7372004cfd07c52b5900db Mon Sep 17 00:00:00 2001 From: mkolodner Date: Fri, 6 Mar 2026 19:44:55 +0000 Subject: [PATCH 22/46] Remove CustomSamplerOptions, keep SamplerOptions plumbing for future extension --- gigl/distributed/dist_sampling_producer.py | 13 +--- gigl/distributed/sampler_options.py | 42 +++--------- .../dist_ablp_neighborloader_test.py | 66 ------------------- .../distributed_neighborloader_test.py | 44 ------------- .../unit/distributed/sampler_options_test.py | 6 -- 5 files changed, 12 insertions(+), 159 deletions(-) diff --git a/gigl/distributed/dist_sampling_producer.py b/gigl/distributed/dist_sampling_producer.py index b76a866fc..f1ff751b7 100644 --- a/gigl/distributed/dist_sampling_producer.py +++ b/gigl/distributed/dist_sampling_producer.py @@ -4,7 +4,6 @@ # neighbor sampling and ABLP) instead of GLT's DistNeighborSampler. import datetime -import importlib import queue from threading import Barrier from typing import Optional, Union, cast @@ -36,7 +35,7 @@ from torch.utils.data.dataset import Dataset from gigl.distributed.dist_neighbor_sampler import DistNeighborSampler -from gigl.distributed.sampler_options import CustomSamplerOptions, SamplerOptions +from gigl.distributed.sampler_options import SamplerOptions def _sampling_worker_loop( @@ -91,14 +90,7 @@ def _sampling_worker_loop( seed_everything(sampling_config.seed) # Resolve sampler class from options - extra_kwargs: dict[str, object] = {} - if isinstance(sampler_options, CustomSamplerOptions): - module_path, class_name = sampler_options.class_path.rsplit(".", 1) - module = importlib.import_module(module_path) - sampler_cls = getattr(module, class_name) - extra_kwargs = sampler_options.class_args - else: - sampler_cls = DistNeighborSampler + sampler_cls = DistNeighborSampler dist_sampler = sampler_cls( data, @@ -113,7 +105,6 @@ def _sampling_worker_loop( worker_options.worker_concurrency, current_device, seed=sampling_config.seed, - **extra_kwargs, ) dist_sampler.start_loop() diff --git a/gigl/distributed/sampler_options.py b/gigl/distributed/sampler_options.py index 9f6f41ba9..521df59b2 100644 --- a/gigl/distributed/sampler_options.py +++ b/gigl/distributed/sampler_options.py @@ -1,15 +1,13 @@ """Sampler option types for configuring which sampler class to use in distributed loading. -Provides two options: -- ``KHopNeighborSamplerOptions``: Uses GiGL's built-in ``DistNeighborSampler``. -- ``CustomSamplerOptions``: Dynamically imports and uses a user-provided sampler class. +Provides ``KHopNeighborSamplerOptions`` for using GiGL's built-in ``DistNeighborSampler``. -Both are frozen dataclasses so they are safe to pickle across RPC boundaries +Frozen dataclass so it is safe to pickle across RPC boundaries (required for Graph Store mode). """ -from dataclasses import dataclass, field -from typing import Any, Optional, Union +from dataclasses import dataclass +from typing import Optional, Union from graphlearn_torch.typing import EdgeType @@ -26,26 +24,7 @@ class KHopNeighborSamplerOptions: num_neighbors: Union[list[int], dict[EdgeType, list[int]]] -@dataclass(frozen=True) -class CustomSamplerOptions: - """Custom sampler options that dynamically import a user-provided sampler class. - - The class at ``class_path`` must conform to the same interface as - ``DistNeighborSampler`` (extend ``GLTDistNeighborSampler`` or at minimum - support ``start_loop``, ``sample_from_nodes``, etc.). - - Attributes: - class_path: Fully qualified Python import path, e.g. - ``"my.module.MySampler"``. - class_args: Additional keyword arguments passed to the sampler - constructor (on top of the standard GLT arguments). - """ - - class_path: str - class_args: dict[str, Any] = field(default_factory=dict) - - -SamplerOptions = Union[KHopNeighborSamplerOptions, CustomSamplerOptions] +SamplerOptions = KHopNeighborSamplerOptions def resolve_sampler_options( @@ -72,11 +51,10 @@ def resolve_sampler_options( if sampler_options is None: return KHopNeighborSamplerOptions(num_neighbors) - if isinstance(sampler_options, KHopNeighborSamplerOptions): - if num_neighbors != sampler_options.num_neighbors: - raise ValueError( - f"num_neighbors ({num_neighbors}) does not match " - f"sampler_options.num_neighbors ({sampler_options.num_neighbors})." - ) + if num_neighbors != sampler_options.num_neighbors: + raise ValueError( + f"num_neighbors ({num_neighbors}) does not match " + f"sampler_options.num_neighbors ({sampler_options.num_neighbors})." + ) return sampler_options diff --git a/tests/unit/distributed/dist_ablp_neighborloader_test.py b/tests/unit/distributed/dist_ablp_neighborloader_test.py index 437b7ce37..fd14aadd5 100644 --- a/tests/unit/distributed/dist_ablp_neighborloader_test.py +++ b/tests/unit/distributed/dist_ablp_neighborloader_test.py @@ -15,7 +15,6 @@ from gigl.distributed.dist_dataset import DistDataset from gigl.distributed.dist_partitioner import DistPartitioner from gigl.distributed.dist_range_partitioner import DistRangePartitioner -from gigl.distributed.sampler_options import CustomSamplerOptions, SamplerOptions from gigl.distributed.utils.serialized_graph_metadata_translator import ( convert_pb_to_serialized_graph_metadata, ) @@ -308,30 +307,6 @@ def _run_toy_heterogeneous_ablp( shutdown_rpc() -def _run_distributed_ablp_loader_with_sampler_options( - _, - dataset: DistDataset, - expected_data_count: int, - sampler_options: SamplerOptions, -): - create_test_process_group() - loader = DistABLPLoader( - dataset=dataset, - num_neighbors=[2, 2], - input_nodes=to_homogeneous(dataset.train_node_ids), - pin_memory_device=torch.device("cpu"), - sampler_options=sampler_options, - ) - count = 0 - for datum in loader: - assert isinstance(datum, Data) - assert hasattr(datum, "y_positive") - count += 1 - assert count == expected_data_count - - shutdown_rpc() - - def _run_distributed_ablp_neighbor_loader_multiple_supervision_edge_types( _, input_nodes: tuple[NodeType, torch.Tensor], @@ -943,47 +918,6 @@ def test_ablp_dataloder_multiple_supervision_edge_types( ), ), - def test_ablp_loader_with_custom_sampler_options(self): - create_test_process_group() - cora_supervised_info = get_mocked_dataset_artifact_metadata()[ - CORA_USER_DEFINED_NODE_ANCHOR_MOCKED_DATASET_INFO.name - ] - - gbml_config_pb_wrapper = ( - GbmlConfigPbWrapper.get_gbml_config_pb_wrapper_from_uri( - gbml_config_uri=cora_supervised_info.frozen_gbml_config_uri - ) - ) - - serialized_graph_metadata = convert_pb_to_serialized_graph_metadata( - preprocessed_metadata_pb_wrapper=gbml_config_pb_wrapper.preprocessed_metadata_pb_wrapper, - graph_metadata_pb_wrapper=gbml_config_pb_wrapper.graph_metadata_pb_wrapper, - tfrecord_uri_pattern=".*.tfrecord(.gz)?$", - ) - - splitter = DistNodeAnchorLinkSplitter( - sampling_direction="in", should_convert_labels_to_edges=True - ) - - dataset = build_dataset( - serialized_graph_metadata=serialized_graph_metadata, - sample_edge_direction="in", - splitter=splitter, - ) - - assert dataset.train_node_ids is not None, "Train node ids must exist." - - mp.spawn( - fn=_run_distributed_ablp_loader_with_sampler_options, - args=( - dataset, - to_homogeneous(dataset.train_node_ids).numel(), - CustomSamplerOptions( - class_path="gigl.distributed.dist_neighbor_sampler.DistNeighborSampler", - ), - ), - ) - @parameterized.expand( [ param( diff --git a/tests/unit/distributed/distributed_neighborloader_test.py b/tests/unit/distributed/distributed_neighborloader_test.py index 2a85a810a..30ef1d910 100644 --- a/tests/unit/distributed/distributed_neighborloader_test.py +++ b/tests/unit/distributed/distributed_neighborloader_test.py @@ -11,7 +11,6 @@ from gigl.distributed.dataset_factory import build_dataset from gigl.distributed.dist_dataset import DistDataset from gigl.distributed.distributed_neighborloader import DistNeighborLoader -from gigl.distributed.sampler_options import CustomSamplerOptions, SamplerOptions from gigl.distributed.utils import get_free_port from gigl.distributed.utils.serialized_graph_metadata_translator import ( convert_pb_to_serialized_graph_metadata, @@ -276,30 +275,6 @@ def _run_distributed_neighbor_loader_with_node_labels_heterogeneous( shutdown_rpc() -def _run_distributed_neighbor_loader_with_sampler_options( - _, - dataset: DistDataset, - expected_data_count: int, - sampler_options: SamplerOptions, -): - create_test_process_group() - loader = DistNeighborLoader( - dataset=dataset, - num_neighbors=[2, 2], - pin_memory_device=torch.device("cpu"), - sampler_options=sampler_options, - ) - - count = 0 - for datum in loader: - assert isinstance(datum, Data) - count += 1 - - assert count == expected_data_count - - shutdown_rpc() - - def _run_cora_supervised_node_classification( _, dataset: DistDataset, @@ -602,25 +577,6 @@ def test_isolated_homogeneous_neighbor_loader( args=(dataset, 18), ) - def test_distributed_neighbor_loader_with_custom_sampler_options(self): - expected_data_count = 2708 - dataset = run_distributed_dataset( - rank=0, - world_size=self._world_size, - mocked_dataset_info=CORA_NODE_ANCHOR_MOCKED_DATASET_INFO, - _port=get_free_port(), - ) - mp.spawn( - fn=_run_distributed_neighbor_loader_with_sampler_options, - args=( - dataset, - expected_data_count, - CustomSamplerOptions( - class_path="gigl.distributed.dist_neighbor_sampler.DistNeighborSampler", - ), - ), - ) - @parameterized.expand( [ param( diff --git a/tests/unit/distributed/sampler_options_test.py b/tests/unit/distributed/sampler_options_test.py index c8cde12c2..9c4ae172a 100644 --- a/tests/unit/distributed/sampler_options_test.py +++ b/tests/unit/distributed/sampler_options_test.py @@ -1,5 +1,4 @@ from gigl.distributed.sampler_options import ( - CustomSamplerOptions, KHopNeighborSamplerOptions, resolve_sampler_options, ) @@ -21,8 +20,3 @@ def test_khop_options_conflicting_num_neighbors_raises(self): opts = KHopNeighborSamplerOptions(num_neighbors=[3, 3]) with self.assertRaises(ValueError): resolve_sampler_options(num_neighbors=[2, 2], sampler_options=opts) - - def test_custom_options_with_num_neighbors(self): - opts = CustomSamplerOptions(class_path="my.module.MySampler") - options = resolve_sampler_options(num_neighbors=[2, 2], sampler_options=opts) - self.assertIs(options, opts) From 8a4f3ef2aa615b0ea09d5393725241222fe0e865 Mon Sep 17 00:00:00 2001 From: mkolodner Date: Fri, 6 Mar 2026 19:59:10 +0000 Subject: [PATCH 23/46] Use kwargs on DistSamplingProducer calls; isinstance dispatch for sampler_cls; make sampler_options required on dist_server --- gigl/distributed/dist_ablp_neighborloader.py | 12 ++++++------ gigl/distributed/dist_sampling_producer.py | 13 +++++++++---- .../distributed/distributed_neighborloader.py | 12 ++++++------ gigl/distributed/graph_store/dist_server.py | 19 +++++++++---------- 4 files changed, 30 insertions(+), 26 deletions(-) diff --git a/gigl/distributed/dist_ablp_neighborloader.py b/gigl/distributed/dist_ablp_neighborloader.py index ac7410dcb..0bc482225 100644 --- a/gigl/distributed/dist_ablp_neighborloader.py +++ b/gigl/distributed/dist_ablp_neighborloader.py @@ -341,12 +341,12 @@ def __init__( producer: Union[ DistSamplingProducer, Callable[..., int] ] = DistSamplingProducer( - dataset, - sampler_input, - sampling_config, - worker_options, - channel, - sampler_options, + data=dataset, + sampler_input=sampler_input, + sampling_config=sampling_config, + worker_options=worker_options, + channel=channel, + sampler_options=sampler_options, ) else: producer = DistServer.create_sampling_producer diff --git a/gigl/distributed/dist_sampling_producer.py b/gigl/distributed/dist_sampling_producer.py index f1ff751b7..7b07ff006 100644 --- a/gigl/distributed/dist_sampling_producer.py +++ b/gigl/distributed/dist_sampling_producer.py @@ -35,7 +35,7 @@ from torch.utils.data.dataset import Dataset from gigl.distributed.dist_neighbor_sampler import DistNeighborSampler -from gigl.distributed.sampler_options import SamplerOptions +from gigl.distributed.sampler_options import KHopNeighborSamplerOptions, SamplerOptions def _sampling_worker_loop( @@ -49,7 +49,7 @@ def _sampling_worker_loop( task_queue: mp.Queue, sampling_completed_worker_count, # mp.Value mp_barrier: Barrier, - sampler_options: Optional[SamplerOptions] = None, + sampler_options: SamplerOptions, ): dist_sampler = None try: @@ -90,7 +90,12 @@ def _sampling_worker_loop( seed_everything(sampling_config.seed) # Resolve sampler class from options - sampler_cls = DistNeighborSampler + if isinstance(sampler_options, KHopNeighborSamplerOptions): + sampler_cls = DistNeighborSampler + else: + raise NotImplementedError( + f"Unsupported sampler options type: {type(sampler_options)}" + ) dist_sampler = sampler_cls( data, @@ -181,7 +186,7 @@ def __init__( sampling_config: SamplingConfig, worker_options: MpDistSamplingWorkerOptions, channel: ChannelBase, - sampler_options: Optional[SamplerOptions] = None, + sampler_options: SamplerOptions, ): super().__init__(data, sampler_input, sampling_config, worker_options, channel) self._sampler_options = sampler_options diff --git a/gigl/distributed/distributed_neighborloader.py b/gigl/distributed/distributed_neighborloader.py index 33ffc4fe4..d6de64ab9 100644 --- a/gigl/distributed/distributed_neighborloader.py +++ b/gigl/distributed/distributed_neighborloader.py @@ -249,12 +249,12 @@ def __init__( producer: Union[ DistSamplingProducer, Callable[..., int] ] = DistSamplingProducer( - dataset, - input_data, - sampling_config, - worker_options, - channel, - sampler_options, + data=dataset, + sampler_input=input_data, + sampling_config=sampling_config, + worker_options=worker_options, + channel=channel, + sampler_options=sampler_options, ) else: producer = GiglDistServer.create_sampling_producer diff --git a/gigl/distributed/graph_store/dist_server.py b/gigl/distributed/graph_store/dist_server.py index fdfffd305..0530e288b 100644 --- a/gigl/distributed/graph_store/dist_server.py +++ b/gigl/distributed/graph_store/dist_server.py @@ -432,7 +432,7 @@ def create_sampling_producer( ], sampling_config: SamplingConfig, worker_options: RemoteDistSamplingWorkerOptions, - sampler_options: Optional[SamplerOptions] = None, + sampler_options: SamplerOptions, ) -> int: """Create and initialize an instance of ``DistSamplingProducer`` with a group of subprocesses for distributed sampling. @@ -446,9 +446,8 @@ def create_sampling_producer( sampling_config (SamplingConfig): Configuration of sampling meta info. worker_options (RemoteDistSamplingWorkerOptions): Options for launching remote sampling workers by this server. - sampler_options (Optional[SamplerOptions]): Controls which sampler class - is instantiated. If ``None``, defaults to the built-in - ``DistNeighborSampler``. + sampler_options (SamplerOptions): Controls which sampler class + is instantiated. Returns: int: A unique id of created sampling producer on this server. @@ -472,12 +471,12 @@ def create_sampling_producer( worker_options.buffer_capacity, worker_options.buffer_size ) producer = DistSamplingProducer( - self.dataset, - sampler_input, - sampling_config, - worker_options, - buffer, - sampler_options, + data=self.dataset, + sampler_input=sampler_input, + sampling_config=sampling_config, + worker_options=worker_options, + channel=buffer, + sampler_options=sampler_options, ) producer.init() self._producer_pool[producer_id] = producer From f2be444f1a6785601b65c47a825a8538842adeed Mon Sep 17 00:00:00 2001 From: mkolodner Date: Mon, 9 Mar 2026 17:40:30 +0000 Subject: [PATCH 24/46] Update --- gigl/distributed/base_dist_loader.py | 2 +- gigl/distributed/sampler_options.py | 5 +++++ 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/gigl/distributed/base_dist_loader.py b/gigl/distributed/base_dist_loader.py index 6db2385e0..980f064ed 100644 --- a/gigl/distributed/base_dist_loader.py +++ b/gigl/distributed/base_dist_loader.py @@ -208,7 +208,7 @@ def __init__( device: torch.device, runtime: DistributedRuntimeInfo, producer: Union[DistSamplingProducer, Callable[..., int]], - sampler_options: Optional[SamplerOptions] = None, + sampler_options: SamplerOptions, process_start_gap_seconds: float = 60.0, ): # Set right away so __del__ can clean up if we throw during init. diff --git a/gigl/distributed/sampler_options.py b/gigl/distributed/sampler_options.py index 521df59b2..f7bbb2e4b 100644 --- a/gigl/distributed/sampler_options.py +++ b/gigl/distributed/sampler_options.py @@ -11,6 +11,10 @@ from graphlearn_torch.typing import EdgeType +from gigl.common.logger import Logger + +logger = Logger() + @dataclass(frozen=True) class KHopNeighborSamplerOptions: @@ -56,5 +60,6 @@ def resolve_sampler_options( f"num_neighbors ({num_neighbors}) does not match " f"sampler_options.num_neighbors ({sampler_options.num_neighbors})." ) + logger.info(f"Using sampler options: {sampler_options}") return sampler_options From d47b1b9111bcaf1fd8644073ef7305b5cb6eb628 Mon Sep 17 00:00:00 2001 From: mkolodner Date: Mon, 9 Mar 2026 18:39:48 +0000 Subject: [PATCH 25/46] Add PPRSamplerOptions and DistPPRNeighborSampler with ABLP support and dataset degree tensors --- gigl/distributed/dist_ppr_sampler.py | 532 +++++++++++++++++++++ gigl/distributed/dist_sampling_producer.py | 21 +- gigl/distributed/sampler_options.py | 51 +- 3 files changed, 595 insertions(+), 9 deletions(-) create mode 100644 gigl/distributed/dist_ppr_sampler.py diff --git a/gigl/distributed/dist_ppr_sampler.py b/gigl/distributed/dist_ppr_sampler.py new file mode 100644 index 000000000..97ff99f15 --- /dev/null +++ b/gigl/distributed/dist_ppr_sampler.py @@ -0,0 +1,532 @@ +import heapq +from collections import defaultdict +from typing import Optional, Union + +import torch +from graphlearn_torch.channel import SampleMessage +from graphlearn_torch.sampler import ( + HeteroSamplerOutput, + NeighborOutput, + NodeSamplerInput, + SamplerOutput, +) +from graphlearn_torch.typing import EdgeType, NodeType + +from gigl.distributed.dist_dataset import DistDataset +from gigl.distributed.dist_neighbor_sampler import DistNeighborSampler + +_PPR_HOMOGENEOUS_NODE_TYPE = "ppr_homogeneous_node_type" +_PPR_HOMOGENEOUS_EDGE_TYPE = ( + _PPR_HOMOGENEOUS_NODE_TYPE, + "to", + _PPR_HOMOGENEOUS_NODE_TYPE, +) + + +class DistPPRNeighborSampler(DistNeighborSampler): + """ + Personalized PageRank (PPR) based neighbor sampler that inherits from GLT DistNeighborSampler. + + Instead of uniform random sampling, this sampler uses PPR scores to select the most + relevant neighbors for each seed node. The PPR algorithm approximates the stationary + distribution of a random walk with restart probability alpha. + + This sampler supports both homogeneous and heterogeneous graphs. For heterogeneous graphs, + the PPR algorithm traverses across all edge types, switching edge types based on the + current node type and the configured edge direction. + + Degree tensors are sourced automatically from the dataset at initialization time. + + Args: + alpha: Restart probability (teleport probability back to seed). Higher values + keep samples closer to seeds. Typical values: 0.15-0.25. + eps: Convergence threshold. Smaller values give more accurate PPR scores + but require more computation. Typical values: 1e-4 to 1e-6. + max_ppr_nodes: Maximum number of nodes to return per seed based on PPR scores. + default_node_id: Node ID to use when fewer than max_ppr_nodes are found. + default_weight: Weight to assign to padding nodes. + num_nbrs_per_hop: Maximum number of neighbors to fetch per hop. + """ + + def __init__( + self, + *args, + alpha: float = 0.5, + eps: float = 1e-4, + max_ppr_nodes: int = 50, + default_node_id: int = -1, + default_weight: float = 0.0, + num_nbrs_per_hop: int = 100000, + **kwargs, + ): + super().__init__(*args, **kwargs) + self._alpha = alpha + self._eps = eps + self._max_ppr_nodes = max_ppr_nodes + self._default_node_id = default_node_id + self._default_weight = default_weight + self._alpha_eps = alpha * eps + self._num_nbrs_per_hop = num_nbrs_per_hop + + assert isinstance( + self.data, DistDataset + ), "DistPPRNeighborSampler requires a GiGL DistDataset to access degree tensors." + self._degree_tensors = self.data.degree_tensor + + # Build mapping from node type to edge types that can be traversed from that node type. + self._node_type_to_edge_types: dict[NodeType, list[EdgeType]] = defaultdict( + list + ) + + if hasattr(self, "edge_types") and self.edge_types is not None: + self._is_homogeneous = False + # Heterogeneous case: map each node type to its outgoing/incoming edge types + for etype in self.edge_types: + if self.edge_dir == "in": + # For incoming edges, we traverse FROM the destination node type + anchor_type = etype[-1] + else: # "out" + # For outgoing edges, we traverse FROM the source node type + anchor_type = etype[0] + + self._node_type_to_edge_types[anchor_type].append(etype) + else: + self._node_type_to_edge_types[_PPR_HOMOGENEOUS_NODE_TYPE] = [ + _PPR_HOMOGENEOUS_EDGE_TYPE + ] + self._is_homogeneous = True + + def _get_degree_from_tensor(self, node_id: int, edge_type: EdgeType) -> int: + """ + Look up the TRUE degree of a node for a specific edge type from in-memory tensors. + + This returns the actual node degree (not capped), which is mathematically correct + for PPR algorithm calculations. + + Args: + node_id: The ID of the node to look up. + edge_type: The edge type to get the degree for. + + Returns: + The true degree of the node for the given edge type. + """ + if self._is_homogeneous: + # For homogeneous graphs, degree_tensors is a single tensor + assert isinstance(self._degree_tensors, torch.Tensor) + if node_id >= len(self._degree_tensors): + return 0 + return int(self._degree_tensors[node_id].item()) + else: + # For heterogeneous graphs, degree_tensors is a dict keyed by edge type + assert isinstance(self._degree_tensors, dict) + if edge_type not in self._degree_tensors: + return 0 + degree_tensor = self._degree_tensors[edge_type] + if node_id >= len(degree_tensor): + return 0 + return int(degree_tensor[node_id].item()) + + def _get_neighbor_type(self, edge_type: EdgeType) -> NodeType: + """Get the node type of neighbors reached via an edge type.""" + return edge_type[0] if self.edge_dir == "in" else edge_type[-1] + + async def _get_neighbors_for_nodes( + self, + nodes: torch.Tensor, + edge_type: EdgeType, + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Fetch neighbors for a batch of nodes. + + Returns: + tuple of (neighbors, neighbor_counts) where neighbors is a flattened tensor + and neighbor_counts[i] gives the number of neighbors for nodes[i]. + """ + # Use the underlying sampling infrastructure to get all neighbors + # We request a large number to effectively get all neighbors + output: NeighborOutput = await self._sample_one_hop( + srcs=nodes, + num_nbr=self._num_nbrs_per_hop, + etype=edge_type if edge_type is not _PPR_HOMOGENEOUS_EDGE_TYPE else None, + ) + return output.nbr, output.nbr_num + + async def _batch_fetch_neighbors( + self, + nodes_by_edge_type: dict[EdgeType, set[int]], + neighbor_target: dict[tuple[int, EdgeType], list[int]], + device: torch.device, + ) -> int: + """ + Batch fetch neighbors for nodes grouped by edge type. + + Fetches neighbors for all nodes in nodes_by_edge_type, populating + neighbor_target with neighbor lists. Degrees are looked up separately + from the in-memory degree_tensors. + + Args: + nodes_by_edge_type: Dict mapping edge type to set of node IDs to fetch + neighbor_target: Dict to populate with (node_id, edge_type) -> neighbor list + device: Torch device for tensor creation + + Returns: + Number of neighbor lookup calls made + """ + num_lookups = 0 + for etype, node_ids in nodes_by_edge_type.items(): + if not node_ids: + continue + nodes_list = list(node_ids) + lookup_tensor = torch.tensor(nodes_list, dtype=torch.long, device=device) + + neighbors, neighbor_counts = await self._get_neighbors_for_nodes( + lookup_tensor, + etype, + ) + num_lookups += 1 + + neighbors_list = neighbors.tolist() + counts_list = neighbor_counts.tolist() + del neighbors, neighbor_counts + + # neighbors_list is a flat concatenation of all neighbors for all looked-up nodes. + # We use offset to slice out each node's neighbors: node i's neighbors are at + # neighbors_list[offset : offset + count], then we advance offset by count. + offset = 0 + for node_id, count in zip(nodes_list, counts_list): + cache_key = (node_id, etype) + neighbor_target[cache_key] = neighbors_list[offset : offset + count] + offset += count + + return num_lookups + + async def _compute_ppr_scores( + self, + seed_nodes: torch.Tensor, + seed_node_type: Optional[NodeType] = None, + ) -> tuple[ + Union[torch.Tensor, dict[NodeType, torch.Tensor]], + Union[torch.Tensor, dict[NodeType, torch.Tensor]], + ]: + """ + Compute PPR scores for seed nodes using the push-based approximation algorithm. + + This implements the Forward Push algorithm (Andersen et al., 2006) which + iteratively pushes probability mass from nodes with high residual to their + neighbors. For heterogeneous graphs, the algorithm traverses across all + edge types, switching based on the current node type. + + Algorithm Overview (each iteration of the main loop): + 1. Fetch neighbors: Drain all nodes from the queue, group by edge type, + and perform a batched neighbor lookup to populate neighbor/degree caches. + 2. Push residual: For each queued node, add its residual to its PPR score, + reset its residual to zero, then distribute (1-alpha) * residual to + all neighbors proportionally by degree. + 3. Batch fetch degrees: Group all neighbors that received residual and + perform a batched lookup to get their degrees (needed for threshold check). + 4. Re-queue high-residual nodes: For each neighbor that received residual, + check if residual >= alpha * eps * total_degree. If so, add to queue + for processing in the next iteration. + + Args: + seed_nodes: Tensor of seed node IDs [batch_size] + seed_node_type: Node type of seed nodes. Should be None for homogeneous graphs. + + Returns: + tuple of (neighbor_ids_by_type, ppr_weights_by_type) where: + - neighbor_ids_by_type: Union[torch.Tensor, dict mapping node type -> [batch_size, max_ppr_nodes]] + - ppr_weights_by_type: Union[torch.Tensor, dict mapping node type -> [batch_size, max_ppr_nodes]] + """ + if seed_node_type is None: + seed_node_type = _PPR_HOMOGENEOUS_NODE_TYPE + device = seed_nodes.device + batch_size = seed_nodes.size(0) + + # PPR scores: p[i][(node_id, node_type)] = score + p: list[dict[tuple[int, NodeType], float]] = [ + defaultdict(float) for _ in range(batch_size) + ] + # Residuals: r[i][(node_id, node_type)] = residual + r: list[dict[tuple[int, NodeType], float]] = [ + defaultdict(float) for _ in range(batch_size) + ] + + # Queue stores (node_id, node_type) tuples + q: list[set[tuple[int, NodeType]]] = [set() for _ in range(batch_size)] + + seed_list = seed_nodes.tolist() + + # Initialize residuals: r[i][(seed, seed_type)] = alpha for each seed + for i, seed in enumerate(seed_list): + r[i][(seed, seed_node_type)] = self._alpha + q[i].add((seed, seed_node_type)) + + # Cache keyed by (node_id, edge_type) since same node can have different neighbors per edge type + neighbor_cache: dict[tuple[int, EdgeType], list[int]] = {} + + num_nodes_in_queue = batch_size + + while num_nodes_in_queue > 0: + # Drain all nodes from all queues and group by edge type for batched lookups + nodes_to_process: list[set[tuple[int, NodeType]]] = [ + set() for _ in range(batch_size) + ] + nodes_by_edge_type: dict[EdgeType, set[int]] = defaultdict(set) + + for i in range(batch_size): + if q[i]: + nodes_to_process[i] = q[i] + q[i] = set() + num_nodes_in_queue -= len(nodes_to_process[i]) + + # Group nodes by edge type for batched lookups + for node_id, node_type in nodes_to_process[i]: + edge_types_for_node = self._node_type_to_edge_types[node_type] + for etype in edge_types_for_node: + cache_key = (node_id, etype) + if cache_key not in neighbor_cache: + nodes_by_edge_type[etype].add(node_id) + + # Batch fetch neighbors per edge type + await self._batch_fetch_neighbors( + nodes_by_edge_type, neighbor_cache, device + ) + + # Process nodes and push residual + for i in range(batch_size): + for u_node, u_type in nodes_to_process[i]: + key_u = (u_node, u_type) + res_u = r[i].get(key_u, 0.0) + + # Push to PPR score and reset residual + p[i][key_u] += res_u + r[i][key_u] = 0.0 + + # For each edge type from this node type, push residual to neighbors + edge_types_for_node = self._node_type_to_edge_types[u_type] + + # Calculate total degree across all edge types for proper probability distribution + # Degrees are looked up directly from in-memory tensors + total_degree = sum( + self._get_degree_from_tensor(u_node, etype) + for etype in edge_types_for_node + ) + + if total_degree == 0: + continue + + # Push residual proportionally based on degree per edge type + for etype in edge_types_for_node: + cache_key = (u_node, etype) + neighbor_list = neighbor_cache[cache_key] + neighbor_count = self._get_degree_from_tensor(u_node, etype) + + if neighbor_count == 0: + continue + + # Determine the type of the neighbors + v_type = self._get_neighbor_type(etype) + + # Distribute residual to neighbors, weighted by edge type contribution + push_value = (1 - self._alpha) * res_u / total_degree + + for v_node in neighbor_list: + key_v = (v_node, v_type) + r[i][key_v] += push_value + + # Add high-residual neighbors to queue + for i in range(batch_size): + for u_node, u_type in nodes_to_process[i]: + edge_types_for_node = self._node_type_to_edge_types.get(u_type, []) + for etype in edge_types_for_node: + cache_key = (u_node, etype) + neighbor_list = neighbor_cache[cache_key] + v_type = self._get_neighbor_type(etype) + + for v_node in neighbor_list: + key_v = (v_node, v_type) + + if key_v in q[i]: + continue + + res_v = r[i].get(key_v, 0.0) + if res_v == 0.0: + continue + + # Sum degrees across all edge types from v_type for threshold check + edge_types_for_v = self._node_type_to_edge_types.get( + v_type, [] + ) + total_v_degree = sum( + self._get_degree_from_tensor(v_node, v_etype) + for v_etype in edge_types_for_v + ) + + if res_v >= self._alpha_eps * total_v_degree: + q[i].add(key_v) + num_nodes_in_queue += 1 + + # Extract top-k nodes by PPR score, grouped by node type + all_node_types: set[NodeType] = set() + for i in range(batch_size): + for node_id, node_type in p[i].keys(): + all_node_types.add(node_type) + + out_neighbor_ids_dict: dict[NodeType, torch.Tensor] = {} + out_weights_dict: dict[NodeType, torch.Tensor] = {} + + for ntype in all_node_types: + ntype_neighbor_ids = torch.full( + (batch_size, self._max_ppr_nodes), + self._default_node_id, + dtype=torch.long, + device=device, + ) + ntype_weights = torch.full( + (batch_size, self._max_ppr_nodes), + self._default_weight, + dtype=torch.float, + device=device, + ) + + for i in range(batch_size): + # Filter to nodes of this type + type_scores = { + node_id: score + for (node_id, node_type), score in p[i].items() + if node_type == ntype + } + top_k = heapq.nlargest( + self._max_ppr_nodes, type_scores.items(), key=lambda x: x[1] + ) + + for j, (node_id, weight) in enumerate(top_k): + ntype_neighbor_ids[i, j] = node_id + ntype_weights[i, j] = weight + + out_neighbor_ids_dict[ntype] = ntype_neighbor_ids + out_weights_dict[ntype] = ntype_weights + + out_neighbor_ids: Union[torch.Tensor, dict[NodeType, torch.Tensor]] + out_weights: Union[torch.Tensor, dict[NodeType, torch.Tensor]] + if self._is_homogeneous: + assert ( + len(all_node_types) == 1 + and _PPR_HOMOGENEOUS_NODE_TYPE in all_node_types + ) + out_neighbor_ids = out_neighbor_ids_dict[_PPR_HOMOGENEOUS_NODE_TYPE] + out_weights = out_weights_dict[_PPR_HOMOGENEOUS_NODE_TYPE] + else: + out_neighbor_ids = out_neighbor_ids_dict + out_weights = out_weights_dict + + return out_neighbor_ids, out_weights + + async def _sample_from_nodes( + self, + inputs: NodeSamplerInput, + ) -> Optional[SampleMessage]: + """ + Override the base sampling method to use PPR-based neighbor selection. + + Supports both NodeSamplerInput and ABLPNodeSamplerInput. For ABLP, PPR + scores are computed from both anchor and supervision nodes, so the sampled + subgraph includes neighbors relevant to all seed types. + + For heterogeneous graphs, PPR traverses across all edge types, switching + edge types based on the current node type. PPR weights are stored in + metadata keyed as ``ppr_weights_{seed_type}_{neighbor_type}`` and + ``ppr_neighbor_ids_{seed_type}_{neighbor_type}``. + """ + sample_loop_inputs = self._prepare_sample_loop_inputs(inputs) + input_seeds = inputs.node.to(self.device) + input_type = inputs.input_type + is_hetero = self.dist_graph.data_cls == "hetero" + metadata = sample_loop_inputs.metadata + nodes_to_sample = sample_loop_inputs.nodes_to_sample + + if is_hetero: + assert isinstance(nodes_to_sample, dict) + assert input_type is not None + + inducer = self._acquire_inducer() + out_nodes_hetero: dict[NodeType, list[torch.Tensor]] = defaultdict(list) + + # All seed nodes (anchors + supervision) are included in the output + for seed_type, seed_nodes in nodes_to_sample.items(): + out_nodes_hetero[seed_type].append(seed_nodes) + + # Compute PPR separately for each seed type and store per-(seed_type, neighbor_type) + for seed_type, seed_nodes in nodes_to_sample.items(): + nbr_ids_by_type, nbr_weights_by_type = await self._compute_ppr_scores( + seed_nodes, seed_type + ) + assert isinstance(nbr_ids_by_type, dict) + assert isinstance(nbr_weights_by_type, dict) + + for ntype, neighbor_ids in nbr_ids_by_type.items(): + weights = nbr_weights_by_type[ntype] + flat_neighbors = neighbor_ids.flatten() + valid_mask = flat_neighbors != self._default_node_id + valid_neighbors = flat_neighbors[valid_mask] + + if valid_neighbors.numel() > 0: + out_nodes_hetero[ntype].append(valid_neighbors.unique()) + + metadata[f"ppr_weights_{seed_type}_{ntype}"] = weights + metadata[f"ppr_neighbor_ids_{seed_type}_{ntype}"] = neighbor_ids + + node_dict = { + ntype: torch.cat(nodes).unique() + for ntype, nodes in out_nodes_hetero.items() + if nodes + } + + sample_output = HeteroSamplerOutput( + node=node_dict, + row={}, # PPR doesn't maintain edge structure + col={}, + edge={}, # Empty dict — GLT SampleQueue requires all values to be tensors + batch={input_type: input_seeds}, + num_sampled_nodes={ + ntype: [nodes.size(0)] for ntype, nodes in node_dict.items() + }, + num_sampled_edges={}, + input_type=input_type, + metadata=metadata, + ) + + self.inducer_pool.put(inducer) + + else: + assert isinstance(nodes_to_sample, torch.Tensor) + + homo_neighbor_ids, homo_weights = await self._compute_ppr_scores( + nodes_to_sample, None + ) + assert isinstance(homo_neighbor_ids, torch.Tensor) + assert isinstance(homo_weights, torch.Tensor) + + flat_neighbors = homo_neighbor_ids.flatten() + valid_mask = flat_neighbors != self._default_node_id + valid_neighbors = flat_neighbors[valid_mask].unique() + + all_nodes = torch.cat([nodes_to_sample, valid_neighbors]) + + metadata["ppr_weights"] = homo_weights + metadata["ppr_neighbor_ids"] = homo_neighbor_ids + + sample_output = SamplerOutput( + node=all_nodes, + row=torch.tensor([], dtype=torch.long, device=self.device), + col=torch.tensor([], dtype=torch.long, device=self.device), + edge=torch.tensor( + [], dtype=torch.long, device=self.device + ), # Empty tensor — GLT SampleQueue requires all values to be tensors + batch=input_seeds, + num_sampled_nodes=[input_seeds.size(0), valid_neighbors.size(0)], + num_sampled_edges=[], + metadata=metadata, + ) + + return sample_output diff --git a/gigl/distributed/dist_sampling_producer.py b/gigl/distributed/dist_sampling_producer.py index 7b07ff006..f2e085736 100644 --- a/gigl/distributed/dist_sampling_producer.py +++ b/gigl/distributed/dist_sampling_producer.py @@ -35,7 +35,12 @@ from torch.utils.data.dataset import Dataset from gigl.distributed.dist_neighbor_sampler import DistNeighborSampler -from gigl.distributed.sampler_options import KHopNeighborSamplerOptions, SamplerOptions +from gigl.distributed.dist_ppr_sampler import DistPPRNeighborSampler +from gigl.distributed.sampler_options import ( + KHopNeighborSamplerOptions, + PPRSamplerOptions, + SamplerOptions, +) def _sampling_worker_loop( @@ -89,9 +94,20 @@ def _sampling_worker_loop( if sampling_config.seed is not None: seed_everything(sampling_config.seed) - # Resolve sampler class from options + # Resolve sampler class and any extra kwargs from options + extra_sampler_kwargs: dict[str, object] = {} if isinstance(sampler_options, KHopNeighborSamplerOptions): sampler_cls = DistNeighborSampler + elif isinstance(sampler_options, PPRSamplerOptions): + sampler_cls = DistPPRNeighborSampler + extra_sampler_kwargs = { + "alpha": sampler_options.alpha, + "eps": sampler_options.eps, + "max_ppr_nodes": sampler_options.max_ppr_nodes, + "default_node_id": sampler_options.default_node_id, + "default_weight": sampler_options.default_weight, + "num_nbrs_per_hop": sampler_options.num_nbrs_per_hop, + } else: raise NotImplementedError( f"Unsupported sampler options type: {type(sampler_options)}" @@ -110,6 +126,7 @@ def _sampling_worker_loop( worker_options.worker_concurrency, current_device, seed=sampling_config.seed, + **extra_sampler_kwargs, ) dist_sampler.start_loop() diff --git a/gigl/distributed/sampler_options.py b/gigl/distributed/sampler_options.py index f7bbb2e4b..22d0abf4b 100644 --- a/gigl/distributed/sampler_options.py +++ b/gigl/distributed/sampler_options.py @@ -1,8 +1,9 @@ """Sampler option types for configuring which sampler class to use in distributed loading. -Provides ``KHopNeighborSamplerOptions`` for using GiGL's built-in ``DistNeighborSampler``. +Provides ``KHopNeighborSamplerOptions`` for using GiGL's built-in ``DistNeighborSampler``, +and ``PPRSamplerOptions`` for PPR-based sampling using ``DistPPRNeighborSampler``. -Frozen dataclass so it is safe to pickle across RPC boundaries +Frozen dataclasses so they are safe to pickle across RPC boundaries (required for Graph Store mode). """ @@ -28,7 +29,38 @@ class KHopNeighborSamplerOptions: num_neighbors: Union[list[int], dict[EdgeType, list[int]]] -SamplerOptions = KHopNeighborSamplerOptions +@dataclass(frozen=True) +class PPRSamplerOptions: + """Sampler options for PPR-based neighbor sampling using DistPPRNeighborSampler. + + Degree tensors are sourced automatically from the dataset at sampler + initialization time and do not need to be provided here. + + Attributes: + alpha: Restart probability (teleport probability back to seed). Higher + values keep samples closer to seeds. Typical values: 0.15-0.25. + eps: Convergence threshold for the Forward Push algorithm. Smaller + values give more accurate PPR scores but require more computation. + Typical values: 1e-4 to 1e-6. + max_ppr_nodes: Maximum number of nodes to return per seed based on PPR + scores. + default_node_id: Node ID used to pad results when fewer than + max_ppr_nodes are found. + default_weight: PPR weight assigned to padding nodes. + num_nbrs_per_hop: Maximum number of neighbors fetched per node per edge + type during PPR traversal. Set large to approximate fetching all + neighbors. + """ + + alpha: float = 0.5 + eps: float = 1e-4 + max_ppr_nodes: int = 50 + default_node_id: int = -1 + default_weight: float = 0.0 + num_nbrs_per_hop: int = 100000 + + +SamplerOptions = Union[KHopNeighborSamplerOptions, PPRSamplerOptions] def resolve_sampler_options( @@ -37,12 +69,14 @@ def resolve_sampler_options( ) -> SamplerOptions: """Resolve sampler_options from user-provided values. - If ``sampler_options`` is ``None``, wraps ``num_neighbors`` in a - ``KHopNeighborSamplerOptions``. If ``KHopNeighborSamplerOptions`` is - provided, validates that its ``num_neighbors`` matches the explicit value. + If ``sampler_options`` is a ``PPRSamplerOptions``, returns it directly + (``num_neighbors`` is unused for PPR). If ``sampler_options`` is ``None``, + wraps ``num_neighbors`` in a ``KHopNeighborSamplerOptions``. If + ``KHopNeighborSamplerOptions`` is provided, validates that its + ``num_neighbors`` matches the explicit value. Args: - num_neighbors: Fanout per hop (always required). + num_neighbors: Fanout per hop (required for KHop; ignored for PPR). sampler_options: Sampler configuration, or None. Returns: @@ -52,6 +86,9 @@ def resolve_sampler_options( ValueError: If ``KHopNeighborSamplerOptions.num_neighbors`` conflicts with the explicit ``num_neighbors``. """ + if isinstance(sampler_options, PPRSamplerOptions): + return sampler_options + if sampler_options is None: return KHopNeighborSamplerOptions(num_neighbors) From ec4bdbc72a97cfb750e94dfa0f296662a4ed4c3a Mon Sep 17 00:00:00 2001 From: mkolodner Date: Mon, 9 Mar 2026 19:13:54 +0000 Subject: [PATCH 26/46] Use inducer for local indexing in DistPPRNeighborSampler --- gigl/distributed/dist_ppr_sampler.py | 115 +++++++++++++++++++++------ 1 file changed, 89 insertions(+), 26 deletions(-) diff --git a/gigl/distributed/dist_ppr_sampler.py b/gigl/distributed/dist_ppr_sampler.py index 97ff99f15..a2b82becb 100644 --- a/gigl/distributed/dist_ppr_sampler.py +++ b/gigl/distributed/dist_ppr_sampler.py @@ -11,6 +11,7 @@ SamplerOutput, ) from graphlearn_torch.typing import EdgeType, NodeType +from graphlearn_torch.utils import merge_dict from gigl.distributed.dist_dataset import DistDataset from gigl.distributed.dist_neighbor_sampler import DistNeighborSampler @@ -437,6 +438,14 @@ async def _sample_from_nodes( edge types based on the current node type. PPR weights are stored in metadata keyed as ``ppr_weights_{seed_type}_{neighbor_type}`` and ``ppr_neighbor_ids_{seed_type}_{neighbor_type}``. + + The ``ppr_neighbor_ids`` tensors are locally indexed — each value is a + 0-based index into ``data[ntype].node`` (the global-ID array produced by + GLT's collate step), so downstream models can directly index into + ``data[ntype].x`` without a separate global→local remapping step. + + The inducer is used to perform deduplication and local-index assignment + in-place during sampling, avoiding a post-hoc lookup pass. """ sample_loop_inputs = self._prepare_sample_loop_inputs(inputs) input_seeds = inputs.node.to(self.device) @@ -445,18 +454,21 @@ async def _sample_from_nodes( metadata = sample_loop_inputs.metadata nodes_to_sample = sample_loop_inputs.nodes_to_sample + inducer = self._acquire_inducer() + if is_hetero: assert isinstance(nodes_to_sample, dict) assert input_type is not None - inducer = self._acquire_inducer() - out_nodes_hetero: dict[NodeType, list[torch.Tensor]] = defaultdict(list) + # Register all seeds with the inducer; src_dict maps NodeType -> global IDs + src_dict = inducer.init_node(nodes_to_sample) - # All seed nodes (anchors + supervision) are included in the output - for seed_type, seed_nodes in nodes_to_sample.items(): - out_nodes_hetero[seed_type].append(seed_nodes) + # Compute PPR for each seed type; build nbr_dict for a single inducer.induce_next + # call using virtual edge types (seed_type, 'ppr', ntype). + nbr_dict: dict[EdgeType, list[torch.Tensor]] = {} + valid_counts_per_pair: dict[tuple[NodeType, NodeType], torch.Tensor] = {} + all_ppr_neighbor_ids: dict[tuple[NodeType, NodeType], torch.Tensor] = {} - # Compute PPR separately for each seed type and store per-(seed_type, neighbor_type) for seed_type, seed_nodes in nodes_to_sample.items(): nbr_ids_by_type, nbr_weights_by_type = await self._compute_ppr_scores( seed_nodes, seed_type @@ -465,23 +477,57 @@ async def _sample_from_nodes( assert isinstance(nbr_weights_by_type, dict) for ntype, neighbor_ids in nbr_ids_by_type.items(): - weights = nbr_weights_by_type[ntype] - flat_neighbors = neighbor_ids.flatten() - valid_mask = flat_neighbors != self._default_node_id - valid_neighbors = flat_neighbors[valid_mask] - - if valid_neighbors.numel() > 0: - out_nodes_hetero[ntype].append(valid_neighbors.unique()) - - metadata[f"ppr_weights_{seed_type}_{ntype}"] = weights - metadata[f"ppr_neighbor_ids_{seed_type}_{ntype}"] = neighbor_ids - + valid_mask = neighbor_ids != self._default_node_id + valid_counts = valid_mask.sum(dim=1) + flat_valid_nbrs = neighbor_ids[valid_mask] + + valid_counts_per_pair[(seed_type, ntype)] = valid_counts + all_ppr_neighbor_ids[(seed_type, ntype)] = neighbor_ids + metadata[f"ppr_weights_{seed_type}_{ntype}"] = nbr_weights_by_type[ + ntype + ] + + # Only add to nbr_dict if there are actual neighbors; induce_next + # will deduplicate across seed types automatically. + if flat_valid_nbrs.numel() > 0: + virtual_etype: EdgeType = (seed_type, "ppr", ntype) + nbr_dict[virtual_etype] = [ + src_dict[seed_type], + flat_valid_nbrs, + valid_counts, + ] + + new_nodes_dict, _rows_dict, cols_dict = inducer.induce_next(nbr_dict) + + # node_dict = seeds + newly discovered PPR neighbors (no duplicates) + out_nodes_hetero: dict[NodeType, list[torch.Tensor]] = defaultdict(list) + merge_dict(src_dict, out_nodes_hetero) + merge_dict(new_nodes_dict, out_nodes_hetero) node_dict = { - ntype: torch.cat(nodes).unique() + ntype: torch.cat(nodes) for ntype, nodes in out_nodes_hetero.items() if nodes } + # Reconstruct locally-indexed ppr_neighbor_ids from inducer cols. + # cols_dict[(seed_type, 'ppr', ntype)] holds local destination indices + # for all edges, in the same flat order as flat_valid_nbrs was built. + for ( + seed_type, + ntype, + ), original_neighbor_ids in all_ppr_neighbor_ids.items(): + valid_counts = valid_counts_per_pair[(seed_type, ntype)] + ppr_ids_local = torch.full_like(original_neighbor_ids, -1) + virtual_etype = (seed_type, "ppr", ntype) + cols = cols_dict.get(virtual_etype) + if cols is not None: + offset = 0 + for i, count in enumerate(valid_counts.tolist()): + count = int(count) + ppr_ids_local[i, :count] = cols[offset : offset + count] + offset += count + metadata[f"ppr_neighbor_ids_{seed_type}_{ntype}"] = ppr_ids_local + sample_output = HeteroSamplerOutput( node=node_dict, row={}, # PPR doesn't maintain edge structure @@ -496,25 +542,41 @@ async def _sample_from_nodes( metadata=metadata, ) - self.inducer_pool.put(inducer) - else: assert isinstance(nodes_to_sample, torch.Tensor) + # Register seeds; srcs holds their global IDs (local indices 0..N-1 assigned internally) + srcs = inducer.init_node(nodes_to_sample) + homo_neighbor_ids, homo_weights = await self._compute_ppr_scores( nodes_to_sample, None ) assert isinstance(homo_neighbor_ids, torch.Tensor) assert isinstance(homo_weights, torch.Tensor) - flat_neighbors = homo_neighbor_ids.flatten() - valid_mask = flat_neighbors != self._default_node_id - valid_neighbors = flat_neighbors[valid_mask].unique() + valid_mask = homo_neighbor_ids != self._default_node_id + valid_counts = valid_mask.sum(dim=1) + flat_valid_nbrs = homo_neighbor_ids[valid_mask] - all_nodes = torch.cat([nodes_to_sample, valid_neighbors]) + # induce_next deduplicates flat_valid_nbrs against already-seen nodes + # and returns local destination indices (cols) for each neighbor edge. + new_nodes, _rows, cols = inducer.induce_next( + srcs, flat_valid_nbrs, valid_counts + ) + all_nodes = torch.cat([srcs, new_nodes]) + + # Reconstruct ppr_neighbor_ids_local with shape [batch_size, max_ppr_nodes]. + # cols is flat; we slice it per seed using valid_counts to get each seed's + # local neighbor indices. + ppr_neighbor_ids_local = torch.full_like(homo_neighbor_ids, -1) + offset = 0 + for i, count in enumerate(valid_counts.tolist()): + count = int(count) + ppr_neighbor_ids_local[i, :count] = cols[offset : offset + count] + offset += count metadata["ppr_weights"] = homo_weights - metadata["ppr_neighbor_ids"] = homo_neighbor_ids + metadata["ppr_neighbor_ids"] = ppr_neighbor_ids_local sample_output = SamplerOutput( node=all_nodes, @@ -524,9 +586,10 @@ async def _sample_from_nodes( [], dtype=torch.long, device=self.device ), # Empty tensor — GLT SampleQueue requires all values to be tensors batch=input_seeds, - num_sampled_nodes=[input_seeds.size(0), valid_neighbors.size(0)], + num_sampled_nodes=[srcs.size(0), new_nodes.size(0)], num_sampled_edges=[], metadata=metadata, ) + self.inducer_pool.put(inducer) return sample_output From b05d9c60a0d968537a62a0cb95c3200f10a254cc Mon Sep 17 00:00:00 2001 From: mkolodner Date: Mon, 9 Mar 2026 20:38:05 +0000 Subject: [PATCH 27/46] Switch PPR output to edge-index format, remove default_node_id/default_weight --- gigl/distributed/dist_ppr_sampler.py | 284 ++++++++++++--------- gigl/distributed/dist_sampling_producer.py | 2 - gigl/distributed/sampler_options.py | 5 - 3 files changed, 167 insertions(+), 124 deletions(-) diff --git a/gigl/distributed/dist_ppr_sampler.py b/gigl/distributed/dist_ppr_sampler.py index a2b82becb..416525d9e 100644 --- a/gigl/distributed/dist_ppr_sampler.py +++ b/gigl/distributed/dist_ppr_sampler.py @@ -44,8 +44,6 @@ class DistPPRNeighborSampler(DistNeighborSampler): eps: Convergence threshold. Smaller values give more accurate PPR scores but require more computation. Typical values: 1e-4 to 1e-6. max_ppr_nodes: Maximum number of nodes to return per seed based on PPR scores. - default_node_id: Node ID to use when fewer than max_ppr_nodes are found. - default_weight: Weight to assign to padding nodes. num_nbrs_per_hop: Maximum number of neighbors to fetch per hop. """ @@ -55,8 +53,6 @@ def __init__( alpha: float = 0.5, eps: float = 1e-4, max_ppr_nodes: int = 50, - default_node_id: int = -1, - default_weight: float = 0.0, num_nbrs_per_hop: int = 100000, **kwargs, ): @@ -64,8 +60,6 @@ def __init__( self._alpha = alpha self._eps = eps self._max_ppr_nodes = max_ppr_nodes - self._default_node_id = default_node_id - self._default_weight = default_weight self._alpha_eps = alpha * eps self._num_nbrs_per_hop = num_nbrs_per_hop @@ -208,6 +202,7 @@ async def _compute_ppr_scores( ) -> tuple[ Union[torch.Tensor, dict[NodeType, torch.Tensor]], Union[torch.Tensor, dict[NodeType, torch.Tensor]], + Union[torch.Tensor, dict[NodeType, torch.Tensor]], ]: """ Compute PPR scores for seed nodes using the push-based approximation algorithm. @@ -234,9 +229,13 @@ async def _compute_ppr_scores( seed_node_type: Node type of seed nodes. Should be None for homogeneous graphs. Returns: - tuple of (neighbor_ids_by_type, ppr_weights_by_type) where: - - neighbor_ids_by_type: Union[torch.Tensor, dict mapping node type -> [batch_size, max_ppr_nodes]] - - ppr_weights_by_type: Union[torch.Tensor, dict mapping node type -> [batch_size, max_ppr_nodes]] + tuple of (flat_neighbor_ids, flat_weights, valid_counts) where each is either + a 1-D tensor (homogeneous) or a dict mapping NodeType to a 1-D tensor + (heterogeneous): + - flat_neighbor_ids: global neighbor IDs in top-k order, concatenated + across all seeds. Length equals sum(valid_counts). + - flat_weights: corresponding PPR scores, same length as flat_neighbor_ids. + - valid_counts: number of PPR neighbors found per seed [batch_size]. """ if seed_node_type is None: seed_node_type = _PPR_HOMOGENEOUS_NODE_TYPE @@ -367,31 +366,24 @@ async def _compute_ppr_scores( q[i].add(key_v) num_nodes_in_queue += 1 - # Extract top-k nodes by PPR score, grouped by node type + # Extract top-k nodes by PPR score, grouped by node type. + # Build flat tensors directly (no padding) — valid_counts[i] records how many + # neighbors seed i actually has, so callers can recover per-seed slices. all_node_types: set[NodeType] = set() for i in range(batch_size): - for node_id, node_type in p[i].keys(): + for _node_id, node_type in p[i].keys(): all_node_types.add(node_type) - out_neighbor_ids_dict: dict[NodeType, torch.Tensor] = {} - out_weights_dict: dict[NodeType, torch.Tensor] = {} + out_flat_ids_dict: dict[NodeType, torch.Tensor] = {} + out_flat_weights_dict: dict[NodeType, torch.Tensor] = {} + out_valid_counts_dict: dict[NodeType, torch.Tensor] = {} for ntype in all_node_types: - ntype_neighbor_ids = torch.full( - (batch_size, self._max_ppr_nodes), - self._default_node_id, - dtype=torch.long, - device=device, - ) - ntype_weights = torch.full( - (batch_size, self._max_ppr_nodes), - self._default_weight, - dtype=torch.float, - device=device, - ) + flat_ids: list[int] = [] + flat_weights: list[float] = [] + valid_counts: list[int] = [] for i in range(batch_size): - # Filter to nodes of this type type_scores = { node_id: score for (node_id, node_type), score in p[i].items() @@ -400,28 +392,38 @@ async def _compute_ppr_scores( top_k = heapq.nlargest( self._max_ppr_nodes, type_scores.items(), key=lambda x: x[1] ) + for node_id, weight in top_k: + flat_ids.append(node_id) + flat_weights.append(weight) + valid_counts.append(len(top_k)) - for j, (node_id, weight) in enumerate(top_k): - ntype_neighbor_ids[i, j] = node_id - ntype_weights[i, j] = weight - - out_neighbor_ids_dict[ntype] = ntype_neighbor_ids - out_weights_dict[ntype] = ntype_weights + out_flat_ids_dict[ntype] = torch.tensor( + flat_ids, dtype=torch.long, device=device + ) + out_flat_weights_dict[ntype] = torch.tensor( + flat_weights, dtype=torch.float, device=device + ) + out_valid_counts_dict[ntype] = torch.tensor( + valid_counts, dtype=torch.long, device=device + ) - out_neighbor_ids: Union[torch.Tensor, dict[NodeType, torch.Tensor]] - out_weights: Union[torch.Tensor, dict[NodeType, torch.Tensor]] + out_flat_ids: Union[torch.Tensor, dict[NodeType, torch.Tensor]] + out_flat_weights: Union[torch.Tensor, dict[NodeType, torch.Tensor]] + out_valid_counts: Union[torch.Tensor, dict[NodeType, torch.Tensor]] if self._is_homogeneous: assert ( len(all_node_types) == 1 and _PPR_HOMOGENEOUS_NODE_TYPE in all_node_types ) - out_neighbor_ids = out_neighbor_ids_dict[_PPR_HOMOGENEOUS_NODE_TYPE] - out_weights = out_weights_dict[_PPR_HOMOGENEOUS_NODE_TYPE] + out_flat_ids = out_flat_ids_dict[_PPR_HOMOGENEOUS_NODE_TYPE] + out_flat_weights = out_flat_weights_dict[_PPR_HOMOGENEOUS_NODE_TYPE] + out_valid_counts = out_valid_counts_dict[_PPR_HOMOGENEOUS_NODE_TYPE] else: - out_neighbor_ids = out_neighbor_ids_dict - out_weights = out_weights_dict + out_flat_ids = out_flat_ids_dict + out_flat_weights = out_flat_weights_dict + out_valid_counts = out_valid_counts_dict - return out_neighbor_ids, out_weights + return out_flat_ids, out_flat_weights, out_valid_counts async def _sample_from_nodes( self, @@ -435,17 +437,51 @@ async def _sample_from_nodes( subgraph includes neighbors relevant to all seed types. For heterogeneous graphs, PPR traverses across all edge types, switching - edge types based on the current node type. PPR weights are stored in - metadata keyed as ``ppr_weights_{seed_type}_{neighbor_type}`` and - ``ppr_neighbor_ids_{seed_type}_{neighbor_type}``. - - The ``ppr_neighbor_ids`` tensors are locally indexed — each value is a - 0-based index into ``data[ntype].node`` (the global-ID array produced by - GLT's collate step), so downstream models can directly index into - ``data[ntype].x`` without a separate global→local remapping step. - - The inducer is used to perform deduplication and local-index assignment - in-place during sampling, avoiding a post-hoc lookup pass. + edge types based on the current node type. + + Output format (PyG edge-index style, no padding): + + - ``ppr_neighbor_ids`` (homo) / ``ppr_neighbor_ids_{seed_type}_{ntype}`` (hetero): + shape ``[2, num_edges]`` — row 0 is local seed indices, row 1 is local + neighbor indices. Both index into ``data[ntype].node``. + - ``ppr_weights`` (homo) / ``ppr_weights_{seed_type}_{ntype}`` (hetero): + shape ``[num_edges]`` — PPR score for each edge, aligned with the columns + of ``ppr_neighbor_ids``. + + Local indices are produced by the inducer (see below), so row 1 of + ``ppr_neighbor_ids`` directly indexes into ``data[ntype].x`` without any + additional global→local remapping. + + **Why the inducer is used for local-index assignment:** + + The inducer is GLT's C++ data structure (backed by a per-node-type hash map) + that maintains a single global-ID → local-index mapping for the entire + subgraph being built. We use it here instead of a Python dict for two reasons: + + 1. **Consistency across seed types.** For heterogeneous ABLP inputs, + ``_compute_ppr_scores`` is called once per seed type (anchors, supervision + nodes, …). A node reachable from multiple seed types must receive the + *same* local index in ``node_dict[ntype]`` regardless of which seed type + discovered it. The inducer is shared across all those calls, so it + guarantees this automatically. + + 2. **Performance.** The inducer's C++ hash map is faster than a Python dict + for per-node lookups on large graphs, and its lifecycle is already managed + by GLT's inducer pool (``_acquire_inducer`` / ``inducer_pool.put``). + + The API used here mirrors GLT's own ``DistNeighborSampler._sample_from_nodes``: + + - ``inducer.init_node(seeds)`` registers seed nodes and returns their global + IDs (local indices 0, 1, … are assigned internally). + - ``inducer.induce_next(srcs, flat_nbrs, counts)`` (homo) or + ``inducer.induce_next(nbr_dict)`` (hetero) deduplicates neighbors against + all previously seen nodes and returns: + + - ``new_nodes``: global IDs of nodes not yet registered. + - ``cols``: flat local destination indices for *every* neighbor edge, + in the same order as the input ``flat_nbrs``. Combined with + ``repeat_interleave``-expanded seed indices, this forms the + ``[2, num_edges]`` edge-index tensor directly. """ sample_loop_inputs = self._prepare_sample_loop_inputs(inputs) input_seeds = inputs.node.to(self.device) @@ -454,52 +490,63 @@ async def _sample_from_nodes( metadata = sample_loop_inputs.metadata nodes_to_sample = sample_loop_inputs.nodes_to_sample + # Acquired once per sample; returned to the pool at the end. The inducer + # maintains the shared global→local index map for this entire subgraph. inducer = self._acquire_inducer() if is_hetero: assert isinstance(nodes_to_sample, dict) assert input_type is not None - # Register all seeds with the inducer; src_dict maps NodeType -> global IDs + # Register all seeds (anchors + supervision nodes for ABLP) with the + # inducer first, so they occupy the lowest local indices. src_dict maps + # NodeType -> global IDs (same values as nodes_to_sample). src_dict = inducer.init_node(nodes_to_sample) - # Compute PPR for each seed type; build nbr_dict for a single inducer.induce_next - # call using virtual edge types (seed_type, 'ppr', ntype). + # Compute PPR for each seed type, collecting flat global neighbor IDs, + # weights, and per-seed counts. Build nbr_dict for a single + # inducer.induce_next call using virtual edge types (seed_type, 'ppr', ntype) + # — the inducer only cares about etype[0] and etype[-1] as source/dest + # node types, so the relation name is arbitrary. nbr_dict: dict[EdgeType, list[torch.Tensor]] = {} - valid_counts_per_pair: dict[tuple[NodeType, NodeType], torch.Tensor] = {} - all_ppr_neighbor_ids: dict[tuple[NodeType, NodeType], torch.Tensor] = {} + all_flat_weights: dict[tuple[NodeType, NodeType], torch.Tensor] = {} + all_valid_counts: dict[tuple[NodeType, NodeType], torch.Tensor] = {} for seed_type, seed_nodes in nodes_to_sample.items(): - nbr_ids_by_type, nbr_weights_by_type = await self._compute_ppr_scores( - seed_nodes, seed_type - ) - assert isinstance(nbr_ids_by_type, dict) - assert isinstance(nbr_weights_by_type, dict) - - for ntype, neighbor_ids in nbr_ids_by_type.items(): - valid_mask = neighbor_ids != self._default_node_id - valid_counts = valid_mask.sum(dim=1) - flat_valid_nbrs = neighbor_ids[valid_mask] - - valid_counts_per_pair[(seed_type, ntype)] = valid_counts - all_ppr_neighbor_ids[(seed_type, ntype)] = neighbor_ids - metadata[f"ppr_weights_{seed_type}_{ntype}"] = nbr_weights_by_type[ - ntype - ] - - # Only add to nbr_dict if there are actual neighbors; induce_next - # will deduplicate across seed types automatically. - if flat_valid_nbrs.numel() > 0: + ( + flat_ids_by_type, + flat_weights_by_type, + valid_counts_by_type, + ) = await self._compute_ppr_scores(seed_nodes, seed_type) + assert isinstance(flat_ids_by_type, dict) + assert isinstance(flat_weights_by_type, dict) + assert isinstance(valid_counts_by_type, dict) + + for ntype, flat_ids in flat_ids_by_type.items(): + valid_counts = valid_counts_by_type[ntype] + all_flat_weights[(seed_type, ntype)] = flat_weights_by_type[ntype] + all_valid_counts[(seed_type, ntype)] = valid_counts + + # Skip empty pairs; induce_next handles deduplication across + # seed types so a neighbor reachable from multiple seed types + # gets one consistent local index in node_dict[ntype]. + if flat_ids.numel() > 0: virtual_etype: EdgeType = (seed_type, "ppr", ntype) nbr_dict[virtual_etype] = [ src_dict[seed_type], - flat_valid_nbrs, + flat_ids, valid_counts, ] + # induce_next assigns local indices to all neighbors not yet registered, + # deduplicating across all virtual edge types in one pass. + # new_nodes_dict: newly discovered global IDs per node type. + # cols_dict: flat local destination indices per virtual edge type, + # in the same order the flat neighbors were provided. new_nodes_dict, _rows_dict, cols_dict = inducer.induce_next(nbr_dict) - # node_dict = seeds + newly discovered PPR neighbors (no duplicates) + # node_dict = seeds (already in src_dict) + newly discovered PPR + # neighbors. merge_dict appends tensors into lists; cat collapses them. out_nodes_hetero: dict[NodeType, list[torch.Tensor]] = defaultdict(list) merge_dict(src_dict, out_nodes_hetero) merge_dict(new_nodes_dict, out_nodes_hetero) @@ -509,24 +556,27 @@ async def _sample_from_nodes( if nodes } - # Reconstruct locally-indexed ppr_neighbor_ids from inducer cols. - # cols_dict[(seed_type, 'ppr', ntype)] holds local destination indices - # for all edges, in the same flat order as flat_valid_nbrs was built. - for ( - seed_type, - ntype, - ), original_neighbor_ids in all_ppr_neighbor_ids.items(): - valid_counts = valid_counts_per_pair[(seed_type, ntype)] - ppr_ids_local = torch.full_like(original_neighbor_ids, -1) + # Build PyG-style edge-index output per (seed_type, ntype) pair. + # cols_dict[(seed_type, 'ppr', ntype)] gives flat local dst indices in + # the same order as the flat neighbors passed to induce_next. + # repeat_interleave expands seed local indices to match. + for (seed_type, ntype), flat_weights in all_flat_weights.items(): + valid_counts = all_valid_counts[(seed_type, ntype)] virtual_etype = (seed_type, "ppr", ntype) cols = cols_dict.get(virtual_etype) if cols is not None: - offset = 0 - for i, count in enumerate(valid_counts.tolist()): - count = int(count) - ppr_ids_local[i, :count] = cols[offset : offset + count] - offset += count - metadata[f"ppr_neighbor_ids_{seed_type}_{ntype}"] = ppr_ids_local + seed_batch_size = nodes_to_sample[seed_type].size(0) + src_indices = torch.repeat_interleave( + torch.arange(seed_batch_size, device=self.device), valid_counts + ) + ppr_edge_index = torch.stack([src_indices, cols]) + else: + ppr_edge_index = torch.zeros( + 2, 0, dtype=torch.long, device=self.device + ) + flat_weights = torch.zeros(0, dtype=torch.float, device=self.device) + metadata[f"ppr_neighbor_ids_{seed_type}_{ntype}"] = ppr_edge_index + metadata[f"ppr_weights_{seed_type}_{ntype}"] = flat_weights sample_output = HeteroSamplerOutput( node=node_dict, @@ -545,38 +595,38 @@ async def _sample_from_nodes( else: assert isinstance(nodes_to_sample, torch.Tensor) - # Register seeds; srcs holds their global IDs (local indices 0..N-1 assigned internally) + # Register seeds; local indices 0..N-1 are assigned internally. + # srcs holds their global IDs (same values as nodes_to_sample). srcs = inducer.init_node(nodes_to_sample) - homo_neighbor_ids, homo_weights = await self._compute_ppr_scores( - nodes_to_sample, None - ) - assert isinstance(homo_neighbor_ids, torch.Tensor) - assert isinstance(homo_weights, torch.Tensor) - - valid_mask = homo_neighbor_ids != self._default_node_id - valid_counts = valid_mask.sum(dim=1) - flat_valid_nbrs = homo_neighbor_ids[valid_mask] - - # induce_next deduplicates flat_valid_nbrs against already-seen nodes - # and returns local destination indices (cols) for each neighbor edge. + homo_ppr_result = await self._compute_ppr_scores(nodes_to_sample, None) + assert isinstance(homo_ppr_result[0], torch.Tensor) + assert isinstance(homo_ppr_result[1], torch.Tensor) + assert isinstance(homo_ppr_result[2], torch.Tensor) + homo_flat_ids: torch.Tensor = homo_ppr_result[0] + homo_flat_weights: torch.Tensor = homo_ppr_result[1] + homo_valid_counts: torch.Tensor = homo_ppr_result[2] + + # induce_next deduplicates homo_flat_ids against already-seen nodes + # (the seeds registered above) and returns: + # new_nodes: global IDs of nodes not yet registered. + # cols: flat local destination indices for every neighbor, in the + # same order as homo_flat_ids. new_nodes, _rows, cols = inducer.induce_next( - srcs, flat_valid_nbrs, valid_counts + srcs, homo_flat_ids, homo_valid_counts ) all_nodes = torch.cat([srcs, new_nodes]) - # Reconstruct ppr_neighbor_ids_local with shape [batch_size, max_ppr_nodes]. - # cols is flat; we slice it per seed using valid_counts to get each seed's - # local neighbor indices. - ppr_neighbor_ids_local = torch.full_like(homo_neighbor_ids, -1) - offset = 0 - for i, count in enumerate(valid_counts.tolist()): - count = int(count) - ppr_neighbor_ids_local[i, :count] = cols[offset : offset + count] - offset += count + # Build PyG-style edge-index: row 0 = local seed indices (expanded via + # repeat_interleave), row 1 = local neighbor indices from inducer cols. + src_indices = torch.repeat_interleave( + torch.arange(nodes_to_sample.size(0), device=self.device), + homo_valid_counts, + ) + ppr_edge_index = torch.stack([src_indices, cols]) - metadata["ppr_weights"] = homo_weights - metadata["ppr_neighbor_ids"] = ppr_neighbor_ids_local + metadata["ppr_neighbor_ids"] = ppr_edge_index + metadata["ppr_weights"] = homo_flat_weights sample_output = SamplerOutput( node=all_nodes, diff --git a/gigl/distributed/dist_sampling_producer.py b/gigl/distributed/dist_sampling_producer.py index f2e085736..860d38842 100644 --- a/gigl/distributed/dist_sampling_producer.py +++ b/gigl/distributed/dist_sampling_producer.py @@ -104,8 +104,6 @@ def _sampling_worker_loop( "alpha": sampler_options.alpha, "eps": sampler_options.eps, "max_ppr_nodes": sampler_options.max_ppr_nodes, - "default_node_id": sampler_options.default_node_id, - "default_weight": sampler_options.default_weight, "num_nbrs_per_hop": sampler_options.num_nbrs_per_hop, } else: diff --git a/gigl/distributed/sampler_options.py b/gigl/distributed/sampler_options.py index 22d0abf4b..52d21aa26 100644 --- a/gigl/distributed/sampler_options.py +++ b/gigl/distributed/sampler_options.py @@ -44,9 +44,6 @@ class PPRSamplerOptions: Typical values: 1e-4 to 1e-6. max_ppr_nodes: Maximum number of nodes to return per seed based on PPR scores. - default_node_id: Node ID used to pad results when fewer than - max_ppr_nodes are found. - default_weight: PPR weight assigned to padding nodes. num_nbrs_per_hop: Maximum number of neighbors fetched per node per edge type during PPR traversal. Set large to approximate fetching all neighbors. @@ -55,8 +52,6 @@ class PPRSamplerOptions: alpha: float = 0.5 eps: float = 1e-4 max_ppr_nodes: int = 50 - default_node_id: int = -1 - default_weight: float = 0.0 num_nbrs_per_hop: int = 100000 From 3506b3aabda5ff45d1a7f84e4707b5c9b1608236 Mon Sep 17 00:00:00 2001 From: mkolodner Date: Mon, 9 Mar 2026 22:50:20 +0000 Subject: [PATCH 28/46] comments --- gigl/distributed/base_dist_loader.py | 3 +-- gigl/distributed/dist_ablp_neighborloader.py | 5 ++--- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/gigl/distributed/base_dist_loader.py b/gigl/distributed/base_dist_loader.py index 980f064ed..f9e5b512c 100644 --- a/gigl/distributed/base_dist_loader.py +++ b/gigl/distributed/base_dist_loader.py @@ -101,8 +101,7 @@ class BaseDistLoader(DistLoader): runtime: Resolved distributed runtime information. producer: Either a pre-constructed ``DistSamplingProducer`` (colocated mode) or a callable to dispatch on the ``DistServer`` (graph store mode). - sampler_options: Controls which sampler class is instantiated. If ``None``, - falls back to the default ``KHopNeighborSamplerOptions``. + sampler_options: Controls which sampler class is instantiated. process_start_gap_seconds: Delay between each process for staggered colocated init. """ diff --git a/gigl/distributed/dist_ablp_neighborloader.py b/gigl/distributed/dist_ablp_neighborloader.py index 0bc482225..6220e99e4 100644 --- a/gigl/distributed/dist_ablp_neighborloader.py +++ b/gigl/distributed/dist_ablp_neighborloader.py @@ -193,9 +193,8 @@ def __init__( shuffle (bool): Whether to shuffle the input nodes. (default: ``False``). drop_last (bool): Whether to drop the last incomplete batch. (default: ``False``). sampler_options (Optional[SamplerOptions]): Controls which sampler class is - instantiated. Pass ``KHopNeighborSamplerOptions`` to use the built-in sampler, - or ``CustomSamplerOptions`` to dynamically import a custom sampler class. - If ``None``, defaults to ``KHopNeighborSamplerOptions(num_neighbors)``. + instantiated. Defaults to `KHopNeighborSamplerOptions`, which will use the num_neighbors argument + to instantiate the sampler. context (deprecated - will be removed soon) (Optional[DistributedContext]): Distributed context information of the current process. local_process_rank (deprecated - will be removed soon) (int): The local rank of the current process within a node. local_process_world_size (deprecated - will be removed soon) (int): The total number of processes within a node. From 43ae1a9a882f32034fe0dd0531fb4cbcd60af7d6 Mon Sep 17 00:00:00 2001 From: mkolodner Date: Wed, 11 Mar 2026 20:22:46 +0000 Subject: [PATCH 29/46] Optimize PPR: merge push+requeue into single pass, cache total degree --- gigl/distributed/dist_ppr_sampler.py | 113 ++++++++++++++------------- 1 file changed, 59 insertions(+), 54 deletions(-) diff --git a/gigl/distributed/dist_ppr_sampler.py b/gigl/distributed/dist_ppr_sampler.py index 416525d9e..08bb69548 100644 --- a/gigl/distributed/dist_ppr_sampler.py +++ b/gigl/distributed/dist_ppr_sampler.py @@ -215,14 +215,13 @@ async def _compute_ppr_scores( Algorithm Overview (each iteration of the main loop): 1. Fetch neighbors: Drain all nodes from the queue, group by edge type, and perform a batched neighbor lookup to populate neighbor/degree caches. - 2. Push residual: For each queued node, add its residual to its PPR score, - reset its residual to zero, then distribute (1-alpha) * residual to - all neighbors proportionally by degree. - 3. Batch fetch degrees: Group all neighbors that received residual and - perform a batched lookup to get their degrees (needed for threshold check). - 4. Re-queue high-residual nodes: For each neighbor that received residual, - check if residual >= alpha * eps * total_degree. If so, add to queue - for processing in the next iteration. + 2. Push residual + re-queue (single pass): For each queued node, add its + residual to its PPR score, reset its residual to zero, then distribute + (1-alpha) * residual to all neighbors proportionally by degree. After + each push, immediately check if the neighbor's accumulated residual + exceeds alpha * eps * total_degree; if so, add it to the queue for + the next iteration. Total degree lookups are cached across the entire + PPR computation to avoid redundant summation. Args: seed_nodes: Tensor of seed node IDs [batch_size] @@ -264,7 +263,28 @@ async def _compute_ppr_scores( # Cache keyed by (node_id, edge_type) since same node can have different neighbors per edge type neighbor_cache: dict[tuple[int, EdgeType], list[int]] = {} + # Cache for total degree (sum across all edge types for a node type). + # The per-edge-type degree is already O(1) via degree_tensors, but the + # *sum* across edge types is recomputed each time a node appears as a + # neighbor — which can be many times across seeds and iterations. + # Caching the sum avoids redundant _get_degree_from_tensor calls and + # the per-call Python overhead (method dispatch, isinstance, .item()). + total_degree_cache: dict[tuple[int, NodeType], int] = {} + + def _get_total_degree(node_id: int, node_type: NodeType) -> int: + key = (node_id, node_type) + cached = total_degree_cache.get(key) + if cached is not None: + return cached + total = sum( + self._get_degree_from_tensor(node_id, et) + for et in self._node_type_to_edge_types.get(node_type, []) + ) + total_degree_cache[key] = total + return total + num_nodes_in_queue = batch_size + one_minus_alpha = 1 - self._alpha while num_nodes_in_queue > 0: # Drain all nodes from all queues and group by edge type for batched lookups @@ -292,7 +312,20 @@ async def _compute_ppr_scores( nodes_by_edge_type, neighbor_cache, device ) - # Process nodes and push residual + # Push residual to neighbors and re-queue in a single pass. + # + # Previously these were two separate loops over the same neighbor + # lists — one to push residual, one to check thresholds. Merging + # them halves the total neighbor-list iteration. + # + # This is safe because each seed's state (p, r, q) is independent. + # If node v receives residual from multiple frontier nodes (u1, u2) + # of the same seed, v's threshold is checked after each push. The + # last frontier node to push to v sees the same accumulated residual + # that the original two-pass version would see. Since push values + # are always positive (residual monotonically increases), the merged + # version can never miss a re-queue that the two-pass version would + # catch. for i in range(batch_size): for u_node, u_type in nodes_to_process[i]: key_u = (u_node, u_type) @@ -305,66 +338,38 @@ async def _compute_ppr_scores( # For each edge type from this node type, push residual to neighbors edge_types_for_node = self._node_type_to_edge_types[u_type] - # Calculate total degree across all edge types for proper probability distribution - # Degrees are looked up directly from in-memory tensors - total_degree = sum( - self._get_degree_from_tensor(u_node, etype) - for etype in edge_types_for_node - ) + total_degree = _get_total_degree(u_node, u_type) if total_degree == 0: continue - # Push residual proportionally based on degree per edge type + push_value = one_minus_alpha * res_u / total_degree + + # Push residual proportionally based on degree per edge type. + # Per-edge-type degree is retrieved from _get_total_degree's + # cached sum path — the individual lookups are only needed + # to detect zero-degree edge types. for etype in edge_types_for_node: cache_key = (u_node, etype) neighbor_list = neighbor_cache[cache_key] - neighbor_count = self._get_degree_from_tensor(u_node, etype) - - if neighbor_count == 0: + if not neighbor_list: continue - # Determine the type of the neighbors v_type = self._get_neighbor_type(etype) - # Distribute residual to neighbors, weighted by edge type contribution - push_value = (1 - self._alpha) * res_u / total_degree - for v_node in neighbor_list: key_v = (v_node, v_type) r[i][key_v] += push_value - # Add high-residual neighbors to queue - for i in range(batch_size): - for u_node, u_type in nodes_to_process[i]: - edge_types_for_node = self._node_type_to_edge_types.get(u_type, []) - for etype in edge_types_for_node: - cache_key = (u_node, etype) - neighbor_list = neighbor_cache[cache_key] - v_type = self._get_neighbor_type(etype) - - for v_node in neighbor_list: - key_v = (v_node, v_type) - - if key_v in q[i]: - continue - - res_v = r[i].get(key_v, 0.0) - if res_v == 0.0: - continue - - # Sum degrees across all edge types from v_type for threshold check - edge_types_for_v = self._node_type_to_edge_types.get( - v_type, [] - ) - total_v_degree = sum( - self._get_degree_from_tensor(v_node, v_etype) - for v_etype in edge_types_for_v - ) - - if res_v >= self._alpha_eps * total_v_degree: - q[i].add(key_v) - num_nodes_in_queue += 1 + # Inline re-queue check: if v is not already queued + # and its accumulated residual exceeds the threshold, + # add it to the queue for the next iteration. + if key_v not in q[i]: + if r[i][key_v] >= self._alpha_eps * _get_total_degree( + v_node, v_type + ): + q[i].add(key_v) + num_nodes_in_queue += 1 # Extract top-k nodes by PPR score, grouped by node type. # Build flat tensors directly (no padding) — valid_counts[i] records how many From 81c1bfb92ac4e450c8bfbfafee4528dae4feb101 Mon Sep 17 00:00:00 2001 From: mkolodner Date: Wed, 11 Mar 2026 22:17:28 +0000 Subject: [PATCH 30/46] Add PPR sampler tests and fix ABLP metadata propagation --- gigl/distributed/dist_ablp_neighborloader.py | 33 ++ gigl/distributed/dist_sampling_producer.py | 21 + .../unit/distributed/dist_ppr_sampler_test.py | 555 ++++++++++++++++++ 3 files changed, 609 insertions(+) create mode 100644 tests/unit/distributed/dist_ppr_sampler_test.py diff --git a/gigl/distributed/dist_ablp_neighborloader.py b/gigl/distributed/dist_ablp_neighborloader.py index 92059df87..d43bc4b63 100644 --- a/gigl/distributed/dist_ablp_neighborloader.py +++ b/gigl/distributed/dist_ablp_neighborloader.py @@ -925,6 +925,10 @@ def _set_labels( return data def _collate_fn(self, msg: SampleMessage) -> Union[Data, HeteroData]: + # _get_labels strips ALL #META. keys from the message to work around a + # GLT bug in to_hetero_data. Collect non-label metadata beforehand so + # we can re-apply it to the output data after conversion. + non_label_metadata = self._extract_non_label_metadata(msg) msg, positive_labels, negative_labels = self._get_labels(msg) data = super()._collate_fn(msg) data = set_missing_features( @@ -941,5 +945,34 @@ def _collate_fn(self, msg: SampleMessage) -> Union[Data, HeteroData]: f"Expected 1 supervision edge type, got {len(self._supervision_edge_types)}" ) data = labeled_to_homogeneous(self._supervision_edge_types[0], data) + for key, value in non_label_metadata.items(): + data[key] = value data = self._set_labels(data, positive_labels, negative_labels) return data + + def _extract_non_label_metadata( + self, msg: SampleMessage + ) -> dict[str, torch.Tensor]: + """Extract non-label metadata from a SampleMessage before _get_labels strips it. + + _get_labels removes ALL ``#META.`` keys from the message to avoid a GLT + bug in ``to_hetero_data``. This method reads non-label metadata (e.g. + PPR scores) so it can be re-applied to the output Data/HeteroData after + conversion. + + Args: + msg: The SampleMessage to scan (not modified). + + Returns: + Dict mapping metadata key (without ``#META.`` prefix) to tensor. + """ + meta_prefix = "#META." + label_prefixes = ( + metadata_key_with_prefix(POSITIVE_LABEL_METADATA_KEY), + metadata_key_with_prefix(NEGATIVE_LABEL_METADATA_KEY), + ) + result: dict[str, torch.Tensor] = {} + for k in msg.keys(): + if k.startswith(meta_prefix) and not k.startswith(label_prefixes): + result[k[len(meta_prefix) :]] = msg[k].to(self.to_device) + return result diff --git a/gigl/distributed/dist_sampling_producer.py b/gigl/distributed/dist_sampling_producer.py index 860d38842..4960b756f 100644 --- a/gigl/distributed/dist_sampling_producer.py +++ b/gigl/distributed/dist_sampling_producer.py @@ -34,6 +34,8 @@ from torch.utils.data.dataloader import DataLoader from torch.utils.data.dataset import Dataset +from gigl.common.logger import Logger +from gigl.distributed.dist_dataset import DistDataset as GiglDistDataset from gigl.distributed.dist_neighbor_sampler import DistNeighborSampler from gigl.distributed.dist_ppr_sampler import DistPPRNeighborSampler from gigl.distributed.sampler_options import ( @@ -42,6 +44,8 @@ SamplerOptions, ) +logger = Logger() + def _sampling_worker_loop( rank: int, @@ -208,6 +212,23 @@ def __init__( def init(self): r"""Create the subprocess pool. Init samplers and rpc server.""" + # PPR sampling requires degree tensors in the sampler __init__. + # Worker subprocesses only initialize RPC (not torch.distributed), + # so the lazy degree computation would fail there. Eagerly compute + # here — where torch.distributed IS initialized — so the cached + # tensor is shared to workers via IPC. + if isinstance(self._sampler_options, PPRSamplerOptions): + assert isinstance(self.data, GiglDistDataset) + degree_tensor = self.data.degree_tensor + if isinstance(degree_tensor, dict): + logger.info( + f"Pre-computed degree tensors for PPR sampling across {len(degree_tensor)} edge types." + ) + else: + logger.info( + f"Pre-computed degree tensor for PPR sampling with {degree_tensor.size(0)} nodes." + ) + if self.sampling_config.seed is not None: seed_everything(self.sampling_config.seed) if not self.sampling_config.shuffle: diff --git a/tests/unit/distributed/dist_ppr_sampler_test.py b/tests/unit/distributed/dist_ppr_sampler_test.py new file mode 100644 index 000000000..b75519e75 --- /dev/null +++ b/tests/unit/distributed/dist_ppr_sampler_test.py @@ -0,0 +1,555 @@ +"""Unit tests for DistPPRNeighborSampler correctness via DistNeighborLoader. + +Verifies that the PPR scores produced by the distributed sampler match +NetworkX's ``pagerank`` with personalization — a well-tested, independent +PPR implementation. +""" + +import heapq +from collections import defaultdict + +import networkx as nx +import torch +import torch.multiprocessing as mp +from absl.testing import absltest +from graphlearn_torch.distributed import shutdown_rpc +from torch_geometric.data import Data, HeteroData + +from gigl.distributed.dist_ablp_neighborloader import DistABLPLoader +from gigl.distributed.dist_dataset import DistDataset +from gigl.distributed.distributed_neighborloader import DistNeighborLoader +from gigl.distributed.sampler_options import PPRSamplerOptions +from tests.test_assets.distributed.test_dataset import ( + STORY, + STORY_TO_USER, + USER, + USER_TO_STORY, + create_heterogeneous_dataset, + create_heterogeneous_dataset_for_ablp, + create_homogeneous_dataset, +) +from tests.test_assets.distributed.utils import create_test_process_group +from tests.test_assets.test_case import TestCase + +# --------------------------------------------------------------------------- +# Homogeneous test graph (5 nodes, undirected edges stored as bidirectional) +# +# 0 --- 1 --- 3 +# | | +# 2 --- + +# | +# 4 +# +# Undirected edges: {0-1, 0-2, 1-2, 1-3, 2-4} +# --------------------------------------------------------------------------- +_TEST_EDGE_INDEX = torch.tensor( + [ + [0, 1, 0, 2, 1, 2, 1, 3, 2, 4], + [1, 0, 2, 0, 2, 1, 3, 1, 4, 2], + ] +) +_NUM_TEST_NODES = 5 + +# --------------------------------------------------------------------------- +# Heterogeneous bipartite test graph (3 users, 3 stories) +# USER_TO_STORY: user 0 -> {story 0, story 1} +# user 1 -> {story 1, story 2} +# user 2 -> {story 0, story 2} +# STORY_TO_USER: reverse of USER_TO_STORY +# --------------------------------------------------------------------------- +_TEST_HETERO_EDGE_INDICES = { + USER_TO_STORY: torch.tensor([[0, 0, 1, 1, 2, 2], [0, 1, 1, 2, 0, 2]]), + STORY_TO_USER: torch.tensor([[0, 0, 1, 1, 2, 2], [0, 2, 0, 1, 1, 2]]), +} +_NUM_TEST_USERS = 3 +_NUM_TEST_STORIES = 3 + +_TEST_ALPHA = 0.5 +_TEST_EPS = 1e-6 +_TEST_MAX_PPR_NODES = 5 +_TEST_NUM_NBRS_PER_HOP = 100000 + + +# --------------------------------------------------------------------------- +# Reference PPR implementations (NetworkX-based) +# --------------------------------------------------------------------------- +def _build_reference_graph() -> nx.DiGraph: + """Build a NetworkX DiGraph matching the homogeneous test edge_index with edge_dir="in". + + With edge_dir="in", the PPR walk from node v follows incoming edges — + i.e., it moves to nodes u where (u, v) exists. NetworkX follows outgoing + edges, so we add edges dst->src so that nx.pagerank traverses the same + neighbors as the sampler. + """ + graph = nx.DiGraph() + graph.add_nodes_from(range(_NUM_TEST_NODES)) + src = _TEST_EDGE_INDEX[0].tolist() + dst = _TEST_EDGE_INDEX[1].tolist() + # Reverse direction: dst->src so outgoing edges in nx match incoming in GLT + graph.add_edges_from(zip(dst, src)) + return graph + + +def _reference_ppr( + graph: nx.DiGraph, + seed: int, + alpha: float, + max_ppr_nodes: int, +) -> dict[int, float]: + """Compute reference PPR scores for a homogeneous graph using NetworkX. + + Args: + graph: NetworkX DiGraph with edges oriented for the sampling direction. + seed: Seed node ID. + alpha: Restart probability (our convention). Mapped to NetworkX's + damping factor as ``nx_alpha = 1 - alpha``, since NetworkX's alpha + is the follow-edge probability while ours is the teleport + probability. + max_ppr_nodes: Maximum number of top-scoring nodes to return. + + Returns: + Dict mapping node_id -> PPR score for the top-k nodes. + """ + personalization = {n: 0.0 for n in graph.nodes()} + personalization[seed] = 1.0 + + # NetworkX alpha = follow probability = 1 - our restart probability + scores = nx.pagerank( + graph, alpha=1 - alpha, personalization=personalization, tol=1e-12 + ) + top_k = heapq.nlargest(max_ppr_nodes, scores.items(), key=lambda x: x[1]) + return dict(top_k) + + +def _build_hetero_reference_graph(edge_dir: str = "in") -> nx.DiGraph: + """Build a NetworkX DiGraph for the heterogeneous test graph. + + Nodes are ``(type_str, id)`` tuples. For edge_dir="in", edges are reversed + (dst->src) so that NetworkX's outgoing-edge traversal matches GLT's + incoming-edge PPR walk. For edge_dir="out", edges keep their original + direction (src->dst). + """ + graph = nx.DiGraph() + for i in range(_NUM_TEST_USERS): + graph.add_node((str(USER), i)) + for i in range(_NUM_TEST_STORIES): + graph.add_node((str(STORY), i)) + + for edge_type, edge_index in _TEST_HETERO_EDGE_INDICES.items(): + src_type, _, dst_type = edge_type + src = edge_index[0].tolist() + dst = edge_index[1].tolist() + if edge_dir == "in": + for s, d in zip(src, dst): + graph.add_edge((str(dst_type), d), (str(src_type), s)) + else: + for s, d in zip(src, dst): + graph.add_edge((str(src_type), s), (str(dst_type), d)) + + return graph + + +def _reference_ppr_hetero( + graph: nx.DiGraph, + seed: int, + seed_type: str, + alpha: float, + max_ppr_nodes: int, +) -> dict[str, dict[int, float]]: + """Compute reference PPR scores for a heterogeneous graph using NetworkX. + + Args: + graph: NetworkX DiGraph with ``(type_str, id)`` tuple nodes. + seed: Seed node ID. + seed_type: Node type string of the seed. + alpha: Restart probability (our convention). + max_ppr_nodes: Maximum top-scoring nodes to return per node type. + + Returns: + Dict mapping node_type_str -> {node_id: PPR score} for top-k per type. + """ + personalization = {n: 0.0 for n in graph.nodes()} + personalization[(seed_type, seed)] = 1.0 + + scores = nx.pagerank( + graph, alpha=1 - alpha, personalization=personalization, tol=1e-12 + ) + + type_to_scores: dict[str, dict[int, float]] = defaultdict(dict) + for (ntype, nid), score in scores.items(): + type_to_scores[ntype][nid] = score + + result: dict[str, dict[int, float]] = {} + for ntype, type_scores in type_to_scores.items(): + top_k = heapq.nlargest(max_ppr_nodes, type_scores.items(), key=lambda x: x[1]) + result[ntype] = dict(top_k) + + return result + + +# --------------------------------------------------------------------------- +# Spawned process functions +# --------------------------------------------------------------------------- +def _run_ppr_loader_correctness_check( + _: int, + dataset: DistDataset, + alpha: float, + max_ppr_nodes: int, +) -> None: + """Iterate homogeneous PPR loader and verify each batch against NetworkX PPR.""" + create_test_process_group() + + loader = DistNeighborLoader( + dataset=dataset, + num_neighbors=[10], # Unused by PPR sampler; required by interface + sampler_options=PPRSamplerOptions( + alpha=alpha, + eps=_TEST_EPS, + max_ppr_nodes=max_ppr_nodes, + num_nbrs_per_hop=_TEST_NUM_NBRS_PER_HOP, + ), + pin_memory_device=torch.device("cpu"), + batch_size=1, + ) + + reference_graph = _build_reference_graph() + + batches_checked = 0 + for datum in loader: + assert isinstance(datum, Data) + + # GLT's to_data() unpacks metadata dict keys directly onto the Data + # object (data[k] = v), so PPR results are top-level attributes. + assert hasattr(datum, "ppr_neighbor_ids"), "Missing ppr_neighbor_ids on Data" + assert hasattr(datum, "ppr_weights"), "Missing ppr_weights on Data" + + ppr_edge_index = datum.ppr_neighbor_ids + ppr_weights = datum.ppr_weights + + assert ( + ppr_edge_index.dim() == 2 and ppr_edge_index.size(0) == 2 + ), f"Expected [2, X] edge_index, got shape {list(ppr_edge_index.shape)}" + assert ppr_weights.dim() == 1, f"Expected 1D weights, got {ppr_weights.dim()}D" + assert ppr_edge_index.size(1) == ppr_weights.size( + 0 + ), f"Edge count mismatch: {ppr_edge_index.size(1)} vs {ppr_weights.size(0)}" + assert (ppr_weights > 0).all(), "PPR weights must be positive" + assert ( + ppr_edge_index[0] == 0 + ).all(), "All src indices must be 0 for batch_size=1" + + # Map local indices to global IDs + global_node_ids = datum.node + seed_global_id = datum.batch[0].item() + + sampler_ppr: dict[int, float] = {} + for j in range(ppr_edge_index.size(1)): + local_dst = ppr_edge_index[1, j].item() + global_dst = global_node_ids[local_dst].item() + sampler_ppr[global_dst] = ppr_weights[j].item() + + # Compute reference PPR + reference_ppr = _reference_ppr( + graph=reference_graph, + seed=seed_global_id, + alpha=alpha, + max_ppr_nodes=max_ppr_nodes, + ) + + # Verify same top-k node set + assert set(sampler_ppr.keys()) == set(reference_ppr.keys()), ( + f"Seed {seed_global_id}: top-k node sets differ.\n" + f" Sampler: {sorted(sampler_ppr.keys())}\n" + f" Reference: {sorted(reference_ppr.keys())}" + ) + + # Forward push is an approximation; with eps=1e-6 the per-node error + # is bounded by O(alpha * eps * degree), so atol=1e-3 is generous. + for node_id in reference_ppr: + ref_score = reference_ppr[node_id] + sam_score = sampler_ppr[node_id] + assert abs(sam_score - ref_score) < 1e-3, ( + f"Seed {seed_global_id}, node {node_id}: " + f"sampler={sam_score:.6f} vs reference={ref_score:.6f}" + ) + + batches_checked += 1 + + assert ( + batches_checked == _NUM_TEST_NODES + ), f"Expected {_NUM_TEST_NODES} batches, got {batches_checked}" + shutdown_rpc() + + +def _run_ppr_hetero_loader_correctness_check( + _: int, + dataset: DistDataset, + alpha: float, + max_ppr_nodes: int, +) -> None: + """Iterate heterogeneous PPR loader and verify each batch against NetworkX PPR.""" + create_test_process_group() + + node_ids = dataset.node_ids + assert isinstance(node_ids, dict) + + loader = DistNeighborLoader( + dataset=dataset, + input_nodes=(USER, node_ids[USER]), + num_neighbors=[10], # Unused by PPR sampler; required by interface + sampler_options=PPRSamplerOptions( + alpha=alpha, + eps=_TEST_EPS, + max_ppr_nodes=max_ppr_nodes, + num_nbrs_per_hop=_TEST_NUM_NBRS_PER_HOP, + ), + pin_memory_device=torch.device("cpu"), + batch_size=1, + ) + + reference_graph = _build_hetero_reference_graph() + + batches_checked = 0 + for datum in loader: + assert isinstance(datum, HeteroData) + + seed_global_id = datum[USER].batch[0].item() + + # Collect sampler PPR scores per node type + sampler_ppr_by_type: dict[str, dict[int, float]] = {} + for ntype in [USER, STORY]: + key_ids = f"ppr_neighbor_ids_{USER}_{ntype}" + key_weights = f"ppr_weights_{USER}_{ntype}" + + assert hasattr(datum, key_ids), f"Missing {key_ids}" + assert hasattr(datum, key_weights), f"Missing {key_weights}" + + ppr_edge_index = getattr(datum, key_ids) + ppr_weights = getattr(datum, key_weights) + + assert ( + ppr_edge_index.dim() == 2 and ppr_edge_index.size(0) == 2 + ), f"Expected [2, X] edge_index, got shape {list(ppr_edge_index.shape)}" + assert ppr_weights.dim() == 1 + assert ppr_edge_index.size(1) == ppr_weights.size(0) + assert (ppr_weights > 0).all(), f"PPR weights for {ntype} must be positive" + assert ( + ppr_edge_index[0] == 0 + ).all(), "All src indices must be 0 for batch_size=1" + + global_node_ids = datum[ntype].node + type_ppr: dict[int, float] = {} + for j in range(ppr_edge_index.size(1)): + local_dst = ppr_edge_index[1, j].item() + global_dst = global_node_ids[local_dst].item() + type_ppr[global_dst] = ppr_weights[j].item() + sampler_ppr_by_type[str(ntype)] = type_ppr + + # Compute reference PPR + reference_ppr = _reference_ppr_hetero( + graph=reference_graph, + seed=seed_global_id, + seed_type=str(USER), + alpha=alpha, + max_ppr_nodes=max_ppr_nodes, + ) + + # Verify per node type + for ntype_str in [str(USER), str(STORY)]: + assert set(sampler_ppr_by_type[ntype_str].keys()) == set( + reference_ppr[ntype_str].keys() + ), ( + f"Seed {seed_global_id}, type {ntype_str}: top-k node sets differ.\n" + f" Sampler: {sorted(sampler_ppr_by_type[ntype_str].keys())}\n" + f" Reference: {sorted(reference_ppr[ntype_str].keys())}" + ) + + for node_id in reference_ppr[ntype_str]: + ref_score = reference_ppr[ntype_str][node_id] + sam_score = sampler_ppr_by_type[ntype_str][node_id] + assert abs(sam_score - ref_score) < 1e-3, ( + f"Seed {seed_global_id}, type {ntype_str}, node {node_id}: " + f"sampler={sam_score:.6f} vs reference={ref_score:.6f}" + ) + + batches_checked += 1 + + assert ( + batches_checked == _NUM_TEST_USERS + ), f"Expected {_NUM_TEST_USERS} batches, got {batches_checked}" + shutdown_rpc() + + +def _run_ppr_ablp_loader_correctness_check( + _: int, + alpha: float, + max_ppr_nodes: int, +) -> None: + """Iterate ABLP PPR loader and verify anchor-seed PPR against NetworkX reference. + + Checks both anchor (USER) seed PPR scores for correctness against NetworkX, + and verifies that supervision (STORY) seed PPR metadata is present with + valid shapes. Also confirms that ABLP-specific output (y_positive) is + produced alongside PPR metadata. + + The ABLP dataset is created inside this spawned process because the + splitter requires torch.distributed to be initialized. + """ + create_test_process_group() + + dataset = create_heterogeneous_dataset_for_ablp( + positive_labels={0: [0, 1], 1: [1, 2], 2: [0, 2]}, + train_node_ids=[0, 1], + val_node_ids=[2], + test_node_ids=[], + edge_indices=_TEST_HETERO_EDGE_INDICES, + edge_dir="out", + ) + + train_node_ids = dataset.train_node_ids + assert isinstance(train_node_ids, dict) + + loader = DistABLPLoader( + dataset=dataset, + num_neighbors=[10], # Unused by PPR sampler; required by interface + input_nodes=(USER, train_node_ids[USER]), + supervision_edge_type=USER_TO_STORY, + sampler_options=PPRSamplerOptions( + alpha=alpha, + eps=_TEST_EPS, + max_ppr_nodes=max_ppr_nodes, + num_nbrs_per_hop=_TEST_NUM_NBRS_PER_HOP, + ), + pin_memory_device=torch.device("cpu"), + batch_size=1, + ) + + reference_graph = _build_hetero_reference_graph(edge_dir="out") + + batches_checked = 0 + for datum in loader: + assert isinstance(datum, HeteroData) + + # ABLP should produce positive labels alongside PPR metadata + assert hasattr(datum, "y_positive"), "Missing y_positive on HeteroData" + + seed_global_id = datum[USER].batch[0].item() + + # --- Verify anchor (USER) seed PPR correctness against NetworkX --- + sampler_ppr_by_type: dict[str, dict[int, float]] = {} + for ntype in [USER, STORY]: + key_ids = f"ppr_neighbor_ids_{USER}_{ntype}" + key_weights = f"ppr_weights_{USER}_{ntype}" + + assert hasattr(datum, key_ids), f"Missing {key_ids}" + assert hasattr(datum, key_weights), f"Missing {key_weights}" + + ppr_edge_index = getattr(datum, key_ids) + ppr_weights = getattr(datum, key_weights) + + assert ppr_edge_index.dim() == 2 and ppr_edge_index.size(0) == 2 + assert ppr_weights.dim() == 1 + assert ppr_edge_index.size(1) == ppr_weights.size(0) + assert (ppr_weights > 0).all() + assert (ppr_edge_index[0] == 0).all() # batch_size=1 + + global_node_ids = datum[ntype].node + type_ppr: dict[int, float] = {} + for j in range(ppr_edge_index.size(1)): + local_dst = ppr_edge_index[1, j].item() + global_dst = global_node_ids[local_dst].item() + type_ppr[global_dst] = ppr_weights[j].item() + sampler_ppr_by_type[str(ntype)] = type_ppr + + reference_ppr = _reference_ppr_hetero( + graph=reference_graph, + seed=seed_global_id, + seed_type=str(USER), + alpha=alpha, + max_ppr_nodes=max_ppr_nodes, + ) + + for ntype_str in [str(USER), str(STORY)]: + assert set(sampler_ppr_by_type[ntype_str].keys()) == set( + reference_ppr[ntype_str].keys() + ), ( + f"ABLP seed {seed_global_id}, type {ntype_str}: top-k node sets differ.\n" + f" Sampler: {sorted(sampler_ppr_by_type[ntype_str].keys())}\n" + f" Reference: {sorted(reference_ppr[ntype_str].keys())}" + ) + + for node_id in reference_ppr[ntype_str]: + ref_score = reference_ppr[ntype_str][node_id] + sam_score = sampler_ppr_by_type[ntype_str][node_id] + assert abs(sam_score - ref_score) < 1e-3, ( + f"ABLP seed {seed_global_id}, type {ntype_str}, node {node_id}: " + f"sampler={sam_score:.6f} vs reference={ref_score:.6f}" + ) + + # --- Verify supervision (STORY) seed PPR metadata --- + # ABLP adds supervision nodes as additional seeds, producing PPR metadata + # keyed by the STORY seed type. + for ntype in [USER, STORY]: + key_ids = f"ppr_neighbor_ids_{STORY}_{ntype}" + key_weights = f"ppr_weights_{STORY}_{ntype}" + + assert hasattr(datum, key_ids), f"Missing {key_ids}" + assert hasattr(datum, key_weights), f"Missing {key_weights}" + + ppr_edge_index = getattr(datum, key_ids) + ppr_weights = getattr(datum, key_weights) + + assert ppr_edge_index.dim() == 2 and ppr_edge_index.size(0) == 2 + assert ppr_weights.dim() == 1 + assert ppr_edge_index.size(1) == ppr_weights.size(0) + if ppr_weights.numel() > 0: + assert (ppr_weights > 0).all() + assert (ppr_edge_index[1] >= 0).all() + assert (ppr_edge_index[1] < datum[ntype].node.size(0)).all() + + batches_checked += 1 + + assert batches_checked > 0, "Expected at least one ABLP batch" + shutdown_rpc() + + +# --------------------------------------------------------------------------- +# Test class +# --------------------------------------------------------------------------- +class DistPPRSamplerTest(TestCase): + def setUp(self) -> None: + super().setUp() + + def tearDown(self) -> None: + if torch.distributed.is_initialized(): + torch.distributed.destroy_process_group() + super().tearDown() + + def test_ppr_sampler_correctness_homogeneous(self) -> None: + """Verify PPR scores match NetworkX pagerank on a small homogeneous graph.""" + dataset = create_homogeneous_dataset(edge_index=_TEST_EDGE_INDEX, edge_dir="in") + mp.spawn( + fn=_run_ppr_loader_correctness_check, + args=(dataset, _TEST_ALPHA, _TEST_MAX_PPR_NODES), + ) + + def test_ppr_sampler_correctness_heterogeneous(self) -> None: + """Verify PPR scores match NetworkX pagerank on a heterogeneous bipartite graph.""" + dataset = create_heterogeneous_dataset( + edge_indices=_TEST_HETERO_EDGE_INDICES, edge_dir="in" + ) + mp.spawn( + fn=_run_ppr_hetero_loader_correctness_check, + args=(dataset, _TEST_ALPHA, _TEST_MAX_PPR_NODES), + ) + + def test_ppr_sampler_ablp_correctness(self) -> None: + """Verify PPR scores through DistABLPLoader on a heterogeneous graph.""" + mp.spawn( + fn=_run_ppr_ablp_loader_correctness_check, + args=(_TEST_ALPHA, _TEST_MAX_PPR_NODES), + ) + + +if __name__ == "__main__": + absltest.main() From 53b228483e20959484f03b62e89f65a65cca8b47 Mon Sep 17 00:00:00 2001 From: mkolodner Date: Wed, 11 Mar 2026 23:09:48 +0000 Subject: [PATCH 31/46] Clean up PPR tests and fix metadata stripping in DistNeighborLoader --- .../distributed/distributed_neighborloader.py | 31 ++ .../unit/distributed/dist_ppr_sampler_test.py | 296 ++++++++++-------- 2 files changed, 204 insertions(+), 123 deletions(-) diff --git a/gigl/distributed/distributed_neighborloader.py b/gigl/distributed/distributed_neighborloader.py index 9c72500b1..f59b732ee 100644 --- a/gigl/distributed/distributed_neighborloader.py +++ b/gigl/distributed/distributed_neighborloader.py @@ -558,6 +558,12 @@ def _setup_for_colocated( ) def _collate_fn(self, msg: SampleMessage) -> Union[Data, HeteroData]: + # Extract user-defined metadata (e.g. PPR scores) before + # super()._collate_fn, which calls GLT's to_hetero_data. + # to_hetero_data misinterprets #META. keys as edge types and + # fails when edge_dir="out" (tries to reverse_edge_type on them). + # We strip them here and re-apply after conversion. + non_edge_metadata = self._extract_metadata(msg) data = super()._collate_fn(msg) data = set_missing_features( data=data, @@ -569,4 +575,29 @@ def _collate_fn(self, msg: SampleMessage) -> Union[Data, HeteroData]: data = strip_label_edges(data) if self._is_homogeneous_with_labeled_edge_type: data = labeled_to_homogeneous(DEFAULT_HOMOGENEOUS_EDGE_TYPE, data) + for key, value in non_edge_metadata.items(): + data[key] = value return data + + def _extract_metadata(self, msg: SampleMessage) -> dict[str, torch.Tensor]: + """Extract and remove user-defined metadata from a SampleMessage. + + GLT's ``to_hetero_data`` misinterprets ``#META.``-prefixed keys as + edge types, causing failures with ``edge_dir="out"`` (it tries to call + ``reverse_edge_type`` on metadata key strings). This method strips + those keys so the conversion succeeds, returning them for re-application + onto the output Data/HeteroData. + + Args: + msg: The SampleMessage to modify in-place. + + Returns: + Dict mapping metadata key (without ``#META.`` prefix) to tensor. + """ + meta_prefix = "#META." + result: dict[str, torch.Tensor] = {} + for k in list(msg.keys()): + if k.startswith(meta_prefix): + result[k[len(meta_prefix) :]] = msg[k].to(self.to_device) + del msg[k] + return result diff --git a/tests/unit/distributed/dist_ppr_sampler_test.py b/tests/unit/distributed/dist_ppr_sampler_test.py index b75519e75..d0617793c 100644 --- a/tests/unit/distributed/dist_ppr_sampler_test.py +++ b/tests/unit/distributed/dist_ppr_sampler_test.py @@ -3,20 +3,42 @@ Verifies that the PPR scores produced by the distributed sampler match NetworkX's ``pagerank`` with personalization — a well-tested, independent PPR implementation. + +Note on compatability with NetworkX: + +Both our forward push algorithm (Andersen et al., 2006) and NetworkX's +``pagerank`` (power iteration) compute Personalized PageRank — they are +different solvers for the same quantity. With a small residual tolerance +(eps=1e-6), forward push converges close enough that per-node scores match +NetworkX within atol=1e-3. + +Another note is that our ``alpha`` is the *restart* (teleport) probability — the probability of +jumping back to the seed at each step. NetworkX's ``alpha`` is the *damping +factor* — the probability of following an edge. These are complements:: + + nx_alpha = 1 - our_alpha + +Finally, with ``edge_dir="in"``, the PPR walk from node v follows *incoming* edges — +it moves to nodes u where edge (u, v) exists in the graph. NetworkX's +``pagerank`` follows *outgoing* edges. To make NetworkX traverse the same +neighbors as the sampler, we reverse the edges when building the reference +graph (add dst→src instead of src→dst). When ``edge_dir="out"``, no +reversal is needed since both follow the original edge direction. """ import heapq from collections import defaultdict +from typing import Literal import networkx as nx import torch import torch.multiprocessing as mp from absl.testing import absltest from graphlearn_torch.distributed import shutdown_rpc +from parameterized import param, parameterized from torch_geometric.data import Data, HeteroData from gigl.distributed.dist_ablp_neighborloader import DistABLPLoader -from gigl.distributed.dist_dataset import DistDataset from gigl.distributed.distributed_neighborloader import DistNeighborLoader from gigl.distributed.sampler_options import PPRSamplerOptions from tests.test_assets.distributed.test_dataset import ( @@ -73,20 +95,23 @@ # --------------------------------------------------------------------------- # Reference PPR implementations (NetworkX-based) # --------------------------------------------------------------------------- -def _build_reference_graph() -> nx.DiGraph: - """Build a NetworkX DiGraph matching the homogeneous test edge_index with edge_dir="in". +def _build_reference_graph(edge_dir: Literal["in", "out"] = "in") -> nx.DiGraph: + """Build a NetworkX DiGraph matching the homogeneous test edge_index. + + With ``edge_dir="in"``, edges are reversed (dst→src) so that NetworkX's + outgoing-edge traversal matches GLT's incoming-edge PPR walk. With + ``edge_dir="out"``, edges keep their original direction (src→dst). - With edge_dir="in", the PPR walk from node v follows incoming edges — - i.e., it moves to nodes u where (u, v) exists. NetworkX follows outgoing - edges, so we add edges dst->src so that nx.pagerank traverses the same - neighbors as the sampler. + See the module docstring for a full explanation of why reversal is needed. """ graph = nx.DiGraph() graph.add_nodes_from(range(_NUM_TEST_NODES)) src = _TEST_EDGE_INDEX[0].tolist() dst = _TEST_EDGE_INDEX[1].tolist() - # Reverse direction: dst->src so outgoing edges in nx match incoming in GLT - graph.add_edges_from(zip(dst, src)) + if edge_dir == "in": + graph.add_edges_from(zip(dst, src)) + else: + graph.add_edges_from(zip(src, dst)) return graph @@ -98,13 +123,12 @@ def _reference_ppr( ) -> dict[int, float]: """Compute reference PPR scores for a homogeneous graph using NetworkX. + See the module docstring for the alpha mapping rationale. + Args: graph: NetworkX DiGraph with edges oriented for the sampling direction. seed: Seed node ID. - alpha: Restart probability (our convention). Mapped to NetworkX's - damping factor as ``nx_alpha = 1 - alpha``, since NetworkX's alpha - is the follow-edge probability while ours is the teleport - probability. + alpha: Restart probability (our convention). max_ppr_nodes: Maximum number of top-scoring nodes to return. Returns: @@ -113,7 +137,6 @@ def _reference_ppr( personalization = {n: 0.0 for n in graph.nodes()} personalization[seed] = 1.0 - # NetworkX alpha = follow probability = 1 - our restart probability scores = nx.pagerank( graph, alpha=1 - alpha, personalization=personalization, tol=1e-12 ) @@ -121,13 +144,12 @@ def _reference_ppr( return dict(top_k) -def _build_hetero_reference_graph(edge_dir: str = "in") -> nx.DiGraph: +def _build_hetero_reference_graph(edge_dir: Literal["in", "out"] = "in") -> nx.DiGraph: """Build a NetworkX DiGraph for the heterogeneous test graph. - Nodes are ``(type_str, id)`` tuples. For edge_dir="in", edges are reversed - (dst->src) so that NetworkX's outgoing-edge traversal matches GLT's - incoming-edge PPR walk. For edge_dir="out", edges keep their original - direction (src->dst). + Nodes are ``(type_str, id)`` tuples. Edge direction is handled the same + way as :func:`_build_reference_graph` — see the module docstring for the + full explanation of why reversal is needed for ``edge_dir="in"``. """ graph = nx.DiGraph() for i in range(_NUM_TEST_USERS): @@ -158,6 +180,8 @@ def _reference_ppr_hetero( ) -> dict[str, dict[int, float]]: """Compute reference PPR scores for a heterogeneous graph using NetworkX. + See the module docstring for the alpha mapping rationale. + Args: graph: NetworkX DiGraph with ``(type_str, id)`` tuple nodes. seed: Seed node ID. @@ -187,18 +211,110 @@ def _reference_ppr_hetero( return result +# --------------------------------------------------------------------------- +# Shared verification helpers +# --------------------------------------------------------------------------- +def _extract_hetero_ppr_scores( + datum: HeteroData, + seed_type: str, + node_types: list[str], +) -> dict[str, dict[int, float]]: + """Extract and validate PPR metadata from a HeteroData batch. + + Verifies tensor shapes and invariants (positive weights, valid indices), + maps local indices to global IDs, and returns scores grouped by node type. + + Args: + datum: A single HeteroData batch (batch_size=1). + seed_type: The seed node type used to key PPR metadata attributes. + node_types: Node types to extract PPR scores for. + + Returns: + Dict mapping node_type_str -> {global_node_id: ppr_score}. + """ + sampler_ppr_by_type: dict[str, dict[int, float]] = {} + for ntype in node_types: + key_ids = f"ppr_neighbor_ids_{seed_type}_{ntype}" + key_weights = f"ppr_weights_{seed_type}_{ntype}" + + assert hasattr(datum, key_ids), f"Missing {key_ids}" + assert hasattr(datum, key_weights), f"Missing {key_weights}" + + ppr_edge_index = getattr(datum, key_ids) + ppr_weights = getattr(datum, key_weights) + + assert ( + ppr_edge_index.dim() == 2 and ppr_edge_index.size(0) == 2 + ), f"Expected [2, X] edge_index, got shape {list(ppr_edge_index.shape)}" + assert ppr_weights.dim() == 1 + assert ppr_edge_index.size(1) == ppr_weights.size(0) + assert (ppr_weights > 0).all(), f"PPR weights for {ntype} must be positive" + assert ( + ppr_edge_index[0] == 0 + ).all(), "All src indices must be 0 for batch_size=1" + + global_node_ids = datum[ntype].node + type_ppr: dict[int, float] = {} + for j in range(ppr_edge_index.size(1)): + local_dst = ppr_edge_index[1, j].item() + global_dst = global_node_ids[local_dst].item() + type_ppr[global_dst] = ppr_weights[j].item() + sampler_ppr_by_type[str(ntype)] = type_ppr + + return sampler_ppr_by_type + + +def _assert_ppr_scores_match_reference( + sampler_ppr_by_type: dict[str, dict[int, float]], + reference_ppr: dict[str, dict[int, float]], + seed_id: int, + context_label: str = "", +) -> None: + """Assert sampler PPR scores match reference scores per node type. + + Checks that top-k node sets are identical and that per-node scores + are within atol=1e-3. The forward push error per node is bounded by + O(alpha * eps * degree), so atol=1e-3 is generous for eps=1e-6. + + Args: + sampler_ppr_by_type: Sampler output from :func:`_extract_hetero_ppr_scores`. + reference_ppr: Reference output from :func:`_reference_ppr_hetero`. + seed_id: Global seed node ID (for error messages). + context_label: Optional prefix for error messages (e.g. "ABLP"). + """ + prefix = f"{context_label} seed" if context_label else f"Seed" + for ntype_str in reference_ppr: + assert set(sampler_ppr_by_type[ntype_str].keys()) == set( + reference_ppr[ntype_str].keys() + ), ( + f"{prefix} {seed_id}, type {ntype_str}: top-k node sets differ.\n" + f" Sampler: {sorted(sampler_ppr_by_type[ntype_str].keys())}\n" + f" Reference: {sorted(reference_ppr[ntype_str].keys())}" + ) + + for node_id in reference_ppr[ntype_str]: + ref_score = reference_ppr[ntype_str][node_id] + sam_score = sampler_ppr_by_type[ntype_str][node_id] + assert abs(sam_score - ref_score) < 1e-3, ( + f"{prefix} {seed_id}, type {ntype_str}, node {node_id}: " + f"sampler={sam_score:.6f} vs reference={ref_score:.6f}" + ) + + # --------------------------------------------------------------------------- # Spawned process functions # --------------------------------------------------------------------------- def _run_ppr_loader_correctness_check( _: int, - dataset: DistDataset, alpha: float, max_ppr_nodes: int, + edge_dir: Literal["in", "out"], ) -> None: """Iterate homogeneous PPR loader and verify each batch against NetworkX PPR.""" create_test_process_group() + dataset = create_homogeneous_dataset(edge_index=_TEST_EDGE_INDEX, edge_dir=edge_dir) + loader = DistNeighborLoader( dataset=dataset, num_neighbors=[10], # Unused by PPR sampler; required by interface @@ -212,7 +328,7 @@ def _run_ppr_loader_correctness_check( batch_size=1, ) - reference_graph = _build_reference_graph() + reference_graph = _build_reference_graph(edge_dir) batches_checked = 0 for datum in loader: @@ -283,13 +399,17 @@ def _run_ppr_loader_correctness_check( def _run_ppr_hetero_loader_correctness_check( _: int, - dataset: DistDataset, alpha: float, max_ppr_nodes: int, + edge_dir: Literal["in", "out"], ) -> None: """Iterate heterogeneous PPR loader and verify each batch against NetworkX PPR.""" create_test_process_group() + dataset = create_heterogeneous_dataset( + edge_indices=_TEST_HETERO_EDGE_INDICES, edge_dir=edge_dir + ) + node_ids = dataset.node_ids assert isinstance(node_ids, dict) @@ -307,7 +427,7 @@ def _run_ppr_hetero_loader_correctness_check( batch_size=1, ) - reference_graph = _build_hetero_reference_graph() + reference_graph = _build_hetero_reference_graph(edge_dir) batches_checked = 0 for datum in loader: @@ -315,37 +435,10 @@ def _run_ppr_hetero_loader_correctness_check( seed_global_id = datum[USER].batch[0].item() - # Collect sampler PPR scores per node type - sampler_ppr_by_type: dict[str, dict[int, float]] = {} - for ntype in [USER, STORY]: - key_ids = f"ppr_neighbor_ids_{USER}_{ntype}" - key_weights = f"ppr_weights_{USER}_{ntype}" - - assert hasattr(datum, key_ids), f"Missing {key_ids}" - assert hasattr(datum, key_weights), f"Missing {key_weights}" - - ppr_edge_index = getattr(datum, key_ids) - ppr_weights = getattr(datum, key_weights) - - assert ( - ppr_edge_index.dim() == 2 and ppr_edge_index.size(0) == 2 - ), f"Expected [2, X] edge_index, got shape {list(ppr_edge_index.shape)}" - assert ppr_weights.dim() == 1 - assert ppr_edge_index.size(1) == ppr_weights.size(0) - assert (ppr_weights > 0).all(), f"PPR weights for {ntype} must be positive" - assert ( - ppr_edge_index[0] == 0 - ).all(), "All src indices must be 0 for batch_size=1" - - global_node_ids = datum[ntype].node - type_ppr: dict[int, float] = {} - for j in range(ppr_edge_index.size(1)): - local_dst = ppr_edge_index[1, j].item() - global_dst = global_node_ids[local_dst].item() - type_ppr[global_dst] = ppr_weights[j].item() - sampler_ppr_by_type[str(ntype)] = type_ppr + sampler_ppr_by_type = _extract_hetero_ppr_scores( + datum, str(USER), [USER, STORY] + ) - # Compute reference PPR reference_ppr = _reference_ppr_hetero( graph=reference_graph, seed=seed_global_id, @@ -354,23 +447,9 @@ def _run_ppr_hetero_loader_correctness_check( max_ppr_nodes=max_ppr_nodes, ) - # Verify per node type - for ntype_str in [str(USER), str(STORY)]: - assert set(sampler_ppr_by_type[ntype_str].keys()) == set( - reference_ppr[ntype_str].keys() - ), ( - f"Seed {seed_global_id}, type {ntype_str}: top-k node sets differ.\n" - f" Sampler: {sorted(sampler_ppr_by_type[ntype_str].keys())}\n" - f" Reference: {sorted(reference_ppr[ntype_str].keys())}" - ) - - for node_id in reference_ppr[ntype_str]: - ref_score = reference_ppr[ntype_str][node_id] - sam_score = sampler_ppr_by_type[ntype_str][node_id] - assert abs(sam_score - ref_score) < 1e-3, ( - f"Seed {seed_global_id}, type {ntype_str}, node {node_id}: " - f"sampler={sam_score:.6f} vs reference={ref_score:.6f}" - ) + _assert_ppr_scores_match_reference( + sampler_ppr_by_type, reference_ppr, seed_global_id + ) batches_checked += 1 @@ -384,6 +463,7 @@ def _run_ppr_ablp_loader_correctness_check( _: int, alpha: float, max_ppr_nodes: int, + edge_dir: Literal["in", "out"], ) -> None: """Iterate ABLP PPR loader and verify anchor-seed PPR against NetworkX reference. @@ -403,7 +483,7 @@ def _run_ppr_ablp_loader_correctness_check( val_node_ids=[2], test_node_ids=[], edge_indices=_TEST_HETERO_EDGE_INDICES, - edge_dir="out", + edge_dir=edge_dir, ) train_node_ids = dataset.train_node_ids @@ -424,7 +504,7 @@ def _run_ppr_ablp_loader_correctness_check( batch_size=1, ) - reference_graph = _build_hetero_reference_graph(edge_dir="out") + reference_graph = _build_hetero_reference_graph(edge_dir=edge_dir) batches_checked = 0 for datum in loader: @@ -436,30 +516,9 @@ def _run_ppr_ablp_loader_correctness_check( seed_global_id = datum[USER].batch[0].item() # --- Verify anchor (USER) seed PPR correctness against NetworkX --- - sampler_ppr_by_type: dict[str, dict[int, float]] = {} - for ntype in [USER, STORY]: - key_ids = f"ppr_neighbor_ids_{USER}_{ntype}" - key_weights = f"ppr_weights_{USER}_{ntype}" - - assert hasattr(datum, key_ids), f"Missing {key_ids}" - assert hasattr(datum, key_weights), f"Missing {key_weights}" - - ppr_edge_index = getattr(datum, key_ids) - ppr_weights = getattr(datum, key_weights) - - assert ppr_edge_index.dim() == 2 and ppr_edge_index.size(0) == 2 - assert ppr_weights.dim() == 1 - assert ppr_edge_index.size(1) == ppr_weights.size(0) - assert (ppr_weights > 0).all() - assert (ppr_edge_index[0] == 0).all() # batch_size=1 - - global_node_ids = datum[ntype].node - type_ppr: dict[int, float] = {} - for j in range(ppr_edge_index.size(1)): - local_dst = ppr_edge_index[1, j].item() - global_dst = global_node_ids[local_dst].item() - type_ppr[global_dst] = ppr_weights[j].item() - sampler_ppr_by_type[str(ntype)] = type_ppr + sampler_ppr_by_type = _extract_hetero_ppr_scores( + datum, str(USER), [USER, STORY] + ) reference_ppr = _reference_ppr_hetero( graph=reference_graph, @@ -469,22 +528,9 @@ def _run_ppr_ablp_loader_correctness_check( max_ppr_nodes=max_ppr_nodes, ) - for ntype_str in [str(USER), str(STORY)]: - assert set(sampler_ppr_by_type[ntype_str].keys()) == set( - reference_ppr[ntype_str].keys() - ), ( - f"ABLP seed {seed_global_id}, type {ntype_str}: top-k node sets differ.\n" - f" Sampler: {sorted(sampler_ppr_by_type[ntype_str].keys())}\n" - f" Reference: {sorted(reference_ppr[ntype_str].keys())}" - ) - - for node_id in reference_ppr[ntype_str]: - ref_score = reference_ppr[ntype_str][node_id] - sam_score = sampler_ppr_by_type[ntype_str][node_id] - assert abs(sam_score - ref_score) < 1e-3, ( - f"ABLP seed {seed_global_id}, type {ntype_str}, node {node_id}: " - f"sampler={sam_score:.6f} vs reference={ref_score:.6f}" - ) + _assert_ppr_scores_match_reference( + sampler_ppr_by_type, reference_ppr, seed_global_id, context_label="ABLP" + ) # --- Verify supervision (STORY) seed PPR metadata --- # ABLP adds supervision nodes as additional seeds, producing PPR metadata @@ -525,29 +571,33 @@ def tearDown(self) -> None: torch.distributed.destroy_process_group() super().tearDown() - def test_ppr_sampler_correctness_homogeneous(self) -> None: + @parameterized.expand([param("in"), param("out")]) + def test_ppr_sampler_correctness_homogeneous(self, edge_dir: str) -> None: """Verify PPR scores match NetworkX pagerank on a small homogeneous graph.""" - dataset = create_homogeneous_dataset(edge_index=_TEST_EDGE_INDEX, edge_dir="in") mp.spawn( fn=_run_ppr_loader_correctness_check, - args=(dataset, _TEST_ALPHA, _TEST_MAX_PPR_NODES), + args=(_TEST_ALPHA, _TEST_MAX_PPR_NODES, edge_dir), ) - def test_ppr_sampler_correctness_heterogeneous(self) -> None: + @parameterized.expand([param("in"), param("out")]) + def test_ppr_sampler_correctness_heterogeneous(self, edge_dir: str) -> None: """Verify PPR scores match NetworkX pagerank on a heterogeneous bipartite graph.""" - dataset = create_heterogeneous_dataset( - edge_indices=_TEST_HETERO_EDGE_INDICES, edge_dir="in" - ) mp.spawn( fn=_run_ppr_hetero_loader_correctness_check, - args=(dataset, _TEST_ALPHA, _TEST_MAX_PPR_NODES), + args=(_TEST_ALPHA, _TEST_MAX_PPR_NODES, edge_dir), ) - def test_ppr_sampler_ablp_correctness(self) -> None: - """Verify PPR scores through DistABLPLoader on a heterogeneous graph.""" + @parameterized.expand([param("out")]) + def test_ppr_sampler_ablp_correctness(self, edge_dir: str) -> None: + """Verify PPR scores through DistABLPLoader on a heterogeneous graph. + + Only tests ``edge_dir="out"`` because ``DistNodeAnchorLinkSplitter`` + with ``edge_dir="in"`` reverses the supervision edge type, requiring + a reversed labeled edge type that the test dataset does not include. + """ mp.spawn( fn=_run_ppr_ablp_loader_correctness_check, - args=(_TEST_ALPHA, _TEST_MAX_PPR_NODES), + args=(_TEST_ALPHA, _TEST_MAX_PPR_NODES, edge_dir), ) From 2d0bc65b3b7fc65aa19f25f3d4b5da21ef1476f6 Mon Sep 17 00:00:00 2001 From: mkolodner Date: Wed, 11 Mar 2026 23:20:33 +0000 Subject: [PATCH 32/46] Improve PPR sampler readability: rename variables, add comments --- gigl/distributed/dist_ppr_sampler.py | 89 +++++++------------ .../unit/distributed/dist_ppr_sampler_test.py | 8 +- 2 files changed, 37 insertions(+), 60 deletions(-) diff --git a/gigl/distributed/dist_ppr_sampler.py b/gigl/distributed/dist_ppr_sampler.py index 08bb69548..4ac5ae6b8 100644 --- a/gigl/distributed/dist_ppr_sampler.py +++ b/gigl/distributed/dist_ppr_sampler.py @@ -16,6 +16,9 @@ from gigl.distributed.dist_dataset import DistDataset from gigl.distributed.dist_neighbor_sampler import DistNeighborSampler +# Sentinel type names for homogeneous graphs. The PPR algorithm uses +# dict[NodeType, ...] internally for both homo and hetero graphs; these +# sentinels let the homogeneous path reuse the same dict-based code. _PPR_HOMOGENEOUS_NODE_TYPE = "ppr_homogeneous_node_type" _PPR_HOMOGENEOUS_EDGE_TYPE = ( _PPR_HOMOGENEOUS_NODE_TYPE, @@ -60,7 +63,7 @@ def __init__( self._alpha = alpha self._eps = eps self._max_ppr_nodes = max_ppr_nodes - self._alpha_eps = alpha * eps + self._requeue_threshold_factor = alpha * eps self._num_nbrs_per_hop = num_nbrs_per_hop assert isinstance( @@ -241,24 +244,19 @@ async def _compute_ppr_scores( device = seed_nodes.device batch_size = seed_nodes.size(0) - # PPR scores: p[i][(node_id, node_type)] = score - p: list[dict[tuple[int, NodeType], float]] = [ + ppr_scores: list[dict[tuple[int, NodeType], float]] = [ defaultdict(float) for _ in range(batch_size) ] - # Residuals: r[i][(node_id, node_type)] = residual - r: list[dict[tuple[int, NodeType], float]] = [ + residuals: list[dict[tuple[int, NodeType], float]] = [ defaultdict(float) for _ in range(batch_size) ] - - # Queue stores (node_id, node_type) tuples - q: list[set[tuple[int, NodeType]]] = [set() for _ in range(batch_size)] + queue: list[set[tuple[int, NodeType]]] = [set() for _ in range(batch_size)] seed_list = seed_nodes.tolist() - # Initialize residuals: r[i][(seed, seed_type)] = alpha for each seed for i, seed in enumerate(seed_list): - r[i][(seed, seed_node_type)] = self._alpha - q[i].add((seed, seed_node_type)) + residuals[i][(seed, seed_node_type)] = self._alpha + queue[i].add((seed, seed_node_type)) # Cache keyed by (node_id, edge_type) since same node can have different neighbors per edge type neighbor_cache: dict[tuple[int, EdgeType], list[int]] = {} @@ -294,12 +292,11 @@ def _get_total_degree(node_id: int, node_type: NodeType) -> int: nodes_by_edge_type: dict[EdgeType, set[int]] = defaultdict(set) for i in range(batch_size): - if q[i]: - nodes_to_process[i] = q[i] - q[i] = set() + if queue[i]: + nodes_to_process[i] = queue[i] + queue[i] = set() num_nodes_in_queue -= len(nodes_to_process[i]) - # Group nodes by edge type for batched lookups for node_id, node_type in nodes_to_process[i]: edge_types_for_node = self._node_type_to_edge_types[node_type] for etype in edge_types_for_node: @@ -307,35 +304,21 @@ def _get_total_degree(node_id: int, node_type: NodeType) -> int: if cache_key not in neighbor_cache: nodes_by_edge_type[etype].add(node_id) - # Batch fetch neighbors per edge type await self._batch_fetch_neighbors( nodes_by_edge_type, neighbor_cache, device ) - # Push residual to neighbors and re-queue in a single pass. - # - # Previously these were two separate loops over the same neighbor - # lists — one to push residual, one to check thresholds. Merging - # them halves the total neighbor-list iteration. - # - # This is safe because each seed's state (p, r, q) is independent. - # If node v receives residual from multiple frontier nodes (u1, u2) - # of the same seed, v's threshold is checked after each push. The - # last frontier node to push to v sees the same accumulated residual - # that the original two-pass version would see. Since push values - # are always positive (residual monotonically increases), the merged - # version can never miss a re-queue that the two-pass version would - # catch. + # Push residual to neighbors and re-queue in a single pass. This + # is safe because each seed's state is independent, and residuals + # are always positive so the merged loop can never miss a re-queue. for i in range(batch_size): for u_node, u_type in nodes_to_process[i]: key_u = (u_node, u_type) - res_u = r[i].get(key_u, 0.0) + res_u = residuals[i].get(key_u, 0.0) - # Push to PPR score and reset residual - p[i][key_u] += res_u - r[i][key_u] = 0.0 + ppr_scores[i][key_u] += res_u + residuals[i][key_u] = 0.0 - # For each edge type from this node type, push residual to neighbors edge_types_for_node = self._node_type_to_edge_types[u_type] total_degree = _get_total_degree(u_node, u_type) @@ -345,10 +328,6 @@ def _get_total_degree(node_id: int, node_type: NodeType) -> int: push_value = one_minus_alpha * res_u / total_degree - # Push residual proportionally based on degree per edge type. - # Per-edge-type degree is retrieved from _get_total_degree's - # cached sum path — the individual lookups are only needed - # to detect zero-degree edge types. for etype in edge_types_for_node: cache_key = (u_node, etype) neighbor_list = neighbor_cache[cache_key] @@ -359,16 +338,15 @@ def _get_total_degree(node_id: int, node_type: NodeType) -> int: for v_node in neighbor_list: key_v = (v_node, v_type) - r[i][key_v] += push_value + residuals[i][key_v] += push_value - # Inline re-queue check: if v is not already queued - # and its accumulated residual exceeds the threshold, - # add it to the queue for the next iteration. - if key_v not in q[i]: - if r[i][key_v] >= self._alpha_eps * _get_total_degree( + if key_v not in queue[i]: + if residuals[i][ + key_v + ] >= self._requeue_threshold_factor * _get_total_degree( v_node, v_type ): - q[i].add(key_v) + queue[i].add(key_v) num_nodes_in_queue += 1 # Extract top-k nodes by PPR score, grouped by node type. @@ -376,7 +354,7 @@ def _get_total_degree(node_id: int, node_type: NodeType) -> int: # neighbors seed i actually has, so callers can recover per-seed slices. all_node_types: set[NodeType] = set() for i in range(batch_size): - for _node_id, node_type in p[i].keys(): + for _node_id, node_type in ppr_scores[i].keys(): all_node_types.add(node_type) out_flat_ids_dict: dict[NodeType, torch.Tensor] = {} @@ -391,7 +369,7 @@ def _get_total_degree(node_id: int, node_type: NodeType) -> int: for i in range(batch_size): type_scores = { node_id: score - for (node_id, node_type), score in p[i].items() + for (node_id, node_type), score in ppr_scores[i].items() if node_type == ntype } top_k = heapq.nlargest( @@ -604,13 +582,14 @@ async def _sample_from_nodes( # srcs holds their global IDs (same values as nodes_to_sample). srcs = inducer.init_node(nodes_to_sample) - homo_ppr_result = await self._compute_ppr_scores(nodes_to_sample, None) - assert isinstance(homo_ppr_result[0], torch.Tensor) - assert isinstance(homo_ppr_result[1], torch.Tensor) - assert isinstance(homo_ppr_result[2], torch.Tensor) - homo_flat_ids: torch.Tensor = homo_ppr_result[0] - homo_flat_weights: torch.Tensor = homo_ppr_result[1] - homo_valid_counts: torch.Tensor = homo_ppr_result[2] + ( + homo_flat_ids, + homo_flat_weights, + homo_valid_counts, + ) = await self._compute_ppr_scores(nodes_to_sample, None) + assert isinstance(homo_flat_ids, torch.Tensor) + assert isinstance(homo_flat_weights, torch.Tensor) + assert isinstance(homo_valid_counts, torch.Tensor) # induce_next deduplicates homo_flat_ids against already-seen nodes # (the seeds registered above) and returns: diff --git a/tests/unit/distributed/dist_ppr_sampler_test.py b/tests/unit/distributed/dist_ppr_sampler_test.py index d0617793c..f3538ea78 100644 --- a/tests/unit/distributed/dist_ppr_sampler_test.py +++ b/tests/unit/distributed/dist_ppr_sampler_test.py @@ -89,7 +89,6 @@ _TEST_ALPHA = 0.5 _TEST_EPS = 1e-6 _TEST_MAX_PPR_NODES = 5 -_TEST_NUM_NBRS_PER_HOP = 100000 # --------------------------------------------------------------------------- @@ -322,7 +321,6 @@ def _run_ppr_loader_correctness_check( alpha=alpha, eps=_TEST_EPS, max_ppr_nodes=max_ppr_nodes, - num_nbrs_per_hop=_TEST_NUM_NBRS_PER_HOP, ), pin_memory_device=torch.device("cpu"), batch_size=1, @@ -421,7 +419,6 @@ def _run_ppr_hetero_loader_correctness_check( alpha=alpha, eps=_TEST_EPS, max_ppr_nodes=max_ppr_nodes, - num_nbrs_per_hop=_TEST_NUM_NBRS_PER_HOP, ), pin_memory_device=torch.device("cpu"), batch_size=1, @@ -498,7 +495,6 @@ def _run_ppr_ablp_loader_correctness_check( alpha=alpha, eps=_TEST_EPS, max_ppr_nodes=max_ppr_nodes, - num_nbrs_per_hop=_TEST_NUM_NBRS_PER_HOP, ), pin_memory_device=torch.device("cpu"), batch_size=1, @@ -534,7 +530,9 @@ def _run_ppr_ablp_loader_correctness_check( # --- Verify supervision (STORY) seed PPR metadata --- # ABLP adds supervision nodes as additional seeds, producing PPR metadata - # keyed by the STORY seed type. + # keyed by the STORY seed type. We only check shapes here (not correctness + # against NetworkX) because the supervision seeds vary per batch depending + # on the label edges, making deterministic reference computation complex. for ntype in [USER, STORY]: key_ids = f"ppr_neighbor_ids_{STORY}_{ntype}" key_weights = f"ppr_weights_{STORY}_{ntype}" From 9f0a67203f0fe23548cee8126453003f4626d3c3 Mon Sep 17 00:00:00 2001 From: mkolodner Date: Thu, 12 Mar 2026 23:14:17 +0000 Subject: [PATCH 33/46] Unify metadata extraction into BaseDistLoader --- gigl/distributed/base_dist_loader.py | 25 +++- gigl/distributed/dist_ablp_neighborloader.py | 120 ++++++------------ .../distributed/distributed_neighborloader.py | 23 ---- 3 files changed, 64 insertions(+), 104 deletions(-) diff --git a/gigl/distributed/base_dist_loader.py b/gigl/distributed/base_dist_loader.py index 17c73905a..3ff1e6ab7 100644 --- a/gigl/distributed/base_dist_loader.py +++ b/gigl/distributed/base_dist_loader.py @@ -16,7 +16,7 @@ from typing import Callable, Optional, Union import torch -from graphlearn_torch.channel import RemoteReceivingChannel, ShmChannel +from graphlearn_torch.channel import RemoteReceivingChannel, SampleMessage, ShmChannel from graphlearn_torch.distributed import ( DistLoader, MpDistSamplingWorkerOptions, @@ -677,6 +677,29 @@ def shutdown(self) -> None: torch.futures.wait_all(rpc_futures) self._shutdowned = True + def _extract_metadata(self, msg: SampleMessage) -> dict[str, torch.Tensor]: + """Extract and remove user-defined metadata from a SampleMessage. + + GLT's ``to_hetero_data`` misinterprets ``#META.``-prefixed keys as + edge types, causing failures with ``edge_dir="out"`` (it tries to call + ``reverse_edge_type`` on metadata key strings). This method strips + those keys so the conversion succeeds, returning them for re-application + onto the output Data/HeteroData. + + Args: + msg: The SampleMessage to modify in-place. + + Returns: + Dict mapping metadata key (without ``#META.`` prefix) to tensor. + """ + meta_prefix = "#META." + result: dict[str, torch.Tensor] = {} + for k in list(msg.keys()): + if k.startswith(meta_prefix): + result[k[len(meta_prefix) :]] = msg[k].to(self.to_device) + del msg[k] + return result + # Overwrite DistLoader.__iter__ to so we can use our own __iter__ and rpc calls def __iter__(self) -> Self: self._num_recv = 0 diff --git a/gigl/distributed/dist_ablp_neighborloader.py b/gigl/distributed/dist_ablp_neighborloader.py index d43bc4b63..63e9c0657 100644 --- a/gigl/distributed/dist_ablp_neighborloader.py +++ b/gigl/distributed/dist_ablp_neighborloader.py @@ -25,7 +25,6 @@ NEGATIVE_LABEL_METADATA_KEY, POSITIVE_LABEL_METADATA_KEY, ABLPNodeSamplerInput, - metadata_key_with_prefix, ) from gigl.distributed.sampler_options import SamplerOptions, resolve_sampler_options from gigl.distributed.utils.neighborloader import ( @@ -772,72 +771,58 @@ def _setup_for_graph_store( ), ) - def _get_labels( - self, msg: SampleMessage + def _extract_labels( + self, metadata: dict[str, torch.Tensor] ) -> tuple[ - SampleMessage, dict[EdgeType, torch.Tensor], dict[EdgeType, torch.Tensor], + dict[str, torch.Tensor], ]: - # TODO (mkolodner-sc): Remove the need to modify metadata once GLT's `to_hetero_data` function is fixed - f""" - Gets the labels from the output SampleMessage and removes them from the metadata. We need to remove the labels from GLT's metadata since the - `to_hetero_data` function strangely assumes that we are doing edge-based sampling if the metadata is not empty at the time of - building the HeteroData object. + """Partition pre-extracted metadata into labels and remaining metadata. + + Takes the metadata dict already extracted by ``_extract_metadata`` (keys + without the ``#META.`` prefix) and separates label entries from + non-label entries. + + Label keys use ``POSITIVE_LABEL_METADATA_KEY`` / ``NEGATIVE_LABEL_METADATA_KEY`` + prefixes followed by a string-encoded edge type tuple. If ``edge_dir`` + is ``"in"``, the edge type is reversed because GLT swaps src/dst + internally. Args: - msg (SampleMessage): All possible results from a sampler, including subgraph data, features, and used defined metadata + metadata: Dict of metadata keys (without ``#META.`` prefix) to tensors, + as returned by ``_extract_metadata``. + Returns: - SampleMessage: Updated sample messsage with the label fields removed - dict[EdgeType, torch.Tensor]: Dict[positive label edge type, label ID tensor], - where the ith row of the tensor corresponds to the ith anchor node ID. - dict[EdgeType, torch.Tensor]: Dict[negative label edge type, label ID tensor], - where the ith row of the tensor corresponds to the ith anchor node ID. - May be empty if no negative labels are present. + A 3-tuple of: + - positive_labels: Dict[label edge type, label ID tensor] + - negative_labels: Dict[label edge type, label ID tensor] (may be empty) + - remaining_metadata: Non-label metadata entries """ - metadata: dict[str, torch.Tensor] = {} positive_labels_by_label_edge_type: dict[EdgeType, torch.Tensor] = {} negative_labels_by_label_edge_type: dict[EdgeType, torch.Tensor] = {} - # We update metadata with sepcial POSITIVE_LABEL_METADATA_KEY and NEGATIVE_LABEL_METADATA_KEY keys - # in gigl/distributed/dist_neighbor_sampler.py. - # We need to encode the tuples as strings because GLT requires the keys to be strings. - # As such, we decode the strings back into tuples, - # And then pop those keys out of the metadata as they are not needed otherwise. - # If edge_dir is "in", we need to reverse the edge type because GLT swaps src/dst for edge_dir = "out". - # NOTE: GLT *prepends* the keys with "#META." - positive_label_metadata_key_prefix = metadata_key_with_prefix( - POSITIVE_LABEL_METADATA_KEY - ) - negative_label_metadata_key_prefix = metadata_key_with_prefix( - NEGATIVE_LABEL_METADATA_KEY - ) - for k in list(msg.keys()): - if k.startswith(positive_label_metadata_key_prefix): - edge_type_str = k[len(positive_label_metadata_key_prefix) :] + remaining_metadata: dict[str, torch.Tensor] = {} + + for key, value in metadata.items(): + if key.startswith(POSITIVE_LABEL_METADATA_KEY): + edge_type_str = key[len(POSITIVE_LABEL_METADATA_KEY) :] edge_type = ast.literal_eval(edge_type_str) if self.edge_dir == "in": edge_type = reverse_edge_type(edge_type) - positive_labels_by_label_edge_type[edge_type] = msg[k].to( - self.to_device - ) - del msg[k] - elif k.startswith(negative_label_metadata_key_prefix): - edge_type_str = k[len(negative_label_metadata_key_prefix) :] + positive_labels_by_label_edge_type[edge_type] = value + elif key.startswith(NEGATIVE_LABEL_METADATA_KEY): + edge_type_str = key[len(NEGATIVE_LABEL_METADATA_KEY) :] edge_type = ast.literal_eval(edge_type_str) if self.edge_dir == "in": edge_type = reverse_edge_type(edge_type) - negative_labels_by_label_edge_type[edge_type] = msg[k].to( - self.to_device - ) - del msg[k] - elif k.startswith("#META."): - meta_key = str(k[len("#META.") :]) - metadata[meta_key] = msg[k].to(self.to_device) - del msg[k] + negative_labels_by_label_edge_type[edge_type] = value + else: + remaining_metadata[key] = value + return ( - msg, positive_labels_by_label_edge_type, negative_labels_by_label_edge_type, + remaining_metadata, ) def _set_labels( @@ -925,11 +910,13 @@ def _set_labels( return data def _collate_fn(self, msg: SampleMessage) -> Union[Data, HeteroData]: - # _get_labels strips ALL #META. keys from the message to work around a - # GLT bug in to_hetero_data. Collect non-label metadata beforehand so - # we can re-apply it to the output data after conversion. - non_label_metadata = self._extract_non_label_metadata(msg) - msg, positive_labels, negative_labels = self._get_labels(msg) + # _extract_metadata strips ALL #META. keys from the message to work + # around a GLT bug in to_hetero_data. _extract_labels then partitions + # the result into labels vs remaining non-label metadata. + all_metadata = self._extract_metadata(msg) + positive_labels, negative_labels, non_label_metadata = self._extract_labels( + all_metadata + ) data = super()._collate_fn(msg) data = set_missing_features( data=data, @@ -949,30 +936,3 @@ def _collate_fn(self, msg: SampleMessage) -> Union[Data, HeteroData]: data[key] = value data = self._set_labels(data, positive_labels, negative_labels) return data - - def _extract_non_label_metadata( - self, msg: SampleMessage - ) -> dict[str, torch.Tensor]: - """Extract non-label metadata from a SampleMessage before _get_labels strips it. - - _get_labels removes ALL ``#META.`` keys from the message to avoid a GLT - bug in ``to_hetero_data``. This method reads non-label metadata (e.g. - PPR scores) so it can be re-applied to the output Data/HeteroData after - conversion. - - Args: - msg: The SampleMessage to scan (not modified). - - Returns: - Dict mapping metadata key (without ``#META.`` prefix) to tensor. - """ - meta_prefix = "#META." - label_prefixes = ( - metadata_key_with_prefix(POSITIVE_LABEL_METADATA_KEY), - metadata_key_with_prefix(NEGATIVE_LABEL_METADATA_KEY), - ) - result: dict[str, torch.Tensor] = {} - for k in msg.keys(): - if k.startswith(meta_prefix) and not k.startswith(label_prefixes): - result[k[len(meta_prefix) :]] = msg[k].to(self.to_device) - return result diff --git a/gigl/distributed/distributed_neighborloader.py b/gigl/distributed/distributed_neighborloader.py index f59b732ee..bdb03717c 100644 --- a/gigl/distributed/distributed_neighborloader.py +++ b/gigl/distributed/distributed_neighborloader.py @@ -578,26 +578,3 @@ def _collate_fn(self, msg: SampleMessage) -> Union[Data, HeteroData]: for key, value in non_edge_metadata.items(): data[key] = value return data - - def _extract_metadata(self, msg: SampleMessage) -> dict[str, torch.Tensor]: - """Extract and remove user-defined metadata from a SampleMessage. - - GLT's ``to_hetero_data`` misinterprets ``#META.``-prefixed keys as - edge types, causing failures with ``edge_dir="out"`` (it tries to call - ``reverse_edge_type`` on metadata key strings). This method strips - those keys so the conversion succeeds, returning them for re-application - onto the output Data/HeteroData. - - Args: - msg: The SampleMessage to modify in-place. - - Returns: - Dict mapping metadata key (without ``#META.`` prefix) to tensor. - """ - meta_prefix = "#META." - result: dict[str, torch.Tensor] = {} - for k in list(msg.keys()): - if k.startswith(meta_prefix): - result[k[len(meta_prefix) :]] = msg[k].to(self.to_device) - del msg[k] - return result From 85f553f6127d77134fc5cde85a081ce6b39f38ad Mon Sep 17 00:00:00 2001 From: mkolodner Date: Thu, 12 Mar 2026 23:20:01 +0000 Subject: [PATCH 34/46] Remove unused metadata_key_with_prefix --- gigl/distributed/sampler.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/gigl/distributed/sampler.py b/gigl/distributed/sampler.py index 5d0d63fa9..e99dd65dc 100644 --- a/gigl/distributed/sampler.py +++ b/gigl/distributed/sampler.py @@ -10,14 +10,6 @@ NEGATIVE_LABEL_METADATA_KEY: Final[str] = "gigl_negative_labels." -def metadata_key_with_prefix(key: str) -> str: - """Prefixes the key with "#META - Do this as GLT also does this. - https://github.com/alibaba/graphlearn-for-pytorch/blob/88ff111ac0d9e45c6c9d2d18cfc5883dca07e9f9/graphlearn_torch/python/distributed/dist_neighbor_sampler.py#L714 - """ - return f"#META.{key}" - - class ABLPNodeSamplerInput(NodeSamplerInput): """ Sampler input specific for ABLP use case. Contains additional information about positive labels, negative labels, and the corresponding From 549c4305893c5fb79df7cc27079b56f84a0690d2 Mon Sep 17 00:00:00 2001 From: mkolodner Date: Thu, 12 Mar 2026 23:37:27 +0000 Subject: [PATCH 35/46] Update --- gigl/distributed/base_dist_loader.py | 25 ++++- gigl/distributed/dist_ablp_neighborloader.py | 91 +++++++++---------- .../distributed/distributed_neighborloader.py | 8 ++ gigl/distributed/sampler.py | 8 -- 4 files changed, 74 insertions(+), 58 deletions(-) diff --git a/gigl/distributed/base_dist_loader.py b/gigl/distributed/base_dist_loader.py index 17c73905a..3ff1e6ab7 100644 --- a/gigl/distributed/base_dist_loader.py +++ b/gigl/distributed/base_dist_loader.py @@ -16,7 +16,7 @@ from typing import Callable, Optional, Union import torch -from graphlearn_torch.channel import RemoteReceivingChannel, ShmChannel +from graphlearn_torch.channel import RemoteReceivingChannel, SampleMessage, ShmChannel from graphlearn_torch.distributed import ( DistLoader, MpDistSamplingWorkerOptions, @@ -677,6 +677,29 @@ def shutdown(self) -> None: torch.futures.wait_all(rpc_futures) self._shutdowned = True + def _extract_metadata(self, msg: SampleMessage) -> dict[str, torch.Tensor]: + """Extract and remove user-defined metadata from a SampleMessage. + + GLT's ``to_hetero_data`` misinterprets ``#META.``-prefixed keys as + edge types, causing failures with ``edge_dir="out"`` (it tries to call + ``reverse_edge_type`` on metadata key strings). This method strips + those keys so the conversion succeeds, returning them for re-application + onto the output Data/HeteroData. + + Args: + msg: The SampleMessage to modify in-place. + + Returns: + Dict mapping metadata key (without ``#META.`` prefix) to tensor. + """ + meta_prefix = "#META." + result: dict[str, torch.Tensor] = {} + for k in list(msg.keys()): + if k.startswith(meta_prefix): + result[k[len(meta_prefix) :]] = msg[k].to(self.to_device) + del msg[k] + return result + # Overwrite DistLoader.__iter__ to so we can use our own __iter__ and rpc calls def __iter__(self) -> Self: self._num_recv = 0 diff --git a/gigl/distributed/dist_ablp_neighborloader.py b/gigl/distributed/dist_ablp_neighborloader.py index 92059df87..63e9c0657 100644 --- a/gigl/distributed/dist_ablp_neighborloader.py +++ b/gigl/distributed/dist_ablp_neighborloader.py @@ -25,7 +25,6 @@ NEGATIVE_LABEL_METADATA_KEY, POSITIVE_LABEL_METADATA_KEY, ABLPNodeSamplerInput, - metadata_key_with_prefix, ) from gigl.distributed.sampler_options import SamplerOptions, resolve_sampler_options from gigl.distributed.utils.neighborloader import ( @@ -772,72 +771,58 @@ def _setup_for_graph_store( ), ) - def _get_labels( - self, msg: SampleMessage + def _extract_labels( + self, metadata: dict[str, torch.Tensor] ) -> tuple[ - SampleMessage, dict[EdgeType, torch.Tensor], dict[EdgeType, torch.Tensor], + dict[str, torch.Tensor], ]: - # TODO (mkolodner-sc): Remove the need to modify metadata once GLT's `to_hetero_data` function is fixed - f""" - Gets the labels from the output SampleMessage and removes them from the metadata. We need to remove the labels from GLT's metadata since the - `to_hetero_data` function strangely assumes that we are doing edge-based sampling if the metadata is not empty at the time of - building the HeteroData object. + """Partition pre-extracted metadata into labels and remaining metadata. + + Takes the metadata dict already extracted by ``_extract_metadata`` (keys + without the ``#META.`` prefix) and separates label entries from + non-label entries. + + Label keys use ``POSITIVE_LABEL_METADATA_KEY`` / ``NEGATIVE_LABEL_METADATA_KEY`` + prefixes followed by a string-encoded edge type tuple. If ``edge_dir`` + is ``"in"``, the edge type is reversed because GLT swaps src/dst + internally. Args: - msg (SampleMessage): All possible results from a sampler, including subgraph data, features, and used defined metadata + metadata: Dict of metadata keys (without ``#META.`` prefix) to tensors, + as returned by ``_extract_metadata``. + Returns: - SampleMessage: Updated sample messsage with the label fields removed - dict[EdgeType, torch.Tensor]: Dict[positive label edge type, label ID tensor], - where the ith row of the tensor corresponds to the ith anchor node ID. - dict[EdgeType, torch.Tensor]: Dict[negative label edge type, label ID tensor], - where the ith row of the tensor corresponds to the ith anchor node ID. - May be empty if no negative labels are present. + A 3-tuple of: + - positive_labels: Dict[label edge type, label ID tensor] + - negative_labels: Dict[label edge type, label ID tensor] (may be empty) + - remaining_metadata: Non-label metadata entries """ - metadata: dict[str, torch.Tensor] = {} positive_labels_by_label_edge_type: dict[EdgeType, torch.Tensor] = {} negative_labels_by_label_edge_type: dict[EdgeType, torch.Tensor] = {} - # We update metadata with sepcial POSITIVE_LABEL_METADATA_KEY and NEGATIVE_LABEL_METADATA_KEY keys - # in gigl/distributed/dist_neighbor_sampler.py. - # We need to encode the tuples as strings because GLT requires the keys to be strings. - # As such, we decode the strings back into tuples, - # And then pop those keys out of the metadata as they are not needed otherwise. - # If edge_dir is "in", we need to reverse the edge type because GLT swaps src/dst for edge_dir = "out". - # NOTE: GLT *prepends* the keys with "#META." - positive_label_metadata_key_prefix = metadata_key_with_prefix( - POSITIVE_LABEL_METADATA_KEY - ) - negative_label_metadata_key_prefix = metadata_key_with_prefix( - NEGATIVE_LABEL_METADATA_KEY - ) - for k in list(msg.keys()): - if k.startswith(positive_label_metadata_key_prefix): - edge_type_str = k[len(positive_label_metadata_key_prefix) :] + remaining_metadata: dict[str, torch.Tensor] = {} + + for key, value in metadata.items(): + if key.startswith(POSITIVE_LABEL_METADATA_KEY): + edge_type_str = key[len(POSITIVE_LABEL_METADATA_KEY) :] edge_type = ast.literal_eval(edge_type_str) if self.edge_dir == "in": edge_type = reverse_edge_type(edge_type) - positive_labels_by_label_edge_type[edge_type] = msg[k].to( - self.to_device - ) - del msg[k] - elif k.startswith(negative_label_metadata_key_prefix): - edge_type_str = k[len(negative_label_metadata_key_prefix) :] + positive_labels_by_label_edge_type[edge_type] = value + elif key.startswith(NEGATIVE_LABEL_METADATA_KEY): + edge_type_str = key[len(NEGATIVE_LABEL_METADATA_KEY) :] edge_type = ast.literal_eval(edge_type_str) if self.edge_dir == "in": edge_type = reverse_edge_type(edge_type) - negative_labels_by_label_edge_type[edge_type] = msg[k].to( - self.to_device - ) - del msg[k] - elif k.startswith("#META."): - meta_key = str(k[len("#META.") :]) - metadata[meta_key] = msg[k].to(self.to_device) - del msg[k] + negative_labels_by_label_edge_type[edge_type] = value + else: + remaining_metadata[key] = value + return ( - msg, positive_labels_by_label_edge_type, negative_labels_by_label_edge_type, + remaining_metadata, ) def _set_labels( @@ -925,7 +910,13 @@ def _set_labels( return data def _collate_fn(self, msg: SampleMessage) -> Union[Data, HeteroData]: - msg, positive_labels, negative_labels = self._get_labels(msg) + # _extract_metadata strips ALL #META. keys from the message to work + # around a GLT bug in to_hetero_data. _extract_labels then partitions + # the result into labels vs remaining non-label metadata. + all_metadata = self._extract_metadata(msg) + positive_labels, negative_labels, non_label_metadata = self._extract_labels( + all_metadata + ) data = super()._collate_fn(msg) data = set_missing_features( data=data, @@ -941,5 +932,7 @@ def _collate_fn(self, msg: SampleMessage) -> Union[Data, HeteroData]: f"Expected 1 supervision edge type, got {len(self._supervision_edge_types)}" ) data = labeled_to_homogeneous(self._supervision_edge_types[0], data) + for key, value in non_label_metadata.items(): + data[key] = value data = self._set_labels(data, positive_labels, negative_labels) return data diff --git a/gigl/distributed/distributed_neighborloader.py b/gigl/distributed/distributed_neighborloader.py index 9c72500b1..bdb03717c 100644 --- a/gigl/distributed/distributed_neighborloader.py +++ b/gigl/distributed/distributed_neighborloader.py @@ -558,6 +558,12 @@ def _setup_for_colocated( ) def _collate_fn(self, msg: SampleMessage) -> Union[Data, HeteroData]: + # Extract user-defined metadata (e.g. PPR scores) before + # super()._collate_fn, which calls GLT's to_hetero_data. + # to_hetero_data misinterprets #META. keys as edge types and + # fails when edge_dir="out" (tries to reverse_edge_type on them). + # We strip them here and re-apply after conversion. + non_edge_metadata = self._extract_metadata(msg) data = super()._collate_fn(msg) data = set_missing_features( data=data, @@ -569,4 +575,6 @@ def _collate_fn(self, msg: SampleMessage) -> Union[Data, HeteroData]: data = strip_label_edges(data) if self._is_homogeneous_with_labeled_edge_type: data = labeled_to_homogeneous(DEFAULT_HOMOGENEOUS_EDGE_TYPE, data) + for key, value in non_edge_metadata.items(): + data[key] = value return data diff --git a/gigl/distributed/sampler.py b/gigl/distributed/sampler.py index 5d0d63fa9..e99dd65dc 100644 --- a/gigl/distributed/sampler.py +++ b/gigl/distributed/sampler.py @@ -10,14 +10,6 @@ NEGATIVE_LABEL_METADATA_KEY: Final[str] = "gigl_negative_labels." -def metadata_key_with_prefix(key: str) -> str: - """Prefixes the key with "#META - Do this as GLT also does this. - https://github.com/alibaba/graphlearn-for-pytorch/blob/88ff111ac0d9e45c6c9d2d18cfc5883dca07e9f9/graphlearn_torch/python/distributed/dist_neighbor_sampler.py#L714 - """ - return f"#META.{key}" - - class ABLPNodeSamplerInput(NodeSamplerInput): """ Sampler input specific for ABLP use case. Contains additional information about positive labels, negative labels, and the corresponding From 6b4db900e1307f6d1eaf931daabd0427b9b53d60 Mon Sep 17 00:00:00 2001 From: mkolodner Date: Fri, 13 Mar 2026 20:42:42 +0000 Subject: [PATCH 36/46] small update --- gigl/distributed/dist_ablp_neighborloader.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/gigl/distributed/dist_ablp_neighborloader.py b/gigl/distributed/dist_ablp_neighborloader.py index 63e9c0657..1efb1d03e 100644 --- a/gigl/distributed/dist_ablp_neighborloader.py +++ b/gigl/distributed/dist_ablp_neighborloader.py @@ -780,9 +780,13 @@ def _extract_labels( ]: """Partition pre-extracted metadata into labels and remaining metadata. + # TODO (mkolodner-sc): Remove the need to modify metadata once GLT's `to_hetero_data` function is fixed + Takes the metadata dict already extracted by ``_extract_metadata`` (keys without the ``#META.`` prefix) and separates label entries from - non-label entries. + non-label entries. We need to remove the labels from GLT's metadata since the `to_hetero_data` function + strangely assumes that we are doing edge-based sampling if the metadata is not empty at the time of + building the HeteroData object. Label keys use ``POSITIVE_LABEL_METADATA_KEY`` / ``NEGATIVE_LABEL_METADATA_KEY`` prefixes followed by a string-encoded edge type tuple. If ``edge_dir`` @@ -794,10 +798,12 @@ def _extract_labels( as returned by ``_extract_metadata``. Returns: - A 3-tuple of: - - positive_labels: Dict[label edge type, label ID tensor] - - negative_labels: Dict[label edge type, label ID tensor] (may be empty) - - remaining_metadata: Non-label metadata entries + dict[EdgeType, torch.Tensor]: Dict[positive label edge type, label ID tensor], + where the ith row of the tensor corresponds to the ith anchor node ID. + dict[EdgeType, torch.Tensor]: Dict[negative label edge type, label ID tensor], + where the ith row of the tensor corresponds to the ith anchor node ID. + May be empty if no negative labels are present. + dict[str, torch.Tensor]: Non-label metadata entries """ positive_labels_by_label_edge_type: dict[EdgeType, torch.Tensor] = {} negative_labels_by_label_edge_type: dict[EdgeType, torch.Tensor] = {} From c9b3f80d24172ada183900aa9e955d1f925eb9d7 Mon Sep 17 00:00:00 2001 From: mkolodner Date: Fri, 13 Mar 2026 21:56:18 +0000 Subject: [PATCH 37/46] Move extract_metadata to utility function with tests --- gigl/distributed/base_dist_loader.py | 25 +------- gigl/distributed/dist_ablp_neighborloader.py | 9 +-- .../distributed/distributed_neighborloader.py | 5 +- gigl/distributed/utils/neighborloader.py | 34 +++++++++++ .../distributed/utils/neighborloader_test.py | 60 +++++++++++++++++++ 5 files changed, 103 insertions(+), 30 deletions(-) diff --git a/gigl/distributed/base_dist_loader.py b/gigl/distributed/base_dist_loader.py index 3ff1e6ab7..17c73905a 100644 --- a/gigl/distributed/base_dist_loader.py +++ b/gigl/distributed/base_dist_loader.py @@ -16,7 +16,7 @@ from typing import Callable, Optional, Union import torch -from graphlearn_torch.channel import RemoteReceivingChannel, SampleMessage, ShmChannel +from graphlearn_torch.channel import RemoteReceivingChannel, ShmChannel from graphlearn_torch.distributed import ( DistLoader, MpDistSamplingWorkerOptions, @@ -677,29 +677,6 @@ def shutdown(self) -> None: torch.futures.wait_all(rpc_futures) self._shutdowned = True - def _extract_metadata(self, msg: SampleMessage) -> dict[str, torch.Tensor]: - """Extract and remove user-defined metadata from a SampleMessage. - - GLT's ``to_hetero_data`` misinterprets ``#META.``-prefixed keys as - edge types, causing failures with ``edge_dir="out"`` (it tries to call - ``reverse_edge_type`` on metadata key strings). This method strips - those keys so the conversion succeeds, returning them for re-application - onto the output Data/HeteroData. - - Args: - msg: The SampleMessage to modify in-place. - - Returns: - Dict mapping metadata key (without ``#META.`` prefix) to tensor. - """ - meta_prefix = "#META." - result: dict[str, torch.Tensor] = {} - for k in list(msg.keys()): - if k.startswith(meta_prefix): - result[k[len(meta_prefix) :]] = msg[k].to(self.to_device) - del msg[k] - return result - # Overwrite DistLoader.__iter__ to so we can use our own __iter__ and rpc calls def __iter__(self) -> Self: self._num_recv = 0 diff --git a/gigl/distributed/dist_ablp_neighborloader.py b/gigl/distributed/dist_ablp_neighborloader.py index 1efb1d03e..c52081c86 100644 --- a/gigl/distributed/dist_ablp_neighborloader.py +++ b/gigl/distributed/dist_ablp_neighborloader.py @@ -30,6 +30,7 @@ from gigl.distributed.utils.neighborloader import ( DatasetSchema, SamplingClusterSetup, + extract_metadata, labeled_to_homogeneous, set_missing_features, shard_nodes_by_process, @@ -916,14 +917,14 @@ def _set_labels( return data def _collate_fn(self, msg: SampleMessage) -> Union[Data, HeteroData]: - # _extract_metadata strips ALL #META. keys from the message to work + # _extract_metadata separates #META. keys from the message to work # around a GLT bug in to_hetero_data. _extract_labels then partitions - # the result into labels vs remaining non-label metadata. - all_metadata = self._extract_metadata(msg) + # the metadata into labels vs remaining non-label metadata. + all_metadata, stripped_msg = extract_metadata(msg, self.to_device) positive_labels, negative_labels, non_label_metadata = self._extract_labels( all_metadata ) - data = super()._collate_fn(msg) + data = super()._collate_fn(stripped_msg) data = set_missing_features( data=data, node_feature_info=self._node_feature_info, diff --git a/gigl/distributed/distributed_neighborloader.py b/gigl/distributed/distributed_neighborloader.py index bdb03717c..ab3a5bbc1 100644 --- a/gigl/distributed/distributed_neighborloader.py +++ b/gigl/distributed/distributed_neighborloader.py @@ -25,6 +25,7 @@ from gigl.distributed.utils.neighborloader import ( DatasetSchema, SamplingClusterSetup, + extract_metadata, labeled_to_homogeneous, set_missing_features, shard_nodes_by_process, @@ -563,8 +564,8 @@ def _collate_fn(self, msg: SampleMessage) -> Union[Data, HeteroData]: # to_hetero_data misinterprets #META. keys as edge types and # fails when edge_dir="out" (tries to reverse_edge_type on them). # We strip them here and re-apply after conversion. - non_edge_metadata = self._extract_metadata(msg) - data = super()._collate_fn(msg) + non_edge_metadata, stripped_msg = extract_metadata(msg, self.to_device) + data = super()._collate_fn(stripped_msg) data = set_missing_features( data=data, node_feature_info=self._node_feature_info, diff --git a/gigl/distributed/utils/neighborloader.py b/gigl/distributed/utils/neighborloader.py index fdac550bc..2faf2bf67 100644 --- a/gigl/distributed/utils/neighborloader.py +++ b/gigl/distributed/utils/neighborloader.py @@ -6,6 +6,7 @@ from typing import Literal, Optional, TypeVar, Union import torch +from graphlearn_torch.channel import SampleMessage from torch_geometric.data import Data, HeteroData from torch_geometric.typing import EdgeType, NodeType @@ -264,3 +265,36 @@ def set_missing_features( ) return data + + +def extract_metadata( + msg: SampleMessage, device: torch.device +) -> tuple[dict[str, torch.Tensor], SampleMessage]: + """Separate user-defined metadata from a SampleMessage. + + GLT's ``to_hetero_data`` misinterprets ``#META.``-prefixed keys as + edge types, causing failures with ``edge_dir="out"`` (it tries to call + ``reverse_edge_type`` on metadata key strings). This function separates + metadata from the sampling data so the stripped message can be passed to + GLT's ``_collate_fn`` without triggering the bug. + + The original ``msg`` is not modified. + + Args: + msg: The SampleMessage to extract metadata from. + device: The device to move metadata tensors to. + + Returns: + A 2-tuple of: + - metadata: Dict mapping metadata key (without ``#META.`` prefix) to tensor. + - stripped_msg: A new SampleMessage with ``#META.``-prefixed keys removed. + """ + meta_prefix = "#META." + metadata: dict[str, torch.Tensor] = {} + stripped_msg: SampleMessage = {} + for k, v in msg.items(): + if k.startswith(meta_prefix): + metadata[k[len(meta_prefix) :]] = v.to(device) + else: + stripped_msg[k] = v + return metadata, stripped_msg diff --git a/tests/unit/distributed/utils/neighborloader_test.py b/tests/unit/distributed/utils/neighborloader_test.py index 603b2dadb..51a156f0c 100644 --- a/tests/unit/distributed/utils/neighborloader_test.py +++ b/tests/unit/distributed/utils/neighborloader_test.py @@ -7,6 +7,7 @@ from torch_geometric.typing import EdgeType from gigl.distributed.utils.neighborloader import ( + extract_metadata, labeled_to_homogeneous, patch_fanout_for_sampling, set_missing_features, @@ -490,5 +491,64 @@ def test_set_custom_features_heterogeneous(self): ) +class ExtractMetadataTest(TestCase): + def setUp(self): + self._device = torch.device("cpu") + super().setUp() + + def test_separates_metadata_from_sampling_data(self): + msg = { + "#META.ppr_scores": torch.tensor([1.0, 2.0]), + "#META.custom_key": torch.tensor([3]), + "user.ids": torch.tensor([10, 20]), + "user__to__item.rows": torch.tensor([0, 1]), + } + metadata, stripped_msg = extract_metadata(msg, self._device) + + self.assertEqual(set(metadata.keys()), {"ppr_scores", "custom_key"}) + self.assert_tensor_equality(metadata["ppr_scores"], torch.tensor([1.0, 2.0])) + self.assert_tensor_equality(metadata["custom_key"], torch.tensor([3])) + + self.assertEqual(set(stripped_msg.keys()), {"user.ids", "user__to__item.rows"}) + self.assert_tensor_equality(stripped_msg["user.ids"], torch.tensor([10, 20])) + + def test_no_metadata_keys(self): + msg = { + "user.ids": torch.tensor([10, 20]), + "#IS_HETERO": torch.tensor([1]), + } + metadata, stripped_msg = extract_metadata(msg, self._device) + + self.assertEqual(metadata, {}) + self.assertEqual(set(stripped_msg.keys()), {"user.ids", "#IS_HETERO"}) + + def test_only_metadata_keys(self): + msg = { + "#META.scores": torch.tensor([1.0]), + } + metadata, stripped_msg = extract_metadata(msg, self._device) + + self.assertEqual(set(metadata.keys()), {"scores"}) + self.assertEqual(stripped_msg, {}) + + def test_does_not_modify_original_message(self): + original_tensor = torch.tensor([1.0, 2.0]) + msg = { + "#META.scores": original_tensor, + "user.ids": torch.tensor([10]), + } + original_keys = set(msg.keys()) + + extract_metadata(msg, self._device) + + self.assertEqual(set(msg.keys()), original_keys) + self.assertIn("#META.scores", msg) + + def test_empty_message(self): + metadata, stripped_msg = extract_metadata({}, self._device) + self.assertEqual(metadata, {}) + self.assertEqual(stripped_msg, {}) + + if __name__ == "__main__": absltest.main() From 23b686b53c68eacc5e724ceffa096aee3cd75398 Mon Sep 17 00:00:00 2001 From: mkolodner Date: Fri, 13 Mar 2026 23:20:15 +0000 Subject: [PATCH 38/46] Improve variable names in _compute_ppr_scores --- gigl/distributed/dist_ppr_sampler.py | 73 +++++++++++++--------------- 1 file changed, 35 insertions(+), 38 deletions(-) diff --git a/gigl/distributed/dist_ppr_sampler.py b/gigl/distributed/dist_ppr_sampler.py index 4ac5ae6b8..423b7f648 100644 --- a/gigl/distributed/dist_ppr_sampler.py +++ b/gigl/distributed/dist_ppr_sampler.py @@ -286,18 +286,18 @@ def _get_total_degree(node_id: int, node_type: NodeType) -> int: while num_nodes_in_queue > 0: # Drain all nodes from all queues and group by edge type for batched lookups - nodes_to_process: list[set[tuple[int, NodeType]]] = [ + queued_nodes: list[set[tuple[int, NodeType]]] = [ set() for _ in range(batch_size) ] nodes_by_edge_type: dict[EdgeType, set[int]] = defaultdict(set) for i in range(batch_size): if queue[i]: - nodes_to_process[i] = queue[i] + queued_nodes[i] = queue[i] queue[i] = set() - num_nodes_in_queue -= len(nodes_to_process[i]) + num_nodes_in_queue -= len(queued_nodes[i]) - for node_id, node_type in nodes_to_process[i]: + for node_id, node_type in queued_nodes[i]: edge_types_for_node = self._node_type_to_edge_types[node_type] for etype in edge_types_for_node: cache_key = (node_id, etype) @@ -312,41 +312,43 @@ def _get_total_degree(node_id: int, node_type: NodeType) -> int: # is safe because each seed's state is independent, and residuals # are always positive so the merged loop can never miss a re-queue. for i in range(batch_size): - for u_node, u_type in nodes_to_process[i]: - key_u = (u_node, u_type) - res_u = residuals[i].get(key_u, 0.0) + for source_node, source_type in queued_nodes[i]: + source_key = (source_node, source_type) + source_residual = residuals[i].get(source_key, 0.0) - ppr_scores[i][key_u] += res_u - residuals[i][key_u] = 0.0 + ppr_scores[i][source_key] += source_residual + residuals[i][source_key] = 0.0 - edge_types_for_node = self._node_type_to_edge_types[u_type] + edge_types_for_node = self._node_type_to_edge_types[source_type] - total_degree = _get_total_degree(u_node, u_type) + total_degree = _get_total_degree(source_node, source_type) if total_degree == 0: continue - push_value = one_minus_alpha * res_u / total_degree + residual_per_neighbor = ( + one_minus_alpha * source_residual / total_degree + ) for etype in edge_types_for_node: - cache_key = (u_node, etype) + cache_key = (source_node, etype) neighbor_list = neighbor_cache[cache_key] if not neighbor_list: continue - v_type = self._get_neighbor_type(etype) + neighbor_type = self._get_neighbor_type(etype) - for v_node in neighbor_list: - key_v = (v_node, v_type) - residuals[i][key_v] += push_value + for neighbor_node in neighbor_list: + neighbor_key = (neighbor_node, neighbor_type) + residuals[i][neighbor_key] += residual_per_neighbor - if key_v not in queue[i]: + if neighbor_key not in queue[i]: if residuals[i][ - key_v + neighbor_key ] >= self._requeue_threshold_factor * _get_total_degree( - v_node, v_type + neighbor_node, neighbor_type ): - queue[i].add(key_v) + queue[i].add(neighbor_key) num_nodes_in_queue += 1 # Extract top-k nodes by PPR score, grouped by node type. @@ -357,9 +359,9 @@ def _get_total_degree(node_id: int, node_type: NodeType) -> int: for _node_id, node_type in ppr_scores[i].keys(): all_node_types.add(node_type) - out_flat_ids_dict: dict[NodeType, torch.Tensor] = {} - out_flat_weights_dict: dict[NodeType, torch.Tensor] = {} - out_valid_counts_dict: dict[NodeType, torch.Tensor] = {} + flat_ids_by_ntype: dict[NodeType, torch.Tensor] = {} + flat_weights_by_ntype: dict[NodeType, torch.Tensor] = {} + valid_counts_by_ntype: dict[NodeType, torch.Tensor] = {} for ntype in all_node_types: flat_ids: list[int] = [] @@ -380,33 +382,28 @@ def _get_total_degree(node_id: int, node_type: NodeType) -> int: flat_weights.append(weight) valid_counts.append(len(top_k)) - out_flat_ids_dict[ntype] = torch.tensor( + flat_ids_by_ntype[ntype] = torch.tensor( flat_ids, dtype=torch.long, device=device ) - out_flat_weights_dict[ntype] = torch.tensor( + flat_weights_by_ntype[ntype] = torch.tensor( flat_weights, dtype=torch.float, device=device ) - out_valid_counts_dict[ntype] = torch.tensor( + valid_counts_by_ntype[ntype] = torch.tensor( valid_counts, dtype=torch.long, device=device ) - out_flat_ids: Union[torch.Tensor, dict[NodeType, torch.Tensor]] - out_flat_weights: Union[torch.Tensor, dict[NodeType, torch.Tensor]] - out_valid_counts: Union[torch.Tensor, dict[NodeType, torch.Tensor]] if self._is_homogeneous: assert ( len(all_node_types) == 1 and _PPR_HOMOGENEOUS_NODE_TYPE in all_node_types ) - out_flat_ids = out_flat_ids_dict[_PPR_HOMOGENEOUS_NODE_TYPE] - out_flat_weights = out_flat_weights_dict[_PPR_HOMOGENEOUS_NODE_TYPE] - out_valid_counts = out_valid_counts_dict[_PPR_HOMOGENEOUS_NODE_TYPE] + return ( + flat_ids_by_ntype[_PPR_HOMOGENEOUS_NODE_TYPE], + flat_weights_by_ntype[_PPR_HOMOGENEOUS_NODE_TYPE], + valid_counts_by_ntype[_PPR_HOMOGENEOUS_NODE_TYPE], + ) else: - out_flat_ids = out_flat_ids_dict - out_flat_weights = out_flat_weights_dict - out_valid_counts = out_valid_counts_dict - - return out_flat_ids, out_flat_weights, out_valid_counts + return flat_ids_by_ntype, flat_weights_by_ntype, valid_counts_by_ntype async def _sample_from_nodes( self, From 29abf6111ea2e4b576c48e336c92cdc8d29fa848 Mon Sep 17 00:00:00 2001 From: mkolodner Date: Fri, 13 Mar 2026 23:27:53 +0000 Subject: [PATCH 39/46] Clean up PPR sampler: remove unused return, fix identity check, simplify requeue logic --- gigl/distributed/dist_ppr_sampler.py | 45 ++++++++++++---------------- 1 file changed, 19 insertions(+), 26 deletions(-) diff --git a/gigl/distributed/dist_ppr_sampler.py b/gigl/distributed/dist_ppr_sampler.py index 423b7f648..96c4e01c2 100644 --- a/gigl/distributed/dist_ppr_sampler.py +++ b/gigl/distributed/dist_ppr_sampler.py @@ -95,18 +95,15 @@ def __init__( self._is_homogeneous = True def _get_degree_from_tensor(self, node_id: int, edge_type: EdgeType) -> int: - """ - Look up the TRUE degree of a node for a specific edge type from in-memory tensors. - - This returns the actual node degree (not capped), which is mathematically correct - for PPR algorithm calculations. + """Look up the degree of a node for a specific edge type from in-memory tensors. Args: node_id: The ID of the node to look up. edge_type: The edge type to get the degree for. Returns: - The true degree of the node for the given edge type. + The degree of the node for the given edge type, or 0 if the node + or edge type is not found. """ if self._is_homogeneous: # For homogeneous graphs, degree_tensors is a single tensor @@ -124,8 +121,8 @@ def _get_degree_from_tensor(self, node_id: int, edge_type: EdgeType) -> int: return 0 return int(degree_tensor[node_id].item()) - def _get_neighbor_type(self, edge_type: EdgeType) -> NodeType: - """Get the node type of neighbors reached via an edge type.""" + def _get_destination_type(self, edge_type: EdgeType) -> NodeType: + """Get the node type at the destination end of an edge type.""" return edge_type[0] if self.edge_dir == "in" else edge_type[-1] async def _get_neighbors_for_nodes( @@ -145,7 +142,7 @@ async def _get_neighbors_for_nodes( output: NeighborOutput = await self._sample_one_hop( srcs=nodes, num_nbr=self._num_nbrs_per_hop, - etype=edge_type if edge_type is not _PPR_HOMOGENEOUS_EDGE_TYPE else None, + etype=edge_type if edge_type != _PPR_HOMOGENEOUS_EDGE_TYPE else None, ) return output.nbr, output.nbr_num @@ -154,7 +151,7 @@ async def _batch_fetch_neighbors( nodes_by_edge_type: dict[EdgeType, set[int]], neighbor_target: dict[tuple[int, EdgeType], list[int]], device: torch.device, - ) -> int: + ) -> None: """ Batch fetch neighbors for nodes grouped by edge type. @@ -166,11 +163,7 @@ async def _batch_fetch_neighbors( nodes_by_edge_type: Dict mapping edge type to set of node IDs to fetch neighbor_target: Dict to populate with (node_id, edge_type) -> neighbor list device: Torch device for tensor creation - - Returns: - Number of neighbor lookup calls made """ - num_lookups = 0 for etype, node_ids in nodes_by_edge_type.items(): if not node_ids: continue @@ -181,7 +174,6 @@ async def _batch_fetch_neighbors( lookup_tensor, etype, ) - num_lookups += 1 neighbors_list = neighbors.tolist() counts_list = neighbor_counts.tolist() @@ -196,8 +188,6 @@ async def _batch_fetch_neighbors( neighbor_target[cache_key] = neighbors_list[offset : offset + count] offset += count - return num_lookups - async def _compute_ppr_scores( self, seed_nodes: torch.Tensor, @@ -336,20 +326,23 @@ def _get_total_degree(node_id: int, node_type: NodeType) -> int: if not neighbor_list: continue - neighbor_type = self._get_neighbor_type(etype) + neighbor_type = self._get_destination_type(etype) for neighbor_node in neighbor_list: neighbor_key = (neighbor_node, neighbor_type) residuals[i][neighbor_key] += residual_per_neighbor - if neighbor_key not in queue[i]: - if residuals[i][ - neighbor_key - ] >= self._requeue_threshold_factor * _get_total_degree( - neighbor_node, neighbor_type - ): - queue[i].add(neighbor_key) - num_nodes_in_queue += 1 + requeue_threshold = ( + self._requeue_threshold_factor + * _get_total_degree(neighbor_node, neighbor_type) + ) + should_requeue = ( + neighbor_key not in queue[i] + and residuals[i][neighbor_key] >= requeue_threshold + ) + if should_requeue: + queue[i].add(neighbor_key) + num_nodes_in_queue += 1 # Extract top-k nodes by PPR score, grouped by node type. # Build flat tensors directly (no padding) — valid_counts[i] records how many From e97fe3b520ceac552634ac4f0e47a43bf06a34c8 Mon Sep 17 00:00:00 2001 From: mkolodner Date: Fri, 13 Mar 2026 23:50:07 +0000 Subject: [PATCH 40/46] Restructure PPR state by node type, inline _get_neighbors_for_nodes, fail-fast on invalid degree lookups --- gigl/distributed/dist_ppr_sampler.py | 192 ++++++++++++++------------- 1 file changed, 99 insertions(+), 93 deletions(-) diff --git a/gigl/distributed/dist_ppr_sampler.py b/gigl/distributed/dist_ppr_sampler.py index 96c4e01c2..b9ee47d56 100644 --- a/gigl/distributed/dist_ppr_sampler.py +++ b/gigl/distributed/dist_ppr_sampler.py @@ -102,50 +102,40 @@ def _get_degree_from_tensor(self, node_id: int, edge_type: EdgeType) -> int: edge_type: The edge type to get the degree for. Returns: - The degree of the node for the given edge type, or 0 if the node - or edge type is not found. + The degree of the node for the given edge type. + + Raises: + ValueError: If the edge type is missing from the degree tensors or + the node ID is out of range. Both indicate corrupted graph data + or a sampler bug. """ if self._is_homogeneous: - # For homogeneous graphs, degree_tensors is a single tensor assert isinstance(self._degree_tensors, torch.Tensor) if node_id >= len(self._degree_tensors): - return 0 + raise ValueError( + f"Node ID {node_id} exceeds degree tensor length " + f"({len(self._degree_tensors)})." + ) return int(self._degree_tensors[node_id].item()) else: - # For heterogeneous graphs, degree_tensors is a dict keyed by edge type assert isinstance(self._degree_tensors, dict) if edge_type not in self._degree_tensors: - return 0 + raise ValueError( + f"Edge type {edge_type} not found in degree tensors. " + f"Available: {list(self._degree_tensors.keys())}" + ) degree_tensor = self._degree_tensors[edge_type] if node_id >= len(degree_tensor): - return 0 + raise ValueError( + f"Node ID {node_id} exceeds degree tensor length " + f"({len(degree_tensor)}) for edge type {edge_type}." + ) return int(degree_tensor[node_id].item()) def _get_destination_type(self, edge_type: EdgeType) -> NodeType: """Get the node type at the destination end of an edge type.""" return edge_type[0] if self.edge_dir == "in" else edge_type[-1] - async def _get_neighbors_for_nodes( - self, - nodes: torch.Tensor, - edge_type: EdgeType, - ) -> tuple[torch.Tensor, torch.Tensor]: - """ - Fetch neighbors for a batch of nodes. - - Returns: - tuple of (neighbors, neighbor_counts) where neighbors is a flattened tensor - and neighbor_counts[i] gives the number of neighbors for nodes[i]. - """ - # Use the underlying sampling infrastructure to get all neighbors - # We request a large number to effectively get all neighbors - output: NeighborOutput = await self._sample_one_hop( - srcs=nodes, - num_nbr=self._num_nbrs_per_hop, - etype=edge_type if edge_type != _PPR_HOMOGENEOUS_EDGE_TYPE else None, - ) - return output.nbr, output.nbr_num - async def _batch_fetch_neighbors( self, nodes_by_edge_type: dict[EdgeType, set[int]], @@ -170,10 +160,14 @@ async def _batch_fetch_neighbors( nodes_list = list(node_ids) lookup_tensor = torch.tensor(nodes_list, dtype=torch.long, device=device) - neighbors, neighbor_counts = await self._get_neighbors_for_nodes( - lookup_tensor, - etype, + # _sample_one_hop expects None for homogeneous graphs, not the PPR sentinel. + output: NeighborOutput = await self._sample_one_hop( + srcs=lookup_tensor, + num_nbr=self._num_nbrs_per_hop, + etype=etype if etype != _PPR_HOMOGENEOUS_EDGE_TYPE else None, ) + neighbors = output.nbr + neighbor_counts = output.nbr_num neighbors_list = neighbors.tolist() counts_list = neighbor_counts.tolist() @@ -234,19 +228,34 @@ async def _compute_ppr_scores( device = seed_nodes.device batch_size = seed_nodes.size(0) - ppr_scores: list[dict[tuple[int, NodeType], float]] = [ - defaultdict(float) for _ in range(batch_size) + # Per-seed PPR state, nested by node type for efficient type-grouped access. + # + # ppr_scores[i][node_type][node_id] = accumulated PPR score for node_id + # of type node_type, relative to seed i. Updated each iteration by + # absorbing the node's residual. + # + # residuals[i][node_type][node_id] = unconverged probability mass at node_id + # of type node_type for seed i. Each iteration, a node's residual is + # absorbed into its PPR score and then distributed to its neighbors. + # + # queue[i][node_type] = set of node IDs whose residual exceeds the + # convergence threshold (alpha * eps * total_degree). The algorithm + # terminates when all queues are empty. + ppr_scores: list[dict[NodeType, dict[int, float]]] = [ + defaultdict(lambda: defaultdict(float)) for _ in range(batch_size) + ] + residuals: list[dict[NodeType, dict[int, float]]] = [ + defaultdict(lambda: defaultdict(float)) for _ in range(batch_size) ] - residuals: list[dict[tuple[int, NodeType], float]] = [ - defaultdict(float) for _ in range(batch_size) + queue: list[dict[NodeType, set[int]]] = [ + defaultdict(set) for _ in range(batch_size) ] - queue: list[set[tuple[int, NodeType]]] = [set() for _ in range(batch_size)] seed_list = seed_nodes.tolist() for i, seed in enumerate(seed_list): - residuals[i][(seed, seed_node_type)] = self._alpha - queue[i].add((seed, seed_node_type)) + residuals[i][seed_node_type][seed] = self._alpha + queue[i][seed_node_type].add(seed) # Cache keyed by (node_id, edge_type) since same node can have different neighbors per edge type neighbor_cache: dict[tuple[int, EdgeType], list[int]] = {} @@ -276,23 +285,23 @@ def _get_total_degree(node_id: int, node_type: NodeType) -> int: while num_nodes_in_queue > 0: # Drain all nodes from all queues and group by edge type for batched lookups - queued_nodes: list[set[tuple[int, NodeType]]] = [ - set() for _ in range(batch_size) + queued_nodes: list[dict[NodeType, set[int]]] = [ + defaultdict(set) for _ in range(batch_size) ] nodes_by_edge_type: dict[EdgeType, set[int]] = defaultdict(set) for i in range(batch_size): if queue[i]: queued_nodes[i] = queue[i] - queue[i] = set() - num_nodes_in_queue -= len(queued_nodes[i]) - - for node_id, node_type in queued_nodes[i]: + queue[i] = defaultdict(set) + for node_type, node_ids in queued_nodes[i].items(): + num_nodes_in_queue -= len(node_ids) edge_types_for_node = self._node_type_to_edge_types[node_type] - for etype in edge_types_for_node: - cache_key = (node_id, etype) - if cache_key not in neighbor_cache: - nodes_by_edge_type[etype].add(node_id) + for node_id in node_ids: + for etype in edge_types_for_node: + cache_key = (node_id, etype) + if cache_key not in neighbor_cache: + nodes_by_edge_type[etype].add(node_id) await self._batch_fetch_neighbors( nodes_by_edge_type, neighbor_cache, device @@ -302,55 +311,56 @@ def _get_total_degree(node_id: int, node_type: NodeType) -> int: # is safe because each seed's state is independent, and residuals # are always positive so the merged loop can never miss a re-queue. for i in range(batch_size): - for source_node, source_type in queued_nodes[i]: - source_key = (source_node, source_type) - source_residual = residuals[i].get(source_key, 0.0) - - ppr_scores[i][source_key] += source_residual - residuals[i][source_key] = 0.0 + for source_type, source_nodes in queued_nodes[i].items(): + for source_node in source_nodes: + source_residual = residuals[i][source_type].get( + source_node, 0.0 + ) - edge_types_for_node = self._node_type_to_edge_types[source_type] + ppr_scores[i][source_type][source_node] += source_residual + residuals[i][source_type][source_node] = 0.0 - total_degree = _get_total_degree(source_node, source_type) + edge_types_for_node = self._node_type_to_edge_types[source_type] - if total_degree == 0: - continue + total_degree = _get_total_degree(source_node, source_type) - residual_per_neighbor = ( - one_minus_alpha * source_residual / total_degree - ) - - for etype in edge_types_for_node: - cache_key = (source_node, etype) - neighbor_list = neighbor_cache[cache_key] - if not neighbor_list: + if total_degree == 0: continue - neighbor_type = self._get_destination_type(etype) + residual_per_neighbor = ( + one_minus_alpha * source_residual / total_degree + ) - for neighbor_node in neighbor_list: - neighbor_key = (neighbor_node, neighbor_type) - residuals[i][neighbor_key] += residual_per_neighbor - - requeue_threshold = ( - self._requeue_threshold_factor - * _get_total_degree(neighbor_node, neighbor_type) - ) - should_requeue = ( - neighbor_key not in queue[i] - and residuals[i][neighbor_key] >= requeue_threshold - ) - if should_requeue: - queue[i].add(neighbor_key) - num_nodes_in_queue += 1 + for etype in edge_types_for_node: + cache_key = (source_node, etype) + neighbor_list = neighbor_cache[cache_key] + if not neighbor_list: + continue + + neighbor_type = self._get_destination_type(etype) + + for neighbor_node in neighbor_list: + residuals[i][neighbor_type][ + neighbor_node + ] += residual_per_neighbor + + requeue_threshold = ( + self._requeue_threshold_factor + * _get_total_degree(neighbor_node, neighbor_type) + ) + should_requeue = ( + neighbor_node not in queue[i][neighbor_type] + and residuals[i][neighbor_type][neighbor_node] + >= requeue_threshold + ) + if should_requeue: + queue[i][neighbor_type].add(neighbor_node) + num_nodes_in_queue += 1 # Extract top-k nodes by PPR score, grouped by node type. # Build flat tensors directly (no padding) — valid_counts[i] records how many # neighbors seed i actually has, so callers can recover per-seed slices. - all_node_types: set[NodeType] = set() - for i in range(batch_size): - for _node_id, node_type in ppr_scores[i].keys(): - all_node_types.add(node_type) + all_node_types = self._node_type_to_edge_types.keys() flat_ids_by_ntype: dict[NodeType, torch.Tensor] = {} flat_weights_by_ntype: dict[NodeType, torch.Tensor] = {} @@ -362,11 +372,7 @@ def _get_total_degree(node_id: int, node_type: NodeType) -> int: valid_counts: list[int] = [] for i in range(batch_size): - type_scores = { - node_id: score - for (node_id, node_type), score in ppr_scores[i].items() - if node_type == ntype - } + type_scores = ppr_scores[i].get(ntype, {}) top_k = heapq.nlargest( self._max_ppr_nodes, type_scores.items(), key=lambda x: x[1] ) @@ -387,8 +393,8 @@ def _get_total_degree(node_id: int, node_type: NodeType) -> int: if self._is_homogeneous: assert ( - len(all_node_types) == 1 - and _PPR_HOMOGENEOUS_NODE_TYPE in all_node_types + len(flat_ids_by_ntype) == 1 + and _PPR_HOMOGENEOUS_NODE_TYPE in flat_ids_by_ntype ) return ( flat_ids_by_ntype[_PPR_HOMOGENEOUS_NODE_TYPE], From f88cd4cee1b1a2017feaa9845fcf8b1296b99185 Mon Sep 17 00:00:00 2001 From: mkolodner Date: Fri, 13 Mar 2026 23:53:09 +0000 Subject: [PATCH 41/46] Document why queue uses a set --- gigl/distributed/dist_ppr_sampler.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/gigl/distributed/dist_ppr_sampler.py b/gigl/distributed/dist_ppr_sampler.py index b9ee47d56..5fff35511 100644 --- a/gigl/distributed/dist_ppr_sampler.py +++ b/gigl/distributed/dist_ppr_sampler.py @@ -240,7 +240,10 @@ async def _compute_ppr_scores( # # queue[i][node_type] = set of node IDs whose residual exceeds the # convergence threshold (alpha * eps * total_degree). The algorithm - # terminates when all queues are empty. + # terminates when all queues are empty. A set is used because multiple + # neighbors can push residual to the same node in one iteration — + # deduplication avoids redundant processing, and the O(1) membership + # check matters since it runs in the innermost loop. ppr_scores: list[dict[NodeType, dict[int, float]]] = [ defaultdict(lambda: defaultdict(float)) for _ in range(batch_size) ] From 03fc3546d7f3f80a3dc218fdfc15126c28484b72 Mon Sep 17 00:00:00 2001 From: mkolodner Date: Sat, 14 Mar 2026 00:04:44 +0000 Subject: [PATCH 42/46] Precompute total degree tensors at init, remove per-call caching --- gigl/distributed/dist_ppr_sampler.py | 125 ++++++++++++++++----------- 1 file changed, 73 insertions(+), 52 deletions(-) diff --git a/gigl/distributed/dist_ppr_sampler.py b/gigl/distributed/dist_ppr_sampler.py index 5fff35511..1a3a93e4e 100644 --- a/gigl/distributed/dist_ppr_sampler.py +++ b/gigl/distributed/dist_ppr_sampler.py @@ -48,6 +48,9 @@ class DistPPRNeighborSampler(DistNeighborSampler): but require more computation. Typical values: 1e-4 to 1e-6. max_ppr_nodes: Maximum number of nodes to return per seed based on PPR scores. num_nbrs_per_hop: Maximum number of neighbors to fetch per hop. + total_degree_dtype: Dtype for precomputed total-degree tensors. Defaults to + ``torch.int32``, which supports total degrees up to ~2 billion. Use a + larger dtype if nodes have exceptionally high aggregate degrees. """ def __init__( @@ -57,6 +60,7 @@ def __init__( eps: float = 1e-4, max_ppr_nodes: int = 50, num_nbrs_per_hop: int = 100000, + total_degree_dtype: torch.dtype = torch.int32, **kwargs, ): super().__init__(*args, **kwargs) @@ -69,7 +73,7 @@ def __init__( assert isinstance( self.data, DistDataset ), "DistPPRNeighborSampler requires a GiGL DistDataset to access degree tensors." - self._degree_tensors = self.data.degree_tensor + degree_tensors = self.data.degree_tensor # Build mapping from node type to edge types that can be traversed from that node type. self._node_type_to_edge_types: dict[NodeType, list[EdgeType]] = defaultdict( @@ -94,43 +98,78 @@ def __init__( ] self._is_homogeneous = True - def _get_degree_from_tensor(self, node_id: int, edge_type: EdgeType) -> int: - """Look up the degree of a node for a specific edge type from in-memory tensors. + # Precompute total degree per node type: the sum of degrees across all + # edge types traversable from that node type. This is a graph-level + # property used on every PPR iteration, so computing it once at init + # avoids per-node summation and cache lookups in the hot loop. + self._total_degree_by_node_type: dict[ + NodeType, torch.Tensor + ] = self._build_total_degree_tensors(degree_tensors, total_degree_dtype) + + def _build_total_degree_tensors( + self, + degree_tensors: Union[torch.Tensor, dict[EdgeType, torch.Tensor]], + dtype: torch.dtype, + ) -> dict[NodeType, torch.Tensor]: + """Build total-degree tensors by summing per-edge-type degrees for each node type. + + For homogeneous graphs, the total degree is just the single degree tensor. + For heterogeneous graphs, it sums degree tensors across all edge types + traversable from each node type, padding shorter tensors with zeros. + + Args: + degree_tensors: Per-edge-type degree tensors from the dataset. + dtype: Dtype for the output tensors. + + Returns: + Dict mapping node type to a 1-D tensor of total degrees. + """ + result: dict[NodeType, torch.Tensor] = {} + + if self._is_homogeneous: + assert isinstance(degree_tensors, torch.Tensor) + result[_PPR_HOMOGENEOUS_NODE_TYPE] = degree_tensors.to(dtype) + else: + assert isinstance(degree_tensors, dict) + for node_type, edge_types in self._node_type_to_edge_types.items(): + max_len = 0 + for et in edge_types: + if et not in degree_tensors: + raise ValueError( + f"Edge type {et} not found in degree tensors. " + f"Available: {list(degree_tensors.keys())}" + ) + max_len = max(max_len, len(degree_tensors[et])) + + summed = torch.zeros(max_len, dtype=dtype) + for et in edge_types: + et_degrees = degree_tensors[et] + summed[: len(et_degrees)] += et_degrees.to(dtype) + result[node_type] = summed + + return result + + def _get_total_degree(self, node_id: int, node_type: NodeType) -> int: + """Look up the precomputed total degree of a node. Args: node_id: The ID of the node to look up. - edge_type: The edge type to get the degree for. + node_type: The node type. Returns: - The degree of the node for the given edge type. + The total degree (sum across all edge types) for the node. Raises: - ValueError: If the edge type is missing from the degree tensors or - the node ID is out of range. Both indicate corrupted graph data - or a sampler bug. + ValueError: If the node ID is out of range, indicating corrupted + graph data or a sampler bug. """ - if self._is_homogeneous: - assert isinstance(self._degree_tensors, torch.Tensor) - if node_id >= len(self._degree_tensors): - raise ValueError( - f"Node ID {node_id} exceeds degree tensor length " - f"({len(self._degree_tensors)})." - ) - return int(self._degree_tensors[node_id].item()) - else: - assert isinstance(self._degree_tensors, dict) - if edge_type not in self._degree_tensors: - raise ValueError( - f"Edge type {edge_type} not found in degree tensors. " - f"Available: {list(self._degree_tensors.keys())}" - ) - degree_tensor = self._degree_tensors[edge_type] - if node_id >= len(degree_tensor): - raise ValueError( - f"Node ID {node_id} exceeds degree tensor length " - f"({len(degree_tensor)}) for edge type {edge_type}." - ) - return int(degree_tensor[node_id].item()) + degree_tensor = self._total_degree_by_node_type[node_type] + if node_id >= len(degree_tensor): + raise ValueError( + f"Node ID {node_id} exceeds total degree tensor length " + f"({len(degree_tensor)}) for node type {node_type}." + ) + return int(degree_tensor[node_id].item()) def _get_destination_type(self, edge_type: EdgeType) -> NodeType: """Get the node type at the destination end of an edge type.""" @@ -263,26 +302,6 @@ async def _compute_ppr_scores( # Cache keyed by (node_id, edge_type) since same node can have different neighbors per edge type neighbor_cache: dict[tuple[int, EdgeType], list[int]] = {} - # Cache for total degree (sum across all edge types for a node type). - # The per-edge-type degree is already O(1) via degree_tensors, but the - # *sum* across edge types is recomputed each time a node appears as a - # neighbor — which can be many times across seeds and iterations. - # Caching the sum avoids redundant _get_degree_from_tensor calls and - # the per-call Python overhead (method dispatch, isinstance, .item()). - total_degree_cache: dict[tuple[int, NodeType], int] = {} - - def _get_total_degree(node_id: int, node_type: NodeType) -> int: - key = (node_id, node_type) - cached = total_degree_cache.get(key) - if cached is not None: - return cached - total = sum( - self._get_degree_from_tensor(node_id, et) - for et in self._node_type_to_edge_types.get(node_type, []) - ) - total_degree_cache[key] = total - return total - num_nodes_in_queue = batch_size one_minus_alpha = 1 - self._alpha @@ -325,7 +344,7 @@ def _get_total_degree(node_id: int, node_type: NodeType) -> int: edge_types_for_node = self._node_type_to_edge_types[source_type] - total_degree = _get_total_degree(source_node, source_type) + total_degree = self._get_total_degree(source_node, source_type) if total_degree == 0: continue @@ -349,7 +368,9 @@ def _get_total_degree(node_id: int, node_type: NodeType) -> int: requeue_threshold = ( self._requeue_threshold_factor - * _get_total_degree(neighbor_node, neighbor_type) + * self._get_total_degree( + neighbor_node, neighbor_type + ) ) should_requeue = ( neighbor_node not in queue[i][neighbor_type] From 141c4b1a247d033a961d15820f6a676e4e01102d Mon Sep 17 00:00:00 2001 From: mkolodner Date: Sat, 14 Mar 2026 00:10:41 +0000 Subject: [PATCH 43/46] Add TODO comment on total degree memory tradeoff --- gigl/distributed/dist_ppr_sampler.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/gigl/distributed/dist_ppr_sampler.py b/gigl/distributed/dist_ppr_sampler.py index 1a3a93e4e..1feca46bf 100644 --- a/gigl/distributed/dist_ppr_sampler.py +++ b/gigl/distributed/dist_ppr_sampler.py @@ -102,6 +102,11 @@ def __init__( # edge types traversable from that node type. This is a graph-level # property used on every PPR iteration, so computing it once at init # avoids per-node summation and cache lookups in the hot loop. + # TODO (mkolodner-sc): This trades memory for throughput — we + # materialize a tensor per node type to avoid recomputing total degree + # on every neighbor during sampling. Computing it here (rather than in + # the dataset) also keeps the door open for edge-specific degree + # strategies. If memory becomes a bottleneck, revisit this. self._total_degree_by_node_type: dict[ NodeType, torch.Tensor ] = self._build_total_degree_tensors(degree_tensors, total_degree_dtype) From e481eb51646f788c3a2992ec978dbc4f85be04b7 Mon Sep 17 00:00:00 2001 From: mkolodner Date: Sat, 14 Mar 2026 00:17:38 +0000 Subject: [PATCH 44/46] Expose total_degree_dtype in PPRSamplerOptions, document valid_counts and inducer usage --- gigl/distributed/dist_ppr_sampler.py | 40 +++++++++++++++++++--- gigl/distributed/dist_sampling_producer.py | 1 + gigl/distributed/sampler_options.py | 5 +++ 3 files changed, 42 insertions(+), 4 deletions(-) diff --git a/gigl/distributed/dist_ppr_sampler.py b/gigl/distributed/dist_ppr_sampler.py index 1feca46bf..24cc91f02 100644 --- a/gigl/distributed/dist_ppr_sampler.py +++ b/gigl/distributed/dist_ppr_sampler.py @@ -387,8 +387,27 @@ async def _compute_ppr_scores( num_nodes_in_queue += 1 # Extract top-k nodes by PPR score, grouped by node type. - # Build flat tensors directly (no padding) — valid_counts[i] records how many - # neighbors seed i actually has, so callers can recover per-seed slices. + # Results are three flat tensors per node type (no padding): + # - flat_ids: [id_seed0_0, id_seed0_1, ..., id_seed1_0, ...] + # - flat_weights: [wt_seed0_0, wt_seed0_1, ..., wt_seed1_0, ...] + # - valid_counts: [count_seed0, count_seed1, ...] + # + # valid_counts[i] records how many top-k neighbors seed i contributed. + # Callers use it to slice flat_ids/flat_weights back into per-seed + # groups and to build PyG edge-index tensors via repeat_interleave: + # + # Example: 3 seeds, valid_counts = [2, 3, 1] + # flat_dst = [dst_0a, dst_0b, dst_1a, dst_1b, dst_1c, dst_2a] + # + # src_indices = repeat_interleave(arange(3), valid_counts) + # = [0, 0, 1, 1, 1, 2] + # + # edge_index = stack([src_indices, flat_dst]) + # = [[0, 0, 1, 1, 1, 2], + # [dst_0a, dst_0b, dst_1a, dst_1b, dst_1c, dst_2a]] + # + # Column j means "edge from seed src_indices[j] to neighbor flat_dst[j]" + # with PPR weight flat_weights[j]. all_node_types = self._node_type_to_edge_types.keys() flat_ids_by_ntype: dict[NodeType, torch.Tensor] = {} @@ -498,8 +517,21 @@ async def _sample_from_nodes( metadata = sample_loop_inputs.metadata nodes_to_sample = sample_loop_inputs.nodes_to_sample - # Acquired once per sample; returned to the pool at the end. The inducer - # maintains the shared global→local index map for this entire subgraph. + # The inducer is GLT's C++ data structure that maintains a global-ID → + # local-index mapping for the subgraph being built. It serves two roles: + # + # 1. Deduplication: when the same global node ID appears from multiple + # seeds or seed types, induce_next assigns it a single local index. + # This ensures node_dict[ntype] has no duplicates. + # + # 2. Local index assignment: init_node registers seeds at local indices + # 0..N-1. induce_next then assigns the next available indices to + # newly discovered neighbors. The returned "cols" tensor contains + # the local destination index for every neighbor (including those + # that were already registered), which we use directly as row 1 of + # the PyG edge-index tensor. + # + # Acquired once per sample call; returned to the pool at the end. inducer = self._acquire_inducer() if is_hetero: diff --git a/gigl/distributed/dist_sampling_producer.py b/gigl/distributed/dist_sampling_producer.py index 4960b756f..c07e8caaf 100644 --- a/gigl/distributed/dist_sampling_producer.py +++ b/gigl/distributed/dist_sampling_producer.py @@ -109,6 +109,7 @@ def _sampling_worker_loop( "eps": sampler_options.eps, "max_ppr_nodes": sampler_options.max_ppr_nodes, "num_nbrs_per_hop": sampler_options.num_nbrs_per_hop, + "total_degree_dtype": sampler_options.total_degree_dtype, } else: raise NotImplementedError( diff --git a/gigl/distributed/sampler_options.py b/gigl/distributed/sampler_options.py index 52d21aa26..678756795 100644 --- a/gigl/distributed/sampler_options.py +++ b/gigl/distributed/sampler_options.py @@ -10,6 +10,7 @@ from dataclasses import dataclass from typing import Optional, Union +import torch from graphlearn_torch.typing import EdgeType from gigl.common.logger import Logger @@ -47,12 +48,16 @@ class PPRSamplerOptions: num_nbrs_per_hop: Maximum number of neighbors fetched per node per edge type during PPR traversal. Set large to approximate fetching all neighbors. + total_degree_dtype: Dtype for precomputed total-degree tensors. Defaults + to ``torch.int32``, which supports total degrees up to ~2 billion. + Use a larger dtype if nodes have exceptionally high aggregate degrees. """ alpha: float = 0.5 eps: float = 1e-4 max_ppr_nodes: int = 50 num_nbrs_per_hop: int = 100000 + total_degree_dtype: torch.dtype = torch.int32 SamplerOptions = Union[KHopNeighborSamplerOptions, PPRSamplerOptions] From 04b5e2bf05b2a6f887edb1c62a41f97560297eb4 Mon Sep 17 00:00:00 2001 From: mkolodner Date: Sat, 14 Mar 2026 00:24:30 +0000 Subject: [PATCH 45/46] Small comment adjustment --- gigl/distributed/dist_ppr_sampler.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/gigl/distributed/dist_ppr_sampler.py b/gigl/distributed/dist_ppr_sampler.py index 24cc91f02..453ca1f72 100644 --- a/gigl/distributed/dist_ppr_sampler.py +++ b/gigl/distributed/dist_ppr_sampler.py @@ -479,8 +479,6 @@ async def _sample_from_nodes( ``ppr_neighbor_ids`` directly indexes into ``data[ntype].x`` without any additional global→local remapping. - **Why the inducer is used for local-index assignment:** - The inducer is GLT's C++ data structure (backed by a per-node-type hash map) that maintains a single global-ID → local-index mapping for the entire subgraph being built. We use it here instead of a Python dict for two reasons: From d6b77fd901c0d2eb31e6d00da923b93741f85ea6 Mon Sep 17 00:00:00 2001 From: mkolodner Date: Sat, 14 Mar 2026 00:26:50 +0000 Subject: [PATCH 46/46] Reformat @param to use named keyword args matching codebase convention --- .../unit/distributed/dist_ppr_sampler_test.py | 26 ++++++++++++++----- 1 file changed, 20 insertions(+), 6 deletions(-) diff --git a/tests/unit/distributed/dist_ppr_sampler_test.py b/tests/unit/distributed/dist_ppr_sampler_test.py index f3538ea78..dc2eb78a5 100644 --- a/tests/unit/distributed/dist_ppr_sampler_test.py +++ b/tests/unit/distributed/dist_ppr_sampler_test.py @@ -569,24 +569,38 @@ def tearDown(self) -> None: torch.distributed.destroy_process_group() super().tearDown() - @parameterized.expand([param("in"), param("out")]) - def test_ppr_sampler_correctness_homogeneous(self, edge_dir: str) -> None: + @parameterized.expand( + [ + param("edge_dir_in", edge_dir="in"), + param("edge_dir_out", edge_dir="out"), + ] + ) + def test_ppr_sampler_correctness_homogeneous(self, _, edge_dir: str) -> None: """Verify PPR scores match NetworkX pagerank on a small homogeneous graph.""" mp.spawn( fn=_run_ppr_loader_correctness_check, args=(_TEST_ALPHA, _TEST_MAX_PPR_NODES, edge_dir), ) - @parameterized.expand([param("in"), param("out")]) - def test_ppr_sampler_correctness_heterogeneous(self, edge_dir: str) -> None: + @parameterized.expand( + [ + param("edge_dir_in", edge_dir="in"), + param("edge_dir_out", edge_dir="out"), + ] + ) + def test_ppr_sampler_correctness_heterogeneous(self, _, edge_dir: str) -> None: """Verify PPR scores match NetworkX pagerank on a heterogeneous bipartite graph.""" mp.spawn( fn=_run_ppr_hetero_loader_correctness_check, args=(_TEST_ALPHA, _TEST_MAX_PPR_NODES, edge_dir), ) - @parameterized.expand([param("out")]) - def test_ppr_sampler_ablp_correctness(self, edge_dir: str) -> None: + @parameterized.expand( + [ + param("edge_dir_out", edge_dir="out"), + ] + ) + def test_ppr_sampler_ablp_correctness(self, _, edge_dir: str) -> None: """Verify PPR scores through DistABLPLoader on a heterogeneous graph. Only tests ``edge_dir="out"`` because ``DistNodeAnchorLinkSplitter``