Skip to content

Use scikit-learn's array-api to accelerate StandardScaler#8020

Merged
rapids-bot[bot] merged 9 commits intorapidsai:mainfrom
jcrist:array-api-dispatch-preprocessing
Apr 28, 2026
Merged

Use scikit-learn's array-api to accelerate StandardScaler#8020
rapids-bot[bot] merged 9 commits intorapidsai:mainfrom
jcrist:array-api-dispatch-preprocessing

Conversation

@jcrist
Copy link
Copy Markdown
Member

@jcrist jcrist commented Apr 28, 2026

This switches our cuml.accel acceleration of StandardScaler to use scikit-learn's array-api support. This helps us achieve higher compatibility with sklearn, and better expose (and test) out their array-api support, all while relying on less code on our end.

We accomplish this by defining a new base class to use for array-api backed proxies (ArrayAPIProxyBase). This base class constructs a wrapper class wrapping the sklearn model in a cuml.Base-compatible class, then rewraps that in a ProxyBase. This lets the array-api compatible models fit into the existing cuml-accel framework with limited special casing.

In the case of StandardScaler I found only scikit-learn >= 1.8 worked successfully (earlier versions had bugs). For now I've hardcoded that as the minimum version. I suspect when we expand this to other estimators we might find we need this version check to be more flexible, but we can handle that then. When running with earlier versions of scikit-learn, the profiler and logger will flag the estimator as unaccelerated and note the scikit-learn version as the reason for that. I think that's decent enough UX.

Fixes #7841.

@jcrist jcrist self-assigned this Apr 28, 2026
@jcrist jcrist requested a review from a team as a code owner April 28, 2026 17:14
@jcrist jcrist added improvement Improvement / enhancement to an existing function non-breaking Non-breaking change labels Apr 28, 2026
@jcrist jcrist requested a review from divyegala April 28, 2026 17:14
@jcrist jcrist added the cuml-accel Issues related to cuml.accel label Apr 28, 2026
@github-actions github-actions Bot added the Cython / Python Cython or Python issue label Apr 28, 2026
@jcrist jcrist requested a review from csadorf April 28, 2026 17:15
@coderabbitai
Copy link
Copy Markdown

coderabbitai Bot commented Apr 28, 2026

No actionable comments were generated in the recent review. 🎉

ℹ️ Recent review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Enterprise

Run ID: 5092870d-148c-460b-a60f-054c89da94d9

📥 Commits

Reviewing files that changed from the base of the PR and between 130d475 and 77f4c98.

📒 Files selected for processing (2)
  • python/cuml/cuml_accel_tests/test_pipeline.py
  • python/cuml/cuml_accel_tests/upstream/scikit-learn/xfail-list.yaml
🚧 Files skipped from review as they are similar to previous changes (1)
  • python/cuml/cuml_accel_tests/upstream/scikit-learn/xfail-list.yaml

📝 Walkthrough

Summary by CodeRabbit

  • New Features

    • StandardScaler now leverages Array API support for enhanced compatibility when using scikit-learn ≥ 1.8
    • Improved partial_fit handling with better state management
  • Documentation

    • Updated limitations documentation to clarify acceleration requirements and CPU fallback conditions
  • Tests

    • Added comprehensive integration tests for StandardScaler operations
    • Enhanced compatibility testing for Array API-enabled estimators

Walkthrough

Replaces bespoke StandardScaler GPU override logic with an array‑API proxy-based approach that delegates to scikit-learn's array‑API-enabled StandardScaler (requires scikit-learn >= 1.8); adds a conditional wrapper for sklearn's array‑api dispatch, internal-context signaling, registry entry, and accompanying tests and docs updates.

Changes

Cohort / File(s) Summary
Documentation
docs/source/cuml-accel/limitations.rst
Documented requirement that cuml.accel acceleration for some estimators depends on scikit-learn's array‑API support (sklearn >= 1.8); tightened StandardScaler CPU-fallback triggers to sparse X and sklearn < 1.8; added array‑API Sphinx target.
StandardScaler override
python/cuml/cuml/accel/_overrides/sklearn/preprocessing.py
Removed custom GPU input-compat checks and bespoke fit/partial_fit overrides; StandardScaler now subclasses ArrayAPIProxyBase with _cpu_class_path="sklearn.preprocessing.StandardScaler".
Proxy & wrapper core
python/cuml/cuml/accel/estimator_proxy.py
Added ArrayAPIProxyBase and _ArrayAPIWrapper; treated partial_fit as fit-like; extended proxy exports and introduced array‑API proxy dispatch path that uses sklearn's array‑api mode and cuML input coercion.
Array‑API patch
python/cuml/cuml/accel/_patches/sklearn/utils/_array_api.py
Added wrapper for sklearn's _check_array_api_dispatch that skips calling upstream dispatch when running inside cuML internal context.
Registry
python/cuml/cuml/accel/core.py
Registered sklearn.utils._array_api in accel patch registry so the local patch is applied.
Internal context helper
python/cuml/cuml/internals/outputs.py
Exported new public in_internal_context() and updated module __all__ to include it.
Tests — integration & proxy
python/cuml/cuml_accel_tests/integration/test_preprocessing.py, python/cuml/cuml_accel_tests/test_estimator_proxy.py, python/cuml/cuml_accel_tests/test_pipeline.py
Added integration tests for StandardScaler (fit/transform/inverse_transform/partial_fit); extended proxy tests for array‑API behavior (pandas I/O, partial_fit accumulation, coercion, pickle, sklearn < 1.8 fallback); conditional skips based on sklearn >= 1.8.
Tests — core & upstream xfails
python/cuml/cuml_accel_tests/test_core.py, python/cuml/cuml_accel_tests/upstream/scikit-learn/xfail-list.yaml
Filtered proxy subclass discovery to only those with _gpu_class; removed several StandardScaler xfail entries and added a targeted xfail for test_scaler_2d_arrays under sklearn >= 1.8.

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Possibly related PRs

