From ba2cbe93fcc62909c2e8ae7cfad5cc4ebf4b24b6 Mon Sep 17 00:00:00 2001 From: selmanozleyen Date: Sun, 18 Jan 2026 19:37:26 +0100 Subject: [PATCH 01/56] resolve conflicts with main --- pyproject.toml | 4 + src/annbatch/__init__.py | 10 +- src/annbatch/loader.py | 300 +++++++++++++++---------------- src/annbatch/sampler/__init__.py | 11 ++ src/annbatch/sampler/_sampler.py | 242 +++++++++++++++++++++++++ src/annbatch/types.py | 17 ++ src/annbatch/utils.py | 44 +++-- tests/test_dataset.py | 157 ++++++++++++---- tests/test_sampler.py | 238 ++++++++++++++++++++++++ 9 files changed, 820 insertions(+), 203 deletions(-) create mode 100644 src/annbatch/sampler/__init__.py create mode 100644 src/annbatch/sampler/_sampler.py create mode 100644 tests/test_sampler.py diff --git a/pyproject.toml b/pyproject.toml index c78e1606..81f95cb1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -178,6 +178,10 @@ omit = [ "**/test_*.py", ] +[[tool.mypy.overrides]] +module = [ "anndata.*", "cupyx.*", "cupy.*" ] +ignore_missing_imports = true + [tool.cruft] skip = [ "tests", diff --git a/src/annbatch/__init__.py b/src/annbatch/__init__.py index d53341c0..9b4f89c2 100644 --- a/src/annbatch/__init__.py +++ b/src/annbatch/__init__.py @@ -2,10 +2,16 @@ from importlib.metadata import version -from . import types +from . import sampler, types from .io import DatasetCollection, write_sharded from .loader import Loader __version__ = version("annbatch") -__all__ = ["Loader", "write_sharded", "DatasetCollection", "types"] +__all__ = [ + "Loader", + "DatasetCollection", + "types", + "sampler", + "write_sharded", +] diff --git a/src/annbatch/loader.py b/src/annbatch/loader.py index d3014421..77a903b8 100644 --- a/src/annbatch/loader.py +++ b/src/annbatch/loader.py @@ -1,7 +1,6 @@ from __future__ import annotations import asyncio -import math from collections import OrderedDict, defaultdict from functools import singledispatchmethod from importlib.util import find_spec @@ -16,18 +15,9 @@ from scipy import sparse as sp from zarr import Array as ZarrArray +from annbatch.sampler import ChunkSampler, Sampler from annbatch.types import BackingArray_T, InputInMemoryArray_T, LoaderOutput, OutputInMemoryArray_T -from annbatch.utils import ( - CSRContainer, - MultiBasicIndexer, - WorkerHandle, - _batched, - check_lt_1, - check_var_shapes, - load_x_and_obs, - split_given_size, - to_torch, -) +from annbatch.utils import CSRContainer, MultiBasicIndexer, check_lt_1, check_var_shapes, to_torch, validate_sampler from .compat import IterableDataset @@ -68,7 +58,9 @@ class Loader[ If `preload_to_gpu` to True and `to_torch` is False, the yielded type is a `cupy` matrix. If `to_torch` is True, the yielded type is a :class:`torch.Tensor`. - If both `preload_to_gpu` and `to_torch` are False, then the return type is the CPU class for the fiven data type. + If both `preload_to_gpu` and `to_torch` are False, then the return type is the CPU class for the given data type. + When providing a custom sampler, `chunk_size`, `preload_nchunks`, `batch_size`, + `shuffle`, and `drop_last` must not be set (they are controlled by the sampler). Parameters ---------- @@ -111,56 +103,65 @@ class Loader[ do_fit(batch) """ + _COMMON_SAMPLER_ARGS = { + "chunk_size": 512, + "preload_nchunks": 32, + "batch_size": 1, + "shuffle": False, + "drop_last": False, + } + _train_datasets: list[BackingArray] _obs: list[pd.DataFrame] | None = None _return_index: bool = False - _batch_size: int = 1 _shapes: list[tuple[int, int]] _preload_to_gpu: bool = True - _drop_last: bool = False _to_torch: bool = True - _shuffle: bool - _preload_nchunks: int - _worker_handle: WorkerHandle - _chunk_size: int _dataset_elem_cache: dict[int, CSRDatasetElems] + _batch_sampler: Sampler[list[slice]] def __init__( self, *, - chunk_size: int = 512, - preload_nchunks: int = 32, - shuffle: bool = True, + batch_sampler: Sampler[list[slice]] | None = None, + chunk_size: int | None = None, + preload_nchunks: int | None = None, + shuffle: bool | None = None, return_index: bool = False, - batch_size: int = 1, + batch_size: int | None = None, preload_to_gpu: bool = find_spec("cupy") is not None, - drop_last: bool = False, + drop_last: bool | None = None, to_torch: bool = find_spec("torch") is not None, ): - check_lt_1( - [ - chunk_size, - preload_nchunks, - ], - ["Chunk size", "Preload chunks"], - ) - if batch_size > (chunk_size * preload_nchunks): - raise NotImplementedError( - "Cannot yield batches bigger than the iterated in-memory size i.e., batch_size > (chunk_size * preload_nchunks)." - ) + sampler_args = { + "chunk_size": chunk_size, + "preload_nchunks": preload_nchunks, + "batch_size": batch_size, + "shuffle": shuffle, + "drop_last": drop_last, + } + if batch_sampler is not None: + if any(v is not None for v in sampler_args.values()): + provided_args = [name for name, val in sampler_args.items() if val is not None] + raise ValueError( + f"Cannot specify {', '.join(provided_args)} when providing a custom sampler. " + "These parameters are controlled by the sampler." + ) + self._batch_sampler = batch_sampler + else: + sampler_args_processed = { + k: (v if v is not None else Loader._COMMON_SAMPLER_ARGS[k]) for k, v in sampler_args.items() + } + self._batch_sampler = ChunkSampler(**sampler_args_processed) + if to_torch and not find_spec("torch"): - raise ImportError("Could not find torch dependeny. Try `pip install torch`.") + raise ImportError("Could not find torch dependency. Try `pip install torch`.") if preload_to_gpu and not find_spec("cupy"): raise ImportError("Follow the directions at https://docs.cupy.dev/en/stable/install.html to install cupy.") + self._return_index = return_index - self._batch_size = batch_size self._preload_to_gpu = preload_to_gpu self._to_torch = to_torch - self._drop_last = drop_last - self._chunk_size = chunk_size - self._preload_nchunks = preload_nchunks - self._shuffle = shuffle - self._worker_handle = WorkerHandle() self._train_datasets = [] self._shapes = [] self._dataset_elem_cache = {} @@ -223,6 +224,8 @@ def n_var(self) -> int: ------- The number of variables. """ + if len(self._shapes) == 0: + raise ValueError("No datasets added yet") return self._shapes[0][1] def use_collection( @@ -251,6 +254,7 @@ def use_collection( self._collection_added = True return self + @validate_sampler(lambda self, adatas: sum(adata.n_obs for adata in adatas)) def add_anndatas( self, adatas: list[ad.AnnData], @@ -264,7 +268,11 @@ def add_anndatas( """ check_lt_1([len(adatas)], ["Number of anndatas"]) for adata in adatas: - self.add_anndata(adata) + dataset = adata.X + obs = adata.obs + if not isinstance(dataset, BackingArray_T.__value__): + raise TypeError(f"Found {type(dataset)} but only {BackingArray_T.__value__} are usable") + self._add_dataset_unchecked(cast("BackingArray", dataset), obs) return self def add_anndata(self, adata: ad.AnnData) -> Self: @@ -282,6 +290,7 @@ def add_anndata(self, adata: ad.AnnData) -> Self: self.add_dataset(cast("BackingArray", dataset), obs) return self + @validate_sampler(lambda self, datasets, obs=None: sum(ds.shape[0] for ds in datasets)) def add_datasets(self, datasets: list[BackingArray], obs: list[pd.DataFrame] | None = None) -> Self: """Append datasets to this dataset. @@ -296,9 +305,10 @@ def add_datasets(self, datasets: list[BackingArray], obs: list[pd.DataFrame] | N if obs is None: obs = [None] * len(datasets) for ds, o in zip(datasets, obs, strict=True): - self.add_dataset(ds, o) + self._add_dataset_unchecked(ds, o) return self + @validate_sampler(lambda self, dataset, obs=None: dataset.shape[0]) def add_dataset(self, dataset: BackingArray, obs: pd.DataFrame | None = None) -> Self: """Append a dataset to this dataset. @@ -309,6 +319,10 @@ def add_dataset(self, dataset: BackingArray, obs: pd.DataFrame | None = None) -> obs :class:`~pandas.DataFrame` labels, generally from :attr:`anndata.AnnData.obs`. """ + self._add_dataset_unchecked(dataset, obs) + return self + + def _add_dataset_unchecked(self, dataset: BackingArray, obs: pd.DataFrame | None = None) -> Self: if len(self._train_datasets) > 0: if self._obs is None and obs is not None: raise ValueError( @@ -334,9 +348,9 @@ def add_dataset(self, dataset: BackingArray, obs: pd.DataFrame | None = None) -> check_var_shapes(datasets) self._shapes = self._shapes + [dataset.shape] self._train_datasets = datasets - if self._obs is not None: # labels exist + if self._obs is not None: self._obs += [obs] - elif obs is not None: # labels dont exist yet, but are being added for the first time + elif obs is not None: self._obs = [obs] return self @@ -386,6 +400,8 @@ def _slices_to_slices_with_array_index( ) -> OrderedDict[int, list[slice]]: """Given a list of slices, give the lookup between on-disk datasets and slices relative to that dataset. + In the codebase we use slice and chunk interchangeably. Not to be confused with the zarr chunking/sharding terminology. + Parameters ---------- slices @@ -398,8 +414,8 @@ def _slices_to_slices_with_array_index( A lookup between the dataset and its indexing slices, ordered by keys. """ dataset_index_to_slices: defaultdict[int, list[slice]] = defaultdict(list) - for slice in slices: - for relative_obs_indices in self._get_relative_obs_indices(slice, use_original_space=use_original_space): + for slice_ in slices: + for relative_obs_indices in self._get_relative_obs_indices(slice_, use_original_space=use_original_space): dataset_index_to_slices[relative_obs_indices[1]] += [relative_obs_indices[0]] keys = sorted(dataset_index_to_slices.keys()) dataset_index_to_slices_sorted = OrderedDict() @@ -407,19 +423,6 @@ def _slices_to_slices_with_array_index( dataset_index_to_slices_sorted[k] = dataset_index_to_slices[k] return dataset_index_to_slices_sorted - def _get_chunks(self, chunk_size: int) -> np.ndarray: - """Get a potentially shuffled list of chunk ids, accounting for the fact that this dataset might be inside a worker. - - Returns - ------- - A :class:`numpy.ndarray` of chunk ids. - """ - chunks = np.arange(math.ceil(self.n_obs / chunk_size)) - if self._shuffle: - self._worker_handle.shuffle(chunks) - - return self._worker_handle.get_part_for_worker(chunks) - @singledispatchmethod async def _fetch_data(self, dataset: ZarrArray | CSRDatasetElems, slices: list[slice]) -> InputInMemoryArray: """Fetch data from an on-disk store. @@ -590,107 +593,92 @@ def __iter__( [len(self._train_datasets), self.n_obs], ["Number of datasets", "Number of observations"], ) - # In order to handle data returned where (chunk_size * preload_nchunks) mod batch_size != 0 - # we must keep track of the leftover data. + in_memory_data = None concatenated_obs = None in_memory_indices = None + mod = self._sp_module if issubclass(self.dataset_type, ad.abc.CSRDataset) else np - for chunk_indices in _batched(self._get_chunks(self._chunk_size), self._preload_nchunks): - slices = [ - slice( - index * self._chunk_size, - min(self.n_obs, (index + 1) * self._chunk_size), - ) - for index in chunk_indices - ] - dataset_index_to_slices = self._slices_to_slices_with_array_index(slices) + + for load_request in self._batch_sampler.sample(self.n_obs): + chunks_to_load = load_request["chunks"] + splits = load_request["splits"] + # Sampler yields a list of slices that sum to batch_size + dataset_index_to_slices = self._slices_to_slices_with_array_index(chunks_to_load, use_original_space=False) # Fetch the data over slices chunks: list[InputInMemoryArray] = zsync.sync(self._index_datasets(dataset_index_to_slices)) - if any(isinstance(c, CSRContainer) for c in chunks): - chunks_converted: list[OutputInMemoryArray] = [ + chunks_converted = self._accumulate_chunks(chunks) + # Accumulate labels and indices if possible + obs: None | list[pd.DataFrame] = self._maybe_accumulate_labels(dataset_index_to_slices) + indices: None | list[np.ndarray] = self._maybe_accumulate_indices(chunks_to_load) + + in_memory_data = mod.vstack(chunks_converted) + if self._obs is not None and obs is not None: + concatenated_obs = pd.concat(obs) + if self._return_index and indices is not None: + in_memory_indices = np.concatenate(indices) + + for split in splits: + yield self._prepare_output( + in_memory_data=in_memory_data, + concatenated_obs=concatenated_obs, + in_memory_indices=in_memory_indices, + split=split, + ) + + def _accumulate_chunks(self, chunks: list[InputInMemoryArray]) -> list[OutputInMemoryArray_T]: + """Convert fetched chunks to output array format (CSR or ndarray).""" + result: list[OutputInMemoryArray_T] = [] + for chunk in chunks: + if isinstance(chunk, CSRContainer): + result.append( self._sp_module.csr_matrix( - tuple(self._np_module.asarray(e) for e in c.elems), - shape=c.shape, - dtype="float64" if self._preload_to_gpu else c.dtype, - ) - for c in chunks - ] - else: - chunks_converted = [self._np_module.asarray(c) for c in chunks] - # Accumulate labels - obs: None | list[pd.DataFrame] = None - if self._obs is not None: - obs = [] - for dataset_idx in dataset_index_to_slices.keys(): - obs += [ - self._obs[dataset_idx].iloc[ - np.concatenate([np.arange(s.start, s.stop) for s in dataset_index_to_slices[dataset_idx]]) - ] - ] - # Accumulate indices if necessary - indices: None | list[np.ndarray] = None - if self._return_index: - dataset_index_to_slices = self._slices_to_slices_with_array_index(slices, use_original_space=True) - dataset_indices = dataset_index_to_slices.keys() - indices = [ - np.concatenate( - [ - np.arange( - s.start, - s.stop, - ) - for s in dataset_index_to_slices[index] - ] + tuple(self._np_module.asarray(e) for e in chunk.elems), + shape=chunk.shape, + dtype="float64" if self._preload_to_gpu else chunk.dtype, ) - for index in dataset_indices - ] - # Do batch returns, handling leftover data as necessary - in_memory_data = ( - mod.vstack(chunks_converted) - if in_memory_data is None - else mod.vstack([in_memory_data, *chunks_converted]) - ) - if self._obs is not None: - concatenated_obs = pd.concat(obs) if concatenated_obs is None else pd.concat([concatenated_obs, *obs]) - if self._return_index: - in_memory_indices = ( - np.concatenate(indices) - if in_memory_indices is None - else np.concatenate([in_memory_indices, *indices]) ) - # Create random indices into in_memory_data and then index into it - # If there is "leftover" at the end (see the modulo op), - # save it for the next iteration. - batch_indices = np.arange(in_memory_data.shape[0]) - if self._shuffle: - np.random.default_rng().shuffle(batch_indices) - splits = split_given_size(batch_indices, self._batch_size) - for i, s in enumerate(splits): - if s.shape[0] == self._batch_size: - output: LoaderOutput = { - "data": to_torch(in_memory_data[s], self._preload_to_gpu) - if self._to_torch - else in_memory_data[s], - "labels": concatenated_obs.iloc[s] if self._obs is not None else None, - "index": in_memory_indices[s] if self._return_index else None, - } - yield output - if i == (len(splits) - 1): # end of iteration, leftover data needs be kept - if (s.shape[0] % self._batch_size) != 0: - in_memory_data = in_memory_data[s] - if concatenated_obs is not None: - concatenated_obs = concatenated_obs.iloc[s] - if in_memory_indices is not None: - in_memory_indices = in_memory_indices[s] - else: - in_memory_data = None - concatenated_obs = None - in_memory_indices = None - if in_memory_data is not None and not self._drop_last: # handle any leftover data - output: LoaderOutput = { - "data": to_torch(in_memory_data, self._preload_to_gpu) if self._to_torch else in_memory_data, - "labels": concatenated_obs if self._obs is not None else None, - "index": in_memory_indices if self._return_index else None, - } - yield output + else: + result.append(self._np_module.asarray(chunk)) + return result + + def _maybe_accumulate_labels( + self, dataset_index_to_slices: OrderedDict[int, list[slice]] + ) -> list[pd.DataFrame] | None: + """Gather obs labels for the loaded slices if possible.""" + if self._obs is None: + return None + return [ + self._obs[idx].iloc[np.concatenate([np.arange(s.start, s.stop) for s in slices])] + for idx, slices in dataset_index_to_slices.items() + ] + + def _maybe_accumulate_indices(self, slices: list[slice]) -> list[np.ndarray] | None: + """Gather original indices for the loaded slices if possible.""" + if self._return_index is False: + return None + dataset_index_to_slices = self._slices_to_slices_with_array_index(slices, use_original_space=True) + return [ + np.concatenate([np.arange(s.start, s.stop) for s in dataset_index_to_slices[idx]]) + for idx in dataset_index_to_slices + ] + + def _prepare_output( + self, + *, + in_memory_data: OutputInMemoryArray_T, + concatenated_obs: pd.DataFrame | None, + in_memory_indices: np.ndarray | None, + split: np.ndarray, + ) -> LoaderOutput: + """Prepare the final output dict for a single batch.""" + index = None + labels = None + if self._obs is not None and concatenated_obs is not None: + labels = concatenated_obs.iloc[split] + if self._return_index and in_memory_indices is not None: + index = in_memory_indices[split] + data = in_memory_data[split] + if self._to_torch: + data = to_torch(data, self._preload_to_gpu) + return {"data": data, "labels": labels, "index": index} diff --git a/src/annbatch/sampler/__init__.py b/src/annbatch/sampler/__init__.py new file mode 100644 index 00000000..9a07f9aa --- /dev/null +++ b/src/annbatch/sampler/__init__.py @@ -0,0 +1,11 @@ +"""Sampler classes for efficient chunk-based data access. + +This module provides samplers optimized for chunk-based data access patterns. +""" + +from annbatch.sampler._sampler import ChunkSampler, Sampler + +__all__ = [ + "ChunkSampler", + "Sampler", +] diff --git a/src/annbatch/sampler/_sampler.py b/src/annbatch/sampler/_sampler.py new file mode 100644 index 00000000..e0aef84b --- /dev/null +++ b/src/annbatch/sampler/_sampler.py @@ -0,0 +1,242 @@ +"""Sampler classes for efficient chunk-based data access.""" + +from __future__ import annotations + +import math +from abc import ABC, abstractmethod +from importlib.util import find_spec +from typing import TYPE_CHECKING + +import numpy as np + +from annbatch.utils import check_lt_1, split_given_size + +if TYPE_CHECKING: + from collections.abc import Iterator + + from annbatch.types import LoadRequest + from annbatch.utils import WorkerHandle + + +class Sampler(ABC): + """Base sampler class. + + Samplers control how data is batched and loaded from the underlying datasets. + """ + + def sample(self, n_obs: int) -> Iterator[LoadRequest]: + """Sample load requests given the total number of observations. + + Parameters + ---------- + n_obs + The total number of observations available. + + Yields + ------ + LoadRequest + Load requests for batching data. + """ + self.validate(n_obs) + yield from self._sample(n_obs) + + @abstractmethod + def validate(self, n_obs: int) -> None: + """Validate the sampler configuration against the loader's state. + + This method is called when the sampler is set on a loader. + Override this method to add custom validation for sampler parameters. + + Parameters + ---------- + n_obs + The total number of observations in the loader. + + Raises + ------ + ValueError + If the sampler configuration is invalid for the given n_obs. + """ + + @abstractmethod + def _sample(self, n_obs: int) -> Iterator[LoadRequest]: + """Implementation of the sample method. + + This method is called by the sample method to perform the actual sampling after + validation has passed. + + Parameters + ---------- + n_obs + The total number of observations available. + + Yields + ------ + LoadRequest + Load requests for batching data. + """ + + +class ChunkSampler(Sampler): + """Chunk-based sampler for batched data access. + + Parameters + ---------- + batch_size + Number of observations per batch. + chunk_size + Size of each chunk i.e. the range of each chunk yielded. + mask + A slice defining the observation range to sample from (start:stop). + shuffle + Whether to shuffle chunk and index order. + preload_nchunks + Number of chunks to load per iteration. + drop_last + Whether to drop the last incomplete batch. + rng + Random number generator for shuffling. + """ + + _batch_size: int + _chunk_size: int + _shuffle: bool + _preload_nchunks: int + _mask: slice + _n_chunks: int + _n_iters: int + _drop_last: bool + _rng: np.random.Generator + + def __init__( + self, + *, + batch_size: int, + chunk_size: int, + mask: slice | None = None, + shuffle: bool = False, + preload_nchunks: int, + drop_last: bool = False, + rng: np.random.Generator | None = None, + ): + if mask is None: + mask = slice(0, None) + if mask.step is not None and mask.step != 1: + raise ValueError(f"mask.step must be 1, but got {mask.step}") + start, stop = mask.start or 0, mask.stop + if start < 0: + raise ValueError("mask.start must be >= 0") + if stop is not None and start >= stop: + raise ValueError("mask.start must be < mask.stop when mask.stop is specified") + + check_lt_1([chunk_size, preload_nchunks], ["Chunk size", "Preloaded chunks"]) + preload_size = chunk_size * preload_nchunks + + if batch_size > preload_size: + raise ValueError( + "batch_size cannot exceed chunk_size * preload_nchunks. " + f"Got batch_size={batch_size}, but max is {preload_size}." + ) + if preload_size % batch_size != 0: + raise ValueError( + "chunk_size * preload_nchunks must be divisible by batch_size. " + f"Got {preload_size} % {batch_size} = {preload_size % batch_size}." + ) + self._rng = rng or np.random.default_rng() + self._batch_size, self._chunk_size, self._shuffle = batch_size, chunk_size, shuffle + self._preload_nchunks, self._mask, self._drop_last = ( + preload_nchunks, + slice(start, stop), + drop_last, + ) # stop can be None + + def validate(self, n_obs: int) -> None: + """Validate the sampler configuration against the loader's n_obs. + + Parameters + ---------- + n_obs + The total number of observations in the loader. + + Raises + ------ + ValueError + If the sampler configuration is invalid for the given n_obs. + """ + start, stop = self._mask.start or 0, self._mask.stop or n_obs + if stop > n_obs: + raise ValueError( + f"Sampler mask.stop ({stop}) exceeds loader n_obs ({n_obs}). " + "The sampler range must be within the loader's observations." + ) + if start >= stop: + raise ValueError(f"Sampler mask.start ({start}) must be < mask.stop ({stop}).") + + def _get_worker_handle(self) -> WorkerHandle | None: + worker_handle = None + if find_spec("torch"): + from torch.utils.data import get_worker_info + + from annbatch.utils import WorkerHandle + + if get_worker_info() is not None: + worker_handle = WorkerHandle() + # Worker mode validation - only check when there are multiple workers + # With batch_size=1, every batch is exactly 1 item, so no partial batches exist + if ( + worker_handle is not None + and worker_handle.num_workers > 1 + and not self._drop_last + and self._batch_size != 1 + ): + raise ValueError("When using DataLoader with multiple workers drop_last=False is not supported.") + return worker_handle + + def _sample(self, n_obs: int) -> Iterator[LoadRequest]: + worker_handle = self._get_worker_handle() + start, stop = self._mask.start or 0, self._mask.stop or n_obs + # Compute chunks directly from resolved mask range + # Create chunk indices for possible shuffling and worker sharding + chunk_indices = np.arange(math.ceil((stop - start) / self._chunk_size)) + if self._shuffle: + self._rng.shuffle(chunk_indices) # TODO: maybe this should be done worker-aware way? + chunks = self._compute_chunks(chunk_indices, start, stop) + # Worker sharding: each worker gets a disjoint subset of chunks + if worker_handle is not None: + chunks = worker_handle.get_part_for_worker(chunks) + # Set up the iterator for chunks and the batch indices for splits + in_memory_size = self._chunk_size * self._preload_nchunks + chunks_per_batch = split_given_size(chunks, self._preload_nchunks) + batch_indices = np.arange(in_memory_size) # to avoid copies use in-place shuffling + split_batch_indices = split_given_size(batch_indices, self._batch_size) + for batch_chunks in chunks_per_batch[:-1]: + if self._shuffle: + self._rng.shuffle(batch_indices) + split_batch_indices = split_given_size(batch_indices, self._batch_size) + yield {"chunks": batch_chunks, "splits": split_batch_indices} + # On the last yield, drop the last uneven batch and create new batch_indices since the in-memory size of this last yield could be divisible by batch_size but smaller than preload_nslices * slice_size + final_chunks = chunks_per_batch[-1] + total_obs_in_last_batch = int(sum(s.stop - s.start for s in final_chunks)) + if self._drop_last: + total_obs_in_last_batch -= total_obs_in_last_batch % self._batch_size + batch_indices = split_given_size( + (self._rng.permutation if self._shuffle else np.arange)(total_obs_in_last_batch), + self._batch_size, + ) + batch_indices.sort(key=len, reverse=True) + yield {"chunks": final_chunks, "splits": batch_indices} + + def _compute_chunks(self, chunk_indices: np.ndarray, start: int, stop: int) -> list[slice]: + """Compute chunks from start and stop indices. + + This function is used to compute the chunks for the data to load. + The chunks are computed such that the last chunk is the incomplete chunk if the total number of observations is not divisible by the chunk size. + Supposed to also work with shuffled chunk indices so that the last chunk computed isn't always the incomplete chunk. + """ + n_chunks, pivot_index = len(chunk_indices), chunk_indices[-1] + offsets = np.ones(n_chunks + 1, dtype=int) * self._chunk_size + offsets[0] = start + offsets[pivot_index + 1] = incomplete if (incomplete := (stop - start) % self._chunk_size) else self._chunk_size + offsets = np.cumsum(offsets) + starts, stops = offsets[:-1][chunk_indices], offsets[1:][chunk_indices] + return [slice(int(s), int(e)) for s, e in zip(starts, stops, strict=True)] diff --git a/src/annbatch/types.py b/src/annbatch/types.py index 39c6a364..9d9a597a 100644 --- a/src/annbatch/types.py +++ b/src/annbatch/types.py @@ -16,6 +16,23 @@ type OutputInMemoryArray_T = sp.csr_matrix | np.ndarray | CupyCSRMatrix | CupyArray | Tensor +class LoadRequest(TypedDict): + """Load request from sampler. + + Attributes + ---------- + chunks + Chunks to load - a list of at most chunk_size ranged slices. + splits + How the concatenation of chunks should be split into batches. + A list of splits, last one may be partial (< batch_size). + The loader carries over partial batches to the next iteration. + """ + + chunks: list[slice] + splits: list[np.ndarray] + + class LoaderOutput[OutputInMemoryArray: OutputInMemoryArray_T](TypedDict): """The output of the loader, the "data matrix" with its labels, optional, and index, also optional.""" diff --git a/src/annbatch/utils.py b/src/annbatch/utils.py index 84d2bc47..e558b861 100644 --- a/src/annbatch/utils.py +++ b/src/annbatch/utils.py @@ -2,9 +2,8 @@ import warnings from dataclasses import dataclass -from functools import cached_property +from functools import cached_property, wraps from importlib.util import find_spec -from itertools import islice from typing import TYPE_CHECKING, Protocol import anndata as ad @@ -15,8 +14,6 @@ from .compat import CupyArray, CupyCSRMatrix, Tensor if TYPE_CHECKING: - from collections.abc import Generator, Iterable - from annbatch.types import OutputInMemoryArray_T @@ -34,14 +31,6 @@ class CSRContainer: dtype: np.dtype -def _batched[T](iterable: Iterable[T], n: int) -> Generator[list[T], None, None]: - if n < 1: - raise ValueError("n must be >= 1") - it = iter(iterable) - while batch := list(islice(it, n)): - yield batch - - # TODO: make this part of the public zarr or zarrs-python API. # We can do chunk coalescing in zarrs based on integer arrays, so I think # there would make sense with ezclump or similar. @@ -75,6 +64,13 @@ def _worker_info(self): return get_worker_info() return None + @property + def num_workers(self) -> int: + """Return the number of workers, or 1 if not in a worker context.""" + if self._worker_info is None: + return 1 + return self._worker_info.num_workers + @cached_property def _rng(self): if self._worker_info is None: @@ -198,3 +194,27 @@ def load_x_and_obs(g: zarr.Group) -> ad.AnnData: return ad.AnnData( X=g["X"] if isinstance(g["X"], zarr.Array) else ad.io.sparse_dataset(g["X"]), obs=ad.io.read_elem(g["obs"]) ) + + +def validate_sampler(get_additional_n_obs): + """Decorator that validates n_obs before modifying state. + + Parameters + ---------- + get_additional_n_obs + A callable (self, *args, **kwargs) -> int that returns the number + of additional observations that will be added by the decorated method.' + For example in add_datasets, this would be lambda self, datasets: sum(dataset.shape[0] for dataset in datasets) + """ + + def decorator(method): + @wraps(method) + def wrapper(self, *args, **kwargs): + additional_obs = get_additional_n_obs(self, *args, **kwargs) + prospective_n_obs = self.n_obs + additional_obs + self._batch_sampler.validate(prospective_n_obs) + return method(self, *args, **kwargs) + + return wrapper + + return decorator diff --git a/tests/test_dataset.py b/tests/test_dataset.py index e30a38cf..16f83209 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -13,6 +13,7 @@ import zarr from annbatch import Loader +from annbatch.sampler import ChunkSampler try: from cupy import ndarray as CupyArray @@ -25,8 +26,6 @@ from collections.abc import Callable from pathlib import Path - from annbatch.io import DatasetCollection - class Data(TypedDict): dataset: ad.abc.CSRDataset | zarr.Array @@ -84,7 +83,7 @@ def concat(datas: list[Data | ad.AnnData]) -> ListData | list[ad.AnnData]: "gen_loader", [ pytest.param( - lambda collection, + lambda path, shuffle, use_zarrs, chunk_size=chunk_size, @@ -99,15 +98,10 @@ def concat(datas: list[Data | ad.AnnData]) -> ListData | list[ad.AnnData]: batch_size=batch_size, preload_to_gpu=preload_to_gpu, to_torch=False, - ).use_collection( - collection, - **( - {"load_adata": lambda group: open_func(group, use_zarrs=use_zarrs, use_anndata=True)} - if open_func is not None - else {} - ), + ).add_anndatas( + [open_func(p, use_zarrs=use_zarrs, use_anndata=True) for p in path.glob("*.zarr")], ), - id=f"chunk_size={chunk_size}-preload_nchunks={preload_nchunks}-open_func={open_func.__name__[5:] if open_func is not None else 'None'}-batch_size={batch_size}{'-cupy' if preload_to_gpu else ''}", # type: ignore[attr-defined] + id=f"chunk_size={chunk_size}-preload_nchunks={preload_nchunks}-dataset_type={open_func.__name__[5:]}-batch_size={batch_size}{'-cupy' if preload_to_gpu else ''}", # type: ignore[attr-defined] marks=pytest.mark.skipif( find_spec("cupy") is None and preload_to_gpu, reason="need cupy installed", @@ -116,7 +110,7 @@ def concat(datas: list[Data | ad.AnnData]) -> ListData | list[ad.AnnData]: for chunk_size, preload_nchunks, open_func, batch_size, preload_to_gpu in [ elem for preload_to_gpu in [True, False] - for open_func in [open_sparse, open_dense, None] + for open_func in [open_sparse, open_dense] for elem in [ [ 1, @@ -146,19 +140,12 @@ def concat(datas: list[Data | ad.AnnData]) -> ListData | list[ad.AnnData]: 50, preload_to_gpu, ], # batch size equal to in-memory size loading - [ - 10, - 5, - open_func, - 14, - preload_to_gpu, - ], # batch size does not divide in memory size evenly ] ] ], ) def test_store_load_dataset( - simple_collection: tuple[ad.AnnData, DatasetCollection], *, shuffle: bool, gen_loader, use_zarrs + adata_with_zarr_path_same_var_space: tuple[ad.AnnData, Path], *, shuffle: bool, gen_loader, use_zarrs ): """ This test verifies that the DaskDataset works correctly: @@ -167,8 +154,8 @@ def test_store_load_dataset( 3. All samples from the dataset are processed 4. If the dataset is not shuffled, it returns the correct data """ - loader: Loader = gen_loader(simple_collection[1], shuffle, use_zarrs) - adata = simple_collection[0] + loader: Loader = gen_loader(adata_with_zarr_path_same_var_space[1], shuffle, use_zarrs) + adata = adata_with_zarr_path_same_var_space[0] is_dense = loader.dataset_type is zarr.Array n_elems = 0 batches = [] @@ -231,11 +218,19 @@ def test_bad_adata_X_type(adata_with_zarr_path_same_var_space: tuple[ad.AnnData, ds.add_dataset(**data) -def test_use_collection_twice(simple_collection: tuple[ad.AnnData, DatasetCollection]): - ds = Loader() - ds = ds.use_collection(simple_collection[1]) - with pytest.raises(RuntimeError, match="You should not add multiple collections"): - ds.use_collection(simple_collection[1]) +def test_batch_size_does_not_divide_evenly_fails(): + """Test that it fails if batch_size does not divide evenly into chunk_size * preload_nchunks.""" + # chunk_size=10, preload_nchunks=5 -> in-memory size = 50 + # batch_size=14 does not divide evenly into 50 + with pytest.raises(ValueError, match="must be divisible by batch_size"): + Loader( + shuffle=False, + chunk_size=10, + preload_nchunks=5, + batch_size=14, + preload_to_gpu=False, + to_torch=False, + ) @pytest.mark.skipif(not find_spec("torch"), reason="need torch installed") @@ -265,7 +260,7 @@ def test_to_torch( shuffle=False, chunk_size=5, preload_nchunks=10, - batch_size=42, + batch_size=25, preload_to_gpu=preload_to_gpu, return_index=True, to_torch=True, @@ -276,14 +271,16 @@ def test_to_torch( @pytest.mark.parametrize("drop_last", [True, False], ids=["drop", "kept"]) def test_drop_last(adata_with_zarr_path_same_var_space: tuple[ad.AnnData, Path], drop_last: bool): - # batch_size guaranteed to have leftovers to drop - batch_size = 42 + # batch_size guaranteed to have last batch to drop + chunk_size = 14 + preload_nchunks = 3 + batch_size = 21 zarr_path = next(adata_with_zarr_path_same_var_space[1].glob("*.zarr")) adata = ad.read_zarr(zarr_path) ds = Loader( shuffle=False, - chunk_size=5, - preload_nchunks=10, + chunk_size=chunk_size, + preload_nchunks=preload_nchunks, batch_size=batch_size, preload_to_gpu=False, return_index=True, @@ -298,6 +295,7 @@ def test_drop_last(adata_with_zarr_path_same_var_space: tuple[ad.AnnData, Path], indices += [batch["index"]] total_obs = adata.shape[0] leftover = total_obs % batch_size + assert leftover != 0, f"batch_size {batch_size} must not divide evenly into {total_obs} observations" for batch in batches[:-1]: assert batch.shape[0] == batch_size assert batches[-1].shape[0] == (batch_size if drop_last else leftover) @@ -412,7 +410,7 @@ def test_default_data_structures( ): # format is a smoke test for sparse ds = Loader( - chunk_size=10, preload_nchunks=4, batch_size=22, shuffle=True, return_index=False, **kwargs + chunk_size=10, preload_nchunks=4, batch_size=20, shuffle=True, return_index=False, **kwargs ).add_dataset( **(open_sparse if issubclass(expected_cls, get_default_sparse()) else open_dense)( list(adata_with_zarr_path_same_var_space[1].iterdir())[0] @@ -420,3 +418,96 @@ def test_default_data_structures( ) for batch in ds: assert isinstance(batch["data"], expected_cls) + + +def test_add_dataset_validation_failure_preserves_state(adata_with_zarr_path_same_var_space: tuple[ad.AnnData, Path]): + """Test that failed validation in add_dataset doesn't modify internal state.""" + from annbatch.sampler import Sampler + + class FailOnSecondValidateSampler(Sampler): + """A sampler that fails validation after the first call.""" + + def __init__(self): + self._validate_count = 0 + + def validate(self, n_obs: int) -> None: + self._validate_count += 1 + if self._validate_count > 1: + raise ValueError("Validation failed on second add") + + @property + def batch_size(self) -> int: + return 10 + + @property + def worker_handle(self): + return None + + def _sample(self, n_obs: int, worker_handle=None): + yield from [] + + paths = list(adata_with_zarr_path_same_var_space[1].glob("*.zarr")) + data1 = open_dense(paths[0]) + data2 = open_dense(paths[1]) + + sampler = FailOnSecondValidateSampler() + loader = Loader(batch_sampler=sampler, preload_to_gpu=False, to_torch=False) + + # First add succeeds + loader.add_dataset(**data1) + + # Capture state before failed add + n_datasets_before = len(loader._train_datasets) + shapes_before = loader._shapes.copy() + + # Second add should fail validation BEFORE modifying state + with pytest.raises(ValueError, match="Validation failed on second add"): + loader.add_dataset(**data2) + + # State should be unchanged + assert len(loader._train_datasets) == n_datasets_before + assert loader._shapes == shapes_before + + +def test_given_batch_sampler_samples_subset_of_combined_datasets( + adata_with_zarr_path_same_var_space: tuple[ad.AnnData, Path], +): + """Test given batch sampler that samples only a specific range from combined datasets. + + Uses multiple zarr files from fixture, combines them, and samples a subset. + """ + paths = list(adata_with_zarr_path_same_var_space[1].glob("*.zarr")) + datas = [open_dense(p) for p in paths] + + # Calculate expected n_obs before creating loader + expected_n_obs = sum(d["dataset"].shape[0] for d in datas) + start_idx, end_idx = expected_n_obs // 4, expected_n_obs // 2 + + sampler = ChunkSampler( + mask=slice(start_idx, end_idx), + batch_size=10, + chunk_size=10, + preload_nchunks=2, + ) + + loader = Loader(batch_sampler=sampler, preload_to_gpu=False, to_torch=False, return_index=True) + loader.add_datasets(**concat(datas)) + + # Collect all yielded indices + all_indices = [] + for batch in loader: + all_indices.append(batch["index"]) + + stacked_indices = np.concatenate(all_indices) + + # Verify we got exactly the expected range + assert set(stacked_indices) == set(range(start_idx, end_idx)) + assert len(stacked_indices) == end_idx - start_idx + + +@pytest.mark.parametrize("kwarg", [{"chunk_size": 10}, {"batch_size": 10}]) +def test_cannot_provide_batch_sampler_with_sampler_args(kwarg): + """Test that providing batch_sampler with sampler args raises in constructor.""" + chunk_sampler = ChunkSampler(mask=slice(0, 50), batch_size=5, chunk_size=10, preload_nchunks=2) + with pytest.raises(ValueError, match="Cannot specify.*when providing a custom sampler"): + Loader(batch_sampler=chunk_sampler, preload_to_gpu=False, to_torch=False, **kwarg) diff --git a/tests/test_sampler.py b/tests/test_sampler.py new file mode 100644 index 00000000..f5867356 --- /dev/null +++ b/tests/test_sampler.py @@ -0,0 +1,238 @@ +"""Tests for ChunkSampler.""" + +from __future__ import annotations + +import numpy as np +import pytest + +from annbatch.sampler import ChunkSampler + +# TODO: Check for the validation within the _get_worker_handle method. Mock worker handle wouldn't make sense +# but overall one must also think about how validation can't be independent of the worker handle. + + +def collect_indices(sampler, n_obs): + """Helper to collect all indices from sampler.""" + indices = set() + for load_request in sampler.sample(n_obs): + assert len(load_request["splits"]) > 0, "splits must be non-empty" + assert all(len(s) > 0 for s in load_request["splits"]), "splits must be non-empty" + for s in load_request["chunks"]: + indices.update(range(s.start, s.stop)) + return indices + + +class MockWorkerHandle: + """Simulates torch worker context for testing without actual DataLoader.""" + + def __init__(self, worker_id: int, num_workers: int, seed: int = 42): + self.worker_id = worker_id + self._num_workers = num_workers + self._rng = np.random.default_rng(seed) + + @property + def num_workers(self) -> int: + return self._num_workers + + def shuffle(self, obj): + self._rng.shuffle(obj) + + def get_part_for_worker(self, obj: np.ndarray) -> np.ndarray: + return np.array_split(obj, self._num_workers)[self.worker_id] + + +class ChunkSamplerWithMockWorkerHandle(ChunkSampler): + def set_worker_handle(self, worker_handle: MockWorkerHandle): + self.worker_handle = worker_handle + + def _get_worker_handle(self) -> MockWorkerHandle | None: + return self.worker_handle + + +# ============================================================================= +# Mask coverage tests +# ============================================================================= + + +@pytest.mark.parametrize( + "n_obs,chunk_size,start,stop,batch_size,preload_nchunks,shuffle", + [ + # Basic full dataset + pytest.param(100, 10, None, None, 5, 2, False, id="full_dataset"), + # mask.start only + pytest.param(100, 10, 30, None, 5, 2, False, id="start_at_chunk_boundary"), + pytest.param(100, 10, 35, None, 5, 2, False, id="start_not_at_chunk_boundary"), + pytest.param(120, 12, 90, None, 3, 1, False, id="start_near_end"), + pytest.param(100, 10, 20, None, 5, 2, False, id="start_mask_stop_none"), + # mask.stop only + pytest.param(50, 10, None, 50, 5, 2, False, id="stop_at_chunk_boundary"), + pytest.param(47, 10, None, 47, 5, 2, False, id="stop_not_at_chunk_boundary"), + # Both bounds + pytest.param(60, 10, 20, 60, 5, 2, False, id="both_at_chunk_boundaries"), + pytest.param(67, 10, 23, 67, 5, 2, False, id="both_not_at_chunk_boundaries"), + pytest.param(28, 10, 22, 28, 2, 1, False, id="single_chunk_span"), + pytest.param(100, 10, 15, 85, 5, 2, False, id="both_non_aligned"), + pytest.param(100, 10, 20, 80, 5, 2, False, id="both_aligned"), + # Edge cases + pytest.param(100, 10, 95, 100, 10, 1, False, id="very_small_mask"), + # With shuffle + pytest.param(100, 10, 30, None, 5, 2, True, id="shuffle_with_start"), + pytest.param(75, 10, 25, 75, 5, 2, True, id="shuffle_with_both_bounds"), + ], +) +def test_mask_coverage(n_obs, chunk_size, start, stop, batch_size, preload_nchunks, shuffle): + """Test sampler covers exactly the expected range.""" + sampler = ChunkSampler( + mask=slice(start, stop), + batch_size=batch_size, + chunk_size=chunk_size, + preload_nchunks=preload_nchunks, + shuffle=shuffle, + rng=np.random.default_rng(42) if shuffle else None, + ) + + all_indices = collect_indices(sampler, n_obs) + + expected_start = start if start is not None else 0 + expected_stop = stop if stop is not None else n_obs + assert all_indices == set(range(expected_start, expected_stop)) + sampler.validate(n_obs) + + +def test_batch_sizes_match_expected_pattern(): + """Test that batch sizes match expected pattern.""" + n_obs, chunk_size, preload_nchunks, batch_size = 103, 10, 2, 5 + # last chunk is incomplete and is also the last batch in the load request + expected_last_chunk_size = 3 + expected_last_batch_size = 3 + expected_last_num_splits = 1 + expected_num_load_requests = 6 + sampler = ChunkSampler( + mask=slice(0, None), + batch_size=batch_size, + chunk_size=chunk_size, + preload_nchunks=preload_nchunks, + ) + + all_requests = list(sampler.sample(n_obs)) + assert len(all_requests) == expected_num_load_requests + for req_idx, load_request in enumerate(all_requests[:-1]): + assert np.all(len(chunk) == chunk_size for chunk in load_request["chunks"]), ( + f"chunk size mismatch at request {req_idx}:", + f"chunks: {load_request['chunks']}", + ) + assert np.all(len(split) == batch_size for split in load_request["splits"]), ( + f"batch size mismatch at request {req_idx}:splits: {load_request['splits']}" + ) + last_request = all_requests[-1] + assert len(last_request["splits"]) == expected_last_num_splits, "last request num splits mismatch" + assert np.all(len(chunk) == expected_last_chunk_size for chunk in last_request["chunks"]), ( + "last request chunk size mismatch", + f"chunks: {last_request['chunks']}", + ) + assert np.all(len(split) == expected_last_batch_size for split in last_request["splits"]), ( + "last request batch size mismatch", + f"splits: {last_request['splits']}", + ) + + +# ============================================================================= +# Worker tests +# ============================================================================= + + +@pytest.mark.parametrize( + "n_obs,chunk_size,preload_nchunks,batch_size,num_workers,drop_last", + [ + pytest.param(200, 10, 2, 10, 2, True, id="two_workers"), + pytest.param(300, 10, 3, 10, 3, True, id="three_workers"), + # checks how it works with batch_size=1 since it is the default case and might be used in torch later + pytest.param(600, 10, 4, 1, 4, False, id="batch_size_one_torch_dataloader_case"), + pytest.param(100, 10, 4, 1, 1, False, id="batch_size_one_single_worker_case"), + pytest.param(95, 10, 4, 1, 1, False, id="batch_size_one_non_divisible_obs_case"), + pytest.param(100, 10, 4, 1, 3, False, id="batch_size_one_three_workers_uneven_case"), + ], +) +def test_workers_cover_full_dataset_without_overlap( + n_obs, chunk_size, preload_nchunks, batch_size, num_workers, drop_last +): + """Test workers cover full dataset without overlap. Also checks if there are empty splits in any of the load requests.""" + all_worker_indices = [] + for worker_id in range(num_workers): + worker_handle = MockWorkerHandle(worker_id, num_workers) + sampler = ChunkSamplerWithMockWorkerHandle( + mask=slice(0, None), + batch_size=batch_size, + chunk_size=chunk_size, + preload_nchunks=preload_nchunks, + drop_last=drop_last, + ) + sampler.set_worker_handle(worker_handle) + all_worker_indices.append(collect_indices(sampler, n_obs)) + + # All workers should have disjoint chunks + for i in range(num_workers): + for j in range(i + 1, num_workers): + assert all_worker_indices[i].isdisjoint(all_worker_indices[j]) + + # Together they cover the full dataset + assert set().union(*all_worker_indices) == set(range(n_obs)) + + +# ============================================================================= +# Validation tests +# ============================================================================= + + +@pytest.mark.parametrize( + "mask,n_obs,error_match", + [ + pytest.param(slice(0, 100), 100, None, id="valid_config"), + pytest.param(slice(0, 200), 100, "mask.stop.*exceeds loader n_obs", id="stop_exceeds_n_obs"), + ], +) +def test_validate(mask, n_obs, error_match): + """Test validate behavior for various configurations.""" + sampler = ChunkSampler(mask=mask, batch_size=5, chunk_size=10, preload_nchunks=2) + if error_match: + with pytest.raises(ValueError, match=error_match): + sampler.validate(n_obs=n_obs) + else: + sampler.validate(n_obs=n_obs) + + +@pytest.mark.parametrize( + "mask,error_match", + [ + pytest.param(slice(-1, 100), "mask.start must be >= 0", id="negative_start"), + pytest.param(slice(50, 50), "mask.start must be < mask.stop", id="start_equals_stop"), + pytest.param(slice(100, 50), "mask.start must be < mask.stop", id="start_greater_than_stop"), + pytest.param(slice(0, 100, 2), "mask.step must be 1, but got 2", id="step_not_one"), + ], +) +def test_invalid_mask_raises(mask, error_match): + """Test that invalid mask configurations raise ValueError at construction.""" + with pytest.raises(ValueError, match=error_match): + ChunkSampler(mask=mask, batch_size=5, chunk_size=10, preload_nchunks=2) + + +# ============================================================================= +# n_obs change tests (To verify nothing is cached between calls.) +# ============================================================================= + + +@pytest.mark.parametrize( + "n_obs_values,expected_ranges", + [ + pytest.param([50, 100], [range(50), range(100)], id="increase_changes_result"), + pytest.param([100, 100], [range(100), range(100)], id="same_gives_same_coverage"), + ], +) +def test_n_obs_coverage(n_obs_values, expected_ranges): + """Test that n_obs changes affect sampling results appropriately.""" + sampler = ChunkSampler(mask=slice(0, None), batch_size=5, chunk_size=10, preload_nchunks=2, shuffle=False) + + results = [collect_indices(sampler, n) for n in n_obs_values] + + for result, expected in zip(results, expected_ranges, strict=True): + assert result == set(expected) From d951c5611f6127199b14a0e5bd556e3454737332 Mon Sep 17 00:00:00 2001 From: selmanozleyen Date: Sun, 18 Jan 2026 19:49:22 +0100 Subject: [PATCH 02/56] load_obs thing was removed by auto formatting --- src/annbatch/loader.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/src/annbatch/loader.py b/src/annbatch/loader.py index 77a903b8..7a2af448 100644 --- a/src/annbatch/loader.py +++ b/src/annbatch/loader.py @@ -17,7 +17,15 @@ from annbatch.sampler import ChunkSampler, Sampler from annbatch.types import BackingArray_T, InputInMemoryArray_T, LoaderOutput, OutputInMemoryArray_T -from annbatch.utils import CSRContainer, MultiBasicIndexer, check_lt_1, check_var_shapes, to_torch, validate_sampler +from annbatch.utils import ( + CSRContainer, + MultiBasicIndexer, + check_lt_1, + check_var_shapes, + load_x_and_obs, + to_torch, + validate_sampler, +) from .compat import IterableDataset From aa348fe72a74e3e336dfdd3d386779fdad25cc8b Mon Sep 17 00:00:00 2001 From: selmanozleyen Date: Sun, 18 Jan 2026 19:58:29 +0100 Subject: [PATCH 03/56] update tests to resolve conflict --- .readthedocs.yaml | 2 +- tests/test_dataset.py | 41 ++++++++++++++++++++--------------------- 2 files changed, 21 insertions(+), 22 deletions(-) diff --git a/.readthedocs.yaml b/.readthedocs.yaml index 4acb793c..c3f3f96f 100644 --- a/.readthedocs.yaml +++ b/.readthedocs.yaml @@ -3,7 +3,7 @@ version: 2 build: os: ubuntu-24.04 tools: - python: "3.14" + python: "3.12" jobs: create_environment: - asdf plugin add uv diff --git a/tests/test_dataset.py b/tests/test_dataset.py index 16f83209..b3bcd877 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -26,6 +26,8 @@ from collections.abc import Callable from pathlib import Path + from annbatch.io import DatasetCollection + class Data(TypedDict): dataset: ad.abc.CSRDataset | zarr.Array @@ -83,7 +85,7 @@ def concat(datas: list[Data | ad.AnnData]) -> ListData | list[ad.AnnData]: "gen_loader", [ pytest.param( - lambda path, + lambda collection, shuffle, use_zarrs, chunk_size=chunk_size, @@ -98,10 +100,15 @@ def concat(datas: list[Data | ad.AnnData]) -> ListData | list[ad.AnnData]: batch_size=batch_size, preload_to_gpu=preload_to_gpu, to_torch=False, - ).add_anndatas( - [open_func(p, use_zarrs=use_zarrs, use_anndata=True) for p in path.glob("*.zarr")], + ).use_collection( + collection, + **( + {"load_adata": lambda group: open_func(group, use_zarrs=use_zarrs, use_anndata=True)} + if open_func is not None + else {} + ), ), - id=f"chunk_size={chunk_size}-preload_nchunks={preload_nchunks}-dataset_type={open_func.__name__[5:]}-batch_size={batch_size}{'-cupy' if preload_to_gpu else ''}", # type: ignore[attr-defined] + id=f"chunk_size={chunk_size}-preload_nchunks={preload_nchunks}-open_func={open_func.__name__[5:] if open_func is not None else 'None'}-batch_size={batch_size}{'-cupy' if preload_to_gpu else ''}", # type: ignore[attr-defined] marks=pytest.mark.skipif( find_spec("cupy") is None and preload_to_gpu, reason="need cupy installed", @@ -110,7 +117,7 @@ def concat(datas: list[Data | ad.AnnData]) -> ListData | list[ad.AnnData]: for chunk_size, preload_nchunks, open_func, batch_size, preload_to_gpu in [ elem for preload_to_gpu in [True, False] - for open_func in [open_sparse, open_dense] + for open_func in [open_sparse, open_dense, None] for elem in [ [ 1, @@ -145,7 +152,7 @@ def concat(datas: list[Data | ad.AnnData]) -> ListData | list[ad.AnnData]: ], ) def test_store_load_dataset( - adata_with_zarr_path_same_var_space: tuple[ad.AnnData, Path], *, shuffle: bool, gen_loader, use_zarrs + simple_collection: tuple[ad.AnnData, DatasetCollection], *, shuffle: bool, gen_loader, use_zarrs ): """ This test verifies that the DaskDataset works correctly: @@ -154,8 +161,8 @@ def test_store_load_dataset( 3. All samples from the dataset are processed 4. If the dataset is not shuffled, it returns the correct data """ - loader: Loader = gen_loader(adata_with_zarr_path_same_var_space[1], shuffle, use_zarrs) - adata = adata_with_zarr_path_same_var_space[0] + loader: Loader = gen_loader(simple_collection[1], shuffle, use_zarrs) + adata = simple_collection[0] is_dense = loader.dataset_type is zarr.Array n_elems = 0 batches = [] @@ -218,19 +225,11 @@ def test_bad_adata_X_type(adata_with_zarr_path_same_var_space: tuple[ad.AnnData, ds.add_dataset(**data) -def test_batch_size_does_not_divide_evenly_fails(): - """Test that it fails if batch_size does not divide evenly into chunk_size * preload_nchunks.""" - # chunk_size=10, preload_nchunks=5 -> in-memory size = 50 - # batch_size=14 does not divide evenly into 50 - with pytest.raises(ValueError, match="must be divisible by batch_size"): - Loader( - shuffle=False, - chunk_size=10, - preload_nchunks=5, - batch_size=14, - preload_to_gpu=False, - to_torch=False, - ) +def test_use_collection_twice(simple_collection: tuple[ad.AnnData, DatasetCollection]): + ds = Loader() + ds = ds.use_collection(simple_collection[1]) + with pytest.raises(RuntimeError, match="You should not add multiple collections"): + ds.use_collection(simple_collection[1]) @pytest.mark.skipif(not find_spec("torch"), reason="need torch installed") From 6220a8acc932015b134fadf1d9a173208a4afb5f Mon Sep 17 00:00:00 2001 From: selmanozleyen Date: Sun, 18 Jan 2026 20:04:29 +0100 Subject: [PATCH 04/56] readthedocs merge --- .readthedocs.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.readthedocs.yaml b/.readthedocs.yaml index c3f3f96f..4acb793c 100644 --- a/.readthedocs.yaml +++ b/.readthedocs.yaml @@ -3,7 +3,7 @@ version: 2 build: os: ubuntu-24.04 tools: - python: "3.12" + python: "3.14" jobs: create_environment: - asdf plugin add uv From 309d33eb9951f63f26f0d0994cd48219e21cbacd Mon Sep 17 00:00:00 2001 From: Ilan Gold Date: Mon, 19 Jan 2026 11:24:23 +0100 Subject: [PATCH 05/56] chore: clarify compatibility of `h5ad` + forward compat of old shuffled `zarr` (#114) * chore: clarify compatibility of `h5ad` + forward compat of old shuffled `zarr` * chore: version * fix: docs * clarify warning * fix: more * fix: `add_anndata` * fix: `h5ad` compat --- CHANGELOG.md | 5 +++ README.md | 6 +++- pyproject.toml | 2 +- src/annbatch/io.py | 77 +++++++++++++++++++++++++++------------- tests/test_preshuffle.py | 43 ++++++++++++++-------- 5 files changed, 93 insertions(+), 40 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index bc396adf..ce9df334 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,10 @@ and this project adheres to [Semantic Versioning][]. [keep a changelog]: https://keepachangelog.com/en/1.0.0/ [semantic versioning]: https://semver.org/spec/v2.0.0.html +# [0.0.3] + +- Revert `h5ad` shuffling into one big store (i.e., go back to sharding into individual files) and add warning that `h5ad` is not fully supported by `annbatch`. `is_collection_h5ad` argument to initialization of {class}`annbatch.DatasetCollection` must be passed when initializing into to use a preshuffled collection of `h5ad` files, reading or writing. + ## [0.0.2] ### Breaking @@ -15,6 +19,7 @@ and this project adheres to [Semantic Versioning][]. - `ZarrSparseDataset` and `ZarrDenseDataset` have been conslidated into {class}`annbatch.Loader` - `create_anndata_collection` and `add_to_collection` have been moved into the {meth}`annbatch.DatasetCollection.add_adatas` method - Default reading of input data is now fully lazy in {meth}`annbatch.DatasetCollection.add_adatas`, and therefore the shuffle process may now be slower although have better memory properties. Use `load_adata` argument in {meth}`annbatch.DatasetCollection.add_adatas` to customize this behavior. +- Files shuffled under the old `create_anndata_collection` will not be recognized by {class}`annbatch.DatasetCollection` and therefore are not usable with the new {class}`annbatch.Loader.use_collection` API. At the moment, the file metadata we maintain is only for internal purposes - however, if you wish to migrate to be able to use {class}`annbatch.DatasetCollection` in conjunction with {class}`annbatch.Loader.use_collection`, the root folder of the old collection must have attrs `{"encoding-type": "annbatch-preshuffled", "encoding-version": "0.1.0"}` and be a {class}`zarr.Group`. The subfolders (i.e., datasets) must be called `dataset_([0-9]*)`. Otherwise you can use the {meth}`annbatch.Loader.add_anndatas` as before. ### Changed diff --git a/README.md b/README.md index 391303c1..c0a35929 100644 --- a/README.md +++ b/README.md @@ -28,7 +28,9 @@ > [!CAUTION] > This package does not have a stable API. -> However, we do not anticipate the on-disk format to change in an incompatible manner. +> However, we do not anticipate the on-disk format to change in a fully incompatible manner. +> Small changes to how we store the shuffled data may occur but you should always be able to load your data somehow i.e., they will never be fully breaking. +> We will always provide lower-level APIs that should make this guarantee possible. [![Tests][badge-tests]][tests] [![Documentation][badge-docs]][documentation] @@ -111,6 +113,8 @@ zarr.config.set( def custom_load_func(g: zarr.Group) -> ad.AnnData: return ad.AnnData(X=ad.io.sparse_dataset(g["layers"]["counts"]), obs=ad.io.read_elem(g["obs"])[some_subset_of_columns]) +# A non empty collection +collection = DatasetCollection("path/to/output/collection.zarr") # This settings override ensures that you don't lose/alter your categorical codes when reading the data in! with ad.settings.override(remove_unused_categories=False): ds = Loader( diff --git a/pyproject.toml b/pyproject.toml index 81f95cb1..59048a0e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ requires = [ "hatchling" ] [project] name = "annbatch" -version = "0.0.2" +version = "0.0.3" description = "A minibatch loader for AnnData stores" readme = "README.md" license = { file = "LICENSE" } diff --git a/src/annbatch/io.py b/src/annbatch/io.py index dd23f2f5..55eeff26 100644 --- a/src/annbatch/io.py +++ b/src/annbatch/io.py @@ -11,7 +11,6 @@ import anndata as ad import dask.array as da -import h5py import numpy as np import pandas as pd import scipy.sparse as sp @@ -28,6 +27,7 @@ from os import PathLike from typing import Any, Literal + import h5py from zarr.abc.codec import BytesBytesCodec V1_ENCODING = {"encoding-type": "annbatch-preshuffled", "encoding-version": "0.1.0"} @@ -304,12 +304,14 @@ def wrapper(*args, **kwargs): return wrapper -class DatasetCollection[T: (h5py.Group, zarr.Group)]: +class DatasetCollection: """A preshuffled collection object including functionality for creating, adding to, and loading collections shuffled by `annbatch`.""" - _group: T + _group: zarr.Group | Path - def __init__(self, group: zarr.Group | h5py.Group | str | Path, *, mode: Literal["a", "r", "r+"] = "a"): + def __init__( + self, group: zarr.Group | str | Path, *, mode: Literal["a", "r", "r+"] = "a", is_collection_h5ad: bool = False + ): """Initialization of the object at a given location. Note that if the group is a h5py/zarr object, it must have the correct permissions for any subsequent operations you plan to do. @@ -319,36 +321,59 @@ def __init__(self, group: zarr.Group | h5py.Group | str | Path, *, mode: Literal Parameters ---------- group - The base location for a preshuffled collection + The base location for a preshuffled collection. + A :class:`zarr.Group` or path ending in `.zarr` indicates zarr as the shuffled format and otherwise a directory of `h5ad` files will be created. """ - if not isinstance(group, h5py.Group | zarr.Group): + if not isinstance(group, zarr.Group): if isinstance(group, str | Path): - if str(group).endswith("h5ad"): - self._group = h5py.File(group, mode=mode) - elif str(group).endswith("zarr"): + if not is_collection_h5ad: + if not str(group).endswith(".zarr"): + warnings.warn( + f"It is highly recommended to make your collections have the `.zarr` suffix, got: {group}.", + stacklevel=2, + ) self._group = zarr.open_group(group, mode=mode) else: - raise ValueError("String argument must end in h5ad or zarr") + warnings.warn( + "Loading h5ad is currently not supported and thus we cannot guarantee the funcionality of the ecosystem with h5ad files." + "DatasetCollection should be able to handle shuffling but we guarantee little else." + "Proceed with caution.", + stacklevel=2, + ) + self._group = Path(group) + self._group.mkdir(exist_ok=True) else: - raise TypeError("group must be a zarr or hdf5 group") + raise TypeError("Group must either be a zarr group or a path") else: + if is_collection_h5ad: + raise ValueError("Do not set `is_collection_h5ad` to True when also passing in a zarr Group.") self._group = group @property def _dataset_keys(self) -> list[str]: - return sorted( - [k for k in self._group.keys() if re.match(rf"{DATASET_PREFIX}_([0-9]*)", k) is not None], - key=lambda x: int(x.split("_")[1]), - ) + if isinstance(self._group, zarr.Group): + return sorted( + [k for k in self._group.keys() if re.match(rf"{DATASET_PREFIX}_([0-9]*)", k) is not None], + key=lambda x: int(x.split("_")[1]), + ) + else: + raise ValueError("Cannot iterate through folder of h5ad files") - def __iter__(self) -> Generator[T]: - for k in self._dataset_keys: - yield self._group[k] + def __iter__(self) -> Generator[zarr.Group]: + if isinstance(self._group, zarr.Group): + for k in self._dataset_keys: + yield self._group[k] + else: + raise ValueError("Cannot iterate through folder of h5ad files") @property def is_empty(self) -> bool: """Wether or not there is an existing store at the group location.""" - return not (V1_ENCODING.items() <= self._group.attrs.items()) or len(self._dataset_keys) == 0 + return ( + (not (V1_ENCODING.items() <= self._group.attrs.items()) or len(self._dataset_keys) == 0) + if isinstance(self._group, zarr.Group) + else (len(list(self._group.iterdir())) == 0) + ) @_with_settings def add_adatas( @@ -555,13 +580,13 @@ def _create_collection( key=f"{DATASET_PREFIX}_{i}", ) else: - ad.io.write_elem( - self._group, f"{DATASET_PREFIX}_{i}", adata_chunk, dataset_kwargs={"compression": h5ad_compressor} + ad.io.write_h5ad( + self._group / f"{DATASET_PREFIX}_{i}.h5ad", + adata_chunk, + dataset_kwargs={"compression": h5ad_compressor}, ) if isinstance(self._group, zarr.Group): self._group.update_attributes(V1_ENCODING) - else: - self._group.attrs.update(V1_ENCODING) def _add_to_collection( self, @@ -651,4 +676,8 @@ def _add_to_collection( key=dataset, ) else: - ad.io.write_elem(self._group, dataset, adata, dataset_kwargs={"compression": h5ad_compressor}) + ad.io.write_h5ad( + self._group / f"{dataset}.h5ad", + adata, + dataset_kwargs={"compression": h5ad_compressor}, + ) diff --git a/tests/test_preshuffle.py b/tests/test_preshuffle.py index 03faddb4..f0aa7956 100644 --- a/tests/test_preshuffle.py +++ b/tests/test_preshuffle.py @@ -1,6 +1,7 @@ from __future__ import annotations import glob +from contextlib import nullcontext from typing import TYPE_CHECKING, Literal import anndata as ad @@ -146,29 +147,43 @@ def test_store_addition_different_keys( ) -@pytest.mark.parametrize("open_store", [h5py.File, zarr.open_group]) +def test_h5ad_and_zarr_simultaneously(tmp_path: Path): + with pytest.raises(ValueError, match=r"Do not set `is_collection_h5ad` to True when also passing in a zarr Group."): + DatasetCollection(zarr.open_group(tmp_path / "foo.zarr"), is_collection_h5ad=True) + + +@pytest.mark.parametrize("is_collection_h5ad", [True, False], ids=["h5ad", "zarr"]) def test_store_creation_default( adata_with_h5_path_different_var_space: tuple[ad.AnnData, Path], - open_store: Callable[[Path], zarr.Group | h5py.Group], + is_collection_h5ad: bool, ): h5_files = sorted(adata_with_h5_path_different_var_space[1].iterdir()) output_path = ( adata_with_h5_path_different_var_space[1].parent - / f"zarr_store_creation_test_default.{'h5ad' if open_store is h5py.File else 'zarr'}" + / f"{'h5ad' if is_collection_h5ad else 'zarr'}_store_creation_test_default" ) - store = open_store(output_path, mode="w") - collection = DatasetCollection(store).add_adatas( - [adata_with_h5_path_different_var_space[1] / f for f in h5_files if str(f).endswith(".h5ad")], + with pytest.warns( + UserWarning, + match=r"collections have the `.zarr` suffix" + if (is_zarr := not is_collection_h5ad) + else r"Loading h5ad is currently not supported", + ): + kwargs = {} if is_zarr else {"is_collection_h5ad": True} + collection = DatasetCollection(output_path, **kwargs).add_adatas( + [adata_with_h5_path_different_var_space[1] / f for f in h5_files if str(f).endswith(".h5ad")] + ) + assert isinstance( + ad.io.read_elem(next(iter(collection)) if is_zarr else h5py.File(next(output_path.iterdir()))).X, sp.csr_matrix ) - assert len(list(iter(collection))) == 1 # default n_obs_per_dataset is much more than total obs - assert isinstance(ad.io.read_elem(next(iter(collection))).X, sp.csr_matrix) + assert len(list(iter(collection) if is_zarr else output_path.iterdir())) == 1 # Test directory structure to make sure nothing extraneous was written - if isinstance(store, zarr.Group): - assert sorted(glob.glob(str(output_path / "dataset_*"))) == sorted( - str(p) for p in (output_path).iterdir() if p.is_dir() - ) - assert list(iter(collection)) == [store[k] for k in sorted(store.keys())] - assert V1_ENCODING.items() <= store.attrs.items() + assert sorted(glob.glob(str(output_path / f"dataset_*{'.h5ad' if is_collection_h5ad else ''}"))) == sorted( + str(p) for p in (output_path).iterdir() if ((p.is_dir() and is_zarr) or not is_zarr) + ) + store = zarr.open(output_path) + with nullcontext() if is_zarr else pytest.raises(ValueError, match=r"Cannot iterate through"): + assert list(iter(collection)) == [store[k] for k in sorted(store.keys())] + assert V1_ENCODING.items() <= store.attrs.items() @pytest.mark.parametrize("shuffle", [pytest.param(True, id="shuffle"), pytest.param(False, id="no_shuffle")]) From 437f184ef8afae99db712d6201ece373cb9ff1b0 Mon Sep 17 00:00:00 2001 From: Ilan Gold Date: Mon, 19 Jan 2026 12:12:31 +0100 Subject: [PATCH 06/56] breaking: clarify obs handling + change output keys (#115) * chore: clarify obs handling * chore: clearer docs * fix: rename --- CHANGELOG.md | 3 + README.md | 10 +- docs/index.md | 2 +- docs/notebooks/example.ipynb | 783 ++++++++++++++++++----------------- src/annbatch/loader.py | 25 +- src/annbatch/types.py | 6 +- src/annbatch/utils.py | 8 +- tests/test_dataset.py | 40 +- 8 files changed, 458 insertions(+), 419 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index ce9df334..fb3263ea 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,7 +10,10 @@ and this project adheres to [Semantic Versioning][]. # [0.0.3] +### Breaking + - Revert `h5ad` shuffling into one big store (i.e., go back to sharding into individual files) and add warning that `h5ad` is not fully supported by `annbatch`. `is_collection_h5ad` argument to initialization of {class}`annbatch.DatasetCollection` must be passed when initializing into to use a preshuffled collection of `h5ad` files, reading or writing. +- Renamed {class}`annbatch.types.LoaderOutput` `["labels"]` and `["data"]` to `["obs"]` and `["X"]` respectively. ## [0.0.2] diff --git a/README.md b/README.md index c0a35929..a87b26bf 100644 --- a/README.md +++ b/README.md @@ -97,6 +97,9 @@ collection.add_adatas( Data loading: +> [!IMPORTANT] +> Without custom loading via `Loader.load_adata` *all* obs columns will be loaded and yielded potentially degrading performance. + ```python from pathlib import Path @@ -110,8 +113,9 @@ zarr.config.set( {"codec_pipeline.path": "zarrs.ZarrsCodecPipeline"} ) +# WARNING: Without custom loading *all* obs columns will be loaded and yielded potentially degrading performance. def custom_load_func(g: zarr.Group) -> ad.AnnData: - return ad.AnnData(X=ad.io.sparse_dataset(g["layers"]["counts"]), obs=ad.io.read_elem(g["obs"])[some_subset_of_columns]) + return ad.AnnData(X=ad.io.sparse_dataset(g["layers"]["counts"]), obs=ad.io.read_elem(g["obs"])[some_subset_of_columns_useful_for_training]) # A non empty collection collection = DatasetCollection("path/to/output/collection.zarr") @@ -125,11 +129,11 @@ with ad.settings.override(remove_unused_categories=False): # `use_collection` automatically uses the on-disk `X` and full `obs` in the `Loader` # but the `load_adata` arg can override this behavior # (see `custom_load_func` above for an example of customization). - ds = ds.use_collection(collection) + ds = ds.use_collection(collection, load_adata = custom_load_func) # Iterate over dataloader (plugin replacement for torch.utils.DataLoader) for batch in ds: - ... + data, obs = batch["X"], batch["obs"] ``` > [!IMPORTANT] diff --git a/docs/index.md b/docs/index.md index ee945a50..d94024b5 100644 --- a/docs/index.md +++ b/docs/index.md @@ -41,7 +41,7 @@ ds = Loader( # Iterate over dataloader (plugin replacement for torch.utils.DataLoader) for batch in ds: - x, df, index = batch["data"], batch["labels"], batch["index"] + x, df, index = batch["X"], batch["obs"], batch["index"] ``` The data loader implements a chunked fetching strategy where `preload_nchunks` number of continguous-chunks of size `chunk_size` are loaded. diff --git a/docs/notebooks/example.ipynb b/docs/notebooks/example.ipynb index b20ccd1b..f7085129 100644 --- a/docs/notebooks/example.ipynb +++ b/docs/notebooks/example.ipynb @@ -1,386 +1,403 @@ { - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Quickstart `annbatch`" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "This notebook will walk you through the following steps:\n", - "1. How to convert an existing collection of `anndata` files into a shuffled, zarr-based, collection of `anndata` datasets\n", - "2. How to load the converted collection using `annbatch`\n", - "3. Extend an existing collection with new `anndata` datasets" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": { - "tags": [ - "hide-output" - ] - }, - "outputs": [], - "source": [ - "# !pip install annbatch[zarrs, torch]" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": { - "tags": [ - "hide-output" - ] - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "zsh:1: command not found: wget\n", - "zsh:1: command not found: wget\n" - ] - } - ], - "source": [ - "# Download two example datasets from CELLxGENE\n", - "!wget https://datasets.cellxgene.cziscience.com/866d7d5e-436b-4dbd-b7c1-7696487d452e.h5ad\n", - "!wget https://datasets.cellxgene.cziscience.com/f81463b8-4986-4904-a0ea-20ff02cbb317.h5ad" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "**IMPORTANT**: Configure zarrs\n", - "\n", - "This step is both required for converting existing `anndata` files into a performant, shuffled collection of datasets for mini batch loading" - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "metadata": { - "tags": [ - "hide-output" - ] - }, - "outputs": [ - { - "data": { - "text/plain": [ - "" - ] - }, - "execution_count": 1, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "import zarr\n", - "\n", - "zarr.config.set({\"codec_pipeline.path\": \"zarrs.ZarrsCodecPipeline\"})" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [], - "source": [ - "import warnings\n", - "\n", - "# Suppress zarr vlen-utf8 codec warnings\n", - "warnings.filterwarnings(\n", - " \"ignore\",\n", - " message=\"The codec `vlen-utf8` is currently not part in the Zarr format 3 specification.*\",\n", - " category=UserWarning,\n", - " module=\"zarr.codecs.vlen_utf8\",\n", - ")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Converting existing `anndata` files into a shuffled collection" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "The conversion code will take care of the following things:\n", - "* Align (outer join) the gene spaces across all datasets listed in `adata_paths`\n", - " * The gene spaces are outer-joined based on the gene names provided in the `var_names` field of the individual `AnnData` objects.\n", - " * If you want to subset to specific gene space, you can provide a list of gene names via the `var_subset` parameter.\n", - "* Shuffle the cells across all datasets (this works on larger than memory datasets as well).\n", - " * This is important for block-wise shuffling during data loading.\n", - "* Shuffle the input files across multiple output datasets:\n", - " * The size of each individual output dataset can be controlled via the `n_obs_per_dataset` parameter.\n", - " * We recommend to choose a dataset size that comfortably fits into system memory.\n", - "\n", - "\n", - "You can apply custom data transformations to each input h5ad file by supplying a `load_adata` function to `DatasetCollection.add`" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": { - "tags": [ - "hide-output" - ] - }, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/Users/ilangold/Projects/Theis/annbatch/venv/lib/python3.12/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", - " from .autonotebook import tqdm as notebook_tqdm\n", - "checking for mismatched keys: 100%|██████████| 2/2 [00:00<00:00, 2.02it/s]\n", - "loading: 2it [00:00, 2.34it/s]\n", - "processing chunks: 0%| | 0/1 [00:00" - ] - }, - "execution_count": 3, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "%load_ext autoreload\n", - "%autoreload 2\n", - "\n", - "import anndata as ad\n", - "from annbatch import DatasetCollection\n", - "\n", - "\n", - "# For CELLxGENE data, the raw counts can either be found under .raw.X or under .X (if .raw is not supplied).\n", - "# To have a store that only contains raw counts, we can write the following load_adata function\n", - "def read_lazy_x_and_obs_only(path) -> ad.AnnData:\n", - " \"\"\"Custom load function to only load raw counts from CxG data.\"\"\"\n", - " # IMPORTANT: Large data should always be loaded lazily to reduce the memory footprint\n", - " adata_ = ad.experimental.read_lazy(path)\n", - " if adata_.raw is not None:\n", - " x = adata_.raw.X\n", - " var = adata_.raw.var\n", - " else:\n", - " x = adata_.X\n", - " var = adata_.var\n", - "\n", - " return ad.AnnData(\n", - " X=x,\n", - " obs=adata_.obs.to_memory()[\n", - " [\"cell_type\"]\n", - " ], # let's only take cell type since it is shared - otherwise DatasetCollection will warn about all the columns we are missing\n", - " var=var.to_memory(),\n", - " )\n", - "\n", - "\n", - "collection = DatasetCollection(zarr.open(\"annbatch_collection\", mode=\"w\"))\n", - "collection.add_adatas(\n", - " # List all the h5ad files you want to include in the collection\n", - " adata_paths=[\"866d7d5e-436b-4dbd-b7c1-7696487d452e.h5ad\", \"f81463b8-4986-4904-a0ea-20ff02cbb317.h5ad\"],\n", - " # Path to store the output collection\n", - " shuffle=True, # Whether to pre-shuffle the cells of the collection\n", - " n_obs_per_dataset=2_097_152, # Number of cells per dataset shard, this number is much higher than available in these datasets but is generally a good target\n", - " var_subset=None, # Optionally subset the collection to a specific gene space\n", - " load_adata=read_lazy_x_and_obs_only,\n", - ")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Data loading example" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "metadata": { - "tags": [ - "hide-output" - ] - }, - "outputs": [ - { - "data": { - "text/plain": [ - "" - ] - }, - "execution_count": 6, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "import anndata as ad\n", - "\n", - "from annbatch import Loader\n", - "\n", - "ds = Loader(\n", - " batch_size=4096, # Total number of obs per yielded batch\n", - " chunk_size=256, # Number of obs to load from disk contiguously - default settings should work well\n", - " preload_nchunks=32, # Number of chunks to preload + shuffle - default settings should work well\n", - " preload_to_gpu=False,\n", - " # If True, preloaded chunks are moved to GPU memory via `cupy`, which can put more pressure on GPU memory but will accelerate loading ~20%\n", - " to_torch=True,\n", - ")\n", - "\n", - "# Add in the shuffled data that should be used for training\n", - "ds.use_collection(collection)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "**IMPORTANT:**\n", - "* The `Loader` yields batches of sparse tensors.\n", - "* The conversion to dense tensors should be done on the GPU, as shown in the example below.\n", - " * First call `.cuda()` and then `.to_dense()`\n", - " * E.g. `x = x.cuda().to_dense()`\n", - " * This is significantly faster than doing the dense conversion on the CPU.\n" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "metadata": { - "tags": [ - "hide-output" - ] - }, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - " 0%| | 42/171792 [00:07<8:54:32, 5.36it/s] \n" - ] - } - ], - "source": [ - "# Iterate over dataloader\n", - "import tqdm\n", - "\n", - "for batch in tqdm.tqdm(ds):\n", - " x, obs = batch[\"data\"], batch[\"labels\"][\"cell_type\"]\n", - " # Important: Convert to dense on GPU\n", - " x = x.cuda().to_dense()\n", - " # Feed data into your model\n", - " ..." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Optional: Extend an existing collection with a new dataset" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "You might want to extend an existing pre-shuffled collection with a new dataset.\n", - "This can be done using the `add` method again.\n", - "\n", - "This function will take care of shuffling the new dataset into the existing collection without having to re-shuffle the entire collection." - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "metadata": { - "tags": [ - "hide-output" - ] - }, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "checking for mismatched keys: 100%|██████████| 1/1 [00:00<00:00, 1.65it/s]\n", - "loading: 1it [00:00, 1.77it/s]\n", - "checking for mismatched keys: 100%|██████████| 2/2 [00:00<00:00, 13.66it/s]\n", - "processing chunks: 0%| | 0/1 [00:00" - ] - }, - "execution_count": 8, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "collection.add_adatas(\n", - " adata_paths=[\n", - " \"866d7d5e-436b-4dbd-b7c1-7696487d452e.h5ad\",\n", - " ],\n", - " load_adata=read_lazy_x_and_obs_only,\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "venv", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.12.3" - } + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Quickstart `annbatch`" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "This notebook will walk you through the following steps:\n", + "1. How to convert an existing collection of `anndata` files into a shuffled, zarr-based, collection of `anndata` datasets\n", + "2. How to load the converted collection using `annbatch`\n", + "3. Extend an existing collection with new `anndata` datasets" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": { + "tags": [ + "hide-output" + ] + }, + "outputs": [], + "source": [ + "# !pip install annbatch[zarrs, torch]" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": { + "tags": [ + "hide-output" + ] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "zsh:1: command not found: wget\n", + "zsh:1: command not found: wget\n" + ] + } + ], + "source": [ + "# Download two example datasets from CELLxGENE\n", + "!wget https://datasets.cellxgene.cziscience.com/866d7d5e-436b-4dbd-b7c1-7696487d452e.h5ad\n", + "!wget https://datasets.cellxgene.cziscience.com/f81463b8-4986-4904-a0ea-20ff02cbb317.h5ad" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "**IMPORTANT**: Configure zarrs\n", + "\n", + "This step is both required for converting existing `anndata` files into a performant, shuffled collection of datasets for mini batch loading" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "tags": [ + "hide-output" + ] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 1, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import zarr\n", + "\n", + "zarr.config.set({\"codec_pipeline.path\": \"zarrs.ZarrsCodecPipeline\"})" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "import warnings\n", + "\n", + "# Suppress zarr vlen-utf8 codec warnings\n", + "warnings.filterwarnings(\n", + " \"ignore\",\n", + " message=\"The codec `vlen-utf8` is currently not part in the Zarr format 3 specification.*\",\n", + " category=UserWarning,\n", + " module=\"zarr.codecs.vlen_utf8\",\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Converting existing `anndata` files into a shuffled collection" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The conversion code will take care of the following things:\n", + "* Align (outer join) the gene spaces across all datasets listed in `adata_paths`\n", + " * The gene spaces are outer-joined based on the gene names provided in the `var_names` field of the individual `AnnData` objects.\n", + " * If you want to subset to specific gene space, you can provide a list of gene names via the `var_subset` parameter.\n", + "* Shuffle the cells across all datasets (this works on larger than memory datasets as well).\n", + " * This is important for block-wise shuffling during data loading.\n", + "* Shuffle the input files across multiple output datasets:\n", + " * The size of each individual output dataset can be controlled via the `n_obs_per_dataset` parameter.\n", + " * We recommend to choose a dataset size that comfortably fits into system memory.\n", + "\n", + "\n", + "You can apply custom data transformations to each input h5ad file by supplying a `load_adata` function to `DatasetCollection.add`" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": { + "tags": [ + "hide-output" + ] + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/ilangold/Projects/Theis/annbatch/venv/lib/python3.12/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n", + "checking for mismatched keys: 100%|██████████| 2/2 [00:00<00:00, 2.19it/s]\n", + "loading: 2it [00:00, 2.19it/s]\n", + "processing chunks: 0%| | 0/1 [00:00" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "%load_ext autoreload\n", + "%autoreload 2\n", + "\n", + "import anndata as ad\n", + "from annbatch import DatasetCollection\n", + "\n", + "# let's write out only shared colunms - otherwise DatasetCollection will warn about all the columns we are missing for good reason - mismatched columns can lead to unexpected data and missing values.\n", + "shared_columns = ad.experimental.read_lazy(\"866d7d5e-436b-4dbd-b7c1-7696487d452e.h5ad\").obs.columns.intersection(\n", + " ad.experimental.read_lazy(\"f81463b8-4986-4904-a0ea-20ff02cbb317.h5ad\").obs.columns\n", + ")\n", + "\n", + "\n", + "# For CELLxGENE data, the raw counts can either be found under .raw.X or under .X (if .raw is not supplied).\n", + "# To have a store that only contains raw counts, we can write the following load_adata function\n", + "def read_lazy_x_and_obs_only(path) -> ad.AnnData:\n", + " \"\"\"Custom load function to only load raw counts from CxG data.\"\"\"\n", + " # IMPORTANT: Large data should always be loaded lazily to reduce the memory footprint\n", + " adata_ = ad.experimental.read_lazy(path)\n", + " if adata_.raw is not None:\n", + " x = adata_.raw.X\n", + " var = adata_.raw.var\n", + " else:\n", + " x = adata_.X\n", + " var = adata_.var\n", + "\n", + " return ad.AnnData(\n", + " X=x,\n", + " obs=adata_.obs.to_memory()[shared_columns],\n", + " var=var.to_memory(),\n", + " )\n", + "\n", + "\n", + "collection = DatasetCollection(zarr.open(\"annbatch_collection\", mode=\"w\"))\n", + "collection.add_adatas(\n", + " # List all the h5ad files you want to include in the collection\n", + " adata_paths=[\"866d7d5e-436b-4dbd-b7c1-7696487d452e.h5ad\", \"f81463b8-4986-4904-a0ea-20ff02cbb317.h5ad\"],\n", + " # Path to store the output collection\n", + " shuffle=True, # Whether to pre-shuffle the cells of the collection\n", + " n_obs_per_dataset=2_097_152, # Number of cells per dataset shard, this number is much higher than available in these datasets but is generally a good target\n", + " var_subset=None, # Optionally subset the collection to a specific gene space\n", + " load_adata=read_lazy_x_and_obs_only,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Data loading example" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now we create our data loader with the desired arguments.\n", + "\n", + "**WARNING**: Without `load_adata` argument in `use_collection`, the *entire* `obs` will be loaded and yielded, degrading performance. It is highly advised to use this argument." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": { + "tags": [ + "hide-output" + ] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import anndata as ad\n", + "\n", + "from annbatch import Loader\n", + "\n", + "\n", + "def _load_adata(g: zarr.Group) -> ad.AnnData:\n", + " return ad.AnnData(X=ad.io.sparse_dataset(g[\"X\"]), obs=ad.experimental.read_lazy(g).obs[[\"cell_type\"]].to_memory())\n", + "\n", + "\n", + "ds = Loader(\n", + " batch_size=4096, # Total number of obs per yielded batch\n", + " chunk_size=256, # Number of obs to load from disk contiguously - default settings should work well\n", + " preload_nchunks=32, # Number of chunks to preload + shuffle - default settings should work well\n", + " # If True, preloaded chunks are moved to GPU memory via `cupy`, which can put more pressure on GPU memory but will accelerate loading ~20%\n", + " preload_to_gpu=False,\n", + " to_torch=True,\n", + ")\n", + "\n", + "# Add in the shuffled data that should be used for training.\n", + "ds.use_collection(collection, load_adata=_load_adata)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "**IMPORTANT:**\n", + "* The `Loader` yields batches of sparse tensors.\n", + "* The conversion to dense tensors should be done on the GPU, as shown in the example below.\n", + " * First call `.cuda()` and then `.to_dense()`\n", + " * E.g. `x = x.cuda().to_dense()`\n", + " * This is significantly faster than doing the dense conversion on the CPU.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [ + "hide-output" + ] + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 0%| | 42/171792 [00:07<8:54:32, 5.36it/s] \n" + ] + } + ], + "source": [ + "# Iterate over dataloader\n", + "import tqdm\n", + "\n", + "for batch in tqdm.tqdm(ds):\n", + " x, obs = batch[\"X\"], batch[\"obs\"][\"cell_type\"]\n", + " # Important: Convert to dense on GPU\n", + " x = x.cuda().to_dense()\n", + " # Feed data into your model\n", + " ..." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Optional: Extend an existing collection with a new dataset" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "You might want to extend an existing pre-shuffled collection with a new dataset.\n", + "This can be done using the `add` method again.\n", + "\n", + "This function will take care of shuffling the new dataset into the existing collection without having to re-shuffle the entire collection." + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": { + "tags": [ + "hide-output" + ] + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "checking for mismatched keys: 100%|██████████| 1/1 [00:00<00:00, 1.65it/s]\n", + "loading: 1it [00:00, 1.77it/s]\n", + "checking for mismatched keys: 100%|██████████| 2/2 [00:00<00:00, 13.66it/s]\n", + "processing chunks: 0%| | 0/1 [00:00" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "collection.add_adatas(\n", + " adata_paths=[\n", + " \"866d7d5e-436b-4dbd-b7c1-7696487d452e.h5ad\",\n", + " ],\n", + " load_adata=read_lazy_x_and_obs_only,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.3" + } + }, + "nbformat": 4, + "nbformat_minor": 4 } diff --git a/src/annbatch/loader.py b/src/annbatch/loader.py index 7a2af448..b9292ef3 100644 --- a/src/annbatch/loader.py +++ b/src/annbatch/loader.py @@ -249,7 +249,8 @@ def use_collection( The collection who on-disk datasets should be used in this loader. load_adata A custom load function - recall that whatever is found in :attr:`~anndata.AnnData.X` and :attr:`~anndata.AnnData.obs` will be yielded in batches. - Default is to just load `X` and `obs`. + Default is to just load `X` and all of `obs`. + This default behavior can degrade performance if you don't need all columns in `obs` - it is recommended to use the `load_adata` argument. """ if collection.is_empty: raise ValueError("DatasetCollection is empty") @@ -272,7 +273,7 @@ def add_anndatas( Parameters ---------- adatas - List of :class:`anndata.AnnData` objects, with :class:`zarr.Array` or :class:`anndata.abc.CSRDataset` as the data matrix in :attr:`~anndata.AnnData.X`, and :attr:`~anndata.AnnData.obs` containing labels to yield in a :class:`pandas.DataFrame`. + List of :class:`anndata.AnnData` objects, with :class:`zarr.Array` or :class:`anndata.abc.CSRDataset` as the data matrix in :attr:`~anndata.AnnData.X`, and :attr:`~anndata.AnnData.obs` containing annotations to yield in a :class:`pandas.DataFrame`. """ check_lt_1([len(adatas)], ["Number of anndatas"]) for adata in adatas: @@ -289,10 +290,12 @@ def add_anndata(self, adata: ad.AnnData) -> Self: Parameters ---------- adata - A :class:`anndata.AnnData` object, with :class:`zarr.Array` or :class:`anndata.abc.CSRDataset` as the data matrix in :attr:`~anndata.AnnData.X`, and :attr:`~anndata.AnnData.obs` containing labels to yield in a :class:`pandas.DataFrame`. + A :class:`anndata.AnnData` object, with :class:`zarr.Array` or :class:`anndata.abc.CSRDataset` as the data matrix in :attr:`~anndata.AnnData.X`, and :attr:`~anndata.AnnData.obs` containing annotations to yield in a :class:`pandas.DataFrame`. """ dataset = adata.X obs = adata.obs + if len(obs.columns) == 0: + obs = None if not isinstance(dataset, BackingArray_T.__value__): raise TypeError(f"Found {type(dataset)} but only {BackingArray_T.__value__} are usable") self.add_dataset(cast("BackingArray", dataset), obs) @@ -308,7 +311,7 @@ def add_datasets(self, datasets: list[BackingArray], obs: list[pd.DataFrame] | N List of :class:`zarr.Array` or :class:`anndata.abc.CSRDataset` objects, generally from :attr:`anndata.AnnData.X`. They must all be of the same type and match that of any already added datasets. obs - List of :class:`~pandas.DataFrame` labels, generally from :attr:`anndata.AnnData.obs`. + List of :class:`~pandas.DataFrame` obs, generally from :attr:`anndata.AnnData.obs`. """ if obs is None: obs = [None] * len(datasets) @@ -325,7 +328,7 @@ def add_dataset(self, dataset: BackingArray, obs: pd.DataFrame | None = None) -> dataset A :class:`zarr.Array` or :class:`anndata.abc.CSRDataset` object, generally from :attr:`anndata.AnnData.X`. obs - :class:`~pandas.DataFrame` labels, generally from :attr:`anndata.AnnData.obs`. + :class:`~pandas.DataFrame` obs, generally from :attr:`anndata.AnnData.obs`. """ self._add_dataset_unchecked(dataset, obs) return self @@ -334,11 +337,11 @@ def _add_dataset_unchecked(self, dataset: BackingArray, obs: pd.DataFrame | None if len(self._train_datasets) > 0: if self._obs is None and obs is not None: raise ValueError( - f"Cannot add a dataset with obs label {obs} when training datasets have already been added without labels" + f"Cannot add a dataset with obs label {obs} when training datasets have already been added without obs" ) if self._obs is not None and obs is None: raise ValueError( - "Cannot add a dataset with no obs label when training datasets have already been added without labels" + "Cannot add a dataset with no obs label when training datasets have already been added without obs" ) if not isinstance(dataset, self.dataset_type): raise ValueError( @@ -350,15 +353,15 @@ def _add_dataset_unchecked(self, dataset: BackingArray, obs: pd.DataFrame | None raise TypeError( "Cannot add CSRDataset backed by h5ad at the moment: see https://github.com/zarr-developers/VirtualiZarr/pull/790" ) - if not isinstance(obs, pd.DataFrame): + if not isinstance(obs, pd.DataFrame) and obs is not None: raise TypeError("obs must be a pandas DataFrame") datasets = self._train_datasets + [dataset] check_var_shapes(datasets) self._shapes = self._shapes + [dataset.shape] self._train_datasets = datasets - if self._obs is not None: + if self._obs is not None: # obs exist self._obs += [obs] - elif obs is not None: + elif obs is not None: # obs dont exist yet, but are being added for the first time self._obs = [obs] return self @@ -595,7 +598,7 @@ def __iter__( Yields ------ - A batch of data along with its labels and index (both optional). + A batch of data along with its obs and index (both optional). """ check_lt_1( [len(self._train_datasets), self.n_obs], diff --git a/src/annbatch/types.py b/src/annbatch/types.py index 9d9a597a..508ff8a1 100644 --- a/src/annbatch/types.py +++ b/src/annbatch/types.py @@ -34,8 +34,8 @@ class LoadRequest(TypedDict): class LoaderOutput[OutputInMemoryArray: OutputInMemoryArray_T](TypedDict): - """The output of the loader, the "data matrix" with its labels, optional, and index, also optional.""" + """The output of the loader, the "data matrix" with its obs, optional, and index, also optional.""" - data: OutputInMemoryArray_T.__value__ # TODO: remove after sphinx 9 - myst compat - labels: pd.DataFrame | None + X: OutputInMemoryArray_T.__value__ # TODO: remove after sphinx 9 - myst compat + obs: pd.DataFrame | None index: np.ndarray | None diff --git a/src/annbatch/utils.py b/src/annbatch/utils.py index e558b861..78d23cbd 100644 --- a/src/annbatch/utils.py +++ b/src/annbatch/utils.py @@ -110,17 +110,17 @@ def get_part_for_worker(self, obj: np.ndarray) -> np.ndarray: return chunks_split[worker_id] -def check_lt_1(vals: list[int], labels: list[str]) -> None: +def check_lt_1(vals: list[int], obs: list[str]) -> None: """Raise a ValueError if any of the values are less than one. - The format of the error is "{labels[i]} must be greater than 1, got {values[i]}" + The format of the error is "{obs[i]} must be greater than 1, got {values[i]}" and is raised based on the first found less than one value. Parameters ---------- vals The values to check < 1 - labels + obs The label for the value in the error if the value is less than one. Raises @@ -131,7 +131,7 @@ def check_lt_1(vals: list[int], labels: list[str]) -> None: label, value = next( (label, value) for label, value, check in zip( - labels, + obs, vals, is_lt_1, strict=True, diff --git a/tests/test_dataset.py b/tests/test_dataset.py index b3bcd877..089b380b 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -166,17 +166,17 @@ def test_store_load_dataset( is_dense = loader.dataset_type is zarr.Array n_elems = 0 batches = [] - labels = [] + obs = [] indices = [] expected_data = adata.X if is_dense else adata.layers["sparse"].toarray() for batch in loader: - x, label, index = batch["data"], batch["labels"], batch["index"] + x, label, index = batch["X"], batch["obs"], batch["index"] n_elems += x.shape[0] # Check feature dimension assert x.shape[1] == 100 batches += [x.get() if isinstance(x, CupyCSRMatrix | CupyArray) else x] if label is not None: - labels += [label] + obs += [label] if index is not None: indices += [index] # check that we yield all samples from the dataset @@ -186,10 +186,10 @@ def test_store_load_dataset( stacked = stacked.toarray() if not shuffle: np.testing.assert_allclose(stacked, expected_data) - if len(labels) > 0: + if len(obs) > 0: expected_labels = adata.obs pd.testing.assert_frame_equal( - pd.concat(labels), + pd.concat(obs), expected_labels, ) else: @@ -265,7 +265,7 @@ def test_to_torch( to_torch=True, ) ds.add_dataset(**open_func(next(adata_with_zarr_path_same_var_space[1].glob("*.zarr")))) - assert isinstance(next(iter(ds))["data"], torch.Tensor) + assert isinstance(next(iter(ds))["X"], torch.Tensor) @pytest.mark.parametrize("drop_last", [True, False], ids=["drop", "kept"]) @@ -290,7 +290,7 @@ def test_drop_last(adata_with_zarr_path_same_var_space: tuple[ad.AnnData, Path], batches = [] indices = [] for batch in ds: - batches += [batch["data"]] + batches += [batch["X"]] indices += [batch["index"]] total_obs = adata.shape[0] leftover = total_obs % batch_size @@ -315,12 +315,12 @@ def test_bad_adata_X_hdf5(adata_with_h5_path_different_var_space: tuple[ad.AnnDa def _custom_collate_fn(elems): import torch - if isinstance(elems[0]["data"], torch.Tensor): - x = torch.vstack([v["data"].to_dense() for v in elems]) - elif isinstance(elems[0]["data"], sp.csr_matrix): - x = sp.vstack([v["data"] for v in elems]).toarray() + if isinstance(elems[0]["X"], torch.Tensor): + x = torch.vstack([v["X"].to_dense() for v in elems]) + elif isinstance(elems[0]["X"], sp.csr_matrix): + x = sp.vstack([v["X"] for v in elems]).toarray() else: - x = np.vstack([v["data"] for v in elems]) + x = np.vstack([v["X"] for v in elems]) y = np.array([v["index"] for v in elems]) @@ -415,8 +415,20 @@ def test_default_data_structures( list(adata_with_zarr_path_same_var_space[1].iterdir())[0] ) ) - for batch in ds: - assert isinstance(batch["data"], expected_cls) + assert isinstance(next(iter(ds))["X"], expected_cls) + + +def test_no_obs(simple_collection: tuple[ad.AnnData, DatasetCollection]): + # No obs loaded is actually None + ds = Loader( + chunk_size=10, + preload_nchunks=4, + batch_size=22, + ).use_collection( + simple_collection[1], + load_adata=lambda g: ad.AnnData(X=ad.io.sparse_dataset(g["layers"]["sparse"])), + ) + assert next(iter(ds))["obs"] is None def test_add_dataset_validation_failure_preserves_state(adata_with_zarr_path_same_var_space: tuple[ad.AnnData, Path]): From 679ee5018150663ae5c50d285302e442b8158384 Mon Sep 17 00:00:00 2001 From: Ilan Gold Date: Mon, 19 Jan 2026 12:24:45 +0100 Subject: [PATCH 07/56] fix: header level (#116) --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index fb3263ea..247d2b03 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,7 +8,7 @@ and this project adheres to [Semantic Versioning][]. [keep a changelog]: https://keepachangelog.com/en/1.0.0/ [semantic versioning]: https://semver.org/spec/v2.0.0.html -# [0.0.3] +## [0.0.3] ### Breaking From ad1ec5544a1e878ce4dc58bfa697202ab01cdf49 Mon Sep 17 00:00:00 2001 From: selmanozleyen Date: Mon, 19 Jan 2026 18:56:35 +0100 Subject: [PATCH 08/56] merge changes --- src/annbatch/loader.py | 17 +++++++---------- tests/test_dataset.py | 4 ++-- 2 files changed, 9 insertions(+), 12 deletions(-) diff --git a/src/annbatch/loader.py b/src/annbatch/loader.py index b9292ef3..a4129cff 100644 --- a/src/annbatch/loader.py +++ b/src/annbatch/loader.py @@ -290,7 +290,7 @@ def add_anndata(self, adata: ad.AnnData) -> Self: Parameters ---------- adata - A :class:`anndata.AnnData` object, with :class:`zarr.Array` or :class:`anndata.abc.CSRDataset` as the data matrix in :attr:`~anndata.AnnData.X`, and :attr:`~anndata.AnnData.obs` containing annotations to yield in a :class:`pandas.DataFrame`. + A :class:`anndata.AnnData` object, with :class:`zarr.Array` or :class:`anndata.abc.CSRDataset` as the data matrix in :attr:`~anndata.AnnData.X`, and :attr:`~anndata.AnnData.obs` containing labels to yield in a :class:`pandas.DataFrame`. """ dataset = adata.X obs = adata.obs @@ -311,7 +311,7 @@ def add_datasets(self, datasets: list[BackingArray], obs: list[pd.DataFrame] | N List of :class:`zarr.Array` or :class:`anndata.abc.CSRDataset` objects, generally from :attr:`anndata.AnnData.X`. They must all be of the same type and match that of any already added datasets. obs - List of :class:`~pandas.DataFrame` obs, generally from :attr:`anndata.AnnData.obs`. + List of :class:`~pandas.DataFrame` labels, generally from :attr:`anndata.AnnData.obs`. """ if obs is None: obs = [None] * len(datasets) @@ -604,11 +604,6 @@ def __iter__( [len(self._train_datasets), self.n_obs], ["Number of datasets", "Number of observations"], ) - - in_memory_data = None - concatenated_obs = None - in_memory_indices = None - mod = self._sp_module if issubclass(self.dataset_type, ad.abc.CSRDataset) else np for load_request in self._batch_sampler.sample(self.n_obs): @@ -624,6 +619,8 @@ def __iter__( indices: None | list[np.ndarray] = self._maybe_accumulate_indices(chunks_to_load) in_memory_data = mod.vstack(chunks_converted) + concatenated_obs = None + in_memory_indices = None if self._obs is not None and obs is not None: concatenated_obs = pd.concat(obs) if self._return_index and indices is not None: @@ -684,12 +681,12 @@ def _prepare_output( ) -> LoaderOutput: """Prepare the final output dict for a single batch.""" index = None - labels = None + obs = None if self._obs is not None and concatenated_obs is not None: - labels = concatenated_obs.iloc[split] + obs = concatenated_obs.iloc[split] if self._return_index and in_memory_indices is not None: index = in_memory_indices[split] data = in_memory_data[split] if self._to_torch: data = to_torch(data, self._preload_to_gpu) - return {"data": data, "labels": labels, "index": index} + return {"X": data, "obs": obs, "index": index} diff --git a/tests/test_dataset.py b/tests/test_dataset.py index 089b380b..c4d250c7 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -265,7 +265,7 @@ def test_to_torch( to_torch=True, ) ds.add_dataset(**open_func(next(adata_with_zarr_path_same_var_space[1].glob("*.zarr")))) - assert isinstance(next(iter(ds))["X"], torch.Tensor) + assert isinstance(next(iter(ds))["data"], torch.Tensor) @pytest.mark.parametrize("drop_last", [True, False], ids=["drop", "kept"]) @@ -423,7 +423,7 @@ def test_no_obs(simple_collection: tuple[ad.AnnData, DatasetCollection]): ds = Loader( chunk_size=10, preload_nchunks=4, - batch_size=22, + batch_size=20, ).use_collection( simple_collection[1], load_adata=lambda g: ad.AnnData(X=ad.io.sparse_dataset(g["layers"]["sparse"])), From c8b43959fc720b5b9b30ff9683eb0f6150a43cee Mon Sep 17 00:00:00 2001 From: selmanozleyen Date: Mon, 19 Jan 2026 19:11:21 +0100 Subject: [PATCH 09/56] apply suggestions --- src/annbatch/sampler/_sampler.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/annbatch/sampler/_sampler.py b/src/annbatch/sampler/_sampler.py index e0aef84b..d418275c 100644 --- a/src/annbatch/sampler/_sampler.py +++ b/src/annbatch/sampler/_sampler.py @@ -148,7 +148,7 @@ def __init__( preload_nchunks, slice(start, stop), drop_last, - ) # stop can be None + ) def validate(self, n_obs: int) -> None: """Validate the sampler configuration against the loader's n_obs. @@ -207,10 +207,11 @@ def _sample(self, n_obs: int) -> Iterator[LoadRequest]: # Set up the iterator for chunks and the batch indices for splits in_memory_size = self._chunk_size * self._preload_nchunks chunks_per_batch = split_given_size(chunks, self._preload_nchunks) - batch_indices = np.arange(in_memory_size) # to avoid copies use in-place shuffling + batch_indices = np.arange(in_memory_size) split_batch_indices = split_given_size(batch_indices, self._batch_size) for batch_chunks in chunks_per_batch[:-1]: if self._shuffle: + # Avoid copies using in-place shuffling since `self._shuffle` should not change mid-training self._rng.shuffle(batch_indices) split_batch_indices = split_given_size(batch_indices, self._batch_size) yield {"chunks": batch_chunks, "splits": split_batch_indices} From 62d7d48f58f3ce7325ed414030be4d2a5913b0e8 Mon Sep 17 00:00:00 2001 From: selmanozleyen Date: Mon, 19 Jan 2026 19:12:22 +0100 Subject: [PATCH 10/56] checkout readme from main --- README.md | 16 ++++------------ 1 file changed, 4 insertions(+), 12 deletions(-) diff --git a/README.md b/README.md index a87b26bf..391303c1 100644 --- a/README.md +++ b/README.md @@ -28,9 +28,7 @@ > [!CAUTION] > This package does not have a stable API. -> However, we do not anticipate the on-disk format to change in a fully incompatible manner. -> Small changes to how we store the shuffled data may occur but you should always be able to load your data somehow i.e., they will never be fully breaking. -> We will always provide lower-level APIs that should make this guarantee possible. +> However, we do not anticipate the on-disk format to change in an incompatible manner. [![Tests][badge-tests]][tests] [![Documentation][badge-docs]][documentation] @@ -97,9 +95,6 @@ collection.add_adatas( Data loading: -> [!IMPORTANT] -> Without custom loading via `Loader.load_adata` *all* obs columns will be loaded and yielded potentially degrading performance. - ```python from pathlib import Path @@ -113,12 +108,9 @@ zarr.config.set( {"codec_pipeline.path": "zarrs.ZarrsCodecPipeline"} ) -# WARNING: Without custom loading *all* obs columns will be loaded and yielded potentially degrading performance. def custom_load_func(g: zarr.Group) -> ad.AnnData: - return ad.AnnData(X=ad.io.sparse_dataset(g["layers"]["counts"]), obs=ad.io.read_elem(g["obs"])[some_subset_of_columns_useful_for_training]) + return ad.AnnData(X=ad.io.sparse_dataset(g["layers"]["counts"]), obs=ad.io.read_elem(g["obs"])[some_subset_of_columns]) -# A non empty collection -collection = DatasetCollection("path/to/output/collection.zarr") # This settings override ensures that you don't lose/alter your categorical codes when reading the data in! with ad.settings.override(remove_unused_categories=False): ds = Loader( @@ -129,11 +121,11 @@ with ad.settings.override(remove_unused_categories=False): # `use_collection` automatically uses the on-disk `X` and full `obs` in the `Loader` # but the `load_adata` arg can override this behavior # (see `custom_load_func` above for an example of customization). - ds = ds.use_collection(collection, load_adata = custom_load_func) + ds = ds.use_collection(collection) # Iterate over dataloader (plugin replacement for torch.utils.DataLoader) for batch in ds: - data, obs = batch["X"], batch["obs"] + ... ``` > [!IMPORTANT] From d7539bf3ea2a11d20590b20ffeb7facc4e923f45 Mon Sep 17 00:00:00 2001 From: Ilan Gold Date: Mon, 19 Jan 2026 12:12:31 +0100 Subject: [PATCH 11/56] breaking: clarify obs handling + change output keys (#115) * chore: clarify obs handling * chore: clearer docs * fix: rename --- CHANGELOG.md | 2 ++ README.md | 10 +++++++--- src/annbatch/loader.py | 4 ++-- tests/test_dataset.py | 2 +- 4 files changed, 12 insertions(+), 6 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 247d2b03..9852d6cc 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,6 +12,8 @@ and this project adheres to [Semantic Versioning][]. ### Breaking +### Breaking + - Revert `h5ad` shuffling into one big store (i.e., go back to sharding into individual files) and add warning that `h5ad` is not fully supported by `annbatch`. `is_collection_h5ad` argument to initialization of {class}`annbatch.DatasetCollection` must be passed when initializing into to use a preshuffled collection of `h5ad` files, reading or writing. - Renamed {class}`annbatch.types.LoaderOutput` `["labels"]` and `["data"]` to `["obs"]` and `["X"]` respectively. diff --git a/README.md b/README.md index 391303c1..95e0713d 100644 --- a/README.md +++ b/README.md @@ -95,6 +95,9 @@ collection.add_adatas( Data loading: +> [!IMPORTANT] +> Without custom loading via `Loader.load_adata` *all* obs columns will be loaded and yielded potentially degrading performance. + ```python from pathlib import Path @@ -108,8 +111,9 @@ zarr.config.set( {"codec_pipeline.path": "zarrs.ZarrsCodecPipeline"} ) +# WARNING: Without custom loading *all* obs columns will be loaded and yielded potentially degrading performance. def custom_load_func(g: zarr.Group) -> ad.AnnData: - return ad.AnnData(X=ad.io.sparse_dataset(g["layers"]["counts"]), obs=ad.io.read_elem(g["obs"])[some_subset_of_columns]) + return ad.AnnData(X=ad.io.sparse_dataset(g["layers"]["counts"]), obs=ad.io.read_elem(g["obs"])[some_subset_of_columns_useful_for_training]) # This settings override ensures that you don't lose/alter your categorical codes when reading the data in! with ad.settings.override(remove_unused_categories=False): @@ -121,11 +125,11 @@ with ad.settings.override(remove_unused_categories=False): # `use_collection` automatically uses the on-disk `X` and full `obs` in the `Loader` # but the `load_adata` arg can override this behavior # (see `custom_load_func` above for an example of customization). - ds = ds.use_collection(collection) + ds = ds.use_collection(collection, load_adata = custom_load_func) # Iterate over dataloader (plugin replacement for torch.utils.DataLoader) for batch in ds: - ... + data, obs = batch["X"], batch["obs"] ``` > [!IMPORTANT] diff --git a/src/annbatch/loader.py b/src/annbatch/loader.py index a4129cff..db769dfb 100644 --- a/src/annbatch/loader.py +++ b/src/annbatch/loader.py @@ -290,7 +290,7 @@ def add_anndata(self, adata: ad.AnnData) -> Self: Parameters ---------- adata - A :class:`anndata.AnnData` object, with :class:`zarr.Array` or :class:`anndata.abc.CSRDataset` as the data matrix in :attr:`~anndata.AnnData.X`, and :attr:`~anndata.AnnData.obs` containing labels to yield in a :class:`pandas.DataFrame`. + A :class:`anndata.AnnData` object, with :class:`zarr.Array` or :class:`anndata.abc.CSRDataset` as the data matrix in :attr:`~anndata.AnnData.X`, and :attr:`~anndata.AnnData.obs` containing annotations to yield in a :class:`pandas.DataFrame`. """ dataset = adata.X obs = adata.obs @@ -311,7 +311,7 @@ def add_datasets(self, datasets: list[BackingArray], obs: list[pd.DataFrame] | N List of :class:`zarr.Array` or :class:`anndata.abc.CSRDataset` objects, generally from :attr:`anndata.AnnData.X`. They must all be of the same type and match that of any already added datasets. obs - List of :class:`~pandas.DataFrame` labels, generally from :attr:`anndata.AnnData.obs`. + List of :class:`~pandas.DataFrame` obs, generally from :attr:`anndata.AnnData.obs`. """ if obs is None: obs = [None] * len(datasets) diff --git a/tests/test_dataset.py b/tests/test_dataset.py index c4d250c7..127cd974 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -265,7 +265,7 @@ def test_to_torch( to_torch=True, ) ds.add_dataset(**open_func(next(adata_with_zarr_path_same_var_space[1].glob("*.zarr")))) - assert isinstance(next(iter(ds))["data"], torch.Tensor) + assert isinstance(next(iter(ds))["X"], torch.Tensor) @pytest.mark.parametrize("drop_last", [True, False], ids=["drop", "kept"]) From 0d7764d432464ce1a75349e212896c0668e7eb3e Mon Sep 17 00:00:00 2001 From: selmanozleyen Date: Sun, 18 Jan 2026 19:37:26 +0100 Subject: [PATCH 12/56] parent 627eb08d699b9cb07ab24fa67775e1e794c07245 author selmanozleyen 1768761446 +0100 committer selmanozleyen 1768846971 +0100 resolve conflicts with main update tests to resolve conflict readthedocs merge merge changes apply suggestions checkout readme from main --- README.md | 10 +++------- src/annbatch/loader.py | 4 ++-- tests/test_dataset.py | 2 +- 3 files changed, 6 insertions(+), 10 deletions(-) diff --git a/README.md b/README.md index 95e0713d..391303c1 100644 --- a/README.md +++ b/README.md @@ -95,9 +95,6 @@ collection.add_adatas( Data loading: -> [!IMPORTANT] -> Without custom loading via `Loader.load_adata` *all* obs columns will be loaded and yielded potentially degrading performance. - ```python from pathlib import Path @@ -111,9 +108,8 @@ zarr.config.set( {"codec_pipeline.path": "zarrs.ZarrsCodecPipeline"} ) -# WARNING: Without custom loading *all* obs columns will be loaded and yielded potentially degrading performance. def custom_load_func(g: zarr.Group) -> ad.AnnData: - return ad.AnnData(X=ad.io.sparse_dataset(g["layers"]["counts"]), obs=ad.io.read_elem(g["obs"])[some_subset_of_columns_useful_for_training]) + return ad.AnnData(X=ad.io.sparse_dataset(g["layers"]["counts"]), obs=ad.io.read_elem(g["obs"])[some_subset_of_columns]) # This settings override ensures that you don't lose/alter your categorical codes when reading the data in! with ad.settings.override(remove_unused_categories=False): @@ -125,11 +121,11 @@ with ad.settings.override(remove_unused_categories=False): # `use_collection` automatically uses the on-disk `X` and full `obs` in the `Loader` # but the `load_adata` arg can override this behavior # (see `custom_load_func` above for an example of customization). - ds = ds.use_collection(collection, load_adata = custom_load_func) + ds = ds.use_collection(collection) # Iterate over dataloader (plugin replacement for torch.utils.DataLoader) for batch in ds: - data, obs = batch["X"], batch["obs"] + ... ``` > [!IMPORTANT] diff --git a/src/annbatch/loader.py b/src/annbatch/loader.py index db769dfb..a4129cff 100644 --- a/src/annbatch/loader.py +++ b/src/annbatch/loader.py @@ -290,7 +290,7 @@ def add_anndata(self, adata: ad.AnnData) -> Self: Parameters ---------- adata - A :class:`anndata.AnnData` object, with :class:`zarr.Array` or :class:`anndata.abc.CSRDataset` as the data matrix in :attr:`~anndata.AnnData.X`, and :attr:`~anndata.AnnData.obs` containing annotations to yield in a :class:`pandas.DataFrame`. + A :class:`anndata.AnnData` object, with :class:`zarr.Array` or :class:`anndata.abc.CSRDataset` as the data matrix in :attr:`~anndata.AnnData.X`, and :attr:`~anndata.AnnData.obs` containing labels to yield in a :class:`pandas.DataFrame`. """ dataset = adata.X obs = adata.obs @@ -311,7 +311,7 @@ def add_datasets(self, datasets: list[BackingArray], obs: list[pd.DataFrame] | N List of :class:`zarr.Array` or :class:`anndata.abc.CSRDataset` objects, generally from :attr:`anndata.AnnData.X`. They must all be of the same type and match that of any already added datasets. obs - List of :class:`~pandas.DataFrame` obs, generally from :attr:`anndata.AnnData.obs`. + List of :class:`~pandas.DataFrame` labels, generally from :attr:`anndata.AnnData.obs`. """ if obs is None: obs = [None] * len(datasets) diff --git a/tests/test_dataset.py b/tests/test_dataset.py index 127cd974..c4d250c7 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -265,7 +265,7 @@ def test_to_torch( to_torch=True, ) ds.add_dataset(**open_func(next(adata_with_zarr_path_same_var_space[1].glob("*.zarr")))) - assert isinstance(next(iter(ds))["X"], torch.Tensor) + assert isinstance(next(iter(ds))["data"], torch.Tensor) @pytest.mark.parametrize("drop_last", [True, False], ids=["drop", "kept"]) From f7742a1c05c45a3e135249f714c1b7fec1b1f863 Mon Sep 17 00:00:00 2001 From: selmanozleyen Date: Mon, 19 Jan 2026 19:32:39 +0100 Subject: [PATCH 13/56] restore from main --- CHANGELOG.md | 2 -- README.md | 16 ++++++++++++---- 2 files changed, 12 insertions(+), 6 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 9852d6cc..247d2b03 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,8 +12,6 @@ and this project adheres to [Semantic Versioning][]. ### Breaking -### Breaking - - Revert `h5ad` shuffling into one big store (i.e., go back to sharding into individual files) and add warning that `h5ad` is not fully supported by `annbatch`. `is_collection_h5ad` argument to initialization of {class}`annbatch.DatasetCollection` must be passed when initializing into to use a preshuffled collection of `h5ad` files, reading or writing. - Renamed {class}`annbatch.types.LoaderOutput` `["labels"]` and `["data"]` to `["obs"]` and `["X"]` respectively. diff --git a/README.md b/README.md index 391303c1..a87b26bf 100644 --- a/README.md +++ b/README.md @@ -28,7 +28,9 @@ > [!CAUTION] > This package does not have a stable API. -> However, we do not anticipate the on-disk format to change in an incompatible manner. +> However, we do not anticipate the on-disk format to change in a fully incompatible manner. +> Small changes to how we store the shuffled data may occur but you should always be able to load your data somehow i.e., they will never be fully breaking. +> We will always provide lower-level APIs that should make this guarantee possible. [![Tests][badge-tests]][tests] [![Documentation][badge-docs]][documentation] @@ -95,6 +97,9 @@ collection.add_adatas( Data loading: +> [!IMPORTANT] +> Without custom loading via `Loader.load_adata` *all* obs columns will be loaded and yielded potentially degrading performance. + ```python from pathlib import Path @@ -108,9 +113,12 @@ zarr.config.set( {"codec_pipeline.path": "zarrs.ZarrsCodecPipeline"} ) +# WARNING: Without custom loading *all* obs columns will be loaded and yielded potentially degrading performance. def custom_load_func(g: zarr.Group) -> ad.AnnData: - return ad.AnnData(X=ad.io.sparse_dataset(g["layers"]["counts"]), obs=ad.io.read_elem(g["obs"])[some_subset_of_columns]) + return ad.AnnData(X=ad.io.sparse_dataset(g["layers"]["counts"]), obs=ad.io.read_elem(g["obs"])[some_subset_of_columns_useful_for_training]) +# A non empty collection +collection = DatasetCollection("path/to/output/collection.zarr") # This settings override ensures that you don't lose/alter your categorical codes when reading the data in! with ad.settings.override(remove_unused_categories=False): ds = Loader( @@ -121,11 +129,11 @@ with ad.settings.override(remove_unused_categories=False): # `use_collection` automatically uses the on-disk `X` and full `obs` in the `Loader` # but the `load_adata` arg can override this behavior # (see `custom_load_func` above for an example of customization). - ds = ds.use_collection(collection) + ds = ds.use_collection(collection, load_adata = custom_load_func) # Iterate over dataloader (plugin replacement for torch.utils.DataLoader) for batch in ds: - ... + data, obs = batch["X"], batch["obs"] ``` > [!IMPORTANT] From 3d73e3a37c4fcf0646cce2faa9cb8069a5749522 Mon Sep 17 00:00:00 2001 From: selmanozleyen Date: Mon, 19 Jan 2026 19:46:01 +0100 Subject: [PATCH 14/56] fix: checking out: confused origin and upstream again... --- CHANGELOG.md | 16 +- docs/api.md | 5 +- docs/conf.py | 18 +- docs/index.md | 18 +- docs/notebooks/example.ipynb | 798 +++++++++++++++++------------------ src/annbatch/io.py | 733 ++++++++++++-------------------- 6 files changed, 707 insertions(+), 881 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 247d2b03..1f6f802e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,21 +8,7 @@ and this project adheres to [Semantic Versioning][]. [keep a changelog]: https://keepachangelog.com/en/1.0.0/ [semantic versioning]: https://semver.org/spec/v2.0.0.html -## [0.0.3] - -### Breaking - -- Revert `h5ad` shuffling into one big store (i.e., go back to sharding into individual files) and add warning that `h5ad` is not fully supported by `annbatch`. `is_collection_h5ad` argument to initialization of {class}`annbatch.DatasetCollection` must be passed when initializing into to use a preshuffled collection of `h5ad` files, reading or writing. -- Renamed {class}`annbatch.types.LoaderOutput` `["labels"]` and `["data"]` to `["obs"]` and `["X"]` respectively. - -## [0.0.2] - -### Breaking - -- `ZarrSparseDataset` and `ZarrDenseDataset` have been conslidated into {class}`annbatch.Loader` -- `create_anndata_collection` and `add_to_collection` have been moved into the {meth}`annbatch.DatasetCollection.add_adatas` method -- Default reading of input data is now fully lazy in {meth}`annbatch.DatasetCollection.add_adatas`, and therefore the shuffle process may now be slower although have better memory properties. Use `load_adata` argument in {meth}`annbatch.DatasetCollection.add_adatas` to customize this behavior. -- Files shuffled under the old `create_anndata_collection` will not be recognized by {class}`annbatch.DatasetCollection` and therefore are not usable with the new {class}`annbatch.Loader.use_collection` API. At the moment, the file metadata we maintain is only for internal purposes - however, if you wish to migrate to be able to use {class}`annbatch.DatasetCollection` in conjunction with {class}`annbatch.Loader.use_collection`, the root folder of the old collection must have attrs `{"encoding-type": "annbatch-preshuffled", "encoding-version": "0.1.0"}` and be a {class}`zarr.Group`. The subfolders (i.e., datasets) must be called `dataset_([0-9]*)`. Otherwise you can use the {meth}`annbatch.Loader.add_anndatas` as before. +## [Unreleased] ### Changed diff --git a/docs/api.md b/docs/api.md index cf399fd6..263a9c52 100644 --- a/docs/api.md +++ b/docs/api.md @@ -25,7 +25,8 @@ :toctree: generated/ write_sharded - DatasetCollection + add_to_collection + create_anndata_collection ``` (types)= @@ -35,5 +36,5 @@ .. autosummary:: :toctree: generated/ - types.LoaderOutput + types.BackingArray_T ``` diff --git a/docs/conf.py b/docs/conf.py index 10b83dd5..98ac5f1f 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -1,14 +1,23 @@ -from __future__ import annotations +# Configuration file for the Sphinx documentation builder. +# This file only contains a selection of the most common options. For a full +# list see the documentation: +# https://www.sphinx-doc.org/en/master/usage/configuration.html + +# -- Path setup -------------------------------------------------------------- import sys from datetime import datetime from importlib.metadata import metadata - -# -- Path setup -------------------------------------------------------------- from pathlib import Path +# For some reason doing this prevents autodoc_mock_import = ["torch"] from not being able to find the module i.e., it's not in sys.modules. +# TODO: Bug report +import annbatch # noqa: F401 + HERE = Path(__file__).parent sys.path.insert(0, str(HERE / "extensions")) + + # -- Project information ----------------------------------------------------- # NOTE: If you installed your project in editable mode, this might be stale. @@ -55,13 +64,13 @@ "IPython.sphinxext.ipython_console_highlighting", "sphinxext.opengraph", "sphinx_issues", - "sphinx_toolbox.more_autodoc.autotypeddict", "scanpydoc", # needs to be before linkcode *[p.stem for p in (HERE / "extensions").glob("*.py")], ] autosummary_generate = True autodoc_member_order = "groupwise" +autodoc_mock_imports = ["torch"] default_role = "literal" napoleon_google_docstring = False napoleon_numpy_docstring = True @@ -100,7 +109,6 @@ "scipy": ("https://docs.scipy.org/doc/scipy", None), "cupy": ("https://docs.cupy.dev/en/stable/", None), "zarrs": ("https://zarrs-python.readthedocs.io/en/latest/", None), - "pandas": ("https://pandas.pydata.org/pandas-docs/stable", None), } # List of patterns, relative to source directory, that match files and diff --git a/docs/index.md b/docs/index.md index d94024b5..9043ce97 100644 --- a/docs/index.md +++ b/docs/index.md @@ -9,11 +9,12 @@ Let's go through the above example: ### Preprocessing ```python -colleciton = DatasetCollection("path/to/output/store.zarr").add_adatas( +create_anndata_collection( adata_paths=[ "path/to/your/file1.h5ad", "path/to/your/file2.h5ad" ], + output_path="path/to/output/store", # a directory containing `chunk_{i}.zarr` shuffle=True, # shuffling is needed if you want to use chunked access ) ``` @@ -32,16 +33,25 @@ See the [zarr docs on sharding][] for more information. #### Chunked access ```python -# `use_collection` will automatically get everything in `X` and `obs` and yield it. ds = Loader( batch_size=4096, chunk_size=32, preload_nchunks=256, -).use_collection(collection) +).add_anndatas( + [ + ad.AnnData( + # note that you can open an anndata file using any type of zarr store + X=ad.io.sparse_dataset(zarr.open(p)["X"]), + obs=ad.io.read_elem(zarr.open(p)["obs"]), + ) + for p in PATH_TO_STORE.glob("*.zarr") + ], + obs_keys="label_column", +) # Iterate over dataloader (plugin replacement for torch.utils.DataLoader) for batch in ds: - x, df, index = batch["X"], batch["obs"], batch["index"] + ... ``` The data loader implements a chunked fetching strategy where `preload_nchunks` number of continguous-chunks of size `chunk_size` are loaded. diff --git a/docs/notebooks/example.ipynb b/docs/notebooks/example.ipynb index f7085129..3f8aef26 100644 --- a/docs/notebooks/example.ipynb +++ b/docs/notebooks/example.ipynb @@ -1,403 +1,401 @@ { - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Quickstart `annbatch`" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "This notebook will walk you through the following steps:\n", - "1. How to convert an existing collection of `anndata` files into a shuffled, zarr-based, collection of `anndata` datasets\n", - "2. How to load the converted collection using `annbatch`\n", - "3. Extend an existing collection with new `anndata` datasets" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": { - "tags": [ - "hide-output" - ] - }, - "outputs": [], - "source": [ - "# !pip install annbatch[zarrs, torch]" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": { - "tags": [ - "hide-output" - ] - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "zsh:1: command not found: wget\n", - "zsh:1: command not found: wget\n" - ] - } - ], - "source": [ - "# Download two example datasets from CELLxGENE\n", - "!wget https://datasets.cellxgene.cziscience.com/866d7d5e-436b-4dbd-b7c1-7696487d452e.h5ad\n", - "!wget https://datasets.cellxgene.cziscience.com/f81463b8-4986-4904-a0ea-20ff02cbb317.h5ad" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "**IMPORTANT**: Configure zarrs\n", - "\n", - "This step is both required for converting existing `anndata` files into a performant, shuffled collection of datasets for mini batch loading" - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "metadata": { - "tags": [ - "hide-output" - ] - }, - "outputs": [ - { - "data": { - "text/plain": [ - "" - ] - }, - "execution_count": 1, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "import zarr\n", - "\n", - "zarr.config.set({\"codec_pipeline.path\": \"zarrs.ZarrsCodecPipeline\"})" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [], - "source": [ - "import warnings\n", - "\n", - "# Suppress zarr vlen-utf8 codec warnings\n", - "warnings.filterwarnings(\n", - " \"ignore\",\n", - " message=\"The codec `vlen-utf8` is currently not part in the Zarr format 3 specification.*\",\n", - " category=UserWarning,\n", - " module=\"zarr.codecs.vlen_utf8\",\n", - ")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Converting existing `anndata` files into a shuffled collection" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "The conversion code will take care of the following things:\n", - "* Align (outer join) the gene spaces across all datasets listed in `adata_paths`\n", - " * The gene spaces are outer-joined based on the gene names provided in the `var_names` field of the individual `AnnData` objects.\n", - " * If you want to subset to specific gene space, you can provide a list of gene names via the `var_subset` parameter.\n", - "* Shuffle the cells across all datasets (this works on larger than memory datasets as well).\n", - " * This is important for block-wise shuffling during data loading.\n", - "* Shuffle the input files across multiple output datasets:\n", - " * The size of each individual output dataset can be controlled via the `n_obs_per_dataset` parameter.\n", - " * We recommend to choose a dataset size that comfortably fits into system memory.\n", - "\n", - "\n", - "You can apply custom data transformations to each input h5ad file by supplying a `load_adata` function to `DatasetCollection.add`" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": { - "tags": [ - "hide-output" - ] - }, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/Users/ilangold/Projects/Theis/annbatch/venv/lib/python3.12/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", - " from .autonotebook import tqdm as notebook_tqdm\n", - "checking for mismatched keys: 100%|██████████| 2/2 [00:00<00:00, 2.19it/s]\n", - "loading: 2it [00:00, 2.19it/s]\n", - "processing chunks: 0%| | 0/1 [00:00] 737.43M 398MB/s in 1.9s \n", + "\n", + "2025-10-09 09:43:21 (398 MB/s) - ‘866d7d5e-436b-4dbd-b7c1-7696487d452e.h5ad’ saved [773247972/773247972]\n", + "\n", + "--2025-10-09 09:43:22-- https://datasets.cellxgene.cziscience.com/f81463b8-4986-4904-a0ea-20ff02cbb317.h5ad\n", + "Resolving datasets.cellxgene.cziscience.com (datasets.cellxgene.cziscience.com)... 18.64.79.73, 18.64.79.80, 18.64.79.72, ...\n", + "Connecting to datasets.cellxgene.cziscience.com (datasets.cellxgene.cziscience.com)|18.64.79.73|:443... connected.\n", + "HTTP request sent, awaiting response... 200 OK\n", + "Length: 1631759823 (1.5G) [binary/octet-stream]\n", + "Saving to: ‘f81463b8-4986-4904-a0ea-20ff02cbb317.h5ad’\n", + "\n", + "f81463b8-4986-4904- 100%[===================>] 1.52G 425MB/s in 3.9s \n", + "\n", + "2025-10-09 09:43:26 (403 MB/s) - ‘f81463b8-4986-4904-a0ea-20ff02cbb317.h5ad’ saved [1631759823/1631759823]\n", + "\n" + ] + } + ], + "source": [ + "# Download two example datasets from CELLxGENE\n", + "!wget https://datasets.cellxgene.cziscience.com/866d7d5e-436b-4dbd-b7c1-7696487d452e.h5ad\n", + "!wget https://datasets.cellxgene.cziscience.com/f81463b8-4986-4904-a0ea-20ff02cbb317.h5ad" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "**IMPORTANT**: Configure zarrs\n", + "\n", + "This step is both required for converting existing `anndata` files into a performant, shuffled collection of datasets for mini batch loading" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "tags": [ + "hide-output" + ] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 1, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import zarr\n", + "import zarrs # noqa\n", + "\n", + "zarr.config.set({\"codec_pipeline.path\": \"zarrs.ZarrsCodecPipeline\"})" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "import warnings\n", + "\n", + "# Suppress zarr vlen-utf8 codec warnings\n", + "warnings.filterwarnings(\n", + " \"ignore\",\n", + " message=\"The codec `vlen-utf8` is currently not part in the Zarr format 3 specification.*\",\n", + " category=UserWarning,\n", + " module=\"zarr.codecs.vlen_utf8\",\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Converting existing `anndata` files into a shuffled collection" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The conversion code will take care of the following things:\n", + "* Align (outer join) the gene spaces across all datasets listed in `adata_paths`\n", + " * The gene spaces are outer-joined based on the gene names provided in the `var_names` field of the individual `AnnData` objects.\n", + " * If you want to subset to specific gene space, you can provide a list of gene names via the `var_subset` parameter.\n", + "* Shuffle the cells across all datasets (this works on larger than memory datasets as well).\n", + " * This is important for block-wise shuffling during data loading.\n", + "* Shuffle the input files across multiple output datasets:\n", + " * The size of each individual output dataset can be controlled via the `n_obs_per_dataset` parameter.\n", + " * We recommend to choose a dataset size that comfortably fits into system memory.\n", + "\n", + "\n", + "You can apply custom data transformations to each input h5ad file by supplying a `load_adata` function to `create_anndata_collection`" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [ + "hide-output" + ] + }, + "outputs": [], + "source": [ + "import anndata as ad\n", + "from annbatch import create_anndata_collection\n", + "\n", + "\n", + "# For CELLxGENE data, the raw counts can either be found under .raw.X or under .X (if .raw is not supplied).\n", + "# To have a store that only contains raw counts, we can write the following load_adata function\n", + "def read_lazy_x_and_obs_only(path) -> ad.AnnData:\n", + " \"\"\"Custom load function to only load raw counts from CxG data.\"\"\"\n", + " # IMPORTANT: Large data should always be loaded lazily to reduce the memory footprint\n", + " adata_ = ad.experimental.read_lazy(path)\n", + " if adata_.raw is not None:\n", + " x = adata_.raw.X\n", + " var = adata_.raw.var\n", + " else:\n", + " x = adata_.X\n", + " var = adata_.var\n", + "\n", + " return ad.AnnData(\n", + " X=x,\n", + " obs=adata_.obs.to_memory(),\n", + " var=var.to_memory(),\n", + " )\n", + "\n", + "\n", + "create_anndata_collection(\n", + " # List all the h5ad files you want to include in the collection\n", + " adata_paths=[\"866d7d5e-436b-4dbd-b7c1-7696487d452e.h5ad\", \"f81463b8-4986-4904-a0ea-20ff02cbb317.h5ad\"],\n", + " # Path to store the output collection\n", + " output_path=\"annbatch_collection\",\n", + " shuffle=True, # Whether to pre-shuffle the cells of the collection\n", + " n_obs_per_dataset=2_097_152, # Number of cells per dataset shard\n", + " var_subset=None, # Optionally subset the collection to a specific gene space\n", + " should_denseify=False,\n", + " load_adata=read_lazy_x_and_obs_only,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Data loading example" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from pathlib import Path\n", + "\n", + "COLLECTION_PATH = Path(\"annbatch_collection/\")" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": { + "tags": [ + "hide-output" + ] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import anndata as ad\n", + "\n", + "from annbatch import Loader\n", + "\n", + "ds = Loader(\n", + " batch_size=4096, # Total number of obs per yielded batch\n", + " chunk_size=256, # Number of obs to load from disk contiguously - default settings should work well\n", + " preload_nchunks=32, # Number of chunks to preload + shuffle - default settings should work well\n", + " preload_to_gpu=False,\n", + " # If True, preloaded chunks are moved to GPU memory via `cupy`, which can put more pressure on GPU memory but will accelerate loading ~20%\n", + " to_torch=True,\n", + ")\n", + "\n", + "# Add dataset that should be used for training\n", + "ds.add_anndatas(\n", + " [\n", + " ad.AnnData(\n", + " X=ad.io.sparse_dataset(zarr.open(p)[\"X\"]),\n", + " obs=ad.io.read_elem(zarr.open(p)[\"obs\"]),\n", + " )\n", + " for p in COLLECTION_PATH.glob(\"*.zarr\")\n", + " ],\n", + " obs_keys=\"cell_type\",\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "**IMPORTANT:**\n", + "* The `Loader` yields batches of sparse tensors.\n", + "* The conversion to dense tensors should be done on the GPU, as shown in the example below.\n", + " * First call `.cuda()` and then `.to_dense()`\n", + " * E.g. `x = x.cuda().to_dense()`\n", + " * This is significantly faster than doing the dense conversion on the CPU.\n" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": { + "tags": [ + "hide-output" + ] + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 0%| | 0/171792 [00:00 ad.AnnData:\n", + " \"\"\"Custom load function to only load raw counts from CxG data.\"\"\"\n", + " # As it's only a small dataset, we can load the full dataset into memory to speed up computations\n", + " adata_ = ad.read_h5ad(path) # Replace with ad.experimental.read_lazy if data does not fit into memory anymore\n", + " if adata_.raw is not None:\n", + " x = adata_.raw.X\n", + " var = adata_.raw.var\n", + " else:\n", + " x = adata_.X\n", + " var = adata_.var\n", + "\n", + " return ad.AnnData(X=x, obs=adata_.obs, var=var)\n", + "\n", + "\n", + "add_to_collection(\n", + " adata_paths=[\n", + " \"866d7d5e-436b-4dbd-b7c1-7696487d452e.h5ad\",\n", + " ],\n", + " output_path=\"annbatch_collection\",\n", + " load_adata=read_x_and_obs_only,\n", + ")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.6" + } }, - { - "data": { - "text/plain": [ - "" - ] - }, - "execution_count": 3, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "%load_ext autoreload\n", - "%autoreload 2\n", - "\n", - "import anndata as ad\n", - "from annbatch import DatasetCollection\n", - "\n", - "# let's write out only shared colunms - otherwise DatasetCollection will warn about all the columns we are missing for good reason - mismatched columns can lead to unexpected data and missing values.\n", - "shared_columns = ad.experimental.read_lazy(\"866d7d5e-436b-4dbd-b7c1-7696487d452e.h5ad\").obs.columns.intersection(\n", - " ad.experimental.read_lazy(\"f81463b8-4986-4904-a0ea-20ff02cbb317.h5ad\").obs.columns\n", - ")\n", - "\n", - "\n", - "# For CELLxGENE data, the raw counts can either be found under .raw.X or under .X (if .raw is not supplied).\n", - "# To have a store that only contains raw counts, we can write the following load_adata function\n", - "def read_lazy_x_and_obs_only(path) -> ad.AnnData:\n", - " \"\"\"Custom load function to only load raw counts from CxG data.\"\"\"\n", - " # IMPORTANT: Large data should always be loaded lazily to reduce the memory footprint\n", - " adata_ = ad.experimental.read_lazy(path)\n", - " if adata_.raw is not None:\n", - " x = adata_.raw.X\n", - " var = adata_.raw.var\n", - " else:\n", - " x = adata_.X\n", - " var = adata_.var\n", - "\n", - " return ad.AnnData(\n", - " X=x,\n", - " obs=adata_.obs.to_memory()[shared_columns],\n", - " var=var.to_memory(),\n", - " )\n", - "\n", - "\n", - "collection = DatasetCollection(zarr.open(\"annbatch_collection\", mode=\"w\"))\n", - "collection.add_adatas(\n", - " # List all the h5ad files you want to include in the collection\n", - " adata_paths=[\"866d7d5e-436b-4dbd-b7c1-7696487d452e.h5ad\", \"f81463b8-4986-4904-a0ea-20ff02cbb317.h5ad\"],\n", - " # Path to store the output collection\n", - " shuffle=True, # Whether to pre-shuffle the cells of the collection\n", - " n_obs_per_dataset=2_097_152, # Number of cells per dataset shard, this number is much higher than available in these datasets but is generally a good target\n", - " var_subset=None, # Optionally subset the collection to a specific gene space\n", - " load_adata=read_lazy_x_and_obs_only,\n", - ")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Data loading example" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Now we create our data loader with the desired arguments.\n", - "\n", - "**WARNING**: Without `load_adata` argument in `use_collection`, the *entire* `obs` will be loaded and yielded, degrading performance. It is highly advised to use this argument." - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": { - "tags": [ - "hide-output" - ] - }, - "outputs": [ - { - "data": { - "text/plain": [ - "" - ] - }, - "execution_count": 5, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "import anndata as ad\n", - "\n", - "from annbatch import Loader\n", - "\n", - "\n", - "def _load_adata(g: zarr.Group) -> ad.AnnData:\n", - " return ad.AnnData(X=ad.io.sparse_dataset(g[\"X\"]), obs=ad.experimental.read_lazy(g).obs[[\"cell_type\"]].to_memory())\n", - "\n", - "\n", - "ds = Loader(\n", - " batch_size=4096, # Total number of obs per yielded batch\n", - " chunk_size=256, # Number of obs to load from disk contiguously - default settings should work well\n", - " preload_nchunks=32, # Number of chunks to preload + shuffle - default settings should work well\n", - " # If True, preloaded chunks are moved to GPU memory via `cupy`, which can put more pressure on GPU memory but will accelerate loading ~20%\n", - " preload_to_gpu=False,\n", - " to_torch=True,\n", - ")\n", - "\n", - "# Add in the shuffled data that should be used for training.\n", - "ds.use_collection(collection, load_adata=_load_adata)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "**IMPORTANT:**\n", - "* The `Loader` yields batches of sparse tensors.\n", - "* The conversion to dense tensors should be done on the GPU, as shown in the example below.\n", - " * First call `.cuda()` and then `.to_dense()`\n", - " * E.g. `x = x.cuda().to_dense()`\n", - " * This is significantly faster than doing the dense conversion on the CPU.\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "tags": [ - "hide-output" - ] - }, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - " 0%| | 42/171792 [00:07<8:54:32, 5.36it/s] \n" - ] - } - ], - "source": [ - "# Iterate over dataloader\n", - "import tqdm\n", - "\n", - "for batch in tqdm.tqdm(ds):\n", - " x, obs = batch[\"X\"], batch[\"obs\"][\"cell_type\"]\n", - " # Important: Convert to dense on GPU\n", - " x = x.cuda().to_dense()\n", - " # Feed data into your model\n", - " ..." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Optional: Extend an existing collection with a new dataset" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "You might want to extend an existing pre-shuffled collection with a new dataset.\n", - "This can be done using the `add` method again.\n", - "\n", - "This function will take care of shuffling the new dataset into the existing collection without having to re-shuffle the entire collection." - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "metadata": { - "tags": [ - "hide-output" - ] - }, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "checking for mismatched keys: 100%|██████████| 1/1 [00:00<00:00, 1.65it/s]\n", - "loading: 1it [00:00, 1.77it/s]\n", - "checking for mismatched keys: 100%|██████████| 2/2 [00:00<00:00, 13.66it/s]\n", - "processing chunks: 0%| | 0/1 [00:00" - ] - }, - "execution_count": 8, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "collection.add_adatas(\n", - " adata_paths=[\n", - " \"866d7d5e-436b-4dbd-b7c1-7696487d452e.h5ad\",\n", - " ],\n", - " load_adata=read_lazy_x_and_obs_only,\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "venv", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.12.3" - } - }, - "nbformat": 4, - "nbformat_minor": 4 + "nbformat": 4, + "nbformat_minor": 4 } diff --git a/src/annbatch/io.py b/src/annbatch/io.py index 55eeff26..7ce99bfc 100644 --- a/src/annbatch/io.py +++ b/src/annbatch/io.py @@ -1,13 +1,12 @@ from __future__ import annotations -import math +import json import random -import re import warnings from collections import defaultdict from functools import wraps from pathlib import Path -from typing import TYPE_CHECKING, Self +from typing import TYPE_CHECKING import anndata as ad import dask.array as da @@ -20,18 +19,13 @@ from tqdm.auto import tqdm from zarr.codecs import BloscCodec, BloscShuffle -from annbatch.utils import split_given_size - if TYPE_CHECKING: - from collections.abc import Callable, Generator, Iterable, Mapping + from collections.abc import Callable, Iterable, Mapping from os import PathLike from typing import Any, Literal - import h5py from zarr.abc.codec import BytesBytesCodec -V1_ENCODING = {"encoding-type": "annbatch-preshuffled", "encoding-version": "0.1.0"} - def _round_down(num: int, divisor: int): return num - (num % divisor) @@ -46,7 +40,6 @@ def write_sharded( dense_chunk_size: int = 1024, dense_shard_size: int = 4194304, compressors: Iterable[BytesBytesCodec] = (BloscCodec(cname="lz4", clevel=3, shuffle=BloscShuffle.shuffle),), - key: str | None = None, ): """Write a sharded zarr store from a single AnnData object. @@ -66,8 +59,6 @@ def write_sharded( Number of obs elements per dense shard along the first axis compressors The compressors to pass to `zarr`. - key - The key to which this object should be written - by default the root, in which case the *entire* store (not just the group) is cleared first. """ ad.settings.zarr_write_format = 3 @@ -108,17 +99,11 @@ def callback( } write_func(store, elem_name, elem, dataset_kwargs=dataset_kwargs) - ad.experimental.write_dispatched(group, "/" if key is None else key, adata, callback=callback) + ad.experimental.write_dispatched(group, "/", adata, callback=callback) zarr.consolidate_metadata(group.store) -def _check_for_mismatched_keys( - paths_or_anndatas: Iterable[PathLike[str] | ad.AnnData | zarr.Group | h5py.Group] | Iterable[str | ad.AnnData], - *, - load_adata: Callable[[PathLike[str] | str], ad.AnnData] = lambda x: ad.experimental.read_lazy( - x, load_annotation_index=False - ), -): +def _check_for_mismatched_keys(paths_or_anndatas: Iterable[PathLike[str] | ad.AnnData] | Iterable[str | ad.AnnData]): num_raw_in_adata = 0 found_keys: dict[str, defaultdict[str, int]] = { "layers": defaultdict(lambda: 0), @@ -127,14 +112,13 @@ def _check_for_mismatched_keys( } for path_or_anndata in tqdm(paths_or_anndatas, desc="checking for mismatched keys"): if not isinstance(path_or_anndata, ad.AnnData): - adata = load_adata(path_or_anndata) + adata = ad.experimental.read_lazy(path_or_anndata) else: adata = path_or_anndata for elem_name, key_count in found_keys.items(): curr_keys = set(getattr(adata, elem_name).keys()) for key in curr_keys: - if not (elem_name in {"var", "obs"} and key == "_index"): - key_count[key] += 1 + key_count[key] += 1 if adata.raw is not None: num_raw_in_adata += 1 if num_raw_in_adata != len(paths_or_anndatas) and num_raw_in_adata != 0: @@ -155,12 +139,10 @@ def _check_for_mismatched_keys( def _lazy_load_anndatas( paths: Iterable[PathLike[str]] | Iterable[str], - load_adata: Callable[[PathLike[str] | str], ad.AnnData] = lambda x: ad.experimental.read_lazy( - x, load_annotation_index=False - ), + load_adata: Callable[[PathLike[str] | str], ad.AnnData] = ad.experimental.read_lazy, ): adatas = [] - categoricals_in_all_adatas: dict[str, pd.Index] = {} + categoricals_in_all_adatas = {} for i, path in tqdm(enumerate(paths), desc="loading"): adata = load_adata(path) # Track the source file for this given anndata object @@ -170,23 +152,20 @@ def _lazy_load_anndatas( # Concatenating Dataset2D drops categoricals so we need to track them if isinstance(adata.obs, Dataset2D): categorical_cols_in_this_adata = { - col: adata.obs[col].dtype.categories for col in adata.obs.columns if adata.obs[col].dtype == "category" + col: set(adata.obs[col].dtype.categories) + for col in adata.obs.columns + if adata.obs[col].dtype == "category" } if not categoricals_in_all_adatas: categoricals_in_all_adatas = { **categorical_cols_in_this_adata, - "src_path": adata.obs["src_path"].dtype.categories, + "src_path": set(adata.obs["src_path"].dtype.categories), } else: for k in categoricals_in_all_adatas.keys() & categorical_cols_in_this_adata.keys(): - categoricals_in_all_adatas[k] = categoricals_in_all_adatas[k].union( - categorical_cols_in_this_adata[k] + categoricals_in_all_adatas[k] = set(categoricals_in_all_adatas[k]).union( + set(categorical_cols_in_this_adata[k]) ) - # TODO: Probably bug in anndata, need the true index for proper outer joins (can't skirt this with fake indexes, at least not in the mixed-type regime). - if isinstance(adata.var, Dataset2D): - adata.var.index = adata.var.true_index - if adata.raw is not None and isinstance(adata.raw.var, Dataset2D): - adata.raw.var.index = adata.raw.var.true_index adatas.append(adata) if len(adatas) == 1: return adatas[0] @@ -196,38 +175,17 @@ def _lazy_load_anndatas( return adata -def _create_chunks_for_shuffling( - n_obs: int, - shuffle_chunk_size: int = 1000, - shuffle: bool = True, - *, - shuffle_n_obs_per_dataset: int | None = None, - n_chunkings: int | None = None, -) -> list[np.ndarray]: - # this splits the array up into `shuffle_chunk_size` contiguous runs - idxs = split_given_size(np.arange(n_obs), shuffle_chunk_size) - if shuffle: - random.shuffle(idxs) - match shuffle_n_obs_per_dataset is not None, n_chunkings is not None: - case True, False: - n_slices_per_dataset = int(shuffle_n_obs_per_dataset // shuffle_chunk_size) - use_single_chunking = n_obs <= shuffle_n_obs_per_dataset or n_slices_per_dataset <= 1 - case False, True: - n_slices_per_dataset = (n_obs // n_chunkings) // shuffle_chunk_size - use_single_chunking = n_chunkings == 1 - case _, _: - raise ValueError("Cannot provide both shuffle_n_obs_per_dataset and n_chunkings or neither") - # In this case `shuffle_n_obs_per_dataset` is bigger than the size of the dataset or the slice size is probably too big. - if use_single_chunking: - return [np.concatenate(idxs)] - # unfortunately, this is the only way to prevent numpy.split from trying to np.array the idxs list, which can have uneven elements. - idxs = np.array([slice(int(idx[0]), int(idx[-1] + 1)) for idx in idxs]) - return [ - np.concatenate([np.arange(s.start, s.stop) for s in idx]) - for idx in ( - split_given_size(idxs, n_slices_per_dataset) if n_chunkings is None else np.array_split(idxs, n_chunkings) - ) +def _create_chunks_for_shuffling(adata: ad.AnnData, shuffle_n_obs_per_dataset: int = 1_048_576, shuffle: bool = True): + chunk_boundaries = np.cumsum([0] + list(adata.X.chunks[0])) + slices = [ + slice(int(start), int(end)) for start, end in zip(chunk_boundaries[:-1], chunk_boundaries[1:], strict=True) ] + if shuffle: + random.shuffle(slices) + idxs = np.concatenate([np.arange(s.start, s.stop) for s in slices]) + idxs = np.array_split(idxs, np.ceil(len(idxs) / shuffle_n_obs_per_dataset)) + + return idxs def _compute_blockwise(x: DaskArray) -> sp.spmatrix: @@ -251,14 +209,9 @@ def _persist_adata_in_memory(adata: ad.AnnData) -> ad.AnnData: adata.X = _compute_blockwise(adata.X) if isinstance(adata.obs, Dataset2D): adata.obs = adata.obs.to_memory() - # TODO: This is a bug in anndata? - if "_index" in adata.obs.columns: - del adata.obs["_index"] adata = _to_categorical_obs(adata) if isinstance(adata.var, Dataset2D): adata.var = adata.var.to_memory() - if "_index" in adata.var.columns: - del adata.var["_index"] if adata.raw is not None: adata_raw = adata.raw.to_adata() @@ -266,28 +219,19 @@ def _persist_adata_in_memory(adata: ad.AnnData) -> ad.AnnData: adata_raw.X = _compute_blockwise(adata_raw.X) if isinstance(adata_raw.var, Dataset2D): adata_raw.var = adata_raw.var.to_memory() - if "_index" in adata_raw.var.columns: - del adata_raw.var["_index"] if isinstance(adata_raw.obs, Dataset2D): adata_raw.obs = adata_raw.obs.to_memory() del adata.raw adata.raw = adata_raw - for axis_name in ["layers", "obsm", "varm", "obsp", "varp"]: - for k, elem in getattr(adata, axis_name).items(): - # TODO: handle `Dataset2D` in `obsm` and `varm` that are - if isinstance(elem, DaskArray): - getattr(adata, axis_name)[k] = _compute_blockwise(elem) - if isinstance(elem, Dataset2D): - elem = elem.to_memory() - if "_index" in elem.columns: - del elem["_index"] - # TODO: Bug in anndata - if "obs" in axis_name: - elem.index = adata.obs_names - getattr(adata, axis_name)[k] = elem - - return adata.to_memory() + for k, elem in adata.obsm.items(): + # TODO: handle `Dataset2D` in `obsm` and `varm` that are + if isinstance(elem, DaskArray): + adata.obsm[k] = _compute_blockwise(elem) + + for k, elem in adata.layers.items(): + if isinstance(elem, DaskArray): + adata.obsm[k] = _compute_blockwise(elem) return adata @@ -304,380 +248,259 @@ def wrapper(*args, **kwargs): return wrapper -class DatasetCollection: - """A preshuffled collection object including functionality for creating, adding to, and loading collections shuffled by `annbatch`.""" +@_with_settings +def create_anndata_collection( + adata_paths: Iterable[PathLike[str]] | Iterable[str], + output_path: PathLike[str] | str, + *, + load_adata: Callable[[PathLike[str] | str], ad.AnnData] = ad.experimental.read_lazy, + var_subset: Iterable[str] | None = None, + zarr_sparse_chunk_size: int = 32768, + zarr_sparse_shard_size: int = 134_217_728, + zarr_dense_chunk_size: int = 1024, + zarr_dense_shard_size: int = 4_194_304, + zarr_compressor: Iterable[BytesBytesCodec] = (BloscCodec(cname="lz4", clevel=3, shuffle=BloscShuffle.shuffle),), + h5ad_compressor: Literal["gzip", "lzf"] | None = "gzip", + n_obs_per_dataset: int = 2_097_152, + shuffle: bool = True, + should_denseify: bool = False, + output_format: Literal["h5ad", "zarr"] = "zarr", +): + """Take AnnData paths, create an on-disk set of AnnData datasets with uniform var spaces at the desired path with `n_obs_per_dataset` rows per store. - _group: zarr.Group | Path + The set of AnnData datasets is collectively referred to as a "collection" where each dataset is called `dataset_i.{zarr,h5ad}`. + The main purpose of this function is to create shuffled sharded zarr datasets, which is the default behavior of this function. + However, this function can also output h5 datasets and also unshuffled datasets as well. + The var space is by default outer-joined, but can be subsetted by `var_subset`. + A key `src_path` is added to `obs` to indicate where individual row came from. + We highly recommend making your indexes unique across files, and this function will call {meth}`AnnData.obs_names_make_unique`. + Memory usage should be controlled by `n_obs_per_dataset` as so many rows will be read into memory before writing to disk. - def __init__( - self, group: zarr.Group | str | Path, *, mode: Literal["a", "r", "r+"] = "a", is_collection_h5ad: bool = False - ): - """Initialization of the object at a given location. - - Note that if the group is a h5py/zarr object, it must have the correct permissions for any subsequent operations you plan to do. - Otherwise, the store will be opened according to the mode argument. - - - Parameters - ---------- - group - The base location for a preshuffled collection. - A :class:`zarr.Group` or path ending in `.zarr` indicates zarr as the shuffled format and otherwise a directory of `h5ad` files will be created. - """ - if not isinstance(group, zarr.Group): - if isinstance(group, str | Path): - if not is_collection_h5ad: - if not str(group).endswith(".zarr"): - warnings.warn( - f"It is highly recommended to make your collections have the `.zarr` suffix, got: {group}.", - stacklevel=2, - ) - self._group = zarr.open_group(group, mode=mode) - else: - warnings.warn( - "Loading h5ad is currently not supported and thus we cannot guarantee the funcionality of the ecosystem with h5ad files." - "DatasetCollection should be able to handle shuffling but we guarantee little else." - "Proceed with caution.", - stacklevel=2, - ) - self._group = Path(group) - self._group.mkdir(exist_ok=True) - else: - raise TypeError("Group must either be a zarr group or a path") - else: - if is_collection_h5ad: - raise ValueError("Do not set `is_collection_h5ad` to True when also passing in a zarr Group.") - self._group = group - - @property - def _dataset_keys(self) -> list[str]: - if isinstance(self._group, zarr.Group): - return sorted( - [k for k in self._group.keys() if re.match(rf"{DATASET_PREFIX}_([0-9]*)", k) is not None], - key=lambda x: int(x.split("_")[1]), + Parameters + ---------- + adata_paths + Paths to the AnnData files used to create the zarr store. + output_path + Path to the output zarr store. + load_adata + Function to customize lazy-loading the invidiual input anndata files. By default, {func}`anndata.experimental.read_lazy` is used. + If you only need a subset of the input anndata files' elems (e.g., only `X` and `obs`), you can provide a custom function here to speed up loading and harmonize your data. + The input to the function is a path to an anndata file, and the output is an anndata object which has `X` as a {class}`dask.array.Array`. + var_subset + Subset of gene names to include in the store. If None, all genes are included. + Genes are subset based on the `var_names` attribute of the concatenated AnnData object. + zarr_sparse_chunk_size + Size of the chunks to use for the `indices` and `data` of a sparse matrix in the zarr store. + zarr_sparse_shard_size + Size of the shards to use for the `indices` and `data` of a sparse matrix in the zarr store. + zarr_dense_chunk_size + Number of observations per dense zarr chunk i.e., sharding is only done along the first axis of the array. + zarr_dense_shard_size + Number of observations per dense zarr shard i.e., chunking is only done along the first axis of the array. + zarr_compressor + Compressors to use to compress the data in the zarr store. + h5ad_compressor + Compressors to use to compress the data in the h5ad store. See anndata.write_h5ad. + n_obs_per_dataset + Number of observations to load into memory at once for shuffling / pre-processing. + The higher this number, the more memory is used, but the better the shuffling. + This corresponds to the size of the shards created. + shuffle + Whether to shuffle the data before writing it to the store. + should_denseify + Whether to write as dense on disk. There's no need to set this for sparse data, it is only for testing. + output_format + Format of the output store. Can be either "zarr" or "h5ad". + + Examples + -------- + >>> import anndata as ad + >>> from annbatch import create_anndata_collection + # create a custom load function to only keep `.X`, `.obs` and `.var` in the output store + >>> def read_lazy_x_and_obs_only(path): + ... adata = ad.experimental.read_lazy(path) + ... return ad.AnnData( + ... X=adata.X, + ... obs=adata.obs.to_memory(), + ... var=adata.var.to_memory(), + ...) + + >>> datasets = [ + ... "path/to/first_adata.h5ad", + ... "path/to/second_adata.h5ad", + ... "path/to/third_adata.h5ad", + ... ] + >>> create_anndata_collection( + ... datasets, + ... "path/to/output/zarr_store", + ... load_adata=read_lazy_x_and_obs_only, + ...) + """ + Path(output_path).mkdir(parents=True, exist_ok=True) + _check_for_mismatched_keys(adata_paths) + adata_concat = _lazy_load_anndatas(adata_paths, load_adata=load_adata) + adata_concat.obs_names_make_unique() + chunks = _create_chunks_for_shuffling(adata_concat, n_obs_per_dataset, shuffle=shuffle) + + if var_subset is None: + var_subset = adata_concat.var_names + + for i, chunk in enumerate(tqdm(chunks, desc="processing chunks")): + var_mask = adata_concat.var_names.isin(var_subset) + # np.sort: It's more efficient to access elements sequentially from dask arrays + # The data will be shuffled later on, we just want the elements at this point + adata_chunk = adata_concat[np.sort(chunk), :][:, var_mask].copy() + adata_chunk = _persist_adata_in_memory(adata_chunk) + if shuffle: + # shuffle adata in memory to break up individual chunks + idxs = np.random.default_rng().permutation(np.arange(len(adata_chunk))) + adata_chunk = adata_chunk[idxs] + # convert to dense format before writing to disk + if should_denseify: + # Need to convert back to dask array to avoid memory issues when converting large sparse matrices to dense + adata_chunk = adata_chunk.copy() + adata_chunk.X = da.from_array( + adata_chunk.X, chunks=(zarr_dense_chunk_size, -1), meta=adata_chunk.X + ).map_blocks(lambda xx: xx.toarray(), dtype=adata_chunk.X.dtype) + + if output_format == "zarr": + f = zarr.open_group(Path(output_path) / f"{DATASET_PREFIX}_{i}.zarr", mode="w") + write_sharded( + f, + adata_chunk, + sparse_chunk_size=zarr_sparse_chunk_size, + sparse_shard_size=zarr_sparse_shard_size, + dense_chunk_size=zarr_dense_chunk_size, + dense_shard_size=zarr_dense_shard_size, + compressors=zarr_compressor, ) + elif output_format == "h5ad": + adata_chunk.write_h5ad(Path(output_path) / f"{DATASET_PREFIX}_{i}.h5ad", compression=h5ad_compressor) else: - raise ValueError("Cannot iterate through folder of h5ad files") + raise ValueError(f"Unrecognized output_format: {output_format}. Only 'zarr' and 'h5ad' are supported.") - def __iter__(self) -> Generator[zarr.Group]: - if isinstance(self._group, zarr.Group): - for k in self._dataset_keys: - yield self._group[k] - else: - raise ValueError("Cannot iterate through folder of h5ad files") - - @property - def is_empty(self) -> bool: - """Wether or not there is an existing store at the group location.""" - return ( - (not (V1_ENCODING.items() <= self._group.attrs.items()) or len(self._dataset_keys) == 0) - if isinstance(self._group, zarr.Group) - else (len(list(self._group.iterdir())) == 0) - ) - @_with_settings - def add_adatas( - self, - adata_paths: Iterable[PathLike[str]] | Iterable[str], - *, - load_adata: Callable[[PathLike[str] | str], ad.AnnData] = lambda x: ad.experimental.read_lazy( - x, load_annotation_index=False - ), - var_subset: Iterable[str] | None = None, - zarr_sparse_chunk_size: int = 32768, - zarr_sparse_shard_size: int = 134_217_728, - zarr_dense_chunk_size: int = 1024, - zarr_dense_shard_size: int = 4_194_304, - zarr_compressor: Iterable[BytesBytesCodec] = (BloscCodec(cname="lz4", clevel=3, shuffle=BloscShuffle.shuffle),), - h5ad_compressor: Literal["gzip", "lzf"] | None = "gzip", - n_obs_per_dataset: int = 2_097_152, - shuffle_chunk_size: int = 1000, - shuffle: bool = True, - ) -> Self: - """Take AnnData paths and create or add to an on-disk set of AnnData datasets with uniform var spaces at the desired path (with `n_obs_per_dataset` rows per dataset if running for the first time). - - The set of AnnData datasets is collectively referred to as a "collection" where each dataset is called `dataset_i.{zarr,h5ad}`. - The main purpose of this function is to create shuffled sharded zarr datasets, which is the default behavior of this function. - However, this function can also output h5 datasets and also unshuffled datasets as well. - The var space is by default outer-joined initially, and then subsequently added datasets (i.e., on second calls to this function) are subsetted, but this behavior can be controlled by `var_subset`. - A key `src_path` is added to `obs` to indicate where individual row came from. - We highly recommend making your indexes unique across files, and this function will call `AnnData.obs_names_make_unique`. - Memory usage should be controlled by `n_obs_per_dataset` + `shuffle_chunk_size` as so many rows will be read into memory before writing to disk. - After the dataset completes, a marker is added to the group's `attrs` to note that this dataset has been shuffled by `annbatch`. - This is not a stable API but only for internal purposes at the moment. - - Parameters - ---------- - adata_paths - Paths to the AnnData files used to create the zarr store. - load_adata - Function to customize lazy-loading the invidiual input anndata files. By default, :func:`anndata.experimental.read_lazy` is used. - If you only need a subset of the input anndata files' elems (e.g., only `X` and `obs`), you can provide a custom function here to speed up loading and harmonize your data. - The input to the function is a path to an anndata file, and the output is an :class:`anndata.AnnData` object. - var_subset - Subset of gene names to include in the store. If None, all genes are included. - Genes are subset based on the `var_names` attribute of the concatenated AnnData object. - zarr_sparse_chunk_size - Size of the chunks to use for the `indices` and `data` of a sparse matrix in the zarr store. - zarr_sparse_shard_size - Size of the shards to use for the `indices` and `data` of a sparse matrix in the zarr store. - zarr_dense_chunk_size - Number of observations per dense zarr chunk i.e., sharding is only done along the first axis of the array. - zarr_dense_shard_size - Number of observations per dense zarr shard i.e., chunking is only done along the first axis of the array. - zarr_compressor - Compressors to use to compress the data in the zarr store. - h5ad_compressor - Compressors to use to compress the data in the h5ad store. See anndata.write_h5ad. - n_obs_per_dataset - Number of observations to load into memory at once for shuffling / pre-processing. - The higher this number, the more memory is used, but the better the shuffling. - This corresponds to the size of the shards created. - Only applicable when adding datasets for the first time, otherwise ignored. - shuffle - Whether to shuffle the data before writing it to the store. - Ignored once the store is non-empty. - shuffle_chunk_size - How many contiguous rows to load into memory before shuffling at once. - `(shuffle_chunk_size // n_obs_per_dataset)` slices will be loaded of size `shuffle_chunk_size`. - - Examples - -------- - >>> import anndata as ad - >>> from annbatch import DatasetCollection - # create a custom load function to only keep `.X`, `.obs` and `.var` in the output store - >>> def read_lazy_x_and_obs_only(path): - ... adata = ad.experimental.read_lazy(path) - ... return ad.AnnData( - ... X=adata.X, - ... obs=adata.obs.to_memory(), - ... var=adata.var.to_memory(), - ...) - >>> datasets = [ - ... "path/to/first_adata.h5ad", - ... "path/to/second_adata.h5ad", - ... "path/to/third_adata.h5ad", - ... ] - >>> DatasetCollection("path/to/output/zarr_store.zarr").add_adatas( - ... datasets, - ... load_adata=read_lazy_x_and_obs_only, - ...) - """ - if shuffle_chunk_size > n_obs_per_dataset: - raise ValueError("Cannot have a large slice size than observations per dataset") - shared_kwargs = { - "adata_paths": adata_paths, - "load_adata": load_adata, - "zarr_sparse_chunk_size": zarr_sparse_chunk_size, - "zarr_sparse_shard_size": zarr_sparse_shard_size, - "zarr_dense_chunk_size": zarr_dense_chunk_size, - "zarr_dense_shard_size": zarr_dense_shard_size, - "zarr_compressor": zarr_compressor, - "h5ad_compressor": h5ad_compressor, - "shuffle_chunk_size": shuffle_chunk_size, - "shuffle": shuffle, - } - if self.is_empty: - self._create_collection(**shared_kwargs, n_obs_per_dataset=n_obs_per_dataset, var_subset=var_subset) - else: - self._add_to_collection(**shared_kwargs) - return self +def _get_array_encoding_type(path: PathLike[str] | str) -> str: + shards = list(Path(path).glob(f"{DATASET_PREFIX}_*.zarr")) + with open(shards[0] / "X" / "zarr.json") as f: + encoding = json.load(f) + return encoding["attributes"]["encoding-type"] - def _create_collection( - self, - *, - adata_paths: Iterable[PathLike[str]] | Iterable[str], - load_adata: Callable[[PathLike[str] | str], ad.AnnData] = lambda x: ad.experimental.read_lazy( - x, load_annotation_index=False - ), - var_subset: Iterable[str] | None = None, - zarr_sparse_chunk_size: int = 32768, - zarr_sparse_shard_size: int = 134_217_728, - zarr_dense_chunk_size: int = 1024, - zarr_dense_shard_size: int = 4_194_304, - zarr_compressor: Iterable[BytesBytesCodec] = (BloscCodec(cname="lz4", clevel=3, shuffle=BloscShuffle.shuffle),), - h5ad_compressor: Literal["gzip", "lzf"] | None = "gzip", - n_obs_per_dataset: int = 2_097_152, - shuffle_chunk_size: int = 1000, - shuffle: bool = True, - ) -> None: - """Take AnnData paths, create an on-disk set of AnnData datasets with uniform var spaces at the desired path with `n_obs_per_dataset` rows per dataset. - - The set of AnnData datasets is collectively referred to as a "collection" where each dataset is called `dataset_i.{zarr,h5ad}`. - The main purpose of this function is to create shuffled sharded zarr datasets, which is the default behavior of this function. - However, this function can also output h5 datasets and also unshuffled datasets as well. - The var space is by default outer-joined, but can be subsetted by `var_subset`. - A key `src_path` is added to `obs` to indicate where individual row came from. - We highly recommend making your indexes unique across files, and this function will call `AnnData.obs_names_make_unique`. - Memory usage should be controlled by `n_obs_per_dataset` as so many rows will be read into memory before writing to disk. - - Parameters - ---------- - adata_paths - Paths to the AnnData files used to create the zarr store. - load_adata - Function to customize lazy-loading the invidiual input anndata files. By default, :func:`anndata.experimental.read_lazy` is used. - If you only need a subset of the input anndata files' elems (e.g., only `X` and `obs`), you can provide a custom function here to speed up loading and harmonize your data. - The input to the function is a path to an anndata file, and the output is an anndata object which has `X` as a :class:`dask.array.Array`. - var_subset - Subset of gene names to include in the store. If None, all genes are included. - Genes are subset based on the `var_names` attribute of the concatenated AnnData object. - Only applicable when adding datasets for the first time, otherwise ignored and the incoming data's var space is subsetted to that of the existing collection. - zarr_sparse_chunk_size - Size of the chunks to use for the `indices` and `data` of a sparse matrix in the zarr store. - zarr_sparse_shard_size - Size of the shards to use for the `indices` and `data` of a sparse matrix in the zarr store. - zarr_dense_chunk_size - Number of observations per dense zarr chunk i.e., sharding is only done along the first axis of the array. - zarr_dense_shard_size - Number of observations per dense zarr shard i.e., chunking is only done along the first axis of the array. - zarr_compressor - Compressors to use to compress the data in the zarr store. - h5ad_compressor - Compressors to use to compress the data in the h5ad store. See anndata.write_h5ad. - n_obs_per_dataset - Number of observations to load into memory at once for shuffling / pre-processing. - The higher this number, the more memory is used, but the better the shuffling. - This corresponds to the size of the shards created. - Only applicable when adding datasets for the first time, otherwise ignored. - shuffle - Whether to shuffle the data before writing it to the store. - shuffle_chunk_size - How many contiguous rows to load into memory before shuffling at once. - `(shuffle_chunk_size // n_obs_per_dataset)` slices will be loaded of size `shuffle_chunk_size`. - """ - if not self.is_empty: - raise RuntimeError("Cannot create a collection at a location that already has a shuffled collection") - _check_for_mismatched_keys(adata_paths, load_adata=load_adata) - adata_concat = _lazy_load_anndatas(adata_paths, load_adata=load_adata) - adata_concat.obs_names_make_unique() - n_obs_per_dataset = min(adata_concat.shape[0], n_obs_per_dataset) - chunks = _create_chunks_for_shuffling( - adata_concat.shape[0], shuffle_chunk_size, shuffle=shuffle, shuffle_n_obs_per_dataset=n_obs_per_dataset - ) - if var_subset is None: - var_subset = adata_concat.var_names - for i, chunk in enumerate(tqdm(chunks, desc="processing chunks")): - var_mask = adata_concat.var_names.isin(var_subset) - # np.sort: It's more efficient to access elements sequentially from dask arrays - # The data will be shuffled later on, we just want the elements at this point - adata_chunk = adata_concat[np.sort(chunk), :][:, var_mask].copy() - adata_chunk = _persist_adata_in_memory(adata_chunk) - if shuffle: - # shuffle adata in memory to break up individual chunks - idxs = np.random.default_rng().permutation(np.arange(len(adata_chunk))) - adata_chunk = adata_chunk[idxs] - if isinstance(self._group, zarr.Group): - write_sharded( - self._group, - adata_chunk, - sparse_chunk_size=zarr_sparse_chunk_size, - sparse_shard_size=zarr_sparse_shard_size, - dense_chunk_size=min(adata_chunk.shape[0], zarr_dense_chunk_size), - dense_shard_size=min(adata_chunk.shape[0], zarr_dense_shard_size), - compressors=zarr_compressor, - key=f"{DATASET_PREFIX}_{i}", - ) - else: - ad.io.write_h5ad( - self._group / f"{DATASET_PREFIX}_{i}.h5ad", - adata_chunk, - dataset_kwargs={"compression": h5ad_compressor}, - ) - if isinstance(self._group, zarr.Group): - self._group.update_attributes(V1_ENCODING) +@_with_settings +def add_to_collection( + adata_paths: Iterable[PathLike[str]] | Iterable[str], + output_path: PathLike[str] | str, + load_adata: Callable[[PathLike[str] | str], ad.AnnData] = ad.read_h5ad, + zarr_sparse_chunk_size: int = 32768, + zarr_sparse_shard_size: int = 134_217_728, + zarr_dense_chunk_size: int = 1024, + zarr_dense_shard_size: int = 4_194_304, + zarr_compressor: Iterable[BytesBytesCodec] = (BloscCodec(cname="lz4", clevel=3, shuffle=BloscShuffle.shuffle),), + should_sparsify_output_in_memory: bool = False, +) -> None: + """Add anndata files to an existing collection of sharded anndata zarr datasets. - def _add_to_collection( - self, - *, - adata_paths: Iterable[PathLike[str]] | Iterable[str], - load_adata: Callable[[PathLike[str] | str], ad.AnnData] = ad.read_h5ad, - zarr_sparse_chunk_size: int = 32768, - zarr_sparse_shard_size: int = 134_217_728, - zarr_dense_chunk_size: int = 1024, - zarr_dense_shard_size: int = 4_194_304, - zarr_compressor: Iterable[BytesBytesCodec] = (BloscCodec(cname="lz4", clevel=3, shuffle=BloscShuffle.shuffle),), - h5ad_compressor: Literal["gzip", "lzf"] | None = "gzip", - shuffle_chunk_size: int = 1000, - shuffle: bool = True, - ) -> None: - """Add anndata files to an existing collection of sharded anndata zarr datasets. - - The var space of the source anndata files will be adapted to the target store. - - Parameters - ---------- - adata_paths - Paths to the anndata files to be appended to the collection of output chunks. - load_adata - Function to customize loading the invidiual input anndata files. By default, :func:`anndata.read_h5ad` is used. - If you only need a subset of the input anndata files' elems (e.g., only `X` and `obs`), you can provide a custom function here to speed up loading and harmonize your data. - The input to the function is a path to an anndata file, and the output is an anndata object. - If the input data is too large to fit into memory, you should use :func:`annndata.experimental.read_lazy` instead. - zarr_sparse_chunk_size - Size of the chunks to use for the `indices` and `data` of a sparse matrix in the zarr store. - zarr_sparse_shard_size - Size of the shards to use for the `indices` and `data` of a sparse matrix in the zarr store. - zarr_dense_chunk_size - Number of observations per dense zarr chunk i.e., sharding is only done along the first axis of the array. - zarr_dense_shard_size - Number of observations per dense zarr shard i.e., chunking is only done along the first axis of the array. - zarr_compressor - Compressors to use to compress the data in the zarr store. - should_sparsify_output_in_memory - This option is for testing only appending sparse files to dense stores. - To save memory, the blocks of a dense on-disk store can be sparsified for in-memory processing. - shuffle_chunk_size - How many contiguous rows to load into memory of the input data for pseudo-blockshuffling into the existing datasets. - shuffle - Whether or not to shuffle when adding. Otherwise, the incoming data will just be split up and appended. - """ - if self.is_empty: - raise ValueError("Store is empty. Please run `DatasetCollection.add` first.") - # Check for mismatched keys among the inputs. - _check_for_mismatched_keys(adata_paths, load_adata=load_adata) - - adata_concat = _lazy_load_anndatas(adata_paths, load_adata=load_adata) - if math.ceil(adata_concat.shape[0] / shuffle_chunk_size) < len(self._dataset_keys): - raise ValueError( - f"Use a shuffle size small enough to distribute the input data with {adata_concat.shape[0]} obs across {len(self._dataset_keys)} anndata stores." - "Open an issue if the incoming anndata is so small it cannot be distributed across the on-disk data" - ) - # Check for mismatched keys between datasets and the inputs. - _check_for_mismatched_keys([adata_concat] + [self._group[k] for k in self._dataset_keys]) - chunks = _create_chunks_for_shuffling( - adata_concat.shape[0], shuffle_chunk_size, shuffle=shuffle, n_chunkings=len(self._dataset_keys) - ) + The var space of the source anndata files will be adapted to the target store. - adata_concat.obs_names_make_unique() - for dataset, chunk in tqdm( - zip(self._dataset_keys, chunks, strict=True), total=len(self._dataset_keys), desc="processing chunks" - ): - adata_dataset = ad.io.read_elem(self._group[dataset]) - subset_adata = _to_categorical_obs( - adata_concat[chunk, :][:, adata_concat.var.index.isin(adata_dataset.var.index)] - ) - adata = ad.concat([adata_dataset, subset_adata], join="outer") - if shuffle: - idxs = np.random.default_rng().permutation(adata.shape[0]) - else: - idxs = np.arange(adata.shape[0]) - adata = _persist_adata_in_memory(adata[idxs, :].copy()) - if isinstance(self._group, zarr.Group): - write_sharded( - self._group, - adata, - sparse_chunk_size=zarr_sparse_chunk_size, - sparse_shard_size=zarr_sparse_shard_size, - dense_chunk_size=min(adata.shape[0], zarr_dense_chunk_size), - dense_shard_size=min(adata.shape[0], zarr_dense_shard_size), - compressors=zarr_compressor, - key=dataset, - ) - else: - ad.io.write_h5ad( - self._group / f"{dataset}.h5ad", - adata, - dataset_kwargs={"compression": h5ad_compressor}, + Parameters + ---------- + adata_paths + Paths to the anndata files to be appended to the collection of output chunks. + output_path + Path to the output zarr store. + load_adata + Function to customize loading the invidiual input anndata files. By default, {func}`anndata.read_h5ad` is used. + If you only need a subset of the input anndata files' elems (e.g., only `X` and `obs`), you can provide a custom function here to speed up loading and harmonize your data. + The input to the function is a path to an anndata file, and the output is an anndata object. + If the input data is too large to fit into memory, you should use `ad.experimental.read_lazy` instead. + zarr_sparse_chunk_size + Size of the chunks to use for the `indices` and `data` of a sparse matrix in the zarr store. + zarr_sparse_shard_size + Size of the shards to use for the `indices` and `data` of a sparse matrix in the zarr store. + zarr_dense_chunk_size + Number of observations per dense zarr chunk i.e., sharding is only done along the first axis of the array. + zarr_dense_shard_size + Number of observations per dense zarr shard i.e., chunking is only done along the first axis of the array. + zarr_compressor + Compressors to use to compress the data in the zarr store. + should_sparsify_output_in_memory + This option is for testing only appending sparse files to dense stores. + To save memory, the blocks of a dense on-disk store can be sparsified for in-memory processing. + + Examples + -------- + >>> import anndata as ad + >>> from annbatch import add_to_collection + >>> datasets = [ + ... "path/to/first_adata.h5ad", + ... "path/to/second_adata.h5ad", + ... "path/to/third_adata.h5ad", + ... ] + >>> add_to_collection( + ... datasets, + ... "path/to/output/zarr_store", + ... load_adata=ad.read_h5ad, # replace with ad.experimental.read_lazy if data does not fit into memory + ...) + """ + shards = list(Path(output_path).glob(f"{DATASET_PREFIX}_*.zarr")) + if len(shards) == 0: + raise ValueError( + "Store at `output_path` does not exist or is empty. Please run `create_anndata_collection` first." + ) + encoding = _get_array_encoding_type(output_path) + if encoding == "array": + print("Detected array encoding type. Will convert to dense format before writing.") + # Check for mismatched keys among the inputs. + _check_for_mismatched_keys(adata_paths) + + adata_concat = _lazy_load_anndatas(adata_paths, load_adata=load_adata) + # Check for mismatched keys between shards and the inputs. + _check_for_mismatched_keys([adata_concat] + shards) + if isinstance(adata_concat.X, DaskArray): + chunks = _create_chunks_for_shuffling(adata_concat, np.ceil(len(adata_concat) / len(shards)), shuffle=True) + else: + chunks = np.array_split(np.random.default_rng().permutation(len(adata_concat)), len(shards)) + + adata_concat.obs_names_make_unique() + if encoding == "array": + if not should_sparsify_output_in_memory: + if isinstance(adata_concat.X, sp.spmatrix): + adata_concat.X = adata_concat.X.toarray() + elif isinstance(adata_concat.X, DaskArray) and isinstance(adata_concat.X._meta, sp.spmatrix): + adata_concat.X = adata_concat.X.map_blocks( + lambda x: x.toarray(), meta=np.ndarray, dtype=adata_concat.X.dtype ) + elif encoding == "csr_matrix": + if isinstance(adata_concat.X, np.ndarray): + adata_concat.X = sp.csr_matrix(adata_concat.X) + elif isinstance(adata_concat.X, DaskArray) and isinstance(adata_concat.X._meta, np.ndarray): + adata_concat.X = adata_concat.X.map_blocks( + sp.csr_matrix, meta=sp.csr_matrix(np.array([0], dtype=adata_concat.X.dtype)) + ) + + for shard, chunk in tqdm(zip(shards, chunks, strict=False), total=len(shards), desc="processing chunks"): + if should_sparsify_output_in_memory and encoding == "array": + adata_shard = _lazy_load_anndatas([shard]) + adata_shard.X = adata_shard.X.map_blocks(sp.csr_matrix).compute() + else: + adata_shard = ad.read_zarr(shard) + subset_adata = _to_categorical_obs( + adata_concat[chunk, :][:, adata_concat.var.index.isin(adata_shard.var.index)] + ) + adata = ad.concat([adata_shard, subset_adata], join="outer") + idxs_shuffled = np.random.default_rng().permutation(len(adata)) + adata = adata[idxs_shuffled, :].copy() # this significantly speeds up writing to disk + if should_sparsify_output_in_memory and encoding == "array": + adata.X = adata.X.map_blocks(lambda x: x.toarray(), meta=np.array([0], dtype=adata.X.dtype)).compute() + + f = zarr.open_group(shard, mode="w") + write_sharded( + f, + adata, + sparse_chunk_size=zarr_sparse_chunk_size, + sparse_shard_size=zarr_sparse_shard_size, + dense_chunk_size=zarr_dense_chunk_size, + dense_shard_size=zarr_dense_shard_size, + compressors=zarr_compressor, + ) From dff2a014a072bd0c4301af10b8e87dba4a5bcceb Mon Sep 17 00:00:00 2001 From: selmanozleyen Date: Mon, 19 Jan 2026 19:50:55 +0100 Subject: [PATCH 15/56] continuation of the upstream origin confusion fix --- CHANGELOG.md | 16 +- docs/api.md | 5 +- docs/conf.py | 18 +- docs/index.md | 18 +- docs/notebooks/example.ipynb | 798 ++++++++++++++++++----------------- src/annbatch/io.py | 733 ++++++++++++++++++++------------ 6 files changed, 881 insertions(+), 707 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 1f6f802e..247d2b03 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,7 +8,21 @@ and this project adheres to [Semantic Versioning][]. [keep a changelog]: https://keepachangelog.com/en/1.0.0/ [semantic versioning]: https://semver.org/spec/v2.0.0.html -## [Unreleased] +## [0.0.3] + +### Breaking + +- Revert `h5ad` shuffling into one big store (i.e., go back to sharding into individual files) and add warning that `h5ad` is not fully supported by `annbatch`. `is_collection_h5ad` argument to initialization of {class}`annbatch.DatasetCollection` must be passed when initializing into to use a preshuffled collection of `h5ad` files, reading or writing. +- Renamed {class}`annbatch.types.LoaderOutput` `["labels"]` and `["data"]` to `["obs"]` and `["X"]` respectively. + +## [0.0.2] + +### Breaking + +- `ZarrSparseDataset` and `ZarrDenseDataset` have been conslidated into {class}`annbatch.Loader` +- `create_anndata_collection` and `add_to_collection` have been moved into the {meth}`annbatch.DatasetCollection.add_adatas` method +- Default reading of input data is now fully lazy in {meth}`annbatch.DatasetCollection.add_adatas`, and therefore the shuffle process may now be slower although have better memory properties. Use `load_adata` argument in {meth}`annbatch.DatasetCollection.add_adatas` to customize this behavior. +- Files shuffled under the old `create_anndata_collection` will not be recognized by {class}`annbatch.DatasetCollection` and therefore are not usable with the new {class}`annbatch.Loader.use_collection` API. At the moment, the file metadata we maintain is only for internal purposes - however, if you wish to migrate to be able to use {class}`annbatch.DatasetCollection` in conjunction with {class}`annbatch.Loader.use_collection`, the root folder of the old collection must have attrs `{"encoding-type": "annbatch-preshuffled", "encoding-version": "0.1.0"}` and be a {class}`zarr.Group`. The subfolders (i.e., datasets) must be called `dataset_([0-9]*)`. Otherwise you can use the {meth}`annbatch.Loader.add_anndatas` as before. ### Changed diff --git a/docs/api.md b/docs/api.md index 263a9c52..cf399fd6 100644 --- a/docs/api.md +++ b/docs/api.md @@ -25,8 +25,7 @@ :toctree: generated/ write_sharded - add_to_collection - create_anndata_collection + DatasetCollection ``` (types)= @@ -36,5 +35,5 @@ .. autosummary:: :toctree: generated/ - types.BackingArray_T + types.LoaderOutput ``` diff --git a/docs/conf.py b/docs/conf.py index 98ac5f1f..10b83dd5 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -1,23 +1,14 @@ -# Configuration file for the Sphinx documentation builder. +from __future__ import annotations -# This file only contains a selection of the most common options. For a full -# list see the documentation: -# https://www.sphinx-doc.org/en/master/usage/configuration.html - -# -- Path setup -------------------------------------------------------------- import sys from datetime import datetime from importlib.metadata import metadata -from pathlib import Path -# For some reason doing this prevents autodoc_mock_import = ["torch"] from not being able to find the module i.e., it's not in sys.modules. -# TODO: Bug report -import annbatch # noqa: F401 +# -- Path setup -------------------------------------------------------------- +from pathlib import Path HERE = Path(__file__).parent sys.path.insert(0, str(HERE / "extensions")) - - # -- Project information ----------------------------------------------------- # NOTE: If you installed your project in editable mode, this might be stale. @@ -64,13 +55,13 @@ "IPython.sphinxext.ipython_console_highlighting", "sphinxext.opengraph", "sphinx_issues", + "sphinx_toolbox.more_autodoc.autotypeddict", "scanpydoc", # needs to be before linkcode *[p.stem for p in (HERE / "extensions").glob("*.py")], ] autosummary_generate = True autodoc_member_order = "groupwise" -autodoc_mock_imports = ["torch"] default_role = "literal" napoleon_google_docstring = False napoleon_numpy_docstring = True @@ -109,6 +100,7 @@ "scipy": ("https://docs.scipy.org/doc/scipy", None), "cupy": ("https://docs.cupy.dev/en/stable/", None), "zarrs": ("https://zarrs-python.readthedocs.io/en/latest/", None), + "pandas": ("https://pandas.pydata.org/pandas-docs/stable", None), } # List of patterns, relative to source directory, that match files and diff --git a/docs/index.md b/docs/index.md index 9043ce97..d94024b5 100644 --- a/docs/index.md +++ b/docs/index.md @@ -9,12 +9,11 @@ Let's go through the above example: ### Preprocessing ```python -create_anndata_collection( +colleciton = DatasetCollection("path/to/output/store.zarr").add_adatas( adata_paths=[ "path/to/your/file1.h5ad", "path/to/your/file2.h5ad" ], - output_path="path/to/output/store", # a directory containing `chunk_{i}.zarr` shuffle=True, # shuffling is needed if you want to use chunked access ) ``` @@ -33,25 +32,16 @@ See the [zarr docs on sharding][] for more information. #### Chunked access ```python +# `use_collection` will automatically get everything in `X` and `obs` and yield it. ds = Loader( batch_size=4096, chunk_size=32, preload_nchunks=256, -).add_anndatas( - [ - ad.AnnData( - # note that you can open an anndata file using any type of zarr store - X=ad.io.sparse_dataset(zarr.open(p)["X"]), - obs=ad.io.read_elem(zarr.open(p)["obs"]), - ) - for p in PATH_TO_STORE.glob("*.zarr") - ], - obs_keys="label_column", -) +).use_collection(collection) # Iterate over dataloader (plugin replacement for torch.utils.DataLoader) for batch in ds: - ... + x, df, index = batch["X"], batch["obs"], batch["index"] ``` The data loader implements a chunked fetching strategy where `preload_nchunks` number of continguous-chunks of size `chunk_size` are loaded. diff --git a/docs/notebooks/example.ipynb b/docs/notebooks/example.ipynb index 3f8aef26..f7085129 100644 --- a/docs/notebooks/example.ipynb +++ b/docs/notebooks/example.ipynb @@ -1,401 +1,403 @@ { - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Quickstart `annbatch`" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "This notebook will walk you through the following steps:\n", - "1. How to convert an existing collection of `anndata` files into a shuffled, zarr-based, collection of `anndata` datasets\n", - "2. How to load the converted collection using `annbatch`\n", - "3. Extend an existing collection with new `anndata` datasets" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "tags": [ - "hide-output" - ] - }, - "outputs": [], - "source": [ - "# !pip install annbatch[zarrs, torch]" - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "metadata": { - "tags": [ - "hide-output" - ] - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "--2025-10-09 09:43:19-- https://datasets.cellxgene.cziscience.com/866d7d5e-436b-4dbd-b7c1-7696487d452e.h5ad\n", - "Resolving datasets.cellxgene.cziscience.com (datasets.cellxgene.cziscience.com)... 18.64.79.73, 18.64.79.80, 18.64.79.109, ...\n", - "Connecting to datasets.cellxgene.cziscience.com (datasets.cellxgene.cziscience.com)|18.64.79.73|:443... connected.\n", - "HTTP request sent, awaiting response... 200 OK\n", - "Length: 773247972 (737M) [binary/octet-stream]\n", - "Saving to: ‘866d7d5e-436b-4dbd-b7c1-7696487d452e.h5ad’\n", - "\n", - "866d7d5e-436b-4dbd- 100%[===================>] 737.43M 398MB/s in 1.9s \n", - "\n", - "2025-10-09 09:43:21 (398 MB/s) - ‘866d7d5e-436b-4dbd-b7c1-7696487d452e.h5ad’ saved [773247972/773247972]\n", - "\n", - "--2025-10-09 09:43:22-- https://datasets.cellxgene.cziscience.com/f81463b8-4986-4904-a0ea-20ff02cbb317.h5ad\n", - "Resolving datasets.cellxgene.cziscience.com (datasets.cellxgene.cziscience.com)... 18.64.79.73, 18.64.79.80, 18.64.79.72, ...\n", - "Connecting to datasets.cellxgene.cziscience.com (datasets.cellxgene.cziscience.com)|18.64.79.73|:443... connected.\n", - "HTTP request sent, awaiting response... 200 OK\n", - "Length: 1631759823 (1.5G) [binary/octet-stream]\n", - "Saving to: ‘f81463b8-4986-4904-a0ea-20ff02cbb317.h5ad’\n", - "\n", - "f81463b8-4986-4904- 100%[===================>] 1.52G 425MB/s in 3.9s \n", - "\n", - "2025-10-09 09:43:26 (403 MB/s) - ‘f81463b8-4986-4904-a0ea-20ff02cbb317.h5ad’ saved [1631759823/1631759823]\n", - "\n" - ] - } - ], - "source": [ - "# Download two example datasets from CELLxGENE\n", - "!wget https://datasets.cellxgene.cziscience.com/866d7d5e-436b-4dbd-b7c1-7696487d452e.h5ad\n", - "!wget https://datasets.cellxgene.cziscience.com/f81463b8-4986-4904-a0ea-20ff02cbb317.h5ad" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "**IMPORTANT**: Configure zarrs\n", - "\n", - "This step is both required for converting existing `anndata` files into a performant, shuffled collection of datasets for mini batch loading" - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "metadata": { - "tags": [ - "hide-output" - ] - }, - "outputs": [ - { - "data": { - "text/plain": [ - "" - ] - }, - "execution_count": 1, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "import zarr\n", - "import zarrs # noqa\n", - "\n", - "zarr.config.set({\"codec_pipeline.path\": \"zarrs.ZarrsCodecPipeline\"})" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [], - "source": [ - "import warnings\n", - "\n", - "# Suppress zarr vlen-utf8 codec warnings\n", - "warnings.filterwarnings(\n", - " \"ignore\",\n", - " message=\"The codec `vlen-utf8` is currently not part in the Zarr format 3 specification.*\",\n", - " category=UserWarning,\n", - " module=\"zarr.codecs.vlen_utf8\",\n", - ")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Converting existing `anndata` files into a shuffled collection" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "The conversion code will take care of the following things:\n", - "* Align (outer join) the gene spaces across all datasets listed in `adata_paths`\n", - " * The gene spaces are outer-joined based on the gene names provided in the `var_names` field of the individual `AnnData` objects.\n", - " * If you want to subset to specific gene space, you can provide a list of gene names via the `var_subset` parameter.\n", - "* Shuffle the cells across all datasets (this works on larger than memory datasets as well).\n", - " * This is important for block-wise shuffling during data loading.\n", - "* Shuffle the input files across multiple output datasets:\n", - " * The size of each individual output dataset can be controlled via the `n_obs_per_dataset` parameter.\n", - " * We recommend to choose a dataset size that comfortably fits into system memory.\n", - "\n", - "\n", - "You can apply custom data transformations to each input h5ad file by supplying a `load_adata` function to `create_anndata_collection`" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "tags": [ - "hide-output" - ] - }, - "outputs": [], - "source": [ - "import anndata as ad\n", - "from annbatch import create_anndata_collection\n", - "\n", - "\n", - "# For CELLxGENE data, the raw counts can either be found under .raw.X or under .X (if .raw is not supplied).\n", - "# To have a store that only contains raw counts, we can write the following load_adata function\n", - "def read_lazy_x_and_obs_only(path) -> ad.AnnData:\n", - " \"\"\"Custom load function to only load raw counts from CxG data.\"\"\"\n", - " # IMPORTANT: Large data should always be loaded lazily to reduce the memory footprint\n", - " adata_ = ad.experimental.read_lazy(path)\n", - " if adata_.raw is not None:\n", - " x = adata_.raw.X\n", - " var = adata_.raw.var\n", - " else:\n", - " x = adata_.X\n", - " var = adata_.var\n", - "\n", - " return ad.AnnData(\n", - " X=x,\n", - " obs=adata_.obs.to_memory(),\n", - " var=var.to_memory(),\n", - " )\n", - "\n", - "\n", - "create_anndata_collection(\n", - " # List all the h5ad files you want to include in the collection\n", - " adata_paths=[\"866d7d5e-436b-4dbd-b7c1-7696487d452e.h5ad\", \"f81463b8-4986-4904-a0ea-20ff02cbb317.h5ad\"],\n", - " # Path to store the output collection\n", - " output_path=\"annbatch_collection\",\n", - " shuffle=True, # Whether to pre-shuffle the cells of the collection\n", - " n_obs_per_dataset=2_097_152, # Number of cells per dataset shard\n", - " var_subset=None, # Optionally subset the collection to a specific gene space\n", - " should_denseify=False,\n", - " load_adata=read_lazy_x_and_obs_only,\n", - ")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Data loading example" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from pathlib import Path\n", - "\n", - "COLLECTION_PATH = Path(\"annbatch_collection/\")" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": { - "tags": [ - "hide-output" - ] - }, - "outputs": [ - { - "data": { - "text/plain": [ - "" - ] - }, - "execution_count": 5, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "import anndata as ad\n", - "\n", - "from annbatch import Loader\n", - "\n", - "ds = Loader(\n", - " batch_size=4096, # Total number of obs per yielded batch\n", - " chunk_size=256, # Number of obs to load from disk contiguously - default settings should work well\n", - " preload_nchunks=32, # Number of chunks to preload + shuffle - default settings should work well\n", - " preload_to_gpu=False,\n", - " # If True, preloaded chunks are moved to GPU memory via `cupy`, which can put more pressure on GPU memory but will accelerate loading ~20%\n", - " to_torch=True,\n", - ")\n", - "\n", - "# Add dataset that should be used for training\n", - "ds.add_anndatas(\n", - " [\n", - " ad.AnnData(\n", - " X=ad.io.sparse_dataset(zarr.open(p)[\"X\"]),\n", - " obs=ad.io.read_elem(zarr.open(p)[\"obs\"]),\n", - " )\n", - " for p in COLLECTION_PATH.glob(\"*.zarr\")\n", - " ],\n", - " obs_keys=\"cell_type\",\n", - ")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "**IMPORTANT:**\n", - "* The `Loader` yields batches of sparse tensors.\n", - "* The conversion to dense tensors should be done on the GPU, as shown in the example below.\n", - " * First call `.cuda()` and then `.to_dense()`\n", - " * E.g. `x = x.cuda().to_dense()`\n", - " * This is significantly faster than doing the dense conversion on the CPU.\n" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "metadata": { - "tags": [ - "hide-output" - ] - }, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - " 0%| | 0/171792 [00:00 ad.AnnData:\n", - " \"\"\"Custom load function to only load raw counts from CxG data.\"\"\"\n", - " # As it's only a small dataset, we can load the full dataset into memory to speed up computations\n", - " adata_ = ad.read_h5ad(path) # Replace with ad.experimental.read_lazy if data does not fit into memory anymore\n", - " if adata_.raw is not None:\n", - " x = adata_.raw.X\n", - " var = adata_.raw.var\n", - " else:\n", - " x = adata_.X\n", - " var = adata_.var\n", - "\n", - " return ad.AnnData(X=x, obs=adata_.obs, var=var)\n", - "\n", - "\n", - "add_to_collection(\n", - " adata_paths=[\n", - " \"866d7d5e-436b-4dbd-b7c1-7696487d452e.h5ad\",\n", - " ],\n", - " output_path=\"annbatch_collection\",\n", - " load_adata=read_x_and_obs_only,\n", - ")" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "venv", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.12.6" - } + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Quickstart `annbatch`" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "This notebook will walk you through the following steps:\n", + "1. How to convert an existing collection of `anndata` files into a shuffled, zarr-based, collection of `anndata` datasets\n", + "2. How to load the converted collection using `annbatch`\n", + "3. Extend an existing collection with new `anndata` datasets" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": { + "tags": [ + "hide-output" + ] + }, + "outputs": [], + "source": [ + "# !pip install annbatch[zarrs, torch]" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": { + "tags": [ + "hide-output" + ] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "zsh:1: command not found: wget\n", + "zsh:1: command not found: wget\n" + ] + } + ], + "source": [ + "# Download two example datasets from CELLxGENE\n", + "!wget https://datasets.cellxgene.cziscience.com/866d7d5e-436b-4dbd-b7c1-7696487d452e.h5ad\n", + "!wget https://datasets.cellxgene.cziscience.com/f81463b8-4986-4904-a0ea-20ff02cbb317.h5ad" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "**IMPORTANT**: Configure zarrs\n", + "\n", + "This step is both required for converting existing `anndata` files into a performant, shuffled collection of datasets for mini batch loading" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "tags": [ + "hide-output" + ] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 1, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import zarr\n", + "\n", + "zarr.config.set({\"codec_pipeline.path\": \"zarrs.ZarrsCodecPipeline\"})" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "import warnings\n", + "\n", + "# Suppress zarr vlen-utf8 codec warnings\n", + "warnings.filterwarnings(\n", + " \"ignore\",\n", + " message=\"The codec `vlen-utf8` is currently not part in the Zarr format 3 specification.*\",\n", + " category=UserWarning,\n", + " module=\"zarr.codecs.vlen_utf8\",\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Converting existing `anndata` files into a shuffled collection" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The conversion code will take care of the following things:\n", + "* Align (outer join) the gene spaces across all datasets listed in `adata_paths`\n", + " * The gene spaces are outer-joined based on the gene names provided in the `var_names` field of the individual `AnnData` objects.\n", + " * If you want to subset to specific gene space, you can provide a list of gene names via the `var_subset` parameter.\n", + "* Shuffle the cells across all datasets (this works on larger than memory datasets as well).\n", + " * This is important for block-wise shuffling during data loading.\n", + "* Shuffle the input files across multiple output datasets:\n", + " * The size of each individual output dataset can be controlled via the `n_obs_per_dataset` parameter.\n", + " * We recommend to choose a dataset size that comfortably fits into system memory.\n", + "\n", + "\n", + "You can apply custom data transformations to each input h5ad file by supplying a `load_adata` function to `DatasetCollection.add`" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": { + "tags": [ + "hide-output" + ] + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/ilangold/Projects/Theis/annbatch/venv/lib/python3.12/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n", + "checking for mismatched keys: 100%|██████████| 2/2 [00:00<00:00, 2.19it/s]\n", + "loading: 2it [00:00, 2.19it/s]\n", + "processing chunks: 0%| | 0/1 [00:00" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "%load_ext autoreload\n", + "%autoreload 2\n", + "\n", + "import anndata as ad\n", + "from annbatch import DatasetCollection\n", + "\n", + "# let's write out only shared colunms - otherwise DatasetCollection will warn about all the columns we are missing for good reason - mismatched columns can lead to unexpected data and missing values.\n", + "shared_columns = ad.experimental.read_lazy(\"866d7d5e-436b-4dbd-b7c1-7696487d452e.h5ad\").obs.columns.intersection(\n", + " ad.experimental.read_lazy(\"f81463b8-4986-4904-a0ea-20ff02cbb317.h5ad\").obs.columns\n", + ")\n", + "\n", + "\n", + "# For CELLxGENE data, the raw counts can either be found under .raw.X or under .X (if .raw is not supplied).\n", + "# To have a store that only contains raw counts, we can write the following load_adata function\n", + "def read_lazy_x_and_obs_only(path) -> ad.AnnData:\n", + " \"\"\"Custom load function to only load raw counts from CxG data.\"\"\"\n", + " # IMPORTANT: Large data should always be loaded lazily to reduce the memory footprint\n", + " adata_ = ad.experimental.read_lazy(path)\n", + " if adata_.raw is not None:\n", + " x = adata_.raw.X\n", + " var = adata_.raw.var\n", + " else:\n", + " x = adata_.X\n", + " var = adata_.var\n", + "\n", + " return ad.AnnData(\n", + " X=x,\n", + " obs=adata_.obs.to_memory()[shared_columns],\n", + " var=var.to_memory(),\n", + " )\n", + "\n", + "\n", + "collection = DatasetCollection(zarr.open(\"annbatch_collection\", mode=\"w\"))\n", + "collection.add_adatas(\n", + " # List all the h5ad files you want to include in the collection\n", + " adata_paths=[\"866d7d5e-436b-4dbd-b7c1-7696487d452e.h5ad\", \"f81463b8-4986-4904-a0ea-20ff02cbb317.h5ad\"],\n", + " # Path to store the output collection\n", + " shuffle=True, # Whether to pre-shuffle the cells of the collection\n", + " n_obs_per_dataset=2_097_152, # Number of cells per dataset shard, this number is much higher than available in these datasets but is generally a good target\n", + " var_subset=None, # Optionally subset the collection to a specific gene space\n", + " load_adata=read_lazy_x_and_obs_only,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Data loading example" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now we create our data loader with the desired arguments.\n", + "\n", + "**WARNING**: Without `load_adata` argument in `use_collection`, the *entire* `obs` will be loaded and yielded, degrading performance. It is highly advised to use this argument." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": { + "tags": [ + "hide-output" + ] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import anndata as ad\n", + "\n", + "from annbatch import Loader\n", + "\n", + "\n", + "def _load_adata(g: zarr.Group) -> ad.AnnData:\n", + " return ad.AnnData(X=ad.io.sparse_dataset(g[\"X\"]), obs=ad.experimental.read_lazy(g).obs[[\"cell_type\"]].to_memory())\n", + "\n", + "\n", + "ds = Loader(\n", + " batch_size=4096, # Total number of obs per yielded batch\n", + " chunk_size=256, # Number of obs to load from disk contiguously - default settings should work well\n", + " preload_nchunks=32, # Number of chunks to preload + shuffle - default settings should work well\n", + " # If True, preloaded chunks are moved to GPU memory via `cupy`, which can put more pressure on GPU memory but will accelerate loading ~20%\n", + " preload_to_gpu=False,\n", + " to_torch=True,\n", + ")\n", + "\n", + "# Add in the shuffled data that should be used for training.\n", + "ds.use_collection(collection, load_adata=_load_adata)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "**IMPORTANT:**\n", + "* The `Loader` yields batches of sparse tensors.\n", + "* The conversion to dense tensors should be done on the GPU, as shown in the example below.\n", + " * First call `.cuda()` and then `.to_dense()`\n", + " * E.g. `x = x.cuda().to_dense()`\n", + " * This is significantly faster than doing the dense conversion on the CPU.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [ + "hide-output" + ] + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 0%| | 42/171792 [00:07<8:54:32, 5.36it/s] \n" + ] + } + ], + "source": [ + "# Iterate over dataloader\n", + "import tqdm\n", + "\n", + "for batch in tqdm.tqdm(ds):\n", + " x, obs = batch[\"X\"], batch[\"obs\"][\"cell_type\"]\n", + " # Important: Convert to dense on GPU\n", + " x = x.cuda().to_dense()\n", + " # Feed data into your model\n", + " ..." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Optional: Extend an existing collection with a new dataset" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "You might want to extend an existing pre-shuffled collection with a new dataset.\n", + "This can be done using the `add` method again.\n", + "\n", + "This function will take care of shuffling the new dataset into the existing collection without having to re-shuffle the entire collection." + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": { + "tags": [ + "hide-output" + ] + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "checking for mismatched keys: 100%|██████████| 1/1 [00:00<00:00, 1.65it/s]\n", + "loading: 1it [00:00, 1.77it/s]\n", + "checking for mismatched keys: 100%|██████████| 2/2 [00:00<00:00, 13.66it/s]\n", + "processing chunks: 0%| | 0/1 [00:00" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "collection.add_adatas(\n", + " adata_paths=[\n", + " \"866d7d5e-436b-4dbd-b7c1-7696487d452e.h5ad\",\n", + " ],\n", + " load_adata=read_lazy_x_and_obs_only,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.3" + } + }, + "nbformat": 4, + "nbformat_minor": 4 } diff --git a/src/annbatch/io.py b/src/annbatch/io.py index 7ce99bfc..55eeff26 100644 --- a/src/annbatch/io.py +++ b/src/annbatch/io.py @@ -1,12 +1,13 @@ from __future__ import annotations -import json +import math import random +import re import warnings from collections import defaultdict from functools import wraps from pathlib import Path -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Self import anndata as ad import dask.array as da @@ -19,13 +20,18 @@ from tqdm.auto import tqdm from zarr.codecs import BloscCodec, BloscShuffle +from annbatch.utils import split_given_size + if TYPE_CHECKING: - from collections.abc import Callable, Iterable, Mapping + from collections.abc import Callable, Generator, Iterable, Mapping from os import PathLike from typing import Any, Literal + import h5py from zarr.abc.codec import BytesBytesCodec +V1_ENCODING = {"encoding-type": "annbatch-preshuffled", "encoding-version": "0.1.0"} + def _round_down(num: int, divisor: int): return num - (num % divisor) @@ -40,6 +46,7 @@ def write_sharded( dense_chunk_size: int = 1024, dense_shard_size: int = 4194304, compressors: Iterable[BytesBytesCodec] = (BloscCodec(cname="lz4", clevel=3, shuffle=BloscShuffle.shuffle),), + key: str | None = None, ): """Write a sharded zarr store from a single AnnData object. @@ -59,6 +66,8 @@ def write_sharded( Number of obs elements per dense shard along the first axis compressors The compressors to pass to `zarr`. + key + The key to which this object should be written - by default the root, in which case the *entire* store (not just the group) is cleared first. """ ad.settings.zarr_write_format = 3 @@ -99,11 +108,17 @@ def callback( } write_func(store, elem_name, elem, dataset_kwargs=dataset_kwargs) - ad.experimental.write_dispatched(group, "/", adata, callback=callback) + ad.experimental.write_dispatched(group, "/" if key is None else key, adata, callback=callback) zarr.consolidate_metadata(group.store) -def _check_for_mismatched_keys(paths_or_anndatas: Iterable[PathLike[str] | ad.AnnData] | Iterable[str | ad.AnnData]): +def _check_for_mismatched_keys( + paths_or_anndatas: Iterable[PathLike[str] | ad.AnnData | zarr.Group | h5py.Group] | Iterable[str | ad.AnnData], + *, + load_adata: Callable[[PathLike[str] | str], ad.AnnData] = lambda x: ad.experimental.read_lazy( + x, load_annotation_index=False + ), +): num_raw_in_adata = 0 found_keys: dict[str, defaultdict[str, int]] = { "layers": defaultdict(lambda: 0), @@ -112,13 +127,14 @@ def _check_for_mismatched_keys(paths_or_anndatas: Iterable[PathLike[str] | ad.An } for path_or_anndata in tqdm(paths_or_anndatas, desc="checking for mismatched keys"): if not isinstance(path_or_anndata, ad.AnnData): - adata = ad.experimental.read_lazy(path_or_anndata) + adata = load_adata(path_or_anndata) else: adata = path_or_anndata for elem_name, key_count in found_keys.items(): curr_keys = set(getattr(adata, elem_name).keys()) for key in curr_keys: - key_count[key] += 1 + if not (elem_name in {"var", "obs"} and key == "_index"): + key_count[key] += 1 if adata.raw is not None: num_raw_in_adata += 1 if num_raw_in_adata != len(paths_or_anndatas) and num_raw_in_adata != 0: @@ -139,10 +155,12 @@ def _check_for_mismatched_keys(paths_or_anndatas: Iterable[PathLike[str] | ad.An def _lazy_load_anndatas( paths: Iterable[PathLike[str]] | Iterable[str], - load_adata: Callable[[PathLike[str] | str], ad.AnnData] = ad.experimental.read_lazy, + load_adata: Callable[[PathLike[str] | str], ad.AnnData] = lambda x: ad.experimental.read_lazy( + x, load_annotation_index=False + ), ): adatas = [] - categoricals_in_all_adatas = {} + categoricals_in_all_adatas: dict[str, pd.Index] = {} for i, path in tqdm(enumerate(paths), desc="loading"): adata = load_adata(path) # Track the source file for this given anndata object @@ -152,20 +170,23 @@ def _lazy_load_anndatas( # Concatenating Dataset2D drops categoricals so we need to track them if isinstance(adata.obs, Dataset2D): categorical_cols_in_this_adata = { - col: set(adata.obs[col].dtype.categories) - for col in adata.obs.columns - if adata.obs[col].dtype == "category" + col: adata.obs[col].dtype.categories for col in adata.obs.columns if adata.obs[col].dtype == "category" } if not categoricals_in_all_adatas: categoricals_in_all_adatas = { **categorical_cols_in_this_adata, - "src_path": set(adata.obs["src_path"].dtype.categories), + "src_path": adata.obs["src_path"].dtype.categories, } else: for k in categoricals_in_all_adatas.keys() & categorical_cols_in_this_adata.keys(): - categoricals_in_all_adatas[k] = set(categoricals_in_all_adatas[k]).union( - set(categorical_cols_in_this_adata[k]) + categoricals_in_all_adatas[k] = categoricals_in_all_adatas[k].union( + categorical_cols_in_this_adata[k] ) + # TODO: Probably bug in anndata, need the true index for proper outer joins (can't skirt this with fake indexes, at least not in the mixed-type regime). + if isinstance(adata.var, Dataset2D): + adata.var.index = adata.var.true_index + if adata.raw is not None and isinstance(adata.raw.var, Dataset2D): + adata.raw.var.index = adata.raw.var.true_index adatas.append(adata) if len(adatas) == 1: return adatas[0] @@ -175,17 +196,38 @@ def _lazy_load_anndatas( return adata -def _create_chunks_for_shuffling(adata: ad.AnnData, shuffle_n_obs_per_dataset: int = 1_048_576, shuffle: bool = True): - chunk_boundaries = np.cumsum([0] + list(adata.X.chunks[0])) - slices = [ - slice(int(start), int(end)) for start, end in zip(chunk_boundaries[:-1], chunk_boundaries[1:], strict=True) - ] +def _create_chunks_for_shuffling( + n_obs: int, + shuffle_chunk_size: int = 1000, + shuffle: bool = True, + *, + shuffle_n_obs_per_dataset: int | None = None, + n_chunkings: int | None = None, +) -> list[np.ndarray]: + # this splits the array up into `shuffle_chunk_size` contiguous runs + idxs = split_given_size(np.arange(n_obs), shuffle_chunk_size) if shuffle: - random.shuffle(slices) - idxs = np.concatenate([np.arange(s.start, s.stop) for s in slices]) - idxs = np.array_split(idxs, np.ceil(len(idxs) / shuffle_n_obs_per_dataset)) - - return idxs + random.shuffle(idxs) + match shuffle_n_obs_per_dataset is not None, n_chunkings is not None: + case True, False: + n_slices_per_dataset = int(shuffle_n_obs_per_dataset // shuffle_chunk_size) + use_single_chunking = n_obs <= shuffle_n_obs_per_dataset or n_slices_per_dataset <= 1 + case False, True: + n_slices_per_dataset = (n_obs // n_chunkings) // shuffle_chunk_size + use_single_chunking = n_chunkings == 1 + case _, _: + raise ValueError("Cannot provide both shuffle_n_obs_per_dataset and n_chunkings or neither") + # In this case `shuffle_n_obs_per_dataset` is bigger than the size of the dataset or the slice size is probably too big. + if use_single_chunking: + return [np.concatenate(idxs)] + # unfortunately, this is the only way to prevent numpy.split from trying to np.array the idxs list, which can have uneven elements. + idxs = np.array([slice(int(idx[0]), int(idx[-1] + 1)) for idx in idxs]) + return [ + np.concatenate([np.arange(s.start, s.stop) for s in idx]) + for idx in ( + split_given_size(idxs, n_slices_per_dataset) if n_chunkings is None else np.array_split(idxs, n_chunkings) + ) + ] def _compute_blockwise(x: DaskArray) -> sp.spmatrix: @@ -209,9 +251,14 @@ def _persist_adata_in_memory(adata: ad.AnnData) -> ad.AnnData: adata.X = _compute_blockwise(adata.X) if isinstance(adata.obs, Dataset2D): adata.obs = adata.obs.to_memory() + # TODO: This is a bug in anndata? + if "_index" in adata.obs.columns: + del adata.obs["_index"] adata = _to_categorical_obs(adata) if isinstance(adata.var, Dataset2D): adata.var = adata.var.to_memory() + if "_index" in adata.var.columns: + del adata.var["_index"] if adata.raw is not None: adata_raw = adata.raw.to_adata() @@ -219,19 +266,28 @@ def _persist_adata_in_memory(adata: ad.AnnData) -> ad.AnnData: adata_raw.X = _compute_blockwise(adata_raw.X) if isinstance(adata_raw.var, Dataset2D): adata_raw.var = adata_raw.var.to_memory() + if "_index" in adata_raw.var.columns: + del adata_raw.var["_index"] if isinstance(adata_raw.obs, Dataset2D): adata_raw.obs = adata_raw.obs.to_memory() del adata.raw adata.raw = adata_raw - for k, elem in adata.obsm.items(): - # TODO: handle `Dataset2D` in `obsm` and `varm` that are - if isinstance(elem, DaskArray): - adata.obsm[k] = _compute_blockwise(elem) - - for k, elem in adata.layers.items(): - if isinstance(elem, DaskArray): - adata.obsm[k] = _compute_blockwise(elem) + for axis_name in ["layers", "obsm", "varm", "obsp", "varp"]: + for k, elem in getattr(adata, axis_name).items(): + # TODO: handle `Dataset2D` in `obsm` and `varm` that are + if isinstance(elem, DaskArray): + getattr(adata, axis_name)[k] = _compute_blockwise(elem) + if isinstance(elem, Dataset2D): + elem = elem.to_memory() + if "_index" in elem.columns: + del elem["_index"] + # TODO: Bug in anndata + if "obs" in axis_name: + elem.index = adata.obs_names + getattr(adata, axis_name)[k] = elem + + return adata.to_memory() return adata @@ -248,259 +304,380 @@ def wrapper(*args, **kwargs): return wrapper -@_with_settings -def create_anndata_collection( - adata_paths: Iterable[PathLike[str]] | Iterable[str], - output_path: PathLike[str] | str, - *, - load_adata: Callable[[PathLike[str] | str], ad.AnnData] = ad.experimental.read_lazy, - var_subset: Iterable[str] | None = None, - zarr_sparse_chunk_size: int = 32768, - zarr_sparse_shard_size: int = 134_217_728, - zarr_dense_chunk_size: int = 1024, - zarr_dense_shard_size: int = 4_194_304, - zarr_compressor: Iterable[BytesBytesCodec] = (BloscCodec(cname="lz4", clevel=3, shuffle=BloscShuffle.shuffle),), - h5ad_compressor: Literal["gzip", "lzf"] | None = "gzip", - n_obs_per_dataset: int = 2_097_152, - shuffle: bool = True, - should_denseify: bool = False, - output_format: Literal["h5ad", "zarr"] = "zarr", -): - """Take AnnData paths, create an on-disk set of AnnData datasets with uniform var spaces at the desired path with `n_obs_per_dataset` rows per store. +class DatasetCollection: + """A preshuffled collection object including functionality for creating, adding to, and loading collections shuffled by `annbatch`.""" - The set of AnnData datasets is collectively referred to as a "collection" where each dataset is called `dataset_i.{zarr,h5ad}`. - The main purpose of this function is to create shuffled sharded zarr datasets, which is the default behavior of this function. - However, this function can also output h5 datasets and also unshuffled datasets as well. - The var space is by default outer-joined, but can be subsetted by `var_subset`. - A key `src_path` is added to `obs` to indicate where individual row came from. - We highly recommend making your indexes unique across files, and this function will call {meth}`AnnData.obs_names_make_unique`. - Memory usage should be controlled by `n_obs_per_dataset` as so many rows will be read into memory before writing to disk. + _group: zarr.Group | Path - Parameters - ---------- - adata_paths - Paths to the AnnData files used to create the zarr store. - output_path - Path to the output zarr store. - load_adata - Function to customize lazy-loading the invidiual input anndata files. By default, {func}`anndata.experimental.read_lazy` is used. - If you only need a subset of the input anndata files' elems (e.g., only `X` and `obs`), you can provide a custom function here to speed up loading and harmonize your data. - The input to the function is a path to an anndata file, and the output is an anndata object which has `X` as a {class}`dask.array.Array`. - var_subset - Subset of gene names to include in the store. If None, all genes are included. - Genes are subset based on the `var_names` attribute of the concatenated AnnData object. - zarr_sparse_chunk_size - Size of the chunks to use for the `indices` and `data` of a sparse matrix in the zarr store. - zarr_sparse_shard_size - Size of the shards to use for the `indices` and `data` of a sparse matrix in the zarr store. - zarr_dense_chunk_size - Number of observations per dense zarr chunk i.e., sharding is only done along the first axis of the array. - zarr_dense_shard_size - Number of observations per dense zarr shard i.e., chunking is only done along the first axis of the array. - zarr_compressor - Compressors to use to compress the data in the zarr store. - h5ad_compressor - Compressors to use to compress the data in the h5ad store. See anndata.write_h5ad. - n_obs_per_dataset - Number of observations to load into memory at once for shuffling / pre-processing. - The higher this number, the more memory is used, but the better the shuffling. - This corresponds to the size of the shards created. - shuffle - Whether to shuffle the data before writing it to the store. - should_denseify - Whether to write as dense on disk. There's no need to set this for sparse data, it is only for testing. - output_format - Format of the output store. Can be either "zarr" or "h5ad". - - Examples - -------- - >>> import anndata as ad - >>> from annbatch import create_anndata_collection - # create a custom load function to only keep `.X`, `.obs` and `.var` in the output store - >>> def read_lazy_x_and_obs_only(path): - ... adata = ad.experimental.read_lazy(path) - ... return ad.AnnData( - ... X=adata.X, - ... obs=adata.obs.to_memory(), - ... var=adata.var.to_memory(), - ...) - - >>> datasets = [ - ... "path/to/first_adata.h5ad", - ... "path/to/second_adata.h5ad", - ... "path/to/third_adata.h5ad", - ... ] - >>> create_anndata_collection( - ... datasets, - ... "path/to/output/zarr_store", - ... load_adata=read_lazy_x_and_obs_only, - ...) - """ - Path(output_path).mkdir(parents=True, exist_ok=True) - _check_for_mismatched_keys(adata_paths) - adata_concat = _lazy_load_anndatas(adata_paths, load_adata=load_adata) - adata_concat.obs_names_make_unique() - chunks = _create_chunks_for_shuffling(adata_concat, n_obs_per_dataset, shuffle=shuffle) - - if var_subset is None: - var_subset = adata_concat.var_names - - for i, chunk in enumerate(tqdm(chunks, desc="processing chunks")): - var_mask = adata_concat.var_names.isin(var_subset) - # np.sort: It's more efficient to access elements sequentially from dask arrays - # The data will be shuffled later on, we just want the elements at this point - adata_chunk = adata_concat[np.sort(chunk), :][:, var_mask].copy() - adata_chunk = _persist_adata_in_memory(adata_chunk) - if shuffle: - # shuffle adata in memory to break up individual chunks - idxs = np.random.default_rng().permutation(np.arange(len(adata_chunk))) - adata_chunk = adata_chunk[idxs] - # convert to dense format before writing to disk - if should_denseify: - # Need to convert back to dask array to avoid memory issues when converting large sparse matrices to dense - adata_chunk = adata_chunk.copy() - adata_chunk.X = da.from_array( - adata_chunk.X, chunks=(zarr_dense_chunk_size, -1), meta=adata_chunk.X - ).map_blocks(lambda xx: xx.toarray(), dtype=adata_chunk.X.dtype) - - if output_format == "zarr": - f = zarr.open_group(Path(output_path) / f"{DATASET_PREFIX}_{i}.zarr", mode="w") - write_sharded( - f, - adata_chunk, - sparse_chunk_size=zarr_sparse_chunk_size, - sparse_shard_size=zarr_sparse_shard_size, - dense_chunk_size=zarr_dense_chunk_size, - dense_shard_size=zarr_dense_shard_size, - compressors=zarr_compressor, + def __init__( + self, group: zarr.Group | str | Path, *, mode: Literal["a", "r", "r+"] = "a", is_collection_h5ad: bool = False + ): + """Initialization of the object at a given location. + + Note that if the group is a h5py/zarr object, it must have the correct permissions for any subsequent operations you plan to do. + Otherwise, the store will be opened according to the mode argument. + + + Parameters + ---------- + group + The base location for a preshuffled collection. + A :class:`zarr.Group` or path ending in `.zarr` indicates zarr as the shuffled format and otherwise a directory of `h5ad` files will be created. + """ + if not isinstance(group, zarr.Group): + if isinstance(group, str | Path): + if not is_collection_h5ad: + if not str(group).endswith(".zarr"): + warnings.warn( + f"It is highly recommended to make your collections have the `.zarr` suffix, got: {group}.", + stacklevel=2, + ) + self._group = zarr.open_group(group, mode=mode) + else: + warnings.warn( + "Loading h5ad is currently not supported and thus we cannot guarantee the funcionality of the ecosystem with h5ad files." + "DatasetCollection should be able to handle shuffling but we guarantee little else." + "Proceed with caution.", + stacklevel=2, + ) + self._group = Path(group) + self._group.mkdir(exist_ok=True) + else: + raise TypeError("Group must either be a zarr group or a path") + else: + if is_collection_h5ad: + raise ValueError("Do not set `is_collection_h5ad` to True when also passing in a zarr Group.") + self._group = group + + @property + def _dataset_keys(self) -> list[str]: + if isinstance(self._group, zarr.Group): + return sorted( + [k for k in self._group.keys() if re.match(rf"{DATASET_PREFIX}_([0-9]*)", k) is not None], + key=lambda x: int(x.split("_")[1]), ) - elif output_format == "h5ad": - adata_chunk.write_h5ad(Path(output_path) / f"{DATASET_PREFIX}_{i}.h5ad", compression=h5ad_compressor) else: - raise ValueError(f"Unrecognized output_format: {output_format}. Only 'zarr' and 'h5ad' are supported.") - - -def _get_array_encoding_type(path: PathLike[str] | str) -> str: - shards = list(Path(path).glob(f"{DATASET_PREFIX}_*.zarr")) - with open(shards[0] / "X" / "zarr.json") as f: - encoding = json.load(f) - return encoding["attributes"]["encoding-type"] - + raise ValueError("Cannot iterate through folder of h5ad files") -@_with_settings -def add_to_collection( - adata_paths: Iterable[PathLike[str]] | Iterable[str], - output_path: PathLike[str] | str, - load_adata: Callable[[PathLike[str] | str], ad.AnnData] = ad.read_h5ad, - zarr_sparse_chunk_size: int = 32768, - zarr_sparse_shard_size: int = 134_217_728, - zarr_dense_chunk_size: int = 1024, - zarr_dense_shard_size: int = 4_194_304, - zarr_compressor: Iterable[BytesBytesCodec] = (BloscCodec(cname="lz4", clevel=3, shuffle=BloscShuffle.shuffle),), - should_sparsify_output_in_memory: bool = False, -) -> None: - """Add anndata files to an existing collection of sharded anndata zarr datasets. + def __iter__(self) -> Generator[zarr.Group]: + if isinstance(self._group, zarr.Group): + for k in self._dataset_keys: + yield self._group[k] + else: + raise ValueError("Cannot iterate through folder of h5ad files") + + @property + def is_empty(self) -> bool: + """Wether or not there is an existing store at the group location.""" + return ( + (not (V1_ENCODING.items() <= self._group.attrs.items()) or len(self._dataset_keys) == 0) + if isinstance(self._group, zarr.Group) + else (len(list(self._group.iterdir())) == 0) + ) - The var space of the source anndata files will be adapted to the target store. + @_with_settings + def add_adatas( + self, + adata_paths: Iterable[PathLike[str]] | Iterable[str], + *, + load_adata: Callable[[PathLike[str] | str], ad.AnnData] = lambda x: ad.experimental.read_lazy( + x, load_annotation_index=False + ), + var_subset: Iterable[str] | None = None, + zarr_sparse_chunk_size: int = 32768, + zarr_sparse_shard_size: int = 134_217_728, + zarr_dense_chunk_size: int = 1024, + zarr_dense_shard_size: int = 4_194_304, + zarr_compressor: Iterable[BytesBytesCodec] = (BloscCodec(cname="lz4", clevel=3, shuffle=BloscShuffle.shuffle),), + h5ad_compressor: Literal["gzip", "lzf"] | None = "gzip", + n_obs_per_dataset: int = 2_097_152, + shuffle_chunk_size: int = 1000, + shuffle: bool = True, + ) -> Self: + """Take AnnData paths and create or add to an on-disk set of AnnData datasets with uniform var spaces at the desired path (with `n_obs_per_dataset` rows per dataset if running for the first time). + + The set of AnnData datasets is collectively referred to as a "collection" where each dataset is called `dataset_i.{zarr,h5ad}`. + The main purpose of this function is to create shuffled sharded zarr datasets, which is the default behavior of this function. + However, this function can also output h5 datasets and also unshuffled datasets as well. + The var space is by default outer-joined initially, and then subsequently added datasets (i.e., on second calls to this function) are subsetted, but this behavior can be controlled by `var_subset`. + A key `src_path` is added to `obs` to indicate where individual row came from. + We highly recommend making your indexes unique across files, and this function will call `AnnData.obs_names_make_unique`. + Memory usage should be controlled by `n_obs_per_dataset` + `shuffle_chunk_size` as so many rows will be read into memory before writing to disk. + After the dataset completes, a marker is added to the group's `attrs` to note that this dataset has been shuffled by `annbatch`. + This is not a stable API but only for internal purposes at the moment. + + Parameters + ---------- + adata_paths + Paths to the AnnData files used to create the zarr store. + load_adata + Function to customize lazy-loading the invidiual input anndata files. By default, :func:`anndata.experimental.read_lazy` is used. + If you only need a subset of the input anndata files' elems (e.g., only `X` and `obs`), you can provide a custom function here to speed up loading and harmonize your data. + The input to the function is a path to an anndata file, and the output is an :class:`anndata.AnnData` object. + var_subset + Subset of gene names to include in the store. If None, all genes are included. + Genes are subset based on the `var_names` attribute of the concatenated AnnData object. + zarr_sparse_chunk_size + Size of the chunks to use for the `indices` and `data` of a sparse matrix in the zarr store. + zarr_sparse_shard_size + Size of the shards to use for the `indices` and `data` of a sparse matrix in the zarr store. + zarr_dense_chunk_size + Number of observations per dense zarr chunk i.e., sharding is only done along the first axis of the array. + zarr_dense_shard_size + Number of observations per dense zarr shard i.e., chunking is only done along the first axis of the array. + zarr_compressor + Compressors to use to compress the data in the zarr store. + h5ad_compressor + Compressors to use to compress the data in the h5ad store. See anndata.write_h5ad. + n_obs_per_dataset + Number of observations to load into memory at once for shuffling / pre-processing. + The higher this number, the more memory is used, but the better the shuffling. + This corresponds to the size of the shards created. + Only applicable when adding datasets for the first time, otherwise ignored. + shuffle + Whether to shuffle the data before writing it to the store. + Ignored once the store is non-empty. + shuffle_chunk_size + How many contiguous rows to load into memory before shuffling at once. + `(shuffle_chunk_size // n_obs_per_dataset)` slices will be loaded of size `shuffle_chunk_size`. + + Examples + -------- + >>> import anndata as ad + >>> from annbatch import DatasetCollection + # create a custom load function to only keep `.X`, `.obs` and `.var` in the output store + >>> def read_lazy_x_and_obs_only(path): + ... adata = ad.experimental.read_lazy(path) + ... return ad.AnnData( + ... X=adata.X, + ... obs=adata.obs.to_memory(), + ... var=adata.var.to_memory(), + ...) + >>> datasets = [ + ... "path/to/first_adata.h5ad", + ... "path/to/second_adata.h5ad", + ... "path/to/third_adata.h5ad", + ... ] + >>> DatasetCollection("path/to/output/zarr_store.zarr").add_adatas( + ... datasets, + ... load_adata=read_lazy_x_and_obs_only, + ...) + """ + if shuffle_chunk_size > n_obs_per_dataset: + raise ValueError("Cannot have a large slice size than observations per dataset") + shared_kwargs = { + "adata_paths": adata_paths, + "load_adata": load_adata, + "zarr_sparse_chunk_size": zarr_sparse_chunk_size, + "zarr_sparse_shard_size": zarr_sparse_shard_size, + "zarr_dense_chunk_size": zarr_dense_chunk_size, + "zarr_dense_shard_size": zarr_dense_shard_size, + "zarr_compressor": zarr_compressor, + "h5ad_compressor": h5ad_compressor, + "shuffle_chunk_size": shuffle_chunk_size, + "shuffle": shuffle, + } + if self.is_empty: + self._create_collection(**shared_kwargs, n_obs_per_dataset=n_obs_per_dataset, var_subset=var_subset) + else: + self._add_to_collection(**shared_kwargs) + return self - Parameters - ---------- - adata_paths - Paths to the anndata files to be appended to the collection of output chunks. - output_path - Path to the output zarr store. - load_adata - Function to customize loading the invidiual input anndata files. By default, {func}`anndata.read_h5ad` is used. - If you only need a subset of the input anndata files' elems (e.g., only `X` and `obs`), you can provide a custom function here to speed up loading and harmonize your data. - The input to the function is a path to an anndata file, and the output is an anndata object. - If the input data is too large to fit into memory, you should use `ad.experimental.read_lazy` instead. - zarr_sparse_chunk_size - Size of the chunks to use for the `indices` and `data` of a sparse matrix in the zarr store. - zarr_sparse_shard_size - Size of the shards to use for the `indices` and `data` of a sparse matrix in the zarr store. - zarr_dense_chunk_size - Number of observations per dense zarr chunk i.e., sharding is only done along the first axis of the array. - zarr_dense_shard_size - Number of observations per dense zarr shard i.e., chunking is only done along the first axis of the array. - zarr_compressor - Compressors to use to compress the data in the zarr store. - should_sparsify_output_in_memory - This option is for testing only appending sparse files to dense stores. - To save memory, the blocks of a dense on-disk store can be sparsified for in-memory processing. - - Examples - -------- - >>> import anndata as ad - >>> from annbatch import add_to_collection - >>> datasets = [ - ... "path/to/first_adata.h5ad", - ... "path/to/second_adata.h5ad", - ... "path/to/third_adata.h5ad", - ... ] - >>> add_to_collection( - ... datasets, - ... "path/to/output/zarr_store", - ... load_adata=ad.read_h5ad, # replace with ad.experimental.read_lazy if data does not fit into memory - ...) - """ - shards = list(Path(output_path).glob(f"{DATASET_PREFIX}_*.zarr")) - if len(shards) == 0: - raise ValueError( - "Store at `output_path` does not exist or is empty. Please run `create_anndata_collection` first." + def _create_collection( + self, + *, + adata_paths: Iterable[PathLike[str]] | Iterable[str], + load_adata: Callable[[PathLike[str] | str], ad.AnnData] = lambda x: ad.experimental.read_lazy( + x, load_annotation_index=False + ), + var_subset: Iterable[str] | None = None, + zarr_sparse_chunk_size: int = 32768, + zarr_sparse_shard_size: int = 134_217_728, + zarr_dense_chunk_size: int = 1024, + zarr_dense_shard_size: int = 4_194_304, + zarr_compressor: Iterable[BytesBytesCodec] = (BloscCodec(cname="lz4", clevel=3, shuffle=BloscShuffle.shuffle),), + h5ad_compressor: Literal["gzip", "lzf"] | None = "gzip", + n_obs_per_dataset: int = 2_097_152, + shuffle_chunk_size: int = 1000, + shuffle: bool = True, + ) -> None: + """Take AnnData paths, create an on-disk set of AnnData datasets with uniform var spaces at the desired path with `n_obs_per_dataset` rows per dataset. + + The set of AnnData datasets is collectively referred to as a "collection" where each dataset is called `dataset_i.{zarr,h5ad}`. + The main purpose of this function is to create shuffled sharded zarr datasets, which is the default behavior of this function. + However, this function can also output h5 datasets and also unshuffled datasets as well. + The var space is by default outer-joined, but can be subsetted by `var_subset`. + A key `src_path` is added to `obs` to indicate where individual row came from. + We highly recommend making your indexes unique across files, and this function will call `AnnData.obs_names_make_unique`. + Memory usage should be controlled by `n_obs_per_dataset` as so many rows will be read into memory before writing to disk. + + Parameters + ---------- + adata_paths + Paths to the AnnData files used to create the zarr store. + load_adata + Function to customize lazy-loading the invidiual input anndata files. By default, :func:`anndata.experimental.read_lazy` is used. + If you only need a subset of the input anndata files' elems (e.g., only `X` and `obs`), you can provide a custom function here to speed up loading and harmonize your data. + The input to the function is a path to an anndata file, and the output is an anndata object which has `X` as a :class:`dask.array.Array`. + var_subset + Subset of gene names to include in the store. If None, all genes are included. + Genes are subset based on the `var_names` attribute of the concatenated AnnData object. + Only applicable when adding datasets for the first time, otherwise ignored and the incoming data's var space is subsetted to that of the existing collection. + zarr_sparse_chunk_size + Size of the chunks to use for the `indices` and `data` of a sparse matrix in the zarr store. + zarr_sparse_shard_size + Size of the shards to use for the `indices` and `data` of a sparse matrix in the zarr store. + zarr_dense_chunk_size + Number of observations per dense zarr chunk i.e., sharding is only done along the first axis of the array. + zarr_dense_shard_size + Number of observations per dense zarr shard i.e., chunking is only done along the first axis of the array. + zarr_compressor + Compressors to use to compress the data in the zarr store. + h5ad_compressor + Compressors to use to compress the data in the h5ad store. See anndata.write_h5ad. + n_obs_per_dataset + Number of observations to load into memory at once for shuffling / pre-processing. + The higher this number, the more memory is used, but the better the shuffling. + This corresponds to the size of the shards created. + Only applicable when adding datasets for the first time, otherwise ignored. + shuffle + Whether to shuffle the data before writing it to the store. + shuffle_chunk_size + How many contiguous rows to load into memory before shuffling at once. + `(shuffle_chunk_size // n_obs_per_dataset)` slices will be loaded of size `shuffle_chunk_size`. + """ + if not self.is_empty: + raise RuntimeError("Cannot create a collection at a location that already has a shuffled collection") + _check_for_mismatched_keys(adata_paths, load_adata=load_adata) + adata_concat = _lazy_load_anndatas(adata_paths, load_adata=load_adata) + adata_concat.obs_names_make_unique() + n_obs_per_dataset = min(adata_concat.shape[0], n_obs_per_dataset) + chunks = _create_chunks_for_shuffling( + adata_concat.shape[0], shuffle_chunk_size, shuffle=shuffle, shuffle_n_obs_per_dataset=n_obs_per_dataset ) - encoding = _get_array_encoding_type(output_path) - if encoding == "array": - print("Detected array encoding type. Will convert to dense format before writing.") - # Check for mismatched keys among the inputs. - _check_for_mismatched_keys(adata_paths) - - adata_concat = _lazy_load_anndatas(adata_paths, load_adata=load_adata) - # Check for mismatched keys between shards and the inputs. - _check_for_mismatched_keys([adata_concat] + shards) - if isinstance(adata_concat.X, DaskArray): - chunks = _create_chunks_for_shuffling(adata_concat, np.ceil(len(adata_concat) / len(shards)), shuffle=True) - else: - chunks = np.array_split(np.random.default_rng().permutation(len(adata_concat)), len(shards)) - - adata_concat.obs_names_make_unique() - if encoding == "array": - if not should_sparsify_output_in_memory: - if isinstance(adata_concat.X, sp.spmatrix): - adata_concat.X = adata_concat.X.toarray() - elif isinstance(adata_concat.X, DaskArray) and isinstance(adata_concat.X._meta, sp.spmatrix): - adata_concat.X = adata_concat.X.map_blocks( - lambda x: x.toarray(), meta=np.ndarray, dtype=adata_concat.X.dtype + + if var_subset is None: + var_subset = adata_concat.var_names + for i, chunk in enumerate(tqdm(chunks, desc="processing chunks")): + var_mask = adata_concat.var_names.isin(var_subset) + # np.sort: It's more efficient to access elements sequentially from dask arrays + # The data will be shuffled later on, we just want the elements at this point + adata_chunk = adata_concat[np.sort(chunk), :][:, var_mask].copy() + adata_chunk = _persist_adata_in_memory(adata_chunk) + if shuffle: + # shuffle adata in memory to break up individual chunks + idxs = np.random.default_rng().permutation(np.arange(len(adata_chunk))) + adata_chunk = adata_chunk[idxs] + if isinstance(self._group, zarr.Group): + write_sharded( + self._group, + adata_chunk, + sparse_chunk_size=zarr_sparse_chunk_size, + sparse_shard_size=zarr_sparse_shard_size, + dense_chunk_size=min(adata_chunk.shape[0], zarr_dense_chunk_size), + dense_shard_size=min(adata_chunk.shape[0], zarr_dense_shard_size), + compressors=zarr_compressor, + key=f"{DATASET_PREFIX}_{i}", ) - elif encoding == "csr_matrix": - if isinstance(adata_concat.X, np.ndarray): - adata_concat.X = sp.csr_matrix(adata_concat.X) - elif isinstance(adata_concat.X, DaskArray) and isinstance(adata_concat.X._meta, np.ndarray): - adata_concat.X = adata_concat.X.map_blocks( - sp.csr_matrix, meta=sp.csr_matrix(np.array([0], dtype=adata_concat.X.dtype)) - ) + else: + ad.io.write_h5ad( + self._group / f"{DATASET_PREFIX}_{i}.h5ad", + adata_chunk, + dataset_kwargs={"compression": h5ad_compressor}, + ) + if isinstance(self._group, zarr.Group): + self._group.update_attributes(V1_ENCODING) - for shard, chunk in tqdm(zip(shards, chunks, strict=False), total=len(shards), desc="processing chunks"): - if should_sparsify_output_in_memory and encoding == "array": - adata_shard = _lazy_load_anndatas([shard]) - adata_shard.X = adata_shard.X.map_blocks(sp.csr_matrix).compute() - else: - adata_shard = ad.read_zarr(shard) - subset_adata = _to_categorical_obs( - adata_concat[chunk, :][:, adata_concat.var.index.isin(adata_shard.var.index)] - ) - adata = ad.concat([adata_shard, subset_adata], join="outer") - idxs_shuffled = np.random.default_rng().permutation(len(adata)) - adata = adata[idxs_shuffled, :].copy() # this significantly speeds up writing to disk - if should_sparsify_output_in_memory and encoding == "array": - adata.X = adata.X.map_blocks(lambda x: x.toarray(), meta=np.array([0], dtype=adata.X.dtype)).compute() - - f = zarr.open_group(shard, mode="w") - write_sharded( - f, - adata, - sparse_chunk_size=zarr_sparse_chunk_size, - sparse_shard_size=zarr_sparse_shard_size, - dense_chunk_size=zarr_dense_chunk_size, - dense_shard_size=zarr_dense_shard_size, - compressors=zarr_compressor, + def _add_to_collection( + self, + *, + adata_paths: Iterable[PathLike[str]] | Iterable[str], + load_adata: Callable[[PathLike[str] | str], ad.AnnData] = ad.read_h5ad, + zarr_sparse_chunk_size: int = 32768, + zarr_sparse_shard_size: int = 134_217_728, + zarr_dense_chunk_size: int = 1024, + zarr_dense_shard_size: int = 4_194_304, + zarr_compressor: Iterable[BytesBytesCodec] = (BloscCodec(cname="lz4", clevel=3, shuffle=BloscShuffle.shuffle),), + h5ad_compressor: Literal["gzip", "lzf"] | None = "gzip", + shuffle_chunk_size: int = 1000, + shuffle: bool = True, + ) -> None: + """Add anndata files to an existing collection of sharded anndata zarr datasets. + + The var space of the source anndata files will be adapted to the target store. + + Parameters + ---------- + adata_paths + Paths to the anndata files to be appended to the collection of output chunks. + load_adata + Function to customize loading the invidiual input anndata files. By default, :func:`anndata.read_h5ad` is used. + If you only need a subset of the input anndata files' elems (e.g., only `X` and `obs`), you can provide a custom function here to speed up loading and harmonize your data. + The input to the function is a path to an anndata file, and the output is an anndata object. + If the input data is too large to fit into memory, you should use :func:`annndata.experimental.read_lazy` instead. + zarr_sparse_chunk_size + Size of the chunks to use for the `indices` and `data` of a sparse matrix in the zarr store. + zarr_sparse_shard_size + Size of the shards to use for the `indices` and `data` of a sparse matrix in the zarr store. + zarr_dense_chunk_size + Number of observations per dense zarr chunk i.e., sharding is only done along the first axis of the array. + zarr_dense_shard_size + Number of observations per dense zarr shard i.e., chunking is only done along the first axis of the array. + zarr_compressor + Compressors to use to compress the data in the zarr store. + should_sparsify_output_in_memory + This option is for testing only appending sparse files to dense stores. + To save memory, the blocks of a dense on-disk store can be sparsified for in-memory processing. + shuffle_chunk_size + How many contiguous rows to load into memory of the input data for pseudo-blockshuffling into the existing datasets. + shuffle + Whether or not to shuffle when adding. Otherwise, the incoming data will just be split up and appended. + """ + if self.is_empty: + raise ValueError("Store is empty. Please run `DatasetCollection.add` first.") + # Check for mismatched keys among the inputs. + _check_for_mismatched_keys(adata_paths, load_adata=load_adata) + + adata_concat = _lazy_load_anndatas(adata_paths, load_adata=load_adata) + if math.ceil(adata_concat.shape[0] / shuffle_chunk_size) < len(self._dataset_keys): + raise ValueError( + f"Use a shuffle size small enough to distribute the input data with {adata_concat.shape[0]} obs across {len(self._dataset_keys)} anndata stores." + "Open an issue if the incoming anndata is so small it cannot be distributed across the on-disk data" + ) + # Check for mismatched keys between datasets and the inputs. + _check_for_mismatched_keys([adata_concat] + [self._group[k] for k in self._dataset_keys]) + chunks = _create_chunks_for_shuffling( + adata_concat.shape[0], shuffle_chunk_size, shuffle=shuffle, n_chunkings=len(self._dataset_keys) ) + + adata_concat.obs_names_make_unique() + for dataset, chunk in tqdm( + zip(self._dataset_keys, chunks, strict=True), total=len(self._dataset_keys), desc="processing chunks" + ): + adata_dataset = ad.io.read_elem(self._group[dataset]) + subset_adata = _to_categorical_obs( + adata_concat[chunk, :][:, adata_concat.var.index.isin(adata_dataset.var.index)] + ) + adata = ad.concat([adata_dataset, subset_adata], join="outer") + if shuffle: + idxs = np.random.default_rng().permutation(adata.shape[0]) + else: + idxs = np.arange(adata.shape[0]) + adata = _persist_adata_in_memory(adata[idxs, :].copy()) + if isinstance(self._group, zarr.Group): + write_sharded( + self._group, + adata, + sparse_chunk_size=zarr_sparse_chunk_size, + sparse_shard_size=zarr_sparse_shard_size, + dense_chunk_size=min(adata.shape[0], zarr_dense_chunk_size), + dense_shard_size=min(adata.shape[0], zarr_dense_shard_size), + compressors=zarr_compressor, + key=dataset, + ) + else: + ad.io.write_h5ad( + self._group / f"{dataset}.h5ad", + adata, + dataset_kwargs={"compression": h5ad_compressor}, + ) From 0c13efa0655c2f8414a5daf893ace8b4dd5d7ae8 Mon Sep 17 00:00:00 2001 From: Ilan Gold Date: Mon, 19 Jan 2026 12:12:31 +0100 Subject: [PATCH 16/56] breaking: clarify obs handling + change output keys (#115) * chore: clarify obs handling * chore: clearer docs * fix: rename --- CHANGELOG.md | 2 ++ src/annbatch/loader.py | 4 ++-- tests/test_dataset.py | 2 +- 3 files changed, 5 insertions(+), 3 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 247d2b03..9852d6cc 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,6 +12,8 @@ and this project adheres to [Semantic Versioning][]. ### Breaking +### Breaking + - Revert `h5ad` shuffling into one big store (i.e., go back to sharding into individual files) and add warning that `h5ad` is not fully supported by `annbatch`. `is_collection_h5ad` argument to initialization of {class}`annbatch.DatasetCollection` must be passed when initializing into to use a preshuffled collection of `h5ad` files, reading or writing. - Renamed {class}`annbatch.types.LoaderOutput` `["labels"]` and `["data"]` to `["obs"]` and `["X"]` respectively. diff --git a/src/annbatch/loader.py b/src/annbatch/loader.py index a4129cff..db769dfb 100644 --- a/src/annbatch/loader.py +++ b/src/annbatch/loader.py @@ -290,7 +290,7 @@ def add_anndata(self, adata: ad.AnnData) -> Self: Parameters ---------- adata - A :class:`anndata.AnnData` object, with :class:`zarr.Array` or :class:`anndata.abc.CSRDataset` as the data matrix in :attr:`~anndata.AnnData.X`, and :attr:`~anndata.AnnData.obs` containing labels to yield in a :class:`pandas.DataFrame`. + A :class:`anndata.AnnData` object, with :class:`zarr.Array` or :class:`anndata.abc.CSRDataset` as the data matrix in :attr:`~anndata.AnnData.X`, and :attr:`~anndata.AnnData.obs` containing annotations to yield in a :class:`pandas.DataFrame`. """ dataset = adata.X obs = adata.obs @@ -311,7 +311,7 @@ def add_datasets(self, datasets: list[BackingArray], obs: list[pd.DataFrame] | N List of :class:`zarr.Array` or :class:`anndata.abc.CSRDataset` objects, generally from :attr:`anndata.AnnData.X`. They must all be of the same type and match that of any already added datasets. obs - List of :class:`~pandas.DataFrame` labels, generally from :attr:`anndata.AnnData.obs`. + List of :class:`~pandas.DataFrame` obs, generally from :attr:`anndata.AnnData.obs`. """ if obs is None: obs = [None] * len(datasets) diff --git a/tests/test_dataset.py b/tests/test_dataset.py index c4d250c7..127cd974 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -265,7 +265,7 @@ def test_to_torch( to_torch=True, ) ds.add_dataset(**open_func(next(adata_with_zarr_path_same_var_space[1].glob("*.zarr")))) - assert isinstance(next(iter(ds))["data"], torch.Tensor) + assert isinstance(next(iter(ds))["X"], torch.Tensor) @pytest.mark.parametrize("drop_last", [True, False], ids=["drop", "kept"]) From ffe23be82eb57538a131824abc90f451592649c8 Mon Sep 17 00:00:00 2001 From: Ilan Gold Date: Mon, 19 Jan 2026 12:24:45 +0100 Subject: [PATCH 17/56] fix: header level (#116) --- CHANGELOG.md | 2 -- 1 file changed, 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 9852d6cc..247d2b03 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,8 +12,6 @@ and this project adheres to [Semantic Versioning][]. ### Breaking -### Breaking - - Revert `h5ad` shuffling into one big store (i.e., go back to sharding into individual files) and add warning that `h5ad` is not fully supported by `annbatch`. `is_collection_h5ad` argument to initialization of {class}`annbatch.DatasetCollection` must be passed when initializing into to use a preshuffled collection of `h5ad` files, reading or writing. - Renamed {class}`annbatch.types.LoaderOutput` `["labels"]` and `["data"]` to `["obs"]` and `["X"]` respectively. From 5d522fe891501dcfb729c75986fa0393a5a08689 Mon Sep 17 00:00:00 2001 From: selmanozleyen Date: Mon, 19 Jan 2026 20:55:31 +0100 Subject: [PATCH 18/56] refactor _prepare_dataset_and_obs --- src/annbatch/loader.py | 46 +++++++++++++++++++++--------------------- 1 file changed, 23 insertions(+), 23 deletions(-) diff --git a/src/annbatch/loader.py b/src/annbatch/loader.py index db769dfb..122549b0 100644 --- a/src/annbatch/loader.py +++ b/src/annbatch/loader.py @@ -277,11 +277,8 @@ def add_anndatas( """ check_lt_1([len(adatas)], ["Number of anndatas"]) for adata in adatas: - dataset = adata.X - obs = adata.obs - if not isinstance(dataset, BackingArray_T.__value__): - raise TypeError(f"Found {type(dataset)} but only {BackingArray_T.__value__} are usable") - self._add_dataset_unchecked(cast("BackingArray", dataset), obs) + dataset, obs = self._prepare_dataset_and_obs(adata) + self._add_dataset_unchecked(dataset, obs) return self def add_anndata(self, adata: ad.AnnData) -> Self: @@ -292,14 +289,18 @@ def add_anndata(self, adata: ad.AnnData) -> Self: adata A :class:`anndata.AnnData` object, with :class:`zarr.Array` or :class:`anndata.abc.CSRDataset` as the data matrix in :attr:`~anndata.AnnData.X`, and :attr:`~anndata.AnnData.obs` containing annotations to yield in a :class:`pandas.DataFrame`. """ + dataset, obs = self._prepare_dataset_and_obs(adata) + self.add_dataset(dataset, obs) + return self + + def _prepare_dataset_and_obs(self, adata: ad.AnnData) -> tuple[BackingArray, pd.DataFrame | None]: dataset = adata.X obs = adata.obs if len(obs.columns) == 0: obs = None if not isinstance(dataset, BackingArray_T.__value__): raise TypeError(f"Found {type(dataset)} but only {BackingArray_T.__value__} are usable") - self.add_dataset(cast("BackingArray", dataset), obs) - return self + return cast("BackingArray", dataset), obs @validate_sampler(lambda self, datasets, obs=None: sum(ds.shape[0] for ds in datasets)) def add_datasets(self, datasets: list[BackingArray], obs: list[pd.DataFrame] | None = None) -> Self: @@ -615,16 +616,10 @@ def __iter__( chunks: list[InputInMemoryArray] = zsync.sync(self._index_datasets(dataset_index_to_slices)) chunks_converted = self._accumulate_chunks(chunks) # Accumulate labels and indices if possible - obs: None | list[pd.DataFrame] = self._maybe_accumulate_labels(dataset_index_to_slices) - indices: None | list[np.ndarray] = self._maybe_accumulate_indices(chunks_to_load) + concatenated_obs: None | list[pd.DataFrame] = self._maybe_accumulate_labels(dataset_index_to_slices) + in_memory_indices: None | list[np.ndarray] = self._maybe_accumulate_indices(chunks_to_load) in_memory_data = mod.vstack(chunks_converted) - concatenated_obs = None - in_memory_indices = None - if self._obs is not None and obs is not None: - concatenated_obs = pd.concat(obs) - if self._return_index and indices is not None: - in_memory_indices = np.concatenate(indices) for split in splits: yield self._prepare_output( @@ -656,20 +651,24 @@ def _maybe_accumulate_labels( """Gather obs labels for the loaded slices if possible.""" if self._obs is None: return None - return [ - self._obs[idx].iloc[np.concatenate([np.arange(s.start, s.stop) for s in slices])] - for idx, slices in dataset_index_to_slices.items() - ] + return pd.concat( + [ + self._obs[idx].iloc[np.concatenate([np.arange(s.start, s.stop) for s in slices])] + for idx, slices in dataset_index_to_slices.items() + ] + ) def _maybe_accumulate_indices(self, slices: list[slice]) -> list[np.ndarray] | None: """Gather original indices for the loaded slices if possible.""" if self._return_index is False: return None dataset_index_to_slices = self._slices_to_slices_with_array_index(slices, use_original_space=True) - return [ - np.concatenate([np.arange(s.start, s.stop) for s in dataset_index_to_slices[idx]]) - for idx in dataset_index_to_slices - ] + return np.concatenate( + [ + np.concatenate([np.arange(s.start, s.stop) for s in dataset_index_to_slices[idx]]) + for idx in dataset_index_to_slices + ] + ) def _prepare_output( self, @@ -689,4 +688,5 @@ def _prepare_output( data = in_memory_data[split] if self._to_torch: data = to_torch(data, self._preload_to_gpu) + print(obs) return {"X": data, "obs": obs, "index": index} From d8168c1fc635b2203b0fc49b5f14eda4e6de886c Mon Sep 17 00:00:00 2001 From: selmanozleyen Date: Mon, 19 Jan 2026 21:00:57 +0100 Subject: [PATCH 19/56] update docstring for loadrequest --- src/annbatch/types.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/annbatch/types.py b/src/annbatch/types.py index 508ff8a1..668a96ae 100644 --- a/src/annbatch/types.py +++ b/src/annbatch/types.py @@ -19,14 +19,16 @@ class LoadRequest(TypedDict): """Load request from sampler. + This is the request format Loader will expect from the sampler. + Not satisfying the constrains documented here may result in unexpected behavior. + Attributes ---------- chunks - Chunks to load - a list of at most chunk_size ranged slices. + Chunks to load - a list of slices with a range of chunk_size except the last one which may be smaller but not empty. splits How the concatenation of chunks should be split into batches. - A list of splits, last one may be partial (< batch_size). - The loader carries over partial batches to the next iteration. + A list of splits, last one may be partial but not empty i.e. 1 <= len(last_split) <= batch_size. """ chunks: list[slice] From c501646cae8a0c9dbb78e07f5fe640bc17e8100b Mon Sep 17 00:00:00 2001 From: selmanozleyen Date: Mon, 19 Jan 2026 21:09:19 +0100 Subject: [PATCH 20/56] separate files for samplers --- src/annbatch/sampler/__init__.py | 3 +- .../{_sampler.py => _chunk_sampler.py} | 61 +--------------- src/annbatch/sampler/abc/__init__.py | 5 ++ src/annbatch/sampler/abc/_sampler.py | 70 +++++++++++++++++++ 4 files changed, 78 insertions(+), 61 deletions(-) rename src/annbatch/sampler/{_sampler.py => _chunk_sampler.py} (82%) create mode 100644 src/annbatch/sampler/abc/__init__.py create mode 100644 src/annbatch/sampler/abc/_sampler.py diff --git a/src/annbatch/sampler/__init__.py b/src/annbatch/sampler/__init__.py index 9a07f9aa..40a8f79b 100644 --- a/src/annbatch/sampler/__init__.py +++ b/src/annbatch/sampler/__init__.py @@ -3,7 +3,8 @@ This module provides samplers optimized for chunk-based data access patterns. """ -from annbatch.sampler._sampler import ChunkSampler, Sampler +from annbatch.sampler._chunk_sampler import ChunkSampler +from annbatch.sampler.abc import Sampler __all__ = [ "ChunkSampler", diff --git a/src/annbatch/sampler/_sampler.py b/src/annbatch/sampler/_chunk_sampler.py similarity index 82% rename from src/annbatch/sampler/_sampler.py rename to src/annbatch/sampler/_chunk_sampler.py index d418275c..49181269 100644 --- a/src/annbatch/sampler/_sampler.py +++ b/src/annbatch/sampler/_chunk_sampler.py @@ -3,12 +3,12 @@ from __future__ import annotations import math -from abc import ABC, abstractmethod from importlib.util import find_spec from typing import TYPE_CHECKING import numpy as np +from annbatch.sampler.abc import Sampler from annbatch.utils import check_lt_1, split_given_size if TYPE_CHECKING: @@ -18,65 +18,6 @@ from annbatch.utils import WorkerHandle -class Sampler(ABC): - """Base sampler class. - - Samplers control how data is batched and loaded from the underlying datasets. - """ - - def sample(self, n_obs: int) -> Iterator[LoadRequest]: - """Sample load requests given the total number of observations. - - Parameters - ---------- - n_obs - The total number of observations available. - - Yields - ------ - LoadRequest - Load requests for batching data. - """ - self.validate(n_obs) - yield from self._sample(n_obs) - - @abstractmethod - def validate(self, n_obs: int) -> None: - """Validate the sampler configuration against the loader's state. - - This method is called when the sampler is set on a loader. - Override this method to add custom validation for sampler parameters. - - Parameters - ---------- - n_obs - The total number of observations in the loader. - - Raises - ------ - ValueError - If the sampler configuration is invalid for the given n_obs. - """ - - @abstractmethod - def _sample(self, n_obs: int) -> Iterator[LoadRequest]: - """Implementation of the sample method. - - This method is called by the sample method to perform the actual sampling after - validation has passed. - - Parameters - ---------- - n_obs - The total number of observations available. - - Yields - ------ - LoadRequest - Load requests for batching data. - """ - - class ChunkSampler(Sampler): """Chunk-based sampler for batched data access. diff --git a/src/annbatch/sampler/abc/__init__.py b/src/annbatch/sampler/abc/__init__.py new file mode 100644 index 00000000..776affab --- /dev/null +++ b/src/annbatch/sampler/abc/__init__.py @@ -0,0 +1,5 @@ +from annbatch.sampler.abc._sampler import Sampler + +__all__ = [ + "Sampler", +] diff --git a/src/annbatch/sampler/abc/_sampler.py b/src/annbatch/sampler/abc/_sampler.py new file mode 100644 index 00000000..25b4846d --- /dev/null +++ b/src/annbatch/sampler/abc/_sampler.py @@ -0,0 +1,70 @@ +"""Sampler classes for efficient chunk-based data access.""" + +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from collections.abc import Iterator + + from annbatch.types import LoadRequest + + +class Sampler(ABC): + """Base sampler class. + + Samplers control how data is batched and loaded from the underlying datasets. + """ + + def sample(self, n_obs: int) -> Iterator[LoadRequest]: + """Sample load requests given the total number of observations. + + Parameters + ---------- + n_obs + The total number of observations available. + + Yields + ------ + LoadRequest + Load requests for batching data. + """ + self.validate(n_obs) + yield from self._sample(n_obs) + + @abstractmethod + def validate(self, n_obs: int) -> None: + """Validate the sampler configuration against the loader's state. + + This method is called when the sampler is set on a loader. + Override this method to add custom validation for sampler parameters. + + Parameters + ---------- + n_obs + The total number of observations in the loader. + + Raises + ------ + ValueError + If the sampler configuration is invalid for the given n_obs. + """ + + @abstractmethod + def _sample(self, n_obs: int) -> Iterator[LoadRequest]: + """Implementation of the sample method. + + This method is called by the sample method to perform the actual sampling after + validation has passed. + + Parameters + ---------- + n_obs + The total number of observations available. + + Yields + ------ + LoadRequest + Load requests for batching data. + """ From f9862b0790f1726a790d203d3654f90de1955845 Mon Sep 17 00:00:00 2001 From: selmanozleyen Date: Mon, 19 Jan 2026 21:18:57 +0100 Subject: [PATCH 21/56] prepare_output is no longer needed --- src/annbatch/loader.py | 37 +++++++------------------------------ 1 file changed, 7 insertions(+), 30 deletions(-) diff --git a/src/annbatch/loader.py b/src/annbatch/loader.py index 122549b0..3b918b2f 100644 --- a/src/annbatch/loader.py +++ b/src/annbatch/loader.py @@ -23,7 +23,6 @@ check_lt_1, check_var_shapes, load_x_and_obs, - to_torch, validate_sampler, ) @@ -616,18 +615,17 @@ def __iter__( chunks: list[InputInMemoryArray] = zsync.sync(self._index_datasets(dataset_index_to_slices)) chunks_converted = self._accumulate_chunks(chunks) # Accumulate labels and indices if possible - concatenated_obs: None | list[pd.DataFrame] = self._maybe_accumulate_labels(dataset_index_to_slices) + concatenated_obs: None | list[pd.DataFrame] = self._maybe_accumulate_obs(dataset_index_to_slices) in_memory_indices: None | list[np.ndarray] = self._maybe_accumulate_indices(chunks_to_load) in_memory_data = mod.vstack(chunks_converted) for split in splits: - yield self._prepare_output( - in_memory_data=in_memory_data, - concatenated_obs=concatenated_obs, - in_memory_indices=in_memory_indices, - split=split, - ) + yield { + "X": in_memory_data[split], + "obs": concatenated_obs.iloc[split] if concatenated_obs is not None else None, + "index": in_memory_indices[split] if in_memory_indices is not None else None, + } def _accumulate_chunks(self, chunks: list[InputInMemoryArray]) -> list[OutputInMemoryArray_T]: """Convert fetched chunks to output array format (CSR or ndarray).""" @@ -645,7 +643,7 @@ def _accumulate_chunks(self, chunks: list[InputInMemoryArray]) -> list[OutputInM result.append(self._np_module.asarray(chunk)) return result - def _maybe_accumulate_labels( + def _maybe_accumulate_obs( self, dataset_index_to_slices: OrderedDict[int, list[slice]] ) -> list[pd.DataFrame] | None: """Gather obs labels for the loaded slices if possible.""" @@ -669,24 +667,3 @@ def _maybe_accumulate_indices(self, slices: list[slice]) -> list[np.ndarray] | N for idx in dataset_index_to_slices ] ) - - def _prepare_output( - self, - *, - in_memory_data: OutputInMemoryArray_T, - concatenated_obs: pd.DataFrame | None, - in_memory_indices: np.ndarray | None, - split: np.ndarray, - ) -> LoaderOutput: - """Prepare the final output dict for a single batch.""" - index = None - obs = None - if self._obs is not None and concatenated_obs is not None: - obs = concatenated_obs.iloc[split] - if self._return_index and in_memory_indices is not None: - index = in_memory_indices[split] - data = in_memory_data[split] - if self._to_torch: - data = to_torch(data, self._preload_to_gpu) - print(obs) - return {"X": data, "obs": obs, "index": index} From 6a1153ee221d7107205deaefb6114f3b735d7d8d Mon Sep 17 00:00:00 2001 From: selmanozleyen Date: Mon, 19 Jan 2026 21:29:58 +0100 Subject: [PATCH 22/56] clarify docs --- src/annbatch/loader.py | 32 +++++++++++++++++++++----------- 1 file changed, 21 insertions(+), 11 deletions(-) diff --git a/src/annbatch/loader.py b/src/annbatch/loader.py index 3b918b2f..dd99f2f0 100644 --- a/src/annbatch/loader.py +++ b/src/annbatch/loader.py @@ -67,30 +67,40 @@ class Loader[ If `to_torch` is True, the yielded type is a :class:`torch.Tensor`. If both `preload_to_gpu` and `to_torch` are False, then the return type is the CPU class for the given data type. When providing a custom sampler, `chunk_size`, `preload_nchunks`, `batch_size`, - `shuffle`, and `drop_last` must not be set (they are controlled by the sampler). + `shuffle`, and `drop_last` must not be set (they are controlled by the `batch_sampler` instead). Parameters ---------- + batch_sampler + A custom sampler to use for batching the data. + If not provided, a :class:`ChunkSampler` will be used with the following defaults: + - `chunk_size`: 512 + - `preload_nchunks`: 32 + - `batch_size`: 1 + - `shuffle`: False + - `drop_last`: False + Mutually exclusive with the following arguments: `chunk_size`, `preload_nchunks`, `batch_size`, `shuffle`, and `drop_last`. chunk_size - The obs size (i.e., axis 0) of contiguous array data to fetch. + The obs size (i.e., axis 0) of contiguous array data to fetch. When `batch_sampler` is not provided, this is used to determine the chunk size. Mutually exclusive with `batch_sampler`. Defaults to 512. preload_nchunks - The number of chunks of contiguous array data to fetch. + The number of chunks of contiguous array data to fetch. When batch_sampler is not provided, this is used to determine the preload_nchunks. Mutually exclusive with `batch_sampler`. Defaults to 32. shuffle - Whether or not to shuffle the data. + Whether or not to shuffle the data. When batch_sampler is not provided, this is used to determine the shuffle. Mutually exclusive with `batch_sampler`. Defaults to False. + batch_size + Batch size to yield from the dataset. When batch_sampler is not provided, this is used to determine the batch size. Mutually exclusive with `batch_sampler`. Defaults to 1. + drop_last + Set to True to drop the last incomplete batch, if the dataset size is not divisible by the batch size. + If False and the size of dataset is not divisible by the batch size, then the last batch will be smaller. + Leave as False when using in conjunction with a :class:`torch.utils.data.DataLoader`. + When batch_sampler is not provided, this is used to determine the drop_last. Mutually exclusive with `batch_sampler`. Defaults to False. return_index Whether or not to yield the index on each iteration. - batch_size - Batch size to yield from the dataset. preload_to_gpu Whether or not to use cupy for non-io array operations like vstack and indexing once the data is in memory internally. This option entails greater GPU memory usage, but is faster at least for sparse operations. :func:`torch.vstack` does not support CSR sparse matrices, hence the current use of cupy internally. Setting this to `False` is advisable when using the :class:`torch.utils.data.DataLoader` wrapper or potentially with dense data. - For top performance, this should be used in conjuction with `to_torch` and then :meth:`torch.Tensor.to_dense` if you wish to denseify. - drop_last - Set to True to drop the last incomplete batch, if the dataset size is not divisible by the batch size. - If False and the size of dataset is not divisible by the batch size, then the last batch will be smaller. - Leave as False when using in conjunction with a :class:`torch.utils.data.DataLoader`. + For top performance, this should be used in conjuction with `to_torch` and then :meth:`torch.Tensor.to_dense` if you wish to densify. to_torch Whether to return `torch.Tensor` as the output. Data transferred should be 0-copy independent of source, and transfer to cuda when applicable is non-blocking. From 418f79a2d720ebb6b41f195ee079bdb606ebc9e6 Mon Sep 17 00:00:00 2001 From: selmanozleyen Date: Mon, 19 Jan 2026 23:12:25 +0100 Subject: [PATCH 23/56] fix overlook: already sorted batch_indices no need to resort them --- src/annbatch/sampler/_chunk_sampler.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/annbatch/sampler/_chunk_sampler.py b/src/annbatch/sampler/_chunk_sampler.py index 49181269..6ad6321f 100644 --- a/src/annbatch/sampler/_chunk_sampler.py +++ b/src/annbatch/sampler/_chunk_sampler.py @@ -165,7 +165,6 @@ def _sample(self, n_obs: int) -> Iterator[LoadRequest]: (self._rng.permutation if self._shuffle else np.arange)(total_obs_in_last_batch), self._batch_size, ) - batch_indices.sort(key=len, reverse=True) yield {"chunks": final_chunks, "splits": batch_indices} def _compute_chunks(self, chunk_indices: np.ndarray, start: int, stop: int) -> list[slice]: From 9b786f399198e6133c46410f526684581b7af6ac Mon Sep 17 00:00:00 2001 From: selmanozleyen Date: Mon, 19 Jan 2026 23:26:21 +0100 Subject: [PATCH 24/56] fix prepare_output refactor --- src/annbatch/loader.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/annbatch/loader.py b/src/annbatch/loader.py index dd99f2f0..c9e28e94 100644 --- a/src/annbatch/loader.py +++ b/src/annbatch/loader.py @@ -23,6 +23,7 @@ check_lt_1, check_var_shapes, load_x_and_obs, + to_torch, validate_sampler, ) @@ -631,8 +632,9 @@ def __iter__( in_memory_data = mod.vstack(chunks_converted) for split in splits: + data = in_memory_data[split] yield { - "X": in_memory_data[split], + "X": data if not self._to_torch else to_torch(data, self._preload_to_gpu), "obs": concatenated_obs.iloc[split] if concatenated_obs is not None else None, "index": in_memory_indices[split] if in_memory_indices is not None else None, } From fc1661ebb9d2604629e3547641a984b9b3568106 Mon Sep 17 00:00:00 2001 From: selmanozleyen Date: Mon, 19 Jan 2026 23:30:12 +0100 Subject: [PATCH 25/56] add todo --- src/annbatch/loader.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/annbatch/loader.py b/src/annbatch/loader.py index c9e28e94..6a2f4f28 100644 --- a/src/annbatch/loader.py +++ b/src/annbatch/loader.py @@ -128,6 +128,9 @@ class Loader[ "shuffle": False, "drop_last": False, } + # TODO: these should be also presented in the documentation + # but this is not ideal since they are hardcoded into the docstrings + # maybe we should make this a public class field? _train_datasets: list[BackingArray] _obs: list[pd.DataFrame] | None = None From d764adc43fd34511e4c3442b9d08e5eb9ad7fc0c Mon Sep 17 00:00:00 2001 From: selmanozleyen Date: Mon, 19 Jan 2026 23:34:30 +0100 Subject: [PATCH 26/56] rename from leftover to remainder for clarity. since there is no leftover --- tests/test_dataset.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/test_dataset.py b/tests/test_dataset.py index 127cd974..4bcac5fb 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -293,13 +293,13 @@ def test_drop_last(adata_with_zarr_path_same_var_space: tuple[ad.AnnData, Path], batches += [batch["X"]] indices += [batch["index"]] total_obs = adata.shape[0] - leftover = total_obs % batch_size - assert leftover != 0, f"batch_size {batch_size} must not divide evenly into {total_obs} observations" + remainder = total_obs % batch_size + assert remainder != 0, f"batch_size {batch_size} must not divide evenly into {total_obs} observations" for batch in batches[:-1]: assert batch.shape[0] == batch_size - assert batches[-1].shape[0] == (batch_size if drop_last else leftover) + assert batches[-1].shape[0] == (batch_size if drop_last else remainder) X = sp.vstack(batches).toarray() - assert X.shape[0] == (total_obs - leftover if drop_last else total_obs) + assert X.shape[0] == (total_obs - remainder if drop_last else total_obs) X_expected = adata[np.concatenate(indices)].layers["sparse"].toarray() np.testing.assert_allclose(X, X_expected) From 5bc27514fc646e98f387f15373630dd9412407bd Mon Sep 17 00:00:00 2001 From: selmanozleyen Date: Tue, 20 Jan 2026 00:20:04 +0100 Subject: [PATCH 27/56] simplify validate_sampler --- src/annbatch/loader.py | 16 +++++++--------- src/annbatch/utils.py | 14 ++++++-------- 2 files changed, 13 insertions(+), 17 deletions(-) diff --git a/src/annbatch/loader.py b/src/annbatch/loader.py index 6a2f4f28..8b030f3f 100644 --- a/src/annbatch/loader.py +++ b/src/annbatch/loader.py @@ -276,7 +276,7 @@ def use_collection( self._collection_added = True return self - @validate_sampler(lambda self, adatas: sum(adata.n_obs for adata in adatas)) + @validate_sampler(lambda adatas, obs=None: sum(adata.n_obs for adata in adatas)) def add_anndatas( self, adatas: list[ad.AnnData], @@ -315,7 +315,7 @@ def _prepare_dataset_and_obs(self, adata: ad.AnnData) -> tuple[BackingArray, pd. raise TypeError(f"Found {type(dataset)} but only {BackingArray_T.__value__} are usable") return cast("BackingArray", dataset), obs - @validate_sampler(lambda self, datasets, obs=None: sum(ds.shape[0] for ds in datasets)) + @validate_sampler(lambda datasets, obs=None: sum(ds.shape[0] for ds in datasets)) def add_datasets(self, datasets: list[BackingArray], obs: list[pd.DataFrame] | None = None) -> Self: """Append datasets to this dataset. @@ -333,7 +333,7 @@ def add_datasets(self, datasets: list[BackingArray], obs: list[pd.DataFrame] | N self._add_dataset_unchecked(ds, o) return self - @validate_sampler(lambda self, dataset, obs=None: dataset.shape[0]) + @validate_sampler(lambda dataset, obs=None: dataset.shape[0]) def add_dataset(self, dataset: BackingArray, obs: pd.DataFrame | None = None) -> Self: """Append a dataset to this dataset. @@ -629,8 +629,8 @@ def __iter__( chunks: list[InputInMemoryArray] = zsync.sync(self._index_datasets(dataset_index_to_slices)) chunks_converted = self._accumulate_chunks(chunks) # Accumulate labels and indices if possible - concatenated_obs: None | list[pd.DataFrame] = self._maybe_accumulate_obs(dataset_index_to_slices) - in_memory_indices: None | list[np.ndarray] = self._maybe_accumulate_indices(chunks_to_load) + concatenated_obs: None | pd.DataFrame = self._maybe_accumulate_obs(dataset_index_to_slices) + in_memory_indices: None | np.ndarray = self._maybe_accumulate_indices(chunks_to_load) in_memory_data = mod.vstack(chunks_converted) @@ -658,9 +658,7 @@ def _accumulate_chunks(self, chunks: list[InputInMemoryArray]) -> list[OutputInM result.append(self._np_module.asarray(chunk)) return result - def _maybe_accumulate_obs( - self, dataset_index_to_slices: OrderedDict[int, list[slice]] - ) -> list[pd.DataFrame] | None: + def _maybe_accumulate_obs(self, dataset_index_to_slices: OrderedDict[int, list[slice]]) -> pd.DataFrame | None: """Gather obs labels for the loaded slices if possible.""" if self._obs is None: return None @@ -671,7 +669,7 @@ def _maybe_accumulate_obs( ] ) - def _maybe_accumulate_indices(self, slices: list[slice]) -> list[np.ndarray] | None: + def _maybe_accumulate_indices(self, slices: list[slice]) -> np.ndarray | None: """Gather original indices for the loaded slices if possible.""" if self._return_index is False: return None diff --git a/src/annbatch/utils.py b/src/annbatch/utils.py index 78d23cbd..225ecd5b 100644 --- a/src/annbatch/utils.py +++ b/src/annbatch/utils.py @@ -196,23 +196,21 @@ def load_x_and_obs(g: zarr.Group) -> ad.AnnData: ) -def validate_sampler(get_additional_n_obs): +def validate_sampler(get_n_obs): """Decorator that validates n_obs before modifying state. Parameters ---------- - get_additional_n_obs - A callable (self, *args, **kwargs) -> int that returns the number - of additional observations that will be added by the decorated method.' - For example in add_datasets, this would be lambda self, datasets: sum(dataset.shape[0] for dataset in datasets) + get_n_obs + A callable ( *args, **kwargs) -> int that returns the number of observations that will be added by the decorated method. + For example in add_datasets, this would be lambda datasets: sum(dataset.shape[0] for dataset in datasets) """ def decorator(method): @wraps(method) def wrapper(self, *args, **kwargs): - additional_obs = get_additional_n_obs(self, *args, **kwargs) - prospective_n_obs = self.n_obs + additional_obs - self._batch_sampler.validate(prospective_n_obs) + n_obs = get_n_obs(*args, **kwargs) + self._batch_sampler.validate(n_obs) return method(self, *args, **kwargs) return wrapper From 66d5d3c9741c3403a6c056a4ac162ebe9be218d3 Mon Sep 17 00:00:00 2001 From: selmanozleyen Date: Tue, 20 Jan 2026 00:26:56 +0100 Subject: [PATCH 28/56] remove old generic params --- src/annbatch/loader.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/annbatch/loader.py b/src/annbatch/loader.py index 8b030f3f..678df3a6 100644 --- a/src/annbatch/loader.py +++ b/src/annbatch/loader.py @@ -139,12 +139,12 @@ class Loader[ _preload_to_gpu: bool = True _to_torch: bool = True _dataset_elem_cache: dict[int, CSRDatasetElems] - _batch_sampler: Sampler[list[slice]] + _batch_sampler: Sampler def __init__( self, *, - batch_sampler: Sampler[list[slice]] | None = None, + batch_sampler: Sampler | None = None, chunk_size: int | None = None, preload_nchunks: int | None = None, shuffle: bool | None = None, From 0e8a472fe34aba5a65977398c545c50539cfd858 Mon Sep 17 00:00:00 2001 From: selmanozleyen Date: Tue, 20 Jan 2026 00:32:22 +0100 Subject: [PATCH 29/56] add broad typing --- src/annbatch/utils.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/annbatch/utils.py b/src/annbatch/utils.py index 225ecd5b..ce6cc21d 100644 --- a/src/annbatch/utils.py +++ b/src/annbatch/utils.py @@ -14,6 +14,8 @@ from .compat import CupyArray, CupyCSRMatrix, Tensor if TYPE_CHECKING: + from collections.abc import Callable + from annbatch.types import OutputInMemoryArray_T @@ -196,7 +198,7 @@ def load_x_and_obs(g: zarr.Group) -> ad.AnnData: ) -def validate_sampler(get_n_obs): +def validate_sampler(get_n_obs: Callable[..., int]): """Decorator that validates n_obs before modifying state. Parameters From ae3e1bc6b78b5d8c60f64a3868ce1ad265ebd914 Mon Sep 17 00:00:00 2001 From: selmanozleyen Date: Tue, 20 Jan 2026 16:00:42 +0100 Subject: [PATCH 30/56] clarify todos and add username --- src/annbatch/loader.py | 4 ++-- src/annbatch/sampler/_chunk_sampler.py | 3 ++- src/annbatch/utils.py | 5 +++-- tests/test_sampler.py | 2 +- 4 files changed, 8 insertions(+), 6 deletions(-) diff --git a/src/annbatch/loader.py b/src/annbatch/loader.py index 678df3a6..b2db88ea 100644 --- a/src/annbatch/loader.py +++ b/src/annbatch/loader.py @@ -128,9 +128,9 @@ class Loader[ "shuffle": False, "drop_last": False, } - # TODO: these should be also presented in the documentation + # TODO(selmanozleyen): these should be also presented in the documentation # but this is not ideal since they are hardcoded into the docstrings - # maybe we should make this a public class field? + # maybe we should make _COMMON_SAMPLER_ARGS a public class field? _train_datasets: list[BackingArray] _obs: list[pd.DataFrame] | None = None diff --git a/src/annbatch/sampler/_chunk_sampler.py b/src/annbatch/sampler/_chunk_sampler.py index 6ad6321f..c0c53ea7 100644 --- a/src/annbatch/sampler/_chunk_sampler.py +++ b/src/annbatch/sampler/_chunk_sampler.py @@ -140,7 +140,8 @@ def _sample(self, n_obs: int) -> Iterator[LoadRequest]: # Create chunk indices for possible shuffling and worker sharding chunk_indices = np.arange(math.ceil((stop - start) / self._chunk_size)) if self._shuffle: - self._rng.shuffle(chunk_indices) # TODO: maybe this should be done worker-aware way? + # TODO(selmanozleyen): maybe this should be done worker-aware way? + self._rng.shuffle(chunk_indices) chunks = self._compute_chunks(chunk_indices, start, stop) # Worker sharding: each worker gets a disjoint subset of chunks if worker_handle is not None: diff --git a/src/annbatch/utils.py b/src/annbatch/utils.py index ce6cc21d..ce78ab39 100644 --- a/src/annbatch/utils.py +++ b/src/annbatch/utils.py @@ -16,7 +16,7 @@ if TYPE_CHECKING: from collections.abc import Callable - from annbatch.types import OutputInMemoryArray_T + from annbatch.types import BackingArray_T, OutputInMemoryArray_T def split_given_size(a: np.ndarray, size: int) -> list[np.ndarray]: @@ -198,7 +198,7 @@ def load_x_and_obs(g: zarr.Group) -> ad.AnnData: ) -def validate_sampler(get_n_obs: Callable[..., int]): +def validate_sampler(get_n_obs: Callable[[list[ad.AnnData | BackingArray_T] | BackingArray_T | ad.AnnData], int]): """Decorator that validates n_obs before modifying state. Parameters @@ -212,6 +212,7 @@ def decorator(method): @wraps(method) def wrapper(self, *args, **kwargs): n_obs = get_n_obs(*args, **kwargs) + # TODO(selmanozleyen): maybe batch sampler should be public? self._batch_sampler.validate(n_obs) return method(self, *args, **kwargs) diff --git a/tests/test_sampler.py b/tests/test_sampler.py index f5867356..bc013a7d 100644 --- a/tests/test_sampler.py +++ b/tests/test_sampler.py @@ -7,7 +7,7 @@ from annbatch.sampler import ChunkSampler -# TODO: Check for the validation within the _get_worker_handle method. Mock worker handle wouldn't make sense +# TODO(selmanozleyen): Check for the validation within the _get_worker_handle method. Mock worker handle wouldn't make sense # but overall one must also think about how validation can't be independent of the worker handle. From 742605a85c3b0f14196b306c2a43c358953e0f6d Mon Sep 17 00:00:00 2001 From: selmanozleyen Date: Tue, 20 Jan 2026 16:13:13 +0100 Subject: [PATCH 31/56] type and modify decorator --- src/annbatch/utils.py | 33 ++++++++++++++------------------- 1 file changed, 14 insertions(+), 19 deletions(-) diff --git a/src/annbatch/utils.py b/src/annbatch/utils.py index ce78ab39..87a11bdc 100644 --- a/src/annbatch/utils.py +++ b/src/annbatch/utils.py @@ -14,9 +14,7 @@ from .compat import CupyArray, CupyCSRMatrix, Tensor if TYPE_CHECKING: - from collections.abc import Callable - - from annbatch.types import BackingArray_T, OutputInMemoryArray_T + from annbatch.types import OutputInMemoryArray_T def split_given_size(a: np.ndarray, size: int) -> list[np.ndarray]: @@ -198,24 +196,21 @@ def load_x_and_obs(g: zarr.Group) -> ad.AnnData: ) -def validate_sampler(get_n_obs: Callable[[list[ad.AnnData | BackingArray_T] | BackingArray_T | ad.AnnData], int]): +def validate_sampler(method): """Decorator that validates n_obs before modifying state. - Parameters - ---------- - get_n_obs - A callable ( *args, **kwargs) -> int that returns the number of observations that will be added by the decorated method. - For example in add_datasets, this would be lambda datasets: sum(dataset.shape[0] for dataset in datasets) - """ + Expects the first positional argument to be either: + - A single object with a `.shape` attribute + - A list of objects with `.shape` attributes - def decorator(method): - @wraps(method) - def wrapper(self, *args, **kwargs): - n_obs = get_n_obs(*args, **kwargs) - # TODO(selmanozleyen): maybe batch sampler should be public? - self._batch_sampler.validate(n_obs) - return method(self, *args, **kwargs) + The total n_obs is computed as sum of shape[0] values for a list of objects or the shape[0] value for a single object. + """ - return wrapper + @wraps(method) + def wrapper(self, first_arg: SupportsShape | list[SupportsShape], /, *args, **kwargs): + n_obs = sum(item.shape[0] for item in first_arg) if isinstance(first_arg, list) else first_arg.shape[0] + # TODO(selmanozleyen): maybe batch sampler should be public? + self._batch_sampler.validate(n_obs) + return method(self, first_arg, *args, **kwargs) - return decorator + return wrapper From cf30686ca528c90964584c4794189b1d4b627637 Mon Sep 17 00:00:00 2001 From: selmanozleyen Date: Tue, 20 Jan 2026 16:22:59 +0100 Subject: [PATCH 32/56] no need for lambdas in decorators --- src/annbatch/loader.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/annbatch/loader.py b/src/annbatch/loader.py index b2db88ea..e23eb601 100644 --- a/src/annbatch/loader.py +++ b/src/annbatch/loader.py @@ -276,7 +276,7 @@ def use_collection( self._collection_added = True return self - @validate_sampler(lambda adatas, obs=None: sum(adata.n_obs for adata in adatas)) + @validate_sampler def add_anndatas( self, adatas: list[ad.AnnData], @@ -315,7 +315,7 @@ def _prepare_dataset_and_obs(self, adata: ad.AnnData) -> tuple[BackingArray, pd. raise TypeError(f"Found {type(dataset)} but only {BackingArray_T.__value__} are usable") return cast("BackingArray", dataset), obs - @validate_sampler(lambda datasets, obs=None: sum(ds.shape[0] for ds in datasets)) + @validate_sampler def add_datasets(self, datasets: list[BackingArray], obs: list[pd.DataFrame] | None = None) -> Self: """Append datasets to this dataset. @@ -333,7 +333,7 @@ def add_datasets(self, datasets: list[BackingArray], obs: list[pd.DataFrame] | N self._add_dataset_unchecked(ds, o) return self - @validate_sampler(lambda dataset, obs=None: dataset.shape[0]) + @validate_sampler def add_dataset(self, dataset: BackingArray, obs: pd.DataFrame | None = None) -> Self: """Append a dataset to this dataset. From 261c5e8d2f1f10fbb9a875d27499a3a82779949b Mon Sep 17 00:00:00 2001 From: selmanozleyen Date: Tue, 20 Jan 2026 16:44:38 +0100 Subject: [PATCH 33/56] make decorator compatible in multiple cases --- src/annbatch/utils.py | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/src/annbatch/utils.py b/src/annbatch/utils.py index 87a11bdc..1509c941 100644 --- a/src/annbatch/utils.py +++ b/src/annbatch/utils.py @@ -1,5 +1,6 @@ from __future__ import annotations +import inspect import warnings from dataclasses import dataclass from functools import cached_property, wraps @@ -205,12 +206,22 @@ def validate_sampler(method): The total n_obs is computed as sum of shape[0] values for a list of objects or the shape[0] value for a single object. """ + # Get the first parameter name (after 'self') at decoration time + sig = inspect.signature(method) + if len(sig.parameters) < 2: + raise ValueError("validate_sampler decorator expects at least two positional arguments after 'self'") + first_param_name = list(sig.parameters.keys())[1] # [0] is 'self' @wraps(method) - def wrapper(self, first_arg: SupportsShape | list[SupportsShape], /, *args, **kwargs): + def wrapper(self, *args, **kwargs): + # Extract from args if positional, otherwise from kwargs by name + if len(args) > 0: + first_arg = args[0] + else: + first_arg = kwargs[first_param_name] + n_obs = sum(item.shape[0] for item in first_arg) if isinstance(first_arg, list) else first_arg.shape[0] - # TODO(selmanozleyen): maybe batch sampler should be public? self._batch_sampler.validate(n_obs) - return method(self, first_arg, *args, **kwargs) + return method(self, *args, **kwargs) return wrapper From 4402d4ebf9720f95bd49a6094ccee6726854ddc2 Mon Sep 17 00:00:00 2001 From: selmanozleyen Date: Tue, 20 Jan 2026 17:11:07 +0100 Subject: [PATCH 34/56] put ABC in abc folder --- src/annbatch/loader.py | 6 ++++-- src/annbatch/sampler/__init__.py | 4 ++-- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/src/annbatch/loader.py b/src/annbatch/loader.py index e23eb601..793d822a 100644 --- a/src/annbatch/loader.py +++ b/src/annbatch/loader.py @@ -15,7 +15,7 @@ from scipy import sparse as sp from zarr import Array as ZarrArray -from annbatch.sampler import ChunkSampler, Sampler +from annbatch.sampler import ChunkSampler from annbatch.types import BackingArray_T, InputInMemoryArray_T, LoaderOutput, OutputInMemoryArray_T from annbatch.utils import ( CSRContainer, @@ -34,6 +34,7 @@ from types import ModuleType from annbatch.io import DatasetCollection + from annbatch.sampler.abc import Sampler # TODO: remove after sphinx 9 - myst compat BackingArray = BackingArray_T @@ -74,7 +75,7 @@ class Loader[ ---------- batch_sampler A custom sampler to use for batching the data. - If not provided, a :class:`ChunkSampler` will be used with the following defaults: + If not provided, a default chunk sampler will be used with the following defaults: - `chunk_size`: 512 - `preload_nchunks`: 32 - `batch_size`: 1 @@ -132,6 +133,7 @@ class Loader[ # but this is not ideal since they are hardcoded into the docstrings # maybe we should make _COMMON_SAMPLER_ARGS a public class field? + # TODO (selmanozleyen): can't link chunk sampler to the docstring because it's not exposed in the public API _train_datasets: list[BackingArray] _obs: list[pd.DataFrame] | None = None _return_index: bool = False diff --git a/src/annbatch/sampler/__init__.py b/src/annbatch/sampler/__init__.py index 40a8f79b..d264a8d4 100644 --- a/src/annbatch/sampler/__init__.py +++ b/src/annbatch/sampler/__init__.py @@ -3,10 +3,10 @@ This module provides samplers optimized for chunk-based data access patterns. """ +from annbatch.sampler import abc from annbatch.sampler._chunk_sampler import ChunkSampler -from annbatch.sampler.abc import Sampler __all__ = [ "ChunkSampler", - "Sampler", + "abc", ] From 0929849bc3b7f811d8337452b8d12b85d12519d6 Mon Sep 17 00:00:00 2001 From: selmanozleyen Date: Tue, 20 Jan 2026 17:13:22 +0100 Subject: [PATCH 35/56] update test with the fix --- tests/test_dataset.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_dataset.py b/tests/test_dataset.py index 4bcac5fb..bbd5ab09 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -433,9 +433,9 @@ def test_no_obs(simple_collection: tuple[ad.AnnData, DatasetCollection]): def test_add_dataset_validation_failure_preserves_state(adata_with_zarr_path_same_var_space: tuple[ad.AnnData, Path]): """Test that failed validation in add_dataset doesn't modify internal state.""" - from annbatch.sampler import Sampler + from annbatch.sampler import abc - class FailOnSecondValidateSampler(Sampler): + class FailOnSecondValidateSampler(abc.Sampler): """A sampler that fails validation after the first call.""" def __init__(self): From 899cc1875839b32ca58c5902368b758b68d14a62 Mon Sep 17 00:00:00 2001 From: selmanozleyen Date: Tue, 20 Jan 2026 17:19:47 +0100 Subject: [PATCH 36/56] qualname for fix. no sampler in public API --- docs/conf.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/docs/conf.py b/docs/conf.py index 10b83dd5..472973d3 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -133,7 +133,9 @@ # If building the documentation fails because of a missing link that is outside your control, # you can add an exception to this list. # ("py:class", "igraph.Graph"), - ("py:class", "annbatch.types.TypeAliasType") + ("py:class", "annbatch.types.TypeAliasType"), + # this is not exposed in the public API + ("py:class", "annbatch.sampler.abc._sampler.Sampler"), ] qualname_overrides = { From 89e7ccdcc45ad60c43b857da5aa96317d9090a62 Mon Sep 17 00:00:00 2001 From: selmanozleyen Date: Tue, 20 Jan 2026 17:57:13 +0100 Subject: [PATCH 37/56] check coverage when shuffled otherwise also check order --- tests/test_sampler.py | 23 +++++++++++++++-------- 1 file changed, 15 insertions(+), 8 deletions(-) diff --git a/tests/test_sampler.py b/tests/test_sampler.py index bc013a7d..5c4a5d03 100644 --- a/tests/test_sampler.py +++ b/tests/test_sampler.py @@ -13,12 +13,12 @@ def collect_indices(sampler, n_obs): """Helper to collect all indices from sampler.""" - indices = set() + indices = [] for load_request in sampler.sample(n_obs): assert len(load_request["splits"]) > 0, "splits must be non-empty" assert all(len(s) > 0 for s in load_request["splits"]), "splits must be non-empty" for s in load_request["chunks"]: - indices.update(range(s.start, s.stop)) + indices.extend(range(s.start, s.stop)) return indices @@ -81,7 +81,7 @@ def _get_worker_handle(self) -> MockWorkerHandle | None: ], ) def test_mask_coverage(n_obs, chunk_size, start, stop, batch_size, preload_nchunks, shuffle): - """Test sampler covers exactly the expected range.""" + """Test sampler covers exactly the expected range, and ordering is correct when not shuffled.""" sampler = ChunkSampler( mask=slice(start, stop), batch_size=batch_size, @@ -91,11 +91,18 @@ def test_mask_coverage(n_obs, chunk_size, start, stop, batch_size, preload_nchun rng=np.random.default_rng(42) if shuffle else None, ) - all_indices = collect_indices(sampler, n_obs) - expected_start = start if start is not None else 0 expected_stop = stop if stop is not None else n_obs - assert all_indices == set(range(expected_start, expected_stop)) + expected_indices = list(range(expected_start, expected_stop)) + + all_indices = collect_indices(sampler, n_obs, preserve_order=True) + + # Always check coverage + if shuffle: + assert set(all_indices) == set(expected_indices), "Sampler should cover all expected indices" + else: + assert all_indices == expected_indices, f"all_indices: {all_indices} != expected_indices: {expected_indices}" + sampler.validate(n_obs) @@ -232,7 +239,7 @@ def test_n_obs_coverage(n_obs_values, expected_ranges): """Test that n_obs changes affect sampling results appropriately.""" sampler = ChunkSampler(mask=slice(0, None), batch_size=5, chunk_size=10, preload_nchunks=2, shuffle=False) - results = [collect_indices(sampler, n) for n in n_obs_values] + results = [collect_indices(sampler, n, preserve_order=True) for n in n_obs_values] for result, expected in zip(results, expected_ranges, strict=True): - assert result == set(expected) + assert result == list(expected), f"result: {result} != expected: {expected}" From 0356374415b4ad07de938f7c9bdcf8a2270a9cc7 Mon Sep 17 00:00:00 2001 From: selmanozleyen Date: Tue, 20 Jan 2026 17:57:47 +0100 Subject: [PATCH 38/56] fix to prev commit --- tests/test_sampler.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_sampler.py b/tests/test_sampler.py index 5c4a5d03..02f46fc1 100644 --- a/tests/test_sampler.py +++ b/tests/test_sampler.py @@ -95,7 +95,7 @@ def test_mask_coverage(n_obs, chunk_size, start, stop, batch_size, preload_nchun expected_stop = stop if stop is not None else n_obs expected_indices = list(range(expected_start, expected_stop)) - all_indices = collect_indices(sampler, n_obs, preserve_order=True) + all_indices = collect_indices(sampler, n_obs) # Always check coverage if shuffle: @@ -239,7 +239,7 @@ def test_n_obs_coverage(n_obs_values, expected_ranges): """Test that n_obs changes affect sampling results appropriately.""" sampler = ChunkSampler(mask=slice(0, None), batch_size=5, chunk_size=10, preload_nchunks=2, shuffle=False) - results = [collect_indices(sampler, n, preserve_order=True) for n in n_obs_values] + results = [collect_indices(sampler, n) for n in n_obs_values] for result, expected in zip(results, expected_ranges, strict=True): assert result == list(expected), f"result: {result} != expected: {expected}" From 87f1ccbfd5a43d62e37076e53c1c9f6dbab43053 Mon Sep 17 00:00:00 2001 From: selmanozleyen Date: Tue, 20 Jan 2026 17:58:16 +0100 Subject: [PATCH 39/56] clarify doc --- docs/conf.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/conf.py b/docs/conf.py index 472973d3..e00e6cba 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -134,7 +134,7 @@ # you can add an exception to this list. # ("py:class", "igraph.Graph"), ("py:class", "annbatch.types.TypeAliasType"), - # this is not exposed in the public API + # this is not exposed in the public API yet ("py:class", "annbatch.sampler.abc._sampler.Sampler"), ] From 8a7f8c29601d4c40a7c970e622ca805c1ce29225 Mon Sep 17 00:00:00 2001 From: selmanozleyen Date: Tue, 20 Jan 2026 18:05:13 +0100 Subject: [PATCH 40/56] update worker tests --- tests/test_sampler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_sampler.py b/tests/test_sampler.py index 02f46fc1..29f74d89 100644 --- a/tests/test_sampler.py +++ b/tests/test_sampler.py @@ -180,7 +180,7 @@ def test_workers_cover_full_dataset_without_overlap( # All workers should have disjoint chunks for i in range(num_workers): for j in range(i + 1, num_workers): - assert all_worker_indices[i].isdisjoint(all_worker_indices[j]) + assert set(all_worker_indices[i]).isdisjoint(all_worker_indices[j]) # Together they cover the full dataset assert set().union(*all_worker_indices) == set(range(n_obs)) From a85634a14c7521022c5c2a2876576b0f1d7410cc Mon Sep 17 00:00:00 2001 From: selmanozleyen Date: Tue, 20 Jan 2026 18:15:44 +0100 Subject: [PATCH 41/56] new * location for ChunkSampler --- src/annbatch/sampler/_chunk_sampler.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/annbatch/sampler/_chunk_sampler.py b/src/annbatch/sampler/_chunk_sampler.py index c0c53ea7..ca02f920 100644 --- a/src/annbatch/sampler/_chunk_sampler.py +++ b/src/annbatch/sampler/_chunk_sampler.py @@ -51,12 +51,12 @@ class ChunkSampler(Sampler): def __init__( self, - *, - batch_size: int, chunk_size: int, + preload_nchunks: int, + batch_size: int, + *, mask: slice | None = None, shuffle: bool = False, - preload_nchunks: int, drop_last: bool = False, rng: np.random.Generator | None = None, ): From 61efe81c7cac5457683e2308340bb5228cb7b707 Mon Sep 17 00:00:00 2001 From: selmanozleyen Date: Tue, 20 Jan 2026 20:18:50 +0100 Subject: [PATCH 42/56] add typing but can revert if too verbose --- src/annbatch/utils.py | 71 ++++++++++++++++++++++++------------------- 1 file changed, 40 insertions(+), 31 deletions(-) diff --git a/src/annbatch/utils.py b/src/annbatch/utils.py index 1509c941..55166499 100644 --- a/src/annbatch/utils.py +++ b/src/annbatch/utils.py @@ -5,7 +5,7 @@ from dataclasses import dataclass from functools import cached_property, wraps from importlib.util import find_spec -from typing import TYPE_CHECKING, Protocol +from typing import TYPE_CHECKING, Concatenate, Protocol import anndata as ad import numpy as np @@ -15,9 +15,48 @@ from .compat import CupyArray, CupyCSRMatrix, Tensor if TYPE_CHECKING: + from collections.abc import Callable + + from annbatch.sampler.abc import Sampler from annbatch.types import OutputInMemoryArray_T +# typing with decorators and self: +# https://stackoverflow.com/a/68290080 +class HasBatchSampler(Protocol): + _batch_sampler: Sampler + + +def validate_sampler[Self: HasBatchSampler, **Param, RetType]( + method: Callable[Concatenate[Self, Param], RetType], +) -> Callable[Concatenate[Self, Param], RetType]: + """Decorator that validates n_obs before modifying state. + + Expects the first positional argument to be either: + - A single object with a `.shape` attribute + - A list of objects with `.shape` attributes + + The total n_obs is computed as sum of shape[0] values for a list of objects or the shape[0] value for a single object. + """ + sig = inspect.signature(method) + if len(sig.parameters) < 2: + raise ValueError("validate_sampler decorator expects at least two positional arguments after 'self'") + first_param_name = list(sig.parameters.keys())[1] + + @wraps(method) + def wrapper(self: Self, *args: Param.args, **kwargs: Param.kwargs) -> RetType: + if len(args) > 0: + first_arg = args[0] + else: + first_arg = kwargs[first_param_name] + + n_obs = sum(item.shape[0] for item in first_arg) if isinstance(first_arg, list) else first_arg.shape[0] + self._batch_sampler.validate(n_obs) + return method(self, *args, **kwargs) + + return wrapper + + def split_given_size(a: np.ndarray, size: int) -> list[np.ndarray]: """Wrapper around `np.split` to split up an array into `size` chunks""" return np.split(a, np.arange(size, len(a), size)) @@ -195,33 +234,3 @@ def load_x_and_obs(g: zarr.Group) -> ad.AnnData: return ad.AnnData( X=g["X"] if isinstance(g["X"], zarr.Array) else ad.io.sparse_dataset(g["X"]), obs=ad.io.read_elem(g["obs"]) ) - - -def validate_sampler(method): - """Decorator that validates n_obs before modifying state. - - Expects the first positional argument to be either: - - A single object with a `.shape` attribute - - A list of objects with `.shape` attributes - - The total n_obs is computed as sum of shape[0] values for a list of objects or the shape[0] value for a single object. - """ - # Get the first parameter name (after 'self') at decoration time - sig = inspect.signature(method) - if len(sig.parameters) < 2: - raise ValueError("validate_sampler decorator expects at least two positional arguments after 'self'") - first_param_name = list(sig.parameters.keys())[1] # [0] is 'self' - - @wraps(method) - def wrapper(self, *args, **kwargs): - # Extract from args if positional, otherwise from kwargs by name - if len(args) > 0: - first_arg = args[0] - else: - first_arg = kwargs[first_param_name] - - n_obs = sum(item.shape[0] for item in first_arg) if isinstance(first_arg, list) else first_arg.shape[0] - self._batch_sampler.validate(n_obs) - return method(self, *args, **kwargs) - - return wrapper From faaf5250fc475a856eca62f1164ac6a623e2207e Mon Sep 17 00:00:00 2001 From: selmanozleyen Date: Tue, 20 Jan 2026 20:19:22 +0100 Subject: [PATCH 43/56] remove unused fields. (maybe linter check can be added) --- src/annbatch/sampler/_chunk_sampler.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/annbatch/sampler/_chunk_sampler.py b/src/annbatch/sampler/_chunk_sampler.py index ca02f920..55b12cf3 100644 --- a/src/annbatch/sampler/_chunk_sampler.py +++ b/src/annbatch/sampler/_chunk_sampler.py @@ -44,8 +44,6 @@ class ChunkSampler(Sampler): _shuffle: bool _preload_nchunks: int _mask: slice - _n_chunks: int - _n_iters: int _drop_last: bool _rng: np.random.Generator From eecb0b15c43e359430c71fe6058ac2fa1d777a2f Mon Sep 17 00:00:00 2001 From: selmanozleyen Date: Tue, 20 Jan 2026 20:20:36 +0100 Subject: [PATCH 44/56] remove old SO link --- src/annbatch/utils.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/annbatch/utils.py b/src/annbatch/utils.py index 55166499..72e09013 100644 --- a/src/annbatch/utils.py +++ b/src/annbatch/utils.py @@ -21,8 +21,6 @@ from annbatch.types import OutputInMemoryArray_T -# typing with decorators and self: -# https://stackoverflow.com/a/68290080 class HasBatchSampler(Protocol): _batch_sampler: Sampler From 2cd09caac0e6fe48e1c6427f5304f4f06cb9d81f Mon Sep 17 00:00:00 2001 From: selmanozleyen Date: Tue, 20 Jan 2026 20:30:00 +0100 Subject: [PATCH 45/56] don't put generators into np.all !! --- src/annbatch/sampler/abc/_sampler.py | 4 ++-- tests/test_sampler.py | 8 ++++---- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/annbatch/sampler/abc/_sampler.py b/src/annbatch/sampler/abc/_sampler.py index 25b4846d..b22b7b17 100644 --- a/src/annbatch/sampler/abc/_sampler.py +++ b/src/annbatch/sampler/abc/_sampler.py @@ -35,9 +35,9 @@ def sample(self, n_obs: int) -> Iterator[LoadRequest]: @abstractmethod def validate(self, n_obs: int) -> None: - """Validate the sampler configuration against the loader's state. + """Validate the sampler configuration against the given n_obs. - This method is called when the sampler is set on a loader. + This method is called at the start of each `sample()` call. Override this method to add custom validation for sampler parameters. Parameters diff --git a/tests/test_sampler.py b/tests/test_sampler.py index 29f74d89..75174202 100644 --- a/tests/test_sampler.py +++ b/tests/test_sampler.py @@ -124,20 +124,20 @@ def test_batch_sizes_match_expected_pattern(): all_requests = list(sampler.sample(n_obs)) assert len(all_requests) == expected_num_load_requests for req_idx, load_request in enumerate(all_requests[:-1]): - assert np.all(len(chunk) == chunk_size for chunk in load_request["chunks"]), ( + assert all(chunk.stop - chunk.start == chunk_size for chunk in load_request["chunks"]), ( f"chunk size mismatch at request {req_idx}:", f"chunks: {load_request['chunks']}", ) - assert np.all(len(split) == batch_size for split in load_request["splits"]), ( + assert all(len(split) == batch_size for split in load_request["splits"]), ( f"batch size mismatch at request {req_idx}:splits: {load_request['splits']}" ) last_request = all_requests[-1] assert len(last_request["splits"]) == expected_last_num_splits, "last request num splits mismatch" - assert np.all(len(chunk) == expected_last_chunk_size for chunk in last_request["chunks"]), ( + assert all(chunk.stop - chunk.start == expected_last_chunk_size for chunk in last_request["chunks"]), ( "last request chunk size mismatch", f"chunks: {last_request['chunks']}", ) - assert np.all(len(split) == expected_last_batch_size for split in last_request["splits"]), ( + assert all(len(split) == expected_last_batch_size for split in last_request["splits"]), ( "last request batch size mismatch", f"splits: {last_request['splits']}", ) From 4fcb55368d6f1c7f5329a37b6e2612fbfff85794 Mon Sep 17 00:00:00 2001 From: selmanozleyen Date: Wed, 21 Jan 2026 12:33:56 +0100 Subject: [PATCH 46/56] apply typing and docstring suggestion --- src/annbatch/loader.py | 13 ++----------- src/annbatch/types.py | 2 +- src/annbatch/utils.py | 14 +++++--------- 3 files changed, 8 insertions(+), 21 deletions(-) diff --git a/src/annbatch/loader.py b/src/annbatch/loader.py index 793d822a..580afe33 100644 --- a/src/annbatch/loader.py +++ b/src/annbatch/loader.py @@ -70,21 +70,12 @@ class Loader[ If both `preload_to_gpu` and `to_torch` are False, then the return type is the CPU class for the given data type. When providing a custom sampler, `chunk_size`, `preload_nchunks`, `batch_size`, `shuffle`, and `drop_last` must not be set (they are controlled by the `batch_sampler` instead). + When providing these arguments and no `batch_sampler`, they are used to construct a :class:`annbatch.ChunkSampler`. Parameters ---------- batch_sampler - A custom sampler to use for batching the data. - If not provided, a default chunk sampler will be used with the following defaults: - - `chunk_size`: 512 - - `preload_nchunks`: 32 - - `batch_size`: 1 - - `shuffle`: False - - `drop_last`: False - Mutually exclusive with the following arguments: `chunk_size`, `preload_nchunks`, `batch_size`, `shuffle`, and `drop_last`. - chunk_size - The obs size (i.e., axis 0) of contiguous array data to fetch. When `batch_sampler` is not provided, this is used to determine the chunk size. Mutually exclusive with `batch_sampler`. Defaults to 512. - preload_nchunks + If not provided, a default :class:`annbatch.ChunkSampler` will be used with the same defaults below. The number of chunks of contiguous array data to fetch. When batch_sampler is not provided, this is used to determine the preload_nchunks. Mutually exclusive with `batch_sampler`. Defaults to 32. shuffle Whether or not to shuffle the data. When batch_sampler is not provided, this is used to determine the shuffle. Mutually exclusive with `batch_sampler`. Defaults to False. diff --git a/src/annbatch/types.py b/src/annbatch/types.py index 668a96ae..ba39b1b3 100644 --- a/src/annbatch/types.py +++ b/src/annbatch/types.py @@ -27,7 +27,7 @@ class LoadRequest(TypedDict): chunks Chunks to load - a list of slices with a range of chunk_size except the last one which may be smaller but not empty. splits - How the concatenation of chunks should be split into batches. + How the in-memory data should be split into batches after it is read off disk and concatenated in-memory. A list of splits, last one may be partial but not empty i.e. 1 <= len(last_split) <= batch_size. """ diff --git a/src/annbatch/utils.py b/src/annbatch/utils.py index 72e09013..3dd04e66 100644 --- a/src/annbatch/utils.py +++ b/src/annbatch/utils.py @@ -17,17 +17,13 @@ if TYPE_CHECKING: from collections.abc import Callable - from annbatch.sampler.abc import Sampler + from annbatch.loader import Loader from annbatch.types import OutputInMemoryArray_T -class HasBatchSampler(Protocol): - _batch_sampler: Sampler - - -def validate_sampler[Self: HasBatchSampler, **Param, RetType]( - method: Callable[Concatenate[Self, Param], RetType], -) -> Callable[Concatenate[Self, Param], RetType]: +def validate_sampler[**Param, RetType]( + method: Callable[Concatenate[Loader, Param], RetType], +) -> Callable[Concatenate[Loader, Param], RetType]: """Decorator that validates n_obs before modifying state. Expects the first positional argument to be either: @@ -42,7 +38,7 @@ def validate_sampler[Self: HasBatchSampler, **Param, RetType]( first_param_name = list(sig.parameters.keys())[1] @wraps(method) - def wrapper(self: Self, *args: Param.args, **kwargs: Param.kwargs) -> RetType: + def wrapper(self: Loader, *args: Param.args, **kwargs: Param.kwargs) -> RetType: if len(args) > 0: first_arg = args[0] else: From b53d685c6c172d7c551dc6ecf3d4d0fd806c49c9 Mon Sep 17 00:00:00 2001 From: selmanozleyen Date: Wed, 21 Jan 2026 13:20:54 +0100 Subject: [PATCH 47/56] change in folder structure --- docs/api.md | 1 + docs/conf.py | 2 -- src/annbatch/__init__.py | 6 ++++-- src/annbatch/abc/__init__.py | 5 +++++ .../{sampler/abc/_sampler.py => abc/sampler.py} | 0 src/annbatch/loader.py | 5 ++--- src/annbatch/sampler/__init__.py | 12 ------------ src/annbatch/sampler/abc/__init__.py | 5 ----- src/annbatch/samplers/__init__.py | 5 +++++ src/annbatch/{sampler => samplers}/_chunk_sampler.py | 2 +- tests/test_dataset.py | 7 +++---- tests/test_sampler.py | 2 +- 12 files changed, 22 insertions(+), 30 deletions(-) create mode 100644 src/annbatch/abc/__init__.py rename src/annbatch/{sampler/abc/_sampler.py => abc/sampler.py} (100%) delete mode 100644 src/annbatch/sampler/__init__.py delete mode 100644 src/annbatch/sampler/abc/__init__.py create mode 100644 src/annbatch/samplers/__init__.py rename src/annbatch/{sampler => samplers}/_chunk_sampler.py (99%) diff --git a/docs/api.md b/docs/api.md index cf399fd6..22e04d66 100644 --- a/docs/api.md +++ b/docs/api.md @@ -4,6 +4,7 @@ .. module:: annbatch ``` + (loaders)= ## Loaders diff --git a/docs/conf.py b/docs/conf.py index e00e6cba..4247d757 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -134,8 +134,6 @@ # you can add an exception to this list. # ("py:class", "igraph.Graph"), ("py:class", "annbatch.types.TypeAliasType"), - # this is not exposed in the public API yet - ("py:class", "annbatch.sampler.abc._sampler.Sampler"), ] qualname_overrides = { diff --git a/src/annbatch/__init__.py b/src/annbatch/__init__.py index 9b4f89c2..39180c0b 100644 --- a/src/annbatch/__init__.py +++ b/src/annbatch/__init__.py @@ -2,9 +2,10 @@ from importlib.metadata import version -from . import sampler, types +from . import abc, types from .io import DatasetCollection, write_sharded from .loader import Loader +from .samplers._chunk_sampler import ChunkSampler __version__ = version("annbatch") @@ -12,6 +13,7 @@ "Loader", "DatasetCollection", "types", - "sampler", "write_sharded", + "ChunkSampler", + "abc", ] diff --git a/src/annbatch/abc/__init__.py b/src/annbatch/abc/__init__.py new file mode 100644 index 00000000..9a3f765f --- /dev/null +++ b/src/annbatch/abc/__init__.py @@ -0,0 +1,5 @@ +from .sampler import Sampler + +__all__ = [ + "Sampler", +] diff --git a/src/annbatch/sampler/abc/_sampler.py b/src/annbatch/abc/sampler.py similarity index 100% rename from src/annbatch/sampler/abc/_sampler.py rename to src/annbatch/abc/sampler.py diff --git a/src/annbatch/loader.py b/src/annbatch/loader.py index 580afe33..5703adb8 100644 --- a/src/annbatch/loader.py +++ b/src/annbatch/loader.py @@ -15,7 +15,7 @@ from scipy import sparse as sp from zarr import Array as ZarrArray -from annbatch.sampler import ChunkSampler +from annbatch.samplers import ChunkSampler from annbatch.types import BackingArray_T, InputInMemoryArray_T, LoaderOutput, OutputInMemoryArray_T from annbatch.utils import ( CSRContainer, @@ -33,8 +33,8 @@ from collections.abc import Callable, Iterator from types import ModuleType + from annbatch.abc import Sampler from annbatch.io import DatasetCollection - from annbatch.sampler.abc import Sampler # TODO: remove after sphinx 9 - myst compat BackingArray = BackingArray_T @@ -124,7 +124,6 @@ class Loader[ # but this is not ideal since they are hardcoded into the docstrings # maybe we should make _COMMON_SAMPLER_ARGS a public class field? - # TODO (selmanozleyen): can't link chunk sampler to the docstring because it's not exposed in the public API _train_datasets: list[BackingArray] _obs: list[pd.DataFrame] | None = None _return_index: bool = False diff --git a/src/annbatch/sampler/__init__.py b/src/annbatch/sampler/__init__.py deleted file mode 100644 index d264a8d4..00000000 --- a/src/annbatch/sampler/__init__.py +++ /dev/null @@ -1,12 +0,0 @@ -"""Sampler classes for efficient chunk-based data access. - -This module provides samplers optimized for chunk-based data access patterns. -""" - -from annbatch.sampler import abc -from annbatch.sampler._chunk_sampler import ChunkSampler - -__all__ = [ - "ChunkSampler", - "abc", -] diff --git a/src/annbatch/sampler/abc/__init__.py b/src/annbatch/sampler/abc/__init__.py deleted file mode 100644 index 776affab..00000000 --- a/src/annbatch/sampler/abc/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -from annbatch.sampler.abc._sampler import Sampler - -__all__ = [ - "Sampler", -] diff --git a/src/annbatch/samplers/__init__.py b/src/annbatch/samplers/__init__.py new file mode 100644 index 00000000..9f92bbf0 --- /dev/null +++ b/src/annbatch/samplers/__init__.py @@ -0,0 +1,5 @@ +from ._chunk_sampler import ChunkSampler + +__all__ = [ + "ChunkSampler", +] diff --git a/src/annbatch/sampler/_chunk_sampler.py b/src/annbatch/samplers/_chunk_sampler.py similarity index 99% rename from src/annbatch/sampler/_chunk_sampler.py rename to src/annbatch/samplers/_chunk_sampler.py index 55b12cf3..88a7c6d2 100644 --- a/src/annbatch/sampler/_chunk_sampler.py +++ b/src/annbatch/samplers/_chunk_sampler.py @@ -8,7 +8,7 @@ import numpy as np -from annbatch.sampler.abc import Sampler +from annbatch.abc import Sampler from annbatch.utils import check_lt_1, split_given_size if TYPE_CHECKING: diff --git a/tests/test_dataset.py b/tests/test_dataset.py index bbd5ab09..b6210786 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -12,8 +12,8 @@ import scipy.sparse as sp import zarr -from annbatch import Loader -from annbatch.sampler import ChunkSampler +from annbatch import ChunkSampler, Loader +from annbatch.abc import Sampler try: from cupy import ndarray as CupyArray @@ -433,9 +433,8 @@ def test_no_obs(simple_collection: tuple[ad.AnnData, DatasetCollection]): def test_add_dataset_validation_failure_preserves_state(adata_with_zarr_path_same_var_space: tuple[ad.AnnData, Path]): """Test that failed validation in add_dataset doesn't modify internal state.""" - from annbatch.sampler import abc - class FailOnSecondValidateSampler(abc.Sampler): + class FailOnSecondValidateSampler(Sampler): """A sampler that fails validation after the first call.""" def __init__(self): diff --git a/tests/test_sampler.py b/tests/test_sampler.py index 75174202..1ef6b2d9 100644 --- a/tests/test_sampler.py +++ b/tests/test_sampler.py @@ -5,7 +5,7 @@ import numpy as np import pytest -from annbatch.sampler import ChunkSampler +from annbatch import ChunkSampler # TODO(selmanozleyen): Check for the validation within the _get_worker_handle method. Mock worker handle wouldn't make sense # but overall one must also think about how validation can't be independent of the worker handle. From d464caafe56bd5a1703808b6b7f667beb98d995b Mon Sep 17 00:00:00 2001 From: selmanozleyen Date: Wed, 21 Jan 2026 13:31:32 +0100 Subject: [PATCH 48/56] make batch sampler getter --- docs/api.md | 12 ++++++++++++ src/annbatch/loader.py | 10 ++++++++++ src/annbatch/utils.py | 2 +- 3 files changed, 23 insertions(+), 1 deletion(-) diff --git a/docs/api.md b/docs/api.md index 22e04d66..1da69c65 100644 --- a/docs/api.md +++ b/docs/api.md @@ -15,6 +15,8 @@ Loader Loader.__iter__ + + ChunkSampler ``` (io-helpers)= @@ -29,6 +31,15 @@ DatasetCollection ``` +(abc)= +## abc +```{eval-rst} +.. autosummary:: + :toctree: generated/ + + abc.Sampler +``` + (types)= ## types @@ -37,4 +48,5 @@ :toctree: generated/ types.LoaderOutput + types.LoadRequest ``` diff --git a/src/annbatch/loader.py b/src/annbatch/loader.py index 5703adb8..9f794df3 100644 --- a/src/annbatch/loader.py +++ b/src/annbatch/loader.py @@ -241,6 +241,16 @@ def n_var(self) -> int: raise ValueError("No datasets added yet") return self._shapes[0][1] + @property + def batch_sampler(self) -> Sampler: + """The sampler used to generate batches. + + Returns + ------- + The sampler. + """ + return self._batch_sampler + def use_collection( self, collection: DatasetCollection, *, load_adata: Callable[[zarr.Group], ad.AnnData] = load_x_and_obs ) -> Self: diff --git a/src/annbatch/utils.py b/src/annbatch/utils.py index 3dd04e66..0bf27cf8 100644 --- a/src/annbatch/utils.py +++ b/src/annbatch/utils.py @@ -45,7 +45,7 @@ def wrapper(self: Loader, *args: Param.args, **kwargs: Param.kwargs) -> RetType: first_arg = kwargs[first_param_name] n_obs = sum(item.shape[0] for item in first_arg) if isinstance(first_arg, list) else first_arg.shape[0] - self._batch_sampler.validate(n_obs) + self.batch_sampler.validate(n_obs) return method(self, *args, **kwargs) return wrapper From 0b4883bbeb73e10a2a6f9c89f0a228f474f79f3a Mon Sep 17 00:00:00 2001 From: selmanozleyen Date: Wed, 21 Jan 2026 13:33:09 +0100 Subject: [PATCH 49/56] remove empty line --- docs/api.md | 1 - 1 file changed, 1 deletion(-) diff --git a/docs/api.md b/docs/api.md index 1da69c65..42309c0d 100644 --- a/docs/api.md +++ b/docs/api.md @@ -15,7 +15,6 @@ Loader Loader.__iter__ - ChunkSampler ``` From c4505862af9c45664623621d8490e0e23cffb849 Mon Sep 17 00:00:00 2001 From: selmanozleyen Date: Wed, 21 Jan 2026 13:44:43 +0100 Subject: [PATCH 50/56] apply docstring suggestions for Loader args --- src/annbatch/loader.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/src/annbatch/loader.py b/src/annbatch/loader.py index 9f794df3..50dd61b7 100644 --- a/src/annbatch/loader.py +++ b/src/annbatch/loader.py @@ -76,16 +76,19 @@ class Loader[ ---------- batch_sampler If not provided, a default :class:`annbatch.ChunkSampler` will be used with the same defaults below. - The number of chunks of contiguous array data to fetch. When batch_sampler is not provided, this is used to determine the preload_nchunks. Mutually exclusive with `batch_sampler`. Defaults to 32. + chunk_size + The obs size (i.e., axis 0) of contiguous array data to fetch. Mutually exclusive with `batch_sampler`. Defaults to 512. + preload_nchunks + The number of chunks of contiguous array data to fetch. Mutually exclusive with `batch_sampler`. Defaults to 32. shuffle - Whether or not to shuffle the data. When batch_sampler is not provided, this is used to determine the shuffle. Mutually exclusive with `batch_sampler`. Defaults to False. + Whether or not to shuffle the data. Mutually exclusive with `batch_sampler`. Defaults to False. batch_size - Batch size to yield from the dataset. When batch_sampler is not provided, this is used to determine the batch size. Mutually exclusive with `batch_sampler`. Defaults to 1. + Batch size to yield from the dataset. Mutually exclusive with `batch_sampler`. Defaults to 1. drop_last Set to True to drop the last incomplete batch, if the dataset size is not divisible by the batch size. If False and the size of dataset is not divisible by the batch size, then the last batch will be smaller. Leave as False when using in conjunction with a :class:`torch.utils.data.DataLoader`. - When batch_sampler is not provided, this is used to determine the drop_last. Mutually exclusive with `batch_sampler`. Defaults to False. + Mutually exclusive with `batch_sampler`. Defaults to False. return_index Whether or not to yield the index on each iteration. preload_to_gpu From 15a11b968ca4dbb18d8827ddc4f80dddc813b2bb Mon Sep 17 00:00:00 2001 From: selmanozleyen Date: Wed, 21 Jan 2026 13:49:19 +0100 Subject: [PATCH 51/56] remove empty line --- docs/api.md | 1 - 1 file changed, 1 deletion(-) diff --git a/docs/api.md b/docs/api.md index 42309c0d..38d1139c 100644 --- a/docs/api.md +++ b/docs/api.md @@ -4,7 +4,6 @@ .. module:: annbatch ``` - (loaders)= ## Loaders From 84c91244a5a8fef768cf898bc7ffe2d5b1dc7abb Mon Sep 17 00:00:00 2001 From: selmanozleyen Date: Wed, 21 Jan 2026 13:52:28 +0100 Subject: [PATCH 52/56] conf.py is same as main --- docs/conf.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/conf.py b/docs/conf.py index 4247d757..10b83dd5 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -133,7 +133,7 @@ # If building the documentation fails because of a missing link that is outside your control, # you can add an exception to this list. # ("py:class", "igraph.Graph"), - ("py:class", "annbatch.types.TypeAliasType"), + ("py:class", "annbatch.types.TypeAliasType") ] qualname_overrides = { From c38160dc4fc6e734aff0f14ffbcdbfef1ef1d148 Mon Sep 17 00:00:00 2001 From: selmanozleyen Date: Wed, 21 Jan 2026 14:14:53 +0100 Subject: [PATCH 53/56] change shuffle --- src/annbatch/samplers/_chunk_sampler.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/annbatch/samplers/_chunk_sampler.py b/src/annbatch/samplers/_chunk_sampler.py index 88a7c6d2..4428e5cb 100644 --- a/src/annbatch/samplers/_chunk_sampler.py +++ b/src/annbatch/samplers/_chunk_sampler.py @@ -139,7 +139,10 @@ def _sample(self, n_obs: int) -> Iterator[LoadRequest]: chunk_indices = np.arange(math.ceil((stop - start) / self._chunk_size)) if self._shuffle: # TODO(selmanozleyen): maybe this should be done worker-aware way? - self._rng.shuffle(chunk_indices) + if worker_handle is None: + self._rng.shuffle(chunk_indices) + else: + worker_handle.shuffle(chunk_indices) chunks = self._compute_chunks(chunk_indices, start, stop) # Worker sharding: each worker gets a disjoint subset of chunks if worker_handle is not None: From b8472e3bfb6a0d36bf7c36beac974e93bc3959bf Mon Sep 17 00:00:00 2001 From: selmanozleyen Date: Wed, 21 Jan 2026 14:29:49 +0100 Subject: [PATCH 54/56] remove todo --- src/annbatch/samplers/_chunk_sampler.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/annbatch/samplers/_chunk_sampler.py b/src/annbatch/samplers/_chunk_sampler.py index 4428e5cb..0872a3e2 100644 --- a/src/annbatch/samplers/_chunk_sampler.py +++ b/src/annbatch/samplers/_chunk_sampler.py @@ -138,7 +138,6 @@ def _sample(self, n_obs: int) -> Iterator[LoadRequest]: # Create chunk indices for possible shuffling and worker sharding chunk_indices = np.arange(math.ceil((stop - start) / self._chunk_size)) if self._shuffle: - # TODO(selmanozleyen): maybe this should be done worker-aware way? if worker_handle is None: self._rng.shuffle(chunk_indices) else: From e620194fb5dcc4e9285a904ee6b09394b14cb40b Mon Sep 17 00:00:00 2001 From: selmanozleyen Date: Wed, 21 Jan 2026 14:43:04 +0100 Subject: [PATCH 55/56] update to match old behaviour --- src/annbatch/samplers/_chunk_sampler.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/annbatch/samplers/_chunk_sampler.py b/src/annbatch/samplers/_chunk_sampler.py index 0872a3e2..aebf05c3 100644 --- a/src/annbatch/samplers/_chunk_sampler.py +++ b/src/annbatch/samplers/_chunk_sampler.py @@ -154,7 +154,7 @@ def _sample(self, n_obs: int) -> Iterator[LoadRequest]: for batch_chunks in chunks_per_batch[:-1]: if self._shuffle: # Avoid copies using in-place shuffling since `self._shuffle` should not change mid-training - self._rng.shuffle(batch_indices) + np.random.default_rng().shuffle(batch_indices) split_batch_indices = split_given_size(batch_indices, self._batch_size) yield {"chunks": batch_chunks, "splits": split_batch_indices} # On the last yield, drop the last uneven batch and create new batch_indices since the in-memory size of this last yield could be divisible by batch_size but smaller than preload_nslices * slice_size @@ -163,7 +163,7 @@ def _sample(self, n_obs: int) -> Iterator[LoadRequest]: if self._drop_last: total_obs_in_last_batch -= total_obs_in_last_batch % self._batch_size batch_indices = split_given_size( - (self._rng.permutation if self._shuffle else np.arange)(total_obs_in_last_batch), + (np.random.default_rng().permutation if self._shuffle else np.arange)(total_obs_in_last_batch), self._batch_size, ) yield {"chunks": final_chunks, "splits": batch_indices} From 72370230a5d1be11fed4a1ec1cbedad2ea889733 Mon Sep 17 00:00:00 2001 From: selmanozleyen Date: Wed, 21 Jan 2026 21:31:54 +0100 Subject: [PATCH 56/56] put vstack inside accumulate chunks --- src/annbatch/loader.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/src/annbatch/loader.py b/src/annbatch/loader.py index 50dd61b7..6f22d5d2 100644 --- a/src/annbatch/loader.py +++ b/src/annbatch/loader.py @@ -623,7 +623,6 @@ def __iter__( [len(self._train_datasets), self.n_obs], ["Number of datasets", "Number of observations"], ) - mod = self._sp_module if issubclass(self.dataset_type, ad.abc.CSRDataset) else np for load_request in self._batch_sampler.sample(self.n_obs): chunks_to_load = load_request["chunks"] @@ -632,13 +631,11 @@ def __iter__( dataset_index_to_slices = self._slices_to_slices_with_array_index(chunks_to_load, use_original_space=False) # Fetch the data over slices chunks: list[InputInMemoryArray] = zsync.sync(self._index_datasets(dataset_index_to_slices)) - chunks_converted = self._accumulate_chunks(chunks) + in_memory_data: OutputInMemoryArray_T = self._accumulate_chunks(chunks) # Accumulate labels and indices if possible concatenated_obs: None | pd.DataFrame = self._maybe_accumulate_obs(dataset_index_to_slices) in_memory_indices: None | np.ndarray = self._maybe_accumulate_indices(chunks_to_load) - in_memory_data = mod.vstack(chunks_converted) - for split in splits: data = in_memory_data[split] yield { @@ -647,7 +644,7 @@ def __iter__( "index": in_memory_indices[split] if in_memory_indices is not None else None, } - def _accumulate_chunks(self, chunks: list[InputInMemoryArray]) -> list[OutputInMemoryArray_T]: + def _accumulate_chunks(self, chunks: list[InputInMemoryArray]) -> OutputInMemoryArray_T: """Convert fetched chunks to output array format (CSR or ndarray).""" result: list[OutputInMemoryArray_T] = [] for chunk in chunks: @@ -661,7 +658,8 @@ def _accumulate_chunks(self, chunks: list[InputInMemoryArray]) -> list[OutputInM ) else: result.append(self._np_module.asarray(chunk)) - return result + mod = self._sp_module if issubclass(self.dataset_type, ad.abc.CSRDataset) else np + return mod.vstack(result) def _maybe_accumulate_obs(self, dataset_index_to_slices: OrderedDict[int, list[slice]]) -> pd.DataFrame | None: """Gather obs labels for the loaded slices if possible."""