Use scikit-learn's array-api to accelerate StandardScaler#8020
Use scikit-learn's array-api to accelerate StandardScaler#8020rapids-bot[bot] merged 9 commits intorapidsai:mainfrom
StandardScaler#8020Conversation
|
No actionable comments were generated in the recent review. 🎉 ℹ️ Recent review info⚙️ Run configurationConfiguration used: Path: .coderabbit.yaml Review profile: CHILL Plan: Enterprise Run ID: 📒 Files selected for processing (2)
🚧 Files skipped from review as they are similar to previous changes (1)
📝 WalkthroughSummary by CodeRabbit
WalkthroughReplaces bespoke StandardScaler GPU override logic with an array‑API proxy-based approach that delegates to scikit-learn's array‑API-enabled StandardScaler (requires scikit-learn >= 1.8); adds a conditional wrapper for sklearn's array‑api dispatch, internal-context signaling, registry entry, and accompanying tests and docs updates. Changes
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Possibly related PRs
Suggested reviewers
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Comment |
There was a problem hiding this comment.
Actionable comments posted: 1
🧹 Nitpick comments (1)
python/cuml/cuml_accel_tests/integration/test_preprocessing.py (1)
9-45: Consider adding explicit edge-case coverage in this new integration module.These tests are strong for the main path, but adding empty/single-sample (and one high-dimensional) scenarios would harden regressions around the new proxy path.
As per coding guidelines "
python/**/test_*.py: Test files must validate numerical correctness by comparing with scikit-learn, include edge case coverage (empty datasets, single sample, high-dimensional data), test fit/predict/transform consistency, and test different input types (cuDF, pandas, NumPy)."🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@python/cuml/cuml_accel_tests/integration/test_preprocessing.py` around lines 9 - 45, Add explicit edge-case tests to the StandardScaler integration tests: extend test_standard_scaler and test_standard_scaler_partial_fit (or add new tests named e.g., test_standard_scaler_empty_single_highdim and test_standard_scaler_different_input_types) to cover empty arrays (0 samples), single-sample inputs (n_samples=1), and a high-dimensional case (n_features >> n_samples), and for each case compare cuml StandardScaler fit/transform/inverse_transform results against scikit-learn’s StandardScaler for numerical equality; also add variants using numpy, pandas, and cuDF inputs to verify consistent behavior across input types and ensure partial_fit updates n_samples_seen_ correctly in the single-sample and incremental scenarios.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@python/cuml/cuml_accel_tests/test_estimator_proxy.py`:
- Around line 783-801: In test_array_api_proxy_partial_fit, the final assertion
is validating the fit path (model._cpu) instead of the partial_fit path; update
the assertion to check model2._cpu for the absence of "n_features_in_" so the
test validates the partial_fit object (change the last line to assert not
hasattr(model2._cpu, "n_features_in_") in the test_array_api_proxy_partial_fit
function).
---
Nitpick comments:
In `@python/cuml/cuml_accel_tests/integration/test_preprocessing.py`:
- Around line 9-45: Add explicit edge-case tests to the StandardScaler
integration tests: extend test_standard_scaler and
test_standard_scaler_partial_fit (or add new tests named e.g.,
test_standard_scaler_empty_single_highdim and
test_standard_scaler_different_input_types) to cover empty arrays (0 samples),
single-sample inputs (n_samples=1), and a high-dimensional case (n_features >>
n_samples), and for each case compare cuml StandardScaler
fit/transform/inverse_transform results against scikit-learn’s StandardScaler
for numerical equality; also add variants using numpy, pandas, and cuDF inputs
to verify consistent behavior across input types and ensure partial_fit updates
n_samples_seen_ correctly in the single-sample and incremental scenarios.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Enterprise
Run ID: b412029d-da23-4207-bd0b-f25a8c976fdd
📒 Files selected for processing (11)
docs/source/cuml-accel/limitations.rstpython/cuml/cuml/accel/_overrides/sklearn/preprocessing.pypython/cuml/cuml/accel/_patches/sklearn/utils/__init__.pypython/cuml/cuml/accel/_patches/sklearn/utils/_array_api.pypython/cuml/cuml/accel/core.pypython/cuml/cuml/accel/estimator_proxy.pypython/cuml/cuml/internals/outputs.pypython/cuml/cuml_accel_tests/integration/test_preprocessing.pypython/cuml/cuml_accel_tests/test_core.pypython/cuml/cuml_accel_tests/test_estimator_proxy.pypython/cuml/cuml_accel_tests/upstream/scikit-learn/xfail-list.yaml
💤 Files with no reviewable changes (1)
- python/cuml/cuml_accel_tests/upstream/scikit-learn/xfail-list.yaml
jcrist
left a comment
There was a problem hiding this comment.
Annotating the PR to ease review.
| - If ``sample_weight`` is provided (weighted statistics not supported on GPU). | ||
| - If ``X`` has object dtype, half precision (``float16``) dtype, or complex dtype (``complex64``, ``complex128``). | ||
| - If ``X`` is a sparse matrix with integer dtype or in a format other than CSR or CSC. | ||
| - If ``X`` is sparse |
There was a problem hiding this comment.
The array-api doesn't support sparse data, so no array-api accelerated estimators will work here. That said, while StandardScaler can support sparse inputs, doing so without with_mean=False would remove the sparsity. It's kind of a weird operation to do on sparse data anyway.
Given that, I'm not concerned about this limitation, and don't think this should prevent us from moving forward with this change.
There was a problem hiding this comment.
It's a bit unfortunate that we are losing function, but I agree that we should move forward here. I think for some pre-processors that make more sense for sparse data (like MaxAbsScaler), we might have to revisit this.
| """partial_fit not supported on GPU - always fall back to CPU.""" | ||
| raise UnsupportedOnGPU("partial_fit not supported on GPU") | ||
| class StandardScaler(ArrayAPIProxyBase): | ||
| _cpu_class_path = "sklearn.preprocessing.StandardScaler" |
There was a problem hiding this comment.
Other array-api backed estimators would be easy to wrap this same way. I've limited this PR to just StandardScaler to keep things clean, but the other *Scaler estimators should all work well too I'd think.
| return self._cpu._repr_html_ | ||
|
|
||
|
|
||
| class _ArrayAPIWrapper(Base, InteropMixin): |
There was a problem hiding this comment.
This class makes a scikit-learn array-api estimator look like a cuml estimator enough that it works with ProxyBase.
| return getattr(self._internal_model, name) | ||
|
|
||
|
|
||
| class ArrayAPIProxyBase(ProxyBase): |
There was a problem hiding this comment.
This is the actual developer-facing API. Subclasses of this will generate a _ArrayAPIWrapper class automatically, and use that as the GPU estimator.
| if not SKLEARN_18: | ||
| raise UnsupportedOnGPU( | ||
| "scikit-learn >= 1.8 is required to run on GPU" | ||
| ) |
There was a problem hiding this comment.
Here's where we fallback when running on earlier versions of sklearn. This check could be made more flexible if some estimators require different versions.
csadorf
left a comment
There was a problem hiding this comment.
Very nice! I like the approach.
| - If ``sample_weight`` is provided (weighted statistics not supported on GPU). | ||
| - If ``X`` has object dtype, half precision (``float16``) dtype, or complex dtype (``complex64``, ``complex128``). | ||
| - If ``X`` is a sparse matrix with integer dtype or in a format other than CSR or CSC. | ||
| - If ``X`` is sparse |
There was a problem hiding this comment.
It's a bit unfortunate that we are losing function, but I agree that we should move forward here. I think for some pre-processors that make more sense for sparse data (like MaxAbsScaler), we might have to revisit this.
| def __init__(self, *args, output_type=None, verbose=False, **kwargs): | ||
| super().__init__(output_type=output_type, verbose=verbose) | ||
| self._internal_model = self._internal_class(*args, **kwargs) | ||
|
|
||
| def __init_subclass__(cls, **kwargs): | ||
| super().__init_subclass__(**kwargs) | ||
|
|
||
| # Store `_internal_class` for ease-of-reference | ||
| cls._internal_class = cls._get_cpu_class() | ||
|
|
||
| # Wrap __init__ to ensure signature compatibility. | ||
| orig_init = cls.__init__ | ||
|
|
||
| @functools.wraps(cls._internal_class.__init__) | ||
| def __init__(self, *args, **kwargs): | ||
| orig_init(self, *args, **kwargs) | ||
|
|
||
| cls.__init__ = __init__ |
There was a problem hiding this comment.
Might be worth adding one more comment about the constructor patching. It took me a moment to follow the logic here.
|
/merge |
This switches our
cuml.accelacceleration ofStandardScalerto use scikit-learn's array-api support. This helps us achieve higher compatibility with sklearn, and better expose (and test) out their array-api support, all while relying on less code on our end.We accomplish this by defining a new base class to use for array-api backed proxies (
ArrayAPIProxyBase). This base class constructs a wrapper class wrapping the sklearn model in acuml.Base-compatible class, then rewraps that in aProxyBase. This lets the array-api compatible models fit into the existing cuml-accel framework with limited special casing.In the case of
StandardScalerI found only scikit-learn >= 1.8 worked successfully (earlier versions had bugs). For now I've hardcoded that as the minimum version. I suspect when we expand this to other estimators we might find we need this version check to be more flexible, but we can handle that then. When running with earlier versions of scikit-learn, the profiler and logger will flag the estimator as unaccelerated and note the scikit-learn version as the reason for that. I think that's decent enough UX.Fixes #7841.