Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ and this project adheres to [Semantic Versioning][].

### Features
- Added `groupby` support to {meth}`annbatch.DatasetCollection.add_adatas` to group observations per dataset before writing collections. When appending to an existing on-disk collection, groupby columns must already exist and categorical categories must be identical to those on-disk.
- Warn when building or extending a {class}`annbatch.DatasetCollection` if outer concatenation would introduce missing values into categorical `obs` columns because those columns are absent from some inputs.

## [0.1.3]

Expand Down
26 changes: 25 additions & 1 deletion src/annbatch/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,12 +255,14 @@ def _validate_anndatas_and_maybe_get_bytes_per_row[T: zarr.Group | h5py.Group |
-------
The average bytes per observation row when *estimate_bytes_per_obs_row* is ``True``, otherwise ``None``.
"""
paths_or_anndatas = list(paths_or_anndatas)
num_raw_in_adata = 0
found_keys: dict[str, defaultdict[str, int]] = {
"layers": defaultdict(lambda: 0),
"obsm": defaultdict(lambda: 0),
"obs": defaultdict(lambda: 0),
}
found_categorical_obs_cols: defaultdict[str, int] = defaultdict(lambda: 0)
bytes_per_obs_samples: list[float] = []
for path_or_anndata in tqdm(paths_or_anndatas, desc="Validating anndatas"):
if not isinstance(path_or_anndata, ad.AnnData):
Expand All @@ -284,9 +286,23 @@ def _validate_anndatas_and_maybe_get_bytes_per_row[T: zarr.Group | h5py.Group |
for key in curr_keys:
if not (elem_name in {"var", "obs"} and key == "_index"):
key_count[key] += 1
Comment on lines 286 to 288
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can just get away with adding obs and var to this check. Why not? It would would go beyond categoricals, but I'm not sure that is so bad. I am not sure why obs and var are excluded. This check is relatively naive as well so we could also add some behavior around e.g., mismatched dtypes. But that appears out-of-scope for this PR, aside from maybe just adding the special warning about categoricals.

categorical_obs_cols_in_adata = {
col
for col in adata.obs.columns
# src_path is an annbatch-internal annotation that is always per-dataset by construction,
# and should not participate in user-facing outer-join validation.
if adata.obs[col].dtype == "category" and col != "src_path"
}
if "dataset2d_categoricals_to_convert" in adata.uns:
categorical_obs_cols_in_adata.update(
col for col in adata.uns["dataset2d_categoricals_to_convert"] if col != "src_path"
)
for col in categorical_obs_cols_in_adata:
found_categorical_obs_cols[col] += 1
if adata.raw is not None:
num_raw_in_adata += 1
if num_raw_in_adata != (num_anndatas := len(list(paths_or_anndatas))) and num_raw_in_adata != 0:
num_anndatas = len(paths_or_anndatas)
if num_raw_in_adata != num_anndatas and num_raw_in_adata != 0:
warnings.warn(
f"Found raw keys not present in all anndatas {paths_or_anndatas}, consider deleting raw or moving it to a shared layer/X location via `load_adata`",
stacklevel=2,
Expand All @@ -298,6 +314,14 @@ def _validate_anndatas_and_maybe_get_bytes_per_row[T: zarr.Group | h5py.Group |
f"Found {elem_name} keys {elem_keys_mismatched} not present in all anndatas {paths_or_anndatas}, consider stopping and using the `load_adata` argument to alter {elem_name} accordingly.",
stacklevel=2,
)
categorical_obs_cols_mismatched = [
col for col, count in found_categorical_obs_cols.items() if count != num_anndatas
]
if len(categorical_obs_cols_mismatched) > 0:
warnings.warn(
f"Found categorical obs columns {categorical_obs_cols_mismatched} not present in all anndatas {paths_or_anndatas}; outer concatenation may introduce missing values in those columns.",
stacklevel=2,
)
return float(np.mean(bytes_per_obs_samples)) if bytes_per_obs_samples else None


Expand Down
178 changes: 175 additions & 3 deletions tests/test_preshuffle.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from __future__ import annotations

import glob
import re
import warnings
from contextlib import nullcontext
from typing import TYPE_CHECKING, Literal

Expand All @@ -22,6 +24,10 @@
from pathlib import Path


def _assert_warning_count(caught_warnings: list[warnings.WarningMessage], match: str, count: int) -> None:
assert sum(bool(re.search(match, str(warning.message))) for warning in caught_warnings) == count


@pytest.mark.parametrize(
["chunk_size", "expected_shard_size"],
[pytest.param(3, 9, id="n_obs_not_divisible_by_chunk"), pytest.param(5, 10, id="n_obs_divisible_by_chunk")],
Expand All @@ -44,14 +50,16 @@ def test_store_creation_warnings_with_different_keys(elem_name: Literal["obsm",
path_2 = tmp_path / "with_extra_key.h5ad"
adata_1.write_h5ad(path_1)
adata_2.write_h5ad(path_2)
with pytest.warns(UserWarning, match=rf"Found {elem_name} keys.* not present in all anndatas"):
with warnings.catch_warnings(record=True) as caught_warnings:
warnings.simplefilter("always")
Comment on lines +53 to +54
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this just blanket ignoring warnings? Seems bad

DatasetCollection(tmp_path / "collection.zarr").add_adatas(
[path_1, path_2],
n_obs_per_chunk=5,
shard_size=10,
dataset_size=10,
shuffle_chunk_size=10,
)
_assert_warning_count(caught_warnings, rf"Found {elem_name} keys.* not present in all anndatas", 1)


def test_store_creation_no_warnings_with_custom_load(tmp_path: Path):
Expand Down Expand Up @@ -121,14 +129,16 @@ def test_store_addition_different_keys(
adata = ad.AnnData(X=np.random.randn(10, 20), **extra_args)
additional_path = tmp_path / "with_extra_key.h5ad"
adata.write_h5ad(additional_path)
with pytest.warns(UserWarning, match=rf"Found {elem_name} keys.* not present in all anndatas"):
with warnings.catch_warnings(record=True) as caught_warnings:
warnings.simplefilter("always")
collection.add_adatas(
[additional_path],
load_adata=load_adata,
n_obs_per_chunk=5,
shard_size=10,
shuffle_chunk_size=2,
)
_assert_warning_count(caught_warnings, rf"Found {elem_name} keys.* not present in all anndatas", 1)


def test_h5ad_and_zarr_simultaneously(tmp_path: Path):
Expand Down Expand Up @@ -262,7 +272,7 @@ def _write_groupby_test_adata(
obs["label"] = pd.Categorical(label_values, categories=label_categories)
ad.AnnData(
X=sp.csr_matrix(np.eye(n_obs, dtype="f4")),
obs=pd.DataFrame(obs, index=[f"cell_{i}" for i in range(n_obs)]),
obs=pd.DataFrame(obs, index=[f"{path.stem}_cell_{i}" for i in range(n_obs)]),
var=pd.DataFrame(index=[f"gene_{i}" for i in range(n_obs)]),
).write_h5ad(path, compression=None)
return path
Expand All @@ -287,6 +297,168 @@ def consistent_groupby_h5_paths(tmp_path: Path) -> list[Path]:
]


def test_store_creation_warns_when_outer_join_introduces_missing_categorical_values(tmp_path: Path):
first = _write_groupby_test_adata(
tmp_path / "first.h5ad",
label_values=["a", "b", "a"],
label_categories=["a", "b"],
)
second = _write_groupby_test_adata(tmp_path / "second.h5ad")
with warnings.catch_warnings():
# annbatch writes Zarr v3 stores and consolidates metadata during writes,
# which currently triggers a known zarr warning unrelated to the behavior under test.
warnings.filterwarnings(
"ignore",
message="Consolidated metadata is currently not part.*",
category=UserWarning,
)
with pytest.warns(UserWarning) as caught_warnings:
DatasetCollection(tmp_path / "collection.zarr").add_adatas(
[first, second],
n_obs_per_chunk=2,
shard_size=2,
dataset_size=3,
shuffle_chunk_size=1,
shuffle=False,
rng=np.random.default_rng(0),
)
_assert_warning_count(caught_warnings, r"Found obs keys", 1)
_assert_warning_count(caught_warnings, r"categorical obs columns", 1)


def test_store_addition_warns_when_outer_join_introduces_missing_categorical_values(tmp_path: Path):
initial = _write_groupby_test_adata(tmp_path / "initial.h5ad")
additional = _write_groupby_test_adata(
tmp_path / "additional.h5ad",
label_values=["a", "b", "a"],
label_categories=["a", "b"],
)
with warnings.catch_warnings():
# annbatch writes Zarr v3 stores and consolidates metadata during writes,
# which currently triggers a known zarr warning unrelated to the behavior under test.
warnings.filterwarnings(
"ignore",
message="Consolidated metadata is currently not part.*",
category=UserWarning,
)
collection = DatasetCollection(tmp_path / "collection.zarr").add_adatas(
[initial],
n_obs_per_chunk=2,
shard_size=2,
dataset_size=3,
shuffle_chunk_size=1,
shuffle=False,
rng=np.random.default_rng(0),
)
with warnings.catch_warnings():
# annbatch writes Zarr v3 stores and consolidates metadata during writes,
# which currently triggers a known zarr warning unrelated to the behavior under test.
warnings.filterwarnings(
"ignore",
message="Consolidated metadata is currently not part.*",
category=UserWarning,
)
with pytest.warns(UserWarning) as caught_warnings:
collection.add_adatas(
[additional],
n_obs_per_chunk=2,
shard_size=2,
shuffle_chunk_size=1,
shuffle=False,
rng=np.random.default_rng(0),
)
_assert_warning_count(caught_warnings, r"Found obs keys", 1)
_assert_warning_count(caught_warnings, r"categorical obs columns", 1)


@pytest.mark.parametrize(
("first_kwargs", "second_kwargs"),
[
pytest.param(
{"label_values": ["a", "b", "a"], "label_categories": ["a", "b"]},
{"label_values": ["b", "a", "b"], "label_categories": ["a", "b", "c"]},
id="different_categories",
),
pytest.param(
{"label_values": ["a", "b", "a"], "label_categories": ["a", "b"]},
{"label_values": ["b", "a", "b"], "label_categories": ["b", "a"]},
id="different_category_order",
),
],
)
def test_store_creation_does_not_warn_for_categorical_category_expansion(
tmp_path: Path,
first_kwargs: dict,
second_kwargs: dict,
):
first = _write_groupby_test_adata(tmp_path / "first.h5ad", **first_kwargs)
second = _write_groupby_test_adata(tmp_path / "second.h5ad", **second_kwargs)
with warnings.catch_warnings(record=True) as caught_warnings:
warnings.simplefilter("always")
DatasetCollection(tmp_path / "collection.zarr").add_adatas(
[first, second],
n_obs_per_chunk=2,
shard_size=2,
dataset_size=3,
shuffle_chunk_size=1,
shuffle=False,
rng=np.random.default_rng(0),
)
_assert_warning_count(caught_warnings, r"categorical obs columns", 0)


@pytest.mark.parametrize(
("initial_kwargs", "additional_kwargs"),
[
pytest.param(
{"label_values": ["a", "b", "a"], "label_categories": ["a", "b"]},
{"label_values": ["b", "a", "b"], "label_categories": ["a", "b", "c"]},
id="different_categories",
),
pytest.param(
{"label_values": ["a", "b", "a"], "label_categories": ["a", "b"]},
{"label_values": ["b", "a", "b"], "label_categories": ["b", "a"]},
id="different_category_order",
),
],
)
def test_store_addition_does_not_warn_for_categorical_category_expansion(
tmp_path: Path,
initial_kwargs: dict,
additional_kwargs: dict,
):
initial = _write_groupby_test_adata(tmp_path / "initial.h5ad", **initial_kwargs)
additional = _write_groupby_test_adata(tmp_path / "additional.h5ad", **additional_kwargs)
with warnings.catch_warnings():
# annbatch writes Zarr v3 stores and consolidates metadata during writes,
# which currently triggers a known zarr warning unrelated to the behavior under test.
warnings.filterwarnings(
"ignore",
message="Consolidated metadata is currently not part.*",
category=UserWarning,
)
collection = DatasetCollection(tmp_path / "collection.zarr").add_adatas(
[initial],
n_obs_per_chunk=2,
shard_size=2,
dataset_size=3,
shuffle_chunk_size=1,
shuffle=False,
rng=np.random.default_rng(0),
)
with warnings.catch_warnings(record=True) as caught_warnings:
warnings.simplefilter("always")
collection.add_adatas(
[additional],
n_obs_per_chunk=2,
shard_size=2,
shuffle_chunk_size=1,
shuffle=False,
rng=np.random.default_rng(0),
)
_assert_warning_count(caught_warnings, r"categorical obs columns", 0)


@pytest.mark.parametrize(
("groupby", "match"),
[
Expand Down
Loading