From 7d0d8be91b4752b304ac50b27b494a061017899d Mon Sep 17 00:00:00 2001 From: kmontemayor Date: Wed, 11 Feb 2026 18:35:49 +0000 Subject: [PATCH 01/30] Setup DistABLPLoader for GraphStore mode --- gigl/distributed/dist_ablp_neighborloader.py | 613 ++++++++++++++---- .../distributed/distributed_neighborloader.py | 8 +- .../graph_store_integration_test.py | 65 ++ 3 files changed, 560 insertions(+), 126 deletions(-) diff --git a/gigl/distributed/dist_ablp_neighborloader.py b/gigl/distributed/dist_ablp_neighborloader.py index 86987c507..0bd653523 100644 --- a/gigl/distributed/dist_ablp_neighborloader.py +++ b/gigl/distributed/dist_ablp_neighborloader.py @@ -1,14 +1,17 @@ import ast +import concurrent.futures import time from collections import Counter, abc, defaultdict from typing import Optional, Union import torch -from graphlearn_torch.channel import SampleMessage, ShmChannel +from graphlearn_torch.channel import RemoteReceivingChannel, SampleMessage, ShmChannel from graphlearn_torch.distributed import ( DistLoader, MpDistSamplingWorkerOptions, + RemoteDistSamplingWorkerOptions, get_context, + request_server, ) from graphlearn_torch.sampler import SamplingConfig, SamplingType from graphlearn_torch.utils import reverse_edge_type @@ -22,6 +25,8 @@ from gigl.distributed.dist_dataset import DistDataset from gigl.distributed.dist_sampling_producer import DistABLPSamplingProducer from gigl.distributed.distributed_neighborloader import DEFAULT_NUM_CPU_THREADS +from gigl.distributed.graph_store.dist_server import DistServer +from gigl.distributed.graph_store.remote_dist_dataset import RemoteDistDataset from gigl.distributed.sampler import ( NEGATIVE_LABEL_METADATA_KEY, POSITIVE_LABEL_METADATA_KEY, @@ -55,12 +60,20 @@ class DistABLPLoader(DistLoader): def __init__( self, - dataset: DistDataset, + dataset: Union[DistDataset, RemoteDistDataset], num_neighbors: Union[list[int], dict[EdgeType, list[int]]], input_nodes: Optional[ Union[ torch.Tensor, tuple[NodeType, torch.Tensor], + # Graph Store mode inputs + dict[int, tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]], + tuple[ + NodeType, + dict[ + int, tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]] + ], + ], ] ] = None, supervision_edge_type: Optional[Union[EdgeType, list[EdgeType]]] = None, @@ -125,24 +138,29 @@ def __init__( - `y_negative`: {(a, to, b): {0: torch.tensor([3])}, (a, to, c): {0: torch.tensor([4])}} Args: - dataset (DistDataset): The dataset to sample from. + dataset (Union[DistDataset, RemoteDistDataset]): The dataset to sample from. + If this is a `RemoteDistDataset`, then we are in "Graph Store" mode. num_neighbors (list[int] or dict[tuple[str, str, str], list[int]]): The number of neighbors to sample for each node in each iteration. If an entry is set to `-1`, all neighbors will be included. In heterogeneous graphs, may also take in a dictionary denoting the amount of neighbors to sample for each individual edge type. - context (DistributedContext): Distributed context information of the current process. - input_nodes (Optional[torch.Tensor, tuple[NodeType, torch.Tensor]]): - Indices of seed nodes to start sampling from. - If set to `None` for homogeneous settings, all nodes will be considered. - In heterogeneous graphs, this flag must be passed in as a tuple that holds - the node type and node indices. (default: `None`) + input_nodes: Indices of seed nodes to start sampling from. + For Colocated mode: `torch.Tensor` or `tuple[NodeType, torch.Tensor]`. + If set to `None` for homogeneous settings, all nodes will be considered. + In heterogeneous graphs, this flag must be passed in as a tuple that holds + the node type and node indices. + For Graph Store mode: `dict[int, tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]]` + or `tuple[NodeType, dict[int, tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]]]`. + The dict maps server_rank to (anchor_nodes, positive_labels, negative_labels). + This is the return type of `RemoteDistDataset.get_ablp_input()`. supervision_edge_type (Optional[Union[EdgeType, list[EdgeType]]]): The edge type(s) to use for supervision. Must be None iff the dataset is labeled homogeneous. If set to a single EdgeType, the positive and negative labels will be stored in the `y_positive` and `y_negative` fields of the Data object. If set to a list of EdgeTypes, the positive and negative labels will be stored in the `y_positive` and `y_negative` fields of the Data object, with the key being the EdgeType. (default: `None`) + NOTE: Graph Store mode currently only supports a single supervision edge type. num_workers (int): How many workers to use (subprocesses to spwan) for distributed neighbor sampling of the current process. (default: ``1``). batch_size (int, optional): how many samples per batch to load @@ -189,6 +207,13 @@ def __init__( master_ip_address: str should_cleanup_distributed_context: bool = False + # Determine sampling cluster setup based on dataset type + if isinstance(dataset, RemoteDistDataset): + self._sampling_cluster_setup = SamplingClusterSetup.GRAPH_STORE + else: + self._sampling_cluster_setup = SamplingClusterSetup.COLOCATED + logger.info(f"Sampling cluster setup: {self._sampling_cluster_setup.value}") + if supervision_edge_type is None: self._supervision_edge_types: list[EdgeType] = [ DEFAULT_HOMOGENEOUS_EDGE_TYPE @@ -201,9 +226,16 @@ def __init__( self._supervision_edge_types = supervision_edge_type else: self._supervision_edge_types = [supervision_edge_type] - del supervision_edge_type - self._sampling_cluster_setup = SamplingClusterSetup.COLOCATED + # TODO(kmonte): Support multiple supervision edge types in Graph Store mode + if self._sampling_cluster_setup == SamplingClusterSetup.GRAPH_STORE: + if len(self._supervision_edge_types) > 1: + raise ValueError( + "Graph Store mode currently only supports a single supervision edge type. " + f"Received {len(self._supervision_edge_types)} edge types: {self._supervision_edge_types}" + ) + + del supervision_edge_type if context: assert ( @@ -266,34 +298,78 @@ def __init__( local_process_world_size, ) # delete deprecated vars so we don't accidentally use them. - self.to_device = ( + device = ( pin_memory_device if pin_memory_device else gigl.distributed.utils.get_available_device( local_process_rank=local_rank ) ) + self.to_device = device - ( - sampler_input, - worker_options, - dataset_metadata, - ) = self._setup_for_colocated( - input_nodes=input_nodes, - dataset=dataset, - local_rank=local_rank, - local_world_size=local_world_size, - device=self.to_device, - master_ip_address=master_ip_address, - node_rank=node_rank, - node_world_size=node_world_size, - num_workers=num_workers, - worker_concurrency=worker_concurrency, - channel_size=channel_size, - num_cpu_threads=num_cpu_threads, - ) + # Call appropriate setup method based on sampling cluster setup + if self._sampling_cluster_setup == SamplingClusterSetup.COLOCATED: + assert isinstance( + dataset, DistDataset + ), "When using colocated mode, dataset must be a DistDataset." + # Validate input_nodes type for colocated mode + if isinstance(input_nodes, abc.Mapping) or ( + isinstance(input_nodes, tuple) + and isinstance(input_nodes[1], abc.Mapping) + ): + raise ValueError( + f"When using Colocated mode, input_nodes must be of type " + f"(torch.Tensor | tuple[NodeType, torch.Tensor] | None), " + f"received Graph Store format: {type(input_nodes)}" + ) + ( + sampler_input, + worker_options, + dataset_metadata, + ) = self._setup_for_colocated( + input_nodes=input_nodes, + dataset=dataset, + local_rank=local_rank, + local_world_size=local_world_size, + device=device, + master_ip_address=master_ip_address, + node_rank=node_rank, + node_world_size=node_world_size, + num_workers=num_workers, + worker_concurrency=worker_concurrency, + channel_size=channel_size, + num_cpu_threads=num_cpu_threads, + ) + else: # Graph Store mode + assert isinstance( + dataset, RemoteDistDataset + ), "When using Graph Store mode, dataset must be a RemoteDistDataset." + # Validate input_nodes type for Graph Store mode + if ( + input_nodes is None + or isinstance(input_nodes, torch.Tensor) + or ( + isinstance(input_nodes, tuple) + and isinstance(input_nodes[1], torch.Tensor) + ) + ): + raise ValueError( + f"When using Graph Store mode, input_nodes must be of type " + f"(dict[int, tuple[...]] | tuple[NodeType, dict[int, tuple[...]]]), " + f"received Colocated format: {type(input_nodes)}" + ) + ( + sampler_input, + worker_options, + dataset_metadata, + ) = self._setup_for_graph_store( + input_nodes=input_nodes, + dataset=dataset, + supervision_edge_type=self._supervision_edge_types[0], + num_workers=num_workers, + ) - self._is_input_labeled_homogeneous = ( + self.is_homogeneous_with_labeled_edge_type = ( dataset_metadata.is_homogeneous_with_labeled_edge_type ) self._node_feature_info = dataset_metadata.node_feature_info @@ -324,19 +400,132 @@ def __init__( ) if self._sampling_cluster_setup == SamplingClusterSetup.COLOCATED: - self._start_colocated_producers( - dataset=dataset, - rank=rank, - local_rank=local_rank, - process_start_gap_seconds=process_start_gap_seconds, - sampler_input=sampler_input, - sampling_config=sampling_config, - worker_options=worker_options, + # Code below this point is taken from the GLT DistNeighborLoader.__init__() function + # (graphlearn_torch/python/distributed/dist_neighbor_loader.py). + # We do this so that we may override the DistSamplingProducer that is used with the GiGL implementation. + + # Type narrowing for colocated mode + + self.data = dataset + self.input_data = sampler_input[0] + del dataset, sampler_input + assert isinstance(self.data, DistDataset) + assert isinstance(self.input_data, ABLPNodeSamplerInput) + + self.sampling_type = sampling_config.sampling_type + self.num_neighbors = sampling_config.num_neighbors + self.batch_size = sampling_config.batch_size + self.shuffle = sampling_config.shuffle + self.drop_last = sampling_config.drop_last + self.with_edge = sampling_config.with_edge + self.with_weight = sampling_config.with_weight + self.collect_features = sampling_config.collect_features + self.edge_dir = sampling_config.edge_dir + self.sampling_config = sampling_config + self.worker_options = worker_options + + # We can set shutdowned to false now + self._shutdowned = False + + self._is_mp_worker = True + self._is_collocated_worker = False + self._is_remote_worker = False + + self.num_data_partitions = self.data.num_partitions + self.data_partition_idx = self.data.partition_idx + self._set_ntypes_and_etypes( + self.data.get_node_types(), self.data.get_edge_types() + ) + + self._num_recv = 0 + self._epoch = 0 + + current_ctx = get_context() + + self._input_len = len(self.input_data) + self._input_type = self.input_data.input_type + self._num_expected = self._input_len // self.batch_size + if not self.drop_last and self._input_len % self.batch_size != 0: + self._num_expected += 1 + + if not current_ctx.is_worker(): + raise RuntimeError( + f"'{self.__class__.__name__}': only supports " + f"launching multiprocessing sampling workers with " + f"a non-server distribution mode, current role of " + f"distributed context is {current_ctx.role}." + ) + if self.data is None: + raise ValueError( + f"'{self.__class__.__name__}': missing input dataset " + f"when launching multiprocessing sampling workers." + ) + + # Launch multiprocessing sampling workers + self._with_channel = True + self.worker_options._set_worker_ranks(current_ctx) + + self._channel = ShmChannel( + self.worker_options.channel_capacity, self.worker_options.channel_size + ) + if self.worker_options.pin_memory: + self._channel.pin_memory() + + self._mp_producer = DistABLPSamplingProducer( + self.data, + self.input_data, + self.sampling_config, + self.worker_options, + self._channel, + ) + # When initiating data loader(s), there will be a spike of memory usage lasting for ~30s. + # The current hypothesis is making connections across machines require a lot of memory. + # If we start all data loaders in all processes simultaneously, the spike of memory + # usage will add up and cause CPU memory OOM. Hence, we initiate the data loaders group by group + # to smooth the memory usage. The definition of group is discussed in init_neighbor_loader_worker. + logger.info( + f"---Machine {rank} local process number {local_rank} preparing to sleep for {process_start_gap_seconds * local_rank} seconds" + ) + time.sleep(process_start_gap_seconds * local_rank) + self._mp_producer.init() + else: + # Graph Store mode - re-implement remote worker setup + # Use sequential initialization per compute node to avoid race conditions + # when initializing the samplers on the storage nodes. + node_rank = dataset.cluster_info.compute_node_rank + for target_node_rank in range(dataset.cluster_info.num_compute_nodes): + if node_rank == target_node_rank: + self._init_remote_worker( + dataset=dataset, + sampler_input=sampler_input, + sampling_config=sampling_config, + worker_options=worker_options, + dataset_metadata=dataset_metadata, + ) + logger.info( + f"node_rank {node_rank} / {dataset.cluster_info.num_compute_nodes} initialized the dist loader" + ) + torch.distributed.barrier() + torch.distributed.barrier() + logger.info( + f"node_rank {node_rank} / {dataset.cluster_info.num_compute_nodes} finished initializing the dist loader" ) def _setup_for_colocated( self, - input_nodes: Optional[Union[torch.Tensor, tuple[NodeType, torch.Tensor]]], + input_nodes: Optional[ + Union[ + torch.Tensor, + tuple[NodeType, torch.Tensor], + dict[int, tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]], + tuple[ + NodeType, + dict[ + int, tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]] + ], + ], + ] + ], dataset: DistDataset, local_rank: int, local_world_size: int, @@ -367,19 +556,17 @@ def _setup_for_colocated( num_cpu_threads: Number of CPU threads for PyTorch. Returns: - Tuple of (list[ABLPNodeSamplerInput], MpDistSamplingWorkerOptions, DatasetSchema). + Tuple of (ABLPNodeSamplerInput, MpDistSamplingWorkerOptions, DatasetSchema). """ # Validate input format - should not be Graph Store format if isinstance(input_nodes, abc.Mapping): raise ValueError( - f"When using Colocated mode, input_nodes must be of type " - f"(torch.Tensor | tuple[NodeType, torch.Tensor] | None), " + f"When using Colocated mode, input_nodes must be of type (torch.Tensor | tuple[NodeType, torch.Tensor]), " f"received {type(input_nodes)}" ) elif isinstance(input_nodes, tuple) and isinstance(input_nodes[1], abc.Mapping): raise ValueError( - f"When using Colocated mode, input_nodes must be of type " - f"(torch.Tensor | tuple[NodeType, torch.Tensor] | None), " + f"When using Colocated mode, input_nodes must be of type (torch.Tensor | tuple[NodeType, torch.Tensor]), " f"received tuple with second element of type {type(input_nodes[1])}" ) @@ -398,15 +585,15 @@ def _setup_for_colocated( anchor_node_type, anchor_node_ids = input_nodes # TODO (mkolodner-sc): We currently assume supervision edges are directed outward, revisit in future if # this assumption is no longer valid and/or is too opinionated - for supervision_edge_type in self._supervision_edge_types: + for sup_edge_type in self._supervision_edge_types: assert ( - supervision_edge_type[0] == anchor_node_type + sup_edge_type[0] == anchor_node_type ), f"Label EdgeType are currently expected to be provided in outward edge direction as tuple (`anchor_node_type`,`relation`,`supervision_node_type`), \ - got supervision edge type {supervision_edge_type} with anchor node type {anchor_node_type}" + got supervision edge type {sup_edge_type} with anchor node type {anchor_node_type}" if dataset.edge_dir == "in": self._supervision_edge_types = [ - reverse_edge_type(supervision_edge_type) - for supervision_edge_type in self._supervision_edge_types + reverse_edge_type(sup_edge_type) + for sup_edge_type in self._supervision_edge_types ] elif isinstance(input_nodes, torch.Tensor): if self._supervision_edge_types != [DEFAULT_HOMOGENEOUS_EDGE_TYPE]: @@ -459,11 +646,11 @@ def _setup_for_colocated( self._negative_label_edge_types: list[EdgeType] = [] positive_labels_by_label_edge_type: dict[EdgeType, torch.Tensor] = {} negative_labels_by_label_edge_type: dict[EdgeType, torch.Tensor] = {} - for supervision_edge_type in self._supervision_edge_types: + for sup_edge_type in self._supervision_edge_types: ( positive_label_edge_type, negative_label_edge_type, - ) = select_label_edge_types(supervision_edge_type, dataset.graph.keys()) + ) = select_label_edge_types(sup_edge_type, dataset.graph.keys()) self._positive_label_edge_types.append(positive_label_edge_type) if negative_label_edge_type is not None: self._negative_label_edge_types.append(negative_label_edge_type) @@ -516,7 +703,6 @@ def _setup_for_colocated( master_worker_port=neighbor_loader_port_for_current_rank, device=device, should_use_cpu_workers=should_use_cpu_workers, - # Lever to explore tuning for CPU based inference num_cpu_threads=num_cpu_threads, ) logger.info( @@ -532,16 +718,8 @@ def _setup_for_colocated( num_workers=num_workers, worker_devices=[torch.device("cpu") for _ in range(num_workers)], worker_concurrency=worker_concurrency, - # Each worker will spawn several sampling workers, and all sampling workers spawned by workers in one group - # need to be connected. Thus, we need master ip address and master port to - # initate the connection. - # Note that different groups of workers are independent, and thus - # the sampling processes in different groups should be independent, and should - # use different master ports. master_addr=master_ip_address, master_port=dist_sampling_port_for_current_rank, - # Load testing shows that when num_rpc_threads exceed 16, the performance - # will degrade. num_rpc_threads=min(dataset.num_partitions, 16), rpc_timeout=600, channel_size=channel_size, @@ -562,21 +740,200 @@ def _setup_for_colocated( ), ) - def _start_colocated_producers( + def _setup_for_graph_store( self, - dataset: DistDataset, - rank: int, - local_rank: int, - process_start_gap_seconds: float, + input_nodes: Optional[ + Union[ + torch.Tensor, + tuple[NodeType, torch.Tensor], + dict[int, tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]], + tuple[ + NodeType, + dict[ + int, tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]] + ], + ], + ] + ], + dataset: RemoteDistDataset, + supervision_edge_type: EdgeType, + num_workers: int, + ) -> tuple[ + list[ABLPNodeSamplerInput], RemoteDistSamplingWorkerOptions, DatasetSchema + ]: + """ + Setup method for Graph Store mode. + + Args: + input_nodes: ABLP input from RemoteDistDataset.get_ablp_input(). + Format: dict[server_rank, (anchors, positive_labels, negative_labels)] + or tuple[NodeType, dict[server_rank, (anchors, positive_labels, negative_labels)]]. + dataset: The RemoteDistDataset to sample from. + supervision_edge_type: The single supervision edge type to use. + num_workers: Number of sampling workers. + + Returns: + Tuple of (list[ABLPNodeSamplerInput], RemoteDistSamplingWorkerOptions, DatasetSchema). + """ + # Validate input format - must be Graph Store format + if input_nodes is None: + raise ValueError( + f"When using Graph Store mode, input_nodes must be provided, received {input_nodes}" + ) + elif isinstance(input_nodes, torch.Tensor): + raise ValueError( + f"When using Graph Store mode, input_nodes must be of type " + f"(dict[int, tuple[Tensor, Tensor, Optional[Tensor]]] | " + f"tuple[NodeType, dict[int, tuple[Tensor, Tensor, Optional[Tensor]]]]), " + f"received {type(input_nodes)}" + ) + elif isinstance(input_nodes, tuple) and isinstance( + input_nodes[1], torch.Tensor + ): + raise ValueError( + f"When using Graph Store mode, input_nodes must be of type " + f"(dict[int, tuple[Tensor, Tensor, Optional[Tensor]]] | " + f"tuple[NodeType, dict[int, tuple[Tensor, Tensor, Optional[Tensor]]]]), " + f"received tuple with second element of type {type(input_nodes[1])}" + ) + + is_homogeneous_with_labeled_edge_type = False + node_feature_info = dataset.get_node_feature_info() + edge_feature_info = dataset.get_edge_feature_info() + edge_types = dataset.get_edge_types() + node_rank = dataset.cluster_info.compute_node_rank + + # Get sampling ports for compute-storage connections. + sampling_ports = dataset.get_free_ports_on_storage_cluster( + num_ports=dataset.cluster_info.num_compute_nodes + ) + sampling_port = sampling_ports[node_rank] + + # TODO(kmonte) - We need to be able to differentiate between differnt instance of the same loader. + # e.g. if we have two different DistABLPLoaders, then they will have conflicting worker keys. + # And they will share each others data. Therefor, the second loader will not load the data it's expecting. + # Probably, we can just keep track of the insantiations on the server-side and include the count in the worker key. + worker_options = RemoteDistSamplingWorkerOptions( + server_rank=list(range(dataset.cluster_info.num_storage_nodes)), + num_workers=num_workers, + worker_devices=[torch.device("cpu") for _ in range(num_workers)], + master_addr=dataset.cluster_info.storage_cluster_master_ip, + master_port=sampling_port, + worker_key=f"compute_ablp_loader_rank_{node_rank}", + ) + logger.info( + f"Rank {torch.distributed.get_rank()}! init for sampling rpc: " + f"tcp://{dataset.cluster_info.storage_cluster_master_ip}:{sampling_port}" + ) + + # Determine input type based on input_nodes structure + if isinstance(input_nodes, abc.Mapping): + # Labeled homogeneous: dict[int, tuple[...]] + nodes_dict = input_nodes + input_type: Optional[NodeType] = DEFAULT_HOMOGENEOUS_NODE_TYPE + elif isinstance(input_nodes, tuple) and isinstance(input_nodes[1], abc.Mapping): + # Heterogeneous: (NodeType, dict[int, tuple[...]]) + input_type = input_nodes[0] + nodes_dict = input_nodes[1] + is_homogeneous_with_labeled_edge_type = True + else: + raise ValueError( + f"When using Graph Store mode, input_nodes must be of type " + f"(dict[int, tuple[...]] | tuple[NodeType, dict[int, tuple[...]]]), " + f"received {type(input_nodes)}" + ) + + # Validate server ranks + servers = nodes_dict.keys() + if len(servers) > 0: + if ( + max(servers) >= dataset.cluster_info.num_storage_nodes + or min(servers) < 0 + ): + raise ValueError( + f"When using Graph Store mode, the server ranks must be in range " + f"[0, {dataset.cluster_info.num_storage_nodes}), " + f"received inputs for servers: {list(servers)}" + ) + + # Get label edge types for building ABLPNodeSamplerInput + # TODO(kmonte): Support multiple supervision edge types in Graph Store mode + ( + positive_label_edge_type, + negative_label_edge_type, + ) = select_label_edge_types(supervision_edge_type, edge_types or []) + logger.info(f"Positive label edge type: {positive_label_edge_type}") + logger.info(f"Negative label edge type: {negative_label_edge_type}") + self._positive_label_edge_types = [positive_label_edge_type] + self._negative_label_edge_types = ( + [negative_label_edge_type] if negative_label_edge_type else [] + ) + + # Convert from dict format to list of ABLPNodeSamplerInput + input_data: list[ABLPNodeSamplerInput] = [] + for server_rank in range(dataset.cluster_info.num_storage_nodes): + if server_rank in nodes_dict: + anchors, positive_labels, negative_labels = nodes_dict[server_rank] + else: + # Empty input for servers with no data for this rank + anchors = torch.empty(0, dtype=torch.long) + positive_labels = torch.empty(0, 0, dtype=torch.long) + negative_labels = None + + # Build label dicts keyed by label edge type + positive_label_by_edge_types = {positive_label_edge_type: positive_labels} + negative_label_by_edge_types: dict[EdgeType, torch.Tensor] = {} + if negative_labels is not None and negative_label_edge_type is not None: + negative_label_by_edge_types[negative_label_edge_type] = negative_labels + + logger.info( + f"Rank: {torch.distributed.get_rank()}! Building ABLPNodeSamplerInput for server rank: {server_rank} with input type: {input_type}. anchors: {anchors.shape}, positive_labels: {positive_labels.shape}, negative_labels: {negative_labels.shape if negative_labels is not None else None}" + ) + ablp_input = ABLPNodeSamplerInput( + node=anchors, + input_type=input_type, + positive_label_by_edge_types=positive_label_by_edge_types, + negative_label_by_edge_types=negative_label_by_edge_types, + ) + input_data.append(ablp_input) + + return ( + input_data, + worker_options, + DatasetSchema( + is_homogeneous_with_labeled_edge_type=is_homogeneous_with_labeled_edge_type, + edge_types=edge_types, + node_feature_info=node_feature_info, + edge_feature_info=edge_feature_info, + edge_dir=dataset.get_edge_dir(), + ), + ) + + def _init_remote_worker( + self, + dataset: RemoteDistDataset, sampler_input: list[ABLPNodeSamplerInput], sampling_config: SamplingConfig, - worker_options: MpDistSamplingWorkerOptions, + worker_options: RemoteDistSamplingWorkerOptions, + dataset_metadata: DatasetSchema, ) -> None: - # Code below this point is taken from the GLT DistNeighborLoader.__init__() function (graphlearn_torch/python/distributed/dist_neighbor_loader.py). - # We do this so that we may override the DistSamplingProducer that is used with the GiGL implementation. + """ + Initialize the remote worker code path for Graph Store mode. + + This re-implements GLT's DistLoader remote worker setup but uses GiGL's DistServer. - self.data = dataset - self.input_data = sampler_input[0] + Args: + dataset: The RemoteDistDataset to sample from. + sampler_input: List of ABLPNodeSamplerInput, one per server. + sampling_config: Configuration for sampling. + worker_options: Options for remote sampling workers. + dataset_metadata: Metadata about the dataset schema. + """ + # Set instance variables (like DistLoader does) + # Note: We assign to self.data and self.input_data which are also set in the colocated + # branch. For Graph Store mode, data is None and input_data is a list. + object.__setattr__(self, "data", None) # No local data in Graph Store mode + object.__setattr__(self, "input_data", sampler_input) self.sampling_type = sampling_config.sampling_type self.num_neighbors = sampling_config.num_neighbors self.batch_size = sampling_config.batch_size @@ -589,70 +946,78 @@ def _start_colocated_producers( self.sampling_config = sampling_config self.worker_options = worker_options - # We can set shutdowned to false now self._shutdowned = False - self._is_mp_worker = True + # Set worker type flags + self._is_mp_worker = False self._is_collocated_worker = False - self._is_remote_worker = False + self._is_remote_worker = True - self.num_data_partitions = self.data.num_partitions - self.data_partition_idx = self.data.partition_idx - self._set_ntypes_and_etypes( - self.data.get_node_types(), self.data.get_edge_types() - ) + # For remote worker, end of epoch is determined by server + self._num_expected = float("inf") + self._with_channel = True self._num_recv = 0 self._epoch = 0 - current_ctx = get_context() - - self._input_len = len(self.input_data) - self._input_type = self.input_data.input_type - self._num_expected = self._input_len // self.batch_size - if not self.drop_last and self._input_len % self.batch_size != 0: - self._num_expected += 1 - - if not current_ctx.is_worker(): - raise RuntimeError( - f"'{self.__class__.__name__}': only supports " - f"launching multiprocessing sampling workers with " - f"a non-server distribution mode, current role of " - f"distributed context is {current_ctx.role}." - ) - if self.data is None: - raise ValueError( - f"'{self.__class__.__name__}': missing input dataset " - f"when launching multiprocessing sampling workers." + # Get server rank list from worker_options + self._server_rank_list = ( + worker_options.server_rank + if isinstance(worker_options.server_rank, list) + else [worker_options.server_rank] + ) + self._input_data_list = sampler_input # Already a list (one per server) + + # Get input type from first input + self._input_type = self._input_data_list[0].input_type + + # Get dataset metadata from cluster_info (not via RPC) + self.num_data_partitions = dataset.cluster_info.num_storage_nodes + self.data_partition_idx = dataset.cluster_info.compute_node_rank + + # Derive node types from edge types + # For labeled homogeneous: edge_types contains DEFAULT_HOMOGENEOUS_EDGE_TYPE + # For heterogeneous: extract unique src/dst types from edge types + edge_types = dataset_metadata.edge_types or [] + if edge_types: + node_types = list( + set([et[0] for et in edge_types] + [et[2] for et in edge_types]) ) + else: + node_types = [DEFAULT_HOMOGENEOUS_NODE_TYPE] + self._set_ntypes_and_etypes(node_types, edge_types) + + # Create sampling producers on each server (concurrently) + # Move input data to CPU before sending to server + for input_data in self._input_data_list: + input_data.to(torch.device("cpu")) + + self._producer_id_list = [] + with concurrent.futures.ThreadPoolExecutor() as executor: + futures = [ + executor.submit( + request_server, + server_rank, + DistServer.create_sampling_ablp_producer, + input_data, + self.sampling_config, + self.worker_options, + ) + for server_rank, input_data in zip( + self._server_rank_list, self._input_data_list + ) + ] - # Launch multiprocessing sampling workers - self._with_channel = True - self.worker_options._set_worker_ranks(current_ctx) + for future in futures: + producer_id = future.result() + self._producer_id_list.append(producer_id) - self._channel = ShmChannel( - self.worker_options.channel_capacity, self.worker_options.channel_size - ) - if self.worker_options.pin_memory: - self._channel.pin_memory() - - self._mp_producer = DistABLPSamplingProducer( - self.data, - self.input_data, - self.sampling_config, - self.worker_options, - self._channel, - ) - # When initiating data loader(s), there will be a spike of memory usage lasting for ~30s. - # The current hypothesis is making connections across machines require a lot of memory. - # If we start all data loaders in all processes simultaneously, the spike of memory - # usage will add up and cause CPU memory OOM. Hence, we initiate the data loaders group by group - # to smooth the memory usage. The definition of group is discussed in init_neighbor_loader_worker. - logger.info( - f"---Machine {rank} local process number {local_rank} preparing to sleep for {process_start_gap_seconds * local_rank} seconds" + # Create remote receiving channel for cross-machine message passing + self._channel = RemoteReceivingChannel( + self._server_rank_list, + self._producer_id_list, + self.worker_options.prefetch_size, ) - time.sleep(process_start_gap_seconds * local_rank) - self._mp_producer.init() def _get_labels( self, msg: SampleMessage @@ -817,7 +1182,7 @@ def _collate_fn(self, msg: SampleMessage) -> Union[Data, HeteroData]: ) if isinstance(data, HeteroData): data = strip_label_edges(data) - if not self._is_input_labeled_homogeneous: + if not self.is_homogeneous_with_labeled_edge_type: if len(self._supervision_edge_types) != 1: raise ValueError( f"Expected 1 supervision edge type, got {len(self._supervision_edge_types)}" diff --git a/gigl/distributed/distributed_neighborloader.py b/gigl/distributed/distributed_neighborloader.py index 6732a5b94..e24101461 100644 --- a/gigl/distributed/distributed_neighborloader.py +++ b/gigl/distributed/distributed_neighborloader.py @@ -394,13 +394,17 @@ def _setup_for_graph_store( ) sampling_port = sampling_ports[node_rank] + # TODO(kmonte) - We need to be able to differentiate between differnt instance of the same loader. + # e.g. if we have two different DistNeighborLoaders, then they will have conflicting worker keys. + # And they will share each others data. Therefor, the second loader will not load the data it's expecting. + # Probably, we can just keep track of the insantiations on the server-side and include the count in the worker key. worker_options = RemoteDistSamplingWorkerOptions( server_rank=list(range(dataset.cluster_info.num_storage_nodes)), num_workers=num_workers, worker_devices=[torch.device("cpu") for i in range(num_workers)], master_addr=dataset.cluster_info.storage_cluster_master_ip, master_port=sampling_port, - worker_key=f"compute_rank_{node_rank}", + worker_key=f"compute_loader_rank_{node_rank}", ) logger.info( f"Rank {torch.distributed.get_rank()}! init for sampling rpc: {f'tcp://{dataset.cluster_info.storage_cluster_master_ip}:{sampling_port}'}" @@ -424,7 +428,7 @@ def _setup_for_graph_store( # Determine input_type based on edge_feature_info if isinstance(edge_types, list): - if edge_types == [DEFAULT_HOMOGENEOUS_EDGE_TYPE]: + if DEFAULT_HOMOGENEOUS_EDGE_TYPE in edge_types: input_type: Optional[NodeType] = DEFAULT_HOMOGENEOUS_NODE_TYPE is_homogeneous_with_labeled_edge_type = True else: diff --git a/tests/integration/distributed/graph_store/graph_store_integration_test.py b/tests/integration/distributed/graph_store/graph_store_integration_test.py index 955d6239d..40b629e97 100644 --- a/tests/integration/distributed/graph_store/graph_store_integration_test.py +++ b/tests/integration/distributed/graph_store/graph_store_integration_test.py @@ -15,6 +15,7 @@ from gigl.common import Uri from gigl.common.logger import Logger +from gigl.distributed.dist_ablp_neighborloader import DistABLPLoader from gigl.distributed.distributed_neighborloader import DistNeighborLoader from gigl.distributed.graph_store.compute import ( init_compute_process, @@ -207,6 +208,70 @@ def _run_compute_train_tests( ) _assert_ablp_input(cluster_info, ablp_result) + # For labeled homogeneous, pass the dict directly (not as tuple) + input_nodes = ablp_result + + ablp_loader = DistABLPLoader( + dataset=remote_dist_dataset, + num_neighbors=[2, 2], + input_nodes=input_nodes, + supervision_edge_type=supervision_edge_type, + pin_memory_device=torch.device("cpu"), + num_workers=2, + worker_concurrency=2, + ) + + random_negative_input = remote_dist_dataset.get_node_ids( + split="train", + node_type=test_node_type, + rank=cluster_info.compute_node_rank, + world_size=cluster_info.num_compute_nodes, + ) + + # Test that two loaders can both be initialized and sampled from simultaneously. + random_negative_loader = DistNeighborLoader( + dataset=remote_dist_dataset, + num_neighbors=[2, 2], + input_nodes=random_negative_input, + pin_memory_device=torch.device("cpu"), + num_workers=2, + worker_concurrency=2, + ) + count = 0 + for i, (ablp_batch, random_negative_batch) in enumerate( + zip(ablp_loader, random_negative_loader) + ): + # Verify batch structure + assert hasattr(ablp_batch, "y_positive"), "Batch should have y_positive labels" + # y_positive should be dict mapping local anchor idx -> local label indices + assert isinstance( + ablp_batch.y_positive, dict + ), f"y_positive should be dict, got {type(ablp_batch.y_positive)}" + count += 1 + + torch.distributed.barrier() + logger.info(f"Rank {torch.distributed.get_rank()} loaded {count} ABLP batches") + + # Verify total count across all ranks + count_tensor = torch.tensor(count, dtype=torch.int64) + torch.distributed.all_reduce(count_tensor, op=torch.distributed.ReduceOp.SUM) + + # Calculate expected total anchors by summing across all compute nodes + # Each process on the same compute node has the same anchor count, so we sum + # across all processes and divide by num_processes_per_compute to get the true total + local_total_anchors = sum( + ablp_result[server_rank][0].shape[0] for server_rank in ablp_result + ) + expected_anchors_tensor = torch.tensor(local_total_anchors, dtype=torch.int64) + torch.distributed.all_reduce( + expected_anchors_tensor, op=torch.distributed.ReduceOp.SUM + ) + expected_batches = ( + expected_anchors_tensor.item() // cluster_info.num_processes_per_compute + ) + assert ( + count_tensor.item() == expected_batches + ), f"Expected {expected_batches} total batches, got {count_tensor.item()}" shutdown_compute_proccess() From 3c3bf1b8307fa9f1ac59d8efb0ed9a8e6996dfa8 Mon Sep 17 00:00:00 2001 From: kmontemayor Date: Wed, 11 Feb 2026 19:58:17 +0000 Subject: [PATCH 02/30] fixes --- gigl/distributed/dist_ablp_neighborloader.py | 33 +++++++++++-------- .../distributed/distributed_neighborloader.py | 2 +- 2 files changed, 21 insertions(+), 14 deletions(-) diff --git a/gigl/distributed/dist_ablp_neighborloader.py b/gigl/distributed/dist_ablp_neighborloader.py index 0bd653523..af3d87ce4 100644 --- a/gigl/distributed/dist_ablp_neighborloader.py +++ b/gigl/distributed/dist_ablp_neighborloader.py @@ -236,6 +236,9 @@ def __init__( ) del supervision_edge_type + self.data: Optional[Union[DistDataset, RemoteDistDataset]] = None + if isinstance(dataset, DistDataset): + self.data = dataset if context: assert ( @@ -406,9 +409,8 @@ def __init__( # Type narrowing for colocated mode - self.data = dataset self.input_data = sampler_input[0] - del dataset, sampler_input + del sampler_input assert isinstance(self.data, DistDataset) assert isinstance(self.input_data, ABLPNodeSamplerInput) @@ -556,7 +558,7 @@ def _setup_for_colocated( num_cpu_threads: Number of CPU threads for PyTorch. Returns: - Tuple of (ABLPNodeSamplerInput, MpDistSamplingWorkerOptions, DatasetSchema). + Tuple of (list[ABLPNodeSamplerInput], MpDistSamplingWorkerOptions, DatasetSchema). """ # Validate input format - should not be Graph Store format if isinstance(input_nodes, abc.Mapping): @@ -585,11 +587,11 @@ def _setup_for_colocated( anchor_node_type, anchor_node_ids = input_nodes # TODO (mkolodner-sc): We currently assume supervision edges are directed outward, revisit in future if # this assumption is no longer valid and/or is too opinionated - for sup_edge_type in self._supervision_edge_types: + for supervision_edge_type in self._supervision_edge_types: assert ( - sup_edge_type[0] == anchor_node_type + supervision_edge_type[0] == anchor_node_type ), f"Label EdgeType are currently expected to be provided in outward edge direction as tuple (`anchor_node_type`,`relation`,`supervision_node_type`), \ - got supervision edge type {sup_edge_type} with anchor node type {anchor_node_type}" + got supervision edge type {supervision_edge_type} with anchor node type {anchor_node_type}" if dataset.edge_dir == "in": self._supervision_edge_types = [ reverse_edge_type(sup_edge_type) @@ -646,11 +648,11 @@ def _setup_for_colocated( self._negative_label_edge_types: list[EdgeType] = [] positive_labels_by_label_edge_type: dict[EdgeType, torch.Tensor] = {} negative_labels_by_label_edge_type: dict[EdgeType, torch.Tensor] = {} - for sup_edge_type in self._supervision_edge_types: + for supervision_edge_type in self._supervision_edge_types: ( positive_label_edge_type, negative_label_edge_type, - ) = select_label_edge_types(sup_edge_type, dataset.graph.keys()) + ) = select_label_edge_types(supervision_edge_type, dataset.graph.keys()) self._positive_label_edge_types.append(positive_label_edge_type) if negative_label_edge_type is not None: self._negative_label_edge_types.append(negative_label_edge_type) @@ -703,6 +705,7 @@ def _setup_for_colocated( master_worker_port=neighbor_loader_port_for_current_rank, device=device, should_use_cpu_workers=should_use_cpu_workers, + # Lever to explore tuning for CPU based inference num_cpu_threads=num_cpu_threads, ) logger.info( @@ -716,10 +719,18 @@ def _setup_for_colocated( dist_sampling_port_for_current_rank = dist_sampling_ports[local_rank] worker_options = MpDistSamplingWorkerOptions( num_workers=num_workers, + # Each worker will spawn several sampling workers, and all sampling workers spawned by workers in one group + # need to be connected. Thus, we need master ip address and master port to + # initate the connection. + # Note that different groups of workers are independent, and thus + # the sampling processes in different groups should be independent, and should + # use different master ports. worker_devices=[torch.device("cpu") for _ in range(num_workers)], worker_concurrency=worker_concurrency, master_addr=master_ip_address, master_port=dist_sampling_port_for_current_rank, + # Load testing shows that when num_rpc_threads exceed 16, the performance + # will degrade. num_rpc_threads=min(dataset.num_partitions, 16), rpc_timeout=600, channel_size=channel_size, @@ -809,7 +820,7 @@ def _setup_for_graph_store( ) sampling_port = sampling_ports[node_rank] - # TODO(kmonte) - We need to be able to differentiate between differnt instance of the same loader. + # TODO(kmonte) - We need to be able to differentiate between different instances of the same loader. # e.g. if we have two different DistABLPLoaders, then they will have conflicting worker keys. # And they will share each others data. Therefor, the second loader will not load the data it's expecting. # Probably, we can just keep track of the insantiations on the server-side and include the count in the worker key. @@ -930,10 +941,6 @@ def _init_remote_worker( dataset_metadata: Metadata about the dataset schema. """ # Set instance variables (like DistLoader does) - # Note: We assign to self.data and self.input_data which are also set in the colocated - # branch. For Graph Store mode, data is None and input_data is a list. - object.__setattr__(self, "data", None) # No local data in Graph Store mode - object.__setattr__(self, "input_data", sampler_input) self.sampling_type = sampling_config.sampling_type self.num_neighbors = sampling_config.num_neighbors self.batch_size = sampling_config.batch_size diff --git a/gigl/distributed/distributed_neighborloader.py b/gigl/distributed/distributed_neighborloader.py index e24101461..d5ce47831 100644 --- a/gigl/distributed/distributed_neighborloader.py +++ b/gigl/distributed/distributed_neighborloader.py @@ -394,7 +394,7 @@ def _setup_for_graph_store( ) sampling_port = sampling_ports[node_rank] - # TODO(kmonte) - We need to be able to differentiate between differnt instance of the same loader. + # TODO(kmonte) - We need to be able to differentiate between different instance of the same loader. # e.g. if we have two different DistNeighborLoaders, then they will have conflicting worker keys. # And they will share each others data. Therefor, the second loader will not load the data it's expecting. # Probably, we can just keep track of the insantiations on the server-side and include the count in the worker key. From 48c26f5d09f36bd681a5d2f8c2633a22395500ce Mon Sep 17 00:00:00 2001 From: kmontemayor Date: Wed, 11 Feb 2026 21:59:20 +0000 Subject: [PATCH 03/30] attempt at e2e gs trainer --- .../configs/e2e_glt_gs_resource_config.yaml | 17 +- .../e2e_het_dblp_sup_gs_task_config.yaml | 11 +- .../e2e_hom_cora_sup_gs_task_config.yaml | 10 +- .../configs/example_resource_config.yaml | 17 +- .../graph_store/heterogeneous_training.py | 938 ++++++++++++++++++ .../graph_store/homogeneous_training.py | 861 ++++++++++++++++ .../graph_store/storage_main.py | 92 +- 7 files changed, 1920 insertions(+), 26 deletions(-) create mode 100644 examples/link_prediction/graph_store/heterogeneous_training.py create mode 100644 examples/link_prediction/graph_store/homogeneous_training.py diff --git a/deployment/configs/e2e_glt_gs_resource_config.yaml b/deployment/configs/e2e_glt_gs_resource_config.yaml index 097cf60d9..4839c8f48 100644 --- a/deployment/configs/e2e_glt_gs_resource_config.yaml +++ b/deployment/configs/e2e_glt_gs_resource_config.yaml @@ -1,5 +1,6 @@ # Diffs from e2e_glt_resource_config.yaml # - Swap vertex_ai_inferencer_config for vertex_ai_graph_store_inferencer_config +# - Swap vertex_ai_trainer_config for vertex_ai_graph_store_trainer_config shared_resource_config: resource_labels: cost_resource_group_tag: dev_experiments_COMPONENT @@ -26,11 +27,17 @@ preprocessor_config: machine_type: "n2d-highmem-64" disk_size_gb: 300 trainer_resource_config: - vertex_ai_trainer_config: - machine_type: n1-highmem-32 - gpu_type: NVIDIA_TESLA_T4 - gpu_limit: 2 - num_replicas: 2 + vertex_ai_graph_store_trainer_config: + graph_store_pool: + machine_type: n2-highmem-32 + gpu_type: ACCELERATOR_TYPE_UNSPECIFIED + gpu_limit: 0 + num_replicas: 2 + compute_pool: + machine_type: n1-standard-16 + gpu_type: NVIDIA_TESLA_T4 + gpu_limit: 2 + num_replicas: 2 inferencer_resource_config: vertex_ai_graph_store_inferencer_config: graph_store_pool: diff --git a/examples/link_prediction/graph_store/configs/e2e_het_dblp_sup_gs_task_config.yaml b/examples/link_prediction/graph_store/configs/e2e_het_dblp_sup_gs_task_config.yaml index 1ebf9acb7..aa4a86a32 100644 --- a/examples/link_prediction/graph_store/configs/e2e_het_dblp_sup_gs_task_config.yaml +++ b/examples/link_prediction/graph_store/configs/e2e_het_dblp_sup_gs_task_config.yaml @@ -29,7 +29,6 @@ datasetConfig: dataPreprocessorArgs: # This argument is specific for the `PassthroughPreprocessorConfigForMockedAssets` preprocessor to indicate which dataset we should be using mocked_dataset_name: 'dblp_node_anchor_edge_features_lp' -# TODO(kmonte): Add GS trainer trainerConfig: trainerArgs: # Example argument to trainer @@ -49,7 +48,15 @@ trainerConfig: ("paper", "to", "author"): [15, 15], ("author", "to", "paper"): [20, 20] } - command: python -m examples.link_prediction.heterogeneous_training + command: python -m examples.link_prediction.graph_store.heterogeneous_training + graphStoreStorageConfig: + command: python -m examples.link_prediction.graph_store.storage_main + storageArgs: + sample_edge_direction: "in" + splitter_cls_path: "gigl.utils.data_splitters.DistNodeAnchorLinkSplitter" + splitter_kwargs: '{"sampling_direction": "in", "should_convert_labels_to_edges": True, "num_val": 0.1, "num_test": 0.1, "supervision_edge_types": [("paper", "to", "author")]}' + ssl_positive_label_percentage: "0.05" + num_server_sessions: "1" # TODO(kmonte): Move to user-defined server code inferencerConfig: inferencerArgs: diff --git a/examples/link_prediction/graph_store/configs/e2e_hom_cora_sup_gs_task_config.yaml b/examples/link_prediction/graph_store/configs/e2e_hom_cora_sup_gs_task_config.yaml index 84e0badef..d694e5cf5 100644 --- a/examples/link_prediction/graph_store/configs/e2e_hom_cora_sup_gs_task_config.yaml +++ b/examples/link_prediction/graph_store/configs/e2e_hom_cora_sup_gs_task_config.yaml @@ -15,13 +15,19 @@ datasetConfig: dataPreprocessorArgs: # This argument is specific for the `PassthroughPreprocessorConfigForMockedAssets` preprocessor to indicate which dataset we should be using mocked_dataset_name: 'cora_homogeneous_node_anchor_edge_features_user_defined_labels' -# TODO(kmonte): Add GS trainer trainerConfig: trainerArgs: # Example argument to trainer log_every_n_batch: "50" # Frequency in which we log batch information num_neighbors: "[10, 10]" # Fanout per hop, specified as a string representation of a list for the homogeneous use case - command: python -m examples.link_prediction.homogeneous_training + command: python -m examples.link_prediction.graph_store.homogeneous_training + graphStoreStorageConfig: + command: python -m examples.link_prediction.graph_store.storage_main + storageArgs: + sample_edge_direction: "in" + splitter_cls_path: "gigl.utils.data_splitters.DistNodeAnchorLinkSplitter" + splitter_kwargs: '{"sampling_direction": "in", "should_convert_labels_to_edges": True, "num_val": 0.1, "num_test": 0.1}' + num_server_sessions: "1" # TODO(kmonte): Move to user-defined server code inferencerConfig: inferencerArgs: diff --git a/examples/link_prediction/graph_store/configs/example_resource_config.yaml b/examples/link_prediction/graph_store/configs/example_resource_config.yaml index aa6618fd7..869f627ca 100644 --- a/examples/link_prediction/graph_store/configs/example_resource_config.yaml +++ b/examples/link_prediction/graph_store/configs/example_resource_config.yaml @@ -46,13 +46,18 @@ preprocessor_config: max_num_workers: 4 machine_type: "n2-standard-16" disk_size_gb: 300 -# TODO(kmonte): Update trainer_resource_config: - vertex_ai_trainer_config: - machine_type: n1-standard-16 - gpu_type: NVIDIA_TESLA_T4 - gpu_limit: 2 - num_replicas: 2 + vertex_ai_graph_store_trainer_config: + graph_store_pool: + machine_type: n2-highmem-32 + gpu_type: ACCELERATOR_TYPE_UNSPECIFIED + gpu_limit: 0 + num_replicas: 2 + compute_pool: + machine_type: n1-standard-16 + gpu_type: NVIDIA_TESLA_T4 + gpu_limit: 2 + num_replicas: 2 inferencer_resource_config: vertex_ai_graph_store_inferencer_config: graph_store_pool: diff --git a/examples/link_prediction/graph_store/heterogeneous_training.py b/examples/link_prediction/graph_store/heterogeneous_training.py new file mode 100644 index 000000000..786597386 --- /dev/null +++ b/examples/link_prediction/graph_store/heterogeneous_training.py @@ -0,0 +1,938 @@ +""" +This file contains an example for how to run heterogeneous training in **graph store mode** using GiGL. + +Graph Store Mode vs Standard Mode: +---------------------------------- +Graph store mode uses a heterogeneous cluster architecture with two distinct sub-clusters: + 1. **Storage Cluster (graph_store_pool)**: Dedicated machines for storing and serving the graph + data. These are typically high-memory machines without GPUs (e.g., n2-highmem-32). + 2. **Compute Cluster (compute_pool)**: Dedicated machines for running model training. + These typically have GPUs attached (e.g., n1-standard-16 with NVIDIA_TESLA_T4). + +This separation allows for: + - Independent scaling of storage and compute resources + - Better memory utilization (graph data stays on storage nodes) + - Cost optimization by using appropriate hardware for each role + +In contrast, the standard training mode (see `examples/link_prediction/heterogeneous_training.py`) +uses a homogeneous cluster where each machine handles both graph storage and computation. + +Key Implementation Differences: +------------------------------- +This file (graph store mode): + - Uses `RemoteDistDataset` to connect to a remote graph store cluster + - Uses `init_compute_process` to initialize the compute node connection to storage + - Obtains cluster topology via `get_graph_store_info()` which returns `GraphStoreInfo` + - Uses `mp_sharing_dict` for efficient tensor sharing between local processes + - Fetches ABLP input via `RemoteDistDataset.get_ablp_input()` for the train/val/test splits + - Fetches random negative node IDs via `RemoteDistDataset.get_node_ids()` + +Standard mode (`heterogeneous_training.py`): + - Uses `DistDataset` with `build_dataset_from_task_config_uri` where each node loads its partition + - Manually manages distributed process groups with master IP and port + - Each machine stores its own partition of the graph data + +To run this file with GiGL orchestration, set the fields similar to below: + +trainerConfig: + trainerArgs: + log_every_n_batch: "50" + ssl_positive_label_percentage: "0.05" + command: python -m examples.link_prediction.graph_store.heterogeneous_training + graphStoreStorageConfig: + command: python -m examples.link_prediction.graph_store.storage_main + storageArgs: + sample_edge_direction: "in" + splitter_cls_path: "gigl.utils.data_splitters.DistNodeAnchorLinkSplitter" + splitter_kwargs: '{"sampling_direction": "in", "should_convert_labels_to_edges": true, "num_val": 0.1, "num_test": 0.1}' + ssl_positive_label_percentage: "0.05" + num_server_sessions: "1" +featureFlags: + should_run_glt_backend: 'True' + +Note: Ensure you use a resource config with `vertex_ai_graph_store_trainer_config` when +running in graph store mode. + +You can run this example in a full pipeline with `make run_het_dblp_sup_gs_e2e_test` from GiGL root. + +Note that the DBLP Dataset does not have specified labeled edges so we use the `ssl_positive_label_percentage` +field in the config to indicate what percentage of edges we should select as self-supervised labeled edges. +""" + +import argparse +import gc +import os +import statistics +import sys +import time +from collections.abc import Iterator, MutableMapping +from dataclasses import dataclass +from typing import Literal, Optional, Union + +import torch +import torch.distributed +import torch.multiprocessing as mp +from examples.link_prediction.models import init_example_gigl_heterogeneous_model +from torch_geometric.data import HeteroData + +import gigl.distributed +import gigl.distributed.utils +from gigl.common import Uri, UriFactory +from gigl.common.logger import Logger +from gigl.common.utils.torch_training import is_distributed_available_and_initialized +from gigl.distributed import DistABLPLoader +from gigl.distributed.distributed_neighborloader import DistNeighborLoader +from gigl.distributed.graph_store.compute import ( + init_compute_process, + shutdown_compute_proccess, +) +from gigl.distributed.graph_store.remote_dist_dataset import RemoteDistDataset +from gigl.distributed.utils import get_available_device, get_graph_store_info +from gigl.env.distributed import GraphStoreInfo +from gigl.nn import LinkPredictionGNN, RetrievalLoss +from gigl.src.common.types.graph_data import EdgeType, NodeType +from gigl.src.common.types.pb_wrappers.gbml_config import GbmlConfigPbWrapper +from gigl.src.common.utils.model import load_state_dict_from_uri, save_state_dict +from gigl.utils.iterator import InfiniteIterator +from gigl.utils.sampling import parse_fanout + +logger = Logger() + + +# We don't see logs for graph store mode for whatever reason. +# TODO(#442): Revert this once the GCP issues are resolved. +def flush(): + sys.stdout.write("\n") + sys.stdout.flush() + sys.stderr.write("\n") + sys.stderr.flush() + + +def _sync_metric_across_processes(metric: torch.Tensor) -> float: + """ + Takes the average of a training metric across multiple processes. Note that this function requires DDP to be initialized. + Args: + metric (torch.Tensor): The metric, expressed as a torch Tensor, which should be synced across multiple processes + Returns: + float: The average of the provided metric across all training processes + """ + assert is_distributed_available_and_initialized(), "DDP is not initialized" + # Make a copy of the local loss tensor + loss_tensor = metric.detach().clone() + torch.distributed.all_reduce(loss_tensor, op=torch.distributed.ReduceOp.SUM) + return loss_tensor.item() / torch.distributed.get_world_size() + + +def _setup_dataloaders( + dataset: RemoteDistDataset, + split: Literal["train", "val", "test"], + cluster_info: GraphStoreInfo, + supervision_edge_type: EdgeType, + num_neighbors: Union[list[int], dict[EdgeType, list[int]]], + sampling_workers_per_process: int, + main_batch_size: int, + random_batch_size: int, + device: torch.device, + sampling_worker_shared_channel_size: str, + process_start_gap_seconds: int, +) -> tuple[DistABLPLoader, DistNeighborLoader]: + """ + Sets up main and random dataloaders for training and testing purposes using a remote graph store dataset. + Args: + dataset (RemoteDistDataset): Remote dataset connected to the graph store cluster. + split (Literal["train", "val", "test"]): The current split which we are loading data for. + cluster_info (GraphStoreInfo): Cluster topology info for graph store mode. + supervision_edge_type (EdgeType): The supervision edge type to use for training. + num_neighbors: Fanout for subgraph sampling, where the ith item corresponds to the number of items to sample for the ith hop. + sampling_workers_per_process (int): Number of sampling workers per training/testing process. + main_batch_size (int): Batch size for main dataloader with query and labeled nodes. + random_batch_size (int): Batch size for random negative dataloader. + device (torch.device): Device to put loaded subgraphs on. + sampling_worker_shared_channel_size (str): Shared-memory buffer size (bytes) allocated for the channel during sampling. + process_start_gap_seconds (int): The amount of time to sleep for initializing each dataloader. + Returns: + DistABLPLoader: Dataloader for loading main batch data with query and labeled nodes. + DistNeighborLoader: Dataloader for loading random negative data. + """ + rank = torch.distributed.get_rank() + + query_node_type = supervision_edge_type.src_node_type + labeled_node_type = supervision_edge_type.dst_node_type + + shuffle = split == "train" + + # In graph store mode, we fetch ABLP input (anchors + positive/negative labels) from the storage cluster. + # This returns dict[server_rank, (anchors, pos_labels, neg_labels)] which the DistABLPLoader knows how to handle. + logger.info(f"---Rank {rank} fetching ABLP input for split={split}") + flush() + ablp_input = dataset.get_ablp_input( + split=split, + rank=cluster_info.compute_node_rank, + world_size=cluster_info.num_compute_nodes, + node_type=query_node_type, + supervision_edge_type=supervision_edge_type, + ) + + main_loader = DistABLPLoader( + dataset=dataset, + num_neighbors=num_neighbors, + input_nodes=(query_node_type, ablp_input), + num_workers=sampling_workers_per_process, + batch_size=main_batch_size, + pin_memory_device=device, + worker_concurrency=sampling_workers_per_process, + channel_size=sampling_worker_shared_channel_size, + process_start_gap_seconds=process_start_gap_seconds, + shuffle=shuffle, + ) + + logger.info(f"---Rank {rank} finished setting up main loader for split={split}") + flush() + + # We need to wait for all processes to finish initializing the main_loader before creating the + # random_negative_loader so that its initialization doesn't compete for memory with the main_loader. + torch.distributed.barrier() + + # For the random negative loader, we get all node IDs of the labeled node type from the storage cluster. + all_node_ids = dataset.get_node_ids( + rank=cluster_info.compute_node_rank, + world_size=cluster_info.num_compute_nodes, + node_type=labeled_node_type, + ) + + random_negative_loader = DistNeighborLoader( + dataset=dataset, + num_neighbors=num_neighbors, + input_nodes=(labeled_node_type, all_node_ids), + num_workers=sampling_workers_per_process, + batch_size=random_batch_size, + pin_memory_device=device, + worker_concurrency=sampling_workers_per_process, + channel_size=sampling_worker_shared_channel_size, + process_start_gap_seconds=process_start_gap_seconds, + shuffle=shuffle, + ) + + logger.info(f"---Rank {rank} finished setting up random negative loader for split={split}") + flush() + + # Wait for all processes to finish initializing the random_loader + torch.distributed.barrier() + + return main_loader, random_negative_loader + + +def _compute_loss( + model: LinkPredictionGNN, + main_data: HeteroData, + random_negative_data: HeteroData, + loss_fn: RetrievalLoss, + supervision_edge_type: EdgeType, + device: torch.device, +) -> torch.Tensor: + """ + With the provided model and loss function, computes the forward pass on the main batch data and random negative data. + Args: + model (LinkPredictionGNN): DDP-wrapped LinkPredictionGNN model for training and testing + main_data (HeteroData): The batch of data containing query nodes, positive nodes, and hard negative nodes + random_negative_data (HeteroData): The batch of data containing random negative nodes + loss_fn (RetrievalLoss): Initialized class to use for loss calculation + supervision_edge_type (EdgeType): The supervision edge type to use for training in format query_node -> relation -> labeled_node + device (torch.device): Device for training or validation + Returns: + torch.Tensor: Final loss for the current batch on the current process + """ + # Extract relevant node types from the supervision edge + query_node_type = supervision_edge_type.src_node_type + labeled_node_type = supervision_edge_type.dst_node_type + + if query_node_type == labeled_node_type: + inference_node_types = [query_node_type] + else: + inference_node_types = [query_node_type, labeled_node_type] + + # Forward pass through encoder + + main_embeddings = model( + data=main_data, output_node_types=inference_node_types, device=device + ) + random_negative_embeddings = model( + data=random_negative_data, + output_node_types=inference_node_types, + device=device, + ) + + # Extracting local query, random negative, positive, hard_negative, and random_negative indices. + query_node_idx: torch.Tensor = torch.arange( + main_data[query_node_type].batch_size + ).to(device) + random_negative_batch_size = random_negative_data[labeled_node_type].batch_size + + positive_idx: torch.Tensor = torch.cat(list(main_data.y_positive.values())).to( + device + ) + repeated_query_node_idx = query_node_idx.repeat_interleave( + torch.tensor([len(v) for v in main_data.y_positive.values()]).to(device) + ) + if hasattr(main_data, "y_negative"): + hard_negative_idx: torch.Tensor = torch.cat( + list(main_data.y_negative.values()) + ).to(device) + else: + hard_negative_idx = torch.empty(0, dtype=torch.long).to(device) + + # Use local IDs to get the corresponding embeddings in the tensors + + repeated_query_embeddings = main_embeddings[query_node_type][ + repeated_query_node_idx + ] + positive_node_embeddings = main_embeddings[labeled_node_type][positive_idx] + hard_negative_embeddings = main_embeddings[labeled_node_type][hard_negative_idx] + random_negative_embeddings = random_negative_embeddings[labeled_node_type][ + :random_negative_batch_size + ] + + # Decode the query embeddings and the candidate embeddings + + repeated_candidate_scores = model.decode( + query_embeddings=repeated_query_embeddings, + candidate_embeddings=torch.cat( + [ + positive_node_embeddings, + hard_negative_embeddings, + random_negative_embeddings, + ], + dim=0, + ), + ) + + # Compute the global candidate ids and concatenate into a single tensor + + global_candidate_ids = torch.cat( + ( + main_data[labeled_node_type].node[positive_idx], + main_data[labeled_node_type].node[hard_negative_idx], + random_negative_data[labeled_node_type].node[:random_negative_batch_size], + ) + ) + + global_repeated_query_ids = main_data[query_node_type].node[repeated_query_node_idx] + + # Feed scores and ids into the RetrievalLoss forward pass to get the final loss + + loss = loss_fn( + repeated_candidate_scores=repeated_candidate_scores, + candidate_ids=global_candidate_ids, + repeated_query_ids=global_repeated_query_ids, + device=device, + ) + + return loss + + +@dataclass(frozen=True) +class TrainingProcessArgs: + """ + Arguments for the heterogeneous training process in graph store mode. + + Attributes: + local_world_size (int): Number of training processes spawned by each machine. + cluster_info (GraphStoreInfo): Cluster topology info for graph store mode. + mp_sharing_dict (MutableMapping[str, torch.Tensor]): Shared dictionary for efficient tensor + sharing between local processes. + supervision_edge_type (EdgeType): The supervision edge type for training. + model_uri (Uri): URI to save/load the trained model state dict. + hid_dim (int): Hidden dimension of the model. + out_dim (int): Output dimension of the model. + node_type_to_feature_dim (dict[NodeType, int]): Mapping of node types to their feature dimensions. + edge_type_to_feature_dim (dict[EdgeType, int]): Mapping of edge types to their feature dimensions. + num_neighbors (Union[list[int], dict[EdgeType, list[int]]]): Fanout for subgraph sampling. + sampling_workers_per_process (int): Number of sampling workers per training/testing process. + sampling_worker_shared_channel_size (str): Shared-memory buffer size for the channel during sampling. + process_start_gap_seconds (int): Time to sleep between dataloader initializations. + main_batch_size (int): Batch size for main dataloader. + random_batch_size (int): Batch size for random negative dataloader. + learning_rate (float): Learning rate for the optimizer. + weight_decay (float): Weight decay for the optimizer. + num_max_train_batches (int): Maximum number of training batches across all processes. + num_val_batches (int): Number of validation batches across all processes. + val_every_n_batch (int): Frequency to run validation during training. + log_every_n_batch (int): Frequency to log batch information during training. + should_skip_training (bool): If True, skip training and only run testing. + """ + + # Distributed context + local_world_size: int + cluster_info: GraphStoreInfo + mp_sharing_dict: MutableMapping[str, torch.Tensor] + + # Data + supervision_edge_type: EdgeType + + # Model + model_uri: Uri + hid_dim: int + out_dim: int + node_type_to_feature_dim: dict[NodeType, int] + edge_type_to_feature_dim: dict[EdgeType, int] + + # Sampling config + num_neighbors: Union[list[int], dict[EdgeType, list[int]]] + sampling_workers_per_process: int + sampling_worker_shared_channel_size: str + process_start_gap_seconds: int + + # Training hyperparameters + main_batch_size: int + random_batch_size: int + learning_rate: float + weight_decay: float + num_max_train_batches: int + num_val_batches: int + val_every_n_batch: int + log_every_n_batch: int + should_skip_training: bool + + +def _training_process( + local_rank: int, + args: TrainingProcessArgs, +) -> None: + """ + This function is spawned by each machine for training a GNN model using graph store mode. + Args: + local_rank (int): Process number on the current machine + args (TrainingProcessArgs): Dataclass containing all training process arguments + """ + + # Note: This is a *critical* step in Graph Store mode. It initializes the connection to the storage cluster + # and sets up torch.distributed with the appropriate backend (NCCL if CUDA available, gloo otherwise). + logger.info( + f"Initializing compute process for local_rank {local_rank} in machine {args.cluster_info.compute_node_rank}" + ) + flush() + init_compute_process(local_rank, args.cluster_info) + dataset = RemoteDistDataset( + args.cluster_info, local_rank, mp_sharing_dict=args.mp_sharing_dict + ) + + rank = torch.distributed.get_rank() + world_size = torch.distributed.get_world_size() + logger.info( + f"---Current training process rank: {rank}, training process world size: {world_size}" + ) + flush() + + # We use one training device for each local process + device = get_available_device(local_process_rank=local_rank) + if torch.cuda.is_available(): + torch.cuda.set_device(device) + logger.info(f"---Rank {rank} training process set device {device}") + + loss_fn = RetrievalLoss( + loss=torch.nn.CrossEntropyLoss(reduction="mean"), + temperature=0.07, + remove_accidental_hits=True, + ) + + if not args.should_skip_training: + train_main_loader, train_random_negative_loader = _setup_dataloaders( + dataset=dataset, + split="train", + cluster_info=args.cluster_info, + supervision_edge_type=args.supervision_edge_type, + num_neighbors=args.num_neighbors, + sampling_workers_per_process=args.sampling_workers_per_process, + main_batch_size=args.main_batch_size, + random_batch_size=args.random_batch_size, + device=device, + sampling_worker_shared_channel_size=args.sampling_worker_shared_channel_size, + process_start_gap_seconds=args.process_start_gap_seconds, + ) + + train_main_loader_iter = InfiniteIterator(train_main_loader) + train_random_negative_loader_iter = InfiniteIterator( + train_random_negative_loader + ) + + val_main_loader, val_random_negative_loader = _setup_dataloaders( + dataset=dataset, + split="val", + cluster_info=args.cluster_info, + supervision_edge_type=args.supervision_edge_type, + num_neighbors=args.num_neighbors, + sampling_workers_per_process=args.sampling_workers_per_process, + main_batch_size=args.main_batch_size, + random_batch_size=args.random_batch_size, + device=device, + sampling_worker_shared_channel_size=args.sampling_worker_shared_channel_size, + process_start_gap_seconds=args.process_start_gap_seconds, + ) + + val_main_loader_iter = InfiniteIterator(val_main_loader) + val_random_negative_loader_iter = InfiniteIterator(val_random_negative_loader) + + model = init_example_gigl_heterogeneous_model( + node_type_to_feature_dim=args.node_type_to_feature_dim, + edge_type_to_feature_dim=args.edge_type_to_feature_dim, + hid_dim=args.hid_dim, + out_dim=args.out_dim, + device=device, + wrap_with_ddp=True, + find_unused_encoder_parameters=True, + ) + optimizer = torch.optim.AdamW( + params=model.parameters(), + lr=args.learning_rate, + weight_decay=args.weight_decay, + ) + logger.info( + f"Model initialized on rank {rank} training device {device}\n{model}" + ) + flush() + + # We add a barrier to wait for all processes to finish preparing the dataloader and initializing the model + torch.distributed.barrier() + + # Entering the training loop + training_start_time = time.time() + batch_idx = 0 + avg_train_loss = 0.0 + last_n_batch_avg_loss: list[float] = [] + last_n_batch_time: list[float] = [] + num_max_train_batches_per_process = args.num_max_train_batches // world_size + num_val_batches_per_process = args.num_val_batches // world_size + logger.info( + f"num_max_train_batches_per_process is set to {num_max_train_batches_per_process}" + ) + + model.train() + + batch_start = time.time() + for main_data, random_data in zip( + train_main_loader_iter, train_random_negative_loader_iter + ): + if batch_idx >= num_max_train_batches_per_process: + logger.info( + f"num_max_train_batches_per_process={num_max_train_batches_per_process} reached, " + f"stopping training on machine {args.cluster_info.compute_node_rank} local rank {local_rank}" + ) + break + loss = _compute_loss( + model=model, + main_data=main_data, + random_negative_data=random_data, + loss_fn=loss_fn, + supervision_edge_type=args.supervision_edge_type, + device=device, + ) + optimizer.zero_grad() + loss.backward() + optimizer.step() + avg_train_loss = _sync_metric_across_processes(metric=loss) + last_n_batch_avg_loss.append(avg_train_loss) + last_n_batch_time.append(time.time() - batch_start) + batch_start = time.time() + batch_idx += 1 + if batch_idx % args.log_every_n_batch == 0: + logger.info( + f"rank={rank}, batch={batch_idx}, latest local train_loss={loss:.6f}" + ) + if torch.cuda.is_available(): + torch.cuda.synchronize() + logger.info( + f"rank={rank}, batch={batch_idx}, mean(batch_time)={statistics.mean(last_n_batch_time):.3f} sec, max(batch_time)={max(last_n_batch_time):.3f} sec, min(batch_time)={min(last_n_batch_time):.3f} sec" + ) + last_n_batch_time.clear() + logger.info( + f"rank={rank}, latest avg_train_loss={avg_train_loss:.6f}, last {args.log_every_n_batch} mean(avg_train_loss)={statistics.mean(last_n_batch_avg_loss):.6f}" + ) + last_n_batch_avg_loss.clear() + flush() + + if batch_idx % args.val_every_n_batch == 0: + logger.info(f"rank={rank}, batch={batch_idx}, validating...") + model.eval() + _run_validation_loops( + model=model, + main_loader=val_main_loader_iter, + random_negative_loader=val_random_negative_loader_iter, + loss_fn=loss_fn, + supervision_edge_type=args.supervision_edge_type, + device=device, + log_every_n_batch=args.log_every_n_batch, + num_batches=num_val_batches_per_process, + ) + model.train() + + logger.info(f"---Rank {rank} finished training") + flush() + + # Memory cleanup and waiting for all processes to finish + if torch.cuda.is_available(): + torch.cuda.empty_cache() + torch.cuda.synchronize() + torch.distributed.barrier() + + # We explicitly shutdown all the dataloaders to reduce their memory footprint. + train_main_loader.shutdown() + train_random_negative_loader.shutdown() + val_main_loader.shutdown() + val_random_negative_loader.shutdown() + + # We save the model on the process with rank 0. + if torch.distributed.get_rank() == 0: + logger.info( + f"Training loop finished, took {time.time() - training_start_time:.3f} seconds, saving model to {args.model_uri}" + ) + save_state_dict( + model=model.unwrap_from_ddp(), save_to_path_uri=args.model_uri + ) + flush() + + else: # should_skip_training is True, meaning we should only run testing + state_dict = load_state_dict_from_uri( + load_from_uri=args.model_uri, device=device + ) + model = init_example_gigl_heterogeneous_model( + node_type_to_feature_dim=args.node_type_to_feature_dim, + edge_type_to_feature_dim=args.edge_type_to_feature_dim, + hid_dim=args.hid_dim, + out_dim=args.out_dim, + device=device, + wrap_with_ddp=True, + find_unused_encoder_parameters=True, + state_dict=state_dict, + ) + logger.info( + f"Model initialized on rank {rank} training device {device}\n{model}" + ) + + logger.info(f"---Rank {rank} started testing") + flush() + testing_start_time = time.time() + + model.eval() + + test_main_loader, test_random_negative_loader = _setup_dataloaders( + dataset=dataset, + split="test", + cluster_info=args.cluster_info, + supervision_edge_type=args.supervision_edge_type, + num_neighbors=args.num_neighbors, + sampling_workers_per_process=args.sampling_workers_per_process, + main_batch_size=args.main_batch_size, + random_batch_size=args.random_batch_size, + device=device, + sampling_worker_shared_channel_size=args.sampling_worker_shared_channel_size, + process_start_gap_seconds=args.process_start_gap_seconds, + ) + + # Since we are doing testing, we only want to go through the data once. + test_main_loader_iter = iter(test_main_loader) + test_random_negative_loader_iter = iter(test_random_negative_loader) + + _run_validation_loops( + model=model, + main_loader=test_main_loader_iter, + random_negative_loader=test_random_negative_loader_iter, + loss_fn=loss_fn, + supervision_edge_type=args.supervision_edge_type, + device=device, + log_every_n_batch=args.log_every_n_batch, + ) + + # Memory cleanup and waiting for all processes to finish + if torch.cuda.is_available(): + torch.cuda.empty_cache() + torch.cuda.synchronize() + torch.distributed.barrier() + + test_main_loader.shutdown() + test_random_negative_loader.shutdown() + + logger.info( + f"---Rank {rank} finished testing in {time.time() - testing_start_time:.3f} seconds" + ) + flush() + + # Graph store mode cleanup: shutdown the compute process connection to the storage cluster. + shutdown_compute_proccess() + gc.collect() + + logger.info( + f"---Rank {rank} finished all training and testing, shut down compute process" + ) + flush() + + +@torch.inference_mode() +def _run_validation_loops( + model: LinkPredictionGNN, + main_loader: Iterator[HeteroData], + random_negative_loader: Iterator[HeteroData], + loss_fn: RetrievalLoss, + supervision_edge_type: EdgeType, + device: torch.device, + log_every_n_batch: int, + num_batches: Optional[int] = None, +) -> None: + """ + Runs validation using the provided models and dataloaders. + This function is shared for both validation while training and testing after training has completed. + Args: + model (LinkPredictionGNN): DDP-wrapped LinkPredictionGNN model for training and testing + main_loader (Iterator[HeteroData]): Dataloader for loading main batch data with query and labeled nodes + random_negative_loader (Iterator[HeteroData]): Dataloader for loading random negative data + loss_fn (RetrievalLoss): Initialized class to use for loss calculation + supervision_edge_type (EdgeType): The supervision edge type to use for training + device (torch.device): Device to use for training or testing + log_every_n_batch (int): The frequency we should log batch information + num_batches (Optional[int]): The number of batches to run the validation loop for. + """ + + rank = torch.distributed.get_rank() + + logger.info( + f"Running validation loop on rank={rank}, log_every_n_batch={log_every_n_batch}, num_batches={num_batches}" + ) + if num_batches is None: + if isinstance(main_loader, InfiniteIterator) or isinstance( + random_negative_loader, InfiniteIterator + ): + raise ValueError( + "Must set `num_batches` field when the provided data loaders are wrapped with InfiniteIterator" + ) + + batch_idx = 0 + batch_losses: list[float] = [] + last_n_batch_time: list[float] = [] + batch_start = time.time() + + while True: + if num_batches and batch_idx >= num_batches: + break + try: + main_data = next(main_loader) + random_data = next(random_negative_loader) + except StopIteration: + break + + loss = _compute_loss( + model=model, + main_data=main_data, + random_negative_data=random_data, + loss_fn=loss_fn, + supervision_edge_type=supervision_edge_type, + device=device, + ) + + batch_losses.append(loss.item()) + last_n_batch_time.append(time.time() - batch_start) + batch_start = time.time() + batch_idx += 1 + if batch_idx % log_every_n_batch == 0: + logger.info(f"rank={rank}, batch={batch_idx}, latest test_loss={loss:.6f}") + if torch.cuda.is_available(): + torch.cuda.synchronize() + logger.info( + f"rank={rank}, batch={batch_idx}, mean(batch_time)={statistics.mean(last_n_batch_time):.3f} sec, max(batch_time)={max(last_n_batch_time):.3f} sec, min(batch_time)={min(last_n_batch_time):.3f} sec" + ) + last_n_batch_time.clear() + flush() + local_avg_loss = statistics.mean(batch_losses) + logger.info( + f"rank={rank} finished validation loop, local loss: {local_avg_loss=:.6f}" + ) + global_avg_val_loss = _sync_metric_across_processes( + metric=torch.tensor(local_avg_loss, device=device) + ) + logger.info(f"rank={rank} got global validation loss {global_avg_val_loss=:.6f}") + flush() + + return + + +def _run_example_training( + task_config_uri: str, +): + """ + Runs an example training + testing loop using GiGL Orchestration in graph store mode. + Args: + task_config_uri (str): Path to YAML-serialized GbmlConfig proto. + """ + program_start_time = time.time() + mp.set_start_method("spawn") + logger.info(f"Starting sub process method: {mp.get_start_method()}") + + # Step 1: Initialize global process group to get cluster info + torch.distributed.init_process_group(backend="gloo") + logger.info( + f"World size: {torch.distributed.get_world_size()}, rank: {torch.distributed.get_rank()}, OS world size: {os.environ['WORLD_SIZE']}, OS rank: {os.environ['RANK']}" + ) + cluster_info = get_graph_store_info() + logger.info(f"Cluster info: {cluster_info}") + torch.distributed.destroy_process_group() + logger.info( + f"Took {time.time() - program_start_time:.2f} seconds to connect worker pool" + ) + flush() + + # Step 2: Read config + gbml_config_pb_wrapper = GbmlConfigPbWrapper.get_gbml_config_pb_wrapper_from_uri( + gbml_config_uri=UriFactory.create_uri(task_config_uri) + ) + + # Training Hyperparameters + trainer_args = dict(gbml_config_pb_wrapper.trainer_config.trainer_args) + + if torch.cuda.is_available(): + default_local_world_size = torch.cuda.device_count() + else: + default_local_world_size = 2 + local_world_size = int( + trainer_args.get("local_world_size", str(default_local_world_size)) + ) + + if torch.cuda.is_available(): + if local_world_size > torch.cuda.device_count(): + raise ValueError( + f"Specified a local world size of {local_world_size} which exceeds the number of devices {torch.cuda.device_count()}" + ) + + fanout = trainer_args.get("num_neighbors", "[10, 10]") + num_neighbors = parse_fanout(fanout) + + sampling_workers_per_process: int = int( + trainer_args.get("sampling_workers_per_process", "4") + ) + + main_batch_size = int(trainer_args.get("main_batch_size", "16")) + random_batch_size = int(trainer_args.get("random_batch_size", "16")) + + hid_dim = int(trainer_args.get("hid_dim", "16")) + out_dim = int(trainer_args.get("out_dim", "16")) + + sampling_worker_shared_channel_size: str = trainer_args.get( + "sampling_worker_shared_channel_size", "4GB" + ) + + process_start_gap_seconds = int(trainer_args.get("process_start_gap_seconds", "0")) + log_every_n_batch = int(trainer_args.get("log_every_n_batch", "25")) + + learning_rate = float(trainer_args.get("learning_rate", "0.0005")) + weight_decay = float(trainer_args.get("weight_decay", "0.0005")) + num_max_train_batches = int(trainer_args.get("num_max_train_batches", "1000")) + num_val_batches = int(trainer_args.get("num_val_batches", "100")) + val_every_n_batch = int(trainer_args.get("val_every_n_batch", "50")) + + logger.info( + f"Got training args local_world_size={local_world_size}, \ + num_neighbors={num_neighbors}, \ + sampling_workers_per_process={sampling_workers_per_process}, \ + main_batch_size={main_batch_size}, \ + random_batch_size={random_batch_size}, \ + hid_dim={hid_dim}, \ + out_dim={out_dim}, \ + sampling_worker_shared_channel_size={sampling_worker_shared_channel_size}, \ + process_start_gap_seconds={process_start_gap_seconds}, \ + log_every_n_batch={log_every_n_batch}, \ + learning_rate={learning_rate}, \ + weight_decay={weight_decay}, \ + num_max_train_batches={num_max_train_batches}, \ + num_val_batches={num_val_batches}, \ + val_every_n_batch={val_every_n_batch}" + ) + + # Step 3: Extract model/data config + graph_metadata = gbml_config_pb_wrapper.graph_metadata_pb_wrapper + + node_type_to_feature_dim: dict[NodeType, int] = { + graph_metadata.condensed_node_type_to_node_type_map[ + condensed_node_type + ]: node_feature_dim + for condensed_node_type, node_feature_dim in gbml_config_pb_wrapper.preprocessed_metadata_pb_wrapper.condensed_node_type_to_feature_dim_map.items() + } + + edge_type_to_feature_dim: dict[EdgeType, int] = { + graph_metadata.condensed_edge_type_to_edge_type_map[ + condensed_edge_type + ]: edge_feature_dim + for condensed_edge_type, edge_feature_dim in gbml_config_pb_wrapper.preprocessed_metadata_pb_wrapper.condensed_edge_type_to_feature_dim_map.items() + } + + model_uri = UriFactory.create_uri( + gbml_config_pb_wrapper.gbml_config_pb.shared_config.trained_model_metadata.trained_model_uri + ) + + should_skip_training = gbml_config_pb_wrapper.shared_config.should_skip_training + + supervision_edge_types = ( + gbml_config_pb_wrapper.task_metadata_pb_wrapper.get_supervision_edge_types() + ) + if len(supervision_edge_types) != 1: + raise NotImplementedError( + "GiGL Training currently only supports 1 supervision edge type." + ) + supervision_edge_type = supervision_edge_types[0] + + # Step 4: Create shared dict for inter-process tensor sharing + mp_sharing_dict = mp.Manager().dict() + + # Step 5: Spawn training processes + logger.info("--- Launching training processes ...\n") + flush() + start_time = time.time() + + training_args = TrainingProcessArgs( + local_world_size=local_world_size, + cluster_info=cluster_info, + mp_sharing_dict=mp_sharing_dict, + supervision_edge_type=supervision_edge_type, + model_uri=model_uri, + hid_dim=hid_dim, + out_dim=out_dim, + node_type_to_feature_dim=node_type_to_feature_dim, + edge_type_to_feature_dim=edge_type_to_feature_dim, + num_neighbors=num_neighbors, + sampling_workers_per_process=sampling_workers_per_process, + sampling_worker_shared_channel_size=sampling_worker_shared_channel_size, + process_start_gap_seconds=process_start_gap_seconds, + main_batch_size=main_batch_size, + random_batch_size=random_batch_size, + learning_rate=learning_rate, + weight_decay=weight_decay, + num_max_train_batches=num_max_train_batches, + num_val_batches=num_val_batches, + val_every_n_batch=val_every_n_batch, + log_every_n_batch=log_every_n_batch, + should_skip_training=should_skip_training, + ) + + torch.multiprocessing.spawn( + _training_process, + args=(training_args,), + nprocs=local_world_size, + join=True, + ) + logger.info( + f"--- Training finished, took {time.time() - start_time} seconds" + ) + logger.info( + f"--- Program finished, which took {time.time() - program_start_time:.2f} seconds" + ) + flush() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Arguments for distributed model training on VertexAI (graph store mode)" + ) + parser.add_argument("--task_config_uri", type=str, help="Gbml config uri") + + args, unused_args = parser.parse_known_args() + logger.info(f"Unused arguments: {unused_args}") + + _run_example_training( + task_config_uri=args.task_config_uri, + ) diff --git a/examples/link_prediction/graph_store/homogeneous_training.py b/examples/link_prediction/graph_store/homogeneous_training.py new file mode 100644 index 000000000..421ed72d5 --- /dev/null +++ b/examples/link_prediction/graph_store/homogeneous_training.py @@ -0,0 +1,861 @@ +""" +This file contains an example for how to run homogeneous training in **graph store mode** using GiGL. + +Graph Store Mode vs Standard Mode: +---------------------------------- +Graph store mode uses a heterogeneous cluster architecture with two distinct sub-clusters: + 1. **Storage Cluster (graph_store_pool)**: Dedicated machines for storing and serving the graph + data. These are typically high-memory machines without GPUs (e.g., n2-highmem-32). + 2. **Compute Cluster (compute_pool)**: Dedicated machines for running model training. + These typically have GPUs attached (e.g., n1-standard-16 with NVIDIA_TESLA_T4). + +This separation allows for: + - Independent scaling of storage and compute resources + - Better memory utilization (graph data stays on storage nodes) + - Cost optimization by using appropriate hardware for each role + +In contrast, the standard training mode (see `examples/link_prediction/homogeneous_training.py`) +uses a homogeneous cluster where each machine handles both graph storage and computation. + +Key Implementation Differences: +------------------------------- +This file (graph store mode): + - Uses `RemoteDistDataset` to connect to a remote graph store cluster + - Uses `init_compute_process` to initialize the compute node connection to storage + - Obtains cluster topology via `get_graph_store_info()` which returns `GraphStoreInfo` + - Uses `mp_sharing_dict` for efficient tensor sharing between local processes + - Fetches ABLP input via `RemoteDistDataset.get_ablp_input()` for the train/val/test splits + - Fetches random negative node IDs via `RemoteDistDataset.get_node_ids()` + +Standard mode (`homogeneous_training.py`): + - Uses `DistDataset` with `build_dataset_from_task_config_uri` where each node loads its partition + - Manually manages distributed process groups with master IP and port + - Each machine stores its own partition of the graph data + +To run this file with GiGL orchestration, set the fields similar to below: + +trainerConfig: + trainerArgs: + log_every_n_batch: "50" + num_neighbors: "[10, 10]" + command: python -m examples.link_prediction.graph_store.homogeneous_training + graphStoreStorageConfig: + command: python -m examples.link_prediction.graph_store.storage_main + storageArgs: + sample_edge_direction: "in" + splitter_cls_path: "gigl.utils.data_splitters.DistNodeAnchorLinkSplitter" + splitter_kwargs: '{"sampling_direction": "in", "should_convert_labels_to_edges": true, "num_val": 0.1, "num_test": 0.1}' + num_server_sessions: "1" +featureFlags: + should_run_glt_backend: 'True' + +Note: Ensure you use a resource config with `vertex_ai_graph_store_trainer_config` when +running in graph store mode. + +You can run this example in a full pipeline with `make run_hom_cora_sup_gs_e2e_test` from GiGL root. +""" + +import argparse +import gc +import os +import statistics +import sys +import time +from collections.abc import Iterator, MutableMapping +from dataclasses import dataclass +from typing import Literal, Optional, Union + +import torch +import torch.distributed +import torch.multiprocessing as mp +from examples.link_prediction.models import init_example_gigl_homogeneous_model +from torch_geometric.data import Data + +import gigl.distributed +import gigl.distributed.utils +from gigl.common import Uri, UriFactory +from gigl.common.logger import Logger +from gigl.common.utils.torch_training import is_distributed_available_and_initialized +from gigl.distributed import DistABLPLoader +from gigl.distributed.distributed_neighborloader import DistNeighborLoader +from gigl.distributed.graph_store.compute import ( + init_compute_process, + shutdown_compute_proccess, +) +from gigl.distributed.graph_store.remote_dist_dataset import RemoteDistDataset +from gigl.distributed.utils import get_available_device, get_graph_store_info +from gigl.env.distributed import GraphStoreInfo +from gigl.nn import LinkPredictionGNN, RetrievalLoss +from gigl.src.common.types.graph_data import EdgeType +from gigl.src.common.types.pb_wrappers.gbml_config import GbmlConfigPbWrapper +from gigl.src.common.utils.model import load_state_dict_from_uri, save_state_dict +from gigl.utils.iterator import InfiniteIterator +from gigl.utils.sampling import parse_fanout + +logger = Logger() + + +# We don't see logs for graph store mode for whatever reason. +# TODO(#442): Revert this once the GCP issues are resolved. +def flush(): + sys.stdout.write("\n") + sys.stdout.flush() + sys.stderr.write("\n") + sys.stderr.flush() + + +def _sync_metric_across_processes(metric: torch.Tensor) -> float: + """ + Takes the average of a training metric across multiple processes. Note that this function requires DDP to be initialized. + Args: + metric (torch.Tensor): The metric, expressed as a torch Tensor, which should be synced across multiple processes + Returns: + float: The average of the provided metric across all training processes + """ + assert is_distributed_available_and_initialized(), "DDP is not initialized" + loss_tensor = metric.detach().clone() + torch.distributed.all_reduce(loss_tensor, op=torch.distributed.ReduceOp.SUM) + return loss_tensor.item() / torch.distributed.get_world_size() + + +def _setup_dataloaders( + dataset: RemoteDistDataset, + split: Literal["train", "val", "test"], + cluster_info: GraphStoreInfo, + num_neighbors: list[int] | dict[EdgeType, list[int]], + sampling_workers_per_process: int, + main_batch_size: int, + random_batch_size: int, + device: torch.device, + sampling_worker_shared_channel_size: str, + process_start_gap_seconds: int, +) -> tuple[DistABLPLoader, DistNeighborLoader]: + """ + Sets up main and random dataloaders for training and testing purposes using a remote graph store dataset. + Args: + dataset (RemoteDistDataset): Remote dataset connected to the graph store cluster. + split (Literal["train", "val", "test"]): The current split which we are loading data for. + cluster_info (GraphStoreInfo): Cluster topology info for graph store mode. + num_neighbors: Fanout for subgraph sampling. + sampling_workers_per_process (int): Number of sampling workers per training/testing process. + main_batch_size (int): Batch size for main dataloader with query and labeled nodes. + random_batch_size (int): Batch size for random negative dataloader. + device (torch.device): Device to put loaded subgraphs on. + sampling_worker_shared_channel_size (str): Shared-memory buffer size (bytes) for the channel during sampling. + process_start_gap_seconds (int): The amount of time to sleep for initializing each dataloader. + Returns: + DistABLPLoader: Dataloader for loading main batch data with query and labeled nodes. + DistNeighborLoader: Dataloader for loading random negative data. + """ + rank = torch.distributed.get_rank() + + shuffle = split == "train" + + # In graph store mode, we fetch ABLP input (anchors + positive/negative labels) from the storage cluster. + # For homogeneous graphs, no node type or supervision edge type wrapper is needed. + logger.info(f"---Rank {rank} fetching ABLP input for split={split}") + flush() + ablp_input = dataset.get_ablp_input( + split=split, + rank=cluster_info.compute_node_rank, + world_size=cluster_info.num_compute_nodes, + ) + + main_loader = DistABLPLoader( + dataset=dataset, + num_neighbors=num_neighbors, + input_nodes=ablp_input, + num_workers=sampling_workers_per_process, + batch_size=main_batch_size, + pin_memory_device=device, + worker_concurrency=sampling_workers_per_process, + channel_size=sampling_worker_shared_channel_size, + process_start_gap_seconds=process_start_gap_seconds, + shuffle=shuffle, + ) + + logger.info(f"---Rank {rank} finished setting up main loader for split={split}") + flush() + + # We need to wait for all processes to finish initializing the main_loader before creating the + # random_negative_loader so that its initialization doesn't compete for memory. + torch.distributed.barrier() + + # For the random negative loader, we get all node IDs from the storage cluster. + all_node_ids = dataset.get_node_ids( + rank=cluster_info.compute_node_rank, + world_size=cluster_info.num_compute_nodes, + ) + + random_negative_loader = DistNeighborLoader( + dataset=dataset, + num_neighbors=num_neighbors, + input_nodes=all_node_ids, + num_workers=sampling_workers_per_process, + batch_size=random_batch_size, + pin_memory_device=device, + worker_concurrency=sampling_workers_per_process, + channel_size=sampling_worker_shared_channel_size, + process_start_gap_seconds=process_start_gap_seconds, + shuffle=shuffle, + ) + + logger.info(f"---Rank {rank} finished setting up random negative loader for split={split}") + flush() + + # Wait for all processes to finish initializing the random_loader + torch.distributed.barrier() + + return main_loader, random_negative_loader + + +def _compute_loss( + model: LinkPredictionGNN, + main_data: Data, + random_negative_data: Data, + loss_fn: RetrievalLoss, + device: torch.device, +) -> torch.Tensor: + """ + With the provided model and loss function, computes the forward pass on the main batch data and random negative data. + Args: + model (LinkPredictionGNN): DDP-wrapped LinkPredictionGNN model for training and testing + main_data (Data): The batch of data containing query nodes, positive nodes, and hard negative nodes + random_negative_data (Data): The batch of data containing random negative nodes + loss_fn (RetrievalLoss): Initialized class to use for loss calculation + device (torch.device): Device for training or validation + Returns: + torch.Tensor: Final loss for the current batch on the current process + """ + # Forward pass through encoder + main_embeddings = model(data=main_data, device=device) + random_negative_embeddings = model(data=random_negative_data, device=device) + + query_node_idx: torch.Tensor = torch.arange(main_data.batch_size).to(device) + random_negative_batch_size = random_negative_data.batch_size + + positive_idx: torch.Tensor = torch.cat(list(main_data.y_positive.values())).to( + device + ) + repeated_query_node_idx = query_node_idx.repeat_interleave( + torch.tensor([len(v) for v in main_data.y_positive.values()]).to(device) + ) + if hasattr(main_data, "y_negative"): + hard_negative_idx: torch.Tensor = torch.cat( + list(main_data.y_negative.values()) + ).to(device) + else: + hard_negative_idx = torch.empty(0, dtype=torch.long).to(device) + + # Use local IDs to get the corresponding embeddings in the tensors + + repeated_query_embeddings = main_embeddings[repeated_query_node_idx] + positive_node_embeddings = main_embeddings[positive_idx] + hard_negative_embeddings = main_embeddings[hard_negative_idx] + random_negative_embeddings = random_negative_embeddings[:random_negative_batch_size] + + repeated_candidate_scores = model.decode( + query_embeddings=repeated_query_embeddings, + candidate_embeddings=torch.cat( + [ + positive_node_embeddings, + hard_negative_embeddings, + random_negative_embeddings, + ], + dim=0, + ), + ) + + global_candidate_ids = torch.cat( + ( + main_data.node[positive_idx], + main_data.node[hard_negative_idx], + random_negative_data.node[:random_negative_batch_size], + ) + ) + + global_repeated_query_ids = main_data.node[repeated_query_node_idx] + + loss = loss_fn( + repeated_candidate_scores=repeated_candidate_scores, + candidate_ids=global_candidate_ids, + repeated_query_ids=global_repeated_query_ids, + device=device, + ) + + return loss + + +@dataclass(frozen=True) +class TrainingProcessArgs: + """ + Arguments for the homogeneous training process in graph store mode. + + Attributes: + local_world_size (int): Number of training processes spawned by each machine. + cluster_info (GraphStoreInfo): Cluster topology info for graph store mode. + mp_sharing_dict (MutableMapping[str, torch.Tensor]): Shared dictionary for efficient tensor + sharing between local processes. + model_uri (Uri): URI to save/load the trained model state dict. + hid_dim (int): Hidden dimension of the model. + out_dim (int): Output dimension of the model. + node_feature_dim (int): Input node feature dimension for the model. + edge_feature_dim (int): Input edge feature dimension for the model. + num_neighbors (list[int] | dict[EdgeType, list[int]]): Fanout for subgraph sampling. + sampling_workers_per_process (int): Number of sampling workers per training/testing process. + sampling_worker_shared_channel_size (str): Shared-memory buffer size for the channel during sampling. + process_start_gap_seconds (int): Time to sleep between dataloader initializations. + main_batch_size (int): Batch size for main dataloader. + random_batch_size (int): Batch size for random negative dataloader. + learning_rate (float): Learning rate for the optimizer. + weight_decay (float): Weight decay for the optimizer. + num_max_train_batches (int): Maximum number of training batches across all processes. + num_val_batches (int): Number of validation batches across all processes. + val_every_n_batch (int): Frequency to run validation during training. + log_every_n_batch (int): Frequency to log batch information during training. + should_skip_training (bool): If True, skip training and only run testing. + """ + + # Distributed context + local_world_size: int + cluster_info: GraphStoreInfo + mp_sharing_dict: MutableMapping[str, torch.Tensor] + + # Model + model_uri: Uri + hid_dim: int + out_dim: int + node_feature_dim: int + edge_feature_dim: int + + # Sampling config + num_neighbors: list[int] | dict[EdgeType, list[int]] + sampling_workers_per_process: int + sampling_worker_shared_channel_size: str + process_start_gap_seconds: int + + # Training hyperparameters + main_batch_size: int + random_batch_size: int + learning_rate: float + weight_decay: float + num_max_train_batches: int + num_val_batches: int + val_every_n_batch: int + log_every_n_batch: int + should_skip_training: bool + + +def _training_process( + local_rank: int, + args: TrainingProcessArgs, +) -> None: + """ + This function is spawned by each machine for training a GNN model using graph store mode. + Args: + local_rank (int): Process number on the current machine + args (TrainingProcessArgs): Dataclass containing all training process arguments + """ + + # Note: This is a *critical* step in Graph Store mode. It initializes the connection to the storage cluster + # and sets up torch.distributed with the appropriate backend (NCCL if CUDA available, gloo otherwise). + logger.info( + f"Initializing compute process for local_rank {local_rank} in machine {args.cluster_info.compute_node_rank}" + ) + flush() + init_compute_process(local_rank, args.cluster_info) + dataset = RemoteDistDataset( + args.cluster_info, local_rank, mp_sharing_dict=args.mp_sharing_dict + ) + + rank = torch.distributed.get_rank() + world_size = torch.distributed.get_world_size() + logger.info( + f"---Current training process rank: {rank}, training process world size: {world_size}" + ) + flush() + + device = get_available_device(local_process_rank=local_rank) + if torch.cuda.is_available(): + torch.cuda.set_device(device) + logger.info(f"---Rank {rank} training process set device {device}") + + loss_fn = RetrievalLoss( + loss=torch.nn.CrossEntropyLoss(reduction="mean"), + temperature=0.07, + remove_accidental_hits=True, + ) + + if not args.should_skip_training: + train_main_loader, train_random_negative_loader = _setup_dataloaders( + dataset=dataset, + split="train", + cluster_info=args.cluster_info, + num_neighbors=args.num_neighbors, + sampling_workers_per_process=args.sampling_workers_per_process, + main_batch_size=args.main_batch_size, + random_batch_size=args.random_batch_size, + device=device, + sampling_worker_shared_channel_size=args.sampling_worker_shared_channel_size, + process_start_gap_seconds=args.process_start_gap_seconds, + ) + + train_main_loader_iter = InfiniteIterator(train_main_loader) + train_random_negative_loader_iter = InfiniteIterator( + train_random_negative_loader + ) + + val_main_loader, val_random_negative_loader = _setup_dataloaders( + dataset=dataset, + split="val", + cluster_info=args.cluster_info, + num_neighbors=args.num_neighbors, + sampling_workers_per_process=args.sampling_workers_per_process, + main_batch_size=args.main_batch_size, + random_batch_size=args.random_batch_size, + device=device, + sampling_worker_shared_channel_size=args.sampling_worker_shared_channel_size, + process_start_gap_seconds=args.process_start_gap_seconds, + ) + + val_main_loader_iter = InfiniteIterator(val_main_loader) + val_random_negative_loader_iter = InfiniteIterator(val_random_negative_loader) + + model = init_example_gigl_homogeneous_model( + node_feature_dim=args.node_feature_dim, + edge_feature_dim=args.edge_feature_dim, + hid_dim=args.hid_dim, + out_dim=args.out_dim, + device=device, + wrap_with_ddp=True, + find_unused_encoder_parameters=True, + ) + + optimizer = torch.optim.AdamW( + params=model.parameters(), + lr=args.learning_rate, + weight_decay=args.weight_decay, + ) + logger.info( + f"Model initialized on rank {rank} training device {device}\n{model}" + ) + flush() + + # We add a barrier to wait for all processes to finish preparing the dataloader and initializing the model + torch.distributed.barrier() + + # Entering the training loop + training_start_time = time.time() + batch_idx = 0 + avg_train_loss = 0.0 + last_n_batch_avg_loss: list[float] = [] + last_n_batch_time: list[float] = [] + num_max_train_batches_per_process = args.num_max_train_batches // world_size + num_val_batches_per_process = args.num_val_batches // world_size + logger.info( + f"num_max_train_batches_per_process is set to {num_max_train_batches_per_process}" + ) + + model.train() + + batch_start = time.time() + for main_data, random_data in zip( + train_main_loader_iter, train_random_negative_loader_iter + ): + if batch_idx >= num_max_train_batches_per_process: + logger.info( + f"num_max_train_batches_per_process={num_max_train_batches_per_process} reached, " + f"stopping training on machine {args.cluster_info.compute_node_rank} local rank {local_rank}" + ) + break + loss = _compute_loss( + model=model, + main_data=main_data, + random_negative_data=random_data, + loss_fn=loss_fn, + device=device, + ) + optimizer.zero_grad() + loss.backward() + optimizer.step() + avg_train_loss = _sync_metric_across_processes(metric=loss) + last_n_batch_avg_loss.append(avg_train_loss) + last_n_batch_time.append(time.time() - batch_start) + batch_start = time.time() + batch_idx += 1 + if batch_idx % args.log_every_n_batch == 0: + logger.info( + f"rank={rank}, batch={batch_idx}, latest local train_loss={loss:.6f}" + ) + if torch.cuda.is_available(): + torch.cuda.synchronize() + logger.info( + f"rank={rank}, mean(batch_time)={statistics.mean(last_n_batch_time):.3f} sec, max(batch_time)={max(last_n_batch_time):.3f} sec, min(batch_time)={min(last_n_batch_time):.3f} sec" + ) + last_n_batch_time.clear() + logger.info( + f"rank={rank}, latest avg_train_loss={avg_train_loss:.6f}, last {args.log_every_n_batch} mean(avg_train_loss)={statistics.mean(last_n_batch_avg_loss):.6f}" + ) + last_n_batch_avg_loss.clear() + flush() + + if batch_idx % args.val_every_n_batch == 0: + logger.info(f"rank={rank}, batch={batch_idx}, validating...") + model.eval() + _run_validation_loops( + model=model, + main_loader=val_main_loader_iter, + random_negative_loader=val_random_negative_loader_iter, + loss_fn=loss_fn, + device=device, + log_every_n_batch=args.log_every_n_batch, + num_batches=num_val_batches_per_process, + ) + model.train() + + logger.info(f"---Rank {rank} finished training") + flush() + + # Memory cleanup and waiting for all processes to finish + if torch.cuda.is_available(): + torch.cuda.empty_cache() + torch.cuda.synchronize() + torch.distributed.barrier() + + # We explicitly shutdown all the dataloaders to reduce their memory footprint. + train_main_loader.shutdown() + train_random_negative_loader.shutdown() + val_main_loader.shutdown() + val_random_negative_loader.shutdown() + + # We save the model on the process with rank 0. + if torch.distributed.get_rank() == 0: + logger.info( + f"Training loop finished, took {time.time() - training_start_time:.3f} seconds, saving model to {args.model_uri}" + ) + save_state_dict( + model=model.unwrap_from_ddp(), save_to_path_uri=args.model_uri + ) + flush() + + else: # should_skip_training is True, meaning we should only run testing + state_dict = load_state_dict_from_uri( + load_from_uri=args.model_uri, device=device + ) + model = init_example_gigl_homogeneous_model( + node_feature_dim=args.node_feature_dim, + edge_feature_dim=args.edge_feature_dim, + hid_dim=args.hid_dim, + out_dim=args.out_dim, + device=device, + wrap_with_ddp=True, + find_unused_encoder_parameters=True, + state_dict=state_dict, + ) + logger.info( + f"Model initialized on rank {rank} training device {device}\n{model}" + ) + + logger.info(f"---Rank {rank} started testing") + flush() + testing_start_time = time.time() + model.eval() + + test_main_loader, test_random_negative_loader = _setup_dataloaders( + dataset=dataset, + split="test", + cluster_info=args.cluster_info, + num_neighbors=args.num_neighbors, + sampling_workers_per_process=args.sampling_workers_per_process, + main_batch_size=args.main_batch_size, + random_batch_size=args.random_batch_size, + device=device, + sampling_worker_shared_channel_size=args.sampling_worker_shared_channel_size, + process_start_gap_seconds=args.process_start_gap_seconds, + ) + + # Since we are doing testing, we only want to go through the data once. + test_main_loader_iter = iter(test_main_loader) + test_random_negative_loader_iter = iter(test_random_negative_loader) + + _run_validation_loops( + model=model, + main_loader=test_main_loader_iter, + random_negative_loader=test_random_negative_loader_iter, + loss_fn=loss_fn, + device=device, + log_every_n_batch=args.log_every_n_batch, + ) + + # Memory cleanup and waiting for all processes to finish + if torch.cuda.is_available(): + torch.cuda.empty_cache() + torch.cuda.synchronize() + torch.distributed.barrier() + + test_main_loader.shutdown() + test_random_negative_loader.shutdown() + + logger.info( + f"---Rank {rank} finished testing in {time.time() - testing_start_time:.3f} seconds" + ) + flush() + + # Graph store mode cleanup: shutdown the compute process connection to the storage cluster. + shutdown_compute_proccess() + gc.collect() + + logger.info( + f"---Rank {rank} finished all training and testing, shut down compute process" + ) + flush() + + +@torch.inference_mode() +def _run_validation_loops( + model: LinkPredictionGNN, + main_loader: Iterator[Data], + random_negative_loader: Iterator[Data], + loss_fn: RetrievalLoss, + device: torch.device, + log_every_n_batch: int, + num_batches: Optional[int] = None, +) -> None: + """ + Runs validation using the provided models and dataloaders. + Args: + model (LinkPredictionGNN): DDP-wrapped LinkPredictionGNN model for training and testing + main_loader (Iterator[Data]): Dataloader for loading main batch data + random_negative_loader (Iterator[Data]): Dataloader for loading random negative data + loss_fn (RetrievalLoss): Initialized class to use for loss calculation + device (torch.device): Device to use for training or testing + log_every_n_batch (int): The frequency we should log batch information + num_batches (Optional[int]): The number of batches to run the validation loop for. + """ + rank = torch.distributed.get_rank() + + logger.info( + f"Running validation loop on rank={rank}, log_every_n_batch={log_every_n_batch}, num_batches={num_batches}" + ) + if num_batches is None: + if isinstance(main_loader, InfiniteIterator) or isinstance( + random_negative_loader, InfiniteIterator + ): + raise ValueError( + "Must set `num_batches` field when the provided data loaders are wrapped with InfiniteIterator" + ) + + batch_idx = 0 + batch_losses: list[float] = [] + last_n_batch_time: list[float] = [] + batch_start = time.time() + + while True: + if num_batches and batch_idx >= num_batches: + break + try: + main_data = next(main_loader) + random_data = next(random_negative_loader) + except StopIteration: + break + + loss = _compute_loss( + model=model, + main_data=main_data, + random_negative_data=random_data, + loss_fn=loss_fn, + device=device, + ) + + batch_losses.append(loss.item()) + last_n_batch_time.append(time.time() - batch_start) + batch_start = time.time() + batch_idx += 1 + if batch_idx % log_every_n_batch == 0: + logger.info(f"rank={rank}, batch={batch_idx}, latest test_loss={loss:.6f}") + if torch.cuda.is_available(): + torch.cuda.synchronize() + logger.info( + f"rank={rank}, batch={batch_idx}, mean(batch_time)={statistics.mean(last_n_batch_time):.3f} sec, max(batch_time)={max(last_n_batch_time):.3f} sec, min(batch_time)={min(last_n_batch_time):.3f} sec" + ) + last_n_batch_time.clear() + flush() + local_avg_loss = statistics.mean(batch_losses) + logger.info( + f"rank={rank} finished validation loop, local loss: {local_avg_loss=:.6f}" + ) + global_avg_val_loss = _sync_metric_across_processes( + metric=torch.tensor(local_avg_loss, device=device) + ) + logger.info(f"rank={rank} got global validation loss {global_avg_val_loss=:.6f}") + flush() + + return + + +def _run_example_training( + task_config_uri: str, +): + """ + Runs an example training + testing loop using GiGL Orchestration in graph store mode. + Args: + task_config_uri (str): Path to YAML-serialized GbmlConfig proto. + """ + program_start_time = time.time() + mp.set_start_method("spawn") + logger.info(f"Starting sub process method: {mp.get_start_method()}") + + # Step 1: Initialize global process group to get cluster info + torch.distributed.init_process_group(backend="gloo") + logger.info( + f"World size: {torch.distributed.get_world_size()}, rank: {torch.distributed.get_rank()}, OS world size: {os.environ['WORLD_SIZE']}, OS rank: {os.environ['RANK']}" + ) + cluster_info = get_graph_store_info() + logger.info(f"Cluster info: {cluster_info}") + torch.distributed.destroy_process_group() + logger.info( + f"Took {time.time() - program_start_time:.2f} seconds to connect worker pool" + ) + flush() + + # Step 2: Read config + gbml_config_pb_wrapper = GbmlConfigPbWrapper.get_gbml_config_pb_wrapper_from_uri( + gbml_config_uri=UriFactory.create_uri(task_config_uri) + ) + + # Training Hyperparameters + trainer_args = dict(gbml_config_pb_wrapper.trainer_config.trainer_args) + + if torch.cuda.is_available(): + default_local_world_size = torch.cuda.device_count() + else: + default_local_world_size = 2 + local_world_size = int( + trainer_args.get("local_world_size", str(default_local_world_size)) + ) + + if torch.cuda.is_available(): + if local_world_size > torch.cuda.device_count(): + raise ValueError( + f"Specified a local world size of {local_world_size} which exceeds the number of devices {torch.cuda.device_count()}" + ) + + fanout = trainer_args.get("num_neighbors", "[10, 10]") + num_neighbors = parse_fanout(fanout) + + sampling_workers_per_process: int = int( + trainer_args.get("sampling_workers_per_process", "4") + ) + + main_batch_size = int(trainer_args.get("main_batch_size", "16")) + random_batch_size = int(trainer_args.get("random_batch_size", "16")) + + hid_dim = int(trainer_args.get("hid_dim", "16")) + out_dim = int(trainer_args.get("out_dim", "16")) + + sampling_worker_shared_channel_size: str = trainer_args.get( + "sampling_worker_shared_channel_size", "4GB" + ) + + process_start_gap_seconds = int(trainer_args.get("process_start_gap_seconds", "0")) + log_every_n_batch = int(trainer_args.get("log_every_n_batch", "25")) + + learning_rate = float(trainer_args.get("learning_rate", "0.0005")) + weight_decay = float(trainer_args.get("weight_decay", "0.0005")) + num_max_train_batches = int(trainer_args.get("num_max_train_batches", "1000")) + num_val_batches = int(trainer_args.get("num_val_batches", "100")) + val_every_n_batch = int(trainer_args.get("val_every_n_batch", "50")) + + logger.info( + f"Got training args local_world_size={local_world_size}, \ + num_neighbors={num_neighbors}, \ + sampling_workers_per_process={sampling_workers_per_process}, \ + main_batch_size={main_batch_size}, \ + random_batch_size={random_batch_size}, \ + hid_dim={hid_dim}, \ + out_dim={out_dim}, \ + sampling_worker_shared_channel_size={sampling_worker_shared_channel_size}, \ + process_start_gap_seconds={process_start_gap_seconds}, \ + log_every_n_batch={log_every_n_batch}, \ + learning_rate={learning_rate}, \ + weight_decay={weight_decay}, \ + num_max_train_batches={num_max_train_batches}, \ + num_val_batches={num_val_batches}, \ + val_every_n_batch={val_every_n_batch}" + ) + + # Step 3: Extract model/data config + graph_metadata = gbml_config_pb_wrapper.graph_metadata_pb_wrapper + + node_feature_dim = gbml_config_pb_wrapper.preprocessed_metadata_pb_wrapper.condensed_node_type_to_feature_dim_map[ + graph_metadata.homogeneous_condensed_node_type + ] + edge_feature_dim = gbml_config_pb_wrapper.preprocessed_metadata_pb_wrapper.condensed_edge_type_to_feature_dim_map[ + graph_metadata.homogeneous_condensed_edge_type + ] + + model_uri = UriFactory.create_uri( + gbml_config_pb_wrapper.gbml_config_pb.shared_config.trained_model_metadata.trained_model_uri + ) + + should_skip_training = gbml_config_pb_wrapper.shared_config.should_skip_training + + # Step 4: Create shared dict for inter-process tensor sharing + mp_sharing_dict = mp.Manager().dict() + + # Step 5: Spawn training processes + logger.info("--- Launching training processes ...\n") + flush() + start_time = time.time() + + training_args = TrainingProcessArgs( + local_world_size=local_world_size, + cluster_info=cluster_info, + mp_sharing_dict=mp_sharing_dict, + model_uri=model_uri, + hid_dim=hid_dim, + out_dim=out_dim, + node_feature_dim=node_feature_dim, + edge_feature_dim=edge_feature_dim, + num_neighbors=num_neighbors, + sampling_workers_per_process=sampling_workers_per_process, + sampling_worker_shared_channel_size=sampling_worker_shared_channel_size, + process_start_gap_seconds=process_start_gap_seconds, + main_batch_size=main_batch_size, + random_batch_size=random_batch_size, + learning_rate=learning_rate, + weight_decay=weight_decay, + num_max_train_batches=num_max_train_batches, + num_val_batches=num_val_batches, + val_every_n_batch=val_every_n_batch, + log_every_n_batch=log_every_n_batch, + should_skip_training=should_skip_training, + ) + + torch.multiprocessing.spawn( + _training_process, + args=(training_args,), + nprocs=local_world_size, + join=True, + ) + logger.info( + f"--- Training finished, took {time.time() - start_time} seconds" + ) + logger.info( + f"--- Program finished, which took {time.time() - program_start_time:.2f} seconds" + ) + flush() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Arguments for distributed model training on VertexAI (graph store mode)" + ) + parser.add_argument("--task_config_uri", type=str, help="Gbml config uri") + + args, unused_args = parser.parse_known_args() + logger.info(f"Unused arguments: {unused_args}") + + _run_example_training( + task_config_uri=args.task_config_uri, + ) diff --git a/examples/link_prediction/graph_store/storage_main.py b/examples/link_prediction/graph_store/storage_main.py index f4016340d..4ae40537e 100644 --- a/examples/link_prediction/graph_store/storage_main.py +++ b/examples/link_prediction/graph_store/storage_main.py @@ -66,6 +66,7 @@ """ import argparse +import ast import multiprocessing.context as py_mp_context import os from distutils.util import strtobool @@ -75,6 +76,7 @@ from gigl.common import Uri, UriFactory from gigl.common.logger import Logger +from gigl.common.utils.os_utils import import_obj from gigl.distributed.dataset_factory import build_dataset from gigl.distributed.dist_dataset import DistDataset from gigl.distributed.dist_range_partitioner import DistRangePartitioner @@ -174,6 +176,7 @@ def storage_node_process( tf_record_uri_pattern: str = ".*-of-.*\.tfrecord(\.gz)?$", ssl_positive_label_percentage: Optional[float] = None, storage_world_backend: Optional[str] = None, + num_server_sessions: Optional[int] = None, ) -> None: """Run a storage node process @@ -190,6 +193,9 @@ def storage_node_process( Must be None if supervised edge labels are provided in advance. If 0.1 is provided, 10% of the edges will be selected as self-supervised labels. storage_world_backend (Optional[str]): The backend for the storage Torch Distributed process group. + num_server_sessions (Optional[int]): Number of server sessions to run. For training, this should be 1 + (a single session for the entire training + testing lifecycle). If None, defaults to one session + per inference node type (the existing inference behavior). """ init_method = f"tcp://{cluster_info.storage_cluster_master_ip}:{cluster_info.storage_cluster_master_port}" logger.info( @@ -222,17 +228,29 @@ def storage_node_process( splitter=splitter, _ssl_positive_label_percentage=ssl_positive_label_percentage, ) - inference_node_types = sorted( - gbml_config_pb_wrapper.task_metadata_pb_wrapper.get_task_root_node_types() - ) - logger.info(f"Inference node types: {inference_node_types}") + # Determine the number of server sessions. + # For inference, we default to one session per inference node type (each node type gets its own + # complete RPC lifecycle). For training, the caller should set num_server_sessions=1 since the + # compute side runs a single session for the entire training + testing lifecycle. + if num_server_sessions is None: + inference_node_types = sorted( + gbml_config_pb_wrapper.task_metadata_pb_wrapper.get_task_root_node_types() + ) + num_sessions = len(inference_node_types) + logger.info( + f"num_server_sessions not set, defaulting to {num_sessions} sessions (one per inference node type: {inference_node_types})" + ) + else: + num_sessions = num_server_sessions + logger.info(f"num_server_sessions explicitly set to {num_sessions}") + torch_process_ports = get_free_ports_from_master_node( - num_ports=len(inference_node_types) + num_ports=num_sessions ) torch.distributed.destroy_process_group() - for i, inference_node_type in enumerate(inference_node_types): + for session_idx in range(num_sessions): logger.info( - f"Starting storage node rank {storage_rank} / {cluster_info.num_storage_nodes} for inference node type {inference_node_type} (storage process group {i} / {len(inference_node_types)})" + f"Starting storage node rank {storage_rank} / {cluster_info.num_storage_nodes} for server session {session_idx} / {num_sessions}" ) mp_context = torch.multiprocessing.get_context("spawn") server_processes: list[py_mp_context.SpawnProcess] = [] @@ -245,7 +263,7 @@ def storage_node_process( storage_rank + i, # storage_rank cluster_info, # cluster_info dataset, # dataset - torch_process_ports[i], # torch_process_port + torch_process_ports[session_idx], # torch_process_port storage_world_backend, # storage_world_backend ), ) @@ -255,12 +273,11 @@ def storage_node_process( for server_process in server_processes: server_process.join() logger.info( - f"All server processes on storage node rank {storage_rank} / {cluster_info.num_storage_nodes} joined for inference node type {inference_node_type}" + f"All server processes on storage node rank {storage_rank} / {cluster_info.num_storage_nodes} joined for server session {session_idx} / {num_sessions}" ) if __name__ == "__main__": - # TODO(kmonte): We want to expose splitter class here probably. parser = argparse.ArgumentParser() parser.add_argument("--task_config_uri", type=str, required=True) parser.add_argument("--resource_config_uri", type=str, required=True) @@ -269,9 +286,59 @@ def storage_node_process( parser.add_argument( "--should_load_tf_records_in_parallel", type=str, default="True" ) + # Splitter configuration: use import_obj to dynamically load a splitter class. + # This is needed for training (where the dataset needs train/val/test splits) but not for inference. + parser.add_argument( + "--splitter_cls_path", + type=str, + default=None, + help="Fully qualified import path to splitter class, e.g. 'gigl.utils.data_splitters.DistNodeAnchorLinkSplitter'", + ) + parser.add_argument( + "--splitter_kwargs", + type=str, + default=None, + help="Python dict literal of keyword arguments for the splitter constructor, " + "parsed with ast.literal_eval. Tuples are supported directly, e.g. " + "'supervision_edge_types': [('paper', 'to', 'author')].", + ) + parser.add_argument( + "--ssl_positive_label_percentage", + type=str, + default=None, + help="Percentage of edges to select as self-supervised labels. " + "Must be None if supervised edge labels are provided in advance.", + ) + parser.add_argument( + "--num_server_sessions", + type=str, + default=None, + help="Number of server sessions. For training use '1'. " + "If not set, defaults to one session per inference node type.", + ) args = parser.parse_args() logger.info(f"Running storage node with arguments: {args}") + # Build splitter from args if provided. + # We use ast.literal_eval instead of json.loads so that Python tuples (e.g. for EdgeType) + # can be passed directly in the splitter_kwargs string without needing custom serialization. + splitter: Optional[Union[DistNodeAnchorLinkSplitter, DistNodeSplitter]] = None + ssl_positive_label_percentage: Optional[float] = None + if args.splitter_cls_path: + splitter_cls = import_obj(args.splitter_cls_path) + splitter_kwargs = ( + ast.literal_eval(args.splitter_kwargs) if args.splitter_kwargs else {} + ) + splitter = splitter_cls(**splitter_kwargs) + logger.info(f"Built splitter: {splitter}") + + if args.ssl_positive_label_percentage: + ssl_positive_label_percentage = float(args.ssl_positive_label_percentage) + + num_server_sessions = ( + int(args.num_server_sessions) if args.num_server_sessions else None + ) + # Setup cluster-wide (e.g. storage and compute nodes) Torch Distributed process group. # This is needed so we can get the cluster information (e.g. number of storage and compute nodes) and rank/world_size. torch.distributed.init_process_group(backend="gloo") @@ -280,7 +347,7 @@ def storage_node_process( logger.info( f"World size: {torch.distributed.get_world_size()}, rank: {torch.distributed.get_rank()}, OS world size: {os.environ['WORLD_SIZE']}, OS rank: {os.environ['RANK']}" ) - # Tear down the """"global""" process group so we can have a server-specific process group. + # Tear down the "global" process group so we can have a server-specific process group. torch.distributed.destroy_process_group() storage_node_process( @@ -288,7 +355,10 @@ def storage_node_process( cluster_info=cluster_info, task_config_uri=UriFactory.create_uri(args.task_config_uri), sample_edge_direction=args.sample_edge_direction, + splitter=splitter, + ssl_positive_label_percentage=ssl_positive_label_percentage, should_load_tf_records_in_parallel=bool( strtobool(args.should_load_tf_records_in_parallel) ), + num_server_sessions=num_server_sessions, ) From 90ff055a16fcbc5dfb022a01966f6ffae73b800d Mon Sep 17 00:00:00 2001 From: kmontemayor Date: Fri, 20 Feb 2026 18:11:30 +0000 Subject: [PATCH 04/30] fix? --- .../graph_store/heterogeneous_training.py | 14 +++++------ .../graph_store/homogeneous_training.py | 12 ++++------ .../graph_store/storage_main.py | 4 +--- gigl/utils/data_splitters.py | 9 ++++--- .../graph_store_integration_test.py | 2 +- tests/unit/utils/data_splitters_test.py | 24 ++++++------------- 6 files changed, 26 insertions(+), 39 deletions(-) diff --git a/examples/link_prediction/graph_store/heterogeneous_training.py b/examples/link_prediction/graph_store/heterogeneous_training.py index 786597386..ff5c887a6 100644 --- a/examples/link_prediction/graph_store/heterogeneous_training.py +++ b/examples/link_prediction/graph_store/heterogeneous_training.py @@ -75,8 +75,6 @@ from examples.link_prediction.models import init_example_gigl_heterogeneous_model from torch_geometric.data import HeteroData -import gigl.distributed -import gigl.distributed.utils from gigl.common import Uri, UriFactory from gigl.common.logger import Logger from gigl.common.utils.torch_training import is_distributed_available_and_initialized @@ -169,14 +167,14 @@ def _setup_dataloaders( split=split, rank=cluster_info.compute_node_rank, world_size=cluster_info.num_compute_nodes, - node_type=query_node_type, + anchor_node_type=query_node_type, supervision_edge_type=supervision_edge_type, ) main_loader = DistABLPLoader( dataset=dataset, num_neighbors=num_neighbors, - input_nodes=(query_node_type, ablp_input), + input_nodes=ablp_input, num_workers=sampling_workers_per_process, batch_size=main_batch_size, pin_memory_device=device, @@ -213,7 +211,9 @@ def _setup_dataloaders( shuffle=shuffle, ) - logger.info(f"---Rank {rank} finished setting up random negative loader for split={split}") + logger.info( + f"---Rank {rank} finished setting up random negative loader for split={split}" + ) flush() # Wait for all processes to finish initializing the random_loader @@ -915,9 +915,7 @@ def _run_example_training( nprocs=local_world_size, join=True, ) - logger.info( - f"--- Training finished, took {time.time() - start_time} seconds" - ) + logger.info(f"--- Training finished, took {time.time() - start_time} seconds") logger.info( f"--- Program finished, which took {time.time() - program_start_time:.2f} seconds" ) diff --git a/examples/link_prediction/graph_store/homogeneous_training.py b/examples/link_prediction/graph_store/homogeneous_training.py index 421ed72d5..4a350bbcb 100644 --- a/examples/link_prediction/graph_store/homogeneous_training.py +++ b/examples/link_prediction/graph_store/homogeneous_training.py @@ -63,7 +63,7 @@ import time from collections.abc import Iterator, MutableMapping from dataclasses import dataclass -from typing import Literal, Optional, Union +from typing import Literal, Optional import torch import torch.distributed @@ -71,8 +71,6 @@ from examples.link_prediction.models import init_example_gigl_homogeneous_model from torch_geometric.data import Data -import gigl.distributed -import gigl.distributed.utils from gigl.common import Uri, UriFactory from gigl.common.logger import Logger from gigl.common.utils.torch_training import is_distributed_available_and_initialized @@ -200,7 +198,9 @@ def _setup_dataloaders( shuffle=shuffle, ) - logger.info(f"---Rank {rank} finished setting up random negative loader for split={split}") + logger.info( + f"---Rank {rank} finished setting up random negative loader for split={split}" + ) flush() # Wait for all processes to finish initializing the random_loader @@ -838,9 +838,7 @@ def _run_example_training( nprocs=local_world_size, join=True, ) - logger.info( - f"--- Training finished, took {time.time() - start_time} seconds" - ) + logger.info(f"--- Training finished, took {time.time() - start_time} seconds") logger.info( f"--- Program finished, which took {time.time() - program_start_time:.2f} seconds" ) diff --git a/examples/link_prediction/graph_store/storage_main.py b/examples/link_prediction/graph_store/storage_main.py index 2065b70d2..5e8d883b4 100644 --- a/examples/link_prediction/graph_store/storage_main.py +++ b/examples/link_prediction/graph_store/storage_main.py @@ -239,9 +239,7 @@ def storage_node_process( num_sessions = num_server_sessions logger.info(f"num_server_sessions explicitly set to {num_sessions}") - torch_process_ports = get_free_ports_from_master_node( - num_ports=num_sessions - ) + torch_process_ports = get_free_ports_from_master_node(num_ports=num_sessions) torch.distributed.destroy_process_group() for session_idx in range(num_sessions): logger.info( diff --git a/gigl/utils/data_splitters.py b/gigl/utils/data_splitters.py index 9e05ead4f..254552fe6 100644 --- a/gigl/utils/data_splitters.py +++ b/gigl/utils/data_splitters.py @@ -189,7 +189,7 @@ def __init__( num_val: float = 0.1, num_test: float = 0.1, hash_function: Callable[[torch.Tensor], torch.Tensor] = _fast_hash, - supervision_edge_types: Optional[list[EdgeType]] = None, + supervision_edge_types: Optional[list[Union[EdgeType, PyGEdgeType]]] = None, should_convert_labels_to_edges: bool = True, ): """Initializes the DistNodeAnchorLinkSplitter. @@ -199,7 +199,7 @@ def __init__( num_val (float): The percentage of nodes to use for training. Defaults to 0.1 (10%). num_test (float): The percentage of nodes to use for validation. Defaults to 0.1 (10%). hash_function (Callable[[torch.Tensor, torch.Tensor], torch.Tensor]): The hash function to use. Defaults to `_fast_hash`. - supervision_edge_types (Optional[list[EdgeType]]): The supervision edge types we should use for splitting. + supervision_edge_types (Optional[list[Union[EdgeType, PyGEdgeType]]]): The supervision edge types we should use for splitting. Must be provided if we are splitting a heterogeneous graph. If None, uses the default message passing edge type in the graph. should_convert_labels_to_edges (bool): Whether label should be converted into an edge type in the graph. If provided, will make `gigl.distributed.build_dataset` convert all labels into edges, and will infer positive and negative edge types based on @@ -232,7 +232,10 @@ def __init__( # also be ("user", "positive", "story"), meaning that all edges in the loaded edge index tensor with this edge type will be treated as a labeled # edge and will be used for splitting. - self._supervision_edge_types: Sequence[EdgeType] = supervision_edge_types + self._supervision_edge_types: Sequence[EdgeType] = [ + EdgeType(*supervision_edge_type) + for supervision_edge_type in supervision_edge_types + ] self._labeled_edge_types: Sequence[EdgeType] if should_convert_labels_to_edges: labeled_edge_types = [ diff --git a/tests/integration/distributed/graph_store/graph_store_integration_test.py b/tests/integration/distributed/graph_store/graph_store_integration_test.py index 3746cc0f7..ea66c16f3 100644 --- a/tests/integration/distributed/graph_store/graph_store_integration_test.py +++ b/tests/integration/distributed/graph_store/graph_store_integration_test.py @@ -278,7 +278,7 @@ def _run_compute_train_tests( # Each process on the same compute node has the same anchor count, so we sum # across all processes and divide by num_processes_per_compute to get the true total local_total_anchors = sum( - ablp_result[server_rank][0].shape[0] for server_rank in ablp_result + ablp_result[server_rank].anchor_nodes.shape[0] for server_rank in ablp_result ) expected_anchors_tensor = torch.tensor(local_total_anchors, dtype=torch.int64) torch.distributed.all_reduce( diff --git a/tests/unit/utils/data_splitters_test.py b/tests/unit/utils/data_splitters_test.py index 729616f8e..1001e9da7 100644 --- a/tests/unit/utils/data_splitters_test.py +++ b/tests/unit/utils/data_splitters_test.py @@ -22,6 +22,7 @@ ) from tests.test_assets.distributed.utils import ( assert_tensor_equality, + create_test_process_group, get_process_group_init_method, ) from tests.test_assets.test_case import TestCase @@ -61,7 +62,7 @@ def _run_splitter_distributed( splitter (Union[DistNodeSplitter, DistNodeAnchorLinkSplitter]): The splitter to use for the distributed test """ torch.distributed.init_process_group( - rank=process_num, world_size=world_size, init_method=init_method + rank=process_num, world_size=world_size, init_method=init_method, backend="gloo" ) train, val, test = splitter(tensors[process_num]) expected_train, expected_val, expected_test = expected[process_num] @@ -198,9 +199,7 @@ def test_node_based_link_splitter( # train_num = 1 - val_num - test_num # From (minimum_num, maximum_num), the first train_num % of node ids will be in expected_train, the next val_num % of node ids will be in expected_val, # and the test_num % of node ids will be in test. If there are no node ids which are in the range for that split, the expected split will be empty. - torch.distributed.init_process_group( - rank=0, world_size=1, init_method=get_process_group_init_method() - ) + create_test_process_group() splitter = DistNodeAnchorLinkSplitter( sampling_direction=sampling_direction, hash_function=_IdentityHash(), @@ -416,9 +415,7 @@ def test_node_based_link_splitter_heterogenous( # train_num = 1 - val_num - test_num # From (minimum_num, maximum_num), the first train_num % of node ids will be in expected_train, the next val_num % of node ids will be in expected_val, # and the test_num % of node ids will be in test. If there are no node ids which are in the range for that split, the expected split will be empty. - torch.distributed.init_process_group( - rank=0, world_size=1, init_method=get_process_group_init_method() - ) + create_test_process_group() splitter = DistNodeAnchorLinkSplitter( sampling_direction="in", @@ -822,9 +819,7 @@ def test_hashed_node_splitter( # train_num = 1 - val_num - test_num # From (minimum_num, maximum_num), the first train_num % of node ids will be in expected_train, the next val_num % of node ids will be in expected_val, # and the test_num % of node ids will be in test. If there are no node ids which are in the range for that split, the expected split will be empty. - torch.distributed.init_process_group( - rank=0, world_size=1, init_method=get_process_group_init_method() - ) + create_test_process_group() splitter = DistNodeSplitter( hash_function=_IdentityHash(), num_val=val_num, @@ -922,10 +917,7 @@ def test_hashed_node_splitter_heterogeneous( # train_num = 1 - val_num - test_num # From (minimum_num, maximum_num), the first train_num % of node ids will be in expected_train, the next val_num % of node ids will be in expected_val, # and the test_num % of node ids will be in test. If there are no node ids which are in the range for that split, the expected split will be empty. - torch.distributed.init_process_group( - rank=0, world_size=1, init_method=get_process_group_init_method() - ) - + create_test_process_group() splitter = DistNodeSplitter( hash_function=_IdentityHash(), num_val=val_num, @@ -968,9 +960,7 @@ def test_hashed_node_splitter_requires_process_group(self): ] ) def test_hashed_node_splitter_invalid_inputs(self, _, node_ids): - torch.distributed.init_process_group( - rank=0, world_size=1, init_method=get_process_group_init_method() - ) + create_test_process_group() splitter = DistNodeSplitter() with self.assertRaises(ValueError): splitter(node_ids) From 366ff082f3dc2b4ac6a93bbd1a2e3e8e144ba72c Mon Sep 17 00:00:00 2001 From: kmontemayor Date: Fri, 20 Feb 2026 21:46:34 +0000 Subject: [PATCH 05/30] fix --- gigl/distributed/graph_store/dist_server.py | 12 ++- .../graph_store_integration_test.py | 73 ++----------------- 2 files changed, 16 insertions(+), 69 deletions(-) diff --git a/gigl/distributed/graph_store/dist_server.py b/gigl/distributed/graph_store/dist_server.py index add152206..7e12e7e3b 100644 --- a/gigl/distributed/graph_store/dist_server.py +++ b/gigl/distributed/graph_store/dist_server.py @@ -289,9 +289,15 @@ def get_node_ids( ) nodes = nodes[node_type] elif not isinstance(nodes, torch.Tensor): - raise ValueError( - f"node_type was not provided, so node ids must be a torch.Tensor (e.g. a homogeneous dataset), got {type(nodes)}." - ) + if nodes is not None and DEFAULT_HOMOGENEOUS_NODE_TYPE in nodes: + logger.info( + f"Received None node type but assuming it's a homogeneous dataset (node types: {nodes.keys()}) and returning the default node type." + ) + nodes = nodes[DEFAULT_HOMOGENEOUS_NODE_TYPE] + else: + raise ValueError( + f"node_type was not provided, so node ids must be a torch.Tensor (e.g. a homogeneous dataset), got {type(nodes)}." + ) if rank is not None and world_size is not None: return shard_nodes_by_process(nodes, rank, world_size) diff --git a/tests/integration/distributed/graph_store/graph_store_integration_test.py b/tests/integration/distributed/graph_store/graph_store_integration_test.py index ea66c16f3..cbb35f7e3 100644 --- a/tests/integration/distributed/graph_store/graph_store_integration_test.py +++ b/tests/integration/distributed/graph_store/graph_store_integration_test.py @@ -211,10 +211,12 @@ def _run_compute_train_tests( ) # Use default types for homogeneous graph - test_node_type = ( - node_type if node_type is not None else DEFAULT_HOMOGENEOUS_NODE_TYPE - ) - supervision_edge_type = DEFAULT_HOMOGENEOUS_EDGE_TYPE + # test_node_type = ( + # node_type if node_type is not None else DEFAULT_HOMOGENEOUS_NODE_TYPE + # ) + # supervision_edge_type = DEFAULT_HOMOGENEOUS_EDGE_TYPE + test_node_type = None + supervision_edge_type = None # Test get_ablp_input for train split ablp_result = remote_dist_dataset.get_ablp_input( @@ -233,7 +235,6 @@ def _run_compute_train_tests( dataset=remote_dist_dataset, num_neighbors=[2, 2], input_nodes=input_nodes, - supervision_edge_type=supervision_edge_type, pin_memory_device=torch.device("cpu"), num_workers=2, worker_concurrency=2, @@ -291,66 +292,6 @@ def _run_compute_train_tests( count_tensor.item() == expected_batches ), f"Expected {expected_batches} total batches, got {count_tensor.item()}" - ablp_loader = DistABLPLoader( - dataset=remote_dist_dataset, - num_neighbors=[2, 2], - input_nodes=ablp_result, - pin_memory_device=torch.device("cpu"), - num_workers=2, - worker_concurrency=2, - ) - - random_negative_input = remote_dist_dataset.get_node_ids( - split="train", - node_type=test_node_type, - rank=cluster_info.compute_node_rank, - world_size=cluster_info.num_compute_nodes, - ) - - # Test that two loaders can both be initialized and sampled from simultaneously. - random_negative_loader = DistNeighborLoader( - dataset=remote_dist_dataset, - num_neighbors=[2, 2], - input_nodes=random_negative_input, - pin_memory_device=torch.device("cpu"), - num_workers=2, - worker_concurrency=2, - ) - count = 0 - for i, (ablp_batch, random_negative_batch) in enumerate( - zip(ablp_loader, random_negative_loader) - ): - assert hasattr(ablp_batch, "y_positive"), "Batch should have y_positive labels" - # y_positive should be dict mapping local anchor idx -> local label indices - assert isinstance( - ablp_batch.y_positive, dict - ), f"y_positive should be dict, got {type(ablp_batch.y_positive)}" - count += 1 - - torch.distributed.barrier() - logger.info(f"Rank {torch.distributed.get_rank()} loaded {count} ABLP batches") - - # Verify total count across all ranks - count_tensor = torch.tensor(count, dtype=torch.int64) - torch.distributed.all_reduce(count_tensor, op=torch.distributed.ReduceOp.SUM) - - # Calculate expected total anchors by summing across all compute nodes - # Each process on the same compute node has the same anchor count, so we sum - # across all processes and divide by num_processes_per_compute to get the true total - local_total_anchors = sum( - ablp_result[server_rank].anchor_nodes.shape[0] for server_rank in ablp_result - ) - expected_anchors_tensor = torch.tensor(local_total_anchors, dtype=torch.int64) - torch.distributed.all_reduce( - expected_anchors_tensor, op=torch.distributed.ReduceOp.SUM - ) - expected_batches = ( - expected_anchors_tensor.item() // cluster_info.num_processes_per_compute - ) - assert ( - count_tensor.item() == expected_batches - ), f"Expected {expected_batches} total batches, got {count_tensor.item()}" - shutdown_compute_proccess() @@ -908,7 +849,7 @@ class GraphStoreIntegrationTest(TestCase): ERROR: build step 0 "docker-img/path:tag" failed: step exited with non-zero status: 2 """ - def test_graph_store_homogeneous(self): + def _test_graph_store_homogeneous(self): # Simulating two server machine, two compute machines. # Each machine has one process. cora_supervised_info = get_mocked_dataset_artifact_metadata()[ From 7454f5372370addfb4e28814c93f59f999db0083 Mon Sep 17 00:00:00 2001 From: kmontemayor Date: Fri, 20 Feb 2026 23:37:51 +0000 Subject: [PATCH 06/30] debug --- .../configs/e2e_het_dblp_sup_gs_task_config.yaml | 2 +- .../link_prediction/graph_store/homogeneous_training.py | 4 ++++ gigl/distributed/dist_ablp_neighborloader.py | 9 +++++++++ 3 files changed, 14 insertions(+), 1 deletion(-) diff --git a/examples/link_prediction/graph_store/configs/e2e_het_dblp_sup_gs_task_config.yaml b/examples/link_prediction/graph_store/configs/e2e_het_dblp_sup_gs_task_config.yaml index aa4a86a32..3ab56ab7c 100644 --- a/examples/link_prediction/graph_store/configs/e2e_het_dblp_sup_gs_task_config.yaml +++ b/examples/link_prediction/graph_store/configs/e2e_het_dblp_sup_gs_task_config.yaml @@ -54,7 +54,7 @@ trainerConfig: storageArgs: sample_edge_direction: "in" splitter_cls_path: "gigl.utils.data_splitters.DistNodeAnchorLinkSplitter" - splitter_kwargs: '{"sampling_direction": "in", "should_convert_labels_to_edges": True, "num_val": 0.1, "num_test": 0.1, "supervision_edge_types": [("paper", "to", "author")]}' + splitter_kwargs: '{"sampling_direction": "in", "should_convert_labels_to_edges": True, "num_val": 0.1, "num_test": 0.1, "supervision_edge_types": [("author", "to", "paper")]}' ssl_positive_label_percentage: "0.05" num_server_sessions: "1" # TODO(kmonte): Move to user-defined server code diff --git a/examples/link_prediction/graph_store/homogeneous_training.py b/examples/link_prediction/graph_store/homogeneous_training.py index 4a350bbcb..ff4938133 100644 --- a/examples/link_prediction/graph_store/homogeneous_training.py +++ b/examples/link_prediction/graph_store/homogeneous_training.py @@ -227,6 +227,10 @@ def _compute_loss( Returns: torch.Tensor: Final loss for the current batch on the current process """ + print(f"Computing loss for main data: {main_data}") + print(f"Computing loss for random negative data: {random_negative_data}") + print(f"Using model: {model}") + flush() # Forward pass through encoder main_embeddings = model(data=main_data, device=device) random_negative_embeddings = model(data=random_negative_data, device=device) diff --git a/gigl/distributed/dist_ablp_neighborloader.py b/gigl/distributed/dist_ablp_neighborloader.py index 1db2fc277..dd5b861d6 100644 --- a/gigl/distributed/dist_ablp_neighborloader.py +++ b/gigl/distributed/dist_ablp_neighborloader.py @@ -1,5 +1,6 @@ import ast import concurrent.futures +import sys import time from collections import Counter, abc, defaultdict from itertools import count @@ -60,6 +61,9 @@ logger = Logger() +def flush(): + sys.stdout.flush() + sys.stderr.flush() class DistABLPLoader(DistLoader): # Counts instantiations of this class, per process. @@ -827,10 +831,15 @@ def _setup_for_graph_store( # Extract node type and label edge types from the ABLPInputNodes dataclass. # All entries should have the same anchor_node_type and edge type keys. first_input = next(iter(input_nodes.values())) + input_type = first_input.anchor_node_type is_homogeneous_with_labeled_edge_type = ( input_type == DEFAULT_HOMOGENEOUS_NODE_TYPE ) + print(f"Input type: {input_type}") + print(f"Is homogeneous with labeled edge type: {is_homogeneous_with_labeled_edge_type}") + print(f"First input: {first_input}") + flush() # Extract supervision edge types and derive label edge types from the # ABLPInputNodes.labels dict (keyed by supervision edge type). From dd78d458de6f6db0d770110e3472da69f81938e7 Mon Sep 17 00:00:00 2001 From: kmontemayor Date: Sat, 21 Feb 2026 01:12:18 +0000 Subject: [PATCH 07/30] debug --- .../graph_store/heterogeneous_training.py | 11 ++++++++--- gigl/distributed/dist_ablp_neighborloader.py | 11 ++++++++--- gigl/distributed/graph_store/dist_server.py | 2 +- 3 files changed, 17 insertions(+), 7 deletions(-) diff --git a/examples/link_prediction/graph_store/heterogeneous_training.py b/examples/link_prediction/graph_store/heterogeneous_training.py index ff5c887a6..ffdc3621f 100644 --- a/examples/link_prediction/graph_store/heterogeneous_training.py +++ b/examples/link_prediction/graph_store/heterogeneous_training.py @@ -157,6 +157,11 @@ def _setup_dataloaders( query_node_type = supervision_edge_type.src_node_type labeled_node_type = supervision_edge_type.dst_node_type + if dataset.get_edge_dir() == "in": + anchor_node_type = labeled_node_type + else: + anchor_node_type = query_node_type + shuffle = split == "train" # In graph store mode, we fetch ABLP input (anchors + positive/negative labels) from the storage cluster. @@ -167,7 +172,7 @@ def _setup_dataloaders( split=split, rank=cluster_info.compute_node_rank, world_size=cluster_info.num_compute_nodes, - anchor_node_type=query_node_type, + anchor_node_type=anchor_node_type, supervision_edge_type=supervision_edge_type, ) @@ -195,13 +200,13 @@ def _setup_dataloaders( all_node_ids = dataset.get_node_ids( rank=cluster_info.compute_node_rank, world_size=cluster_info.num_compute_nodes, - node_type=labeled_node_type, + node_type=anchor_node_type, ) random_negative_loader = DistNeighborLoader( dataset=dataset, num_neighbors=num_neighbors, - input_nodes=(labeled_node_type, all_node_ids), + input_nodes=(anchor_node_type, all_node_ids), num_workers=sampling_workers_per_process, batch_size=random_batch_size, pin_memory_device=device, diff --git a/gigl/distributed/dist_ablp_neighborloader.py b/gigl/distributed/dist_ablp_neighborloader.py index dd5b861d6..8abeb6139 100644 --- a/gigl/distributed/dist_ablp_neighborloader.py +++ b/gigl/distributed/dist_ablp_neighborloader.py @@ -61,10 +61,12 @@ logger = Logger() + def flush(): sys.stdout.flush() sys.stderr.flush() + class DistABLPLoader(DistLoader): # Counts instantiations of this class, per process. # This is needed so we can generate unique worker key for each instance, for graph store mode. @@ -588,7 +590,6 @@ def _setup_for_colocated( raise ValueError( "When using heterogeneous ABLP, you must provide supervision_edge_types." ) - is_homogeneous_with_labeled_edge_type = True anchor_node_type, anchor_node_ids = input_nodes # TODO (mkolodner-sc): We currently assume supervision edges are directed outward, revisit in future if # this assumption is no longer valid and/or is too opinionated @@ -609,6 +610,7 @@ def _setup_for_colocated( ) anchor_node_ids = input_nodes anchor_node_type = DEFAULT_HOMOGENEOUS_NODE_TYPE + is_homogeneous_with_labeled_edge_type = True elif input_nodes is None: if dataset.node_ids is None: raise ValueError( @@ -624,6 +626,7 @@ def _setup_for_colocated( ) anchor_node_ids = dataset.node_ids anchor_node_type = DEFAULT_HOMOGENEOUS_NODE_TYPE + is_homogeneous_with_labeled_edge_type = True else: raise ValueError(f"Unexpected input_nodes type: {type(input_nodes)}") @@ -837,7 +840,9 @@ def _setup_for_graph_store( input_type == DEFAULT_HOMOGENEOUS_NODE_TYPE ) print(f"Input type: {input_type}") - print(f"Is homogeneous with labeled edge type: {is_homogeneous_with_labeled_edge_type}") + print( + f"Is homogeneous with labeled edge type: {is_homogeneous_with_labeled_edge_type}" + ) print(f"First input: {first_input}") flush() @@ -1199,7 +1204,7 @@ def _collate_fn(self, msg: SampleMessage) -> Union[Data, HeteroData]: ) if isinstance(data, HeteroData): data = strip_label_edges(data) - if not self.is_homogeneous_with_labeled_edge_type: + if self.is_homogeneous_with_labeled_edge_type: if len(self._supervision_edge_types) != 1: raise ValueError( f"Expected 1 supervision edge type, got {len(self._supervision_edge_types)}" diff --git a/gigl/distributed/graph_store/dist_server.py b/gigl/distributed/graph_store/dist_server.py index 7e12e7e3b..c38f1b63f 100644 --- a/gigl/distributed/graph_store/dist_server.py +++ b/gigl/distributed/graph_store/dist_server.py @@ -281,12 +281,12 @@ def get_node_ids( raise ValueError( f"Invalid split: {split}. Must be one of 'train', 'val', 'test', or None." ) - if node_type is not None: if not isinstance(nodes, abc.Mapping): raise ValueError( f"node_type was provided as {node_type}, so node ids must be a dict[NodeType, torch.Tensor] (e.g. a heterogeneous dataset), got {type(nodes)}" ) + print(f"node types: {nodes.keys()}") nodes = nodes[node_type] elif not isinstance(nodes, torch.Tensor): if nodes is not None and DEFAULT_HOMOGENEOUS_NODE_TYPE in nodes: From 2dbb241f142c9531c6c826259166b9c28d39c605 Mon Sep 17 00:00:00 2001 From: kmontemayor Date: Mon, 23 Feb 2026 18:20:13 +0000 Subject: [PATCH 08/30] test --- .../link_prediction/graph_store/heterogeneous_training.py | 5 ++++- gigl/distributed/distributed_neighborloader.py | 1 + 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/examples/link_prediction/graph_store/heterogeneous_training.py b/examples/link_prediction/graph_store/heterogeneous_training.py index ffdc3621f..2ac7b790e 100644 --- a/examples/link_prediction/graph_store/heterogeneous_training.py +++ b/examples/link_prediction/graph_store/heterogeneous_training.py @@ -257,7 +257,10 @@ def _compute_loss( inference_node_types = [query_node_type, labeled_node_type] # Forward pass through encoder - + print(f"Computing loss for main data: {main_data}") + print(f"Computing loss for random negative data: {random_negative_data}") + print(f"Using model: {model}") + flush() main_embeddings = model( data=main_data, output_node_types=inference_node_types, device=device ) diff --git a/gigl/distributed/distributed_neighborloader.py b/gigl/distributed/distributed_neighborloader.py index 1c0634042..20547763a 100644 --- a/gigl/distributed/distributed_neighborloader.py +++ b/gigl/distributed/distributed_neighborloader.py @@ -417,6 +417,7 @@ def _setup_for_graph_store( if isinstance(edge_types, list): if DEFAULT_HOMOGENEOUS_EDGE_TYPE in edge_types: input_type: Optional[NodeType] = DEFAULT_HOMOGENEOUS_NODE_TYPE + is_homogeneous_with_labeled_edge_type = True else: input_type = fallback_input_type elif require_edge_feature_info: From 15d73263765e3f63ca2ecd1e6e71357193398c24 Mon Sep 17 00:00:00 2001 From: kmontemayor Date: Tue, 24 Feb 2026 15:34:20 +0000 Subject: [PATCH 09/30] update with fixes --- gigl/distributed/base_dist_loader.py | 587 ++++++++++++++++++ gigl/distributed/dist_ablp_neighborloader.py | 430 ++----------- .../distributed/distributed_neighborloader.py | 452 +++----------- gigl/distributed/graph_store/dist_server.py | 25 +- .../graph_store/remote_dist_dataset.py | 19 + 5 files changed, 764 insertions(+), 749 deletions(-) create mode 100644 gigl/distributed/base_dist_loader.py diff --git a/gigl/distributed/base_dist_loader.py b/gigl/distributed/base_dist_loader.py new file mode 100644 index 000000000..d4ae3e452 --- /dev/null +++ b/gigl/distributed/base_dist_loader.py @@ -0,0 +1,587 @@ +""" +Base distributed loader that consolidates shared initialization logic +from DistNeighborLoader and DistABLPLoader. + +Subclasses GLT's DistLoader and handles: +- Dataset metadata storage +- Colocated mode: DistLoader attribute setting + staggered producer init +- Graph Store mode: barrier loop + async RPC dispatch + channel creation +""" + +import sys +import time +from collections import Counter +from dataclasses import dataclass +from typing import Callable, Optional, Union + +import torch +from graphlearn_torch.channel import RemoteReceivingChannel, ShmChannel +from graphlearn_torch.distributed import ( + DistLoader, + MpDistSamplingWorkerOptions, + RemoteDistSamplingWorkerOptions, + get_context, +) +from graphlearn_torch.distributed.dist_client import async_request_server +from graphlearn_torch.distributed.dist_sampling_producer import DistMpSamplingProducer +from graphlearn_torch.distributed.rpc import rpc_is_initialized +from graphlearn_torch.sampler import ( + NodeSamplerInput, + RemoteSamplerInput, + SamplingConfig, + SamplingType, +) +from torch_geometric.typing import EdgeType +from typing_extensions import Self + +import gigl.distributed.utils +from gigl.common.logger import Logger +from gigl.distributed.constants import DEFAULT_MASTER_INFERENCE_PORT +from gigl.distributed.dist_context import DistributedContext +from gigl.distributed.dist_dataset import DistDataset +from gigl.distributed.graph_store.dist_server import DistServer +from gigl.distributed.graph_store.remote_dist_dataset import RemoteDistDataset +from gigl.distributed.utils.neighborloader import ( + DatasetSchema, + patch_fanout_for_sampling, +) +from gigl.types.graph import DEFAULT_HOMOGENEOUS_NODE_TYPE + +logger = Logger() + + +# We don't see logs for graph store mode for whatever reason. +# TOOD(#442): Revert this once the GCP issues are resolved. +def _flush() -> None: + sys.stdout.flush() + sys.stderr.flush() + + +@dataclass(frozen=True) +class DistributedRuntimeInfo: + """Plain data container for resolved distributed context information.""" + + node_world_size: int + node_rank: int + rank: int + world_size: int + local_rank: int + local_world_size: int + master_ip_address: str + should_cleanup_distributed_context: bool + + +class BaseDistLoader(DistLoader): + """Base class for GiGL distributed loaders. + + Consolidates shared initialization logic from DistNeighborLoader and DistABLPLoader. + Subclasses GLT's DistLoader but does NOT call its ``__init__`` — instead, it + replicates the relevant attribute-setting logic to allow configurable producer classes. + + Subclasses should: + 1. Call ``resolve_runtime()`` to get runtime context. + 2. Determine mode (colocated vs graph store). + 3. Call ``create_sampling_config()`` to build the SamplingConfig. + 4. For colocated: call ``create_colocated_channel()`` and construct the + ``DistMpSamplingProducer`` (or subclass), then pass the producer as ``sampler``. + 5. For graph store: pass the RPC function (e.g. ``DistServer.create_sampling_producer``) + as ``sampler``. + 6. Call ``super().__init__()`` with the prepared data. + + Args: + dataset: ``DistDataset`` (colocated) or ``RemoteDistDataset`` (graph store). + sampler_input: Prepared by the subclass. Single input for colocated mode, + list (one per server) for graph store mode. + dataset_schema: Contains edge types, feature info, edge dir, etc. + worker_options: ``MpDistSamplingWorkerOptions`` (colocated) or + ``RemoteDistSamplingWorkerOptions`` (graph store). + sampling_config: Configuration for the sampler (created via ``create_sampling_config``). + device: Target device for sampled results. + runtime: Resolved distributed runtime information. + sampler: Either a pre-constructed ``DistMpSamplingProducer`` (colocated mode) + or a callable to dispatch on the ``DistServer`` (graph store mode). + process_start_gap_seconds: Delay between each process for staggered colocated init. + """ + + @staticmethod + def resolve_runtime( + context: Optional[DistributedContext] = None, + local_process_rank: Optional[int] = None, + local_process_world_size: Optional[int] = None, + ) -> DistributedRuntimeInfo: + """Resolves distributed context from either a DistributedContext or torch.distributed. + + Args: + context: (Deprecated) If provided, derives rank info from the DistributedContext. + Requires local_process_rank and local_process_world_size. + local_process_rank: (Deprecated) Required when context is provided. + local_process_world_size: (Deprecated) Required when context is provided. + + Returns: + A DistributedRuntimeInfo containing all resolved rank/topology information. + """ + should_cleanup_distributed_context: bool = False + + if context: + assert ( + local_process_world_size is not None + ), "context: DistributedContext provided, so local_process_world_size must be provided." + assert ( + local_process_rank is not None + ), "context: DistributedContext provided, so local_process_rank must be provided." + + master_ip_address = context.main_worker_ip_address + node_world_size = context.global_world_size + node_rank = context.global_rank + local_world_size = local_process_world_size + local_rank = local_process_rank + + rank = node_rank * local_world_size + local_rank + world_size = node_world_size * local_world_size + + if not torch.distributed.is_initialized(): + logger.info( + "process group is not available, trying to torch.distributed.init_process_group " + "to communicate necessary setup information." + ) + should_cleanup_distributed_context = True + logger.info( + f"Initializing process group with master ip address: {master_ip_address}, " + f"rank: {rank}, world size: {world_size}, " + f"local_rank: {local_rank}, local_world_size: {local_world_size}." + ) + torch.distributed.init_process_group( + backend="gloo", + init_method=f"tcp://{master_ip_address}:{DEFAULT_MASTER_INFERENCE_PORT}", + rank=rank, + world_size=world_size, + ) + else: + assert torch.distributed.is_initialized(), ( + "context: DistributedContext is None, so process group must be " + "initialized before constructing the loader." + ) + world_size = torch.distributed.get_world_size() + rank = torch.distributed.get_rank() + + rank_ip_addresses = gigl.distributed.utils.get_internal_ip_from_all_ranks() + master_ip_address = rank_ip_addresses[0] + + count_ranks_per_ip_address = Counter(rank_ip_addresses) + local_world_size = count_ranks_per_ip_address[master_ip_address] + for rank_ip_address, count in count_ranks_per_ip_address.items(): + if count != local_world_size: + raise ValueError( + f"All ranks must have the same number of processes, but found " + f"{count} processes for rank {rank} on ip {rank_ip_address}, " + f"expected {local_world_size}. " + f"count_ranks_per_ip_address = {count_ranks_per_ip_address}" + ) + + node_world_size = len(count_ranks_per_ip_address) + local_rank = rank % local_world_size + node_rank = rank // local_world_size + + return DistributedRuntimeInfo( + node_world_size=node_world_size, + node_rank=node_rank, + rank=rank, + world_size=world_size, + local_rank=local_rank, + local_world_size=local_world_size, + master_ip_address=master_ip_address, + should_cleanup_distributed_context=should_cleanup_distributed_context, + ) + + def __init__( + self, + dataset: Union[DistDataset, RemoteDistDataset], + sampler_input: Union[NodeSamplerInput, list[NodeSamplerInput]], + dataset_schema: DatasetSchema, + worker_options: Union[ + MpDistSamplingWorkerOptions, RemoteDistSamplingWorkerOptions + ], + sampling_config: SamplingConfig, + device: torch.device, + runtime: DistributedRuntimeInfo, + sampler: Union[DistMpSamplingProducer, Callable[..., int]], + process_start_gap_seconds: float = 60.0, + ): + # Set right away so __del__ can clean up if we throw during init. + # Will be set to False once connections are initialized. + self._shutdowned = True + + # Store dataset metadata for subclass _collate_fn usage + self._is_homogeneous_with_labeled_edge_type = ( + dataset_schema.is_homogeneous_with_labeled_edge_type + ) + self._node_feature_info = dataset_schema.node_feature_info + self._edge_feature_info = dataset_schema.edge_feature_info + + # --- Attributes shared by both modes (mirrors GLT DistLoader.__init__) --- + self.input_data = sampler_input + self.sampling_type = sampling_config.sampling_type + self.num_neighbors = sampling_config.num_neighbors + self.batch_size = sampling_config.batch_size + self.shuffle = sampling_config.shuffle + self.drop_last = sampling_config.drop_last + self.with_edge = sampling_config.with_edge + self.with_weight = sampling_config.with_weight + self.collect_features = sampling_config.collect_features + self.edge_dir = sampling_config.edge_dir + self.sampling_config = sampling_config + self.to_device = device + self.worker_options = worker_options + + self._is_collocated_worker = False + self._with_channel = True + self._num_recv = 0 + self._epoch = 0 + + # --- Mode-specific attributes and connection initialization --- + if isinstance(sampler, DistMpSamplingProducer): + assert isinstance(dataset, DistDataset) + assert isinstance(worker_options, MpDistSamplingWorkerOptions) + assert isinstance(sampler_input, NodeSamplerInput) + + self.data: Optional[DistDataset] = dataset + self._is_mp_worker = True + self._is_remote_worker = False + + self.num_data_partitions = dataset.num_partitions + self.data_partition_idx = dataset.partition_idx + self._set_ntypes_and_etypes( + dataset.get_node_types(), dataset.get_edge_types() + ) + + self._input_len = len(sampler_input) + self._input_type = sampler_input.input_type + self._num_expected = self._input_len // self.batch_size + if not self.drop_last and self._input_len % self.batch_size != 0: + self._num_expected += 1 + + self._shutdowned = False + self._init_colocated_connections( + dataset=dataset, + producer=sampler, + runtime=runtime, + process_start_gap_seconds=process_start_gap_seconds, + ) + else: + assert isinstance(dataset, RemoteDistDataset) + assert isinstance(worker_options, RemoteDistSamplingWorkerOptions) + assert isinstance(sampler_input, list) + assert callable(sampler) + + self.data = None + self._is_mp_worker = False + self._is_remote_worker = True + self._num_expected = float("inf") + + self._server_rank_list: list[int] = ( + worker_options.server_rank + if isinstance(worker_options.server_rank, list) + else [worker_options.server_rank] + ) + self._input_data_list = sampler_input + self._input_type = self._input_data_list[0].input_type + + self.num_data_partitions = dataset.cluster_info.num_storage_nodes + self.data_partition_idx = dataset.cluster_info.compute_node_rank + edge_types = dataset_schema.edge_types or [] + if edge_types: + node_types = list( + set([et[0] for et in edge_types] + [et[2] for et in edge_types]) + ) + else: + node_types = [DEFAULT_HOMOGENEOUS_NODE_TYPE] + self._set_ntypes_and_etypes(node_types, edge_types) + + self._shutdowned = False + self._init_graph_store_connections( + dataset=dataset, + create_producer_fn=sampler, + ) + + @staticmethod + def create_sampling_config( + num_neighbors: Union[list[int], dict[EdgeType, list[int]]], + dataset_schema: DatasetSchema, + batch_size: int = 1, + shuffle: bool = False, + drop_last: bool = False, + ) -> SamplingConfig: + """Creates a SamplingConfig with patched fanout. + + Patches ``num_neighbors`` to zero-out label edge types, then creates + the SamplingConfig used by both colocated and graph store modes. + + Args: + num_neighbors: Fanout per hop. + dataset_schema: Contains edge types and edge dir. + batch_size: How many samples per batch. + shuffle: Whether to shuffle input nodes. + drop_last: Whether to drop the last incomplete batch. + + Returns: + A fully configured SamplingConfig. + """ + num_neighbors = patch_fanout_for_sampling( + edge_types=dataset_schema.edge_types, + num_neighbors=num_neighbors, + ) + return SamplingConfig( + sampling_type=SamplingType.NODE, + num_neighbors=num_neighbors, + batch_size=batch_size, + shuffle=shuffle, + drop_last=drop_last, + with_edge=True, + collect_features=True, + with_neg=False, + with_weight=False, + edge_dir=dataset_schema.edge_dir, + seed=None, + ) + + @staticmethod + def create_colocated_channel( + worker_options: MpDistSamplingWorkerOptions, + ) -> ShmChannel: + """Creates a ShmChannel for colocated mode. + + Creates and optionally pin-memories the shared-memory channel. + + Args: + worker_options: The colocated worker options (must already be fully configured). + + Returns: + A ShmChannel ready to be passed to a DistMpSamplingProducer. + """ + channel = ShmChannel( + worker_options.channel_capacity, worker_options.channel_size + ) + if worker_options.pin_memory: + channel.pin_memory() + return channel + + def _init_colocated_connections( + self, + dataset: DistDataset, + producer: DistMpSamplingProducer, + runtime: DistributedRuntimeInfo, + process_start_gap_seconds: float, + ) -> None: + """Initialize colocated mode connections. + + Validates the GLT distributed context, stores the pre-constructed producer, + and performs staggered initialization to avoid memory OOM. + + All DistLoader attributes are already set by ``__init__`` before this is called. + + Args: + dataset: The local DistDataset. + producer: A pre-constructed DistMpSamplingProducer (or subclass). + runtime: Resolved distributed runtime info (used for staggered sleep). + process_start_gap_seconds: Delay multiplier for staggered init. + """ + # Validate context and store the pre-constructed producer and its channel + current_ctx = get_context() + if not current_ctx.is_worker(): + raise RuntimeError( + f"'{self.__class__.__name__}': only supports " + f"launching multiprocessing sampling workers with " + f"a non-server distribution mode, current role of " + f"distributed context is {current_ctx.role}." + ) + if dataset is None: + raise ValueError( + f"'{self.__class__.__name__}': missing input dataset " + f"when launching multiprocessing sampling workers." + ) + self.worker_options._set_worker_ranks(current_ctx) + self._channel = producer.output_channel + self._mp_producer = producer + + # Staggered init — sleep proportional to local_rank to avoid + # concurrent initialization spikes that cause CPU memory OOM. + logger.info( + f"---Machine {runtime.rank} local process number {runtime.local_rank} " + f"preparing to sleep for {process_start_gap_seconds * runtime.local_rank} seconds" + ) + time.sleep(process_start_gap_seconds * runtime.local_rank) + self._mp_producer.init() + + def _init_graph_store_connections( + self, + dataset: RemoteDistDataset, + create_producer_fn: Callable[..., int], + ) -> None: + """Initialize Graph Store mode connections. + + Validates the GLT distributed context, performs a sequential barrier loop + across compute nodes, dispatches async RPCs to create sampling producers on + storage nodes, and creates a RemoteReceivingChannel. + + All DistLoader attributes are already set by ``__init__`` before this is called. + + Uses ``async_request_server`` instead of ``ThreadPoolExecutor`` to avoid + TensorPipe rendezvous deadlock with many servers. + + For Graph Store mode it's important to distinguish "compute node" (e.g. physical compute machine) from "compute process" (e.g. process running on the compute node). + Since in practice we have multiple compute processes per compute node, and each compute process needs to initialize the connection to the storage nodes. + E.g. if there are 4 gpus per compute node, then there will be 4 connections from each compute node to each storage node. + + See below for a connection setup. + ╔═══════════════════════════════════════════════════════════════════════════════════════╗ + ║ COMPUTE TO STORAGE NODE CONNECTIONS ║ + ╚═══════════════════════════════════════════════════════════════════════════════════════╝ + + COMPUTE NODES STORAGE NODES + ═════════════ ═════════════ + + ┌──────────────────────┐ (1) ┌───────────────┐ + │ COMPUTE NODE 0 │ │ │ + │ ┌────┬────┬────┬────┤ ══════════════════════════════════│ STORAGE 0 │ + │ │GPU │GPU │GPU │GPU │ ╱ │ │ + │ │ 0 │ 1 │ 2 │ 3 │ ════════════════════╲ ╱ └───────────────┘ + │ └────┴────┴────┴────┤ (2) ╲ ╱ + └──────────────────────┘ ╲ ╱ + ╳ + (3) ╱ ╲ (4) + ┌──────────────────────┐ ╱ ╲ ┌───────────────┐ + │ COMPUTE NODE 1 │ ╱ ╲ │ │ + │ ┌────┬────┬────┬────┤ ═════════════════╱ ═│ STORAGE 1 │ + │ │GPU │GPU │GPU │GPU │ │ │ + │ │ 0 │ 1 │ 2 │ 3 │ ══════════════════════════════════│ │ + │ └────┴────┴────┴────┤ └───────────────┘ + └──────────────────────┘ + + ┌─────────────────────────────────────────────────────────────────────────────┐ + │ (1) Compute Node 0 → Storage 0 (4 connections, one per GPU) │ + │ (2) Compute Node 0 → Storage 1 (4 connections, one per GPU) │ + │ (3) Compute Node 1 → Storage 0 (4 connections, one per GPU) │ + │ (4) Compute Node 1 → Storage 1 (4 connections, one per GPU) │ + └─────────────────────────────────────────────────────────────────────────────┘ + """ + # Validate distributed context + ctx = get_context() + if ctx is None: + raise RuntimeError( + f"'{self.__class__.__name__}': the distributed context " + f"has not been initialized." + ) + if not ctx.is_client(): + raise RuntimeError( + f"'{self.__class__.__name__}': must be used on a client " + f"worker process." + ) + + # Move input to CPU before sending to server + for inp in self._input_data_list: + if not isinstance(inp, RemoteSamplerInput): + inp.to(torch.device("cpu")) + + node_rank = dataset.cluster_info.compute_node_rank + + _flush() + start_time = time.time() + rpc_futures: list[tuple[int, torch.futures.Future[int]]] = [] + # Dispatch ALL create_producer RPCs async. + # async_request_server queues the RPC in TensorPipe and returns + # immediately, allowing all storage nodes to start their worker + # rendezvous simultaneously. + logger.info( + f"node_rank={node_rank} dispatching create_sampling_producer to " + f"{len(self._server_rank_list)} servers" + ) + _flush() + t_dispatch = time.time() + for server_rank, inp_data in zip(self._server_rank_list, self._input_data_list): + fut = async_request_server( + server_rank, + create_producer_fn, + inp_data, + self.sampling_config, + self.worker_options, + ) + rpc_futures.append((server_rank, fut)) + logger.info( + f"node_rank={node_rank} all {len(rpc_futures)} RPCs dispatched in " + f"{time.time() - t_dispatch:.3f}s, waiting for responses" + ) + _flush() + + # Wait for all results + self._producer_id_list: list[int] = [] + for server_rank, fut in rpc_futures: + t_wait = time.time() + producer_id: int = fut.wait() + logger.info( + f"node_rank={node_rank} create_sampling_producer" + f"(server_rank={server_rank}) returned " + f"producer_id={producer_id} in {time.time() - t_wait:.2f}s" + ) + _flush() + self._producer_id_list.append(producer_id) + logger.info( + f"node_rank={node_rank} all {len(self._producer_id_list)} producers " + f"created in {time.time() - t_dispatch:.2f}s total" + ) + _flush() + # Create remote receiving channel for cross-machine message passing + self._channel = RemoteReceivingChannel( + self._server_rank_list, + self._producer_id_list, + self.worker_options.prefetch_size, + ) + + logger.info( + f"node_rank {node_rank} initialized the dist loader in " + f"{time.time() - start_time:.2f}s" + ) + _flush() + + # Overwrite DistLoader.shutdown to so we can use our own shutdown and rpc calls + def shutdown(self) -> None: + if self._shutdowned: + return + if self._is_collocated_worker: + self._collocated_producer.shutdown() + elif self._is_mp_worker: + self._mp_producer.shutdown() + elif rpc_is_initialized() is True: + rpc_futures: list[torch.futures.Future[None]] = [] + for server_rank, producer_id in zip( + self._server_rank_list, self._producer_id_list + ): + fut = async_request_server( + server_rank, DistServer.destroy_sampling_producer, producer_id + ) + rpc_futures.append(fut) + torch.futures.wait_all(rpc_futures) + self._shutdowned = True + + # Overwrite DistLoader.__iter__ to so we can use our own __iter__ and rpc calls + def __iter__(self) -> Self: + self._num_recv = 0 + if self._is_collocated_worker: + self._collocated_producer.reset() + elif self._is_mp_worker: + self._mp_producer.produce_all() + else: + rpc_futures: list[torch.futures.Future[None]] = [] + for server_rank, producer_id in zip( + self._server_rank_list, self._producer_id_list + ): + fut = async_request_server( + server_rank, + DistServer.start_new_epoch_sampling, + producer_id, + self._epoch, + ) + rpc_futures.append(fut) + torch.futures.wait_all(rpc_futures) + self._channel.reset() + self._epoch += 1 + return self diff --git a/gigl/distributed/dist_ablp_neighborloader.py b/gigl/distributed/dist_ablp_neighborloader.py index 8abeb6139..1ddaf0fc7 100644 --- a/gigl/distributed/dist_ablp_neighborloader.py +++ b/gigl/distributed/dist_ablp_neighborloader.py @@ -1,28 +1,20 @@ import ast -import concurrent.futures -import sys -import time -from collections import Counter, abc, defaultdict +from collections import abc, defaultdict from itertools import count -from typing import Optional, Union +from typing import Callable, Optional, Union import torch -from graphlearn_torch.channel import RemoteReceivingChannel, SampleMessage, ShmChannel +from graphlearn_torch.channel import SampleMessage from graphlearn_torch.distributed import ( - DistLoader, MpDistSamplingWorkerOptions, RemoteDistSamplingWorkerOptions, - get_context, - request_server, ) -from graphlearn_torch.sampler import SamplingConfig, SamplingType -from graphlearn_torch.utils import reverse_edge_type from torch_geometric.data import Data, HeteroData from torch_geometric.typing import EdgeType import gigl.distributed.utils from gigl.common.logger import Logger -from gigl.distributed.constants import DEFAULT_MASTER_INFERENCE_PORT +from gigl.distributed.base_dist_loader import BaseDistLoader from gigl.distributed.dist_context import DistributedContext from gigl.distributed.dist_dataset import DistDataset from gigl.distributed.dist_sampling_producer import DistABLPSamplingProducer @@ -39,7 +31,6 @@ DatasetSchema, SamplingClusterSetup, labeled_to_homogeneous, - patch_fanout_for_sampling, set_missing_features, shard_nodes_by_process, strip_label_edges, @@ -62,12 +53,7 @@ logger = Logger() -def flush(): - sys.stdout.flush() - sys.stderr.flush() - - -class DistABLPLoader(DistLoader): +class DistABLPLoader(BaseDistLoader): # Counts instantiations of this class, per process. # This is needed so we can generate unique worker key for each instance, for graph store mode. # NOTE: This is per-class, not per-instance. @@ -210,21 +196,8 @@ def __init__( # Set self._shutdowned right away, that way if we throw here, and __del__ is called, # then we can properly clean up and don't get extraneous error messages. - # We set to `True` as we don't need to cleanup right away, and this will get set - # to `False` in super().__init__()` e.g. - # https://github.com/alibaba/graphlearn-for-pytorch/blob/26fe3d4e050b081bc51a79dc9547f244f5d314da/graphlearn_torch/python/distributed/dist_loader.py#L125C1-L126C1 self._shutdowned = True - node_world_size: int - node_rank: int - rank: int - world_size: int - local_rank: int - local_world_size: int - - master_ip_address: str - should_cleanup_distributed_context: bool = False - # Determine sampling cluster setup based on dataset type if isinstance(dataset, RemoteDistDataset): self._sampling_cluster_setup = SamplingClusterSetup.GRAPH_STORE @@ -256,81 +229,23 @@ def __init__( del supervision_edge_type self._instance_count = next(self._counter) - self.data: Optional[Union[DistDataset, RemoteDistDataset]] = None - if isinstance(dataset, DistDataset): - self.data = dataset - - if context: - assert ( - local_process_world_size is not None - ), "context: DistributedContext provided, so local_process_world_size must be provided." - assert ( - local_process_rank is not None - ), "context: DistributedContext provided, so local_process_rank must be provided." - - master_ip_address = context.main_worker_ip_address - node_world_size = context.global_world_size - node_rank = context.global_rank - local_world_size = local_process_world_size - local_rank = local_process_rank - - rank = node_rank * local_world_size + local_rank - world_size = node_world_size * local_world_size - - if not torch.distributed.is_initialized(): - logger.info( - "process group is not available, trying to torch.distributed.init_process_group to communicate necessary setup information." - ) - should_cleanup_distributed_context = True - logger.info( - f"Initializing process group with master ip address: {master_ip_address}, rank: {rank}, world size: {world_size}, local_rank: {local_rank}, local_world_size: {local_world_size}" - ) - torch.distributed.init_process_group( - backend="gloo", # We just default to gloo for this temporary process group - init_method=f"tcp://{master_ip_address}:{DEFAULT_MASTER_INFERENCE_PORT}", - rank=rank, - world_size=world_size, - ) - - else: - assert ( - torch.distributed.is_initialized() - ), f"context: DistributedContext is None, so process group must be initialized before constructing this object {self.__class__.__name__}." - world_size = torch.distributed.get_world_size() - rank = torch.distributed.get_rank() - - rank_ip_addresses = gigl.distributed.utils.get_internal_ip_from_all_ranks() - master_ip_address = rank_ip_addresses[0] - - count_ranks_per_ip_address = Counter(rank_ip_addresses) - local_world_size = count_ranks_per_ip_address[master_ip_address] - for rank_ip_address, count in count_ranks_per_ip_address.items(): - if count != local_world_size: - raise ValueError( - f"All ranks must have the same number of processes, but found {count} processes for rank {rank} on ip {rank_ip_address}, expected {local_world_size}." - + f"count_ranks_per_ip_address = {count_ranks_per_ip_address}" - ) - - node_world_size = len(count_ranks_per_ip_address) - local_rank = rank % local_world_size - node_rank = rank // local_world_size - del ( - context, - local_process_rank, - local_process_world_size, - ) # delete deprecated vars so we don't accidentally use them. + # Resolve distributed context + runtime = BaseDistLoader.resolve_runtime( + context, local_process_rank, local_process_world_size + ) + del context, local_process_rank, local_process_world_size device = ( pin_memory_device if pin_memory_device else gigl.distributed.utils.get_available_device( - local_process_rank=local_rank + local_process_rank=runtime.local_rank ) ) self.to_device = device - # Call appropriate setup method based on sampling cluster setup + # Mode-specific setup if self._sampling_cluster_setup == SamplingClusterSetup.COLOCATED: assert isinstance( dataset, DistDataset @@ -340,26 +255,29 @@ def __init__( raise ValueError( f"When using Colocated mode, input_nodes must be of type " f"(torch.Tensor | tuple[NodeType, torch.Tensor] | None), " - f"received Graph Store format: dict[int, ABLPInputNodes]" + f"received {type(input_nodes)}" ) - ( - sampler_input, - worker_options, - dataset_metadata, - ) = self._setup_for_colocated( + setup_info = self._setup_for_colocated( input_nodes=input_nodes, dataset=dataset, - local_rank=local_rank, - local_world_size=local_world_size, + local_rank=runtime.local_rank, + local_world_size=runtime.local_world_size, device=device, - master_ip_address=master_ip_address, - node_rank=node_rank, - node_world_size=node_world_size, + master_ip_address=runtime.master_ip_address, + node_rank=runtime.node_rank, + node_world_size=runtime.node_world_size, num_workers=num_workers, worker_concurrency=worker_concurrency, channel_size=channel_size, num_cpu_threads=num_cpu_threads, ) + sampler_input: Union[ + ABLPNodeSamplerInput, list[ABLPNodeSamplerInput] + ] = setup_info[0] + worker_options: Union[ + MpDistSamplingWorkerOptions, RemoteDistSamplingWorkerOptions + ] = setup_info[1] + dataset_schema: DatasetSchema = setup_info[2] else: # Graph Store mode assert isinstance( dataset, RemoteDistDataset @@ -377,7 +295,7 @@ def __init__( ( sampler_input, worker_options, - dataset_metadata, + dataset_schema, ) = self._setup_for_graph_store( input_nodes=input_nodes, dataset=dataset, @@ -386,146 +304,56 @@ def __init__( prefetch_size=prefetch_size, ) - self.is_homogeneous_with_labeled_edge_type = ( - dataset_metadata.is_homogeneous_with_labeled_edge_type - ) - self._node_feature_info = dataset_metadata.node_feature_info - self._edge_feature_info = dataset_metadata.edge_feature_info - - num_neighbors = patch_fanout_for_sampling( - dataset_metadata.edge_types, num_neighbors - ) - - if should_cleanup_distributed_context and torch.distributed.is_initialized(): + # Cleanup temporary process group if needed + if ( + runtime.should_cleanup_distributed_context + and torch.distributed.is_initialized() + ): logger.info( f"Cleaning up process group as it was initialized inside {self.__class__.__name__}.__init__." ) torch.distributed.destroy_process_group() - sampling_config = SamplingConfig( - sampling_type=SamplingType.NODE, + # Create SamplingConfig (with patched fanout) + sampling_config = BaseDistLoader.create_sampling_config( num_neighbors=num_neighbors, + dataset_schema=dataset_schema, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last, - with_edge=True, - collect_features=True, - with_neg=False, - with_weight=False, - edge_dir=dataset_metadata.edge_dir, - seed=None, # it's actually optional - None means random. ) + # Build the sampler: a pre-constructed producer for colocated mode, + # or an RPC callable for graph store mode. if self._sampling_cluster_setup == SamplingClusterSetup.COLOCATED: - # Code below this point is taken from the GLT DistNeighborLoader.__init__() function - # (graphlearn_torch/python/distributed/dist_neighbor_loader.py). - # We do this so that we may override the DistSamplingProducer that is used with the GiGL implementation. - - # Type narrowing for colocated mode - - self.input_data = sampler_input[0] - del sampler_input - assert isinstance(self.data, DistDataset) - assert isinstance(self.input_data, ABLPNodeSamplerInput) - - self.sampling_type = sampling_config.sampling_type - self.num_neighbors = sampling_config.num_neighbors - self.batch_size = sampling_config.batch_size - self.shuffle = sampling_config.shuffle - self.drop_last = sampling_config.drop_last - self.with_edge = sampling_config.with_edge - self.with_weight = sampling_config.with_weight - self.collect_features = sampling_config.collect_features - self.edge_dir = sampling_config.edge_dir - self.sampling_config = sampling_config - self.worker_options = worker_options - - # We can set shutdowned to false now - self._shutdowned = False - - self._is_mp_worker = True - self._is_collocated_worker = False - self._is_remote_worker = False - - self.num_data_partitions = self.data.num_partitions - self.data_partition_idx = self.data.partition_idx - self._set_ntypes_and_etypes( - self.data.get_node_types(), self.data.get_edge_types() - ) - - self._num_recv = 0 - self._epoch = 0 - - current_ctx = get_context() - - self._input_len = len(self.input_data) - self._input_type = self.input_data.input_type - self._num_expected = self._input_len // self.batch_size - if not self.drop_last and self._input_len % self.batch_size != 0: - self._num_expected += 1 - - if not current_ctx.is_worker(): - raise RuntimeError( - f"'{self.__class__.__name__}': only supports " - f"launching multiprocessing sampling workers with " - f"a non-server distribution mode, current role of " - f"distributed context is {current_ctx.role}." - ) - if self.data is None: - raise ValueError( - f"'{self.__class__.__name__}': missing input dataset " - f"when launching multiprocessing sampling workers." - ) - - # Launch multiprocessing sampling workers - self._with_channel = True - self.worker_options._set_worker_ranks(current_ctx) - - self._channel = ShmChannel( - self.worker_options.channel_capacity, self.worker_options.channel_size - ) - if self.worker_options.pin_memory: - self._channel.pin_memory() - - self._mp_producer = DistABLPSamplingProducer( - self.data, - self.input_data, - self.sampling_config, - self.worker_options, - self._channel, - ) - # When initiating data loader(s), there will be a spike of memory usage lasting for ~30s. - # The current hypothesis is making connections across machines require a lot of memory. - # If we start all data loaders in all processes simultaneously, the spike of memory - # usage will add up and cause CPU memory OOM. Hence, we initiate the data loaders group by group - # to smooth the memory usage. The definition of group is discussed in init_neighbor_loader_worker. - logger.info( - f"---Machine {rank} local process number {local_rank} preparing to sleep for {process_start_gap_seconds * local_rank} seconds" + assert isinstance(dataset, DistDataset) + assert isinstance(worker_options, MpDistSamplingWorkerOptions) + channel = BaseDistLoader.create_colocated_channel(worker_options) + sampler: Union[ + DistABLPSamplingProducer, Callable[..., int] + ] = DistABLPSamplingProducer( + dataset, + sampler_input, + sampling_config, + worker_options, + channel, ) - time.sleep(process_start_gap_seconds * local_rank) - self._mp_producer.init() else: - # Graph Store mode - re-implement remote worker setup - # Use sequential initialization per compute node to avoid race conditions - # when initializing the samplers on the storage nodes. - node_rank = dataset.cluster_info.compute_node_rank - for target_node_rank in range(dataset.cluster_info.num_compute_nodes): - if node_rank == target_node_rank: - self._init_remote_worker( - dataset=dataset, - sampler_input=sampler_input, - sampling_config=sampling_config, - worker_options=worker_options, - dataset_metadata=dataset_metadata, - ) - logger.info( - f"node_rank {node_rank} / {dataset.cluster_info.num_compute_nodes} initialized the dist loader" - ) - torch.distributed.barrier() - torch.distributed.barrier() - logger.info( - f"node_rank {node_rank} / {dataset.cluster_info.num_compute_nodes} finished initializing the dist loader" - ) + sampler = DistServer.create_sampling_ablp_producer + + # Call base class — handles metadata storage and connection initialization + # (including staggered init for colocated mode). + super().__init__( + dataset=dataset, + sampler_input=sampler_input, + dataset_schema=dataset_schema, + worker_options=worker_options, + sampling_config=sampling_config, + device=device, + runtime=runtime, + sampler=sampler, + process_start_gap_seconds=process_start_gap_seconds, + ) def _setup_for_colocated( self, @@ -546,7 +374,7 @@ def _setup_for_colocated( worker_concurrency: int, channel_size: str, num_cpu_threads: Optional[int], - ) -> tuple[list[ABLPNodeSamplerInput], MpDistSamplingWorkerOptions, DatasetSchema]: + ) -> tuple[ABLPNodeSamplerInput, MpDistSamplingWorkerOptions, DatasetSchema]: """ Setup method for colocated (non-Graph Store) mode. @@ -565,7 +393,7 @@ def _setup_for_colocated( num_cpu_threads: Number of CPU threads for PyTorch. Returns: - Tuple of (list[ABLPNodeSamplerInput], MpDistSamplingWorkerOptions, DatasetSchema). + Tuple of (ABLPNodeSamplerInput, MpDistSamplingWorkerOptions, DatasetSchema). """ # Validate input format - should not be Graph Store format if isinstance(input_nodes, abc.Mapping): @@ -584,12 +412,13 @@ def _setup_for_colocated( f"The dataset must be heterogeneous for ABLP. Received dataset with graph of type: {type(dataset.graph)}" ) - is_homogeneous_with_labeled_edge_type: bool = False + is_homogeneous_with_labeled_edge_type: bool = True if isinstance(input_nodes, tuple): if self._supervision_edge_types == [DEFAULT_HOMOGENEOUS_EDGE_TYPE]: raise ValueError( "When using heterogeneous ABLP, you must provide supervision_edge_types." ) + is_homogeneous_with_labeled_edge_type = False anchor_node_type, anchor_node_ids = input_nodes # TODO (mkolodner-sc): We currently assume supervision edges are directed outward, revisit in future if # this assumption is no longer valid and/or is too opinionated @@ -610,7 +439,6 @@ def _setup_for_colocated( ) anchor_node_ids = input_nodes anchor_node_type = DEFAULT_HOMOGENEOUS_NODE_TYPE - is_homogeneous_with_labeled_edge_type = True elif input_nodes is None: if dataset.node_ids is None: raise ValueError( @@ -626,7 +454,6 @@ def _setup_for_colocated( ) anchor_node_ids = dataset.node_ids anchor_node_type = DEFAULT_HOMOGENEOUS_NODE_TYPE - is_homogeneous_with_labeled_edge_type = True else: raise ValueError(f"Unexpected input_nodes type: {type(input_nodes)}") @@ -748,7 +575,7 @@ def _setup_for_colocated( edge_types = list(dataset.graph.keys()) return ( - [sampler_input], + sampler_input, worker_options, DatasetSchema( is_homogeneous_with_labeled_edge_type=is_homogeneous_with_labeled_edge_type, @@ -795,10 +622,6 @@ def _setup_for_graph_store( num_ports=dataset.cluster_info.num_compute_nodes ) sampling_port = sampling_ports[node_rank] - # TODO(kmonte) - We need to be able to differentiate between different instances of the same loader. - # e.g. if we have two different DistABLPLoaders, then they will have conflicting worker keys. - # And they will share each others data. Therefor, the second loader will not load the data it's expecting. - # Probably, we can just keep track of the insantiations on the server-side and include the count in the worker key. worker_key = ( f"compute_ablp_loader_rank_{node_rank}_worker_{self._instance_count}" ) @@ -834,17 +657,10 @@ def _setup_for_graph_store( # Extract node type and label edge types from the ABLPInputNodes dataclass. # All entries should have the same anchor_node_type and edge type keys. first_input = next(iter(input_nodes.values())) - input_type = first_input.anchor_node_type is_homogeneous_with_labeled_edge_type = ( input_type == DEFAULT_HOMOGENEOUS_NODE_TYPE ) - print(f"Input type: {input_type}") - print( - f"Is homogeneous with labeled edge type: {is_homogeneous_with_labeled_edge_type}" - ) - print(f"First input: {first_input}") - flush() # Extract supervision edge types and derive label edge types from the # ABLPInputNodes.labels dict (keyed by supervision edge type). @@ -933,114 +749,6 @@ def _setup_for_graph_store( ), ) - def _init_remote_worker( - self, - dataset: RemoteDistDataset, - sampler_input: list[ABLPNodeSamplerInput], - sampling_config: SamplingConfig, - worker_options: RemoteDistSamplingWorkerOptions, - dataset_metadata: DatasetSchema, - ) -> None: - """ - Initialize the remote worker code path for Graph Store mode. - - This re-implements GLT's DistLoader remote worker setup but uses GiGL's DistServer. - - Args: - dataset: The RemoteDistDataset to sample from. - sampler_input: List of ABLPNodeSamplerInput, one per server. - sampling_config: Configuration for sampling. - worker_options: Options for remote sampling workers. - dataset_metadata: Metadata about the dataset schema. - """ - # Set instance variables (like DistLoader does) - self.sampling_type = sampling_config.sampling_type - self.num_neighbors = sampling_config.num_neighbors - self.batch_size = sampling_config.batch_size - self.shuffle = sampling_config.shuffle - self.drop_last = sampling_config.drop_last - self.with_edge = sampling_config.with_edge - self.with_weight = sampling_config.with_weight - self.collect_features = sampling_config.collect_features - self.edge_dir = sampling_config.edge_dir - self.sampling_config = sampling_config - self.worker_options = worker_options - - self._shutdowned = False - - # Set worker type flags - self._is_mp_worker = False - self._is_collocated_worker = False - self._is_remote_worker = True - - # For remote worker, end of epoch is determined by server - self._num_expected = float("inf") - self._with_channel = True - - self._num_recv = 0 - self._epoch = 0 - - # Get server rank list from worker_options - self._server_rank_list = ( - worker_options.server_rank - if isinstance(worker_options.server_rank, list) - else [worker_options.server_rank] - ) - self._input_data_list = sampler_input # Already a list (one per server) - - # Get input type from first input - self._input_type = self._input_data_list[0].input_type - - # Get dataset metadata from cluster_info (not via RPC) - self.num_data_partitions = dataset.cluster_info.num_storage_nodes - self.data_partition_idx = dataset.cluster_info.compute_node_rank - - # Derive node types from edge types - # For labeled homogeneous: edge_types contains DEFAULT_HOMOGENEOUS_EDGE_TYPE - # For heterogeneous: extract unique src/dst types from edge types - edge_types = dataset_metadata.edge_types or [] - if edge_types: - node_types = list( - set([et[0] for et in edge_types] + [et[2] for et in edge_types]) - ) - else: - node_types = [DEFAULT_HOMOGENEOUS_NODE_TYPE] - self._set_ntypes_and_etypes(node_types, edge_types) - - # Create sampling producers on each server (concurrently) - # Move input data to CPU before sending to server - for input_data in self._input_data_list: - input_data.to(torch.device("cpu")) - - self._producer_id_list = [] - with concurrent.futures.ThreadPoolExecutor() as executor: - futures = [ - executor.submit( - request_server, - server_rank, - DistServer.create_sampling_ablp_producer, - input_data, - self.sampling_config, - self.worker_options, - ) - for server_rank, input_data in zip( - self._server_rank_list, self._input_data_list - ) - ] - - for future in futures: - producer_id = future.result() - self._producer_id_list.append(producer_id) - logger.info( - f"DistABLPLoader rank {torch.distributed.get_rank()} producers: ({[producer_id for producer_id in self._producer_id_list]})" - ) - # Create remote receiving channel for cross-machine message passing - self._channel = RemoteReceivingChannel( - self._server_rank_list, - self._producer_id_list, - self.worker_options.prefetch_size, - ) - def _get_labels( self, msg: SampleMessage ) -> tuple[ @@ -1204,7 +912,7 @@ def _collate_fn(self, msg: SampleMessage) -> Union[Data, HeteroData]: ) if isinstance(data, HeteroData): data = strip_label_edges(data) - if self.is_homogeneous_with_labeled_edge_type: + if self._is_homogeneous_with_labeled_edge_type: if len(self._supervision_edge_types) != 1: raise ValueError( f"Expected 1 supervision edge type, got {len(self._supervision_edge_types)}" diff --git a/gigl/distributed/distributed_neighborloader.py b/gigl/distributed/distributed_neighborloader.py index 20547763a..5b96d9da2 100644 --- a/gigl/distributed/distributed_neighborloader.py +++ b/gigl/distributed/distributed_neighborloader.py @@ -1,39 +1,30 @@ import sys -import time -from collections import Counter, abc +from collections import abc from itertools import count -from typing import Optional, Tuple, Union +from typing import Callable, Optional, Tuple, Union import torch -from graphlearn_torch.channel import RemoteReceivingChannel, SampleMessage +from graphlearn_torch.channel import SampleMessage from graphlearn_torch.distributed import ( - DistLoader, MpDistSamplingWorkerOptions, RemoteDistSamplingWorkerOptions, ) -from graphlearn_torch.distributed.dist_context import get_context -from graphlearn_torch.sampler import ( - NodeSamplerInput, - RemoteSamplerInput, - SamplingConfig, - SamplingType, -) +from graphlearn_torch.distributed.dist_sampling_producer import DistMpSamplingProducer +from graphlearn_torch.sampler import NodeSamplerInput from torch_geometric.data import Data, HeteroData from torch_geometric.typing import EdgeType import gigl.distributed.utils from gigl.common.logger import Logger -from gigl.distributed.constants import DEFAULT_MASTER_INFERENCE_PORT +from gigl.distributed.base_dist_loader import BaseDistLoader from gigl.distributed.dist_context import DistributedContext from gigl.distributed.dist_dataset import DistDataset -from gigl.distributed.graph_store.compute import async_request_server, request_server from gigl.distributed.graph_store.dist_server import DistServer as GiglDistServer from gigl.distributed.graph_store.remote_dist_dataset import RemoteDistDataset from gigl.distributed.utils.neighborloader import ( DatasetSchema, SamplingClusterSetup, labeled_to_homogeneous, - patch_fanout_for_sampling, set_missing_features, shard_nodes_by_process, strip_label_edges, @@ -52,12 +43,14 @@ DEFAULT_NUM_CPU_THREADS = 2 +# We don't see logs for graph store mode for whatever reason. +# TOOD(#442): Revert this once the GCP issues are resolved. def flush(): sys.stdout.flush() sys.stderr.flush() -class DistNeighborLoader(DistLoader): +class DistNeighborLoader(BaseDistLoader): # Counts instantiations of this class, per process. # This is needed so we can generate unique worker key for each instance, for graph store mode. # NOTE: This is per-class, not per-instance. @@ -90,6 +83,12 @@ def __init__( drop_last: bool = False, ): """ + Distributed Neighbor Loader. + Takes in some input nodes and samples neighbors from the dataset. + This loader should be used if you do not have any specially sampling needs, + e.g. you need to generate *training* examples for Anchor Based Link Prediction (ABLP) tasks. + Though this loader is useful for generating random negative examples for ABLP training. + Note: We try to adhere to pyg dataloader api as much as possible. See the following for reference: https://pytorch-geometric.readthedocs.io/en/2.5.2/_modules/torch_geometric/loader/node_loader.html#NodeLoader @@ -151,81 +150,15 @@ def __init__( # Set self._shutdowned right away, that way if we throw here, and __del__ is called, # then we can properly clean up and don't get extraneous error messages. - # We set to `True` as we don't need to cleanup right away, and this will get set - # to `False` in super().__init__()` e.g. - # https://github.com/alibaba/graphlearn-for-pytorch/blob/26fe3d4e050b081bc51a79dc9547f244f5d314da/graphlearn_torch/python/distributed/dist_loader.py#L125C1-L126C1 self._shutdowned = True - node_world_size: int - node_rank: int - rank: int - world_size: int - local_rank: int - local_world_size: int - - master_ip_address: str - should_cleanup_distributed_context: bool = False - - if context: - assert ( - local_process_world_size is not None - ), "context: DistributedContext provided, so local_process_world_size must be provided." - assert ( - local_process_rank is not None - ), "context: DistributedContext provided, so local_process_rank must be provided." - - master_ip_address = context.main_worker_ip_address - node_world_size = context.global_world_size - node_rank = context.global_rank - local_world_size = local_process_world_size - local_rank = local_process_rank - - rank = node_rank * local_world_size + local_rank - world_size = node_world_size * local_world_size - - if not torch.distributed.is_initialized(): - logger.info( - "process group is not available, trying to torch.distributed.init_process_group to communicate necessary setup information." - ) - should_cleanup_distributed_context = True - logger.info( - f"Initializing process group with master ip address: {master_ip_address}, rank: {rank}, world size: {world_size}, local_rank: {local_rank}, local_world_size: {local_world_size}." - ) - torch.distributed.init_process_group( - backend="gloo", # We just default to gloo for this temporary process group - init_method=f"tcp://{master_ip_address}:{DEFAULT_MASTER_INFERENCE_PORT}", - rank=rank, - world_size=world_size, - ) - - else: - assert ( - torch.distributed.is_initialized() - ), f"context: DistributedContext is None, so process group must be initialized before constructing this object {self.__class__.__name__}." - world_size = torch.distributed.get_world_size() - rank = torch.distributed.get_rank() - - rank_ip_addresses = gigl.distributed.utils.get_internal_ip_from_all_ranks() - master_ip_address = rank_ip_addresses[0] - - count_ranks_per_ip_address = Counter(rank_ip_addresses) - local_world_size = count_ranks_per_ip_address[master_ip_address] - for rank_ip_address, count in count_ranks_per_ip_address.items(): - if count != local_world_size: - raise ValueError( - f"All ranks must have the same number of processes, but found {count} processes for rank {rank} on ip {rank_ip_address}, expected {local_world_size}." - + f"count_ranks_per_ip_address = {count_ranks_per_ip_address}" - ) - - node_world_size = len(count_ranks_per_ip_address) - local_rank = rank % local_world_size - node_rank = rank // local_world_size + # Resolve distributed context + runtime = BaseDistLoader.resolve_runtime( + context, local_process_rank, local_process_world_size + ) + del context, local_process_rank, local_process_world_size - del ( - context, - local_process_rank, - local_process_world_size, - ) # delete deprecated vars so we don't accidentally use them. + # Determine mode if isinstance(dataset, RemoteDistDataset): self._sampling_cluster_setup = SamplingClusterSetup.GRAPH_STORE else: @@ -241,37 +174,37 @@ def __init__( pin_memory_device if pin_memory_device else gigl.distributed.utils.get_available_device( - local_process_rank=local_rank + local_process_rank=runtime.local_rank ) ) - # Determines if the node ids passed in are heterogeneous or homogeneous. + # Mode-specific setup if self._sampling_cluster_setup == SamplingClusterSetup.COLOCATED: assert isinstance( dataset, DistDataset ), "When using colocated mode, dataset must be a DistDataset." - input_data, worker_options, dataset_metadata = self._setup_for_colocated( + input_data, worker_options, dataset_schema = self._setup_for_colocated( input_nodes=input_nodes, dataset=dataset, - local_rank=local_rank, - local_world_size=local_world_size, + local_rank=runtime.local_rank, + local_world_size=runtime.local_world_size, device=device, - master_ip_address=master_ip_address, - node_rank=node_rank, - node_world_size=node_world_size, + master_ip_address=runtime.master_ip_address, + node_rank=runtime.node_rank, + node_world_size=runtime.node_world_size, num_workers=num_workers, worker_concurrency=worker_concurrency, channel_size=channel_size, num_cpu_threads=num_cpu_threads, ) - else: # Graph Store mode + else: assert isinstance( dataset, RemoteDistDataset ), "When using Graph Store mode, dataset must be a RemoteDistDataset." if prefetch_size is None: logger.info(f"prefetch_size is not provided, using default of 4") prefetch_size = 4 - input_data, worker_options, dataset_metadata = self._setup_for_graph_store( + input_data, worker_options, dataset_schema = self._setup_for_graph_store( input_nodes=input_nodes, dataset=dataset, num_workers=num_workers, @@ -279,65 +212,56 @@ def __init__( channel_size=channel_size, ) - self._is_homogeneous_with_labeled_edge_type = ( - dataset_metadata.is_homogeneous_with_labeled_edge_type - ) - self._node_feature_info = dataset_metadata.node_feature_info - self._edge_feature_info = dataset_metadata.edge_feature_info + # Cleanup temporary process group if needed + if ( + runtime.should_cleanup_distributed_context + and torch.distributed.is_initialized() + ): + logger.info( + f"Cleaning up process group as it was initialized inside {self.__class__.__name__}.__init__." + ) + torch.distributed.destroy_process_group() - logger.info(f"num_neighbors before patch: {num_neighbors}") - num_neighbors = patch_fanout_for_sampling( - edge_types=dataset_metadata.edge_types, - num_neighbors=num_neighbors, - ) - logger.info( - f"num_neighbors: {num_neighbors}, edge_types: {dataset_metadata.edge_types}" - ) - sampling_config = SamplingConfig( - sampling_type=SamplingType.NODE, + # Create SamplingConfig (with patched fanout) + sampling_config = BaseDistLoader.create_sampling_config( num_neighbors=num_neighbors, + dataset_schema=dataset_schema, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last, - with_edge=True, - collect_features=True, - with_neg=False, - with_weight=False, - edge_dir=dataset_metadata.edge_dir, - seed=None, # it's actually optional - None means random. ) - if should_cleanup_distributed_context and torch.distributed.is_initialized(): - logger.info( - f"Cleaning up process group as it was initialized inside {self.__class__.__name__}.__init__." - ) - torch.distributed.destroy_process_group() - + # Build the sampler: a pre-constructed producer for colocated mode, + # or an RPC callable for graph store mode. if self._sampling_cluster_setup == SamplingClusterSetup.COLOCATED: - # When initiating data loader(s), there will be a spike of memory usage lasting for ~30s. - # The current hypothesis is making connections across machines require a lot of memory. - # If we start all data loaders in all processes simultaneously, the spike of memory - # usage will add up and cause CPU memory OOM. Hence, we initiate the data loaders group by group - # to smooth the memory usage. The definition of group is discussed in init_neighbor_loader_worker. - logger.info( - f"---Machine {rank} local process number {local_rank} preparing to sleep for {process_start_gap_seconds * local_rank} seconds" - ) - time.sleep(process_start_gap_seconds * local_rank) - super().__init__( - dataset, # Pass in the dataset for colocated mode. + assert isinstance(dataset, DistDataset) + assert isinstance(worker_options, MpDistSamplingWorkerOptions) + channel = BaseDistLoader.create_colocated_channel(worker_options) + sampler: Union[ + DistMpSamplingProducer, Callable[..., int] + ] = DistMpSamplingProducer( + dataset, input_data, sampling_config, - device, worker_options, + channel, ) else: - self._init_graph_store_connections( - dataset=dataset, - input_data=input_data, - sampling_config=sampling_config, - device=device, - worker_options=worker_options, - ) + sampler = GiglDistServer.create_sampling_producer + + # Call base class — handles metadata storage and connection initialization + # (including staggered init for colocated mode). + super().__init__( + dataset=dataset, + sampler_input=input_data, + dataset_schema=dataset_schema, + worker_options=worker_options, + sampling_config=sampling_config, + device=device, + runtime=runtime, + sampler=sampler, + process_start_gap_seconds=process_start_gap_seconds, + ) def _setup_for_graph_store( self, @@ -353,7 +277,7 @@ def _setup_for_graph_store( num_workers: int, prefetch_size: int, channel_size: str, - ) -> tuple[NodeSamplerInput, RemoteDistSamplingWorkerOptions, DatasetSchema]: + ) -> tuple[list[NodeSamplerInput], RemoteDistSamplingWorkerOptions, DatasetSchema]: if input_nodes is None: raise ValueError( f"When using Graph Store mode, input nodes must be provided, received {input_nodes}" @@ -369,7 +293,6 @@ def _setup_for_graph_store( f"When using Graph Store mode, input nodes must be of type (dict[int, torch.Tensor] | (NodeType, dict[int, torch.Tensor])), received {type(input_nodes)} ({type(input_nodes[0])}, {type(input_nodes[1])})" ) - is_homogeneous_with_labeled_edge_type = False node_feature_info = dataset.get_node_feature_info() edge_feature_info = dataset.get_edge_feature_info() edge_types = dataset.get_edge_types() @@ -417,7 +340,6 @@ def _setup_for_graph_store( if isinstance(edge_types, list): if DEFAULT_HOMOGENEOUS_EDGE_TYPE in edge_types: input_type: Optional[NodeType] = DEFAULT_HOMOGENEOUS_NODE_TYPE - is_homogeneous_with_labeled_edge_type = True else: input_type = fallback_input_type elif require_edge_feature_info: @@ -427,11 +349,15 @@ def _setup_for_graph_store( else: input_type = None + is_homogeneous_with_labeled_edge_type = ( + input_type == DEFAULT_HOMOGENEOUS_NODE_TYPE + ) + # Convert from dict to list which is what the GLT DistNeighborLoader expects. servers = nodes.keys() if max(servers) >= dataset.cluster_info.num_storage_nodes or min(servers) < 0: raise ValueError( - f"When using Graph Store mode, the server ranks must be less than the number of storage nodes and greater than 0, received inputs for servers: {list(nodes.keys())}" + f"When using Graph Store mode, the server ranks must be in range [0, num_servers ({dataset.cluster_info.num_storage_nodes})), received inputs for servers: {list(nodes.keys())}" ) input_data: list[NodeSamplerInput] = [] for server_rank in range(dataset.cluster_info.num_storage_nodes): @@ -607,236 +533,6 @@ def _setup_for_colocated( ), ) - def _init_graph_store_connections( - self, - dataset: RemoteDistDataset, - input_data: list[NodeSamplerInput], - sampling_config: SamplingConfig, - device: torch.device, - worker_options: RemoteDistSamplingWorkerOptions, - ): - # Graph Store mode — initialize DistLoader attributes directly instead of - # calling super().__init__() to avoid the ThreadPoolExecutor deadlock at scale. - # - # GLT's DistLoader.__init__() dispatches create_sampling_producer RPCs via - # ThreadPoolExecutor(max_workers=32). With 60+ servers, only 32 threads run, - # causing a TensorPipe rendezvous deadlock. Instead, we inline the DistLoader - # init code and dispatch all RPCs asynchronously in a simple loop. - - node_rank = dataset.cluster_info.compute_node_rank - num_storage_nodes = dataset.cluster_info.num_storage_nodes - - # --- Set all DistLoader attributes (mirrors GLT DistLoader.__init__) --- - # These are required by inherited methods: shutdown(), __iter__(), __next__(), - # __del__(), _collate_fn(), _set_ntypes_and_etypes(). - self.data = None # No local data in Graph Store mode - self.input_data = input_data - self.sampling_type = sampling_config.sampling_type - self.num_neighbors = sampling_config.num_neighbors - self.batch_size = sampling_config.batch_size - self.shuffle = sampling_config.shuffle - self.drop_last = sampling_config.drop_last - self.with_edge = sampling_config.with_edge - self.with_weight = sampling_config.with_weight - self.collect_features = sampling_config.collect_features - self.edge_dir = sampling_config.edge_dir - self.sampling_config = sampling_config - self.to_device = device - self.worker_options = worker_options - self._shutdowned = False - - self._is_collocated_worker = False - self._is_mp_worker = False - self._is_remote_worker = True - - self._num_recv = 0 - self._epoch = 0 - - # Context validation - ctx = get_context() - if ctx is None: - raise RuntimeError( - f"'{self.__class__.__name__}': the distributed context " - f"has not been initialized." - ) - if not ctx.is_client(): - raise RuntimeError( - f"'{self.__class__.__name__}': must be used on a client " - f"worker process." - ) - - # Remote worker attributes - self._num_expected = float("inf") - self._with_channel = True - - self._server_rank_list: list[int] = ( - self.worker_options.server_rank - if isinstance(self.worker_options.server_rank, list) - else [self.worker_options.server_rank] - ) - self._input_data_list: list[NodeSamplerInput] = ( - self.input_data if isinstance(self.input_data, list) else [self.input_data] - ) - self._input_type = self._input_data_list[0].input_type - - # --- Barrier loop: one compute node at a time --- - logger.info( - f"node_rank {node_rank} starting barrier loop with " - f"{dataset.cluster_info.num_compute_nodes} compute nodes" - ) - flush() - # For Graph Store mode, we need to start the communcation between compute and storage nodes sequentially, by compute node. - # E.g. intialize connections between compute node 0 and storage nodes 0, 1, 2, 3, then compute node 1 and storage nodes 0, 1, 2, 3, etc. - # Note that each compute node may have multiple connections to each storage node, once per compute process. - # It's important to distinguish "compute node" (e.g. physical compute machine) from "compute process" (e.g. process running on the compute node). - # Since in practice we have multiple compute processes per compute node, and each compute process needs to initialize the connection to the storage nodes. - # E.g. if there are 4 gpus per compute node, then there will be 4 connections from each compute node to each storage node. - # We need to this because if we don't, then there is a race condition when initalizing the samplers on the storage nodes [1] - # Where since the lock is per *server* (e.g. per storage node), if we try to start one connection from compute node 0, and compute node 1 - # Then we deadlock and fail. - # Specifically, the race condition happens in `DistLoader.__init__` when it initializes the sampling producers on the storage nodes. [2] - # [1]: https://github.com/alibaba/graphlearn-for-pytorch/blob/main/graphlearn_torch/python/distributed/dist_server.py#L129-L167 - # [2]: https://github.com/alibaba/graphlearn-for-pytorch/blob/88ff111ac0d9e45c6c9d2d18cfc5883dca07e9f9/graphlearn_torch/python/distributed/dist_loader.py#L187-L193 - - # See below for a connection setup. - # ╔═══════════════════════════════════════════════════════════════════════════════════════╗ - # ║ COMPUTE TO STORAGE NODE CONNECTIONS ║ - # ╚═══════════════════════════════════════════════════════════════════════════════════════╝ - - # COMPUTE NODES STORAGE NODES - # ═════════════ ═════════════ - - # ┌──────────────────────┐ (1) ┌───────────────┐ - # │ COMPUTE NODE 0 │ │ │ - # │ ┌────┬────┬────┬────┤ ══════════════════════════════════│ STORAGE 0 │ - # │ │GPU │GPU │GPU │GPU │ ╱ │ │ - # │ │ 0 │ 1 │ 2 │ 3 │ ════════════════════╲ ╱ └───────────────┘ - # │ └────┴────┴────┴────┤ (2) ╲ ╱ - # └──────────────────────┘ ╲ ╱ - # ╳ - # (3) ╱ ╲ (4) - # ┌──────────────────────┐ ╱ ╲ ┌───────────────┐ - # │ COMPUTE NODE 1 │ ╱ ╲ │ │ - # │ ┌────┬────┬────┬────┤ ═════════════════╱ ═│ STORAGE 1 │ - # │ │GPU │GPU │GPU │GPU │ │ │ - # │ │ 0 │ 1 │ 2 │ 3 │ ══════════════════════════════════│ │ - # │ └────┴────┴────┴────┤ └───────────────┘ - # └──────────────────────┘ - - # ┌─────────────────────────────────────────────────────────────────────────────┐ - # │ (1) Compute Node 0 → Storage 0 (4 connections, one per GPU) │ - # │ (2) Compute Node 0 → Storage 1 (4 connections, one per GPU) │ - # │ (3) Compute Node 1 → Storage 0 (4 connections, one per GPU) │ - # │ (4) Compute Node 1 → Storage 1 (4 connections, one per GPU) │ - # └─────────────────────────────────────────────────────────────────────────────┘ - for target_node_rank in range(dataset.cluster_info.num_compute_nodes): - start_time = time.time() - if node_rank == target_node_rank: - # Step 1: Get dataset metadata via RPC (single call, fast) - ( - self.num_data_partitions, - self.data_partition_idx, - ntypes, - etypes, - ) = request_server( - self._server_rank_list[0], - GiglDistServer.get_dataset_meta, - ) - self._set_ntypes_and_etypes(ntypes, etypes) - - # Step 2: Move input data to CPU if needed - for i, inp in enumerate(self._input_data_list): - if not isinstance(inp, RemoteSamplerInput): - self._input_data_list[i] = inp.to(torch.device("cpu")) - - # Step 3: Dispatch ALL create_sampling_producer RPCs async. - # - # This is the key fix: async_request_server queues the RPC - # in TensorPipe and returns immediately. By dispatching all - # N RPCs in a loop BEFORE waiting for any response, all - # storage nodes receive the RPC and start their worker - # rendezvous simultaneously. No ThreadPoolExecutor needed. - logger.info( - f"node_rank={node_rank} dispatching " - f"create_sampling_producer to " - f"{num_storage_nodes} servers" - ) - flush() - t_dispatch = time.time() - rpc_futures: list[tuple[int, torch.futures.Future[int]]] = [] - for server_rank, inp_data in zip( - self._server_rank_list, self._input_data_list - ): - fut = async_request_server( - server_rank, - GiglDistServer.create_sampling_producer, - inp_data, - self.sampling_config, - self.worker_options, - ) - rpc_futures.append((server_rank, fut)) - logger.info( - f"node_rank={node_rank} all " - f"{len(rpc_futures)} RPCs dispatched in " - f"{time.time() - t_dispatch:.3f}s, " - f"waiting for responses" - ) - flush() - - # Step 4: Wait for all results - self._producer_id_list: list[int] = [] - for server_rank, fut in rpc_futures: - t_wait = time.time() - producer_id: int = fut.wait() - logger.info( - f"node_rank={node_rank} " - f"create_sampling_producer" - f"(server_rank={server_rank}) returned " - f"producer_id={producer_id} in " - f"{time.time() - t_wait:.2f}s" - ) - flush() - self._producer_id_list.append(producer_id) - logger.info( - f"node_rank={node_rank} all " - f"{len(self._producer_id_list)} producers created " - f"in {time.time() - t_dispatch:.2f}s total" - ) - flush() - - # Step 5: Create remote receiving channel - self._channel = RemoteReceivingChannel( - self._server_rank_list, - self._producer_id_list, - self.worker_options.prefetch_size, - ) - - logger.info( - f"node_rank {node_rank} initialized the dist loader in " - f"{time.time() - start_time:.2f}s" - ) - flush() - else: - logger.info( - f"node_rank {node_rank} waiting for barrier " - f"for rank {target_node_rank}" - ) - flush() - torch.distributed.barrier(device_ids=torch.device("cpu")) - logger.info( - f"node_rank {node_rank} barrier for rank " - f"{target_node_rank} in {time.time() - start_time:.2f}s" - ) - flush() - - torch.distributed.barrier(device_ids=torch.device("cpu")) - logger.info( - f"node_rank {node_rank}: all " - f"{dataset.cluster_info.num_compute_nodes} node ranks " - f"initialized the dist loader" - ) - flush() - def _collate_fn(self, msg: SampleMessage) -> Union[Data, HeteroData]: data = super()._collate_fn(msg) data = set_missing_features( diff --git a/gigl/distributed/graph_store/dist_server.py b/gigl/distributed/graph_store/dist_server.py index c38f1b63f..c2989d317 100644 --- a/gigl/distributed/graph_store/dist_server.py +++ b/gigl/distributed/graph_store/dist_server.py @@ -281,23 +281,17 @@ def get_node_ids( raise ValueError( f"Invalid split: {split}. Must be one of 'train', 'val', 'test', or None." ) + if node_type is not None: if not isinstance(nodes, abc.Mapping): raise ValueError( f"node_type was provided as {node_type}, so node ids must be a dict[NodeType, torch.Tensor] (e.g. a heterogeneous dataset), got {type(nodes)}" ) - print(f"node types: {nodes.keys()}") nodes = nodes[node_type] elif not isinstance(nodes, torch.Tensor): - if nodes is not None and DEFAULT_HOMOGENEOUS_NODE_TYPE in nodes: - logger.info( - f"Received None node type but assuming it's a homogeneous dataset (node types: {nodes.keys()}) and returning the default node type." - ) - nodes = nodes[DEFAULT_HOMOGENEOUS_NODE_TYPE] - else: - raise ValueError( - f"node_type was not provided, so node ids must be a torch.Tensor (e.g. a homogeneous dataset), got {type(nodes)}." - ) + raise ValueError( + f"node_type was not provided, so node ids must be a torch.Tensor (e.g. a homogeneous dataset), got {type(nodes)}." + ) if rank is not None and world_size is not None: return shard_nodes_by_process(nodes, rank, world_size) @@ -314,6 +308,17 @@ def get_edge_types(self) -> Optional[list[EdgeType]]: else: return None + def get_node_types(self) -> Optional[list[NodeType]]: + """Get the node types from the dataset. + + Returns: + The node types in the dataset, None if the dataset is homogeneous. + """ + if isinstance(self.dataset.graph, dict): + return list(self.dataset.get_node_types()) + else: + return None + def get_ablp_input( self, split: Union[Literal["train", "val", "test"], str], diff --git a/gigl/distributed/graph_store/remote_dist_dataset.py b/gigl/distributed/graph_store/remote_dist_dataset.py index f2d97e10e..c71aa204f 100644 --- a/gigl/distributed/graph_store/remote_dist_dataset.py +++ b/gigl/distributed/graph_store/remote_dist_dataset.py @@ -116,6 +116,14 @@ def _get_node_ids( ) -> dict[int, torch.Tensor]: """Fetches node ids from the storage nodes for the current compute node (machine).""" futures: list[torch.futures.Future[torch.Tensor]] = [] + if node_type is None: + node_types = self.get_node_types() + if node_types is not None and DEFAULT_HOMOGENEOUS_NODE_TYPE in node_types: + node_type = DEFAULT_HOMOGENEOUS_NODE_TYPE + logger.info( + f"Using default node type {node_type} for homogeneous dataset with label edge types as {DEFAULT_HOMOGENEOUS_NODE_TYPE} is in the node types: {node_types}" + ) + logger.info( f"Getting node ids for rank {rank} / {world_size} with node type {node_type} and split {split}" ) @@ -505,3 +513,14 @@ def get_edge_types(self) -> Optional[list[EdgeType]]: 0, DistServer.get_edge_types, ) + + def get_node_types(self) -> Optional[list[NodeType]]: + """Get the node types from the registered dataset. + + Returns: + The node types in the dataset, None if the dataset is homogeneous. + """ + return request_server( + 0, + DistServer.get_node_types, + ) From 0efc55d96ad8f2fbec578ef42b999d9f04ac8333 Mon Sep 17 00:00:00 2001 From: kmontemayor Date: Tue, 24 Feb 2026 17:48:38 +0000 Subject: [PATCH 10/30] maybe fix? --- .../graph_store/heterogeneous_training.py | 32 +++++++++++++++---- .../graph_store/homogeneous_training.py | 2 ++ 2 files changed, 27 insertions(+), 7 deletions(-) diff --git a/examples/link_prediction/graph_store/heterogeneous_training.py b/examples/link_prediction/graph_store/heterogeneous_training.py index 2ac7b790e..76992ada9 100644 --- a/examples/link_prediction/graph_store/heterogeneous_training.py +++ b/examples/link_prediction/graph_store/heterogeneous_training.py @@ -154,14 +154,17 @@ def _setup_dataloaders( """ rank = torch.distributed.get_rank() - query_node_type = supervision_edge_type.src_node_type - labeled_node_type = supervision_edge_type.dst_node_type - if dataset.get_edge_dir() == "in": + query_node_type = supervision_edge_type.dst_node_type + labeled_node_type = supervision_edge_type.src_node_type anchor_node_type = labeled_node_type else: + query_node_type = supervision_edge_type.src_node_type + labeled_node_type = supervision_edge_type.dst_node_type anchor_node_type = query_node_type + logger.info(f"---Rank {rank} query node type: {query_node_type}, labeled node type: {labeled_node_type}, anchor node type: {anchor_node_type} due to edge direction {dataset.get_edge_dir()}") + shuffle = split == "train" # In graph store mode, we fetch ABLP input (anchors + positive/negative labels) from the storage cluster. @@ -200,13 +203,13 @@ def _setup_dataloaders( all_node_ids = dataset.get_node_ids( rank=cluster_info.compute_node_rank, world_size=cluster_info.num_compute_nodes, - node_type=anchor_node_type, + node_type=labeled_node_type, ) random_negative_loader = DistNeighborLoader( dataset=dataset, num_neighbors=num_neighbors, - input_nodes=(anchor_node_type, all_node_ids), + input_nodes=(labeled_node_type, all_node_ids), num_workers=sampling_workers_per_process, batch_size=random_batch_size, pin_memory_device=device, @@ -233,6 +236,7 @@ def _compute_loss( random_negative_data: HeteroData, loss_fn: RetrievalLoss, supervision_edge_type: EdgeType, + edge_dir: str, device: torch.device, ) -> torch.Tensor: """ @@ -243,14 +247,20 @@ def _compute_loss( random_negative_data (HeteroData): The batch of data containing random negative nodes loss_fn (RetrievalLoss): Initialized class to use for loss calculation supervision_edge_type (EdgeType): The supervision edge type to use for training in format query_node -> relation -> labeled_node + edge_dir (str): Direction of the supervision edge device (torch.device): Device for training or validation Returns: torch.Tensor: Final loss for the current batch on the current process """ # Extract relevant node types from the supervision edge - query_node_type = supervision_edge_type.src_node_type - labeled_node_type = supervision_edge_type.dst_node_type + if edge_dir == "in": + query_node_type = supervision_edge_type.src_node_type + labeled_node_type = supervision_edge_type.dst_node_type + else: + query_node_type = supervision_edge_type.dst_node_type + labeled_node_type = supervision_edge_type.src_node_type + logger.info(f"---Rank {torch.distributed.get_rank()} query node type: {query_node_type}, labeled node type: {labeled_node_type} due to edge direction {edge_dir}") if query_node_type == labeled_node_type: inference_node_types = [query_node_type] else: @@ -532,6 +542,7 @@ def _training_process( random_negative_data=random_data, loss_fn=loss_fn, supervision_edge_type=args.supervision_edge_type, + edge_dir=dataset.get_edge_dir(), device=device, ) optimizer.zero_grad() @@ -567,6 +578,7 @@ def _training_process( random_negative_loader=val_random_negative_loader_iter, loss_fn=loss_fn, supervision_edge_type=args.supervision_edge_type, + edge_dir=dataset.get_edge_dir(), device=device, log_every_n_batch=args.log_every_n_batch, num_batches=num_val_batches_per_process, @@ -646,6 +658,7 @@ def _training_process( random_negative_loader=test_random_negative_loader_iter, loss_fn=loss_fn, supervision_edge_type=args.supervision_edge_type, + edge_dir=dataset.get_edge_dir(), device=device, log_every_n_batch=args.log_every_n_batch, ) @@ -681,6 +694,7 @@ def _run_validation_loops( random_negative_loader: Iterator[HeteroData], loss_fn: RetrievalLoss, supervision_edge_type: EdgeType, + edge_dir: str, device: torch.device, log_every_n_batch: int, num_batches: Optional[int] = None, @@ -694,6 +708,7 @@ def _run_validation_loops( random_negative_loader (Iterator[HeteroData]): Dataloader for loading random negative data loss_fn (RetrievalLoss): Initialized class to use for loss calculation supervision_edge_type (EdgeType): The supervision edge type to use for training + edge_dir (Literal["in", "out"]): Direction of the supervision edge device (torch.device): Device to use for training or testing log_every_n_batch (int): The frequency we should log batch information num_batches (Optional[int]): The number of batches to run the validation loop for. @@ -719,11 +734,13 @@ def _run_validation_loops( while True: if num_batches and batch_idx >= num_batches: + logger.info(f"Rank {torch.distributed.get_rank()} num_batches={num_batches} reached, stopping validation loop") break try: main_data = next(main_loader) random_data = next(random_negative_loader) except StopIteration: + logger.info(f"Rank {torch.distributed.get_rank()} test data loader exhausted, stopping validation loop") break loss = _compute_loss( @@ -732,6 +749,7 @@ def _run_validation_loops( random_negative_data=random_data, loss_fn=loss_fn, supervision_edge_type=supervision_edge_type, + edge_dir=edge_dir, device=device, ) diff --git a/examples/link_prediction/graph_store/homogeneous_training.py b/examples/link_prediction/graph_store/homogeneous_training.py index ff4938133..0b819eb1b 100644 --- a/examples/link_prediction/graph_store/homogeneous_training.py +++ b/examples/link_prediction/graph_store/homogeneous_training.py @@ -656,11 +656,13 @@ def _run_validation_loops( while True: if num_batches and batch_idx >= num_batches: + logger.info(f"Rank {torch.distributed.get_rank()} num_batches={num_batches} reached, stopping validation loop") break try: main_data = next(main_loader) random_data = next(random_negative_loader) except StopIteration: + logger.info(f"Rank {torch.distributed.get_rank()} test data loader exhausted, stopping validation loop") break loss = _compute_loss( From 47431e733da1e978698b5f6e58ce9b984c29a8ea Mon Sep 17 00:00:00 2001 From: kmontemayor Date: Tue, 24 Feb 2026 19:23:33 +0000 Subject: [PATCH 11/30] idk --- .../link_prediction/graph_store/heterogeneous_training.py | 2 +- .../link_prediction/graph_store/homogeneous_training.py | 8 +++++--- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/examples/link_prediction/graph_store/heterogeneous_training.py b/examples/link_prediction/graph_store/heterogeneous_training.py index 76992ada9..a823018c3 100644 --- a/examples/link_prediction/graph_store/heterogeneous_training.py +++ b/examples/link_prediction/graph_store/heterogeneous_training.py @@ -154,7 +154,7 @@ def _setup_dataloaders( """ rank = torch.distributed.get_rank() - if dataset.get_edge_dir() == "in": + if dataset.get_edge_dir() == "out": query_node_type = supervision_edge_type.dst_node_type labeled_node_type = supervision_edge_type.src_node_type anchor_node_type = labeled_node_type diff --git a/examples/link_prediction/graph_store/homogeneous_training.py b/examples/link_prediction/graph_store/homogeneous_training.py index 0b819eb1b..02c7057ec 100644 --- a/examples/link_prediction/graph_store/homogeneous_training.py +++ b/examples/link_prediction/graph_store/homogeneous_training.py @@ -656,13 +656,15 @@ def _run_validation_loops( while True: if num_batches and batch_idx >= num_batches: - logger.info(f"Rank {torch.distributed.get_rank()} num_batches={num_batches} reached, stopping validation loop") + print(f"Rank {torch.distributed.get_rank()} num_batches={num_batches} reached, stopping validation loop") + flush() break try: main_data = next(main_loader) random_data = next(random_negative_loader) except StopIteration: - logger.info(f"Rank {torch.distributed.get_rank()} test data loader exhausted, stopping validation loop") + print(f"Rank {torch.distributed.get_rank()} test data loader exhausted, stopping validation loop") + flush() break loss = _compute_loss( @@ -678,7 +680,7 @@ def _run_validation_loops( batch_start = time.time() batch_idx += 1 if batch_idx % log_every_n_batch == 0: - logger.info(f"rank={rank}, batch={batch_idx}, latest test_loss={loss:.6f}") + print(f"rank={rank}, batch={batch_idx}, latest test_loss={loss:.6f}") if torch.cuda.is_available(): torch.cuda.synchronize() logger.info( From ca03373cbfa558aad90c2abfabdc46ccde9d3873 Mon Sep 17 00:00:00 2001 From: kmontemayor Date: Tue, 24 Feb 2026 19:23:54 +0000 Subject: [PATCH 12/30] only gs --- testing/e2e_tests/e2e_tests.yaml | 36 ++++++++++++++++---------------- 1 file changed, 18 insertions(+), 18 deletions(-) diff --git a/testing/e2e_tests/e2e_tests.yaml b/testing/e2e_tests/e2e_tests.yaml index 0f3691e80..c6735ce27 100644 --- a/testing/e2e_tests/e2e_tests.yaml +++ b/testing/e2e_tests/e2e_tests.yaml @@ -1,24 +1,24 @@ # Combined e2e test configurations for GiGL # This file contains all the test specifications that can be run via the e2e test script tests: - cora_nalp_test: - task_config_uri: "gigl/src/mocking/configs/e2e_node_anchor_based_link_prediction_template_gbml_config.yaml" - resource_config_uri: "${oc.env:GIGL_TEST_DEFAULT_RESOURCE_CONFIG,deployment/configs/e2e_cicd_resource_config.yaml}" - cora_snc_test: - task_config_uri: "gigl/src/mocking/configs/e2e_supervised_node_classification_template_gbml_config.yaml" - resource_config_uri: "${oc.env:GIGL_TEST_DEFAULT_RESOURCE_CONFIG,deployment/configs/e2e_cicd_resource_config.yaml}" - cora_udl_test: - task_config_uri: "gigl/src/mocking/configs/e2e_udl_node_anchor_based_link_prediction_template_gbml_config.yaml" - resource_config_uri: "${oc.env:GIGL_TEST_DEFAULT_RESOURCE_CONFIG,deployment/configs/e2e_cicd_resource_config.yaml}" - dblp_nalp_test: - task_config_uri: "gigl/src/mocking/configs/dblp_node_anchor_based_link_prediction_template_gbml_config.yaml" - resource_config_uri: "${oc.env:GIGL_TEST_DEFAULT_RESOURCE_CONFIG,deployment/configs/e2e_cicd_resource_config.yaml}" - hom_cora_sup_test: - task_config_uri: "examples/link_prediction/configs/e2e_hom_cora_sup_task_config.yaml" - resource_config_uri: "${oc.env:GIGL_TEST_IN_MEMORY_DEFAULT_RESOURCE_CONFIG,deployment/configs/e2e_glt_resource_config.yaml}" - het_dblp_sup_test: - task_config_uri: "examples/link_prediction/configs/e2e_het_dblp_sup_task_config.yaml" - resource_config_uri: "${oc.env:GIGL_TEST_IN_MEMORY_DEFAULT_RESOURCE_CONFIG,deployment/configs/e2e_glt_resource_config.yaml}" + # cora_nalp_test: + # task_config_uri: "gigl/src/mocking/configs/e2e_node_anchor_based_link_prediction_template_gbml_config.yaml" + # resource_config_uri: "${oc.env:GIGL_TEST_DEFAULT_RESOURCE_CONFIG,deployment/configs/e2e_cicd_resource_config.yaml}" + # cora_snc_test: + # task_config_uri: "gigl/src/mocking/configs/e2e_supervised_node_classification_template_gbml_config.yaml" + # resource_config_uri: "${oc.env:GIGL_TEST_DEFAULT_RESOURCE_CONFIG,deployment/configs/e2e_cicd_resource_config.yaml}" + # cora_udl_test: + # task_config_uri: "gigl/src/mocking/configs/e2e_udl_node_anchor_based_link_prediction_template_gbml_config.yaml" + # resource_config_uri: "${oc.env:GIGL_TEST_DEFAULT_RESOURCE_CONFIG,deployment/configs/e2e_cicd_resource_config.yaml}" + # dblp_nalp_test: + # task_config_uri: "gigl/src/mocking/configs/dblp_node_anchor_based_link_prediction_template_gbml_config.yaml" + # resource_config_uri: "${oc.env:GIGL_TEST_DEFAULT_RESOURCE_CONFIG,deployment/configs/e2e_cicd_resource_config.yaml}" + # hom_cora_sup_test: + # task_config_uri: "examples/link_prediction/configs/e2e_hom_cora_sup_task_config.yaml" + # resource_config_uri: "${oc.env:GIGL_TEST_IN_MEMORY_DEFAULT_RESOURCE_CONFIG,deployment/configs/e2e_glt_resource_config.yaml}" + # het_dblp_sup_test: + # task_config_uri: "examples/link_prediction/configs/e2e_het_dblp_sup_task_config.yaml" + # resource_config_uri: "${oc.env:GIGL_TEST_IN_MEMORY_DEFAULT_RESOURCE_CONFIG,deployment/configs/e2e_glt_resource_config.yaml}" hom_cora_sup_gs_test: task_config_uri: "examples/link_prediction/graph_store/configs/e2e_hom_cora_sup_gs_task_config.yaml" resource_config_uri: "${oc.env:GIGL_TEST_IN_MEMORY_DEFAULT_GRAPH_STORE_RESOURCE_CONFIG,deployment/configs/e2e_glt_gs_resource_config.yaml}" From dd5ffb170044a36242b60af7e3c9d3c5ef2d5d16 Mon Sep 17 00:00:00 2001 From: kmontemayor Date: Tue, 24 Feb 2026 22:16:48 +0000 Subject: [PATCH 13/30] debug --- .../graph_store/heterogeneous_training.py | 2 +- .../graph_store/homogeneous_training.py | 47 ++++++++++++++----- 2 files changed, 37 insertions(+), 12 deletions(-) diff --git a/examples/link_prediction/graph_store/heterogeneous_training.py b/examples/link_prediction/graph_store/heterogeneous_training.py index a823018c3..625aa1231 100644 --- a/examples/link_prediction/graph_store/heterogeneous_training.py +++ b/examples/link_prediction/graph_store/heterogeneous_training.py @@ -157,7 +157,7 @@ def _setup_dataloaders( if dataset.get_edge_dir() == "out": query_node_type = supervision_edge_type.dst_node_type labeled_node_type = supervision_edge_type.src_node_type - anchor_node_type = labeled_node_type + anchor_node_type = query_node_type else: query_node_type = supervision_edge_type.src_node_type labeled_node_type = supervision_edge_type.dst_node_type diff --git a/examples/link_prediction/graph_store/homogeneous_training.py b/examples/link_prediction/graph_store/homogeneous_training.py index 02c7057ec..4ec030e33 100644 --- a/examples/link_prediction/graph_store/homogeneous_training.py +++ b/examples/link_prediction/graph_store/homogeneous_training.py @@ -159,6 +159,14 @@ def _setup_dataloaders( world_size=cluster_info.num_compute_nodes, ) + for storage_rank, ablp_nodes in ablp_input.items(): + print( + f"Rank {rank} split={split}: storage_rank={storage_rank}, " + f"num_anchors={ablp_nodes.anchor_nodes.shape}, " + f"labels: {ablp_nodes.labels}" + ) + flush() + main_loader = DistABLPLoader( dataset=dataset, num_neighbors=num_neighbors, @@ -185,6 +193,13 @@ def _setup_dataloaders( world_size=cluster_info.num_compute_nodes, ) + for storage_rank, node_ids_tensor in all_node_ids.items(): + print( + f"Rank {rank} split={split}: random_negative storage_rank={storage_rank}, " + f"num_node_ids={node_ids_tensor.shape}" + ) + flush() + random_negative_loader = DistNeighborLoader( dataset=dataset, num_neighbors=num_neighbors, @@ -227,9 +242,9 @@ def _compute_loss( Returns: torch.Tensor: Final loss for the current batch on the current process """ - print(f"Computing loss for main data: {main_data}") - print(f"Computing loss for random negative data: {random_negative_data}") - print(f"Using model: {model}") + # print(f"Computing loss for main data: {main_data}") + # print(f"Computing loss for random negative data: {random_negative_data}") + # print(f"Using model: {model}") flush() # Forward pass through encoder main_embeddings = model(data=main_data, device=device) @@ -656,14 +671,19 @@ def _run_validation_loops( while True: if num_batches and batch_idx >= num_batches: - print(f"Rank {torch.distributed.get_rank()} num_batches={num_batches} reached, stopping validation loop") + print(f"Rank {torch.distributed.get_rank()} num_batches={num_batches} reached, stopping validation loop with batch_idx={batch_idx} and num_batches={num_batches}") flush() break try: main_data = next(main_loader) + except StopIteration: + print(f"Rank {torch.distributed.get_rank()} MAIN loader exhausted at batch_idx={batch_idx}, num_batches={num_batches}") + flush() + break + try: random_data = next(random_negative_loader) except StopIteration: - print(f"Rank {torch.distributed.get_rank()} test data loader exhausted, stopping validation loop") + print(f"Rank {torch.distributed.get_rank()} RANDOM NEGATIVE loader exhausted at batch_idx={batch_idx}, num_batches={num_batches}") flush() break @@ -688,14 +708,18 @@ def _run_validation_loops( ) last_n_batch_time.clear() flush() - local_avg_loss = statistics.mean(batch_losses) - logger.info( - f"rank={rank} finished validation loop, local loss: {local_avg_loss=:.6f}" - ) + if batch_losses: + local_avg_loss = statistics.mean(batch_losses) + else: + print(f"rank={rank} WARNING: 0 batches processed in validation loop, setting local loss to 0.0") + flush() + local_avg_loss = 0.0 + print(f"rank={rank} finished validation loop, num_batches_processed={len(batch_losses)}, local loss: {local_avg_loss:.6f}") + flush() global_avg_val_loss = _sync_metric_across_processes( metric=torch.tensor(local_avg_loss, device=device) ) - logger.info(f"rank={rank} got global validation loss {global_avg_val_loss=:.6f}") + print(f"rank={rank} got global validation loss {global_avg_val_loss=:.6f}") flush() return @@ -719,7 +743,8 @@ def _run_example_training( f"World size: {torch.distributed.get_world_size()}, rank: {torch.distributed.get_rank()}, OS world size: {os.environ['WORLD_SIZE']}, OS rank: {os.environ['RANK']}" ) cluster_info = get_graph_store_info() - logger.info(f"Cluster info: {cluster_info}") + print(f"Cluster info: {cluster_info}") + flush() torch.distributed.destroy_process_group() logger.info( f"Took {time.time() - program_start_time:.2f} seconds to connect worker pool" From 7742a2658afae411783065183f7d3ee0870da44a Mon Sep 17 00:00:00 2001 From: kmontemayor Date: Tue, 24 Feb 2026 23:45:06 +0000 Subject: [PATCH 14/30] maybe fix het? --- .../link_prediction/graph_store/heterogeneous_training.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/link_prediction/graph_store/heterogeneous_training.py b/examples/link_prediction/graph_store/heterogeneous_training.py index 625aa1231..30e2f6d3a 100644 --- a/examples/link_prediction/graph_store/heterogeneous_training.py +++ b/examples/link_prediction/graph_store/heterogeneous_training.py @@ -154,7 +154,7 @@ def _setup_dataloaders( """ rank = torch.distributed.get_rank() - if dataset.get_edge_dir() == "out": + if dataset.get_edge_dir() == "in": query_node_type = supervision_edge_type.dst_node_type labeled_node_type = supervision_edge_type.src_node_type anchor_node_type = query_node_type @@ -163,7 +163,7 @@ def _setup_dataloaders( labeled_node_type = supervision_edge_type.dst_node_type anchor_node_type = query_node_type - logger.info(f"---Rank {rank} query node type: {query_node_type}, labeled node type: {labeled_node_type}, anchor node type: {anchor_node_type} due to edge direction {dataset.get_edge_dir()}") + print(f"---Rank {rank} query node type: {query_node_type}, labeled node type: {labeled_node_type}, anchor node type: {anchor_node_type} due to edge direction {dataset.get_edge_dir()}") shuffle = split == "train" From aa09e62a774c684dcdcffb0716d6650c730143cf Mon Sep 17 00:00:00 2001 From: kmontemayor Date: Wed, 25 Feb 2026 01:27:00 +0000 Subject: [PATCH 15/30] swap back --- .../link_prediction/graph_store/heterogeneous_training.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/link_prediction/graph_store/heterogeneous_training.py b/examples/link_prediction/graph_store/heterogeneous_training.py index 30e2f6d3a..178df1192 100644 --- a/examples/link_prediction/graph_store/heterogeneous_training.py +++ b/examples/link_prediction/graph_store/heterogeneous_training.py @@ -253,14 +253,14 @@ def _compute_loss( torch.Tensor: Final loss for the current batch on the current process """ # Extract relevant node types from the supervision edge - if edge_dir == "in": + if edge_dir == "out": query_node_type = supervision_edge_type.src_node_type labeled_node_type = supervision_edge_type.dst_node_type else: query_node_type = supervision_edge_type.dst_node_type labeled_node_type = supervision_edge_type.src_node_type - logger.info(f"---Rank {torch.distributed.get_rank()} query node type: {query_node_type}, labeled node type: {labeled_node_type} due to edge direction {edge_dir}") + print(f"---Rank {torch.distributed.get_rank()} query node type: {query_node_type}, labeled node type: {labeled_node_type} due to edge direction {edge_dir}") if query_node_type == labeled_node_type: inference_node_types = [query_node_type] else: From 88d1b9a274fd6d1e08b52c4eb1a104edadd34e55 Mon Sep 17 00:00:00 2001 From: kmontemayor Date: Wed, 25 Feb 2026 16:37:26 +0000 Subject: [PATCH 16/30] update --- .../graph_store/heterogeneous_training.py | 76 +++++++++---------- 1 file changed, 38 insertions(+), 38 deletions(-) diff --git a/examples/link_prediction/graph_store/heterogeneous_training.py b/examples/link_prediction/graph_store/heterogeneous_training.py index 178df1192..cf45cef40 100644 --- a/examples/link_prediction/graph_store/heterogeneous_training.py +++ b/examples/link_prediction/graph_store/heterogeneous_training.py @@ -169,7 +169,7 @@ def _setup_dataloaders( # In graph store mode, we fetch ABLP input (anchors + positive/negative labels) from the storage cluster. # This returns dict[server_rank, (anchors, pos_labels, neg_labels)] which the DistABLPLoader knows how to handle. - logger.info(f"---Rank {rank} fetching ABLP input for split={split}") + print(f"---Rank {rank} fetching ABLP input for split={split}") flush() ablp_input = dataset.get_ablp_input( split=split, @@ -192,7 +192,7 @@ def _setup_dataloaders( shuffle=shuffle, ) - logger.info(f"---Rank {rank} finished setting up main loader for split={split}") + print(f"---Rank {rank} finished setting up main loader for split={split}") flush() # We need to wait for all processes to finish initializing the main_loader before creating the @@ -219,7 +219,7 @@ def _setup_dataloaders( shuffle=shuffle, ) - logger.info( + print( f"---Rank {rank} finished setting up random negative loader for split={split}" ) flush() @@ -267,9 +267,9 @@ def _compute_loss( inference_node_types = [query_node_type, labeled_node_type] # Forward pass through encoder - print(f"Computing loss for main data: {main_data}") - print(f"Computing loss for random negative data: {random_negative_data}") - print(f"Using model: {model}") + # print(f"Computing loss for main data: {main_data}") + # print(f"Computing loss for random negative data: {random_negative_data}") + # print(f"Using model: {model}") flush() main_embeddings = model( data=main_data, output_node_types=inference_node_types, device=device @@ -425,7 +425,7 @@ def _training_process( # Note: This is a *critical* step in Graph Store mode. It initializes the connection to the storage cluster # and sets up torch.distributed with the appropriate backend (NCCL if CUDA available, gloo otherwise). - logger.info( + print( f"Initializing compute process for local_rank {local_rank} in machine {args.cluster_info.compute_node_rank}" ) flush() @@ -436,7 +436,7 @@ def _training_process( rank = torch.distributed.get_rank() world_size = torch.distributed.get_world_size() - logger.info( + print( f"---Current training process rank: {rank}, training process world size: {world_size}" ) flush() @@ -445,7 +445,7 @@ def _training_process( device = get_available_device(local_process_rank=local_rank) if torch.cuda.is_available(): torch.cuda.set_device(device) - logger.info(f"---Rank {rank} training process set device {device}") + print(f"---Rank {rank} training process set device {device}") loss_fn = RetrievalLoss( loss=torch.nn.CrossEntropyLoss(reduction="mean"), @@ -504,7 +504,7 @@ def _training_process( lr=args.learning_rate, weight_decay=args.weight_decay, ) - logger.info( + print( f"Model initialized on rank {rank} training device {device}\n{model}" ) flush() @@ -520,7 +520,7 @@ def _training_process( last_n_batch_time: list[float] = [] num_max_train_batches_per_process = args.num_max_train_batches // world_size num_val_batches_per_process = args.num_val_batches // world_size - logger.info( + print( f"num_max_train_batches_per_process is set to {num_max_train_batches_per_process}" ) @@ -531,7 +531,7 @@ def _training_process( train_main_loader_iter, train_random_negative_loader_iter ): if batch_idx >= num_max_train_batches_per_process: - logger.info( + print( f"num_max_train_batches_per_process={num_max_train_batches_per_process} reached, " f"stopping training on machine {args.cluster_info.compute_node_rank} local rank {local_rank}" ) @@ -554,23 +554,23 @@ def _training_process( batch_start = time.time() batch_idx += 1 if batch_idx % args.log_every_n_batch == 0: - logger.info( + print( f"rank={rank}, batch={batch_idx}, latest local train_loss={loss:.6f}" ) if torch.cuda.is_available(): torch.cuda.synchronize() - logger.info( + print( f"rank={rank}, batch={batch_idx}, mean(batch_time)={statistics.mean(last_n_batch_time):.3f} sec, max(batch_time)={max(last_n_batch_time):.3f} sec, min(batch_time)={min(last_n_batch_time):.3f} sec" ) last_n_batch_time.clear() - logger.info( + print( f"rank={rank}, latest avg_train_loss={avg_train_loss:.6f}, last {args.log_every_n_batch} mean(avg_train_loss)={statistics.mean(last_n_batch_avg_loss):.6f}" ) last_n_batch_avg_loss.clear() flush() if batch_idx % args.val_every_n_batch == 0: - logger.info(f"rank={rank}, batch={batch_idx}, validating...") + print(f"rank={rank}, batch={batch_idx}, validating...") model.eval() _run_validation_loops( model=model, @@ -585,7 +585,7 @@ def _training_process( ) model.train() - logger.info(f"---Rank {rank} finished training") + print(f"---Rank {rank} finished training") flush() # Memory cleanup and waiting for all processes to finish @@ -602,7 +602,7 @@ def _training_process( # We save the model on the process with rank 0. if torch.distributed.get_rank() == 0: - logger.info( + print( f"Training loop finished, took {time.time() - training_start_time:.3f} seconds, saving model to {args.model_uri}" ) save_state_dict( @@ -624,11 +624,11 @@ def _training_process( find_unused_encoder_parameters=True, state_dict=state_dict, ) - logger.info( + print( f"Model initialized on rank {rank} training device {device}\n{model}" ) - logger.info(f"---Rank {rank} started testing") + print(f"---Rank {rank} started testing") flush() testing_start_time = time.time() @@ -672,7 +672,7 @@ def _training_process( test_main_loader.shutdown() test_random_negative_loader.shutdown() - logger.info( + print( f"---Rank {rank} finished testing in {time.time() - testing_start_time:.3f} seconds" ) flush() @@ -681,7 +681,7 @@ def _training_process( shutdown_compute_proccess() gc.collect() - logger.info( + print( f"---Rank {rank} finished all training and testing, shut down compute process" ) flush() @@ -716,7 +716,7 @@ def _run_validation_loops( rank = torch.distributed.get_rank() - logger.info( + print( f"Running validation loop on rank={rank}, log_every_n_batch={log_every_n_batch}, num_batches={num_batches}" ) if num_batches is None: @@ -734,13 +734,13 @@ def _run_validation_loops( while True: if num_batches and batch_idx >= num_batches: - logger.info(f"Rank {torch.distributed.get_rank()} num_batches={num_batches} reached, stopping validation loop") + print(f"Rank {torch.distributed.get_rank()} num_batches={num_batches} reached, stopping validation loop") break try: main_data = next(main_loader) random_data = next(random_negative_loader) except StopIteration: - logger.info(f"Rank {torch.distributed.get_rank()} test data loader exhausted, stopping validation loop") + print(f"Rank {torch.distributed.get_rank()} test data loader exhausted, stopping validation loop") break loss = _compute_loss( @@ -758,22 +758,22 @@ def _run_validation_loops( batch_start = time.time() batch_idx += 1 if batch_idx % log_every_n_batch == 0: - logger.info(f"rank={rank}, batch={batch_idx}, latest test_loss={loss:.6f}") + print(f"rank={rank}, batch={batch_idx}, latest test_loss={loss:.6f}") if torch.cuda.is_available(): torch.cuda.synchronize() - logger.info( + print( f"rank={rank}, batch={batch_idx}, mean(batch_time)={statistics.mean(last_n_batch_time):.3f} sec, max(batch_time)={max(last_n_batch_time):.3f} sec, min(batch_time)={min(last_n_batch_time):.3f} sec" ) last_n_batch_time.clear() flush() local_avg_loss = statistics.mean(batch_losses) - logger.info( + print( f"rank={rank} finished validation loop, local loss: {local_avg_loss=:.6f}" ) global_avg_val_loss = _sync_metric_across_processes( metric=torch.tensor(local_avg_loss, device=device) ) - logger.info(f"rank={rank} got global validation loss {global_avg_val_loss=:.6f}") + print(f"rank={rank} got global validation loss {global_avg_val_loss=:.6f}") flush() return @@ -789,17 +789,17 @@ def _run_example_training( """ program_start_time = time.time() mp.set_start_method("spawn") - logger.info(f"Starting sub process method: {mp.get_start_method()}") + print(f"Starting sub process method: {mp.get_start_method()}") # Step 1: Initialize global process group to get cluster info torch.distributed.init_process_group(backend="gloo") - logger.info( + print( f"World size: {torch.distributed.get_world_size()}, rank: {torch.distributed.get_rank()}, OS world size: {os.environ['WORLD_SIZE']}, OS rank: {os.environ['RANK']}" ) cluster_info = get_graph_store_info() - logger.info(f"Cluster info: {cluster_info}") + print(f"Cluster info: {cluster_info}") torch.distributed.destroy_process_group() - logger.info( + print( f"Took {time.time() - program_start_time:.2f} seconds to connect worker pool" ) flush() @@ -852,7 +852,7 @@ def _run_example_training( num_val_batches = int(trainer_args.get("num_val_batches", "100")) val_every_n_batch = int(trainer_args.get("val_every_n_batch", "50")) - logger.info( + print( f"Got training args local_world_size={local_world_size}, \ num_neighbors={num_neighbors}, \ sampling_workers_per_process={sampling_workers_per_process}, \ @@ -906,7 +906,7 @@ def _run_example_training( mp_sharing_dict = mp.Manager().dict() # Step 5: Spawn training processes - logger.info("--- Launching training processes ...\n") + print("--- Launching training processes ...\n") flush() start_time = time.time() @@ -941,8 +941,8 @@ def _run_example_training( nprocs=local_world_size, join=True, ) - logger.info(f"--- Training finished, took {time.time() - start_time} seconds") - logger.info( + print(f"--- Training finished, took {time.time() - start_time} seconds") + print( f"--- Program finished, which took {time.time() - program_start_time:.2f} seconds" ) flush() @@ -955,7 +955,7 @@ def _run_example_training( parser.add_argument("--task_config_uri", type=str, help="Gbml config uri") args, unused_args = parser.parse_known_args() - logger.info(f"Unused arguments: {unused_args}") + print(f"Unused arguments: {unused_args}") _run_example_training( task_config_uri=args.task_config_uri, From 9527c5574205bf1969618e2cb5dca3d7e85abd86 Mon Sep 17 00:00:00 2001 From: kmontemayor Date: Wed, 25 Feb 2026 18:09:34 +0000 Subject: [PATCH 17/30] debug --- .../link_prediction/graph_store/heterogeneous_training.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/examples/link_prediction/graph_store/heterogeneous_training.py b/examples/link_prediction/graph_store/heterogeneous_training.py index cf45cef40..91c9b412e 100644 --- a/examples/link_prediction/graph_store/heterogeneous_training.py +++ b/examples/link_prediction/graph_store/heterogeneous_training.py @@ -117,6 +117,7 @@ def _sync_metric_across_processes(metric: torch.Tensor) -> float: assert is_distributed_available_and_initialized(), "DDP is not initialized" # Make a copy of the local loss tensor loss_tensor = metric.detach().clone() + print(f"---Rank {torch.distributed.get_rank()} loss tensor: {loss_tensor}") torch.distributed.all_reduce(loss_tensor, op=torch.distributed.ReduceOp.SUM) return loss_tensor.item() / torch.distributed.get_world_size() @@ -260,7 +261,7 @@ def _compute_loss( query_node_type = supervision_edge_type.dst_node_type labeled_node_type = supervision_edge_type.src_node_type - print(f"---Rank {torch.distributed.get_rank()} query node type: {query_node_type}, labeled node type: {labeled_node_type} due to edge direction {edge_dir}") + #print(f"---Rank {torch.distributed.get_rank()} query node type: {query_node_type}, labeled node type: {labeled_node_type} due to edge direction {edge_dir}") if query_node_type == labeled_node_type: inference_node_types = [query_node_type] else: @@ -553,7 +554,7 @@ def _training_process( last_n_batch_time.append(time.time() - batch_start) batch_start = time.time() batch_idx += 1 - if batch_idx % args.log_every_n_batch == 0: + if batch_idx % args.log_every_n_batch == 0 or batch_idx < 10: # Log the first 10 batches to ensure the model is initialized correctly print( f"rank={rank}, batch={batch_idx}, latest local train_loss={loss:.6f}" ) From a2034df309b56b399dd1c642c98d80c5c0609e86 Mon Sep 17 00:00:00 2001 From: kmontemayor Date: Wed, 25 Feb 2026 19:39:35 +0000 Subject: [PATCH 18/30] update --- .../graph_store/heterogeneous_training.py | 3 +- .../remote_dist_sampling_worker_options.md | 184 ++++++++++++++++++ gigl/utils/iterator.py | 1 + 3 files changed, 187 insertions(+), 1 deletion(-) create mode 100644 gigl/distributed/graph_store/remote_dist_sampling_worker_options.md diff --git a/examples/link_prediction/graph_store/heterogeneous_training.py b/examples/link_prediction/graph_store/heterogeneous_training.py index 91c9b412e..a1f6120da 100644 --- a/examples/link_prediction/graph_store/heterogeneous_training.py +++ b/examples/link_prediction/graph_store/heterogeneous_training.py @@ -585,7 +585,8 @@ def _training_process( num_batches=num_val_batches_per_process, ) model.train() - + else: + print(f"rank={rank} ended training early - no break condition was met") print(f"---Rank {rank} finished training") flush() diff --git a/gigl/distributed/graph_store/remote_dist_sampling_worker_options.md b/gigl/distributed/graph_store/remote_dist_sampling_worker_options.md new file mode 100644 index 000000000..f3d439501 --- /dev/null +++ b/gigl/distributed/graph_store/remote_dist_sampling_worker_options.md @@ -0,0 +1,184 @@ +# `RemoteDistSamplingWorkerOptions` Deep Dive + +## Class Definition + +Defined in the installed GLT package at: +`graphlearn_torch/distributed/dist_options.py` (lines 210-291) + +It extends `_BasicDistSamplingWorkerOptions` (lines 26-117) and is designed for **Graph Store (server-client) mode**, +where sampling workers run on remote storage servers and results are sent back to compute nodes. + +## All Fields/Knobs + +### Inherited from `_BasicDistSamplingWorkerOptions` (lines 26-117) + +| Field | Type | Default | Description | +|---|---|---|---| +| `num_workers` | `int` | `1` | Number of sampling worker subprocesses to launch on the server for this client | +| `worker_devices` | `list[torch.device] \| None` | `None` | Device assignment per worker; auto-assigned if `None` | +| `worker_concurrency` | `int` | `4` | Max concurrent seed batches each worker processes simultaneously (clamped to [1, 32]) | +| `master_addr` | `str` | env `MASTER_ADDR` | Master address for RPC init of the sampling worker group | +| `master_port` | `int` | env `MASTER_PORT` + 1 | Master port for RPC init of the sampling worker group | +| `num_rpc_threads` | `int \| None` | `None` | RPC threads per sampling worker; auto-set to `min(num_partitions, 16)` if `None` | +| `rpc_timeout` | `float` | `180` | Timeout (seconds) for all RPC requests during sampling | + +### Specific to `RemoteDistSamplingWorkerOptions` (lines 210-291) + +| Field | Type | Default | Description | +|---|---|---|---| +| `server_rank` | `int \| list[int]` | auto-assigned | Which storage server(s) to create sampling workers on | +| `buffer_size` | `int \| str` | `"{num_workers * 64}MB"` | Size of server-side shared-memory buffer for sampled messages | +| `buffer_capacity` | computed | `num_workers * worker_concurrency` | Max messages the server-side buffer can hold | +| `prefetch_size` | `int` | `4` | Max prefetched messages on the **client** side (must be <= `buffer_capacity`) | +| `worker_key` | `str \| None` | `None` | Deduplication key -- same key reuses existing producer on server | +| `use_all2all` | `bool` | `False` | Use all2all collective for feature collection instead of point-to-point RPC | +| `glt_graph` | any | `None` | GraphScope only (not used by GiGL) | +| `workload_type` | `str \| None` | `None` | GraphScope only (not used by GiGL) | + +## How Each Field Is Used & Client vs. Server + +### `server_rank` -- CLIENT side + +The client reads this to know which servers to talk to: + +- `dist_loader.py:170-171` -- expanded to `_server_rank_list` +- `dist_loader.py:178` -- `request_server(self._server_rank_list[0], DistServer.get_dataset_meta)` to fetch metadata +- `dist_loader.py:188` -- loops over servers calling `DistServer.create_sampling_producer` +- `dist_loader.py:194` -- passed to `RemoteReceivingChannel` for receiving results +- `dist_loader.py:305-306` -- `request_server(server_rank, DistServer.start_new_epoch_sampling, ...)` + +### `num_workers` -- SERVER side + +Serialized and sent to the server via RPC. On the server: + +- `dist_sampling_producer.py:184` -- `self.num_workers = self.worker_options.num_workers` +- `dist_sampling_producer.py:208` -- spawns that many subprocesses via `mp_context.Process(...)` +- Also drives the defaults for `buffer_size` (line 281) and `buffer_capacity` (line 279) + +### `worker_devices` -- SERVER side + +- Auto-assigned via `_assign_worker_devices()` (`dist_options.py:113-116`) if `None` +- Used in `_sampling_worker_loop` (line 87): `current_device = worker_options.worker_devices[rank]` + +### `worker_concurrency` -- SERVER side + +Controls async parallelism within each sampling worker: + +- `_sampling_worker_loop:106` -- passed to `DistNeighborSampler(..., concurrency=worker_options.worker_concurrency)` +- In `ConcurrentEventLoop` (`event_loop.py:47`): creates a `BoundedSemaphore(concurrency)` limiting concurrent seed batches +- Also drives `buffer_capacity = num_workers * worker_concurrency` (line 279) + +### `master_addr` / `master_port` -- SERVER side + +Used by sampling worker subprocesses to form their own RPC group for cross-partition sampling: + +- `_sampling_worker_loop:93-98` -- `init_rpc(master_addr=..., master_port=..., ...)` + +### `num_rpc_threads` -- SERVER side + +- `_sampling_worker_loop:82-85` -- if `None`, auto-set to `min(data.num_partitions, 16)` +- Line 91: `torch.set_num_threads(num_rpc_threads + 1)` +- Line 93: passed to `init_rpc(...)` which sets `TensorPipeRpcBackendOptions.num_worker_threads` + +### `rpc_timeout` -- SERVER side + +- `_sampling_worker_loop:97` -- `init_rpc(..., rpc_timeout=...)` +- Sets the timeout for RPCs made by sampling workers when fetching graph partitions from other servers + +### `buffer_size` -- SERVER side + +- GLT `dist_server.py:158`: `ShmChannel(worker_options.buffer_capacity, worker_options.buffer_size)` +- GiGL `dist_server.py:456-457`: same usage +- Controls the total bytes of shared memory allocated for the message queue + +### `buffer_capacity` -- SERVER side + +- Computed as `num_workers * worker_concurrency` (`dist_options.py:279`) +- Passed as the first arg to `ShmChannel(capacity, size)` -- max messages before producers block + +### `prefetch_size` -- CLIENT side + +- `dist_loader.py:196` -- `RemoteReceivingChannel(..., prefetch_size)` +- In `remote_channel.py:47`: `self.prefetch_size = prefetch_size` +- Line 56: `queue.Queue(maxsize=self.prefetch_size * len(self.server_rank_list))` +- Lines 120-131: controls how many async RPC fetch requests are in-flight per server at any time + +### `worker_key` -- SERVER side (during producer creation) + +- GLT `dist_server.py:152`: `producer_id = self._worker_key2producer_id.get(worker_options.worker_key)` -- if already exists, reuses the producer +- GiGL `dist_server.py:444-453`: same pattern with per-producer locks + +### `use_all2all` -- SERVER side + +- `_sampling_worker_loop:73-80` -- if True, initializes `torch.distributed` process group (gloo backend) +- `dist_neighbor_sampler.py:749-753` -- switches from per-type `async_get()` to `get_all2all()` for feature collection + +## Client vs. Server Summary + +| Field | Side | Purpose | +|---|---|---| +| `server_rank` | **Client** | Which servers to send RPCs to | +| `num_workers` | **Server** | Sampling subprocesses per server | +| `worker_devices` | **Server** | Device per subprocess | +| `worker_concurrency` | **Server** | Concurrent batches per subprocess | +| `master_addr` / `master_port` | **Server** | RPC group for cross-partition sampling | +| `num_rpc_threads` | **Server** | RPC threads per sampling subprocess | +| `rpc_timeout` | **Server** | Timeout for cross-partition RPCs | +| `buffer_size` | **Server** | Shared-memory buffer bytes | +| `buffer_capacity` | **Server** | Shared-memory buffer message count | +| `prefetch_size` | **Client** | Prefetched messages per server | +| `worker_key` | **Server** | Producer deduplication | +| `use_all2all` | **Server** | Collective vs point-to-point features | + +The entire options object is **serialized and sent via RPC** from client to server (at `dist_loader.py:188` via +`DistServer.create_sampling_producer`). The server reads the server-side fields; the client reads `server_rank` and +`prefetch_size` locally. + +## How GiGL Uses It + +### `DistNeighborLoader._setup_for_graph_store()` + +**File:** `gigl/distributed/distributed_neighborloader.py:386-395` + +```python +worker_options = RemoteDistSamplingWorkerOptions( + server_rank=list(range(dataset.cluster_info.num_storage_nodes)), + num_workers=num_workers, + worker_devices=[torch.device("cpu") for i in range(num_workers)], + master_addr=dataset.cluster_info.storage_cluster_master_ip, + buffer_size=channel_size, # defaults to "4GB" + master_port=sampling_port, + worker_key=worker_key, # unique per compute rank + loader instance + prefetch_size=prefetch_size, # default 4 +) +``` + +GiGL talks to **all** storage servers (`server_rank=list(range(num_storage_nodes))`), always uses **CPU** sampling, and +assigns a unique `worker_key` per compute rank + loader instance (`distributed_neighborloader.py:384`). + +Notably, GiGL **bypasses GLT's `DistLoader.__init__`** in `_init_graph_store_connections()` (lines 609-837), +dispatching `create_sampling_producer` RPCs sequentially per compute node to avoid GLT's `ThreadPoolExecutor` deadlock +at large scale. + +### `DistABLPLoader._setup_for_graph_store()` + +**File:** `gigl/distributed/dist_ablp_neighborloader.py:799-808` + +```python +worker_options = RemoteDistSamplingWorkerOptions( + server_rank=list(range(dataset.cluster_info.num_storage_nodes)), + num_workers=num_workers, + worker_devices=[torch.device("cpu") for _ in range(num_workers)], + worker_concurrency=worker_concurrency, + master_addr=dataset.cluster_info.storage_cluster_master_ip, + master_port=sampling_port, + worker_key=worker_key, + prefetch_size=prefetch_size, +) +``` + +Nearly identical, except: + +- Explicitly passes `worker_concurrency` (default `4`) +- Does **not** set `buffer_size` (uses GLT default of `num_workers * 64 MB` instead of GiGL's `4GB`) +- Uses `ThreadPoolExecutor` for setup (lines 1002-1015) rather than the sequential barrier approach diff --git a/gigl/utils/iterator.py b/gigl/utils/iterator.py index 63f809083..11563d7b8 100644 --- a/gigl/utils/iterator.py +++ b/gigl/utils/iterator.py @@ -20,5 +20,6 @@ def __next__(self) -> _T: try: return next(self._iter) except StopIteration: + print(f"InfiniteIterator: _iterable={self._iterable} exhausted, resetting iterator") self._iter = iter(self._iterable) return next(self._iter) From 97b2e946a8d5a61d57e7d9bc4d910a16019e7ee5 Mon Sep 17 00:00:00 2001 From: kmontemayor Date: Wed, 25 Feb 2026 20:54:59 +0000 Subject: [PATCH 19/30] test fixes --- .../graph_store/heterogeneous_training.py | 10 ++++++---- testing/e2e_tests/e2e_tests.yaml | 6 +++--- 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/examples/link_prediction/graph_store/heterogeneous_training.py b/examples/link_prediction/graph_store/heterogeneous_training.py index a1f6120da..1be1871a5 100644 --- a/examples/link_prediction/graph_store/heterogeneous_training.py +++ b/examples/link_prediction/graph_store/heterogeneous_training.py @@ -768,10 +768,12 @@ def _run_validation_loops( ) last_n_batch_time.clear() flush() - local_avg_loss = statistics.mean(batch_losses) - print( - f"rank={rank} finished validation loop, local loss: {local_avg_loss=:.6f}" - ) + if len(batch_losses) == 0: + print(f"rank={rank} WARNING: 0 batches processed in validation loop, setting local loss to 0.0") + flush() + local_avg_loss = 0.0 + else: + local_avg_loss = statistics.mean(batch_losses) global_avg_val_loss = _sync_metric_across_processes( metric=torch.tensor(local_avg_loss, device=device) ) diff --git a/testing/e2e_tests/e2e_tests.yaml b/testing/e2e_tests/e2e_tests.yaml index c6735ce27..d87465ed4 100644 --- a/testing/e2e_tests/e2e_tests.yaml +++ b/testing/e2e_tests/e2e_tests.yaml @@ -19,9 +19,9 @@ tests: # het_dblp_sup_test: # task_config_uri: "examples/link_prediction/configs/e2e_het_dblp_sup_task_config.yaml" # resource_config_uri: "${oc.env:GIGL_TEST_IN_MEMORY_DEFAULT_RESOURCE_CONFIG,deployment/configs/e2e_glt_resource_config.yaml}" - hom_cora_sup_gs_test: - task_config_uri: "examples/link_prediction/graph_store/configs/e2e_hom_cora_sup_gs_task_config.yaml" - resource_config_uri: "${oc.env:GIGL_TEST_IN_MEMORY_DEFAULT_GRAPH_STORE_RESOURCE_CONFIG,deployment/configs/e2e_glt_gs_resource_config.yaml}" + # hom_cora_sup_gs_test: + # task_config_uri: "examples/link_prediction/graph_store/configs/e2e_hom_cora_sup_gs_task_config.yaml" + # resource_config_uri: "${oc.env:GIGL_TEST_IN_MEMORY_DEFAULT_GRAPH_STORE_RESOURCE_CONFIG,deployment/configs/e2e_glt_gs_resource_config.yaml}" het_dblp_sup_gs_test: task_config_uri: "examples/link_prediction/graph_store/configs/e2e_het_dblp_sup_gs_task_config.yaml" resource_config_uri: "${oc.env:GIGL_TEST_IN_MEMORY_DEFAULT_GRAPH_STORE_RESOURCE_CONFIG,deployment/configs/e2e_glt_gs_resource_config.yaml}" From 19207c82d0c6b2d9ff1cc96d2cbf4160e50c8592 Mon Sep 17 00:00:00 2001 From: kmontemayor Date: Wed, 25 Feb 2026 23:49:53 +0000 Subject: [PATCH 20/30] debug --- .../graph_store/heterogeneous_training.py | 14 +- .../graph_store/homogeneous_training.py | 141 ++++++++++++------ 2 files changed, 101 insertions(+), 54 deletions(-) diff --git a/examples/link_prediction/graph_store/heterogeneous_training.py b/examples/link_prediction/graph_store/heterogeneous_training.py index 1be1871a5..d6839d13d 100644 --- a/examples/link_prediction/graph_store/heterogeneous_training.py +++ b/examples/link_prediction/graph_store/heterogeneous_training.py @@ -155,7 +155,7 @@ def _setup_dataloaders( """ rank = torch.distributed.get_rank() - if dataset.get_edge_dir() == "in": + if dataset.fetch_edge_dir() == "in": query_node_type = supervision_edge_type.dst_node_type labeled_node_type = supervision_edge_type.src_node_type anchor_node_type = query_node_type @@ -164,7 +164,7 @@ def _setup_dataloaders( labeled_node_type = supervision_edge_type.dst_node_type anchor_node_type = query_node_type - print(f"---Rank {rank} query node type: {query_node_type}, labeled node type: {labeled_node_type}, anchor node type: {anchor_node_type} due to edge direction {dataset.get_edge_dir()}") + print(f"---Rank {rank} query node type: {query_node_type}, labeled node type: {labeled_node_type}, anchor node type: {anchor_node_type} due to edge direction {dataset.fetch_edge_dir()}") shuffle = split == "train" @@ -172,7 +172,7 @@ def _setup_dataloaders( # This returns dict[server_rank, (anchors, pos_labels, neg_labels)] which the DistABLPLoader knows how to handle. print(f"---Rank {rank} fetching ABLP input for split={split}") flush() - ablp_input = dataset.get_ablp_input( + ablp_input = dataset.fetch_ablp_input( split=split, rank=cluster_info.compute_node_rank, world_size=cluster_info.num_compute_nodes, @@ -201,7 +201,7 @@ def _setup_dataloaders( torch.distributed.barrier() # For the random negative loader, we get all node IDs of the labeled node type from the storage cluster. - all_node_ids = dataset.get_node_ids( + all_node_ids = dataset.fetch_node_ids( rank=cluster_info.compute_node_rank, world_size=cluster_info.num_compute_nodes, node_type=labeled_node_type, @@ -543,7 +543,7 @@ def _training_process( random_negative_data=random_data, loss_fn=loss_fn, supervision_edge_type=args.supervision_edge_type, - edge_dir=dataset.get_edge_dir(), + edge_dir=dataset.fetch_edge_dir(), device=device, ) optimizer.zero_grad() @@ -579,7 +579,7 @@ def _training_process( random_negative_loader=val_random_negative_loader_iter, loss_fn=loss_fn, supervision_edge_type=args.supervision_edge_type, - edge_dir=dataset.get_edge_dir(), + edge_dir=dataset.fetch_edge_dir(), device=device, log_every_n_batch=args.log_every_n_batch, num_batches=num_val_batches_per_process, @@ -660,7 +660,7 @@ def _training_process( random_negative_loader=test_random_negative_loader_iter, loss_fn=loss_fn, supervision_edge_type=args.supervision_edge_type, - edge_dir=dataset.get_edge_dir(), + edge_dir=dataset.fetch_edge_dir(), device=device, log_every_n_batch=args.log_every_n_batch, ) diff --git a/examples/link_prediction/graph_store/homogeneous_training.py b/examples/link_prediction/graph_store/homogeneous_training.py index 4ec030e33..9c7ab1a5f 100644 --- a/examples/link_prediction/graph_store/homogeneous_training.py +++ b/examples/link_prediction/graph_store/homogeneous_training.py @@ -1,58 +1,105 @@ """ -This file contains an example for how to run homogeneous training in **graph store mode** using GiGL. +This file contains an example for how to run homogeneous link prediction training in +**graph store mode** using GiGL. -Graph Store Mode vs Standard Mode: ----------------------------------- +Graph Store Mode vs Standard (Colocated) Mode +---------------------------------------------- Graph store mode uses a heterogeneous cluster architecture with two distinct sub-clusters: + 1. **Storage Cluster (graph_store_pool)**: Dedicated machines for storing and serving the graph data. These are typically high-memory machines without GPUs (e.g., n2-highmem-32). 2. **Compute Cluster (compute_pool)**: Dedicated machines for running model training. These typically have GPUs attached (e.g., n1-standard-16 with NVIDIA_TESLA_T4). -This separation allows for: - - Independent scaling of storage and compute resources - - Better memory utilization (graph data stays on storage nodes) - - Cost optimization by using appropriate hardware for each role - -In contrast, the standard training mode (see `examples/link_prediction/homogeneous_training.py`) -uses a homogeneous cluster where each machine handles both graph storage and computation. - -Key Implementation Differences: -------------------------------- -This file (graph store mode): - - Uses `RemoteDistDataset` to connect to a remote graph store cluster - - Uses `init_compute_process` to initialize the compute node connection to storage - - Obtains cluster topology via `get_graph_store_info()` which returns `GraphStoreInfo` - - Uses `mp_sharing_dict` for efficient tensor sharing between local processes - - Fetches ABLP input via `RemoteDistDataset.get_ablp_input()` for the train/val/test splits - - Fetches random negative node IDs via `RemoteDistDataset.get_node_ids()` - -Standard mode (`homogeneous_training.py`): - - Uses `DistDataset` with `build_dataset_from_task_config_uri` where each node loads its partition - - Manually manages distributed process groups with master IP and port - - Each machine stores its own partition of the graph data - -To run this file with GiGL orchestration, set the fields similar to below: - -trainerConfig: - trainerArgs: - log_every_n_batch: "50" - num_neighbors: "[10, 10]" - command: python -m examples.link_prediction.graph_store.homogeneous_training - graphStoreStorageConfig: - command: python -m examples.link_prediction.graph_store.storage_main - storageArgs: - sample_edge_direction: "in" - splitter_cls_path: "gigl.utils.data_splitters.DistNodeAnchorLinkSplitter" - splitter_kwargs: '{"sampling_direction": "in", "should_convert_labels_to_edges": true, "num_val": 0.1, "num_test": 0.1}' - num_server_sessions: "1" -featureFlags: - should_run_glt_backend: 'True' - -Note: Ensure you use a resource config with `vertex_ai_graph_store_trainer_config` when +This separation allows for independent scaling of storage and compute resources, better memory +utilization (graph data stays on storage nodes), and cost optimization by using appropriate +hardware for each role. + +In contrast, the standard colocated training mode +(see ``examples/link_prediction/homogeneous_training.py``) uses a homogeneous cluster where each +machine handles both graph storage and computation. + +Key Implementation Differences +------------------------------ + ++---------------------------+----------------------------------------------+----------------------------------------------+ +| Aspect | Standard (``homogeneous_training.py``) | Graph Store (this file) | ++===========================+==============================================+==============================================+ +| **Dataset class** | ``DistDataset`` (local partition) | ``RemoteDistDataset`` (RPC to storage) | ++---------------------------+----------------------------------------------+----------------------------------------------+ +| **Dataset loading** | ``build_dataset_from_task_config_uri()`` | Storage nodes build data; compute nodes | +| | loads and partitions data locally | connect via ``init_compute_process()`` | ++---------------------------+----------------------------------------------+----------------------------------------------+ +| **Process group init** | Manual ``init_process_group`` with master | ``init_process_group(gloo)`` to | +| | IP/port, ``destroy_process_group``, then | ``get_graph_store_info()``, then | +| | re-init in spawned processes | ``destroy_process_group``; spawned processes | +| | | call ``init_compute_process()`` | ++---------------------------+----------------------------------------------+----------------------------------------------+ +| **Split/label access** | ``dataset.train_node_ids`` / | ``dataset.fetch_ablp_input(split=...)`` | +| | ``dataset.val_node_ids`` / | fetches anchors + labels from storage via | +| | ``dataset.test_node_ids`` via | RPC | +| | ``to_homogeneous()`` | | ++---------------------------+----------------------------------------------+----------------------------------------------+ +| **Random negative nodes** | ``dataset.node_ids`` via | ``dataset.fetch_node_ids()`` fetches from | +| | ``to_homogeneous()`` | storage via RPC | ++---------------------------+----------------------------------------------+----------------------------------------------+ +| **Cluster info** | ``machine_rank``, ``machine_world_size``, | ``GraphStoreInfo`` dataclass from | +| | ``master_ip_address`` extracted manually | ``get_graph_store_info()`` encapsulates all | +| | | topology | ++---------------------------+----------------------------------------------+----------------------------------------------+ +| **Inter-process sharing** | N/A (each process loads own partition) | ``mp_sharing_dict`` for efficient tensor | +| | | sharing between local processes | ++---------------------------+----------------------------------------------+----------------------------------------------+ +| **Cleanup** | ``torch.distributed.destroy_process_group()`` | ``shutdown_compute_proccess()`` disconnects | +| | | from storage cluster | ++---------------------------+----------------------------------------------+----------------------------------------------+ + +Data Splitting and Storage Pipeline +------------------------------------ +Before training begins, the **storage cluster** prepares the graph data including train/val/test +splits. The flow is: + +1. **Splitter configuration**: The ``splitter_cls_path`` and ``splitter_kwargs`` are specified in + the YAML config under ``graphStoreStorageConfig.storageArgs``. The storage entry point + (``storage_main.py``) parses these via ``argparse`` and dynamically imports the splitter class + using ``import_obj()``. The kwargs string is evaluated with ``ast.literal_eval`` and passed to + the splitter constructor (e.g. ``DistNodeAnchorLinkSplitter(**splitter_kwargs)``). + +2. **ABLP input fetching** (at training time): ``RemoteDistDataset.fetch_ablp_input(split=...)`` + issues an RPC to the storage cluster and returns a ``dict[int, ABLPInputNodes]`` keyed by + storage rank. Each ``ABLPInputNodes`` contains ``anchor_nodes``, ``positive_labels``, and + optional ``negative_labels`` tensors for the requested split. + +3. **Node ID fetching**: ``RemoteDistDataset.fetch_node_ids()`` similarly fetches all node IDs + from storage, used for the random negative sampling loader. + +Because the storage cluster owns the split, compute nodes see train/val/test as first-class +properties of the remote dataset. + +Config Example +-------------- +To run this file with GiGL orchestration, set the fields similar to below:: + + trainerConfig: + trainerArgs: + log_every_n_batch: "50" + num_neighbors: "[10, 10]" + command: python -m examples.link_prediction.graph_store.homogeneous_training + graphStoreStorageConfig: + command: python -m examples.link_prediction.graph_store.storage_main + storageArgs: + sample_edge_direction: "in" + splitter_cls_path: "gigl.utils.data_splitters.DistNodeAnchorLinkSplitter" + splitter_kwargs: '{"sampling_direction": "in", "should_convert_labels_to_edges": true, "num_val": 0.1, "num_test": 0.1}' + num_server_sessions: "1" + featureFlags: + should_run_glt_backend: 'True' + +Note: Ensure you use a resource config with ``vertex_ai_graph_store_trainer_config`` when running in graph store mode. -You can run this example in a full pipeline with `make run_hom_cora_sup_gs_e2e_test` from GiGL root. +You can run this example in a full pipeline with ``make run_hom_cora_sup_gs_e2e_test`` from +GiGL root. """ import argparse @@ -153,7 +200,7 @@ def _setup_dataloaders( # For homogeneous graphs, no node type or supervision edge type wrapper is needed. logger.info(f"---Rank {rank} fetching ABLP input for split={split}") flush() - ablp_input = dataset.get_ablp_input( + ablp_input = dataset.fetch_ablp_input( split=split, rank=cluster_info.compute_node_rank, world_size=cluster_info.num_compute_nodes, @@ -188,7 +235,7 @@ def _setup_dataloaders( torch.distributed.barrier() # For the random negative loader, we get all node IDs from the storage cluster. - all_node_ids = dataset.get_node_ids( + all_node_ids = dataset.fetch_node_ids( rank=cluster_info.compute_node_rank, world_size=cluster_info.num_compute_nodes, ) From 520bbee76ce536d2d59660b78d314ac8d4175db1 Mon Sep 17 00:00:00 2001 From: kmontemayor Date: Thu, 26 Feb 2026 16:37:49 +0000 Subject: [PATCH 21/30] idk --- gigl/distributed/base_dist_loader.py | 6 ++++++ gigl/utils/iterator.py | 7 ++++++- 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/gigl/distributed/base_dist_loader.py b/gigl/distributed/base_dist_loader.py index d4ae3e452..483f6cb4a 100644 --- a/gigl/distributed/base_dist_loader.py +++ b/gigl/distributed/base_dist_loader.py @@ -564,12 +564,14 @@ def shutdown(self) -> None: # Overwrite DistLoader.__iter__ to so we can use our own __iter__ and rpc calls def __iter__(self) -> Self: + logger.info(f"rank={torch.distributed.get_rank()}: BaseDistLoader.__iter__") self._num_recv = 0 if self._is_collocated_worker: self._collocated_producer.reset() elif self._is_mp_worker: self._mp_producer.produce_all() else: + logger.info(f"rank={torch.distributed.get_rank()}: BaseDistLoader.__iter__: rpc calls") rpc_futures: list[torch.futures.Future[None]] = [] for server_rank, producer_id in zip( self._server_rank_list, self._producer_id_list @@ -581,7 +583,11 @@ def __iter__(self) -> Self: self._epoch, ) rpc_futures.append(fut) + logger.info(f"rank={torch.distributed.get_rank()}: BaseDistLoader.__iter__: waiting for {len(rpc_futures)} rpc calls") torch.futures.wait_all(rpc_futures) + logger.info(f"rank={torch.distributed.get_rank()}: BaseDistLoader.__iter__: rpc calls done") self._channel.reset() + logger.info(f"rank={torch.distributed.get_rank()}: BaseDistLoader.__iter__: channel reset") + logger.info(f"rank={torch.distributed.get_rank()}: BaseDistLoader.__iter__: epoch incremented to {self._epoch}") self._epoch += 1 return self diff --git a/gigl/utils/iterator.py b/gigl/utils/iterator.py index 11563d7b8..aaad1812d 100644 --- a/gigl/utils/iterator.py +++ b/gigl/utils/iterator.py @@ -1,6 +1,8 @@ from collections.abc import Iterable, Iterator from typing import TypeVar +import torch + _T = TypeVar("_T") @@ -20,6 +22,9 @@ def __next__(self) -> _T: try: return next(self._iter) except StopIteration: - print(f"InfiniteIterator: _iterable={self._iterable} exhausted, resetting iterator") + if torch.distributed.is_initialized(): + print(f"rank={torch.distributed.get_rank()}: InfiniteIterator: _iterable={self._iterable} exhausted, resetting iterator") + else: + print(f"InfiniteIterator: _iterable={self._iterable} exhausted, resetting iterator") self._iter = iter(self._iterable) return next(self._iter) From e589894bbd1206cfb41b9da263d11e708990c112 Mon Sep 17 00:00:00 2001 From: kmontemayor Date: Thu, 26 Feb 2026 18:17:08 +0000 Subject: [PATCH 22/30] debug --- gigl/distributed/base_dist_loader.py | 12 ++++++------ gigl/distributed/graph_store/dist_server.py | 6 ++++++ 2 files changed, 12 insertions(+), 6 deletions(-) diff --git a/gigl/distributed/base_dist_loader.py b/gigl/distributed/base_dist_loader.py index 483f6cb4a..fb430c6ea 100644 --- a/gigl/distributed/base_dist_loader.py +++ b/gigl/distributed/base_dist_loader.py @@ -564,14 +564,14 @@ def shutdown(self) -> None: # Overwrite DistLoader.__iter__ to so we can use our own __iter__ and rpc calls def __iter__(self) -> Self: - logger.info(f"rank={torch.distributed.get_rank()}: BaseDistLoader.__iter__") + print(f"rank={torch.distributed.get_rank()}: BaseDistLoader.__iter__") self._num_recv = 0 if self._is_collocated_worker: self._collocated_producer.reset() elif self._is_mp_worker: self._mp_producer.produce_all() else: - logger.info(f"rank={torch.distributed.get_rank()}: BaseDistLoader.__iter__: rpc calls") + print(f"rank={torch.distributed.get_rank()}: BaseDistLoader.__iter__: rpc calls") rpc_futures: list[torch.futures.Future[None]] = [] for server_rank, producer_id in zip( self._server_rank_list, self._producer_id_list @@ -583,11 +583,11 @@ def __iter__(self) -> Self: self._epoch, ) rpc_futures.append(fut) - logger.info(f"rank={torch.distributed.get_rank()}: BaseDistLoader.__iter__: waiting for {len(rpc_futures)} rpc calls") + print(f"rank={torch.distributed.get_rank()}: BaseDistLoader.__iter__: waiting for {len(rpc_futures)} rpc calls") torch.futures.wait_all(rpc_futures) - logger.info(f"rank={torch.distributed.get_rank()}: BaseDistLoader.__iter__: rpc calls done") + print(f"rank={torch.distributed.get_rank()}: BaseDistLoader.__iter__: rpc calls done") self._channel.reset() - logger.info(f"rank={torch.distributed.get_rank()}: BaseDistLoader.__iter__: channel reset") - logger.info(f"rank={torch.distributed.get_rank()}: BaseDistLoader.__iter__: epoch incremented to {self._epoch}") + print(f"rank={torch.distributed.get_rank()}: BaseDistLoader.__iter__: channel reset") + print(f"rank={torch.distributed.get_rank()}: BaseDistLoader.__iter__: epoch incremented to {self._epoch}") self._epoch += 1 return self diff --git a/gigl/distributed/graph_store/dist_server.py b/gigl/distributed/graph_store/dist_server.py index 1506432b2..23ae3c32f 100644 --- a/gigl/distributed/graph_store/dist_server.py +++ b/gigl/distributed/graph_store/dist_server.py @@ -555,13 +555,19 @@ def start_new_epoch_sampling(self, producer_id: int, epoch: int) -> None: r"""Start a new epoch sampling tasks for a specific sampling producer with its producer id. """ + logger.info(f"DistServer.start_new_epoch_sampling: producer_id={producer_id}, epoch={epoch}") with self._producer_lock[producer_id]: cur_epoch = self._epoch[producer_id] if cur_epoch < epoch: self._epoch[producer_id] = epoch producer = self._producer_pool.get(producer_id, None) if producer is not None: + logger.info(f"DistServer.start_new_epoch_sampling: producing all for producer {producer_id}") producer.produce_all() + else: + logger.warning(f"DistServer.start_new_epoch_sampling: producer {producer_id} not found") + else: + logger.info(f"DistServer.start_new_epoch_sampling: producer {producer_id} already on epoch {cur_epoch}, skipping") def fetch_one_sampled_message( self, producer_id: int From 5671c5934099805628d9b8f6eabad6515dd8bead Mon Sep 17 00:00:00 2001 From: kmontemayor Date: Thu, 26 Feb 2026 23:58:13 +0000 Subject: [PATCH 23/30] blehg --- .../graph_store/heterogeneous_training.py | 34 +++++++++++-------- .../graph_store/homogeneous_training.py | 20 ++++++++--- gigl/distributed/base_dist_loader.py | 28 ++++++++++----- gigl/distributed/graph_store/dist_server.py | 24 ++++++++++--- gigl/utils/iterator.py | 8 +++-- 5 files changed, 79 insertions(+), 35 deletions(-) diff --git a/examples/link_prediction/graph_store/heterogeneous_training.py b/examples/link_prediction/graph_store/heterogeneous_training.py index d6839d13d..15854ee5d 100644 --- a/examples/link_prediction/graph_store/heterogeneous_training.py +++ b/examples/link_prediction/graph_store/heterogeneous_training.py @@ -164,7 +164,9 @@ def _setup_dataloaders( labeled_node_type = supervision_edge_type.dst_node_type anchor_node_type = query_node_type - print(f"---Rank {rank} query node type: {query_node_type}, labeled node type: {labeled_node_type}, anchor node type: {anchor_node_type} due to edge direction {dataset.fetch_edge_dir()}") + print( + f"---Rank {rank} query node type: {query_node_type}, labeled node type: {labeled_node_type}, anchor node type: {anchor_node_type} due to edge direction {dataset.fetch_edge_dir()}" + ) shuffle = split == "train" @@ -261,7 +263,7 @@ def _compute_loss( query_node_type = supervision_edge_type.dst_node_type labeled_node_type = supervision_edge_type.src_node_type - #print(f"---Rank {torch.distributed.get_rank()} query node type: {query_node_type}, labeled node type: {labeled_node_type} due to edge direction {edge_dir}") + # print(f"---Rank {torch.distributed.get_rank()} query node type: {query_node_type}, labeled node type: {labeled_node_type} due to edge direction {edge_dir}") if query_node_type == labeled_node_type: inference_node_types = [query_node_type] else: @@ -505,9 +507,7 @@ def _training_process( lr=args.learning_rate, weight_decay=args.weight_decay, ) - print( - f"Model initialized on rank {rank} training device {device}\n{model}" - ) + print(f"Model initialized on rank {rank} training device {device}\n{model}") flush() # We add a barrier to wait for all processes to finish preparing the dataloader and initializing the model @@ -554,7 +554,9 @@ def _training_process( last_n_batch_time.append(time.time() - batch_start) batch_start = time.time() batch_idx += 1 - if batch_idx % args.log_every_n_batch == 0 or batch_idx < 10: # Log the first 10 batches to ensure the model is initialized correctly + if ( + batch_idx % args.log_every_n_batch == 0 or batch_idx < 10 + ): # Log the first 10 batches to ensure the model is initialized correctly print( f"rank={rank}, batch={batch_idx}, latest local train_loss={loss:.6f}" ) @@ -626,9 +628,7 @@ def _training_process( find_unused_encoder_parameters=True, state_dict=state_dict, ) - print( - f"Model initialized on rank {rank} training device {device}\n{model}" - ) + print(f"Model initialized on rank {rank} training device {device}\n{model}") print(f"---Rank {rank} started testing") flush() @@ -736,13 +736,17 @@ def _run_validation_loops( while True: if num_batches and batch_idx >= num_batches: - print(f"Rank {torch.distributed.get_rank()} num_batches={num_batches} reached, stopping validation loop") + print( + f"Rank {torch.distributed.get_rank()} num_batches={num_batches} reached, stopping validation loop" + ) break try: main_data = next(main_loader) random_data = next(random_negative_loader) except StopIteration: - print(f"Rank {torch.distributed.get_rank()} test data loader exhausted, stopping validation loop") + print( + f"Rank {torch.distributed.get_rank()} test data loader exhausted, stopping validation loop" + ) break loss = _compute_loss( @@ -769,7 +773,9 @@ def _run_validation_loops( last_n_batch_time.clear() flush() if len(batch_losses) == 0: - print(f"rank={rank} WARNING: 0 batches processed in validation loop, setting local loss to 0.0") + print( + f"rank={rank} WARNING: 0 batches processed in validation loop, setting local loss to 0.0" + ) flush() local_avg_loss = 0.0 else: @@ -803,9 +809,7 @@ def _run_example_training( cluster_info = get_graph_store_info() print(f"Cluster info: {cluster_info}") torch.distributed.destroy_process_group() - print( - f"Took {time.time() - program_start_time:.2f} seconds to connect worker pool" - ) + print(f"Took {time.time() - program_start_time:.2f} seconds to connect worker pool") flush() # Step 2: Read config diff --git a/examples/link_prediction/graph_store/homogeneous_training.py b/examples/link_prediction/graph_store/homogeneous_training.py index 9c7ab1a5f..816fb1edc 100644 --- a/examples/link_prediction/graph_store/homogeneous_training.py +++ b/examples/link_prediction/graph_store/homogeneous_training.py @@ -718,19 +718,25 @@ def _run_validation_loops( while True: if num_batches and batch_idx >= num_batches: - print(f"Rank {torch.distributed.get_rank()} num_batches={num_batches} reached, stopping validation loop with batch_idx={batch_idx} and num_batches={num_batches}") + print( + f"Rank {torch.distributed.get_rank()} num_batches={num_batches} reached, stopping validation loop with batch_idx={batch_idx} and num_batches={num_batches}" + ) flush() break try: main_data = next(main_loader) except StopIteration: - print(f"Rank {torch.distributed.get_rank()} MAIN loader exhausted at batch_idx={batch_idx}, num_batches={num_batches}") + print( + f"Rank {torch.distributed.get_rank()} MAIN loader exhausted at batch_idx={batch_idx}, num_batches={num_batches}" + ) flush() break try: random_data = next(random_negative_loader) except StopIteration: - print(f"Rank {torch.distributed.get_rank()} RANDOM NEGATIVE loader exhausted at batch_idx={batch_idx}, num_batches={num_batches}") + print( + f"Rank {torch.distributed.get_rank()} RANDOM NEGATIVE loader exhausted at batch_idx={batch_idx}, num_batches={num_batches}" + ) flush() break @@ -758,10 +764,14 @@ def _run_validation_loops( if batch_losses: local_avg_loss = statistics.mean(batch_losses) else: - print(f"rank={rank} WARNING: 0 batches processed in validation loop, setting local loss to 0.0") + print( + f"rank={rank} WARNING: 0 batches processed in validation loop, setting local loss to 0.0" + ) flush() local_avg_loss = 0.0 - print(f"rank={rank} finished validation loop, num_batches_processed={len(batch_losses)}, local loss: {local_avg_loss:.6f}") + print( + f"rank={rank} finished validation loop, num_batches_processed={len(batch_losses)}, local loss: {local_avg_loss:.6f}" + ) flush() global_avg_val_loss = _sync_metric_across_processes( metric=torch.tensor(local_avg_loss, device=device) diff --git a/gigl/distributed/base_dist_loader.py b/gigl/distributed/base_dist_loader.py index fb430c6ea..e5b7b4c4b 100644 --- a/gigl/distributed/base_dist_loader.py +++ b/gigl/distributed/base_dist_loader.py @@ -568,11 +568,15 @@ def __iter__(self) -> Self: self._num_recv = 0 if self._is_collocated_worker: self._collocated_producer.reset() + self._epoch += 1 elif self._is_mp_worker: self._mp_producer.produce_all() + self._epoch += 1 else: - print(f"rank={torch.distributed.get_rank()}: BaseDistLoader.__iter__: rpc calls") - rpc_futures: list[torch.futures.Future[None]] = [] + print( + f"rank={torch.distributed.get_rank()}: BaseDistLoader.__iter__: rpc calls" + ) + rpc_futures: list[torch.futures.Future[int]] = [] for server_rank, producer_id in zip( self._server_rank_list, self._producer_id_list ): @@ -583,11 +587,19 @@ def __iter__(self) -> Self: self._epoch, ) rpc_futures.append(fut) - print(f"rank={torch.distributed.get_rank()}: BaseDistLoader.__iter__: waiting for {len(rpc_futures)} rpc calls") - torch.futures.wait_all(rpc_futures) - print(f"rank={torch.distributed.get_rank()}: BaseDistLoader.__iter__: rpc calls done") + print( + f"rank={torch.distributed.get_rank()}: BaseDistLoader.__iter__: waiting for {len(rpc_futures)} rpc calls" + ) + server_epochs: list[int] = torch.futures.wait_all(rpc_futures) + print( + f"rank={torch.distributed.get_rank()}: BaseDistLoader.__iter__: rpc calls done, server_epochs={server_epochs}" + ) self._channel.reset() - print(f"rank={torch.distributed.get_rank()}: BaseDistLoader.__iter__: channel reset") - print(f"rank={torch.distributed.get_rank()}: BaseDistLoader.__iter__: epoch incremented to {self._epoch}") - self._epoch += 1 + print( + f"rank={torch.distributed.get_rank()}: BaseDistLoader.__iter__: channel reset" + ) + self._epoch = max(server_epochs) + 1 + print( + f"rank={torch.distributed.get_rank()}: BaseDistLoader.__iter__: epoch={self._epoch}" + ) return self diff --git a/gigl/distributed/graph_store/dist_server.py b/gigl/distributed/graph_store/dist_server.py index 23ae3c32f..fa2faf7dd 100644 --- a/gigl/distributed/graph_store/dist_server.py +++ b/gigl/distributed/graph_store/dist_server.py @@ -551,23 +551,37 @@ def destroy_sampling_producer(self, producer_id: int) -> None: self._msg_buffer_pool.pop(producer_id) self._epoch.pop(producer_id) - def start_new_epoch_sampling(self, producer_id: int, epoch: int) -> None: + def start_new_epoch_sampling(self, producer_id: int, epoch: int) -> int: r"""Start a new epoch sampling tasks for a specific sampling producer with its producer id. + + Returns: + The server's current epoch for this producer after processing the + request. Clients should use this to synchronize their local epoch + counter, since multiple clients may share the same producer. """ - logger.info(f"DistServer.start_new_epoch_sampling: producer_id={producer_id}, epoch={epoch}") + logger.info( + f"DistServer.start_new_epoch_sampling: producer_id={producer_id}, epoch={epoch}" + ) with self._producer_lock[producer_id]: cur_epoch = self._epoch[producer_id] if cur_epoch < epoch: self._epoch[producer_id] = epoch producer = self._producer_pool.get(producer_id, None) if producer is not None: - logger.info(f"DistServer.start_new_epoch_sampling: producing all for producer {producer_id}") + logger.info( + f"DistServer.start_new_epoch_sampling: producing all for producer {producer_id}" + ) producer.produce_all() else: - logger.warning(f"DistServer.start_new_epoch_sampling: producer {producer_id} not found") + logger.warning( + f"DistServer.start_new_epoch_sampling: producer {producer_id} not found" + ) else: - logger.info(f"DistServer.start_new_epoch_sampling: producer {producer_id} already on epoch {cur_epoch}, skipping") + logger.info( + f"DistServer.start_new_epoch_sampling: producer {producer_id} already on epoch {cur_epoch}, skipping" + ) + return self._epoch[producer_id] def fetch_one_sampled_message( self, producer_id: int diff --git a/gigl/utils/iterator.py b/gigl/utils/iterator.py index aaad1812d..219cf4fdb 100644 --- a/gigl/utils/iterator.py +++ b/gigl/utils/iterator.py @@ -23,8 +23,12 @@ def __next__(self) -> _T: return next(self._iter) except StopIteration: if torch.distributed.is_initialized(): - print(f"rank={torch.distributed.get_rank()}: InfiniteIterator: _iterable={self._iterable} exhausted, resetting iterator") + print( + f"rank={torch.distributed.get_rank()}: InfiniteIterator: _iterable={self._iterable} exhausted, resetting iterator" + ) else: - print(f"InfiniteIterator: _iterable={self._iterable} exhausted, resetting iterator") + print( + f"InfiniteIterator: _iterable={self._iterable} exhausted, resetting iterator" + ) self._iter = iter(self._iterable) return next(self._iter) From bf35aac3b6af5aeb8d08eee813f8e58cd51de66f Mon Sep 17 00:00:00 2001 From: kmontemayor Date: Thu, 26 Feb 2026 23:59:01 +0000 Subject: [PATCH 24/30] reenable cora --- tests/e2e_tests/e2e_tests.yaml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/e2e_tests/e2e_tests.yaml b/tests/e2e_tests/e2e_tests.yaml index d87465ed4..c6735ce27 100644 --- a/tests/e2e_tests/e2e_tests.yaml +++ b/tests/e2e_tests/e2e_tests.yaml @@ -19,9 +19,9 @@ tests: # het_dblp_sup_test: # task_config_uri: "examples/link_prediction/configs/e2e_het_dblp_sup_task_config.yaml" # resource_config_uri: "${oc.env:GIGL_TEST_IN_MEMORY_DEFAULT_RESOURCE_CONFIG,deployment/configs/e2e_glt_resource_config.yaml}" - # hom_cora_sup_gs_test: - # task_config_uri: "examples/link_prediction/graph_store/configs/e2e_hom_cora_sup_gs_task_config.yaml" - # resource_config_uri: "${oc.env:GIGL_TEST_IN_MEMORY_DEFAULT_GRAPH_STORE_RESOURCE_CONFIG,deployment/configs/e2e_glt_gs_resource_config.yaml}" + hom_cora_sup_gs_test: + task_config_uri: "examples/link_prediction/graph_store/configs/e2e_hom_cora_sup_gs_task_config.yaml" + resource_config_uri: "${oc.env:GIGL_TEST_IN_MEMORY_DEFAULT_GRAPH_STORE_RESOURCE_CONFIG,deployment/configs/e2e_glt_gs_resource_config.yaml}" het_dblp_sup_gs_test: task_config_uri: "examples/link_prediction/graph_store/configs/e2e_het_dblp_sup_gs_task_config.yaml" resource_config_uri: "${oc.env:GIGL_TEST_IN_MEMORY_DEFAULT_GRAPH_STORE_RESOURCE_CONFIG,deployment/configs/e2e_glt_gs_resource_config.yaml}" From f05a90287b149797f7581753a6983830f8dc3455 Mon Sep 17 00:00:00 2001 From: kmontemayor Date: Fri, 27 Feb 2026 05:07:48 +0000 Subject: [PATCH 25/30] hmmm --- gigl/distributed/base_dist_loader.py | 15 ++++++++---- gigl/distributed/graph_store/dist_server.py | 26 +++++++++++++++++---- 2 files changed, 32 insertions(+), 9 deletions(-) diff --git a/gigl/distributed/base_dist_loader.py b/gigl/distributed/base_dist_loader.py index e5b7b4c4b..fac6dcda6 100644 --- a/gigl/distributed/base_dist_loader.py +++ b/gigl/distributed/base_dist_loader.py @@ -576,7 +576,12 @@ def __iter__(self) -> Self: print( f"rank={torch.distributed.get_rank()}: BaseDistLoader.__iter__: rpc calls" ) - rpc_futures: list[torch.futures.Future[int]] = [] + # When multiple GPUs on the same node share a producer, the second + # GPU to call start_new_epoch_sampling for a given epoch gets + # skipped (no produce_all). We sync the local epoch to the server's + # current epoch + 1 so that the next __iter__ call (e.g. from + # InfiniteIterator reset) is guaranteed to trigger produce_all. + rpc_futures: list[torch.futures.Future[tuple[int, bool]]] = [] for server_rank, producer_id in zip( self._server_rank_list, self._producer_id_list ): @@ -590,15 +595,17 @@ def __iter__(self) -> Self: print( f"rank={torch.distributed.get_rank()}: BaseDistLoader.__iter__: waiting for {len(rpc_futures)} rpc calls" ) - server_epochs: list[int] = torch.futures.wait_all(rpc_futures) + results: list[tuple[int, bool]] = torch.futures.wait_all(rpc_futures) + server_epochs = [epoch for epoch, _ in results] + any_produced = any(produced for _, produced in results) print( - f"rank={torch.distributed.get_rank()}: BaseDistLoader.__iter__: rpc calls done, server_epochs={server_epochs}" + f"rank={torch.distributed.get_rank()}: BaseDistLoader.__iter__: rpc calls done, server_epochs={server_epochs}, any_produced={any_produced}" ) + self._epoch = max(server_epochs) + 1 self._channel.reset() print( f"rank={torch.distributed.get_rank()}: BaseDistLoader.__iter__: channel reset" ) - self._epoch = max(server_epochs) + 1 print( f"rank={torch.distributed.get_rank()}: BaseDistLoader.__iter__: epoch={self._epoch}" ) diff --git a/gigl/distributed/graph_store/dist_server.py b/gigl/distributed/graph_store/dist_server.py index fa2faf7dd..896b6b7f1 100644 --- a/gigl/distributed/graph_store/dist_server.py +++ b/gigl/distributed/graph_store/dist_server.py @@ -551,14 +551,19 @@ def destroy_sampling_producer(self, producer_id: int) -> None: self._msg_buffer_pool.pop(producer_id) self._epoch.pop(producer_id) - def start_new_epoch_sampling(self, producer_id: int, epoch: int) -> int: + def start_new_epoch_sampling( + self, producer_id: int, epoch: int + ) -> tuple[int, bool]: r"""Start a new epoch sampling tasks for a specific sampling producer with its producer id. Returns: - The server's current epoch for this producer after processing the - request. Clients should use this to synchronize their local epoch - counter, since multiple clients may share the same producer. + A tuple of (server_epoch, did_produce) where server_epoch is the + server's current epoch for this producer after processing the + request, and did_produce indicates whether ``produce_all`` was + called. Clients should use server_epoch to synchronize their local + epoch counter (since multiple clients may share the same producer) + and retry with a higher epoch if no server produced data. """ logger.info( f"DistServer.start_new_epoch_sampling: producer_id={producer_id}, epoch={epoch}" @@ -569,19 +574,30 @@ def start_new_epoch_sampling(self, producer_id: int, epoch: int) -> int: self._epoch[producer_id] = epoch producer = self._producer_pool.get(producer_id, None) if producer is not None: + # Wait for any in-flight workers from a previous + # produce_all to finish before starting the next epoch. + # This prevents corruption of sampling_completed_worker_count + # when a retry triggers produce_all in quick succession. + # Only wait if produce_all was called before (cur_epoch >= 0); + # on the very first call (cur_epoch == -1), no workers exist. + if cur_epoch >= 0: + while not producer.is_all_sampling_completed(): + time.sleep(0.01) logger.info( f"DistServer.start_new_epoch_sampling: producing all for producer {producer_id}" ) producer.produce_all() + return self._epoch[producer_id], True else: logger.warning( f"DistServer.start_new_epoch_sampling: producer {producer_id} not found" ) + return self._epoch[producer_id], False else: logger.info( f"DistServer.start_new_epoch_sampling: producer {producer_id} already on epoch {cur_epoch}, skipping" ) - return self._epoch[producer_id] + return self._epoch[producer_id], False def fetch_one_sampled_message( self, producer_id: int From e12051f062f7c194e0e3716553697fd4feea1bff Mon Sep 17 00:00:00 2001 From: kmontemayor Date: Fri, 27 Feb 2026 21:29:42 +0000 Subject: [PATCH 26/30] Use fix --- gigl/distributed/base_dist_loader.py | 80 ++++++++++++++------- gigl/distributed/graph_store/dist_server.py | 45 +++--------- 2 files changed, 65 insertions(+), 60 deletions(-) diff --git a/gigl/distributed/base_dist_loader.py b/gigl/distributed/base_dist_loader.py index fac6dcda6..b8d9ff63d 100644 --- a/gigl/distributed/base_dist_loader.py +++ b/gigl/distributed/base_dist_loader.py @@ -562,25 +562,41 @@ def shutdown(self) -> None: torch.futures.wait_all(rpc_futures) self._shutdowned = True + _MAX_EPOCH_CATCH_UP_RETRIES: int = 10 + # Overwrite DistLoader.__iter__ to so we can use our own __iter__ and rpc calls def __iter__(self) -> Self: - print(f"rank={torch.distributed.get_rank()}: BaseDistLoader.__iter__") self._num_recv = 0 if self._is_collocated_worker: self._collocated_producer.reset() - self._epoch += 1 elif self._is_mp_worker: self._mp_producer.produce_all() - self._epoch += 1 else: - print( - f"rank={torch.distributed.get_rank()}: BaseDistLoader.__iter__: rpc calls" - ) - # When multiple GPUs on the same node share a producer, the second - # GPU to call start_new_epoch_sampling for a given epoch gets - # skipped (no produce_all). We sync the local epoch to the server's - # current epoch + 1 so that the next __iter__ call (e.g. from - # InfiniteIterator reset) is guaranteed to trigger produce_all. + self._request_new_epoch_production() + self._channel.reset() + self._epoch += 1 + return self + + def _request_new_epoch_production(self) -> None: + """Request production from all servers, retrying only on genuine epoch skew. + + In graph store mode, multiple GPUs on the same compute node share a + producer per server (same ``worker_key``). Only the first GPU to call + ``start_new_epoch_sampling`` for a given epoch triggers + ``produce_all()``; subsequent calls at the same epoch are no-ops + because the data is already flowing through the shared buffer. + + Two distinct cases are handled: + + * **Same epoch** (``self._epoch >= max_server_epoch``): another GPU + already triggered production for this epoch. Data is in the shared + buffer — return immediately without retrying. + * **Behind** (``self._epoch < max_server_epoch``): our epoch is + genuinely stale. Fast-forward past the server's epoch and retry so + ``produce_all()`` is guaranteed to fire. This typically resolves in + two iterations (first detects staleness, second triggers). + """ + for attempt in range(self._MAX_EPOCH_CATCH_UP_RETRIES): rpc_futures: list[torch.futures.Future[tuple[int, bool]]] = [] for server_rank, producer_id in zip( self._server_rank_list, self._producer_id_list @@ -592,21 +608,33 @@ def __iter__(self) -> Self: self._epoch, ) rpc_futures.append(fut) - print( - f"rank={torch.distributed.get_rank()}: BaseDistLoader.__iter__: waiting for {len(rpc_futures)} rpc calls" - ) - results: list[tuple[int, bool]] = torch.futures.wait_all(rpc_futures) - server_epochs = [epoch for epoch, _ in results] + + results = [fut.wait() for fut in rpc_futures] any_produced = any(produced for _, produced in results) - print( - f"rank={torch.distributed.get_rank()}: BaseDistLoader.__iter__: rpc calls done, server_epochs={server_epochs}, any_produced={any_produced}" - ) - self._epoch = max(server_epochs) + 1 - self._channel.reset() - print( - f"rank={torch.distributed.get_rank()}: BaseDistLoader.__iter__: channel reset" + + if any_produced: + return + + # No server produced — check whether we are genuinely behind or + # another GPU sharing the same producer simply beat us. + max_server_epoch = max(server_epoch for server_epoch, _ in results) + + if self._epoch >= max_server_epoch: + # Another GPU already triggered production for this epoch. + # Data is flowing through the shared buffer — nothing to do. + return + + # Our epoch is genuinely behind the server's. Fast-forward and + # retry so the next RPC has epoch > max_server_epoch. + logger.warning( + f"Epoch skew detected: client epoch {self._epoch} behind " + f"server epoch {max_server_epoch}. Retrying with epoch " + f"{max_server_epoch + 1} (attempt {attempt + 1})." ) - print( - f"rank={torch.distributed.get_rank()}: BaseDistLoader.__iter__: epoch={self._epoch}" + self._epoch = max_server_epoch + 1 + + raise RuntimeError( + f"Failed to trigger production after " + f"{self._MAX_EPOCH_CATCH_UP_RETRIES} attempts. " + f"This indicates a persistent epoch skew." ) - return self diff --git a/gigl/distributed/graph_store/dist_server.py b/gigl/distributed/graph_store/dist_server.py index 896b6b7f1..47880cfdc 100644 --- a/gigl/distributed/graph_store/dist_server.py +++ b/gigl/distributed/graph_store/dist_server.py @@ -554,50 +554,27 @@ def destroy_sampling_producer(self, producer_id: int) -> None: def start_new_epoch_sampling( self, producer_id: int, epoch: int ) -> tuple[int, bool]: - r"""Start a new epoch sampling tasks for a specific sampling producer - with its producer id. + """Start a new epoch sampling for a specific sampling producer. + + Args: + producer_id: The unique id of the sampling producer. + epoch: The epoch requested by the client. Returns: - A tuple of (server_epoch, did_produce) where server_epoch is the - server's current epoch for this producer after processing the - request, and did_produce indicates whether ``produce_all`` was - called. Clients should use server_epoch to synchronize their local - epoch counter (since multiple clients may share the same producer) - and retry with a higher epoch if no server produced data. + A tuple of (server_epoch, produced) where server_epoch is the + current epoch on the server after this call and produced indicates + whether ``produce_all()`` was triggered. """ - logger.info( - f"DistServer.start_new_epoch_sampling: producer_id={producer_id}, epoch={epoch}" - ) with self._producer_lock[producer_id]: cur_epoch = self._epoch[producer_id] + produced = False if cur_epoch < epoch: self._epoch[producer_id] = epoch producer = self._producer_pool.get(producer_id, None) if producer is not None: - # Wait for any in-flight workers from a previous - # produce_all to finish before starting the next epoch. - # This prevents corruption of sampling_completed_worker_count - # when a retry triggers produce_all in quick succession. - # Only wait if produce_all was called before (cur_epoch >= 0); - # on the very first call (cur_epoch == -1), no workers exist. - if cur_epoch >= 0: - while not producer.is_all_sampling_completed(): - time.sleep(0.01) - logger.info( - f"DistServer.start_new_epoch_sampling: producing all for producer {producer_id}" - ) producer.produce_all() - return self._epoch[producer_id], True - else: - logger.warning( - f"DistServer.start_new_epoch_sampling: producer {producer_id} not found" - ) - return self._epoch[producer_id], False - else: - logger.info( - f"DistServer.start_new_epoch_sampling: producer {producer_id} already on epoch {cur_epoch}, skipping" - ) - return self._epoch[producer_id], False + produced = True + return self._epoch[producer_id], produced def fetch_one_sampled_message( self, producer_id: int From 7c887448003f9ec889aeee6c6b608c900e1d4082 Mon Sep 17 00:00:00 2001 From: kmontemayor Date: Fri, 27 Feb 2026 22:37:42 +0000 Subject: [PATCH 27/30] more training examples --- .../graph_store/configs/e2e_het_dblp_sup_gs_task_config.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/link_prediction/graph_store/configs/e2e_het_dblp_sup_gs_task_config.yaml b/examples/link_prediction/graph_store/configs/e2e_het_dblp_sup_gs_task_config.yaml index 09414eb1b..251eea4b6 100644 --- a/examples/link_prediction/graph_store/configs/e2e_het_dblp_sup_gs_task_config.yaml +++ b/examples/link_prediction/graph_store/configs/e2e_het_dblp_sup_gs_task_config.yaml @@ -54,8 +54,8 @@ trainerConfig: storageArgs: sample_edge_direction: "in" splitter_cls_path: "gigl.utils.data_splitters.DistNodeAnchorLinkSplitter" - splitter_kwargs: '{"sampling_direction": "in", "should_convert_labels_to_edges": True, "num_val": 0.1, "num_test": 0.1, "supervision_edge_types": [("author", "to", "paper")]}' - ssl_positive_label_percentage: "0.05" + splitter_kwargs: '{"sampling_direction": "in", "should_convert_labels_to_edges": True, "num_val": 0.3, "num_test": 0.3, "supervision_edge_types": [("author", "to", "paper")]}' + ssl_positive_label_percentage: "0.15" num_server_sessions: "1" # TODO(kmonte): Move to user-defined server code inferencerConfig: From d848843321ed81706965c3be1975ed4315e5071c Mon Sep 17 00:00:00 2001 From: kmontemayor Date: Fri, 27 Feb 2026 23:04:41 +0000 Subject: [PATCH 28/30] revert --- gigl/utils/data_splitters.py | 9 +++------ gigl/utils/iterator.py | 10 ---------- 2 files changed, 3 insertions(+), 16 deletions(-) diff --git a/gigl/utils/data_splitters.py b/gigl/utils/data_splitters.py index 254552fe6..9e05ead4f 100644 --- a/gigl/utils/data_splitters.py +++ b/gigl/utils/data_splitters.py @@ -189,7 +189,7 @@ def __init__( num_val: float = 0.1, num_test: float = 0.1, hash_function: Callable[[torch.Tensor], torch.Tensor] = _fast_hash, - supervision_edge_types: Optional[list[Union[EdgeType, PyGEdgeType]]] = None, + supervision_edge_types: Optional[list[EdgeType]] = None, should_convert_labels_to_edges: bool = True, ): """Initializes the DistNodeAnchorLinkSplitter. @@ -199,7 +199,7 @@ def __init__( num_val (float): The percentage of nodes to use for training. Defaults to 0.1 (10%). num_test (float): The percentage of nodes to use for validation. Defaults to 0.1 (10%). hash_function (Callable[[torch.Tensor, torch.Tensor], torch.Tensor]): The hash function to use. Defaults to `_fast_hash`. - supervision_edge_types (Optional[list[Union[EdgeType, PyGEdgeType]]]): The supervision edge types we should use for splitting. + supervision_edge_types (Optional[list[EdgeType]]): The supervision edge types we should use for splitting. Must be provided if we are splitting a heterogeneous graph. If None, uses the default message passing edge type in the graph. should_convert_labels_to_edges (bool): Whether label should be converted into an edge type in the graph. If provided, will make `gigl.distributed.build_dataset` convert all labels into edges, and will infer positive and negative edge types based on @@ -232,10 +232,7 @@ def __init__( # also be ("user", "positive", "story"), meaning that all edges in the loaded edge index tensor with this edge type will be treated as a labeled # edge and will be used for splitting. - self._supervision_edge_types: Sequence[EdgeType] = [ - EdgeType(*supervision_edge_type) - for supervision_edge_type in supervision_edge_types - ] + self._supervision_edge_types: Sequence[EdgeType] = supervision_edge_types self._labeled_edge_types: Sequence[EdgeType] if should_convert_labels_to_edges: labeled_edge_types = [ diff --git a/gigl/utils/iterator.py b/gigl/utils/iterator.py index 219cf4fdb..63f809083 100644 --- a/gigl/utils/iterator.py +++ b/gigl/utils/iterator.py @@ -1,8 +1,6 @@ from collections.abc import Iterable, Iterator from typing import TypeVar -import torch - _T = TypeVar("_T") @@ -22,13 +20,5 @@ def __next__(self) -> _T: try: return next(self._iter) except StopIteration: - if torch.distributed.is_initialized(): - print( - f"rank={torch.distributed.get_rank()}: InfiniteIterator: _iterable={self._iterable} exhausted, resetting iterator" - ) - else: - print( - f"InfiniteIterator: _iterable={self._iterable} exhausted, resetting iterator" - ) self._iter = iter(self._iterable) return next(self._iter) From a4691af9c363bcf7dc2084cdfc3071285ae1874a Mon Sep 17 00:00:00 2001 From: kmontemayor Date: Sat, 28 Feb 2026 00:18:15 +0000 Subject: [PATCH 29/30] revert --- gigl/utils/data_splitters.py | 9 ++++++--- gigl/utils/iterator.py | 10 ++++++++++ 2 files changed, 16 insertions(+), 3 deletions(-) diff --git a/gigl/utils/data_splitters.py b/gigl/utils/data_splitters.py index 9e05ead4f..254552fe6 100644 --- a/gigl/utils/data_splitters.py +++ b/gigl/utils/data_splitters.py @@ -189,7 +189,7 @@ def __init__( num_val: float = 0.1, num_test: float = 0.1, hash_function: Callable[[torch.Tensor], torch.Tensor] = _fast_hash, - supervision_edge_types: Optional[list[EdgeType]] = None, + supervision_edge_types: Optional[list[Union[EdgeType, PyGEdgeType]]] = None, should_convert_labels_to_edges: bool = True, ): """Initializes the DistNodeAnchorLinkSplitter. @@ -199,7 +199,7 @@ def __init__( num_val (float): The percentage of nodes to use for training. Defaults to 0.1 (10%). num_test (float): The percentage of nodes to use for validation. Defaults to 0.1 (10%). hash_function (Callable[[torch.Tensor, torch.Tensor], torch.Tensor]): The hash function to use. Defaults to `_fast_hash`. - supervision_edge_types (Optional[list[EdgeType]]): The supervision edge types we should use for splitting. + supervision_edge_types (Optional[list[Union[EdgeType, PyGEdgeType]]]): The supervision edge types we should use for splitting. Must be provided if we are splitting a heterogeneous graph. If None, uses the default message passing edge type in the graph. should_convert_labels_to_edges (bool): Whether label should be converted into an edge type in the graph. If provided, will make `gigl.distributed.build_dataset` convert all labels into edges, and will infer positive and negative edge types based on @@ -232,7 +232,10 @@ def __init__( # also be ("user", "positive", "story"), meaning that all edges in the loaded edge index tensor with this edge type will be treated as a labeled # edge and will be used for splitting. - self._supervision_edge_types: Sequence[EdgeType] = supervision_edge_types + self._supervision_edge_types: Sequence[EdgeType] = [ + EdgeType(*supervision_edge_type) + for supervision_edge_type in supervision_edge_types + ] self._labeled_edge_types: Sequence[EdgeType] if should_convert_labels_to_edges: labeled_edge_types = [ diff --git a/gigl/utils/iterator.py b/gigl/utils/iterator.py index 63f809083..219cf4fdb 100644 --- a/gigl/utils/iterator.py +++ b/gigl/utils/iterator.py @@ -1,6 +1,8 @@ from collections.abc import Iterable, Iterator from typing import TypeVar +import torch + _T = TypeVar("_T") @@ -20,5 +22,13 @@ def __next__(self) -> _T: try: return next(self._iter) except StopIteration: + if torch.distributed.is_initialized(): + print( + f"rank={torch.distributed.get_rank()}: InfiniteIterator: _iterable={self._iterable} exhausted, resetting iterator" + ) + else: + print( + f"InfiniteIterator: _iterable={self._iterable} exhausted, resetting iterator" + ) self._iter = iter(self._iterable) return next(self._iter) From cda0477e33ba24bf975211af6702fb6797c2ce69 Mon Sep 17 00:00:00 2001 From: kmontemayor Date: Sat, 28 Feb 2026 01:52:36 +0000 Subject: [PATCH 30/30] debug --- .../graph_store/heterogeneous_training.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/examples/link_prediction/graph_store/heterogeneous_training.py b/examples/link_prediction/graph_store/heterogeneous_training.py index 15854ee5d..a448e2fc3 100644 --- a/examples/link_prediction/graph_store/heterogeneous_training.py +++ b/examples/link_prediction/graph_store/heterogeneous_training.py @@ -181,7 +181,8 @@ def _setup_dataloaders( anchor_node_type=anchor_node_type, supervision_edge_type=supervision_edge_type, ) - + pos_labels = [a.labels[supervision_edge_type][0].shape for a in ablp_input.values()] + print(f"---Rank {rank} split {split} ABLP input sizes: main_loader: {[a.anchor_nodes.shape for a in ablp_input.values()]}, pos labels: {pos_labels}") main_loader = DistABLPLoader( dataset=dataset, num_neighbors=num_neighbors, @@ -209,6 +210,9 @@ def _setup_dataloaders( node_type=labeled_node_type, ) + print(f"---Rank {rank} split {split} all node ids sizes: {[n.shape for n in all_node_ids.values()]}") + flush() + random_negative_loader = DistNeighborLoader( dataset=dataset, num_neighbors=num_neighbors, @@ -841,8 +845,8 @@ def _run_example_training( trainer_args.get("sampling_workers_per_process", "4") ) - main_batch_size = int(trainer_args.get("main_batch_size", "16")) - random_batch_size = int(trainer_args.get("random_batch_size", "16")) + main_batch_size = int(trainer_args.get("main_batch_size", "4")) + random_batch_size = int(trainer_args.get("random_batch_size", "4")) hid_dim = int(trainer_args.get("hid_dim", "16")) out_dim = int(trainer_args.get("out_dim", "16"))