Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
54 commits
Select commit Hold shift + click to select a range
2e870f0
update
mkolodner-sc Mar 4, 2026
9f2331e
initial changes
mkolodner-sc Mar 4, 2026
d75f96e
Update
mkolodner-sc Mar 4, 2026
7a36925
Update
mkolodner-sc Mar 4, 2026
21d68eb
Update
mkolodner-sc Mar 4, 2026
f114822
Update
mkolodner-sc Mar 4, 2026
7af2d0c
Update
mkolodner-sc Mar 5, 2026
6498544
Fix
mkolodner-sc Mar 5, 2026
ec43f27
Update
mkolodner-sc Mar 5, 2026
5129dd0
fix: empty tensor dtype mismatch, docstring cleanup, and del guard
mkolodner-sc Mar 5, 2026
2377963
Update format
mkolodner-sc Mar 5, 2026
e59ca07
Add SamplerOptions to enable configurable sampler classes
mkolodner-sc Mar 6, 2026
c0fdb03
Rename to KHopNeighborSamplerOptions, inline resolve logic, clean up …
mkolodner-sc Mar 6, 2026
e41b957
Pass sampler_options positionally in DistABLPLoader
mkolodner-sc Mar 6, 2026
4b00573
Pass sampler_options positionally in DistNeighborLoader and DistServer
mkolodner-sc Mar 6, 2026
8ac9ab8
Make num_neighbors optional, resolve from sampler_options
mkolodner-sc Mar 6, 2026
7a51864
Add resolve_sampler_options tests, remove redundant num_neighbors fro…
mkolodner-sc Mar 6, 2026
8c36de9
Merge main into custom_sampler_options
mkolodner-sc Mar 6, 2026
c924bb6
Remove redundant KHop test cases, keep only CustomSamplerOptions tests
mkolodner-sc Mar 6, 2026
455cb24
Require num_neighbors always, remove silent [] default for CustomSamp…
mkolodner-sc Mar 6, 2026
096493e
Make num_neighbors required on loaders, simplify resolve_sampler_options
mkolodner-sc Mar 6, 2026
b6a3597
Fix stale error message in resolve_sampler_options
mkolodner-sc Mar 6, 2026
2e49c51
Remove CustomSamplerOptions, keep SamplerOptions plumbing for future …
mkolodner-sc Mar 6, 2026
8a4f3ef
Use kwargs on DistSamplingProducer calls; isinstance dispatch for sam…
mkolodner-sc Mar 6, 2026
f2be444
Update
mkolodner-sc Mar 9, 2026
d47b1b9
Add PPRSamplerOptions and DistPPRNeighborSampler with ABLP support an…
mkolodner-sc Mar 9, 2026
ec4bdbc
Use inducer for local indexing in DistPPRNeighborSampler
mkolodner-sc Mar 9, 2026
b05d9c6
Switch PPR output to edge-index format, remove default_node_id/defaul…
mkolodner-sc Mar 9, 2026
3506b3a
comments
mkolodner-sc Mar 9, 2026
1469a20
Merge branch 'mkolodner-sc/custom_sampler_options' into mkolodner-sc/…
mkolodner-sc Mar 9, 2026
8c62e37
Resolve merge conflict: add sampler_options to leader RPC dispatch
mkolodner-sc Mar 11, 2026
7fdd22d
Merge branch 'mkolodner-sc/custom_sampler_options' into mkolodner-sc/…
mkolodner-sc Mar 11, 2026
9fc9ac8
Resolve merge conflicts: keep PPR sampler types and dispatch
mkolodner-sc Mar 11, 2026
43ae1a9
Optimize PPR: merge push+requeue into single pass, cache total degree
mkolodner-sc Mar 11, 2026
81c1bfb
Add PPR sampler tests and fix ABLP metadata propagation
mkolodner-sc Mar 11, 2026
53b2284
Clean up PPR tests and fix metadata stripping in DistNeighborLoader
mkolodner-sc Mar 11, 2026
2d0bc65
Improve PPR sampler readability: rename variables, add comments
mkolodner-sc Mar 11, 2026
9f0a672
Unify metadata extraction into BaseDistLoader
mkolodner-sc Mar 12, 2026
85f553f
Remove unused metadata_key_with_prefix
mkolodner-sc Mar 12, 2026
549c430
Update
mkolodner-sc Mar 12, 2026
264f2d8
Merge branch 'mkolodner-sc/unify_metadata_extract_in_collate' into mk…
mkolodner-sc Mar 12, 2026
6b4db90
small update
mkolodner-sc Mar 13, 2026
e79ea63
Merge branch 'mkolodner-sc/unify_metadata_extract_in_collate' into mk…
mkolodner-sc Mar 13, 2026
c9b3f80
Move extract_metadata to utility function with tests
mkolodner-sc Mar 13, 2026
ee88a65
Merge branch 'mkolodner-sc/unify_metadata_extract_in_collate' into mk…
mkolodner-sc Mar 13, 2026
23b686b
Improve variable names in _compute_ppr_scores
mkolodner-sc Mar 13, 2026
29abf61
Clean up PPR sampler: remove unused return, fix identity check, simpl…
mkolodner-sc Mar 13, 2026
e97fe3b
Restructure PPR state by node type, inline _get_neighbors_for_nodes, …
mkolodner-sc Mar 13, 2026
f88cd4c
Document why queue uses a set
mkolodner-sc Mar 13, 2026
03fc354
Precompute total degree tensors at init, remove per-call caching
mkolodner-sc Mar 14, 2026
141c4b1
Add TODO comment on total degree memory tradeoff
mkolodner-sc Mar 14, 2026
e481eb5
Expose total_degree_dtype in PPRSamplerOptions, document valid_counts…
mkolodner-sc Mar 14, 2026
04b5e2b
Small comment adjustment
mkolodner-sc Mar 14, 2026
d6b77fd
Reformat @param to use named keyword args matching codebase convention
mkolodner-sc Mar 14, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 43 additions & 43 deletions gigl/distributed/dist_ablp_neighborloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,12 @@
NEGATIVE_LABEL_METADATA_KEY,
POSITIVE_LABEL_METADATA_KEY,
ABLPNodeSamplerInput,
metadata_key_with_prefix,
)
from gigl.distributed.sampler_options import SamplerOptions, resolve_sampler_options
from gigl.distributed.utils.neighborloader import (
DatasetSchema,
SamplingClusterSetup,
extract_metadata,
labeled_to_homogeneous,
set_missing_features,
shard_nodes_by_process,
Expand Down Expand Up @@ -772,72 +772,64 @@ def _setup_for_graph_store(
),
)

def _get_labels(
self, msg: SampleMessage
def _extract_labels(
self, metadata: dict[str, torch.Tensor]
) -> tuple[
SampleMessage,
dict[EdgeType, torch.Tensor],
dict[EdgeType, torch.Tensor],
dict[str, torch.Tensor],
]:
"""Partition pre-extracted metadata into labels and remaining metadata.

# TODO (mkolodner-sc): Remove the need to modify metadata once GLT's `to_hetero_data` function is fixed
f"""
Gets the labels from the output SampleMessage and removes them from the metadata. We need to remove the labels from GLT's metadata since the
`to_hetero_data` function strangely assumes that we are doing edge-based sampling if the metadata is not empty at the time of

Takes the metadata dict already extracted by ``_extract_metadata`` (keys
without the ``#META.`` prefix) and separates label entries from
non-label entries. We need to remove the labels from GLT's metadata since the `to_hetero_data` function
strangely assumes that we are doing edge-based sampling if the metadata is not empty at the time of
building the HeteroData object.

Label keys use ``POSITIVE_LABEL_METADATA_KEY`` / ``NEGATIVE_LABEL_METADATA_KEY``
prefixes followed by a string-encoded edge type tuple. If ``edge_dir``
is ``"in"``, the edge type is reversed because GLT swaps src/dst
internally.

Args:
msg (SampleMessage): All possible results from a sampler, including subgraph data, features, and used defined metadata
metadata: Dict of metadata keys (without ``#META.`` prefix) to tensors,
as returned by ``_extract_metadata``.

Returns:
SampleMessage: Updated sample messsage with the label fields removed
dict[EdgeType, torch.Tensor]: Dict[positive label edge type, label ID tensor],
where the ith row of the tensor corresponds to the ith anchor node ID.
dict[EdgeType, torch.Tensor]: Dict[negative label edge type, label ID tensor],
where the ith row of the tensor corresponds to the ith anchor node ID.
May be empty if no negative labels are present.
dict[str, torch.Tensor]: Non-label metadata entries
"""
metadata: dict[str, torch.Tensor] = {}
positive_labels_by_label_edge_type: dict[EdgeType, torch.Tensor] = {}
negative_labels_by_label_edge_type: dict[EdgeType, torch.Tensor] = {}
# We update metadata with sepcial POSITIVE_LABEL_METADATA_KEY and NEGATIVE_LABEL_METADATA_KEY keys
# in gigl/distributed/dist_neighbor_sampler.py.
# We need to encode the tuples as strings because GLT requires the keys to be strings.
# As such, we decode the strings back into tuples,
# And then pop those keys out of the metadata as they are not needed otherwise.
# If edge_dir is "in", we need to reverse the edge type because GLT swaps src/dst for edge_dir = "out".
# NOTE: GLT *prepends* the keys with "#META."
positive_label_metadata_key_prefix = metadata_key_with_prefix(
POSITIVE_LABEL_METADATA_KEY
)
negative_label_metadata_key_prefix = metadata_key_with_prefix(
NEGATIVE_LABEL_METADATA_KEY
)
for k in list(msg.keys()):
if k.startswith(positive_label_metadata_key_prefix):
edge_type_str = k[len(positive_label_metadata_key_prefix) :]
remaining_metadata: dict[str, torch.Tensor] = {}

for key, value in metadata.items():
if key.startswith(POSITIVE_LABEL_METADATA_KEY):
edge_type_str = key[len(POSITIVE_LABEL_METADATA_KEY) :]
edge_type = ast.literal_eval(edge_type_str)
if self.edge_dir == "in":
edge_type = reverse_edge_type(edge_type)
positive_labels_by_label_edge_type[edge_type] = msg[k].to(
self.to_device
)
del msg[k]
elif k.startswith(negative_label_metadata_key_prefix):
edge_type_str = k[len(negative_label_metadata_key_prefix) :]
positive_labels_by_label_edge_type[edge_type] = value
elif key.startswith(NEGATIVE_LABEL_METADATA_KEY):
edge_type_str = key[len(NEGATIVE_LABEL_METADATA_KEY) :]
edge_type = ast.literal_eval(edge_type_str)
if self.edge_dir == "in":
edge_type = reverse_edge_type(edge_type)
negative_labels_by_label_edge_type[edge_type] = msg[k].to(
self.to_device
)
del msg[k]
elif k.startswith("#META."):
meta_key = str(k[len("#META.") :])
metadata[meta_key] = msg[k].to(self.to_device)
del msg[k]
negative_labels_by_label_edge_type[edge_type] = value
else:
remaining_metadata[key] = value

return (
msg,
positive_labels_by_label_edge_type,
negative_labels_by_label_edge_type,
remaining_metadata,
)

def _set_labels(
Expand Down Expand Up @@ -925,8 +917,14 @@ def _set_labels(
return data

def _collate_fn(self, msg: SampleMessage) -> Union[Data, HeteroData]:
msg, positive_labels, negative_labels = self._get_labels(msg)
data = super()._collate_fn(msg)
# _extract_metadata separates #META. keys from the message to work
# around a GLT bug in to_hetero_data. _extract_labels then partitions
# the metadata into labels vs remaining non-label metadata.
all_metadata, stripped_msg = extract_metadata(msg, self.to_device)
positive_labels, negative_labels, non_label_metadata = self._extract_labels(
all_metadata
)
data = super()._collate_fn(stripped_msg)
data = set_missing_features(
data=data,
node_feature_info=self._node_feature_info,
Expand All @@ -941,5 +939,7 @@ def _collate_fn(self, msg: SampleMessage) -> Union[Data, HeteroData]:
f"Expected 1 supervision edge type, got {len(self._supervision_edge_types)}"
)
data = labeled_to_homogeneous(self._supervision_edge_types[0], data)
for key, value in non_label_metadata.items():
data[key] = value
data = self._set_labels(data, positive_labels, negative_labels)
return data
Loading