Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion python/cuml/cuml/accel/_sklearn_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from collections import defaultdict
from operator import itemgetter

from sklearn.utils._test_common.instance_generator import _construct_instances
from sklearn.utils.discovery import all_estimators as sklearn_all_estimators

from cuml.accel.estimator_proxy import is_proxy
Expand All @@ -28,6 +29,7 @@ def _patched_all_estimators(*args, **kwargs):
that filters out duplicate estimator names, keeping only proxy estimators
when both proxy and non-proxy versions exist.
"""

# Obtain the list of all estimators from sklearn
ret = sklearn_all_estimators(*args, **kwargs)

Expand All @@ -50,6 +52,12 @@ def _patched_all_estimators(*args, **kwargs):
return sorted(set(estimators), key=itemgetter(0))


def _patched_tested_estimators(*args, **kwargs):
for _, Estimator in _patched_all_estimators(*args, **kwargs):
for estimator in _construct_instances(Estimator):
yield estimator


def apply_sklearn_patches():
"""Apply all sklearn patches necessary for the accelerator testing."""

Expand All @@ -64,6 +72,8 @@ def apply_sklearn_patches():
# The patch filters out duplicates by keeping only proxy estimators when
# multiple classes with the same name exist, ensuring test collection
# succeeds.
import sklearn.utils
import sklearn.tests.test_common

sklearn.utils.discovery.all_estimators = _patched_all_estimators
sklearn.utils.all_estimators = _patched_all_estimators
sklearn.tests.test_common._tested_estimators = _patched_tested_estimators