Suggested reviewers

  • csadorf
  • divyegala
  • dantegd
🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 21.95% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Title check ✅ Passed The title accurately describes the main change: switching StandardScaler acceleration to use scikit-learn's array-api support, which is the primary objective of this PR.
Description check ✅ Passed The description is well-detailed and directly related to the changeset, explaining the motivation, implementation approach with ArrayAPIProxyBase, version constraints, and user experience considerations.
Linked Issues check ✅ Passed The PR fully addresses issue #7841 objectives: it uses upstream sklearn array-api implementations, provides a shim (ArrayAPIProxyBase) to enable array-api routing, and achieves reduced maintenance burden with improved sklearn compatibility.
Out of Scope Changes check ✅ Passed All changes align with the stated objective of leveraging sklearn's array-api for StandardScaler acceleration; no unrelated modifications are present.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (1)
python/cuml/cuml_accel_tests/integration/test_preprocessing.py (1)

9-45: Consider adding explicit edge-case coverage in this new integration module.

These tests are strong for the main path, but adding empty/single-sample (and one high-dimensional) scenarios would harden regressions around the new proxy path.

As per coding guidelines "python/**/test_*.py: Test files must validate numerical correctness by comparing with scikit-learn, include edge case coverage (empty datasets, single sample, high-dimensional data), test fit/predict/transform consistency, and test different input types (cuDF, pandas, NumPy)."

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@python/cuml/cuml_accel_tests/integration/test_preprocessing.py` around lines
9 - 45, Add explicit edge-case tests to the StandardScaler integration tests:
extend test_standard_scaler and test_standard_scaler_partial_fit (or add new
tests named e.g., test_standard_scaler_empty_single_highdim and
test_standard_scaler_different_input_types) to cover empty arrays (0 samples),
single-sample inputs (n_samples=1), and a high-dimensional case (n_features >>
n_samples), and for each case compare cuml StandardScaler
fit/transform/inverse_transform results against scikit-learn’s StandardScaler
for numerical equality; also add variants using numpy, pandas, and cuDF inputs
to verify consistent behavior across input types and ensure partial_fit updates
n_samples_seen_ correctly in the single-sample and incremental scenarios.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@python/cuml/cuml_accel_tests/test_estimator_proxy.py`:
- Around line 783-801: In test_array_api_proxy_partial_fit, the final assertion
is validating the fit path (model._cpu) instead of the partial_fit path; update
the assertion to check model2._cpu for the absence of "n_features_in_" so the
test validates the partial_fit object (change the last line to assert not
hasattr(model2._cpu, "n_features_in_") in the test_array_api_proxy_partial_fit
function).

---

Nitpick comments:
In `@python/cuml/cuml_accel_tests/integration/test_preprocessing.py`:
- Around line 9-45: Add explicit edge-case tests to the StandardScaler
integration tests: extend test_standard_scaler and
test_standard_scaler_partial_fit (or add new tests named e.g.,
test_standard_scaler_empty_single_highdim and
test_standard_scaler_different_input_types) to cover empty arrays (0 samples),
single-sample inputs (n_samples=1), and a high-dimensional case (n_features >>
n_samples), and for each case compare cuml StandardScaler
fit/transform/inverse_transform results against scikit-learn’s StandardScaler
for numerical equality; also add variants using numpy, pandas, and cuDF inputs
to verify consistent behavior across input types and ensure partial_fit updates
n_samples_seen_ correctly in the single-sample and incremental scenarios.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Enterprise

Run ID: b412029d-da23-4207-bd0b-f25a8c976fdd

📥 Commits

Reviewing files that changed from the base of the PR and between ae8b901 and 130d475.

