Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion examples/hgnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,6 @@
"bias": True,
"use_batch_normalization": False,
"drop_rate": 0.5,
"fast": False,
},
aggregation="mean",
lr=0.01,
Expand Down
155 changes: 155 additions & 0 deletions examples/hgnnp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
from torchmetrics import MetricCollection
from torchmetrics.classification import (
BinaryAUROC,
BinaryAveragePrecision,
BinaryPrecision,
BinaryRecall,
)
from lightning.pytorch.callbacks import EarlyStopping

from hyperbench.data import AlgebraDataset, DataLoader, SamplingStrategy
from hyperbench.hlp import HGNNPHlpModule
from hyperbench.nn import LaplacianPositionalEncodingEnricher
from hyperbench.train import MultiModelTrainer, RandomNegativeSampler
from hyperbench.types import HData, ModelConfig


if __name__ == "__main__":
verbose = False
num_workers = 8
sampling_strategy = SamplingStrategy.HYPEREDGE
metrics = MetricCollection(
{
"auc": BinaryAUROC(),
"avg_precision": BinaryAveragePrecision(),
"precision": BinaryPrecision(),
"recall": BinaryRecall(),
}
)

print("Loading and preparing dataset...")

dataset = AlgebraDataset(sampling_strategy=sampling_strategy, prepare=True)
if verbose:
print(f"Dataset:\n {dataset.hdata}\n")

train_dataset, test_dataset = dataset.split(ratios=[0.8, 0.2], shuffle=True, seed=42)
train_dataset, val_dataset = train_dataset.split(ratios=[0.875, 0.125], shuffle=True, seed=42)
if verbose:
print(f"Train dataset (before train/val split):\n {train_dataset.hdata}\n")
print(f"Train dataset (after train/val split):\n {train_dataset.hdata}\n")
print(f"Val dataset:\n {val_dataset.hdata}\n")
print(f"Test dataset:\n {test_dataset.hdata}\n")

train_hyperedge_index = train_dataset.hdata.hyperedge_index

for name, ds in [("Train", train_dataset), ("Val", val_dataset), ("Test", test_dataset)]:
num_negative_samples = (
ds.hdata.num_hyperedges
if name in ["Train", "Val"]
else int(ds.hdata.num_hyperedges * 0.6)
)
negative_sampler = RandomNegativeSampler(
num_negative_samples=num_negative_samples,
num_nodes_per_sample=int(ds.stats()["avg_degree_hyperedge"]),
)
neg_hdata = negative_sampler.sample(ds.hdata)
combined_hdata = HData.cat_same_node_space([ds.hdata, neg_hdata])
shuffled_hdata = combined_hdata.shuffle(seed=42)
ds_with_negatives = ds.update_from_hdata(shuffled_hdata)

if name == "Train":
train_dataset = ds_with_negatives
elif name == "Val":
val_dataset = ds_with_negatives
else:
test_dataset = ds_with_negatives

if verbose:
print(f"{name} dataset after adding negative samples: {shuffled_hdata}\n")

print("Enriching node features...")

train_dataset.enrich_node_features(
enricher=LaplacianPositionalEncodingEnricher(num_features=32),
enrichment_mode="replace",
)
val_dataset.hdata.x = train_dataset.hdata.x[: val_dataset.hdata.num_nodes]
test_dataset.hdata.x = train_dataset.hdata.x[:, : test_dataset.hdata.num_nodes]

print("Creating dataloaders...")

train_loader_full_hypergraph = DataLoader(
train_dataset,
sample_full_hypergraph=True,
shuffle=False,
num_workers=num_workers,
persistent_workers=True,
)
val_loader_full_hypergraph = DataLoader(
val_dataset,
sample_full_hypergraph=True,
shuffle=False,
num_workers=num_workers,
persistent_workers=True,
)
test_loader_full_hypergraph = DataLoader(
test_dataset,
sample_full_hypergraph=True,
shuffle=False,
num_workers=num_workers,
persistent_workers=True,
)

mean_hgnnp_module = HGNNPHlpModule(
encoder_config={
"in_channels": 32,
"hidden_channels": 16,
"out_channels": 16,
"bias": True,
"use_batch_normalization": False,
"drop_rate": 0.5,
},
aggregation="mean",
lr=0.01,
weight_decay=5e-4,
metrics=metrics,
)

configs = [
ModelConfig(
name="hgnnp",
version="mean",
model=mean_hgnnp_module,
train_dataloader=train_loader_full_hypergraph,
val_dataloader=val_loader_full_hypergraph,
test_dataloader=test_loader_full_hypergraph,
),
]

