diff --git a/.readthedocs.yaml b/.readthedocs.yaml index c3f3f96f..4acb793c 100644 --- a/.readthedocs.yaml +++ b/.readthedocs.yaml @@ -3,7 +3,7 @@ version: 2 build: os: ubuntu-24.04 tools: - python: "3.12" + python: "3.14" jobs: create_environment: - asdf plugin add uv diff --git a/README.md b/README.md index b7a4d116..391303c1 100644 --- a/README.md +++ b/README.md @@ -71,7 +71,7 @@ For a detailed tutorial, please see the [in-depth section of our docs][] Basic preprocessing: ```python -from annbatch import create_anndata_collection +from annbatch import DatasetCollection import zarr from pathlib import Path @@ -82,13 +82,14 @@ zarr.config.set( {"codec_pipeline.path": "zarrs.ZarrsCodecPipeline"} ) -create_anndata_collection( +# Create a collection at the given path. The subgroups will all be anndata stores. +collection = DatasetCollection("path/to/output/collection.zarr") +collection.add_adatas( adata_paths=[ "path/to/your/file1.h5ad", "path/to/your/file2.h5ad" ], - output_path="path/to/output/collection", # a directory containing `dataset_{i}.zarr` - shuffle=True, # shuffling is needed if you want to use chunked access + shuffle=True, # shuffling is needed if you want to use chunked access, but is the default ) ``` @@ -107,22 +108,20 @@ zarr.config.set( {"codec_pipeline.path": "zarrs.ZarrsCodecPipeline"} ) +def custom_load_func(g: zarr.Group) -> ad.AnnData: + return ad.AnnData(X=ad.io.sparse_dataset(g["layers"]["counts"]), obs=ad.io.read_elem(g["obs"])[some_subset_of_columns]) + # This settings override ensures that you don't lose/alter your categorical codes when reading the data in! with ad.settings.override(remove_unused_categories=False): ds = Loader( batch_size=4096, chunk_size=32, preload_nchunks=256, - ).add_anndatas( - [ - ad.AnnData( - # note that you can open an AnnData file using any type of zarr store - X=ad.io.sparse_dataset(zarr.open(p)["X"]), - obs=ad.io.read_elem(zarr.open(p)["obs"]), - ) - for p in Path("path/to/output/collection").glob("*.zarr") - ] ) + # `use_collection` automatically uses the on-disk `X` and full `obs` in the `Loader` + # but the `load_adata` arg can override this behavior + # (see `custom_load_func` above for an example of customization). + ds = ds.use_collection(collection) # Iterate over dataloader (plugin replacement for torch.utils.DataLoader) for batch in ds: diff --git a/docs/api.md b/docs/api.md index 7656af65..cf399fd6 100644 --- a/docs/api.md +++ b/docs/api.md @@ -25,8 +25,7 @@ :toctree: generated/ write_sharded - add_to_collection - create_anndata_collection + DatasetCollection ``` (types)= diff --git a/docs/index.md b/docs/index.md index 46bf0470..ee945a50 100644 --- a/docs/index.md +++ b/docs/index.md @@ -9,12 +9,11 @@ Let's go through the above example: ### Preprocessing ```python -create_anndata_collection( +colleciton = DatasetCollection("path/to/output/store.zarr").add_adatas( adata_paths=[ "path/to/your/file1.h5ad", "path/to/your/file2.h5ad" ], - output_path="path/to/output/store", # a directory containing `chunk_{i}.zarr` shuffle=True, # shuffling is needed if you want to use chunked access ) ``` @@ -33,20 +32,12 @@ See the [zarr docs on sharding][] for more information. #### Chunked access ```python +# `use_collection` will automatically get everything in `X` and `obs` and yield it. ds = Loader( batch_size=4096, chunk_size=32, preload_nchunks=256, -).add_anndatas( - [ - ad.AnnData( - # note that you can open an anndata file using any type of zarr store - X=ad.io.sparse_dataset(zarr.open(p)["X"]), - obs=ad.io.read_elem(zarr.open(p)["obs"]), - ) - for p in PATH_TO_STORE.glob("*.zarr") - ] -) +).use_collection(collection) # Iterate over dataloader (plugin replacement for torch.utils.DataLoader) for batch in ds: diff --git a/docs/notebooks/example.ipynb b/docs/notebooks/example.ipynb index 3f8aef26..b20ccd1b 100644 --- a/docs/notebooks/example.ipynb +++ b/docs/notebooks/example.ipynb @@ -19,7 +19,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 2, "metadata": { "tags": [ "hide-output" @@ -32,7 +32,7 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": 3, "metadata": { "tags": [ "hide-output" @@ -43,28 +43,8 @@ "name": "stdout", "output_type": "stream", "text": [ - "--2025-10-09 09:43:19-- https://datasets.cellxgene.cziscience.com/866d7d5e-436b-4dbd-b7c1-7696487d452e.h5ad\n", - "Resolving datasets.cellxgene.cziscience.com (datasets.cellxgene.cziscience.com)... 18.64.79.73, 18.64.79.80, 18.64.79.109, ...\n", - "Connecting to datasets.cellxgene.cziscience.com (datasets.cellxgene.cziscience.com)|18.64.79.73|:443... connected.\n", - "HTTP request sent, awaiting response... 200 OK\n", - "Length: 773247972 (737M) [binary/octet-stream]\n", - "Saving to: ‘866d7d5e-436b-4dbd-b7c1-7696487d452e.h5ad’\n", - "\n", - "866d7d5e-436b-4dbd- 100%[===================>] 737.43M 398MB/s in 1.9s \n", - "\n", - "2025-10-09 09:43:21 (398 MB/s) - ‘866d7d5e-436b-4dbd-b7c1-7696487d452e.h5ad’ saved [773247972/773247972]\n", - "\n", - "--2025-10-09 09:43:22-- https://datasets.cellxgene.cziscience.com/f81463b8-4986-4904-a0ea-20ff02cbb317.h5ad\n", - "Resolving datasets.cellxgene.cziscience.com (datasets.cellxgene.cziscience.com)... 18.64.79.73, 18.64.79.80, 18.64.79.72, ...\n", - "Connecting to datasets.cellxgene.cziscience.com (datasets.cellxgene.cziscience.com)|18.64.79.73|:443... connected.\n", - "HTTP request sent, awaiting response... 200 OK\n", - "Length: 1631759823 (1.5G) [binary/octet-stream]\n", - "Saving to: ‘f81463b8-4986-4904-a0ea-20ff02cbb317.h5ad’\n", - "\n", - "f81463b8-4986-4904- 100%[===================>] 1.52G 425MB/s in 3.9s \n", - "\n", - "2025-10-09 09:43:26 (403 MB/s) - ‘f81463b8-4986-4904-a0ea-20ff02cbb317.h5ad’ saved [1631759823/1631759823]\n", - "\n" + "zsh:1: command not found: wget\n", + "zsh:1: command not found: wget\n" ] } ], @@ -95,7 +75,7 @@ { "data": { "text/plain": [ - "" + "" ] }, "execution_count": 1, @@ -105,7 +85,6 @@ ], "source": [ "import zarr\n", - "import zarrs # noqa\n", "\n", "zarr.config.set({\"codec_pipeline.path\": \"zarrs.ZarrsCodecPipeline\"})" ] @@ -149,21 +128,48 @@ " * We recommend to choose a dataset size that comfortably fits into system memory.\n", "\n", "\n", - "You can apply custom data transformations to each input h5ad file by supplying a `load_adata` function to `create_anndata_collection`" + "You can apply custom data transformations to each input h5ad file by supplying a `load_adata` function to `DatasetCollection.add`" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 3, "metadata": { "tags": [ "hide-output" ] }, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/ilangold/Projects/Theis/annbatch/venv/lib/python3.12/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n", + "checking for mismatched keys: 100%|██████████| 2/2 [00:00<00:00, 2.02it/s]\n", + "loading: 2it [00:00, 2.34it/s]\n", + "processing chunks: 0%| | 0/1 [00:00" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ + "%load_ext autoreload\n", + "%autoreload 2\n", + "\n", "import anndata as ad\n", - "from annbatch import create_anndata_collection\n", + "from annbatch import DatasetCollection\n", "\n", "\n", "# For CELLxGENE data, the raw counts can either be found under .raw.X or under .X (if .raw is not supplied).\n", @@ -181,20 +187,21 @@ "\n", " return ad.AnnData(\n", " X=x,\n", - " obs=adata_.obs.to_memory(),\n", + " obs=adata_.obs.to_memory()[\n", + " [\"cell_type\"]\n", + " ], # let's only take cell type since it is shared - otherwise DatasetCollection will warn about all the columns we are missing\n", " var=var.to_memory(),\n", " )\n", "\n", "\n", - "create_anndata_collection(\n", + "collection = DatasetCollection(zarr.open(\"annbatch_collection\", mode=\"w\"))\n", + "collection.add_adatas(\n", " # List all the h5ad files you want to include in the collection\n", " adata_paths=[\"866d7d5e-436b-4dbd-b7c1-7696487d452e.h5ad\", \"f81463b8-4986-4904-a0ea-20ff02cbb317.h5ad\"],\n", " # Path to store the output collection\n", - " output_path=\"annbatch_collection\",\n", " shuffle=True, # Whether to pre-shuffle the cells of the collection\n", - " n_obs_per_dataset=2_097_152, # Number of cells per dataset shard\n", + " n_obs_per_dataset=2_097_152, # Number of cells per dataset shard, this number is much higher than available in these datasets but is generally a good target\n", " var_subset=None, # Optionally subset the collection to a specific gene space\n", - " should_denseify=False,\n", " load_adata=read_lazy_x_and_obs_only,\n", ")" ] @@ -208,18 +215,7 @@ }, { "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from pathlib import Path\n", - "\n", - "COLLECTION_PATH = Path(\"annbatch_collection/\")" - ] - }, - { - "cell_type": "code", - "execution_count": 5, + "execution_count": 6, "metadata": { "tags": [ "hide-output" @@ -229,10 +225,10 @@ { "data": { "text/plain": [ - "" + "" ] }, - "execution_count": 5, + "execution_count": 6, "metadata": {}, "output_type": "execute_result" } @@ -251,17 +247,8 @@ " to_torch=True,\n", ")\n", "\n", - "# Add dataset that should be used for training\n", - "ds.add_anndatas(\n", - " [\n", - " ad.AnnData(\n", - " X=ad.io.sparse_dataset(zarr.open(p)[\"X\"]),\n", - " obs=ad.io.read_elem(zarr.open(p)[\"obs\"]),\n", - " )\n", - " for p in COLLECTION_PATH.glob(\"*.zarr\")\n", - " ],\n", - " obs_keys=\"cell_type\",\n", - ")" + "# Add in the shuffled data that should be used for training\n", + "ds.use_collection(collection)" ] }, { @@ -278,7 +265,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 7, "metadata": { "tags": [ "hide-output" @@ -289,9 +276,7 @@ "name": "stderr", "output_type": "stream", "text": [ - " 0%| | 0/171792 [00:00" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" } ], "source": [ - "from annbatch import add_to_collection\n", - "\n", - "\n", - "def read_x_and_obs_only(path) -> ad.AnnData:\n", - " \"\"\"Custom load function to only load raw counts from CxG data.\"\"\"\n", - " # As it's only a small dataset, we can load the full dataset into memory to speed up computations\n", - " adata_ = ad.read_h5ad(path) # Replace with ad.experimental.read_lazy if data does not fit into memory anymore\n", - " if adata_.raw is not None:\n", - " x = adata_.raw.X\n", - " var = adata_.raw.var\n", - " else:\n", - " x = adata_.X\n", - " var = adata_.var\n", - "\n", - " return ad.AnnData(X=x, obs=adata_.obs, var=var)\n", - "\n", - "\n", - "add_to_collection(\n", + "collection.add_adatas(\n", " adata_paths=[\n", " \"866d7d5e-436b-4dbd-b7c1-7696487d452e.h5ad\",\n", " ],\n", - " output_path=\"annbatch_collection\",\n", - " load_adata=read_x_and_obs_only,\n", + " load_adata=read_lazy_x_and_obs_only,\n", ")" ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] } ], "metadata": { @@ -393,7 +378,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.12.6" + "version": "3.12.3" } }, "nbformat": 4, diff --git a/pyproject.toml b/pyproject.toml index 9a8e219e..db74bebb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -49,7 +49,8 @@ optional-dependencies.doc = [ "myst-nb>=1.1", "pandas", "scanpydoc[theme,typehints]>=0.15.3", - "sphinx>=8.1", + # https://github.com/sphinx-toolbox/sphinx-toolbox/issues/201 + "sphinx>=8.1,<=8.2.3", "sphinx-autodoc-typehints", "sphinx-book-theme>=1", "sphinx-copybutton", diff --git a/src/annbatch/__init__.py b/src/annbatch/__init__.py index 306cbb99..d53341c0 100644 --- a/src/annbatch/__init__.py +++ b/src/annbatch/__init__.py @@ -3,9 +3,9 @@ from importlib.metadata import version from . import types -from .io import add_to_collection, create_anndata_collection, write_sharded +from .io import DatasetCollection, write_sharded from .loader import Loader __version__ = version("annbatch") -__all__ = ["Loader", "write_sharded", "add_to_collection", "create_anndata_collection", "types"] +__all__ = ["Loader", "write_sharded", "DatasetCollection", "types"] diff --git a/src/annbatch/io.py b/src/annbatch/io.py index f197cd14..dd23f2f5 100644 --- a/src/annbatch/io.py +++ b/src/annbatch/io.py @@ -1,15 +1,17 @@ from __future__ import annotations -import json +import math import random +import re import warnings from collections import defaultdict from functools import wraps from pathlib import Path -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Self import anndata as ad import dask.array as da +import h5py import numpy as np import pandas as pd import scipy.sparse as sp @@ -19,13 +21,17 @@ from tqdm.auto import tqdm from zarr.codecs import BloscCodec, BloscShuffle +from annbatch.utils import split_given_size + if TYPE_CHECKING: - from collections.abc import Callable, Iterable, Mapping + from collections.abc import Callable, Generator, Iterable, Mapping from os import PathLike from typing import Any, Literal from zarr.abc.codec import BytesBytesCodec +V1_ENCODING = {"encoding-type": "annbatch-preshuffled", "encoding-version": "0.1.0"} + def _round_down(num: int, divisor: int): return num - (num % divisor) @@ -40,6 +46,7 @@ def write_sharded( dense_chunk_size: int = 1024, dense_shard_size: int = 4194304, compressors: Iterable[BytesBytesCodec] = (BloscCodec(cname="lz4", clevel=3, shuffle=BloscShuffle.shuffle),), + key: str | None = None, ): """Write a sharded zarr store from a single AnnData object. @@ -59,6 +66,8 @@ def write_sharded( Number of obs elements per dense shard along the first axis compressors The compressors to pass to `zarr`. + key + The key to which this object should be written - by default the root, in which case the *entire* store (not just the group) is cleared first. """ ad.settings.zarr_write_format = 3 @@ -99,11 +108,17 @@ def callback( } write_func(store, elem_name, elem, dataset_kwargs=dataset_kwargs) - ad.experimental.write_dispatched(group, "/", adata, callback=callback) + ad.experimental.write_dispatched(group, "/" if key is None else key, adata, callback=callback) zarr.consolidate_metadata(group.store) -def _check_for_mismatched_keys(paths_or_anndatas: Iterable[PathLike[str] | ad.AnnData] | Iterable[str | ad.AnnData]): +def _check_for_mismatched_keys( + paths_or_anndatas: Iterable[PathLike[str] | ad.AnnData | zarr.Group | h5py.Group] | Iterable[str | ad.AnnData], + *, + load_adata: Callable[[PathLike[str] | str], ad.AnnData] = lambda x: ad.experimental.read_lazy( + x, load_annotation_index=False + ), +): num_raw_in_adata = 0 found_keys: dict[str, defaultdict[str, int]] = { "layers": defaultdict(lambda: 0), @@ -112,13 +127,13 @@ def _check_for_mismatched_keys(paths_or_anndatas: Iterable[PathLike[str] | ad.An } for path_or_anndata in tqdm(paths_or_anndatas, desc="checking for mismatched keys"): if not isinstance(path_or_anndata, ad.AnnData): - adata = ad.experimental.read_lazy(path_or_anndata, load_annotation_index=False) + adata = load_adata(path_or_anndata) else: adata = path_or_anndata for elem_name, key_count in found_keys.items(): curr_keys = set(getattr(adata, elem_name).keys()) for key in curr_keys: - if not (elem_name in { "var", "obs" } and key == "_index"): + if not (elem_name in {"var", "obs"} and key == "_index"): key_count[key] += 1 if adata.raw is not None: num_raw_in_adata += 1 @@ -140,10 +155,12 @@ def _check_for_mismatched_keys(paths_or_anndatas: Iterable[PathLike[str] | ad.An def _lazy_load_anndatas( paths: Iterable[PathLike[str]] | Iterable[str], - load_adata: Callable[[PathLike[str] | str], ad.AnnData] = lambda x: ad.experimental.read_lazy(x, load_annotation_index=False), + load_adata: Callable[[PathLike[str] | str], ad.AnnData] = lambda x: ad.experimental.read_lazy( + x, load_annotation_index=False + ), ): adatas = [] - categoricals_in_all_adatas = {} + categoricals_in_all_adatas: dict[str, pd.Index] = {} for i, path in tqdm(enumerate(paths), desc="loading"): adata = load_adata(path) # Track the source file for this given anndata object @@ -153,20 +170,23 @@ def _lazy_load_anndatas( # Concatenating Dataset2D drops categoricals so we need to track them if isinstance(adata.obs, Dataset2D): categorical_cols_in_this_adata = { - col: set(adata.obs[col].dtype.categories) - for col in adata.obs.columns - if adata.obs[col].dtype == "category" + col: adata.obs[col].dtype.categories for col in adata.obs.columns if adata.obs[col].dtype == "category" } if not categoricals_in_all_adatas: categoricals_in_all_adatas = { **categorical_cols_in_this_adata, - "src_path": set(adata.obs["src_path"].dtype.categories), + "src_path": adata.obs["src_path"].dtype.categories, } else: for k in categoricals_in_all_adatas.keys() & categorical_cols_in_this_adata.keys(): - categoricals_in_all_adatas[k] = set(categoricals_in_all_adatas[k]).union( - set(categorical_cols_in_this_adata[k]) + categoricals_in_all_adatas[k] = categoricals_in_all_adatas[k].union( + categorical_cols_in_this_adata[k] ) + # TODO: Probably bug in anndata, need the true index for proper outer joins (can't skirt this with fake indexes, at least not in the mixed-type regime). + if isinstance(adata.var, Dataset2D): + adata.var.index = adata.var.true_index + if adata.raw is not None and isinstance(adata.raw.var, Dataset2D): + adata.raw.var.index = adata.raw.var.true_index adatas.append(adata) if len(adatas) == 1: return adatas[0] @@ -176,17 +196,38 @@ def _lazy_load_anndatas( return adata -def _create_chunks_for_shuffling(adata: ad.AnnData, shuffle_n_obs_per_dataset: int = 1_048_576, shuffle: bool = True): - chunk_boundaries = np.cumsum([0] + list(adata.X.chunks[0])) - slices = [ - slice(int(start), int(end)) for start, end in zip(chunk_boundaries[:-1], chunk_boundaries[1:], strict=True) - ] +def _create_chunks_for_shuffling( + n_obs: int, + shuffle_chunk_size: int = 1000, + shuffle: bool = True, + *, + shuffle_n_obs_per_dataset: int | None = None, + n_chunkings: int | None = None, +) -> list[np.ndarray]: + # this splits the array up into `shuffle_chunk_size` contiguous runs + idxs = split_given_size(np.arange(n_obs), shuffle_chunk_size) if shuffle: - random.shuffle(slices) - idxs = np.concatenate([np.arange(s.start, s.stop) for s in slices]) - idxs = np.array_split(idxs, np.ceil(len(idxs) / shuffle_n_obs_per_dataset)) - - return idxs + random.shuffle(idxs) + match shuffle_n_obs_per_dataset is not None, n_chunkings is not None: + case True, False: + n_slices_per_dataset = int(shuffle_n_obs_per_dataset // shuffle_chunk_size) + use_single_chunking = n_obs <= shuffle_n_obs_per_dataset or n_slices_per_dataset <= 1 + case False, True: + n_slices_per_dataset = (n_obs // n_chunkings) // shuffle_chunk_size + use_single_chunking = n_chunkings == 1 + case _, _: + raise ValueError("Cannot provide both shuffle_n_obs_per_dataset and n_chunkings or neither") + # In this case `shuffle_n_obs_per_dataset` is bigger than the size of the dataset or the slice size is probably too big. + if use_single_chunking: + return [np.concatenate(idxs)] + # unfortunately, this is the only way to prevent numpy.split from trying to np.array the idxs list, which can have uneven elements. + idxs = np.array([slice(int(idx[0]), int(idx[-1] + 1)) for idx in idxs]) + return [ + np.concatenate([np.arange(s.start, s.stop) for s in idx]) + for idx in ( + split_given_size(idxs, n_slices_per_dataset) if n_chunkings is None else np.array_split(idxs, n_chunkings) + ) + ] def _compute_blockwise(x: DaskArray) -> sp.spmatrix: @@ -210,9 +251,14 @@ def _persist_adata_in_memory(adata: ad.AnnData) -> ad.AnnData: adata.X = _compute_blockwise(adata.X) if isinstance(adata.obs, Dataset2D): adata.obs = adata.obs.to_memory() + # TODO: This is a bug in anndata? + if "_index" in adata.obs.columns: + del adata.obs["_index"] adata = _to_categorical_obs(adata) if isinstance(adata.var, Dataset2D): adata.var = adata.var.to_memory() + if "_index" in adata.var.columns: + del adata.var["_index"] if adata.raw is not None: adata_raw = adata.raw.to_adata() @@ -220,19 +266,28 @@ def _persist_adata_in_memory(adata: ad.AnnData) -> ad.AnnData: adata_raw.X = _compute_blockwise(adata_raw.X) if isinstance(adata_raw.var, Dataset2D): adata_raw.var = adata_raw.var.to_memory() + if "_index" in adata_raw.var.columns: + del adata_raw.var["_index"] if isinstance(adata_raw.obs, Dataset2D): adata_raw.obs = adata_raw.obs.to_memory() del adata.raw adata.raw = adata_raw - for k, elem in adata.obsm.items(): - # TODO: handle `Dataset2D` in `obsm` and `varm` that are - if isinstance(elem, DaskArray): - adata.obsm[k] = _compute_blockwise(elem) - - for k, elem in adata.layers.items(): - if isinstance(elem, DaskArray): - adata.obsm[k] = _compute_blockwise(elem) + for axis_name in ["layers", "obsm", "varm", "obsp", "varp"]: + for k, elem in getattr(adata, axis_name).items(): + # TODO: handle `Dataset2D` in `obsm` and `varm` that are + if isinstance(elem, DaskArray): + getattr(adata, axis_name)[k] = _compute_blockwise(elem) + if isinstance(elem, Dataset2D): + elem = elem.to_memory() + if "_index" in elem.columns: + del elem["_index"] + # TODO: Bug in anndata + if "obs" in axis_name: + elem.index = adata.obs_names + getattr(adata, axis_name)[k] = elem + + return adata.to_memory() return adata @@ -249,259 +304,351 @@ def wrapper(*args, **kwargs): return wrapper -@_with_settings -def create_anndata_collection( - adata_paths: Iterable[PathLike[str]] | Iterable[str], - output_path: PathLike[str] | str, - *, - load_adata: Callable[[PathLike[str] | str], ad.AnnData] = ad.experimental.read_lazy, - var_subset: Iterable[str] | None = None, - zarr_sparse_chunk_size: int = 32768, - zarr_sparse_shard_size: int = 134_217_728, - zarr_dense_chunk_size: int = 1024, - zarr_dense_shard_size: int = 4_194_304, - zarr_compressor: Iterable[BytesBytesCodec] = (BloscCodec(cname="lz4", clevel=3, shuffle=BloscShuffle.shuffle),), - h5ad_compressor: Literal["gzip", "lzf"] | None = "gzip", - n_obs_per_dataset: int = 2_097_152, - shuffle: bool = True, - should_denseify: bool = False, - output_format: Literal["h5ad", "zarr"] = "zarr", -): - """Take AnnData paths, create an on-disk set of AnnData datasets with uniform var spaces at the desired path with `n_obs_per_dataset` rows per store. +class DatasetCollection[T: (h5py.Group, zarr.Group)]: + """A preshuffled collection object including functionality for creating, adding to, and loading collections shuffled by `annbatch`.""" - The set of AnnData datasets is collectively referred to as a "collection" where each dataset is called `dataset_i.{zarr,h5ad}`. - The main purpose of this function is to create shuffled sharded zarr datasets, which is the default behavior of this function. - However, this function can also output h5 datasets and also unshuffled datasets as well. - The var space is by default outer-joined, but can be subsetted by `var_subset`. - A key `src_path` is added to `obs` to indicate where individual row came from. - We highly recommend making your indexes unique across files, and this function will call {meth}`AnnData.obs_names_make_unique`. - Memory usage should be controlled by `n_obs_per_dataset` as so many rows will be read into memory before writing to disk. + _group: T - Parameters - ---------- - adata_paths - Paths to the AnnData files used to create the zarr store. - output_path - Path to the output zarr store. - load_adata - Function to customize lazy-loading the invidiual input anndata files. By default, {func}`anndata.experimental.read_lazy` is used. - If you only need a subset of the input anndata files' elems (e.g., only `X` and `obs`), you can provide a custom function here to speed up loading and harmonize your data. - The input to the function is a path to an anndata file, and the output is an anndata object which has `X` as a {class}`dask.array.Array`. - var_subset - Subset of gene names to include in the store. If None, all genes are included. - Genes are subset based on the `var_names` attribute of the concatenated AnnData object. - zarr_sparse_chunk_size - Size of the chunks to use for the `indices` and `data` of a sparse matrix in the zarr store. - zarr_sparse_shard_size - Size of the shards to use for the `indices` and `data` of a sparse matrix in the zarr store. - zarr_dense_chunk_size - Number of observations per dense zarr chunk i.e., sharding is only done along the first axis of the array. - zarr_dense_shard_size - Number of observations per dense zarr shard i.e., chunking is only done along the first axis of the array. - zarr_compressor - Compressors to use to compress the data in the zarr store. - h5ad_compressor - Compressors to use to compress the data in the h5ad store. See anndata.write_h5ad. - n_obs_per_dataset - Number of observations to load into memory at once for shuffling / pre-processing. - The higher this number, the more memory is used, but the better the shuffling. - This corresponds to the size of the shards created. - shuffle - Whether to shuffle the data before writing it to the store. - should_denseify - Whether to write as dense on disk. There's no need to set this for sparse data, it is only for testing. - output_format - Format of the output store. Can be either "zarr" or "h5ad". - - Examples - -------- - >>> import anndata as ad - >>> from annbatch import create_anndata_collection - # create a custom load function to only keep `.X`, `.obs` and `.var` in the output store - >>> def read_lazy_x_and_obs_only(path): - ... adata = ad.experimental.read_lazy(path) - ... return ad.AnnData( - ... X=adata.X, - ... obs=adata.obs.to_memory(), - ... var=adata.var.to_memory(), - ...) - - >>> datasets = [ - ... "path/to/first_adata.h5ad", - ... "path/to/second_adata.h5ad", - ... "path/to/third_adata.h5ad", - ... ] - >>> create_anndata_collection( - ... datasets, - ... "path/to/output/zarr_store", - ... load_adata=read_lazy_x_and_obs_only, - ...) - """ - Path(output_path).mkdir(parents=True, exist_ok=True) - _check_for_mismatched_keys(adata_paths) - adata_concat = _lazy_load_anndatas(adata_paths, load_adata=load_adata) - adata_concat.obs_names_make_unique() - chunks = _create_chunks_for_shuffling(adata_concat, n_obs_per_dataset, shuffle=shuffle) - - if var_subset is None: - var_subset = adata_concat.var_names - - for i, chunk in enumerate(tqdm(chunks, desc="processing chunks")): - var_mask = adata_concat.var_names.isin(var_subset) - # np.sort: It's more efficient to access elements sequentially from dask arrays - # The data will be shuffled later on, we just want the elements at this point - adata_chunk = adata_concat[np.sort(chunk), :][:, var_mask].copy() - adata_chunk = _persist_adata_in_memory(adata_chunk) - if shuffle: - # shuffle adata in memory to break up individual chunks - idxs = np.random.default_rng().permutation(np.arange(len(adata_chunk))) - adata_chunk = adata_chunk[idxs] - # convert to dense format before writing to disk - if should_denseify: - # Need to convert back to dask array to avoid memory issues when converting large sparse matrices to dense - adata_chunk = adata_chunk.copy() - adata_chunk.X = da.from_array( - adata_chunk.X, chunks=(zarr_dense_chunk_size, -1), meta=adata_chunk.X - ).map_blocks(lambda xx: xx.toarray(), dtype=adata_chunk.X.dtype) - - if output_format == "zarr": - f = zarr.open_group(Path(output_path) / f"{DATASET_PREFIX}_{i}.zarr", mode="w") - write_sharded( - f, - adata_chunk, - sparse_chunk_size=zarr_sparse_chunk_size, - sparse_shard_size=zarr_sparse_shard_size, - dense_chunk_size=zarr_dense_chunk_size, - dense_shard_size=zarr_dense_shard_size, - compressors=zarr_compressor, - ) - elif output_format == "h5ad": - adata_chunk.write_h5ad(Path(output_path) / f"{DATASET_PREFIX}_{i}.h5ad", compression=h5ad_compressor) - else: - raise ValueError(f"Unrecognized output_format: {output_format}. Only 'zarr' and 'h5ad' are supported.") + def __init__(self, group: zarr.Group | h5py.Group | str | Path, *, mode: Literal["a", "r", "r+"] = "a"): + """Initialization of the object at a given location. + Note that if the group is a h5py/zarr object, it must have the correct permissions for any subsequent operations you plan to do. + Otherwise, the store will be opened according to the mode argument. -def _get_array_encoding_type(path: PathLike[str] | str) -> str: - shards = list(Path(path).glob(f"{DATASET_PREFIX}_*.zarr")) - with open(shards[0] / "X" / "zarr.json") as f: - encoding = json.load(f) - return encoding["attributes"]["encoding-type"] + Parameters + ---------- + group + The base location for a preshuffled collection + """ + if not isinstance(group, h5py.Group | zarr.Group): + if isinstance(group, str | Path): + if str(group).endswith("h5ad"): + self._group = h5py.File(group, mode=mode) + elif str(group).endswith("zarr"): + self._group = zarr.open_group(group, mode=mode) + else: + raise ValueError("String argument must end in h5ad or zarr") + else: + raise TypeError("group must be a zarr or hdf5 group") + else: + self._group = group -@_with_settings -def add_to_collection( - adata_paths: Iterable[PathLike[str]] | Iterable[str], - output_path: PathLike[str] | str, - load_adata: Callable[[PathLike[str] | str], ad.AnnData] = ad.read_h5ad, - zarr_sparse_chunk_size: int = 32768, - zarr_sparse_shard_size: int = 134_217_728, - zarr_dense_chunk_size: int = 1024, - zarr_dense_shard_size: int = 4_194_304, - zarr_compressor: Iterable[BytesBytesCodec] = (BloscCodec(cname="lz4", clevel=3, shuffle=BloscShuffle.shuffle),), - should_sparsify_output_in_memory: bool = False, -) -> None: - """Add anndata files to an existing collection of sharded anndata zarr datasets. + @property + def _dataset_keys(self) -> list[str]: + return sorted( + [k for k in self._group.keys() if re.match(rf"{DATASET_PREFIX}_([0-9]*)", k) is not None], + key=lambda x: int(x.split("_")[1]), + ) - The var space of the source anndata files will be adapted to the target store. + def __iter__(self) -> Generator[T]: + for k in self._dataset_keys: + yield self._group[k] - Parameters - ---------- - adata_paths - Paths to the anndata files to be appended to the collection of output chunks. - output_path - Path to the output zarr store. - load_adata - Function to customize loading the invidiual input anndata files. By default, {func}`anndata.read_h5ad` is used. - If you only need a subset of the input anndata files' elems (e.g., only `X` and `obs`), you can provide a custom function here to speed up loading and harmonize your data. - The input to the function is a path to an anndata file, and the output is an anndata object. - If the input data is too large to fit into memory, you should use `ad.experimental.read_lazy` instead. - zarr_sparse_chunk_size - Size of the chunks to use for the `indices` and `data` of a sparse matrix in the zarr store. - zarr_sparse_shard_size - Size of the shards to use for the `indices` and `data` of a sparse matrix in the zarr store. - zarr_dense_chunk_size - Number of observations per dense zarr chunk i.e., sharding is only done along the first axis of the array. - zarr_dense_shard_size - Number of observations per dense zarr shard i.e., chunking is only done along the first axis of the array. - zarr_compressor - Compressors to use to compress the data in the zarr store. - should_sparsify_output_in_memory - This option is for testing only appending sparse files to dense stores. - To save memory, the blocks of a dense on-disk store can be sparsified for in-memory processing. - - Examples - -------- - >>> import anndata as ad - >>> from annbatch import add_to_collection - >>> datasets = [ - ... "path/to/first_adata.h5ad", - ... "path/to/second_adata.h5ad", - ... "path/to/third_adata.h5ad", - ... ] - >>> add_to_collection( - ... datasets, - ... "path/to/output/zarr_store", - ... load_adata=ad.read_h5ad, # replace with ad.experimental.read_lazy if data does not fit into memory - ...) - """ - shards = list(Path(output_path).glob(f"{DATASET_PREFIX}_*.zarr")) - if len(shards) == 0: - raise ValueError( - "Store at `output_path` does not exist or is empty. Please run `create_anndata_collection` first." - ) - encoding = _get_array_encoding_type(output_path) - if encoding == "array": - print("Detected array encoding type. Will convert to dense format before writing.") - # Check for mismatched keys among the inputs. - _check_for_mismatched_keys(adata_paths) - - adata_concat = _lazy_load_anndatas(adata_paths, load_adata=load_adata) - # Check for mismatched keys between shards and the inputs. - _check_for_mismatched_keys([adata_concat] + shards) - if isinstance(adata_concat.X, DaskArray): - chunks = _create_chunks_for_shuffling(adata_concat, np.ceil(len(adata_concat) / len(shards)), shuffle=True) - else: - chunks = np.array_split(np.random.default_rng().permutation(len(adata_concat)), len(shards)) - - adata_concat.obs_names_make_unique() - if encoding == "array": - if not should_sparsify_output_in_memory: - if isinstance(adata_concat.X, sp.spmatrix): - adata_concat.X = adata_concat.X.toarray() - elif isinstance(adata_concat.X, DaskArray) and isinstance(adata_concat.X._meta, sp.spmatrix): - adata_concat.X = adata_concat.X.map_blocks( - lambda x: x.toarray(), meta=np.ndarray, dtype=adata_concat.X.dtype - ) - elif encoding == "csr_matrix": - if isinstance(adata_concat.X, np.ndarray): - adata_concat.X = sp.csr_matrix(adata_concat.X) - elif isinstance(adata_concat.X, DaskArray) and isinstance(adata_concat.X._meta, np.ndarray): - adata_concat.X = adata_concat.X.map_blocks( - sp.csr_matrix, meta=sp.csr_matrix(np.array([0], dtype=adata_concat.X.dtype)) - ) + @property + def is_empty(self) -> bool: + """Wether or not there is an existing store at the group location.""" + return not (V1_ENCODING.items() <= self._group.attrs.items()) or len(self._dataset_keys) == 0 - for shard, chunk in tqdm(zip(shards, chunks, strict=False), total=len(shards), desc="processing chunks"): - if should_sparsify_output_in_memory and encoding == "array": - adata_shard = _lazy_load_anndatas([shard]) - adata_shard.X = adata_shard.X.map_blocks(sp.csr_matrix).compute() + @_with_settings + def add_adatas( + self, + adata_paths: Iterable[PathLike[str]] | Iterable[str], + *, + load_adata: Callable[[PathLike[str] | str], ad.AnnData] = lambda x: ad.experimental.read_lazy( + x, load_annotation_index=False + ), + var_subset: Iterable[str] | None = None, + zarr_sparse_chunk_size: int = 32768, + zarr_sparse_shard_size: int = 134_217_728, + zarr_dense_chunk_size: int = 1024, + zarr_dense_shard_size: int = 4_194_304, + zarr_compressor: Iterable[BytesBytesCodec] = (BloscCodec(cname="lz4", clevel=3, shuffle=BloscShuffle.shuffle),), + h5ad_compressor: Literal["gzip", "lzf"] | None = "gzip", + n_obs_per_dataset: int = 2_097_152, + shuffle_chunk_size: int = 1000, + shuffle: bool = True, + ) -> Self: + """Take AnnData paths and create or add to an on-disk set of AnnData datasets with uniform var spaces at the desired path (with `n_obs_per_dataset` rows per dataset if running for the first time). + + The set of AnnData datasets is collectively referred to as a "collection" where each dataset is called `dataset_i.{zarr,h5ad}`. + The main purpose of this function is to create shuffled sharded zarr datasets, which is the default behavior of this function. + However, this function can also output h5 datasets and also unshuffled datasets as well. + The var space is by default outer-joined initially, and then subsequently added datasets (i.e., on second calls to this function) are subsetted, but this behavior can be controlled by `var_subset`. + A key `src_path` is added to `obs` to indicate where individual row came from. + We highly recommend making your indexes unique across files, and this function will call `AnnData.obs_names_make_unique`. + Memory usage should be controlled by `n_obs_per_dataset` + `shuffle_chunk_size` as so many rows will be read into memory before writing to disk. + After the dataset completes, a marker is added to the group's `attrs` to note that this dataset has been shuffled by `annbatch`. + This is not a stable API but only for internal purposes at the moment. + + Parameters + ---------- + adata_paths + Paths to the AnnData files used to create the zarr store. + load_adata + Function to customize lazy-loading the invidiual input anndata files. By default, :func:`anndata.experimental.read_lazy` is used. + If you only need a subset of the input anndata files' elems (e.g., only `X` and `obs`), you can provide a custom function here to speed up loading and harmonize your data. + The input to the function is a path to an anndata file, and the output is an :class:`anndata.AnnData` object. + var_subset + Subset of gene names to include in the store. If None, all genes are included. + Genes are subset based on the `var_names` attribute of the concatenated AnnData object. + zarr_sparse_chunk_size + Size of the chunks to use for the `indices` and `data` of a sparse matrix in the zarr store. + zarr_sparse_shard_size + Size of the shards to use for the `indices` and `data` of a sparse matrix in the zarr store. + zarr_dense_chunk_size + Number of observations per dense zarr chunk i.e., sharding is only done along the first axis of the array. + zarr_dense_shard_size + Number of observations per dense zarr shard i.e., chunking is only done along the first axis of the array. + zarr_compressor + Compressors to use to compress the data in the zarr store. + h5ad_compressor + Compressors to use to compress the data in the h5ad store. See anndata.write_h5ad. + n_obs_per_dataset + Number of observations to load into memory at once for shuffling / pre-processing. + The higher this number, the more memory is used, but the better the shuffling. + This corresponds to the size of the shards created. + Only applicable when adding datasets for the first time, otherwise ignored. + shuffle + Whether to shuffle the data before writing it to the store. + Ignored once the store is non-empty. + shuffle_chunk_size + How many contiguous rows to load into memory before shuffling at once. + `(shuffle_chunk_size // n_obs_per_dataset)` slices will be loaded of size `shuffle_chunk_size`. + + Examples + -------- + >>> import anndata as ad + >>> from annbatch import DatasetCollection + # create a custom load function to only keep `.X`, `.obs` and `.var` in the output store + >>> def read_lazy_x_and_obs_only(path): + ... adata = ad.experimental.read_lazy(path) + ... return ad.AnnData( + ... X=adata.X, + ... obs=adata.obs.to_memory(), + ... var=adata.var.to_memory(), + ...) + >>> datasets = [ + ... "path/to/first_adata.h5ad", + ... "path/to/second_adata.h5ad", + ... "path/to/third_adata.h5ad", + ... ] + >>> DatasetCollection("path/to/output/zarr_store.zarr").add_adatas( + ... datasets, + ... load_adata=read_lazy_x_and_obs_only, + ...) + """ + if shuffle_chunk_size > n_obs_per_dataset: + raise ValueError("Cannot have a large slice size than observations per dataset") + shared_kwargs = { + "adata_paths": adata_paths, + "load_adata": load_adata, + "zarr_sparse_chunk_size": zarr_sparse_chunk_size, + "zarr_sparse_shard_size": zarr_sparse_shard_size, + "zarr_dense_chunk_size": zarr_dense_chunk_size, + "zarr_dense_shard_size": zarr_dense_shard_size, + "zarr_compressor": zarr_compressor, + "h5ad_compressor": h5ad_compressor, + "shuffle_chunk_size": shuffle_chunk_size, + "shuffle": shuffle, + } + if self.is_empty: + self._create_collection(**shared_kwargs, n_obs_per_dataset=n_obs_per_dataset, var_subset=var_subset) else: - adata_shard = ad.read_zarr(shard) - subset_adata = _to_categorical_obs( - adata_concat[chunk, :][:, adata_concat.var.index.isin(adata_shard.var.index)] + self._add_to_collection(**shared_kwargs) + return self + + def _create_collection( + self, + *, + adata_paths: Iterable[PathLike[str]] | Iterable[str], + load_adata: Callable[[PathLike[str] | str], ad.AnnData] = lambda x: ad.experimental.read_lazy( + x, load_annotation_index=False + ), + var_subset: Iterable[str] | None = None, + zarr_sparse_chunk_size: int = 32768, + zarr_sparse_shard_size: int = 134_217_728, + zarr_dense_chunk_size: int = 1024, + zarr_dense_shard_size: int = 4_194_304, + zarr_compressor: Iterable[BytesBytesCodec] = (BloscCodec(cname="lz4", clevel=3, shuffle=BloscShuffle.shuffle),), + h5ad_compressor: Literal["gzip", "lzf"] | None = "gzip", + n_obs_per_dataset: int = 2_097_152, + shuffle_chunk_size: int = 1000, + shuffle: bool = True, + ) -> None: + """Take AnnData paths, create an on-disk set of AnnData datasets with uniform var spaces at the desired path with `n_obs_per_dataset` rows per dataset. + + The set of AnnData datasets is collectively referred to as a "collection" where each dataset is called `dataset_i.{zarr,h5ad}`. + The main purpose of this function is to create shuffled sharded zarr datasets, which is the default behavior of this function. + However, this function can also output h5 datasets and also unshuffled datasets as well. + The var space is by default outer-joined, but can be subsetted by `var_subset`. + A key `src_path` is added to `obs` to indicate where individual row came from. + We highly recommend making your indexes unique across files, and this function will call `AnnData.obs_names_make_unique`. + Memory usage should be controlled by `n_obs_per_dataset` as so many rows will be read into memory before writing to disk. + + Parameters + ---------- + adata_paths + Paths to the AnnData files used to create the zarr store. + load_adata + Function to customize lazy-loading the invidiual input anndata files. By default, :func:`anndata.experimental.read_lazy` is used. + If you only need a subset of the input anndata files' elems (e.g., only `X` and `obs`), you can provide a custom function here to speed up loading and harmonize your data. + The input to the function is a path to an anndata file, and the output is an anndata object which has `X` as a :class:`dask.array.Array`. + var_subset + Subset of gene names to include in the store. If None, all genes are included. + Genes are subset based on the `var_names` attribute of the concatenated AnnData object. + Only applicable when adding datasets for the first time, otherwise ignored and the incoming data's var space is subsetted to that of the existing collection. + zarr_sparse_chunk_size + Size of the chunks to use for the `indices` and `data` of a sparse matrix in the zarr store. + zarr_sparse_shard_size + Size of the shards to use for the `indices` and `data` of a sparse matrix in the zarr store. + zarr_dense_chunk_size + Number of observations per dense zarr chunk i.e., sharding is only done along the first axis of the array. + zarr_dense_shard_size + Number of observations per dense zarr shard i.e., chunking is only done along the first axis of the array. + zarr_compressor + Compressors to use to compress the data in the zarr store. + h5ad_compressor + Compressors to use to compress the data in the h5ad store. See anndata.write_h5ad. + n_obs_per_dataset + Number of observations to load into memory at once for shuffling / pre-processing. + The higher this number, the more memory is used, but the better the shuffling. + This corresponds to the size of the shards created. + Only applicable when adding datasets for the first time, otherwise ignored. + shuffle + Whether to shuffle the data before writing it to the store. + shuffle_chunk_size + How many contiguous rows to load into memory before shuffling at once. + `(shuffle_chunk_size // n_obs_per_dataset)` slices will be loaded of size `shuffle_chunk_size`. + """ + if not self.is_empty: + raise RuntimeError("Cannot create a collection at a location that already has a shuffled collection") + _check_for_mismatched_keys(adata_paths, load_adata=load_adata) + adata_concat = _lazy_load_anndatas(adata_paths, load_adata=load_adata) + adata_concat.obs_names_make_unique() + n_obs_per_dataset = min(adata_concat.shape[0], n_obs_per_dataset) + chunks = _create_chunks_for_shuffling( + adata_concat.shape[0], shuffle_chunk_size, shuffle=shuffle, shuffle_n_obs_per_dataset=n_obs_per_dataset ) - adata = ad.concat([adata_shard, subset_adata], join="outer") - idxs_shuffled = np.random.default_rng().permutation(len(adata)) - adata = adata[idxs_shuffled, :].copy() # this significantly speeds up writing to disk - if should_sparsify_output_in_memory and encoding == "array": - adata.X = adata.X.map_blocks(lambda x: x.toarray(), meta=np.array([0], dtype=adata.X.dtype)).compute() - - f = zarr.open_group(shard, mode="w") - write_sharded( - f, - adata, - sparse_chunk_size=zarr_sparse_chunk_size, - sparse_shard_size=zarr_sparse_shard_size, - dense_chunk_size=zarr_dense_chunk_size, - dense_shard_size=zarr_dense_shard_size, - compressors=zarr_compressor, + + if var_subset is None: + var_subset = adata_concat.var_names + for i, chunk in enumerate(tqdm(chunks, desc="processing chunks")): + var_mask = adata_concat.var_names.isin(var_subset) + # np.sort: It's more efficient to access elements sequentially from dask arrays + # The data will be shuffled later on, we just want the elements at this point + adata_chunk = adata_concat[np.sort(chunk), :][:, var_mask].copy() + adata_chunk = _persist_adata_in_memory(adata_chunk) + if shuffle: + # shuffle adata in memory to break up individual chunks + idxs = np.random.default_rng().permutation(np.arange(len(adata_chunk))) + adata_chunk = adata_chunk[idxs] + if isinstance(self._group, zarr.Group): + write_sharded( + self._group, + adata_chunk, + sparse_chunk_size=zarr_sparse_chunk_size, + sparse_shard_size=zarr_sparse_shard_size, + dense_chunk_size=min(adata_chunk.shape[0], zarr_dense_chunk_size), + dense_shard_size=min(adata_chunk.shape[0], zarr_dense_shard_size), + compressors=zarr_compressor, + key=f"{DATASET_PREFIX}_{i}", + ) + else: + ad.io.write_elem( + self._group, f"{DATASET_PREFIX}_{i}", adata_chunk, dataset_kwargs={"compression": h5ad_compressor} + ) + if isinstance(self._group, zarr.Group): + self._group.update_attributes(V1_ENCODING) + else: + self._group.attrs.update(V1_ENCODING) + + def _add_to_collection( + self, + *, + adata_paths: Iterable[PathLike[str]] | Iterable[str], + load_adata: Callable[[PathLike[str] | str], ad.AnnData] = ad.read_h5ad, + zarr_sparse_chunk_size: int = 32768, + zarr_sparse_shard_size: int = 134_217_728, + zarr_dense_chunk_size: int = 1024, + zarr_dense_shard_size: int = 4_194_304, + zarr_compressor: Iterable[BytesBytesCodec] = (BloscCodec(cname="lz4", clevel=3, shuffle=BloscShuffle.shuffle),), + h5ad_compressor: Literal["gzip", "lzf"] | None = "gzip", + shuffle_chunk_size: int = 1000, + shuffle: bool = True, + ) -> None: + """Add anndata files to an existing collection of sharded anndata zarr datasets. + + The var space of the source anndata files will be adapted to the target store. + + Parameters + ---------- + adata_paths + Paths to the anndata files to be appended to the collection of output chunks. + load_adata + Function to customize loading the invidiual input anndata files. By default, :func:`anndata.read_h5ad` is used. + If you only need a subset of the input anndata files' elems (e.g., only `X` and `obs`), you can provide a custom function here to speed up loading and harmonize your data. + The input to the function is a path to an anndata file, and the output is an anndata object. + If the input data is too large to fit into memory, you should use :func:`annndata.experimental.read_lazy` instead. + zarr_sparse_chunk_size + Size of the chunks to use for the `indices` and `data` of a sparse matrix in the zarr store. + zarr_sparse_shard_size + Size of the shards to use for the `indices` and `data` of a sparse matrix in the zarr store. + zarr_dense_chunk_size + Number of observations per dense zarr chunk i.e., sharding is only done along the first axis of the array. + zarr_dense_shard_size + Number of observations per dense zarr shard i.e., chunking is only done along the first axis of the array. + zarr_compressor + Compressors to use to compress the data in the zarr store. + should_sparsify_output_in_memory + This option is for testing only appending sparse files to dense stores. + To save memory, the blocks of a dense on-disk store can be sparsified for in-memory processing. + shuffle_chunk_size + How many contiguous rows to load into memory of the input data for pseudo-blockshuffling into the existing datasets. + shuffle + Whether or not to shuffle when adding. Otherwise, the incoming data will just be split up and appended. + """ + if self.is_empty: + raise ValueError("Store is empty. Please run `DatasetCollection.add` first.") + # Check for mismatched keys among the inputs. + _check_for_mismatched_keys(adata_paths, load_adata=load_adata) + + adata_concat = _lazy_load_anndatas(adata_paths, load_adata=load_adata) + if math.ceil(adata_concat.shape[0] / shuffle_chunk_size) < len(self._dataset_keys): + raise ValueError( + f"Use a shuffle size small enough to distribute the input data with {adata_concat.shape[0]} obs across {len(self._dataset_keys)} anndata stores." + "Open an issue if the incoming anndata is so small it cannot be distributed across the on-disk data" + ) + # Check for mismatched keys between datasets and the inputs. + _check_for_mismatched_keys([adata_concat] + [self._group[k] for k in self._dataset_keys]) + chunks = _create_chunks_for_shuffling( + adata_concat.shape[0], shuffle_chunk_size, shuffle=shuffle, n_chunkings=len(self._dataset_keys) ) + + adata_concat.obs_names_make_unique() + for dataset, chunk in tqdm( + zip(self._dataset_keys, chunks, strict=True), total=len(self._dataset_keys), desc="processing chunks" + ): + adata_dataset = ad.io.read_elem(self._group[dataset]) + subset_adata = _to_categorical_obs( + adata_concat[chunk, :][:, adata_concat.var.index.isin(adata_dataset.var.index)] + ) + adata = ad.concat([adata_dataset, subset_adata], join="outer") + if shuffle: + idxs = np.random.default_rng().permutation(adata.shape[0]) + else: + idxs = np.arange(adata.shape[0]) + adata = _persist_adata_in_memory(adata[idxs, :].copy()) + if isinstance(self._group, zarr.Group): + write_sharded( + self._group, + adata, + sparse_chunk_size=zarr_sparse_chunk_size, + sparse_shard_size=zarr_sparse_shard_size, + dense_chunk_size=min(adata.shape[0], zarr_dense_chunk_size), + dense_shard_size=min(adata.shape[0], zarr_dense_shard_size), + compressors=zarr_compressor, + key=dataset, + ) + else: + ad.io.write_elem(self._group, dataset, adata, dataset_kwargs={"compression": h5ad_compressor}) diff --git a/src/annbatch/loader.py b/src/annbatch/loader.py index 2c93e20e..d3014421 100644 --- a/src/annbatch/loader.py +++ b/src/annbatch/loader.py @@ -24,6 +24,7 @@ _batched, check_lt_1, check_var_shapes, + load_x_and_obs, split_given_size, to_torch, ) @@ -31,9 +32,11 @@ from .compat import IterableDataset if TYPE_CHECKING: - from collections.abc import Iterator + from collections.abc import Callable, Iterator from types import ModuleType + from annbatch.io import DatasetCollection + # TODO: remove after sphinx 9 - myst compat BackingArray = BackingArray_T OutputInMemoryArray = OutputInMemoryArray_T @@ -222,6 +225,32 @@ def n_var(self) -> int: """ return self._shapes[0][1] + def use_collection( + self, collection: DatasetCollection, *, load_adata: Callable[[zarr.Group], ad.AnnData] = load_x_and_obs + ) -> Self: + """Load from an existing :class:`annbatch.DatasetCollection`. + + This function can only be called once. If you want to manually add more data, use :meth:`Loader.add_anndatas` or open an issue. + + Parameters + ---------- + collection + The collection who on-disk datasets should be used in this loader. + load_adata + A custom load function - recall that whatever is found in :attr:`~anndata.AnnData.X` and :attr:`~anndata.AnnData.obs` will be yielded in batches. + Default is to just load `X` and `obs`. + """ + if collection.is_empty: + raise ValueError("DatasetCollection is empty") + if getattr(self, "_collection_added", False): + raise RuntimeError( + "You should not add multiple collections, independently shuffled - please preshuffle multiple collections, use `add_anndatas` manually if you know what you are doing, or open an issue if you believe that this should be supported at an API level higher than `add_anndatas`." + ) + adatas = [load_adata(g) for g in collection] + self.add_anndatas(adatas) + self._collection_added = True + return self + def add_anndatas( self, adatas: list[ad.AnnData], diff --git a/src/annbatch/utils.py b/src/annbatch/utils.py index a0a827ba..84d2bc47 100644 --- a/src/annbatch/utils.py +++ b/src/annbatch/utils.py @@ -7,6 +7,7 @@ from itertools import islice from typing import TYPE_CHECKING, Protocol +import anndata as ad import numpy as np import scipy as sp import zarr @@ -65,50 +66,6 @@ def __iter__(self): total += gap -def sample_rows( - x_list: list[np.ndarray], - obs_list: list[np.ndarray] | None, - indices: list[np.ndarray] | None = None, - *, - shuffle: bool = True, -) -> Generator[tuple[np.ndarray, np.ndarray | None], None, None]: - """Samples rows from multiple arrays and their corresponding observation arrays. - - Parameters - ---------- - x_list - A list of numpy arrays containing the data to sample from. - obs_list - A list of numpy arrays containing the corresponding observations. - indices - the list of indexes for each element in `x_list/` - shuffle - Whether to shuffle the rows before sampling. - - Yields - ------ - tuple - A tuple containing a row from `x_list` and the corresponding row from `obs_list`. - """ - lengths = np.fromiter((x.shape[0] for x in x_list), dtype=int) - cum = np.concatenate(([0], np.cumsum(lengths))) - total = cum[-1] - idxs = np.arange(total) - if shuffle: - np.random.default_rng().shuffle(idxs) - arr_idxs = np.searchsorted(cum, idxs, side="right") - 1 - row_idxs = idxs - cum[arr_idxs] - for ai, ri in zip(arr_idxs, row_idxs, strict=True): - res = [ - x_list[ai][ri], - obs_list[ai][ri] if obs_list is not None else None, - ] - if indices is not None: - yield (*res, indices[ai][ri]) - else: - yield tuple(res) - - class WorkerHandle: # noqa: D101 @cached_property def _worker_info(self): @@ -234,3 +191,10 @@ def to_torch(input: OutputInMemoryArray_T, preload_to_gpu: bool) -> Tensor: input.shape, ) raise TypeError(f"Cannot convert {type(input)} to torch.Tensor") + + +def load_x_and_obs(g: zarr.Group) -> ad.AnnData: + """Load X as a sparse array or dense zarr array and obs from a group""" + return ad.AnnData( + X=g["X"] if isinstance(g["X"], zarr.Array) else ad.io.sparse_dataset(g["X"]), obs=ad.io.read_elem(g["obs"]) + ) diff --git a/tests/conftest.py b/tests/conftest.py index aa012e34..c34d9a73 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -13,6 +13,7 @@ from scipy.sparse import random as sparse_random from annbatch import write_sharded +from annbatch.io import DatasetCollection if TYPE_CHECKING: from collections.abc import Generator @@ -77,6 +78,8 @@ def adata_with_h5_path_different_var_space( n_cells = [random.randint(50, 100) for _ in range(n_adatas)] adatas = [] for i, (m, n) in enumerate(zip(n_cells, n_features, strict=True)): + var_idx = [f"gene_{gene}" for gene in range(n // 2)] + [f"gene_{gene}_{i}" for gene in range(n // 2, n)] + obs_idx = np.arange(m).astype(str) + f"-{i}" adata = ad.AnnData( X=sparse_random(m, n, density=0.1, format="csr", dtype="f4"), obs=pd.DataFrame( @@ -85,12 +88,11 @@ def adata_with_h5_path_different_var_space( "store_id": [i] * m, "numeric": np.arange(m), }, - index=np.arange(m).astype(str), + index=obs_idx, ), - var=pd.DataFrame( - index=[f"gene_{gene}" for gene in range(n // 2)] + [f"gene_{gene}_{i}" for gene in range(n // 2, n)] - ), - obsm={"arr": np.random.randn(m, 10)}, + var=pd.DataFrame(index=var_idx), + obsm={"arr": np.random.randn(m, 10), "df": pd.DataFrame({"numeric": np.arange(m)}, index=obs_idx)}, + varm={"arr": np.random.randn(n, 10), "df": pd.DataFrame({"numeric": np.arange(n)}, index=var_idx)}, ) if all_adatas_have_raw or (i % 2 == 0): adata_raw = adata[:, adata.var.index[: (n // 2)]].copy() @@ -102,3 +104,21 @@ def adata_with_h5_path_different_var_space( [ad.read_h5ad(tmp_path / shard) for shard in sorted(tmp_path.iterdir()) if str(shard).endswith(".h5ad")], join="outer", ), tmp_path + + +@pytest.fixture(scope="session") +def simple_collection( + tmpdir_factory, adata_with_zarr_path_same_var_space: tuple[ad.AnnData, Path] +) -> tuple[DatasetCollection, ad.AnnData]: + zarr_stores = sorted(f for f in adata_with_zarr_path_same_var_space[1].iterdir() if f.is_dir()) + output_path = Path(tmpdir_factory.mktemp("zarr_folder")) / "simple_fixture.zarr" + collection = DatasetCollection(output_path).add_adatas( + zarr_stores, + zarr_sparse_chunk_size=10, + zarr_sparse_shard_size=20, + zarr_dense_chunk_size=10, + zarr_dense_shard_size=20, + n_obs_per_dataset=60, + shuffle_chunk_size=10, + ) + return ad.concat([ad.io.read_elem(ds) for ds in collection], join="outer"), collection diff --git a/tests/test_dataset.py b/tests/test_dataset.py index 9687ae58..e30a38cf 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -25,6 +25,8 @@ from collections.abc import Callable from pathlib import Path + from annbatch.io import DatasetCollection + class Data(TypedDict): dataset: ad.abc.CSRDataset | zarr.Array @@ -36,31 +38,34 @@ class ListData: obs: list[np.ndarray] -def open_sparse(path: Path, *, use_zarrs: bool = False, use_anndata: bool = False) -> Data | ad.AnnData: +def open_sparse(path: Path | zarr.Group, *, use_zarrs: bool = False, use_anndata: bool = False) -> Data | ad.AnnData: old_pipeline = zarr.config.get("codec_pipeline.path") with zarr.config.set({"codec_pipeline.path": "zarrs.ZarrsCodecPipeline" if use_zarrs else old_pipeline}): + if not isinstance(path, zarr.Group): + path = zarr.open(path) data = { - "dataset": ad.io.sparse_dataset(zarr.open(path)["layers"]["sparse"]), - "obs": ad.io.read_elem(zarr.open(path)["obs"]), + "dataset": ad.io.sparse_dataset(path["layers"]["sparse"]), + "obs": ad.io.read_elem(path["obs"]), } if use_anndata: return ad.AnnData(X=data["dataset"], obs=data["obs"]) return data -def open_dense(path: Path, *, use_zarrs: bool = False, use_anndata: bool = False) -> Data | ad.AnnData: +def open_dense(path: Path | zarr.Group, *, use_zarrs: bool = False, use_anndata: bool = False) -> Data | ad.AnnData: old_pipeline = zarr.config.get("codec_pipeline.path") with zarr.config.set({"codec_pipeline.path": "zarrs.ZarrsCodecPipeline" if use_zarrs else old_pipeline}): + if not isinstance(path, zarr.Group): + path = zarr.open(path) data = { - "dataset": zarr.open(path)["X"], - "obs": ad.io.read_elem(zarr.open(path)["obs"]), + "dataset": path["X"], + "obs": ad.io.read_elem(path["obs"]), } if use_anndata: return ad.AnnData(X=data["dataset"], obs=data["obs"]) return data - return data def concat(datas: list[Data | ad.AnnData]) -> ListData | list[ad.AnnData]: @@ -79,7 +84,7 @@ def concat(datas: list[Data | ad.AnnData]) -> ListData | list[ad.AnnData]: "gen_loader", [ pytest.param( - lambda path, + lambda collection, shuffle, use_zarrs, chunk_size=chunk_size, @@ -94,10 +99,15 @@ def concat(datas: list[Data | ad.AnnData]) -> ListData | list[ad.AnnData]: batch_size=batch_size, preload_to_gpu=preload_to_gpu, to_torch=False, - ).add_anndatas( - [open_func(p, use_zarrs=use_zarrs, use_anndata=True) for p in path.glob("*.zarr")], + ).use_collection( + collection, + **( + {"load_adata": lambda group: open_func(group, use_zarrs=use_zarrs, use_anndata=True)} + if open_func is not None + else {} + ), ), - id=f"chunk_size={chunk_size}-preload_nchunks={preload_nchunks}-dataset_type={open_func.__name__[5:]}-batch_size={batch_size}{'-cupy' if preload_to_gpu else ''}", # type: ignore[attr-defined] + id=f"chunk_size={chunk_size}-preload_nchunks={preload_nchunks}-open_func={open_func.__name__[5:] if open_func is not None else 'None'}-batch_size={batch_size}{'-cupy' if preload_to_gpu else ''}", # type: ignore[attr-defined] marks=pytest.mark.skipif( find_spec("cupy") is None and preload_to_gpu, reason="need cupy installed", @@ -106,7 +116,7 @@ def concat(datas: list[Data | ad.AnnData]) -> ListData | list[ad.AnnData]: for chunk_size, preload_nchunks, open_func, batch_size, preload_to_gpu in [ elem for preload_to_gpu in [True, False] - for open_func in [open_sparse, open_dense] + for open_func in [open_sparse, open_dense, None] for elem in [ [ 1, @@ -148,7 +158,7 @@ def concat(datas: list[Data | ad.AnnData]) -> ListData | list[ad.AnnData]: ], ) def test_store_load_dataset( - adata_with_zarr_path_same_var_space: tuple[ad.AnnData, Path], *, shuffle: bool, gen_loader, use_zarrs + simple_collection: tuple[ad.AnnData, DatasetCollection], *, shuffle: bool, gen_loader, use_zarrs ): """ This test verifies that the DaskDataset works correctly: @@ -157,8 +167,8 @@ def test_store_load_dataset( 3. All samples from the dataset are processed 4. If the dataset is not shuffled, it returns the correct data """ - loader: Loader = gen_loader(adata_with_zarr_path_same_var_space[1], shuffle, use_zarrs) - adata = adata_with_zarr_path_same_var_space[0] + loader: Loader = gen_loader(simple_collection[1], shuffle, use_zarrs) + adata = simple_collection[0] is_dense = loader.dataset_type is zarr.Array n_elems = 0 batches = [] @@ -221,6 +231,13 @@ def test_bad_adata_X_type(adata_with_zarr_path_same_var_space: tuple[ad.AnnData, ds.add_dataset(**data) +def test_use_collection_twice(simple_collection: tuple[ad.AnnData, DatasetCollection]): + ds = Loader() + ds = ds.use_collection(simple_collection[1]) + with pytest.raises(RuntimeError, match="You should not add multiple collections"): + ds.use_collection(simple_collection[1]) + + @pytest.mark.skipif(not find_spec("torch"), reason="need torch installed") @pytest.mark.parametrize( "preload_to_gpu", diff --git a/tests/test_store_creation.py b/tests/test_preshuffle.py similarity index 68% rename from tests/test_store_creation.py rename to tests/test_preshuffle.py index 5e9f099c..03faddb4 100644 --- a/tests/test_store_creation.py +++ b/tests/test_preshuffle.py @@ -4,13 +4,15 @@ from typing import TYPE_CHECKING, Literal import anndata as ad +import h5py import numpy as np import pandas as pd import pytest import scipy.sparse as sp import zarr -from annbatch import add_to_collection, create_anndata_collection, write_sharded +from annbatch import DatasetCollection, write_sharded +from annbatch.io import V1_ENCODING if TYPE_CHECKING: from collections.abc import Callable @@ -37,7 +39,7 @@ def test_write_sharded_shard_size_too_big(tmp_path: Path, chunk_size: int, expec @pytest.mark.parametrize("elem_name", ["obsm", "layers", "raw", "obs"]) -def test_store_creation_warngs_with_different_keys(elem_name: Literal["obsm", "layers", "raw"], tmp_path: Path): +def test_store_creation_warnings_with_different_keys(elem_name: Literal["obsm", "layers", "raw"], tmp_path: Path): adata_1 = ad.AnnData(X=np.random.randn(10, 20)) extra_args = { elem_name: {"arr" if elem_name != "raw" else "X": np.random.randn(10, 20) if elem_name != "obs" else ["a"] * 10} @@ -48,17 +50,37 @@ def test_store_creation_warngs_with_different_keys(elem_name: Literal["obsm", "l adata_1.write_h5ad(path_1) adata_2.write_h5ad(path_2) with pytest.warns(UserWarning, match=rf"Found {elem_name} keys.* not present in all anndatas"): - create_anndata_collection( + DatasetCollection(tmp_path / "collection.zarr").add_adatas( [path_1, path_2], - tmp_path / "collection", zarr_sparse_chunk_size=10, zarr_sparse_shard_size=20, zarr_dense_chunk_size=5, zarr_dense_shard_size=10, n_obs_per_dataset=10, + shuffle_chunk_size=10, ) +def test_store_creation_no_warnings_with_custom_load(tmp_path: Path): + adata_1 = ad.AnnData(X=np.random.randn(10, 20)) + adata_2 = ad.AnnData(X=np.random.randn(10, 20), layers={"arr": np.random.randn(10, 20)}) + path_1 = tmp_path / "just_x.h5ad" + path_2 = tmp_path / "with_extra_key.h5ad" + adata_1.write_h5ad(path_1) + adata_2.write_h5ad(path_2) + collection = DatasetCollection(tmp_path / "collection.zarr").add_adatas( + [path_1, path_2], + zarr_sparse_chunk_size=10, + zarr_sparse_shard_size=20, + zarr_dense_chunk_size=5, + zarr_dense_shard_size=10, + n_obs_per_dataset=10, + shuffle_chunk_size=5, + load_adata=lambda x: ad.AnnData(X=ad.io.read_elem(h5py.File(x)["X"])), + ) + assert len(ad.read_zarr(next(iter(collection))).layers.keys()) == 0 + + def test_store_creation_path_added_to_obs(tmp_path: Path): adata_1 = ad.AnnData(X=np.random.randn(10, 20)) adata_2 = adata_1.copy() @@ -67,18 +89,18 @@ def test_store_creation_path_added_to_obs(tmp_path: Path): adata_1.write_h5ad(path_1) adata_2.write_h5ad(path_2) paths = [path_1, path_2] - output_dir = tmp_path / "path_src_collection" - create_anndata_collection( + output_dir = tmp_path / "path_src_collection.zarr" + collection = DatasetCollection(output_dir).add_adatas( paths, - output_dir, zarr_sparse_chunk_size=10, zarr_sparse_shard_size=20, zarr_dense_chunk_size=5, zarr_dense_shard_size=10, n_obs_per_dataset=10, + shuffle_chunk_size=5, shuffle=False, ) - adata_result = ad.concat([ad.read_zarr(path) for path in sorted((output_dir).iterdir())], join="outer") + adata_result = ad.concat([ad.io.read_elem(g) for g in collection], join="outer") pd.testing.assert_extension_array_equal( adata_result.obs["src_path"].array, pd.Categorical(([str(path_1)] * 10) + ([str(path_2)] * 10), categories=[str(p) for p in paths]), @@ -95,16 +117,16 @@ def test_store_addition_different_keys( adata_orig = ad.AnnData(X=np.random.randn(100, 20)) orig_path = tmp_path / "orig.h5ad" adata_orig.write_h5ad(orig_path) - output_path = tmp_path / "zarr_store_addition_different_keys" - output_path.mkdir(parents=True, exist_ok=True) - create_anndata_collection( + output_path = tmp_path / "zarr_store_addition_different_keys.zarr" + collection = DatasetCollection(output_path) + collection.add_adatas( [orig_path], - output_path, zarr_sparse_chunk_size=10, zarr_sparse_shard_size=20, zarr_dense_chunk_size=10, zarr_dense_shard_size=20, n_obs_per_dataset=50, + shuffle_chunk_size=10, ) extra_args = { elem_name: {"arr" if elem_name != "raw" else "X": np.random.randn(10, 20) if elem_name != "obs" else ["a"] * 10} @@ -113,105 +135,67 @@ def test_store_addition_different_keys( additional_path = tmp_path / "with_extra_key.h5ad" adata.write_h5ad(additional_path) with pytest.warns(UserWarning, match=rf"Found {elem_name} keys.* not present in all anndatas"): - add_to_collection( + collection.add_adatas( [additional_path], - output_path, load_adata=load_adata, zarr_sparse_chunk_size=10, zarr_sparse_shard_size=20, zarr_dense_chunk_size=5, zarr_dense_shard_size=10, + shuffle_chunk_size=2, ) -def _read_lazy_x_and_obs_only(path) -> ad.AnnData: - adata_ = ad.experimental.read_lazy(path) - if adata_.raw is not None: - x = adata_.raw.X - var = adata_.raw.var - else: - x = adata_.X - var = adata_.var - - return ad.AnnData( - X=x, - obs=adata_.obs.to_memory(), - var=var.to_memory(), - ) - - +@pytest.mark.parametrize("open_store", [h5py.File, zarr.open_group]) def test_store_creation_default( adata_with_h5_path_different_var_space: tuple[ad.AnnData, Path], + open_store: Callable[[Path], zarr.Group | h5py.Group], ): - var_subset = [f"gene_{i}" for i in range(100)] h5_files = sorted(adata_with_h5_path_different_var_space[1].iterdir()) - output_path = adata_with_h5_path_different_var_space[1].parent / "zarr_store_creation_test_default" - output_path.mkdir(parents=True, exist_ok=True) - create_anndata_collection( - [adata_with_h5_path_different_var_space[1] / f for f in h5_files if str(f).endswith(".h5ad")], - output_path, - var_subset=var_subset, - zarr_sparse_chunk_size=10, - zarr_sparse_shard_size=20, - zarr_dense_chunk_size=10, - zarr_dense_shard_size=20, - n_obs_per_dataset=60, + output_path = ( + adata_with_h5_path_different_var_space[1].parent + / f"zarr_store_creation_test_default.{'h5ad' if open_store is h5py.File else 'zarr'}" ) - assert isinstance(ad.read_zarr(next((output_path).iterdir())).X, sp.csr_matrix) - assert sorted(glob.glob(str(output_path / "dataset_*.zarr"))) == sorted(str(p) for p in (output_path).iterdir()) - - -def test_store_creation_drop_elem( - adata_with_h5_path_different_var_space: tuple[ad.AnnData, Path], -): - var_subset = [f"gene_{i}" for i in range(100)] - h5_files = sorted(adata_with_h5_path_different_var_space[1].iterdir()) - output_path = adata_with_h5_path_different_var_space[1].parent / "zarr_store_creation_drop_elems" - output_path.mkdir(parents=True, exist_ok=True) - - create_anndata_collection( + store = open_store(output_path, mode="w") + collection = DatasetCollection(store).add_adatas( [adata_with_h5_path_different_var_space[1] / f for f in h5_files if str(f).endswith(".h5ad")], - output_path, - var_subset=var_subset, - zarr_sparse_chunk_size=10, - zarr_sparse_shard_size=20, - zarr_dense_chunk_size=10, - zarr_dense_shard_size=20, - n_obs_per_dataset=60, - load_adata=_read_lazy_x_and_obs_only, ) - adata_output = ad.read_zarr(next(output_path.iterdir())) - assert "arr" not in adata_output.obsm - assert adata_output.raw is None + assert len(list(iter(collection))) == 1 # default n_obs_per_dataset is much more than total obs + assert isinstance(ad.io.read_elem(next(iter(collection))).X, sp.csr_matrix) + # Test directory structure to make sure nothing extraneous was written + if isinstance(store, zarr.Group): + assert sorted(glob.glob(str(output_path / "dataset_*"))) == sorted( + str(p) for p in (output_path).iterdir() if p.is_dir() + ) + assert list(iter(collection)) == [store[k] for k in sorted(store.keys())] + assert V1_ENCODING.items() <= store.attrs.items() @pytest.mark.parametrize("shuffle", [pytest.param(True, id="shuffle"), pytest.param(False, id="no_shuffle")]) -@pytest.mark.parametrize("densify", [pytest.param(True, id="densify"), pytest.param(False, id="no_densify")]) def test_store_creation( adata_with_h5_path_different_var_space: tuple[ad.AnnData, Path], shuffle: bool, - densify: bool, ): var_subset = [f"gene_{i}" for i in range(100)] h5_files = sorted(adata_with_h5_path_different_var_space[1].iterdir()) - output_path = adata_with_h5_path_different_var_space[1].parent / f"zarr_store_creation_test_{shuffle}_{densify}" - output_path.mkdir(parents=True, exist_ok=True) - create_anndata_collection( + output_path = adata_with_h5_path_different_var_space[1].parent / f"zarr_store_creation_test_{shuffle}.zarr" + collection = DatasetCollection(output_path).add_adatas( [adata_with_h5_path_different_var_space[1] / f for f in h5_files if str(f).endswith(".h5ad")], - output_path, var_subset=var_subset, zarr_sparse_chunk_size=10, zarr_sparse_shard_size=20, zarr_dense_chunk_size=5, zarr_dense_shard_size=10, - n_obs_per_dataset=60, + n_obs_per_dataset=50, + shuffle_chunk_size=10, shuffle=shuffle, - should_denseify=densify, ) + assert not DatasetCollection(output_path).is_empty + assert V1_ENCODING.items() <= zarr.open(output_path).attrs.items() adata_orig = adata_with_h5_path_different_var_space[0] # make sure all category dtypes match - adatas_shuffled = [ad.read_zarr(zarr_path) for zarr_path in sorted(output_path.iterdir())] + adatas_shuffled = [ad.io.read_elem(g) for g in collection] for adata in adatas_shuffled: assert adata.obs["label"].dtype == adata_orig.obs["label"].dtype # subset to var_subset @@ -230,7 +214,12 @@ def test_store_creation( ) assert "arr" in adata.obsm if shuffle: + # If it's shuffled I'd expect more than 90% of elements to be out of order + assert sum(adata_orig.obs_names != adata.obs_names) > (0.9 * adata.shape[0]) + assert adata_orig.obs_names.isin(adata.obs_names).all() adata = adata[adata_orig.obs_names].copy() + else: + assert (adata_orig.obs_names == adata.obs_names).all() np.testing.assert_array_equal( adata.X if isinstance(adata.X, np.ndarray) else adata.X.toarray(), adata_orig.X if isinstance(adata_orig.X, np.ndarray) else adata_orig.X.toarray(), @@ -245,12 +234,25 @@ def test_store_creation( adata.obs["label"] = adata.obs["label"].cat.reorder_categories(adata_orig.obs["label"].dtype.categories) pd.testing.assert_frame_equal(adata.obs, adata_orig.obs) - z = zarr.open(output_path / "dataset_0.zarr") + z = zarr.open(output_path / "dataset_0") assert z["obsm"]["arr"].chunks[0] == 5, z["obsm"]["arr"] - if not densify: - assert z["X"]["indices"].chunks[0] == 10 + assert z["X"]["indices"].chunks[0] == 10 + + +def _read_lazy_x_and_obs_only_from_raw(path) -> ad.AnnData: + adata_ = ad.experimental.read_lazy(path) + if adata_.raw is not None: + x = adata_.raw.X + var = adata_.raw.var else: - assert z["X"].chunks[0] == 5, z["X"] + x = adata_.X + var = adata_.var + + return ad.AnnData( + X=x, + obs=adata_.obs.to_memory(), + var=var.to_memory(), + ) @pytest.mark.parametrize( @@ -262,21 +264,19 @@ def test_mismatched_raw_concat( adata_with_h5_path_different_var_space: tuple[ad.AnnData, Path], ): h5_files = sorted(adata_with_h5_path_different_var_space[1].iterdir()) - output_path = adata_with_h5_path_different_var_space[1].parent / "zarr_store_creation_test_heterogeneous" - output_path.mkdir(parents=True, exist_ok=True) + output_path = adata_with_h5_path_different_var_space[1].parent / "zarr_store_creation_test_heterogeneous.zarr" h5_paths = [adata_with_h5_path_different_var_space[1] / f for f in h5_files if str(f).endswith(".h5ad")] - with pytest.warns(UserWarning, match=r"Found raw keys not present in all anndatas"): - create_anndata_collection( - h5_paths, - output_path, - zarr_sparse_chunk_size=10, - zarr_sparse_shard_size=20, - zarr_dense_chunk_size=10, - zarr_dense_shard_size=20, - n_obs_per_dataset=60, - load_adata=_read_lazy_x_and_obs_only, - shuffle=False, # don't shuffle -> want to check if the right attributes get taken - ) + collection = DatasetCollection(output_path).add_adatas( + h5_paths, + zarr_sparse_chunk_size=10, + zarr_sparse_shard_size=20, + zarr_dense_chunk_size=10, + zarr_dense_shard_size=20, + n_obs_per_dataset=30, + shuffle_chunk_size=10, + shuffle=False, # don't shuffle -> want to check if the right attributes get taken + load_adata=_read_lazy_x_and_obs_only_from_raw, + ) adatas_orig = [] for file in h5_paths: @@ -291,50 +291,48 @@ def test_mismatched_raw_concat( adata_orig = ad.concat(adatas_orig, join="outer") adata_orig.obs_names_make_unique() - adata = ad.concat([ad.read_zarr(zarr_path) for zarr_path in sorted(output_path.iterdir())]) + adata = ad.concat([ad.io.read_elem(g) for g in collection]) del adata.obs["src_path"] pd.testing.assert_frame_equal(adata_orig.var, adata.var) pd.testing.assert_frame_equal(adata_orig.obs, adata.obs) np.testing.assert_array_equal(adata_orig.X.toarray(), adata.X.toarray()) -@pytest.mark.parametrize("densify", [True, False]) @pytest.mark.parametrize("load_adata", [ad.read_h5ad, ad.experimental.read_lazy]) def test_store_extension( adata_with_h5_path_different_var_space: tuple[ad.AnnData, Path], - densify: bool, load_adata: Callable[[PathLike[str] | str], ad.AnnData], ): - all_h5_paths = sorted(adata_with_h5_path_different_var_space[1].iterdir()) + all_h5_paths = sorted(p for p in adata_with_h5_path_different_var_space[1].iterdir() if p.suffix == ".h5ad") store_path = ( - adata_with_h5_path_different_var_space[1].parent / f"zarr_store_extension_test_{densify}_{load_adata.__name__}" + adata_with_h5_path_different_var_space[1].parent / f"zarr_store_extension_test_{load_adata.__name__}.zarr" ) original = all_h5_paths additional = all_h5_paths[4:] # don't add everything to get a "different" var space # create new store - create_anndata_collection( + collection = DatasetCollection(store_path) + collection.add_adatas( original, - store_path, zarr_sparse_chunk_size=10, zarr_sparse_shard_size=20, zarr_dense_chunk_size=10, zarr_dense_shard_size=20, n_obs_per_dataset=60, + shuffle_chunk_size=10, shuffle=True, - should_denseify=densify, ) # add h5ads to existing store - add_to_collection( + collection.add_adatas( additional, - store_path, load_adata=load_adata, zarr_sparse_chunk_size=10, zarr_sparse_shard_size=20, zarr_dense_chunk_size=5, zarr_dense_shard_size=10, + n_obs_per_dataset=50, + shuffle_chunk_size=10, ) - - adatas_on_disk = [ad.read_zarr(zarr_path) for zarr_path in sorted(store_path.iterdir())] + adatas_on_disk = [ad.io.read_elem(g) for g in collection] adata = ad.concat(adatas_on_disk) adata_orig = adata_with_h5_path_different_var_space[0] expected_adata = ad.concat([adata_orig, adata_orig[adata_orig.obs["store_id"] >= 4]], join="outer") @@ -344,9 +342,16 @@ def test_store_extension( for a in [*adatas_on_disk, adata]: assert a.obs["label"].dtype == expected_adata.obs["label"].dtype assert "arr" in adata.obsm - z = zarr.open(store_path / "dataset_0.zarr") + z = zarr.open(store_path / "dataset_0") assert z["obsm"]["arr"].chunks == (5, z["obsm"]["arr"].shape[1]) - if not densify: - assert z["X"]["indices"].chunks[0] == 10 - else: - assert z["X"].chunks == (5, z["X"].shape[1]) + assert z["X"]["indices"].chunks[0] == 10 + + +def test_empty(tmp_path: Path): + g = zarr.open(tmp_path / "empty.zarr") + collection = DatasetCollection(g) + assert collection.is_empty + # Doesn't matter what errors as long as this function runs, but not to completion + with pytest.raises(TypeError): + collection.add_adatas() + assert not (V1_ENCODING.items() <= g.attrs.items())