Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 25 additions & 35 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,54 +7,44 @@ Loader package for opening source datasets before validating them with

`mlwp-data-loaders` is responsible for:

1. Importing a loader module (or Python file).
2. Using the loader to open and normalize source datasets.
3. Extracting the appropriate validation trait profiles (`TIME_PROFILE`, `SPACE_PROFILE`, `UNCERTAINTY_PROFILE`) defined by the loader.
4. Validating the returned `xarray.Dataset` automatically via `mlwp-data-specs`.
5. Returning the `xarray.Dataset` (and optionally the trait dict) for further use or machine learning workloads.
1. Using a given loader module ([bundled with mlwp-data-loaders](src/mlwp_data_loaders/loaders/) or [user-provided](#loader-module-contract)) that defines how to load and normalize source files. The loader is expected to [set global attributes on the resulting dataset](#loader-module-contract) to indicate the [dataset's traits](https://github.com/mlwp-tools/mlwp-data-specs).
2. Validating the returned dataset automatically via `mlwp-data-specs`.
3. Returning the `xarray.Dataset` (and optionally a validation report) for further use or machine learning workloads.

The intended split is:
- **`mlwp-data-loaders`**: Source-specific loading and normalization logic.
- **`mlwp-data-specs`**: General trait validation and compliance checks.
- **`mlwp-data-loaders`** (this repo): Source-specific loading and normalization logic.
- [**`mlwp-data-specs`**](https://github.com/mlwp-tools/mlwp-data-specs): Trait dataset requirements definitions and validation.

## Python API

The `loader` argument is interpreted as:
- A Python file path if it ends with `.py`.
- A Python module path if it contains `.` (e.g. `mlwp_data_loaders.loaders.anemoi.anemoi_inference`).

You can load a dataset and its trait profiles natively:
You can load a dataset and get its validation report natively:

```python
from mlwp_data_loaders import load_dataset
from mlwp_data_loaders import load_and_validate_dataset
from mlwp_data_specs import validate_dataset

# 1. Load the dataset and extract the trait profiles defined by the loader
ds, dataset_traits = load_dataset(
# 1. Load the dataset and extract the validation report
ds, validation_report = load_and_validate_dataset(
[
"/path/to/anemoi-inference-20260101T00.nc",
"/path/to/anemoi-inference-20260102T00.nc",
],
loader="mlwp_data_loaders.loaders.anemoi.anemoi_inference",
return_dataset_traits=True,
return_validation_report=True,
)

# 2. Get a detailed validation report by passing the extracted traits
report = validate_dataset(
ds,
time=dataset_traits.get("time_profile"),
space=dataset_traits.get("space_profile"),
uncertainty=dataset_traits.get("uncertainty_profile"),
)

# 3. Print the validation results to the console
report.console_print()
# 2. Print the validation results to the console
validation_report.console_print()
```

If you don't need the traits dictionary returned, simply omit `return_dataset_traits` (defaults to `False`):
If you don't need the report returned, simply omit `return_validation_report` (defaults to `False`). The function will raise a `ValueError` if the dataset does not pass the validation.

```python
ds = load_dataset(
ds = load_and_validate_dataset(
"s3://my-bucket/dataset.zarr",
loader="mlwp_data_loaders.loaders.anemoi.anemoi_datasets",
storage_options={"anon": True},
Expand Down Expand Up @@ -83,25 +73,20 @@ uv run mlwp.load_and_validate_dataset \

## Loader Module Contract

Each loader module must define a function and optionally standard profile variables:
Each loader module must define a function and assign the correct trait profile attributes to the dataset:

1. `load_dataset(path: str | list[str], **kwargs) -> xr.Dataset`
- **Required**. Handles opening the path(s), preprocessing, concatenating, and postprocessing, returning a single normalized `xarray.Dataset`.
2. `TIME_PROFILE`: `str`
- Defines the time trait profile for `mlwp-data-specs` validation (e.g. `"forecast"`).
3. `SPACE_PROFILE`: `str`
- Defines the space trait profile (e.g. `"grid"`).
4. `UNCERTAINTY_PROFILE`: `str`
- Defines the uncertainty trait profile (e.g. `"deterministic"`).
2. Attributes attached to the dataset
- Must set `mlwp_time_trait` (e.g. `"forecast"`).
- Must set `mlwp_space_trait` (e.g. `"grid"`).
- Must set `mlwp_uncertainty_trait` (e.g. `"deterministic"`).

### Example Loader (`my_loader.py`)

```python
import xarray as xr

TIME_PROFILE = "observation"
SPACE_PROFILE = "grid"
UNCERTAINTY_PROFILE = "deterministic"
from mlwp_data_specs.api import SPACE_TRAIT_ATTR, TIME_TRAIT_ATTR, UNCERTAINTY_TRAIT_ATTR

def load_dataset(path: str | list[str], **kwargs) -> xr.Dataset:
if isinstance(path, list):
Expand All @@ -113,5 +98,10 @@ def load_dataset(path: str | list[str], **kwargs) -> xr.Dataset:
if "time" in ds.dims:
ds = ds.rename({"time": "valid_time"})

# Assign required traits for validation
ds.attrs[TIME_TRAIT_ATTR] = "observation"
ds.attrs[SPACE_TRAIT_ATTR] = "grid"
ds.attrs[UNCERTAINTY_TRAIT_ATTR] = "deterministic"

return ds
```
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@ where = ["src"]
testpaths = ["tests"]

[tool.uv.sources]
mlwp-data-specs = { git = "https://github.com/mlwp-tools/mlwp-data-specs", rev = "3a7529b" }
mxalign = { git = "https://github.com/mlwp-tools/mxalign", rev = "e2232d93275c7508897a7ddb0cce8b508665f24c" }
mlwp-data-specs = { git = "https://github.com/mlwp-tools/mlwp-data-specs", rev = "059f382" }

[dependency-groups]
dev = [
Expand Down
4 changes: 2 additions & 2 deletions src/mlwp_data_loaders/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""Helpers for loading datasets before validating them with mlwp-data-specs."""

from .api import load_dataset
from .api import load_and_validate_dataset

__all__ = ["load_dataset"]
__all__ = ["load_and_validate_dataset"]
70 changes: 28 additions & 42 deletions src/mlwp_data_loaders/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,23 +2,22 @@

from __future__ import annotations

import inspect
from typing import Any

import xarray as xr
from mlwp_data_specs import validate_dataset
from mlwp_data_specs.specs.reporting import ValidationReport

from .core import get_dataset_traits_from_loader
from .core import get_loader_func


def load_dataset(
def load_and_validate_dataset(
dataset_path: str | list[str],
*,
loader: str,
storage_options: dict[str, Any] | None = None,
return_dataset_traits: bool = False,
return_validation_report: bool = False,
**kwargs: Any,
) -> xr.Dataset | tuple[xr.Dataset, dict[str, Any]]:
) -> xr.Dataset | tuple[xr.Dataset, ValidationReport]:
"""Load a dataset through a loader module and validate it.

Parameters
Expand All @@ -28,53 +27,40 @@ def load_dataset(
loader : str
Loader module reference. A value ending in ``.py`` is treated as a file
path. A value containing ``.`` is treated as a Python module path.
storage_options : dict[str, Any] | None, optional
Storage options forwarded to the loader's ``load_dataset`` function.
return_dataset_traits : bool, optional
If True, return a tuple containing the dataset and the loader traits.
return_validation_report : bool, optional
If True, return a tuple containing the dataset and the validation report.
Defaults to False.
**kwargs
Additional keyword arguments forwarded to the loader's ``load_dataset``
function if its signature accepts them.
function (e.g., ``storage_options``).

Returns
-------
xr.Dataset | tuple[xr.Dataset, dict[str, Any]]
Loaded and validated dataset. If `return_dataset_traits` is True,
returns a tuple of (dataset, dataset_traits).
xr.Dataset | tuple[xr.Dataset, ValidationReport]
Loaded and validated dataset. If `return_validation_report` is True,
returns a tuple of (dataset, validation_report).

Raises
------
ValueError
If validation fails and `return_validation_report` is False.
"""
dataset_traits = get_dataset_traits_from_loader(loader)
loader_func = get_loader_func(loader)

loader_func = dataset_traits["load_dataset"]
sig = inspect.signature(loader_func)

loader_kwargs: dict[str, Any] = {}

# Check if the loader's load_dataset accepts **kwargs
accepts_kwargs = any(
param.kind == inspect.Parameter.VAR_KEYWORD for param in sig.parameters.values()
)

if storage_options is not None:
if accepts_kwargs or "storage_options" in sig.parameters:
loader_kwargs["storage_options"] = storage_options

for key, value in kwargs.items():
if accepts_kwargs or key in sig.parameters:
loader_kwargs[key] = value

ds = loader_func(dataset_path, **loader_kwargs)
ds = loader_func(dataset_path, **kwargs)

if not isinstance(ds, xr.Dataset):
ds = ds.to_dataset()

validate_dataset(
ds,
time=dataset_traits.get("time_profile"),
space=dataset_traits.get("space_profile"),
uncertainty=dataset_traits.get("uncertainty_profile"),
)
report = validate_dataset(ds)

if return_validation_report:
return ds, report

if report.has_fails():
# Ideally, we should be able to format the report nicely
raise ValueError(
"Dataset validation failed. Run with return_validation_report=True for details."
)

if return_dataset_traits:
return ds, dataset_traits
return ds
24 changes: 8 additions & 16 deletions src/mlwp_data_loaders/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,8 @@

from loguru import logger
from mlwp_data_specs import __version__ as specs_version
from mlwp_data_specs.api import validate_dataset

from .api import load_dataset
from .api import load_and_validate_dataset


def build_parser() -> argparse.ArgumentParser:
Expand Down Expand Up @@ -76,23 +75,16 @@ def main(argv: Sequence[str] | None = None) -> int:

logger.info(f"Using mlwp-data-specs {specs_version}")

kwargs = {}
if storage_options:
kwargs["storage_options"] = storage_options

# Load the dataset
ds, dataset_traits = load_dataset( # type: ignore # load_dataset returns a tuple when return_dataset_traits=True
ds, report = load_and_validate_dataset( # type: ignore
dataset_input,
loader=args.loader,
storage_options=storage_options or None,
return_dataset_traits=True,
)

time_profile = dataset_traits.get("time_profile")
space_profile = dataset_traits.get("space_profile")
uncertainty_profile = dataset_traits.get("uncertainty_profile")

report = validate_dataset(
ds,
time=time_profile,
space=space_profile,
uncertainty=uncertainty_profile,
return_validation_report=True,
**kwargs,
)

report.console_print()
Expand Down
63 changes: 14 additions & 49 deletions src/mlwp_data_loaders/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,11 @@
import importlib
import importlib.util
from pathlib import Path
from types import ModuleType
from typing import Any
from typing import Any, Callable

DatasetTraits = dict[str, Any]


def _load_module(loader: str) -> ModuleType:
"""Import a loader module from a Python file or module path.
def get_loader_func(loader: str) -> Callable[..., Any]:
"""Get the load_dataset function from a loader module.

Parameters
----------
Expand All @@ -22,13 +19,14 @@ def _load_module(loader: str) -> ModuleType:

Returns
-------
ModuleType
Imported module object.
Callable
The load_dataset function.

Raises
------
ValueError
If the loader reference cannot be resolved.
If the loader reference cannot be resolved or if the loader module
does not define a 'load_dataset' function.
"""
if loader.endswith(".py"):
path = Path(loader)
Expand All @@ -37,48 +35,15 @@ def _load_module(loader: str) -> ModuleType:
raise ValueError(f"Could not import loader module from file: {loader}")
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
return module
if "." in loader:
return importlib.import_module(loader)
raise ValueError(
"Loader must be a Python file path ending in .py or a Python module path"
)


def get_dataset_traits_from_loader(loader: str) -> DatasetTraits:
"""Import traits from a loader module.

Parameters
----------
loader : str
Loader module reference.

Returns
-------
DatasetTraits
Mapping with trait names normalized to lowercase.

Raises
------
ValueError
If the loader module does not define a 'load_dataset' function.
"""
module = _load_module(loader)
traits: DatasetTraits = {}
elif "." in loader:
module = importlib.import_module(loader)
else:
raise ValueError(
"Loader must be a Python file path ending in .py or a Python module path"
)

if not hasattr(module, "load_dataset"):
raise ValueError(
f"Loader module {loader!r} must define a 'load_dataset' function."
)
traits["load_dataset"] = module.load_dataset

supported_traits = (
"TIME_PROFILE",
"SPACE_PROFILE",
"UNCERTAINTY_PROFILE",
)
for name in supported_traits:
if hasattr(module, name):
traits[name.lower()] = getattr(module, name)

return traits
return module.load_dataset
15 changes: 10 additions & 5 deletions src/mlwp_data_loaders/loaders/anemoi/anemoi_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@
import numpy as np
import xarray as xr
from loguru import logger

TIME_PROFILE = "observation"
SPACE_PROFILE = "grid"
UNCERTAINTY_PROFILE = "deterministic"
from mlwp_data_specs.api import (
SPACE_TRAIT_ATTR,
TIME_TRAIT_ATTR,
UNCERTAINTY_TRAIT_ATTR,
)

DROP_VARS = [
"latitude",
Expand Down Expand Up @@ -78,7 +79,11 @@ def load_dataset(
f"to xr.Dataset, this might take some time. Consider selecting the relevant variables during loading"
)

return ds_selected.to_dataset(dim="variable")
ds_final = ds_selected.to_dataset(dim="variable")
ds_final.attrs[TIME_TRAIT_ATTR] = "observation"
ds_final.attrs[SPACE_TRAIT_ATTR] = "grid"
ds_final.attrs[UNCERTAINTY_TRAIT_ATTR] = "deterministic"
return ds_final


def _postprocess(dataset: xr.Dataset) -> xr.Dataset:
Expand Down
Loading
Loading