diff --git a/examples/early_stopping.py b/examples/early_stopping.py new file mode 100644 index 0000000..07e2526 --- /dev/null +++ b/examples/early_stopping.py @@ -0,0 +1,155 @@ +from torchmetrics import MetricCollection +from torchmetrics.classification import ( + BinaryAUROC, + BinaryAveragePrecision, + BinaryPrecision, + BinaryRecall, +) +from lightning.pytorch.callbacks import EarlyStopping +from hyperbench.hlp import MLPHlpModule +from hyperbench.nn import LaplacianPositionalEncodingEnricher +from hyperbench.train import MultiModelTrainer, RandomNegativeSampler +from hyperbench.types import HData, ModelConfig +from hyperbench.data import AlgebraDataset, DataLoader, SamplingStrategy + + +if __name__ == "__main__": + verbose = False + num_workers = 8 + num_features = 32 + sampling_strategy = SamplingStrategy.HYPEREDGE + metrics = MetricCollection( + { + "auc": BinaryAUROC(), + "avg_precision": BinaryAveragePrecision(), + "precision": BinaryPrecision(), + "recall": BinaryRecall(), + } + ) + + print("Loading and preparing dataset...") + + dataset = AlgebraDataset(sampling_strategy=sampling_strategy, prepare=True) + if verbose: + print(f"Dataset:\n {dataset.hdata}\n") + + # Split dataset into train and test (80/20) + train_dataset, test_dataset = dataset.split( + ratios=[0.8, 0.2], shuffle=True, seed=42, node_space_setting="transductive" + ) + + # Split train into train and val (87.5/12.5 of train = 70/10 of total) + train_dataset, val_dataset = train_dataset.split( + ratios=[0.875, 0.125], shuffle=True, seed=42, node_space_setting="transductive" + ) + if verbose: + print(f"Train dataset (before train/val split):\n {train_dataset.hdata}\n") + print(f"Train dataset (after train/val split):\n {train_dataset.hdata}\n") + print(f"Val dataset:\n {val_dataset.hdata}\n") + print(f"Test dataset:\n {test_dataset.hdata}\n") + + # Save train hyperedge index before adding negatives (for CommonNeighbors) + train_hyperedge_index = train_dataset.hdata.hyperedge_index + + # Add negative samples to all splits + for name, ds in [("Train", train_dataset), ("Val", val_dataset), ("Test", test_dataset)]: + num_negative_samples = ( + ds.hdata.num_hyperedges + if name in ["Train", "Val"] # 1:1 ratio of pos:neg samples + else int(ds.hdata.num_hyperedges * 0.6) # 60% negatives for test set + ) + negative_sampler = RandomNegativeSampler( + num_negative_samples=num_negative_samples, + num_nodes_per_sample=int(ds.stats()["avg_degree_hyperedge"]), + ) + neg_hdata = negative_sampler.sample(ds.hdata) + combined_hdata = HData.cat_same_node_space([ds.hdata, neg_hdata]) + shuffled_hdata = combined_hdata.shuffle(seed=42) + ds_with_negatives = ds.update_from_hdata(shuffled_hdata) + + if name == "Train": + train_dataset = ds_with_negatives + elif name == "Val": + val_dataset = ds_with_negatives + else: + test_dataset = ds_with_negatives + + if verbose: + print(f"{name} dataset after adding negative samples: {shuffled_hdata}\n") + + print("Enriching node features...") + + train_dataset.enrich_node_features( + enricher=LaplacianPositionalEncodingEnricher( + num_features=num_features, + # In transductive setting, use total number of nodes to ensure consistent encoding across splits + # as the train dataset contain all nodes but may have no hyperedges where they appear + num_nodes=train_dataset.hdata.num_nodes, + ), + enrichment_mode="replace", + ) + val_dataset.enrich_node_features_from(train_dataset) + test_dataset.enrich_node_features_from(train_dataset) + + print("Creating dataloaders...") + + train_loader = DataLoader( + train_dataset, + batch_size=128, # or 256 + shuffle=False, + num_workers=num_workers, + persistent_workers=True, + ) + val_loader = DataLoader( + val_dataset, + sample_full_hypergraph=True, + shuffle=False, + num_workers=num_workers, + persistent_workers=True, + ) + test_loader = DataLoader( + test_dataset, + sample_full_hypergraph=True, + shuffle=False, + num_workers=num_workers, + persistent_workers=True, + ) + + mean_mlp_module = MLPHlpModule( + encoder_config={ + "in_channels": num_features, + "out_channels": num_features, + "hidden_channels": 64, + "num_layers": 3, + "drop_rate": 0.3, + }, + aggregation="mean", + metrics=metrics, + ) + + configs = [ + ModelConfig(name="mlp", version="mean", model=mean_mlp_module), + ] + + early_stopping = EarlyStopping( + monitor="val_loss", + patience=10, + mode="min", + ) + + print("Starting training and evaluation...") + + with MultiModelTrainer( + model_configs=configs, + max_epochs=200, + accelerator="auto", + log_every_n_steps=10, + callbacks=[early_stopping], + enable_checkpointing=False, + auto_start_tensorboard=True, + auto_wait=True, + ) as trainer: + trainer.fit_all(train_dataloader=train_loader, val_dataloader=val_loader, verbose=True) + trainer.test_all(dataloader=test_loader, verbose=True) + + print("Complete!") diff --git a/examples/hgnn.py b/examples/hgnn.py index b34e13c..5837914 100644 --- a/examples/hgnn.py +++ b/examples/hgnn.py @@ -5,7 +5,6 @@ BinaryPrecision, BinaryRecall, ) -from lightning.pytorch.callbacks import EarlyStopping from hyperbench.hlp import HGNNHlpModule from hyperbench.nn import LaplacianPositionalEncodingEnricher from hyperbench.train import MultiModelTrainer, RandomNegativeSampler @@ -16,6 +15,7 @@ if __name__ == "__main__": verbose = False num_workers = 8 + num_features = 32 sampling_strategy = SamplingStrategy.HYPEREDGE metrics = MetricCollection( { @@ -33,10 +33,14 @@ print(f"Dataset:\n {dataset.hdata}\n") # Split dataset into train and test (80/20) - train_dataset, test_dataset = dataset.split(ratios=[0.8, 0.2], shuffle=True, seed=42) + train_dataset, test_dataset = dataset.split( + ratios=[0.8, 0.2], shuffle=True, seed=42, node_space_setting="transductive" + ) # Split train into train and val (87.5/12.5 of train = 70/10 of total) - train_dataset, val_dataset = train_dataset.split(ratios=[0.875, 0.125], shuffle=True, seed=42) + train_dataset, val_dataset = train_dataset.split( + ratios=[0.875, 0.125], shuffle=True, seed=42, node_space_setting="transductive" + ) if verbose: print(f"Train dataset (before train/val split):\n {train_dataset.hdata}\n") print(f"Train dataset (after train/val split):\n {train_dataset.hdata}\n") @@ -75,11 +79,16 @@ print("Enriching node features...") train_dataset.enrich_node_features( - enricher=LaplacianPositionalEncodingEnricher(num_features=32), + enricher=LaplacianPositionalEncodingEnricher( + num_features=num_features, + # In transductive setting, use total number of nodes to ensure consistent encoding across splits + # as the train dataset contain all nodes but may have no hyperedges where they appear + num_nodes=train_dataset.hdata.num_nodes, + ), enrichment_mode="replace", ) - val_dataset.hdata.x = train_dataset.hdata.x[: val_dataset.hdata.num_nodes] - test_dataset.hdata.x = train_dataset.hdata.x[:, : test_dataset.hdata.num_nodes] + val_dataset.enrich_node_features_from(train_dataset) + test_dataset.enrich_node_features_from(train_dataset) print("Creating dataloaders...") @@ -107,7 +116,7 @@ mean_hgnn_module = HGNNHlpModule( encoder_config={ - "in_channels": 32, + "in_channels": num_features, "hidden_channels": 16, "out_channels": 16, "bias": True, @@ -115,7 +124,7 @@ "drop_rate": 0.5, }, aggregation="mean", - lr=0.01, + lr=0.001, weight_decay=5e-4, metrics=metrics, ) @@ -131,12 +140,6 @@ ), ] - early_stopping = EarlyStopping( - monitor="val_loss", - patience=30, - mode="min", - ) - print("Starting training and evaluation...") with MultiModelTrainer( @@ -144,7 +147,6 @@ max_epochs=60, accelerator="auto", log_every_n_steps=1, - callbacks=[early_stopping], enable_checkpointing=False, auto_start_tensorboard=True, auto_wait=True, diff --git a/examples/hgnnp.py b/examples/hgnnp.py index daf5a33..f3e1038 100644 --- a/examples/hgnnp.py +++ b/examples/hgnnp.py @@ -5,8 +5,6 @@ BinaryPrecision, BinaryRecall, ) -from lightning.pytorch.callbacks import EarlyStopping - from hyperbench.data import AlgebraDataset, DataLoader, SamplingStrategy from hyperbench.hlp import HGNNPHlpModule from hyperbench.nn import LaplacianPositionalEncodingEnricher @@ -17,6 +15,7 @@ if __name__ == "__main__": verbose = False num_workers = 8 + num_features = 32 sampling_strategy = SamplingStrategy.HYPEREDGE metrics = MetricCollection( { @@ -33,8 +32,12 @@ if verbose: print(f"Dataset:\n {dataset.hdata}\n") - train_dataset, test_dataset = dataset.split(ratios=[0.8, 0.2], shuffle=True, seed=42) - train_dataset, val_dataset = train_dataset.split(ratios=[0.875, 0.125], shuffle=True, seed=42) + train_dataset, test_dataset = dataset.split( + ratios=[0.8, 0.2], shuffle=True, seed=42, node_space_setting="transductive" + ) + train_dataset, val_dataset = train_dataset.split( + ratios=[0.875, 0.125], shuffle=True, seed=42, node_space_setting="transductive" + ) if verbose: print(f"Train dataset (before train/val split):\n {train_dataset.hdata}\n") print(f"Train dataset (after train/val split):\n {train_dataset.hdata}\n") @@ -71,11 +74,16 @@ print("Enriching node features...") train_dataset.enrich_node_features( - enricher=LaplacianPositionalEncodingEnricher(num_features=32), + enricher=LaplacianPositionalEncodingEnricher( + num_features=num_features, + # In transductive setting, use total number of nodes to ensure consistent encoding across splits + # as the train dataset contain all nodes but may have no hyperedges where they appear + num_nodes=train_dataset.hdata.num_nodes, + ), enrichment_mode="replace", ) - val_dataset.hdata.x = train_dataset.hdata.x[: val_dataset.hdata.num_nodes] - test_dataset.hdata.x = train_dataset.hdata.x[:, : test_dataset.hdata.num_nodes] + val_dataset.enrich_node_features_from(train_dataset) + test_dataset.enrich_node_features_from(train_dataset) print("Creating dataloaders...") @@ -103,7 +111,7 @@ mean_hgnnp_module = HGNNPHlpModule( encoder_config={ - "in_channels": 32, + "in_channels": num_features, "hidden_channels": 16, "out_channels": 16, "bias": True, @@ -127,12 +135,6 @@ ), ] - early_stopping = EarlyStopping( - monitor="val_loss", - patience=30, - mode="min", - ) - print("Starting training and evaluation...") with MultiModelTrainer( @@ -140,7 +142,6 @@ max_epochs=60, accelerator="auto", log_every_n_steps=1, - callbacks=[early_stopping], enable_checkpointing=False, auto_start_tensorboard=True, auto_wait=True, diff --git a/examples/hnhn.py b/examples/hnhn.py new file mode 100644 index 0000000..cb7f62f --- /dev/null +++ b/examples/hnhn.py @@ -0,0 +1,156 @@ +from torchmetrics import MetricCollection +from torchmetrics.classification import ( + BinaryAUROC, + BinaryAveragePrecision, + BinaryPrecision, + BinaryRecall, +) +from hyperbench.data import AlgebraDataset, DataLoader, SamplingStrategy +from hyperbench.hlp import HNHNHlpModule +from hyperbench.nn import LaplacianPositionalEncodingEnricher +from hyperbench.train import MultiModelTrainer, RandomNegativeSampler +from hyperbench.types import HData, ModelConfig + + +if __name__ == "__main__": + verbose = False + num_workers = 8 + num_features = 128 + sampling_strategy = SamplingStrategy.HYPEREDGE + metrics = MetricCollection( + { + "auc": BinaryAUROC(), + "avg_precision": BinaryAveragePrecision(), + "precision": BinaryPrecision(), + "recall": BinaryRecall(), + } + ) + + print("Loading and preparing dataset...") + + dataset = AlgebraDataset(sampling_strategy=sampling_strategy, prepare=True) + if verbose: + print(f"Dataset:\n {dataset.hdata}\n") + + train_dataset, test_dataset = dataset.split( + ratios=[0.8, 0.2], shuffle=True, seed=42, node_space_setting="transductive" + ) + train_dataset, val_dataset = train_dataset.split( + ratios=[0.875, 0.125], shuffle=True, seed=42, node_space_setting="transductive" + ) + if verbose: + print(f"Train dataset (before train/val split):\n {train_dataset.hdata}\n") + print(f"Train dataset (after train/val split):\n {train_dataset.hdata}\n") + print(f"Val dataset:\n {val_dataset.hdata}\n") + print(f"Test dataset:\n {test_dataset.hdata}\n") + + for name, ds in [("Train", train_dataset), ("Val", val_dataset), ("Test", test_dataset)]: + num_negative_samples = ( + ds.hdata.num_hyperedges + if name in ["Train", "Val"] + else int(ds.hdata.num_hyperedges * 0.6) + ) + negative_sampler = RandomNegativeSampler( + num_negative_samples=num_negative_samples, + num_nodes_per_sample=int(ds.stats()["avg_degree_hyperedge"]), + ) + neg_hdata = negative_sampler.sample(ds.hdata) + combined_hdata = HData.cat_same_node_space([ds.hdata, neg_hdata]) + shuffled_hdata = combined_hdata.shuffle(seed=42) + ds_with_negatives = ds.update_from_hdata(shuffled_hdata) + + if name == "Train": + train_dataset = ds_with_negatives + elif name == "Val": + val_dataset = ds_with_negatives + else: + test_dataset = ds_with_negatives + + if verbose: + print(f"{name} dataset after adding negative samples: {shuffled_hdata}\n") + + print("Enriching node features...") + + train_dataset.enrich_node_features( + enricher=LaplacianPositionalEncodingEnricher( + num_features=num_features, + # In transductive setting, use total number of nodes to ensure consistent encoding across splits + # as the train dataset contain all nodes but may have no hyperedges where they appear + num_nodes=train_dataset.hdata.num_nodes, + ), + enrichment_mode="replace", + ) + val_dataset.enrich_node_features_from(train_dataset) + test_dataset.enrich_node_features_from(train_dataset) + + print("Creating dataloaders...") + + train_loader_full_hypergraph = DataLoader( + train_dataset, + sample_full_hypergraph=True, + shuffle=False, + num_workers=num_workers, + persistent_workers=True, + ) + val_loader_full_hypergraph = DataLoader( + val_dataset, + sample_full_hypergraph=True, + shuffle=False, + num_workers=num_workers, + persistent_workers=True, + ) + test_loader_full_hypergraph = DataLoader( + test_dataset, + sample_full_hypergraph=True, + shuffle=False, + num_workers=num_workers, + persistent_workers=True, + ) + + mean_hnhn_module = HNHNHlpModule( + encoder_config={ + "in_channels": num_features, + "hidden_channels": 400, + "out_channels": 400, + "bias": True, + "use_batch_normalization": False, + "drop_rate": 0.3, + }, + aggregation="mean", + lr=0.04, + weight_decay=5e-4, + scheduler_step_size=100, + scheduler_gamma=0.51, + metrics=metrics, + ) + + configs = [ + ModelConfig( + name="hnhn", + version="mean", + model=mean_hnhn_module, + train_dataloader=train_loader_full_hypergraph, + val_dataloader=val_loader_full_hypergraph, + test_dataloader=test_loader_full_hypergraph, + ), + ] + + print("Starting training and evaluation...") + + with MultiModelTrainer( + model_configs=configs, + max_epochs=200, + accelerator="auto", + log_every_n_steps=1, + enable_checkpointing=False, + auto_start_tensorboard=True, + auto_wait=True, + ) as trainer: + trainer.fit_all( + train_dataloader=train_loader_full_hypergraph, + val_dataloader=val_loader_full_hypergraph, + verbose=True, + ) + trainer.test_all(dataloader=test_loader_full_hypergraph, verbose=True) + + print("Complete!") diff --git a/examples/hypergcn.py b/examples/hypergcn.py index e94e281..d919681 100644 --- a/examples/hypergcn.py +++ b/examples/hypergcn.py @@ -5,7 +5,6 @@ BinaryPrecision, BinaryRecall, ) -from lightning.pytorch.callbacks import EarlyStopping from hyperbench.hlp import HyperGCNHlpModule from hyperbench.nn import HyperedgeWeightsEnricher, LaplacianPositionalEncodingEnricher from hyperbench.train import MultiModelTrainer, RandomNegativeSampler @@ -16,6 +15,7 @@ if __name__ == "__main__": verbose = False num_workers = 8 + num_features = 32 sampling_strategy = SamplingStrategy.HYPEREDGE metrics = MetricCollection( { @@ -34,10 +34,14 @@ print(f"Dataset:\n {dataset.hdata}\n") # Split dataset into train and test (80/20) - train_dataset, test_dataset = dataset.split(ratios=[0.8, 0.2], shuffle=True, seed=42) + train_dataset, test_dataset = dataset.split( + ratios=[0.8, 0.2], shuffle=True, seed=42, node_space_setting="transductive" + ) # Split train into train and val (87.5/12.5 of train = 70/10 of total) - train_dataset, val_dataset = train_dataset.split(ratios=[0.875, 0.125], shuffle=True, seed=42) + train_dataset, val_dataset = train_dataset.split( + ratios=[0.875, 0.125], shuffle=True, seed=42, node_space_setting="transductive" + ) if verbose: print(f"Train dataset (before train/val split):\n {train_dataset.hdata}\n") print(f"Train dataset (after train/val split):\n {train_dataset.hdata}\n") @@ -84,11 +88,16 @@ print("Enriching node features...") train_dataset.enrich_node_features( - enricher=LaplacianPositionalEncodingEnricher(num_features=32), + enricher=LaplacianPositionalEncodingEnricher( + num_features=num_features, + # In transductive setting, use total number of nodes to ensure consistent encoding across splits + # as the train dataset contain all nodes but may have no hyperedges where they appear + num_nodes=train_dataset.hdata.num_nodes, + ), enrichment_mode="replace", ) - val_dataset.hdata.x = train_dataset.hdata.x[: val_dataset.hdata.num_nodes] - test_dataset.hdata.x = train_dataset.hdata.x[:, : test_dataset.hdata.num_nodes] + val_dataset.enrich_node_features_from(train_dataset) + test_dataset.enrich_node_features_from(train_dataset) print("Creating dataloaders...") @@ -116,7 +125,7 @@ mean_hypergcn_no_mediator_module = HyperGCNHlpModule( encoder_config={ - "in_channels": 32, + "in_channels": num_features, "hidden_channels": 16, "out_channels": 16, "bias": True, @@ -167,12 +176,6 @@ ), ] - early_stopping = EarlyStopping( - monitor="val_loss", - patience=100, - mode="min", - ) - print("Starting training and evaluation...") with MultiModelTrainer( @@ -180,7 +183,6 @@ max_epochs=200, accelerator="auto", log_every_n_steps=1, - callbacks=[early_stopping], enable_checkpointing=False, auto_start_tensorboard=True, auto_wait=True, diff --git a/examples/mlp_common_neighbors.py b/examples/mlp_common_neighbors.py index b638429..41dd388 100644 --- a/examples/mlp_common_neighbors.py +++ b/examples/mlp_common_neighbors.py @@ -5,7 +5,6 @@ BinaryPrecision, BinaryRecall, ) -from lightning.pytorch.callbacks import EarlyStopping from hyperbench.hlp import CommonNeighborsHlpModule, MLPHlpModule from hyperbench.nn import LaplacianPositionalEncodingEnricher from hyperbench.train import MultiModelTrainer, RandomNegativeSampler @@ -16,6 +15,7 @@ if __name__ == "__main__": verbose = False num_workers = 8 + num_features = 32 sampling_strategy = SamplingStrategy.HYPEREDGE metrics = MetricCollection( { @@ -33,10 +33,14 @@ print(f"Dataset:\n {dataset.hdata}\n") # Split dataset into train and test (80/20) - train_dataset, test_dataset = dataset.split(ratios=[0.8, 0.2], shuffle=True, seed=42) + train_dataset, test_dataset = dataset.split( + ratios=[0.8, 0.2], shuffle=True, seed=42, node_space_setting="transductive" + ) # Split train into train and val (87.5/12.5 of train = 70/10 of total) - train_dataset, val_dataset = train_dataset.split(ratios=[0.875, 0.125], shuffle=True, seed=42) + train_dataset, val_dataset = train_dataset.split( + ratios=[0.875, 0.125], shuffle=True, seed=42, node_space_setting="transductive" + ) if verbose: print(f"Train dataset (before train/val split):\n {train_dataset.hdata}\n") print(f"Train dataset (after train/val split):\n {train_dataset.hdata}\n") @@ -75,11 +79,16 @@ print("Enriching node features...") train_dataset.enrich_node_features( - enricher=LaplacianPositionalEncodingEnricher(num_features=32), + enricher=LaplacianPositionalEncodingEnricher( + num_features=num_features, + # In transductive setting, use total number of nodes to ensure consistent encoding across splits + # as the train dataset contain all nodes but may have no hyperedges where they appear + num_nodes=train_dataset.hdata.num_nodes, + ), enrichment_mode="replace", ) - val_dataset.hdata.x = train_dataset.hdata.x[: val_dataset.hdata.num_nodes] - test_dataset.hdata.x = train_dataset.hdata.x[:, : test_dataset.hdata.num_nodes] + val_dataset.enrich_node_features_from(train_dataset) + test_dataset.enrich_node_features_from(train_dataset) print("Creating dataloaders...") @@ -113,8 +122,8 @@ mean_mlp_module = MLPHlpModule( encoder_config={ - "in_channels": 32, - "out_channels": 32, + "in_channels": num_features, + "out_channels": num_features, "hidden_channels": 64, "num_layers": 3, "drop_rate": 0.3, @@ -133,12 +142,6 @@ ModelConfig(name="mlp", version="mean", model=mean_mlp_module), ] - early_stopping = EarlyStopping( - monitor="val_loss", - patience=100, - mode="min", - ) - print("Starting training and evaluation...") with MultiModelTrainer( @@ -146,7 +149,6 @@ max_epochs=200, accelerator="auto", log_every_n_steps=10, - callbacks=[early_stopping], enable_checkpointing=False, auto_start_tensorboard=True, auto_wait=True, diff --git a/examples/node2vec.py b/examples/node2vec.py index 63a4644..c0de8c8 100644 --- a/examples/node2vec.py +++ b/examples/node2vec.py @@ -1,4 +1,3 @@ -from lightning.pytorch.callbacks import EarlyStopping from torchmetrics import MetricCollection from torchmetrics.classification import ( BinaryAUROC, @@ -16,7 +15,7 @@ if __name__ == "__main__": verbose = False num_workers = 8 - embedding_dim = 32 + num_features = 32 sampling_strategy = SamplingStrategy.HYPEREDGE metrics = MetricCollection( { @@ -35,10 +34,14 @@ print(f"Dataset:\n {dataset.hdata}\n") # Split dataset into train and test (80/20) - train_dataset, test_dataset = dataset.split(ratios=[0.8, 0.2], shuffle=True, seed=42) + train_dataset, test_dataset = dataset.split( + ratios=[0.8, 0.2], shuffle=True, seed=42, node_space_setting="transductive" + ) # Split train into train and val (87.5/12.5 of train = 70/10 of total) - train_dataset, val_dataset = train_dataset.split(ratios=[0.875, 0.125], shuffle=True, seed=42) + train_dataset, val_dataset = train_dataset.split( + ratios=[0.875, 0.125], shuffle=True, seed=42, node_space_setting="transductive" + ) if verbose: print(f"Train dataset:\n {train_dataset.hdata}\n") print(f"Val dataset:\n {val_dataset.hdata}\n") @@ -74,7 +77,7 @@ print("Computing Node2Vec embeddings from the train graph...") node2vec_enricher = Node2VecEnricher( - num_features=embedding_dim, + num_features=num_features, context_size=10, walk_length=20, num_walks_per_node=10, @@ -90,8 +93,8 @@ enricher=node2vec_enricher, enrichment_mode="replace", ) - val_dataset.hdata.x = train_dataset.hdata.x[: val_dataset.hdata.num_nodes] - test_dataset.hdata.x = train_dataset.hdata.x[:, : test_dataset.hdata.num_nodes] + val_dataset.enrich_node_features_from(train_dataset) + test_dataset.enrich_node_features_from(train_dataset) print("Creating dataloaders...") @@ -120,7 +123,7 @@ precomputed_node2vec_module = Node2VecHlpModule( encoder_config={ "mode": "precomputed", - "num_features": embedding_dim, + "num_features": num_features, }, aggregation="mean", lr=0.001, @@ -132,7 +135,7 @@ joint_node2vec_module = Node2VecHlpModule( encoder_config={ "mode": "joint", - "num_features": embedding_dim, + "num_features": num_features, "context_size": 10, "walk_length": 20, "num_walks_per_node": 10, @@ -170,12 +173,6 @@ ), ] - early_stopping = EarlyStopping( - monitor="val_loss", - patience=30, - mode="min", - ) - print("Starting training and evaluation...") with MultiModelTrainer( @@ -183,7 +180,6 @@ max_epochs=60, accelerator="auto", log_every_n_steps=1, - callbacks=[early_stopping], enable_checkpointing=False, auto_start_tensorboard=True, auto_wait=True, diff --git a/examples/node_enricher.py b/examples/node_enricher.py index dd5a932..4557daf 100644 --- a/examples/node_enricher.py +++ b/examples/node_enricher.py @@ -3,12 +3,14 @@ if __name__ == "__main__": + num_features = 32 + print("Loading and preparing dataset...") dataset = AlgebraDataset(sampling_strategy=SamplingStrategy.HYPEREDGE, prepare=True) dataset.enrich_node_features( - enricher=LaplacianPositionalEncodingEnricher(num_features=32), + enricher=LaplacianPositionalEncodingEnricher(num_features=num_features), enrichment_mode="replace", ) @@ -18,7 +20,7 @@ print(f"- First 5 node features:\n {dataset.hdata.x[:5]}\n") node2vec_enricher = Node2VecEnricher( - num_features=32, + num_features=num_features, walk_length=20, context_size=10, num_walks_per_node=10, diff --git a/hyperbench/data/dataset.py b/hyperbench/data/dataset.py index 48b874c..742120d 100644 --- a/hyperbench/data/dataset.py +++ b/hyperbench/data/dataset.py @@ -2,20 +2,25 @@ import os import tempfile import torch -import zstandard as zstd import requests import warnings +import zstandard as zstd from enum import Enum from huggingface_hub import hf_hub_download from typing import Any, Dict, List, Optional from torch import Tensor from torch.utils.data import Dataset as TorchDataset - -from hyperbench.nn import EnrichmentMode, NodeEnricher, HyperedgeEnricher +from hyperbench.nn.enricher import EnrichmentMode, NodeEnricher, HyperedgeEnricher from hyperbench.types import HData, HIFHypergraph, HyperedgeIndex -from hyperbench.nn import EnrichmentMode -from hyperbench.utils import validate_hif_json +from hyperbench.utils import ( + NodeSpaceAssignment, + NodeSpaceFiller, + NodeSpaceSetting, + is_inductive_setting, + is_transductive_split, + validate_hif_json, +) from hyperbench.data.sampling import SamplingStrategy, create_sampler_from_strategy @@ -300,6 +305,42 @@ def enrich_node_features( """ self.hdata = self.hdata.enrich_node_features(enricher, enrichment_mode) + def enrich_node_features_from( + self, + dataset_with_features: "Dataset", + node_space_setting: NodeSpaceSetting = "transductive", + fill_value: Optional[NodeSpaceFiller] = None, + ) -> None: + """ + Enrich node features from another dataset by copying features by ``global_node_ids``. + + Examples: + In a transductive setting, the full node space is preserved across datasets: + >>> val_dataset.enrich_node_features_from(train_dataset) + + In inductive setting, missing node features can be filled with 0.0: + >>> test_dataset.enrich_node_features_from( + ... train_dataset, + ... node_space_setting="inductive", + ... fill_value=0.0, # torch.tensor(0.0) also works and will be broadcast to the appropriate shape + ... ) + + Args: + dataset_with_features: Source dataset providing node features. + node_space_setting: The setting for the node space, determining how nodes are handled. + ``transductive`` (default) preserves the full node space of the target dataset. + ``inductive`` allows the target dataset to have a different node space, filling missing features with ``fill_value``. + fill_value: Scalar or vector used to fill missing node features when ``node_space_setting`` is not transductive. + + Raises: + ValueError: If the source dataset's node features cannot be aligned with the target dataset's nodes. + """ + self.hdata = self.hdata.enrich_node_features_from( + hdata_with_features=dataset_with_features.hdata, + node_space_setting=node_space_setting, + fill_value=fill_value, + ) + def enrich_hyperedge_attr( self, enricher: HyperedgeEnricher, @@ -356,28 +397,49 @@ def split( ratios: List[float], shuffle: Optional[bool] = False, seed: Optional[int] = None, + node_space_setting: NodeSpaceSetting = "transductive", + assign_node_space_to: Optional[NodeSpaceAssignment] = "first", ) -> List["Dataset"]: """ - Split the dataset by hyperedges into partitions with contiguous 0-based IDs. + Split the dataset by hyperedges into partitions with contiguous 0-based hyperedge IDs. Boundaries are computed using cumulative floor to prevent early splits from over-consuming edges. The last split absorbs any rounding remainder. Examples: - With ``num_hyperedges = 3`` and ``ratios = [0.5, 0.25, 0.25]``: - - >>> cumulative_ratios = [0.5, 0.75, 1.0] - - Boundaries: - - - ``i=0`` -> ``end = int(0.5 * 3) = 1`` -> slice ``[0:1]`` -> 1 edge - - ``i=1`` -> ``end = int(0.75 * 3) = 2`` -> slice ``[1:2]`` -> 1 edge - - ``i=2`` -> ``end = 3`` (clamped) -> slice ``[2:3]`` -> 1 edge + Transductive split keeping the full node space only on the first split (default): + >>> train, test = dataset.split([0.8, 0.2]) + >>> train.hdata.num_nodes == dataset.hdata.num_nodes + >>> test.hdata.num_nodes <= dataset.hdata.num_nodes + + Transductive split keeping the full node space on all splits: + >>> train, test = dataset.split( + ... [0.8, 0.2], + ... node_space_setting="transductive", + ... assign_node_space_to="all", + ... ) + >>> train.hdata.num_nodes == dataset.hdata.num_nodes + >>> test.hdata.num_nodes == dataset.hdata.num_nodes + + Inductive split: + >>> train, test = dataset.split( + ... [0.8, 0.2], + ... node_space_setting="inductive", + ... assign_node_space_to=None, + ... ) + >>> train.hdata.num_nodes <= dataset.hdata.num_nodes + >>> test.hdata.num_nodes <= dataset.hdata.num_nodes Args: ratios: List of floats summing to ``1.0``, e.g., ``[0.8, 0.1, 0.1]``. shuffle: Whether to shuffle hyperedges before splitting. Defaults to ``False`` for deterministic splits. seed: Optional random seed for reproducibility. Ignored if shuffle is set to ``False``. + node_space_setting: Whether to preserve the full node space in the splits. + ``transductive`` (default) ensures all nodes are present in every split, + while ``inductive`` allows splits to have disjoint node spaces. + assign_node_space_to: Which split(s) preserve the full node space when + ``node_space_setting="transductive"``. + ``first`` preserves only the first returned split. ``all`` preserves all splits. Returns: List of Dataset objects, one per split, each with contiguous IDs. @@ -388,6 +450,10 @@ def split( # ratios = [0.8, 0.1, 0.1, 0.0000001] -> sum = 1.0000001 (valid, allows small imprecision) if abs(sum(ratios) - 1.0) > 1e-6: raise ValueError(f"Split ratios must sum to 1.0, got {sum(ratios)}.") + if is_inductive_setting(node_space_setting) and assign_node_space_to is not None: + raise ValueError( + "assign_node_space_to can only be provided when node_space_setting='transductive'." + ) device = self.hdata.device num_hyperedges = self.hdata.num_hyperedges @@ -424,7 +490,16 @@ def split( # i=1 -> permutation[1:2] = [1] (1 edge) # i=2 -> permutation[2:3] = [2] (1 edge) split_hyperedge_ids = hyperedge_ids_permutation[start:end] - split_hdata = HData.split(self.hdata, split_hyperedge_ids).to(device=device) + + use_transductive_node_space = is_transductive_split( + node_space_setting, assign_node_space_to, split_num=i + ) + split_hdata = HData.split( + self.hdata, + split_hyperedge_ids, + node_space_setting="transductive" if use_transductive_node_space else "inductive", + ).to(device=device) + split_dataset = self.__class__( hdata=split_hdata, sampling_strategy=self.sampling_strategy, diff --git a/hyperbench/hlp/__init__.py b/hyperbench/hlp/__init__.py index 0ae682e..9863403 100644 --- a/hyperbench/hlp/__init__.py +++ b/hyperbench/hlp/__init__.py @@ -1,5 +1,6 @@ from .common_neighbors_hlp import CommonNeighborsHlpModule from .hgnn_hlp import HGNNHlpModule, HGNNEncoderConfig +from .hnhn_hlp import HNHNEncoderConfig, HNHNHlpModule from .hgnnp_hlp import HGNNPEncoderConfig, HGNNPHlpModule from .hlp import HlpModule from .hypergcn_hlp import HyperGCNHlpModule, HyperGCNEncoderConfig @@ -10,6 +11,8 @@ "CommonNeighborsHlpModule", "HGNNEncoderConfig", "HGNNHlpModule", + "HNHNEncoderConfig", + "HNHNHlpModule", "HGNNPEncoderConfig", "HGNNPHlpModule", "HlpModule", diff --git a/hyperbench/hlp/hgnn_hlp.py b/hyperbench/hlp/hgnn_hlp.py index 7c308fa..24784e6 100644 --- a/hyperbench/hlp/hgnn_hlp.py +++ b/hyperbench/hlp/hgnn_hlp.py @@ -53,7 +53,7 @@ def __init__( encoder_config: HGNNEncoderConfig, aggregation: Literal["mean", "max", "min", "sum"] = "mean", loss_fn: Optional[nn.Module] = None, - lr: float = 0.01, + lr: float = 0.001, weight_decay: float = 5e-4, metrics: Optional[MetricCollection] = None, ): diff --git a/hyperbench/hlp/hnhn_hlp.py b/hyperbench/hlp/hnhn_hlp.py new file mode 100644 index 0000000..b66d817 --- /dev/null +++ b/hyperbench/hlp/hnhn_hlp.py @@ -0,0 +1,135 @@ +from typing import Literal, Optional, TypedDict + +from torch import Tensor, nn, optim +from torchmetrics import MetricCollection +from typing_extensions import NotRequired + +from hyperbench.hlp.hlp import HlpModule +from hyperbench.models import HNHN, SLP +from hyperbench.nn import HyperedgeAggregator +from hyperbench.types import HData +from hyperbench.utils import Stage + + +class HNHNEncoderConfig(TypedDict): + """ + Configuration for the HNHN encoder in HNHNHlpModule. + + Args: + in_channels: Number of input features per node. + hidden_channels: Number of hidden units in the intermediate HNHN layer. + out_channels: Number of output features (embedding size) per node. + bias: Whether to include bias terms. Defaults to ``True``. + use_batch_normalization: Whether to use batch normalization. Defaults to ``False``. + drop_rate: Dropout rate. Defaults to ``0.5``. + """ + + in_channels: int + hidden_channels: int + out_channels: int + bias: NotRequired[bool] + use_batch_normalization: NotRequired[bool] + drop_rate: NotRequired[float] + + +class HNHNHlpModule(HlpModule): + """ + A LightningModule for HNHN-based Hyperedge Link Prediction. + + Uses HNHN as an encoder to produce node embeddings through explicit + hyperedge neurons, aggregates them per hyperedge, and scores each + hyperedge with a linear decoder. + + Args: + encoder_config: Configuration for the HNHN encoder. + aggregation: Method to aggregate node embeddings per hyperedge. Defaults to ``"mean"``. + loss_fn: Loss function. Defaults to ``BCEWithLogitsLoss``. + lr: Learning rate for the optimizer. Defaults to ``0.01``. + weight_decay: L2 regularization. Defaults to ``5e-4``. + scheduler_step_size: Step size for learning rate scheduler. Defaults to ``100``. + scheduler_gamma: Multiplicative factor for learning rate decay. Defaults to ``0.51``. + metrics: Optional metric collection for evaluation. + """ + + def __init__( + self, + encoder_config: HNHNEncoderConfig, + aggregation: Literal["mean", "max", "min", "sum"] = "mean", + loss_fn: Optional[nn.Module] = None, + lr: float = 0.01, + weight_decay: float = 5e-4, + scheduler_step_size: int = 100, + scheduler_gamma: float = 0.51, + metrics: Optional[MetricCollection] = None, + ): + encoder = HNHN( + in_channels=encoder_config["in_channels"], + hidden_channels=encoder_config["hidden_channels"], + num_classes=encoder_config["out_channels"], + bias=encoder_config.get("bias", True), + use_batch_normalization=encoder_config.get("use_batch_normalization", False), + drop_rate=encoder_config.get("drop_rate", 0.5), + ) + decoder = SLP(in_channels=encoder_config["out_channels"], out_channels=1) + + super().__init__( + encoder=encoder, + decoder=decoder, + loss_fn=loss_fn if loss_fn is not None else nn.BCEWithLogitsLoss(), + metrics=metrics, + ) + + self.aggregation = aggregation + self.lr = lr + self.weight_decay = weight_decay + self.scheduler_step_size = scheduler_step_size + self.scheduler_gamma = scheduler_gamma + + def forward(self, x: Tensor, hyperedge_index: Tensor) -> Tensor: + """ + Run the full HNHN-based hyperedge link prediction pipeline. + + Args: + x: Node feature matrix of shape ``(num_nodes, in_channels)``. + hyperedge_index: Hyperedge connectivity of shape ``(2, num_incidences)``. + + Returns: + Logit scores of shape ``(num_hyperedges,)``. + """ + if self.encoder is None: + raise ValueError("Encoder is not defined for this HLP module.") + + node_embeddings: Tensor = self.encoder(x, hyperedge_index) + hyperedge_embeddings = HyperedgeAggregator(hyperedge_index, node_embeddings).pool( + self.aggregation + ) + scores: Tensor = self.decoder(hyperedge_embeddings).squeeze(-1) + return scores + + def training_step(self, batch: HData, batch_idx: int) -> Tensor: + return self.__eval_step(batch, Stage.TRAIN) + + def validation_step(self, batch: HData, batch_idx: int) -> Tensor: + return self.__eval_step(batch, Stage.VAL) + + def test_step(self, batch: HData, batch_idx: int) -> Tensor: + return self.__eval_step(batch, Stage.TEST) + + def predict_step(self, batch: HData, batch_idx: int) -> Tensor: + return self.forward(batch.x, batch.hyperedge_index) + + def configure_optimizers(self): + optimizer = optim.Adam(self.parameters(), lr=self.lr, weight_decay=self.weight_decay) + scheduler = optim.lr_scheduler.StepLR( + optimizer, step_size=self.scheduler_step_size, gamma=self.scheduler_gamma + ) + return [optimizer], [scheduler] + + def __eval_step(self, batch: HData, stage: Stage) -> Tensor: + scores = self.forward(batch.x, batch.hyperedge_index) + labels = batch.y + batch_size = batch.num_hyperedges + + loss = self._compute_loss(scores, labels, batch_size, stage) + self._compute_metrics(scores, labels, batch_size, stage) + return loss diff --git a/hyperbench/models/__init__.py b/hyperbench/models/__init__.py index a839efe..53b2ecf 100644 --- a/hyperbench/models/__init__.py +++ b/hyperbench/models/__init__.py @@ -1,8 +1,9 @@ from .common_neighbors import CommonNeighbors from .hgnn import HGNN +from .hnhn import HNHN from .hgnnp import HGNNP from .hypergcn import HyperGCN from .mlp import MLP, SLP from .node2vec import Node2Vec -__all__ = ["CommonNeighbors", "HGNN", "HGNNP", "HyperGCN", "MLP", "Node2Vec", "SLP"] +__all__ = ["CommonNeighbors", "HGNN", "HGNNP", "HNHN", "HyperGCN", "MLP", "Node2Vec", "SLP"] diff --git a/hyperbench/models/hnhn.py b/hyperbench/models/hnhn.py new file mode 100644 index 0000000..78f5492 --- /dev/null +++ b/hyperbench/models/hnhn.py @@ -0,0 +1,65 @@ +from torch import Tensor, nn + +from hyperbench.nn import HNHNConv + + +class HNHN(nn.Module): + """ + HNHN performs incidence-based hypergraph convolution with explicit hyperedge + embeddings between the node -> hyperedge -> node propagation steps. + - Proposed in `HNHN: Hypergraph Networks with Hyperedge Neurons `_ paper. + - Reference implementation: `source `_. + + Args: + in_channels: The number of input channels. + hidden_channels: The number of hidden channels. + num_classes: The number of output channels. + bias: If set to ``False``, the layer will not learn the bias parameter. Defaults to ``True``. + use_batch_normalization: If set to ``True``, layers will use batch normalization. Defaults to ``False``. + drop_rate: Dropout ratio. Defaults to ``0.5``. + """ + + def __init__( + self, + in_channels: int, + hidden_channels: int, + num_classes: int, + bias: bool = True, + use_batch_normalization: bool = False, + drop_rate: float = 0.5, + ): + super().__init__() + + self.layers = nn.ModuleList( + [ + HNHNConv( + in_channels=in_channels, + out_channels=hidden_channels, + bias=bias, + use_batch_normalization=use_batch_normalization, + drop_rate=drop_rate, + ), + HNHNConv( + in_channels=hidden_channels, + out_channels=num_classes, + bias=bias, + use_batch_normalization=use_batch_normalization, + is_last=True, + ), + ] + ) + + def forward(self, x: Tensor, hyperedge_index: Tensor) -> Tensor: + """ + Apply two stacked ``HNHNConv`` layers to produce node embeddings. + + Args: + x: Input node feature matrix of size ``(num_nodes, in_channels)``. + hyperedge_index: Hyperedge incidence in COO format of size ``(2, num_incidences)``. + + Returns: + The output node feature matrix of size ``(num_nodes, num_classes)``. + """ + for layer in self.layers: + x = layer(x, hyperedge_index) + return x diff --git a/hyperbench/nn/__init__.py b/hyperbench/nn/__init__.py index 7190c0e..e7af09e 100644 --- a/hyperbench/nn/__init__.py +++ b/hyperbench/nn/__init__.py @@ -1,7 +1,7 @@ from hyperbench.utils import Aggregation -from .aggregator import HyperedgeAggregator -from .conv import HGNNConv, HGNNPConv, HyperGCNConv +from .aggregator import HyperedgeAggregator, NodeAggregator +from .conv import HGNNConv, HGNNPConv, HNHNConv, HyperGCNConv from .enricher import ( EnrichmentMode, NodeEnricher, @@ -19,9 +19,11 @@ "EnrichmentMode", "HGNNConv", "HGNNPConv", + "HNHNConv", "HyperedgeAggregator", "HyperGCNConv", "NeighborScorer", + "NodeAggregator", "NodeEnricher", "HyperedgeEnricher", "HyperedgeAttrsEnricher", diff --git a/hyperbench/nn/aggregator.py b/hyperbench/nn/aggregator.py index 03c01de..3626844 100644 --- a/hyperbench/nn/aggregator.py +++ b/hyperbench/nn/aggregator.py @@ -1,11 +1,22 @@ from torch import Tensor -from typing import Literal +from typing import Literal, Optional from torch_geometric.utils import scatter from hyperbench.types import HyperedgeIndex class HyperedgeAggregator: + """ + Pool node embeddings into hyperedge embeddings using the incidence structure. + + Each node-hyperedge incidence selects one node embedding row, then reduces + those rows per hyperedge with the requested scatter aggregation. + + Args: + hyperedge_index: Hyperedge incidence in COO format of size ``(2, num_incidences)``. + node_embeddings: Node embedding matrix of size ``(num_nodes, num_channels)``. + """ + def __init__( self, hyperedge_index: Tensor, @@ -15,25 +26,50 @@ def __init__( self.node_embeddings = node_embeddings def pool(self, aggregation: Literal["max", "min", "mean", "mul", "sum"]) -> Tensor: + """ + Aggregate node embeddings for each hyperedge. + + Examples: + >>> hyperedge_index = [[0, 1, 2, 2, 3], + ... [0, 0, 0, 1, 1]] + >>> node_embeddings = [[1, 10], [2, 20], [3, 30], [4, 40]] + >>> HyperedgeAggregator(hyperedge_index, node_embeddings).pool("mean") + ... [[2, 20], [3.5, 35]] + >>> HyperedgeAggregator(hyperedge_index, node_embeddings).pool("sum") + ... [[6, 60], [7, 70]] + >>> HyperedgeAggregator(hyperedge_index, node_embeddings).pool("max") + ... [[3, 30], [4, 40]] + + Args: + aggregation: Reduction applied across the nodes belonging to each hyperedge. + + Returns: + A hyperedge embedding matrix of shape ``(num_hyperedges, num_channels)``. + """ # Gather the embeddings for each incidence. # A node appearing in multiple hyperedges is repeated, once per incidence. - # Example: all_node_ids = [0, 1, 2, 2, 3] (node 2 appears twice, once per hyperedge) - # -> incidence_node_embeddings = [[e00, e01], # node 0 - # [e10, e11], # node 1 - # [e20, e21], # node 2 (for hyperedge 0) - # [e20, e21], # node 2 (for hyperedge 1) - # [e30, e31]] # node 3 - # shape: (num_incidences, out_channels) + # Example: node_embeddings = [[1, 10], # node 0 + # [2, 20], # node 1 + # [3, 30], # node 2 + # [4, 40]] # node 3 + # -> all_node_ids = [0, 1, 2, 2, 3] + # -> incidence_node_embeddings = [[1, 10], # node 0 for hyperedge 0 + # [2, 20], # node 1 for hyperedge 0 + # [3, 30], # node 2 for hyperedge 0 + # [3, 30], # node 2 for hyperedge 1 + # [4, 40]] # node 3 for hyperedge 1 + # shape: (num_incidences, num_channels) incidence_node_embeddings = self.node_embeddings[self.hyperedge_index_wrapper.all_node_ids] # Scatter-aggregate node embeddings into hyperedge embeddings. # Example: with aggregation="sum": - # [[e00+e10+e20, e01+e11+e21], # hyperedge 0 contains node 0, 1, 2 - # [e20+e30, e21+e31]] # hyperedge 1 contains node 2, 3 - # shape: (num_hyperedges, out_channels) + # [[1+2+3, 10+20+30], # hyperedge 0 contains node 0, 1, 2 + # [3+4, 30+40]] # hyperedge 1 contains node 2, 3 + # shape: (num_hyperedges, num_channels) # with aggregation="max": - # [[max(e00, e10, e20), max(e01, e11, e21)], # hyperedge 0 contains node 0, 1, 2 - # [max(e20, e30), max(e21, e31)]] # hyperedge 1 contains node 2, 3 + # [[max(1, 2, 3), max(10, 20, 30)], # hyperedge 0 contains node 0, 1, 2 + # [max(3, 4), max(30, 40)]] # hyperedge 1 contains node 2, 3 + # shape: (num_hyperedges, num_channels) return scatter( src=incidence_node_embeddings, index=self.hyperedge_index_wrapper.all_hyperedge_ids, @@ -41,3 +77,83 @@ def pool(self, aggregation: Literal["max", "min", "mean", "mul", "sum"]) -> Tens dim_size=self.hyperedge_index_wrapper.num_hyperedges, reduce=aggregation, ) + + +class NodeAggregator: + """ + Pool hyperedge embeddings into node embeddings using the incidence structure. + + Each node-hyperedge incidence selects one hyperedge embedding row, then + reduces those rows per node with the requested scatter aggregation. + + Args: + hyperedge_index: Hyperedge incidence in COO format of size ``(2, num_incidences)``. + hyperedge_embeddings: Hyperedge embedding matrix of size ``(num_hyperedges, num_channels)``. + num_nodes: Optional explicit node count. When provided, the pooled output preserves isolated nodes that do not appear in ``hyperedge_index``. + """ + + def __init__( + self, + hyperedge_index: Tensor, + hyperedge_embeddings: Tensor, + num_nodes: Optional[int] = None, + ): + self.hyperedge_index_wrapper = HyperedgeIndex(hyperedge_index) + self.hyperedge_embeddings = hyperedge_embeddings + self.num_nodes = num_nodes + + def pool(self, aggregation: Literal["max", "min", "mean", "mul", "sum"]) -> Tensor: + """ + Aggregate hyperedge embeddings for each node. + + Examples: + >>> hyperedge_index = [[0, 1, 1, 2], + ... [0, 0, 1, 1]] + >>> hyperedge_embeddings = [[10, 100], [20, 200]] + >>> NodeAggregator(hyperedge_index, hyperedge_embeddings).pool("mean") + ... [[10, 100], [15, 150], [20, 200]] + >>> NodeAggregator(hyperedge_index, hyperedge_embeddings).pool("sum") + ... [[10, 100], [30, 300], [20, 200]] + >>> NodeAggregator(hyperedge_index, hyperedge_embeddings).pool("max") + ... [[10, 100], [20, 200], [20, 200]] + + Args: + aggregation: Reduction applied across the hyperedges incident to each node. + + Returns: + A node embedding matrix of shape ``(num_nodes, num_channels)``. + """ + # Gather the embeddings for each incidence. + # A hyperedge appearing in multiple node incidences is repeated, once per incidence. + # Example: hyperedge_embeddings = [[10, 100], # hyperedge 0 + # [20, 200]] # hyperedge 1 + # -> all_hyperedge_ids = [0, 0, 1, 1] + # -> incidence_hyperedge_embeddings = [[10, 100], # hyperedge 0 for node 0 + # [10, 100], # hyperedge 0 for node 1 + # [20, 200], # hyperedge 1 for node 1 + # [20, 200]] # hyperedge 1 for node 2 + # shape: (num_incidences, num_channels) + incidence_hyperedge_embeddings = self.hyperedge_embeddings[ + self.hyperedge_index_wrapper.all_hyperedge_ids + ] + num_nodes = ( + self.num_nodes if self.num_nodes is not None else self.hyperedge_index_wrapper.num_nodes + ) + + # Scatter-aggregate hyperedge embeddings into node embeddings. + # Example: with aggregation="sum": + # [[10, 100], # node 0 belongs to hyperedge 0 + # [10+20, 100+200], # node 1 belongs to hyperedge 0 and 1 + # [20, 200]] # node 2 belongs to hyperedge 1 + # shape: (num_nodes, num_channels) + # with aggregation="max": + # [[10, 100], # node 0 belongs to hyperedge 0 + # [max(10, 20), max(100, 200)], # node 1 belongs to hyperedge 0 and 1 + # [20, 200]] # node 2 belongs to hyperedge 1 + return scatter( + src=incidence_hyperedge_embeddings, + index=self.hyperedge_index_wrapper.all_node_ids, + dim=0, # scatter along the node dimension + dim_size=num_nodes, + reduce=aggregation, + ) diff --git a/hyperbench/nn/conv.py b/hyperbench/nn/conv.py index e18a61a..c0003f5 100644 --- a/hyperbench/nn/conv.py +++ b/hyperbench/nn/conv.py @@ -1,5 +1,6 @@ -from typing import Optional +from typing import Literal, Optional from torch import Tensor, nn +from hyperbench.nn.aggregator import HyperedgeAggregator, NodeAggregator from hyperbench.types import EdgeIndex, Graph, HyperedgeIndex, Hypergraph @@ -136,12 +137,11 @@ def forward(self, x: Tensor, hyperedge_index: Tensor) -> Tensor: for efficiency when the hypergraph structure does not change across forward passes. Args: - x: Input node feature matrix. Size ``(num_nodes, in_channels)``. - hyperedge_index: Hyperedge incidence in COO format. Size ``(2, num_incidences)``, - where row 0 contains node IDs and row 1 contains hyperedge IDs. + x: Input node feature matrix of size ``(num_nodes, in_channels)``. + hyperedge_index: Hyperedge incidence in COO format of size ``(2, num_incidences)``. Returns: - The output node feature matrix. Size ``(num_nodes, out_channels)``. + The output node feature matrix of size ``(num_nodes, out_channels)``. """ x = self.theta(x) @@ -204,12 +204,11 @@ def forward(self, x: Tensor, hyperedge_index: Tensor) -> Tensor: ``X' = sigma( D_v^{-1} H D_e^{-1} H^T (X Theta) )`` Args: - x: Input node feature matrix. Size ``(num_nodes, in_channels)``. - hyperedge_index: Hyperedge incidence in COO format. Size ``(2, num_incidences)``, - where row 0 contains node IDs and row 1 contains hyperedge IDs. + x: Input node feature matrix of size ``(num_nodes, in_channels)``. + hyperedge_index: Hyperedge incidence in COO format of size ``(2, num_incidences)``. Returns: - The output node feature matrix. Size ``(num_nodes, out_channels)``. + The output node feature matrix of size ``(num_nodes, out_channels)``. """ x = self.theta(x) @@ -225,3 +224,69 @@ def forward(self, x: Tensor, hyperedge_index: Tensor) -> Tensor: x = self.dropout(x) return x + + +class HNHNConv(nn.Module): + """ + The HNHNConv layer proposed in `HNHN: Hypergraph Networks with Hyperedge Neurons `_ paper. + Reference implementation: `source `_. + + Args: + in_channels: The number of input channels. + out_channels: The number of output channels. + bias: If set to ``False``, the layer will not learn the bias parameter. Defaults to ``True``. + use_batch_normalization: If set to ``True``, the layer will use batch normalization. Defaults to ``False``. + drop_rate: If set to a positive number, the layer will use dropout. Defaults to ``0.5``. + is_last: If set to ``True``, the layer will not apply the final activation and dropout functions. Defaults to ``False``. + """ + + __AGGREGATION: Literal["mean"] = "mean" + + def __init__( + self, + in_channels: int, + out_channels: int, + bias: bool = True, + use_batch_normalization: bool = False, + drop_rate: float = 0.5, + is_last: bool = False, + ): + super().__init__() + self.is_last = is_last + self.batch_norm_1d = nn.BatchNorm1d(out_channels) if use_batch_normalization else None + self.activation_fn = nn.ReLU(inplace=True) + self.dropout = nn.Dropout(drop_rate) + self.theta_v2e = nn.Linear(in_channels, out_channels, bias=bias) + self.theta_e2v = nn.Linear(out_channels, out_channels, bias=bias) + + def forward(self, x: Tensor, hyperedge_index: Tensor) -> Tensor: + """ + Apply one HNHN convolution layer using two learned projections around + node-to-hyperedge and hyperedge-to-node mean aggregation. + + Args: + x: Input node feature matrix of size ``(num_nodes, in_channels)``. + hyperedge_index: Hyperedge incidence in COO format of size ``(2, num_incidences)``. + + Returns: + The output node feature matrix of size ``(num_nodes, out_channels)``. + """ + x = self.theta_v2e(x) + + hyperedge_embeddings = HyperedgeAggregator(hyperedge_index, x).pool(self.__AGGREGATION) + hyperedge_embeddings = self.activation_fn(hyperedge_embeddings) + hyperedge_embeddings = self.theta_e2v(hyperedge_embeddings) + + x = NodeAggregator( + hyperedge_index=hyperedge_index, + hyperedge_embeddings=hyperedge_embeddings, + num_nodes=x.size(0), + ).pool(self.__AGGREGATION) + + if not self.is_last: + x = self.activation_fn(x) + if self.batch_norm_1d is not None: + x = self.batch_norm_1d(x) + x = self.dropout(x) + + return x diff --git a/hyperbench/nn/enricher.py b/hyperbench/nn/enricher.py index ae65d20..4e9e9bb 100644 --- a/hyperbench/nn/enricher.py +++ b/hyperbench/nn/enricher.py @@ -289,13 +289,27 @@ def enrich(self, hyperedge_index: Tensor) -> Tensor: class LaplacianPositionalEncodingEnricher(NodeEnricher): + """ + Enrich node features with Laplacian Positional Encodings computed from the symmetric normalized Laplacian of the clique expansion of the hypergraph. + + Args: + num_features: Number of positional encoding features to generate for each node. + num_nodes: Total number of nodes in the graph. If not provided, it will be inferred from the hyperedge_index. + This is only needed if the hyperedge_index does not include all nodes (e.g., some isolated nodes are missing). + Another instance is when the setting is transductive and the hyperedge index contains some hyperedges + that do not contain all the nodes in the node space. + cache_dir: Optional directory to cache computed features. If ``None``, caching is disabled. + """ + def __init__( self, num_features: int, + num_nodes: int = 0, cache_dir: Optional[str] = None, ): super().__init__(cache_dir=cache_dir) self.num_features = num_features + self.num_nodes = num_nodes def enrich(self, hyperedge_index: Tensor) -> Tensor: """ @@ -313,7 +327,8 @@ def enrich(self, hyperedge_index: Tensor) -> Tensor: """ edge_index = HyperedgeIndex(hyperedge_index).reduce_to_edge_index_on_clique_expansion() edge_index_wrapper = EdgeIndex(edge_index) - laplacian_matrix = edge_index_wrapper.get_sparse_normalized_laplacian() + num_nodes = self.num_nodes if self.num_nodes > 0 else None + laplacian_matrix = edge_index_wrapper.get_sparse_normalized_laplacian(num_nodes=num_nodes) laplacian_matrix_dense = ( laplacian_matrix.to_dense() # torch.linalg.eigh only works on dense tensors ) diff --git a/hyperbench/tests/data/dataset_test.py b/hyperbench/tests/data/dataset_test.py index 7c44c01..3e367c0 100644 --- a/hyperbench/tests/data/dataset_test.py +++ b/hyperbench/tests/data/dataset_test.py @@ -1012,6 +1012,76 @@ def test_enrich_node_features_concatenate(mock_hdata): assert dataset.hdata.x.shape == (3, 5) # 1 original + 4 enriched +def test_enrich_node_features_from_dataset(): + source_dataset = Dataset.from_hdata( + HData( + x=torch.tensor([[1.0, 10.0], [2.0, 20.0], [3.0, 30.0]]), + hyperedge_index=torch.tensor([[0, 1, 2], [0, 0, 1]]), + global_node_ids=torch.tensor([100, 200, 300]), + ) + ) + target_dataset = Dataset.from_hdata( + HData( + x=torch.tensor([[0.0], [0.0]]), + hyperedge_index=torch.tensor([[0, 1], [0, 0]]), + global_node_ids=torch.tensor([300, 100]), + ) + ) + + target_dataset.enrich_node_features_from(source_dataset) + + assert torch.equal(target_dataset.hdata.x, torch.tensor([[3.0, 30.0], [1.0, 10.0]])) + + +def test_enrich_node_features_from_propagates_hdata_validation_errors(): + source_dataset = Dataset.from_hdata( + HData( + x=torch.tensor([[1.0], [2.0]]), + hyperedge_index=torch.tensor([[0, 1], [0, 0]]), + global_node_ids=torch.tensor([10, 20]), + ) + ) + target_dataset = Dataset.from_hdata( + HData( + x=torch.tensor([[0.0]]), + hyperedge_index=torch.tensor([[0], [0]]), + global_node_ids=torch.tensor([10]), + ) + ) + target_dataset.hdata.global_node_ids = None + + with pytest.raises( + ValueError, + match="Both HData instances must define global_node_ids to align node features.", + ): + target_dataset.enrich_node_features_from(source_dataset) + + +def test_enrich_node_features_from_dataset_with_fill_value(): + source_dataset = Dataset.from_hdata( + HData( + x=torch.tensor([[1.0, 10.0], [2.0, 20.0]]), + hyperedge_index=torch.tensor([[0, 1], [0, 0]]), + global_node_ids=torch.tensor([10, 20]), + ) + ) + target_dataset = Dataset.from_hdata( + HData( + x=torch.tensor([[0.0], [0.0]]), + hyperedge_index=torch.tensor([[0, 1], [0, 0]]), + global_node_ids=torch.tensor([10, 30]), + ) + ) + + target_dataset.enrich_node_features_from( + source_dataset, + node_space_setting="inductive", + fill_value=[7.0, 8.0], + ) + + assert torch.equal(target_dataset.hdata.x, torch.tensor([[1.0, 10.0], [7.0, 8.0]])) + + def test_enrich_hyperedge_attr_replace(mock_hdata): dataset = Dataset.from_hdata(mock_hdata) @@ -1230,6 +1300,89 @@ def test_split_without_edge_attr(mock_no_edge_attr_hypergraph): assert split.hdata.hyperedge_attr is None +def test_split_transductive_default_preserves_first_split_node_space(): + hdata = HData( + x=torch.arange(4, dtype=torch.float).unsqueeze(1), + hyperedge_index=torch.tensor([[0, 1, 2, 3], [0, 1, 2, 3]]), + global_node_ids=torch.tensor([100, 200, 300, 400]), + ) + dataset = Dataset.from_hdata(hdata) + + train_dataset, test_dataset = dataset.split([0.75, 0.25]) + + assert train_dataset.hdata.num_nodes == dataset.hdata.num_nodes + assert torch.equal(train_dataset.hdata.x, dataset.hdata.x) + assert test_dataset.hdata.num_nodes == 1 + + +def test_split_transductive_all_preserves_all_split_node_spaces(): + hdata = HData( + x=torch.arange(4, dtype=torch.float).unsqueeze(1), + hyperedge_index=torch.tensor([[0, 1, 2, 3], [0, 1, 2, 3]]), + global_node_ids=torch.tensor([100, 200, 300, 400]), + ) + dataset = Dataset.from_hdata(hdata) + + train_dataset, test_dataset = dataset.split( + [0.75, 0.25], + node_space_setting="transductive", + assign_node_space_to="all", + ) + + assert train_dataset.hdata.num_nodes == dataset.hdata.num_nodes + assert test_dataset.hdata.num_nodes == dataset.hdata.num_nodes + assert torch.equal(train_dataset.hdata.x, dataset.hdata.x) + assert torch.equal(test_dataset.hdata.x, dataset.hdata.x) + + +def test_split_raises_when_node_space_provided_with_transductive_disabled(): + hdata = HData( + x=torch.arange(4, dtype=torch.float).unsqueeze(1), + hyperedge_index=torch.tensor([[0, 1, 2, 3], [0, 1, 2, 3]]), + global_node_ids=torch.tensor([100, 200, 300, 400]), + ) + dataset = Dataset.from_hdata(hdata) + + with pytest.raises( + ValueError, + match="assign_node_space_to can only be provided when node_space_setting='transductive'.", + ): + dataset.split( + [0.75, 0.25], + node_space_setting="inductive", + assign_node_space_to="first", + ) + + +def test_nested_transductive_split_supports_train_feature_reuse(): + hdata = HData( + x=torch.arange(4, dtype=torch.float).unsqueeze(1), + hyperedge_index=torch.tensor([[0, 1, 2, 3], [0, 1, 2, 3]]), + global_node_ids=torch.tensor([100, 200, 300, 400]), + ) + dataset = Dataset.from_hdata(hdata) + + train_dataset, test_dataset = dataset.split( + [0.75, 0.25], + node_space_setting="transductive", + ) + train_dataset, val_dataset = train_dataset.split( + [2 / 3, 1 / 3], + node_space_setting="transductive", + ) + + enricher = MagicMock(spec=NodeEnricher) + enricher.enrich.return_value = torch.tensor( + [[10.0, 11.0], [20.0, 21.0], [30.0, 31.0], [40.0, 41.0]] + ) + train_dataset.enrich_node_features(enricher, enrichment_mode="replace") + val_dataset.enrich_node_features_from(train_dataset) + test_dataset.enrich_node_features_from(train_dataset) + + assert torch.equal(val_dataset.hdata.x, torch.tensor([[30.0, 31.0]])) + assert torch.equal(test_dataset.hdata.x, torch.tensor([[40.0, 41.0]])) + + def test_to_device(mock_hdata): device = torch.device("cpu") diff --git a/hyperbench/tests/data/loader_test.py b/hyperbench/tests/data/loader_test.py index 1a51dae..d213126 100644 --- a/hyperbench/tests/data/loader_test.py +++ b/hyperbench/tests/data/loader_test.py @@ -328,7 +328,8 @@ def test_collate_sample_full_hypergraph_returns_cached_hdata(mock_dataset_single assert torch.equal(batched.x, expected_hdata.x) assert torch.equal(batched.hyperedge_index, expected_hdata.hyperedge_index) assert torch.equal( - utils.to_non_empty_edgeattr(batched.hyperedge_attr), expected_hdata.hyperedge_attr + utils.to_non_empty_edgeattr(batched.hyperedge_attr), + utils.to_non_empty_edgeattr(expected_hdata.hyperedge_attr), ) diff --git a/hyperbench/tests/types/hdata_test.py b/hyperbench/tests/types/hdata_test.py index e2f4e5e..91ce459 100644 --- a/hyperbench/tests/types/hdata_test.py +++ b/hyperbench/tests/types/hdata_test.py @@ -481,29 +481,87 @@ def test_cat_same_node_space_drops_hyperedge_attr_when_partially_missing(): torch.tensor([[0, 1, 2, 2, 3], [0, 0, 0, 1, 1]]), id="both_hyperedges", ), + pytest.param( + torch.tensor([0]), + 3, + 1, + torch.tensor([[0, 1, 2], [0, 0, 0]]), + id="subset_hyperedges", + ), ], ) -def test_split_counts( +def test_split_inductive_counts( split_ids, expected_num_nodes, expected_num_hyperedges, expected_hyperedge_index ): x = torch.randn(4, 2) hyperedge_index = torch.tensor([[0, 1, 2, 2, 3], [0, 0, 0, 1, 1]]) hdata = HData(x=x, hyperedge_index=hyperedge_index) - result = HData.split(hdata, split_hyperedge_ids=split_ids) + result = HData.split( + hdata, + split_hyperedge_ids=split_ids, + node_space_setting="inductive", + ) assert result.num_nodes == expected_num_nodes assert result.num_hyperedges == expected_num_hyperedges assert torch.equal(result.hyperedge_index, expected_hyperedge_index) -def test_split_subsets_node_features(): +@pytest.mark.parametrize( + "split_ids, expected_num_nodes, expected_num_hyperedges, expected_hyperedge_index", + [ + pytest.param( + torch.tensor([0]), 4, 1, torch.tensor([[0, 1, 2], [0, 0, 0]]), id="first_hyperedge" + ), + pytest.param( + torch.tensor([1]), 4, 1, torch.tensor([[2, 3], [0, 0]]), id="second_hyperedge" + ), + pytest.param( + torch.tensor([0, 1]), + 4, + 2, + torch.tensor([[0, 1, 2, 2, 3], [0, 0, 0, 1, 1]]), + id="both_hyperedges", + ), + pytest.param( + torch.tensor([0]), + 4, + 1, + torch.tensor([[0, 1, 2], [0, 0, 0]]), + id="subset_hyperedges", + ), + ], +) +def test_split_transductive_counts( + split_ids, expected_num_nodes, expected_num_hyperedges, expected_hyperedge_index +): + x = torch.randn(4, 2) + hyperedge_index = torch.tensor([[0, 1, 2, 2, 3], [0, 0, 0, 1, 1]]) + hdata = HData(x=x, hyperedge_index=hyperedge_index) + + result = HData.split( + hdata, + split_hyperedge_ids=split_ids, + node_space_setting="transductive", + ) + + assert result.num_nodes == expected_num_nodes + assert result.num_hyperedges == expected_num_hyperedges + assert torch.equal(result.hyperedge_index, expected_hyperedge_index) + + +def test_split_inductive_subsets_node_features(): x = torch.tensor([[10.0], [20.0], [30.0], [40.0], [50.0]]) hyperedge_index = torch.tensor([[0, 1, 3, 4], [0, 0, 1, 1]]) hdata = HData(x=x, hyperedge_index=hyperedge_index) hyperedge_ids = torch.tensor([1]) # Split by hyperedge 1, which includes nodes 3 and 4 - result = HData.split(hdata, split_hyperedge_ids=hyperedge_ids) + result = HData.split( + hdata, + split_hyperedge_ids=hyperedge_ids, + node_space_setting="inductive", + ) # Only nodes 3 and 4 should be included assert result.num_nodes == 2 @@ -522,16 +580,80 @@ def test_split_subsets_labels(): assert torch.equal(result.y, torch.tensor([0.0])) -def test_split_handles_none_global_node_ids(): +@pytest.mark.parametrize( + "node_space_setting, split_hyperedge_ids, expected_global_node_ids", + [ + pytest.param( + "transductive", + torch.tensor([1]), + torch.arange(4), + id="transductive", + ), + pytest.param( + "inductive", + torch.tensor([1]), + torch.arange(2), + id="inductive", + ), + ], +) +def test_split_handles_none_global_node_ids( + node_space_setting, split_hyperedge_ids, expected_global_node_ids +): x = torch.tensor([[10.0], [20.0], [30.0], [40.0]]) hyperedge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]]) hdata = HData(x=x, hyperedge_index=hyperedge_index) hdata.global_node_ids = None - result = HData.split(hdata, split_hyperedge_ids=torch.tensor([1])) + result = HData.split( + hdata, + split_hyperedge_ids=split_hyperedge_ids, + node_space_setting=node_space_setting, + ) assert result.global_node_ids is not None - assert torch.equal(result.global_node_ids, torch.arange(result.num_nodes)) + assert torch.equal(result.global_node_ids, expected_global_node_ids) + + +def test_split_transductive_keeps_full_x_and_global_node_ids(): + x = torch.tensor([[10.0], [20.0], [30.0], [40.0], [50.0]]) + hyperedge_index = torch.tensor([[0, 2, 3, 4], [0, 0, 1, 1]]) + global_node_ids = torch.tensor([10, 20, 30, 40, 50]) + hdata = HData( + x=x, + hyperedge_index=hyperedge_index, + global_node_ids=global_node_ids, + y=torch.tensor([1.0, 0.0]), + ) + + result = HData.split( + hdata, + split_hyperedge_ids=torch.tensor([1]), + node_space_setting="transductive", + ) + + assert result.num_nodes == hdata.num_nodes + assert torch.equal(result.x, x) + assert result.global_node_ids is not None + assert torch.equal(result.global_node_ids, global_node_ids) + assert torch.equal(result.hyperedge_index, torch.tensor([[3, 4], [0, 0]])) + assert torch.equal(result.y, torch.tensor([0.0])) + + +def test_split_transductive_handles_none_global_node_ids(): + x = torch.tensor([[10.0], [20.0], [30.0], [40.0], [50.0]]) + hyperedge_index = torch.tensor([[0, 2, 3, 4], [0, 0, 1, 1]]) + hdata = HData(x=x, hyperedge_index=hyperedge_index, y=torch.tensor([1.0, 0.0])) + hdata.global_node_ids = None + + result = HData.split( + hdata, + split_hyperedge_ids=torch.tensor([1]), + node_space_setting="transductive", + ) + + assert result.global_node_ids is not None + assert torch.equal(result.global_node_ids, torch.arange(hdata.num_nodes)) def test_split_subsets_edge_attr(): @@ -636,6 +758,236 @@ def test_enrich_node_features_concatenate(mock_hdata): assert result.x.shape == (5, 7) # 4 original + 3 enriched +@pytest.mark.parametrize( + "enrichment_mode", + [ + pytest.param("replace", id="replace"), + pytest.param("concatenate", id="concatenate"), + pytest.param(None, id="none_enrichment_mode_defaults_to_replace"), + ], +) +def test_enrich_node_features_replace_preserves_global_node_ids(mock_hdata, enrichment_mode): + global_node_ids = torch.tensor([10, 20, 30, 40, 50]) + mock_hdata.global_node_ids = global_node_ids + + enricher = MagicMock(spec=NodeEnricher) + enricher.enrich.return_value = torch.randn(5, 3) + + result = mock_hdata.enrich_node_features(enricher, enrichment_mode=enrichment_mode) + + assert result.global_node_ids is not None + assert torch.equal(result.global_node_ids, global_node_ids) + + +def test_enrich_node_features_from_aligns_by_global_node_ids(): + source_hdata = HData( + x=torch.tensor([[1.0, 10.0], [2.0, 20.0], [3.0, 30.0]]), + hyperedge_index=torch.tensor([[0, 1, 2], [0, 0, 1]]), + global_node_ids=torch.tensor([100, 200, 300]), + ) + target_hdata = HData( + x=torch.tensor([[0.0], [0.0]]), + hyperedge_index=torch.tensor([[0, 1], [0, 0]]), + global_node_ids=torch.tensor([300, 100]), + y=torch.tensor([0.0]), + ) + + result = target_hdata.enrich_node_features_from(source_hdata) + + assert torch.equal(result.x, torch.tensor([[3.0, 30.0], [1.0, 10.0]])) + assert torch.equal(result.hyperedge_index, target_hdata.hyperedge_index) + assert result.hyperedge_weights is None + assert result.hyperedge_attr is None + assert result.global_node_ids is not None + assert torch.equal( + result.global_node_ids, utils.to_non_empty_edgeattr(target_hdata.global_node_ids) + ) + assert torch.equal(result.y, target_hdata.y) + + +@pytest.mark.parametrize( + "missing_side", + [ + pytest.param("source", id="source_missing_global_node_ids"), + pytest.param("target", id="target_missing_global_node_ids"), + ], +) +def test_enrich_node_features_from_raises_without_global_node_ids(missing_side): + source_hdata = HData( + x=torch.tensor([[1.0], [2.0]]), + hyperedge_index=torch.tensor([[0, 1], [0, 0]]), + global_node_ids=torch.tensor([10, 20]), + ) + target_hdata = HData( + x=torch.tensor([[0.0]]), + hyperedge_index=torch.tensor([[0], [0]]), + global_node_ids=torch.tensor([10]), + ) + + if missing_side == "source": + source_hdata.global_node_ids = None + else: + target_hdata.global_node_ids = None + + with pytest.raises( + ValueError, + match="Both HData instances must define global_node_ids to align node features.", + ): + target_hdata.enrich_node_features_from(source_hdata) + + +def test_enrich_node_features_from_raises_when_source_rows_do_not_match_global_node_ids(): + source_hdata = HData( + x=torch.empty((0, 0)), + hyperedge_index=torch.tensor([[0, 1], [0, 0]]), + ) + target_hdata = HData( + x=torch.tensor([[0.0]]), + hyperedge_index=torch.tensor([[0], [0]]), + global_node_ids=torch.tensor([0]), + ) + + with pytest.raises( + ValueError, + match="Expected hdata_with_features.x rows to align with hdata_with_features.global_node_ids.", + ): + target_hdata.enrich_node_features_from(source_hdata) + + +def test_enrich_node_features_from_raises_when_target_node_missing_from_source(): + source_hdata = HData( + x=torch.tensor([[1.0], [2.0]]), + hyperedge_index=torch.tensor([[0, 1], [0, 0]]), + global_node_ids=torch.tensor([10, 20]), + ) + target_hdata = HData( + x=torch.tensor([[0.0], [0.0]]), + hyperedge_index=torch.tensor([[0, 1], [0, 0]]), + global_node_ids=torch.tensor([10, 30]), + ) + + with pytest.raises( + ValueError, + match=r"Missing node features for target global_node_ids: \[30\]\.", + ): + target_hdata.enrich_node_features_from(source_hdata) + + +@pytest.mark.parametrize( + "fill_value, expected_x", + [ + pytest.param(0.5, torch.tensor([[1.0, 10.0], [0.5, 0.5]]), id="scalar_fill_value"), + pytest.param( + [7.0, 8.0], + torch.tensor([[1.0, 10.0], [7.0, 8.0]]), + id="vector_fill_value", + ), + pytest.param( + torch.tensor([7.0, 8.0]), + torch.tensor([[1.0, 10.0], [7.0, 8.0]]), + id="tensor_fill_value", + ), + pytest.param( + [0.5], + torch.tensor([[1.0, 10.0], [0.5, 0.5]]), + id="missing_dimensions_scalar_vector_fill_value", + ), + pytest.param( + torch.tensor(0.5), + torch.tensor([[1.0, 10.0], [0.5, 0.5]]), + id="missing_dimensions_scalar_tensor_fill_value", + ), + ], +) +def test_enrich_node_features_from_inductive_fill_value(fill_value, expected_x): + source_hdata = HData( + x=torch.tensor([[1.0, 10.0], [2.0, 20.0]]), + hyperedge_index=torch.tensor([[0, 1], [0, 0]]), + global_node_ids=torch.tensor([10, 20]), + ) + target_hdata = HData( + x=torch.tensor([[0.0], [0.0]]), + hyperedge_index=torch.tensor([[0, 1], [0, 0]]), + global_node_ids=torch.tensor([10, 30]), + ) + + result = target_hdata.enrich_node_features_from( + source_hdata, + node_space_setting="inductive", + fill_value=fill_value, + ) + + assert torch.equal(result.x, expected_x) + + +def test_enrich_node_features_from_inductive_raises_without_fill_value(): + source_hdata = HData( + x=torch.tensor([[1.0, 10.0], [2.0, 20.0]]), + hyperedge_index=torch.tensor([[0, 1], [0, 0]]), + global_node_ids=torch.tensor([10, 20]), + ) + target_hdata = HData( + x=torch.tensor([[0.0], [0.0]]), + hyperedge_index=torch.tensor([[0, 1], [0, 0]]), + global_node_ids=torch.tensor([10, 30]), + ) + + with pytest.raises( + ValueError, + match="fill_value must be provided when node_space_setting='inductive'.", + ): + target_hdata.enrich_node_features_from( + source_hdata, + node_space_setting="inductive", + ) + + +def test_enrich_node_features_from_transductive_raises_when_fill_value_provided(): + source_hdata = HData( + x=torch.tensor([[1.0, 10.0], [2.0, 20.0]]), + hyperedge_index=torch.tensor([[0, 1], [0, 0]]), + global_node_ids=torch.tensor([10, 20]), + ) + target_hdata = HData( + x=torch.tensor([[0.0]]), + hyperedge_index=torch.tensor([[0], [0]]), + global_node_ids=torch.tensor([10]), + ) + + with pytest.raises( + ValueError, + match="fill_value cannot be provided when node_space_setting='transductive'.", + ): + target_hdata.enrich_node_features_from( + source_hdata, + node_space_setting="transductive", + fill_value=0.0, + ) + + +def test_enrich_node_features_from_non_transductive_raises_on_fill_value_shape_mismatch(): + source_hdata = HData( + x=torch.tensor([[1.0, 10.0], [2.0, 20.0]]), + hyperedge_index=torch.tensor([[0, 1], [0, 0]]), + global_node_ids=torch.tensor([10, 20]), + ) + target_hdata = HData( + x=torch.tensor([[0.0], [0.0]]), + hyperedge_index=torch.tensor([[0, 1], [0, 0]]), + global_node_ids=torch.tensor([10, 30]), + ) + + with pytest.raises( + ValueError, + match=r"Expected fill_value to define exactly 2 features, got shape \(3,\)\.", + ): + target_hdata.enrich_node_features_from( + source_hdata, + node_space_setting="inductive", + fill_value=[1.0, 2.0, 3.0], + ) + + def test_enrich_hyperedge_weights_replace(): x = torch.tensor([[1.0], [2.0], [3.0]]) hyperedge_index = torch.tensor([[0, 1, 2], [0, 0, 1]]) diff --git a/hyperbench/tests/utils/node_utils_test.py b/hyperbench/tests/utils/node_utils_test.py new file mode 100644 index 0000000..add0929 --- /dev/null +++ b/hyperbench/tests/utils/node_utils_test.py @@ -0,0 +1,74 @@ +import pytest + +from hyperbench.utils import ( + is_assigned_to_all, + is_assigned_to_first, + is_inductive_setting, + is_transductive_setting, + is_transductive_split, +) + + +@pytest.mark.parametrize( + "node_space_assignment, expected", + [ + pytest.param("all", True, id="all"), + pytest.param("first", False, id="first"), + pytest.param(None, False, id="none"), + ], +) +def test_is_assigned_to_all(node_space_assignment, expected): + assert is_assigned_to_all(node_space_assignment) == expected + + +@pytest.mark.parametrize( + "node_space_assignment, expected", + [ + pytest.param("first", True, id="first"), + pytest.param("all", False, id="all"), + pytest.param(None, False, id="none"), + ], +) +def test_is_assigned_to_first(node_space_assignment, expected): + assert is_assigned_to_first(node_space_assignment) == expected + + +@pytest.mark.parametrize( + "node_space_setting, expected", + [ + pytest.param("inductive", True, id="inductive"), + pytest.param("transductive", False, id="transductive"), + pytest.param(None, False, id="none"), + ], +) +def test_is_inductive_setting(node_space_setting, expected): + assert is_inductive_setting(node_space_setting) == expected + + +@pytest.mark.parametrize( + "node_space_setting, expected", + [ + pytest.param("transductive", True, id="transductive"), + pytest.param("inductive", False, id="inductive"), + pytest.param(None, False, id="none"), + ], +) +def test_is_transductive_setting(node_space_setting, expected): + assert is_transductive_setting(node_space_setting) == expected + + +@pytest.mark.parametrize( + "node_space_setting, assign_node_space_to, split_num, expected", + [ + pytest.param("inductive", "all", 0, False, id="inductive_all_first_split"), + pytest.param("inductive", "first", 0, False, id="inductive_first_first_split"), + pytest.param(None, "all", 0, False, id="none_setting"), + pytest.param("transductive", "all", 0, True, id="transductive_all_first_split"), + pytest.param("transductive", "all", 2, True, id="transductive_all_later_split"), + pytest.param("transductive", "first", 0, True, id="transductive_first_first_split"), + pytest.param("transductive", "first", 1, False, id="transductive_first_later_split"), + pytest.param("transductive", None, 0, False, id="transductive_none_assignment"), + ], +) +def test_is_transductive_split(node_space_setting, assign_node_space_to, split_num, expected): + assert is_transductive_split(node_space_setting, assign_node_space_to, split_num) == expected diff --git a/hyperbench/train/trainer.py b/hyperbench/train/trainer.py index 2ec80c0..b251d5b 100644 --- a/hyperbench/train/trainer.py +++ b/hyperbench/train/trainer.py @@ -372,30 +372,30 @@ def __start_tensorboard_process(self) -> Optional[subprocess.Popen]: def __device(self, trainer: L.Trainer) -> str: if trainer.strategy is None: - return MultiModelTrainer.__UNKNOWN_DEVICE + return self.__UNKNOWN_DEVICE strategy = trainer.strategy if strategy.root_device is None: - return MultiModelTrainer.__UNKNOWN_DEVICE + return self.__UNKNOWN_DEVICE return str(strategy.root_device) def __next_experiment_name(self, save_dir: Path) -> Path: if not save_dir.exists(): - return Path(f"{MultiModelTrainer.EXPERIMENT_NAME_PREFIX}_0") + return Path(f"{self.EXPERIMENT_NAME_PREFIX}_0") existing_experiment_names: List[str] = [ dir.name for dir in save_dir.iterdir() - if dir.is_dir() and dir.name.startswith(MultiModelTrainer.EXPERIMENT_NAME_PREFIX) + if dir.is_dir() and dir.name.startswith(self.EXPERIMENT_NAME_PREFIX) ] if len(existing_experiment_names) < 1: - return Path(f"{MultiModelTrainer.EXPERIMENT_NAME_PREFIX}_0") + return Path(f"{self.EXPERIMENT_NAME_PREFIX}_0") last_experiment_number = max( int(experiment_name.split("_")[1]) for experiment_name in existing_experiment_names if experiment_name.split("_")[1].isdigit() ) - return Path(f"{MultiModelTrainer.EXPERIMENT_NAME_PREFIX}_{last_experiment_number + 1}") + return Path(f"{self.EXPERIMENT_NAME_PREFIX}_{last_experiment_number + 1}") def __setup_logdir( self, @@ -403,9 +403,7 @@ def __setup_logdir( experiment_name: Optional[str], ) -> Path: base_dir = ( - Path(MultiModelTrainer.DEFAULT_BASE_LOG_DIR) - if default_root_dir is None - else Path(default_root_dir) + Path(self.DEFAULT_BASE_LOG_DIR) if default_root_dir is None else Path(default_root_dir) ) next_experiment_name = ( self.__next_experiment_name(base_dir) @@ -428,7 +426,7 @@ def __setup_logger( CSVLogger( save_dir=self.log_dir, name=model_config.name, - version=f"{MultiModelTrainer.VERSION_NAME_PREFIX}_{model_config.version}", + version=f"{self.VERSION_NAME_PREFIX}_{model_config.version}", ), MarkdownTableLogger( save_dir=self.log_dir, @@ -454,7 +452,7 @@ def __setup_logger( TensorBoardLogger( save_dir=self.log_dir, name=model_config.name, - version=f"{MultiModelTrainer.VERSION_NAME_PREFIX}_{model_config.version}", + version=f"{self.VERSION_NAME_PREFIX}_{model_config.version}", ), ) diff --git a/hyperbench/types/hdata.py b/hyperbench/types/hdata.py index b0ba584..ecf6ffb 100644 --- a/hyperbench/types/hdata.py +++ b/hyperbench/types/hdata.py @@ -2,9 +2,17 @@ from torch import Tensor from typing import Optional, Sequence, Dict, Any -from hyperbench.utils import empty_hyperedgeindex, empty_nodefeatures -from hyperbench.nn.enricher import EnrichmentMode, NodeEnricher, HyperedgeEnricher +from hyperbench.utils import ( + NodeSpaceFiller, + NodeSpaceSetting, + empty_hyperedgeindex, + empty_nodefeatures, + is_inductive_setting, + is_transductive_setting, + to_0based_ids, +) +from hyperbench.nn.enricher import EnrichmentMode, NodeEnricher, HyperedgeEnricher from hyperbench.types.hypergraph import HyperedgeIndex @@ -217,23 +225,33 @@ def from_hyperedge_index(cls, hyperedge_index: Tensor) -> "HData": ) @classmethod - def split(cls, hdata: "HData", split_hyperedge_ids: Tensor) -> "HData": + def split( + cls, + hdata: "HData", + split_hyperedge_ids: Tensor, + node_space_setting: NodeSpaceSetting = "transductive", + ) -> "HData": """ Build an :class:`HData` for a single split from the given hyperedge IDs. Examples: - >>> hyperedge_index = [[0, 0, 1, 2, 3, 4], - ... [0, 0, 0, 1, 2, 2]] - >>> split_hyperedge_ids = [0, 2] - >>> new_hyperedge_index = [[0, 0, 1, 2, 3], # nodes 0 -> 0, 1 -> 1, 3 -> 2, 4 -> 3 (remapped to 0-based) - ... [0, 0, 0, 1, 1]] # hyperedges 0 -> 0, 2 -> 1 (remapped to 0-based) - >>> new_x = [x[0], x[1], x[3], x[4]] - >>> new_hyperedge_attr = [hyperedge_attr[0], hyperedge_attr[2]] - >>> new_hyperedge_weights = [hyperedge_weights[0], hyperedge_weights[2]] + Transductive split (default) preserving the full node space: + >>> split_hdata = HData.split(hdata, torch.tensor([1]), node_space_setting="transductive") + >>> split_hdata.x.shape[0] == hdata.x.shape[0] + >>> split_hdata.hyperedge_index + ... # node IDs stay in the original row space, hyperedge IDs are rebased + + Inductive split: + >>> split_hdata = HData.split(hdata, torch.tensor([1]), node_space_setting="inductive") + >>> split_hdata.x.shape[0] # only nodes incident to hyperedge 1 + ... 2 Args: hdata: The original :class:`HData` containing the full hypergraph. split_hyperedge_ids: Tensor of hyperedge IDs to include in this split. + node_space_setting: Whether to preserve the full node space in the splits. + ``transductive`` (default) ensures all nodes are present in every split, + while ``inductive`` allows splits to have disjoint node spaces. Returns: The splitted instance with remapped node and hyperedge IDs. @@ -248,44 +266,71 @@ def split(cls, hdata: "HData", split_hyperedge_ids: Tensor) -> "HData": # Example: hyperedge_index = [[0, 0, 1, 3, 4], # [0, 0, 0, 2, 2]] # incidence [2, 1] is missing as 1 is not in split_hyperedge_ids = [0, 2] - split_hyperedge_index = hdata.hyperedge_index[:, keep_mask] + split_hyperedge_index = hdata.hyperedge_index[:, keep_mask].clone() + + # Example: split_hyperedge_index = [[2, 3, 4], + # [2, 2, 5]] + # -> split_unique_hyperedge_ids = [2, 5] + split_unique_hyperedge_ids = split_hyperedge_index[1].unique() + + split_y = hdata.y[split_unique_hyperedge_ids] + + split_hyperedge_attr = None + if hdata.hyperedge_attr is not None: + split_hyperedge_attr = hdata.hyperedge_attr[split_unique_hyperedge_ids] + + split_hyperedge_weights = None + if hdata.hyperedge_weights is not None: + split_hyperedge_weights = hdata.hyperedge_weights[split_unique_hyperedge_ids] + + # We don't need to split nodes, so we split only hyperedges and rebase their IDs to 0-based + if is_transductive_setting(node_space_setting): + # Example: split_unique_hyperedge_ids = [2, 5] + # -> hyperedge 2 -> 0, hyperedge 5 -> 1 + split_hyperedge_index[1] = to_0based_ids( + original_ids=split_hyperedge_index[1], + ids_to_rebase=split_unique_hyperedge_ids, + ) + return cls( + x=hdata.x, + hyperedge_index=split_hyperedge_index, + hyperedge_weights=split_hyperedge_weights, + hyperedge_attr=split_hyperedge_attr, + num_nodes=hdata.num_nodes, + num_hyperedges=len(split_unique_hyperedge_ids), + global_node_ids=hdata.global_node_ids, + y=split_y, + ) # Example: split_hyperedge_index = [[0, 0, 1, 3, 4], # [0, 0, 0, 2, 2]] # -> split_unique_node_ids = [0, 1, 3, 4] - # -> split_unique_hyperedge_ids = [0, 2] split_unique_node_ids = split_hyperedge_index[0].unique() - split_unique_hyperedge_ids = split_hyperedge_index[1].unique() - split_hyperedge_index_wrapper = HyperedgeIndex(split_hyperedge_index).to_0based( - node_ids_to_rebase=split_unique_node_ids, - hyperedge_ids_to_rebase=split_unique_hyperedge_ids, + split_hyperedge_index = ( + HyperedgeIndex(split_hyperedge_index) + .to_0based( + node_ids_to_rebase=split_unique_node_ids, + hyperedge_ids_to_rebase=split_unique_hyperedge_ids, + ) + .item ) - new_x = hdata.x[split_unique_node_ids] - new_global_node_ids = None - if hdata.global_node_ids is not None: - new_global_node_ids = hdata.global_node_ids[split_unique_node_ids] - new_y = hdata.y[split_unique_hyperedge_ids] - - # Subset hyperedge_attr if present - new_hyperedge_attr = None - if hdata.hyperedge_attr is not None: - new_hyperedge_attr = hdata.hyperedge_attr[split_unique_hyperedge_ids] + split_x = hdata.x[split_unique_node_ids] - new_hyperedge_weights = None - if hdata.hyperedge_weights is not None: - new_hyperedge_weights = hdata.hyperedge_weights[split_unique_hyperedge_ids] + split_global_node_ids = None + if hdata.global_node_ids is not None: + split_global_node_ids = hdata.global_node_ids[split_unique_node_ids] return cls( - x=new_x, - hyperedge_index=split_hyperedge_index_wrapper.item, - hyperedge_weights=new_hyperedge_weights, - hyperedge_attr=new_hyperedge_attr, + x=split_x, + hyperedge_index=split_hyperedge_index, + hyperedge_weights=split_hyperedge_weights, + hyperedge_attr=split_hyperedge_attr, num_nodes=len(split_unique_node_ids), num_hyperedges=len(split_unique_hyperedge_ids), - global_node_ids=new_global_node_ids, - y=new_y, + global_node_ids=split_global_node_ids, + y=split_y, ) def enrich_node_features( @@ -317,6 +362,122 @@ def enrich_node_features( hyperedge_attr=self.hyperedge_attr, num_nodes=self.num_nodes, num_hyperedges=self.num_hyperedges, + global_node_ids=self.global_node_ids, + y=self.y, + ) + + def enrich_node_features_from( + self, + hdata_with_features: "HData", + node_space_setting: NodeSpaceSetting = "transductive", + fill_value: Optional[NodeSpaceFiller] = None, + ) -> "HData": + """ + Copy node features from another :class:`HData` by aligning features by ``global_node_ids``. + + Examples: + Transductive enrichment (default) expecting the same node space in both source and target: + >>> target = target.enrich_node_features_from(source, node_space_setting="transductive") + + Inductive with a scalar fill value: + >>> target = target.enrich_node_features_from( + ... source, + ... node_space_setting="inductive", + ... fill_value=0.0, + ... ) + + Inductive with a feature vector fill value: + >>> target = target.enrich_node_features_from( + ... source, + ... node_space_setting="inductive", + ... fill_value=[0.0, 1.0, 0.0], + ... ) + + Args: + hdata_with_features: Source :class:`HData` providing node features. + node_space_setting: The setting for the node space, determining how nodes are handled. + If ``"transductive"``, every target node is expected to exist in the source. + If ``"inductive"``, the target dataset may have a different node space, and missing nodes are filled using ``fill_value``. + fill_value: Scalar or vector used to fill missing node features when ``node_space_setting`` is not transductive. + + Returns: + A new :class:`HData` with node features copied from ``hdata_with_features``. + + Raises: + ValueError: If either instance lacks ``global_node_ids``, if the source feature rows + do not align with the source node IDs, if ``fill_value`` is used with + ``node_space_setting="transductive"``, or if ``fill_value`` is missing or malformed when ``node_space_setting="inductive"``. + """ + source_global_node_ids = hdata_with_features.global_node_ids + source_x = hdata_with_features.x + if self.global_node_ids is None or source_global_node_ids is None: + raise ValueError( + "Both HData instances must define global_node_ids to align node features." + ) + if source_x.size(0) != source_global_node_ids.size(0): + raise ValueError( + "Expected hdata_with_features.x rows to align with hdata_with_features.global_node_ids." + ) + self.__validate_node_space_setting(node_space_setting, fill_value) + + target_global_node_ids = self.global_node_ids.detach().cpu().tolist() + + # We need the index of the features for each node in the source, as we will use the index to track back + # to the node feautures after we match the global node id in the target to the one that is in the source + source_feature_idx_by_global_node_id = { + int(global_node_id): feature_idx + for feature_idx, global_node_id in enumerate( + source_global_node_ids.detach().cpu().tolist() + ) + } + + fill_features = self.__to_fill_features( + fill_value=fill_value, + num_features=int(source_x.size(1)), + dtype=source_x.dtype, + device=source_x.device, + ) + + enriched_rows = [] + missing_global_node_ids = [] + for global_node_id in target_global_node_ids: + source_feature_idx = source_feature_idx_by_global_node_id.get(int(global_node_id)) + if source_feature_idx is None: + # Example: global_node_id = 30 is not present in the source + # -> strict transductive mode records it as missing and then raises an error + # -> non-transductive mode fills the features with fill_value and continues enriching the other nodes + if is_transductive_setting(node_space_setting): + missing_global_node_ids.append( + int(global_node_id) + ) # record missing node for error message + else: + enriched_rows.append( + fill_features + ) # fill missing node features with fill_value and + continue + + # Match the global node IDs in the target to the corresponding feature indices in the source + # Example: source_global_node_ids = [10, 20, 30], source_x has shape (3, num_features) + # target_global_node_ids = [10, 30] + # -> source_feature_idx_by_global_node_id = {10: 0, 20: 1, 30: 2} + # -> pick source_x rows 0 and 2 for the target + enriched_rows.append(source_x[source_feature_idx]) + + if len(missing_global_node_ids) > 0: + raise ValueError( + f"Missing node features for target global_node_ids: {missing_global_node_ids}." + ) + + enriched_x = torch.stack(enriched_rows, dim=0).to(device=self.device) + + return self.__class__( + x=enriched_x, + hyperedge_index=self.hyperedge_index, + hyperedge_weights=self.hyperedge_weights, + hyperedge_attr=self.hyperedge_attr, + num_nodes=self.num_nodes, + num_hyperedges=self.num_hyperedges, + global_node_ids=self.global_node_ids, y=self.y, ) @@ -676,3 +837,47 @@ def stats(self) -> Dict[str, Any]: "distribution_node_degree_hist": distribution_node_degree_hist, "distribution_hyperedge_size_hist": distribution_hyperedge_size_hist, } + + def __to_fill_features( + self, + fill_value: Optional[NodeSpaceFiller], + num_features: int, + dtype: torch.dtype, + device: torch.device, + ) -> Tensor: + if fill_value is None: + return torch.empty((0,), dtype=dtype, device=device) + + if isinstance(fill_value, Tensor): + fill_features = fill_value.to(dtype=dtype, device=device) + elif isinstance(fill_value, (int, float)): + fill_features = torch.full( + (num_features,), float(fill_value), dtype=dtype, device=device + ) + else: + fill_features = torch.tensor(fill_value, dtype=dtype, device=device) + + # This can happen when fill_value is: + # - A scalar tensor, e.g., tensor(0.0), which should be broadcasted to all features + # - A list with a single value, e.g., [0.0], which should also be broadcasted to all features + if fill_features.numel() == 1: + fill_features = fill_features.repeat(num_features) + + if fill_features.dim() != 1 or fill_features.numel() != num_features: + raise ValueError( + f"Expected fill_value to define exactly {num_features} features, got shape " + f"{tuple(fill_features.shape)}." + ) + return fill_features + + def __validate_node_space_setting( + self, + node_space_setting: NodeSpaceSetting, + fill_value: Optional[NodeSpaceFiller], + ) -> None: + if is_transductive_setting(node_space_setting) and fill_value is not None: + raise ValueError( + "fill_value cannot be provided when node_space_setting='transductive'." + ) + if is_inductive_setting(node_space_setting) and fill_value is None: + raise ValueError("fill_value must be provided when node_space_setting='inductive'.") diff --git a/hyperbench/utils/__init__.py b/hyperbench/utils/__init__.py index 65e09ab..a30542e 100644 --- a/hyperbench/utils/__init__.py +++ b/hyperbench/utils/__init__.py @@ -15,6 +15,16 @@ is_input_layer, is_layer, ) +from .node_utils import ( + NodeSpaceAssignment, + NodeSpaceFiller, + NodeSpaceSetting, + is_assigned_to_all, + is_assigned_to_first, + is_inductive_setting, + is_transductive_setting, + is_transductive_split, +) from .sparse_utils import sparse_dropout __all__ = [ @@ -23,9 +33,17 @@ "NormalizationFn", "Aggregation", "Stage", + "NodeSpaceAssignment", + "NodeSpaceFiller", + "NodeSpaceSetting", "empty_edgeattr", "empty_hyperedgeindex", "empty_nodefeatures", + "is_assigned_to_all", + "is_assigned_to_first", + "is_inductive_setting", + "is_transductive_setting", + "is_transductive_split", "is_input_layer", "is_layer", "sparse_dropout", diff --git a/hyperbench/utils/node_utils.py b/hyperbench/utils/node_utils.py new file mode 100644 index 0000000..4bdb046 --- /dev/null +++ b/hyperbench/utils/node_utils.py @@ -0,0 +1,35 @@ +from torch import Tensor +from typing import Literal, Optional, Sequence, TypeAlias + + +NodeSpaceAssignment: TypeAlias = Literal["first", "all"] +NodeSpaceFiller: TypeAlias = float | int | Sequence[float] | Tensor +NodeSpaceSetting: TypeAlias = Literal["inductive", "transductive"] + + +def is_assigned_to_all(node_space_assignment: Optional[NodeSpaceAssignment]) -> bool: + return node_space_assignment == "all" + + +def is_assigned_to_first(node_space_assignment: Optional[NodeSpaceAssignment]) -> bool: + return node_space_assignment == "first" + + +def is_inductive_setting(node_space_setting: Optional[NodeSpaceSetting]) -> bool: + return node_space_setting == "inductive" + + +def is_transductive_setting(node_space_setting: Optional[NodeSpaceSetting]) -> bool: + return node_space_setting == "transductive" + + +def is_transductive_split( + node_space_setting: Optional[NodeSpaceSetting], + assign_node_space_to: Optional[NodeSpaceAssignment], + split_num: int, +) -> bool: + if not is_transductive_setting(node_space_setting): + return False + if is_assigned_to_all(assign_node_space_to): + return True + return is_assigned_to_first(assign_node_space_to) and split_num == 0