Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 4 additions & 15 deletions src/spatialdata/_io/io_raster.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from ome_zarr.writer import write_labels as write_labels_ngff
from ome_zarr.writer import write_multiscale as write_multiscale_ngff
from ome_zarr.writer import write_multiscale_labels as write_multiscale_labels_ngff
from xarray import DataArray, Dataset, DataTree
from xarray import DataArray, DataTree

from spatialdata._io._utils import (
_get_transformations_from_ngff_dict,
Expand All @@ -27,6 +27,7 @@
from spatialdata._utils import get_pyramid_levels
from spatialdata.models._utils import get_channel_names
from spatialdata.models.models import ATTRS_KEY
from spatialdata.models.pyramids_utils import dask_arrays_to_datatree
from spatialdata.transformations._utils import (
_get_transformations,
_get_transformations_xarray,
Expand Down Expand Up @@ -91,20 +92,8 @@ def _read_multiscale(
channels = [d["label"] for d in omero_metadata["channels"]]
axes = [i["name"] for i in node.metadata["axes"]]
if len(datasets) > 1:
multiscale_image = {}
for i, d in enumerate(datasets):
data = node.load(Multiscales).array(resolution=d)
multiscale_image[f"scale{i}"] = Dataset(
{
"image": DataArray(
data,
name="image",
dims=axes,
coords={"c": channels} if channels is not None else {},
)
}
)
msi = DataTree.from_dict(multiscale_image)
arrays = [node.load(Multiscales).array(resolution=d) for d in datasets]
msi = dask_arrays_to_datatree(arrays, dims=axes, channels=channels)
_set_transformations(msi, transformations)
return compute_coordinates(msi)

Expand Down
77 changes: 52 additions & 25 deletions src/spatialdata/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,15 @@
from spatialdata._core.query.relational_query import get_element_instances
from spatialdata._core.spatialdata import SpatialData
from spatialdata._types import ArrayLike
from spatialdata.models import Image2DModel, Labels2DModel, PointsModel, ShapesModel, TableModel
from spatialdata.models import (
Image2DModel,
Image3DModel,
Labels2DModel,
Labels3DModel,
PointsModel,
ShapesModel,
TableModel,
)
from spatialdata.transformations import Identity

__all__ = ["blobs", "raccoon"]
Expand Down Expand Up @@ -143,10 +151,10 @@ def blobs(
"""Blobs dataset."""
image = self._image_blobs(self.transformations, self.length, self.n_channels, self.c_coords)
multiscale_image = self._image_blobs(
self.transformations, self.length, self.n_channels, self.c_coords, multiscale=True
self.transformations, self.length, self.n_channels, self.c_coords, scale_factors=[2, 2]
)
labels = self._labels_blobs(self.transformations, self.length)
multiscale_labels = self._labels_blobs(self.transformations, self.length, multiscale=True)
multiscale_labels = self._labels_blobs(self.transformations, self.length, scale_factors=[2, 2])
points = self._points_blobs(self.transformations, self.length, self.n_points)
circles = self._circles_blobs(self.transformations, self.length, self.n_shapes)
polygons = self._polygons_blobs(self.transformations, self.length, self.n_shapes)
Expand All @@ -171,38 +179,51 @@ def _image_blobs(
length: int = 512,
n_channels: int = 3,
c_coords: str | list[str] | None = None,
multiscale: bool = False,
scale_factors: list[int] | None = None,
ndim: int = 2,
) -> DataArray | DataTree:
masks = []
for i in range(n_channels):
mask = self._generate_blobs(length=length, seed=i)
mask = self._generate_blobs(length=length, seed=i, ndim=ndim)
mask = (mask - mask.min()) / np.ptp(mask)
masks.append(mask)

x = np.stack(masks, axis=0)
dims = ["c", "y", "x"]
if not multiscale:
return Image2DModel.parse(x, transformations=transformations, dims=dims, c_coords=c_coords)
return Image2DModel.parse(
x, transformations=transformations, dims=dims, c_coords=c_coords, scale_factors=[2, 2]
model: type[Image2DModel] | type[Image3DModel]
if ndim == 2:
dims = ["c", "y", "x"]
model = Image2DModel
else:
dims = ["c", "z", "y", "x"]
model = Image3DModel
if scale_factors is None:
return model.parse(x, transformations=transformations, dims=dims, c_coords=c_coords)
return model.parse(
x, transformations=transformations, dims=dims, c_coords=c_coords, scale_factors=scale_factors
)

def _labels_blobs(
self, transformations: dict[str, Any] | None = None, length: int = 512, multiscale: bool = False
self,
transformations: dict[str, Any] | None = None,
length: int = 512,
scale_factors: list[int] | None = None,
ndim: int = 2,
) -> DataArray | DataTree:
"""Create a 2D labels."""
"""Create labels in 2D or 3D."""
from scipy.ndimage import watershed_ift

# from skimage
mask = self._generate_blobs(length=length)
mask = self._generate_blobs(length=length, ndim=ndim)
threshold = np.percentile(mask, 100 * (1 - 0.3))
inputs = np.logical_not(mask < threshold).astype(np.uint8)
# use watershed from scipy
xm, ym = np.ogrid[0:length:10, 0:length:10]
grid = np.ogrid[tuple(slice(0, length, 10) for _ in range(ndim))]
markers = np.zeros_like(inputs).astype(np.int16)
markers[xm, ym] = np.arange(xm.size * ym.size).reshape((xm.size, ym.size))
grid_shape = tuple(g.size for g in grid)
markers[tuple(grid)] = np.arange(np.prod(grid_shape)).reshape(grid_shape)
out = watershed_ift(inputs, markers)
out[xm, ym] = out[xm - 1, ym - 1] # remove the isolate seeds
shifted = tuple(g - 1 for g in grid)
out[tuple(grid)] = out[tuple(shifted)] # remove the isolated seeds
# reindex by frequency
val, counts = np.unique(out, return_counts=True)
sorted_idx = np.argsort(counts)
Expand All @@ -211,20 +232,26 @@ def _labels_blobs(
out[out == val[idx]] = 0
else:
out[out == val[idx]] = i
dims = ["y", "x"]
if not multiscale:
return Labels2DModel.parse(out, transformations=transformations, dims=dims)
return Labels2DModel.parse(out, transformations=transformations, dims=dims, scale_factors=[2, 2])

def _generate_blobs(self, length: int = 512, seed: int | None = None) -> ArrayLike:
model: type[Labels2DModel] | type[Labels3DModel]
if ndim == 2:
dims = ["y", "x"]
model = Labels2DModel
else:
dims = ["z", "y", "x"]
model = Labels3DModel
if scale_factors is None:
return model.parse(out, transformations=transformations, dims=dims)
return model.parse(out, transformations=transformations, dims=dims, scale_factors=scale_factors)

def _generate_blobs(self, length: int = 512, seed: int | None = None, ndim: int = 2) -> ArrayLike:
from scipy.ndimage import gaussian_filter

rng = default_rng(42) if seed is None else default_rng(seed)
# from skimage
shape = tuple([length] * 2)
shape = (length,) * ndim
mask = np.zeros(shape)
n_pts = max(int(1.0 / 0.1) ** 2, 1)
points = (length * rng.random((2, n_pts))).astype(int)
n_pts = max(int(1.0 / 0.1) ** ndim, 1)
points = (length * rng.random((ndim, n_pts))).astype(int)
mask[tuple(indices for indices in points)] = 1
mask = gaussian_filter(mask, sigma=0.25 * length * 0.1)
assert isinstance(mask, np.ndarray)
Expand Down
60 changes: 60 additions & 0 deletions src/spatialdata/models/chunks_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
from collections.abc import Mapping, Sequence
from typing import Any, TypeAlias

Chunks_t: TypeAlias = int | tuple[int, ...] | tuple[tuple[int, ...], ...] | Mapping[Any, None | int | tuple[int, ...]]


def normalize_chunks(
chunks: Chunks_t,
axes: Sequence[str],
) -> dict[str, None | int | tuple[int, ...]]:
"""Normalize chunk specification to dict format.

This function converts various chunk formats to a dict mapping dimension names
to chunk sizes. The dict format is preferred because it's explicit about which
dimension gets which chunk size.

Parameters
----------
chunks
Chunk specification. Can be:
- int: Applied to all axes
- tuple[int, ...]: Chunk sizes in order corresponding to axes
- tuple[tuple[int, ...], ...]: Explicit per-block chunk sizes per axis
- dict: Mapping of axis names to chunk sizes. Values can be:
- int: uniform chunk size for that axis
- tuple[int, ...]: explicit per-block chunk sizes
- None: keep existing chunks (or use full dimension when no chunks were available)
axes
Tuple of axis names that defines the expected dimensions (e.g., ('c', 'y', 'x')).

Returns
-------
dict[str, None | int | tuple[int, ...]]
Dict mapping axis names to chunk sizes. ``None`` values are preserved
with dask semantics (keep existing chunks, or use full dimension size if chunks
where not available and are being created).

Raises
------
ValueError
If chunks format is not supported or incompatible with axes.
"""
if isinstance(chunks, int):
return dict.fromkeys(axes, chunks)

if isinstance(chunks, Mapping):
chunks_dict = dict(chunks)
missing = set(axes) - set(chunks_dict.keys())
if missing:
raise ValueError(f"chunks dict missing keys for axes {missing}, got: {list(chunks_dict.keys())}")
return {ax: chunks_dict[ax] for ax in axes}

if isinstance(chunks, tuple):
if len(chunks) != len(axes):
raise ValueError(f"chunks tuple length {len(chunks)} doesn't match axes {axes} (length {len(axes)})")
if not all(isinstance(c, (int, tuple)) for c in chunks):
raise ValueError(f"All elements in chunks tuple must be int or tuple[int, ...], got: {chunks}")
return dict(zip(axes, chunks, strict=True)) # type: ignore[arg-type]

raise ValueError(f"Unsupported chunks type: {type(chunks)}. Expected int, tuple, dict, or None.")
34 changes: 18 additions & 16 deletions src/spatialdata/models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from dask.array.core import from_array
from dask.dataframe import DataFrame as DaskDataFrame
from geopandas import GeoDataFrame, GeoSeries
from multiscale_spatial_image import to_multiscale
from multiscale_spatial_image import to_multiscale as to_multiscale_msi
from multiscale_spatial_image.to_multiscale.to_multiscale import Methods
from pandas import CategoricalDtype
from shapely._geometry import GeometryType
Expand All @@ -38,16 +38,17 @@
_validate_mapping_to_coordinate_system_type,
convert_region_column_to_categorical,
)
from spatialdata.models.chunks_utils import Chunks_t
from spatialdata.models.pyramids_utils import ScaleFactors_t # ozp -> ome-zarr-py
from spatialdata.models.pyramids_utils import to_multiscale as to_multiscale_ozp
from spatialdata.transformations._utils import (
_get_transformations,
_set_transformations,
compute_coordinates,
)
from spatialdata.transformations.transformations import Identity

# Types
Chunks_t: TypeAlias = int | tuple[int, ...] | tuple[tuple[int, ...], ...] | Mapping[Any, None | int | tuple[int, ...]]
ScaleFactors_t = Sequence[dict[str, int] | int]
__all__ = ["Chunks_t", "ScaleFactors_t"]

ATTRS_KEY = "spatialdata_attrs"

Expand Down Expand Up @@ -225,12 +226,19 @@ def parse(
chunks = {dim: chunks[index] for index, dim in enumerate(data.dims)}
if isinstance(chunks, float):
chunks = {dim: chunks for index, dim in data.dims}
data = to_multiscale(
data,
scale_factors=scale_factors,
method=method,
chunks=chunks,
)
if method is not None:
data = to_multiscale_msi(
data,
scale_factors=scale_factors,
method=method,
chunks=chunks,
)
else:
data = to_multiscale_ozp(
data,
scale_factors=scale_factors,
chunks=chunks,
)
_parse_transformations(data, parsed_transform)
else:
# Chunk single scale images
Expand Down Expand Up @@ -375,9 +383,6 @@ def parse( # noqa: D102
) -> DataArray | DataTree:
if kwargs.get("c_coords") is not None:
raise ValueError("`c_coords` is not supported for labels")
if kwargs.get("scale_factors") is not None and kwargs.get("method") is None:
# Override default scaling method to preserve labels
kwargs["method"] = Methods.DASK_IMAGE_NEAREST
return super().parse(*args, **kwargs)


Expand All @@ -388,9 +393,6 @@ class Labels3DModel(RasterSchema):
def parse(self, *args: Any, **kwargs: Any) -> DataArray | DataTree: # noqa: D102
if kwargs.get("c_coords") is not None:
raise ValueError("`c_coords` is not supported for labels")
if kwargs.get("scale_factors") is not None and kwargs.get("method") is None:
# Override default scaling method to preserve labels
kwargs["method"] = Methods.DASK_IMAGE_NEAREST
return super().parse(*args, **kwargs)


Expand Down
Loading
Loading