diff --git a/gigl/transforms/__init__.py b/gigl/transforms/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/gigl/transforms/add_positional_encodings.py b/gigl/transforms/add_positional_encodings.py new file mode 100644 index 000000000..6bea1eaa6 --- /dev/null +++ b/gigl/transforms/add_positional_encodings.py @@ -0,0 +1,510 @@ +from typing import Optional + +import torch + +from torch_geometric.data import HeteroData +from torch_geometric.data.datapipes import functional_transform +from torch_geometric.transforms import BaseTransform +from torch_geometric.utils import to_torch_sparse_tensor + +from gigl.transforms.utils import add_node_attr + +r""" +Positional and Structural Encodings for Heterogeneous Graphs. + +This module provides PyG-compatible transforms for adding positional and structural +encodings to HeteroData objects. All transforms follow the PyG BaseTransform interface +and can be composed using `torch_geometric.transforms.Compose`. + +Available Transforms: + - AddHeteroRandomWalkPE: Random walk positional encoding (column sum of non-diagonal) + - AddHeteroRandomWalkSE: Random walk structural encoding (diagonal elements) + - AddHeteroHopDistanceEncoding: Shortest path distance encoding + +Example Usage: + >>> from torch_geometric.data import HeteroData + >>> from torch_geometric.transforms import Compose + >>> from gigl.transforms.add_positional_encodings import ( + ... AddHeteroRandomWalkPE, + ... AddHeteroRandomWalkSE, + ... AddHeteroHopDistanceEncoding, + ... ) + >>> + >>> # Create a heterogeneous graph + >>> data = HeteroData() + >>> data['user'].x = torch.randn(5, 16) + >>> data['item'].x = torch.randn(3, 16) + >>> data['user', 'buys', 'item'].edge_index = torch.tensor([[0, 1, 2], [0, 1, 2]]) + >>> data['item', 'bought_by', 'user'].edge_index = torch.tensor([[0, 1, 2], [0, 1, 2]]) + >>> + >>> # Apply single transform + >>> transform = AddHeteroRandomWalkPE(walk_length=8) + >>> data = transform(data) + >>> print(data['user'].random_walk_pe.shape) # (5, 8) + >>> + >>> # Compose multiple transforms + >>> transform = Compose([ + ... AddHeteroRandomWalkPE(walk_length=8), + ... AddHeteroRandomWalkSE(walk_length=8), + ... AddHeteroHopDistanceEncoding(h_max=3), + ... ]) + >>> data = transform(data) + >>> + >>> # For Graph Transformers, use hop distance encoding for attention bias + >>> # Returns sparse matrix (0 for unreachable, 1-h_max for reachable pairs) + >>> transform = AddHeteroHopDistanceEncoding(h_max=5) + >>> data = transform(data) + >>> print(data.hop_distance.shape) # (num_total_nodes, num_total_nodes) sparse + >>> print(data.hop_distance.is_sparse) # True +""" + + +@functional_transform('add_hetero_random_walk_pe') +class AddHeteroRandomWalkPE(BaseTransform): + r"""Adds the random walk positional encoding to the given heterogeneous graph + (functional name: :obj:`add_hetero_random_walk_pe`). + + For each node j, computes the sum of transition probabilities from all other + nodes to j after k steps of a random walk, for k = 1, 2, ..., walk_length. + This captures how "reachable" or "central" a node is from the rest of the graph. + + The encoding is the column sum of non-diagonal elements of the k-step + random walk matrix: + PE[j, k] = Σ_{i≠j} (P^k)[i, j] + + where P is the transition matrix. This measures the probability mass flowing + into node j from all other nodes at step k. + + Args: + walk_length (int): The number of random walk steps. + attr_name (str, optional): The attribute name of the positional + encoding. (default: :obj:`"random_walk_pe"`) + is_undirected (bool, optional): If set to :obj:`True`, the graph is + assumed to be undirected, and the adjacency matrix will be made + symmetric. (default: :obj:`False`) + attach_to_x (bool, optional): If set to :obj:`True`, the encoding is + concatenated directly to :obj:`data[node_type].x` for each node type + instead of being stored as a separate attribute. (default: :obj:`False`) + """ + def __init__( + self, + walk_length: int, + attr_name: Optional[str] = 'random_walk_pe', + is_undirected: bool = False, + attach_to_x: bool = False, + ) -> None: + self.walk_length = walk_length + self.attr_name = attr_name + self.is_undirected = is_undirected + self.attach_to_x = attach_to_x + + def forward(self, data: HeteroData) -> HeteroData: + assert isinstance(data, HeteroData), ( + f"'{self.__class__.__name__}' only supports 'HeteroData' " + f"(got '{type(data)}')" + ) + + # Convert to homogeneous + homo_data = data.to_homogeneous() + edge_index = homo_data.edge_index + num_nodes = homo_data.num_nodes + + if num_nodes == 0: + for node_type in data.node_types: + empty_pe = torch.zeros( + (data[node_type].num_nodes, self.walk_length), + dtype=torch.float, + ) + effective_attr_name = None if self.attach_to_x else self.attr_name + add_node_attr(data, {node_type: empty_pe}, effective_attr_name) + return data + + # Compute transition matrix (row-stochastic) using sparse operations + adj = to_torch_sparse_tensor(edge_index, size=(num_nodes, num_nodes)) + + if self.is_undirected: + # Make symmetric for undirected graphs + adj = (adj + adj.t()).coalesce() + + # Compute degree for row normalization + adj_coalesced = adj.coalesce() + deg = torch.zeros(num_nodes, device=edge_index.device) + deg.scatter_add_(0, adj_coalesced.indices()[0], adj_coalesced.values().float()) + deg = torch.clamp(deg, min=1) # Avoid division by zero + + # Create row-normalized transition matrix (sparse) + # P[i,j] = A[i,j] / deg[i] + row_indices = adj_coalesced.indices()[0] + normalized_values = adj_coalesced.values().float() / deg[row_indices] + transition = torch.sparse_coo_tensor( + adj_coalesced.indices(), + normalized_values, + size=(num_nodes, num_nodes), + ).coalesce() + + # Compute random walk positional encoding using sparse operations + # PE[j, k] = sum of column j excluding diagonal = Σ_{i≠j} (P^k)[i, j] + pe = torch.zeros((num_nodes, self.walk_length), dtype=torch.float, device=edge_index.device) + + # Start with identity matrix (sparse) + identity_indices = torch.arange(num_nodes, device=edge_index.device) + current = torch.sparse_coo_tensor( + torch.stack([identity_indices, identity_indices]), + torch.ones(num_nodes, device=edge_index.device), + size=(num_nodes, num_nodes), + ).coalesce() + + for k in range(self.walk_length): + current = torch.sparse.mm(current, transition).coalesce() + # Column sum = sum over rows for each column + col_sum = torch.zeros(num_nodes, device=edge_index.device) + col_sum.scatter_add_(0, current.indices()[1], current.values()) + # Extract diagonal elements + diag = torch.zeros(num_nodes, device=edge_index.device) + diag_mask = current.indices()[0] == current.indices()[1] + if diag_mask.any(): + diag.scatter_add_(0, current.indices()[0][diag_mask], current.values()[diag_mask]) + pe[:, k] = col_sum - diag + + # Map back to HeteroData node types + # If attach_to_x is True, pass None as attr_name to concatenate to x directly + effective_attr_name = None if self.attach_to_x else self.attr_name + add_node_attr(data, pe, effective_attr_name) + + return data + + def __repr__(self) -> str: + return ( + f'{self.__class__.__name__}(walk_length={self.walk_length}, ' + f'attach_to_x={self.attach_to_x})' + ) + + +@functional_transform('add_hetero_random_walk_se') +class AddHeteroRandomWalkSE(BaseTransform): + r"""Adds the random walk structural encoding from the + `"Graph Neural Networks with Learnable Structural and Positional + Representations" `_ paper to the given + heterogeneous graph (functional name: :obj:`add_hetero_random_walk_se`). + + For each node, computes the probability of returning to itself after k steps + of a random walk, for k = 1, 2, ..., walk_length. This captures the local + structural role of each node (e.g., cycles, clustering coefficient). + + The encoding is the diagonal of the k-step random walk matrix: + SE[i, k] = (P^k)[i, i] + + where P is the transition matrix. + + Args: + walk_length (int): The number of random walk steps. + attr_name (str, optional): The attribute name of the structural + encoding. (default: :obj:`"random_walk_se"`) + is_undirected (bool, optional): If set to :obj:`True`, the graph is + assumed to be undirected, and the adjacency matrix will be made + symmetric. (default: :obj:`False`) + attach_to_x (bool, optional): If set to :obj:`True`, the encoding is + concatenated directly to :obj:`data[node_type].x` for each node type + instead of being stored as a separate attribute. (default: :obj:`False`) + """ + def __init__( + self, + walk_length: int, + attr_name: Optional[str] = 'random_walk_se', + is_undirected: bool = False, + attach_to_x: bool = False, + ) -> None: + self.walk_length = walk_length + self.attr_name = attr_name + self.is_undirected = is_undirected + self.attach_to_x = attach_to_x + + def forward(self, data: HeteroData) -> HeteroData: + assert isinstance(data, HeteroData), ( + f"'{self.__class__.__name__}' only supports 'HeteroData' " + f"(got '{type(data)}')" + ) + + # Convert to homogeneous + homo_data = data.to_homogeneous() + edge_index = homo_data.edge_index + num_nodes = homo_data.num_nodes + + if num_nodes == 0: + for node_type in data.node_types: + empty_se = torch.zeros( + (data[node_type].num_nodes, self.walk_length), + dtype=torch.float, + ) + effective_attr_name = None if self.attach_to_x else self.attr_name + add_node_attr(data, {node_type: empty_se}, effective_attr_name) + return data + + # Compute transition matrix (row-stochastic) using sparse operations + adj = to_torch_sparse_tensor(edge_index, size=(num_nodes, num_nodes)) + + if self.is_undirected: + # Make symmetric for undirected graphs + adj = (adj + adj.t()).coalesce() + + # Compute degree for row normalization + adj_coalesced = adj.coalesce() + deg = torch.zeros(num_nodes, device=edge_index.device) + deg.scatter_add_(0, adj_coalesced.indices()[0], adj_coalesced.values().float()) + deg = torch.clamp(deg, min=1) # Avoid division by zero + + # Create row-normalized transition matrix (sparse) + # P[i,j] = A[i,j] / deg[i] + row_indices = adj_coalesced.indices()[0] + normalized_values = adj_coalesced.values().float() / deg[row_indices] + transition = torch.sparse_coo_tensor( + adj_coalesced.indices(), + normalized_values, + size=(num_nodes, num_nodes), + ).coalesce() + + # Compute random walk return probabilities (diagonal elements) using sparse operations + se = torch.zeros((num_nodes, self.walk_length), dtype=torch.float, device=edge_index.device) + + # Start with identity matrix (sparse) + identity_indices = torch.arange(num_nodes, device=edge_index.device) + current = torch.sparse_coo_tensor( + torch.stack([identity_indices, identity_indices]), + torch.ones(num_nodes, device=edge_index.device), + size=(num_nodes, num_nodes), + ).coalesce() + + for k in range(self.walk_length): + current = torch.sparse.mm(current, transition).coalesce() + # Extract diagonal elements: probability of returning to the same node + diag_mask = current.indices()[0] == current.indices()[1] + if diag_mask.any(): + se[:, k].scatter_add_(0, current.indices()[0][diag_mask], current.values()[diag_mask]) + + # Map back to HeteroData node types + # If attach_to_x is True, pass None as attr_name to concatenate to x directly + effective_attr_name = None if self.attach_to_x else self.attr_name + add_node_attr(data, se, effective_attr_name) + + return data + + def __repr__(self) -> str: + return ( + f'{self.__class__.__name__}(walk_length={self.walk_length}, ' + f'attach_to_x={self.attach_to_x})' + ) + + + +@functional_transform('add_hetero_hop_distance_encoding') +class AddHeteroHopDistanceEncoding(BaseTransform): + r"""Adds hop distance positional encoding as relative encoding (sparse). + + For each pair of nodes (vi, vj), computes the shortest path distance p(vi, vj). + This captures structural proximity and can be used with a learnable embedding + matrix: + + h_hop(vi, vj) = W_hop · onehot(p(vi, vj)) + + Based on the approach from `"Do Transformers Really Perform Bad for Graph + Representation?" `_ (Graphormer). + + The output is a **sparse matrix** where: + - Reachable pairs (i, j) within h_max hops have value = hop distance (1 to h_max) + - Unreachable pairs have value = 0 (not stored in sparse tensor) + - Self-loops (diagonal) are not stored (distance to self is implicitly 0) + + This sparse representation avoids GPU memory blowup for large graphs. + + Args: + h_max (int): Maximum hop distance to consider. Distances > h_max + are treated as unreachable (value 0 in sparse matrix). + Set to 2-3 for 2-hop sampled subgraphs. + Set to min(walk_length // 2, 10) for random walk sampled subgraphs. + attr_name (str, optional): The attribute name of the positional + encoding. (default: :obj:`"hop_distance"`) + is_undirected (bool, optional): If set to :obj:`True`, the graph is + assumed to be undirected for distance computation. + (default: :obj:`False`) + """ + def __init__( + self, + h_max: int, + attr_name: Optional[str] = 'hop_distance', + is_undirected: bool = False, + ) -> None: + self.h_max = h_max + self.attr_name = attr_name + self.is_undirected = is_undirected + + def forward(self, data: HeteroData) -> HeteroData: + assert isinstance(data, HeteroData), ( + f"'{self.__class__.__name__}' only supports 'HeteroData' " + f"(got '{type(data)}')" + ) + + # Convert to homogeneous to compute shortest paths + homo_data = data.to_homogeneous() + edge_index = homo_data.edge_index + num_nodes = homo_data.num_nodes + num_edges = edge_index.size(1) + + if num_nodes == 0 or num_edges == 0: + # Handle empty graph case - return empty sparse tensor + empty_sparse = torch.sparse_coo_tensor( + torch.zeros((2, 0), dtype=torch.long), + torch.zeros(0, dtype=torch.float), + size=(num_nodes, num_nodes), + ).coalesce() + data[self.attr_name] = empty_sparse + return data + + device = edge_index.device + + # Build sparse adjacency matrix for shortest path computation + adj = to_torch_sparse_tensor(edge_index, size=(num_nodes, num_nodes)) + + if self.is_undirected: + # Make symmetric for undirected graphs + adj = (adj + adj.t()).coalesce() + + # Binarize adjacency (sparse) + adj_coalesced = adj.coalesce() + adj = torch.sparse_coo_tensor( + adj_coalesced.indices(), + torch.ones(adj_coalesced.indices().size(1), device=device), + size=(num_nodes, num_nodes), + ).coalesce() + + # Memory-optimized BFS for computing shortest path distances + # + # Key memory optimizations: + # 1. Use sorted linear indices with searchsorted for O(n log n) membership test + # (more memory efficient than torch.isin which may create hash tables) + # 2. Store distances as int8 (h_max typically < 127) + # 3. Avoid tensor concatenation in hot loop - use pre-sorted merge instead + # 4. Explicit del statements to trigger garbage collection + # 5. CSR format for sparse matmul (more memory efficient than COO) + # + # Memory complexity: O(nnz_frontier + nnz_visited) per iteration + # where nnz_frontier can grow up to O(n^2) for dense graphs at large hop + + dist_matrix_rows = [] + dist_matrix_cols = [] + dist_matrix_vals = [] + + # Choose tracking strategy based on graph size + # For small graphs (n < 10000), bitmap is faster with O(1) lookup + # For large graphs, sorted indices use less memory O(visited) vs O(n^2/8) + USE_BITMAP = num_nodes < 10000 + + if USE_BITMAP: + # Dense bitmap: O(n^2 / 8) bytes, O(1) lookup + # For n=10000, this is ~12.5 MB + visited_bitmap = torch.zeros(num_nodes, num_nodes, dtype=torch.bool, device=device) + visited_bitmap.fill_diagonal_(True) # Mark diagonal as visited + else: + # Sorted linear indices: O(visited pairs) memory, O(log n) lookup + identity_indices = torch.arange(num_nodes, device=device, dtype=torch.long) + visited_linear = identity_indices * num_nodes + identity_indices # Diagonal + visited_linear = visited_linear.sort()[0] + + # Adjacency matrix in CSR format (more memory efficient for matmul) + adj_csr = adj.to_sparse_csr() + del adj # Free COO adjacency + + # Current frontier (reachable pairs at current hop distance) + frontier = adj_csr.to_sparse_coo().coalesce() + + for hop in range(1, self.h_max + 1): + if hop > 1: + # frontier = frontier @ adj (sparse matmul) + # CSR @ CSR is most efficient + frontier_csr = frontier.to_sparse_csr() + del frontier + frontier = torch.sparse.mm(frontier_csr, adj_csr).to_sparse_coo().coalesce() + del frontier_csr + + frontier_indices = frontier.indices() + num_frontier = frontier_indices.size(1) + if num_frontier == 0: + break + + reach_i, reach_j = frontier_indices[0], frontier_indices[1] + + if USE_BITMAP: + # O(1) lookup using dense bitmap + is_visited = visited_bitmap[reach_i, reach_j] + is_new = ~is_visited + del is_visited + else: + # O(log n) lookup using sorted searchsorted + frontier_linear = reach_i.long() * num_nodes + reach_j.long() + insert_pos = torch.searchsorted(visited_linear, frontier_linear) + insert_pos_clamped = insert_pos.clamp(max=visited_linear.size(0) - 1) + is_visited = (visited_linear[insert_pos_clamped] == frontier_linear) + is_new = ~is_visited + del frontier_linear, insert_pos, insert_pos_clamped, is_visited + + num_new = is_new.sum().item() + if num_new > 0: + new_i = reach_i[is_new] + new_j = reach_j[is_new] + + dist_matrix_rows.append(new_i) + dist_matrix_cols.append(new_j) + # Use int8 for hop distance (saves 4x memory vs float32) + dist_matrix_vals.append( + torch.full((num_new,), hop, device=device, dtype=torch.int8) + ) + + # Update visited + if USE_BITMAP: + visited_bitmap[new_i, new_j] = True + else: + new_linear = new_i.long() * num_nodes + new_j.long() + visited_linear = torch.cat([visited_linear, new_linear]).sort()[0] + del new_linear + + del is_new, reach_i, reach_j + + # Clean up + if USE_BITMAP: + del visited_bitmap + else: + del visited_linear + del adj_csr, frontier + + # Build sparse distance matrix + if dist_matrix_rows: + dist_rows = torch.cat(dist_matrix_rows) + dist_cols = torch.cat(dist_matrix_cols) + # Convert int8 to float for downstream compatibility + dist_vals = torch.cat(dist_matrix_vals).float() + # Free intermediate lists + del dist_matrix_rows, dist_matrix_cols, dist_matrix_vals + else: + dist_rows = torch.zeros(0, dtype=torch.long, device=device) + dist_cols = torch.zeros(0, dtype=torch.long, device=device) + dist_vals = torch.zeros(0, dtype=torch.float, device=device) + + # Create sparse distance matrix + # Unreachable pairs have value 0 (not stored) + # Reachable pairs have value = hop distance (1 to h_max) + dist_sparse = torch.sparse_coo_tensor( + torch.stack([dist_rows, dist_cols]), + dist_vals, + size=(num_nodes, num_nodes), + ).coalesce() + + # Store sparse pairwise distance matrix as graph-level attribute + # Access via: data.hop_distance or data['hop_distance'] + # Usage in attention: dist = data.hop_distance.to_dense() for small graphs, + # or use sparse indexing for memory efficiency + # Note: Node ordering follows data.to_homogeneous() order (by node_type alphabetically) + data[self.attr_name] = dist_sparse + + return data + + def __repr__(self) -> str: + return f'{self.__class__.__name__}(h_max={self.h_max})' diff --git a/gigl/transforms/utils.py b/gigl/transforms/utils.py new file mode 100644 index 000000000..ada69687c --- /dev/null +++ b/gigl/transforms/utils.py @@ -0,0 +1,167 @@ +from typing import Dict, Optional, Tuple, Union + +import torch +from torch import Tensor + +from torch_geometric.data import HeteroData + +# Type alias for edge types in PyG HeteroData +EdgeType = Tuple[str, str, str] + + +def add_node_attr( + data: HeteroData, + values: Union[Tensor, Dict[str, Tensor]], + attr_name: Optional[str] = None, + node_type_to_idx: Optional[Dict[str, Tuple[int, int]]] = None, +) -> HeteroData: + """Helper function to add node attributes to a HeteroData object. + + Args: + data: The HeteroData object to modify. + values: Either: + - A tensor of values in homogeneous node order (requires node_type_to_idx + or will be computed from data.node_types), OR + - A dictionary mapping node types to tensors of values for each type. + attr_name: The name of the attribute to add. If None, concatenates to + existing `x` attribute for each node type (or creates it). + node_type_to_idx: Optional mapping from node type to (start, end) indices. + Only used when values is a tensor. If None, it will be computed from + data.node_types. + + Returns: + The modified HeteroData object. + """ + # If values is a dictionary, directly assign to each node type + if isinstance(values, dict): + for node_type, value in values.items(): + if node_type not in data.node_types: + continue + _set_node_attr_for_type(data, node_type, value, attr_name) + return data + + # Otherwise, values is a tensor in homogeneous order - split by node type + if node_type_to_idx is None: + # Build mapping from node type to (start, end) indices in homogeneous tensor + # When HeteroData is converted to homogeneous, nodes are ordered by node type. + # This mapping lets us slice the homogeneous tensor to get values for each type. + # Example: if data has 3 'user' nodes and 2 'item' nodes: + # node_type_to_idx = {'user': (0, 3), 'item': (3, 5)} + node_type_to_idx = {} + start_idx = 0 + for node_type in data.node_types: + num_type_nodes = data[node_type].num_nodes + node_type_to_idx[node_type] = (start_idx, start_idx + num_type_nodes) + start_idx += num_type_nodes + + for node_type in data.node_types: + start, end = node_type_to_idx[node_type] + value = values[start:end] + _set_node_attr_for_type(data, node_type, value, attr_name) + + return data + + +def _set_node_attr_for_type( + data: HeteroData, + node_type: str, + value: Tensor, + attr_name: Optional[str], +) -> None: + """Helper to set node attribute for a single node type.""" + if attr_name is None: + # Concatenate to existing x or create new x + # Use getattr to safely get x attribute, returns None if not present + x = getattr(data[node_type], "x", None) + if x is not None: + # Existing features found: concatenate new values to them + # Reshape 1D tensor [num_nodes] to 2D [num_nodes, 1] for concatenation + x = x.view(-1, 1) if x.dim() == 1 else x + # Move value to same device/dtype as x, then concatenate along feature dim + data[node_type].x = torch.cat( + [x, value.to(x.device, x.dtype)], dim=-1 + ) + else: + # No existing features: use new values as x directly + data[node_type].x = value + else: + data[node_type][attr_name] = value + + +def add_edge_attr( + data: HeteroData, + values: Union[Tensor, Dict[EdgeType, Tensor]], + attr_name: Optional[str] = None, + edge_type_to_idx: Optional[Dict[EdgeType, Tuple[int, int]]] = None, +) -> HeteroData: + """Helper function to add edge attributes to a HeteroData object. + + Args: + data: The HeteroData object to modify. + values: Either: + - A tensor of values in homogeneous edge order (requires edge_type_to_idx + or will be computed from data.edge_types), OR + - A dictionary mapping edge types to tensors of values for each type. + attr_name: The name of the attribute to add. If None, concatenates to + existing `edge_attr` attribute for each edge type (or creates it). + edge_type_to_idx: Optional mapping from edge type to (start, end) indices. + Only used when values is a tensor. If None, it will be computed from + data.edge_types. + + Returns: + The modified HeteroData object. + """ + # If values is a dictionary, directly assign to each edge type + if isinstance(values, dict): + for edge_type, value in values.items(): + if edge_type not in data.edge_types: + continue + _set_edge_attr_for_type(data, edge_type, value, attr_name) + return data + + # Otherwise, values is a tensor in homogeneous order - split by edge type + if edge_type_to_idx is None: + # Build mapping from edge type to (start, end) indices in homogeneous tensor + # When HeteroData is converted to homogeneous, edges are ordered by edge type. + # This mapping lets us slice the homogeneous tensor to get values for each type. + # Example: if data has 3 'buys' edges and 2 'views' edges: + # edge_type_to_idx = {('user', 'buys', 'item'): (0, 3), ('user', 'views', 'item'): (3, 5)} + edge_type_to_idx = {} + start_idx = 0 + for edge_type in data.edge_types: + num_type_edges = data[edge_type].num_edges + edge_type_to_idx[edge_type] = (start_idx, start_idx + num_type_edges) + start_idx += num_type_edges + + for edge_type in data.edge_types: + start, end = edge_type_to_idx[edge_type] + value = values[start:end] + _set_edge_attr_for_type(data, edge_type, value, attr_name) + + return data + + +def _set_edge_attr_for_type( + data: HeteroData, + edge_type: EdgeType, + value: Tensor, + attr_name: Optional[str], +) -> None: + """Helper to set edge attribute for a single edge type.""" + if attr_name is None: + # Concatenate to existing edge_attr or create new edge_attr + # Use getattr to safely get edge_attr attribute, returns None if not present + edge_attr = getattr(data[edge_type], "edge_attr", None) + if edge_attr is not None: + # Existing features found: concatenate new values to them + # Reshape 1D tensor [num_edges] to 2D [num_edges, 1] for concatenation + edge_attr = edge_attr.view(-1, 1) if edge_attr.dim() == 1 else edge_attr + # Move value to same device/dtype as edge_attr, then concatenate along feature dim + data[edge_type].edge_attr = torch.cat( + [edge_attr, value.to(edge_attr.device, edge_attr.dtype)], dim=-1 + ) + else: + # No existing features: use new values as edge_attr directly + data[edge_type].edge_attr = value + else: + data[edge_type][attr_name] = value diff --git a/tests/unit/transforms/__init__.py b/tests/unit/transforms/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/unit/transforms/add_positional_encodings_test.py b/tests/unit/transforms/add_positional_encodings_test.py new file mode 100644 index 000000000..0224ceaa3 --- /dev/null +++ b/tests/unit/transforms/add_positional_encodings_test.py @@ -0,0 +1,376 @@ +import torch +from absl.testing import absltest +from torch_geometric.data import HeteroData + +from gigl.transforms.add_positional_encodings import ( + AddHeteroHopDistanceEncoding, + AddHeteroRandomWalkPE, + AddHeteroRandomWalkSE, +) +from tests.test_assets.test_case import TestCase + + +def create_simple_hetero_data() -> HeteroData: + """Create a simple heterogeneous graph for testing. + + Graph structure: + - 3 'user' nodes + - 2 'item' nodes + - Edges: user -> item (bipartite) + """ + data = HeteroData() + + # Node features + data['user'].x = torch.randn(3, 4) + data['item'].x = torch.randn(2, 4) + + # Edges: user -> item + data['user', 'buys', 'item'].edge_index = torch.tensor([ + [0, 1, 2], # source (user) + [0, 0, 1], # target (item) + ]) + + # Edges: item -> user (reverse) + data['item', 'bought_by', 'user'].edge_index = torch.tensor([ + [0, 0, 1], # source (item) + [0, 1, 2], # target (user) + ]) + + return data + + +def create_empty_hetero_data() -> HeteroData: + """Create an empty heterogeneous graph for testing edge cases.""" + data = HeteroData() + data['user'].x = torch.zeros(0, 4) + data['item'].x = torch.zeros(0, 4) + return data + + + +class TestAddHeteroRandomWalkPE(TestCase): + """Tests for AddHeteroRandomWalkPE (Positional Encoding - column sum of non-diagonal).""" + + def test_forward_basic(self): + """Test basic forward pass.""" + data = create_simple_hetero_data() + transform = AddHeteroRandomWalkPE(walk_length=4) + + result = transform(data) + + # Check that PE was added to both node types + self.assertTrue(hasattr(result['user'], 'random_walk_pe')) + self.assertTrue(hasattr(result['item'], 'random_walk_pe')) + + # Check shapes + self.assertEqual(result['user'].random_walk_pe.shape, (3, 4)) + self.assertEqual(result['item'].random_walk_pe.shape, (2, 4)) + + def test_forward_with_custom_attr_name(self): + """Test forward pass with custom attribute name.""" + data = create_simple_hetero_data() + transform = AddHeteroRandomWalkPE(walk_length=3, attr_name='rw_pe') + + result = transform(data) + + self.assertTrue(hasattr(result['user'], 'rw_pe')) + self.assertTrue(hasattr(result['item'], 'rw_pe')) + self.assertFalse(hasattr(result['user'], 'random_walk_pe')) + + def test_forward_undirected(self): + """Test forward pass with undirected graph setting.""" + data = create_simple_hetero_data() + transform = AddHeteroRandomWalkPE(walk_length=3, is_undirected=True) + + result = transform(data) + + self.assertEqual(result['user'].random_walk_pe.shape, (3, 3)) + self.assertEqual(result['item'].random_walk_pe.shape, (2, 3)) + + def test_forward_empty_graph(self): + """Test forward pass with empty graph.""" + data = create_empty_hetero_data() + transform = AddHeteroRandomWalkPE(walk_length=3) + + result = transform(data) + + self.assertEqual(result['user'].random_walk_pe.shape, (0, 3)) + self.assertEqual(result['item'].random_walk_pe.shape, (0, 3)) + + def test_repr(self): + """Test string representation.""" + transform = AddHeteroRandomWalkPE(walk_length=10) + self.assertEqual(repr(transform), 'AddHeteroRandomWalkPE(walk_length=10, attach_to_x=False)') + + def test_forward_attach_to_x(self): + """Test forward pass with attach_to_x=True concatenates PE to node features.""" + data = create_simple_hetero_data() + original_user_dim = data['user'].x.shape[1] # 4 + original_item_dim = data['item'].x.shape[1] # 4 + walk_length = 3 + transform = AddHeteroRandomWalkPE(walk_length=walk_length, attach_to_x=True) + + result = transform(data) + + # Check that PE was NOT added as separate attribute + self.assertFalse(hasattr(result['user'], 'random_walk_pe')) + self.assertFalse(hasattr(result['item'], 'random_walk_pe')) + + # Check that x was expanded with PE dimensions + self.assertEqual(result['user'].x.shape, (3, original_user_dim + walk_length)) + self.assertEqual(result['item'].x.shape, (2, original_item_dim + walk_length)) + + def test_forward_attach_to_x_no_existing_features(self): + """Test forward pass with attach_to_x=True when nodes have no existing features.""" + data = HeteroData() + data['user'].num_nodes = 3 + data['item'].num_nodes = 2 + data['user', 'buys', 'item'].edge_index = torch.tensor([ + [0, 1, 2], + [0, 0, 1], + ]) + data['item', 'bought_by', 'user'].edge_index = torch.tensor([ + [0, 0, 1], + [0, 1, 2], + ]) + + walk_length = 4 + transform = AddHeteroRandomWalkPE(walk_length=walk_length, attach_to_x=True) + + result = transform(data) + + # Check that x was created with PE as features + self.assertTrue(hasattr(result['user'], 'x')) + self.assertTrue(hasattr(result['item'], 'x')) + self.assertEqual(result['user'].x.shape, (3, walk_length)) + self.assertEqual(result['item'].x.shape, (2, walk_length)) + + def test_forward_attach_to_x_empty_graph(self): + """Test forward pass with attach_to_x=True on empty graph.""" + data = create_empty_hetero_data() + original_user_dim = data['user'].x.shape[1] # 4 + original_item_dim = data['item'].x.shape[1] # 4 + walk_length = 3 + transform = AddHeteroRandomWalkPE(walk_length=walk_length, attach_to_x=True) + + result = transform(data) + + # Check shapes on empty graph + self.assertEqual(result['user'].x.shape, (0, original_user_dim + walk_length)) + self.assertEqual(result['item'].x.shape, (0, original_item_dim + walk_length)) + + def test_repr_attach_to_x(self): + """Test string representation with attach_to_x=True.""" + transform = AddHeteroRandomWalkPE(walk_length=10, attach_to_x=True) + self.assertEqual(repr(transform), 'AddHeteroRandomWalkPE(walk_length=10, attach_to_x=True)') + + +class TestAddHeteroRandomWalkSE(TestCase): + """Tests for AddHeteroRandomWalkSE (Structural Encoding - diagonal elements).""" + + def test_forward_basic(self): + """Test basic forward pass.""" + data = create_simple_hetero_data() + transform = AddHeteroRandomWalkSE(walk_length=4) + + result = transform(data) + + # Check that SE was added to both node types + self.assertTrue(hasattr(result['user'], 'random_walk_se')) + self.assertTrue(hasattr(result['item'], 'random_walk_se')) + + # Check shapes + self.assertEqual(result['user'].random_walk_se.shape, (3, 4)) + self.assertEqual(result['item'].random_walk_se.shape, (2, 4)) + + # Values should be probabilities (between 0 and 1) + self.assertTrue((result['user'].random_walk_se >= 0).all()) + self.assertTrue((result['user'].random_walk_se <= 1).all()) + + def test_forward_with_custom_attr_name(self): + """Test forward pass with custom attribute name.""" + data = create_simple_hetero_data() + transform = AddHeteroRandomWalkSE(walk_length=3, attr_name='rw_se') + + result = transform(data) + + self.assertTrue(hasattr(result['user'], 'rw_se')) + self.assertTrue(hasattr(result['item'], 'rw_se')) + self.assertFalse(hasattr(result['user'], 'random_walk_se')) + + def test_forward_undirected(self): + """Test forward pass with undirected graph setting.""" + data = create_simple_hetero_data() + transform = AddHeteroRandomWalkSE(walk_length=3, is_undirected=True) + + result = transform(data) + + self.assertEqual(result['user'].random_walk_se.shape, (3, 3)) + self.assertEqual(result['item'].random_walk_se.shape, (2, 3)) + + def test_forward_empty_graph(self): + """Test forward pass with empty graph.""" + data = create_empty_hetero_data() + transform = AddHeteroRandomWalkSE(walk_length=3) + + result = transform(data) + + self.assertEqual(result['user'].random_walk_se.shape, (0, 3)) + self.assertEqual(result['item'].random_walk_se.shape, (0, 3)) + + def test_repr(self): + """Test string representation.""" + transform = AddHeteroRandomWalkSE(walk_length=10) + self.assertEqual(repr(transform), 'AddHeteroRandomWalkSE(walk_length=10, attach_to_x=False)') + + def test_forward_attach_to_x(self): + """Test forward pass with attach_to_x=True concatenates SE to node features.""" + data = create_simple_hetero_data() + original_user_dim = data['user'].x.shape[1] # 4 + original_item_dim = data['item'].x.shape[1] # 4 + walk_length = 3 + transform = AddHeteroRandomWalkSE(walk_length=walk_length, attach_to_x=True) + + result = transform(data) + + # Check that SE was NOT added as separate attribute + self.assertFalse(hasattr(result['user'], 'random_walk_se')) + self.assertFalse(hasattr(result['item'], 'random_walk_se')) + + # Check that x was expanded with SE dimensions + self.assertEqual(result['user'].x.shape, (3, original_user_dim + walk_length)) + self.assertEqual(result['item'].x.shape, (2, original_item_dim + walk_length)) + + # The appended values should be valid probabilities (between 0 and 1) + # Extract the SE portion (last walk_length columns) + user_se = result['user'].x[:, -walk_length:] + item_se = result['item'].x[:, -walk_length:] + self.assertTrue((user_se >= 0).all()) + self.assertTrue((user_se <= 1).all()) + self.assertTrue((item_se >= 0).all()) + self.assertTrue((item_se <= 1).all()) + + def test_forward_attach_to_x_no_existing_features(self): + """Test forward pass with attach_to_x=True when nodes have no existing features.""" + data = HeteroData() + data['user'].num_nodes = 3 + data['item'].num_nodes = 2 + data['user', 'buys', 'item'].edge_index = torch.tensor([ + [0, 1, 2], + [0, 0, 1], + ]) + data['item', 'bought_by', 'user'].edge_index = torch.tensor([ + [0, 0, 1], + [0, 1, 2], + ]) + + walk_length = 4 + transform = AddHeteroRandomWalkSE(walk_length=walk_length, attach_to_x=True) + + result = transform(data) + + # Check that x was created with SE as features + self.assertTrue(hasattr(result['user'], 'x')) + self.assertTrue(hasattr(result['item'], 'x')) + self.assertEqual(result['user'].x.shape, (3, walk_length)) + self.assertEqual(result['item'].x.shape, (2, walk_length)) + + def test_forward_attach_to_x_empty_graph(self): + """Test forward pass with attach_to_x=True on empty graph.""" + data = create_empty_hetero_data() + original_user_dim = data['user'].x.shape[1] # 4 + original_item_dim = data['item'].x.shape[1] # 4 + walk_length = 3 + transform = AddHeteroRandomWalkSE(walk_length=walk_length, attach_to_x=True) + + result = transform(data) + + # Check shapes on empty graph + self.assertEqual(result['user'].x.shape, (0, original_user_dim + walk_length)) + self.assertEqual(result['item'].x.shape, (0, original_item_dim + walk_length)) + + def test_repr_attach_to_x(self): + """Test string representation with attach_to_x=True.""" + transform = AddHeteroRandomWalkSE(walk_length=10, attach_to_x=True) + self.assertEqual(repr(transform), 'AddHeteroRandomWalkSE(walk_length=10, attach_to_x=True)') + + +class TestAddHeteroHopDistanceEncoding(TestCase): + def test_forward_basic(self): + """Test basic forward pass returns sparse matrix.""" + data = create_simple_hetero_data() + transform = AddHeteroHopDistanceEncoding(h_max=3) + + result = transform(data) + + # Check that sparse pairwise distance matrix is stored as graph-level attribute + self.assertTrue(hasattr(result, 'hop_distance')) + # Total nodes: 3 users + 2 items = 5 nodes + self.assertEqual(result.hop_distance.shape, (5, 5)) + # Should be sparse + self.assertTrue(result.hop_distance.is_sparse) + + def test_forward_sparse_values(self): + """Test that sparse matrix has correct values (0 for unreachable, 1-h_max for reachable).""" + data = create_simple_hetero_data() + transform = AddHeteroHopDistanceEncoding(h_max=3) + + result = transform(data) + + # Convert to dense for easier testing + dense = result.hop_distance.to_dense() + + # Diagonal should be 0 (distance to self, not stored in sparse = 0) + self.assertTrue((dense.diag() == 0).all()) + + # Non-zero values (reachable pairs) should be in [1, h_max] + nonzero_vals = result.hop_distance.values() + if nonzero_vals.numel() > 0: + self.assertTrue((nonzero_vals >= 1).all()) + self.assertTrue((nonzero_vals <= 3).all()) + + def test_forward_with_custom_attr_name(self): + """Test forward pass with custom attribute name.""" + data = create_simple_hetero_data() + transform = AddHeteroHopDistanceEncoding(h_max=2, attr_name='custom_hop') + + result = transform(data) + + self.assertTrue(hasattr(result, 'custom_hop')) + self.assertFalse(hasattr(result, 'hop_distance')) + self.assertTrue(result.custom_hop.is_sparse) + + def test_forward_undirected(self): + """Test forward pass with undirected graph setting.""" + data = create_simple_hetero_data() + transform = AddHeteroHopDistanceEncoding(h_max=2, is_undirected=True) + + result = transform(data) + + self.assertTrue(result.hop_distance.is_sparse) + self.assertEqual(result.hop_distance.shape, (5, 5)) + + def test_forward_empty_graph(self): + """Test forward pass with empty graph.""" + data = create_empty_hetero_data() + # Add empty edge types + data['user', 'buys', 'item'].edge_index = torch.zeros((2, 0), dtype=torch.long) + transform = AddHeteroHopDistanceEncoding(h_max=3) + + result = transform(data) + + self.assertTrue(hasattr(result, 'hop_distance')) + self.assertTrue(result.hop_distance.is_sparse) + self.assertEqual(result.hop_distance.shape, (0, 0)) + + def test_repr(self): + """Test string representation.""" + transform = AddHeteroHopDistanceEncoding(h_max=5) + self.assertEqual( + repr(transform), + 'AddHeteroHopDistanceEncoding(h_max=5)' + ) + + +if __name__ == '__main__': + absltest.main() diff --git a/tests/unit/transforms/utils_test.py b/tests/unit/transforms/utils_test.py new file mode 100644 index 000000000..5ad1219bb --- /dev/null +++ b/tests/unit/transforms/utils_test.py @@ -0,0 +1,208 @@ +import torch +from absl.testing import absltest +from torch_geometric.data import HeteroData + +from gigl.transforms.utils import add_node_attr, add_edge_attr +from tests.test_assets.test_case import TestCase + + +def create_simple_hetero_data() -> HeteroData: + """Create a simple heterogeneous graph for testing. + + Graph structure: + - 3 'user' nodes + - 2 'item' nodes + - Edges: user -> item (bipartite) + """ + data = HeteroData() + + # Node features + data['user'].x = torch.randn(3, 4) + data['item'].x = torch.randn(2, 4) + + # Edges: user -> item + data['user', 'buys', 'item'].edge_index = torch.tensor([ + [0, 1, 2], # source (user) + [0, 0, 1], # target (item) + ]) + + # Edges: item -> user (reverse) + data['item', 'bought_by', 'user'].edge_index = torch.tensor([ + [0, 0, 1], # source (item) + [0, 1, 2], # target (user) + ]) + + return data + + +class TestAddNodeAttr(TestCase): + def test_add_node_attr_with_attr_name(self): + """Test adding a node attribute with a specific attribute name.""" + data = create_simple_hetero_data() + + # Create values in homogeneous order (3 users + 2 items = 5 nodes) + values = torch.randn(5, 8) + + add_node_attr(data, values, attr_name='test_attr') + + self.assertEqual(data['user'].test_attr.shape, (3, 8)) + self.assertEqual(data['item'].test_attr.shape, (2, 8)) + self.assert_tensor_equality(data['user'].test_attr, values[:3]) + self.assert_tensor_equality(data['item'].test_attr, values[3:]) + + def test_add_node_attr_with_dict(self): + """Test adding a node attribute with dictionary input.""" + data = create_simple_hetero_data() + + # Create values as dictionary per node type + values = { + 'user': torch.randn(3, 8), + 'item': torch.randn(2, 8), + } + + add_node_attr(data, values, attr_name='test_attr') + + self.assertEqual(data['user'].test_attr.shape, (3, 8)) + self.assertEqual(data['item'].test_attr.shape, (2, 8)) + self.assert_tensor_equality(data['user'].test_attr, values['user']) + self.assert_tensor_equality(data['item'].test_attr, values['item']) + + def test_add_node_attr_with_dict_partial(self): + """Test adding a node attribute with dictionary containing only some node types.""" + data = create_simple_hetero_data() + + # Only provide values for 'user' node type + values = { + 'user': torch.randn(3, 8), + } + + add_node_attr(data, values, attr_name='test_attr') + + self.assertTrue(hasattr(data['user'], 'test_attr')) + self.assertEqual(data['user'].test_attr.shape, (3, 8)) + self.assertFalse(hasattr(data['item'], 'test_attr')) + + def test_add_node_attr_concatenate_to_x(self): + """Test adding a node attribute by concatenating to existing x.""" + data = create_simple_hetero_data() + original_user_x = data['user'].x.clone() + original_item_x = data['item'].x.clone() + + # Create values in homogeneous order + values = torch.randn(5, 8) + + add_node_attr(data, values, attr_name=None) + + # Check that x was concatenated + self.assertEqual(data['user'].x.shape, (3, 12)) # 4 + 8 + self.assertEqual(data['item'].x.shape, (2, 12)) # 4 + 8 + + # Check original features are preserved + self.assert_tensor_equality(data['user'].x[:, :4], original_user_x) + self.assert_tensor_equality(data['item'].x[:, :4], original_item_x) + + def test_add_node_attr_create_x_if_none(self): + """Test creating x attribute if it doesn't exist.""" + data = HeteroData() + data['user'].num_nodes = 3 + data['item'].num_nodes = 2 + + values = torch.randn(5, 8) + + add_node_attr(data, values, attr_name=None) + + self.assertEqual(data['user'].x.shape, (3, 8)) + self.assertEqual(data['item'].x.shape, (2, 8)) + + +class TestAddEdgeAttr(TestCase): + def test_add_edge_attr_with_attr_name(self): + """Test adding an edge attribute with a specific attribute name.""" + data = create_simple_hetero_data() + + # Create values in homogeneous order (3 + 3 = 6 edges) + values = torch.randn(6, 8) + + add_edge_attr(data, values, attr_name='test_attr') + + self.assertEqual(data['user', 'buys', 'item'].test_attr.shape, (3, 8)) + self.assertEqual(data['item', 'bought_by', 'user'].test_attr.shape, (3, 8)) + self.assert_tensor_equality(data['user', 'buys', 'item'].test_attr, values[:3]) + self.assert_tensor_equality(data['item', 'bought_by', 'user'].test_attr, values[3:]) + + def test_add_edge_attr_with_dict(self): + """Test adding an edge attribute with dictionary input.""" + data = create_simple_hetero_data() + + # Create values as dictionary per edge type + values = { + ('user', 'buys', 'item'): torch.randn(3, 8), + ('item', 'bought_by', 'user'): torch.randn(3, 8), + } + + add_edge_attr(data, values, attr_name='test_attr') + + self.assertEqual(data['user', 'buys', 'item'].test_attr.shape, (3, 8)) + self.assertEqual(data['item', 'bought_by', 'user'].test_attr.shape, (3, 8)) + self.assert_tensor_equality(data['user', 'buys', 'item'].test_attr, values[('user', 'buys', 'item')]) + self.assert_tensor_equality(data['item', 'bought_by', 'user'].test_attr, values[('item', 'bought_by', 'user')]) + + def test_add_edge_attr_with_dict_partial(self): + """Test adding an edge attribute with dictionary containing only some edge types.""" + data = create_simple_hetero_data() + + # Only provide values for one edge type + values = { + ('user', 'buys', 'item'): torch.randn(3, 8), + } + + add_edge_attr(data, values, attr_name='test_attr') + + self.assertTrue(hasattr(data['user', 'buys', 'item'], 'test_attr')) + self.assertEqual(data['user', 'buys', 'item'].test_attr.shape, (3, 8)) + self.assertFalse(hasattr(data['item', 'bought_by', 'user'], 'test_attr')) + + def test_add_edge_attr_concatenate_to_edge_attr(self): + """Test adding an edge attribute by concatenating to existing edge_attr.""" + data = create_simple_hetero_data() + + # Add initial edge attributes + data['user', 'buys', 'item'].edge_attr = torch.randn(3, 4) + data['item', 'bought_by', 'user'].edge_attr = torch.randn(3, 4) + + original_buys_attr = data['user', 'buys', 'item'].edge_attr.clone() + original_bought_by_attr = data['item', 'bought_by', 'user'].edge_attr.clone() + + # Create values in homogeneous order + values = torch.randn(6, 8) + + add_edge_attr(data, values, attr_name=None) + + # Check that edge_attr was concatenated + self.assertEqual(data['user', 'buys', 'item'].edge_attr.shape, (3, 12)) # 4 + 8 + self.assertEqual(data['item', 'bought_by', 'user'].edge_attr.shape, (3, 12)) # 4 + 8 + + # Check original features are preserved + self.assert_tensor_equality(data['user', 'buys', 'item'].edge_attr[:, :4], original_buys_attr) + self.assert_tensor_equality(data['item', 'bought_by', 'user'].edge_attr[:, :4], original_bought_by_attr) + + def test_add_edge_attr_create_edge_attr_if_none(self): + """Test creating edge_attr attribute if it doesn't exist.""" + data = create_simple_hetero_data() + + # Ensure no edge_attr exists + if hasattr(data['user', 'buys', 'item'], 'edge_attr'): + del data['user', 'buys', 'item'].edge_attr + if hasattr(data['item', 'bought_by', 'user'], 'edge_attr'): + del data['item', 'bought_by', 'user'].edge_attr + + values = torch.randn(6, 8) + + add_edge_attr(data, values, attr_name=None) + + self.assertEqual(data['user', 'buys', 'item'].edge_attr.shape, (3, 8)) + self.assertEqual(data['item', 'bought_by', 'user'].edge_attr.shape, (3, 8)) + + +if __name__ == '__main__': + absltest.main()