diff --git a/openghg_inversions/basis/_functions.py b/openghg_inversions/basis/_functions.py index 79b355a2..91b8a55f 100644 --- a/openghg_inversions/basis/_functions.py +++ b/openghg_inversions/basis/_functions.py @@ -132,7 +132,7 @@ def basis_boundary_conditions(domain: str, basis_case: str, bc_basis_directory: def _flux_fp_from_fp_all( - fp_all: dict, emissions_name: list[str] | None = None + fp_all: dict, emissions_name: list[str] | None = None, scenario = "standard" ) -> tuple[xr.DataArray, list[xr.DataArray]]: """Get flux and list of footprints from `fp_all` dictionary and optional list of emissions names. @@ -149,15 +149,38 @@ def _flux_fp_from_fp_all( footprints (list): List of xarray DataArray containing the footprints of each sites. """ + flux_key = ".inner_flux" if scenario == "inner" and ".inner_flux" in fp_all else ".flux" + if emissions_name is not None: - flux = fp_all[".flux"][emissions_name[0]].data.flux + flux = fp_all[flux_key][emissions_name[0]].data.flux else: - first_flux = next(iter(fp_all[".flux"].values())) + first_flux = next(iter(fp_all[flux_key].values())) flux = first_flux.data.flux flux = cast(xr.DataArray, flux) - footprints: list[xr.DataArray] = [v.fp for k, v in fp_all.items() if not k.startswith(".")] + footprints = [] + for k, v in fp_all.items(): + if k.startswith("."): + continue + + # Need to discuss this further + # fp_x_flux is guaranteed to be on the + # EUROPE grid (it comes from the standard scenario). Raw .fp may be + # on the 6km grid if OpenGHG snapped it to the footprint resolution. + if scenario == "inner" and isinstance(v, xr.DataTree) and "inner" in v.children: + fp = v["inner"].ds["fp_x_flux"] + elif "fp_x_flux" in v: + fp = v["fp_x_flux"] + else: + fp = v[scenario].ds.fp + + # if grid still doesn't match flux, regrid to flux grid + if fp.sizes.get("lat") != flux.sizes.get("lat") or fp.sizes.get("lon") != flux.sizes.get("lon"): + fp = fp.interp(lat=flux.lat, lon=flux.lon, method="nearest").fillna(0.0) + fp = fp.assign_coords(lat=flux.lat, lon=flux.lon) + + footprints.append(fp) return flux, footprints @@ -213,6 +236,7 @@ def quadtreebasisfunction( fp_all: dict, start_date: str, domain : str, + scenario: str = "standard", emissions_name: list[str] | None = None, nbasis: int = 100, country_directory: str | None = None, @@ -258,7 +282,7 @@ def quadtreebasisfunction( quad_basis (xarray.DataArray): Array with lat/lon dimensions and basis regions encoded by integers. """ - flux, footprints = _flux_fp_from_fp_all(fp_all, emissions_name) + flux, footprints = _flux_fp_from_fp_all(fp_all=fp_all, emissions_name=emissions_name, scenario=scenario) fps = _mean_fp_times_mean_flux(flux, footprints, abs_flux=abs_flux, mask=mask).as_numpy() # use xr.apply_ufunc to keep xarray coords @@ -283,7 +307,8 @@ def bucketbasisfunction( nbasis: int = 100, country_directory: str | None = None, abs_flux: bool = False, - mask: xr.DataArray | None = None + mask: xr.DataArray | None = None, + scenario: str = "standard", ) -> xr.DataArray: """Basis functions calculated using a weighted region approach where each basis function / scaling region contains approximately @@ -311,12 +336,15 @@ def bucketbasisfunction( Default None country_directory (str): Directory containing land-sea files. If None, will use default files. + scenario (str): + Scenario for retrieving emissions files. + Default "standard" Returns: bucket_basis (xarray.DataArray): Array with lat/lon dimensions and basis regions encoded by integers. """ - flux, footprints = _flux_fp_from_fp_all(fp_all, emissions_name) + flux, footprints = _flux_fp_from_fp_all(fp_all=fp_all, emissions_name=emissions_name, scenario=scenario) fps = _mean_fp_times_mean_flux(flux, footprints, abs_flux=abs_flux, mask=mask).as_numpy() fps = fps / fps.max() @@ -351,6 +379,7 @@ def fixed_outer_regions_basis( nbasis: int = 100, country_directory: str | None = None, abs_flux: bool = False, + scenario: str = "standard" ) -> xr.DataArray: """Fix outer region of basis functions to InTEM regions, and fit the inner regions using `basis_algorithm`. @@ -389,7 +418,7 @@ def fixed_outer_regions_basis( intem_regions = xr.open_dataset(intem_regions_path).region # force intem_regions to use flux coordinates - flux, _ = _flux_fp_from_fp_all(fp_all, emissions_name) + flux, _ = _flux_fp_from_fp_all(fp_all=fp_all, emissions_name=emissions_name, scenario=scenario) _, intem_regions = xr.align(flux, intem_regions, join="override") inner_index = intem_regions.values.max() @@ -397,7 +426,7 @@ def fixed_outer_regions_basis( mask = intem_regions == inner_index basis_function = basis_functions[basis_algorithm].algorithm - inner_region = basis_function(fp_all, start_date, domain, emissions_name, nbasis, country_directory, abs_flux, mask=mask) + inner_region = basis_function(fp_all=fp_all, start_date=start_date, domain=domain, emissions_name=emissions_name, nbasis=nbasis, country_directory=country_directory, abs_flux=abs_flux, mask=mask) basis = intem_regions.rename("basis") diff --git a/openghg_inversions/basis/_helpers.py b/openghg_inversions/basis/_helpers.py index 91ddfe3e..63c1e5fd 100644 --- a/openghg_inversions/basis/_helpers.py +++ b/openghg_inversions/basis/_helpers.py @@ -6,7 +6,11 @@ from ._functions import basis_boundary_conditions -def fp_sensitivity(fp_and_data: dict, basis_func: xr.DataArray | dict[str, xr.DataArray]) -> dict: +def fp_sensitivity( + fp_and_data: dict, + basis_func: xr.DataArray | dict[str, xr.DataArray], + inner_basis_func: xr.DataArray | None = None, +) -> dict: """Add a sensitivity matrix, H, to each site xr.Dataset in fp_and_data. The sensitivity matrix H takes the footprint sensitivities (the `fp` variable), @@ -38,7 +42,6 @@ def fp_sensitivity(fp_and_data: dict, basis_func: xr.DataArray | dict[str, xr.Da if len(flux_sources) == 1: if not isinstance(basis_func, xr.DataArray): basis_func = next(iter(basis_func.values())) - fp_x_flux_name = "fp_x_flux" else: @@ -63,18 +66,107 @@ def fp_sensitivity(fp_and_data: dict, basis_func: xr.DataArray | dict[str, xr.Da if "time" in basis_func.dims and basis_func.sizes["time"] <= 1: basis_func = basis_func.squeeze("time") + if inner_basis_func is not None and "time" in inner_basis_func.dims and inner_basis_func.sizes["time"] <= 1: + inner_basis_func = inner_basis_func.squeeze("time") + fp_and_data[".basis"] = basis_func + if inner_basis_func is not None: + fp_and_data[".basis_inner"] = inner_basis_func + + if inner_basis_func is not None: + invalid_inner_sites = [ + site + for site in sites + if not ( + isinstance(fp_and_data[site], xr.DataTree) and "inner" in fp_and_data[site].children + ) + ] + if invalid_inner_sites: + raise ValueError( + "Inner basis supplied, but some sites are not DataTree entries with an inner child: " + f"{invalid_inner_sites}." + ) for site in sites: - sensitivity = apply_fp_basis_functions( - fp_x_flux=fp_and_data[site][fp_x_flux_name], - basis_func=basis_func, - ) - fp_and_data[site]["H"] = sensitivity + entry = fp_and_data[site] + + + # extract root fp_x_flux (already masked if inner domain exists) + if isinstance(entry, xr.DataTree): + if "standard" in entry.children: + root_ds = entry["standard"].ds + else: + root_ds = entry.ds + fp_x_flux_outer = root_ds[fp_x_flux_name] + + # Compute outer H from the (already masked) fp_x_flux + sensitivity = apply_fp_basis_functions( + fp_x_flux=fp_x_flux_outer, + basis_func=basis_func, + ) + + # Compute H_inner from the inner child's fp_x_flux (its own lat/lon grid) + if "inner" in entry.children: + if inner_basis_func is None: + raise ValueError("Inner-domain data exists but no inner basis function was provided.") + + inner_fp_x_flux = entry["inner"].ds[fp_x_flux_name] + H_inner = apply_fp_basis_functions( + fp_x_flux=inner_fp_x_flux, + basis_func=inner_basis_func, + ) + # Write both back into the DataTree + new_root = root_ds.assign({"H": sensitivity}) + new_inner = entry["inner"].ds.assign({"H_inner": H_inner}) + fp_and_data[site] = xr.DataTree.from_dict({ + "/standard": new_root, + "/inner": new_inner, + }) + else: + fp_and_data[site] = xr.DataTree.from_dict({ + "/standard": root_ds.assign({"H": sensitivity}) + }) + + else: + if inner_basis_func is not None: + raise ValueError( + "Inner-domain inversion requires DataTree site entries with an inner child. " + f"Site '{site}' is a plain Dataset." + ) + + # Legacy: plain xr.Dataset path — unchanged + sensitivity = apply_fp_basis_functions( + fp_x_flux=entry[fp_x_flux_name], + basis_func=basis_func, + ) + fp_and_data[site]["H"] = sensitivity return fp_and_data +def combine_inner_outer_fp_x_flux( + inner_fp_x_flux: xr.DataArray, + outer_fp_x_flux: xr.DataArray, +) -> xr.DataArray: + """Merge inner (6km) and outer (EUROPE) fp_x_flux.""" + # Regrid inner fp_x_flux to the same grid as outer fp_x_flux, and then patch it in where the inner domain mask is True. + # regrid inner to EUROPE lat/lon coords + inner_regridded = inner_fp_x_flux.interp(lat=outer_fp_x_flux.lat, lon=outer_fp_x_flux.lon, method="nearest") + + # force coordinates to exactly match outer (avoids float precision + # mismatches that prevent xr.align / xr.where from working correctly) + inner_regridded = inner_regridded.assign_coords(lat=outer_fp_x_flux.lat, lon=outer_fp_x_flux.lon) + + # fill NaN (points outside the inner domain extent) with 0 + inner_regridded = inner_regridded.fillna(0.0) + + # True where the inner domain contributed non-zero values at any timestep + inner_has_coverage = (inner_regridded != 0).any("time") + + # Both arrays are now on the EUROPE grid so xr.where is safe + return xr.where(inner_has_coverage, inner_regridded, outer_fp_x_flux) + + def apply_fp_basis_functions( fp_x_flux: xr.DataArray, basis_func: xr.DataArray, @@ -100,7 +192,9 @@ def apply_fp_basis_functions( _, basis_aligned = xr.align(fp_x_flux.isel(time=0), basis_func, join="override") basis_mat = get_xr_dummies(basis_aligned, cat_dim="region") - sensitivity = sparse_xr_dot(basis_mat, fp_x_flux.fillna(0.0), dim=["lat", "lon"]) + spatial_dims = ["lat", "lon"] + + sensitivity = sparse_xr_dot(basis_mat, fp_x_flux.fillna(0.0), dim=spatial_dims) if sensitivity.dims[:2] != ("region", "time"): sensitivity = sensitivity.transpose("region", "time", ...) @@ -126,14 +220,33 @@ def bc_sensitivity( dict of xr.Datasets in same format as fp_and_data with `H_bc` sensitivity matrix added. """ + def _outer_ds(entry): + if isinstance(entry, xr.DataTree): + if "standard" in entry.children: + return entry["standard"].ds + return entry.ds + return entry + + def _with_updated_outer(entry, updated_outer_ds): + if not isinstance(entry, xr.DataTree): + return updated_outer_ds + + if "standard" in entry.children: + tree_dict = {f"/{name}": child.ds for name, child in entry.children.items() if child.ds is not None} + tree_dict["/standard"] = updated_outer_ds + return xr.DataTree.from_dict(tree_dict) + + return xr.DataTree(dataset=updated_outer_ds, children=dict(entry.children)) + sites = [key for key in list(fp_and_data.keys()) if key[0] != "."] if basis_case.lower() == "nesw": for site in sites: - ds = fp_and_data[site] - bc_ds = ds[[f"bc_{d}" for d in "nesw"]].rename({f"bc_{d}": d for d in "nesw"}) + entry = fp_and_data[site] + outer_ds = _outer_ds(entry) + bc_ds = outer_ds[[f"bc_{d}" for d in "nesw"]].rename({f"bc_{d}": d for d in "nesw"}) sensitivity = bc_ds.sum(["lat", "lon", "height"]).to_dataarray(dim="bc_region") - fp_and_data[site]["H_bc"] = sensitivity + fp_and_data[site] = _with_updated_outer(entry, outer_ds.assign({"H_bc": sensitivity})) return fp_and_data @@ -151,10 +264,11 @@ def bc_sensitivity( bc_basis = basis_func.rename({dv: str(dv).replace("basis_", "") for dv in basis_func.data_vars}) for site in sites: - ds = fp_and_data[site] - bc_ds = ds[[f"bc_{d}" for d in "nesw"]] + entry = fp_and_data[site] + outer_ds = _outer_ds(entry) + bc_ds = outer_ds[[f"bc_{d}" for d in "nesw"]] sensitivity = (bc_ds * bc_basis).sum(["lat", "lon", "height"]).to_dataarray(dim="__newdim__").sum("__newdim__") sensitivity = sensitivity.rename(region="bc_region") - fp_and_data[site]["H_bc"] = sensitivity + fp_and_data[site] = _with_updated_outer(entry, outer_ds.assign({"H_bc": sensitivity})) return fp_and_data diff --git a/openghg_inversions/basis/_wrapper.py b/openghg_inversions/basis/_wrapper.py index 6d719e49..eadfa480 100644 --- a/openghg_inversions/basis/_wrapper.py +++ b/openghg_inversions/basis/_wrapper.py @@ -26,6 +26,7 @@ def basis_functions_wrapper( country_directory: str | None = None, outputname: str | None = None, output_path: str | None = None, + inner_domain: str | None = None, ): """Wrapper function for selecting basis function algorithm. @@ -81,6 +82,7 @@ def basis_functions_wrapper( Dictionary object similar to fp_all but with information on basis functions and sensitivities """ + inner_basis_data_array = None if use_bc is True and bc_basis_case is None: raise ValueError("If `use_bc` is True, you must specify `bc_basis_case`.") @@ -118,13 +120,39 @@ def basis_functions_wrapper( "Basis algorithm not recognised. Please use either 'quadtree' or 'weighted', or input a basis function file" ) from e print(f"Using {basis_function.description} to derive basis functions.") - basis_data_array = basis_function.algorithm(fp_all, start_date, domain, emissions_name, nbasis, country_directory=country_directory) + + if inner_domain is not None: + inner_basis_data_array = basis_function.algorithm( + fp_all=fp_all, + start_date=start_date, + domain=f"{domain}-{inner_domain}", + emissions_name=emissions_name, + nbasis=nbasis, + country_directory=country_directory, + scenario="inner", + ) + + print(f"Computing inner basis took {time() - basis_start}s.") + + basis_data_array = basis_function.algorithm( + fp_all=fp_all, + start_date=start_date, + domain=domain, + emissions_name=emissions_name, + nbasis=nbasis, + country_directory=country_directory, + ) print(f"Computing basis took {time() - basis_start}s.") fp_sens_start = time() - fp_data = fp_sensitivity(fp_all, basis_func=basis_data_array) - print(f"Computing fp sensitivity took {time() - fp_sens_start}s.") + if basis_data_array is not None: + fp_data = fp_sensitivity( + fp_all, + basis_func=basis_data_array, + inner_basis_func=inner_basis_data_array, + ) + print(f"Computing fp sensitivity took {time() - fp_sens_start}s.") if use_bc is True: bc_sens_start = time() diff --git a/openghg_inversions/filters.py b/openghg_inversions/filters.py index a9308fab..eb8c6163 100644 --- a/openghg_inversions/filters.py +++ b/openghg_inversions/filters.py @@ -26,7 +26,14 @@ # this dictionary will be populated by using the decorator `register_filter` filtering_functions = {} - +def _time_mask(mask: xr.DataArray, filter_name: str) -> xr.DataArray: + if "time" not in mask.dims: + raise ValueError(f"{filter_name} expects a 'time' dimension.") + extra_dims = [dim for dim in mask.dims if dim != "time"] + if extra_dims: + mask = mask.any(dim=extra_dims) + return mask + def register_filter(filt: Callable) -> Callable: """Decorator function to register filters. @@ -134,15 +141,46 @@ def filtering( for site in filters: if filters[site] is not None and site in sites: for filt in filters[site]: - n_nofilter = datasets[site].time.values.shape[0] - - datasets[site] = filtering_functions[filt](datasets[site], keep_missing=keep_missing) - - n_filter = datasets[site].time.values.shape[0] - n_dropped = n_nofilter - n_filter - perc_dropped = np.round(n_dropped / n_nofilter * 100, 2) - print(f"{filt} filter removed {n_dropped} ({perc_dropped} %) obs at site {site}") - if n_filter == 0: break # no values left, so we won't apply remaining filters + + site_entry = datasets[site] + + # --- DataTree handling --- + if isinstance(site_entry, xr.DataTree): + outer_ds = site_entry["standard"].ds if "standard" in site_entry.children else site_entry.ds + n_nofilter = outer_ds.time.values.shape[0] + + filtered_outer = filtering_functions[filt](outer_ds, keep_missing=keep_missing) + + n_filter = filtered_outer.time.values.shape[0] + n_dropped = n_nofilter - n_filter + perc_dropped = np.round(n_dropped / n_nofilter * 100, 2) + print(f"{filt} filter removed {n_dropped} ({perc_dropped} %) obs at site {site}") + + # Rebuild DataTree preserving the standard/inner layout. + standard_key = "/standard" if "standard" in site_entry.children else "/" + dt_dict = {standard_key: filtered_outer} + if "inner" in site_entry.children: + inner_ds = site_entry["inner"].ds + # Keep inner time axis aligned with the (now filtered) outer time axis + dt_dict["/inner"] = inner_ds.reindex( + time=filtered_outer.time, fill_value=0.0 + ) + datasets[site] = xr.DataTree.from_dict(dt_dict) + + if n_filter == 0: + break + + # --- original flat Dataset handling (unchanged) --- + else: + n_nofilter = datasets[site].time.values.shape[0] + + datasets[site] = filtering_functions[filt](datasets[site], keep_missing=keep_missing) + + n_filter = datasets[site].time.values.shape[0] + n_dropped = n_nofilter - n_filter + perc_dropped = np.round(n_dropped / n_nofilter * 100, 2) + print(f"{filt} filter removed {n_dropped} ({perc_dropped} %) obs at site {site}") + if n_filter == 0: break # no values left, so we won't apply remaining filters return datasets @@ -372,22 +410,17 @@ def pblh_min(dataset: xr.Dataset, pblh_threshold: float = 200.0, keep_missing: b """ pblh_da = dataset.PBLH if "PBLH" in dataset.data_vars else dataset.atmosphere_boundary_layer_thickness - ti = [i for i, pblh in enumerate(pblh_da) if pblh > pblh_threshold] - - if keep_missing is True: - mf_data_array = dataset.mf - dataset_temp = dataset.drop("mf") - - dataarray_temp = mf_data_array[dict(time=ti)] + filt = _time_mask(pblh_da > pblh_threshold, "pblh_min") - mf_ds = xr.Dataset( - {"mf": (["time"], dataarray_temp)}, coords={"time": (dataarray_temp.coords["time"])} - ) + # Some inputs can include extra dimensions; collapse to a per-time mask. + if "time" not in filt.dims: + raise ValueError("PBLH filter expects a 'time' dimension.") + extra_dims = [dim for dim in filt.dims if dim != "time"] + if extra_dims: + filt = filt.any(dim=extra_dims) - dataset_out = combine_datasets(dataset_temp, mf_ds, method=None) - return dataset_out - else: - return dataset[dict(time=ti)] + drop = not keep_missing + return dataset.where(filt.compute(), drop=drop) @register_filter @@ -421,7 +454,7 @@ def pblh_inlet_diff( pblh_da = dataset.PBLH if "PBLH" in dataset.data_vars else dataset.atmosphere_boundary_layer_thickness - filt = pblh_da > inlet_height + diff_threshold + filt = _time_mask(pblh_da > inlet_height + diff_threshold, "pblh_inlet_diff") drop = not keep_missing return dataset.where(filt.compute(), drop=drop) diff --git a/openghg_inversions/hbmcmc/config/openghg_hbmcmc_input_template.ini b/openghg_inversions/hbmcmc/config/openghg_hbmcmc_input_template.ini index ce6aab27..0399d9ac 100644 --- a/openghg_inversions/hbmcmc/config/openghg_hbmcmc_input_template.ini +++ b/openghg_inversions/hbmcmc/config/openghg_hbmcmc_input_template.ini @@ -16,12 +16,12 @@ ; start_date (str): Start of observations to extract (format YYYY-MM-DD) ; end_date (str): End of observations to extract (format YYYY-MM-DD) (non-inclusive) -species = "" ; (required) +species = "ch4" ; (required) use_tracer = False ; (required) -sites = [] ; (required) -averaging_period = [] ; (required) -start_date = " " ; (required - but can be specified on command line instead) -end_date = " " ; (required - but can be specified on command line instead) +sites = ["mhd"] ; (required) +averaging_period = ["4h"] ; (required) +start_date = "2020-01-01" ; (required - but can be specified on command line instead) +end_date = "2024-12-31" ; (required - but can be specified on command line instead) ; save_merged_data (bool): If True, saves merged data object ; reload_merged_data (bool): If True, reads merged data object rather than rerunning get_data @@ -37,7 +37,7 @@ merged_data_dir = " " ; obs_data_level (list/str): Measurement data quality level ; filters (str): Data filtering approach to apply -inlet = None +inlet = "9m" instrument = None calibration_scale = None obs_data_level = None @@ -51,10 +51,13 @@ filters = [] ; footprint_store (str): Name of object store with footprints data ; emissions_store (str): Name of flux emissions object store -bc_store = " " ; (required) -obs_store = " " ; (required) -footprint_store = " " ; (required) -emissions_store = " " ; (required) +bc_store = "/group/chem/acrg/object_stores/shared_store_zarr" ; (required) +obs_store = "/group/chem/acrg/object_stores/paris/obs_icos_2025_08_store" ; (required) +footprint_store = "shared_store_zarr" ; (required) +emissions_store = "/acrg/prasad/job_scripts/openghg_inversions/prior_flux_2023_EUROPE " ; (required) +inner_footprint_store = "shared_store_zarr" +inner_emissions_store = "/acrg/prasad/job_scripts/openghg_inversions/prior_flux_2023_EUROPE_6km" + [INPUT.PRIORS] @@ -67,13 +70,14 @@ emissions_store = " " ; (required) ; emissions_name (list): Name of emissions sources as used when adding flux files to the object store ; bc_input (list/str): Name of boundary conditions data to use from object store -domain = " " ; (required) +domain = "europe" ; (required) +inner_domain = "6km" met_model = None fp_model = None -fp_height = None -fp_species = None -emissions_name = [None] ; (required) -bc_input = None +fp_height = "10m" +fp_species = "inert" +emissions_name = ["edgarv80_wetchartsv131"] ; (required) +bc_input = "camsv22r2_daily" [INPUT.BASIS_CASE] ; Input values to extract the basis cases to use within the inversion for boundary conditions and emissions @@ -89,14 +93,13 @@ bc_input = None ; country_file (str/None): Directory with filename containing the indices of country boundaries in domain ; country_directory (str/None): Directory containing auxiliary country files for deriving basis functions (land-sea mask, InTEM outer regions) -basis_algorithm = " " +basis_algorithm = "quadtree" bc_basis_case = "NESW" fp_basis_case = None -nbasis = -basis_directory = None -bc_basis_directory = None -country_file = " " -country_directory = None +nbasis = 250 +basis_directory = "/group/chem/acrg/LPDM/basis_functions/" +bc_basis_directory = "/group/chem/acrg/LPDM/bc_basis_functions/" +country_file = country_file = "/group/chem/acrg/LPDM/countries/country_EUROPE_EEZ_PARIS_gapfilled.nc" [MCMC.TYPE] ; Which MCMC setup to use. This defines the function which will be called and the expected inputs. @@ -173,9 +176,9 @@ sigma_per_site = True ; burn (int): Number of iterations to burn/discard in MCMC ; tune (int): Number of iterations to use to tune step size -nit = ; (required) -burn = ; (required) -tune = ; (required) +nit = 1000 ; (required) +burn = 100 ; (required) +tune = 1000 ; (required) [MCMC.NCHAIN] ; nchain (int): Number of chains to run simultaneously. Must be >=2 to allow convergence to be checked. @@ -222,5 +225,5 @@ no_model_error = False ; outputpath (str): Directory to write output ; outputname (str): Unique identifier for output/run name. -outputpath = " " ; (required) -outputname = " " ; (required) +outputpath = "/group/chem/acrg/prasad/job_scripts/openghg_inversions/inversion_outputs/paris_bc" ; (required) +outputname = "mhd_openghg_inversions" ; (required) diff --git a/openghg_inversions/hbmcmc/hbmcmc.py b/openghg_inversions/hbmcmc/hbmcmc.py index 85e51c17..b8ac7ffc 100644 --- a/openghg_inversions/hbmcmc/hbmcmc.py +++ b/openghg_inversions/hbmcmc/hbmcmc.py @@ -88,11 +88,31 @@ def make_inv_inputs( offset_args, power, ): + def _outer_ds(site_entry): + if isinstance(site_entry, xr.DataTree): + if "standard" in site_entry.children: + return site_entry["standard"].ds + return site_entry.ds + return site_entry + + def _with_updated_outer(site_entry, updated_outer_ds): + if not isinstance(site_entry, xr.DataTree): + return updated_outer_ds + + # Preserve empty-root DataTree structure by writing data to /standard. + if "standard" in site_entry.children: + tree_dict = {f"/{name}": child.ds for name, child in site_entry.children.items() if child.ds is not None} + tree_dict["/standard"] = updated_outer_ds + return xr.DataTree.from_dict(tree_dict) + + return xr.DataTree(dataset=updated_outer_ds, children=dict(site_entry.children)) + # Trigger dask computations # we only compute the variables we need below to_compute = [ "H", "H_bc", + "H_inner", "mf", "mf_error", "mf_repeatability", @@ -103,8 +123,25 @@ def make_inv_inputs( "mf_mod", ] for site in sites: - to_compute_site = [dv for dv in to_compute if dv in fp_data[site].data_vars] - fp_data[site][to_compute_site] = fp_data[site][to_compute_site].compute() + site_entry = fp_data[site] + outer_ds = _outer_ds(site_entry) + + to_compute_outer = [dv for dv in to_compute if dv in outer_ds.data_vars] + if to_compute_outer: + computed_outer = outer_ds.assign({var: outer_ds[var].compute() for var in to_compute_outer}) + fp_data[site] = _with_updated_outer(site_entry, computed_outer) + + if isinstance(fp_data[site], xr.DataTree) and "inner" in fp_data[site].children: + inner_ds = fp_data[site]["inner"].ds + if "H_inner" in inner_ds.data_vars: + computed_inner = inner_ds.assign({"H_inner": inner_ds["H_inner"].compute()}) + tree_dict = { + f"/{name}": child.ds + for name, child in fp_data[site].children.items() + if child.ds is not None and name != "inner" + } + tree_dict["/inner"] = computed_inner + fp_data[site] = xr.DataTree.from_dict(tree_dict) # Get inputs ready error = np.zeros(0) @@ -116,6 +153,12 @@ def make_inv_inputs( Y = np.zeros(0) siteindicator = np.zeros(0) + Hx_inner = None + has_inner = any( + isinstance(fp_data[s], xr.DataTree) and "H_inner" in fp_data[s]["inner"].ds.data_vars + for s in sites if s in fp_data + ) + for si, site in enumerate(sites): # if site was dropped, skip; this makes the site indicator numbers consistent # even if a site is dropped @@ -125,39 +168,43 @@ def make_inv_inputs( # select variables to drop NaNs from drop_vars = [] for var in ["H", "H_bc", "mf", "mf_error"]: - if var in fp_data[site].data_vars: + if var in _outer_ds(fp_data[site]).data_vars: drop_vars.append(var) # pymc doesn't like NaNs, so drop them for the variables used below - fp_data[site] = fp_data[site].dropna("time", subset=drop_vars) + # DataTree doesn't support dropna; use sel with valid time indices instead + if isinstance(fp_data[site], xr.DataTree): + valid_times = _outer_ds(fp_data[site]).dropna("time", subset=drop_vars).time + fp_data[site] = fp_data[site].sel(time=valid_times) + else: + fp_data[site] = fp_data[site].dropna("time", subset=drop_vars) # repeatability/variability chosen/combined into mf_error in `get_data.py` - error = np.concatenate((error, fp_data[site].mf_error.values)) + ds = _outer_ds(fp_data[site]) - # make repeatability and variability for outputs (not used directly in inversions) - obs_repeatability = np.concatenate((obs_repeatability, fp_data[site].mf_repeatability.values)) - obs_variability = np.concatenate((obs_variability, fp_data[site].mf_variability.values)) + error = np.concatenate((error, ds["mf_error"].values)) + obs_repeatability = np.concatenate((obs_repeatability, ds["mf_repeatability"].values)) + obs_variability = np.concatenate((obs_variability, ds["mf_variability"].values)) + Y = np.concatenate((Y, ds["mf"].values)) - Y = np.concatenate((Y, fp_data[site].mf.values)) - if fp_data[site].attrs.get("inlet") == "column" or fp_data[site].attrs.get("platform") == "satellite": - obs_prior_factor = np.concatenate((obs_prior_factor, fp_data[site].mf_prior_factor.values)) + if ds.attrs.get("inlet") == "column" or ds.attrs.get("platform") == "satellite": + obs_prior_factor = np.concatenate((obs_prior_factor, ds["mf_prior_factor"].values)) obs_prior_upper_level_factor = np.concatenate( - (obs_prior_upper_level_factor, fp_data[site].mf_prior_upper_level_factor.values) + (obs_prior_upper_level_factor, ds["mf_prior_upper_level_factor"].values) ) else: - # If not a column/satellite measurement, set prior factors to zero - # This is required if there is mix of insitu and column measurements - obs_prior_factor = np.concatenate((obs_prior_factor, np.zeros(fp_data[site].mf.size))) + obs_prior_factor = np.concatenate((obs_prior_factor, np.zeros(ds["mf"].size))) obs_prior_upper_level_factor = np.concatenate( - (obs_prior_upper_level_factor, np.zeros(fp_data[site].mf.size)) + (obs_prior_upper_level_factor, np.zeros(ds["mf"].size)) ) - siteindicator = np.concatenate((siteindicator, np.ones_like(fp_data[site].mf.values) * si)) - if si == 0: - Ytime = fp_data[site].time.values - else: - Ytime = np.concatenate((Ytime, fp_data[site].time.values)) - Hx = fp_data[site].H.values if si == 0 else np.hstack((Hx, fp_data[site].H.values)) + siteindicator = np.concatenate((siteindicator, np.ones_like(ds["mf"].values) * si)) + Ytime = ds["time"].values if si == 0 else np.concatenate((Ytime, ds["time"].values)) + Hx = ds["H"].values if si == 0 else np.hstack((Hx, ds["H"].values)) + + if has_inner and isinstance(fp_data[site], xr.DataTree) and "inner" in fp_data[site].children: + h_inner = fp_data[site]["inner"].ds["H_inner"].values + Hx_inner = h_inner if Hx_inner is None else np.hstack((Hx_inner, h_inner)) if np.isnan(Hx).any(): warnings.warn(f"Hx matrix contains {np.isnan(Hx).flatten().sum()} NaN values") @@ -194,6 +241,7 @@ def make_inv_inputs( mcmc_args = { "Hx": Hx, + "Hx_inner": Hx_inner if Hx_inner is not None else None, "Y": Y, "error": error, "siteindicator": siteindicator, @@ -220,7 +268,7 @@ def make_inv_inputs( if bc_freq == "monthly": Hmbc = setup.monthly_bcs(start_date, end_date, site, fp_data) elif bc_freq is None: - Hmbc = fp_data[site].H_bc.values + Hmbc = _outer_ds(fp_data[site])["H_bc"].values else: Hmbc = setup.create_bc_sensitivity(start_date, end_date, site, fp_data, bc_freq) @@ -262,6 +310,8 @@ def fixedbasisMCMC( obs_store: str = "user", footprint_store: str = "user", emissions_store: str = "user", + inner_footprint_store: str = "user", + inner_emissions_store: str = "user", met_model: list | None = None, fp_model: str | None = None, # Changed to none. When "NAME" specified FPs are not found fp_height: list[str] | None = None, @@ -282,6 +332,7 @@ def fixedbasisMCMC( country_directory: str | None = None, country_file: str | None = None, bc_input: str | None = None, + inner_domain: str | None = None, basis_algorithm: str = "weighted", nbasis: int = 100, xprior: dict = {"pdf": "truncatednormal", "mu": 1.0, "sigma": 1.0, "lower": 0.0}, @@ -314,7 +365,7 @@ def fixedbasisMCMC( min_error_options: dict | None = None, output_format: Literal[ "hbmcmc", "paris", "basic", "merged_data", "inv_out", "mcmc_args", "mcmc_results" - ] = "hbmcmc", + ] = "paris", paris_postprocessing: bool = False, paris_postprocessing_kwargs: dict | None = None, power: dict | float = 1.99, @@ -501,6 +552,20 @@ def fixedbasisMCMC( print("Successfully read in merged data.\n") rerun_merge = False + if inner_domain is not None: + invalid_inner_sites = [ + s + for s, entry in fp_all.items() + if not s.startswith(".") + and not (isinstance(entry, xr.DataTree) and "inner" in entry.children) + ] + if invalid_inner_sites: + rerun_merge = True + print( + "Loaded merged data does not contain inner-domain DataTree entries " + f"for sites {invalid_inner_sites}; re-running data merge." + ) + # check if sites were dropped when merged data was saved sites_merged = [s for s in fp_all if "." not in s] @@ -538,6 +603,7 @@ def fixedbasisMCMC( species=species, sites=sites, domain=domain, + inner_domain=inner_domain, averaging_period=averaging_period, start_date=start_date, end_date=end_date, @@ -558,6 +624,8 @@ def fixedbasisMCMC( obs_store=obs_store, footprint_store=footprint_store, emissions_store=emissions_store, + inner_footprint_store=inner_footprint_store, + inner_emissions_store=inner_emissions_store, averagingerror=averaging_error, save_merged_data=save_merged_data, merged_data_name=merged_data_name, @@ -584,13 +652,14 @@ def fixedbasisMCMC( use_bc=use_bc, species=species, domain=domain, + inner_domain=inner_domain, start_date=start_date, fix_outer_regions=fix_basis_outer_regions, emissions_name=emissions_name, outputname=outputname, output_path=basis_output_path, ) - + print(f"Basis functions applied to data {domain}.\n ") # Apply named filters to the data if filters is not None: try: @@ -601,14 +670,26 @@ def fixedbasisMCMC( # # Apply compute before filtering to avoid dask issue for site in sites: - fp_data[site] = fp_data[site].compute() + entry = fp_data[site] + if isinstance(entry, xr.DataTree): + # compute dask arrays in standard and inner children + dt_dict = {} + if "standard" in entry.children: + dt_dict["/standard"] = entry["standard"].ds.compute() + elif entry.ds is not None: + dt_dict["/"] = entry.ds.compute() + if "inner" in entry.children: + dt_dict["/inner"] = entry["inner"].ds.compute() + fp_data[site] = xr.DataTree.from_dict(dt_dict) + else: + fp_data[site] = entry.compute() fp_data = filtering(fp_data, filters) - # check for sites dropped by filtering dropped_sites = [] for site in sites: # check if some datasets are empty due to filtering - if fp_data[site].time.values.shape[0] == 0: + site_ds = fp_data[site]["standard"].ds if isinstance(fp_data[site], xr.DataTree) and "standard" in fp_data[site].children else fp_data[site].ds + if site_ds.time.values.shape[0] == 0: dropped_sites.append(site) del fp_data[site] @@ -619,6 +700,22 @@ def fixedbasisMCMC( for si, site in enumerate(sites): fp_data[site].attrs["Domain"] = domain + if inner_domain is not None: + invalid_inner_sites = [ + site + for site in sites + if not ( + isinstance(fp_data[site], xr.DataTree) + and "inner" in fp_data[site].children + and "H_inner" in fp_data[site]["inner"].ds.data_vars + ) + ] + if invalid_inner_sites: + raise ValueError( + "Inner-domain inversion requires DataTree entries with computed H_inner for all sites. " + f"Missing or invalid inner structure for: {invalid_inner_sites}." + ) + # Inverse models if use_tracer: raise ValueError("Model does not currently include tracer model. Watch this space") @@ -686,6 +783,8 @@ def fixedbasisMCMC( del post_process_args["power"] # add any additional kwargs to mcmc_args (these aren't needed for post processing) + kwargs.pop("mcmc_type", None) + mcmc_args.update(kwargs) end_data = time.time() diff --git a/openghg_inversions/hbmcmc/inversion_pymc.py b/openghg_inversions/hbmcmc/inversion_pymc.py index e2d480cd..60b4214a 100644 --- a/openghg_inversions/hbmcmc/inversion_pymc.py +++ b/openghg_inversions/hbmcmc/inversion_pymc.py @@ -121,6 +121,7 @@ def _make_coords( Hbc: np.ndarray | None = None, sites: list[str] | None = None, sigma_per_site: bool = False, + Hx_inner: np.ndarray | None = None, ) -> dict: result = { "nmeasure": np.arange(len(Y)), @@ -129,6 +130,8 @@ def _make_coords( "nsigma_time": np.unique(sigma_freq_indices), "nsigma_site": np.unique(site_indicator) if sigma_per_site else [0], } + if Hx_inner is not None: + result["nx_inner"] = np.arange(Hx_inner.shape[0]) if Hbc is not None: result["nbc"] = np.arange(Hbc.shape[0]) return result @@ -141,6 +144,7 @@ def inferpymc( siteindicator: np.ndarray, sigma_freq_index: np.ndarray, Hbc: np.ndarray | None = None, + Hx_inner: np.ndarray | None = None, xprior: dict = {"pdf": "normal", "mu": 1.0, "sigma": 1.0}, bcprior: dict = {"pdf": "normal", "mu": 1.0, "sigma": 1.0}, sigprior: dict = {"pdf": "uniform", "lower": 0.1, "upper": 3.0}, @@ -316,8 +320,19 @@ def inferpymc( sigma = parse_prior("sigma", sigprior, dims=("nsigma_site", "nsigma_time")) hx = pm.Data("hx", hx, dims=("nmeasure", "nx")) - mu = pm.Deterministic("mu", pt.dot(hx, x), dims="nmeasure") + if Hx_inner is not None: + hx_inner = pm.Data("hx_inner", Hx_inner.T, dims=("nmeasure", "nx_inner")) + x_inner = parse_prior("x_inner", xprior, dims="nx_inner") + step1_vars.append(x_inner) + mu = pm.Deterministic( + "mu", + pt.dot(hx, x) + pt.dot(hx_inner, x_inner), + dims="nmeasure", + ) + else: + mu = pm.Deterministic("mu", pt.dot(hx, x), dims="nmeasure") + if use_bc: hbc = pm.Data("hbc", hbc, dims=("nmeasure", "nbc")) mu_bc = pm.Deterministic("mu_bc", pt.dot(hbc, bc), dims="nmeasure") @@ -425,7 +440,9 @@ def inferpymc( "model": model, "trace": trace, } - + if Hx_inner is not None: + result["xouts_inner"] = posterior_burned.x_inner + if use_bc: result["bcouts"] = bcouts result["YBCtrace"] = YBCtrace.values.T @@ -475,6 +492,8 @@ def inferpymc_postprocessouts( rerun_file: xr.Dataset | None = None, use_bc: bool = False, min_error: float | np.ndarray = 0.0, + xouts_inner: np.ndarray | None = None, + Hx_inner: np.ndarray | None = None, ) -> xr.Dataset: r"""Take the output from inferpymc function along with other input information. @@ -491,6 +510,7 @@ def inferpymc_postprocessouts( Args: xouts: MCMC chain for emissions scaling factors for each basis function. + xouts_inner: MCMC chain for emissions scaling factors for each inner-domain basis function. sigouts: MCMC chain for model error. convergence: Passed/Failed convergence test as to whether multiple chains have a Gelman-Rubin diagnostic value <1.05. @@ -498,6 +518,7 @@ def inferpymc_postprocessouts( This is the same as what is given from fp_data[site].H.values, where fp_data is the output from e.g. footprint_data_merge, but where it has been stacked for all sites. + Hx_inner: Inner-domain sensitivity matrix with the same orientation as Hx. Y: Measurement vector containing all measurements. error: Measurement error vector, containing a value for each element of Y. Ytrace: Trace of modelled y values calculated from mcmc outputs and H matrices. @@ -565,6 +586,12 @@ def inferpymc_postprocessouts( nit = xouts.shape[0] nx = Hx.shape[0] ny = len(Y) + if xouts_inner is not None: + nx_inner = xouts_inner.shape[1] + elif Hx_inner is not None: + nx_inner = Hx_inner.shape[0] + else: + nx_inner = 0 if use_bc: nbc = Hbc.shape[0] @@ -574,6 +601,8 @@ def inferpymc_postprocessouts( steps = np.arange(nit) nmeasure = np.arange(ny) nparam = np.arange(nx) + if nx_inner > 0: + nparam_inner = np.arange(nx_inner) # OFFSET HYPERPARAMETER YmodmuOFF = np.mean(OFFSETtrace, axis=1) # mean @@ -636,6 +665,9 @@ def inferpymc_postprocessouts( else: Yapriori = np.sum(Hx.T, axis=1) + if Hx_inner is not None: + Yapriori = Yapriori + np.sum(Hx_inner.T, axis=1) + sitenum = np.arange(len(sites)) if fp_data is None and rerun_file is not None: @@ -770,6 +802,26 @@ def inferpymc_postprocessouts( min_error = min_error * np.ones_like(Y) # Make output netcdf file + xouts_values = xouts.values if hasattr(xouts, "values") else xouts + sigouts_values = sigouts.values if hasattr(sigouts, "values") else sigouts + xouts_inner_values = xouts_inner.values if hasattr(xouts_inner, "values") else xouts_inner + + if xouts_inner_values is not None: + xtrace_values = np.concatenate([xouts_inner_values, xouts_values], axis=1) + xsensitivity_values = np.concatenate([Hx_inner.T, Hx.T], axis=1) if Hx_inner is not None else Hx.T + nparam_values = np.arange(nx + nx_inner) + param_index = np.concatenate([np.arange(nx_inner), np.arange(nx)]) + param_domain = np.concatenate([ + np.full(nx_inner, "inner", dtype="U5"), + np.full(nx, "outer", dtype="U5"), + ]) + else: + xtrace_values = xouts_values + xsensitivity_values = Hx.T + nparam_values = np.arange(nx) + param_index = np.arange(nx) + param_domain = np.full(nx, "outer", dtype="U5") + data_vars = { "Yobs": (["nmeasure"], Y), "Yerror": (["nmeasure"], error), @@ -788,8 +840,8 @@ def inferpymc_postprocessouts( "Yoffmode": (["nmeasure"], YmodmodeOFF), "Yoff68": (["nmeasure", "nUI"], Ymod68OFF), "Yoff95": (["nmeasure", "nUI"], Ymod95OFF), - "xtrace": (["steps", "nparam"], xouts.values), - "sigtrace": (["steps", "nsigma_site", "nsigma_time"], sigouts.values), + "xtrace": (["steps", "nparam"], xtrace_values), + "sigtrace": (["steps", "nsigma_site", "nsigma_time"], sigouts_values), "siteindicator": (["nmeasure"], siteindicator), "sigmafreqindex": (["nmeasure"], sigma_freq_index), "sitenames": (["nsite"], sites), @@ -808,9 +860,12 @@ def inferpymc_postprocessouts( "country95": (["countrynames", "nUI"], cntry95), "countryapriori": (["countrynames"], cntryprior), "countrydefinition": (["lat", "lon"], cntrygrid), - "xsensitivity": (["nmeasure", "nparam"], Hx.T), + "xsensitivity": (["nmeasure", "nparam"], xsensitivity_values), } + if Hx_inner is not None: + data_vars["xsensitivity_inner"] = (["nmeasure", "nparam_inner"], Hx_inner.T) + coords = { "stepnum": (["steps"], steps), "paramnum": (["nlatent"], nparam), @@ -819,11 +874,17 @@ def inferpymc_postprocessouts( "nsites": (["nsite"], sitenum), "nsigma_time": (["nsigma_time"], np.unique(sigma_freq_index)), "nsigma_site": (["nsigma_site"], np.arange(sigouts.shape[1]).astype(int)), + "nparam": (["nparam"], nparam_values), + "param_index": (["nparam"], param_index), + "param_domain": (["nparam"], param_domain), "lat": (["lat"], lat), "lon": (["lon"], lon), "countrynames": (["countrynames"], cntrynames), } + if nx_inner > 0: + coords["nparam_inner"] = (["nparam_inner"], nparam_inner) + if use_bc: data_vars.update( { @@ -868,6 +929,8 @@ def inferpymc_postprocessouts( outds.countryapriori.attrs["units"] = country_units outds.xsensitivity.attrs["units"] = obs_units + " " + "mol/mol" outds.sigtrace.attrs["units"] = obs_units + " " + "mol/mol" + if "xsensitivity_inner" in outds: + outds.xsensitivity_inner.attrs["units"] = obs_units + " " + "mol/mol" outds.Yobs.attrs["longname"] = "observations" outds.Yerror.attrs["longname"] = "measurement error" @@ -888,7 +951,11 @@ def inferpymc_postprocessouts( outds.Yoff95.attrs["longname"] = ( " 0.95 Bayesian credible interval of posterior simulated offset between measurements" ) - outds.xtrace.attrs["longname"] = "trace of unitless scaling factors for emissions parameters" + outds.xtrace.attrs["longname"] = ( + "trace of unitless scaling factors for combined outer+inner emissions parameters" + ) + outds.param_index.attrs["longname"] = "parameter index within each domain for emissions trace" + outds.param_domain.attrs["longname"] = "domain label (outer or inner) for emissions parameters" outds.sigtrace.attrs["longname"] = "trace of model error parameters" outds.siteindicator.attrs["longname"] = "index of site of measurement corresponding to sitenames" outds.sigmafreqindex.attrs["longname"] = "perdiod over which the model error is estimated" @@ -909,6 +976,8 @@ def inferpymc_postprocessouts( outds.countryapriori.attrs["longname"] = "prior mean of ocean and country totals" outds.countrydefinition.attrs["longname"] = "grid definition of countries" outds.xsensitivity.attrs["longname"] = "emissions sensitivity timeseries" + if "xsensitivity_inner" in outds: + outds.xsensitivity_inner.attrs["longname"] = "inner-domain emissions sensitivity timeseries" if use_bc: outds.YmodmeanBC.attrs["units"] = obs_units + " " + "mol/mol" diff --git a/openghg_inversions/hbmcmc/inversionsetup.py b/openghg_inversions/hbmcmc/inversionsetup.py index 08a79ad8..4795aa61 100644 --- a/openghg_inversions/hbmcmc/inversionsetup.py +++ b/openghg_inversions/hbmcmc/inversionsetup.py @@ -2,6 +2,37 @@ import numpy as np import pandas as pd +import xarray as xr + + +def _site_ds(fp_data: dict, site: str): + entry = fp_data[site] + + if isinstance(entry, xr.DataTree): + if "standard" in entry.children and entry["standard"].ds is not None: + return entry["standard"].ds + + if entry.ds is not None: + return entry.ds + + # Fallback: use first non-empty child dataset if present. + for child in entry.children.values(): + if child.ds is not None: + return child.ds + + raise ValueError(f"Site '{site}' DataTree does not contain a dataset in root or child nodes.") + + return entry + + +def _site_hbc(fp_data: dict, site: str) -> xr.DataArray: + site_ds = _site_ds(fp_data, site) + if "H_bc" not in site_ds.data_vars: + raise ValueError( + f"Boundary-condition sensitivity 'H_bc' is missing for site '{site}'. " + f"Available variables: {list(site_ds.data_vars)}" + ) + return site_ds["H_bc"] def monthly_bcs(start_date: str, end_date: str, site: str, fp_data: dict) -> np.ndarray: @@ -23,12 +54,14 @@ def monthly_bcs(start_date: str, end_date: str, site: str, fp_data: dict) -> np. hmbc: Sensitivity matrix by month for observations """ + site_ds = _site_ds(fp_data, site) + h_bc = _site_hbc(fp_data, site).values allmonth = pd.date_range(start_date, end_date, freq="MS")[:-1] nmonth = len(allmonth) - curtime = pd.to_datetime(fp_data[site].time.values).to_period("M") - pmonth = pd.to_datetime(fp_data[site].resample(time="MS").mean().time.values) - nregions = fp_data[site].sizes["bc_region"] - hmbc = np.zeros((nregions * nmonth, len(fp_data[site].time.values))) + curtime = pd.to_datetime(site_ds.time.values).to_period("M") + pmonth = pd.to_datetime(site_ds.resample(time="MS").mean().time.values) + nregions = site_ds.sizes["bc_region"] + hmbc = np.zeros((nregions * nmonth, len(site_ds.time.values))) count = 0 for cord in range(nregions): for m in range(0, nmonth): @@ -38,7 +71,7 @@ def monthly_bcs(start_date: str, end_date: str, site: str, fp_data: dict) -> np. mnth = allmonth[m].month yr = allmonth[m].year mnthloc = np.where(np.logical_and(curtime.month == mnth, curtime.year == yr))[0] - hmbc[count, mnthloc] = fp_data[site].H_bc.values[cord, mnthloc] + hmbc[count, mnthloc] = h_bc[cord, mnthloc] count += 1 return hmbc @@ -70,14 +103,16 @@ def create_bc_sensitivity(start_date: str, end_date: str, site: str, fp_data: di hmbc: Sensitivity matrix by for observations to boundary conditions """ + site_ds = _site_ds(fp_data, site) + h_bc = _site_hbc(fp_data, site).values dys = int("".join([s for s in freq if s.isdigit()])) alldates = pd.date_range( pd.to_datetime(start_date), pd.to_datetime(end_date) + pd.DateOffset(days=dys), freq=freq ) ndates = np.sum(alldates < pd.to_datetime(end_date)) - curdates = fp_data[site].time.values - nregions = fp_data[site].sizes["bc_region"] - hmbc = np.zeros((nregions * ndates, len(fp_data[site].time.values))) + curdates = site_ds.time.values + nregions = site_ds.sizes["bc_region"] + hmbc = np.zeros((nregions * ndates, len(site_ds.time.values))) count = 0 for cord in range(nregions): for m in range(0, ndates): @@ -89,7 +124,7 @@ def create_bc_sensitivity(start_date: str, end_date: str, site: str, fp_data: di if len(dateloc) == 0: count += 1 continue - hmbc[count, dateloc] = fp_data[site].H_bc.values[cord, dateloc] + hmbc[count, dateloc] = h_bc[cord, dateloc] count += 1 return hmbc diff --git a/openghg_inversions/inversion_data/get_data.py b/openghg_inversions/inversion_data/get_data.py index a0fd1546..cdeb5da4 100644 --- a/openghg_inversions/inversion_data/get_data.py +++ b/openghg_inversions/inversion_data/get_data.py @@ -60,69 +60,69 @@ def add_obs_error(sites: list[str], fp_all: dict, add_averaging_error: bool = Tr """ # TODO: do we want to fill missing values in repeatability or variability? for site in sites: - ds = fp_all[site] - - variability_missing = False - if "mf_variability" not in ds: - ds["mf_variability"] = xr.zeros_like(ds.mf) - ds["mf_variability"].attrs["long_name"] = ds.mf.attrs.get("long_name", "") + "_variability" - variability_missing = True - - if "mf_repeatability" not in ds: - if variability_missing: - raise ValueError(f"Obs data for site {site} is missing both repeatability and variability.") - - ds["mf_repeatability"] = xr.zeros_like(ds.mf_variability) - ds["mf_repeatability"].attrs["long_name"] = ds.mf.attrs.get("long_name", "") + "_repeatability" - - ds["mf_error"] = ds["mf_variability"] - - if add_averaging_error: - logger.info( - "`mf_repeatability` not present; using `mf_variability` for `mf_error` at site %s", site + for node_name in ["standard", "inner"]: + if node_name not in fp_all[site]: + continue + ds = fp_all[site][node_name].ds.copy() + variability_missing = False + if "mf_variability" not in ds: + ds["mf_variability"] = xr.zeros_like(ds.mf) + ds["mf_variability"].attrs["long_name"] = ds.mf.attrs.get("long_name", "") + "_variability" + variability_missing = True + + if "mf_repeatability" not in ds: + if variability_missing: + raise ValueError(f"Obs data for site {site} ({node_name}) is missing both repeatability and variability.") + + ds["mf_repeatability"] = xr.zeros_like(ds.mf_variability) + ds["mf_repeatability"].attrs["long_name"] = ds.mf.attrs.get("long_name", "") + "_repeatability" + + ds["mf_error"] = ds["mf_variability"] + + if add_averaging_error: + logger.info( + "`mf_repeatability` not present; using `mf_variability` for `mf_error` at site %s (%s)", site, node_name + ) + + elif add_averaging_error: + ds["mf_error"] = np.sqrt( + ds["mf_repeatability"].fillna(0) ** 2 + ds["mf_variability"].fillna(0) ** 2 + ) + else: + ds["mf_error"] = ds["mf_repeatability"] + + ds["mf_error"].attrs["long_name"] = ds.mf.attrs.get("long_name", "") + "_error" + ds["mf_error"].attrs["units"] = ds.mf.attrs.get("units", None) + + err0 = (ds["mf_error"] == 0) | (ds["mf_error"].isnull()) + + if err0.any(): + percent0 = 100 * err0.mean() + logger.warning( + ( + "`mf_error` is zero/nan for %.2f percent of times at site %s (%s);" + "filling with max(median(mf_error), std(mf))." + ), + percent0, + site, + node_name, ) - elif add_averaging_error: - # Fill with zeros so that if one of repeatability and variability is not NaN, then mf_error will not be NaN. - ds["mf_error"] = np.sqrt( - ds["mf_repeatability"].fillna(0) ** 2 + ds["mf_variability"].fillna(0) ** 2 - ) - else: - ds["mf_error"] = ds["mf_repeatability"] - - ds["mf_error"].attrs["long_name"] = ds.mf.attrs.get("long_name", "") + "_error" - ds["mf_error"].attrs["units"] = ds.mf.attrs.get("units", None) - - # warnings/info for debugging - err0 = (ds["mf_error"] == 0) | ( - ds["mf_error"].isnull() - ) # might have NaN if add_averaging_error is False - - if err0.any(): - percent0 = 100 * err0.mean() - logger.warning( - ( - "`mf_error` is zero/nan for %.2f percent of times at site %s;" - "filling with max(median(mf_error), std(mf))." - ), - percent0, - site, - ) - - mf_err_da = ds["mf_error"].as_numpy() # load into memory to avoid Dask issues - fill_value = np.nanmax( - [ - mf_err_da.where(mf_err_da != 0).dropna(dim="time").median(), - ds["mf"].std(dim="time"), - ] - ) - ds["mf_error"] = mf_err_da.where(mf_err_da != 0, fill_value) - info_msg = ( - "If `averaging_period` matches the frequency of the obs data, then `mf_variability` " - "will be zero. Try setting `averaging_period = None`." - ) - logger.info(info_msg) + mf_err_da = ds["mf_error"].as_numpy() + fill_value = np.nanmax( + [ + mf_err_da.where(mf_err_da != 0).dropna(dim="time").median(), + ds["mf"].std(dim="time"), + ] + ) + ds["mf_error"] = mf_err_da.where(mf_err_da != 0, fill_value) + info_msg = ( + "If `averaging_period` matches the frequency of the obs data, then `mf_variability` " + "will be zero. Try setting `averaging_period = None`." + ) + logger.info(info_msg) + fp_all[site][node_name].ds = ds def convert_to_list( x: list[str | None] | str | None, length: int, name: str | None = None @@ -152,6 +152,108 @@ def convert_to_list( return x +def _interp_bool_mask_to_grid(mask: xr.DataArray, lat: xr.DataArray, lon: xr.DataArray) -> xr.DataArray: + """Interpolate a boolean mask to a target lat/lon grid. + + xarray's interp requires numeric dtype, so convert to float, interpolate, + then threshold back to bool. + """ + return ( + mask.astype(float) + .interp(lat=lat, lon=lon, method="nearest") + .assign_coords(lat=lat, lon=lon) + .fillna(0.0) + >= 0.5 + ) + + +def _apply_inner_mask_on_standard_domain(standard_footprint_data: xr.Dataset, inner_footprint_data: xr.Dataset): + """Mask standard-domain footprint values where inner-domain coverage exists. + + Args: + standard_footprint_data: Standard-domain footprint dataset. + inner_footprint_data: Inner-domain footprint dataset. + + Returns: + Copy of ``standard_footprint_data`` with ``fp`` zeroed in cells covered by the inner domain. + """ + fp_standard = standard_footprint_data.copy() + + if inner_footprint_data is None: + return fp_standard + + if "fp" not in fp_standard or "fp" not in inner_footprint_data: + return fp_standard + + inner_has_coverage = (inner_footprint_data["fp"] != 0).any("time") + inner_on_standard = _interp_bool_mask_to_grid( + mask=inner_has_coverage, + lat=fp_standard.lat, + lon=fp_standard.lon, + ) + + fp_standard["fp"] = fp_standard["fp"].where(~inner_on_standard, other=0.0) + return fp_standard + + +def _apply_inner_mask_on_standard_flux( + standard_flux_dict: dict, + inner_footprint_data: xr.Dataset, +) -> dict: + """Mask standard-domain flux values where inner-domain coverage exists. + + Args: + standard_flux_dict: Dict of FluxData-like objects with ``data.flux`` DataArray. + inner_footprint_data: Inner-domain footprint dataset. + + Returns: + New flux dict with ``flux`` zeroed in cells covered by the inner domain. + """ + if inner_footprint_data is None or "fp" not in inner_footprint_data: + return standard_flux_dict + + first_flux = next(iter(standard_flux_dict.values())).data.flux + inner_has_coverage = (inner_footprint_data["fp"] != 0).any("time") + inner_on_flux = _interp_bool_mask_to_grid( + mask=inner_has_coverage, + lat=first_flux.lat, + lon=first_flux.lon, + ) + + return _apply_boolean_mask_on_standard_flux(standard_flux_dict, inner_on_flux) + + +def _apply_boolean_mask_on_standard_flux( + standard_flux_dict: dict, + mask_on_standard_grid: xr.DataArray, +) -> dict: + """Apply a precomputed boolean mask to all standard-domain flux sources. + + Args: + standard_flux_dict: Dict of FluxData-like objects with ``data.flux`` DataArray. + mask_on_standard_grid: Boolean mask on standard-domain lat/lon grid where True means "inner-domain coverage". + + Returns: + New flux dict with masked ``flux`` fields. + """ + masked_flux_dict = {} + + for source, flux_data in standard_flux_dict.items(): + flux_da = flux_data.data.flux + mask_for_flux = _interp_bool_mask_to_grid( + mask=mask_on_standard_grid, + lat=flux_da.lat, + lon=flux_da.lon, + ) + + masked_flux = flux_da.where(~mask_for_flux, other=0.0) + masked_ds = flux_data.data.copy() + masked_ds["flux"] = masked_flux + masked_flux_dict[source] = type(flux_data)(data=masked_ds, metadata=flux_data.metadata) + + return masked_flux_dict + + def data_processing_surface_notracer( species: str, sites: list | str, @@ -181,6 +283,9 @@ def data_processing_surface_notracer( merged_data_name: str | None = None, merged_data_dir: str | None = None, output_name: str | None = None, + inner_domain: str | None = None, + inner_footprint_store: str | list[str] | None = None, + inner_emissions_store: str | None = None, ) -> tuple[dict, list, list, list, list, list]: """Retrieve and prepare fixed-surface datasets from specified OpenGHG object stores. @@ -257,15 +362,27 @@ def data_processing_surface_notracer( raise ValueError("`emissions_name` must be specified") flux_dict = get_flux_data( - sources=emissions_name, - species=species, - domain=domain, - start_date=start_date, - end_date=end_date, - store=emissions_store, - ) + sources=emissions_name, + species=species, + domain=domain, + start_date=start_date, + end_date=end_date, + store=emissions_store, + ) + fp_all[".flux"] = flux_dict + if inner_emissions_store is not None: + inner_flux_dict = get_flux_data( + sources=emissions_name, + species=species, + domain=f"{domain}-{inner_domain}", + start_date=start_date, + end_date=end_date, + store=inner_emissions_store, + ) + fp_all[".inner_flux"] = inner_flux_dict + # Get BC data if use_bc is True: try: @@ -289,6 +406,7 @@ def data_processing_surface_notracer( check_scales = set() units = {} site_indices_to_keep = [] + standard_flux_inner_mask_union = None keep_variables = [ f"{species}", @@ -328,7 +446,8 @@ def data_processing_surface_notracer( continue # Get footprints data - footprint_data = get_footprint_data( + + standard_footprint_data = get_footprint_data( site=site, domain=domain, platform=platform[i], @@ -342,25 +461,83 @@ def data_processing_surface_notracer( obs_data=site_data, stores=footprint_store, ) - if footprint_data is None: + if standard_footprint_data is None: print( f"\nNo footprint data found for {site} with inlet/height {fp_height[i]}, model {fp_model}, and domain {domain}.", f"Check these values.\nContinuing model run without {site}.\n", ) continue # skip this site + inner_footprint_data = None + standard_flux_for_site = flux_dict + if inner_domain is not None: + print(f"Inner domain {inner_domain} specified; attempting to retrieve inner footprint data for {site} ...") + inner_footprint_data = get_footprint_data( + site=site, + domain=f"{domain}-{inner_domain}", + platform=platform[i], + fp_height=fp_height[i], + start_date=start_date, + end_date=end_date, + model=fp_model, + met_model=met_model[i], + fp_species=fp_species, + averaging_period=averaging_period[i], + obs_data=site_data, + stores=inner_footprint_store if inner_footprint_store is not None else footprint_store, + ) + if inner_footprint_data is None: + print( + f"\nNo Inner footprint data found for {site} with inlet/height {fp_height[i]}, model {fp_model}, and domain {domain}-{inner_domain}, starting from {start_date} to {end_date}, fp_species {fp_species} and met_model {met_model[i]}, obs_data {site_data} and averaging_period {averaging_period[i]}.", + f"Check these values.\nContinuing model run without {site}.Jai\n", + ) + continue # skip this site + # Mask the standard domain fp_x_flux with the inner domain mask BEFORE storing in DataTree + standard_footprint_data.data = _apply_inner_mask_on_standard_domain( + standard_footprint_data=standard_footprint_data.data, + inner_footprint_data=inner_footprint_data.data, + ) + standard_flux_for_site = _apply_inner_mask_on_standard_flux( + standard_flux_dict=flux_dict, + inner_footprint_data=inner_footprint_data.data, + ) + + first_flux = next(iter(flux_dict.values())).data.flux + inner_has_coverage = (inner_footprint_data.data["fp"] != 0).any("time") + inner_on_standard_flux = _interp_bool_mask_to_grid( + mask=inner_has_coverage, + lat=first_flux.lat, + lon=first_flux.lon, + ) + if standard_flux_inner_mask_union is None: + standard_flux_inner_mask_union = inner_on_standard_flux + else: + standard_flux_inner_mask_union = standard_flux_inner_mask_union | inner_on_standard_flux scenario_combined = merged_scenario_data( - site_data, footprint_data, flux_dict, bc_data, platform=platform[i], max_level=max_level + obs_data=site_data, footprint_data=standard_footprint_data, + flux_dict=standard_flux_for_site, bc_data=bc_data, inner_footprint_data=inner_footprint_data, + inner_flux_dict= inner_flux_dict if inner_emissions_store is not None else None, + platform=platform[i], max_level=max_level ) fp_all[site] = scenario_combined - units[site] = scenario_combined.mf.attrs.get("units") + # scenario_combined returns a datatree. The standard_domain scenario is stored at the root, and any inner domain scenario is stored at "/{inner_domain}". Here we take the root scenario to store the calibration scale and units, since these should be the same for the inner domain scenario. If inner domain scenario is None the root scenario is the only scenario, so this will still work. + + root_ds = scenario_combined["standard"].ds + units[site] = root_ds.mf.attrs.get("units") if "satellite" not in platform: - scales[site] = scenario_combined.scale - check_scales.add(scenario_combined.scale) + scales[site] = root_ds.scale + check_scales.add(root_ds.scale) site_indices_to_keep.append(i) + + if standard_flux_inner_mask_union is not None: + fp_all[".flux"] = _apply_boolean_mask_on_standard_flux( + standard_flux_dict=flux_dict, + mask_on_standard_grid=standard_flux_inner_mask_union, + ) + if len(site_indices_to_keep) == 0: raise SearchError("No site data found. Exiting process.") @@ -412,3 +589,4 @@ def data_processing_surface_notracer( print(f"\nfp_all saved in {merged_data_dir}\n") return fp_all, sites, inlet, fp_height, instrument, averaging_period + diff --git a/openghg_inversions/inversion_data/scenario.py b/openghg_inversions/inversion_data/scenario.py index 35eb596d..b8dc26e2 100644 --- a/openghg_inversions/inversion_data/scenario.py +++ b/openghg_inversions/inversion_data/scenario.py @@ -3,12 +3,72 @@ from openghg.analyse import ModelScenario from openghg.dataobjects import ObsData, BoundaryConditionsData, FluxData, FootprintData +def _mask_flux_to_inner_domain( + flux_dict: dict[str, FluxData], + inner_footprint_data: FootprintData, +) -> dict[str, FluxData]: + """Mask EUROPE-domain flux values to zero outside the inner footprint extent. + + The inner footprint (``inner_fp``) is non-zero only within the inner + domain (e.g. 6 km grid). We use this spatial coverage to build a + boolean mask, regrid the EUROPE flux to the inner footprint lat/lon + coordinates, and then zero out any flux cells that fall outside the + inner domain extent. + + The masked flux is returned as a new dict of ``FluxData`` objects so + that ``ModelScenario`` can use it unmodified to compute ``fp_x_flux`` + on the correct inner grid. + + Args: + flux_dict: EUROPE-domain flux, keyed by source name. + inner_footprint_data: FootprintData for the inner domain whose + raw ``fp`` defines the spatial extent and lat/lon grid. + + Returns: + New dict of ``FluxData`` with flux regridded to the inner grid + and zeroed outside the inner domain footprint coverage. + """ + # Inner fp: dims (time, lat, lon) on the inner (e.g. 6 km) grid. + inner_fp: xr.DataArray = inner_footprint_data.data.fp + + # Boolean mask: True where inner domain has any non-zero fp at any time. + inner_domain_mask: xr.DataArray = (inner_fp != 0).any("time") # (lat, lon) + + inner_lat = inner_fp.lat + inner_lon = inner_fp.lon + + masked_flux_dict: dict[str, FluxData] = {} + + for source, flux_data in flux_dict.items(): + flux_da: xr.DataArray = flux_data.data.flux # EUROPE grid (time, lat, lon) + + # 1. Regrid EUROPE flux to inner footprint lat/lon grid + flux_on_inner = ( + flux_da + .interp(lat=inner_lat, lon=inner_lon, method="nearest") + .assign_coords(lat=inner_lat, lon=inner_lon) + .fillna(0.0) + ) + + # 2. Mask: zero out flux cells outside the inner domain extent + flux_masked = flux_on_inner.where(inner_domain_mask, other=0.0) + + # 3. Build a new FluxData with the masked flux dataset, + # preserving all original metadata and dataset attributes. + masked_ds = flux_data.data.copy() + masked_ds["flux"] = flux_masked + + masked_flux_dict[source] = FluxData(data=masked_ds, metadata=flux_data.metadata) + + return masked_flux_dict def merged_scenario_data( obs_data: ObsData, footprint_data: FootprintData, flux_dict: dict[str, FluxData], + inner_flux_dict: dict[str, FluxData] | None, bc_data: BoundaryConditionsData | None = None, + inner_footprint_data: FootprintData | None = None, platform: str | None = None, max_level: int | None = None ) -> xr.Dataset: @@ -42,4 +102,30 @@ def merged_scenario_data( cache=False, ) - return scenario_combined + dt_dict: dict[str, xr.Dataset] = {"/standard": scenario_combined} + if inner_footprint_data is not None: + # Mask the EUROPE flux to the inner domain extent (zero outside), + # regridded to the inner footprint lat/lon grid. + # ModelScenario then computes fp_x_flux on the inner grid correctly. + # flux_dict_inner = _mask_flux_to_inner_domain(flux_dict, inner_footprint_data) + flux_dict_inner = inner_flux_dict + + inner_scenario = ModelScenario(obs=obs_data, footprint=inner_footprint_data, flux=flux_dict_inner, bc=None) + inner_domain_merged = inner_scenario.footprints_data_merge( + calc_fp_x_flux=True, + calc_bc_sensitivity=False, + cache=False, + + ) + + # Align inner to outer time axis. + # If inner footprint is missing any timestamps that exist in the outer + # scenario (e.g. sparse inner store coverage), fill those with 0 so + # both nodes share exactly the same time dimension in the DataTree. + inner_domain_merged = inner_domain_merged.reindex( + time=scenario_combined.time, fill_value=0.0 + ) + + dt_dict["/inner"] = inner_domain_merged = inner_domain_merged + + return xr.DataTree.from_dict(dt_dict) diff --git a/openghg_inversions/inversion_data/serialise.py b/openghg_inversions/inversion_data/serialise.py index d8904671..96136ca5 100644 --- a/openghg_inversions/inversion_data/serialise.py +++ b/openghg_inversions/inversion_data/serialise.py @@ -475,7 +475,7 @@ def datatree_to_flux_dict(dt: xr.DataTree) -> dict[str, FluxData]: def fp_all_to_datatree(fp_all: dict, netcdf_safe_attrs: bool = False) -> xr.DataTree: dt_dict: dict[str, xr.Dataset | xr.DataTree] = {} - scenario_dict = {} + scenario_dict: dict[str, xr.Dataset] = {} dt_attrs = {} if ".flux" in fp_all: @@ -487,7 +487,15 @@ def fp_all_to_datatree(fp_all: dict, netcdf_safe_attrs: bool = False) -> xr.Data if isinstance(v, BoundaryConditionsData): dt_dict[k.removeprefix(".")] = openghg_data_to_dataset(v, netcdf_safe_attrs) elif not k.startswith(".") and isinstance(v, xr.Dataset): - scenario_dict[k] = v + scenario_dict[f"/{k}"] = v + elif not k.startswith(".") and isinstance(v, xr.DataTree): + for group in v.groups: + node = v[group] + if node.ds is None: + continue + + rel_group = "" if group == "/" else group + scenario_dict[f"/{k}{rel_group}"] = node.ds else: dt_attrs[k] = v @@ -512,7 +520,13 @@ def datatree_to_fp_all(dt: xr.DataTree) -> dict: fp_all[".bc"] = dataset_to_bc_data(dt.bc.to_dataset()) for k, v in dt.scenarios.items(): - fp_all[str(k)] = v.to_dataset() + if len(v.children) > 0: + scenario_children = { + f"/{child_name}": child.ds for child_name, child in v.children.items() if child.ds is not None + } + fp_all[str(k)] = xr.DataTree.from_dict(scenario_children) + else: + fp_all[str(k)] = v.to_dataset() fp_all.update({str(k): v for k, v in dt.attrs.items()}) diff --git a/openghg_inversions/postprocessing/countries.py b/openghg_inversions/postprocessing/countries.py index 598d435e..6a2f2909 100644 --- a/openghg_inversions/postprocessing/countries.py +++ b/openghg_inversions/postprocessing/countries.py @@ -383,7 +383,14 @@ def get_country_trace( 1.0, based on how it is used in hbmcmc """ x_to_country_mat = self.get_x_to_country_mat(inv_out) - x_trace = inv_out.get_trace_dataset(var_names="x") + trace_all = inv_out.get_trace_dataset() + x_outer_vars = [ + "x_prior", + "x_posterior", + "x_prior_predictive", + "x_posterior_predictive", + ] + x_trace = trace_all[[var for var in x_outer_vars if var in trace_all.data_vars]] species = inv_out.species diff --git a/openghg_inversions/postprocessing/inversion_output.py b/openghg_inversions/postprocessing/inversion_output.py index a24915d0..fc8ea129 100644 --- a/openghg_inversions/postprocessing/inversion_output.py +++ b/openghg_inversions/postprocessing/inversion_output.py @@ -627,6 +627,24 @@ def make_inv_out_for_fixed_basis_mcmc( obs_prior_upper_level_factor: np.ndarray | None = None, ) -> InversionOutput: """Create InversionOutput in `fixedbasisMCMC`.""" + def _scenario_datasets(entry: xr.Dataset | xr.DataTree) -> list[xr.Dataset]: + if isinstance(entry, xr.DataTree): + datasets: list[xr.Dataset] = [] + if "inner" in entry.children: + datasets.append(entry["inner"].ds) + if "standard" in entry.children: + datasets.append(entry["standard"].ds) + if not datasets: + datasets.append(entry.ds) + return datasets + return [entry] + + def _first_attrs(datasets: list[xr.Dataset], var_name: str) -> dict: + for ds in datasets: + if var_name in ds.data_vars: + return ds[var_name].attrs + return {} + nmeasure = np.arange(len(Y)) y_obs = xr.DataArray(Y, dims=["nmeasure"], coords={"nmeasure": nmeasure}, name="Yobs") times = xr.DataArray(Ytime, dims=["nmeasure"], coords={"nmeasure": nmeasure}, name="times") @@ -666,11 +684,18 @@ def make_inv_out_for_fixed_basis_mcmc( basis = get_xr_dummies(fp_data[".basis"], cat_dim="nx", categories=nx) - scenarios = [v for k, v in fp_data.items() if not k.startswith(".")] + scenario_entries = [v for k, v in fp_data.items() if not k.startswith(".")] + scenario_datasets = [_scenario_datasets(entry) for entry in scenario_entries] + + flux = None + for entry in scenario_entries: + try: + flux = entry.flux_stacked + break + except AttributeError: + continue - try: - flux = scenarios[0].flux_stacked - except AttributeError: + if flux is None: flux = next(iter(fp_data[".flux"].values())).data # TODO: this only works if there is one flux used (or if multiple, but ModelScenario stacks them) @@ -680,18 +705,27 @@ def make_inv_out_for_fixed_basis_mcmc( if not isinstance(flux, xr.DataArray): raise ValueError("Flux from `fp_data` could not be converted to a xr.DataArray.") - # add attributes - scenario = scenarios[0] - y_obs.attrs = scenario.mf.attrs - times.attrs = scenario.time.attrs - y_error.attrs = scenario.mf_error.attrs - y_error_variability.attrs = scenario.mf_variability.attrs - y_error_repeatability.attrs = scenario.mf_repeatability.attrs - for scenario in scenarios: - if not ("mf_prior_factor" in scenario and "mf_prior_upper_level_factor" in scenario): + # add attributes (prefer inner metadata, then standard) + first_ds_group = scenario_datasets[0] + y_obs.attrs = _first_attrs(first_ds_group, "mf") + time_attrs = next((ds.time.attrs for ds in first_ds_group if "time" in ds.coords), {}) + times.attrs = time_attrs + y_error.attrs = _first_attrs(first_ds_group, "mf_error") + y_error_variability.attrs = _first_attrs(first_ds_group, "mf_variability") + y_error_repeatability.attrs = _first_attrs(first_ds_group, "mf_repeatability") + for ds_group in scenario_datasets: + ds_with_priors = next( + ( + ds + for ds in ds_group + if "mf_prior_factor" in ds.data_vars and "mf_prior_upper_level_factor" in ds.data_vars + ), + None, + ) + if ds_with_priors is None: continue - y_obs_prior_factor.attrs = scenario.mf_prior_factor.attrs - y_obs_prior_upper_level_factor.attrs = scenario.mf_prior_upper_level_factor.attrs + y_obs_prior_factor.attrs = ds_with_priors.mf_prior_factor.attrs + y_obs_prior_upper_level_factor.attrs = ds_with_priors.mf_prior_upper_level_factor.attrs return InversionOutput( obs=y_obs, diff --git a/openghg_inversions/postprocessing/make_outputs.py b/openghg_inversions/postprocessing/make_outputs.py index 2b2089ac..96dedeac 100644 --- a/openghg_inversions/postprocessing/make_outputs.py +++ b/openghg_inversions/postprocessing/make_outputs.py @@ -42,7 +42,14 @@ def make_flux_outputs( xr.Dataset with computed flux stats. """ - trace = inv_out.get_trace_dataset(var_names="x") + trace_all = inv_out.get_trace_dataset() + x_outer_vars = [ + "x_prior", + "x_posterior", + "x_prior_predictive", + "x_posterior_predictive", + ] + trace = trace_all[[var for var in x_outer_vars if var in trace_all.data_vars]] if stats_args is None: stats_args = {} diff --git a/tests/conftest.py b/tests/conftest.py index c29f9c9f..77c04e38 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -344,3 +344,31 @@ def mhd_and_tac_ch4_data_args(): "averaging_period": ["1h", "1h"], } return data_args + +@pytest.fixture(scope="module") +def mhd_with_inner_domain_ch4_data_args(): + """Data args for MHD with an inner domain (europe-6km) footprint.""" + data_args = { + "species": "ch4", + "sites": ["MHD"], + "use_tracer" : False, + "start_date": "2023-01-01", + "end_date": "2023-10-01", + "bc_store": "/group/chem/acrg/object_stores/shared_store_zarr", + "obs_store": "/group/chem/acrg/object_stores/paris/obs_icos_2025_08_store", + "footprint_store": "/group/chem/acrg/object_stores/shared_store_zarr", + "emissions_store": "/group/chem/acrg/prasad/job_scripts/openghg_inversions/prior_flux_2023_EUROPE", + "inner_footprint_store": "/group/chem/acrg/object_stores/shared_store_zarr", + "inner_emissions_store": "/group/chem/acrg/prasad/job_scripts/openghg_inversions/prior_flux_2023_EUROPE_6km", + "inlet": ["9m"], + "instrument": None, + "domain": "EUROPE", + "inner_domain": "6km", + "fp_height": ["10m"], + "fp_species": "inert", + "fp_model": "NAME", + "emissions_name": ["edgar_wetcharts"], + "averaging_period": ["4h"], + "bc_input": "camsv22r2_daily", + } + return data_args \ No newline at end of file diff --git a/tests/test_full_inversion.py b/tests/test_full_inversion.py index 367524f2..6fdc855b 100644 --- a/tests/test_full_inversion.py +++ b/tests/test_full_inversion.py @@ -102,6 +102,15 @@ def test_full_inversion_paris_outputs(mcmc_args): assert "Yapost" in out +def test_full_inversion_6km_paris_outputs(mcmc_args): + """Test full inversion including loading data with PARIS output format.""" + mcmc_args["reload_merged_data"] = False + mcmc_args["output_format"] = "paris" + out = fixedbasisMCMC(**mcmc_args) + + assert "Yapost" in out + + def test_full_inversion_no_model_error(mcmc_args): mcmc_args["no_model_error"] = True fixedbasisMCMC(**mcmc_args) @@ -219,3 +228,60 @@ def test_full_inversion_long(mcmc_args): } ) fixedbasisMCMC(**mcmc_args) + +@pytest.fixture +def inner_domain_mcmc_args(tmp_path, mhd_with_inner_domain_ch4_data_args): + """MCMC args for a minimal inversion with inner domain enabled.""" + mcmc_args = mhd_with_inner_domain_ch4_data_args.copy() + mcmc_args.update( + { + "outputname": "inner_domain_test_run", + "outputpath": "/group/chem/acrg/prasad/job_scripts/openghg_inversions/inversion_outputs/paris_bc", + "basis_algorithm": "quadtree", + "bc_basis_case" : "NESW", + "fp_basis_case" : None, + "nbasis" : 250, + "basis_directory" : "/group/chem/acrg/LPDM/basis_functions/", + "bc_basis_directory" : "/group/chem/acrg/LPDM/bc_basis_functions/", + "country_file" : "/group/chem/acrg/LPDM/countries/country_EUROPE_EEZ_PARIS_gapfilled.nc", + "basis_output_path": str(tmp_path), + "nbasis": 4, + "nit": 1, + "burn": 1, + "tune": 1, + "nchain": 1, + "mcmc_type" : "fixed_basis", + "reload_merged_data": False, + "xprior": {"pdf": "normal", "mu": 1.0, "sigma": 1.0}, + "bcprior": {"pdf": "normal", "mu": 1.0, "sigma": 0.5}, + "sigprior": {"pdf": "uniform", "lower": 0.5, "upper": 3.0}, + "use_bc": True, + "averaging_error": True, + "min_error": 0.0, + "fix_basis_outer_regions" : "False", + "nuts_sampler" : "numpyro", + "save_trace" : True, + "pollution_events_from_obs" : True, + "no_model_error" : False, + "bc_freq" : None, + "sigma_freq" : None, + "sigma_per_site" : True, + "inlet" : [slice(0,25)], + "instrument" : ['multiple'], + "filters" : {'MHD' : ['pblh_inlet_diff']} + } + ) + return mcmc_args + +def test_full_inversion_with_inner_domain(inner_domain_mcmc_args): + """Test that fixedbasisMCMC runs end-to-end with inner_domain specified. + + Verifies: + - The inversion completes without error. + - Standard output variables are present (Yerror_repeatability, Yerror_variability). + - The fp_data for each site contains H_inner (computed from the inner footprint). + """ + out = fixedbasisMCMC(**inner_domain_mcmc_args) + + assert "Yerror_repeatability" in out + assert "Yerror_variability" in out \ No newline at end of file