diff --git a/examples/gcn.py b/examples/gcn.py new file mode 100644 index 0000000..887fafb --- /dev/null +++ b/examples/gcn.py @@ -0,0 +1,164 @@ +from torchmetrics import MetricCollection +from torchmetrics.classification import ( + BinaryAUROC, + BinaryAveragePrecision, + BinaryPrecision, + BinaryRecall, +) +from hyperbench.hlp import GCNHlpModule +from hyperbench.nn import LaplacianPositionalEncodingEnricher +from hyperbench.train import MultiModelTrainer, RandomNegativeSampler +from hyperbench.types import HData, ModelConfig +from hyperbench.data import AlgebraDataset, DataLoader, SamplingStrategy + + +if __name__ == "__main__": + verbose = False + num_workers = 8 + num_features = 32 + sampling_strategy = SamplingStrategy.HYPEREDGE + metrics = MetricCollection( + { + "auc": BinaryAUROC(), + "avg_precision": BinaryAveragePrecision(), + "precision": BinaryPrecision(), + "recall": BinaryRecall(), + } + ) + + print("Loading and preparing dataset...") + + dataset = AlgebraDataset(sampling_strategy=sampling_strategy, prepare=True) + dataset.remove_hyperedges_with_fewer_than_k_nodes(k=2) + if verbose: + print(f"Dataset:\n {dataset.hdata}\n") + + # Split dataset into train and test (80/20) + train_dataset, test_dataset = dataset.split( + ratios=[0.8, 0.2], shuffle=True, seed=42, node_space_setting="transductive" + ) + + # Split train into train and val (87.5/12.5 of train = 70/10 of total) + train_dataset, val_dataset = train_dataset.split( + ratios=[0.875, 0.125], shuffle=True, seed=42, node_space_setting="transductive" + ) + if verbose: + print(f"Train dataset (before train/val split):\n {train_dataset.hdata}\n") + print(f"Train dataset (after train/val split):\n {train_dataset.hdata}\n") + print(f"Val dataset:\n {val_dataset.hdata}\n") + print(f"Test dataset:\n {test_dataset.hdata}\n") + + # Add negative samples to all splits + for name, ds in [("Train", train_dataset), ("Val", val_dataset), ("Test", test_dataset)]: + num_negative_samples = ( + ds.hdata.num_hyperedges + if name in ["Train", "Val"] # 1:1 ratio of pos:neg samples + else int(ds.hdata.num_hyperedges * 0.6) # 60% negatives for test set + ) + negative_sampler = RandomNegativeSampler( + num_negative_samples=num_negative_samples, + num_nodes_per_sample=int(ds.stats()["avg_degree_hyperedge"]), + ) + neg_hdata = negative_sampler.sample(ds.hdata) + combined_hdata = HData.cat_same_node_space([ds.hdata, neg_hdata]) + shuffled_hdata = combined_hdata.shuffle(seed=42) + ds_with_negatives = ds.update_from_hdata(shuffled_hdata) + + if name == "Train": + train_dataset = ds_with_negatives + elif name == "Val": + val_dataset = ds_with_negatives + else: + test_dataset = ds_with_negatives + + if verbose: + print(f"{name} dataset after adding negative samples: {shuffled_hdata}\n") + + print("Enriching node features...") + + train_dataset.enrich_node_features( + enricher=LaplacianPositionalEncodingEnricher( + num_features=num_features, + # In transductive setting, use total number of nodes to ensure consistent encoding across splits + # as the train dataset contain all nodes but may have no hyperedges where they appear + num_nodes=train_dataset.hdata.num_nodes, + ), + enrichment_mode="replace", + ) + val_dataset.enrich_node_features_from(train_dataset) + test_dataset.enrich_node_features_from(train_dataset) + + print("Creating dataloaders...") + + train_loader_full_hypergraph = DataLoader( + train_dataset, + sample_full_hypergraph=True, + shuffle=False, + num_workers=num_workers, + persistent_workers=True, + ) + val_loader_full_hypergraph = DataLoader( + val_dataset, + sample_full_hypergraph=True, + shuffle=False, + num_workers=num_workers, + persistent_workers=True, + ) + test_loader_full_hypergraph = DataLoader( + test_dataset, + sample_full_hypergraph=True, + shuffle=False, + num_workers=num_workers, + persistent_workers=True, + ) + + mean_gcn_module = GCNHlpModule( + encoder_config={ + "in_channels": num_features, + "hidden_channels": 16, + "out_channels": 16, + "num_layers": 2, + "drop_rate": 0.1, + "bias": True, + "improved": False, + "add_self_loops": True, + "normalize": True, + "cached": False, + "graph_reduction_strategy": "clique_expansion", + }, + aggregation="mean", + lr=0.001, + weight_decay=5e-4, + metrics=metrics, + ) + + configs = [ + ModelConfig( + name="gcn", + version="mean", + model=mean_gcn_module, + train_dataloader=train_loader_full_hypergraph, + val_dataloader=val_loader_full_hypergraph, + test_dataloader=test_loader_full_hypergraph, + ), + ] + + print("Starting training and evaluation...") + + with MultiModelTrainer( + model_configs=configs, + max_epochs=60, + accelerator="auto", + log_every_n_steps=1, + enable_checkpointing=False, + auto_start_tensorboard=True, + auto_wait=True, + ) as trainer: + trainer.fit_all( + train_dataloader=train_loader_full_hypergraph, + val_dataloader=val_loader_full_hypergraph, + verbose=True, + ) + trainer.test_all(dataloader=test_loader_full_hypergraph, verbose=True) + + print("Complete!") diff --git a/examples/node2vecgcn.py b/examples/node2vecgcn.py new file mode 100644 index 0000000..828efd1 --- /dev/null +++ b/examples/node2vecgcn.py @@ -0,0 +1,213 @@ +from torchmetrics import MetricCollection +from torchmetrics.classification import ( + BinaryAUROC, + BinaryAveragePrecision, + BinaryPrecision, + BinaryRecall, +) +from hyperbench.data import AlgebraDataset, DataLoader, SamplingStrategy +from hyperbench.hlp import Node2VecGCNHlpModule, Node2VecGCNHlpConfig +from hyperbench.nn import Node2VecEnricher +from hyperbench.train import MultiModelTrainer, RandomNegativeSampler +from hyperbench.types import HData, ModelConfig + + +if __name__ == "__main__": + verbose = False + num_workers = 8 + num_features = 32 + sampling_strategy = SamplingStrategy.HYPEREDGE + metrics = MetricCollection( + { + "auc": BinaryAUROC(), + "avg_precision": BinaryAveragePrecision(), + "precision": BinaryPrecision(), + "recall": BinaryRecall(), + } + ) + + print("Loading and preparing dataset...") + + dataset = AlgebraDataset(sampling_strategy=sampling_strategy, prepare=True) + dataset.remove_hyperedges_with_fewer_than_k_nodes(k=2) + if verbose: + print(f"Dataset:\n {dataset.hdata}\n") + + # Split dataset into train and test (80/20) + train_dataset, test_dataset = dataset.split( + ratios=[0.8, 0.2], shuffle=True, seed=42, node_space_setting="transductive" + ) + + # Split train into train and val (87.5/12.5 of train = 70/10 of total) + train_dataset, val_dataset = train_dataset.split( + ratios=[0.875, 0.125], shuffle=True, seed=42, node_space_setting="transductive" + ) + if verbose: + print(f"Train dataset:\n {train_dataset.hdata}\n") + print(f"Val dataset:\n {val_dataset.hdata}\n") + print(f"Test dataset:\n {test_dataset.hdata}\n") + + print("Adding negative samples...") + + for name, ds in [("Train", train_dataset), ("Val", val_dataset), ("Test", test_dataset)]: + num_negative_samples = ( + ds.hdata.num_hyperedges + if name in ["Train", "Val"] + else int(ds.hdata.num_hyperedges * 0.6) + ) + negative_sampler = RandomNegativeSampler( + num_negative_samples=num_negative_samples, + num_nodes_per_sample=int(ds.stats()["avg_degree_hyperedge"]), + ) + neg_hdata = negative_sampler.sample(ds.hdata) + combined_hdata = HData.cat_same_node_space([ds.hdata, neg_hdata]) + shuffled_hdata = combined_hdata.shuffle(seed=42) + ds_with_negatives = ds.update_from_hdata(shuffled_hdata) + + if name == "Train": + train_dataset = ds_with_negatives + elif name == "Val": + val_dataset = ds_with_negatives + else: + test_dataset = ds_with_negatives + + if verbose: + print(f"{name} dataset after adding negative samples: {shuffled_hdata}\n") + + print("Computing Node2Vec embeddings from the train graph...") + + node2vec_enricher = Node2VecEnricher( + num_features=num_features, + context_size=10, + walk_length=20, + num_walks_per_node=10, + num_negative_samples=1, + num_nodes=dataset.hdata.num_nodes, + num_epochs=10, + learning_rate=0.01, + batch_size=128, + sparse=False, + verbose=verbose, + ) + train_dataset.enrich_node_features( + enricher=node2vec_enricher, + enrichment_mode="replace", + ) + val_dataset.enrich_node_features_from(train_dataset) + test_dataset.enrich_node_features_from(train_dataset) + + print("Creating dataloaders...") + + train_loader = DataLoader( + train_dataset, + batch_size=128, + shuffle=False, + num_workers=num_workers, + persistent_workers=True, + ) + val_loader = DataLoader( + val_dataset, + sample_full_hypergraph=True, + shuffle=False, + num_workers=num_workers, + persistent_workers=True, + ) + test_loader = DataLoader( + test_dataset, + sample_full_hypergraph=True, + shuffle=False, + num_workers=num_workers, + persistent_workers=True, + ) + + gcn_config: Node2VecGCNHlpConfig = { + "out_channels": num_features, + "hidden_channels": num_features, + "num_layers": 2, + "drop_rate": 0.1, + "bias": True, + "improved": False, + "add_self_loops": True, + "normalize": True, + "cached": False, + "graph_reduction_strategy": "clique_expansion", + } + + precomputed_node2vecgcn_module = Node2VecGCNHlpModule( + encoder_config={ + "mode": "precomputed", + "num_features": num_features, + "node2vec_config": {}, + "gcn_config": gcn_config, + }, + aggregation="mean", + lr=0.001, + weight_decay=0.0, + metrics=metrics, + ) + + train_hyperedge_index = train_dataset.hdata.hyperedge_index + joint_node2vecgcn_module = Node2VecGCNHlpModule( + encoder_config={ + "mode": "joint", + "num_features": num_features, + "node2vec_config": { + "context_size": 10, + "walk_length": 20, + "num_walks_per_node": 10, + "p": 1.0, + "q": 1.0, + "num_negative_samples": 1, + "train_hyperedge_index": train_hyperedge_index, + "num_nodes": dataset.hdata.num_nodes, + "graph_reduction_strategy": "clique_expansion", + "random_walk_batch_size": 128, + # We count the node2vec loss as 40% of the total loss (the rest is the SLP loss) + "node2vec_loss_weight": 0.4, + }, + "gcn_config": gcn_config, + }, + aggregation="mean", + lr=0.001, + weight_decay=0.0, + metrics=metrics, + ) + + configs = [ + ModelConfig( + name="node2vecgcn", + version="precomputed", + model=precomputed_node2vecgcn_module, + train_dataloader=train_loader, + val_dataloader=val_loader, + test_dataloader=test_loader, + ), + ModelConfig( + name="node2vecgcn", + version="joint", + model=joint_node2vecgcn_module, + train_dataloader=train_loader, + val_dataloader=val_loader, + test_dataloader=test_loader, + ), + ] + + print("Starting training and evaluation...") + + with MultiModelTrainer( + model_configs=configs, + max_epochs=60, + accelerator="auto", + log_every_n_steps=1, + enable_checkpointing=False, + auto_start_tensorboard=True, + auto_wait=True, + ) as trainer: + trainer.fit_all( + train_dataloader=train_loader, + val_dataloader=val_loader, + verbose=True, + ) + trainer.test_all(dataloader=test_loader, verbose=True) + + print("Complete!") diff --git a/examples/node2vecslp.py b/examples/node2vecslp.py index 30d608e..5749404 100644 --- a/examples/node2vecslp.py +++ b/examples/node2vecslp.py @@ -124,6 +124,7 @@ encoder_config={ "mode": "precomputed", "num_features": num_features, + "node2vec_config": {}, }, aggregation="mean", lr=0.001, @@ -136,17 +137,20 @@ encoder_config={ "mode": "joint", "num_features": num_features, - "context_size": 10, - "walk_length": 20, - "num_walks_per_node": 10, - "p": 1.0, - "q": 1.0, - "num_negative_samples": 1, - "train_hyperedge_index": train_hyperedge_index, - "num_nodes": dataset.hdata.num_nodes, - "graph_reduction_strategy": "clique_expansion", - "random_walk_batch_size": 128, - "node2vec_loss_weight": 1.0, + "node2vec_config": { + "context_size": 10, + "walk_length": 20, + "num_walks_per_node": 10, + "p": 1.0, + "q": 1.0, + "num_negative_samples": 1, + "train_hyperedge_index": train_hyperedge_index, + "num_nodes": dataset.hdata.num_nodes, + "graph_reduction_strategy": "clique_expansion", + "random_walk_batch_size": 128, + # We count the node2vec loss as 40% of the total loss (the rest is the SLP loss) + "node2vec_loss_weight": 0.4, + }, }, aggregation="mean", lr=0.001, diff --git a/hyperbench/hlp/__init__.py b/hyperbench/hlp/__init__.py index 7307230..10c5751 100644 --- a/hyperbench/hlp/__init__.py +++ b/hyperbench/hlp/__init__.py @@ -1,14 +1,28 @@ from .common_neighbors_hlp import CommonNeighborsHlpModule +from .gcn_hlp import GCNEncoderConfig, GCNHlpModule from .hgnn_hlp import HGNNHlpModule, HGNNEncoderConfig from .hnhn_hlp import HNHNEncoderConfig, HNHNHlpModule from .hgnnp_hlp import HGNNPEncoderConfig, HGNNPHlpModule from .hlp import HlpModule from .hypergcn_hlp import HyperGCNHlpModule, HyperGCNEncoderConfig from .mlp_hlp import MLPHlpModule, MlpEncoderConfig -from .node2vec_hlp import Node2VecEncoderConfig, Node2VecSLPHlpModule +from .node2vec_common import ( + NODE2VEC_JOINT_MODE, + NODE2VEC_PRECOMPUTED_MODE, + Node2VecGCNHlpConfig, + Node2VecHlpConfig, + Node2VecMode, +) +from .node2vecgcn_hlp import Node2VecGCNEncoderConfig, Node2VecGCNHlpModule +from .node2vecslp_hlp import Node2VecSLPEncoderConfig, Node2VecSLPHlpModule __all__ = [ + "NODE2VEC_JOINT_MODE", + "NODE2VEC_PRECOMPUTED_MODE", "CommonNeighborsHlpModule", + "GCNEncoderConfig", + "GCNHlpModule", + "Node2VecGCNHlpConfig", "HGNNEncoderConfig", "HGNNHlpModule", "HNHNEncoderConfig", @@ -20,6 +34,10 @@ "HyperGCNHlpModule", "MlpEncoderConfig", "MLPHlpModule", - "Node2VecEncoderConfig", + "Node2VecHlpConfig", + "Node2VecGCNEncoderConfig", + "Node2VecGCNHlpModule", + "Node2VecMode", + "Node2VecSLPEncoderConfig", "Node2VecSLPHlpModule", ] diff --git a/hyperbench/hlp/gcn_hlp.py b/hyperbench/hlp/gcn_hlp.py new file mode 100644 index 0000000..a97ae6a --- /dev/null +++ b/hyperbench/hlp/gcn_hlp.py @@ -0,0 +1,153 @@ +from torch import Tensor, nn, optim +from typing import Dict, Literal, Optional, TypedDict +from typing_extensions import NotRequired +from torchmetrics import MetricCollection +from hyperbench.models import GCN, SLP +from hyperbench.nn import HyperedgeAggregator +from hyperbench.types import EdgeIndex, HData, HyperedgeIndex +from hyperbench.utils import ActivationFn, Stage + +from hyperbench.hlp.hlp import HlpModule + + +class GCNEncoderConfig(TypedDict): + """ + Configuration for the GCN encoder in GCNHlpModule. + + Args: + in_channels: Number of input features per node. + out_channels: Number of output features (embedding size) per node. + hidden_channels: Number of hidden units in the intermediate GCN layers. + num_layers: Number of GCN layers. Defaults to ``2``. + drop_rate: Dropout rate applied after each hidden GCN layer. Defaults to ``0.0``. + bias: Whether to include bias terms. Defaults to ``True``. + improved: Whether to use the improved GCN normalization. Defaults to ``False``. + add_self_loops: Whether to add self-loops before convolution. Defaults to ``True``. + normalize: Whether to normalize the adjacency matrix in ``GCNConv``. Defaults to ``True``. + cached: Whether to cache the normalized graph in ``GCNConv``. Defaults to ``False``. + graph_reduction_strategy: Strategy for reducing the hypergraph to a graph. Defaults to ``"clique_expansion"``. + activation_fn: Activation function to use after each hidden layer. Defaults to ``nn.ReLU``. + activation_fn_kwargs: Keyword arguments for the activation function. Defaults to empty dict. + """ + + in_channels: int + out_channels: int + hidden_channels: NotRequired[int] + num_layers: NotRequired[int] + drop_rate: NotRequired[float] + bias: NotRequired[bool] + improved: NotRequired[bool] + add_self_loops: NotRequired[bool] + normalize: NotRequired[bool] + cached: NotRequired[bool] + graph_reduction_strategy: NotRequired[Literal["clique_expansion"]] + activation_fn: NotRequired[ActivationFn] + activation_fn_kwargs: NotRequired[Dict] + + +class GCNHlpModule(HlpModule): + """ + A LightningModule for GCN-based HLP. + + Uses a graph reduction of the input hypergraph to run GCN over nodes, + aggregates node embeddings per hyperedge, and scores each hyperedge with a linear decoder. + + Args: + encoder_config: Configuration for the GCN encoder. + aggregation: Method to aggregate node embeddings per hyperedge. Defaults to ``"mean"``. + loss_fn: Loss function. Defaults to ``BCEWithLogitsLoss``. + lr: Learning rate for the optimizer. Defaults to ``0.001``. + weight_decay: L2 regularization. Defaults to ``0.0``. + metrics: Optional metric collection for evaluation. + """ + + def __init__( + self, + encoder_config: GCNEncoderConfig, + aggregation: Literal["mean", "max", "min", "sum"] = "mean", + loss_fn: Optional[nn.Module] = None, + lr: float = 0.001, + weight_decay: float = 0.0, + metrics: Optional[MetricCollection] = None, + ): + encoder = GCN( + in_channels=encoder_config["in_channels"], + out_channels=encoder_config["out_channels"], + hidden_channels=encoder_config.get("hidden_channels"), + num_layers=encoder_config.get("num_layers", 2), + drop_rate=encoder_config.get("drop_rate", 0.0), + bias=encoder_config.get("bias", True), + activation_fn=encoder_config.get("activation_fn"), + activation_fn_kwargs=encoder_config.get("activation_fn_kwargs"), + improved=encoder_config.get("improved", False), + add_self_loops=encoder_config.get("add_self_loops", True), + normalize=encoder_config.get("normalize", True), + cached=encoder_config.get("cached", False), + ) + decoder = SLP(in_channels=encoder_config["out_channels"], out_channels=1) + + super().__init__( + encoder=encoder, + decoder=decoder, + loss_fn=loss_fn if loss_fn is not None else nn.BCEWithLogitsLoss(), + metrics=metrics, + ) + + self.encoder_config = encoder_config + self.aggregation = aggregation + self.lr = lr + self.weight_decay = weight_decay + + def forward(self, x: Tensor, hyperedge_index: Tensor) -> Tensor: + """ + Reduce the hypergraph to a graph, encode nodes with GCN, aggregate per hyperedge, and score. + + Args: + x: Node feature matrix of shape ``(num_nodes, in_channels)``. + hyperedge_index: Hyperedge connectivity of shape ``(2, num_incidences)``. + + Returns: + Logit scores of shape ``(num_hyperedges,)``. + """ + if self.encoder is None: + raise ValueError("Encoder is not defined for this HLP module.") + + # Reduce hypergraph to graph and remove self-loops + reduced_edge_index = HyperedgeIndex(hyperedge_index).reduce( + strategy=self.encoder_config.get("graph_reduction_strategy", "clique_expansion") + ) + edge_index = EdgeIndex(reduced_edge_index).remove_selfloops().item + + # Encode nodes with GCN + node_embeddings: Tensor = self.encoder(x, edge_index) + + # Aggregate node embeddings per hyperedge + hyperedge_embeddings = HyperedgeAggregator(hyperedge_index, node_embeddings).pool( + self.aggregation + ) + + return self.decoder(hyperedge_embeddings).squeeze(-1) + + def training_step(self, batch: HData, batch_idx: int) -> Tensor: + return self.__eval_step(batch, Stage.TRAIN) + + def validation_step(self, batch: HData, batch_idx: int) -> Tensor: + return self.__eval_step(batch, Stage.VAL) + + def test_step(self, batch: HData, batch_idx: int) -> Tensor: + return self.__eval_step(batch, Stage.TEST) + + def predict_step(self, batch: HData, batch_idx: int) -> Tensor: + return self.forward(batch.x, batch.hyperedge_index) + + def configure_optimizers(self): + return optim.Adam(self.parameters(), lr=self.lr, weight_decay=self.weight_decay) + + def __eval_step(self, batch: HData, stage: Stage) -> Tensor: + scores = self.forward(batch.x, batch.hyperedge_index) + labels = batch.y + batch_size = batch.num_hyperedges + + loss = self._compute_loss(scores, labels, batch_size, stage) + self._compute_metrics(scores, labels, batch_size, stage) + return loss diff --git a/hyperbench/hlp/node2vec_common.py b/hyperbench/hlp/node2vec_common.py new file mode 100644 index 0000000..2badd30 --- /dev/null +++ b/hyperbench/hlp/node2vec_common.py @@ -0,0 +1,221 @@ +from torch import Tensor, nn +from torch.utils.data import DataLoader +from typing import Dict, Iterator, Literal, Optional, Tuple, TypeAlias, TypedDict +from typing_extensions import NotRequired +from hyperbench.models import GCNConfig, Node2Vec, Node2VecGCN +from hyperbench.types import EdgeIndex, HyperedgeIndex +from hyperbench.utils import ActivationFn + + +NODE2VEC_JOINT_MODE = "joint" +NODE2VEC_PRECOMPUTED_MODE = "precomputed" + +Node2VecMode: TypeAlias = Literal["precomputed", "joint"] + + +class Node2VecGCNHlpConfig(TypedDict): + """ + Configuration for the GCN model. + + Args: + out_channels: Dimension of the output node embeddings from the GCN layers. + hidden_channels: Dimension of the hidden node embeddings in the GCN layers. + num_layers: Number of GCN layers. Must be at least 1. Defaults to ``2``. + drop_rate: Dropout rate applied after each GCN layer (except the last one). Defaults to ``0.0`` (no dropout). + bias: Whether to include a bias term in the GCN layers. Defaults to ``True``. + improved: Whether to use the improved version of GCNConv. Defaults to ``False``. + add_self_loops: Whether to add self-loops to the input graph. Defaults to ``True``. + normalize: Whether to symmetrically normalize the adjacency matrix in GCNConv. Defaults to ``True``. + cached: Whether to cache the normalized adjacency matrix in GCNConv. + Only applicable if the graph structure does not change between epochs. Defaults to ``False``. + graph_reduction_strategy: Strategy for reducing the hyperedge graph. Defaults to ``clique_expansion``. + activation_fn: Activation function to use after each hidden layer. Defaults to ``nn.ReLU``. + activation_fn_kwargs: Keyword arguments for the activation function. Defaults to empty dict. + """ + + out_channels: int + hidden_channels: NotRequired[int] + num_layers: NotRequired[int] + drop_rate: NotRequired[float] + bias: NotRequired[bool] + improved: NotRequired[bool] + add_self_loops: NotRequired[bool] + normalize: NotRequired[bool] + cached: NotRequired[bool] + graph_reduction_strategy: NotRequired[Literal["clique_expansion"]] + activation_fn: NotRequired[ActivationFn] + activation_fn_kwargs: NotRequired[Dict] + + +class Node2VecHlpConfig(TypedDict): + """ + Configuration for the Node2Vec encoder. + + Args: + context_size: Skip-gram context size for Node2Vec. + For example, if ``context_size=2`` and ``walk_length=5``, then for a random walk ``[v0, v1, v2, v3, v4]``, + the context for ``v2`` would be ``[v0, v1, v3, v4]`` as we take neighbors within distance 2 in the walk. + The pairs generated by skip-gram would be ``[(v2, v0), (v2, v1), (v2, v3), (v2, v4)]``. + Rule of thumb: Graphs with strong local structure (5-10), Graphs with communities/long-range patterns (10-20). + Defaults to ``10``. + walk_length: Length of each random walk. + num_walks_per_node: Number of walks sampled per node. + p: Node2Vec return parameter. Controls the probability of stepping back to the node visited + in the previous step. Lower values of ``p`` make immediate backtracking more likely, + while higher values discourage returning to the previous node. + q: Node2Vec in-out parameter. Controls whether walks stay near the source node or explore + further outward. Lower values of ``q`` bias the walk toward DFS-like exploration and + structural similarity, while higher values bias it toward BFS-like exploration and + local community structure and homophily. + num_negative_samples: Number of negative samples per positive walk context. + If set to ``X``, then for each positive pair ``(u, v)`` generated from the random walks, + ``X`` negative pairs ``(u, v_neg)`` will be generated, + where ``v_neg`` is a node sampled uniformly at random from all nodes in the graph. + Defaults to ``1``, meaning one negative sample per positive pair. + num_nodes: Number of nodes in the stable node space. Defaults to the number of nodes in the ``hyperedge_index`` if not provided. + train_hyperedge_index: Training hypereddge index used to build the Node2Vec walk graph. Required in ``joint`` mode. + graph_reduction_strategy: Strategy for reducing the hyperedge graph. Defaults to ``clique_expansion``. + random_walk_batch_size: Batch size used by the walk sampler in joint mode. + node2vec_loss_weight: Weight applied to the Node2Vec walk loss in joint mode. + This is to decide how much the loss of Node2Vec contributes to the overall loss in joint training, relative to the HLP loss. + Defaults to ``1.0`` (equal weighting). Set to a higher value to prioritize learning good node embeddings, + or a lower value to prioritize the HLP loss. Ignored in precomputed mode. + sparse: Whether to use sparse gradients in the Node2Vec encoder. Defaults to ``False``. + """ + + context_size: NotRequired[int] + walk_length: NotRequired[int] + num_walks_per_node: NotRequired[int] + p: NotRequired[float] + q: NotRequired[float] + num_negative_samples: NotRequired[int] + num_nodes: NotRequired[int] + train_hyperedge_index: NotRequired[Tensor] + graph_reduction_strategy: NotRequired[Literal["clique_expansion"]] + random_walk_batch_size: NotRequired[int] + node2vec_loss_weight: NotRequired[float] + sparse: NotRequired[bool] + + +class Node2VecWalkLoaderState: + """ + State object to hold the walk loader and its iterator for joint Node2Vec training. + + Args: + walk_loader: The DataLoader that provides batches of random walks from the Node2Vec encoder during joint training. + Initialized lazily when first needed. + cached_walk_loader_iterator: An iterator over the walk_loader, cached to allow + fetching the next batch of walks at each training step without reinitializing. + """ + + walk_loader: Optional[DataLoader] = None + cached_walk_loader_iterator: Optional[Iterator] = None + + +Node2VecEncoder: TypeAlias = Node2Vec | Node2VecGCN + + +def _get_walk_loader( + mode: Node2VecMode, + encoder: Optional[nn.Module], + batch_size: int, + state: Node2VecWalkLoaderState, +) -> Optional[DataLoader]: + if mode != NODE2VEC_JOINT_MODE: + return None + + if state.walk_loader is None: + state.walk_loader = _to_node2vec_encoder(encoder, mode).loader( + batch_size=batch_size, + shuffle=True, + ) + state.cached_walk_loader_iterator = iter(state.walk_loader) + + return state.walk_loader + + +def _next_walk_batch( + mode: Node2VecMode, + encoder: Optional[nn.Module], + batch_size: int, + state: Node2VecWalkLoaderState, +): + _get_walk_loader(mode, encoder, batch_size, state) + if state.walk_loader is None or state.cached_walk_loader_iterator is None: + raise ValueError("Joint Node2Vec mode could not initialize the walk loader.") + + try: + return next(state.cached_walk_loader_iterator) + except StopIteration: + state.cached_walk_loader_iterator = iter(state.walk_loader) + return next(state.cached_walk_loader_iterator) + + +def _to_node2vec_edge_index( + node2vec_config: Node2VecHlpConfig, + mode: Node2VecMode, +) -> Tuple[Tensor, int]: + if "train_hyperedge_index" not in node2vec_config: + raise ValueError(f"Node2Vec in mode {mode} requires train_hyperedge_index.") + + reduced_edge_index = HyperedgeIndex(node2vec_config["train_hyperedge_index"]).reduce( + strategy=node2vec_config.get("graph_reduction_strategy", "clique_expansion") + ) + edge_index_wrapper = EdgeIndex(reduced_edge_index).remove_selfloops() + num_nodes = node2vec_config.get("num_nodes", edge_index_wrapper.num_nodes) + return edge_index_wrapper.item, num_nodes + + +def _to_gcn_config(embedding_dim: int, gcn_hlp_config: Node2VecGCNHlpConfig) -> GCNConfig: + gcn_config: GCNConfig = { + "in_channels": embedding_dim, + "out_channels": gcn_hlp_config["out_channels"], + "hidden_channels": gcn_hlp_config.get("hidden_channels", embedding_dim), + "num_layers": gcn_hlp_config.get("num_layers", 2), + "drop_rate": gcn_hlp_config.get("drop_rate", 0.0), + "bias": gcn_hlp_config.get("bias", True), + "improved": gcn_hlp_config.get("improved", False), + "add_self_loops": gcn_hlp_config.get("add_self_loops", True), + "normalize": gcn_hlp_config.get("normalize", True), + "cached": gcn_hlp_config.get("cached", False), + } + if "activation_fn" in gcn_hlp_config: + gcn_config["activation_fn"] = gcn_hlp_config["activation_fn"] + if "activation_fn_kwargs" in gcn_hlp_config: + gcn_config["activation_fn_kwargs"] = gcn_hlp_config["activation_fn_kwargs"] + return gcn_config + + +def _to_node2vec_encoder(encoder: Optional[nn.Module], mode: Node2VecMode) -> Node2VecEncoder: + if encoder is None or not isinstance(encoder, (Node2Vec, Node2VecGCN)): + raise ValueError(f"Node2Vec in mode {mode} requires an encoder, but none was provided.") + return encoder + + +def _validate_global_node_ids( + num_embeddings: int, + global_node_ids: Optional[Tensor], + mode: Node2VecMode, +) -> None: + if global_node_ids is None or len(global_node_ids) < 1: + raise ValueError(f"Node2Vec in mode {mode} requires batch.global_node_ids.") + + min_global_node_id = int(global_node_ids.min().item()) + max_global_node_id = int(global_node_ids.max().item()) + + max_acceptable_node_id = num_embeddings - 1 + if min_global_node_id < 0 or max_global_node_id > max_acceptable_node_id: + raise ValueError( + f"Node2Vec in mode {mode} cannot index the provided batch.global_node_ids. " + f"Expected IDs in [0, {max_acceptable_node_id}], got " + f"[{min_global_node_id}, {max_global_node_id}]. " + "Set encoder_config['node2vec_config']['num_nodes'] for the full stable node space." + ) + + +def _validate_walk_length_and_context_size(walk_length: int, context_size: int) -> None: + if walk_length < context_size: + raise ValueError( + f"Expected walk_length >= context_size, got " + f"walk_length={walk_length}, context_size={context_size}." + ) diff --git a/hyperbench/hlp/node2vec_hlp.py b/hyperbench/hlp/node2vec_hlp.py deleted file mode 100644 index 6004cbf..0000000 --- a/hyperbench/hlp/node2vec_hlp.py +++ /dev/null @@ -1,283 +0,0 @@ -from torch import Tensor, nn, optim -from torch.utils.data import DataLoader -from typing import Iterator, Literal, Optional, TypedDict -from typing_extensions import NotRequired -from torchmetrics import MetricCollection -from hyperbench.models import Node2Vec, SLP -from hyperbench.types import EdgeIndex, HData, HyperedgeIndex -from hyperbench.utils import Stage -from hyperbench.nn import HyperedgeAggregator - -from hyperbench.hlp.hlp import HlpModule - - -class Node2VecEncoderConfig(TypedDict): - """ - Configuration for the Node2Vec encoder in ``Node2VecHlpModule``. - - Args: - mode: Whether to use precomputed node embeddings from ``x`` or train a Node2Vec encoder jointly inside the module. - num_features: Dimension of the node embeddings consumed by the decoder. - walk_length: Length of each random walk. - context_size: Skip-gram context size for Node2Vec. - For example, if ``context_size=2`` and ``walk_length=5``, then for a random walk ``[v0, v1, v2, v3, v4]``, - the context for ``v2`` would be ``[v0, v1, v3, v4]`` as we take neighbors within distance 2 in the walk. - The pairs generated by skip-gram would be ``[(v2, v0), (v2, v1), (v2, v3), (v2, v4)]``. - Rule of thumb: Graphs with strong local structure (5-10), Graphs with communities/long-range patterns (10-20). - Defaults to ``10``. - num_walks_per_node: Number of walks sampled per node. - p: Node2Vec return parameter. Controls the probability of stepping back to the node visited - in the previous step. Lower values of ``p`` make immediate backtracking more likely, - while higher values discourage returning to the previous node. - q: Node2Vec in-out parameter. Controls whether walks stay near the source node or explore - further outward. Lower values of ``q`` bias the walk toward DFS-like exploration and - structural similarity, while higher values bias it toward BFS-like exploration and - local community structure and homophily. - num_negative_samples: Number of negative samples per positive walk context. - If set to ``X``, then for each positive pair ``(u, v)`` generated from the random walks, - ``X`` negative pairs ``(u, v_neg)`` will be generated, - where ``v_neg`` is a node sampled uniformly at random from all nodes in the graph. - Defaults to ``1``, meaning one negative sample per positive pair. - num_nodes: Number of nodes in the stable node space. Defaults to the number of nodes in the ``hyperedge_index`` if not provided. - train_hyperedge_index: Training hypereddge index used to build the Node2Vec walk graph. Required in ``joint`` mode. - graph_reduction_strategy: Strategy for reducing the hyperedge graph. Defaults to ``clique_expansion``. - random_walk_batch_size: Batch size used by the walk sampler in joint mode. - node2vec_loss_weight: Weight applied to the Node2Vec walk loss in joint mode. - This is to decide how much the loss of Node2Vec contributes to the overall loss in joint training, relative to the HLP loss. - Defaults to ``1.0`` (equal weighting). Set to a higher value to prioritize learning good node embeddings, - or a lower value to prioritize the HLP loss. Ignored in precomputed mode. - """ - - mode: NotRequired[Literal["precomputed", "joint"]] - num_features: int - context_size: NotRequired[int] - walk_length: NotRequired[int] - num_walks_per_node: NotRequired[int] - p: NotRequired[float] - q: NotRequired[float] - num_negative_samples: NotRequired[int] - num_nodes: NotRequired[int] - train_hyperedge_index: NotRequired[Tensor] - graph_reduction_strategy: NotRequired[Literal["clique_expansion"]] - random_walk_batch_size: NotRequired[int] - node2vec_loss_weight: NotRequired[float] - - -class Node2VecSLPHlpModule(HlpModule): - """ - A LightningModule for Node2Vec-based Hyperedge Link Prediction. - - Supports two modes: - - ``precomputed``: use node embeddings already stored in ``batch.x``. - - ``joint``: train a Node2Vec encoder jointly with the hyperedge decoder. - - Args: - encoder_config: Configuration for the Node2Vec encoder. - aggregation: Method to aggregate node embeddings per hyperedge. - loss_fn: Loss function. Defaults to ``BCEWithLogitsLoss``. - lr: Learning rate for the optimizer. Defaults to ``0.001``. - weight_decay: Weight decay (L2 regularization) for the optimizer. Defaults to ``0.0`` (no weight decay). - metrics: Optional dictionary of metric functions. - """ - - JOINT_MODE = "joint" - PRECOMPUTED_MODE = "precomputed" - - def __init__( - self, - encoder_config: Node2VecEncoderConfig, - aggregation: Literal["mean", "max", "min", "sum"] = "mean", - loss_fn: Optional[nn.Module] = None, - lr: float = 0.001, - weight_decay: float = 0.0, - metrics: Optional[MetricCollection] = None, - ): - self.mode = encoder_config.get("mode", self.JOINT_MODE) - self.embedding_dim = encoder_config["num_features"] - self.num_nodes = encoder_config.get("num_nodes") - - if self.mode == self.JOINT_MODE: - if "train_hyperedge_index" not in encoder_config: - raise ValueError(f"Node2Vec in mode {self.mode} requires train_hyperedge_index.") - - walk_length = encoder_config.get("walk_length", 20) - context_size = encoder_config.get("context_size", 10) - if walk_length < context_size: - raise ValueError( - f"Expected walk_length >= context_size, got " - f"walk_length={walk_length}, context_size={context_size}." - ) - - reduced_edge_index = HyperedgeIndex(encoder_config["train_hyperedge_index"]).reduce( - encoder_config.get("graph_reduction_strategy", "clique_expansion") - ) - edge_index_wrapper = EdgeIndex(reduced_edge_index).remove_selfloops() - - encoder = Node2Vec( - edge_index=edge_index_wrapper.item, - embedding_dim=self.embedding_dim, - walk_length=walk_length, - context_size=context_size, - num_walks_per_node=encoder_config.get("num_walks_per_node", 10), - p=encoder_config.get("p", 1.0), - q=encoder_config.get("q", 1.0), - num_negative_samples=encoder_config.get("num_negative_samples", 1), - num_nodes=self.num_nodes - if self.num_nodes is not None - else edge_index_wrapper.num_nodes, - sparse=False, - ) - else: - # We don't need an encoder in precomputed mode - # since node features are used as node embeddings - encoder = None - - decoder = SLP(in_channels=self.embedding_dim, out_channels=1) - - super().__init__( - encoder=encoder, - decoder=decoder, - loss_fn=loss_fn if loss_fn is not None else nn.BCEWithLogitsLoss(), - metrics=metrics, - ) - - self.aggregation = aggregation - self.lr = lr - self.weight_decay = weight_decay - self.random_walk_batch_size = encoder_config.get("random_walk_batch_size", 128) - self.node2vec_loss_weight = encoder_config.get("node2vec_loss_weight", 1.0) - - self.__walk_loader: Optional[DataLoader] = None - self.__cached_walk_loader_iterator: Optional[Iterator] = None - - def forward( - self, - x: Tensor, - hyperedge_index: Tensor, - global_node_ids: Optional[Tensor] = None, - ) -> Tensor: - # Encode: get node embeddings from precomputation or joint encoder - if self.mode == self.JOINT_MODE: - encoder = self.__node2vec_encoder() - self.__validate_global_node_ids(encoder.num_embeddings, global_node_ids) - node_embeddings = encoder(batch=global_node_ids) - else: - if x.size(1) != self.embedding_dim: - raise ValueError( - f"Expected precomputed node embeddings with dimension " - f"{self.embedding_dim}, got {x.size(1)}." - ) - node_embeddings = x - - # Aggregate: pool node embeddings per hyperedge - # shape: (num_hyperedges, embedding_dim) - hyperedge_embeddings = HyperedgeAggregator(hyperedge_index, node_embeddings).pool( - self.aggregation - ) - - # Decode: linear projection to scalar score per hyperedge - # shape: (num_hyperedges, 1) -> squeeze -> (num_hyperedges,) - scores: Tensor = self.decoder(hyperedge_embeddings).squeeze(-1) - return scores - - def training_step(self, batch: HData, batch_idx: int) -> Tensor: - scores = self.forward(batch.x, batch.hyperedge_index, batch.global_node_ids) - labels = batch.y - batch_size = batch.num_hyperedges - - if self.mode == self.JOINT_MODE: - # Node2Vec.loss() is already a stochastic objective over sampled walks, - # so one walk batch is a standard SGD estimate, not a logically different loss, - # meaning we can optimize training by using a single walk batch per training step, - # instead of averaging over multiple walk batches. - positive_random_walk, negative_random_walk = self.__next_walk_batch() - positive_random_walk = positive_random_walk.to(self.device) - negative_random_walk = negative_random_walk.to(self.device) - - hlp_loss = self.loss_fn(scores, labels) - node2vec_loss = self.__node2vec_encoder().loss( - positive_random_walk, - negative_random_walk, - ) - loss = hlp_loss + (self.node2vec_loss_weight * node2vec_loss) - - loss_prefix = Stage.TRAIN.value - self.log(f"{loss_prefix}_hlp_loss", hlp_loss, prog_bar=True, batch_size=batch_size) - self.log( - f"{loss_prefix}_node2vec_loss", node2vec_loss, prog_bar=True, batch_size=batch_size - ) - self.log(f"{loss_prefix}_loss", loss, prog_bar=True, batch_size=batch_size) - else: - loss = self._compute_loss(scores, labels, batch_size, Stage.TRAIN) - - self._compute_metrics(scores, labels, batch_size, Stage.TRAIN) - return loss - - def validation_step(self, batch: HData, batch_idx: int) -> Tensor: - return self.__eval_step(batch, Stage.VAL) - - def test_step(self, batch: HData, batch_idx: int) -> Tensor: - return self.__eval_step(batch, Stage.TEST) - - def predict_step(self, batch: HData, batch_idx: int) -> Tensor: - return self.forward(batch.x, batch.hyperedge_index, batch.global_node_ids) - - def configure_optimizers(self): - return optim.Adam(self.parameters(), lr=self.lr, weight_decay=self.weight_decay) - - def __eval_step(self, batch: HData, stage: Stage) -> Tensor: - scores = self.forward(batch.x, batch.hyperedge_index, batch.global_node_ids) - labels = batch.y - batch_size = batch.num_hyperedges - - loss = self._compute_loss(scores, labels, batch_size, stage) - self._compute_metrics(scores, labels, batch_size, stage) - return loss - - def __get_walk_loader(self) -> Optional[DataLoader]: - if self.mode != "joint": # Walk loader is only needed in joint mode - return None - - if self.__walk_loader is None: - self.__walk_loader = self.__node2vec_encoder().loader( - batch_size=self.random_walk_batch_size, - shuffle=True, - ) - self.__cached_walk_loader_iterator = iter(self.__walk_loader) - - return self.__walk_loader - - def __next_walk_batch(self): - self.__get_walk_loader() - if self.__walk_loader is None or self.__cached_walk_loader_iterator is None: - raise ValueError("Joint Node2Vec mode could not initialize the walk loader.") - - try: - return next(self.__cached_walk_loader_iterator) - except StopIteration: - # Create a new iterator as we exhausted the existing one - self.__cached_walk_loader_iterator = iter(self.__walk_loader) - return next(self.__cached_walk_loader_iterator) - - def __node2vec_encoder(self) -> Node2Vec: - if self.encoder is None or not isinstance(self.encoder, Node2Vec): - raise ValueError( - f"Node2Vec in mode {self.mode} requires an encoder, but none was provided." - ) - return self.encoder - - def __validate_global_node_ids(self, num_embeddings: int, global_node_ids: Optional[Tensor]): - if global_node_ids is None or len(global_node_ids) < 1: - raise ValueError(f"Node2Vec in mode {self.mode} requires batch.global_node_ids.") - - min_global_node_id = int(global_node_ids.min().item()) - max_global_node_id = int(global_node_ids.max().item()) - - max_acceptable_node_id = num_embeddings - 1 - if min_global_node_id < 0 or max_global_node_id > max_acceptable_node_id: - raise ValueError( - f"Node2Vec in mode {self.mode} cannot index the provided batch.global_node_ids. " - f"Expected IDs in [0, {max_acceptable_node_id}], got " - f"[{min_global_node_id}, {max_global_node_id}]. " - "Set encoder_config['num_nodes'] for the full stable node space." - ) diff --git a/hyperbench/hlp/node2vecgcn_hlp.py b/hyperbench/hlp/node2vecgcn_hlp.py new file mode 100644 index 0000000..944cb52 --- /dev/null +++ b/hyperbench/hlp/node2vecgcn_hlp.py @@ -0,0 +1,230 @@ +from torch import Tensor, nn, optim +from typing import Literal, Optional, TypedDict +from typing_extensions import NotRequired +from torchmetrics import MetricCollection +from hyperbench.models import GCN, Node2VecGCN, Node2VecConfig, SLP +from hyperbench.types import EdgeIndex, HData, HyperedgeIndex +from hyperbench.nn import HyperedgeAggregator +from hyperbench.utils import Stage + +from hyperbench.hlp.hlp import HlpModule +from hyperbench.hlp.node2vec_common import ( + NODE2VEC_JOINT_MODE, + NODE2VEC_PRECOMPUTED_MODE, + Node2VecGCNHlpConfig, + Node2VecHlpConfig, + Node2VecMode, + Node2VecWalkLoaderState, + _next_walk_batch, + _to_gcn_config, + _to_node2vec_encoder, + _to_node2vec_edge_index, + _validate_global_node_ids, + _validate_walk_length_and_context_size, +) + + +class Node2VecGCNEncoderConfig(TypedDict): + """ + Configuration for the Node2Vec encoder in ``Node2VecGCNHlpModule``. + + Args: + mode: Whether to use precomputed node embeddings from ``x`` or train a Node2Vec encoder jointly inside the module. + num_features: Dimension of the node embeddings consumed by the decoder. + node2vec_config: Shared Node2Vec configuration used in joint mode, or metadata for validating precomputed embeddings. + gcn_config: Configuration for the GCN layers. + """ + + mode: NotRequired[Node2VecMode] + num_features: int + node2vec_config: Node2VecHlpConfig + gcn_config: Node2VecGCNHlpConfig + + +class Node2VecGCNHlpModule(HlpModule): + """ + A LightningModule for Node2Vec-based Hyperedge Link Prediction with GCN encoder. + + Supports two modes: + - ``precomputed``: use node embeddings already stored in ``batch.x``. + - ``joint``: train a Node2Vec encoder jointly with the GCN layers and hyperedge decoder. + + Args: + encoder_config: Configuration for the Node2Vec encoder and GCN layers. + aggregation: Method to aggregate node embeddings per hyperedge. + loss_fn: Loss function. Defaults to ``BCEWithLogitsLoss``. + lr: Learning rate for the optimizer. Defaults to ``0.001``. + weight_decay: Weight decay (L2 regularization) for the optimizer. Defaults to ``0.0`` (no weight decay). + metrics: Optional dictionary of metric functions. + """ + + def __init__( + self, + encoder_config: Node2VecGCNEncoderConfig, + aggregation: Literal["mean", "max", "min", "sum"] = "mean", + loss_fn: Optional[nn.Module] = None, + lr: float = 0.001, + weight_decay: float = 0.0, + metrics: Optional[MetricCollection] = None, + ): + self.mode = encoder_config.get("mode", NODE2VEC_JOINT_MODE) + self.embedding_dim = encoder_config["num_features"] + + self.node2vec_hlp_config = encoder_config["node2vec_config"] + self.gcn_hlp_config = encoder_config["gcn_config"] + + node2vecgcn_encoder = ( + self.__build_node2vecgcn_encoder( + embedding_dim=self.embedding_dim, + node2vec_config=self.node2vec_hlp_config, + gcn_config=self.gcn_hlp_config, + mode=self.mode, + ) + if self.mode == NODE2VEC_JOINT_MODE + else None + ) + + decoder = SLP(in_channels=self.gcn_hlp_config["out_channels"], out_channels=1) + + super().__init__( + encoder=node2vecgcn_encoder, + decoder=decoder, + loss_fn=loss_fn if loss_fn is not None else nn.BCEWithLogitsLoss(), + metrics=metrics, + ) + + self.precomputed_gcn_encoder = ( + self.__build_gcn_encoder(self.embedding_dim, self.gcn_hlp_config) + if self.mode == NODE2VEC_PRECOMPUTED_MODE + else None + ) + + self.aggregation = aggregation + self.lr = lr + self.weight_decay = weight_decay + self.random_walk_batch_size = self.node2vec_hlp_config.get("random_walk_batch_size", 128) + self.node2vec_loss_weight = self.node2vec_hlp_config.get("node2vec_loss_weight", 1.0) + + self.__walk_loader_state = Node2VecWalkLoaderState() + + def forward( + self, + x: Tensor, + hyperedge_index: Tensor, + global_node_ids: Optional[Tensor] = None, + ) -> Tensor: + gcn_edge_index = self.__to_gcn_edge_index(hyperedge_index) + + if self.mode == NODE2VEC_JOINT_MODE: + encoder = _to_node2vec_encoder(self.encoder, self.mode) + _validate_global_node_ids(encoder.num_embeddings, global_node_ids, self.mode) + node_embeddings = encoder(batch=global_node_ids, edge_index=gcn_edge_index) + else: + if x.size(1) != self.embedding_dim: + raise ValueError( + f"Expected precomputed node embeddings with dimension " + f"{self.embedding_dim}, got {x.size(1)}." + ) + if self.precomputed_gcn_encoder is None: + raise ValueError("Precomputed GCN encoder is not initialized.") + node_embeddings = self.precomputed_gcn_encoder(x, gcn_edge_index) + + hyperedge_embeddings = HyperedgeAggregator( + hyperedge_index, + node_embeddings, + ).pool(self.aggregation) + + return self.decoder(hyperedge_embeddings).squeeze(-1) + + def training_step(self, batch: HData, batch_idx: int) -> Tensor: + scores = self.forward(batch.x, batch.hyperedge_index, batch.global_node_ids) + labels = batch.y + batch_size = batch.num_hyperedges + + if self.mode == NODE2VEC_JOINT_MODE: + positive_random_walk, negative_random_walk = _next_walk_batch( + mode=self.mode, + encoder=self.encoder, + batch_size=self.random_walk_batch_size, + state=self.__walk_loader_state, + ) + positive_random_walk = positive_random_walk.to(self.device) + negative_random_walk = negative_random_walk.to(self.device) + + hlp_loss = self.loss_fn(scores, labels) + node2vec_loss = _to_node2vec_encoder(self.encoder, self.mode).loss( + positive_random_walk, negative_random_walk + ) + loss = hlp_loss + (self.node2vec_loss_weight * node2vec_loss) + + self.log("train_hlp_loss", hlp_loss, prog_bar=True, batch_size=batch_size) + self.log("train_node2vec_loss", node2vec_loss, prog_bar=True, batch_size=batch_size) + self.log("train_loss", loss, prog_bar=True, batch_size=batch_size) + else: + loss = self._compute_loss(scores, labels, batch_size, Stage.TRAIN) + + self._compute_metrics(scores, labels, batch_size, Stage.TRAIN) + return loss + + def validation_step(self, batch: HData, batch_idx: int) -> Tensor: + return self.__eval_step(batch, Stage.VAL) + + def test_step(self, batch: HData, batch_idx: int) -> Tensor: + return self.__eval_step(batch, Stage.TEST) + + def predict_step(self, batch: HData, batch_idx: int) -> Tensor: + return self.forward(batch.x, batch.hyperedge_index, batch.global_node_ids) + + def configure_optimizers(self): + return optim.Adam(self.parameters(), lr=self.lr, weight_decay=self.weight_decay) + + def __build_gcn_encoder(self, embedding_dim: int, gcn_config: Node2VecGCNHlpConfig) -> GCN: + return GCN(**_to_gcn_config(embedding_dim, gcn_config)) + + def __build_node2vecgcn_encoder( + self, + embedding_dim: int, + node2vec_config: Node2VecHlpConfig, + gcn_config: Node2VecGCNHlpConfig, + mode: Node2VecMode, + ) -> Node2VecGCN: + _validate_walk_length_and_context_size( + walk_length=node2vec_config.get("walk_length", 20), + context_size=node2vec_config.get("context_size", 10), + ) + + edge_index, num_nodes = _to_node2vec_edge_index(node2vec_config, mode) + + model_node2vec_config: Node2VecConfig = { + "edge_index": edge_index, + "embedding_dim": embedding_dim, + "walk_length": node2vec_config.get("walk_length", 20), + "context_size": node2vec_config.get("context_size", 10), + "num_walks_per_node": node2vec_config.get("num_walks_per_node", 10), + "p": node2vec_config.get("p", 1.0), + "q": node2vec_config.get("q", 1.0), + "num_negative_samples": node2vec_config.get("num_negative_samples", 1), + "num_nodes": num_nodes, + "sparse": node2vec_config.get("sparse", False), + } + + return Node2VecGCN( + node2vec_config=model_node2vec_config, + gcn_config=_to_gcn_config(embedding_dim, gcn_config), + ) + + def __eval_step(self, batch: HData, stage: Stage) -> Tensor: + scores = self.forward(batch.x, batch.hyperedge_index, batch.global_node_ids) + labels = batch.y + batch_size = batch.num_hyperedges + + loss = self._compute_loss(scores, labels, batch_size, stage) + self._compute_metrics(scores, labels, batch_size, stage) + return loss + + def __to_gcn_edge_index(self, hyperedge_index: Tensor) -> Tensor: + graph_reduction_strategy = self.gcn_hlp_config.get( + "graph_reduction_strategy", "clique_expansion" + ) + reduced_gcn_edge_index = HyperedgeIndex(hyperedge_index).reduce(graph_reduction_strategy) + return EdgeIndex(reduced_gcn_edge_index).remove_selfloops().item diff --git a/hyperbench/hlp/node2vecslp_hlp.py b/hyperbench/hlp/node2vecslp_hlp.py new file mode 100644 index 0000000..9c31785 --- /dev/null +++ b/hyperbench/hlp/node2vecslp_hlp.py @@ -0,0 +1,205 @@ +from torch import Tensor, nn, optim +from typing import Literal, Optional, TypedDict +from typing_extensions import NotRequired +from torchmetrics import MetricCollection +from hyperbench.models import Node2Vec, SLP +from hyperbench.types import HData +from hyperbench.utils import Stage +from hyperbench.nn import HyperedgeAggregator + +from hyperbench.hlp.hlp import HlpModule +from hyperbench.hlp.node2vec_common import ( + NODE2VEC_JOINT_MODE, + Node2VecHlpConfig, + Node2VecMode, + Node2VecWalkLoaderState, + _next_walk_batch, + _to_node2vec_edge_index, + _to_node2vec_encoder, + _validate_global_node_ids, + _validate_walk_length_and_context_size, +) + + +class Node2VecSLPEncoderConfig(TypedDict): + """ + Configuration for the Node2Vec encoder in ``Node2VecSLPHlpModule``. + + Args: + mode: Whether to use precomputed node embeddings from ``x`` or train a Node2Vec encoder jointly inside the module. + num_features: Dimension of the node embeddings consumed by the decoder. + node2vec_config: Shared Node2Vec configuration used in joint mode, or metadata for validating precomputed embeddings. + """ + + mode: NotRequired[Node2VecMode] + num_features: int + node2vec_config: Node2VecHlpConfig + + +class Node2VecSLPHlpModule(HlpModule): + """ + A LightningModule for Node2Vec-based Hyperedge Link Prediction. + + Supports two modes: + - ``precomputed``: use node embeddings already stored in ``batch.x``. + - ``joint``: train a Node2Vec encoder jointly with the hyperedge decoder. + + Args: + encoder_config: Configuration for the Node2Vec encoder. + aggregation: Method to aggregate node embeddings per hyperedge. + loss_fn: Loss function. Defaults to ``BCEWithLogitsLoss``. + lr: Learning rate for the optimizer. Defaults to ``0.001``. + weight_decay: Weight decay (L2 regularization) for the optimizer. Defaults to ``0.0`` (no weight decay). + metrics: Optional dictionary of metric functions. + """ + + def __init__( + self, + encoder_config: Node2VecSLPEncoderConfig, + aggregation: Literal["mean", "max", "min", "sum"] = "mean", + loss_fn: Optional[nn.Module] = None, + lr: float = 0.001, + weight_decay: float = 0.0, + metrics: Optional[MetricCollection] = None, + ): + self.mode = encoder_config.get("mode", NODE2VEC_JOINT_MODE) + self.embedding_dim = encoder_config["num_features"] + node2vec_config = encoder_config["node2vec_config"] + + encoder = ( + self.__build_node2vec_encoder(self.embedding_dim, node2vec_config, self.mode) + if self.mode == NODE2VEC_JOINT_MODE + else None + ) + + decoder = SLP(in_channels=self.embedding_dim, out_channels=1) + + super().__init__( + encoder=encoder, + decoder=decoder, + loss_fn=loss_fn if loss_fn is not None else nn.BCEWithLogitsLoss(), + metrics=metrics, + ) + + self.aggregation = aggregation + self.lr = lr + self.weight_decay = weight_decay + self.random_walk_batch_size = node2vec_config.get("random_walk_batch_size", 128) + self.node2vec_loss_weight = node2vec_config.get("node2vec_loss_weight", 1.0) + + self.__walk_loader_state = Node2VecWalkLoaderState() + + def forward( + self, + x: Tensor, + hyperedge_index: Tensor, + global_node_ids: Optional[Tensor] = None, + ) -> Tensor: + # Encode: get node embeddings from precomputation or joint encoder + if self.mode == NODE2VEC_JOINT_MODE: + encoder = _to_node2vec_encoder(self.encoder, self.mode) + _validate_global_node_ids(encoder.num_embeddings, global_node_ids, self.mode) + node_embeddings = encoder(batch=global_node_ids) + else: + if x.size(1) != self.embedding_dim: + raise ValueError( + f"Expected precomputed node embeddings with dimension " + f"{self.embedding_dim}, got {x.size(1)}." + ) + node_embeddings = x + + # Aggregate: pool node embeddings per hyperedge + # shape: (num_hyperedges, embedding_dim) + hyperedge_embeddings = HyperedgeAggregator(hyperedge_index, node_embeddings).pool( + self.aggregation + ) + + # Decode: linear projection to scalar score per hyperedge + # shape: (num_hyperedges, 1) -> squeeze -> (num_hyperedges,) + scores: Tensor = self.decoder(hyperedge_embeddings).squeeze(-1) + return scores + + def training_step(self, batch: HData, batch_idx: int) -> Tensor: + scores = self.forward(batch.x, batch.hyperedge_index, batch.global_node_ids) + labels = batch.y + batch_size = batch.num_hyperedges + + if self.mode == NODE2VEC_JOINT_MODE: + # Node2Vec.loss() is already a stochastic objective over sampled walks, + # so one walk batch is a standard SGD estimate, not a logically different loss, + # meaning we can optimize training by using a single walk batch per training step, + # instead of averaging over multiple walk batches. + positive_random_walk, negative_random_walk = _next_walk_batch( + mode=self.mode, + encoder=self.encoder, + batch_size=self.random_walk_batch_size, + state=self.__walk_loader_state, + ) + positive_random_walk = positive_random_walk.to(self.device) + negative_random_walk = negative_random_walk.to(self.device) + + hlp_loss = self.loss_fn(scores, labels) + node2vec_loss = _to_node2vec_encoder(self.encoder, self.mode).loss( + positive_random_walk, + negative_random_walk, + ) + loss = hlp_loss + (self.node2vec_loss_weight * node2vec_loss) + + loss_prefix = Stage.TRAIN.value + self.log(f"{loss_prefix}_hlp_loss", hlp_loss, prog_bar=True, batch_size=batch_size) + self.log( + f"{loss_prefix}_node2vec_loss", node2vec_loss, prog_bar=True, batch_size=batch_size + ) + self.log(f"{loss_prefix}_loss", loss, prog_bar=True, batch_size=batch_size) + else: + loss = self._compute_loss(scores, labels, batch_size, Stage.TRAIN) + + self._compute_metrics(scores, labels, batch_size, Stage.TRAIN) + return loss + + def validation_step(self, batch: HData, batch_idx: int) -> Tensor: + return self.__eval_step(batch, Stage.VAL) + + def test_step(self, batch: HData, batch_idx: int) -> Tensor: + return self.__eval_step(batch, Stage.TEST) + + def predict_step(self, batch: HData, batch_idx: int) -> Tensor: + return self.forward(batch.x, batch.hyperedge_index, batch.global_node_ids) + + def configure_optimizers(self): + return optim.Adam(self.parameters(), lr=self.lr, weight_decay=self.weight_decay) + + def __build_node2vec_encoder( + self, + embedding_dim: int, + node2vec_config: Node2VecHlpConfig, + mode: Node2VecMode, + ) -> Node2Vec: + _validate_walk_length_and_context_size( + walk_length=node2vec_config.get("walk_length", 20), + context_size=node2vec_config.get("context_size", 10), + ) + + edge_index, num_nodes = _to_node2vec_edge_index(node2vec_config, mode) + + return Node2Vec( + edge_index=edge_index, + embedding_dim=embedding_dim, + walk_length=node2vec_config.get("walk_length", 20), + context_size=node2vec_config.get("context_size", 10), + num_walks_per_node=node2vec_config.get("num_walks_per_node", 10), + p=node2vec_config.get("p", 1.0), + q=node2vec_config.get("q", 1.0), + num_negative_samples=node2vec_config.get("num_negative_samples", 1), + num_nodes=num_nodes, + sparse=node2vec_config.get("sparse", False), + ) + + def __eval_step(self, batch: HData, stage: Stage) -> Tensor: + scores = self.forward(batch.x, batch.hyperedge_index, batch.global_node_ids) + labels = batch.y + batch_size = batch.num_hyperedges + + loss = self._compute_loss(scores, labels, batch_size, stage) + self._compute_metrics(scores, labels, batch_size, stage) + return loss diff --git a/hyperbench/models/__init__.py b/hyperbench/models/__init__.py index 53b2ecf..080b280 100644 --- a/hyperbench/models/__init__.py +++ b/hyperbench/models/__init__.py @@ -1,9 +1,23 @@ from .common_neighbors import CommonNeighbors +from .gcn import GCN, GCNConfig from .hgnn import HGNN from .hnhn import HNHN from .hgnnp import HGNNP from .hypergcn import HyperGCN from .mlp import MLP, SLP -from .node2vec import Node2Vec +from .node2vec import Node2Vec, Node2VecConfig, Node2VecGCN -__all__ = ["CommonNeighbors", "HGNN", "HGNNP", "HNHN", "HyperGCN", "MLP", "Node2Vec", "SLP"] +__all__ = [ + "CommonNeighbors", + "GCN", + "GCNConfig", + "HGNN", + "HGNNP", + "HNHN", + "HyperGCN", + "MLP", + "Node2Vec", + "Node2VecConfig", + "Node2VecGCN", + "SLP", +] diff --git a/hyperbench/models/gcn.py b/hyperbench/models/gcn.py new file mode 100644 index 0000000..7a6148e --- /dev/null +++ b/hyperbench/models/gcn.py @@ -0,0 +1,144 @@ +from torch import Tensor, nn +from typing import Dict, Optional, TypedDict +from typing_extensions import NotRequired +from torch_geometric.nn import GCNConv +from hyperbench.utils import ActivationFn, is_layer + + +class GCNConfig(TypedDict): + """ + Configuration for the GCN model. + + Args: + in_channels: Dimension of the input node embeddings to the GCN layers. + out_channels: Dimension of the output node embeddings from the GCN layers. + hidden_channels: Dimension of the hidden node embeddings in the GCN layers. + num_layers: Number of GCN layers. Must be at least 1. Defaults to ``2``. + drop_rate: Dropout rate applied after each GCN layer (except the last one). Defaults to ``0.0`` (no dropout). + activation_fn: Activation function to use after each hidden layer. Defaults to ``nn.ReLU``. + activation_fn_kwargs: Keyword arguments for the activation function. Defaults to empty dict. + bias: Whether to include a bias term in the GCN layers. Defaults to ``True``. + improved: Whether to use the improved version of GCNConv. Defaults to ``False``. + add_self_loops: Whether to add self-loops to the input graph. Defaults to ``True``. + normalize: Whether to symmetrically normalize the adjacency matrix in GCNConv. Defaults to ``True``. + cached: Whether to cache the normalized adjacency matrix in GCNConv. + Only applicable if the graph structure does not change between epochs. Defaults to ``False``. + """ + + in_channels: int + out_channels: int + hidden_channels: NotRequired[int] + num_layers: NotRequired[int] + drop_rate: NotRequired[float] + bias: NotRequired[bool] + activation_fn: NotRequired[ActivationFn] + activation_fn_kwargs: NotRequired[Dict] + improved: NotRequired[bool] + add_self_loops: NotRequired[bool] + normalize: NotRequired[bool] + cached: NotRequired[bool] + + +class GCN(nn.Module): + """ + A reusable multi-layer GCN stack built from ``torch_geometric.nn.GCNConv``. + + Args: + in_channels: Dimension of the input node embeddings to the GCN layers. + out_channels: Dimension of the output node embeddings from the GCN layers. + hidden_channels: Dimension of the hidden node embeddings in the GCN layers. + Defaults to ``in_channels``. + num_layers: Number of GCN layers. Must be at least 1. Defaults to ``2``. + drop_rate: Dropout rate applied after each GCN layer except the last one. + bias: Whether to include a bias term in the GCN layers. + activation_fn: Activation function to use after each hidden layer. Defaults to ``nn.ReLU``. + activation_fn_kwargs: Keyword arguments for the activation function. Defaults to empty dict. + improved: Whether to use the improved version of ``GCNConv``. + add_self_loops: Whether to add self-loops to the input graph. + normalize: Whether to symmetrically normalize the adjacency matrix in ``GCNConv``. + cached: Whether to cache the normalized adjacency matrix in ``GCNConv``. + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + hidden_channels: Optional[int] = None, + num_layers: int = 2, + drop_rate: float = 0.0, + bias: bool = True, + activation_fn: Optional[ActivationFn] = None, + activation_fn_kwargs: Optional[Dict] = None, + improved: bool = False, + add_self_loops: bool = True, + normalize: bool = True, + cached: bool = False, + ): + super().__init__() + activation_fn = activation_fn if activation_fn is not None else nn.ReLU + activation_fn_kwargs = activation_fn_kwargs if activation_fn_kwargs is not None else {} + + self.dropout = nn.Dropout(drop_rate) + self.activation = activation_fn(**activation_fn_kwargs) + self.layers = self.__build_layers( + in_channels=in_channels, + out_channels=out_channels, + hidden_channels=hidden_channels, + num_layers=num_layers, + bias=bias, + improved=improved, + add_self_loops=add_self_loops, + normalize=normalize, + cached=cached, + ) + + def forward(self, x: Tensor, edge_index: Tensor) -> Tensor: + num_layers = len(self.layers) + for idx, layer in enumerate(self.layers): + x = layer(x, edge_index) + + is_not_last_layer = not is_layer(idx, num_layers - 1) + if is_not_last_layer: + x = self.activation(x) + x = self.dropout(x) + + return x + + def __build_layers( + self, + in_channels: int, + out_channels: int, + hidden_channels: Optional[int], + num_layers: int, + bias: bool, + improved: bool, + add_self_loops: bool, + normalize: bool, + cached: bool, + ) -> nn.ModuleList: + if num_layers < 1: + raise ValueError(f"Expected num_layers >= 1 for GCN, got {num_layers}.") + + hidden_channels = hidden_channels if hidden_channels is not None else 0 + if num_layers > 1 and hidden_channels <= 0: + raise ValueError( + f"Expected positive hidden_channels for GCN with multiple layers, got {hidden_channels}." + ) + + common_kwargs: Dict[str, bool] = { + "bias": bias, + "improved": improved, + "add_self_loops": add_self_loops, + "normalize": normalize, + "cached": cached, + } + + if num_layers == 1: + return nn.ModuleList([GCNConv(in_channels, out_channels, **common_kwargs)]) + + layers = [GCNConv(in_channels, hidden_channels, **common_kwargs)] + for _ in range(num_layers - 2): + layers.append(GCNConv(hidden_channels, hidden_channels, **common_kwargs)) + layers.append(GCNConv(hidden_channels, out_channels, **common_kwargs)) + + return nn.ModuleList(layers) diff --git a/hyperbench/models/node2vec.py b/hyperbench/models/node2vec.py index ef80003..5144375 100644 --- a/hyperbench/models/node2vec.py +++ b/hyperbench/models/node2vec.py @@ -1,7 +1,10 @@ from torch import Tensor, nn -from typing import Optional +from typing import Optional, TypedDict +from typing_extensions import NotRequired from torch_geometric.nn import Node2Vec as PyGNode2Vec +from hyperbench.models.gcn import GCN, GCNConfig + class Node2Vec(nn.Module): """ @@ -82,3 +85,89 @@ def loss(self, pos_rw: Tensor, neg_rw: Tensor) -> Tensor: def loader(self, batch_size: int = 128, shuffle: bool = True): return self.model.loader(batch_size=batch_size, shuffle=shuffle) + + +class Node2VecConfig(TypedDict): + """ + Configuration for the Node2Vec model. + + Args: + edge_index: Edge index representing the graph structure. Size ``(2, num_edges)``. + embedding_dim: Dimension of the node embeddings to learn. + walk_length: Length of each random walk. + context_size: Window size for the skip-gram model (number of neighbors in the walk considered as context). + For example, if ``context_size=2`` and ``walk_length=5``, then for a random walk ``[v0, v1, v2, v3, v4]``, + the context for ``v2`` would be ``[v0, v1, v3, v4]`` as we take neighbors within distance 2 in the walk. + The pairs generated by skip-gram would be ``[(v2, v0), (v2, v1), (v2, v3), (v2, v4)]``. + Rule of thumb: Graphs with strong local structure (5-10), Graphs with communities/long-range patterns (10-20). + Defaults to ``10``. + num_walks_per_node: Number of random walks to start at each node. + p: Return hyperparameter for Node2Vec. Default is ``1.0`` (unbiased). + This controls the probability of stepping back to the node visited in the previous step. + Lower values of ``p`` make immediate backtracking more likely, which keeps walks closer to the + local neighborhood. Higher values of ``p`` discourage returning to the previous node, so walks + are less likely to bounce back and forth across the same edge. + q: In-out hyperparameter for Node2Vec. Default is ``1.0`` (unbiased). + This controls whether walks stay near the source node or explore further outward. + Lower values of ``q`` bias the walk toward outward exploration, behaving more like DFS and + emphasizing structural roles. Higher values of ``q`` bias the walk toward nearby nodes, + behaving more like BFS and emphasizing community structure and homophily. + num_negative_samples: Number of negative samples to use for training the skip-gram model. + If set to ``X``, then for each positive pair ``(u, v)`` generated from the random walks, ``X`` negative pairs ``(u, v_neg)`` will be generated, + where ``v_neg`` is a node sampled uniformly at random from all nodes in the graph. + Defaults to ``1``, meaning one negative sample per positive pair. + num_nodes: Total number of nodes in the graph. If not provided, it will be inferred from the hyperedge_index. + This is only needed if the hyperedge_index does not include all nodes (e.g., some isolated nodes are missing). + sparse: Whether Node2Vec embeddings should use sparse gradients. + """ + + edge_index: Tensor + embedding_dim: int + context_size: NotRequired[int] + walk_length: NotRequired[int] + num_walks_per_node: NotRequired[int] + p: NotRequired[float] + q: NotRequired[float] + num_negative_samples: NotRequired[int] + num_nodes: NotRequired[int] + sparse: NotRequired[bool] + + +class Node2VecGCN(nn.Module): + """ + A joint encoder that first learns Node2Vec embeddings and then refines them with GCN layers. + + Args: + node2vec_config: Model-side configuration for the internal ``Node2Vec`` encoder. + gcn_config: Model-side configuration for the GCN stack applied to the Node2Vec embeddings. + """ + + def __init__( + self, + node2vec_config: Node2VecConfig, + gcn_config: GCNConfig, + ): + super().__init__() + self.node2vec = Node2Vec(**node2vec_config) + self.gcn = GCN(**gcn_config) + + def forward( + self, + batch: Optional[Tensor] = None, + edge_index: Optional[Tensor] = None, + ) -> Tensor: + if edge_index is None: + raise ValueError("Node2VecGCN requires edge_index in forward().") + + node_embeddings = self.node2vec(batch) + return self.gcn(node_embeddings, edge_index) + + @property + def num_embeddings(self) -> int: + return self.node2vec.num_embeddings + + def loss(self, pos_rw: Tensor, neg_rw: Tensor) -> Tensor: + return self.node2vec.loss(pos_rw, neg_rw) + + def loader(self, batch_size: int = 128, shuffle: bool = True): + return self.node2vec.loader(batch_size=batch_size, shuffle=shuffle)