Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
173 changes: 173 additions & 0 deletions examples/nhp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
from torch import nn
from torchmetrics import MetricCollection
from torchmetrics.classification import (
BinaryAUROC,
BinaryAveragePrecision,
BinaryPrecision,
BinaryRecall,
)
from hyperbench.data import AlgebraDataset, DataLoader, SamplingStrategy
from hyperbench.hlp import NHPHlpModule
from hyperbench.nn import Node2VecEnricher
from hyperbench.train import MultiModelTrainer, RandomNegativeSampler
from hyperbench.types import HData, ModelConfig


if __name__ == "__main__":
verbose = False
num_workers = 8
num_features = 128
sampling_strategy = SamplingStrategy.HYPEREDGE
metrics = MetricCollection(
{
"auc": BinaryAUROC(),
"avg_precision": BinaryAveragePrecision(),
"precision": BinaryPrecision(),
"recall": BinaryRecall(),
}
)

print("Loading and preparing dataset...")

dataset = AlgebraDataset(sampling_strategy=sampling_strategy, prepare=True)
if verbose:
print(f"Dataset:\n {dataset.hdata}\n")

train_dataset, test_dataset = dataset.split(
ratios=[0.8, 0.2], shuffle=True, seed=42, node_space_setting="transductive"
)
train_dataset, val_dataset = train_dataset.split(
ratios=[0.875, 0.125], shuffle=True, seed=42, node_space_setting="transductive"
)

for name, ds in [("Train", train_dataset), ("Val", val_dataset), ("Test", test_dataset)]:
num_negative_samples = (
ds.hdata.num_hyperedges
if name in ["Train", "Val"]
else int(ds.hdata.num_hyperedges * 0.6)
)
negative_sampler = RandomNegativeSampler(
num_negative_samples=num_negative_samples,
num_nodes_per_sample=int(ds.stats()["avg_degree_hyperedge"]),
)
neg_hdata = negative_sampler.sample(ds.hdata)
shuffled_hdata = HData.cat_same_node_space([ds.hdata, neg_hdata]).shuffle(seed=42)
ds_with_negatives = ds.update_from_hdata(shuffled_hdata)

if name == "Train":
train_dataset = ds_with_negatives
elif name == "Val":
val_dataset = ds_with_negatives
else:
test_dataset = ds_with_negatives

if verbose:
print(f"{name} dataset after adding negative samples: {shuffled_hdata}\n")

print("Enriching node features...")

node2vec_enricher = Node2VecEnricher(
num_features=num_features,
context_size=10,
walk_length=20,
num_walks_per_node=10,
num_negative_samples=1,
num_nodes=dataset.hdata.num_nodes,
num_epochs=10,
learning_rate=0.01,
batch_size=128,
sparse=False,
verbose=verbose,
)

train_dataset.enrich_node_features(
enricher=node2vec_enricher,
enrichment_mode="replace",
)
val_dataset.enrich_node_features_from(train_dataset)
test_dataset.enrich_node_features_from(train_dataset)

print("Creating dataloaders...")

train_loader = DataLoader(
train_dataset,
batch_size=64,
shuffle=False,
num_workers=num_workers,
persistent_workers=True,
)
val_loader = DataLoader(
val_dataset,
sample_full_hypergraph=True,
shuffle=False,
num_workers=num_workers,
persistent_workers=True,
)
test_loader = DataLoader(
test_dataset,
sample_full_hypergraph=True,
shuffle=False,
num_workers=num_workers,
persistent_workers=True,
)

maxmin_nhp_module = NHPHlpModule(
encoder_config={
"in_channels": num_features,
"hidden_channels": 512,
"aggregation": "maxmin",
},
lr=0.001,
weight_decay=5e-4,
metrics=metrics,
)

mean_nhp_module = NHPHlpModule(
encoder_config={
"in_channels": num_features,
"hidden_channels": 512,
"aggregation": "mean",
},
lr=0.001,
weight_decay=5e-4,
metrics=metrics,
)

