Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 54 additions & 49 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,106 +7,111 @@ Loader package for opening source datasets before validating them with

`mlwp-data-loaders` is responsible for:

1. importing a loader module or script
2. opening one or more source datasets using loader-defined traits
3. optionally checking that the chosen trait profiles are compatible with the loader
4. returning an `xarray.Dataset` that can then be validated with `mlwp-data-specs`
1. Importing a loader module (or Python file).
2. Using the loader to open and normalize source datasets.
3. Extracting the appropriate validation trait profiles (`TIME_PROFILE`, `SPACE_PROFILE`, `UNCERTAINTY_PROFILE`) defined by the loader.
4. Validating the returned `xarray.Dataset` automatically via `mlwp-data-specs`.
5. Returning the `xarray.Dataset` (and optionally the trait dict) for further use or machine learning workloads.

The intended split is:

1. `mlwp-data-loaders`: source-specific loading and normalization
2. `mlwp-data-specs`: trait validation
- **`mlwp-data-loaders`**: Source-specific loading and normalization logic.
- **`mlwp-data-specs`**: General trait validation and compliance checks.

## Python API

The `loader` argument is interpreted as:
- A Python file path if it ends with `.py`.
- A Python module path if it contains `.` (e.g. `mlwp_data_loaders.loaders.anemoi.anemoi_inference`).

- a Python file path if it ends with `.py`
- a Python module path if it contains `.`

Load a dataset through a loader module:
You can load a dataset and its trait profiles natively:

```python
from mlwp_data_loaders import load_dataset
from mlwp_data_specs import validate_dataset

ds = load_dataset(
# 1. Load the dataset and extract the trait profiles defined by the loader
ds, dataset_traits = load_dataset(
[
"/path/to/anemoi-inference-20260101T00.nc",
"/path/to/anemoi-inference-20260102T00.nc",
],
loader="mlwp_data_loaders.loaders.anemoi.anemoi_inference",
time="forecast",
space="grid",
uncertainty="deterministic",
return_dataset_traits=True,
)

# 2. Get a detailed validation report by passing the extracted traits
report = validate_dataset(
ds,
time="forecast",
space="grid",
uncertainty="deterministic",
time=dataset_traits.get("time_profile"),
space=dataset_traits.get("space_profile"),
uncertainty=dataset_traits.get("uncertainty_profile"),
)

# 3. Print the validation results to the console
report.console_print()
```

If you don't need the traits dictionary returned, simply omit `return_dataset_traits` (defaults to `False`):

```python
ds = load_dataset(
"s3://my-bucket/dataset.zarr",
loader="mlwp_data_loaders.loaders.anemoi.anemoi_datasets",
storage_options={"anon": True},
)
```

## CLI

Use the loader-aware CLI:
Use the loader-aware CLI to load and validate data from the command line:

```bash
uv run mlwp.load_and_validate_dataset \
/path/to/anemoi-inference-20260101T00.nc \
/path/to/anemoi-inference-20260102T00.nc \
--loader mlwp_data_loaders.loaders.anemoi.anemoi_inference \
--time forecast \
--space grid \
--uncertainty deterministic
--loader mlwp_data_loaders.loaders.anemoi.anemoi_inference
```

Using a user-provided loader script:
Using a user-provided custom loader script:

```bash
uv run mlwp.load_and_validate_dataset \
/path/to/source-a.nc \
/path/to/source-b.nc \
--loader ./examples/my_loader.py \
--time forecast \
--space grid \
--uncertainty deterministic
--loader ./examples/my_loader.py
```

## Loader module contract
## Loader Module Contract

The loader module may define a subset of the following:
Each loader module must define a function and optionally standard profile variables:

1. Variables defining how each provided path is opened with `xarray.open_dataset`
- `OPEN_KWARGS`: keyword arguments forwarded to `xarray.open_dataset`, including backend selection such as `{"engine": "zarr"}` or `{"engine": "h5netcdf"}`
2. Functions and variables around preprocessing, concatenation, and postprocessing
- `preprocess(ds)`: normalize each opened source dataset before combination
- `CONCAT_DIM`: dimension used when combining multiple inputs; required if more than one path is provided
- `postprocess(ds)`: finalize the combined dataset before validation
3. Variables defining valid trait profiles
- `valid_time_profiles`: allowed `time=` profile values for this loader
- `valid_space_profiles`: allowed `space=` profile values for this loader
- `valid_uncertainty_profiles`: allowed `uncertainty=` profile values for this loader
1. `load_dataset(path: str | list[str], **kwargs) -> xr.Dataset`
- **Required**. Handles opening the path(s), preprocessing, concatenating, and postprocessing, returning a single normalized `xarray.Dataset`.
2. `TIME_PROFILE`: `str`
- Defines the time trait profile for `mlwp-data-specs` validation (e.g. `"forecast"`).
3. `SPACE_PROFILE`: `str`
- Defines the space trait profile (e.g. `"grid"`).
4. `UNCERTAINTY_PROFILE`: `str`
- Defines the uncertainty trait profile (e.g. `"deterministic"`).

