Skip to content

Commit a060d89

Browse files
authored
Add cuml.accel.is_proxy (#6559)
This adds a new utility for checking if a class or instance is an accelerated proxy object. Useful for easing introspection when working with and debugging `cuml.accel`. Authors: - Jim Crist-Harif (https://github.com/jcrist) Approvers: - Simon Adorf (https://github.com/csadorf) URL: #6559
1 parent 59ff097 commit a060d89

3 files changed

Lines changed: 25 additions & 1 deletion

File tree

python/cuml/cuml/accel/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,14 @@
1515
#
1616

1717
from cuml.accel.core import enabled, install
18+
from cuml.accel.estimator_proxy import is_proxy
1819
from cuml.accel.magics import load_ipython_extension
1920
from cuml.accel.pytest_plugin import pytest_load_initial_conftests
2021

2122
__all__ = (
22-
"install",
2323
"enabled",
24+
"install",
25+
"is_proxy",
2426
"load_ipython_extension",
2527
"pytest_load_initial_conftests",
2628
)

python/cuml/cuml/accel/estimator_proxy.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,16 @@
1818
import inspect
1919

2020

21+
def is_proxy(instance_or_class) -> bool:
22+
"""Check if an instance or class is a proxy object created by the accelerator."""
23+
24+
if isinstance(instance_or_class, type):
25+
cls = instance_or_class
26+
else:
27+
cls = type(instance_or_class)
28+
return issubclass(cls, ProxyMixin)
29+
30+
2131
def reconstruct_proxy(proxy_module, proxy_name, state):
2232
module = importlib.import_module(proxy_module)
2333
cls = getattr(module, proxy_name)

python/cuml/cuml_accel_tests/test_estimator_proxy.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,18 @@
2222
from sklearn.decomposition import PCA, TruncatedSVD
2323
from sklearn.neighbors import NearestNeighbors
2424

25+
from cuml.accel import is_proxy
26+
27+
28+
def test_is_proxy():
29+
class Foo:
30+
pass
31+
32+
assert is_proxy(PCA)
33+
assert is_proxy(PCA())
34+
assert not is_proxy(Foo)
35+
assert not is_proxy(Foo())
36+
2537

2638
def test_meta_attributes():
2739
# Check that the proxy estimator pretends to look like the

0 commit comments

Comments
 (0)