configs = [
ModelConfig(
name="nhp",
version="maxmin",
model=maxmin_nhp_module,
train_dataloader=train_loader,
val_dataloader=val_loader,
test_dataloader=test_loader,
),
ModelConfig(
name="nhp",
version="mean",
model=mean_nhp_module,
train_dataloader=train_loader,
val_dataloader=val_loader,
test_dataloader=test_loader,
),
]

print("Starting training and evaluation...")

with MultiModelTrainer(
model_configs=configs,
max_epochs=50,
accelerator="auto",
log_every_n_steps=1,
enable_checkpointing=False,
auto_start_tensorboard=True,
auto_wait=True,
) as trainer:
trainer.fit_all(
train_dataloader=train_loader,
val_dataloader=val_loader,
verbose=True,
)
trainer.test_all(dataloader=test_loader, verbose=True)

print("Complete!")
4 changes: 4 additions & 0 deletions hyperbench/hlp/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from .hlp import HlpModule
from .hypergcn_hlp import HyperGCNHlpModule, HyperGCNEncoderConfig
from .mlp_hlp import MLPHlpModule, MlpEncoderConfig
from .nhp_hlp import NHPEncoderConfig, NHPHlpModule, NHPRankingLoss
from .node2vec_common import (
NODE2VEC_JOINT_MODE,
NODE2VEC_PRECOMPUTED_MODE,
Expand Down Expand Up @@ -34,6 +35,9 @@
"HyperGCNHlpModule",
"MlpEncoderConfig",
"MLPHlpModule",
"NHPEncoderConfig",
"NHPHlpModule",
"NHPRankingLoss",
"Node2VecHlpConfig",
"Node2VecGCNEncoderConfig",
"Node2VecGCNHlpModule",
Expand Down
167 changes: 167 additions & 0 deletions hyperbench/hlp/nhp_hlp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
import torch
import torch.nn.functional as F

from torch import Tensor, nn, optim
from typing import Dict, Literal, Optional, TypedDict
from torchmetrics import MetricCollection
from typing_extensions import NotRequired
from hyperbench.models import NHP
from hyperbench.types import HData
from hyperbench.utils import ActivationFn, Stage

from hyperbench.hlp.hlp import HlpModule


class NHPEncoderConfig(TypedDict):
"""
Configuration for the NHP encoder/scorer to be used for hyperedge link prediction.

Args:
in_channels: Number of input features per node.
hidden_channels: Number of hidden channels for incidence embeddings. Defaults to ``512``.
aggregation: Hyperedge scoring aggregation. ``"maxmin"`` uses the paper's
element-wise range representation; ``"mean"`` uses mean pooling.
bias: Whether to include bias terms. Defaults to ``True``.
"""

in_channels: int
hidden_channels: NotRequired[int]
activation_fn: NotRequired[Optional[ActivationFn]]
activation_fn_kwargs: NotRequired[Optional[Dict]]
aggregation: NotRequired[Literal["mean", "maxmin"]]
bias: NotRequired[bool]


class NHPRankingLoss(nn.Module):
"""
Ranking loss that pushes positive hyperedges above sampled negatives.

Examples:
>>> logits = [2.0, 1.0, -1.0]
>>> labels = [1.0, 1.0, 0.0]
>>> loss = NHPRankingLoss()
>>> loss(logits, labels)
>>> loss.ndim
... 0
"""

def forward(self, logits: Tensor, labels: Tensor) -> Tensor:
"""
Compute the ranking loss.

Args:
logits: Logit scores for each candidate hyperedge, of shape ``(num_hyperedges,)``.
labels: Binary labels indicating positive (1) and negative (0) hyperedges, of shape ``(num_hyperedges,)``.

Returns:
Scalar loss value.
"""
# Split logits by label as we need to compare positive scores against negative scores.
# Example: logits = [2.0, 1.0, -1.0]
# labels = [1.0, 1.0, 0.0]
# -> positive_logits = [2.0, 1.0]
# -> negative_logits = [-1.0]
positive_logits = logits[labels == 1]
negative_logits = logits[labels == 0]

positive_scores = torch.sigmoid(positive_logits)
negative_scores = torch.sigmoid(negative_logits)
if positive_scores.numel() == 0 or negative_scores.numel() == 0:
raise ValueError("NHPRankingLoss requires both positive and negative hyperedges.")

# Objective: enforce that each positive score is higher than the average negative score.
# For each positive score pos_i:
# margin_i = mean(negative_scores) - pos_i
# We interpret margin_i as follows:
# - If pos_i > mean(negatives), then margin_i < 0 -> desirable
# - If pos_i <= mean(negatives), then margin_i >= 0 -> violation
margins = negative_scores.mean() - positive_scores

# Then softplus(margin_i):
# - Is ~0 when margin_i is strongly negative (good ranking).
# - Grows smoothly when margin_i > 0 (penalizing violations).
# Final loss is the average over all positive samples.
return F.softplus(margins).mean()


class NHPHlpModule(HlpModule):
"""
A LightningModule for undirected NHP hyperedge link prediction.

NHP encodes and scores candidate hyperedges in a single pass.
Unlike encoder wrappers that produce reusable global node embeddings,
NHP builds candidate-specific incidence embeddings before pooling and scoring each hyperedge.

Args:
encoder_config: Configuration for the NHP encoder/scorer.
loss_fn: Loss function. Defaults to :class:`NHPRankingLoss`.
lr: Learning rate for the optimizer. Defaults to ``0.001``.
weight_decay: L2 regularization. Defaults to ``5e-4``.
metrics: Optional metric collection for evaluation.
"""

def __init__(
self,
encoder_config: NHPEncoderConfig,
loss_fn: Optional[nn.Module] = None,
lr: float = 0.001,
weight_decay: float = 5e-4,
metrics: Optional[MetricCollection] = None,
):
encoder = NHP(
in_channels=encoder_config["in_channels"],
hidden_channels=encoder_config.get("hidden_channels", 512),
activation_fn=encoder_config.get("activation_fn"),
activation_fn_kwargs=encoder_config.get("activation_fn_kwargs"),
aggregation=encoder_config.get("aggregation", "maxmin"),
bias=encoder_config.get("bias", True),
)

super().__init__(
encoder=encoder,
decoder=nn.Identity(),
loss_fn=loss_fn if loss_fn is not None else NHPRankingLoss(),
metrics=metrics,
)

self.lr = lr
self.weight_decay = weight_decay

def forward(self, x: Tensor, hyperedge_index: Tensor) -> Tensor:
"""
Encode and score each candidate hyperedge.

Args:
x: Node feature matrix of shape ``(num_nodes, in_channels)``.
hyperedge_index: Hyperedge connectivity of shape ``(2, num_incidences)``.

Returns:
Scores of shape ``(num_hyperedges,)``.
"""
if self.encoder is None:
raise ValueError("Encoder is not defined for this HLP module.")
return self.encoder(x, hyperedge_index)

def training_step(self, batch: HData, batch_idx: int) -> Tensor:
return self.__eval_step(batch, Stage.TRAIN)

def validation_step(self, batch: HData, batch_idx: int) -> Tensor:
return self.__eval_step(batch, Stage.VAL)

def test_step(self, batch: HData, batch_idx: int) -> Tensor:
return self.__eval_step(batch, Stage.TEST)

def predict_step(self, batch: HData, batch_idx: int) -> Tensor:
return self.forward(batch.x, batch.hyperedge_index)

def configure_optimizers(self):
return optim.Adam(self.parameters(), lr=self.lr, weight_decay=self.weight_decay)

def __eval_step(self, batch: HData, stage: Stage) -> Tensor:
scores = self.forward(batch.x, batch.hyperedge_index)
labels = batch.y
batch_size = batch.num_hyperedges

loss = self._compute_loss(scores, labels, batch_size, stage)
self._compute_metrics(scores, labels, batch_size, stage)
return loss
3 changes: 3 additions & 0 deletions hyperbench/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from .hgnnp import HGNNP
from .hypergcn import HyperGCN
from .mlp import MLP, SLP
from .nhp import NHP
from .node2vec import Node2Vec, Node2VecConfig, Node2VecGCN

__all__ = [
Expand All @@ -16,6 +17,8 @@
"HNHN",
"HyperGCN",
"MLP",
"NHP",
"NHPAggregation",
"Node2Vec",
"Node2VecConfig",
"Node2VecGCN",
Expand Down
Loading
Loading