early_stopping = EarlyStopping(
monitor="val_loss",
patience=30,
mode="min",
)

print("Starting training and evaluation...")

with MultiModelTrainer(
model_configs=configs,
max_epochs=60,
accelerator="auto",
log_every_n_steps=1,
callbacks=[early_stopping],
enable_checkpointing=False,
auto_start_tensorboard=True,
auto_wait=True,
) as trainer:
trainer.fit_all(
train_dataloader=train_loader_full_hypergraph,
val_dataloader=val_loader_full_hypergraph,
verbose=True,
)
trainer.test_all(dataloader=test_loader_full_hypergraph, verbose=True)

print("Complete!")
3 changes: 3 additions & 0 deletions hyperbench/hlp/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from .common_neighbors_hlp import CommonNeighborsHlpModule
from .hgnn_hlp import HGNNHlpModule, HGNNEncoderConfig
from .hgnnp_hlp import HGNNPEncoderConfig, HGNNPHlpModule
from .hlp import HlpModule
from .hypergcn_hlp import HyperGCNHlpModule, HyperGCNEncoderConfig
from .mlp_hlp import MLPHlpModule, MlpEncoderConfig
Expand All @@ -9,6 +10,8 @@
"CommonNeighborsHlpModule",
"HGNNEncoderConfig",
"HGNNHlpModule",
"HGNNPEncoderConfig",
"HGNNPHlpModule",
"HlpModule",
"HyperGCNEncoderConfig",
"HyperGCNHlpModule",
Expand Down
3 changes: 0 additions & 3 deletions hyperbench/hlp/hgnn_hlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ class HGNNEncoderConfig(TypedDict):
bias: Whether to include bias terms. Defaults to ``True``.
use_batch_normalization: Whether to use batch normalization. Defaults to ``False``.
drop_rate: Dropout rate. Defaults to ``0.5``.
fast: Whether to cache the HGNN Laplacian. Defaults to ``True``.
"""

in_channels: int
Expand All @@ -30,7 +29,6 @@ class HGNNEncoderConfig(TypedDict):
bias: NotRequired[bool]
use_batch_normalization: NotRequired[bool]
drop_rate: NotRequired[float]
fast: NotRequired[bool]


class HGNNHlpModule(HlpModule):
Expand Down Expand Up @@ -66,7 +64,6 @@ def __init__(
bias=encoder_config.get("bias", True),
use_batch_normalization=encoder_config.get("use_batch_normalization", False),
drop_rate=encoder_config.get("drop_rate", 0.5),
fast=encoder_config.get("fast", True),
)
decoder = SLP(in_channels=encoder_config["out_channels"], out_channels=1)

Expand Down
146 changes: 146 additions & 0 deletions hyperbench/hlp/hgnnp_hlp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
from torch import Tensor, nn, optim
from typing import Literal, Optional, TypedDict
from torchmetrics import MetricCollection
from typing_extensions import NotRequired
from hyperbench.models import HGNNP, SLP
from hyperbench.nn import HyperedgeAggregator
from hyperbench.types import HData
from hyperbench.utils import Stage

from hyperbench.hlp.hlp import HlpModule


class HGNNPEncoderConfig(TypedDict):
"""
Configuration for the HGNN+ encoder in HGNNPHlpModule.

Args:
in_channels: Number of input features per node.
hidden_channels: Number of hidden units in the intermediate HGNN+ layer.
out_channels: Number of output features (embedding size) per node.
bias: Whether to include bias terms. Defaults to ``True``.
use_batch_normalization: Whether to use batch normalization. Defaults to ``False``.
drop_rate: Dropout rate. Defaults to ``0.5``.
"""

in_channels: int
hidden_channels: int
out_channels: int
bias: NotRequired[bool]
use_batch_normalization: NotRequired[bool]
drop_rate: NotRequired[float]


class HGNNPHlpModule(HlpModule):
"""
A LightningModule for HGNN+-based Hyperedge Link Prediction.

Uses HGNN+ as an encoder to produce structure-aware node embeddings via
row-stochastic hypergraph convolution, aggregates them per hyperedge,
and scores each hyperedge with a linear decoder.

