Skip to content

feat: splits#34

Merged
Adames4 merged 7 commits intomainfrom
feature/splits
Mar 3, 2026
Merged

feat: splits#34
Adames4 merged 7 commits intomainfrom
feature/splits

Conversation

@Adames4
Copy link
Collaborator

@Adames4 Adames4 commented Feb 28, 2026

StratifiedGroupShuffleSplit is inspired by the implementation at https://gitlab.ics.muni.cz/rationai/digital-pathology/pathology/prostate-cancer/-/blob/master/preprocessing/stratified_group_split.py?ref_type=heads
, but follows the scikit-learn splitting API.

train_test_split extends sklearn.model_selection.train_test_split with group support.

Possible discussion:
– Should the model_selection folder be located inside the sklearn folder?

Summary by CodeRabbit

  • New Features

    • Added a public stratified, group-aware splitter and an enhanced train/test split that preserves class distributions while keeping groups non-overlapping.
  • Chores

    • Added scikit-learn>=1.8.0 to project dependencies.
  • Tests

    • Added tests validating grouped splits, stratification behavior, sample coverage, and train/test disjointness.

@Adames4 Adames4 self-assigned this Feb 28, 2026
@Adames4 Adames4 requested review from a team February 28, 2026 17:46
@coderabbitai
Copy link

coderabbitai bot commented Feb 28, 2026

📝 Walkthrough

Walkthrough

Adds a stratified, group-aware splitter (StratifiedGroupShuffleSplit), enhances train_test_split to support combined stratify+groups behavior and indexing, adds corresponding tests, and adds scikit-learn as a dependency.

Changes

Cohort / File(s) Summary
Dependency
pyproject.toml
Added dependency scikit-learn>=1.8.0.
Model selection package
ratiopath/model_selection/__init__.py
Re-exports StratifiedGroupShuffleSplit and train_test_split from .split.
Split implementation
ratiopath/model_selection/split.py
New StratifiedGroupShuffleSplit class and enhanced train_test_split supporting stratify+groups, selection by distribution (L1) over group-aware folds, input validation, and safe indexing.
Tests
tests/test_split.py
Added tests for train_test_split with groups+stratify, groups without stratify, and for StratifiedGroupShuffleSplit properties (group disjointness, coverage, non-empty test).

Sequence Diagram

sequenceDiagram
    participant User as User
    participant TTS as train_test_split
    participant Validator as Validator
    participant Router as Router
    participant SGSS as StratifiedGroupShuffleSplit
    participant SKL as sklearn_Splitter
    participant Indexer as Indexer
    participant Result as Result

    User->>TTS: call(X, y, test_size, stratify, groups, shuffle, random_state)
    TTS->>Validator: validate inputs/sizes
    Validator-->>TTS: validated
    TTS->>Router: choose splitter
    alt stratify AND groups
        Router->>SGSS: create & run splitter
        SGSS->>SGSS: compute group-wise class distributions
        SGSS->>SGSS: evaluate candidate splits, score by L1, pick best
        SGSS-->>Router: train/test indices
    else other cases
        Router->>SKL: use StratifiedShuffleSplit / GroupShuffleSplit / ShuffleSplit
        SKL-->>Router: train/test indices
    end
    Router->>Indexer: apply indices to inputs
    Indexer-->>Result: indexed arrays
    Result-->>User: return split arrays
Loading

Estimated Code Review Effort

🎯 4 (Complex) | ⏱️ ~45 minutes

Poem

🐰 I hopped through code with careful paws,

Groups kept apart without a flaw,
Classes balanced, splits made right,
Tests and sklearn set the light,
A tiny hop, a big applause.

🚥 Pre-merge checks | ✅ 1 | ❌ 2

