From 28090444a0f4733f1f665488816c2470004f3bc6 Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Tue, 13 Jan 2026 16:34:02 +0100 Subject: [PATCH 01/39] feat: move to a class based API --- src/annbatch/__init__.py | 4 +- src/annbatch/io.py | 520 +++++++++--------- ...t_store_creation.py => test_preshuffle.py} | 91 ++- 3 files changed, 307 insertions(+), 308 deletions(-) rename tests/{test_store_creation.py => test_preshuffle.py} (83%) diff --git a/src/annbatch/__init__.py b/src/annbatch/__init__.py index 306cbb99..fd9ec611 100644 --- a/src/annbatch/__init__.py +++ b/src/annbatch/__init__.py @@ -3,9 +3,9 @@ from importlib.metadata import version from . import types -from .io import add_to_collection, create_anndata_collection, write_sharded +from .io import PreShuffledCollection, write_sharded from .loader import Loader __version__ = version("annbatch") -__all__ = ["Loader", "write_sharded", "add_to_collection", "create_anndata_collection", "types"] +__all__ = ["Loader", "write_sharded", "PreShuffledCollection", "types"] diff --git a/src/annbatch/io.py b/src/annbatch/io.py index f197cd14..fbb7ad04 100644 --- a/src/annbatch/io.py +++ b/src/annbatch/io.py @@ -1,7 +1,7 @@ from __future__ import annotations -import json import random +import re import warnings from collections import defaultdict from functools import wraps @@ -10,6 +10,7 @@ import anndata as ad import dask.array as da +import h5py import numpy as np import pandas as pd import scipy.sparse as sp @@ -40,6 +41,7 @@ def write_sharded( dense_chunk_size: int = 1024, dense_shard_size: int = 4194304, compressors: Iterable[BytesBytesCodec] = (BloscCodec(cname="lz4", clevel=3, shuffle=BloscShuffle.shuffle),), + key: str | None = None, ): """Write a sharded zarr store from a single AnnData object. @@ -99,11 +101,13 @@ def callback( } write_func(store, elem_name, elem, dataset_kwargs=dataset_kwargs) - ad.experimental.write_dispatched(group, "/", adata, callback=callback) + ad.experimental.write_dispatched(group, "/" if key is None else key, adata, callback=callback) zarr.consolidate_metadata(group.store) -def _check_for_mismatched_keys(paths_or_anndatas: Iterable[PathLike[str] | ad.AnnData] | Iterable[str | ad.AnnData]): +def _check_for_mismatched_keys( + paths_or_anndatas: Iterable[PathLike[str] | ad.AnnData | zarr.Group | h5py.Group] | Iterable[str | ad.AnnData], +): num_raw_in_adata = 0 found_keys: dict[str, defaultdict[str, int]] = { "layers": defaultdict(lambda: 0), @@ -118,7 +122,7 @@ def _check_for_mismatched_keys(paths_or_anndatas: Iterable[PathLike[str] | ad.An for elem_name, key_count in found_keys.items(): curr_keys = set(getattr(adata, elem_name).keys()) for key in curr_keys: - if not (elem_name in { "var", "obs" } and key == "_index"): + if not (elem_name in {"var", "obs"} and key == "_index"): key_count[key] += 1 if adata.raw is not None: num_raw_in_adata += 1 @@ -140,7 +144,9 @@ def _check_for_mismatched_keys(paths_or_anndatas: Iterable[PathLike[str] | ad.An def _lazy_load_anndatas( paths: Iterable[PathLike[str]] | Iterable[str], - load_adata: Callable[[PathLike[str] | str], ad.AnnData] = lambda x: ad.experimental.read_lazy(x, load_annotation_index=False), + load_adata: Callable[[PathLike[str] | str], ad.AnnData] = lambda x: ad.experimental.read_lazy( + x, load_annotation_index=False + ), ): adatas = [] categoricals_in_all_adatas = {} @@ -167,6 +173,11 @@ def _lazy_load_anndatas( categoricals_in_all_adatas[k] = set(categoricals_in_all_adatas[k]).union( set(categorical_cols_in_this_adata[k]) ) + # TODO: Probably bug in anndata, need the true index for proper outer joins (can't skirt this with fake indexes, at least not in the mixed-type regime). + if isinstance(adata.var, Dataset2D): + adata.var.index = adata.var.true_index + if adata.raw is not None and isinstance(adata.raw.var, Dataset2D): + adata.raw.var.index = adata.raw.var.true_index adatas.append(adata) if len(adatas) == 1: return adatas[0] @@ -210,9 +221,14 @@ def _persist_adata_in_memory(adata: ad.AnnData) -> ad.AnnData: adata.X = _compute_blockwise(adata.X) if isinstance(adata.obs, Dataset2D): adata.obs = adata.obs.to_memory() + # TODO: This is a bug in anndata? + if "_index" in adata.obs.columns: + del adata.obs["_index"] adata = _to_categorical_obs(adata) if isinstance(adata.var, Dataset2D): adata.var = adata.var.to_memory() + if "_index" in adata.var.columns: + del adata.var["_index"] if adata.raw is not None: adata_raw = adata.raw.to_adata() @@ -220,6 +236,8 @@ def _persist_adata_in_memory(adata: ad.AnnData) -> ad.AnnData: adata_raw.X = _compute_blockwise(adata_raw.X) if isinstance(adata_raw.var, Dataset2D): adata_raw.var = adata_raw.var.to_memory() + if "_index" in adata_raw.var.columns: + del adata_raw.var["_index"] if isinstance(adata_raw.obs, Dataset2D): adata_raw.obs = adata_raw.obs.to_memory() del adata.raw @@ -249,259 +267,261 @@ def wrapper(*args, **kwargs): return wrapper -@_with_settings -def create_anndata_collection( - adata_paths: Iterable[PathLike[str]] | Iterable[str], - output_path: PathLike[str] | str, - *, - load_adata: Callable[[PathLike[str] | str], ad.AnnData] = ad.experimental.read_lazy, - var_subset: Iterable[str] | None = None, - zarr_sparse_chunk_size: int = 32768, - zarr_sparse_shard_size: int = 134_217_728, - zarr_dense_chunk_size: int = 1024, - zarr_dense_shard_size: int = 4_194_304, - zarr_compressor: Iterable[BytesBytesCodec] = (BloscCodec(cname="lz4", clevel=3, shuffle=BloscShuffle.shuffle),), - h5ad_compressor: Literal["gzip", "lzf"] | None = "gzip", - n_obs_per_dataset: int = 2_097_152, - shuffle: bool = True, - should_denseify: bool = False, - output_format: Literal["h5ad", "zarr"] = "zarr", -): - """Take AnnData paths, create an on-disk set of AnnData datasets with uniform var spaces at the desired path with `n_obs_per_dataset` rows per store. +class PreShuffledCollection[T: h5py.Group | zarr.Group]: + """A preshuffled collection object including functionality for creating, adding to, and loading collections.""" - The set of AnnData datasets is collectively referred to as a "collection" where each dataset is called `dataset_i.{zarr,h5ad}`. - The main purpose of this function is to create shuffled sharded zarr datasets, which is the default behavior of this function. - However, this function can also output h5 datasets and also unshuffled datasets as well. - The var space is by default outer-joined, but can be subsetted by `var_subset`. - A key `src_path` is added to `obs` to indicate where individual row came from. - We highly recommend making your indexes unique across files, and this function will call {meth}`AnnData.obs_names_make_unique`. - Memory usage should be controlled by `n_obs_per_dataset` as so many rows will be read into memory before writing to disk. + _group: T - Parameters - ---------- - adata_paths - Paths to the AnnData files used to create the zarr store. - output_path - Path to the output zarr store. - load_adata - Function to customize lazy-loading the invidiual input anndata files. By default, {func}`anndata.experimental.read_lazy` is used. - If you only need a subset of the input anndata files' elems (e.g., only `X` and `obs`), you can provide a custom function here to speed up loading and harmonize your data. - The input to the function is a path to an anndata file, and the output is an anndata object which has `X` as a {class}`dask.array.Array`. - var_subset - Subset of gene names to include in the store. If None, all genes are included. - Genes are subset based on the `var_names` attribute of the concatenated AnnData object. - zarr_sparse_chunk_size - Size of the chunks to use for the `indices` and `data` of a sparse matrix in the zarr store. - zarr_sparse_shard_size - Size of the shards to use for the `indices` and `data` of a sparse matrix in the zarr store. - zarr_dense_chunk_size - Number of observations per dense zarr chunk i.e., sharding is only done along the first axis of the array. - zarr_dense_shard_size - Number of observations per dense zarr shard i.e., chunking is only done along the first axis of the array. - zarr_compressor - Compressors to use to compress the data in the zarr store. - h5ad_compressor - Compressors to use to compress the data in the h5ad store. See anndata.write_h5ad. - n_obs_per_dataset - Number of observations to load into memory at once for shuffling / pre-processing. - The higher this number, the more memory is used, but the better the shuffling. - This corresponds to the size of the shards created. - shuffle - Whether to shuffle the data before writing it to the store. - should_denseify - Whether to write as dense on disk. There's no need to set this for sparse data, it is only for testing. - output_format - Format of the output store. Can be either "zarr" or "h5ad". - - Examples - -------- - >>> import anndata as ad - >>> from annbatch import create_anndata_collection - # create a custom load function to only keep `.X`, `.obs` and `.var` in the output store - >>> def read_lazy_x_and_obs_only(path): - ... adata = ad.experimental.read_lazy(path) - ... return ad.AnnData( - ... X=adata.X, - ... obs=adata.obs.to_memory(), - ... var=adata.var.to_memory(), - ...) - - >>> datasets = [ - ... "path/to/first_adata.h5ad", - ... "path/to/second_adata.h5ad", - ... "path/to/third_adata.h5ad", - ... ] - >>> create_anndata_collection( - ... datasets, - ... "path/to/output/zarr_store", - ... load_adata=read_lazy_x_and_obs_only, - ...) - """ - Path(output_path).mkdir(parents=True, exist_ok=True) - _check_for_mismatched_keys(adata_paths) - adata_concat = _lazy_load_anndatas(adata_paths, load_adata=load_adata) - adata_concat.obs_names_make_unique() - chunks = _create_chunks_for_shuffling(adata_concat, n_obs_per_dataset, shuffle=shuffle) - - if var_subset is None: - var_subset = adata_concat.var_names - - for i, chunk in enumerate(tqdm(chunks, desc="processing chunks")): - var_mask = adata_concat.var_names.isin(var_subset) - # np.sort: It's more efficient to access elements sequentially from dask arrays - # The data will be shuffled later on, we just want the elements at this point - adata_chunk = adata_concat[np.sort(chunk), :][:, var_mask].copy() - adata_chunk = _persist_adata_in_memory(adata_chunk) - if shuffle: - # shuffle adata in memory to break up individual chunks - idxs = np.random.default_rng().permutation(np.arange(len(adata_chunk))) - adata_chunk = adata_chunk[idxs] - # convert to dense format before writing to disk - if should_denseify: - # Need to convert back to dask array to avoid memory issues when converting large sparse matrices to dense - adata_chunk = adata_chunk.copy() - adata_chunk.X = da.from_array( - adata_chunk.X, chunks=(zarr_dense_chunk_size, -1), meta=adata_chunk.X - ).map_blocks(lambda xx: xx.toarray(), dtype=adata_chunk.X.dtype) - - if output_format == "zarr": - f = zarr.open_group(Path(output_path) / f"{DATASET_PREFIX}_{i}.zarr", mode="w") - write_sharded( - f, - adata_chunk, - sparse_chunk_size=zarr_sparse_chunk_size, - sparse_shard_size=zarr_sparse_shard_size, - dense_chunk_size=zarr_dense_chunk_size, - dense_shard_size=zarr_dense_shard_size, - compressors=zarr_compressor, - ) - elif output_format == "h5ad": - adata_chunk.write_h5ad(Path(output_path) / f"{DATASET_PREFIX}_{i}.h5ad", compression=h5ad_compressor) - else: - raise ValueError(f"Unrecognized output_format: {output_format}. Only 'zarr' and 'h5ad' are supported.") + def __init__(self, group: T | str | Path, *, mode: Literal["a", "r", "r+"] = "a"): + """Initialization of the object at a given location. + Note that if the group is a h5py/zarr object, it must have the correct permissions for any subsequent operations you plan to do. + Otherwise, the store will be opened according to the mode argument. -def _get_array_encoding_type(path: PathLike[str] | str) -> str: - shards = list(Path(path).glob(f"{DATASET_PREFIX}_*.zarr")) - with open(shards[0] / "X" / "zarr.json") as f: - encoding = json.load(f) - return encoding["attributes"]["encoding-type"] + Parameters + ---------- + group + The base location for a preshuffled collection + """ + if not isinstance(group, h5py.Group | zarr.Group): + if isinstance(group, str | Path): + if str(group).endswith("h5ad"): + self._group = h5py.File(group, mode=mode) + elif str(group).endswith("zarr"): + self._group = zarr.open_group(group, mode=mode) + else: + raise ValueError("String argument must end in h5ad or zarr") + else: + raise TypeError("group must be a zarr or hdf5 group") + else: + self._group = group -@_with_settings -def add_to_collection( - adata_paths: Iterable[PathLike[str]] | Iterable[str], - output_path: PathLike[str] | str, - load_adata: Callable[[PathLike[str] | str], ad.AnnData] = ad.read_h5ad, - zarr_sparse_chunk_size: int = 32768, - zarr_sparse_shard_size: int = 134_217_728, - zarr_dense_chunk_size: int = 1024, - zarr_dense_shard_size: int = 4_194_304, - zarr_compressor: Iterable[BytesBytesCodec] = (BloscCodec(cname="lz4", clevel=3, shuffle=BloscShuffle.shuffle),), - should_sparsify_output_in_memory: bool = False, -) -> None: - """Add anndata files to an existing collection of sharded anndata zarr datasets. - - The var space of the source anndata files will be adapted to the target store. + @property + def _dataset_keys(self) -> list[str]: + return [k for k in self._group.keys() if re.match(rf"{DATASET_PREFIX}_([0-9]*)", k) is not None] - Parameters - ---------- - adata_paths - Paths to the anndata files to be appended to the collection of output chunks. - output_path - Path to the output zarr store. - load_adata - Function to customize loading the invidiual input anndata files. By default, {func}`anndata.read_h5ad` is used. - If you only need a subset of the input anndata files' elems (e.g., only `X` and `obs`), you can provide a custom function here to speed up loading and harmonize your data. - The input to the function is a path to an anndata file, and the output is an anndata object. - If the input data is too large to fit into memory, you should use `ad.experimental.read_lazy` instead. - zarr_sparse_chunk_size - Size of the chunks to use for the `indices` and `data` of a sparse matrix in the zarr store. - zarr_sparse_shard_size - Size of the shards to use for the `indices` and `data` of a sparse matrix in the zarr store. - zarr_dense_chunk_size - Number of observations per dense zarr chunk i.e., sharding is only done along the first axis of the array. - zarr_dense_shard_size - Number of observations per dense zarr shard i.e., chunking is only done along the first axis of the array. - zarr_compressor - Compressors to use to compress the data in the zarr store. - should_sparsify_output_in_memory - This option is for testing only appending sparse files to dense stores. - To save memory, the blocks of a dense on-disk store can be sparsified for in-memory processing. - - Examples - -------- - >>> import anndata as ad - >>> from annbatch import add_to_collection - >>> datasets = [ - ... "path/to/first_adata.h5ad", - ... "path/to/second_adata.h5ad", - ... "path/to/third_adata.h5ad", - ... ] - >>> add_to_collection( - ... datasets, - ... "path/to/output/zarr_store", - ... load_adata=ad.read_h5ad, # replace with ad.experimental.read_lazy if data does not fit into memory - ...) - """ - shards = list(Path(output_path).glob(f"{DATASET_PREFIX}_*.zarr")) - if len(shards) == 0: - raise ValueError( - "Store at `output_path` does not exist or is empty. Please run `create_anndata_collection` first." + @property + def is_empty(self) -> bool: + """Wether or not there is an existing store at the group location.""" + return "annbatch-shuffled" not in self._group.attrs or ( + not self._group.attrs["annbatch-shuffled"] and len(self._dataset_keys) == 0 ) - encoding = _get_array_encoding_type(output_path) - if encoding == "array": - print("Detected array encoding type. Will convert to dense format before writing.") - # Check for mismatched keys among the inputs. - _check_for_mismatched_keys(adata_paths) - - adata_concat = _lazy_load_anndatas(adata_paths, load_adata=load_adata) - # Check for mismatched keys between shards and the inputs. - _check_for_mismatched_keys([adata_concat] + shards) - if isinstance(adata_concat.X, DaskArray): - chunks = _create_chunks_for_shuffling(adata_concat, np.ceil(len(adata_concat) / len(shards)), shuffle=True) - else: - chunks = np.array_split(np.random.default_rng().permutation(len(adata_concat)), len(shards)) - - adata_concat.obs_names_make_unique() - if encoding == "array": - if not should_sparsify_output_in_memory: - if isinstance(adata_concat.X, sp.spmatrix): - adata_concat.X = adata_concat.X.toarray() - elif isinstance(adata_concat.X, DaskArray) and isinstance(adata_concat.X._meta, sp.spmatrix): - adata_concat.X = adata_concat.X.map_blocks( - lambda x: x.toarray(), meta=np.ndarray, dtype=adata_concat.X.dtype + + @_with_settings + def create_anndata_collection( + self, + adata_paths: Iterable[PathLike[str]] | Iterable[str], + *, + load_adata: Callable[[PathLike[str] | str], ad.AnnData] = lambda x: ad.experimental.read_lazy( + x, load_annotation_index=False + ), + var_subset: Iterable[str] | None = None, + zarr_sparse_chunk_size: int = 32768, + zarr_sparse_shard_size: int = 134_217_728, + zarr_dense_chunk_size: int = 1024, + zarr_dense_shard_size: int = 4_194_304, + zarr_compressor: Iterable[BytesBytesCodec] = (BloscCodec(cname="lz4", clevel=3, shuffle=BloscShuffle.shuffle),), + h5ad_compressor: Literal["gzip", "lzf"] | None = "gzip", + n_obs_per_dataset: int = 2_097_152, + shuffle: bool = True, + ): + """Take AnnData paths, create an on-disk set of AnnData datasets with uniform var spaces at the desired path with `n_obs_per_dataset` rows per store. + + The set of AnnData datasets is collectively referred to as a "collection" where each dataset is called `dataset_i.{zarr,h5ad}`. + The main purpose of this function is to create shuffled sharded zarr datasets, which is the default behavior of this function. + However, this function can also output h5 datasets and also unshuffled datasets as well. + The var space is by default outer-joined, but can be subsetted by `var_subset`. + A key `src_path` is added to `obs` to indicate where individual row came from. + We highly recommend making your indexes unique across files, and this function will call {meth}`AnnData.obs_names_make_unique`. + Memory usage should be controlled by `n_obs_per_dataset` as so many rows will be read into memory before writing to disk. + + Parameters + ---------- + adata_paths + Paths to the AnnData files used to create the zarr store. + output_path + Path to the output zarr store. + load_adata + Function to customize lazy-loading the invidiual input anndata files. By default, {func}`anndata.experimental.read_lazy` is used. + If you only need a subset of the input anndata files' elems (e.g., only `X` and `obs`), you can provide a custom function here to speed up loading and harmonize your data. + The input to the function is a path to an anndata file, and the output is an anndata object which has `X` as a {class}`dask.array.Array`. + var_subset + Subset of gene names to include in the store. If None, all genes are included. + Genes are subset based on the `var_names` attribute of the concatenated AnnData object. + zarr_sparse_chunk_size + Size of the chunks to use for the `indices` and `data` of a sparse matrix in the zarr store. + zarr_sparse_shard_size + Size of the shards to use for the `indices` and `data` of a sparse matrix in the zarr store. + zarr_dense_chunk_size + Number of observations per dense zarr chunk i.e., sharding is only done along the first axis of the array. + zarr_dense_shard_size + Number of observations per dense zarr shard i.e., chunking is only done along the first axis of the array. + zarr_compressor + Compressors to use to compress the data in the zarr store. + h5ad_compressor + Compressors to use to compress the data in the h5ad store. See anndata.write_h5ad. + n_obs_per_dataset + Number of observations to load into memory at once for shuffling / pre-processing. + The higher this number, the more memory is used, but the better the shuffling. + This corresponds to the size of the shards created. + shuffle + Whether to shuffle the data before writing it to the store. + + Examples + -------- + >>> import anndata as ad + >>> from annbatch import create_anndata_collection + # create a custom load function to only keep `.X`, `.obs` and `.var` in the output store + >>> def read_lazy_x_and_obs_only(path): + ... adata = ad.experimental.read_lazy(path) + ... return ad.AnnData( + ... X=adata.X, + ... obs=adata.obs.to_memory(), + ... var=adata.var.to_memory(), + ...) + + >>> datasets = [ + ... "path/to/first_adata.h5ad", + ... "path/to/second_adata.h5ad", + ... "path/to/third_adata.h5ad", + ... ] + >>> PreShuffledCollection("path/to/output/zarr_store.zarr").create_anndata_collection( + ... datasets, + ... load_adata=read_lazy_x_and_obs_only, + ...) + """ + if not self.is_empty: + raise RuntimeError("Cannot create a collection at a location that already has a shuffled collection") + _check_for_mismatched_keys(adata_paths) + adata_concat = _lazy_load_anndatas(adata_paths, load_adata=load_adata) + adata_concat.obs_names_make_unique() + chunks = _create_chunks_for_shuffling(adata_concat, n_obs_per_dataset, shuffle=shuffle) + + if var_subset is None: + var_subset = adata_concat.var_names + + for i, chunk in enumerate(tqdm(chunks, desc="processing chunks")): + var_mask = adata_concat.var_names.isin(var_subset) + # np.sort: It's more efficient to access elements sequentially from dask arrays + # The data will be shuffled later on, we just want the elements at this point + adata_chunk = adata_concat[np.sort(chunk), :][:, var_mask].copy() + adata_chunk = _persist_adata_in_memory(adata_chunk) + if shuffle: + # shuffle adata in memory to break up individual chunks + idxs = np.random.default_rng().permutation(np.arange(len(adata_chunk))) + adata_chunk = adata_chunk[idxs] + if isinstance(self._group, zarr.Group): + write_sharded( + self._group, + adata_chunk, + sparse_chunk_size=zarr_sparse_chunk_size, + sparse_shard_size=zarr_sparse_shard_size, + dense_chunk_size=zarr_dense_chunk_size, + dense_shard_size=zarr_dense_shard_size, + compressors=zarr_compressor, + key=f"{DATASET_PREFIX}_{i}", + ) + else: + ad.io.write_elem( + self._group, f"{DATASET_PREFIX}_{i}", adata_chunk, dataset_kwargs={"compression": h5ad_compressor} ) - elif encoding == "csr_matrix": - if isinstance(adata_concat.X, np.ndarray): - adata_concat.X = sp.csr_matrix(adata_concat.X) - elif isinstance(adata_concat.X, DaskArray) and isinstance(adata_concat.X._meta, np.ndarray): - adata_concat.X = adata_concat.X.map_blocks( - sp.csr_matrix, meta=sp.csr_matrix(np.array([0], dtype=adata_concat.X.dtype)) + if isinstance(self._group, zarr.Group): + self._group.update_attributes({"annbatch-shuffled": True}) + else: + self._group.attrs["annbatch-shuffled"] = True + + @_with_settings + def add_to_collection( + self, + adata_paths: Iterable[PathLike[str]] | Iterable[str], + load_adata: Callable[[PathLike[str] | str], ad.AnnData] = ad.read_h5ad, + zarr_sparse_chunk_size: int = 32768, + zarr_sparse_shard_size: int = 134_217_728, + zarr_dense_chunk_size: int = 1024, + zarr_dense_shard_size: int = 4_194_304, + zarr_compressor: Iterable[BytesBytesCodec] = (BloscCodec(cname="lz4", clevel=3, shuffle=BloscShuffle.shuffle),), + h5ad_compressor: Literal["gzip", "lzf"] | None = "gzip", + ) -> None: + """Add anndata files to an existing collection of sharded anndata zarr datasets. + + The var space of the source anndata files will be adapted to the target store. + + Parameters + ---------- + adata_paths + Paths to the anndata files to be appended to the collection of output chunks. + load_adata + Function to customize loading the invidiual input anndata files. By default, {func}`anndata.read_h5ad` is used. + If you only need a subset of the input anndata files' elems (e.g., only `X` and `obs`), you can provide a custom function here to speed up loading and harmonize your data. + The input to the function is a path to an anndata file, and the output is an anndata object. + If the input data is too large to fit into memory, you should use `ad.experimental.read_lazy` instead. + zarr_sparse_chunk_size + Size of the chunks to use for the `indices` and `data` of a sparse matrix in the zarr store. + zarr_sparse_shard_size + Size of the shards to use for the `indices` and `data` of a sparse matrix in the zarr store. + zarr_dense_chunk_size + Number of observations per dense zarr chunk i.e., sharding is only done along the first axis of the array. + zarr_dense_shard_size + Number of observations per dense zarr shard i.e., chunking is only done along the first axis of the array. + zarr_compressor + Compressors to use to compress the data in the zarr store. + should_sparsify_output_in_memory + This option is for testing only appending sparse files to dense stores. + To save memory, the blocks of a dense on-disk store can be sparsified for in-memory processing. + + Examples + -------- + >>> import anndata as ad + >>> from annbatch import add_to_collection + >>> datasets = [ + ... "path/to/first_adata.h5ad", + ... "path/to/second_adata.h5ad", + ... "path/to/third_adata.h5ad", + ... ] + >>> add_to_collection( + ... datasets, + ... "path/to/output/zarr_store", + ... load_adata=ad.read_h5ad, # replace with ad.experimental.read_lazy if data does not fit into memory + ...) + """ + if self.is_empty: + raise ValueError("Store is empty. Please run `PreShuffledCollection.create_anndata_collection` first.") + # Check for mismatched keys among the inputs. + _check_for_mismatched_keys(adata_paths) + + adata_concat = _lazy_load_anndatas(adata_paths, load_adata=load_adata) + # Check for mismatched keys between datasets and the inputs. + _check_for_mismatched_keys([adata_concat] + [self._group[k] for k in self._dataset_keys]) + if isinstance(adata_concat.X, DaskArray): + chunks = _create_chunks_for_shuffling( + adata_concat, np.ceil(len(adata_concat) / len(self._dataset_keys)), shuffle=True ) - - for shard, chunk in tqdm(zip(shards, chunks, strict=False), total=len(shards), desc="processing chunks"): - if should_sparsify_output_in_memory and encoding == "array": - adata_shard = _lazy_load_anndatas([shard]) - adata_shard.X = adata_shard.X.map_blocks(sp.csr_matrix).compute() else: - adata_shard = ad.read_zarr(shard) - subset_adata = _to_categorical_obs( - adata_concat[chunk, :][:, adata_concat.var.index.isin(adata_shard.var.index)] - ) - adata = ad.concat([adata_shard, subset_adata], join="outer") - idxs_shuffled = np.random.default_rng().permutation(len(adata)) - adata = adata[idxs_shuffled, :].copy() # this significantly speeds up writing to disk - if should_sparsify_output_in_memory and encoding == "array": - adata.X = adata.X.map_blocks(lambda x: x.toarray(), meta=np.array([0], dtype=adata.X.dtype)).compute() - - f = zarr.open_group(shard, mode="w") - write_sharded( - f, - adata, - sparse_chunk_size=zarr_sparse_chunk_size, - sparse_shard_size=zarr_sparse_shard_size, - dense_chunk_size=zarr_dense_chunk_size, - dense_shard_size=zarr_dense_shard_size, - compressors=zarr_compressor, - ) + chunks = np.array_split(np.random.default_rng().permutation(len(adata_concat)), len(self._dataset_keys)) + + adata_concat.obs_names_make_unique() + + for dataset, chunk in tqdm( + zip(self._dataset_keys, chunks, strict=False), total=len(self._dataset_keys), desc="processing chunks" + ): + adata_dataset = ad.io.read_elem(self._group[dataset]) + subset_adata = _to_categorical_obs( + adata_concat[chunk, :][:, adata_concat.var.index.isin(adata_dataset.var.index)] + ) + adata = ad.concat([adata_dataset, subset_adata], join="outer") + idxs_shuffled = np.random.default_rng().permutation(len(adata)) + adata = _persist_adata_in_memory(adata[idxs_shuffled, :]) + if isinstance(self._group, zarr.Group): + write_sharded( + self._group, + adata, + sparse_chunk_size=zarr_sparse_chunk_size, + sparse_shard_size=zarr_sparse_shard_size, + dense_chunk_size=zarr_dense_chunk_size, + dense_shard_size=zarr_dense_shard_size, + compressors=zarr_compressor, + key=dataset, + ) + else: + ad.io.write_elem(self._group, dataset, adata, dataset_kwargs={"compression": h5ad_compressor}) diff --git a/tests/test_store_creation.py b/tests/test_preshuffle.py similarity index 83% rename from tests/test_store_creation.py rename to tests/test_preshuffle.py index 5e9f099c..24663b2f 100644 --- a/tests/test_store_creation.py +++ b/tests/test_preshuffle.py @@ -10,7 +10,7 @@ import scipy.sparse as sp import zarr -from annbatch import add_to_collection, create_anndata_collection, write_sharded +from annbatch import PreShuffledCollection, write_sharded if TYPE_CHECKING: from collections.abc import Callable @@ -48,9 +48,8 @@ def test_store_creation_warngs_with_different_keys(elem_name: Literal["obsm", "l adata_1.write_h5ad(path_1) adata_2.write_h5ad(path_2) with pytest.warns(UserWarning, match=rf"Found {elem_name} keys.* not present in all anndatas"): - create_anndata_collection( + PreShuffledCollection(tmp_path / "collection.zarr").create_anndata_collection( [path_1, path_2], - tmp_path / "collection", zarr_sparse_chunk_size=10, zarr_sparse_shard_size=20, zarr_dense_chunk_size=5, @@ -67,10 +66,9 @@ def test_store_creation_path_added_to_obs(tmp_path: Path): adata_1.write_h5ad(path_1) adata_2.write_h5ad(path_2) paths = [path_1, path_2] - output_dir = tmp_path / "path_src_collection" - create_anndata_collection( + output_dir = tmp_path / "path_src_collection.zarr" + PreShuffledCollection(output_dir).create_anndata_collection( paths, - output_dir, zarr_sparse_chunk_size=10, zarr_sparse_shard_size=20, zarr_dense_chunk_size=5, @@ -78,7 +76,9 @@ def test_store_creation_path_added_to_obs(tmp_path: Path): n_obs_per_dataset=10, shuffle=False, ) - adata_result = ad.concat([ad.read_zarr(path) for path in sorted((output_dir).iterdir())], join="outer") + adata_result = ad.concat( + [ad.read_zarr(path) for path in sorted((output_dir).iterdir()) if path.is_dir()], join="outer" + ) pd.testing.assert_extension_array_equal( adata_result.obs["src_path"].array, pd.Categorical(([str(path_1)] * 10) + ([str(path_2)] * 10), categories=[str(p) for p in paths]), @@ -95,11 +95,10 @@ def test_store_addition_different_keys( adata_orig = ad.AnnData(X=np.random.randn(100, 20)) orig_path = tmp_path / "orig.h5ad" adata_orig.write_h5ad(orig_path) - output_path = tmp_path / "zarr_store_addition_different_keys" - output_path.mkdir(parents=True, exist_ok=True) - create_anndata_collection( + output_path = tmp_path / "zarr_store_addition_different_keys.zarr" + collection = PreShuffledCollection(output_path) + collection.create_anndata_collection( [orig_path], - output_path, zarr_sparse_chunk_size=10, zarr_sparse_shard_size=20, zarr_dense_chunk_size=10, @@ -113,9 +112,8 @@ def test_store_addition_different_keys( additional_path = tmp_path / "with_extra_key.h5ad" adata.write_h5ad(additional_path) with pytest.warns(UserWarning, match=rf"Found {elem_name} keys.* not present in all anndatas"): - add_to_collection( + collection.add_to_collection( [additional_path], - output_path, load_adata=load_adata, zarr_sparse_chunk_size=10, zarr_sparse_shard_size=20, @@ -145,11 +143,9 @@ def test_store_creation_default( ): var_subset = [f"gene_{i}" for i in range(100)] h5_files = sorted(adata_with_h5_path_different_var_space[1].iterdir()) - output_path = adata_with_h5_path_different_var_space[1].parent / "zarr_store_creation_test_default" - output_path.mkdir(parents=True, exist_ok=True) - create_anndata_collection( + output_path = adata_with_h5_path_different_var_space[1].parent / "zarr_store_creation_test_default.zarr" + PreShuffledCollection(output_path).create_anndata_collection( [adata_with_h5_path_different_var_space[1] / f for f in h5_files if str(f).endswith(".h5ad")], - output_path, var_subset=var_subset, zarr_sparse_chunk_size=10, zarr_sparse_shard_size=20, @@ -157,8 +153,10 @@ def test_store_creation_default( zarr_dense_shard_size=20, n_obs_per_dataset=60, ) - assert isinstance(ad.read_zarr(next((output_path).iterdir())).X, sp.csr_matrix) - assert sorted(glob.glob(str(output_path / "dataset_*.zarr"))) == sorted(str(p) for p in (output_path).iterdir()) + assert isinstance(ad.read_zarr(next(p for p in (output_path).iterdir() if p.is_dir())).X, sp.csr_matrix) + assert sorted(glob.glob(str(output_path / "dataset_*"))) == sorted( + str(p) for p in (output_path).iterdir() if p.is_dir() + ) def test_store_creation_drop_elem( @@ -166,12 +164,11 @@ def test_store_creation_drop_elem( ): var_subset = [f"gene_{i}" for i in range(100)] h5_files = sorted(adata_with_h5_path_different_var_space[1].iterdir()) - output_path = adata_with_h5_path_different_var_space[1].parent / "zarr_store_creation_drop_elems" + output_path = adata_with_h5_path_different_var_space[1].parent / "zarr_store_creation_drop_elems.zarr" output_path.mkdir(parents=True, exist_ok=True) - create_anndata_collection( + PreShuffledCollection(output_path).create_anndata_collection( [adata_with_h5_path_different_var_space[1] / f for f in h5_files if str(f).endswith(".h5ad")], - output_path, var_subset=var_subset, zarr_sparse_chunk_size=10, zarr_sparse_shard_size=20, @@ -186,19 +183,15 @@ def test_store_creation_drop_elem( @pytest.mark.parametrize("shuffle", [pytest.param(True, id="shuffle"), pytest.param(False, id="no_shuffle")]) -@pytest.mark.parametrize("densify", [pytest.param(True, id="densify"), pytest.param(False, id="no_densify")]) def test_store_creation( adata_with_h5_path_different_var_space: tuple[ad.AnnData, Path], shuffle: bool, - densify: bool, ): var_subset = [f"gene_{i}" for i in range(100)] h5_files = sorted(adata_with_h5_path_different_var_space[1].iterdir()) - output_path = adata_with_h5_path_different_var_space[1].parent / f"zarr_store_creation_test_{shuffle}_{densify}" - output_path.mkdir(parents=True, exist_ok=True) - create_anndata_collection( + output_path = adata_with_h5_path_different_var_space[1].parent / f"zarr_store_creation_test_{shuffle}.zarr" + PreShuffledCollection(output_path).create_anndata_collection( [adata_with_h5_path_different_var_space[1] / f for f in h5_files if str(f).endswith(".h5ad")], - output_path, var_subset=var_subset, zarr_sparse_chunk_size=10, zarr_sparse_shard_size=20, @@ -206,12 +199,11 @@ def test_store_creation( zarr_dense_shard_size=10, n_obs_per_dataset=60, shuffle=shuffle, - should_denseify=densify, ) adata_orig = adata_with_h5_path_different_var_space[0] # make sure all category dtypes match - adatas_shuffled = [ad.read_zarr(zarr_path) for zarr_path in sorted(output_path.iterdir())] + adatas_shuffled = [ad.read_zarr(zarr_path) for zarr_path in sorted(output_path.iterdir()) if zarr_path.is_dir()] for adata in adatas_shuffled: assert adata.obs["label"].dtype == adata_orig.obs["label"].dtype # subset to var_subset @@ -245,12 +237,9 @@ def test_store_creation( adata.obs["label"] = adata.obs["label"].cat.reorder_categories(adata_orig.obs["label"].dtype.categories) pd.testing.assert_frame_equal(adata.obs, adata_orig.obs) - z = zarr.open(output_path / "dataset_0.zarr") + z = zarr.open(output_path / "dataset_0") assert z["obsm"]["arr"].chunks[0] == 5, z["obsm"]["arr"] - if not densify: - assert z["X"]["indices"].chunks[0] == 10 - else: - assert z["X"].chunks[0] == 5, z["X"] + assert z["X"]["indices"].chunks[0] == 10 @pytest.mark.parametrize( @@ -262,13 +251,11 @@ def test_mismatched_raw_concat( adata_with_h5_path_different_var_space: tuple[ad.AnnData, Path], ): h5_files = sorted(adata_with_h5_path_different_var_space[1].iterdir()) - output_path = adata_with_h5_path_different_var_space[1].parent / "zarr_store_creation_test_heterogeneous" - output_path.mkdir(parents=True, exist_ok=True) + output_path = adata_with_h5_path_different_var_space[1].parent / "zarr_store_creation_test_heterogeneous.zarr" h5_paths = [adata_with_h5_path_different_var_space[1] / f for f in h5_files if str(f).endswith(".h5ad")] with pytest.warns(UserWarning, match=r"Found raw keys not present in all anndatas"): - create_anndata_collection( + PreShuffledCollection(output_path).create_anndata_collection( h5_paths, - output_path, zarr_sparse_chunk_size=10, zarr_sparse_shard_size=20, zarr_dense_chunk_size=10, @@ -291,50 +278,45 @@ def test_mismatched_raw_concat( adata_orig = ad.concat(adatas_orig, join="outer") adata_orig.obs_names_make_unique() - adata = ad.concat([ad.read_zarr(zarr_path) for zarr_path in sorted(output_path.iterdir())]) + adata = ad.concat([ad.read_zarr(zarr_path) for zarr_path in sorted(output_path.iterdir()) if zarr_path.is_dir()]) del adata.obs["src_path"] pd.testing.assert_frame_equal(adata_orig.var, adata.var) pd.testing.assert_frame_equal(adata_orig.obs, adata.obs) np.testing.assert_array_equal(adata_orig.X.toarray(), adata.X.toarray()) -@pytest.mark.parametrize("densify", [True, False]) @pytest.mark.parametrize("load_adata", [ad.read_h5ad, ad.experimental.read_lazy]) def test_store_extension( adata_with_h5_path_different_var_space: tuple[ad.AnnData, Path], - densify: bool, load_adata: Callable[[PathLike[str] | str], ad.AnnData], ): - all_h5_paths = sorted(adata_with_h5_path_different_var_space[1].iterdir()) + all_h5_paths = sorted(p for p in adata_with_h5_path_different_var_space[1].iterdir() if p.suffix == ".h5ad") store_path = ( - adata_with_h5_path_different_var_space[1].parent / f"zarr_store_extension_test_{densify}_{load_adata.__name__}" + adata_with_h5_path_different_var_space[1].parent / f"zarr_store_extension_test_{load_adata.__name__}.zarr" ) original = all_h5_paths additional = all_h5_paths[4:] # don't add everything to get a "different" var space # create new store - create_anndata_collection( + collection = PreShuffledCollection(store_path) + collection.create_anndata_collection( original, - store_path, zarr_sparse_chunk_size=10, zarr_sparse_shard_size=20, zarr_dense_chunk_size=10, zarr_dense_shard_size=20, n_obs_per_dataset=60, shuffle=True, - should_denseify=densify, ) # add h5ads to existing store - add_to_collection( + collection.add_to_collection( additional, - store_path, load_adata=load_adata, zarr_sparse_chunk_size=10, zarr_sparse_shard_size=20, zarr_dense_chunk_size=5, zarr_dense_shard_size=10, ) - - adatas_on_disk = [ad.read_zarr(zarr_path) for zarr_path in sorted(store_path.iterdir())] + adatas_on_disk = [ad.read_zarr(zarr_path) for zarr_path in sorted(store_path.iterdir()) if zarr_path.is_dir()] adata = ad.concat(adatas_on_disk) adata_orig = adata_with_h5_path_different_var_space[0] expected_adata = ad.concat([adata_orig, adata_orig[adata_orig.obs["store_id"] >= 4]], join="outer") @@ -344,9 +326,6 @@ def test_store_extension( for a in [*adatas_on_disk, adata]: assert a.obs["label"].dtype == expected_adata.obs["label"].dtype assert "arr" in adata.obsm - z = zarr.open(store_path / "dataset_0.zarr") + z = zarr.open(store_path / "dataset_0") assert z["obsm"]["arr"].chunks == (5, z["obsm"]["arr"].shape[1]) - if not densify: - assert z["X"]["indices"].chunks[0] == 10 - else: - assert z["X"].chunks == (5, z["X"].shape[1]) + assert z["X"]["indices"].chunks[0] == 10 From 94ed2ae62f33cc8e77bae97fa8cbd9ec7c9bad4b Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Tue, 13 Jan 2026 16:52:58 +0100 Subject: [PATCH 02/39] fix: update docs --- README.md | 7 ++++--- docs/index.md | 3 +-- docs/notebooks/example.ipynb | 17 ++++++----------- src/annbatch/io.py | 13 +++++-------- tests/test_preshuffle.py | 12 ++++++++++++ 5 files changed, 28 insertions(+), 24 deletions(-) diff --git a/README.md b/README.md index b7a4d116..4dd8e033 100644 --- a/README.md +++ b/README.md @@ -81,13 +81,14 @@ from pathlib import Path zarr.config.set( {"codec_pipeline.path": "zarrs.ZarrsCodecPipeline"} ) - -create_anndata_collection( +# a directory containing `dataset_{i}.zarr` +collection = PreShuffledCollection("path/to/output/collection.zarr") +collection.create_anndata_collection( adata_paths=[ "path/to/your/file1.h5ad", "path/to/your/file2.h5ad" ], - output_path="path/to/output/collection", # a directory containing `dataset_{i}.zarr` + output_path=, shuffle=True, # shuffling is needed if you want to use chunked access ) ``` diff --git a/docs/index.md b/docs/index.md index 46bf0470..6abbf53c 100644 --- a/docs/index.md +++ b/docs/index.md @@ -9,12 +9,11 @@ Let's go through the above example: ### Preprocessing ```python -create_anndata_collection( +PreShuffledCollection("path/to/output/store.zarr").create_anndata_collection( adata_paths=[ "path/to/your/file1.h5ad", "path/to/your/file2.h5ad" ], - output_path="path/to/output/store", # a directory containing `chunk_{i}.zarr` shuffle=True, # shuffling is needed if you want to use chunked access ) ``` diff --git a/docs/notebooks/example.ipynb b/docs/notebooks/example.ipynb index 3f8aef26..6dbe73bf 100644 --- a/docs/notebooks/example.ipynb +++ b/docs/notebooks/example.ipynb @@ -149,7 +149,7 @@ " * We recommend to choose a dataset size that comfortably fits into system memory.\n", "\n", "\n", - "You can apply custom data transformations to each input h5ad file by supplying a `load_adata` function to `create_anndata_collection`" + "You can apply custom data transformations to each input h5ad file by supplying a `load_adata` function to `PreShuffledCollection.create_anndata_collection`" ] }, { @@ -163,7 +163,7 @@ "outputs": [], "source": [ "import anndata as ad\n", - "from annbatch import create_anndata_collection\n", + "from annbatch import PreShuffledCollection\n", "\n", "\n", "# For CELLxGENE data, the raw counts can either be found under .raw.X or under .X (if .raw is not supplied).\n", @@ -186,15 +186,14 @@ " )\n", "\n", "\n", - "create_anndata_collection(\n", + "collection = PreShuffledCollection(zarr.open(\"annbatch_collection\"))\n", + "collection.create_anndata_collection(\n", " # List all the h5ad files you want to include in the collection\n", " adata_paths=[\"866d7d5e-436b-4dbd-b7c1-7696487d452e.h5ad\", \"f81463b8-4986-4904-a0ea-20ff02cbb317.h5ad\"],\n", " # Path to store the output collection\n", - " output_path=\"annbatch_collection\",\n", " shuffle=True, # Whether to pre-shuffle the cells of the collection\n", " n_obs_per_dataset=2_097_152, # Number of cells per dataset shard\n", " var_subset=None, # Optionally subset the collection to a specific gene space\n", - " should_denseify=False,\n", " load_adata=read_lazy_x_and_obs_only,\n", ")" ] @@ -326,7 +325,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": null, "metadata": { "tags": [ "hide-output" @@ -350,9 +349,6 @@ } ], "source": [ - "from annbatch import add_to_collection\n", - "\n", - "\n", "def read_x_and_obs_only(path) -> ad.AnnData:\n", " \"\"\"Custom load function to only load raw counts from CxG data.\"\"\"\n", " # As it's only a small dataset, we can load the full dataset into memory to speed up computations\n", @@ -367,11 +363,10 @@ " return ad.AnnData(X=x, obs=adata_.obs, var=var)\n", "\n", "\n", - "add_to_collection(\n", + "collection.add_to_collection(\n", " adata_paths=[\n", " \"866d7d5e-436b-4dbd-b7c1-7696487d452e.h5ad\",\n", " ],\n", - " output_path=\"annbatch_collection\",\n", " load_adata=read_x_and_obs_only,\n", ")" ] diff --git a/src/annbatch/io.py b/src/annbatch/io.py index fbb7ad04..41a16602 100644 --- a/src/annbatch/io.py +++ b/src/annbatch/io.py @@ -340,8 +340,6 @@ def create_anndata_collection( ---------- adata_paths Paths to the AnnData files used to create the zarr store. - output_path - Path to the output zarr store. load_adata Function to customize lazy-loading the invidiual input anndata files. By default, {func}`anndata.experimental.read_lazy` is used. If you only need a subset of the input anndata files' elems (e.g., only `X` and `obs`), you can provide a custom function here to speed up loading and harmonize your data. @@ -371,7 +369,7 @@ def create_anndata_collection( Examples -------- >>> import anndata as ad - >>> from annbatch import create_anndata_collection + >>> from annbatch import PreShuffledCollection # create a custom load function to only keep `.X`, `.obs` and `.var` in the output store >>> def read_lazy_x_and_obs_only(path): ... adata = ad.experimental.read_lazy(path) @@ -473,16 +471,15 @@ def add_to_collection( Examples -------- >>> import anndata as ad - >>> from annbatch import add_to_collection + >>> from annbatch import PreShuffledCollection >>> datasets = [ ... "path/to/first_adata.h5ad", ... "path/to/second_adata.h5ad", ... "path/to/third_adata.h5ad", ... ] - >>> add_to_collection( - ... datasets, - ... "path/to/output/zarr_store", - ... load_adata=ad.read_h5ad, # replace with ad.experimental.read_lazy if data does not fit into memory + >>> PreShuffledCollection("path/to/existing/preshuffled_collection.zarr").add_to_collection( + ... datasets, + ... load_adata=ad.read_h5ad, # replace with ad.experimental.read_lazy if data does not fit into memory ...) """ if self.is_empty: diff --git a/tests/test_preshuffle.py b/tests/test_preshuffle.py index 24663b2f..3f24416f 100644 --- a/tests/test_preshuffle.py +++ b/tests/test_preshuffle.py @@ -200,6 +200,8 @@ def test_store_creation( n_obs_per_dataset=60, shuffle=shuffle, ) + assert not PreShuffledCollection(output_path).is_empty + assert zarr.open(output_path).attrs["annbatch-shuffled"] adata_orig = adata_with_h5_path_different_var_space[0] # make sure all category dtypes match @@ -329,3 +331,13 @@ def test_store_extension( z = zarr.open(store_path / "dataset_0") assert z["obsm"]["arr"].chunks == (5, z["obsm"]["arr"].shape[1]) assert z["X"]["indices"].chunks[0] == 10 + + +def test_empty(tmp_path: Path): + g = zarr.open(tmp_path / "empty.zarr") + collection = PreShuffledCollection(g) + assert collection.is_empty + # Doesn't matter what errors as long as this function runs, but not to completion + with pytest.raises(TypeError): + collection.add_to_collection() + assert "annbatch-shuffled" not in g.attrs From 5d8c640831bc9f1b926a94c8f5b3d5f0be261ee0 Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Tue, 13 Jan 2026 16:54:21 +0100 Subject: [PATCH 03/39] fix: readme --- README.md | 1 - 1 file changed, 1 deletion(-) diff --git a/README.md b/README.md index 4dd8e033..b93fd475 100644 --- a/README.md +++ b/README.md @@ -88,7 +88,6 @@ collection.create_anndata_collection( "path/to/your/file1.h5ad", "path/to/your/file2.h5ad" ], - output_path=, shuffle=True, # shuffling is needed if you want to use chunked access ) ``` From 2aa6b7557316e7202bbefa99e990ed12c77b39a3 Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Tue, 13 Jan 2026 16:55:22 +0100 Subject: [PATCH 04/39] fix: docs --- docs/api.md | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/docs/api.md b/docs/api.md index 7656af65..750b0030 100644 --- a/docs/api.md +++ b/docs/api.md @@ -25,8 +25,7 @@ :toctree: generated/ write_sharded - add_to_collection - create_anndata_collection + PreShuffledCollection ``` (types)= From 11c7f413c25a61952439c98469e2aee3373b14ec Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Tue, 13 Jan 2026 16:57:29 +0100 Subject: [PATCH 05/39] fix: docs env --- .readthedocs.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.readthedocs.yaml b/.readthedocs.yaml index c3f3f96f..4acb793c 100644 --- a/.readthedocs.yaml +++ b/.readthedocs.yaml @@ -3,7 +3,7 @@ version: 2 build: os: ubuntu-24.04 tools: - python: "3.12" + python: "3.14" jobs: create_environment: - asdf plugin add uv From 03dd3dd68adb40225733e198478cdf932bc3cb8e Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Wed, 14 Jan 2026 10:48:59 +0100 Subject: [PATCH 06/39] feat: add explicit shuffle size --- .readthedocs.yaml | 2 +- src/annbatch/io.py | 33 +++++++++++++++++++++----------- src/annbatch/utils.py | 44 ------------------------------------------- 3 files changed, 23 insertions(+), 56 deletions(-) diff --git a/.readthedocs.yaml b/.readthedocs.yaml index 4acb793c..c3f3f96f 100644 --- a/.readthedocs.yaml +++ b/.readthedocs.yaml @@ -3,7 +3,7 @@ version: 2 build: os: ubuntu-24.04 tools: - python: "3.14" + python: "3.12" jobs: create_environment: - asdf plugin add uv diff --git a/src/annbatch/io.py b/src/annbatch/io.py index 41a16602..cea81649 100644 --- a/src/annbatch/io.py +++ b/src/annbatch/io.py @@ -187,16 +187,15 @@ def _lazy_load_anndatas( return adata -def _create_chunks_for_shuffling(adata: ad.AnnData, shuffle_n_obs_per_dataset: int = 1_048_576, shuffle: bool = True): - chunk_boundaries = np.cumsum([0] + list(adata.X.chunks[0])) - slices = [ - slice(int(start), int(end)) for start, end in zip(chunk_boundaries[:-1], chunk_boundaries[1:], strict=True) - ] +def _create_chunks_for_shuffling( + n_obs: int, shuffle_n_obs_per_dataset: int = 1_048_576, shuffle_slice_size: int = 1000, shuffle: bool = True +): + # this splits the array up into `shuffle_slice_size` contiguous runs + idxs = np.array_split(np.arange(n_obs), np.ceil(n_obs / shuffle_slice_size)) if shuffle: - random.shuffle(slices) - idxs = np.concatenate([np.arange(s.start, s.stop) for s in slices]) - idxs = np.array_split(idxs, np.ceil(len(idxs) / shuffle_n_obs_per_dataset)) - + random.shuffle(idxs) + idxs = np.concatenate(idxs) + idxs = np.array_split(idxs, np.ceil(n_obs / shuffle_n_obs_per_dataset)) return idxs @@ -324,6 +323,7 @@ def create_anndata_collection( zarr_compressor: Iterable[BytesBytesCodec] = (BloscCodec(cname="lz4", clevel=3, shuffle=BloscShuffle.shuffle),), h5ad_compressor: Literal["gzip", "lzf"] | None = "gzip", n_obs_per_dataset: int = 2_097_152, + shuffle_slice_size: int = 1000, shuffle: bool = True, ): """Take AnnData paths, create an on-disk set of AnnData datasets with uniform var spaces at the desired path with `n_obs_per_dataset` rows per store. @@ -365,6 +365,9 @@ def create_anndata_collection( This corresponds to the size of the shards created. shuffle Whether to shuffle the data before writing it to the store. + shuffle_slice_size + How many contiguous rows to load into memory before shuffling at once. + `(shuffle_slice_size // n_obs_per_dataset)` slices will be loaded of size `shuffle_slice_size`. Examples -------- @@ -394,7 +397,9 @@ def create_anndata_collection( _check_for_mismatched_keys(adata_paths) adata_concat = _lazy_load_anndatas(adata_paths, load_adata=load_adata) adata_concat.obs_names_make_unique() - chunks = _create_chunks_for_shuffling(adata_concat, n_obs_per_dataset, shuffle=shuffle) + chunks = _create_chunks_for_shuffling( + adata_concat.shape[0], n_obs_per_dataset, shuffle_slice_size, shuffle=shuffle + ) if var_subset is None: var_subset = adata_concat.var_names @@ -440,6 +445,7 @@ def add_to_collection( zarr_dense_shard_size: int = 4_194_304, zarr_compressor: Iterable[BytesBytesCodec] = (BloscCodec(cname="lz4", clevel=3, shuffle=BloscShuffle.shuffle),), h5ad_compressor: Literal["gzip", "lzf"] | None = "gzip", + shuffle_slice_size: int = 1000, ) -> None: """Add anndata files to an existing collection of sharded anndata zarr datasets. @@ -467,6 +473,8 @@ def add_to_collection( should_sparsify_output_in_memory This option is for testing only appending sparse files to dense stores. To save memory, the blocks of a dense on-disk store can be sparsified for in-memory processing. + shuffle_slice_size + How many contiguous rows to load into memory of the input data for pseudo-blockshuffling into the existing datasets. Examples -------- @@ -492,7 +500,10 @@ def add_to_collection( _check_for_mismatched_keys([adata_concat] + [self._group[k] for k in self._dataset_keys]) if isinstance(adata_concat.X, DaskArray): chunks = _create_chunks_for_shuffling( - adata_concat, np.ceil(len(adata_concat) / len(self._dataset_keys)), shuffle=True + adata_concat.shape[0], + np.ceil(len(adata_concat) / len(self._dataset_keys)), + shuffle_slice_size, + shuffle=True, ) else: chunks = np.array_split(np.random.default_rng().permutation(len(adata_concat)), len(self._dataset_keys)) diff --git a/src/annbatch/utils.py b/src/annbatch/utils.py index a0a827ba..8d4f8f79 100644 --- a/src/annbatch/utils.py +++ b/src/annbatch/utils.py @@ -65,50 +65,6 @@ def __iter__(self): total += gap -def sample_rows( - x_list: list[np.ndarray], - obs_list: list[np.ndarray] | None, - indices: list[np.ndarray] | None = None, - *, - shuffle: bool = True, -) -> Generator[tuple[np.ndarray, np.ndarray | None], None, None]: - """Samples rows from multiple arrays and their corresponding observation arrays. - - Parameters - ---------- - x_list - A list of numpy arrays containing the data to sample from. - obs_list - A list of numpy arrays containing the corresponding observations. - indices - the list of indexes for each element in `x_list/` - shuffle - Whether to shuffle the rows before sampling. - - Yields - ------ - tuple - A tuple containing a row from `x_list` and the corresponding row from `obs_list`. - """ - lengths = np.fromiter((x.shape[0] for x in x_list), dtype=int) - cum = np.concatenate(([0], np.cumsum(lengths))) - total = cum[-1] - idxs = np.arange(total) - if shuffle: - np.random.default_rng().shuffle(idxs) - arr_idxs = np.searchsorted(cum, idxs, side="right") - 1 - row_idxs = idxs - cum[arr_idxs] - for ai, ri in zip(arr_idxs, row_idxs, strict=True): - res = [ - x_list[ai][ri], - obs_list[ai][ri] if obs_list is not None else None, - ] - if indices is not None: - yield (*res, indices[ai][ri]) - else: - yield tuple(res) - - class WorkerHandle: # noqa: D101 @cached_property def _worker_info(self): From df6365d66498f0b42417fac5f2059b01a2a0adb8 Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Wed, 14 Jan 2026 10:50:07 +0100 Subject: [PATCH 07/39] chore: `PreShuffledCollection` -> `Collection` --- README.md | 2 +- docs/api.md | 2 +- docs/index.md | 2 +- docs/notebooks/example.ipynb | 6 +++--- src/annbatch/__init__.py | 4 ++-- src/annbatch/io.py | 12 ++++++------ tests/test_preshuffle.py | 22 +++++++++++----------- 7 files changed, 25 insertions(+), 25 deletions(-) diff --git a/README.md b/README.md index b93fd475..d74a2b1d 100644 --- a/README.md +++ b/README.md @@ -82,7 +82,7 @@ zarr.config.set( {"codec_pipeline.path": "zarrs.ZarrsCodecPipeline"} ) # a directory containing `dataset_{i}.zarr` -collection = PreShuffledCollection("path/to/output/collection.zarr") +collection = Collection("path/to/output/collection.zarr") collection.create_anndata_collection( adata_paths=[ "path/to/your/file1.h5ad", diff --git a/docs/api.md b/docs/api.md index 750b0030..662dd5fa 100644 --- a/docs/api.md +++ b/docs/api.md @@ -25,7 +25,7 @@ :toctree: generated/ write_sharded - PreShuffledCollection + Collection ``` (types)= diff --git a/docs/index.md b/docs/index.md index 6abbf53c..896edb26 100644 --- a/docs/index.md +++ b/docs/index.md @@ -9,7 +9,7 @@ Let's go through the above example: ### Preprocessing ```python -PreShuffledCollection("path/to/output/store.zarr").create_anndata_collection( +Collection("path/to/output/store.zarr").create_anndata_collection( adata_paths=[ "path/to/your/file1.h5ad", "path/to/your/file2.h5ad" diff --git a/docs/notebooks/example.ipynb b/docs/notebooks/example.ipynb index 6dbe73bf..315ec408 100644 --- a/docs/notebooks/example.ipynb +++ b/docs/notebooks/example.ipynb @@ -149,7 +149,7 @@ " * We recommend to choose a dataset size that comfortably fits into system memory.\n", "\n", "\n", - "You can apply custom data transformations to each input h5ad file by supplying a `load_adata` function to `PreShuffledCollection.create_anndata_collection`" + "You can apply custom data transformations to each input h5ad file by supplying a `load_adata` function to `Collection.create_anndata_collection`" ] }, { @@ -163,7 +163,7 @@ "outputs": [], "source": [ "import anndata as ad\n", - "from annbatch import PreShuffledCollection\n", + "from annbatch import Collection\n", "\n", "\n", "# For CELLxGENE data, the raw counts can either be found under .raw.X or under .X (if .raw is not supplied).\n", @@ -186,7 +186,7 @@ " )\n", "\n", "\n", - "collection = PreShuffledCollection(zarr.open(\"annbatch_collection\"))\n", + "collection = Collection(zarr.open(\"annbatch_collection\"))\n", "collection.create_anndata_collection(\n", " # List all the h5ad files you want to include in the collection\n", " adata_paths=[\"866d7d5e-436b-4dbd-b7c1-7696487d452e.h5ad\", \"f81463b8-4986-4904-a0ea-20ff02cbb317.h5ad\"],\n", diff --git a/src/annbatch/__init__.py b/src/annbatch/__init__.py index fd9ec611..3a699bb3 100644 --- a/src/annbatch/__init__.py +++ b/src/annbatch/__init__.py @@ -3,9 +3,9 @@ from importlib.metadata import version from . import types -from .io import PreShuffledCollection, write_sharded +from .io import Collection, write_sharded from .loader import Loader __version__ = version("annbatch") -__all__ = ["Loader", "write_sharded", "PreShuffledCollection", "types"] +__all__ = ["Loader", "write_sharded", "Collection", "types"] diff --git a/src/annbatch/io.py b/src/annbatch/io.py index cea81649..cc76232b 100644 --- a/src/annbatch/io.py +++ b/src/annbatch/io.py @@ -266,7 +266,7 @@ def wrapper(*args, **kwargs): return wrapper -class PreShuffledCollection[T: h5py.Group | zarr.Group]: +class Collection[T: h5py.Group | zarr.Group]: """A preshuffled collection object including functionality for creating, adding to, and loading collections.""" _group: T @@ -372,7 +372,7 @@ def create_anndata_collection( Examples -------- >>> import anndata as ad - >>> from annbatch import PreShuffledCollection + >>> from annbatch import Collection # create a custom load function to only keep `.X`, `.obs` and `.var` in the output store >>> def read_lazy_x_and_obs_only(path): ... adata = ad.experimental.read_lazy(path) @@ -387,7 +387,7 @@ def create_anndata_collection( ... "path/to/second_adata.h5ad", ... "path/to/third_adata.h5ad", ... ] - >>> PreShuffledCollection("path/to/output/zarr_store.zarr").create_anndata_collection( + >>> Collection("path/to/output/zarr_store.zarr").create_anndata_collection( ... datasets, ... load_adata=read_lazy_x_and_obs_only, ...) @@ -479,19 +479,19 @@ def add_to_collection( Examples -------- >>> import anndata as ad - >>> from annbatch import PreShuffledCollection + >>> from annbatch import Collection >>> datasets = [ ... "path/to/first_adata.h5ad", ... "path/to/second_adata.h5ad", ... "path/to/third_adata.h5ad", ... ] - >>> PreShuffledCollection("path/to/existing/preshuffled_collection.zarr").add_to_collection( + >>> Collection("path/to/existing/preshuffled_collection.zarr").add_to_collection( ... datasets, ... load_adata=ad.read_h5ad, # replace with ad.experimental.read_lazy if data does not fit into memory ...) """ if self.is_empty: - raise ValueError("Store is empty. Please run `PreShuffledCollection.create_anndata_collection` first.") + raise ValueError("Store is empty. Please run `Collection.create_anndata_collection` first.") # Check for mismatched keys among the inputs. _check_for_mismatched_keys(adata_paths) diff --git a/tests/test_preshuffle.py b/tests/test_preshuffle.py index 3f24416f..8a890cd7 100644 --- a/tests/test_preshuffle.py +++ b/tests/test_preshuffle.py @@ -10,7 +10,7 @@ import scipy.sparse as sp import zarr -from annbatch import PreShuffledCollection, write_sharded +from annbatch import Collection, write_sharded if TYPE_CHECKING: from collections.abc import Callable @@ -48,7 +48,7 @@ def test_store_creation_warngs_with_different_keys(elem_name: Literal["obsm", "l adata_1.write_h5ad(path_1) adata_2.write_h5ad(path_2) with pytest.warns(UserWarning, match=rf"Found {elem_name} keys.* not present in all anndatas"): - PreShuffledCollection(tmp_path / "collection.zarr").create_anndata_collection( + Collection(tmp_path / "collection.zarr").create_anndata_collection( [path_1, path_2], zarr_sparse_chunk_size=10, zarr_sparse_shard_size=20, @@ -67,7 +67,7 @@ def test_store_creation_path_added_to_obs(tmp_path: Path): adata_2.write_h5ad(path_2) paths = [path_1, path_2] output_dir = tmp_path / "path_src_collection.zarr" - PreShuffledCollection(output_dir).create_anndata_collection( + Collection(output_dir).create_anndata_collection( paths, zarr_sparse_chunk_size=10, zarr_sparse_shard_size=20, @@ -96,7 +96,7 @@ def test_store_addition_different_keys( orig_path = tmp_path / "orig.h5ad" adata_orig.write_h5ad(orig_path) output_path = tmp_path / "zarr_store_addition_different_keys.zarr" - collection = PreShuffledCollection(output_path) + collection = Collection(output_path) collection.create_anndata_collection( [orig_path], zarr_sparse_chunk_size=10, @@ -144,7 +144,7 @@ def test_store_creation_default( var_subset = [f"gene_{i}" for i in range(100)] h5_files = sorted(adata_with_h5_path_different_var_space[1].iterdir()) output_path = adata_with_h5_path_different_var_space[1].parent / "zarr_store_creation_test_default.zarr" - PreShuffledCollection(output_path).create_anndata_collection( + Collection(output_path).create_anndata_collection( [adata_with_h5_path_different_var_space[1] / f for f in h5_files if str(f).endswith(".h5ad")], var_subset=var_subset, zarr_sparse_chunk_size=10, @@ -167,7 +167,7 @@ def test_store_creation_drop_elem( output_path = adata_with_h5_path_different_var_space[1].parent / "zarr_store_creation_drop_elems.zarr" output_path.mkdir(parents=True, exist_ok=True) - PreShuffledCollection(output_path).create_anndata_collection( + Collection(output_path).create_anndata_collection( [adata_with_h5_path_different_var_space[1] / f for f in h5_files if str(f).endswith(".h5ad")], var_subset=var_subset, zarr_sparse_chunk_size=10, @@ -190,7 +190,7 @@ def test_store_creation( var_subset = [f"gene_{i}" for i in range(100)] h5_files = sorted(adata_with_h5_path_different_var_space[1].iterdir()) output_path = adata_with_h5_path_different_var_space[1].parent / f"zarr_store_creation_test_{shuffle}.zarr" - PreShuffledCollection(output_path).create_anndata_collection( + Collection(output_path).create_anndata_collection( [adata_with_h5_path_different_var_space[1] / f for f in h5_files if str(f).endswith(".h5ad")], var_subset=var_subset, zarr_sparse_chunk_size=10, @@ -200,7 +200,7 @@ def test_store_creation( n_obs_per_dataset=60, shuffle=shuffle, ) - assert not PreShuffledCollection(output_path).is_empty + assert not Collection(output_path).is_empty assert zarr.open(output_path).attrs["annbatch-shuffled"] adata_orig = adata_with_h5_path_different_var_space[0] @@ -256,7 +256,7 @@ def test_mismatched_raw_concat( output_path = adata_with_h5_path_different_var_space[1].parent / "zarr_store_creation_test_heterogeneous.zarr" h5_paths = [adata_with_h5_path_different_var_space[1] / f for f in h5_files if str(f).endswith(".h5ad")] with pytest.warns(UserWarning, match=r"Found raw keys not present in all anndatas"): - PreShuffledCollection(output_path).create_anndata_collection( + Collection(output_path).create_anndata_collection( h5_paths, zarr_sparse_chunk_size=10, zarr_sparse_shard_size=20, @@ -299,7 +299,7 @@ def test_store_extension( original = all_h5_paths additional = all_h5_paths[4:] # don't add everything to get a "different" var space # create new store - collection = PreShuffledCollection(store_path) + collection = Collection(store_path) collection.create_anndata_collection( original, zarr_sparse_chunk_size=10, @@ -335,7 +335,7 @@ def test_store_extension( def test_empty(tmp_path: Path): g = zarr.open(tmp_path / "empty.zarr") - collection = PreShuffledCollection(g) + collection = Collection(g) assert collection.is_empty # Doesn't matter what errors as long as this function runs, but not to completion with pytest.raises(TypeError): From 3e7fc19569619a54ad2e9e6d6af8a7578b3b8d21 Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Wed, 14 Jan 2026 11:11:23 +0100 Subject: [PATCH 08/39] chore: move to one `add` function --- README.md | 6 +- docs/index.md | 2 +- docs/notebooks/example.ipynb | 6 +- src/annbatch/io.py | 129 ++++++++++++++++++++++++++++++++--- tests/test_preshuffle.py | 22 +++--- 5 files changed, 138 insertions(+), 27 deletions(-) diff --git a/README.md b/README.md index d74a2b1d..84557c25 100644 --- a/README.md +++ b/README.md @@ -71,7 +71,7 @@ For a detailed tutorial, please see the [in-depth section of our docs][] Basic preprocessing: ```python -from annbatch import create_anndata_collection +from annbatch import Collection import zarr from pathlib import Path @@ -83,12 +83,12 @@ zarr.config.set( ) # a directory containing `dataset_{i}.zarr` collection = Collection("path/to/output/collection.zarr") -collection.create_anndata_collection( +collection.add( adata_paths=[ "path/to/your/file1.h5ad", "path/to/your/file2.h5ad" ], - shuffle=True, # shuffling is needed if you want to use chunked access + shuffle=True, # shuffling is needed if you want to use chunked access, but is the default ) ``` diff --git a/docs/index.md b/docs/index.md index 896edb26..41a3ff0a 100644 --- a/docs/index.md +++ b/docs/index.md @@ -9,7 +9,7 @@ Let's go through the above example: ### Preprocessing ```python -Collection("path/to/output/store.zarr").create_anndata_collection( +Collection("path/to/output/store.zarr").add( adata_paths=[ "path/to/your/file1.h5ad", "path/to/your/file2.h5ad" diff --git a/docs/notebooks/example.ipynb b/docs/notebooks/example.ipynb index 315ec408..8eb8adf2 100644 --- a/docs/notebooks/example.ipynb +++ b/docs/notebooks/example.ipynb @@ -149,7 +149,7 @@ " * We recommend to choose a dataset size that comfortably fits into system memory.\n", "\n", "\n", - "You can apply custom data transformations to each input h5ad file by supplying a `load_adata` function to `Collection.create_anndata_collection`" + "You can apply custom data transformations to each input h5ad file by supplying a `load_adata` function to `Collection.add`" ] }, { @@ -187,7 +187,7 @@ "\n", "\n", "collection = Collection(zarr.open(\"annbatch_collection\"))\n", - "collection.create_anndata_collection(\n", + "collection.add(\n", " # List all the h5ad files you want to include in the collection\n", " adata_paths=[\"866d7d5e-436b-4dbd-b7c1-7696487d452e.h5ad\", \"f81463b8-4986-4904-a0ea-20ff02cbb317.h5ad\"],\n", " # Path to store the output collection\n", @@ -363,7 +363,7 @@ " return ad.AnnData(X=x, obs=adata_.obs, var=var)\n", "\n", "\n", - "collection.add_to_collection(\n", + "collection.add(\n", " adata_paths=[\n", " \"866d7d5e-436b-4dbd-b7c1-7696487d452e.h5ad\",\n", " ],\n", diff --git a/src/annbatch/io.py b/src/annbatch/io.py index cc76232b..dbb2c275 100644 --- a/src/annbatch/io.py +++ b/src/annbatch/io.py @@ -308,7 +308,7 @@ def is_empty(self) -> bool: ) @_with_settings - def create_anndata_collection( + def add( self, adata_paths: Iterable[PathLike[str]] | Iterable[str], *, @@ -325,6 +325,109 @@ def create_anndata_collection( n_obs_per_dataset: int = 2_097_152, shuffle_slice_size: int = 1000, shuffle: bool = True, + ): + """Take AnnData paths, create or add to an on-disk set of AnnData datasets with uniform var spaces at the desired path (with `n_obs_per_dataset` rows per store if running for the first time). + + The set of AnnData datasets is collectively referred to as a "collection" where each dataset is called `dataset_i.{zarr,h5ad}`. + The main purpose of this function is to create shuffled sharded zarr datasets, which is the default behavior of this function. + However, this function can also output h5 datasets and also unshuffled datasets as well. + The var space is by default outer-joined initially, and then subsequently added datasets (i.e., on second calls to this function) are subsetted, but this behavior can be controlled by `var_subset`. + A key `src_path` is added to `obs` to indicate where individual row came from. + We highly recommend making your indexes unique across files, and this function will call {meth}`AnnData.obs_names_make_unique`. + Memory usage should be controlled by `n_obs_per_dataset` + `shuffle_slice_size` as so many rows will be read into memory before writing to disk. + + Parameters + ---------- + adata_paths + Paths to the AnnData files used to create the zarr store. + load_adata + Function to customize lazy-loading the invidiual input anndata files. By default, {func}`anndata.experimental.read_lazy` is used. + If you only need a subset of the input anndata files' elems (e.g., only `X` and `obs`), you can provide a custom function here to speed up loading and harmonize your data. + The input to the function is a path to an anndata file, and the output is an anndata object which has `X` as a {class}`dask.array.Array`. + var_subset + Subset of gene names to include in the store. If None, all genes are included. + Genes are subset based on the `var_names` attribute of the concatenated AnnData object. + zarr_sparse_chunk_size + Size of the chunks to use for the `indices` and `data` of a sparse matrix in the zarr store. + zarr_sparse_shard_size + Size of the shards to use for the `indices` and `data` of a sparse matrix in the zarr store. + zarr_dense_chunk_size + Number of observations per dense zarr chunk i.e., sharding is only done along the first axis of the array. + zarr_dense_shard_size + Number of observations per dense zarr shard i.e., chunking is only done along the first axis of the array. + zarr_compressor + Compressors to use to compress the data in the zarr store. + h5ad_compressor + Compressors to use to compress the data in the h5ad store. See anndata.write_h5ad. + n_obs_per_dataset + Number of observations to load into memory at once for shuffling / pre-processing. + The higher this number, the more memory is used, but the better the shuffling. + This corresponds to the size of the shards created. + Only applicable when adding datasets for the first time, otherwise ignored. + shuffle + Whether to shuffle the data before writing it to the store. + Ignored once the store is non-empty. + shuffle_slice_size + How many contiguous rows to load into memory before shuffling at once. + `(shuffle_slice_size // n_obs_per_dataset)` slices will be loaded of size `shuffle_slice_size`. + + Examples + -------- + >>> import anndata as ad + >>> from annbatch import Collection + # create a custom load function to only keep `.X`, `.obs` and `.var` in the output store + >>> def read_lazy_x_and_obs_only(path): + ... adata = ad.experimental.read_lazy(path) + ... return ad.AnnData( + ... X=adata.X, + ... obs=adata.obs.to_memory(), + ... var=adata.var.to_memory(), + ...) + + >>> datasets = [ + ... "path/to/first_adata.h5ad", + ... "path/to/second_adata.h5ad", + ... "path/to/third_adata.h5ad", + ... ] + >>> Collection("path/to/output/zarr_store.zarr").add( + ... datasets, + ... load_adata=read_lazy_x_and_obs_only, + ...) + """ + shared_kwargs = { + "adata_paths": adata_paths, + "load_adata": load_adata, + "zarr_sparse_chunk_size": zarr_sparse_chunk_size, + "zarr_sparse_shard_size": zarr_sparse_shard_size, + "zarr_dense_chunk_size": zarr_dense_chunk_size, + "zarr_dense_shard_size": zarr_dense_shard_size, + "zarr_compressor": zarr_compressor, + "h5ad_compressor": h5ad_compressor, + "shuffle_slice_size": shuffle_slice_size, + "shuffle": shuffle, + } + if self.is_empty: + self._create_collection(**shared_kwargs, n_obs_per_dataset=n_obs_per_dataset, var_subset=var_subset) + else: + self._add_to_collection(**shared_kwargs) + + def _create_collection( + self, + *, + adata_paths: Iterable[PathLike[str]] | Iterable[str], + load_adata: Callable[[PathLike[str] | str], ad.AnnData] = lambda x: ad.experimental.read_lazy( + x, load_annotation_index=False + ), + var_subset: Iterable[str] | None = None, + zarr_sparse_chunk_size: int = 32768, + zarr_sparse_shard_size: int = 134_217_728, + zarr_dense_chunk_size: int = 1024, + zarr_dense_shard_size: int = 4_194_304, + zarr_compressor: Iterable[BytesBytesCodec] = (BloscCodec(cname="lz4", clevel=3, shuffle=BloscShuffle.shuffle),), + h5ad_compressor: Literal["gzip", "lzf"] | None = "gzip", + n_obs_per_dataset: int = 2_097_152, + shuffle_slice_size: int = 1000, + shuffle: bool = True, ): """Take AnnData paths, create an on-disk set of AnnData datasets with uniform var spaces at the desired path with `n_obs_per_dataset` rows per store. @@ -347,6 +450,7 @@ def create_anndata_collection( var_subset Subset of gene names to include in the store. If None, all genes are included. Genes are subset based on the `var_names` attribute of the concatenated AnnData object. + Only applicable when adding datasets for the first time, otherwise ignored and the incoming data's var space is subsetted to that of the existing collection. zarr_sparse_chunk_size Size of the chunks to use for the `indices` and `data` of a sparse matrix in the zarr store. zarr_sparse_shard_size @@ -363,6 +467,7 @@ def create_anndata_collection( Number of observations to load into memory at once for shuffling / pre-processing. The higher this number, the more memory is used, but the better the shuffling. This corresponds to the size of the shards created. + Only applicable when adding datasets for the first time, otherwise ignored. shuffle Whether to shuffle the data before writing it to the store. shuffle_slice_size @@ -387,7 +492,7 @@ def create_anndata_collection( ... "path/to/second_adata.h5ad", ... "path/to/third_adata.h5ad", ... ] - >>> Collection("path/to/output/zarr_store.zarr").create_anndata_collection( + >>> Collection("path/to/output/zarr_store.zarr").add( ... datasets, ... load_adata=read_lazy_x_and_obs_only, ...) @@ -434,9 +539,9 @@ def create_anndata_collection( else: self._group.attrs["annbatch-shuffled"] = True - @_with_settings - def add_to_collection( + def _add_to_collection( self, + *, adata_paths: Iterable[PathLike[str]] | Iterable[str], load_adata: Callable[[PathLike[str] | str], ad.AnnData] = ad.read_h5ad, zarr_sparse_chunk_size: int = 32768, @@ -446,6 +551,7 @@ def add_to_collection( zarr_compressor: Iterable[BytesBytesCodec] = (BloscCodec(cname="lz4", clevel=3, shuffle=BloscShuffle.shuffle),), h5ad_compressor: Literal["gzip", "lzf"] | None = "gzip", shuffle_slice_size: int = 1000, + shuffle: bool = True, ) -> None: """Add anndata files to an existing collection of sharded anndata zarr datasets. @@ -475,6 +581,8 @@ def add_to_collection( To save memory, the blocks of a dense on-disk store can be sparsified for in-memory processing. shuffle_slice_size How many contiguous rows to load into memory of the input data for pseudo-blockshuffling into the existing datasets. + shuffle + Whether or not to shuffle when adding. Otherwise, the incoming data will just be split up and appended. Examples -------- @@ -485,13 +593,13 @@ def add_to_collection( ... "path/to/second_adata.h5ad", ... "path/to/third_adata.h5ad", ... ] - >>> Collection("path/to/existing/preshuffled_collection.zarr").add_to_collection( + >>> Collection("path/to/existing/preshuffled_collection.zarr").add( ... datasets, ... load_adata=ad.read_h5ad, # replace with ad.experimental.read_lazy if data does not fit into memory ...) """ if self.is_empty: - raise ValueError("Store is empty. Please run `Collection.create_anndata_collection` first.") + raise ValueError("Store is empty. Please run `Collection.add` first.") # Check for mismatched keys among the inputs. _check_for_mismatched_keys(adata_paths) @@ -503,7 +611,7 @@ def add_to_collection( adata_concat.shape[0], np.ceil(len(adata_concat) / len(self._dataset_keys)), shuffle_slice_size, - shuffle=True, + shuffle=shuffle, ) else: chunks = np.array_split(np.random.default_rng().permutation(len(adata_concat)), len(self._dataset_keys)) @@ -518,8 +626,11 @@ def add_to_collection( adata_concat[chunk, :][:, adata_concat.var.index.isin(adata_dataset.var.index)] ) adata = ad.concat([adata_dataset, subset_adata], join="outer") - idxs_shuffled = np.random.default_rng().permutation(len(adata)) - adata = _persist_adata_in_memory(adata[idxs_shuffled, :]) + if shuffle: + idxs = np.random.default_rng().permutation(len(adata)) + else: + idxs = np.arange(len(adata)) + adata = _persist_adata_in_memory(adata[idxs, :]) if isinstance(self._group, zarr.Group): write_sharded( self._group, diff --git a/tests/test_preshuffle.py b/tests/test_preshuffle.py index 8a890cd7..90c92ac3 100644 --- a/tests/test_preshuffle.py +++ b/tests/test_preshuffle.py @@ -48,7 +48,7 @@ def test_store_creation_warngs_with_different_keys(elem_name: Literal["obsm", "l adata_1.write_h5ad(path_1) adata_2.write_h5ad(path_2) with pytest.warns(UserWarning, match=rf"Found {elem_name} keys.* not present in all anndatas"): - Collection(tmp_path / "collection.zarr").create_anndata_collection( + Collection(tmp_path / "collection.zarr").add( [path_1, path_2], zarr_sparse_chunk_size=10, zarr_sparse_shard_size=20, @@ -67,7 +67,7 @@ def test_store_creation_path_added_to_obs(tmp_path: Path): adata_2.write_h5ad(path_2) paths = [path_1, path_2] output_dir = tmp_path / "path_src_collection.zarr" - Collection(output_dir).create_anndata_collection( + Collection(output_dir).add( paths, zarr_sparse_chunk_size=10, zarr_sparse_shard_size=20, @@ -97,7 +97,7 @@ def test_store_addition_different_keys( adata_orig.write_h5ad(orig_path) output_path = tmp_path / "zarr_store_addition_different_keys.zarr" collection = Collection(output_path) - collection.create_anndata_collection( + collection.add( [orig_path], zarr_sparse_chunk_size=10, zarr_sparse_shard_size=20, @@ -112,7 +112,7 @@ def test_store_addition_different_keys( additional_path = tmp_path / "with_extra_key.h5ad" adata.write_h5ad(additional_path) with pytest.warns(UserWarning, match=rf"Found {elem_name} keys.* not present in all anndatas"): - collection.add_to_collection( + collection.add( [additional_path], load_adata=load_adata, zarr_sparse_chunk_size=10, @@ -144,7 +144,7 @@ def test_store_creation_default( var_subset = [f"gene_{i}" for i in range(100)] h5_files = sorted(adata_with_h5_path_different_var_space[1].iterdir()) output_path = adata_with_h5_path_different_var_space[1].parent / "zarr_store_creation_test_default.zarr" - Collection(output_path).create_anndata_collection( + Collection(output_path).add( [adata_with_h5_path_different_var_space[1] / f for f in h5_files if str(f).endswith(".h5ad")], var_subset=var_subset, zarr_sparse_chunk_size=10, @@ -167,7 +167,7 @@ def test_store_creation_drop_elem( output_path = adata_with_h5_path_different_var_space[1].parent / "zarr_store_creation_drop_elems.zarr" output_path.mkdir(parents=True, exist_ok=True) - Collection(output_path).create_anndata_collection( + Collection(output_path).add( [adata_with_h5_path_different_var_space[1] / f for f in h5_files if str(f).endswith(".h5ad")], var_subset=var_subset, zarr_sparse_chunk_size=10, @@ -190,7 +190,7 @@ def test_store_creation( var_subset = [f"gene_{i}" for i in range(100)] h5_files = sorted(adata_with_h5_path_different_var_space[1].iterdir()) output_path = adata_with_h5_path_different_var_space[1].parent / f"zarr_store_creation_test_{shuffle}.zarr" - Collection(output_path).create_anndata_collection( + Collection(output_path).add( [adata_with_h5_path_different_var_space[1] / f for f in h5_files if str(f).endswith(".h5ad")], var_subset=var_subset, zarr_sparse_chunk_size=10, @@ -256,7 +256,7 @@ def test_mismatched_raw_concat( output_path = adata_with_h5_path_different_var_space[1].parent / "zarr_store_creation_test_heterogeneous.zarr" h5_paths = [adata_with_h5_path_different_var_space[1] / f for f in h5_files if str(f).endswith(".h5ad")] with pytest.warns(UserWarning, match=r"Found raw keys not present in all anndatas"): - Collection(output_path).create_anndata_collection( + Collection(output_path).add( h5_paths, zarr_sparse_chunk_size=10, zarr_sparse_shard_size=20, @@ -300,7 +300,7 @@ def test_store_extension( additional = all_h5_paths[4:] # don't add everything to get a "different" var space # create new store collection = Collection(store_path) - collection.create_anndata_collection( + collection.add( original, zarr_sparse_chunk_size=10, zarr_sparse_shard_size=20, @@ -310,7 +310,7 @@ def test_store_extension( shuffle=True, ) # add h5ads to existing store - collection.add_to_collection( + collection.add( additional, load_adata=load_adata, zarr_sparse_chunk_size=10, @@ -339,5 +339,5 @@ def test_empty(tmp_path: Path): assert collection.is_empty # Doesn't matter what errors as long as this function runs, but not to completion with pytest.raises(TypeError): - collection.add_to_collection() + collection.add() assert "annbatch-shuffled" not in g.attrs From 61a692f59f50c3d59e8ddcd419b1463e9fab496f Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Wed, 14 Jan 2026 11:50:34 +0100 Subject: [PATCH 09/39] feat: add `Loader` API --- README.md | 14 ++++-------- docs/notebooks/example.ipynb | 19 +++++----------- src/annbatch/io.py | 9 ++++++-- src/annbatch/loader.py | 26 +++++++++++++++++++++- src/annbatch/utils.py | 8 +++++++ tests/conftest.py | 18 ++++++++++++++++ tests/test_dataset.py | 42 +++++++++++++++++++++--------------- 7 files changed, 92 insertions(+), 44 deletions(-) diff --git a/README.md b/README.md index 84557c25..4c41f383 100644 --- a/README.md +++ b/README.md @@ -107,22 +107,16 @@ zarr.config.set( {"codec_pipeline.path": "zarrs.ZarrsCodecPipeline"} ) +def custom_load_func(g: zarr.Group) -> ad.AnnData: + return ad.AnnData(X=ad.io.sparse_dataset(g["layers"]["counts"]), obs=ad.io.read_elem(g["obs"])[some_subset_of_columns]) + # This settings override ensures that you don't lose/alter your categorical codes when reading the data in! with ad.settings.override(remove_unused_categories=False): ds = Loader( batch_size=4096, chunk_size=32, preload_nchunks=256, - ).add_anndatas( - [ - ad.AnnData( - # note that you can open an AnnData file using any type of zarr store - X=ad.io.sparse_dataset(zarr.open(p)["X"]), - obs=ad.io.read_elem(zarr.open(p)["obs"]), - ) - for p in Path("path/to/output/collection").glob("*.zarr") - ] - ) + ).add_collection(collection) # Automatically uses the on-disk `X` and full `obs` in the `Loader` but the `load_adata` arg can override this behavior (see `custom_load_func` above) # Iterate over dataloader (plugin replacement for torch.utils.DataLoader) for batch in ds: diff --git a/docs/notebooks/example.ipynb b/docs/notebooks/example.ipynb index 8eb8adf2..97e83495 100644 --- a/docs/notebooks/example.ipynb +++ b/docs/notebooks/example.ipynb @@ -218,7 +218,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": null, "metadata": { "tags": [ "hide-output" @@ -250,17 +250,8 @@ " to_torch=True,\n", ")\n", "\n", - "# Add dataset that should be used for training\n", - "ds.add_anndatas(\n", - " [\n", - " ad.AnnData(\n", - " X=ad.io.sparse_dataset(zarr.open(p)[\"X\"]),\n", - " obs=ad.io.read_elem(zarr.open(p)[\"obs\"]),\n", - " )\n", - " for p in COLLECTION_PATH.glob(\"*.zarr\")\n", - " ],\n", - " obs_keys=\"cell_type\",\n", - ")" + "# Add in the shuffled data that should be used for training\n", + "ds.add_collection(collection)" ] }, { @@ -277,7 +268,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": null, "metadata": { "tags": [ "hide-output" @@ -299,7 +290,7 @@ "import tqdm\n", "\n", "for batch in tqdm.tqdm(ds):\n", - " x, obs = batch\n", + " x, obs = batch.data, batch.labels[\"cell_type\"]\n", " # Important: Convert to dense on GPU\n", " x = x.cuda().to_dense()\n", " # Feed data into your model\n", diff --git a/src/annbatch/io.py b/src/annbatch/io.py index dbb2c275..b74ee717 100644 --- a/src/annbatch/io.py +++ b/src/annbatch/io.py @@ -21,7 +21,7 @@ from zarr.codecs import BloscCodec, BloscShuffle if TYPE_CHECKING: - from collections.abc import Callable, Iterable, Mapping + from collections.abc import Callable, Generator, Iterable, Mapping from os import PathLike from typing import Any, Literal @@ -298,7 +298,11 @@ def __init__(self, group: T | str | Path, *, mode: Literal["a", "r", "r+"] = "a" @property def _dataset_keys(self) -> list[str]: - return [k for k in self._group.keys() if re.match(rf"{DATASET_PREFIX}_([0-9]*)", k) is not None] + return sorted([k for k in self._group.keys() if re.match(rf"{DATASET_PREFIX}_([0-9]*)", k) is not None]) + + def __iter__(self) -> Generator[zarr.Group]: + for k in self._dataset_keys: + yield self._group[k] @property def is_empty(self) -> bool: @@ -410,6 +414,7 @@ def add( self._create_collection(**shared_kwargs, n_obs_per_dataset=n_obs_per_dataset, var_subset=var_subset) else: self._add_to_collection(**shared_kwargs) + return self def _create_collection( self, diff --git a/src/annbatch/loader.py b/src/annbatch/loader.py index 2c93e20e..32b04aea 100644 --- a/src/annbatch/loader.py +++ b/src/annbatch/loader.py @@ -24,6 +24,7 @@ _batched, check_lt_1, check_var_shapes, + load_x_and_obs, split_given_size, to_torch, ) @@ -31,9 +32,11 @@ from .compat import IterableDataset if TYPE_CHECKING: - from collections.abc import Iterator + from collections.abc import Callable, Iterator from types import ModuleType + from annbatch.io import Collection + # TODO: remove after sphinx 9 - myst compat BackingArray = BackingArray_T OutputInMemoryArray = OutputInMemoryArray_T @@ -222,6 +225,27 @@ def n_var(self) -> int: """ return self._shapes[0][1] + def add_collection( + self, collection: Collection, *, load_adata: Callable[[zarr.Group], ad.AnnData] = load_x_and_obs + ): + """Load from an existing {class}`annbatch.Collection` + + Parameters + ---------- + collection + _description_ + load_adata, optional + _description_, by default load_x_and_obs + + Returns + ------- + _description_ + """ + if collection.is_empty: + raise ValueError("Collection is empty") + adatas = [load_adata(g) for g in collection] + return self.add_anndatas(adatas) + def add_anndatas( self, adatas: list[ad.AnnData], diff --git a/src/annbatch/utils.py b/src/annbatch/utils.py index 8d4f8f79..84d2bc47 100644 --- a/src/annbatch/utils.py +++ b/src/annbatch/utils.py @@ -7,6 +7,7 @@ from itertools import islice from typing import TYPE_CHECKING, Protocol +import anndata as ad import numpy as np import scipy as sp import zarr @@ -190,3 +191,10 @@ def to_torch(input: OutputInMemoryArray_T, preload_to_gpu: bool) -> Tensor: input.shape, ) raise TypeError(f"Cannot convert {type(input)} to torch.Tensor") + + +def load_x_and_obs(g: zarr.Group) -> ad.AnnData: + """Load X as a sparse array or dense zarr array and obs from a group""" + return ad.AnnData( + X=g["X"] if isinstance(g["X"], zarr.Array) else ad.io.sparse_dataset(g["X"]), obs=ad.io.read_elem(g["obs"]) + ) diff --git a/tests/conftest.py b/tests/conftest.py index aa012e34..6211751f 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -13,6 +13,7 @@ from scipy.sparse import random as sparse_random from annbatch import write_sharded +from annbatch.io import Collection if TYPE_CHECKING: from collections.abc import Generator @@ -102,3 +103,20 @@ def adata_with_h5_path_different_var_space( [ad.read_h5ad(tmp_path / shard) for shard in sorted(tmp_path.iterdir()) if str(shard).endswith(".h5ad")], join="outer", ), tmp_path + + +@pytest.fixture(scope="session") +def simple_collection( + tmpdir_factory, adata_with_zarr_path_same_var_space: tuple[ad.AnnData, Path] +) -> tuple[Collection, ad.AnnData]: + zarr_stores = sorted(f for f in adata_with_zarr_path_same_var_space[1].iterdir() if f.is_dir()) + output_path = Path(tmpdir_factory.mktemp("zarr_folder")) / "simple_fixture.zarr" + collection = Collection(output_path).add( + zarr_stores, + zarr_sparse_chunk_size=10, + zarr_sparse_shard_size=20, + zarr_dense_chunk_size=10, + zarr_dense_shard_size=20, + n_obs_per_dataset=60, + ) + return ad.concat([ad.io.read_elem(ds) for ds in collection], join="outer"), collection diff --git a/tests/test_dataset.py b/tests/test_dataset.py index 9687ae58..77fdce0a 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -25,6 +25,8 @@ from collections.abc import Callable from pathlib import Path + from annbatch.io import Collection + class Data(TypedDict): dataset: ad.abc.CSRDataset | zarr.Array @@ -36,31 +38,34 @@ class ListData: obs: list[np.ndarray] -def open_sparse(path: Path, *, use_zarrs: bool = False, use_anndata: bool = False) -> Data | ad.AnnData: +def open_sparse(path: Path | zarr.Group, *, use_zarrs: bool = False, use_anndata: bool = False) -> Data | ad.AnnData: old_pipeline = zarr.config.get("codec_pipeline.path") with zarr.config.set({"codec_pipeline.path": "zarrs.ZarrsCodecPipeline" if use_zarrs else old_pipeline}): + if not isinstance(path, zarr.Group): + path = zarr.open(path) data = { - "dataset": ad.io.sparse_dataset(zarr.open(path)["layers"]["sparse"]), - "obs": ad.io.read_elem(zarr.open(path)["obs"]), + "dataset": ad.io.sparse_dataset(path["layers"]["sparse"]), + "obs": ad.io.read_elem(path["obs"]), } if use_anndata: return ad.AnnData(X=data["dataset"], obs=data["obs"]) return data -def open_dense(path: Path, *, use_zarrs: bool = False, use_anndata: bool = False) -> Data | ad.AnnData: +def open_dense(path: Path | zarr.Group, *, use_zarrs: bool = False, use_anndata: bool = False) -> Data | ad.AnnData: old_pipeline = zarr.config.get("codec_pipeline.path") with zarr.config.set({"codec_pipeline.path": "zarrs.ZarrsCodecPipeline" if use_zarrs else old_pipeline}): + if not isinstance(path, zarr.Group): + path = zarr.open(path) data = { - "dataset": zarr.open(path)["X"], - "obs": ad.io.read_elem(zarr.open(path)["obs"]), + "dataset": path["X"], + "obs": ad.io.read_elem(path["obs"]), } if use_anndata: return ad.AnnData(X=data["dataset"], obs=data["obs"]) return data - return data def concat(datas: list[Data | ad.AnnData]) -> ListData | list[ad.AnnData]: @@ -79,7 +84,7 @@ def concat(datas: list[Data | ad.AnnData]) -> ListData | list[ad.AnnData]: "gen_loader", [ pytest.param( - lambda path, + lambda collection, shuffle, use_zarrs, chunk_size=chunk_size, @@ -94,10 +99,15 @@ def concat(datas: list[Data | ad.AnnData]) -> ListData | list[ad.AnnData]: batch_size=batch_size, preload_to_gpu=preload_to_gpu, to_torch=False, - ).add_anndatas( - [open_func(p, use_zarrs=use_zarrs, use_anndata=True) for p in path.glob("*.zarr")], + ).add_collection( + collection, + **( + {"load_adata": lambda group: open_func(group, use_zarrs=use_zarrs, use_anndata=True)} + if open_func is not None + else {} + ), ), - id=f"chunk_size={chunk_size}-preload_nchunks={preload_nchunks}-dataset_type={open_func.__name__[5:]}-batch_size={batch_size}{'-cupy' if preload_to_gpu else ''}", # type: ignore[attr-defined] + id=f"chunk_size={chunk_size}-preload_nchunks={preload_nchunks}-open_func={open_func.__name__[5:] if open_func is not None else 'None'}-batch_size={batch_size}{'-cupy' if preload_to_gpu else ''}", # type: ignore[attr-defined] marks=pytest.mark.skipif( find_spec("cupy") is None and preload_to_gpu, reason="need cupy installed", @@ -106,7 +116,7 @@ def concat(datas: list[Data | ad.AnnData]) -> ListData | list[ad.AnnData]: for chunk_size, preload_nchunks, open_func, batch_size, preload_to_gpu in [ elem for preload_to_gpu in [True, False] - for open_func in [open_sparse, open_dense] + for open_func in [open_sparse, open_dense, None] for elem in [ [ 1, @@ -147,9 +157,7 @@ def concat(datas: list[Data | ad.AnnData]) -> ListData | list[ad.AnnData]: ] ], ) -def test_store_load_dataset( - adata_with_zarr_path_same_var_space: tuple[ad.AnnData, Path], *, shuffle: bool, gen_loader, use_zarrs -): +def test_store_load_dataset(simple_collection: tuple[ad.AnnData, Collection], *, shuffle: bool, gen_loader, use_zarrs): """ This test verifies that the DaskDataset works correctly: 1. The DaskDataset correctly loads data from the mock store @@ -157,8 +165,8 @@ def test_store_load_dataset( 3. All samples from the dataset are processed 4. If the dataset is not shuffled, it returns the correct data """ - loader: Loader = gen_loader(adata_with_zarr_path_same_var_space[1], shuffle, use_zarrs) - adata = adata_with_zarr_path_same_var_space[0] + loader: Loader = gen_loader(simple_collection[1], shuffle, use_zarrs) + adata = simple_collection[0] is_dense = loader.dataset_type is zarr.Array n_elems = 0 batches = [] From a8b3cb99f565e3bca0f4ca8f4c895aa5748556b9 Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Wed, 14 Jan 2026 11:51:11 +0100 Subject: [PATCH 10/39] fix: `Collection` is part of the `collections` built-in lib --- README.md | 4 ++-- docs/api.md | 2 +- docs/index.md | 2 +- docs/notebooks/example.ipynb | 6 +++--- src/annbatch/__init__.py | 4 ++-- src/annbatch/io.py | 16 ++++++++-------- src/annbatch/loader.py | 8 ++++---- tests/conftest.py | 6 +++--- tests/test_dataset.py | 6 ++++-- tests/test_preshuffle.py | 22 +++++++++++----------- 10 files changed, 39 insertions(+), 37 deletions(-) diff --git a/README.md b/README.md index 4c41f383..1b26dbc5 100644 --- a/README.md +++ b/README.md @@ -71,7 +71,7 @@ For a detailed tutorial, please see the [in-depth section of our docs][] Basic preprocessing: ```python -from annbatch import Collection +from annbatch import DatasetCollection import zarr from pathlib import Path @@ -82,7 +82,7 @@ zarr.config.set( {"codec_pipeline.path": "zarrs.ZarrsCodecPipeline"} ) # a directory containing `dataset_{i}.zarr` -collection = Collection("path/to/output/collection.zarr") +collection = DatasetCollection("path/to/output/collection.zarr") collection.add( adata_paths=[ "path/to/your/file1.h5ad", diff --git a/docs/api.md b/docs/api.md index 662dd5fa..cf399fd6 100644 --- a/docs/api.md +++ b/docs/api.md @@ -25,7 +25,7 @@ :toctree: generated/ write_sharded - Collection + DatasetCollection ``` (types)= diff --git a/docs/index.md b/docs/index.md index 41a3ff0a..8207e2af 100644 --- a/docs/index.md +++ b/docs/index.md @@ -9,7 +9,7 @@ Let's go through the above example: ### Preprocessing ```python -Collection("path/to/output/store.zarr").add( +DatasetCollection("path/to/output/store.zarr").add( adata_paths=[ "path/to/your/file1.h5ad", "path/to/your/file2.h5ad" diff --git a/docs/notebooks/example.ipynb b/docs/notebooks/example.ipynb index 97e83495..698b1da5 100644 --- a/docs/notebooks/example.ipynb +++ b/docs/notebooks/example.ipynb @@ -149,7 +149,7 @@ " * We recommend to choose a dataset size that comfortably fits into system memory.\n", "\n", "\n", - "You can apply custom data transformations to each input h5ad file by supplying a `load_adata` function to `Collection.add`" + "You can apply custom data transformations to each input h5ad file by supplying a `load_adata` function to `DatasetCollection.add`" ] }, { @@ -163,7 +163,7 @@ "outputs": [], "source": [ "import anndata as ad\n", - "from annbatch import Collection\n", + "from annbatch import DatasetCollection\n", "\n", "\n", "# For CELLxGENE data, the raw counts can either be found under .raw.X or under .X (if .raw is not supplied).\n", @@ -186,7 +186,7 @@ " )\n", "\n", "\n", - "collection = Collection(zarr.open(\"annbatch_collection\"))\n", + "collection = DatasetCollection(zarr.open(\"annbatch_collection\"))\n", "collection.add(\n", " # List all the h5ad files you want to include in the collection\n", " adata_paths=[\"866d7d5e-436b-4dbd-b7c1-7696487d452e.h5ad\", \"f81463b8-4986-4904-a0ea-20ff02cbb317.h5ad\"],\n", diff --git a/src/annbatch/__init__.py b/src/annbatch/__init__.py index 3a699bb3..d53341c0 100644 --- a/src/annbatch/__init__.py +++ b/src/annbatch/__init__.py @@ -3,9 +3,9 @@ from importlib.metadata import version from . import types -from .io import Collection, write_sharded +from .io import DatasetCollection, write_sharded from .loader import Loader __version__ = version("annbatch") -__all__ = ["Loader", "write_sharded", "Collection", "types"] +__all__ = ["Loader", "write_sharded", "DatasetCollection", "types"] diff --git a/src/annbatch/io.py b/src/annbatch/io.py index b74ee717..d444640a 100644 --- a/src/annbatch/io.py +++ b/src/annbatch/io.py @@ -266,7 +266,7 @@ def wrapper(*args, **kwargs): return wrapper -class Collection[T: h5py.Group | zarr.Group]: +class DatasetCollection[T: h5py.Group | zarr.Group]: """A preshuffled collection object including functionality for creating, adding to, and loading collections.""" _group: T @@ -378,7 +378,7 @@ def add( Examples -------- >>> import anndata as ad - >>> from annbatch import Collection + >>> from annbatch import DatasetCollection # create a custom load function to only keep `.X`, `.obs` and `.var` in the output store >>> def read_lazy_x_and_obs_only(path): ... adata = ad.experimental.read_lazy(path) @@ -393,7 +393,7 @@ def add( ... "path/to/second_adata.h5ad", ... "path/to/third_adata.h5ad", ... ] - >>> Collection("path/to/output/zarr_store.zarr").add( + >>> DatasetCollection("path/to/output/zarr_store.zarr").add( ... datasets, ... load_adata=read_lazy_x_and_obs_only, ...) @@ -482,7 +482,7 @@ def _create_collection( Examples -------- >>> import anndata as ad - >>> from annbatch import Collection + >>> from annbatch import DatasetCollection # create a custom load function to only keep `.X`, `.obs` and `.var` in the output store >>> def read_lazy_x_and_obs_only(path): ... adata = ad.experimental.read_lazy(path) @@ -497,7 +497,7 @@ def _create_collection( ... "path/to/second_adata.h5ad", ... "path/to/third_adata.h5ad", ... ] - >>> Collection("path/to/output/zarr_store.zarr").add( + >>> DatasetCollection("path/to/output/zarr_store.zarr").add( ... datasets, ... load_adata=read_lazy_x_and_obs_only, ...) @@ -592,19 +592,19 @@ def _add_to_collection( Examples -------- >>> import anndata as ad - >>> from annbatch import Collection + >>> from annbatch import DatasetCollection >>> datasets = [ ... "path/to/first_adata.h5ad", ... "path/to/second_adata.h5ad", ... "path/to/third_adata.h5ad", ... ] - >>> Collection("path/to/existing/preshuffled_collection.zarr").add( + >>> DatasetCollection("path/to/existing/preshuffled_collection.zarr").add( ... datasets, ... load_adata=ad.read_h5ad, # replace with ad.experimental.read_lazy if data does not fit into memory ...) """ if self.is_empty: - raise ValueError("Store is empty. Please run `Collection.add` first.") + raise ValueError("Store is empty. Please run `DatasetCollection.add` first.") # Check for mismatched keys among the inputs. _check_for_mismatched_keys(adata_paths) diff --git a/src/annbatch/loader.py b/src/annbatch/loader.py index 32b04aea..11000df1 100644 --- a/src/annbatch/loader.py +++ b/src/annbatch/loader.py @@ -35,7 +35,7 @@ from collections.abc import Callable, Iterator from types import ModuleType - from annbatch.io import Collection + from annbatch.io import DatasetCollection # TODO: remove after sphinx 9 - myst compat BackingArray = BackingArray_T @@ -226,9 +226,9 @@ def n_var(self) -> int: return self._shapes[0][1] def add_collection( - self, collection: Collection, *, load_adata: Callable[[zarr.Group], ad.AnnData] = load_x_and_obs + self, collection: DatasetCollection, *, load_adata: Callable[[zarr.Group], ad.AnnData] = load_x_and_obs ): - """Load from an existing {class}`annbatch.Collection` + """Load from an existing {class}`annbatch.DatasetCollection` Parameters ---------- @@ -242,7 +242,7 @@ def add_collection( _description_ """ if collection.is_empty: - raise ValueError("Collection is empty") + raise ValueError("DatasetCollection is empty") adatas = [load_adata(g) for g in collection] return self.add_anndatas(adatas) diff --git a/tests/conftest.py b/tests/conftest.py index 6211751f..73acaff9 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -13,7 +13,7 @@ from scipy.sparse import random as sparse_random from annbatch import write_sharded -from annbatch.io import Collection +from annbatch.io import DatasetCollection if TYPE_CHECKING: from collections.abc import Generator @@ -108,10 +108,10 @@ def adata_with_h5_path_different_var_space( @pytest.fixture(scope="session") def simple_collection( tmpdir_factory, adata_with_zarr_path_same_var_space: tuple[ad.AnnData, Path] -) -> tuple[Collection, ad.AnnData]: +) -> tuple[DatasetCollection, ad.AnnData]: zarr_stores = sorted(f for f in adata_with_zarr_path_same_var_space[1].iterdir() if f.is_dir()) output_path = Path(tmpdir_factory.mktemp("zarr_folder")) / "simple_fixture.zarr" - collection = Collection(output_path).add( + collection = DatasetCollection(output_path).add( zarr_stores, zarr_sparse_chunk_size=10, zarr_sparse_shard_size=20, diff --git a/tests/test_dataset.py b/tests/test_dataset.py index 77fdce0a..bebbe3c7 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -25,7 +25,7 @@ from collections.abc import Callable from pathlib import Path - from annbatch.io import Collection + from annbatch.io import DatasetCollection class Data(TypedDict): @@ -157,7 +157,9 @@ def concat(datas: list[Data | ad.AnnData]) -> ListData | list[ad.AnnData]: ] ], ) -def test_store_load_dataset(simple_collection: tuple[ad.AnnData, Collection], *, shuffle: bool, gen_loader, use_zarrs): +def test_store_load_dataset( + simple_collection: tuple[ad.AnnData, DatasetCollection], *, shuffle: bool, gen_loader, use_zarrs +): """ This test verifies that the DaskDataset works correctly: 1. The DaskDataset correctly loads data from the mock store diff --git a/tests/test_preshuffle.py b/tests/test_preshuffle.py index 90c92ac3..ad412baa 100644 --- a/tests/test_preshuffle.py +++ b/tests/test_preshuffle.py @@ -10,7 +10,7 @@ import scipy.sparse as sp import zarr -from annbatch import Collection, write_sharded +from annbatch import DatasetCollection, write_sharded if TYPE_CHECKING: from collections.abc import Callable @@ -48,7 +48,7 @@ def test_store_creation_warngs_with_different_keys(elem_name: Literal["obsm", "l adata_1.write_h5ad(path_1) adata_2.write_h5ad(path_2) with pytest.warns(UserWarning, match=rf"Found {elem_name} keys.* not present in all anndatas"): - Collection(tmp_path / "collection.zarr").add( + DatasetCollection(tmp_path / "collection.zarr").add( [path_1, path_2], zarr_sparse_chunk_size=10, zarr_sparse_shard_size=20, @@ -67,7 +67,7 @@ def test_store_creation_path_added_to_obs(tmp_path: Path): adata_2.write_h5ad(path_2) paths = [path_1, path_2] output_dir = tmp_path / "path_src_collection.zarr" - Collection(output_dir).add( + DatasetCollection(output_dir).add( paths, zarr_sparse_chunk_size=10, zarr_sparse_shard_size=20, @@ -96,7 +96,7 @@ def test_store_addition_different_keys( orig_path = tmp_path / "orig.h5ad" adata_orig.write_h5ad(orig_path) output_path = tmp_path / "zarr_store_addition_different_keys.zarr" - collection = Collection(output_path) + collection = DatasetCollection(output_path) collection.add( [orig_path], zarr_sparse_chunk_size=10, @@ -144,7 +144,7 @@ def test_store_creation_default( var_subset = [f"gene_{i}" for i in range(100)] h5_files = sorted(adata_with_h5_path_different_var_space[1].iterdir()) output_path = adata_with_h5_path_different_var_space[1].parent / "zarr_store_creation_test_default.zarr" - Collection(output_path).add( + DatasetCollection(output_path).add( [adata_with_h5_path_different_var_space[1] / f for f in h5_files if str(f).endswith(".h5ad")], var_subset=var_subset, zarr_sparse_chunk_size=10, @@ -167,7 +167,7 @@ def test_store_creation_drop_elem( output_path = adata_with_h5_path_different_var_space[1].parent / "zarr_store_creation_drop_elems.zarr" output_path.mkdir(parents=True, exist_ok=True) - Collection(output_path).add( + DatasetCollection(output_path).add( [adata_with_h5_path_different_var_space[1] / f for f in h5_files if str(f).endswith(".h5ad")], var_subset=var_subset, zarr_sparse_chunk_size=10, @@ -190,7 +190,7 @@ def test_store_creation( var_subset = [f"gene_{i}" for i in range(100)] h5_files = sorted(adata_with_h5_path_different_var_space[1].iterdir()) output_path = adata_with_h5_path_different_var_space[1].parent / f"zarr_store_creation_test_{shuffle}.zarr" - Collection(output_path).add( + DatasetCollection(output_path).add( [adata_with_h5_path_different_var_space[1] / f for f in h5_files if str(f).endswith(".h5ad")], var_subset=var_subset, zarr_sparse_chunk_size=10, @@ -200,7 +200,7 @@ def test_store_creation( n_obs_per_dataset=60, shuffle=shuffle, ) - assert not Collection(output_path).is_empty + assert not DatasetCollection(output_path).is_empty assert zarr.open(output_path).attrs["annbatch-shuffled"] adata_orig = adata_with_h5_path_different_var_space[0] @@ -256,7 +256,7 @@ def test_mismatched_raw_concat( output_path = adata_with_h5_path_different_var_space[1].parent / "zarr_store_creation_test_heterogeneous.zarr" h5_paths = [adata_with_h5_path_different_var_space[1] / f for f in h5_files if str(f).endswith(".h5ad")] with pytest.warns(UserWarning, match=r"Found raw keys not present in all anndatas"): - Collection(output_path).add( + DatasetCollection(output_path).add( h5_paths, zarr_sparse_chunk_size=10, zarr_sparse_shard_size=20, @@ -299,7 +299,7 @@ def test_store_extension( original = all_h5_paths additional = all_h5_paths[4:] # don't add everything to get a "different" var space # create new store - collection = Collection(store_path) + collection = DatasetCollection(store_path) collection.add( original, zarr_sparse_chunk_size=10, @@ -335,7 +335,7 @@ def test_store_extension( def test_empty(tmp_path: Path): g = zarr.open(tmp_path / "empty.zarr") - collection = Collection(g) + collection = DatasetCollection(g) assert collection.is_empty # Doesn't matter what errors as long as this function runs, but not to completion with pytest.raises(TypeError): From f6141c8b7f2b8fa98c7a9a4761b45ad206698463 Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Wed, 14 Jan 2026 16:37:36 +0100 Subject: [PATCH 11/39] fix: try getting directly from colletion --- tests/test_preshuffle.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_preshuffle.py b/tests/test_preshuffle.py index ad412baa..accc996f 100644 --- a/tests/test_preshuffle.py +++ b/tests/test_preshuffle.py @@ -167,7 +167,7 @@ def test_store_creation_drop_elem( output_path = adata_with_h5_path_different_var_space[1].parent / "zarr_store_creation_drop_elems.zarr" output_path.mkdir(parents=True, exist_ok=True) - DatasetCollection(output_path).add( + collection = DatasetCollection(output_path).add( [adata_with_h5_path_different_var_space[1] / f for f in h5_files if str(f).endswith(".h5ad")], var_subset=var_subset, zarr_sparse_chunk_size=10, @@ -177,7 +177,7 @@ def test_store_creation_drop_elem( n_obs_per_dataset=60, load_adata=_read_lazy_x_and_obs_only, ) - adata_output = ad.read_zarr(next(output_path.iterdir())) + adata_output = ad.io.read_elem(next(iter(collection))) assert "arr" not in adata_output.obsm assert adata_output.raw is None From e02a424b5cd048805213b899463f8ce2d530ed8c Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Wed, 14 Jan 2026 16:40:33 +0100 Subject: [PATCH 12/39] chore: use collection iteration --- tests/test_preshuffle.py | 20 +++++++++----------- 1 file changed, 9 insertions(+), 11 deletions(-) diff --git a/tests/test_preshuffle.py b/tests/test_preshuffle.py index accc996f..40a6db26 100644 --- a/tests/test_preshuffle.py +++ b/tests/test_preshuffle.py @@ -67,7 +67,7 @@ def test_store_creation_path_added_to_obs(tmp_path: Path): adata_2.write_h5ad(path_2) paths = [path_1, path_2] output_dir = tmp_path / "path_src_collection.zarr" - DatasetCollection(output_dir).add( + collection = DatasetCollection(output_dir).add( paths, zarr_sparse_chunk_size=10, zarr_sparse_shard_size=20, @@ -76,9 +76,7 @@ def test_store_creation_path_added_to_obs(tmp_path: Path): n_obs_per_dataset=10, shuffle=False, ) - adata_result = ad.concat( - [ad.read_zarr(path) for path in sorted((output_dir).iterdir()) if path.is_dir()], join="outer" - ) + adata_result = ad.concat([ad.io.read_elem(g) for g in collection], join="outer") pd.testing.assert_extension_array_equal( adata_result.obs["src_path"].array, pd.Categorical(([str(path_1)] * 10) + ([str(path_2)] * 10), categories=[str(p) for p in paths]), @@ -144,7 +142,7 @@ def test_store_creation_default( var_subset = [f"gene_{i}" for i in range(100)] h5_files = sorted(adata_with_h5_path_different_var_space[1].iterdir()) output_path = adata_with_h5_path_different_var_space[1].parent / "zarr_store_creation_test_default.zarr" - DatasetCollection(output_path).add( + collection = DatasetCollection(output_path).add( [adata_with_h5_path_different_var_space[1] / f for f in h5_files if str(f).endswith(".h5ad")], var_subset=var_subset, zarr_sparse_chunk_size=10, @@ -153,7 +151,7 @@ def test_store_creation_default( zarr_dense_shard_size=20, n_obs_per_dataset=60, ) - assert isinstance(ad.read_zarr(next(p for p in (output_path).iterdir() if p.is_dir())).X, sp.csr_matrix) + assert isinstance(ad.io.read_elem(next(iter(collection))).X, sp.csr_matrix) assert sorted(glob.glob(str(output_path / "dataset_*"))) == sorted( str(p) for p in (output_path).iterdir() if p.is_dir() ) @@ -190,7 +188,7 @@ def test_store_creation( var_subset = [f"gene_{i}" for i in range(100)] h5_files = sorted(adata_with_h5_path_different_var_space[1].iterdir()) output_path = adata_with_h5_path_different_var_space[1].parent / f"zarr_store_creation_test_{shuffle}.zarr" - DatasetCollection(output_path).add( + collection = DatasetCollection(output_path).add( [adata_with_h5_path_different_var_space[1] / f for f in h5_files if str(f).endswith(".h5ad")], var_subset=var_subset, zarr_sparse_chunk_size=10, @@ -205,7 +203,7 @@ def test_store_creation( adata_orig = adata_with_h5_path_different_var_space[0] # make sure all category dtypes match - adatas_shuffled = [ad.read_zarr(zarr_path) for zarr_path in sorted(output_path.iterdir()) if zarr_path.is_dir()] + adatas_shuffled = [ad.io.read_elem(g) for g in collection] for adata in adatas_shuffled: assert adata.obs["label"].dtype == adata_orig.obs["label"].dtype # subset to var_subset @@ -256,7 +254,7 @@ def test_mismatched_raw_concat( output_path = adata_with_h5_path_different_var_space[1].parent / "zarr_store_creation_test_heterogeneous.zarr" h5_paths = [adata_with_h5_path_different_var_space[1] / f for f in h5_files if str(f).endswith(".h5ad")] with pytest.warns(UserWarning, match=r"Found raw keys not present in all anndatas"): - DatasetCollection(output_path).add( + collection = DatasetCollection(output_path).add( h5_paths, zarr_sparse_chunk_size=10, zarr_sparse_shard_size=20, @@ -280,7 +278,7 @@ def test_mismatched_raw_concat( adata_orig = ad.concat(adatas_orig, join="outer") adata_orig.obs_names_make_unique() - adata = ad.concat([ad.read_zarr(zarr_path) for zarr_path in sorted(output_path.iterdir()) if zarr_path.is_dir()]) + adata = ad.concat([ad.io.read_elem(g) for g in collection]) del adata.obs["src_path"] pd.testing.assert_frame_equal(adata_orig.var, adata.var) pd.testing.assert_frame_equal(adata_orig.obs, adata.obs) @@ -318,7 +316,7 @@ def test_store_extension( zarr_dense_chunk_size=5, zarr_dense_shard_size=10, ) - adatas_on_disk = [ad.read_zarr(zarr_path) for zarr_path in sorted(store_path.iterdir()) if zarr_path.is_dir()] + adatas_on_disk = [ad.io.read_elem(g) for g in collection] adata = ad.concat(adatas_on_disk) adata_orig = adata_with_h5_path_different_var_space[0] expected_adata = ad.concat([adata_orig, adata_orig[adata_orig.obs["store_id"] >= 4]], join="outer") From 34ba111c4bfed613845aba8ba0e4893f22c4ed84 Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Wed, 14 Jan 2026 16:55:05 +0100 Subject: [PATCH 13/39] fix: forward ref --- src/annbatch/io.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/annbatch/io.py b/src/annbatch/io.py index d444640a..899cd9d3 100644 --- a/src/annbatch/io.py +++ b/src/annbatch/io.py @@ -6,7 +6,7 @@ from collections import defaultdict from functools import wraps from pathlib import Path -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Self import anndata as ad import dask.array as da @@ -266,12 +266,12 @@ def wrapper(*args, **kwargs): return wrapper -class DatasetCollection[T: h5py.Group | zarr.Group]: +class DatasetCollection[T: (h5py.Group, zarr.Group)]: """A preshuffled collection object including functionality for creating, adding to, and loading collections.""" _group: T - def __init__(self, group: T | str | Path, *, mode: Literal["a", "r", "r+"] = "a"): + def __init__(self, group: zarr.Group | h5py.Group | str | Path, *, mode: Literal["a", "r", "r+"] = "a"): """Initialization of the object at a given location. Note that if the group is a h5py/zarr object, it must have the correct permissions for any subsequent operations you plan to do. @@ -300,7 +300,7 @@ def __init__(self, group: T | str | Path, *, mode: Literal["a", "r", "r+"] = "a" def _dataset_keys(self) -> list[str]: return sorted([k for k in self._group.keys() if re.match(rf"{DATASET_PREFIX}_([0-9]*)", k) is not None]) - def __iter__(self) -> Generator[zarr.Group]: + def __iter__(self) -> Generator[T]: for k in self._dataset_keys: yield self._group[k] @@ -329,7 +329,7 @@ def add( n_obs_per_dataset: int = 2_097_152, shuffle_slice_size: int = 1000, shuffle: bool = True, - ): + ) -> Self: """Take AnnData paths, create or add to an on-disk set of AnnData datasets with uniform var spaces at the desired path (with `n_obs_per_dataset` rows per store if running for the first time). The set of AnnData datasets is collectively referred to as a "collection" where each dataset is called `dataset_i.{zarr,h5ad}`. @@ -433,7 +433,7 @@ def _create_collection( n_obs_per_dataset: int = 2_097_152, shuffle_slice_size: int = 1000, shuffle: bool = True, - ): + ) -> None: """Take AnnData paths, create an on-disk set of AnnData datasets with uniform var spaces at the desired path with `n_obs_per_dataset` rows per store. The set of AnnData datasets is collectively referred to as a "collection" where each dataset is called `dataset_i.{zarr,h5ad}`. From 80ce3a1e4fc561e4cc0c2c74b33ae2862f642a56 Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Wed, 14 Jan 2026 17:19:16 +0100 Subject: [PATCH 14/39] fix: notebook --- docs/notebooks/example.ipynb | 121 +++++++++++++++++++++-------------- 1 file changed, 72 insertions(+), 49 deletions(-) diff --git a/docs/notebooks/example.ipynb b/docs/notebooks/example.ipynb index 698b1da5..78e404ee 100644 --- a/docs/notebooks/example.ipynb +++ b/docs/notebooks/example.ipynb @@ -19,7 +19,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 1, "metadata": { "tags": [ "hide-output" @@ -32,7 +32,7 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": 2, "metadata": { "tags": [ "hide-output" @@ -43,28 +43,8 @@ "name": "stdout", "output_type": "stream", "text": [ - "--2025-10-09 09:43:19-- https://datasets.cellxgene.cziscience.com/866d7d5e-436b-4dbd-b7c1-7696487d452e.h5ad\n", - "Resolving datasets.cellxgene.cziscience.com (datasets.cellxgene.cziscience.com)... 18.64.79.73, 18.64.79.80, 18.64.79.109, ...\n", - "Connecting to datasets.cellxgene.cziscience.com (datasets.cellxgene.cziscience.com)|18.64.79.73|:443... connected.\n", - "HTTP request sent, awaiting response... 200 OK\n", - "Length: 773247972 (737M) [binary/octet-stream]\n", - "Saving to: ‘866d7d5e-436b-4dbd-b7c1-7696487d452e.h5ad’\n", - "\n", - "866d7d5e-436b-4dbd- 100%[===================>] 737.43M 398MB/s in 1.9s \n", - "\n", - "2025-10-09 09:43:21 (398 MB/s) - ‘866d7d5e-436b-4dbd-b7c1-7696487d452e.h5ad’ saved [773247972/773247972]\n", - "\n", - "--2025-10-09 09:43:22-- https://datasets.cellxgene.cziscience.com/f81463b8-4986-4904-a0ea-20ff02cbb317.h5ad\n", - "Resolving datasets.cellxgene.cziscience.com (datasets.cellxgene.cziscience.com)... 18.64.79.73, 18.64.79.80, 18.64.79.72, ...\n", - "Connecting to datasets.cellxgene.cziscience.com (datasets.cellxgene.cziscience.com)|18.64.79.73|:443... connected.\n", - "HTTP request sent, awaiting response... 200 OK\n", - "Length: 1631759823 (1.5G) [binary/octet-stream]\n", - "Saving to: ‘f81463b8-4986-4904-a0ea-20ff02cbb317.h5ad’\n", - "\n", - "f81463b8-4986-4904- 100%[===================>] 1.52G 425MB/s in 3.9s \n", - "\n", - "2025-10-09 09:43:26 (403 MB/s) - ‘f81463b8-4986-4904-a0ea-20ff02cbb317.h5ad’ saved [1631759823/1631759823]\n", - "\n" + "zsh:1: command not found: wget\n", + "zsh:1: command not found: wget\n" ] } ], @@ -85,7 +65,7 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": 7, "metadata": { "tags": [ "hide-output" @@ -95,24 +75,23 @@ { "data": { "text/plain": [ - "" + "" ] }, - "execution_count": 1, + "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import zarr\n", - "import zarrs # noqa\n", "\n", "zarr.config.set({\"codec_pipeline.path\": \"zarrs.ZarrsCodecPipeline\"})" ] }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 8, "metadata": {}, "outputs": [], "source": [ @@ -154,13 +133,39 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 11, "metadata": { "tags": [ "hide-output" ] }, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "checking for mismatched keys: 100%|██████████| 2/2 [00:01<00:00, 1.73it/s]\n", + "/Users/ilangold/Projects/Theis/annbatch/src/annbatch/io.py:507: UserWarning: Found layers keys ['soupX'] not present in all anndatas ['866d7d5e-436b-4dbd-b7c1-7696487d452e.h5ad', 'f81463b8-4986-4904-a0ea-20ff02cbb317.h5ad'], consider stopping and using the `load_adata` argument to alter layers accordingly.\n", + " _check_for_mismatched_keys(adata_paths)\n", + "/Users/ilangold/Projects/Theis/annbatch/src/annbatch/io.py:507: UserWarning: Found obs keys ['Eye', 'alignment_software', 'nCount_RNA', 'Region', 'donor_age', 'Post_mortemtime', 'reference_genome', 'library_id_repository', 'sequenced_fragment', 'donor_cause_of_death', 'sampleid', 'sequencing_platform', 'sample_id', 'nFeature_RNA', 'library_id', 'gene_annotation_version', 'Study', 'sample_collection_year', 'percent.mt', 'institute', 'Developmental', 'tissue_source', 'PMT_in_hrs', 'tissue_handling_interval', 'sample_collection_method', 'intronic_reads_counted', 'sample_source', 'leiden_scVI', 'library_starting_quantity', 'library_uuid', 'sample_uuid', 'sample_derivation_process', 'suspension_dissociation_reagent', 'donor_BMI_at_collection', 'mapped_reference_annotation', 'suspension_dissociation_time'] not present in all anndatas ['866d7d5e-436b-4dbd-b7c1-7696487d452e.h5ad', 'f81463b8-4986-4904-a0ea-20ff02cbb317.h5ad'], consider stopping and using the `load_adata` argument to alter obs accordingly.\n", + " _check_for_mismatched_keys(adata_paths)\n", + "loading: 2it [00:00, 2.26it/s]\n", + "processing chunks: 0%| | 0/1 [00:00" + ] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "import anndata as ad\n", "from annbatch import DatasetCollection\n", @@ -207,7 +212,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 12, "metadata": {}, "outputs": [], "source": [ @@ -218,7 +223,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 14, "metadata": { "tags": [ "hide-output" @@ -228,10 +233,10 @@ { "data": { "text/plain": [ - "" + "" ] }, - "execution_count": 5, + "execution_count": 14, "metadata": {}, "output_type": "execute_result" } @@ -268,7 +273,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 18, "metadata": { "tags": [ "hide-output" @@ -279,9 +284,7 @@ "name": "stderr", "output_type": "stream", "text": [ - " 0%| | 0/171792 [00:00], consider stopping and using the `load_adata` argument to alter obs accordingly.\n", + " _check_for_mismatched_keys([adata_concat] + [self._group[k] for k in self._dataset_keys])\n", + "processing chunks: 0%| | 0/1 [00:00" + ] + }, + "execution_count": 19, + "metadata": {}, + "output_type": "execute_result" } ], "source": [ @@ -361,6 +377,13 @@ " load_adata=read_x_and_obs_only,\n", ")" ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] } ], "metadata": { @@ -379,7 +402,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.12.6" + "version": "3.12.3" } }, "nbformat": 4, From a15e7996992dac1e5b6f0d5a7a74df18730041eb Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Wed, 14 Jan 2026 17:38:47 +0100 Subject: [PATCH 15/39] fix: dont warn for keys that are ignored by custom loading functions --- docs/notebooks/example.ipynb | 783 +++++++++++++++++------------------ src/annbatch/io.py | 10 +- tests/test_preshuffle.py | 22 +- 3 files changed, 404 insertions(+), 411 deletions(-) diff --git a/docs/notebooks/example.ipynb b/docs/notebooks/example.ipynb index 78e404ee..15a6675c 100644 --- a/docs/notebooks/example.ipynb +++ b/docs/notebooks/example.ipynb @@ -1,410 +1,379 @@ { - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Quickstart `annbatch`" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "This notebook will walk you through the following steps:\n", - "1. How to convert an existing collection of `anndata` files into a shuffled, zarr-based, collection of `anndata` datasets\n", - "2. How to load the converted collection using `annbatch`\n", - "3. Extend an existing collection with new `anndata` datasets" - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "metadata": { - "tags": [ - "hide-output" - ] - }, - "outputs": [], - "source": [ - "# !pip install annbatch[zarrs, torch]" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": { - "tags": [ - "hide-output" - ] - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "zsh:1: command not found: wget\n", - "zsh:1: command not found: wget\n" - ] - } - ], - "source": [ - "# Download two example datasets from CELLxGENE\n", - "!wget https://datasets.cellxgene.cziscience.com/866d7d5e-436b-4dbd-b7c1-7696487d452e.h5ad\n", - "!wget https://datasets.cellxgene.cziscience.com/f81463b8-4986-4904-a0ea-20ff02cbb317.h5ad" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "**IMPORTANT**: Configure zarrs\n", - "\n", - "This step is both required for converting existing `anndata` files into a performant, shuffled collection of datasets for mini batch loading" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "metadata": { - "tags": [ - "hide-output" - ] - }, - "outputs": [ - { - "data": { - "text/plain": [ - "" - ] - }, - "execution_count": 7, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "import zarr\n", - "\n", - "zarr.config.set({\"codec_pipeline.path\": \"zarrs.ZarrsCodecPipeline\"})" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "metadata": {}, - "outputs": [], - "source": [ - "import warnings\n", - "\n", - "# Suppress zarr vlen-utf8 codec warnings\n", - "warnings.filterwarnings(\n", - " \"ignore\",\n", - " message=\"The codec `vlen-utf8` is currently not part in the Zarr format 3 specification.*\",\n", - " category=UserWarning,\n", - " module=\"zarr.codecs.vlen_utf8\",\n", - ")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Converting existing `anndata` files into a shuffled collection" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "The conversion code will take care of the following things:\n", - "* Align (outer join) the gene spaces across all datasets listed in `adata_paths`\n", - " * The gene spaces are outer-joined based on the gene names provided in the `var_names` field of the individual `AnnData` objects.\n", - " * If you want to subset to specific gene space, you can provide a list of gene names via the `var_subset` parameter.\n", - "* Shuffle the cells across all datasets (this works on larger than memory datasets as well).\n", - " * This is important for block-wise shuffling during data loading.\n", - "* Shuffle the input files across multiple output datasets:\n", - " * The size of each individual output dataset can be controlled via the `n_obs_per_dataset` parameter.\n", - " * We recommend to choose a dataset size that comfortably fits into system memory.\n", - "\n", - "\n", - "You can apply custom data transformations to each input h5ad file by supplying a `load_adata` function to `DatasetCollection.add`" - ] - }, - { - "cell_type": "code", - "execution_count": 11, - "metadata": { - "tags": [ - "hide-output" - ] - }, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "checking for mismatched keys: 100%|██████████| 2/2 [00:01<00:00, 1.73it/s]\n", - "/Users/ilangold/Projects/Theis/annbatch/src/annbatch/io.py:507: UserWarning: Found layers keys ['soupX'] not present in all anndatas ['866d7d5e-436b-4dbd-b7c1-7696487d452e.h5ad', 'f81463b8-4986-4904-a0ea-20ff02cbb317.h5ad'], consider stopping and using the `load_adata` argument to alter layers accordingly.\n", - " _check_for_mismatched_keys(adata_paths)\n", - "/Users/ilangold/Projects/Theis/annbatch/src/annbatch/io.py:507: UserWarning: Found obs keys ['Eye', 'alignment_software', 'nCount_RNA', 'Region', 'donor_age', 'Post_mortemtime', 'reference_genome', 'library_id_repository', 'sequenced_fragment', 'donor_cause_of_death', 'sampleid', 'sequencing_platform', 'sample_id', 'nFeature_RNA', 'library_id', 'gene_annotation_version', 'Study', 'sample_collection_year', 'percent.mt', 'institute', 'Developmental', 'tissue_source', 'PMT_in_hrs', 'tissue_handling_interval', 'sample_collection_method', 'intronic_reads_counted', 'sample_source', 'leiden_scVI', 'library_starting_quantity', 'library_uuid', 'sample_uuid', 'sample_derivation_process', 'suspension_dissociation_reagent', 'donor_BMI_at_collection', 'mapped_reference_annotation', 'suspension_dissociation_time'] not present in all anndatas ['866d7d5e-436b-4dbd-b7c1-7696487d452e.h5ad', 'f81463b8-4986-4904-a0ea-20ff02cbb317.h5ad'], consider stopping and using the `load_adata` argument to alter obs accordingly.\n", - " _check_for_mismatched_keys(adata_paths)\n", - "loading: 2it [00:00, 2.26it/s]\n", - "processing chunks: 0%| | 0/1 [00:00" - ] - }, - "execution_count": 11, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "import anndata as ad\n", - "from annbatch import DatasetCollection\n", - "\n", - "\n", - "# For CELLxGENE data, the raw counts can either be found under .raw.X or under .X (if .raw is not supplied).\n", - "# To have a store that only contains raw counts, we can write the following load_adata function\n", - "def read_lazy_x_and_obs_only(path) -> ad.AnnData:\n", - " \"\"\"Custom load function to only load raw counts from CxG data.\"\"\"\n", - " # IMPORTANT: Large data should always be loaded lazily to reduce the memory footprint\n", - " adata_ = ad.experimental.read_lazy(path)\n", - " if adata_.raw is not None:\n", - " x = adata_.raw.X\n", - " var = adata_.raw.var\n", - " else:\n", - " x = adata_.X\n", - " var = adata_.var\n", - "\n", - " return ad.AnnData(\n", - " X=x,\n", - " obs=adata_.obs.to_memory(),\n", - " var=var.to_memory(),\n", - " )\n", - "\n", - "\n", - "collection = DatasetCollection(zarr.open(\"annbatch_collection\"))\n", - "collection.add(\n", - " # List all the h5ad files you want to include in the collection\n", - " adata_paths=[\"866d7d5e-436b-4dbd-b7c1-7696487d452e.h5ad\", \"f81463b8-4986-4904-a0ea-20ff02cbb317.h5ad\"],\n", - " # Path to store the output collection\n", - " shuffle=True, # Whether to pre-shuffle the cells of the collection\n", - " n_obs_per_dataset=2_097_152, # Number of cells per dataset shard\n", - " var_subset=None, # Optionally subset the collection to a specific gene space\n", - " load_adata=read_lazy_x_and_obs_only,\n", - ")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Data loading example" - ] - }, - { - "cell_type": "code", - "execution_count": 12, - "metadata": {}, - "outputs": [], - "source": [ - "from pathlib import Path\n", - "\n", - "COLLECTION_PATH = Path(\"annbatch_collection/\")" - ] - }, - { - "cell_type": "code", - "execution_count": 14, - "metadata": { - "tags": [ - "hide-output" - ] - }, - "outputs": [ - { - "data": { - "text/plain": [ - "" - ] - }, - "execution_count": 14, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "import anndata as ad\n", - "\n", - "from annbatch import Loader\n", - "\n", - "ds = Loader(\n", - " batch_size=4096, # Total number of obs per yielded batch\n", - " chunk_size=256, # Number of obs to load from disk contiguously - default settings should work well\n", - " preload_nchunks=32, # Number of chunks to preload + shuffle - default settings should work well\n", - " preload_to_gpu=False,\n", - " # If True, preloaded chunks are moved to GPU memory via `cupy`, which can put more pressure on GPU memory but will accelerate loading ~20%\n", - " to_torch=True,\n", - ")\n", - "\n", - "# Add in the shuffled data that should be used for training\n", - "ds.add_collection(collection)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "**IMPORTANT:**\n", - "* The `Loader` yields batches of sparse tensors.\n", - "* The conversion to dense tensors should be done on the GPU, as shown in the example below.\n", - " * First call `.cuda()` and then `.to_dense()`\n", - " * E.g. `x = x.cuda().to_dense()`\n", - " * This is significantly faster than doing the dense conversion on the CPU.\n" - ] - }, - { - "cell_type": "code", - "execution_count": 18, - "metadata": { - "tags": [ - "hide-output" - ] - }, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - " 0%| | 42/171792 [00:10<12:13:16, 3.90it/s]\n" - ] - } - ], - "source": [ - "# Iterate over dataloader\n", - "import tqdm\n", - "\n", - "for batch in tqdm.tqdm(ds):\n", - " x, obs = batch[\"data\"], batch[\"labels\"][\"cell_type\"]\n", - " # Important: Convert to dense on GPU\n", - " x = x.cuda().to_dense()\n", - " # Feed data into your model\n", - " ..." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Optional: Extend an existing collection with a new dataset" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "You might want to extend an existing pre-shuffled collection with a new dataset.\n", - "This can be done using the `add_to_collection` function.\n", - "\n", - "This function will take care of shuffling the new dataset into the existing collection without having to re-shuffle the entire collection." - ] - }, - { - "cell_type": "code", - "execution_count": 19, - "metadata": { - "tags": [ - "hide-output" - ] - }, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "checking for mismatched keys: 100%|██████████| 1/1 [00:00<00:00, 2.09it/s]\n", - "loading: 1it [00:10, 10.77s/it]\n", - "checking for mismatched keys: 100%|██████████| 2/2 [00:01<00:00, 1.18it/s]\n", - "/Users/ilangold/Projects/Theis/annbatch/src/annbatch/io.py:613: UserWarning: Found obs keys ['sample_source', 'leiden_scVI', 'library_starting_quantity', 'library_uuid', 'sample_uuid', 'sample_derivation_process', 'suspension_dissociation_reagent', 'donor_BMI_at_collection', 'mapped_reference_annotation', 'suspension_dissociation_time'] not present in all anndatas [AnnData object with n_obs × n_vars = 99457 × 35475\n", - " obs: 'reference_genome', 'gene_annotation_version', 'alignment_software', 'intronic_reads_counted', 'donor_id', 'donor_age', 'self_reported_ethnicity_ontology_term_id', 'donor_cause_of_death', 'donor_living_at_sample_collection', 'sample_id', 'sample_preservation_method', 'tissue_ontology_term_id', 'development_stage_ontology_term_id', 'sample_collection_method', 'tissue_source', 'tissue_type', 'sample_collection_year', 'suspension_derivation_process', 'suspension_uuid', 'suspension_type', 'tissue_handling_interval', 'library_id', 'assay_ontology_term_id', 'sequenced_fragment', 'institute', 'library_id_repository', 'sequencing_platform', 'is_primary_data', 'cell_type_ontology_term_id', 'author_cell_type', 'disease_ontology_term_id', 'reported_diseases', 'sex_ontology_term_id', 'nCount_RNA', 'nFeature_RNA', 'percent.mt', 'Study', 'Developmental', 'Post_mortemtime', 'PMT_in_hrs', 'Eye', 'Region', 'sampleid', 'cell_type', 'assay', 'disease', 'sex', 'tissue', 'self_reported_ethnicity', 'development_stage', 'observation_joinid', 'src_path'\n", - " var: 'feature_name', 'feature_reference', 'feature_biotype', 'feature_length', 'feature_type', ], consider stopping and using the `load_adata` argument to alter obs accordingly.\n", - " _check_for_mismatched_keys([adata_concat] + [self._group[k] for k in self._dataset_keys])\n", - "processing chunks: 0%| | 0/1 [00:00" - ] - }, - "execution_count": 19, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "def read_x_and_obs_only(path) -> ad.AnnData:\n", - " \"\"\"Custom load function to only load raw counts from CxG data.\"\"\"\n", - " # As it's only a small dataset, we can load the full dataset into memory to speed up computations\n", - " adata_ = ad.read_h5ad(path) # Replace with ad.experimental.read_lazy if data does not fit into memory anymore\n", - " if adata_.raw is not None:\n", - " x = adata_.raw.X\n", - " var = adata_.raw.var\n", - " else:\n", - " x = adata_.X\n", - " var = adata_.var\n", - "\n", - " return ad.AnnData(X=x, obs=adata_.obs, var=var)\n", - "\n", - "\n", - "collection.add(\n", - " adata_paths=[\n", - " \"866d7d5e-436b-4dbd-b7c1-7696487d452e.h5ad\",\n", - " ],\n", - " load_adata=read_x_and_obs_only,\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "venv", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.12.3" - } + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Quickstart `annbatch`" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "This notebook will walk you through the following steps:\n", + "1. How to convert an existing collection of `anndata` files into a shuffled, zarr-based, collection of `anndata` datasets\n", + "2. How to load the converted collection using `annbatch`\n", + "3. Extend an existing collection with new `anndata` datasets" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "tags": [ + "hide-output" + ] + }, + "outputs": [], + "source": [ + "# !pip install annbatch[zarrs, torch]" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": { + "tags": [ + "hide-output" + ] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "zsh:1: command not found: wget\n", + "zsh:1: command not found: wget\n" + ] + } + ], + "source": [ + "# Download two example datasets from CELLxGENE\n", + "!wget https://datasets.cellxgene.cziscience.com/866d7d5e-436b-4dbd-b7c1-7696487d452e.h5ad\n", + "!wget https://datasets.cellxgene.cziscience.com/f81463b8-4986-4904-a0ea-20ff02cbb317.h5ad" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "**IMPORTANT**: Configure zarrs\n", + "\n", + "This step is both required for converting existing `anndata` files into a performant, shuffled collection of datasets for mini batch loading" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": { + "tags": [ + "hide-output" + ] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 21, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import zarr\n", + "\n", + "zarr.config.set({\"codec_pipeline.path\": \"zarrs.ZarrsCodecPipeline\"})" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "metadata": {}, + "outputs": [], + "source": [ + "import warnings\n", + "\n", + "# Suppress zarr vlen-utf8 codec warnings\n", + "warnings.filterwarnings(\n", + " \"ignore\",\n", + " message=\"The codec `vlen-utf8` is currently not part in the Zarr format 3 specification.*\",\n", + " category=UserWarning,\n", + " module=\"zarr.codecs.vlen_utf8\",\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Converting existing `anndata` files into a shuffled collection" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The conversion code will take care of the following things:\n", + "* Align (outer join) the gene spaces across all datasets listed in `adata_paths`\n", + " * The gene spaces are outer-joined based on the gene names provided in the `var_names` field of the individual `AnnData` objects.\n", + " * If you want to subset to specific gene space, you can provide a list of gene names via the `var_subset` parameter.\n", + "* Shuffle the cells across all datasets (this works on larger than memory datasets as well).\n", + " * This is important for block-wise shuffling during data loading.\n", + "* Shuffle the input files across multiple output datasets:\n", + " * The size of each individual output dataset can be controlled via the `n_obs_per_dataset` parameter.\n", + " * We recommend to choose a dataset size that comfortably fits into system memory.\n", + "\n", + "\n", + "You can apply custom data transformations to each input h5ad file by supplying a `load_adata` function to `DatasetCollection.add`" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "metadata": { + "tags": [ + "hide-output" + ] + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "checking for mismatched keys: 100%|██████████| 2/2 [00:00<00:00, 2.03it/s]\n", + "loading: 2it [00:00, 2.16it/s]\n", + "processing chunks: 0%| | 0/1 [00:00" + ] + }, + "execution_count": 23, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import anndata as ad\n", + "from annbatch import DatasetCollection\n", + "\n", + "\n", + "# For CELLxGENE data, the raw counts can either be found under .raw.X or under .X (if .raw is not supplied).\n", + "# To have a store that only contains raw counts, we can write the following load_adata function\n", + "def read_lazy_x_and_obs_only(path) -> ad.AnnData:\n", + " \"\"\"Custom load function to only load raw counts from CxG data.\"\"\"\n", + " # IMPORTANT: Large data should always be loaded lazily to reduce the memory footprint\n", + " adata_ = ad.experimental.read_lazy(path)\n", + " if adata_.raw is not None:\n", + " x = adata_.raw.X\n", + " var = adata_.raw.var\n", + " else:\n", + " x = adata_.X\n", + " var = adata_.var\n", + "\n", + " return ad.AnnData(\n", + " X=x,\n", + " obs=adata_.obs.to_memory()[\n", + " [\"cell_type\"]\n", + " ], # let's only take cell type since it is shared - otherwise DatasetCollection will warn about all the columns we are missing\n", + " var=var.to_memory(),\n", + " )\n", + "\n", + "\n", + "collection = DatasetCollection(zarr.open(\"annbatch_collection\"))\n", + "collection.add(\n", + " # List all the h5ad files you want to include in the collection\n", + " adata_paths=[\"866d7d5e-436b-4dbd-b7c1-7696487d452e.h5ad\", \"f81463b8-4986-4904-a0ea-20ff02cbb317.h5ad\"],\n", + " # Path to store the output collection\n", + " shuffle=True, # Whether to pre-shuffle the cells of the collection\n", + " n_obs_per_dataset=2_097_152, # Number of cells per dataset shard\n", + " var_subset=None, # Optionally subset the collection to a specific gene space\n", + " load_adata=read_lazy_x_and_obs_only,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Data loading example" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "metadata": { + "tags": [ + "hide-output" + ] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 24, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import anndata as ad\n", + "\n", + "from annbatch import Loader\n", + "\n", + "ds = Loader(\n", + " batch_size=4096, # Total number of obs per yielded batch\n", + " chunk_size=256, # Number of obs to load from disk contiguously - default settings should work well\n", + " preload_nchunks=32, # Number of chunks to preload + shuffle - default settings should work well\n", + " preload_to_gpu=False,\n", + " # If True, preloaded chunks are moved to GPU memory via `cupy`, which can put more pressure on GPU memory but will accelerate loading ~20%\n", + " to_torch=True,\n", + ")\n", + "\n", + "# Add in the shuffled data that should be used for training\n", + "ds.add_collection(collection)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "**IMPORTANT:**\n", + "* The `Loader` yields batches of sparse tensors.\n", + "* The conversion to dense tensors should be done on the GPU, as shown in the example below.\n", + " * First call `.cuda()` and then `.to_dense()`\n", + " * E.g. `x = x.cuda().to_dense()`\n", + " * This is significantly faster than doing the dense conversion on the CPU.\n" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "metadata": { + "tags": [ + "hide-output" + ] + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 0%| | 42/171792 [00:08<9:13:24, 5.17it/s]\n" + ] + } + ], + "source": [ + "# Iterate over dataloader\n", + "import tqdm\n", + "\n", + "for batch in tqdm.tqdm(ds):\n", + " x, obs = batch[\"data\"], batch[\"labels\"][\"cell_type\"]\n", + " # Important: Convert to dense on GPU\n", + " x = x.cuda().to_dense()\n", + " # Feed data into your model\n", + " ..." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Optional: Extend an existing collection with a new dataset" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "You might want to extend an existing pre-shuffled collection with a new dataset.\n", + "This can be done using the `add` method again.\n", + "\n", + "This function will take care of shuffling the new dataset into the existing collection without having to re-shuffle the entire collection." + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "metadata": { + "tags": [ + "hide-output" + ] + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "checking for mismatched keys: 100%|██████████| 1/1 [00:00<00:00, 1.60it/s]\n", + "loading: 1it [00:00, 1.73it/s]\n", + "checking for mismatched keys: 100%|██████████| 2/2 [00:00<00:00, 14.22it/s]\n", + "processing chunks: 0%| | 0/1 [00:00" + ] + }, + "execution_count": 27, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "collection.add(\n", + " adata_paths=[\n", + " \"866d7d5e-436b-4dbd-b7c1-7696487d452e.h5ad\",\n", + " ],\n", + " load_adata=read_lazy_x_and_obs_only,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.3" + } + }, + "nbformat": 4, + "nbformat_minor": 4 } diff --git a/src/annbatch/io.py b/src/annbatch/io.py index 899cd9d3..e60813e2 100644 --- a/src/annbatch/io.py +++ b/src/annbatch/io.py @@ -107,6 +107,10 @@ def callback( def _check_for_mismatched_keys( paths_or_anndatas: Iterable[PathLike[str] | ad.AnnData | zarr.Group | h5py.Group] | Iterable[str | ad.AnnData], + *, + load_adata: Callable[[PathLike[str] | str], ad.AnnData] = lambda x: ad.experimental.read_lazy( + x, load_annotation_index=False + ), ): num_raw_in_adata = 0 found_keys: dict[str, defaultdict[str, int]] = { @@ -116,7 +120,7 @@ def _check_for_mismatched_keys( } for path_or_anndata in tqdm(paths_or_anndatas, desc="checking for mismatched keys"): if not isinstance(path_or_anndata, ad.AnnData): - adata = ad.experimental.read_lazy(path_or_anndata, load_annotation_index=False) + adata = load_adata(path_or_anndata) else: adata = path_or_anndata for elem_name, key_count in found_keys.items(): @@ -504,7 +508,7 @@ def _create_collection( """ if not self.is_empty: raise RuntimeError("Cannot create a collection at a location that already has a shuffled collection") - _check_for_mismatched_keys(adata_paths) + _check_for_mismatched_keys(adata_paths, load_adata=load_adata) adata_concat = _lazy_load_anndatas(adata_paths, load_adata=load_adata) adata_concat.obs_names_make_unique() chunks = _create_chunks_for_shuffling( @@ -606,7 +610,7 @@ def _add_to_collection( if self.is_empty: raise ValueError("Store is empty. Please run `DatasetCollection.add` first.") # Check for mismatched keys among the inputs. - _check_for_mismatched_keys(adata_paths) + _check_for_mismatched_keys(adata_paths, load_adata=load_adata) adata_concat = _lazy_load_anndatas(adata_paths, load_adata=load_adata) # Check for mismatched keys between datasets and the inputs. diff --git a/tests/test_preshuffle.py b/tests/test_preshuffle.py index 40a6db26..db5833c2 100644 --- a/tests/test_preshuffle.py +++ b/tests/test_preshuffle.py @@ -4,6 +4,7 @@ from typing import TYPE_CHECKING, Literal import anndata as ad +import h5py import numpy as np import pandas as pd import pytest @@ -37,7 +38,7 @@ def test_write_sharded_shard_size_too_big(tmp_path: Path, chunk_size: int, expec @pytest.mark.parametrize("elem_name", ["obsm", "layers", "raw", "obs"]) -def test_store_creation_warngs_with_different_keys(elem_name: Literal["obsm", "layers", "raw"], tmp_path: Path): +def test_store_creation_warnings_with_different_keys(elem_name: Literal["obsm", "layers", "raw"], tmp_path: Path): adata_1 = ad.AnnData(X=np.random.randn(10, 20)) extra_args = { elem_name: {"arr" if elem_name != "raw" else "X": np.random.randn(10, 20) if elem_name != "obs" else ["a"] * 10} @@ -58,6 +59,25 @@ def test_store_creation_warngs_with_different_keys(elem_name: Literal["obsm", "l ) +def test_store_creation_no_warnings_with_custom_load(tmp_path: Path): + adata_1 = ad.AnnData(X=np.random.randn(10, 20)) + adata_2 = ad.AnnData(X=np.random.randn(10, 20), layers={"arr": np.random.randn(10, 20)}) + path_1 = tmp_path / "just_x.h5ad" + path_2 = tmp_path / "with_extra_key.h5ad" + adata_1.write_h5ad(path_1) + adata_2.write_h5ad(path_2) + collection = DatasetCollection(tmp_path / "collection.zarr").add( + [path_1, path_2], + zarr_sparse_chunk_size=10, + zarr_sparse_shard_size=20, + zarr_dense_chunk_size=5, + zarr_dense_shard_size=10, + n_obs_per_dataset=10, + load_adata=lambda x: ad.AnnData(X=ad.io.read_elem(h5py.File(x)["X"])), + ) + assert len(ad.read_zarr(next(iter(collection))).layers.keys()) == 0 + + def test_store_creation_path_added_to_obs(tmp_path: Path): adata_1 = ad.AnnData(X=np.random.randn(10, 20)) adata_2 = adata_1.copy() From 6eee41eecf65d28cfdc55bc1506d08d3aed39cd0 Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Wed, 14 Jan 2026 17:47:07 +0100 Subject: [PATCH 16/39] fix: remove now-useless test --- tests/test_preshuffle.py | 76 ++++++++++++++-------------------------- 1 file changed, 26 insertions(+), 50 deletions(-) diff --git a/tests/test_preshuffle.py b/tests/test_preshuffle.py index db5833c2..9f4b5683 100644 --- a/tests/test_preshuffle.py +++ b/tests/test_preshuffle.py @@ -140,22 +140,6 @@ def test_store_addition_different_keys( ) -def _read_lazy_x_and_obs_only(path) -> ad.AnnData: - adata_ = ad.experimental.read_lazy(path) - if adata_.raw is not None: - x = adata_.raw.X - var = adata_.raw.var - else: - x = adata_.X - var = adata_.var - - return ad.AnnData( - X=x, - obs=adata_.obs.to_memory(), - var=var.to_memory(), - ) - - def test_store_creation_default( adata_with_h5_path_different_var_space: tuple[ad.AnnData, Path], ): @@ -177,29 +161,6 @@ def test_store_creation_default( ) -def test_store_creation_drop_elem( - adata_with_h5_path_different_var_space: tuple[ad.AnnData, Path], -): - var_subset = [f"gene_{i}" for i in range(100)] - h5_files = sorted(adata_with_h5_path_different_var_space[1].iterdir()) - output_path = adata_with_h5_path_different_var_space[1].parent / "zarr_store_creation_drop_elems.zarr" - output_path.mkdir(parents=True, exist_ok=True) - - collection = DatasetCollection(output_path).add( - [adata_with_h5_path_different_var_space[1] / f for f in h5_files if str(f).endswith(".h5ad")], - var_subset=var_subset, - zarr_sparse_chunk_size=10, - zarr_sparse_shard_size=20, - zarr_dense_chunk_size=10, - zarr_dense_shard_size=20, - n_obs_per_dataset=60, - load_adata=_read_lazy_x_and_obs_only, - ) - adata_output = ad.io.read_elem(next(iter(collection))) - assert "arr" not in adata_output.obsm - assert adata_output.raw is None - - @pytest.mark.parametrize("shuffle", [pytest.param(True, id="shuffle"), pytest.param(False, id="no_shuffle")]) def test_store_creation( adata_with_h5_path_different_var_space: tuple[ad.AnnData, Path], @@ -262,6 +223,22 @@ def test_store_creation( assert z["X"]["indices"].chunks[0] == 10 +def _read_lazy_x_and_obs_only_from_raw(path) -> ad.AnnData: + adata_ = ad.experimental.read_lazy(path) + if adata_.raw is not None: + x = adata_.raw.X + var = adata_.raw.var + else: + x = adata_.X + var = adata_.var + + return ad.AnnData( + X=x, + obs=adata_.obs.to_memory(), + var=var.to_memory(), + ) + + @pytest.mark.parametrize( "adata_with_h5_path_different_var_space", [{"all_adatas_have_raw": False}], @@ -273,17 +250,16 @@ def test_mismatched_raw_concat( h5_files = sorted(adata_with_h5_path_different_var_space[1].iterdir()) output_path = adata_with_h5_path_different_var_space[1].parent / "zarr_store_creation_test_heterogeneous.zarr" h5_paths = [adata_with_h5_path_different_var_space[1] / f for f in h5_files if str(f).endswith(".h5ad")] - with pytest.warns(UserWarning, match=r"Found raw keys not present in all anndatas"): - collection = DatasetCollection(output_path).add( - h5_paths, - zarr_sparse_chunk_size=10, - zarr_sparse_shard_size=20, - zarr_dense_chunk_size=10, - zarr_dense_shard_size=20, - n_obs_per_dataset=60, - load_adata=_read_lazy_x_and_obs_only, - shuffle=False, # don't shuffle -> want to check if the right attributes get taken - ) + collection = DatasetCollection(output_path).add( + h5_paths, + zarr_sparse_chunk_size=10, + zarr_sparse_shard_size=20, + zarr_dense_chunk_size=10, + zarr_dense_shard_size=20, + n_obs_per_dataset=60, + shuffle=False, # don't shuffle -> want to check if the right attributes get taken + load_adata=_read_lazy_x_and_obs_only_from_raw, + ) adatas_orig = [] for file in h5_paths: From 88733903913e059a2fa7ec537b7a2d7be295c2e9 Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Thu, 15 Jan 2026 13:51:55 +0100 Subject: [PATCH 17/39] fix: handle dataframes in obs + varm --- src/annbatch/io.py | 23 +++++++++++++++-------- tests/conftest.py | 11 ++++++----- 2 files changed, 21 insertions(+), 13 deletions(-) diff --git a/src/annbatch/io.py b/src/annbatch/io.py index e60813e2..5706a478 100644 --- a/src/annbatch/io.py +++ b/src/annbatch/io.py @@ -246,14 +246,21 @@ def _persist_adata_in_memory(adata: ad.AnnData) -> ad.AnnData: del adata.raw adata.raw = adata_raw - for k, elem in adata.obsm.items(): - # TODO: handle `Dataset2D` in `obsm` and `varm` that are - if isinstance(elem, DaskArray): - adata.obsm[k] = _compute_blockwise(elem) - - for k, elem in adata.layers.items(): - if isinstance(elem, DaskArray): - adata.obsm[k] = _compute_blockwise(elem) + for axis_name in ["layers", "obsm", "varm", "obsp", "varp"]: + for k, elem in getattr(adata, axis_name).items(): + # TODO: handle `Dataset2D` in `obsm` and `varm` that are + if isinstance(elem, DaskArray): + getattr(adata, axis_name)[k] = _compute_blockwise(elem) + if isinstance(elem, Dataset2D): + elem = elem.to_memory() + if "_index" in elem.columns: + del elem["_index"] + # TODO: Bug in anndata + if "obs" in axis_name: + elem.index = adata.obs_names + getattr(adata, axis_name)[k] = elem + + return adata.to_memory() return adata diff --git a/tests/conftest.py b/tests/conftest.py index 73acaff9..d5acac47 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -78,6 +78,8 @@ def adata_with_h5_path_different_var_space( n_cells = [random.randint(50, 100) for _ in range(n_adatas)] adatas = [] for i, (m, n) in enumerate(zip(n_cells, n_features, strict=True)): + var_idx = [f"gene_{gene}" for gene in range(n // 2)] + [f"gene_{gene}_{i}" for gene in range(n // 2, n)] + obs_idx = np.arange(m).astype(str) adata = ad.AnnData( X=sparse_random(m, n, density=0.1, format="csr", dtype="f4"), obs=pd.DataFrame( @@ -86,12 +88,11 @@ def adata_with_h5_path_different_var_space( "store_id": [i] * m, "numeric": np.arange(m), }, - index=np.arange(m).astype(str), + index=obs_idx, ), - var=pd.DataFrame( - index=[f"gene_{gene}" for gene in range(n // 2)] + [f"gene_{gene}_{i}" for gene in range(n // 2, n)] - ), - obsm={"arr": np.random.randn(m, 10)}, + var=pd.DataFrame(index=var_idx), + obsm={"arr": np.random.randn(m, 10), "df": pd.DataFrame({"numeric": np.arange(m)}, index=obs_idx)}, + varm={"arr": np.random.randn(n, 10), "df": pd.DataFrame({"numeric": np.arange(n)}, index=var_idx)}, ) if all_adatas_have_raw or (i % 2 == 0): adata_raw = adata[:, adata.var.index[: (n // 2)]].copy() From 471690c1626c46ef4309c68f0a5652e414397bdc Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Thu, 15 Jan 2026 14:02:04 +0100 Subject: [PATCH 18/39] fix: try bumping python --- .readthedocs.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.readthedocs.yaml b/.readthedocs.yaml index c3f3f96f..4acb793c 100644 --- a/.readthedocs.yaml +++ b/.readthedocs.yaml @@ -3,7 +3,7 @@ version: 2 build: os: ubuntu-24.04 tools: - python: "3.12" + python: "3.14" jobs: create_environment: - asdf plugin add uv From 578a81183432085b0ed3efef25c3349f68e290ec Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Thu, 15 Jan 2026 14:09:44 +0100 Subject: [PATCH 19/39] fix: bound sphinx --- pyproject.toml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 9a8e219e..db74bebb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -49,7 +49,8 @@ optional-dependencies.doc = [ "myst-nb>=1.1", "pandas", "scanpydoc[theme,typehints]>=0.15.3", - "sphinx>=8.1", + # https://github.com/sphinx-toolbox/sphinx-toolbox/issues/201 + "sphinx>=8.1,<=8.2.3", "sphinx-autodoc-typehints", "sphinx-book-theme>=1", "sphinx-copybutton", From 5c6b670a02a39980530b628c40685da11739a028 Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Thu, 15 Jan 2026 14:27:09 +0100 Subject: [PATCH 20/39] fix: docs --- src/annbatch/io.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/annbatch/io.py b/src/annbatch/io.py index 5706a478..442fc9d2 100644 --- a/src/annbatch/io.py +++ b/src/annbatch/io.py @@ -278,7 +278,7 @@ def wrapper(*args, **kwargs): class DatasetCollection[T: (h5py.Group, zarr.Group)]: - """A preshuffled collection object including functionality for creating, adding to, and loading collections.""" + """A preshuffled collection object including functionality for creating, adding to, and loading collections shuffled by `annbatch`.""" _group: T @@ -350,15 +350,17 @@ def add( A key `src_path` is added to `obs` to indicate where individual row came from. We highly recommend making your indexes unique across files, and this function will call {meth}`AnnData.obs_names_make_unique`. Memory usage should be controlled by `n_obs_per_dataset` + `shuffle_slice_size` as so many rows will be read into memory before writing to disk. + After the dataset completes, a marker is added to the group's `attrs` to note that this dataset has been shuffled by `annbatch`. + This is not a stable API but only for internal purposes at the moment. Parameters ---------- adata_paths Paths to the AnnData files used to create the zarr store. load_adata - Function to customize lazy-loading the invidiual input anndata files. By default, {func}`anndata.experimental.read_lazy` is used. + Function to customize lazy-loading the invidiual input anndata files. By default, :func:`anndata.experimental.read_lazy` is used. If you only need a subset of the input anndata files' elems (e.g., only `X` and `obs`), you can provide a custom function here to speed up loading and harmonize your data. - The input to the function is a path to an anndata file, and the output is an anndata object which has `X` as a {class}`dask.array.Array`. + The input to the function is a path to an anndata file, and the output is an :class:`anndata.AnnData` object. var_subset Subset of gene names to include in the store. If None, all genes are included. Genes are subset based on the `var_names` attribute of the concatenated AnnData object. From b70f8d6b845b20cbc54c1bade6da4ac14c02e3e9 Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Thu, 15 Jan 2026 14:39:46 +0100 Subject: [PATCH 21/39] fix: on-disk encoding --- src/annbatch/io.py | 8 +++++--- tests/test_preshuffle.py | 25 ++++++++++++++++++------- 2 files changed, 23 insertions(+), 10 deletions(-) diff --git a/src/annbatch/io.py b/src/annbatch/io.py index 442fc9d2..71a1d7d5 100644 --- a/src/annbatch/io.py +++ b/src/annbatch/io.py @@ -27,6 +27,8 @@ from zarr.abc.codec import BytesBytesCodec +V1_ENCODING = {"encoding-type": "annbatch-preshuffled", "encoding-version": "0.1.0"} + def _round_down(num: int, divisor: int): return num - (num % divisor) @@ -318,7 +320,7 @@ def __iter__(self) -> Generator[T]: @property def is_empty(self) -> bool: """Wether or not there is an existing store at the group location.""" - return "annbatch-shuffled" not in self._group.attrs or ( + return not (V1_ENCODING.items() <= self._group.attrs.items()) or ( not self._group.attrs["annbatch-shuffled"] and len(self._dataset_keys) == 0 ) @@ -553,9 +555,9 @@ def _create_collection( self._group, f"{DATASET_PREFIX}_{i}", adata_chunk, dataset_kwargs={"compression": h5ad_compressor} ) if isinstance(self._group, zarr.Group): - self._group.update_attributes({"annbatch-shuffled": True}) + self._group.update_attributes(V1_ENCODING) else: - self._group.attrs["annbatch-shuffled"] = True + self._group.attrs.update(V1_ENCODING) def _add_to_collection( self, diff --git a/tests/test_preshuffle.py b/tests/test_preshuffle.py index 9f4b5683..dde59499 100644 --- a/tests/test_preshuffle.py +++ b/tests/test_preshuffle.py @@ -12,6 +12,7 @@ import zarr from annbatch import DatasetCollection, write_sharded +from annbatch.io import V1_ENCODING if TYPE_CHECKING: from collections.abc import Callable @@ -140,13 +141,19 @@ def test_store_addition_different_keys( ) +@pytest.mark.parametrize("open_store", [h5py.File, zarr.open_group]) def test_store_creation_default( adata_with_h5_path_different_var_space: tuple[ad.AnnData, Path], + open_store: Callable[[Path], zarr.Group | h5py.Group], ): var_subset = [f"gene_{i}" for i in range(100)] h5_files = sorted(adata_with_h5_path_different_var_space[1].iterdir()) - output_path = adata_with_h5_path_different_var_space[1].parent / "zarr_store_creation_test_default.zarr" - collection = DatasetCollection(output_path).add( + output_path = ( + adata_with_h5_path_different_var_space[1].parent + / f"zarr_store_creation_test_default.{'h5ad' if open_store is h5py.File else 'zarr'}" + ) + store = open_store(output_path, mode="w") + collection = DatasetCollection(store).add( [adata_with_h5_path_different_var_space[1] / f for f in h5_files if str(f).endswith(".h5ad")], var_subset=var_subset, zarr_sparse_chunk_size=10, @@ -156,9 +163,13 @@ def test_store_creation_default( n_obs_per_dataset=60, ) assert isinstance(ad.io.read_elem(next(iter(collection))).X, sp.csr_matrix) - assert sorted(glob.glob(str(output_path / "dataset_*"))) == sorted( - str(p) for p in (output_path).iterdir() if p.is_dir() - ) + # Test directory structure to make sure nothing extraneous was written + if isinstance(store, zarr.Group): + assert sorted(glob.glob(str(output_path / "dataset_*"))) == sorted( + str(p) for p in (output_path).iterdir() if p.is_dir() + ) + assert list(iter(collection)) == [store[k] for k in sorted(store.keys())] + assert V1_ENCODING.items() <= store.attrs.items() @pytest.mark.parametrize("shuffle", [pytest.param(True, id="shuffle"), pytest.param(False, id="no_shuffle")]) @@ -180,7 +191,7 @@ def test_store_creation( shuffle=shuffle, ) assert not DatasetCollection(output_path).is_empty - assert zarr.open(output_path).attrs["annbatch-shuffled"] + assert V1_ENCODING.items() <= zarr.open(output_path).attrs.items() adata_orig = adata_with_h5_path_different_var_space[0] # make sure all category dtypes match @@ -334,4 +345,4 @@ def test_empty(tmp_path: Path): # Doesn't matter what errors as long as this function runs, but not to completion with pytest.raises(TypeError): collection.add() - assert "annbatch-shuffled" not in g.attrs + assert not (V1_ENCODING.items() <= g.attrs.items()) From 52880568884f979dd68f2232940b7ee0fcec237d Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Thu, 15 Jan 2026 14:44:03 +0100 Subject: [PATCH 22/39] fix: more docs fixes --- src/annbatch/io.py | 52 ++++++------------------------------------ src/annbatch/loader.py | 9 ++++---- 2 files changed, 12 insertions(+), 49 deletions(-) diff --git a/src/annbatch/io.py b/src/annbatch/io.py index 71a1d7d5..54774642 100644 --- a/src/annbatch/io.py +++ b/src/annbatch/io.py @@ -343,14 +343,14 @@ def add( shuffle_slice_size: int = 1000, shuffle: bool = True, ) -> Self: - """Take AnnData paths, create or add to an on-disk set of AnnData datasets with uniform var spaces at the desired path (with `n_obs_per_dataset` rows per store if running for the first time). + """Take AnnData paths and create or add to an on-disk set of AnnData datasets with uniform var spaces at the desired path (with `n_obs_per_dataset` rows per store if running for the first time). The set of AnnData datasets is collectively referred to as a "collection" where each dataset is called `dataset_i.{zarr,h5ad}`. The main purpose of this function is to create shuffled sharded zarr datasets, which is the default behavior of this function. However, this function can also output h5 datasets and also unshuffled datasets as well. The var space is by default outer-joined initially, and then subsequently added datasets (i.e., on second calls to this function) are subsetted, but this behavior can be controlled by `var_subset`. A key `src_path` is added to `obs` to indicate where individual row came from. - We highly recommend making your indexes unique across files, and this function will call {meth}`AnnData.obs_names_make_unique`. + We highly recommend making your indexes unique across files, and this function will call :meth:`AnnData.obs_names_make_unique`. Memory usage should be controlled by `n_obs_per_dataset` + `shuffle_slice_size` as so many rows will be read into memory before writing to disk. After the dataset completes, a marker is added to the group's `attrs` to note that this dataset has been shuffled by `annbatch`. This is not a stable API but only for internal purposes at the moment. @@ -402,7 +402,6 @@ def add( ... obs=adata.obs.to_memory(), ... var=adata.var.to_memory(), ...) - >>> datasets = [ ... "path/to/first_adata.h5ad", ... "path/to/second_adata.h5ad", @@ -456,7 +455,7 @@ def _create_collection( However, this function can also output h5 datasets and also unshuffled datasets as well. The var space is by default outer-joined, but can be subsetted by `var_subset`. A key `src_path` is added to `obs` to indicate where individual row came from. - We highly recommend making your indexes unique across files, and this function will call {meth}`AnnData.obs_names_make_unique`. + We highly recommend making your indexes unique across files, and this function will call :meth:`AnnData.obs_names_make_unique`. Memory usage should be controlled by `n_obs_per_dataset` as so many rows will be read into memory before writing to disk. Parameters @@ -464,9 +463,9 @@ def _create_collection( adata_paths Paths to the AnnData files used to create the zarr store. load_adata - Function to customize lazy-loading the invidiual input anndata files. By default, {func}`anndata.experimental.read_lazy` is used. + Function to customize lazy-loading the invidiual input anndata files. By default, :func:`anndata.experimental.read_lazy` is used. If you only need a subset of the input anndata files' elems (e.g., only `X` and `obs`), you can provide a custom function here to speed up loading and harmonize your data. - The input to the function is a path to an anndata file, and the output is an anndata object which has `X` as a {class}`dask.array.Array`. + The input to the function is a path to an anndata file, and the output is an anndata object which has `X` as a :class:`dask.array.Array`. var_subset Subset of gene names to include in the store. If None, all genes are included. Genes are subset based on the `var_names` attribute of the concatenated AnnData object. @@ -493,29 +492,6 @@ def _create_collection( shuffle_slice_size How many contiguous rows to load into memory before shuffling at once. `(shuffle_slice_size // n_obs_per_dataset)` slices will be loaded of size `shuffle_slice_size`. - - Examples - -------- - >>> import anndata as ad - >>> from annbatch import DatasetCollection - # create a custom load function to only keep `.X`, `.obs` and `.var` in the output store - >>> def read_lazy_x_and_obs_only(path): - ... adata = ad.experimental.read_lazy(path) - ... return ad.AnnData( - ... X=adata.X, - ... obs=adata.obs.to_memory(), - ... var=adata.var.to_memory(), - ...) - - >>> datasets = [ - ... "path/to/first_adata.h5ad", - ... "path/to/second_adata.h5ad", - ... "path/to/third_adata.h5ad", - ... ] - >>> DatasetCollection("path/to/output/zarr_store.zarr").add( - ... datasets, - ... load_adata=read_lazy_x_and_obs_only, - ...) """ if not self.is_empty: raise RuntimeError("Cannot create a collection at a location that already has a shuffled collection") @@ -582,10 +558,10 @@ def _add_to_collection( adata_paths Paths to the anndata files to be appended to the collection of output chunks. load_adata - Function to customize loading the invidiual input anndata files. By default, {func}`anndata.read_h5ad` is used. + Function to customize loading the invidiual input anndata files. By default, :func:`anndata.read_h5ad` is used. If you only need a subset of the input anndata files' elems (e.g., only `X` and `obs`), you can provide a custom function here to speed up loading and harmonize your data. The input to the function is a path to an anndata file, and the output is an anndata object. - If the input data is too large to fit into memory, you should use `ad.experimental.read_lazy` instead. + If the input data is too large to fit into memory, you should use :func:`annndata.experimental.read_lazy` instead. zarr_sparse_chunk_size Size of the chunks to use for the `indices` and `data` of a sparse matrix in the zarr store. zarr_sparse_shard_size @@ -603,20 +579,6 @@ def _add_to_collection( How many contiguous rows to load into memory of the input data for pseudo-blockshuffling into the existing datasets. shuffle Whether or not to shuffle when adding. Otherwise, the incoming data will just be split up and appended. - - Examples - -------- - >>> import anndata as ad - >>> from annbatch import DatasetCollection - >>> datasets = [ - ... "path/to/first_adata.h5ad", - ... "path/to/second_adata.h5ad", - ... "path/to/third_adata.h5ad", - ... ] - >>> DatasetCollection("path/to/existing/preshuffled_collection.zarr").add( - ... datasets, - ... load_adata=ad.read_h5ad, # replace with ad.experimental.read_lazy if data does not fit into memory - ...) """ if self.is_empty: raise ValueError("Store is empty. Please run `DatasetCollection.add` first.") diff --git a/src/annbatch/loader.py b/src/annbatch/loader.py index 11000df1..9a17e75b 100644 --- a/src/annbatch/loader.py +++ b/src/annbatch/loader.py @@ -228,14 +228,15 @@ def n_var(self) -> int: def add_collection( self, collection: DatasetCollection, *, load_adata: Callable[[zarr.Group], ad.AnnData] = load_x_and_obs ): - """Load from an existing {class}`annbatch.DatasetCollection` + """Load from an existing :class:`annbatch.DatasetCollection` Parameters ---------- collection - _description_ - load_adata, optional - _description_, by default load_x_and_obs + The collection who on-disk datasets should be used in this loader. + load_adata + A custom load function - recall that whatever is found in :attr:`AnnData.X` and :attr:`AnnData.obs` will be yielded in batches. + Default is to just load `X` and `obs`. Returns ------- From 82a4330f5d4325ba5c16df785304fb03ffe5fd8c Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Thu, 15 Jan 2026 14:44:55 +0100 Subject: [PATCH 23/39] fix: `is_empty` --- src/annbatch/io.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/annbatch/io.py b/src/annbatch/io.py index 54774642..b9bb76a3 100644 --- a/src/annbatch/io.py +++ b/src/annbatch/io.py @@ -320,9 +320,7 @@ def __iter__(self) -> Generator[T]: @property def is_empty(self) -> bool: """Wether or not there is an existing store at the group location.""" - return not (V1_ENCODING.items() <= self._group.attrs.items()) or ( - not self._group.attrs["annbatch-shuffled"] and len(self._dataset_keys) == 0 - ) + return not (V1_ENCODING.items() <= self._group.attrs.items()) or len(self._dataset_keys) == 0 @_with_settings def add( From 00c19e01d773f98ccd3ef8312b4d254821318c77 Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Thu, 15 Jan 2026 14:45:39 +0100 Subject: [PATCH 24/39] fix: no doc string for `obs_names_make_unique` --- src/annbatch/io.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/annbatch/io.py b/src/annbatch/io.py index b9bb76a3..91b5f658 100644 --- a/src/annbatch/io.py +++ b/src/annbatch/io.py @@ -348,7 +348,7 @@ def add( However, this function can also output h5 datasets and also unshuffled datasets as well. The var space is by default outer-joined initially, and then subsequently added datasets (i.e., on second calls to this function) are subsetted, but this behavior can be controlled by `var_subset`. A key `src_path` is added to `obs` to indicate where individual row came from. - We highly recommend making your indexes unique across files, and this function will call :meth:`AnnData.obs_names_make_unique`. + We highly recommend making your indexes unique across files, and this function will call `AnnData.obs_names_make_unique`. Memory usage should be controlled by `n_obs_per_dataset` + `shuffle_slice_size` as so many rows will be read into memory before writing to disk. After the dataset completes, a marker is added to the group's `attrs` to note that this dataset has been shuffled by `annbatch`. This is not a stable API but only for internal purposes at the moment. @@ -453,7 +453,7 @@ def _create_collection( However, this function can also output h5 datasets and also unshuffled datasets as well. The var space is by default outer-joined, but can be subsetted by `var_subset`. A key `src_path` is added to `obs` to indicate where individual row came from. - We highly recommend making your indexes unique across files, and this function will call :meth:`AnnData.obs_names_make_unique`. + We highly recommend making your indexes unique across files, and this function will call `AnnData.obs_names_make_unique`. Memory usage should be controlled by `n_obs_per_dataset` as so many rows will be read into memory before writing to disk. Parameters From 21fac275e9a312c4506437f88ea2dc7c6ff7d545 Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Thu, 15 Jan 2026 14:55:02 +0100 Subject: [PATCH 25/39] fix: intersphinx --- src/annbatch/loader.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/annbatch/loader.py b/src/annbatch/loader.py index 9a17e75b..28a3b5b2 100644 --- a/src/annbatch/loader.py +++ b/src/annbatch/loader.py @@ -235,7 +235,7 @@ def add_collection( collection The collection who on-disk datasets should be used in this loader. load_adata - A custom load function - recall that whatever is found in :attr:`AnnData.X` and :attr:`AnnData.obs` will be yielded in batches. + A custom load function - recall that whatever is found in :attr:`~anndata.AnnData.X` and :attr:`~anndata.AnnData.obs` will be yielded in batches. Default is to just load `X` and `obs`. Returns From 7b882b504e7484135107d280ecec9e049dbb5b9f Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Thu, 15 Jan 2026 15:07:15 +0100 Subject: [PATCH 26/39] fix: more docs --- README.md | 9 +++++++-- docs/index.md | 14 +++----------- src/annbatch/loader.py | 6 +----- 3 files changed, 11 insertions(+), 18 deletions(-) diff --git a/README.md b/README.md index 1b26dbc5..dd9f736e 100644 --- a/README.md +++ b/README.md @@ -81,7 +81,8 @@ from pathlib import Path zarr.config.set( {"codec_pipeline.path": "zarrs.ZarrsCodecPipeline"} ) -# a directory containing `dataset_{i}.zarr` + +# Create a collection at the given path. The subgroups will all be anndata stores. collection = DatasetCollection("path/to/output/collection.zarr") collection.add( adata_paths=[ @@ -116,7 +117,11 @@ with ad.settings.override(remove_unused_categories=False): batch_size=4096, chunk_size=32, preload_nchunks=256, - ).add_collection(collection) # Automatically uses the on-disk `X` and full `obs` in the `Loader` but the `load_adata` arg can override this behavior (see `custom_load_func` above) + ) + # `add_collection` automatically uses the on-disk `X` and full `obs` in the `Loader` + # but the `load_adata` arg can override this behavior + # (see `custom_load_func` above for an example of customization). + ds = ds.add_collection(collection) # Iterate over dataloader (plugin replacement for torch.utils.DataLoader) for batch in ds: diff --git a/docs/index.md b/docs/index.md index 8207e2af..6c5e8208 100644 --- a/docs/index.md +++ b/docs/index.md @@ -9,7 +9,7 @@ Let's go through the above example: ### Preprocessing ```python -DatasetCollection("path/to/output/store.zarr").add( +colleciton = DatasetCollection("path/to/output/store.zarr").add( adata_paths=[ "path/to/your/file1.h5ad", "path/to/your/file2.h5ad" @@ -32,20 +32,12 @@ See the [zarr docs on sharding][] for more information. #### Chunked access ```python +# `add_collection` will automatically get everything in `X` and `obs` and yield it. ds = Loader( batch_size=4096, chunk_size=32, preload_nchunks=256, -).add_anndatas( - [ - ad.AnnData( - # note that you can open an anndata file using any type of zarr store - X=ad.io.sparse_dataset(zarr.open(p)["X"]), - obs=ad.io.read_elem(zarr.open(p)["obs"]), - ) - for p in PATH_TO_STORE.glob("*.zarr") - ] -) +).add_collection(collection) # Iterate over dataloader (plugin replacement for torch.utils.DataLoader) for batch in ds: diff --git a/src/annbatch/loader.py b/src/annbatch/loader.py index 28a3b5b2..b12532bd 100644 --- a/src/annbatch/loader.py +++ b/src/annbatch/loader.py @@ -227,7 +227,7 @@ def n_var(self) -> int: def add_collection( self, collection: DatasetCollection, *, load_adata: Callable[[zarr.Group], ad.AnnData] = load_x_and_obs - ): + ) -> Self: """Load from an existing :class:`annbatch.DatasetCollection` Parameters @@ -237,10 +237,6 @@ def add_collection( load_adata A custom load function - recall that whatever is found in :attr:`~anndata.AnnData.X` and :attr:`~anndata.AnnData.obs` will be yielded in batches. Default is to just load `X` and `obs`. - - Returns - ------- - _description_ """ if collection.is_empty: raise ValueError("DatasetCollection is empty") From ee912fd240b10513f013be17f69dbd3b66a9d600 Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Thu, 15 Jan 2026 16:01:33 +0100 Subject: [PATCH 27/39] fix: api changes according to felix --- README.md | 6 +- docs/index.md | 6 +- docs/notebooks/example.ipynb | 752 +++++++++++++++++------------------ src/annbatch/io.py | 4 +- src/annbatch/loader.py | 14 +- tests/conftest.py | 2 +- tests/test_dataset.py | 2 +- tests/test_preshuffle.py | 22 +- 8 files changed, 408 insertions(+), 400 deletions(-) diff --git a/README.md b/README.md index dd9f736e..391303c1 100644 --- a/README.md +++ b/README.md @@ -84,7 +84,7 @@ zarr.config.set( # Create a collection at the given path. The subgroups will all be anndata stores. collection = DatasetCollection("path/to/output/collection.zarr") -collection.add( +collection.add_adatas( adata_paths=[ "path/to/your/file1.h5ad", "path/to/your/file2.h5ad" @@ -118,10 +118,10 @@ with ad.settings.override(remove_unused_categories=False): chunk_size=32, preload_nchunks=256, ) - # `add_collection` automatically uses the on-disk `X` and full `obs` in the `Loader` + # `use_collection` automatically uses the on-disk `X` and full `obs` in the `Loader` # but the `load_adata` arg can override this behavior # (see `custom_load_func` above for an example of customization). - ds = ds.add_collection(collection) + ds = ds.use_collection(collection) # Iterate over dataloader (plugin replacement for torch.utils.DataLoader) for batch in ds: diff --git a/docs/index.md b/docs/index.md index 6c5e8208..ee945a50 100644 --- a/docs/index.md +++ b/docs/index.md @@ -9,7 +9,7 @@ Let's go through the above example: ### Preprocessing ```python -colleciton = DatasetCollection("path/to/output/store.zarr").add( +colleciton = DatasetCollection("path/to/output/store.zarr").add_adatas( adata_paths=[ "path/to/your/file1.h5ad", "path/to/your/file2.h5ad" @@ -32,12 +32,12 @@ See the [zarr docs on sharding][] for more information. #### Chunked access ```python -# `add_collection` will automatically get everything in `X` and `obs` and yield it. +# `use_collection` will automatically get everything in `X` and `obs` and yield it. ds = Loader( batch_size=4096, chunk_size=32, preload_nchunks=256, -).add_collection(collection) +).use_collection(collection) # Iterate over dataloader (plugin replacement for torch.utils.DataLoader) for batch in ds: diff --git a/docs/notebooks/example.ipynb b/docs/notebooks/example.ipynb index 15a6675c..d517c9ed 100644 --- a/docs/notebooks/example.ipynb +++ b/docs/notebooks/example.ipynb @@ -1,379 +1,379 @@ { - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Quickstart `annbatch`" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "This notebook will walk you through the following steps:\n", - "1. How to convert an existing collection of `anndata` files into a shuffled, zarr-based, collection of `anndata` datasets\n", - "2. How to load the converted collection using `annbatch`\n", - "3. Extend an existing collection with new `anndata` datasets" - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "metadata": { - "tags": [ - "hide-output" - ] - }, - "outputs": [], - "source": [ - "# !pip install annbatch[zarrs, torch]" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": { - "tags": [ - "hide-output" - ] - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "zsh:1: command not found: wget\n", - "zsh:1: command not found: wget\n" - ] - } - ], - "source": [ - "# Download two example datasets from CELLxGENE\n", - "!wget https://datasets.cellxgene.cziscience.com/866d7d5e-436b-4dbd-b7c1-7696487d452e.h5ad\n", - "!wget https://datasets.cellxgene.cziscience.com/f81463b8-4986-4904-a0ea-20ff02cbb317.h5ad" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "**IMPORTANT**: Configure zarrs\n", - "\n", - "This step is both required for converting existing `anndata` files into a performant, shuffled collection of datasets for mini batch loading" - ] - }, - { - "cell_type": "code", - "execution_count": 21, - "metadata": { - "tags": [ - "hide-output" - ] - }, - "outputs": [ - { - "data": { - "text/plain": [ - "" - ] - }, - "execution_count": 21, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "import zarr\n", - "\n", - "zarr.config.set({\"codec_pipeline.path\": \"zarrs.ZarrsCodecPipeline\"})" - ] - }, - { - "cell_type": "code", - "execution_count": 22, - "metadata": {}, - "outputs": [], - "source": [ - "import warnings\n", - "\n", - "# Suppress zarr vlen-utf8 codec warnings\n", - "warnings.filterwarnings(\n", - " \"ignore\",\n", - " message=\"The codec `vlen-utf8` is currently not part in the Zarr format 3 specification.*\",\n", - " category=UserWarning,\n", - " module=\"zarr.codecs.vlen_utf8\",\n", - ")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Converting existing `anndata` files into a shuffled collection" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "The conversion code will take care of the following things:\n", - "* Align (outer join) the gene spaces across all datasets listed in `adata_paths`\n", - " * The gene spaces are outer-joined based on the gene names provided in the `var_names` field of the individual `AnnData` objects.\n", - " * If you want to subset to specific gene space, you can provide a list of gene names via the `var_subset` parameter.\n", - "* Shuffle the cells across all datasets (this works on larger than memory datasets as well).\n", - " * This is important for block-wise shuffling during data loading.\n", - "* Shuffle the input files across multiple output datasets:\n", - " * The size of each individual output dataset can be controlled via the `n_obs_per_dataset` parameter.\n", - " * We recommend to choose a dataset size that comfortably fits into system memory.\n", - "\n", - "\n", - "You can apply custom data transformations to each input h5ad file by supplying a `load_adata` function to `DatasetCollection.add`" - ] - }, - { - "cell_type": "code", - "execution_count": 23, - "metadata": { - "tags": [ - "hide-output" - ] - }, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "checking for mismatched keys: 100%|██████████| 2/2 [00:00<00:00, 2.03it/s]\n", - "loading: 2it [00:00, 2.16it/s]\n", - "processing chunks: 0%| | 0/1 [00:00" + ] + }, + "execution_count": 21, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import zarr\n", + "\n", + "zarr.config.set({\"codec_pipeline.path\": \"zarrs.ZarrsCodecPipeline\"})" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "metadata": {}, + "outputs": [], + "source": [ + "import warnings\n", + "\n", + "# Suppress zarr vlen-utf8 codec warnings\n", + "warnings.filterwarnings(\n", + " \"ignore\",\n", + " message=\"The codec `vlen-utf8` is currently not part in the Zarr format 3 specification.*\",\n", + " category=UserWarning,\n", + " module=\"zarr.codecs.vlen_utf8\",\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Converting existing `anndata` files into a shuffled collection" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The conversion code will take care of the following things:\n", + "* Align (outer join) the gene spaces across all datasets listed in `adata_paths`\n", + " * The gene spaces are outer-joined based on the gene names provided in the `var_names` field of the individual `AnnData` objects.\n", + " * If you want to subset to specific gene space, you can provide a list of gene names via the `var_subset` parameter.\n", + "* Shuffle the cells across all datasets (this works on larger than memory datasets as well).\n", + " * This is important for block-wise shuffling during data loading.\n", + "* Shuffle the input files across multiple output datasets:\n", + " * The size of each individual output dataset can be controlled via the `n_obs_per_dataset` parameter.\n", + " * We recommend to choose a dataset size that comfortably fits into system memory.\n", + "\n", + "\n", + "You can apply custom data transformations to each input h5ad file by supplying a `load_adata` function to `DatasetCollection.add`" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "metadata": { + "tags": [ + "hide-output" + ] + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "checking for mismatched keys: 100%|██████████| 2/2 [00:00<00:00, 2.03it/s]\n", + "loading: 2it [00:00, 2.16it/s]\n", + "processing chunks: 0%| | 0/1 [00:00" + ] + }, + "execution_count": 23, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import anndata as ad\n", + "from annbatch import DatasetCollection\n", + "\n", + "\n", + "# For CELLxGENE data, the raw counts can either be found under .raw.X or under .X (if .raw is not supplied).\n", + "# To have a store that only contains raw counts, we can write the following load_adata function\n", + "def read_lazy_x_and_obs_only(path) -> ad.AnnData:\n", + " \"\"\"Custom load function to only load raw counts from CxG data.\"\"\"\n", + " # IMPORTANT: Large data should always be loaded lazily to reduce the memory footprint\n", + " adata_ = ad.experimental.read_lazy(path)\n", + " if adata_.raw is not None:\n", + " x = adata_.raw.X\n", + " var = adata_.raw.var\n", + " else:\n", + " x = adata_.X\n", + " var = adata_.var\n", + "\n", + " return ad.AnnData(\n", + " X=x,\n", + " obs=adata_.obs.to_memory()[\n", + " [\"cell_type\"]\n", + " ], # let's only take cell type since it is shared - otherwise DatasetCollection will warn about all the columns we are missing\n", + " var=var.to_memory(),\n", + " )\n", + "\n", + "\n", + "collection = DatasetCollection(zarr.open(\"annbatch_collection\"))\n", + "collection.add_adatas(\n", + " # List all the h5ad files you want to include in the collection\n", + " adata_paths=[\"866d7d5e-436b-4dbd-b7c1-7696487d452e.h5ad\", \"f81463b8-4986-4904-a0ea-20ff02cbb317.h5ad\"],\n", + " # Path to store the output collection\n", + " shuffle=True, # Whether to pre-shuffle the cells of the collection\n", + " n_obs_per_dataset=2_097_152, # Number of cells per dataset shard\n", + " var_subset=None, # Optionally subset the collection to a specific gene space\n", + " load_adata=read_lazy_x_and_obs_only,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Data loading example" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "metadata": { + "tags": [ + "hide-output" + ] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 24, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import anndata as ad\n", + "\n", + "from annbatch import Loader\n", + "\n", + "ds = Loader(\n", + " batch_size=4096, # Total number of obs per yielded batch\n", + " chunk_size=256, # Number of obs to load from disk contiguously - default settings should work well\n", + " preload_nchunks=32, # Number of chunks to preload + shuffle - default settings should work well\n", + " preload_to_gpu=False,\n", + " # If True, preloaded chunks are moved to GPU memory via `cupy`, which can put more pressure on GPU memory but will accelerate loading ~20%\n", + " to_torch=True,\n", + ")\n", + "\n", + "# Add in the shuffled data that should be used for training\n", + "ds.use_collection(collection)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "**IMPORTANT:**\n", + "* The `Loader` yields batches of sparse tensors.\n", + "* The conversion to dense tensors should be done on the GPU, as shown in the example below.\n", + " * First call `.cuda()` and then `.to_dense()`\n", + " * E.g. `x = x.cuda().to_dense()`\n", + " * This is significantly faster than doing the dense conversion on the CPU.\n" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "metadata": { + "tags": [ + "hide-output" + ] + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 0%| | 42/171792 [00:08<9:13:24, 5.17it/s]\n" + ] + } + ], + "source": [ + "# Iterate over dataloader\n", + "import tqdm\n", + "\n", + "for batch in tqdm.tqdm(ds):\n", + " x, obs = batch[\"data\"], batch[\"labels\"][\"cell_type\"]\n", + " # Important: Convert to dense on GPU\n", + " x = x.cuda().to_dense()\n", + " # Feed data into your model\n", + " ..." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Optional: Extend an existing collection with a new dataset" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "You might want to extend an existing pre-shuffled collection with a new dataset.\n", + "This can be done using the `add` method again.\n", + "\n", + "This function will take care of shuffling the new dataset into the existing collection without having to re-shuffle the entire collection." + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "metadata": { + "tags": [ + "hide-output" + ] + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "checking for mismatched keys: 100%|██████████| 1/1 [00:00<00:00, 1.60it/s]\n", + "loading: 1it [00:00, 1.73it/s]\n", + "checking for mismatched keys: 100%|██████████| 2/2 [00:00<00:00, 14.22it/s]\n", + "processing chunks: 0%| | 0/1 [00:00" + ] + }, + "execution_count": 27, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "collection.add_adatas(\n", + " adata_paths=[\n", + " \"866d7d5e-436b-4dbd-b7c1-7696487d452e.h5ad\",\n", + " ],\n", + " load_adata=read_lazy_x_and_obs_only,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.3" + } }, - { - "data": { - "text/plain": [ - "" - ] - }, - "execution_count": 23, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "import anndata as ad\n", - "from annbatch import DatasetCollection\n", - "\n", - "\n", - "# For CELLxGENE data, the raw counts can either be found under .raw.X or under .X (if .raw is not supplied).\n", - "# To have a store that only contains raw counts, we can write the following load_adata function\n", - "def read_lazy_x_and_obs_only(path) -> ad.AnnData:\n", - " \"\"\"Custom load function to only load raw counts from CxG data.\"\"\"\n", - " # IMPORTANT: Large data should always be loaded lazily to reduce the memory footprint\n", - " adata_ = ad.experimental.read_lazy(path)\n", - " if adata_.raw is not None:\n", - " x = adata_.raw.X\n", - " var = adata_.raw.var\n", - " else:\n", - " x = adata_.X\n", - " var = adata_.var\n", - "\n", - " return ad.AnnData(\n", - " X=x,\n", - " obs=adata_.obs.to_memory()[\n", - " [\"cell_type\"]\n", - " ], # let's only take cell type since it is shared - otherwise DatasetCollection will warn about all the columns we are missing\n", - " var=var.to_memory(),\n", - " )\n", - "\n", - "\n", - "collection = DatasetCollection(zarr.open(\"annbatch_collection\"))\n", - "collection.add(\n", - " # List all the h5ad files you want to include in the collection\n", - " adata_paths=[\"866d7d5e-436b-4dbd-b7c1-7696487d452e.h5ad\", \"f81463b8-4986-4904-a0ea-20ff02cbb317.h5ad\"],\n", - " # Path to store the output collection\n", - " shuffle=True, # Whether to pre-shuffle the cells of the collection\n", - " n_obs_per_dataset=2_097_152, # Number of cells per dataset shard\n", - " var_subset=None, # Optionally subset the collection to a specific gene space\n", - " load_adata=read_lazy_x_and_obs_only,\n", - ")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Data loading example" - ] - }, - { - "cell_type": "code", - "execution_count": 24, - "metadata": { - "tags": [ - "hide-output" - ] - }, - "outputs": [ - { - "data": { - "text/plain": [ - "" - ] - }, - "execution_count": 24, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "import anndata as ad\n", - "\n", - "from annbatch import Loader\n", - "\n", - "ds = Loader(\n", - " batch_size=4096, # Total number of obs per yielded batch\n", - " chunk_size=256, # Number of obs to load from disk contiguously - default settings should work well\n", - " preload_nchunks=32, # Number of chunks to preload + shuffle - default settings should work well\n", - " preload_to_gpu=False,\n", - " # If True, preloaded chunks are moved to GPU memory via `cupy`, which can put more pressure on GPU memory but will accelerate loading ~20%\n", - " to_torch=True,\n", - ")\n", - "\n", - "# Add in the shuffled data that should be used for training\n", - "ds.add_collection(collection)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "**IMPORTANT:**\n", - "* The `Loader` yields batches of sparse tensors.\n", - "* The conversion to dense tensors should be done on the GPU, as shown in the example below.\n", - " * First call `.cuda()` and then `.to_dense()`\n", - " * E.g. `x = x.cuda().to_dense()`\n", - " * This is significantly faster than doing the dense conversion on the CPU.\n" - ] - }, - { - "cell_type": "code", - "execution_count": 26, - "metadata": { - "tags": [ - "hide-output" - ] - }, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - " 0%| | 42/171792 [00:08<9:13:24, 5.17it/s]\n" - ] - } - ], - "source": [ - "# Iterate over dataloader\n", - "import tqdm\n", - "\n", - "for batch in tqdm.tqdm(ds):\n", - " x, obs = batch[\"data\"], batch[\"labels\"][\"cell_type\"]\n", - " # Important: Convert to dense on GPU\n", - " x = x.cuda().to_dense()\n", - " # Feed data into your model\n", - " ..." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Optional: Extend an existing collection with a new dataset" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "You might want to extend an existing pre-shuffled collection with a new dataset.\n", - "This can be done using the `add` method again.\n", - "\n", - "This function will take care of shuffling the new dataset into the existing collection without having to re-shuffle the entire collection." - ] - }, - { - "cell_type": "code", - "execution_count": 27, - "metadata": { - "tags": [ - "hide-output" - ] - }, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "checking for mismatched keys: 100%|██████████| 1/1 [00:00<00:00, 1.60it/s]\n", - "loading: 1it [00:00, 1.73it/s]\n", - "checking for mismatched keys: 100%|██████████| 2/2 [00:00<00:00, 14.22it/s]\n", - "processing chunks: 0%| | 0/1 [00:00" - ] - }, - "execution_count": 27, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "collection.add(\n", - " adata_paths=[\n", - " \"866d7d5e-436b-4dbd-b7c1-7696487d452e.h5ad\",\n", - " ],\n", - " load_adata=read_lazy_x_and_obs_only,\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "venv", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.12.3" - } - }, - "nbformat": 4, - "nbformat_minor": 4 + "nbformat": 4, + "nbformat_minor": 4 } diff --git a/src/annbatch/io.py b/src/annbatch/io.py index 91b5f658..a66c6236 100644 --- a/src/annbatch/io.py +++ b/src/annbatch/io.py @@ -323,7 +323,7 @@ def is_empty(self) -> bool: return not (V1_ENCODING.items() <= self._group.attrs.items()) or len(self._dataset_keys) == 0 @_with_settings - def add( + def add_adatas( self, adata_paths: Iterable[PathLike[str]] | Iterable[str], *, @@ -405,7 +405,7 @@ def add( ... "path/to/second_adata.h5ad", ... "path/to/third_adata.h5ad", ... ] - >>> DatasetCollection("path/to/output/zarr_store.zarr").add( + >>> DatasetCollection("path/to/output/zarr_store.zarr").add_adatas( ... datasets, ... load_adata=read_lazy_x_and_obs_only, ...) diff --git a/src/annbatch/loader.py b/src/annbatch/loader.py index b12532bd..d3014421 100644 --- a/src/annbatch/loader.py +++ b/src/annbatch/loader.py @@ -225,10 +225,12 @@ def n_var(self) -> int: """ return self._shapes[0][1] - def add_collection( + def use_collection( self, collection: DatasetCollection, *, load_adata: Callable[[zarr.Group], ad.AnnData] = load_x_and_obs ) -> Self: - """Load from an existing :class:`annbatch.DatasetCollection` + """Load from an existing :class:`annbatch.DatasetCollection`. + + This function can only be called once. If you want to manually add more data, use :meth:`Loader.add_anndatas` or open an issue. Parameters ---------- @@ -240,8 +242,14 @@ def add_collection( """ if collection.is_empty: raise ValueError("DatasetCollection is empty") + if getattr(self, "_collection_added", False): + raise RuntimeError( + "You should not add multiple collections, independently shuffled - please preshuffle multiple collections, use `add_anndatas` manually if you know what you are doing, or open an issue if you believe that this should be supported at an API level higher than `add_anndatas`." + ) adatas = [load_adata(g) for g in collection] - return self.add_anndatas(adatas) + self.add_anndatas(adatas) + self._collection_added = True + return self def add_anndatas( self, diff --git a/tests/conftest.py b/tests/conftest.py index d5acac47..b2a41fce 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -112,7 +112,7 @@ def simple_collection( ) -> tuple[DatasetCollection, ad.AnnData]: zarr_stores = sorted(f for f in adata_with_zarr_path_same_var_space[1].iterdir() if f.is_dir()) output_path = Path(tmpdir_factory.mktemp("zarr_folder")) / "simple_fixture.zarr" - collection = DatasetCollection(output_path).add( + collection = DatasetCollection(output_path).add_adatas( zarr_stores, zarr_sparse_chunk_size=10, zarr_sparse_shard_size=20, diff --git a/tests/test_dataset.py b/tests/test_dataset.py index bebbe3c7..cb3aa591 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -99,7 +99,7 @@ def concat(datas: list[Data | ad.AnnData]) -> ListData | list[ad.AnnData]: batch_size=batch_size, preload_to_gpu=preload_to_gpu, to_torch=False, - ).add_collection( + ).use_collection( collection, **( {"load_adata": lambda group: open_func(group, use_zarrs=use_zarrs, use_anndata=True)} diff --git a/tests/test_preshuffle.py b/tests/test_preshuffle.py index dde59499..22817ba4 100644 --- a/tests/test_preshuffle.py +++ b/tests/test_preshuffle.py @@ -50,7 +50,7 @@ def test_store_creation_warnings_with_different_keys(elem_name: Literal["obsm", adata_1.write_h5ad(path_1) adata_2.write_h5ad(path_2) with pytest.warns(UserWarning, match=rf"Found {elem_name} keys.* not present in all anndatas"): - DatasetCollection(tmp_path / "collection.zarr").add( + DatasetCollection(tmp_path / "collection.zarr").add_adatas( [path_1, path_2], zarr_sparse_chunk_size=10, zarr_sparse_shard_size=20, @@ -67,7 +67,7 @@ def test_store_creation_no_warnings_with_custom_load(tmp_path: Path): path_2 = tmp_path / "with_extra_key.h5ad" adata_1.write_h5ad(path_1) adata_2.write_h5ad(path_2) - collection = DatasetCollection(tmp_path / "collection.zarr").add( + collection = DatasetCollection(tmp_path / "collection.zarr").add_adatas( [path_1, path_2], zarr_sparse_chunk_size=10, zarr_sparse_shard_size=20, @@ -88,7 +88,7 @@ def test_store_creation_path_added_to_obs(tmp_path: Path): adata_2.write_h5ad(path_2) paths = [path_1, path_2] output_dir = tmp_path / "path_src_collection.zarr" - collection = DatasetCollection(output_dir).add( + collection = DatasetCollection(output_dir).add_adatas( paths, zarr_sparse_chunk_size=10, zarr_sparse_shard_size=20, @@ -116,7 +116,7 @@ def test_store_addition_different_keys( adata_orig.write_h5ad(orig_path) output_path = tmp_path / "zarr_store_addition_different_keys.zarr" collection = DatasetCollection(output_path) - collection.add( + collection.add_adatas( [orig_path], zarr_sparse_chunk_size=10, zarr_sparse_shard_size=20, @@ -131,7 +131,7 @@ def test_store_addition_different_keys( additional_path = tmp_path / "with_extra_key.h5ad" adata.write_h5ad(additional_path) with pytest.warns(UserWarning, match=rf"Found {elem_name} keys.* not present in all anndatas"): - collection.add( + collection.add_adatas( [additional_path], load_adata=load_adata, zarr_sparse_chunk_size=10, @@ -153,7 +153,7 @@ def test_store_creation_default( / f"zarr_store_creation_test_default.{'h5ad' if open_store is h5py.File else 'zarr'}" ) store = open_store(output_path, mode="w") - collection = DatasetCollection(store).add( + collection = DatasetCollection(store).add_adatas( [adata_with_h5_path_different_var_space[1] / f for f in h5_files if str(f).endswith(".h5ad")], var_subset=var_subset, zarr_sparse_chunk_size=10, @@ -180,7 +180,7 @@ def test_store_creation( var_subset = [f"gene_{i}" for i in range(100)] h5_files = sorted(adata_with_h5_path_different_var_space[1].iterdir()) output_path = adata_with_h5_path_different_var_space[1].parent / f"zarr_store_creation_test_{shuffle}.zarr" - collection = DatasetCollection(output_path).add( + collection = DatasetCollection(output_path).add_adatas( [adata_with_h5_path_different_var_space[1] / f for f in h5_files if str(f).endswith(".h5ad")], var_subset=var_subset, zarr_sparse_chunk_size=10, @@ -261,7 +261,7 @@ def test_mismatched_raw_concat( h5_files = sorted(adata_with_h5_path_different_var_space[1].iterdir()) output_path = adata_with_h5_path_different_var_space[1].parent / "zarr_store_creation_test_heterogeneous.zarr" h5_paths = [adata_with_h5_path_different_var_space[1] / f for f in h5_files if str(f).endswith(".h5ad")] - collection = DatasetCollection(output_path).add( + collection = DatasetCollection(output_path).add_adatas( h5_paths, zarr_sparse_chunk_size=10, zarr_sparse_shard_size=20, @@ -305,7 +305,7 @@ def test_store_extension( additional = all_h5_paths[4:] # don't add everything to get a "different" var space # create new store collection = DatasetCollection(store_path) - collection.add( + collection.add_adatas( original, zarr_sparse_chunk_size=10, zarr_sparse_shard_size=20, @@ -315,7 +315,7 @@ def test_store_extension( shuffle=True, ) # add h5ads to existing store - collection.add( + collection.add_adatas( additional, load_adata=load_adata, zarr_sparse_chunk_size=10, @@ -344,5 +344,5 @@ def test_empty(tmp_path: Path): assert collection.is_empty # Doesn't matter what errors as long as this function runs, but not to completion with pytest.raises(TypeError): - collection.add() + collection.add_adatas() assert not (V1_ENCODING.items() <= g.attrs.items()) From 23aeaf9f22d3da3dc6bfd9bbf9c7a7c82c33ef1a Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Thu, 15 Jan 2026 16:04:17 +0100 Subject: [PATCH 28/39] fix: add test --- tests/test_dataset.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/tests/test_dataset.py b/tests/test_dataset.py index cb3aa591..e30a38cf 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -231,6 +231,13 @@ def test_bad_adata_X_type(adata_with_zarr_path_same_var_space: tuple[ad.AnnData, ds.add_dataset(**data) +def test_use_collection_twice(simple_collection: tuple[ad.AnnData, DatasetCollection]): + ds = Loader() + ds = ds.use_collection(simple_collection[1]) + with pytest.raises(RuntimeError, match="You should not add multiple collections"): + ds.use_collection(simple_collection[1]) + + @pytest.mark.skipif(not find_spec("torch"), reason="need torch installed") @pytest.mark.parametrize( "preload_to_gpu", From 03024164aa9fc64e29528c6c99f8ec53df13e5be Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Thu, 15 Jan 2026 17:32:49 +0100 Subject: [PATCH 29/39] fix: `_create_chunks_for_shuffling` bug --- src/annbatch/io.py | 12 +++++++----- tests/conftest.py | 1 + tests/test_preshuffle.py | 5 +++++ 3 files changed, 13 insertions(+), 5 deletions(-) diff --git a/src/annbatch/io.py b/src/annbatch/io.py index a66c6236..53d43035 100644 --- a/src/annbatch/io.py +++ b/src/annbatch/io.py @@ -20,6 +20,8 @@ from tqdm.auto import tqdm from zarr.codecs import BloscCodec, BloscShuffle +from annbatch.utils import split_given_size + if TYPE_CHECKING: from collections.abc import Callable, Generator, Iterable, Mapping from os import PathLike @@ -63,6 +65,8 @@ def write_sharded( Number of obs elements per dense shard along the first axis compressors The compressors to pass to `zarr`. + key + The key to which this object should be written - by default the root, in which case the *entire* store (not just the group) is cleared first. """ ad.settings.zarr_write_format = 3 @@ -195,14 +199,12 @@ def _lazy_load_anndatas( def _create_chunks_for_shuffling( n_obs: int, shuffle_n_obs_per_dataset: int = 1_048_576, shuffle_slice_size: int = 1000, shuffle: bool = True -): +) -> list[np.ndarray]: # this splits the array up into `shuffle_slice_size` contiguous runs - idxs = np.array_split(np.arange(n_obs), np.ceil(n_obs / shuffle_slice_size)) + idxs = split_given_size(np.arange(n_obs), shuffle_slice_size) if shuffle: random.shuffle(idxs) - idxs = np.concatenate(idxs) - idxs = np.array_split(idxs, np.ceil(n_obs / shuffle_n_obs_per_dataset)) - return idxs + return [idx.ravel() for idx in split_given_size(idxs, max(1, shuffle_n_obs_per_dataset // shuffle_slice_size))] def _compute_blockwise(x: DaskArray) -> sp.spmatrix: diff --git a/tests/conftest.py b/tests/conftest.py index b2a41fce..e5538c71 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -119,5 +119,6 @@ def simple_collection( zarr_dense_chunk_size=10, zarr_dense_shard_size=20, n_obs_per_dataset=60, + shuffle_slice_size=10, ) return ad.concat([ad.io.read_elem(ds) for ds in collection], join="outer"), collection diff --git a/tests/test_preshuffle.py b/tests/test_preshuffle.py index 22817ba4..95383274 100644 --- a/tests/test_preshuffle.py +++ b/tests/test_preshuffle.py @@ -214,7 +214,12 @@ def test_store_creation( ) assert "arr" in adata.obsm if shuffle: + # If it's shuffled I'd expect more than 80% of elements to be out of order + assert sum(adata_orig.obs_names != adata.obs_names) > (0.8 * adata.shape[0]) + assert adata_orig.obs_names.isin(adata.obs_names).all() adata = adata[adata_orig.obs_names].copy() + else: + assert (adata_orig.obs_names == adata.obs_names).all() np.testing.assert_array_equal( adata.X if isinstance(adata.X, np.ndarray) else adata.X.toarray(), adata_orig.X if isinstance(adata_orig.X, np.ndarray) else adata_orig.X.toarray(), From b25657efbe39acee34b3ef7eb57be11e2984cc0f Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Thu, 15 Jan 2026 18:50:55 +0100 Subject: [PATCH 30/39] fix: actually shuffle chunks --- docs/notebooks/example.ipynb | 58 ++++++++++++++++++++---------------- src/annbatch/io.py | 41 ++++++++++++++++--------- tests/test_preshuffle.py | 18 +++++++++-- 3 files changed, 75 insertions(+), 42 deletions(-) diff --git a/docs/notebooks/example.ipynb b/docs/notebooks/example.ipynb index d517c9ed..72a7da6b 100644 --- a/docs/notebooks/example.ipynb +++ b/docs/notebooks/example.ipynb @@ -32,7 +32,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 28, "metadata": { "tags": [ "hide-output" @@ -65,7 +65,7 @@ }, { "cell_type": "code", - "execution_count": 21, + "execution_count": 1, "metadata": { "tags": [ "hide-output" @@ -75,10 +75,10 @@ { "data": { "text/plain": [ - "" + "" ] }, - "execution_count": 21, + "execution_count": 1, "metadata": {}, "output_type": "execute_result" } @@ -91,7 +91,7 @@ }, { "cell_type": "code", - "execution_count": 22, + "execution_count": 2, "metadata": {}, "outputs": [], "source": [ @@ -133,7 +133,7 @@ }, { "cell_type": "code", - "execution_count": 23, + "execution_count": 3, "metadata": { "tags": [ "hide-output" @@ -144,20 +144,24 @@ "name": "stderr", "output_type": "stream", "text": [ - "checking for mismatched keys: 100%|██████████| 2/2 [00:00<00:00, 2.03it/s]\n", - "loading: 2it [00:00, 2.16it/s]\n", - "processing chunks: 0%| | 0/1 [00:00" + "" ] }, - "execution_count": 23, + "execution_count": 3, "metadata": {}, "output_type": "execute_result" } @@ -189,7 +193,7 @@ " )\n", "\n", "\n", - "collection = DatasetCollection(zarr.open(\"annbatch_collection\"))\n", + "collection = DatasetCollection(zarr.open(\"annbatch_collection\", mode=\"w\"))\n", "collection.add_adatas(\n", " # List all the h5ad files you want to include in the collection\n", " adata_paths=[\"866d7d5e-436b-4dbd-b7c1-7696487d452e.h5ad\", \"f81463b8-4986-4904-a0ea-20ff02cbb317.h5ad\"],\n", @@ -210,7 +214,7 @@ }, { "cell_type": "code", - "execution_count": 24, + "execution_count": 4, "metadata": { "tags": [ "hide-output" @@ -220,10 +224,10 @@ { "data": { "text/plain": [ - "" + "" ] }, - "execution_count": 24, + "execution_count": 4, "metadata": {}, "output_type": "execute_result" } @@ -260,7 +264,7 @@ }, { "cell_type": "code", - "execution_count": 26, + "execution_count": 6, "metadata": { "tags": [ "hide-output" @@ -271,7 +275,7 @@ "name": "stderr", "output_type": "stream", "text": [ - " 0%| | 42/171792 [00:08<9:13:24, 5.17it/s]\n" + " 0%| | 42/171792 [00:08<9:35:18, 4.98it/s] \n" ] } ], @@ -306,7 +310,7 @@ }, { "cell_type": "code", - "execution_count": 27, + "execution_count": 5, "metadata": { "tags": [ "hide-output" @@ -318,22 +322,26 @@ "output_type": "stream", "text": [ "checking for mismatched keys: 100%|██████████| 1/1 [00:00<00:00, 1.60it/s]\n", - "loading: 1it [00:00, 1.73it/s]\n", - "checking for mismatched keys: 100%|██████████| 2/2 [00:00<00:00, 14.22it/s]\n", - "processing chunks: 0%| | 0/1 [00:00" + "" ] }, - "execution_count": 27, + "execution_count": 5, "metadata": {}, "output_type": "execute_result" } diff --git a/src/annbatch/io.py b/src/annbatch/io.py index 53d43035..fc92d3bb 100644 --- a/src/annbatch/io.py +++ b/src/annbatch/io.py @@ -204,7 +204,20 @@ def _create_chunks_for_shuffling( idxs = split_given_size(np.arange(n_obs), shuffle_slice_size) if shuffle: random.shuffle(idxs) - return [idx.ravel() for idx in split_given_size(idxs, max(1, shuffle_n_obs_per_dataset // shuffle_slice_size))] + n_slices_per_dataset = int(shuffle_n_obs_per_dataset // shuffle_slice_size) + # In this case `shuffle_n_obs_per_dataset` is bigger than the size of the dataset or the slice size is probably too big. + if n_obs < shuffle_n_obs_per_dataset or n_slices_per_dataset <= 1: + chunks = [np.concatenate(idxs)] + else: + # unfortunately, this is the only way to prevent numpy.split from trying to np.array the idxs list, which can have uneven elements. + idxs = np.array([slice(int(idx[0]), int(idx[-1] + 1)) for idx in idxs]) + chunks = [ + np.concatenate([np.arange(s.start, s.stop) for s in idx]) + for idx in split_given_size(idxs, n_slices_per_dataset) + ] + if sum(len(idx) for idx in chunks) != n_obs or (np.sort(np.concatenate(chunks)) != np.arange(n_obs)).any(): + raise RuntimeError(f"This should not happen, please open an issue, {np.sort(np.concatenate(chunks))}") + return chunks def _compute_blockwise(x: DaskArray) -> sp.spmatrix: @@ -412,6 +425,8 @@ def add_adatas( ... load_adata=read_lazy_x_and_obs_only, ...) """ + if shuffle_slice_size > n_obs_per_dataset: + raise ValueError("Cannot have a large slice size than observations per dataset") shared_kwargs = { "adata_paths": adata_paths, "load_adata": load_adata, @@ -498,6 +513,7 @@ def _create_collection( _check_for_mismatched_keys(adata_paths, load_adata=load_adata) adata_concat = _lazy_load_anndatas(adata_paths, load_adata=load_adata) adata_concat.obs_names_make_unique() + n_obs_per_dataset = min(adata_concat.shape[0], n_obs_per_dataset) chunks = _create_chunks_for_shuffling( adata_concat.shape[0], n_obs_per_dataset, shuffle_slice_size, shuffle=shuffle ) @@ -521,8 +537,8 @@ def _create_collection( adata_chunk, sparse_chunk_size=zarr_sparse_chunk_size, sparse_shard_size=zarr_sparse_shard_size, - dense_chunk_size=zarr_dense_chunk_size, - dense_shard_size=zarr_dense_shard_size, + dense_chunk_size=min(adata_chunk.shape[0], zarr_dense_chunk_size), + dense_shard_size=min(adata_chunk.shape[0], zarr_dense_shard_size), compressors=zarr_compressor, key=f"{DATASET_PREFIX}_{i}", ) @@ -588,15 +604,12 @@ def _add_to_collection( adata_concat = _lazy_load_anndatas(adata_paths, load_adata=load_adata) # Check for mismatched keys between datasets and the inputs. _check_for_mismatched_keys([adata_concat] + [self._group[k] for k in self._dataset_keys]) - if isinstance(adata_concat.X, DaskArray): - chunks = _create_chunks_for_shuffling( - adata_concat.shape[0], - np.ceil(len(adata_concat) / len(self._dataset_keys)), - shuffle_slice_size, - shuffle=shuffle, - ) - else: - chunks = np.array_split(np.random.default_rng().permutation(len(adata_concat)), len(self._dataset_keys)) + chunks = _create_chunks_for_shuffling( + adata_concat.shape[0], + np.ceil(len(adata_concat) / len(self._dataset_keys)), + shuffle_slice_size, + shuffle=shuffle, + ) adata_concat.obs_names_make_unique() @@ -619,8 +632,8 @@ def _add_to_collection( adata, sparse_chunk_size=zarr_sparse_chunk_size, sparse_shard_size=zarr_sparse_shard_size, - dense_chunk_size=zarr_dense_chunk_size, - dense_shard_size=zarr_dense_shard_size, + dense_chunk_size=min(adata.shape[0], zarr_dense_chunk_size), + dense_shard_size=min(adata.shape[0], zarr_dense_shard_size), compressors=zarr_compressor, key=dataset, ) diff --git a/tests/test_preshuffle.py b/tests/test_preshuffle.py index 95383274..6129bd3c 100644 --- a/tests/test_preshuffle.py +++ b/tests/test_preshuffle.py @@ -57,6 +57,7 @@ def test_store_creation_warnings_with_different_keys(elem_name: Literal["obsm", zarr_dense_chunk_size=5, zarr_dense_shard_size=10, n_obs_per_dataset=10, + shuffle_slice_size=10, ) @@ -74,6 +75,7 @@ def test_store_creation_no_warnings_with_custom_load(tmp_path: Path): zarr_dense_chunk_size=5, zarr_dense_shard_size=10, n_obs_per_dataset=10, + shuffle_slice_size=5, load_adata=lambda x: ad.AnnData(X=ad.io.read_elem(h5py.File(x)["X"])), ) assert len(ad.read_zarr(next(iter(collection))).layers.keys()) == 0 @@ -95,6 +97,7 @@ def test_store_creation_path_added_to_obs(tmp_path: Path): zarr_dense_chunk_size=5, zarr_dense_shard_size=10, n_obs_per_dataset=10, + shuffle_slice_size=5, shuffle=False, ) adata_result = ad.concat([ad.io.read_elem(g) for g in collection], join="outer") @@ -123,6 +126,7 @@ def test_store_addition_different_keys( zarr_dense_chunk_size=10, zarr_dense_shard_size=20, n_obs_per_dataset=50, + shuffle_slice_size=10, ) extra_args = { elem_name: {"arr" if elem_name != "raw" else "X": np.random.randn(10, 20) if elem_name != "obs" else ["a"] * 10} @@ -138,6 +142,8 @@ def test_store_addition_different_keys( zarr_sparse_shard_size=20, zarr_dense_chunk_size=5, zarr_dense_shard_size=10, + n_obs_per_dataset=50, + shuffle_slice_size=10, ) @@ -160,7 +166,8 @@ def test_store_creation_default( zarr_sparse_shard_size=20, zarr_dense_chunk_size=10, zarr_dense_shard_size=20, - n_obs_per_dataset=60, + n_obs_per_dataset=50, + shuffle_slice_size=10, ) assert isinstance(ad.io.read_elem(next(iter(collection))).X, sp.csr_matrix) # Test directory structure to make sure nothing extraneous was written @@ -187,7 +194,8 @@ def test_store_creation( zarr_sparse_shard_size=20, zarr_dense_chunk_size=5, zarr_dense_shard_size=10, - n_obs_per_dataset=60, + n_obs_per_dataset=50, + shuffle_slice_size=10, shuffle=shuffle, ) assert not DatasetCollection(output_path).is_empty @@ -272,7 +280,8 @@ def test_mismatched_raw_concat( zarr_sparse_shard_size=20, zarr_dense_chunk_size=10, zarr_dense_shard_size=20, - n_obs_per_dataset=60, + n_obs_per_dataset=50, + shuffle_slice_size=10, shuffle=False, # don't shuffle -> want to check if the right attributes get taken load_adata=_read_lazy_x_and_obs_only_from_raw, ) @@ -317,6 +326,7 @@ def test_store_extension( zarr_dense_chunk_size=10, zarr_dense_shard_size=20, n_obs_per_dataset=60, + shuffle_slice_size=10, shuffle=True, ) # add h5ads to existing store @@ -327,6 +337,8 @@ def test_store_extension( zarr_sparse_shard_size=20, zarr_dense_chunk_size=5, zarr_dense_shard_size=10, + n_obs_per_dataset=50, + shuffle_slice_size=10, ) adatas_on_disk = [ad.io.read_elem(g) for g in collection] adata = ad.concat(adatas_on_disk) From e03c7083140c3c15f9ce6441958bc9cabf3a94c5 Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Thu, 15 Jan 2026 21:27:36 +0100 Subject: [PATCH 31/39] fix: key sorting --- src/annbatch/io.py | 7 ++++--- tests/conftest.py | 2 +- tests/test_preshuffle.py | 2 +- 3 files changed, 6 insertions(+), 5 deletions(-) diff --git a/src/annbatch/io.py b/src/annbatch/io.py index fc92d3bb..efe4fee6 100644 --- a/src/annbatch/io.py +++ b/src/annbatch/io.py @@ -215,8 +215,6 @@ def _create_chunks_for_shuffling( np.concatenate([np.arange(s.start, s.stop) for s in idx]) for idx in split_given_size(idxs, n_slices_per_dataset) ] - if sum(len(idx) for idx in chunks) != n_obs or (np.sort(np.concatenate(chunks)) != np.arange(n_obs)).any(): - raise RuntimeError(f"This should not happen, please open an issue, {np.sort(np.concatenate(chunks))}") return chunks @@ -326,7 +324,10 @@ def __init__(self, group: zarr.Group | h5py.Group | str | Path, *, mode: Literal @property def _dataset_keys(self) -> list[str]: - return sorted([k for k in self._group.keys() if re.match(rf"{DATASET_PREFIX}_([0-9]*)", k) is not None]) + return sorted( + [k for k in self._group.keys() if re.match(rf"{DATASET_PREFIX}_([0-9]*)", k) is not None], + key=lambda x: int(x.split("_")[1]), + ) def __iter__(self) -> Generator[T]: for k in self._dataset_keys: diff --git a/tests/conftest.py b/tests/conftest.py index e5538c71..24321bbd 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -79,7 +79,7 @@ def adata_with_h5_path_different_var_space( adatas = [] for i, (m, n) in enumerate(zip(n_cells, n_features, strict=True)): var_idx = [f"gene_{gene}" for gene in range(n // 2)] + [f"gene_{gene}_{i}" for gene in range(n // 2, n)] - obs_idx = np.arange(m).astype(str) + obs_idx = np.arange(m).astype(str) + f"-{i}" adata = ad.AnnData( X=sparse_random(m, n, density=0.1, format="csr", dtype="f4"), obs=pd.DataFrame( diff --git a/tests/test_preshuffle.py b/tests/test_preshuffle.py index 6129bd3c..ced65ae3 100644 --- a/tests/test_preshuffle.py +++ b/tests/test_preshuffle.py @@ -280,7 +280,7 @@ def test_mismatched_raw_concat( zarr_sparse_shard_size=20, zarr_dense_chunk_size=10, zarr_dense_shard_size=20, - n_obs_per_dataset=50, + n_obs_per_dataset=30, shuffle_slice_size=10, shuffle=False, # don't shuffle -> want to check if the right attributes get taken load_adata=_read_lazy_x_and_obs_only_from_raw, From 19d165b28e58015b7155d22a107b8ce483ceba6a Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Thu, 15 Jan 2026 21:31:31 +0100 Subject: [PATCH 32/39] fix: remove parameters in default test --- tests/test_preshuffle.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/tests/test_preshuffle.py b/tests/test_preshuffle.py index ced65ae3..853ee5d6 100644 --- a/tests/test_preshuffle.py +++ b/tests/test_preshuffle.py @@ -152,7 +152,6 @@ def test_store_creation_default( adata_with_h5_path_different_var_space: tuple[ad.AnnData, Path], open_store: Callable[[Path], zarr.Group | h5py.Group], ): - var_subset = [f"gene_{i}" for i in range(100)] h5_files = sorted(adata_with_h5_path_different_var_space[1].iterdir()) output_path = ( adata_with_h5_path_different_var_space[1].parent @@ -161,13 +160,6 @@ def test_store_creation_default( store = open_store(output_path, mode="w") collection = DatasetCollection(store).add_adatas( [adata_with_h5_path_different_var_space[1] / f for f in h5_files if str(f).endswith(".h5ad")], - var_subset=var_subset, - zarr_sparse_chunk_size=10, - zarr_sparse_shard_size=20, - zarr_dense_chunk_size=10, - zarr_dense_shard_size=20, - n_obs_per_dataset=50, - shuffle_slice_size=10, ) assert isinstance(ad.io.read_elem(next(iter(collection))).X, sp.csr_matrix) # Test directory structure to make sure nothing extraneous was written From e753c9ec77db373fcecc551656c9408ba845df49 Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Thu, 15 Jan 2026 21:44:52 +0100 Subject: [PATCH 33/39] fix: `<=` for checking --- docs/notebooks/example.ipynb | 51 ++++++++++++++++++------------------ src/annbatch/io.py | 3 ++- tests/test_preshuffle.py | 1 + 3 files changed, 28 insertions(+), 27 deletions(-) diff --git a/docs/notebooks/example.ipynb b/docs/notebooks/example.ipynb index 72a7da6b..b20ccd1b 100644 --- a/docs/notebooks/example.ipynb +++ b/docs/notebooks/example.ipynb @@ -19,7 +19,7 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": 2, "metadata": { "tags": [ "hide-output" @@ -32,7 +32,7 @@ }, { "cell_type": "code", - "execution_count": 28, + "execution_count": 3, "metadata": { "tags": [ "hide-output" @@ -75,7 +75,7 @@ { "data": { "text/plain": [ - "" + "" ] }, "execution_count": 1, @@ -146,19 +146,17 @@ "text": [ "/Users/ilangold/Projects/Theis/annbatch/venv/lib/python3.12/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", " from .autonotebook import tqdm as notebook_tqdm\n", - "checking for mismatched keys: 100%|██████████| 2/2 [00:00<00:00, 2.20it/s]\n", - "loading: 2it [00:00, 2.24it/s]\n", - "processing chunks: 0%| | 0/2 [00:00" + "" ] }, "execution_count": 3, @@ -167,6 +165,9 @@ } ], "source": [ + "%load_ext autoreload\n", + "%autoreload 2\n", + "\n", "import anndata as ad\n", "from annbatch import DatasetCollection\n", "\n", @@ -199,7 +200,7 @@ " adata_paths=[\"866d7d5e-436b-4dbd-b7c1-7696487d452e.h5ad\", \"f81463b8-4986-4904-a0ea-20ff02cbb317.h5ad\"],\n", " # Path to store the output collection\n", " shuffle=True, # Whether to pre-shuffle the cells of the collection\n", - " n_obs_per_dataset=2_097_152, # Number of cells per dataset shard\n", + " n_obs_per_dataset=2_097_152, # Number of cells per dataset shard, this number is much higher than available in these datasets but is generally a good target\n", " var_subset=None, # Optionally subset the collection to a specific gene space\n", " load_adata=read_lazy_x_and_obs_only,\n", ")" @@ -214,7 +215,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 6, "metadata": { "tags": [ "hide-output" @@ -224,10 +225,10 @@ { "data": { "text/plain": [ - "" + "" ] }, - "execution_count": 4, + "execution_count": 6, "metadata": {}, "output_type": "execute_result" } @@ -264,7 +265,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 7, "metadata": { "tags": [ "hide-output" @@ -275,7 +276,7 @@ "name": "stderr", "output_type": "stream", "text": [ - " 0%| | 42/171792 [00:08<9:35:18, 4.98it/s] \n" + " 0%| | 42/171792 [00:07<8:54:32, 5.36it/s] \n" ] } ], @@ -310,7 +311,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 8, "metadata": { "tags": [ "hide-output" @@ -321,27 +322,25 @@ "name": "stderr", "output_type": "stream", "text": [ - "checking for mismatched keys: 100%|██████████| 1/1 [00:00<00:00, 1.60it/s]\n", - "loading: 1it [00:00, 1.89it/s]\n", - "checking for mismatched keys: 100%|██████████| 3/3 [00:00<00:00, 16.47it/s]\n", - "processing chunks: 0%| | 0/2 [00:00" + "" ] }, - "execution_count": 5, + "execution_count": 8, "metadata": {}, "output_type": "execute_result" } diff --git a/src/annbatch/io.py b/src/annbatch/io.py index efe4fee6..c4af5d79 100644 --- a/src/annbatch/io.py +++ b/src/annbatch/io.py @@ -206,7 +206,7 @@ def _create_chunks_for_shuffling( random.shuffle(idxs) n_slices_per_dataset = int(shuffle_n_obs_per_dataset // shuffle_slice_size) # In this case `shuffle_n_obs_per_dataset` is bigger than the size of the dataset or the slice size is probably too big. - if n_obs < shuffle_n_obs_per_dataset or n_slices_per_dataset <= 1: + if n_obs <= shuffle_n_obs_per_dataset or n_slices_per_dataset <= 1: chunks = [np.concatenate(idxs)] else: # unfortunately, this is the only way to prevent numpy.split from trying to np.array the idxs list, which can have uneven elements. @@ -523,6 +523,7 @@ def _create_collection( var_subset = adata_concat.var_names for i, chunk in enumerate(tqdm(chunks, desc="processing chunks")): + print(chunk) var_mask = adata_concat.var_names.isin(var_subset) # np.sort: It's more efficient to access elements sequentially from dask arrays # The data will be shuffled later on, we just want the elements at this point diff --git a/tests/test_preshuffle.py b/tests/test_preshuffle.py index 853ee5d6..f78227e8 100644 --- a/tests/test_preshuffle.py +++ b/tests/test_preshuffle.py @@ -161,6 +161,7 @@ def test_store_creation_default( collection = DatasetCollection(store).add_adatas( [adata_with_h5_path_different_var_space[1] / f for f in h5_files if str(f).endswith(".h5ad")], ) + assert len(list(iter(collection))) == 1 # default n_obs_per_dataset is much more than total obs assert isinstance(ad.io.read_elem(next(iter(collection))).X, sp.csr_matrix) # Test directory structure to make sure nothing extraneous was written if isinstance(store, zarr.Group): From e29e2b360387b0f019ea319d38a3ff6953cd800f Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Thu, 15 Jan 2026 21:54:27 +0100 Subject: [PATCH 34/39] fix: remove printing --- src/annbatch/io.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/annbatch/io.py b/src/annbatch/io.py index c4af5d79..b14f2c49 100644 --- a/src/annbatch/io.py +++ b/src/annbatch/io.py @@ -523,7 +523,6 @@ def _create_collection( var_subset = adata_concat.var_names for i, chunk in enumerate(tqdm(chunks, desc="processing chunks")): - print(chunk) var_mask = adata_concat.var_names.isin(var_subset) # np.sort: It's more efficient to access elements sequentially from dask arrays # The data will be shuffled later on, we just want the elements at this point From 1759c9af5b952f184848baa8a13f7dc90c991e94 Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Thu, 15 Jan 2026 22:37:45 +0100 Subject: [PATCH 35/39] fix: ensure datasets are properly exhausted when adding --- src/annbatch/io.py | 52 ++++++++++++++++++++++++++-------------------- 1 file changed, 30 insertions(+), 22 deletions(-) diff --git a/src/annbatch/io.py b/src/annbatch/io.py index b14f2c49..ab427260 100644 --- a/src/annbatch/io.py +++ b/src/annbatch/io.py @@ -198,24 +198,37 @@ def _lazy_load_anndatas( def _create_chunks_for_shuffling( - n_obs: int, shuffle_n_obs_per_dataset: int = 1_048_576, shuffle_slice_size: int = 1000, shuffle: bool = True + n_obs: int, + shuffle_slice_size: int = 1000, + shuffle: bool = True, + *, + shuffle_n_obs_per_dataset: int | None = None, + n_chunkings: int | None = None, ) -> list[np.ndarray]: # this splits the array up into `shuffle_slice_size` contiguous runs idxs = split_given_size(np.arange(n_obs), shuffle_slice_size) if shuffle: random.shuffle(idxs) - n_slices_per_dataset = int(shuffle_n_obs_per_dataset // shuffle_slice_size) + match shuffle_n_obs_per_dataset is not None, n_chunkings is not None: + case True, False: + n_slices_per_dataset = int(shuffle_n_obs_per_dataset // shuffle_slice_size) + use_single_chunking = n_obs <= shuffle_n_obs_per_dataset or n_slices_per_dataset <= 1 + case False, True: + n_slices_per_dataset = (n_obs // n_chunkings) // shuffle_slice_size + use_single_chunking = n_chunkings == 1 + case _, _: + raise ValueError("Cannot provide both shuffle_n_obs_per_dataset and n_chunkings or neither") # In this case `shuffle_n_obs_per_dataset` is bigger than the size of the dataset or the slice size is probably too big. - if n_obs <= shuffle_n_obs_per_dataset or n_slices_per_dataset <= 1: - chunks = [np.concatenate(idxs)] - else: - # unfortunately, this is the only way to prevent numpy.split from trying to np.array the idxs list, which can have uneven elements. - idxs = np.array([slice(int(idx[0]), int(idx[-1] + 1)) for idx in idxs]) - chunks = [ - np.concatenate([np.arange(s.start, s.stop) for s in idx]) - for idx in split_given_size(idxs, n_slices_per_dataset) - ] - return chunks + if use_single_chunking: + return [np.concatenate(idxs)] + # unfortunately, this is the only way to prevent numpy.split from trying to np.array the idxs list, which can have uneven elements. + idxs = np.array([slice(int(idx[0]), int(idx[-1] + 1)) for idx in idxs]) + return [ + np.concatenate([np.arange(s.start, s.stop) for s in idx]) + for idx in ( + split_given_size(idxs, n_slices_per_dataset) if n_chunkings is None else np.array_split(idxs, n_chunkings) + ) + ] def _compute_blockwise(x: DaskArray) -> sp.spmatrix: @@ -516,12 +529,11 @@ def _create_collection( adata_concat.obs_names_make_unique() n_obs_per_dataset = min(adata_concat.shape[0], n_obs_per_dataset) chunks = _create_chunks_for_shuffling( - adata_concat.shape[0], n_obs_per_dataset, shuffle_slice_size, shuffle=shuffle + adata_concat.shape[0], shuffle_slice_size, shuffle=shuffle, shuffle_n_obs_per_dataset=n_obs_per_dataset ) if var_subset is None: var_subset = adata_concat.var_names - for i, chunk in enumerate(tqdm(chunks, desc="processing chunks")): var_mask = adata_concat.var_names.isin(var_subset) # np.sort: It's more efficient to access elements sequentially from dask arrays @@ -606,16 +618,12 @@ def _add_to_collection( # Check for mismatched keys between datasets and the inputs. _check_for_mismatched_keys([adata_concat] + [self._group[k] for k in self._dataset_keys]) chunks = _create_chunks_for_shuffling( - adata_concat.shape[0], - np.ceil(len(adata_concat) / len(self._dataset_keys)), - shuffle_slice_size, - shuffle=shuffle, + adata_concat.shape[0], shuffle_slice_size, shuffle=shuffle, n_chunkings=len(self._dataset_keys) ) adata_concat.obs_names_make_unique() - for dataset, chunk in tqdm( - zip(self._dataset_keys, chunks, strict=False), total=len(self._dataset_keys), desc="processing chunks" + zip(self._dataset_keys, chunks, strict=True), total=len(self._dataset_keys), desc="processing chunks" ): adata_dataset = ad.io.read_elem(self._group[dataset]) subset_adata = _to_categorical_obs( @@ -623,9 +631,9 @@ def _add_to_collection( ) adata = ad.concat([adata_dataset, subset_adata], join="outer") if shuffle: - idxs = np.random.default_rng().permutation(len(adata)) + idxs = np.random.default_rng().permutation(adata.shape[0]) else: - idxs = np.arange(len(adata)) + idxs = np.arange(adata.shape[0]) adata = _persist_adata_in_memory(adata[idxs, :]) if isinstance(self._group, zarr.Group): write_sharded( From 29d71e9ce2a8d3b962c12316a22ca92e5505f429 Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Thu, 15 Jan 2026 22:56:23 +0100 Subject: [PATCH 36/39] fix: add error --- src/annbatch/io.py | 6 ++++++ tests/test_preshuffle.py | 3 +-- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/src/annbatch/io.py b/src/annbatch/io.py index ab427260..0c72585d 100644 --- a/src/annbatch/io.py +++ b/src/annbatch/io.py @@ -1,5 +1,6 @@ from __future__ import annotations +import math import random import re import warnings @@ -615,6 +616,11 @@ def _add_to_collection( _check_for_mismatched_keys(adata_paths, load_adata=load_adata) adata_concat = _lazy_load_anndatas(adata_paths, load_adata=load_adata) + if math.ceil(adata_concat.shape[0] / shuffle_slice_size) < len(self._dataset_keys): + raise ValueError( + f"Use a shuffle size small enough to distribute the input data with {adata_concat.shape[0]} obs across {len(self._dataset_keys)} anndata stores." + "Open an issue if the incoming anndata is so small it cannot be distributed across the on-disk data" + ) # Check for mismatched keys between datasets and the inputs. _check_for_mismatched_keys([adata_concat] + [self._group[k] for k in self._dataset_keys]) chunks = _create_chunks_for_shuffling( diff --git a/tests/test_preshuffle.py b/tests/test_preshuffle.py index f78227e8..f3527961 100644 --- a/tests/test_preshuffle.py +++ b/tests/test_preshuffle.py @@ -142,8 +142,7 @@ def test_store_addition_different_keys( zarr_sparse_shard_size=20, zarr_dense_chunk_size=5, zarr_dense_shard_size=10, - n_obs_per_dataset=50, - shuffle_slice_size=10, + shuffle_slice_size=2, ) From 5641610616e98243b00e69a64b7a8aa572bd4a00 Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Fri, 16 Jan 2026 00:12:59 +0100 Subject: [PATCH 37/39] fix: use pandas index for categoricals --- src/annbatch/io.py | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/src/annbatch/io.py b/src/annbatch/io.py index 0c72585d..d897d9ed 100644 --- a/src/annbatch/io.py +++ b/src/annbatch/io.py @@ -160,7 +160,7 @@ def _lazy_load_anndatas( ), ): adatas = [] - categoricals_in_all_adatas = {} + categoricals_in_all_adatas: dict[str, pd.Index] = {} for i, path in tqdm(enumerate(paths), desc="loading"): adata = load_adata(path) # Track the source file for this given anndata object @@ -170,19 +170,17 @@ def _lazy_load_anndatas( # Concatenating Dataset2D drops categoricals so we need to track them if isinstance(adata.obs, Dataset2D): categorical_cols_in_this_adata = { - col: set(adata.obs[col].dtype.categories) - for col in adata.obs.columns - if adata.obs[col].dtype == "category" + col: adata.obs[col].dtype.categories for col in adata.obs.columns if adata.obs[col].dtype == "category" } if not categoricals_in_all_adatas: categoricals_in_all_adatas = { **categorical_cols_in_this_adata, - "src_path": set(adata.obs["src_path"].dtype.categories), + "src_path": adata.obs["src_path"].dtype.categories, } else: for k in categoricals_in_all_adatas.keys() & categorical_cols_in_this_adata.keys(): - categoricals_in_all_adatas[k] = set(categoricals_in_all_adatas[k]).union( - set(categorical_cols_in_this_adata[k]) + categoricals_in_all_adatas[k] = categoricals_in_all_adatas[k].union( + categorical_cols_in_this_adata[k] ) # TODO: Probably bug in anndata, need the true index for proper outer joins (can't skirt this with fake indexes, at least not in the mixed-type regime). if isinstance(adata.var, Dataset2D): @@ -640,7 +638,7 @@ def _add_to_collection( idxs = np.random.default_rng().permutation(adata.shape[0]) else: idxs = np.arange(adata.shape[0]) - adata = _persist_adata_in_memory(adata[idxs, :]) + adata = _persist_adata_in_memory(adata[idxs, :].copy()) if isinstance(self._group, zarr.Group): write_sharded( self._group, From 535a2182dd9bad00eb209db7264f91ae8953cc08 Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Fri, 16 Jan 2026 12:00:22 +0100 Subject: [PATCH 38/39] fix: doc fixes +`slice_size` -> `chunk_size` --- src/annbatch/io.py | 42 ++++++++++++++++++++-------------------- tests/conftest.py | 2 +- tests/test_preshuffle.py | 18 ++++++++--------- 3 files changed, 31 insertions(+), 31 deletions(-) diff --git a/src/annbatch/io.py b/src/annbatch/io.py index d897d9ed..dd23f2f5 100644 --- a/src/annbatch/io.py +++ b/src/annbatch/io.py @@ -198,22 +198,22 @@ def _lazy_load_anndatas( def _create_chunks_for_shuffling( n_obs: int, - shuffle_slice_size: int = 1000, + shuffle_chunk_size: int = 1000, shuffle: bool = True, *, shuffle_n_obs_per_dataset: int | None = None, n_chunkings: int | None = None, ) -> list[np.ndarray]: - # this splits the array up into `shuffle_slice_size` contiguous runs - idxs = split_given_size(np.arange(n_obs), shuffle_slice_size) + # this splits the array up into `shuffle_chunk_size` contiguous runs + idxs = split_given_size(np.arange(n_obs), shuffle_chunk_size) if shuffle: random.shuffle(idxs) match shuffle_n_obs_per_dataset is not None, n_chunkings is not None: case True, False: - n_slices_per_dataset = int(shuffle_n_obs_per_dataset // shuffle_slice_size) + n_slices_per_dataset = int(shuffle_n_obs_per_dataset // shuffle_chunk_size) use_single_chunking = n_obs <= shuffle_n_obs_per_dataset or n_slices_per_dataset <= 1 case False, True: - n_slices_per_dataset = (n_obs // n_chunkings) // shuffle_slice_size + n_slices_per_dataset = (n_obs // n_chunkings) // shuffle_chunk_size use_single_chunking = n_chunkings == 1 case _, _: raise ValueError("Cannot provide both shuffle_n_obs_per_dataset and n_chunkings or neither") @@ -366,10 +366,10 @@ def add_adatas( zarr_compressor: Iterable[BytesBytesCodec] = (BloscCodec(cname="lz4", clevel=3, shuffle=BloscShuffle.shuffle),), h5ad_compressor: Literal["gzip", "lzf"] | None = "gzip", n_obs_per_dataset: int = 2_097_152, - shuffle_slice_size: int = 1000, + shuffle_chunk_size: int = 1000, shuffle: bool = True, ) -> Self: - """Take AnnData paths and create or add to an on-disk set of AnnData datasets with uniform var spaces at the desired path (with `n_obs_per_dataset` rows per store if running for the first time). + """Take AnnData paths and create or add to an on-disk set of AnnData datasets with uniform var spaces at the desired path (with `n_obs_per_dataset` rows per dataset if running for the first time). The set of AnnData datasets is collectively referred to as a "collection" where each dataset is called `dataset_i.{zarr,h5ad}`. The main purpose of this function is to create shuffled sharded zarr datasets, which is the default behavior of this function. @@ -377,7 +377,7 @@ def add_adatas( The var space is by default outer-joined initially, and then subsequently added datasets (i.e., on second calls to this function) are subsetted, but this behavior can be controlled by `var_subset`. A key `src_path` is added to `obs` to indicate where individual row came from. We highly recommend making your indexes unique across files, and this function will call `AnnData.obs_names_make_unique`. - Memory usage should be controlled by `n_obs_per_dataset` + `shuffle_slice_size` as so many rows will be read into memory before writing to disk. + Memory usage should be controlled by `n_obs_per_dataset` + `shuffle_chunk_size` as so many rows will be read into memory before writing to disk. After the dataset completes, a marker is added to the group's `attrs` to note that this dataset has been shuffled by `annbatch`. This is not a stable API but only for internal purposes at the moment. @@ -412,9 +412,9 @@ def add_adatas( shuffle Whether to shuffle the data before writing it to the store. Ignored once the store is non-empty. - shuffle_slice_size + shuffle_chunk_size How many contiguous rows to load into memory before shuffling at once. - `(shuffle_slice_size // n_obs_per_dataset)` slices will be loaded of size `shuffle_slice_size`. + `(shuffle_chunk_size // n_obs_per_dataset)` slices will be loaded of size `shuffle_chunk_size`. Examples -------- @@ -438,7 +438,7 @@ def add_adatas( ... load_adata=read_lazy_x_and_obs_only, ...) """ - if shuffle_slice_size > n_obs_per_dataset: + if shuffle_chunk_size > n_obs_per_dataset: raise ValueError("Cannot have a large slice size than observations per dataset") shared_kwargs = { "adata_paths": adata_paths, @@ -449,7 +449,7 @@ def add_adatas( "zarr_dense_shard_size": zarr_dense_shard_size, "zarr_compressor": zarr_compressor, "h5ad_compressor": h5ad_compressor, - "shuffle_slice_size": shuffle_slice_size, + "shuffle_chunk_size": shuffle_chunk_size, "shuffle": shuffle, } if self.is_empty: @@ -473,10 +473,10 @@ def _create_collection( zarr_compressor: Iterable[BytesBytesCodec] = (BloscCodec(cname="lz4", clevel=3, shuffle=BloscShuffle.shuffle),), h5ad_compressor: Literal["gzip", "lzf"] | None = "gzip", n_obs_per_dataset: int = 2_097_152, - shuffle_slice_size: int = 1000, + shuffle_chunk_size: int = 1000, shuffle: bool = True, ) -> None: - """Take AnnData paths, create an on-disk set of AnnData datasets with uniform var spaces at the desired path with `n_obs_per_dataset` rows per store. + """Take AnnData paths, create an on-disk set of AnnData datasets with uniform var spaces at the desired path with `n_obs_per_dataset` rows per dataset. The set of AnnData datasets is collectively referred to as a "collection" where each dataset is called `dataset_i.{zarr,h5ad}`. The main purpose of this function is to create shuffled sharded zarr datasets, which is the default behavior of this function. @@ -517,9 +517,9 @@ def _create_collection( Only applicable when adding datasets for the first time, otherwise ignored. shuffle Whether to shuffle the data before writing it to the store. - shuffle_slice_size + shuffle_chunk_size How many contiguous rows to load into memory before shuffling at once. - `(shuffle_slice_size // n_obs_per_dataset)` slices will be loaded of size `shuffle_slice_size`. + `(shuffle_chunk_size // n_obs_per_dataset)` slices will be loaded of size `shuffle_chunk_size`. """ if not self.is_empty: raise RuntimeError("Cannot create a collection at a location that already has a shuffled collection") @@ -528,7 +528,7 @@ def _create_collection( adata_concat.obs_names_make_unique() n_obs_per_dataset = min(adata_concat.shape[0], n_obs_per_dataset) chunks = _create_chunks_for_shuffling( - adata_concat.shape[0], shuffle_slice_size, shuffle=shuffle, shuffle_n_obs_per_dataset=n_obs_per_dataset + adata_concat.shape[0], shuffle_chunk_size, shuffle=shuffle, shuffle_n_obs_per_dataset=n_obs_per_dataset ) if var_subset is None: @@ -574,7 +574,7 @@ def _add_to_collection( zarr_dense_shard_size: int = 4_194_304, zarr_compressor: Iterable[BytesBytesCodec] = (BloscCodec(cname="lz4", clevel=3, shuffle=BloscShuffle.shuffle),), h5ad_compressor: Literal["gzip", "lzf"] | None = "gzip", - shuffle_slice_size: int = 1000, + shuffle_chunk_size: int = 1000, shuffle: bool = True, ) -> None: """Add anndata files to an existing collection of sharded anndata zarr datasets. @@ -603,7 +603,7 @@ def _add_to_collection( should_sparsify_output_in_memory This option is for testing only appending sparse files to dense stores. To save memory, the blocks of a dense on-disk store can be sparsified for in-memory processing. - shuffle_slice_size + shuffle_chunk_size How many contiguous rows to load into memory of the input data for pseudo-blockshuffling into the existing datasets. shuffle Whether or not to shuffle when adding. Otherwise, the incoming data will just be split up and appended. @@ -614,7 +614,7 @@ def _add_to_collection( _check_for_mismatched_keys(adata_paths, load_adata=load_adata) adata_concat = _lazy_load_anndatas(adata_paths, load_adata=load_adata) - if math.ceil(adata_concat.shape[0] / shuffle_slice_size) < len(self._dataset_keys): + if math.ceil(adata_concat.shape[0] / shuffle_chunk_size) < len(self._dataset_keys): raise ValueError( f"Use a shuffle size small enough to distribute the input data with {adata_concat.shape[0]} obs across {len(self._dataset_keys)} anndata stores." "Open an issue if the incoming anndata is so small it cannot be distributed across the on-disk data" @@ -622,7 +622,7 @@ def _add_to_collection( # Check for mismatched keys between datasets and the inputs. _check_for_mismatched_keys([adata_concat] + [self._group[k] for k in self._dataset_keys]) chunks = _create_chunks_for_shuffling( - adata_concat.shape[0], shuffle_slice_size, shuffle=shuffle, n_chunkings=len(self._dataset_keys) + adata_concat.shape[0], shuffle_chunk_size, shuffle=shuffle, n_chunkings=len(self._dataset_keys) ) adata_concat.obs_names_make_unique() diff --git a/tests/conftest.py b/tests/conftest.py index 24321bbd..c34d9a73 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -119,6 +119,6 @@ def simple_collection( zarr_dense_chunk_size=10, zarr_dense_shard_size=20, n_obs_per_dataset=60, - shuffle_slice_size=10, + shuffle_chunk_size=10, ) return ad.concat([ad.io.read_elem(ds) for ds in collection], join="outer"), collection diff --git a/tests/test_preshuffle.py b/tests/test_preshuffle.py index f3527961..2e28a6f8 100644 --- a/tests/test_preshuffle.py +++ b/tests/test_preshuffle.py @@ -57,7 +57,7 @@ def test_store_creation_warnings_with_different_keys(elem_name: Literal["obsm", zarr_dense_chunk_size=5, zarr_dense_shard_size=10, n_obs_per_dataset=10, - shuffle_slice_size=10, + shuffle_chunk_size=10, ) @@ -75,7 +75,7 @@ def test_store_creation_no_warnings_with_custom_load(tmp_path: Path): zarr_dense_chunk_size=5, zarr_dense_shard_size=10, n_obs_per_dataset=10, - shuffle_slice_size=5, + shuffle_chunk_size=5, load_adata=lambda x: ad.AnnData(X=ad.io.read_elem(h5py.File(x)["X"])), ) assert len(ad.read_zarr(next(iter(collection))).layers.keys()) == 0 @@ -97,7 +97,7 @@ def test_store_creation_path_added_to_obs(tmp_path: Path): zarr_dense_chunk_size=5, zarr_dense_shard_size=10, n_obs_per_dataset=10, - shuffle_slice_size=5, + shuffle_chunk_size=5, shuffle=False, ) adata_result = ad.concat([ad.io.read_elem(g) for g in collection], join="outer") @@ -126,7 +126,7 @@ def test_store_addition_different_keys( zarr_dense_chunk_size=10, zarr_dense_shard_size=20, n_obs_per_dataset=50, - shuffle_slice_size=10, + shuffle_chunk_size=10, ) extra_args = { elem_name: {"arr" if elem_name != "raw" else "X": np.random.randn(10, 20) if elem_name != "obs" else ["a"] * 10} @@ -142,7 +142,7 @@ def test_store_addition_different_keys( zarr_sparse_shard_size=20, zarr_dense_chunk_size=5, zarr_dense_shard_size=10, - shuffle_slice_size=2, + shuffle_chunk_size=2, ) @@ -187,7 +187,7 @@ def test_store_creation( zarr_dense_chunk_size=5, zarr_dense_shard_size=10, n_obs_per_dataset=50, - shuffle_slice_size=10, + shuffle_chunk_size=10, shuffle=shuffle, ) assert not DatasetCollection(output_path).is_empty @@ -273,7 +273,7 @@ def test_mismatched_raw_concat( zarr_dense_chunk_size=10, zarr_dense_shard_size=20, n_obs_per_dataset=30, - shuffle_slice_size=10, + shuffle_chunk_size=10, shuffle=False, # don't shuffle -> want to check if the right attributes get taken load_adata=_read_lazy_x_and_obs_only_from_raw, ) @@ -318,7 +318,7 @@ def test_store_extension( zarr_dense_chunk_size=10, zarr_dense_shard_size=20, n_obs_per_dataset=60, - shuffle_slice_size=10, + shuffle_chunk_size=10, shuffle=True, ) # add h5ads to existing store @@ -330,7 +330,7 @@ def test_store_extension( zarr_dense_chunk_size=5, zarr_dense_shard_size=10, n_obs_per_dataset=50, - shuffle_slice_size=10, + shuffle_chunk_size=10, ) adatas_on_disk = [ad.io.read_elem(g) for g in collection] adata = ad.concat(adatas_on_disk) From b8b96019112dd51fffc077257f1d0087ee888e4f Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Fri, 16 Jan 2026 12:09:09 +0100 Subject: [PATCH 39/39] chore: stronger shuffle bounds --- tests/test_preshuffle.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_preshuffle.py b/tests/test_preshuffle.py index 2e28a6f8..03faddb4 100644 --- a/tests/test_preshuffle.py +++ b/tests/test_preshuffle.py @@ -214,8 +214,8 @@ def test_store_creation( ) assert "arr" in adata.obsm if shuffle: - # If it's shuffled I'd expect more than 80% of elements to be out of order - assert sum(adata_orig.obs_names != adata.obs_names) > (0.8 * adata.shape[0]) + # If it's shuffled I'd expect more than 90% of elements to be out of order + assert sum(adata_orig.obs_names != adata.obs_names) > (0.9 * adata.shape[0]) assert adata_orig.obs_names.isin(adata.obs_names).all() adata = adata[adata_orig.obs_names].copy() else: