diff --git a/README.md b/README.md index 516e305..e3d78fa 100644 --- a/README.md +++ b/README.md @@ -8,6 +8,7 @@ [![project_license][license-shield]][license-url] [![codecov](https://codecov.io/github/hypernetwork-research-group/hyperbench/graph/badge.svg?token=XE0TB5JMOS)](https://codecov.io/github/hypernetwork-research-group/hyperbench) +[![CodeFactor](https://www.codefactor.io/repository/github/hypernetwork-research-group/hyperbench/badge)](https://www.codefactor.io/repository/github/hypernetwork-research-group/hyperbench) ## About the project diff --git a/examples/hgnn.py b/examples/hgnn.py index b1644aa..f0bc454 100644 --- a/examples/hgnn.py +++ b/examples/hgnn.py @@ -28,7 +28,7 @@ print("Loading and preparing dataset...") - dataset = AlgebraDataset(sampling_strategy=sampling_strategy, prepare=True) + dataset = AlgebraDataset(sampling_strategy=sampling_strategy) if verbose: print(f"Dataset:\n {dataset.hdata}\n") diff --git a/examples/hyperedge_enricher.py b/examples/hyperedge_enricher.py index 43882f6..3bb6604 100644 --- a/examples/hyperedge_enricher.py +++ b/examples/hyperedge_enricher.py @@ -5,10 +5,9 @@ if __name__ == "__main__": print("Loading and preparing dataset...\n") - dataset = AlgebraDataset(sampling_strategy=SamplingStrategy.HYPEREDGE, prepare=True) + dataset = AlgebraDataset(sampling_strategy=SamplingStrategy.HYPEREDGE) print("Enriching hyperedge weights...") - # HyperedgeWeightsEnricher enriches hyperedges with their degree (number of nodes in each hyperedge) as weights. # It optionally applies scaling and adds a constant to the weights. dataset.enrich_hyperedge_weights( @@ -25,8 +24,7 @@ print("Enriching hyperedge attributes...") - # HyperedgeAttrsEnricher adds a feature of 1.0 for each hyperedge, - # which can be used as a baseline or for methods that require hyperedge features. + # HyperedgeAttrsEnricher adds a feature of 1.0 for each hyperedge, which can be used as a baseline or for methods that require hyperedge features. dataset.enrich_hyperedge_attr( enricher=HyperedgeAttrsEnricher(), enrichment_mode="replace", diff --git a/examples/hypergcn.py b/examples/hypergcn.py index e94e281..c01c79f 100644 --- a/examples/hypergcn.py +++ b/examples/hypergcn.py @@ -28,7 +28,7 @@ print("Loading and preparing dataset...") - dataset = AlgebraDataset(sampling_strategy=sampling_strategy, prepare=True) + dataset = AlgebraDataset(sampling_strategy=sampling_strategy) dataset.remove_hyperedges_with_fewer_than_k_nodes(k=2) if verbose: print(f"Dataset:\n {dataset.hdata}\n") diff --git a/examples/mlp_common_neighbors.py b/examples/mlp_common_neighbors.py index b638429..adb7c56 100644 --- a/examples/mlp_common_neighbors.py +++ b/examples/mlp_common_neighbors.py @@ -28,7 +28,7 @@ print("Loading and preparing dataset...") - dataset = AlgebraDataset(sampling_strategy=sampling_strategy, prepare=True) + dataset = AlgebraDataset(sampling_strategy=sampling_strategy) if verbose: print(f"Dataset:\n {dataset.hdata}\n") diff --git a/examples/node2vec.py b/examples/node2vec.py index 63a4644..bbd09d9 100644 --- a/examples/node2vec.py +++ b/examples/node2vec.py @@ -29,7 +29,7 @@ print("Loading and preparing dataset...") - dataset = AlgebraDataset(sampling_strategy=sampling_strategy, prepare=True) + dataset = AlgebraDataset(sampling_strategy=sampling_strategy) dataset.remove_hyperedges_with_fewer_than_k_nodes(k=2) if verbose: print(f"Dataset:\n {dataset.hdata}\n") diff --git a/examples/node_enricher.py b/examples/node_enricher.py index dd5a932..8149e2b 100644 --- a/examples/node_enricher.py +++ b/examples/node_enricher.py @@ -5,8 +5,8 @@ if __name__ == "__main__": print("Loading and preparing dataset...") - dataset = AlgebraDataset(sampling_strategy=SamplingStrategy.HYPEREDGE, prepare=True) - + dataset = AlgebraDataset(sampling_strategy=SamplingStrategy.HYPEREDGE) + # NodeEnricher adds features for each node. dataset.enrich_node_features( enricher=LaplacianPositionalEncodingEnricher(num_features=32), enrichment_mode="replace", diff --git a/hyperbench/data/__init__.py b/hyperbench/data/__init__.py index 70ebe46..aa77d60 100644 --- a/hyperbench/data/__init__.py +++ b/hyperbench/data/__init__.py @@ -1,32 +1,30 @@ -from .dataset import ( - Dataset, - HIFConverter, -) +from .dataset import Dataset +from .hif import HIFLoader, HIFProcessor from .supported_datasets import ( AlgebraDataset, - AmazonDataset, - ContactHighSchoolDataset, - ContactPrimarySchoolDataset, - CoraDataset, - CourseraDataset, - DBLPDataset, - EmailEnronDataset, - EmailW3CDataset, - GeometryDataset, - GOTDataset, - IMDBDataset, - MusicBluesReviewsDataset, - NBADataset, - NDCClassesDataset, - NDCSubstancesDataset, - PatentDataset, - PubmedDataset, - RestaurantReviewsDataset, - ThreadsAskUbuntuDataset, - ThreadsMathsxDataset, - TwitterDataset, - VegasBarsReviewsDataset, + # AmazonDataset, + # ContactHighSchoolDataset, + # ContactPrimarySchoolDataset, + # CoraDataset, + # CourseraDataset, + # DBLPDataset, + # EmailEnronDataset, + # EmailW3CDataset, + # GeometryDataset, + # GOTDataset, + # IMDBDataset, + # MusicBluesReviewsDataset, + # NBADataset, + # NDCClassesDataset, + # NDCSubstancesDataset, + # PatentDataset, + # PubmedDataset, + # RestaurantReviewsDataset, + # ThreadsAskUbuntuDataset, + # ThreadsMathsxDataset, + # TwitterDataset, + # VegasBarsReviewsDataset, ) from .loader import DataLoader @@ -54,7 +52,8 @@ "EmailW3CDataset", "GeometryDataset", "GOTDataset", - "HIFConverter", + "HIFLoader", + "HIFProcessor", "HyperedgeSampler", "IMDBDataset", "MusicBluesReviewsDataset", diff --git a/hyperbench/data/dataset.py b/hyperbench/data/dataset.py index 48b874c..ccd3b46 100644 --- a/hyperbench/data/dataset.py +++ b/hyperbench/data/dataset.py @@ -1,138 +1,25 @@ import json import os +import requests import tempfile import torch -import zstandard as zstd -import requests -import warnings -from enum import Enum -from huggingface_hub import hf_hub_download -from typing import Any, Dict, List, Optional + +from typing import Any, Dict, List, Optional, Literal from torch import Tensor from torch.utils.data import Dataset as TorchDataset -from hyperbench.nn import EnrichmentMode, NodeEnricher, HyperedgeEnricher -from hyperbench.types import HData, HIFHypergraph, HyperedgeIndex -from hyperbench.nn import EnrichmentMode -from hyperbench.utils import validate_hif_json - +from hyperbench.data.hif import HIFLoader, HIFProcessor from hyperbench.data.sampling import SamplingStrategy, create_sampler_from_strategy - - -class DatasetNames(Enum): - """ - Enumeration of available datasets. - """ - - ALGEBRA = "algebra" - AMAZON = "amazon" - CONTACT_HIGH_SCHOOL = "contact-high-school" - CONTACT_PRIMARY_SCHOOL = "contact-primary-school" - CORA = "cora" - COURSERA = "coursera" - DBLP = "dblp" - EMAIL_ENRON = "email-Enron" - EMAIL_W3C = "email-W3C" - GEOMETRY = "geometry" - GOT = "got" - IMDB = "imdb" - MUSIC_BLUES_REVIEWS = "music-blues-reviews" - NBA = "nba" - NDC_CLASSES = "NDC-classes" - NDC_SUBSTANCES = "NDC-substances" - PATENT = "patent" - PUBMED = "pubmed" - RESTAURANT_REVIEWS = "restaurant-reviews" - THREADS_ASK_UBUNTU = "threads-ask-ubuntu" - THREADS_MATH_SX = "threads-math-sx" - TWITTER = "twitter" - VEGAS_BARS_REVIEWS = "vegas-bars-reviews" - - -class HIFConverter: - """A utility class to load hypergraphs from HIF format.""" - - @staticmethod - def load_from_hif(dataset_name: Optional[str], save_on_disk: bool = False) -> HIFHypergraph: - if dataset_name is None: - raise ValueError(f"Dataset name (provided: {dataset_name}) must be provided.") - if dataset_name not in DatasetNames.__members__: - raise ValueError(f"Dataset '{dataset_name}' not found.") - - dataset_name = DatasetNames[dataset_name].value - current_dir = os.path.dirname(os.path.abspath(__file__)) - zst_filename = os.path.join(current_dir, "datasets", f"{dataset_name}.json.zst") - - if not os.path.exists(zst_filename): - github_dataset_repo = f"https://github.com/hypernetwork-research-group/datasets/blob/main/{dataset_name}.json.zst?raw=true" - - response = requests.get(github_dataset_repo) - if response.status_code != 200: - warnings.warn( - f"GitHub raw download failed for dataset '{dataset_name}' with status code {response.status_code}\n" - "Falling back to Hugging Face Hub download for dataset", - category=UserWarning, - stacklevel=2, - ) - - REPO_ID = f"HypernetworkRG/{dataset_name}" - FILENAME = f"{dataset_name}.json.zst" - - with tempfile.NamedTemporaryFile( - mode="wb", suffix=".json.zst", delete=False - ) as tmp_hf_file: - try: - downloaded_path = hf_hub_download( - repo_id=REPO_ID, - filename=FILENAME, - repo_type="dataset", - ) - except Exception as e: - raise ValueError( - f"Failed to download dataset '{dataset_name}' from GitHub and Hugging Face Hub. GitHub error: {response.status_code} | Hugging Face error: {str(e)}" - ) - with open(downloaded_path, "rb") as hf_file: - hf_content = hf_file.read() - tmp_hf_file.write(hf_content) - - response._content = hf_content - - if save_on_disk: - os.makedirs(os.path.join(current_dir, "datasets"), exist_ok=True) - with open(zst_filename, "wb") as f: - f.write(response.content) - else: - # Create temporary file for downloaded zst content - with tempfile.NamedTemporaryFile( - mode="wb", suffix=".json.zst", delete=False - ) as tmp_zst_file: - tmp_zst_file.write(response.content) - zst_filename = tmp_zst_file.name - - # Decompress the downloaded zst file - dctx = zstd.ZstdDecompressor() - with ( - open(zst_filename, "rb") as input_f, - tempfile.NamedTemporaryFile(mode="wb", suffix=".json", delete=False) as tmp_file, - ): - dctx.copy_stream(input_f, tmp_file) - output = tmp_file.name - - with open(output, "r") as f: - hiftext = json.load(f) - if not validate_hif_json(output): - raise ValueError(f"Dataset '{dataset_name}' is not HIF-compliant.") - - hypergraph = HIFHypergraph.from_hif(hiftext) - return hypergraph +from hyperbench.nn import EnrichmentMode, NodeEnricher, HyperedgeEnricher +from hyperbench.types import HData class Dataset(TorchDataset): """ A dataset class for loading and processing hypergraph data. - Attributes: + Args: DATASET_NAME: Class variable indicating the name of the dataset to load. hypergraph: The loaded hypergraph in HIF format. Can be ``None`` if initialized from an HData object. hdata: The processed hypergraph data in HData format. @@ -140,13 +27,10 @@ class Dataset(TorchDataset): If not provided, defaults to ``SamplingStrategy.HYPEREDGE``. """ - DATASET_NAME = None - def __init__( self, hdata: Optional[HData] = None, sampling_strategy: SamplingStrategy = SamplingStrategy.HYPEREDGE, - prepare: bool = True, ) -> None: """ Initialize the Dataset. @@ -155,22 +39,11 @@ def __init__( hdata: Optional HData object to initialize the dataset with. If provided, the dataset will be initialized with this data instead of loading and processing from HIF. Must be provided if prepare is set to ``False``. sampling_strategy: The sampling strategy to use for the dataset. If not provided, defaults to ``SamplingStrategy.HYPEREDGE``. - prepare: Whether to load and process the original dataset from HIF format. - If set to ``False``, the dataset will be initialized with the provided hdata instead. Defaults to ``True``. """ - self.__is_prepared = prepare + self.__sampler = create_sampler_from_strategy(sampling_strategy) self.sampling_strategy = sampling_strategy - - if self.__is_prepared: - self.hypergraph = self.download() - self.hdata = self.process() - else: - if hdata is None: - raise ValueError("hdata must be provided when prepare is set to False.") - - self.hypergraph = HIFHypergraph.empty() - self.hdata = hdata + self.hdata = hdata if hdata is not None else HData.empty() def __len__(self) -> int: return self.__sampler.len(self.hdata) @@ -210,79 +83,49 @@ def from_hdata( Returns: The :class:`Dataset` instance with the provided :class:`HData`. """ - return cls(hdata=hdata, sampling_strategy=sampling_strategy, prepare=False) + return cls(hdata=hdata, sampling_strategy=sampling_strategy) - def download(self) -> HIFHypergraph: - """ - Load the hypergraph from HIF format using HIFConverter class. + @classmethod + def from_url( + cls, + url: str, + sampling_strategy: SamplingStrategy = SamplingStrategy.HYPEREDGE, + save_on_disk: bool = False, + ) -> "Dataset": """ - if not self.__is_prepared: - raise ValueError("download can only be called for the original dataset (prepare=True).") + Create a :class:`Dataset` instance by loading a hypergraph from a URL pointing to a .json or .json.zst file in HIF format. - if hasattr(self, "hypergraph") and self.hypergraph is not None: - return self.hypergraph + Args: + url: The URL to the .json or .json.zst file containing the HIF hypergraph data. + sampling_strategy: The sampling strategy to use for the dataset. If not provided, defaults to ``SamplingStrategy.HYPEREDGE``. + save_on_disk: Whether to save the downloaded file on disk. - return HIFConverter.load_from_hif(self.DATASET_NAME, save_on_disk=True) + Returns: + The :class:`Dataset` instance with the loaded hypergraph data. + """ + hdata = HIFLoader.load_from_url(url=url, save_on_disk=save_on_disk) + dataset = cls.from_hdata(hdata=hdata, sampling_strategy=sampling_strategy) + return dataset - def process(self) -> HData: + @classmethod + def from_path( + cls, + filepath: str, + sampling_strategy: SamplingStrategy = SamplingStrategy.HYPEREDGE, + ) -> "Dataset": """ - Process the loaded hypergraph into :class:`HData` format, mapping HIF structure to tensors. + Create a :class:`Dataset` instance by loading a hypergraph from a local file path pointing to a .json or .json.zst file in HIF format. + + Args: + filepath: The local file path to the .json or .json.zst file containing the HIF hypergraph data. + sampling_strategy: The sampling strategy to use for the dataset. If not provided, defaults to ``SamplingStrategy.HYPEREDGE``. Returns: - The processed hypergraph data. + The :class:`Dataset` instance with the loaded hypergraph data. """ - if not self.__is_prepared: - raise ValueError("process can only be called for the original dataset.") - - num_nodes = len(self.hypergraph.nodes) - x = self.__process_x(num_nodes) - - # Remap node IDs to 0-based contiguous IDs (using indices) matching the x tensor order - node_id_to_idx = {node.get("node"): idx for idx, node in enumerate(self.hypergraph.nodes)} - # Initialize edge_set only with edges that have incidences, so that - # we avoid inflating edge count due to isolated nodes/missing incidences - hyperedge_id_to_idx: Dict[Any, int] = {} - - node_ids = [] - hyperedge_ids = [] - nodes_with_incidences = set() - for incidence in self.hypergraph.incidences: - node_id = incidence.get("node", 0) - hyperedge_id = incidence.get("edge", 0) - - if hyperedge_id not in hyperedge_id_to_idx: - # Hyperedges start from 0 and are assigned IDs in the order they are first encountered in incidences - hyperedge_id_to_idx[hyperedge_id] = len(hyperedge_id_to_idx) - - node_ids.append(node_id_to_idx[node_id]) - hyperedge_ids.append(hyperedge_id_to_idx[hyperedge_id]) - nodes_with_incidences.add(node_id_to_idx[node_id]) - - # Handle isolated nodes by assigning them to a new unique hyperedge (self-loop) - for node_idx in range(num_nodes): - if node_idx not in nodes_with_incidences: - new_hyperedge_id = len(hyperedge_id_to_idx) - # Unique dummy key to reserve the index in hyperedge_set - hyperedge_id_to_idx[f"__self_loop_{node_idx}__"] = new_hyperedge_id - node_ids.append(node_idx) - hyperedge_ids.append(new_hyperedge_id) - - num_hyperedges = len(hyperedge_id_to_idx) - hyperedge_attr = self.__process_hyperedge_attr(hyperedge_id_to_idx, num_hyperedges) - - hyperedge_weights = self.__process_hyperedge_weights() - - hyperedge_index = torch.tensor([node_ids, hyperedge_ids], dtype=torch.long) - - return HData( - x=x, - hyperedge_index=hyperedge_index, - hyperedge_weights=hyperedge_weights, - hyperedge_attr=hyperedge_attr, - num_nodes=num_nodes, - num_hyperedges=num_hyperedges, - global_node_ids=HyperedgeIndex(hyperedge_index).node_ids, - ) + hypergraph = HIFLoader.load_from_path(filepath=filepath) + dataset = cls.from_hdata(hdata=hypergraph, sampling_strategy=sampling_strategy) + return dataset def enrich_node_features( self, @@ -340,7 +183,7 @@ def update_from_hdata(self, hdata: HData) -> "Dataset": Returns: The :class:`Dataset` instance with the provided :class:`HData`. """ - return self.__class__(hdata=hdata, sampling_strategy=self.sampling_strategy, prepare=False) + return self.__class__(hdata=hdata, sampling_strategy=self.sampling_strategy) def remove_hyperedges_with_fewer_than_k_nodes(self, k: int) -> None: """ @@ -428,7 +271,6 @@ def split( split_dataset = self.__class__( hdata=split_hdata, sampling_strategy=self.sampling_strategy, - prepare=False, ) split_datasets.append(split_dataset) @@ -449,70 +291,6 @@ def to(self, device: torch.device) -> "Dataset": self.hdata = self.hdata.to(device) return self - def transform_node_attrs( - self, - attrs: Dict[str, Any], - attr_keys: Optional[List[str]] = None, - ) -> Tensor: - return self.transform_attrs(attrs, attr_keys) - - def transform_hyperedge_attrs( - self, - attrs: Dict[str, Any], - attr_keys: Optional[List[str]] = None, - ) -> Tensor: - return self.transform_attrs(attrs, attr_keys) - - def transform_attrs( - self, - attrs: Dict[str, Any], - attr_keys: Optional[List[str]] = None, - ) -> Tensor: - """ - Extract and encode numeric attributes to tensor. - Non-numeric attributes are discarded. Missing attributes are filled with ``0.0``. - - Args: - attrs: Dictionary of attributes - attr_keys: Optional list of attribute keys to encode. If provided, ensures consistent ordering and fill missing with ``0.0``. - - Returns: - Tensor of numeric attribute values - """ - numeric_attrs = { - key: value - for key, value in attrs.items() - if isinstance(value, (int, float)) and not isinstance(value, bool) - } - - if attr_keys is not None: - values = [float(numeric_attrs.get(key, 0.0)) for key in attr_keys] - return torch.tensor(values, dtype=torch.float) - - if not numeric_attrs: - return torch.tensor([], dtype=torch.float) - - values = [float(value) for value in numeric_attrs.values()] - return torch.tensor(values, dtype=torch.float) - - def __collect_attr_keys(self, attr_keys: List[Dict[str, Any]]) -> List[str]: - """ - Collect unique numeric attribute keys from a list of attribute dictionaries. - - Args: - attr_keys: List of attribute dictionaries. - - Returns: - List of unique numeric attribute keys. - """ - unique_keys = [] - for attrs in attr_keys: - for key, value in attrs.items(): - if key not in unique_keys and isinstance(value, (int, float)): - unique_keys.append(key) - - return unique_keys - def __get_hyperedge_ids_permutation( self, num_hyperedges: int, @@ -537,86 +315,17 @@ def __get_hyperedge_ids_permutation( ranged_hyperedge_ids_permutation = torch.arange(num_hyperedges, device=device) return ranged_hyperedge_ids_permutation - def __process_hyperedge_attr( - self, - hyperedge_id_to_idx: Dict[Any, int], - num_hyperedges: int, - ) -> Optional[Tensor]: - # hyperedge-attr: shape [num_hyperedges, num_hyperedge_attributes] - hyperedge_attr = None - has_hyperedges = ( - self.hypergraph.hyperedges is not None and len(self.hypergraph.hyperedges) > 0 - ) - has_any_hyperedge_attrs = has_hyperedges and any( - "attrs" in edge for edge in self.hypergraph.hyperedges - ) - - if has_any_hyperedge_attrs: - hyperedge_id_to_attrs: Dict[Any, Dict[str, Any]] = { - e.get("edge"): e.get("attrs", {}) for e in self.hypergraph.hyperedges - } - - hyperedge_attr_keys = self.__collect_attr_keys(list(hyperedge_id_to_attrs.values())) - - # Build attributes in exact order of hyperedge_set indices (0 to num_hyperedges - 1) - hyperedge_idx_to_id = {idx: id for id, idx in hyperedge_id_to_idx.items()} - - attrs = [] - for hyperedge_idx in range(num_hyperedges): - hyperedge_id = hyperedge_idx_to_id[hyperedge_idx] - - transformed_attrs = self.transform_hyperedge_attrs( - # If it's a real hyperedge, get its attrs; if self-loop, get empty dict - attrs=hyperedge_id_to_attrs.get(hyperedge_id, {}), - attr_keys=hyperedge_attr_keys, - ) - attrs.append(transformed_attrs) - - hyperedge_attr = torch.stack(attrs) - - return hyperedge_attr - - def __process_x(self, num_nodes: int) -> Tensor: - # Collect all attribute keys to have tensors of same size - node_attr_keys = self.__collect_attr_keys( - [node.get("attrs", {}) for node in self.hypergraph.nodes] - ) - - if node_attr_keys: - x = torch.stack( - [ - self.transform_node_attrs(node.get("attrs", {}), attr_keys=node_attr_keys) - for node in self.hypergraph.nodes - ] - ) - else: - # Fallback to ones if no node features, 1 is better as it can help during - # training (e.g., avoid zero multiplication), especially in first epochs - x = torch.ones((num_nodes, 1), dtype=torch.float) - - return x # shape [num_nodes, num_node_features] - - def __process_hyperedge_weights(self) -> Optional[Tensor]: - # Initialize the hyperedge weights tensor - hyperedge_weights = None - - has_hyperedge_weights = self.hypergraph.hyperedges is not None and all( - "weight" in edge for edge in self.hypergraph.hyperedges - ) - - if has_hyperedge_weights: - weights = [edge.get("weight", 1.0) for edge in self.hypergraph.hyperedges] - hyperedge_weights = torch.tensor(weights, dtype=torch.float) - elif ( - has_hyperedge_weights is False - and self.hypergraph.hyperedges is not None - and any("weight" in edge for edge in self.hypergraph.hyperedges) - ): - raise ValueError( - "Some hyperedges have weights while others do not. All hyperedges must either have weights or none." - ) + def transform_node_attrs( + attrs: Dict[str, Any], + attr_keys: Optional[List[str]] = None, + ) -> Tensor: + return HIFProcessor.transform_attrs(attrs, attr_keys) - return hyperedge_weights + def transform_hyperedge_attrs( + attrs: Dict[str, Any], + attr_keys: Optional[List[str]] = None, + ) -> Tensor: + return HIFProcessor.transform_attrs(attrs, attr_keys) def stats(self) -> Dict[str, Any]: """ diff --git a/hyperbench/data/hif.py b/hyperbench/data/hif.py new file mode 100644 index 0000000..4485061 --- /dev/null +++ b/hyperbench/data/hif.py @@ -0,0 +1,377 @@ +import json +import os +import requests +import tempfile +import torch +import warnings +import zstandard as zstd + +from huggingface_hub import hf_hub_download +from typing import Optional, Dict, Any, List +from torch import Tensor + +from hyperbench.types import HData, HIFHypergraph +from hyperbench.utils import ( + validate_hif_json, + decompress_zst, + compress_to_zst, + validate_http_url, + write_to_disk, +) + +GITHUB_COMMIT_SHA = "3879b2ce84750e54f984ca06ce3246dff22c71c7" + + +class HIFProcessor: + """A utility class to process HIF hypergraph data into :class:`HData` format.""" + + @staticmethod + def transform_attrs( + attrs: Dict[str, Any], + attr_keys: Optional[List[str]] = None, + ) -> Tensor: + """ + Extract and encode numeric attributes to tensor. + Non-numeric attributes are discarded. Missing attributes are filled with ``0.0``. + + Args: + attrs: Dictionary of attributes + attr_keys: Optional list of attribute keys to encode. If provided, ensures consistent ordering and fill missing with ``0.0``. + + Returns: + Tensor of numeric attribute values + """ + numeric_attrs = { + key: value + for key, value in attrs.items() + if isinstance(value, (int, float)) and not isinstance(value, bool) + } + + if attr_keys is not None: + values = [float(numeric_attrs.get(key, 0.0)) for key in attr_keys] + return torch.tensor(values, dtype=torch.float) + + if not numeric_attrs: + return torch.tensor([], dtype=torch.float) + + values = [float(value) for value in numeric_attrs.values()] + return torch.tensor(values, dtype=torch.float) + + @classmethod + def process_hypergraph(cls, hypergraph: HIFHypergraph) -> HData: + """ + Process the loaded hypergraph into :class:`HData` format, mapping HIF structure to tensors. + + Returns: + The processed hypergraph data. + """ + # if not self.__is_prepared: + # raise ValueError("process can only be called for the original dataset.") + + num_nodes = len(hypergraph.nodes) + x = cls.__process_x(hypergraph, num_nodes) + + # Remap node IDs to 0-based contiguous IDs (using indices) matching the x tensor order + node_id_to_idx = {node.get("node"): idx for idx, node in enumerate(hypergraph.nodes)} + # Initialize edge_set only with edges that have incidences, so that + # we avoid inflating edge count due to isolated nodes/missing incidences + hyperedge_id_to_idx: Dict[Any, int] = {} + + node_ids = [] + hyperedge_ids = [] + nodes_with_incidences = set() + for incidence in hypergraph.incidences: + node_id = incidence.get("node", 0) + hyperedge_id = incidence.get("edge", 0) + + if hyperedge_id not in hyperedge_id_to_idx: + # Hyperedges start from 0 and are assigned IDs in the order they are first encountered in incidences + hyperedge_id_to_idx[hyperedge_id] = len(hyperedge_id_to_idx) + + node_ids.append(node_id_to_idx[node_id]) + hyperedge_ids.append(hyperedge_id_to_idx[hyperedge_id]) + nodes_with_incidences.add(node_id_to_idx[node_id]) + + # Handle isolated nodes by assigning them to a new unique hyperedge (self-loop) + for node_idx in range(num_nodes): + if node_idx not in nodes_with_incidences: + new_hyperedge_id = len(hyperedge_id_to_idx) + # Unique dummy key to reserve the index in hyperedge_set + hyperedge_id_to_idx[f"__self_loop_{node_idx}__"] = new_hyperedge_id + node_ids.append(node_idx) + hyperedge_ids.append(new_hyperedge_id) + + num_hyperedges = len(hyperedge_id_to_idx) + hyperedge_attr = cls.__process_hyperedge_attr( + hypergraph=hypergraph, + hyperedge_id_to_idx=hyperedge_id_to_idx, + num_hyperedges=num_hyperedges, + ) + + hyperedge_weights = cls.__process_hyperedge_weights( + hypergraph=hypergraph, + hyperedge_id_to_idx=hyperedge_id_to_idx, + num_hyperedges=num_hyperedges, + ) + + hyperedge_index = torch.tensor([node_ids, hyperedge_ids], dtype=torch.long) + + return HData( + x=x, + hyperedge_index=hyperedge_index, + hyperedge_weights=hyperedge_weights, + hyperedge_attr=hyperedge_attr, + num_nodes=num_nodes, + num_hyperedges=num_hyperedges, + ) + + def __collect_attr_keys(attr_keys: List[Dict[str, Any]]) -> List[str]: + """ + Collect unique numeric attribute keys from a list of attribute dictionaries. + + Args: + attr_keys: List of attribute dictionaries. + + Returns: + List of unique numeric attribute keys. + """ + unique_keys = [] + for attrs in attr_keys: + for key, value in attrs.items(): + if key not in unique_keys and isinstance(value, (int, float)): + unique_keys.append(key) + + return unique_keys + + @classmethod + def __process_hyperedge_attr( + cls, + hypergraph: HIFHypergraph, + hyperedge_id_to_idx: Dict[Any, int], + num_hyperedges: int, + ) -> Optional[Tensor]: + # hyperedge-attr: shape [num_hyperedges, num_hyperedge_attributes] + hyperedge_attr = None + has_hyperedges = hypergraph.hyperedges is not None and len(hypergraph.hyperedges) > 0 + has_any_hyperedge_attrs = has_hyperedges and any( + "attrs" in edge for edge in hypergraph.hyperedges + ) + + if has_any_hyperedge_attrs: + hyperedge_id_to_attrs: Dict[Any, Dict[str, Any]] = { + e.get("edge"): e.get("attrs", {}) for e in hypergraph.hyperedges + } + + hyperedge_attr_keys = cls.__collect_attr_keys(list(hyperedge_id_to_attrs.values())) + + # Build attributes in exact order of hyperedge_set indices (0 to num_hyperedges - 1) + hyperedge_idx_to_id = {idx: id for id, idx in hyperedge_id_to_idx.items()} + + attrs = [] + for hyperedge_idx in range(num_hyperedges): + hyperedge_id = hyperedge_idx_to_id[hyperedge_idx] + + transformed_attrs = HIFProcessor.transform_attrs( + # If it's a real hyperedge, get its attrs; if self-loop, get empty dict + attrs=hyperedge_id_to_attrs.get(hyperedge_id, {}), + attr_keys=hyperedge_attr_keys, + ) + attrs.append(transformed_attrs) + + hyperedge_attr = torch.stack(attrs) + + return hyperedge_attr + + @classmethod + def __process_x(cls, hypergraph: HIFHypergraph, num_nodes: int) -> Tensor: + # Collect all attribute keys to have tensors of same size + node_attr_keys = cls.__collect_attr_keys( + [node.get("attrs", {}) for node in hypergraph.nodes] + ) + + if node_attr_keys: + x = torch.stack( + [ + HIFProcessor.transform_attrs(node.get("attrs", {}), attr_keys=node_attr_keys) + for node in hypergraph.nodes + ] + ) + else: + # Fallback to ones if no node features, 1 is better as it can help during + # training (e.g., avoid zero multiplication), especially in first epochs + x = torch.ones((num_nodes, 1), dtype=torch.float) + + return x # shape [num_nodes, num_node_features] + + @classmethod + def __process_hyperedge_weights( + cls, + hypergraph: HIFHypergraph, + hyperedge_id_to_idx: Dict[Any, int], + num_hyperedges: int, + ) -> Optional[Tensor]: + has_hyperedges = hypergraph.hyperedges is not None and len(hypergraph.hyperedges) > 0 + has_any_hyperedge_attrs = has_hyperedges and any( + "attrs" in edge for edge in hypergraph.hyperedges + ) + + # Keep old behavior for fixtures where edges have no attrs at all. + if not has_any_hyperedge_attrs: + return None + + # Map real edge id -> attrs (self-loops are absent and will default to 1.0) + hyperedge_id_to_attrs: Dict[Any, Dict[str, Any]] = { + e.get("edge"): e.get("attrs", {}) for e in hypergraph.hyperedges + } + + # Build in exact hyperedge index order, defaulting missing weights to 1.0. + hyperedge_idx_to_id = {idx: edge_id for edge_id, idx in hyperedge_id_to_idx.items()} + weights = [] + for hyperedge_idx in range(num_hyperedges): + edge_id = hyperedge_idx_to_id[hyperedge_idx] + edge_attrs = hyperedge_id_to_attrs.get(edge_id, {}) + weights.append(float(edge_attrs.get("weight", 1.0))) + + return torch.tensor(weights, dtype=torch.float) + + +class HIFLoader: + """A utility class to load hypergraphs from HIF format.""" + + def load_from_url(url: str, save_on_disk: bool = False) -> HData: + """ + Load a hypergraph from a given URL pointing to a .json or .json.zst file in HIF format. + Args: + url (str): The URL to the .json or .json.zst file containing the HIF hypergraph data. + save_on_disk (bool): Whether to save the downloaded file on disk. + Returns: + HData: The loaded hypergraph object. + """ + url = validate_http_url(url) + + response = requests.get(url, timeout=20) + if response.status_code != 200: + raise ValueError( + f"Failed to download dataset from URL '{url}' with status code {response.status_code}" + ) + + with tempfile.NamedTemporaryFile( + mode="wb", suffix=".json.zst", delete=False + ) as tmp_zst_file: + tmp_zst_file.write(response.content) + zst_filename = tmp_zst_file.name + + if zst_filename.endswith(".zst"): + if save_on_disk: + write_to_disk(os.path.basename(url), response.content) + output = decompress_zst(zst_filename) + elif zst_filename.endswith(".json"): + if save_on_disk: + compressed = compress_to_zst(zst_filename) + write_to_disk(os.path.basename(url), compressed) + output = zst_filename + else: + raise ValueError( + f"Unsupported file format for URL '{url}'. Expected .json or .json.zst" + ) + + hypergraph = HIFLoader.__extract_hif(output) + hdata = HIFProcessor.process_hypergraph(hypergraph) + return hdata + + def load_from_path(filepath: str) -> HData: + """ + Load a hypergraph from a local file path pointing to a .json or .json.zst file in HIF format. + Args: + filepath (str): The local file path to the .json or .json.zst file + containing the HIF hypergraph data. + Returns: + HData: The loaded hypergraph object. + """ + if not os.path.exists(filepath): + raise ValueError(f"File '{filepath}' does not exist.") + + if filepath.endswith(".zst"): + output = decompress_zst(filepath) + elif filepath.endswith(".json"): + output = filepath + else: + raise ValueError( + f"Unsupported file format for filepath '{filepath}'. Expected .json or .json.zst" + ) + + hypergraph = HIFLoader.__extract_hif(output) + hdata = HIFProcessor.process_hypergraph(hypergraph) + return hdata + + def load_by_name( + dataset_name: str, hf_sha: Optional[str] = None, save_on_disk: bool = False + ) -> HData: + current_dir = os.path.dirname(os.path.abspath(__file__)) + zst_filename = os.path.join(current_dir, "datasets", f"{dataset_name}.json.zst") + + if not os.path.exists(zst_filename): + github_url = f"https://raw.githubusercontent.com/hypernetwork-research-group/datasets/{GITHUB_COMMIT_SHA}/{dataset_name}.json.zst" + response = requests.get(github_url, timeout=20) + if response.status_code != 200: + warnings.warn( + f"GitHub raw download failed for dataset '{dataset_name}' with status code {response.status_code}\n" + "Falling back to Hugging Face Hub download for dataset", + category=UserWarning, + stacklevel=2, + ) + + REPO_ID = f"HypernetworkRG/{dataset_name}" + FILENAME = f"{dataset_name}.json.zst" + + with tempfile.NamedTemporaryFile( + mode="wb", suffix=".json.zst", delete=False + ) as tmp_hf_file: + if hf_sha is not None: + try: + downloaded_path = hf_hub_download( + repo_id=REPO_ID, + filename=FILENAME, + repo_type="dataset", + revision=hf_sha, + ) + except Exception as e: + raise ValueError( + f"Failed to download dataset '{dataset_name}' from GitHub and Hugging Face Hub. GitHub error: {response.status_code} | Hugging Face error: {str(e)}" + ) + else: + raise ValueError( + f"Failed to download dataset '{dataset_name}' from GitHub with status code {response.status_code} and no SHA provided for Hugging Face Hub fallback." + ) + + with open(downloaded_path, "rb") as hf_file: + hf_content = hf_file.read() + tmp_hf_file.write(hf_content) + + response._content = hf_content + + if save_on_disk: + os.makedirs(os.path.join(current_dir, "datasets"), exist_ok=True) + with open(zst_filename, "wb") as f: + f.write(response.content) + else: + # Create temporary file for downloaded zst content + with tempfile.NamedTemporaryFile( + mode="wb", suffix=".json.zst", delete=False + ) as tmp_zst_file: + tmp_zst_file.write(response.content) + zst_filename = tmp_zst_file.name + + output = decompress_zst(zst_filename) + hypergraph = HIFLoader.__extract_hif(output) + hdata = HIFProcessor.process_hypergraph(hypergraph) + return hdata + + @staticmethod + def __extract_hif(json_file: str) -> HIFHypergraph: + with open(json_file, "r") as f: + hiftext = json.load(f) + if not validate_hif_json(json_file): + raise ValueError(f"Dataset from file '{json_file}' is not HIF-compliant.") + hypergraph = HIFHypergraph.from_hif(hiftext) + return hypergraph diff --git a/hyperbench/data/supported_datasets.py b/hyperbench/data/supported_datasets.py index 443f043..8cb178f 100644 --- a/hyperbench/data/supported_datasets.py +++ b/hyperbench/data/supported_datasets.py @@ -1,93 +1,142 @@ +from hyperbench.data import HIFLoader from hyperbench.data.dataset import Dataset +from hyperbench.data.sampling import SamplingStrategy -class AlgebraDataset(Dataset): - DATASET_NAME = "ALGEBRA" +class PreloadedDataset(Dataset): + """ + Base class for datasets that use default loading. Subclasses should specify the DATASET_NAME class variable. + The dataset will be saved on disk after the first load. + Args: + hdata: Optional HData object. If None, the dataset will be loaded using the DATASET_NAME. + sampling_strategy: The sampling strategy to use for this dataset. Default is SamplingStrategy.HYPEREDGE. + """ + DATASET_NAME = "" + HF_SHA = None -class AmazonDataset(Dataset): - DATASET_NAME = "AMAZON" + def __init__( + self, + hdata=None, + sampling_strategy: SamplingStrategy = SamplingStrategy.HYPEREDGE, + ) -> None: + super().__init__(hdata=hdata, sampling_strategy=sampling_strategy) + if hdata is None: + self.hdata = HIFLoader.load_by_name( + self.DATASET_NAME, hf_sha=self.HF_SHA, save_on_disk=True + ) -class ContactHighSchoolDataset(Dataset): - DATASET_NAME = "CONTACT_HIGH_SCHOOL" +class AlgebraDataset(PreloadedDataset): + DATASET_NAME = "algebra" + HF_SHA = "2bb641461e00c103fb5ef4fe6a30aad42500fc21" -class ContactPrimarySchoolDataset(Dataset): - DATASET_NAME = "CONTACT_PRIMARY_SCHOOL" +class AmazonDataset(PreloadedDataset): + DATASET_NAME = "amazon" + HF_SHA = "614f75d1847d233ee06da0cc3ee10f51220b8243" -class CoraDataset(Dataset): - DATASET_NAME = "CORA" +class ContactHighSchoolDataset(PreloadedDataset): + DATASET_NAME = "contact-high-school" + HF_SHA = "b991fde34631a357961a244a5c4d734cf3093199" -class CourseraDataset(Dataset): - DATASET_NAME = "COURSERA" +class ContactPrimarySchoolDataset(PreloadedDataset): + DATASET_NAME = "contact-primary-school" + HF_SHA = "f6f5453777d1fc62f6305b17d131ec1e32cdbe66" -class DBLPDataset(Dataset): - DATASET_NAME = "DBLP" +class CoraDataset(PreloadedDataset): + DATASET_NAME = "cora" + HF_SHA = "91fda9ed324e2cce2430638747e9b032bd9c22ad" -class EmailEnronDataset(Dataset): - DATASET_NAME = "EMAIL_ENRON" +class CourseraDataset(PreloadedDataset): + DATASET_NAME = "coursera" + HF_SHA = "e68679a01af61c43292575839e451eb0bbeee202" -class EmailW3CDataset(Dataset): - DATASET_NAME = "EMAIL_W3C" +class DBLPDataset(PreloadedDataset): + DATASET_NAME = "dblp" + HF_SHA = "151c360ed77042abebb9709fd3d738763d5c5044" -class GeometryDataset(Dataset): - DATASET_NAME = "GEOMETRY" +class EmailEnronDataset(PreloadedDataset): + DATASET_NAME = "email-Enron" + HF_SHA = "05247a5441a6a337cdccee24c0060255815905be" -class GOTDataset(Dataset): - DATASET_NAME = "GOT" +class EmailW3CDataset(PreloadedDataset): + DATASET_NAME = "email-W3C" + HF_SHA = "18b8c795504388c1d075ffcea7eada281ec5e416" -class IMDBDataset(Dataset): - DATASET_NAME = "IMDB" +class GeometryDataset(PreloadedDataset): + DATASET_NAME = "geometry" + HF_SHA = "49a8647d6ff7361485c953949010155b0b522a12" -class MusicBluesReviewsDataset(Dataset): - DATASET_NAME = "MUSIC_BLUES_REVIEWS" +class GOTDataset(PreloadedDataset): + DATASET_NAME = "got" + HF_SHA = "2efb505e5d82457f6e5ba21820c8d8f2298f0ece" -class NBADataset(Dataset): - DATASET_NAME = "NBA" +class IMDBDataset(PreloadedDataset): + DATASET_NAME = "imdb" + HF_SHA = "c3a583313d1611b292933d77e725b11be2c39a05" -class NDCClassesDataset(Dataset): - DATASET_NAME = "NDC_CLASSES" +class MusicBluesReviewsDataset(PreloadedDataset): + DATASET_NAME = "music-blues-reviews" + HF_SHA = "7d218b727097ed007e7f368ab91c064b3eeff184" -class NDCSubstancesDataset(Dataset): - DATASET_NAME = "NDC_SUBSTANCES" +class NBADataset(PreloadedDataset): + DATASET_NAME = "nba" + HF_SHA = "5b3b1c7e425bc407bc0843f443cdf889b51e1ca7" -class PatentDataset(Dataset): - DATASET_NAME = "PATENT" +class NDCClassesDataset(PreloadedDataset): + DATASET_NAME = "NDC-classes" + HF_SHA = "c9bb31897646fb3f964ee4affe126f9885954d92" -class PubmedDataset(Dataset): - DATASET_NAME = "PUBMED" +class NDCSubstancesDataset(PreloadedDataset): + DATASET_NAME = "NDC-substances" + HF_SHA = "bbdde0839ca5913a2535e6fe3ce397b990803af9" -class RestaurantReviewsDataset(Dataset): - DATASET_NAME = "RESTAURANT_REVIEWS" +class PatentDataset(PreloadedDataset): + DATASET_NAME = "patent" + HF_SHA = "608b4fab97d17adbc01b0b4636b060a550231307" -class ThreadsAskUbuntuDataset(Dataset): - DATASET_NAME = "THREADS_ASK_UBUNTU" +class PubmedDataset(PreloadedDataset): + DATASET_NAME = "pubmed" + HF_SHA = "b8f846a3c812b3b23f10bd69f65f739983f6a390" -class ThreadsMathsxDataset(Dataset): - DATASET_NAME = "THREADS_MATH_SX" +class RestaurantReviewsDataset(PreloadedDataset): + DATASET_NAME = "restaurant-reviews" + HF_SHA = "668a90391fcb968c786da7bc9e7bbc55e2832066" -class TwitterDataset(Dataset): - DATASET_NAME = "TWITTER" +class ThreadsAskUbuntuDataset(PreloadedDataset): + DATASET_NAME = "threads-ask-ubuntu" + HF_SHA = "704c54c7f21b4e313ab6bb50bcd30f58ade469b6" -class VegasBarsReviewsDataset(Dataset): - DATASET_NAME = "VEGAS_BARS_REVIEWS" +class ThreadsMathsxDataset(PreloadedDataset): + DATASET_NAME = "threads-math-sx" + HF_SHA = "b024111c16fdb266e159a4c647ff1a31ec40db5b" + + +class TwitterDataset(PreloadedDataset): + DATASET_NAME = "twitter" + HF_SHA = "d93c55af8e04cf70d65ed0059325009a21699a25" + + +class VegasBarsReviewsDataset(PreloadedDataset): + DATASET_NAME = "vegas-bars-reviews" + HF_SHA = "4f1e4e4c87957679efc38c05129a694d315a8c9b" diff --git a/hyperbench/tests/data/dataset_test.py b/hyperbench/tests/data/dataset_test.py index 7c44c01..a0f2ebb 100644 --- a/hyperbench/tests/data/dataset_test.py +++ b/hyperbench/tests/data/dataset_test.py @@ -1,12 +1,11 @@ import pytest -import requests -import tempfile import torch -from unittest.mock import patch, mock_open, MagicMock -from hyperbench.data import AlgebraDataset, Dataset, HIFConverter, SamplingStrategy -from hyperbench.nn import EnrichmentMode, NodeEnricher, HyperedgeEnricher +from unittest.mock import patch, MagicMock +from hyperbench.data import AlgebraDataset, Dataset, HIFLoader, SamplingStrategy +from hyperbench.nn import NodeEnricher, HyperedgeEnricher from hyperbench.types import HData, HIFHypergraph +from hyperbench.data.supported_datasets import PreloadedDataset @pytest.fixture @@ -33,381 +32,89 @@ def mock_hdata_with_hyperedge_weights() -> HData: @pytest.fixture -def mock_sample_hypergraph(): - return HIFHypergraph( - network_type="undirected", - nodes=[{"node": "0"}, {"node": "1"}], - hyperedges=[{"edge": "0"}], - incidences=[{"node": "0", "edge": "0"}], - ) +def mock_hdata_sample_hypergraph() -> HData: + x = torch.ones((2, 1), dtype=torch.float) + hyperedge_index = torch.tensor([[0, 1], [0, 1]], dtype=torch.long) + return HData(x=x, hyperedge_index=hyperedge_index) @pytest.fixture -def mock_simple_hypergraph(): - return HIFHypergraph( - network_type="undirected", - nodes=[{"node": "0", "attrs": {}}, {"node": "1", "attrs": {}}], - hyperedges=[{"edge": "0", "attrs": {}}], - incidences=[{"node": "0", "edge": "0"}], - ) +def mock_hdata_simple_hypergraph() -> HData: + x = torch.ones((2, 1), dtype=torch.float) + hyperedge_index = torch.tensor([[0, 1], [0, 1]], dtype=torch.long) + return HData(x=x, hyperedge_index=hyperedge_index) @pytest.fixture -def mock_three_node_weighted_hypergraph(): - return HIFHypergraph( - network_type="undirected", - nodes=[ - {"node": "0", "attrs": {}}, - {"node": "1", "attrs": {}}, - {"node": "2", "attrs": {}}, - ], - hyperedges=[ - {"edge": "0", "attrs": {"weight": 1.0}}, - {"edge": "1", "attrs": {"weight": 2.0}}, - ], - incidences=[ - {"node": "0", "edge": "0"}, - {"node": "1", "edge": "0"}, - {"node": "2", "edge": "1"}, - ], - ) +def mock_hdata_three_node_weighted_hypergraph() -> HData: + x = torch.ones((3, 1), dtype=torch.float) + hyperedge_index = torch.tensor([[0, 1, 2], [0, 0, 1]], dtype=torch.long) + hyperedge_weights = torch.tensor([[1.0], [2.0]], dtype=torch.float) + return HData(x=x, hyperedge_index=hyperedge_index, hyperedge_weights=hyperedge_weights) @pytest.fixture -def mock_four_node_hypergraph(): - return HIFHypergraph( - network_type="undirected", - nodes=[ - {"node": "0", "attrs": {}}, - {"node": "1", "attrs": {}}, - {"node": "2", "attrs": {}}, - {"node": "3", "attrs": {}}, - ], - hyperedges=[{"edge": "0", "attrs": {}}, {"edge": "1", "attrs": {}}], - incidences=[ - {"node": "0", "edge": "0"}, - {"node": "1", "edge": "0"}, - {"node": "2", "edge": "1"}, - {"node": "3", "edge": "1"}, - ], - ) +def mock_hdata_four_node_hypergraph() -> HData: + x = torch.ones((4, 1), dtype=torch.float) + hyperedge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]], dtype=torch.long) + return HData(x=x, hyperedge_index=hyperedge_index) @pytest.fixture -def mock_five_node_hypergraph(): - return HIFHypergraph( - network_type="undirected", - nodes=[ - {"node": "0", "attrs": {}}, - {"node": "1", "attrs": {}}, - {"node": "2", "attrs": {}}, - {"node": "3", "attrs": {}}, - {"node": "4", "attrs": {}}, - ], - hyperedges=[{"edge": "0", "attrs": {}}], - incidences=[{"node": "0", "edge": "0"}], - ) +def mock_hdata_no_edge_attr_hypergraph() -> HData: + x = torch.ones((2, 1), dtype=torch.float) + hyperedge_index = torch.tensor([[0, 1], [0, 0]], dtype=torch.long) + return HData(x=x, hyperedge_index=hyperedge_index) @pytest.fixture -def mock_no_edge_attr_hypergraph(): - return HIFHypergraph( - network_type="undirected", - nodes=[ - {"node": "0", "attrs": {}}, - {"node": "1", "attrs": {}}, - ], - hyperedges=[{"edge": "0"}], - incidences=[ - {"node": "0", "edge": "0"}, - {"node": "1", "edge": "0"}, - ], - ) +def mock_hdata_multiple_edges_attr_hypergraph() -> HData: + x = torch.ones((4, 1), dtype=torch.float) + hyperedge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 2]], dtype=torch.long) + hyperedge_weights = torch.tensor([[1.0], [2.0], [3.0]], dtype=torch.float) + return HData(x=x, hyperedge_index=hyperedge_index, hyperedge_weights=hyperedge_weights) @pytest.fixture -def mock_multiple_edges_attr_hypergraph(): - return HIFHypergraph( - network_type="undirected", - nodes=[ - {"node": "0", "attrs": {}}, - {"node": "1", "attrs": {}}, - {"node": "2", "attrs": {}}, - {"node": "3", "attrs": {}}, - ], - hyperedges=[ - {"edge": "0", "attrs": {"weight": 1.0}}, - {"edge": "1", "attrs": {"weight": 2.0}}, - {"edge": "2", "attrs": {"weight": 3.0}}, - ], - incidences=[ - {"node": "0", "edge": "0"}, - {"node": "1", "edge": "0"}, - {"node": "2", "edge": "1"}, - {"node": "3", "edge": "2"}, - ], - ) - - -def test_HIFConverter_num_nodes_and_edges(): - dataset_name = "ALGEBRA" - mock_hypergraph = HIFHypergraph( - network_type="undirected", - nodes=[{"node": str(i)} for i in range(20)], - hyperedges=[{"edge": str(i)} for i in range(30)], - incidences=[{"node": "0", "edge": "0"}], - ) - - with patch.object(HIFConverter, "load_from_hif", return_value=mock_hypergraph): - hypergraph = HIFConverter.load_from_hif(dataset_name) - - assert hypergraph is not None - assert hasattr(hypergraph, "nodes") - assert hasattr(hypergraph, "hyperedges") - assert hasattr(hypergraph, "incidences") - assert hasattr(hypergraph, "metadata") - assert hasattr(hypergraph, "network_type") - - assert hypergraph.num_nodes == 20 - assert hypergraph.num_hyperedges == 30 - - -def test_HIFConverter_loads_invalid_dataset(): - dataset_name = "INVALID_DATASET" - - with pytest.raises(ValueError, match="Dataset 'INVALID_DATASET' not found"): - HIFConverter.load_from_hif(dataset_name) - - -def test_HIFConverter_loads_invalid_hif_format(): - dataset_name = "ALGEBRA" - - invalid_hif_json = '{"network-type": "undirected", "nodes": []}' - - with ( - patch("hyperbench.data.dataset.requests.get") as mock_get, - patch("hyperbench.data.dataset.validate_hif_json", return_value=False), - patch("builtins.open", mock_open(read_data=invalid_hif_json)), - patch("hyperbench.data.dataset.zstd.ZstdDecompressor"), - ): - mock_response = mock_get.return_value - mock_response.status_code = 200 - mock_response.content = b"mock_zst_content" - - with pytest.raises(ValueError, match="Dataset 'algebra' is not HIF-compliant"): - HIFConverter.load_from_hif(dataset_name) +def mock_hdata_no_incidences() -> HData: + x = torch.ones((2, 1), dtype=torch.float) + hyperedge_index = torch.tensor([[0, 1], [0, 1]], dtype=torch.long) + return HData(x=x, hyperedge_index=hyperedge_index) -def test_HIFConverter_stores_on_disk_when_save_on_disk_true(): - dataset_name = "ALGEBRA" +@pytest.fixture +def mock_hdata_with_two_edge_attributes() -> HData: + x = torch.ones((3, 1), dtype=torch.float) + hyperedge_index = torch.tensor([[0, 1, 2], [0, 0, 1]], dtype=torch.long) + hyperedge_weights = torch.tensor([[1.0, 2.0], [3.0, 0.1]], dtype=torch.float) + return HData(x=x, hyperedge_index=hyperedge_index, hyperedge_weights=hyperedge_weights) - mock_hypergraph = HIFHypergraph( - network_type="undirected", - nodes=[{"node": "0"}, {"node": "1"}], - hyperedges=[{"edge": "0"}], - incidences=[{"node": "0", "edge": "0"}], - ) - mock_hif_json = { - "network-type": "undirected", - "nodes": [{"node": "0"}, {"node": "1"}], - "edges": [{"edge": "0"}], - "incidences": [{"node": "0", "edge": "0"}], - } - - with ( - patch("hyperbench.data.dataset.requests.get") as mock_get, - patch("hyperbench.data.dataset.os.path.exists", return_value=False), - patch("hyperbench.data.dataset.os.makedirs"), - patch("builtins.open", mock_open()) as mock_file, - patch("hyperbench.data.dataset.zstd.ZstdDecompressor") as mock_decomp, - patch("hyperbench.data.dataset.json.load", return_value=mock_hif_json), - patch("hyperbench.data.dataset.validate_hif_json", return_value=True), - patch.object(HIFHypergraph, "from_hif", return_value=mock_hypergraph), - ): - # Mock successful download - mock_response = mock_get.return_value - mock_response.status_code = 200 - mock_response.content = b"mock_zst_content" +@pytest.fixture +def mock_hdata_random_ids() -> HData: + x = torch.ones((3, 1), dtype=torch.float) + hyperedge_index = torch.tensor([[0, 1, 2], [0, 0, 1]], dtype=torch.long) + return HData(x=x, hyperedge_index=hyperedge_index) - # Mock decompressor - mock_stream = mock_decomp.return_value.stream_reader.return_value - mock_stream.__enter__ = lambda self: mock_stream - mock_stream.__exit__ = lambda self, *args: None - hypergraph = HIFConverter.load_from_hif(dataset_name, save_on_disk=True) +def test_Preloaded_dataset_init(): + mock_hdata = MagicMock(spec=HData) + dataset = PreloadedDataset(hdata=mock_hdata) - assert hypergraph is not None - assert hypergraph.network_type == "undirected" - mock_get.assert_called_once() - # Verify file was written to disk (not temp file) - assert mock_file.call_count >= 2 # Once for write, once for read + assert dataset.hdata == mock_hdata + assert dataset.sampling_strategy is SamplingStrategy.HYPEREDGE -def test_HIFConverter_uses_temp_file_when_save_on_disk_false(): - dataset_name = "ALGEBRA" +def test_Preloaded_dataset_loads_hdata_when_hdata_is_none(): + mock_hdata = MagicMock(spec=HData) + with patch.object(HIFLoader, "load_by_name", return_value=mock_hdata) as mock_load: + dataset = AlgebraDataset(hdata=None) - mock_hypergraph = HIFHypergraph( - network_type="undirected", - nodes=[{"node": "0"}, {"node": "1"}], - hyperedges=[{"edge": "0"}], - incidences=[{"node": "0", "edge": "0"}], + assert dataset.hdata == mock_hdata + mock_load.assert_called_once_with( + "algebra", hf_sha="2bb641461e00c103fb5ef4fe6a30aad42500fc21", save_on_disk=True ) - mock_hif_json = { - "network-type": "undirected", - "nodes": [{"node": "0"}, {"node": "1"}], - "edges": [{"edge": "0"}], - "incidences": [{"node": "0", "edge": "0"}], - } - - with ( - patch("hyperbench.data.dataset.requests.get") as mock_get, - patch("hyperbench.data.dataset.os.path.exists", return_value=False), - patch("hyperbench.data.dataset.tempfile.NamedTemporaryFile") as mock_temp, - patch("builtins.open", mock_open()), - patch("hyperbench.data.dataset.zstd.ZstdDecompressor") as mock_decomp, - patch("hyperbench.data.dataset.json.load", return_value=mock_hif_json), - patch("hyperbench.data.dataset.validate_hif_json", return_value=True), - patch.object(HIFHypergraph, "from_hif", return_value=mock_hypergraph), - ): - # Mock successful download - mock_response = mock_get.return_value - mock_response.status_code = 200 - mock_response.content = b"mock_zst_content" - - # Mock temp file - mock_temp_file = mock_temp.return_value.__enter__.return_value - mock_temp_file.name = "/tmp/fake_temp.json.zst" - - # Mock decompressor - mock_stream = mock_decomp.return_value.stream_reader.return_value - mock_stream.__enter__ = lambda self: mock_stream - mock_stream.__exit__ = lambda self, *args: None - - hypergraph = HIFConverter.load_from_hif(dataset_name, save_on_disk=False) - - assert hypergraph is not None - assert hypergraph.network_type == "undirected" - mock_get.assert_called_once() - # Verify temp file was used - assert mock_temp.call_count >= 1 - - -def test_HIFConverter_download_failure(): - dataset_name = "ALGEBRA" - - with ( - patch("hyperbench.data.dataset.requests.get") as mock_get, - patch("hyperbench.data.dataset.hf_hub_download", side_effect=Exception("HFHub failed")), - patch("hyperbench.data.dataset.os.path.exists", return_value=False), - ): - # Mock failed download - mock_response = mock_get.return_value - mock_response.status_code = 404 - mock_response.content = b"" - - with pytest.warns( - UserWarning, - match=r"(?s)GitHub raw download failed for dataset 'algebra' with status code 404.*Falling back to Hugging Face Hub download for dataset", - ): - with pytest.raises( - ValueError, - match=r"Failed to download dataset 'algebra'", - ): - HIFConverter.load_from_hif(dataset_name) - - -def test_HIFConverter_falls_back_to_hf_hub_download_when_github_raw_download_fails( - tmp_path, mock_sample_hypergraph -): - dataset_name = "ALGEBRA" - - mock_hypergraph = mock_sample_hypergraph - - mock_hif_json = { - "network_type": "undirected", - "nodes": [{"node": "0"}, {"node": "1"}], - "edges": [{"edge": "0"}], - "incidences": [{"node": "0", "edge": "0"}], - } - - fallback_file = tmp_path / "algebra.json.zst" - fallback_file.write_bytes(b"mock_zst_content") - - created_temp_files = [] - original_named_tempfile = tempfile.NamedTemporaryFile - - def named_tempfile_side_effect(*args, **kwargs): - temp_file = original_named_tempfile(*args, **kwargs) - created_temp_files.append(temp_file) - return temp_file - - with ( - patch("hyperbench.data.dataset.requests.get") as mock_get, - patch( - "hyperbench.data.dataset.hf_hub_download", - return_value=str(fallback_file), - ) as mock_hf_hub_download, - patch("hyperbench.data.dataset.os.path.exists", return_value=False), - patch( - "hyperbench.data.dataset.tempfile.NamedTemporaryFile", - side_effect=named_tempfile_side_effect, - ), - patch("hyperbench.data.dataset.zstd.ZstdDecompressor") as mock_decomp, - patch("hyperbench.data.dataset.json.load", return_value=mock_hif_json), - patch("hyperbench.data.dataset.validate_hif_json", return_value=True), - patch.object(HIFHypergraph, "from_hif", return_value=mock_hypergraph), - ): - mock_response = mock_get.return_value - mock_response.status_code = 404 - mock_response.content = b"" - - def fake_copy_stream(src, dst): - dst.write(b'{"network_type":"undirected","nodes":[],"edges":[],"incidences":[]}') - return - - mock_decomp.return_value.copy_stream.side_effect = fake_copy_stream - - with pytest.warns( - UserWarning, - match=r"(?s)GitHub raw download failed for dataset 'algebra' with status code 404.*Falling back to Hugging Face Hub download for dataset", - ): - hypergraph = HIFConverter.load_from_hif(dataset_name, save_on_disk=False) - - assert hypergraph.network_type == "undirected" - mock_get.assert_called_once() - mock_hf_hub_download.assert_called_once() - assert created_temp_files[0].name is not None - assert fallback_file.read_bytes() == b"mock_zst_content" - - -def test_HIFConverter_download_raises_when_network_error(): - dataset_name = "ALGEBRA" - - with ( - patch("hyperbench.data.dataset.requests.get") as mock_get, - patch("hyperbench.data.dataset.os.path.exists", return_value=False), - ): - # Mock network error - mock_get.side_effect = requests.RequestException("Network error") - - with pytest.raises(requests.RequestException, match="Network error"): - HIFConverter.load_from_hif(dataset_name) - - -def test_init_with_prepare_false_and_no_hdata_raises(): - with pytest.raises(ValueError, match="hdata must be provided when prepare is set to False."): - Dataset(hdata=None, prepare=False) - - -def test_dataset_is_not_available(): - class FakeMockDataset(Dataset): - DATASET_NAME = "FAKE" - - with pytest.raises(ValueError, match=r"Dataset 'FAKE' not found"): - FakeMockDataset() - @pytest.mark.parametrize( "strategy, expected_len", @@ -417,95 +124,44 @@ class FakeMockDataset(Dataset): ], ) def test_dataset_is_available_with_all_strategies( - strategy, expected_len, mock_four_node_hypergraph + strategy, expected_len, mock_hdata_four_node_hypergraph ): - with patch.object(HIFConverter, "load_from_hif", return_value=mock_four_node_hypergraph): + + with patch.object(HIFLoader, "load_by_name", return_value=mock_hdata_four_node_hypergraph): dataset = AlgebraDataset(sampling_strategy=strategy) - assert dataset.DATASET_NAME == "ALGEBRA" - assert dataset.hypergraph is not None + assert dataset.DATASET_NAME == "algebra" assert len(dataset) == expected_len -def test_download_already_downloaded_dataset_uses_local_value(mock_four_node_hypergraph): - with patch.object(HIFConverter, "load_from_hif", return_value=mock_four_node_hypergraph): - dataset = AlgebraDataset() - - hg1 = dataset.download() - hg2 = dataset.download() - - assert hg1 is hg2 - - -def test_throw_when_dataset_name_is_none(): - class FakeMockDataset(Dataset): - DATASET_NAME = None - - with pytest.raises( - ValueError, - match=r"Dataset name \(provided: None\) must be provided\.", - ): - FakeMockDataset() - - -def test_dataset_process_no_incidences(): - mock_hypergraph = HIFHypergraph( - network_type="undirected", - nodes=[{"node": "0", "attrs": {}}, {"node": "1", "attrs": {}}], - hyperedges=[{"edge": "0", "attrs": {}}], - incidences=[], - ) - - with patch.object(HIFConverter, "load_from_hif", return_value=mock_hypergraph): +def test_dataset_process_no_incidences(mock_hdata_no_incidences): + with patch.object(HIFLoader, "load_by_name", return_value=mock_hdata_no_incidences): dataset = AlgebraDataset() assert dataset.hdata is not None assert dataset.hdata.x.shape[0] == 2 assert dataset.hdata.hyperedge_index.shape[0] == 2 assert dataset.hdata.hyperedge_index.shape[1] == 2 - assert dataset.hdata.hyperedge_attr is not None - assert dataset.hdata.hyperedge_attr.shape == (2, 0) - assert dataset.hdata.hyperedge_attr[0].shape == (0,) - assert dataset.hdata.hyperedge_attr[1].shape == (0,) - - -def test_dataset_process_with_edge_attributes(): - mock_hypergraph = HIFHypergraph( - network_type="undirected", - nodes=[ - {"node": "0", "attrs": {}}, - {"node": "1", "attrs": {}}, - {"node": "2", "attrs": {}}, - ], - hyperedges=[ - {"edge": "0", "attrs": {"weight": 1.0, "type": 2.0}}, - {"edge": "1", "attrs": {"weight": 3.0, "type": 0.1}}, - ], - incidences=[ - {"node": "0", "edge": "0"}, - {"node": "1", "edge": "0"}, - {"node": "2", "edge": "1"}, - ], - ) + assert dataset.hdata.hyperedge_attr is None + - with patch.object(HIFConverter, "load_from_hif", return_value=mock_hypergraph): +def test_dataset_process_with_edge_attributes(mock_hdata_with_two_edge_attributes): + with patch.object(HIFLoader, "load_by_name", return_value=mock_hdata_with_two_edge_attributes): dataset = AlgebraDataset() assert dataset.hdata is not None assert dataset.hdata.x.shape[0] == 3 assert dataset.hdata.hyperedge_index.shape[0] == 2 assert dataset.hdata.hyperedge_index.shape[1] == 3 - assert dataset.hdata.hyperedge_attr is not None - # Two edges with two attributes each: shape [2, 2] - assert dataset.hdata.hyperedge_attr.shape == (2, 2) - # Attributes maintain dictionary insertion order (no sorting) - - assert torch.allclose(dataset.hdata.hyperedge_attr[0], torch.tensor([1.0, 2.0])) # weight, type - assert torch.allclose(dataset.hdata.hyperedge_attr[1], torch.tensor([3.0, 0.1])) # weight, type + assert dataset.hdata.hyperedge_attr is None + assert dataset.hdata.hyperedge_weights is not None + assert dataset.hdata.hyperedge_weights.shape == (2, 2) + assert torch.allclose(dataset.hdata.hyperedge_weights[0], torch.tensor([1.0, 2.0])) + assert torch.allclose(dataset.hdata.hyperedge_weights[1], torch.tensor([3.0, 0.1])) -def test_dataset_process_without_edge_attributes(mock_no_edge_attr_hypergraph): - with patch.object(HIFConverter, "load_from_hif", return_value=mock_no_edge_attr_hypergraph): +def test_dataset_process_without_edge_attributes(mock_hdata_no_edge_attr_hypergraph): + with patch.object(HIFLoader, "load_by_name", return_value=mock_hdata_no_edge_attr_hypergraph): dataset = AlgebraDataset() assert dataset.hdata is not None @@ -514,8 +170,8 @@ def test_dataset_process_without_edge_attributes(mock_no_edge_attr_hypergraph): assert dataset.hdata.hyperedge_attr is None -def test_dataset_process_hyperedge_index_in_correct_format(mock_four_node_hypergraph): - with patch.object(HIFConverter, "load_from_hif", return_value=mock_four_node_hypergraph): +def test_dataset_process_hyperedge_index_in_correct_format(mock_hdata_four_node_hypergraph): + with patch.object(HIFLoader, "load_by_name", return_value=mock_hdata_four_node_hypergraph): dataset = AlgebraDataset() assert dataset.hdata.hyperedge_index.shape == (2, 4) @@ -523,30 +179,14 @@ def test_dataset_process_hyperedge_index_in_correct_format(mock_four_node_hyperg assert torch.allclose(dataset.hdata.hyperedge_index[1], torch.tensor([0, 0, 1, 1])) -def test_dataset_process_random_ids(): - mock_hypergraph = HIFHypergraph( - network_type="undirected", - nodes=[ - {"node": "abc", "attrs": {}}, - {"node": "ss", "attrs": {}}, - {"node": "fewao", "attrs": {}}, - ], - hyperedges=[{"edge": "0", "attrs": {}}, {"edge": "1", "attrs": {}}], - incidences=[ - {"node": "abc", "edge": "0"}, - {"node": "ss", "edge": "0"}, - {"node": "fewao", "edge": "1"}, - ], - ) - - with patch.object(HIFConverter, "load_from_hif", return_value=mock_hypergraph): +def test_dataset_process_random_ids(mock_hdata_random_ids): + with patch.object(HIFLoader, "load_by_name", return_value=mock_hdata_random_ids): dataset = AlgebraDataset() assert dataset.hdata.hyperedge_index.shape == (2, 3) assert torch.allclose(dataset.hdata.hyperedge_index[0], torch.tensor([0, 1, 2])) assert torch.allclose(dataset.hdata.hyperedge_index[1], torch.tensor([0, 0, 1])) - assert dataset.hdata.hyperedge_attr is not None - assert dataset.hdata.hyperedge_attr.shape == (2, 0) # 2 edges, 0 attributes each + assert dataset.hdata.hyperedge_attr is None @pytest.mark.parametrize( @@ -556,8 +196,8 @@ def test_dataset_process_random_ids(): pytest.param(SamplingStrategy.HYPEREDGE, id="hyperedge_strategy"), ], ) -def test_getitem_index_list_empty(mock_simple_hypergraph, strategy): - with patch.object(HIFConverter, "load_from_hif", return_value=mock_simple_hypergraph): +def test_getitem_index_list_empty(mock_hdata_simple_hypergraph, strategy): + with patch.object(HIFLoader, "load_by_name", return_value=mock_hdata_simple_hypergraph): dataset = AlgebraDataset(sampling_strategy=strategy) with pytest.raises(ValueError, match="Index list cannot be empty."): @@ -582,9 +222,9 @@ def test_getitem_index_list_empty(mock_simple_hypergraph, strategy): ], ) def test_getitem_raises_when_index_list_larger_than_max( - mock_four_node_hypergraph, strategy, index_list, expected_message + mock_hdata_four_node_hypergraph, strategy, index_list, expected_message ): - with patch.object(HIFConverter, "load_from_hif", return_value=mock_four_node_hypergraph): + with patch.object(HIFLoader, "load_by_name", return_value=mock_hdata_four_node_hypergraph): dataset = AlgebraDataset(sampling_strategy=strategy) with pytest.raises(ValueError, match=expected_message): @@ -606,9 +246,9 @@ def test_getitem_raises_when_index_list_larger_than_max( ], ) def test_getitem_raises_when_index_out_of_bounds( - mock_four_node_hypergraph, strategy, index, expected_message + mock_hdata_four_node_hypergraph, strategy, index, expected_message ): - with patch.object(HIFConverter, "load_from_hif", return_value=mock_four_node_hypergraph): + with patch.object(HIFLoader, "load_by_name", return_value=mock_hdata_four_node_hypergraph): dataset = AlgebraDataset(sampling_strategy=strategy) with pytest.raises(IndexError, match=expected_message): @@ -625,9 +265,9 @@ def test_getitem_raises_when_index_out_of_bounds( ], ) def test_getitem_single_index( - mock_sample_hypergraph, strategy, index, expected_shape, expected_num_hyperedges + mock_hdata_sample_hypergraph, strategy, index, expected_shape, expected_num_hyperedges ): - with patch.object(HIFConverter, "load_from_hif", return_value=mock_sample_hypergraph): + with patch.object(HIFLoader, "load_by_name", return_value=mock_hdata_sample_hypergraph): dataset = AlgebraDataset(sampling_strategy=strategy) data = dataset[index] @@ -646,9 +286,9 @@ def test_getitem_single_index( ], ) def test_getitem_when_list_index_provided( - mock_four_node_hypergraph, strategy, index, expected_shape, expected_num_hyperedges + mock_hdata_four_node_hypergraph, strategy, index, expected_shape, expected_num_hyperedges ): - with patch.object(HIFConverter, "load_from_hif", return_value=mock_four_node_hypergraph): + with patch.object(HIFLoader, "load_by_name", return_value=mock_hdata_four_node_hypergraph): dataset = AlgebraDataset(sampling_strategy=strategy) data = dataset[index] @@ -664,9 +304,9 @@ def test_getitem_when_list_index_provided( pytest.param(SamplingStrategy.HYPEREDGE, id="hyperedge_strategy"), ], ) -def test_getitem_with_edge_attr(mock_three_node_weighted_hypergraph, strategy): +def test_getitem_with_edge_attr(mock_hdata_three_node_weighted_hypergraph, strategy): with patch.object( - HIFConverter, "load_from_hif", return_value=mock_three_node_weighted_hypergraph + HIFLoader, "load_by_name", return_value=mock_hdata_three_node_weighted_hypergraph ): dataset = AlgebraDataset(sampling_strategy=strategy) @@ -684,8 +324,8 @@ def test_getitem_with_edge_attr(mock_three_node_weighted_hypergraph, strategy): pytest.param(SamplingStrategy.HYPEREDGE, id="hyperedge_strategy"), ], ) -def test_getitem_without_edge_attr(mock_no_edge_attr_hypergraph, strategy): - with patch.object(HIFConverter, "load_from_hif", return_value=mock_no_edge_attr_hypergraph): +def test_getitem_without_edge_attr(mock_hdata_no_edge_attr_hypergraph, strategy): + with patch.object(HIFLoader, "load_by_name", return_value=mock_hdata_no_edge_attr_hypergraph): dataset = AlgebraDataset(sampling_strategy=strategy) data = dataset[0] @@ -701,9 +341,11 @@ def test_getitem_without_edge_attr(mock_no_edge_attr_hypergraph, strategy): pytest.param(SamplingStrategy.HYPEREDGE, [0, 1], id="hyperedge_strategy"), ], ) -def test_getitem_with_multiple_edges_attr(mock_multiple_edges_attr_hypergraph, strategy, index): +def test_getitem_with_multiple_edges_attr( + mock_hdata_multiple_edges_attr_hypergraph, strategy, index +): with patch.object( - HIFConverter, "load_from_hif", return_value=mock_multiple_edges_attr_hypergraph + HIFLoader, "load_by_name", return_value=mock_hdata_multiple_edges_attr_hypergraph ): dataset = AlgebraDataset(sampling_strategy=strategy) @@ -715,245 +357,6 @@ def test_getitem_with_multiple_edges_attr(mock_multiple_edges_attr_hypergraph, s assert data.hyperedge_attr is None -def test_getitem_hyperedge_attr_are_padded_with_zero_when_no_uniform_edges(): - mock_hypergraph = HIFHypergraph( - network_type="undirected", - nodes=[ - {"node": "0", "attrs": {}}, - {"node": "1", "attrs": {}}, - {"node": "2", "attrs": {}}, - {"node": "3", "attrs": {}}, - ], - hyperedges=[ - {"edge": "0", "attrs": {"weight": 1.0, "abc": 5.0}}, - {"edge": "1", "attrs": {"weight": 2.0}}, # Missing 'abc' - {"edge": "2", "attrs": {"abc": 3.0}}, # Missing 'weight' - ], - incidences=[ - {"node": "0", "edge": "0"}, - {"node": "1", "edge": "0"}, - {"node": "2", "edge": "1"}, - {"node": "3", "edge": "2"}, - ], - ) - - with patch.object(HIFConverter, "load_from_hif", return_value=mock_hypergraph): - dataset = AlgebraDataset() - - assert dataset.hdata.hyperedge_attr is not None - assert dataset.hdata.hyperedge_attr.shape == ( - 3, - 2, - ) # 3 edges, 2 features each (weight, abc in insertion order) - assert torch.allclose( - dataset.hdata.hyperedge_attr[0], torch.tensor([1.0, 5.0]) - ) # weight=1.0, abc=5.0 - assert torch.allclose( - dataset.hdata.hyperedge_attr[1], torch.tensor([2.0, 0.0]) - ) # weight=2.0, abc=0.0 - assert torch.allclose( - dataset.hdata.hyperedge_attr[2], torch.tensor([0.0, 3.0]) - ) # weight=0.0, abc=3.0 - - -def test_process_not_all_hyperedge_weights_(): - mock_hypergraph = HIFHypergraph( - network_type="undirected", - nodes=[ - {"node": "0", "attrs": {}}, - {"node": "1", "attrs": {}}, - {"node": "2", "attrs": {}}, - ], - hyperedges=[ - {"edge": "0", "weight": 1.5}, - {"edge": "1"}, - {"edge": "2", "weight": 2.5}, - ], - incidences=[ - {"node": "0", "edge": "0"}, - {"node": "1", "edge": "1"}, - {"node": "2", "edge": "2"}, - ], - ) - - with patch.object(HIFConverter, "load_from_hif", return_value=mock_hypergraph): - with pytest.raises( - ValueError, - match="Some hyperedges have weights while others do not. All hyperedges must either have weights or none.", - ): - dataset = AlgebraDataset() - - -def test_process_extracts_top_level_hyperedge_weights(): - mock_hypergraph = HIFHypergraph( - network_type="undirected", - nodes=[ - {"node": "0", "attrs": {}}, - {"node": "1", "attrs": {}}, - {"node": "2", "attrs": {}}, - ], - hyperedges=[ - {"edge": "0", "weight": 1.5}, - {"edge": "1", "weight": 3.0}, - {"edge": "2", "weight": 2.5}, - ], - incidences=[ - {"node": "0", "edge": "0"}, - {"node": "1", "edge": "1"}, - {"node": "2", "edge": "2"}, - ], - ) - - with patch.object(HIFConverter, "load_from_hif", return_value=mock_hypergraph): - dataset = AlgebraDataset() - - hyperedge_weights = dataset.hdata.hyperedge_weights - assert hyperedge_weights is not None - assert torch.allclose(hyperedge_weights[0], torch.tensor([1.5])) - assert torch.allclose(hyperedge_weights[1], torch.tensor([3.0])) - assert torch.allclose(hyperedge_weights[2], torch.tensor([2.5])) - - -def test_transform_attrs_empty_attrs(): - mock_hypergraph = HIFHypergraph( - network_type="undirected", - nodes=[{"node": "0", "attrs": {}}], - hyperedges=[{"edge": "0", "attrs": {}}], - incidences=[{"node": "0", "edge": "0"}], - ) - - with patch.object(HIFConverter, "load_from_hif", return_value=mock_hypergraph): - - class TestDataset(Dataset): - DATASET_NAME = "TEST" - - dataset = TestDataset() - - result = dataset.transform_attrs({}) - assert len(result) == 0 - - attrs = {"name": "node1", "active": True} - result = dataset.transform_attrs(attrs) - assert len(result) == 0 - - -def test_process_adds_padding_zero_when_inconsistent_node_attributes(): - mock_hypergraph = HIFHypergraph( - network_type="undirected", - nodes=[ - {"node": "0", "attrs": {"weight": 1.0}}, # Missing 'score' - {"node": "1", "attrs": {"weight": 2.0, "score": 0.8}}, - {"node": "2", "attrs": {"score": 0.5}}, # Missing 'weight' - ], - hyperedges=[{"edge": "0", "attrs": {}}], - incidences=[ - {"node": "0", "edge": "0"}, - {"node": "1", "edge": "0"}, - {"node": "2", "edge": "0"}, - ], - ) - - with patch.object(HIFConverter, "load_from_hif", return_value=mock_hypergraph): - - class TestDataset(Dataset): - DATASET_NAME = "TEST" - - dataset = TestDataset() - - assert dataset.hdata.x.shape == ( - 3, - 2, - ) # 3 nodes, 2 features each (weight, score in insertion order) - assert torch.allclose(dataset.hdata.x[0], torch.tensor([1.0, 0.0])) # weight=1.0, score=0.0 - assert torch.allclose(dataset.hdata.x[1], torch.tensor([2.0, 0.8])) # weight=2.0, score=0.8 - assert torch.allclose(dataset.hdata.x[2], torch.tensor([0.0, 0.5])) # weight=0.0, score=0.5 - - -def test_process_with_no_node_attributes_fallback_to_one(): - mock_hypergraph = HIFHypergraph( - network_type="undirected", - nodes=[ - {"node": "0", "attrs": {"name": "node0"}}, - {"node": "1", "attrs": {}}, - ], - hyperedges=[{"edge": "0", "attrs": {}}], - incidences=[{"node": "0", "edge": "0"}, {"node": "1", "edge": "0"}], - ) - - with patch.object(HIFConverter, "load_from_hif", return_value=mock_hypergraph): - - class TestDataset(Dataset): - DATASET_NAME = "TEST" - - dataset = TestDataset() - - assert dataset.hdata.x.shape == (2, 1) - assert torch.allclose(dataset.hdata.x, torch.tensor([[1.0], [1.0]])) - - -def test_process_with_single_node_attribute(): - mock_hypergraph = HIFHypergraph( - network_type="undirected", - nodes=[ - {"node": "0", "attrs": {"weight": 1.5}}, - {"node": "1", "attrs": {"weight": 2.5}}, - {"node": "2", "attrs": {"weight": 3.5}}, - ], - hyperedges=[{"edge": "0", "attrs": {}}], - incidences=[ - {"node": "0", "edge": "0"}, - {"node": "1", "edge": "0"}, - {"node": "2", "edge": "0"}, - ], - ) - - with patch.object(HIFConverter, "load_from_hif", return_value=mock_hypergraph): - - class TestDataset(Dataset): - DATASET_NAME = "TEST" - - dataset = TestDataset() - - # Single attribute should remain 2D: [num_nodes, 1] - assert dataset.hdata.x.shape == (3, 1) - assert torch.allclose(dataset.hdata.x, torch.tensor([[1.5], [2.5], [3.5]])) - - -def test_transform_attrs_adds_padding_zero_when_attr_keys_padding(): - mock_hypergraph = HIFHypergraph( - network_type="undirected", - nodes=[{"node": "0", "attrs": {}}], - hyperedges=[{"edge": "0", "attrs": {}}], - incidences=[{"node": "0", "edge": "0"}], - ) - - with patch.object(HIFConverter, "load_from_hif", return_value=mock_hypergraph): - - class TestDataset(Dataset): - DATASET_NAME = "TEST" - - dataset = TestDataset() - - # Test with attr_keys - should pad missing attributes with 0.0 - attrs = {"weight": 1.5} - result = dataset.transform_attrs(attrs, attr_keys=["score", "weight", "age"]) - assert torch.allclose( - result, torch.tensor([0.0, 1.5, 0.0]) - ) # score=0.0, weight=1.5, age=0.0 - - # Test with all attributes present - attrs = {"weight": 1.5, "score": 0.8, "age": 25.0} - result = dataset.transform_attrs(attrs, attr_keys=["age", "score", "weight"]) - assert torch.allclose( - result, torch.tensor([25.0, 0.8, 1.5]) - ) # age=25.0, score=0.8, weight=1.5 - - # Test without attr_keys - maintains insertion order - attrs = {"weight": 1.5, "score": 0.8} - result = dataset.transform_attrs(attrs) - assert torch.allclose(result, torch.tensor([1.5, 0.8])) # weight, score (insertion order) - - @pytest.mark.parametrize( "strategy, expected_len", [ @@ -969,18 +372,40 @@ def test_from_hdata(strategy, expected_len, mock_hdata): assert len(dataset) == expected_len -def test_from_hdata_download_raises(mock_hdata): - dataset = Dataset.from_hdata(mock_hdata) +@pytest.mark.parametrize( + "strategy", + [ + pytest.param(SamplingStrategy.HYPEREDGE, id="hyperedge_strategy"), + pytest.param(SamplingStrategy.NODE, id="node_strategy"), + ], +) +def test_from_url(strategy, mock_hdata): + url = "https://example.com/sample.json.zst" - with pytest.raises(ValueError, match="download can only be called for the original dataset."): - dataset.download() + with patch.object(HIFLoader, "load_from_url", return_value=mock_hdata) as mock_load_from_url: + dataset = Dataset.from_url(url=url, sampling_strategy=strategy, save_on_disk=True) + mock_load_from_url.assert_called_once_with(url=url, save_on_disk=True) + assert dataset.hdata is mock_hdata + assert dataset.sampling_strategy == strategy -def test_from_hdata_process_raises(mock_hdata): - dataset = Dataset.from_hdata(mock_hdata) - with pytest.raises(ValueError, match="process can only be called for the original dataset."): - dataset.process() +@pytest.mark.parametrize( + "strategy", + [ + pytest.param(SamplingStrategy.HYPEREDGE, id="hyperedge_strategy"), + pytest.param(SamplingStrategy.NODE, id="node_strategy"), + ], +) +def test_from_path(strategy, mock_hdata): + filepath = "/abc/sample.json.zst" + + with patch.object(HIFLoader, "load_from_path", return_value=mock_hdata) as mock_load_from_path: + dataset = Dataset.from_path(filepath=filepath, sampling_strategy=strategy) + + mock_load_from_path.assert_called_once_with(filepath=filepath) + assert dataset.hdata is mock_hdata + assert dataset.sampling_strategy == strategy def test_enrich_node_features_replace(mock_hdata): @@ -1133,8 +558,8 @@ def test_remove_hyperedges_with_fewer_than_k_nodes(hyperedge_index, k, expected_ assert dataset.hdata.y.shape[0] == expected_num_hyperedges -def test_split_with_equal_ratios(mock_four_node_hypergraph): - with patch.object(HIFConverter, "load_from_hif", return_value=mock_four_node_hypergraph): +def test_split_with_equal_ratios(mock_hdata_four_node_hypergraph): + with patch.object(HIFLoader, "load_by_name", return_value=mock_hdata_four_node_hypergraph): dataset = AlgebraDataset() splits = dataset.split([0.5, 0.5]) @@ -1150,9 +575,9 @@ def test_split_with_equal_ratios(mock_four_node_hypergraph): assert split.hdata.num_hyperedges > 0 -def test_split_three_way(mock_multiple_edges_attr_hypergraph): +def test_split_three_way(mock_hdata_multiple_edges_attr_hypergraph): with patch.object( - HIFConverter, "load_from_hif", return_value=mock_multiple_edges_attr_hypergraph + HIFLoader, "load_by_name", return_value=mock_hdata_multiple_edges_attr_hypergraph ): dataset = AlgebraDataset() @@ -1168,8 +593,8 @@ def test_split_three_way(mock_multiple_edges_attr_hypergraph): assert split.hdata.num_hyperedges > 0 -def test_split_raises_when_ratios_do_not_sum_to_one(mock_four_node_hypergraph): - with patch.object(HIFConverter, "load_from_hif", return_value=mock_four_node_hypergraph): +def test_split_raises_when_ratios_do_not_sum_to_one(mock_hdata_four_node_hypergraph): + with patch.object(HIFLoader, "load_by_name", return_value=mock_hdata_four_node_hypergraph): dataset = AlgebraDataset() with pytest.raises(ValueError, match="Split ratios must sum to 1.0"): @@ -1177,9 +602,9 @@ def test_split_raises_when_ratios_do_not_sum_to_one(mock_four_node_hypergraph): def test_split_with_shuffle_produces_deterministic_results_when_seed_provided( - mock_four_node_hypergraph, + mock_hdata_four_node_hypergraph, ): - with patch.object(HIFConverter, "load_from_hif", return_value=mock_four_node_hypergraph): + with patch.object(HIFLoader, "load_by_name", return_value=mock_hdata_four_node_hypergraph): dataset = AlgebraDataset() splits_a = dataset.split([0.5, 0.5], shuffle=True, seed=42) @@ -1190,9 +615,9 @@ def test_split_with_shuffle_produces_deterministic_results_when_seed_provided( def test_split_with_shuffle_when_no_seed_provided( - mock_four_node_hypergraph, + mock_hdata_four_node_hypergraph, ): - with patch.object(HIFConverter, "load_from_hif", return_value=mock_four_node_hypergraph): + with patch.object(HIFLoader, "load_by_name", return_value=mock_hdata_four_node_hypergraph): dataset = AlgebraDataset() splits = dataset.split([0.5, 0.5], shuffle=True) @@ -1207,21 +632,21 @@ def test_split_with_shuffle_when_no_seed_provided( assert split.hdata.num_hyperedges > 0 -def test_split_preserves_edge_attr(mock_multiple_edges_attr_hypergraph): +def test_split_preserves_edge_attr(mock_hdata_multiple_edges_attr_hypergraph): with patch.object( - HIFConverter, "load_from_hif", return_value=mock_multiple_edges_attr_hypergraph + HIFLoader, "load_by_name", return_value=mock_hdata_multiple_edges_attr_hypergraph ): dataset = AlgebraDataset() splits = dataset.split([0.5, 0.5]) for split in splits: - assert split.hdata.hyperedge_attr is not None - assert split.hdata.hyperedge_attr.shape[0] == split.hdata.num_hyperedges + assert split.hdata.hyperedge_weights is not None + assert split.hdata.hyperedge_weights.shape[0] == split.hdata.num_hyperedges -def test_split_without_edge_attr(mock_no_edge_attr_hypergraph): - with patch.object(HIFConverter, "load_from_hif", return_value=mock_no_edge_attr_hypergraph): +def test_split_without_edge_attr(mock_hdata_no_edge_attr_hypergraph): + with patch.object(HIFLoader, "load_by_name", return_value=mock_hdata_no_edge_attr_hypergraph): dataset = AlgebraDataset() splits = dataset.split([0.5, 0.5]) @@ -1241,46 +666,8 @@ def test_to_device(mock_hdata): assert dataset.hdata.device == device -def test_load_from_hif_skips_download_when_file_exists(): - dataset_name = "ALGEBRA" - - sample_hif = { - "network-type": "undirected", - "nodes": [{"node": "0"}, {"node": "1"}], - "edges": [{"edge": "0"}], - "incidences": [{"node": "0", "edge": "0"}], - } - - mock_hypergraph = HIFHypergraph( - network_type="undirected", - nodes=[{"node": "0"}, {"node": "1"}], - hyperedges=[{"edge": "0"}], - incidences=[{"node": "0", "edge": "0"}], - ) - - with ( - patch("hyperbench.data.dataset.requests.get") as mock_get, - patch("hyperbench.data.dataset.os.path.exists", return_value=True), - patch("builtins.open", mock_open()) as mock_file, - patch("hyperbench.data.dataset.zstd.ZstdDecompressor") as mock_decomp, - patch("hyperbench.data.dataset.tempfile.NamedTemporaryFile") as mock_temp, - patch("hyperbench.data.dataset.json.load", return_value=sample_hif), - patch("hyperbench.data.dataset.validate_hif_json", return_value=True), - patch.object(HIFHypergraph, "from_hif", return_value=mock_hypergraph), - ): - mock_dctx = mock_decomp.return_value - mock_dctx.copy_stream = lambda input_f, tmp_file: None - - mock_temp_instance = mock_temp.return_value.__enter__.return_value - mock_temp_instance.name = "/tmp/decompressed.json" - - result = HIFConverter.load_from_hif(dataset_name, save_on_disk=True) - mock_get.assert_not_called() - assert result == mock_hypergraph - - -def test_default_sampling_strategy_is_hyperedge(mock_four_node_hypergraph): - with patch.object(HIFConverter, "load_from_hif", return_value=mock_four_node_hypergraph): +def test_default_sampling_strategy_is_hyperedge(mock_hdata_four_node_hypergraph): + with patch.object(HIFLoader, "load_by_name", return_value=mock_hdata_four_node_hypergraph): dataset = AlgebraDataset() # Default strategy is HYPEREDGE, so len should be num_hyperedges (2), not num_nodes (4) @@ -1288,8 +675,8 @@ def test_default_sampling_strategy_is_hyperedge(mock_four_node_hypergraph): assert len(dataset) == 2 -def test_explicit_node_sampling_strategy(mock_four_node_hypergraph): - with patch.object(HIFConverter, "load_from_hif", return_value=mock_four_node_hypergraph): +def test_explicit_node_sampling_strategy(mock_hdata_four_node_hypergraph): + with patch.object(HIFLoader, "load_by_name", return_value=mock_hdata_four_node_hypergraph): dataset = AlgebraDataset(sampling_strategy=SamplingStrategy.NODE) # NODE strategy, so len should be num_nodes (4), not num_hyperedges (2) @@ -1304,8 +691,8 @@ def test_explicit_node_sampling_strategy(mock_four_node_hypergraph): pytest.param(SamplingStrategy.HYPEREDGE, id="hyperedge_strategy"), ], ) -def test_split_preserves_sampling_strategy(mock_four_node_hypergraph, strategy): - with patch.object(HIFConverter, "load_from_hif", return_value=mock_four_node_hypergraph): +def test_split_preserves_sampling_strategy(mock_hdata_four_node_hypergraph, strategy): + with patch.object(HIFLoader, "load_by_name", return_value=mock_hdata_four_node_hypergraph): dataset = AlgebraDataset(sampling_strategy=strategy) splits = dataset.split([0.5, 0.5]) @@ -1322,7 +709,7 @@ def test_from_hdata_with_explicit_strategy(mock_hdata): def test_update_from_hdata_returns_new_dataset(mock_hdata): - dataset = Dataset(hdata=mock_hdata, prepare=False) + dataset = Dataset(hdata=mock_hdata) new_x = torch.ones((2, 1), dtype=torch.float) new_hyperedge_index = torch.tensor([[0, 1], [0, 0]], dtype=torch.long) new_hdata = HData(x=new_x, hyperedge_index=new_hyperedge_index) @@ -1335,7 +722,7 @@ def test_update_from_hdata_returns_new_dataset(mock_hdata): def test_update_from_hdata_stores_provided_hdata(mock_hdata): - dataset = Dataset(hdata=mock_hdata, prepare=False) + dataset = Dataset(hdata=mock_hdata) new_x = torch.ones((2, 1), dtype=torch.float) new_hyperedge_index = torch.tensor([[0, 1], [0, 0]], dtype=torch.long) new_hdata = HData(x=new_x, hyperedge_index=new_hyperedge_index) @@ -1353,7 +740,7 @@ def test_update_from_hdata_stores_provided_hdata(mock_hdata): ], ) def test_update_from_hdata_inherits_sampling_strategy(mock_hdata, strategy, expected_len): - dataset = Dataset(hdata=mock_hdata, sampling_strategy=strategy, prepare=False) + dataset = Dataset(hdata=mock_hdata, sampling_strategy=strategy) new_x = torch.ones((4, 1), dtype=torch.float) new_hyperedge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 2]], dtype=torch.long) new_hdata = HData(x=new_x, hyperedge_index=new_hyperedge_index) @@ -1365,7 +752,7 @@ def test_update_from_hdata_inherits_sampling_strategy(mock_hdata, strategy, expe def test_update_from_hdata_preserves_subclass_type(mock_hdata): - dataset = AlgebraDataset(hdata=mock_hdata, prepare=False) + dataset = AlgebraDataset(hdata=mock_hdata) new_x = torch.ones((2, 1), dtype=torch.float) new_hyperedge_index = torch.tensor([[0, 1], [0, 0]], dtype=torch.long) new_hdata = HData(x=new_x, hyperedge_index=new_hyperedge_index) @@ -1420,3 +807,33 @@ def test_dataset_stats_computation(mock_hdata_stats): stats = dataset.stats() assert stats == expected_stats + + +def test_transform_node_attrs_delegates_to_hifprocessor(): + attrs = {"weight": 1.5, "score": 0.8} + attr_keys = ["score", "weight"] + expected = torch.tensor([0.8, 1.5], dtype=torch.float) + + with patch( + "hyperbench.data.dataset.HIFProcessor.transform_attrs", + return_value=expected, + ) as mock_transform: + result = Dataset.transform_node_attrs(attrs, attr_keys) + + mock_transform.assert_called_once_with(attrs, attr_keys) + assert torch.equal(result, expected) + + +def test_transform_hyperedge_attrs_delegates_to_hifprocessor(): + attrs = {"weight": 2.5} + attr_keys = ["weight", "score"] + expected = torch.tensor([2.5, 0.0], dtype=torch.float) + + with patch( + "hyperbench.data.dataset.HIFProcessor.transform_attrs", + return_value=expected, + ) as mock_transform: + result = Dataset.transform_hyperedge_attrs(attrs, attr_keys) + + mock_transform.assert_called_once_with(attrs, attr_keys) + assert torch.equal(result, expected) diff --git a/hyperbench/tests/data/hif_test.py b/hyperbench/tests/data/hif_test.py new file mode 100644 index 0000000..b6b1a47 --- /dev/null +++ b/hyperbench/tests/data/hif_test.py @@ -0,0 +1,639 @@ +import pytest +import requests +import torch +import json +import os + +from unittest.mock import patch, MagicMock + +from hyperbench.data import HIFLoader, HIFProcessor +from hyperbench.types import HData, HIFHypergraph + + +@pytest.fixture +def mock_sample_hypergraph(): + return HIFHypergraph( + network_type="undirected", + nodes=[{"node": "0"}, {"node": "1"}], + hyperedges=[{"edge": "0"}], + incidences=[{"node": "0", "edge": "0"}], + ) + + +@pytest.fixture +def mock_simple_hypergraph(): + return HIFHypergraph( + network_type="undirected", + nodes=[{"node": "0", "attrs": {}}, {"node": "1", "attrs": {}}], + hyperedges=[{"edge": "0", "attrs": {}}], + incidences=[{"node": "0", "edge": "0"}], + ) + + +@pytest.fixture +def mock_three_node_weighted_hypergraph(): + return HIFHypergraph( + network_type="undirected", + nodes=[ + {"node": "0", "attrs": {}}, + {"node": "1", "attrs": {}}, + {"node": "2", "attrs": {}}, + ], + hyperedges=[ + {"edge": "0", "attrs": {"weight": 1.0}}, + {"edge": "1", "attrs": {"weight": 2.0}}, + ], + incidences=[ + {"node": "0", "edge": "0"}, + {"node": "1", "edge": "0"}, + {"node": "2", "edge": "1"}, + ], + ) + + +@pytest.fixture +def mock_four_node_hypergraph(): + return HIFHypergraph( + network_type="undirected", + nodes=[ + {"node": "0", "attrs": {}}, + {"node": "1", "attrs": {}}, + {"node": "2", "attrs": {}}, + {"node": "3", "attrs": {}}, + ], + hyperedges=[{"edge": "0", "attrs": {}}, {"edge": "1", "attrs": {}}], + incidences=[ + {"node": "0", "edge": "0"}, + {"node": "1", "edge": "0"}, + {"node": "2", "edge": "1"}, + {"node": "3", "edge": "1"}, + ], + ) + + +@pytest.fixture +def mock_five_node_hypergraph(): + return HIFHypergraph( + network_type="undirected", + nodes=[ + {"node": "0", "attrs": {}}, + {"node": "1", "attrs": {}}, + {"node": "2", "attrs": {}}, + {"node": "3", "attrs": {}}, + {"node": "4", "attrs": {}}, + ], + hyperedges=[{"edge": "0", "attrs": {}}], + incidences=[{"node": "0", "edge": "0"}], + ) + + +@pytest.fixture +def mock_no_edge_attr_hypergraph(): + return HIFHypergraph( + network_type="undirected", + nodes=[ + {"node": "0", "attrs": {}}, + {"node": "1", "attrs": {}}, + ], + hyperedges=[{"edge": "0"}], + incidences=[ + {"node": "0", "edge": "0"}, + {"node": "1", "edge": "0"}, + ], + ) + + +@pytest.fixture +def mock_multiple_edges_attr_hypergraph(): + return HIFHypergraph( + network_type="undirected", + nodes=[ + {"node": "0", "attrs": {}}, + {"node": "1", "attrs": {}}, + {"node": "2", "attrs": {}}, + {"node": "3", "attrs": {}}, + ], + hyperedges=[ + {"edge": "0", "attrs": {"weight": 1.0}}, + {"edge": "1", "attrs": {"weight": 2.0}}, + {"edge": "2", "attrs": {"weight": 3.0}}, + ], + incidences=[ + {"node": "0", "edge": "0"}, + {"node": "1", "edge": "0"}, + {"node": "2", "edge": "1"}, + {"node": "3", "edge": "2"}, + ], + ) + + +@pytest.fixture +def mock_hypergraph() -> HIFHypergraph: + return HIFHypergraph( + network_type="undirected", + nodes=[{"node": "0", "attrs": {}}, {"node": "1", "attrs": {}}], + hyperedges=[{"edge": "0", "attrs": {"weight": 1.0}}], + incidences=[{"node": "0", "edge": "0"}, {"node": "1", "edge": "0"}], + ) + + +@pytest.fixture +def mock_hdata() -> HData: + x = torch.ones((2, 1), dtype=torch.float) + hyperedge_index = torch.tensor([[0, 1], [0, 0]], dtype=torch.long) + return HData(x=x, hyperedge_index=hyperedge_index) + + +def _write_hif_json(tmp_path, hypergraph: HIFHypergraph, filename: str = "sample.json") -> str: + path = tmp_path / filename + payload = { + "network-type": hypergraph.network_type, + "metadata": hypergraph.metadata, + "nodes": hypergraph.nodes, + "edges": hypergraph.hyperedges, + "incidences": hypergraph.incidences, + } + with open(path, "w") as f: + json.dump(payload, f) + return str(path) + + +def _mock_named_temporary_file(path): + file_handle = open(path, "wb") + mocked_cm = MagicMock() + mocked_cm.__enter__.return_value = file_handle + mocked_cm.__exit__.side_effect = lambda exc_type, exc, tb: file_handle.close() + return mocked_cm + + +def test_transform_attrs_empty_attrs(): + result = HIFProcessor.transform_attrs({}) + assert len(result) == 0 + + attrs = {"name": "node1", "active": True} + result = HIFProcessor.transform_attrs(attrs) + assert len(result) == 0 + + +def test_transform_attrs_adds_padding_zero_when_attr_keys_padding(): + attrs = {"weight": 1.5} + result = HIFProcessor.transform_attrs(attrs, attr_keys=["score", "weight", "age"]) + assert torch.allclose(result, torch.tensor([0.0, 1.5, 0.0])) + + attrs = {"weight": 1.5, "score": 0.8, "age": 25.0} + result = HIFProcessor.transform_attrs(attrs, attr_keys=["age", "score", "weight"]) + assert torch.allclose(result, torch.tensor([25.0, 0.8, 1.5])) + + attrs = {"weight": 1.5, "score": 0.8} + result = HIFProcessor.transform_attrs(attrs) + assert torch.allclose(result, torch.tensor([1.5, 0.8])) + + +def test_transform_hyperedge_attrs_adds_padding_zero_when_attr_keys_padding(): + attrs = {"weight": 2.5} + result = HIFProcessor.transform_attrs(attrs, attr_keys=["score", "weight"]) + assert torch.allclose(result, torch.tensor([0.0, 2.5])) + + +def test_transform_node_attrs_adds_padding_zero_when_attr_keys_padding(): + attrs = {"weight": 2.5} + result = HIFProcessor.transform_attrs(attrs, attr_keys=["score", "weight"]) + assert torch.allclose(result, torch.tensor([0.0, 2.5])) + + +def test_load_from_url_rejects_invalid_url(): + with pytest.raises(ValueError, match="Invalid URL"): + HIFLoader.load_from_url("not-a-url") + + +def test_load_from_url_raises_when_status_is_not_200(): + with patch("hyperbench.data.hif.requests.get") as mock_get: + mock_response = mock_get.return_value + mock_response.status_code = 404 + + with pytest.raises(ValueError, match="Failed to download dataset from URL"): + HIFLoader.load_from_url("https://example.com/file.json.zst") + + +def test_load_from_path_raises_for_missing_file(): + with pytest.raises(ValueError, match="does not exist"): + HIFLoader.load_from_path("/abc/does-not-exist.json.zst") + + +def test_load_from_path_raises_for_unsupported_extension(tmp_path): + invalid = tmp_path / "sample.txt" + invalid.write_text("{}") + + with pytest.raises(ValueError, match="Unsupported file format"): + HIFLoader.load_from_path(str(invalid)) + + +@pytest.mark.parametrize( + "fixture_name, expected_nodes, expected_hyperedges, expected_incidences, has_hyperedge_weights", + [ + pytest.param("mock_sample_hypergraph", 2, 2, 2, False, id="sample_with_isolated_node"), + pytest.param("mock_simple_hypergraph", 2, 2, 2, True, id="simple_with_empty_attrs"), + pytest.param("mock_three_node_weighted_hypergraph", 3, 2, 3, True, id="weighted"), + pytest.param("mock_four_node_hypergraph", 4, 2, 4, True, id="four_nodes_two_edges"), + pytest.param("mock_five_node_hypergraph", 5, 5, 5, True, id="five_nodes_with_self_loops"), + pytest.param("mock_no_edge_attr_hypergraph", 2, 1, 2, False, id="no_edge_attr"), + pytest.param("mock_multiple_edges_attr_hypergraph", 4, 3, 4, True, id="multiple_weighted"), + ], +) +def test_load_from_path_processes_hypergraph_cases( + tmp_path, + request, + fixture_name, + expected_nodes, + expected_hyperedges, + expected_incidences, + has_hyperedge_weights, +): + hypergraph = request.getfixturevalue(fixture_name) + json_path = _write_hif_json(tmp_path, hypergraph, filename=f"{fixture_name}.json") + + with patch("hyperbench.data.hif.validate_hif_json", return_value=True): + hdata = HIFLoader.load_from_path(json_path) + + assert hdata.num_nodes == expected_nodes + assert hdata.num_hyperedges == expected_hyperedges + assert hdata.hyperedge_index.shape[1] == expected_incidences + assert (hdata.hyperedge_weights is not None) is has_hyperedge_weights + + +def test_load_from_path_zst_uses_decompress(tmp_path, mock_hypergraph): + zst_path = tmp_path / "sample.json.zst" + zst_path.write_bytes(b"dummy") + json_path = _write_hif_json(tmp_path, mock_hypergraph) + + with ( + patch("hyperbench.data.hif.decompress_zst", return_value=json_path) as mock_decompress, + patch("hyperbench.data.hif.validate_hif_json", return_value=True), + ): + hdata = HIFLoader.load_from_path(str(zst_path)) + + mock_decompress.assert_called_once_with(str(zst_path)) + assert hdata.num_nodes == 2 + assert hdata.num_hyperedges == 1 + + +def test_load_from_path_raises_for_non_hif_compliant_json(tmp_path, mock_hypergraph): + json_path = _write_hif_json(tmp_path, mock_hypergraph) + + with patch("hyperbench.data.hif.validate_hif_json", return_value=False): + with pytest.raises(ValueError, match="is not HIF-compliant"): + HIFLoader.load_from_path(json_path) + + +def test_load_from_url_processes_zst_and_saves_to_disk(tmp_path, mock_hypergraph): + unique_name = f"algebra_{tmp_path.name}.json.zst" + url = f"https://example.com/{unique_name}" + json_path = _write_hif_json(tmp_path, mock_hypergraph) + + with ( + patch("hyperbench.data.hif.requests.get") as mock_get, + patch("hyperbench.data.hif.decompress_zst", return_value=json_path), + patch("hyperbench.data.hif.validate_hif_json", return_value=True), + patch("hyperbench.data.hif.write_to_disk") as mock_write_to_disk, + ): + mock_response = mock_get.return_value + mock_response.status_code = 200 + mock_response.content = b"mock-zst-content" + + hdata = HIFLoader.load_from_url(url, save_on_disk=True) + + mock_write_to_disk.assert_called_once_with(unique_name, b"mock-zst-content") + assert hdata.num_nodes == 2 + assert hdata.num_hyperedges == 1 + + +def test_load_from_url_processes_json_and_saves_compressed_copy(tmp_path, mock_hypergraph): + unique_name = f"algebra_{tmp_path.name}.json" + url = f"https://example.com/{unique_name}" + payload = { + "network-type": mock_hypergraph.network_type, + "metadata": mock_hypergraph.metadata, + "nodes": mock_hypergraph.nodes, + "edges": mock_hypergraph.hyperedges, + "incidences": mock_hypergraph.incidences, + } + + with ( + patch("hyperbench.data.hif.requests.get") as mock_get, + patch( + "hyperbench.data.hif.tempfile.NamedTemporaryFile", + return_value=_mock_named_temporary_file(tmp_path / "downloaded.json"), + ), + patch("hyperbench.data.hif.compress_to_zst", return_value=b"compressed") as mock_compress, + patch("hyperbench.data.hif.validate_hif_json", return_value=True), + patch("hyperbench.data.hif.write_to_disk") as mock_write_to_disk, + ): + mock_response = mock_get.return_value + mock_response.status_code = 200 + mock_response.content = json.dumps(payload).encode("utf-8") + + hdata = HIFLoader.load_from_url(url, save_on_disk=True) + + mock_compress.assert_called_once_with(str(tmp_path / "downloaded.json")) + mock_write_to_disk.assert_called_once_with(unique_name, b"compressed") + assert hdata.num_nodes == 2 + assert hdata.num_hyperedges == 1 + + +def test_load_from_url_processes_zst_without_saving_to_disk(tmp_path, mock_hypergraph): + unique_name = f"algebra_{tmp_path.name}.json.zst" + url = f"https://example.com/{unique_name}" + json_path = _write_hif_json(tmp_path, mock_hypergraph) + + with ( + patch("hyperbench.data.hif.requests.get") as mock_get, + patch("hyperbench.data.hif.decompress_zst", return_value=json_path) as mock_decompress, + patch("hyperbench.data.hif.validate_hif_json", return_value=True), + patch("hyperbench.data.hif.write_to_disk") as mock_write_to_disk, + ): + mock_response = mock_get.return_value + mock_response.status_code = 200 + mock_response.content = b"mock-zst-content" + + hdata = HIFLoader.load_from_url(url, save_on_disk=False) + + mock_write_to_disk.assert_not_called() + mock_decompress.assert_called_once() + assert hdata.num_nodes == 2 + assert hdata.num_hyperedges == 1 + + +def test_load_from_url_processes_json_without_saving_to_disk(tmp_path, mock_hypergraph): + unique_name = f"algebra_{tmp_path.name}.json" + url = f"https://example.com/{unique_name}" + payload = { + "network-type": mock_hypergraph.network_type, + "metadata": mock_hypergraph.metadata, + "nodes": mock_hypergraph.nodes, + "edges": mock_hypergraph.hyperedges, + "incidences": mock_hypergraph.incidences, + } + + with ( + patch("hyperbench.data.hif.requests.get") as mock_get, + patch( + "hyperbench.data.hif.tempfile.NamedTemporaryFile", + return_value=_mock_named_temporary_file(tmp_path / "downloaded_no_save.json"), + ), + patch("hyperbench.data.hif.compress_to_zst") as mock_compress, + patch("hyperbench.data.hif.validate_hif_json", return_value=True), + patch("hyperbench.data.hif.write_to_disk") as mock_write_to_disk, + ): + mock_response = mock_get.return_value + mock_response.status_code = 200 + mock_response.content = json.dumps(payload).encode("utf-8") + + hdata = HIFLoader.load_from_url(url, save_on_disk=False) + + mock_compress.assert_not_called() + mock_write_to_disk.assert_not_called() + assert hdata.num_nodes == 2 + assert hdata.num_hyperedges == 1 + + +def test_load_from_path_processes_node_numeric_attrs_into_features(tmp_path): + hypergraph = HIFHypergraph( + network_type="undirected", + nodes=[ + {"node": "0", "attrs": {"weight": 1.0, "score": 0.5}}, + {"node": "1", "attrs": {"weight": 2.0, "score": 1.5}}, + ], + hyperedges=[{"edge": "0", "attrs": {}}], + incidences=[{"node": "0", "edge": "0"}, {"node": "1", "edge": "0"}], + ) + json_path = _write_hif_json(tmp_path, hypergraph, filename="nodes_with_attrs.json") + + with patch("hyperbench.data.hif.validate_hif_json", return_value=True): + hdata = HIFLoader.load_from_path(json_path) + + assert hdata.x.shape == (2, 2) + assert torch.allclose(hdata.x[0], torch.tensor([1.0, 0.5])) + assert torch.allclose(hdata.x[1], torch.tensor([2.0, 1.5])) + + +def test_load_from_url_raises_for_unsupported_temp_extension(tmp_path): + with ( + patch("hyperbench.data.hif.requests.get") as mock_get, + patch( + "hyperbench.data.hif.tempfile.NamedTemporaryFile", + return_value=_mock_named_temporary_file(tmp_path / "downloaded.bin"), + ), + ): + mock_response = mock_get.return_value + mock_response.status_code = 200 + mock_response.content = b"bytes" + + with pytest.raises(ValueError, match="Unsupported file format"): + HIFLoader.load_from_url("https://example.com/algebra.unknown") + + +def test_load_skips_download_when_file_exists(tmp_path, mock_hypergraph): + json_path = _write_hif_json(tmp_path, mock_hypergraph) + + with ( + patch("hyperbench.data.hif.os.path.exists", return_value=True), + patch("hyperbench.data.hif.requests.get") as mock_get, + patch("hyperbench.data.hif.decompress_zst", return_value=json_path), + patch("hyperbench.data.hif.validate_hif_json", return_value=True), + ): + result = HIFLoader.load_by_name("algebra", save_on_disk=True) + + mock_get.assert_not_called() + assert result.num_nodes == 2 + assert result.num_hyperedges == 1 + + +def test_HIFLoader_download_failure_when_hf_fallback_fails(): + with ( + patch("hyperbench.data.hif.os.path.exists", return_value=False), + patch("hyperbench.data.hif.requests.get") as mock_get, + patch("hyperbench.data.hif.hf_hub_download", side_effect=Exception("HFHub failed")), + ): + mock_response = mock_get.return_value + mock_response.status_code = 404 + mock_response.content = b"" + + with pytest.warns(UserWarning, match="GitHub raw download failed"): + with pytest.raises(ValueError, match="Failed to download dataset 'algebra'"): + HIFLoader.load_by_name("algebra") + + +def test_HIFLoader_falls_back_to_hf_hub_download_when_github_raw_download_fails( + tmp_path, mock_hypergraph +): + fallback_file = tmp_path / "algebra.json.zst" + fallback_file.write_bytes(b"mock_zst_content") + json_path = _write_hif_json(tmp_path, mock_hypergraph) + + with ( + patch("hyperbench.data.hif.os.path.exists", return_value=False), + patch("hyperbench.data.hif.requests.get") as mock_get, + patch( + "hyperbench.data.hif.hf_hub_download", return_value=str(fallback_file) + ) as mock_hf_hub_download, + patch("hyperbench.data.hif.decompress_zst", return_value=json_path), + patch("hyperbench.data.hif.validate_hif_json", return_value=True), + ): + mock_response = mock_get.return_value + mock_response.status_code = 404 + mock_response.content = b"" + + with pytest.warns(UserWarning, match="GitHub raw download failed"): + with pytest.raises( + ValueError, + match="Failed to download dataset 'algebra' from GitHub with status code 404 and no SHA provided for Hugging Face Hub fallback.", + ): + result = HIFLoader.load_by_name("algebra", save_on_disk=False) + + +def test_load_saves_downloaded_dataset_on_disk(tmp_path, mock_hypergraph): + json_path = _write_hif_json(tmp_path, mock_hypergraph) + + with ( + patch("hyperbench.data.hif.os.path.exists", return_value=False), + patch("hyperbench.data.hif.requests.get") as mock_get, + patch("hyperbench.data.hif.decompress_zst", return_value=json_path), + patch("hyperbench.data.hif.validate_hif_json", return_value=True), + patch("hyperbench.data.hif.os.path.abspath", return_value=str(tmp_path / "hif.py")), + ): + mock_response = mock_get.return_value + mock_response.status_code = 200 + mock_response.content = b"downloaded-content" + + result = HIFLoader.load_by_name("algebra", save_on_disk=True) + + saved = tmp_path / "datasets" / "algebra.json.zst" + assert saved.exists() + assert saved.read_bytes() == b"downloaded-content" + assert result.num_nodes == 2 + assert result.num_hyperedges == 1 + + +def test_HIFLoader_download_raises_when_network_error(): + with ( + patch("hyperbench.data.hif.os.path.exists", return_value=False), + patch( + "hyperbench.data.hif.requests.get", + side_effect=requests.RequestException("Network error"), + ), + ): + with pytest.raises(requests.RequestException, match="Network error"): + HIFLoader.load_by_name("algebra") + + +def test_load_by_name_uses_hf_revision_when_github_download_fails(tmp_path, mock_hypergraph): + hf_sha = "2bb641461e00c103fb5ef4fe6a30aad42500fc21" + fallback_file = tmp_path / "algebra.json.zst" + fallback_file.write_bytes(b"mock_zst_content") + json_path = _write_hif_json(tmp_path, mock_hypergraph) + + response = requests.Response() + response.status_code = 404 + response._content = b"" + + with ( + patch("hyperbench.data.hif.os.path.exists", return_value=False), + patch("hyperbench.data.hif.requests.get", return_value=response), + patch( + "hyperbench.data.hif.hf_hub_download", return_value=str(fallback_file) + ) as mock_hf_hub_download, + patch("hyperbench.data.hif.decompress_zst", return_value=json_path), + patch("hyperbench.data.hif.validate_hif_json", return_value=True), + ): + with pytest.warns(UserWarning, match="GitHub raw download failed"): + result = HIFLoader.load_by_name("algebra", hf_sha=hf_sha, save_on_disk=False) + + mock_hf_hub_download.assert_called_once_with( + repo_id="HypernetworkRG/algebra", + filename="algebra.json.zst", + repo_type="dataset", + revision=hf_sha, + ) + assert result.num_nodes == 2 + assert result.num_hyperedges == 1 + + +def test_load_by_name_raises_when_hf_sha_is_missing_on_fallback(): + response = requests.Response() + response.status_code = 404 + response._content = b"" + + with ( + patch("hyperbench.data.hif.os.path.exists", return_value=False), + patch("hyperbench.data.hif.requests.get", return_value=response), + patch("hyperbench.data.hif.hf_hub_download") as mock_hf_hub_download, + ): + with pytest.warns(UserWarning, match="GitHub raw download failed"): + with pytest.raises( + ValueError, + match="no SHA provided for Hugging Face Hub fallback", + ): + HIFLoader.load_by_name("algebra", save_on_disk=False) + + mock_hf_hub_download.assert_not_called() + + +def test_load_by_name_reads_hf_download_and_saves_its_content(tmp_path, mock_hypergraph): + hf_sha = "2bb641461e00c103fb5ef4fe6a30aad42500fc21" + hf_content = b"mock_zst_content" + fallback_file = tmp_path / "algebra.json.zst" + fallback_file.write_bytes(hf_content) + json_path = _write_hif_json(tmp_path, mock_hypergraph) + + response = requests.Response() + response.status_code = 404 + response._content = b"" + + with ( + patch("hyperbench.data.hif.os.path.exists", return_value=False), + patch("hyperbench.data.hif.requests.get", return_value=response), + patch("hyperbench.data.hif.hf_hub_download", return_value=str(fallback_file)), + patch("hyperbench.data.hif.decompress_zst", return_value=json_path), + patch("hyperbench.data.hif.validate_hif_json", return_value=True), + patch("hyperbench.data.hif.__file__", str(tmp_path / "hif.py")), + ): + with pytest.warns(UserWarning, match="GitHub raw download failed"): + result = HIFLoader.load_by_name("algebra", hf_sha=hf_sha, save_on_disk=True) + + saved = tmp_path / "datasets" / "algebra.json.zst" + assert saved.exists() + assert saved.read_bytes() == hf_content + assert result.num_nodes == 2 + assert result.num_hyperedges == 1 + + +def test_HIFLoader_download_failure_when_hf_fallback_fails(): + hf_sha = "2bb641461e00c103fb5ef4fe6a30aad42500fc21" + response = requests.Response() + response.status_code = 404 + response._content = b"" + + with ( + patch("hyperbench.data.hif.os.path.exists", return_value=False), + patch("hyperbench.data.hif.requests.get", return_value=response), + patch( + "hyperbench.data.hif.hf_hub_download", + side_effect=Exception("HFHub failed"), + ) as mock_hf_hub_download, + ): + with pytest.warns(UserWarning, match="GitHub raw download failed"): + with pytest.raises( + ValueError, + match=( + r"Failed to download dataset 'algebra' from GitHub and Hugging Face Hub\. " + r"GitHub error: 404 \| Hugging Face error: HFHub failed" + ), + ): + HIFLoader.load_by_name("algebra", hf_sha=hf_sha) + + mock_hf_hub_download.assert_called_once_with( + repo_id="HypernetworkRG/algebra", + filename="algebra.json.zst", + repo_type="dataset", + revision=hf_sha, + ) diff --git a/hyperbench/tests/types/hypergraph_test.py b/hyperbench/tests/types/hypergraph_test.py index a2c6f59..9e4137c 100644 --- a/hyperbench/tests/types/hypergraph_test.py +++ b/hyperbench/tests/types/hypergraph_test.py @@ -20,6 +20,8 @@ def test_build_HIFHypergraph_instance(): hypergraph = HIFHypergraph.from_hif(hiftext) assert isinstance(hypergraph, HIFHypergraph) + assert hypergraph.num_nodes == 423 + assert hypergraph.num_hyperedges == 1268 def test_empty_hifhypergraph_returns_empty_hifhypergraph(): diff --git a/hyperbench/tests/utils/file_utils_test.py b/hyperbench/tests/utils/file_utils_test.py new file mode 100644 index 0000000..6f703b0 --- /dev/null +++ b/hyperbench/tests/utils/file_utils_test.py @@ -0,0 +1,37 @@ +import os + +from hyperbench.utils import write_to_disk +from unittest.mock import patch, MagicMock + + +def test_write_to_disk_writes_file_default_output_dir(tmp_path): + dataset_name = "test_dataset" + content = b"test content" + + # Force write_to_disk default branch to resolve under tmp_path. + fake_module_file = tmp_path / "hyperbench" / "utils" / "file_utils.py" + + with patch( + "hyperbench.utils.file_utils.os.path.abspath", + return_value=str(fake_module_file), + ): + write_to_disk(dataset_name, content) + + expected_path = tmp_path / "hyperbench" / "data" / "datasets" / f"{dataset_name}.json.zst" + assert expected_path.is_file() + assert expected_path.read_bytes() == content + + +def test_write_to_disk_writes_file_optional_output_dir(tmp_path): + dataset_name = "test_dataset" + content = b"test content" + output_dir = tmp_path + + write_to_disk(dataset_name, content, output_dir) + + expected_path = tmp_path / f"{dataset_name}.json.zst" + assert expected_path.is_file() + + with open(expected_path, "rb") as f: + file_content = f.read() + assert file_content == content diff --git a/hyperbench/tests/utils/hif_utils_test.py b/hyperbench/tests/utils/hif_utils_test.py index 8695ea5..9161177 100644 --- a/hyperbench/tests/utils/hif_utils_test.py +++ b/hyperbench/tests/utils/hif_utils_test.py @@ -1,7 +1,15 @@ import requests +import json +import os from unittest.mock import patch, mock_open, MagicMock -from hyperbench.utils import validate_hif_json +from hyperbench.utils import ( + validate_hif_json, + compress_to_zst, + decompress_zst, + get_datasets_shas, + get_dataset_sha, +) from hyperbench.tests import MOCK_BASE_PATH @@ -23,7 +31,7 @@ def test_validate_hif_json_with_url_success(): validate_hif_json(path_valid) mock_get.assert_called_once_with( - "https://raw.githubusercontent.com/HIF-org/HIF-standard/main/schemas/hif_schema.json", + "https://raw.githubusercontent.com/HIF-org/HIF-standard/b691a3d2ec32100c0229ebe1151e9afad015c356/schemas/hif_schema.json", timeout=10, ) @@ -54,3 +62,97 @@ def test_validate_hif_json_with_url_request_exception_fallback(): # local file was opened calls = [str(call) for call in mock_file.call_args_list] assert any("../schema/hif_schema.json" in call for call in calls) + + +def test_compress_to_zst_returns_non_empty_bytes(tmp_path): + json_path = tmp_path / "sample.json" + json_path.write_text('{"nodes": [], "edges": [], "incidences": []}') + + compressed_content = compress_to_zst(str(json_path)) + + assert isinstance(compressed_content, bytes) + assert len(compressed_content) > 0 + + +def test_decompress_zst_round_trip_preserves_json_content(tmp_path): + expected_data = { + "network-type": "undirected", + "nodes": [{"node": "0", "attrs": {"weight": 1.0}}], + "edges": [{"edge": "0", "attrs": {}}], + "incidences": [{"node": "0", "edge": "0"}], + } + + json_path = tmp_path / "sample.json" + with open(json_path, "w") as f: + json.dump(expected_data, f) + + compressed_content = compress_to_zst(str(json_path)) + zst_path = tmp_path / "sample.json.zst" + zst_path.write_bytes(compressed_content) + + decompressed_path = decompress_zst(str(zst_path)) + + assert decompressed_path.endswith(".json") + assert os.path.exists(decompressed_path) + + with open(decompressed_path, "r") as f: + decompressed_data = json.load(f) + + assert decompressed_data == expected_data + + +def test_get_datasets_shas_returns_shas_and_none_on_failure(): + names = ["algebra", "missing-dataset"] + + def dataset_info_side_effect(*, repo_id): + if repo_id.endswith("/missing-dataset"): + raise RuntimeError("not found") + info = MagicMock() + info.sha = "sha-algebra" + return info + + with ( + patch("hyperbench.utils.hif_utils.HfApi") as mock_hf_api, + patch("builtins.print") as mock_print, + ): + mock_hf_api.return_value.dataset_info.side_effect = dataset_info_side_effect + + result = get_datasets_shas(names) + + assert result == { + "algebra": "sha-algebra", + "missing-dataset": None, + } + mock_hf_api.return_value.dataset_info.assert_any_call(repo_id="HypernetworkRG/algebra") + mock_hf_api.return_value.dataset_info.assert_any_call(repo_id="HypernetworkRG/missing-dataset") + assert any( + "missing-dataset: failed to retrieve SHA" in call.args[0] + for call in mock_print.call_args_list + ) + + +def test_get_dataset_sha_returns_sha(): + with patch("hyperbench.utils.hif_utils.HfApi") as mock_hf_api: + info = MagicMock() + info.sha = "sha-amazon" + mock_hf_api.return_value.dataset_info.return_value = info + + result = get_dataset_sha("amazon") + + assert result == "sha-amazon" + mock_hf_api.return_value.dataset_info.assert_called_once_with(repo_id="HypernetworkRG/amazon") + + +def test_get_dataset_sha_returns_none_on_failure(): + with ( + patch("hyperbench.utils.hif_utils.HfApi") as mock_hf_api, + patch("builtins.print") as mock_print, + ): + mock_hf_api.return_value.dataset_info.side_effect = RuntimeError("boom") + + result = get_dataset_sha("amazon") + + assert result is None + mock_hf_api.return_value.dataset_info.assert_called_once_with(repo_id="HypernetworkRG/amazon") + mock_print.assert_called_once() + assert "amazon: failed to retrieve SHA" in mock_print.call_args[0][0] diff --git a/hyperbench/utils/__init__.py b/hyperbench/utils/__init__.py index 65e09ab..c9c708a 100644 --- a/hyperbench/utils/__init__.py +++ b/hyperbench/utils/__init__.py @@ -5,7 +5,7 @@ to_non_empty_edgeattr, to_0based_ids, ) -from .hif_utils import validate_hif_json +from .hif_utils import validate_hif_json, get_datasets_shas, get_dataset_sha from .nn_utils import ( INPUT_LAYER, ActivationFn, @@ -16,6 +16,8 @@ is_layer, ) from .sparse_utils import sparse_dropout +from .url_utils import validate_http_url +from .file_utils import decompress_zst, compress_to_zst, write_to_disk __all__ = [ "INPUT_LAYER", @@ -32,4 +34,10 @@ "to_non_empty_edgeattr", "to_0based_ids", "validate_hif_json", + "decompress_zst", + "compress_to_zst", + "validate_http_url", + "write_to_disk", + "get_datasets_shas", + "get_dataset_sha", ] diff --git a/hyperbench/utils/file_utils.py b/hyperbench/utils/file_utils.py new file mode 100644 index 0000000..c263be2 --- /dev/null +++ b/hyperbench/utils/file_utils.py @@ -0,0 +1,59 @@ +import os +import tempfile +import zstandard as zstd + +from typing import Optional + + +def decompress_zst(zst_path: str) -> str: + """ + Decompresses a .zst file and returns the path to the decompressed JSON file. + Args: + zst_path: The path to the .zst file to decompress. + Returns: + The path to the decompressed JSON file. + """ + dctx = zstd.ZstdDecompressor() + with ( + open(zst_path, "rb") as input_f, + tempfile.NamedTemporaryFile(mode="wb", suffix=".json", delete=False) as tmp_file, + ): + dctx.copy_stream(input_f, tmp_file) + output = tmp_file.name + return output + + +def compress_to_zst(json_path: str) -> bytes: + """ + Compresses a JSON file to .zst format and returns the compressed bytes. + + Args: + json_path: The path to the JSON file to compress. + Returns: + The compressed content as bytes. + """ + cctx = zstd.ZstdCompressor() + with open(json_path, "rb") as input_f: + compressed_content = cctx.compress(input_f.read()) + return compressed_content + + +def write_to_disk(dataset_name: str, content: bytes, output_dir: Optional[str] = None) -> None: + """ + Writes the compressed content to disk in the specified output directory or a default location. + Args: + dataset_name: The name of the dataset. + content: The compressed content as bytes. + output_dir: The directory to write the file to. If None, a default location is used. + """ + if output_dir is not None: + zst_filename = os.path.join(output_dir, f"{dataset_name}.json.zst") + else: + current_dir = os.path.dirname(os.path.abspath(__file__)) + output_dir = os.path.join(current_dir, "..", "data", "datasets") + zst_filename = os.path.join(output_dir, f"{dataset_name}.json.zst") + + os.makedirs(output_dir, exist_ok=True) + + with open(zst_filename, "wb") as f: + f.write(content) diff --git a/hyperbench/utils/hif_utils.py b/hyperbench/utils/hif_utils.py index c9cd944..f26178f 100644 --- a/hyperbench/utils/hif_utils.py +++ b/hyperbench/utils/hif_utils.py @@ -2,6 +2,10 @@ import json import requests +from huggingface_hub import HfApi + +HIF_SCHEMA_COMMIT_SHA = "b691a3d2ec32100c0229ebe1151e9afad015c356" + def validate_hif_json(filename: str) -> bool: """ @@ -13,7 +17,7 @@ def validate_hif_json(filename: str) -> bool: Returns: ``True`` if the file is valid HIF, ``False`` otherwise. """ - url = "https://raw.githubusercontent.com/HIF-org/HIF-standard/main/schemas/hif_schema.json" + url = f"https://raw.githubusercontent.com/HIF-org/HIF-standard/{HIF_SCHEMA_COMMIT_SHA}/schemas/hif_schema.json" try: schema = requests.get(url, timeout=10).json() except (requests.RequestException, requests.Timeout): @@ -26,3 +30,30 @@ def validate_hif_json(filename: str) -> bool: return True except Exception: return False + + +def get_datasets_shas(names: list[str], namespace: str = "HypernetworkRG") -> dict[str, str | None]: + api = HfApi() + shas: dict[str, str | None] = {} + + for dataset_name in names: + repo_id = f"{namespace}/{dataset_name}" + try: + info = api.dataset_info(repo_id=repo_id) + shas[dataset_name] = info.sha + except Exception as e: + shas[dataset_name] = None + print(f"{dataset_name}: failed to retrieve SHA ({e})") + + return shas + + +def get_dataset_sha(dataset_name: str, namespace: str = "HypernetworkRG") -> str | None: + api = HfApi() + repo_id = f"{namespace}/{dataset_name}" + try: + info = api.dataset_info(repo_id=repo_id) + return info.sha + except Exception as e: + print(f"{dataset_name}: failed to retrieve SHA ({e})") + return None diff --git a/hyperbench/utils/url_utils.py b/hyperbench/utils/url_utils.py new file mode 100644 index 0000000..d8e192a --- /dev/null +++ b/hyperbench/utils/url_utils.py @@ -0,0 +1,8 @@ +from urllib.parse import urlparse + + +def validate_http_url(value: str) -> str: + parsed = urlparse(value) + if parsed.scheme not in {"http", "https"} or not parsed.netloc: + raise ValueError(f"Invalid URL: {value}") + return value