❌ Failed checks (1 warning, 1 inconclusive)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 28.57% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
Title check ❓ Inconclusive The title 'feat: splits' is vague and generic, using a non-descriptive term that doesn't convey meaningful information about the specific changes. Consider a more descriptive title such as 'feat: add StratifiedGroupShuffleSplit and extend train_test_split to support groups' to better reflect the main contributions.
✅ Passed checks (1 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
  • 📝 Generate docstrings (stacked PR)
  • 📝 Generate docstrings (commit on current branch)
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment
  • Commit unit tests in branch feature/splits

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@gemini-code-assist
Copy link

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly enhances the data splitting capabilities of the project by introducing advanced utilities for creating robust train/test splits. The new StratifiedGroupShuffleSplit class and the extended train_test_split function address common challenges in machine learning, particularly when dealing with grouped data that also requires class stratification. These additions ensure more reliable model evaluation by preventing data leakage across groups and maintaining representative class distributions in splits.

Highlights

  • New Data Splitting Utilities: Introduced StratifiedGroupShuffleSplit, a new class for creating train/test splits that preserve class distribution (stratification) while ensuring groups do not overlap between splits. This is crucial for robust model evaluation in grouped data.
  • Enhanced train_test_split: Extended the standard sklearn.model_selection.train_test_split function to support group-aware and stratified splitting. This new function intelligently uses StratifiedGroupShuffleSplit or GroupShuffleSplit when stratify and/or groups parameters are provided.
  • Dependency Updates: Added scikit-learn as a core dependency and introduced a new tests dependency group including openslide-bin and pytest to support the new splitting functionalities and their testing.
  • Comprehensive Testing: Added dedicated unit tests for both StratifiedGroupShuffleSplit and the extended train_test_split to verify correct behavior regarding group separation and stratification.
Changelog
  • pyproject.toml
    • Added scikit-learn as a new core dependency.
    • Introduced a new tests dependency group, including openslide-bin and pytest.
  • ratiopath/model_selection/init.py
    • Added a new initialization file for the model_selection module.
    • Exported StratifiedGroupShuffleSplit and train_test_split for module-level access.
  • ratiopath/model_selection/split.py
    • Implemented the StratifiedGroupShuffleSplit class, combining stratified sampling with group separation.
    • Extended the train_test_split function to incorporate group-aware and stratified splitting logic.
    • Added detailed docstrings and examples for the new class and function.
  • tests/test_split.py
    • Added new test cases for train_test_split to validate its behavior with groups and stratification.
    • Included tests for StratifiedGroupShuffleSplit to ensure correct group partitioning and index generation.
  • uv.lock
    • Updated the lock file to include joblib, scikit-learn, and threadpoolctl dependencies.
    • Reflected the new test dependencies (openslide-bin, pytest) in the lock file.
Activity
  • No specific activity (comments, reviews, etc.) was provided in the context for this pull request.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces StratifiedGroupShuffleSplit and an extended train_test_split to support grouped and stratified data splitting, which is a valuable addition. All original comments have been retained as they do not conflict with the provided rules. I've identified a significant logical flaw in the StratifiedGroupShuffleSplit implementation that could lead to repetitive and non-random splits. I've provided a high-severity comment with a suggested fix for this. Additionally, I've pointed out a misleading docstring, an inefficiency in the train_test_split function, and a suggestion to strengthen the tests to prevent similar issues in the future. Please address these points to ensure the reliability of these new features.

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

🧹 Nitpick comments (3)
tests/test_split.py (1)

3-6: Consider importing from the public API.

The tests import directly from ratiopath.model_selection.split (the internal module) rather than ratiopath.model_selection (the public re-export). Using the public API in tests helps verify that the re-exports work correctly and aligns with how external users will consume the module.

♻️ Proposed change
-from ratiopath.model_selection.split import (
+from ratiopath.model_selection import (
     StratifiedGroupShuffleSplit,
     train_test_split,
 )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/test_split.py` around lines 3 - 6, Tests currently import
StratifiedGroupShuffleSplit and train_test_split directly from the internal
module ratiopath.model_selection.split; update the import to use the public
re-export by importing both StratifiedGroupShuffleSplit and train_test_split
from ratiopath.model_selection instead, so tests exercise the public API and
verify the re-exports for those symbols.
pyproject.toml (1)

40-47: Duplicate test dependencies between [dependency-groups] and [project.optional-dependencies].

The same packages (openslide-bin>=4.0.0.8, pytest>=8.4.1) are declared in both sections. This may be intentional for compatibility with different tooling (PEP 735 dependency-groups vs PEP 508 optional-dependencies), but if so, consider adding a brief comment to clarify the intent and prevent future drift.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@pyproject.toml` around lines 40 - 47, The tests dependencies are duplicated
between the dependency-groups `tests = ["openslide-bin>=4.0.0.8",
"pytest>=8.4.1"]` and the `[project.optional-dependencies]` `tests =
["openslide-bin>=4.0.0.8", "pytest>=8.4.1"]`; either remove one duplicate or
keep both but add a clear inline comment above the duplicated `tests` entry (in
either `dependency-groups` or `[project.optional-dependencies]`) stating the
reason for duplication (e.g., "kept for PEP 735 compatibility with tooling X")
to prevent future drift and make intent explicit when editing the `tests` lists.
ratiopath/model_selection/split.py (1)

16-20: Reliance on private scikit-learn APIs is fragile.

The imports from sklearn.model_selection._split, sklearn.utils._array_api, sklearn.utils._indexing, and sklearn.utils._param_validation use underscore-prefixed (private) modules. These internal APIs can change without deprecation warnings between minor releases, potentially breaking this code.

Consider:

  • Documenting which sklearn version these internals are compatible with
  • Adding integration tests that catch breakage early
  • Evaluating if public APIs can achieve the same result (e.g., check_array instead of internal utilities)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@ratiopath/model_selection/split.py` around lines 16 - 20, The code imports
private sklearn internals (GroupsConsumerMixin, _validate_shuffle_split,
get_namespace_and_device, move_to, _safe_indexing, Interval, RealNotInt,
validate_params, _num_samples, check_random_state, indexable) which is fragile;
replace these with public equivalents where possible (e.g., use public
model_selection APIs, sklearn.utils.check_array / sklearn.utils.resample /
sklearn.utils.validation public helpers, and public param-validation helpers) or
wrap usage behind adapter functions in this module so you can swap
implementations, and add a short compatibility note in the module docstring
stating the tested sklearn version and add integration tests that assert
behavior of split-related functions (including group handling and parameter
validation) to detect breakage early.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@ratiopath/model_selection/split.py`:
- Around line 73-74: Update the misleading docstring that currently claims
"Groups appear exactly once in the test set across all splits" in the splitter's
docstring inside ratiopath.model_selection.split; change the wording to state
that each split independently selects its best fold for stratification and that
groups are not globally tracked across splits (so the same group may appear in
test sets of multiple splits), and reference the splitter class/function
docstring where that sentence appears so future readers understand the actual
behavior.
- Around line 264-271: The code path handling shuffle == False in
ratiopath/model_selection/split.py silently ignores the groups parameter; update
that branch to explicitly error when groups is provided instead of proceeding
(i.e., check if groups is not None and raise a ValueError), or alternatively
document this behavior — but the preferred fix is to raise an error referencing
the same wording style used for stratify (e.g., raise ValueError("Group-based
splitting not supported for shuffle=False")), and keep the rest of the
sequential assignment to train/test (variables train, test, n_train, n_test)
unchanged.
- Around line 133-155: The generator can yield None if cv.split() produced no
folds; update the loop in the split generator (the block using
StratifiedGroupKFold, min_diff, train_index, test_index, self._get_distribution
and flipped) to check after the inner for-loop whether train_index or test_index
is still None and, if so, raise a clear ValueError (or skip yielding) with
context about the input and n_splits rather than yielding None; also guard the
flipped swap so it only runs when train_index/test_index are not None. Ensure
the error message references the provided n_splits, X/y/groups sizes and that it
occurs before the yield.

---

Nitpick comments:
In `@pyproject.toml`:
- Around line 40-47: The tests dependencies are duplicated between the
dependency-groups `tests = ["openslide-bin>=4.0.0.8", "pytest>=8.4.1"]` and the
`[project.optional-dependencies]` `tests = ["openslide-bin>=4.0.0.8",
"pytest>=8.4.1"]`; either remove one duplicate or keep both but add a clear
inline comment above the duplicated `tests` entry (in either `dependency-groups`
or `[project.optional-dependencies]`) stating the reason for duplication (e.g.,
"kept for PEP 735 compatibility with tooling X") to prevent future drift and
make intent explicit when editing the `tests` lists.

In `@ratiopath/model_selection/split.py`:
- Around line 16-20: The code imports private sklearn internals
(GroupsConsumerMixin, _validate_shuffle_split, get_namespace_and_device,
move_to, _safe_indexing, Interval, RealNotInt, validate_params, _num_samples,
check_random_state, indexable) which is fragile; replace these with public
equivalents where possible (e.g., use public model_selection APIs,
sklearn.utils.check_array / sklearn.utils.resample / sklearn.utils.validation
public helpers, and public param-validation helpers) or wrap usage behind
adapter functions in this module so you can swap implementations, and add a
short compatibility note in the module docstring stating the tested sklearn
version and add integration tests that assert behavior of split-related
functions (including group handling and parameter validation) to detect breakage
early.

In `@tests/test_split.py`:
- Around line 3-6: Tests currently import StratifiedGroupShuffleSplit and
train_test_split directly from the internal module
ratiopath.model_selection.split; update the import to use the public re-export
by importing both StratifiedGroupShuffleSplit and train_test_split from
ratiopath.model_selection instead, so tests exercise the public API and verify
the re-exports for those symbols.

ℹ️ Review info

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 546f393 and 866e0a0.

⛔ Files ignored due to path filters (1)
  • uv.lock is excluded by !**/*.lock
📒 Files selected for processing (4)
  • pyproject.toml
  • ratiopath/model_selection/__init__.py
  • ratiopath/model_selection/split.py
  • tests/test_split.py

JakubPekar
JakubPekar previously approved these changes Mar 1, 2026
vejtek
vejtek previously approved these changes Mar 2, 2026
@Adames4 Adames4 dismissed stale reviews from vejtek and JakubPekar via e0e6f7a March 3, 2026 16:39
Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

🧹 Nitpick comments (2)
ratiopath/model_selection/split.py (2)

88-88: Consider making _default_test_size configurable or documenting the rationale.

The hardcoded 0.2 differs from sklearn's train_test_split default of 0.25. If this is intentional, consider adding a brief comment explaining why. Alternatively, align with sklearn's default for consistency.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@ratiopath/model_selection/split.py` at line 88, The hardcoded attribute
_default_test_size = 0.2 should be made explicit or configurable: either change
the default to match sklearn (0.25) or add a constructor parameter (e.g.,
test_size=None or default_test_size=0.2) so callers can override it, and if you
keep 0.2 add a one-line comment next to _default_test_size explaining why it
differs from sklearn; update any references in methods like
split/train_test_split wrappers to use the new parameter or attribute.

16-20: Using private scikit-learn APIs poses a stability risk.

These imports rely on sklearn's internal modules (_split, _array_api, _indexing, _param_validation), which are not part of the public API and may change without deprecation warnings between versions. Consider pinning an upper bound on scikit-learn (e.g., scikit-learn>=1.8.0,<1.9) in pyproject.toml to reduce the risk of breakage.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@ratiopath/model_selection/split.py` around lines 16 - 20, The code imports
private scikit-learn internals (GroupsConsumerMixin, _validate_shuffle_split,
get_namespace_and_device, move_to, _safe_indexing, Interval, RealNotInt,
validate_params, _num_samples, check_random_state, indexable) which can break
across versions; update the project to avoid instability by either switching to
public equivalents where available or pinning scikit-learn to a safe range (for
example add an upper bound like <1.9 in pyproject.toml such as
scikit-learn>=1.8.0,<1.9) and note which private symbols (e.g.,
_validate_shuffle_split, _safe_indexing, Interval/RealNotInt) still require
replacement so we can later refactor to public APIs.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@pyproject.toml`:
- Line 30: The project depends on internal scikit-learn APIs
(GroupsConsumerMixin and _validate_shuffle_split from
sklearn.model_selection._split) which can break across minor releases; update
the scikit-learn dependency in pyproject.toml to a pinned range such as
"scikit-learn>=1.8.0,<1.9" to avoid accidental upgrades, and either open an
issue with scikit-learn to request stabilizing those utilities or refactor
split.py to stop importing GroupsConsumerMixin and _validate_shuffle_split
(re-implement needed behavior or use public APIs) so the code no longer relies
on private symbols.

In `@ratiopath/model_selection/split.py`:
- Around line 128-129: The code converts y via np.asarray(y) and then calls
self._get_distribution(y) but does not validate y is not None; add an early
validation in the method handling the y parameter (the function that contains y
= np.asarray(y) and calls _get_distribution) to check if y is None or contains
only None-equivalent values and raise a clear ValueError indicating
stratification requires a non-None target. Ensure the check happens before
np.asarray(y) or immediately after it, and reference this validation in error
messages for the caller (use the same function/method name where y =
np.asarray(y) and _get_distribution is invoked).
- Line 99: Add an explicit validation for the parameter named groups (the
annotated parameter "groups: Any = None") at the start of the function where it
is defined: if groups is None raise a ValueError with a clear message like
"groups must be provided for stratified group splitting" so callers get an
immediate, descriptive error; place this check before any use of
StratifiedGroupKFold or downstream logic that assumes groups is present.

---

Nitpick comments:
In `@ratiopath/model_selection/split.py`:
- Line 88: The hardcoded attribute _default_test_size = 0.2 should be made
explicit or configurable: either change the default to match sklearn (0.25) or
add a constructor parameter (e.g., test_size=None or default_test_size=0.2) so
callers can override it, and if you keep 0.2 add a one-line comment next to
_default_test_size explaining why it differs from sklearn; update any references
in methods like split/train_test_split wrappers to use the new parameter or
attribute.
- Around line 16-20: The code imports private scikit-learn internals
(GroupsConsumerMixin, _validate_shuffle_split, get_namespace_and_device,
move_to, _safe_indexing, Interval, RealNotInt, validate_params, _num_samples,
check_random_state, indexable) which can break across versions; update the
project to avoid instability by either switching to public equivalents where
available or pinning scikit-learn to a safe range (for example add an upper
bound like <1.9 in pyproject.toml such as scikit-learn>=1.8.0,<1.9) and note
which private symbols (e.g., _validate_shuffle_split, _safe_indexing,
Interval/RealNotInt) still require replacement so we can later refactor to
public APIs.

ℹ️ Review info

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 866e0a0 and e0e6f7a.

⛔ Files ignored due to path filters (1)
  • uv.lock is excluded by !**/*.lock
📒 Files selected for processing (2)
  • pyproject.toml
  • ratiopath/model_selection/split.py

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

♻️ Duplicate comments (1)
ratiopath/model_selection/split.py (1)

96-130: ⚠️ Potential issue | 🟡 Minor

Add explicit y and groups validation at the top of split().

The public API/docstring requires stratification target and groups, but missing values currently fail later with less direct errors.

🛡️ Proposed fix
     def split(
         self,
         X: list[str] | MatrixLike,  # noqa: N803
         y: ArrayLike | None = None,
         groups: Any = None,
     ) -> Iterator[Any]:
@@
-        n_samples = _num_samples(X)
+        if y is None:
+            raise ValueError("`y` must be provided for stratified group splitting.")
+        if groups is None:
+            raise ValueError("`groups` must be provided for stratified group splitting.")
+
+        n_samples = _num_samples(X)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@ratiopath/model_selection/split.py` around lines 96 - 130, Add explicit
validation at the start of split(): ensure y is not None when stratification is
required and groups is not None (raise ValueError with clear message if
missing), check that len(y) and len(groups) (after converting y with np.asarray
if needed) equal n_samples from _num_samples(X), and raise if they don't match;
do this before using y in _get_distribution or any later logic so
missing/length-mismatch errors are reported clearly (refer to the split method,
variables y and groups, _num_samples, n_samples, and _get_distribution).
🧹 Nitpick comments (1)
ratiopath/model_selection/split.py (1)

16-20: Isolate scikit-learn private API imports behind a compatibility layer to reduce future breakage risk.

Imports from underscore modules (e.g., sklearn.model_selection._split, sklearn.utils._array_api, sklearn.utils._indexing) and private functions (e.g., _validate_shuffle_split, _num_samples) are not guaranteed to be stable across sklearn releases. A thin wrapper module (e.g., ratiopath/_sklearn_compat.py) centralizing these imports would make version transitions more manageable.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@ratiopath/model_selection/split.py` around lines 16 - 20, Replace direct
private sklearn imports in split.py with a compatibility layer: create
ratiopath/_sklearn_compat.py that imports and re-exports GroupsConsumerMixin,
_validate_shuffle_split, get_namespace_and_device, move_to, _safe_indexing,
Interval, RealNotInt, validate_params, _num_samples, check_random_state, and
indexable (and add graceful fallbacks/try/except aliases for different sklearn
versions), then change the imports in ratiopath/model_selection/split.py to
import those symbols from ratiopath._sklearn_compat instead of from sklearn.*
private modules so future sklearn breakages are isolated to the compat module.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@ratiopath/model_selection/split.py`:
- Around line 248-250: Update the docstring for the split function that
documents the shuffle parameter to reflect runtime behavior: state that when
shuffle=False both stratify and groups must be None (not just stratify).
Reference the parameters by name (shuffle, stratify, groups) and update the
sentence currently mentioning only stratify to explicitly mention groups as well
so documentation matches the implemented checks in the function.
- Around line 127-137: n_splits computed as round(n_samples / n_test) can exceed
the number of unique groups and cause StratifiedGroupKFold to raise; before
constructing StratifiedGroupKFold in the loop, compute max_groups =
len(np.unique(groups)) (or 1 if groups is None) and clamp n_splits = max(2,
min(n_splits, max_groups)) (ensuring the lower bound 2 required by
StratifiedGroupKFold), then pass that clamped n_splits into
StratifiedGroupKFold(...) so the fold count never exceeds available group
cardinality.

---

Duplicate comments:
In `@ratiopath/model_selection/split.py`:
- Around line 96-130: Add explicit validation at the start of split(): ensure y
is not None when stratification is required and groups is not None (raise
ValueError with clear message if missing), check that len(y) and len(groups)
(after converting y with np.asarray if needed) equal n_samples from
_num_samples(X), and raise if they don't match; do this before using y in
_get_distribution or any later logic so missing/length-mismatch errors are
reported clearly (refer to the split method, variables y and groups,
_num_samples, n_samples, and _get_distribution).

---

Nitpick comments:
In `@ratiopath/model_selection/split.py`:
- Around line 16-20: Replace direct private sklearn imports in split.py with a
compatibility layer: create ratiopath/_sklearn_compat.py that imports and
re-exports GroupsConsumerMixin, _validate_shuffle_split,
get_namespace_and_device, move_to, _safe_indexing, Interval, RealNotInt,
validate_params, _num_samples, check_random_state, and indexable (and add
graceful fallbacks/try/except aliases for different sklearn versions), then
change the imports in ratiopath/model_selection/split.py to import those symbols
from ratiopath._sklearn_compat instead of from sklearn.* private modules so
future sklearn breakages are isolated to the compat module.

ℹ️ Review info

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between e0e6f7a and 7e8ca60.

📒 Files selected for processing (1)
  • ratiopath/model_selection/split.py

@Adames4 Adames4 merged commit 2a29a19 into main Mar 3, 2026
6 checks passed
@matejpekar matejpekar deleted the feature/splits branch March 5, 2026 10:32
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants