Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,4 @@ jobs:
with:
python-version: '3.12'
- name: Run tests
run: uv run python -m pytest
run: uv run --extra test python -m pytest
68 changes: 68 additions & 0 deletions DEVELOPING.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
# Developing

## Environment

This project uses `uv` for dependency management and local commands.

Install dependencies into the project environment with:

```bash
uv sync --extra test --group dev
```

If you only need to run a one-off command, you can also use `uv run ...`
without activating the environment.

## Running Tests

Run the full test suite with:

```bash
uv run python -m pytest
```

Run a single test file with:

```bash
uv run python -m pytest tests/test_anemoi_datasets_integration.py
```

## Pre-commit

This repository includes a checked-in
[`.pre-commit-config.yaml`](.pre-commit-config.yaml).
CI runs the same hooks via `.github/workflows/pre-commit.yml`.

Install the development dependencies first:

```bash
uv sync --extra test --group dev
```

Then run the hooks locally with:

```bash
pre-commit run --all-files
```

You can also install the git hook so checks run before each commit:

```bash
pre-commit install
```

## CA Certificates

`certifi` is included so that `botocore` and `aiobotocore` use an up-to-date CA
bundle when opening datasets from custom S3 endpoints over HTTPS.

This matters for the ECMWF object-store endpoint used by the
`anemoi-datasets` integration test:

- the endpoint certificate chain is valid for standard HTTPS clients
- older bundled CA bundles in `botocore` do not include the required
`HARICA TLS RSA Root CA 2021` root
- when `certifi` is installed, `botocore` uses the `certifi` bundle by default

The minimum supported version is `certifi>=2021.10.8` because that is the first
`certifi` release we verified to include `HARICA TLS RSA Root CA 2021`.
26 changes: 19 additions & 7 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,13 @@ description = "Loader package for opening datasets before mlwp-data-specs valida
readme = "README.md"
requires-python = ">=3.12"
dependencies = [
"xarray",
"loguru",
"fsspec",
"s3fs",
"zarr>2,<3",
"h5netcdf",
"mlwp-data-specs",
"xarray>=2024.1.0",
"loguru>=0.7.0",
"fsspec>=2024.2.0",
"s3fs>=2024.2.0",
"zarr>=3.1.5",
"h5netcdf>=1.3.0",
]

[tool.isort]
Expand All @@ -24,6 +24,7 @@ profile = "black"
[project.optional-dependencies]
test = [
"pytest>=8.0.0",
"certifi>=2021.10.8",
]

[project.scripts]
Expand All @@ -39,4 +40,15 @@ where = ["src"]
testpaths = ["tests"]

[tool.uv.sources]
mlwp-data-specs = { git = "https://github.com/mlwp-tools/mlwp-data-specs", rev = "1aafd6b" }
mlwp-data-specs = { git = "https://github.com/mlwp-tools/mlwp-data-specs", rev = "3a7529b" }
mxalign = { git = "https://github.com/mlwp-tools/mxalign", rev = "e2232d93275c7508897a7ddb0cce8b508665f24c" }

[dependency-groups]
dev = [
"pre-commit>=4.0.0",
"ipython>=9.11.0",
"pytest>=9.0.2",
]
test = [
"mxalign",
]
40 changes: 35 additions & 5 deletions src/mlwp_data_loaders/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,17 @@

from .api import load_dataset
from .core import get_dataset_traits_from_loader
from .mxalign_api import validate_dataset_with_mxalign


