Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 58 additions & 5 deletions gigl/distributed/base_dist_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -562,6 +562,8 @@ def shutdown(self) -> None:
torch.futures.wait_all(rpc_futures)
self._shutdowned = True

_MAX_EPOCH_CATCH_UP_RETRIES: int = 10

# Overwrite DistLoader.__iter__ to so we can use our own __iter__ and rpc calls
def __iter__(self) -> Self:
self._num_recv = 0
Expand All @@ -570,7 +572,32 @@ def __iter__(self) -> Self:
elif self._is_mp_worker:
self._mp_producer.produce_all()
else:
rpc_futures: list[torch.futures.Future[None]] = []
self._request_new_epoch_production()
self._channel.reset()
self._epoch += 1
return self

def _request_new_epoch_production(self) -> None:
"""Request production from all servers, retrying only on genuine epoch skew.

In graph store mode, multiple GPUs on the same compute node share a
producer per server (same ``worker_key``). Only the first GPU to call
``start_new_epoch_sampling`` for a given epoch triggers
``produce_all()``; subsequent calls at the same epoch are no-ops
because the data is already flowing through the shared buffer.

Two distinct cases are handled:

* **Same epoch** (``self._epoch >= max_server_epoch``): another GPU
already triggered production for this epoch. Data is in the shared
buffer — return immediately without retrying.
* **Behind** (``self._epoch < max_server_epoch``): our epoch is
genuinely stale. Fast-forward past the server's epoch and retry so
``produce_all()`` is guaranteed to fire. This typically resolves in
two iterations (first detects staleness, second triggers).
"""
for attempt in range(self._MAX_EPOCH_CATCH_UP_RETRIES):
rpc_futures: list[torch.futures.Future[tuple[int, bool]]] = []
for server_rank, producer_id in zip(
self._server_rank_list, self._producer_id_list
):
Expand All @@ -581,7 +608,33 @@ def __iter__(self) -> Self:
self._epoch,
)
rpc_futures.append(fut)
torch.futures.wait_all(rpc_futures)
self._channel.reset()
self._epoch += 1
return self

results = [fut.wait() for fut in rpc_futures]
any_produced = any(produced for _, produced in results)

if any_produced:
return

# No server produced — check whether we are genuinely behind or
# another GPU sharing the same producer simply beat us.
max_server_epoch = max(server_epoch for server_epoch, _ in results)

if self._epoch >= max_server_epoch:
# Another GPU already triggered production for this epoch.
# Data is flowing through the shared buffer — nothing to do.
return

# Our epoch is genuinely behind the server's. Fast-forward and
# retry so the next RPC has epoch > max_server_epoch.
logger.warning(
f"Epoch skew detected: client epoch {self._epoch} behind "
f"server epoch {max_server_epoch}. Retrying with epoch "
f"{max_server_epoch + 1} (attempt {attempt + 1})."
)
self._epoch = max_server_epoch + 1

raise RuntimeError(
f"Failed to trigger production after "
f"{self._MAX_EPOCH_CATCH_UP_RETRIES} attempts. "
f"This indicates a persistent epoch skew."
)
19 changes: 16 additions & 3 deletions gigl/distributed/graph_store/dist_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -551,17 +551,30 @@ def destroy_sampling_producer(self, producer_id: int) -> None:
self._msg_buffer_pool.pop(producer_id)
self._epoch.pop(producer_id)

def start_new_epoch_sampling(self, producer_id: int, epoch: int) -> None:
r"""Start a new epoch sampling tasks for a specific sampling producer
with its producer id.
def start_new_epoch_sampling(
self, producer_id: int, epoch: int
) -> tuple[int, bool]:
"""Start a new epoch sampling for a specific sampling producer.

Args:
producer_id: The unique id of the sampling producer.
epoch: The epoch requested by the client.

Returns:
A tuple of (server_epoch, produced) where server_epoch is the
current epoch on the server after this call and produced indicates
whether ``produce_all()`` was triggered.
"""
with self._producer_lock[producer_id]:
cur_epoch = self._epoch[producer_id]
produced = False
if cur_epoch < epoch:
self._epoch[producer_id] = epoch
producer = self._producer_pool.get(producer_id, None)
if producer is not None:
producer.produce_all()
produced = True
return self._epoch[producer_id], produced

def fetch_one_sampled_message(
self, producer_id: int
Expand Down
Loading