Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
69 commits
Select commit Hold shift + click to select a range
07a2dfc
inner_domain
SutarPrasad Feb 24, 2026
3b08f8b
not_merge
SutarPrasad Feb 24, 2026
e744f2e
paris hardcoded for test
SutarPrasad Feb 24, 2026
874b8a1
pulling fp and innner domain fp separately and passing to merged_scen…
SutarPrasad Mar 4, 2026
4e92b33
flux is kept to fetch europe domain, calls to get_footprint_data are …
SutarPrasad Mar 4, 2026
1cd0a8d
6km fp_x_flux is added as a separate variable
SutarPrasad Mar 4, 2026
2e7d17e
inner_fp_store and inner_domain is added to fixedbasisMCMC
SutarPrasad Mar 4, 2026
b7bd81d
combine_inner_outer_fp_x_flux
SutarPrasad Mar 5, 2026
6d55254
fp_x_flux fetched
SutarPrasad Mar 5, 2026
96c6d7c
adding extra print statements
SutarPrasad Mar 10, 2026
c733114
commenting inner-domain as we are regridding to europe domain before …
SutarPrasad Mar 10, 2026
ab727b4
removed hardcoded 6km
SutarPrasad Mar 13, 2026
27622d2
inner domain flux
SutarPrasad Mar 13, 2026
db9af9b
getting inner flux for inner scenario
SutarPrasad Mar 13, 2026
bfa1735
calculating H_inner
SutarPrasad Mar 16, 2026
e8a9af7
removed inner flux
SutarPrasad Mar 16, 2026
0541897
inner flux removal
SutarPrasad Mar 16, 2026
e7a872b
flux to inner
SutarPrasad Mar 17, 2026
c42bbaa
datatree only in merged_scenario
SutarPrasad Mar 18, 2026
ab1da6e
fetching outer_ds as root_ds
SutarPrasad Mar 18, 2026
499fcda
merged_scenario is datatree fp_all is still a dictionary
SutarPrasad Mar 18, 2026
920cc6d
root data fetch
SutarPrasad Mar 18, 2026
8ba74ef
changes based on datatree syntax
SutarPrasad Mar 20, 2026
889c6dc
hardcoded test
SutarPrasad Mar 20, 2026
9921f76
renamed lat lon of child dataset
SutarPrasad Mar 21, 2026
a0ef644
updates to handle fp_x_flux
SutarPrasad Mar 21, 2026
8828b49
compatibility with datatree changes
SutarPrasad Mar 23, 2026
2602046
datatree compatibility hbmcmc
SutarPrasad Mar 23, 2026
be015b0
datatree comaptibility
SutarPrasad Mar 24, 2026
a3332f8
writing back to fp_and_data from copied
SutarPrasad Mar 24, 2026
28caa48
copying the raw data modifying and writing back
SutarPrasad Mar 24, 2026
aa929ea
datatree compatiblity
SutarPrasad Mar 25, 2026
5e1f565
removed copy
SutarPrasad Mar 30, 2026
fd106fd
getting H_inner
SutarPrasad Mar 30, 2026
5c8bdc7
to original data
SutarPrasad Mar 30, 2026
431ae08
datgatree compatibility
SutarPrasad Mar 30, 2026
59f0211
ingesting H_inner in pymc
SutarPrasad Mar 30, 2026
e809ab4
Initial plan
Copilot Apr 1, 2026
7a0247a
Fix DataTree sensitivity matrix persistence and dask compute loop for…
Copilot Apr 1, 2026
bfad995
Merge pull request #385 from openghg/copilot/fix-sensitivity-matrix-p…
SutarPrasad Apr 1, 2026
12c6d5e
fp_x_flux_inner
SutarPrasad Apr 9, 2026
c17badc
fp_x_flux_inner
SutarPrasad Apr 9, 2026
b0322b9
test
SutarPrasad Apr 9, 2026
3720ac3
masking inner domain from standard domain
SutarPrasad Apr 9, 2026
0b940b7
empty root two separate nodes and same dim names
SutarPrasad Apr 13, 2026
4584cfe
handling of separate nodes in the code
SutarPrasad Apr 23, 2026
85cab69
considers hx solved separately and 0 at inner domain
SutarPrasad Apr 23, 2026
189082c
inner domain passing to basis function wrapper
SutarPrasad Apr 23, 2026
53408e0
passing inner domain
SutarPrasad Apr 23, 2026
2bec153
aaplying mask on standard footprint if inner_domain is not none
SutarPrasad Apr 28, 2026
ee6856a
removed masking from fp_sensitivity as fp is masked in get_data
SutarPrasad Apr 28, 2026
86334d9
applying mask on standard_fp f inner domain is present
SutarPrasad Apr 28, 2026
29a4726
calculating inner basis data array for inner fp and flux by checking …
SutarPrasad May 12, 2026
a5dea90
inner flux dic accepted by merged scenario
SutarPrasad May 12, 2026
359e9fb
inner emissions store is passed to fetch flux for inner domain
SutarPrasad May 12, 2026
fbe9be5
temp ini needs update for default paths to stores
SutarPrasad May 12, 2026
278ec94
passing inner emissions store
SutarPrasad May 12, 2026
309d7ba
inner basis func is accepted in fp_sensitivity and applies fp basis f…
SutarPrasad May 12, 2026
bf81d5b
flux and fp based on standard and inner scenario are fetched
SutarPrasad May 12, 2026
55df001
requires update to paths in conftest too
SutarPrasad May 12, 2026
c6fcd7b
datatree handling across codebase
SutarPrasad May 13, 2026
6de2cc3
x out var traces and comning of x inner and x into traces output of t…
SutarPrasad May 13, 2026
4a0d28d
time masking other dims issue raised due to datatree passing values w…
SutarPrasad May 13, 2026
e9bbc29
reducing test probably havin this test might not be the good solution
SutarPrasad May 13, 2026
96252eb
improved handling of inner and standrd domain in datatree as well as …
SutarPrasad May 20, 2026
49ff459
imporved handling of datatree
SutarPrasad May 20, 2026
13e5b09
masking flux too
SutarPrasad May 21, 2026
826cc8f
dt handling
SutarPrasad May 21, 2026
0b132d4
dt handling
SutarPrasad May 21, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 38 additions & 9 deletions openghg_inversions/basis/_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def basis_boundary_conditions(domain: str, basis_case: str, bc_basis_directory:


def _flux_fp_from_fp_all(
fp_all: dict, emissions_name: list[str] | None = None
fp_all: dict, emissions_name: list[str] | None = None, scenario = "standard"
) -> tuple[xr.DataArray, list[xr.DataArray]]:
"""Get flux and list of footprints from `fp_all` dictionary and optional list of emissions names.

Expand All @@ -149,15 +149,38 @@ def _flux_fp_from_fp_all(
footprints (list):
List of xarray DataArray containing the footprints of each sites.
"""
flux_key = ".inner_flux" if scenario == "inner" and ".inner_flux" in fp_all else ".flux"

if emissions_name is not None:
flux = fp_all[".flux"][emissions_name[0]].data.flux
flux = fp_all[flux_key][emissions_name[0]].data.flux
else:
first_flux = next(iter(fp_all[".flux"].values()))
first_flux = next(iter(fp_all[flux_key].values()))
flux = first_flux.data.flux

flux = cast(xr.DataArray, flux)

footprints: list[xr.DataArray] = [v.fp for k, v in fp_all.items() if not k.startswith(".")]
footprints = []
for k, v in fp_all.items():
if k.startswith("."):
continue

# Need to discuss this further
# fp_x_flux is guaranteed to be on the
# EUROPE grid (it comes from the standard scenario). Raw .fp may be
# on the 6km grid if OpenGHG snapped it to the footprint resolution.
if scenario == "inner" and isinstance(v, xr.DataTree) and "inner" in v.children:
fp = v["inner"].ds["fp_x_flux"]
elif "fp_x_flux" in v:
fp = v["fp_x_flux"]
else:
fp = v[scenario].ds.fp

# if grid still doesn't match flux, regrid to flux grid
if fp.sizes.get("lat") != flux.sizes.get("lat") or fp.sizes.get("lon") != flux.sizes.get("lon"):
fp = fp.interp(lat=flux.lat, lon=flux.lon, method="nearest").fillna(0.0)
fp = fp.assign_coords(lat=flux.lat, lon=flux.lon)

footprints.append(fp)

return flux, footprints

Expand Down Expand Up @@ -213,6 +236,7 @@ def quadtreebasisfunction(
fp_all: dict,
start_date: str,
domain : str,
scenario: str = "standard",
emissions_name: list[str] | None = None,
nbasis: int = 100,
country_directory: str | None = None,
Expand Down Expand Up @@ -258,7 +282,7 @@ def quadtreebasisfunction(
quad_basis (xarray.DataArray):
Array with lat/lon dimensions and basis regions encoded by integers.
"""
flux, footprints = _flux_fp_from_fp_all(fp_all, emissions_name)
flux, footprints = _flux_fp_from_fp_all(fp_all=fp_all, emissions_name=emissions_name, scenario=scenario)
fps = _mean_fp_times_mean_flux(flux, footprints, abs_flux=abs_flux, mask=mask).as_numpy()

# use xr.apply_ufunc to keep xarray coords
Expand All @@ -283,7 +307,8 @@ def bucketbasisfunction(
nbasis: int = 100,
country_directory: str | None = None,
abs_flux: bool = False,
mask: xr.DataArray | None = None
mask: xr.DataArray | None = None,
scenario: str = "standard",
) -> xr.DataArray:
"""Basis functions calculated using a weighted region approach
where each basis function / scaling region contains approximately
Expand Down Expand Up @@ -311,12 +336,15 @@ def bucketbasisfunction(
Default None
country_directory (str):
Directory containing land-sea files. If None, will use default files.
scenario (str):
Scenario for retrieving emissions files.
Default "standard"

Returns:
bucket_basis (xarray.DataArray):
Array with lat/lon dimensions and basis regions encoded by integers.
"""
flux, footprints = _flux_fp_from_fp_all(fp_all, emissions_name)
flux, footprints = _flux_fp_from_fp_all(fp_all=fp_all, emissions_name=emissions_name, scenario=scenario)
fps = _mean_fp_times_mean_flux(flux, footprints, abs_flux=abs_flux, mask=mask).as_numpy()
fps = fps / fps.max()

Expand Down Expand Up @@ -351,6 +379,7 @@ def fixed_outer_regions_basis(
nbasis: int = 100,
country_directory: str | None = None,
abs_flux: bool = False,
scenario: str = "standard"
) -> xr.DataArray:
"""Fix outer region of basis functions to InTEM regions, and fit the inner regions using `basis_algorithm`.

Expand Down Expand Up @@ -389,15 +418,15 @@ def fixed_outer_regions_basis(
intem_regions = xr.open_dataset(intem_regions_path).region

# force intem_regions to use flux coordinates
flux, _ = _flux_fp_from_fp_all(fp_all, emissions_name)
flux, _ = _flux_fp_from_fp_all(fp_all=fp_all, emissions_name=emissions_name, scenario=scenario)
_, intem_regions = xr.align(flux, intem_regions, join="override")

inner_index = intem_regions.values.max()

mask = intem_regions == inner_index

basis_function = basis_functions[basis_algorithm].algorithm
inner_region = basis_function(fp_all, start_date, domain, emissions_name, nbasis, country_directory, abs_flux, mask=mask)
inner_region = basis_function(fp_all=fp_all, start_date=start_date, domain=domain, emissions_name=emissions_name, nbasis=nbasis, country_directory=country_directory, abs_flux=abs_flux, mask=mask)

basis = intem_regions.rename("basis")

Expand Down
142 changes: 128 additions & 14 deletions openghg_inversions/basis/_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,11 @@
from ._functions import basis_boundary_conditions


def fp_sensitivity(fp_and_data: dict, basis_func: xr.DataArray | dict[str, xr.DataArray]) -> dict:
def fp_sensitivity(
fp_and_data: dict,
basis_func: xr.DataArray | dict[str, xr.DataArray],
inner_basis_func: xr.DataArray | None = None,
) -> dict:
"""Add a sensitivity matrix, H, to each site xr.Dataset in fp_and_data.

The sensitivity matrix H takes the footprint sensitivities (the `fp` variable),
Expand Down Expand Up @@ -38,7 +42,6 @@ def fp_sensitivity(fp_and_data: dict, basis_func: xr.DataArray | dict[str, xr.Da
if len(flux_sources) == 1:
if not isinstance(basis_func, xr.DataArray):
basis_func = next(iter(basis_func.values()))

fp_x_flux_name = "fp_x_flux"

else:
Expand All @@ -63,18 +66,107 @@ def fp_sensitivity(fp_and_data: dict, basis_func: xr.DataArray | dict[str, xr.Da
if "time" in basis_func.dims and basis_func.sizes["time"] <= 1:
basis_func = basis_func.squeeze("time")

if inner_basis_func is not None and "time" in inner_basis_func.dims and inner_basis_func.sizes["time"] <= 1:
inner_basis_func = inner_basis_func.squeeze("time")

fp_and_data[".basis"] = basis_func
if inner_basis_func is not None:
fp_and_data[".basis_inner"] = inner_basis_func

if inner_basis_func is not None:
invalid_inner_sites = [
site
for site in sites
if not (
isinstance(fp_and_data[site], xr.DataTree) and "inner" in fp_and_data[site].children
)
]
if invalid_inner_sites:
raise ValueError(
"Inner basis supplied, but some sites are not DataTree entries with an inner child: "
f"{invalid_inner_sites}."
)

for site in sites:
sensitivity = apply_fp_basis_functions(
fp_x_flux=fp_and_data[site][fp_x_flux_name],
basis_func=basis_func,
)
fp_and_data[site]["H"] = sensitivity
entry = fp_and_data[site]


# extract root fp_x_flux (already masked if inner domain exists)
if isinstance(entry, xr.DataTree):
if "standard" in entry.children:
root_ds = entry["standard"].ds
else:
root_ds = entry.ds
fp_x_flux_outer = root_ds[fp_x_flux_name]

# Compute outer H from the (already masked) fp_x_flux
sensitivity = apply_fp_basis_functions(
fp_x_flux=fp_x_flux_outer,
basis_func=basis_func,
)

# Compute H_inner from the inner child's fp_x_flux (its own lat/lon grid)
if "inner" in entry.children:
if inner_basis_func is None:
raise ValueError("Inner-domain data exists but no inner basis function was provided.")

inner_fp_x_flux = entry["inner"].ds[fp_x_flux_name]
H_inner = apply_fp_basis_functions(
fp_x_flux=inner_fp_x_flux,
basis_func=inner_basis_func,
)
# Write both back into the DataTree
new_root = root_ds.assign({"H": sensitivity})
new_inner = entry["inner"].ds.assign({"H_inner": H_inner})
fp_and_data[site] = xr.DataTree.from_dict({
"/standard": new_root,
"/inner": new_inner,
})
else:
fp_and_data[site] = xr.DataTree.from_dict({
"/standard": root_ds.assign({"H": sensitivity})
})

else:
if inner_basis_func is not None:
raise ValueError(
"Inner-domain inversion requires DataTree site entries with an inner child. "
f"Site '{site}' is a plain Dataset."
)

# Legacy: plain xr.Dataset path — unchanged
sensitivity = apply_fp_basis_functions(
fp_x_flux=entry[fp_x_flux_name],
basis_func=basis_func,
)
fp_and_data[site]["H"] = sensitivity

return fp_and_data


def combine_inner_outer_fp_x_flux(
Comment thread
SutarPrasad marked this conversation as resolved.
inner_fp_x_flux: xr.DataArray,
outer_fp_x_flux: xr.DataArray,
) -> xr.DataArray:
"""Merge inner (6km) and outer (EUROPE) fp_x_flux."""
# Regrid inner fp_x_flux to the same grid as outer fp_x_flux, and then patch it in where the inner domain mask is True.
# regrid inner to EUROPE lat/lon coords
inner_regridded = inner_fp_x_flux.interp(lat=outer_fp_x_flux.lat, lon=outer_fp_x_flux.lon, method="nearest")

# force coordinates to exactly match outer (avoids float precision
# mismatches that prevent xr.align / xr.where from working correctly)
inner_regridded = inner_regridded.assign_coords(lat=outer_fp_x_flux.lat, lon=outer_fp_x_flux.lon)

# fill NaN (points outside the inner domain extent) with 0
inner_regridded = inner_regridded.fillna(0.0)

# True where the inner domain contributed non-zero values at any timestep
inner_has_coverage = (inner_regridded != 0).any("time")

# Both arrays are now on the EUROPE grid so xr.where is safe
return xr.where(inner_has_coverage, inner_regridded, outer_fp_x_flux)


def apply_fp_basis_functions(
fp_x_flux: xr.DataArray,
basis_func: xr.DataArray,
Expand All @@ -100,7 +192,9 @@ def apply_fp_basis_functions(

_, basis_aligned = xr.align(fp_x_flux.isel(time=0), basis_func, join="override")
basis_mat = get_xr_dummies(basis_aligned, cat_dim="region")
sensitivity = sparse_xr_dot(basis_mat, fp_x_flux.fillna(0.0), dim=["lat", "lon"])
spatial_dims = ["lat", "lon"]

sensitivity = sparse_xr_dot(basis_mat, fp_x_flux.fillna(0.0), dim=spatial_dims)

if sensitivity.dims[:2] != ("region", "time"):
sensitivity = sensitivity.transpose("region", "time", ...)
Expand All @@ -126,14 +220,33 @@ def bc_sensitivity(
dict of xr.Datasets in same format as fp_and_data with `H_bc` sensitivity matrix added.

"""
def _outer_ds(entry):
if isinstance(entry, xr.DataTree):
if "standard" in entry.children:
return entry["standard"].ds
return entry.ds
return entry

def _with_updated_outer(entry, updated_outer_ds):
if not isinstance(entry, xr.DataTree):
return updated_outer_ds

if "standard" in entry.children:
tree_dict = {f"/{name}": child.ds for name, child in entry.children.items() if child.ds is not None}
tree_dict["/standard"] = updated_outer_ds
return xr.DataTree.from_dict(tree_dict)

return xr.DataTree(dataset=updated_outer_ds, children=dict(entry.children))

sites = [key for key in list(fp_and_data.keys()) if key[0] != "."]

if basis_case.lower() == "nesw":
for site in sites:
ds = fp_and_data[site]
bc_ds = ds[[f"bc_{d}" for d in "nesw"]].rename({f"bc_{d}": d for d in "nesw"})
entry = fp_and_data[site]
outer_ds = _outer_ds(entry)
bc_ds = outer_ds[[f"bc_{d}" for d in "nesw"]].rename({f"bc_{d}": d for d in "nesw"})
sensitivity = bc_ds.sum(["lat", "lon", "height"]).to_dataarray(dim="bc_region")
fp_and_data[site]["H_bc"] = sensitivity
fp_and_data[site] = _with_updated_outer(entry, outer_ds.assign({"H_bc": sensitivity}))

return fp_and_data

Expand All @@ -151,10 +264,11 @@ def bc_sensitivity(
bc_basis = basis_func.rename({dv: str(dv).replace("basis_", "") for dv in basis_func.data_vars})

for site in sites:
ds = fp_and_data[site]
bc_ds = ds[[f"bc_{d}" for d in "nesw"]]
entry = fp_and_data[site]
outer_ds = _outer_ds(entry)
bc_ds = outer_ds[[f"bc_{d}" for d in "nesw"]]
sensitivity = (bc_ds * bc_basis).sum(["lat", "lon", "height"]).to_dataarray(dim="__newdim__").sum("__newdim__")
sensitivity = sensitivity.rename(region="bc_region")
fp_and_data[site]["H_bc"] = sensitivity
fp_and_data[site] = _with_updated_outer(entry, outer_ds.assign({"H_bc": sensitivity}))

return fp_and_data
34 changes: 31 additions & 3 deletions openghg_inversions/basis/_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ def basis_functions_wrapper(
country_directory: str | None = None,
outputname: str | None = None,
output_path: str | None = None,
inner_domain: str | None = None,
):
"""Wrapper function for selecting basis function
algorithm.
Expand Down Expand Up @@ -81,6 +82,7 @@ def basis_functions_wrapper(
Dictionary object similar to fp_all but with information
on basis functions and sensitivities
"""
inner_basis_data_array = None
if use_bc is True and bc_basis_case is None:
raise ValueError("If `use_bc` is True, you must specify `bc_basis_case`.")

Expand Down Expand Up @@ -118,13 +120,39 @@ def basis_functions_wrapper(
"Basis algorithm not recognised. Please use either 'quadtree' or 'weighted', or input a basis function file"
) from e
print(f"Using {basis_function.description} to derive basis functions.")
basis_data_array = basis_function.algorithm(fp_all, start_date, domain, emissions_name, nbasis, country_directory=country_directory)

if inner_domain is not None:
inner_basis_data_array = basis_function.algorithm(
fp_all=fp_all,
start_date=start_date,
domain=f"{domain}-{inner_domain}",
emissions_name=emissions_name,
nbasis=nbasis,
country_directory=country_directory,
scenario="inner",
)

print(f"Computing inner basis took {time() - basis_start}s.")

basis_data_array = basis_function.algorithm(
fp_all=fp_all,
start_date=start_date,
domain=domain,
emissions_name=emissions_name,
nbasis=nbasis,
country_directory=country_directory,
)

print(f"Computing basis took {time() - basis_start}s.")

fp_sens_start = time()
fp_data = fp_sensitivity(fp_all, basis_func=basis_data_array)
print(f"Computing fp sensitivity took {time() - fp_sens_start}s.")
if basis_data_array is not None:
fp_data = fp_sensitivity(
fp_all,
basis_func=basis_data_array,
inner_basis_func=inner_basis_data_array,
)
print(f"Computing fp sensitivity took {time() - fp_sens_start}s.")

if use_bc is True:
bc_sens_start = time()
Expand Down
Loading