From 5c52e09ebde74159539381cbe56d616d1977db43 Mon Sep 17 00:00:00 2001 From: Brendan Murphy Date: Thu, 30 Apr 2026 09:30:58 +0100 Subject: [PATCH 1/3] Add modern RHIME runners alongside fixedbasisMCMC - Adds `run_rhime(...)` for standard single-sector RHIME inversions. - Adds `run_rhime_multisector(...)` for shared-basis multi-sector RHIME inversions. - Adds `openghg-inversions run-rhime ...` and `openghg-inversions run-rhime-multisector ...` CLI entry points for config-file driven runs. - Adds lightweight RHIME result/spec dataclasses and a RHIME config template. - Adds direct modern `InversionOutput` construction for the standard RHIME path. - Adds sector-aware diagnostic output for multi-sector runs. The new RHIME runners reuse the existing data preparation and component-based PyMC model pieces, but do not route public modern behavior through `fixedbasisMCMC` or the legacy `inferpymc` output adapter. --- CHANGELOG.md | 1 + README.md | 31 + openghg_inversions/__init__.py | 1 + openghg_inversions/cli.py | 84 ++ .../config/templates/rhime_template.ini | 78 ++ openghg_inversions/inversion_data/get_data.py | 12 +- openghg_inversions/inversion_inputs.py | 9 +- openghg_inversions/models/rhime.py | 304 +++++ .../postprocessing/inversion_output.py | 7 +- .../postprocessing/make_outputs.py | 13 +- openghg_inversions/rhime.py | 1126 +++++++++++++++++ pyproject.toml | 4 + tests/test_get_data.py | 37 +- tests/test_inversion_inputs.py | 12 + tests/test_rhime.py | 425 +++++++ 15 files changed, 2120 insertions(+), 24 deletions(-) create mode 100644 openghg_inversions/cli.py create mode 100644 openghg_inversions/config/templates/rhime_template.ini create mode 100644 openghg_inversions/models/rhime.py create mode 100644 openghg_inversions/rhime.py create mode 100644 tests/test_rhime.py diff --git a/CHANGELOG.md b/CHANGELOG.md index bdbf87e9..db55ab93 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,7 @@ ## Code changes +- Added modern `run_rhime` and shared-basis `run_rhime_multisector` pipelines, RHIME CLI entry points, RHIME config template, modern result/spec objects, and focused tests for the new public runners. [#398](https://github.com/openghg/openghg_inversions/issues/398) - Made concat-gather handling of mismatched site data variables order-independent, added an opt-in drop policy used by `make_inv_inputs`, and added lightweight regression tests for issue #394. [#394](https://github.com/openghg/openghg_inversions/issues/394) - Fix bug which was assigninig the wrong times to inversion flux outputs in non-standard cases, such as 3-monthly inversions. [#PR 387](https://github.com/openghg/openghg_inversions/pull/387) - Fix small bug where postprocessing was failing if country codes in file didn't match exactly those in `paris_regions_dict`. [#PR 377](https://github.com/openghg/openghg_inversions/pull/377) diff --git a/README.md b/README.md index b341f3dc..8b1917fa 100644 --- a/README.md +++ b/README.md @@ -113,6 +113,37 @@ Solutions to this are: For an overview of OpenGHG inversions, see this [primer](docs/getting_started.md). +### Modern RHIME entry points + +New RHIME runs can be launched without calling an internal source file path: + +```python +from openghg_inversions.rhime import run_rhime, run_rhime_multisector + +result = run_rhime( + species="ch4", + sites=["TAC"], + averaging_period=["1h"], + domain="EUROPE", + start_date="2019-01-01", + end_date="2019-01-02", + output_path="outputs", + output_name="example", + flux_sources=["total-ukghg-edgar7"], +) +``` + +For SLURM batch scripts and installed environments, use the console entry point: + +```bash +openghg-inversions run-rhime 2019-01-01 2019-01-02 -c rhime.ini --output-path outputs +openghg-inversions run-rhime-multisector 2019-01-01 2019-01-02 -c rhime_multisector.ini +``` + +The new RHIME config template is available at +`openghg_inversions/config/templates/rhime_template.ini`. New configs should use +`flux_sources`; legacy `emissions_name` is accepted when `flux_sources` is absent. + ### Passing parameters to the inversion Keyword arguments are propagated as follows: diff --git a/openghg_inversions/__init__.py b/openghg_inversions/__init__.py index e69de29b..1b1de400 100644 --- a/openghg_inversions/__init__.py +++ b/openghg_inversions/__init__.py @@ -0,0 +1 @@ +"""OpenGHG inversions.""" diff --git a/openghg_inversions/cli.py b/openghg_inversions/cli.py new file mode 100644 index 00000000..6ce756c9 --- /dev/null +++ b/openghg_inversions/cli.py @@ -0,0 +1,84 @@ +"""Command line interface for OpenGHG inversions.""" + +from __future__ import annotations + +import argparse +import json +from typing import Any + + +def _add_run_args(parser: argparse.ArgumentParser) -> None: + parser.add_argument("start", help="Start date string of the format YYYY-MM-DD", nargs="?") + parser.add_argument("end", help="End date string of the format YYYY-MM-DD", nargs="?") + parser.add_argument("-c", "--config", help="Name including path of configuration file", required=True) + parser.add_argument( + "--kwargs", + type=json.loads, + help="Pass keyword arguments to the RHIME function, e.g. '{\"nit\": 10}'.", + ) + parser.add_argument("--output-path", help="Path to write results to.") + + +def _command_kwargs(args: argparse.Namespace) -> dict[str, Any]: + """Create keyword overrides from parsed CLI arguments.""" + kwargs: dict[str, Any] = {} + if args.start: + kwargs["start_date"] = args.start + if args.end: + kwargs["end_date"] = args.end + if args.output_path: + kwargs["output_path"] = args.output_path + if args.kwargs: + kwargs.update(args.kwargs) + return kwargs + + +def _run_rhime_command(args: argparse.Namespace) -> None: + """Run the standard RHIME command with lazy imports for fast help output.""" + from openghg_inversions.rhime import run_rhime + + run_rhime(config_file=args.config, **_command_kwargs(args)) + + +def _run_rhime_multisector_command(args: argparse.Namespace) -> None: + """Run the multi-sector RHIME command with lazy imports for fast help output.""" + from openghg_inversions.rhime import run_rhime_multisector + + run_rhime_multisector(config_file=args.config, **_command_kwargs(args)) + + +def build_parser() -> argparse.ArgumentParser: + """Build the OpenGHG inversions CLI argument parser. + + Returns: + Configured argument parser. + """ + parser = argparse.ArgumentParser(prog="openghg-inversions", description="OpenGHG inversions CLI") + subparsers = parser.add_subparsers(dest="command", required=True) + + run_parser = subparsers.add_parser("run-rhime", help="Run a standard RHIME inversion") + _add_run_args(run_parser) + run_parser.set_defaults(func=_run_rhime_command) + + run_multi_parser = subparsers.add_parser( + "run-rhime-multisector", help="Run a shared-basis multi-sector RHIME inversion" + ) + _add_run_args(run_multi_parser) + run_multi_parser.set_defaults(func=_run_rhime_multisector_command) + + return parser + + +def main(argv: list[str] | None = None) -> None: + """Run the OpenGHG inversions CLI. + + Args: + argv: Optional argument vector. Defaults to ``sys.argv`` when omitted. + """ + parser = build_parser() + args = parser.parse_args(argv) + args.func(args) + + +if __name__ == "__main__": + main() diff --git a/openghg_inversions/config/templates/rhime_template.ini b/openghg_inversions/config/templates/rhime_template.ini new file mode 100644 index 00000000..df5b8601 --- /dev/null +++ b/openghg_inversions/config/templates/rhime_template.ini @@ -0,0 +1,78 @@ +; ======================================================================================= +; OpenGHG Inversions RHIME configuration file +; ======================================================================================= +; New RHIME configs prefer flux_sources. Legacy emissions_name is still accepted when +; flux_sources is absent. + +[INPUT.MEASUREMENTS] +species = "" +sites = [] +averaging_period = [] +start_date = " " +end_date = " " +inlet = None +instrument = None +calibration_scale = None +obs_data_level = None +filters = [] + +[INPUT.STORES] +bc_store = "user" +obs_store = "user" +footprint_store = "user" +emissions_store = "user" + +[INPUT.PRIORS] +domain = " " +met_model = None +fp_model = None +fp_height = None +fp_species = None +flux_sources = [] +bc_input = None + +[INPUT.BASIS_CASE] +basis_algorithm = "weighted" +bc_basis_case = "NESW" +fp_basis_case = None +nbasis = 100 +basis_directory = None +bc_basis_directory = None +country_file = None +country_directory = None + +[RHIME.PDF] +x_prior = {"pdf": "lognormal", "mean": 1.0, "stdev": 1.0} +sector_priors = None +bc_prior = {"pdf": "truncatednormal", "mu": 1.0, "sigma": 0.05, "lower": 0.0} +sigma_prior = {"pdf": "uniform", "lower": 0.1, "upper": 3.0} +add_offset = False +offset_prior = {"pdf": "normal", "mu": 0, "sigma": 1} + +[RHIME.SPLIT] +bc_freq = None +sigma_freq = None +sigma_per_site = True + +[RHIME.ITERATIONS] +nit = 1000 +burn = 0 +tune = 1000 +nchain = 4 + +[RHIME.OPTIONS] +averaging_error = True +min_error = 0.0 +fix_basis_outer_regions = False +use_bc = True +nuts_sampler = "pymc" +save_trace = False +save_inversion_output = True +pollution_events_from_obs = False +no_model_error = False +sampler_kwargs = {} + +[RHIME.OUTPUT] +output_path = " " +output_name = "rhime" +output_format = "inv_out" diff --git a/openghg_inversions/inversion_data/get_data.py b/openghg_inversions/inversion_data/get_data.py index e92cdbc9..642ac46b 100644 --- a/openghg_inversions/inversion_data/get_data.py +++ b/openghg_inversions/inversion_data/get_data.py @@ -61,19 +61,21 @@ def add_obs_error(sites: list[str], fp_all: dict, add_averaging_error: bool = Tr # TODO: do we want to fill missing values in repeatability or variability? for site in sites: ds = fp_all[site] + mf_long_name = ds.mf.attrs.get("long_name", "") + mf_units = ds.mf.attrs.get("units", None) variability_missing = False if "mf_variability" not in ds: ds["mf_variability"] = xr.zeros_like(ds.mf) - ds["mf_variability"].attrs["long_name"] = ds.mf.attrs.get("long_name", "") + "_variability" variability_missing = True + ds["mf_variability"].attrs["long_name"] = mf_long_name + "_variability" + ds["mf_variability"].attrs["units"] = mf_units if "mf_repeatability" not in ds: if variability_missing: raise ValueError(f"Obs data for site {site} is missing both repeatability and variability.") ds["mf_repeatability"] = xr.zeros_like(ds.mf_variability) - ds["mf_repeatability"].attrs["long_name"] = ds.mf.attrs.get("long_name", "") + "_repeatability" ds["mf_error"] = ds["mf_variability"] @@ -90,8 +92,10 @@ def add_obs_error(sites: list[str], fp_all: dict, add_averaging_error: bool = Tr else: ds["mf_error"] = ds["mf_repeatability"] - ds["mf_error"].attrs["long_name"] = ds.mf.attrs.get("long_name", "") + "_error" - ds["mf_error"].attrs["units"] = ds.mf.attrs.get("units", None) + ds["mf_repeatability"].attrs["long_name"] = mf_long_name + "_repeatability" + ds["mf_repeatability"].attrs["units"] = mf_units + ds["mf_error"].attrs["long_name"] = mf_long_name + "_error" + ds["mf_error"].attrs["units"] = mf_units # warnings/info for debugging err0 = (ds["mf_error"] == 0) | ( diff --git a/openghg_inversions/inversion_inputs.py b/openghg_inversions/inversion_inputs.py index d459823a..44f8f75d 100644 --- a/openghg_inversions/inversion_inputs.py +++ b/openghg_inversions/inversion_inputs.py @@ -1,6 +1,7 @@ """Functions for creating the inputs needed by PyMC.""" import datetime as dt +import numbers from typing import Any, Iterable, Literal import numpy as np @@ -116,12 +117,14 @@ def make_sigma_freq( def add_min_error( ds: xr.Dataset, fp_data: dict[str, Any], - min_error: str | dict[str, float] | float = 0.0, + min_error: str | dict[str, float] | int | float = 0.0, min_error_per_site: bool = True, ) -> xr.Dataset: """Add min_error to combined Dataset.""" min_error_data: xr.DataArray | float | np.ndarray - if isinstance(min_error, float) or (isinstance(min_error, np.ndarray) and min_error.ndim == 0): + if isinstance(min_error, numbers.Real) and not isinstance(min_error, bool): + min_error_data = float(min_error) * xr.ones_like(ds.mf) + elif isinstance(min_error, np.ndarray) and min_error.ndim == 0: min_error_data = min_error * xr.ones_like(ds.mf) elif isinstance(min_error, dict): sites = [k for k in fp_data if not k.startswith(".")] @@ -274,7 +277,7 @@ def make_inv_inputs( sites: list[str] | None = None, bc_freq: Literal["monthly"] | str | None = None, sigma_freq: Literal["monthly"] | str | None = None, - min_error: str | dict[str, float] | float = 0.0, + min_error: str | dict[str, float] | int | float = 0.0, min_error_per_site: bool = True, start_date: DatetimeLike | None = None, ) -> xr.Dataset: diff --git a/openghg_inversions/models/rhime.py b/openghg_inversions/models/rhime.py new file mode 100644 index 00000000..03168021 --- /dev/null +++ b/openghg_inversions/models/rhime.py @@ -0,0 +1,304 @@ +"""RHIME model builders. + +These builders are the modern public model-construction names. They reuse the +component-based PyMC helpers, while keeping the legacy ``inferpymc`` adapter out +of the RHIME runtime path. +""" + +from __future__ import annotations + +import re +from collections.abc import Mapping, Sequence + +import pymc as pm +import pytensor.tensor as pt +import xarray as xr + +from openghg_inversions.models.components import ( + add_inferpymc_likelihood_component, + add_linear_component, + add_offset_component, +) +from openghg_inversions.models.coords import CoordRegistry, attach_coord_registry +from openghg_inversions.models.priors import PriorArgs + +DEFAULT_X_PRIOR: PriorArgs = {"pdf": "lognormal", "mean": 1.0, "stdev": 1.0, "reparameterise": True} +DEFAULT_BC_PRIOR: PriorArgs = {"pdf": "truncatednormal", "mu": 1.0, "sigma": 0.05, "lower": 0.0} +DEFAULT_SIGMA_PRIOR: PriorArgs = {"pdf": "uniform", "lower": 0.1, "upper": 3.0} +DEFAULT_OFFSET_PRIOR: PriorArgs = {"pdf": "normal", "mu": 0, "sigma": 1} + + +def safe_pymc_name(value: str) -> str: + """Return a stable PyMC-safe suffix for a user-facing sector/source name. + + Args: + value: User-facing sector or source name. + + Returns: + Lowercase snake-case suffix safe to use in PyMC variable names. + """ + name = re.sub(r"\W+", "_", str(value).strip().lower()).strip("_") + return name or "sector" + + +def _prepare_builder_priors( + *, + x_prior: dict | None, + bc_prior: dict | None, + sigma_prior: dict | None, + offset_prior: dict | None, +) -> tuple[dict, dict, dict, dict]: + """Copy builder priors, applying RHIME model defaults when omitted.""" + prepared_x_prior = DEFAULT_X_PRIOR.copy() if x_prior is None else x_prior.copy() + prepared_bc_prior = DEFAULT_BC_PRIOR.copy() if bc_prior is None else bc_prior.copy() + prepared_sigma_prior = DEFAULT_SIGMA_PRIOR.copy() if sigma_prior is None else sigma_prior.copy() + prepared_offset_prior = DEFAULT_OFFSET_PRIOR.copy() if offset_prior is None else offset_prior.copy() + return prepared_x_prior, prepared_bc_prior, prepared_sigma_prior, prepared_offset_prior + + +def build_rhime_model( + inv_inputs: xr.Dataset, + *, + x_prior: dict | None = None, + bc_prior: dict | None = None, + sigma_prior: dict | None = None, + sigma_per_site: bool = True, + offset_prior: dict | None = None, + add_offset: bool = False, + use_bc: bool = True, + pollution_events_from_obs: bool = False, + no_model_error: bool = False, + offset_args: dict | None = None, + power: dict | float = 1.99, +) -> pm.Model: + """Build the standard single-sector RHIME model. + + Args: + inv_inputs: Canonical inversion-input dataset produced by + ``make_inv_inputs``. + x_prior: Prior specification for flux scaling factors. + bc_prior: Prior specification for boundary-condition scaling factors. + sigma_prior: Prior specification for model-error terms. + sigma_per_site: Whether model-error terms vary by site. + offset_prior: Prior specification for optional offsets. + add_offset: Whether to include an offset term. + use_bc: Whether to include boundary-condition terms. + pollution_events_from_obs: Whether to derive pollution-event scaling + from observations rather than modelled concentrations. + no_model_error: Whether to suppress the explicit model-error term. + offset_args: Extra keyword arguments forwarded to the offset component. + power: Exponent or prior specification used in likelihood error scaling. + + Returns: + Built PyMC model. + """ + x_prior, bc_prior, sigma_prior, offset_prior = _prepare_builder_priors( + x_prior=x_prior, + bc_prior=bc_prior, + sigma_prior=sigma_prior, + offset_prior=offset_prior, + ) + + with pm.Model() as model: + attach_coord_registry(model, CoordRegistry()) + flux_component = add_linear_component( + inv_inputs["H"], + data_name="hx", + prior_args=x_prior, + var_name="x", + output_name="mu", + output_dim="nmeasure", + compute_deterministic=True, + ) + + mu_bc = None + if use_bc: + if "H_bc" not in inv_inputs: + raise ValueError("If `use_bc` is True, `inv_inputs` must contain `H_bc`.") + bc_component = add_linear_component( + inv_inputs["H_bc"], + data_name="hbc", + prior_args=bc_prior, + var_name="bc", + output_name="mu_bc", + output_dim="nmeasure", + compute_deterministic=True, + ) + mu_bc = bc_component.output + + offset = None + if add_offset: + offset_args = offset_args or {} + offset = add_offset_component( + inv_inputs["site_indicator"], + prior_args=offset_prior, + output_name="offset", + output_dim="nmeasure", + **offset_args, + ) + + add_inferpymc_likelihood_component( + inv_inputs, + mu=flux_component.output, + mu_bc=mu_bc, + offset=offset, + sigprior=sigma_prior, + power=power, + pollution_events_from_obs=pollution_events_from_obs, + no_model_error=no_model_error, + sigma_per_site=sigma_per_site, + output_dim="nmeasure", + ) + + return model + + +def _resolve_sectors(inv_inputs: xr.Dataset, sectors: Sequence[str] | None) -> list[str]: + """Resolve requested sector names against the source coordinate.""" + if "source" not in inv_inputs["H"].dims: + raise ValueError("Multi-sector RHIME requires inv_inputs['H'] to include a 'source' dimension.") + + available = [str(value) for value in inv_inputs["H"].coords["source"].values] + if sectors is None: + sectors = available + sectors = [str(sector) for sector in sectors] + + missing = [sector for sector in sectors if sector not in available] + if missing: + raise ValueError(f"Sector(s) {missing!r} are not present in inv_inputs['H'].source.") + if len(sectors) < 2: + raise ValueError("Multi-sector RHIME requires at least two sectors.") + + return sectors + + +def _sector_prior( + sector: str, + *, + sector_priors: Mapping[str, dict] | None, + x_prior: dict | None, +) -> dict: + """Resolve the prior for a sector, falling back to the shared x prior.""" + if sector_priors is not None and sector in sector_priors: + return dict(sector_priors[sector]) + return dict(DEFAULT_X_PRIOR if x_prior is None else x_prior) + + +def build_rhime_multisector_model( + inv_inputs: xr.Dataset, + *, + sectors: Sequence[str] | None = None, + sector_priors: Mapping[str, dict] | None = None, + x_prior: dict | None = None, + bc_prior: dict | None = None, + sigma_prior: dict | None = None, + sigma_per_site: bool = True, + offset_prior: dict | None = None, + add_offset: bool = False, + use_bc: bool = True, + pollution_events_from_obs: bool = False, + no_model_error: bool = False, + offset_args: dict | None = None, + power: dict | float = 1.99, +) -> pm.Model: + """Build the first shared-basis multi-sector RHIME model. + + Each sector receives its own state vector ``x_`` and forward-model + contribution ``mu_``. The total ``mu`` is the sum of sector + contributions and is passed to the standard RHIME likelihood. + + Args: + inv_inputs: Canonical inversion-input dataset with + ``H(region, nmeasure, source)``. + sectors: Ordered sector/source names to optimise. Defaults to all + ``inv_inputs.H.source`` values. + sector_priors: Optional per-sector flux-scaling priors. + x_prior: Shared fallback flux-scaling prior. + bc_prior: Prior specification for boundary-condition scaling factors. + sigma_prior: Prior specification for model-error terms. + sigma_per_site: Whether model-error terms vary by site. + offset_prior: Prior specification for optional offsets. + add_offset: Whether to include an offset term. + use_bc: Whether to include boundary-condition terms. + pollution_events_from_obs: Whether to derive pollution-event scaling + from observations rather than modelled concentrations. + no_model_error: Whether to suppress explicit model-error terms. + offset_args: Extra keyword arguments forwarded to the offset component. + power: Exponent or prior specification used in likelihood error scaling. + + Returns: + Built PyMC model. + """ + sectors = _resolve_sectors(inv_inputs, sectors) + bc_prior = dict(DEFAULT_BC_PRIOR if bc_prior is None else bc_prior) + sigma_prior = dict(DEFAULT_SIGMA_PRIOR if sigma_prior is None else sigma_prior) + offset_prior = dict(DEFAULT_OFFSET_PRIOR if offset_prior is None else offset_prior) + + with pm.Model() as model: + attach_coord_registry(model, CoordRegistry()) + + sector_outputs = [] + used_names: set[str] = set() + for sector in sectors: + suffix = safe_pymc_name(sector) + if suffix in used_names: + raise ValueError( + "Sector names must be unique after PyMC name sanitisation; " + f"duplicate sanitized name {suffix!r}." + ) + used_names.add(suffix) + + h_sector = inv_inputs["H"].sel(source=sector).drop_vars("source", errors="ignore") + component = add_linear_component( + h_sector, + data_name=f"hx_{suffix}", + prior_args=_sector_prior(sector, sector_priors=sector_priors, x_prior=x_prior), + var_name=f"x_{suffix}", + output_name=f"mu_{suffix}", + output_dim="nmeasure", + compute_deterministic=True, + ) + sector_outputs.append(component.output) + + total_mu = pm.Deterministic("mu", pt.stack(sector_outputs, axis=0).sum(axis=0), dims="nmeasure") + + mu_bc = None + if use_bc: + if "H_bc" not in inv_inputs: + raise ValueError("If `use_bc` is True, `inv_inputs` must contain `H_bc`.") + bc_component = add_linear_component( + inv_inputs["H_bc"], + data_name="hbc", + prior_args=bc_prior, + var_name="bc", + output_name="mu_bc", + output_dim="nmeasure", + compute_deterministic=True, + ) + mu_bc = bc_component.output + + offset = None + if add_offset: + offset_args = offset_args or {} + offset = add_offset_component( + inv_inputs["site_indicator"], + prior_args=offset_prior, + output_name="offset", + output_dim="nmeasure", + **offset_args, + ) + + add_inferpymc_likelihood_component( + inv_inputs, + mu=total_mu, + mu_bc=mu_bc, + offset=offset, + sigprior=sigma_prior, + power=power, + pollution_events_from_obs=pollution_events_from_obs, + no_model_error=no_model_error, + sigma_per_site=sigma_per_site, + output_dim="nmeasure", + ) + + return model diff --git a/openghg_inversions/postprocessing/inversion_output.py b/openghg_inversions/postprocessing/inversion_output.py index 42893270..04df8682 100644 --- a/openghg_inversions/postprocessing/inversion_output.py +++ b/openghg_inversions/postprocessing/inversion_output.py @@ -579,7 +579,12 @@ def from_datatree(cls: type[Self], dt: xr.DataTree) -> Self: "species": dt.attrs.get("species"), "domain": dt.attrs.get("domain"), } - basis = get_xr_dummies(dt.basis.basis, cat_dim="nx", categories=dt.trace.posterior.nx) + basis_dim = "nx" if "nx" in dt.trace.posterior.coords else "region" + basis = get_xr_dummies( + dt.basis.basis, + cat_dim=basis_dim, + categories=dt.trace.posterior[basis_dim], + ) trace = az.InferenceData(**{group: val.to_dataset() for group, val in dt.trace.items()}) return cls( **obs_and_errs, diff --git a/openghg_inversions/postprocessing/make_outputs.py b/openghg_inversions/postprocessing/make_outputs.py index fe19dc51..659a24b2 100644 --- a/openghg_inversions/postprocessing/make_outputs.py +++ b/openghg_inversions/postprocessing/make_outputs.py @@ -50,7 +50,7 @@ def make_flux_outputs( if stats is not None: stats_args["stats"] = stats - stats_args["chunk_dim"] = "nx" + stats_args["chunk_dim"] = "nx" if "nx" in trace.dims else "region" stats_ds = calculate_stats(trace, **stats_args) if report_flux_on_inversion_grid: @@ -192,7 +192,10 @@ def make_concentration_outputs( trace[dv] = trace[dv] + trace[offset_dv] # update long name, creating if not present - trace[dv].attrs["long_name"] = trace[dv].attrs.get("long_name", str(dv).split("_")[-1] + "_baseline") + "_including_offset" + trace[dv].attrs["long_name"] = ( + trace[dv].attrs.get("long_name", str(dv).split("_")[-1] + "_baseline") + + "_including_offset" + ) if stats_args is None: stats_args = {} @@ -214,7 +217,7 @@ def make_country_outputs( country_regions: str | Path | dict[str, list[str]] | Literal["paris"] | None = None, stats: list[str] | None = None, stats_args: dict | None = None, - country_code: Literal["alpha2", "alpha3"] | None = "alpha3" + country_code: Literal["alpha2", "alpha3"] | None = "alpha3", ) -> xr.Dataset: """Calculate country emission stats. @@ -276,7 +279,7 @@ def basic_output( country_file: str | Path | None = None, country_regions: str | Path | dict[str, list[str]] | Literal["paris"] | None = None, stats: list[str] | None = None, - stats_args: dict | None = None + stats_args: dict | None = None, ) -> xr.Dataset: """Create basic output with concentrations, flux totals, and country totals. @@ -305,7 +308,7 @@ def basic_output( country_file=country_file, country_regions=country_regions, stats=stats, - stats_args=stats_args + stats_args=stats_args, ) model_data = inv_out.get_model_data(var_names=["hx", "hbc", "min_error"]).rename( diff --git a/openghg_inversions/rhime.py b/openghg_inversions/rhime.py new file mode 100644 index 00000000..9ed3efdb --- /dev/null +++ b/openghg_inversions/rhime.py @@ -0,0 +1,1126 @@ +"""Modern public RHIME run functions.""" + +from __future__ import annotations + +from collections.abc import Mapping, Sequence +from dataclasses import dataclass, field +import inspect +from pathlib import Path +from typing import Any, Literal +import time +import warnings + +import arviz as az +import numpy as np +import pymc as pm +import xarray as xr +from openghg.util import split_function_inputs + +from openghg_inversions.array_ops import get_xr_dummies, sparse_xr_dot +from openghg_inversions.basis import basis_functions_wrapper +from openghg_inversions.config import config +from openghg_inversions.filters import filtering +from openghg_inversions.inversion_data import data_processing_surface_notracer, load_merged_data +from openghg_inversions.inversion_inputs import make_inv_inputs +from openghg_inversions.models.rhime import ( + DEFAULT_X_PRIOR, + build_rhime_model, + build_rhime_multisector_model, + safe_pymc_name, +) +from openghg_inversions.postprocessing.inversion_output import InversionOutput +from openghg_inversions.utils import ncdf_encoding + +OutputFormat = Literal["none", "inv_out", "basic", "paris"] + + +@dataclass(frozen=True) +class SectorSpec: + """Configuration for one separately optimised flux sector. + + Args: + name: User-facing sector name. + flux_source: OpenGHG flux ``source`` used to retrieve this sector. + x_prior: Prior specification for this sector's flux scaling factors. + variable_suffix: PyMC-safe suffix used in model variable names. + """ + + name: str + flux_source: str + x_prior: dict[str, Any] + variable_suffix: str + + +@dataclass(frozen=True) +class RhimeModelSpec: + """Model options used to build a RHIME PyMC model. + + Args: + species: Primary species name. + domain: Model domain name. + sectors: Flux sectors included in the model. + use_bc: Whether boundary-condition scaling is included. + sigma_per_site: Whether model-error terms vary by site. + add_offset: Whether model-data offsets are included. + pollution_events_from_obs: Whether model error scales with observed + enhancements instead of modelled enhancements. + no_model_error: Whether explicit model-error terms are disabled. + power: Exponent or prior specification used in likelihood error scaling. + """ + + species: str + domain: str + sectors: tuple[SectorSpec, ...] + use_bc: bool = True + sigma_per_site: bool = True + add_offset: bool = False + pollution_events_from_obs: bool = False + no_model_error: bool = False + power: dict[str, Any] | float = 1.99 + + +@dataclass(frozen=True) +class RhimeOutputSpec: + """Output settings for a RHIME run. + + Args: + output_format: Output mode. ``"inv_out"`` saves/returns the modern + inversion output, ``"basic"`` and ``"paris"`` additionally create + derived outputs, and ``"none"`` skips output products. + output_path: Directory for saved outputs. + output_name: Base output name. + save_trace: Trace save setting. If true, save to ``output_path`` using + the default trace file name; if a path, save there. + save_inversion_output: Inversion-output save setting. Defaults to true + for CLI-friendly behaviour. + country_file: Optional country mask file used by derived outputs. + paris_postprocessing_kwargs: Extra keyword arguments for PARIS output + creation. + """ + + output_format: OutputFormat = "inv_out" + output_path: str | None = None + output_name: str = "rhime" + save_trace: str | Path | bool = False + save_inversion_output: str | Path | bool = True + country_file: str | None = None + paris_postprocessing_kwargs: dict[str, Any] | None = None + + +@dataclass(frozen=True) +class RhimeRunSpec: + """Top-level run metadata for a RHIME run. + + Args: + start_date: Inclusive inversion start date. + end_date: Exclusive inversion end date. + sites: Sites included after data preparation and filtering. + averaging_period: Observation averaging period per retained site. + model: Mathematical model specification. + output: Output settings. + split_by_sectors: Whether flux data were prepared in sector-resolved + mode. + """ + + start_date: str + end_date: str + sites: tuple[str, ...] + averaging_period: tuple[str | None, ...] + model: RhimeModelSpec + output: RhimeOutputSpec + split_by_sectors: bool = False + + +@dataclass +class RhimeResult: + """Modern RHIME run result. + + Args: + run_spec: Top-level run metadata. + model_spec: Model specification used to build the PyMC model. + output_spec: Output settings used by the run. + inv_inputs: Canonical xarray inversion inputs consumed by the model. + idata: ArviZ ``InferenceData`` returned by sampling. + output_metadata: Paths and notes for generated outputs. + outputs: In-memory derived outputs keyed by output kind. + model: Built PyMC model. + inv_out: Modern inversion output object when created. + """ + + run_spec: RhimeRunSpec + model_spec: RhimeModelSpec + output_spec: RhimeOutputSpec + inv_inputs: xr.Dataset + idata: az.InferenceData + output_metadata: dict[str, Any] = field(default_factory=dict) + outputs: dict[str, Any] = field(default_factory=dict) + model: pm.Model | None = None + inv_out: InversionOutput | None = None + + +@dataclass +class _PreparedRhimeData: + """Prepared data needed after data gathering, basis application, and filtering.""" + + inv_inputs: xr.Dataset + basis: xr.DataArray + flux: xr.DataArray + sites: list[str] + averaging_period: list[str | None] + + +def _as_list(value: str | Sequence[str] | None) -> list[str] | None: + """Convert a scalar/list-like value to a list of strings.""" + if value is None: + return None + if isinstance(value, str): + return [value] + return [str(item) for item in value] + + +def resolve_flux_sources( + *, + flux_sources: str | Sequence[str] | None = None, + emissions_name: str | Sequence[str] | None = None, +) -> list[str]: + """Resolve new ``flux_sources`` and legacy ``emissions_name`` arguments. + + Args: + flux_sources: Preferred OpenGHG flux source names. + emissions_name: Legacy name for flux sources. + + Returns: + Resolved flux source names. + + Raises: + ValueError: If no usable flux source is supplied. + """ + resolved = _as_list(flux_sources) + if resolved is None: + resolved = _as_list(emissions_name) + if not resolved or any(source in {"", "None", "none"} for source in resolved): + raise ValueError("At least one flux source must be supplied via `flux_sources`.") + return resolved + + +def params_from_config( + config_file: str | Path, + *, + start_date: str | None = None, + end_date: str | None = None, + output_path: str | None = None, + extra_kwargs: Mapping[str, Any] | None = None, +) -> dict[str, Any]: + """Load RHIME run parameters from an INI config file. + + Args: + config_file: Path to an INI configuration file. + start_date: Optional command-line start-date override. + end_date: Optional command-line end-date override. + output_path: Optional command-line output-path override. + extra_kwargs: Optional keyword overrides, normally parsed from CLI JSON. + + Returns: + Normalized RHIME run parameters using snake-case public names. + + Raises: + ValueError: If deprecated unsupported parameters are present. + """ + params = dict(config.all_param(str(config_file), exclude_not_found=True, allow_new=True)) + if start_date is not None: + params["start_date"] = start_date + if end_date is not None: + params["end_date"] = end_date + if output_path is not None: + params["output_path"] = output_path + if extra_kwargs: + params.update(extra_kwargs) + return _normalise_params(params) + + +def _normalise_params(params: Mapping[str, Any]) -> dict[str, Any]: + """Normalize legacy config spellings to modern snake-case names.""" + normalized = dict(params) + aliases = { + "outputpath": "output_path", + "outputname": "output_name", + "xprior": "x_prior", + "bcprior": "bc_prior", + "sigprior": "sigma_prior", + "offsetprior": "offset_prior", + "emissions_name": "flux_sources", + } + for old, new in aliases.items(): + if old not in normalized: + continue + if new in normalized: + warnings.warn( + f"Ignoring deprecated RHIME parameter {old!r} because {new!r} was also supplied.", + UserWarning, + stacklevel=2, + ) + else: + warnings.warn( + f"RHIME parameter {old!r} is deprecated; use {new!r} instead.", + UserWarning, + stacklevel=2, + ) + normalized[new] = normalized[old] + del normalized[old] + + if "calculate_min_error" in normalized: + raise ValueError("`calculate_min_error` is not supported by RHIME runners; use `min_error`.") + if "reparameterise_log_normal" in normalized: + raise ValueError( + "`reparameterise_log_normal` is not supported by RHIME runners; " + "set `reparameterise` in the relevant prior dictionary if needed." + ) + if "mcmc_type" in normalized: + raise ValueError("`mcmc_type` is not supported by RHIME runners; use `nuts_sampler` if needed.") + + return normalized + + +def _required_run_params() -> set[str]: + return { + "species", + "sites", + "averaging_period", + "domain", + "start_date", + "end_date", + "output_path", + "output_name", + } + + +def _validate_required_params(params: Mapping[str, Any]) -> None: + """Raise if normalized run parameters are missing required values.""" + missing = [ + name for name in sorted(_required_run_params()) if name not in params or params[name] in (None, " ") + ] + if missing: + raise ValueError(f"Required RHIME parameter(s) missing: {missing!r}") + + +def _validate_supported_params(params: Mapping[str, Any]) -> None: + """Raise if normalized run parameters contain unsupported keys.""" + data_params = set(inspect.signature(_prepare_data).parameters) + runner_params = { + "x_prior", + "bc_prior", + "sigma_prior", + "offset_prior", + "sector_priors", + "pollution_events_from_obs", + "no_model_error", + "power", + "nit", + "burn", + "tune", + "nchain", + "nuts_sampler", + "verbose", + "sampler_kwargs", + "output_format", + "save_trace", + "save_inversion_output", + "paris_postprocessing_kwargs", + "offset_args", + "country_file", + "add_offset", + "sigma_per_site", + } + required = _required_run_params() + supported = data_params | runner_params | required + unsupported = sorted(set(params) - supported) + if unsupported: + raise ValueError(f"Unsupported RHIME parameter(s): {unsupported!r}") + + +def _validate_output_format(output_format: str) -> None: + """Raise if a RHIME output format is not supported by the modern runners.""" + valid_formats = {"none", "inv_out", "basic", "paris"} + if output_format not in valid_formats: + raise ValueError( + f"Unsupported RHIME output_format {output_format!r}; expected one of {sorted(valid_formats)!r}." + ) + + +def _resolve_output_path( + save_setting: str | Path | bool, output_path: str | None, filename: str +) -> Path | None: + """Resolve an optional output path from a bool/path save setting.""" + if not save_setting: + return None + if isinstance(save_setting, str | Path): + return Path(save_setting) + if output_path is None: + raise ValueError("An output path is required when saving RHIME artifacts.") + return Path(output_path) / filename + + +def _define_output_filename( + output_path: str | Path, + species: str, + domain: str, + output_name: str, + start_date: str, + *, + ext: str = ".nc", +) -> Path: + """Create the RHIME output filename used for derived NetCDF products.""" + return Path(output_path) / f"{output_name}_{species}_{domain}_{start_date}{ext}" + + +def _prepare_data( + *, + species: str, + sites: list[str], + domain: str, + averaging_period: list[str | None], + start_date: str, + end_date: str, + output_name: str, + flux_sources: list[str], + split_by_sectors: bool, + bc_store: str = "user", + obs_store: str = "user", + footprint_store: str = "user", + emissions_store: str = "user", + met_model: list[str | None] | str | None = None, + fp_model: str | None = None, + fp_height: list[str | None] | str | None = None, + fp_species: str | None = None, + inlet: list[str | None] | str | None = None, + instrument: list[str | None] | str | None = None, + max_level: int | None = None, + calibration_scale: str | None = None, + obs_data_level: list[str | None] | str | None = None, + platform: list[str | None] | str | None = None, + use_tracer: bool = False, + use_bc: bool = True, + fp_basis_case: str | None = None, + basis_directory: str | None = None, + bc_basis_case: str = "NESW", + bc_basis_directory: str | None = None, + country_directory: str | None = None, + bc_input: str | None = None, + basis_algorithm: str = "weighted", + nbasis: int = 100, + filters: None | list | dict[str, list[str] | None] = None, + fix_basis_outer_regions: bool = False, + averaging_error: bool = True, + bc_freq: str | None = None, + sigma_freq: str | None = None, + reload_merged_data: bool = False, + save_merged_data: bool = False, + merged_data_dir: str | None = None, + merged_data_name: str | None = None, + basis_output_path: str | None = None, + min_error: Literal["percentile", "residual"] | None | int | float = 0.0, + min_error_options: dict | None = None, +) -> _PreparedRhimeData: + """Gather data, apply basis functions, filter observations, and make canonical inputs. + + This is the only RHIME runner stage that should know about the legacy + ``fp_all``/``fp_data`` containers. It returns explicit modern objects needed + downstream. + """ + if use_tracer: + raise ValueError("RHIME public runners do not support tracer inversions in issue #398.") + + rerun_merge = True + if reload_merged_data and merged_data_dir is not None: + try: + fp_all = load_merged_data(merged_data_dir, species, start_date, output_name, merged_data_name) + except ValueError as exc: + print(f"{exc}, re-running data merge.") + else: + fp_all[".split_by_sectors"] = split_by_sectors + rerun_merge = False + elif reload_merged_data: + print("Cannot reload merged data without `merged_data_dir`; re-running data merge.") + + if rerun_merge: + ( + fp_all, + sites, + inlet, + fp_height, + instrument, + averaging_period, + ) = data_processing_surface_notracer( + species=species, + sites=sites, + domain=domain, + averaging_period=averaging_period, + start_date=start_date, + end_date=end_date, + obs_data_level=obs_data_level, + platform=platform, + met_model=met_model, + fp_model=fp_model, + fp_height=fp_height, + fp_species=fp_species, + emissions_name=flux_sources, + inlet=inlet, + instrument=instrument, + max_level=max_level, + calibration_scale=calibration_scale, + use_bc=use_bc, + bc_input=bc_input, + bc_store=bc_store, + obs_store=obs_store, + footprint_store=footprint_store, + emissions_store=emissions_store, + split_by_sectors=split_by_sectors, + averagingerror=averaging_error, + save_merged_data=save_merged_data, + merged_data_name=merged_data_name, + merged_data_dir=merged_data_dir, + output_name=output_name, + ) + + fp_data, basis_objects = basis_functions_wrapper( + basis_algorithm=basis_algorithm, + nbasis=nbasis, + fp_basis_case=fp_basis_case, + bc_basis_case=bc_basis_case, + basis_directory=basis_directory, + bc_basis_directory=bc_basis_directory, + country_directory=country_directory, + fp_all=fp_all, + use_bc=use_bc, + species=species, + domain=domain, + start_date=start_date, + fix_outer_regions=fix_basis_outer_regions, + emissions_name=flux_sources, + outputname=output_name, + output_path=basis_output_path, + return_basis_objects=True, + ) + + if filters is not None: + try: + fp_data = filtering(fp_data, filters) + except ValueError: + for site in sites: + fp_data[site] = fp_data[site].compute() + fp_data = filtering(fp_data, filters) + + dropped_sites = [] + for site in sites: + if fp_data[site].time.values.shape[0] == 0: + dropped_sites.append(site) + del fp_data[site] + if dropped_sites: + sites = [site for site in sites if site not in dropped_sites] + print(f"\nDropping {dropped_sites} sites as no data passed the filtering.\n") + + for site in sites: + fp_data[site].attrs["Domain"] = domain + + min_error_options = min_error_options or {} + if isinstance(min_error, int): + min_error = float(min_error) + inv_inputs = make_inv_inputs( + fp_data, + sites=sites, + bc_freq=bc_freq, + sigma_freq=sigma_freq, + min_error=min_error, + min_error_per_site=min_error_options.get("by_site", False), + start_date=start_date, + ) + + if np.isnan(inv_inputs.H.values).any(): + warnings.warn(f"H matrix contains {np.isnan(inv_inputs.H.values).flatten().sum()} NaN values") + if use_bc and "H_bc" in inv_inputs and np.isnan(inv_inputs.H_bc.values).any(): + warnings.warn(f"H_bc matrix contains {np.isnan(inv_inputs.H_bc.values).flatten().sum()} NaN values") + + basis = get_xr_dummies(fp_data[".basis"], cat_dim="region", categories=inv_inputs.region) + flux = basis_objects["emissions"].flux + + return _PreparedRhimeData( + inv_inputs=inv_inputs, + basis=basis, + flux=flux, + sites=sites, + averaging_period=averaging_period, + ) + + +def _make_model_spec( + *, + species: str, + domain: str, + flux_sources: list[str], + x_prior: dict | None, + sector_priors: Mapping[str, dict] | None, + use_bc: bool, + sigma_per_site: bool, + add_offset: bool, + pollution_events_from_obs: bool, + no_model_error: bool, + power: dict | float, +) -> RhimeModelSpec: + """Create a lightweight model spec from normalized run parameters.""" + default_x_prior = DEFAULT_X_PRIOR.copy() if x_prior is None else x_prior.copy() + sectors = [] + for source in flux_sources: + prior = ( + sector_priors[source] + if sector_priors is not None and source in sector_priors + else default_x_prior + ) + sectors.append( + SectorSpec( + name=source, + flux_source=source, + x_prior=dict(prior), + variable_suffix=safe_pymc_name(source), + ) + ) + return RhimeModelSpec( + species=species, + domain=domain, + sectors=tuple(sectors), + use_bc=use_bc, + sigma_per_site=sigma_per_site, + add_offset=add_offset, + pollution_events_from_obs=pollution_events_from_obs, + no_model_error=no_model_error, + power=power, + ) + + +def _sample_model( + model: pm.Model, + *, + nit: int, + burn: int, + tune: int, + nchain: int, + nuts_sampler: str, + verbose: bool, + sampler_kwargs: dict | None, +) -> az.InferenceData: + """Sample a built RHIME model and return InferenceData.""" + sampler_kwargs = dict(sampler_kwargs or {}) + sampler_kwargs.setdefault("progressbar", verbose) + sampler_kwargs.setdefault("cores", nchain) + return _sample( + model, + draws=int(nit), + burn=int(burn), + tune=int(tune), + chains=int(nchain), + sample_prior_predictive=True, + sample_posterior_predictive=["y"], + nuts_sampler=nuts_sampler, + **sampler_kwargs, + ) + + +def _extend_inferencedata_predictive( + trace: az.InferenceData, + *, + model: pm.Model, + sample_prior_predictive: bool | int = False, + sample_posterior_predictive: bool | list[str] = False, +) -> az.InferenceData: + """Extend an InferenceData object with requested predictive groups.""" + if sample_prior_predictive: + prior_draws = ( + trace.posterior.sizes["draw"] if sample_prior_predictive is True else int(sample_prior_predictive) + ) + with model: + trace.extend(pm.sample_prior_predictive(prior_draws, model)) + + if sample_posterior_predictive: + posterior_var_names = ( + None if sample_posterior_predictive is True else list(sample_posterior_predictive) + ) + with model: + trace.extend(pm.sample_posterior_predictive(trace, model=model, var_names=posterior_var_names)) + + return trace + + +def _sample( + model: pm.Model, + *, + draws: int = 1000, + tune: int = 1000, + chains: int = 4, + burn: int = 0, + sample_prior_predictive: bool | int = False, + sample_posterior_predictive: bool | list[str] = False, + **kwargs: Any, +) -> az.InferenceData: + """Sample from a built RHIME model and apply burn slicing/predictive requests.""" + sample_kwargs = dict(kwargs) + sample_kwargs.pop("return_inferencedata", None) + idata_kwargs = dict(sample_kwargs.pop("idata_kwargs", {})) + idata_kwargs["log_likelihood"] = True + + with model: + raw_trace = pm.sample( + draws=draws, + tune=tune, + chains=chains, + return_inferencedata=True, + idata_kwargs=idata_kwargs, + **sample_kwargs, + ) + + burned_trace = raw_trace.isel(draw=slice(burn, None)) + return _extend_inferencedata_predictive( + burned_trace, + model=model, + sample_prior_predictive=sample_prior_predictive, + sample_posterior_predictive=sample_posterior_predictive, + ) + + +def _make_inversion_output( + *, + prepared: _PreparedRhimeData, + idata: az.InferenceData, + start_date: str, + end_date: str, + species: str, + domain: str, +) -> InversionOutput: + """Create an InversionOutput directly from RHIME inputs and InferenceData. + + This is a transitional direct constructor for the modern RHIME path. It is + deliberately not routed through the fixed-basis/inferpymc legacy adapter. + This should be refactored when issue #401 defines the modern + ``InversionOutput`` contract. + """ + inv_inputs = prepared.inv_inputs + nmeasure = np.arange(inv_inputs.sizes["nmeasure"]) + site_names = ( + inv_inputs["site_names"] if "site_names" in inv_inputs else xr.DataArray(prepared.sites, dims="nsite") + ) + + obs_prior_factor = inv_inputs["mf_prior_factor"] if "mf_prior_factor" in inv_inputs else None + obs_prior_upper_level_factor = ( + inv_inputs["mf_prior_upper_level_factor"] if "mf_prior_upper_level_factor" in inv_inputs else None + ) + + def nmeasure_array(name: str, source: xr.DataArray) -> xr.DataArray: + """Create a clean nmeasure DataArray without inherited MultiIndex coords.""" + result = xr.DataArray( + source.values, + dims=["nmeasure"], + coords={"nmeasure": nmeasure}, + name=name, + ) + result.attrs = source.attrs + return result + + return InversionOutput( + obs=nmeasure_array("Yobs", inv_inputs["mf"]), + obs_err=nmeasure_array("Yerror", inv_inputs["mf_error"]), + obs_repeatability=nmeasure_array("Yerror_repeatability", inv_inputs["mf_repeatability"]), + obs_variability=nmeasure_array("Yerror_variability", inv_inputs["mf_variability"]), + obs_prior_factor=( + nmeasure_array("Yobs_prior_factor", obs_prior_factor) if obs_prior_factor is not None else None + ), + obs_prior_upper_level_factor=( + nmeasure_array("Yobs_prior_upper_level_factor", obs_prior_upper_level_factor) + if obs_prior_upper_level_factor is not None + else None + ), + site_indicators=nmeasure_array("site_indicator", inv_inputs["site_indicator"]), + flux=prepared.flux, + basis=prepared.basis, + trace=idata, + site_names=site_names, + times=nmeasure_array("times", inv_inputs["time"]), + start_date=start_date, + end_date=end_date, + species=species, + domain=domain, + ) + + +def _write_standard_outputs( + *, + result: RhimeResult, + prepared: _PreparedRhimeData, + country_file: str | None, +) -> None: + """Create and optionally save standard RHIME outputs.""" + output_spec = result.output_spec + if output_spec.output_format == "none": + return + + inv_out = _make_inversion_output( + prepared=prepared, + idata=result.idata, + start_date=result.run_spec.start_date, + end_date=result.run_spec.end_date, + species=result.model_spec.species, + domain=result.model_spec.domain, + ) + result.inv_out = inv_out + result.outputs["inversion_output"] = inv_out + + trace_path = _resolve_output_path( + output_spec.save_trace, + output_spec.output_path, + f"{output_spec.output_name}{result.run_spec.start_date}_trace.nc", + ) + if trace_path is not None: + trace_path.parent.mkdir(parents=True, exist_ok=True) + result.idata.to_netcdf(str(trace_path), engine="netcdf4", compress=True) + result.output_metadata["trace_path"] = str(trace_path) + + inv_out_path = _resolve_output_path( + output_spec.save_inversion_output, + output_spec.output_path, + f"{output_spec.output_name}{result.run_spec.start_date}_inversion_output.nc", + ) + if inv_out_path is not None: + inv_out_path.parent.mkdir(parents=True, exist_ok=True) + inv_out.save(inv_out_path) + result.output_metadata["inversion_output_path"] = str(inv_out_path) + + if output_spec.output_format == "basic": + from openghg_inversions.postprocessing.make_outputs import basic_output + + result.outputs["basic"] = basic_output(inv_out, country_file=country_file) + elif output_spec.output_format == "paris": + from openghg_inversions.postprocessing.make_paris_outputs import make_paris_outputs + + obs_avg_period = prepared.averaging_period[0] or "0h" + kwargs = output_spec.paris_postprocessing_kwargs or {} + flux_outs, conc_outs = make_paris_outputs( + inv_out, + country_file=country_file, + domain=result.model_spec.domain, + obs_avg_period=obs_avg_period, + **kwargs, + ) + result.outputs["paris_flux"] = flux_outs + result.outputs["paris_concentration"] = conc_outs + + if output_spec.output_path is not None: + Path(output_spec.output_path).mkdir(parents=True, exist_ok=True) + conc_file = _define_output_filename( + output_spec.output_path, + result.model_spec.species, + result.model_spec.domain, + output_spec.output_name + "_conc", + result.run_spec.start_date, + ext=".nc", + ) + flux_file = _define_output_filename( + output_spec.output_path, + result.model_spec.species, + result.model_spec.domain, + output_spec.output_name + "_flux", + result.run_spec.start_date, + ext=".nc", + ) + conc_outs.to_netcdf( + conc_file, unlimited_dims=["time"], mode="w", encoding=ncdf_encoding(conc_outs) + ) + flux_outs.to_netcdf( + flux_file, unlimited_dims=["time"], mode="w", encoding=ncdf_encoding(flux_outs) + ) + result.output_metadata["paris_concentration_path"] = str(conc_file) + result.output_metadata["paris_flux_path"] = str(flux_file) + + +def make_multisector_flux_diagnostics( + *, + idata: az.InferenceData, + prepared: _PreparedRhimeData, + model_spec: RhimeModelSpec, +) -> xr.Dataset: + """Create sector-aware posterior flux diagnostics for shared-basis RHIME. + + Args: + idata: InferenceData returned by RHIME sampling. + prepared: Prepared RHIME data object containing basis and flux arrays. + model_spec: Model spec containing sector names and variable suffixes. + + Returns: + Dataset containing posterior mean scaling factors, sector posterior flux + means, and total posterior flux mean. + """ + basis = prepared.basis + flux = prepared.flux + posterior_flux = [] + posterior_scaling = [] + + for sector in model_spec.sectors: + x_name = f"x_{sector.variable_suffix}" + x_mean = idata.posterior[x_name].mean(("chain", "draw")) + scale_grid = sparse_xr_dot(basis, x_mean) + sector_flux = flux.sel(source=sector.flux_source) if "source" in flux.dims else flux + posterior_scaling.append(scale_grid.expand_dims(sector=[sector.name])) + posterior_flux.append((scale_grid * sector_flux).expand_dims(sector=[sector.name])) + + scaling = xr.concat(posterior_scaling, dim="sector").rename("posterior_scaling_mean") + flux_by_sector = xr.concat(posterior_flux, dim="sector").rename("posterior_flux_mean") + total_flux = flux_by_sector.sum("sector").rename("posterior_flux_total_mean") + return xr.merge([scaling, flux_by_sector, total_flux]) + + +def _write_multisector_outputs( + *, + result: RhimeResult, + prepared: _PreparedRhimeData, +) -> None: + """Create and optionally save shared-basis multi-sector RHIME outputs.""" + diagnostics = make_multisector_flux_diagnostics( + idata=result.idata, + prepared=prepared, + model_spec=result.model_spec, + ) + result.outputs["sector_flux_diagnostics"] = diagnostics + + output_spec = result.output_spec + if output_spec.output_format == "paris": + result.output_metadata["paris_note"] = ( + "Multi-sector PARIS schema support is not implemented in issue #398; " + "sector-aware modern diagnostics were generated instead." + ) + if output_spec.output_path is not None and output_spec.output_format != "none": + Path(output_spec.output_path).mkdir(parents=True, exist_ok=True) + diagnostics_path = ( + Path(output_spec.output_path) + / f"{output_spec.output_name}{result.run_spec.start_date}_sector_flux_diagnostics.nc" + ) + diagnostics.to_netcdf(diagnostics_path, mode="w", encoding=ncdf_encoding(diagnostics)) + result.output_metadata["sector_flux_diagnostics_path"] = str(diagnostics_path) + + +def _run_common( + *, + multisector: bool, + params: dict[str, Any], +) -> RhimeResult: + """Run the shared RHIME pipeline after public wrapper/config normalization.""" + params = _normalise_params(params) + _validate_required_params(params) + _validate_supported_params(params) + + flux_sources = resolve_flux_sources( + flux_sources=params.pop("flux_sources", None), + emissions_name=params.pop("emissions_name", None), + ) + if multisector and len(flux_sources) < 2: + raise ValueError("`run_rhime_multisector` requires at least two flux sources.") + if not multisector and len(flux_sources) != 1: + raise ValueError("`run_rhime` requires exactly one flux source.") + + species = params.pop("species") + sites = _as_list(params.pop("sites")) or [] + domain = params.pop("domain") + averaging_period = _as_list(params.pop("averaging_period")) or [] + start_date = params.pop("start_date") + end_date = params.pop("end_date") + output_path = params.pop("output_path") + output_name = params.pop("output_name") + + x_prior = params.pop("x_prior", None) + bc_prior = params.pop("bc_prior", None) + sigma_prior = params.pop("sigma_prior", None) + offset_prior = params.pop("offset_prior", None) + sector_priors = params.pop("sector_priors", None) + if sector_priors is not None: + sector_priors = {key: dict(value) for key, value in sector_priors.items()} + + use_bc = params.get("use_bc", True) + sigma_per_site = params.get("sigma_per_site", True) + add_offset = params.get("add_offset", False) + pollution_events_from_obs = params.pop("pollution_events_from_obs", False) + no_model_error = params.pop("no_model_error", False) + power = params.pop("power", 1.99) + nit = int(params.pop("nit", 1000)) + burn = int(params.pop("burn", 0)) + tune = int(params.pop("tune", 1000)) + nchain = int(params.pop("nchain", 4)) + nuts_sampler = params.pop("nuts_sampler", "pymc") + verbose = params.pop("verbose", False) + sampler_kwargs = params.pop("sampler_kwargs", None) + output_format = params.pop("output_format", "inv_out") + _validate_output_format(output_format) + save_trace = params.pop("save_trace", False) + save_inversion_output = params.pop("save_inversion_output", True) + country_file = params.get("country_file") + paris_postprocessing_kwargs = params.pop("paris_postprocessing_kwargs", None) + + data_args, _ = split_function_inputs( + { + **params, + "species": species, + "sites": sites, + "domain": domain, + "averaging_period": averaging_period, + "start_date": start_date, + "end_date": end_date, + "output_name": output_name, + "flux_sources": flux_sources, + "split_by_sectors": multisector, + }, + _prepare_data, + ) + prepared = _prepare_data(**data_args) + + model_spec = _make_model_spec( + species=species, + domain=domain, + flux_sources=flux_sources, + x_prior=x_prior, + sector_priors=sector_priors, + use_bc=use_bc, + sigma_per_site=sigma_per_site, + add_offset=add_offset, + pollution_events_from_obs=pollution_events_from_obs, + no_model_error=no_model_error, + power=power, + ) + output_spec = RhimeOutputSpec( + output_format=output_format, + output_path=output_path, + output_name=output_name, + save_trace=save_trace, + save_inversion_output=save_inversion_output, + country_file=country_file, + paris_postprocessing_kwargs=paris_postprocessing_kwargs, + ) + run_spec = RhimeRunSpec( + start_date=start_date, + end_date=end_date, + sites=tuple(prepared.sites), + averaging_period=tuple(prepared.averaging_period), + model=model_spec, + output=output_spec, + split_by_sectors=multisector, + ) + + start_build = time.time() + if multisector: + model = build_rhime_multisector_model( + prepared.inv_inputs, + sectors=flux_sources, + sector_priors=sector_priors, + x_prior=x_prior, + bc_prior=bc_prior, + sigma_prior=sigma_prior, + sigma_per_site=sigma_per_site, + offset_prior=offset_prior, + add_offset=add_offset, + use_bc=use_bc, + pollution_events_from_obs=pollution_events_from_obs, + no_model_error=no_model_error, + offset_args=params.get("offset_args"), + power=power, + ) + else: + model = build_rhime_model( + prepared.inv_inputs, + x_prior=x_prior, + bc_prior=bc_prior, + sigma_prior=sigma_prior, + sigma_per_site=sigma_per_site, + offset_prior=offset_prior, + add_offset=add_offset, + use_bc=use_bc, + pollution_events_from_obs=pollution_events_from_obs, + no_model_error=no_model_error, + offset_args=params.get("offset_args"), + power=power, + ) + + idata = _sample_model( + model, + nit=nit, + burn=burn, + tune=tune, + nchain=nchain, + nuts_sampler=nuts_sampler, + verbose=verbose, + sampler_kwargs=sampler_kwargs, + ) + result = RhimeResult( + run_spec=run_spec, + model_spec=model_spec, + output_spec=output_spec, + inv_inputs=prepared.inv_inputs, + idata=idata, + model=model, + output_metadata={"build_and_sample_seconds": time.time() - start_build}, + ) + + if multisector: + _write_multisector_outputs(result=result, prepared=prepared) + else: + _write_standard_outputs( + result=result, + prepared=prepared, + country_file=country_file, + ) + + return result + + +def run_rhime( + *, + config_file: str | Path | None = None, + **kwargs: Any, +) -> RhimeResult: + """Run a standard single-sector RHIME inversion. + + Args: + config_file: Optional INI configuration file. Values in ``kwargs`` + override values read from this file. + **kwargs: RHIME run parameters using snake-case names, such as + ``output_path``, ``output_name``, ``flux_sources``, and + ``x_prior``. + + Returns: + Modern RHIME result containing canonical inputs, InferenceData, specs, + output metadata, and generated outputs. + + Raises: + ValueError: If required parameters are missing, unsupported parameters + are supplied, or the flux-source count is invalid. + """ + params = params_from_config(config_file, extra_kwargs=kwargs) if config_file is not None else dict(kwargs) + return _run_common(multisector=False, params=params) + + +def run_rhime_multisector( + *, + config_file: str | Path | None = None, + **kwargs: Any, +) -> RhimeResult: + """Run a shared-basis multi-sector RHIME inversion. + + Args: + config_file: Optional INI configuration file. Values in ``kwargs`` + override values read from this file. + **kwargs: RHIME run parameters using snake-case names. Multi-sector + runs require at least two ``flux_sources`` and may include + ``sector_priors`` keyed by flux source. + + Returns: + Modern RHIME result containing canonical inputs, InferenceData, specs, + output metadata, and sector diagnostics. + + Raises: + ValueError: If required parameters are missing, unsupported parameters + are supplied, or fewer than two flux sources are provided. + """ + params = params_from_config(config_file, extra_kwargs=kwargs) if config_file is not None else dict(kwargs) + return _run_common(multisector=True, params=params) diff --git a/pyproject.toml b/pyproject.toml index 57716882..586c98bb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -56,6 +56,9 @@ uv_dev = ["openghg_inversions[jupyter,dev]"] "Home" = "https://github.com/openghg/openghg_inversions" "Bug Tracker" = "https://github.com/openghg/openghg_inversions/issues" +[project.scripts] +openghg-inversions = "openghg_inversions.cli:main" + [tool.setuptools.packages.find] where = ["."] @@ -66,6 +69,7 @@ openghg_inversions = [ "basis/algorithms/*.nc", "postprocessing/*.cdl", "postprocessing/*.json", + "config/templates/*.ini", "hbmcmc/config/*.ini" ] diff --git a/tests/test_get_data.py b/tests/test_get_data.py index 9672047e..49698748 100644 --- a/tests/test_get_data.py +++ b/tests/test_get_data.py @@ -22,9 +22,7 @@ from openghg_inversions.inversion_data.getters import get_flux_data -def test_data_processing_surface_notracer( - tac_ch4_data_args, merged_data_file_name, raw_data_path -): +def test_data_processing_surface_notracer(tac_ch4_data_args, merged_data_file_name, raw_data_path): """Check that `data_processing_surface_notracer` produces the same output as v0.1, with test data frozen on 9 Feb 2024, or the same as v0.2, with test data frozen on 15 Apr 2024 (using the zarr backend). @@ -35,7 +33,15 @@ def test_data_processing_surface_notracer( assert len(result) == 6 # check keys of "fp_all" - assert list(result[0].keys()) == [".species", ".flux", ".split_by_sectors", ".bc", "TAC", ".scales", ".units"] + assert list(result[0].keys()) == [ + ".species", + ".flux", + ".split_by_sectors", + ".bc", + "TAC", + ".scales", + ".units", + ] # variables to check (to avoid surprises from new variables added to data) check_vars = ["mf", "fp", "mf_mod", "bc_mod", "fp_x_flux", "bc_n"] @@ -50,6 +56,7 @@ def test_data_processing_surface_notracer( expected_tac_combined_scenario["TAC"][check_vars].isel(time=0), ) + def test_load_merged_data(merged_data_dir, merged_data_file_name): """This should pass by finding the merged data with .zarr suffix.""" result = load_merged_data(merged_data_dir, merged_data_name=merged_data_file_name + "no_zip") @@ -179,6 +186,11 @@ def test_add_averaging_error(tac_ch4_data_args): for var in ["mf_error", "mf_repeatability", "mf_variability"]: for ds in [ds1, ds2]: assert var in ds + assert "number_of_observations" not in ds[var].attrs["long_name"] + + assert ds1.mf_error.attrs["long_name"] == ds1.mf.attrs["long_name"] + "_error" + assert ds1.mf_repeatability.attrs["long_name"] == ds1.mf.attrs["long_name"] + "_repeatability" + assert ds1.mf_variability.attrs["long_name"] == ds1.mf.attrs["long_name"] + "_variability" # averagingerror=True is default, so for ds1, "mf_error" should have repeatability # and variability added @@ -250,14 +262,17 @@ def test_looking_older_flux_files(tac_ch4_data_args, capsys): assert "Using flux data from 2019-01-01" in stdout -@pytest.mark.parametrize("end_date, time_period", [("2019-02-01", "monthly"), ("2020-01-01", "1 year"), ("2019-01-02", "1 year")]) +@pytest.mark.parametrize( + "end_date, time_period", [("2019-02-01", "monthly"), ("2020-01-01", "1 year"), ("2019-01-02", "1 year")] +) def test_flux_time_period_inference(end_date, time_period, tac_ch4_data_args): - kwargs = {"sources": tac_ch4_data_args["emissions_name"], - "species": tac_ch4_data_args["species"], - "domain": tac_ch4_data_args["domain"], - "start_date": "2019-01-01", - "end_date": end_date, - } + kwargs = { + "sources": tac_ch4_data_args["emissions_name"], + "species": tac_ch4_data_args["species"], + "domain": tac_ch4_data_args["domain"], + "start_date": "2019-01-01", + "end_date": end_date, + } flux_data = get_flux_data(**kwargs) source = tac_ch4_data_args["emissions_name"][0] diff --git a/tests/test_inversion_inputs.py b/tests/test_inversion_inputs.py index 94b85033..1ce756ca 100644 --- a/tests/test_inversion_inputs.py +++ b/tests/test_inversion_inputs.py @@ -266,3 +266,15 @@ def test_make_inv_inputs_raises_if_required_var_would_be_dropped(): with pytest.raises(ValueError, match="Required inversion data variables.*mf_error"): make_inv_inputs(fp_data=fp_data, sites=["AAA", "BBB"], min_error=0.0) + + +def test_make_inv_inputs_accepts_integer_min_error(): + """Integer min_error values should be treated as numeric scalar errors.""" + fp_data = { + "AAA": _make_minimal_fp_site(mf_base=10.0, include_inlet_height=False), + "BBB": _make_minimal_fp_site(mf_base=20.0, include_inlet_height=False), + } + + result = make_inv_inputs(fp_data=fp_data, sites=["AAA", "BBB"], min_error=40) + + assert np.all(result.min_error.values == 40.0) diff --git a/tests/test_rhime.py b/tests/test_rhime.py new file mode 100644 index 00000000..3eafc9d5 --- /dev/null +++ b/tests/test_rhime.py @@ -0,0 +1,425 @@ +from __future__ import annotations + +from pathlib import Path + +import pymc as pm +import pytest +import xarray as xr + +import openghg_inversions.rhime as rhime_module +from openghg_inversions.cli import main +from openghg_inversions.inversion_inputs import make_inv_inputs +from openghg_inversions.models.rhime import ( + build_rhime_model, + build_rhime_multisector_model, + safe_pymc_name, +) +from openghg_inversions.postprocessing.inversion_output import InversionOutput +from openghg_inversions.rhime import ( + RhimeResult, + params_from_config, + resolve_flux_sources, + run_rhime, + run_rhime_multisector, +) + + +@pytest.fixture(scope="module") +def rhime_inv_inputs(mhd_and_tac_fp_data) -> xr.Dataset: + return make_inv_inputs( + mhd_and_tac_fp_data, + sites=["MHD", "TAC"], + bc_freq="3h", + sigma_freq="3h", + min_error=0.0, + start_date="2019-01-01", + ) + + +@pytest.fixture +def multisector_inv_inputs(rhime_inv_inputs: xr.Dataset) -> xr.Dataset: + ds = rhime_inv_inputs.copy() + ds["H"] = xr.concat( + [ + rhime_inv_inputs["H"].expand_dims(source=["total-ukghg-edgar7"]), + (2.0 * rhime_inv_inputs["H"]).expand_dims(source=["sector-2"]), + ], + dim="source", + ) + return ds + + +@pytest.fixture +def builder_args() -> dict: + return { + "x_prior": {"pdf": "normal", "mu": 1.0, "sigma": 1.0}, + "bc_prior": {"pdf": "normal", "mu": 1.0, "sigma": 1.0}, + "sigma_prior": {"pdf": "uniform", "lower": 0.1, "upper": 10.0}, + "sigma_per_site": True, + "offset_prior": {"pdf": "normal", "mu": 0, "sigma": 1}, + "add_offset": False, + "use_bc": True, + "pollution_events_from_obs": True, + "no_model_error": False, + "power": 1.99, + } + + +def test_build_rhime_model_contains_expected_variables( + rhime_inv_inputs: xr.Dataset, builder_args: dict +) -> None: + model = build_rhime_model(rhime_inv_inputs, **builder_args) + + assert isinstance(model, pm.Model) + expected = {"x", "mu", "bc", "mu_bc", "sigma", "epsilon", "y"} + assert expected.issubset(model.named_vars) + + +def test_build_rhime_multisector_model_contains_expected_variables( + multisector_inv_inputs: xr.Dataset, builder_args: dict +) -> None: + sectors = ["total-ukghg-edgar7", "sector-2"] + model = build_rhime_multisector_model(multisector_inv_inputs, sectors=sectors, **builder_args) + + expected = { + "x_total_ukghg_edgar7", + "mu_total_ukghg_edgar7", + "x_sector_2", + "mu_sector_2", + "mu", + "bc", + "mu_bc", + "sigma", + "epsilon", + "y", + } + assert expected.issubset(model.named_vars) + assert len(model.coords["region"]) == multisector_inv_inputs.sizes["region"] + + +def test_build_rhime_multisector_model_requires_multiple_sectors( + multisector_inv_inputs: xr.Dataset, builder_args: dict +) -> None: + with pytest.raises(ValueError, match="at least two sectors"): + build_rhime_multisector_model( + multisector_inv_inputs, + sectors=["total-ukghg-edgar7"], + **builder_args, + ) + + +def test_resolve_flux_sources_prefers_new_name() -> None: + assert resolve_flux_sources(flux_sources=["new"], emissions_name=["legacy"]) == ["new"] + assert resolve_flux_sources(emissions_name=["legacy"]) == ["legacy"] + + +def test_params_from_config_maps_legacy_emissions_name(tmp_path: Path) -> None: + config_file = tmp_path / "rhime.ini" + config_file.write_text( + """ +[INPUT.MEASUREMENTS] +species = "ch4" +sites = ["TAC"] +averaging_period = ["1h"] +start_date = "2019-01-01" +end_date = "2019-01-02" + +[INPUT.PRIORS] +domain = "EUROPE" +emissions_name = ["legacy-source"] + +[RHIME.OUTPUT] +output_path = "out" +output_name = "test" +""", + encoding="utf-8", + ) + + params = params_from_config(config_file) + assert params["flux_sources"] == ["legacy-source"] + + +def test_params_from_config_rejects_unsupported_deprecated_option(tmp_path: Path) -> None: + config_file = tmp_path / "rhime.ini" + config_file.write_text( + """ +[INPUT.MEASUREMENTS] +species = "ch4" +sites = ["TAC"] +averaging_period = ["1h"] +start_date = "2019-01-01" +end_date = "2019-01-02" + +[INPUT.PRIORS] +domain = "EUROPE" +flux_sources = ["total-ukghg-edgar7"] + +[RHIME.OUTPUT] +output_path = "out" +output_name = "test" + +[RHIME.DATA] +calculate_min_error = true +""", + encoding="utf-8", + ) + + with pytest.raises(ValueError, match="calculate_min_error"): + params_from_config(config_file) + + +@pytest.mark.parametrize( + ("name", "value"), + [ + ("reparameterise_log_normal", "true"), + ("mcmc_type", '"hmc"'), + ], +) +def test_params_from_config_rejects_unsupported_legacy_runner_options( + tmp_path: Path, name: str, value: str +) -> None: + config_file = tmp_path / "rhime.ini" + config_file.write_text( + f""" +[INPUT.MEASUREMENTS] +species = "ch4" +sites = ["TAC"] +averaging_period = ["1h"] +start_date = "2019-01-01" +end_date = "2019-01-02" + +[INPUT.PRIORS] +domain = "EUROPE" +flux_sources = ["total-ukghg-edgar7"] + +[RHIME.OUTPUT] +output_path = "out" +output_name = "test" + +[RHIME.MCMC] +{name} = {value} +""", + encoding="utf-8", + ) + + with pytest.raises(ValueError, match=name): + params_from_config(config_file) + + +def test_run_rhime_rejects_unknown_parameter_before_data_preparation(tmp_path: Path) -> None: + args = { + "species": "ch4", + "sites": ["TAC"], + "averaging_period": ["1h"], + "domain": "EUROPE", + "start_date": "2019-01-01", + "end_date": "2019-01-02", + "flux_sources": ["total-ukghg-edgar7"], + "output_path": str(tmp_path), + "output_name": "test", + "definitely_not_a_rhime_parameter": True, + } + + with pytest.raises(ValueError, match="Unsupported RHIME parameter"): + run_rhime(**args) + + +def test_run_rhime_rejects_unsupported_output_format(tmp_path: Path) -> None: + args = { + "species": "ch4", + "sites": ["TAC"], + "averaging_period": ["1h"], + "domain": "EUROPE", + "start_date": "2019-01-01", + "end_date": "2019-01-02", + "flux_sources": ["total-ukghg-edgar7"], + "output_path": str(tmp_path), + "output_name": "test", + "output_format": "legacy", + } + + with pytest.raises(ValueError, match="Unsupported RHIME output_format"): + run_rhime(**args) + + +def test_supported_parameter_validation_accepts_sigma_per_site(tmp_path: Path) -> None: + args = { + "species": "ch4", + "sites": ["TAC"], + "averaging_period": ["1h"], + "domain": "EUROPE", + "start_date": "2019-01-01", + "end_date": "2019-01-02", + "flux_sources": ["total-ukghg-edgar7"], + "output_path": str(tmp_path), + "output_name": "test", + "sigma_per_site": False, + } + + rhime_module._validate_supported_params(args) + + +def test_run_rhime_rejects_multiple_flux_sources(tac_ch4_data_args, tmp_path: Path) -> None: + args = tac_ch4_data_args.copy() + args.update( + { + "flux_sources": ["a", "b"], + "output_path": str(tmp_path), + "output_name": "test", + } + ) + args.pop("emissions_name") + + with pytest.raises(ValueError, match="exactly one flux source"): + run_rhime(**args) + + +def test_run_rhime_multisector_rejects_single_flux_source(tac_ch4_data_args, tmp_path: Path) -> None: + args = tac_ch4_data_args.copy() + args.update( + { + "flux_sources": ["total-ukghg-edgar7"], + "output_path": str(tmp_path), + "output_name": "test", + } + ) + args.pop("emissions_name") + + with pytest.raises(ValueError, match="at least two flux sources"): + run_rhime_multisector(**args) + + +def test_run_rhime_api_smoke(tac_ch4_data_args, tmp_path: Path) -> None: + args = tac_ch4_data_args.copy() + args.update( + { + "flux_sources": args.pop("emissions_name"), + "output_name": "rhime_test", + "output_path": str(tmp_path), + "basis_algorithm": "quadtree", + "basis_output_path": str(tmp_path), + "nbasis": 4, + "nit": 1, + "burn": 0, + "tune": 0, + "nchain": 1, + "reload_merged_data": False, + "x_prior": {"pdf": "normal", "mu": 1.0, "sigma": 1.0}, + "bc_prior": {"pdf": "normal", "mu": 1.0, "sigma": 1.0}, + "sigma_prior": {"pdf": "uniform", "lower": 0.1, "upper": 10.0}, + "sampler_kwargs": {"random_seed": 123, "compute_convergence_checks": False}, + } + ) + + result = run_rhime(**args) + + assert isinstance(result, RhimeResult) + assert "x" in result.idata.posterior + assert "mu" in result.idata.posterior + assert result.run_spec.split_by_sectors is False + assert "inversion_output" in result.outputs + inv_input_long_names = [ + result.inv_inputs.mf.attrs.get("long_name", ""), + result.inv_inputs.mf_error.attrs.get("long_name", ""), + result.inv_inputs.mf_repeatability.attrs.get("long_name", ""), + result.inv_inputs.mf_variability.attrs.get("long_name", ""), + ] + assert all("number_of_observations" not in long_name for long_name in inv_input_long_names) + output_file = tmp_path / "rhime_test2019-01-01_inversion_output.nc" + assert output_file.exists() + reloaded = InversionOutput.load(output_file) + assert reloaded.species == "ch4" + obs_long_names = [ + reloaded.obs.attrs.get("long_name", ""), + reloaded.obs_err.attrs.get("long_name", ""), + reloaded.obs_repeatability.attrs.get("long_name", ""), + reloaded.obs_variability.attrs.get("long_name", ""), + ] + assert all("number_of_observations" not in long_name for long_name in obs_long_names) + + +def test_run_rhime_multisector_api_smoke(tac_ch4_data_args, tmp_path: Path) -> None: + args = tac_ch4_data_args.copy() + args.update( + { + "flux_sources": ["total-ukghg-edgar7", "total-ukghg-edgar7-shuffled"], + "output_name": "rhime_multisector_test", + "output_path": str(tmp_path), + "basis_algorithm": "quadtree", + "basis_output_path": str(tmp_path), + "nbasis": 4, + "nit": 1, + "burn": 0, + "tune": 0, + "nchain": 1, + "reload_merged_data": False, + "output_format": "none", + "x_prior": {"pdf": "normal", "mu": 1.0, "sigma": 1.0}, + "bc_prior": {"pdf": "normal", "mu": 1.0, "sigma": 1.0}, + "sigma_prior": {"pdf": "uniform", "lower": 0.1, "upper": 10.0}, + "sampler_kwargs": {"random_seed": 123, "compute_convergence_checks": False}, + } + ) + args.pop("emissions_name") + + result = run_rhime_multisector(**args) + + assert isinstance(result, RhimeResult) + assert result.run_spec.split_by_sectors is True + assert "x_total_ukghg_edgar7" in result.idata.posterior + assert "x_total_ukghg_edgar7_shuffled" in result.idata.posterior + assert "sector_flux_diagnostics" in result.outputs + + +def test_cli_run_rhime_passes_config_and_overrides(monkeypatch, tmp_path: Path) -> None: + config_file = tmp_path / "rhime.ini" + config_file.write_text('[RHIME.OUTPUT]\noutput_name = "test"\n', encoding="utf-8") + seen = {} + + def fake_run_rhime(*, config_file, **kwargs): + seen["config_file"] = config_file + seen["kwargs"] = kwargs + + monkeypatch.setattr("openghg_inversions.rhime.run_rhime", fake_run_rhime) + + main( + [ + "run-rhime", + "2019-01-01", + "2019-01-02", + "-c", + str(config_file), + "--output-path", + str(tmp_path), + "--kwargs", + '{"nit": 1}', + ] + ) + + assert seen["config_file"] == str(config_file) + assert seen["kwargs"]["start_date"] == "2019-01-01" + assert seen["kwargs"]["end_date"] == "2019-01-02" + assert seen["kwargs"]["output_path"] == str(tmp_path) + assert seen["kwargs"]["nit"] == 1 + + +def test_cli_run_rhime_multisector_passes_config(monkeypatch, tmp_path: Path) -> None: + config_file = tmp_path / "rhime.ini" + config_file.write_text('[RHIME.OUTPUT]\noutput_name = "test"\n', encoding="utf-8") + seen = {} + + def fake_run_rhime_multisector(*, config_file, **kwargs): + seen["config_file"] = config_file + seen["kwargs"] = kwargs + + monkeypatch.setattr("openghg_inversions.rhime.run_rhime_multisector", fake_run_rhime_multisector) + + main(["run-rhime-multisector", "-c", str(config_file)]) + + assert seen["config_file"] == str(config_file) + assert seen["kwargs"] == {} + + +def test_safe_pymc_name_sanitizes_source_names() -> None: + assert safe_pymc_name("total-ukghg-edgar7") == "total_ukghg_edgar7" + assert safe_pymc_name("Sector 2") == "sector_2" From b165c62d8177ba4edbff105777e746afdc277aea Mon Sep 17 00:00:00 2001 From: Brendan Murphy Date: Fri, 1 May 2026 10:00:50 +0100 Subject: [PATCH 2/3] Relax RHIME output path validation and fix trace saving --- openghg_inversions/rhime.py | 45 ++++++++++++++++++++++++++-- tests/test_rhime.py | 59 +++++++++++++++++++++++++++++++++++++ 2 files changed, 101 insertions(+), 3 deletions(-) diff --git a/openghg_inversions/rhime.py b/openghg_inversions/rhime.py index 9ed3efdb..221c4f57 100644 --- a/openghg_inversions/rhime.py +++ b/openghg_inversions/rhime.py @@ -289,15 +289,27 @@ def _required_run_params() -> set[str]: "domain", "start_date", "end_date", - "output_path", "output_name", } +def _is_missing_required_value(value: Any) -> bool: + """Return true when a required RHIME parameter has no usable value.""" + if value is None: + return True + if isinstance(value, str): + return not value.strip() + if isinstance(value, Sequence) and not isinstance(value, str | bytes) and len(value) == 0: + return True + return False + + def _validate_required_params(params: Mapping[str, Any]) -> None: """Raise if normalized run parameters are missing required values.""" missing = [ - name for name in sorted(_required_run_params()) if name not in params or params[name] in (None, " ") + name + for name in sorted(_required_run_params()) + if name not in params or _is_missing_required_value(params[name]) ] if missing: raise ValueError(f"Required RHIME parameter(s) missing: {missing!r}") @@ -323,6 +335,7 @@ def _validate_supported_params(params: Mapping[str, Any]) -> None: "verbose", "sampler_kwargs", "output_format", + "output_path", "save_trace", "save_inversion_output", "paris_postprocessing_kwargs", @@ -347,6 +360,25 @@ def _validate_output_format(output_format: str) -> None: ) +def _validate_output_path_settings( + *, + output_format: str, + output_path: str | None, + save_trace: str | Path | bool, + save_inversion_output: str | Path | bool, + multisector: bool, +) -> None: + """Raise if output settings imply a default save path but none is supplied.""" + if output_format == "none": + return + if output_path is not None: + return + if save_trace is True: + raise ValueError("`output_path` is required when `save_trace=True`.") + if not multisector and save_inversion_output is True: + raise ValueError("`output_path` is required when saving the standard RHIME InversionOutput.") + + def _resolve_output_path( save_setting: str | Path | bool, output_path: str | None, filename: str ) -> Path | None: @@ -778,7 +810,7 @@ def _write_standard_outputs( ) if trace_path is not None: trace_path.parent.mkdir(parents=True, exist_ok=True) - result.idata.to_netcdf(str(trace_path), engine="netcdf4", compress=True) + result.idata.to_netcdf(str(trace_path), compress=True) result.output_metadata["trace_path"] = str(trace_path) inv_out_path = _resolve_output_path( @@ -956,6 +988,13 @@ def _run_common( _validate_output_format(output_format) save_trace = params.pop("save_trace", False) save_inversion_output = params.pop("save_inversion_output", True) + _validate_output_path_settings( + output_format=output_format, + output_path=output_path, + save_trace=save_trace, + save_inversion_output=save_inversion_output, + multisector=multisector, + ) country_file = params.get("country_file") paris_postprocessing_kwargs = params.pop("paris_postprocessing_kwargs", None) diff --git a/tests/test_rhime.py b/tests/test_rhime.py index 3eafc9d5..d3d980bb 100644 --- a/tests/test_rhime.py +++ b/tests/test_rhime.py @@ -242,6 +242,65 @@ def test_run_rhime_rejects_unsupported_output_format(tmp_path: Path) -> None: run_rhime(**args) +def test_required_parameter_validation_allows_missing_output_path_for_in_memory_runs() -> None: + args = { + "species": "ch4", + "sites": ["TAC"], + "averaging_period": ["1h"], + "domain": "EUROPE", + "start_date": "2019-01-01", + "end_date": "2019-01-02", + "output_name": "test", + } + + rhime_module._validate_required_params(args) + + +@pytest.mark.parametrize( + ("name", "value"), + [ + ("species", ""), + ("sites", []), + ("domain", " "), + ], +) +def test_required_parameter_validation_rejects_empty_values(name: str, value) -> None: + args = { + "species": "ch4", + "sites": ["TAC"], + "averaging_period": ["1h"], + "domain": "EUROPE", + "start_date": "2019-01-01", + "end_date": "2019-01-02", + "output_name": "test", + } + args[name] = value + + with pytest.raises(ValueError, match=name): + rhime_module._validate_required_params(args) + + +def test_output_path_validation_allows_output_none_without_path() -> None: + rhime_module._validate_output_path_settings( + output_format="none", + output_path=None, + save_trace=False, + save_inversion_output=True, + multisector=False, + ) + + +def test_output_path_validation_rejects_default_standard_save_without_path() -> None: + with pytest.raises(ValueError, match="output_path"): + rhime_module._validate_output_path_settings( + output_format="inv_out", + output_path=None, + save_trace=False, + save_inversion_output=True, + multisector=False, + ) + + def test_supported_parameter_validation_accepts_sigma_per_site(tmp_path: Path) -> None: args = { "species": "ch4", From f980e92b9ef0689b779e64097d6ba78662c46ae2 Mon Sep 17 00:00:00 2001 From: Brendan Murphy Date: Fri, 1 May 2026 10:46:10 +0100 Subject: [PATCH 3/3] Prefer h5netcdf and allow in-memory RHIME runs --- openghg_inversions/rhime.py | 25 +++++++++++++++-- tests/test_rhime.py | 54 +++++++++++++++++++++++++++++++++++++ 2 files changed, 77 insertions(+), 2 deletions(-) diff --git a/openghg_inversions/rhime.py b/openghg_inversions/rhime.py index 221c4f57..791caf4e 100644 --- a/openghg_inversions/rhime.py +++ b/openghg_inversions/rhime.py @@ -405,6 +405,27 @@ def _define_output_filename( return Path(output_path) / f"{output_name}_{species}_{domain}_{start_date}{ext}" +def _save_inferencedata(idata: az.InferenceData, path: str | Path) -> None: + """Save InferenceData, preferring the h5netcdf backend with fallbacks.""" + failures = [] + for engine in ("h5netcdf", None, "netcdf4"): + try: + if engine is None: + idata.to_netcdf(str(path), compress=True) + else: + idata.to_netcdf(str(path), engine=engine, compress=True) + except Exception as exc: + engine_name = "arviz-default" if engine is None else engine + failures.append(f"{engine_name}: {exc}") + else: + return + + joined_failures = "\n".join(failures) + raise RuntimeError( + f"Could not save RHIME trace to {path}. Tried h5netcdf, ArviZ default, and netcdf4:\n{joined_failures}" + ) + + def _prepare_data( *, species: str, @@ -810,7 +831,7 @@ def _write_standard_outputs( ) if trace_path is not None: trace_path.parent.mkdir(parents=True, exist_ok=True) - result.idata.to_netcdf(str(trace_path), compress=True) + _save_inferencedata(result.idata, trace_path) result.output_metadata["trace_path"] = str(trace_path) inv_out_path = _resolve_output_path( @@ -960,7 +981,7 @@ def _run_common( averaging_period = _as_list(params.pop("averaging_period")) or [] start_date = params.pop("start_date") end_date = params.pop("end_date") - output_path = params.pop("output_path") + output_path = params.pop("output_path", None) output_name = params.pop("output_name") x_prior = params.pop("x_prior", None) diff --git a/tests/test_rhime.py b/tests/test_rhime.py index d3d980bb..5cf6604a 100644 --- a/tests/test_rhime.py +++ b/tests/test_rhime.py @@ -242,6 +242,23 @@ def test_run_rhime_rejects_unsupported_output_format(tmp_path: Path) -> None: run_rhime(**args) +def test_run_rhime_can_validate_output_format_without_output_path() -> None: + args = { + "species": "ch4", + "sites": ["TAC"], + "averaging_period": ["1h"], + "domain": "EUROPE", + "start_date": "2019-01-01", + "end_date": "2019-01-02", + "flux_sources": ["total-ukghg-edgar7"], + "output_name": "test", + "output_format": "legacy", + } + + with pytest.raises(ValueError, match="Unsupported RHIME output_format"): + run_rhime(**args) + + def test_required_parameter_validation_allows_missing_output_path_for_in_memory_runs() -> None: args = { "species": "ch4", @@ -301,6 +318,43 @@ def test_output_path_validation_rejects_default_standard_save_without_path() -> ) +def test_save_inferencedata_prefers_h5netcdf(tmp_path: Path) -> None: + class FakeInferenceData: + def __init__(self) -> None: + self.calls = [] + + def to_netcdf(self, path, **kwargs): + self.calls.append((path, kwargs)) + + idata = FakeInferenceData() + path = tmp_path / "trace.nc" + + rhime_module._save_inferencedata(idata, path) + + assert idata.calls == [(str(path), {"engine": "h5netcdf", "compress": True})] + + +def test_save_inferencedata_falls_back_after_h5netcdf_failure(tmp_path: Path) -> None: + class FakeInferenceData: + def __init__(self) -> None: + self.calls = [] + + def to_netcdf(self, path, **kwargs): + self.calls.append((path, kwargs)) + if kwargs.get("engine") == "h5netcdf": + raise ValueError("h5netcdf unavailable") + + idata = FakeInferenceData() + path = tmp_path / "trace.nc" + + rhime_module._save_inferencedata(idata, path) + + assert idata.calls == [ + (str(path), {"engine": "h5netcdf", "compress": True}), + (str(path), {"compress": True}), + ] + + def test_supported_parameter_validation_accepts_sigma_per_site(tmp_path: Path) -> None: args = { "species": "ch4",