From 07a2dfcce123a256aebc3cb030b00c0102a7bd85 Mon Sep 17 00:00:00 2001 From: SutarPrasad Date: Tue, 24 Feb 2026 10:41:01 +0000 Subject: [PATCH 01/68] inner_domain --- openghg_inversions/inversion_data/getters.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/openghg_inversions/inversion_data/getters.py b/openghg_inversions/inversion_data/getters.py index 9f4b950f..e374fa06 100644 --- a/openghg_inversions/inversion_data/getters.py +++ b/openghg_inversions/inversion_data/getters.py @@ -30,7 +30,7 @@ def adjust_flux_start_date( start_date: str, species: str, source: str, domain: str, store: str | None = None ) -> pd.Timestamp: """Adjusts the flux start_date to align with the flux data's temporal resolution.""" - flux_search = search_flux(species=species, source=source, domain=domain, store=store) + flux_search = search_flux(species=species, source=source, domain=f"{domain}-6km", store=store) if flux_search.results.empty: raise SearchError( f"No flux found with species={species}, source={source}, domain={domain}, store={store}." @@ -85,7 +85,7 @@ def get_flux_data( try: flux_data = get_flux( species=species, - domain=domain, + domain=f"{domain}-6km", source=source, start_date=None, end_date=end_date, @@ -422,7 +422,7 @@ def get_footprint_data( def get_func(store): return get_footprint_to_match( obs_data, - domain=domain, + domain=f"{domain}-6km", start_date=start_date, end_date=end_date, model=model, @@ -458,7 +458,7 @@ def get_func(store): return get_footprint( site=site, height=fp_height, - domain=domain, + domain=f"{domain}-6km", model=model, met_model=met_model, start_date=start_date, From 3b08f8b5da70c2cd70a15b96cbc2008235f2dfc7 Mon Sep 17 00:00:00 2001 From: SutarPrasad Date: Tue, 24 Feb 2026 10:41:45 +0000 Subject: [PATCH 02/68] not_merge --- .../config/openghg_hbmcmc_input_template.ini | 47 +++++++++---------- 1 file changed, 23 insertions(+), 24 deletions(-) diff --git a/openghg_inversions/hbmcmc/config/openghg_hbmcmc_input_template.ini b/openghg_inversions/hbmcmc/config/openghg_hbmcmc_input_template.ini index ce6aab27..a2a7fc6b 100644 --- a/openghg_inversions/hbmcmc/config/openghg_hbmcmc_input_template.ini +++ b/openghg_inversions/hbmcmc/config/openghg_hbmcmc_input_template.ini @@ -16,12 +16,12 @@ ; start_date (str): Start of observations to extract (format YYYY-MM-DD) ; end_date (str): End of observations to extract (format YYYY-MM-DD) (non-inclusive) -species = "" ; (required) +species = "ch4" ; (required) use_tracer = False ; (required) -sites = [] ; (required) -averaging_period = [] ; (required) -start_date = " " ; (required - but can be specified on command line instead) -end_date = " " ; (required - but can be specified on command line instead) +sites = ["mhd"] ; (required) +averaging_period = ["1h"] ; (required) +start_date = "2022-01-01" ; (required - but can be specified on command line instead) +end_date = "2024-12-31" ; (required - but can be specified on command line instead) ; save_merged_data (bool): If True, saves merged data object ; reload_merged_data (bool): If True, reads merged data object rather than rerunning get_data @@ -37,7 +37,7 @@ merged_data_dir = " " ; obs_data_level (list/str): Measurement data quality level ; filters (str): Data filtering approach to apply -inlet = None +inlet = "9m" instrument = None calibration_scale = None obs_data_level = None @@ -52,9 +52,9 @@ filters = [] ; emissions_store (str): Name of flux emissions object store bc_store = " " ; (required) -obs_store = " " ; (required) -footprint_store = " " ; (required) -emissions_store = " " ; (required) +obs_store = "/group/chem/acrg/object_stores/paris/obs_icos_2025_08_store" ; (required) +footprint_store = "/group/chem/acrg/prasad/object_store_6km" ; (required) +emissions_store = "/group/chem/acrg/prasad/object_store_6km" ; (required) [INPUT.PRIORS] @@ -67,12 +67,12 @@ emissions_store = " " ; (required) ; emissions_name (list): Name of emissions sources as used when adding flux files to the object store ; bc_input (list/str): Name of boundary conditions data to use from object store -domain = " " ; (required) +domain = "europe" ; (required) met_model = None fp_model = None -fp_height = None +fp_height = "10m" fp_species = None -emissions_name = [None] ; (required) +emissions_name = ["anthro"] ; (required) bc_input = None [INPUT.BASIS_CASE] @@ -89,14 +89,13 @@ bc_input = None ; country_file (str/None): Directory with filename containing the indices of country boundaries in domain ; country_directory (str/None): Directory containing auxiliary country files for deriving basis functions (land-sea mask, InTEM outer regions) -basis_algorithm = " " +basis_algorithm = "quadtree" bc_basis_case = "NESW" fp_basis_case = None -nbasis = -basis_directory = None -bc_basis_directory = None -country_file = " " -country_directory = None +nbasis = 250 +basis_directory = "/group/chem/acrg/LPDM/basis_functions/" +bc_basis_directory = "/group/chem/acrg/LPDM/bc_basis_functions/" +country_file = "/group/chem/acrg/LPDM/countries/country_EUROPE_EEZ_PARIS_gapfilled.nc" [MCMC.TYPE] ; Which MCMC setup to use. This defines the function which will be called and the expected inputs. @@ -173,9 +172,9 @@ sigma_per_site = True ; burn (int): Number of iterations to burn/discard in MCMC ; tune (int): Number of iterations to use to tune step size -nit = ; (required) -burn = ; (required) -tune = ; (required) +nit = 100 ; (required) +burn = 10 ; (required) +tune = 100 ; (required) [MCMC.NCHAIN] ; nchain (int): Number of chains to run simultaneously. Must be >=2 to allow convergence to be checked. @@ -211,7 +210,7 @@ nchain = 2 averaging_error = True min_error = 0.0 fix_basis_outer_regions = False -use_bc = True +use_bc = False nuts_sampler = "numpyro" save_trace = True pollution_events_from_obs = True @@ -222,5 +221,5 @@ no_model_error = False ; outputpath (str): Directory to write output ; outputname (str): Unique identifier for output/run name. -outputpath = " " ; (required) -outputname = " " ; (required) +outputpath = "~/openghg_inversions" ; (required) +outputname = "mhd_openghg_inversions" ; (required) From e744f2e8a82b33183a9a0b553ea48e5f1772ec72 Mon Sep 17 00:00:00 2001 From: SutarPrasad Date: Tue, 24 Feb 2026 20:01:07 +0000 Subject: [PATCH 03/68] paris hardcoded for test --- openghg_inversions/hbmcmc/hbmcmc.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/openghg_inversions/hbmcmc/hbmcmc.py b/openghg_inversions/hbmcmc/hbmcmc.py index 85e51c17..74066ca9 100644 --- a/openghg_inversions/hbmcmc/hbmcmc.py +++ b/openghg_inversions/hbmcmc/hbmcmc.py @@ -314,7 +314,7 @@ def fixedbasisMCMC( min_error_options: dict | None = None, output_format: Literal[ "hbmcmc", "paris", "basic", "merged_data", "inv_out", "mcmc_args", "mcmc_results" - ] = "hbmcmc", + ] = "paris", paris_postprocessing: bool = False, paris_postprocessing_kwargs: dict | None = None, power: dict | float = 1.99, From 874b8a1790ab0ea02f317d556ff619fa3ac2972b Mon Sep 17 00:00:00 2001 From: SutarPrasad Date: Wed, 4 Mar 2026 15:29:46 +0000 Subject: [PATCH 04/68] pulling fp and innner domain fp separately and passing to merged_scenario_data --- openghg_inversions/inversion_data/get_data.py | 33 +++++++++++++++++-- 1 file changed, 30 insertions(+), 3 deletions(-) diff --git a/openghg_inversions/inversion_data/get_data.py b/openghg_inversions/inversion_data/get_data.py index a0fd1546..70334e74 100644 --- a/openghg_inversions/inversion_data/get_data.py +++ b/openghg_inversions/inversion_data/get_data.py @@ -181,6 +181,8 @@ def data_processing_surface_notracer( merged_data_name: str | None = None, merged_data_dir: str | None = None, output_name: str | None = None, + inner_domain: str | None = None, + inner_footprint_store: str | list[str] | None = None, ) -> tuple[dict, list, list, list, list, list]: """Retrieve and prepare fixed-surface datasets from specified OpenGHG object stores. @@ -328,7 +330,8 @@ def data_processing_surface_notracer( continue # Get footprints data - footprint_data = get_footprint_data( + + standard_footprint_data = get_footprint_data( site=site, domain=domain, platform=platform[i], @@ -342,15 +345,39 @@ def data_processing_surface_notracer( obs_data=site_data, stores=footprint_store, ) - if footprint_data is None: + if standard_footprint_data is None: print( f"\nNo footprint data found for {site} with inlet/height {fp_height[i]}, model {fp_model}, and domain {domain}.", f"Check these values.\nContinuing model run without {site}.\n", ) continue # skip this site + inner_footprint_data = None + if inner_domain is not None: + print(f"Inner domain {inner_domain} specified; attempting to retrieve inner footprint data for {site} ...") + inner_footprint_data = get_footprint_data( + site=site, + domain=f"{domain}-{inner_domain}", + platform=platform[i], + fp_height=fp_height[i], + start_date=start_date, + end_date=end_date, + model=fp_model, + met_model=met_model[i], + fp_species=fp_species, + averaging_period=averaging_period[i], + obs_data=site_data, + stores=inner_footprint_store if inner_footprint_store is not None else footprint_store, + ) + if inner_footprint_data is None: + print( + f"\nNo Inner footprint data found for {site} with inlet/height {fp_height[i]}, model {fp_model}, and domain {domain}-{inner_domain}.", + f"Check these values.\nContinuing model run without {site}.Jai\n", + ) + continue # skip this site scenario_combined = merged_scenario_data( - site_data, footprint_data, flux_dict, bc_data, platform=platform[i], max_level=max_level + obs_data=site_data, footprint_data=standard_footprint_data, + flux_dict= flux_dict, bc_data=bc_data, inner_footprint_data=inner_footprint_data, platform=platform[i], max_level=max_level ) fp_all[site] = scenario_combined From 4e92b33b083581a328885a734a45799864ff6fa6 Mon Sep 17 00:00:00 2001 From: SutarPrasad Date: Wed, 4 Mar 2026 15:31:23 +0000 Subject: [PATCH 05/68] flux is kept to fetch europe domain, calls to get_footprint_data are done separately so inner domain fp are fetched correctly --- openghg_inversions/inversion_data/getters.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/openghg_inversions/inversion_data/getters.py b/openghg_inversions/inversion_data/getters.py index e374fa06..7908c41d 100644 --- a/openghg_inversions/inversion_data/getters.py +++ b/openghg_inversions/inversion_data/getters.py @@ -85,7 +85,7 @@ def get_flux_data( try: flux_data = get_flux( species=species, - domain=f"{domain}-6km", + domain=domain, source=source, start_date=None, end_date=end_date, @@ -422,7 +422,7 @@ def get_footprint_data( def get_func(store): return get_footprint_to_match( obs_data, - domain=f"{domain}-6km", + domain=domain, start_date=start_date, end_date=end_date, model=model, @@ -458,7 +458,7 @@ def get_func(store): return get_footprint( site=site, height=fp_height, - domain=f"{domain}-6km", + domain=domain, model=model, met_model=met_model, start_date=start_date, From 1cd0a8db9d9a653921476485dac27b8c154e797f Mon Sep 17 00:00:00 2001 From: SutarPrasad Date: Wed, 4 Mar 2026 15:32:08 +0000 Subject: [PATCH 06/68] 6km fp_x_flux is added as a separate variable --- openghg_inversions/inversion_data/scenario.py | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/openghg_inversions/inversion_data/scenario.py b/openghg_inversions/inversion_data/scenario.py index 35eb596d..c65700e2 100644 --- a/openghg_inversions/inversion_data/scenario.py +++ b/openghg_inversions/inversion_data/scenario.py @@ -9,6 +9,7 @@ def merged_scenario_data( footprint_data: FootprintData, flux_dict: dict[str, FluxData], bc_data: BoundaryConditionsData | None = None, + inner_footprint_data: FootprintData | None = None, platform: str | None = None, max_level: int | None = None ) -> xr.Dataset: @@ -42,4 +43,18 @@ def merged_scenario_data( cache=False, ) - return scenario_combined + if inner_footprint_data is not None: + inner_scenario = ModelScenario(obs=obs_data, footprint=inner_footprint_data, flux=flux_dict, bc=None) + inner_domain_merged = inner_scenario.footprints_data_merge( + calc_fp_x_flux=True, + calc_bc_sensitivity=False, + cache=False, + ) + scenario_combined = scenario_combined.copy() + + # 6km fp_x_flux is added as a separate variable to the combined dataset, and can be merged with the EUROPE fp_x_flux. + scenario_combined["fp_x_flux_inner"] = inner_domain_merged["fp_x_flux"] + + return scenario_combined + else: + return scenario_combined From 2e7d17ec45712f49ec64dfe23a30e2dc9e468ebb Mon Sep 17 00:00:00 2001 From: SutarPrasad Date: Wed, 4 Mar 2026 15:32:58 +0000 Subject: [PATCH 07/68] inner_fp_store and inner_domain is added to fixedbasisMCMC --- openghg_inversions/hbmcmc/hbmcmc.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/openghg_inversions/hbmcmc/hbmcmc.py b/openghg_inversions/hbmcmc/hbmcmc.py index 74066ca9..3b6fe5a4 100644 --- a/openghg_inversions/hbmcmc/hbmcmc.py +++ b/openghg_inversions/hbmcmc/hbmcmc.py @@ -262,6 +262,7 @@ def fixedbasisMCMC( obs_store: str = "user", footprint_store: str = "user", emissions_store: str = "user", + inner_footprint_store: str = "user", met_model: list | None = None, fp_model: str | None = None, # Changed to none. When "NAME" specified FPs are not found fp_height: list[str] | None = None, @@ -282,6 +283,7 @@ def fixedbasisMCMC( country_directory: str | None = None, country_file: str | None = None, bc_input: str | None = None, + inner_domain: str | None = None, basis_algorithm: str = "weighted", nbasis: int = 100, xprior: dict = {"pdf": "truncatednormal", "mu": 1.0, "sigma": 1.0, "lower": 0.0}, @@ -538,6 +540,7 @@ def fixedbasisMCMC( species=species, sites=sites, domain=domain, + inner_domain=inner_domain, averaging_period=averaging_period, start_date=start_date, end_date=end_date, @@ -558,6 +561,7 @@ def fixedbasisMCMC( obs_store=obs_store, footprint_store=footprint_store, emissions_store=emissions_store, + inner_footprint_store=inner_footprint_store, averagingerror=averaging_error, save_merged_data=save_merged_data, merged_data_name=merged_data_name, @@ -584,6 +588,7 @@ def fixedbasisMCMC( use_bc=use_bc, species=species, domain=domain, + inner_domain=inner_domain, start_date=start_date, fix_outer_regions=fix_basis_outer_regions, emissions_name=emissions_name, From b7bd81d5a457c72dda026376b92378de370e5aa9 Mon Sep 17 00:00:00 2001 From: SutarPrasad Date: Thu, 5 Mar 2026 14:43:19 +0000 Subject: [PATCH 08/68] combine_inner_outer_fp_x_flux --- openghg_inversions/basis/_helpers.py | 31 ++++++++++++++++++++++++++++ 1 file changed, 31 insertions(+) diff --git a/openghg_inversions/basis/_helpers.py b/openghg_inversions/basis/_helpers.py index 91ddfe3e..ba31d053 100644 --- a/openghg_inversions/basis/_helpers.py +++ b/openghg_inversions/basis/_helpers.py @@ -66,6 +66,14 @@ def fp_sensitivity(fp_and_data: dict, basis_func: xr.DataArray | dict[str, xr.Da fp_and_data[".basis"] = basis_func for site in sites: + fp_x_flux = fp_and_data[site]["fp_x_flux"] + # if inner domain fp_x_flux exists, blend it in + if "fp_x_flux_inner" in fp_and_data[site]: + fp_x_flux = combine_inner_outer_fp_x_flux( + inner_fp_x_flux=fp_and_data[site]["fp_x_flux_inner"], + outer_fp_x_flux=fp_x_flux, + ) + sensitivity = apply_fp_basis_functions( fp_x_flux=fp_and_data[site][fp_x_flux_name], basis_func=basis_func, @@ -75,6 +83,29 @@ def fp_sensitivity(fp_and_data: dict, basis_func: xr.DataArray | dict[str, xr.Da return fp_and_data +def combine_inner_outer_fp_x_flux( + inner_fp_x_flux: xr.DataArray, + outer_fp_x_flux: xr.DataArray, +) -> xr.DataArray: + """Merge inner (6km) and outer (EUROPE) fp_x_flux.""" + # Regrid inner fp_x_flux to the same grid as outer fp_x_flux, and then patch it in where the inner domain mask is True. + # regrid inner to EUROPE lat/lon coords + inner_regridded = inner_fp_x_flux.interp(lat=outer_fp_x_flux.lat, lon=outer_fp_x_flux.lon, method="nearest") + + # force coordinates to exactly match outer (avoids float precision + # mismatches that prevent xr.align / xr.where from working correctly) + inner_regridded = inner_regridded.assign_coords(lat=outer_fp_x_flux.lat, lon=outer_fp_x_flux.lon) + + # fill NaN (points outside the inner domain extent) with 0 + inner_regridded = inner_regridded.fillna(0.0) + + # True where the inner domain contributed non-zero values at any timestep + inner_has_coverage = (inner_regridded != 0).any("time") + + # Both arrays are now on the EUROPE grid so xr.where is safe + return xr.where(inner_has_coverage, inner_regridded, outer_fp_x_flux) + + def apply_fp_basis_functions( fp_x_flux: xr.DataArray, basis_func: xr.DataArray, From 6d55254f04527f8a4ce90616860b5e06e67ad9e0 Mon Sep 17 00:00:00 2001 From: SutarPrasad Date: Thu, 5 Mar 2026 14:50:44 +0000 Subject: [PATCH 09/68] fp_x_flux fetched --- openghg_inversions/basis/_functions.py | 21 ++++++++++++++++++++- 1 file changed, 20 insertions(+), 1 deletion(-) diff --git a/openghg_inversions/basis/_functions.py b/openghg_inversions/basis/_functions.py index 79b355a2..fdb04b0d 100644 --- a/openghg_inversions/basis/_functions.py +++ b/openghg_inversions/basis/_functions.py @@ -157,7 +157,26 @@ def _flux_fp_from_fp_all( flux = cast(xr.DataArray, flux) - footprints: list[xr.DataArray] = [v.fp for k, v in fp_all.items() if not k.startswith(".")] + footprints = [] + for k, v in fp_all.items(): + if k.startswith("."): + continue + + # Need to discuss this further + # fp_x_flux is guaranteed to be on the + # EUROPE grid (it comes from the standard scenario). Raw .fp may be + # on the 6km grid if OpenGHG snapped it to the footprint resolution. + if "fp_x_flux" in v: + fp = v["fp_x_flux"] + else: + fp = v.fp + + # if grid still doesn't match flux, regrid to flux grid + if fp.sizes.get("lat") != flux.sizes.get("lat") or fp.sizes.get("lon") != flux.sizes.get("lon"): + fp = fp.interp(lat=flux.lat, lon=flux.lon, method="nearest").fillna(0.0) + fp = fp.assign_coords(lat=flux.lat, lon=flux.lon) + + footprints.append(fp) return flux, footprints From 96c6d7c77fb1ea92e112d97d414a9584f4d38487 Mon Sep 17 00:00:00 2001 From: SutarPrasad Date: Tue, 10 Mar 2026 10:28:15 +0000 Subject: [PATCH 10/68] adding extra print statements --- openghg_inversions/inversion_data/get_data.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/openghg_inversions/inversion_data/get_data.py b/openghg_inversions/inversion_data/get_data.py index 70334e74..8cb2335f 100644 --- a/openghg_inversions/inversion_data/get_data.py +++ b/openghg_inversions/inversion_data/get_data.py @@ -370,7 +370,7 @@ def data_processing_surface_notracer( ) if inner_footprint_data is None: print( - f"\nNo Inner footprint data found for {site} with inlet/height {fp_height[i]}, model {fp_model}, and domain {domain}-{inner_domain}.", + f"\nNo Inner footprint data found for {site} with inlet/height {fp_height[i]}, model {fp_model}, and domain {domain}-{inner_domain}, starting from {start_date} to {end_date}, fp_species {fp_species} and met_model {met_model[i]}, obs_data {site_data} and averaging_period {averaging_period[i]}.", f"Check these values.\nContinuing model run without {site}.Jai\n", ) continue # skip this site From c7331140c4e70ad2f1d64d60306a919ab4a389d3 Mon Sep 17 00:00:00 2001 From: SutarPrasad Date: Tue, 10 Mar 2026 17:15:12 +0000 Subject: [PATCH 11/68] commenting inner-domain as we are regridding to europe domain before this soon to be changed --- openghg_inversions/basis/_helpers.py | 1 + openghg_inversions/hbmcmc/hbmcmc.py | 4 ++-- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/openghg_inversions/basis/_helpers.py b/openghg_inversions/basis/_helpers.py index ba31d053..39dbc055 100644 --- a/openghg_inversions/basis/_helpers.py +++ b/openghg_inversions/basis/_helpers.py @@ -78,6 +78,7 @@ def fp_sensitivity(fp_and_data: dict, basis_func: xr.DataArray | dict[str, xr.Da fp_x_flux=fp_and_data[site][fp_x_flux_name], basis_func=basis_func, ) + # TODO: store houter hinner here fp_and_data[site]["H"] = sensitivity return fp_and_data diff --git a/openghg_inversions/hbmcmc/hbmcmc.py b/openghg_inversions/hbmcmc/hbmcmc.py index 3b6fe5a4..3efec960 100644 --- a/openghg_inversions/hbmcmc/hbmcmc.py +++ b/openghg_inversions/hbmcmc/hbmcmc.py @@ -588,14 +588,14 @@ def fixedbasisMCMC( use_bc=use_bc, species=species, domain=domain, - inner_domain=inner_domain, + # inner_domain=inner_domain, start_date=start_date, fix_outer_regions=fix_basis_outer_regions, emissions_name=emissions_name, outputname=outputname, output_path=basis_output_path, ) - + print(f"Basis functions applied to data {domain}.\n ") # Apply named filters to the data if filters is not None: try: From ab727b429fd05fdfdcd468dd819ba0bf95fd746e Mon Sep 17 00:00:00 2001 From: SutarPrasad Date: Fri, 13 Mar 2026 15:31:15 +0000 Subject: [PATCH 12/68] removed hardcoded 6km --- openghg_inversions/inversion_data/getters.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/openghg_inversions/inversion_data/getters.py b/openghg_inversions/inversion_data/getters.py index 7908c41d..9f4b950f 100644 --- a/openghg_inversions/inversion_data/getters.py +++ b/openghg_inversions/inversion_data/getters.py @@ -30,7 +30,7 @@ def adjust_flux_start_date( start_date: str, species: str, source: str, domain: str, store: str | None = None ) -> pd.Timestamp: """Adjusts the flux start_date to align with the flux data's temporal resolution.""" - flux_search = search_flux(species=species, source=source, domain=f"{domain}-6km", store=store) + flux_search = search_flux(species=species, source=source, domain=domain, store=store) if flux_search.results.empty: raise SearchError( f"No flux found with species={species}, source={source}, domain={domain}, store={store}." From 27622d2c6d0518235193f10052c61fc7552ca580 Mon Sep 17 00:00:00 2001 From: SutarPrasad Date: Fri, 13 Mar 2026 15:31:28 +0000 Subject: [PATCH 13/68] inner domain flux --- openghg_inversions/inversion_data/get_data.py | 29 ++++++++++++++----- 1 file changed, 22 insertions(+), 7 deletions(-) diff --git a/openghg_inversions/inversion_data/get_data.py b/openghg_inversions/inversion_data/get_data.py index 8cb2335f..564a71c1 100644 --- a/openghg_inversions/inversion_data/get_data.py +++ b/openghg_inversions/inversion_data/get_data.py @@ -259,15 +259,30 @@ def data_processing_surface_notracer( raise ValueError("`emissions_name` must be specified") flux_dict = get_flux_data( - sources=emissions_name, - species=species, - domain=domain, - start_date=start_date, - end_date=end_date, - store=emissions_store, - ) + sources=emissions_name, + species=species, + domain=f"{domain}-{inner_domain}" if inner_domain is not None else domain, + start_date=start_date, + end_date=end_date, + store=emissions_store, + ) + fp_all[".flux"] = flux_dict + if inner_domain is not None: + print(f"Inner domain {inner_domain} specified; attempting to retrieve flux data for both domains ...") + + flux_dict = get_flux_data( + sources=emissions_name, + species=species, + domain=f"{domain}-{inner_domain}", + start_date=start_date, + end_date=end_date, + store=emissions_store, + ) + + fp_all[".flux_inner"] = flux_dict + # Get BC data if use_bc is True: try: From db9af9b40bfd891b393a5c2408d75d1e8952b7cf Mon Sep 17 00:00:00 2001 From: SutarPrasad Date: Fri, 13 Mar 2026 16:16:34 +0000 Subject: [PATCH 14/68] getting inner flux for inner scenario --- openghg_inversions/inversion_data/get_data.py | 10 +++++----- openghg_inversions/inversion_data/scenario.py | 3 ++- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/openghg_inversions/inversion_data/get_data.py b/openghg_inversions/inversion_data/get_data.py index 564a71c1..54f8e6e7 100644 --- a/openghg_inversions/inversion_data/get_data.py +++ b/openghg_inversions/inversion_data/get_data.py @@ -272,17 +272,16 @@ def data_processing_surface_notracer( if inner_domain is not None: print(f"Inner domain {inner_domain} specified; attempting to retrieve flux data for both domains ...") - flux_dict = get_flux_data( + inner_flux_dict = get_flux_data( sources=emissions_name, species=species, domain=f"{domain}-{inner_domain}", start_date=start_date, end_date=end_date, store=emissions_store, - ) + ) if inner_domain is not None else None - fp_all[".flux_inner"] = flux_dict - + fp_all[".inner_flux"] = inner_flux_dict if inner_flux_dict is not None else {} # Get BC data if use_bc is True: try: @@ -392,7 +391,8 @@ def data_processing_surface_notracer( scenario_combined = merged_scenario_data( obs_data=site_data, footprint_data=standard_footprint_data, - flux_dict= flux_dict, bc_data=bc_data, inner_footprint_data=inner_footprint_data, platform=platform[i], max_level=max_level + flux_dict= flux_dict, bc_data=bc_data, inner_footprint_data=inner_footprint_data, inner_flux_dict=inner_flux_dict, + platform=platform[i], max_level=max_level ) fp_all[site] = scenario_combined diff --git a/openghg_inversions/inversion_data/scenario.py b/openghg_inversions/inversion_data/scenario.py index c65700e2..4b31a74b 100644 --- a/openghg_inversions/inversion_data/scenario.py +++ b/openghg_inversions/inversion_data/scenario.py @@ -10,6 +10,7 @@ def merged_scenario_data( flux_dict: dict[str, FluxData], bc_data: BoundaryConditionsData | None = None, inner_footprint_data: FootprintData | None = None, + inner_flux_dict: dict[str, FluxData] | None = None, platform: str | None = None, max_level: int | None = None ) -> xr.Dataset: @@ -44,7 +45,7 @@ def merged_scenario_data( ) if inner_footprint_data is not None: - inner_scenario = ModelScenario(obs=obs_data, footprint=inner_footprint_data, flux=flux_dict, bc=None) + inner_scenario = ModelScenario(obs=obs_data, footprint=inner_footprint_data, flux=inner_flux_dict, bc=None) inner_domain_merged = inner_scenario.footprints_data_merge( calc_fp_x_flux=True, calc_bc_sensitivity=False, From bfa1735294d0ef811458e63cffc7b51d752861a8 Mon Sep 17 00:00:00 2001 From: SutarPrasad Date: Mon, 16 Mar 2026 11:12:04 +0000 Subject: [PATCH 15/68] calculating H_inner --- openghg_inversions/basis/_helpers.py | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/openghg_inversions/basis/_helpers.py b/openghg_inversions/basis/_helpers.py index 39dbc055..03bb13a8 100644 --- a/openghg_inversions/basis/_helpers.py +++ b/openghg_inversions/basis/_helpers.py @@ -69,18 +69,25 @@ def fp_sensitivity(fp_and_data: dict, basis_func: xr.DataArray | dict[str, xr.Da fp_x_flux = fp_and_data[site]["fp_x_flux"] # if inner domain fp_x_flux exists, blend it in if "fp_x_flux_inner" in fp_and_data[site]: - fp_x_flux = combine_inner_outer_fp_x_flux( - inner_fp_x_flux=fp_and_data[site]["fp_x_flux_inner"], - outer_fp_x_flux=fp_x_flux, - ) + # fp_x_flux = combine_inner_outer_fp_x_flux( + # inner_fp_x_flux=fp_and_data[site]["fp_x_flux_inner"], + # outer_fp_x_flux=fp_x_flux, + # ) + + sensitivity_inner = apply_fp_basis_functions( + fp_x_flux=fp_and_data[site]["fp_x_flux_inner"], + basis_func=basis_func, + ) + fp_and_data[site]["H_inner"] = sensitivity_inner sensitivity = apply_fp_basis_functions( fp_x_flux=fp_and_data[site][fp_x_flux_name], basis_func=basis_func, ) + # TODO: store houter hinner here fp_and_data[site]["H"] = sensitivity - + print(fp_and_data[site]) return fp_and_data From e8a9af73d160b4424451ff413f2016de6b18041a Mon Sep 17 00:00:00 2001 From: SutarPrasad Date: Mon, 16 Mar 2026 11:13:07 +0000 Subject: [PATCH 16/68] removed inner flux --- openghg_inversions/inversion_data/get_data.py | 17 ++--------------- 1 file changed, 2 insertions(+), 15 deletions(-) diff --git a/openghg_inversions/inversion_data/get_data.py b/openghg_inversions/inversion_data/get_data.py index 54f8e6e7..d36b79ee 100644 --- a/openghg_inversions/inversion_data/get_data.py +++ b/openghg_inversions/inversion_data/get_data.py @@ -261,7 +261,7 @@ def data_processing_surface_notracer( flux_dict = get_flux_data( sources=emissions_name, species=species, - domain=f"{domain}-{inner_domain}" if inner_domain is not None else domain, + domain=domain, start_date=start_date, end_date=end_date, store=emissions_store, @@ -269,19 +269,6 @@ def data_processing_surface_notracer( fp_all[".flux"] = flux_dict - if inner_domain is not None: - print(f"Inner domain {inner_domain} specified; attempting to retrieve flux data for both domains ...") - - inner_flux_dict = get_flux_data( - sources=emissions_name, - species=species, - domain=f"{domain}-{inner_domain}", - start_date=start_date, - end_date=end_date, - store=emissions_store, - ) if inner_domain is not None else None - - fp_all[".inner_flux"] = inner_flux_dict if inner_flux_dict is not None else {} # Get BC data if use_bc is True: try: @@ -391,7 +378,7 @@ def data_processing_surface_notracer( scenario_combined = merged_scenario_data( obs_data=site_data, footprint_data=standard_footprint_data, - flux_dict= flux_dict, bc_data=bc_data, inner_footprint_data=inner_footprint_data, inner_flux_dict=inner_flux_dict, + flux_dict= flux_dict, bc_data=bc_data, inner_footprint_data=inner_footprint_data, platform=platform[i], max_level=max_level ) fp_all[site] = scenario_combined From 054189713f97524ec36bce0b78e7ea7a40936608 Mon Sep 17 00:00:00 2001 From: SutarPrasad Date: Mon, 16 Mar 2026 11:13:25 +0000 Subject: [PATCH 17/68] inner flux removal --- openghg_inversions/inversion_data/scenario.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/openghg_inversions/inversion_data/scenario.py b/openghg_inversions/inversion_data/scenario.py index 4b31a74b..c65700e2 100644 --- a/openghg_inversions/inversion_data/scenario.py +++ b/openghg_inversions/inversion_data/scenario.py @@ -10,7 +10,6 @@ def merged_scenario_data( flux_dict: dict[str, FluxData], bc_data: BoundaryConditionsData | None = None, inner_footprint_data: FootprintData | None = None, - inner_flux_dict: dict[str, FluxData] | None = None, platform: str | None = None, max_level: int | None = None ) -> xr.Dataset: @@ -45,7 +44,7 @@ def merged_scenario_data( ) if inner_footprint_data is not None: - inner_scenario = ModelScenario(obs=obs_data, footprint=inner_footprint_data, flux=inner_flux_dict, bc=None) + inner_scenario = ModelScenario(obs=obs_data, footprint=inner_footprint_data, flux=flux_dict, bc=None) inner_domain_merged = inner_scenario.footprints_data_merge( calc_fp_x_flux=True, calc_bc_sensitivity=False, From e7a872b7b5c48700b20f8ec81cff3c7a79fdd96b Mon Sep 17 00:00:00 2001 From: SutarPrasad Date: Tue, 17 Mar 2026 14:07:22 +0000 Subject: [PATCH 18/68] flux to inner --- openghg_inversions/inversion_data/scenario.py | 66 ++++++++++++++++++- 1 file changed, 65 insertions(+), 1 deletion(-) diff --git a/openghg_inversions/inversion_data/scenario.py b/openghg_inversions/inversion_data/scenario.py index c65700e2..4d63aa7d 100644 --- a/openghg_inversions/inversion_data/scenario.py +++ b/openghg_inversions/inversion_data/scenario.py @@ -3,6 +3,64 @@ from openghg.analyse import ModelScenario from openghg.dataobjects import ObsData, BoundaryConditionsData, FluxData, FootprintData +def _mask_flux_to_inner_domain( + flux_dict: dict[str, FluxData], + inner_footprint_data: FootprintData, +) -> dict[str, FluxData]: + """Mask EUROPE-domain flux values to zero outside the inner footprint extent. + + The inner footprint (``inner_fp``) is non-zero only within the inner + domain (e.g. 6 km grid). We use this spatial coverage to build a + boolean mask, regrid the EUROPE flux to the inner footprint lat/lon + coordinates, and then zero out any flux cells that fall outside the + inner domain extent. + + The masked flux is returned as a new dict of ``FluxData`` objects so + that ``ModelScenario`` can use it unmodified to compute ``fp_x_flux`` + on the correct inner grid. + + Args: + flux_dict: EUROPE-domain flux, keyed by source name. + inner_footprint_data: FootprintData for the inner domain whose + raw ``fp`` defines the spatial extent and lat/lon grid. + + Returns: + New dict of ``FluxData`` with flux regridded to the inner grid + and zeroed outside the inner domain footprint coverage. + """ + # Inner fp: dims (time, lat, lon) on the inner (e.g. 6 km) grid. + inner_fp: xr.DataArray = inner_footprint_data.data.fp + + # Boolean mask: True where inner domain has any non-zero fp at any time. + inner_domain_mask: xr.DataArray = (inner_fp != 0).any("time") # (lat, lon) + + inner_lat = inner_fp.lat + inner_lon = inner_fp.lon + + masked_flux_dict: dict[str, FluxData] = {} + + for source, flux_data in flux_dict.items(): + flux_da: xr.DataArray = flux_data.data.flux # EUROPE grid (time, lat, lon) + + # 1. Regrid EUROPE flux to inner footprint lat/lon grid + flux_on_inner = ( + flux_da + .interp(lat=inner_lat, lon=inner_lon, method="nearest") + .assign_coords(lat=inner_lat, lon=inner_lon) + .fillna(0.0) + ) + + # 2. Mask: zero out flux cells outside the inner domain extent + flux_masked = flux_on_inner.where(inner_domain_mask, other=0.0) + + # 3. Build a new FluxData with the masked flux dataset, + # preserving all original metadata and dataset attributes. + masked_ds = flux_data.data.copy() + masked_ds["flux"] = flux_masked + + masked_flux_dict[source] = FluxData(data=masked_ds, metadata=flux_data.metadata) + + return masked_flux_dict def merged_scenario_data( obs_data: ObsData, @@ -44,11 +102,17 @@ def merged_scenario_data( ) if inner_footprint_data is not None: - inner_scenario = ModelScenario(obs=obs_data, footprint=inner_footprint_data, flux=flux_dict, bc=None) + # Mask the EUROPE flux to the inner domain extent (zero outside), + # regridded to the inner footprint lat/lon grid. + # ModelScenario then computes fp_x_flux on the inner grid correctly. + flux_dict_inner = _mask_flux_to_inner_domain(flux_dict, inner_footprint_data) + + inner_scenario = ModelScenario(obs=obs_data, footprint=inner_footprint_data, flux=flux_dict_inner, bc=None) inner_domain_merged = inner_scenario.footprints_data_merge( calc_fp_x_flux=True, calc_bc_sensitivity=False, cache=False, + ) scenario_combined = scenario_combined.copy() From c42bbaa0e8b6a9802ae764c195259f452c9a9f8e Mon Sep 17 00:00:00 2001 From: SutarPrasad Date: Wed, 18 Mar 2026 10:40:16 +0000 Subject: [PATCH 19/68] datatree only in merged_scenario --- openghg_inversions/inversion_data/scenario.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/openghg_inversions/inversion_data/scenario.py b/openghg_inversions/inversion_data/scenario.py index 4d63aa7d..870aabac 100644 --- a/openghg_inversions/inversion_data/scenario.py +++ b/openghg_inversions/inversion_data/scenario.py @@ -101,6 +101,7 @@ def merged_scenario_data( cache=False, ) + dt_dict: dict[str, xr.Dataset] = {"/": scenario_combined} if inner_footprint_data is not None: # Mask the EUROPE flux to the inner domain extent (zero outside), # regridded to the inner footprint lat/lon grid. @@ -114,11 +115,15 @@ def merged_scenario_data( cache=False, ) - scenario_combined = scenario_combined.copy() - # 6km fp_x_flux is added as a separate variable to the combined dataset, and can be merged with the EUROPE fp_x_flux. - scenario_combined["fp_x_flux_inner"] = inner_domain_merged["fp_x_flux"] + # Align inner to outer time axis. + # If inner footprint is missing any timestamps that exist in the outer + # scenario (e.g. sparse inner store coverage), fill those with 0 so + # both nodes share exactly the same time dimension in the DataTree. + inner_domain_merged = inner_domain_merged.reindex( + time=scenario_combined.time, fill_value=0.0 + ) + + dt_dict["inner_domain_merged"] = inner_domain_merged - return scenario_combined - else: - return scenario_combined + return xr.DataTree.from_dict(dt_dict) From ab1da6e266652e983a651581ba7a26257bad9ee3 Mon Sep 17 00:00:00 2001 From: SutarPrasad Date: Wed, 18 Mar 2026 10:42:30 +0000 Subject: [PATCH 20/68] fetching outer_ds as root_ds --- openghg_inversions/inversion_data/get_data.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/openghg_inversions/inversion_data/get_data.py b/openghg_inversions/inversion_data/get_data.py index d36b79ee..f22b9627 100644 --- a/openghg_inversions/inversion_data/get_data.py +++ b/openghg_inversions/inversion_data/get_data.py @@ -383,7 +383,8 @@ def data_processing_surface_notracer( ) fp_all[site] = scenario_combined - units[site] = scenario_combined.mf.attrs.get("units") + root_ds = scenario_combined.ds + units[site] = root_ds.mf.attrs.get("units") if "satellite" not in platform: scales[site] = scenario_combined.scale From 499fcdae6a21309641f3914aff53b831c69eb4bd Mon Sep 17 00:00:00 2001 From: SutarPrasad Date: Wed, 18 Mar 2026 10:50:04 +0000 Subject: [PATCH 21/68] merged_scenario is datatree fp_all is still a dictionary --- openghg_inversions/inversion_data/get_data.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/openghg_inversions/inversion_data/get_data.py b/openghg_inversions/inversion_data/get_data.py index f22b9627..3b49c942 100644 --- a/openghg_inversions/inversion_data/get_data.py +++ b/openghg_inversions/inversion_data/get_data.py @@ -383,12 +383,14 @@ def data_processing_surface_notracer( ) fp_all[site] = scenario_combined + # scenario_combined returns a datatree. The standard_domain scenario is stored at the root, and any inner domain scenario is stored at "/{inner_domain}". Here we take the root scenario to store the calibration scale and units, since these should be the same for the inner domain scenario. If inner domain scenario is None the root scenario is the only scenario, so this will still work. + root_ds = scenario_combined.ds units[site] = root_ds.mf.attrs.get("units") if "satellite" not in platform: - scales[site] = scenario_combined.scale - check_scales.add(scenario_combined.scale) + scales[site] = root_ds.scale + check_scales.add(root_ds.scale) site_indices_to_keep.append(i) if len(site_indices_to_keep) == 0: From 920cc6de0b6e3663603b4e36cb3be9e08c057f7d Mon Sep 17 00:00:00 2001 From: SutarPrasad Date: Wed, 18 Mar 2026 11:12:43 +0000 Subject: [PATCH 22/68] root data fetch --- openghg_inversions/inversion_data/get_data.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/openghg_inversions/inversion_data/get_data.py b/openghg_inversions/inversion_data/get_data.py index 3b49c942..d6233522 100644 --- a/openghg_inversions/inversion_data/get_data.py +++ b/openghg_inversions/inversion_data/get_data.py @@ -60,7 +60,7 @@ def add_obs_error(sites: list[str], fp_all: dict, add_averaging_error: bool = Tr """ # TODO: do we want to fill missing values in repeatability or variability? for site in sites: - ds = fp_all[site] + ds = fp_all[site].ds variability_missing = False if "mf_variability" not in ds: From 8ba74ef6eb7587ce275c6a6874ca2b1f21dcb8b5 Mon Sep 17 00:00:00 2001 From: SutarPrasad Date: Fri, 20 Mar 2026 14:46:28 +0000 Subject: [PATCH 23/68] changes based on datatree syntax --- openghg_inversions/basis/_helpers.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/openghg_inversions/basis/_helpers.py b/openghg_inversions/basis/_helpers.py index 03bb13a8..c73983a6 100644 --- a/openghg_inversions/basis/_helpers.py +++ b/openghg_inversions/basis/_helpers.py @@ -66,22 +66,26 @@ def fp_sensitivity(fp_and_data: dict, basis_func: xr.DataArray | dict[str, xr.Da fp_and_data[".basis"] = basis_func for site in sites: - fp_x_flux = fp_and_data[site]["fp_x_flux"] - # if inner domain fp_x_flux exists, blend it in - if "fp_x_flux_inner" in fp_and_data[site]: + root_ds = fp_and_data[site].ds + fp_x_flux = root_ds[fp_x_flux_name] + inner_domain_ds = fp_and_data[site].get("inner_domain_merged") + + # Verify the node exists AND the variable is inside it + if inner_domain_ds is not None and "fp_x_flux_inner" in inner_domain_ds: + # if inner domain fp_x_flux exists, blend it in # fp_x_flux = combine_inner_outer_fp_x_flux( # inner_fp_x_flux=fp_and_data[site]["fp_x_flux_inner"], # outer_fp_x_flux=fp_x_flux, # ) - + fp_x_flux_inner = inner_domain_ds["fp_x_flux_inner"] sensitivity_inner = apply_fp_basis_functions( - fp_x_flux=fp_and_data[site]["fp_x_flux_inner"], + fp_x_flux=fp_x_flux_inner, basis_func=basis_func, ) - fp_and_data[site]["H_inner"] = sensitivity_inner + inner_domain_ds["H_inner"] = sensitivity_inner sensitivity = apply_fp_basis_functions( - fp_x_flux=fp_and_data[site][fp_x_flux_name], + fp_x_flux=fp_x_flux, basis_func=basis_func, ) From 889c6dc7a7ace8c45f5cd8e9f4e0855f67794d28 Mon Sep 17 00:00:00 2001 From: SutarPrasad Date: Fri, 20 Mar 2026 16:58:23 +0000 Subject: [PATCH 24/68] hardcoded test --- tests/conftest.py | 27 +++++++++++++++ tests/test_full_inversion.py | 66 ++++++++++++++++++++++++++++++++++++ 2 files changed, 93 insertions(+) diff --git a/tests/conftest.py b/tests/conftest.py index c29f9c9f..a928d340 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -344,3 +344,30 @@ def mhd_and_tac_ch4_data_args(): "averaging_period": ["1h", "1h"], } return data_args + +@pytest.fixture(scope="module") +def mhd_with_inner_domain_ch4_data_args(): + """Data args for MHD with an inner domain (europe-6km) footprint.""" + data_args = { + "species": "ch4", + "sites": ["MHD"], + "use_tracer" : False, + "start_date": "2023-01-01", + "end_date": "2023-10-01", + "bc_store": "/group/chem/acrg/object_stores/shared_store_zarr", + "obs_store": "/group/chem/acrg/object_stores/paris/obs_icos_2025_08_store", + "footprint_store": "/group/chem/acrg/object_stores/shared_store_zarr", + "emissions_store": "/group/chem/acrg/object_stores/shared_store_zarr", + "inner_footprint_store": "/group/chem/acrg/object_stores/shared_store_zarr", + "inlet": ["9m"], + "instrument": None, + "domain": "EUROPE", + "inner_domain": "6km", + "fp_height": ["10m"], + "fp_species": "inert", + "fp_model": "NAME", + "emissions_name": ["edgarv80_wetchartsv131"], + "averaging_period": ["4h"], + "bc_input": "camsv22r2_daily", + } + return data_args \ No newline at end of file diff --git a/tests/test_full_inversion.py b/tests/test_full_inversion.py index 367524f2..00b94a4f 100644 --- a/tests/test_full_inversion.py +++ b/tests/test_full_inversion.py @@ -102,6 +102,15 @@ def test_full_inversion_paris_outputs(mcmc_args): assert "Yapost" in out +def test_full_inversion_6km_paris_outputs(mcmc_args): + """Test full inversion including loading data with PARIS output format.""" + mcmc_args["reload_merged_data"] = False + mcmc_args["output_format"] = "paris" + out = fixedbasisMCMC(**mcmc_args) + + assert "Yapost" in out + + def test_full_inversion_no_model_error(mcmc_args): mcmc_args["no_model_error"] = True fixedbasisMCMC(**mcmc_args) @@ -219,3 +228,60 @@ def test_full_inversion_long(mcmc_args): } ) fixedbasisMCMC(**mcmc_args) + +@pytest.fixture +def inner_domain_mcmc_args(tmp_path, mhd_with_inner_domain_ch4_data_args): + """MCMC args for a minimal inversion with inner domain enabled.""" + mcmc_args = mhd_with_inner_domain_ch4_data_args.copy() + mcmc_args.update( + { + "outputname": "inner_domain_test_run", + "outputpath": "/group/chem/acrg/prasad/job_scripts/openghg_inversions/inversion_outputs/paris_bc", + "basis_algorithm": "quadtree", + "bc_basis_case" : "NESW", + "fp_basis_case" : None, + "nbasis" : 250, + "basis_directory" : "/group/chem/acrg/LPDM/basis_functions/", + "bc_basis_directory" : "/group/chem/acrg/LPDM/bc_basis_functions/", + "country_file" : "/group/chem/acrg/LPDM/countries/country_EUROPE_EEZ_PARIS_gapfilled.nc", + "basis_output_path": str(tmp_path), + "nbasis": 4, + "nit": 1, + "burn": 10, + "tune": 10, + "nchain": 1, + "mcmc_type" : "fixed_basis", + "reload_merged_data": False, + "xprior": {"pdf": "normal", "mu": 1.0, "sigma": 1.0}, + "bcprior": {"pdf": "normal", "mu": 1.0, "sigma": 0.5}, + "sigprior": {"pdf": "uniform", "lower": 0.5, "upper": 3.0}, + "use_bc": True, + "averaging_error": True, + "min_error": 0.0, + "fix_basis_outer_regions" : "False", + "nuts_sampler" : "numpyro", + "save_trace" : True, + "pollution_events_from_obs" : True, + "no_model_error" : False, + "bc_freq" : None, + "sigma_freq" : None, + "sigma_per_site" : True, + "inlet" : [slice(0,25)], + "instrument" : ['multiple'], + "filters" : {'MHD' : ['pblh_min','pblh_inlet_diff']} + } + ) + return mcmc_args + +def test_full_inversion_with_inner_domain(inner_domain_mcmc_args): + """Test that fixedbasisMCMC runs end-to-end with inner_domain specified. + + Verifies: + - The inversion completes without error. + - Standard output variables are present (Yerror_repeatability, Yerror_variability). + - The fp_data for each site contains H_inner (computed from the inner footprint). + """ + out = fixedbasisMCMC(**inner_domain_mcmc_args) + + assert "Yerror_repeatability" in out + assert "Yerror_variability" in out \ No newline at end of file From 9921f76352fed76eb93c9e9d1c9818ae03fd2c10 Mon Sep 17 00:00:00 2001 From: SutarPrasad Date: Sat, 21 Mar 2026 00:10:25 +0000 Subject: [PATCH 25/68] renamed lat lon of child dataset --- openghg_inversions/inversion_data/get_data.py | 1 + openghg_inversions/inversion_data/scenario.py | 5 +++-- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/openghg_inversions/inversion_data/get_data.py b/openghg_inversions/inversion_data/get_data.py index d6233522..b1aae29d 100644 --- a/openghg_inversions/inversion_data/get_data.py +++ b/openghg_inversions/inversion_data/get_data.py @@ -84,6 +84,7 @@ def add_obs_error(sites: list[str], fp_all: dict, add_averaging_error: bool = Tr elif add_averaging_error: # Fill with zeros so that if one of repeatability and variability is not NaN, then mf_error will not be NaN. + ds = ds.copy() ds["mf_error"] = np.sqrt( ds["mf_repeatability"].fillna(0) ** 2 + ds["mf_variability"].fillna(0) ** 2 ) diff --git a/openghg_inversions/inversion_data/scenario.py b/openghg_inversions/inversion_data/scenario.py index 870aabac..27b39fda 100644 --- a/openghg_inversions/inversion_data/scenario.py +++ b/openghg_inversions/inversion_data/scenario.py @@ -123,7 +123,8 @@ def merged_scenario_data( inner_domain_merged = inner_domain_merged.reindex( time=scenario_combined.time, fill_value=0.0 ) - - dt_dict["inner_domain_merged"] = inner_domain_merged + inner_domain_merged = inner_domain_merged.rename({"lat": "lat_inner", "lon": "lon_inner"}) + inner_domain_merged = inner_domain_merged.reset_index(["lat_inner", "lon_inner"]) + dt_dict["inner"] = inner_domain_merged return xr.DataTree.from_dict(dt_dict) From a0ef6445ab353089cbbfe8c507d2271a1cf3089c Mon Sep 17 00:00:00 2001 From: SutarPrasad Date: Sat, 21 Mar 2026 00:10:48 +0000 Subject: [PATCH 26/68] updates to handle fp_x_flux --- openghg_inversions/basis/_helpers.py | 17 +++++++---------- 1 file changed, 7 insertions(+), 10 deletions(-) diff --git a/openghg_inversions/basis/_helpers.py b/openghg_inversions/basis/_helpers.py index c73983a6..a389ce18 100644 --- a/openghg_inversions/basis/_helpers.py +++ b/openghg_inversions/basis/_helpers.py @@ -66,32 +66,29 @@ def fp_sensitivity(fp_and_data: dict, basis_func: xr.DataArray | dict[str, xr.Da fp_and_data[".basis"] = basis_func for site in sites: + # if inner domain DataTree child is present, compute H_inner from its fp_x_flux root_ds = fp_and_data[site].ds - fp_x_flux = root_ds[fp_x_flux_name] - inner_domain_ds = fp_and_data[site].get("inner_domain_merged") - - # Verify the node exists AND the variable is inside it - if inner_domain_ds is not None and "fp_x_flux_inner" in inner_domain_ds: + if isinstance(fp_and_data[site], xr.DataTree) and "inner" in fp_and_data[site].children: + inner_domain_ds = fp_and_data[site]["inner"] + inner_fp_x_flux = inner_domain_ds["fp_x_flux"] # if inner domain fp_x_flux exists, blend it in # fp_x_flux = combine_inner_outer_fp_x_flux( # inner_fp_x_flux=fp_and_data[site]["fp_x_flux_inner"], # outer_fp_x_flux=fp_x_flux, # ) - fp_x_flux_inner = inner_domain_ds["fp_x_flux_inner"] sensitivity_inner = apply_fp_basis_functions( - fp_x_flux=fp_x_flux_inner, + fp_x_flux=inner_fp_x_flux, basis_func=basis_func, ) inner_domain_ds["H_inner"] = sensitivity_inner sensitivity = apply_fp_basis_functions( - fp_x_flux=fp_x_flux, + fp_x_flux=root_ds[fp_x_flux_name], basis_func=basis_func, ) # TODO: store houter hinner here - fp_and_data[site]["H"] = sensitivity - print(fp_and_data[site]) + root_ds["H"] = sensitivity return fp_and_data From 8828b49abec7bd618a261abe227a9685ab651c9e Mon Sep 17 00:00:00 2001 From: SutarPrasad Date: Mon, 23 Mar 2026 10:57:02 +0000 Subject: [PATCH 27/68] compatibility with datatree changes --- openghg_inversions/basis/_helpers.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/openghg_inversions/basis/_helpers.py b/openghg_inversions/basis/_helpers.py index a389ce18..f552dd31 100644 --- a/openghg_inversions/basis/_helpers.py +++ b/openghg_inversions/basis/_helpers.py @@ -67,7 +67,7 @@ def fp_sensitivity(fp_and_data: dict, basis_func: xr.DataArray | dict[str, xr.Da for site in sites: # if inner domain DataTree child is present, compute H_inner from its fp_x_flux - root_ds = fp_and_data[site].ds + root_ds = fp_and_data[site].ds.copy() # avoid modifying original dataset in-place if isinstance(fp_and_data[site], xr.DataTree) and "inner" in fp_and_data[site].children: inner_domain_ds = fp_and_data[site]["inner"] inner_fp_x_flux = inner_domain_ds["fp_x_flux"] @@ -88,7 +88,7 @@ def fp_sensitivity(fp_and_data: dict, basis_func: xr.DataArray | dict[str, xr.Da ) # TODO: store houter hinner here - root_ds["H"] = sensitivity + root_ds["H"] = (sensitivity.dims, sensitivity.data, {"long_name": "sensitivity"}) return fp_and_data @@ -171,7 +171,7 @@ def bc_sensitivity( if basis_case.lower() == "nesw": for site in sites: ds = fp_and_data[site] - bc_ds = ds[[f"bc_{d}" for d in "nesw"]].rename({f"bc_{d}": d for d in "nesw"}) + bc_ds = ds.ds[[f"bc_{d}" for d in "nesw"]].rename({f"bc_{d}": d for d in "nesw"}) sensitivity = bc_ds.sum(["lat", "lon", "height"]).to_dataarray(dim="bc_region") fp_and_data[site]["H_bc"] = sensitivity @@ -192,7 +192,7 @@ def bc_sensitivity( for site in sites: ds = fp_and_data[site] - bc_ds = ds[[f"bc_{d}" for d in "nesw"]] + bc_ds = ds.ds[[f"bc_{d}" for d in "nesw"]] sensitivity = (bc_ds * bc_basis).sum(["lat", "lon", "height"]).to_dataarray(dim="__newdim__").sum("__newdim__") sensitivity = sensitivity.rename(region="bc_region") fp_and_data[site]["H_bc"] = sensitivity From 26020464532963ea098b334dff47ba96bf9227dd Mon Sep 17 00:00:00 2001 From: SutarPrasad Date: Mon, 23 Mar 2026 10:59:35 +0000 Subject: [PATCH 28/68] datatree compatibility hbmcmc --- openghg_inversions/hbmcmc/hbmcmc.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/openghg_inversions/hbmcmc/hbmcmc.py b/openghg_inversions/hbmcmc/hbmcmc.py index 3efec960..fe75064d 100644 --- a/openghg_inversions/hbmcmc/hbmcmc.py +++ b/openghg_inversions/hbmcmc/hbmcmc.py @@ -104,7 +104,8 @@ def make_inv_inputs( ] for site in sites: to_compute_site = [dv for dv in to_compute if dv in fp_data[site].data_vars] - fp_data[site][to_compute_site] = fp_data[site][to_compute_site].compute() + for var in to_compute_site: + fp_data[site][var] = fp_data[site][var].compute() # Get inputs ready error = np.zeros(0) @@ -129,7 +130,12 @@ def make_inv_inputs( drop_vars.append(var) # pymc doesn't like NaNs, so drop them for the variables used below - fp_data[site] = fp_data[site].dropna("time", subset=drop_vars) + # DataTree doesn't support dropna; use sel with valid time indices instead + if isinstance(fp_data[site], xr.DataTree): + valid_times = fp_data[site].ds.dropna("time", subset=drop_vars).time + fp_data[site] = fp_data[site].sel(time=valid_times) + else: + fp_data[site] = fp_data[site].dropna("time", subset=drop_vars) # repeatability/variability chosen/combined into mf_error in `get_data.py` error = np.concatenate((error, fp_data[site].mf_error.values)) From be015b00736977f6879b18aeaeb474172cc9c4f6 Mon Sep 17 00:00:00 2001 From: SutarPrasad Date: Tue, 24 Mar 2026 11:44:34 +0000 Subject: [PATCH 29/68] datatree comaptibility --- openghg_inversions/hbmcmc/hbmcmc.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/openghg_inversions/hbmcmc/hbmcmc.py b/openghg_inversions/hbmcmc/hbmcmc.py index fe75064d..6fc229f4 100644 --- a/openghg_inversions/hbmcmc/hbmcmc.py +++ b/openghg_inversions/hbmcmc/hbmcmc.py @@ -93,6 +93,7 @@ def make_inv_inputs( to_compute = [ "H", "H_bc", + "H_inner", "mf", "mf_error", "mf_repeatability", @@ -138,7 +139,7 @@ def make_inv_inputs( fp_data[site] = fp_data[site].dropna("time", subset=drop_vars) # repeatability/variability chosen/combined into mf_error in `get_data.py` - error = np.concatenate((error, fp_data[site].mf_error.values)) + error = np.concatenate((error, fp_data[site].ds.mf_error.values)) # make repeatability and variability for outputs (not used directly in inversions) obs_repeatability = np.concatenate((obs_repeatability, fp_data[site].mf_repeatability.values)) @@ -163,7 +164,7 @@ def make_inv_inputs( else: Ytime = np.concatenate((Ytime, fp_data[site].time.values)) - Hx = fp_data[site].H.values if si == 0 else np.hstack((Hx, fp_data[site].H.values)) + Hx = fp_data[site].ds.H.values if si == 0 else np.hstack((Hx, fp_data[site].ds.H.values)) if np.isnan(Hx).any(): warnings.warn(f"Hx matrix contains {np.isnan(Hx).flatten().sum()} NaN values") @@ -619,7 +620,7 @@ def fixedbasisMCMC( dropped_sites = [] for site in sites: # check if some datasets are empty due to filtering - if fp_data[site].time.values.shape[0] == 0: + if fp_data[site].ds.time.values.shape[0] == 0: dropped_sites.append(site) del fp_data[site] From a3332f868fd9bbc95f102eba67c61cd0f8796965 Mon Sep 17 00:00:00 2001 From: SutarPrasad Date: Tue, 24 Mar 2026 11:44:54 +0000 Subject: [PATCH 30/68] writing back to fp_and_data from copied --- openghg_inversions/basis/_helpers.py | 1 + 1 file changed, 1 insertion(+) diff --git a/openghg_inversions/basis/_helpers.py b/openghg_inversions/basis/_helpers.py index f552dd31..98ff1648 100644 --- a/openghg_inversions/basis/_helpers.py +++ b/openghg_inversions/basis/_helpers.py @@ -89,6 +89,7 @@ def fp_sensitivity(fp_and_data: dict, basis_func: xr.DataArray | dict[str, xr.Da # TODO: store houter hinner here root_ds["H"] = (sensitivity.dims, sensitivity.data, {"long_name": "sensitivity"}) + fp_and_data[site].ds = root_ds return fp_and_data From 28caa48179b740d603ee0e64cc7d6779a23ec6f5 Mon Sep 17 00:00:00 2001 From: SutarPrasad Date: Tue, 24 Mar 2026 11:45:13 +0000 Subject: [PATCH 31/68] copying the raw data modifying and writing back --- openghg_inversions/inversion_data/get_data.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/openghg_inversions/inversion_data/get_data.py b/openghg_inversions/inversion_data/get_data.py index b1aae29d..612f37af 100644 --- a/openghg_inversions/inversion_data/get_data.py +++ b/openghg_inversions/inversion_data/get_data.py @@ -60,8 +60,7 @@ def add_obs_error(sites: list[str], fp_all: dict, add_averaging_error: bool = Tr """ # TODO: do we want to fill missing values in repeatability or variability? for site in sites: - ds = fp_all[site].ds - + ds = fp_all[site].ds.copy() variability_missing = False if "mf_variability" not in ds: ds["mf_variability"] = xr.zeros_like(ds.mf) @@ -84,7 +83,6 @@ def add_obs_error(sites: list[str], fp_all: dict, add_averaging_error: bool = Tr elif add_averaging_error: # Fill with zeros so that if one of repeatability and variability is not NaN, then mf_error will not be NaN. - ds = ds.copy() ds["mf_error"] = np.sqrt( ds["mf_repeatability"].fillna(0) ** 2 + ds["mf_variability"].fillna(0) ** 2 ) @@ -124,6 +122,7 @@ def add_obs_error(sites: list[str], fp_all: dict, add_averaging_error: bool = Tr ) logger.info(info_msg) + fp_all[site].ds = ds def convert_to_list( x: list[str | None] | str | None, length: int, name: str | None = None From aa929ea9cf04b04fc7b5998162af85b4f7cebdf3 Mon Sep 17 00:00:00 2001 From: SutarPrasad Date: Wed, 25 Mar 2026 18:11:16 +0000 Subject: [PATCH 32/68] datatree compatiblity --- openghg_inversions/filters.py | 48 +++++++++++++++++++++++------ openghg_inversions/hbmcmc/hbmcmc.py | 13 ++++++-- 2 files changed, 50 insertions(+), 11 deletions(-) diff --git a/openghg_inversions/filters.py b/openghg_inversions/filters.py index a9308fab..3e26e3bd 100644 --- a/openghg_inversions/filters.py +++ b/openghg_inversions/filters.py @@ -134,15 +134,45 @@ def filtering( for site in filters: if filters[site] is not None and site in sites: for filt in filters[site]: - n_nofilter = datasets[site].time.values.shape[0] - - datasets[site] = filtering_functions[filt](datasets[site], keep_missing=keep_missing) - - n_filter = datasets[site].time.values.shape[0] - n_dropped = n_nofilter - n_filter - perc_dropped = np.round(n_dropped / n_nofilter * 100, 2) - print(f"{filt} filter removed {n_dropped} ({perc_dropped} %) obs at site {site}") - if n_filter == 0: break # no values left, so we won't apply remaining filters + + site_entry = datasets[site] + + # --- DataTree handling --- + if isinstance(site_entry, xr.DataTree): + outer_ds = site_entry.ds + n_nofilter = outer_ds.time.values.shape[0] + + filtered_outer = filtering_functions[filt](outer_ds, keep_missing=keep_missing) + + n_filter = filtered_outer.time.values.shape[0] + n_dropped = n_nofilter - n_filter + perc_dropped = np.round(n_dropped / n_nofilter * 100, 2) + print(f"{filt} filter removed {n_dropped} ({perc_dropped} %) obs at site {site}") + + # Rebuild DataTree: root = filtered outer, inner child reindexed to new time + dt_dict = {"/": filtered_outer} + if "inner" in site_entry.children: + inner_ds = site_entry["inner"].ds + # Keep inner time axis aligned with the (now filtered) outer time axis + dt_dict["inner"] = inner_ds.reindex( + time=filtered_outer.time, fill_value=0.0 + ) + datasets[site] = xr.DataTree.from_dict(dt_dict) + + if n_filter == 0: + break + + # --- original flat Dataset handling (unchanged) --- + else: + n_nofilter = datasets[site].time.values.shape[0] + + datasets[site] = filtering_functions[filt](datasets[site], keep_missing=keep_missing) + + n_filter = datasets[site].time.values.shape[0] + n_dropped = n_nofilter - n_filter + perc_dropped = np.round(n_dropped / n_nofilter * 100, 2) + print(f"{filt} filter removed {n_dropped} ({perc_dropped} %) obs at site {site}") + if n_filter == 0: break # no values left, so we won't apply remaining filters return datasets diff --git a/openghg_inversions/hbmcmc/hbmcmc.py b/openghg_inversions/hbmcmc/hbmcmc.py index 6fc229f4..b613c8e0 100644 --- a/openghg_inversions/hbmcmc/hbmcmc.py +++ b/openghg_inversions/hbmcmc/hbmcmc.py @@ -613,9 +613,16 @@ def fixedbasisMCMC( # # Apply compute before filtering to avoid dask issue for site in sites: - fp_data[site] = fp_data[site].compute() + entry = fp_data[site] + if isinstance(entry, xr.DataTree): + # compute dask arrays in both root and inner child + dt_dict = {"/": entry.ds.compute()} + if "inner" in entry.children: + dt_dict["inner"] = entry["inner"].ds.compute() + fp_data[site] = xr.DataTree.from_dict(dt_dict) + else: + fp_data[site] = entry.compute() fp_data = filtering(fp_data, filters) - # check for sites dropped by filtering dropped_sites = [] for site in sites: @@ -698,6 +705,8 @@ def fixedbasisMCMC( del post_process_args["power"] # add any additional kwargs to mcmc_args (these aren't needed for post processing) + kwargs.pop("mcmc_type", None) + mcmc_args.update(kwargs) end_data = time.time() From 5e1f565c21de1b2e8ac4bb88e6baf8f37a98cdc3 Mon Sep 17 00:00:00 2001 From: SutarPrasad Date: Mon, 30 Mar 2026 10:15:18 +0100 Subject: [PATCH 33/68] removed copy --- openghg_inversions/basis/_helpers.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/openghg_inversions/basis/_helpers.py b/openghg_inversions/basis/_helpers.py index 98ff1648..2486ae58 100644 --- a/openghg_inversions/basis/_helpers.py +++ b/openghg_inversions/basis/_helpers.py @@ -69,19 +69,20 @@ def fp_sensitivity(fp_and_data: dict, basis_func: xr.DataArray | dict[str, xr.Da # if inner domain DataTree child is present, compute H_inner from its fp_x_flux root_ds = fp_and_data[site].ds.copy() # avoid modifying original dataset in-place if isinstance(fp_and_data[site], xr.DataTree) and "inner" in fp_and_data[site].children: - inner_domain_ds = fp_and_data[site]["inner"] - inner_fp_x_flux = inner_domain_ds["fp_x_flux"] + inner_node = fp_and_data[site]["inner"] + # if inner domain fp_x_flux exists, blend it in # fp_x_flux = combine_inner_outer_fp_x_flux( # inner_fp_x_flux=fp_and_data[site]["fp_x_flux_inner"], # outer_fp_x_flux=fp_x_flux, # ) sensitivity_inner = apply_fp_basis_functions( - fp_x_flux=inner_fp_x_flux, + fp_x_flux=inner_node["fp_x_flux"], basis_func=basis_func, ) - inner_domain_ds["H_inner"] = sensitivity_inner - + + fp_and_data[site]["inner/H_inner"] = sensitivity_inner + sensitivity = apply_fp_basis_functions( fp_x_flux=root_ds[fp_x_flux_name], basis_func=basis_func, From fd106fd6f31d9384c1f7472a6807c88a456ae048 Mon Sep 17 00:00:00 2001 From: SutarPrasad Date: Mon, 30 Mar 2026 10:15:43 +0100 Subject: [PATCH 34/68] getting H_inner --- openghg_inversions/hbmcmc/hbmcmc.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/openghg_inversions/hbmcmc/hbmcmc.py b/openghg_inversions/hbmcmc/hbmcmc.py index b613c8e0..1fdee83c 100644 --- a/openghg_inversions/hbmcmc/hbmcmc.py +++ b/openghg_inversions/hbmcmc/hbmcmc.py @@ -118,6 +118,12 @@ def make_inv_inputs( Y = np.zeros(0) siteindicator = np.zeros(0) + Hx_inner = None + has_inner = any( + isinstance(fp_data[s], xr.DataTree) and "H_inner" in fp_data[s]["inner"].ds.data_vars + for s in sites if s in fp_data + ) + for si, site in enumerate(sites): # if site was dropped, skip; this makes the site indicator numbers consistent # even if a site is dropped @@ -166,6 +172,10 @@ def make_inv_inputs( Hx = fp_data[site].ds.H.values if si == 0 else np.hstack((Hx, fp_data[site].ds.H.values)) + if has_inner and isinstance(fp_data[site], xr.DataTree) and "inner" in fp_data[site].children: + h_inner_site = fp_data[site]["inner"].ds["H_inner"].values + Hx_inner = h_inner_site if si == 0 else np.hstack((Hx_inner, h_inner_site)) + if np.isnan(Hx).any(): warnings.warn(f"Hx matrix contains {np.isnan(Hx).flatten().sum()} NaN values") @@ -201,6 +211,7 @@ def make_inv_inputs( mcmc_args = { "Hx": Hx, + "Hx_inner": Hx_inner if Hx_inner is not None else None, "Y": Y, "error": error, "siteindicator": siteindicator, From 5c8bdc7e184fff518be3bc4f6c935616b892019f Mon Sep 17 00:00:00 2001 From: SutarPrasad Date: Mon, 30 Mar 2026 16:49:35 +0100 Subject: [PATCH 35/68] to original data --- openghg_inversions/basis/_helpers.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/openghg_inversions/basis/_helpers.py b/openghg_inversions/basis/_helpers.py index 2486ae58..95046008 100644 --- a/openghg_inversions/basis/_helpers.py +++ b/openghg_inversions/basis/_helpers.py @@ -80,7 +80,7 @@ def fp_sensitivity(fp_and_data: dict, basis_func: xr.DataArray | dict[str, xr.Da fp_x_flux=inner_node["fp_x_flux"], basis_func=basis_func, ) - + # store houter hinner here fp_and_data[site]["inner/H_inner"] = sensitivity_inner sensitivity = apply_fp_basis_functions( @@ -88,9 +88,10 @@ def fp_sensitivity(fp_and_data: dict, basis_func: xr.DataArray | dict[str, xr.Da basis_func=basis_func, ) - # TODO: store houter hinner here - root_ds["H"] = (sensitivity.dims, sensitivity.data, {"long_name": "sensitivity"}) - fp_and_data[site].ds = root_ds + fp_and_data[site]["H"] = xr.DataArray( + sensitivity.data, dims=sensitivity.dims, attrs={"long_name": "sensitivity"} +) + # fp_and_data[site].ds = root_ds return fp_and_data From 431ae08867fd0627782f878e4b2edfd0ba50bcfa Mon Sep 17 00:00:00 2001 From: SutarPrasad Date: Mon, 30 Mar 2026 16:49:48 +0100 Subject: [PATCH 36/68] datgatree compatibility --- openghg_inversions/hbmcmc/hbmcmc.py | 34 +++++++++++++---------------- 1 file changed, 15 insertions(+), 19 deletions(-) diff --git a/openghg_inversions/hbmcmc/hbmcmc.py b/openghg_inversions/hbmcmc/hbmcmc.py index 1fdee83c..408330c3 100644 --- a/openghg_inversions/hbmcmc/hbmcmc.py +++ b/openghg_inversions/hbmcmc/hbmcmc.py @@ -146,35 +146,31 @@ def make_inv_inputs( # repeatability/variability chosen/combined into mf_error in `get_data.py` error = np.concatenate((error, fp_data[site].ds.mf_error.values)) + ds = fp_data[site].ds if isinstance(fp_data[site], xr.DataTree) else fp_data[site] - # make repeatability and variability for outputs (not used directly in inversions) - obs_repeatability = np.concatenate((obs_repeatability, fp_data[site].mf_repeatability.values)) - obs_variability = np.concatenate((obs_variability, fp_data[site].mf_variability.values)) + error = np.concatenate((error, ds["mf_error"].values)) + obs_repeatability = np.concatenate((obs_repeatability, ds["mf_repeatability"].values)) + obs_variability = np.concatenate((obs_variability, ds["mf_variability"].values)) + Y = np.concatenate((Y, ds["mf"].values)) - Y = np.concatenate((Y, fp_data[site].mf.values)) - if fp_data[site].attrs.get("inlet") == "column" or fp_data[site].attrs.get("platform") == "satellite": - obs_prior_factor = np.concatenate((obs_prior_factor, fp_data[site].mf_prior_factor.values)) + if ds.attrs.get("inlet") == "column" or ds.attrs.get("platform") == "satellite": + obs_prior_factor = np.concatenate((obs_prior_factor, ds["mf_prior_factor"].values)) obs_prior_upper_level_factor = np.concatenate( - (obs_prior_upper_level_factor, fp_data[site].mf_prior_upper_level_factor.values) + (obs_prior_upper_level_factor, ds["mf_prior_upper_level_factor"].values) ) else: - # If not a column/satellite measurement, set prior factors to zero - # This is required if there is mix of insitu and column measurements - obs_prior_factor = np.concatenate((obs_prior_factor, np.zeros(fp_data[site].mf.size))) + obs_prior_factor = np.concatenate((obs_prior_factor, np.zeros(ds["mf"].size))) obs_prior_upper_level_factor = np.concatenate( - (obs_prior_upper_level_factor, np.zeros(fp_data[site].mf.size)) + (obs_prior_upper_level_factor, np.zeros(ds["mf"].size)) ) - siteindicator = np.concatenate((siteindicator, np.ones_like(fp_data[site].mf.values) * si)) - if si == 0: - Ytime = fp_data[site].time.values - else: - Ytime = np.concatenate((Ytime, fp_data[site].time.values)) - Hx = fp_data[site].ds.H.values if si == 0 else np.hstack((Hx, fp_data[site].ds.H.values)) + siteindicator = np.concatenate((siteindicator, np.ones_like(ds["mf"].values) * si)) + Ytime = ds["time"].values if si == 0 else np.concatenate((Ytime, ds["time"].values)) + Hx = ds["H"].values if si == 0 else np.hstack((Hx, ds["H"].values)) if has_inner and isinstance(fp_data[site], xr.DataTree) and "inner" in fp_data[site].children: - h_inner_site = fp_data[site]["inner"].ds["H_inner"].values - Hx_inner = h_inner_site if si == 0 else np.hstack((Hx_inner, h_inner_site)) + h_inner = fp_data[site]["inner"].ds["H_inner"].values + Hx_inner = h_inner if Hx_inner is None else np.hstack((Hx_inner, h_inner)) if np.isnan(Hx).any(): warnings.warn(f"Hx matrix contains {np.isnan(Hx).flatten().sum()} NaN values") From 59f0211126cebca01fda3455f5f20b49aa914273 Mon Sep 17 00:00:00 2001 From: SutarPrasad Date: Mon, 30 Mar 2026 16:50:20 +0100 Subject: [PATCH 37/68] ingesting H_inner in pymc --- openghg_inversions/hbmcmc/inversion_pymc.py | 29 ++++++++++++++++++++- 1 file changed, 28 insertions(+), 1 deletion(-) diff --git a/openghg_inversions/hbmcmc/inversion_pymc.py b/openghg_inversions/hbmcmc/inversion_pymc.py index e2d480cd..a25d6aec 100644 --- a/openghg_inversions/hbmcmc/inversion_pymc.py +++ b/openghg_inversions/hbmcmc/inversion_pymc.py @@ -121,6 +121,7 @@ def _make_coords( Hbc: np.ndarray | None = None, sites: list[str] | None = None, sigma_per_site: bool = False, + Hx_inner: np.ndarray | None = None, ) -> dict: result = { "nmeasure": np.arange(len(Y)), @@ -129,6 +130,8 @@ def _make_coords( "nsigma_time": np.unique(sigma_freq_indices), "nsigma_site": np.unique(site_indicator) if sigma_per_site else [0], } + if Hx_inner is not None: + result["nx_inner"] = np.arange(Hx_inner.shape[0]) if Hbc is not None: result["nbc"] = np.arange(Hbc.shape[0]) return result @@ -141,6 +144,7 @@ def inferpymc( siteindicator: np.ndarray, sigma_freq_index: np.ndarray, Hbc: np.ndarray | None = None, + Hx_inner: np.ndarray | None = None, xprior: dict = {"pdf": "normal", "mu": 1.0, "sigma": 1.0}, bcprior: dict = {"pdf": "normal", "mu": 1.0, "sigma": 1.0}, sigprior: dict = {"pdf": "uniform", "lower": 0.1, "upper": 3.0}, @@ -318,6 +322,27 @@ def inferpymc( hx = pm.Data("hx", hx, dims=("nmeasure", "nx")) mu = pm.Deterministic("mu", pt.dot(hx, x), dims="nmeasure") + if Hx_inner is not None: + # Split outer contribution: H_outer - H_inner (everything outside inner domain) + # + H_inner · x_inner (inner domain resolved separately) + hx_outer_contrib = pm.Data( + "hx_outer_contrib", (Hx.T - Hx_inner.T), dims=("nmeasure", "nx") + ) + hx_inner = pm.Data("hx_inner", Hx_inner.T, dims=("nmeasure", "nx")) + x = parse_prior("x", xprior, dims="nx") + x_inner = parse_prior("x_inner", xprior, dims="nx") # same prior, separate RV + step1_vars += [x, x_inner] + mu = pm.Deterministic( + "mu", + pt.dot(hx_outer_contrib, x) + pt.dot(hx_inner, x_inner), + dims="nmeasure", + ) + else: + hx_data = pm.Data("hx", hx, dims=("nmeasure", "nx")) + x = parse_prior("x", xprior, dims="nx") + step1_vars.append(x) + mu = pm.Deterministic("mu", pt.dot(hx_data, x), dims="nmeasure") + if use_bc: hbc = pm.Data("hbc", hbc, dims=("nmeasure", "nbc")) mu_bc = pm.Deterministic("mu_bc", pt.dot(hbc, bc), dims="nmeasure") @@ -425,7 +450,9 @@ def inferpymc( "model": model, "trace": trace, } - + if Hx_inner is not None: + result["xouts_inner"] = posterior_burned.x_inner + if use_bc: result["bcouts"] = bcouts result["YBCtrace"] = YBCtrace.values.T From e809ab409b11eeb734ab0b8ac5e6cd52c7d1222b Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Wed, 1 Apr 2026 10:38:36 +0000 Subject: [PATCH 38/68] Initial plan From 7a0247ae604c9c60a261e093518117dba129c230 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Wed, 1 Apr 2026 11:07:08 +0000 Subject: [PATCH 39/68] Fix DataTree sensitivity matrix persistence and dask compute loop for inner domain Agent-Logs-Url: https://github.com/openghg/openghg_inversions/sessions/b559cdcb-596f-4b24-9aed-48e16eefcf27 Co-authored-by: SutarPrasad <75735315+SutarPrasad@users.noreply.github.com> --- openghg_inversions/basis/_helpers.py | 66 ++++++++++++++++++---------- openghg_inversions/hbmcmc/hbmcmc.py | 22 ++++++++-- tests/test_full_inversion.py | 30 ++----------- 3 files changed, 65 insertions(+), 53 deletions(-) diff --git a/openghg_inversions/basis/_helpers.py b/openghg_inversions/basis/_helpers.py index 95046008..a401f845 100644 --- a/openghg_inversions/basis/_helpers.py +++ b/openghg_inversions/basis/_helpers.py @@ -66,32 +66,54 @@ def fp_sensitivity(fp_and_data: dict, basis_func: xr.DataArray | dict[str, xr.Da fp_and_data[".basis"] = basis_func for site in sites: - # if inner domain DataTree child is present, compute H_inner from its fp_x_flux - root_ds = fp_and_data[site].ds.copy() # avoid modifying original dataset in-place - if isinstance(fp_and_data[site], xr.DataTree) and "inner" in fp_and_data[site].children: - inner_node = fp_and_data[site]["inner"] - - # if inner domain fp_x_flux exists, blend it in - # fp_x_flux = combine_inner_outer_fp_x_flux( - # inner_fp_x_flux=fp_and_data[site]["fp_x_flux_inner"], - # outer_fp_x_flux=fp_x_flux, - # ) - sensitivity_inner = apply_fp_basis_functions( - fp_x_flux=inner_node["fp_x_flux"], - basis_func=basis_func, - ) - # store houter hinner here - fp_and_data[site]["inner/H_inner"] = sensitivity_inner - + entry = fp_and_data[site] + + # Get root dataset (works for both DataTree and plain Dataset) + if isinstance(entry, xr.DataTree): + root_ds = entry.ds + else: + root_ds = entry + + # Compute H for the root (outer) node using root's fp_x_flux sensitivity = apply_fp_basis_functions( fp_x_flux=root_ds[fp_x_flux_name], basis_func=basis_func, ) - - fp_and_data[site]["H"] = xr.DataArray( - sensitivity.data, dims=sensitivity.dims, attrs={"long_name": "sensitivity"} -) - # fp_and_data[site].ds = root_ds + new_root_ds = root_ds.assign( + H=xr.DataArray( + sensitivity.data, + dims=sensitivity.dims, + attrs={"long_name": "sensitivity"}, + ) + ) + + # Compute H_inner for the inner child node if present + new_inner = None + if isinstance(entry, xr.DataTree) and "inner" in entry.children: + inner_ds = entry["inner"].ds + sensitivity_inner = apply_fp_basis_functions( + fp_x_flux=inner_ds["fp_x_flux"], + basis_func=basis_func, + ) + new_inner_ds = inner_ds.assign( + H_inner=xr.DataArray( + sensitivity_inner.data, + dims=sensitivity_inner.dims, + attrs={"long_name": "inner_sensitivity"}, + ) + ) + # Rebuild inner child DataTree preserving its own children + new_inner = xr.DataTree(dataset=new_inner_ds, children=dict(entry["inner"].children)) + + # Rebuild the site DataTree/Dataset with the updated root and children + if isinstance(entry, xr.DataTree): + children = {k: v for k, v in entry.children.items() if k != "inner"} + if new_inner is not None: + children["inner"] = new_inner + fp_and_data[site] = xr.DataTree(dataset=new_root_ds, children=children) + else: + fp_and_data[site] = new_root_ds + return fp_and_data diff --git a/openghg_inversions/hbmcmc/hbmcmc.py b/openghg_inversions/hbmcmc/hbmcmc.py index 408330c3..71c0164a 100644 --- a/openghg_inversions/hbmcmc/hbmcmc.py +++ b/openghg_inversions/hbmcmc/hbmcmc.py @@ -104,9 +104,24 @@ def make_inv_inputs( "mf_mod", ] for site in sites: - to_compute_site = [dv for dv in to_compute if dv in fp_data[site].data_vars] - for var in to_compute_site: - fp_data[site][var] = fp_data[site][var].compute() + site_entry = fp_data[site] + # Use root node's data_vars only to avoid picking up inner child variables + if isinstance(site_entry, xr.DataTree): + root_data_vars = site_entry.ds.data_vars + else: + root_data_vars = site_entry.data_vars + + to_compute_site = [dv for dv in to_compute if dv in root_data_vars] + if to_compute_site: + if isinstance(site_entry, xr.DataTree): + # Compute the needed variables in the root dataset and rebuild the DataTree + computed_root = site_entry.ds.assign( + {var: site_entry.ds[var].compute() for var in to_compute_site} + ) + fp_data[site] = xr.DataTree(dataset=computed_root, children=dict(site_entry.children)) + else: + for var in to_compute_site: + fp_data[site][var] = fp_data[site][var].compute() # Get inputs ready error = np.zeros(0) @@ -145,7 +160,6 @@ def make_inv_inputs( fp_data[site] = fp_data[site].dropna("time", subset=drop_vars) # repeatability/variability chosen/combined into mf_error in `get_data.py` - error = np.concatenate((error, fp_data[site].ds.mf_error.values)) ds = fp_data[site].ds if isinstance(fp_data[site], xr.DataTree) else fp_data[site] error = np.concatenate((error, ds["mf_error"].values)) diff --git a/tests/test_full_inversion.py b/tests/test_full_inversion.py index 00b94a4f..860478fe 100644 --- a/tests/test_full_inversion.py +++ b/tests/test_full_inversion.py @@ -236,39 +236,15 @@ def inner_domain_mcmc_args(tmp_path, mhd_with_inner_domain_ch4_data_args): mcmc_args.update( { "outputname": "inner_domain_test_run", - "outputpath": "/group/chem/acrg/prasad/job_scripts/openghg_inversions/inversion_outputs/paris_bc", + "outputpath": str(tmp_path), "basis_algorithm": "quadtree", - "bc_basis_case" : "NESW", - "fp_basis_case" : None, - "nbasis" : 250, - "basis_directory" : "/group/chem/acrg/LPDM/basis_functions/", - "bc_basis_directory" : "/group/chem/acrg/LPDM/bc_basis_functions/", - "country_file" : "/group/chem/acrg/LPDM/countries/country_EUROPE_EEZ_PARIS_gapfilled.nc", "basis_output_path": str(tmp_path), "nbasis": 4, "nit": 1, - "burn": 10, - "tune": 10, + "burn": 0, + "tune": 0, "nchain": 1, - "mcmc_type" : "fixed_basis", "reload_merged_data": False, - "xprior": {"pdf": "normal", "mu": 1.0, "sigma": 1.0}, - "bcprior": {"pdf": "normal", "mu": 1.0, "sigma": 0.5}, - "sigprior": {"pdf": "uniform", "lower": 0.5, "upper": 3.0}, - "use_bc": True, - "averaging_error": True, - "min_error": 0.0, - "fix_basis_outer_regions" : "False", - "nuts_sampler" : "numpyro", - "save_trace" : True, - "pollution_events_from_obs" : True, - "no_model_error" : False, - "bc_freq" : None, - "sigma_freq" : None, - "sigma_per_site" : True, - "inlet" : [slice(0,25)], - "instrument" : ['multiple'], - "filters" : {'MHD' : ['pblh_min','pblh_inlet_diff']} } ) return mcmc_args From 12c6d5e920f9927c6146cb8bcca64ab037032638 Mon Sep 17 00:00:00 2001 From: SutarPrasad Date: Thu, 9 Apr 2026 12:30:09 +0100 Subject: [PATCH 40/68] fp_x_flux_inner --- openghg_inversions/inversion_data/scenario.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/openghg_inversions/inversion_data/scenario.py b/openghg_inversions/inversion_data/scenario.py index 27b39fda..774d4a1f 100644 --- a/openghg_inversions/inversion_data/scenario.py +++ b/openghg_inversions/inversion_data/scenario.py @@ -125,6 +125,8 @@ def merged_scenario_data( ) inner_domain_merged = inner_domain_merged.rename({"lat": "lat_inner", "lon": "lon_inner"}) inner_domain_merged = inner_domain_merged.reset_index(["lat_inner", "lon_inner"]) + inner_domain_merged = inner_domain_merged.rename({"fp_x_flux": "fp_x_flux_inner"}) + dt_dict["inner"] = inner_domain_merged return xr.DataTree.from_dict(dt_dict) From c17badce0e5e43527cabc6ef076d78b11afc177d Mon Sep 17 00:00:00 2001 From: SutarPrasad Date: Thu, 9 Apr 2026 12:31:11 +0100 Subject: [PATCH 41/68] fp_x_flux_inner --- openghg_inversions/basis/_helpers.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/openghg_inversions/basis/_helpers.py b/openghg_inversions/basis/_helpers.py index a401f845..dd12d5bf 100644 --- a/openghg_inversions/basis/_helpers.py +++ b/openghg_inversions/basis/_helpers.py @@ -79,7 +79,7 @@ def fp_sensitivity(fp_and_data: dict, basis_func: xr.DataArray | dict[str, xr.Da fp_x_flux=root_ds[fp_x_flux_name], basis_func=basis_func, ) - new_root_ds = root_ds.assign( + root_ds = root_ds.assign( H=xr.DataArray( sensitivity.data, dims=sensitivity.dims, @@ -92,10 +92,10 @@ def fp_sensitivity(fp_and_data: dict, basis_func: xr.DataArray | dict[str, xr.Da if isinstance(entry, xr.DataTree) and "inner" in entry.children: inner_ds = entry["inner"].ds sensitivity_inner = apply_fp_basis_functions( - fp_x_flux=inner_ds["fp_x_flux"], + fp_x_flux=inner_ds["fp_x_flux_inner"], basis_func=basis_func, ) - new_inner_ds = inner_ds.assign( + inner_ds = inner_ds.assign( H_inner=xr.DataArray( sensitivity_inner.data, dims=sensitivity_inner.dims, @@ -103,16 +103,16 @@ def fp_sensitivity(fp_and_data: dict, basis_func: xr.DataArray | dict[str, xr.Da ) ) # Rebuild inner child DataTree preserving its own children - new_inner = xr.DataTree(dataset=new_inner_ds, children=dict(entry["inner"].children)) + new_inner = xr.DataTree(dataset=inner_ds, children=dict(entry["inner"].children)) # Rebuild the site DataTree/Dataset with the updated root and children if isinstance(entry, xr.DataTree): children = {k: v for k, v in entry.children.items() if k != "inner"} if new_inner is not None: children["inner"] = new_inner - fp_and_data[site] = xr.DataTree(dataset=new_root_ds, children=children) + fp_and_data[site] = xr.DataTree(dataset=root_ds, children=children) else: - fp_and_data[site] = new_root_ds + fp_and_data[site] = root_ds return fp_and_data From b0322b9b4951873aca76b48d27cf2ee81461bd7b Mon Sep 17 00:00:00 2001 From: SutarPrasad Date: Thu, 9 Apr 2026 12:31:23 +0100 Subject: [PATCH 42/68] test --- tests/test_full_inversion.py | 30 +++++++++++++++++++++++++++--- 1 file changed, 27 insertions(+), 3 deletions(-) diff --git a/tests/test_full_inversion.py b/tests/test_full_inversion.py index 860478fe..00b94a4f 100644 --- a/tests/test_full_inversion.py +++ b/tests/test_full_inversion.py @@ -236,15 +236,39 @@ def inner_domain_mcmc_args(tmp_path, mhd_with_inner_domain_ch4_data_args): mcmc_args.update( { "outputname": "inner_domain_test_run", - "outputpath": str(tmp_path), + "outputpath": "/group/chem/acrg/prasad/job_scripts/openghg_inversions/inversion_outputs/paris_bc", "basis_algorithm": "quadtree", + "bc_basis_case" : "NESW", + "fp_basis_case" : None, + "nbasis" : 250, + "basis_directory" : "/group/chem/acrg/LPDM/basis_functions/", + "bc_basis_directory" : "/group/chem/acrg/LPDM/bc_basis_functions/", + "country_file" : "/group/chem/acrg/LPDM/countries/country_EUROPE_EEZ_PARIS_gapfilled.nc", "basis_output_path": str(tmp_path), "nbasis": 4, "nit": 1, - "burn": 0, - "tune": 0, + "burn": 10, + "tune": 10, "nchain": 1, + "mcmc_type" : "fixed_basis", "reload_merged_data": False, + "xprior": {"pdf": "normal", "mu": 1.0, "sigma": 1.0}, + "bcprior": {"pdf": "normal", "mu": 1.0, "sigma": 0.5}, + "sigprior": {"pdf": "uniform", "lower": 0.5, "upper": 3.0}, + "use_bc": True, + "averaging_error": True, + "min_error": 0.0, + "fix_basis_outer_regions" : "False", + "nuts_sampler" : "numpyro", + "save_trace" : True, + "pollution_events_from_obs" : True, + "no_model_error" : False, + "bc_freq" : None, + "sigma_freq" : None, + "sigma_per_site" : True, + "inlet" : [slice(0,25)], + "instrument" : ['multiple'], + "filters" : {'MHD' : ['pblh_min','pblh_inlet_diff']} } ) return mcmc_args From 3720ac38d6891d13453f334289729ed3226543e5 Mon Sep 17 00:00:00 2001 From: SutarPrasad Date: Thu, 9 Apr 2026 13:02:23 +0100 Subject: [PATCH 43/68] masking inner domain from standard domain --- openghg_inversions/basis/_helpers.py | 85 +++++++++++++++------------- 1 file changed, 46 insertions(+), 39 deletions(-) diff --git a/openghg_inversions/basis/_helpers.py b/openghg_inversions/basis/_helpers.py index dd12d5bf..8f4f620c 100644 --- a/openghg_inversions/basis/_helpers.py +++ b/openghg_inversions/basis/_helpers.py @@ -63,56 +63,63 @@ def fp_sensitivity(fp_and_data: dict, basis_func: xr.DataArray | dict[str, xr.Da if "time" in basis_func.dims and basis_func.sizes["time"] <= 1: basis_func = basis_func.squeeze("time") - fp_and_data[".basis"] = basis_func + fp_and_data[".basis"] = basis_func for site in sites: entry = fp_and_data[site] - # Get root dataset (works for both DataTree and plain Dataset) + # extract root fp_x_flux and mask inner-domain cells to zero if inner domain exists. This ensures that the outer sensitivity H only reflects the outer domain fluxes. if isinstance(entry, xr.DataTree): root_ds = entry.ds - else: - root_ds = entry - - # Compute H for the root (outer) node using root's fp_x_flux - sensitivity = apply_fp_basis_functions( - fp_x_flux=root_ds[fp_x_flux_name], - basis_func=basis_func, - ) - root_ds = root_ds.assign( - H=xr.DataArray( - sensitivity.data, - dims=sensitivity.dims, - attrs={"long_name": "sensitivity"}, - ) - ) - - # Compute H_inner for the inner child node if present - new_inner = None - if isinstance(entry, xr.DataTree) and "inner" in entry.children: - inner_ds = entry["inner"].ds - sensitivity_inner = apply_fp_basis_functions( - fp_x_flux=inner_ds["fp_x_flux_inner"], + fp_x_flux_outer = root_ds[fp_x_flux_name] + + if "inner" in entry.children: + inner_ds = entry["inner"].ds + # Build a boolean mask: True where lat/lon is inside inner domain bounds + inner_lat_min = float(inner_ds.lat.min()) + inner_lat_max = float(inner_ds.lat.max()) + inner_lon_min = float(inner_ds.lon.min()) + inner_lon_max = float(inner_ds.lon.max()) + + lat_mask = (fp_x_flux_outer.lat >= inner_lat_min) & (fp_x_flux_outer.lat <= inner_lat_max) + lon_mask = (fp_x_flux_outer.lon >= inner_lon_min) & (fp_x_flux_outer.lon <= inner_lon_max) + inner_region_mask = lat_mask & lon_mask # broadcasts over (lat, lon) + + # Zero out those cells in the outer fp_x_flux + fp_x_flux_for_H = fp_x_flux_outer.where(~inner_region_mask, other=0.0) + else: + fp_x_flux_for_H = fp_x_flux_outer + + # Compute outer H from the masked fp_x_flux + sensitivity = apply_fp_basis_functions( + fp_x_flux=fp_x_flux_for_H, basis_func=basis_func, ) - inner_ds = inner_ds.assign( - H_inner=xr.DataArray( - sensitivity_inner.data, - dims=sensitivity_inner.dims, - attrs={"long_name": "inner_sensitivity"}, + + # Compute H_inner from the inner child's fp_x_flux (its own lat/lon grid) + if "inner" in entry.children: + inner_fp_x_flux = entry["inner"].ds["fp_x_flux_inner"] + H_inner = apply_fp_basis_functions( + fp_x_flux=inner_fp_x_flux, + basis_func=basis_func, ) - ) - # Rebuild inner child DataTree preserving its own children - new_inner = xr.DataTree(dataset=inner_ds, children=dict(entry["inner"].children)) + # Write both back into the DataTree + new_root = root_ds.assign({"H": sensitivity}) + new_inner = entry["inner"].ds.assign({"H_inner": H_inner}) + fp_and_data[site] = xr.DataTree.from_dict({ + "/": new_root, + "inner": new_inner, + }) + else: + fp_and_data[site] = xr.DataTree(dataset=root_ds.assign({"H": sensitivity})) - # Rebuild the site DataTree/Dataset with the updated root and children - if isinstance(entry, xr.DataTree): - children = {k: v for k, v in entry.children.items() if k != "inner"} - if new_inner is not None: - children["inner"] = new_inner - fp_and_data[site] = xr.DataTree(dataset=root_ds, children=children) else: - fp_and_data[site] = root_ds + # Legacy: plain xr.Dataset path — unchanged + sensitivity = apply_fp_basis_functions( + fp_x_flux=entry[fp_x_flux_name], + basis_func=basis_func, + ) + fp_and_data[site]["H"] = sensitivity return fp_and_data From 0b940b7130ce06817903716927edb3156ea6dff4 Mon Sep 17 00:00:00 2001 From: SutarPrasad Date: Mon, 13 Apr 2026 14:22:59 +0100 Subject: [PATCH 44/68] empty root two separate nodes and same dim names --- openghg_inversions/basis/_helpers.py | 8 +++++--- openghg_inversions/inversion_data/get_data.py | 2 +- openghg_inversions/inversion_data/scenario.py | 7 ++----- 3 files changed, 8 insertions(+), 9 deletions(-) diff --git a/openghg_inversions/basis/_helpers.py b/openghg_inversions/basis/_helpers.py index 8f4f620c..769810bf 100644 --- a/openghg_inversions/basis/_helpers.py +++ b/openghg_inversions/basis/_helpers.py @@ -63,14 +63,14 @@ def fp_sensitivity(fp_and_data: dict, basis_func: xr.DataArray | dict[str, xr.Da if "time" in basis_func.dims and basis_func.sizes["time"] <= 1: basis_func = basis_func.squeeze("time") - fp_and_data[".basis"] = basis_func + fp_and_data[".basis"] = basis_func for site in sites: entry = fp_and_data[site] # extract root fp_x_flux and mask inner-domain cells to zero if inner domain exists. This ensures that the outer sensitivity H only reflects the outer domain fluxes. if isinstance(entry, xr.DataTree): - root_ds = entry.ds + root_ds = entry["standard"].ds fp_x_flux_outer = root_ds[fp_x_flux_name] if "inner" in entry.children: @@ -172,7 +172,9 @@ def apply_fp_basis_functions( _, basis_aligned = xr.align(fp_x_flux.isel(time=0), basis_func, join="override") basis_mat = get_xr_dummies(basis_aligned, cat_dim="region") - sensitivity = sparse_xr_dot(basis_mat, fp_x_flux.fillna(0.0), dim=["lat", "lon"]) + spatial_dims = ["lat", "lon"] + + sensitivity = sparse_xr_dot(basis_mat, fp_x_flux.fillna(0.0), dim=spatial_dims) if sensitivity.dims[:2] != ("region", "time"): sensitivity = sensitivity.transpose("region", "time", ...) diff --git a/openghg_inversions/inversion_data/get_data.py b/openghg_inversions/inversion_data/get_data.py index 612f37af..847e8771 100644 --- a/openghg_inversions/inversion_data/get_data.py +++ b/openghg_inversions/inversion_data/get_data.py @@ -385,7 +385,7 @@ def data_processing_surface_notracer( # scenario_combined returns a datatree. The standard_domain scenario is stored at the root, and any inner domain scenario is stored at "/{inner_domain}". Here we take the root scenario to store the calibration scale and units, since these should be the same for the inner domain scenario. If inner domain scenario is None the root scenario is the only scenario, so this will still work. - root_ds = scenario_combined.ds + root_ds = scenario_combined["standard"].ds units[site] = root_ds.mf.attrs.get("units") if "satellite" not in platform: diff --git a/openghg_inversions/inversion_data/scenario.py b/openghg_inversions/inversion_data/scenario.py index 774d4a1f..53c78e2f 100644 --- a/openghg_inversions/inversion_data/scenario.py +++ b/openghg_inversions/inversion_data/scenario.py @@ -101,7 +101,7 @@ def merged_scenario_data( cache=False, ) - dt_dict: dict[str, xr.Dataset] = {"/": scenario_combined} + dt_dict: dict[str, xr.Dataset] = {"/standard": scenario_combined} if inner_footprint_data is not None: # Mask the EUROPE flux to the inner domain extent (zero outside), # regridded to the inner footprint lat/lon grid. @@ -123,10 +123,7 @@ def merged_scenario_data( inner_domain_merged = inner_domain_merged.reindex( time=scenario_combined.time, fill_value=0.0 ) - inner_domain_merged = inner_domain_merged.rename({"lat": "lat_inner", "lon": "lon_inner"}) - inner_domain_merged = inner_domain_merged.reset_index(["lat_inner", "lon_inner"]) - inner_domain_merged = inner_domain_merged.rename({"fp_x_flux": "fp_x_flux_inner"}) - dt_dict["inner"] = inner_domain_merged + dt_dict["/inner"] = inner_domain_merged = inner_domain_merged return xr.DataTree.from_dict(dt_dict) From 4584cfe8f3b89919a645416743a20cce734e386f Mon Sep 17 00:00:00 2001 From: SutarPrasad Date: Thu, 23 Apr 2026 08:02:58 +0100 Subject: [PATCH 45/68] handling of separate nodes in the code --- openghg_inversions/inversion_data/get_data.py | 120 +++++++++--------- 1 file changed, 60 insertions(+), 60 deletions(-) diff --git a/openghg_inversions/inversion_data/get_data.py b/openghg_inversions/inversion_data/get_data.py index 847e8771..5b41f744 100644 --- a/openghg_inversions/inversion_data/get_data.py +++ b/openghg_inversions/inversion_data/get_data.py @@ -60,69 +60,69 @@ def add_obs_error(sites: list[str], fp_all: dict, add_averaging_error: bool = Tr """ # TODO: do we want to fill missing values in repeatability or variability? for site in sites: - ds = fp_all[site].ds.copy() - variability_missing = False - if "mf_variability" not in ds: - ds["mf_variability"] = xr.zeros_like(ds.mf) - ds["mf_variability"].attrs["long_name"] = ds.mf.attrs.get("long_name", "") + "_variability" - variability_missing = True - - if "mf_repeatability" not in ds: - if variability_missing: - raise ValueError(f"Obs data for site {site} is missing both repeatability and variability.") - - ds["mf_repeatability"] = xr.zeros_like(ds.mf_variability) - ds["mf_repeatability"].attrs["long_name"] = ds.mf.attrs.get("long_name", "") + "_repeatability" - - ds["mf_error"] = ds["mf_variability"] - - if add_averaging_error: - logger.info( - "`mf_repeatability` not present; using `mf_variability` for `mf_error` at site %s", site + for node_name in ["standard", "inner"]: + if node_name not in fp_all[site]: + continue + ds = fp_all[site][node_name].ds.copy() + variability_missing = False + if "mf_variability" not in ds: + ds["mf_variability"] = xr.zeros_like(ds.mf) + ds["mf_variability"].attrs["long_name"] = ds.mf.attrs.get("long_name", "") + "_variability" + variability_missing = True + + if "mf_repeatability" not in ds: + if variability_missing: + raise ValueError(f"Obs data for site {site} ({node_name}) is missing both repeatability and variability.") + + ds["mf_repeatability"] = xr.zeros_like(ds.mf_variability) + ds["mf_repeatability"].attrs["long_name"] = ds.mf.attrs.get("long_name", "") + "_repeatability" + + ds["mf_error"] = ds["mf_variability"] + + if add_averaging_error: + logger.info( + "`mf_repeatability` not present; using `mf_variability` for `mf_error` at site %s (%s)", site, node_name + ) + + elif add_averaging_error: + ds["mf_error"] = np.sqrt( + ds["mf_repeatability"].fillna(0) ** 2 + ds["mf_variability"].fillna(0) ** 2 + ) + else: + ds["mf_error"] = ds["mf_repeatability"] + + ds["mf_error"].attrs["long_name"] = ds.mf.attrs.get("long_name", "") + "_error" + ds["mf_error"].attrs["units"] = ds.mf.attrs.get("units", None) + + err0 = (ds["mf_error"] == 0) | (ds["mf_error"].isnull()) + + if err0.any(): + percent0 = 100 * err0.mean() + logger.warning( + ( + "`mf_error` is zero/nan for %.2f percent of times at site %s (%s);" + "filling with max(median(mf_error), std(mf))." + ), + percent0, + site, + node_name, ) - elif add_averaging_error: - # Fill with zeros so that if one of repeatability and variability is not NaN, then mf_error will not be NaN. - ds["mf_error"] = np.sqrt( - ds["mf_repeatability"].fillna(0) ** 2 + ds["mf_variability"].fillna(0) ** 2 - ) - else: - ds["mf_error"] = ds["mf_repeatability"] - - ds["mf_error"].attrs["long_name"] = ds.mf.attrs.get("long_name", "") + "_error" - ds["mf_error"].attrs["units"] = ds.mf.attrs.get("units", None) - - # warnings/info for debugging - err0 = (ds["mf_error"] == 0) | ( - ds["mf_error"].isnull() - ) # might have NaN if add_averaging_error is False - - if err0.any(): - percent0 = 100 * err0.mean() - logger.warning( - ( - "`mf_error` is zero/nan for %.2f percent of times at site %s;" - "filling with max(median(mf_error), std(mf))." - ), - percent0, - site, - ) - - mf_err_da = ds["mf_error"].as_numpy() # load into memory to avoid Dask issues - fill_value = np.nanmax( - [ - mf_err_da.where(mf_err_da != 0).dropna(dim="time").median(), - ds["mf"].std(dim="time"), - ] - ) - ds["mf_error"] = mf_err_da.where(mf_err_da != 0, fill_value) - info_msg = ( - "If `averaging_period` matches the frequency of the obs data, then `mf_variability` " - "will be zero. Try setting `averaging_period = None`." - ) - logger.info(info_msg) + mf_err_da = ds["mf_error"].as_numpy() + fill_value = np.nanmax( + [ + mf_err_da.where(mf_err_da != 0).dropna(dim="time").median(), + ds["mf"].std(dim="time"), + ] + ) + ds["mf_error"] = mf_err_da.where(mf_err_da != 0, fill_value) + info_msg = ( + "If `averaging_period` matches the frequency of the obs data, then `mf_variability` " + "will be zero. Try setting `averaging_period = None`." + ) + logger.info(info_msg) - fp_all[site].ds = ds + fp_all[site][node_name].ds = ds def convert_to_list( x: list[str | None] | str | None, length: int, name: str | None = None From 85cab6929e954f00e7fe59e731467d3e00acea11 Mon Sep 17 00:00:00 2001 From: SutarPrasad Date: Thu, 23 Apr 2026 08:05:34 +0100 Subject: [PATCH 46/68] considers hx solved separately and 0 at inner domain --- openghg_inversions/hbmcmc/inversion_pymc.py | 17 ++++++----------- 1 file changed, 6 insertions(+), 11 deletions(-) diff --git a/openghg_inversions/hbmcmc/inversion_pymc.py b/openghg_inversions/hbmcmc/inversion_pymc.py index a25d6aec..e025a695 100644 --- a/openghg_inversions/hbmcmc/inversion_pymc.py +++ b/openghg_inversions/hbmcmc/inversion_pymc.py @@ -320,28 +320,23 @@ def inferpymc( sigma = parse_prior("sigma", sigprior, dims=("nsigma_site", "nsigma_time")) hx = pm.Data("hx", hx, dims=("nmeasure", "nx")) - mu = pm.Deterministic("mu", pt.dot(hx, x), dims="nmeasure") if Hx_inner is not None: - # Split outer contribution: H_outer - H_inner (everything outside inner domain) - # + H_inner · x_inner (inner domain resolved separately) - hx_outer_contrib = pm.Data( - "hx_outer_contrib", (Hx.T - Hx_inner.T), dims=("nmeasure", "nx") - ) - hx_inner = pm.Data("hx_inner", Hx_inner.T, dims=("nmeasure", "nx")) + # Hx_inner has shape (nx_inner, nmeasure); Hx has shape (nx, nmeasure) + # They have different leading dims so cannot be subtracted — keep separate + hx_inner = pm.Data("hx_inner", Hx_inner.T, dims=("nmeasure", "nx_inner")) x = parse_prior("x", xprior, dims="nx") - x_inner = parse_prior("x_inner", xprior, dims="nx") # same prior, separate RV + x_inner = parse_prior("x_inner", xprior, dims="nx_inner") step1_vars += [x, x_inner] mu = pm.Deterministic( "mu", - pt.dot(hx_outer_contrib, x) + pt.dot(hx_inner, x_inner), + pt.dot(hx, x) + pt.dot(hx_inner, x_inner), dims="nmeasure", ) else: - hx_data = pm.Data("hx", hx, dims=("nmeasure", "nx")) x = parse_prior("x", xprior, dims="nx") step1_vars.append(x) - mu = pm.Deterministic("mu", pt.dot(hx_data, x), dims="nmeasure") + mu = pm.Deterministic("mu", pt.dot(hx, x), dims="nmeasure") if use_bc: hbc = pm.Data("hbc", hbc, dims=("nmeasure", "nbc")) From 189082c9a9d2bd82c36b241cdbbadf3395e9028a Mon Sep 17 00:00:00 2001 From: SutarPrasad Date: Thu, 23 Apr 2026 08:06:01 +0100 Subject: [PATCH 47/68] inner domain passing to basis function wrapper --- openghg_inversions/basis/_wrapper.py | 1 + 1 file changed, 1 insertion(+) diff --git a/openghg_inversions/basis/_wrapper.py b/openghg_inversions/basis/_wrapper.py index 6d719e49..a1a0f367 100644 --- a/openghg_inversions/basis/_wrapper.py +++ b/openghg_inversions/basis/_wrapper.py @@ -26,6 +26,7 @@ def basis_functions_wrapper( country_directory: str | None = None, outputname: str | None = None, output_path: str | None = None, + inner_domain: str | None = None, ): """Wrapper function for selecting basis function algorithm. From 53408e0510c7b878cff06b04659d5bcf13717f01 Mon Sep 17 00:00:00 2001 From: SutarPrasad Date: Thu, 23 Apr 2026 08:07:28 +0100 Subject: [PATCH 48/68] passing inner domain --- openghg_inversions/hbmcmc/hbmcmc.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/openghg_inversions/hbmcmc/hbmcmc.py b/openghg_inversions/hbmcmc/hbmcmc.py index 71c0164a..ab98b34b 100644 --- a/openghg_inversions/hbmcmc/hbmcmc.py +++ b/openghg_inversions/hbmcmc/hbmcmc.py @@ -616,7 +616,7 @@ def fixedbasisMCMC( use_bc=use_bc, species=species, domain=domain, - # inner_domain=inner_domain, + inner_domain=inner_domain, start_date=start_date, fix_outer_regions=fix_basis_outer_regions, emissions_name=emissions_name, From 2bec153edcec93a0aa11d4c8db5d822d1ad4464d Mon Sep 17 00:00:00 2001 From: SutarPrasad Date: Tue, 28 Apr 2026 13:26:15 +0100 Subject: [PATCH 49/68] aaplying mask on standard footprint if inner_domain is not none --- openghg_inversions/inversion_data/get_data.py | 35 +++++++++++++++++++ 1 file changed, 35 insertions(+) diff --git a/openghg_inversions/inversion_data/get_data.py b/openghg_inversions/inversion_data/get_data.py index 5b41f744..095c527a 100644 --- a/openghg_inversions/inversion_data/get_data.py +++ b/openghg_inversions/inversion_data/get_data.py @@ -152,6 +152,37 @@ def convert_to_list( return x +def _apply_inner_mask_on_standard_domain(standard_footprint_data: xr.DataArray, inner_footprint_data: xr.DataArray)): + """Apply inner domain mask to standard domain fp_x_flux to ensure that the standard domain sensitivity H only reflects the outer domain fluxes. + + Args: + standard_footprint_data: xr.DataArray containing the standard domain footprint data, with lat/lon coordinates. + inner_footprint_data: xr.DataArray containing the inner domain footprint data, with lat/lon coordinates. + + Returns: + fp: xr.DataArray of the same shape as standard_footprint_data, but with values set to zero in the region covered by inner_footprint_data. + """ + + fp_standard = standard_footprint_data + + if inner_footprint_data is not None: + inner_ds = inner_footprint_data + # Build a boolean mask: True where lat/lon is inside inner domain bounds + inner_lat_min = float(inner_ds.lat.min()) + inner_lat_max = float(inner_ds.lat.max()) + inner_lon_min = float(inner_ds.lon.min()) + inner_lon_max = float(inner_ds.lon.max()) + + lat_mask = (fp_standard.lat >= inner_lat_min) & (fp_standard.lat <= inner_lat_max) + lon_mask = (fp_standard.lon >= inner_lon_min) & (fp_standard.lon <= inner_lon_max) + inner_region_mask = lat_mask & lon_mask # broadcasts over (lat, lon) + + # Zero out those cells in the outer fp + fp = fp_standard.where(~inner_region_mask, other=0.0) + + return fp + + def data_processing_surface_notracer( species: str, sites: list | str, @@ -369,6 +400,9 @@ def data_processing_surface_notracer( obs_data=site_data, stores=inner_footprint_store if inner_footprint_store is not None else footprint_store, ) + + standard_footprint_data = _apply_inner_mask_on_standard_domain(standard_footprint_data= standard_footprint_data, inner_footprint_data= inner_footprint_data) + if inner_footprint_data is None: print( f"\nNo Inner footprint data found for {site} with inlet/height {fp_height[i]}, model {fp_model}, and domain {domain}-{inner_domain}, starting from {start_date} to {end_date}, fp_species {fp_species} and met_model {met_model[i]}, obs_data {site_data} and averaging_period {averaging_period[i]}.", @@ -444,3 +478,4 @@ def data_processing_surface_notracer( print(f"\nfp_all saved in {merged_data_dir}\n") return fp_all, sites, inlet, fp_height, instrument, averaging_period + From ee6856ad4263e3d6951fac7e6c03ad7c1fbdf0b3 Mon Sep 17 00:00:00 2001 From: SutarPrasad Date: Tue, 28 Apr 2026 14:07:25 +0100 Subject: [PATCH 50/68] removed masking from fp_sensitivity as fp is masked in get_data --- openghg_inversions/basis/_helpers.py | 21 +++------------------ 1 file changed, 3 insertions(+), 18 deletions(-) diff --git a/openghg_inversions/basis/_helpers.py b/openghg_inversions/basis/_helpers.py index 769810bf..65b35192 100644 --- a/openghg_inversions/basis/_helpers.py +++ b/openghg_inversions/basis/_helpers.py @@ -73,22 +73,7 @@ def fp_sensitivity(fp_and_data: dict, basis_func: xr.DataArray | dict[str, xr.Da root_ds = entry["standard"].ds fp_x_flux_outer = root_ds[fp_x_flux_name] - if "inner" in entry.children: - inner_ds = entry["inner"].ds - # Build a boolean mask: True where lat/lon is inside inner domain bounds - inner_lat_min = float(inner_ds.lat.min()) - inner_lat_max = float(inner_ds.lat.max()) - inner_lon_min = float(inner_ds.lon.min()) - inner_lon_max = float(inner_ds.lon.max()) - - lat_mask = (fp_x_flux_outer.lat >= inner_lat_min) & (fp_x_flux_outer.lat <= inner_lat_max) - lon_mask = (fp_x_flux_outer.lon >= inner_lon_min) & (fp_x_flux_outer.lon <= inner_lon_max) - inner_region_mask = lat_mask & lon_mask # broadcasts over (lat, lon) - - # Zero out those cells in the outer fp_x_flux - fp_x_flux_for_H = fp_x_flux_outer.where(~inner_region_mask, other=0.0) - else: - fp_x_flux_for_H = fp_x_flux_outer + fp_x_flux_for_H = fp_x_flux_outer # Compute outer H from the masked fp_x_flux sensitivity = apply_fp_basis_functions( @@ -107,8 +92,8 @@ def fp_sensitivity(fp_and_data: dict, basis_func: xr.DataArray | dict[str, xr.Da new_root = root_ds.assign({"H": sensitivity}) new_inner = entry["inner"].ds.assign({"H_inner": H_inner}) fp_and_data[site] = xr.DataTree.from_dict({ - "/": new_root, - "inner": new_inner, + "/standard": new_root, + "/inner": new_inner, }) else: fp_and_data[site] = xr.DataTree(dataset=root_ds.assign({"H": sensitivity})) From 86334d9f52fa9a78906a0681cb7de943b27b099d Mon Sep 17 00:00:00 2001 From: SutarPrasad Date: Tue, 28 Apr 2026 14:07:59 +0100 Subject: [PATCH 51/68] applying mask on standard_fp f inner domain is present --- openghg_inversions/inversion_data/get_data.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/openghg_inversions/inversion_data/get_data.py b/openghg_inversions/inversion_data/get_data.py index 095c527a..ae6f9603 100644 --- a/openghg_inversions/inversion_data/get_data.py +++ b/openghg_inversions/inversion_data/get_data.py @@ -152,7 +152,7 @@ def convert_to_list( return x -def _apply_inner_mask_on_standard_domain(standard_footprint_data: xr.DataArray, inner_footprint_data: xr.DataArray)): +def _apply_inner_mask_on_standard_domain(standard_footprint_data: xr.Dataset, inner_footprint_data: xr.Dataset): """Apply inner domain mask to standard domain fp_x_flux to ensure that the standard domain sensitivity H only reflects the outer domain fluxes. Args: @@ -163,10 +163,10 @@ def _apply_inner_mask_on_standard_domain(standard_footprint_data: xr.DataArray, fp: xr.DataArray of the same shape as standard_footprint_data, but with values set to zero in the region covered by inner_footprint_data. """ - fp_standard = standard_footprint_data + fp_standard = standard_footprint_data.copy() if inner_footprint_data is not None: - inner_ds = inner_footprint_data + inner_ds = inner_footprint_data.copy() # Build a boolean mask: True where lat/lon is inside inner domain bounds inner_lat_min = float(inner_ds.lat.min()) inner_lat_max = float(inner_ds.lat.max()) @@ -401,7 +401,7 @@ def data_processing_surface_notracer( stores=inner_footprint_store if inner_footprint_store is not None else footprint_store, ) - standard_footprint_data = _apply_inner_mask_on_standard_domain(standard_footprint_data= standard_footprint_data, inner_footprint_data= inner_footprint_data) + standard_footprint_data.data = _apply_inner_mask_on_standard_domain(standard_footprint_data= standard_footprint_data.data, inner_footprint_data= inner_footprint_data.data) if inner_footprint_data is None: print( From 29a472678a47855291b7e4d928161c895ca31e31 Mon Sep 17 00:00:00 2001 From: SutarPrasad Date: Tue, 12 May 2026 13:20:50 +0100 Subject: [PATCH 52/68] calculating inner basis data array for inner fp and flux by checking if inner domain is not None Co-authored-by: Copilot --- openghg_inversions/basis/_wrapper.py | 33 +++++++++++++++++++++++++--- 1 file changed, 30 insertions(+), 3 deletions(-) diff --git a/openghg_inversions/basis/_wrapper.py b/openghg_inversions/basis/_wrapper.py index a1a0f367..eadfa480 100644 --- a/openghg_inversions/basis/_wrapper.py +++ b/openghg_inversions/basis/_wrapper.py @@ -82,6 +82,7 @@ def basis_functions_wrapper( Dictionary object similar to fp_all but with information on basis functions and sensitivities """ + inner_basis_data_array = None if use_bc is True and bc_basis_case is None: raise ValueError("If `use_bc` is True, you must specify `bc_basis_case`.") @@ -119,13 +120,39 @@ def basis_functions_wrapper( "Basis algorithm not recognised. Please use either 'quadtree' or 'weighted', or input a basis function file" ) from e print(f"Using {basis_function.description} to derive basis functions.") - basis_data_array = basis_function.algorithm(fp_all, start_date, domain, emissions_name, nbasis, country_directory=country_directory) + + if inner_domain is not None: + inner_basis_data_array = basis_function.algorithm( + fp_all=fp_all, + start_date=start_date, + domain=f"{domain}-{inner_domain}", + emissions_name=emissions_name, + nbasis=nbasis, + country_directory=country_directory, + scenario="inner", + ) + + print(f"Computing inner basis took {time() - basis_start}s.") + + basis_data_array = basis_function.algorithm( + fp_all=fp_all, + start_date=start_date, + domain=domain, + emissions_name=emissions_name, + nbasis=nbasis, + country_directory=country_directory, + ) print(f"Computing basis took {time() - basis_start}s.") fp_sens_start = time() - fp_data = fp_sensitivity(fp_all, basis_func=basis_data_array) - print(f"Computing fp sensitivity took {time() - fp_sens_start}s.") + if basis_data_array is not None: + fp_data = fp_sensitivity( + fp_all, + basis_func=basis_data_array, + inner_basis_func=inner_basis_data_array, + ) + print(f"Computing fp sensitivity took {time() - fp_sens_start}s.") if use_bc is True: bc_sens_start = time() From a5dea90b1e482b2de5774677761186959db684b9 Mon Sep 17 00:00:00 2001 From: SutarPrasad Date: Tue, 12 May 2026 13:21:49 +0100 Subject: [PATCH 53/68] inner flux dic accepted by merged scenario --- openghg_inversions/inversion_data/scenario.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/openghg_inversions/inversion_data/scenario.py b/openghg_inversions/inversion_data/scenario.py index 53c78e2f..b8dc26e2 100644 --- a/openghg_inversions/inversion_data/scenario.py +++ b/openghg_inversions/inversion_data/scenario.py @@ -66,6 +66,7 @@ def merged_scenario_data( obs_data: ObsData, footprint_data: FootprintData, flux_dict: dict[str, FluxData], + inner_flux_dict: dict[str, FluxData] | None, bc_data: BoundaryConditionsData | None = None, inner_footprint_data: FootprintData | None = None, platform: str | None = None, @@ -106,7 +107,8 @@ def merged_scenario_data( # Mask the EUROPE flux to the inner domain extent (zero outside), # regridded to the inner footprint lat/lon grid. # ModelScenario then computes fp_x_flux on the inner grid correctly. - flux_dict_inner = _mask_flux_to_inner_domain(flux_dict, inner_footprint_data) + # flux_dict_inner = _mask_flux_to_inner_domain(flux_dict, inner_footprint_data) + flux_dict_inner = inner_flux_dict inner_scenario = ModelScenario(obs=obs_data, footprint=inner_footprint_data, flux=flux_dict_inner, bc=None) inner_domain_merged = inner_scenario.footprints_data_merge( From 359e9fb7cfeefe1962389da2538b07af72053fb2 Mon Sep 17 00:00:00 2001 From: SutarPrasad Date: Tue, 12 May 2026 13:22:21 +0100 Subject: [PATCH 54/68] inner emissions store is passed to fetch flux for inner domain --- openghg_inversions/inversion_data/get_data.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/openghg_inversions/inversion_data/get_data.py b/openghg_inversions/inversion_data/get_data.py index ae6f9603..3e0f80e6 100644 --- a/openghg_inversions/inversion_data/get_data.py +++ b/openghg_inversions/inversion_data/get_data.py @@ -214,6 +214,7 @@ def data_processing_surface_notracer( output_name: str | None = None, inner_domain: str | None = None, inner_footprint_store: str | list[str] | None = None, + inner_emissions_store: str | None = None, ) -> tuple[dict, list, list, list, list, list]: """Retrieve and prepare fixed-surface datasets from specified OpenGHG object stores. @@ -300,6 +301,17 @@ def data_processing_surface_notracer( fp_all[".flux"] = flux_dict + if inner_emissions_store is not None: + inner_flux_dict = get_flux_data( + sources=emissions_name, + species=species, + domain=f"{domain}-{inner_domain}", + start_date=start_date, + end_date=end_date, + store=inner_emissions_store, + ) + fp_all[".inner_flux"] = inner_flux_dict + # Get BC data if use_bc is True: try: @@ -413,6 +425,7 @@ def data_processing_surface_notracer( scenario_combined = merged_scenario_data( obs_data=site_data, footprint_data=standard_footprint_data, flux_dict= flux_dict, bc_data=bc_data, inner_footprint_data=inner_footprint_data, + inner_flux_dict= inner_flux_dict if inner_emissions_store is not None else None, platform=platform[i], max_level=max_level ) fp_all[site] = scenario_combined From fbe9be56505ffcd053187ecf3a3dd1c2fdb9eafb Mon Sep 17 00:00:00 2001 From: SutarPrasad Date: Tue, 12 May 2026 13:22:59 +0100 Subject: [PATCH 55/68] temp ini needs update for default paths to stores --- .../config/openghg_hbmcmc_input_template.ini | 32 +++++++++++-------- 1 file changed, 18 insertions(+), 14 deletions(-) diff --git a/openghg_inversions/hbmcmc/config/openghg_hbmcmc_input_template.ini b/openghg_inversions/hbmcmc/config/openghg_hbmcmc_input_template.ini index a2a7fc6b..0399d9ac 100644 --- a/openghg_inversions/hbmcmc/config/openghg_hbmcmc_input_template.ini +++ b/openghg_inversions/hbmcmc/config/openghg_hbmcmc_input_template.ini @@ -19,8 +19,8 @@ species = "ch4" ; (required) use_tracer = False ; (required) sites = ["mhd"] ; (required) -averaging_period = ["1h"] ; (required) -start_date = "2022-01-01" ; (required - but can be specified on command line instead) +averaging_period = ["4h"] ; (required) +start_date = "2020-01-01" ; (required - but can be specified on command line instead) end_date = "2024-12-31" ; (required - but can be specified on command line instead) ; save_merged_data (bool): If True, saves merged data object @@ -51,10 +51,13 @@ filters = [] ; footprint_store (str): Name of object store with footprints data ; emissions_store (str): Name of flux emissions object store -bc_store = " " ; (required) +bc_store = "/group/chem/acrg/object_stores/shared_store_zarr" ; (required) obs_store = "/group/chem/acrg/object_stores/paris/obs_icos_2025_08_store" ; (required) -footprint_store = "/group/chem/acrg/prasad/object_store_6km" ; (required) -emissions_store = "/group/chem/acrg/prasad/object_store_6km" ; (required) +footprint_store = "shared_store_zarr" ; (required) +emissions_store = "/acrg/prasad/job_scripts/openghg_inversions/prior_flux_2023_EUROPE " ; (required) +inner_footprint_store = "shared_store_zarr" +inner_emissions_store = "/acrg/prasad/job_scripts/openghg_inversions/prior_flux_2023_EUROPE_6km" + [INPUT.PRIORS] @@ -68,12 +71,13 @@ emissions_store = "/group/chem/acrg/prasad/object_store_6km" ; (required) ; bc_input (list/str): Name of boundary conditions data to use from object store domain = "europe" ; (required) +inner_domain = "6km" met_model = None fp_model = None fp_height = "10m" -fp_species = None -emissions_name = ["anthro"] ; (required) -bc_input = None +fp_species = "inert" +emissions_name = ["edgarv80_wetchartsv131"] ; (required) +bc_input = "camsv22r2_daily" [INPUT.BASIS_CASE] ; Input values to extract the basis cases to use within the inversion for boundary conditions and emissions @@ -95,7 +99,7 @@ fp_basis_case = None nbasis = 250 basis_directory = "/group/chem/acrg/LPDM/basis_functions/" bc_basis_directory = "/group/chem/acrg/LPDM/bc_basis_functions/" -country_file = "/group/chem/acrg/LPDM/countries/country_EUROPE_EEZ_PARIS_gapfilled.nc" +country_file = country_file = "/group/chem/acrg/LPDM/countries/country_EUROPE_EEZ_PARIS_gapfilled.nc" [MCMC.TYPE] ; Which MCMC setup to use. This defines the function which will be called and the expected inputs. @@ -172,9 +176,9 @@ sigma_per_site = True ; burn (int): Number of iterations to burn/discard in MCMC ; tune (int): Number of iterations to use to tune step size -nit = 100 ; (required) -burn = 10 ; (required) -tune = 100 ; (required) +nit = 1000 ; (required) +burn = 100 ; (required) +tune = 1000 ; (required) [MCMC.NCHAIN] ; nchain (int): Number of chains to run simultaneously. Must be >=2 to allow convergence to be checked. @@ -210,7 +214,7 @@ nchain = 2 averaging_error = True min_error = 0.0 fix_basis_outer_regions = False -use_bc = False +use_bc = True nuts_sampler = "numpyro" save_trace = True pollution_events_from_obs = True @@ -221,5 +225,5 @@ no_model_error = False ; outputpath (str): Directory to write output ; outputname (str): Unique identifier for output/run name. -outputpath = "~/openghg_inversions" ; (required) +outputpath = "/group/chem/acrg/prasad/job_scripts/openghg_inversions/inversion_outputs/paris_bc" ; (required) outputname = "mhd_openghg_inversions" ; (required) From 278ec947a398ac091b8f3d12b3cfc595f183ff79 Mon Sep 17 00:00:00 2001 From: SutarPrasad Date: Tue, 12 May 2026 13:27:49 +0100 Subject: [PATCH 56/68] passing inner emissions store --- openghg_inversions/hbmcmc/hbmcmc.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/openghg_inversions/hbmcmc/hbmcmc.py b/openghg_inversions/hbmcmc/hbmcmc.py index ab98b34b..dbd92b83 100644 --- a/openghg_inversions/hbmcmc/hbmcmc.py +++ b/openghg_inversions/hbmcmc/hbmcmc.py @@ -291,6 +291,7 @@ def fixedbasisMCMC( footprint_store: str = "user", emissions_store: str = "user", inner_footprint_store: str = "user", + inner_emissions_store: str = "user", met_model: list | None = None, fp_model: str | None = None, # Changed to none. When "NAME" specified FPs are not found fp_height: list[str] | None = None, @@ -590,6 +591,7 @@ def fixedbasisMCMC( footprint_store=footprint_store, emissions_store=emissions_store, inner_footprint_store=inner_footprint_store, + inner_emissions_store=inner_emissions_store, averagingerror=averaging_error, save_merged_data=save_merged_data, merged_data_name=merged_data_name, From 309d7ba298688d23752025c39f7a2d39b3bf1c3c Mon Sep 17 00:00:00 2001 From: SutarPrasad Date: Tue, 12 May 2026 13:28:46 +0100 Subject: [PATCH 57/68] inner basis func is accepted in fp_sensitivity and applies fp basis function accordingly --- openghg_inversions/basis/_helpers.py | 26 ++++++++++++++++++++------ 1 file changed, 20 insertions(+), 6 deletions(-) diff --git a/openghg_inversions/basis/_helpers.py b/openghg_inversions/basis/_helpers.py index 65b35192..67e0b489 100644 --- a/openghg_inversions/basis/_helpers.py +++ b/openghg_inversions/basis/_helpers.py @@ -6,7 +6,11 @@ from ._functions import basis_boundary_conditions -def fp_sensitivity(fp_and_data: dict, basis_func: xr.DataArray | dict[str, xr.DataArray]) -> dict: +def fp_sensitivity( + fp_and_data: dict, + basis_func: xr.DataArray | dict[str, xr.DataArray], + inner_basis_func: xr.DataArray | None = None, +) -> dict: """Add a sensitivity matrix, H, to each site xr.Dataset in fp_and_data. The sensitivity matrix H takes the footprint sensitivities (the `fp` variable), @@ -38,7 +42,6 @@ def fp_sensitivity(fp_and_data: dict, basis_func: xr.DataArray | dict[str, xr.Da if len(flux_sources) == 1: if not isinstance(basis_func, xr.DataArray): basis_func = next(iter(basis_func.values())) - fp_x_flux_name = "fp_x_flux" else: @@ -63,14 +66,22 @@ def fp_sensitivity(fp_and_data: dict, basis_func: xr.DataArray | dict[str, xr.Da if "time" in basis_func.dims and basis_func.sizes["time"] <= 1: basis_func = basis_func.squeeze("time") + if inner_basis_func is not None and "time" in inner_basis_func.dims and inner_basis_func.sizes["time"] <= 1: + inner_basis_func = inner_basis_func.squeeze("time") + fp_and_data[".basis"] = basis_func + if inner_basis_func is not None: + fp_and_data[".basis_inner"] = inner_basis_func for site in sites: entry = fp_and_data[site] # extract root fp_x_flux and mask inner-domain cells to zero if inner domain exists. This ensures that the outer sensitivity H only reflects the outer domain fluxes. if isinstance(entry, xr.DataTree): - root_ds = entry["standard"].ds + if "standard" in entry.children: + root_ds = entry["standard"].ds + else: + root_ds = entry.ds fp_x_flux_outer = root_ds[fp_x_flux_name] fp_x_flux_for_H = fp_x_flux_outer @@ -83,10 +94,13 @@ def fp_sensitivity(fp_and_data: dict, basis_func: xr.DataArray | dict[str, xr.Da # Compute H_inner from the inner child's fp_x_flux (its own lat/lon grid) if "inner" in entry.children: - inner_fp_x_flux = entry["inner"].ds["fp_x_flux_inner"] + if inner_basis_func is None: + raise ValueError("Inner-domain data exists but no inner basis function was provided.") + + inner_fp_x_flux = entry["inner"].ds[fp_x_flux_name] H_inner = apply_fp_basis_functions( fp_x_flux=inner_fp_x_flux, - basis_func=basis_func, + basis_func=inner_basis_func, ) # Write both back into the DataTree new_root = root_ds.assign({"H": sensitivity}) @@ -190,7 +204,7 @@ def bc_sensitivity( if basis_case.lower() == "nesw": for site in sites: ds = fp_and_data[site] - bc_ds = ds.ds[[f"bc_{d}" for d in "nesw"]].rename({f"bc_{d}": d for d in "nesw"}) + bc_ds = ds['standard'].ds[[f"bc_{d}" for d in "nesw"]].rename({f"bc_{d}": d for d in "nesw"}) sensitivity = bc_ds.sum(["lat", "lon", "height"]).to_dataarray(dim="bc_region") fp_and_data[site]["H_bc"] = sensitivity From bf81d5b9eae382a91204e02cdfa4e0f95df50376 Mon Sep 17 00:00:00 2001 From: SutarPrasad Date: Tue, 12 May 2026 13:30:13 +0100 Subject: [PATCH 58/68] flux and fp based on standard and inner scenario are fetched --- openghg_inversions/basis/_functions.py | 32 +++++++++++++++++--------- 1 file changed, 21 insertions(+), 11 deletions(-) diff --git a/openghg_inversions/basis/_functions.py b/openghg_inversions/basis/_functions.py index fdb04b0d..91b8a55f 100644 --- a/openghg_inversions/basis/_functions.py +++ b/openghg_inversions/basis/_functions.py @@ -132,7 +132,7 @@ def basis_boundary_conditions(domain: str, basis_case: str, bc_basis_directory: def _flux_fp_from_fp_all( - fp_all: dict, emissions_name: list[str] | None = None + fp_all: dict, emissions_name: list[str] | None = None, scenario = "standard" ) -> tuple[xr.DataArray, list[xr.DataArray]]: """Get flux and list of footprints from `fp_all` dictionary and optional list of emissions names. @@ -149,10 +149,12 @@ def _flux_fp_from_fp_all( footprints (list): List of xarray DataArray containing the footprints of each sites. """ + flux_key = ".inner_flux" if scenario == "inner" and ".inner_flux" in fp_all else ".flux" + if emissions_name is not None: - flux = fp_all[".flux"][emissions_name[0]].data.flux + flux = fp_all[flux_key][emissions_name[0]].data.flux else: - first_flux = next(iter(fp_all[".flux"].values())) + first_flux = next(iter(fp_all[flux_key].values())) flux = first_flux.data.flux flux = cast(xr.DataArray, flux) @@ -162,14 +164,16 @@ def _flux_fp_from_fp_all( if k.startswith("."): continue - # Need to discuss this further + # Need to discuss this further # fp_x_flux is guaranteed to be on the # EUROPE grid (it comes from the standard scenario). Raw .fp may be # on the 6km grid if OpenGHG snapped it to the footprint resolution. - if "fp_x_flux" in v: + if scenario == "inner" and isinstance(v, xr.DataTree) and "inner" in v.children: + fp = v["inner"].ds["fp_x_flux"] + elif "fp_x_flux" in v: fp = v["fp_x_flux"] else: - fp = v.fp + fp = v[scenario].ds.fp # if grid still doesn't match flux, regrid to flux grid if fp.sizes.get("lat") != flux.sizes.get("lat") or fp.sizes.get("lon") != flux.sizes.get("lon"): @@ -232,6 +236,7 @@ def quadtreebasisfunction( fp_all: dict, start_date: str, domain : str, + scenario: str = "standard", emissions_name: list[str] | None = None, nbasis: int = 100, country_directory: str | None = None, @@ -277,7 +282,7 @@ def quadtreebasisfunction( quad_basis (xarray.DataArray): Array with lat/lon dimensions and basis regions encoded by integers. """ - flux, footprints = _flux_fp_from_fp_all(fp_all, emissions_name) + flux, footprints = _flux_fp_from_fp_all(fp_all=fp_all, emissions_name=emissions_name, scenario=scenario) fps = _mean_fp_times_mean_flux(flux, footprints, abs_flux=abs_flux, mask=mask).as_numpy() # use xr.apply_ufunc to keep xarray coords @@ -302,7 +307,8 @@ def bucketbasisfunction( nbasis: int = 100, country_directory: str | None = None, abs_flux: bool = False, - mask: xr.DataArray | None = None + mask: xr.DataArray | None = None, + scenario: str = "standard", ) -> xr.DataArray: """Basis functions calculated using a weighted region approach where each basis function / scaling region contains approximately @@ -330,12 +336,15 @@ def bucketbasisfunction( Default None country_directory (str): Directory containing land-sea files. If None, will use default files. + scenario (str): + Scenario for retrieving emissions files. + Default "standard" Returns: bucket_basis (xarray.DataArray): Array with lat/lon dimensions and basis regions encoded by integers. """ - flux, footprints = _flux_fp_from_fp_all(fp_all, emissions_name) + flux, footprints = _flux_fp_from_fp_all(fp_all=fp_all, emissions_name=emissions_name, scenario=scenario) fps = _mean_fp_times_mean_flux(flux, footprints, abs_flux=abs_flux, mask=mask).as_numpy() fps = fps / fps.max() @@ -370,6 +379,7 @@ def fixed_outer_regions_basis( nbasis: int = 100, country_directory: str | None = None, abs_flux: bool = False, + scenario: str = "standard" ) -> xr.DataArray: """Fix outer region of basis functions to InTEM regions, and fit the inner regions using `basis_algorithm`. @@ -408,7 +418,7 @@ def fixed_outer_regions_basis( intem_regions = xr.open_dataset(intem_regions_path).region # force intem_regions to use flux coordinates - flux, _ = _flux_fp_from_fp_all(fp_all, emissions_name) + flux, _ = _flux_fp_from_fp_all(fp_all=fp_all, emissions_name=emissions_name, scenario=scenario) _, intem_regions = xr.align(flux, intem_regions, join="override") inner_index = intem_regions.values.max() @@ -416,7 +426,7 @@ def fixed_outer_regions_basis( mask = intem_regions == inner_index basis_function = basis_functions[basis_algorithm].algorithm - inner_region = basis_function(fp_all, start_date, domain, emissions_name, nbasis, country_directory, abs_flux, mask=mask) + inner_region = basis_function(fp_all=fp_all, start_date=start_date, domain=domain, emissions_name=emissions_name, nbasis=nbasis, country_directory=country_directory, abs_flux=abs_flux, mask=mask) basis = intem_regions.rename("basis") From 55df001822ed9df399ce0bf20f4af8cae2950a0a Mon Sep 17 00:00:00 2001 From: SutarPrasad Date: Tue, 12 May 2026 13:36:04 +0100 Subject: [PATCH 59/68] requires update to paths in conftest too --- tests/conftest.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index a928d340..77c04e38 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -357,16 +357,17 @@ def mhd_with_inner_domain_ch4_data_args(): "bc_store": "/group/chem/acrg/object_stores/shared_store_zarr", "obs_store": "/group/chem/acrg/object_stores/paris/obs_icos_2025_08_store", "footprint_store": "/group/chem/acrg/object_stores/shared_store_zarr", - "emissions_store": "/group/chem/acrg/object_stores/shared_store_zarr", + "emissions_store": "/group/chem/acrg/prasad/job_scripts/openghg_inversions/prior_flux_2023_EUROPE", "inner_footprint_store": "/group/chem/acrg/object_stores/shared_store_zarr", + "inner_emissions_store": "/group/chem/acrg/prasad/job_scripts/openghg_inversions/prior_flux_2023_EUROPE_6km", "inlet": ["9m"], "instrument": None, - "domain": "EUROPE", + "domain": "EUROPE", "inner_domain": "6km", "fp_height": ["10m"], "fp_species": "inert", "fp_model": "NAME", - "emissions_name": ["edgarv80_wetchartsv131"], + "emissions_name": ["edgar_wetcharts"], "averaging_period": ["4h"], "bc_input": "camsv22r2_daily", } From c6fcd7bcff4bd6ddd73cd057dde81f451c203121 Mon Sep 17 00:00:00 2001 From: SutarPrasad Date: Wed, 13 May 2026 16:10:40 +0100 Subject: [PATCH 60/68] datatree handling across codebase --- openghg_inversions/basis/_helpers.py | 50 ++++++++++--- openghg_inversions/hbmcmc/hbmcmc.py | 72 ++++++++++++------- openghg_inversions/hbmcmc/inversionsetup.py | 29 +++++--- .../postprocessing/countries.py | 9 ++- .../postprocessing/inversion_output.py | 64 +++++++++++++---- 5 files changed, 166 insertions(+), 58 deletions(-) diff --git a/openghg_inversions/basis/_helpers.py b/openghg_inversions/basis/_helpers.py index 67e0b489..2ff9accb 100644 --- a/openghg_inversions/basis/_helpers.py +++ b/openghg_inversions/basis/_helpers.py @@ -203,10 +203,25 @@ def bc_sensitivity( if basis_case.lower() == "nesw": for site in sites: - ds = fp_and_data[site] - bc_ds = ds['standard'].ds[[f"bc_{d}" for d in "nesw"]].rename({f"bc_{d}": d for d in "nesw"}) - sensitivity = bc_ds.sum(["lat", "lon", "height"]).to_dataarray(dim="bc_region") - fp_and_data[site]["H_bc"] = sensitivity + site_entry = fp_and_data[site] + if isinstance(site_entry, xr.DataTree): + standard_node = site_entry["standard"] if "standard" in site_entry.children else site_entry + standard_ds = standard_node.ds + bc_ds = standard_ds[[f"bc_{d}" for d in "nesw"]].rename({f"bc_{d}": d for d in "nesw"}) + sensitivity = bc_ds.sum(["lat", "lon", "height"]).to_dataarray(dim="bc_region") + updated_standard = standard_ds.assign({"H_bc": sensitivity}) + + if "standard" in site_entry.children: + dt_dict = {"/standard": updated_standard} + if "inner" in site_entry.children: + dt_dict["/inner"] = site_entry["inner"].ds + fp_and_data[site] = xr.DataTree.from_dict(dt_dict) + else: + fp_and_data[site] = xr.DataTree(dataset=updated_standard) + else: + bc_ds = site_entry[[f"bc_{d}" for d in "nesw"]].rename({f"bc_{d}": d for d in "nesw"}) + sensitivity = bc_ds.sum(["lat", "lon", "height"]).to_dataarray(dim="bc_region") + site_entry["H_bc"] = sensitivity return fp_and_data @@ -224,10 +239,27 @@ def bc_sensitivity( bc_basis = basis_func.rename({dv: str(dv).replace("basis_", "") for dv in basis_func.data_vars}) for site in sites: - ds = fp_and_data[site] - bc_ds = ds.ds[[f"bc_{d}" for d in "nesw"]] - sensitivity = (bc_ds * bc_basis).sum(["lat", "lon", "height"]).to_dataarray(dim="__newdim__").sum("__newdim__") - sensitivity = sensitivity.rename(region="bc_region") - fp_and_data[site]["H_bc"] = sensitivity + site_entry = fp_and_data[site] + if isinstance(site_entry, xr.DataTree): + standard_node = site_entry["standard"] if "standard" in site_entry.children else site_entry + standard_ds = standard_node.ds + bc_ds = standard_ds[[f"bc_{d}" for d in "nesw"]] + sensitivity = (bc_ds * bc_basis).sum(["lat", "lon", "height"]).to_dataarray(dim="__newdim__").sum("__newdim__") + sensitivity = sensitivity.rename(region="bc_region") + updated_standard = standard_ds.assign({"H_bc": sensitivity}) + + if "standard" in site_entry.children: + dt_dict = {"/standard": updated_standard} + if "inner" in site_entry.children: + dt_dict["/inner"] = site_entry["inner"].ds + fp_and_data[site] = xr.DataTree.from_dict(dt_dict) + else: + fp_and_data[site] = xr.DataTree(dataset=updated_standard) + else: + site_ds = site_entry + bc_ds = site_ds[[f"bc_{d}" for d in "nesw"]] + sensitivity = (bc_ds * bc_basis).sum(["lat", "lon", "height"]).to_dataarray(dim="__newdim__").sum("__newdim__") + sensitivity = sensitivity.rename(region="bc_region") + site_ds["H_bc"] = sensitivity return fp_and_data diff --git a/openghg_inversions/hbmcmc/hbmcmc.py b/openghg_inversions/hbmcmc/hbmcmc.py index dbd92b83..5ed6a32a 100644 --- a/openghg_inversions/hbmcmc/hbmcmc.py +++ b/openghg_inversions/hbmcmc/hbmcmc.py @@ -88,6 +88,13 @@ def make_inv_inputs( offset_args, power, ): + def _main_ds(entry: xr.Dataset | xr.DataTree) -> xr.Dataset: + if isinstance(entry, xr.DataTree): + if "standard" in entry.children: + return entry["standard"].ds + return entry.ds + return entry + # Trigger dask computations # we only compute the variables we need below to_compute = [ @@ -105,23 +112,34 @@ def make_inv_inputs( ] for site in sites: site_entry = fp_data[site] - # Use root node's data_vars only to avoid picking up inner child variables if isinstance(site_entry, xr.DataTree): - root_data_vars = site_entry.ds.data_vars + has_standard = "standard" in site_entry.children + standard_ds = site_entry["standard"].ds if has_standard else site_entry.ds + + to_compute_site = [dv for dv in to_compute if dv in standard_ds.data_vars] + updated_standard = ( + standard_ds.assign({var: standard_ds[var].compute() for var in to_compute_site}) + if to_compute_site + else standard_ds + ) + + updated_inner = None + if "inner" in site_entry.children: + inner_ds = site_entry["inner"].ds + if "H_inner" in inner_ds.data_vars: + updated_inner = inner_ds.assign({"H_inner": inner_ds["H_inner"].compute()}) + else: + updated_inner = inner_ds + + standard_key = "/standard" if has_standard else "/" + dt_dict = {standard_key: updated_standard} + if updated_inner is not None: + dt_dict["/inner"] = updated_inner + fp_data[site] = xr.DataTree.from_dict(dt_dict) else: - root_data_vars = site_entry.data_vars - - to_compute_site = [dv for dv in to_compute if dv in root_data_vars] - if to_compute_site: - if isinstance(site_entry, xr.DataTree): - # Compute the needed variables in the root dataset and rebuild the DataTree - computed_root = site_entry.ds.assign( - {var: site_entry.ds[var].compute() for var in to_compute_site} - ) - fp_data[site] = xr.DataTree(dataset=computed_root, children=dict(site_entry.children)) - else: - for var in to_compute_site: - fp_data[site][var] = fp_data[site][var].compute() + to_compute_site = [dv for dv in to_compute if dv in site_entry.data_vars] + if to_compute_site: + fp_data[site] = site_entry.assign({var: site_entry[var].compute() for var in to_compute_site}) # Get inputs ready error = np.zeros(0) @@ -145,22 +163,24 @@ def make_inv_inputs( if site in dropped_sites: continue + site_ds = _main_ds(fp_data[site]) + # select variables to drop NaNs from drop_vars = [] for var in ["H", "H_bc", "mf", "mf_error"]: - if var in fp_data[site].data_vars: + if var in site_ds.data_vars: drop_vars.append(var) # pymc doesn't like NaNs, so drop them for the variables used below # DataTree doesn't support dropna; use sel with valid time indices instead if isinstance(fp_data[site], xr.DataTree): - valid_times = fp_data[site].ds.dropna("time", subset=drop_vars).time + valid_times = site_ds.dropna("time", subset=drop_vars).time fp_data[site] = fp_data[site].sel(time=valid_times) else: - fp_data[site] = fp_data[site].dropna("time", subset=drop_vars) + fp_data[site] = site_ds.dropna("time", subset=drop_vars) # repeatability/variability chosen/combined into mf_error in `get_data.py` - ds = fp_data[site].ds if isinstance(fp_data[site], xr.DataTree) else fp_data[site] + ds = _main_ds(fp_data[site]) error = np.concatenate((error, ds["mf_error"].values)) obs_repeatability = np.concatenate((obs_repeatability, ds["mf_repeatability"].values)) @@ -248,7 +268,7 @@ def make_inv_inputs( if bc_freq == "monthly": Hmbc = setup.monthly_bcs(start_date, end_date, site, fp_data) elif bc_freq is None: - Hmbc = fp_data[site].H_bc.values + Hmbc = _main_ds(fp_data[site])["H_bc"].values else: Hmbc = setup.create_bc_sensitivity(start_date, end_date, site, fp_data, bc_freq) @@ -638,10 +658,12 @@ def fixedbasisMCMC( for site in sites: entry = fp_data[site] if isinstance(entry, xr.DataTree): - # compute dask arrays in both root and inner child - dt_dict = {"/": entry.ds.compute()} + # compute dask arrays while preserving standard/inner node layout + standard_ds = entry["standard"].ds if "standard" in entry.children else entry.ds + standard_key = "/standard" if "standard" in entry.children else "/" + dt_dict = {standard_key: standard_ds.compute()} if "inner" in entry.children: - dt_dict["inner"] = entry["inner"].ds.compute() + dt_dict["/inner"] = entry["inner"].ds.compute() fp_data[site] = xr.DataTree.from_dict(dt_dict) else: fp_data[site] = entry.compute() @@ -650,7 +672,9 @@ def fixedbasisMCMC( dropped_sites = [] for site in sites: # check if some datasets are empty due to filtering - if fp_data[site].ds.time.values.shape[0] == 0: + site_entry = fp_data[site] + site_ds = site_entry["standard"].ds if isinstance(site_entry, xr.DataTree) and "standard" in site_entry.children else (site_entry.ds if isinstance(site_entry, xr.DataTree) else site_entry) + if site_ds.time.values.shape[0] == 0: dropped_sites.append(site) del fp_data[site] diff --git a/openghg_inversions/hbmcmc/inversionsetup.py b/openghg_inversions/hbmcmc/inversionsetup.py index 08a79ad8..1feecc93 100644 --- a/openghg_inversions/hbmcmc/inversionsetup.py +++ b/openghg_inversions/hbmcmc/inversionsetup.py @@ -4,6 +4,15 @@ import pandas as pd +def _site_ds(fp_data: dict, site: str): + entry = fp_data[site] + if hasattr(entry, "children"): + if "standard" in entry.children: + return entry["standard"].ds + return entry.ds + return entry + + def monthly_bcs(start_date: str, end_date: str, site: str, fp_data: dict) -> np.ndarray: """Creates a sensitivity matrix (H-matrix) for the boundary conditions, which will map monthly boundary condition @@ -23,12 +32,13 @@ def monthly_bcs(start_date: str, end_date: str, site: str, fp_data: dict) -> np. hmbc: Sensitivity matrix by month for observations """ + site_ds = _site_ds(fp_data, site) allmonth = pd.date_range(start_date, end_date, freq="MS")[:-1] nmonth = len(allmonth) - curtime = pd.to_datetime(fp_data[site].time.values).to_period("M") - pmonth = pd.to_datetime(fp_data[site].resample(time="MS").mean().time.values) - nregions = fp_data[site].sizes["bc_region"] - hmbc = np.zeros((nregions * nmonth, len(fp_data[site].time.values))) + curtime = pd.to_datetime(site_ds.time.values).to_period("M") + pmonth = pd.to_datetime(site_ds.resample(time="MS").mean().time.values) + nregions = site_ds.sizes["bc_region"] + hmbc = np.zeros((nregions * nmonth, len(site_ds.time.values))) count = 0 for cord in range(nregions): for m in range(0, nmonth): @@ -38,7 +48,7 @@ def monthly_bcs(start_date: str, end_date: str, site: str, fp_data: dict) -> np. mnth = allmonth[m].month yr = allmonth[m].year mnthloc = np.where(np.logical_and(curtime.month == mnth, curtime.year == yr))[0] - hmbc[count, mnthloc] = fp_data[site].H_bc.values[cord, mnthloc] + hmbc[count, mnthloc] = site_ds["H_bc"].values[cord, mnthloc] count += 1 return hmbc @@ -70,14 +80,15 @@ def create_bc_sensitivity(start_date: str, end_date: str, site: str, fp_data: di hmbc: Sensitivity matrix by for observations to boundary conditions """ + site_ds = _site_ds(fp_data, site) dys = int("".join([s for s in freq if s.isdigit()])) alldates = pd.date_range( pd.to_datetime(start_date), pd.to_datetime(end_date) + pd.DateOffset(days=dys), freq=freq ) ndates = np.sum(alldates < pd.to_datetime(end_date)) - curdates = fp_data[site].time.values - nregions = fp_data[site].sizes["bc_region"] - hmbc = np.zeros((nregions * ndates, len(fp_data[site].time.values))) + curdates = site_ds.time.values + nregions = site_ds.sizes["bc_region"] + hmbc = np.zeros((nregions * ndates, len(site_ds.time.values))) count = 0 for cord in range(nregions): for m in range(0, ndates): @@ -89,7 +100,7 @@ def create_bc_sensitivity(start_date: str, end_date: str, site: str, fp_data: di if len(dateloc) == 0: count += 1 continue - hmbc[count, dateloc] = fp_data[site].H_bc.values[cord, dateloc] + hmbc[count, dateloc] = site_ds["H_bc"].values[cord, dateloc] count += 1 return hmbc diff --git a/openghg_inversions/postprocessing/countries.py b/openghg_inversions/postprocessing/countries.py index 598d435e..6a2f2909 100644 --- a/openghg_inversions/postprocessing/countries.py +++ b/openghg_inversions/postprocessing/countries.py @@ -383,7 +383,14 @@ def get_country_trace( 1.0, based on how it is used in hbmcmc """ x_to_country_mat = self.get_x_to_country_mat(inv_out) - x_trace = inv_out.get_trace_dataset(var_names="x") + trace_all = inv_out.get_trace_dataset() + x_outer_vars = [ + "x_prior", + "x_posterior", + "x_prior_predictive", + "x_posterior_predictive", + ] + x_trace = trace_all[[var for var in x_outer_vars if var in trace_all.data_vars]] species = inv_out.species diff --git a/openghg_inversions/postprocessing/inversion_output.py b/openghg_inversions/postprocessing/inversion_output.py index a24915d0..fc8ea129 100644 --- a/openghg_inversions/postprocessing/inversion_output.py +++ b/openghg_inversions/postprocessing/inversion_output.py @@ -627,6 +627,24 @@ def make_inv_out_for_fixed_basis_mcmc( obs_prior_upper_level_factor: np.ndarray | None = None, ) -> InversionOutput: """Create InversionOutput in `fixedbasisMCMC`.""" + def _scenario_datasets(entry: xr.Dataset | xr.DataTree) -> list[xr.Dataset]: + if isinstance(entry, xr.DataTree): + datasets: list[xr.Dataset] = [] + if "inner" in entry.children: + datasets.append(entry["inner"].ds) + if "standard" in entry.children: + datasets.append(entry["standard"].ds) + if not datasets: + datasets.append(entry.ds) + return datasets + return [entry] + + def _first_attrs(datasets: list[xr.Dataset], var_name: str) -> dict: + for ds in datasets: + if var_name in ds.data_vars: + return ds[var_name].attrs + return {} + nmeasure = np.arange(len(Y)) y_obs = xr.DataArray(Y, dims=["nmeasure"], coords={"nmeasure": nmeasure}, name="Yobs") times = xr.DataArray(Ytime, dims=["nmeasure"], coords={"nmeasure": nmeasure}, name="times") @@ -666,11 +684,18 @@ def make_inv_out_for_fixed_basis_mcmc( basis = get_xr_dummies(fp_data[".basis"], cat_dim="nx", categories=nx) - scenarios = [v for k, v in fp_data.items() if not k.startswith(".")] + scenario_entries = [v for k, v in fp_data.items() if not k.startswith(".")] + scenario_datasets = [_scenario_datasets(entry) for entry in scenario_entries] + + flux = None + for entry in scenario_entries: + try: + flux = entry.flux_stacked + break + except AttributeError: + continue - try: - flux = scenarios[0].flux_stacked - except AttributeError: + if flux is None: flux = next(iter(fp_data[".flux"].values())).data # TODO: this only works if there is one flux used (or if multiple, but ModelScenario stacks them) @@ -680,18 +705,27 @@ def make_inv_out_for_fixed_basis_mcmc( if not isinstance(flux, xr.DataArray): raise ValueError("Flux from `fp_data` could not be converted to a xr.DataArray.") - # add attributes - scenario = scenarios[0] - y_obs.attrs = scenario.mf.attrs - times.attrs = scenario.time.attrs - y_error.attrs = scenario.mf_error.attrs - y_error_variability.attrs = scenario.mf_variability.attrs - y_error_repeatability.attrs = scenario.mf_repeatability.attrs - for scenario in scenarios: - if not ("mf_prior_factor" in scenario and "mf_prior_upper_level_factor" in scenario): + # add attributes (prefer inner metadata, then standard) + first_ds_group = scenario_datasets[0] + y_obs.attrs = _first_attrs(first_ds_group, "mf") + time_attrs = next((ds.time.attrs for ds in first_ds_group if "time" in ds.coords), {}) + times.attrs = time_attrs + y_error.attrs = _first_attrs(first_ds_group, "mf_error") + y_error_variability.attrs = _first_attrs(first_ds_group, "mf_variability") + y_error_repeatability.attrs = _first_attrs(first_ds_group, "mf_repeatability") + for ds_group in scenario_datasets: + ds_with_priors = next( + ( + ds + for ds in ds_group + if "mf_prior_factor" in ds.data_vars and "mf_prior_upper_level_factor" in ds.data_vars + ), + None, + ) + if ds_with_priors is None: continue - y_obs_prior_factor.attrs = scenario.mf_prior_factor.attrs - y_obs_prior_upper_level_factor.attrs = scenario.mf_prior_upper_level_factor.attrs + y_obs_prior_factor.attrs = ds_with_priors.mf_prior_factor.attrs + y_obs_prior_upper_level_factor.attrs = ds_with_priors.mf_prior_upper_level_factor.attrs return InversionOutput( obs=y_obs, From 6de2cc3bdf6ffee2d9a1966613714b55a45ea2ca Mon Sep 17 00:00:00 2001 From: SutarPrasad Date: Wed, 13 May 2026 16:11:48 +0100 Subject: [PATCH 61/68] x out var traces and comning of x inner and x into traces output of the inversions --- openghg_inversions/hbmcmc/inversion_pymc.py | 67 ++++++++++++++++--- .../postprocessing/make_outputs.py | 9 ++- 2 files changed, 65 insertions(+), 11 deletions(-) diff --git a/openghg_inversions/hbmcmc/inversion_pymc.py b/openghg_inversions/hbmcmc/inversion_pymc.py index e025a695..60b4214a 100644 --- a/openghg_inversions/hbmcmc/inversion_pymc.py +++ b/openghg_inversions/hbmcmc/inversion_pymc.py @@ -322,20 +322,15 @@ def inferpymc( hx = pm.Data("hx", hx, dims=("nmeasure", "nx")) if Hx_inner is not None: - # Hx_inner has shape (nx_inner, nmeasure); Hx has shape (nx, nmeasure) - # They have different leading dims so cannot be subtracted — keep separate hx_inner = pm.Data("hx_inner", Hx_inner.T, dims=("nmeasure", "nx_inner")) - x = parse_prior("x", xprior, dims="nx") x_inner = parse_prior("x_inner", xprior, dims="nx_inner") - step1_vars += [x, x_inner] + step1_vars.append(x_inner) mu = pm.Deterministic( "mu", pt.dot(hx, x) + pt.dot(hx_inner, x_inner), dims="nmeasure", ) else: - x = parse_prior("x", xprior, dims="nx") - step1_vars.append(x) mu = pm.Deterministic("mu", pt.dot(hx, x), dims="nmeasure") if use_bc: @@ -497,6 +492,8 @@ def inferpymc_postprocessouts( rerun_file: xr.Dataset | None = None, use_bc: bool = False, min_error: float | np.ndarray = 0.0, + xouts_inner: np.ndarray | None = None, + Hx_inner: np.ndarray | None = None, ) -> xr.Dataset: r"""Take the output from inferpymc function along with other input information. @@ -513,6 +510,7 @@ def inferpymc_postprocessouts( Args: xouts: MCMC chain for emissions scaling factors for each basis function. + xouts_inner: MCMC chain for emissions scaling factors for each inner-domain basis function. sigouts: MCMC chain for model error. convergence: Passed/Failed convergence test as to whether multiple chains have a Gelman-Rubin diagnostic value <1.05. @@ -520,6 +518,7 @@ def inferpymc_postprocessouts( This is the same as what is given from fp_data[site].H.values, where fp_data is the output from e.g. footprint_data_merge, but where it has been stacked for all sites. + Hx_inner: Inner-domain sensitivity matrix with the same orientation as Hx. Y: Measurement vector containing all measurements. error: Measurement error vector, containing a value for each element of Y. Ytrace: Trace of modelled y values calculated from mcmc outputs and H matrices. @@ -587,6 +586,12 @@ def inferpymc_postprocessouts( nit = xouts.shape[0] nx = Hx.shape[0] ny = len(Y) + if xouts_inner is not None: + nx_inner = xouts_inner.shape[1] + elif Hx_inner is not None: + nx_inner = Hx_inner.shape[0] + else: + nx_inner = 0 if use_bc: nbc = Hbc.shape[0] @@ -596,6 +601,8 @@ def inferpymc_postprocessouts( steps = np.arange(nit) nmeasure = np.arange(ny) nparam = np.arange(nx) + if nx_inner > 0: + nparam_inner = np.arange(nx_inner) # OFFSET HYPERPARAMETER YmodmuOFF = np.mean(OFFSETtrace, axis=1) # mean @@ -658,6 +665,9 @@ def inferpymc_postprocessouts( else: Yapriori = np.sum(Hx.T, axis=1) + if Hx_inner is not None: + Yapriori = Yapriori + np.sum(Hx_inner.T, axis=1) + sitenum = np.arange(len(sites)) if fp_data is None and rerun_file is not None: @@ -792,6 +802,26 @@ def inferpymc_postprocessouts( min_error = min_error * np.ones_like(Y) # Make output netcdf file + xouts_values = xouts.values if hasattr(xouts, "values") else xouts + sigouts_values = sigouts.values if hasattr(sigouts, "values") else sigouts + xouts_inner_values = xouts_inner.values if hasattr(xouts_inner, "values") else xouts_inner + + if xouts_inner_values is not None: + xtrace_values = np.concatenate([xouts_inner_values, xouts_values], axis=1) + xsensitivity_values = np.concatenate([Hx_inner.T, Hx.T], axis=1) if Hx_inner is not None else Hx.T + nparam_values = np.arange(nx + nx_inner) + param_index = np.concatenate([np.arange(nx_inner), np.arange(nx)]) + param_domain = np.concatenate([ + np.full(nx_inner, "inner", dtype="U5"), + np.full(nx, "outer", dtype="U5"), + ]) + else: + xtrace_values = xouts_values + xsensitivity_values = Hx.T + nparam_values = np.arange(nx) + param_index = np.arange(nx) + param_domain = np.full(nx, "outer", dtype="U5") + data_vars = { "Yobs": (["nmeasure"], Y), "Yerror": (["nmeasure"], error), @@ -810,8 +840,8 @@ def inferpymc_postprocessouts( "Yoffmode": (["nmeasure"], YmodmodeOFF), "Yoff68": (["nmeasure", "nUI"], Ymod68OFF), "Yoff95": (["nmeasure", "nUI"], Ymod95OFF), - "xtrace": (["steps", "nparam"], xouts.values), - "sigtrace": (["steps", "nsigma_site", "nsigma_time"], sigouts.values), + "xtrace": (["steps", "nparam"], xtrace_values), + "sigtrace": (["steps", "nsigma_site", "nsigma_time"], sigouts_values), "siteindicator": (["nmeasure"], siteindicator), "sigmafreqindex": (["nmeasure"], sigma_freq_index), "sitenames": (["nsite"], sites), @@ -830,9 +860,12 @@ def inferpymc_postprocessouts( "country95": (["countrynames", "nUI"], cntry95), "countryapriori": (["countrynames"], cntryprior), "countrydefinition": (["lat", "lon"], cntrygrid), - "xsensitivity": (["nmeasure", "nparam"], Hx.T), + "xsensitivity": (["nmeasure", "nparam"], xsensitivity_values), } + if Hx_inner is not None: + data_vars["xsensitivity_inner"] = (["nmeasure", "nparam_inner"], Hx_inner.T) + coords = { "stepnum": (["steps"], steps), "paramnum": (["nlatent"], nparam), @@ -841,11 +874,17 @@ def inferpymc_postprocessouts( "nsites": (["nsite"], sitenum), "nsigma_time": (["nsigma_time"], np.unique(sigma_freq_index)), "nsigma_site": (["nsigma_site"], np.arange(sigouts.shape[1]).astype(int)), + "nparam": (["nparam"], nparam_values), + "param_index": (["nparam"], param_index), + "param_domain": (["nparam"], param_domain), "lat": (["lat"], lat), "lon": (["lon"], lon), "countrynames": (["countrynames"], cntrynames), } + if nx_inner > 0: + coords["nparam_inner"] = (["nparam_inner"], nparam_inner) + if use_bc: data_vars.update( { @@ -890,6 +929,8 @@ def inferpymc_postprocessouts( outds.countryapriori.attrs["units"] = country_units outds.xsensitivity.attrs["units"] = obs_units + " " + "mol/mol" outds.sigtrace.attrs["units"] = obs_units + " " + "mol/mol" + if "xsensitivity_inner" in outds: + outds.xsensitivity_inner.attrs["units"] = obs_units + " " + "mol/mol" outds.Yobs.attrs["longname"] = "observations" outds.Yerror.attrs["longname"] = "measurement error" @@ -910,7 +951,11 @@ def inferpymc_postprocessouts( outds.Yoff95.attrs["longname"] = ( " 0.95 Bayesian credible interval of posterior simulated offset between measurements" ) - outds.xtrace.attrs["longname"] = "trace of unitless scaling factors for emissions parameters" + outds.xtrace.attrs["longname"] = ( + "trace of unitless scaling factors for combined outer+inner emissions parameters" + ) + outds.param_index.attrs["longname"] = "parameter index within each domain for emissions trace" + outds.param_domain.attrs["longname"] = "domain label (outer or inner) for emissions parameters" outds.sigtrace.attrs["longname"] = "trace of model error parameters" outds.siteindicator.attrs["longname"] = "index of site of measurement corresponding to sitenames" outds.sigmafreqindex.attrs["longname"] = "perdiod over which the model error is estimated" @@ -931,6 +976,8 @@ def inferpymc_postprocessouts( outds.countryapriori.attrs["longname"] = "prior mean of ocean and country totals" outds.countrydefinition.attrs["longname"] = "grid definition of countries" outds.xsensitivity.attrs["longname"] = "emissions sensitivity timeseries" + if "xsensitivity_inner" in outds: + outds.xsensitivity_inner.attrs["longname"] = "inner-domain emissions sensitivity timeseries" if use_bc: outds.YmodmeanBC.attrs["units"] = obs_units + " " + "mol/mol" diff --git a/openghg_inversions/postprocessing/make_outputs.py b/openghg_inversions/postprocessing/make_outputs.py index 2b2089ac..96dedeac 100644 --- a/openghg_inversions/postprocessing/make_outputs.py +++ b/openghg_inversions/postprocessing/make_outputs.py @@ -42,7 +42,14 @@ def make_flux_outputs( xr.Dataset with computed flux stats. """ - trace = inv_out.get_trace_dataset(var_names="x") + trace_all = inv_out.get_trace_dataset() + x_outer_vars = [ + "x_prior", + "x_posterior", + "x_prior_predictive", + "x_posterior_predictive", + ] + trace = trace_all[[var for var in x_outer_vars if var in trace_all.data_vars]] if stats_args is None: stats_args = {} From 4a0d28de219c569619279158986b81dcacf618a2 Mon Sep 17 00:00:00 2001 From: SutarPrasad Date: Wed, 13 May 2026 16:12:19 +0100 Subject: [PATCH 62/68] time masking other dims issue raised due to datatree passing values with extra dims --- openghg_inversions/filters.py | 43 +++++++++++++++++++---------------- 1 file changed, 23 insertions(+), 20 deletions(-) diff --git a/openghg_inversions/filters.py b/openghg_inversions/filters.py index 3e26e3bd..eb8c6163 100644 --- a/openghg_inversions/filters.py +++ b/openghg_inversions/filters.py @@ -26,7 +26,14 @@ # this dictionary will be populated by using the decorator `register_filter` filtering_functions = {} - +def _time_mask(mask: xr.DataArray, filter_name: str) -> xr.DataArray: + if "time" not in mask.dims: + raise ValueError(f"{filter_name} expects a 'time' dimension.") + extra_dims = [dim for dim in mask.dims if dim != "time"] + if extra_dims: + mask = mask.any(dim=extra_dims) + return mask + def register_filter(filt: Callable) -> Callable: """Decorator function to register filters. @@ -139,7 +146,7 @@ def filtering( # --- DataTree handling --- if isinstance(site_entry, xr.DataTree): - outer_ds = site_entry.ds + outer_ds = site_entry["standard"].ds if "standard" in site_entry.children else site_entry.ds n_nofilter = outer_ds.time.values.shape[0] filtered_outer = filtering_functions[filt](outer_ds, keep_missing=keep_missing) @@ -149,12 +156,13 @@ def filtering( perc_dropped = np.round(n_dropped / n_nofilter * 100, 2) print(f"{filt} filter removed {n_dropped} ({perc_dropped} %) obs at site {site}") - # Rebuild DataTree: root = filtered outer, inner child reindexed to new time - dt_dict = {"/": filtered_outer} + # Rebuild DataTree preserving the standard/inner layout. + standard_key = "/standard" if "standard" in site_entry.children else "/" + dt_dict = {standard_key: filtered_outer} if "inner" in site_entry.children: inner_ds = site_entry["inner"].ds # Keep inner time axis aligned with the (now filtered) outer time axis - dt_dict["inner"] = inner_ds.reindex( + dt_dict["/inner"] = inner_ds.reindex( time=filtered_outer.time, fill_value=0.0 ) datasets[site] = xr.DataTree.from_dict(dt_dict) @@ -402,22 +410,17 @@ def pblh_min(dataset: xr.Dataset, pblh_threshold: float = 200.0, keep_missing: b """ pblh_da = dataset.PBLH if "PBLH" in dataset.data_vars else dataset.atmosphere_boundary_layer_thickness - ti = [i for i, pblh in enumerate(pblh_da) if pblh > pblh_threshold] - - if keep_missing is True: - mf_data_array = dataset.mf - dataset_temp = dataset.drop("mf") - - dataarray_temp = mf_data_array[dict(time=ti)] + filt = _time_mask(pblh_da > pblh_threshold, "pblh_min") - mf_ds = xr.Dataset( - {"mf": (["time"], dataarray_temp)}, coords={"time": (dataarray_temp.coords["time"])} - ) + # Some inputs can include extra dimensions; collapse to a per-time mask. + if "time" not in filt.dims: + raise ValueError("PBLH filter expects a 'time' dimension.") + extra_dims = [dim for dim in filt.dims if dim != "time"] + if extra_dims: + filt = filt.any(dim=extra_dims) - dataset_out = combine_datasets(dataset_temp, mf_ds, method=None) - return dataset_out - else: - return dataset[dict(time=ti)] + drop = not keep_missing + return dataset.where(filt.compute(), drop=drop) @register_filter @@ -451,7 +454,7 @@ def pblh_inlet_diff( pblh_da = dataset.PBLH if "PBLH" in dataset.data_vars else dataset.atmosphere_boundary_layer_thickness - filt = pblh_da > inlet_height + diff_threshold + filt = _time_mask(pblh_da > inlet_height + diff_threshold, "pblh_inlet_diff") drop = not keep_missing return dataset.where(filt.compute(), drop=drop) From e9bbc296ff7385237b9144345388578479d125bf Mon Sep 17 00:00:00 2001 From: SutarPrasad Date: Wed, 13 May 2026 16:12:49 +0100 Subject: [PATCH 63/68] reducing test probably havin this test might not be the good solution --- tests/test_full_inversion.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/test_full_inversion.py b/tests/test_full_inversion.py index 00b94a4f..6fdc855b 100644 --- a/tests/test_full_inversion.py +++ b/tests/test_full_inversion.py @@ -247,8 +247,8 @@ def inner_domain_mcmc_args(tmp_path, mhd_with_inner_domain_ch4_data_args): "basis_output_path": str(tmp_path), "nbasis": 4, "nit": 1, - "burn": 10, - "tune": 10, + "burn": 1, + "tune": 1, "nchain": 1, "mcmc_type" : "fixed_basis", "reload_merged_data": False, @@ -268,7 +268,7 @@ def inner_domain_mcmc_args(tmp_path, mhd_with_inner_domain_ch4_data_args): "sigma_per_site" : True, "inlet" : [slice(0,25)], "instrument" : ['multiple'], - "filters" : {'MHD' : ['pblh_min','pblh_inlet_diff']} + "filters" : {'MHD' : ['pblh_inlet_diff']} } ) return mcmc_args From 96252eb869c6940844e13974683a50af0e493fd9 Mon Sep 17 00:00:00 2001 From: SutarPrasad Date: Wed, 20 May 2026 20:02:35 +0100 Subject: [PATCH 64/68] improved handling of inner and standrd domain in datatree as well as H and H inner by creating sub functions --- openghg_inversions/basis/_helpers.py | 103 +++++++++++++++------------ 1 file changed, 56 insertions(+), 47 deletions(-) diff --git a/openghg_inversions/basis/_helpers.py b/openghg_inversions/basis/_helpers.py index 2ff9accb..63c1e5fd 100644 --- a/openghg_inversions/basis/_helpers.py +++ b/openghg_inversions/basis/_helpers.py @@ -73,10 +73,25 @@ def fp_sensitivity( if inner_basis_func is not None: fp_and_data[".basis_inner"] = inner_basis_func + if inner_basis_func is not None: + invalid_inner_sites = [ + site + for site in sites + if not ( + isinstance(fp_and_data[site], xr.DataTree) and "inner" in fp_and_data[site].children + ) + ] + if invalid_inner_sites: + raise ValueError( + "Inner basis supplied, but some sites are not DataTree entries with an inner child: " + f"{invalid_inner_sites}." + ) + for site in sites: entry = fp_and_data[site] - # extract root fp_x_flux and mask inner-domain cells to zero if inner domain exists. This ensures that the outer sensitivity H only reflects the outer domain fluxes. + + # extract root fp_x_flux (already masked if inner domain exists) if isinstance(entry, xr.DataTree): if "standard" in entry.children: root_ds = entry["standard"].ds @@ -84,11 +99,9 @@ def fp_sensitivity( root_ds = entry.ds fp_x_flux_outer = root_ds[fp_x_flux_name] - fp_x_flux_for_H = fp_x_flux_outer - - # Compute outer H from the masked fp_x_flux + # Compute outer H from the (already masked) fp_x_flux sensitivity = apply_fp_basis_functions( - fp_x_flux=fp_x_flux_for_H, + fp_x_flux=fp_x_flux_outer, basis_func=basis_func, ) @@ -110,9 +123,17 @@ def fp_sensitivity( "/inner": new_inner, }) else: - fp_and_data[site] = xr.DataTree(dataset=root_ds.assign({"H": sensitivity})) + fp_and_data[site] = xr.DataTree.from_dict({ + "/standard": root_ds.assign({"H": sensitivity}) + }) else: + if inner_basis_func is not None: + raise ValueError( + "Inner-domain inversion requires DataTree site entries with an inner child. " + f"Site '{site}' is a plain Dataset." + ) + # Legacy: plain xr.Dataset path — unchanged sensitivity = apply_fp_basis_functions( fp_x_flux=entry[fp_x_flux_name], @@ -199,29 +220,33 @@ def bc_sensitivity( dict of xr.Datasets in same format as fp_and_data with `H_bc` sensitivity matrix added. """ + def _outer_ds(entry): + if isinstance(entry, xr.DataTree): + if "standard" in entry.children: + return entry["standard"].ds + return entry.ds + return entry + + def _with_updated_outer(entry, updated_outer_ds): + if not isinstance(entry, xr.DataTree): + return updated_outer_ds + + if "standard" in entry.children: + tree_dict = {f"/{name}": child.ds for name, child in entry.children.items() if child.ds is not None} + tree_dict["/standard"] = updated_outer_ds + return xr.DataTree.from_dict(tree_dict) + + return xr.DataTree(dataset=updated_outer_ds, children=dict(entry.children)) + sites = [key for key in list(fp_and_data.keys()) if key[0] != "."] if basis_case.lower() == "nesw": for site in sites: - site_entry = fp_and_data[site] - if isinstance(site_entry, xr.DataTree): - standard_node = site_entry["standard"] if "standard" in site_entry.children else site_entry - standard_ds = standard_node.ds - bc_ds = standard_ds[[f"bc_{d}" for d in "nesw"]].rename({f"bc_{d}": d for d in "nesw"}) - sensitivity = bc_ds.sum(["lat", "lon", "height"]).to_dataarray(dim="bc_region") - updated_standard = standard_ds.assign({"H_bc": sensitivity}) - - if "standard" in site_entry.children: - dt_dict = {"/standard": updated_standard} - if "inner" in site_entry.children: - dt_dict["/inner"] = site_entry["inner"].ds - fp_and_data[site] = xr.DataTree.from_dict(dt_dict) - else: - fp_and_data[site] = xr.DataTree(dataset=updated_standard) - else: - bc_ds = site_entry[[f"bc_{d}" for d in "nesw"]].rename({f"bc_{d}": d for d in "nesw"}) - sensitivity = bc_ds.sum(["lat", "lon", "height"]).to_dataarray(dim="bc_region") - site_entry["H_bc"] = sensitivity + entry = fp_and_data[site] + outer_ds = _outer_ds(entry) + bc_ds = outer_ds[[f"bc_{d}" for d in "nesw"]].rename({f"bc_{d}": d for d in "nesw"}) + sensitivity = bc_ds.sum(["lat", "lon", "height"]).to_dataarray(dim="bc_region") + fp_and_data[site] = _with_updated_outer(entry, outer_ds.assign({"H_bc": sensitivity})) return fp_and_data @@ -239,27 +264,11 @@ def bc_sensitivity( bc_basis = basis_func.rename({dv: str(dv).replace("basis_", "") for dv in basis_func.data_vars}) for site in sites: - site_entry = fp_and_data[site] - if isinstance(site_entry, xr.DataTree): - standard_node = site_entry["standard"] if "standard" in site_entry.children else site_entry - standard_ds = standard_node.ds - bc_ds = standard_ds[[f"bc_{d}" for d in "nesw"]] - sensitivity = (bc_ds * bc_basis).sum(["lat", "lon", "height"]).to_dataarray(dim="__newdim__").sum("__newdim__") - sensitivity = sensitivity.rename(region="bc_region") - updated_standard = standard_ds.assign({"H_bc": sensitivity}) - - if "standard" in site_entry.children: - dt_dict = {"/standard": updated_standard} - if "inner" in site_entry.children: - dt_dict["/inner"] = site_entry["inner"].ds - fp_and_data[site] = xr.DataTree.from_dict(dt_dict) - else: - fp_and_data[site] = xr.DataTree(dataset=updated_standard) - else: - site_ds = site_entry - bc_ds = site_ds[[f"bc_{d}" for d in "nesw"]] - sensitivity = (bc_ds * bc_basis).sum(["lat", "lon", "height"]).to_dataarray(dim="__newdim__").sum("__newdim__") - sensitivity = sensitivity.rename(region="bc_region") - site_ds["H_bc"] = sensitivity + entry = fp_and_data[site] + outer_ds = _outer_ds(entry) + bc_ds = outer_ds[[f"bc_{d}" for d in "nesw"]] + sensitivity = (bc_ds * bc_basis).sum(["lat", "lon", "height"]).to_dataarray(dim="__newdim__").sum("__newdim__") + sensitivity = sensitivity.rename(region="bc_region") + fp_and_data[site] = _with_updated_outer(entry, outer_ds.assign({"H_bc": sensitivity})) return fp_and_data From 49ff459af7a48d04434771a93aec2839dcb2534c Mon Sep 17 00:00:00 2001 From: SutarPrasad Date: Wed, 20 May 2026 20:02:59 +0100 Subject: [PATCH 65/68] imporved handling of datatree --- openghg_inversions/hbmcmc/hbmcmc.py | 125 +++++++++++++++++----------- 1 file changed, 78 insertions(+), 47 deletions(-) diff --git a/openghg_inversions/hbmcmc/hbmcmc.py b/openghg_inversions/hbmcmc/hbmcmc.py index 5ed6a32a..b8ac7ffc 100644 --- a/openghg_inversions/hbmcmc/hbmcmc.py +++ b/openghg_inversions/hbmcmc/hbmcmc.py @@ -88,12 +88,24 @@ def make_inv_inputs( offset_args, power, ): - def _main_ds(entry: xr.Dataset | xr.DataTree) -> xr.Dataset: - if isinstance(entry, xr.DataTree): - if "standard" in entry.children: - return entry["standard"].ds - return entry.ds - return entry + def _outer_ds(site_entry): + if isinstance(site_entry, xr.DataTree): + if "standard" in site_entry.children: + return site_entry["standard"].ds + return site_entry.ds + return site_entry + + def _with_updated_outer(site_entry, updated_outer_ds): + if not isinstance(site_entry, xr.DataTree): + return updated_outer_ds + + # Preserve empty-root DataTree structure by writing data to /standard. + if "standard" in site_entry.children: + tree_dict = {f"/{name}": child.ds for name, child in site_entry.children.items() if child.ds is not None} + tree_dict["/standard"] = updated_outer_ds + return xr.DataTree.from_dict(tree_dict) + + return xr.DataTree(dataset=updated_outer_ds, children=dict(site_entry.children)) # Trigger dask computations # we only compute the variables we need below @@ -112,34 +124,24 @@ def _main_ds(entry: xr.Dataset | xr.DataTree) -> xr.Dataset: ] for site in sites: site_entry = fp_data[site] - if isinstance(site_entry, xr.DataTree): - has_standard = "standard" in site_entry.children - standard_ds = site_entry["standard"].ds if has_standard else site_entry.ds - - to_compute_site = [dv for dv in to_compute if dv in standard_ds.data_vars] - updated_standard = ( - standard_ds.assign({var: standard_ds[var].compute() for var in to_compute_site}) - if to_compute_site - else standard_ds - ) - - updated_inner = None - if "inner" in site_entry.children: - inner_ds = site_entry["inner"].ds - if "H_inner" in inner_ds.data_vars: - updated_inner = inner_ds.assign({"H_inner": inner_ds["H_inner"].compute()}) - else: - updated_inner = inner_ds - - standard_key = "/standard" if has_standard else "/" - dt_dict = {standard_key: updated_standard} - if updated_inner is not None: - dt_dict["/inner"] = updated_inner - fp_data[site] = xr.DataTree.from_dict(dt_dict) - else: - to_compute_site = [dv for dv in to_compute if dv in site_entry.data_vars] - if to_compute_site: - fp_data[site] = site_entry.assign({var: site_entry[var].compute() for var in to_compute_site}) + outer_ds = _outer_ds(site_entry) + + to_compute_outer = [dv for dv in to_compute if dv in outer_ds.data_vars] + if to_compute_outer: + computed_outer = outer_ds.assign({var: outer_ds[var].compute() for var in to_compute_outer}) + fp_data[site] = _with_updated_outer(site_entry, computed_outer) + + if isinstance(fp_data[site], xr.DataTree) and "inner" in fp_data[site].children: + inner_ds = fp_data[site]["inner"].ds + if "H_inner" in inner_ds.data_vars: + computed_inner = inner_ds.assign({"H_inner": inner_ds["H_inner"].compute()}) + tree_dict = { + f"/{name}": child.ds + for name, child in fp_data[site].children.items() + if child.ds is not None and name != "inner" + } + tree_dict["/inner"] = computed_inner + fp_data[site] = xr.DataTree.from_dict(tree_dict) # Get inputs ready error = np.zeros(0) @@ -163,24 +165,22 @@ def _main_ds(entry: xr.Dataset | xr.DataTree) -> xr.Dataset: if site in dropped_sites: continue - site_ds = _main_ds(fp_data[site]) - # select variables to drop NaNs from drop_vars = [] for var in ["H", "H_bc", "mf", "mf_error"]: - if var in site_ds.data_vars: + if var in _outer_ds(fp_data[site]).data_vars: drop_vars.append(var) # pymc doesn't like NaNs, so drop them for the variables used below # DataTree doesn't support dropna; use sel with valid time indices instead if isinstance(fp_data[site], xr.DataTree): - valid_times = site_ds.dropna("time", subset=drop_vars).time + valid_times = _outer_ds(fp_data[site]).dropna("time", subset=drop_vars).time fp_data[site] = fp_data[site].sel(time=valid_times) else: - fp_data[site] = site_ds.dropna("time", subset=drop_vars) + fp_data[site] = fp_data[site].dropna("time", subset=drop_vars) # repeatability/variability chosen/combined into mf_error in `get_data.py` - ds = _main_ds(fp_data[site]) + ds = _outer_ds(fp_data[site]) error = np.concatenate((error, ds["mf_error"].values)) obs_repeatability = np.concatenate((obs_repeatability, ds["mf_repeatability"].values)) @@ -268,7 +268,7 @@ def _main_ds(entry: xr.Dataset | xr.DataTree) -> xr.Dataset: if bc_freq == "monthly": Hmbc = setup.monthly_bcs(start_date, end_date, site, fp_data) elif bc_freq is None: - Hmbc = _main_ds(fp_data[site])["H_bc"].values + Hmbc = _outer_ds(fp_data[site])["H_bc"].values else: Hmbc = setup.create_bc_sensitivity(start_date, end_date, site, fp_data, bc_freq) @@ -552,6 +552,20 @@ def fixedbasisMCMC( print("Successfully read in merged data.\n") rerun_merge = False + if inner_domain is not None: + invalid_inner_sites = [ + s + for s, entry in fp_all.items() + if not s.startswith(".") + and not (isinstance(entry, xr.DataTree) and "inner" in entry.children) + ] + if invalid_inner_sites: + rerun_merge = True + print( + "Loaded merged data does not contain inner-domain DataTree entries " + f"for sites {invalid_inner_sites}; re-running data merge." + ) + # check if sites were dropped when merged data was saved sites_merged = [s for s in fp_all if "." not in s] @@ -658,10 +672,12 @@ def fixedbasisMCMC( for site in sites: entry = fp_data[site] if isinstance(entry, xr.DataTree): - # compute dask arrays while preserving standard/inner node layout - standard_ds = entry["standard"].ds if "standard" in entry.children else entry.ds - standard_key = "/standard" if "standard" in entry.children else "/" - dt_dict = {standard_key: standard_ds.compute()} + # compute dask arrays in standard and inner children + dt_dict = {} + if "standard" in entry.children: + dt_dict["/standard"] = entry["standard"].ds.compute() + elif entry.ds is not None: + dt_dict["/"] = entry.ds.compute() if "inner" in entry.children: dt_dict["/inner"] = entry["inner"].ds.compute() fp_data[site] = xr.DataTree.from_dict(dt_dict) @@ -672,8 +688,7 @@ def fixedbasisMCMC( dropped_sites = [] for site in sites: # check if some datasets are empty due to filtering - site_entry = fp_data[site] - site_ds = site_entry["standard"].ds if isinstance(site_entry, xr.DataTree) and "standard" in site_entry.children else (site_entry.ds if isinstance(site_entry, xr.DataTree) else site_entry) + site_ds = fp_data[site]["standard"].ds if isinstance(fp_data[site], xr.DataTree) and "standard" in fp_data[site].children else fp_data[site].ds if site_ds.time.values.shape[0] == 0: dropped_sites.append(site) del fp_data[site] @@ -685,6 +700,22 @@ def fixedbasisMCMC( for si, site in enumerate(sites): fp_data[site].attrs["Domain"] = domain + if inner_domain is not None: + invalid_inner_sites = [ + site + for site in sites + if not ( + isinstance(fp_data[site], xr.DataTree) + and "inner" in fp_data[site].children + and "H_inner" in fp_data[site]["inner"].ds.data_vars + ) + ] + if invalid_inner_sites: + raise ValueError( + "Inner-domain inversion requires DataTree entries with computed H_inner for all sites. " + f"Missing or invalid inner structure for: {invalid_inner_sites}." + ) + # Inverse models if use_tracer: raise ValueError("Model does not currently include tracer model. Watch this space") From 13e5b096f5246008b8f0f00113bc6aa282393a24 Mon Sep 17 00:00:00 2001 From: SutarPrasad Date: Thu, 21 May 2026 16:24:13 +0100 Subject: [PATCH 66/68] masking flux too --- openghg_inversions/inversion_data/get_data.py | 144 +++++++++++++++--- 1 file changed, 121 insertions(+), 23 deletions(-) diff --git a/openghg_inversions/inversion_data/get_data.py b/openghg_inversions/inversion_data/get_data.py index 3e0f80e6..cdeb5da4 100644 --- a/openghg_inversions/inversion_data/get_data.py +++ b/openghg_inversions/inversion_data/get_data.py @@ -152,35 +152,106 @@ def convert_to_list( return x +def _interp_bool_mask_to_grid(mask: xr.DataArray, lat: xr.DataArray, lon: xr.DataArray) -> xr.DataArray: + """Interpolate a boolean mask to a target lat/lon grid. + + xarray's interp requires numeric dtype, so convert to float, interpolate, + then threshold back to bool. + """ + return ( + mask.astype(float) + .interp(lat=lat, lon=lon, method="nearest") + .assign_coords(lat=lat, lon=lon) + .fillna(0.0) + >= 0.5 + ) + + def _apply_inner_mask_on_standard_domain(standard_footprint_data: xr.Dataset, inner_footprint_data: xr.Dataset): - """Apply inner domain mask to standard domain fp_x_flux to ensure that the standard domain sensitivity H only reflects the outer domain fluxes. - + """Mask standard-domain footprint values where inner-domain coverage exists. + Args: - standard_footprint_data: xr.DataArray containing the standard domain footprint data, with lat/lon coordinates. - inner_footprint_data: xr.DataArray containing the inner domain footprint data, with lat/lon coordinates. + standard_footprint_data: Standard-domain footprint dataset. + inner_footprint_data: Inner-domain footprint dataset. Returns: - fp: xr.DataArray of the same shape as standard_footprint_data, but with values set to zero in the region covered by inner_footprint_data. + Copy of ``standard_footprint_data`` with ``fp`` zeroed in cells covered by the inner domain. """ - fp_standard = standard_footprint_data.copy() - if inner_footprint_data is not None: - inner_ds = inner_footprint_data.copy() - # Build a boolean mask: True where lat/lon is inside inner domain bounds - inner_lat_min = float(inner_ds.lat.min()) - inner_lat_max = float(inner_ds.lat.max()) - inner_lon_min = float(inner_ds.lon.min()) - inner_lon_max = float(inner_ds.lon.max()) + if inner_footprint_data is None: + return fp_standard + + if "fp" not in fp_standard or "fp" not in inner_footprint_data: + return fp_standard + + inner_has_coverage = (inner_footprint_data["fp"] != 0).any("time") + inner_on_standard = _interp_bool_mask_to_grid( + mask=inner_has_coverage, + lat=fp_standard.lat, + lon=fp_standard.lon, + ) + + fp_standard["fp"] = fp_standard["fp"].where(~inner_on_standard, other=0.0) + return fp_standard + + +def _apply_inner_mask_on_standard_flux( + standard_flux_dict: dict, + inner_footprint_data: xr.Dataset, +) -> dict: + """Mask standard-domain flux values where inner-domain coverage exists. + + Args: + standard_flux_dict: Dict of FluxData-like objects with ``data.flux`` DataArray. + inner_footprint_data: Inner-domain footprint dataset. + + Returns: + New flux dict with ``flux`` zeroed in cells covered by the inner domain. + """ + if inner_footprint_data is None or "fp" not in inner_footprint_data: + return standard_flux_dict + + first_flux = next(iter(standard_flux_dict.values())).data.flux + inner_has_coverage = (inner_footprint_data["fp"] != 0).any("time") + inner_on_flux = _interp_bool_mask_to_grid( + mask=inner_has_coverage, + lat=first_flux.lat, + lon=first_flux.lon, + ) - lat_mask = (fp_standard.lat >= inner_lat_min) & (fp_standard.lat <= inner_lat_max) - lon_mask = (fp_standard.lon >= inner_lon_min) & (fp_standard.lon <= inner_lon_max) - inner_region_mask = lat_mask & lon_mask # broadcasts over (lat, lon) + return _apply_boolean_mask_on_standard_flux(standard_flux_dict, inner_on_flux) - # Zero out those cells in the outer fp - fp = fp_standard.where(~inner_region_mask, other=0.0) - return fp +def _apply_boolean_mask_on_standard_flux( + standard_flux_dict: dict, + mask_on_standard_grid: xr.DataArray, +) -> dict: + """Apply a precomputed boolean mask to all standard-domain flux sources. + + Args: + standard_flux_dict: Dict of FluxData-like objects with ``data.flux`` DataArray. + mask_on_standard_grid: Boolean mask on standard-domain lat/lon grid where True means "inner-domain coverage". + + Returns: + New flux dict with masked ``flux`` fields. + """ + masked_flux_dict = {} + + for source, flux_data in standard_flux_dict.items(): + flux_da = flux_data.data.flux + mask_for_flux = _interp_bool_mask_to_grid( + mask=mask_on_standard_grid, + lat=flux_da.lat, + lon=flux_da.lon, + ) + + masked_flux = flux_da.where(~mask_for_flux, other=0.0) + masked_ds = flux_data.data.copy() + masked_ds["flux"] = masked_flux + masked_flux_dict[source] = type(flux_data)(data=masked_ds, metadata=flux_data.metadata) + + return masked_flux_dict def data_processing_surface_notracer( @@ -335,6 +406,7 @@ def data_processing_surface_notracer( check_scales = set() units = {} site_indices_to_keep = [] + standard_flux_inner_mask_union = None keep_variables = [ f"{species}", @@ -396,6 +468,7 @@ def data_processing_surface_notracer( ) continue # skip this site inner_footprint_data = None + standard_flux_for_site = flux_dict if inner_domain is not None: print(f"Inner domain {inner_domain} specified; attempting to retrieve inner footprint data for {site} ...") inner_footprint_data = get_footprint_data( @@ -412,19 +485,37 @@ def data_processing_surface_notracer( obs_data=site_data, stores=inner_footprint_store if inner_footprint_store is not None else footprint_store, ) - - standard_footprint_data.data = _apply_inner_mask_on_standard_domain(standard_footprint_data= standard_footprint_data.data, inner_footprint_data= inner_footprint_data.data) - if inner_footprint_data is None: print( f"\nNo Inner footprint data found for {site} with inlet/height {fp_height[i]}, model {fp_model}, and domain {domain}-{inner_domain}, starting from {start_date} to {end_date}, fp_species {fp_species} and met_model {met_model[i]}, obs_data {site_data} and averaging_period {averaging_period[i]}.", f"Check these values.\nContinuing model run without {site}.Jai\n", ) continue # skip this site + # Mask the standard domain fp_x_flux with the inner domain mask BEFORE storing in DataTree + standard_footprint_data.data = _apply_inner_mask_on_standard_domain( + standard_footprint_data=standard_footprint_data.data, + inner_footprint_data=inner_footprint_data.data, + ) + standard_flux_for_site = _apply_inner_mask_on_standard_flux( + standard_flux_dict=flux_dict, + inner_footprint_data=inner_footprint_data.data, + ) + + first_flux = next(iter(flux_dict.values())).data.flux + inner_has_coverage = (inner_footprint_data.data["fp"] != 0).any("time") + inner_on_standard_flux = _interp_bool_mask_to_grid( + mask=inner_has_coverage, + lat=first_flux.lat, + lon=first_flux.lon, + ) + if standard_flux_inner_mask_union is None: + standard_flux_inner_mask_union = inner_on_standard_flux + else: + standard_flux_inner_mask_union = standard_flux_inner_mask_union | inner_on_standard_flux scenario_combined = merged_scenario_data( obs_data=site_data, footprint_data=standard_footprint_data, - flux_dict= flux_dict, bc_data=bc_data, inner_footprint_data=inner_footprint_data, + flux_dict=standard_flux_for_site, bc_data=bc_data, inner_footprint_data=inner_footprint_data, inner_flux_dict= inner_flux_dict if inner_emissions_store is not None else None, platform=platform[i], max_level=max_level ) @@ -440,6 +531,13 @@ def data_processing_surface_notracer( check_scales.add(root_ds.scale) site_indices_to_keep.append(i) + + if standard_flux_inner_mask_union is not None: + fp_all[".flux"] = _apply_boolean_mask_on_standard_flux( + standard_flux_dict=flux_dict, + mask_on_standard_grid=standard_flux_inner_mask_union, + ) + if len(site_indices_to_keep) == 0: raise SearchError("No site data found. Exiting process.") From 826cc8fa821aec3c5e147d810b29091f8e3a4887 Mon Sep 17 00:00:00 2001 From: SutarPrasad Date: Thu, 21 May 2026 16:24:38 +0100 Subject: [PATCH 67/68] dt handling --- .../inversion_data/serialise.py | 20 ++++++++++++++++--- 1 file changed, 17 insertions(+), 3 deletions(-) diff --git a/openghg_inversions/inversion_data/serialise.py b/openghg_inversions/inversion_data/serialise.py index d8904671..96136ca5 100644 --- a/openghg_inversions/inversion_data/serialise.py +++ b/openghg_inversions/inversion_data/serialise.py @@ -475,7 +475,7 @@ def datatree_to_flux_dict(dt: xr.DataTree) -> dict[str, FluxData]: def fp_all_to_datatree(fp_all: dict, netcdf_safe_attrs: bool = False) -> xr.DataTree: dt_dict: dict[str, xr.Dataset | xr.DataTree] = {} - scenario_dict = {} + scenario_dict: dict[str, xr.Dataset] = {} dt_attrs = {} if ".flux" in fp_all: @@ -487,7 +487,15 @@ def fp_all_to_datatree(fp_all: dict, netcdf_safe_attrs: bool = False) -> xr.Data if isinstance(v, BoundaryConditionsData): dt_dict[k.removeprefix(".")] = openghg_data_to_dataset(v, netcdf_safe_attrs) elif not k.startswith(".") and isinstance(v, xr.Dataset): - scenario_dict[k] = v + scenario_dict[f"/{k}"] = v + elif not k.startswith(".") and isinstance(v, xr.DataTree): + for group in v.groups: + node = v[group] + if node.ds is None: + continue + + rel_group = "" if group == "/" else group + scenario_dict[f"/{k}{rel_group}"] = node.ds else: dt_attrs[k] = v @@ -512,7 +520,13 @@ def datatree_to_fp_all(dt: xr.DataTree) -> dict: fp_all[".bc"] = dataset_to_bc_data(dt.bc.to_dataset()) for k, v in dt.scenarios.items(): - fp_all[str(k)] = v.to_dataset() + if len(v.children) > 0: + scenario_children = { + f"/{child_name}": child.ds for child_name, child in v.children.items() if child.ds is not None + } + fp_all[str(k)] = xr.DataTree.from_dict(scenario_children) + else: + fp_all[str(k)] = v.to_dataset() fp_all.update({str(k): v for k, v in dt.attrs.items()}) From 0b132d46a0aedd89e2a24035bdbf3c578baa934d Mon Sep 17 00:00:00 2001 From: SutarPrasad Date: Thu, 21 May 2026 16:24:51 +0100 Subject: [PATCH 68/68] dt handling --- openghg_inversions/hbmcmc/inversionsetup.py | 34 ++++++++++++++++++--- 1 file changed, 29 insertions(+), 5 deletions(-) diff --git a/openghg_inversions/hbmcmc/inversionsetup.py b/openghg_inversions/hbmcmc/inversionsetup.py index 1feecc93..4795aa61 100644 --- a/openghg_inversions/hbmcmc/inversionsetup.py +++ b/openghg_inversions/hbmcmc/inversionsetup.py @@ -2,17 +2,39 @@ import numpy as np import pandas as pd +import xarray as xr def _site_ds(fp_data: dict, site: str): entry = fp_data[site] - if hasattr(entry, "children"): - if "standard" in entry.children: + + if isinstance(entry, xr.DataTree): + if "standard" in entry.children and entry["standard"].ds is not None: return entry["standard"].ds - return entry.ds + + if entry.ds is not None: + return entry.ds + + # Fallback: use first non-empty child dataset if present. + for child in entry.children.values(): + if child.ds is not None: + return child.ds + + raise ValueError(f"Site '{site}' DataTree does not contain a dataset in root or child nodes.") + return entry +def _site_hbc(fp_data: dict, site: str) -> xr.DataArray: + site_ds = _site_ds(fp_data, site) + if "H_bc" not in site_ds.data_vars: + raise ValueError( + f"Boundary-condition sensitivity 'H_bc' is missing for site '{site}'. " + f"Available variables: {list(site_ds.data_vars)}" + ) + return site_ds["H_bc"] + + def monthly_bcs(start_date: str, end_date: str, site: str, fp_data: dict) -> np.ndarray: """Creates a sensitivity matrix (H-matrix) for the boundary conditions, which will map monthly boundary condition @@ -33,6 +55,7 @@ def monthly_bcs(start_date: str, end_date: str, site: str, fp_data: dict) -> np. Sensitivity matrix by month for observations """ site_ds = _site_ds(fp_data, site) + h_bc = _site_hbc(fp_data, site).values allmonth = pd.date_range(start_date, end_date, freq="MS")[:-1] nmonth = len(allmonth) curtime = pd.to_datetime(site_ds.time.values).to_period("M") @@ -48,7 +71,7 @@ def monthly_bcs(start_date: str, end_date: str, site: str, fp_data: dict) -> np. mnth = allmonth[m].month yr = allmonth[m].year mnthloc = np.where(np.logical_and(curtime.month == mnth, curtime.year == yr))[0] - hmbc[count, mnthloc] = site_ds["H_bc"].values[cord, mnthloc] + hmbc[count, mnthloc] = h_bc[cord, mnthloc] count += 1 return hmbc @@ -81,6 +104,7 @@ def create_bc_sensitivity(start_date: str, end_date: str, site: str, fp_data: di Sensitivity matrix by for observations to boundary conditions """ site_ds = _site_ds(fp_data, site) + h_bc = _site_hbc(fp_data, site).values dys = int("".join([s for s in freq if s.isdigit()])) alldates = pd.date_range( pd.to_datetime(start_date), pd.to_datetime(end_date) + pd.DateOffset(days=dys), freq=freq @@ -100,7 +124,7 @@ def create_bc_sensitivity(start_date: str, end_date: str, site: str, fp_data: di if len(dateloc) == 0: count += 1 continue - hmbc[count, dateloc] = site_ds["H_bc"].values[cord, dateloc] + hmbc[count, dateloc] = h_bc[cord, dateloc] count += 1 return hmbc