From 28e89f97ee3dda262220624d93bb1eb64e703920 Mon Sep 17 00:00:00 2001 From: ddevin96 Date: Wed, 22 Apr 2026 16:36:25 +0200 Subject: [PATCH 01/15] feat: messing comming --- hyperbench/data/__init__.py | 52 +- hyperbench/data/dataset.py | 413 +---- hyperbench/data/hif.py | 371 +++++ hyperbench/data/supported_datasets.py | 138 +- hyperbench/tests/data/dataset_test.py | 2046 +++++++++++++------------ hyperbench/tests/data/hif_test.py | 0 6 files changed, 1597 insertions(+), 1423 deletions(-) create mode 100644 hyperbench/data/hif.py create mode 100644 hyperbench/tests/data/hif_test.py diff --git a/hyperbench/data/__init__.py b/hyperbench/data/__init__.py index 70ebe46..706aba3 100644 --- a/hyperbench/data/__init__.py +++ b/hyperbench/data/__init__.py @@ -1,32 +1,30 @@ -from .dataset import ( - Dataset, - HIFConverter, -) +from .dataset import Dataset +from .hif import HIFLoader from .supported_datasets import ( AlgebraDataset, - AmazonDataset, - ContactHighSchoolDataset, - ContactPrimarySchoolDataset, - CoraDataset, - CourseraDataset, - DBLPDataset, - EmailEnronDataset, - EmailW3CDataset, - GeometryDataset, - GOTDataset, - IMDBDataset, - MusicBluesReviewsDataset, - NBADataset, - NDCClassesDataset, - NDCSubstancesDataset, - PatentDataset, - PubmedDataset, - RestaurantReviewsDataset, - ThreadsAskUbuntuDataset, - ThreadsMathsxDataset, - TwitterDataset, - VegasBarsReviewsDataset, + # AmazonDataset, + # ContactHighSchoolDataset, + # ContactPrimarySchoolDataset, + # CoraDataset, + # CourseraDataset, + # DBLPDataset, + # EmailEnronDataset, + # EmailW3CDataset, + # GeometryDataset, + # GOTDataset, + # IMDBDataset, + # MusicBluesReviewsDataset, + # NBADataset, + # NDCClassesDataset, + # NDCSubstancesDataset, + # PatentDataset, + # PubmedDataset, + # RestaurantReviewsDataset, + # ThreadsAskUbuntuDataset, + # ThreadsMathsxDataset, + # TwitterDataset, + # VegasBarsReviewsDataset, ) from .loader import DataLoader @@ -54,7 +52,7 @@ "EmailW3CDataset", "GeometryDataset", "GOTDataset", - "HIFConverter", + "HIFLoader", "HyperedgeSampler", "IMDBDataset", "MusicBluesReviewsDataset", diff --git a/hyperbench/data/dataset.py b/hyperbench/data/dataset.py index 48b874c..bfd1e10 100644 --- a/hyperbench/data/dataset.py +++ b/hyperbench/data/dataset.py @@ -6,9 +6,7 @@ import requests import warnings -from enum import Enum -from huggingface_hub import hf_hub_download -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, TypeAlias, Literal from torch import Tensor from torch.utils.data import Dataset as TorchDataset @@ -18,114 +16,7 @@ from hyperbench.utils import validate_hif_json from hyperbench.data.sampling import SamplingStrategy, create_sampler_from_strategy - - -class DatasetNames(Enum): - """ - Enumeration of available datasets. - """ - - ALGEBRA = "algebra" - AMAZON = "amazon" - CONTACT_HIGH_SCHOOL = "contact-high-school" - CONTACT_PRIMARY_SCHOOL = "contact-primary-school" - CORA = "cora" - COURSERA = "coursera" - DBLP = "dblp" - EMAIL_ENRON = "email-Enron" - EMAIL_W3C = "email-W3C" - GEOMETRY = "geometry" - GOT = "got" - IMDB = "imdb" - MUSIC_BLUES_REVIEWS = "music-blues-reviews" - NBA = "nba" - NDC_CLASSES = "NDC-classes" - NDC_SUBSTANCES = "NDC-substances" - PATENT = "patent" - PUBMED = "pubmed" - RESTAURANT_REVIEWS = "restaurant-reviews" - THREADS_ASK_UBUNTU = "threads-ask-ubuntu" - THREADS_MATH_SX = "threads-math-sx" - TWITTER = "twitter" - VEGAS_BARS_REVIEWS = "vegas-bars-reviews" - - -class HIFConverter: - """A utility class to load hypergraphs from HIF format.""" - - @staticmethod - def load_from_hif(dataset_name: Optional[str], save_on_disk: bool = False) -> HIFHypergraph: - if dataset_name is None: - raise ValueError(f"Dataset name (provided: {dataset_name}) must be provided.") - if dataset_name not in DatasetNames.__members__: - raise ValueError(f"Dataset '{dataset_name}' not found.") - - dataset_name = DatasetNames[dataset_name].value - current_dir = os.path.dirname(os.path.abspath(__file__)) - zst_filename = os.path.join(current_dir, "datasets", f"{dataset_name}.json.zst") - - if not os.path.exists(zst_filename): - github_dataset_repo = f"https://github.com/hypernetwork-research-group/datasets/blob/main/{dataset_name}.json.zst?raw=true" - - response = requests.get(github_dataset_repo) - if response.status_code != 200: - warnings.warn( - f"GitHub raw download failed for dataset '{dataset_name}' with status code {response.status_code}\n" - "Falling back to Hugging Face Hub download for dataset", - category=UserWarning, - stacklevel=2, - ) - - REPO_ID = f"HypernetworkRG/{dataset_name}" - FILENAME = f"{dataset_name}.json.zst" - - with tempfile.NamedTemporaryFile( - mode="wb", suffix=".json.zst", delete=False - ) as tmp_hf_file: - try: - downloaded_path = hf_hub_download( - repo_id=REPO_ID, - filename=FILENAME, - repo_type="dataset", - ) - except Exception as e: - raise ValueError( - f"Failed to download dataset '{dataset_name}' from GitHub and Hugging Face Hub. GitHub error: {response.status_code} | Hugging Face error: {str(e)}" - ) - with open(downloaded_path, "rb") as hf_file: - hf_content = hf_file.read() - tmp_hf_file.write(hf_content) - - response._content = hf_content - - if save_on_disk: - os.makedirs(os.path.join(current_dir, "datasets"), exist_ok=True) - with open(zst_filename, "wb") as f: - f.write(response.content) - else: - # Create temporary file for downloaded zst content - with tempfile.NamedTemporaryFile( - mode="wb", suffix=".json.zst", delete=False - ) as tmp_zst_file: - tmp_zst_file.write(response.content) - zst_filename = tmp_zst_file.name - - # Decompress the downloaded zst file - dctx = zstd.ZstdDecompressor() - with ( - open(zst_filename, "rb") as input_f, - tempfile.NamedTemporaryFile(mode="wb", suffix=".json", delete=False) as tmp_file, - ): - dctx.copy_stream(input_f, tmp_file) - output = tmp_file.name - - with open(output, "r") as f: - hiftext = json.load(f) - if not validate_hif_json(output): - raise ValueError(f"Dataset '{dataset_name}' is not HIF-compliant.") - - hypergraph = HIFHypergraph.from_hif(hiftext) - return hypergraph +from hyperbench.data.hif import HIFLoader class Dataset(TorchDataset): @@ -140,13 +31,10 @@ class Dataset(TorchDataset): If not provided, defaults to ``SamplingStrategy.HYPEREDGE``. """ - DATASET_NAME = None - def __init__( self, hdata: Optional[HData] = None, sampling_strategy: SamplingStrategy = SamplingStrategy.HYPEREDGE, - prepare: bool = True, ) -> None: """ Initialize the Dataset. @@ -158,19 +46,10 @@ def __init__( prepare: Whether to load and process the original dataset from HIF format. If set to ``False``, the dataset will be initialized with the provided hdata instead. Defaults to ``True``. """ - self.__is_prepared = prepare + self.__sampler = create_sampler_from_strategy(sampling_strategy) self.sampling_strategy = sampling_strategy - - if self.__is_prepared: - self.hypergraph = self.download() - self.hdata = self.process() - else: - if hdata is None: - raise ValueError("hdata must be provided when prepare is set to False.") - - self.hypergraph = HIFHypergraph.empty() - self.hdata = hdata + self.hdata = hdata if hdata is not None else HData.empty() def __len__(self) -> int: return self.__sampler.len(self.hdata) @@ -210,79 +89,69 @@ def from_hdata( Returns: The :class:`Dataset` instance with the provided :class:`HData`. """ - return cls(hdata=hdata, sampling_strategy=sampling_strategy, prepare=False) + return cls(hdata=hdata, sampling_strategy=sampling_strategy) - def download(self) -> HIFHypergraph: + @classmethod + def from_url( + cls, + url: str, + sampling_strategy: SamplingStrategy = SamplingStrategy.HYPEREDGE, + save_on_disk: bool = False, + ) -> "Dataset": """ - Load the hypergraph from HIF format using HIFConverter class. + Create a :class:`Dataset` instance by loading a hypergraph from a URL pointing to a .json or .json.zst file in HIF format. + + Args: + url: The URL to the .json or .json.zst file containing the HIF hypergraph data. + sampling_strategy: The sampling strategy to use for the dataset. If not provided, defaults to ``SamplingStrategy.HYPEREDGE``. + save_on_disk: Whether to save the downloaded file on disk. + + Returns: + The :class:`Dataset` instance with the loaded hypergraph data. """ - if not self.__is_prepared: - raise ValueError("download can only be called for the original dataset (prepare=True).") + hdata = HIFLoader.load_from_url(url=url, save_on_disk=save_on_disk) + dataset = cls.from_hdata(hdata=hdata, sampling_strategy=sampling_strategy) + return dataset - if hasattr(self, "hypergraph") and self.hypergraph is not None: - return self.hypergraph + @classmethod + def from_path( + cls, + filepath: str, + sampling_strategy: SamplingStrategy = SamplingStrategy.HYPEREDGE, + ) -> "Dataset": + """ + Create a :class:`Dataset` instance by loading a hypergraph from a local file path pointing to a .json or .json.zst file in HIF format. - return HIFConverter.load_from_hif(self.DATASET_NAME, save_on_disk=True) + Args: + filepath: The local file path to the .json or .json.zst file containing the HIF hypergraph data. + sampling_strategy: The sampling strategy to use for the dataset. If not provided, defaults to ``SamplingStrategy.HYPEREDGE``. - def process(self) -> HData: + Returns: + The :class:`Dataset` instance with the loaded hypergraph data. + """ + hypergraph = HIFLoader.load_from_path(filepath=filepath) + dataset = cls.from_hdata(hdata=hypergraph, sampling_strategy=sampling_strategy) + return dataset + + @classmethod + def from_default( + cls, + sampling_strategy: SamplingStrategy = SamplingStrategy.HYPEREDGE, + save_on_disk: bool = False, + ) -> "Dataset": """ - Process the loaded hypergraph into :class:`HData` format, mapping HIF structure to tensors. + Create a :class:`Dataset` instance by loading a hypergraph from a URL pointing to a .json or .json.zst file in HIF format. + + Args: + sampling_strategy: The sampling strategy to use for the dataset. If not provided, defaults to ``SamplingStrategy.HYPEREDGE``. + save_on_disk: Whether to save the downloaded file on disk. Returns: - The processed hypergraph data. + The :class:`Dataset` instance with the loaded hypergraph data. """ - if not self.__is_prepared: - raise ValueError("process can only be called for the original dataset.") - - num_nodes = len(self.hypergraph.nodes) - x = self.__process_x(num_nodes) - - # Remap node IDs to 0-based contiguous IDs (using indices) matching the x tensor order - node_id_to_idx = {node.get("node"): idx for idx, node in enumerate(self.hypergraph.nodes)} - # Initialize edge_set only with edges that have incidences, so that - # we avoid inflating edge count due to isolated nodes/missing incidences - hyperedge_id_to_idx: Dict[Any, int] = {} - - node_ids = [] - hyperedge_ids = [] - nodes_with_incidences = set() - for incidence in self.hypergraph.incidences: - node_id = incidence.get("node", 0) - hyperedge_id = incidence.get("edge", 0) - - if hyperedge_id not in hyperedge_id_to_idx: - # Hyperedges start from 0 and are assigned IDs in the order they are first encountered in incidences - hyperedge_id_to_idx[hyperedge_id] = len(hyperedge_id_to_idx) - - node_ids.append(node_id_to_idx[node_id]) - hyperedge_ids.append(hyperedge_id_to_idx[hyperedge_id]) - nodes_with_incidences.add(node_id_to_idx[node_id]) - - # Handle isolated nodes by assigning them to a new unique hyperedge (self-loop) - for node_idx in range(num_nodes): - if node_idx not in nodes_with_incidences: - new_hyperedge_id = len(hyperedge_id_to_idx) - # Unique dummy key to reserve the index in hyperedge_set - hyperedge_id_to_idx[f"__self_loop_{node_idx}__"] = new_hyperedge_id - node_ids.append(node_idx) - hyperedge_ids.append(new_hyperedge_id) - - num_hyperedges = len(hyperedge_id_to_idx) - hyperedge_attr = self.__process_hyperedge_attr(hyperedge_id_to_idx, num_hyperedges) - - hyperedge_weights = self.__process_hyperedge_weights() - - hyperedge_index = torch.tensor([node_ids, hyperedge_ids], dtype=torch.long) - - return HData( - x=x, - hyperedge_index=hyperedge_index, - hyperedge_weights=hyperedge_weights, - hyperedge_attr=hyperedge_attr, - num_nodes=num_nodes, - num_hyperedges=num_hyperedges, - global_node_ids=HyperedgeIndex(hyperedge_index).node_ids, - ) + hdata = HIFLoader.load(dataset_name="", save_on_disk=save_on_disk) + dataset = cls.from_hdata(hdata=hdata, sampling_strategy=sampling_strategy) + return dataset def enrich_node_features( self, @@ -340,7 +209,7 @@ def update_from_hdata(self, hdata: HData) -> "Dataset": Returns: The :class:`Dataset` instance with the provided :class:`HData`. """ - return self.__class__(hdata=hdata, sampling_strategy=self.sampling_strategy, prepare=False) + return self.__class__(hdata=hdata, sampling_strategy=self.sampling_strategy) def remove_hyperedges_with_fewer_than_k_nodes(self, k: int) -> None: """ @@ -428,7 +297,6 @@ def split( split_dataset = self.__class__( hdata=split_hdata, sampling_strategy=self.sampling_strategy, - prepare=False, ) split_datasets.append(split_dataset) @@ -449,70 +317,6 @@ def to(self, device: torch.device) -> "Dataset": self.hdata = self.hdata.to(device) return self - def transform_node_attrs( - self, - attrs: Dict[str, Any], - attr_keys: Optional[List[str]] = None, - ) -> Tensor: - return self.transform_attrs(attrs, attr_keys) - - def transform_hyperedge_attrs( - self, - attrs: Dict[str, Any], - attr_keys: Optional[List[str]] = None, - ) -> Tensor: - return self.transform_attrs(attrs, attr_keys) - - def transform_attrs( - self, - attrs: Dict[str, Any], - attr_keys: Optional[List[str]] = None, - ) -> Tensor: - """ - Extract and encode numeric attributes to tensor. - Non-numeric attributes are discarded. Missing attributes are filled with ``0.0``. - - Args: - attrs: Dictionary of attributes - attr_keys: Optional list of attribute keys to encode. If provided, ensures consistent ordering and fill missing with ``0.0``. - - Returns: - Tensor of numeric attribute values - """ - numeric_attrs = { - key: value - for key, value in attrs.items() - if isinstance(value, (int, float)) and not isinstance(value, bool) - } - - if attr_keys is not None: - values = [float(numeric_attrs.get(key, 0.0)) for key in attr_keys] - return torch.tensor(values, dtype=torch.float) - - if not numeric_attrs: - return torch.tensor([], dtype=torch.float) - - values = [float(value) for value in numeric_attrs.values()] - return torch.tensor(values, dtype=torch.float) - - def __collect_attr_keys(self, attr_keys: List[Dict[str, Any]]) -> List[str]: - """ - Collect unique numeric attribute keys from a list of attribute dictionaries. - - Args: - attr_keys: List of attribute dictionaries. - - Returns: - List of unique numeric attribute keys. - """ - unique_keys = [] - for attrs in attr_keys: - for key, value in attrs.items(): - if key not in unique_keys and isinstance(value, (int, float)): - unique_keys.append(key) - - return unique_keys - def __get_hyperedge_ids_permutation( self, num_hyperedges: int, @@ -537,86 +341,27 @@ def __get_hyperedge_ids_permutation( ranged_hyperedge_ids_permutation = torch.arange(num_hyperedges, device=device) return ranged_hyperedge_ids_permutation - def __process_hyperedge_attr( - self, - hyperedge_id_to_idx: Dict[Any, int], - num_hyperedges: int, - ) -> Optional[Tensor]: - # hyperedge-attr: shape [num_hyperedges, num_hyperedge_attributes] - hyperedge_attr = None - has_hyperedges = ( - self.hypergraph.hyperedges is not None and len(self.hypergraph.hyperedges) > 0 - ) - has_any_hyperedge_attrs = has_hyperedges and any( - "attrs" in edge for edge in self.hypergraph.hyperedges - ) - - if has_any_hyperedge_attrs: - hyperedge_id_to_attrs: Dict[Any, Dict[str, Any]] = { - e.get("edge"): e.get("attrs", {}) for e in self.hypergraph.hyperedges - } - - hyperedge_attr_keys = self.__collect_attr_keys(list(hyperedge_id_to_attrs.values())) - - # Build attributes in exact order of hyperedge_set indices (0 to num_hyperedges - 1) - hyperedge_idx_to_id = {idx: id for id, idx in hyperedge_id_to_idx.items()} - - attrs = [] - for hyperedge_idx in range(num_hyperedges): - hyperedge_id = hyperedge_idx_to_id[hyperedge_idx] - - transformed_attrs = self.transform_hyperedge_attrs( - # If it's a real hyperedge, get its attrs; if self-loop, get empty dict - attrs=hyperedge_id_to_attrs.get(hyperedge_id, {}), - attr_keys=hyperedge_attr_keys, - ) - attrs.append(transformed_attrs) - - hyperedge_attr = torch.stack(attrs) - - return hyperedge_attr - - def __process_x(self, num_nodes: int) -> Tensor: - # Collect all attribute keys to have tensors of same size - node_attr_keys = self.__collect_attr_keys( - [node.get("attrs", {}) for node in self.hypergraph.nodes] - ) - - if node_attr_keys: - x = torch.stack( - [ - self.transform_node_attrs(node.get("attrs", {}), attr_keys=node_attr_keys) - for node in self.hypergraph.nodes - ] - ) - else: - # Fallback to ones if no node features, 1 is better as it can help during - # training (e.g., avoid zero multiplication), especially in first epochs - x = torch.ones((num_nodes, 1), dtype=torch.float) - - return x # shape [num_nodes, num_node_features] - - def __process_hyperedge_weights(self) -> Optional[Tensor]: - # Initialize the hyperedge weights tensor - hyperedge_weights = None - - has_hyperedge_weights = self.hypergraph.hyperedges is not None and all( - "weight" in edge for edge in self.hypergraph.hyperedges - ) - - if has_hyperedge_weights: - weights = [edge.get("weight", 1.0) for edge in self.hypergraph.hyperedges] - hyperedge_weights = torch.tensor(weights, dtype=torch.float) - elif ( - has_hyperedge_weights is False - and self.hypergraph.hyperedges is not None - and any("weight" in edge for edge in self.hypergraph.hyperedges) - ): - raise ValueError( - "Some hyperedges have weights while others do not. All hyperedges must either have weights or none." - ) - - return hyperedge_weights + # def __process_hyperedge_weights(self) -> Optional[Tensor]: + # # Initialize the hyperedge weights tensor + # hyperedge_weights = None + + # has_hyperedge_weights = self.hypergraph.hyperedges is not None and all( + # "weight" in edge for edge in self.hypergraph.hyperedges + # ) + + # if has_hyperedge_weights: + # weights = [edge.get("weight", 1.0) for edge in self.hypergraph.hyperedges] + # hyperedge_weights = torch.tensor(weights, dtype=torch.float) + # elif ( + # has_hyperedge_weights is False + # and self.hypergraph.hyperedges is not None + # and any("weight" in edge for edge in self.hypergraph.hyperedges) + # ): + # raise ValueError( + # "Some hyperedges have weights while others do not. All hyperedges must either have weights or none." + # ) + + # return hyperedge_weights def stats(self) -> Dict[str, Any]: """ diff --git a/hyperbench/data/hif.py b/hyperbench/data/hif.py new file mode 100644 index 0000000..aae6f72 --- /dev/null +++ b/hyperbench/data/hif.py @@ -0,0 +1,371 @@ +from turtle import st +import torch +from sympy.physics.units import h +import os +import json +import zstandard as zstd +import requests +import tempfile +import warnings +from huggingface_hub import hf_hub_download +from typing import Optional, Dict, Any, List +from urllib.parse import urlparse +from torch import Tensor + +from hyperbench.types import HData, HIFHypergraph +from hyperbench.utils import validate_hif_json + + +def _validate_http_url(value: str) -> str: + parsed = urlparse(value) + if parsed.scheme not in {"http", "https"} or not parsed.netloc: + raise ValueError(f"Invalid URL: {value}") + return value + + +class HIFLoader: + """A utility class to load hypergraphs from HIF format.""" + + @staticmethod + def load_from_url(url: str, save_on_disk: bool = False) -> HData: + """ + Load a hypergraph from a given URL pointing to a .json or .json.zst file in HIF format. + Args: + url (str): The URL to the .json or .json.zst file containing the HIF hypergraph data. + save_on_disk (bool): Whether to save the downloaded file on disk. + Returns: + HData: The loaded hypergraph object. + """ + url = _validate_http_url(url) + + response = requests.get(url, timeout=20) + if response.status_code != 200: + raise ValueError( + f"Failed to download dataset from URL '{url}' with status code {response.status_code}" + ) + + with tempfile.NamedTemporaryFile( + mode="wb", suffix=".json.zst", delete=False + ) as tmp_zst_file: + tmp_zst_file.write(response.content) + zst_filename = tmp_zst_file.name + + if zst_filename.endswith(".zst"): + if save_on_disk: + HIFLoader.__save_on_disk(os.path.basename(url), response.content) + output = HIFLoader.__decompress_zst(zst_filename) + elif zst_filename.endswith(".json"): + if save_on_disk: + compressed = HIFLoader.__compress_to_zst(zst_filename) + HIFLoader.__save_on_disk(os.path.basename(url), compressed) + output = zst_filename + else: + raise ValueError( + f"Unsupported file format for URL '{url}'. Expected .json or .json.zst" + ) + + hypergraph = HIFLoader.__extract_hif(output) + hdata = HIFLoader.__process(hypergraph) + return hdata + + @staticmethod + def load_from_path(filepath: str) -> HData: + """ + Load a hypergraph from a local file path pointing to a .json or .json.zst file in HIF format. + Args: + filepath (str): The local file path to the .json or .json.zst file + containing the HIF hypergraph data. + Returns: + HData: The loaded hypergraph object. + """ + if not os.path.exists(filepath): + raise ValueError(f"File '{filepath}' does not exist.") + + if filepath.endswith(".zst"): + output = HIFLoader.__decompress_zst(filepath) + elif filepath.endswith(".json"): + output = filepath + else: + raise ValueError( + f"Unsupported file format for filepath '{filepath}'. Expected .json or .json.zst" + ) + + hypergraph = HIFLoader.__extract_hif(output) + hdata = HIFLoader.__process(hypergraph) + return hdata + + @staticmethod + def load(dataset_name: str, save_on_disk: bool = False) -> HData: + print(f"Loading dataset '{dataset_name}' from disk or remote sources...") + current_dir = os.path.dirname(os.path.abspath(__file__)) + zst_filename = os.path.join(current_dir, "datasets", f"{dataset_name}.json.zst") + + if not os.path.exists(zst_filename): + github_dataset_repo = f"https://github.com/hypernetwork-research-group/datasets/blob/main/{dataset_name}.json.zst?raw=true" + + response = requests.get(github_dataset_repo) + if response.status_code != 200: + warnings.warn( + f"GitHub raw download failed for dataset '{dataset_name}' with status code {response.status_code}\n" + "Falling back to Hugging Face Hub download for dataset", + category=UserWarning, + stacklevel=2, + ) + + REPO_ID = f"HypernetworkRG/{dataset_name}" + FILENAME = f"{dataset_name}.json.zst" + + with tempfile.NamedTemporaryFile( + mode="wb", suffix=".json.zst", delete=False + ) as tmp_hf_file: + try: + downloaded_path = hf_hub_download( + repo_id=REPO_ID, + filename=FILENAME, + repo_type="dataset", + ) + except Exception as e: + raise ValueError( + f"Failed to download dataset '{dataset_name}' from GitHub and Hugging Face Hub. GitHub error: {response.status_code} | Hugging Face error: {str(e)}" + ) + with open(downloaded_path, "rb") as hf_file: + hf_content = hf_file.read() + tmp_hf_file.write(hf_content) + + response._content = hf_content + + if save_on_disk: + print(f"Saving downloaded dataset '{dataset_name}' to disk at '{zst_filename}'") + os.makedirs(os.path.join(current_dir, "datasets"), exist_ok=True) + with open(zst_filename, "wb") as f: + f.write(response.content) + else: + # Create temporary file for downloaded zst content + with tempfile.NamedTemporaryFile( + mode="wb", suffix=".json.zst", delete=False + ) as tmp_zst_file: + tmp_zst_file.write(response.content) + zst_filename = tmp_zst_file.name + + # Decompress the downloaded zst file + output = HIFLoader.__decompress_zst(zst_filename) + hypergraph = HIFLoader.__extract_hif(output) + hdata = HIFLoader.__process(hypergraph) + return hdata + + @staticmethod + def __decompress_zst(zst_path: str) -> str: + dctx = zstd.ZstdDecompressor() + with ( + open(zst_path, "rb") as input_f, + tempfile.NamedTemporaryFile(mode="wb", suffix=".json", delete=False) as tmp_file, + ): + dctx.copy_stream(input_f, tmp_file) + output = tmp_file.name + return output + + @staticmethod + def __compress_to_zst(json_path: str) -> bytes: + cctx = zstd.ZstdCompressor() + with open(json_path, "rb") as input_f: + compressed_content = cctx.compress(input_f.read()) + return compressed_content + + @staticmethod + def __extract_hif(json_file: str) -> HIFHypergraph: + with open(json_file, "r") as f: + hiftext = json.load(f) + if not validate_hif_json(json_file): + raise ValueError(f"Dataset from file '{json_file}' is not HIF-compliant.") + hypergraph = HIFHypergraph.from_hif(hiftext) + return hypergraph + + @staticmethod + def __save_on_disk(dataset_name: str, content: bytes) -> None: + current_dir = os.path.dirname(os.path.abspath(__file__)) + zst_filename = os.path.join(current_dir, "datasets", f"{dataset_name}.json.zst") + os.makedirs(os.path.join(current_dir, "datasets"), exist_ok=True) + + with open(zst_filename, "wb") as f: + f.write(content) + + @staticmethod + def __process_hyperedge_attr( + hypergraph: HIFHypergraph, + hyperedge_id_to_idx: Dict[Any, int], + num_hyperedges: int, + ) -> Optional[Tensor]: + # hyperedge-attr: shape [num_hyperedges, num_hyperedge_attributes] + hyperedge_attr = None + has_hyperedges = hypergraph.edges is not None and len(hypergraph.edges) > 0 + has_any_hyperedge_attrs = has_hyperedges and any( + "attrs" in edge for edge in hypergraph.edges + ) + + if has_any_hyperedge_attrs: + hyperedge_id_to_attrs: Dict[Any, Dict[str, Any]] = { + e.get("edge"): e.get("attrs", {}) for e in hypergraph.edges + } + + hyperedge_attr_keys = HIFLoader.__collect_attr_keys( + list(hyperedge_id_to_attrs.values()) + ) + + # Build attributes in exact order of hyperedge_set indices (0 to num_hyperedges - 1) + hyperedge_idx_to_id = {idx: id for id, idx in hyperedge_id_to_idx.items()} + + attrs = [] + for hyperedge_idx in range(num_hyperedges): + hyperedge_id = hyperedge_idx_to_id[hyperedge_idx] + + transformed_attrs = HIFLoader.transform_hyperedge_attrs( + # If it's a real hyperedge, get its attrs; if self-loop, get empty dict + attrs=hyperedge_id_to_attrs.get(hyperedge_id, {}), + attr_keys=hyperedge_attr_keys, + ) + attrs.append(transformed_attrs) + + hyperedge_attr = torch.stack(attrs) + + return hyperedge_attr + + @staticmethod + def transform_node_attrs( + attrs: Dict[str, Any], + attr_keys: Optional[List[str]] = None, + ) -> Tensor: + return HIFLoader.transform_attrs(attrs, attr_keys) + + @staticmethod + def __process_x(hypergraph: HIFHypergraph, num_nodes: int) -> Tensor: + # Collect all attribute keys to have tensors of same size + node_attr_keys = HIFLoader.__collect_attr_keys( + [node.get("attrs", {}) for node in hypergraph.nodes] + ) + + if node_attr_keys: + x = torch.stack( + [ + HIFLoader.transform_node_attrs(node.get("attrs", {}), attr_keys=node_attr_keys) + for node in hypergraph.nodes + ] + ) + else: + # Fallback to ones if no node features, 1 is better as it can help during + # training (e.g., avoid zero multiplication), especially in first epochs + x = torch.ones((num_nodes, 1), dtype=torch.float) + + return x # shape [num_nodes, num_node_features] + + @staticmethod + def __process(hypergraph: HIFHypergraph) -> HData: + """ + Process the loaded hypergraph into :class:`HData` format, mapping HIF structure to tensors. + + Returns: + The processed hypergraph data. + """ + # if not self.__is_prepared: + # raise ValueError("process can only be called for the original dataset.") + + num_nodes = len(hypergraph.nodes) + x = HIFLoader.__process_x(hypergraph, num_nodes) + + # Remap node IDs to 0-based contiguous IDs (using indices) matching the x tensor order + node_id_to_idx = {node.get("node"): idx for idx, node in enumerate(hypergraph.nodes)} + # Initialize edge_set only with edges that have incidences, so that + # we avoid inflating edge count due to isolated nodes/missing incidences + hyperedge_id_to_idx: Dict[Any, int] = {} + + node_ids = [] + hyperedge_ids = [] + nodes_with_incidences = set() + for incidence in hypergraph.incidences: + node_id = incidence.get("node", 0) + hyperedge_id = incidence.get("edge", 0) + + if hyperedge_id not in hyperedge_id_to_idx: + # Hyperedges start from 0 and are assigned IDs in the order they are first encountered in incidences + hyperedge_id_to_idx[hyperedge_id] = len(hyperedge_id_to_idx) + + node_ids.append(node_id_to_idx[node_id]) + hyperedge_ids.append(hyperedge_id_to_idx[hyperedge_id]) + nodes_with_incidences.add(node_id_to_idx[node_id]) + + # Handle isolated nodes by assigning them to a new unique hyperedge (self-loop) + for node_idx in range(num_nodes): + if node_idx not in nodes_with_incidences: + new_hyperedge_id = len(hyperedge_id_to_idx) + # Unique dummy key to reserve the index in hyperedge_set + hyperedge_id_to_idx[f"__self_loop_{node_idx}__"] = new_hyperedge_id + node_ids.append(node_idx) + hyperedge_ids.append(new_hyperedge_id) + + num_hyperedges = len(hyperedge_id_to_idx) + hyperedge_attr = HIFLoader.__process_hyperedge_attr( + hypergraph=hypergraph, + hyperedge_id_to_idx=hyperedge_id_to_idx, + num_hyperedges=num_hyperedges, + ) + + hyperedge_index = torch.tensor([node_ids, hyperedge_ids], dtype=torch.long) + + return HData(x, hyperedge_index, hyperedge_attr, num_nodes, num_hyperedges) + + @staticmethod + def __collect_attr_keys(attr_keys: List[Dict[str, Any]]) -> List[str]: + """ + Collect unique numeric attribute keys from a list of attribute dictionaries. + + Args: + attr_keys: List of attribute dictionaries. + + Returns: + List of unique numeric attribute keys. + """ + unique_keys = [] + for attrs in attr_keys: + for key, value in attrs.items(): + if key not in unique_keys and isinstance(value, (int, float)): + unique_keys.append(key) + + return unique_keys + + @staticmethod + def transform_hyperedge_attrs( + attrs: Dict[str, Any], + attr_keys: Optional[List[str]] = None, + ) -> Tensor: + return HIFLoader.transform_attrs(attrs, attr_keys) + + @staticmethod + def transform_attrs( + attrs: Dict[str, Any], + attr_keys: Optional[List[str]] = None, + ) -> Tensor: + """ + Extract and encode numeric attributes to tensor. + Non-numeric attributes are discarded. Missing attributes are filled with ``0.0``. + + Args: + attrs: Dictionary of attributes + attr_keys: Optional list of attribute keys to encode. If provided, ensures consistent ordering and fill missing with ``0.0``. + + Returns: + Tensor of numeric attribute values + """ + numeric_attrs = { + key: value + for key, value in attrs.items() + if isinstance(value, (int, float)) and not isinstance(value, bool) + } + + if attr_keys is not None: + values = [float(numeric_attrs.get(key, 0.0)) for key in attr_keys] + return torch.tensor(values, dtype=torch.float) + + if not numeric_attrs: + return torch.tensor([], dtype=torch.float) + + values = [float(value) for value in numeric_attrs.values()] + return torch.tensor(values, dtype=torch.float) diff --git a/hyperbench/data/supported_datasets.py b/hyperbench/data/supported_datasets.py index 443f043..198aa1e 100644 --- a/hyperbench/data/supported_datasets.py +++ b/hyperbench/data/supported_datasets.py @@ -1,93 +1,139 @@ +from hyperbench.data import HIFLoader from hyperbench.data.dataset import Dataset +from hyperbench.data.sampling import SamplingStrategy -class AlgebraDataset(Dataset): - DATASET_NAME = "ALGEBRA" +class PreloadedDataset(Dataset): + DATASET_NAME = "" + def __init__( + self, + hdata=None, + sampling_strategy: SamplingStrategy = SamplingStrategy.HYPEREDGE, + ) -> None: + super().__init__(hdata=hdata, sampling_strategy=sampling_strategy) + if hdata is None: + self.hdata = HIFLoader.load(self.DATASET_NAME, save_on_disk=True) -class AmazonDataset(Dataset): - DATASET_NAME = "AMAZON" +class AlgebraDataset(PreloadedDataset): + DATASET_NAME = "algebra" -class ContactHighSchoolDataset(Dataset): - DATASET_NAME = "CONTACT_HIGH_SCHOOL" +class AmazonDataset(PreloadedDataset): + DATASET_NAME = "amazon" -class ContactPrimarySchoolDataset(Dataset): - DATASET_NAME = "CONTACT_PRIMARY_SCHOOL" +class ContactHighSchoolDataset(PreloadedDataset): + DATASET_NAME = "contact-high-school" -class CoraDataset(Dataset): - DATASET_NAME = "CORA" +class ContactPrimarySchoolDataset(PreloadedDataset): + DATASET_NAME = "contact-primary-school" -class CourseraDataset(Dataset): - DATASET_NAME = "COURSERA" +class CoraDataset(PreloadedDataset): + DATASET_NAME = "cora" -class DBLPDataset(Dataset): - DATASET_NAME = "DBLP" +class CourseraDataset(PreloadedDataset): + DATASET_NAME = "coursera" -class EmailEnronDataset(Dataset): - DATASET_NAME = "EMAIL_ENRON" +class DBLPDataset(PreloadedDataset): + DATASET_NAME = "dblp" -class EmailW3CDataset(Dataset): - DATASET_NAME = "EMAIL_W3C" +class EmailEnronDataset(PreloadedDataset): + DATASET_NAME = "email-Enron" -class GeometryDataset(Dataset): - DATASET_NAME = "GEOMETRY" +class EmailW3CDataset(PreloadedDataset): + DATASET_NAME = "email-W3C" -class GOTDataset(Dataset): - DATASET_NAME = "GOT" +class GeometryDataset(PreloadedDataset): + DATASET_NAME = "geometry" -class IMDBDataset(Dataset): - DATASET_NAME = "IMDB" +class GOTDataset(PreloadedDataset): + DATASET_NAME = "got" -class MusicBluesReviewsDataset(Dataset): - DATASET_NAME = "MUSIC_BLUES_REVIEWS" +class IMDBDataset(PreloadedDataset): + DATASET_NAME = "imdb" -class NBADataset(Dataset): - DATASET_NAME = "NBA" +class MusicBluesReviewsDataset(PreloadedDataset): + DATASET_NAME = "music-blues-reviews" -class NDCClassesDataset(Dataset): - DATASET_NAME = "NDC_CLASSES" +class NBADataset(PreloadedDataset): + DATASET_NAME = "nba" -class NDCSubstancesDataset(Dataset): - DATASET_NAME = "NDC_SUBSTANCES" +class NDCClassesDataset(PreloadedDataset): + DATASET_NAME = "NDC-classes" -class PatentDataset(Dataset): - DATASET_NAME = "PATENT" +class NDCSubstancesDataset(PreloadedDataset): + DATASET_NAME = "NDC-substances" -class PubmedDataset(Dataset): - DATASET_NAME = "PUBMED" +class PatentDataset(PreloadedDataset): + DATASET_NAME = "patent" -class RestaurantReviewsDataset(Dataset): - DATASET_NAME = "RESTAURANT_REVIEWS" +class PubmedDataset(PreloadedDataset): + DATASET_NAME = "pubmed" -class ThreadsAskUbuntuDataset(Dataset): - DATASET_NAME = "THREADS_ASK_UBUNTU" +class RestaurantReviewsDataset(PreloadedDataset): + DATASET_NAME = "restaurant-reviews" -class ThreadsMathsxDataset(Dataset): - DATASET_NAME = "THREADS_MATH_SX" +class ThreadsAskUbuntuDataset(PreloadedDataset): + DATASET_NAME = "threads-ask-ubuntu" -class TwitterDataset(Dataset): - DATASET_NAME = "TWITTER" +class ThreadsMathsxDataset(PreloadedDataset): + DATASET_NAME = "threads-math-sx" -class VegasBarsReviewsDataset(Dataset): - DATASET_NAME = "VEGAS_BARS_REVIEWS" + +class TwitterDataset(PreloadedDataset): + DATASET_NAME = "twitter" + + +class VegasBarsReviewsDataset(PreloadedDataset): + DATASET_NAME = "vegas-bars-reviews" + + +if __name__ == "__main__": + # test loading each dataset + for dataset_cls in [ + AlgebraDataset, + AmazonDataset, + ContactHighSchoolDataset, + # ContactPrimarySchoolDataset, + # CoraDataset, + # CourseraDataset, + # DBLPDataset, + # EmailEnronDataset, + # EmailW3CDataset, + # GeometryDataset, + # GOTDataset, + # IMDBDataset, + # MusicBluesReviewsDataset, + # NBADataset, + # NDCClassesDataset, + # NDCSubstancesDataset, + # PatentDataset, + # PubmedDataset, + # RestaurantReviewsDataset, + # ThreadsAskUbuntuDataset, + # ThreadsMathsxDataset, + # TwitterDataset, + # VegasBarsReviewsDataset, + ]: + dataset = dataset_cls() + print(dataset.hdata.num_nodes, dataset.hdata.num_hyperedges) diff --git a/hyperbench/tests/data/dataset_test.py b/hyperbench/tests/data/dataset_test.py index 7c44c01..aaa5015 100644 --- a/hyperbench/tests/data/dataset_test.py +++ b/hyperbench/tests/data/dataset_test.py @@ -4,10 +4,10 @@ import torch from unittest.mock import patch, mock_open, MagicMock -from hyperbench.data import AlgebraDataset, Dataset, HIFConverter, SamplingStrategy +from hyperbench.data import AlgebraDataset, Dataset, HIFLoader, SamplingStrategy from hyperbench.nn import EnrichmentMode, NodeEnricher, HyperedgeEnricher from hyperbench.types import HData, HIFHypergraph - +from hyperbench.data.supported_datasets import PreloadedDataset @pytest.fixture def mock_hdata() -> HData: @@ -148,594 +148,608 @@ def mock_multiple_edges_attr_hypergraph(): ], ) - -def test_HIFConverter_num_nodes_and_edges(): - dataset_name = "ALGEBRA" - mock_hypergraph = HIFHypergraph( - network_type="undirected", - nodes=[{"node": str(i)} for i in range(20)], - hyperedges=[{"edge": str(i)} for i in range(30)], - incidences=[{"node": "0", "edge": "0"}], - ) - - with patch.object(HIFConverter, "load_from_hif", return_value=mock_hypergraph): - hypergraph = HIFConverter.load_from_hif(dataset_name) - - assert hypergraph is not None - assert hasattr(hypergraph, "nodes") - assert hasattr(hypergraph, "hyperedges") - assert hasattr(hypergraph, "incidences") - assert hasattr(hypergraph, "metadata") - assert hasattr(hypergraph, "network_type") - - assert hypergraph.num_nodes == 20 - assert hypergraph.num_hyperedges == 30 - - -def test_HIFConverter_loads_invalid_dataset(): - dataset_name = "INVALID_DATASET" - - with pytest.raises(ValueError, match="Dataset 'INVALID_DATASET' not found"): - HIFConverter.load_from_hif(dataset_name) - - -def test_HIFConverter_loads_invalid_hif_format(): - dataset_name = "ALGEBRA" - - invalid_hif_json = '{"network-type": "undirected", "nodes": []}' - - with ( - patch("hyperbench.data.dataset.requests.get") as mock_get, - patch("hyperbench.data.dataset.validate_hif_json", return_value=False), - patch("builtins.open", mock_open(read_data=invalid_hif_json)), - patch("hyperbench.data.dataset.zstd.ZstdDecompressor"), - ): - mock_response = mock_get.return_value - mock_response.status_code = 200 - mock_response.content = b"mock_zst_content" - - with pytest.raises(ValueError, match="Dataset 'algebra' is not HIF-compliant"): - HIFConverter.load_from_hif(dataset_name) - - -def test_HIFConverter_stores_on_disk_when_save_on_disk_true(): - dataset_name = "ALGEBRA" - - mock_hypergraph = HIFHypergraph( - network_type="undirected", - nodes=[{"node": "0"}, {"node": "1"}], - hyperedges=[{"edge": "0"}], - incidences=[{"node": "0", "edge": "0"}], - ) - - mock_hif_json = { - "network-type": "undirected", - "nodes": [{"node": "0"}, {"node": "1"}], - "edges": [{"edge": "0"}], - "incidences": [{"node": "0", "edge": "0"}], - } - - with ( - patch("hyperbench.data.dataset.requests.get") as mock_get, - patch("hyperbench.data.dataset.os.path.exists", return_value=False), - patch("hyperbench.data.dataset.os.makedirs"), - patch("builtins.open", mock_open()) as mock_file, - patch("hyperbench.data.dataset.zstd.ZstdDecompressor") as mock_decomp, - patch("hyperbench.data.dataset.json.load", return_value=mock_hif_json), - patch("hyperbench.data.dataset.validate_hif_json", return_value=True), - patch.object(HIFHypergraph, "from_hif", return_value=mock_hypergraph), - ): - # Mock successful download - mock_response = mock_get.return_value - mock_response.status_code = 200 - mock_response.content = b"mock_zst_content" - - # Mock decompressor - mock_stream = mock_decomp.return_value.stream_reader.return_value - mock_stream.__enter__ = lambda self: mock_stream - mock_stream.__exit__ = lambda self, *args: None - - hypergraph = HIFConverter.load_from_hif(dataset_name, save_on_disk=True) - - assert hypergraph is not None - assert hypergraph.network_type == "undirected" - mock_get.assert_called_once() - # Verify file was written to disk (not temp file) - assert mock_file.call_count >= 2 # Once for write, once for read - - -def test_HIFConverter_uses_temp_file_when_save_on_disk_false(): - dataset_name = "ALGEBRA" - - mock_hypergraph = HIFHypergraph( - network_type="undirected", - nodes=[{"node": "0"}, {"node": "1"}], - hyperedges=[{"edge": "0"}], - incidences=[{"node": "0", "edge": "0"}], - ) - - mock_hif_json = { - "network-type": "undirected", - "nodes": [{"node": "0"}, {"node": "1"}], - "edges": [{"edge": "0"}], - "incidences": [{"node": "0", "edge": "0"}], - } - - with ( - patch("hyperbench.data.dataset.requests.get") as mock_get, - patch("hyperbench.data.dataset.os.path.exists", return_value=False), - patch("hyperbench.data.dataset.tempfile.NamedTemporaryFile") as mock_temp, - patch("builtins.open", mock_open()), - patch("hyperbench.data.dataset.zstd.ZstdDecompressor") as mock_decomp, - patch("hyperbench.data.dataset.json.load", return_value=mock_hif_json), - patch("hyperbench.data.dataset.validate_hif_json", return_value=True), - patch.object(HIFHypergraph, "from_hif", return_value=mock_hypergraph), - ): - # Mock successful download - mock_response = mock_get.return_value - mock_response.status_code = 200 - mock_response.content = b"mock_zst_content" - - # Mock temp file - mock_temp_file = mock_temp.return_value.__enter__.return_value - mock_temp_file.name = "/tmp/fake_temp.json.zst" - - # Mock decompressor - mock_stream = mock_decomp.return_value.stream_reader.return_value - mock_stream.__enter__ = lambda self: mock_stream - mock_stream.__exit__ = lambda self, *args: None - - hypergraph = HIFConverter.load_from_hif(dataset_name, save_on_disk=False) - - assert hypergraph is not None - assert hypergraph.network_type == "undirected" - mock_get.assert_called_once() - # Verify temp file was used - assert mock_temp.call_count >= 1 - - -def test_HIFConverter_download_failure(): - dataset_name = "ALGEBRA" - - with ( - patch("hyperbench.data.dataset.requests.get") as mock_get, - patch("hyperbench.data.dataset.hf_hub_download", side_effect=Exception("HFHub failed")), - patch("hyperbench.data.dataset.os.path.exists", return_value=False), - ): - # Mock failed download - mock_response = mock_get.return_value - mock_response.status_code = 404 - mock_response.content = b"" - - with pytest.warns( - UserWarning, - match=r"(?s)GitHub raw download failed for dataset 'algebra' with status code 404.*Falling back to Hugging Face Hub download for dataset", - ): - with pytest.raises( - ValueError, - match=r"Failed to download dataset 'algebra'", - ): - HIFConverter.load_from_hif(dataset_name) - - -def test_HIFConverter_falls_back_to_hf_hub_download_when_github_raw_download_fails( - tmp_path, mock_sample_hypergraph -): - dataset_name = "ALGEBRA" - - mock_hypergraph = mock_sample_hypergraph - - mock_hif_json = { - "network_type": "undirected", - "nodes": [{"node": "0"}, {"node": "1"}], - "edges": [{"edge": "0"}], - "incidences": [{"node": "0", "edge": "0"}], - } - - fallback_file = tmp_path / "algebra.json.zst" - fallback_file.write_bytes(b"mock_zst_content") - - created_temp_files = [] - original_named_tempfile = tempfile.NamedTemporaryFile - - def named_tempfile_side_effect(*args, **kwargs): - temp_file = original_named_tempfile(*args, **kwargs) - created_temp_files.append(temp_file) - return temp_file - - with ( - patch("hyperbench.data.dataset.requests.get") as mock_get, - patch( - "hyperbench.data.dataset.hf_hub_download", - return_value=str(fallback_file), - ) as mock_hf_hub_download, - patch("hyperbench.data.dataset.os.path.exists", return_value=False), - patch( - "hyperbench.data.dataset.tempfile.NamedTemporaryFile", - side_effect=named_tempfile_side_effect, - ), - patch("hyperbench.data.dataset.zstd.ZstdDecompressor") as mock_decomp, - patch("hyperbench.data.dataset.json.load", return_value=mock_hif_json), - patch("hyperbench.data.dataset.validate_hif_json", return_value=True), - patch.object(HIFHypergraph, "from_hif", return_value=mock_hypergraph), - ): - mock_response = mock_get.return_value - mock_response.status_code = 404 - mock_response.content = b"" - - def fake_copy_stream(src, dst): - dst.write(b'{"network_type":"undirected","nodes":[],"edges":[],"incidences":[]}') - return - - mock_decomp.return_value.copy_stream.side_effect = fake_copy_stream - - with pytest.warns( - UserWarning, - match=r"(?s)GitHub raw download failed for dataset 'algebra' with status code 404.*Falling back to Hugging Face Hub download for dataset", - ): - hypergraph = HIFConverter.load_from_hif(dataset_name, save_on_disk=False) - - assert hypergraph.network_type == "undirected" - mock_get.assert_called_once() - mock_hf_hub_download.assert_called_once() - assert created_temp_files[0].name is not None - assert fallback_file.read_bytes() == b"mock_zst_content" - - -def test_HIFConverter_download_raises_when_network_error(): - dataset_name = "ALGEBRA" - - with ( - patch("hyperbench.data.dataset.requests.get") as mock_get, - patch("hyperbench.data.dataset.os.path.exists", return_value=False), - ): - # Mock network error - mock_get.side_effect = requests.RequestException("Network error") - - with pytest.raises(requests.RequestException, match="Network error"): - HIFConverter.load_from_hif(dataset_name) - - -def test_init_with_prepare_false_and_no_hdata_raises(): - with pytest.raises(ValueError, match="hdata must be provided when prepare is set to False."): - Dataset(hdata=None, prepare=False) - - -def test_dataset_is_not_available(): - class FakeMockDataset(Dataset): - DATASET_NAME = "FAKE" - - with pytest.raises(ValueError, match=r"Dataset 'FAKE' not found"): - FakeMockDataset() - - -@pytest.mark.parametrize( - "strategy, expected_len", - [ - pytest.param(SamplingStrategy.NODE, 4, id="node_strategy"), - pytest.param(SamplingStrategy.HYPEREDGE, 2, id="hyperedge_strategy"), - ], -) -def test_dataset_is_available_with_all_strategies( - strategy, expected_len, mock_four_node_hypergraph -): - with patch.object(HIFConverter, "load_from_hif", return_value=mock_four_node_hypergraph): - dataset = AlgebraDataset(sampling_strategy=strategy) - - assert dataset.DATASET_NAME == "ALGEBRA" - assert dataset.hypergraph is not None - assert len(dataset) == expected_len - - -def test_download_already_downloaded_dataset_uses_local_value(mock_four_node_hypergraph): - with patch.object(HIFConverter, "load_from_hif", return_value=mock_four_node_hypergraph): - dataset = AlgebraDataset() - - hg1 = dataset.download() - hg2 = dataset.download() - - assert hg1 is hg2 - - -def test_throw_when_dataset_name_is_none(): - class FakeMockDataset(Dataset): - DATASET_NAME = None - - with pytest.raises( - ValueError, - match=r"Dataset name \(provided: None\) must be provided\.", - ): - FakeMockDataset() - - -def test_dataset_process_no_incidences(): - mock_hypergraph = HIFHypergraph( - network_type="undirected", - nodes=[{"node": "0", "attrs": {}}, {"node": "1", "attrs": {}}], - hyperedges=[{"edge": "0", "attrs": {}}], - incidences=[], - ) - - with patch.object(HIFConverter, "load_from_hif", return_value=mock_hypergraph): - dataset = AlgebraDataset() - - assert dataset.hdata is not None - assert dataset.hdata.x.shape[0] == 2 - assert dataset.hdata.hyperedge_index.shape[0] == 2 - assert dataset.hdata.hyperedge_index.shape[1] == 2 - assert dataset.hdata.hyperedge_attr is not None - assert dataset.hdata.hyperedge_attr.shape == (2, 0) - assert dataset.hdata.hyperedge_attr[0].shape == (0,) - assert dataset.hdata.hyperedge_attr[1].shape == (0,) - - -def test_dataset_process_with_edge_attributes(): - mock_hypergraph = HIFHypergraph( - network_type="undirected", - nodes=[ - {"node": "0", "attrs": {}}, - {"node": "1", "attrs": {}}, - {"node": "2", "attrs": {}}, - ], - hyperedges=[ - {"edge": "0", "attrs": {"weight": 1.0, "type": 2.0}}, - {"edge": "1", "attrs": {"weight": 3.0, "type": 0.1}}, - ], - incidences=[ - {"node": "0", "edge": "0"}, - {"node": "1", "edge": "0"}, - {"node": "2", "edge": "1"}, - ], - ) - - with patch.object(HIFConverter, "load_from_hif", return_value=mock_hypergraph): - dataset = AlgebraDataset() - - assert dataset.hdata is not None - assert dataset.hdata.x.shape[0] == 3 - assert dataset.hdata.hyperedge_index.shape[0] == 2 - assert dataset.hdata.hyperedge_index.shape[1] == 3 - assert dataset.hdata.hyperedge_attr is not None - # Two edges with two attributes each: shape [2, 2] - assert dataset.hdata.hyperedge_attr.shape == (2, 2) - # Attributes maintain dictionary insertion order (no sorting) - - assert torch.allclose(dataset.hdata.hyperedge_attr[0], torch.tensor([1.0, 2.0])) # weight, type - assert torch.allclose(dataset.hdata.hyperedge_attr[1], torch.tensor([3.0, 0.1])) # weight, type - - -def test_dataset_process_without_edge_attributes(mock_no_edge_attr_hypergraph): - with patch.object(HIFConverter, "load_from_hif", return_value=mock_no_edge_attr_hypergraph): - dataset = AlgebraDataset() - - assert dataset.hdata is not None - assert dataset.hdata.hyperedge_index.shape[0] == 2 - assert dataset.hdata.hyperedge_index.shape[1] == 2 - assert dataset.hdata.hyperedge_attr is None - - -def test_dataset_process_hyperedge_index_in_correct_format(mock_four_node_hypergraph): - with patch.object(HIFConverter, "load_from_hif", return_value=mock_four_node_hypergraph): - dataset = AlgebraDataset() - - assert dataset.hdata.hyperedge_index.shape == (2, 4) - assert torch.allclose(dataset.hdata.hyperedge_index[0], torch.tensor([0, 1, 2, 3])) - assert torch.allclose(dataset.hdata.hyperedge_index[1], torch.tensor([0, 0, 1, 1])) - - -def test_dataset_process_random_ids(): - mock_hypergraph = HIFHypergraph( - network_type="undirected", - nodes=[ - {"node": "abc", "attrs": {}}, - {"node": "ss", "attrs": {}}, - {"node": "fewao", "attrs": {}}, - ], - hyperedges=[{"edge": "0", "attrs": {}}, {"edge": "1", "attrs": {}}], - incidences=[ - {"node": "abc", "edge": "0"}, - {"node": "ss", "edge": "0"}, - {"node": "fewao", "edge": "1"}, - ], - ) - - with patch.object(HIFConverter, "load_from_hif", return_value=mock_hypergraph): - dataset = AlgebraDataset() - - assert dataset.hdata.hyperedge_index.shape == (2, 3) - assert torch.allclose(dataset.hdata.hyperedge_index[0], torch.tensor([0, 1, 2])) - assert torch.allclose(dataset.hdata.hyperedge_index[1], torch.tensor([0, 0, 1])) - assert dataset.hdata.hyperedge_attr is not None - assert dataset.hdata.hyperedge_attr.shape == (2, 0) # 2 edges, 0 attributes each - - -@pytest.mark.parametrize( - "strategy", - [ - pytest.param(SamplingStrategy.NODE, id="node_strategy"), - pytest.param(SamplingStrategy.HYPEREDGE, id="hyperedge_strategy"), - ], -) -def test_getitem_index_list_empty(mock_simple_hypergraph, strategy): - with patch.object(HIFConverter, "load_from_hif", return_value=mock_simple_hypergraph): - dataset = AlgebraDataset(sampling_strategy=strategy) - - with pytest.raises(ValueError, match="Index list cannot be empty."): - dataset[[]] - - -@pytest.mark.parametrize( - "strategy, index_list, expected_message", - [ - pytest.param( - SamplingStrategy.NODE, - [0, 1, 2, 3, 4], - r"Index list length \(5\) cannot exceed the number of sampleable items \(4\)\.", - id="node_strategy", - ), - pytest.param( - SamplingStrategy.HYPEREDGE, - [0, 1, 2], - r"Index list length \(3\) cannot exceed the number of sampleable items \(2\)\.", - id="hyperedge_strategy", - ), - ], -) -def test_getitem_raises_when_index_list_larger_than_max( - mock_four_node_hypergraph, strategy, index_list, expected_message -): - with patch.object(HIFConverter, "load_from_hif", return_value=mock_four_node_hypergraph): - dataset = AlgebraDataset(sampling_strategy=strategy) - - with pytest.raises(ValueError, match=expected_message): - dataset[index_list] - - -@pytest.mark.parametrize( - "strategy, index, expected_message", - [ - pytest.param( - SamplingStrategy.NODE, 4, r"Node ID 4 is out of bounds \(0, 3\)\.", id="node_strategy" - ), - pytest.param( - SamplingStrategy.HYPEREDGE, - 2, - r"Hyperedge ID 2 is out of bounds \(0, 1\)\.", - id="hyperedge_strategy", - ), - ], -) -def test_getitem_raises_when_index_out_of_bounds( - mock_four_node_hypergraph, strategy, index, expected_message -): - with patch.object(HIFConverter, "load_from_hif", return_value=mock_four_node_hypergraph): - dataset = AlgebraDataset(sampling_strategy=strategy) - - with pytest.raises(IndexError, match=expected_message): - dataset[index] - - -@pytest.mark.parametrize( - "strategy, index, expected_shape, expected_num_hyperedges", - [ - # When node 1 is selected, we get hyperedge 0 with nodes 0 and 1 -> 2 incidences, 1 hyperedge - pytest.param(SamplingStrategy.NODE, 1, (2, 1), 1, id="node_strategy"), - # When hyperedge 0 is selected, we get nodes 0 and 1 -> 2 incidences, 1 hyperedge - pytest.param(SamplingStrategy.HYPEREDGE, 0, (2, 1), 1, id="hyperedge_strategy"), - ], -) -def test_getitem_single_index( - mock_sample_hypergraph, strategy, index, expected_shape, expected_num_hyperedges -): - with patch.object(HIFConverter, "load_from_hif", return_value=mock_sample_hypergraph): - dataset = AlgebraDataset(sampling_strategy=strategy) - - data = dataset[index] - - assert data.hyperedge_index.shape == expected_shape - assert data.num_hyperedges == expected_num_hyperedges - - -@pytest.mark.parametrize( - "strategy, index, expected_shape, expected_num_hyperedges", - [ - # When nodes (0, 2, 3) -> hyperedge 0 (nodes 0, 1) + hyperedge 1 (nodes 2, 3) -> 4 incidences, 2 hyperedges - pytest.param(SamplingStrategy.NODE, [0, 2, 3], (2, 4), 2, id="node_strategy"), - # When hyperedge 0 (nodes 0, 1) + hyperedge 1 (nodes 2, 3) -> 4 incidences, 2 hyperedges - pytest.param(SamplingStrategy.HYPEREDGE, [0, 1], (2, 4), 2, id="hyperedge_strategy"), - ], -) -def test_getitem_when_list_index_provided( - mock_four_node_hypergraph, strategy, index, expected_shape, expected_num_hyperedges -): - with patch.object(HIFConverter, "load_from_hif", return_value=mock_four_node_hypergraph): - dataset = AlgebraDataset(sampling_strategy=strategy) - - data = dataset[index] - - assert data.hyperedge_index.shape == expected_shape - assert data.num_hyperedges == expected_num_hyperedges - - -@pytest.mark.parametrize( - "strategy", - [ - pytest.param(SamplingStrategy.NODE, id="node_strategy"), - pytest.param(SamplingStrategy.HYPEREDGE, id="hyperedge_strategy"), - ], -) -def test_getitem_with_edge_attr(mock_three_node_weighted_hypergraph, strategy): - with patch.object( - HIFConverter, "load_from_hif", return_value=mock_three_node_weighted_hypergraph - ): - dataset = AlgebraDataset(sampling_strategy=strategy) - - data = dataset[0] - - assert data.hyperedge_index.shape == (2, 2) - assert data.num_hyperedges == 1 - assert data.hyperedge_attr is None - - -@pytest.mark.parametrize( - "strategy", - [ - pytest.param(SamplingStrategy.NODE, id="node_strategy"), - pytest.param(SamplingStrategy.HYPEREDGE, id="hyperedge_strategy"), - ], -) -def test_getitem_without_edge_attr(mock_no_edge_attr_hypergraph, strategy): - with patch.object(HIFConverter, "load_from_hif", return_value=mock_no_edge_attr_hypergraph): - dataset = AlgebraDataset(sampling_strategy=strategy) - - data = dataset[0] - assert data.hyperedge_attr is None - - -@pytest.mark.parametrize( - "strategy, index", - [ - # When nodes 0,2 -> hyperedge 0 (nodes 0, 1) + hyperedge 1 (node 2) -> 2 hyperedges - pytest.param(SamplingStrategy.NODE, [0, 2], id="node_strategy"), - # When hyperedge 0 (nodes 0, 1) + hyperedge 1 (node 2) -> 2 hyperedges - pytest.param(SamplingStrategy.HYPEREDGE, [0, 1], id="hyperedge_strategy"), - ], -) -def test_getitem_with_multiple_edges_attr(mock_multiple_edges_attr_hypergraph, strategy, index): - with patch.object( - HIFConverter, "load_from_hif", return_value=mock_multiple_edges_attr_hypergraph - ): - dataset = AlgebraDataset(sampling_strategy=strategy) - - data = dataset[index] - assert data.num_hyperedges == 2 - - # Even though the original hypergraph has edge attributes, __getitem__ should return hyperedge_attr as None - # as the hyperedge attributes are handled by the loader's collate function during batching - assert data.hyperedge_attr is None - - -def test_getitem_hyperedge_attr_are_padded_with_zero_when_no_uniform_edges(): - mock_hypergraph = HIFHypergraph( - network_type="undirected", - nodes=[ - {"node": "0", "attrs": {}}, - {"node": "1", "attrs": {}}, - {"node": "2", "attrs": {}}, - {"node": "3", "attrs": {}}, - ], - hyperedges=[ - {"edge": "0", "attrs": {"weight": 1.0, "abc": 5.0}}, - {"edge": "1", "attrs": {"weight": 2.0}}, # Missing 'abc' - {"edge": "2", "attrs": {"abc": 3.0}}, # Missing 'weight' - ], - incidences=[ - {"node": "0", "edge": "0"}, - {"node": "1", "edge": "0"}, - {"node": "2", "edge": "1"}, - {"node": "3", "edge": "2"}, - ], - ) +def test_Preloaded_dataset_init(): + mock_hdata = MagicMock(spec=HData) + dataset = PreloadedDataset(hdata=mock_hdata) + + assert dataset.hdata == mock_hdata + assert dataset.sampling_strategy is SamplingStrategy.HYPEREDGE + +def test_Preloaded_dataset_loads_hdata_when_hdata_is_none(): + mock_hdata = MagicMock(spec=HData) + with patch.object(HIFLoader, "load", return_value=mock_hdata) as mock_load: + dataset = AlgebraDataset(hdata=None) + + assert dataset.hdata == mock_hdata + mock_load.assert_called_once_with("algebra", save_on_disk=True) + +# def test_HIFLoader_num_nodes_and_edges(): +# dataset_name = "ALGEBRA" +# mock_hypergraph = HIFHypergraph( +# network_type="undirected", +# nodes=[{"node": str(i)} for i in range(20)], +# edges=[{"edge": str(i)} for i in range(30)], +# incidences=[{"node": "0", "edge": "0"}], +# ) + +# with patch.object(HIFLoader, "load", return_value=mock_hypergraph): +# hypergraph = HIFLoader.load(dataset_name) + +# assert hypergraph is not None +# assert hasattr(hypergraph, "nodes") +# assert hasattr(hypergraph, "hyperedges") +# assert hasattr(hypergraph, "incidences") +# assert hasattr(hypergraph, "metadata") +# assert hasattr(hypergraph, "network_type") + +# assert hypergraph.num_nodes == 20 +# assert hypergraph.num_hyperedges == 30 + + +# def test_HIFLoader_loads_invalid_dataset(): +# dataset_name = "INVALID_DATASET" + +# with pytest.raises(ValueError, match="Dataset 'INVALID_DATASET' not found"): +# HIFLoader.load(dataset_name) + + +# def test_HIFLoader_loads_invalid_hif_format(): +# dataset_name = "ALGEBRA" + +# invalid_hif_json = '{"network-type": "undirected", "nodes": []}' + +# with ( +# patch("hyperbench.data.dataset.requests.get") as mock_get, +# patch("hyperbench.data.dataset.validate_hif_json", return_value=False), +# patch("builtins.open", mock_open(read_data=invalid_hif_json)), +# patch("hyperbench.data.dataset.zstd.ZstdDecompressor"), +# ): +# mock_response = mock_get.return_value +# mock_response.status_code = 200 +# mock_response.content = b"mock_zst_content" + +# with pytest.raises(ValueError, match="Dataset 'algebra' is not HIF-compliant"): +# HIFLoader.load(dataset_name) + + +# def test_HIFLoader_stores_on_disk_when_save_on_disk_true(): +# dataset_name = "ALGEBRA" + +# mock_hypergraph = HIFHypergraph( +# network_type="undirected", +# nodes=[{"node": "0"}, {"node": "1"}], +# hyperedges=[{"edge": "0"}], +# incidences=[{"node": "0", "edge": "0"}], +# ) + +# mock_hif_json = { +# "network-type": "undirected", +# "nodes": [{"node": "0"}, {"node": "1"}], +# "edges": [{"edge": "0"}], +# "incidences": [{"node": "0", "edge": "0"}], +# } + +# with ( +# patch("hyperbench.data.dataset.requests.get") as mock_get, +# patch("hyperbench.data.dataset.os.path.exists", return_value=False), +# patch("hyperbench.data.dataset.os.makedirs"), +# patch("builtins.open", mock_open()) as mock_file, +# patch("hyperbench.data.dataset.zstd.ZstdDecompressor") as mock_decomp, +# patch("hyperbench.data.dataset.json.load", return_value=mock_hif_json), +# patch("hyperbench.data.dataset.validate_hif_json", return_value=True), +# patch.object(HIFHypergraph, "from_hif", return_value=mock_hypergraph), +# ): +# # Mock successful download +# mock_response = mock_get.return_value +# mock_response.status_code = 200 +# mock_response.content = b"mock_zst_content" + +# # Mock decompressor +# mock_stream = mock_decomp.return_value.stream_reader.return_value +# mock_stream.__enter__ = lambda self: mock_stream +# mock_stream.__exit__ = lambda self, *args: None + +# hypergraph = HIFLoader.load(dataset_name, save_on_disk=True) + +# assert hypergraph is not None +# assert hypergraph.network_type == "undirected" +# mock_get.assert_called_once() +# # Verify file was written to disk (not temp file) +# assert mock_file.call_count >= 2 # Once for write, once for read + + +# def test_HIFLoader_uses_temp_file_when_save_on_disk_false(): +# dataset_name = "ALGEBRA" + +# mock_hypergraph = HIFHypergraph( +# network_type="undirected", +# nodes=[{"node": "0"}, {"node": "1"}], +# hyperedges=[{"edge": "0"}], +# incidences=[{"node": "0", "edge": "0"}], +# ) + +# mock_hif_json = { +# "network-type": "undirected", +# "nodes": [{"node": "0"}, {"node": "1"}], +# "edges": [{"edge": "0"}], +# "incidences": [{"node": "0", "edge": "0"}], +# } + +# with ( +# patch("hyperbench.data.dataset.requests.get") as mock_get, +# patch("hyperbench.data.dataset.os.path.exists", return_value=False), +# patch("hyperbench.data.dataset.tempfile.NamedTemporaryFile") as mock_temp, +# patch("builtins.open", mock_open()), +# patch("hyperbench.data.dataset.zstd.ZstdDecompressor") as mock_decomp, +# patch("hyperbench.data.dataset.json.load", return_value=mock_hif_json), +# patch("hyperbench.data.dataset.validate_hif_json", return_value=True), +# patch.object(HIFHypergraph, "from_hif", return_value=mock_hypergraph), +# ): +# # Mock successful download +# mock_response = mock_get.return_value +# mock_response.status_code = 200 +# mock_response.content = b"mock_zst_content" + +# # Mock temp file +# mock_temp_file = mock_temp.return_value.__enter__.return_value +# mock_temp_file.name = "/tmp/fake_temp.json.zst" + +# # Mock decompressor +# mock_stream = mock_decomp.return_value.stream_reader.return_value +# mock_stream.__enter__ = lambda self: mock_stream +# mock_stream.__exit__ = lambda self, *args: None + +# hypergraph = HIFLoader.load(dataset_name, save_on_disk=False) + +# assert hypergraph is not None +# assert hypergraph.network_type == "undirected" +# mock_get.assert_called_once() +# # Verify temp file was used +# assert mock_temp.call_count >= 1 + + +# def test_HIFLoader_download_failure(): +# dataset_name = "ALGEBRA" + +# with ( +# patch("hyperbench.data.dataset.requests.get") as mock_get, +# patch("hyperbench.data.dataset.hf_hub_download", side_effect=Exception("HFHub failed")), +# patch("hyperbench.data.dataset.os.path.exists", return_value=False), +# ): +# # Mock failed download +# mock_response = mock_get.return_value +# mock_response.status_code = 404 +# mock_response.content = b"" + +# with pytest.warns( +# UserWarning, +# match=r"(?s)GitHub raw download failed for dataset 'algebra' with status code 404.*Falling back to Hugging Face Hub download for dataset", +# ): +# with pytest.raises( +# ValueError, +# match=r"Failed to download dataset 'algebra'", +# ): +# HIFLoader.load(dataset_name) + + +# def test_HIFLoader_falls_back_to_hf_hub_download_when_github_raw_download_fails( +# tmp_path, mock_sample_hypergraph +# ): +# dataset_name = "ALGEBRA" + +# mock_hypergraph = mock_sample_hypergraph + +# mock_hif_json = { +# "network_type": "undirected", +# "nodes": [{"node": "0"}, {"node": "1"}], +# "edges": [{"edge": "0"}], +# "incidences": [{"node": "0", "edge": "0"}], +# } + +# fallback_file = tmp_path / "algebra.json.zst" +# fallback_file.write_bytes(b"mock_zst_content") + +# created_temp_files = [] +# original_named_tempfile = tempfile.NamedTemporaryFile + +# def named_tempfile_side_effect(*args, **kwargs): +# temp_file = original_named_tempfile(*args, **kwargs) +# created_temp_files.append(temp_file) +# return temp_file + +# with ( +# patch("hyperbench.data.dataset.requests.get") as mock_get, +# patch( +# "hyperbench.data.dataset.hf_hub_download", +# return_value=str(fallback_file), +# ) as mock_hf_hub_download, +# patch("hyperbench.data.dataset.os.path.exists", return_value=False), +# patch( +# "hyperbench.data.dataset.tempfile.NamedTemporaryFile", +# side_effect=named_tempfile_side_effect, +# ), +# patch("hyperbench.data.dataset.zstd.ZstdDecompressor") as mock_decomp, +# patch("hyperbench.data.dataset.json.load", return_value=mock_hif_json), +# patch("hyperbench.data.dataset.validate_hif_json", return_value=True), +# patch.object(HIFHypergraph, "from_hif", return_value=mock_hypergraph), +# ): +# mock_response = mock_get.return_value +# mock_response.status_code = 404 +# mock_response.content = b"" + +# def fake_copy_stream(src, dst): +# dst.write(b'{"network_type":"undirected","nodes":[],"edges":[],"incidences":[]}') +# return + +# mock_decomp.return_value.copy_stream.side_effect = fake_copy_stream + +# with pytest.warns( +# UserWarning, +# match=r"(?s)GitHub raw download failed for dataset 'algebra' with status code 404.*Falling back to Hugging Face Hub download for dataset", +# ): +# hypergraph = HIFLoader.load(dataset_name, save_on_disk=False) + +# assert hypergraph.network_type == "undirected" +# mock_get.assert_called_once() +# mock_hf_hub_download.assert_called_once() +# assert created_temp_files[0].name is not None +# assert fallback_file.read_bytes() == b"mock_zst_content" + + +# def test_HIFLoader_download_raises_when_network_error(): +# dataset_name = "ALGEBRA" + +# with ( +# patch("hyperbench.data.dataset.requests.get") as mock_get, +# patch("hyperbench.data.dataset.os.path.exists", return_value=False), +# ): +# # Mock network error +# mock_get.side_effect = requests.RequestException("Network error") + +# with pytest.raises(requests.RequestException, match="Network error"): +# HIFLoader.load(dataset_name) + + +# def test_init_with_prepare_false_and_no_hdata_raises(): +# with pytest.raises(ValueError, match="hdata must be provided when prepare is set to False."): +# Dataset(hdata=None, prepare=False) + + +# def test_dataset_is_not_available(): +# class FakeMockDataset(Dataset): +# DATASET_NAME = "FAKE" + +# with pytest.raises(ValueError, match=r"Dataset 'FAKE' not found"): +# FakeMockDataset() + + +# @pytest.mark.parametrize( +# "strategy, expected_len", +# [ +# pytest.param(SamplingStrategy.NODE, 4, id="node_strategy"), +# pytest.param(SamplingStrategy.HYPEREDGE, 2, id="hyperedge_strategy"), +# ], +# ) +# def test_dataset_is_available_with_all_strategies( +# strategy, expected_len, mock_four_node_hypergraph +# ): +# with patch.object(HIFLoader, "load", return_value=mock_four_node_hypergraph): +# dataset = AlgebraDataset(sampling_strategy=strategy) + +# assert dataset.DATASET_NAME == "ALGEBRA" +# assert dataset.hypergraph is not None +# assert len(dataset) == expected_len + + +# def test_download_already_downloaded_dataset_uses_local_value(mock_four_node_hypergraph): +# with patch.object(HIFLoader, "load", return_value=mock_four_node_hypergraph): +# dataset = AlgebraDataset() + +# hg1 = dataset.download() +# hg2 = dataset.download() + +# assert hg1 is hg2 + + +# def test_throw_when_dataset_name_is_none(): +# class FakeMockDataset(Dataset): +# DATASET_NAME = None + +# with pytest.raises( +# ValueError, +# match=r"Dataset name \(provided: None\) must be provided\.", +# ): +# FakeMockDataset() + + +# def test_dataset_process_no_incidences(): +# mock_hypergraph = HIFHypergraph( +# network_type="undirected", +# nodes=[{"node": "0", "attrs": {}}, {"node": "1", "attrs": {}}], +# hyperedges=[{"edge": "0", "attrs": {}}], +# incidences=[], +# ) + +# with patch.object(HIFLoader, "load", return_value=mock_hypergraph): +# dataset = AlgebraDataset() + +# assert dataset.hdata is not None +# assert dataset.hdata.x.shape[0] == 2 +# assert dataset.hdata.hyperedge_index.shape[0] == 2 +# assert dataset.hdata.hyperedge_index.shape[1] == 2 +# assert dataset.hdata.hyperedge_attr is not None +# assert dataset.hdata.hyperedge_attr.shape == (2, 0) +# assert dataset.hdata.hyperedge_attr[0].shape == (0,) +# assert dataset.hdata.hyperedge_attr[1].shape == (0,) + + +# def test_dataset_process_with_edge_attributes(): +# mock_hypergraph = HIFHypergraph( +# network_type="undirected", +# nodes=[ +# {"node": "0", "attrs": {}}, +# {"node": "1", "attrs": {}}, +# {"node": "2", "attrs": {}}, +# ], +# hyperedges=[ +# {"edge": "0", "attrs": {"weight": 1.0, "type": 2.0}}, +# {"edge": "1", "attrs": {"weight": 3.0, "type": 0.1}}, +# ], +# incidences=[ +# {"node": "0", "edge": "0"}, +# {"node": "1", "edge": "0"}, +# {"node": "2", "edge": "1"}, +# ], +# ) + +# with patch.object(HIFLoader, "load", return_value=mock_hypergraph): +# dataset = AlgebraDataset() + +# assert dataset.hdata is not None +# assert dataset.hdata.x.shape[0] == 3 +# assert dataset.hdata.hyperedge_index.shape[0] == 2 +# assert dataset.hdata.hyperedge_index.shape[1] == 3 +# assert dataset.hdata.hyperedge_attr is not None +# # Two edges with two attributes each: shape [2, 2] +# assert dataset.hdata.hyperedge_attr.shape == (2, 2) +# # Attributes maintain dictionary insertion order (no sorting) + +# assert torch.allclose(dataset.hdata.hyperedge_attr[0], torch.tensor([1.0, 2.0])) # weight, type +# assert torch.allclose(dataset.hdata.hyperedge_attr[1], torch.tensor([3.0, 0.1])) # weight, type + + +# def test_dataset_process_without_edge_attributes(mock_no_edge_attr_hypergraph): +# with patch.object(HIFLoader, "load", return_value=mock_no_edge_attr_hypergraph): +# dataset = AlgebraDataset() + +# assert dataset.hdata is not None +# assert dataset.hdata.hyperedge_index.shape[0] == 2 +# assert dataset.hdata.hyperedge_index.shape[1] == 2 +# assert dataset.hdata.hyperedge_attr is None + + +# def test_dataset_process_hyperedge_index_in_correct_format(mock_four_node_hypergraph): +# with patch.object(HIFLoader, "load", return_value=mock_four_node_hypergraph): +# dataset = AlgebraDataset() + +# assert dataset.hdata.hyperedge_index.shape == (2, 4) +# assert torch.allclose(dataset.hdata.hyperedge_index[0], torch.tensor([0, 1, 2, 3])) +# assert torch.allclose(dataset.hdata.hyperedge_index[1], torch.tensor([0, 0, 1, 1])) + + +# def test_dataset_process_random_ids(): +# mock_hypergraph = HIFHypergraph( +# network_type="undirected", +# nodes=[ +# {"node": "abc", "attrs": {}}, +# {"node": "ss", "attrs": {}}, +# {"node": "fewao", "attrs": {}}, +# ], +# hyperedges=[{"edge": "0", "attrs": {}}, {"edge": "1", "attrs": {}}], +# incidences=[ +# {"node": "abc", "edge": "0"}, +# {"node": "ss", "edge": "0"}, +# {"node": "fewao", "edge": "1"}, +# ], +# ) + +# with patch.object(HIFLoader, "load", return_value=mock_hypergraph): +# dataset = AlgebraDataset() + +# assert dataset.hdata.hyperedge_index.shape == (2, 3) +# assert torch.allclose(dataset.hdata.hyperedge_index[0], torch.tensor([0, 1, 2])) +# assert torch.allclose(dataset.hdata.hyperedge_index[1], torch.tensor([0, 0, 1])) +# assert dataset.hdata.hyperedge_attr is not None +# assert dataset.hdata.hyperedge_attr.shape == (2, 0) # 2 edges, 0 attributes each + + +# @pytest.mark.parametrize( +# "strategy", +# [ +# pytest.param(SamplingStrategy.NODE, id="node_strategy"), +# pytest.param(SamplingStrategy.HYPEREDGE, id="hyperedge_strategy"), +# ], +# ) +# def test_getitem_index_list_empty(mock_simple_hypergraph, strategy): +# with patch.object(HIFLoader, "load", return_value=mock_simple_hypergraph): +# dataset = AlgebraDataset(sampling_strategy=strategy) + +# with pytest.raises(ValueError, match="Index list cannot be empty."): +# dataset[[]] + + +# @pytest.mark.parametrize( +# "strategy, index_list, expected_message", +# [ +# pytest.param( +# SamplingStrategy.NODE, +# [0, 1, 2, 3, 4], +# r"Index list length \(5\) cannot exceed the number of sampleable items \(4\)\.", +# id="node_strategy", +# ), +# pytest.param( +# SamplingStrategy.HYPEREDGE, +# [0, 1, 2], +# r"Index list length \(3\) cannot exceed the number of sampleable items \(2\)\.", +# id="hyperedge_strategy", +# ), +# ], +# ) +# def test_getitem_raises_when_index_list_larger_than_max( +# mock_four_node_hypergraph, strategy, index_list, expected_message +# ): +# with patch.object(HIFLoader, "load", return_value=mock_four_node_hypergraph): +# dataset = AlgebraDataset(sampling_strategy=strategy) + +# with pytest.raises(ValueError, match=expected_message): +# dataset[index_list] + + +# @pytest.mark.parametrize( +# "strategy, index, expected_message", +# [ +# pytest.param( +# SamplingStrategy.NODE, 4, r"Node ID 4 is out of bounds \(0, 3\)\.", id="node_strategy" +# ), +# pytest.param( +# SamplingStrategy.HYPEREDGE, +# 2, +# r"Hyperedge ID 2 is out of bounds \(0, 1\)\.", +# id="hyperedge_strategy", +# ), +# ], +# ) +# def test_getitem_raises_when_index_out_of_bounds( +# mock_four_node_hypergraph, strategy, index, expected_message +# ): +# with patch.object(HIFLoader, "load", return_value=mock_four_node_hypergraph): +# dataset = AlgebraDataset(sampling_strategy=strategy) + +# with pytest.raises(IndexError, match=expected_message): +# dataset[index] + + +# @pytest.mark.parametrize( +# "strategy, index, expected_shape, expected_num_hyperedges", +# [ +# # When node 1 is selected, we get hyperedge 0 with nodes 0 and 1 -> 2 incidences, 1 hyperedge +# pytest.param(SamplingStrategy.NODE, 1, (2, 1), 1, id="node_strategy"), +# # When hyperedge 0 is selected, we get nodes 0 and 1 -> 2 incidences, 1 hyperedge +# pytest.param(SamplingStrategy.HYPEREDGE, 0, (2, 1), 1, id="hyperedge_strategy"), +# ], +# ) +# def test_getitem_single_index( +# mock_sample_hypergraph, strategy, index, expected_shape, expected_num_hyperedges +# ): +# with patch.object(HIFLoader, "load", return_value=mock_sample_hypergraph): +# dataset = AlgebraDataset(sampling_strategy=strategy) + +# data = dataset[index] + +# assert data.hyperedge_index.shape == expected_shape +# assert data.num_hyperedges == expected_num_hyperedges + + +# @pytest.mark.parametrize( +# "strategy, index, expected_shape, expected_num_hyperedges", +# [ +# # When nodes (0, 2, 3) -> hyperedge 0 (nodes 0, 1) + hyperedge 1 (nodes 2, 3) -> 4 incidences, 2 hyperedges +# pytest.param(SamplingStrategy.NODE, [0, 2, 3], (2, 4), 2, id="node_strategy"), +# # When hyperedge 0 (nodes 0, 1) + hyperedge 1 (nodes 2, 3) -> 4 incidences, 2 hyperedges +# pytest.param(SamplingStrategy.HYPEREDGE, [0, 1], (2, 4), 2, id="hyperedge_strategy"), +# ], +# ) +# def test_getitem_when_list_index_provided( +# mock_four_node_hypergraph, strategy, index, expected_shape, expected_num_hyperedges +# ): +# with patch.object(HIFLoader, "load", return_value=mock_four_node_hypergraph): +# dataset = AlgebraDataset(sampling_strategy=strategy) + +# data = dataset[index] + +# assert data.hyperedge_index.shape == expected_shape +# assert data.num_hyperedges == expected_num_hyperedges + + +# @pytest.mark.parametrize( +# "strategy", +# [ +# pytest.param(SamplingStrategy.NODE, id="node_strategy"), +# pytest.param(SamplingStrategy.HYPEREDGE, id="hyperedge_strategy"), +# ], +# ) +# def test_getitem_with_edge_attr(mock_three_node_weighted_hypergraph, strategy): +# with patch.object( +# HIFLoader, "load", return_value=mock_three_node_weighted_hypergraph +# ): +# dataset = AlgebraDataset(sampling_strategy=strategy) + +# data = dataset[0] + +# assert data.hyperedge_index.shape == (2, 2) +# assert data.num_hyperedges == 1 +# assert data.hyperedge_attr is None + + +# @pytest.mark.parametrize( +# "strategy", +# [ +# pytest.param(SamplingStrategy.NODE, id="node_strategy"), +# pytest.param(SamplingStrategy.HYPEREDGE, id="hyperedge_strategy"), +# ], +# ) +# def test_getitem_without_edge_attr(mock_no_edge_attr_hypergraph, strategy): +# with patch.object(HIFLoader, "load", return_value=mock_no_edge_attr_hypergraph): +# dataset = AlgebraDataset(sampling_strategy=strategy) + +# data = dataset[0] +# assert data.hyperedge_attr is None + + +# @pytest.mark.parametrize( +# "strategy, index", +# [ +# # When nodes 0,2 -> hyperedge 0 (nodes 0, 1) + hyperedge 1 (node 2) -> 2 hyperedges +# pytest.param(SamplingStrategy.NODE, [0, 2], id="node_strategy"), +# # When hyperedge 0 (nodes 0, 1) + hyperedge 1 (node 2) -> 2 hyperedges +# pytest.param(SamplingStrategy.HYPEREDGE, [0, 1], id="hyperedge_strategy"), +# ], +# ) +# def test_getitem_with_multiple_edges_attr(mock_multiple_edges_attr_hypergraph, strategy, index): +# with patch.object( +# HIFLoader, "load", return_value=mock_multiple_edges_attr_hypergraph +# ): +# dataset = AlgebraDataset(sampling_strategy=strategy) + +# data = dataset[index] +# assert data.num_hyperedges == 2 + +# # Even though the original hypergraph has edge attributes, __getitem__ should return hyperedge_attr as None +# # as the hyperedge attributes are handled by the loader's collate function during batching +# assert data.hyperedge_attr is None + + +# def test_getitem_hyperedge_attr_are_padded_with_zero_when_no_uniform_edges(): +# mock_hypergraph = HIFHypergraph( +# network_type="undirected", +# nodes=[ +# {"node": "0", "attrs": {}}, +# {"node": "1", "attrs": {}}, +# {"node": "2", "attrs": {}}, +# {"node": "3", "attrs": {}}, +# ], +# edges=[ +# {"edge": "0", "attrs": {"weight": 1.0, "abc": 5.0}}, +# {"edge": "1", "attrs": {"weight": 2.0}}, # Missing 'abc' +# {"edge": "2", "attrs": {"abc": 3.0}}, # Missing 'weight' +# ], +# incidences=[ +# {"node": "0", "edge": "0"}, +# {"node": "1", "edge": "0"}, +# {"node": "2", "edge": "1"}, +# {"node": "3", "edge": "2"}, +# ], +# ) with patch.object(HIFConverter, "load_from_hif", return_value=mock_hypergraph): dataset = AlgebraDataset() @@ -814,195 +828,195 @@ def test_process_extracts_top_level_hyperedge_weights(): assert torch.allclose(hyperedge_weights[2], torch.tensor([2.5])) -def test_transform_attrs_empty_attrs(): - mock_hypergraph = HIFHypergraph( - network_type="undirected", - nodes=[{"node": "0", "attrs": {}}], - hyperedges=[{"edge": "0", "attrs": {}}], - incidences=[{"node": "0", "edge": "0"}], - ) +# def test_transform_attrs_empty_attrs(): +# mock_hypergraph = HIFHypergraph( +# network_type="undirected", +# nodes=[{"node": "0", "attrs": {}}], +# hyperedges=[{"edge": "0", "attrs": {}}], +# incidences=[{"node": "0", "edge": "0"}], +# ) - with patch.object(HIFConverter, "load_from_hif", return_value=mock_hypergraph): +# with patch.object(HIFLoader, "load", return_value=mock_hypergraph): + +# class TestDataset(Dataset): +# DATASET_NAME = "TEST" - class TestDataset(Dataset): - DATASET_NAME = "TEST" +# dataset = TestDataset() - dataset = TestDataset() +# result = dataset.transform_attrs({}) +# assert len(result) == 0 - result = dataset.transform_attrs({}) - assert len(result) == 0 +# attrs = {"name": "node1", "active": True} +# result = dataset.transform_attrs(attrs) +# assert len(result) == 0 - attrs = {"name": "node1", "active": True} - result = dataset.transform_attrs(attrs) - assert len(result) == 0 +# def test_process_adds_padding_zero_when_inconsistent_node_attributes(): +# mock_hypergraph = HIFHypergraph( +# network_type="undirected", +# nodes=[ +# {"node": "0", "attrs": {"weight": 1.0}}, # Missing 'score' +# {"node": "1", "attrs": {"weight": 2.0, "score": 0.8}}, +# {"node": "2", "attrs": {"score": 0.5}}, # Missing 'weight' +# ], +# hyperedges=[{"edge": "0", "attrs": {}}], +# incidences=[ +# {"node": "0", "edge": "0"}, +# {"node": "1", "edge": "0"}, +# {"node": "2", "edge": "0"}, +# ], +# ) + +# with patch.object(HIFLoader, "load", return_value=mock_hypergraph): -def test_process_adds_padding_zero_when_inconsistent_node_attributes(): - mock_hypergraph = HIFHypergraph( - network_type="undirected", - nodes=[ - {"node": "0", "attrs": {"weight": 1.0}}, # Missing 'score' - {"node": "1", "attrs": {"weight": 2.0, "score": 0.8}}, - {"node": "2", "attrs": {"score": 0.5}}, # Missing 'weight' - ], - hyperedges=[{"edge": "0", "attrs": {}}], - incidences=[ - {"node": "0", "edge": "0"}, - {"node": "1", "edge": "0"}, - {"node": "2", "edge": "0"}, - ], - ) +# class TestDataset(Dataset): +# DATASET_NAME = "TEST" + +# dataset = TestDataset() + +# assert dataset.hdata.x.shape == ( +# 3, +# 2, +# ) # 3 nodes, 2 features each (weight, score in insertion order) +# assert torch.allclose(dataset.hdata.x[0], torch.tensor([1.0, 0.0])) # weight=1.0, score=0.0 +# assert torch.allclose(dataset.hdata.x[1], torch.tensor([2.0, 0.8])) # weight=2.0, score=0.8 +# assert torch.allclose(dataset.hdata.x[2], torch.tensor([0.0, 0.5])) # weight=0.0, score=0.5 + + +# def test_process_with_no_node_attributes_fallback_to_one(): +# mock_hypergraph = HIFHypergraph( +# network_type="undirected", +# nodes=[ +# {"node": "0", "attrs": {"name": "node0"}}, +# {"node": "1", "attrs": {}}, +# ], +# hyperedges=[{"edge": "0", "attrs": {}}], +# incidences=[{"node": "0", "edge": "0"}, {"node": "1", "edge": "0"}], +# ) - with patch.object(HIFConverter, "load_from_hif", return_value=mock_hypergraph): +# with patch.object(HIFLoader, "load", return_value=mock_hypergraph): - class TestDataset(Dataset): - DATASET_NAME = "TEST" +# class TestDataset(Dataset): +# DATASET_NAME = "TEST" - dataset = TestDataset() +# dataset = TestDataset() - assert dataset.hdata.x.shape == ( - 3, - 2, - ) # 3 nodes, 2 features each (weight, score in insertion order) - assert torch.allclose(dataset.hdata.x[0], torch.tensor([1.0, 0.0])) # weight=1.0, score=0.0 - assert torch.allclose(dataset.hdata.x[1], torch.tensor([2.0, 0.8])) # weight=2.0, score=0.8 - assert torch.allclose(dataset.hdata.x[2], torch.tensor([0.0, 0.5])) # weight=0.0, score=0.5 +# assert dataset.hdata.x.shape == (2, 1) +# assert torch.allclose(dataset.hdata.x, torch.tensor([[1.0], [1.0]])) -def test_process_with_no_node_attributes_fallback_to_one(): - mock_hypergraph = HIFHypergraph( - network_type="undirected", - nodes=[ - {"node": "0", "attrs": {"name": "node0"}}, - {"node": "1", "attrs": {}}, - ], - hyperedges=[{"edge": "0", "attrs": {}}], - incidences=[{"node": "0", "edge": "0"}, {"node": "1", "edge": "0"}], - ) - - with patch.object(HIFConverter, "load_from_hif", return_value=mock_hypergraph): - - class TestDataset(Dataset): - DATASET_NAME = "TEST" - - dataset = TestDataset() - - assert dataset.hdata.x.shape == (2, 1) - assert torch.allclose(dataset.hdata.x, torch.tensor([[1.0], [1.0]])) - - -def test_process_with_single_node_attribute(): - mock_hypergraph = HIFHypergraph( - network_type="undirected", - nodes=[ - {"node": "0", "attrs": {"weight": 1.5}}, - {"node": "1", "attrs": {"weight": 2.5}}, - {"node": "2", "attrs": {"weight": 3.5}}, - ], - hyperedges=[{"edge": "0", "attrs": {}}], - incidences=[ - {"node": "0", "edge": "0"}, - {"node": "1", "edge": "0"}, - {"node": "2", "edge": "0"}, - ], - ) +# def test_process_with_single_node_attribute(): +# mock_hypergraph = HIFHypergraph( +# network_type="undirected", +# nodes=[ +# {"node": "0", "attrs": {"weight": 1.5}}, +# {"node": "1", "attrs": {"weight": 2.5}}, +# {"node": "2", "attrs": {"weight": 3.5}}, +# ], +# hyperedges=[{"edge": "0", "attrs": {}}], +# incidences=[ +# {"node": "0", "edge": "0"}, +# {"node": "1", "edge": "0"}, +# {"node": "2", "edge": "0"}, +# ], +# ) - with patch.object(HIFConverter, "load_from_hif", return_value=mock_hypergraph): +# with patch.object(HIFLoader, "load", return_value=mock_hypergraph): - class TestDataset(Dataset): - DATASET_NAME = "TEST" +# class TestDataset(Dataset): +# DATASET_NAME = "TEST" - dataset = TestDataset() +# dataset = TestDataset() - # Single attribute should remain 2D: [num_nodes, 1] - assert dataset.hdata.x.shape == (3, 1) - assert torch.allclose(dataset.hdata.x, torch.tensor([[1.5], [2.5], [3.5]])) +# # Single attribute should remain 2D: [num_nodes, 1] +# assert dataset.hdata.x.shape == (3, 1) +# assert torch.allclose(dataset.hdata.x, torch.tensor([[1.5], [2.5], [3.5]])) -def test_transform_attrs_adds_padding_zero_when_attr_keys_padding(): - mock_hypergraph = HIFHypergraph( - network_type="undirected", - nodes=[{"node": "0", "attrs": {}}], - hyperedges=[{"edge": "0", "attrs": {}}], - incidences=[{"node": "0", "edge": "0"}], - ) +# def test_transform_attrs_adds_padding_zero_when_attr_keys_padding(): +# mock_hypergraph = HIFHypergraph( +# network_type="undirected", +# nodes=[{"node": "0", "attrs": {}}], +# hyperedges=[{"edge": "0", "attrs": {}}], +# incidences=[{"node": "0", "edge": "0"}], +# ) - with patch.object(HIFConverter, "load_from_hif", return_value=mock_hypergraph): +# with patch.object(HIFLoader, "load", return_value=mock_hypergraph): - class TestDataset(Dataset): - DATASET_NAME = "TEST" +# class TestDataset(Dataset): +# DATASET_NAME = "TEST" - dataset = TestDataset() +# dataset = TestDataset() - # Test with attr_keys - should pad missing attributes with 0.0 - attrs = {"weight": 1.5} - result = dataset.transform_attrs(attrs, attr_keys=["score", "weight", "age"]) - assert torch.allclose( - result, torch.tensor([0.0, 1.5, 0.0]) - ) # score=0.0, weight=1.5, age=0.0 +# # Test with attr_keys - should pad missing attributes with 0.0 +# attrs = {"weight": 1.5} +# result = dataset.transform_attrs(attrs, attr_keys=["score", "weight", "age"]) +# assert torch.allclose( +# result, torch.tensor([0.0, 1.5, 0.0]) +# ) # score=0.0, weight=1.5, age=0.0 - # Test with all attributes present - attrs = {"weight": 1.5, "score": 0.8, "age": 25.0} - result = dataset.transform_attrs(attrs, attr_keys=["age", "score", "weight"]) - assert torch.allclose( - result, torch.tensor([25.0, 0.8, 1.5]) - ) # age=25.0, score=0.8, weight=1.5 +# # Test with all attributes present +# attrs = {"weight": 1.5, "score": 0.8, "age": 25.0} +# result = dataset.transform_attrs(attrs, attr_keys=["age", "score", "weight"]) +# assert torch.allclose( +# result, torch.tensor([25.0, 0.8, 1.5]) +# ) # age=25.0, score=0.8, weight=1.5 - # Test without attr_keys - maintains insertion order - attrs = {"weight": 1.5, "score": 0.8} - result = dataset.transform_attrs(attrs) - assert torch.allclose(result, torch.tensor([1.5, 0.8])) # weight, score (insertion order) +# # Test without attr_keys - maintains insertion order +# attrs = {"weight": 1.5, "score": 0.8} +# result = dataset.transform_attrs(attrs) +# assert torch.allclose(result, torch.tensor([1.5, 0.8])) # weight, score (insertion order) -@pytest.mark.parametrize( - "strategy, expected_len", - [ - # mock_hdata: 3 nodes, 2 hyperedges - pytest.param(SamplingStrategy.NODE, 3, id="node_strategy"), - pytest.param(SamplingStrategy.HYPEREDGE, 2, id="hyperedge_strategy"), - ], -) -def test_from_hdata(strategy, expected_len, mock_hdata): - dataset = Dataset.from_hdata(mock_hdata, sampling_strategy=strategy) +# @pytest.mark.parametrize( +# "strategy, expected_len", +# [ +# # mock_hdata: 3 nodes, 2 hyperedges +# pytest.param(SamplingStrategy.NODE, 3, id="node_strategy"), +# pytest.param(SamplingStrategy.HYPEREDGE, 2, id="hyperedge_strategy"), +# ], +# ) +# def test_from_hdata(strategy, expected_len, mock_hdata): +# dataset = Dataset.from_hdata(mock_hdata, sampling_strategy=strategy) - assert dataset.hdata is mock_hdata - assert len(dataset) == expected_len +# assert dataset.hdata is mock_hdata +# assert len(dataset) == expected_len -def test_from_hdata_download_raises(mock_hdata): - dataset = Dataset.from_hdata(mock_hdata) +# def test_from_hdata_download_raises(mock_hdata): +# dataset = Dataset.from_hdata(mock_hdata) - with pytest.raises(ValueError, match="download can only be called for the original dataset."): - dataset.download() +# with pytest.raises(ValueError, match="download can only be called for the original dataset."): +# dataset.download() -def test_from_hdata_process_raises(mock_hdata): - dataset = Dataset.from_hdata(mock_hdata) +# def test_from_hdata_process_raises(mock_hdata): +# dataset = Dataset.from_hdata(mock_hdata) - with pytest.raises(ValueError, match="process can only be called for the original dataset."): - dataset.process() +# with pytest.raises(ValueError, match="process can only be called for the original dataset."): +# dataset.process() -def test_enrich_node_features_replace(mock_hdata): - dataset = Dataset.from_hdata(mock_hdata) +# def test_enrich_node_features_replace(mock_hdata): +# dataset = Dataset.from_hdata(mock_hdata) - enricher = MagicMock(spec=NodeEnricher) - enriched_x = torch.randn(3, 4) - enricher.enrich.return_value = enriched_x +# enricher = MagicMock(spec=NodeEnricher) +# enriched_x = torch.randn(3, 4) +# enricher.enrich.return_value = enriched_x - dataset.enrich_node_features(enricher) +# dataset.enrich_node_features(enricher) - enricher.enrich.assert_called_once_with(mock_hdata.hyperedge_index) - assert torch.equal(dataset.hdata.x, enriched_x) +# enricher.enrich.assert_called_once_with(mock_hdata.hyperedge_index) +# assert torch.equal(dataset.hdata.x, enriched_x) -def test_enrich_node_features_concatenate(mock_hdata): - dataset = Dataset.from_hdata(mock_hdata) - original_x = dataset.hdata.x.clone() +# def test_enrich_node_features_concatenate(mock_hdata): +# dataset = Dataset.from_hdata(mock_hdata) +# original_x = dataset.hdata.x.clone() - enricher = MagicMock(spec=NodeEnricher) - enriched_x = torch.randn(3, 4) - enricher.enrich.return_value = enriched_x +# enricher = MagicMock(spec=NodeEnricher) +# enriched_x = torch.randn(3, 4) +# enricher.enrich.return_value = enriched_x dataset.enrich_node_features(enricher, enrichment_mode="concatenate") @@ -1082,341 +1096,341 @@ def test_enrich_hyperedge_weights_concatenate(mock_hdata_with_hyperedge_weights) assert hyperedge_weights.shape == (6,) # 3 original + 3 enriched -@pytest.mark.parametrize( - "hyperedge_index, k, expected_hyperedge_index", - [ - pytest.param( - torch.tensor([[0, 1, 2], [0, 0, 0]]), - 4, - torch.zeros((2, 0), dtype=torch.long), - id="single_hyperedge_below_k_removed", - ), - pytest.param( - torch.tensor([[0, 1, 2], [0, 0, 0]]), - 3, - torch.tensor([[0, 1, 2], [0, 0, 0]]), - id="single_hyperedge_at_exact_k_kept", - ), - pytest.param( - torch.tensor([[0, 1, 2, 3, 4], [0, 0, 0, 1, 1]]), - 3, - torch.tensor([[0, 1, 2], [0, 0, 0]]), - id="two_hyperedges_first_kept_second_removed", - ), - pytest.param( - torch.tensor([[0, 1, 2, 3, 4, 5], [0, 0, 0, 1, 1, 1]]), - 3, - torch.tensor([[0, 1, 2, 3, 4, 5], [0, 0, 0, 1, 1, 1]]), - id="two_hyperedges_both_kept", - ), - pytest.param( - torch.tensor([[0, 1, 2, 3, 4, 5], [0, 0, 1, 1, 2, 2]]), - 3, - torch.zeros((2, 0), dtype=torch.long), - id="three_hyperedges_all_removed", - ), - ], -) -def test_remove_hyperedges_with_fewer_than_k_nodes(hyperedge_index, k, expected_hyperedge_index): - num_nodes = hyperedge_index[0].max().item() + 1 if hyperedge_index.shape[1] > 0 else 0 - x = torch.ones((num_nodes, 1), dtype=torch.float) - hdata = HData(x=x, hyperedge_index=hyperedge_index) - dataset = Dataset.from_hdata(hdata) - - dataset.remove_hyperedges_with_fewer_than_k_nodes(k) - - expected_num_nodes = expected_hyperedge_index[0].unique().shape[0] - expected_num_hyperedges = expected_hyperedge_index[1].unique().shape[0] - - assert torch.equal(dataset.hdata.hyperedge_index, expected_hyperedge_index) - assert dataset.hdata.x.shape[0] == expected_num_nodes - assert dataset.hdata.y.shape[0] == expected_num_hyperedges - - -def test_split_with_equal_ratios(mock_four_node_hypergraph): - with patch.object(HIFConverter, "load_from_hif", return_value=mock_four_node_hypergraph): - dataset = AlgebraDataset() +# @pytest.mark.parametrize( +# "hyperedge_index, k, expected_hyperedge_index", +# [ +# pytest.param( +# torch.tensor([[0, 1, 2], [0, 0, 0]]), +# 4, +# torch.zeros((2, 0), dtype=torch.long), +# id="single_hyperedge_below_k_removed", +# ), +# pytest.param( +# torch.tensor([[0, 1, 2], [0, 0, 0]]), +# 3, +# torch.tensor([[0, 1, 2], [0, 0, 0]]), +# id="single_hyperedge_at_exact_k_kept", +# ), +# pytest.param( +# torch.tensor([[0, 1, 2, 3, 4], [0, 0, 0, 1, 1]]), +# 3, +# torch.tensor([[0, 1, 2], [0, 0, 0]]), +# id="two_hyperedges_first_kept_second_removed", +# ), +# pytest.param( +# torch.tensor([[0, 1, 2, 3, 4, 5], [0, 0, 0, 1, 1, 1]]), +# 3, +# torch.tensor([[0, 1, 2, 3, 4, 5], [0, 0, 0, 1, 1, 1]]), +# id="two_hyperedges_both_kept", +# ), +# pytest.param( +# torch.tensor([[0, 1, 2, 3, 4, 5], [0, 0, 1, 1, 2, 2]]), +# 3, +# torch.zeros((2, 0), dtype=torch.long), +# id="three_hyperedges_all_removed", +# ), +# ], +# ) +# def test_remove_hyperedges_with_fewer_than_k_nodes(hyperedge_index, k, expected_hyperedge_index): +# num_nodes = hyperedge_index[0].max().item() + 1 if hyperedge_index.shape[1] > 0 else 0 +# x = torch.ones((num_nodes, 1), dtype=torch.float) +# hdata = HData(x=x, hyperedge_index=hyperedge_index) +# dataset = Dataset.from_hdata(hdata) + +# dataset.remove_hyperedges_with_fewer_than_k_nodes(k) + +# expected_num_nodes = expected_hyperedge_index[0].unique().shape[0] +# expected_num_hyperedges = expected_hyperedge_index[1].unique().shape[0] + +# assert torch.equal(dataset.hdata.hyperedge_index, expected_hyperedge_index) +# assert dataset.hdata.x.shape[0] == expected_num_nodes +# assert dataset.hdata.y.shape[0] == expected_num_hyperedges + + +# def test_split_with_equal_ratios(mock_four_node_hypergraph): +# with patch.object(HIFLoader, "load", return_value=mock_four_node_hypergraph): +# dataset = AlgebraDataset() - splits = dataset.split([0.5, 0.5]) +# splits = dataset.split([0.5, 0.5]) - assert len(splits) == 2 - assert ( - splits[0].hdata.num_hyperedges + splits[1].hdata.num_hyperedges - == dataset.hdata.num_hyperedges - ) - for split in splits: - assert split.hdata.x is not None - assert split.hdata.num_nodes > 0 - assert split.hdata.num_hyperedges > 0 +# assert len(splits) == 2 +# assert ( +# splits[0].hdata.num_hyperedges + splits[1].hdata.num_hyperedges +# == dataset.hdata.num_hyperedges +# ) +# for split in splits: +# assert split.hdata.x is not None +# assert split.hdata.num_nodes > 0 +# assert split.hdata.num_hyperedges > 0 -def test_split_three_way(mock_multiple_edges_attr_hypergraph): - with patch.object( - HIFConverter, "load_from_hif", return_value=mock_multiple_edges_attr_hypergraph - ): - dataset = AlgebraDataset() +# def test_split_three_way(mock_multiple_edges_attr_hypergraph): +# with patch.object( +# HIFLoader, "load", return_value=mock_multiple_edges_attr_hypergraph +# ): +# dataset = AlgebraDataset() + +# splits = dataset.split([0.5, 0.25, 0.25]) +# total_edges = sum(split.hdata.num_hyperedges for split in splits) - splits = dataset.split([0.5, 0.25, 0.25]) - total_edges = sum(split.hdata.num_hyperedges for split in splits) +# assert len(splits) == 3 +# assert total_edges == dataset.hdata.num_hyperedges + +# for split in splits: +# assert split.hdata.x is not None +# assert split.hdata.num_nodes > 0 +# assert split.hdata.num_hyperedges > 0 + + +# def test_split_raises_when_ratios_do_not_sum_to_one(mock_four_node_hypergraph): +# with patch.object(HIFLoader, "load", return_value=mock_four_node_hypergraph): +# dataset = AlgebraDataset() + +# with pytest.raises(ValueError, match="Split ratios must sum to 1.0"): +# dataset.split([0.8, 0.1, 0.05]) + + +# def test_split_with_shuffle_produces_deterministic_results_when_seed_provided( +# mock_four_node_hypergraph, +# ): +# with patch.object(HIFLoader, "load", return_value=mock_four_node_hypergraph): +# dataset = AlgebraDataset() + +# splits_a = dataset.split([0.5, 0.5], shuffle=True, seed=42) +# splits_b = dataset.split([0.5, 0.5], shuffle=True, seed=42) + +# assert torch.equal(splits_a[0].hdata.hyperedge_index, splits_b[0].hdata.hyperedge_index) +# assert torch.equal(splits_a[1].hdata.hyperedge_index, splits_b[1].hdata.hyperedge_index) + + +# def test_split_with_shuffle_when_no_seed_provided( +# mock_four_node_hypergraph, +# ): +# with patch.object(HIFLoader, "load", return_value=mock_four_node_hypergraph): +# dataset = AlgebraDataset() - assert len(splits) == 3 - assert total_edges == dataset.hdata.num_hyperedges +# splits = dataset.split([0.5, 0.5], shuffle=True) +# total_edges = sum(split.hdata.num_hyperedges for split in splits) - for split in splits: - assert split.hdata.x is not None - assert split.hdata.num_nodes > 0 - assert split.hdata.num_hyperedges > 0 +# assert len(splits) == 2 +# assert total_edges == dataset.hdata.num_hyperedges +# for split in splits: +# assert split.hdata.x is not None +# assert split.hdata.num_nodes > 0 +# assert split.hdata.num_hyperedges > 0 -def test_split_raises_when_ratios_do_not_sum_to_one(mock_four_node_hypergraph): - with patch.object(HIFConverter, "load_from_hif", return_value=mock_four_node_hypergraph): - dataset = AlgebraDataset() - with pytest.raises(ValueError, match="Split ratios must sum to 1.0"): - dataset.split([0.8, 0.1, 0.05]) +# def test_split_preserves_edge_attr(mock_multiple_edges_attr_hypergraph): +# with patch.object( +# HIFLoader, "load", return_value=mock_multiple_edges_attr_hypergraph +# ): +# dataset = AlgebraDataset() +# splits = dataset.split([0.5, 0.5]) -def test_split_with_shuffle_produces_deterministic_results_when_seed_provided( - mock_four_node_hypergraph, -): - with patch.object(HIFConverter, "load_from_hif", return_value=mock_four_node_hypergraph): - dataset = AlgebraDataset() +# for split in splits: +# assert split.hdata.hyperedge_attr is not None +# assert split.hdata.hyperedge_attr.shape[0] == split.hdata.num_hyperedges - splits_a = dataset.split([0.5, 0.5], shuffle=True, seed=42) - splits_b = dataset.split([0.5, 0.5], shuffle=True, seed=42) - assert torch.equal(splits_a[0].hdata.hyperedge_index, splits_b[0].hdata.hyperedge_index) - assert torch.equal(splits_a[1].hdata.hyperedge_index, splits_b[1].hdata.hyperedge_index) +# def test_split_without_edge_attr(mock_no_edge_attr_hypergraph): +# with patch.object(HIFLoader, "load", return_value=mock_no_edge_attr_hypergraph): +# dataset = AlgebraDataset() +# splits = dataset.split([0.5, 0.5]) -def test_split_with_shuffle_when_no_seed_provided( - mock_four_node_hypergraph, -): - with patch.object(HIFConverter, "load_from_hif", return_value=mock_four_node_hypergraph): - dataset = AlgebraDataset() +# for split in splits: +# assert split.hdata.hyperedge_attr is None - splits = dataset.split([0.5, 0.5], shuffle=True) - total_edges = sum(split.hdata.num_hyperedges for split in splits) - assert len(splits) == 2 - assert total_edges == dataset.hdata.num_hyperedges +# def test_to_device(mock_hdata): +# device = torch.device("cpu") - for split in splits: - assert split.hdata.x is not None - assert split.hdata.num_nodes > 0 - assert split.hdata.num_hyperedges > 0 +# dataset = Dataset.from_hdata(mock_hdata) +# result = dataset.to(device) -def test_split_preserves_edge_attr(mock_multiple_edges_attr_hypergraph): - with patch.object( - HIFConverter, "load_from_hif", return_value=mock_multiple_edges_attr_hypergraph - ): - dataset = AlgebraDataset() +# assert result is dataset +# assert dataset.hdata.device == device - splits = dataset.split([0.5, 0.5]) - for split in splits: - assert split.hdata.hyperedge_attr is not None - assert split.hdata.hyperedge_attr.shape[0] == split.hdata.num_hyperedges +# def test_load_skips_download_when_file_exists(): +# dataset_name = "ALGEBRA" +# sample_hif = { +# "network-type": "undirected", +# "nodes": [{"node": "0"}, {"node": "1"}], +# "edges": [{"edge": "0"}], +# "incidences": [{"node": "0", "edge": "0"}], +# } -def test_split_without_edge_attr(mock_no_edge_attr_hypergraph): - with patch.object(HIFConverter, "load_from_hif", return_value=mock_no_edge_attr_hypergraph): - dataset = AlgebraDataset() +# mock_hypergraph = HIFHypergraph( +# network_type="undirected", +# nodes=[{"node": "0"}, {"node": "1"}], +# hyperedges=[{"edge": "0"}], +# incidences=[{"node": "0", "edge": "0"}], +# ) - splits = dataset.split([0.5, 0.5]) +# with ( +# patch("hyperbench.data.dataset.requests.get") as mock_get, +# patch("hyperbench.data.dataset.os.path.exists", return_value=True), +# patch("builtins.open", mock_open()) as mock_file, +# patch("hyperbench.data.dataset.zstd.ZstdDecompressor") as mock_decomp, +# patch("hyperbench.data.dataset.tempfile.NamedTemporaryFile") as mock_temp, +# patch("hyperbench.data.dataset.json.load", return_value=sample_hif), +# patch("hyperbench.data.dataset.validate_hif_json", return_value=True), +# patch.object(HIFHypergraph, "from_hif", return_value=mock_hypergraph), +# ): +# mock_dctx = mock_decomp.return_value +# mock_dctx.copy_stream = lambda input_f, tmp_file: None - for split in splits: - assert split.hdata.hyperedge_attr is None +# mock_temp_instance = mock_temp.return_value.__enter__.return_value +# mock_temp_instance.name = "/tmp/decompressed.json" +# result = HIFLoader.load(dataset_name, save_on_disk=True) +# mock_get.assert_not_called() +# assert result == mock_hypergraph -def test_to_device(mock_hdata): - device = torch.device("cpu") - - dataset = Dataset.from_hdata(mock_hdata) - - result = dataset.to(device) - assert result is dataset - assert dataset.hdata.device == device +# def test_default_sampling_strategy_is_hyperedge(mock_four_node_hypergraph): +# with patch.object(HIFLoader, "load", return_value=mock_four_node_hypergraph): +# dataset = AlgebraDataset() +# # Default strategy is HYPEREDGE, so len should be num_hyperedges (2), not num_nodes (4) +# assert dataset.sampling_strategy == SamplingStrategy.HYPEREDGE +# assert len(dataset) == 2 -def test_load_from_hif_skips_download_when_file_exists(): - dataset_name = "ALGEBRA" - sample_hif = { - "network-type": "undirected", - "nodes": [{"node": "0"}, {"node": "1"}], - "edges": [{"edge": "0"}], - "incidences": [{"node": "0", "edge": "0"}], - } +# def test_explicit_node_sampling_strategy(mock_four_node_hypergraph): +# with patch.object(HIFLoader, "load", return_value=mock_four_node_hypergraph): +# dataset = AlgebraDataset(sampling_strategy=SamplingStrategy.NODE) - mock_hypergraph = HIFHypergraph( - network_type="undirected", - nodes=[{"node": "0"}, {"node": "1"}], - hyperedges=[{"edge": "0"}], - incidences=[{"node": "0", "edge": "0"}], - ) +# # NODE strategy, so len should be num_nodes (4), not num_hyperedges (2) +# assert dataset.sampling_strategy == SamplingStrategy.NODE +# assert len(dataset) == 4 - with ( - patch("hyperbench.data.dataset.requests.get") as mock_get, - patch("hyperbench.data.dataset.os.path.exists", return_value=True), - patch("builtins.open", mock_open()) as mock_file, - patch("hyperbench.data.dataset.zstd.ZstdDecompressor") as mock_decomp, - patch("hyperbench.data.dataset.tempfile.NamedTemporaryFile") as mock_temp, - patch("hyperbench.data.dataset.json.load", return_value=sample_hif), - patch("hyperbench.data.dataset.validate_hif_json", return_value=True), - patch.object(HIFHypergraph, "from_hif", return_value=mock_hypergraph), - ): - mock_dctx = mock_decomp.return_value - mock_dctx.copy_stream = lambda input_f, tmp_file: None - - mock_temp_instance = mock_temp.return_value.__enter__.return_value - mock_temp_instance.name = "/tmp/decompressed.json" - - result = HIFConverter.load_from_hif(dataset_name, save_on_disk=True) - mock_get.assert_not_called() - assert result == mock_hypergraph - - -def test_default_sampling_strategy_is_hyperedge(mock_four_node_hypergraph): - with patch.object(HIFConverter, "load_from_hif", return_value=mock_four_node_hypergraph): - dataset = AlgebraDataset() - # Default strategy is HYPEREDGE, so len should be num_hyperedges (2), not num_nodes (4) - assert dataset.sampling_strategy == SamplingStrategy.HYPEREDGE - assert len(dataset) == 2 +# @pytest.mark.parametrize( +# "strategy", +# [ +# pytest.param(SamplingStrategy.NODE, id="node_strategy"), +# pytest.param(SamplingStrategy.HYPEREDGE, id="hyperedge_strategy"), +# ], +# ) +# def test_split_preserves_sampling_strategy(mock_four_node_hypergraph, strategy): +# with patch.object(HIFLoader, "load", return_value=mock_four_node_hypergraph): +# dataset = AlgebraDataset(sampling_strategy=strategy) +# splits = dataset.split([0.5, 0.5]) -def test_explicit_node_sampling_strategy(mock_four_node_hypergraph): - with patch.object(HIFConverter, "load_from_hif", return_value=mock_four_node_hypergraph): - dataset = AlgebraDataset(sampling_strategy=SamplingStrategy.NODE) +# for split in splits: +# assert split.sampling_strategy == strategy - # NODE strategy, so len should be num_nodes (4), not num_hyperedges (2) - assert dataset.sampling_strategy == SamplingStrategy.NODE - assert len(dataset) == 4 +# def test_from_hdata_with_explicit_strategy(mock_hdata): +# dataset = Dataset.from_hdata(mock_hdata, sampling_strategy=SamplingStrategy.NODE) -@pytest.mark.parametrize( - "strategy", - [ - pytest.param(SamplingStrategy.NODE, id="node_strategy"), - pytest.param(SamplingStrategy.HYPEREDGE, id="hyperedge_strategy"), - ], -) -def test_split_preserves_sampling_strategy(mock_four_node_hypergraph, strategy): - with patch.object(HIFConverter, "load_from_hif", return_value=mock_four_node_hypergraph): - dataset = AlgebraDataset(sampling_strategy=strategy) +# assert dataset.sampling_strategy == SamplingStrategy.NODE +# assert len(dataset) == 3 # mock_hdata has 3 nodes - splits = dataset.split([0.5, 0.5]) - for split in splits: - assert split.sampling_strategy == strategy +# def test_update_from_hdata_returns_new_dataset(mock_hdata): +# dataset = Dataset(hdata=mock_hdata, prepare=False) +# new_x = torch.ones((2, 1), dtype=torch.float) +# new_hyperedge_index = torch.tensor([[0, 1], [0, 0]], dtype=torch.long) +# new_hdata = HData(x=new_x, hyperedge_index=new_hyperedge_index) +# result = dataset.update_from_hdata(new_hdata) -def test_from_hdata_with_explicit_strategy(mock_hdata): - dataset = Dataset.from_hdata(mock_hdata, sampling_strategy=SamplingStrategy.NODE) +# assert result is not dataset +# assert result.hdata is new_hdata +# assert dataset.hdata is mock_hdata - assert dataset.sampling_strategy == SamplingStrategy.NODE - assert len(dataset) == 3 # mock_hdata has 3 nodes +# def test_update_from_hdata_stores_provided_hdata(mock_hdata): +# dataset = Dataset(hdata=mock_hdata, prepare=False) +# new_x = torch.ones((2, 1), dtype=torch.float) +# new_hyperedge_index = torch.tensor([[0, 1], [0, 0]], dtype=torch.long) +# new_hdata = HData(x=new_x, hyperedge_index=new_hyperedge_index) -def test_update_from_hdata_returns_new_dataset(mock_hdata): - dataset = Dataset(hdata=mock_hdata, prepare=False) - new_x = torch.ones((2, 1), dtype=torch.float) - new_hyperedge_index = torch.tensor([[0, 1], [0, 0]], dtype=torch.long) - new_hdata = HData(x=new_x, hyperedge_index=new_hyperedge_index) +# result = dataset.update_from_hdata(new_hdata) + +# assert result.hdata is new_hdata + + +# @pytest.mark.parametrize( +# "strategy, expected_len", +# [ +# pytest.param(SamplingStrategy.NODE, 4, id="node_strategy"), +# pytest.param(SamplingStrategy.HYPEREDGE, 3, id="hyperedge_strategy"), +# ], +# ) +# def test_update_from_hdata_inherits_sampling_strategy(mock_hdata, strategy, expected_len): +# dataset = Dataset(hdata=mock_hdata, sampling_strategy=strategy, prepare=False) +# new_x = torch.ones((4, 1), dtype=torch.float) +# new_hyperedge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 2]], dtype=torch.long) +# new_hdata = HData(x=new_x, hyperedge_index=new_hyperedge_index) - result = dataset.update_from_hdata(new_hdata) +# result = dataset.update_from_hdata(new_hdata) - assert result is not dataset - assert result.hdata is new_hdata - assert dataset.hdata is mock_hdata +# assert result.sampling_strategy == strategy +# assert len(result) == expected_len -def test_update_from_hdata_stores_provided_hdata(mock_hdata): - dataset = Dataset(hdata=mock_hdata, prepare=False) - new_x = torch.ones((2, 1), dtype=torch.float) - new_hyperedge_index = torch.tensor([[0, 1], [0, 0]], dtype=torch.long) - new_hdata = HData(x=new_x, hyperedge_index=new_hyperedge_index) +# def test_update_from_hdata_preserves_subclass_type(mock_hdata): +# dataset = AlgebraDataset(hdata=mock_hdata, prepare=False) +# new_x = torch.ones((2, 1), dtype=torch.float) +# new_hyperedge_index = torch.tensor([[0, 1], [0, 0]], dtype=torch.long) +# new_hdata = HData(x=new_x, hyperedge_index=new_hyperedge_index) - result = dataset.update_from_hdata(new_hdata) +# result = dataset.update_from_hdata(new_hdata) - assert result.hdata is new_hdata +# assert type(result) is AlgebraDataset -@pytest.mark.parametrize( - "strategy, expected_len", - [ - pytest.param(SamplingStrategy.NODE, 4, id="node_strategy"), - pytest.param(SamplingStrategy.HYPEREDGE, 3, id="hyperedge_strategy"), - ], -) -def test_update_from_hdata_inherits_sampling_strategy(mock_hdata, strategy, expected_len): - dataset = Dataset(hdata=mock_hdata, sampling_strategy=strategy, prepare=False) - new_x = torch.ones((4, 1), dtype=torch.float) - new_hyperedge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 2]], dtype=torch.long) - new_hdata = HData(x=new_x, hyperedge_index=new_hyperedge_index) - - result = dataset.update_from_hdata(new_hdata) - - assert result.sampling_strategy == strategy - assert len(result) == expected_len - - -def test_update_from_hdata_preserves_subclass_type(mock_hdata): - dataset = AlgebraDataset(hdata=mock_hdata, prepare=False) - new_x = torch.ones((2, 1), dtype=torch.float) - new_hyperedge_index = torch.tensor([[0, 1], [0, 0]], dtype=torch.long) - new_hdata = HData(x=new_x, hyperedge_index=new_hyperedge_index) - - result = dataset.update_from_hdata(new_hdata) - - assert type(result) is AlgebraDataset - - -@pytest.fixture -def mock_hdata_stats(): - x = torch.tensor( - [ - [0.0, 1.0, 2.0, 3.0], - [1.0, 2.0, 3.0, 4.0], - [2.0, 3.0, 4.0, 5.0], - [3.0, 4.0, 5.0, 6.0], - ], - dtype=torch.float, - ) - hyperedge_index = torch.tensor( - [ - [0, 1, 2, 2, 3], - [0, 0, 0, 1, 1], - ] - ) - return HData(x=x, hyperedge_index=hyperedge_index) +# @pytest.fixture +# def mock_hdata_stats(): +# x = torch.tensor( +# [ +# [0.0, 1.0, 2.0, 3.0], +# [1.0, 2.0, 3.0, 4.0], +# [2.0, 3.0, 4.0, 5.0], +# [3.0, 4.0, 5.0, 6.0], +# ], +# dtype=torch.float, +# ) +# hyperedge_index = torch.tensor( +# [ +# [0, 1, 2, 2, 3], +# [0, 0, 0, 1, 1], +# ] +# ) +# return HData(x=x, hyperedge_index=hyperedge_index) -def test_dataset_stats_computation(mock_hdata_stats): - expected_stats = { - "shape_x": torch.Size([4, 4]), - "shape_hyperedge_attr": None, - "shape_hyperedge_weights": None, +# def test_dataset_stats_computation(mock_hdata_stats): +# expected_stats = { +# "shape_x": torch.Size([4, 4]), +# "shape_hyperedge_attr": None, +# "shape_hyperedge_weights": None, "num_nodes": 4, - "num_hyperedges": 2, - "avg_degree_node_raw": 1.25, - "avg_degree_node": 1, - "avg_degree_hyperedge_raw": 2.5, - "avg_degree_hyperedge": 2, - "node_degree_max": 2, - "hyperedge_degree_max": 3, - "node_degree_median": 1, - "hyperedge_degree_median": 2, - "distribution_node_degree": [1, 1, 2, 1], - "distribution_hyperedge_size": [3, 2], - "distribution_node_degree_hist": {1: 3, 2: 1}, - "distribution_hyperedge_size_hist": {2: 1, 3: 1}, - } - - dataset = Dataset.from_hdata(mock_hdata_stats) - - stats = dataset.stats() - assert stats == expected_stats +# "num_hyperedges": 2, +# "avg_degree_node_raw": 1.25, +# "avg_degree_node": 1, +# "avg_degree_hyperedge_raw": 2.5, +# "avg_degree_hyperedge": 2, +# "node_degree_max": 2, +# "hyperedge_degree_max": 3, +# "node_degree_median": 1, +# "hyperedge_degree_median": 2, +# "distribution_node_degree": [1, 1, 2, 1], +# "distribution_hyperedge_size": [3, 2], +# "distribution_node_degree_hist": {1: 3, 2: 1}, +# "distribution_hyperedge_size_hist": {2: 1, 3: 1}, +# } + +# dataset = Dataset.from_hdata(mock_hdata_stats) + +# stats = dataset.stats() +# assert stats == expected_stats diff --git a/hyperbench/tests/data/hif_test.py b/hyperbench/tests/data/hif_test.py new file mode 100644 index 0000000..e69de29 From a0e033f82831afb46c056c88b7f7baee470fcac3 Mon Sep 17 00:00:00 2001 From: ddevin96 Date: Wed, 22 Apr 2026 16:36:56 +0200 Subject: [PATCH 02/15] feat: messing comming --- hyperbench/tests/data/dataset_test.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/hyperbench/tests/data/dataset_test.py b/hyperbench/tests/data/dataset_test.py index aaa5015..0089cf6 100644 --- a/hyperbench/tests/data/dataset_test.py +++ b/hyperbench/tests/data/dataset_test.py @@ -9,6 +9,7 @@ from hyperbench.types import HData, HIFHypergraph from hyperbench.data.supported_datasets import PreloadedDataset + @pytest.fixture def mock_hdata() -> HData: x = torch.ones((3, 1), dtype=torch.float) @@ -148,6 +149,7 @@ def mock_multiple_edges_attr_hypergraph(): ], ) + def test_Preloaded_dataset_init(): mock_hdata = MagicMock(spec=HData) dataset = PreloadedDataset(hdata=mock_hdata) @@ -155,6 +157,7 @@ def test_Preloaded_dataset_init(): assert dataset.hdata == mock_hdata assert dataset.sampling_strategy is SamplingStrategy.HYPEREDGE + def test_Preloaded_dataset_loads_hdata_when_hdata_is_none(): mock_hdata = MagicMock(spec=HData) with patch.object(HIFLoader, "load", return_value=mock_hdata) as mock_load: @@ -163,6 +166,7 @@ def test_Preloaded_dataset_loads_hdata_when_hdata_is_none(): assert dataset.hdata == mock_hdata mock_load.assert_called_once_with("algebra", save_on_disk=True) + # def test_HIFLoader_num_nodes_and_edges(): # dataset_name = "ALGEBRA" # mock_hypergraph = HIFHypergraph( @@ -239,7 +243,7 @@ def test_Preloaded_dataset_loads_hdata_when_hdata_is_none(): # patch("hyperbench.data.dataset.validate_hif_json", return_value=True), # patch.object(HIFHypergraph, "from_hif", return_value=mock_hypergraph), # ): -# # Mock successful download +# # Mock successful download # mock_response = mock_get.return_value # mock_response.status_code = 200 # mock_response.content = b"mock_zst_content" From f54e63b8ee1c3e787f8a7779e48f70b3ed7cc402 Mon Sep 17 00:00:00 2001 From: ddevin96 Date: Wed, 22 Apr 2026 17:25:18 +0200 Subject: [PATCH 03/15] chore: rebase with some issues --- hyperbench/data/hif.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/hyperbench/data/hif.py b/hyperbench/data/hif.py index aae6f72..b96dd56 100644 --- a/hyperbench/data/hif.py +++ b/hyperbench/data/hif.py @@ -197,14 +197,14 @@ def __process_hyperedge_attr( ) -> Optional[Tensor]: # hyperedge-attr: shape [num_hyperedges, num_hyperedge_attributes] hyperedge_attr = None - has_hyperedges = hypergraph.edges is not None and len(hypergraph.edges) > 0 + has_hyperedges = hypergraph.hyperedges is not None and len(hypergraph.hyperedges) > 0 has_any_hyperedge_attrs = has_hyperedges and any( - "attrs" in edge for edge in hypergraph.edges + "attrs" in edge for edge in hypergraph.hyperedges ) if has_any_hyperedge_attrs: hyperedge_id_to_attrs: Dict[Any, Dict[str, Any]] = { - e.get("edge"): e.get("attrs", {}) for e in hypergraph.edges + e.get("edge"): e.get("attrs", {}) for e in hypergraph.hyperedges } hyperedge_attr_keys = HIFLoader.__collect_attr_keys( @@ -307,10 +307,10 @@ def __process(hypergraph: HIFHypergraph) -> HData: hyperedge_id_to_idx=hyperedge_id_to_idx, num_hyperedges=num_hyperedges, ) - + hyperedge_index = torch.tensor([node_ids, hyperedge_ids], dtype=torch.long) - return HData(x, hyperedge_index, hyperedge_attr, num_nodes, num_hyperedges) + return HData(x, hyperedge_index, hyperedge_attr) @staticmethod def __collect_attr_keys(attr_keys: List[Dict[str, Any]]) -> List[str]: From 993c1669278394bb2470c76f9290cf9d93d5ece8 Mon Sep 17 00:00:00 2001 From: ddevin96 Date: Wed, 22 Apr 2026 17:27:00 +0200 Subject: [PATCH 04/15] fix: rebase issue --- hyperbench/data/hif.py | 2 +- hyperbench/tests/data/dataset_test.py | 2498 ++++++++++++------------- 2 files changed, 1250 insertions(+), 1250 deletions(-) diff --git a/hyperbench/data/hif.py b/hyperbench/data/hif.py index b96dd56..0d3fd7d 100644 --- a/hyperbench/data/hif.py +++ b/hyperbench/data/hif.py @@ -307,7 +307,7 @@ def __process(hypergraph: HIFHypergraph) -> HData: hyperedge_id_to_idx=hyperedge_id_to_idx, num_hyperedges=num_hyperedges, ) - + hyperedge_index = torch.tensor([node_ids, hyperedge_ids], dtype=torch.long) return HData(x, hyperedge_index, hyperedge_attr) diff --git a/hyperbench/tests/data/dataset_test.py b/hyperbench/tests/data/dataset_test.py index 0089cf6..8e1300d 100644 --- a/hyperbench/tests/data/dataset_test.py +++ b/hyperbench/tests/data/dataset_test.py @@ -1,494 +1,61 @@ -import pytest -import requests -import tempfile -import torch - -from unittest.mock import patch, mock_open, MagicMock -from hyperbench.data import AlgebraDataset, Dataset, HIFLoader, SamplingStrategy -from hyperbench.nn import EnrichmentMode, NodeEnricher, HyperedgeEnricher -from hyperbench.types import HData, HIFHypergraph -from hyperbench.data.supported_datasets import PreloadedDataset - - -@pytest.fixture -def mock_hdata() -> HData: - x = torch.ones((3, 1), dtype=torch.float) - hyperedge_index = torch.tensor([[0, 1, 2], [0, 0, 1]], dtype=torch.long) - return HData(x=x, hyperedge_index=hyperedge_index) - - -@pytest.fixture -def mock_hdata_with_hyperedge_attr() -> HData: - x = torch.ones((3, 1), dtype=torch.float) - hyperedge_index = torch.tensor([[0, 1, 2], [0, 0, 1]], dtype=torch.long) - hyperedge_attr = torch.ones((3, 1), dtype=torch.float) - return HData(x=x, hyperedge_index=hyperedge_index, hyperedge_attr=hyperedge_attr) - - -@pytest.fixture -def mock_hdata_with_hyperedge_weights() -> HData: - x = torch.ones((3, 1), dtype=torch.float) - hyperedge_index = torch.tensor([[0, 1, 2], [0, 0, 1]], dtype=torch.long) - hyperedge_weights = torch.tensor([1.0, 2.0, 3.0], dtype=torch.float) - return HData(x=x, hyperedge_index=hyperedge_index, hyperedge_weights=hyperedge_weights) - - -@pytest.fixture -def mock_sample_hypergraph(): - return HIFHypergraph( - network_type="undirected", - nodes=[{"node": "0"}, {"node": "1"}], - hyperedges=[{"edge": "0"}], - incidences=[{"node": "0", "edge": "0"}], - ) - - -@pytest.fixture -def mock_simple_hypergraph(): - return HIFHypergraph( - network_type="undirected", - nodes=[{"node": "0", "attrs": {}}, {"node": "1", "attrs": {}}], - hyperedges=[{"edge": "0", "attrs": {}}], - incidences=[{"node": "0", "edge": "0"}], - ) - - -@pytest.fixture -def mock_three_node_weighted_hypergraph(): - return HIFHypergraph( - network_type="undirected", - nodes=[ - {"node": "0", "attrs": {}}, - {"node": "1", "attrs": {}}, - {"node": "2", "attrs": {}}, - ], - hyperedges=[ - {"edge": "0", "attrs": {"weight": 1.0}}, - {"edge": "1", "attrs": {"weight": 2.0}}, - ], - incidences=[ - {"node": "0", "edge": "0"}, - {"node": "1", "edge": "0"}, - {"node": "2", "edge": "1"}, - ], - ) - - -@pytest.fixture -def mock_four_node_hypergraph(): - return HIFHypergraph( - network_type="undirected", - nodes=[ - {"node": "0", "attrs": {}}, - {"node": "1", "attrs": {}}, - {"node": "2", "attrs": {}}, - {"node": "3", "attrs": {}}, - ], - hyperedges=[{"edge": "0", "attrs": {}}, {"edge": "1", "attrs": {}}], - incidences=[ - {"node": "0", "edge": "0"}, - {"node": "1", "edge": "0"}, - {"node": "2", "edge": "1"}, - {"node": "3", "edge": "1"}, - ], - ) - - -@pytest.fixture -def mock_five_node_hypergraph(): - return HIFHypergraph( - network_type="undirected", - nodes=[ - {"node": "0", "attrs": {}}, - {"node": "1", "attrs": {}}, - {"node": "2", "attrs": {}}, - {"node": "3", "attrs": {}}, - {"node": "4", "attrs": {}}, - ], - hyperedges=[{"edge": "0", "attrs": {}}], - incidences=[{"node": "0", "edge": "0"}], - ) - - -@pytest.fixture -def mock_no_edge_attr_hypergraph(): - return HIFHypergraph( - network_type="undirected", - nodes=[ - {"node": "0", "attrs": {}}, - {"node": "1", "attrs": {}}, - ], - hyperedges=[{"edge": "0"}], - incidences=[ - {"node": "0", "edge": "0"}, - {"node": "1", "edge": "0"}, - ], - ) - - -@pytest.fixture -def mock_multiple_edges_attr_hypergraph(): - return HIFHypergraph( - network_type="undirected", - nodes=[ - {"node": "0", "attrs": {}}, - {"node": "1", "attrs": {}}, - {"node": "2", "attrs": {}}, - {"node": "3", "attrs": {}}, - ], - hyperedges=[ - {"edge": "0", "attrs": {"weight": 1.0}}, - {"edge": "1", "attrs": {"weight": 2.0}}, - {"edge": "2", "attrs": {"weight": 3.0}}, - ], - incidences=[ - {"node": "0", "edge": "0"}, - {"node": "1", "edge": "0"}, - {"node": "2", "edge": "1"}, - {"node": "3", "edge": "2"}, - ], - ) - - -def test_Preloaded_dataset_init(): - mock_hdata = MagicMock(spec=HData) - dataset = PreloadedDataset(hdata=mock_hdata) - - assert dataset.hdata == mock_hdata - assert dataset.sampling_strategy is SamplingStrategy.HYPEREDGE - - -def test_Preloaded_dataset_loads_hdata_when_hdata_is_none(): - mock_hdata = MagicMock(spec=HData) - with patch.object(HIFLoader, "load", return_value=mock_hdata) as mock_load: - dataset = AlgebraDataset(hdata=None) - - assert dataset.hdata == mock_hdata - mock_load.assert_called_once_with("algebra", save_on_disk=True) - - -# def test_HIFLoader_num_nodes_and_edges(): -# dataset_name = "ALGEBRA" -# mock_hypergraph = HIFHypergraph( -# network_type="undirected", -# nodes=[{"node": str(i)} for i in range(20)], -# edges=[{"edge": str(i)} for i in range(30)], -# incidences=[{"node": "0", "edge": "0"}], -# ) - -# with patch.object(HIFLoader, "load", return_value=mock_hypergraph): -# hypergraph = HIFLoader.load(dataset_name) - -# assert hypergraph is not None -# assert hasattr(hypergraph, "nodes") -# assert hasattr(hypergraph, "hyperedges") -# assert hasattr(hypergraph, "incidences") -# assert hasattr(hypergraph, "metadata") -# assert hasattr(hypergraph, "network_type") +# import pytest +# import requests +# import tempfile +# import torch -# assert hypergraph.num_nodes == 20 -# assert hypergraph.num_hyperedges == 30 +# from unittest.mock import patch, mock_open, MagicMock +# from hyperbench.data import AlgebraDataset, Dataset, HIFLoader, SamplingStrategy +# from hyperbench.nn import EnrichmentMode, NodeEnricher, HyperedgeEnricher +# from hyperbench.types import HData, HIFHypergraph +# from hyperbench.data.supported_datasets import PreloadedDataset -# def test_HIFLoader_loads_invalid_dataset(): -# dataset_name = "INVALID_DATASET" - -# with pytest.raises(ValueError, match="Dataset 'INVALID_DATASET' not found"): -# HIFLoader.load(dataset_name) - - -# def test_HIFLoader_loads_invalid_hif_format(): -# dataset_name = "ALGEBRA" - -# invalid_hif_json = '{"network-type": "undirected", "nodes": []}' - -# with ( -# patch("hyperbench.data.dataset.requests.get") as mock_get, -# patch("hyperbench.data.dataset.validate_hif_json", return_value=False), -# patch("builtins.open", mock_open(read_data=invalid_hif_json)), -# patch("hyperbench.data.dataset.zstd.ZstdDecompressor"), -# ): -# mock_response = mock_get.return_value -# mock_response.status_code = 200 -# mock_response.content = b"mock_zst_content" +# @pytest.fixture +# def mock_hdata() -> HData: +# x = torch.ones((3, 1), dtype=torch.float) +# hyperedge_index = torch.tensor([[0, 1, 2], [0, 0, 1]], dtype=torch.long) +# return HData(x=x, hyperedge_index=hyperedge_index) -# with pytest.raises(ValueError, match="Dataset 'algebra' is not HIF-compliant"): -# HIFLoader.load(dataset_name) +# @pytest.fixture +# def mock_hdata_with_hyperedge_attr() -> HData: +# x = torch.ones((3, 1), dtype=torch.float) +# hyperedge_index = torch.tensor([[0, 1, 2], [0, 0, 1]], dtype=torch.long) +# hyperedge_attr = torch.ones((3, 1), dtype=torch.float) +# return HData(x=x, hyperedge_index=hyperedge_index, hyperedge_attr=hyperedge_attr) -# def test_HIFLoader_stores_on_disk_when_save_on_disk_true(): -# dataset_name = "ALGEBRA" -# mock_hypergraph = HIFHypergraph( -# network_type="undirected", -# nodes=[{"node": "0"}, {"node": "1"}], -# hyperedges=[{"edge": "0"}], -# incidences=[{"node": "0", "edge": "0"}], -# ) +# @pytest.fixture +# def mock_hdata_with_hyperedge_weights() -> HData: +# x = torch.ones((3, 1), dtype=torch.float) +# hyperedge_index = torch.tensor([[0, 1, 2], [0, 0, 1]], dtype=torch.long) +# hyperedge_weights = torch.tensor([1.0, 2.0, 3.0], dtype=torch.float) +# return HData(x=x, hyperedge_index=hyperedge_index, hyperedge_weights=hyperedge_weights) -# mock_hif_json = { -# "network-type": "undirected", -# "nodes": [{"node": "0"}, {"node": "1"}], -# "edges": [{"edge": "0"}], -# "incidences": [{"node": "0", "edge": "0"}], -# } - -# with ( -# patch("hyperbench.data.dataset.requests.get") as mock_get, -# patch("hyperbench.data.dataset.os.path.exists", return_value=False), -# patch("hyperbench.data.dataset.os.makedirs"), -# patch("builtins.open", mock_open()) as mock_file, -# patch("hyperbench.data.dataset.zstd.ZstdDecompressor") as mock_decomp, -# patch("hyperbench.data.dataset.json.load", return_value=mock_hif_json), -# patch("hyperbench.data.dataset.validate_hif_json", return_value=True), -# patch.object(HIFHypergraph, "from_hif", return_value=mock_hypergraph), -# ): -# # Mock successful download -# mock_response = mock_get.return_value -# mock_response.status_code = 200 -# mock_response.content = b"mock_zst_content" - -# # Mock decompressor -# mock_stream = mock_decomp.return_value.stream_reader.return_value -# mock_stream.__enter__ = lambda self: mock_stream -# mock_stream.__exit__ = lambda self, *args: None - -# hypergraph = HIFLoader.load(dataset_name, save_on_disk=True) - -# assert hypergraph is not None -# assert hypergraph.network_type == "undirected" -# mock_get.assert_called_once() -# # Verify file was written to disk (not temp file) -# assert mock_file.call_count >= 2 # Once for write, once for read - - -# def test_HIFLoader_uses_temp_file_when_save_on_disk_false(): -# dataset_name = "ALGEBRA" -# mock_hypergraph = HIFHypergraph( +# @pytest.fixture +# def mock_sample_hypergraph(): +# return HIFHypergraph( # network_type="undirected", # nodes=[{"node": "0"}, {"node": "1"}], # hyperedges=[{"edge": "0"}], # incidences=[{"node": "0", "edge": "0"}], # ) -# mock_hif_json = { -# "network-type": "undirected", -# "nodes": [{"node": "0"}, {"node": "1"}], -# "edges": [{"edge": "0"}], -# "incidences": [{"node": "0", "edge": "0"}], -# } - -# with ( -# patch("hyperbench.data.dataset.requests.get") as mock_get, -# patch("hyperbench.data.dataset.os.path.exists", return_value=False), -# patch("hyperbench.data.dataset.tempfile.NamedTemporaryFile") as mock_temp, -# patch("builtins.open", mock_open()), -# patch("hyperbench.data.dataset.zstd.ZstdDecompressor") as mock_decomp, -# patch("hyperbench.data.dataset.json.load", return_value=mock_hif_json), -# patch("hyperbench.data.dataset.validate_hif_json", return_value=True), -# patch.object(HIFHypergraph, "from_hif", return_value=mock_hypergraph), -# ): -# # Mock successful download -# mock_response = mock_get.return_value -# mock_response.status_code = 200 -# mock_response.content = b"mock_zst_content" - -# # Mock temp file -# mock_temp_file = mock_temp.return_value.__enter__.return_value -# mock_temp_file.name = "/tmp/fake_temp.json.zst" - -# # Mock decompressor -# mock_stream = mock_decomp.return_value.stream_reader.return_value -# mock_stream.__enter__ = lambda self: mock_stream -# mock_stream.__exit__ = lambda self, *args: None - -# hypergraph = HIFLoader.load(dataset_name, save_on_disk=False) - -# assert hypergraph is not None -# assert hypergraph.network_type == "undirected" -# mock_get.assert_called_once() -# # Verify temp file was used -# assert mock_temp.call_count >= 1 - - -# def test_HIFLoader_download_failure(): -# dataset_name = "ALGEBRA" - -# with ( -# patch("hyperbench.data.dataset.requests.get") as mock_get, -# patch("hyperbench.data.dataset.hf_hub_download", side_effect=Exception("HFHub failed")), -# patch("hyperbench.data.dataset.os.path.exists", return_value=False), -# ): -# # Mock failed download -# mock_response = mock_get.return_value -# mock_response.status_code = 404 -# mock_response.content = b"" - -# with pytest.warns( -# UserWarning, -# match=r"(?s)GitHub raw download failed for dataset 'algebra' with status code 404.*Falling back to Hugging Face Hub download for dataset", -# ): -# with pytest.raises( -# ValueError, -# match=r"Failed to download dataset 'algebra'", -# ): -# HIFLoader.load(dataset_name) - - -# def test_HIFLoader_falls_back_to_hf_hub_download_when_github_raw_download_fails( -# tmp_path, mock_sample_hypergraph -# ): -# dataset_name = "ALGEBRA" - -# mock_hypergraph = mock_sample_hypergraph - -# mock_hif_json = { -# "network_type": "undirected", -# "nodes": [{"node": "0"}, {"node": "1"}], -# "edges": [{"edge": "0"}], -# "incidences": [{"node": "0", "edge": "0"}], -# } - -# fallback_file = tmp_path / "algebra.json.zst" -# fallback_file.write_bytes(b"mock_zst_content") - -# created_temp_files = [] -# original_named_tempfile = tempfile.NamedTemporaryFile - -# def named_tempfile_side_effect(*args, **kwargs): -# temp_file = original_named_tempfile(*args, **kwargs) -# created_temp_files.append(temp_file) -# return temp_file - -# with ( -# patch("hyperbench.data.dataset.requests.get") as mock_get, -# patch( -# "hyperbench.data.dataset.hf_hub_download", -# return_value=str(fallback_file), -# ) as mock_hf_hub_download, -# patch("hyperbench.data.dataset.os.path.exists", return_value=False), -# patch( -# "hyperbench.data.dataset.tempfile.NamedTemporaryFile", -# side_effect=named_tempfile_side_effect, -# ), -# patch("hyperbench.data.dataset.zstd.ZstdDecompressor") as mock_decomp, -# patch("hyperbench.data.dataset.json.load", return_value=mock_hif_json), -# patch("hyperbench.data.dataset.validate_hif_json", return_value=True), -# patch.object(HIFHypergraph, "from_hif", return_value=mock_hypergraph), -# ): -# mock_response = mock_get.return_value -# mock_response.status_code = 404 -# mock_response.content = b"" - -# def fake_copy_stream(src, dst): -# dst.write(b'{"network_type":"undirected","nodes":[],"edges":[],"incidences":[]}') -# return - -# mock_decomp.return_value.copy_stream.side_effect = fake_copy_stream - -# with pytest.warns( -# UserWarning, -# match=r"(?s)GitHub raw download failed for dataset 'algebra' with status code 404.*Falling back to Hugging Face Hub download for dataset", -# ): -# hypergraph = HIFLoader.load(dataset_name, save_on_disk=False) - -# assert hypergraph.network_type == "undirected" -# mock_get.assert_called_once() -# mock_hf_hub_download.assert_called_once() -# assert created_temp_files[0].name is not None -# assert fallback_file.read_bytes() == b"mock_zst_content" - - -# def test_HIFLoader_download_raises_when_network_error(): -# dataset_name = "ALGEBRA" - -# with ( -# patch("hyperbench.data.dataset.requests.get") as mock_get, -# patch("hyperbench.data.dataset.os.path.exists", return_value=False), -# ): -# # Mock network error -# mock_get.side_effect = requests.RequestException("Network error") - -# with pytest.raises(requests.RequestException, match="Network error"): -# HIFLoader.load(dataset_name) - - -# def test_init_with_prepare_false_and_no_hdata_raises(): -# with pytest.raises(ValueError, match="hdata must be provided when prepare is set to False."): -# Dataset(hdata=None, prepare=False) - - -# def test_dataset_is_not_available(): -# class FakeMockDataset(Dataset): -# DATASET_NAME = "FAKE" - -# with pytest.raises(ValueError, match=r"Dataset 'FAKE' not found"): -# FakeMockDataset() - - -# @pytest.mark.parametrize( -# "strategy, expected_len", -# [ -# pytest.param(SamplingStrategy.NODE, 4, id="node_strategy"), -# pytest.param(SamplingStrategy.HYPEREDGE, 2, id="hyperedge_strategy"), -# ], -# ) -# def test_dataset_is_available_with_all_strategies( -# strategy, expected_len, mock_four_node_hypergraph -# ): -# with patch.object(HIFLoader, "load", return_value=mock_four_node_hypergraph): -# dataset = AlgebraDataset(sampling_strategy=strategy) - -# assert dataset.DATASET_NAME == "ALGEBRA" -# assert dataset.hypergraph is not None -# assert len(dataset) == expected_len - - -# def test_download_already_downloaded_dataset_uses_local_value(mock_four_node_hypergraph): -# with patch.object(HIFLoader, "load", return_value=mock_four_node_hypergraph): -# dataset = AlgebraDataset() - -# hg1 = dataset.download() -# hg2 = dataset.download() - -# assert hg1 is hg2 - - -# def test_throw_when_dataset_name_is_none(): -# class FakeMockDataset(Dataset): -# DATASET_NAME = None - -# with pytest.raises( -# ValueError, -# match=r"Dataset name \(provided: None\) must be provided\.", -# ): -# FakeMockDataset() - -# def test_dataset_process_no_incidences(): -# mock_hypergraph = HIFHypergraph( +# @pytest.fixture +# def mock_simple_hypergraph(): +# return HIFHypergraph( # network_type="undirected", # nodes=[{"node": "0", "attrs": {}}, {"node": "1", "attrs": {}}], # hyperedges=[{"edge": "0", "attrs": {}}], -# incidences=[], +# incidences=[{"node": "0", "edge": "0"}], # ) -# with patch.object(HIFLoader, "load", return_value=mock_hypergraph): -# dataset = AlgebraDataset() - -# assert dataset.hdata is not None -# assert dataset.hdata.x.shape[0] == 2 -# assert dataset.hdata.hyperedge_index.shape[0] == 2 -# assert dataset.hdata.hyperedge_index.shape[1] == 2 -# assert dataset.hdata.hyperedge_attr is not None -# assert dataset.hdata.hyperedge_attr.shape == (2, 0) -# assert dataset.hdata.hyperedge_attr[0].shape == (0,) -# assert dataset.hdata.hyperedge_attr[1].shape == (0,) - -# def test_dataset_process_with_edge_attributes(): -# mock_hypergraph = HIFHypergraph( +# @pytest.fixture +# def mock_three_node_weighted_hypergraph(): +# return HIFHypergraph( # network_type="undirected", # nodes=[ # {"node": "0", "attrs": {}}, @@ -496,8 +63,8 @@ def test_Preloaded_dataset_loads_hdata_when_hdata_is_none(): # {"node": "2", "attrs": {}}, # ], # hyperedges=[ -# {"edge": "0", "attrs": {"weight": 1.0, "type": 2.0}}, -# {"edge": "1", "attrs": {"weight": 3.0, "type": 0.1}}, +# {"edge": "0", "attrs": {"weight": 1.0}}, +# {"edge": "1", "attrs": {"weight": 2.0}}, # ], # incidences=[ # {"node": "0", "edge": "0"}, @@ -506,235 +73,62 @@ def test_Preloaded_dataset_loads_hdata_when_hdata_is_none(): # ], # ) -# with patch.object(HIFLoader, "load", return_value=mock_hypergraph): -# dataset = AlgebraDataset() - -# assert dataset.hdata is not None -# assert dataset.hdata.x.shape[0] == 3 -# assert dataset.hdata.hyperedge_index.shape[0] == 2 -# assert dataset.hdata.hyperedge_index.shape[1] == 3 -# assert dataset.hdata.hyperedge_attr is not None -# # Two edges with two attributes each: shape [2, 2] -# assert dataset.hdata.hyperedge_attr.shape == (2, 2) -# # Attributes maintain dictionary insertion order (no sorting) - -# assert torch.allclose(dataset.hdata.hyperedge_attr[0], torch.tensor([1.0, 2.0])) # weight, type -# assert torch.allclose(dataset.hdata.hyperedge_attr[1], torch.tensor([3.0, 0.1])) # weight, type - - -# def test_dataset_process_without_edge_attributes(mock_no_edge_attr_hypergraph): -# with patch.object(HIFLoader, "load", return_value=mock_no_edge_attr_hypergraph): -# dataset = AlgebraDataset() - -# assert dataset.hdata is not None -# assert dataset.hdata.hyperedge_index.shape[0] == 2 -# assert dataset.hdata.hyperedge_index.shape[1] == 2 -# assert dataset.hdata.hyperedge_attr is None +# @pytest.fixture +# def mock_four_node_hypergraph(): +# return HIFHypergraph( +# network_type="undirected", +# nodes=[ +# {"node": "0", "attrs": {}}, +# {"node": "1", "attrs": {}}, +# {"node": "2", "attrs": {}}, +# {"node": "3", "attrs": {}}, +# ], +# hyperedges=[{"edge": "0", "attrs": {}}, {"edge": "1", "attrs": {}}], +# incidences=[ +# {"node": "0", "edge": "0"}, +# {"node": "1", "edge": "0"}, +# {"node": "2", "edge": "1"}, +# {"node": "3", "edge": "1"}, +# ], +# ) -# def test_dataset_process_hyperedge_index_in_correct_format(mock_four_node_hypergraph): -# with patch.object(HIFLoader, "load", return_value=mock_four_node_hypergraph): -# dataset = AlgebraDataset() -# assert dataset.hdata.hyperedge_index.shape == (2, 4) -# assert torch.allclose(dataset.hdata.hyperedge_index[0], torch.tensor([0, 1, 2, 3])) -# assert torch.allclose(dataset.hdata.hyperedge_index[1], torch.tensor([0, 0, 1, 1])) +# @pytest.fixture +# def mock_five_node_hypergraph(): +# return HIFHypergraph( +# network_type="undirected", +# nodes=[ +# {"node": "0", "attrs": {}}, +# {"node": "1", "attrs": {}}, +# {"node": "2", "attrs": {}}, +# {"node": "3", "attrs": {}}, +# {"node": "4", "attrs": {}}, +# ], +# hyperedges=[{"edge": "0", "attrs": {}}], +# incidences=[{"node": "0", "edge": "0"}], +# ) -# def test_dataset_process_random_ids(): -# mock_hypergraph = HIFHypergraph( +# @pytest.fixture +# def mock_no_edge_attr_hypergraph(): +# return HIFHypergraph( # network_type="undirected", # nodes=[ -# {"node": "abc", "attrs": {}}, -# {"node": "ss", "attrs": {}}, -# {"node": "fewao", "attrs": {}}, +# {"node": "0", "attrs": {}}, +# {"node": "1", "attrs": {}}, # ], -# hyperedges=[{"edge": "0", "attrs": {}}, {"edge": "1", "attrs": {}}], +# hyperedges=[{"edge": "0"}], # incidences=[ -# {"node": "abc", "edge": "0"}, -# {"node": "ss", "edge": "0"}, -# {"node": "fewao", "edge": "1"}, +# {"node": "0", "edge": "0"}, +# {"node": "1", "edge": "0"}, # ], # ) -# with patch.object(HIFLoader, "load", return_value=mock_hypergraph): -# dataset = AlgebraDataset() -# assert dataset.hdata.hyperedge_index.shape == (2, 3) -# assert torch.allclose(dataset.hdata.hyperedge_index[0], torch.tensor([0, 1, 2])) -# assert torch.allclose(dataset.hdata.hyperedge_index[1], torch.tensor([0, 0, 1])) -# assert dataset.hdata.hyperedge_attr is not None -# assert dataset.hdata.hyperedge_attr.shape == (2, 0) # 2 edges, 0 attributes each - - -# @pytest.mark.parametrize( -# "strategy", -# [ -# pytest.param(SamplingStrategy.NODE, id="node_strategy"), -# pytest.param(SamplingStrategy.HYPEREDGE, id="hyperedge_strategy"), -# ], -# ) -# def test_getitem_index_list_empty(mock_simple_hypergraph, strategy): -# with patch.object(HIFLoader, "load", return_value=mock_simple_hypergraph): -# dataset = AlgebraDataset(sampling_strategy=strategy) - -# with pytest.raises(ValueError, match="Index list cannot be empty."): -# dataset[[]] - - -# @pytest.mark.parametrize( -# "strategy, index_list, expected_message", -# [ -# pytest.param( -# SamplingStrategy.NODE, -# [0, 1, 2, 3, 4], -# r"Index list length \(5\) cannot exceed the number of sampleable items \(4\)\.", -# id="node_strategy", -# ), -# pytest.param( -# SamplingStrategy.HYPEREDGE, -# [0, 1, 2], -# r"Index list length \(3\) cannot exceed the number of sampleable items \(2\)\.", -# id="hyperedge_strategy", -# ), -# ], -# ) -# def test_getitem_raises_when_index_list_larger_than_max( -# mock_four_node_hypergraph, strategy, index_list, expected_message -# ): -# with patch.object(HIFLoader, "load", return_value=mock_four_node_hypergraph): -# dataset = AlgebraDataset(sampling_strategy=strategy) - -# with pytest.raises(ValueError, match=expected_message): -# dataset[index_list] - - -# @pytest.mark.parametrize( -# "strategy, index, expected_message", -# [ -# pytest.param( -# SamplingStrategy.NODE, 4, r"Node ID 4 is out of bounds \(0, 3\)\.", id="node_strategy" -# ), -# pytest.param( -# SamplingStrategy.HYPEREDGE, -# 2, -# r"Hyperedge ID 2 is out of bounds \(0, 1\)\.", -# id="hyperedge_strategy", -# ), -# ], -# ) -# def test_getitem_raises_when_index_out_of_bounds( -# mock_four_node_hypergraph, strategy, index, expected_message -# ): -# with patch.object(HIFLoader, "load", return_value=mock_four_node_hypergraph): -# dataset = AlgebraDataset(sampling_strategy=strategy) - -# with pytest.raises(IndexError, match=expected_message): -# dataset[index] - - -# @pytest.mark.parametrize( -# "strategy, index, expected_shape, expected_num_hyperedges", -# [ -# # When node 1 is selected, we get hyperedge 0 with nodes 0 and 1 -> 2 incidences, 1 hyperedge -# pytest.param(SamplingStrategy.NODE, 1, (2, 1), 1, id="node_strategy"), -# # When hyperedge 0 is selected, we get nodes 0 and 1 -> 2 incidences, 1 hyperedge -# pytest.param(SamplingStrategy.HYPEREDGE, 0, (2, 1), 1, id="hyperedge_strategy"), -# ], -# ) -# def test_getitem_single_index( -# mock_sample_hypergraph, strategy, index, expected_shape, expected_num_hyperedges -# ): -# with patch.object(HIFLoader, "load", return_value=mock_sample_hypergraph): -# dataset = AlgebraDataset(sampling_strategy=strategy) - -# data = dataset[index] - -# assert data.hyperedge_index.shape == expected_shape -# assert data.num_hyperedges == expected_num_hyperedges - - -# @pytest.mark.parametrize( -# "strategy, index, expected_shape, expected_num_hyperedges", -# [ -# # When nodes (0, 2, 3) -> hyperedge 0 (nodes 0, 1) + hyperedge 1 (nodes 2, 3) -> 4 incidences, 2 hyperedges -# pytest.param(SamplingStrategy.NODE, [0, 2, 3], (2, 4), 2, id="node_strategy"), -# # When hyperedge 0 (nodes 0, 1) + hyperedge 1 (nodes 2, 3) -> 4 incidences, 2 hyperedges -# pytest.param(SamplingStrategy.HYPEREDGE, [0, 1], (2, 4), 2, id="hyperedge_strategy"), -# ], -# ) -# def test_getitem_when_list_index_provided( -# mock_four_node_hypergraph, strategy, index, expected_shape, expected_num_hyperedges -# ): -# with patch.object(HIFLoader, "load", return_value=mock_four_node_hypergraph): -# dataset = AlgebraDataset(sampling_strategy=strategy) - -# data = dataset[index] - -# assert data.hyperedge_index.shape == expected_shape -# assert data.num_hyperedges == expected_num_hyperedges - - -# @pytest.mark.parametrize( -# "strategy", -# [ -# pytest.param(SamplingStrategy.NODE, id="node_strategy"), -# pytest.param(SamplingStrategy.HYPEREDGE, id="hyperedge_strategy"), -# ], -# ) -# def test_getitem_with_edge_attr(mock_three_node_weighted_hypergraph, strategy): -# with patch.object( -# HIFLoader, "load", return_value=mock_three_node_weighted_hypergraph -# ): -# dataset = AlgebraDataset(sampling_strategy=strategy) - -# data = dataset[0] - -# assert data.hyperedge_index.shape == (2, 2) -# assert data.num_hyperedges == 1 -# assert data.hyperedge_attr is None - - -# @pytest.mark.parametrize( -# "strategy", -# [ -# pytest.param(SamplingStrategy.NODE, id="node_strategy"), -# pytest.param(SamplingStrategy.HYPEREDGE, id="hyperedge_strategy"), -# ], -# ) -# def test_getitem_without_edge_attr(mock_no_edge_attr_hypergraph, strategy): -# with patch.object(HIFLoader, "load", return_value=mock_no_edge_attr_hypergraph): -# dataset = AlgebraDataset(sampling_strategy=strategy) - -# data = dataset[0] -# assert data.hyperedge_attr is None - - -# @pytest.mark.parametrize( -# "strategy, index", -# [ -# # When nodes 0,2 -> hyperedge 0 (nodes 0, 1) + hyperedge 1 (node 2) -> 2 hyperedges -# pytest.param(SamplingStrategy.NODE, [0, 2], id="node_strategy"), -# # When hyperedge 0 (nodes 0, 1) + hyperedge 1 (node 2) -> 2 hyperedges -# pytest.param(SamplingStrategy.HYPEREDGE, [0, 1], id="hyperedge_strategy"), -# ], -# ) -# def test_getitem_with_multiple_edges_attr(mock_multiple_edges_attr_hypergraph, strategy, index): -# with patch.object( -# HIFLoader, "load", return_value=mock_multiple_edges_attr_hypergraph -# ): -# dataset = AlgebraDataset(sampling_strategy=strategy) - -# data = dataset[index] -# assert data.num_hyperedges == 2 - -# # Even though the original hypergraph has edge attributes, __getitem__ should return hyperedge_attr as None -# # as the hyperedge attributes are handled by the loader's collate function during batching -# assert data.hyperedge_attr is None - - -# def test_getitem_hyperedge_attr_are_padded_with_zero_when_no_uniform_edges(): -# mock_hypergraph = HIFHypergraph( +# @pytest.fixture +# def mock_multiple_edges_attr_hypergraph(): +# return HIFHypergraph( # network_type="undirected", # nodes=[ # {"node": "0", "attrs": {}}, @@ -742,10 +136,10 @@ def test_Preloaded_dataset_loads_hdata_when_hdata_is_none(): # {"node": "2", "attrs": {}}, # {"node": "3", "attrs": {}}, # ], -# edges=[ -# {"edge": "0", "attrs": {"weight": 1.0, "abc": 5.0}}, -# {"edge": "1", "attrs": {"weight": 2.0}}, # Missing 'abc' -# {"edge": "2", "attrs": {"abc": 3.0}}, # Missing 'weight' +# hyperedges=[ +# {"edge": "0", "attrs": {"weight": 1.0}}, +# {"edge": "1", "attrs": {"weight": 2.0}}, +# {"edge": "2", "attrs": {"weight": 3.0}}, # ], # incidences=[ # {"node": "0", "edge": "0"}, @@ -755,686 +149,1292 @@ def test_Preloaded_dataset_loads_hdata_when_hdata_is_none(): # ], # ) - with patch.object(HIFConverter, "load_from_hif", return_value=mock_hypergraph): - dataset = AlgebraDataset() - - assert dataset.hdata.hyperedge_attr is not None - assert dataset.hdata.hyperedge_attr.shape == ( - 3, - 2, - ) # 3 edges, 2 features each (weight, abc in insertion order) - assert torch.allclose( - dataset.hdata.hyperedge_attr[0], torch.tensor([1.0, 5.0]) - ) # weight=1.0, abc=5.0 - assert torch.allclose( - dataset.hdata.hyperedge_attr[1], torch.tensor([2.0, 0.0]) - ) # weight=2.0, abc=0.0 - assert torch.allclose( - dataset.hdata.hyperedge_attr[2], torch.tensor([0.0, 3.0]) - ) # weight=0.0, abc=3.0 - - -def test_process_not_all_hyperedge_weights_(): - mock_hypergraph = HIFHypergraph( - network_type="undirected", - nodes=[ - {"node": "0", "attrs": {}}, - {"node": "1", "attrs": {}}, - {"node": "2", "attrs": {}}, - ], - hyperedges=[ - {"edge": "0", "weight": 1.5}, - {"edge": "1"}, - {"edge": "2", "weight": 2.5}, - ], - incidences=[ - {"node": "0", "edge": "0"}, - {"node": "1", "edge": "1"}, - {"node": "2", "edge": "2"}, - ], - ) - - with patch.object(HIFConverter, "load_from_hif", return_value=mock_hypergraph): - with pytest.raises( - ValueError, - match="Some hyperedges have weights while others do not. All hyperedges must either have weights or none.", - ): - dataset = AlgebraDataset() - - -def test_process_extracts_top_level_hyperedge_weights(): - mock_hypergraph = HIFHypergraph( - network_type="undirected", - nodes=[ - {"node": "0", "attrs": {}}, - {"node": "1", "attrs": {}}, - {"node": "2", "attrs": {}}, - ], - hyperedges=[ - {"edge": "0", "weight": 1.5}, - {"edge": "1", "weight": 3.0}, - {"edge": "2", "weight": 2.5}, - ], - incidences=[ - {"node": "0", "edge": "0"}, - {"node": "1", "edge": "1"}, - {"node": "2", "edge": "2"}, - ], - ) - - with patch.object(HIFConverter, "load_from_hif", return_value=mock_hypergraph): - dataset = AlgebraDataset() - - hyperedge_weights = dataset.hdata.hyperedge_weights - assert hyperedge_weights is not None - assert torch.allclose(hyperedge_weights[0], torch.tensor([1.5])) - assert torch.allclose(hyperedge_weights[1], torch.tensor([3.0])) - assert torch.allclose(hyperedge_weights[2], torch.tensor([2.5])) - - -# def test_transform_attrs_empty_attrs(): -# mock_hypergraph = HIFHypergraph( -# network_type="undirected", -# nodes=[{"node": "0", "attrs": {}}], -# hyperedges=[{"edge": "0", "attrs": {}}], -# incidences=[{"node": "0", "edge": "0"}], -# ) -# with patch.object(HIFLoader, "load", return_value=mock_hypergraph): +# def test_Preloaded_dataset_init(): +# mock_hdata = MagicMock(spec=HData) +# dataset = PreloadedDataset(hdata=mock_hdata) -# class TestDataset(Dataset): -# DATASET_NAME = "TEST" +# assert dataset.hdata == mock_hdata +# assert dataset.sampling_strategy is SamplingStrategy.HYPEREDGE -# dataset = TestDataset() -# result = dataset.transform_attrs({}) -# assert len(result) == 0 +# def test_Preloaded_dataset_loads_hdata_when_hdata_is_none(): +# mock_hdata = MagicMock(spec=HData) +# with patch.object(HIFLoader, "load", return_value=mock_hdata) as mock_load: +# dataset = AlgebraDataset(hdata=None) -# attrs = {"name": "node1", "active": True} -# result = dataset.transform_attrs(attrs) -# assert len(result) == 0 +# assert dataset.hdata == mock_hdata +# mock_load.assert_called_once_with("algebra", save_on_disk=True) -# def test_process_adds_padding_zero_when_inconsistent_node_attributes(): +# # def test_HIFLoader_num_nodes_and_edges(): +# # dataset_name = "ALGEBRA" +# # mock_hypergraph = HIFHypergraph( +# # network_type="undirected", +# # nodes=[{"node": str(i)} for i in range(20)], +# # edges=[{"edge": str(i)} for i in range(30)], +# # incidences=[{"node": "0", "edge": "0"}], +# # ) + +# # with patch.object(HIFLoader, "load", return_value=mock_hypergraph): +# # hypergraph = HIFLoader.load(dataset_name) + +# # assert hypergraph is not None +# # assert hasattr(hypergraph, "nodes") +# # assert hasattr(hypergraph, "hyperedges") +# # assert hasattr(hypergraph, "incidences") +# # assert hasattr(hypergraph, "metadata") +# # assert hasattr(hypergraph, "network_type") + +# # assert hypergraph.num_nodes == 20 +# # assert hypergraph.num_hyperedges == 30 + + +# # def test_HIFLoader_loads_invalid_dataset(): +# # dataset_name = "INVALID_DATASET" + +# # with pytest.raises(ValueError, match="Dataset 'INVALID_DATASET' not found"): +# # HIFLoader.load(dataset_name) + + +# # def test_HIFLoader_loads_invalid_hif_format(): +# # dataset_name = "ALGEBRA" + +# # invalid_hif_json = '{"network-type": "undirected", "nodes": []}' + +# # with ( +# # patch("hyperbench.data.dataset.requests.get") as mock_get, +# # patch("hyperbench.data.dataset.validate_hif_json", return_value=False), +# # patch("builtins.open", mock_open(read_data=invalid_hif_json)), +# # patch("hyperbench.data.dataset.zstd.ZstdDecompressor"), +# # ): +# # mock_response = mock_get.return_value +# # mock_response.status_code = 200 +# # mock_response.content = b"mock_zst_content" + +# # with pytest.raises(ValueError, match="Dataset 'algebra' is not HIF-compliant"): +# # HIFLoader.load(dataset_name) + + +# # def test_HIFLoader_stores_on_disk_when_save_on_disk_true(): +# # dataset_name = "ALGEBRA" + +# # mock_hypergraph = HIFHypergraph( +# # network_type="undirected", +# # nodes=[{"node": "0"}, {"node": "1"}], +# # hyperedges=[{"edge": "0"}], +# # incidences=[{"node": "0", "edge": "0"}], +# # ) + +# # mock_hif_json = { +# # "network-type": "undirected", +# # "nodes": [{"node": "0"}, {"node": "1"}], +# # "edges": [{"edge": "0"}], +# # "incidences": [{"node": "0", "edge": "0"}], +# # } + +# # with ( +# # patch("hyperbench.data.dataset.requests.get") as mock_get, +# # patch("hyperbench.data.dataset.os.path.exists", return_value=False), +# # patch("hyperbench.data.dataset.os.makedirs"), +# # patch("builtins.open", mock_open()) as mock_file, +# # patch("hyperbench.data.dataset.zstd.ZstdDecompressor") as mock_decomp, +# # patch("hyperbench.data.dataset.json.load", return_value=mock_hif_json), +# # patch("hyperbench.data.dataset.validate_hif_json", return_value=True), +# # patch.object(HIFHypergraph, "from_hif", return_value=mock_hypergraph), +# # ): +# # # Mock successful download +# # mock_response = mock_get.return_value +# # mock_response.status_code = 200 +# # mock_response.content = b"mock_zst_content" + +# # # Mock decompressor +# # mock_stream = mock_decomp.return_value.stream_reader.return_value +# # mock_stream.__enter__ = lambda self: mock_stream +# # mock_stream.__exit__ = lambda self, *args: None + +# # hypergraph = HIFLoader.load(dataset_name, save_on_disk=True) + +# # assert hypergraph is not None +# # assert hypergraph.network_type == "undirected" +# # mock_get.assert_called_once() +# # # Verify file was written to disk (not temp file) +# # assert mock_file.call_count >= 2 # Once for write, once for read + + +# # def test_HIFLoader_uses_temp_file_when_save_on_disk_false(): +# # dataset_name = "ALGEBRA" + +# # mock_hypergraph = HIFHypergraph( +# # network_type="undirected", +# # nodes=[{"node": "0"}, {"node": "1"}], +# # hyperedges=[{"edge": "0"}], +# # incidences=[{"node": "0", "edge": "0"}], +# # ) + +# # mock_hif_json = { +# # "network-type": "undirected", +# # "nodes": [{"node": "0"}, {"node": "1"}], +# # "edges": [{"edge": "0"}], +# # "incidences": [{"node": "0", "edge": "0"}], +# # } + +# # with ( +# # patch("hyperbench.data.dataset.requests.get") as mock_get, +# # patch("hyperbench.data.dataset.os.path.exists", return_value=False), +# # patch("hyperbench.data.dataset.tempfile.NamedTemporaryFile") as mock_temp, +# # patch("builtins.open", mock_open()), +# # patch("hyperbench.data.dataset.zstd.ZstdDecompressor") as mock_decomp, +# # patch("hyperbench.data.dataset.json.load", return_value=mock_hif_json), +# # patch("hyperbench.data.dataset.validate_hif_json", return_value=True), +# # patch.object(HIFHypergraph, "from_hif", return_value=mock_hypergraph), +# # ): +# # # Mock successful download +# # mock_response = mock_get.return_value +# # mock_response.status_code = 200 +# # mock_response.content = b"mock_zst_content" + +# # # Mock temp file +# # mock_temp_file = mock_temp.return_value.__enter__.return_value +# # mock_temp_file.name = "/tmp/fake_temp.json.zst" + +# # # Mock decompressor +# # mock_stream = mock_decomp.return_value.stream_reader.return_value +# # mock_stream.__enter__ = lambda self: mock_stream +# # mock_stream.__exit__ = lambda self, *args: None + +# # hypergraph = HIFLoader.load(dataset_name, save_on_disk=False) + +# # assert hypergraph is not None +# # assert hypergraph.network_type == "undirected" +# # mock_get.assert_called_once() +# # # Verify temp file was used +# # assert mock_temp.call_count >= 1 + + +# # def test_HIFLoader_download_failure(): +# # dataset_name = "ALGEBRA" + +# # with ( +# # patch("hyperbench.data.dataset.requests.get") as mock_get, +# # patch("hyperbench.data.dataset.hf_hub_download", side_effect=Exception("HFHub failed")), +# # patch("hyperbench.data.dataset.os.path.exists", return_value=False), +# # ): +# # # Mock failed download +# # mock_response = mock_get.return_value +# # mock_response.status_code = 404 +# # mock_response.content = b"" + +# # with pytest.warns( +# # UserWarning, +# # match=r"(?s)GitHub raw download failed for dataset 'algebra' with status code 404.*Falling back to Hugging Face Hub download for dataset", +# # ): +# # with pytest.raises( +# # ValueError, +# # match=r"Failed to download dataset 'algebra'", +# # ): +# # HIFLoader.load(dataset_name) + + +# # def test_HIFLoader_falls_back_to_hf_hub_download_when_github_raw_download_fails( +# # tmp_path, mock_sample_hypergraph +# # ): +# # dataset_name = "ALGEBRA" + +# # mock_hypergraph = mock_sample_hypergraph + +# # mock_hif_json = { +# # "network_type": "undirected", +# # "nodes": [{"node": "0"}, {"node": "1"}], +# # "edges": [{"edge": "0"}], +# # "incidences": [{"node": "0", "edge": "0"}], +# # } + +# # fallback_file = tmp_path / "algebra.json.zst" +# # fallback_file.write_bytes(b"mock_zst_content") + +# # created_temp_files = [] +# # original_named_tempfile = tempfile.NamedTemporaryFile + +# # def named_tempfile_side_effect(*args, **kwargs): +# # temp_file = original_named_tempfile(*args, **kwargs) +# # created_temp_files.append(temp_file) +# # return temp_file + +# # with ( +# # patch("hyperbench.data.dataset.requests.get") as mock_get, +# # patch( +# # "hyperbench.data.dataset.hf_hub_download", +# # return_value=str(fallback_file), +# # ) as mock_hf_hub_download, +# # patch("hyperbench.data.dataset.os.path.exists", return_value=False), +# # patch( +# # "hyperbench.data.dataset.tempfile.NamedTemporaryFile", +# # side_effect=named_tempfile_side_effect, +# # ), +# # patch("hyperbench.data.dataset.zstd.ZstdDecompressor") as mock_decomp, +# # patch("hyperbench.data.dataset.json.load", return_value=mock_hif_json), +# # patch("hyperbench.data.dataset.validate_hif_json", return_value=True), +# # patch.object(HIFHypergraph, "from_hif", return_value=mock_hypergraph), +# # ): +# # mock_response = mock_get.return_value +# # mock_response.status_code = 404 +# # mock_response.content = b"" + +# # def fake_copy_stream(src, dst): +# # dst.write(b'{"network_type":"undirected","nodes":[],"edges":[],"incidences":[]}') +# # return + +# # mock_decomp.return_value.copy_stream.side_effect = fake_copy_stream + +# # with pytest.warns( +# # UserWarning, +# # match=r"(?s)GitHub raw download failed for dataset 'algebra' with status code 404.*Falling back to Hugging Face Hub download for dataset", +# # ): +# # hypergraph = HIFLoader.load(dataset_name, save_on_disk=False) + +# # assert hypergraph.network_type == "undirected" +# # mock_get.assert_called_once() +# # mock_hf_hub_download.assert_called_once() +# # assert created_temp_files[0].name is not None +# # assert fallback_file.read_bytes() == b"mock_zst_content" + + +# # def test_HIFLoader_download_raises_when_network_error(): +# # dataset_name = "ALGEBRA" + +# # with ( +# # patch("hyperbench.data.dataset.requests.get") as mock_get, +# # patch("hyperbench.data.dataset.os.path.exists", return_value=False), +# # ): +# # # Mock network error +# # mock_get.side_effect = requests.RequestException("Network error") + +# # with pytest.raises(requests.RequestException, match="Network error"): +# # HIFLoader.load(dataset_name) + + +# # def test_init_with_prepare_false_and_no_hdata_raises(): +# # with pytest.raises(ValueError, match="hdata must be provided when prepare is set to False."): +# # Dataset(hdata=None, prepare=False) + + +# # def test_dataset_is_not_available(): +# # class FakeMockDataset(Dataset): +# # DATASET_NAME = "FAKE" + +# # with pytest.raises(ValueError, match=r"Dataset 'FAKE' not found"): +# # FakeMockDataset() + + +# # @pytest.mark.parametrize( +# # "strategy, expected_len", +# # [ +# # pytest.param(SamplingStrategy.NODE, 4, id="node_strategy"), +# # pytest.param(SamplingStrategy.HYPEREDGE, 2, id="hyperedge_strategy"), +# # ], +# # ) +# # def test_dataset_is_available_with_all_strategies( +# # strategy, expected_len, mock_four_node_hypergraph +# # ): +# # with patch.object(HIFLoader, "load", return_value=mock_four_node_hypergraph): +# # dataset = AlgebraDataset(sampling_strategy=strategy) + +# # assert dataset.DATASET_NAME == "ALGEBRA" +# # assert dataset.hypergraph is not None +# # assert len(dataset) == expected_len + + +# # def test_download_already_downloaded_dataset_uses_local_value(mock_four_node_hypergraph): +# # with patch.object(HIFLoader, "load", return_value=mock_four_node_hypergraph): +# # dataset = AlgebraDataset() + +# # hg1 = dataset.download() +# # hg2 = dataset.download() + +# # assert hg1 is hg2 + + +# # def test_throw_when_dataset_name_is_none(): +# # class FakeMockDataset(Dataset): +# # DATASET_NAME = None + +# # with pytest.raises( +# # ValueError, +# # match=r"Dataset name \(provided: None\) must be provided\.", +# # ): +# # FakeMockDataset() + + +# # def test_dataset_process_no_incidences(): +# # mock_hypergraph = HIFHypergraph( +# # network_type="undirected", +# # nodes=[{"node": "0", "attrs": {}}, {"node": "1", "attrs": {}}], +# # hyperedges=[{"edge": "0", "attrs": {}}], +# # incidences=[], +# # ) + +# # with patch.object(HIFLoader, "load", return_value=mock_hypergraph): +# # dataset = AlgebraDataset() + +# # assert dataset.hdata is not None +# # assert dataset.hdata.x.shape[0] == 2 +# # assert dataset.hdata.hyperedge_index.shape[0] == 2 +# # assert dataset.hdata.hyperedge_index.shape[1] == 2 +# # assert dataset.hdata.hyperedge_attr is not None +# # assert dataset.hdata.hyperedge_attr.shape == (2, 0) +# # assert dataset.hdata.hyperedge_attr[0].shape == (0,) +# # assert dataset.hdata.hyperedge_attr[1].shape == (0,) + + +# # def test_dataset_process_with_edge_attributes(): +# # mock_hypergraph = HIFHypergraph( +# # network_type="undirected", +# # nodes=[ +# # {"node": "0", "attrs": {}}, +# # {"node": "1", "attrs": {}}, +# # {"node": "2", "attrs": {}}, +# # ], +# # hyperedges=[ +# # {"edge": "0", "attrs": {"weight": 1.0, "type": 2.0}}, +# # {"edge": "1", "attrs": {"weight": 3.0, "type": 0.1}}, +# # ], +# # incidences=[ +# # {"node": "0", "edge": "0"}, +# # {"node": "1", "edge": "0"}, +# # {"node": "2", "edge": "1"}, +# # ], +# # ) + +# # with patch.object(HIFLoader, "load", return_value=mock_hypergraph): +# # dataset = AlgebraDataset() + +# # assert dataset.hdata is not None +# # assert dataset.hdata.x.shape[0] == 3 +# # assert dataset.hdata.hyperedge_index.shape[0] == 2 +# # assert dataset.hdata.hyperedge_index.shape[1] == 3 +# # assert dataset.hdata.hyperedge_attr is not None +# # # Two edges with two attributes each: shape [2, 2] +# # assert dataset.hdata.hyperedge_attr.shape == (2, 2) +# # # Attributes maintain dictionary insertion order (no sorting) + +# # assert torch.allclose(dataset.hdata.hyperedge_attr[0], torch.tensor([1.0, 2.0])) # weight, type +# # assert torch.allclose(dataset.hdata.hyperedge_attr[1], torch.tensor([3.0, 0.1])) # weight, type + + +# # def test_dataset_process_without_edge_attributes(mock_no_edge_attr_hypergraph): +# # with patch.object(HIFLoader, "load", return_value=mock_no_edge_attr_hypergraph): +# # dataset = AlgebraDataset() + +# # assert dataset.hdata is not None +# # assert dataset.hdata.hyperedge_index.shape[0] == 2 +# # assert dataset.hdata.hyperedge_index.shape[1] == 2 +# # assert dataset.hdata.hyperedge_attr is None + + +# # def test_dataset_process_hyperedge_index_in_correct_format(mock_four_node_hypergraph): +# # with patch.object(HIFLoader, "load", return_value=mock_four_node_hypergraph): +# # dataset = AlgebraDataset() + +# # assert dataset.hdata.hyperedge_index.shape == (2, 4) +# # assert torch.allclose(dataset.hdata.hyperedge_index[0], torch.tensor([0, 1, 2, 3])) +# # assert torch.allclose(dataset.hdata.hyperedge_index[1], torch.tensor([0, 0, 1, 1])) + + +# # def test_dataset_process_random_ids(): +# # mock_hypergraph = HIFHypergraph( +# # network_type="undirected", +# # nodes=[ +# # {"node": "abc", "attrs": {}}, +# # {"node": "ss", "attrs": {}}, +# # {"node": "fewao", "attrs": {}}, +# # ], +# # hyperedges=[{"edge": "0", "attrs": {}}, {"edge": "1", "attrs": {}}], +# # incidences=[ +# # {"node": "abc", "edge": "0"}, +# # {"node": "ss", "edge": "0"}, +# # {"node": "fewao", "edge": "1"}, +# # ], +# # ) + +# # with patch.object(HIFLoader, "load", return_value=mock_hypergraph): +# # dataset = AlgebraDataset() + +# # assert dataset.hdata.hyperedge_index.shape == (2, 3) +# # assert torch.allclose(dataset.hdata.hyperedge_index[0], torch.tensor([0, 1, 2])) +# # assert torch.allclose(dataset.hdata.hyperedge_index[1], torch.tensor([0, 0, 1])) +# # assert dataset.hdata.hyperedge_attr is not None +# # assert dataset.hdata.hyperedge_attr.shape == (2, 0) # 2 edges, 0 attributes each + + +# # @pytest.mark.parametrize( +# # "strategy", +# # [ +# # pytest.param(SamplingStrategy.NODE, id="node_strategy"), +# # pytest.param(SamplingStrategy.HYPEREDGE, id="hyperedge_strategy"), +# # ], +# # ) +# # def test_getitem_index_list_empty(mock_simple_hypergraph, strategy): +# # with patch.object(HIFLoader, "load", return_value=mock_simple_hypergraph): +# # dataset = AlgebraDataset(sampling_strategy=strategy) + +# # with pytest.raises(ValueError, match="Index list cannot be empty."): +# # dataset[[]] + + +# # @pytest.mark.parametrize( +# # "strategy, index_list, expected_message", +# # [ +# # pytest.param( +# # SamplingStrategy.NODE, +# # [0, 1, 2, 3, 4], +# # r"Index list length \(5\) cannot exceed the number of sampleable items \(4\)\.", +# # id="node_strategy", +# # ), +# # pytest.param( +# # SamplingStrategy.HYPEREDGE, +# # [0, 1, 2], +# # r"Index list length \(3\) cannot exceed the number of sampleable items \(2\)\.", +# # id="hyperedge_strategy", +# # ), +# # ], +# # ) +# # def test_getitem_raises_when_index_list_larger_than_max( +# # mock_four_node_hypergraph, strategy, index_list, expected_message +# # ): +# # with patch.object(HIFLoader, "load", return_value=mock_four_node_hypergraph): +# # dataset = AlgebraDataset(sampling_strategy=strategy) + +# # with pytest.raises(ValueError, match=expected_message): +# # dataset[index_list] + + +# # @pytest.mark.parametrize( +# # "strategy, index, expected_message", +# # [ +# # pytest.param( +# # SamplingStrategy.NODE, 4, r"Node ID 4 is out of bounds \(0, 3\)\.", id="node_strategy" +# # ), +# # pytest.param( +# # SamplingStrategy.HYPEREDGE, +# # 2, +# # r"Hyperedge ID 2 is out of bounds \(0, 1\)\.", +# # id="hyperedge_strategy", +# # ), +# # ], +# # ) +# # def test_getitem_raises_when_index_out_of_bounds( +# # mock_four_node_hypergraph, strategy, index, expected_message +# # ): +# # with patch.object(HIFLoader, "load", return_value=mock_four_node_hypergraph): +# # dataset = AlgebraDataset(sampling_strategy=strategy) + +# # with pytest.raises(IndexError, match=expected_message): +# # dataset[index] + + +# # @pytest.mark.parametrize( +# # "strategy, index, expected_shape, expected_num_hyperedges", +# # [ +# # # When node 1 is selected, we get hyperedge 0 with nodes 0 and 1 -> 2 incidences, 1 hyperedge +# # pytest.param(SamplingStrategy.NODE, 1, (2, 1), 1, id="node_strategy"), +# # # When hyperedge 0 is selected, we get nodes 0 and 1 -> 2 incidences, 1 hyperedge +# # pytest.param(SamplingStrategy.HYPEREDGE, 0, (2, 1), 1, id="hyperedge_strategy"), +# # ], +# # ) +# # def test_getitem_single_index( +# # mock_sample_hypergraph, strategy, index, expected_shape, expected_num_hyperedges +# # ): +# # with patch.object(HIFLoader, "load", return_value=mock_sample_hypergraph): +# # dataset = AlgebraDataset(sampling_strategy=strategy) + +# # data = dataset[index] + +# # assert data.hyperedge_index.shape == expected_shape +# # assert data.num_hyperedges == expected_num_hyperedges + + +# # @pytest.mark.parametrize( +# # "strategy, index, expected_shape, expected_num_hyperedges", +# # [ +# # # When nodes (0, 2, 3) -> hyperedge 0 (nodes 0, 1) + hyperedge 1 (nodes 2, 3) -> 4 incidences, 2 hyperedges +# # pytest.param(SamplingStrategy.NODE, [0, 2, 3], (2, 4), 2, id="node_strategy"), +# # # When hyperedge 0 (nodes 0, 1) + hyperedge 1 (nodes 2, 3) -> 4 incidences, 2 hyperedges +# # pytest.param(SamplingStrategy.HYPEREDGE, [0, 1], (2, 4), 2, id="hyperedge_strategy"), +# # ], +# # ) +# # def test_getitem_when_list_index_provided( +# # mock_four_node_hypergraph, strategy, index, expected_shape, expected_num_hyperedges +# # ): +# # with patch.object(HIFLoader, "load", return_value=mock_four_node_hypergraph): +# # dataset = AlgebraDataset(sampling_strategy=strategy) + +# # data = dataset[index] + +# # assert data.hyperedge_index.shape == expected_shape +# # assert data.num_hyperedges == expected_num_hyperedges + + +# # @pytest.mark.parametrize( +# # "strategy", +# # [ +# # pytest.param(SamplingStrategy.NODE, id="node_strategy"), +# # pytest.param(SamplingStrategy.HYPEREDGE, id="hyperedge_strategy"), +# # ], +# # ) +# # def test_getitem_with_edge_attr(mock_three_node_weighted_hypergraph, strategy): +# # with patch.object( +# # HIFLoader, "load", return_value=mock_three_node_weighted_hypergraph +# # ): +# # dataset = AlgebraDataset(sampling_strategy=strategy) + +# # data = dataset[0] + +# # assert data.hyperedge_index.shape == (2, 2) +# # assert data.num_hyperedges == 1 +# # assert data.hyperedge_attr is None + + +# # @pytest.mark.parametrize( +# # "strategy", +# # [ +# # pytest.param(SamplingStrategy.NODE, id="node_strategy"), +# # pytest.param(SamplingStrategy.HYPEREDGE, id="hyperedge_strategy"), +# # ], +# # ) +# # def test_getitem_without_edge_attr(mock_no_edge_attr_hypergraph, strategy): +# # with patch.object(HIFLoader, "load", return_value=mock_no_edge_attr_hypergraph): +# # dataset = AlgebraDataset(sampling_strategy=strategy) + +# # data = dataset[0] +# # assert data.hyperedge_attr is None + + +# # @pytest.mark.parametrize( +# # "strategy, index", +# # [ +# # # When nodes 0,2 -> hyperedge 0 (nodes 0, 1) + hyperedge 1 (node 2) -> 2 hyperedges +# # pytest.param(SamplingStrategy.NODE, [0, 2], id="node_strategy"), +# # # When hyperedge 0 (nodes 0, 1) + hyperedge 1 (node 2) -> 2 hyperedges +# # pytest.param(SamplingStrategy.HYPEREDGE, [0, 1], id="hyperedge_strategy"), +# # ], +# # ) +# # def test_getitem_with_multiple_edges_attr(mock_multiple_edges_attr_hypergraph, strategy, index): +# # with patch.object( +# # HIFLoader, "load", return_value=mock_multiple_edges_attr_hypergraph +# # ): +# # dataset = AlgebraDataset(sampling_strategy=strategy) + +# # data = dataset[index] +# # assert data.num_hyperedges == 2 + +# # # Even though the original hypergraph has edge attributes, __getitem__ should return hyperedge_attr as None +# # # as the hyperedge attributes are handled by the loader's collate function during batching +# # assert data.hyperedge_attr is None + + +# # def test_getitem_hyperedge_attr_are_padded_with_zero_when_no_uniform_edges(): +# # mock_hypergraph = HIFHypergraph( +# # network_type="undirected", +# # nodes=[ +# # {"node": "0", "attrs": {}}, +# # {"node": "1", "attrs": {}}, +# # {"node": "2", "attrs": {}}, +# # {"node": "3", "attrs": {}}, +# # ], +# # edges=[ +# # {"edge": "0", "attrs": {"weight": 1.0, "abc": 5.0}}, +# # {"edge": "1", "attrs": {"weight": 2.0}}, # Missing 'abc' +# # {"edge": "2", "attrs": {"abc": 3.0}}, # Missing 'weight' +# # ], +# # incidences=[ +# # {"node": "0", "edge": "0"}, +# # {"node": "1", "edge": "0"}, +# # {"node": "2", "edge": "1"}, +# # {"node": "3", "edge": "2"}, +# # ], +# # ) + +# with patch.object(HIFConverter, "load_from_hif", return_value=mock_hypergraph): +# dataset = AlgebraDataset() + +# assert dataset.hdata.hyperedge_attr is not None +# assert dataset.hdata.hyperedge_attr.shape == ( +# 3, +# 2, +# ) # 3 edges, 2 features each (weight, abc in insertion order) +# assert torch.allclose( +# dataset.hdata.hyperedge_attr[0], torch.tensor([1.0, 5.0]) +# ) # weight=1.0, abc=5.0 +# assert torch.allclose( +# dataset.hdata.hyperedge_attr[1], torch.tensor([2.0, 0.0]) +# ) # weight=2.0, abc=0.0 +# assert torch.allclose( +# dataset.hdata.hyperedge_attr[2], torch.tensor([0.0, 3.0]) +# ) # weight=0.0, abc=3.0 + + +# def test_process_not_all_hyperedge_weights_(): # mock_hypergraph = HIFHypergraph( # network_type="undirected", # nodes=[ -# {"node": "0", "attrs": {"weight": 1.0}}, # Missing 'score' -# {"node": "1", "attrs": {"weight": 2.0, "score": 0.8}}, -# {"node": "2", "attrs": {"score": 0.5}}, # Missing 'weight' +# {"node": "0", "attrs": {}}, +# {"node": "1", "attrs": {}}, +# {"node": "2", "attrs": {}}, +# ], +# hyperedges=[ +# {"edge": "0", "weight": 1.5}, +# {"edge": "1"}, +# {"edge": "2", "weight": 2.5}, # ], -# hyperedges=[{"edge": "0", "attrs": {}}], # incidences=[ # {"node": "0", "edge": "0"}, -# {"node": "1", "edge": "0"}, -# {"node": "2", "edge": "0"}, +# {"node": "1", "edge": "1"}, +# {"node": "2", "edge": "2"}, # ], # ) -# with patch.object(HIFLoader, "load", return_value=mock_hypergraph): - -# class TestDataset(Dataset): -# DATASET_NAME = "TEST" - -# dataset = TestDataset() - -# assert dataset.hdata.x.shape == ( -# 3, -# 2, -# ) # 3 nodes, 2 features each (weight, score in insertion order) -# assert torch.allclose(dataset.hdata.x[0], torch.tensor([1.0, 0.0])) # weight=1.0, score=0.0 -# assert torch.allclose(dataset.hdata.x[1], torch.tensor([2.0, 0.8])) # weight=2.0, score=0.8 -# assert torch.allclose(dataset.hdata.x[2], torch.tensor([0.0, 0.5])) # weight=0.0, score=0.5 +# with patch.object(HIFConverter, "load_from_hif", return_value=mock_hypergraph): +# with pytest.raises( +# ValueError, +# match="Some hyperedges have weights while others do not. All hyperedges must either have weights or none.", +# ): +# dataset = AlgebraDataset() -# def test_process_with_no_node_attributes_fallback_to_one(): +# def test_process_extracts_top_level_hyperedge_weights(): # mock_hypergraph = HIFHypergraph( # network_type="undirected", # nodes=[ -# {"node": "0", "attrs": {"name": "node0"}}, +# {"node": "0", "attrs": {}}, # {"node": "1", "attrs": {}}, +# {"node": "2", "attrs": {}}, # ], -# hyperedges=[{"edge": "0", "attrs": {}}], -# incidences=[{"node": "0", "edge": "0"}, {"node": "1", "edge": "0"}], -# ) - -# with patch.object(HIFLoader, "load", return_value=mock_hypergraph): - -# class TestDataset(Dataset): -# DATASET_NAME = "TEST" - -# dataset = TestDataset() - -# assert dataset.hdata.x.shape == (2, 1) -# assert torch.allclose(dataset.hdata.x, torch.tensor([[1.0], [1.0]])) - - -# def test_process_with_single_node_attribute(): -# mock_hypergraph = HIFHypergraph( -# network_type="undirected", -# nodes=[ -# {"node": "0", "attrs": {"weight": 1.5}}, -# {"node": "1", "attrs": {"weight": 2.5}}, -# {"node": "2", "attrs": {"weight": 3.5}}, +# hyperedges=[ +# {"edge": "0", "weight": 1.5}, +# {"edge": "1", "weight": 3.0}, +# {"edge": "2", "weight": 2.5}, # ], -# hyperedges=[{"edge": "0", "attrs": {}}], # incidences=[ # {"node": "0", "edge": "0"}, -# {"node": "1", "edge": "0"}, -# {"node": "2", "edge": "0"}, +# {"node": "1", "edge": "1"}, +# {"node": "2", "edge": "2"}, # ], # ) -# with patch.object(HIFLoader, "load", return_value=mock_hypergraph): +# with patch.object(HIFConverter, "load_from_hif", return_value=mock_hypergraph): +# dataset = AlgebraDataset() -# class TestDataset(Dataset): -# DATASET_NAME = "TEST" +# hyperedge_weights = dataset.hdata.hyperedge_weights +# assert hyperedge_weights is not None +# assert torch.allclose(hyperedge_weights[0], torch.tensor([1.5])) +# assert torch.allclose(hyperedge_weights[1], torch.tensor([3.0])) +# assert torch.allclose(hyperedge_weights[2], torch.tensor([2.5])) + + +# # def test_transform_attrs_empty_attrs(): +# # mock_hypergraph = HIFHypergraph( +# # network_type="undirected", +# # nodes=[{"node": "0", "attrs": {}}], +# # hyperedges=[{"edge": "0", "attrs": {}}], +# # incidences=[{"node": "0", "edge": "0"}], +# # ) + +# # with patch.object(HIFLoader, "load", return_value=mock_hypergraph): + +# # class TestDataset(Dataset): +# # DATASET_NAME = "TEST" + +# # dataset = TestDataset() -# dataset = TestDataset() +# # result = dataset.transform_attrs({}) +# # assert len(result) == 0 -# # Single attribute should remain 2D: [num_nodes, 1] -# assert dataset.hdata.x.shape == (3, 1) -# assert torch.allclose(dataset.hdata.x, torch.tensor([[1.5], [2.5], [3.5]])) +# # attrs = {"name": "node1", "active": True} +# # result = dataset.transform_attrs(attrs) +# # assert len(result) == 0 -# def test_transform_attrs_adds_padding_zero_when_attr_keys_padding(): -# mock_hypergraph = HIFHypergraph( -# network_type="undirected", -# nodes=[{"node": "0", "attrs": {}}], -# hyperedges=[{"edge": "0", "attrs": {}}], -# incidences=[{"node": "0", "edge": "0"}], -# ) +# # def test_process_adds_padding_zero_when_inconsistent_node_attributes(): +# # mock_hypergraph = HIFHypergraph( +# # network_type="undirected", +# # nodes=[ +# # {"node": "0", "attrs": {"weight": 1.0}}, # Missing 'score' +# # {"node": "1", "attrs": {"weight": 2.0, "score": 0.8}}, +# # {"node": "2", "attrs": {"score": 0.5}}, # Missing 'weight' +# # ], +# # hyperedges=[{"edge": "0", "attrs": {}}], +# # incidences=[ +# # {"node": "0", "edge": "0"}, +# # {"node": "1", "edge": "0"}, +# # {"node": "2", "edge": "0"}, +# # ], +# # ) + +# # with patch.object(HIFLoader, "load", return_value=mock_hypergraph): + +# # class TestDataset(Dataset): +# # DATASET_NAME = "TEST" + +# # dataset = TestDataset() + +# # assert dataset.hdata.x.shape == ( +# # 3, +# # 2, +# # ) # 3 nodes, 2 features each (weight, score in insertion order) +# # assert torch.allclose(dataset.hdata.x[0], torch.tensor([1.0, 0.0])) # weight=1.0, score=0.0 +# # assert torch.allclose(dataset.hdata.x[1], torch.tensor([2.0, 0.8])) # weight=2.0, score=0.8 +# # assert torch.allclose(dataset.hdata.x[2], torch.tensor([0.0, 0.5])) # weight=0.0, score=0.5 + + +# # def test_process_with_no_node_attributes_fallback_to_one(): +# # mock_hypergraph = HIFHypergraph( +# # network_type="undirected", +# # nodes=[ +# # {"node": "0", "attrs": {"name": "node0"}}, +# # {"node": "1", "attrs": {}}, +# # ], +# # hyperedges=[{"edge": "0", "attrs": {}}], +# # incidences=[{"node": "0", "edge": "0"}, {"node": "1", "edge": "0"}], +# # ) -# with patch.object(HIFLoader, "load", return_value=mock_hypergraph): +# # with patch.object(HIFLoader, "load", return_value=mock_hypergraph): -# class TestDataset(Dataset): -# DATASET_NAME = "TEST" +# # class TestDataset(Dataset): +# # DATASET_NAME = "TEST" -# dataset = TestDataset() +# # dataset = TestDataset() -# # Test with attr_keys - should pad missing attributes with 0.0 -# attrs = {"weight": 1.5} -# result = dataset.transform_attrs(attrs, attr_keys=["score", "weight", "age"]) -# assert torch.allclose( -# result, torch.tensor([0.0, 1.5, 0.0]) -# ) # score=0.0, weight=1.5, age=0.0 +# # assert dataset.hdata.x.shape == (2, 1) +# # assert torch.allclose(dataset.hdata.x, torch.tensor([[1.0], [1.0]])) -# # Test with all attributes present -# attrs = {"weight": 1.5, "score": 0.8, "age": 25.0} -# result = dataset.transform_attrs(attrs, attr_keys=["age", "score", "weight"]) -# assert torch.allclose( -# result, torch.tensor([25.0, 0.8, 1.5]) -# ) # age=25.0, score=0.8, weight=1.5 -# # Test without attr_keys - maintains insertion order -# attrs = {"weight": 1.5, "score": 0.8} -# result = dataset.transform_attrs(attrs) -# assert torch.allclose(result, torch.tensor([1.5, 0.8])) # weight, score (insertion order) +# # def test_process_with_single_node_attribute(): +# # mock_hypergraph = HIFHypergraph( +# # network_type="undirected", +# # nodes=[ +# # {"node": "0", "attrs": {"weight": 1.5}}, +# # {"node": "1", "attrs": {"weight": 2.5}}, +# # {"node": "2", "attrs": {"weight": 3.5}}, +# # ], +# # hyperedges=[{"edge": "0", "attrs": {}}], +# # incidences=[ +# # {"node": "0", "edge": "0"}, +# # {"node": "1", "edge": "0"}, +# # {"node": "2", "edge": "0"}, +# # ], +# # ) +# # with patch.object(HIFLoader, "load", return_value=mock_hypergraph): -# @pytest.mark.parametrize( -# "strategy, expected_len", -# [ -# # mock_hdata: 3 nodes, 2 hyperedges -# pytest.param(SamplingStrategy.NODE, 3, id="node_strategy"), -# pytest.param(SamplingStrategy.HYPEREDGE, 2, id="hyperedge_strategy"), -# ], -# ) -# def test_from_hdata(strategy, expected_len, mock_hdata): -# dataset = Dataset.from_hdata(mock_hdata, sampling_strategy=strategy) +# # class TestDataset(Dataset): +# # DATASET_NAME = "TEST" -# assert dataset.hdata is mock_hdata -# assert len(dataset) == expected_len +# # dataset = TestDataset() +# # # Single attribute should remain 2D: [num_nodes, 1] +# # assert dataset.hdata.x.shape == (3, 1) +# # assert torch.allclose(dataset.hdata.x, torch.tensor([[1.5], [2.5], [3.5]])) -# def test_from_hdata_download_raises(mock_hdata): -# dataset = Dataset.from_hdata(mock_hdata) -# with pytest.raises(ValueError, match="download can only be called for the original dataset."): -# dataset.download() +# # def test_transform_attrs_adds_padding_zero_when_attr_keys_padding(): +# # mock_hypergraph = HIFHypergraph( +# # network_type="undirected", +# # nodes=[{"node": "0", "attrs": {}}], +# # hyperedges=[{"edge": "0", "attrs": {}}], +# # incidences=[{"node": "0", "edge": "0"}], +# # ) +# # with patch.object(HIFLoader, "load", return_value=mock_hypergraph): -# def test_from_hdata_process_raises(mock_hdata): -# dataset = Dataset.from_hdata(mock_hdata) +# # class TestDataset(Dataset): +# # DATASET_NAME = "TEST" -# with pytest.raises(ValueError, match="process can only be called for the original dataset."): -# dataset.process() +# # dataset = TestDataset() +# # # Test with attr_keys - should pad missing attributes with 0.0 +# # attrs = {"weight": 1.5} +# # result = dataset.transform_attrs(attrs, attr_keys=["score", "weight", "age"]) +# # assert torch.allclose( +# # result, torch.tensor([0.0, 1.5, 0.0]) +# # ) # score=0.0, weight=1.5, age=0.0 -# def test_enrich_node_features_replace(mock_hdata): -# dataset = Dataset.from_hdata(mock_hdata) +# # # Test with all attributes present +# # attrs = {"weight": 1.5, "score": 0.8, "age": 25.0} +# # result = dataset.transform_attrs(attrs, attr_keys=["age", "score", "weight"]) +# # assert torch.allclose( +# # result, torch.tensor([25.0, 0.8, 1.5]) +# # ) # age=25.0, score=0.8, weight=1.5 -# enricher = MagicMock(spec=NodeEnricher) -# enriched_x = torch.randn(3, 4) -# enricher.enrich.return_value = enriched_x +# # # Test without attr_keys - maintains insertion order +# # attrs = {"weight": 1.5, "score": 0.8} +# # result = dataset.transform_attrs(attrs) +# # assert torch.allclose(result, torch.tensor([1.5, 0.8])) # weight, score (insertion order) -# dataset.enrich_node_features(enricher) -# enricher.enrich.assert_called_once_with(mock_hdata.hyperedge_index) -# assert torch.equal(dataset.hdata.x, enriched_x) +# # @pytest.mark.parametrize( +# # "strategy, expected_len", +# # [ +# # # mock_hdata: 3 nodes, 2 hyperedges +# # pytest.param(SamplingStrategy.NODE, 3, id="node_strategy"), +# # pytest.param(SamplingStrategy.HYPEREDGE, 2, id="hyperedge_strategy"), +# # ], +# # ) +# # def test_from_hdata(strategy, expected_len, mock_hdata): +# # dataset = Dataset.from_hdata(mock_hdata, sampling_strategy=strategy) +# # assert dataset.hdata is mock_hdata +# # assert len(dataset) == expected_len -# def test_enrich_node_features_concatenate(mock_hdata): -# dataset = Dataset.from_hdata(mock_hdata) -# original_x = dataset.hdata.x.clone() -# enricher = MagicMock(spec=NodeEnricher) -# enriched_x = torch.randn(3, 4) -# enricher.enrich.return_value = enriched_x +# # def test_from_hdata_download_raises(mock_hdata): +# # dataset = Dataset.from_hdata(mock_hdata) - dataset.enrich_node_features(enricher, enrichment_mode="concatenate") - - enricher.enrich.assert_called_once_with(mock_hdata.hyperedge_index) - expected_x = torch.cat([original_x, enriched_x], dim=1) - assert torch.equal(dataset.hdata.x, expected_x) - assert dataset.hdata.x.shape == (3, 5) # 1 original + 4 enriched - - -def test_enrich_hyperedge_attr_replace(mock_hdata): - dataset = Dataset.from_hdata(mock_hdata) - - enricher = MagicMock(spec=HyperedgeEnricher) - enriched_x = torch.randn(3, 4) - enricher.enrich.return_value = enriched_x - - dataset.enrich_hyperedge_attr(enricher) - - enricher.enrich.assert_called_once_with(mock_hdata.hyperedge_index) - hyperedge_attr = dataset.hdata.hyperedge_attr - assert hyperedge_attr is not None - assert torch.equal(hyperedge_attr, enriched_x) - - -def test_enrich_hyperedge_attr_concatenate(mock_hdata_with_hyperedge_attr): - dataset = Dataset.from_hdata(mock_hdata_with_hyperedge_attr) - original_hyperedge_attr = dataset.hdata.hyperedge_attr - assert original_hyperedge_attr is not None - original_hyperedge_attr = original_hyperedge_attr.clone() - - enricher = MagicMock(spec=HyperedgeEnricher) - enriched_x = torch.randn(3, 4) - enricher.enrich.return_value = enriched_x - - dataset.enrich_hyperedge_attr(enricher, enrichment_mode="concatenate") - - enricher.enrich.assert_called_once_with(mock_hdata_with_hyperedge_attr.hyperedge_index) - expected_x = torch.cat([original_hyperedge_attr, enriched_x], dim=1) - hyperedge_attr = dataset.hdata.hyperedge_attr - assert hyperedge_attr is not None - assert torch.equal(hyperedge_attr, expected_x) - assert hyperedge_attr.shape == (3, 5) # 1 original + 4 enriched - - -def test_enrich_hyperedge_weights_replace(mock_hdata): - dataset = Dataset.from_hdata(mock_hdata) - - enricher = MagicMock(spec=HyperedgeEnricher) - enriched_weights = torch.randn(3) - enricher.enrich.return_value = enriched_weights - - dataset.enrich_hyperedge_weights(enricher) - - enricher.enrich.assert_called_once_with(mock_hdata.hyperedge_index) - hyperedge_weights = dataset.hdata.hyperedge_weights - assert hyperedge_weights is not None - assert torch.equal(hyperedge_weights, enriched_weights) - - -def test_enrich_hyperedge_weights_concatenate(mock_hdata_with_hyperedge_weights): - dataset = Dataset.from_hdata(mock_hdata_with_hyperedge_weights) - original_weights = dataset.hdata.hyperedge_weights - assert original_weights is not None - original_weights = original_weights.clone() - - enricher = MagicMock(spec=HyperedgeEnricher) - enriched_weights = torch.randn(3) - enricher.enrich.return_value = enriched_weights - - dataset.enrich_hyperedge_weights(enricher, enrichment_mode="concatenate") - - enricher.enrich.assert_called_once_with(mock_hdata_with_hyperedge_weights.hyperedge_index) - expected_weights = torch.cat([original_weights, enriched_weights], dim=0) - hyperedge_weights = dataset.hdata.hyperedge_weights - assert hyperedge_weights is not None - assert torch.equal(hyperedge_weights, expected_weights) - assert hyperedge_weights.shape == (6,) # 3 original + 3 enriched - - -# @pytest.mark.parametrize( -# "hyperedge_index, k, expected_hyperedge_index", -# [ -# pytest.param( -# torch.tensor([[0, 1, 2], [0, 0, 0]]), -# 4, -# torch.zeros((2, 0), dtype=torch.long), -# id="single_hyperedge_below_k_removed", -# ), -# pytest.param( -# torch.tensor([[0, 1, 2], [0, 0, 0]]), -# 3, -# torch.tensor([[0, 1, 2], [0, 0, 0]]), -# id="single_hyperedge_at_exact_k_kept", -# ), -# pytest.param( -# torch.tensor([[0, 1, 2, 3, 4], [0, 0, 0, 1, 1]]), -# 3, -# torch.tensor([[0, 1, 2], [0, 0, 0]]), -# id="two_hyperedges_first_kept_second_removed", -# ), -# pytest.param( -# torch.tensor([[0, 1, 2, 3, 4, 5], [0, 0, 0, 1, 1, 1]]), -# 3, -# torch.tensor([[0, 1, 2, 3, 4, 5], [0, 0, 0, 1, 1, 1]]), -# id="two_hyperedges_both_kept", -# ), -# pytest.param( -# torch.tensor([[0, 1, 2, 3, 4, 5], [0, 0, 1, 1, 2, 2]]), -# 3, -# torch.zeros((2, 0), dtype=torch.long), -# id="three_hyperedges_all_removed", -# ), -# ], -# ) -# def test_remove_hyperedges_with_fewer_than_k_nodes(hyperedge_index, k, expected_hyperedge_index): -# num_nodes = hyperedge_index[0].max().item() + 1 if hyperedge_index.shape[1] > 0 else 0 -# x = torch.ones((num_nodes, 1), dtype=torch.float) -# hdata = HData(x=x, hyperedge_index=hyperedge_index) -# dataset = Dataset.from_hdata(hdata) - -# dataset.remove_hyperedges_with_fewer_than_k_nodes(k) - -# expected_num_nodes = expected_hyperedge_index[0].unique().shape[0] -# expected_num_hyperedges = expected_hyperedge_index[1].unique().shape[0] - -# assert torch.equal(dataset.hdata.hyperedge_index, expected_hyperedge_index) -# assert dataset.hdata.x.shape[0] == expected_num_nodes -# assert dataset.hdata.y.shape[0] == expected_num_hyperedges - - -# def test_split_with_equal_ratios(mock_four_node_hypergraph): -# with patch.object(HIFLoader, "load", return_value=mock_four_node_hypergraph): -# dataset = AlgebraDataset() +# # with pytest.raises(ValueError, match="download can only be called for the original dataset."): +# # dataset.download() -# splits = dataset.split([0.5, 0.5]) -# assert len(splits) == 2 -# assert ( -# splits[0].hdata.num_hyperedges + splits[1].hdata.num_hyperedges -# == dataset.hdata.num_hyperedges -# ) -# for split in splits: -# assert split.hdata.x is not None -# assert split.hdata.num_nodes > 0 -# assert split.hdata.num_hyperedges > 0 +# # def test_from_hdata_process_raises(mock_hdata): +# # dataset = Dataset.from_hdata(mock_hdata) +# # with pytest.raises(ValueError, match="process can only be called for the original dataset."): +# # dataset.process() -# def test_split_three_way(mock_multiple_edges_attr_hypergraph): -# with patch.object( -# HIFLoader, "load", return_value=mock_multiple_edges_attr_hypergraph -# ): -# dataset = AlgebraDataset() -# splits = dataset.split([0.5, 0.25, 0.25]) -# total_edges = sum(split.hdata.num_hyperedges for split in splits) +# # def test_enrich_node_features_replace(mock_hdata): +# # dataset = Dataset.from_hdata(mock_hdata) -# assert len(splits) == 3 -# assert total_edges == dataset.hdata.num_hyperedges +# # enricher = MagicMock(spec=NodeEnricher) +# # enriched_x = torch.randn(3, 4) +# # enricher.enrich.return_value = enriched_x -# for split in splits: -# assert split.hdata.x is not None -# assert split.hdata.num_nodes > 0 -# assert split.hdata.num_hyperedges > 0 +# # dataset.enrich_node_features(enricher) +# # enricher.enrich.assert_called_once_with(mock_hdata.hyperedge_index) +# # assert torch.equal(dataset.hdata.x, enriched_x) -# def test_split_raises_when_ratios_do_not_sum_to_one(mock_four_node_hypergraph): -# with patch.object(HIFLoader, "load", return_value=mock_four_node_hypergraph): -# dataset = AlgebraDataset() - -# with pytest.raises(ValueError, match="Split ratios must sum to 1.0"): -# dataset.split([0.8, 0.1, 0.05]) +# # def test_enrich_node_features_concatenate(mock_hdata): +# # dataset = Dataset.from_hdata(mock_hdata) +# # original_x = dataset.hdata.x.clone() -# def test_split_with_shuffle_produces_deterministic_results_when_seed_provided( -# mock_four_node_hypergraph, -# ): -# with patch.object(HIFLoader, "load", return_value=mock_four_node_hypergraph): -# dataset = AlgebraDataset() +# # enricher = MagicMock(spec=NodeEnricher) +# # enriched_x = torch.randn(3, 4) +# # enricher.enrich.return_value = enriched_x -# splits_a = dataset.split([0.5, 0.5], shuffle=True, seed=42) -# splits_b = dataset.split([0.5, 0.5], shuffle=True, seed=42) +# dataset.enrich_node_features(enricher, enrichment_mode="concatenate") -# assert torch.equal(splits_a[0].hdata.hyperedge_index, splits_b[0].hdata.hyperedge_index) -# assert torch.equal(splits_a[1].hdata.hyperedge_index, splits_b[1].hdata.hyperedge_index) +# enricher.enrich.assert_called_once_with(mock_hdata.hyperedge_index) +# expected_x = torch.cat([original_x, enriched_x], dim=1) +# assert torch.equal(dataset.hdata.x, expected_x) +# assert dataset.hdata.x.shape == (3, 5) # 1 original + 4 enriched -# def test_split_with_shuffle_when_no_seed_provided( -# mock_four_node_hypergraph, -# ): -# with patch.object(HIFLoader, "load", return_value=mock_four_node_hypergraph): -# dataset = AlgebraDataset() +# def test_enrich_hyperedge_attr_replace(mock_hdata): +# dataset = Dataset.from_hdata(mock_hdata) -# splits = dataset.split([0.5, 0.5], shuffle=True) -# total_edges = sum(split.hdata.num_hyperedges for split in splits) +# enricher = MagicMock(spec=HyperedgeEnricher) +# enriched_x = torch.randn(3, 4) +# enricher.enrich.return_value = enriched_x -# assert len(splits) == 2 -# assert total_edges == dataset.hdata.num_hyperedges +# dataset.enrich_hyperedge_attr(enricher) -# for split in splits: -# assert split.hdata.x is not None -# assert split.hdata.num_nodes > 0 -# assert split.hdata.num_hyperedges > 0 +# enricher.enrich.assert_called_once_with(mock_hdata.hyperedge_index) +# hyperedge_attr = dataset.hdata.hyperedge_attr +# assert hyperedge_attr is not None +# assert torch.equal(hyperedge_attr, enriched_x) -# def test_split_preserves_edge_attr(mock_multiple_edges_attr_hypergraph): -# with patch.object( -# HIFLoader, "load", return_value=mock_multiple_edges_attr_hypergraph -# ): -# dataset = AlgebraDataset() +# def test_enrich_hyperedge_attr_concatenate(mock_hdata_with_hyperedge_attr): +# dataset = Dataset.from_hdata(mock_hdata_with_hyperedge_attr) +# original_hyperedge_attr = dataset.hdata.hyperedge_attr +# assert original_hyperedge_attr is not None +# original_hyperedge_attr = original_hyperedge_attr.clone() -# splits = dataset.split([0.5, 0.5]) +# enricher = MagicMock(spec=HyperedgeEnricher) +# enriched_x = torch.randn(3, 4) +# enricher.enrich.return_value = enriched_x -# for split in splits: -# assert split.hdata.hyperedge_attr is not None -# assert split.hdata.hyperedge_attr.shape[0] == split.hdata.num_hyperedges +# dataset.enrich_hyperedge_attr(enricher, enrichment_mode="concatenate") +# enricher.enrich.assert_called_once_with(mock_hdata_with_hyperedge_attr.hyperedge_index) +# expected_x = torch.cat([original_hyperedge_attr, enriched_x], dim=1) +# hyperedge_attr = dataset.hdata.hyperedge_attr +# assert hyperedge_attr is not None +# assert torch.equal(hyperedge_attr, expected_x) +# assert hyperedge_attr.shape == (3, 5) # 1 original + 4 enriched -# def test_split_without_edge_attr(mock_no_edge_attr_hypergraph): -# with patch.object(HIFLoader, "load", return_value=mock_no_edge_attr_hypergraph): -# dataset = AlgebraDataset() -# splits = dataset.split([0.5, 0.5]) +# def test_enrich_hyperedge_weights_replace(mock_hdata): +# dataset = Dataset.from_hdata(mock_hdata) -# for split in splits: -# assert split.hdata.hyperedge_attr is None +# enricher = MagicMock(spec=HyperedgeEnricher) +# enriched_weights = torch.randn(3) +# enricher.enrich.return_value = enriched_weights +# dataset.enrich_hyperedge_weights(enricher) -# def test_to_device(mock_hdata): -# device = torch.device("cpu") +# enricher.enrich.assert_called_once_with(mock_hdata.hyperedge_index) +# hyperedge_weights = dataset.hdata.hyperedge_weights +# assert hyperedge_weights is not None +# assert torch.equal(hyperedge_weights, enriched_weights) + + +# def test_enrich_hyperedge_weights_concatenate(mock_hdata_with_hyperedge_weights): +# dataset = Dataset.from_hdata(mock_hdata_with_hyperedge_weights) +# original_weights = dataset.hdata.hyperedge_weights +# assert original_weights is not None +# original_weights = original_weights.clone() + +# enricher = MagicMock(spec=HyperedgeEnricher) +# enriched_weights = torch.randn(3) +# enricher.enrich.return_value = enriched_weights + +# dataset.enrich_hyperedge_weights(enricher, enrichment_mode="concatenate") + +# enricher.enrich.assert_called_once_with(mock_hdata_with_hyperedge_weights.hyperedge_index) +# expected_weights = torch.cat([original_weights, enriched_weights], dim=0) +# hyperedge_weights = dataset.hdata.hyperedge_weights +# assert hyperedge_weights is not None +# assert torch.equal(hyperedge_weights, expected_weights) +# assert hyperedge_weights.shape == (6,) # 3 original + 3 enriched + + +# # @pytest.mark.parametrize( +# # "hyperedge_index, k, expected_hyperedge_index", +# # [ +# # pytest.param( +# # torch.tensor([[0, 1, 2], [0, 0, 0]]), +# # 4, +# # torch.zeros((2, 0), dtype=torch.long), +# # id="single_hyperedge_below_k_removed", +# # ), +# # pytest.param( +# # torch.tensor([[0, 1, 2], [0, 0, 0]]), +# # 3, +# # torch.tensor([[0, 1, 2], [0, 0, 0]]), +# # id="single_hyperedge_at_exact_k_kept", +# # ), +# # pytest.param( +# # torch.tensor([[0, 1, 2, 3, 4], [0, 0, 0, 1, 1]]), +# # 3, +# # torch.tensor([[0, 1, 2], [0, 0, 0]]), +# # id="two_hyperedges_first_kept_second_removed", +# # ), +# # pytest.param( +# # torch.tensor([[0, 1, 2, 3, 4, 5], [0, 0, 0, 1, 1, 1]]), +# # 3, +# # torch.tensor([[0, 1, 2, 3, 4, 5], [0, 0, 0, 1, 1, 1]]), +# # id="two_hyperedges_both_kept", +# # ), +# # pytest.param( +# # torch.tensor([[0, 1, 2, 3, 4, 5], [0, 0, 1, 1, 2, 2]]), +# # 3, +# # torch.zeros((2, 0), dtype=torch.long), +# # id="three_hyperedges_all_removed", +# # ), +# # ], +# # ) +# # def test_remove_hyperedges_with_fewer_than_k_nodes(hyperedge_index, k, expected_hyperedge_index): +# # num_nodes = hyperedge_index[0].max().item() + 1 if hyperedge_index.shape[1] > 0 else 0 +# # x = torch.ones((num_nodes, 1), dtype=torch.float) +# # hdata = HData(x=x, hyperedge_index=hyperedge_index) +# # dataset = Dataset.from_hdata(hdata) + +# # dataset.remove_hyperedges_with_fewer_than_k_nodes(k) + +# # expected_num_nodes = expected_hyperedge_index[0].unique().shape[0] +# # expected_num_hyperedges = expected_hyperedge_index[1].unique().shape[0] + +# # assert torch.equal(dataset.hdata.hyperedge_index, expected_hyperedge_index) +# # assert dataset.hdata.x.shape[0] == expected_num_nodes +# # assert dataset.hdata.y.shape[0] == expected_num_hyperedges + + +# # def test_split_with_equal_ratios(mock_four_node_hypergraph): +# # with patch.object(HIFLoader, "load", return_value=mock_four_node_hypergraph): +# # dataset = AlgebraDataset() + +# # splits = dataset.split([0.5, 0.5]) + +# # assert len(splits) == 2 +# # assert ( +# # splits[0].hdata.num_hyperedges + splits[1].hdata.num_hyperedges +# # == dataset.hdata.num_hyperedges +# # ) +# # for split in splits: +# # assert split.hdata.x is not None +# # assert split.hdata.num_nodes > 0 +# # assert split.hdata.num_hyperedges > 0 + + +# # def test_split_three_way(mock_multiple_edges_attr_hypergraph): +# # with patch.object( +# # HIFLoader, "load", return_value=mock_multiple_edges_attr_hypergraph +# # ): +# # dataset = AlgebraDataset() + +# # splits = dataset.split([0.5, 0.25, 0.25]) +# # total_edges = sum(split.hdata.num_hyperedges for split in splits) + +# # assert len(splits) == 3 +# # assert total_edges == dataset.hdata.num_hyperedges + +# # for split in splits: +# # assert split.hdata.x is not None +# # assert split.hdata.num_nodes > 0 +# # assert split.hdata.num_hyperedges > 0 + + +# # def test_split_raises_when_ratios_do_not_sum_to_one(mock_four_node_hypergraph): +# # with patch.object(HIFLoader, "load", return_value=mock_four_node_hypergraph): +# # dataset = AlgebraDataset() + +# # with pytest.raises(ValueError, match="Split ratios must sum to 1.0"): +# # dataset.split([0.8, 0.1, 0.05]) + + +# # def test_split_with_shuffle_produces_deterministic_results_when_seed_provided( +# # mock_four_node_hypergraph, +# # ): +# # with patch.object(HIFLoader, "load", return_value=mock_four_node_hypergraph): +# # dataset = AlgebraDataset() + +# # splits_a = dataset.split([0.5, 0.5], shuffle=True, seed=42) +# # splits_b = dataset.split([0.5, 0.5], shuffle=True, seed=42) + +# # assert torch.equal(splits_a[0].hdata.hyperedge_index, splits_b[0].hdata.hyperedge_index) +# # assert torch.equal(splits_a[1].hdata.hyperedge_index, splits_b[1].hdata.hyperedge_index) + + +# # def test_split_with_shuffle_when_no_seed_provided( +# # mock_four_node_hypergraph, +# # ): +# # with patch.object(HIFLoader, "load", return_value=mock_four_node_hypergraph): +# # dataset = AlgebraDataset() -# dataset = Dataset.from_hdata(mock_hdata) +# # splits = dataset.split([0.5, 0.5], shuffle=True) +# # total_edges = sum(split.hdata.num_hyperedges for split in splits) -# result = dataset.to(device) +# # assert len(splits) == 2 +# # assert total_edges == dataset.hdata.num_hyperedges -# assert result is dataset -# assert dataset.hdata.device == device +# # for split in splits: +# # assert split.hdata.x is not None +# # assert split.hdata.num_nodes > 0 +# # assert split.hdata.num_hyperedges > 0 -# def test_load_skips_download_when_file_exists(): -# dataset_name = "ALGEBRA" +# # def test_split_preserves_edge_attr(mock_multiple_edges_attr_hypergraph): +# # with patch.object( +# # HIFLoader, "load", return_value=mock_multiple_edges_attr_hypergraph +# # ): +# # dataset = AlgebraDataset() -# sample_hif = { -# "network-type": "undirected", -# "nodes": [{"node": "0"}, {"node": "1"}], -# "edges": [{"edge": "0"}], -# "incidences": [{"node": "0", "edge": "0"}], -# } +# # splits = dataset.split([0.5, 0.5]) -# mock_hypergraph = HIFHypergraph( -# network_type="undirected", -# nodes=[{"node": "0"}, {"node": "1"}], -# hyperedges=[{"edge": "0"}], -# incidences=[{"node": "0", "edge": "0"}], -# ) +# # for split in splits: +# # assert split.hdata.hyperedge_attr is not None +# # assert split.hdata.hyperedge_attr.shape[0] == split.hdata.num_hyperedges -# with ( -# patch("hyperbench.data.dataset.requests.get") as mock_get, -# patch("hyperbench.data.dataset.os.path.exists", return_value=True), -# patch("builtins.open", mock_open()) as mock_file, -# patch("hyperbench.data.dataset.zstd.ZstdDecompressor") as mock_decomp, -# patch("hyperbench.data.dataset.tempfile.NamedTemporaryFile") as mock_temp, -# patch("hyperbench.data.dataset.json.load", return_value=sample_hif), -# patch("hyperbench.data.dataset.validate_hif_json", return_value=True), -# patch.object(HIFHypergraph, "from_hif", return_value=mock_hypergraph), -# ): -# mock_dctx = mock_decomp.return_value -# mock_dctx.copy_stream = lambda input_f, tmp_file: None - -# mock_temp_instance = mock_temp.return_value.__enter__.return_value -# mock_temp_instance.name = "/tmp/decompressed.json" - -# result = HIFLoader.load(dataset_name, save_on_disk=True) -# mock_get.assert_not_called() -# assert result == mock_hypergraph - - -# def test_default_sampling_strategy_is_hyperedge(mock_four_node_hypergraph): -# with patch.object(HIFLoader, "load", return_value=mock_four_node_hypergraph): -# dataset = AlgebraDataset() -# # Default strategy is HYPEREDGE, so len should be num_hyperedges (2), not num_nodes (4) -# assert dataset.sampling_strategy == SamplingStrategy.HYPEREDGE -# assert len(dataset) == 2 +# # def test_split_without_edge_attr(mock_no_edge_attr_hypergraph): +# # with patch.object(HIFLoader, "load", return_value=mock_no_edge_attr_hypergraph): +# # dataset = AlgebraDataset() +# # splits = dataset.split([0.5, 0.5]) -# def test_explicit_node_sampling_strategy(mock_four_node_hypergraph): -# with patch.object(HIFLoader, "load", return_value=mock_four_node_hypergraph): -# dataset = AlgebraDataset(sampling_strategy=SamplingStrategy.NODE) +# # for split in splits: +# # assert split.hdata.hyperedge_attr is None -# # NODE strategy, so len should be num_nodes (4), not num_hyperedges (2) -# assert dataset.sampling_strategy == SamplingStrategy.NODE -# assert len(dataset) == 4 +# # def test_to_device(mock_hdata): +# # device = torch.device("cpu") -# @pytest.mark.parametrize( -# "strategy", -# [ -# pytest.param(SamplingStrategy.NODE, id="node_strategy"), -# pytest.param(SamplingStrategy.HYPEREDGE, id="hyperedge_strategy"), -# ], -# ) -# def test_split_preserves_sampling_strategy(mock_four_node_hypergraph, strategy): -# with patch.object(HIFLoader, "load", return_value=mock_four_node_hypergraph): -# dataset = AlgebraDataset(sampling_strategy=strategy) +# # dataset = Dataset.from_hdata(mock_hdata) -# splits = dataset.split([0.5, 0.5]) +# # result = dataset.to(device) -# for split in splits: -# assert split.sampling_strategy == strategy +# # assert result is dataset +# # assert dataset.hdata.device == device -# def test_from_hdata_with_explicit_strategy(mock_hdata): -# dataset = Dataset.from_hdata(mock_hdata, sampling_strategy=SamplingStrategy.NODE) +# # def test_load_skips_download_when_file_exists(): +# # dataset_name = "ALGEBRA" -# assert dataset.sampling_strategy == SamplingStrategy.NODE -# assert len(dataset) == 3 # mock_hdata has 3 nodes +# # sample_hif = { +# # "network-type": "undirected", +# # "nodes": [{"node": "0"}, {"node": "1"}], +# # "edges": [{"edge": "0"}], +# # "incidences": [{"node": "0", "edge": "0"}], +# # } +# # mock_hypergraph = HIFHypergraph( +# # network_type="undirected", +# # nodes=[{"node": "0"}, {"node": "1"}], +# # hyperedges=[{"edge": "0"}], +# # incidences=[{"node": "0", "edge": "0"}], +# # ) -# def test_update_from_hdata_returns_new_dataset(mock_hdata): -# dataset = Dataset(hdata=mock_hdata, prepare=False) -# new_x = torch.ones((2, 1), dtype=torch.float) -# new_hyperedge_index = torch.tensor([[0, 1], [0, 0]], dtype=torch.long) -# new_hdata = HData(x=new_x, hyperedge_index=new_hyperedge_index) +# # with ( +# # patch("hyperbench.data.dataset.requests.get") as mock_get, +# # patch("hyperbench.data.dataset.os.path.exists", return_value=True), +# # patch("builtins.open", mock_open()) as mock_file, +# # patch("hyperbench.data.dataset.zstd.ZstdDecompressor") as mock_decomp, +# # patch("hyperbench.data.dataset.tempfile.NamedTemporaryFile") as mock_temp, +# # patch("hyperbench.data.dataset.json.load", return_value=sample_hif), +# # patch("hyperbench.data.dataset.validate_hif_json", return_value=True), +# # patch.object(HIFHypergraph, "from_hif", return_value=mock_hypergraph), +# # ): +# # mock_dctx = mock_decomp.return_value +# # mock_dctx.copy_stream = lambda input_f, tmp_file: None -# result = dataset.update_from_hdata(new_hdata) +# # mock_temp_instance = mock_temp.return_value.__enter__.return_value +# # mock_temp_instance.name = "/tmp/decompressed.json" -# assert result is not dataset -# assert result.hdata is new_hdata -# assert dataset.hdata is mock_hdata +# # result = HIFLoader.load(dataset_name, save_on_disk=True) +# # mock_get.assert_not_called() +# # assert result == mock_hypergraph -# def test_update_from_hdata_stores_provided_hdata(mock_hdata): -# dataset = Dataset(hdata=mock_hdata, prepare=False) -# new_x = torch.ones((2, 1), dtype=torch.float) -# new_hyperedge_index = torch.tensor([[0, 1], [0, 0]], dtype=torch.long) -# new_hdata = HData(x=new_x, hyperedge_index=new_hyperedge_index) +# # def test_default_sampling_strategy_is_hyperedge(mock_four_node_hypergraph): +# # with patch.object(HIFLoader, "load", return_value=mock_four_node_hypergraph): +# # dataset = AlgebraDataset() -# result = dataset.update_from_hdata(new_hdata) +# # # Default strategy is HYPEREDGE, so len should be num_hyperedges (2), not num_nodes (4) +# # assert dataset.sampling_strategy == SamplingStrategy.HYPEREDGE +# # assert len(dataset) == 2 -# assert result.hdata is new_hdata +# # def test_explicit_node_sampling_strategy(mock_four_node_hypergraph): +# # with patch.object(HIFLoader, "load", return_value=mock_four_node_hypergraph): +# # dataset = AlgebraDataset(sampling_strategy=SamplingStrategy.NODE) -# @pytest.mark.parametrize( -# "strategy, expected_len", -# [ -# pytest.param(SamplingStrategy.NODE, 4, id="node_strategy"), -# pytest.param(SamplingStrategy.HYPEREDGE, 3, id="hyperedge_strategy"), -# ], -# ) -# def test_update_from_hdata_inherits_sampling_strategy(mock_hdata, strategy, expected_len): -# dataset = Dataset(hdata=mock_hdata, sampling_strategy=strategy, prepare=False) -# new_x = torch.ones((4, 1), dtype=torch.float) -# new_hyperedge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 2]], dtype=torch.long) -# new_hdata = HData(x=new_x, hyperedge_index=new_hyperedge_index) +# # # NODE strategy, so len should be num_nodes (4), not num_hyperedges (2) +# # assert dataset.sampling_strategy == SamplingStrategy.NODE +# # assert len(dataset) == 4 -# result = dataset.update_from_hdata(new_hdata) -# assert result.sampling_strategy == strategy -# assert len(result) == expected_len +# # @pytest.mark.parametrize( +# # "strategy", +# # [ +# # pytest.param(SamplingStrategy.NODE, id="node_strategy"), +# # pytest.param(SamplingStrategy.HYPEREDGE, id="hyperedge_strategy"), +# # ], +# # ) +# # def test_split_preserves_sampling_strategy(mock_four_node_hypergraph, strategy): +# # with patch.object(HIFLoader, "load", return_value=mock_four_node_hypergraph): +# # dataset = AlgebraDataset(sampling_strategy=strategy) +# # splits = dataset.split([0.5, 0.5]) -# def test_update_from_hdata_preserves_subclass_type(mock_hdata): -# dataset = AlgebraDataset(hdata=mock_hdata, prepare=False) -# new_x = torch.ones((2, 1), dtype=torch.float) -# new_hyperedge_index = torch.tensor([[0, 1], [0, 0]], dtype=torch.long) -# new_hdata = HData(x=new_x, hyperedge_index=new_hyperedge_index) +# # for split in splits: +# # assert split.sampling_strategy == strategy -# result = dataset.update_from_hdata(new_hdata) -# assert type(result) is AlgebraDataset +# # def test_from_hdata_with_explicit_strategy(mock_hdata): +# # dataset = Dataset.from_hdata(mock_hdata, sampling_strategy=SamplingStrategy.NODE) +# # assert dataset.sampling_strategy == SamplingStrategy.NODE +# # assert len(dataset) == 3 # mock_hdata has 3 nodes -# @pytest.fixture -# def mock_hdata_stats(): -# x = torch.tensor( -# [ -# [0.0, 1.0, 2.0, 3.0], -# [1.0, 2.0, 3.0, 4.0], -# [2.0, 3.0, 4.0, 5.0], -# [3.0, 4.0, 5.0, 6.0], -# ], -# dtype=torch.float, -# ) -# hyperedge_index = torch.tensor( -# [ -# [0, 1, 2, 2, 3], -# [0, 0, 0, 1, 1], -# ] -# ) -# return HData(x=x, hyperedge_index=hyperedge_index) +# # def test_update_from_hdata_returns_new_dataset(mock_hdata): +# # dataset = Dataset(hdata=mock_hdata, prepare=False) +# # new_x = torch.ones((2, 1), dtype=torch.float) +# # new_hyperedge_index = torch.tensor([[0, 1], [0, 0]], dtype=torch.long) +# # new_hdata = HData(x=new_x, hyperedge_index=new_hyperedge_index) + +# # result = dataset.update_from_hdata(new_hdata) + +# # assert result is not dataset +# # assert result.hdata is new_hdata +# # assert dataset.hdata is mock_hdata + + +# # def test_update_from_hdata_stores_provided_hdata(mock_hdata): +# # dataset = Dataset(hdata=mock_hdata, prepare=False) +# # new_x = torch.ones((2, 1), dtype=torch.float) +# # new_hyperedge_index = torch.tensor([[0, 1], [0, 0]], dtype=torch.long) +# # new_hdata = HData(x=new_x, hyperedge_index=new_hyperedge_index) + +# # result = dataset.update_from_hdata(new_hdata) + +# # assert result.hdata is new_hdata + + +# # @pytest.mark.parametrize( +# # "strategy, expected_len", +# # [ +# # pytest.param(SamplingStrategy.NODE, 4, id="node_strategy"), +# # pytest.param(SamplingStrategy.HYPEREDGE, 3, id="hyperedge_strategy"), +# # ], +# # ) +# # def test_update_from_hdata_inherits_sampling_strategy(mock_hdata, strategy, expected_len): +# # dataset = Dataset(hdata=mock_hdata, sampling_strategy=strategy, prepare=False) +# # new_x = torch.ones((4, 1), dtype=torch.float) +# # new_hyperedge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 2]], dtype=torch.long) +# # new_hdata = HData(x=new_x, hyperedge_index=new_hyperedge_index) + +# # result = dataset.update_from_hdata(new_hdata) + +# # assert result.sampling_strategy == strategy +# # assert len(result) == expected_len + + +# # def test_update_from_hdata_preserves_subclass_type(mock_hdata): +# # dataset = AlgebraDataset(hdata=mock_hdata, prepare=False) +# # new_x = torch.ones((2, 1), dtype=torch.float) +# # new_hyperedge_index = torch.tensor([[0, 1], [0, 0]], dtype=torch.long) +# # new_hdata = HData(x=new_x, hyperedge_index=new_hyperedge_index) + +# # result = dataset.update_from_hdata(new_hdata) + +# # assert type(result) is AlgebraDataset -# def test_dataset_stats_computation(mock_hdata_stats): -# expected_stats = { -# "shape_x": torch.Size([4, 4]), -# "shape_hyperedge_attr": None, -# "shape_hyperedge_weights": None, - "num_nodes": 4, -# "num_hyperedges": 2, -# "avg_degree_node_raw": 1.25, -# "avg_degree_node": 1, -# "avg_degree_hyperedge_raw": 2.5, -# "avg_degree_hyperedge": 2, -# "node_degree_max": 2, -# "hyperedge_degree_max": 3, -# "node_degree_median": 1, -# "hyperedge_degree_median": 2, -# "distribution_node_degree": [1, 1, 2, 1], -# "distribution_hyperedge_size": [3, 2], -# "distribution_node_degree_hist": {1: 3, 2: 1}, -# "distribution_hyperedge_size_hist": {2: 1, 3: 1}, -# } - -# dataset = Dataset.from_hdata(mock_hdata_stats) - -# stats = dataset.stats() -# assert stats == expected_stats + +# # @pytest.fixture +# # def mock_hdata_stats(): +# # x = torch.tensor( +# # [ +# # [0.0, 1.0, 2.0, 3.0], +# # [1.0, 2.0, 3.0, 4.0], +# # [2.0, 3.0, 4.0, 5.0], +# # [3.0, 4.0, 5.0, 6.0], +# # ], +# # dtype=torch.float, +# # ) +# # hyperedge_index = torch.tensor( +# # [ +# # [0, 1, 2, 2, 3], +# # [0, 0, 0, 1, 1], +# # ] +# # ) +# # return HData(x=x, hyperedge_index=hyperedge_index) + + +# # def test_dataset_stats_computation(mock_hdata_stats): +# # expected_stats = { +# # "shape_x": torch.Size([4, 4]), +# # "shape_hyperedge_attr": None, +# # "shape_hyperedge_weights": None, +# "num_nodes": 4, +# # "num_hyperedges": 2, +# # "avg_degree_node_raw": 1.25, +# # "avg_degree_node": 1, +# # "avg_degree_hyperedge_raw": 2.5, +# # "avg_degree_hyperedge": 2, +# # "node_degree_max": 2, +# # "hyperedge_degree_max": 3, +# # "node_degree_median": 1, +# # "hyperedge_degree_median": 2, +# # "distribution_node_degree": [1, 1, 2, 1], +# # "distribution_hyperedge_size": [3, 2], +# # "distribution_node_degree_hist": {1: 3, 2: 1}, +# # "distribution_hyperedge_size_hist": {2: 1, 3: 1}, +# # } + +# # dataset = Dataset.from_hdata(mock_hdata_stats) + +# # stats = dataset.stats() +# # assert stats == expected_stats From fe01450cd5f59e6d5f78f5516a8b373fefebab35 Mon Sep 17 00:00:00 2001 From: ddevin96 Date: Thu, 23 Apr 2026 08:56:08 +0200 Subject: [PATCH 05/15] fix: removed prepare as parameter --- examples/hgnn.py | 2 +- examples/hyperedge_enricher.py | 5 +++-- examples/hypergcn.py | 2 +- examples/mlp_common_neighbors.py | 2 +- examples/node_enricher.py | 4 ++-- hyperbench/data/dataset.py | 38 ++++++++++++++++---------------- hyperbench/data/hif.py | 2 -- 7 files changed, 27 insertions(+), 28 deletions(-) diff --git a/examples/hgnn.py b/examples/hgnn.py index b1644aa..f0bc454 100644 --- a/examples/hgnn.py +++ b/examples/hgnn.py @@ -28,7 +28,7 @@ print("Loading and preparing dataset...") - dataset = AlgebraDataset(sampling_strategy=sampling_strategy, prepare=True) + dataset = AlgebraDataset(sampling_strategy=sampling_strategy) if verbose: print(f"Dataset:\n {dataset.hdata}\n") diff --git a/examples/hyperedge_enricher.py b/examples/hyperedge_enricher.py index 43882f6..dfc9e3c 100644 --- a/examples/hyperedge_enricher.py +++ b/examples/hyperedge_enricher.py @@ -9,6 +9,7 @@ print("Enriching hyperedge weights...") + dataset = AlgebraDataset(sampling_strategy=sampling_strategy) # HyperedgeWeightsEnricher enriches hyperedges with their degree (number of nodes in each hyperedge) as weights. # It optionally applies scaling and adds a constant to the weights. dataset.enrich_hyperedge_weights( @@ -25,8 +26,8 @@ print("Enriching hyperedge attributes...") - # HyperedgeAttrsEnricher adds a feature of 1.0 for each hyperedge, - # which can be used as a baseline or for methods that require hyperedge features. + dataset = AlgebraDataset(sampling_strategy=sampling_strategy) + # HyperedgeAttrsEnricher adds a feature of 1.0 for each hyperedge, which can be used as a baseline or for methods that require hyperedge features. dataset.enrich_hyperedge_attr( enricher=HyperedgeAttrsEnricher(), enrichment_mode="replace", diff --git a/examples/hypergcn.py b/examples/hypergcn.py index e94e281..c01c79f 100644 --- a/examples/hypergcn.py +++ b/examples/hypergcn.py @@ -28,7 +28,7 @@ print("Loading and preparing dataset...") - dataset = AlgebraDataset(sampling_strategy=sampling_strategy, prepare=True) + dataset = AlgebraDataset(sampling_strategy=sampling_strategy) dataset.remove_hyperedges_with_fewer_than_k_nodes(k=2) if verbose: print(f"Dataset:\n {dataset.hdata}\n") diff --git a/examples/mlp_common_neighbors.py b/examples/mlp_common_neighbors.py index b638429..adb7c56 100644 --- a/examples/mlp_common_neighbors.py +++ b/examples/mlp_common_neighbors.py @@ -28,7 +28,7 @@ print("Loading and preparing dataset...") - dataset = AlgebraDataset(sampling_strategy=sampling_strategy, prepare=True) + dataset = AlgebraDataset(sampling_strategy=sampling_strategy) if verbose: print(f"Dataset:\n {dataset.hdata}\n") diff --git a/examples/node_enricher.py b/examples/node_enricher.py index dd5a932..0803c2c 100644 --- a/examples/node_enricher.py +++ b/examples/node_enricher.py @@ -5,8 +5,8 @@ if __name__ == "__main__": print("Loading and preparing dataset...") - dataset = AlgebraDataset(sampling_strategy=SamplingStrategy.HYPEREDGE, prepare=True) - + dataset = AlgebraDataset(sampling_strategy=sampling_strategy) + # NodeEnricher adds features for each node. dataset.enrich_node_features( enricher=LaplacianPositionalEncodingEnricher(num_features=32), enrichment_mode="replace", diff --git a/hyperbench/data/dataset.py b/hyperbench/data/dataset.py index bfd1e10..1ecda68 100644 --- a/hyperbench/data/dataset.py +++ b/hyperbench/data/dataset.py @@ -133,25 +133,25 @@ def from_path( dataset = cls.from_hdata(hdata=hypergraph, sampling_strategy=sampling_strategy) return dataset - @classmethod - def from_default( - cls, - sampling_strategy: SamplingStrategy = SamplingStrategy.HYPEREDGE, - save_on_disk: bool = False, - ) -> "Dataset": - """ - Create a :class:`Dataset` instance by loading a hypergraph from a URL pointing to a .json or .json.zst file in HIF format. - - Args: - sampling_strategy: The sampling strategy to use for the dataset. If not provided, defaults to ``SamplingStrategy.HYPEREDGE``. - save_on_disk: Whether to save the downloaded file on disk. - - Returns: - The :class:`Dataset` instance with the loaded hypergraph data. - """ - hdata = HIFLoader.load(dataset_name="", save_on_disk=save_on_disk) - dataset = cls.from_hdata(hdata=hdata, sampling_strategy=sampling_strategy) - return dataset + # @classmethod + # def from_default( + # cls, + # sampling_strategy: SamplingStrategy = SamplingStrategy.HYPEREDGE, + # save_on_disk: bool = False, + # ) -> "Dataset": + # """ + # Create a :class:`Dataset` instance by loading a hypergraph from a URL pointing to a .json or .json.zst file in HIF format. + + # Args: + # sampling_strategy: The sampling strategy to use for the dataset. If not provided, defaults to ``SamplingStrategy.HYPEREDGE``. + # save_on_disk: Whether to save the downloaded file on disk. + + # Returns: + # The :class:`Dataset` instance with the loaded hypergraph data. + # """ + # hdata = HIFLoader.load(dataset_name="", save_on_disk=save_on_disk) + # dataset = cls.from_hdata(hdata=hdata, sampling_strategy=sampling_strategy) + # return dataset def enrich_node_features( self, diff --git a/hyperbench/data/hif.py b/hyperbench/data/hif.py index 0d3fd7d..1f02597 100644 --- a/hyperbench/data/hif.py +++ b/hyperbench/data/hif.py @@ -1,6 +1,4 @@ -from turtle import st import torch -from sympy.physics.units import h import os import json import zstandard as zstd From e59ee85ef20e111d942dd55e151bea1b0e2d9423 Mon Sep 17 00:00:00 2001 From: ddevin96 Date: Thu, 23 Apr 2026 12:07:55 +0200 Subject: [PATCH 06/15] feat: moved functions to proper place --- hyperbench/data/hif.py | 29 +++++------------------------ hyperbench/utils/__init__.py | 4 +++- hyperbench/utils/hif_utils.py | 20 ++++++++++++++++++++ 3 files changed, 28 insertions(+), 25 deletions(-) diff --git a/hyperbench/data/hif.py b/hyperbench/data/hif.py index 1f02597..45c3208 100644 --- a/hyperbench/data/hif.py +++ b/hyperbench/data/hif.py @@ -11,7 +11,7 @@ from torch import Tensor from hyperbench.types import HData, HIFHypergraph -from hyperbench.utils import validate_hif_json +from hyperbench.utils import validate_hif_json, decompress_zst, compress_to_zst def _validate_http_url(value: str) -> str: @@ -51,10 +51,10 @@ def load_from_url(url: str, save_on_disk: bool = False) -> HData: if zst_filename.endswith(".zst"): if save_on_disk: HIFLoader.__save_on_disk(os.path.basename(url), response.content) - output = HIFLoader.__decompress_zst(zst_filename) + output = decompress_zst(zst_filename) elif zst_filename.endswith(".json"): if save_on_disk: - compressed = HIFLoader.__compress_to_zst(zst_filename) + compressed = compress_to_zst(zst_filename) HIFLoader.__save_on_disk(os.path.basename(url), compressed) output = zst_filename else: @@ -80,7 +80,7 @@ def load_from_path(filepath: str) -> HData: raise ValueError(f"File '{filepath}' does not exist.") if filepath.endswith(".zst"): - output = HIFLoader.__decompress_zst(filepath) + output = decompress_zst(filepath) elif filepath.endswith(".json"): output = filepath else: @@ -145,30 +145,11 @@ def load(dataset_name: str, save_on_disk: bool = False) -> HData: tmp_zst_file.write(response.content) zst_filename = tmp_zst_file.name - # Decompress the downloaded zst file - output = HIFLoader.__decompress_zst(zst_filename) + output = decompress_zst(zst_filename) hypergraph = HIFLoader.__extract_hif(output) hdata = HIFLoader.__process(hypergraph) return hdata - @staticmethod - def __decompress_zst(zst_path: str) -> str: - dctx = zstd.ZstdDecompressor() - with ( - open(zst_path, "rb") as input_f, - tempfile.NamedTemporaryFile(mode="wb", suffix=".json", delete=False) as tmp_file, - ): - dctx.copy_stream(input_f, tmp_file) - output = tmp_file.name - return output - - @staticmethod - def __compress_to_zst(json_path: str) -> bytes: - cctx = zstd.ZstdCompressor() - with open(json_path, "rb") as input_f: - compressed_content = cctx.compress(input_f.read()) - return compressed_content - @staticmethod def __extract_hif(json_file: str) -> HIFHypergraph: with open(json_file, "r") as f: diff --git a/hyperbench/utils/__init__.py b/hyperbench/utils/__init__.py index 65e09ab..e149372 100644 --- a/hyperbench/utils/__init__.py +++ b/hyperbench/utils/__init__.py @@ -5,7 +5,7 @@ to_non_empty_edgeattr, to_0based_ids, ) -from .hif_utils import validate_hif_json +from .hif_utils import validate_hif_json, decompress_zst, compress_to_zst from .nn_utils import ( INPUT_LAYER, ActivationFn, @@ -32,4 +32,6 @@ "to_non_empty_edgeattr", "to_0based_ids", "validate_hif_json", + "decompress_zst", + "compress_to_zst", ] diff --git a/hyperbench/utils/hif_utils.py b/hyperbench/utils/hif_utils.py index c9cd944..3f5eb7d 100644 --- a/hyperbench/utils/hif_utils.py +++ b/hyperbench/utils/hif_utils.py @@ -1,6 +1,8 @@ import fastjsonschema import json import requests +import zstandard as zstd +import tempfile def validate_hif_json(filename: str) -> bool: @@ -26,3 +28,21 @@ def validate_hif_json(filename: str) -> bool: return True except Exception: return False + + +def decompress_zst(zst_path: str) -> str: + dctx = zstd.ZstdDecompressor() + with ( + open(zst_path, "rb") as input_f, + tempfile.NamedTemporaryFile(mode="wb", suffix=".json", delete=False) as tmp_file, + ): + dctx.copy_stream(input_f, tmp_file) + output = tmp_file.name + return output + + +def compress_to_zst(json_path: str) -> bytes: + cctx = zstd.ZstdCompressor() + with open(json_path, "rb") as input_f: + compressed_content = cctx.compress(input_f.read()) + return compressed_content From 40a0c1bc56f3b075d8da93da0264f4390d934110 Mon Sep 17 00:00:00 2001 From: ddevin96 Date: Thu, 23 Apr 2026 14:00:22 +0200 Subject: [PATCH 07/15] fix: reordering class function and small adjustments to module --- hyperbench/data/__init__.py | 3 +- hyperbench/data/dataset.py | 58 +--- hyperbench/data/hif.py | 410 +++++++++++++------------- hyperbench/data/supported_datasets.py | 7 +- hyperbench/utils/__init__.py | 6 +- hyperbench/utils/file_utils.py | 31 ++ hyperbench/utils/hif_utils.py | 20 -- hyperbench/utils/url_utils.py | 8 + 8 files changed, 272 insertions(+), 271 deletions(-) create mode 100644 hyperbench/utils/file_utils.py create mode 100644 hyperbench/utils/url_utils.py diff --git a/hyperbench/data/__init__.py b/hyperbench/data/__init__.py index 706aba3..aa77d60 100644 --- a/hyperbench/data/__init__.py +++ b/hyperbench/data/__init__.py @@ -1,5 +1,5 @@ from .dataset import Dataset -from .hif import HIFLoader +from .hif import HIFLoader, HIFProcessor from .supported_datasets import ( AlgebraDataset, @@ -53,6 +53,7 @@ "GeometryDataset", "GOTDataset", "HIFLoader", + "HIFProcessor", "HyperedgeSampler", "IMDBDataset", "MusicBluesReviewsDataset", diff --git a/hyperbench/data/dataset.py b/hyperbench/data/dataset.py index 1ecda68..55ca3b3 100644 --- a/hyperbench/data/dataset.py +++ b/hyperbench/data/dataset.py @@ -16,7 +16,7 @@ from hyperbench.utils import validate_hif_json from hyperbench.data.sampling import SamplingStrategy, create_sampler_from_strategy -from hyperbench.data.hif import HIFLoader +from hyperbench.data.hif import HIFLoader, HIFProcessor class Dataset(TorchDataset): @@ -43,8 +43,6 @@ def __init__( hdata: Optional HData object to initialize the dataset with. If provided, the dataset will be initialized with this data instead of loading and processing from HIF. Must be provided if prepare is set to ``False``. sampling_strategy: The sampling strategy to use for the dataset. If not provided, defaults to ``SamplingStrategy.HYPEREDGE``. - prepare: Whether to load and process the original dataset from HIF format. - If set to ``False``, the dataset will be initialized with the provided hdata instead. Defaults to ``True``. """ self.__sampler = create_sampler_from_strategy(sampling_strategy) @@ -133,26 +131,6 @@ def from_path( dataset = cls.from_hdata(hdata=hypergraph, sampling_strategy=sampling_strategy) return dataset - # @classmethod - # def from_default( - # cls, - # sampling_strategy: SamplingStrategy = SamplingStrategy.HYPEREDGE, - # save_on_disk: bool = False, - # ) -> "Dataset": - # """ - # Create a :class:`Dataset` instance by loading a hypergraph from a URL pointing to a .json or .json.zst file in HIF format. - - # Args: - # sampling_strategy: The sampling strategy to use for the dataset. If not provided, defaults to ``SamplingStrategy.HYPEREDGE``. - # save_on_disk: Whether to save the downloaded file on disk. - - # Returns: - # The :class:`Dataset` instance with the loaded hypergraph data. - # """ - # hdata = HIFLoader.load(dataset_name="", save_on_disk=save_on_disk) - # dataset = cls.from_hdata(hdata=hdata, sampling_strategy=sampling_strategy) - # return dataset - def enrich_node_features( self, enricher: NodeEnricher, @@ -341,27 +319,19 @@ def __get_hyperedge_ids_permutation( ranged_hyperedge_ids_permutation = torch.arange(num_hyperedges, device=device) return ranged_hyperedge_ids_permutation - # def __process_hyperedge_weights(self) -> Optional[Tensor]: - # # Initialize the hyperedge weights tensor - # hyperedge_weights = None - - # has_hyperedge_weights = self.hypergraph.hyperedges is not None and all( - # "weight" in edge for edge in self.hypergraph.hyperedges - # ) - - # if has_hyperedge_weights: - # weights = [edge.get("weight", 1.0) for edge in self.hypergraph.hyperedges] - # hyperedge_weights = torch.tensor(weights, dtype=torch.float) - # elif ( - # has_hyperedge_weights is False - # and self.hypergraph.hyperedges is not None - # and any("weight" in edge for edge in self.hypergraph.hyperedges) - # ): - # raise ValueError( - # "Some hyperedges have weights while others do not. All hyperedges must either have weights or none." - # ) - - # return hyperedge_weights + @staticmethod + def transform_node_attrs( + attrs: Dict[str, Any], + attr_keys: Optional[List[str]] = None, + ) -> Tensor: + return HIFProcessor.transform_attrs(attrs, attr_keys) + + @staticmethod + def transform_hyperedge_attrs( + attrs: Dict[str, Any], + attr_keys: Optional[List[str]] = None, + ) -> Tensor: + return HIFProcessor.transform_attrs(attrs, attr_keys) def stats(self) -> Dict[str, Any]: """ diff --git a/hyperbench/data/hif.py b/hyperbench/data/hif.py index 45c3208..1ab65e0 100644 --- a/hyperbench/data/hif.py +++ b/hyperbench/data/hif.py @@ -7,18 +7,210 @@ import warnings from huggingface_hub import hf_hub_download from typing import Optional, Dict, Any, List -from urllib.parse import urlparse from torch import Tensor from hyperbench.types import HData, HIFHypergraph -from hyperbench.utils import validate_hif_json, decompress_zst, compress_to_zst +from hyperbench.utils import ( + validate_hif_json, + decompress_zst, + compress_to_zst, + validate_http_url, +) +from hyperbench.utils import save_on_disk as save -def _validate_http_url(value: str) -> str: - parsed = urlparse(value) - if parsed.scheme not in {"http", "https"} or not parsed.netloc: - raise ValueError(f"Invalid URL: {value}") - return value +class HIFProcessor: + """A utility class to process HIF hypergraph data into :class:`HData` format.""" + + @staticmethod + def transform_attrs( + attrs: Dict[str, Any], + attr_keys: Optional[List[str]] = None, + ) -> Tensor: + """ + Extract and encode numeric attributes to tensor. + Non-numeric attributes are discarded. Missing attributes are filled with ``0.0``. + + Args: + attrs: Dictionary of attributes + attr_keys: Optional list of attribute keys to encode. If provided, ensures consistent ordering and fill missing with ``0.0``. + + Returns: + Tensor of numeric attribute values + """ + numeric_attrs = { + key: value + for key, value in attrs.items() + if isinstance(value, (int, float)) and not isinstance(value, bool) + } + + if attr_keys is not None: + values = [float(numeric_attrs.get(key, 0.0)) for key in attr_keys] + return torch.tensor(values, dtype=torch.float) + + if not numeric_attrs: + return torch.tensor([], dtype=torch.float) + + values = [float(value) for value in numeric_attrs.values()] + return torch.tensor(values, dtype=torch.float) + + @staticmethod + def _process_hypergraph(hypergraph: HIFHypergraph) -> HData: + """ + Process the loaded hypergraph into :class:`HData` format, mapping HIF structure to tensors. + + Returns: + The processed hypergraph data. + """ + # if not self.__is_prepared: + # raise ValueError("process can only be called for the original dataset.") + + num_nodes = len(hypergraph.nodes) + x = HIFProcessor._process_x(hypergraph, num_nodes) + + # Remap node IDs to 0-based contiguous IDs (using indices) matching the x tensor order + node_id_to_idx = {node.get("node"): idx for idx, node in enumerate(hypergraph.nodes)} + # Initialize edge_set only with edges that have incidences, so that + # we avoid inflating edge count due to isolated nodes/missing incidences + hyperedge_id_to_idx: Dict[Any, int] = {} + + node_ids = [] + hyperedge_ids = [] + nodes_with_incidences = set() + for incidence in hypergraph.incidences: + node_id = incidence.get("node", 0) + hyperedge_id = incidence.get("edge", 0) + + if hyperedge_id not in hyperedge_id_to_idx: + # Hyperedges start from 0 and are assigned IDs in the order they are first encountered in incidences + hyperedge_id_to_idx[hyperedge_id] = len(hyperedge_id_to_idx) + + node_ids.append(node_id_to_idx[node_id]) + hyperedge_ids.append(hyperedge_id_to_idx[hyperedge_id]) + nodes_with_incidences.add(node_id_to_idx[node_id]) + + # Handle isolated nodes by assigning them to a new unique hyperedge (self-loop) + for node_idx in range(num_nodes): + if node_idx not in nodes_with_incidences: + new_hyperedge_id = len(hyperedge_id_to_idx) + # Unique dummy key to reserve the index in hyperedge_set + hyperedge_id_to_idx[f"__self_loop_{node_idx}__"] = new_hyperedge_id + node_ids.append(node_idx) + hyperedge_ids.append(new_hyperedge_id) + + num_hyperedges = len(hyperedge_id_to_idx) + hyperedge_attr = HIFProcessor._process_hyperedge_attr( + hypergraph=hypergraph, + hyperedge_id_to_idx=hyperedge_id_to_idx, + num_hyperedges=num_hyperedges, + ) + + hyperedge_index = torch.tensor([node_ids, hyperedge_ids], dtype=torch.long) + + return HData(x, hyperedge_index, hyperedge_attr) + + @staticmethod + def _collect_attr_keys(attr_keys: List[Dict[str, Any]]) -> List[str]: + """ + Collect unique numeric attribute keys from a list of attribute dictionaries. + + Args: + attr_keys: List of attribute dictionaries. + + Returns: + List of unique numeric attribute keys. + """ + unique_keys = [] + for attrs in attr_keys: + for key, value in attrs.items(): + if key not in unique_keys and isinstance(value, (int, float)): + unique_keys.append(key) + + return unique_keys + + @staticmethod + def _process_hyperedge_attr( + hypergraph: HIFHypergraph, + hyperedge_id_to_idx: Dict[Any, int], + num_hyperedges: int, + ) -> Optional[Tensor]: + # hyperedge-attr: shape [num_hyperedges, num_hyperedge_attributes] + hyperedge_attr = None + has_hyperedges = hypergraph.hyperedges is not None and len(hypergraph.hyperedges) > 0 + has_any_hyperedge_attrs = has_hyperedges and any( + "attrs" in edge for edge in hypergraph.hyperedges + ) + + if has_any_hyperedge_attrs: + hyperedge_id_to_attrs: Dict[Any, Dict[str, Any]] = { + e.get("edge"): e.get("attrs", {}) for e in hypergraph.hyperedges + } + + hyperedge_attr_keys = HIFProcessor._collect_attr_keys( + list(hyperedge_id_to_attrs.values()) + ) + + # Build attributes in exact order of hyperedge_set indices (0 to num_hyperedges - 1) + hyperedge_idx_to_id = {idx: id for id, idx in hyperedge_id_to_idx.items()} + + attrs = [] + for hyperedge_idx in range(num_hyperedges): + hyperedge_id = hyperedge_idx_to_id[hyperedge_idx] + + transformed_attrs = HIFProcessor.transform_attrs( + # If it's a real hyperedge, get its attrs; if self-loop, get empty dict + attrs=hyperedge_id_to_attrs.get(hyperedge_id, {}), + attr_keys=hyperedge_attr_keys, + ) + attrs.append(transformed_attrs) + + hyperedge_attr = torch.stack(attrs) + + return hyperedge_attr + + @staticmethod + def _process_x(hypergraph: HIFHypergraph, num_nodes: int) -> Tensor: + # Collect all attribute keys to have tensors of same size + node_attr_keys = HIFProcessor._collect_attr_keys( + [node.get("attrs", {}) for node in hypergraph.nodes] + ) + + if node_attr_keys: + x = torch.stack( + [ + HIFProcessor.transform_attrs(node.get("attrs", {}), attr_keys=node_attr_keys) + for node in hypergraph.nodes + ] + ) + else: + # Fallback to ones if no node features, 1 is better as it can help during + # training (e.g., avoid zero multiplication), especially in first epochs + x = torch.ones((num_nodes, 1), dtype=torch.float) + + return x # shape [num_nodes, num_node_features] + + @staticmethod + def _process_hyperedge_weights(hypergraph: HIFHypergraph) -> Optional[Tensor]: + # Initialize the hyperedge weights tensor + hyperedge_weights = None + + has_hyperedge_weights = hypergraph.hyperedges is not None and all( + "weight" in edge for edge in hypergraph.hyperedges + ) + + if has_hyperedge_weights: + weights = [edge.get("weight", 1.0) for edge in hypergraph.hyperedges] + hyperedge_weights = torch.tensor(weights, dtype=torch.float) + elif ( + has_hyperedge_weights is False + and hypergraph.hyperedges is not None + and any("weight" in edge for edge in hypergraph.hyperedges) + ): + raise ValueError( + "Some hyperedges have weights while others do not. All hyperedges must either have weights or none." + ) + + return hyperedge_weights class HIFLoader: @@ -34,7 +226,7 @@ def load_from_url(url: str, save_on_disk: bool = False) -> HData: Returns: HData: The loaded hypergraph object. """ - url = _validate_http_url(url) + url = validate_http_url(url) response = requests.get(url, timeout=20) if response.status_code != 200: @@ -50,12 +242,12 @@ def load_from_url(url: str, save_on_disk: bool = False) -> HData: if zst_filename.endswith(".zst"): if save_on_disk: - HIFLoader.__save_on_disk(os.path.basename(url), response.content) + save(os.path.basename(url), response.content) output = decompress_zst(zst_filename) elif zst_filename.endswith(".json"): if save_on_disk: compressed = compress_to_zst(zst_filename) - HIFLoader.__save_on_disk(os.path.basename(url), compressed) + save(os.path.basename(url), compressed) output = zst_filename else: raise ValueError( @@ -63,7 +255,7 @@ def load_from_url(url: str, save_on_disk: bool = False) -> HData: ) hypergraph = HIFLoader.__extract_hif(output) - hdata = HIFLoader.__process(hypergraph) + hdata = HIFProcessor._process_hypergraph(hypergraph) return hdata @staticmethod @@ -89,11 +281,11 @@ def load_from_path(filepath: str) -> HData: ) hypergraph = HIFLoader.__extract_hif(output) - hdata = HIFLoader.__process(hypergraph) + hdata = HIFProcessor._process_hypergraph(hypergraph) return hdata @staticmethod - def load(dataset_name: str, save_on_disk: bool = False) -> HData: + def load_from_name(dataset_name: str, save_on_disk: bool = False) -> HData: print(f"Loading dataset '{dataset_name}' from disk or remote sources...") current_dir = os.path.dirname(os.path.abspath(__file__)) zst_filename = os.path.join(current_dir, "datasets", f"{dataset_name}.json.zst") @@ -147,7 +339,7 @@ def load(dataset_name: str, save_on_disk: bool = False) -> HData: output = decompress_zst(zst_filename) hypergraph = HIFLoader.__extract_hif(output) - hdata = HIFLoader.__process(hypergraph) + hdata = HIFProcessor._process_hypergraph(hypergraph) return hdata @staticmethod @@ -158,193 +350,3 @@ def __extract_hif(json_file: str) -> HIFHypergraph: raise ValueError(f"Dataset from file '{json_file}' is not HIF-compliant.") hypergraph = HIFHypergraph.from_hif(hiftext) return hypergraph - - @staticmethod - def __save_on_disk(dataset_name: str, content: bytes) -> None: - current_dir = os.path.dirname(os.path.abspath(__file__)) - zst_filename = os.path.join(current_dir, "datasets", f"{dataset_name}.json.zst") - os.makedirs(os.path.join(current_dir, "datasets"), exist_ok=True) - - with open(zst_filename, "wb") as f: - f.write(content) - - @staticmethod - def __process_hyperedge_attr( - hypergraph: HIFHypergraph, - hyperedge_id_to_idx: Dict[Any, int], - num_hyperedges: int, - ) -> Optional[Tensor]: - # hyperedge-attr: shape [num_hyperedges, num_hyperedge_attributes] - hyperedge_attr = None - has_hyperedges = hypergraph.hyperedges is not None and len(hypergraph.hyperedges) > 0 - has_any_hyperedge_attrs = has_hyperedges and any( - "attrs" in edge for edge in hypergraph.hyperedges - ) - - if has_any_hyperedge_attrs: - hyperedge_id_to_attrs: Dict[Any, Dict[str, Any]] = { - e.get("edge"): e.get("attrs", {}) for e in hypergraph.hyperedges - } - - hyperedge_attr_keys = HIFLoader.__collect_attr_keys( - list(hyperedge_id_to_attrs.values()) - ) - - # Build attributes in exact order of hyperedge_set indices (0 to num_hyperedges - 1) - hyperedge_idx_to_id = {idx: id for id, idx in hyperedge_id_to_idx.items()} - - attrs = [] - for hyperedge_idx in range(num_hyperedges): - hyperedge_id = hyperedge_idx_to_id[hyperedge_idx] - - transformed_attrs = HIFLoader.transform_hyperedge_attrs( - # If it's a real hyperedge, get its attrs; if self-loop, get empty dict - attrs=hyperedge_id_to_attrs.get(hyperedge_id, {}), - attr_keys=hyperedge_attr_keys, - ) - attrs.append(transformed_attrs) - - hyperedge_attr = torch.stack(attrs) - - return hyperedge_attr - - @staticmethod - def transform_node_attrs( - attrs: Dict[str, Any], - attr_keys: Optional[List[str]] = None, - ) -> Tensor: - return HIFLoader.transform_attrs(attrs, attr_keys) - - @staticmethod - def __process_x(hypergraph: HIFHypergraph, num_nodes: int) -> Tensor: - # Collect all attribute keys to have tensors of same size - node_attr_keys = HIFLoader.__collect_attr_keys( - [node.get("attrs", {}) for node in hypergraph.nodes] - ) - - if node_attr_keys: - x = torch.stack( - [ - HIFLoader.transform_node_attrs(node.get("attrs", {}), attr_keys=node_attr_keys) - for node in hypergraph.nodes - ] - ) - else: - # Fallback to ones if no node features, 1 is better as it can help during - # training (e.g., avoid zero multiplication), especially in first epochs - x = torch.ones((num_nodes, 1), dtype=torch.float) - - return x # shape [num_nodes, num_node_features] - - @staticmethod - def __process(hypergraph: HIFHypergraph) -> HData: - """ - Process the loaded hypergraph into :class:`HData` format, mapping HIF structure to tensors. - - Returns: - The processed hypergraph data. - """ - # if not self.__is_prepared: - # raise ValueError("process can only be called for the original dataset.") - - num_nodes = len(hypergraph.nodes) - x = HIFLoader.__process_x(hypergraph, num_nodes) - - # Remap node IDs to 0-based contiguous IDs (using indices) matching the x tensor order - node_id_to_idx = {node.get("node"): idx for idx, node in enumerate(hypergraph.nodes)} - # Initialize edge_set only with edges that have incidences, so that - # we avoid inflating edge count due to isolated nodes/missing incidences - hyperedge_id_to_idx: Dict[Any, int] = {} - - node_ids = [] - hyperedge_ids = [] - nodes_with_incidences = set() - for incidence in hypergraph.incidences: - node_id = incidence.get("node", 0) - hyperedge_id = incidence.get("edge", 0) - - if hyperedge_id not in hyperedge_id_to_idx: - # Hyperedges start from 0 and are assigned IDs in the order they are first encountered in incidences - hyperedge_id_to_idx[hyperedge_id] = len(hyperedge_id_to_idx) - - node_ids.append(node_id_to_idx[node_id]) - hyperedge_ids.append(hyperedge_id_to_idx[hyperedge_id]) - nodes_with_incidences.add(node_id_to_idx[node_id]) - - # Handle isolated nodes by assigning them to a new unique hyperedge (self-loop) - for node_idx in range(num_nodes): - if node_idx not in nodes_with_incidences: - new_hyperedge_id = len(hyperedge_id_to_idx) - # Unique dummy key to reserve the index in hyperedge_set - hyperedge_id_to_idx[f"__self_loop_{node_idx}__"] = new_hyperedge_id - node_ids.append(node_idx) - hyperedge_ids.append(new_hyperedge_id) - - num_hyperedges = len(hyperedge_id_to_idx) - hyperedge_attr = HIFLoader.__process_hyperedge_attr( - hypergraph=hypergraph, - hyperedge_id_to_idx=hyperedge_id_to_idx, - num_hyperedges=num_hyperedges, - ) - - hyperedge_index = torch.tensor([node_ids, hyperedge_ids], dtype=torch.long) - - return HData(x, hyperedge_index, hyperedge_attr) - - @staticmethod - def __collect_attr_keys(attr_keys: List[Dict[str, Any]]) -> List[str]: - """ - Collect unique numeric attribute keys from a list of attribute dictionaries. - - Args: - attr_keys: List of attribute dictionaries. - - Returns: - List of unique numeric attribute keys. - """ - unique_keys = [] - for attrs in attr_keys: - for key, value in attrs.items(): - if key not in unique_keys and isinstance(value, (int, float)): - unique_keys.append(key) - - return unique_keys - - @staticmethod - def transform_hyperedge_attrs( - attrs: Dict[str, Any], - attr_keys: Optional[List[str]] = None, - ) -> Tensor: - return HIFLoader.transform_attrs(attrs, attr_keys) - - @staticmethod - def transform_attrs( - attrs: Dict[str, Any], - attr_keys: Optional[List[str]] = None, - ) -> Tensor: - """ - Extract and encode numeric attributes to tensor. - Non-numeric attributes are discarded. Missing attributes are filled with ``0.0``. - - Args: - attrs: Dictionary of attributes - attr_keys: Optional list of attribute keys to encode. If provided, ensures consistent ordering and fill missing with ``0.0``. - - Returns: - Tensor of numeric attribute values - """ - numeric_attrs = { - key: value - for key, value in attrs.items() - if isinstance(value, (int, float)) and not isinstance(value, bool) - } - - if attr_keys is not None: - values = [float(numeric_attrs.get(key, 0.0)) for key in attr_keys] - return torch.tensor(values, dtype=torch.float) - - if not numeric_attrs: - return torch.tensor([], dtype=torch.float) - - values = [float(value) for value in numeric_attrs.values()] - return torch.tensor(values, dtype=torch.float) diff --git a/hyperbench/data/supported_datasets.py b/hyperbench/data/supported_datasets.py index 198aa1e..4f28bd7 100644 --- a/hyperbench/data/supported_datasets.py +++ b/hyperbench/data/supported_datasets.py @@ -4,6 +4,11 @@ class PreloadedDataset(Dataset): + """ + Base class for datasets that use default loading. Subclasses should specify the DATASET_NAME class variable. + The dataset will be saved on disk after the first load. + """ + DATASET_NAME = "" def __init__( @@ -13,7 +18,7 @@ def __init__( ) -> None: super().__init__(hdata=hdata, sampling_strategy=sampling_strategy) if hdata is None: - self.hdata = HIFLoader.load(self.DATASET_NAME, save_on_disk=True) + self.hdata = HIFLoader.load_from_name(self.DATASET_NAME, save_on_disk=True) class AlgebraDataset(PreloadedDataset): diff --git a/hyperbench/utils/__init__.py b/hyperbench/utils/__init__.py index e149372..e9aff4e 100644 --- a/hyperbench/utils/__init__.py +++ b/hyperbench/utils/__init__.py @@ -5,7 +5,7 @@ to_non_empty_edgeattr, to_0based_ids, ) -from .hif_utils import validate_hif_json, decompress_zst, compress_to_zst +from .hif_utils import validate_hif_json from .nn_utils import ( INPUT_LAYER, ActivationFn, @@ -16,6 +16,8 @@ is_layer, ) from .sparse_utils import sparse_dropout +from .url_utils import validate_http_url +from .file_utils import decompress_zst, compress_to_zst, save_on_disk __all__ = [ "INPUT_LAYER", @@ -34,4 +36,6 @@ "validate_hif_json", "decompress_zst", "compress_to_zst", + "validate_http_url", + "save_on_disk", ] diff --git a/hyperbench/utils/file_utils.py b/hyperbench/utils/file_utils.py new file mode 100644 index 0000000..73573d9 --- /dev/null +++ b/hyperbench/utils/file_utils.py @@ -0,0 +1,31 @@ +import zstandard as zstd +import tempfile +import os + + +def decompress_zst(zst_path: str) -> str: + dctx = zstd.ZstdDecompressor() + with ( + open(zst_path, "rb") as input_f, + tempfile.NamedTemporaryFile(mode="wb", suffix=".json", delete=False) as tmp_file, + ): + dctx.copy_stream(input_f, tmp_file) + output = tmp_file.name + return output + + +def compress_to_zst(json_path: str) -> bytes: + cctx = zstd.ZstdCompressor() + with open(json_path, "rb") as input_f: + compressed_content = cctx.compress(input_f.read()) + return compressed_content + + +def save_on_disk(dataset_name: str, content: bytes) -> None: + current_dir = os.path.dirname(os.path.abspath(__file__)) + datasets_dir = os.path.join(current_dir, "..", "data", "datasets") + zst_filename = os.path.join(datasets_dir, f"{dataset_name}.json.zst") + os.makedirs(datasets_dir, exist_ok=True) + + with open(zst_filename, "wb") as f: + f.write(content) diff --git a/hyperbench/utils/hif_utils.py b/hyperbench/utils/hif_utils.py index 3f5eb7d..c9cd944 100644 --- a/hyperbench/utils/hif_utils.py +++ b/hyperbench/utils/hif_utils.py @@ -1,8 +1,6 @@ import fastjsonschema import json import requests -import zstandard as zstd -import tempfile def validate_hif_json(filename: str) -> bool: @@ -28,21 +26,3 @@ def validate_hif_json(filename: str) -> bool: return True except Exception: return False - - -def decompress_zst(zst_path: str) -> str: - dctx = zstd.ZstdDecompressor() - with ( - open(zst_path, "rb") as input_f, - tempfile.NamedTemporaryFile(mode="wb", suffix=".json", delete=False) as tmp_file, - ): - dctx.copy_stream(input_f, tmp_file) - output = tmp_file.name - return output - - -def compress_to_zst(json_path: str) -> bytes: - cctx = zstd.ZstdCompressor() - with open(json_path, "rb") as input_f: - compressed_content = cctx.compress(input_f.read()) - return compressed_content diff --git a/hyperbench/utils/url_utils.py b/hyperbench/utils/url_utils.py new file mode 100644 index 0000000..d8e192a --- /dev/null +++ b/hyperbench/utils/url_utils.py @@ -0,0 +1,8 @@ +from urllib.parse import urlparse + + +def validate_http_url(value: str) -> str: + parsed = urlparse(value) + if parsed.scheme not in {"http", "https"} or not parsed.netloc: + raise ValueError(f"Invalid URL: {value}") + return value From 2651d75d4315427c741d70ad61324075b81f461a Mon Sep 17 00:00:00 2001 From: ddevin96 Date: Thu, 23 Apr 2026 21:10:05 +0200 Subject: [PATCH 08/15] fix: check on hyperedge weights - moved and corrected tests --- hyperbench/data/hif.py | 58 +- hyperbench/tests/data/dataset_test.py | 2109 ++++++++------------- hyperbench/tests/data/hif_test.py | 540 ++++++ hyperbench/tests/types/hypergraph_test.py | 2 + hyperbench/tests/utils/hif_utils_test.py | 41 +- 5 files changed, 1375 insertions(+), 1375 deletions(-) diff --git a/hyperbench/data/hif.py b/hyperbench/data/hif.py index 1ab65e0..3afeb7a 100644 --- a/hyperbench/data/hif.py +++ b/hyperbench/data/hif.py @@ -105,9 +105,22 @@ def _process_hypergraph(hypergraph: HIFHypergraph) -> HData: num_hyperedges=num_hyperedges, ) + hyperedge_weights = HIFProcessor._process_hyperedge_weights( + hypergraph=hypergraph, + hyperedge_id_to_idx=hyperedge_id_to_idx, + num_hyperedges=num_hyperedges, + ) + hyperedge_index = torch.tensor([node_ids, hyperedge_ids], dtype=torch.long) - return HData(x, hyperedge_index, hyperedge_attr) + return HData( + x=x, + hyperedge_index=hyperedge_index, + hyperedge_weights=hyperedge_weights, + hyperedge_attr=hyperedge_attr, + num_nodes=num_nodes, + num_hyperedges=num_hyperedges, + ) @staticmethod def _collect_attr_keys(attr_keys: List[Dict[str, Any]]) -> List[str]: @@ -190,27 +203,34 @@ def _process_x(hypergraph: HIFHypergraph, num_nodes: int) -> Tensor: return x # shape [num_nodes, num_node_features] @staticmethod - def _process_hyperedge_weights(hypergraph: HIFHypergraph) -> Optional[Tensor]: - # Initialize the hyperedge weights tensor - hyperedge_weights = None - - has_hyperedge_weights = hypergraph.hyperedges is not None and all( - "weight" in edge for edge in hypergraph.hyperedges + def _process_hyperedge_weights( + hypergraph: HIFHypergraph, + hyperedge_id_to_idx: Dict[Any, int], + num_hyperedges: int, + ) -> Optional[Tensor]: + has_hyperedges = hypergraph.hyperedges is not None and len(hypergraph.hyperedges) > 0 + has_any_hyperedge_attrs = has_hyperedges and any( + "attrs" in edge for edge in hypergraph.hyperedges ) - if has_hyperedge_weights: - weights = [edge.get("weight", 1.0) for edge in hypergraph.hyperedges] - hyperedge_weights = torch.tensor(weights, dtype=torch.float) - elif ( - has_hyperedge_weights is False - and hypergraph.hyperedges is not None - and any("weight" in edge for edge in hypergraph.hyperedges) - ): - raise ValueError( - "Some hyperedges have weights while others do not. All hyperedges must either have weights or none." - ) + # Keep old behavior for fixtures where edges have no attrs at all. + if not has_any_hyperedge_attrs: + return None + + # Map real edge id -> attrs (self-loops are absent and will default to 1.0) + hyperedge_id_to_attrs: Dict[Any, Dict[str, Any]] = { + e.get("edge"): e.get("attrs", {}) for e in hypergraph.hyperedges + } + + # Build in exact hyperedge index order, defaulting missing weights to 1.0. + hyperedge_idx_to_id = {idx: edge_id for edge_id, idx in hyperedge_id_to_idx.items()} + weights = [] + for hyperedge_idx in range(num_hyperedges): + edge_id = hyperedge_idx_to_id[hyperedge_idx] + edge_attrs = hyperedge_id_to_attrs.get(edge_id, {}) + weights.append(float(edge_attrs.get("weight", 1.0))) - return hyperedge_weights + return torch.tensor(weights, dtype=torch.float) class HIFLoader: diff --git a/hyperbench/tests/data/dataset_test.py b/hyperbench/tests/data/dataset_test.py index 8e1300d..856acba 100644 --- a/hyperbench/tests/data/dataset_test.py +++ b/hyperbench/tests/data/dataset_test.py @@ -1,1440 +1,839 @@ -# import pytest -# import requests -# import tempfile -# import torch - -# from unittest.mock import patch, mock_open, MagicMock -# from hyperbench.data import AlgebraDataset, Dataset, HIFLoader, SamplingStrategy -# from hyperbench.nn import EnrichmentMode, NodeEnricher, HyperedgeEnricher -# from hyperbench.types import HData, HIFHypergraph -# from hyperbench.data.supported_datasets import PreloadedDataset - - -# @pytest.fixture -# def mock_hdata() -> HData: -# x = torch.ones((3, 1), dtype=torch.float) -# hyperedge_index = torch.tensor([[0, 1, 2], [0, 0, 1]], dtype=torch.long) -# return HData(x=x, hyperedge_index=hyperedge_index) - - -# @pytest.fixture -# def mock_hdata_with_hyperedge_attr() -> HData: -# x = torch.ones((3, 1), dtype=torch.float) -# hyperedge_index = torch.tensor([[0, 1, 2], [0, 0, 1]], dtype=torch.long) -# hyperedge_attr = torch.ones((3, 1), dtype=torch.float) -# return HData(x=x, hyperedge_index=hyperedge_index, hyperedge_attr=hyperedge_attr) - - -# @pytest.fixture -# def mock_hdata_with_hyperedge_weights() -> HData: -# x = torch.ones((3, 1), dtype=torch.float) -# hyperedge_index = torch.tensor([[0, 1, 2], [0, 0, 1]], dtype=torch.long) -# hyperedge_weights = torch.tensor([1.0, 2.0, 3.0], dtype=torch.float) -# return HData(x=x, hyperedge_index=hyperedge_index, hyperedge_weights=hyperedge_weights) - - -# @pytest.fixture -# def mock_sample_hypergraph(): -# return HIFHypergraph( -# network_type="undirected", -# nodes=[{"node": "0"}, {"node": "1"}], -# hyperedges=[{"edge": "0"}], -# incidences=[{"node": "0", "edge": "0"}], -# ) - - -# @pytest.fixture -# def mock_simple_hypergraph(): -# return HIFHypergraph( -# network_type="undirected", -# nodes=[{"node": "0", "attrs": {}}, {"node": "1", "attrs": {}}], -# hyperedges=[{"edge": "0", "attrs": {}}], -# incidences=[{"node": "0", "edge": "0"}], -# ) - - -# @pytest.fixture -# def mock_three_node_weighted_hypergraph(): -# return HIFHypergraph( -# network_type="undirected", -# nodes=[ -# {"node": "0", "attrs": {}}, -# {"node": "1", "attrs": {}}, -# {"node": "2", "attrs": {}}, -# ], -# hyperedges=[ -# {"edge": "0", "attrs": {"weight": 1.0}}, -# {"edge": "1", "attrs": {"weight": 2.0}}, -# ], -# incidences=[ -# {"node": "0", "edge": "0"}, -# {"node": "1", "edge": "0"}, -# {"node": "2", "edge": "1"}, -# ], -# ) - - -# @pytest.fixture -# def mock_four_node_hypergraph(): -# return HIFHypergraph( -# network_type="undirected", -# nodes=[ -# {"node": "0", "attrs": {}}, -# {"node": "1", "attrs": {}}, -# {"node": "2", "attrs": {}}, -# {"node": "3", "attrs": {}}, -# ], -# hyperedges=[{"edge": "0", "attrs": {}}, {"edge": "1", "attrs": {}}], -# incidences=[ -# {"node": "0", "edge": "0"}, -# {"node": "1", "edge": "0"}, -# {"node": "2", "edge": "1"}, -# {"node": "3", "edge": "1"}, -# ], -# ) - - -# @pytest.fixture -# def mock_five_node_hypergraph(): -# return HIFHypergraph( -# network_type="undirected", -# nodes=[ -# {"node": "0", "attrs": {}}, -# {"node": "1", "attrs": {}}, -# {"node": "2", "attrs": {}}, -# {"node": "3", "attrs": {}}, -# {"node": "4", "attrs": {}}, -# ], -# hyperedges=[{"edge": "0", "attrs": {}}], -# incidences=[{"node": "0", "edge": "0"}], -# ) - - -# @pytest.fixture -# def mock_no_edge_attr_hypergraph(): -# return HIFHypergraph( -# network_type="undirected", -# nodes=[ -# {"node": "0", "attrs": {}}, -# {"node": "1", "attrs": {}}, -# ], -# hyperedges=[{"edge": "0"}], -# incidences=[ -# {"node": "0", "edge": "0"}, -# {"node": "1", "edge": "0"}, -# ], -# ) - - -# @pytest.fixture -# def mock_multiple_edges_attr_hypergraph(): -# return HIFHypergraph( -# network_type="undirected", -# nodes=[ -# {"node": "0", "attrs": {}}, -# {"node": "1", "attrs": {}}, -# {"node": "2", "attrs": {}}, -# {"node": "3", "attrs": {}}, -# ], -# hyperedges=[ -# {"edge": "0", "attrs": {"weight": 1.0}}, -# {"edge": "1", "attrs": {"weight": 2.0}}, -# {"edge": "2", "attrs": {"weight": 3.0}}, -# ], -# incidences=[ -# {"node": "0", "edge": "0"}, -# {"node": "1", "edge": "0"}, -# {"node": "2", "edge": "1"}, -# {"node": "3", "edge": "2"}, -# ], -# ) - - -# def test_Preloaded_dataset_init(): -# mock_hdata = MagicMock(spec=HData) -# dataset = PreloadedDataset(hdata=mock_hdata) - -# assert dataset.hdata == mock_hdata -# assert dataset.sampling_strategy is SamplingStrategy.HYPEREDGE - - -# def test_Preloaded_dataset_loads_hdata_when_hdata_is_none(): -# mock_hdata = MagicMock(spec=HData) -# with patch.object(HIFLoader, "load", return_value=mock_hdata) as mock_load: -# dataset = AlgebraDataset(hdata=None) - -# assert dataset.hdata == mock_hdata -# mock_load.assert_called_once_with("algebra", save_on_disk=True) - - -# # def test_HIFLoader_num_nodes_and_edges(): -# # dataset_name = "ALGEBRA" -# # mock_hypergraph = HIFHypergraph( -# # network_type="undirected", -# # nodes=[{"node": str(i)} for i in range(20)], -# # edges=[{"edge": str(i)} for i in range(30)], -# # incidences=[{"node": "0", "edge": "0"}], -# # ) - -# # with patch.object(HIFLoader, "load", return_value=mock_hypergraph): -# # hypergraph = HIFLoader.load(dataset_name) - -# # assert hypergraph is not None -# # assert hasattr(hypergraph, "nodes") -# # assert hasattr(hypergraph, "hyperedges") -# # assert hasattr(hypergraph, "incidences") -# # assert hasattr(hypergraph, "metadata") -# # assert hasattr(hypergraph, "network_type") - -# # assert hypergraph.num_nodes == 20 -# # assert hypergraph.num_hyperedges == 30 - - -# # def test_HIFLoader_loads_invalid_dataset(): -# # dataset_name = "INVALID_DATASET" - -# # with pytest.raises(ValueError, match="Dataset 'INVALID_DATASET' not found"): -# # HIFLoader.load(dataset_name) - - -# # def test_HIFLoader_loads_invalid_hif_format(): -# # dataset_name = "ALGEBRA" - -# # invalid_hif_json = '{"network-type": "undirected", "nodes": []}' - -# # with ( -# # patch("hyperbench.data.dataset.requests.get") as mock_get, -# # patch("hyperbench.data.dataset.validate_hif_json", return_value=False), -# # patch("builtins.open", mock_open(read_data=invalid_hif_json)), -# # patch("hyperbench.data.dataset.zstd.ZstdDecompressor"), -# # ): -# # mock_response = mock_get.return_value -# # mock_response.status_code = 200 -# # mock_response.content = b"mock_zst_content" - -# # with pytest.raises(ValueError, match="Dataset 'algebra' is not HIF-compliant"): -# # HIFLoader.load(dataset_name) - - -# # def test_HIFLoader_stores_on_disk_when_save_on_disk_true(): -# # dataset_name = "ALGEBRA" - -# # mock_hypergraph = HIFHypergraph( -# # network_type="undirected", -# # nodes=[{"node": "0"}, {"node": "1"}], -# # hyperedges=[{"edge": "0"}], -# # incidences=[{"node": "0", "edge": "0"}], -# # ) - -# # mock_hif_json = { -# # "network-type": "undirected", -# # "nodes": [{"node": "0"}, {"node": "1"}], -# # "edges": [{"edge": "0"}], -# # "incidences": [{"node": "0", "edge": "0"}], -# # } - -# # with ( -# # patch("hyperbench.data.dataset.requests.get") as mock_get, -# # patch("hyperbench.data.dataset.os.path.exists", return_value=False), -# # patch("hyperbench.data.dataset.os.makedirs"), -# # patch("builtins.open", mock_open()) as mock_file, -# # patch("hyperbench.data.dataset.zstd.ZstdDecompressor") as mock_decomp, -# # patch("hyperbench.data.dataset.json.load", return_value=mock_hif_json), -# # patch("hyperbench.data.dataset.validate_hif_json", return_value=True), -# # patch.object(HIFHypergraph, "from_hif", return_value=mock_hypergraph), -# # ): -# # # Mock successful download -# # mock_response = mock_get.return_value -# # mock_response.status_code = 200 -# # mock_response.content = b"mock_zst_content" - -# # # Mock decompressor -# # mock_stream = mock_decomp.return_value.stream_reader.return_value -# # mock_stream.__enter__ = lambda self: mock_stream -# # mock_stream.__exit__ = lambda self, *args: None - -# # hypergraph = HIFLoader.load(dataset_name, save_on_disk=True) - -# # assert hypergraph is not None -# # assert hypergraph.network_type == "undirected" -# # mock_get.assert_called_once() -# # # Verify file was written to disk (not temp file) -# # assert mock_file.call_count >= 2 # Once for write, once for read - - -# # def test_HIFLoader_uses_temp_file_when_save_on_disk_false(): -# # dataset_name = "ALGEBRA" - -# # mock_hypergraph = HIFHypergraph( -# # network_type="undirected", -# # nodes=[{"node": "0"}, {"node": "1"}], -# # hyperedges=[{"edge": "0"}], -# # incidences=[{"node": "0", "edge": "0"}], -# # ) - -# # mock_hif_json = { -# # "network-type": "undirected", -# # "nodes": [{"node": "0"}, {"node": "1"}], -# # "edges": [{"edge": "0"}], -# # "incidences": [{"node": "0", "edge": "0"}], -# # } - -# # with ( -# # patch("hyperbench.data.dataset.requests.get") as mock_get, -# # patch("hyperbench.data.dataset.os.path.exists", return_value=False), -# # patch("hyperbench.data.dataset.tempfile.NamedTemporaryFile") as mock_temp, -# # patch("builtins.open", mock_open()), -# # patch("hyperbench.data.dataset.zstd.ZstdDecompressor") as mock_decomp, -# # patch("hyperbench.data.dataset.json.load", return_value=mock_hif_json), -# # patch("hyperbench.data.dataset.validate_hif_json", return_value=True), -# # patch.object(HIFHypergraph, "from_hif", return_value=mock_hypergraph), -# # ): -# # # Mock successful download -# # mock_response = mock_get.return_value -# # mock_response.status_code = 200 -# # mock_response.content = b"mock_zst_content" - -# # # Mock temp file -# # mock_temp_file = mock_temp.return_value.__enter__.return_value -# # mock_temp_file.name = "/tmp/fake_temp.json.zst" - -# # # Mock decompressor -# # mock_stream = mock_decomp.return_value.stream_reader.return_value -# # mock_stream.__enter__ = lambda self: mock_stream -# # mock_stream.__exit__ = lambda self, *args: None - -# # hypergraph = HIFLoader.load(dataset_name, save_on_disk=False) - -# # assert hypergraph is not None -# # assert hypergraph.network_type == "undirected" -# # mock_get.assert_called_once() -# # # Verify temp file was used -# # assert mock_temp.call_count >= 1 - - -# # def test_HIFLoader_download_failure(): -# # dataset_name = "ALGEBRA" - -# # with ( -# # patch("hyperbench.data.dataset.requests.get") as mock_get, -# # patch("hyperbench.data.dataset.hf_hub_download", side_effect=Exception("HFHub failed")), -# # patch("hyperbench.data.dataset.os.path.exists", return_value=False), -# # ): -# # # Mock failed download -# # mock_response = mock_get.return_value -# # mock_response.status_code = 404 -# # mock_response.content = b"" - -# # with pytest.warns( -# # UserWarning, -# # match=r"(?s)GitHub raw download failed for dataset 'algebra' with status code 404.*Falling back to Hugging Face Hub download for dataset", -# # ): -# # with pytest.raises( -# # ValueError, -# # match=r"Failed to download dataset 'algebra'", -# # ): -# # HIFLoader.load(dataset_name) - - -# # def test_HIFLoader_falls_back_to_hf_hub_download_when_github_raw_download_fails( -# # tmp_path, mock_sample_hypergraph -# # ): -# # dataset_name = "ALGEBRA" - -# # mock_hypergraph = mock_sample_hypergraph - -# # mock_hif_json = { -# # "network_type": "undirected", -# # "nodes": [{"node": "0"}, {"node": "1"}], -# # "edges": [{"edge": "0"}], -# # "incidences": [{"node": "0", "edge": "0"}], -# # } - -# # fallback_file = tmp_path / "algebra.json.zst" -# # fallback_file.write_bytes(b"mock_zst_content") - -# # created_temp_files = [] -# # original_named_tempfile = tempfile.NamedTemporaryFile - -# # def named_tempfile_side_effect(*args, **kwargs): -# # temp_file = original_named_tempfile(*args, **kwargs) -# # created_temp_files.append(temp_file) -# # return temp_file - -# # with ( -# # patch("hyperbench.data.dataset.requests.get") as mock_get, -# # patch( -# # "hyperbench.data.dataset.hf_hub_download", -# # return_value=str(fallback_file), -# # ) as mock_hf_hub_download, -# # patch("hyperbench.data.dataset.os.path.exists", return_value=False), -# # patch( -# # "hyperbench.data.dataset.tempfile.NamedTemporaryFile", -# # side_effect=named_tempfile_side_effect, -# # ), -# # patch("hyperbench.data.dataset.zstd.ZstdDecompressor") as mock_decomp, -# # patch("hyperbench.data.dataset.json.load", return_value=mock_hif_json), -# # patch("hyperbench.data.dataset.validate_hif_json", return_value=True), -# # patch.object(HIFHypergraph, "from_hif", return_value=mock_hypergraph), -# # ): -# # mock_response = mock_get.return_value -# # mock_response.status_code = 404 -# # mock_response.content = b"" - -# # def fake_copy_stream(src, dst): -# # dst.write(b'{"network_type":"undirected","nodes":[],"edges":[],"incidences":[]}') -# # return - -# # mock_decomp.return_value.copy_stream.side_effect = fake_copy_stream - -# # with pytest.warns( -# # UserWarning, -# # match=r"(?s)GitHub raw download failed for dataset 'algebra' with status code 404.*Falling back to Hugging Face Hub download for dataset", -# # ): -# # hypergraph = HIFLoader.load(dataset_name, save_on_disk=False) - -# # assert hypergraph.network_type == "undirected" -# # mock_get.assert_called_once() -# # mock_hf_hub_download.assert_called_once() -# # assert created_temp_files[0].name is not None -# # assert fallback_file.read_bytes() == b"mock_zst_content" - - -# # def test_HIFLoader_download_raises_when_network_error(): -# # dataset_name = "ALGEBRA" - -# # with ( -# # patch("hyperbench.data.dataset.requests.get") as mock_get, -# # patch("hyperbench.data.dataset.os.path.exists", return_value=False), -# # ): -# # # Mock network error -# # mock_get.side_effect = requests.RequestException("Network error") - -# # with pytest.raises(requests.RequestException, match="Network error"): -# # HIFLoader.load(dataset_name) - - -# # def test_init_with_prepare_false_and_no_hdata_raises(): -# # with pytest.raises(ValueError, match="hdata must be provided when prepare is set to False."): -# # Dataset(hdata=None, prepare=False) - - -# # def test_dataset_is_not_available(): -# # class FakeMockDataset(Dataset): -# # DATASET_NAME = "FAKE" - -# # with pytest.raises(ValueError, match=r"Dataset 'FAKE' not found"): -# # FakeMockDataset() - - -# # @pytest.mark.parametrize( -# # "strategy, expected_len", -# # [ -# # pytest.param(SamplingStrategy.NODE, 4, id="node_strategy"), -# # pytest.param(SamplingStrategy.HYPEREDGE, 2, id="hyperedge_strategy"), -# # ], -# # ) -# # def test_dataset_is_available_with_all_strategies( -# # strategy, expected_len, mock_four_node_hypergraph -# # ): -# # with patch.object(HIFLoader, "load", return_value=mock_four_node_hypergraph): -# # dataset = AlgebraDataset(sampling_strategy=strategy) - -# # assert dataset.DATASET_NAME == "ALGEBRA" -# # assert dataset.hypergraph is not None -# # assert len(dataset) == expected_len - - -# # def test_download_already_downloaded_dataset_uses_local_value(mock_four_node_hypergraph): -# # with patch.object(HIFLoader, "load", return_value=mock_four_node_hypergraph): -# # dataset = AlgebraDataset() - -# # hg1 = dataset.download() -# # hg2 = dataset.download() - -# # assert hg1 is hg2 - - -# # def test_throw_when_dataset_name_is_none(): -# # class FakeMockDataset(Dataset): -# # DATASET_NAME = None - -# # with pytest.raises( -# # ValueError, -# # match=r"Dataset name \(provided: None\) must be provided\.", -# # ): -# # FakeMockDataset() - - -# # def test_dataset_process_no_incidences(): -# # mock_hypergraph = HIFHypergraph( -# # network_type="undirected", -# # nodes=[{"node": "0", "attrs": {}}, {"node": "1", "attrs": {}}], -# # hyperedges=[{"edge": "0", "attrs": {}}], -# # incidences=[], -# # ) - -# # with patch.object(HIFLoader, "load", return_value=mock_hypergraph): -# # dataset = AlgebraDataset() - -# # assert dataset.hdata is not None -# # assert dataset.hdata.x.shape[0] == 2 -# # assert dataset.hdata.hyperedge_index.shape[0] == 2 -# # assert dataset.hdata.hyperedge_index.shape[1] == 2 -# # assert dataset.hdata.hyperedge_attr is not None -# # assert dataset.hdata.hyperedge_attr.shape == (2, 0) -# # assert dataset.hdata.hyperedge_attr[0].shape == (0,) -# # assert dataset.hdata.hyperedge_attr[1].shape == (0,) - - -# # def test_dataset_process_with_edge_attributes(): -# # mock_hypergraph = HIFHypergraph( -# # network_type="undirected", -# # nodes=[ -# # {"node": "0", "attrs": {}}, -# # {"node": "1", "attrs": {}}, -# # {"node": "2", "attrs": {}}, -# # ], -# # hyperedges=[ -# # {"edge": "0", "attrs": {"weight": 1.0, "type": 2.0}}, -# # {"edge": "1", "attrs": {"weight": 3.0, "type": 0.1}}, -# # ], -# # incidences=[ -# # {"node": "0", "edge": "0"}, -# # {"node": "1", "edge": "0"}, -# # {"node": "2", "edge": "1"}, -# # ], -# # ) - -# # with patch.object(HIFLoader, "load", return_value=mock_hypergraph): -# # dataset = AlgebraDataset() - -# # assert dataset.hdata is not None -# # assert dataset.hdata.x.shape[0] == 3 -# # assert dataset.hdata.hyperedge_index.shape[0] == 2 -# # assert dataset.hdata.hyperedge_index.shape[1] == 3 -# # assert dataset.hdata.hyperedge_attr is not None -# # # Two edges with two attributes each: shape [2, 2] -# # assert dataset.hdata.hyperedge_attr.shape == (2, 2) -# # # Attributes maintain dictionary insertion order (no sorting) - -# # assert torch.allclose(dataset.hdata.hyperedge_attr[0], torch.tensor([1.0, 2.0])) # weight, type -# # assert torch.allclose(dataset.hdata.hyperedge_attr[1], torch.tensor([3.0, 0.1])) # weight, type - - -# # def test_dataset_process_without_edge_attributes(mock_no_edge_attr_hypergraph): -# # with patch.object(HIFLoader, "load", return_value=mock_no_edge_attr_hypergraph): -# # dataset = AlgebraDataset() - -# # assert dataset.hdata is not None -# # assert dataset.hdata.hyperedge_index.shape[0] == 2 -# # assert dataset.hdata.hyperedge_index.shape[1] == 2 -# # assert dataset.hdata.hyperedge_attr is None - - -# # def test_dataset_process_hyperedge_index_in_correct_format(mock_four_node_hypergraph): -# # with patch.object(HIFLoader, "load", return_value=mock_four_node_hypergraph): -# # dataset = AlgebraDataset() - -# # assert dataset.hdata.hyperedge_index.shape == (2, 4) -# # assert torch.allclose(dataset.hdata.hyperedge_index[0], torch.tensor([0, 1, 2, 3])) -# # assert torch.allclose(dataset.hdata.hyperedge_index[1], torch.tensor([0, 0, 1, 1])) - - -# # def test_dataset_process_random_ids(): -# # mock_hypergraph = HIFHypergraph( -# # network_type="undirected", -# # nodes=[ -# # {"node": "abc", "attrs": {}}, -# # {"node": "ss", "attrs": {}}, -# # {"node": "fewao", "attrs": {}}, -# # ], -# # hyperedges=[{"edge": "0", "attrs": {}}, {"edge": "1", "attrs": {}}], -# # incidences=[ -# # {"node": "abc", "edge": "0"}, -# # {"node": "ss", "edge": "0"}, -# # {"node": "fewao", "edge": "1"}, -# # ], -# # ) - -# # with patch.object(HIFLoader, "load", return_value=mock_hypergraph): -# # dataset = AlgebraDataset() - -# # assert dataset.hdata.hyperedge_index.shape == (2, 3) -# # assert torch.allclose(dataset.hdata.hyperedge_index[0], torch.tensor([0, 1, 2])) -# # assert torch.allclose(dataset.hdata.hyperedge_index[1], torch.tensor([0, 0, 1])) -# # assert dataset.hdata.hyperedge_attr is not None -# # assert dataset.hdata.hyperedge_attr.shape == (2, 0) # 2 edges, 0 attributes each - - -# # @pytest.mark.parametrize( -# # "strategy", -# # [ -# # pytest.param(SamplingStrategy.NODE, id="node_strategy"), -# # pytest.param(SamplingStrategy.HYPEREDGE, id="hyperedge_strategy"), -# # ], -# # ) -# # def test_getitem_index_list_empty(mock_simple_hypergraph, strategy): -# # with patch.object(HIFLoader, "load", return_value=mock_simple_hypergraph): -# # dataset = AlgebraDataset(sampling_strategy=strategy) - -# # with pytest.raises(ValueError, match="Index list cannot be empty."): -# # dataset[[]] - - -# # @pytest.mark.parametrize( -# # "strategy, index_list, expected_message", -# # [ -# # pytest.param( -# # SamplingStrategy.NODE, -# # [0, 1, 2, 3, 4], -# # r"Index list length \(5\) cannot exceed the number of sampleable items \(4\)\.", -# # id="node_strategy", -# # ), -# # pytest.param( -# # SamplingStrategy.HYPEREDGE, -# # [0, 1, 2], -# # r"Index list length \(3\) cannot exceed the number of sampleable items \(2\)\.", -# # id="hyperedge_strategy", -# # ), -# # ], -# # ) -# # def test_getitem_raises_when_index_list_larger_than_max( -# # mock_four_node_hypergraph, strategy, index_list, expected_message -# # ): -# # with patch.object(HIFLoader, "load", return_value=mock_four_node_hypergraph): -# # dataset = AlgebraDataset(sampling_strategy=strategy) - -# # with pytest.raises(ValueError, match=expected_message): -# # dataset[index_list] - - -# # @pytest.mark.parametrize( -# # "strategy, index, expected_message", -# # [ -# # pytest.param( -# # SamplingStrategy.NODE, 4, r"Node ID 4 is out of bounds \(0, 3\)\.", id="node_strategy" -# # ), -# # pytest.param( -# # SamplingStrategy.HYPEREDGE, -# # 2, -# # r"Hyperedge ID 2 is out of bounds \(0, 1\)\.", -# # id="hyperedge_strategy", -# # ), -# # ], -# # ) -# # def test_getitem_raises_when_index_out_of_bounds( -# # mock_four_node_hypergraph, strategy, index, expected_message -# # ): -# # with patch.object(HIFLoader, "load", return_value=mock_four_node_hypergraph): -# # dataset = AlgebraDataset(sampling_strategy=strategy) - -# # with pytest.raises(IndexError, match=expected_message): -# # dataset[index] - - -# # @pytest.mark.parametrize( -# # "strategy, index, expected_shape, expected_num_hyperedges", -# # [ -# # # When node 1 is selected, we get hyperedge 0 with nodes 0 and 1 -> 2 incidences, 1 hyperedge -# # pytest.param(SamplingStrategy.NODE, 1, (2, 1), 1, id="node_strategy"), -# # # When hyperedge 0 is selected, we get nodes 0 and 1 -> 2 incidences, 1 hyperedge -# # pytest.param(SamplingStrategy.HYPEREDGE, 0, (2, 1), 1, id="hyperedge_strategy"), -# # ], -# # ) -# # def test_getitem_single_index( -# # mock_sample_hypergraph, strategy, index, expected_shape, expected_num_hyperedges -# # ): -# # with patch.object(HIFLoader, "load", return_value=mock_sample_hypergraph): -# # dataset = AlgebraDataset(sampling_strategy=strategy) - -# # data = dataset[index] - -# # assert data.hyperedge_index.shape == expected_shape -# # assert data.num_hyperedges == expected_num_hyperedges - - -# # @pytest.mark.parametrize( -# # "strategy, index, expected_shape, expected_num_hyperedges", -# # [ -# # # When nodes (0, 2, 3) -> hyperedge 0 (nodes 0, 1) + hyperedge 1 (nodes 2, 3) -> 4 incidences, 2 hyperedges -# # pytest.param(SamplingStrategy.NODE, [0, 2, 3], (2, 4), 2, id="node_strategy"), -# # # When hyperedge 0 (nodes 0, 1) + hyperedge 1 (nodes 2, 3) -> 4 incidences, 2 hyperedges -# # pytest.param(SamplingStrategy.HYPEREDGE, [0, 1], (2, 4), 2, id="hyperedge_strategy"), -# # ], -# # ) -# # def test_getitem_when_list_index_provided( -# # mock_four_node_hypergraph, strategy, index, expected_shape, expected_num_hyperedges -# # ): -# # with patch.object(HIFLoader, "load", return_value=mock_four_node_hypergraph): -# # dataset = AlgebraDataset(sampling_strategy=strategy) - -# # data = dataset[index] - -# # assert data.hyperedge_index.shape == expected_shape -# # assert data.num_hyperedges == expected_num_hyperedges - - -# # @pytest.mark.parametrize( -# # "strategy", -# # [ -# # pytest.param(SamplingStrategy.NODE, id="node_strategy"), -# # pytest.param(SamplingStrategy.HYPEREDGE, id="hyperedge_strategy"), -# # ], -# # ) -# # def test_getitem_with_edge_attr(mock_three_node_weighted_hypergraph, strategy): -# # with patch.object( -# # HIFLoader, "load", return_value=mock_three_node_weighted_hypergraph -# # ): -# # dataset = AlgebraDataset(sampling_strategy=strategy) - -# # data = dataset[0] - -# # assert data.hyperedge_index.shape == (2, 2) -# # assert data.num_hyperedges == 1 -# # assert data.hyperedge_attr is None - - -# # @pytest.mark.parametrize( -# # "strategy", -# # [ -# # pytest.param(SamplingStrategy.NODE, id="node_strategy"), -# # pytest.param(SamplingStrategy.HYPEREDGE, id="hyperedge_strategy"), -# # ], -# # ) -# # def test_getitem_without_edge_attr(mock_no_edge_attr_hypergraph, strategy): -# # with patch.object(HIFLoader, "load", return_value=mock_no_edge_attr_hypergraph): -# # dataset = AlgebraDataset(sampling_strategy=strategy) - -# # data = dataset[0] -# # assert data.hyperedge_attr is None - - -# # @pytest.mark.parametrize( -# # "strategy, index", -# # [ -# # # When nodes 0,2 -> hyperedge 0 (nodes 0, 1) + hyperedge 1 (node 2) -> 2 hyperedges -# # pytest.param(SamplingStrategy.NODE, [0, 2], id="node_strategy"), -# # # When hyperedge 0 (nodes 0, 1) + hyperedge 1 (node 2) -> 2 hyperedges -# # pytest.param(SamplingStrategy.HYPEREDGE, [0, 1], id="hyperedge_strategy"), -# # ], -# # ) -# # def test_getitem_with_multiple_edges_attr(mock_multiple_edges_attr_hypergraph, strategy, index): -# # with patch.object( -# # HIFLoader, "load", return_value=mock_multiple_edges_attr_hypergraph -# # ): -# # dataset = AlgebraDataset(sampling_strategy=strategy) - -# # data = dataset[index] -# # assert data.num_hyperedges == 2 - -# # # Even though the original hypergraph has edge attributes, __getitem__ should return hyperedge_attr as None -# # # as the hyperedge attributes are handled by the loader's collate function during batching -# # assert data.hyperedge_attr is None - - -# # def test_getitem_hyperedge_attr_are_padded_with_zero_when_no_uniform_edges(): -# # mock_hypergraph = HIFHypergraph( -# # network_type="undirected", -# # nodes=[ -# # {"node": "0", "attrs": {}}, -# # {"node": "1", "attrs": {}}, -# # {"node": "2", "attrs": {}}, -# # {"node": "3", "attrs": {}}, -# # ], -# # edges=[ -# # {"edge": "0", "attrs": {"weight": 1.0, "abc": 5.0}}, -# # {"edge": "1", "attrs": {"weight": 2.0}}, # Missing 'abc' -# # {"edge": "2", "attrs": {"abc": 3.0}}, # Missing 'weight' -# # ], -# # incidences=[ -# # {"node": "0", "edge": "0"}, -# # {"node": "1", "edge": "0"}, -# # {"node": "2", "edge": "1"}, -# # {"node": "3", "edge": "2"}, -# # ], -# # ) - -# with patch.object(HIFConverter, "load_from_hif", return_value=mock_hypergraph): -# dataset = AlgebraDataset() - -# assert dataset.hdata.hyperedge_attr is not None -# assert dataset.hdata.hyperedge_attr.shape == ( -# 3, -# 2, -# ) # 3 edges, 2 features each (weight, abc in insertion order) -# assert torch.allclose( -# dataset.hdata.hyperedge_attr[0], torch.tensor([1.0, 5.0]) -# ) # weight=1.0, abc=5.0 -# assert torch.allclose( -# dataset.hdata.hyperedge_attr[1], torch.tensor([2.0, 0.0]) -# ) # weight=2.0, abc=0.0 -# assert torch.allclose( -# dataset.hdata.hyperedge_attr[2], torch.tensor([0.0, 3.0]) -# ) # weight=0.0, abc=3.0 - - -# def test_process_not_all_hyperedge_weights_(): -# mock_hypergraph = HIFHypergraph( -# network_type="undirected", -# nodes=[ -# {"node": "0", "attrs": {}}, -# {"node": "1", "attrs": {}}, -# {"node": "2", "attrs": {}}, -# ], -# hyperedges=[ -# {"edge": "0", "weight": 1.5}, -# {"edge": "1"}, -# {"edge": "2", "weight": 2.5}, -# ], -# incidences=[ -# {"node": "0", "edge": "0"}, -# {"node": "1", "edge": "1"}, -# {"node": "2", "edge": "2"}, -# ], -# ) - -# with patch.object(HIFConverter, "load_from_hif", return_value=mock_hypergraph): -# with pytest.raises( -# ValueError, -# match="Some hyperedges have weights while others do not. All hyperedges must either have weights or none.", -# ): -# dataset = AlgebraDataset() - - -# def test_process_extracts_top_level_hyperedge_weights(): -# mock_hypergraph = HIFHypergraph( -# network_type="undirected", -# nodes=[ -# {"node": "0", "attrs": {}}, -# {"node": "1", "attrs": {}}, -# {"node": "2", "attrs": {}}, -# ], -# hyperedges=[ -# {"edge": "0", "weight": 1.5}, -# {"edge": "1", "weight": 3.0}, -# {"edge": "2", "weight": 2.5}, -# ], -# incidences=[ -# {"node": "0", "edge": "0"}, -# {"node": "1", "edge": "1"}, -# {"node": "2", "edge": "2"}, -# ], -# ) - -# with patch.object(HIFConverter, "load_from_hif", return_value=mock_hypergraph): -# dataset = AlgebraDataset() - -# hyperedge_weights = dataset.hdata.hyperedge_weights -# assert hyperedge_weights is not None -# assert torch.allclose(hyperedge_weights[0], torch.tensor([1.5])) -# assert torch.allclose(hyperedge_weights[1], torch.tensor([3.0])) -# assert torch.allclose(hyperedge_weights[2], torch.tensor([2.5])) - - -# # def test_transform_attrs_empty_attrs(): -# # mock_hypergraph = HIFHypergraph( -# # network_type="undirected", -# # nodes=[{"node": "0", "attrs": {}}], -# # hyperedges=[{"edge": "0", "attrs": {}}], -# # incidences=[{"node": "0", "edge": "0"}], -# # ) - -# # with patch.object(HIFLoader, "load", return_value=mock_hypergraph): - -# # class TestDataset(Dataset): -# # DATASET_NAME = "TEST" - -# # dataset = TestDataset() - -# # result = dataset.transform_attrs({}) -# # assert len(result) == 0 - -# # attrs = {"name": "node1", "active": True} -# # result = dataset.transform_attrs(attrs) -# # assert len(result) == 0 - - -# # def test_process_adds_padding_zero_when_inconsistent_node_attributes(): -# # mock_hypergraph = HIFHypergraph( -# # network_type="undirected", -# # nodes=[ -# # {"node": "0", "attrs": {"weight": 1.0}}, # Missing 'score' -# # {"node": "1", "attrs": {"weight": 2.0, "score": 0.8}}, -# # {"node": "2", "attrs": {"score": 0.5}}, # Missing 'weight' -# # ], -# # hyperedges=[{"edge": "0", "attrs": {}}], -# # incidences=[ -# # {"node": "0", "edge": "0"}, -# # {"node": "1", "edge": "0"}, -# # {"node": "2", "edge": "0"}, -# # ], -# # ) - -# # with patch.object(HIFLoader, "load", return_value=mock_hypergraph): - -# # class TestDataset(Dataset): -# # DATASET_NAME = "TEST" - -# # dataset = TestDataset() - -# # assert dataset.hdata.x.shape == ( -# # 3, -# # 2, -# # ) # 3 nodes, 2 features each (weight, score in insertion order) -# # assert torch.allclose(dataset.hdata.x[0], torch.tensor([1.0, 0.0])) # weight=1.0, score=0.0 -# # assert torch.allclose(dataset.hdata.x[1], torch.tensor([2.0, 0.8])) # weight=2.0, score=0.8 -# # assert torch.allclose(dataset.hdata.x[2], torch.tensor([0.0, 0.5])) # weight=0.0, score=0.5 - - -# # def test_process_with_no_node_attributes_fallback_to_one(): -# # mock_hypergraph = HIFHypergraph( -# # network_type="undirected", -# # nodes=[ -# # {"node": "0", "attrs": {"name": "node0"}}, -# # {"node": "1", "attrs": {}}, -# # ], -# # hyperedges=[{"edge": "0", "attrs": {}}], -# # incidences=[{"node": "0", "edge": "0"}, {"node": "1", "edge": "0"}], -# # ) - -# # with patch.object(HIFLoader, "load", return_value=mock_hypergraph): - -# # class TestDataset(Dataset): -# # DATASET_NAME = "TEST" - -# # dataset = TestDataset() - -# # assert dataset.hdata.x.shape == (2, 1) -# # assert torch.allclose(dataset.hdata.x, torch.tensor([[1.0], [1.0]])) - - -# # def test_process_with_single_node_attribute(): -# # mock_hypergraph = HIFHypergraph( -# # network_type="undirected", -# # nodes=[ -# # {"node": "0", "attrs": {"weight": 1.5}}, -# # {"node": "1", "attrs": {"weight": 2.5}}, -# # {"node": "2", "attrs": {"weight": 3.5}}, -# # ], -# # hyperedges=[{"edge": "0", "attrs": {}}], -# # incidences=[ -# # {"node": "0", "edge": "0"}, -# # {"node": "1", "edge": "0"}, -# # {"node": "2", "edge": "0"}, -# # ], -# # ) +import pytest +import torch -# # with patch.object(HIFLoader, "load", return_value=mock_hypergraph): +from unittest.mock import patch, MagicMock +from hyperbench.data import AlgebraDataset, Dataset, HIFLoader, SamplingStrategy +from hyperbench.nn import NodeEnricher, HyperedgeEnricher +from hyperbench.types import HData, HIFHypergraph +from hyperbench.data.supported_datasets import PreloadedDataset -# # class TestDataset(Dataset): -# # DATASET_NAME = "TEST" -# # dataset = TestDataset() +@pytest.fixture +def mock_hdata() -> HData: + x = torch.ones((3, 1), dtype=torch.float) + hyperedge_index = torch.tensor([[0, 1, 2], [0, 0, 1]], dtype=torch.long) + return HData(x=x, hyperedge_index=hyperedge_index) -# # # Single attribute should remain 2D: [num_nodes, 1] -# # assert dataset.hdata.x.shape == (3, 1) -# # assert torch.allclose(dataset.hdata.x, torch.tensor([[1.5], [2.5], [3.5]])) +@pytest.fixture +def mock_hdata_with_hyperedge_attr() -> HData: + x = torch.ones((3, 1), dtype=torch.float) + hyperedge_index = torch.tensor([[0, 1, 2], [0, 0, 1]], dtype=torch.long) + hyperedge_attr = torch.ones((3, 1), dtype=torch.float) + return HData(x=x, hyperedge_index=hyperedge_index, hyperedge_attr=hyperedge_attr) -# # def test_transform_attrs_adds_padding_zero_when_attr_keys_padding(): -# # mock_hypergraph = HIFHypergraph( -# # network_type="undirected", -# # nodes=[{"node": "0", "attrs": {}}], -# # hyperedges=[{"edge": "0", "attrs": {}}], -# # incidences=[{"node": "0", "edge": "0"}], -# # ) -# # with patch.object(HIFLoader, "load", return_value=mock_hypergraph): +@pytest.fixture +def mock_hdata_with_hyperedge_weights() -> HData: + x = torch.ones((3, 1), dtype=torch.float) + hyperedge_index = torch.tensor([[0, 1, 2], [0, 0, 1]], dtype=torch.long) + hyperedge_weights = torch.tensor([1.0, 2.0, 3.0], dtype=torch.float) + return HData(x=x, hyperedge_index=hyperedge_index, hyperedge_weights=hyperedge_weights) -# # class TestDataset(Dataset): -# # DATASET_NAME = "TEST" -# # dataset = TestDataset() +@pytest.fixture +def mock_hdata_sample_hypergraph() -> HData: + x = torch.ones((2, 1), dtype=torch.float) + hyperedge_index = torch.tensor([[0, 1], [0, 1]], dtype=torch.long) + return HData(x=x, hyperedge_index=hyperedge_index) -# # # Test with attr_keys - should pad missing attributes with 0.0 -# # attrs = {"weight": 1.5} -# # result = dataset.transform_attrs(attrs, attr_keys=["score", "weight", "age"]) -# # assert torch.allclose( -# # result, torch.tensor([0.0, 1.5, 0.0]) -# # ) # score=0.0, weight=1.5, age=0.0 -# # # Test with all attributes present -# # attrs = {"weight": 1.5, "score": 0.8, "age": 25.0} -# # result = dataset.transform_attrs(attrs, attr_keys=["age", "score", "weight"]) -# # assert torch.allclose( -# # result, torch.tensor([25.0, 0.8, 1.5]) -# # ) # age=25.0, score=0.8, weight=1.5 +@pytest.fixture +def mock_hdata_simple_hypergraph() -> HData: + x = torch.ones((2, 1), dtype=torch.float) + hyperedge_index = torch.tensor([[0, 1], [0, 1]], dtype=torch.long) + return HData(x=x, hyperedge_index=hyperedge_index) -# # # Test without attr_keys - maintains insertion order -# # attrs = {"weight": 1.5, "score": 0.8} -# # result = dataset.transform_attrs(attrs) -# # assert torch.allclose(result, torch.tensor([1.5, 0.8])) # weight, score (insertion order) +@pytest.fixture +def mock_hdata_three_node_weighted_hypergraph() -> HData: + x = torch.ones((3, 1), dtype=torch.float) + hyperedge_index = torch.tensor([[0, 1, 2], [0, 0, 1]], dtype=torch.long) + hyperedge_weights = torch.tensor([[1.0], [2.0]], dtype=torch.float) + return HData(x=x, hyperedge_index=hyperedge_index, hyperedge_weights=hyperedge_weights) -# # @pytest.mark.parametrize( -# # "strategy, expected_len", -# # [ -# # # mock_hdata: 3 nodes, 2 hyperedges -# # pytest.param(SamplingStrategy.NODE, 3, id="node_strategy"), -# # pytest.param(SamplingStrategy.HYPEREDGE, 2, id="hyperedge_strategy"), -# # ], -# # ) -# # def test_from_hdata(strategy, expected_len, mock_hdata): -# # dataset = Dataset.from_hdata(mock_hdata, sampling_strategy=strategy) -# # assert dataset.hdata is mock_hdata -# # assert len(dataset) == expected_len +@pytest.fixture +def mock_hdata_four_node_hypergraph() -> HData: + x = torch.ones((4, 1), dtype=torch.float) + hyperedge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]], dtype=torch.long) + return HData(x=x, hyperedge_index=hyperedge_index) -# # def test_from_hdata_download_raises(mock_hdata): -# # dataset = Dataset.from_hdata(mock_hdata) +@pytest.fixture +def mock_hdata_no_edge_attr_hypergraph() -> HData: + x = torch.ones((2, 1), dtype=torch.float) + hyperedge_index = torch.tensor([[0, 1], [0, 0]], dtype=torch.long) + return HData(x=x, hyperedge_index=hyperedge_index) -# # with pytest.raises(ValueError, match="download can only be called for the original dataset."): -# # dataset.download() +@pytest.fixture +def mock_hdata_multiple_edges_attr_hypergraph() -> HData: + x = torch.ones((4, 1), dtype=torch.float) + hyperedge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 2]], dtype=torch.long) + hyperedge_weights = torch.tensor([[1.0], [2.0], [3.0]], dtype=torch.float) + return HData(x=x, hyperedge_index=hyperedge_index, hyperedge_weights=hyperedge_weights) -# # def test_from_hdata_process_raises(mock_hdata): -# # dataset = Dataset.from_hdata(mock_hdata) -# # with pytest.raises(ValueError, match="process can only be called for the original dataset."): -# # dataset.process() +@pytest.fixture +def mock_hdata_no_incidences() -> HData: + x = torch.ones((2, 1), dtype=torch.float) + hyperedge_index = torch.tensor([[0, 1], [0, 1]], dtype=torch.long) + return HData(x=x, hyperedge_index=hyperedge_index) -# # def test_enrich_node_features_replace(mock_hdata): -# # dataset = Dataset.from_hdata(mock_hdata) +@pytest.fixture +def mock_hdata_with_two_edge_attributes() -> HData: + x = torch.ones((3, 1), dtype=torch.float) + hyperedge_index = torch.tensor([[0, 1, 2], [0, 0, 1]], dtype=torch.long) + hyperedge_weights = torch.tensor([[1.0, 2.0], [3.0, 0.1]], dtype=torch.float) + return HData(x=x, hyperedge_index=hyperedge_index, hyperedge_weights=hyperedge_weights) -# # enricher = MagicMock(spec=NodeEnricher) -# # enriched_x = torch.randn(3, 4) -# # enricher.enrich.return_value = enriched_x -# # dataset.enrich_node_features(enricher) +@pytest.fixture +def mock_hdata_random_ids() -> HData: + x = torch.ones((3, 1), dtype=torch.float) + hyperedge_index = torch.tensor([[0, 1, 2], [0, 0, 1]], dtype=torch.long) + return HData(x=x, hyperedge_index=hyperedge_index) -# # enricher.enrich.assert_called_once_with(mock_hdata.hyperedge_index) -# # assert torch.equal(dataset.hdata.x, enriched_x) +def test_Preloaded_dataset_init(): + mock_hdata = MagicMock(spec=HData) + dataset = PreloadedDataset(hdata=mock_hdata) -# # def test_enrich_node_features_concatenate(mock_hdata): -# # dataset = Dataset.from_hdata(mock_hdata) -# # original_x = dataset.hdata.x.clone() + assert dataset.hdata == mock_hdata + assert dataset.sampling_strategy is SamplingStrategy.HYPEREDGE -# # enricher = MagicMock(spec=NodeEnricher) -# # enriched_x = torch.randn(3, 4) -# # enricher.enrich.return_value = enriched_x -# dataset.enrich_node_features(enricher, enrichment_mode="concatenate") +def test_Preloaded_dataset_loads_hdata_when_hdata_is_none(): + mock_hdata = MagicMock(spec=HData) + with patch.object(HIFLoader, "load_from_name", return_value=mock_hdata) as mock_load: + dataset = AlgebraDataset(hdata=None) -# enricher.enrich.assert_called_once_with(mock_hdata.hyperedge_index) -# expected_x = torch.cat([original_x, enriched_x], dim=1) -# assert torch.equal(dataset.hdata.x, expected_x) -# assert dataset.hdata.x.shape == (3, 5) # 1 original + 4 enriched + assert dataset.hdata == mock_hdata + mock_load.assert_called_once_with("algebra", save_on_disk=True) -# def test_enrich_hyperedge_attr_replace(mock_hdata): -# dataset = Dataset.from_hdata(mock_hdata) - -# enricher = MagicMock(spec=HyperedgeEnricher) -# enriched_x = torch.randn(3, 4) -# enricher.enrich.return_value = enriched_x - -# dataset.enrich_hyperedge_attr(enricher) - -# enricher.enrich.assert_called_once_with(mock_hdata.hyperedge_index) -# hyperedge_attr = dataset.hdata.hyperedge_attr -# assert hyperedge_attr is not None -# assert torch.equal(hyperedge_attr, enriched_x) - - -# def test_enrich_hyperedge_attr_concatenate(mock_hdata_with_hyperedge_attr): -# dataset = Dataset.from_hdata(mock_hdata_with_hyperedge_attr) -# original_hyperedge_attr = dataset.hdata.hyperedge_attr -# assert original_hyperedge_attr is not None -# original_hyperedge_attr = original_hyperedge_attr.clone() - -# enricher = MagicMock(spec=HyperedgeEnricher) -# enriched_x = torch.randn(3, 4) -# enricher.enrich.return_value = enriched_x - -# dataset.enrich_hyperedge_attr(enricher, enrichment_mode="concatenate") +@pytest.mark.parametrize( + "strategy, expected_len", + [ + pytest.param(SamplingStrategy.NODE, 4, id="node_strategy"), + pytest.param(SamplingStrategy.HYPEREDGE, 2, id="hyperedge_strategy"), + ], +) +def test_dataset_is_available_with_all_strategies( + strategy, expected_len, mock_hdata_four_node_hypergraph +): + + with patch.object(HIFLoader, "load_from_name", return_value=mock_hdata_four_node_hypergraph): + dataset = AlgebraDataset(sampling_strategy=strategy) + + assert dataset.DATASET_NAME == "algebra" + assert len(dataset) == expected_len -# enricher.enrich.assert_called_once_with(mock_hdata_with_hyperedge_attr.hyperedge_index) -# expected_x = torch.cat([original_hyperedge_attr, enriched_x], dim=1) -# hyperedge_attr = dataset.hdata.hyperedge_attr -# assert hyperedge_attr is not None -# assert torch.equal(hyperedge_attr, expected_x) -# assert hyperedge_attr.shape == (3, 5) # 1 original + 4 enriched - - -# def test_enrich_hyperedge_weights_replace(mock_hdata): -# dataset = Dataset.from_hdata(mock_hdata) -# enricher = MagicMock(spec=HyperedgeEnricher) -# enriched_weights = torch.randn(3) -# enricher.enrich.return_value = enriched_weights - -# dataset.enrich_hyperedge_weights(enricher) - -# enricher.enrich.assert_called_once_with(mock_hdata.hyperedge_index) -# hyperedge_weights = dataset.hdata.hyperedge_weights -# assert hyperedge_weights is not None -# assert torch.equal(hyperedge_weights, enriched_weights) +def test_dataset_process_no_incidences(mock_hdata_no_incidences): + with patch.object(HIFLoader, "load_from_name", return_value=mock_hdata_no_incidences): + dataset = AlgebraDataset() + + assert dataset.hdata is not None + assert dataset.hdata.x.shape[0] == 2 + assert dataset.hdata.hyperedge_index.shape[0] == 2 + assert dataset.hdata.hyperedge_index.shape[1] == 2 + assert dataset.hdata.hyperedge_attr is None + + +def test_dataset_process_with_edge_attributes(mock_hdata_with_two_edge_attributes): + with patch.object( + HIFLoader, "load_from_name", return_value=mock_hdata_with_two_edge_attributes + ): + dataset = AlgebraDataset() + + assert dataset.hdata is not None + assert dataset.hdata.x.shape[0] == 3 + assert dataset.hdata.hyperedge_index.shape[0] == 2 + assert dataset.hdata.hyperedge_index.shape[1] == 3 + assert dataset.hdata.hyperedge_attr is None + assert dataset.hdata.hyperedge_weights is not None + assert dataset.hdata.hyperedge_weights.shape == (2, 2) + assert torch.allclose(dataset.hdata.hyperedge_weights[0], torch.tensor([1.0, 2.0])) + assert torch.allclose(dataset.hdata.hyperedge_weights[1], torch.tensor([3.0, 0.1])) + + +def test_dataset_process_without_edge_attributes(mock_hdata_no_edge_attr_hypergraph): + with patch.object(HIFLoader, "load_from_name", return_value=mock_hdata_no_edge_attr_hypergraph): + dataset = AlgebraDataset() + + assert dataset.hdata is not None + assert dataset.hdata.hyperedge_index.shape[0] == 2 + assert dataset.hdata.hyperedge_index.shape[1] == 2 + assert dataset.hdata.hyperedge_attr is None + + +def test_dataset_process_hyperedge_index_in_correct_format(mock_hdata_four_node_hypergraph): + with patch.object(HIFLoader, "load_from_name", return_value=mock_hdata_four_node_hypergraph): + dataset = AlgebraDataset() + + assert dataset.hdata.hyperedge_index.shape == (2, 4) + assert torch.allclose(dataset.hdata.hyperedge_index[0], torch.tensor([0, 1, 2, 3])) + assert torch.allclose(dataset.hdata.hyperedge_index[1], torch.tensor([0, 0, 1, 1])) + + +def test_dataset_process_random_ids(mock_hdata_random_ids): + with patch.object(HIFLoader, "load_from_name", return_value=mock_hdata_random_ids): + dataset = AlgebraDataset() + + assert dataset.hdata.hyperedge_index.shape == (2, 3) + assert torch.allclose(dataset.hdata.hyperedge_index[0], torch.tensor([0, 1, 2])) + assert torch.allclose(dataset.hdata.hyperedge_index[1], torch.tensor([0, 0, 1])) + assert dataset.hdata.hyperedge_attr is None + + +@pytest.mark.parametrize( + "strategy", + [ + pytest.param(SamplingStrategy.NODE, id="node_strategy"), + pytest.param(SamplingStrategy.HYPEREDGE, id="hyperedge_strategy"), + ], +) +def test_getitem_index_list_empty(mock_hdata_simple_hypergraph, strategy): + with patch.object(HIFLoader, "load_from_name", return_value=mock_hdata_simple_hypergraph): + dataset = AlgebraDataset(sampling_strategy=strategy) + + with pytest.raises(ValueError, match="Index list cannot be empty."): + dataset[[]] + + +@pytest.mark.parametrize( + "strategy, index_list, expected_message", + [ + pytest.param( + SamplingStrategy.NODE, + [0, 1, 2, 3, 4], + r"Index list length \(5\) cannot exceed the number of sampleable items \(4\)\.", + id="node_strategy", + ), + pytest.param( + SamplingStrategy.HYPEREDGE, + [0, 1, 2], + r"Index list length \(3\) cannot exceed the number of sampleable items \(2\)\.", + id="hyperedge_strategy", + ), + ], +) +def test_getitem_raises_when_index_list_larger_than_max( + mock_hdata_four_node_hypergraph, strategy, index_list, expected_message +): + with patch.object(HIFLoader, "load_from_name", return_value=mock_hdata_four_node_hypergraph): + dataset = AlgebraDataset(sampling_strategy=strategy) + + with pytest.raises(ValueError, match=expected_message): + dataset[index_list] + + +@pytest.mark.parametrize( + "strategy, index, expected_message", + [ + pytest.param( + SamplingStrategy.NODE, 4, r"Node ID 4 is out of bounds \(0, 3\)\.", id="node_strategy" + ), + pytest.param( + SamplingStrategy.HYPEREDGE, + 2, + r"Hyperedge ID 2 is out of bounds \(0, 1\)\.", + id="hyperedge_strategy", + ), + ], +) +def test_getitem_raises_when_index_out_of_bounds( + mock_hdata_four_node_hypergraph, strategy, index, expected_message +): + with patch.object(HIFLoader, "load_from_name", return_value=mock_hdata_four_node_hypergraph): + dataset = AlgebraDataset(sampling_strategy=strategy) + + with pytest.raises(IndexError, match=expected_message): + dataset[index] + + +@pytest.mark.parametrize( + "strategy, index, expected_shape, expected_num_hyperedges", + [ + # When node 1 is selected, we get hyperedge 0 with nodes 0 and 1 -> 2 incidences, 1 hyperedge + pytest.param(SamplingStrategy.NODE, 1, (2, 1), 1, id="node_strategy"), + # When hyperedge 0 is selected, we get nodes 0 and 1 -> 2 incidences, 1 hyperedge + pytest.param(SamplingStrategy.HYPEREDGE, 0, (2, 1), 1, id="hyperedge_strategy"), + ], +) +def test_getitem_single_index( + mock_hdata_sample_hypergraph, strategy, index, expected_shape, expected_num_hyperedges +): + with patch.object(HIFLoader, "load_from_name", return_value=mock_hdata_sample_hypergraph): + dataset = AlgebraDataset(sampling_strategy=strategy) + + data = dataset[index] + + assert data.hyperedge_index.shape == expected_shape + assert data.num_hyperedges == expected_num_hyperedges + + +@pytest.mark.parametrize( + "strategy, index, expected_shape, expected_num_hyperedges", + [ + # When nodes (0, 2, 3) -> hyperedge 0 (nodes 0, 1) + hyperedge 1 (nodes 2, 3) -> 4 incidences, 2 hyperedges + pytest.param(SamplingStrategy.NODE, [0, 2, 3], (2, 4), 2, id="node_strategy"), + # When hyperedge 0 (nodes 0, 1) + hyperedge 1 (nodes 2, 3) -> 4 incidences, 2 hyperedges + pytest.param(SamplingStrategy.HYPEREDGE, [0, 1], (2, 4), 2, id="hyperedge_strategy"), + ], +) +def test_getitem_when_list_index_provided( + mock_hdata_four_node_hypergraph, strategy, index, expected_shape, expected_num_hyperedges +): + with patch.object(HIFLoader, "load_from_name", return_value=mock_hdata_four_node_hypergraph): + dataset = AlgebraDataset(sampling_strategy=strategy) + + data = dataset[index] + + assert data.hyperedge_index.shape == expected_shape + assert data.num_hyperedges == expected_num_hyperedges + + +@pytest.mark.parametrize( + "strategy", + [ + pytest.param(SamplingStrategy.NODE, id="node_strategy"), + pytest.param(SamplingStrategy.HYPEREDGE, id="hyperedge_strategy"), + ], +) +def test_getitem_with_edge_attr(mock_hdata_three_node_weighted_hypergraph, strategy): + with patch.object( + HIFLoader, "load_from_name", return_value=mock_hdata_three_node_weighted_hypergraph + ): + dataset = AlgebraDataset(sampling_strategy=strategy) + + data = dataset[0] + + assert data.hyperedge_index.shape == (2, 2) + assert data.num_hyperedges == 1 + assert data.hyperedge_attr is None + + +@pytest.mark.parametrize( + "strategy", + [ + pytest.param(SamplingStrategy.NODE, id="node_strategy"), + pytest.param(SamplingStrategy.HYPEREDGE, id="hyperedge_strategy"), + ], +) +def test_getitem_without_edge_attr(mock_hdata_no_edge_attr_hypergraph, strategy): + with patch.object(HIFLoader, "load_from_name", return_value=mock_hdata_no_edge_attr_hypergraph): + dataset = AlgebraDataset(sampling_strategy=strategy) + + data = dataset[0] + assert data.hyperedge_attr is None + + +@pytest.mark.parametrize( + "strategy, index", + [ + # When nodes 0,2 -> hyperedge 0 (nodes 0, 1) + hyperedge 1 (node 2) -> 2 hyperedges + pytest.param(SamplingStrategy.NODE, [0, 2], id="node_strategy"), + # When hyperedge 0 (nodes 0, 1) + hyperedge 1 (node 2) -> 2 hyperedges + pytest.param(SamplingStrategy.HYPEREDGE, [0, 1], id="hyperedge_strategy"), + ], +) +def test_getitem_with_multiple_edges_attr( + mock_hdata_multiple_edges_attr_hypergraph, strategy, index +): + with patch.object( + HIFLoader, "load_from_name", return_value=mock_hdata_multiple_edges_attr_hypergraph + ): + dataset = AlgebraDataset(sampling_strategy=strategy) + + data = dataset[index] + assert data.num_hyperedges == 2 + + # Even though the original hypergraph has edge attributes, __getitem__ should return hyperedge_attr as None + # as the hyperedge attributes are handled by the loader's collate function during batching + assert data.hyperedge_attr is None + + +@pytest.mark.parametrize( + "strategy, expected_len", + [ + # mock_hdata: 3 nodes, 2 hyperedges + pytest.param(SamplingStrategy.NODE, 3, id="node_strategy"), + pytest.param(SamplingStrategy.HYPEREDGE, 2, id="hyperedge_strategy"), + ], +) +def test_from_hdata(strategy, expected_len, mock_hdata): + dataset = Dataset.from_hdata(mock_hdata, sampling_strategy=strategy) + + assert dataset.hdata is mock_hdata + assert len(dataset) == expected_len + + +@pytest.mark.parametrize( + "strategy", + [ + pytest.param(SamplingStrategy.HYPEREDGE, id="hyperedge_strategy"), + pytest.param(SamplingStrategy.NODE, id="node_strategy"), + ], +) +def test_from_url(strategy, mock_hdata): + url = "https://example.com/sample.json.zst" + + with patch.object(HIFLoader, "load_from_url", return_value=mock_hdata) as mock_load_from_url: + dataset = Dataset.from_url(url=url, sampling_strategy=strategy, save_on_disk=True) + + mock_load_from_url.assert_called_once_with(url=url, save_on_disk=True) + assert dataset.hdata is mock_hdata + assert dataset.sampling_strategy == strategy + +@pytest.mark.parametrize( + "strategy", + [ + pytest.param(SamplingStrategy.HYPEREDGE, id="hyperedge_strategy"), + pytest.param(SamplingStrategy.NODE, id="node_strategy"), + ], +) +def test_from_path(strategy, mock_hdata): + filepath = "/tmp/sample.json.zst" -# def test_enrich_hyperedge_weights_concatenate(mock_hdata_with_hyperedge_weights): -# dataset = Dataset.from_hdata(mock_hdata_with_hyperedge_weights) -# original_weights = dataset.hdata.hyperedge_weights -# assert original_weights is not None -# original_weights = original_weights.clone() + with patch.object(HIFLoader, "load_from_path", return_value=mock_hdata) as mock_load_from_path: + dataset = Dataset.from_path(filepath=filepath, sampling_strategy=strategy) -# enricher = MagicMock(spec=HyperedgeEnricher) -# enriched_weights = torch.randn(3) -# enricher.enrich.return_value = enriched_weights + mock_load_from_path.assert_called_once_with(filepath=filepath) + assert dataset.hdata is mock_hdata + assert dataset.sampling_strategy == strategy -# dataset.enrich_hyperedge_weights(enricher, enrichment_mode="concatenate") - -# enricher.enrich.assert_called_once_with(mock_hdata_with_hyperedge_weights.hyperedge_index) -# expected_weights = torch.cat([original_weights, enriched_weights], dim=0) -# hyperedge_weights = dataset.hdata.hyperedge_weights -# assert hyperedge_weights is not None -# assert torch.equal(hyperedge_weights, expected_weights) -# assert hyperedge_weights.shape == (6,) # 3 original + 3 enriched +def test_enrich_node_features_replace(mock_hdata): + dataset = Dataset.from_hdata(mock_hdata) -# # @pytest.mark.parametrize( -# # "hyperedge_index, k, expected_hyperedge_index", -# # [ -# # pytest.param( -# # torch.tensor([[0, 1, 2], [0, 0, 0]]), -# # 4, -# # torch.zeros((2, 0), dtype=torch.long), -# # id="single_hyperedge_below_k_removed", -# # ), -# # pytest.param( -# # torch.tensor([[0, 1, 2], [0, 0, 0]]), -# # 3, -# # torch.tensor([[0, 1, 2], [0, 0, 0]]), -# # id="single_hyperedge_at_exact_k_kept", -# # ), -# # pytest.param( -# # torch.tensor([[0, 1, 2, 3, 4], [0, 0, 0, 1, 1]]), -# # 3, -# # torch.tensor([[0, 1, 2], [0, 0, 0]]), -# # id="two_hyperedges_first_kept_second_removed", -# # ), -# # pytest.param( -# # torch.tensor([[0, 1, 2, 3, 4, 5], [0, 0, 0, 1, 1, 1]]), -# # 3, -# # torch.tensor([[0, 1, 2, 3, 4, 5], [0, 0, 0, 1, 1, 1]]), -# # id="two_hyperedges_both_kept", -# # ), -# # pytest.param( -# # torch.tensor([[0, 1, 2, 3, 4, 5], [0, 0, 1, 1, 2, 2]]), -# # 3, -# # torch.zeros((2, 0), dtype=torch.long), -# # id="three_hyperedges_all_removed", -# # ), -# # ], -# # ) -# # def test_remove_hyperedges_with_fewer_than_k_nodes(hyperedge_index, k, expected_hyperedge_index): -# # num_nodes = hyperedge_index[0].max().item() + 1 if hyperedge_index.shape[1] > 0 else 0 -# # x = torch.ones((num_nodes, 1), dtype=torch.float) -# # hdata = HData(x=x, hyperedge_index=hyperedge_index) -# # dataset = Dataset.from_hdata(hdata) - -# # dataset.remove_hyperedges_with_fewer_than_k_nodes(k) - -# # expected_num_nodes = expected_hyperedge_index[0].unique().shape[0] -# # expected_num_hyperedges = expected_hyperedge_index[1].unique().shape[0] - -# # assert torch.equal(dataset.hdata.hyperedge_index, expected_hyperedge_index) -# # assert dataset.hdata.x.shape[0] == expected_num_nodes -# # assert dataset.hdata.y.shape[0] == expected_num_hyperedges - + enricher = MagicMock(spec=NodeEnricher) + enriched_x = torch.randn(3, 4) + enricher.enrich.return_value = enriched_x + + dataset.enrich_node_features(enricher) + + enricher.enrich.assert_called_once_with(mock_hdata.hyperedge_index) + assert torch.equal(dataset.hdata.x, enriched_x) + + +def test_enrich_node_features_concatenate(mock_hdata): + dataset = Dataset.from_hdata(mock_hdata) + original_x = dataset.hdata.x.clone() + + enricher = MagicMock(spec=NodeEnricher) + enriched_x = torch.randn(3, 4) + enricher.enrich.return_value = enriched_x + + dataset.enrich_node_features(enricher, enrichment_mode="concatenate") + + enricher.enrich.assert_called_once_with(mock_hdata.hyperedge_index) + expected_x = torch.cat([original_x, enriched_x], dim=1) + assert torch.equal(dataset.hdata.x, expected_x) + assert dataset.hdata.x.shape == (3, 5) # 1 original + 4 enriched + + +def test_enrich_hyperedge_attr_replace(mock_hdata): + dataset = Dataset.from_hdata(mock_hdata) + + enricher = MagicMock(spec=HyperedgeEnricher) + enriched_x = torch.randn(3, 4) + enricher.enrich.return_value = enriched_x + + dataset.enrich_hyperedge_attr(enricher) + + enricher.enrich.assert_called_once_with(mock_hdata.hyperedge_index) + hyperedge_attr = dataset.hdata.hyperedge_attr + assert hyperedge_attr is not None + assert torch.equal(hyperedge_attr, enriched_x) + + +def test_enrich_hyperedge_attr_concatenate(mock_hdata_with_hyperedge_attr): + dataset = Dataset.from_hdata(mock_hdata_with_hyperedge_attr) + original_hyperedge_attr = dataset.hdata.hyperedge_attr + assert original_hyperedge_attr is not None + original_hyperedge_attr = original_hyperedge_attr.clone() + + enricher = MagicMock(spec=HyperedgeEnricher) + enriched_x = torch.randn(3, 4) + enricher.enrich.return_value = enriched_x + + dataset.enrich_hyperedge_attr(enricher, enrichment_mode="concatenate") + + enricher.enrich.assert_called_once_with(mock_hdata_with_hyperedge_attr.hyperedge_index) + expected_x = torch.cat([original_hyperedge_attr, enriched_x], dim=1) + hyperedge_attr = dataset.hdata.hyperedge_attr + assert hyperedge_attr is not None + assert torch.equal(hyperedge_attr, expected_x) + assert hyperedge_attr.shape == (3, 5) # 1 original + 4 enriched + + +def test_enrich_hyperedge_weights_replace(mock_hdata): + dataset = Dataset.from_hdata(mock_hdata) + + enricher = MagicMock(spec=HyperedgeEnricher) + enriched_weights = torch.randn(3) + enricher.enrich.return_value = enriched_weights + + dataset.enrich_hyperedge_weights(enricher) + + enricher.enrich.assert_called_once_with(mock_hdata.hyperedge_index) + hyperedge_weights = dataset.hdata.hyperedge_weights + assert hyperedge_weights is not None + assert torch.equal(hyperedge_weights, enriched_weights) + + +def test_enrich_hyperedge_weights_concatenate(mock_hdata_with_hyperedge_weights): + dataset = Dataset.from_hdata(mock_hdata_with_hyperedge_weights) + original_weights = dataset.hdata.hyperedge_weights + assert original_weights is not None + original_weights = original_weights.clone() + + enricher = MagicMock(spec=HyperedgeEnricher) + enriched_weights = torch.randn(3) + enricher.enrich.return_value = enriched_weights + + dataset.enrich_hyperedge_weights(enricher, enrichment_mode="concatenate") + + enricher.enrich.assert_called_once_with(mock_hdata_with_hyperedge_weights.hyperedge_index) + expected_weights = torch.cat([original_weights, enriched_weights], dim=0) + hyperedge_weights = dataset.hdata.hyperedge_weights + assert hyperedge_weights is not None + assert torch.equal(hyperedge_weights, expected_weights) + assert hyperedge_weights.shape == (6,) # 3 original + 3 enriched -# # def test_split_with_equal_ratios(mock_four_node_hypergraph): -# # with patch.object(HIFLoader, "load", return_value=mock_four_node_hypergraph): -# # dataset = AlgebraDataset() -# # splits = dataset.split([0.5, 0.5]) +@pytest.mark.parametrize( + "hyperedge_index, k, expected_hyperedge_index", + [ + pytest.param( + torch.tensor([[0, 1, 2], [0, 0, 0]]), + 4, + torch.zeros((2, 0), dtype=torch.long), + id="single_hyperedge_below_k_removed", + ), + pytest.param( + torch.tensor([[0, 1, 2], [0, 0, 0]]), + 3, + torch.tensor([[0, 1, 2], [0, 0, 0]]), + id="single_hyperedge_at_exact_k_kept", + ), + pytest.param( + torch.tensor([[0, 1, 2, 3, 4], [0, 0, 0, 1, 1]]), + 3, + torch.tensor([[0, 1, 2], [0, 0, 0]]), + id="two_hyperedges_first_kept_second_removed", + ), + pytest.param( + torch.tensor([[0, 1, 2, 3, 4, 5], [0, 0, 0, 1, 1, 1]]), + 3, + torch.tensor([[0, 1, 2, 3, 4, 5], [0, 0, 0, 1, 1, 1]]), + id="two_hyperedges_both_kept", + ), + pytest.param( + torch.tensor([[0, 1, 2, 3, 4, 5], [0, 0, 1, 1, 2, 2]]), + 3, + torch.zeros((2, 0), dtype=torch.long), + id="three_hyperedges_all_removed", + ), + ], +) +def test_remove_hyperedges_with_fewer_than_k_nodes(hyperedge_index, k, expected_hyperedge_index): + num_nodes = hyperedge_index[0].max().item() + 1 if hyperedge_index.shape[1] > 0 else 0 + x = torch.ones((num_nodes, 1), dtype=torch.float) + hdata = HData(x=x, hyperedge_index=hyperedge_index) + dataset = Dataset.from_hdata(hdata) + + dataset.remove_hyperedges_with_fewer_than_k_nodes(k) + + expected_num_nodes = expected_hyperedge_index[0].unique().shape[0] + expected_num_hyperedges = expected_hyperedge_index[1].unique().shape[0] + + assert torch.equal(dataset.hdata.hyperedge_index, expected_hyperedge_index) + assert dataset.hdata.x.shape[0] == expected_num_nodes + assert dataset.hdata.y.shape[0] == expected_num_hyperedges + -# # assert len(splits) == 2 -# # assert ( -# # splits[0].hdata.num_hyperedges + splits[1].hdata.num_hyperedges -# # == dataset.hdata.num_hyperedges -# # ) -# # for split in splits: -# # assert split.hdata.x is not None -# # assert split.hdata.num_nodes > 0 -# # assert split.hdata.num_hyperedges > 0 +def test_split_with_equal_ratios(mock_hdata_four_node_hypergraph): + with patch.object(HIFLoader, "load_from_name", return_value=mock_hdata_four_node_hypergraph): + dataset = AlgebraDataset() + splits = dataset.split([0.5, 0.5]) -# # def test_split_three_way(mock_multiple_edges_attr_hypergraph): -# # with patch.object( -# # HIFLoader, "load", return_value=mock_multiple_edges_attr_hypergraph -# # ): -# # dataset = AlgebraDataset() - -# # splits = dataset.split([0.5, 0.25, 0.25]) -# # total_edges = sum(split.hdata.num_hyperedges for split in splits) - -# # assert len(splits) == 3 -# # assert total_edges == dataset.hdata.num_hyperedges - -# # for split in splits: -# # assert split.hdata.x is not None -# # assert split.hdata.num_nodes > 0 -# # assert split.hdata.num_hyperedges > 0 - - -# # def test_split_raises_when_ratios_do_not_sum_to_one(mock_four_node_hypergraph): -# # with patch.object(HIFLoader, "load", return_value=mock_four_node_hypergraph): -# # dataset = AlgebraDataset() - -# # with pytest.raises(ValueError, match="Split ratios must sum to 1.0"): -# # dataset.split([0.8, 0.1, 0.05]) - - -# # def test_split_with_shuffle_produces_deterministic_results_when_seed_provided( -# # mock_four_node_hypergraph, -# # ): -# # with patch.object(HIFLoader, "load", return_value=mock_four_node_hypergraph): -# # dataset = AlgebraDataset() - -# # splits_a = dataset.split([0.5, 0.5], shuffle=True, seed=42) -# # splits_b = dataset.split([0.5, 0.5], shuffle=True, seed=42) - -# # assert torch.equal(splits_a[0].hdata.hyperedge_index, splits_b[0].hdata.hyperedge_index) -# # assert torch.equal(splits_a[1].hdata.hyperedge_index, splits_b[1].hdata.hyperedge_index) - - -# # def test_split_with_shuffle_when_no_seed_provided( -# # mock_four_node_hypergraph, -# # ): -# # with patch.object(HIFLoader, "load", return_value=mock_four_node_hypergraph): -# # dataset = AlgebraDataset() + assert len(splits) == 2 + assert ( + splits[0].hdata.num_hyperedges + splits[1].hdata.num_hyperedges + == dataset.hdata.num_hyperedges + ) + for split in splits: + assert split.hdata.x is not None + assert split.hdata.num_nodes > 0 + assert split.hdata.num_hyperedges > 0 -# # splits = dataset.split([0.5, 0.5], shuffle=True) -# # total_edges = sum(split.hdata.num_hyperedges for split in splits) -# # assert len(splits) == 2 -# # assert total_edges == dataset.hdata.num_hyperedges +def test_split_three_way(mock_hdata_multiple_edges_attr_hypergraph): + with patch.object( + HIFLoader, "load_from_name", return_value=mock_hdata_multiple_edges_attr_hypergraph + ): + dataset = AlgebraDataset() -# # for split in splits: -# # assert split.hdata.x is not None -# # assert split.hdata.num_nodes > 0 -# # assert split.hdata.num_hyperedges > 0 + splits = dataset.split([0.5, 0.25, 0.25]) + total_edges = sum(split.hdata.num_hyperedges for split in splits) + assert len(splits) == 3 + assert total_edges == dataset.hdata.num_hyperedges -# # def test_split_preserves_edge_attr(mock_multiple_edges_attr_hypergraph): -# # with patch.object( -# # HIFLoader, "load", return_value=mock_multiple_edges_attr_hypergraph -# # ): -# # dataset = AlgebraDataset() + for split in splits: + assert split.hdata.x is not None + assert split.hdata.num_nodes > 0 + assert split.hdata.num_hyperedges > 0 -# # splits = dataset.split([0.5, 0.5]) -# # for split in splits: -# # assert split.hdata.hyperedge_attr is not None -# # assert split.hdata.hyperedge_attr.shape[0] == split.hdata.num_hyperedges +def test_split_raises_when_ratios_do_not_sum_to_one(mock_hdata_four_node_hypergraph): + with patch.object(HIFLoader, "load_from_name", return_value=mock_hdata_four_node_hypergraph): + dataset = AlgebraDataset() + with pytest.raises(ValueError, match="Split ratios must sum to 1.0"): + dataset.split([0.8, 0.1, 0.05]) -# # def test_split_without_edge_attr(mock_no_edge_attr_hypergraph): -# # with patch.object(HIFLoader, "load", return_value=mock_no_edge_attr_hypergraph): -# # dataset = AlgebraDataset() -# # splits = dataset.split([0.5, 0.5]) +def test_split_with_shuffle_produces_deterministic_results_when_seed_provided( + mock_hdata_four_node_hypergraph, +): + with patch.object(HIFLoader, "load_from_name", return_value=mock_hdata_four_node_hypergraph): + dataset = AlgebraDataset() -# # for split in splits: -# # assert split.hdata.hyperedge_attr is None + splits_a = dataset.split([0.5, 0.5], shuffle=True, seed=42) + splits_b = dataset.split([0.5, 0.5], shuffle=True, seed=42) + assert torch.equal(splits_a[0].hdata.hyperedge_index, splits_b[0].hdata.hyperedge_index) + assert torch.equal(splits_a[1].hdata.hyperedge_index, splits_b[1].hdata.hyperedge_index) -# # def test_to_device(mock_hdata): -# # device = torch.device("cpu") -# # dataset = Dataset.from_hdata(mock_hdata) +def test_split_with_shuffle_when_no_seed_provided( + mock_hdata_four_node_hypergraph, +): + with patch.object(HIFLoader, "load_from_name", return_value=mock_hdata_four_node_hypergraph): + dataset = AlgebraDataset() -# # result = dataset.to(device) + splits = dataset.split([0.5, 0.5], shuffle=True) + total_edges = sum(split.hdata.num_hyperedges for split in splits) -# # assert result is dataset -# # assert dataset.hdata.device == device + assert len(splits) == 2 + assert total_edges == dataset.hdata.num_hyperedges + for split in splits: + assert split.hdata.x is not None + assert split.hdata.num_nodes > 0 + assert split.hdata.num_hyperedges > 0 -# # def test_load_skips_download_when_file_exists(): -# # dataset_name = "ALGEBRA" -# # sample_hif = { -# # "network-type": "undirected", -# # "nodes": [{"node": "0"}, {"node": "1"}], -# # "edges": [{"edge": "0"}], -# # "incidences": [{"node": "0", "edge": "0"}], -# # } +def test_split_preserves_edge_attr(mock_hdata_multiple_edges_attr_hypergraph): + with patch.object( + HIFLoader, "load_from_name", return_value=mock_hdata_multiple_edges_attr_hypergraph + ): + dataset = AlgebraDataset() -# # mock_hypergraph = HIFHypergraph( -# # network_type="undirected", -# # nodes=[{"node": "0"}, {"node": "1"}], -# # hyperedges=[{"edge": "0"}], -# # incidences=[{"node": "0", "edge": "0"}], -# # ) + splits = dataset.split([0.5, 0.5]) -# # with ( -# # patch("hyperbench.data.dataset.requests.get") as mock_get, -# # patch("hyperbench.data.dataset.os.path.exists", return_value=True), -# # patch("builtins.open", mock_open()) as mock_file, -# # patch("hyperbench.data.dataset.zstd.ZstdDecompressor") as mock_decomp, -# # patch("hyperbench.data.dataset.tempfile.NamedTemporaryFile") as mock_temp, -# # patch("hyperbench.data.dataset.json.load", return_value=sample_hif), -# # patch("hyperbench.data.dataset.validate_hif_json", return_value=True), -# # patch.object(HIFHypergraph, "from_hif", return_value=mock_hypergraph), -# # ): -# # mock_dctx = mock_decomp.return_value -# # mock_dctx.copy_stream = lambda input_f, tmp_file: None + for split in splits: + assert split.hdata.hyperedge_weights is not None + assert split.hdata.hyperedge_weights.shape[0] == split.hdata.num_hyperedges -# # mock_temp_instance = mock_temp.return_value.__enter__.return_value -# # mock_temp_instance.name = "/tmp/decompressed.json" -# # result = HIFLoader.load(dataset_name, save_on_disk=True) -# # mock_get.assert_not_called() -# # assert result == mock_hypergraph +def test_split_without_edge_attr(mock_hdata_no_edge_attr_hypergraph): + with patch.object(HIFLoader, "load_from_name", return_value=mock_hdata_no_edge_attr_hypergraph): + dataset = AlgebraDataset() + splits = dataset.split([0.5, 0.5]) -# # def test_default_sampling_strategy_is_hyperedge(mock_four_node_hypergraph): -# # with patch.object(HIFLoader, "load", return_value=mock_four_node_hypergraph): -# # dataset = AlgebraDataset() + for split in splits: + assert split.hdata.hyperedge_attr is None -# # # Default strategy is HYPEREDGE, so len should be num_hyperedges (2), not num_nodes (4) -# # assert dataset.sampling_strategy == SamplingStrategy.HYPEREDGE -# # assert len(dataset) == 2 +def test_to_device(mock_hdata): + device = torch.device("cpu") -# # def test_explicit_node_sampling_strategy(mock_four_node_hypergraph): -# # with patch.object(HIFLoader, "load", return_value=mock_four_node_hypergraph): -# # dataset = AlgebraDataset(sampling_strategy=SamplingStrategy.NODE) + dataset = Dataset.from_hdata(mock_hdata) -# # # NODE strategy, so len should be num_nodes (4), not num_hyperedges (2) -# # assert dataset.sampling_strategy == SamplingStrategy.NODE -# # assert len(dataset) == 4 + result = dataset.to(device) + assert result is dataset + assert dataset.hdata.device == device -# # @pytest.mark.parametrize( -# # "strategy", -# # [ -# # pytest.param(SamplingStrategy.NODE, id="node_strategy"), -# # pytest.param(SamplingStrategy.HYPEREDGE, id="hyperedge_strategy"), -# # ], -# # ) -# # def test_split_preserves_sampling_strategy(mock_four_node_hypergraph, strategy): -# # with patch.object(HIFLoader, "load", return_value=mock_four_node_hypergraph): -# # dataset = AlgebraDataset(sampling_strategy=strategy) -# # splits = dataset.split([0.5, 0.5]) +def test_default_sampling_strategy_is_hyperedge(mock_hdata_four_node_hypergraph): + with patch.object(HIFLoader, "load_from_name", return_value=mock_hdata_four_node_hypergraph): + dataset = AlgebraDataset() -# # for split in splits: -# # assert split.sampling_strategy == strategy + # Default strategy is HYPEREDGE, so len should be num_hyperedges (2), not num_nodes (4) + assert dataset.sampling_strategy == SamplingStrategy.HYPEREDGE + assert len(dataset) == 2 -# # def test_from_hdata_with_explicit_strategy(mock_hdata): -# # dataset = Dataset.from_hdata(mock_hdata, sampling_strategy=SamplingStrategy.NODE) +def test_explicit_node_sampling_strategy(mock_hdata_four_node_hypergraph): + with patch.object(HIFLoader, "load_from_name", return_value=mock_hdata_four_node_hypergraph): + dataset = AlgebraDataset(sampling_strategy=SamplingStrategy.NODE) -# # assert dataset.sampling_strategy == SamplingStrategy.NODE -# # assert len(dataset) == 3 # mock_hdata has 3 nodes + # NODE strategy, so len should be num_nodes (4), not num_hyperedges (2) + assert dataset.sampling_strategy == SamplingStrategy.NODE + assert len(dataset) == 4 -# # def test_update_from_hdata_returns_new_dataset(mock_hdata): -# # dataset = Dataset(hdata=mock_hdata, prepare=False) -# # new_x = torch.ones((2, 1), dtype=torch.float) -# # new_hyperedge_index = torch.tensor([[0, 1], [0, 0]], dtype=torch.long) -# # new_hdata = HData(x=new_x, hyperedge_index=new_hyperedge_index) - -# # result = dataset.update_from_hdata(new_hdata) - -# # assert result is not dataset -# # assert result.hdata is new_hdata -# # assert dataset.hdata is mock_hdata - - -# # def test_update_from_hdata_stores_provided_hdata(mock_hdata): -# # dataset = Dataset(hdata=mock_hdata, prepare=False) -# # new_x = torch.ones((2, 1), dtype=torch.float) -# # new_hyperedge_index = torch.tensor([[0, 1], [0, 0]], dtype=torch.long) -# # new_hdata = HData(x=new_x, hyperedge_index=new_hyperedge_index) - -# # result = dataset.update_from_hdata(new_hdata) - -# # assert result.hdata is new_hdata - - -# # @pytest.mark.parametrize( -# # "strategy, expected_len", -# # [ -# # pytest.param(SamplingStrategy.NODE, 4, id="node_strategy"), -# # pytest.param(SamplingStrategy.HYPEREDGE, 3, id="hyperedge_strategy"), -# # ], -# # ) -# # def test_update_from_hdata_inherits_sampling_strategy(mock_hdata, strategy, expected_len): -# # dataset = Dataset(hdata=mock_hdata, sampling_strategy=strategy, prepare=False) -# # new_x = torch.ones((4, 1), dtype=torch.float) -# # new_hyperedge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 2]], dtype=torch.long) -# # new_hdata = HData(x=new_x, hyperedge_index=new_hyperedge_index) - -# # result = dataset.update_from_hdata(new_hdata) - -# # assert result.sampling_strategy == strategy -# # assert len(result) == expected_len - - -# # def test_update_from_hdata_preserves_subclass_type(mock_hdata): -# # dataset = AlgebraDataset(hdata=mock_hdata, prepare=False) -# # new_x = torch.ones((2, 1), dtype=torch.float) -# # new_hyperedge_index = torch.tensor([[0, 1], [0, 0]], dtype=torch.long) -# # new_hdata = HData(x=new_x, hyperedge_index=new_hyperedge_index) - -# # result = dataset.update_from_hdata(new_hdata) - -# # assert type(result) is AlgebraDataset +@pytest.mark.parametrize( + "strategy", + [ + pytest.param(SamplingStrategy.NODE, id="node_strategy"), + pytest.param(SamplingStrategy.HYPEREDGE, id="hyperedge_strategy"), + ], +) +def test_split_preserves_sampling_strategy(mock_hdata_four_node_hypergraph, strategy): + with patch.object(HIFLoader, "load_from_name", return_value=mock_hdata_four_node_hypergraph): + dataset = AlgebraDataset(sampling_strategy=strategy) - -# # @pytest.fixture -# # def mock_hdata_stats(): -# # x = torch.tensor( -# # [ -# # [0.0, 1.0, 2.0, 3.0], -# # [1.0, 2.0, 3.0, 4.0], -# # [2.0, 3.0, 4.0, 5.0], -# # [3.0, 4.0, 5.0, 6.0], -# # ], -# # dtype=torch.float, -# # ) -# # hyperedge_index = torch.tensor( -# # [ -# # [0, 1, 2, 2, 3], -# # [0, 0, 0, 1, 1], -# # ] -# # ) -# # return HData(x=x, hyperedge_index=hyperedge_index) - - -# # def test_dataset_stats_computation(mock_hdata_stats): -# # expected_stats = { -# # "shape_x": torch.Size([4, 4]), -# # "shape_hyperedge_attr": None, -# # "shape_hyperedge_weights": None, -# "num_nodes": 4, -# # "num_hyperedges": 2, -# # "avg_degree_node_raw": 1.25, -# # "avg_degree_node": 1, -# # "avg_degree_hyperedge_raw": 2.5, -# # "avg_degree_hyperedge": 2, -# # "node_degree_max": 2, -# # "hyperedge_degree_max": 3, -# # "node_degree_median": 1, -# # "hyperedge_degree_median": 2, -# # "distribution_node_degree": [1, 1, 2, 1], -# # "distribution_hyperedge_size": [3, 2], -# # "distribution_node_degree_hist": {1: 3, 2: 1}, -# # "distribution_hyperedge_size_hist": {2: 1, 3: 1}, -# # } - -# # dataset = Dataset.from_hdata(mock_hdata_stats) - -# # stats = dataset.stats() -# # assert stats == expected_stats + splits = dataset.split([0.5, 0.5]) + + for split in splits: + assert split.sampling_strategy == strategy + + +def test_from_hdata_with_explicit_strategy(mock_hdata): + dataset = Dataset.from_hdata(mock_hdata, sampling_strategy=SamplingStrategy.NODE) + + assert dataset.sampling_strategy == SamplingStrategy.NODE + assert len(dataset) == 3 # mock_hdata has 3 nodes + + +def test_update_from_hdata_returns_new_dataset(mock_hdata): + dataset = Dataset(hdata=mock_hdata) + new_x = torch.ones((2, 1), dtype=torch.float) + new_hyperedge_index = torch.tensor([[0, 1], [0, 0]], dtype=torch.long) + new_hdata = HData(x=new_x, hyperedge_index=new_hyperedge_index) + + result = dataset.update_from_hdata(new_hdata) + + assert result is not dataset + assert result.hdata is new_hdata + assert dataset.hdata is mock_hdata + + +def test_update_from_hdata_stores_provided_hdata(mock_hdata): + dataset = Dataset(hdata=mock_hdata) + new_x = torch.ones((2, 1), dtype=torch.float) + new_hyperedge_index = torch.tensor([[0, 1], [0, 0]], dtype=torch.long) + new_hdata = HData(x=new_x, hyperedge_index=new_hyperedge_index) + + result = dataset.update_from_hdata(new_hdata) + + assert result.hdata is new_hdata + + +@pytest.mark.parametrize( + "strategy, expected_len", + [ + pytest.param(SamplingStrategy.NODE, 4, id="node_strategy"), + pytest.param(SamplingStrategy.HYPEREDGE, 3, id="hyperedge_strategy"), + ], +) +def test_update_from_hdata_inherits_sampling_strategy(mock_hdata, strategy, expected_len): + dataset = Dataset(hdata=mock_hdata, sampling_strategy=strategy) + new_x = torch.ones((4, 1), dtype=torch.float) + new_hyperedge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 2]], dtype=torch.long) + new_hdata = HData(x=new_x, hyperedge_index=new_hyperedge_index) + + result = dataset.update_from_hdata(new_hdata) + + assert result.sampling_strategy == strategy + assert len(result) == expected_len + + +def test_update_from_hdata_preserves_subclass_type(mock_hdata): + dataset = AlgebraDataset(hdata=mock_hdata) + new_x = torch.ones((2, 1), dtype=torch.float) + new_hyperedge_index = torch.tensor([[0, 1], [0, 0]], dtype=torch.long) + new_hdata = HData(x=new_x, hyperedge_index=new_hyperedge_index) + + result = dataset.update_from_hdata(new_hdata) + + assert type(result) is AlgebraDataset + + +@pytest.fixture +def mock_hdata_stats(): + x = torch.tensor( + [ + [0.0, 1.0, 2.0, 3.0], + [1.0, 2.0, 3.0, 4.0], + [2.0, 3.0, 4.0, 5.0], + [3.0, 4.0, 5.0, 6.0], + ], + dtype=torch.float, + ) + hyperedge_index = torch.tensor( + [ + [0, 1, 2, 2, 3], + [0, 0, 0, 1, 1], + ] + ) + return HData(x=x, hyperedge_index=hyperedge_index) + + +def test_dataset_stats_computation(mock_hdata_stats): + expected_stats = { + "shape_x": torch.Size([4, 4]), + "shape_hyperedge_attr": None, + "shape_hyperedge_weights": None, + "num_nodes": 4, + "num_hyperedges": 2, + "avg_degree_node_raw": 1.25, + "avg_degree_node": 1, + "avg_degree_hyperedge_raw": 2.5, + "avg_degree_hyperedge": 2, + "node_degree_max": 2, + "hyperedge_degree_max": 3, + "node_degree_median": 1, + "hyperedge_degree_median": 2, + "distribution_node_degree": [1, 1, 2, 1], + "distribution_hyperedge_size": [3, 2], + "distribution_node_degree_hist": {1: 3, 2: 1}, + "distribution_hyperedge_size_hist": {2: 1, 3: 1}, + } + + dataset = Dataset.from_hdata(mock_hdata_stats) + + stats = dataset.stats() + assert stats == expected_stats + + +def test_transform_node_attrs_delegates_to_hifprocessor(): + attrs = {"weight": 1.5, "score": 0.8} + attr_keys = ["score", "weight"] + expected = torch.tensor([0.8, 1.5], dtype=torch.float) + + with patch( + "hyperbench.data.dataset.HIFProcessor.transform_attrs", + return_value=expected, + ) as mock_transform: + result = Dataset.transform_node_attrs(attrs, attr_keys) + + mock_transform.assert_called_once_with(attrs, attr_keys) + assert torch.equal(result, expected) + + +def test_transform_hyperedge_attrs_delegates_to_hifprocessor(): + attrs = {"weight": 2.5} + attr_keys = ["weight", "score"] + expected = torch.tensor([2.5, 0.0], dtype=torch.float) + + with patch( + "hyperbench.data.dataset.HIFProcessor.transform_attrs", + return_value=expected, + ) as mock_transform: + result = Dataset.transform_hyperedge_attrs(attrs, attr_keys) + + mock_transform.assert_called_once_with(attrs, attr_keys) + assert torch.equal(result, expected) diff --git a/hyperbench/tests/data/hif_test.py b/hyperbench/tests/data/hif_test.py index e69de29..bd4bb68 100644 --- a/hyperbench/tests/data/hif_test.py +++ b/hyperbench/tests/data/hif_test.py @@ -0,0 +1,540 @@ +import pytest +import requests +import torch +import json +import os + +from unittest.mock import patch, MagicMock + +from hyperbench.data.hif import HIFLoader, HIFProcessor +import hyperbench.data.hif as hif_module +from hyperbench.types import HData, HIFHypergraph + + +@pytest.fixture +def mock_sample_hypergraph(): + return HIFHypergraph( + network_type="undirected", + nodes=[{"node": "0"}, {"node": "1"}], + hyperedges=[{"edge": "0"}], + incidences=[{"node": "0", "edge": "0"}], + ) + + +@pytest.fixture +def mock_simple_hypergraph(): + return HIFHypergraph( + network_type="undirected", + nodes=[{"node": "0", "attrs": {}}, {"node": "1", "attrs": {}}], + hyperedges=[{"edge": "0", "attrs": {}}], + incidences=[{"node": "0", "edge": "0"}], + ) + + +@pytest.fixture +def mock_three_node_weighted_hypergraph(): + return HIFHypergraph( + network_type="undirected", + nodes=[ + {"node": "0", "attrs": {}}, + {"node": "1", "attrs": {}}, + {"node": "2", "attrs": {}}, + ], + hyperedges=[ + {"edge": "0", "attrs": {"weight": 1.0}}, + {"edge": "1", "attrs": {"weight": 2.0}}, + ], + incidences=[ + {"node": "0", "edge": "0"}, + {"node": "1", "edge": "0"}, + {"node": "2", "edge": "1"}, + ], + ) + + +@pytest.fixture +def mock_four_node_hypergraph(): + return HIFHypergraph( + network_type="undirected", + nodes=[ + {"node": "0", "attrs": {}}, + {"node": "1", "attrs": {}}, + {"node": "2", "attrs": {}}, + {"node": "3", "attrs": {}}, + ], + hyperedges=[{"edge": "0", "attrs": {}}, {"edge": "1", "attrs": {}}], + incidences=[ + {"node": "0", "edge": "0"}, + {"node": "1", "edge": "0"}, + {"node": "2", "edge": "1"}, + {"node": "3", "edge": "1"}, + ], + ) + + +@pytest.fixture +def mock_five_node_hypergraph(): + return HIFHypergraph( + network_type="undirected", + nodes=[ + {"node": "0", "attrs": {}}, + {"node": "1", "attrs": {}}, + {"node": "2", "attrs": {}}, + {"node": "3", "attrs": {}}, + {"node": "4", "attrs": {}}, + ], + hyperedges=[{"edge": "0", "attrs": {}}], + incidences=[{"node": "0", "edge": "0"}], + ) + + +@pytest.fixture +def mock_no_edge_attr_hypergraph(): + return HIFHypergraph( + network_type="undirected", + nodes=[ + {"node": "0", "attrs": {}}, + {"node": "1", "attrs": {}}, + ], + hyperedges=[{"edge": "0"}], + incidences=[ + {"node": "0", "edge": "0"}, + {"node": "1", "edge": "0"}, + ], + ) + + +@pytest.fixture +def mock_multiple_edges_attr_hypergraph(): + return HIFHypergraph( + network_type="undirected", + nodes=[ + {"node": "0", "attrs": {}}, + {"node": "1", "attrs": {}}, + {"node": "2", "attrs": {}}, + {"node": "3", "attrs": {}}, + ], + hyperedges=[ + {"edge": "0", "attrs": {"weight": 1.0}}, + {"edge": "1", "attrs": {"weight": 2.0}}, + {"edge": "2", "attrs": {"weight": 3.0}}, + ], + incidences=[ + {"node": "0", "edge": "0"}, + {"node": "1", "edge": "0"}, + {"node": "2", "edge": "1"}, + {"node": "3", "edge": "2"}, + ], + ) + + +@pytest.fixture +def mock_hypergraph() -> HIFHypergraph: + return HIFHypergraph( + network_type="undirected", + nodes=[{"node": "0", "attrs": {}}, {"node": "1", "attrs": {}}], + hyperedges=[{"edge": "0", "attrs": {"weight": 1.0}}], + incidences=[{"node": "0", "edge": "0"}, {"node": "1", "edge": "0"}], + ) + + +@pytest.fixture +def mock_hdata() -> HData: + x = torch.ones((2, 1), dtype=torch.float) + hyperedge_index = torch.tensor([[0, 1], [0, 0]], dtype=torch.long) + return HData(x=x, hyperedge_index=hyperedge_index) + + +def _write_hif_json(tmp_path, hypergraph: HIFHypergraph, filename: str = "sample.json") -> str: + path = tmp_path / filename + payload = { + "network-type": hypergraph.network_type, + "metadata": hypergraph.metadata, + "nodes": hypergraph.nodes, + "edges": hypergraph.hyperedges, + "incidences": hypergraph.incidences, + } + with open(path, "w") as f: + json.dump(payload, f) + return str(path) + + +def _mock_named_temporary_file(path): + file_handle = open(path, "wb") + mocked_cm = MagicMock() + mocked_cm.__enter__.return_value = file_handle + mocked_cm.__exit__.side_effect = lambda exc_type, exc, tb: file_handle.close() + return mocked_cm + + +def test_transform_attrs_empty_attrs(): + result = HIFProcessor.transform_attrs({}) + assert len(result) == 0 + + attrs = {"name": "node1", "active": True} + result = HIFProcessor.transform_attrs(attrs) + assert len(result) == 0 + + +def test_transform_attrs_adds_padding_zero_when_attr_keys_padding(): + attrs = {"weight": 1.5} + result = HIFProcessor.transform_attrs(attrs, attr_keys=["score", "weight", "age"]) + assert torch.allclose(result, torch.tensor([0.0, 1.5, 0.0])) + + attrs = {"weight": 1.5, "score": 0.8, "age": 25.0} + result = HIFProcessor.transform_attrs(attrs, attr_keys=["age", "score", "weight"]) + assert torch.allclose(result, torch.tensor([25.0, 0.8, 1.5])) + + attrs = {"weight": 1.5, "score": 0.8} + result = HIFProcessor.transform_attrs(attrs) + assert torch.allclose(result, torch.tensor([1.5, 0.8])) + + +def test_transform_hyperedge_attrs_adds_padding_zero_when_attr_keys_padding(): + attrs = {"weight": 2.5} + result = HIFProcessor.transform_attrs(attrs, attr_keys=["score", "weight"]) + assert torch.allclose(result, torch.tensor([0.0, 2.5])) + + +def test_transform_node_attrs_adds_padding_zero_when_attr_keys_padding(): + attrs = {"weight": 2.5} + result = HIFProcessor.transform_attrs(attrs, attr_keys=["score", "weight"]) + assert torch.allclose(result, torch.tensor([0.0, 2.5])) + + +def test_load_from_url_rejects_invalid_url(): + with pytest.raises(ValueError, match="Invalid URL"): + HIFLoader.load_from_url("not-a-url") + + +def test_load_from_url_raises_when_status_is_not_200(): + with patch("hyperbench.data.hif.requests.get") as mock_get: + mock_response = mock_get.return_value + mock_response.status_code = 404 + + with pytest.raises(ValueError, match="Failed to download dataset from URL"): + HIFLoader.load_from_url("https://example.com/file.json.zst") + + +def test_load_from_path_raises_for_missing_file(): + with pytest.raises(ValueError, match="does not exist"): + HIFLoader.load_from_path("/tmp/does-not-exist.json.zst") + + +def test_load_from_path_raises_for_unsupported_extension(tmp_path): + invalid = tmp_path / "sample.txt" + invalid.write_text("{}") + + with pytest.raises(ValueError, match="Unsupported file format"): + HIFLoader.load_from_path(str(invalid)) + + +@pytest.mark.parametrize( + "fixture_name, expected_nodes, expected_hyperedges, expected_incidences, has_hyperedge_weights", + [ + pytest.param("mock_sample_hypergraph", 2, 2, 2, False, id="sample_with_isolated_node"), + pytest.param("mock_simple_hypergraph", 2, 2, 2, True, id="simple_with_empty_attrs"), + pytest.param("mock_three_node_weighted_hypergraph", 3, 2, 3, True, id="weighted"), + pytest.param("mock_four_node_hypergraph", 4, 2, 4, True, id="four_nodes_two_edges"), + pytest.param("mock_five_node_hypergraph", 5, 5, 5, True, id="five_nodes_with_self_loops"), + pytest.param("mock_no_edge_attr_hypergraph", 2, 1, 2, False, id="no_edge_attr"), + pytest.param("mock_multiple_edges_attr_hypergraph", 4, 3, 4, True, id="multiple_weighted"), + ], +) +def test_load_from_path_processes_hypergraph_cases( + tmp_path, + request, + fixture_name, + expected_nodes, + expected_hyperedges, + expected_incidences, + has_hyperedge_weights, +): + hypergraph = request.getfixturevalue(fixture_name) + json_path = _write_hif_json(tmp_path, hypergraph, filename=f"{fixture_name}.json") + + with patch("hyperbench.data.hif.validate_hif_json", return_value=True): + hdata = HIFLoader.load_from_path(json_path) + + assert hdata.num_nodes == expected_nodes + assert hdata.num_hyperedges == expected_hyperedges + assert hdata.hyperedge_index.shape[1] == expected_incidences + assert (hdata.hyperedge_weights is not None) is has_hyperedge_weights + + +def test_load_from_path_zst_uses_decompress(tmp_path, mock_hypergraph): + zst_path = tmp_path / "sample.json.zst" + zst_path.write_bytes(b"dummy") + json_path = _write_hif_json(tmp_path, mock_hypergraph) + + with ( + patch("hyperbench.data.hif.decompress_zst", return_value=json_path) as mock_decompress, + patch("hyperbench.data.hif.validate_hif_json", return_value=True), + ): + hdata = HIFLoader.load_from_path(str(zst_path)) + + mock_decompress.assert_called_once_with(str(zst_path)) + assert hdata.num_nodes == 2 + assert hdata.num_hyperedges == 1 + + +def test_load_from_path_raises_for_non_hif_compliant_json(tmp_path, mock_hypergraph): + json_path = _write_hif_json(tmp_path, mock_hypergraph) + + with patch("hyperbench.data.hif.validate_hif_json", return_value=False): + with pytest.raises(ValueError, match="is not HIF-compliant"): + HIFLoader.load_from_path(json_path) + + +def test_load_from_url_processes_zst_and_saves_to_disk(tmp_path, mock_hypergraph): + unique_name = f"algebra_{tmp_path.name}.json.zst" + url = f"https://example.com/{unique_name}" + json_path = _write_hif_json(tmp_path, mock_hypergraph) + current_dir = os.path.dirname(os.path.abspath(hif_module.__file__)) + saved_path = os.path.join(current_dir, "datasets", f"{unique_name}.json.zst") + + with ( + patch("hyperbench.data.hif.requests.get") as mock_get, + patch("hyperbench.data.hif.decompress_zst", return_value=json_path), + patch("hyperbench.data.hif.validate_hif_json", return_value=True), + ): + mock_response = mock_get.return_value + mock_response.status_code = 200 + mock_response.content = b"mock-zst-content" + + hdata = HIFLoader.load_from_url(url, save_on_disk=True) + + try: + assert os.path.exists(saved_path) + assert hdata.num_nodes == 2 + assert hdata.num_hyperedges == 1 + finally: + if os.path.exists(saved_path): + os.remove(saved_path) + + +def test_load_from_url_processes_json_and_saves_compressed_copy(tmp_path, mock_hypergraph): + unique_name = f"algebra_{tmp_path.name}.json" + url = f"https://example.com/{unique_name}" + payload = { + "network-type": mock_hypergraph.network_type, + "metadata": mock_hypergraph.metadata, + "nodes": mock_hypergraph.nodes, + "edges": mock_hypergraph.hyperedges, + "incidences": mock_hypergraph.incidences, + } + current_dir = os.path.dirname(os.path.abspath(hif_module.__file__)) + saved_path = os.path.join(current_dir, "datasets", f"{unique_name}.json.zst") + + with ( + patch("hyperbench.data.hif.requests.get") as mock_get, + patch( + "hyperbench.data.hif.tempfile.NamedTemporaryFile", + return_value=_mock_named_temporary_file(tmp_path / "downloaded.json"), + ), + patch("hyperbench.data.hif.compress_to_zst", return_value=b"compressed") as mock_compress, + patch("hyperbench.data.hif.validate_hif_json", return_value=True), + ): + mock_response = mock_get.return_value + mock_response.status_code = 200 + mock_response.content = json.dumps(payload).encode("utf-8") + + hdata = HIFLoader.load_from_url(url, save_on_disk=True) + + mock_compress.assert_called_once_with(str(tmp_path / "downloaded.json")) + try: + assert os.path.exists(saved_path) + assert hdata.num_nodes == 2 + assert hdata.num_hyperedges == 1 + finally: + if os.path.exists(saved_path): + os.remove(saved_path) + + +def test_load_from_url_processes_zst_without_saving_to_disk(tmp_path, mock_hypergraph): + unique_name = f"algebra_{tmp_path.name}.json.zst" + url = f"https://example.com/{unique_name}" + json_path = _write_hif_json(tmp_path, mock_hypergraph) + current_dir = os.path.dirname(os.path.abspath(hif_module.__file__)) + saved_path = os.path.join(current_dir, "datasets", f"{unique_name}.json.zst") + + with ( + patch("hyperbench.data.hif.requests.get") as mock_get, + patch("hyperbench.data.hif.decompress_zst", return_value=json_path) as mock_decompress, + patch("hyperbench.data.hif.validate_hif_json", return_value=True), + ): + mock_response = mock_get.return_value + mock_response.status_code = 200 + mock_response.content = b"mock-zst-content" + + hdata = HIFLoader.load_from_url(url, save_on_disk=False) + + assert not os.path.exists(saved_path) + mock_decompress.assert_called_once() + assert hdata.num_nodes == 2 + assert hdata.num_hyperedges == 1 + + +def test_load_from_url_processes_json_without_saving_to_disk(tmp_path, mock_hypergraph): + unique_name = f"algebra_{tmp_path.name}.json" + url = f"https://example.com/{unique_name}" + payload = { + "network-type": mock_hypergraph.network_type, + "metadata": mock_hypergraph.metadata, + "nodes": mock_hypergraph.nodes, + "edges": mock_hypergraph.hyperedges, + "incidences": mock_hypergraph.incidences, + } + current_dir = os.path.dirname(os.path.abspath(hif_module.__file__)) + saved_path = os.path.join(current_dir, "datasets", f"{unique_name}.json.zst") + + with ( + patch("hyperbench.data.hif.requests.get") as mock_get, + patch( + "hyperbench.data.hif.tempfile.NamedTemporaryFile", + return_value=_mock_named_temporary_file(tmp_path / "downloaded_no_save.json"), + ), + patch("hyperbench.data.hif.compress_to_zst") as mock_compress, + patch("hyperbench.data.hif.validate_hif_json", return_value=True), + ): + mock_response = mock_get.return_value + mock_response.status_code = 200 + mock_response.content = json.dumps(payload).encode("utf-8") + + hdata = HIFLoader.load_from_url(url, save_on_disk=False) + + mock_compress.assert_not_called() + assert not os.path.exists(saved_path) + assert hdata.num_nodes == 2 + assert hdata.num_hyperedges == 1 + + +def test_load_from_path_processes_node_numeric_attrs_into_features(tmp_path): + hypergraph = HIFHypergraph( + network_type="undirected", + nodes=[ + {"node": "0", "attrs": {"weight": 1.0, "score": 0.5}}, + {"node": "1", "attrs": {"weight": 2.0, "score": 1.5}}, + ], + hyperedges=[{"edge": "0", "attrs": {}}], + incidences=[{"node": "0", "edge": "0"}, {"node": "1", "edge": "0"}], + ) + json_path = _write_hif_json(tmp_path, hypergraph, filename="nodes_with_attrs.json") + + with patch("hyperbench.data.hif.validate_hif_json", return_value=True): + hdata = HIFLoader.load_from_path(json_path) + + assert hdata.x.shape == (2, 2) + assert torch.allclose(hdata.x[0], torch.tensor([1.0, 0.5])) + assert torch.allclose(hdata.x[1], torch.tensor([2.0, 1.5])) + + +def test_load_from_url_raises_for_unsupported_temp_extension(tmp_path): + with ( + patch("hyperbench.data.hif.requests.get") as mock_get, + patch( + "hyperbench.data.hif.tempfile.NamedTemporaryFile", + return_value=_mock_named_temporary_file(tmp_path / "downloaded.bin"), + ), + ): + mock_response = mock_get.return_value + mock_response.status_code = 200 + mock_response.content = b"bytes" + + with pytest.raises(ValueError, match="Unsupported file format"): + HIFLoader.load_from_url("https://example.com/algebra.unknown") + + +def test_load_skips_download_when_file_exists(tmp_path, mock_hypergraph): + json_path = _write_hif_json(tmp_path, mock_hypergraph) + + with ( + patch("hyperbench.data.hif.os.path.exists", return_value=True), + patch("hyperbench.data.hif.requests.get") as mock_get, + patch("hyperbench.data.hif.decompress_zst", return_value=json_path), + patch("hyperbench.data.hif.validate_hif_json", return_value=True), + ): + result = HIFLoader.load_from_name("algebra", save_on_disk=True) + + mock_get.assert_not_called() + assert result.num_nodes == 2 + assert result.num_hyperedges == 1 + + +def test_HIFLoader_download_failure_when_hf_fallback_fails(): + with ( + patch("hyperbench.data.hif.os.path.exists", return_value=False), + patch("hyperbench.data.hif.requests.get") as mock_get, + patch("hyperbench.data.hif.hf_hub_download", side_effect=Exception("HFHub failed")), + ): + mock_response = mock_get.return_value + mock_response.status_code = 404 + mock_response.content = b"" + + with pytest.warns(UserWarning, match="GitHub raw download failed"): + with pytest.raises(ValueError, match="Failed to download dataset 'algebra'"): + HIFLoader.load_from_name("algebra") + + +def test_HIFLoader_falls_back_to_hf_hub_download_when_github_raw_download_fails( + tmp_path, mock_hypergraph +): + fallback_file = tmp_path / "algebra.json.zst" + fallback_file.write_bytes(b"mock_zst_content") + json_path = _write_hif_json(tmp_path, mock_hypergraph) + + with ( + patch("hyperbench.data.hif.os.path.exists", return_value=False), + patch("hyperbench.data.hif.requests.get") as mock_get, + patch( + "hyperbench.data.hif.hf_hub_download", return_value=str(fallback_file) + ) as mock_hf_hub_download, + patch("hyperbench.data.hif.decompress_zst", return_value=json_path), + patch("hyperbench.data.hif.validate_hif_json", return_value=True), + ): + mock_response = mock_get.return_value + mock_response.status_code = 404 + mock_response.content = b"" + + with pytest.warns(UserWarning, match="GitHub raw download failed"): + result = HIFLoader.load_from_name("algebra", save_on_disk=False) + + mock_get.assert_called_once() + mock_hf_hub_download.assert_called_once() + assert result.num_nodes == 2 + assert result.num_hyperedges == 1 + + +def test_load_saves_downloaded_dataset_on_disk(tmp_path, mock_hypergraph): + json_path = _write_hif_json(tmp_path, mock_hypergraph) + + with ( + patch("hyperbench.data.hif.os.path.exists", return_value=False), + patch("hyperbench.data.hif.requests.get") as mock_get, + patch("hyperbench.data.hif.decompress_zst", return_value=json_path), + patch("hyperbench.data.hif.validate_hif_json", return_value=True), + patch("hyperbench.data.hif.os.path.abspath", return_value=str(tmp_path / "hif.py")), + ): + mock_response = mock_get.return_value + mock_response.status_code = 200 + mock_response.content = b"downloaded-content" + + result = HIFLoader.load_from_name("algebra", save_on_disk=True) + + saved = tmp_path / "datasets" / "algebra.json.zst" + assert saved.exists() + assert saved.read_bytes() == b"downloaded-content" + assert result.num_nodes == 2 + assert result.num_hyperedges == 1 + + +def test_HIFLoader_download_raises_when_network_error(): + with ( + patch("hyperbench.data.hif.os.path.exists", return_value=False), + patch( + "hyperbench.data.hif.requests.get", + side_effect=requests.RequestException("Network error"), + ), + ): + with pytest.raises(requests.RequestException, match="Network error"): + HIFLoader.load_from_name("algebra") diff --git a/hyperbench/tests/types/hypergraph_test.py b/hyperbench/tests/types/hypergraph_test.py index a2c6f59..9e4137c 100644 --- a/hyperbench/tests/types/hypergraph_test.py +++ b/hyperbench/tests/types/hypergraph_test.py @@ -20,6 +20,8 @@ def test_build_HIFHypergraph_instance(): hypergraph = HIFHypergraph.from_hif(hiftext) assert isinstance(hypergraph, HIFHypergraph) + assert hypergraph.num_nodes == 423 + assert hypergraph.num_hyperedges == 1268 def test_empty_hifhypergraph_returns_empty_hifhypergraph(): diff --git a/hyperbench/tests/utils/hif_utils_test.py b/hyperbench/tests/utils/hif_utils_test.py index 8695ea5..66c7d2c 100644 --- a/hyperbench/tests/utils/hif_utils_test.py +++ b/hyperbench/tests/utils/hif_utils_test.py @@ -1,7 +1,9 @@ import requests +import json +import os from unittest.mock import patch, mock_open, MagicMock -from hyperbench.utils import validate_hif_json +from hyperbench.utils import validate_hif_json, compress_to_zst, decompress_zst from hyperbench.tests import MOCK_BASE_PATH @@ -54,3 +56,40 @@ def test_validate_hif_json_with_url_request_exception_fallback(): # local file was opened calls = [str(call) for call in mock_file.call_args_list] assert any("../schema/hif_schema.json" in call for call in calls) + + +def test_compress_to_zst_returns_non_empty_bytes(tmp_path): + json_path = tmp_path / "sample.json" + json_path.write_text('{"nodes": [], "edges": [], "incidences": []}') + + compressed_content = compress_to_zst(str(json_path)) + + assert isinstance(compressed_content, bytes) + assert len(compressed_content) > 0 + + +def test_decompress_zst_round_trip_preserves_json_content(tmp_path): + expected_data = { + "network-type": "undirected", + "nodes": [{"node": "0", "attrs": {"weight": 1.0}}], + "edges": [{"edge": "0", "attrs": {}}], + "incidences": [{"node": "0", "edge": "0"}], + } + + json_path = tmp_path / "sample.json" + with open(json_path, "w") as f: + json.dump(expected_data, f) + + compressed_content = compress_to_zst(str(json_path)) + zst_path = tmp_path / "sample.json.zst" + zst_path.write_bytes(compressed_content) + + decompressed_path = decompress_zst(str(zst_path)) + + assert decompressed_path.endswith(".json") + assert os.path.exists(decompressed_path) + + with open(decompressed_path, "r") as f: + decompressed_data = json.load(f) + + assert decompressed_data == expected_data From 770909ad50fab5e94f323b2bbb1d12182d6f1077 Mon Sep 17 00:00:00 2001 From: ddevin96 Date: Fri, 24 Apr 2026 10:47:52 +0200 Subject: [PATCH 09/15] feat: renamed functions - changed scope of functions --- hyperbench/data/dataset.py | 4 +- hyperbench/data/hif.py | 53 ++++++++++++--------------- hyperbench/data/supported_datasets.py | 5 ++- hyperbench/tests/data/dataset_test.py | 52 +++++++++++++------------- hyperbench/tests/data/hif_test.py | 10 ++--- hyperbench/utils/__init__.py | 4 +- hyperbench/utils/file_utils.py | 2 +- 7 files changed, 62 insertions(+), 68 deletions(-) diff --git a/hyperbench/data/dataset.py b/hyperbench/data/dataset.py index 55ca3b3..146d0ad 100644 --- a/hyperbench/data/dataset.py +++ b/hyperbench/data/dataset.py @@ -23,7 +23,7 @@ class Dataset(TorchDataset): """ A dataset class for loading and processing hypergraph data. - Attributes: + Args: DATASET_NAME: Class variable indicating the name of the dataset to load. hypergraph: The loaded hypergraph in HIF format. Can be ``None`` if initialized from an HData object. hdata: The processed hypergraph data in HData format. @@ -319,14 +319,12 @@ def __get_hyperedge_ids_permutation( ranged_hyperedge_ids_permutation = torch.arange(num_hyperedges, device=device) return ranged_hyperedge_ids_permutation - @staticmethod def transform_node_attrs( attrs: Dict[str, Any], attr_keys: Optional[List[str]] = None, ) -> Tensor: return HIFProcessor.transform_attrs(attrs, attr_keys) - @staticmethod def transform_hyperedge_attrs( attrs: Dict[str, Any], attr_keys: Optional[List[str]] = None, diff --git a/hyperbench/data/hif.py b/hyperbench/data/hif.py index 3afeb7a..77686b1 100644 --- a/hyperbench/data/hif.py +++ b/hyperbench/data/hif.py @@ -5,6 +5,7 @@ import requests import tempfile import warnings + from huggingface_hub import hf_hub_download from typing import Optional, Dict, Any, List from torch import Tensor @@ -15,8 +16,8 @@ decompress_zst, compress_to_zst, validate_http_url, + write_to_disk, ) -from hyperbench.utils import save_on_disk as save class HIFProcessor: @@ -54,8 +55,8 @@ def transform_attrs( values = [float(value) for value in numeric_attrs.values()] return torch.tensor(values, dtype=torch.float) - @staticmethod - def _process_hypergraph(hypergraph: HIFHypergraph) -> HData: + @classmethod + def process_hypergraph(cls, hypergraph: HIFHypergraph) -> HData: """ Process the loaded hypergraph into :class:`HData` format, mapping HIF structure to tensors. @@ -66,7 +67,7 @@ def _process_hypergraph(hypergraph: HIFHypergraph) -> HData: # raise ValueError("process can only be called for the original dataset.") num_nodes = len(hypergraph.nodes) - x = HIFProcessor._process_x(hypergraph, num_nodes) + x = cls.__process_x(hypergraph, num_nodes) # Remap node IDs to 0-based contiguous IDs (using indices) matching the x tensor order node_id_to_idx = {node.get("node"): idx for idx, node in enumerate(hypergraph.nodes)} @@ -99,13 +100,13 @@ def _process_hypergraph(hypergraph: HIFHypergraph) -> HData: hyperedge_ids.append(new_hyperedge_id) num_hyperedges = len(hyperedge_id_to_idx) - hyperedge_attr = HIFProcessor._process_hyperedge_attr( + hyperedge_attr = cls.__process_hyperedge_attr( hypergraph=hypergraph, hyperedge_id_to_idx=hyperedge_id_to_idx, num_hyperedges=num_hyperedges, ) - hyperedge_weights = HIFProcessor._process_hyperedge_weights( + hyperedge_weights = cls.__process_hyperedge_weights( hypergraph=hypergraph, hyperedge_id_to_idx=hyperedge_id_to_idx, num_hyperedges=num_hyperedges, @@ -122,8 +123,7 @@ def _process_hypergraph(hypergraph: HIFHypergraph) -> HData: num_hyperedges=num_hyperedges, ) - @staticmethod - def _collect_attr_keys(attr_keys: List[Dict[str, Any]]) -> List[str]: + def __collect_attr_keys(attr_keys: List[Dict[str, Any]]) -> List[str]: """ Collect unique numeric attribute keys from a list of attribute dictionaries. @@ -141,8 +141,9 @@ def _collect_attr_keys(attr_keys: List[Dict[str, Any]]) -> List[str]: return unique_keys - @staticmethod - def _process_hyperedge_attr( + @classmethod + def __process_hyperedge_attr( + cls, hypergraph: HIFHypergraph, hyperedge_id_to_idx: Dict[Any, int], num_hyperedges: int, @@ -159,9 +160,7 @@ def _process_hyperedge_attr( e.get("edge"): e.get("attrs", {}) for e in hypergraph.hyperedges } - hyperedge_attr_keys = HIFProcessor._collect_attr_keys( - list(hyperedge_id_to_attrs.values()) - ) + hyperedge_attr_keys = cls.__collect_attr_keys(list(hyperedge_id_to_attrs.values())) # Build attributes in exact order of hyperedge_set indices (0 to num_hyperedges - 1) hyperedge_idx_to_id = {idx: id for id, idx in hyperedge_id_to_idx.items()} @@ -181,10 +180,10 @@ def _process_hyperedge_attr( return hyperedge_attr - @staticmethod - def _process_x(hypergraph: HIFHypergraph, num_nodes: int) -> Tensor: + @classmethod + def __process_x(cls, hypergraph: HIFHypergraph, num_nodes: int) -> Tensor: # Collect all attribute keys to have tensors of same size - node_attr_keys = HIFProcessor._collect_attr_keys( + node_attr_keys = cls.__collect_attr_keys( [node.get("attrs", {}) for node in hypergraph.nodes] ) @@ -202,8 +201,9 @@ def _process_x(hypergraph: HIFHypergraph, num_nodes: int) -> Tensor: return x # shape [num_nodes, num_node_features] - @staticmethod - def _process_hyperedge_weights( + @classmethod + def __process_hyperedge_weights( + cls, hypergraph: HIFHypergraph, hyperedge_id_to_idx: Dict[Any, int], num_hyperedges: int, @@ -236,7 +236,6 @@ def _process_hyperedge_weights( class HIFLoader: """A utility class to load hypergraphs from HIF format.""" - @staticmethod def load_from_url(url: str, save_on_disk: bool = False) -> HData: """ Load a hypergraph from a given URL pointing to a .json or .json.zst file in HIF format. @@ -262,12 +261,12 @@ def load_from_url(url: str, save_on_disk: bool = False) -> HData: if zst_filename.endswith(".zst"): if save_on_disk: - save(os.path.basename(url), response.content) + write_to_disk(os.path.basename(url), response.content) output = decompress_zst(zst_filename) elif zst_filename.endswith(".json"): if save_on_disk: compressed = compress_to_zst(zst_filename) - save(os.path.basename(url), compressed) + write_to_disk(os.path.basename(url), compressed) output = zst_filename else: raise ValueError( @@ -275,10 +274,9 @@ def load_from_url(url: str, save_on_disk: bool = False) -> HData: ) hypergraph = HIFLoader.__extract_hif(output) - hdata = HIFProcessor._process_hypergraph(hypergraph) + hdata = HIFProcessor.process_hypergraph(hypergraph) return hdata - @staticmethod def load_from_path(filepath: str) -> HData: """ Load a hypergraph from a local file path pointing to a .json or .json.zst file in HIF format. @@ -301,12 +299,10 @@ def load_from_path(filepath: str) -> HData: ) hypergraph = HIFLoader.__extract_hif(output) - hdata = HIFProcessor._process_hypergraph(hypergraph) + hdata = HIFProcessor.process_hypergraph(hypergraph) return hdata - @staticmethod - def load_from_name(dataset_name: str, save_on_disk: bool = False) -> HData: - print(f"Loading dataset '{dataset_name}' from disk or remote sources...") + def load_by_name(dataset_name: str, save_on_disk: bool = False) -> HData: current_dir = os.path.dirname(os.path.abspath(__file__)) zst_filename = os.path.join(current_dir, "datasets", f"{dataset_name}.json.zst") @@ -345,7 +341,6 @@ def load_from_name(dataset_name: str, save_on_disk: bool = False) -> HData: response._content = hf_content if save_on_disk: - print(f"Saving downloaded dataset '{dataset_name}' to disk at '{zst_filename}'") os.makedirs(os.path.join(current_dir, "datasets"), exist_ok=True) with open(zst_filename, "wb") as f: f.write(response.content) @@ -359,7 +354,7 @@ def load_from_name(dataset_name: str, save_on_disk: bool = False) -> HData: output = decompress_zst(zst_filename) hypergraph = HIFLoader.__extract_hif(output) - hdata = HIFProcessor._process_hypergraph(hypergraph) + hdata = HIFProcessor.process_hypergraph(hypergraph) return hdata @staticmethod diff --git a/hyperbench/data/supported_datasets.py b/hyperbench/data/supported_datasets.py index 4f28bd7..7a21741 100644 --- a/hyperbench/data/supported_datasets.py +++ b/hyperbench/data/supported_datasets.py @@ -7,6 +7,9 @@ class PreloadedDataset(Dataset): """ Base class for datasets that use default loading. Subclasses should specify the DATASET_NAME class variable. The dataset will be saved on disk after the first load. + Args: + hdata: Optional HData object. If None, the dataset will be loaded using the DATASET_NAME. + sampling_strategy: The sampling strategy to use for this dataset. Default is SamplingStrategy.HYPEREDGE. """ DATASET_NAME = "" @@ -18,7 +21,7 @@ def __init__( ) -> None: super().__init__(hdata=hdata, sampling_strategy=sampling_strategy) if hdata is None: - self.hdata = HIFLoader.load_from_name(self.DATASET_NAME, save_on_disk=True) + self.hdata = HIFLoader.load_by_name(self.DATASET_NAME, save_on_disk=True) class AlgebraDataset(PreloadedDataset): diff --git a/hyperbench/tests/data/dataset_test.py b/hyperbench/tests/data/dataset_test.py index 856acba..7c13b61 100644 --- a/hyperbench/tests/data/dataset_test.py +++ b/hyperbench/tests/data/dataset_test.py @@ -107,7 +107,7 @@ def test_Preloaded_dataset_init(): def test_Preloaded_dataset_loads_hdata_when_hdata_is_none(): mock_hdata = MagicMock(spec=HData) - with patch.object(HIFLoader, "load_from_name", return_value=mock_hdata) as mock_load: + with patch.object(HIFLoader, "load_by_name", return_value=mock_hdata) as mock_load: dataset = AlgebraDataset(hdata=None) assert dataset.hdata == mock_hdata @@ -125,7 +125,7 @@ def test_dataset_is_available_with_all_strategies( strategy, expected_len, mock_hdata_four_node_hypergraph ): - with patch.object(HIFLoader, "load_from_name", return_value=mock_hdata_four_node_hypergraph): + with patch.object(HIFLoader, "load_by_name", return_value=mock_hdata_four_node_hypergraph): dataset = AlgebraDataset(sampling_strategy=strategy) assert dataset.DATASET_NAME == "algebra" @@ -133,7 +133,7 @@ def test_dataset_is_available_with_all_strategies( def test_dataset_process_no_incidences(mock_hdata_no_incidences): - with patch.object(HIFLoader, "load_from_name", return_value=mock_hdata_no_incidences): + with patch.object(HIFLoader, "load_by_name", return_value=mock_hdata_no_incidences): dataset = AlgebraDataset() assert dataset.hdata is not None @@ -144,9 +144,7 @@ def test_dataset_process_no_incidences(mock_hdata_no_incidences): def test_dataset_process_with_edge_attributes(mock_hdata_with_two_edge_attributes): - with patch.object( - HIFLoader, "load_from_name", return_value=mock_hdata_with_two_edge_attributes - ): + with patch.object(HIFLoader, "load_by_name", return_value=mock_hdata_with_two_edge_attributes): dataset = AlgebraDataset() assert dataset.hdata is not None @@ -161,7 +159,7 @@ def test_dataset_process_with_edge_attributes(mock_hdata_with_two_edge_attribute def test_dataset_process_without_edge_attributes(mock_hdata_no_edge_attr_hypergraph): - with patch.object(HIFLoader, "load_from_name", return_value=mock_hdata_no_edge_attr_hypergraph): + with patch.object(HIFLoader, "load_by_name", return_value=mock_hdata_no_edge_attr_hypergraph): dataset = AlgebraDataset() assert dataset.hdata is not None @@ -171,7 +169,7 @@ def test_dataset_process_without_edge_attributes(mock_hdata_no_edge_attr_hypergr def test_dataset_process_hyperedge_index_in_correct_format(mock_hdata_four_node_hypergraph): - with patch.object(HIFLoader, "load_from_name", return_value=mock_hdata_four_node_hypergraph): + with patch.object(HIFLoader, "load_by_name", return_value=mock_hdata_four_node_hypergraph): dataset = AlgebraDataset() assert dataset.hdata.hyperedge_index.shape == (2, 4) @@ -180,7 +178,7 @@ def test_dataset_process_hyperedge_index_in_correct_format(mock_hdata_four_node_ def test_dataset_process_random_ids(mock_hdata_random_ids): - with patch.object(HIFLoader, "load_from_name", return_value=mock_hdata_random_ids): + with patch.object(HIFLoader, "load_by_name", return_value=mock_hdata_random_ids): dataset = AlgebraDataset() assert dataset.hdata.hyperedge_index.shape == (2, 3) @@ -197,7 +195,7 @@ def test_dataset_process_random_ids(mock_hdata_random_ids): ], ) def test_getitem_index_list_empty(mock_hdata_simple_hypergraph, strategy): - with patch.object(HIFLoader, "load_from_name", return_value=mock_hdata_simple_hypergraph): + with patch.object(HIFLoader, "load_by_name", return_value=mock_hdata_simple_hypergraph): dataset = AlgebraDataset(sampling_strategy=strategy) with pytest.raises(ValueError, match="Index list cannot be empty."): @@ -224,7 +222,7 @@ def test_getitem_index_list_empty(mock_hdata_simple_hypergraph, strategy): def test_getitem_raises_when_index_list_larger_than_max( mock_hdata_four_node_hypergraph, strategy, index_list, expected_message ): - with patch.object(HIFLoader, "load_from_name", return_value=mock_hdata_four_node_hypergraph): + with patch.object(HIFLoader, "load_by_name", return_value=mock_hdata_four_node_hypergraph): dataset = AlgebraDataset(sampling_strategy=strategy) with pytest.raises(ValueError, match=expected_message): @@ -248,7 +246,7 @@ def test_getitem_raises_when_index_list_larger_than_max( def test_getitem_raises_when_index_out_of_bounds( mock_hdata_four_node_hypergraph, strategy, index, expected_message ): - with patch.object(HIFLoader, "load_from_name", return_value=mock_hdata_four_node_hypergraph): + with patch.object(HIFLoader, "load_by_name", return_value=mock_hdata_four_node_hypergraph): dataset = AlgebraDataset(sampling_strategy=strategy) with pytest.raises(IndexError, match=expected_message): @@ -267,7 +265,7 @@ def test_getitem_raises_when_index_out_of_bounds( def test_getitem_single_index( mock_hdata_sample_hypergraph, strategy, index, expected_shape, expected_num_hyperedges ): - with patch.object(HIFLoader, "load_from_name", return_value=mock_hdata_sample_hypergraph): + with patch.object(HIFLoader, "load_by_name", return_value=mock_hdata_sample_hypergraph): dataset = AlgebraDataset(sampling_strategy=strategy) data = dataset[index] @@ -288,7 +286,7 @@ def test_getitem_single_index( def test_getitem_when_list_index_provided( mock_hdata_four_node_hypergraph, strategy, index, expected_shape, expected_num_hyperedges ): - with patch.object(HIFLoader, "load_from_name", return_value=mock_hdata_four_node_hypergraph): + with patch.object(HIFLoader, "load_by_name", return_value=mock_hdata_four_node_hypergraph): dataset = AlgebraDataset(sampling_strategy=strategy) data = dataset[index] @@ -306,7 +304,7 @@ def test_getitem_when_list_index_provided( ) def test_getitem_with_edge_attr(mock_hdata_three_node_weighted_hypergraph, strategy): with patch.object( - HIFLoader, "load_from_name", return_value=mock_hdata_three_node_weighted_hypergraph + HIFLoader, "load_by_name", return_value=mock_hdata_three_node_weighted_hypergraph ): dataset = AlgebraDataset(sampling_strategy=strategy) @@ -325,7 +323,7 @@ def test_getitem_with_edge_attr(mock_hdata_three_node_weighted_hypergraph, strat ], ) def test_getitem_without_edge_attr(mock_hdata_no_edge_attr_hypergraph, strategy): - with patch.object(HIFLoader, "load_from_name", return_value=mock_hdata_no_edge_attr_hypergraph): + with patch.object(HIFLoader, "load_by_name", return_value=mock_hdata_no_edge_attr_hypergraph): dataset = AlgebraDataset(sampling_strategy=strategy) data = dataset[0] @@ -345,7 +343,7 @@ def test_getitem_with_multiple_edges_attr( mock_hdata_multiple_edges_attr_hypergraph, strategy, index ): with patch.object( - HIFLoader, "load_from_name", return_value=mock_hdata_multiple_edges_attr_hypergraph + HIFLoader, "load_by_name", return_value=mock_hdata_multiple_edges_attr_hypergraph ): dataset = AlgebraDataset(sampling_strategy=strategy) @@ -559,7 +557,7 @@ def test_remove_hyperedges_with_fewer_than_k_nodes(hyperedge_index, k, expected_ def test_split_with_equal_ratios(mock_hdata_four_node_hypergraph): - with patch.object(HIFLoader, "load_from_name", return_value=mock_hdata_four_node_hypergraph): + with patch.object(HIFLoader, "load_by_name", return_value=mock_hdata_four_node_hypergraph): dataset = AlgebraDataset() splits = dataset.split([0.5, 0.5]) @@ -577,7 +575,7 @@ def test_split_with_equal_ratios(mock_hdata_four_node_hypergraph): def test_split_three_way(mock_hdata_multiple_edges_attr_hypergraph): with patch.object( - HIFLoader, "load_from_name", return_value=mock_hdata_multiple_edges_attr_hypergraph + HIFLoader, "load_by_name", return_value=mock_hdata_multiple_edges_attr_hypergraph ): dataset = AlgebraDataset() @@ -594,7 +592,7 @@ def test_split_three_way(mock_hdata_multiple_edges_attr_hypergraph): def test_split_raises_when_ratios_do_not_sum_to_one(mock_hdata_four_node_hypergraph): - with patch.object(HIFLoader, "load_from_name", return_value=mock_hdata_four_node_hypergraph): + with patch.object(HIFLoader, "load_by_name", return_value=mock_hdata_four_node_hypergraph): dataset = AlgebraDataset() with pytest.raises(ValueError, match="Split ratios must sum to 1.0"): @@ -604,7 +602,7 @@ def test_split_raises_when_ratios_do_not_sum_to_one(mock_hdata_four_node_hypergr def test_split_with_shuffle_produces_deterministic_results_when_seed_provided( mock_hdata_four_node_hypergraph, ): - with patch.object(HIFLoader, "load_from_name", return_value=mock_hdata_four_node_hypergraph): + with patch.object(HIFLoader, "load_by_name", return_value=mock_hdata_four_node_hypergraph): dataset = AlgebraDataset() splits_a = dataset.split([0.5, 0.5], shuffle=True, seed=42) @@ -617,7 +615,7 @@ def test_split_with_shuffle_produces_deterministic_results_when_seed_provided( def test_split_with_shuffle_when_no_seed_provided( mock_hdata_four_node_hypergraph, ): - with patch.object(HIFLoader, "load_from_name", return_value=mock_hdata_four_node_hypergraph): + with patch.object(HIFLoader, "load_by_name", return_value=mock_hdata_four_node_hypergraph): dataset = AlgebraDataset() splits = dataset.split([0.5, 0.5], shuffle=True) @@ -634,7 +632,7 @@ def test_split_with_shuffle_when_no_seed_provided( def test_split_preserves_edge_attr(mock_hdata_multiple_edges_attr_hypergraph): with patch.object( - HIFLoader, "load_from_name", return_value=mock_hdata_multiple_edges_attr_hypergraph + HIFLoader, "load_by_name", return_value=mock_hdata_multiple_edges_attr_hypergraph ): dataset = AlgebraDataset() @@ -646,7 +644,7 @@ def test_split_preserves_edge_attr(mock_hdata_multiple_edges_attr_hypergraph): def test_split_without_edge_attr(mock_hdata_no_edge_attr_hypergraph): - with patch.object(HIFLoader, "load_from_name", return_value=mock_hdata_no_edge_attr_hypergraph): + with patch.object(HIFLoader, "load_by_name", return_value=mock_hdata_no_edge_attr_hypergraph): dataset = AlgebraDataset() splits = dataset.split([0.5, 0.5]) @@ -667,7 +665,7 @@ def test_to_device(mock_hdata): def test_default_sampling_strategy_is_hyperedge(mock_hdata_four_node_hypergraph): - with patch.object(HIFLoader, "load_from_name", return_value=mock_hdata_four_node_hypergraph): + with patch.object(HIFLoader, "load_by_name", return_value=mock_hdata_four_node_hypergraph): dataset = AlgebraDataset() # Default strategy is HYPEREDGE, so len should be num_hyperedges (2), not num_nodes (4) @@ -676,7 +674,7 @@ def test_default_sampling_strategy_is_hyperedge(mock_hdata_four_node_hypergraph) def test_explicit_node_sampling_strategy(mock_hdata_four_node_hypergraph): - with patch.object(HIFLoader, "load_from_name", return_value=mock_hdata_four_node_hypergraph): + with patch.object(HIFLoader, "load_by_name", return_value=mock_hdata_four_node_hypergraph): dataset = AlgebraDataset(sampling_strategy=SamplingStrategy.NODE) # NODE strategy, so len should be num_nodes (4), not num_hyperedges (2) @@ -692,7 +690,7 @@ def test_explicit_node_sampling_strategy(mock_hdata_four_node_hypergraph): ], ) def test_split_preserves_sampling_strategy(mock_hdata_four_node_hypergraph, strategy): - with patch.object(HIFLoader, "load_from_name", return_value=mock_hdata_four_node_hypergraph): + with patch.object(HIFLoader, "load_by_name", return_value=mock_hdata_four_node_hypergraph): dataset = AlgebraDataset(sampling_strategy=strategy) splits = dataset.split([0.5, 0.5]) diff --git a/hyperbench/tests/data/hif_test.py b/hyperbench/tests/data/hif_test.py index bd4bb68..c9c9b4d 100644 --- a/hyperbench/tests/data/hif_test.py +++ b/hyperbench/tests/data/hif_test.py @@ -454,7 +454,7 @@ def test_load_skips_download_when_file_exists(tmp_path, mock_hypergraph): patch("hyperbench.data.hif.decompress_zst", return_value=json_path), patch("hyperbench.data.hif.validate_hif_json", return_value=True), ): - result = HIFLoader.load_from_name("algebra", save_on_disk=True) + result = HIFLoader.load_by_name("algebra", save_on_disk=True) mock_get.assert_not_called() assert result.num_nodes == 2 @@ -473,7 +473,7 @@ def test_HIFLoader_download_failure_when_hf_fallback_fails(): with pytest.warns(UserWarning, match="GitHub raw download failed"): with pytest.raises(ValueError, match="Failed to download dataset 'algebra'"): - HIFLoader.load_from_name("algebra") + HIFLoader.load_by_name("algebra") def test_HIFLoader_falls_back_to_hf_hub_download_when_github_raw_download_fails( @@ -497,7 +497,7 @@ def test_HIFLoader_falls_back_to_hf_hub_download_when_github_raw_download_fails( mock_response.content = b"" with pytest.warns(UserWarning, match="GitHub raw download failed"): - result = HIFLoader.load_from_name("algebra", save_on_disk=False) + result = HIFLoader.load_by_name("algebra", save_on_disk=False) mock_get.assert_called_once() mock_hf_hub_download.assert_called_once() @@ -519,7 +519,7 @@ def test_load_saves_downloaded_dataset_on_disk(tmp_path, mock_hypergraph): mock_response.status_code = 200 mock_response.content = b"downloaded-content" - result = HIFLoader.load_from_name("algebra", save_on_disk=True) + result = HIFLoader.load_by_name("algebra", save_on_disk=True) saved = tmp_path / "datasets" / "algebra.json.zst" assert saved.exists() @@ -537,4 +537,4 @@ def test_HIFLoader_download_raises_when_network_error(): ), ): with pytest.raises(requests.RequestException, match="Network error"): - HIFLoader.load_from_name("algebra") + HIFLoader.load_by_name("algebra") diff --git a/hyperbench/utils/__init__.py b/hyperbench/utils/__init__.py index e9aff4e..5763312 100644 --- a/hyperbench/utils/__init__.py +++ b/hyperbench/utils/__init__.py @@ -17,7 +17,7 @@ ) from .sparse_utils import sparse_dropout from .url_utils import validate_http_url -from .file_utils import decompress_zst, compress_to_zst, save_on_disk +from .file_utils import decompress_zst, compress_to_zst, write_to_disk __all__ = [ "INPUT_LAYER", @@ -37,5 +37,5 @@ "decompress_zst", "compress_to_zst", "validate_http_url", - "save_on_disk", + "write_to_disk", ] diff --git a/hyperbench/utils/file_utils.py b/hyperbench/utils/file_utils.py index 73573d9..b99aa8c 100644 --- a/hyperbench/utils/file_utils.py +++ b/hyperbench/utils/file_utils.py @@ -21,7 +21,7 @@ def compress_to_zst(json_path: str) -> bytes: return compressed_content -def save_on_disk(dataset_name: str, content: bytes) -> None: +def write_to_disk(dataset_name: str, content: bytes) -> None: current_dir = os.path.dirname(os.path.abspath(__file__)) datasets_dir = os.path.join(current_dir, "..", "data", "datasets") zst_filename = os.path.join(datasets_dir, f"{dataset_name}.json.zst") From 480718493a17690393aff8d92bac5ea41d6bd20a Mon Sep 17 00:00:00 2001 From: ddevin96 Date: Fri, 24 Apr 2026 12:51:44 +0200 Subject: [PATCH 10/15] feat: added codefactor badge - fix tests --- README.md | 1 + hyperbench/tests/data/hif_test.py | 39 +++++++++++-------------------- 2 files changed, 14 insertions(+), 26 deletions(-) diff --git a/README.md b/README.md index 516e305..e3d78fa 100644 --- a/README.md +++ b/README.md @@ -8,6 +8,7 @@ [![project_license][license-shield]][license-url] [![codecov](https://codecov.io/github/hypernetwork-research-group/hyperbench/graph/badge.svg?token=XE0TB5JMOS)](https://codecov.io/github/hypernetwork-research-group/hyperbench) +[![CodeFactor](https://www.codefactor.io/repository/github/hypernetwork-research-group/hyperbench/badge)](https://www.codefactor.io/repository/github/hypernetwork-research-group/hyperbench) ## About the project diff --git a/hyperbench/tests/data/hif_test.py b/hyperbench/tests/data/hif_test.py index c9c9b4d..0a8b5ea 100644 --- a/hyperbench/tests/data/hif_test.py +++ b/hyperbench/tests/data/hif_test.py @@ -6,8 +6,7 @@ from unittest.mock import patch, MagicMock -from hyperbench.data.hif import HIFLoader, HIFProcessor -import hyperbench.data.hif as hif_module +from hyperbench.data import HIFLoader, HIFProcessor from hyperbench.types import HData, HIFHypergraph @@ -290,13 +289,12 @@ def test_load_from_url_processes_zst_and_saves_to_disk(tmp_path, mock_hypergraph unique_name = f"algebra_{tmp_path.name}.json.zst" url = f"https://example.com/{unique_name}" json_path = _write_hif_json(tmp_path, mock_hypergraph) - current_dir = os.path.dirname(os.path.abspath(hif_module.__file__)) - saved_path = os.path.join(current_dir, "datasets", f"{unique_name}.json.zst") with ( patch("hyperbench.data.hif.requests.get") as mock_get, patch("hyperbench.data.hif.decompress_zst", return_value=json_path), patch("hyperbench.data.hif.validate_hif_json", return_value=True), + patch("hyperbench.data.hif.write_to_disk") as mock_write_to_disk, ): mock_response = mock_get.return_value mock_response.status_code = 200 @@ -304,13 +302,9 @@ def test_load_from_url_processes_zst_and_saves_to_disk(tmp_path, mock_hypergraph hdata = HIFLoader.load_from_url(url, save_on_disk=True) - try: - assert os.path.exists(saved_path) - assert hdata.num_nodes == 2 - assert hdata.num_hyperedges == 1 - finally: - if os.path.exists(saved_path): - os.remove(saved_path) + mock_write_to_disk.assert_called_once_with(unique_name, b"mock-zst-content") + assert hdata.num_nodes == 2 + assert hdata.num_hyperedges == 1 def test_load_from_url_processes_json_and_saves_compressed_copy(tmp_path, mock_hypergraph): @@ -323,8 +317,6 @@ def test_load_from_url_processes_json_and_saves_compressed_copy(tmp_path, mock_h "edges": mock_hypergraph.hyperedges, "incidences": mock_hypergraph.incidences, } - current_dir = os.path.dirname(os.path.abspath(hif_module.__file__)) - saved_path = os.path.join(current_dir, "datasets", f"{unique_name}.json.zst") with ( patch("hyperbench.data.hif.requests.get") as mock_get, @@ -334,6 +326,7 @@ def test_load_from_url_processes_json_and_saves_compressed_copy(tmp_path, mock_h ), patch("hyperbench.data.hif.compress_to_zst", return_value=b"compressed") as mock_compress, patch("hyperbench.data.hif.validate_hif_json", return_value=True), + patch("hyperbench.data.hif.write_to_disk") as mock_write_to_disk, ): mock_response = mock_get.return_value mock_response.status_code = 200 @@ -342,26 +335,21 @@ def test_load_from_url_processes_json_and_saves_compressed_copy(tmp_path, mock_h hdata = HIFLoader.load_from_url(url, save_on_disk=True) mock_compress.assert_called_once_with(str(tmp_path / "downloaded.json")) - try: - assert os.path.exists(saved_path) - assert hdata.num_nodes == 2 - assert hdata.num_hyperedges == 1 - finally: - if os.path.exists(saved_path): - os.remove(saved_path) + mock_write_to_disk.assert_called_once_with(unique_name, b"compressed") + assert hdata.num_nodes == 2 + assert hdata.num_hyperedges == 1 def test_load_from_url_processes_zst_without_saving_to_disk(tmp_path, mock_hypergraph): unique_name = f"algebra_{tmp_path.name}.json.zst" url = f"https://example.com/{unique_name}" json_path = _write_hif_json(tmp_path, mock_hypergraph) - current_dir = os.path.dirname(os.path.abspath(hif_module.__file__)) - saved_path = os.path.join(current_dir, "datasets", f"{unique_name}.json.zst") with ( patch("hyperbench.data.hif.requests.get") as mock_get, patch("hyperbench.data.hif.decompress_zst", return_value=json_path) as mock_decompress, patch("hyperbench.data.hif.validate_hif_json", return_value=True), + patch("hyperbench.data.hif.write_to_disk") as mock_write_to_disk, ): mock_response = mock_get.return_value mock_response.status_code = 200 @@ -369,7 +357,7 @@ def test_load_from_url_processes_zst_without_saving_to_disk(tmp_path, mock_hyper hdata = HIFLoader.load_from_url(url, save_on_disk=False) - assert not os.path.exists(saved_path) + mock_write_to_disk.assert_not_called() mock_decompress.assert_called_once() assert hdata.num_nodes == 2 assert hdata.num_hyperedges == 1 @@ -385,8 +373,6 @@ def test_load_from_url_processes_json_without_saving_to_disk(tmp_path, mock_hype "edges": mock_hypergraph.hyperedges, "incidences": mock_hypergraph.incidences, } - current_dir = os.path.dirname(os.path.abspath(hif_module.__file__)) - saved_path = os.path.join(current_dir, "datasets", f"{unique_name}.json.zst") with ( patch("hyperbench.data.hif.requests.get") as mock_get, @@ -396,6 +382,7 @@ def test_load_from_url_processes_json_without_saving_to_disk(tmp_path, mock_hype ), patch("hyperbench.data.hif.compress_to_zst") as mock_compress, patch("hyperbench.data.hif.validate_hif_json", return_value=True), + patch("hyperbench.data.hif.write_to_disk") as mock_write_to_disk, ): mock_response = mock_get.return_value mock_response.status_code = 200 @@ -404,7 +391,7 @@ def test_load_from_url_processes_json_without_saving_to_disk(tmp_path, mock_hype hdata = HIFLoader.load_from_url(url, save_on_disk=False) mock_compress.assert_not_called() - assert not os.path.exists(saved_path) + mock_write_to_disk.assert_not_called() assert hdata.num_nodes == 2 assert hdata.num_hyperedges == 1 From 0a1d667e983dfe4cfb48561b660a6bcb7550d661 Mon Sep 17 00:00:00 2001 From: ddevin96 Date: Fri, 24 Apr 2026 13:11:41 +0200 Subject: [PATCH 11/15] fix: removed prepare --- examples/hyperedge_enricher.py | 5 +---- examples/node2vec.py | 2 +- examples/node_enricher.py | 2 +- 3 files changed, 3 insertions(+), 6 deletions(-) diff --git a/examples/hyperedge_enricher.py b/examples/hyperedge_enricher.py index dfc9e3c..3bb6604 100644 --- a/examples/hyperedge_enricher.py +++ b/examples/hyperedge_enricher.py @@ -5,11 +5,9 @@ if __name__ == "__main__": print("Loading and preparing dataset...\n") - dataset = AlgebraDataset(sampling_strategy=SamplingStrategy.HYPEREDGE, prepare=True) + dataset = AlgebraDataset(sampling_strategy=SamplingStrategy.HYPEREDGE) print("Enriching hyperedge weights...") - - dataset = AlgebraDataset(sampling_strategy=sampling_strategy) # HyperedgeWeightsEnricher enriches hyperedges with their degree (number of nodes in each hyperedge) as weights. # It optionally applies scaling and adds a constant to the weights. dataset.enrich_hyperedge_weights( @@ -26,7 +24,6 @@ print("Enriching hyperedge attributes...") - dataset = AlgebraDataset(sampling_strategy=sampling_strategy) # HyperedgeAttrsEnricher adds a feature of 1.0 for each hyperedge, which can be used as a baseline or for methods that require hyperedge features. dataset.enrich_hyperedge_attr( enricher=HyperedgeAttrsEnricher(), diff --git a/examples/node2vec.py b/examples/node2vec.py index 63a4644..bbd09d9 100644 --- a/examples/node2vec.py +++ b/examples/node2vec.py @@ -29,7 +29,7 @@ print("Loading and preparing dataset...") - dataset = AlgebraDataset(sampling_strategy=sampling_strategy, prepare=True) + dataset = AlgebraDataset(sampling_strategy=sampling_strategy) dataset.remove_hyperedges_with_fewer_than_k_nodes(k=2) if verbose: print(f"Dataset:\n {dataset.hdata}\n") diff --git a/examples/node_enricher.py b/examples/node_enricher.py index 0803c2c..8149e2b 100644 --- a/examples/node_enricher.py +++ b/examples/node_enricher.py @@ -5,7 +5,7 @@ if __name__ == "__main__": print("Loading and preparing dataset...") - dataset = AlgebraDataset(sampling_strategy=sampling_strategy) + dataset = AlgebraDataset(sampling_strategy=SamplingStrategy.HYPEREDGE) # NodeEnricher adds features for each node. dataset.enrich_node_features( enricher=LaplacianPositionalEncodingEnricher(num_features=32), From 088eab2a83a5d783a3ef31829fb01d300ff13036 Mon Sep 17 00:00:00 2001 From: ddevin96 Date: Fri, 24 Apr 2026 13:42:19 +0200 Subject: [PATCH 12/15] feat: removed main in supported dataset - added test file for file_utils --- hyperbench/data/supported_datasets.py | 31 ------------------- hyperbench/tests/utils/file_utils_test.py | 37 +++++++++++++++++++++++ hyperbench/utils/file_utils.py | 37 ++++++++++++++++++++--- 3 files changed, 69 insertions(+), 36 deletions(-) create mode 100644 hyperbench/tests/utils/file_utils_test.py diff --git a/hyperbench/data/supported_datasets.py b/hyperbench/data/supported_datasets.py index 7a21741..b08f777 100644 --- a/hyperbench/data/supported_datasets.py +++ b/hyperbench/data/supported_datasets.py @@ -114,34 +114,3 @@ class TwitterDataset(PreloadedDataset): class VegasBarsReviewsDataset(PreloadedDataset): DATASET_NAME = "vegas-bars-reviews" - - -if __name__ == "__main__": - # test loading each dataset - for dataset_cls in [ - AlgebraDataset, - AmazonDataset, - ContactHighSchoolDataset, - # ContactPrimarySchoolDataset, - # CoraDataset, - # CourseraDataset, - # DBLPDataset, - # EmailEnronDataset, - # EmailW3CDataset, - # GeometryDataset, - # GOTDataset, - # IMDBDataset, - # MusicBluesReviewsDataset, - # NBADataset, - # NDCClassesDataset, - # NDCSubstancesDataset, - # PatentDataset, - # PubmedDataset, - # RestaurantReviewsDataset, - # ThreadsAskUbuntuDataset, - # ThreadsMathsxDataset, - # TwitterDataset, - # VegasBarsReviewsDataset, - ]: - dataset = dataset_cls() - print(dataset.hdata.num_nodes, dataset.hdata.num_hyperedges) diff --git a/hyperbench/tests/utils/file_utils_test.py b/hyperbench/tests/utils/file_utils_test.py new file mode 100644 index 0000000..6f703b0 --- /dev/null +++ b/hyperbench/tests/utils/file_utils_test.py @@ -0,0 +1,37 @@ +import os + +from hyperbench.utils import write_to_disk +from unittest.mock import patch, MagicMock + + +def test_write_to_disk_writes_file_default_output_dir(tmp_path): + dataset_name = "test_dataset" + content = b"test content" + + # Force write_to_disk default branch to resolve under tmp_path. + fake_module_file = tmp_path / "hyperbench" / "utils" / "file_utils.py" + + with patch( + "hyperbench.utils.file_utils.os.path.abspath", + return_value=str(fake_module_file), + ): + write_to_disk(dataset_name, content) + + expected_path = tmp_path / "hyperbench" / "data" / "datasets" / f"{dataset_name}.json.zst" + assert expected_path.is_file() + assert expected_path.read_bytes() == content + + +def test_write_to_disk_writes_file_optional_output_dir(tmp_path): + dataset_name = "test_dataset" + content = b"test content" + output_dir = tmp_path + + write_to_disk(dataset_name, content, output_dir) + + expected_path = tmp_path / f"{dataset_name}.json.zst" + assert expected_path.is_file() + + with open(expected_path, "rb") as f: + file_content = f.read() + assert file_content == content diff --git a/hyperbench/utils/file_utils.py b/hyperbench/utils/file_utils.py index b99aa8c..0931efa 100644 --- a/hyperbench/utils/file_utils.py +++ b/hyperbench/utils/file_utils.py @@ -1,9 +1,17 @@ +from typing import Optional import zstandard as zstd import tempfile import os def decompress_zst(zst_path: str) -> str: + """ + Decompresses a .zst file and returns the path to the decompressed JSON file. + Args: + zst_path: The path to the .zst file to decompress. + Returns: + The path to the decompressed JSON file. + """ dctx = zstd.ZstdDecompressor() with ( open(zst_path, "rb") as input_f, @@ -15,17 +23,36 @@ def decompress_zst(zst_path: str) -> str: def compress_to_zst(json_path: str) -> bytes: + """ + Compresses a JSON file to .zst format and returns the compressed bytes. + + Args: + json_path: The path to the JSON file to compress. + Returns: + The compressed content as bytes. + """ cctx = zstd.ZstdCompressor() with open(json_path, "rb") as input_f: compressed_content = cctx.compress(input_f.read()) return compressed_content -def write_to_disk(dataset_name: str, content: bytes) -> None: - current_dir = os.path.dirname(os.path.abspath(__file__)) - datasets_dir = os.path.join(current_dir, "..", "data", "datasets") - zst_filename = os.path.join(datasets_dir, f"{dataset_name}.json.zst") - os.makedirs(datasets_dir, exist_ok=True) +def write_to_disk(dataset_name: str, content: bytes, output_dir: Optional[str] = None) -> None: + """ + Writes the compressed content to disk in the specified output directory or a default location. + Args: + dataset_name: The name of the dataset. + content: The compressed content as bytes. + output_dir: The directory to write the file to. If None, a default location is used. + """ + if output_dir is not None: + zst_filename = os.path.join(output_dir, f"{dataset_name}.json.zst") + else: + current_dir = os.path.dirname(os.path.abspath(__file__)) + output_dir = os.path.join(current_dir, "..", "data", "datasets") + zst_filename = os.path.join(output_dir, f"{dataset_name}.json.zst") + + os.makedirs(output_dir, exist_ok=True) with open(zst_filename, "wb") as f: f.write(content) From 43ae18f20a9f080d121ccb6f301cb5803c2830f8 Mon Sep 17 00:00:00 2001 From: ddevin96 Date: Fri, 24 Apr 2026 13:51:35 +0200 Subject: [PATCH 13/15] chore: removed unused imports + order --- hyperbench/data/dataset.py | 16 ++++++---------- hyperbench/data/hif.py | 6 +++--- hyperbench/utils/file_utils.py | 7 ++++--- 3 files changed, 13 insertions(+), 16 deletions(-) diff --git a/hyperbench/data/dataset.py b/hyperbench/data/dataset.py index 146d0ad..ccd3b46 100644 --- a/hyperbench/data/dataset.py +++ b/hyperbench/data/dataset.py @@ -1,22 +1,18 @@ import json import os +import requests import tempfile import torch -import zstandard as zstd -import requests -import warnings -from typing import Any, Dict, List, Optional, TypeAlias, Literal + +from typing import Any, Dict, List, Optional, Literal from torch import Tensor from torch.utils.data import Dataset as TorchDataset -from hyperbench.nn import EnrichmentMode, NodeEnricher, HyperedgeEnricher -from hyperbench.types import HData, HIFHypergraph, HyperedgeIndex -from hyperbench.nn import EnrichmentMode -from hyperbench.utils import validate_hif_json - -from hyperbench.data.sampling import SamplingStrategy, create_sampler_from_strategy from hyperbench.data.hif import HIFLoader, HIFProcessor +from hyperbench.data.sampling import SamplingStrategy, create_sampler_from_strategy +from hyperbench.nn import EnrichmentMode, NodeEnricher, HyperedgeEnricher +from hyperbench.types import HData class Dataset(TorchDataset): diff --git a/hyperbench/data/hif.py b/hyperbench/data/hif.py index 77686b1..7e6edc5 100644 --- a/hyperbench/data/hif.py +++ b/hyperbench/data/hif.py @@ -1,10 +1,10 @@ -import torch -import os import json -import zstandard as zstd +import os import requests import tempfile +import torch import warnings +import zstandard as zstd from huggingface_hub import hf_hub_download from typing import Optional, Dict, Any, List diff --git a/hyperbench/utils/file_utils.py b/hyperbench/utils/file_utils.py index 0931efa..c263be2 100644 --- a/hyperbench/utils/file_utils.py +++ b/hyperbench/utils/file_utils.py @@ -1,7 +1,8 @@ -from typing import Optional -import zstandard as zstd -import tempfile import os +import tempfile +import zstandard as zstd + +from typing import Optional def decompress_zst(zst_path: str) -> str: From 056e24d970c4886fd9a79e0642d90189c4fd462a Mon Sep 17 00:00:00 2001 From: ddevin96 Date: Fri, 24 Apr 2026 14:12:36 +0200 Subject: [PATCH 14/15] fix: removed unsecure path from tests --- hyperbench/tests/data/dataset_test.py | 2 +- hyperbench/tests/data/hif_test.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/hyperbench/tests/data/dataset_test.py b/hyperbench/tests/data/dataset_test.py index 7c13b61..4453045 100644 --- a/hyperbench/tests/data/dataset_test.py +++ b/hyperbench/tests/data/dataset_test.py @@ -396,7 +396,7 @@ def test_from_url(strategy, mock_hdata): ], ) def test_from_path(strategy, mock_hdata): - filepath = "/tmp/sample.json.zst" + filepath = "/abc/sample.json.zst" with patch.object(HIFLoader, "load_from_path", return_value=mock_hdata) as mock_load_from_path: dataset = Dataset.from_path(filepath=filepath, sampling_strategy=strategy) diff --git a/hyperbench/tests/data/hif_test.py b/hyperbench/tests/data/hif_test.py index 0a8b5ea..e7e3838 100644 --- a/hyperbench/tests/data/hif_test.py +++ b/hyperbench/tests/data/hif_test.py @@ -217,7 +217,7 @@ def test_load_from_url_raises_when_status_is_not_200(): def test_load_from_path_raises_for_missing_file(): with pytest.raises(ValueError, match="does not exist"): - HIFLoader.load_from_path("/tmp/does-not-exist.json.zst") + HIFLoader.load_from_path("/abc/does-not-exist.json.zst") def test_load_from_path_raises_for_unsupported_extension(tmp_path): From 3c74d1e177ea787159c7dba37249003968f43f6e Mon Sep 17 00:00:00 2001 From: ddevin96 Date: Fri, 24 Apr 2026 15:34:41 +0200 Subject: [PATCH 15/15] feat: pinned version of external files for GitHub and HF - added relatives test --- hyperbench/data/hif.py | 34 ++++--- hyperbench/data/supported_datasets.py | 28 ++++- hyperbench/tests/data/dataset_test.py | 4 +- hyperbench/tests/data/hif_test.py | 124 +++++++++++++++++++++-- hyperbench/tests/utils/hif_utils_test.py | 67 +++++++++++- hyperbench/utils/__init__.py | 4 +- hyperbench/utils/hif_utils.py | 33 +++++- 7 files changed, 270 insertions(+), 24 deletions(-) diff --git a/hyperbench/data/hif.py b/hyperbench/data/hif.py index 7e6edc5..4485061 100644 --- a/hyperbench/data/hif.py +++ b/hyperbench/data/hif.py @@ -19,6 +19,8 @@ write_to_disk, ) +GITHUB_COMMIT_SHA = "3879b2ce84750e54f984ca06ce3246dff22c71c7" + class HIFProcessor: """A utility class to process HIF hypergraph data into :class:`HData` format.""" @@ -302,14 +304,15 @@ def load_from_path(filepath: str) -> HData: hdata = HIFProcessor.process_hypergraph(hypergraph) return hdata - def load_by_name(dataset_name: str, save_on_disk: bool = False) -> HData: + def load_by_name( + dataset_name: str, hf_sha: Optional[str] = None, save_on_disk: bool = False + ) -> HData: current_dir = os.path.dirname(os.path.abspath(__file__)) zst_filename = os.path.join(current_dir, "datasets", f"{dataset_name}.json.zst") if not os.path.exists(zst_filename): - github_dataset_repo = f"https://github.com/hypernetwork-research-group/datasets/blob/main/{dataset_name}.json.zst?raw=true" - - response = requests.get(github_dataset_repo) + github_url = f"https://raw.githubusercontent.com/hypernetwork-research-group/datasets/{GITHUB_COMMIT_SHA}/{dataset_name}.json.zst" + response = requests.get(github_url, timeout=20) if response.status_code != 200: warnings.warn( f"GitHub raw download failed for dataset '{dataset_name}' with status code {response.status_code}\n" @@ -324,16 +327,23 @@ def load_by_name(dataset_name: str, save_on_disk: bool = False) -> HData: with tempfile.NamedTemporaryFile( mode="wb", suffix=".json.zst", delete=False ) as tmp_hf_file: - try: - downloaded_path = hf_hub_download( - repo_id=REPO_ID, - filename=FILENAME, - repo_type="dataset", - ) - except Exception as e: + if hf_sha is not None: + try: + downloaded_path = hf_hub_download( + repo_id=REPO_ID, + filename=FILENAME, + repo_type="dataset", + revision=hf_sha, + ) + except Exception as e: + raise ValueError( + f"Failed to download dataset '{dataset_name}' from GitHub and Hugging Face Hub. GitHub error: {response.status_code} | Hugging Face error: {str(e)}" + ) + else: raise ValueError( - f"Failed to download dataset '{dataset_name}' from GitHub and Hugging Face Hub. GitHub error: {response.status_code} | Hugging Face error: {str(e)}" + f"Failed to download dataset '{dataset_name}' from GitHub with status code {response.status_code} and no SHA provided for Hugging Face Hub fallback." ) + with open(downloaded_path, "rb") as hf_file: hf_content = hf_file.read() tmp_hf_file.write(hf_content) diff --git a/hyperbench/data/supported_datasets.py b/hyperbench/data/supported_datasets.py index b08f777..8cb178f 100644 --- a/hyperbench/data/supported_datasets.py +++ b/hyperbench/data/supported_datasets.py @@ -13,6 +13,7 @@ class PreloadedDataset(Dataset): """ DATASET_NAME = "" + HF_SHA = None def __init__( self, @@ -21,96 +22,121 @@ def __init__( ) -> None: super().__init__(hdata=hdata, sampling_strategy=sampling_strategy) if hdata is None: - self.hdata = HIFLoader.load_by_name(self.DATASET_NAME, save_on_disk=True) + self.hdata = HIFLoader.load_by_name( + self.DATASET_NAME, hf_sha=self.HF_SHA, save_on_disk=True + ) class AlgebraDataset(PreloadedDataset): DATASET_NAME = "algebra" + HF_SHA = "2bb641461e00c103fb5ef4fe6a30aad42500fc21" class AmazonDataset(PreloadedDataset): DATASET_NAME = "amazon" + HF_SHA = "614f75d1847d233ee06da0cc3ee10f51220b8243" class ContactHighSchoolDataset(PreloadedDataset): DATASET_NAME = "contact-high-school" + HF_SHA = "b991fde34631a357961a244a5c4d734cf3093199" class ContactPrimarySchoolDataset(PreloadedDataset): DATASET_NAME = "contact-primary-school" + HF_SHA = "f6f5453777d1fc62f6305b17d131ec1e32cdbe66" class CoraDataset(PreloadedDataset): DATASET_NAME = "cora" + HF_SHA = "91fda9ed324e2cce2430638747e9b032bd9c22ad" class CourseraDataset(PreloadedDataset): DATASET_NAME = "coursera" + HF_SHA = "e68679a01af61c43292575839e451eb0bbeee202" class DBLPDataset(PreloadedDataset): DATASET_NAME = "dblp" + HF_SHA = "151c360ed77042abebb9709fd3d738763d5c5044" class EmailEnronDataset(PreloadedDataset): DATASET_NAME = "email-Enron" + HF_SHA = "05247a5441a6a337cdccee24c0060255815905be" class EmailW3CDataset(PreloadedDataset): DATASET_NAME = "email-W3C" + HF_SHA = "18b8c795504388c1d075ffcea7eada281ec5e416" class GeometryDataset(PreloadedDataset): DATASET_NAME = "geometry" + HF_SHA = "49a8647d6ff7361485c953949010155b0b522a12" class GOTDataset(PreloadedDataset): DATASET_NAME = "got" + HF_SHA = "2efb505e5d82457f6e5ba21820c8d8f2298f0ece" class IMDBDataset(PreloadedDataset): DATASET_NAME = "imdb" + HF_SHA = "c3a583313d1611b292933d77e725b11be2c39a05" class MusicBluesReviewsDataset(PreloadedDataset): DATASET_NAME = "music-blues-reviews" + HF_SHA = "7d218b727097ed007e7f368ab91c064b3eeff184" class NBADataset(PreloadedDataset): DATASET_NAME = "nba" + HF_SHA = "5b3b1c7e425bc407bc0843f443cdf889b51e1ca7" class NDCClassesDataset(PreloadedDataset): DATASET_NAME = "NDC-classes" + HF_SHA = "c9bb31897646fb3f964ee4affe126f9885954d92" class NDCSubstancesDataset(PreloadedDataset): DATASET_NAME = "NDC-substances" + HF_SHA = "bbdde0839ca5913a2535e6fe3ce397b990803af9" class PatentDataset(PreloadedDataset): DATASET_NAME = "patent" + HF_SHA = "608b4fab97d17adbc01b0b4636b060a550231307" class PubmedDataset(PreloadedDataset): DATASET_NAME = "pubmed" + HF_SHA = "b8f846a3c812b3b23f10bd69f65f739983f6a390" class RestaurantReviewsDataset(PreloadedDataset): DATASET_NAME = "restaurant-reviews" + HF_SHA = "668a90391fcb968c786da7bc9e7bbc55e2832066" class ThreadsAskUbuntuDataset(PreloadedDataset): DATASET_NAME = "threads-ask-ubuntu" + HF_SHA = "704c54c7f21b4e313ab6bb50bcd30f58ade469b6" class ThreadsMathsxDataset(PreloadedDataset): DATASET_NAME = "threads-math-sx" + HF_SHA = "b024111c16fdb266e159a4c647ff1a31ec40db5b" class TwitterDataset(PreloadedDataset): DATASET_NAME = "twitter" + HF_SHA = "d93c55af8e04cf70d65ed0059325009a21699a25" class VegasBarsReviewsDataset(PreloadedDataset): DATASET_NAME = "vegas-bars-reviews" + HF_SHA = "4f1e4e4c87957679efc38c05129a694d315a8c9b" diff --git a/hyperbench/tests/data/dataset_test.py b/hyperbench/tests/data/dataset_test.py index 4453045..a0f2ebb 100644 --- a/hyperbench/tests/data/dataset_test.py +++ b/hyperbench/tests/data/dataset_test.py @@ -111,7 +111,9 @@ def test_Preloaded_dataset_loads_hdata_when_hdata_is_none(): dataset = AlgebraDataset(hdata=None) assert dataset.hdata == mock_hdata - mock_load.assert_called_once_with("algebra", save_on_disk=True) + mock_load.assert_called_once_with( + "algebra", hf_sha="2bb641461e00c103fb5ef4fe6a30aad42500fc21", save_on_disk=True + ) @pytest.mark.parametrize( diff --git a/hyperbench/tests/data/hif_test.py b/hyperbench/tests/data/hif_test.py index e7e3838..b6b1a47 100644 --- a/hyperbench/tests/data/hif_test.py +++ b/hyperbench/tests/data/hif_test.py @@ -484,12 +484,11 @@ def test_HIFLoader_falls_back_to_hf_hub_download_when_github_raw_download_fails( mock_response.content = b"" with pytest.warns(UserWarning, match="GitHub raw download failed"): - result = HIFLoader.load_by_name("algebra", save_on_disk=False) - - mock_get.assert_called_once() - mock_hf_hub_download.assert_called_once() - assert result.num_nodes == 2 - assert result.num_hyperedges == 1 + with pytest.raises( + ValueError, + match="Failed to download dataset 'algebra' from GitHub with status code 404 and no SHA provided for Hugging Face Hub fallback.", + ): + result = HIFLoader.load_by_name("algebra", save_on_disk=False) def test_load_saves_downloaded_dataset_on_disk(tmp_path, mock_hypergraph): @@ -525,3 +524,116 @@ def test_HIFLoader_download_raises_when_network_error(): ): with pytest.raises(requests.RequestException, match="Network error"): HIFLoader.load_by_name("algebra") + + +def test_load_by_name_uses_hf_revision_when_github_download_fails(tmp_path, mock_hypergraph): + hf_sha = "2bb641461e00c103fb5ef4fe6a30aad42500fc21" + fallback_file = tmp_path / "algebra.json.zst" + fallback_file.write_bytes(b"mock_zst_content") + json_path = _write_hif_json(tmp_path, mock_hypergraph) + + response = requests.Response() + response.status_code = 404 + response._content = b"" + + with ( + patch("hyperbench.data.hif.os.path.exists", return_value=False), + patch("hyperbench.data.hif.requests.get", return_value=response), + patch( + "hyperbench.data.hif.hf_hub_download", return_value=str(fallback_file) + ) as mock_hf_hub_download, + patch("hyperbench.data.hif.decompress_zst", return_value=json_path), + patch("hyperbench.data.hif.validate_hif_json", return_value=True), + ): + with pytest.warns(UserWarning, match="GitHub raw download failed"): + result = HIFLoader.load_by_name("algebra", hf_sha=hf_sha, save_on_disk=False) + + mock_hf_hub_download.assert_called_once_with( + repo_id="HypernetworkRG/algebra", + filename="algebra.json.zst", + repo_type="dataset", + revision=hf_sha, + ) + assert result.num_nodes == 2 + assert result.num_hyperedges == 1 + + +def test_load_by_name_raises_when_hf_sha_is_missing_on_fallback(): + response = requests.Response() + response.status_code = 404 + response._content = b"" + + with ( + patch("hyperbench.data.hif.os.path.exists", return_value=False), + patch("hyperbench.data.hif.requests.get", return_value=response), + patch("hyperbench.data.hif.hf_hub_download") as mock_hf_hub_download, + ): + with pytest.warns(UserWarning, match="GitHub raw download failed"): + with pytest.raises( + ValueError, + match="no SHA provided for Hugging Face Hub fallback", + ): + HIFLoader.load_by_name("algebra", save_on_disk=False) + + mock_hf_hub_download.assert_not_called() + + +def test_load_by_name_reads_hf_download_and_saves_its_content(tmp_path, mock_hypergraph): + hf_sha = "2bb641461e00c103fb5ef4fe6a30aad42500fc21" + hf_content = b"mock_zst_content" + fallback_file = tmp_path / "algebra.json.zst" + fallback_file.write_bytes(hf_content) + json_path = _write_hif_json(tmp_path, mock_hypergraph) + + response = requests.Response() + response.status_code = 404 + response._content = b"" + + with ( + patch("hyperbench.data.hif.os.path.exists", return_value=False), + patch("hyperbench.data.hif.requests.get", return_value=response), + patch("hyperbench.data.hif.hf_hub_download", return_value=str(fallback_file)), + patch("hyperbench.data.hif.decompress_zst", return_value=json_path), + patch("hyperbench.data.hif.validate_hif_json", return_value=True), + patch("hyperbench.data.hif.__file__", str(tmp_path / "hif.py")), + ): + with pytest.warns(UserWarning, match="GitHub raw download failed"): + result = HIFLoader.load_by_name("algebra", hf_sha=hf_sha, save_on_disk=True) + + saved = tmp_path / "datasets" / "algebra.json.zst" + assert saved.exists() + assert saved.read_bytes() == hf_content + assert result.num_nodes == 2 + assert result.num_hyperedges == 1 + + +def test_HIFLoader_download_failure_when_hf_fallback_fails(): + hf_sha = "2bb641461e00c103fb5ef4fe6a30aad42500fc21" + response = requests.Response() + response.status_code = 404 + response._content = b"" + + with ( + patch("hyperbench.data.hif.os.path.exists", return_value=False), + patch("hyperbench.data.hif.requests.get", return_value=response), + patch( + "hyperbench.data.hif.hf_hub_download", + side_effect=Exception("HFHub failed"), + ) as mock_hf_hub_download, + ): + with pytest.warns(UserWarning, match="GitHub raw download failed"): + with pytest.raises( + ValueError, + match=( + r"Failed to download dataset 'algebra' from GitHub and Hugging Face Hub\. " + r"GitHub error: 404 \| Hugging Face error: HFHub failed" + ), + ): + HIFLoader.load_by_name("algebra", hf_sha=hf_sha) + + mock_hf_hub_download.assert_called_once_with( + repo_id="HypernetworkRG/algebra", + filename="algebra.json.zst", + repo_type="dataset", + revision=hf_sha, + ) diff --git a/hyperbench/tests/utils/hif_utils_test.py b/hyperbench/tests/utils/hif_utils_test.py index 66c7d2c..9161177 100644 --- a/hyperbench/tests/utils/hif_utils_test.py +++ b/hyperbench/tests/utils/hif_utils_test.py @@ -3,7 +3,13 @@ import os from unittest.mock import patch, mock_open, MagicMock -from hyperbench.utils import validate_hif_json, compress_to_zst, decompress_zst +from hyperbench.utils import ( + validate_hif_json, + compress_to_zst, + decompress_zst, + get_datasets_shas, + get_dataset_sha, +) from hyperbench.tests import MOCK_BASE_PATH @@ -25,7 +31,7 @@ def test_validate_hif_json_with_url_success(): validate_hif_json(path_valid) mock_get.assert_called_once_with( - "https://raw.githubusercontent.com/HIF-org/HIF-standard/main/schemas/hif_schema.json", + "https://raw.githubusercontent.com/HIF-org/HIF-standard/b691a3d2ec32100c0229ebe1151e9afad015c356/schemas/hif_schema.json", timeout=10, ) @@ -93,3 +99,60 @@ def test_decompress_zst_round_trip_preserves_json_content(tmp_path): decompressed_data = json.load(f) assert decompressed_data == expected_data + + +def test_get_datasets_shas_returns_shas_and_none_on_failure(): + names = ["algebra", "missing-dataset"] + + def dataset_info_side_effect(*, repo_id): + if repo_id.endswith("/missing-dataset"): + raise RuntimeError("not found") + info = MagicMock() + info.sha = "sha-algebra" + return info + + with ( + patch("hyperbench.utils.hif_utils.HfApi") as mock_hf_api, + patch("builtins.print") as mock_print, + ): + mock_hf_api.return_value.dataset_info.side_effect = dataset_info_side_effect + + result = get_datasets_shas(names) + + assert result == { + "algebra": "sha-algebra", + "missing-dataset": None, + } + mock_hf_api.return_value.dataset_info.assert_any_call(repo_id="HypernetworkRG/algebra") + mock_hf_api.return_value.dataset_info.assert_any_call(repo_id="HypernetworkRG/missing-dataset") + assert any( + "missing-dataset: failed to retrieve SHA" in call.args[0] + for call in mock_print.call_args_list + ) + + +def test_get_dataset_sha_returns_sha(): + with patch("hyperbench.utils.hif_utils.HfApi") as mock_hf_api: + info = MagicMock() + info.sha = "sha-amazon" + mock_hf_api.return_value.dataset_info.return_value = info + + result = get_dataset_sha("amazon") + + assert result == "sha-amazon" + mock_hf_api.return_value.dataset_info.assert_called_once_with(repo_id="HypernetworkRG/amazon") + + +def test_get_dataset_sha_returns_none_on_failure(): + with ( + patch("hyperbench.utils.hif_utils.HfApi") as mock_hf_api, + patch("builtins.print") as mock_print, + ): + mock_hf_api.return_value.dataset_info.side_effect = RuntimeError("boom") + + result = get_dataset_sha("amazon") + + assert result is None + mock_hf_api.return_value.dataset_info.assert_called_once_with(repo_id="HypernetworkRG/amazon") + mock_print.assert_called_once() + assert "amazon: failed to retrieve SHA" in mock_print.call_args[0][0] diff --git a/hyperbench/utils/__init__.py b/hyperbench/utils/__init__.py index 5763312..c9c708a 100644 --- a/hyperbench/utils/__init__.py +++ b/hyperbench/utils/__init__.py @@ -5,7 +5,7 @@ to_non_empty_edgeattr, to_0based_ids, ) -from .hif_utils import validate_hif_json +from .hif_utils import validate_hif_json, get_datasets_shas, get_dataset_sha from .nn_utils import ( INPUT_LAYER, ActivationFn, @@ -38,4 +38,6 @@ "compress_to_zst", "validate_http_url", "write_to_disk", + "get_datasets_shas", + "get_dataset_sha", ] diff --git a/hyperbench/utils/hif_utils.py b/hyperbench/utils/hif_utils.py index c9cd944..f26178f 100644 --- a/hyperbench/utils/hif_utils.py +++ b/hyperbench/utils/hif_utils.py @@ -2,6 +2,10 @@ import json import requests +from huggingface_hub import HfApi + +HIF_SCHEMA_COMMIT_SHA = "b691a3d2ec32100c0229ebe1151e9afad015c356" + def validate_hif_json(filename: str) -> bool: """ @@ -13,7 +17,7 @@ def validate_hif_json(filename: str) -> bool: Returns: ``True`` if the file is valid HIF, ``False`` otherwise. """ - url = "https://raw.githubusercontent.com/HIF-org/HIF-standard/main/schemas/hif_schema.json" + url = f"https://raw.githubusercontent.com/HIF-org/HIF-standard/{HIF_SCHEMA_COMMIT_SHA}/schemas/hif_schema.json" try: schema = requests.get(url, timeout=10).json() except (requests.RequestException, requests.Timeout): @@ -26,3 +30,30 @@ def validate_hif_json(filename: str) -> bool: return True except Exception: return False + + +def get_datasets_shas(names: list[str], namespace: str = "HypernetworkRG") -> dict[str, str | None]: + api = HfApi() + shas: dict[str, str | None] = {} + + for dataset_name in names: + repo_id = f"{namespace}/{dataset_name}" + try: + info = api.dataset_info(repo_id=repo_id) + shas[dataset_name] = info.sha + except Exception as e: + shas[dataset_name] = None + print(f"{dataset_name}: failed to retrieve SHA ({e})") + + return shas + + +def get_dataset_sha(dataset_name: str, namespace: str = "HypernetworkRG") -> str | None: + api = HfApi() + repo_id = f"{namespace}/{dataset_name}" + try: + info = api.dataset_info(repo_id=repo_id) + return info.sha + except Exception as e: + print(f"{dataset_name}: failed to retrieve SHA ({e})") + return None