Skip to content
Closed
Show file tree
Hide file tree
Changes from 18 commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
a025668
Add support for Pipeline in cuml.accel
csadorf Feb 6, 2026
c73fc7f
Add basic pipeline unit test
csadorf Feb 9, 2026
0a72031
Support both accelerated and non-accelerated steps in the same pipeline.
csadorf Feb 9, 2026
b1ef357
refactor pipeline.py to deduplicate
csadorf Feb 9, 2026
0189bc7
Refactor to split pipeline implementation and proxy wrapper.
csadorf Feb 9, 2026
98263a5
patch collection logic
csadorf Feb 9, 2026
b7d6313
attempt to fix transformer tags complain
csadorf Feb 9, 2026
07e4c13
implement __getitem__
csadorf Feb 10, 2026
3ff2f08
fix transform / fit_transform availability
csadorf Feb 10, 2026
6414fd1
fix fit_predict availability
csadorf Feb 10, 2026
cefaf85
Propagate underlying cause for _cpu_has check
csadorf Feb 10, 2026
a16dc83
update xfail list with remaining failures
csadorf Feb 10, 2026
9e3cd07
make remaining methods conditional
csadorf Feb 10, 2026
f203cc3
patch sklearn test_construct_instances
csadorf Feb 10, 2026
8a2012a
Expand xfail_manager to support test removal
csadorf Feb 10, 2026
cba5a00
remove passing tests from xfail list
csadorf Feb 10, 2026
3447e4c
preserve doc-strings
csadorf Feb 11, 2026
29007e9
update xfail list
csadorf Feb 18, 2026
7220ab8
Refactor to split pipeline acceleration and wrapping.
csadorf Feb 18, 2026
ffe1911
ProxyBase supports conditional methods
csadorf Feb 18, 2026
eb80699
Only call .get() on cupy arrays.
csadorf Feb 18, 2026
b607066
do not convert sparse matrixes
csadorf Feb 18, 2026
312b6e6
Implement Pipeline __len__
csadorf Feb 18, 2026
e1b1d9d
Implement inefficient hack to support dataframe containers.
csadorf Feb 19, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 44 additions & 1 deletion python/cuml/cuml/accel/_sklearn_patch.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
#
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION.
# SPDX-License-Identifier: Apache-2.0
#
import warnings
from collections import defaultdict
from operator import itemgetter

Expand Down Expand Up @@ -39,6 +40,34 @@ def _patched_all_estimators(*args, **kwargs):
return sorted(set(estimators), key=itemgetter(0))


def _construct_instances_for_proxy(init_params, skipped):
"""Return a _construct_instances implementation that looks up INIT_PARAMS by proxy or _cpu_class."""

def _construct_instances(Estimator):
if Estimator in skipped:
msg = f"Can't instantiate estimator {Estimator.__name__}"
from sklearn.exceptions import SkipTestWarning
from sklearn.utils._testing import SkipTest

warnings.warn(msg, SkipTestWarning)
raise SkipTest(msg)
key = (
Estimator
if Estimator in init_params
else getattr(Estimator, "_cpu_class", None)
)
if key is not None and key in init_params:
param_sets = init_params[key]
if not isinstance(param_sets, list):
param_sets = [param_sets]
for params in param_sets:
yield Estimator(**params)
else:
yield Estimator()

return _construct_instances


def apply_sklearn_patches():
"""Apply all sklearn patches necessary for the accelerator testing."""

Expand All @@ -56,3 +85,17 @@ def apply_sklearn_patches():
import sklearn.utils

sklearn.utils.all_estimators = _patched_all_estimators

# Patch _construct_instances so INIT_PARAMS lookup works for proxy classes.
# INIT_PARAMS is keyed by the class at import time; all_estimators() yields
# proxy classes, so "Estimator in INIT_PARAMS" can be False. Look up by
# _cpu_class when the estimator is a proxy so we use the same param sets
# (e.g. Pipeline(steps=...) instead of Pipeline()).
try:
from sklearn.utils._test_common import instance_generator
except ImportError:
return
instance_generator._construct_instances = _construct_instances_for_proxy(
instance_generator.INIT_PARAMS,
instance_generator.SKIPPED_ESTIMATORS,
)
Loading
Loading