def build_parser() -> argparse.ArgumentParser:
"""Build the CLI argument parser."""
"""Build the CLI argument parser.

Returns
-------
argparse.ArgumentParser
The configured argument parser.
"""
parser = argparse.ArgumentParser(
description=(
"Load a dataset through a loader module and validate it with "
Expand Down Expand Up @@ -44,7 +51,18 @@ def build_parser() -> argparse.ArgumentParser:

@logger.catch
def main(argv: Sequence[str] | None = None) -> int:
"""Run dataset loading and validation from CLI arguments."""
"""Run dataset loading and validation from CLI arguments.

Parameters
----------
argv : Sequence[str] | None, optional
Command line arguments. Defaults to None, which uses sys.argv[1:].

Returns
-------
int
Exit code: 0 for success, 1 for validation failures.
"""
parser = build_parser()
args = parser.parse_args(argv)

Expand All @@ -69,11 +87,23 @@ def main(argv: Sequence[str] | None = None) -> int:

# Re-run validation to get the report for printing
traits = get_dataset_traits_from_loader(args.loader)

time_profile = traits.get("time_profile")
space_profile = traits.get("space_profile")
uncertainty_profile = traits.get("uncertainty_profile")

report = validate_dataset(
ds,
time=traits.get("time_profile"),
space=traits.get("space_profile"),
uncertainty=traits.get("uncertainty_profile"),
time=time_profile,
space=space_profile,
uncertainty=uncertainty_profile,
)

report += validate_dataset_with_mxalign(
ds,
time=time_profile,
space=space_profile,
uncertainty=uncertainty_profile,
)

report.console_print()
Expand Down
47 changes: 39 additions & 8 deletions src/mlwp_data_loaders/loaders/anemoi/anemoi_datasets.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Any

import numpy as np
import xarray as xr
from loguru import logger
Expand All @@ -24,20 +26,31 @@
COORDS = dict(longitude="longitudes", latitude="latitudes", valid_time="dates")


def load_dataset(path, chunks="auto", consolidated=False, variables=None):
def load_dataset(
path: str,
chunks: str | dict | None = "auto",
consolidated: bool = False,
variables: str | list[str] | None = None,
storage_options: dict[str, Any] | None = None,
**kwargs: Any,
) -> xr.Dataset:
"""
Load Anemoi datasets from Zarr files.

Parameters
----------
path : str
Path to the Zarr dataset.
chunks : str or dict, default: "auto"
chunks : str or dict or None, default: "auto"
Chunk size or strategy for dask arrays.
consolidated : bool, default: False
Whether to use consolidated metadata when opening the Zarr store.
variables : str or list of str, optional
List of variables to select from the dataset. If None, all variables are kept.
storage_options : dict of str to Any, optional
Storage options passed to xarray.open_zarr (e.g. for S3 access).
**kwargs
Additional keyword arguments passed to xarray.open_zarr.

Returns
-------
Expand All @@ -46,7 +59,13 @@ def load_dataset(path, chunks="auto", consolidated=False, variables=None):
"""
variables = [variables] if isinstance(variables, str) else variables

ds = xr.open_zarr(path, consolidated=consolidated, chunks=chunks) # type: ignore
ds = xr.open_zarr(
path,
consolidated=consolidated,
chunks=chunks,
storage_options=storage_options,
**kwargs,
) # type: ignore
ds_postproc = _postprocess(ds)

if variables:
Expand All @@ -65,12 +84,15 @@ def load_dataset(path, chunks="auto", consolidated=False, variables=None):
def _postprocess(dataset: xr.Dataset) -> xr.Dataset:
"""Post-process the dataset to add coordinates and drop unused variables.

Args:
dataset (xr.Dataset): The input dataset to be processed.
Parameters
----------
dataset : xr.Dataset
The input dataset to be processed.

Returns:
xr.Dataset: The processed dataset with assigned coordinates and
attributes.
Returns
-------
xr.Dataset
The processed dataset with assigned coordinates and attributes.
"""

# Add coordinates
Expand All @@ -97,4 +119,13 @@ def _postprocess(dataset: xr.Dataset) -> xr.Dataset:
.swap_dims({"time": "valid_time"})
.rename({"cell": "grid_index"})
)

ds_pruned.coords["valid_time"].attrs["standard_name"] = "time"
ds_pruned.coords["latitude"].attrs.update(
{"standard_name": "latitude", "units": "degrees_north"}
)
ds_pruned.coords["longitude"].attrs.update(
{"standard_name": "longitude", "units": "degrees_east"}
)

return ds_pruned # type: ignore
26 changes: 23 additions & 3 deletions src/mlwp_data_loaders/loaders/anemoi/anemoi_inference.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,27 @@
from typing import Any

import xarray as xr

TIME_PROFILE = "forecast"
SPACE_PROFILE = "grid"
UNCERTAINTY_PROFILE = "deterministic"


def load_dataset(paths, chunks="auto", engine="h5netcdf", parallel=True, **kwargs):
def load_dataset(
paths: str | list[str],
chunks: str | dict | None = "auto",
engine: str = "h5netcdf",
parallel: bool = True,
**kwargs: Any,
) -> xr.Dataset:
"""
Load Anemoi inference datasets from NetCDF/HDF5 files.

Parameters
----------
paths : str or list of str
Path or list of paths to the dataset files.
chunks : str or dict, default: "auto"
chunks : str or dict or None, default: "auto"
Chunk size or strategy for dask arrays.
engine : str, default: "h5netcdf"
Engine to use for reading the files.
Expand Down Expand Up @@ -54,7 +62,19 @@ def load_dataset(paths, chunks="auto", engine="h5netcdf", parallel=True, **kwarg
return ds_out


def _preprocess(ds):
def _preprocess(ds: xr.Dataset) -> xr.Dataset:
"""Preprocess individual datasets before concatenation.

Parameters
----------
ds : xr.Dataset
The input dataset to preprocess.

Returns
-------
xr.Dataset
The preprocessed dataset with reference time expanded.
"""
ds_out = (
ds.set_coords(["longitude", "latitude"])
.expand_dims("reference_time")
Expand Down
Loading
Loading