diff --git a/gigl/distributed/base_dist_loader.py b/gigl/distributed/base_dist_loader.py index d4ae3e452..b8d9ff63d 100644 --- a/gigl/distributed/base_dist_loader.py +++ b/gigl/distributed/base_dist_loader.py @@ -562,6 +562,8 @@ def shutdown(self) -> None: torch.futures.wait_all(rpc_futures) self._shutdowned = True + _MAX_EPOCH_CATCH_UP_RETRIES: int = 10 + # Overwrite DistLoader.__iter__ to so we can use our own __iter__ and rpc calls def __iter__(self) -> Self: self._num_recv = 0 @@ -570,7 +572,32 @@ def __iter__(self) -> Self: elif self._is_mp_worker: self._mp_producer.produce_all() else: - rpc_futures: list[torch.futures.Future[None]] = [] + self._request_new_epoch_production() + self._channel.reset() + self._epoch += 1 + return self + + def _request_new_epoch_production(self) -> None: + """Request production from all servers, retrying only on genuine epoch skew. + + In graph store mode, multiple GPUs on the same compute node share a + producer per server (same ``worker_key``). Only the first GPU to call + ``start_new_epoch_sampling`` for a given epoch triggers + ``produce_all()``; subsequent calls at the same epoch are no-ops + because the data is already flowing through the shared buffer. + + Two distinct cases are handled: + + * **Same epoch** (``self._epoch >= max_server_epoch``): another GPU + already triggered production for this epoch. Data is in the shared + buffer — return immediately without retrying. + * **Behind** (``self._epoch < max_server_epoch``): our epoch is + genuinely stale. Fast-forward past the server's epoch and retry so + ``produce_all()`` is guaranteed to fire. This typically resolves in + two iterations (first detects staleness, second triggers). + """ + for attempt in range(self._MAX_EPOCH_CATCH_UP_RETRIES): + rpc_futures: list[torch.futures.Future[tuple[int, bool]]] = [] for server_rank, producer_id in zip( self._server_rank_list, self._producer_id_list ): @@ -581,7 +608,33 @@ def __iter__(self) -> Self: self._epoch, ) rpc_futures.append(fut) - torch.futures.wait_all(rpc_futures) - self._channel.reset() - self._epoch += 1 - return self + + results = [fut.wait() for fut in rpc_futures] + any_produced = any(produced for _, produced in results) + + if any_produced: + return + + # No server produced — check whether we are genuinely behind or + # another GPU sharing the same producer simply beat us. + max_server_epoch = max(server_epoch for server_epoch, _ in results) + + if self._epoch >= max_server_epoch: + # Another GPU already triggered production for this epoch. + # Data is flowing through the shared buffer — nothing to do. + return + + # Our epoch is genuinely behind the server's. Fast-forward and + # retry so the next RPC has epoch > max_server_epoch. + logger.warning( + f"Epoch skew detected: client epoch {self._epoch} behind " + f"server epoch {max_server_epoch}. Retrying with epoch " + f"{max_server_epoch + 1} (attempt {attempt + 1})." + ) + self._epoch = max_server_epoch + 1 + + raise RuntimeError( + f"Failed to trigger production after " + f"{self._MAX_EPOCH_CATCH_UP_RETRIES} attempts. " + f"This indicates a persistent epoch skew." + ) diff --git a/gigl/distributed/graph_store/dist_server.py b/gigl/distributed/graph_store/dist_server.py index 1506432b2..47880cfdc 100644 --- a/gigl/distributed/graph_store/dist_server.py +++ b/gigl/distributed/graph_store/dist_server.py @@ -551,17 +551,30 @@ def destroy_sampling_producer(self, producer_id: int) -> None: self._msg_buffer_pool.pop(producer_id) self._epoch.pop(producer_id) - def start_new_epoch_sampling(self, producer_id: int, epoch: int) -> None: - r"""Start a new epoch sampling tasks for a specific sampling producer - with its producer id. + def start_new_epoch_sampling( + self, producer_id: int, epoch: int + ) -> tuple[int, bool]: + """Start a new epoch sampling for a specific sampling producer. + + Args: + producer_id: The unique id of the sampling producer. + epoch: The epoch requested by the client. + + Returns: + A tuple of (server_epoch, produced) where server_epoch is the + current epoch on the server after this call and produced indicates + whether ``produce_all()`` was triggered. """ with self._producer_lock[producer_id]: cur_epoch = self._epoch[producer_id] + produced = False if cur_epoch < epoch: self._epoch[producer_id] = epoch producer = self._producer_pool.get(producer_id, None) if producer is not None: producer.produce_all() + produced = True + return self._epoch[producer_id], produced def fetch_one_sampled_message( self, producer_id: int diff --git a/tests/integration/distributed/graph_store/graph_store_integration_test.py b/tests/integration/distributed/graph_store/graph_store_integration_test.py index f29a1fd97..be9c4278b 100644 --- a/tests/integration/distributed/graph_store/graph_store_integration_test.py +++ b/tests/integration/distributed/graph_store/graph_store_integration_test.py @@ -843,6 +843,170 @@ def _get_expected_input_nodes_by_rank( return dict(expected_sampler_input) +def _run_compute_epoch_skew_test( + client_rank: int, + cluster_info: GraphStoreInfo, + mp_sharing_dict: Optional[MutableMapping[str, torch.Tensor]], + node_type: Optional[NodeType], + expected_sampler_input: dict[int, list[torch.Tensor]], + expected_edge_types: Optional[list[EdgeType]], +) -> None: + """Test that a loader recovers when its epoch falls behind the server's. + + Reproduces the epoch skew deadlock scenario: + + 1. Creates a DistNeighborLoader and iterates through two epochs normally, + which advances the server's epoch to 1. + 2. Resets the loader's ``_epoch`` to 0 (below the server's epoch), + simulating a slow GPU that has fallen behind. + 3. Iterates the loader again and verifies it still produces data — i.e. + ``_request_new_epoch_production`` correctly detects the stale epoch, + fast-forwards past the server's epoch, and triggers production. + + Without the epoch skew fix, step 3 would silently skip production (since + the loader's epoch <= the server's epoch), return an empty buffer, and + raise ``StopIteration`` — which in a real training loop escapes + ``InfiniteIterator`` and causes an NCCL deadlock. + """ + init_compute_process(client_rank, cluster_info, compute_world_backend="gloo") + + remote_dist_dataset = RemoteDistDataset( + cluster_info=cluster_info, + local_rank=client_rank, + mp_sharing_dict=mp_sharing_dict, + ) + + sampler_input = remote_dist_dataset.fetch_node_ids( + node_type=node_type, + rank=cluster_info.compute_node_rank, + world_size=cluster_info.num_compute_nodes, + ) + _assert_sampler_input(cluster_info, sampler_input, expected_sampler_input) + + if node_type is not None: + input_nodes: Union[ + dict[int, torch.Tensor], tuple[NodeType, dict[int, torch.Tensor]] + ] = ( + node_type, + sampler_input, + ) + else: + input_nodes = sampler_input + + loader = DistNeighborLoader( + dataset=remote_dist_dataset, + num_neighbors=[2, 2], + pin_memory_device=torch.device("cpu"), + input_nodes=input_nodes, + num_workers=2, + worker_concurrency=2, + ) + + # --- Epoch 0: normal iteration --- + epoch_0_count = 0 + for datum in loader: + if node_type is not None: + assert isinstance(datum, HeteroData) + else: + assert isinstance(datum, Data) + epoch_0_count += 1 + + torch.distributed.barrier() + logger.info( + f"Rank {torch.distributed.get_rank()} loaded {epoch_0_count} batches " + f"in epoch 0 (normal). Loader epoch={loader._epoch}" + ) + assert epoch_0_count > 0, "Should have loaded batches in normal epoch 0" + + # --- Epoch 1: normal iteration (advances server epoch) --- + epoch_1_count = 0 + for datum in loader: + if node_type is not None: + assert isinstance(datum, HeteroData) + else: + assert isinstance(datum, Data) + epoch_1_count += 1 + + torch.distributed.barrier() + logger.info( + f"Rank {torch.distributed.get_rank()} loaded {epoch_1_count} batches " + f"in epoch 1 (normal). Loader epoch={loader._epoch}" + ) + assert epoch_1_count > 0, "Should have loaded batches in normal epoch 1" + + # After two normal iterations the server's epoch is 1 and the loader's + # _epoch is 2. Reset the loader's epoch to 0 to simulate a slow GPU + # whose epoch has fallen behind the server's. + server_epoch_after_normal = loader._epoch - 1 # epoch sent in last __iter__ + loader._epoch = 0 + logger.info( + f"Rank {torch.distributed.get_rank()} simulating epoch skew: " + f"loader._epoch reset to {loader._epoch}, " + f"server epoch is {server_epoch_after_normal}" + ) + + # --- Epoch after skew: should recover --- + # When __iter__ is called, _request_new_epoch_production sends epoch 0 + # to the server, which has epoch 1. The fix detects 0 < 1, fast-forwards + # to epoch 2, retries, and triggers production (2 > 1). + skew_count = 0 + for datum in loader: + if node_type is not None: + assert isinstance(datum, HeteroData) + else: + assert isinstance(datum, Data) + skew_count += 1 + + torch.distributed.barrier() + logger.info( + f"Rank {torch.distributed.get_rank()} loaded {skew_count} batches " + f"after epoch skew recovery. Loader epoch={loader._epoch}" + ) + assert skew_count > 0, ( + f"Rank {torch.distributed.get_rank()} loaded 0 batches after epoch " + f"skew recovery — _request_new_epoch_production failed to trigger " + f"production with stale epoch 0 vs server epoch " + f"{server_epoch_after_normal}." + ) + + shutdown_compute_proccess() + + +def _client_epoch_skew_process(args: ClientProcessArgs) -> None: + """Client process wrapper for the epoch skew recovery test.""" + process_name = f"client_epoch_skew_{args.client_rank}" + try: + logger.info( + f"Initializing epoch skew client node {args.client_rank} / " + f"{args.cluster_info.num_compute_nodes}. " + f"OS rank: {os.environ['RANK']}, " + f"OS world size: {os.environ['WORLD_SIZE']}" + ) + mp_context = torch.multiprocessing.get_context("spawn") + mp_sharing_dict = torch.multiprocessing.Manager().dict() + client_processes: list[py_mp_context.SpawnProcess] = [] + for i in range(args.cluster_info.num_processes_per_compute): + client_process = mp_context.Process( + target=_run_compute_epoch_skew_test, + args=[ + i, + args.cluster_info, + mp_sharing_dict, + args.node_type, + args.expected_sampler_input, + args.expected_edge_types, + ], + ) + client_processes.append(client_process) + for client_process in client_processes: + client_process.start() + for client_process in client_processes: + client_process.join(DEFAULT_TIMEOUT_SECONDS) + except Exception: + args.exception_dict[process_name] = traceback.format_exc() + raise + + class GraphStoreIntegrationTest(TestCase): """ NOTE: Since these tests run on cloud build, @@ -953,6 +1117,116 @@ def test_graph_store_homogeneous(self): self.assert_all_processes_succeed(launched_processes, exception_dict) + def test_epoch_skew_recovery(self): + """Test that loaders recover from epoch skew without deadlocking. + + Simulates the scenario where a fast GPU advances the server's epoch + past a slow GPU's epoch. Without the epoch skew fix in + ``_request_new_epoch_production``, the slow GPU's ``__iter__`` would + silently skip production, find an empty buffer, and raise + ``StopIteration`` — terminating one rank and causing an NCCL deadlock. + + Uses 1 storage node and 1 compute node with 1 process to isolate the + epoch skew logic from shared-buffer interactions. + """ + cora_supervised_info = get_mocked_dataset_artifact_metadata()[ + CORA_USER_DEFINED_NODE_ANCHOR_MOCKED_DATASET_INFO.name + ] + task_config_uri = cora_supervised_info.frozen_gbml_config_uri + ( + cluster_master_port, + storage_cluster_master_port, + compute_cluster_master_port, + master_port, + rpc_master_port, + rpc_wait_port, + ) = get_free_ports(num_ports=6) + host_ip = socket.gethostbyname(socket.gethostname()) + # Small cluster: 1 storage, 1 compute, 1 process per compute. + # This isolates the epoch skew logic from shared-buffer interactions. + cluster_info = GraphStoreInfo( + num_storage_nodes=1, + num_compute_nodes=1, + num_processes_per_compute=1, + cluster_master_ip=host_ip, + storage_cluster_master_ip=host_ip, + compute_cluster_master_ip=host_ip, + cluster_master_port=cluster_master_port, + storage_cluster_master_port=storage_cluster_master_port, + compute_cluster_master_port=compute_cluster_master_port, + rpc_master_port=rpc_master_port, + rpc_wait_port=rpc_wait_port, + ) + + num_cora_nodes = 2708 + expected_sampler_input = _get_expected_input_nodes_by_rank( + num_cora_nodes, cluster_info + ) + + ctx = mp.get_context("spawn") + manager = mp.Manager() + exception_dict = manager.dict() + launched_processes: list[py_mp_context.SpawnProcess] = [] + for i in range(cluster_info.num_compute_nodes): + with mock.patch.dict( + os.environ, + { + "MASTER_ADDR": host_ip, + "MASTER_PORT": str(master_port), + "RANK": str(i), + "WORLD_SIZE": str(cluster_info.num_cluster_nodes), + COMPUTE_CLUSTER_LOCAL_WORLD_SIZE_ENV_KEY: str( + cluster_info.num_processes_per_compute + ), + }, + clear=False, + ): + client_args = ClientProcessArgs( + client_rank=i, + cluster_info=cluster_info, + node_type=None, + expected_sampler_input=expected_sampler_input, + expected_edge_types=None, + exception_dict=exception_dict, + ) + client_process = ctx.Process( + target=_client_epoch_skew_process, + args=[client_args], + name=f"client_epoch_skew_{i}", + ) + client_process.start() + launched_processes.append(client_process) + + for i in range(cluster_info.num_storage_nodes): + with mock.patch.dict( + os.environ, + { + "MASTER_ADDR": host_ip, + "MASTER_PORT": str(master_port), + "RANK": str(i + cluster_info.num_compute_nodes), + "WORLD_SIZE": str(cluster_info.num_cluster_nodes), + COMPUTE_CLUSTER_LOCAL_WORLD_SIZE_ENV_KEY: str( + cluster_info.num_processes_per_compute + ), + }, + clear=False, + ): + server_args = ServerProcessArgs( + cluster_info=cluster_info, + task_config_uri=task_config_uri, + sample_edge_direction="in", + exception_dict=exception_dict, + ) + server_process = ctx.Process( + target=_run_storage_main_process, + args=[server_args], + name=f"server_{i}", + ) + server_process.start() + launched_processes.append(server_process) + + self.assert_all_processes_succeed(launched_processes, exception_dict) + def test_homogeneous_training(self): cora_supervised_info = get_mocked_dataset_artifact_metadata()[ CORA_USER_DEFINED_NODE_ANCHOR_MOCKED_DATASET_INFO.name