diff --git a/CHANGELOG.md b/CHANGELOG.md index e1e0086..1f6d5b5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -20,6 +20,7 @@ and this project adheres to [Semantic Versioning][]. ### Features - Added `groupby` support to {meth}`annbatch.DatasetCollection.add_adatas` to group observations per dataset before writing collections. When appending to an existing on-disk collection, groupby columns must already exist and categorical categories must be identical to those on-disk. +- Warn when building or extending a {class}`annbatch.DatasetCollection` if outer concatenation would introduce missing values into categorical `obs` columns because those columns are absent from some inputs. ## [0.1.3] diff --git a/src/annbatch/io.py b/src/annbatch/io.py index b76be75..b643174 100644 --- a/src/annbatch/io.py +++ b/src/annbatch/io.py @@ -255,12 +255,14 @@ def _validate_anndatas_and_maybe_get_bytes_per_row[T: zarr.Group | h5py.Group | ------- The average bytes per observation row when *estimate_bytes_per_obs_row* is ``True``, otherwise ``None``. """ + paths_or_anndatas = list(paths_or_anndatas) num_raw_in_adata = 0 found_keys: dict[str, defaultdict[str, int]] = { "layers": defaultdict(lambda: 0), "obsm": defaultdict(lambda: 0), "obs": defaultdict(lambda: 0), } + found_categorical_obs_cols: defaultdict[str, int] = defaultdict(lambda: 0) bytes_per_obs_samples: list[float] = [] for path_or_anndata in tqdm(paths_or_anndatas, desc="Validating anndatas"): if not isinstance(path_or_anndata, ad.AnnData): @@ -284,9 +286,23 @@ def _validate_anndatas_and_maybe_get_bytes_per_row[T: zarr.Group | h5py.Group | for key in curr_keys: if not (elem_name in {"var", "obs"} and key == "_index"): key_count[key] += 1 + categorical_obs_cols_in_adata = { + col + for col in adata.obs.columns + # src_path is an annbatch-internal annotation that is always per-dataset by construction, + # and should not participate in user-facing outer-join validation. + if adata.obs[col].dtype == "category" and col != "src_path" + } + if "dataset2d_categoricals_to_convert" in adata.uns: + categorical_obs_cols_in_adata.update( + col for col in adata.uns["dataset2d_categoricals_to_convert"] if col != "src_path" + ) + for col in categorical_obs_cols_in_adata: + found_categorical_obs_cols[col] += 1 if adata.raw is not None: num_raw_in_adata += 1 - if num_raw_in_adata != (num_anndatas := len(list(paths_or_anndatas))) and num_raw_in_adata != 0: + num_anndatas = len(paths_or_anndatas) + if num_raw_in_adata != num_anndatas and num_raw_in_adata != 0: warnings.warn( f"Found raw keys not present in all anndatas {paths_or_anndatas}, consider deleting raw or moving it to a shared layer/X location via `load_adata`", stacklevel=2, @@ -298,6 +314,14 @@ def _validate_anndatas_and_maybe_get_bytes_per_row[T: zarr.Group | h5py.Group | f"Found {elem_name} keys {elem_keys_mismatched} not present in all anndatas {paths_or_anndatas}, consider stopping and using the `load_adata` argument to alter {elem_name} accordingly.", stacklevel=2, ) + categorical_obs_cols_mismatched = [ + col for col, count in found_categorical_obs_cols.items() if count != num_anndatas + ] + if len(categorical_obs_cols_mismatched) > 0: + warnings.warn( + f"Found categorical obs columns {categorical_obs_cols_mismatched} not present in all anndatas {paths_or_anndatas}; outer concatenation may introduce missing values in those columns.", + stacklevel=2, + ) return float(np.mean(bytes_per_obs_samples)) if bytes_per_obs_samples else None diff --git a/tests/test_preshuffle.py b/tests/test_preshuffle.py index 20d6be9..ff40b10 100644 --- a/tests/test_preshuffle.py +++ b/tests/test_preshuffle.py @@ -1,6 +1,8 @@ from __future__ import annotations import glob +import re +import warnings from contextlib import nullcontext from typing import TYPE_CHECKING, Literal @@ -22,6 +24,10 @@ from pathlib import Path +def _assert_warning_count(caught_warnings: list[warnings.WarningMessage], match: str, count: int) -> None: + assert sum(bool(re.search(match, str(warning.message))) for warning in caught_warnings) == count + + @pytest.mark.parametrize( ["chunk_size", "expected_shard_size"], [pytest.param(3, 9, id="n_obs_not_divisible_by_chunk"), pytest.param(5, 10, id="n_obs_divisible_by_chunk")], @@ -44,7 +50,8 @@ def test_store_creation_warnings_with_different_keys(elem_name: Literal["obsm", path_2 = tmp_path / "with_extra_key.h5ad" adata_1.write_h5ad(path_1) adata_2.write_h5ad(path_2) - with pytest.warns(UserWarning, match=rf"Found {elem_name} keys.* not present in all anndatas"): + with warnings.catch_warnings(record=True) as caught_warnings: + warnings.simplefilter("always") DatasetCollection(tmp_path / "collection.zarr").add_adatas( [path_1, path_2], n_obs_per_chunk=5, @@ -52,6 +59,7 @@ def test_store_creation_warnings_with_different_keys(elem_name: Literal["obsm", dataset_size=10, shuffle_chunk_size=10, ) + _assert_warning_count(caught_warnings, rf"Found {elem_name} keys.* not present in all anndatas", 1) def test_store_creation_no_warnings_with_custom_load(tmp_path: Path): @@ -121,7 +129,8 @@ def test_store_addition_different_keys( adata = ad.AnnData(X=np.random.randn(10, 20), **extra_args) additional_path = tmp_path / "with_extra_key.h5ad" adata.write_h5ad(additional_path) - with pytest.warns(UserWarning, match=rf"Found {elem_name} keys.* not present in all anndatas"): + with warnings.catch_warnings(record=True) as caught_warnings: + warnings.simplefilter("always") collection.add_adatas( [additional_path], load_adata=load_adata, @@ -129,6 +138,7 @@ def test_store_addition_different_keys( shard_size=10, shuffle_chunk_size=2, ) + _assert_warning_count(caught_warnings, rf"Found {elem_name} keys.* not present in all anndatas", 1) def test_h5ad_and_zarr_simultaneously(tmp_path: Path): @@ -262,7 +272,7 @@ def _write_groupby_test_adata( obs["label"] = pd.Categorical(label_values, categories=label_categories) ad.AnnData( X=sp.csr_matrix(np.eye(n_obs, dtype="f4")), - obs=pd.DataFrame(obs, index=[f"cell_{i}" for i in range(n_obs)]), + obs=pd.DataFrame(obs, index=[f"{path.stem}_cell_{i}" for i in range(n_obs)]), var=pd.DataFrame(index=[f"gene_{i}" for i in range(n_obs)]), ).write_h5ad(path, compression=None) return path @@ -287,6 +297,168 @@ def consistent_groupby_h5_paths(tmp_path: Path) -> list[Path]: ] +def test_store_creation_warns_when_outer_join_introduces_missing_categorical_values(tmp_path: Path): + first = _write_groupby_test_adata( + tmp_path / "first.h5ad", + label_values=["a", "b", "a"], + label_categories=["a", "b"], + ) + second = _write_groupby_test_adata(tmp_path / "second.h5ad") + with warnings.catch_warnings(): + # annbatch writes Zarr v3 stores and consolidates metadata during writes, + # which currently triggers a known zarr warning unrelated to the behavior under test. + warnings.filterwarnings( + "ignore", + message="Consolidated metadata is currently not part.*", + category=UserWarning, + ) + with pytest.warns(UserWarning) as caught_warnings: + DatasetCollection(tmp_path / "collection.zarr").add_adatas( + [first, second], + n_obs_per_chunk=2, + shard_size=2, + dataset_size=3, + shuffle_chunk_size=1, + shuffle=False, + rng=np.random.default_rng(0), + ) + _assert_warning_count(caught_warnings, r"Found obs keys", 1) + _assert_warning_count(caught_warnings, r"categorical obs columns", 1) + + +def test_store_addition_warns_when_outer_join_introduces_missing_categorical_values(tmp_path: Path): + initial = _write_groupby_test_adata(tmp_path / "initial.h5ad") + additional = _write_groupby_test_adata( + tmp_path / "additional.h5ad", + label_values=["a", "b", "a"], + label_categories=["a", "b"], + ) + with warnings.catch_warnings(): + # annbatch writes Zarr v3 stores and consolidates metadata during writes, + # which currently triggers a known zarr warning unrelated to the behavior under test. + warnings.filterwarnings( + "ignore", + message="Consolidated metadata is currently not part.*", + category=UserWarning, + ) + collection = DatasetCollection(tmp_path / "collection.zarr").add_adatas( + [initial], + n_obs_per_chunk=2, + shard_size=2, + dataset_size=3, + shuffle_chunk_size=1, + shuffle=False, + rng=np.random.default_rng(0), + ) + with warnings.catch_warnings(): + # annbatch writes Zarr v3 stores and consolidates metadata during writes, + # which currently triggers a known zarr warning unrelated to the behavior under test. + warnings.filterwarnings( + "ignore", + message="Consolidated metadata is currently not part.*", + category=UserWarning, + ) + with pytest.warns(UserWarning) as caught_warnings: + collection.add_adatas( + [additional], + n_obs_per_chunk=2, + shard_size=2, + shuffle_chunk_size=1, + shuffle=False, + rng=np.random.default_rng(0), + ) + _assert_warning_count(caught_warnings, r"Found obs keys", 1) + _assert_warning_count(caught_warnings, r"categorical obs columns", 1) + + +@pytest.mark.parametrize( + ("first_kwargs", "second_kwargs"), + [ + pytest.param( + {"label_values": ["a", "b", "a"], "label_categories": ["a", "b"]}, + {"label_values": ["b", "a", "b"], "label_categories": ["a", "b", "c"]}, + id="different_categories", + ), + pytest.param( + {"label_values": ["a", "b", "a"], "label_categories": ["a", "b"]}, + {"label_values": ["b", "a", "b"], "label_categories": ["b", "a"]}, + id="different_category_order", + ), + ], +) +def test_store_creation_does_not_warn_for_categorical_category_expansion( + tmp_path: Path, + first_kwargs: dict, + second_kwargs: dict, +): + first = _write_groupby_test_adata(tmp_path / "first.h5ad", **first_kwargs) + second = _write_groupby_test_adata(tmp_path / "second.h5ad", **second_kwargs) + with warnings.catch_warnings(record=True) as caught_warnings: + warnings.simplefilter("always") + DatasetCollection(tmp_path / "collection.zarr").add_adatas( + [first, second], + n_obs_per_chunk=2, + shard_size=2, + dataset_size=3, + shuffle_chunk_size=1, + shuffle=False, + rng=np.random.default_rng(0), + ) + _assert_warning_count(caught_warnings, r"categorical obs columns", 0) + + +@pytest.mark.parametrize( + ("initial_kwargs", "additional_kwargs"), + [ + pytest.param( + {"label_values": ["a", "b", "a"], "label_categories": ["a", "b"]}, + {"label_values": ["b", "a", "b"], "label_categories": ["a", "b", "c"]}, + id="different_categories", + ), + pytest.param( + {"label_values": ["a", "b", "a"], "label_categories": ["a", "b"]}, + {"label_values": ["b", "a", "b"], "label_categories": ["b", "a"]}, + id="different_category_order", + ), + ], +) +def test_store_addition_does_not_warn_for_categorical_category_expansion( + tmp_path: Path, + initial_kwargs: dict, + additional_kwargs: dict, +): + initial = _write_groupby_test_adata(tmp_path / "initial.h5ad", **initial_kwargs) + additional = _write_groupby_test_adata(tmp_path / "additional.h5ad", **additional_kwargs) + with warnings.catch_warnings(): + # annbatch writes Zarr v3 stores and consolidates metadata during writes, + # which currently triggers a known zarr warning unrelated to the behavior under test. + warnings.filterwarnings( + "ignore", + message="Consolidated metadata is currently not part.*", + category=UserWarning, + ) + collection = DatasetCollection(tmp_path / "collection.zarr").add_adatas( + [initial], + n_obs_per_chunk=2, + shard_size=2, + dataset_size=3, + shuffle_chunk_size=1, + shuffle=False, + rng=np.random.default_rng(0), + ) + with warnings.catch_warnings(record=True) as caught_warnings: + warnings.simplefilter("always") + collection.add_adatas( + [additional], + n_obs_per_chunk=2, + shard_size=2, + shuffle_chunk_size=1, + shuffle=False, + rng=np.random.default_rng(0), + ) + _assert_warning_count(caught_warnings, r"categorical obs columns", 0) + + @pytest.mark.parametrize( ("groupby", "match"), [