From 4ca82bc6c2c5c8bed9d2918228a2171878fc2dba Mon Sep 17 00:00:00 2001 From: selmanozleyen Date: Tue, 26 May 2026 14:07:23 +0200 Subject: [PATCH 01/14] introduce the new class --- src/annbatch/samplers/_chunk_sampler.py | 115 +++++++----- .../samplers/_fragmented_random_sampler.py | 173 ++++++++++++++++++ 2 files changed, 242 insertions(+), 46 deletions(-) create mode 100644 src/annbatch/samplers/_fragmented_random_sampler.py diff --git a/src/annbatch/samplers/_chunk_sampler.py b/src/annbatch/samplers/_chunk_sampler.py index 399536e..d57723f 100644 --- a/src/annbatch/samplers/_chunk_sampler.py +++ b/src/annbatch/samplers/_chunk_sampler.py @@ -58,19 +58,7 @@ def __init__( if stop is not None and start >= stop: raise ValueError("mask.start must be < mask.stop when mask.stop is specified") - check_lt_1([chunk_size, preload_nchunks], ["Chunk size", "Preloaded chunks"]) - preload_size = chunk_size * preload_nchunks - - if batch_size > preload_size: - raise ValueError( - "batch_size cannot exceed chunk_size * preload_nchunks. " - f"Got batch_size={batch_size}, but max is {preload_size}." - ) - if preload_size % batch_size != 0: - raise ValueError( - "chunk_size * preload_nchunks must be divisible by batch_size. " - f"Got {preload_size} % {batch_size} = {preload_size % batch_size}." - ) + validate_chunk_batch_preload_sizes(chunk_size, preload_nchunks, batch_size) self._rng = rng or np.random.default_rng() self._replacement = replacement self._num_samples = num_samples @@ -168,7 +156,16 @@ def _iter_from_chunks( batch_rng: np.random.Generator, worker_info: WorkerInfo | None, ) -> Iterator[LoadRequest]: - base = self._iter_from_chunks_base(chunks, batch_rng, worker_info) + base = iter_from_chunks( + chunks=chunks, + batch_rng=batch_rng, + worker_info=worker_info, + preload_nchunks=self._preload_nchunks, + batch_size=self._batch_size, + drop_last=self._drop_last, + shuffle=self._shuffle, + chunk_size=self._chunk_size, + ) if not self._replacement and self._num_samples is None: yield from base return @@ -181,38 +178,6 @@ def _iter_from_chunks( "splits": load_request["splits"][:tail], } - def _iter_from_chunks_base( - self, - chunks: list[slice], - batch_rng: np.random.Generator, - worker_info: WorkerInfo | None, - ) -> Iterator[LoadRequest]: - # Worker sharding: each worker gets a disjoint subset of chunks - if worker_info is not None: - chunks = np.array_split(chunks, worker_info.num_workers)[worker_info.id] - # Set up the iterator for chunks and the batch indices for splits - chunks_per_request = split_given_size(chunks, self._preload_nchunks) - batch_indices = np.arange(self._in_memory_size) - split_batch_indices = split_given_size(batch_indices, self.batch_size) - for request_chunks in chunks_per_request[:-1]: - if self.shuffle: - # Avoid copies using in-place shuffling since `self.shuffle` should not change mid-training - batch_rng.shuffle(batch_indices) - split_batch_indices = split_given_size(batch_indices, self.batch_size) - yield {"chunks": request_chunks, "splits": split_batch_indices} - # On the last yield, drop the last uneven batch and create new batch_indices since the in-memory size of this last yield could be divisible by batch_size but smaller than preload_nslices * slice_size - final_chunks = chunks_per_request[-1] - total_obs_in_last_batch = int(sum(s.stop - s.start for s in final_chunks)) - if total_obs_in_last_batch == 0: # pragma: no cover - raise RuntimeError("Last batch was found to have no observations. Please open an issue.") - if self._drop_last: - if total_obs_in_last_batch < self.batch_size: - return - total_obs_in_last_batch -= total_obs_in_last_batch % self.batch_size - indices = batch_rng.permutation(total_obs_in_last_batch) if self.shuffle else np.arange(total_obs_in_last_batch) - batch_indices = split_given_size(indices, self.batch_size) - yield {"chunks": final_chunks, "splits": batch_indices} - def _compute_chunks(self, n_obs: int, rng: np.random.Generator) -> list[slice]: """Compute chunks from start and stop indices. @@ -323,3 +288,61 @@ def __init__( mask=mask, rng=rng, ) + + +def iter_from_chunks( + chunks: list[slice], + batch_rng: np.random.Generator, + worker_info: WorkerInfo | None, + preload_nchunks: int, + batch_size: int, + drop_last: bool, + shuffle: bool, + chunk_size: int, +) -> Iterator[LoadRequest]: + # Worker sharding: each worker gets a disjoint subset of chunks + if worker_info is not None: + chunks = np.array_split(chunks, worker_info.num_workers)[worker_info.id] + # Set up the iterator for chunks and the batch indices for splits + chunks_per_request = split_given_size(chunks, preload_nchunks) + in_memory_size = preload_nchunks * chunk_size + batch_indices = np.arange(in_memory_size) + split_batch_indices = split_given_size(batch_indices, batch_size) + for request_chunks in chunks_per_request[:-1]: + if shuffle: + # Avoid copies using in-place shuffling since `self.shuffle` should not change mid-training + batch_rng.shuffle(batch_indices) + split_batch_indices = split_given_size(batch_indices, batch_size) + yield {"chunks": request_chunks, "splits": split_batch_indices} + # On the last yield, drop the last uneven batch and create new batch_indices since the in-memory size of this last yield could be divisible by batch_size but smaller than preload_nslices * slice_size + final_chunks = chunks_per_request[-1] + total_obs_in_last_batch = int(sum(s.stop - s.start for s in final_chunks)) + if total_obs_in_last_batch == 0: # pragma: no cover + raise RuntimeError("Last batch was found to have no observations. Please open an issue.") + if drop_last: + if total_obs_in_last_batch < batch_size: + return + total_obs_in_last_batch -= total_obs_in_last_batch % batch_size + indices = batch_rng.permutation(total_obs_in_last_batch) if shuffle else np.arange(total_obs_in_last_batch) + batch_indices = split_given_size(indices, batch_size) + yield {"chunks": final_chunks, "splits": batch_indices} + + +def validate_chunk_batch_preload_sizes( + chunk_size: int, + preload_nchunks: int, + batch_size: int, +) -> None: + check_lt_1([chunk_size, preload_nchunks], ["Chunk size", "Preloaded chunks"]) + preload_size = chunk_size * preload_nchunks + + if batch_size > preload_size: + raise ValueError( + "batch_size cannot exceed chunk_size * preload_nchunks. " + f"Got batch_size={batch_size}, but max is {preload_size}." + ) + if preload_size % batch_size != 0: + raise ValueError( + "chunk_size * preload_nchunks must be divisible by batch_size. " + f"Got {preload_size} % {batch_size} = {preload_size % batch_size}." + ) diff --git a/src/annbatch/samplers/_fragmented_random_sampler.py b/src/annbatch/samplers/_fragmented_random_sampler.py new file mode 100644 index 0000000..70d8f53 --- /dev/null +++ b/src/annbatch/samplers/_fragmented_random_sampler.py @@ -0,0 +1,173 @@ +"""SequentialSampler -- ordered chunk-based sampler.""" + +from __future__ import annotations + +import math +from typing import TYPE_CHECKING + +import numpy as np + +from annbatch.abc import Sampler +from annbatch.samplers._chunk_sampler import iter_from_chunks, validate_chunk_batch_preload_sizes +from annbatch.samplers._utils import get_torch_worker_info + +if TYPE_CHECKING: + from collections.abc import Iterator + + from annbatch.types import LoadRequest + + +if TYPE_CHECKING: + from collections.abc import Iterator + + from annbatch.types import LoadRequest + + +class FragmentedRandomSampler(Sampler): + """Chunk-based sampler implementation for batched data access.""" + + _batch_size: int + _chunk_size: int + _preload_nchunks: int + _in_memory_size: int + _num_samples: int + _masks: list[slice] + + def __init__( + self, + chunk_size: int, + preload_nchunks: int, + batch_size: int, + *, + masks: list[slice], + num_samples: int, + drop_last: bool = False, + rng: np.random.Generator | None = None, + ): + validate_chunk_batch_preload_sizes(chunk_size, preload_nchunks, batch_size) + + # standard mask validation + if not all(mask.stop > mask.start and mask.start >= 0 for mask in masks): + raise ValueError("All masks must have mask.stop > mask.start and mask.start >= 0.") + if not all(mask.stop is not None and mask.start is not None for mask in masks): + raise ValueError("All masks must have non-None start and stop.") + if not all(mask.step == 1 or mask.step is None for mask in masks): + raise ValueError("mask.step must be 1 or None for all masks in FragmentedRandomSampler.") + + # enforce that it's non-overlapping and sorted by start index + # sorting by start index should be same with sorting by stop index otherwise there is an overlap + sorted_masks = sorted(masks, key=lambda m: m.start) + starts = np.array([m.start for m in sorted_masks], dtype=np.int64) + stops = np.array([m.stop for m in sorted_masks], dtype=np.int64) + if len(sorted_masks) > 1 and not np.all(stops[:-1] <= starts[1:]): + raise ValueError("Masks must be non-overlapping for FragmentedRandomSampler.") + + # now we will merge any two masks that are adjacent + is_adj = starts[1:] == stops[:-1] + if np.any(is_adj): + new_starts = np.concatenate(([starts[0]], starts[1:][~is_adj])) + new_stops = np.concatenate((stops[:-1][~is_adj], [stops[-1]])) + starts, stops = new_starts, new_stops + + if not np.all(stops - starts >= chunk_size): + raise ValueError("Each mask must cover at least one chunk (mask.stop - mask.start >= chunk_size).") + + # precompute cumulative sums for efficient chunk sampling + cumsum_centered = np.concatenate([np.array([0]), np.cumsum(stops - starts - self._chunk_size)]) + chunk_start_offsets = np.concatenate([np.array([0]), cumsum_centered[1:] - cumsum_centered[:-1]]) + + self._rng = rng or np.random.default_rng() + self._in_memory_size = chunk_size * preload_nchunks + + # self._masks = sorted_masks + self._cumsum_centered, self._chunk_start_offsets = cumsum_centered, chunk_start_offsets + self._starts, self._stops = starts, stops + self._num_samples = num_samples + self._drop_last = drop_last + self._batch_size, self._chunk_size, self._preload_nchunks = ( + batch_size, + chunk_size, + preload_nchunks, + ) + + @property + def mask(self) -> slice: + raise NotImplementedError( + "mask property is not implemented for FragmentedRandomSampler since it operates on multiple masks." + ) + + @mask.setter + def mask(self, value: slice) -> None: + raise NotImplementedError( + "mask property is not implemented for FragmentedRandomSampler since it operates on multiple masks." + ) + + @property + def batch_size(self) -> int: + return self._batch_size + + @property + def shuffle(self) -> bool: + return True + + def n_iters(self, n_obs: int) -> int: + del n_obs # not needed + return ( + self._num_samples // self.batch_size if self._drop_last else math.ceil(self._num_samples / self.batch_size) + ) + + def validate(self, n_obs: int) -> None: + """Validate the sampler configuration against the loader's n_obs. + + Parameters + ---------- + n_obs + The total number of observations in the loader. + + Raises + ------ + ValueError + If the sampler configuration is invalid for the given n_obs. + """ + if np.any(self._stops > n_obs): + raise ValueError( + f"Sampler has a mask from masks such that mask.stop exceeds loader n_obs ({n_obs}). " + "The masks given to the sampler must be within the loader's observations." + ) + + def _sample(self, n_obs: int) -> Iterator[LoadRequest]: + del n_obs # not needed since we don't infer anything from n_obs + worker_info = get_torch_worker_info() + if worker_info is not None and worker_info.num_workers > 1: + raise NotImplementedError("Multiple workers are not supported with FragmentedRandomSampler.") + + chunks = self._compute_chunks() + return iter_from_chunks( + chunks=chunks, + batch_rng=self._rng, + preload_nchunks=self._preload_nchunks, + batch_size=self._batch_size, + drop_last=self._drop_last, + chunk_size=self._chunk_size, + shuffle=True, + worker_info=None, + ) + + def _compute_chunks(self): + n_chunks, remainder = divmod(self._num_samples, self._chunk_size) + if remainder > 0 and not self._drop_last: + n_chunks += 1 + + num_possible_chunk_starts = self._cumsum_centered[-1] + + offsets = self._rng.integers(num_possible_chunk_starts, size=n_chunks) + frag_idx = np.searchsorted(self._cumsum_centered, offsets, side="left") + + # there is two layer of remapping here: + # we need to readjust the distances between masks: done by chunk_start_offsets[frag_idx - 1] + # adding the actual starts of the masks: done by self._starts[frag_idx - 1] + chunk_starts = offsets - self._chunk_start_offsets[frag_idx - 1] + self._starts[frag_idx - 1] + chunks = [slice(int(s), int(s + self._chunk_size)) for s in chunk_starts] + if remainder > 0 and not self._drop_last: + chunks[-1] = slice(int(chunk_starts[-1]), int(chunk_starts[-1] + remainder)) + return chunks From 00463e84deb3d9aa082d83cc592dcc36a739a368 Mon Sep 17 00:00:00 2001 From: selmanozleyen Date: Tue, 26 May 2026 14:33:14 +0200 Subject: [PATCH 02/14] add docstring --- .../samplers/_fragmented_random_sampler.py | 33 ++++++++++++++++++- 1 file changed, 32 insertions(+), 1 deletion(-) diff --git a/src/annbatch/samplers/_fragmented_random_sampler.py b/src/annbatch/samplers/_fragmented_random_sampler.py index 70d8f53..7a0c95e 100644 --- a/src/annbatch/samplers/_fragmented_random_sampler.py +++ b/src/annbatch/samplers/_fragmented_random_sampler.py @@ -24,7 +24,38 @@ class FragmentedRandomSampler(Sampler): - """Chunk-based sampler implementation for batched data access.""" + """Random sampler for multiple non-overlapping data ranges. + + This sampler generates random chunks across multiple disjoint regions (masks) of a dataset, + enabling efficient random sampling from fragmented data regions. + + Adjacent masks are automatically merged internally. + For example, if masks=[slice(0, 10), slice(10, 20)], they will be merged into a single mask slice(0, 20). + After this merging step, the sampler will ensure that each mask from the merged list of masks + covers at least one full chunk and is within the dataset bounds. + + Multiple workers are not supported with this sampler. + + Parameters + ---------- + batch_size + Number of observations per batch. + chunk_size + Size of each chunk i.e. the range of each chunk yielded. + masks + List of non-overlapping slices defining the data regions to sample from. + Each slice must have start >= 0, stop > start, and step is 1 or None. + preload_nchunks + Number of chunks to load per iteration. + drop_last + Whether to drop the last incomplete batch. + rng + Random number generator for shuffling. Note that :func:`torch.manual_seed` + has no effect on reproducibility here; pass a seeded + :class:`numpy.random.Generator` to control randomness. + num_samples + Total number of observations to draw. + """ _batch_size: int _chunk_size: int From e50e78e8a6cb4b2e3a1747876898f5161d07ebe7 Mon Sep 17 00:00:00 2001 From: selmanozleyen Date: Tue, 26 May 2026 14:34:20 +0200 Subject: [PATCH 03/14] self._in_mem not needed --- src/annbatch/samplers/_fragmented_random_sampler.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/annbatch/samplers/_fragmented_random_sampler.py b/src/annbatch/samplers/_fragmented_random_sampler.py index 7a0c95e..3ff311b 100644 --- a/src/annbatch/samplers/_fragmented_random_sampler.py +++ b/src/annbatch/samplers/_fragmented_random_sampler.py @@ -108,9 +108,7 @@ def __init__( chunk_start_offsets = np.concatenate([np.array([0]), cumsum_centered[1:] - cumsum_centered[:-1]]) self._rng = rng or np.random.default_rng() - self._in_memory_size = chunk_size * preload_nchunks - # self._masks = sorted_masks self._cumsum_centered, self._chunk_start_offsets = cumsum_centered, chunk_start_offsets self._starts, self._stops = starts, stops self._num_samples = num_samples From 248523393a028abcc9718496505df65a9a706171 Mon Sep 17 00:00:00 2001 From: selmanozleyen Date: Tue, 26 May 2026 14:35:08 +0200 Subject: [PATCH 04/14] simplify the validate method --- src/annbatch/samplers/_fragmented_random_sampler.py | 13 +------------ 1 file changed, 1 insertion(+), 12 deletions(-) diff --git a/src/annbatch/samplers/_fragmented_random_sampler.py b/src/annbatch/samplers/_fragmented_random_sampler.py index 3ff311b..134ea2f 100644 --- a/src/annbatch/samplers/_fragmented_random_sampler.py +++ b/src/annbatch/samplers/_fragmented_random_sampler.py @@ -146,18 +146,7 @@ def n_iters(self, n_obs: int) -> int: ) def validate(self, n_obs: int) -> None: - """Validate the sampler configuration against the loader's n_obs. - - Parameters - ---------- - n_obs - The total number of observations in the loader. - - Raises - ------ - ValueError - If the sampler configuration is invalid for the given n_obs. - """ + """Validate if there are any masks that exceed the loader's n_obs.""" if np.any(self._stops > n_obs): raise ValueError( f"Sampler has a mask from masks such that mask.stop exceeds loader n_obs ({n_obs}). " From 9ef05739aee2c48e9eb28e02e92997f366046005 Mon Sep 17 00:00:00 2001 From: selmanozleyen Date: Tue, 26 May 2026 14:36:17 +0200 Subject: [PATCH 05/14] simplify value error --- src/annbatch/samplers/_fragmented_random_sampler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/annbatch/samplers/_fragmented_random_sampler.py b/src/annbatch/samplers/_fragmented_random_sampler.py index 134ea2f..3ecb4e3 100644 --- a/src/annbatch/samplers/_fragmented_random_sampler.py +++ b/src/annbatch/samplers/_fragmented_random_sampler.py @@ -91,7 +91,7 @@ def __init__( starts = np.array([m.start for m in sorted_masks], dtype=np.int64) stops = np.array([m.stop for m in sorted_masks], dtype=np.int64) if len(sorted_masks) > 1 and not np.all(stops[:-1] <= starts[1:]): - raise ValueError("Masks must be non-overlapping for FragmentedRandomSampler.") + raise ValueError("Masks must be non-overlapping.") # now we will merge any two masks that are adjacent is_adj = starts[1:] == stops[:-1] From 83b3b3a7e8463604e725bc6516da8809e4097714 Mon Sep 17 00:00:00 2001 From: selmanozleyen Date: Tue, 26 May 2026 14:37:56 +0200 Subject: [PATCH 06/14] arrange assignments --- src/annbatch/samplers/_fragmented_random_sampler.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/annbatch/samplers/_fragmented_random_sampler.py b/src/annbatch/samplers/_fragmented_random_sampler.py index 3ecb4e3..9056a5d 100644 --- a/src/annbatch/samplers/_fragmented_random_sampler.py +++ b/src/annbatch/samplers/_fragmented_random_sampler.py @@ -103,14 +103,13 @@ def __init__( if not np.all(stops - starts >= chunk_size): raise ValueError("Each mask must cover at least one chunk (mask.stop - mask.start >= chunk_size).") - # precompute cumulative sums for efficient chunk sampling + # precompute and save cumulative sums to avoid recomputing them every time in _compute_chunks cumsum_centered = np.concatenate([np.array([0]), np.cumsum(stops - starts - self._chunk_size)]) chunk_start_offsets = np.concatenate([np.array([0]), cumsum_centered[1:] - cumsum_centered[:-1]]) - - self._rng = rng or np.random.default_rng() - self._cumsum_centered, self._chunk_start_offsets = cumsum_centered, chunk_start_offsets self._starts, self._stops = starts, stops + self._rng = rng or np.random.default_rng() + # assigned as is self._num_samples = num_samples self._drop_last = drop_last self._batch_size, self._chunk_size, self._preload_nchunks = ( From 1dfa6209836026ec65d4f05eeb61c2278d855664 Mon Sep 17 00:00:00 2001 From: selmanozleyen Date: Tue, 26 May 2026 14:39:12 +0200 Subject: [PATCH 07/14] add comment --- src/annbatch/samplers/_fragmented_random_sampler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/annbatch/samplers/_fragmented_random_sampler.py b/src/annbatch/samplers/_fragmented_random_sampler.py index 9056a5d..89a3d8b 100644 --- a/src/annbatch/samplers/_fragmented_random_sampler.py +++ b/src/annbatch/samplers/_fragmented_random_sampler.py @@ -176,8 +176,8 @@ def _compute_chunks(self): n_chunks += 1 num_possible_chunk_starts = self._cumsum_centered[-1] - offsets = self._rng.integers(num_possible_chunk_starts, size=n_chunks) + # frag_idx will tell us for each draw which mask it belongs to frag_idx = np.searchsorted(self._cumsum_centered, offsets, side="left") # there is two layer of remapping here: From 2635cc9d83c40b4b82d22c17eb5aadf331d12397 Mon Sep 17 00:00:00 2001 From: selmanozleyen Date: Tue, 26 May 2026 14:46:53 +0200 Subject: [PATCH 08/14] fix of self.chunksize --- src/annbatch/samplers/_fragmented_random_sampler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/annbatch/samplers/_fragmented_random_sampler.py b/src/annbatch/samplers/_fragmented_random_sampler.py index 89a3d8b..87c893f 100644 --- a/src/annbatch/samplers/_fragmented_random_sampler.py +++ b/src/annbatch/samplers/_fragmented_random_sampler.py @@ -104,7 +104,7 @@ def __init__( raise ValueError("Each mask must cover at least one chunk (mask.stop - mask.start >= chunk_size).") # precompute and save cumulative sums to avoid recomputing them every time in _compute_chunks - cumsum_centered = np.concatenate([np.array([0]), np.cumsum(stops - starts - self._chunk_size)]) + cumsum_centered = np.concatenate([np.array([0]), np.cumsum(stops - starts - chunk_size)]) chunk_start_offsets = np.concatenate([np.array([0]), cumsum_centered[1:] - cumsum_centered[:-1]]) self._cumsum_centered, self._chunk_start_offsets = cumsum_centered, chunk_start_offsets self._starts, self._stops = starts, stops From 5fc6d90854e1ed78fcf7b1cbacdc98f8bf06b507 Mon Sep 17 00:00:00 2001 From: selmanozleyen Date: Tue, 26 May 2026 15:25:48 +0200 Subject: [PATCH 09/14] merge tests --- tests/test_fragmented_random_sampler.py | 417 ++++++++++++++++++++++++ 1 file changed, 417 insertions(+) create mode 100644 tests/test_fragmented_random_sampler.py diff --git a/tests/test_fragmented_random_sampler.py b/tests/test_fragmented_random_sampler.py new file mode 100644 index 0000000..ba143f7 --- /dev/null +++ b/tests/test_fragmented_random_sampler.py @@ -0,0 +1,417 @@ +"""Tests for FragmentedRandomSampler.""" + +from __future__ import annotations + +import math +from unittest.mock import patch + +import numpy as np +import pytest + +from annbatch.samplers._fragmented_random_sampler import FragmentedRandomSampler +from annbatch.samplers._utils import WorkerInfo +from tests.test_sampler import collect_indices + +# ============================================================================= +# Mask Validation Tests +# ============================================================================= + + +@pytest.mark.parametrize( + ("masks", "error_match"), + [ + pytest.param([slice(0, 20)], None, id="valid_single_mask"), + pytest.param([slice(0, 20), slice(30, 50)], None, id="valid_two_masks"), + pytest.param([slice(0, 15), slice(15, 30)], None, id="adjacent_masks_merged"), + pytest.param([slice(30, 50), slice(0, 20)], None, id="unsorted_masks_sorted"), + pytest.param( + [slice(0, 20), slice(10, 30)], + "non-overlapping", + id="overlapping_masks", + ), + pytest.param( + [slice(0, 20, 2)], + "mask.step must be 1 or None", + id="step_not_one", + ), + pytest.param( + [slice(20, 10)], + "mask.stop > mask.start", + id="start_greater_than_stop", + ), + pytest.param( + [slice(-1, 10)], + "mask.start >= 0", + id="negative_start", + ), + pytest.param( + [slice(0, 5)], + "at least one chunk", + id="mask_smaller_than_chunk", + ), + pytest.param( + [slice(0, 20), slice(30, 35)], + "at least one chunk", + id="second_mask_too_small", + ), + ], +) +def test_mask_validation(masks: list[slice], error_match: str | None): + """Test mask validation at construction.""" + kwargs = { + "chunk_size": 10, + "preload_nchunks": 2, + "batch_size": 5, + "masks": masks, + "num_samples": 50, + } + if error_match: + with pytest.raises((ValueError, TypeError), match=error_match): + FragmentedRandomSampler(**kwargs) + else: + sampler = FragmentedRandomSampler(**kwargs) + assert sampler is not None + + +# ============================================================================= +# Mask Coverage Tests +# ============================================================================= + + +@pytest.mark.parametrize( + ("masks", "num_samples", "drop_last"), + [ + pytest.param([slice(0, 100)], 50, False, id="single_mask_partial"), + pytest.param([slice(0, 100)], 100, False, id="single_mask_full"), + pytest.param([slice(0, 100), slice(200, 300)], 100, False, id="two_masks"), + pytest.param( + [slice(0, 100), slice(200, 300)], + 90, + False, + id="three_masks", + ), + ], +) +def test_mask_coverage( + masks: list[slice], + num_samples: int, + drop_last: bool, +): + """Test that sampler yields correct number of samples from specified masks.""" + chunk_size, preload_nchunks, batch_size = 10, 2, 5 + sampler = FragmentedRandomSampler( + chunk_size=chunk_size, + preload_nchunks=preload_nchunks, + batch_size=batch_size, + masks=masks, + num_samples=num_samples, + drop_last=drop_last, + rng=np.random.default_rng(42), + ) + + all_indices, all_chunks, splits = collect_indices(sampler, n_obs=max(m.stop for m in masks)) + + # Verify all indices are within at least one mask + mask_ranges = set() + for mask in masks: + mask_ranges.update(range(mask.start, mask.stop)) + + assert set(all_indices).issubset(mask_ranges), "All indices should be within mask ranges" + + # Verify batch count matches n_iters + expected_iters = sampler.n_iters(n_obs=500) + assert len(splits) == expected_iters, f"Expected {expected_iters} batches, got {len(splits)}" + + +@pytest.mark.parametrize( + ("masks", "num_samples"), + [ + pytest.param([slice(0, 100)], 50, id="single_mask"), + pytest.param([slice(0, 100), slice(200, 300)], 100, id="two_masks"), + pytest.param([slice(0, 50), slice(100, 150), slice(200, 250)], 75, id="three_masks"), + ], +) +def test_chunk_distribution_across_masks(masks: list[slice], num_samples: int): + """Test that chunks are distributed across all masks (not biased to one).""" + chunk_size, preload_nchunks, batch_size = 10, 2, 5 + sampler = FragmentedRandomSampler( + chunk_size=chunk_size, + preload_nchunks=preload_nchunks, + batch_size=batch_size, + masks=masks, + num_samples=num_samples, + rng=np.random.default_rng(42), + ) + + all_indices, all_chunks, _ = collect_indices(sampler, n_obs=max(m.stop for m in masks)) + + # Count how many chunks belong to each mask + chunks_per_mask = dict.fromkeys(range(len(masks)), 0) + for chunk in all_chunks: + for i, mask in enumerate(masks): + if mask.start <= chunk.start and chunk.stop <= mask.stop: + chunks_per_mask[i] += 1 + break + + # All masks should have at least one chunk (if large enough) + for mask_idx, count in chunks_per_mask.items(): + if masks[mask_idx].stop - masks[mask_idx].start >= chunk_size * 2: + assert count > 0, f"Mask {mask_idx} has no chunks" + + +# ============================================================================= +# Batch and Iteration Count Tests +# ============================================================================= + + +@pytest.mark.parametrize( + ("masks", "num_samples", "batch_size", "drop_last", "expected_iters"), + [ + pytest.param([slice(0, 100)], 50, 5, False, 10, id="exact_division"), + pytest.param([slice(0, 100)], 52, 5, False, 11, id="ceil_division"), + pytest.param([slice(0, 100)], 52, 5, True, 10, id="floor_division"), + pytest.param([slice(0, 100), slice(200, 300)], 75, 10, False, 8, id="multiple_masks"), + pytest.param([slice(0, 100), slice(200, 300)], 75, 10, True, 7, id="multiple_masks_drop"), + ], +) +def test_n_iters_property( + masks: list[slice], + num_samples: int, + batch_size: int, + drop_last: bool, + expected_iters: int, +): + """Test n_iters returns correct batch count.""" + sampler = FragmentedRandomSampler( + chunk_size=10, + preload_nchunks=2, + batch_size=batch_size, + masks=masks, + num_samples=num_samples, + drop_last=drop_last, + ) + assert sampler.n_iters(n_obs=500) == expected_iters + + assert sampler.shuffle is True, "Shuffle property should always return True" + assert sampler.batch_size == batch_size, "Batch size property should return the correct value" + + +# ============================================================================= +# Randomness and Reproducibility Tests +# ============================================================================= + + +@pytest.mark.parametrize( + ("masks", "num_samples"), + [ + pytest.param([slice(0, 100)], 50, id="single_mask"), + pytest.param([slice(0, 100), slice(200, 300)], 100, id="two_masks"), + ], +) +def test_reproducibility_with_same_seed(masks: list[slice], num_samples: int): + """Test same seed produces same chunk sequence.""" + kwargs = { + "chunk_size": 10, + "preload_nchunks": 2, + "batch_size": 5, + "masks": masks, + "num_samples": num_samples, + } + + indices1, _, _ = collect_indices( + FragmentedRandomSampler(**kwargs, rng=np.random.default_rng(42)), + n_obs=max(m.stop for m in masks), + ) + indices2, _, _ = collect_indices( + FragmentedRandomSampler(**kwargs, rng=np.random.default_rng(42)), + n_obs=max(m.stop for m in masks), + ) + indices3, _, _ = collect_indices( + FragmentedRandomSampler(**kwargs, rng=np.random.default_rng(99)), + n_obs=max(m.stop for m in masks), + ) + + assert indices1 == indices2, "Same seed should produce identical sequences" + assert indices1 != indices3, "Different seeds should produce different sequences" + + +def test_mask_property_not_implemented(): + """Test that mask property getter and setter raise NotImplementedError.""" + sampler = FragmentedRandomSampler( + chunk_size=10, + preload_nchunks=2, + batch_size=5, + masks=[slice(0, 100)], + num_samples=50, + ) + with pytest.raises(NotImplementedError, match="mask property is not implemented"): + _ = sampler.mask + + with pytest.raises(NotImplementedError, match="mask property is not implemented"): + sampler.mask = slice(0, 50) + + +@pytest.mark.parametrize( + ("masks", "n_obs", "should_fail"), + [ + pytest.param([slice(0, 100)], 100, False, id="valid_n_obs"), + pytest.param([slice(0, 100)], 99, True, id="n_obs_too_small"), + pytest.param([slice(0, 100), slice(200, 300)], 300, False, id="multiple_masks_valid"), + pytest.param([slice(0, 100), slice(200, 300)], 299, True, id="multiple_masks_invalid"), + ], +) +def test_validate(masks: list[slice], n_obs: int, should_fail: bool): + """Test validate() checks n_obs against mask bounds.""" + sampler = FragmentedRandomSampler( + chunk_size=10, + preload_nchunks=2, + batch_size=5, + masks=masks, + num_samples=50, + ) + if should_fail: + with pytest.raises(ValueError, match="mask.stop.*exceeds loader n_obs"): + sampler.validate(n_obs) + else: + sampler.validate(n_obs) + + +# ============================================================================= +# Error and Edge Case Tests +# ============================================================================= + + +def test_multiple_workers_not_supported(): + """Test that multiple workers raise NotImplementedError.""" + sampler = FragmentedRandomSampler( + chunk_size=10, + preload_nchunks=2, + batch_size=5, + masks=[slice(0, 100)], + num_samples=50, + rng=np.random.default_rng(42), + ) + with ( + patch( + "annbatch.samplers._fragmented_random_sampler.get_torch_worker_info", + return_value=WorkerInfo(id=0, num_workers=2), + ), + pytest.raises(NotImplementedError, match="Multiple workers are not supported"), + ): + list(sampler.sample(n_obs=500)) + + +@pytest.mark.parametrize( + ("masks", "num_samples", "id_suffix"), + [ + pytest.param([slice(0, 100)], 50, "single_large"), + pytest.param([slice(0, 100)], 100, "single_full"), + pytest.param([slice(0, 100), slice(200, 300)], 100, "two_large"), + pytest.param( + [slice(0, 100), slice(200, 300), slice(400, 500)], + 150, + "three_masks", + ), + ], +) +def test_edge_case_mask_counts(masks: list[slice], num_samples: int, id_suffix: str): + """Test sampling with various mask and sample counts.""" + chunk_size, preload_nchunks, batch_size = 10, 2, 5 + sampler = FragmentedRandomSampler( + chunk_size=chunk_size, + preload_nchunks=preload_nchunks, + batch_size=batch_size, + masks=masks, + num_samples=num_samples, + rng=np.random.default_rng(42), + ) + n_obs = max(m.stop for m in masks) + all_indices, all_chunks, splits = collect_indices(sampler, n_obs) + + # Verify batch count + expected_iters = math.ceil(num_samples / batch_size) + assert len(splits) == expected_iters + # Verify structure is sound + assert len(all_chunks) > 0 + for chunk in all_chunks: + assert chunk.stop - chunk.start > 0 + + +@pytest.mark.parametrize( + ("chunk_size", "preload_nchunks", "batch_size", "should_fail"), + [ + pytest.param(10, 0, 5, True, id="preload_zero"), + pytest.param(0, 2, 5, True, id="chunk_zero"), + pytest.param(10, 2, 5, False, id="valid_config"), + pytest.param(10, 2, 10, False, id="batch_equals_preload_size"), + ], +) +def test_chunk_batch_preload_size_validation( + chunk_size: int, + preload_nchunks: int, + batch_size: int, + should_fail: bool, +): + """Test validation of chunk, batch, and preload sizes.""" + if should_fail: + with pytest.raises((ValueError, ZeroDivisionError)): + FragmentedRandomSampler( + chunk_size=chunk_size, + preload_nchunks=preload_nchunks, + batch_size=batch_size, + masks=[slice(0, 100)], + num_samples=50, + ) + else: + sampler = FragmentedRandomSampler( + chunk_size=chunk_size, + preload_nchunks=preload_nchunks, + batch_size=batch_size, + masks=[slice(0, 100)], + num_samples=50, + ) + assert sampler is not None + + +def test_small_masks_with_multiple_chunks(): + """Test sampling from masks that span exactly 2 chunks.""" + chunk_size = 10 + sampler = FragmentedRandomSampler( + chunk_size=chunk_size, + preload_nchunks=2, + batch_size=5, + masks=[slice(0, 20), slice(40, 60)], + num_samples=15, + rng=np.random.default_rng(42), + ) + all_indices, all_chunks, splits = collect_indices(sampler, n_obs=60) + + assert len(splits) == math.ceil(15 / 5) + # All indices must be in one of the two mask ranges + mask_set = set(range(0, 20)) | set(range(40, 60)) + assert set(all_indices).issubset(mask_set) + + +def test_remainder_handling(): + """Test that remainder samples are handled correctly with drop_last.""" + num_samples = 23 + batch_size = 5 + chunk_size = 10 + sampler = FragmentedRandomSampler( + chunk_size=chunk_size, + preload_nchunks=2, + batch_size=batch_size, + masks=[slice(0, 100)], + num_samples=num_samples, + drop_last=False, + rng=np.random.default_rng(42), + ) + + _, _, splits = collect_indices(sampler, n_obs=100) + + expected_batches = math.ceil(num_samples / batch_size) + assert len(splits) == expected_batches + + # Last batch should have the remainder + last_batch_size = num_samples % batch_size or batch_size + assert len(splits[-1]) == last_batch_size From 12e853093caf1ec9598facc3ed6dcb30fc5a9bf7 Mon Sep 17 00:00:00 2001 From: selmanozleyen Date: Tue, 26 May 2026 15:33:12 +0200 Subject: [PATCH 10/14] merge some unit tests --- tests/test_fragmented_random_sampler.py | 90 +++++-------------------- 1 file changed, 16 insertions(+), 74 deletions(-) diff --git a/tests/test_fragmented_random_sampler.py b/tests/test_fragmented_random_sampler.py index ba143f7..3ce33ba 100644 --- a/tests/test_fragmented_random_sampler.py +++ b/tests/test_fragmented_random_sampler.py @@ -123,42 +123,6 @@ def test_mask_coverage( assert len(splits) == expected_iters, f"Expected {expected_iters} batches, got {len(splits)}" -@pytest.mark.parametrize( - ("masks", "num_samples"), - [ - pytest.param([slice(0, 100)], 50, id="single_mask"), - pytest.param([slice(0, 100), slice(200, 300)], 100, id="two_masks"), - pytest.param([slice(0, 50), slice(100, 150), slice(200, 250)], 75, id="three_masks"), - ], -) -def test_chunk_distribution_across_masks(masks: list[slice], num_samples: int): - """Test that chunks are distributed across all masks (not biased to one).""" - chunk_size, preload_nchunks, batch_size = 10, 2, 5 - sampler = FragmentedRandomSampler( - chunk_size=chunk_size, - preload_nchunks=preload_nchunks, - batch_size=batch_size, - masks=masks, - num_samples=num_samples, - rng=np.random.default_rng(42), - ) - - all_indices, all_chunks, _ = collect_indices(sampler, n_obs=max(m.stop for m in masks)) - - # Count how many chunks belong to each mask - chunks_per_mask = dict.fromkeys(range(len(masks)), 0) - for chunk in all_chunks: - for i, mask in enumerate(masks): - if mask.start <= chunk.start and chunk.stop <= mask.stop: - chunks_per_mask[i] += 1 - break - - # All masks should have at least one chunk (if large enough) - for mask_idx, count in chunks_per_mask.items(): - if masks[mask_idx].stop - masks[mask_idx].start >= chunk_size * 2: - assert count > 0, f"Mask {mask_idx} has no chunks" - - # ============================================================================= # Batch and Iteration Count Tests # ============================================================================= @@ -172,16 +136,19 @@ def test_chunk_distribution_across_masks(masks: list[slice], num_samples: int): pytest.param([slice(0, 100)], 52, 5, True, 10, id="floor_division"), pytest.param([slice(0, 100), slice(200, 300)], 75, 10, False, 8, id="multiple_masks"), pytest.param([slice(0, 100), slice(200, 300)], 75, 10, True, 7, id="multiple_masks_drop"), + pytest.param( + [slice(0, 100), slice(200, 300), slice(400, 500)], 150, 5, False, 30, id="three_masks" + ), ], ) -def test_n_iters_property( +def test_batch_and_iteration_counts( masks: list[slice], num_samples: int, batch_size: int, drop_last: bool, expected_iters: int, ): - """Test n_iters returns correct batch count.""" + """Test n_iters property and actual batch counts for various configurations.""" sampler = FragmentedRandomSampler( chunk_size=10, preload_nchunks=2, @@ -189,12 +156,23 @@ def test_n_iters_property( masks=masks, num_samples=num_samples, drop_last=drop_last, + rng=np.random.default_rng(42), ) assert sampler.n_iters(n_obs=500) == expected_iters assert sampler.shuffle is True, "Shuffle property should always return True" assert sampler.batch_size == batch_size, "Batch size property should return the correct value" + n_obs = max(m.stop for m in masks) + all_indices, all_chunks, splits = collect_indices(sampler, n_obs) + + # Verify actual batch count matches the property + assert len(splits) == expected_iters + # Verify structure is sound + assert len(all_chunks) > 0 + for chunk in all_chunks: + assert chunk.stop - chunk.start > 0 + # ============================================================================= # Randomness and Reproducibility Tests @@ -301,42 +279,6 @@ def test_multiple_workers_not_supported(): list(sampler.sample(n_obs=500)) -@pytest.mark.parametrize( - ("masks", "num_samples", "id_suffix"), - [ - pytest.param([slice(0, 100)], 50, "single_large"), - pytest.param([slice(0, 100)], 100, "single_full"), - pytest.param([slice(0, 100), slice(200, 300)], 100, "two_large"), - pytest.param( - [slice(0, 100), slice(200, 300), slice(400, 500)], - 150, - "three_masks", - ), - ], -) -def test_edge_case_mask_counts(masks: list[slice], num_samples: int, id_suffix: str): - """Test sampling with various mask and sample counts.""" - chunk_size, preload_nchunks, batch_size = 10, 2, 5 - sampler = FragmentedRandomSampler( - chunk_size=chunk_size, - preload_nchunks=preload_nchunks, - batch_size=batch_size, - masks=masks, - num_samples=num_samples, - rng=np.random.default_rng(42), - ) - n_obs = max(m.stop for m in masks) - all_indices, all_chunks, splits = collect_indices(sampler, n_obs) - - # Verify batch count - expected_iters = math.ceil(num_samples / batch_size) - assert len(splits) == expected_iters - # Verify structure is sound - assert len(all_chunks) > 0 - for chunk in all_chunks: - assert chunk.stop - chunk.start > 0 - - @pytest.mark.parametrize( ("chunk_size", "preload_nchunks", "batch_size", "should_fail"), [ From 808fb5599f34577c64877fdff3181066412b33b6 Mon Sep 17 00:00:00 2001 From: selmanozleyen Date: Tue, 26 May 2026 15:33:22 +0200 Subject: [PATCH 11/14] merge some unit tests --- tests/test_fragmented_random_sampler.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/tests/test_fragmented_random_sampler.py b/tests/test_fragmented_random_sampler.py index 3ce33ba..49f136d 100644 --- a/tests/test_fragmented_random_sampler.py +++ b/tests/test_fragmented_random_sampler.py @@ -136,9 +136,7 @@ def test_mask_coverage( pytest.param([slice(0, 100)], 52, 5, True, 10, id="floor_division"), pytest.param([slice(0, 100), slice(200, 300)], 75, 10, False, 8, id="multiple_masks"), pytest.param([slice(0, 100), slice(200, 300)], 75, 10, True, 7, id="multiple_masks_drop"), - pytest.param( - [slice(0, 100), slice(200, 300), slice(400, 500)], 150, 5, False, 30, id="three_masks" - ), + pytest.param([slice(0, 100), slice(200, 300), slice(400, 500)], 150, 5, False, 30, id="three_masks"), ], ) def test_batch_and_iteration_counts( From 42d42a78f55960bd3a3fa4e4361cd8327e560670 Mon Sep 17 00:00:00 2001 From: selmanozleyen Date: Tue, 26 May 2026 15:51:08 +0200 Subject: [PATCH 12/14] rewrite header --- src/annbatch/samplers/_fragmented_random_sampler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/annbatch/samplers/_fragmented_random_sampler.py b/src/annbatch/samplers/_fragmented_random_sampler.py index 87c893f..6a14561 100644 --- a/src/annbatch/samplers/_fragmented_random_sampler.py +++ b/src/annbatch/samplers/_fragmented_random_sampler.py @@ -1,4 +1,4 @@ -"""SequentialSampler -- ordered chunk-based sampler.""" +"""FragmentedRandomSampler""" from __future__ import annotations From 449d473c63d04aee7dbd5451693c19bb624180aa Mon Sep 17 00:00:00 2001 From: selmanozleyen Date: Fri, 29 May 2026 16:41:34 +0200 Subject: [PATCH 13/14] categorical sampler in similar fashion to demonstrate --- src/annbatch/samplers/__init__.py | 2 + src/annbatch/samplers/_categorical_sampler.py | 213 ++++++++++++++ tests/test_categorical_sampler.py | 260 ++++++++++++++++++ 3 files changed, 475 insertions(+) create mode 100644 src/annbatch/samplers/_categorical_sampler.py create mode 100644 tests/test_categorical_sampler.py diff --git a/src/annbatch/samplers/__init__.py b/src/annbatch/samplers/__init__.py index 3330dad..206f7ba 100644 --- a/src/annbatch/samplers/__init__.py +++ b/src/annbatch/samplers/__init__.py @@ -1,9 +1,11 @@ +from ._categorical_sampler import CategoricalSampler from ._chunk_sampler import ChunkSampler from ._distributed_sampler import DistributedSampler from ._random_sampler import RandomSampler from ._sequential_sampler import SequentialSampler __all__ = [ + "CategoricalSampler", "ChunkSampler", "DistributedSampler", "RandomSampler", diff --git a/src/annbatch/samplers/_categorical_sampler.py b/src/annbatch/samplers/_categorical_sampler.py new file mode 100644 index 0000000..13cf5ed --- /dev/null +++ b/src/annbatch/samplers/_categorical_sampler.py @@ -0,0 +1,213 @@ +"""CategoricalSampler -- single-pass, vectorized categorical chunk sampler. + +This is the "one sampler, not N composed samplers" alternative to wrapping a +:class:`~annbatch.samplers.FragmentedRandomSampler` per category. Instead of +holding a list of slices (or a per-category sampler object) it keeps only: + +* the run-length encoding (RLE) of contiguous category runs as numpy int arrays, and +* per-category offsets into a single flat prefix-sum of valid chunk-start positions. + +A whole epoch's chunks are then drawn in one vectorized numpy pass -- no python +loop over categories, no per-category :class:`numpy.random.Generator`, and memory +that scales with the number of *runs* (``<= n_obs / chunk_size``) rather than with +``n_categories * n_obs``. + +Only uniform balancing (each category drawn equally often) is implemented for now, +to keep the comparison with :class:`~annbatch.samplers.FragmentedRandomSampler` simple. +""" + +from __future__ import annotations + +import math +from typing import TYPE_CHECKING + +import numpy as np + +from annbatch.abc import Sampler +from annbatch.samplers._chunk_sampler import iter_from_chunks, validate_chunk_batch_preload_sizes +from annbatch.samplers._utils import get_torch_worker_info + +if TYPE_CHECKING: + from collections.abc import Iterator + + from annbatch.types import LoadRequest + + +class CategoricalSampler(Sampler): + """Category-coherent random sampler over a fragmented categorical column. + + Every chunk that is yielded lies entirely within a single category, so each + on-disk read stays contiguous *and* its category label is known for free. + Categories are drawn uniformly (each equally often, regardless of size); the + distribution *within* a category is uniform over its valid chunk-start positions. + + Sampling is with replacement (chunks are drawn independently), mirroring + :class:`~annbatch.samplers.FragmentedRandomSampler`. ``num_samples`` controls + the total number of observations drawn per epoch. + + Multiple workers are not supported with this sampler. + + Parameters + ---------- + chunk_size + Size of each chunk i.e. the range of each chunk yielded. + preload_nchunks + Number of chunks to load per iteration. + batch_size + Number of observations per batch. + codes + Integer category code per observation, e.g. ``df["cell_type"].cat.codes`` + from the input dataframe. Length must equal the loader's ``n_obs``. + Codes do not need to be contiguous on the obs axis -- a category may be + spread across many runs (fragments). + num_samples + Total number of observations to draw per epoch. + drop_last + Whether to drop the last incomplete batch. + rng + Random number generator. Note that :func:`torch.manual_seed` has no effect + here; pass a seeded :class:`numpy.random.Generator` to control randomness. + """ + + _batch_size: int + _chunk_size: int + _preload_nchunks: int + _num_samples: int + + def __init__( + self, + chunk_size: int, + preload_nchunks: int, + batch_size: int, + *, + codes: np.ndarray, + num_samples: int, + drop_last: bool = False, + rng: np.random.Generator | None = None, + ): + validate_chunk_batch_preload_sizes(chunk_size, preload_nchunks, batch_size) + + codes = np.asarray(codes) + if codes.ndim != 1: + raise ValueError("codes must be a 1D array of category codes (one per observation).") + + self._n_obs = int(codes.shape[0]) + self._rng = rng or np.random.default_rng() + self._num_samples = num_samples + self._drop_last = drop_last + self._batch_size, self._chunk_size, self._preload_nchunks = batch_size, chunk_size, preload_nchunks + + self._build_runs(codes, chunk_size) + + def _build_runs(self, codes: np.ndarray, chunk_size: int) -> None: + """RLE the codes and precompute the per-category prefix-sum of chunk positions.""" + # contiguous runs of identical category code + boundaries = np.flatnonzero(np.diff(codes)) + 1 + run_start = np.concatenate([np.array([0], dtype=np.int64), boundaries.astype(np.int64)]) + run_len = np.concatenate([boundaries, np.array([codes.shape[0]])]).astype(np.int64) - run_start + run_cat = codes[run_start] + + # only runs that can host a full chunk contribute (the >= chunk_size invariant). + # runs shorter than chunk_size are dropped -- see the design note in the PR thread. + keep = run_len >= chunk_size + run_start, run_len, run_cat = run_start[keep], run_len[keep], run_cat[keep] + n_pos = run_len - chunk_size + 1 # number of valid chunk-start positions in the run + + # group runs by category so each category owns a contiguous span of `cum` + order = np.argsort(run_cat, kind="stable") + run_start, run_cat, n_pos = run_start[order], run_cat[order], n_pos[order] + + cum = np.concatenate([np.array([0], dtype=np.int64), np.cumsum(n_pos)]) + cat_ids, first = np.unique(run_cat, return_index=True) + last = np.append(first[1:], len(run_cat)) + + if cat_ids.size == 0: + raise ValueError(f"No category has a run of at least chunk_size ({chunk_size}) observations.") + + self._run_start = run_start + self._cum = cum + self._cat_ids = cat_ids # category codes that have at least one full-chunk run + self._cat_base = cum[first] # offset into `cum` where each category begins + self._cat_total = cum[last] - cum[first] # # of valid chunk positions per category + + @property + def categories(self) -> np.ndarray: + """Category codes this sampler draws from.""" + return self._cat_ids + + @property + def mask(self) -> slice: + raise NotImplementedError( + "mask property is not implemented for CategoricalSampler since it operates on a categorical column." + ) + + @mask.setter + def mask(self, value: slice) -> None: + raise NotImplementedError( + "mask property is not implemented for CategoricalSampler since it operates on a categorical column." + ) + + @property + def batch_size(self) -> int: + return self._batch_size + + @property + def shuffle(self) -> bool: + return True + + def n_iters(self, n_obs: int) -> int: + del n_obs # determined by num_samples, not the loader size + return ( + self._num_samples // self.batch_size if self._drop_last else math.ceil(self._num_samples / self.batch_size) + ) + + def validate(self, n_obs: int) -> None: + """Validate that the codes describe exactly the loader's observations.""" + if n_obs != self._n_obs: + raise ValueError( + f"codes length ({self._n_obs}) does not match loader n_obs ({n_obs}). " + "The categorical column must describe exactly the loader's observations." + ) + + def _sample(self, n_obs: int) -> Iterator[LoadRequest]: + del n_obs # nothing inferred from n_obs + worker_info = get_torch_worker_info() + if worker_info is not None and worker_info.num_workers > 1: + raise NotImplementedError("Multiple workers are not supported with CategoricalSampler.") + + chunks = self._compute_chunks() + return iter_from_chunks( + chunks=chunks, + batch_rng=self._rng, + preload_nchunks=self._preload_nchunks, + batch_size=self._batch_size, + drop_last=self._drop_last, + chunk_size=self._chunk_size, + shuffle=True, + worker_info=None, + ) + + def _compute_chunks(self) -> list[slice]: + n_chunks, remainder = divmod(self._num_samples, self._chunk_size) + if remainder > 0 and not self._drop_last: + n_chunks += 1 + + # 1) pick a category for each chunk uniformly + cat_of_draw = self._rng.integers(len(self._cat_ids), size=n_chunks) + + # 2) one uniform draw within each chosen category's flat span of valid positions + local_off = (self._rng.random(n_chunks) * self._cat_total[cat_of_draw]).astype(np.int64) + global_off = self._cat_base[cat_of_draw] + local_off + + # 3) map the flat offset -> run -> absolute chunk start (the searchsorted trick, + # generalized across every category at once) + run_idx = np.searchsorted(self._cum, global_off, side="right") - 1 + within = global_off - self._cum[run_idx] + chunk_starts = self._run_start[run_idx] + within + # NB: self._cat_ids[cat_of_draw] is the category label of each chunk, available for free. + + chunks = [slice(int(s), int(s + self._chunk_size)) for s in chunk_starts] + if remainder > 0 and not self._drop_last: + last = int(chunk_starts[-1]) + chunks[-1] = slice(last, last + remainder) + return chunks diff --git a/tests/test_categorical_sampler.py b/tests/test_categorical_sampler.py new file mode 100644 index 0000000..6bd6835 --- /dev/null +++ b/tests/test_categorical_sampler.py @@ -0,0 +1,260 @@ +"""Tests for CategoricalSampler. + +The passing tests check the sampler does what it promises: every chunk is +category-coherent, categories are drawn uniformly, and the bookkeeping +(``num_samples`` / ``n_iters`` / validation) is correct. + +The final test (``test_pure_categorical_batches_unsupported``) is expected to +**fail**. It is deliberately not marked ``xfail``: it documents, with the real +:class:`~annbatch.Loader` ordering contract, why the sampler cannot currently +yield *category-pure batches*. See the PR thread. +""" + +from __future__ import annotations + +from unittest.mock import patch + +import numpy as np +import pytest + +from annbatch.samplers import CategoricalSampler +from annbatch.samplers._utils import WorkerInfo + + +def _chunk_categories(chunks: list[slice], codes: np.ndarray) -> list[int]: + """Category of each chunk, or -1 if the chunk straddles more than one category.""" + out = [] + for c in chunks: + u = np.unique(codes[c]) + out.append(int(u[0]) if u.size == 1 else -1) + return out + + +def _collect_chunks(sampler: CategoricalSampler, n_obs: int) -> list[slice]: + chunks: list[slice] = [] + for load_request in sampler.sample(n_obs): + chunks.extend(load_request["chunks"]) + return chunks + + +# ============================================================================= +# Construction / validation +# ============================================================================= + + +def test_codes_must_be_1d(): + with pytest.raises(ValueError, match="1D array"): + CategoricalSampler( + chunk_size=10, + preload_nchunks=2, + batch_size=5, + codes=np.zeros((10, 2), dtype=int), + num_samples=50, + ) + + +def test_no_category_large_enough_raises(): + # every run is shorter than chunk_size + codes = np.array([0, 0, 1, 1, 2, 2], dtype=np.int64) + with pytest.raises(ValueError, match="at least chunk_size"): + CategoricalSampler( + chunk_size=10, + preload_nchunks=2, + batch_size=5, + codes=codes, + num_samples=10, + ) + + +def test_validate_rejects_n_obs_mismatch(): + codes = np.repeat([0, 1], 50) + sampler = CategoricalSampler( + chunk_size=10, + preload_nchunks=2, + batch_size=10, + codes=codes, + num_samples=50, + ) + with pytest.raises(ValueError, match="does not match loader n_obs"): + sampler.validate(n_obs=999) + + +def test_mask_property_not_supported(): + codes = np.repeat([0, 1], 50) + sampler = CategoricalSampler(chunk_size=10, preload_nchunks=2, batch_size=10, codes=codes, num_samples=50) + with pytest.raises(NotImplementedError, match="mask property"): + _ = sampler.mask + with pytest.raises(NotImplementedError, match="mask property"): + sampler.mask = slice(0, 10) + + +def test_multiple_workers_not_supported(): + codes = np.repeat([0, 1], 50) + sampler = CategoricalSampler(chunk_size=10, preload_nchunks=2, batch_size=10, codes=codes, num_samples=50) + with ( + patch( + "annbatch.samplers._categorical_sampler.get_torch_worker_info", + return_value=WorkerInfo(id=0, num_workers=2), + ), + pytest.raises(NotImplementedError, match="Multiple workers"), + ): + list(sampler.sample(len(codes))) + + +def test_runs_shorter_than_chunk_size_are_dropped(): + # cat 0 has one good run (>= chunk_size) and one tiny run; cat 1 only a good run. + codes = np.array([0] * 30 + [1] * 30 + [0] * 3, dtype=np.int64) + sampler = CategoricalSampler( + chunk_size=10, preload_nchunks=2, batch_size=10, codes=codes, num_samples=200, rng=np.random.default_rng(0) + ) + # both categories are still sampleable; the 3-row tail of cat 0 is never a chunk start + assert set(sampler.categories.tolist()) == {0, 1} + chunks = _collect_chunks(sampler, len(codes)) + assert all(c.stop <= 60 for c in chunks), "no chunk should reach into the dropped 3-row run" + + +# ============================================================================= +# Core behavior +# ============================================================================= + + +@pytest.mark.parametrize( + "codes", + [ + pytest.param(np.repeat([0, 1, 2, 3], 100), id="contiguous"), + # each category fragmented into two runs interleaved with the others + pytest.param(np.array([0] * 50 + [1] * 50 + [0] * 50 + [1] * 50), id="fragmented"), + pytest.param(np.array([2] * 40 + [0] * 40 + [1] * 40 + [0] * 40 + [2] * 40), id="fragmented_3cat"), + ], +) +def test_chunks_are_category_coherent(codes: np.ndarray): + sampler = CategoricalSampler( + chunk_size=10, preload_nchunks=4, batch_size=10, codes=codes, num_samples=1000, rng=np.random.default_rng(0) + ) + chunks = _collect_chunks(sampler, len(codes)) + cats = _chunk_categories(chunks, codes) + assert -1 not in cats, "every chunk must lie entirely within a single category" + # chunks stay in-bounds and are full size (num_samples is a multiple of chunk_size here) + assert all(0 <= c.start and c.stop <= len(codes) and c.stop - c.start == 10 for c in chunks) + + +def test_fragmented_category_samples_all_runs(): + # cat 0 lives in two separate runs; over many draws both should be hit. + codes = np.array([0] * 50 + [1] * 50 + [0] * 50, dtype=np.int64) + sampler = CategoricalSampler( + chunk_size=10, preload_nchunks=4, batch_size=10, codes=codes, num_samples=5000, rng=np.random.default_rng(0) + ) + chunks = _collect_chunks(sampler, len(codes)) + starts_cat0 = [c.start for c in chunks if codes[c.start] == 0] + hit_first_run = any(s < 50 for s in starts_cat0) + hit_second_run = any(s >= 100 for s in starts_cat0) + assert hit_first_run and hit_second_run, "both fragments of category 0 should be sampled" + + +def test_categories_drawn_uniformly(): + # 4 categories of different sizes -- uniform balancing means equal *chunk counts*, + # independent of category size. + codes = np.array([0] * 200 + [1] * 100 + [2] * 400 + [3] * 100, dtype=np.int64) + sampler = CategoricalSampler( + chunk_size=10, preload_nchunks=4, batch_size=10, codes=codes, num_samples=40_000, rng=np.random.default_rng(0) + ) + chunks = _collect_chunks(sampler, len(codes)) + cats = np.array(_chunk_categories(chunks, codes)) + _, counts = np.unique(cats, return_counts=True) + shares = counts / counts.sum() + assert np.allclose(shares, 0.25, atol=0.02), f"expected ~uniform 0.25 per category, got {shares}" + + +@pytest.mark.parametrize( + ("num_samples", "batch_size", "drop_last", "expected_iters"), + [ + pytest.param(100, 10, False, 10, id="exact"), + pytest.param(105, 10, False, 11, id="partial_kept"), + pytest.param(105, 10, True, 10, id="partial_dropped"), + ], +) +def test_n_iters(num_samples: int, batch_size: int, drop_last: bool, expected_iters: int): + codes = np.repeat([0, 1], 100) + sampler = CategoricalSampler( + chunk_size=10, + preload_nchunks=2, + batch_size=batch_size, + codes=codes, + num_samples=num_samples, + drop_last=drop_last, + rng=np.random.default_rng(0), + ) + assert sampler.n_iters(len(codes)) == expected_iters + + +def test_num_samples_respected(): + codes = np.repeat([0, 1, 2], 100) + sampler = CategoricalSampler( + chunk_size=10, preload_nchunks=3, batch_size=10, codes=codes, num_samples=300, rng=np.random.default_rng(0) + ) + total = sum(len(s) for lr in sampler.sample(len(codes)) for s in lr["splits"]) + assert total == 300 + + +# ============================================================================= +# The inevitable failure: category-pure *batches* are not achievable today. +# ============================================================================= + + +def _loader_in_memory_global_index(chunks: list[slice], dataset_shapes: list[int]) -> np.ndarray: + """Replay ``Loader._slices_to_dataset_rows`` ordering (see src/annbatch/loader.py:481). + + The loader groups requested chunks by dataset index and concatenates the + in-memory buffer **in dataset order**, not in the chunk order the sampler + emitted. ``splits`` then index into this reordered buffer (loader.py:797). + This returns, for each buffer position, the global obs index living there. + """ + global_index = np.concatenate([np.arange(s.start, s.stop) for s in chunks]) + ordered: list[np.ndarray] = [] + b_start = 0 + for shape in dataset_shapes: + b_end = b_start + shape + mask = (global_index >= b_start) & (global_index < b_end) + if mask.any(): + ordered.append(global_index[mask]) + b_start = b_end + return np.concatenate(ordered) + + +def test_pure_categorical_batches_unsupported(): + """EXPECTED TO FAIL (not xfail) -- documents the pure-batch gap. + + Each category is fragmented across two datasets, and every individual chunk + is category-coherent. Yet the batches the loader yields are not category-pure: + the loader reorders the in-memory buffer by dataset index (which the sampler + is blind to) and the sampler shuffles across categories within a preload + window. To fix this the Sampler API needs dataset-boundary information and + per-category splits. + """ + # ds0 = rows [0, 40), ds1 = rows [40, 80); each category straddles both datasets + codes = np.array([0] * 20 + [1] * 20 + [0] * 20 + [1] * 20, dtype=np.int64) + dataset_shapes = [40, 40] + + sampler = CategoricalSampler( + chunk_size=10, + preload_nchunks=2, + batch_size=10, + codes=codes, + num_samples=80, + rng=np.random.default_rng(0), + ) + + impure = 0 + for load_request in sampler.sample(len(codes)): + buffer_global = _loader_in_memory_global_index(load_request["chunks"], dataset_shapes) + for split in load_request["splits"]: + batch_cats = codes[buffer_global[split]] + if np.unique(batch_cats).size != 1: + impure += 1 + + assert impure == 0, ( + f"{impure} batch(es) mixed categories despite every chunk being category-coherent. " + "The loader concatenates the in-memory buffer by dataset index (loader.py:481) and the " + "sampler shuffles across categories within a preload window, so splits cannot carve out " + "category-pure batches without dataset-boundary awareness in the Sampler API." + ) From 27f995e33e480d50e294f7850d7bcf81225dd450 Mon Sep 17 00:00:00 2001 From: selmanozleyen Date: Fri, 29 May 2026 23:18:47 +0200 Subject: [PATCH 14/14] add category weights and category subset features --- src/annbatch/samplers/_categorical_sampler.py | 86 +++++++++--- tests/test_categorical_sampler.py | 122 ++++++++++++++++-- 2 files changed, 178 insertions(+), 30 deletions(-) diff --git a/src/annbatch/samplers/_categorical_sampler.py b/src/annbatch/samplers/_categorical_sampler.py index 13cf5ed..4094635 100644 --- a/src/annbatch/samplers/_categorical_sampler.py +++ b/src/annbatch/samplers/_categorical_sampler.py @@ -11,9 +11,6 @@ loop over categories, no per-category :class:`numpy.random.Generator`, and memory that scales with the number of *runs* (``<= n_obs / chunk_size``) rather than with ``n_categories * n_obs``. - -Only uniform balancing (each category drawn equally often) is implemented for now, -to keep the comparison with :class:`~annbatch.samplers.FragmentedRandomSampler` simple. """ from __future__ import annotations @@ -38,8 +35,16 @@ class CategoricalSampler(Sampler): Every chunk that is yielded lies entirely within a single category, so each on-disk read stays contiguous *and* its category label is known for free. - Categories are drawn uniformly (each equally often, regardless of size); the - distribution *within* a category is uniform over its valid chunk-start positions. + The distribution *within* a category is uniform over its valid chunk-start + positions; the distribution *over* categories is uniform by default and can + be reshaped with ``category_weights`` (e.g. pass per-category observation + counts for proportional sampling). + + **Run-length rule.** Every contiguous run of a category must be at least + ``chunk_size`` observations long. Otherwise no chunk-size read could ever + land inside it, so rather than silently ignoring such a run the sampler + raises at construction and names the offending categories. Re-chunk the data + (so each category's fragments are large enough) or lower ``chunk_size``. Sampling is with replacement (chunks are drawn independently), mirroring :class:`~annbatch.samplers.FragmentedRandomSampler`. ``num_samples`` controls @@ -59,9 +64,20 @@ class CategoricalSampler(Sampler): Integer category code per observation, e.g. ``df["cell_type"].cat.codes`` from the input dataframe. Length must equal the loader's ``n_obs``. Codes do not need to be contiguous on the obs axis -- a category may be - spread across many runs (fragments). + spread across many runs (fragments) -- but every run must be at least + ``chunk_size`` long (see the run-length rule above). num_samples Total number of observations to draw per epoch. + categories + Optional subset of category codes to sample from. When ``None`` (the + default) every category present in ``codes`` is sampled. Requested codes + must exist in ``codes``, and the run-length rule is only enforced for the + selected categories (a short run in a category you do not sample is fine). + category_weights + Optional unnormalized weights aligned with :attr:`categories` controlling + how often each category is drawn. When ``None`` (the default) categories + are drawn uniformly. For proportional (≈ plain global random) sampling, + pass each category's observation count, e.g. ``np.bincount(codes)[sampler.categories]``. drop_last Whether to drop the last incomplete batch. rng @@ -82,6 +98,8 @@ def __init__( *, codes: np.ndarray, num_samples: int, + categories: np.ndarray | None = None, + category_weights: np.ndarray | None = None, drop_last: bool = False, rng: np.random.Generator | None = None, ): @@ -97,9 +115,10 @@ def __init__( self._drop_last = drop_last self._batch_size, self._chunk_size, self._preload_nchunks = batch_size, chunk_size, preload_nchunks - self._build_runs(codes, chunk_size) + self._build_runs(codes, chunk_size, categories) + self._build_category_probs(category_weights) - def _build_runs(self, codes: np.ndarray, chunk_size: int) -> None: + def _build_runs(self, codes: np.ndarray, chunk_size: int, categories: np.ndarray | None) -> None: """RLE the codes and precompute the per-category prefix-sum of chunk positions.""" # contiguous runs of identical category code boundaries = np.flatnonzero(np.diff(codes)) + 1 @@ -107,32 +126,59 @@ def _build_runs(self, codes: np.ndarray, chunk_size: int) -> None: run_len = np.concatenate([boundaries, np.array([codes.shape[0]])]).astype(np.int64) - run_start run_cat = codes[run_start] - # only runs that can host a full chunk contribute (the >= chunk_size invariant). - # runs shorter than chunk_size are dropped -- see the design note in the PR thread. - keep = run_len >= chunk_size - run_start, run_len, run_cat = run_start[keep], run_len[keep], run_cat[keep] + # restrict to the requested subset of categories, if any + if categories is not None: + categories = np.asarray(categories) + missing = np.setdiff1d(categories, np.unique(codes)) + if missing.size: + raise ValueError(f"Requested categories {missing.tolist()} are not present in codes.") + keep = np.isin(run_cat, categories) + run_start, run_len, run_cat = run_start[keep], run_len[keep], run_cat[keep] + if run_cat.size == 0: + raise ValueError("No categories to sample from.") + + # run-length rule: every (selected) run must hold at least one full chunk. + too_short = run_len < chunk_size + if np.any(too_short): + bad = np.unique(run_cat[too_short]) + raise ValueError( + f"Every contiguous run must be at least chunk_size ({chunk_size}) observations long, " + f"but {int(too_short.sum())} run(s) are shorter (categories {bad.tolist()}). " + "Re-chunk the data so each category's fragments are large enough, or lower chunk_size." + ) n_pos = run_len - chunk_size + 1 # number of valid chunk-start positions in the run # group runs by category so each category owns a contiguous span of `cum` order = np.argsort(run_cat, kind="stable") - run_start, run_cat, n_pos = run_start[order], run_cat[order], n_pos[order] + run_start, run_cat, run_len, n_pos = run_start[order], run_cat[order], run_len[order], n_pos[order] cum = np.concatenate([np.array([0], dtype=np.int64), np.cumsum(n_pos)]) cat_ids, first = np.unique(run_cat, return_index=True) last = np.append(first[1:], len(run_cat)) - if cat_ids.size == 0: - raise ValueError(f"No category has a run of at least chunk_size ({chunk_size}) observations.") - self._run_start = run_start self._cum = cum - self._cat_ids = cat_ids # category codes that have at least one full-chunk run + self._cat_ids = cat_ids self._cat_base = cum[first] # offset into `cum` where each category begins self._cat_total = cum[last] - cum[first] # # of valid chunk positions per category + def _build_category_probs(self, category_weights: np.ndarray | None) -> None: + if category_weights is None: + weights = np.ones(self._cat_ids.shape, dtype=float) + else: + weights = np.asarray(category_weights, dtype=float) + if weights.shape != self._cat_ids.shape: + raise ValueError( + f"category_weights must align with categories (expected shape {self._cat_ids.shape}, " + f"got {weights.shape}). See the `categories` property for the expected order." + ) + if np.any(weights < 0) or weights.sum() == 0: + raise ValueError("category_weights must be non-negative and not all zero.") + self._probs = weights / weights.sum() + @property def categories(self) -> np.ndarray: - """Category codes this sampler draws from.""" + """Category codes this sampler draws from, in the order ``category_weights`` expects.""" return self._cat_ids @property @@ -192,8 +238,8 @@ def _compute_chunks(self) -> list[slice]: if remainder > 0 and not self._drop_last: n_chunks += 1 - # 1) pick a category for each chunk uniformly - cat_of_draw = self._rng.integers(len(self._cat_ids), size=n_chunks) + # 1) pick a category for each chunk according to the sampling policy + cat_of_draw = self._rng.choice(len(self._cat_ids), size=n_chunks, p=self._probs) # 2) one uniform draw within each chosen category's flat span of valid positions local_off = (self._rng.random(n_chunks) * self._cat_total[cat_of_draw]).astype(np.int64) diff --git a/tests/test_categorical_sampler.py b/tests/test_categorical_sampler.py index 6bd6835..bbaba5c 100644 --- a/tests/test_categorical_sampler.py +++ b/tests/test_categorical_sampler.py @@ -53,7 +53,7 @@ def test_codes_must_be_1d(): ) -def test_no_category_large_enough_raises(): +def test_all_runs_shorter_than_chunk_size_raises(): # every run is shorter than chunk_size codes = np.array([0, 0, 1, 1, 2, 2], dtype=np.int64) with pytest.raises(ValueError, match="at least chunk_size"): @@ -101,16 +101,11 @@ def test_multiple_workers_not_supported(): list(sampler.sample(len(codes))) -def test_runs_shorter_than_chunk_size_are_dropped(): - # cat 0 has one good run (>= chunk_size) and one tiny run; cat 1 only a good run. +def test_any_run_shorter_than_chunk_size_raises(): + # run-length rule: cat 0 has a good run AND a tiny 3-row run -> must raise, naming cat 0. codes = np.array([0] * 30 + [1] * 30 + [0] * 3, dtype=np.int64) - sampler = CategoricalSampler( - chunk_size=10, preload_nchunks=2, batch_size=10, codes=codes, num_samples=200, rng=np.random.default_rng(0) - ) - # both categories are still sampleable; the 3-row tail of cat 0 is never a chunk start - assert set(sampler.categories.tolist()) == {0, 1} - chunks = _collect_chunks(sampler, len(codes)) - assert all(c.stop <= 60 for c in chunks), "no chunk should reach into the dropped 3-row run" + with pytest.raises(ValueError, match=r"at least chunk_size.*\[0\]"): + CategoricalSampler(chunk_size=10, preload_nchunks=2, batch_size=10, codes=codes, num_samples=200) # ============================================================================= @@ -165,6 +160,113 @@ def test_categories_drawn_uniformly(): assert np.allclose(shares, 0.25, atol=0.02), f"expected ~uniform 0.25 per category, got {shares}" +def test_select_subset_of_categories(): + codes = np.array([0] * 50 + [1] * 50 + [2] * 50 + [3] * 50, dtype=np.int64) + sampler = CategoricalSampler( + chunk_size=10, + preload_nchunks=4, + batch_size=10, + codes=codes, + num_samples=2000, + categories=np.array([0, 2]), + rng=np.random.default_rng(0), + ) + assert list(sampler.categories) == [0, 2] + chunks = _collect_chunks(sampler, len(codes)) + drawn = {int(np.unique(codes[c])[0]) for c in chunks} + assert drawn == {0, 2}, f"only selected categories should be sampled, got {drawn}" + + +def test_select_missing_category_raises(): + codes = np.repeat([0, 1], 50) + with pytest.raises(ValueError, match=r"\[5\].*not present in codes"): + CategoricalSampler( + chunk_size=10, preload_nchunks=2, batch_size=10, codes=codes, num_samples=50, categories=np.array([0, 5]) + ) + + +def test_subset_ignores_short_runs_of_unselected_categories(): + # cat 1 has a too-short run, but we only sample cats 0 and 2 -> must NOT raise. + codes = np.array([0] * 30 + [1] * 3 + [2] * 30, dtype=np.int64) + sampler = CategoricalSampler( + chunk_size=10, preload_nchunks=2, batch_size=10, codes=codes, num_samples=100, categories=np.array([0, 2]) + ) + assert list(sampler.categories) == [0, 2] + # but selecting the offending category surfaces the run-length rule + with pytest.raises(ValueError, match="at least chunk_size"): + CategoricalSampler( + chunk_size=10, preload_nchunks=2, batch_size=10, codes=codes, num_samples=100, categories=np.array([1]) + ) + + +def test_weights_align_with_selected_subset(): + codes = np.array([0] * 50 + [1] * 50 + [2] * 50, dtype=np.int64) + sampler = CategoricalSampler( + chunk_size=10, + preload_nchunks=4, + batch_size=10, + codes=codes, + num_samples=40_000, + categories=np.array([0, 2]), + category_weights=np.array([3.0, 1.0]), # aligned with [0, 2] -> 0.75 / 0.25 + rng=np.random.default_rng(0), + ) + assert np.allclose(_chunk_shares(sampler, codes), [0.75, 0.25], atol=0.02) + + +def _chunk_shares(sampler: CategoricalSampler, codes: np.ndarray) -> np.ndarray: + chunks = _collect_chunks(sampler, len(codes)) + cats = np.array(_chunk_categories(chunks, codes)) + _, counts = np.unique(cats, return_counts=True) + return counts / counts.sum() + + +def test_proportional_sampling_via_weights(): + # "proportional" is just weights == per-category observation counts. + codes = np.array([0] * 300 + [1] * 100, dtype=np.int64) # 75% / 25% + sampler = CategoricalSampler( + chunk_size=10, + preload_nchunks=4, + batch_size=10, + codes=codes, + num_samples=40_000, + category_weights=np.bincount(codes), # [300, 100] -> 0.75 / 0.25 + rng=np.random.default_rng(0), + ) + assert np.allclose(_chunk_shares(sampler, codes), [0.75, 0.25], atol=0.02) + + +def test_explicit_category_weights(): + codes = np.array([0] * 100 + [1] * 100 + [2] * 100, dtype=np.int64) + sampler = CategoricalSampler( + chunk_size=10, + preload_nchunks=4, + batch_size=10, + codes=codes, + num_samples=40_000, + category_weights=np.array([6.0, 3.0, 1.0]), # -> 0.6 / 0.3 / 0.1 + rng=np.random.default_rng(0), + ) + assert list(sampler.categories) == [0, 1, 2] + assert np.allclose(_chunk_shares(sampler, codes), [0.6, 0.3, 0.1], atol=0.02) + + +@pytest.mark.parametrize( + ("weights", "match"), + [ + pytest.param(np.array([1.0, 1.0, 1.0]), "must align with categories", id="wrong_shape"), + pytest.param(np.array([1.0, -1.0]), "non-negative", id="negative"), + pytest.param(np.array([0.0, 0.0]), "not all zero", id="all_zero"), + ], +) +def test_invalid_category_weights(weights: np.ndarray, match: str): + codes = np.repeat([0, 1], 50) + with pytest.raises(ValueError, match=match): + CategoricalSampler( + chunk_size=10, preload_nchunks=2, batch_size=10, codes=codes, num_samples=50, category_weights=weights + ) + + @pytest.mark.parametrize( ("num_samples", "batch_size", "drop_last", "expected_iters"), [