Example:
### Example Loader (`my_loader.py`)

```python
import xarray as xr

OPEN_KWARGS = {}


def preprocess(ds: xr.Dataset) -> xr.Dataset | xr.DataArray:
return ds

TIME_PROFILE = "observation"
SPACE_PROFILE = "grid"
UNCERTAINTY_PROFILE = "deterministic"

CONCAT_DIM = "valid_time"
def load_dataset(path: str | list[str], **kwargs) -> xr.Dataset:
if isinstance(path, list):
ds = xr.open_mfdataset(path, combine="by_coords", **kwargs)
else:
ds = xr.open_dataset(path, **kwargs)

# Example post-processing
if "time" in ds.dims:
ds = ds.rename({"time": "valid_time"})

def postprocess(ds: xr.Dataset | xr.DataArray) -> xr.Dataset | xr.DataArray:
return ds
```
23 changes: 15 additions & 8 deletions src/mlwp_data_loaders/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,9 @@ def load_dataset(
*,
loader: str,
storage_options: dict[str, Any] | None = None,
return_dataset_traits: bool = False,
**kwargs: Any,
) -> xr.Dataset:
) -> xr.Dataset | tuple[xr.Dataset, dict[str, Any]]:
"""Load a dataset through a loader module and validate it.

Parameters
Expand All @@ -29,18 +30,22 @@ def load_dataset(
path. A value containing ``.`` is treated as a Python module path.
storage_options : dict[str, Any] | None, optional
Storage options forwarded to the loader's ``load_dataset`` function.
return_dataset_traits : bool, optional
If True, return a tuple containing the dataset and the loader traits.
Defaults to False.
**kwargs
Additional keyword arguments forwarded to the loader's ``load_dataset``
function if its signature accepts them.

Returns
-------
xr.Dataset
Loaded and validated dataset.
xr.Dataset | tuple[xr.Dataset, dict[str, Any]]
Loaded and validated dataset. If `return_dataset_traits` is True,
returns a tuple of (dataset, dataset_traits).
"""
traits = get_dataset_traits_from_loader(loader)
dataset_traits = get_dataset_traits_from_loader(loader)

loader_func = traits["load_dataset"]
loader_func = dataset_traits["load_dataset"]
sig = inspect.signature(loader_func)

