Conversation
… support; update dependencies
📝 WalkthroughWalkthroughAdds a stratified, group-aware splitter (StratifiedGroupShuffleSplit), enhances train_test_split to support combined stratify+groups behavior and indexing, adds corresponding tests, and adds scikit-learn as a dependency. Changes
Sequence DiagramsequenceDiagram
participant User as User
participant TTS as train_test_split
participant Validator as Validator
participant Router as Router
participant SGSS as StratifiedGroupShuffleSplit
participant SKL as sklearn_Splitter
participant Indexer as Indexer
participant Result as Result
User->>TTS: call(X, y, test_size, stratify, groups, shuffle, random_state)
TTS->>Validator: validate inputs/sizes
Validator-->>TTS: validated
TTS->>Router: choose splitter
alt stratify AND groups
Router->>SGSS: create & run splitter
SGSS->>SGSS: compute group-wise class distributions
SGSS->>SGSS: evaluate candidate splits, score by L1, pick best
SGSS-->>Router: train/test indices
else other cases
Router->>SKL: use StratifiedShuffleSplit / GroupShuffleSplit / ShuffleSplit
SKL-->>Router: train/test indices
end
Router->>Indexer: apply indices to inputs
Indexer-->>Result: indexed arrays
Result-->>User: return split arrays
Estimated Code Review Effort🎯 4 (Complex) | ⏱️ ~45 minutes Poem
🚥 Pre-merge checks | ✅ 1 | ❌ 2❌ Failed checks (1 warning, 1 inconclusive)
✅ Passed checks (1 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Summary of ChangesHello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly enhances the data splitting capabilities of the project by introducing advanced utilities for creating robust train/test splits. The new Highlights
Changelog
Activity
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request introduces StratifiedGroupShuffleSplit and an extended train_test_split to support grouped and stratified data splitting, which is a valuable addition. All original comments have been retained as they do not conflict with the provided rules. I've identified a significant logical flaw in the StratifiedGroupShuffleSplit implementation that could lead to repetitive and non-random splits. I've provided a high-severity comment with a suggested fix for this. Additionally, I've pointed out a misleading docstring, an inefficiency in the train_test_split function, and a suggestion to strengthen the tests to prevent similar issues in the future. Please address these points to ensure the reliability of these new features.
There was a problem hiding this comment.
Actionable comments posted: 3
🧹 Nitpick comments (3)
tests/test_split.py (1)
3-6: Consider importing from the public API.The tests import directly from
ratiopath.model_selection.split(the internal module) rather thanratiopath.model_selection(the public re-export). Using the public API in tests helps verify that the re-exports work correctly and aligns with how external users will consume the module.♻️ Proposed change
-from ratiopath.model_selection.split import ( +from ratiopath.model_selection import ( StratifiedGroupShuffleSplit, train_test_split, )🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/test_split.py` around lines 3 - 6, Tests currently import StratifiedGroupShuffleSplit and train_test_split directly from the internal module ratiopath.model_selection.split; update the import to use the public re-export by importing both StratifiedGroupShuffleSplit and train_test_split from ratiopath.model_selection instead, so tests exercise the public API and verify the re-exports for those symbols.pyproject.toml (1)
40-47: Duplicate test dependencies between[dependency-groups]and[project.optional-dependencies].The same packages (
openslide-bin>=4.0.0.8,pytest>=8.4.1) are declared in both sections. This may be intentional for compatibility with different tooling (PEP 735 dependency-groups vs PEP 508 optional-dependencies), but if so, consider adding a brief comment to clarify the intent and prevent future drift.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@pyproject.toml` around lines 40 - 47, The tests dependencies are duplicated between the dependency-groups `tests = ["openslide-bin>=4.0.0.8", "pytest>=8.4.1"]` and the `[project.optional-dependencies]` `tests = ["openslide-bin>=4.0.0.8", "pytest>=8.4.1"]`; either remove one duplicate or keep both but add a clear inline comment above the duplicated `tests` entry (in either `dependency-groups` or `[project.optional-dependencies]`) stating the reason for duplication (e.g., "kept for PEP 735 compatibility with tooling X") to prevent future drift and make intent explicit when editing the `tests` lists.ratiopath/model_selection/split.py (1)
16-20: Reliance on private scikit-learn APIs is fragile.The imports from
sklearn.model_selection._split,sklearn.utils._array_api,sklearn.utils._indexing, andsklearn.utils._param_validationuse underscore-prefixed (private) modules. These internal APIs can change without deprecation warnings between minor releases, potentially breaking this code.Consider:
- Documenting which sklearn version these internals are compatible with
- Adding integration tests that catch breakage early
- Evaluating if public APIs can achieve the same result (e.g.,
check_arrayinstead of internal utilities)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@ratiopath/model_selection/split.py` around lines 16 - 20, The code imports private sklearn internals (GroupsConsumerMixin, _validate_shuffle_split, get_namespace_and_device, move_to, _safe_indexing, Interval, RealNotInt, validate_params, _num_samples, check_random_state, indexable) which is fragile; replace these with public equivalents where possible (e.g., use public model_selection APIs, sklearn.utils.check_array / sklearn.utils.resample / sklearn.utils.validation public helpers, and public param-validation helpers) or wrap usage behind adapter functions in this module so you can swap implementations, and add a short compatibility note in the module docstring stating the tested sklearn version and add integration tests that assert behavior of split-related functions (including group handling and parameter validation) to detect breakage early.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@ratiopath/model_selection/split.py`:
- Around line 73-74: Update the misleading docstring that currently claims
"Groups appear exactly once in the test set across all splits" in the splitter's
docstring inside ratiopath.model_selection.split; change the wording to state
that each split independently selects its best fold for stratification and that
groups are not globally tracked across splits (so the same group may appear in
test sets of multiple splits), and reference the splitter class/function
docstring where that sentence appears so future readers understand the actual
behavior.
- Around line 264-271: The code path handling shuffle == False in
ratiopath/model_selection/split.py silently ignores the groups parameter; update
that branch to explicitly error when groups is provided instead of proceeding
(i.e., check if groups is not None and raise a ValueError), or alternatively
document this behavior — but the preferred fix is to raise an error referencing
the same wording style used for stratify (e.g., raise ValueError("Group-based
splitting not supported for shuffle=False")), and keep the rest of the
sequential assignment to train/test (variables train, test, n_train, n_test)
unchanged.
- Around line 133-155: The generator can yield None if cv.split() produced no
folds; update the loop in the split generator (the block using
StratifiedGroupKFold, min_diff, train_index, test_index, self._get_distribution
and flipped) to check after the inner for-loop whether train_index or test_index
is still None and, if so, raise a clear ValueError (or skip yielding) with
context about the input and n_splits rather than yielding None; also guard the
flipped swap so it only runs when train_index/test_index are not None. Ensure
the error message references the provided n_splits, X/y/groups sizes and that it
occurs before the yield.
---
Nitpick comments:
In `@pyproject.toml`:
- Around line 40-47: The tests dependencies are duplicated between the
dependency-groups `tests = ["openslide-bin>=4.0.0.8", "pytest>=8.4.1"]` and the
`[project.optional-dependencies]` `tests = ["openslide-bin>=4.0.0.8",
"pytest>=8.4.1"]`; either remove one duplicate or keep both but add a clear
inline comment above the duplicated `tests` entry (in either `dependency-groups`
or `[project.optional-dependencies]`) stating the reason for duplication (e.g.,
"kept for PEP 735 compatibility with tooling X") to prevent future drift and
make intent explicit when editing the `tests` lists.
In `@ratiopath/model_selection/split.py`:
- Around line 16-20: The code imports private sklearn internals
(GroupsConsumerMixin, _validate_shuffle_split, get_namespace_and_device,
move_to, _safe_indexing, Interval, RealNotInt, validate_params, _num_samples,
check_random_state, indexable) which is fragile; replace these with public
equivalents where possible (e.g., use public model_selection APIs,
sklearn.utils.check_array / sklearn.utils.resample / sklearn.utils.validation
public helpers, and public param-validation helpers) or wrap usage behind
adapter functions in this module so you can swap implementations, and add a
short compatibility note in the module docstring stating the tested sklearn
version and add integration tests that assert behavior of split-related
functions (including group handling and parameter validation) to detect breakage
early.
In `@tests/test_split.py`:
- Around line 3-6: Tests currently import StratifiedGroupShuffleSplit and
train_test_split directly from the internal module
ratiopath.model_selection.split; update the import to use the public re-export
by importing both StratifiedGroupShuffleSplit and train_test_split from
ratiopath.model_selection instead, so tests exercise the public API and verify
the re-exports for those symbols.
ℹ️ Review info
Configuration used: defaults
Review profile: CHILL
Plan: Pro
⛔ Files ignored due to path filters (1)
uv.lockis excluded by!**/*.lock
📒 Files selected for processing (4)
pyproject.tomlratiopath/model_selection/__init__.pyratiopath/model_selection/split.pytests/test_split.py
There was a problem hiding this comment.
Actionable comments posted: 3
🧹 Nitpick comments (2)
ratiopath/model_selection/split.py (2)
88-88: Consider making_default_test_sizeconfigurable or documenting the rationale.The hardcoded
0.2differs from sklearn'strain_test_splitdefault of0.25. If this is intentional, consider adding a brief comment explaining why. Alternatively, align with sklearn's default for consistency.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@ratiopath/model_selection/split.py` at line 88, The hardcoded attribute _default_test_size = 0.2 should be made explicit or configurable: either change the default to match sklearn (0.25) or add a constructor parameter (e.g., test_size=None or default_test_size=0.2) so callers can override it, and if you keep 0.2 add a one-line comment next to _default_test_size explaining why it differs from sklearn; update any references in methods like split/train_test_split wrappers to use the new parameter or attribute.
16-20: Using private scikit-learn APIs poses a stability risk.These imports rely on sklearn's internal modules (
_split,_array_api,_indexing,_param_validation), which are not part of the public API and may change without deprecation warnings between versions. Consider pinning an upper bound on scikit-learn (e.g.,scikit-learn>=1.8.0,<1.9) inpyproject.tomlto reduce the risk of breakage.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@ratiopath/model_selection/split.py` around lines 16 - 20, The code imports private scikit-learn internals (GroupsConsumerMixin, _validate_shuffle_split, get_namespace_and_device, move_to, _safe_indexing, Interval, RealNotInt, validate_params, _num_samples, check_random_state, indexable) which can break across versions; update the project to avoid instability by either switching to public equivalents where available or pinning scikit-learn to a safe range (for example add an upper bound like <1.9 in pyproject.toml such as scikit-learn>=1.8.0,<1.9) and note which private symbols (e.g., _validate_shuffle_split, _safe_indexing, Interval/RealNotInt) still require replacement so we can later refactor to public APIs.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@pyproject.toml`:
- Line 30: The project depends on internal scikit-learn APIs
(GroupsConsumerMixin and _validate_shuffle_split from
sklearn.model_selection._split) which can break across minor releases; update
the scikit-learn dependency in pyproject.toml to a pinned range such as
"scikit-learn>=1.8.0,<1.9" to avoid accidental upgrades, and either open an
issue with scikit-learn to request stabilizing those utilities or refactor
split.py to stop importing GroupsConsumerMixin and _validate_shuffle_split
(re-implement needed behavior or use public APIs) so the code no longer relies
on private symbols.
In `@ratiopath/model_selection/split.py`:
- Around line 128-129: The code converts y via np.asarray(y) and then calls
self._get_distribution(y) but does not validate y is not None; add an early
validation in the method handling the y parameter (the function that contains y
= np.asarray(y) and calls _get_distribution) to check if y is None or contains
only None-equivalent values and raise a clear ValueError indicating
stratification requires a non-None target. Ensure the check happens before
np.asarray(y) or immediately after it, and reference this validation in error
messages for the caller (use the same function/method name where y =
np.asarray(y) and _get_distribution is invoked).
- Line 99: Add an explicit validation for the parameter named groups (the
annotated parameter "groups: Any = None") at the start of the function where it
is defined: if groups is None raise a ValueError with a clear message like
"groups must be provided for stratified group splitting" so callers get an
immediate, descriptive error; place this check before any use of
StratifiedGroupKFold or downstream logic that assumes groups is present.
---
Nitpick comments:
In `@ratiopath/model_selection/split.py`:
- Line 88: The hardcoded attribute _default_test_size = 0.2 should be made
explicit or configurable: either change the default to match sklearn (0.25) or
add a constructor parameter (e.g., test_size=None or default_test_size=0.2) so
callers can override it, and if you keep 0.2 add a one-line comment next to
_default_test_size explaining why it differs from sklearn; update any references
in methods like split/train_test_split wrappers to use the new parameter or
attribute.
- Around line 16-20: The code imports private scikit-learn internals
(GroupsConsumerMixin, _validate_shuffle_split, get_namespace_and_device,
move_to, _safe_indexing, Interval, RealNotInt, validate_params, _num_samples,
check_random_state, indexable) which can break across versions; update the
project to avoid instability by either switching to public equivalents where
available or pinning scikit-learn to a safe range (for example add an upper
bound like <1.9 in pyproject.toml such as scikit-learn>=1.8.0,<1.9) and note
which private symbols (e.g., _validate_shuffle_split, _safe_indexing,
Interval/RealNotInt) still require replacement so we can later refactor to
public APIs.
ℹ️ Review info
Configuration used: defaults
Review profile: CHILL
Plan: Pro
⛔ Files ignored due to path filters (1)
uv.lockis excluded by!**/*.lock
📒 Files selected for processing (2)
pyproject.tomlratiopath/model_selection/split.py
There was a problem hiding this comment.
Actionable comments posted: 2
♻️ Duplicate comments (1)
ratiopath/model_selection/split.py (1)
96-130:⚠️ Potential issue | 🟡 MinorAdd explicit
yandgroupsvalidation at the top ofsplit().The public API/docstring requires stratification target and groups, but missing values currently fail later with less direct errors.
🛡️ Proposed fix
def split( self, X: list[str] | MatrixLike, # noqa: N803 y: ArrayLike | None = None, groups: Any = None, ) -> Iterator[Any]: @@ - n_samples = _num_samples(X) + if y is None: + raise ValueError("`y` must be provided for stratified group splitting.") + if groups is None: + raise ValueError("`groups` must be provided for stratified group splitting.") + + n_samples = _num_samples(X)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@ratiopath/model_selection/split.py` around lines 96 - 130, Add explicit validation at the start of split(): ensure y is not None when stratification is required and groups is not None (raise ValueError with clear message if missing), check that len(y) and len(groups) (after converting y with np.asarray if needed) equal n_samples from _num_samples(X), and raise if they don't match; do this before using y in _get_distribution or any later logic so missing/length-mismatch errors are reported clearly (refer to the split method, variables y and groups, _num_samples, n_samples, and _get_distribution).
🧹 Nitpick comments (1)
ratiopath/model_selection/split.py (1)
16-20: Isolate scikit-learn private API imports behind a compatibility layer to reduce future breakage risk.Imports from underscore modules (e.g.,
sklearn.model_selection._split,sklearn.utils._array_api,sklearn.utils._indexing) and private functions (e.g.,_validate_shuffle_split,_num_samples) are not guaranteed to be stable across sklearn releases. A thin wrapper module (e.g.,ratiopath/_sklearn_compat.py) centralizing these imports would make version transitions more manageable.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@ratiopath/model_selection/split.py` around lines 16 - 20, Replace direct private sklearn imports in split.py with a compatibility layer: create ratiopath/_sklearn_compat.py that imports and re-exports GroupsConsumerMixin, _validate_shuffle_split, get_namespace_and_device, move_to, _safe_indexing, Interval, RealNotInt, validate_params, _num_samples, check_random_state, and indexable (and add graceful fallbacks/try/except aliases for different sklearn versions), then change the imports in ratiopath/model_selection/split.py to import those symbols from ratiopath._sklearn_compat instead of from sklearn.* private modules so future sklearn breakages are isolated to the compat module.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@ratiopath/model_selection/split.py`:
- Around line 248-250: Update the docstring for the split function that
documents the shuffle parameter to reflect runtime behavior: state that when
shuffle=False both stratify and groups must be None (not just stratify).
Reference the parameters by name (shuffle, stratify, groups) and update the
sentence currently mentioning only stratify to explicitly mention groups as well
so documentation matches the implemented checks in the function.
- Around line 127-137: n_splits computed as round(n_samples / n_test) can exceed
the number of unique groups and cause StratifiedGroupKFold to raise; before
constructing StratifiedGroupKFold in the loop, compute max_groups =
len(np.unique(groups)) (or 1 if groups is None) and clamp n_splits = max(2,
min(n_splits, max_groups)) (ensuring the lower bound 2 required by
StratifiedGroupKFold), then pass that clamped n_splits into
StratifiedGroupKFold(...) so the fold count never exceeds available group
cardinality.
---
Duplicate comments:
In `@ratiopath/model_selection/split.py`:
- Around line 96-130: Add explicit validation at the start of split(): ensure y
is not None when stratification is required and groups is not None (raise
ValueError with clear message if missing), check that len(y) and len(groups)
(after converting y with np.asarray if needed) equal n_samples from
_num_samples(X), and raise if they don't match; do this before using y in
_get_distribution or any later logic so missing/length-mismatch errors are
reported clearly (refer to the split method, variables y and groups,
_num_samples, n_samples, and _get_distribution).
---
Nitpick comments:
In `@ratiopath/model_selection/split.py`:
- Around line 16-20: Replace direct private sklearn imports in split.py with a
compatibility layer: create ratiopath/_sklearn_compat.py that imports and
re-exports GroupsConsumerMixin, _validate_shuffle_split,
get_namespace_and_device, move_to, _safe_indexing, Interval, RealNotInt,
validate_params, _num_samples, check_random_state, and indexable (and add
graceful fallbacks/try/except aliases for different sklearn versions), then
change the imports in ratiopath/model_selection/split.py to import those symbols
from ratiopath._sklearn_compat instead of from sklearn.* private modules so
future sklearn breakages are isolated to the compat module.
StratifiedGroupShuffleSplitis inspired by the implementation at https://gitlab.ics.muni.cz/rationai/digital-pathology/pathology/prostate-cancer/-/blob/master/preprocessing/stratified_group_split.py?ref_type=heads, but follows the scikit-learn splitting API.
train_test_splitextendssklearn.model_selection.train_test_splitwith group support.Possible discussion:
– Should the
model_selectionfolder be located inside thesklearnfolder?Summary by CodeRabbit
New Features
Chores
Tests