From 4297f49ccaa4f9956400d0c18438681757cd2ca8 Mon Sep 17 00:00:00 2001 From: Yozen Liu Date: Mon, 23 Feb 2026 12:54:26 -0800 Subject: [PATCH 01/10] init --- gigl/transforms/__init__.py | 0 gigl/transforms/add_positional_encodings.py | 201 ++++++++++++++++++ tests/unit/transforms/__init__.py | 0 .../add_positional_encodings_test.py | 195 +++++++++++++++++ 4 files changed, 396 insertions(+) create mode 100644 gigl/transforms/__init__.py create mode 100644 gigl/transforms/add_positional_encodings.py create mode 100644 tests/unit/transforms/__init__.py create mode 100644 tests/unit/transforms/add_positional_encodings_test.py diff --git a/gigl/transforms/__init__.py b/gigl/transforms/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/gigl/transforms/add_positional_encodings.py b/gigl/transforms/add_positional_encodings.py new file mode 100644 index 000000000..2520fbdae --- /dev/null +++ b/gigl/transforms/add_positional_encodings.py @@ -0,0 +1,201 @@ +from typing import Dict, Optional + +import torch +from torch import Tensor + +from torch_geometric.data import HeteroData +from torch_geometric.data.datapipes import functional_transform +from torch_geometric.transforms import BaseTransform +from torch_geometric.utils import to_torch_sparse_tensor + + +def add_node_attr( + data: HeteroData, + values: Tensor, + attr_name: Optional[str] = None, + node_type_to_idx: Optional[Dict[str, tuple]] = None, +) -> HeteroData: + """Helper function to add node attributes to a HeteroData object. + + Args: + data: The HeteroData object to modify. + values: The tensor of values (in homogeneous node order). + attr_name: The name of the attribute to add. If None, concatenates to + existing `x` attribute for each node type (or creates it). + node_type_to_idx: Optional mapping from node type to (start, end) indices. + If None, it will be computed from data.node_types. + + Returns: + The modified HeteroData object. + """ + if node_type_to_idx is None: + node_type_to_idx = {} + start_idx = 0 + for node_type in data.node_types: + num_type_nodes = data[node_type].num_nodes + node_type_to_idx[node_type] = (start_idx, start_idx + num_type_nodes) + start_idx += num_type_nodes + + for node_type in data.node_types: + start, end = node_type_to_idx[node_type] + value = values[start:end] + + if attr_name is None: + # Concatenate to existing x or create new x + x = data[node_type].x + if x is not None: + x = x.view(-1, 1) if x.dim() == 1 else x + data[node_type].x = torch.cat( + [x, value.to(x.device, x.dtype)], dim=-1 + ) + else: + data[node_type].x = value + else: + data[node_type][attr_name] = value + + return data + + +@functional_transform('add_hetero_hop_distance_pe') +class AddHeteroHopDistancePE(BaseTransform): + r"""Adds the hop distance positional encoding from the + `"Graph Neural Networks with Learnable Structural and Positional + Representations" `_ paper to the given + heterogeneous graph (functional name: :obj:`add_hetero_hop_distance_pe`). + + Args: + k (int): The number of hops to consider. + attr_name (str, optional): The attribute name of the positional + encoding. (default: :obj:`"hop_pe"`) + is_undirected (bool, optional): If set to :obj:`True`, the graph is + assumed to be undirected, and multi-hop connectivity will be + computed accordingly. (default: :obj:`False`) + """ + def __init__( + self, + k: int, + attr_name: Optional[str] = 'hop_pe', + is_undirected: bool = False, + ) -> None: + self.k = k + self.attr_name = attr_name + self.is_undirected = is_undirected + + def forward(self, data: HeteroData) -> HeteroData: + assert isinstance(data, HeteroData), ( + f"'{self.__class__.__name__}' only supports 'HeteroData' " + f"(got '{type(data)}')" + ) + + # Convert to homogeneous to calculate subgraph hop distances + homo_data = data.to_homogeneous() + edge_index = homo_data.edge_index + num_nodes = homo_data.num_nodes + + if num_nodes == 0: + # Handle empty graph case + for node_type in data.node_types: + data[node_type][self.attr_name] = torch.zeros( + (data[node_type].num_nodes, self.k), + dtype=torch.float, + ) + return data + + # Compute Adjacency Matrix (sparse for efficiency) + adj = to_torch_sparse_tensor(edge_index, size=(num_nodes, num_nodes)) + + if self.is_undirected: + # Make symmetric for undirected graphs + adj = adj + adj.t() + adj = adj.coalesce() + + # Convert to dense for matrix power computation + adj_dense = adj.to_dense() + adj_dense = (adj_dense > 0).float() # Binary adjacency + + # Iteratively compute k-hop reachability + hop_pe = torch.zeros((num_nodes, self.k), dtype=torch.float) + current_power = torch.eye(num_nodes, dtype=torch.float) + + for i in range(self.k): + current_power = current_power @ adj_dense + # Count number of paths at exactly i+1 hops (normalized) + reachable = (current_power > 0).float() + hop_pe[:, i] = reachable.sum(dim=-1) + + # Normalize hop distances (optional, for numerical stability) + hop_pe = hop_pe / (hop_pe.max(dim=0, keepdim=True).values + 1e-8) + + # Map back to HeteroData node types + add_node_attr(data, hop_pe, self.attr_name) + + return data + + def __repr__(self) -> str: + return f'{self.__class__.__name__}(k={self.k})' + + +@functional_transform('add_hetero_random_walk_pe') +class AddHeteroRandomWalkPE(BaseTransform): + r"""Adds the random walk positional encoding from the + `"Graph Neural Networks with Learnable Structural and Positional + Representations" `_ paper to the given + heterogeneous graph (functional name: :obj:`add_hetero_random_walk_pe`). + + Args: + walk_length (int): The number of random walk steps. + attr_name (str, optional): The attribute name of the positional + encoding. (default: :obj:`"random_walk_pe"`) + """ + def __init__( + self, + walk_length: int, + attr_name: Optional[str] = 'random_walk_pe', + ) -> None: + self.walk_length = walk_length + self.attr_name = attr_name + + def forward(self, data: HeteroData) -> HeteroData: + assert isinstance(data, HeteroData), ( + f"'{self.__class__.__name__}' only supports 'HeteroData' " + f"(got '{type(data)}')" + ) + + # Convert to homogeneous + homo_data = data.to_homogeneous() + edge_index = homo_data.edge_index + num_nodes = homo_data.num_nodes + + if num_nodes == 0: + for node_type in data.node_types: + data[node_type][self.attr_name] = torch.zeros( + (data[node_type].num_nodes, self.walk_length), + dtype=torch.float, + ) + return data + + # Compute transition matrix (row-stochastic) + adj = to_torch_sparse_tensor(edge_index, size=(num_nodes, num_nodes)) + adj_dense = adj.to_dense() + + # Compute degree and create transition matrix + deg = adj_dense.sum(dim=1, keepdim=True) + deg = torch.clamp(deg, min=1) # Avoid division by zero + transition = adj_dense / deg + + # Compute random walk probabilities + pe = torch.zeros((num_nodes, self.walk_length), dtype=torch.float) + current = torch.eye(num_nodes, dtype=torch.float) + + for i in range(self.walk_length): + current = current @ transition + # Diagonal gives probability of returning to the same node + pe[:, i] = current.diag() + + # Map back to HeteroData node types + add_node_attr(data, pe, self.attr_name) + + return data + + def __repr__(self) -> str: + return f'{self.__class__.__name__}(walk_length={self.walk_length})' diff --git a/tests/unit/transforms/__init__.py b/tests/unit/transforms/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/unit/transforms/add_positional_encodings_test.py b/tests/unit/transforms/add_positional_encodings_test.py new file mode 100644 index 000000000..4db4ee41a --- /dev/null +++ b/tests/unit/transforms/add_positional_encodings_test.py @@ -0,0 +1,195 @@ +import torch +from absl.testing import absltest +from torch_geometric.data import HeteroData + +from gigl.transforms.add_positional_encodings import ( + AddHeteroHopDistancePE, + AddHeteroRandomWalkPE, + add_node_attr, +) +from tests.test_assets.test_case import TestCase + + +def create_simple_hetero_data() -> HeteroData: + """Create a simple heterogeneous graph for testing. + + Graph structure: + - 3 'user' nodes + - 2 'item' nodes + - Edges: user -> item (bipartite) + """ + data = HeteroData() + + # Node features + data['user'].x = torch.randn(3, 4) + data['item'].x = torch.randn(2, 4) + + # Edges: user -> item + data['user', 'buys', 'item'].edge_index = torch.tensor([ + [0, 1, 2], # source (user) + [0, 0, 1], # target (item) + ]) + + # Edges: item -> user (reverse) + data['item', 'bought_by', 'user'].edge_index = torch.tensor([ + [0, 0, 1], # source (item) + [0, 1, 2], # target (user) + ]) + + return data + + +def create_empty_hetero_data() -> HeteroData: + """Create an empty heterogeneous graph for testing edge cases.""" + data = HeteroData() + data['user'].x = torch.zeros(0, 4) + data['item'].x = torch.zeros(0, 4) + return data + + +class TestAddNodeAttr(TestCase): + def test_add_node_attr_with_attr_name(self): + """Test adding a node attribute with a specific attribute name.""" + data = create_simple_hetero_data() + + # Create values in homogeneous order (3 users + 2 items = 5 nodes) + values = torch.randn(5, 8) + + add_node_attr(data, values, attr_name='test_attr') + + self.assertEqual(data['user'].test_attr.shape, (3, 8)) + self.assertEqual(data['item'].test_attr.shape, (2, 8)) + self.assert_tensor_equality(data['user'].test_attr, values[:3]) + self.assert_tensor_equality(data['item'].test_attr, values[3:]) + + def test_add_node_attr_concatenate_to_x(self): + """Test adding a node attribute by concatenating to existing x.""" + data = create_simple_hetero_data() + original_user_x = data['user'].x.clone() + original_item_x = data['item'].x.clone() + + # Create values in homogeneous order + values = torch.randn(5, 8) + + add_node_attr(data, values, attr_name=None) + + # Check that x was concatenated + self.assertEqual(data['user'].x.shape, (3, 12)) # 4 + 8 + self.assertEqual(data['item'].x.shape, (2, 12)) # 4 + 8 + + # Check original features are preserved + self.assert_tensor_equality(data['user'].x[:, :4], original_user_x) + self.assert_tensor_equality(data['item'].x[:, :4], original_item_x) + + def test_add_node_attr_create_x_if_none(self): + """Test creating x attribute if it doesn't exist.""" + data = HeteroData() + data['user'].num_nodes = 3 + data['item'].num_nodes = 2 + + values = torch.randn(5, 8) + + add_node_attr(data, values, attr_name=None) + + self.assertEqual(data['user'].x.shape, (3, 8)) + self.assertEqual(data['item'].x.shape, (2, 8)) + + +class TestAddHeteroHopDistancePE(TestCase): + def test_forward_basic(self): + """Test basic forward pass.""" + data = create_simple_hetero_data() + transform = AddHeteroHopDistancePE(k=3) + + result = transform(data) + + # Check that PE was added to both node types + self.assertTrue(hasattr(result['user'], 'hop_pe')) + self.assertTrue(hasattr(result['item'], 'hop_pe')) + + # Check shapes + self.assertEqual(result['user'].hop_pe.shape, (3, 3)) + self.assertEqual(result['item'].hop_pe.shape, (2, 3)) + + def test_forward_with_custom_attr_name(self): + """Test forward pass with custom attribute name.""" + data = create_simple_hetero_data() + transform = AddHeteroHopDistancePE(k=2, attr_name='custom_pe') + + result = transform(data) + + self.assertTrue(hasattr(result['user'], 'custom_pe')) + self.assertTrue(hasattr(result['item'], 'custom_pe')) + self.assertFalse(hasattr(result['user'], 'hop_pe')) + + def test_forward_undirected(self): + """Test forward pass with undirected graph setting.""" + data = create_simple_hetero_data() + transform = AddHeteroHopDistancePE(k=2, is_undirected=True) + + result = transform(data) + + self.assertEqual(result['user'].hop_pe.shape, (3, 2)) + self.assertEqual(result['item'].hop_pe.shape, (2, 2)) + + def test_forward_empty_graph(self): + """Test forward pass with empty graph.""" + data = create_empty_hetero_data() + transform = AddHeteroHopDistancePE(k=3) + + result = transform(data) + + self.assertEqual(result['user'].hop_pe.shape, (0, 3)) + self.assertEqual(result['item'].hop_pe.shape, (0, 3)) + + def test_repr(self): + """Test string representation.""" + transform = AddHeteroHopDistancePE(k=5) + self.assertEqual(repr(transform), 'AddHeteroHopDistancePE(k=5)') + + +class TestAddHeteroRandomWalkPE(TestCase): + def test_forward_basic(self): + """Test basic forward pass.""" + data = create_simple_hetero_data() + transform = AddHeteroRandomWalkPE(walk_length=4) + + result = transform(data) + + # Check that PE was added to both node types + self.assertTrue(hasattr(result['user'], 'random_walk_pe')) + self.assertTrue(hasattr(result['item'], 'random_walk_pe')) + + # Check shapes + self.assertEqual(result['user'].random_walk_pe.shape, (3, 4)) + self.assertEqual(result['item'].random_walk_pe.shape, (2, 4)) + + def test_forward_with_custom_attr_name(self): + """Test forward pass with custom attribute name.""" + data = create_simple_hetero_data() + transform = AddHeteroRandomWalkPE(walk_length=3, attr_name='rw_pe') + + result = transform(data) + + self.assertTrue(hasattr(result['user'], 'rw_pe')) + self.assertTrue(hasattr(result['item'], 'rw_pe')) + self.assertFalse(hasattr(result['user'], 'random_walk_pe')) + + def test_forward_empty_graph(self): + """Test forward pass with empty graph.""" + data = create_empty_hetero_data() + transform = AddHeteroRandomWalkPE(walk_length=3) + + result = transform(data) + + self.assertEqual(result['user'].random_walk_pe.shape, (0, 3)) + self.assertEqual(result['item'].random_walk_pe.shape, (0, 3)) + + def test_repr(self): + """Test string representation.""" + transform = AddHeteroRandomWalkPE(walk_length=10) + self.assertEqual(repr(transform), 'AddHeteroRandomWalkPE(walk_length=10)') + + +if __name__ == '__main__': + absltest.main() From a6f758d1affc08a6933605e9710250a800e46646 Mon Sep 17 00:00:00 2001 From: Yozen Liu Date: Mon, 23 Feb 2026 13:38:25 -0800 Subject: [PATCH 02/10] Fix AttributeError in add_node_attr when node type has no x attribute Use getattr with default None instead of direct attribute access, which raises AttributeError on NodeStorage objects without an x attribute. Co-Authored-By: Claude Opus 4.6 --- gigl/transforms/add_positional_encodings.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gigl/transforms/add_positional_encodings.py b/gigl/transforms/add_positional_encodings.py index 2520fbdae..38b44c28b 100644 --- a/gigl/transforms/add_positional_encodings.py +++ b/gigl/transforms/add_positional_encodings.py @@ -42,7 +42,7 @@ def add_node_attr( if attr_name is None: # Concatenate to existing x or create new x - x = data[node_type].x + x = getattr(data[node_type], "x", None) if x is not None: x = x.view(-1, 1) if x.dim() == 1 else x data[node_type].x = torch.cat( From dc45b7e8248ba0deb0ae34df935302876ebcc0c6 Mon Sep 17 00:00:00 2001 From: Yozen Liu Date: Mon, 23 Feb 2026 18:29:50 -0800 Subject: [PATCH 03/10] updates --- gigl/transforms/add_positional_encodings.py | 330 +++++++++++++----- gigl/transforms/utils.py | 167 +++++++++ .../add_positional_encodings_test.py | 180 +++++++--- tests/unit/transforms/utils_test.py | 208 +++++++++++ 4 files changed, 741 insertions(+), 144 deletions(-) create mode 100644 gigl/transforms/utils.py create mode 100644 tests/unit/transforms/utils_test.py diff --git a/gigl/transforms/add_positional_encodings.py b/gigl/transforms/add_positional_encodings.py index 38b44c28b..58d8066af 100644 --- a/gigl/transforms/add_positional_encodings.py +++ b/gigl/transforms/add_positional_encodings.py @@ -1,83 +1,46 @@ -from typing import Dict, Optional +from typing import Optional import torch -from torch import Tensor from torch_geometric.data import HeteroData from torch_geometric.data.datapipes import functional_transform from torch_geometric.transforms import BaseTransform from torch_geometric.utils import to_torch_sparse_tensor +from gigl.transforms.utils import add_node_attr, add_edge_attr -def add_node_attr( - data: HeteroData, - values: Tensor, - attr_name: Optional[str] = None, - node_type_to_idx: Optional[Dict[str, tuple]] = None, -) -> HeteroData: - """Helper function to add node attributes to a HeteroData object. - Args: - data: The HeteroData object to modify. - values: The tensor of values (in homogeneous node order). - attr_name: The name of the attribute to add. If None, concatenates to - existing `x` attribute for each node type (or creates it). - node_type_to_idx: Optional mapping from node type to (start, end) indices. - If None, it will be computed from data.node_types. - - Returns: - The modified HeteroData object. - """ - if node_type_to_idx is None: - node_type_to_idx = {} - start_idx = 0 - for node_type in data.node_types: - num_type_nodes = data[node_type].num_nodes - node_type_to_idx[node_type] = (start_idx, start_idx + num_type_nodes) - start_idx += num_type_nodes - - for node_type in data.node_types: - start, end = node_type_to_idx[node_type] - value = values[start:end] - - if attr_name is None: - # Concatenate to existing x or create new x - x = getattr(data[node_type], "x", None) - if x is not None: - x = x.view(-1, 1) if x.dim() == 1 else x - data[node_type].x = torch.cat( - [x, value.to(x.device, x.dtype)], dim=-1 - ) - else: - data[node_type].x = value - else: - data[node_type][attr_name] = value +@functional_transform('add_hetero_random_walk_pe') +class AddHeteroRandomWalkPE(BaseTransform): + r"""Adds the random walk positional encoding to the given heterogeneous graph + (functional name: :obj:`add_hetero_random_walk_pe`). - return data + For each node j, computes the sum of transition probabilities from all other + nodes to j after k steps of a random walk, for k = 1, 2, ..., walk_length. + This captures how "reachable" or "central" a node is from the rest of the graph. + The encoding is the column sum of non-diagonal elements of the k-step + random walk matrix: + PE[j, k] = Σ_{i≠j} (P^k)[i, j] -@functional_transform('add_hetero_hop_distance_pe') -class AddHeteroHopDistancePE(BaseTransform): - r"""Adds the hop distance positional encoding from the - `"Graph Neural Networks with Learnable Structural and Positional - Representations" `_ paper to the given - heterogeneous graph (functional name: :obj:`add_hetero_hop_distance_pe`). + where P is the transition matrix. This measures the probability mass flowing + into node j from all other nodes at step k. Args: - k (int): The number of hops to consider. + walk_length (int): The number of random walk steps. attr_name (str, optional): The attribute name of the positional - encoding. (default: :obj:`"hop_pe"`) + encoding. (default: :obj:`"random_walk_pe"`) is_undirected (bool, optional): If set to :obj:`True`, the graph is - assumed to be undirected, and multi-hop connectivity will be - computed accordingly. (default: :obj:`False`) + assumed to be undirected, and the adjacency matrix will be made + symmetric. (default: :obj:`False`) """ def __init__( self, - k: int, - attr_name: Optional[str] = 'hop_pe', + walk_length: int, + attr_name: Optional[str] = 'random_walk_pe', is_undirected: bool = False, ) -> None: - self.k = k + self.walk_length = walk_length self.attr_name = attr_name self.is_undirected = is_undirected @@ -87,73 +50,89 @@ def forward(self, data: HeteroData) -> HeteroData: f"(got '{type(data)}')" ) - # Convert to homogeneous to calculate subgraph hop distances + # Convert to homogeneous homo_data = data.to_homogeneous() edge_index = homo_data.edge_index num_nodes = homo_data.num_nodes if num_nodes == 0: - # Handle empty graph case for node_type in data.node_types: data[node_type][self.attr_name] = torch.zeros( - (data[node_type].num_nodes, self.k), + (data[node_type].num_nodes, self.walk_length), dtype=torch.float, ) return data - # Compute Adjacency Matrix (sparse for efficiency) + # Compute transition matrix (row-stochastic) adj = to_torch_sparse_tensor(edge_index, size=(num_nodes, num_nodes)) + adj_dense = adj.to_dense() if self.is_undirected: # Make symmetric for undirected graphs - adj = adj + adj.t() - adj = adj.coalesce() - - # Convert to dense for matrix power computation - adj_dense = adj.to_dense() - adj_dense = (adj_dense > 0).float() # Binary adjacency + adj_dense = adj_dense + adj_dense.t() - # Iteratively compute k-hop reachability - hop_pe = torch.zeros((num_nodes, self.k), dtype=torch.float) - current_power = torch.eye(num_nodes, dtype=torch.float) + # Compute degree and create transition matrix + deg = adj_dense.sum(dim=1, keepdim=True) + deg = torch.clamp(deg, min=1) # Avoid division by zero + transition = adj_dense / deg - for i in range(self.k): - current_power = current_power @ adj_dense - # Count number of paths at exactly i+1 hops (normalized) - reachable = (current_power > 0).float() - hop_pe[:, i] = reachable.sum(dim=-1) + # Compute random walk positional encoding + # PE[j, k] = sum of column j excluding diagonal = Σ_{i≠j} (P^k)[i, j] + pe = torch.zeros((num_nodes, self.walk_length), dtype=torch.float) + current = torch.eye(num_nodes, dtype=torch.float) - # Normalize hop distances (optional, for numerical stability) - hop_pe = hop_pe / (hop_pe.max(dim=0, keepdim=True).values + 1e-8) + for k in range(self.walk_length): + current = current @ transition + # Sum each column, excluding diagonal elements + # column_sum[j] = Σ_i current[i, j] + # diagonal[j] = current[j, j] + # non_diagonal_column_sum[j] = column_sum[j] - diagonal[j] + column_sum = current.sum(dim=0) # Sum along rows for each column + diagonal = current.diag() + pe[:, k] = column_sum - diagonal # Map back to HeteroData node types - add_node_attr(data, hop_pe, self.attr_name) + add_node_attr(data, pe, self.attr_name) return data def __repr__(self) -> str: - return f'{self.__class__.__name__}(k={self.k})' + return f'{self.__class__.__name__}(walk_length={self.walk_length})' -@functional_transform('add_hetero_random_walk_pe') -class AddHeteroRandomWalkPE(BaseTransform): - r"""Adds the random walk positional encoding from the +@functional_transform('add_hetero_random_walk_se') +class AddHeteroRandomWalkSE(BaseTransform): + r"""Adds the random walk structural encoding from the `"Graph Neural Networks with Learnable Structural and Positional Representations" `_ paper to the given - heterogeneous graph (functional name: :obj:`add_hetero_random_walk_pe`). + heterogeneous graph (functional name: :obj:`add_hetero_random_walk_se`). + + For each node, computes the probability of returning to itself after k steps + of a random walk, for k = 1, 2, ..., walk_length. This captures the local + structural role of each node (e.g., cycles, clustering coefficient). + + The encoding is the diagonal of the k-step random walk matrix: + SE[i, k] = (P^k)[i, i] + + where P is the transition matrix. Args: walk_length (int): The number of random walk steps. - attr_name (str, optional): The attribute name of the positional - encoding. (default: :obj:`"random_walk_pe"`) + attr_name (str, optional): The attribute name of the structural + encoding. (default: :obj:`"random_walk_se"`) + is_undirected (bool, optional): If set to :obj:`True`, the graph is + assumed to be undirected, and the adjacency matrix will be made + symmetric. (default: :obj:`False`) """ def __init__( self, walk_length: int, - attr_name: Optional[str] = 'random_walk_pe', + attr_name: Optional[str] = 'random_walk_se', + is_undirected: bool = False, ) -> None: self.walk_length = walk_length self.attr_name = attr_name + self.is_undirected = is_undirected def forward(self, data: HeteroData) -> HeteroData: assert isinstance(data, HeteroData), ( @@ -178,24 +157,193 @@ def forward(self, data: HeteroData) -> HeteroData: adj = to_torch_sparse_tensor(edge_index, size=(num_nodes, num_nodes)) adj_dense = adj.to_dense() + if self.is_undirected: + # Make symmetric for undirected graphs + adj_dense = adj_dense + adj_dense.t() + # Compute degree and create transition matrix deg = adj_dense.sum(dim=1, keepdim=True) deg = torch.clamp(deg, min=1) # Avoid division by zero transition = adj_dense / deg - # Compute random walk probabilities - pe = torch.zeros((num_nodes, self.walk_length), dtype=torch.float) + # Compute random walk return probabilities (diagonal elements) + se = torch.zeros((num_nodes, self.walk_length), dtype=torch.float) current = torch.eye(num_nodes, dtype=torch.float) for i in range(self.walk_length): current = current @ transition # Diagonal gives probability of returning to the same node - pe[:, i] = current.diag() + se[:, i] = current.diag() # Map back to HeteroData node types - add_node_attr(data, pe, self.attr_name) + add_node_attr(data, se, self.attr_name) return data def __repr__(self) -> str: return f'{self.__class__.__name__}(walk_length={self.walk_length})' + + + +@functional_transform('add_hetero_hop_distance_encoding') +class AddHeteroHopDistanceEncoding(BaseTransform): + r"""Adds hop distance positional encoding as relative encoding. + + For each pair of nodes (vi, vj), computes the shortest path distance p(vi, vj). + This captures structural proximity and can be used with a learnable embedding + matrix: + + h_hop(vi, vj) = W_hop · onehot(p(vi, vj)) + + Based on the approach from `"Do Transformers Really Perform Bad for Graph + Representation?" `_ (Graphormer). + + Args: + h_max (int): Maximum hop distance to consider. Distances > h_max + are clipped to h_max (representing "far" or "unreachable" nodes). + Set to 2 - 3 for 2hop sampled subgraphs. + Set to min(walk_length // 2, 10) for random walk sampled subgraphs. + attr_name (str, optional): The attribute name of the positional + encoding. (default: :obj:`"hop_distance"`) + is_undirected (bool, optional): If set to :obj:`True`, the graph is + assumed to be undirected for distance computation. + (default: :obj:`False`) + full_matrix (bool, optional): If set to :obj:`True`, stores the full + pairwise distance matrix as a graph-level attribute (for use in + Graph Transformers with fully-connected attention). If :obj:`False`, + stores hop distances only for existing edges. Note that when + :obj:`full_matrix=False`, the hop distance for existing edges is + always 1 (direct connection), which may be redundant. Use + :obj:`full_matrix=True` for attention bias in Graph Transformers. + (default: :obj:`False`) + """ + def __init__( + self, + h_max: int, + attr_name: Optional[str] = 'hop_distance', + is_undirected: bool = False, + full_matrix: bool = False, + ) -> None: + self.h_max = h_max + self.attr_name = attr_name + self.is_undirected = is_undirected + self.full_matrix = full_matrix + + def forward(self, data: HeteroData) -> HeteroData: + assert isinstance(data, HeteroData), ( + f"'{self.__class__.__name__}' only supports 'HeteroData' " + f"(got '{type(data)}')" + ) + + # Convert to homogeneous to compute shortest paths + homo_data = data.to_homogeneous() + edge_index = homo_data.edge_index + num_nodes = homo_data.num_nodes + num_edges = edge_index.size(1) + + if num_nodes == 0 or num_edges == 0: + # Handle empty graph case + if self.full_matrix: + data[self.attr_name] = torch.zeros( + (num_nodes, num_nodes), dtype=torch.long + ) + else: + for edge_type in data.edge_types: + num_type_edges = data[edge_type].num_edges + data[edge_type][self.attr_name] = torch.zeros( + num_type_edges, dtype=torch.long + ) + return data + + # Build adjacency matrix for shortest path computation + adj = to_torch_sparse_tensor(edge_index, size=(num_nodes, num_nodes)) + adj_dense = adj.to_dense() + + if self.is_undirected: + # Make symmetric for undirected graphs + adj_dense = adj_dense + adj_dense.t() + + adj_dense = (adj_dense > 0).float() # Binary adjacency + + # Compute shortest path distances using BFS via matrix powers + # dist_matrix[i, j] = shortest path distance from node i to node j + dist_matrix = torch.full( + (num_nodes, num_nodes), self.h_max, dtype=torch.long + ) + dist_matrix.fill_diagonal_(0) # Distance to self is 0 + + # BFS: track which nodes are reachable and at what distance + reachable = torch.eye(num_nodes, dtype=torch.bool) + current_frontier = adj_dense.bool() + + for hop in range(1, self.h_max + 1): + # Nodes newly reachable at this hop (not previously seen) + newly_reachable = current_frontier & ~reachable + # Set distance for newly reachable nodes + dist_matrix[newly_reachable] = hop + # Update reachable set + reachable = reachable | current_frontier + + # Early exit if all nodes are reachable (no need to continue BFS) + if reachable.all(): + break + + # Expand frontier to next hop neighbors + current_frontier = (current_frontier.float() @ adj_dense) > 0 + + # Example + # Graph: 0 -> 1 -> 2 (3 nodes, chain) + # edge_index = [[0, 1], [1, 2]] (edges: 0→1, 1→2) + # h_max = 3 + + # Initial state: + # dist_matrix = [[0, 3, 3], # 3 = h_max (unreachable) + # [3, 0, 3], + # [3, 3, 0]] + # reachable = eye(3) = [[T,F,F], [F,T,F], [F,F,T]] + # current_frontier = adj_dense = [[F,T,F], [F,F,T], [F,F,F]] + + # Hop 1: + # newly_reachable = [[F,T,F], [F,F,T], [F,F,F]] (0→1, 1→2) + # dist_matrix[0,1] = 1, dist_matrix[1,2] = 1 + # dist_matrix = [[0, 1, 3], + # [3, 0, 1], + # [3, 3, 0]] + + # Hop 2: + # current_frontier = frontier @ adj = [[F,F,T], [F,F,F], [F,F,F]] + # newly_reachable = [[F,F,T], ...] (0→2 in 2 hops) + # dist_matrix[0,2] = 2 + # dist_matrix = [[0, 1, 2], + # [3, 0, 1], + # [3, 3, 0]] + + if self.full_matrix: + # Store full pairwise distance matrix as graph-level attribute on HeteroData + # Shape: (num_nodes, num_nodes) - for use in Graph Transformers + # Access via: data.hop_distance or data['hop_distance'] + # Can be used as attention bias: bias = learnable_embedding[data.hop_distance.long()] + # Note: Node ordering follows data.to_homogeneous() order (by node_type alphabetically) + data[self.attr_name] = dist_matrix.float() + else: + # Extract hop distances for each edge in edge_index + # Example: + # Graph: 0 -> 1 -> 2 (3 nodes, chain) + # edge_index = [[0, 1], [1, 2]] (edges: 0→1, 1→2) + # h_max = 3 + # edge_index = [[0, 1], [1, 2]] + # src_nodes = [0, 1], dst_nodes = [1, 2] + # edge_hop_distances = [dist_matrix[0,1], dist_matrix[1,2]] = [1, 1] + + # Output: tensor([[1.], [1.]]) # Both edges are direct (1-hop) + src_nodes = edge_index[0] # Source nodes + dst_nodes = edge_index[1] # Destination nodes + edge_hop_distances = dist_matrix[src_nodes, dst_nodes] + # Store hop distances only for existing edges + # Map back to HeteroData edge types + add_edge_attr(data, edge_hop_distances.unsqueeze(-1).float(), self.attr_name) + + return data + + def __repr__(self) -> str: + return f'{self.__class__.__name__}(h_max={self.h_max}, full_matrix={self.full_matrix})' diff --git a/gigl/transforms/utils.py b/gigl/transforms/utils.py new file mode 100644 index 000000000..ada69687c --- /dev/null +++ b/gigl/transforms/utils.py @@ -0,0 +1,167 @@ +from typing import Dict, Optional, Tuple, Union + +import torch +from torch import Tensor + +from torch_geometric.data import HeteroData + +# Type alias for edge types in PyG HeteroData +EdgeType = Tuple[str, str, str] + + +def add_node_attr( + data: HeteroData, + values: Union[Tensor, Dict[str, Tensor]], + attr_name: Optional[str] = None, + node_type_to_idx: Optional[Dict[str, Tuple[int, int]]] = None, +) -> HeteroData: + """Helper function to add node attributes to a HeteroData object. + + Args: + data: The HeteroData object to modify. + values: Either: + - A tensor of values in homogeneous node order (requires node_type_to_idx + or will be computed from data.node_types), OR + - A dictionary mapping node types to tensors of values for each type. + attr_name: The name of the attribute to add. If None, concatenates to + existing `x` attribute for each node type (or creates it). + node_type_to_idx: Optional mapping from node type to (start, end) indices. + Only used when values is a tensor. If None, it will be computed from + data.node_types. + + Returns: + The modified HeteroData object. + """ + # If values is a dictionary, directly assign to each node type + if isinstance(values, dict): + for node_type, value in values.items(): + if node_type not in data.node_types: + continue + _set_node_attr_for_type(data, node_type, value, attr_name) + return data + + # Otherwise, values is a tensor in homogeneous order - split by node type + if node_type_to_idx is None: + # Build mapping from node type to (start, end) indices in homogeneous tensor + # When HeteroData is converted to homogeneous, nodes are ordered by node type. + # This mapping lets us slice the homogeneous tensor to get values for each type. + # Example: if data has 3 'user' nodes and 2 'item' nodes: + # node_type_to_idx = {'user': (0, 3), 'item': (3, 5)} + node_type_to_idx = {} + start_idx = 0 + for node_type in data.node_types: + num_type_nodes = data[node_type].num_nodes + node_type_to_idx[node_type] = (start_idx, start_idx + num_type_nodes) + start_idx += num_type_nodes + + for node_type in data.node_types: + start, end = node_type_to_idx[node_type] + value = values[start:end] + _set_node_attr_for_type(data, node_type, value, attr_name) + + return data + + +def _set_node_attr_for_type( + data: HeteroData, + node_type: str, + value: Tensor, + attr_name: Optional[str], +) -> None: + """Helper to set node attribute for a single node type.""" + if attr_name is None: + # Concatenate to existing x or create new x + # Use getattr to safely get x attribute, returns None if not present + x = getattr(data[node_type], "x", None) + if x is not None: + # Existing features found: concatenate new values to them + # Reshape 1D tensor [num_nodes] to 2D [num_nodes, 1] for concatenation + x = x.view(-1, 1) if x.dim() == 1 else x + # Move value to same device/dtype as x, then concatenate along feature dim + data[node_type].x = torch.cat( + [x, value.to(x.device, x.dtype)], dim=-1 + ) + else: + # No existing features: use new values as x directly + data[node_type].x = value + else: + data[node_type][attr_name] = value + + +def add_edge_attr( + data: HeteroData, + values: Union[Tensor, Dict[EdgeType, Tensor]], + attr_name: Optional[str] = None, + edge_type_to_idx: Optional[Dict[EdgeType, Tuple[int, int]]] = None, +) -> HeteroData: + """Helper function to add edge attributes to a HeteroData object. + + Args: + data: The HeteroData object to modify. + values: Either: + - A tensor of values in homogeneous edge order (requires edge_type_to_idx + or will be computed from data.edge_types), OR + - A dictionary mapping edge types to tensors of values for each type. + attr_name: The name of the attribute to add. If None, concatenates to + existing `edge_attr` attribute for each edge type (or creates it). + edge_type_to_idx: Optional mapping from edge type to (start, end) indices. + Only used when values is a tensor. If None, it will be computed from + data.edge_types. + + Returns: + The modified HeteroData object. + """ + # If values is a dictionary, directly assign to each edge type + if isinstance(values, dict): + for edge_type, value in values.items(): + if edge_type not in data.edge_types: + continue + _set_edge_attr_for_type(data, edge_type, value, attr_name) + return data + + # Otherwise, values is a tensor in homogeneous order - split by edge type + if edge_type_to_idx is None: + # Build mapping from edge type to (start, end) indices in homogeneous tensor + # When HeteroData is converted to homogeneous, edges are ordered by edge type. + # This mapping lets us slice the homogeneous tensor to get values for each type. + # Example: if data has 3 'buys' edges and 2 'views' edges: + # edge_type_to_idx = {('user', 'buys', 'item'): (0, 3), ('user', 'views', 'item'): (3, 5)} + edge_type_to_idx = {} + start_idx = 0 + for edge_type in data.edge_types: + num_type_edges = data[edge_type].num_edges + edge_type_to_idx[edge_type] = (start_idx, start_idx + num_type_edges) + start_idx += num_type_edges + + for edge_type in data.edge_types: + start, end = edge_type_to_idx[edge_type] + value = values[start:end] + _set_edge_attr_for_type(data, edge_type, value, attr_name) + + return data + + +def _set_edge_attr_for_type( + data: HeteroData, + edge_type: EdgeType, + value: Tensor, + attr_name: Optional[str], +) -> None: + """Helper to set edge attribute for a single edge type.""" + if attr_name is None: + # Concatenate to existing edge_attr or create new edge_attr + # Use getattr to safely get edge_attr attribute, returns None if not present + edge_attr = getattr(data[edge_type], "edge_attr", None) + if edge_attr is not None: + # Existing features found: concatenate new values to them + # Reshape 1D tensor [num_edges] to 2D [num_edges, 1] for concatenation + edge_attr = edge_attr.view(-1, 1) if edge_attr.dim() == 1 else edge_attr + # Move value to same device/dtype as edge_attr, then concatenate along feature dim + data[edge_type].edge_attr = torch.cat( + [edge_attr, value.to(edge_attr.device, edge_attr.dtype)], dim=-1 + ) + else: + # No existing features: use new values as edge_attr directly + data[edge_type].edge_attr = value + else: + data[edge_type][attr_name] = value diff --git a/tests/unit/transforms/add_positional_encodings_test.py b/tests/unit/transforms/add_positional_encodings_test.py index 4db4ee41a..713fdcd49 100644 --- a/tests/unit/transforms/add_positional_encodings_test.py +++ b/tests/unit/transforms/add_positional_encodings_test.py @@ -3,9 +3,9 @@ from torch_geometric.data import HeteroData from gigl.transforms.add_positional_encodings import ( - AddHeteroHopDistancePE, + AddHeteroHopDistanceEncoding, AddHeteroRandomWalkPE, - add_node_attr, + AddHeteroRandomWalkSE, ) from tests.test_assets.test_case import TestCase @@ -47,108 +47,155 @@ def create_empty_hetero_data() -> HeteroData: return data -class TestAddNodeAttr(TestCase): - def test_add_node_attr_with_attr_name(self): - """Test adding a node attribute with a specific attribute name.""" +class TestAddHeteroHopDistanceEncoding(TestCase): + def test_forward_basic(self): + """Test basic forward pass.""" data = create_simple_hetero_data() + transform = AddHeteroHopDistanceEncoding(h_max=3) + + result = transform(data) + + # Check that PE was added to both edge types + self.assertTrue(hasattr(result['user', 'buys', 'item'], 'hop_distance')) + self.assertTrue(hasattr(result['item', 'bought_by', 'user'], 'hop_distance')) + + # Check shapes (3 edges each) + self.assertEqual(result['user', 'buys', 'item'].hop_distance.shape, (3, 1)) + self.assertEqual(result['item', 'bought_by', 'user'].hop_distance.shape, (3, 1)) - # Create values in homogeneous order (3 users + 2 items = 5 nodes) - values = torch.randn(5, 8) + # Direct edges should have distance <= h_max + self.assertTrue((result['user', 'buys', 'item'].hop_distance <= 3).all()) - add_node_attr(data, values, attr_name='test_attr') + def test_forward_with_custom_attr_name(self): + """Test forward pass with custom attribute name.""" + data = create_simple_hetero_data() + transform = AddHeteroHopDistanceEncoding(h_max=2, attr_name='custom_hop') + + result = transform(data) - self.assertEqual(data['user'].test_attr.shape, (3, 8)) - self.assertEqual(data['item'].test_attr.shape, (2, 8)) - self.assert_tensor_equality(data['user'].test_attr, values[:3]) - self.assert_tensor_equality(data['item'].test_attr, values[3:]) + self.assertTrue(hasattr(result['user', 'buys', 'item'], 'custom_hop')) + self.assertTrue(hasattr(result['item', 'bought_by', 'user'], 'custom_hop')) + self.assertFalse(hasattr(result['user', 'buys', 'item'], 'hop_distance')) - def test_add_node_attr_concatenate_to_x(self): - """Test adding a node attribute by concatenating to existing x.""" + def test_forward_undirected(self): + """Test forward pass with undirected graph setting.""" data = create_simple_hetero_data() - original_user_x = data['user'].x.clone() - original_item_x = data['item'].x.clone() + transform = AddHeteroHopDistanceEncoding(h_max=2, is_undirected=True) + + result = transform(data) + + self.assertEqual(result['user', 'buys', 'item'].hop_distance.shape, (3, 1)) + self.assertEqual(result['item', 'bought_by', 'user'].hop_distance.shape, (3, 1)) + + def test_forward_empty_graph(self): + """Test forward pass with empty graph.""" + data = create_empty_hetero_data() + # Add empty edge types + data['user', 'buys', 'item'].edge_index = torch.zeros((2, 0), dtype=torch.long) + transform = AddHeteroHopDistanceEncoding(h_max=3) + + result = transform(data) - # Create values in homogeneous order - values = torch.randn(5, 8) + self.assertEqual(result['user', 'buys', 'item'].hop_distance.shape, (0,)) - add_node_attr(data, values, attr_name=None) + def test_forward_full_matrix(self): + """Test forward pass with full_matrix=True for Graph Transformer use.""" + data = create_simple_hetero_data() + transform = AddHeteroHopDistanceEncoding(h_max=3, full_matrix=True) - # Check that x was concatenated - self.assertEqual(data['user'].x.shape, (3, 12)) # 4 + 8 - self.assertEqual(data['item'].x.shape, (2, 12)) # 4 + 8 + result = transform(data) - # Check original features are preserved - self.assert_tensor_equality(data['user'].x[:, :4], original_user_x) - self.assert_tensor_equality(data['item'].x[:, :4], original_item_x) + # Check that full pairwise distance matrix is stored as graph-level attribute + self.assertTrue(hasattr(result, 'hop_distance')) + # Total nodes: 3 users + 2 items = 5 nodes + self.assertEqual(result.hop_distance.shape, (5, 5)) + # Diagonal should be 0 (distance to self) + self.assertTrue((result.hop_distance.diag() == 0).all()) + # All distances should be <= h_max + self.assertTrue((result.hop_distance <= 3).all()) + + def test_forward_full_matrix_empty_graph(self): + """Test forward pass with full_matrix=True on empty graph.""" + data = create_empty_hetero_data() + data['user', 'buys', 'item'].edge_index = torch.zeros((2, 0), dtype=torch.long) + transform = AddHeteroHopDistanceEncoding(h_max=3, full_matrix=True) - def test_add_node_attr_create_x_if_none(self): - """Test creating x attribute if it doesn't exist.""" - data = HeteroData() - data['user'].num_nodes = 3 - data['item'].num_nodes = 2 + result = transform(data) - values = torch.randn(5, 8) + self.assertTrue(hasattr(result, 'hop_distance')) + self.assertEqual(result.hop_distance.shape, (0, 0)) - add_node_attr(data, values, attr_name=None) + def test_repr(self): + """Test string representation.""" + transform = AddHeteroHopDistanceEncoding(h_max=5) + self.assertEqual(repr(transform), 'AddHeteroHopDistanceEncoding(h_max=5, full_matrix=False)') - self.assertEqual(data['user'].x.shape, (3, 8)) - self.assertEqual(data['item'].x.shape, (2, 8)) + transform_full = AddHeteroHopDistanceEncoding(h_max=3, full_matrix=True) + self.assertEqual(repr(transform_full), 'AddHeteroHopDistanceEncoding(h_max=3, full_matrix=True)') -class TestAddHeteroHopDistancePE(TestCase): +class TestAddHeteroRandomWalkSE(TestCase): + """Tests for AddHeteroRandomWalkSE (Structural Encoding - diagonal elements).""" + def test_forward_basic(self): """Test basic forward pass.""" data = create_simple_hetero_data() - transform = AddHeteroHopDistancePE(k=3) + transform = AddHeteroRandomWalkSE(walk_length=4) result = transform(data) - # Check that PE was added to both node types - self.assertTrue(hasattr(result['user'], 'hop_pe')) - self.assertTrue(hasattr(result['item'], 'hop_pe')) + # Check that SE was added to both node types + self.assertTrue(hasattr(result['user'], 'random_walk_se')) + self.assertTrue(hasattr(result['item'], 'random_walk_se')) # Check shapes - self.assertEqual(result['user'].hop_pe.shape, (3, 3)) - self.assertEqual(result['item'].hop_pe.shape, (2, 3)) + self.assertEqual(result['user'].random_walk_se.shape, (3, 4)) + self.assertEqual(result['item'].random_walk_se.shape, (2, 4)) + + # Values should be probabilities (between 0 and 1) + self.assertTrue((result['user'].random_walk_se >= 0).all()) + self.assertTrue((result['user'].random_walk_se <= 1).all()) def test_forward_with_custom_attr_name(self): """Test forward pass with custom attribute name.""" data = create_simple_hetero_data() - transform = AddHeteroHopDistancePE(k=2, attr_name='custom_pe') + transform = AddHeteroRandomWalkSE(walk_length=3, attr_name='rw_se') result = transform(data) - self.assertTrue(hasattr(result['user'], 'custom_pe')) - self.assertTrue(hasattr(result['item'], 'custom_pe')) - self.assertFalse(hasattr(result['user'], 'hop_pe')) + self.assertTrue(hasattr(result['user'], 'rw_se')) + self.assertTrue(hasattr(result['item'], 'rw_se')) + self.assertFalse(hasattr(result['user'], 'random_walk_se')) def test_forward_undirected(self): """Test forward pass with undirected graph setting.""" data = create_simple_hetero_data() - transform = AddHeteroHopDistancePE(k=2, is_undirected=True) + transform = AddHeteroRandomWalkSE(walk_length=3, is_undirected=True) result = transform(data) - self.assertEqual(result['user'].hop_pe.shape, (3, 2)) - self.assertEqual(result['item'].hop_pe.shape, (2, 2)) + self.assertEqual(result['user'].random_walk_se.shape, (3, 3)) + self.assertEqual(result['item'].random_walk_se.shape, (2, 3)) def test_forward_empty_graph(self): """Test forward pass with empty graph.""" data = create_empty_hetero_data() - transform = AddHeteroHopDistancePE(k=3) + transform = AddHeteroRandomWalkSE(walk_length=3) result = transform(data) - self.assertEqual(result['user'].hop_pe.shape, (0, 3)) - self.assertEqual(result['item'].hop_pe.shape, (0, 3)) + self.assertEqual(result['user'].random_walk_se.shape, (0, 3)) + self.assertEqual(result['item'].random_walk_se.shape, (0, 3)) def test_repr(self): """Test string representation.""" - transform = AddHeteroHopDistancePE(k=5) - self.assertEqual(repr(transform), 'AddHeteroHopDistancePE(k=5)') + transform = AddHeteroRandomWalkSE(walk_length=10) + self.assertEqual(repr(transform), 'AddHeteroRandomWalkSE(walk_length=10)') class TestAddHeteroRandomWalkPE(TestCase): + """Tests for AddHeteroRandomWalkPE (Positional Encoding - column sum of non-diagonal).""" + def test_forward_basic(self): """Test basic forward pass.""" data = create_simple_hetero_data() @@ -175,6 +222,16 @@ def test_forward_with_custom_attr_name(self): self.assertTrue(hasattr(result['item'], 'rw_pe')) self.assertFalse(hasattr(result['user'], 'random_walk_pe')) + def test_forward_undirected(self): + """Test forward pass with undirected graph setting.""" + data = create_simple_hetero_data() + transform = AddHeteroRandomWalkPE(walk_length=3, is_undirected=True) + + result = transform(data) + + self.assertEqual(result['user'].random_walk_pe.shape, (3, 3)) + self.assertEqual(result['item'].random_walk_pe.shape, (2, 3)) + def test_forward_empty_graph(self): """Test forward pass with empty graph.""" data = create_empty_hetero_data() @@ -185,6 +242,23 @@ def test_forward_empty_graph(self): self.assertEqual(result['user'].random_walk_pe.shape, (0, 3)) self.assertEqual(result['item'].random_walk_pe.shape, (0, 3)) + def test_pe_differs_from_se(self): + """Test that PE (column sum) differs from SE (diagonal).""" + data = create_simple_hetero_data() + transform_pe = AddHeteroRandomWalkPE(walk_length=4) + transform_se = AddHeteroRandomWalkSE(walk_length=4) + + result_pe = transform_pe(data.clone()) + result_se = transform_se(data.clone()) + + # PE and SE should have different values (column sum vs diagonal) + # They may occasionally match for specific graphs, but generally differ + pe_values = result_pe['user'].random_walk_pe + se_values = result_se['user'].random_walk_se + + # Check shapes are the same + self.assertEqual(pe_values.shape, se_values.shape) + def test_repr(self): """Test string representation.""" transform = AddHeteroRandomWalkPE(walk_length=10) diff --git a/tests/unit/transforms/utils_test.py b/tests/unit/transforms/utils_test.py new file mode 100644 index 000000000..5ad1219bb --- /dev/null +++ b/tests/unit/transforms/utils_test.py @@ -0,0 +1,208 @@ +import torch +from absl.testing import absltest +from torch_geometric.data import HeteroData + +from gigl.transforms.utils import add_node_attr, add_edge_attr +from tests.test_assets.test_case import TestCase + + +def create_simple_hetero_data() -> HeteroData: + """Create a simple heterogeneous graph for testing. + + Graph structure: + - 3 'user' nodes + - 2 'item' nodes + - Edges: user -> item (bipartite) + """ + data = HeteroData() + + # Node features + data['user'].x = torch.randn(3, 4) + data['item'].x = torch.randn(2, 4) + + # Edges: user -> item + data['user', 'buys', 'item'].edge_index = torch.tensor([ + [0, 1, 2], # source (user) + [0, 0, 1], # target (item) + ]) + + # Edges: item -> user (reverse) + data['item', 'bought_by', 'user'].edge_index = torch.tensor([ + [0, 0, 1], # source (item) + [0, 1, 2], # target (user) + ]) + + return data + + +class TestAddNodeAttr(TestCase): + def test_add_node_attr_with_attr_name(self): + """Test adding a node attribute with a specific attribute name.""" + data = create_simple_hetero_data() + + # Create values in homogeneous order (3 users + 2 items = 5 nodes) + values = torch.randn(5, 8) + + add_node_attr(data, values, attr_name='test_attr') + + self.assertEqual(data['user'].test_attr.shape, (3, 8)) + self.assertEqual(data['item'].test_attr.shape, (2, 8)) + self.assert_tensor_equality(data['user'].test_attr, values[:3]) + self.assert_tensor_equality(data['item'].test_attr, values[3:]) + + def test_add_node_attr_with_dict(self): + """Test adding a node attribute with dictionary input.""" + data = create_simple_hetero_data() + + # Create values as dictionary per node type + values = { + 'user': torch.randn(3, 8), + 'item': torch.randn(2, 8), + } + + add_node_attr(data, values, attr_name='test_attr') + + self.assertEqual(data['user'].test_attr.shape, (3, 8)) + self.assertEqual(data['item'].test_attr.shape, (2, 8)) + self.assert_tensor_equality(data['user'].test_attr, values['user']) + self.assert_tensor_equality(data['item'].test_attr, values['item']) + + def test_add_node_attr_with_dict_partial(self): + """Test adding a node attribute with dictionary containing only some node types.""" + data = create_simple_hetero_data() + + # Only provide values for 'user' node type + values = { + 'user': torch.randn(3, 8), + } + + add_node_attr(data, values, attr_name='test_attr') + + self.assertTrue(hasattr(data['user'], 'test_attr')) + self.assertEqual(data['user'].test_attr.shape, (3, 8)) + self.assertFalse(hasattr(data['item'], 'test_attr')) + + def test_add_node_attr_concatenate_to_x(self): + """Test adding a node attribute by concatenating to existing x.""" + data = create_simple_hetero_data() + original_user_x = data['user'].x.clone() + original_item_x = data['item'].x.clone() + + # Create values in homogeneous order + values = torch.randn(5, 8) + + add_node_attr(data, values, attr_name=None) + + # Check that x was concatenated + self.assertEqual(data['user'].x.shape, (3, 12)) # 4 + 8 + self.assertEqual(data['item'].x.shape, (2, 12)) # 4 + 8 + + # Check original features are preserved + self.assert_tensor_equality(data['user'].x[:, :4], original_user_x) + self.assert_tensor_equality(data['item'].x[:, :4], original_item_x) + + def test_add_node_attr_create_x_if_none(self): + """Test creating x attribute if it doesn't exist.""" + data = HeteroData() + data['user'].num_nodes = 3 + data['item'].num_nodes = 2 + + values = torch.randn(5, 8) + + add_node_attr(data, values, attr_name=None) + + self.assertEqual(data['user'].x.shape, (3, 8)) + self.assertEqual(data['item'].x.shape, (2, 8)) + + +class TestAddEdgeAttr(TestCase): + def test_add_edge_attr_with_attr_name(self): + """Test adding an edge attribute with a specific attribute name.""" + data = create_simple_hetero_data() + + # Create values in homogeneous order (3 + 3 = 6 edges) + values = torch.randn(6, 8) + + add_edge_attr(data, values, attr_name='test_attr') + + self.assertEqual(data['user', 'buys', 'item'].test_attr.shape, (3, 8)) + self.assertEqual(data['item', 'bought_by', 'user'].test_attr.shape, (3, 8)) + self.assert_tensor_equality(data['user', 'buys', 'item'].test_attr, values[:3]) + self.assert_tensor_equality(data['item', 'bought_by', 'user'].test_attr, values[3:]) + + def test_add_edge_attr_with_dict(self): + """Test adding an edge attribute with dictionary input.""" + data = create_simple_hetero_data() + + # Create values as dictionary per edge type + values = { + ('user', 'buys', 'item'): torch.randn(3, 8), + ('item', 'bought_by', 'user'): torch.randn(3, 8), + } + + add_edge_attr(data, values, attr_name='test_attr') + + self.assertEqual(data['user', 'buys', 'item'].test_attr.shape, (3, 8)) + self.assertEqual(data['item', 'bought_by', 'user'].test_attr.shape, (3, 8)) + self.assert_tensor_equality(data['user', 'buys', 'item'].test_attr, values[('user', 'buys', 'item')]) + self.assert_tensor_equality(data['item', 'bought_by', 'user'].test_attr, values[('item', 'bought_by', 'user')]) + + def test_add_edge_attr_with_dict_partial(self): + """Test adding an edge attribute with dictionary containing only some edge types.""" + data = create_simple_hetero_data() + + # Only provide values for one edge type + values = { + ('user', 'buys', 'item'): torch.randn(3, 8), + } + + add_edge_attr(data, values, attr_name='test_attr') + + self.assertTrue(hasattr(data['user', 'buys', 'item'], 'test_attr')) + self.assertEqual(data['user', 'buys', 'item'].test_attr.shape, (3, 8)) + self.assertFalse(hasattr(data['item', 'bought_by', 'user'], 'test_attr')) + + def test_add_edge_attr_concatenate_to_edge_attr(self): + """Test adding an edge attribute by concatenating to existing edge_attr.""" + data = create_simple_hetero_data() + + # Add initial edge attributes + data['user', 'buys', 'item'].edge_attr = torch.randn(3, 4) + data['item', 'bought_by', 'user'].edge_attr = torch.randn(3, 4) + + original_buys_attr = data['user', 'buys', 'item'].edge_attr.clone() + original_bought_by_attr = data['item', 'bought_by', 'user'].edge_attr.clone() + + # Create values in homogeneous order + values = torch.randn(6, 8) + + add_edge_attr(data, values, attr_name=None) + + # Check that edge_attr was concatenated + self.assertEqual(data['user', 'buys', 'item'].edge_attr.shape, (3, 12)) # 4 + 8 + self.assertEqual(data['item', 'bought_by', 'user'].edge_attr.shape, (3, 12)) # 4 + 8 + + # Check original features are preserved + self.assert_tensor_equality(data['user', 'buys', 'item'].edge_attr[:, :4], original_buys_attr) + self.assert_tensor_equality(data['item', 'bought_by', 'user'].edge_attr[:, :4], original_bought_by_attr) + + def test_add_edge_attr_create_edge_attr_if_none(self): + """Test creating edge_attr attribute if it doesn't exist.""" + data = create_simple_hetero_data() + + # Ensure no edge_attr exists + if hasattr(data['user', 'buys', 'item'], 'edge_attr'): + del data['user', 'buys', 'item'].edge_attr + if hasattr(data['item', 'bought_by', 'user'], 'edge_attr'): + del data['item', 'bought_by', 'user'].edge_attr + + values = torch.randn(6, 8) + + add_edge_attr(data, values, attr_name=None) + + self.assertEqual(data['user', 'buys', 'item'].edge_attr.shape, (3, 8)) + self.assertEqual(data['item', 'bought_by', 'user'].edge_attr.shape, (3, 8)) + + +if __name__ == '__main__': + absltest.main() From cc90c6f922fd5ae85aa737ecdad5677aff0381ec Mon Sep 17 00:00:00 2001 From: Yozen Liu Date: Tue, 24 Feb 2026 11:14:12 -0800 Subject: [PATCH 04/10] updates --- gigl/.claude/worktrees/loving-noyce | 1 + gigl/.claude/worktrees/zen-bhaskara | 1 + gigl/transforms/add_positional_encodings.py | 146 ++++++++--- .../add_positional_encodings_test.py | 247 +++++++++++------- 4 files changed, 268 insertions(+), 127 deletions(-) create mode 160000 gigl/.claude/worktrees/loving-noyce create mode 160000 gigl/.claude/worktrees/zen-bhaskara diff --git a/gigl/.claude/worktrees/loving-noyce b/gigl/.claude/worktrees/loving-noyce new file mode 160000 index 000000000..817294a12 --- /dev/null +++ b/gigl/.claude/worktrees/loving-noyce @@ -0,0 +1 @@ +Subproject commit 817294a12143f513ce396b50344ba3a767d84c6f diff --git a/gigl/.claude/worktrees/zen-bhaskara b/gigl/.claude/worktrees/zen-bhaskara new file mode 160000 index 000000000..4297f49cc --- /dev/null +++ b/gigl/.claude/worktrees/zen-bhaskara @@ -0,0 +1 @@ +Subproject commit 4297f49ccaa4f9956400d0c18438681757cd2ca8 diff --git a/gigl/transforms/add_positional_encodings.py b/gigl/transforms/add_positional_encodings.py index 58d8066af..ddfbf4e7f 100644 --- a/gigl/transforms/add_positional_encodings.py +++ b/gigl/transforms/add_positional_encodings.py @@ -9,6 +9,65 @@ from gigl.transforms.utils import add_node_attr, add_edge_attr +r""" +Positional and Structural Encodings for Heterogeneous Graphs. + +This module provides PyG-compatible transforms for adding positional and structural +encodings to HeteroData objects. All transforms follow the PyG BaseTransform interface +and can be composed using `torch_geometric.transforms.Compose`. + +Available Transforms: + - AddHeteroRandomWalkPE: Random walk positional encoding (column sum of non-diagonal) + - AddHeteroRandomWalkSE: Random walk structural encoding (diagonal elements) + - AddHeteroHopDistanceEncoding: Shortest path distance encoding + +Example Usage: + >>> from torch_geometric.data import HeteroData + >>> from torch_geometric.transforms import Compose + >>> from gigl.transforms.add_positional_encodings import ( + ... AddHeteroRandomWalkPE, + ... AddHeteroRandomWalkSE, + ... AddHeteroHopDistanceEncoding, + ... ) + >>> + >>> # Create a heterogeneous graph + >>> data = HeteroData() + >>> data['user'].x = torch.randn(5, 16) + >>> data['item'].x = torch.randn(3, 16) + >>> data['user', 'buys', 'item'].edge_index = torch.tensor([[0, 1, 2], [0, 1, 2]]) + >>> data['item', 'bought_by', 'user'].edge_index = torch.tensor([[0, 1, 2], [0, 1, 2]]) + >>> + >>> # Apply single transform + >>> transform = AddHeteroRandomWalkPE(walk_length=8) + >>> data = transform(data) + >>> print(data['user'].random_walk_pe.shape) # (5, 8) + >>> + >>> # Compose multiple transforms + >>> transform = Compose([ + ... AddHeteroRandomWalkPE(walk_length=8), + ... AddHeteroRandomWalkSE(walk_length=8), + ... AddHeteroHopDistanceEncoding(h_max=3, full_matrix=True), + ... ]) + >>> data = transform(data) + >>> + >>> # For Graph Transformers, use full_matrix=True to get full pairwise distances + >>> transform = AddHeteroHopDistanceEncoding(h_max=5, full_matrix=True) + >>> data = transform(data) + >>> print(data.hop_distance.shape) # (num_total_nodes, num_total_nodes) + >>> + >>> # For heterogeneous Graph Transformers, use node_type_aware=True to preserve + >>> # node type information for type-aware attention bias + >>> transform = AddHeteroHopDistanceEncoding(h_max=5, full_matrix=True, node_type_aware=True) + >>> data = transform(data) + >>> print(data.hop_distance.shape) # (8, 8) - pairwise distances + >>> print(data.node_type_ids.shape) # (8,) - type ID for each node + >>> print(data.node_type_pair.shape) # (8, 8) - encodes (src_type, dst_type) pairs + >>> print(data.node_type_names) # ['item', 'user'] - sorted alphabetically + >>> + >>> # In a Graph Transformer, combine hop distance and node type for attention bias: + >>> # bias = hop_embedding[data.hop_distance.long()] + type_pair_embedding[data.node_type_pair] +""" + @functional_transform('add_hetero_random_walk_pe') class AddHeteroRandomWalkPE(BaseTransform): @@ -198,6 +257,15 @@ class AddHeteroHopDistanceEncoding(BaseTransform): Based on the approach from `"Do Transformers Really Perform Bad for Graph Representation?" `_ (Graphormer). + For heterogeneous graphs, when `full_matrix=True`, additional node type + information can be preserved by setting `node_type_aware=True`. This stores: + - `data.hop_distance`: (num_nodes, num_nodes) distance matrix + - `data.node_type_ids`: (num_nodes,) node type ID for each node + - `data.node_type_pair`: (num_nodes, num_nodes) encodes (src_type, dst_type) pairs + + This allows Graph Transformers to use both structural (hop distance) and + semantic (node type) information in attention bias computation. + Args: h_max (int): Maximum hop distance to consider. Distances > h_max are clipped to h_max (representing "far" or "unreachable" nodes). @@ -215,6 +283,13 @@ class AddHeteroHopDistanceEncoding(BaseTransform): :obj:`full_matrix=False`, the hop distance for existing edges is always 1 (direct connection), which may be redundant. Use :obj:`full_matrix=True` for attention bias in Graph Transformers. + (default: :obj:`True`) + node_type_aware (bool, optional): If set to :obj:`True` (only effective + when `full_matrix=True`), also stores node type information: + - `node_type_ids`: (num_nodes,) tensor mapping each node to its type ID + - `node_type_pair`: (num_nodes, num_nodes) tensor encoding the + (src_type, dst_type) pair as `src_type * num_node_types + dst_type` + This enables type-aware attention bias in heterogeneous Graph Transformers. (default: :obj:`False`) """ def __init__( @@ -222,12 +297,14 @@ def __init__( h_max: int, attr_name: Optional[str] = 'hop_distance', is_undirected: bool = False, - full_matrix: bool = False, + full_matrix: bool = True, + node_type_aware: bool = False, ) -> None: self.h_max = h_max self.attr_name = attr_name self.is_undirected = is_undirected self.full_matrix = full_matrix + self.node_type_aware = node_type_aware def forward(self, data: HeteroData) -> HeteroData: assert isinstance(data, HeteroData), ( @@ -247,6 +324,11 @@ def forward(self, data: HeteroData) -> HeteroData: data[self.attr_name] = torch.zeros( (num_nodes, num_nodes), dtype=torch.long ) + if self.node_type_aware: + data['node_type_ids'] = torch.zeros(num_nodes, dtype=torch.long) + data['node_type_pair'] = torch.zeros( + (num_nodes, num_nodes), dtype=torch.long + ) else: for edge_type in data.edge_types: num_type_edges = data[edge_type].num_edges @@ -291,33 +373,6 @@ def forward(self, data: HeteroData) -> HeteroData: # Expand frontier to next hop neighbors current_frontier = (current_frontier.float() @ adj_dense) > 0 - # Example - # Graph: 0 -> 1 -> 2 (3 nodes, chain) - # edge_index = [[0, 1], [1, 2]] (edges: 0→1, 1→2) - # h_max = 3 - - # Initial state: - # dist_matrix = [[0, 3, 3], # 3 = h_max (unreachable) - # [3, 0, 3], - # [3, 3, 0]] - # reachable = eye(3) = [[T,F,F], [F,T,F], [F,F,T]] - # current_frontier = adj_dense = [[F,T,F], [F,F,T], [F,F,F]] - - # Hop 1: - # newly_reachable = [[F,T,F], [F,F,T], [F,F,F]] (0→1, 1→2) - # dist_matrix[0,1] = 1, dist_matrix[1,2] = 1 - # dist_matrix = [[0, 1, 3], - # [3, 0, 1], - # [3, 3, 0]] - - # Hop 2: - # current_frontier = frontier @ adj = [[F,F,T], [F,F,F], [F,F,F]] - # newly_reachable = [[F,F,T], ...] (0→2 in 2 hops) - # dist_matrix[0,2] = 2 - # dist_matrix = [[0, 1, 2], - # [3, 0, 1], - # [3, 3, 0]] - if self.full_matrix: # Store full pairwise distance matrix as graph-level attribute on HeteroData # Shape: (num_nodes, num_nodes) - for use in Graph Transformers @@ -325,17 +380,29 @@ def forward(self, data: HeteroData) -> HeteroData: # Can be used as attention bias: bias = learnable_embedding[data.hop_distance.long()] # Note: Node ordering follows data.to_homogeneous() order (by node_type alphabetically) data[self.attr_name] = dist_matrix.float() + + if self.node_type_aware: + # Store node type information for heterogeneous-aware attention + # homo_data.node_type contains the type ID for each node after to_homogeneous() + node_type_ids = homo_data.node_type # Shape: (num_nodes,) + data['node_type_ids'] = node_type_ids + + # Compute pairwise node type encoding: (src_type, dst_type) -> single ID + # node_type_pair[i, j] = node_type_ids[i] * num_node_types + node_type_ids[j] + # This allows looking up type-specific attention biases + num_node_types = len(data.node_types) + # Outer product style: src_types[:, None] * num_types + dst_types[None, :] + node_type_pair = ( + node_type_ids.unsqueeze(1) * num_node_types + + node_type_ids.unsqueeze(0) + ) + data['node_type_pair'] = node_type_pair + + # Also store the mapping from type ID to type name for reference + # Node types are sorted alphabetically in to_homogeneous() + data['node_type_names'] = sorted(data.node_types) else: # Extract hop distances for each edge in edge_index - # Example: - # Graph: 0 -> 1 -> 2 (3 nodes, chain) - # edge_index = [[0, 1], [1, 2]] (edges: 0→1, 1→2) - # h_max = 3 - # edge_index = [[0, 1], [1, 2]] - # src_nodes = [0, 1], dst_nodes = [1, 2] - # edge_hop_distances = [dist_matrix[0,1], dist_matrix[1,2]] = [1, 1] - - # Output: tensor([[1.], [1.]]) # Both edges are direct (1-hop) src_nodes = edge_index[0] # Source nodes dst_nodes = edge_index[1] # Destination nodes edge_hop_distances = dist_matrix[src_nodes, dst_nodes] @@ -346,4 +413,7 @@ def forward(self, data: HeteroData) -> HeteroData: return data def __repr__(self) -> str: - return f'{self.__class__.__name__}(h_max={self.h_max}, full_matrix={self.full_matrix})' + return ( + f'{self.__class__.__name__}(h_max={self.h_max}, ' + f'full_matrix={self.full_matrix}, node_type_aware={self.node_type_aware})' + ) diff --git a/tests/unit/transforms/add_positional_encodings_test.py b/tests/unit/transforms/add_positional_encodings_test.py index 713fdcd49..137f7a96d 100644 --- a/tests/unit/transforms/add_positional_encodings_test.py +++ b/tests/unit/transforms/add_positional_encodings_test.py @@ -47,91 +47,60 @@ def create_empty_hetero_data() -> HeteroData: return data -class TestAddHeteroHopDistanceEncoding(TestCase): + +class TestAddHeteroRandomWalkPE(TestCase): + """Tests for AddHeteroRandomWalkPE (Positional Encoding - column sum of non-diagonal).""" + def test_forward_basic(self): """Test basic forward pass.""" data = create_simple_hetero_data() - transform = AddHeteroHopDistanceEncoding(h_max=3) + transform = AddHeteroRandomWalkPE(walk_length=4) result = transform(data) - # Check that PE was added to both edge types - self.assertTrue(hasattr(result['user', 'buys', 'item'], 'hop_distance')) - self.assertTrue(hasattr(result['item', 'bought_by', 'user'], 'hop_distance')) - - # Check shapes (3 edges each) - self.assertEqual(result['user', 'buys', 'item'].hop_distance.shape, (3, 1)) - self.assertEqual(result['item', 'bought_by', 'user'].hop_distance.shape, (3, 1)) + # Check that PE was added to both node types + self.assertTrue(hasattr(result['user'], 'random_walk_pe')) + self.assertTrue(hasattr(result['item'], 'random_walk_pe')) - # Direct edges should have distance <= h_max - self.assertTrue((result['user', 'buys', 'item'].hop_distance <= 3).all()) + # Check shapes + self.assertEqual(result['user'].random_walk_pe.shape, (3, 4)) + self.assertEqual(result['item'].random_walk_pe.shape, (2, 4)) def test_forward_with_custom_attr_name(self): """Test forward pass with custom attribute name.""" data = create_simple_hetero_data() - transform = AddHeteroHopDistanceEncoding(h_max=2, attr_name='custom_hop') + transform = AddHeteroRandomWalkPE(walk_length=3, attr_name='rw_pe') result = transform(data) - self.assertTrue(hasattr(result['user', 'buys', 'item'], 'custom_hop')) - self.assertTrue(hasattr(result['item', 'bought_by', 'user'], 'custom_hop')) - self.assertFalse(hasattr(result['user', 'buys', 'item'], 'hop_distance')) + self.assertTrue(hasattr(result['user'], 'rw_pe')) + self.assertTrue(hasattr(result['item'], 'rw_pe')) + self.assertFalse(hasattr(result['user'], 'random_walk_pe')) def test_forward_undirected(self): """Test forward pass with undirected graph setting.""" data = create_simple_hetero_data() - transform = AddHeteroHopDistanceEncoding(h_max=2, is_undirected=True) + transform = AddHeteroRandomWalkPE(walk_length=3, is_undirected=True) result = transform(data) - self.assertEqual(result['user', 'buys', 'item'].hop_distance.shape, (3, 1)) - self.assertEqual(result['item', 'bought_by', 'user'].hop_distance.shape, (3, 1)) + self.assertEqual(result['user'].random_walk_pe.shape, (3, 3)) + self.assertEqual(result['item'].random_walk_pe.shape, (2, 3)) def test_forward_empty_graph(self): """Test forward pass with empty graph.""" data = create_empty_hetero_data() - # Add empty edge types - data['user', 'buys', 'item'].edge_index = torch.zeros((2, 0), dtype=torch.long) - transform = AddHeteroHopDistanceEncoding(h_max=3) - - result = transform(data) - - self.assertEqual(result['user', 'buys', 'item'].hop_distance.shape, (0,)) - - def test_forward_full_matrix(self): - """Test forward pass with full_matrix=True for Graph Transformer use.""" - data = create_simple_hetero_data() - transform = AddHeteroHopDistanceEncoding(h_max=3, full_matrix=True) - - result = transform(data) - - # Check that full pairwise distance matrix is stored as graph-level attribute - self.assertTrue(hasattr(result, 'hop_distance')) - # Total nodes: 3 users + 2 items = 5 nodes - self.assertEqual(result.hop_distance.shape, (5, 5)) - # Diagonal should be 0 (distance to self) - self.assertTrue((result.hop_distance.diag() == 0).all()) - # All distances should be <= h_max - self.assertTrue((result.hop_distance <= 3).all()) - - def test_forward_full_matrix_empty_graph(self): - """Test forward pass with full_matrix=True on empty graph.""" - data = create_empty_hetero_data() - data['user', 'buys', 'item'].edge_index = torch.zeros((2, 0), dtype=torch.long) - transform = AddHeteroHopDistanceEncoding(h_max=3, full_matrix=True) + transform = AddHeteroRandomWalkPE(walk_length=3) result = transform(data) - self.assertTrue(hasattr(result, 'hop_distance')) - self.assertEqual(result.hop_distance.shape, (0, 0)) + self.assertEqual(result['user'].random_walk_pe.shape, (0, 3)) + self.assertEqual(result['item'].random_walk_pe.shape, (0, 3)) def test_repr(self): """Test string representation.""" - transform = AddHeteroHopDistanceEncoding(h_max=5) - self.assertEqual(repr(transform), 'AddHeteroHopDistanceEncoding(h_max=5, full_matrix=False)') - - transform_full = AddHeteroHopDistanceEncoding(h_max=3, full_matrix=True) - self.assertEqual(repr(transform_full), 'AddHeteroHopDistanceEncoding(h_max=3, full_matrix=True)') + transform = AddHeteroRandomWalkPE(walk_length=10) + self.assertEqual(repr(transform), 'AddHeteroRandomWalkPE(walk_length=10)') class TestAddHeteroRandomWalkSE(TestCase): @@ -193,76 +162,176 @@ def test_repr(self): self.assertEqual(repr(transform), 'AddHeteroRandomWalkSE(walk_length=10)') -class TestAddHeteroRandomWalkPE(TestCase): - """Tests for AddHeteroRandomWalkPE (Positional Encoding - column sum of non-diagonal).""" - +class TestAddHeteroHopDistanceEncoding(TestCase): def test_forward_basic(self): - """Test basic forward pass.""" + """Test basic forward pass with full_matrix=True (default).""" data = create_simple_hetero_data() - transform = AddHeteroRandomWalkPE(walk_length=4) + transform = AddHeteroHopDistanceEncoding(h_max=3) result = transform(data) - # Check that PE was added to both node types - self.assertTrue(hasattr(result['user'], 'random_walk_pe')) - self.assertTrue(hasattr(result['item'], 'random_walk_pe')) + # Check that full pairwise distance matrix is stored as graph-level attribute + self.assertTrue(hasattr(result, 'hop_distance')) + # Total nodes: 3 users + 2 items = 5 nodes + self.assertEqual(result.hop_distance.shape, (5, 5)) - # Check shapes - self.assertEqual(result['user'].random_walk_pe.shape, (3, 4)) - self.assertEqual(result['item'].random_walk_pe.shape, (2, 4)) + def test_forward_full_matrix_false(self): + """Test forward pass with full_matrix=False (edge-level).""" + data = create_simple_hetero_data() + transform = AddHeteroHopDistanceEncoding(h_max=3, full_matrix=False) + + result = transform(data) + + # Check that PE was added to both edge types + self.assertTrue(hasattr(result['user', 'buys', 'item'], 'hop_distance')) + self.assertTrue(hasattr(result['item', 'bought_by', 'user'], 'hop_distance')) + + # Check shapes (3 edges each) + self.assertEqual(result['user', 'buys', 'item'].hop_distance.shape, (3, 1)) + self.assertEqual(result['item', 'bought_by', 'user'].hop_distance.shape, (3, 1)) + + # Direct edges should have distance <= h_max + self.assertTrue((result['user', 'buys', 'item'].hop_distance <= 3).all()) def test_forward_with_custom_attr_name(self): """Test forward pass with custom attribute name.""" data = create_simple_hetero_data() - transform = AddHeteroRandomWalkPE(walk_length=3, attr_name='rw_pe') + transform = AddHeteroHopDistanceEncoding(h_max=2, attr_name='custom_hop', full_matrix=False) result = transform(data) - self.assertTrue(hasattr(result['user'], 'rw_pe')) - self.assertTrue(hasattr(result['item'], 'rw_pe')) - self.assertFalse(hasattr(result['user'], 'random_walk_pe')) + self.assertTrue(hasattr(result['user', 'buys', 'item'], 'custom_hop')) + self.assertTrue(hasattr(result['item', 'bought_by', 'user'], 'custom_hop')) + self.assertFalse(hasattr(result['user', 'buys', 'item'], 'hop_distance')) def test_forward_undirected(self): """Test forward pass with undirected graph setting.""" data = create_simple_hetero_data() - transform = AddHeteroRandomWalkPE(walk_length=3, is_undirected=True) + transform = AddHeteroHopDistanceEncoding(h_max=2, is_undirected=True, full_matrix=False) result = transform(data) - self.assertEqual(result['user'].random_walk_pe.shape, (3, 3)) - self.assertEqual(result['item'].random_walk_pe.shape, (2, 3)) + self.assertEqual(result['user', 'buys', 'item'].hop_distance.shape, (3, 1)) + self.assertEqual(result['item', 'bought_by', 'user'].hop_distance.shape, (3, 1)) def test_forward_empty_graph(self): """Test forward pass with empty graph.""" data = create_empty_hetero_data() - transform = AddHeteroRandomWalkPE(walk_length=3) + # Add empty edge types + data['user', 'buys', 'item'].edge_index = torch.zeros((2, 0), dtype=torch.long) + transform = AddHeteroHopDistanceEncoding(h_max=3, full_matrix=False) result = transform(data) - self.assertEqual(result['user'].random_walk_pe.shape, (0, 3)) - self.assertEqual(result['item'].random_walk_pe.shape, (0, 3)) + self.assertEqual(result['user', 'buys', 'item'].hop_distance.shape, (0,)) + + def test_forward_full_matrix(self): + """Test forward pass with full_matrix=True for Graph Transformer use.""" + data = create_simple_hetero_data() + transform = AddHeteroHopDistanceEncoding(h_max=3, full_matrix=True) + + result = transform(data) + + # Check that full pairwise distance matrix is stored as graph-level attribute + self.assertTrue(hasattr(result, 'hop_distance')) + # Total nodes: 3 users + 2 items = 5 nodes + self.assertEqual(result.hop_distance.shape, (5, 5)) + # Diagonal should be 0 (distance to self) + self.assertTrue((result.hop_distance.diag() == 0).all()) + # All distances should be <= h_max + self.assertTrue((result.hop_distance <= 3).all()) + + def test_forward_full_matrix_empty_graph(self): + """Test forward pass with full_matrix=True on empty graph.""" + data = create_empty_hetero_data() + data['user', 'buys', 'item'].edge_index = torch.zeros((2, 0), dtype=torch.long) + transform = AddHeteroHopDistanceEncoding(h_max=3, full_matrix=True) + + result = transform(data) - def test_pe_differs_from_se(self): - """Test that PE (column sum) differs from SE (diagonal).""" + self.assertTrue(hasattr(result, 'hop_distance')) + self.assertEqual(result.hop_distance.shape, (0, 0)) + + def test_forward_node_type_aware(self): + """Test forward pass with node_type_aware=True for heterogeneous Graph Transformers.""" data = create_simple_hetero_data() - transform_pe = AddHeteroRandomWalkPE(walk_length=4) - transform_se = AddHeteroRandomWalkSE(walk_length=4) + transform = AddHeteroHopDistanceEncoding(h_max=3, full_matrix=True, node_type_aware=True) + + result = transform(data) - result_pe = transform_pe(data.clone()) - result_se = transform_se(data.clone()) + # Check that hop distance matrix is stored + self.assertTrue(hasattr(result, 'hop_distance')) + self.assertEqual(result.hop_distance.shape, (5, 5)) - # PE and SE should have different values (column sum vs diagonal) - # They may occasionally match for specific graphs, but generally differ - pe_values = result_pe['user'].random_walk_pe - se_values = result_se['user'].random_walk_se + # Check that node type information is stored + self.assertTrue(hasattr(result, 'node_type_ids')) + self.assertEqual(result.node_type_ids.shape, (5,)) - # Check shapes are the same - self.assertEqual(pe_values.shape, se_values.shape) + # Check that node type pair matrix is stored + self.assertTrue(hasattr(result, 'node_type_pair')) + self.assertEqual(result.node_type_pair.shape, (5, 5)) + + # Check that node type names are stored + self.assertTrue(hasattr(result, 'node_type_names')) + self.assertEqual(result.node_type_names, ['item', 'user']) # Sorted alphabetically + + # Verify node_type_ids values are valid (0 or 1 for 2 node types) + self.assertTrue((result.node_type_ids >= 0).all()) + self.assertTrue((result.node_type_ids < 2).all()) + + # Verify node_type_pair encodes (src_type, dst_type) correctly + # For 2 node types, pair values should be in [0, 3] (0*2+0, 0*2+1, 1*2+0, 1*2+1) + self.assertTrue((result.node_type_pair >= 0).all()) + self.assertTrue((result.node_type_pair < 4).all()) + + def test_forward_node_type_aware_empty_graph(self): + """Test forward pass with node_type_aware=True on empty graph.""" + data = create_empty_hetero_data() + data['user', 'buys', 'item'].edge_index = torch.zeros((2, 0), dtype=torch.long) + transform = AddHeteroHopDistanceEncoding(h_max=3, full_matrix=True, node_type_aware=True) + + result = transform(data) + + self.assertTrue(hasattr(result, 'hop_distance')) + self.assertEqual(result.hop_distance.shape, (0, 0)) + self.assertTrue(hasattr(result, 'node_type_ids')) + self.assertEqual(result.node_type_ids.shape, (0,)) + self.assertTrue(hasattr(result, 'node_type_pair')) + self.assertEqual(result.node_type_pair.shape, (0, 0)) + + def test_node_type_aware_only_with_full_matrix(self): + """Test that node_type_aware only takes effect when full_matrix=True.""" + data = create_simple_hetero_data() + # node_type_aware=True but full_matrix=False should not add node type info + transform = AddHeteroHopDistanceEncoding(h_max=3, full_matrix=False, node_type_aware=True) + + result = transform(data) + + # Should have edge-level hop_distance + self.assertTrue(hasattr(result['user', 'buys', 'item'], 'hop_distance')) + # Should NOT have node type information (only added when full_matrix=True) + self.assertFalse(hasattr(result, 'node_type_ids')) + self.assertFalse(hasattr(result, 'node_type_pair')) def test_repr(self): """Test string representation.""" - transform = AddHeteroRandomWalkPE(walk_length=10) - self.assertEqual(repr(transform), 'AddHeteroRandomWalkPE(walk_length=10)') + transform = AddHeteroHopDistanceEncoding(h_max=5, full_matrix=False) + self.assertEqual( + repr(transform), + 'AddHeteroHopDistanceEncoding(h_max=5, full_matrix=False, node_type_aware=False)' + ) + + transform_full = AddHeteroHopDistanceEncoding(h_max=3, full_matrix=True) + self.assertEqual( + repr(transform_full), + 'AddHeteroHopDistanceEncoding(h_max=3, full_matrix=True, node_type_aware=False)' + ) + + transform_type_aware = AddHeteroHopDistanceEncoding(h_max=3, full_matrix=True, node_type_aware=True) + self.assertEqual( + repr(transform_type_aware), + 'AddHeteroHopDistanceEncoding(h_max=3, full_matrix=True, node_type_aware=True)' + ) if __name__ == '__main__': From d4e97a15126cde7807203a742a250b1b84d8a86c Mon Sep 17 00:00:00 2001 From: Yozen Liu Date: Tue, 24 Feb 2026 11:16:25 -0800 Subject: [PATCH 05/10] rm claude --- gigl/.claude/worktrees/loving-noyce | 1 - gigl/.claude/worktrees/zen-bhaskara | 1 - 2 files changed, 2 deletions(-) delete mode 160000 gigl/.claude/worktrees/loving-noyce delete mode 160000 gigl/.claude/worktrees/zen-bhaskara diff --git a/gigl/.claude/worktrees/loving-noyce b/gigl/.claude/worktrees/loving-noyce deleted file mode 160000 index 817294a12..000000000 --- a/gigl/.claude/worktrees/loving-noyce +++ /dev/null @@ -1 +0,0 @@ -Subproject commit 817294a12143f513ce396b50344ba3a767d84c6f diff --git a/gigl/.claude/worktrees/zen-bhaskara b/gigl/.claude/worktrees/zen-bhaskara deleted file mode 160000 index 4297f49cc..000000000 --- a/gigl/.claude/worktrees/zen-bhaskara +++ /dev/null @@ -1 +0,0 @@ -Subproject commit 4297f49ccaa4f9956400d0c18438681757cd2ca8 From 8d366a30d74b98bd0fe034c290f2208da33185cb Mon Sep 17 00:00:00 2001 From: Yozen Liu Date: Tue, 24 Feb 2026 17:34:54 -0800 Subject: [PATCH 06/10] update to sparse operations --- gigl/transforms/add_positional_encodings.py | 297 ++++++++++++------ .../add_positional_encodings_test.py | 137 +++++++- 2 files changed, 338 insertions(+), 96 deletions(-) diff --git a/gigl/transforms/add_positional_encodings.py b/gigl/transforms/add_positional_encodings.py index ddfbf4e7f..90e378ad8 100644 --- a/gigl/transforms/add_positional_encodings.py +++ b/gigl/transforms/add_positional_encodings.py @@ -92,16 +92,21 @@ class AddHeteroRandomWalkPE(BaseTransform): is_undirected (bool, optional): If set to :obj:`True`, the graph is assumed to be undirected, and the adjacency matrix will be made symmetric. (default: :obj:`False`) + attach_to_x (bool, optional): If set to :obj:`True`, the encoding is + concatenated directly to :obj:`data[node_type].x` for each node type + instead of being stored as a separate attribute. (default: :obj:`False`) """ def __init__( self, walk_length: int, attr_name: Optional[str] = 'random_walk_pe', is_undirected: bool = False, + attach_to_x: bool = False, ) -> None: self.walk_length = walk_length self.attr_name = attr_name self.is_undirected = is_undirected + self.attach_to_x = attach_to_x def forward(self, data: HeteroData) -> HeteroData: assert isinstance(data, HeteroData), ( @@ -116,47 +121,73 @@ def forward(self, data: HeteroData) -> HeteroData: if num_nodes == 0: for node_type in data.node_types: - data[node_type][self.attr_name] = torch.zeros( + empty_pe = torch.zeros( (data[node_type].num_nodes, self.walk_length), dtype=torch.float, ) + effective_attr_name = None if self.attach_to_x else self.attr_name + add_node_attr(data, {node_type: empty_pe}, effective_attr_name) return data - # Compute transition matrix (row-stochastic) + # Compute transition matrix (row-stochastic) using sparse operations adj = to_torch_sparse_tensor(edge_index, size=(num_nodes, num_nodes)) - adj_dense = adj.to_dense() if self.is_undirected: # Make symmetric for undirected graphs - adj_dense = adj_dense + adj_dense.t() + adj = (adj + adj.t()).coalesce() - # Compute degree and create transition matrix - deg = adj_dense.sum(dim=1, keepdim=True) + # Compute degree for row normalization + adj_coalesced = adj.coalesce() + deg = torch.zeros(num_nodes, device=edge_index.device) + deg.scatter_add_(0, adj_coalesced.indices()[0], adj_coalesced.values().float()) deg = torch.clamp(deg, min=1) # Avoid division by zero - transition = adj_dense / deg - # Compute random walk positional encoding + # Create row-normalized transition matrix (sparse) + # P[i,j] = A[i,j] / deg[i] + row_indices = adj_coalesced.indices()[0] + normalized_values = adj_coalesced.values().float() / deg[row_indices] + transition = torch.sparse_coo_tensor( + adj_coalesced.indices(), + normalized_values, + size=(num_nodes, num_nodes), + ).coalesce() + + # Compute random walk positional encoding using sparse operations # PE[j, k] = sum of column j excluding diagonal = Σ_{i≠j} (P^k)[i, j] - pe = torch.zeros((num_nodes, self.walk_length), dtype=torch.float) - current = torch.eye(num_nodes, dtype=torch.float) + pe = torch.zeros((num_nodes, self.walk_length), dtype=torch.float, device=edge_index.device) + + # Start with identity matrix (sparse) + identity_indices = torch.arange(num_nodes, device=edge_index.device) + current = torch.sparse_coo_tensor( + torch.stack([identity_indices, identity_indices]), + torch.ones(num_nodes, device=edge_index.device), + size=(num_nodes, num_nodes), + ).coalesce() for k in range(self.walk_length): - current = current @ transition - # Sum each column, excluding diagonal elements - # column_sum[j] = Σ_i current[i, j] - # diagonal[j] = current[j, j] - # non_diagonal_column_sum[j] = column_sum[j] - diagonal[j] - column_sum = current.sum(dim=0) # Sum along rows for each column - diagonal = current.diag() - pe[:, k] = column_sum - diagonal + current = torch.sparse.mm(current, transition).coalesce() + # Column sum = sum over rows for each column + col_sum = torch.zeros(num_nodes, device=edge_index.device) + col_sum.scatter_add_(0, current.indices()[1], current.values()) + # Extract diagonal elements + diag = torch.zeros(num_nodes, device=edge_index.device) + diag_mask = current.indices()[0] == current.indices()[1] + if diag_mask.any(): + diag.scatter_add_(0, current.indices()[0][diag_mask], current.values()[diag_mask]) + pe[:, k] = col_sum - diag # Map back to HeteroData node types - add_node_attr(data, pe, self.attr_name) + # If attach_to_x is True, pass None as attr_name to concatenate to x directly + effective_attr_name = None if self.attach_to_x else self.attr_name + add_node_attr(data, pe, effective_attr_name) return data def __repr__(self) -> str: - return f'{self.__class__.__name__}(walk_length={self.walk_length})' + return ( + f'{self.__class__.__name__}(walk_length={self.walk_length}, ' + f'attach_to_x={self.attach_to_x})' + ) @functional_transform('add_hetero_random_walk_se') @@ -182,16 +213,21 @@ class AddHeteroRandomWalkSE(BaseTransform): is_undirected (bool, optional): If set to :obj:`True`, the graph is assumed to be undirected, and the adjacency matrix will be made symmetric. (default: :obj:`False`) + attach_to_x (bool, optional): If set to :obj:`True`, the encoding is + concatenated directly to :obj:`data[node_type].x` for each node type + instead of being stored as a separate attribute. (default: :obj:`False`) """ def __init__( self, walk_length: int, attr_name: Optional[str] = 'random_walk_se', is_undirected: bool = False, + attach_to_x: bool = False, ) -> None: self.walk_length = walk_length self.attr_name = attr_name self.is_undirected = is_undirected + self.attach_to_x = attach_to_x def forward(self, data: HeteroData) -> HeteroData: assert isinstance(data, HeteroData), ( @@ -206,41 +242,67 @@ def forward(self, data: HeteroData) -> HeteroData: if num_nodes == 0: for node_type in data.node_types: - data[node_type][self.attr_name] = torch.zeros( + empty_se = torch.zeros( (data[node_type].num_nodes, self.walk_length), dtype=torch.float, ) + effective_attr_name = None if self.attach_to_x else self.attr_name + add_node_attr(data, {node_type: empty_se}, effective_attr_name) return data - # Compute transition matrix (row-stochastic) + # Compute transition matrix (row-stochastic) using sparse operations adj = to_torch_sparse_tensor(edge_index, size=(num_nodes, num_nodes)) - adj_dense = adj.to_dense() if self.is_undirected: # Make symmetric for undirected graphs - adj_dense = adj_dense + adj_dense.t() + adj = (adj + adj.t()).coalesce() - # Compute degree and create transition matrix - deg = adj_dense.sum(dim=1, keepdim=True) + # Compute degree for row normalization + adj_coalesced = adj.coalesce() + deg = torch.zeros(num_nodes, device=edge_index.device) + deg.scatter_add_(0, adj_coalesced.indices()[0], adj_coalesced.values().float()) deg = torch.clamp(deg, min=1) # Avoid division by zero - transition = adj_dense / deg - # Compute random walk return probabilities (diagonal elements) - se = torch.zeros((num_nodes, self.walk_length), dtype=torch.float) - current = torch.eye(num_nodes, dtype=torch.float) + # Create row-normalized transition matrix (sparse) + # P[i,j] = A[i,j] / deg[i] + row_indices = adj_coalesced.indices()[0] + normalized_values = adj_coalesced.values().float() / deg[row_indices] + transition = torch.sparse_coo_tensor( + adj_coalesced.indices(), + normalized_values, + size=(num_nodes, num_nodes), + ).coalesce() + + # Compute random walk return probabilities (diagonal elements) using sparse operations + se = torch.zeros((num_nodes, self.walk_length), dtype=torch.float, device=edge_index.device) + + # Start with identity matrix (sparse) + identity_indices = torch.arange(num_nodes, device=edge_index.device) + current = torch.sparse_coo_tensor( + torch.stack([identity_indices, identity_indices]), + torch.ones(num_nodes, device=edge_index.device), + size=(num_nodes, num_nodes), + ).coalesce() - for i in range(self.walk_length): - current = current @ transition - # Diagonal gives probability of returning to the same node - se[:, i] = current.diag() + for k in range(self.walk_length): + current = torch.sparse.mm(current, transition).coalesce() + # Extract diagonal elements: probability of returning to the same node + diag_mask = current.indices()[0] == current.indices()[1] + if diag_mask.any(): + se[:, k].scatter_add_(0, current.indices()[0][diag_mask], current.values()[diag_mask]) # Map back to HeteroData node types - add_node_attr(data, se, self.attr_name) + # If attach_to_x is True, pass None as attr_name to concatenate to x directly + effective_attr_name = None if self.attach_to_x else self.attr_name + add_node_attr(data, se, effective_attr_name) return data def __repr__(self) -> str: - return f'{self.__class__.__name__}(walk_length={self.walk_length})' + return ( + f'{self.__class__.__name__}(walk_length={self.walk_length}, ' + f'attach_to_x={self.attach_to_x})' + ) @@ -337,78 +399,125 @@ def forward(self, data: HeteroData) -> HeteroData: ) return data - # Build adjacency matrix for shortest path computation + # Build sparse adjacency matrix for shortest path computation adj = to_torch_sparse_tensor(edge_index, size=(num_nodes, num_nodes)) - adj_dense = adj.to_dense() if self.is_undirected: # Make symmetric for undirected graphs - adj_dense = adj_dense + adj_dense.t() + adj = (adj + adj.t()).coalesce() + + # Binarize adjacency (sparse) + adj_coalesced = adj.coalesce() + adj = torch.sparse_coo_tensor( + adj_coalesced.indices(), + torch.ones(adj_coalesced.indices().size(1), device=edge_index.device), + size=(num_nodes, num_nodes), + ).coalesce() + + if not self.full_matrix: + # For edge-level distances only, use sparse BFS from source nodes + # This avoids materializing the full distance matrix + src_nodes = edge_index[0] + dst_nodes = edge_index[1] + edge_hop_distances = torch.full((num_edges,), self.h_max, dtype=torch.long, device=edge_index.device) + + # For direct edges, distance is 1 + edge_hop_distances[:] = 1 - adj_dense = (adj_dense > 0).float() # Binary adjacency + # Map back to HeteroData edge types + add_edge_attr(data, edge_hop_distances.unsqueeze(-1).float(), self.attr_name) + return data - # Compute shortest path distances using BFS via matrix powers - # dist_matrix[i, j] = shortest path distance from node i to node j + # For full_matrix=True, we need the complete distance matrix + # Use sparse BFS but accumulate into dense distance matrix + device = edge_index.device dist_matrix = torch.full( - (num_nodes, num_nodes), self.h_max, dtype=torch.long + (num_nodes, num_nodes), self.h_max, dtype=torch.long, device=device ) dist_matrix.fill_diagonal_(0) # Distance to self is 0 - # BFS: track which nodes are reachable and at what distance - reachable = torch.eye(num_nodes, dtype=torch.bool) - current_frontier = adj_dense.bool() + # Track reachability using sparse tensors + # reachable[i,j] = 1 if j is reachable from i + identity_indices = torch.arange(num_nodes, device=device) + reachable_indices = torch.stack([identity_indices, identity_indices]) + reachable_values = torch.ones(num_nodes, device=device, dtype=torch.bool) + + # Current frontier (sparse): nodes reachable at current hop + frontier = adj.coalesce() for hop in range(1, self.h_max + 1): - # Nodes newly reachable at this hop (not previously seen) - newly_reachable = current_frontier & ~reachable - # Set distance for newly reachable nodes - dist_matrix[newly_reachable] = hop - # Update reachable set - reachable = reachable | current_frontier - - # Early exit if all nodes are reachable (no need to continue BFS) - if reachable.all(): + frontier_indices = frontier.indices() + frontier_values = frontier.values() + + if frontier_indices.size(1) == 0: break - # Expand frontier to next hop neighbors - current_frontier = (current_frontier.float() @ adj_dense) > 0 - - if self.full_matrix: - # Store full pairwise distance matrix as graph-level attribute on HeteroData - # Shape: (num_nodes, num_nodes) - for use in Graph Transformers - # Access via: data.hop_distance or data['hop_distance'] - # Can be used as attention bias: bias = learnable_embedding[data.hop_distance.long()] - # Note: Node ordering follows data.to_homogeneous() order (by node_type alphabetically) - data[self.attr_name] = dist_matrix.float() - - if self.node_type_aware: - # Store node type information for heterogeneous-aware attention - # homo_data.node_type contains the type ID for each node after to_homogeneous() - node_type_ids = homo_data.node_type # Shape: (num_nodes,) - data['node_type_ids'] = node_type_ids - - # Compute pairwise node type encoding: (src_type, dst_type) -> single ID - # node_type_pair[i, j] = node_type_ids[i] * num_node_types + node_type_ids[j] - # This allows looking up type-specific attention biases - num_node_types = len(data.node_types) - # Outer product style: src_types[:, None] * num_types + dst_types[None, :] - node_type_pair = ( - node_type_ids.unsqueeze(1) * num_node_types + - node_type_ids.unsqueeze(0) - ) - data['node_type_pair'] = node_type_pair - - # Also store the mapping from type ID to type name for reference - # Node types are sorted alphabetically in to_homogeneous() - data['node_type_names'] = sorted(data.node_types) - else: - # Extract hop distances for each edge in edge_index - src_nodes = edge_index[0] # Source nodes - dst_nodes = edge_index[1] # Destination nodes - edge_hop_distances = dist_matrix[src_nodes, dst_nodes] - # Store hop distances only for existing edges - # Map back to HeteroData edge types - add_edge_attr(data, edge_hop_distances.unsqueeze(-1).float(), self.attr_name) + # Find newly reachable: in frontier but not in reachable + # Check each (i,j) in frontier against reachable + frontier_i = frontier_indices[0] + frontier_j = frontier_indices[1] + + # Create a set of reachable pairs for fast lookup + # Convert to linear indices for comparison + reachable_linear = reachable_indices[0] * num_nodes + reachable_indices[1] + frontier_linear = frontier_i * num_nodes + frontier_j + + # Find which frontier edges are not yet reachable + # Use searchsorted for efficiency + reachable_linear_sorted, sort_idx = reachable_linear.sort() + insert_pos = torch.searchsorted(reachable_linear_sorted, frontier_linear) + insert_pos = insert_pos.clamp(max=reachable_linear_sorted.size(0) - 1) + is_new = reachable_linear_sorted[insert_pos] != frontier_linear + + if is_new.any(): + new_i = frontier_i[is_new] + new_j = frontier_j[is_new] + dist_matrix[new_i, new_j] = hop + + # Update reachable set + reachable_indices = torch.cat([ + reachable_indices, + torch.stack([new_i, new_j]) + ], dim=1) + reachable_values = torch.cat([ + reachable_values, + torch.ones(new_i.size(0), device=device, dtype=torch.bool) + ]) + + # Check if all pairs are reachable + if reachable_indices.size(1) >= num_nodes * num_nodes: + break + + # Expand frontier: frontier = frontier @ adj (sparse matmul) + frontier = torch.sparse.mm(frontier, adj).coalesce() + + # Store full pairwise distance matrix as graph-level attribute on HeteroData + # Shape: (num_nodes, num_nodes) - for use in Graph Transformers + # Access via: data.hop_distance or data['hop_distance'] + # Can be used as attention bias: bias = learnable_embedding[data.hop_distance.long()] + # Note: Node ordering follows data.to_homogeneous() order (by node_type alphabetically) + data[self.attr_name] = dist_matrix.float() + + if self.node_type_aware: + # Store node type information for heterogeneous-aware attention + # homo_data.node_type contains the type ID for each node after to_homogeneous() + node_type_ids = homo_data.node_type # Shape: (num_nodes,) + data['node_type_ids'] = node_type_ids + + # Compute pairwise node type encoding: (src_type, dst_type) -> single ID + # node_type_pair[i, j] = node_type_ids[i] * num_node_types + node_type_ids[j] + # This allows looking up type-specific attention biases + num_node_types = len(data.node_types) + # Outer product style: src_types[:, None] * num_types + dst_types[None, :] + node_type_pair = ( + node_type_ids.unsqueeze(1) * num_node_types + + node_type_ids.unsqueeze(0) + ) + data['node_type_pair'] = node_type_pair + + # Also store the mapping from type ID to type name for reference + # Node types are sorted alphabetically in to_homogeneous() + data['node_type_names'] = sorted(data.node_types) return data diff --git a/tests/unit/transforms/add_positional_encodings_test.py b/tests/unit/transforms/add_positional_encodings_test.py index 137f7a96d..f6d1a1c62 100644 --- a/tests/unit/transforms/add_positional_encodings_test.py +++ b/tests/unit/transforms/add_positional_encodings_test.py @@ -100,7 +100,69 @@ def test_forward_empty_graph(self): def test_repr(self): """Test string representation.""" transform = AddHeteroRandomWalkPE(walk_length=10) - self.assertEqual(repr(transform), 'AddHeteroRandomWalkPE(walk_length=10)') + self.assertEqual(repr(transform), 'AddHeteroRandomWalkPE(walk_length=10, attach_to_x=False)') + + def test_forward_attach_to_x(self): + """Test forward pass with attach_to_x=True concatenates PE to node features.""" + data = create_simple_hetero_data() + original_user_dim = data['user'].x.shape[1] # 4 + original_item_dim = data['item'].x.shape[1] # 4 + walk_length = 3 + transform = AddHeteroRandomWalkPE(walk_length=walk_length, attach_to_x=True) + + result = transform(data) + + # Check that PE was NOT added as separate attribute + self.assertFalse(hasattr(result['user'], 'random_walk_pe')) + self.assertFalse(hasattr(result['item'], 'random_walk_pe')) + + # Check that x was expanded with PE dimensions + self.assertEqual(result['user'].x.shape, (3, original_user_dim + walk_length)) + self.assertEqual(result['item'].x.shape, (2, original_item_dim + walk_length)) + + def test_forward_attach_to_x_no_existing_features(self): + """Test forward pass with attach_to_x=True when nodes have no existing features.""" + data = HeteroData() + data['user'].num_nodes = 3 + data['item'].num_nodes = 2 + data['user', 'buys', 'item'].edge_index = torch.tensor([ + [0, 1, 2], + [0, 0, 1], + ]) + data['item', 'bought_by', 'user'].edge_index = torch.tensor([ + [0, 0, 1], + [0, 1, 2], + ]) + + walk_length = 4 + transform = AddHeteroRandomWalkPE(walk_length=walk_length, attach_to_x=True) + + result = transform(data) + + # Check that x was created with PE as features + self.assertTrue(hasattr(result['user'], 'x')) + self.assertTrue(hasattr(result['item'], 'x')) + self.assertEqual(result['user'].x.shape, (3, walk_length)) + self.assertEqual(result['item'].x.shape, (2, walk_length)) + + def test_forward_attach_to_x_empty_graph(self): + """Test forward pass with attach_to_x=True on empty graph.""" + data = create_empty_hetero_data() + original_user_dim = data['user'].x.shape[1] # 4 + original_item_dim = data['item'].x.shape[1] # 4 + walk_length = 3 + transform = AddHeteroRandomWalkPE(walk_length=walk_length, attach_to_x=True) + + result = transform(data) + + # Check shapes on empty graph + self.assertEqual(result['user'].x.shape, (0, original_user_dim + walk_length)) + self.assertEqual(result['item'].x.shape, (0, original_item_dim + walk_length)) + + def test_repr_attach_to_x(self): + """Test string representation with attach_to_x=True.""" + transform = AddHeteroRandomWalkPE(walk_length=10, attach_to_x=True) + self.assertEqual(repr(transform), 'AddHeteroRandomWalkPE(walk_length=10, attach_to_x=True)') class TestAddHeteroRandomWalkSE(TestCase): @@ -159,7 +221,78 @@ def test_forward_empty_graph(self): def test_repr(self): """Test string representation.""" transform = AddHeteroRandomWalkSE(walk_length=10) - self.assertEqual(repr(transform), 'AddHeteroRandomWalkSE(walk_length=10)') + self.assertEqual(repr(transform), 'AddHeteroRandomWalkSE(walk_length=10, attach_to_x=False)') + + def test_forward_attach_to_x(self): + """Test forward pass with attach_to_x=True concatenates SE to node features.""" + data = create_simple_hetero_data() + original_user_dim = data['user'].x.shape[1] # 4 + original_item_dim = data['item'].x.shape[1] # 4 + walk_length = 3 + transform = AddHeteroRandomWalkSE(walk_length=walk_length, attach_to_x=True) + + result = transform(data) + + # Check that SE was NOT added as separate attribute + self.assertFalse(hasattr(result['user'], 'random_walk_se')) + self.assertFalse(hasattr(result['item'], 'random_walk_se')) + + # Check that x was expanded with SE dimensions + self.assertEqual(result['user'].x.shape, (3, original_user_dim + walk_length)) + self.assertEqual(result['item'].x.shape, (2, original_item_dim + walk_length)) + + # The appended values should be valid probabilities (between 0 and 1) + # Extract the SE portion (last walk_length columns) + user_se = result['user'].x[:, -walk_length:] + item_se = result['item'].x[:, -walk_length:] + self.assertTrue((user_se >= 0).all()) + self.assertTrue((user_se <= 1).all()) + self.assertTrue((item_se >= 0).all()) + self.assertTrue((item_se <= 1).all()) + + def test_forward_attach_to_x_no_existing_features(self): + """Test forward pass with attach_to_x=True when nodes have no existing features.""" + data = HeteroData() + data['user'].num_nodes = 3 + data['item'].num_nodes = 2 + data['user', 'buys', 'item'].edge_index = torch.tensor([ + [0, 1, 2], + [0, 0, 1], + ]) + data['item', 'bought_by', 'user'].edge_index = torch.tensor([ + [0, 0, 1], + [0, 1, 2], + ]) + + walk_length = 4 + transform = AddHeteroRandomWalkSE(walk_length=walk_length, attach_to_x=True) + + result = transform(data) + + # Check that x was created with SE as features + self.assertTrue(hasattr(result['user'], 'x')) + self.assertTrue(hasattr(result['item'], 'x')) + self.assertEqual(result['user'].x.shape, (3, walk_length)) + self.assertEqual(result['item'].x.shape, (2, walk_length)) + + def test_forward_attach_to_x_empty_graph(self): + """Test forward pass with attach_to_x=True on empty graph.""" + data = create_empty_hetero_data() + original_user_dim = data['user'].x.shape[1] # 4 + original_item_dim = data['item'].x.shape[1] # 4 + walk_length = 3 + transform = AddHeteroRandomWalkSE(walk_length=walk_length, attach_to_x=True) + + result = transform(data) + + # Check shapes on empty graph + self.assertEqual(result['user'].x.shape, (0, original_user_dim + walk_length)) + self.assertEqual(result['item'].x.shape, (0, original_item_dim + walk_length)) + + def test_repr_attach_to_x(self): + """Test string representation with attach_to_x=True.""" + transform = AddHeteroRandomWalkSE(walk_length=10, attach_to_x=True) + self.assertEqual(repr(transform), 'AddHeteroRandomWalkSE(walk_length=10, attach_to_x=True)') class TestAddHeteroHopDistanceEncoding(TestCase): From 00667eef506df490913d3c0dc90e27974e0c0f59 Mon Sep 17 00:00:00 2001 From: Yozen Liu Date: Wed, 25 Feb 2026 17:07:51 -0800 Subject: [PATCH 07/10] update hop distance --- gigl/transforms/add_positional_encodings.py | 227 +++++++++--------- .../add_positional_encodings_test.py | 118 ++++----- 2 files changed, 148 insertions(+), 197 deletions(-) diff --git a/gigl/transforms/add_positional_encodings.py b/gigl/transforms/add_positional_encodings.py index 90e378ad8..0e0baebde 100644 --- a/gigl/transforms/add_positional_encodings.py +++ b/gigl/transforms/add_positional_encodings.py @@ -7,7 +7,7 @@ from torch_geometric.transforms import BaseTransform from torch_geometric.utils import to_torch_sparse_tensor -from gigl.transforms.utils import add_node_attr, add_edge_attr +from gigl.transforms.utils import add_node_attr r""" Positional and Structural Encodings for Heterogeneous Graphs. @@ -46,22 +46,24 @@ >>> transform = Compose([ ... AddHeteroRandomWalkPE(walk_length=8), ... AddHeteroRandomWalkSE(walk_length=8), - ... AddHeteroHopDistanceEncoding(h_max=3, full_matrix=True), + ... AddHeteroHopDistanceEncoding(h_max=3), ... ]) >>> data = transform(data) >>> - >>> # For Graph Transformers, use full_matrix=True to get full pairwise distances - >>> transform = AddHeteroHopDistanceEncoding(h_max=5, full_matrix=True) + >>> # For Graph Transformers, use hop distance encoding for attention bias + >>> # Returns sparse matrix (0 for unreachable, 1-h_max for reachable pairs) + >>> transform = AddHeteroHopDistanceEncoding(h_max=5) >>> data = transform(data) - >>> print(data.hop_distance.shape) # (num_total_nodes, num_total_nodes) + >>> print(data.hop_distance.shape) # (num_total_nodes, num_total_nodes) sparse + >>> print(data.hop_distance.is_sparse) # True >>> >>> # For heterogeneous Graph Transformers, use node_type_aware=True to preserve >>> # node type information for type-aware attention bias - >>> transform = AddHeteroHopDistanceEncoding(h_max=5, full_matrix=True, node_type_aware=True) + >>> transform = AddHeteroHopDistanceEncoding(h_max=5, node_type_aware=True) >>> data = transform(data) - >>> print(data.hop_distance.shape) # (8, 8) - pairwise distances + >>> print(data.hop_distance.shape) # (8, 8) - sparse pairwise distances >>> print(data.node_type_ids.shape) # (8,) - type ID for each node - >>> print(data.node_type_pair.shape) # (8, 8) - encodes (src_type, dst_type) pairs + >>> print(data.node_type_pair.shape) # (8, 8) - sparse (src_type, dst_type) pairs >>> print(data.node_type_names) # ['item', 'user'] - sorted alphabetically >>> >>> # In a Graph Transformer, combine hop distance and node type for attention bias: @@ -308,7 +310,7 @@ def __repr__(self) -> str: @functional_transform('add_hetero_hop_distance_encoding') class AddHeteroHopDistanceEncoding(BaseTransform): - r"""Adds hop distance positional encoding as relative encoding. + r"""Adds hop distance positional encoding as relative encoding (sparse). For each pair of nodes (vi, vj), computes the shortest path distance p(vi, vj). This captures structural proximity and can be used with a learnable embedding @@ -319,37 +321,36 @@ class AddHeteroHopDistanceEncoding(BaseTransform): Based on the approach from `"Do Transformers Really Perform Bad for Graph Representation?" `_ (Graphormer). - For heterogeneous graphs, when `full_matrix=True`, additional node type - information can be preserved by setting `node_type_aware=True`. This stores: - - `data.hop_distance`: (num_nodes, num_nodes) distance matrix + The output is a **sparse matrix** where: + - Reachable pairs (i, j) within h_max hops have value = hop distance (1 to h_max) + - Unreachable pairs have value = 0 (not stored in sparse tensor) + - Self-loops (diagonal) are not stored (distance to self is implicitly 0) + + This sparse representation avoids GPU memory blowup for large graphs. + + For heterogeneous graphs, additional node type information can be preserved + by setting `node_type_aware=True`. This stores: + - `data.hop_distance`: sparse (num_nodes, num_nodes) distance matrix - `data.node_type_ids`: (num_nodes,) node type ID for each node - - `data.node_type_pair`: (num_nodes, num_nodes) encodes (src_type, dst_type) pairs + - `data.node_type_pair`: sparse (num_nodes, num_nodes) encodes (src_type, dst_type) pairs This allows Graph Transformers to use both structural (hop distance) and semantic (node type) information in attention bias computation. Args: h_max (int): Maximum hop distance to consider. Distances > h_max - are clipped to h_max (representing "far" or "unreachable" nodes). - Set to 2 - 3 for 2hop sampled subgraphs. + are treated as unreachable (value 0 in sparse matrix). + Set to 2-3 for 2-hop sampled subgraphs. Set to min(walk_length // 2, 10) for random walk sampled subgraphs. attr_name (str, optional): The attribute name of the positional encoding. (default: :obj:`"hop_distance"`) is_undirected (bool, optional): If set to :obj:`True`, the graph is assumed to be undirected for distance computation. (default: :obj:`False`) - full_matrix (bool, optional): If set to :obj:`True`, stores the full - pairwise distance matrix as a graph-level attribute (for use in - Graph Transformers with fully-connected attention). If :obj:`False`, - stores hop distances only for existing edges. Note that when - :obj:`full_matrix=False`, the hop distance for existing edges is - always 1 (direct connection), which may be redundant. Use - :obj:`full_matrix=True` for attention bias in Graph Transformers. - (default: :obj:`True`) - node_type_aware (bool, optional): If set to :obj:`True` (only effective - when `full_matrix=True`), also stores node type information: + node_type_aware (bool, optional): If set to :obj:`True`, also stores + node type information: - `node_type_ids`: (num_nodes,) tensor mapping each node to its type ID - - `node_type_pair`: (num_nodes, num_nodes) tensor encoding the + - `node_type_pair`: sparse (num_nodes, num_nodes) tensor encoding the (src_type, dst_type) pair as `src_type * num_node_types + dst_type` This enables type-aware attention bias in heterogeneous Graph Transformers. (default: :obj:`False`) @@ -359,13 +360,11 @@ def __init__( h_max: int, attr_name: Optional[str] = 'hop_distance', is_undirected: bool = False, - full_matrix: bool = True, node_type_aware: bool = False, ) -> None: self.h_max = h_max self.attr_name = attr_name self.is_undirected = is_undirected - self.full_matrix = full_matrix self.node_type_aware = node_type_aware def forward(self, data: HeteroData) -> HeteroData: @@ -381,24 +380,20 @@ def forward(self, data: HeteroData) -> HeteroData: num_edges = edge_index.size(1) if num_nodes == 0 or num_edges == 0: - # Handle empty graph case - if self.full_matrix: - data[self.attr_name] = torch.zeros( - (num_nodes, num_nodes), dtype=torch.long - ) - if self.node_type_aware: - data['node_type_ids'] = torch.zeros(num_nodes, dtype=torch.long) - data['node_type_pair'] = torch.zeros( - (num_nodes, num_nodes), dtype=torch.long - ) - else: - for edge_type in data.edge_types: - num_type_edges = data[edge_type].num_edges - data[edge_type][self.attr_name] = torch.zeros( - num_type_edges, dtype=torch.long - ) + # Handle empty graph case - return empty sparse tensor + empty_sparse = torch.sparse_coo_tensor( + torch.zeros((2, 0), dtype=torch.long), + torch.zeros(0, dtype=torch.float), + size=(num_nodes, num_nodes), + ).coalesce() + data[self.attr_name] = empty_sparse + if self.node_type_aware: + data['node_type_ids'] = torch.zeros(num_nodes, dtype=torch.long) + data['node_type_pair'] = empty_sparse return data + device = edge_index.device + # Build sparse adjacency matrix for shortest path computation adj = to_torch_sparse_tensor(edge_index, size=(num_nodes, num_nodes)) @@ -410,113 +405,105 @@ def forward(self, data: HeteroData) -> HeteroData: adj_coalesced = adj.coalesce() adj = torch.sparse_coo_tensor( adj_coalesced.indices(), - torch.ones(adj_coalesced.indices().size(1), device=edge_index.device), + torch.ones(adj_coalesced.indices().size(1), device=device), size=(num_nodes, num_nodes), ).coalesce() - if not self.full_matrix: - # For edge-level distances only, use sparse BFS from source nodes - # This avoids materializing the full distance matrix - src_nodes = edge_index[0] - dst_nodes = edge_index[1] - edge_hop_distances = torch.full((num_edges,), self.h_max, dtype=torch.long, device=edge_index.device) - - # For direct edges, distance is 1 - edge_hop_distances[:] = 1 + # Compute sparse BFS to find all reachable pairs within h_max hops + # Store (row, col, distance) for all reachable pairs + all_rows = [] + all_cols = [] + all_dists = [] - # Map back to HeteroData edge types - add_edge_attr(data, edge_hop_distances.unsqueeze(-1).float(), self.attr_name) - return data - - # For full_matrix=True, we need the complete distance matrix - # Use sparse BFS but accumulate into dense distance matrix - device = edge_index.device - dist_matrix = torch.full( - (num_nodes, num_nodes), self.h_max, dtype=torch.long, device=device - ) - dist_matrix.fill_diagonal_(0) # Distance to self is 0 - - # Track reachability using sparse tensors - # reachable[i,j] = 1 if j is reachable from i + # Track which pairs have been visited (using set of linear indices) + # Start with diagonal (self-loops) as visited but don't include in output identity_indices = torch.arange(num_nodes, device=device) - reachable_indices = torch.stack([identity_indices, identity_indices]) - reachable_values = torch.ones(num_nodes, device=device, dtype=torch.bool) + visited_linear = set((identity_indices * num_nodes + identity_indices).tolist()) - # Current frontier (sparse): nodes reachable at current hop + # Current frontier (sparse): edges reachable at current hop frontier = adj.coalesce() for hop in range(1, self.h_max + 1): frontier_indices = frontier.indices() - frontier_values = frontier.values() if frontier_indices.size(1) == 0: break - # Find newly reachable: in frontier but not in reachable - # Check each (i,j) in frontier against reachable frontier_i = frontier_indices[0] frontier_j = frontier_indices[1] - - # Create a set of reachable pairs for fast lookup - # Convert to linear indices for comparison - reachable_linear = reachable_indices[0] * num_nodes + reachable_indices[1] - frontier_linear = frontier_i * num_nodes + frontier_j - - # Find which frontier edges are not yet reachable - # Use searchsorted for efficiency - reachable_linear_sorted, sort_idx = reachable_linear.sort() - insert_pos = torch.searchsorted(reachable_linear_sorted, frontier_linear) - insert_pos = insert_pos.clamp(max=reachable_linear_sorted.size(0) - 1) - is_new = reachable_linear_sorted[insert_pos] != frontier_linear - - if is_new.any(): - new_i = frontier_i[is_new] - new_j = frontier_j[is_new] - dist_matrix[new_i, new_j] = hop - - # Update reachable set - reachable_indices = torch.cat([ - reachable_indices, - torch.stack([new_i, new_j]) - ], dim=1) - reachable_values = torch.cat([ - reachable_values, - torch.ones(new_i.size(0), device=device, dtype=torch.bool) - ]) - - # Check if all pairs are reachable - if reachable_indices.size(1) >= num_nodes * num_nodes: - break + frontier_linear = (frontier_i * num_nodes + frontier_j).tolist() + + # Find newly reachable pairs (not in visited set) + new_mask = [] + new_pairs_linear = [] + for idx, lin_idx in enumerate(frontier_linear): + if lin_idx not in visited_linear: + new_mask.append(idx) + new_pairs_linear.append(lin_idx) + visited_linear.add(lin_idx) + + if new_mask: + new_mask = torch.tensor(new_mask, device=device, dtype=torch.long) + new_i = frontier_i[new_mask] + new_j = frontier_j[new_mask] + + all_rows.append(new_i) + all_cols.append(new_j) + all_dists.append(torch.full((new_i.size(0),), hop, device=device, dtype=torch.float)) # Expand frontier: frontier = frontier @ adj (sparse matmul) frontier = torch.sparse.mm(frontier, adj).coalesce() - # Store full pairwise distance matrix as graph-level attribute on HeteroData - # Shape: (num_nodes, num_nodes) - for use in Graph Transformers + # Build sparse distance matrix + if all_rows: + dist_rows = torch.cat(all_rows) + dist_cols = torch.cat(all_cols) + dist_vals = torch.cat(all_dists) + else: + dist_rows = torch.zeros(0, dtype=torch.long, device=device) + dist_cols = torch.zeros(0, dtype=torch.long, device=device) + dist_vals = torch.zeros(0, dtype=torch.float, device=device) + + # Create sparse distance matrix + # Unreachable pairs have value 0 (not stored) + # Reachable pairs have value = hop distance (1 to h_max) + dist_sparse = torch.sparse_coo_tensor( + torch.stack([dist_rows, dist_cols]), + dist_vals, + size=(num_nodes, num_nodes), + ).coalesce() + + # Store sparse pairwise distance matrix as graph-level attribute # Access via: data.hop_distance or data['hop_distance'] - # Can be used as attention bias: bias = learnable_embedding[data.hop_distance.long()] + # Usage in attention: dist = data.hop_distance.to_dense() for small graphs, + # or use sparse indexing for memory efficiency # Note: Node ordering follows data.to_homogeneous() order (by node_type alphabetically) - data[self.attr_name] = dist_matrix.float() + data[self.attr_name] = dist_sparse if self.node_type_aware: # Store node type information for heterogeneous-aware attention - # homo_data.node_type contains the type ID for each node after to_homogeneous() node_type_ids = homo_data.node_type # Shape: (num_nodes,) data['node_type_ids'] = node_type_ids - # Compute pairwise node type encoding: (src_type, dst_type) -> single ID + # Compute sparse pairwise node type encoding for reachable pairs only # node_type_pair[i, j] = node_type_ids[i] * num_node_types + node_type_ids[j] - # This allows looking up type-specific attention biases num_node_types = len(data.node_types) - # Outer product style: src_types[:, None] * num_types + dst_types[None, :] - node_type_pair = ( - node_type_ids.unsqueeze(1) * num_node_types + - node_type_ids.unsqueeze(0) - ) - data['node_type_pair'] = node_type_pair + if dist_rows.size(0) > 0: + type_pair_vals = ( + node_type_ids[dist_rows] * num_node_types + + node_type_ids[dist_cols] + ).float() + else: + type_pair_vals = torch.zeros(0, dtype=torch.float, device=device) + + node_type_pair_sparse = torch.sparse_coo_tensor( + torch.stack([dist_rows, dist_cols]), + type_pair_vals, + size=(num_nodes, num_nodes), + ).coalesce() + data['node_type_pair'] = node_type_pair_sparse # Also store the mapping from type ID to type name for reference - # Node types are sorted alphabetically in to_homogeneous() data['node_type_names'] = sorted(data.node_types) return data @@ -524,5 +511,5 @@ def forward(self, data: HeteroData) -> HeteroData: def __repr__(self) -> str: return ( f'{self.__class__.__name__}(h_max={self.h_max}, ' - f'full_matrix={self.full_matrix}, node_type_aware={self.node_type_aware})' + f'node_type_aware={self.node_type_aware})' ) diff --git a/tests/unit/transforms/add_positional_encodings_test.py b/tests/unit/transforms/add_positional_encodings_test.py index f6d1a1c62..6d7da4443 100644 --- a/tests/unit/transforms/add_positional_encodings_test.py +++ b/tests/unit/transforms/add_positional_encodings_test.py @@ -297,111 +297,91 @@ def test_repr_attach_to_x(self): class TestAddHeteroHopDistanceEncoding(TestCase): def test_forward_basic(self): - """Test basic forward pass with full_matrix=True (default).""" + """Test basic forward pass returns sparse matrix.""" data = create_simple_hetero_data() transform = AddHeteroHopDistanceEncoding(h_max=3) result = transform(data) - # Check that full pairwise distance matrix is stored as graph-level attribute + # Check that sparse pairwise distance matrix is stored as graph-level attribute self.assertTrue(hasattr(result, 'hop_distance')) # Total nodes: 3 users + 2 items = 5 nodes self.assertEqual(result.hop_distance.shape, (5, 5)) + # Should be sparse + self.assertTrue(result.hop_distance.is_sparse) - def test_forward_full_matrix_false(self): - """Test forward pass with full_matrix=False (edge-level).""" + def test_forward_sparse_values(self): + """Test that sparse matrix has correct values (0 for unreachable, 1-h_max for reachable).""" data = create_simple_hetero_data() - transform = AddHeteroHopDistanceEncoding(h_max=3, full_matrix=False) + transform = AddHeteroHopDistanceEncoding(h_max=3) result = transform(data) - # Check that PE was added to both edge types - self.assertTrue(hasattr(result['user', 'buys', 'item'], 'hop_distance')) - self.assertTrue(hasattr(result['item', 'bought_by', 'user'], 'hop_distance')) + # Convert to dense for easier testing + dense = result.hop_distance.to_dense() - # Check shapes (3 edges each) - self.assertEqual(result['user', 'buys', 'item'].hop_distance.shape, (3, 1)) - self.assertEqual(result['item', 'bought_by', 'user'].hop_distance.shape, (3, 1)) + # Diagonal should be 0 (distance to self, not stored in sparse = 0) + self.assertTrue((dense.diag() == 0).all()) - # Direct edges should have distance <= h_max - self.assertTrue((result['user', 'buys', 'item'].hop_distance <= 3).all()) + # Non-zero values (reachable pairs) should be in [1, h_max] + nonzero_vals = result.hop_distance.values() + if nonzero_vals.numel() > 0: + self.assertTrue((nonzero_vals >= 1).all()) + self.assertTrue((nonzero_vals <= 3).all()) def test_forward_with_custom_attr_name(self): """Test forward pass with custom attribute name.""" data = create_simple_hetero_data() - transform = AddHeteroHopDistanceEncoding(h_max=2, attr_name='custom_hop', full_matrix=False) + transform = AddHeteroHopDistanceEncoding(h_max=2, attr_name='custom_hop') result = transform(data) - self.assertTrue(hasattr(result['user', 'buys', 'item'], 'custom_hop')) - self.assertTrue(hasattr(result['item', 'bought_by', 'user'], 'custom_hop')) - self.assertFalse(hasattr(result['user', 'buys', 'item'], 'hop_distance')) + self.assertTrue(hasattr(result, 'custom_hop')) + self.assertFalse(hasattr(result, 'hop_distance')) + self.assertTrue(result.custom_hop.is_sparse) def test_forward_undirected(self): """Test forward pass with undirected graph setting.""" data = create_simple_hetero_data() - transform = AddHeteroHopDistanceEncoding(h_max=2, is_undirected=True, full_matrix=False) + transform = AddHeteroHopDistanceEncoding(h_max=2, is_undirected=True) result = transform(data) - self.assertEqual(result['user', 'buys', 'item'].hop_distance.shape, (3, 1)) - self.assertEqual(result['item', 'bought_by', 'user'].hop_distance.shape, (3, 1)) + self.assertTrue(result.hop_distance.is_sparse) + self.assertEqual(result.hop_distance.shape, (5, 5)) def test_forward_empty_graph(self): """Test forward pass with empty graph.""" data = create_empty_hetero_data() # Add empty edge types data['user', 'buys', 'item'].edge_index = torch.zeros((2, 0), dtype=torch.long) - transform = AddHeteroHopDistanceEncoding(h_max=3, full_matrix=False) - - result = transform(data) - - self.assertEqual(result['user', 'buys', 'item'].hop_distance.shape, (0,)) - - def test_forward_full_matrix(self): - """Test forward pass with full_matrix=True for Graph Transformer use.""" - data = create_simple_hetero_data() - transform = AddHeteroHopDistanceEncoding(h_max=3, full_matrix=True) - - result = transform(data) - - # Check that full pairwise distance matrix is stored as graph-level attribute - self.assertTrue(hasattr(result, 'hop_distance')) - # Total nodes: 3 users + 2 items = 5 nodes - self.assertEqual(result.hop_distance.shape, (5, 5)) - # Diagonal should be 0 (distance to self) - self.assertTrue((result.hop_distance.diag() == 0).all()) - # All distances should be <= h_max - self.assertTrue((result.hop_distance <= 3).all()) - - def test_forward_full_matrix_empty_graph(self): - """Test forward pass with full_matrix=True on empty graph.""" - data = create_empty_hetero_data() - data['user', 'buys', 'item'].edge_index = torch.zeros((2, 0), dtype=torch.long) - transform = AddHeteroHopDistanceEncoding(h_max=3, full_matrix=True) + transform = AddHeteroHopDistanceEncoding(h_max=3) result = transform(data) self.assertTrue(hasattr(result, 'hop_distance')) + self.assertTrue(result.hop_distance.is_sparse) self.assertEqual(result.hop_distance.shape, (0, 0)) def test_forward_node_type_aware(self): """Test forward pass with node_type_aware=True for heterogeneous Graph Transformers.""" data = create_simple_hetero_data() - transform = AddHeteroHopDistanceEncoding(h_max=3, full_matrix=True, node_type_aware=True) + transform = AddHeteroHopDistanceEncoding(h_max=3, node_type_aware=True) result = transform(data) - # Check that hop distance matrix is stored + # Check that hop distance matrix is stored (sparse) self.assertTrue(hasattr(result, 'hop_distance')) + self.assertTrue(result.hop_distance.is_sparse) self.assertEqual(result.hop_distance.shape, (5, 5)) # Check that node type information is stored self.assertTrue(hasattr(result, 'node_type_ids')) self.assertEqual(result.node_type_ids.shape, (5,)) - # Check that node type pair matrix is stored + # Check that node type pair matrix is stored (sparse) self.assertTrue(hasattr(result, 'node_type_pair')) + self.assertTrue(result.node_type_pair.is_sparse) self.assertEqual(result.node_type_pair.shape, (5, 5)) # Check that node type names are stored @@ -414,56 +394,40 @@ def test_forward_node_type_aware(self): # Verify node_type_pair encodes (src_type, dst_type) correctly # For 2 node types, pair values should be in [0, 3] (0*2+0, 0*2+1, 1*2+0, 1*2+1) - self.assertTrue((result.node_type_pair >= 0).all()) - self.assertTrue((result.node_type_pair < 4).all()) + type_pair_vals = result.node_type_pair.values() + if type_pair_vals.numel() > 0: + self.assertTrue((type_pair_vals >= 0).all()) + self.assertTrue((type_pair_vals < 4).all()) def test_forward_node_type_aware_empty_graph(self): """Test forward pass with node_type_aware=True on empty graph.""" data = create_empty_hetero_data() data['user', 'buys', 'item'].edge_index = torch.zeros((2, 0), dtype=torch.long) - transform = AddHeteroHopDistanceEncoding(h_max=3, full_matrix=True, node_type_aware=True) + transform = AddHeteroHopDistanceEncoding(h_max=3, node_type_aware=True) result = transform(data) self.assertTrue(hasattr(result, 'hop_distance')) + self.assertTrue(result.hop_distance.is_sparse) self.assertEqual(result.hop_distance.shape, (0, 0)) self.assertTrue(hasattr(result, 'node_type_ids')) self.assertEqual(result.node_type_ids.shape, (0,)) self.assertTrue(hasattr(result, 'node_type_pair')) + self.assertTrue(result.node_type_pair.is_sparse) self.assertEqual(result.node_type_pair.shape, (0, 0)) - def test_node_type_aware_only_with_full_matrix(self): - """Test that node_type_aware only takes effect when full_matrix=True.""" - data = create_simple_hetero_data() - # node_type_aware=True but full_matrix=False should not add node type info - transform = AddHeteroHopDistanceEncoding(h_max=3, full_matrix=False, node_type_aware=True) - - result = transform(data) - - # Should have edge-level hop_distance - self.assertTrue(hasattr(result['user', 'buys', 'item'], 'hop_distance')) - # Should NOT have node type information (only added when full_matrix=True) - self.assertFalse(hasattr(result, 'node_type_ids')) - self.assertFalse(hasattr(result, 'node_type_pair')) - def test_repr(self): """Test string representation.""" - transform = AddHeteroHopDistanceEncoding(h_max=5, full_matrix=False) + transform = AddHeteroHopDistanceEncoding(h_max=5) self.assertEqual( repr(transform), - 'AddHeteroHopDistanceEncoding(h_max=5, full_matrix=False, node_type_aware=False)' - ) - - transform_full = AddHeteroHopDistanceEncoding(h_max=3, full_matrix=True) - self.assertEqual( - repr(transform_full), - 'AddHeteroHopDistanceEncoding(h_max=3, full_matrix=True, node_type_aware=False)' + 'AddHeteroHopDistanceEncoding(h_max=5, node_type_aware=False)' ) - transform_type_aware = AddHeteroHopDistanceEncoding(h_max=3, full_matrix=True, node_type_aware=True) + transform_type_aware = AddHeteroHopDistanceEncoding(h_max=3, node_type_aware=True) self.assertEqual( repr(transform_type_aware), - 'AddHeteroHopDistanceEncoding(h_max=3, full_matrix=True, node_type_aware=True)' + 'AddHeteroHopDistanceEncoding(h_max=3, node_type_aware=True)' ) From c34d0d45e379bd62d5c0a52dee372dfdc385af00 Mon Sep 17 00:00:00 2001 From: Yozen Liu Date: Wed, 25 Feb 2026 17:14:08 -0800 Subject: [PATCH 08/10] simplify --- gigl/transforms/add_positional_encodings.py | 64 +------------------ .../add_positional_encodings_test.py | 61 +----------------- 2 files changed, 2 insertions(+), 123 deletions(-) diff --git a/gigl/transforms/add_positional_encodings.py b/gigl/transforms/add_positional_encodings.py index 0e0baebde..58ae217b0 100644 --- a/gigl/transforms/add_positional_encodings.py +++ b/gigl/transforms/add_positional_encodings.py @@ -56,18 +56,6 @@ >>> data = transform(data) >>> print(data.hop_distance.shape) # (num_total_nodes, num_total_nodes) sparse >>> print(data.hop_distance.is_sparse) # True - >>> - >>> # For heterogeneous Graph Transformers, use node_type_aware=True to preserve - >>> # node type information for type-aware attention bias - >>> transform = AddHeteroHopDistanceEncoding(h_max=5, node_type_aware=True) - >>> data = transform(data) - >>> print(data.hop_distance.shape) # (8, 8) - sparse pairwise distances - >>> print(data.node_type_ids.shape) # (8,) - type ID for each node - >>> print(data.node_type_pair.shape) # (8, 8) - sparse (src_type, dst_type) pairs - >>> print(data.node_type_names) # ['item', 'user'] - sorted alphabetically - >>> - >>> # In a Graph Transformer, combine hop distance and node type for attention bias: - >>> # bias = hop_embedding[data.hop_distance.long()] + type_pair_embedding[data.node_type_pair] """ @@ -328,15 +316,6 @@ class AddHeteroHopDistanceEncoding(BaseTransform): This sparse representation avoids GPU memory blowup for large graphs. - For heterogeneous graphs, additional node type information can be preserved - by setting `node_type_aware=True`. This stores: - - `data.hop_distance`: sparse (num_nodes, num_nodes) distance matrix - - `data.node_type_ids`: (num_nodes,) node type ID for each node - - `data.node_type_pair`: sparse (num_nodes, num_nodes) encodes (src_type, dst_type) pairs - - This allows Graph Transformers to use both structural (hop distance) and - semantic (node type) information in attention bias computation. - Args: h_max (int): Maximum hop distance to consider. Distances > h_max are treated as unreachable (value 0 in sparse matrix). @@ -347,25 +326,16 @@ class AddHeteroHopDistanceEncoding(BaseTransform): is_undirected (bool, optional): If set to :obj:`True`, the graph is assumed to be undirected for distance computation. (default: :obj:`False`) - node_type_aware (bool, optional): If set to :obj:`True`, also stores - node type information: - - `node_type_ids`: (num_nodes,) tensor mapping each node to its type ID - - `node_type_pair`: sparse (num_nodes, num_nodes) tensor encoding the - (src_type, dst_type) pair as `src_type * num_node_types + dst_type` - This enables type-aware attention bias in heterogeneous Graph Transformers. - (default: :obj:`False`) """ def __init__( self, h_max: int, attr_name: Optional[str] = 'hop_distance', is_undirected: bool = False, - node_type_aware: bool = False, ) -> None: self.h_max = h_max self.attr_name = attr_name self.is_undirected = is_undirected - self.node_type_aware = node_type_aware def forward(self, data: HeteroData) -> HeteroData: assert isinstance(data, HeteroData), ( @@ -387,9 +357,6 @@ def forward(self, data: HeteroData) -> HeteroData: size=(num_nodes, num_nodes), ).coalesce() data[self.attr_name] = empty_sparse - if self.node_type_aware: - data['node_type_ids'] = torch.zeros(num_nodes, dtype=torch.long) - data['node_type_pair'] = empty_sparse return data device = edge_index.device @@ -480,36 +447,7 @@ def forward(self, data: HeteroData) -> HeteroData: # Note: Node ordering follows data.to_homogeneous() order (by node_type alphabetically) data[self.attr_name] = dist_sparse - if self.node_type_aware: - # Store node type information for heterogeneous-aware attention - node_type_ids = homo_data.node_type # Shape: (num_nodes,) - data['node_type_ids'] = node_type_ids - - # Compute sparse pairwise node type encoding for reachable pairs only - # node_type_pair[i, j] = node_type_ids[i] * num_node_types + node_type_ids[j] - num_node_types = len(data.node_types) - if dist_rows.size(0) > 0: - type_pair_vals = ( - node_type_ids[dist_rows] * num_node_types + - node_type_ids[dist_cols] - ).float() - else: - type_pair_vals = torch.zeros(0, dtype=torch.float, device=device) - - node_type_pair_sparse = torch.sparse_coo_tensor( - torch.stack([dist_rows, dist_cols]), - type_pair_vals, - size=(num_nodes, num_nodes), - ).coalesce() - data['node_type_pair'] = node_type_pair_sparse - - # Also store the mapping from type ID to type name for reference - data['node_type_names'] = sorted(data.node_types) - return data def __repr__(self) -> str: - return ( - f'{self.__class__.__name__}(h_max={self.h_max}, ' - f'node_type_aware={self.node_type_aware})' - ) + return f'{self.__class__.__name__}(h_max={self.h_max})' diff --git a/tests/unit/transforms/add_positional_encodings_test.py b/tests/unit/transforms/add_positional_encodings_test.py index 6d7da4443..0224ceaa3 100644 --- a/tests/unit/transforms/add_positional_encodings_test.py +++ b/tests/unit/transforms/add_positional_encodings_test.py @@ -363,71 +363,12 @@ def test_forward_empty_graph(self): self.assertTrue(result.hop_distance.is_sparse) self.assertEqual(result.hop_distance.shape, (0, 0)) - def test_forward_node_type_aware(self): - """Test forward pass with node_type_aware=True for heterogeneous Graph Transformers.""" - data = create_simple_hetero_data() - transform = AddHeteroHopDistanceEncoding(h_max=3, node_type_aware=True) - - result = transform(data) - - # Check that hop distance matrix is stored (sparse) - self.assertTrue(hasattr(result, 'hop_distance')) - self.assertTrue(result.hop_distance.is_sparse) - self.assertEqual(result.hop_distance.shape, (5, 5)) - - # Check that node type information is stored - self.assertTrue(hasattr(result, 'node_type_ids')) - self.assertEqual(result.node_type_ids.shape, (5,)) - - # Check that node type pair matrix is stored (sparse) - self.assertTrue(hasattr(result, 'node_type_pair')) - self.assertTrue(result.node_type_pair.is_sparse) - self.assertEqual(result.node_type_pair.shape, (5, 5)) - - # Check that node type names are stored - self.assertTrue(hasattr(result, 'node_type_names')) - self.assertEqual(result.node_type_names, ['item', 'user']) # Sorted alphabetically - - # Verify node_type_ids values are valid (0 or 1 for 2 node types) - self.assertTrue((result.node_type_ids >= 0).all()) - self.assertTrue((result.node_type_ids < 2).all()) - - # Verify node_type_pair encodes (src_type, dst_type) correctly - # For 2 node types, pair values should be in [0, 3] (0*2+0, 0*2+1, 1*2+0, 1*2+1) - type_pair_vals = result.node_type_pair.values() - if type_pair_vals.numel() > 0: - self.assertTrue((type_pair_vals >= 0).all()) - self.assertTrue((type_pair_vals < 4).all()) - - def test_forward_node_type_aware_empty_graph(self): - """Test forward pass with node_type_aware=True on empty graph.""" - data = create_empty_hetero_data() - data['user', 'buys', 'item'].edge_index = torch.zeros((2, 0), dtype=torch.long) - transform = AddHeteroHopDistanceEncoding(h_max=3, node_type_aware=True) - - result = transform(data) - - self.assertTrue(hasattr(result, 'hop_distance')) - self.assertTrue(result.hop_distance.is_sparse) - self.assertEqual(result.hop_distance.shape, (0, 0)) - self.assertTrue(hasattr(result, 'node_type_ids')) - self.assertEqual(result.node_type_ids.shape, (0,)) - self.assertTrue(hasattr(result, 'node_type_pair')) - self.assertTrue(result.node_type_pair.is_sparse) - self.assertEqual(result.node_type_pair.shape, (0, 0)) - def test_repr(self): """Test string representation.""" transform = AddHeteroHopDistanceEncoding(h_max=5) self.assertEqual( repr(transform), - 'AddHeteroHopDistanceEncoding(h_max=5, node_type_aware=False)' - ) - - transform_type_aware = AddHeteroHopDistanceEncoding(h_max=3, node_type_aware=True) - self.assertEqual( - repr(transform_type_aware), - 'AddHeteroHopDistanceEncoding(h_max=3, node_type_aware=True)' + 'AddHeteroHopDistanceEncoding(h_max=5)' ) From cd8053441602ca7c974cf30c9e51d2c18cb5e940 Mon Sep 17 00:00:00 2001 From: Yozen Liu Date: Fri, 27 Feb 2026 14:52:58 -0800 Subject: [PATCH 09/10] update hop distance --- gigl/transforms/add_positional_encodings.py | 78 ++++++++++++++------- 1 file changed, 52 insertions(+), 26 deletions(-) diff --git a/gigl/transforms/add_positional_encodings.py b/gigl/transforms/add_positional_encodings.py index 58ae217b0..2f330ef89 100644 --- a/gigl/transforms/add_positional_encodings.py +++ b/gigl/transforms/add_positional_encodings.py @@ -377,48 +377,74 @@ def forward(self, data: HeteroData) -> HeteroData: ).coalesce() # Compute sparse BFS to find all reachable pairs within h_max hops - # Store (row, col, distance) for all reachable pairs + # Use sparse tensor accumulation (like AddHeteroRandomWalkPE) instead of Python set + + # Track minimum hop distance using sparse accumulation + # We'll accumulate reachability and track first-visit hop all_rows = [] all_cols = [] all_dists = [] - # Track which pairs have been visited (using set of linear indices) - # Start with diagonal (self-loops) as visited but don't include in output - identity_indices = torch.arange(num_nodes, device=device) - visited_linear = set((identity_indices * num_nodes + identity_indices).tolist()) - - # Current frontier (sparse): edges reachable at current hop + # Start with adjacency as hop 1 frontier = adj.coalesce() - for hop in range(1, self.h_max + 1): - frontier_indices = frontier.indices() + # Track all visited pairs using sparse tensor (value > 0 means visited) + visited = torch.sparse_coo_tensor( + torch.stack([ + torch.arange(num_nodes, device=device), + torch.arange(num_nodes, device=device), + ]), + torch.ones(num_nodes, device=device), + size=(num_nodes, num_nodes), + ) # Start with diagonal as visited - if frontier_indices.size(1) == 0: + for hop in range(1, self.h_max + 1): + if frontier._nnz() == 0: break + frontier_indices = frontier.indices() frontier_i = frontier_indices[0] frontier_j = frontier_indices[1] - frontier_linear = (frontier_i * num_nodes + frontier_j).tolist() - - # Find newly reachable pairs (not in visited set) - new_mask = [] - new_pairs_linear = [] - for idx, lin_idx in enumerate(frontier_linear): - if lin_idx not in visited_linear: - new_mask.append(idx) - new_pairs_linear.append(lin_idx) - visited_linear.add(lin_idx) - - if new_mask: - new_mask = torch.tensor(new_mask, device=device, dtype=torch.long) - new_i = frontier_i[new_mask] - new_j = frontier_j[new_mask] + + # Find newly reachable: in frontier but not in visited + # Use sparse addition: visited + frontier, then check where frontier has entry but combined == 1 + combined = (visited + frontier).coalesce() + + # For entries in frontier, check if they're new (combined value == 1 means new) + # We need to find frontier entries where combined value == 1 + frontier_keys = frontier_i * num_nodes + frontier_j + + combined_indices = combined.indices() + combined_values = combined.values() + combined_keys = combined_indices[0] * num_nodes + combined_indices[1] + + # Sort combined for searchsorted + sorted_keys, sort_perm = torch.sort(combined_keys) + sorted_values = combined_values[sort_perm] + + # Find frontier entries in combined + pos = torch.searchsorted(sorted_keys, frontier_keys) + pos_clamped = pos.clamp(max=sorted_keys.size(0) - 1) + + # New if combined value == 1 (only from frontier, not previously visited) + is_new = (sorted_keys[pos_clamped] == frontier_keys) & (sorted_values[pos_clamped] == 1.0) + + if is_new.any(): + new_i = frontier_i[is_new] + new_j = frontier_j[is_new] all_rows.append(new_i) all_cols.append(new_j) all_dists.append(torch.full((new_i.size(0),), hop, device=device, dtype=torch.float)) - # Expand frontier: frontier = frontier @ adj (sparse matmul) + # Update visited (binarize combined) + visited = torch.sparse_coo_tensor( + combined.indices(), + torch.ones(combined._nnz(), device=device), + size=(num_nodes, num_nodes), + ) + + # Expand frontier: frontier = frontier @ adj frontier = torch.sparse.mm(frontier, adj).coalesce() # Build sparse distance matrix From 7fe24154f0aacea9d211575abad8a5ff2a12027f Mon Sep 17 00:00:00 2001 From: Yozen Liu Date: Thu, 5 Mar 2026 14:49:28 -0800 Subject: [PATCH 10/10] optim hop dist mem --- gigl/transforms/add_positional_encodings.py | 163 ++++++++++++-------- 1 file changed, 97 insertions(+), 66 deletions(-) diff --git a/gigl/transforms/add_positional_encodings.py b/gigl/transforms/add_positional_encodings.py index 2f330ef89..6bea1eaa6 100644 --- a/gigl/transforms/add_positional_encodings.py +++ b/gigl/transforms/add_positional_encodings.py @@ -376,82 +376,113 @@ def forward(self, data: HeteroData) -> HeteroData: size=(num_nodes, num_nodes), ).coalesce() - # Compute sparse BFS to find all reachable pairs within h_max hops - # Use sparse tensor accumulation (like AddHeteroRandomWalkPE) instead of Python set - - # Track minimum hop distance using sparse accumulation - # We'll accumulate reachability and track first-visit hop - all_rows = [] - all_cols = [] - all_dists = [] - - # Start with adjacency as hop 1 - frontier = adj.coalesce() - - # Track all visited pairs using sparse tensor (value > 0 means visited) - visited = torch.sparse_coo_tensor( - torch.stack([ - torch.arange(num_nodes, device=device), - torch.arange(num_nodes, device=device), - ]), - torch.ones(num_nodes, device=device), - size=(num_nodes, num_nodes), - ) # Start with diagonal as visited - - for hop in range(1, self.h_max + 1): - if frontier._nnz() == 0: - break - - frontier_indices = frontier.indices() - frontier_i = frontier_indices[0] - frontier_j = frontier_indices[1] - - # Find newly reachable: in frontier but not in visited - # Use sparse addition: visited + frontier, then check where frontier has entry but combined == 1 - combined = (visited + frontier).coalesce() - - # For entries in frontier, check if they're new (combined value == 1 means new) - # We need to find frontier entries where combined value == 1 - frontier_keys = frontier_i * num_nodes + frontier_j + # Memory-optimized BFS for computing shortest path distances + # + # Key memory optimizations: + # 1. Use sorted linear indices with searchsorted for O(n log n) membership test + # (more memory efficient than torch.isin which may create hash tables) + # 2. Store distances as int8 (h_max typically < 127) + # 3. Avoid tensor concatenation in hot loop - use pre-sorted merge instead + # 4. Explicit del statements to trigger garbage collection + # 5. CSR format for sparse matmul (more memory efficient than COO) + # + # Memory complexity: O(nnz_frontier + nnz_visited) per iteration + # where nnz_frontier can grow up to O(n^2) for dense graphs at large hop + + dist_matrix_rows = [] + dist_matrix_cols = [] + dist_matrix_vals = [] + + # Choose tracking strategy based on graph size + # For small graphs (n < 10000), bitmap is faster with O(1) lookup + # For large graphs, sorted indices use less memory O(visited) vs O(n^2/8) + USE_BITMAP = num_nodes < 10000 + + if USE_BITMAP: + # Dense bitmap: O(n^2 / 8) bytes, O(1) lookup + # For n=10000, this is ~12.5 MB + visited_bitmap = torch.zeros(num_nodes, num_nodes, dtype=torch.bool, device=device) + visited_bitmap.fill_diagonal_(True) # Mark diagonal as visited + else: + # Sorted linear indices: O(visited pairs) memory, O(log n) lookup + identity_indices = torch.arange(num_nodes, device=device, dtype=torch.long) + visited_linear = identity_indices * num_nodes + identity_indices # Diagonal + visited_linear = visited_linear.sort()[0] - combined_indices = combined.indices() - combined_values = combined.values() - combined_keys = combined_indices[0] * num_nodes + combined_indices[1] + # Adjacency matrix in CSR format (more memory efficient for matmul) + adj_csr = adj.to_sparse_csr() + del adj # Free COO adjacency - # Sort combined for searchsorted - sorted_keys, sort_perm = torch.sort(combined_keys) - sorted_values = combined_values[sort_perm] + # Current frontier (reachable pairs at current hop distance) + frontier = adj_csr.to_sparse_coo().coalesce() - # Find frontier entries in combined - pos = torch.searchsorted(sorted_keys, frontier_keys) - pos_clamped = pos.clamp(max=sorted_keys.size(0) - 1) + for hop in range(1, self.h_max + 1): + if hop > 1: + # frontier = frontier @ adj (sparse matmul) + # CSR @ CSR is most efficient + frontier_csr = frontier.to_sparse_csr() + del frontier + frontier = torch.sparse.mm(frontier_csr, adj_csr).to_sparse_coo().coalesce() + del frontier_csr - # New if combined value == 1 (only from frontier, not previously visited) - is_new = (sorted_keys[pos_clamped] == frontier_keys) & (sorted_values[pos_clamped] == 1.0) + frontier_indices = frontier.indices() + num_frontier = frontier_indices.size(1) + if num_frontier == 0: + break - if is_new.any(): - new_i = frontier_i[is_new] - new_j = frontier_j[is_new] + reach_i, reach_j = frontier_indices[0], frontier_indices[1] + + if USE_BITMAP: + # O(1) lookup using dense bitmap + is_visited = visited_bitmap[reach_i, reach_j] + is_new = ~is_visited + del is_visited + else: + # O(log n) lookup using sorted searchsorted + frontier_linear = reach_i.long() * num_nodes + reach_j.long() + insert_pos = torch.searchsorted(visited_linear, frontier_linear) + insert_pos_clamped = insert_pos.clamp(max=visited_linear.size(0) - 1) + is_visited = (visited_linear[insert_pos_clamped] == frontier_linear) + is_new = ~is_visited + del frontier_linear, insert_pos, insert_pos_clamped, is_visited + + num_new = is_new.sum().item() + if num_new > 0: + new_i = reach_i[is_new] + new_j = reach_j[is_new] + + dist_matrix_rows.append(new_i) + dist_matrix_cols.append(new_j) + # Use int8 for hop distance (saves 4x memory vs float32) + dist_matrix_vals.append( + torch.full((num_new,), hop, device=device, dtype=torch.int8) + ) - all_rows.append(new_i) - all_cols.append(new_j) - all_dists.append(torch.full((new_i.size(0),), hop, device=device, dtype=torch.float)) + # Update visited + if USE_BITMAP: + visited_bitmap[new_i, new_j] = True + else: + new_linear = new_i.long() * num_nodes + new_j.long() + visited_linear = torch.cat([visited_linear, new_linear]).sort()[0] + del new_linear - # Update visited (binarize combined) - visited = torch.sparse_coo_tensor( - combined.indices(), - torch.ones(combined._nnz(), device=device), - size=(num_nodes, num_nodes), - ) + del is_new, reach_i, reach_j - # Expand frontier: frontier = frontier @ adj - frontier = torch.sparse.mm(frontier, adj).coalesce() + # Clean up + if USE_BITMAP: + del visited_bitmap + else: + del visited_linear + del adj_csr, frontier # Build sparse distance matrix - if all_rows: - dist_rows = torch.cat(all_rows) - dist_cols = torch.cat(all_cols) - dist_vals = torch.cat(all_dists) + if dist_matrix_rows: + dist_rows = torch.cat(dist_matrix_rows) + dist_cols = torch.cat(dist_matrix_cols) + # Convert int8 to float for downstream compatibility + dist_vals = torch.cat(dist_matrix_vals).float() + # Free intermediate lists + del dist_matrix_rows, dist_matrix_cols, dist_matrix_vals else: dist_rows = torch.zeros(0, dtype=torch.long, device=device) dist_cols = torch.zeros(0, dtype=torch.long, device=device)