Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions examples/node2vec.py → examples/node2vecslp.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
BinaryRecall,
)
from hyperbench.data import AlgebraDataset, DataLoader, SamplingStrategy
from hyperbench.hlp import Node2VecHlpModule
from hyperbench.hlp import Node2VecSLPHlpModule
from hyperbench.nn import Node2VecEnricher
from hyperbench.train import MultiModelTrainer, RandomNegativeSampler
from hyperbench.types import HData, ModelConfig
Expand Down Expand Up @@ -120,7 +120,7 @@
persistent_workers=True,
)

precomputed_node2vec_module = Node2VecHlpModule(
precomputed_node2vecslp_module = Node2VecSLPHlpModule(
encoder_config={
"mode": "precomputed",
"num_features": num_features,
Expand All @@ -132,7 +132,7 @@
)

train_hyperedge_index = train_dataset.hdata.hyperedge_index
joint_node2vec_module = Node2VecHlpModule(
joint_node2vecslp_module = Node2VecSLPHlpModule(
encoder_config={
"mode": "joint",
"num_features": num_features,
Expand All @@ -156,17 +156,17 @@

configs = [
ModelConfig(
name="node2vec",
name="node2vecslp",
version="precomputed",
model=precomputed_node2vec_module,
model=precomputed_node2vecslp_module,
train_dataloader=train_loader,
val_dataloader=val_loader,
test_dataloader=test_loader,
),
ModelConfig(
name="node2vec",
name="node2vecslp",
version="joint",
model=joint_node2vec_module,
model=joint_node2vecslp_module,
train_dataloader=train_loader,
val_dataloader=val_loader,
test_dataloader=test_loader,
Expand Down
4 changes: 2 additions & 2 deletions hyperbench/hlp/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from .hlp import HlpModule
from .hypergcn_hlp import HyperGCNHlpModule, HyperGCNEncoderConfig
from .mlp_hlp import MLPHlpModule, MlpEncoderConfig
from .node2vec_hlp import Node2VecEncoderConfig, Node2VecHlpModule
from .node2vec_hlp import Node2VecEncoderConfig, Node2VecSLPHlpModule

__all__ = [
"CommonNeighborsHlpModule",
Expand All @@ -21,5 +21,5 @@
"MlpEncoderConfig",
"MLPHlpModule",
"Node2VecEncoderConfig",
"Node2VecHlpModule",
"Node2VecSLPHlpModule",
]
2 changes: 1 addition & 1 deletion hyperbench/hlp/node2vec_hlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ class Node2VecEncoderConfig(TypedDict):
node2vec_loss_weight: NotRequired[float]


class Node2VecHlpModule(HlpModule):
class Node2VecSLPHlpModule(HlpModule):
"""
A LightningModule for Node2Vec-based Hyperedge Link Prediction.

Expand Down
Loading