diff --git a/README.md b/README.md index fbf5436..c8c22fd 100644 --- a/README.md +++ b/README.md @@ -7,15 +7,13 @@ Loader package for opening source datasets before validating them with `mlwp-data-loaders` is responsible for: -1. Importing a loader module (or Python file). -2. Using the loader to open and normalize source datasets. -3. Extracting the appropriate validation trait profiles (`TIME_PROFILE`, `SPACE_PROFILE`, `UNCERTAINTY_PROFILE`) defined by the loader. -4. Validating the returned `xarray.Dataset` automatically via `mlwp-data-specs`. -5. Returning the `xarray.Dataset` (and optionally the trait dict) for further use or machine learning workloads. +1. Using a given loader module ([bundled with mlwp-data-loaders](src/mlwp_data_loaders/loaders/) or [user-provided](#loader-module-contract)) that defines how to load and normalize source files. The loader is expected to [set global attributes on the resulting dataset](#loader-module-contract) to indicate the [dataset's traits](https://github.com/mlwp-tools/mlwp-data-specs). +2. Validating the returned dataset automatically via `mlwp-data-specs`. +3. Returning the `xarray.Dataset` (and optionally a validation report) for further use or machine learning workloads. The intended split is: -- **`mlwp-data-loaders`**: Source-specific loading and normalization logic. -- **`mlwp-data-specs`**: General trait validation and compliance checks. +- **`mlwp-data-loaders`** (this repo): Source-specific loading and normalization logic. +- [**`mlwp-data-specs`**](https://github.com/mlwp-tools/mlwp-data-specs): Trait dataset requirements definitions and validation. ## Python API @@ -23,38 +21,30 @@ The `loader` argument is interpreted as: - A Python file path if it ends with `.py`. - A Python module path if it contains `.` (e.g. `mlwp_data_loaders.loaders.anemoi.anemoi_inference`). -You can load a dataset and its trait profiles natively: +You can load a dataset and get its validation report natively: ```python -from mlwp_data_loaders import load_dataset +from mlwp_data_loaders import load_and_validate_dataset from mlwp_data_specs import validate_dataset -# 1. Load the dataset and extract the trait profiles defined by the loader -ds, dataset_traits = load_dataset( +# 1. Load the dataset and extract the validation report +ds, validation_report = load_and_validate_dataset( [ "/path/to/anemoi-inference-20260101T00.nc", "/path/to/anemoi-inference-20260102T00.nc", ], loader="mlwp_data_loaders.loaders.anemoi.anemoi_inference", - return_dataset_traits=True, + return_validation_report=True, ) -# 2. Get a detailed validation report by passing the extracted traits -report = validate_dataset( - ds, - time=dataset_traits.get("time_profile"), - space=dataset_traits.get("space_profile"), - uncertainty=dataset_traits.get("uncertainty_profile"), -) - -# 3. Print the validation results to the console -report.console_print() +# 2. Print the validation results to the console +validation_report.console_print() ``` -If you don't need the traits dictionary returned, simply omit `return_dataset_traits` (defaults to `False`): +If you don't need the report returned, simply omit `return_validation_report` (defaults to `False`). The function will raise a `ValueError` if the dataset does not pass the validation. ```python -ds = load_dataset( +ds = load_and_validate_dataset( "s3://my-bucket/dataset.zarr", loader="mlwp_data_loaders.loaders.anemoi.anemoi_datasets", storage_options={"anon": True}, @@ -83,25 +73,20 @@ uv run mlwp.load_and_validate_dataset \ ## Loader Module Contract -Each loader module must define a function and optionally standard profile variables: +Each loader module must define a function and assign the correct trait profile attributes to the dataset: 1. `load_dataset(path: str | list[str], **kwargs) -> xr.Dataset` - **Required**. Handles opening the path(s), preprocessing, concatenating, and postprocessing, returning a single normalized `xarray.Dataset`. -2. `TIME_PROFILE`: `str` - - Defines the time trait profile for `mlwp-data-specs` validation (e.g. `"forecast"`). -3. `SPACE_PROFILE`: `str` - - Defines the space trait profile (e.g. `"grid"`). -4. `UNCERTAINTY_PROFILE`: `str` - - Defines the uncertainty trait profile (e.g. `"deterministic"`). +2. Attributes attached to the dataset + - Must set `mlwp_time_trait` (e.g. `"forecast"`). + - Must set `mlwp_space_trait` (e.g. `"grid"`). + - Must set `mlwp_uncertainty_trait` (e.g. `"deterministic"`). ### Example Loader (`my_loader.py`) ```python import xarray as xr - -TIME_PROFILE = "observation" -SPACE_PROFILE = "grid" -UNCERTAINTY_PROFILE = "deterministic" +from mlwp_data_specs.api import SPACE_TRAIT_ATTR, TIME_TRAIT_ATTR, UNCERTAINTY_TRAIT_ATTR def load_dataset(path: str | list[str], **kwargs) -> xr.Dataset: if isinstance(path, list): @@ -113,5 +98,10 @@ def load_dataset(path: str | list[str], **kwargs) -> xr.Dataset: if "time" in ds.dims: ds = ds.rename({"time": "valid_time"}) + # Assign required traits for validation + ds.attrs[TIME_TRAIT_ATTR] = "observation" + ds.attrs[SPACE_TRAIT_ATTR] = "grid" + ds.attrs[UNCERTAINTY_TRAIT_ATTR] = "deterministic" + return ds ``` diff --git a/pyproject.toml b/pyproject.toml index 7a59460..2d93271 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,8 +40,8 @@ where = ["src"] testpaths = ["tests"] [tool.uv.sources] -mlwp-data-specs = { git = "https://github.com/mlwp-tools/mlwp-data-specs", rev = "3a7529b" } mxalign = { git = "https://github.com/mlwp-tools/mxalign", rev = "e2232d93275c7508897a7ddb0cce8b508665f24c" } +mlwp-data-specs = { git = "https://github.com/mlwp-tools/mlwp-data-specs", rev = "059f382" } [dependency-groups] dev = [ diff --git a/src/mlwp_data_loaders/__init__.py b/src/mlwp_data_loaders/__init__.py index 54df21d..ba2a087 100644 --- a/src/mlwp_data_loaders/__init__.py +++ b/src/mlwp_data_loaders/__init__.py @@ -1,5 +1,5 @@ """Helpers for loading datasets before validating them with mlwp-data-specs.""" -from .api import load_dataset +from .api import load_and_validate_dataset -__all__ = ["load_dataset"] +__all__ = ["load_and_validate_dataset"] diff --git a/src/mlwp_data_loaders/api.py b/src/mlwp_data_loaders/api.py index 4d077c0..995f282 100644 --- a/src/mlwp_data_loaders/api.py +++ b/src/mlwp_data_loaders/api.py @@ -2,23 +2,22 @@ from __future__ import annotations -import inspect from typing import Any import xarray as xr from mlwp_data_specs import validate_dataset +from mlwp_data_specs.specs.reporting import ValidationReport -from .core import get_dataset_traits_from_loader +from .core import get_loader_func -def load_dataset( +def load_and_validate_dataset( dataset_path: str | list[str], *, loader: str, - storage_options: dict[str, Any] | None = None, - return_dataset_traits: bool = False, + return_validation_report: bool = False, **kwargs: Any, -) -> xr.Dataset | tuple[xr.Dataset, dict[str, Any]]: +) -> xr.Dataset | tuple[xr.Dataset, ValidationReport]: """Load a dataset through a loader module and validate it. Parameters @@ -28,53 +27,40 @@ def load_dataset( loader : str Loader module reference. A value ending in ``.py`` is treated as a file path. A value containing ``.`` is treated as a Python module path. - storage_options : dict[str, Any] | None, optional - Storage options forwarded to the loader's ``load_dataset`` function. - return_dataset_traits : bool, optional - If True, return a tuple containing the dataset and the loader traits. + return_validation_report : bool, optional + If True, return a tuple containing the dataset and the validation report. Defaults to False. **kwargs Additional keyword arguments forwarded to the loader's ``load_dataset`` - function if its signature accepts them. + function (e.g., ``storage_options``). Returns ------- - xr.Dataset | tuple[xr.Dataset, dict[str, Any]] - Loaded and validated dataset. If `return_dataset_traits` is True, - returns a tuple of (dataset, dataset_traits). + xr.Dataset | tuple[xr.Dataset, ValidationReport] + Loaded and validated dataset. If `return_validation_report` is True, + returns a tuple of (dataset, validation_report). + + Raises + ------ + ValueError + If validation fails and `return_validation_report` is False. """ - dataset_traits = get_dataset_traits_from_loader(loader) + loader_func = get_loader_func(loader) - loader_func = dataset_traits["load_dataset"] - sig = inspect.signature(loader_func) - - loader_kwargs: dict[str, Any] = {} - - # Check if the loader's load_dataset accepts **kwargs - accepts_kwargs = any( - param.kind == inspect.Parameter.VAR_KEYWORD for param in sig.parameters.values() - ) - - if storage_options is not None: - if accepts_kwargs or "storage_options" in sig.parameters: - loader_kwargs["storage_options"] = storage_options - - for key, value in kwargs.items(): - if accepts_kwargs or key in sig.parameters: - loader_kwargs[key] = value - - ds = loader_func(dataset_path, **loader_kwargs) + ds = loader_func(dataset_path, **kwargs) if not isinstance(ds, xr.Dataset): ds = ds.to_dataset() - validate_dataset( - ds, - time=dataset_traits.get("time_profile"), - space=dataset_traits.get("space_profile"), - uncertainty=dataset_traits.get("uncertainty_profile"), - ) + report = validate_dataset(ds) + + if return_validation_report: + return ds, report + + if report.has_fails(): + # Ideally, we should be able to format the report nicely + raise ValueError( + "Dataset validation failed. Run with return_validation_report=True for details." + ) - if return_dataset_traits: - return ds, dataset_traits return ds diff --git a/src/mlwp_data_loaders/cli.py b/src/mlwp_data_loaders/cli.py index 949f69b..0944818 100644 --- a/src/mlwp_data_loaders/cli.py +++ b/src/mlwp_data_loaders/cli.py @@ -7,9 +7,8 @@ from loguru import logger from mlwp_data_specs import __version__ as specs_version -from mlwp_data_specs.api import validate_dataset -from .api import load_dataset +from .api import load_and_validate_dataset def build_parser() -> argparse.ArgumentParser: @@ -76,23 +75,16 @@ def main(argv: Sequence[str] | None = None) -> int: logger.info(f"Using mlwp-data-specs {specs_version}") + kwargs = {} + if storage_options: + kwargs["storage_options"] = storage_options + # Load the dataset - ds, dataset_traits = load_dataset( # type: ignore # load_dataset returns a tuple when return_dataset_traits=True + ds, report = load_and_validate_dataset( # type: ignore dataset_input, loader=args.loader, - storage_options=storage_options or None, - return_dataset_traits=True, - ) - - time_profile = dataset_traits.get("time_profile") - space_profile = dataset_traits.get("space_profile") - uncertainty_profile = dataset_traits.get("uncertainty_profile") - - report = validate_dataset( - ds, - time=time_profile, - space=space_profile, - uncertainty=uncertainty_profile, + return_validation_report=True, + **kwargs, ) report.console_print() diff --git a/src/mlwp_data_loaders/core.py b/src/mlwp_data_loaders/core.py index 94313c8..726f9a5 100644 --- a/src/mlwp_data_loaders/core.py +++ b/src/mlwp_data_loaders/core.py @@ -5,14 +5,11 @@ import importlib import importlib.util from pathlib import Path -from types import ModuleType -from typing import Any +from typing import Any, Callable -DatasetTraits = dict[str, Any] - -def _load_module(loader: str) -> ModuleType: - """Import a loader module from a Python file or module path. +def get_loader_func(loader: str) -> Callable[..., Any]: + """Get the load_dataset function from a loader module. Parameters ---------- @@ -22,13 +19,14 @@ def _load_module(loader: str) -> ModuleType: Returns ------- - ModuleType - Imported module object. + Callable + The load_dataset function. Raises ------ ValueError - If the loader reference cannot be resolved. + If the loader reference cannot be resolved or if the loader module + does not define a 'load_dataset' function. """ if loader.endswith(".py"): path = Path(loader) @@ -37,48 +35,15 @@ def _load_module(loader: str) -> ModuleType: raise ValueError(f"Could not import loader module from file: {loader}") module = importlib.util.module_from_spec(spec) spec.loader.exec_module(module) - return module - if "." in loader: - return importlib.import_module(loader) - raise ValueError( - "Loader must be a Python file path ending in .py or a Python module path" - ) - - -def get_dataset_traits_from_loader(loader: str) -> DatasetTraits: - """Import traits from a loader module. - - Parameters - ---------- - loader : str - Loader module reference. - - Returns - ------- - DatasetTraits - Mapping with trait names normalized to lowercase. - - Raises - ------ - ValueError - If the loader module does not define a 'load_dataset' function. - """ - module = _load_module(loader) - traits: DatasetTraits = {} + elif "." in loader: + module = importlib.import_module(loader) + else: + raise ValueError( + "Loader must be a Python file path ending in .py or a Python module path" + ) if not hasattr(module, "load_dataset"): raise ValueError( f"Loader module {loader!r} must define a 'load_dataset' function." ) - traits["load_dataset"] = module.load_dataset - - supported_traits = ( - "TIME_PROFILE", - "SPACE_PROFILE", - "UNCERTAINTY_PROFILE", - ) - for name in supported_traits: - if hasattr(module, name): - traits[name.lower()] = getattr(module, name) - - return traits + return module.load_dataset diff --git a/src/mlwp_data_loaders/loaders/anemoi/anemoi_datasets.py b/src/mlwp_data_loaders/loaders/anemoi/anemoi_datasets.py index 571b444..cd5e8a5 100644 --- a/src/mlwp_data_loaders/loaders/anemoi/anemoi_datasets.py +++ b/src/mlwp_data_loaders/loaders/anemoi/anemoi_datasets.py @@ -3,10 +3,11 @@ import numpy as np import xarray as xr from loguru import logger - -TIME_PROFILE = "observation" -SPACE_PROFILE = "grid" -UNCERTAINTY_PROFILE = "deterministic" +from mlwp_data_specs.api import ( + SPACE_TRAIT_ATTR, + TIME_TRAIT_ATTR, + UNCERTAINTY_TRAIT_ATTR, +) DROP_VARS = [ "latitude", @@ -78,7 +79,11 @@ def load_dataset( f"to xr.Dataset, this might take some time. Consider selecting the relevant variables during loading" ) - return ds_selected.to_dataset(dim="variable") + ds_final = ds_selected.to_dataset(dim="variable") + ds_final.attrs[TIME_TRAIT_ATTR] = "observation" + ds_final.attrs[SPACE_TRAIT_ATTR] = "grid" + ds_final.attrs[UNCERTAINTY_TRAIT_ATTR] = "deterministic" + return ds_final def _postprocess(dataset: xr.Dataset) -> xr.Dataset: diff --git a/src/mlwp_data_loaders/loaders/anemoi/anemoi_inference.py b/src/mlwp_data_loaders/loaders/anemoi/anemoi_inference.py index a421734..f434938 100644 --- a/src/mlwp_data_loaders/loaders/anemoi/anemoi_inference.py +++ b/src/mlwp_data_loaders/loaders/anemoi/anemoi_inference.py @@ -1,10 +1,11 @@ from typing import Any import xarray as xr - -TIME_PROFILE = "forecast" -SPACE_PROFILE = "grid" -UNCERTAINTY_PROFILE = "deterministic" +from mlwp_data_specs.api import ( + SPACE_TRAIT_ATTR, + TIME_TRAIT_ATTR, + UNCERTAINTY_TRAIT_ATTR, +) def load_dataset( @@ -59,6 +60,10 @@ def load_dataset( .swap_dims({"time": "lead_time"}) ) + ds_out.attrs[TIME_TRAIT_ATTR] = "forecast" + ds_out.attrs[SPACE_TRAIT_ATTR] = "grid" + ds_out.attrs[UNCERTAINTY_TRAIT_ATTR] = "deterministic" + return ds_out diff --git a/src/mlwp_data_loaders/loaders/harp/obstable.py b/src/mlwp_data_loaders/loaders/harp/obstable.py index fc1cd76..3b6c3bd 100644 --- a/src/mlwp_data_loaders/loaders/harp/obstable.py +++ b/src/mlwp_data_loaders/loaders/harp/obstable.py @@ -6,10 +6,11 @@ import pandas as pd import xarray as xr - -TIME_PROFILE = "observation" -SPACE_PROFILE = "point" -UNCERTAINTY_PROFILE = "deterministic" +from mlwp_data_specs.api import ( + SPACE_TRAIT_ATTR, + TIME_TRAIT_ATTR, + UNCERTAINTY_TRAIT_ATTR, +) COORDS = { "longitude": "lon", @@ -100,6 +101,12 @@ def load_dataset( {"standard_name": "longitude", "units": "degrees_east"} ) - return ds.rename_dims({"code": "point_index"}).transpose( + ds_final = ds.rename_dims({"code": "point_index"}).transpose( "valid_time", "point_index" ) + + ds_final.attrs[TIME_TRAIT_ATTR] = "observation" + ds_final.attrs[SPACE_TRAIT_ATTR] = "point" + ds_final.attrs[UNCERTAINTY_TRAIT_ATTR] = "deterministic" + + return ds_final diff --git a/tests/test_anemoi_datasets_integration.py b/tests/test_anemoi_datasets_integration.py index 4a33c00..8c7df9c 100644 --- a/tests/test_anemoi_datasets_integration.py +++ b/tests/test_anemoi_datasets_integration.py @@ -2,9 +2,13 @@ from __future__ import annotations -from mlwp_data_specs import validate_dataset +from mlwp_data_specs.api import ( + SPACE_TRAIT_ATTR, + TIME_TRAIT_ATTR, + UNCERTAINTY_TRAIT_ATTR, +) -from mlwp_data_loaders.api import load_dataset +from mlwp_data_loaders.api import load_and_validate_dataset from mlwp_data_loaders.mxalign_api import validate_dataset_with_mxalign # Use small CERRA sample dataset stored on EWC (European Weather Cloud) @@ -24,32 +28,26 @@ def test_load_dataset_opens_anemoi_store_from_ewc() -> None: "anon": True, } - ds, dataset_traits = load_dataset( # type: ignore # load_dataset returns a tuple when return_dataset_traits=True + ds, report_specs = load_and_validate_dataset( # type: ignore DATASET_PATH, loader=LOADER, storage_options=storage_options, chunks=None, - return_dataset_traits=True, + return_validation_report=True, ) # Note: mxalign validation is temporarily kept here during early development # to ensure `mlwp-data-specs` behaves identically. It will eventually be removed. report_mxalign = validate_dataset_with_mxalign( ds, - time=dataset_traits.get("time_profile"), - space=dataset_traits.get("space_profile"), - uncertainty=dataset_traits.get("uncertainty_profile"), + time=ds.attrs.get(TIME_TRAIT_ATTR), + space=ds.attrs.get(SPACE_TRAIT_ATTR), + uncertainty=ds.attrs.get(UNCERTAINTY_TRAIT_ATTR), ) if report_mxalign.has_fails(): report_mxalign.console_print() assert not report_mxalign.has_fails() - report_specs = validate_dataset( - ds, - time=dataset_traits.get("time_profile"), - space=dataset_traits.get("space_profile"), - uncertainty=dataset_traits.get("uncertainty_profile"), - ) if report_specs.has_fails(): report_specs.console_print() assert not report_specs.has_fails() diff --git a/tests/test_api.py b/tests/test_api.py index 64eccad..591abce 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -6,8 +6,8 @@ import xarray as xr import mlwp_data_loaders.mxalign_api as mxalign_api -from mlwp_data_loaders.api import load_dataset -from mlwp_data_loaders.core import get_dataset_traits_from_loader +from mlwp_data_loaders.api import load_and_validate_dataset +from mlwp_data_loaders.core import get_loader_func def _forecast_grid_ds() -> xr.Dataset: @@ -33,87 +33,82 @@ def _forecast_grid_ds() -> xr.Dataset: return ds -def test_get_dataset_traits_from_loader_raises_missing_load_dataset(tmp_path) -> None: +def test_get_loader_func_raises_missing_load_dataset(tmp_path) -> None: """Loader modules must define a 'load_dataset' function.""" loader_file = tmp_path / "loader_missing.py" loader_file.write_text("TIME_PROFILE = 'forecast'\n", encoding="utf-8") with pytest.raises(ValueError, match="must define a 'load_dataset' function"): - get_dataset_traits_from_loader(str(loader_file)) + get_loader_func(str(loader_file)) -def test_get_dataset_traits_from_loader_finds_constants(tmp_path) -> None: - """Loader modules can define trait constants which are correctly captured.""" - loader_file = tmp_path / "loader_valid.py" - loader_file.write_text( - "def load_dataset(path, **kwargs): return None\n" - "TIME_PROFILE = 'forecast'\n" - "SPACE_PROFILE = 'grid'\n", - encoding="utf-8", - ) - traits = get_dataset_traits_from_loader(str(loader_file)) - assert "load_dataset" in traits - assert traits["time_profile"] == "forecast" - assert traits["space_profile"] == "grid" - assert "uncertainty_profile" not in traits - - -def test_load_dataset_filters_kwargs(tmp_path, monkeypatch: pytest.MonkeyPatch) -> None: - """Check that api.load_dataset filters kwargs based on loader signature.""" +def test_load_dataset_rejects_unsupported_kwargs( + tmp_path, monkeypatch: pytest.MonkeyPatch +) -> None: + """Check that api.load_dataset rejects kwargs not supported by the loader.""" loader_file = tmp_path / "loader_strict.py" loader_file.write_text( "def load_dataset(path, chunks=None):\n" " from xarray import Dataset\n" " ds = Dataset()\n" " ds.attrs['chunks'] = chunks\n" - " return ds\n" - "TIME_PROFILE = 'forecast'\n", + " ds.attrs['mlwp_time_trait'] = 'forecast'\n" + " ds.attrs['mlwp_space_trait'] = 'grid'\n" + " ds.attrs['mlwp_uncertainty_trait'] = 'deterministic'\n" + " return ds\n", encoding="utf-8", ) + class MockReport: + def has_fails(self): + return False + # Mock validate_dataset to bypass validation on an empty dataset monkeypatch.setattr( - "mlwp_data_loaders.api.validate_dataset", lambda *args, **kwargs: None + "mlwp_data_loaders.api.validate_dataset", lambda *args, **kwargs: MockReport() ) - ds = load_dataset( - "dummy.nc", - loader=str(loader_file), - chunks="auto", - engine="h5netcdf", # Should be ignored because strict load_dataset doesn't take 'engine' - storage_options={"anon": True}, # Should be ignored - ) - assert isinstance(ds, xr.Dataset) + with pytest.raises(TypeError, match="unexpected keyword argument 'engine'"): + load_and_validate_dataset( + "dummy.nc", + loader=str(loader_file), + chunks="auto", + engine="h5netcdf", # Should raise TypeError + ) - assert ds.attrs["chunks"] == "auto" - assert "engine" not in ds.attrs - -def test_load_dataset_returns_traits(tmp_path, monkeypatch: pytest.MonkeyPatch) -> None: - """Check that api.load_dataset returns traits when requested.""" +def test_load_dataset_returns_report(tmp_path, monkeypatch: pytest.MonkeyPatch) -> None: + """Check that load_and_validate_dataset returns report when requested.""" loader_file = tmp_path / "loader_traits.py" loader_file.write_text( "def load_dataset(path, **kwargs):\n" " from xarray import Dataset\n" - " return Dataset()\n" - "TIME_PROFILE = 'forecast'\n", + " ds = Dataset()\n" + " ds.attrs['mlwp_time_trait'] = 'forecast'\n" + " ds.attrs['mlwp_space_trait'] = 'grid'\n" + " ds.attrs['mlwp_uncertainty_trait'] = 'deterministic'\n" + " return ds\n", encoding="utf-8", ) + class MockReport: + def has_fails(self): + return False + monkeypatch.setattr( - "mlwp_data_loaders.api.validate_dataset", lambda *args, **kwargs: None + "mlwp_data_loaders.api.validate_dataset", lambda *args, **kwargs: MockReport() ) - res = load_dataset( + res = load_and_validate_dataset( "dummy.nc", loader=str(loader_file), - return_dataset_traits=True, + return_validation_report=True, ) assert isinstance(res, tuple) - ds, dataset_traits = res # type: ignore # load_dataset returns a tuple when return_dataset_traits=True + ds, report = res # type: ignore assert isinstance(ds, xr.Dataset) - assert isinstance(dataset_traits, dict) - assert dataset_traits.get("time_profile") == "forecast" + assert not report.has_fails() + assert ds.attrs.get("mlwp_time_trait") == "forecast" def test_validate_dataset_with_mxalign_returns_fail_report_for_invalid_dims( diff --git a/tests/test_cli.py b/tests/test_cli.py index be11051..fc9c729 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -35,16 +35,6 @@ def test_cli_accepts_multiple_dataset_paths(monkeypatch: MonkeyPatch) -> None: """CLI passes multiple dataset paths through to the load/validate API.""" observed: dict[str, object] = {} - def _load_dataset(dataset_path, **kwargs): - observed["dataset_path"] = dataset_path - if kwargs.get("return_dataset_traits"): - return _forecast_grid_ds(), { - "time_profile": "forecast", - "space_profile": "grid", - "uncertainty_profile": "deterministic", - } - return _forecast_grid_ds() - class _Report: def __init__(self): self.fails = False @@ -58,11 +48,13 @@ def has_fails(self): def __iadd__(self, other): return self - def _validate_dataset(ds, **kwargs): - return _Report() + def _load_and_validate_dataset(dataset_path, **kwargs): + observed["dataset_path"] = dataset_path + if kwargs.get("return_validation_report"): + return _forecast_grid_ds(), _Report() + return _forecast_grid_ds() - monkeypatch.setattr(cli, "load_dataset", _load_dataset) - monkeypatch.setattr(cli, "validate_dataset", _validate_dataset) + monkeypatch.setattr(cli, "load_and_validate_dataset", _load_and_validate_dataset) code = cli.main( [ diff --git a/tests/test_harp_obstable_integration.py b/tests/test_harp_obstable_integration.py index 744ad02..f834f12 100644 --- a/tests/test_harp_obstable_integration.py +++ b/tests/test_harp_obstable_integration.py @@ -4,9 +4,13 @@ import pooch import pytest -from mlwp_data_specs import validate_dataset +from mlwp_data_specs.api import ( + SPACE_TRAIT_ATTR, + TIME_TRAIT_ATTR, + UNCERTAINTY_TRAIT_ATTR, +) -from mlwp_data_loaders.api import load_dataset +from mlwp_data_loaders.api import load_and_validate_dataset from mlwp_data_loaders.mxalign_api import validate_dataset_with_mxalign HARP_DATA_URL = "https://raw.githubusercontent.com/harphub/harpData/master/inst/OBSTABLE/OBSTABLE_2019.sqlite" @@ -25,30 +29,24 @@ def obstable_path() -> str: def test_load_dataset_opens_harp_obstable(obstable_path: str) -> None: """The harp.obstable loader can open and validate the sample SQLite file.""" - ds, dataset_traits = load_dataset( # type: ignore # load_dataset returns a tuple when return_dataset_traits=True + ds, report_specs = load_and_validate_dataset( # type: ignore obstable_path, loader=LOADER, - return_dataset_traits=True, + return_validation_report=True, ) # Note: mxalign validation is temporarily kept here during early development # to ensure `mlwp-data-specs` behaves identically. It will eventually be removed. report_mxalign = validate_dataset_with_mxalign( ds, - time=dataset_traits.get("time_profile"), - space=dataset_traits.get("space_profile"), - uncertainty=dataset_traits.get("uncertainty_profile"), + time=ds.attrs.get(TIME_TRAIT_ATTR), + space=ds.attrs.get(SPACE_TRAIT_ATTR), + uncertainty=ds.attrs.get(UNCERTAINTY_TRAIT_ATTR), ) if report_mxalign.has_fails(): report_mxalign.console_print() assert not report_mxalign.has_fails() - report_specs = validate_dataset( - ds, - time=dataset_traits.get("time_profile"), - space=dataset_traits.get("space_profile"), - uncertainty=dataset_traits.get("uncertainty_profile"), - ) if report_specs.has_fails(): report_specs.console_print() assert not report_specs.has_fails() diff --git a/uv.lock b/uv.lock index addea55..061da36 100644 --- a/uv.lock +++ b/uv.lock @@ -1450,7 +1450,7 @@ requires-dist = [ { name = "fsspec" }, { name = "h5netcdf" }, { name = "loguru" }, - { name = "mlwp-data-specs", git = "https://github.com/mlwp-tools/mlwp-data-specs?rev=3a7529b" }, + { name = "mlwp-data-specs", git = "https://github.com/mlwp-tools/mlwp-data-specs?rev=059f382" }, { name = "pytest", marker = "extra == 'test'", specifier = ">=8.0.0" }, { name = "s3fs" }, { name = "xarray" }, @@ -1473,7 +1473,7 @@ test = [ [[package]] name = "mlwp-data-specs" version = "0.1.0" -source = { git = "https://github.com/mlwp-tools/mlwp-data-specs?rev=3a7529b#3a7529bbaa32ab705b0feb24f0d3928e98f99029" } +source = { git = "https://github.com/mlwp-tools/mlwp-data-specs?rev=059f382#059f3820134e0ebff4deb9180dc5a0da18399289" } dependencies = [ { name = "fsspec" }, { name = "loguru" },