From 5c472a3c69266e22556eb23f7e0aba27ea607cde Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adam=20Kuku=C4=8Dka?= Date: Sat, 28 Feb 2026 17:36:43 +0000 Subject: [PATCH 1/6] feat: add StratifiedGroupShuffleSplit and train_test_split with group support; update dependencies --- pyproject.toml | 5 + ratiopath/model_selection/__init__.py | 10 + ratiopath/model_selection/split.py | 300 ++++++++++++++++++++++++++ tests/test_split.py | 57 +++++ uv.lock | 72 +++++++ 5 files changed, 444 insertions(+) create mode 100644 ratiopath/model_selection/__init__.py create mode 100644 ratiopath/model_selection/split.py create mode 100644 tests/test_split.py diff --git a/pyproject.toml b/pyproject.toml index 5400891..3740efc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,6 +27,7 @@ dependencies = [ "zarr>=3.1.1", "geopandas>=1.1.1", "rasterio>=1.4.3", + "scikit-learn>=1.8.0", ] [dependency-groups] @@ -36,6 +37,10 @@ dev = [ "pyarrow-stubs>=20.0.0.20251107", "ruff", ] +tests = [ + "openslide-bin>=4.0.0.8", + "pytest>=8.4.1", +] [project.optional-dependencies] docs = ["mkdocs-material>=9.6.18", "mkdocstrings[python]>=0.30.0"] diff --git a/ratiopath/model_selection/__init__.py b/ratiopath/model_selection/__init__.py new file mode 100644 index 0000000..44b6819 --- /dev/null +++ b/ratiopath/model_selection/__init__.py @@ -0,0 +1,10 @@ +from ratiopath.model_selection.split import ( + StratifiedGroupShuffleSplit, + train_test_split, +) + + +__all__ = [ + "StratifiedGroupShuffleSplit", + "train_test_split", +] diff --git a/ratiopath/model_selection/split.py b/ratiopath/model_selection/split.py new file mode 100644 index 0000000..18b4fb3 --- /dev/null +++ b/ratiopath/model_selection/split.py @@ -0,0 +1,300 @@ +import numbers +from collections.abc import Iterator +from itertools import chain +from typing import Any, TypeAlias + +import numpy as np +import pandas as pd +from scipy.sparse import spmatrix +from sklearn.model_selection import ( + BaseShuffleSplit, + GroupShuffleSplit, + ShuffleSplit, + StratifiedGroupKFold, + StratifiedShuffleSplit, +) +from sklearn.model_selection._split import GroupsConsumerMixin, _validate_shuffle_split +from sklearn.utils._array_api import get_namespace_and_device, move_to +from sklearn.utils._indexing import _safe_indexing +from sklearn.utils._param_validation import Interval, RealNotInt, validate_params +from sklearn.utils.validation import _num_samples, check_random_state, indexable + + +ArrayLike: TypeAlias = np.typing.ArrayLike +MatrixLike: TypeAlias = np.ndarray | pd.DataFrame | spmatrix +Int: TypeAlias = int | np.int8 | np.int16 | np.int32 | np.int64 +Float: TypeAlias = float | np.float16 | np.float32 | np.float64 + + +class StratifiedGroupShuffleSplit(GroupsConsumerMixin, BaseShuffleSplit): + """Stratified shuffle split with non-overlapping groups. + + Provides train/test indices to split data such that both stratification + (preserving class distribution) and grouping (non-overlapping groups between + splits) are maintained. + + This splitter combines the functionality of StratifiedShuffleSplit and + GroupShuffleSplit. It attempts to create folds which preserve the percentage + of samples from each class while ensuring that samples from the same group + do not appear in both train and test sets. + + Read more in the :ref:`User Guide `. + + Parameters: + n_splits: Number of re-shuffling & splitting iterations. + test_size: If float, should be between 0.0 and 1.0 and represent the proportion of + the dataset to include in the test split. If int, represents the absolute number + of test samples. If None, the value is set to the complement of the train size. + train_size: If float, should be between 0.0 and 1.0 and represent the proportion of + the dataset to include in the train split. If int, represents the absolute + number of train samples. If None, the value is automatically set to the + complement of the test size. + random_state: Controls the randomness of the training and testing indices. Pass an + int for reproducible output across multiple function calls. + See :term:`Glossary `. + + Examples: + >>> import numpy as np + >>> from ratiopath.model_selection import StratifiedGroupShuffleSplit + >>> X = np.array([[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12]]) + >>> y = np.array([0, 0, 1, 1, 0, 1]) + >>> groups = np.array([1, 1, 2, 2, 3, 3]) + >>> sgss = StratifiedGroupShuffleSplit(n_splits=2, random_state=42) + >>> for train_index, test_index in sgss.split(X, y, groups): + ... print(f"Train: {train_index}, Test: {test_index}") + Train: [0 1 2 3], Test: [4 5] + Train: [2 3 4 5], Test: [0 1] + + Notes: + The implementation finds the best stratification split by trying multiple splits + and selecting the one that minimizes the difference between the class + distributions in the original data and the test split. + + Groups appear exactly once in the test set across all splits. + """ + + def __init__( + self, + n_splits: Int = 5, + *, + test_size: None | Float = None, + train_size: None | Float = None, + random_state: np.random.RandomState | None | Int = None, + ) -> None: + super().__init__( + n_splits=n_splits, + test_size=test_size, + train_size=train_size, + random_state=random_state, + ) + self._default_test_size = 0.2 + + @staticmethod + def _get_distribution(labels: ArrayLike) -> np.ndarray: + _, counts = np.unique(labels, return_counts=True) + return counts / counts.sum() + + def split( + self, + X: list[str] | MatrixLike, # noqa: N803 + y: ArrayLike, + groups: Any = None, + ) -> Iterator[Any]: + """Generate indices to split data into training and test set. + + Parameters: + X: Training data, where ``n_samples`` is the number of samples and + ``n_features`` is the number of features. + y: The target variable for supervised learning problems. Stratification is + done based on the y labels. + groups: Group labels for the samples used while splitting the dataset into + train and test set. Must be provided. + + Yields: + train: The training set indices for that split. + test: The testing set indices for that split. + """ + n_samples = _num_samples(X) + n_train, n_test = _validate_shuffle_split( + n_samples, self.test_size, self.train_size, self._default_test_size + ) + + flipped = False + if n_test > n_train: + # Approximation using folds is terrible when the test set is larger than the train set + n_test, n_train = n_train, n_test + flipped = True + + n_splits = round(n_samples / n_test) + rng = check_random_state(self.random_state) + y = np.asarray(y) + + data_distribution = self._get_distribution(y) + min_diff: Float | None = None + train_index: np.ndarray | None = None + test_index: np.ndarray | None = None + + for _ in range(self.n_splits): + cv = StratifiedGroupKFold(n_splits=n_splits, shuffle=True, random_state=rng) + + for curr_train_index, curr_test_index in cv.split(X=X, y=y, groups=groups): + test_distribution = self._get_distribution(y[curr_test_index]) + + if len(test_distribution) == len(data_distribution): + diff = np.abs(test_distribution - data_distribution).sum() + else: + diff = float("inf") + + if min_diff is None or diff < min_diff: + min_diff = diff + train_index = curr_train_index + test_index = curr_test_index + + if flipped: + train_index, test_index = test_index, train_index + yield train_index, test_index + + +# https://github.com/scikit-learn/scikit-learn/blob/d3898d9d5/sklearn/model_selection/_split.py#L2757 +@validate_params( + { + "test_size": [ + Interval(RealNotInt, 0, 1, closed="neither"), + Interval(numbers.Integral, 1, None, closed="left"), + None, + ], + "train_size": [ + Interval(RealNotInt, 0, 1, closed="neither"), + Interval(numbers.Integral, 1, None, closed="left"), + None, + ], + "random_state": ["random_state"], + "shuffle": ["boolean"], + "stratify": ["array-like", None], + "groups": ["array-like", None], + }, + prefer_skip_nested_validation=True, +) +def train_test_split( + *arrays, + test_size: None | Float = None, + train_size: None | Float = None, + random_state: np.random.RandomState | None | Int = None, + shuffle: bool = True, + stratify: None | ArrayLike = None, + groups: None | ArrayLike = None, +) -> list: + """Split arrays or matrices into random train and test subsets. + + This is an extended version of ``sklearn.model_selection.train_test_split`` that + adds support for stratified splits with non-overlapping groups. When both + ``stratify`` and ``groups`` are provided, uses ``StratifiedGroupShuffleSplit`` to + ensure both class distributions and group separation are preserved. + + Parameters: + *arrays: sequence of indexables with same length / shape[0] + Allowed inputs are lists, numpy arrays, scipy-sparse matrices or pandas + dataframes. + test_size: If float, should be between 0.0 and 1.0 and represent the proportion + of the dataset to include in the test split. If int, represents the absolute + number of test samples. If None, the value is set to the complement of the + train size. If ``train_size`` is also None, it will be set to 0.25. + train_size: If float, should be between 0.0 and 1.0 and represent the proportion + of the dataset to include in the train split. If int, represents the + absolute number of train samples. If None, the value is automatically set to + the complement of the test size. + random_state: Controls the randomness of the training and testing indices. Pass + an int for reproducible output across multiple function calls. + See :term:`Glossary `. + shuffle: Whether or not to shuffle the data before splitting. If False, stratify + must be None. + stratify: If not None, data is split in a stratified fashion, using this as the + class labels. For binary or multiclass classification, this ensures that the + test and training sets have approximately the same percentage of samples of + each target class as the complete set. + groups: Group labels for the samples used while splitting the dataset into train + and test set. When provided with ``stratify``, ensures both stratification + and non-overlapping groups are maintained. + + Returns: + splitting: List containing train-test split of inputs. If ``shuffle=False``, the + ``train`` arrays will have shape ``[0:split_point]`` and ``test`` arrays + will have shape ``[split_point:n_samples]`` for each input. + + Examples: + >>> import numpy as np + >>> from ratiopath.model_selection import train_test_split + >>> X = np.array([[1, 2], [3, 4], [5, 6], [7, 8]]) + >>> y = np.array([0, 0, 1, 1]) + >>> groups = np.array([1, 1, 2, 2]) + >>> X_train, X_test, y_train, y_test = train_test_split( + ... X, y, test_size=0.25, random_state=42, stratify=y, groups=groups + ... ) + >>> X_train + array([[1, 2], + [5, 6], + [7, 8]]) + >>> X_test + array([[3, 4]]) + + Notes: + When ``shuffle=True`` and both ``stratify`` and ``groups`` are provided, uses + ``StratifiedGroupShuffleSplit`` to split the data, ensuring that: + + * The class distribution is preserved in train and test sets + * No group appears in both train and test sets + + When only one of ``stratify`` or ``groups`` is provided, uses the appropriate + single-constraint splitter. + + When ``shuffle=False``, a stratified split is not supported and ``stratify`` + must be None. + """ + n_arrays = len(arrays) + if n_arrays == 0: + raise ValueError("At least one array required as input") + + arrays = indexable(*arrays) + + n_samples = _num_samples(arrays[0]) + n_train, n_test = _validate_shuffle_split( + n_samples, test_size, train_size, default_test_size=0.25 + ) + + if shuffle is False: + if stratify is not None: + raise ValueError( + "Stratified train/test split is not implemented for shuffle=False" + ) + + train = np.arange(n_train) + test = np.arange(n_train, n_train + n_test) + + else: + # Just this branch is different from sklearn's implementation + if groups is not None: + if stratify is not None: + cvclass = StratifiedGroupShuffleSplit + else: + cvclass = GroupShuffleSplit + else: + cvclass = StratifiedShuffleSplit if stratify is not None else ShuffleSplit + + # It is safer to pass fractions, because some splitters calculate n_samplers + # as number of groups, not samples + cv = cvclass( + test_size=n_test / n_samples, + train_size=n_train / n_samples, + random_state=random_state, + ) + + train, test = next(cv.split(X=arrays[0], y=stratify, groups=groups)) + + xp, _, device = get_namespace_and_device(arrays[0]) + train, test = move_to(train, test, xp=xp, device=device) + + return list( + chain.from_iterable( + (_safe_indexing(a, train), _safe_indexing(a, test)) for a in arrays + ) + ) diff --git a/tests/test_split.py b/tests/test_split.py new file mode 100644 index 0000000..00ffb67 --- /dev/null +++ b/tests/test_split.py @@ -0,0 +1,57 @@ +import numpy as np + +from ratiopath.model_selection.split import ( + StratifiedGroupShuffleSplit, + train_test_split, +) + + +def test_train_test_split_with_groups_and_stratify(): + x = np.arange(12).reshape(6, 2) + y = np.array([0, 0, 1, 1, 0, 1]) + groups = np.array([1, 1, 2, 2, 3, 3]) + + # include groups as one of the arrays so we can inspect split groups + x_train, x_test, y_train, y_test, g_train, g_test = train_test_split( + x, y, groups, test_size=0.33, random_state=0, stratify=y, groups=groups + ) + + # ensure groups do not overlap between train and test + assert set(g_train).isdisjoint(set(g_test)) + + # ensure stratification roughly preserved in the test set + prop_full = (y == 0).sum() / len(y) + prop_test = (y_test == 0).sum() / len(y_test) + assert abs(prop_full - prop_test) <= 0.34 + + +def test_train_test_split_with_groups_no_stratify(): + x = np.arange(10).reshape(5, 2) + y = np.array([0, 1, 0, 1, 0]) + groups = np.array([1, 1, 2, 2, 3]) + + x_train, x_test, y_train, y_test, g_train, g_test = train_test_split( + x, y, groups, test_size=0.4, random_state=1, groups=groups + ) + + assert set(g_train).isdisjoint(set(g_test)) + + +def test_stratified_group_shuffle_split_splits(): + x = np.arange(12).reshape(6, 2) + y = np.array([0, 0, 1, 1, 0, 1]) + groups = np.array([1, 1, 2, 2, 3, 3]) + + sgss = StratifiedGroupShuffleSplit(n_splits=5, test_size=0.33, random_state=42) + + for train_idx, test_idx in sgss.split(x, y, groups=groups): + # groups should be non-overlapping + train_groups = set(groups[train_idx]) + test_groups = set(groups[test_idx]) + assert train_groups.isdisjoint(test_groups) + + # indices should cover all samples + assert len(train_idx) + len(test_idx) == len(x) + + # test must contain at least one sample + assert len(test_idx) > 0 diff --git a/uv.lock b/uv.lock index 5af0930..6b32886 100644 --- a/uv.lock +++ b/uv.lock @@ -405,6 +405,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/62/a1/3d680cbfd5f4b8f15abc1d571870c5fc3e594bb582bc3b64ea099db13e56/jinja2-3.1.6-py3-none-any.whl", hash = "sha256:85ece4451f492d0c13c5dd7c13a64681a86afae63a5f347908daf103ce6d2f67", size = 134899, upload-time = "2025-03-05T20:05:00.369Z" }, ] +[[package]] +name = "joblib" +version = "1.5.3" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/41/f2/d34e8b3a08a9cc79a50b2208a93dce981fe615b64d5a4d4abee421d898df/joblib-1.5.3.tar.gz", hash = "sha256:8561a3269e6801106863fd0d6d84bb737be9e7631e33aaed3fb9ce5953688da3", size = 331603, upload-time = "2025-12-15T08:41:46.427Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/7b/91/984aca2ec129e2757d1e4e3c81c3fcda9d0f85b74670a094cc443d9ee949/joblib-1.5.3-py3-none-any.whl", hash = "sha256:5fc3c5039fc5ca8c0276333a188bbd59d6b7ab37fe6632daa76bc7f9ec18e713", size = 309071, upload-time = "2025-12-15T08:41:44.973Z" }, +] + [[package]] name = "jsonschema" version = "4.25.0" @@ -1484,6 +1493,7 @@ dependencies = [ { name = "rasterio" }, { name = "ray", extra = ["data"] }, { name = "scikit-image" }, + { name = "scikit-learn" }, { name = "shapely" }, { name = "tifffile" }, { name = "torch" }, @@ -1507,6 +1517,10 @@ dev = [ { name = "pyarrow-stubs" }, { name = "ruff" }, ] +tests = [ + { name = "openslide-bin" }, + { name = "pytest" }, +] [package.metadata] requires-dist = [ @@ -1525,6 +1539,7 @@ requires-dist = [ { name = "rasterio", specifier = ">=1.4.3" }, { name = "ray", extras = ["data"], specifier = ">=2.50.0" }, { name = "scikit-image", specifier = ">=0.25.2" }, + { name = "scikit-learn", specifier = ">=1.8.0" }, { name = "shapely", specifier = ">=2.0.0" }, { name = "tifffile", specifier = ">=2024.5.22" }, { name = "torch", specifier = ">=2.6.0" }, @@ -1539,6 +1554,10 @@ dev = [ { name = "pyarrow-stubs", specifier = ">=20.0.0.20251107" }, { name = "ruff" }, ] +tests = [ + { name = "openslide-bin", specifier = ">=4.0.0.8" }, + { name = "pytest", specifier = ">=8.4.1" }, +] [[package]] name = "ray" @@ -1737,6 +1756,50 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/10/cc/75e9f17e3670b5ed93c32456fda823333c6279b144cd93e2c03aa06aa472/scikit_image-0.25.2-cp313-cp313t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:330d061bd107d12f8d68f1d611ae27b3b813b8cdb0300a71d07b1379178dd4cd", size = 13862801, upload-time = "2025-02-18T18:05:20.783Z" }, ] +[[package]] +name = "scikit-learn" +version = "1.8.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "joblib" }, + { name = "numpy" }, + { name = "scipy" }, + { name = "threadpoolctl" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/0e/d4/40988bf3b8e34feec1d0e6a051446b1f66225f8529b9309becaeef62b6c4/scikit_learn-1.8.0.tar.gz", hash = "sha256:9bccbb3b40e3de10351f8f5068e105d0f4083b1a65fa07b6634fbc401a6287fd", size = 7335585, upload-time = "2025-12-10T07:08:53.618Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/90/74/e6a7cc4b820e95cc38cf36cd74d5aa2b42e8ffc2d21fe5a9a9c45c1c7630/scikit_learn-1.8.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:5fb63362b5a7ddab88e52b6dbb47dac3fd7dafeee740dc6c8d8a446ddedade8e", size = 8548242, upload-time = "2025-12-10T07:07:51.568Z" }, + { url = "https://files.pythonhosted.org/packages/49/d8/9be608c6024d021041c7f0b3928d4749a706f4e2c3832bbede4fb4f58c95/scikit_learn-1.8.0-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:5025ce924beccb28298246e589c691fe1b8c1c96507e6d27d12c5fadd85bfd76", size = 8079075, upload-time = "2025-12-10T07:07:53.697Z" }, + { url = "https://files.pythonhosted.org/packages/dd/47/f187b4636ff80cc63f21cd40b7b2d177134acaa10f6bb73746130ee8c2e5/scikit_learn-1.8.0-cp312-cp312-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:4496bb2cf7a43ce1a2d7524a79e40bc5da45cf598dbf9545b7e8316ccba47bb4", size = 8660492, upload-time = "2025-12-10T07:07:55.574Z" }, + { url = "https://files.pythonhosted.org/packages/97/74/b7a304feb2b49df9fafa9382d4d09061a96ee9a9449a7cbea7988dda0828/scikit_learn-1.8.0-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:a0bcfe4d0d14aec44921545fd2af2338c7471de9cb701f1da4c9d85906ab847a", size = 8931904, upload-time = "2025-12-10T07:07:57.666Z" }, + { url = "https://files.pythonhosted.org/packages/9f/c4/0ab22726a04ede56f689476b760f98f8f46607caecff993017ac1b64aa5d/scikit_learn-1.8.0-cp312-cp312-win_amd64.whl", hash = "sha256:35c007dedb2ffe38fe3ee7d201ebac4a2deccd2408e8621d53067733e3c74809", size = 8019359, upload-time = "2025-12-10T07:07:59.838Z" }, + { url = "https://files.pythonhosted.org/packages/24/90/344a67811cfd561d7335c1b96ca21455e7e472d281c3c279c4d3f2300236/scikit_learn-1.8.0-cp312-cp312-win_arm64.whl", hash = "sha256:8c497fff237d7b4e07e9ef1a640887fa4fb765647f86fbe00f969ff6280ce2bb", size = 7641898, upload-time = "2025-12-10T07:08:01.36Z" }, + { url = "https://files.pythonhosted.org/packages/03/aa/e22e0768512ce9255eba34775be2e85c2048da73da1193e841707f8f039c/scikit_learn-1.8.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:0d6ae97234d5d7079dc0040990a6f7aeb97cb7fa7e8945f1999a429b23569e0a", size = 8513770, upload-time = "2025-12-10T07:08:03.251Z" }, + { url = "https://files.pythonhosted.org/packages/58/37/31b83b2594105f61a381fc74ca19e8780ee923be2d496fcd8d2e1147bd99/scikit_learn-1.8.0-cp313-cp313-macosx_12_0_arm64.whl", hash = "sha256:edec98c5e7c128328124a029bceb09eda2d526997780fef8d65e9a69eead963e", size = 8044458, upload-time = "2025-12-10T07:08:05.336Z" }, + { url = "https://files.pythonhosted.org/packages/2d/5a/3f1caed8765f33eabb723596666da4ebbf43d11e96550fb18bdec42b467b/scikit_learn-1.8.0-cp313-cp313-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:74b66d8689d52ed04c271e1329f0c61635bcaf5b926db9b12d58914cdc01fe57", size = 8610341, upload-time = "2025-12-10T07:08:07.732Z" }, + { url = "https://files.pythonhosted.org/packages/38/cf/06896db3f71c75902a8e9943b444a56e727418f6b4b4a90c98c934f51ed4/scikit_learn-1.8.0-cp313-cp313-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:8fdf95767f989b0cfedb85f7ed8ca215d4be728031f56ff5a519ee1e3276dc2e", size = 8900022, upload-time = "2025-12-10T07:08:09.862Z" }, + { url = "https://files.pythonhosted.org/packages/1c/f9/9b7563caf3ec8873e17a31401858efab6b39a882daf6c1bfa88879c0aa11/scikit_learn-1.8.0-cp313-cp313-win_amd64.whl", hash = "sha256:2de443b9373b3b615aec1bb57f9baa6bb3a9bd093f1269ba95c17d870422b271", size = 7989409, upload-time = "2025-12-10T07:08:12.028Z" }, + { url = "https://files.pythonhosted.org/packages/49/bd/1f4001503650e72c4f6009ac0c4413cb17d2d601cef6f71c0453da2732fc/scikit_learn-1.8.0-cp313-cp313-win_arm64.whl", hash = "sha256:eddde82a035681427cbedded4e6eff5e57fa59216c2e3e90b10b19ab1d0a65c3", size = 7619760, upload-time = "2025-12-10T07:08:13.688Z" }, + { url = "https://files.pythonhosted.org/packages/d2/7d/a630359fc9dcc95496588c8d8e3245cc8fd81980251079bc09c70d41d951/scikit_learn-1.8.0-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:7cc267b6108f0a1499a734167282c00c4ebf61328566b55ef262d48e9849c735", size = 8826045, upload-time = "2025-12-10T07:08:15.215Z" }, + { url = "https://files.pythonhosted.org/packages/cc/56/a0c86f6930cfcd1c7054a2bc417e26960bb88d32444fe7f71d5c2cfae891/scikit_learn-1.8.0-cp313-cp313t-macosx_12_0_arm64.whl", hash = "sha256:fe1c011a640a9f0791146011dfd3c7d9669785f9fed2b2a5f9e207536cf5c2fd", size = 8420324, upload-time = "2025-12-10T07:08:17.561Z" }, + { url = "https://files.pythonhosted.org/packages/46/1e/05962ea1cebc1cf3876667ecb14c283ef755bf409993c5946ade3b77e303/scikit_learn-1.8.0-cp313-cp313t-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:72358cce49465d140cc4e7792015bb1f0296a9742d5622c67e31399b75468b9e", size = 8680651, upload-time = "2025-12-10T07:08:19.952Z" }, + { url = "https://files.pythonhosted.org/packages/fe/56/a85473cd75f200c9759e3a5f0bcab2d116c92a8a02ee08ccd73b870f8bb4/scikit_learn-1.8.0-cp313-cp313t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:80832434a6cc114f5219211eec13dcbc16c2bac0e31ef64c6d346cde3cf054cb", size = 8925045, upload-time = "2025-12-10T07:08:22.11Z" }, + { url = "https://files.pythonhosted.org/packages/cc/b7/64d8cfa896c64435ae57f4917a548d7ac7a44762ff9802f75a79b77cb633/scikit_learn-1.8.0-cp313-cp313t-win_amd64.whl", hash = "sha256:ee787491dbfe082d9c3013f01f5991658b0f38aa8177e4cd4bf434c58f551702", size = 8507994, upload-time = "2025-12-10T07:08:23.943Z" }, + { url = "https://files.pythonhosted.org/packages/5e/37/e192ea709551799379958b4c4771ec507347027bb7c942662c7fbeba31cb/scikit_learn-1.8.0-cp313-cp313t-win_arm64.whl", hash = "sha256:bf97c10a3f5a7543f9b88cbf488d33d175e9146115a451ae34568597ba33dcde", size = 7869518, upload-time = "2025-12-10T07:08:25.71Z" }, + { url = "https://files.pythonhosted.org/packages/24/05/1af2c186174cc92dcab2233f327336058c077d38f6fe2aceb08e6ab4d509/scikit_learn-1.8.0-cp314-cp314-macosx_10_15_x86_64.whl", hash = "sha256:c22a2da7a198c28dd1a6e1136f19c830beab7fdca5b3e5c8bba8394f8a5c45b3", size = 8528667, upload-time = "2025-12-10T07:08:27.541Z" }, + { url = "https://files.pythonhosted.org/packages/a8/25/01c0af38fe969473fb292bba9dc2b8f9b451f3112ff242c647fee3d0dfe7/scikit_learn-1.8.0-cp314-cp314-macosx_12_0_arm64.whl", hash = "sha256:6b595b07a03069a2b1740dc08c2299993850ea81cce4fe19b2421e0c970de6b7", size = 8066524, upload-time = "2025-12-10T07:08:29.822Z" }, + { url = "https://files.pythonhosted.org/packages/be/ce/a0623350aa0b68647333940ee46fe45086c6060ec604874e38e9ab7d8e6c/scikit_learn-1.8.0-cp314-cp314-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:29ffc74089f3d5e87dfca4c2c8450f88bdc61b0fc6ed5d267f3988f19a1309f6", size = 8657133, upload-time = "2025-12-10T07:08:31.865Z" }, + { url = "https://files.pythonhosted.org/packages/b8/cb/861b41341d6f1245e6ca80b1c1a8c4dfce43255b03df034429089ca2a2c5/scikit_learn-1.8.0-cp314-cp314-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:fb65db5d7531bccf3a4f6bec3462223bea71384e2cda41da0f10b7c292b9e7c4", size = 8923223, upload-time = "2025-12-10T07:08:34.166Z" }, + { url = "https://files.pythonhosted.org/packages/76/18/a8def8f91b18cd1ba6e05dbe02540168cb24d47e8dcf69e8d00b7da42a08/scikit_learn-1.8.0-cp314-cp314-win_amd64.whl", hash = "sha256:56079a99c20d230e873ea40753102102734c5953366972a71d5cb39a32bc40c6", size = 8096518, upload-time = "2025-12-10T07:08:36.339Z" }, + { url = "https://files.pythonhosted.org/packages/d1/77/482076a678458307f0deb44e29891d6022617b2a64c840c725495bee343f/scikit_learn-1.8.0-cp314-cp314-win_arm64.whl", hash = "sha256:3bad7565bc9cf37ce19a7c0d107742b320c1285df7aab1a6e2d28780df167242", size = 7754546, upload-time = "2025-12-10T07:08:38.128Z" }, + { url = "https://files.pythonhosted.org/packages/2d/d1/ef294ca754826daa043b2a104e59960abfab4cf653891037d19dd5b6f3cf/scikit_learn-1.8.0-cp314-cp314t-macosx_10_15_x86_64.whl", hash = "sha256:4511be56637e46c25721e83d1a9cea9614e7badc7040c4d573d75fbe257d6fd7", size = 8848305, upload-time = "2025-12-10T07:08:41.013Z" }, + { url = "https://files.pythonhosted.org/packages/5b/e2/b1f8b05138ee813b8e1a4149f2f0d289547e60851fd1bb268886915adbda/scikit_learn-1.8.0-cp314-cp314t-macosx_12_0_arm64.whl", hash = "sha256:a69525355a641bf8ef136a7fa447672fb54fe8d60cab5538d9eb7c6438543fb9", size = 8432257, upload-time = "2025-12-10T07:08:42.873Z" }, + { url = "https://files.pythonhosted.org/packages/26/11/c32b2138a85dcb0c99f6afd13a70a951bfdff8a6ab42d8160522542fb647/scikit_learn-1.8.0-cp314-cp314t-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:c2656924ec73e5939c76ac4c8b026fc203b83d8900362eb2599d8aee80e4880f", size = 8678673, upload-time = "2025-12-10T07:08:45.362Z" }, + { url = "https://files.pythonhosted.org/packages/c7/57/51f2384575bdec454f4fe4e7a919d696c9ebce914590abf3e52d47607ab8/scikit_learn-1.8.0-cp314-cp314t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:15fc3b5d19cc2be65404786857f2e13c70c83dd4782676dd6814e3b89dc8f5b9", size = 8922467, upload-time = "2025-12-10T07:08:47.408Z" }, + { url = "https://files.pythonhosted.org/packages/35/4d/748c9e2872637a57981a04adc038dacaa16ba8ca887b23e34953f0b3f742/scikit_learn-1.8.0-cp314-cp314t-win_amd64.whl", hash = "sha256:00d6f1d66fbcf4eba6e356e1420d33cc06c70a45bb1363cd6f6a8e4ebbbdece2", size = 8774395, upload-time = "2025-12-10T07:08:49.337Z" }, + { url = "https://files.pythonhosted.org/packages/60/22/d7b2ebe4704a5e50790ba089d5c2ae308ab6bb852719e6c3bd4f04c3a363/scikit_learn-1.8.0-cp314-cp314t-win_arm64.whl", hash = "sha256:f28dd15c6bb0b66ba09728cf09fd8736c304be29409bd8445a080c1280619e8c", size = 8002647, upload-time = "2025-12-10T07:08:51.601Z" }, +] + [[package]] name = "scipy" version = "1.16.1" @@ -1950,6 +2013,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/a2/09/77d55d46fd61b4a135c444fc97158ef34a095e5681d0a6c10b75bf356191/sympy-1.14.0-py3-none-any.whl", hash = "sha256:e091cc3e99d2141a0ba2847328f5479b05d94a6635cb96148ccb3f34671bd8f5", size = 6299353, upload-time = "2025-04-27T18:04:59.103Z" }, ] +[[package]] +name = "threadpoolctl" +version = "3.6.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/b7/4d/08c89e34946fce2aec4fbb45c9016efd5f4d7f24af8e5d93296e935631d8/threadpoolctl-3.6.0.tar.gz", hash = "sha256:8ab8b4aa3491d812b623328249fab5302a68d2d71745c8a4c719a2fcaba9f44e", size = 21274, upload-time = "2025-03-13T13:49:23.031Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/32/d5/f9a850d79b0851d1d4ef6456097579a9005b31fea68726a4ae5f2d82ddd9/threadpoolctl-3.6.0-py3-none-any.whl", hash = "sha256:43a0b8fd5a2928500110039e43a5eed8480b918967083ea48dc3ab9f13c4a7fb", size = 18638, upload-time = "2025-03-13T13:49:21.846Z" }, +] + [[package]] name = "tifffile" version = "2025.6.11" From 866e0a0a2215f70dea5895ff1f8ea88c13b370bc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adam=20Kuku=C4=8Dka?= Date: Sat, 28 Feb 2026 17:48:12 +0000 Subject: [PATCH 2/6] fix: ruff --- tests/test_split.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_split.py b/tests/test_split.py index 00ffb67..55b441d 100644 --- a/tests/test_split.py +++ b/tests/test_split.py @@ -12,7 +12,7 @@ def test_train_test_split_with_groups_and_stratify(): groups = np.array([1, 1, 2, 2, 3, 3]) # include groups as one of the arrays so we can inspect split groups - x_train, x_test, y_train, y_test, g_train, g_test = train_test_split( + _, _, _, y_test, g_train, g_test = train_test_split( x, y, groups, test_size=0.33, random_state=0, stratify=y, groups=groups ) @@ -30,7 +30,7 @@ def test_train_test_split_with_groups_no_stratify(): y = np.array([0, 1, 0, 1, 0]) groups = np.array([1, 1, 2, 2, 3]) - x_train, x_test, y_train, y_test, g_train, g_test = train_test_split( + _, _, _, _, g_train, g_test = train_test_split( x, y, groups, test_size=0.4, random_state=1, groups=groups ) From fb1c78da0561cbfa7070a4338dc0254008c43798 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adam=20Kuku=C4=8Dka?= Date: Sat, 28 Feb 2026 17:58:45 +0000 Subject: [PATCH 3/6] fix: PR --- ratiopath/model_selection/split.py | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/ratiopath/model_selection/split.py b/ratiopath/model_selection/split.py index 18b4fb3..0e9e0c9 100644 --- a/ratiopath/model_selection/split.py +++ b/ratiopath/model_selection/split.py @@ -69,8 +69,6 @@ class StratifiedGroupShuffleSplit(GroupsConsumerMixin, BaseShuffleSplit): The implementation finds the best stratification split by trying multiple splits and selecting the one that minimizes the difference between the class distributions in the original data and the test split. - - Groups appear exactly once in the test set across all splits. """ def __init__( @@ -97,7 +95,7 @@ def _get_distribution(labels: ArrayLike) -> np.ndarray: def split( self, X: list[str] | MatrixLike, # noqa: N803 - y: ArrayLike, + y: ArrayLike | None = None, groups: Any = None, ) -> Iterator[Any]: """Generate indices to split data into training and test set. @@ -129,12 +127,11 @@ def split( rng = check_random_state(self.random_state) y = np.asarray(y) - data_distribution = self._get_distribution(y) - min_diff: Float | None = None - train_index: np.ndarray | None = None - test_index: np.ndarray | None = None - for _ in range(self.n_splits): + data_distribution = self._get_distribution(y) + min_diff: Float | None = None + train_index: np.ndarray | None = None + test_index: np.ndarray | None = None cv = StratifiedGroupKFold(n_splits=n_splits, shuffle=True, random_state=rng) for curr_train_index, curr_test_index in cv.split(X=X, y=y, groups=groups): @@ -283,6 +280,7 @@ class labels. For binary or multiclass classification, this ensures that the # It is safer to pass fractions, because some splitters calculate n_samplers # as number of groups, not samples cv = cvclass( + n_splits=1, test_size=n_test / n_samples, train_size=n_train / n_samples, random_state=random_state, From 756a84746a1313185c70b40d32d3e7b082a7a17e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adam=20Kuku=C4=8Dka?= Date: Sat, 28 Feb 2026 17:59:42 +0000 Subject: [PATCH 4/6] fix: typo --- ratiopath/model_selection/split.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ratiopath/model_selection/split.py b/ratiopath/model_selection/split.py index 0e9e0c9..a9635f4 100644 --- a/ratiopath/model_selection/split.py +++ b/ratiopath/model_selection/split.py @@ -126,9 +126,9 @@ def split( n_splits = round(n_samples / n_test) rng = check_random_state(self.random_state) y = np.asarray(y) + data_distribution = self._get_distribution(y) for _ in range(self.n_splits): - data_distribution = self._get_distribution(y) min_diff: Float | None = None train_index: np.ndarray | None = None test_index: np.ndarray | None = None From 9166abe9f373aa2fd06f97a6aa27d5a34ba4a408 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adam=20Kuku=C4=8Dka?= Date: Sat, 28 Feb 2026 18:02:09 +0000 Subject: [PATCH 5/6] fix: PR --- ratiopath/model_selection/split.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ratiopath/model_selection/split.py b/ratiopath/model_selection/split.py index a9635f4..b7ad472 100644 --- a/ratiopath/model_selection/split.py +++ b/ratiopath/model_selection/split.py @@ -259,9 +259,9 @@ class labels. For binary or multiclass classification, this ensures that the ) if shuffle is False: - if stratify is not None: + if stratify is not None or groups is not None: raise ValueError( - "Stratified train/test split is not implemented for shuffle=False" + "Stratified or grouped train/test split is not implemented for shuffle=False" ) train = np.arange(n_train) From 7e8ca60be66931f7af2960ef6f5548c03838d497 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adam=20Kuku=C4=8Dka?= Date: Tue, 3 Mar 2026 16:47:59 +0000 Subject: [PATCH 6/6] fix: docs indentation --- ratiopath/model_selection/split.py | 23 ++++++++++++----------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/ratiopath/model_selection/split.py b/ratiopath/model_selection/split.py index b7ad472..a304048 100644 --- a/ratiopath/model_selection/split.py +++ b/ratiopath/model_selection/split.py @@ -41,17 +41,18 @@ class StratifiedGroupShuffleSplit(GroupsConsumerMixin, BaseShuffleSplit): Read more in the :ref:`User Guide `. Parameters: - n_splits: Number of re-shuffling & splitting iterations. - test_size: If float, should be between 0.0 and 1.0 and represent the proportion of - the dataset to include in the test split. If int, represents the absolute number - of test samples. If None, the value is set to the complement of the train size. - train_size: If float, should be between 0.0 and 1.0 and represent the proportion of - the dataset to include in the train split. If int, represents the absolute - number of train samples. If None, the value is automatically set to the - complement of the test size. - random_state: Controls the randomness of the training and testing indices. Pass an - int for reproducible output across multiple function calls. - See :term:`Glossary `. + n_splits: Number of re-shuffling & splitting iterations. + test_size: If float, should be between 0.0 and 1.0 and represent the proportion + of the dataset to include in the test split. If int, represents the absolute + number of test samples. If None, the value is set to the complement of the + train size. + train_size: If float, should be between 0.0 and 1.0 and represent the proportion + of the dataset to include in the train split. If int, represents the + absolute number of train samples. If None, the value is automatically set to + the complement of the test size. + random_state: Controls the randomness of the training and testing indices. Pass + an int for reproducible output across multiple function calls. + See :term:`Glossary `. Examples: >>> import numpy as np