📒 Files selected for processing (11)
  • docs/source/cuml-accel/limitations.rst
  • python/cuml/cuml/accel/_overrides/sklearn/preprocessing.py
  • python/cuml/cuml/accel/_patches/sklearn/utils/__init__.py
  • python/cuml/cuml/accel/_patches/sklearn/utils/_array_api.py
  • python/cuml/cuml/accel/core.py
  • python/cuml/cuml/accel/estimator_proxy.py
  • python/cuml/cuml/internals/outputs.py
  • python/cuml/cuml_accel_tests/integration/test_preprocessing.py
  • python/cuml/cuml_accel_tests/test_core.py
  • python/cuml/cuml_accel_tests/test_estimator_proxy.py
  • python/cuml/cuml_accel_tests/upstream/scikit-learn/xfail-list.yaml
💤 Files with no reviewable changes (1)
  • python/cuml/cuml_accel_tests/upstream/scikit-learn/xfail-list.yaml

Comment thread python/cuml/cuml_accel_tests/test_estimator_proxy.py
Copy link
Copy Markdown
Member Author

@jcrist jcrist left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Annotating the PR to ease review.

- If ``sample_weight`` is provided (weighted statistics not supported on GPU).
- If ``X`` has object dtype, half precision (``float16``) dtype, or complex dtype (``complex64``, ``complex128``).
- If ``X`` is a sparse matrix with integer dtype or in a format other than CSR or CSC.
- If ``X`` is sparse
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The array-api doesn't support sparse data, so no array-api accelerated estimators will work here. That said, while StandardScaler can support sparse inputs, doing so without with_mean=False would remove the sparsity. It's kind of a weird operation to do on sparse data anyway.

Given that, I'm not concerned about this limitation, and don't think this should prevent us from moving forward with this change.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a bit unfortunate that we are losing function, but I agree that we should move forward here. I think for some pre-processors that make more sense for sparse data (like MaxAbsScaler), we might have to revisit this.

"""partial_fit not supported on GPU - always fall back to CPU."""
raise UnsupportedOnGPU("partial_fit not supported on GPU")
class StandardScaler(ArrayAPIProxyBase):
_cpu_class_path = "sklearn.preprocessing.StandardScaler"
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Other array-api backed estimators would be easy to wrap this same way. I've limited this PR to just StandardScaler to keep things clean, but the other *Scaler estimators should all work well too I'd think.

Comment thread python/cuml/cuml/accel/_patches/sklearn/utils/_array_api.py
return self._cpu._repr_html_


class _ArrayAPIWrapper(Base, InteropMixin):
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This class makes a scikit-learn array-api estimator look like a cuml estimator enough that it works with ProxyBase.

return getattr(self._internal_model, name)


class ArrayAPIProxyBase(ProxyBase):
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the actual developer-facing API. Subclasses of this will generate a _ArrayAPIWrapper class automatically, and use that as the GPU estimator.

Comment on lines +708 to +711
if not SKLEARN_18:
raise UnsupportedOnGPU(
"scikit-learn >= 1.8 is required to run on GPU"
)
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here's where we fallback when running on earlier versions of sklearn. This check could be made more flexible if some estimators require different versions.

Copy link
Copy Markdown
Contributor

@csadorf csadorf left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very nice! I like the approach.

- If ``sample_weight`` is provided (weighted statistics not supported on GPU).
- If ``X`` has object dtype, half precision (``float16``) dtype, or complex dtype (``complex64``, ``complex128``).
- If ``X`` is a sparse matrix with integer dtype or in a format other than CSR or CSC.
- If ``X`` is sparse
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a bit unfortunate that we are losing function, but I agree that we should move forward here. I think for some pre-processors that make more sense for sparse data (like MaxAbsScaler), we might have to revisit this.

Comment thread python/cuml/cuml/accel/_patches/sklearn/utils/_array_api.py
Comment on lines +607 to +624
def __init__(self, *args, output_type=None, verbose=False, **kwargs):
super().__init__(output_type=output_type, verbose=verbose)
self._internal_model = self._internal_class(*args, **kwargs)

def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)

# Store `_internal_class` for ease-of-reference
cls._internal_class = cls._get_cpu_class()

# Wrap __init__ to ensure signature compatibility.
orig_init = cls.__init__

@functools.wraps(cls._internal_class.__init__)
def __init__(self, *args, **kwargs):
orig_init(self, *args, **kwargs)

cls.__init__ = __init__
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might be worth adding one more comment about the constructor patching. It took me a moment to follow the logic here.

@jcrist
Copy link
Copy Markdown
Member Author

jcrist commented Apr 28, 2026

/merge

@rapids-bot rapids-bot Bot merged commit c811e80 into rapidsai:main Apr 28, 2026
93 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

cuml-accel Issues related to cuml.accel Cython / Python Cython or Python issue improvement Improvement / enhancement to an existing function non-breaking Non-breaking change

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Use array-api to wrap preprocessors in cuml.accel

3 participants