loader_kwargs: dict[str, Any] = {}
Expand All @@ -65,9 +70,11 @@ def load_dataset(

validate_dataset(
ds,
time=traits.get("time_profile"),
space=traits.get("space_profile"),
uncertainty=traits.get("uncertainty_profile"),
time=dataset_traits.get("time_profile"),
space=dataset_traits.get("space_profile"),
uncertainty=dataset_traits.get("uncertainty_profile"),
)

if return_dataset_traits:
return ds, dataset_traits
return ds
21 changes: 5 additions & 16 deletions src/mlwp_data_loaders/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,6 @@
from mlwp_data_specs.api import validate_dataset

from .api import load_dataset
from .core import get_dataset_traits_from_loader
from .mxalign_api import validate_dataset_with_mxalign


def build_parser() -> argparse.ArgumentParser:
Expand Down Expand Up @@ -79,18 +77,16 @@ def main(argv: Sequence[str] | None = None) -> int:
logger.info(f"Using mlwp-data-specs {specs_version}")

# Load the dataset
ds = load_dataset(
ds, dataset_traits = load_dataset( # type: ignore # load_dataset returns a tuple when return_dataset_traits=True
dataset_input,
loader=args.loader,
storage_options=storage_options or None,
return_dataset_traits=True,
)

# Re-run validation to get the report for printing
traits = get_dataset_traits_from_loader(args.loader)

time_profile = traits.get("time_profile")
space_profile = traits.get("space_profile")
uncertainty_profile = traits.get("uncertainty_profile")
time_profile = dataset_traits.get("time_profile")
space_profile = dataset_traits.get("space_profile")
uncertainty_profile = dataset_traits.get("uncertainty_profile")

report = validate_dataset(
ds,
Expand All @@ -99,13 +95,6 @@ def main(argv: Sequence[str] | None = None) -> int:
uncertainty=uncertainty_profile,
)

report += validate_dataset_with_mxalign(
ds,
time=time_profile,
space=space_profile,
uncertainty=uncertainty_profile,
)

report.console_print()
return 1 if report.has_fails() else 0

Expand Down
20 changes: 10 additions & 10 deletions tests/test_anemoi_datasets_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from mlwp_data_specs import validate_dataset

from mlwp_data_loaders.api import load_dataset
from mlwp_data_loaders.core import get_dataset_traits_from_loader
from mlwp_data_loaders.mxalign_api import validate_dataset_with_mxalign

# Use small CERRA sample dataset stored on EWC (European Weather Cloud)
Expand All @@ -25,30 +24,31 @@ def test_load_dataset_opens_anemoi_store_from_ewc() -> None:
"anon": True,
}

ds = load_dataset(
ds, dataset_traits = load_dataset( # type: ignore # load_dataset returns a tuple when return_dataset_traits=True
DATASET_PATH,
loader=LOADER,
storage_options=storage_options,
chunks=None,
return_dataset_traits=True,
)

traits = get_dataset_traits_from_loader(LOADER)

# Note: mxalign validation is temporarily kept here during early development
# to ensure `mlwp-data-specs` behaves identically. It will eventually be removed.
report_mxalign = validate_dataset_with_mxalign(
ds,
time=traits.get("time_profile"),
space=traits.get("space_profile"),
uncertainty=traits.get("uncertainty_profile"),
time=dataset_traits.get("time_profile"),
space=dataset_traits.get("space_profile"),
uncertainty=dataset_traits.get("uncertainty_profile"),
)
if report_mxalign.has_fails():
report_mxalign.console_print()
assert not report_mxalign.has_fails()

report_specs = validate_dataset(
ds,
time=traits.get("time_profile"),
space=traits.get("space_profile"),
uncertainty=traits.get("uncertainty_profile"),
time=dataset_traits.get("time_profile"),
space=dataset_traits.get("space_profile"),
uncertainty=dataset_traits.get("uncertainty_profile"),
)
if report_specs.has_fails():
report_specs.console_print()
Expand Down
29 changes: 29 additions & 0 deletions tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,11 +82,40 @@ def test_load_dataset_filters_kwargs(tmp_path, monkeypatch: pytest.MonkeyPatch)
engine="h5netcdf", # Should be ignored because strict load_dataset doesn't take 'engine'
storage_options={"anon": True}, # Should be ignored
)
assert isinstance(ds, xr.Dataset)

assert ds.attrs["chunks"] == "auto"
assert "engine" not in ds.attrs


def test_load_dataset_returns_traits(tmp_path, monkeypatch: pytest.MonkeyPatch) -> None:
"""Check that api.load_dataset returns traits when requested."""
loader_file = tmp_path / "loader_traits.py"
loader_file.write_text(
"def load_dataset(path, **kwargs):\n"
" from xarray import Dataset\n"
" return Dataset()\n"
"TIME_PROFILE = 'forecast'\n",
encoding="utf-8",
)

monkeypatch.setattr(
"mlwp_data_loaders.api.validate_dataset", lambda *args, **kwargs: None
)

res = load_dataset(
"dummy.nc",
loader=str(loader_file),
return_dataset_traits=True,
)
assert isinstance(res, tuple)

ds, dataset_traits = res # type: ignore # load_dataset returns a tuple when return_dataset_traits=True
assert isinstance(ds, xr.Dataset)
assert isinstance(dataset_traits, dict)
assert dataset_traits.get("time_profile") == "forecast"


def test_validate_dataset_with_mxalign_returns_fail_report_for_invalid_dims(
monkeypatch: pytest.MonkeyPatch,
) -> None:
Expand Down
19 changes: 6 additions & 13 deletions tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,15 +37,14 @@ def test_cli_accepts_multiple_dataset_paths(monkeypatch: MonkeyPatch) -> None:

def _load_dataset(dataset_path, **kwargs):
observed["dataset_path"] = dataset_path
if kwargs.get("return_dataset_traits"):
return _forecast_grid_ds(), {
"time_profile": "forecast",
"space_profile": "grid",
"uncertainty_profile": "deterministic",
}
return _forecast_grid_ds()

def _get_dataset_traits_from_loader(loader):
return {
"time_profile": "forecast",
"space_profile": "grid",
"uncertainty_profile": "deterministic",
}

class _Report:
def __init__(self):
self.fails = False
Expand All @@ -63,13 +62,7 @@ def _validate_dataset(ds, **kwargs):
return _Report()

monkeypatch.setattr(cli, "load_dataset", _load_dataset)
monkeypatch.setattr(
cli, "get_dataset_traits_from_loader", _get_dataset_traits_from_loader
)
monkeypatch.setattr(cli, "validate_dataset", _validate_dataset)
monkeypatch.setattr(
cli, "validate_dataset_with_mxalign", lambda *args, **kwargs: _Report()
)

code = cli.main(
[
Expand Down
Loading
Loading