Args:
encoder_config: Configuration for the HGNN+ encoder.
aggregation: Method to aggregate node embeddings per hyperedge. Defaults to ``"mean"``.
loss_fn: Loss function. Defaults to ``BCEWithLogitsLoss``.
lr: Learning rate for the optimizer. Defaults to ``0.01``.
weight_decay: L2 regularization. Defaults to ``5e-4``.
metrics: Optional metric collection for evaluation.
"""

def __init__(
self,
encoder_config: HGNNPEncoderConfig,
aggregation: Literal["mean", "max", "min", "sum"] = "mean",
loss_fn: Optional[nn.Module] = None,
lr: float = 0.01,
weight_decay: float = 5e-4,
metrics: Optional[MetricCollection] = None,
):
encoder = HGNNP(
in_channels=encoder_config["in_channels"],
hidden_channels=encoder_config["hidden_channels"],
num_classes=encoder_config["out_channels"],
bias=encoder_config.get("bias", True),
use_batch_normalization=encoder_config.get("use_batch_normalization", False),
drop_rate=encoder_config.get("drop_rate", 0.5),
)
decoder = SLP(in_channels=encoder_config["out_channels"], out_channels=1)

super().__init__(
encoder=encoder,
decoder=decoder,
loss_fn=loss_fn if loss_fn is not None else nn.BCEWithLogitsLoss(),
metrics=metrics,
)

self.aggregation = aggregation
self.lr = lr
self.weight_decay = weight_decay

def forward(self, x: Tensor, hyperedge_index: Tensor) -> Tensor:
"""
Run the full HGNN+-based hyperedge link prediction pipeline.

The pipeline has three stages:
1. Encode: HGNN+ applies two rounds of ``D_v^{-1} H D_e^{-1} H^T``
smoothing to propagate information through the hypergraph topology with
two-stage mean aggregation. The output is a structure-aware node
embedding matrix of shape ``(num_nodes, out_channels)``.
2. Aggregate: For each hyperedge being scored, pool the embeddings of its member
nodes using the configured strategy (mean/max/min/sum). This produces a hyperedge
embedding of shape ``(num_hyperedges, out_channels)``.
3. Decode: A single linear layer projects each hyperedge embedding to a
scalar score. Shape: ``(num_hyperedges,)``.

Args:
x: Node feature matrix of shape ``(num_nodes, in_channels)``.
Must contain **all** nodes referenced in ``hyperedge_index``.
hyperedge_index: Hyperedge connectivity of shape ``(2, num_incidences)``,
with row 0 containing global node IDs and row 1 hyperedge IDs.

Returns:
Logit scores of shape ``(num_hyperedges,)``.
"""
if self.encoder is None:
raise ValueError("Encoder is not defined for this HLP module.")

# Encode: produce node embeddings using HGNN+, no graph reduction is applied
# Example: x: (num_nodes, in_channels)
# -> node_embeddings: (num_nodes, out_channels), out_channels)
node_embeddings: Tensor = self.encoder(x, hyperedge_index)

# Aggregate: pool node embeddings per hyperedge
# shape: (num_hyperedges, out_channels)
hyperedge_embeddings = HyperedgeAggregator(hyperedge_index, node_embeddings).pool(
self.aggregation
)

# Decode: linear projection to scalar score per hyperedge
# shape: (num_hyperedges, 1) -> squeeze -> (num_hyperedges,)
scores: Tensor = self.decoder(hyperedge_embeddings).squeeze(-1)
return scores

def training_step(self, batch: HData, batch_idx: int) -> Tensor:
return self.__eval_step(batch, Stage.TRAIN)

def validation_step(self, batch: HData, batch_idx: int) -> Tensor:
return self.__eval_step(batch, Stage.VAL)

def test_step(self, batch: HData, batch_idx: int) -> Tensor:
return self.__eval_step(batch, Stage.TEST)

def predict_step(self, batch: HData, batch_idx: int) -> Tensor:
return self.forward(batch.x, batch.hyperedge_index)

def configure_optimizers(self):
return optim.Adam(self.parameters(), lr=self.lr, weight_decay=self.weight_decay)

def __eval_step(self, batch: HData, stage: Stage) -> Tensor:
scores = self.forward(batch.x, batch.hyperedge_index)
labels = batch.y
batch_size = batch.num_hyperedges

loss = self._compute_loss(scores, labels, batch_size, stage)
self._compute_metrics(scores, labels, batch_size, stage)
return loss
3 changes: 2 additions & 1 deletion hyperbench/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from .common_neighbors import CommonNeighbors
from .hgnn import HGNN
from .hgnnp import HGNNP
from .hypergcn import HyperGCN
from .mlp import MLP, SLP
from .node2vec import Node2Vec

__all__ = ["CommonNeighbors", "HGNN", "HyperGCN", "MLP", "Node2Vec", "SLP"]
__all__ = ["CommonNeighbors", "HGNN", "HGNNP", "HyperGCN", "MLP", "Node2Vec", "SLP"]
Loading
Loading