Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions imod/common/constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
"""
Store constants here that are used across the package. This is to avoid circular
imports and to have a single source of truth for these values.
"""

from dataclasses import dataclass

import numpy as np


@dataclass
class MaskValues:
"""
Stores mask values for nodata. Special sentinel values can be stored in
here, such as the -9999.0 for MetaSWAP.
"""

bool = False
float = np.nan
integer = 0
msw_default = -9999.0
25 changes: 13 additions & 12 deletions imod/common/utilities/clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
clipped_zbound_linestrings_to_vertical_polygons,
vertical_polygons_to_zbound_linestrings,
)
from imod.common.utilities.mask import mask_da
from imod.common.utilities.value_filters import is_valid
from imod.typing import GeoDataFrameType, GridDataArray, GridDataset
from imod.typing.grid import bounding_polygon, is_spatial_grid
Expand Down Expand Up @@ -128,18 +129,18 @@ def _filter_inactive_cells(package: IPackage, active: GridDataArray):
return

package_vars = package.dataset.data_vars
for var in package_vars:
if package_vars[var].shape != ():
if is_spatial_grid(package.dataset[var]):
other = (
0
if np.issubdtype(package.dataset[var].dtype, np.integer)
else np.nan
)

package.dataset[var] = package.dataset[var].where(
active > 0, other=other
)
to_mask = [
var
for var in package_vars
if (package_vars[var].shape != () and is_spatial_grid(package.dataset[var]))
]
# Shortcut if nothing to mask to avoid computing mask unnecessarily.
if not to_mask:
return

mask = active > 0
for var in to_mask:
package.dataset[var] = mask_da(package.dataset[var], mask)


