From 71c4077e3db509246789b6d763ff123664709aab Mon Sep 17 00:00:00 2001 From: kmontemayor Date: Wed, 11 Feb 2026 16:36:59 +0000 Subject: [PATCH] WIP Draft for adding base dataloader to gigl --- gigl/distributed/__init__.py | 10 + gigl/distributed/base_dist_loader.py | 283 ++++++++++++++++++ gigl/distributed/dist_ablp_neighborloader.py | 150 +++------- .../distributed/distributed_neighborloader.py | 91 ++++-- gigl/distributed/sampling_engine.py | 211 +++++++++++++ 5 files changed, 612 insertions(+), 133 deletions(-) create mode 100644 gigl/distributed/base_dist_loader.py create mode 100644 gigl/distributed/sampling_engine.py diff --git a/gigl/distributed/__init__.py b/gigl/distributed/__init__.py index 15f04e8cf..b4c28c9fa 100644 --- a/gigl/distributed/__init__.py +++ b/gigl/distributed/__init__.py @@ -3,15 +3,20 @@ """ __all__ = [ + "BaseDistLoader", + "ColocatedSamplingEngine", "DistNeighborLoader", "DistDataset", "DistributedContext", "DistPartitioner", "DistRangePartitioner", + "GraphStoreSamplingEngine", + "SamplingEngine", "build_dataset", "build_dataset_from_task_config_uri", ] +from gigl.distributed.base_dist_loader import BaseDistLoader from gigl.distributed.dataset_factory import ( build_dataset, build_dataset_from_task_config_uri, @@ -22,3 +27,8 @@ from gigl.distributed.dist_partitioner import DistPartitioner from gigl.distributed.dist_range_partitioner import DistRangePartitioner from gigl.distributed.distributed_neighborloader import DistNeighborLoader +from gigl.distributed.sampling_engine import ( + ColocatedSamplingEngine, + GraphStoreSamplingEngine, + SamplingEngine, +) diff --git a/gigl/distributed/base_dist_loader.py b/gigl/distributed/base_dist_loader.py new file mode 100644 index 000000000..77bd800d1 --- /dev/null +++ b/gigl/distributed/base_dist_loader.py @@ -0,0 +1,283 @@ +""" +GiGL's mode-agnostic distributed data loader. + +Replaces GLT's DistLoader. Delegates all sampling lifecycle to a SamplingEngine, +and provides a composable ``_base_collate`` method for converting SampleMessages +into PyG Data/HeteroData objects. +""" + +from typing import List, Optional, Union + +import torch +from graphlearn_torch.channel import SampleMessage +from graphlearn_torch.loader import to_data, to_hetero_data +from graphlearn_torch.sampler import ( + HeteroSamplerOutput, + SamplerOutput, + SamplingConfig, + SamplingType, +) +from graphlearn_torch.typing import EdgeType, NodeType, as_str, reverse_edge_type +from graphlearn_torch.utils import ensure_device, python_exit_status +from torch_geometric.data import Data, HeteroData + +from gigl.distributed.sampling_engine import SamplingEngine + + +class BaseDistLoader(object): + """GiGL's mode-agnostic distributed data loader. + + Replaces GLT's DistLoader. Delegates all sampling lifecycle to a + :class:`SamplingEngine` instance. + + Subclasses override ``_collate_fn`` and use ``_base_collate`` for the core + SampleMessage-to-PyG conversion. This composable pattern allows each + subclass to control the collation pipeline explicitly without relying on + ``super()._collate_fn()``. + + Args: + engine: A :class:`SamplingEngine` that handles epoch start, sample + retrieval, and shutdown. + sampling_config: Configuration for sampling (batch size, neighbors, etc.). + to_device: Target device for collated results. + input_type: The node type of the input seeds (for heterogeneous graphs). + node_types: List of node types in the graph. + edge_types: List of edge types in the graph. + """ + + def __init__( + self, + engine: SamplingEngine, + sampling_config: SamplingConfig, + to_device: torch.device, + input_type: Optional[Union[str, NodeType]] = None, + node_types: Optional[List[NodeType]] = None, + edge_types: Optional[List[EdgeType]] = None, + ): + self._engine = engine + self.sampling_config = sampling_config + # Unpack commonly used fields for _base_collate compatibility + self.sampling_type = sampling_config.sampling_type + self.batch_size = sampling_config.batch_size + self.edge_dir = sampling_config.edge_dir + self.to_device = to_device + self._input_type = input_type + self._epoch = 0 + self._num_recv = 0 + self._shutdowned = False + + self._set_ntypes_and_etypes(node_types, edge_types) + + def __del__(self): + if python_exit_status is True or python_exit_status is None: + return + self.shutdown() + + def shutdown(self): + """Release all resources held by the sampling engine.""" + if self._shutdowned: + return + self._engine.shutdown() + self._shutdowned = True + + def __iter__(self): + self._num_recv = 0 + self._engine.start_epoch(self._epoch) + self._epoch += 1 + return self + + def __next__(self): + if self._num_recv == self._engine.num_expected: + raise StopIteration + + msg = self._engine.get_sample() + if msg is None: + raise StopIteration # Graph store mode: server signals end of epoch + + result = self._collate_fn(msg) + self._num_recv += 1 + return result + + def _set_ntypes_and_etypes( + self, + node_types: Optional[List[NodeType]], + edge_types: Optional[List[EdgeType]], + ): + """Set node/edge type metadata used by ``_base_collate``. + + Ported from GLT DistLoader._set_ntypes_and_etypes. + """ + self._node_types = node_types or [] + self._edge_types = edge_types + self._reversed_edge_types: List[EdgeType] = [] + self._etype_str_to_rev: dict[str, EdgeType] = {} + if self._edge_types is not None: + for etype in self._edge_types: + rev_etype = reverse_edge_type(etype) + if self.edge_dir == "out": + self._reversed_edge_types.append(rev_etype) + self._etype_str_to_rev[as_str(etype)] = rev_etype + elif self.edge_dir == "in": + self._reversed_edge_types.append(etype) + self._etype_str_to_rev[as_str(rev_etype)] = etype + + def _base_collate(self, msg: SampleMessage) -> Union[Data, HeteroData]: + """Core collation: converts a SampleMessage into PyG Data/HeteroData. + + Ported verbatim from GLT DistLoader._collate_fn. This is a standalone + method so subclasses can compose collation steps explicitly:: + + def _collate_fn(self, msg): + data = self._base_collate(msg) + data = my_custom_transform(data) + return data + """ + ensure_device(self.to_device) + is_hetero = bool(msg["#IS_HETERO"]) + + # Extract metadata + _metadata_dict: dict[str, torch.Tensor] = {} + for k in msg.keys(): + if k.startswith("#META."): + meta_key = str(k[6:]) + _metadata_dict[meta_key] = msg[k].to(self.to_device) + metadata: Optional[dict[str, torch.Tensor]] = ( + _metadata_dict if _metadata_dict else None + ) + + # Heterogeneous sampling results + if is_hetero: + node_dict, row_dict, col_dict, edge_dict = {}, {}, {}, {} + nfeat_dict, efeat_dict = {}, {} + num_sampled_nodes_dict, num_sampled_edges_dict = {}, {} + + for ntype in self._node_types: + ids_key = f"{as_str(ntype)}.ids" + if ids_key in msg: + node_dict[ntype] = msg[ids_key].to(self.to_device) + nfeat_key = f"{as_str(ntype)}.nfeats" + if nfeat_key in msg: + nfeat_dict[ntype] = msg[nfeat_key].to(self.to_device) + num_sampled_nodes_key = f"{as_str(ntype)}.num_sampled_nodes" + if num_sampled_nodes_key in msg: + num_sampled_nodes_dict[ntype] = msg[num_sampled_nodes_key] + + for etype_str, rev_etype in self._etype_str_to_rev.items(): + rows_key = f"{etype_str}.rows" + cols_key = f"{etype_str}.cols" + if rows_key in msg: + # The edge index should be reversed. + row_dict[rev_etype] = msg[cols_key].to(self.to_device) + col_dict[rev_etype] = msg[rows_key].to(self.to_device) + eids_key = f"{etype_str}.eids" + if eids_key in msg: + edge_dict[rev_etype] = msg[eids_key].to(self.to_device) + num_sampled_edges_key = f"{etype_str}.num_sampled_edges" + if num_sampled_edges_key in msg: + num_sampled_edges_dict[rev_etype] = msg[num_sampled_edges_key] + efeat_key = f"{etype_str}.efeats" + if efeat_key in msg: + efeat_dict[rev_etype] = msg[efeat_key].to(self.to_device) + + nfeat_dict_or_none = nfeat_dict if len(nfeat_dict) > 0 else None + efeat_dict_or_none = efeat_dict if len(efeat_dict) > 0 else None + + if self.sampling_config.sampling_type in [ + SamplingType.NODE, + SamplingType.SUBGRAPH, + ]: + batch_key = f"{self._input_type}.batch" + if msg.get(batch_key) is not None: + batch_dict = { + self._input_type: msg[f"{self._input_type}.batch"].to( + self.to_device + ) + } + else: + batch_dict = { + self._input_type: node_dict[self._input_type][: self.batch_size] + } + batch_labels_key = f"{self._input_type}.nlabels" + if batch_labels_key in msg: + batch_labels = msg[batch_labels_key].to(self.to_device) + else: + batch_labels = None + batch_label_dict = {self._input_type: batch_labels} + else: + batch_dict = {} + batch_label_dict = {} + + output = HeteroSamplerOutput( + node_dict, + row_dict, + col_dict, + edge_dict if len(edge_dict) else None, + batch_dict, + num_sampled_nodes=num_sampled_nodes_dict, + num_sampled_edges=num_sampled_edges_dict, + edge_types=self._reversed_edge_types, + input_type=self._input_type, + device=self.to_device, + metadata=metadata, + ) + res_data = to_hetero_data( + output, + batch_label_dict, + nfeat_dict_or_none, + efeat_dict_or_none, + self.edge_dir, + ) + + # Homogeneous sampling results + else: + ids = msg["ids"].to(self.to_device) + rows = msg["rows"].to(self.to_device) + cols = msg["cols"].to(self.to_device) + eids = msg["eids"].to(self.to_device) if "eids" in msg else None + num_sampled_nodes = ( + msg["num_sampled_nodes"] if "num_sampled_nodes" in msg else None + ) + num_sampled_edges = ( + msg["num_sampled_edges"] if "num_sampled_edges" in msg else None + ) + + nfeats = msg["nfeats"].to(self.to_device) if "nfeats" in msg else None + efeats = msg["efeats"].to(self.to_device) if "efeats" in msg else None + + if self.sampling_config.sampling_type in [ + SamplingType.NODE, + SamplingType.SUBGRAPH, + ]: + if msg.get("batch") is not None: + batch = msg["batch"].to(self.to_device) + else: + batch = ids[: self.batch_size] + batch_labels = ( + msg["nlabels"].to(self.to_device) if "nlabels" in msg else None + ) + else: + batch = None + batch_labels = None + + # The edge index should be reversed. + output = SamplerOutput( + ids, + cols, + rows, + eids, + batch, + num_sampled_nodes, + num_sampled_edges, + device=self.to_device, + metadata=metadata, + ) + res_data = to_data(output, batch_labels, nfeats, efeats) + + return res_data + + def _collate_fn(self, msg: SampleMessage) -> Union[Data, HeteroData]: + """Default collation. Subclasses override this to add post-processing. + + The default implementation simply calls ``_base_collate``. + """ + return self._base_collate(msg) diff --git a/gigl/distributed/dist_ablp_neighborloader.py b/gigl/distributed/dist_ablp_neighborloader.py index 86987c507..69180ebed 100644 --- a/gigl/distributed/dist_ablp_neighborloader.py +++ b/gigl/distributed/dist_ablp_neighborloader.py @@ -5,11 +5,7 @@ import torch from graphlearn_torch.channel import SampleMessage, ShmChannel -from graphlearn_torch.distributed import ( - DistLoader, - MpDistSamplingWorkerOptions, - get_context, -) +from graphlearn_torch.distributed import MpDistSamplingWorkerOptions from graphlearn_torch.sampler import SamplingConfig, SamplingType from graphlearn_torch.utils import reverse_edge_type from torch_geometric.data import Data, HeteroData @@ -17,6 +13,7 @@ import gigl.distributed.utils from gigl.common.logger import Logger +from gigl.distributed.base_dist_loader import BaseDistLoader from gigl.distributed.constants import DEFAULT_MASTER_INFERENCE_PORT from gigl.distributed.dist_context import DistributedContext from gigl.distributed.dist_dataset import DistDataset @@ -28,6 +25,7 @@ ABLPNodeSamplerInput, metadata_key_with_prefix, ) +from gigl.distributed.sampling_engine import ColocatedSamplingEngine from gigl.distributed.utils.neighborloader import ( DatasetSchema, SamplingClusterSetup, @@ -52,7 +50,7 @@ logger = Logger() -class DistABLPLoader(DistLoader): +class DistABLPLoader(BaseDistLoader): def __init__( self, dataset: DistDataset, @@ -324,14 +322,46 @@ def __init__( ) if self._sampling_cluster_setup == SamplingClusterSetup.COLOCATED: - self._start_colocated_producers( - dataset=dataset, - rank=rank, - local_rank=local_rank, - process_start_gap_seconds=process_start_gap_seconds, - sampler_input=sampler_input, + # When initiating data loader(s), there will be a spike of memory usage lasting for ~30s. + # The current hypothesis is making connections across machines require a lot of memory. + # If we start all data loaders in all processes simultaneously, the spike of memory + # usage will add up and cause CPU memory OOM. Hence, we initiate the data loaders group by group + # to smooth the memory usage. The definition of group is discussed in init_neighbor_loader_worker. + logger.info( + f"---Machine {rank} local process number {local_rank} preparing to sleep for {process_start_gap_seconds * local_rank} seconds" + ) + time.sleep(process_start_gap_seconds * local_rank) + + channel = ShmChannel( + worker_options.channel_capacity, worker_options.channel_size + ) + if worker_options.pin_memory: + channel.pin_memory() + + producer = DistABLPSamplingProducer( + dataset, + sampler_input[0], + sampling_config, + worker_options, + channel, + ) + producer.init() + + engine = ColocatedSamplingEngine( + producer=producer, + channel=channel, + input_len=len(sampler_input[0]), + batch_size=sampling_config.batch_size, + drop_last=sampling_config.drop_last, + ) + + super().__init__( + engine=engine, sampling_config=sampling_config, - worker_options=worker_options, + to_device=self.to_device, + input_type=sampler_input[0].input_type, + node_types=dataset.get_node_types(), + edge_types=dataset.get_edge_types(), ) def _setup_for_colocated( @@ -562,98 +592,6 @@ def _setup_for_colocated( ), ) - def _start_colocated_producers( - self, - dataset: DistDataset, - rank: int, - local_rank: int, - process_start_gap_seconds: float, - sampler_input: list[ABLPNodeSamplerInput], - sampling_config: SamplingConfig, - worker_options: MpDistSamplingWorkerOptions, - ) -> None: - # Code below this point is taken from the GLT DistNeighborLoader.__init__() function (graphlearn_torch/python/distributed/dist_neighbor_loader.py). - # We do this so that we may override the DistSamplingProducer that is used with the GiGL implementation. - - self.data = dataset - self.input_data = sampler_input[0] - self.sampling_type = sampling_config.sampling_type - self.num_neighbors = sampling_config.num_neighbors - self.batch_size = sampling_config.batch_size - self.shuffle = sampling_config.shuffle - self.drop_last = sampling_config.drop_last - self.with_edge = sampling_config.with_edge - self.with_weight = sampling_config.with_weight - self.collect_features = sampling_config.collect_features - self.edge_dir = sampling_config.edge_dir - self.sampling_config = sampling_config - self.worker_options = worker_options - - # We can set shutdowned to false now - self._shutdowned = False - - self._is_mp_worker = True - self._is_collocated_worker = False - self._is_remote_worker = False - - self.num_data_partitions = self.data.num_partitions - self.data_partition_idx = self.data.partition_idx - self._set_ntypes_and_etypes( - self.data.get_node_types(), self.data.get_edge_types() - ) - - self._num_recv = 0 - self._epoch = 0 - - current_ctx = get_context() - - self._input_len = len(self.input_data) - self._input_type = self.input_data.input_type - self._num_expected = self._input_len // self.batch_size - if not self.drop_last and self._input_len % self.batch_size != 0: - self._num_expected += 1 - - if not current_ctx.is_worker(): - raise RuntimeError( - f"'{self.__class__.__name__}': only supports " - f"launching multiprocessing sampling workers with " - f"a non-server distribution mode, current role of " - f"distributed context is {current_ctx.role}." - ) - if self.data is None: - raise ValueError( - f"'{self.__class__.__name__}': missing input dataset " - f"when launching multiprocessing sampling workers." - ) - - # Launch multiprocessing sampling workers - self._with_channel = True - self.worker_options._set_worker_ranks(current_ctx) - - self._channel = ShmChannel( - self.worker_options.channel_capacity, self.worker_options.channel_size - ) - if self.worker_options.pin_memory: - self._channel.pin_memory() - - self._mp_producer = DistABLPSamplingProducer( - self.data, - self.input_data, - self.sampling_config, - self.worker_options, - self._channel, - ) - # When initiating data loader(s), there will be a spike of memory usage lasting for ~30s. - # The current hypothesis is making connections across machines require a lot of memory. - # If we start all data loaders in all processes simultaneously, the spike of memory - # usage will add up and cause CPU memory OOM. Hence, we initiate the data loaders group by group - # to smooth the memory usage. The definition of group is discussed in init_neighbor_loader_worker. - logger.info( - f"---Machine {rank} local process number {local_rank} preparing to sleep for {process_start_gap_seconds * local_rank} seconds" - ) - time.sleep(process_start_gap_seconds * local_rank) - self._mp_producer.init() - def _get_labels( self, msg: SampleMessage ) -> tuple[ @@ -808,7 +746,7 @@ def _set_labels( def _collate_fn(self, msg: SampleMessage) -> Union[Data, HeteroData]: msg, positive_labels, negative_labels = self._get_labels(msg) - data = super()._collate_fn(msg) + data = self._base_collate(msg) data = set_missing_features( data=data, node_feature_info=self._node_feature_info, diff --git a/gigl/distributed/distributed_neighborloader.py b/gigl/distributed/distributed_neighborloader.py index 6732a5b94..64701f035 100644 --- a/gigl/distributed/distributed_neighborloader.py +++ b/gigl/distributed/distributed_neighborloader.py @@ -3,22 +3,28 @@ from typing import Optional, Tuple, Union import torch -from graphlearn_torch.channel import SampleMessage +from graphlearn_torch.channel import SampleMessage, ShmChannel from graphlearn_torch.distributed import ( - DistLoader, MpDistSamplingWorkerOptions, RemoteDistSamplingWorkerOptions, ) +from graphlearn_torch.distributed.dist_sampling_producer import DistMpSamplingProducer from graphlearn_torch.sampler import NodeSamplerInput, SamplingConfig, SamplingType from torch_geometric.data import Data, HeteroData from torch_geometric.typing import EdgeType import gigl.distributed.utils from gigl.common.logger import Logger +from gigl.distributed.base_dist_loader import BaseDistLoader from gigl.distributed.constants import DEFAULT_MASTER_INFERENCE_PORT from gigl.distributed.dist_context import DistributedContext from gigl.distributed.dist_dataset import DistDataset +from gigl.distributed.graph_store.dist_server import DistServer from gigl.distributed.graph_store.remote_dist_dataset import RemoteDistDataset +from gigl.distributed.sampling_engine import ( + ColocatedSamplingEngine, + GraphStoreSamplingEngine, +) from gigl.distributed.utils.neighborloader import ( DatasetSchema, SamplingClusterSetup, @@ -42,7 +48,7 @@ DEFAULT_NUM_CPU_THREADS = 2 -class DistNeighborLoader(DistLoader): +class DistNeighborLoader(BaseDistLoader): def __init__( self, dataset: Union[DistDataset, RemoteDistDataset], @@ -124,9 +130,6 @@ def __init__( # Set self._shutdowned right away, that way if we throw here, and __del__ is called, # then we can properly clean up and don't get extraneous error messages. - # We set to `True` as we don't need to cleanup right away, and this will get set - # to `False` in super().__init__()` e.g. - # https://github.com/alibaba/graphlearn-for-pytorch/blob/26fe3d4e050b081bc51a79dc9547f244f5d314da/graphlearn_torch/python/distributed/dist_loader.py#L125C1-L126C1 self._shutdowned = True node_world_size: int @@ -285,26 +288,49 @@ def __init__( f"---Machine {rank} local process number {local_rank} preparing to sleep for {process_start_gap_seconds * local_rank} seconds" ) time.sleep(process_start_gap_seconds * local_rank) + + channel = ShmChannel( + worker_options.channel_capacity, worker_options.channel_size + ) + if worker_options.pin_memory: + channel.pin_memory() + + producer = DistMpSamplingProducer( + dataset, input_data, sampling_config, worker_options, channel + ) + producer.init() + + colocated_engine = ColocatedSamplingEngine( + producer=producer, + channel=channel, + input_len=len(input_data), + batch_size=batch_size, + drop_last=drop_last, + ) + input_type = input_data.input_type + assert isinstance(dataset, DistDataset) + node_types = dataset.get_node_types() + edge_types_for_collation = dataset.get_edge_types() + super().__init__( - dataset, # Pass in the dataset for colocated mode. - input_data, - sampling_config, - device, - worker_options, + engine=colocated_engine, + sampling_config=sampling_config, + to_device=device, + input_type=input_type, + node_types=node_types, + edge_types=edge_types_for_collation, ) else: - # For Graph Store mode, we need to start the communcation between compute and storage nodes sequentially, by compute node. - # E.g. intialize connections between compute node 0 and storage nodes 0, 1, 2, 3, then compute node 1 and storage nodes 0, 1, 2, 3, etc. + # For Graph Store mode, we need to start the communication between compute and storage nodes sequentially, by compute node. + # E.g. initialize connections between compute node 0 and storage nodes 0, 1, 2, 3, then compute node 1 and storage nodes 0, 1, 2, 3, etc. # Note that each compute node may have multiple connections to each storage node, once per compute process. # It's important to distinguish "compute node" (e.g. physical compute machine) from "compute process" (e.g. process running on the compute node). # Since in practice we have multiple compute processes per compute node, and each compute process needs to initialize the connection to the storage nodes. # E.g. if there are 4 gpus per compute node, then there will be 4 connections from each compute node to each storage node. - # We need to this because if we don't, then there is a race condition when initalizing the samplers on the storage nodes [1] + # We need to this because if we don't, then there is a race condition when initializing the samplers on the storage nodes [1] # Where since the lock is per *server* (e.g. per storage node), if we try to start one connection from compute node 0, and compute node 1 # Then we deadlock and fail. - # Specifically, the race condition happens in `DistLoader.__init__` when it initializes the sampling producers on the storage nodes. [2] # [1]: https://github.com/alibaba/graphlearn-for-pytorch/blob/main/graphlearn_torch/python/distributed/dist_server.py#L129-L167 - # [2]: https://github.com/alibaba/graphlearn-for-pytorch/blob/88ff111ac0d9e45c6c9d2d18cfc5883dca07e9f9/graphlearn_torch/python/distributed/dist_loader.py#L187-L193 # See below for a connection setup. # ╔═══════════════════════════════════════════════════════════════════════════════════════╗ @@ -338,22 +364,33 @@ def __init__( # │ (4) Compute Node 1 → Storage 1 (4 connections, one per GPU) │ # └─────────────────────────────────────────────────────────────────────────────┘ node_rank = dataset.cluster_info.compute_node_rank + + gs_engine = GraphStoreSamplingEngine( + server_ranks=list(range(dataset.cluster_info.num_storage_nodes)), + input_data_list=input_data, + sampling_config=sampling_config, + worker_options=worker_options, + server_create_fn=DistServer.create_sampling_producer, + ) + + # Barrier loop: only setup_rpc() needs barrier protection for target_node_rank in range(dataset.cluster_info.num_compute_nodes): if node_rank == target_node_rank: - # TODO: (kmontemayor2-sc) Evaluate if we need to stagger the initialization of the data loaders - # to smooth the memory usage. - super().__init__( - None, # Pass in None for Graph Store mode. - input_data, - sampling_config, - device, - worker_options, - ) + gs_engine.setup_rpc() logger.info(f"node_rank {node_rank} initialized the dist loader") torch.distributed.barrier() torch.distributed.barrier() logger.info("All node ranks initialized the dist loader") + super().__init__( + engine=gs_engine, + sampling_config=sampling_config, + to_device=device, + input_type=gs_engine._input_type, + node_types=gs_engine.node_types, + edge_types=gs_engine.edge_types, + ) + def _setup_for_graph_store( self, input_nodes: Optional[ @@ -366,7 +403,7 @@ def _setup_for_graph_store( ], dataset: RemoteDistDataset, num_workers: int, - ) -> tuple[NodeSamplerInput, RemoteDistSamplingWorkerOptions, DatasetSchema]: + ) -> tuple[list[NodeSamplerInput], RemoteDistSamplingWorkerOptions, DatasetSchema]: if input_nodes is None: raise ValueError( f"When using Graph Store mode, input nodes must be provided, received {input_nodes}" @@ -617,7 +654,7 @@ def _setup_for_colocated( ) def _collate_fn(self, msg: SampleMessage) -> Union[Data, HeteroData]: - data = super()._collate_fn(msg) + data = self._base_collate(msg) data = set_missing_features( data=data, node_feature_info=self._node_feature_info, diff --git a/gigl/distributed/sampling_engine.py b/gigl/distributed/sampling_engine.py new file mode 100644 index 000000000..a67a47da1 --- /dev/null +++ b/gigl/distributed/sampling_engine.py @@ -0,0 +1,211 @@ +""" +SamplingEngine abstraction for GiGL's distributed data loading. + +Provides a clean interface over the two sampling modes GiGL uses: +- Colocated: graph data and compute on the same machines (MpDistSamplingWorkerOptions) +- Graph Store: graph data on separate storage nodes (RemoteDistSamplingWorkerOptions) +""" + +import concurrent.futures +from abc import ABC, abstractmethod +from typing import Callable, List, Union + +import torch +from graphlearn_torch.channel import SampleMessage, ShmChannel +from graphlearn_torch.distributed.dist_client import request_server +from graphlearn_torch.distributed.dist_sampling_producer import DistMpSamplingProducer +from graphlearn_torch.distributed.rpc import rpc_is_initialized +from graphlearn_torch.sampler import ( + EdgeSamplerInput, + NodeSamplerInput, + RemoteSamplerInput, + SamplingConfig, +) + +from gigl.common.logger import Logger +from gigl.distributed.graph_store.dist_server import DistServer + +logger = Logger() + + +class SamplingEngine(ABC): + """Abstracts the lifecycle of distributed sample production and consumption. + + Concrete implementations handle the two sampling modes used by GiGL: + colocated (Mp workers) and graph store (remote workers). + """ + + @abstractmethod + def start_epoch(self, epoch: int) -> None: + """Signal the start of a new epoch. Implementations trigger sampling.""" + ... + + @abstractmethod + def get_sample(self) -> SampleMessage: + """Return the next sampled message (blocking).""" + ... + + @abstractmethod + def shutdown(self) -> None: + """Release all resources (channels, subprocesses, RPC connections).""" + ... + + @property + @abstractmethod + def num_expected(self) -> Union[int, float]: + """Number of batches expected per epoch. float('inf') for graph store mode.""" + ... + + +class ColocatedSamplingEngine(SamplingEngine): + """Sampling engine for GiGL's colocated mode. + + Wraps a DistMpSamplingProducer (or any subclass, e.g. DistABLPSamplingProducer) + and a ShmChannel. Used when graph data and compute live on the same machines. + """ + + def __init__( + self, + producer: DistMpSamplingProducer, + channel: ShmChannel, + input_len: int, + batch_size: int, + drop_last: bool, + ): + self._producer = producer + self._channel = channel + self._num_expected = input_len // batch_size + if not drop_last and input_len % batch_size != 0: + self._num_expected += 1 + + def start_epoch(self, epoch: int) -> None: + self._producer.produce_all() + + def get_sample(self) -> SampleMessage: + return self._channel.recv() + + def shutdown(self) -> None: + self._producer.shutdown() + + @property + def num_expected(self) -> int: + return self._num_expected + + +class GraphStoreSamplingEngine(SamplingEngine): + """Sampling engine for GiGL's graph store mode. + + Manages server-side producer lifecycle via RPC. Used when graph data lives + on separate storage nodes. + + The initialization is split into two phases: + - ``__init__``: Stores configuration (safe to run without barriers). + - ``setup_rpc()``: Dispatches RPCs to create producers on each server. + This must be called inside the per-compute-node barrier loop to avoid + race conditions in TensorPipe rendezvous. + """ + + def __init__( + self, + server_ranks: List[int], + input_data_list: List[ + Union[NodeSamplerInput, EdgeSamplerInput, RemoteSamplerInput] + ], + sampling_config: SamplingConfig, + worker_options, # RemoteDistSamplingWorkerOptions + server_create_fn: Callable = DistServer.create_sampling_producer, + ): + self._server_ranks = server_ranks + self._input_data_list = input_data_list + self._sampling_config = sampling_config + self._worker_options = worker_options + self._server_create_fn = server_create_fn + self._producer_ids: List[int] = [] + self._channel = None # Will be set by setup_rpc() + self._shutdowned = False + + # Fetch dataset metadata from the first server + ( + self.num_data_partitions, + self.data_partition_idx, + self.node_types, + self.edge_types, + ) = request_server(self._server_ranks[0], DistServer.get_dataset_meta) + + # Determine input_type from the first input data + self._input_type = self._input_data_list[0].input_type + + def setup_rpc(self) -> None: + """Dispatch RPCs to create producers on each server. + + This method must be called inside the per-compute-node barrier loop + to avoid race conditions in TensorPipe rendezvous. + """ + # Move input data to CPU for serialization + for input_data in self._input_data_list: + if not isinstance(input_data, RemoteSamplerInput): + input_data = input_data.to(torch.device("cpu")) + + # Dispatch RPCs to all servers concurrently + with concurrent.futures.ThreadPoolExecutor() as executor: + futures = [ + executor.submit( + request_server, + server_rank, + self._server_create_fn, + input_data, + self._sampling_config, + self._worker_options, + ) + for server_rank, input_data in zip( + self._server_ranks, self._input_data_list + ) + ] + + for future in futures: + producer_id = future.result() + self._producer_ids.append(producer_id) + + # Import here to avoid circular import issues with compiled GLT channel module + from graphlearn_torch.channel import RemoteReceivingChannel + + self._channel = RemoteReceivingChannel( + self._server_ranks, + self._producer_ids, + self._worker_options.prefetch_size, + ) + + def start_epoch(self, epoch: int) -> None: + for server_rank, producer_id in zip(self._server_ranks, self._producer_ids): + request_server( + server_rank, + DistServer.start_new_epoch_sampling, + producer_id, + epoch, + ) + assert ( + self._channel is not None + ), "setup_rpc() must be called before start_epoch()" + self._channel.reset() + + def get_sample(self) -> SampleMessage: + assert ( + self._channel is not None + ), "setup_rpc() must be called before get_sample()" + return self._channel.recv() + + def shutdown(self) -> None: + if self._shutdowned: + return + if rpc_is_initialized() is True: + for server_rank, producer_id in zip(self._server_ranks, self._producer_ids): + request_server( + server_rank, + DistServer.destroy_sampling_producer, + producer_id, + ) + self._shutdowned = True + + @property + def num_expected(self) -> float: + return float("inf")