Skip to content

Commit a251b2f

Browse files
committed
Delay some sklearn imports
After this commit, `import cuml` should only import core bits of sklearn like `import sklearn.base`. Anything algorithm specific is delayed until first use, speeding import time.
1 parent 32a2e39 commit a251b2f

4 files changed

Lines changed: 22 additions & 15 deletions

File tree

python/cuml/cuml/explainer/kernel_shap.pyx

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@ from random import randint
2020

2121
import cupy as cp
2222
import numpy as np
23-
from sklearn.linear_model import LassoLarsIC, lars_path
2423

2524
from cuml.explainer.base import SHAPBase
2625
from cuml.explainer.common import get_cai_ptr, model_func_call
@@ -565,6 +564,7 @@ def _l1_regularization(X,
565564
"""
566565
Function calls LASSO or LARS if l1 regularization is needed.
567566
"""
567+
from sklearn.linear_model import LassoLarsIC, lars_path
568568

569569
# create augmented dataset for feature selection
570570
s = cp.sum(X, axis=1)

python/cuml/cuml/explainer/tree_shap.pyx

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,13 @@
1414
# limitations under the License.
1515
#
1616

17+
import cuml
1718
from cuml.common import input_to_cuml_array
18-
from cuml.ensemble import RandomForestClassifier as curfc
19-
from cuml.ensemble import RandomForestRegressor as curfr
2019
from cuml.internals.array import CumlArray
2120
from cuml.internals.input_utils import determine_array_type
2221

2322
from cuml.internals.treelite cimport *
23+
2424
from cuml.internals.treelite import safe_treelite_call
2525

2626
from libc.stdint cimport uintptr_t
@@ -29,8 +29,9 @@ import re
2929

3030
import numpy as np
3131
import treelite
32-
from sklearn.ensemble import RandomForestClassifier as sklrfc
33-
from sklearn.ensemble import RandomForestRegressor as sklrfr
32+
33+
# from sklearn.ensemble import RandomForestClassifier as sklrfc
34+
# from sklearn.ensemble import RandomForestRegressor as sklrfr
3435

3536

3637
cdef extern from "cuml/explainer/tree_shap.hpp" namespace "ML::Explainer" nogil:
@@ -178,15 +179,20 @@ cdef class TreeExplainer:
178179
model = model.booster_
179180
tl_model = treelite.frontend.from_lightgbm(model)
180181
# cuML RF model object
181-
elif isinstance(model, (curfr, curfc)):
182+
elif isinstance(model, (cuml.RandomForestClassifier, cuml.RandomForestRegressor)):
182183
tl_model = model.convert_to_treelite_model()
183184
# scikit-learn RF model object
184-
elif isinstance(model, (sklrfr, sklrfc)):
185-
tl_model = treelite.sklearn.import_model(model)
186185
elif isinstance(model, treelite.Model):
187186
tl_model = model
188187
else:
189-
raise ValueError(f"Unrecognized model object type: {type(model)}")
188+
from sklearn.ensemble import (
189+
RandomForestClassifier,
190+
RandomForestRegressor,
191+
)
192+
if isinstance(model, (RandomForestClassifier, RandomForestRegressor)):
193+
tl_model = treelite.sklearn.import_model(model)
194+
else:
195+
raise ValueError(f"Unrecognized model object type: {type(model)}")
190196

191197
# Get num_class
192198
self.num_class = tl_model.get_header_accessor().get_field("num_class").copy()

python/cuml/cuml/internals/base.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
import os
1919

2020
import pylibraft.common.handle
21-
from sklearn.utils import estimator_html_repr
2221

2322
import cuml
2423
import cuml.common
@@ -458,10 +457,11 @@ def _more_tags(self):
458457

459458
def _repr_mimebundle_(self, **kwargs):
460459
"""Prepare representations used by jupyter kernels to display estimator"""
461-
if estimator_html_repr is not None:
462-
output = {"text/plain": repr(self)}
463-
output["text/html"] = estimator_html_repr(self)
464-
return output
460+
from sklearn.utils import estimator_html_repr
461+
462+
output = {"text/plain": repr(self)}
463+
output["text/html"] = estimator_html_repr(self)
464+
return output
465465

466466
def set_nvtx_annotations(self):
467467
for func_name in [

python/cuml/cuml/random_projection/random_projection.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
import cupyx.scipy.sparse as cp_sp
1616
import numpy as np
1717
import scipy.sparse as sp
18-
import sklearn.random_projection
1918

2019
import cuml
2120
from cuml.common.array_descriptor import CumlArrayDescriptor
@@ -51,6 +50,8 @@ def johnson_lindenstrauss_min_dim(n_samples, eps=0.1):
5150
The minimal number of components to guarantee with good probability
5251
an eps-embedding with n_samples.
5352
"""
53+
import sklearn.random_projection
54+
5455
return sklearn.random_projection.johnson_lindenstrauss_min_dim(
5556
n_samples, eps=eps
5657
)

0 commit comments

Comments
 (0)