def _clip_linestring(
Expand Down
36 changes: 36 additions & 0 deletions imod/common/utilities/dtype.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
"""
Module for utilities related to data types of different kinds of origins.
np.issubdtype unfortunately is too restrictive for some of our use cases, as it
does not for example consider pandas extension dtypes.
"""

import numbers

import numpy as np
from numpy.typing import DTypeLike


def is_float(dtype: DTypeLike) -> bool:
try:
return np.issubdtype(dtype, np.floating)
except TypeError:
# Catch cases where dtype is not a numpy dtype and check if subclass is
# not an integer. As numpy-style integers are also considered real
# numbers.
return issubclass(dtype.type, numbers.Real) and not issubclass(
dtype.type, numbers.Integral
)


def is_integer(dtype: DTypeLike) -> bool:
try:
return np.issubdtype(dtype, np.integer)
except TypeError:
return issubclass(dtype.type, numbers.Integral)


def is_bool(dtype: DTypeLike) -> bool:
try:
return np.issubdtype(dtype, np.bool_)
except TypeError:
return issubclass(dtype.type, np.bool)
77 changes: 59 additions & 18 deletions imod/common/utilities/mask.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
import numbers

import xarray as xr
from plum import Dispatcher
from xarray.core.utils import is_scalar

from imod.common.constants import MaskValues
from imod.common.interfaces.imaskingsettings import IMaskingSettings
from imod.common.interfaces.imodel import IModel
from imod.common.interfaces.ipackage import IPackage
from imod.common.interfaces.isimulation import ISimulation
from imod.common.utilities.dtype import is_bool, is_float, is_integer
from imod.typing.grid import (
GridDataArray,
concat,
Expand Down Expand Up @@ -80,7 +80,7 @@ def mask_package(package: IPackage, mask: GridDataArray) -> IPackage:
if _skip_dataarray(package.dataset[var]) or _skip_variable(package, var):
masked[var] = package.dataset[var]
else:
masked[var] = _mask_spatial_var(package, var, mask)
masked[var] = _mask_spatial_var_pkg(package, var, mask)

return type(package)(**masked)

Expand Down Expand Up @@ -108,22 +108,43 @@ def _skip_variable(package: IMaskingSettings, var: str) -> bool:
return var in package.skip_variables


def _mask_spatial_var(self, var: str, mask: GridDataArray) -> GridDataArray:
da = self.dataset[var]
array_mask = _adjust_mask_for_unlayered_data(da, mask)
active = array_mask > 0
def mask_da(da: GridDataArray, mask: GridDataArray) -> GridDataArray:
"""
Mask a DataArray with a boolean mask. Function attempts to preserve the
dtype of the original DataArray. It will set the
value to 0 for integers, np.nan for floats, and False for booleans.
"""

if issubclass(da.dtype.type, numbers.Integral):
if var == "idomain":
return da.where(active, other=array_mask)
else:
return da.where(active, other=0)
elif issubclass(da.dtype.type, numbers.Real):
return da.where(active)
if is_integer(da.dtype):
other = MaskValues.integer
elif is_float(da.dtype):
other = MaskValues.float
elif is_bool(da.dtype):
other = MaskValues.bool
else:
raise TypeError(
f"Expected dtype float or integer. Received instead: {da.dtype}"
f"Expected dtype float, integer, or bool. Received instead: {da.dtype}"
)
# Align the mask, as calling where with "other" specified does not
# automatically align the mask to the DataArray.
_, mask = xr.align(da, mask, join="left", copy=False)
return da.where(mask, other=other)


def _mask_spatial_var_pkg(
package: IPackage, var: str, mask: GridDataArray
) -> GridDataArray:
"""
Mask a spatial variable in a package. There is some additional logic for the
MF6 DIS/DISV packages to work with unlayered grids for the "top" value.
"""
da = package.dataset[var]
array_mask = _adjust_mask_for_unlayered_data(da, mask)
active = array_mask > 0

if var == "idomain":
return da.where(active, other=array_mask)
return mask_da(da, active)


def _adjust_mask_for_unlayered_data(
Expand All @@ -145,17 +166,37 @@ def _adjust_mask_for_unlayered_data(
return array_mask


def make_mask(da: GridDataArray):
"""
Make a boolean mask from a DataArray. The mask is True where the values are
not equal to the nodata value. The nodata value is determined by the dtype
of the DataArray. For integers, the nodata value is 0. For floats, the
nodata value is np.nan. For booleans, the nodata value is False.
"""
if is_integer(da.dtype):
return da != MaskValues.integer
elif is_float(da.dtype):
return notnull(da)
elif is_bool(da.dtype):
return da != MaskValues.bool
else:
raise TypeError(
f"Expected dtype float, integer, or bool. Received instead: {da.dtype}"
)


def mask_arrays(arrays: dict[str, GridDataArray]) -> dict[str, GridDataArray]:
"""
This function takes a dictionary of xr.DataArrays. The arrays are assumed to have the same
coordinates. When a np.nan value is found in any array, the other arrays are also
set to np.nan at the same coordinates.
"""
masks = [notnull(array) for array in arrays.values()]
# Get total mask across all arrays

masks = [make_mask(array) for array in arrays.values()]
# Get total mask across all arrays.
total_mask = concat(masks, dim="arrays").all("arrays")
# Mask arrays with total mask
arrays_masked = {key: array.where(total_mask) for key, array in arrays.items()}
arrays_masked = {key: mask_da(array, total_mask) for key, array in arrays.items()}
return arrays_masked


Expand Down
12 changes: 8 additions & 4 deletions imod/common/utilities/regrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from xarray.core.utils import is_scalar
from xugrid.regrid.regridder import BaseRegridder

from imod.common.constants import MaskValues
from imod.common.interfaces.ilinedatapackage import ILineDataPackage
from imod.common.interfaces.imodel import IModel
from imod.common.interfaces.ipackage import IPackage
Expand All @@ -17,6 +18,7 @@
from imod.common.interfaces.isimulation import ISimulation
from imod.common.utilities.clip import clip_by_grid
from imod.common.utilities.dataclass_type import DataclassType, EmptyRegridMethod
from imod.common.utilities.dtype import is_integer
from imod.common.utilities.value_filters import is_valid
from imod.typing.grid import (
GridDataArray,
Expand Down Expand Up @@ -107,8 +109,8 @@ def _regrid_array(
# Nans can be introduced when the source data has a nan value, or when the
# target grid has a larger domain. Fill nans with 0 for integer, as this is
# mainly important for the idomain array where 0 indicates an inactive cell.
if np.issubdtype(original_dtype, np.integer):
regridded_array = regridded_array.fillna(0)
if is_integer(original_dtype):
regridded_array = regridded_array.fillna(MaskValues.integer)
return regridded_array.astype(original_dtype)


Expand Down Expand Up @@ -438,7 +440,7 @@ def _get_regridding_domain(
# to track nodata cells and verify that the regridding process does not
# inadvertently affect them. The use of np.abs() simplifies the logic by
# avoiding additional conditional checks for -1 and 1 separately.
is_active = np.abs(idomain.where(idomain != 0, other=np.nan))
is_active = np.abs(idomain.where(idomain != 0, other=MaskValues.float))
included_in_all = ones_like(target_grid)
# Take the first regridder function, as each regridder type handles nans
# consistently amongst methods.
Expand All @@ -452,6 +454,8 @@ def _get_regridding_domain(
idomain_regridder_type = regridder.regrid(is_active)
included_in_all = included_in_all.where(idomain_regridder_type.notnull())

new_idomain = regridded_domain.where(included_in_all.notnull(), other=0).astype(int)
new_idomain = regridded_domain.where(
included_in_all.notnull(), other=MaskValues.integer
).astype(int)

return new_idomain
6 changes: 3 additions & 3 deletions imod/msw/infiltration.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@

import xarray as xr

from imod.common.constants import MaskValues
from imod.common.interfaces.iregridpackage import IRegridPackage
from imod.logging import LogLevel, logger
from imod.msw.fixed_format import VariableMetaData
from imod.msw.pkgbase import MetaSwapPackage
from imod.msw.regrid.regrid_schemes import InfiltrationRegridMethod
from imod.msw.utilities.common import concat_imod5
from imod.msw.utilities.mask import MaskValues
from imod.typing import GridDataDict, Imod5DataDict
from imod.typing.grid import ones_like

Expand All @@ -29,7 +29,7 @@ def deactivate_small_resistances_in_data(data: GridDataDict):
message=message.format(var=var),
additional_depth=1,
)
data[var] = data[var].where(~to_deactivate, MaskValues.default)
data[var] = data[var].where(~to_deactivate, MaskValues.msw_default)
return data


Expand Down Expand Up @@ -142,7 +142,7 @@ def from_imod5_data(cls, imod5_data: Imod5DataDict) -> "Infiltration":
data = deactivate_small_resistances_in_data(data)

like = data["downward_resistance"].isel(subunit=0, drop=True)
data["bottom_resistance"] = ones_like(like) * MaskValues.default
data["bottom_resistance"] = ones_like(like) * MaskValues.msw_default
data["extra_storage_coefficient"] = ones_like(like)

return cls(**data)
4 changes: 2 additions & 2 deletions imod/msw/meteo_grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,13 @@
import xarray as xr

import imod
from imod.common.constants import MaskValues
from imod.common.interfaces.iregridpackage import IRegridPackage
from imod.common.utilities.dataclass_type import DataclassType, EmptyRegridMethod
from imod.msw.pkgbase import MetaSwapPackage
from imod.msw.regrid.regrid_schemes import MeteoGridRegridMethod
from imod.msw.timeutil import to_metaswap_timeformat
from imod.msw.utilities.common import find_in_file_list
from imod.msw.utilities.mask import MaskValues
from imod.typing import Imod5DataDict


Expand Down Expand Up @@ -178,7 +178,7 @@ def write(self, directory: Union[str, Path], *args):
".asc"
)
imod.rasterio.save(
path, self.dataset[str(varname)], nodata=MaskValues.default
path, self.dataset[str(varname)], nodata=MaskValues.msw_default
)

def _pkgcheck(self):
Expand Down
5 changes: 3 additions & 2 deletions imod/msw/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import numpy as np
import xarray as xr

from imod.common.constants import MaskValues
from imod.common.utilities.clip import clip_by_grid
from imod.common.utilities.partitioninfo import create_partition_info
from imod.common.utilities.value_filters import enforce_scalar
Expand Down Expand Up @@ -42,7 +43,7 @@
from imod.msw.timeutil import to_metaswap_timeformat
from imod.msw.utilities.common import find_in_file_list
from imod.msw.utilities.imod5_converter import has_active_scaling_factor
from imod.msw.utilities.mask import MaskValues, mask_and_broadcast_cap_data
from imod.msw.utilities.mask import mask_and_broadcast_cap_data
from imod.msw.utilities.parse import read_para_sim
from imod.msw.vegetation import AnnualCropFactors
from imod.typing import GridDataArray, Imod5DataDict
Expand Down Expand Up @@ -652,7 +653,7 @@ def from_imod5_data(
if has_active_scaling_factor(imod5_cap_no_layer["cap"]):
model["scaling_factor"] = ScalingFactors.from_imod5_data(imod5_masked)
area = model["grid"]["area"].isel(subunit=0, drop=True)
model["idf_mapping"] = IdfMapping(area, MaskValues.default)
model["idf_mapping"] = IdfMapping(area, MaskValues.msw_default)
model["coupling"] = CouplerMapping()
model["extra_files"] = FileCopier.from_imod5_data(imod5_masked)

Expand Down
5 changes: 3 additions & 2 deletions imod/msw/utilities/imod5_converter.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from xarray.core.utils import is_scalar

from imod.common.constants import MaskValues
from imod.logging import LogLevel, logger
from imod.mf6 import StructuredDiscretization
from imod.msw.utilities.common import concat_imod5
from imod.msw.utilities.mask import MaskValues, MetaSwapActive
from imod.msw.utilities.mask import MetaSwapActive
from imod.typing import GridDataArray, GridDataDict
from imod.typing.grid import ones_like
from imod.util.spatial import get_cell_area
Expand Down Expand Up @@ -103,7 +104,7 @@ def has_active_scaling_factor(imod5_cap: GridDataDict):
function shortcuts if data is provided as constant.
"""
variable_inactive_mapping = {
"perched_water_table_level": MaskValues.default,
"perched_water_table_level": MaskValues.msw_default,
"soil_moisture_fraction": 1.0,
"conductivitiy_factor": 1.0,
}
Expand Down
Loading
Loading