From 701cc094ad5d3fc2b34377beecfcc974529513a2 Mon Sep 17 00:00:00 2001 From: hrodmn Date: Mon, 1 Jun 2026 09:23:18 -0500 Subject: [PATCH] fix: keep lazy runtime state out of DataArray attrs Avoid storing the lazycogs backend and DuckDB client in xarray attrs so DataArray.copy(), sortby(), and downstream write paths can deep-copy metadata without trying to pickle live runtime objects. Recover the backend for lazycogs.explain() from the lazy backing array instead, and preserve explain behavior for sliced and reordered arrays. --- ARCHITECTURE.md | 4 +- src/lazycogs/_backend.py | 9 +++ src/lazycogs/_core.py | 21 +---- src/lazycogs/_explain.py | 169 +++++++++++++++++++++++++++++++-------- tests/test_core.py | 31 ++++++- tests/test_explain.py | 49 +++++++++--- 6 files changed, 214 insertions(+), 69 deletions(-) diff --git a/ARCHITECTURE.md b/ARCHITECTURE.md index 6c2ea76..929952e 100644 --- a/ARCHITECTURE.md +++ b/ARCHITECTURE.md @@ -56,7 +56,7 @@ src/lazycogs/ 8. Creates a single `MultiBandStacBackendArray` (a dataclass) with shape `(band, time, y, x)` holding all the parameters needed to materialise any chunk later, then wraps it in one `xarray.core.indexing.LazilyIndexedArray`. This avoids `xr.concat` (used internally by `ds.to_array()`), which would eagerly load `LazilyIndexedArray`-backed objects. 9. Uses `rasterix.RasterIndex` for spatial indexing, but materialises the x/y coordinate variables eagerly as numpy arrays so chunked scalar spatial selections compute reliably. 10. Constructs the `xr.DataArray` directly from the 4-D variable. If `chunks` is provided, calls `.chunk(chunks)` to convert to a dask-backed array; otherwise the `LazilyIndexedArray` remains in play so narrow slices (e.g. a single pixel) translate to minimal I/O. When output nodata is known, the returned array sets `da.attrs["_FillValue"]` and `da.encoding["_FillValue"]` for downstream serialization. When unknown, no `_FillValue` metadata is attached. -11. Stores `_stac_backend` (the `MultiBandStacBackendArray` instance) and `_stac_time_coords` (the full time coordinate array) in `da.attrs` so that `da.lazycogs.explain()` can reconstruct the explain plan without re-specifying `open()` parameters. +11. Keeps lazy runtime state on the backing array rather than in `da.attrs`. This lets xarray operations such as `sortby()` and deep copies clone metadata safely without trying to pickle live objects like `DuckdbClient`. ## Explain: dry-run read estimator @@ -69,7 +69,7 @@ print(plan.summary()) df = plan.to_dataframe() ``` -The accessor reads `_stac_backend` and `_stac_time_coords` from `da.attrs` and respects the DataArray's current extent and chunk sizes, so explaining a sliced DataArray (`da.isel(time=0).lazycogs.explain()`) queries only the reads needed for that slice. +The accessor discovers the `MultiBandStacBackendArray` from the DataArray's lazy backing array and respects the current extent, indexing, and chunk sizes, so explaining a sliced or reordered DataArray (`da.isel(time=0).lazycogs.explain()` or `da.sortby("time").lazycogs.explain()`) queries only the reads needed for that view. If the array has been materialized or transformed into a different backing array, `explain()` raises and asks for a still-lazy lazycogs array. `ExplainPlan` exposes: - `total_chunk_reads` — number of `(band, time, spatial tile)` combinations diff --git a/src/lazycogs/_backend.py b/src/lazycogs/_backend.py index 0c40c48..ca944c1 100644 --- a/src/lazycogs/_backend.py +++ b/src/lazycogs/_backend.py @@ -403,6 +403,15 @@ def __repr__(self) -> str: """Return a compact string representation.""" return f"MultiBandStacBackendArray(bands={self.bands!r}, shape={self.shape})" + def __copy__(self) -> MultiBandStacBackendArray: + """Return ``self`` because backend arrays are immutable runtime state.""" + return self + + def __deepcopy__(self, memo: dict[int, object]) -> MultiBandStacBackendArray: + """Return ``self`` so xarray copies do not try to pickle DuckDB state.""" + memo[id(self)] = self + return self + def _resolve_spatial_window( self, y_key: int | np.integer | slice, diff --git a/src/lazycogs/_core.py b/src/lazycogs/_core.py index 3035f79..d04e371 100644 --- a/src/lazycogs/_core.py +++ b/src/lazycogs/_core.py @@ -6,7 +6,7 @@ import logging import time from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Self +from typing import TYPE_CHECKING, Any import numpy as np from async_geotiff import GeoTIFF @@ -34,23 +34,6 @@ _INT_WIDTHS = (8, 16, 32, 64) -class _CompactDateArray(np.ndarray): - """Numpy datetime64 array subclass with a compact display for xarray HTML repr.""" - - def __new__(cls, values: np.ndarray) -> Self: - return np.asarray(values, dtype="datetime64[D]").view(cls) - - def __str__(self) -> str: - arr = self.view(np.ndarray) - n = len(arr) - if n == 1: - return str(arr[0]) - return f"{arr[0]} \u2026 {arr[-1]} (n={n})" - - def __repr__(self) -> str: - return self.__str__() - - @dataclass(frozen=True) class _ItemInspection: """Sampled startup metadata from one representative STAC item.""" @@ -570,8 +553,6 @@ def _build_dataarray( "spatial:transform": gt, "spatial:shape": [dst_height, dst_width], "spatial:registration": "pixel", - "_stac_backend": multi, - "_stac_time_coords": _CompactDateArray(time_coord), } # Zarr geo-proj convention diff --git a/src/lazycogs/_explain.py b/src/lazycogs/_explain.py index 97f318b..dc540ef 100644 --- a/src/lazycogs/_explain.py +++ b/src/lazycogs/_explain.py @@ -13,7 +13,9 @@ from affine import Affine from pandas import DataFrame from pyproj import CRS, Transformer +from xarray.core import indexing +from lazycogs._backend import MultiBandStacBackendArray from lazycogs._chunk_reader import _ChunkContext, _open_and_window from lazycogs._executor import run_duckdb, run_on_loop @@ -22,13 +24,131 @@ from async_geotiff import Store - from lazycogs._backend import MultiBandStacBackendArray logger = logging.getLogger(__name__) _COG_MULTI_THRESHOLD = 2 +def _backend_search_children(value: object) -> list[object]: + """Return nested objects that may contain a lazycogs backend.""" + children: list[object] = [] + + if isinstance(value, indexing.LazilyIndexedArray): + children.append(value.array) + return children + + if isinstance(value, indexing.ImplicitToExplicitIndexingAdapter): + children.append(value.array) + return children + + for attr in ("array", "_array", "dask", "layers"): + nested = getattr(value, attr, None) + if nested is not None: + children.append(nested) + + if isinstance(value, dict): + children.extend(value.values()) + return children + + values = getattr(value, "values", None) + if callable(values): + children.extend(values()) + + if isinstance(value, tuple | list): + children.extend(value) + + return children + + +def _find_backend_array( + value: object, +) -> tuple[MultiBandStacBackendArray, indexing.ExplicitIndexer | None] | None: + """Return the first lazycogs backend found inside an xarray data wrapper.""" + stack: list[object] = [value] + seen: set[int] = set() + + while stack: + current = stack.pop() + obj_id = id(current) + if obj_id in seen: + continue + seen.add(obj_id) + + if isinstance(current, MultiBandStacBackendArray): + return current, None + + if isinstance(current, indexing.LazilyIndexedArray) and isinstance( + current.array, + MultiBandStacBackendArray, + ): + return current.array, current.key + + stack.extend(reversed(_backend_search_children(current))) + + return None + + +def _indexer_positions( + indexer: int | np.integer | slice | np.ndarray, + size: int, +) -> list[int]: + """Normalise one xarray indexer component to explicit integer positions.""" + if isinstance(indexer, int | np.integer): + return [int(indexer)] + if isinstance(indexer, slice): + return list(range(*indexer.indices(size))) + + values = np.asarray(indexer) + if values.dtype == np.bool_: + return np.flatnonzero(values).astype(int).tolist() + return values.astype(int).tolist() + + +def _current_time_items( + da: xr.DataArray, + backend: MultiBandStacBackendArray, + key: indexing.ExplicitIndexer | None, +) -> list[tuple[int, str, np.datetime64]]: + """Return backend time indices and current coordinate values in array order.""" + if "time" not in da.coords: + raise ValueError( + "This DataArray no longer exposes a time coordinate. " + "lazycogs.explain() only supports arrays that still retain their " + "lazycogs time axis.", + ) + + current_time_coords = np.atleast_1d(da.coords["time"].values).astype( + "datetime64[D]", + ) + + if key is None: + if len(current_time_coords) != len(backend.dates): + raise ValueError( + "Could not recover lazycogs time-step indexing from this DataArray. " + "lazycogs.explain() currently supports arrays that still retain " + "their original lazy indexing structure.", + ) + time_positions = list(range(len(backend.dates))) + else: + time_positions = _indexer_positions(key.tuple[1], len(backend.dates)) + + if len(time_positions) != len(current_time_coords): + raise ValueError( + "Could not align the current time coordinates with the underlying " + "lazycogs backend indexing.", + ) + + return [ + (backend_index, backend.dates[backend_index], time_coord) + for backend_index, time_coord in zip( + time_positions, + current_time_coords, + strict=False, + ) + ] + + @dataclass class CogRead: """Read details for one COG file within one chunk. @@ -500,6 +620,7 @@ async def _inspect_item_async( async def _explain_async( da: xr.DataArray, backend: MultiBandStacBackendArray, + key: indexing.ExplicitIndexer | None, *, fetch_headers: bool, ) -> ExplainPlan: @@ -515,8 +636,10 @@ async def _explain_async( Args: da: DataArray whose extent and chunking define the explain scope. - backend: :class:`MultiBandStacBackendArray` stored in the DataArray - attrs by :func:`~lazycogs._core._build_dataarray`. + backend: :class:`MultiBandStacBackendArray` discovered from the + DataArray's lazy backing array. + key: The xarray indexer associated with the discovered backend, when + available. fetch_headers: When ``True``, open each matched COG header. Returns: @@ -543,28 +666,7 @@ async def _explain_async( dst_crs = backend.dst_crs - # Identify which time steps to explain based on current DataArray coords. - full_time_coords: np.ndarray = np.asarray( - da.attrs["_stac_time_coords"], - dtype="datetime64[D]", - ) - full_time_filters: list[str] = backend.dates - - if "time" in da.coords: - current_times: set[np.datetime64] = set( - np.atleast_1d(da.coords["time"].values).astype("datetime64[D]"), - ) - time_items = [ - (i, f, tc) - for i, (f, tc) in enumerate( - zip(full_time_filters, full_time_coords, strict=False), - ) - if tc.astype("datetime64[D]") in current_times - ] - else: - time_items = [ - (i, f, full_time_coords[i]) for i, f in enumerate(full_time_filters) - ] + time_items = _current_time_items(da, backend, key) chunk_h, chunk_w = _infer_chunk_sizes(da) @@ -738,17 +840,18 @@ def explain(self, *, fetch_headers: bool = False) -> ExplainPlan: tile) reads for the current DataArray extent and chunking. Raises: - ValueError: If the DataArray was not produced by - ``lazycogs.open()`` (missing explain metadata in - ``attrs``). + ValueError: If the DataArray is not still backed by lazycogs' + lazy backend array. """ - backend: MultiBandStacBackendArray | None = self._da.attrs.get("_stac_backend") - if backend is None: + backend_and_key = _find_backend_array(self._da.variable._data) # noqa: SLF001 + if backend_and_key is None: raise ValueError( - "This DataArray does not have lazycogs explain metadata. " - "Ensure it was created by lazycogs.open().", + "This DataArray is not backed by lazycogs' lazy array. " + "Ensure it was created by lazycogs.open() and has not been " + "materialized or transformed into a different backing array.", ) + backend, key = backend_and_key return run_on_loop( - _explain_async(self._da, backend, fetch_headers=fetch_headers), + _explain_async(self._da, backend, key, fetch_headers=fetch_headers), ) diff --git a/tests/test_core.py b/tests/test_core.py index 831c49d..bf59e08 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -21,6 +21,7 @@ _resolve_output_dtype, _resolve_output_nodata, ) +from lazycogs._explain import _find_backend_array from lazycogs._mosaic_methods import FirstMethod, MeanMethod, MedianMethod from lazycogs._temporal import _DayGrouper, _FixedDayGrouper, _MonthGrouper @@ -379,9 +380,12 @@ def test_open_sets_expected_dataarray_attributes(opened_dataarray): assert "missing_value" not in da.attrs assert da.encoding["_FillValue"] == da.attrs["_FillValue"] == 0 - # Internal bookkeeping attributes - assert isinstance(da.attrs["_stac_backend"], MultiBandStacBackendArray) - assert da.attrs["_stac_time_coords"].dtype == np.dtype("datetime64[D]") + # Internal runtime state is kept off attrs so xarray can deep-copy safely. + assert "_stac_backend" not in da.attrs + assert "_stac_time_coords" not in da.attrs + + backend, _ = _find_backend_array(da.variable._data) + assert isinstance(backend, MultiBandStacBackendArray) def test_chunked_spatial_selection_computes_scalar_spatial_coords(opened_dataarray): @@ -410,6 +414,24 @@ def test_chunked_spatial_selection_full_compute_succeeds(opened_dataarray): assert computed.shape == (1, 1) +def test_opened_dataarray_sortby_time_deepcopy_safe(opened_dataarray): + """sortby(time) works because runtime state is not stored in attrs.""" + sorted_da = opened_dataarray.sortby("time") + + assert sorted_da.dims == opened_dataarray.dims + assert "_stac_backend" not in sorted_da.attrs + assert "_stac_time_coords" not in sorted_da.attrs + + +def test_opened_dataarray_deep_copy_safe(opened_dataarray): + """Deep copying a lazycogs DataArray does not try to pickle DuckDB state.""" + copied = opened_dataarray.copy(deep=True) + + assert copied.dims == opened_dataarray.dims + assert "_stac_backend" not in copied.attrs + assert "_stac_time_coords" not in copied.attrs + + # --------------------------------------------------------------------------- # Startup inspection and output contract helpers # --------------------------------------------------------------------------- @@ -691,7 +713,8 @@ async def fake_open(path: str, *, store): path_from_href=lambda href: href.split("/", 3)[-1], ) - assert da.attrs["_stac_backend"].store is store + backend, _ = _find_backend_array(da.variable._data) + assert backend.store is store def test_open_auto_promotes_inferred_integer_dtype_for_float_method(tmp_path): diff --git a/tests/test_explain.py b/tests/test_explain.py index 7da60f4..f25f774 100644 --- a/tests/test_explain.py +++ b/tests/test_explain.py @@ -10,6 +10,7 @@ from affine import Affine from pyproj import CRS from rustac import DuckdbClient +from xarray.core import indexing from lazycogs._backend import MultiBandStacBackendArray from lazycogs._explain import ( @@ -17,6 +18,7 @@ CogRead, ExplainPlan, _compute_chunk_bbox_4326, + _find_backend_array, _infer_chunk_sizes, _iter_spatial_chunks, _roi_pixel_offsets, @@ -84,7 +86,7 @@ def _make_da_with_backends( height: int = 10, affine: Affine | None = None, ) -> xr.DataArray: - """Return a minimal DataArray with lazycogs explain attrs attached.""" + """Return a minimal lazycogs-backed DataArray for explain tests.""" if affine is None: resolution = 1.0 affine = Affine(resolution, 0.0, 0.0, 0.0, -resolution, float(height)) @@ -105,19 +107,19 @@ def _make_da_with_backends( x_coords = np.array([affine.c + (i + 0.5) * resolution for i in range(width)]) y_coords = np.array([affine.f + (i + 0.5) * affine.e for i in range(height)]) - da = xr.DataArray( - np.zeros((len(bands), len(dates), height, width), dtype="float32"), + variable = xr.Variable( + ("band", "time", "y", "x"), + indexing.LazilyIndexedArray(backend), + ) + return xr.DataArray( + variable, coords={ "band": bands, "time": time_coord, "y": y_coords, "x": x_coords, }, - dims=("band", "time", "y", "x"), ) - da.attrs["_stac_backend"] = backend - da.attrs["_stac_time_coords"] = time_coord - return da # --------------------------------------------------------------------------- @@ -269,7 +271,7 @@ def test_roi_pixel_offsets_full_extent(wgs84): height=10, affine=affine, ) - backend = da.attrs["_stac_backend"] + backend, _ = _find_backend_array(da.variable._data) x_start, y_start_physical, roi_w, roi_h = _roi_pixel_offsets(da, backend) assert x_start == 0 assert y_start_physical == 0 @@ -457,9 +459,9 @@ def _fake_items(band: str, n: int) -> list[dict]: def test_accessor_raises_on_non_stac_da(): - """explain() raises ValueError when the DataArray has no explain metadata.""" + """explain() raises ValueError when the array is not lazycogs-backed.""" da = xr.DataArray(np.zeros((3, 3))) - with pytest.raises(ValueError, match=r"lazycogs\.open"): + with pytest.raises(ValueError, match=r"backed by lazycogs"): da.lazycogs.explain() @@ -628,6 +630,33 @@ def test_accessor_explain_time_slice(wgs84): assert plan.total_chunk_reads == 2 # 1 band * 2 time steps * 1 spatial tile +def test_accessor_explain_time_sort_preserves_current_order(wgs84): + """explain() follows the current DataArray time order after sortby.""" + dates = ["2023-01-01/2023-01-01", "2023-01-02/2023-01-02"] + time_coords = [np.datetime64("2023-01-02", "D"), np.datetime64("2023-01-01", "D")] + da = _make_da_with_backends( + wgs84, + dates=dates, + time_coords=time_coords, + bands=["red"], + width=4, + height=4, + ) + da_sorted = da.sortby("time") + + with patch("rustac.DuckdbClient.search", return_value=[]): + plan = da_sorted.lazycogs.explain() + + assert plan.time_coords == [ + np.datetime64("2023-01-01", "D"), + np.datetime64("2023-01-02", "D"), + ] + assert [chunk.date_filter for chunk in plan.chunk_reads] == [ + "2023-01-02/2023-01-02", + "2023-01-01/2023-01-01", + ] + + def test_accessor_explain_query_count_not_multiplied_by_bands(wgs84): """DuckDB is queried once per (time, tile) — not once per (band, time, tile).""" dates = ["2023-01-01/2023-01-01", "2023-01-02/2023-01-02"]