Skip to content

Commit f60b5f0

Browse files
authored
CPU/GPU interop with RandomForest (#6175)
First version for CPU/GPU interop with RandomForest. Note. This feature requires latest Treelite. Authors: - Philip Hyunsu Cho (https://github.com/hcho3) - Dante Gama Dessavre (https://github.com/dantegd) Approvers: - William Hicks (https://github.com/wphicks) URL: #6175
1 parent f67b426 commit f60b5f0

8 files changed

Lines changed: 677 additions & 14 deletions

File tree

python/cuml/cuml/ensemble/randomforest_common.pyx

Lines changed: 30 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
#
2-
# Copyright (c) 2020-2024, NVIDIA CORPORATION.
2+
# Copyright (c) 2020-2025, NVIDIA CORPORATION.
33
#
44
# Licensed under the Apache License, Version 2.0 (the "License");
55
# you may not use this file except in compliance with the License.
@@ -13,7 +13,12 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515
#
16+
import threading
17+
import treelite.sklearn
1618
from cuml.internals.safe_imports import gpu_only_import
19+
from cuml.internals.api_decorators import device_interop_preparation
20+
from cuml.internals.global_settings import GlobalSettings
21+
1722
cp = gpu_only_import('cupy')
1823
import math
1924
import warnings
@@ -24,7 +29,7 @@ np = cpu_only_import('numpy')
2429
from cuml import ForestInference
2530
from cuml.fil.fil import TreeliteModel
2631
from pylibraft.common.handle import Handle
27-
from cuml.internals.base import Base
32+
from cuml.internals.base import UniversalBase
2833
from cuml.internals.array import CumlArray
2934
from cuml.common.exceptions import NotFittedError
3035
import cuml.internals
@@ -39,7 +44,7 @@ from cuml.common.array_descriptor import CumlArrayDescriptor
3944
from cuml.prims.label.classlabels import make_monotonic, check_labels
4045

4146

42-
class BaseRandomForestModel(Base):
47+
class BaseRandomForestModel(UniversalBase):
4348
_param_names = ['n_estimators', 'max_depth', 'handle',
4449
'max_features', 'n_bins',
4550
'split_criterion', 'min_samples_leaf',
@@ -67,6 +72,7 @@ class BaseRandomForestModel(Base):
6772

6873
classes_ = CumlArrayDescriptor()
6974

75+
@device_interop_preparation
7076
def __init__(self, *, split_criterion, n_streams=4, n_estimators=100,
7177
max_depth=16, handle=None, max_features='sqrt', n_bins=128,
7278
bootstrap=True,
@@ -88,7 +94,7 @@ class BaseRandomForestModel(Base):
8894
"class_weight": class_weight}
8995

9096
for key, vals in sklearn_params.items():
91-
if vals:
97+
if vals and not GlobalSettings().accelerator_active:
9298
raise TypeError(
9399
" The Scikit-learn variable ", key,
94100
" is not supported in cuML,"
@@ -97,7 +103,7 @@ class BaseRandomForestModel(Base):
97103
"api.html#random-forest) for more information")
98104

99105
for key in kwargs.keys():
100-
if key not in self._param_names:
106+
if key not in self._param_names and not GlobalSettings().accelerator_active:
101107
raise TypeError(
102108
" The variable ", key,
103109
" is not supported in cuML,"
@@ -154,6 +160,7 @@ class BaseRandomForestModel(Base):
154160
self.model_pbuf_bytes = bytearray()
155161
self.treelite_handle = None
156162
self.treelite_serialized_model = None
163+
self._cpu_model_class_lock = threading.RLock()
157164

158165
def _get_max_feat_val(self) -> float:
159166
if isinstance(self.max_features, int):
@@ -268,6 +275,24 @@ class BaseRandomForestModel(Base):
268275
self.treelite_handle = <uintptr_t> tl_handle
269276
return self.treelite_handle
270277

278+
def cpu_to_gpu(self):
279+
tl_model = treelite.sklearn.import_model(self._cpu_model)
280+
self._temp = TreeliteModel.from_treelite_bytes(tl_model.serialize_bytes())
281+
self.treelite_serialized_model = treelite_serialize(self._temp.handle)
282+
self._obtain_treelite_handle()
283+
self.dtype = np.float64
284+
self.update_labels = False
285+
super().cpu_to_gpu()
286+
287+
def gpu_to_cpu(self):
288+
self._obtain_treelite_handle()
289+
tl_model = TreeliteModel.from_treelite_model_handle(
290+
self.treelite_handle,
291+
take_handle_ownership=False)
292+
tl_bytes = tl_model.to_treelite_bytes()
293+
tl_model2 = treelite.Model.deserialize_bytes(tl_bytes)
294+
self._cpu_model = treelite.sklearn.export_model(tl_model2)
295+
271296
@cuml.internals.api_base_return_generic(set_output_type=True,
272297
set_n_features_in=True,
273298
get_output_type=False)

python/cuml/cuml/ensemble/randomforestclassifier.pyx

Lines changed: 70 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,13 @@
1515
# limitations under the License.
1616
#
1717

18+
1819
# distutils: language = c++
20+
import sys
21+
import threading
22+
23+
from cuml.internals.api_decorators import device_interop_preparation
24+
from cuml.internals.api_decorators import enable_device_interop
1925
from cuml.internals.safe_imports import (
2026
cpu_only_import,
2127
gpu_only_import,
@@ -29,7 +35,9 @@ rmm = gpu_only_import('rmm')
2935

3036
from cuml.internals.array import CumlArray
3137
from cuml.internals.mixins import ClassifierMixin
38+
from cuml.internals.global_settings import GlobalSettings
3239
import cuml.internals
40+
from cuml.internals import logger
3341
from cuml.common.doc_utils import generate_docstring
3442
from cuml.common.doc_utils import insert_into_docstring
3543
from cuml.common import input_to_cuml_array
@@ -248,6 +256,22 @@ class RandomForestClassifier(BaseRandomForestModel,
248256
<https://scikit-learn.org/stable/modules/generated/sklearn.ensemble.RandomForestClassifier.html>`_.
249257
"""
250258

259+
_cpu_estimator_import_path = 'sklearn.ensemble.RandomForestClassifier'
260+
261+
_hyperparam_interop_translator = {
262+
"criterion": "NotImplemented",
263+
"oob_score": {
264+
True: "NotImplemented",
265+
},
266+
"max_depth": {
267+
None: 16,
268+
},
269+
"max_samples": {
270+
None: 1.0,
271+
},
272+
}
273+
274+
@device_interop_preparation
251275
def __init__(self, *, split_criterion=0, handle=None, verbose=False,
252276
output_type=None,
253277
**kwargs):
@@ -292,6 +316,10 @@ class RandomForestClassifier(BaseRandomForestModel,
292316
state["treelite_handle"] = None
293317
state["split_criterion"] = self.split_criterion
294318
state["handle"] = self.handle
319+
320+
if "_cpu_model_class_lock" in state:
321+
del state["_cpu_model_class_lock"]
322+
295323
return state
296324

297325
def __setstate__(self, state):
@@ -314,6 +342,7 @@ class RandomForestClassifier(BaseRandomForestModel,
314342

315343
self.treelite_serialized_model = state["treelite_serialized_model"]
316344
self.__dict__.update(state)
345+
self._cpu_model_class_lock = threading.RLock()
317346

318347
def __del__(self):
319348
self._reset_forest_data()
@@ -338,6 +367,9 @@ class RandomForestClassifier(BaseRandomForestModel,
338367
self.treelite_serialized_model = None
339368
self.n_cols = None
340369

370+
def get_attr_names(self):
371+
return []
372+
341373
def convert_to_treelite_model(self):
342374
"""
343375
Converts the cuML RF model to a Treelite model
@@ -418,6 +450,7 @@ class RandomForestClassifier(BaseRandomForestModel,
418450
@cuml.internals.api_base_return_any(set_output_type=False,
419451
set_output_dtype=True,
420452
set_n_features_in=False)
453+
@enable_device_interop
421454
def fit(self, X, y, convert_dtype=True):
422455
"""
423456
Perform Random Forest Classification on the input data
@@ -429,7 +462,6 @@ class RandomForestClassifier(BaseRandomForestModel,
429462
y to be of dtype int32. This will increase memory used for
430463
the method.
431464
"""
432-
433465
X_m, y_m, max_feature_val = self._dataset_setup_for_fit(X, y,
434466
convert_dtype)
435467
# Track the labels to see if update is necessary
@@ -556,6 +588,7 @@ class RandomForestClassifier(BaseRandomForestModel,
556588
@insert_into_docstring(parameters=[('dense', '(n_samples, n_features)')],
557589
return_values=[('dense', '(n_samples, 1)')])
558590
@cuml.internals.api_base_return_array(get_output_dtype=True)
591+
@enable_device_interop
559592
def predict(self, X, predict_model="GPU", threshold=0.5,
560593
algo='auto', convert_dtype=True,
561594
fil_sparse_format='auto') -> CumlArray:
@@ -828,3 +861,39 @@ class RandomForestClassifier(BaseRandomForestModel,
828861
if self.dtype == np.float64:
829862
return get_rf_json(rf_forest64).decode('utf-8')
830863
return get_rf_json(rf_forest).decode('utf-8')
864+
865+
def cpu_to_gpu(self):
866+
# treelite does an internal isinstance check to detect an sklearn
867+
# RF, which proxymodule interferes with. We work around that
868+
# temporarily here just for treelite internal check and
869+
# restore the __class__ at the end of the method.
870+
if GlobalSettings().accelerator_active:
871+
with self._cpu_model_class_lock:
872+
original_class = self._cpu_model.__class__
873+
self._cpu_model.__class__ = sys.modules['sklearn.ensemble'].RandomForestClassifier
874+
875+
try:
876+
super().cpu_to_gpu()
877+
finally:
878+
self._cpu_model.__class__ = original_class
879+
880+
else:
881+
super().cpu_to_gpu()
882+
883+
@classmethod
884+
def _hyperparam_translator(cls, **kwargs):
885+
kwargs, gpuaccel = super(RandomForestClassifier, cls)._hyperparam_translator(**kwargs)
886+
887+
if "max_samples" in kwargs:
888+
if isinstance(kwargs["max_samples"], int):
889+
logger.warn(
890+
f"Integer value of max_samples={kwargs['max_samples']}"
891+
"not supported, changed to 1.0."
892+
)
893+
kwargs["max_samples"] = 1.0
894+
895+
# determinism requires only 1 cuda stream
896+
if "random_state" in kwargs:
897+
kwargs["n_streams"] = 1
898+
899+
return kwargs, gpuaccel

python/cuml/cuml/ensemble/randomforestregressor.pyx

Lines changed: 69 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,14 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515
#
16-
1716
# distutils: language = c++
1817

18+
19+
import sys
20+
import threading
21+
22+
from cuml.internals.api_decorators import device_interop_preparation
23+
from cuml.internals.api_decorators import enable_device_interop
1924
from cuml.internals.safe_imports import (
2025
cpu_only_import,
2126
gpu_only_import,
@@ -27,7 +32,9 @@ nvtx_annotate = gpu_only_import_from("nvtx", "annotate", alt=null_decorator)
2732
rmm = gpu_only_import('rmm')
2833

2934
from cuml.internals.array import CumlArray
35+
from cuml.internals.global_settings import GlobalSettings
3036
import cuml.internals
37+
from cuml.internals import logger
3138

3239
from cuml.internals.mixins import RegressorMixin
3340
from cuml.internals.logger cimport level_enum
@@ -251,6 +258,22 @@ class RandomForestRegressor(BaseRandomForestModel,
251258
<https://scikit-learn.org/stable/modules/generated/sklearn.ensemble.RandomForestRegressor.html>`_.
252259
"""
253260

261+
_cpu_estimator_import_path = 'sklearn.ensemble.RandomForestRegressor'
262+
263+
_hyperparam_interop_translator = {
264+
"criterion": "NotImplemented",
265+
"oob_score": {
266+
True: "NotImplemented",
267+
},
268+
"max_depth": {
269+
None: 16,
270+
},
271+
"max_samples": {
272+
None: 1.0,
273+
},
274+
}
275+
276+
@device_interop_preparation
254277
def __init__(self, *,
255278
split_criterion=2,
256279
accuracy_metric='r2',
@@ -297,6 +320,9 @@ class RandomForestRegressor(BaseRandomForestModel,
297320
state["treelite_handle"] = None
298321
state["split_criterion"] = self.split_criterion
299322

323+
if "_cpu_model_class_lock" in state:
324+
del state["_cpu_model_class_lock"]
325+
300326
return state
301327

302328
def __setstate__(self, state):
@@ -318,6 +344,7 @@ class RandomForestRegressor(BaseRandomForestModel,
318344

319345
self.treelite_serialized_model = state["treelite_serialized_model"]
320346
self.__dict__.update(state)
347+
self._cpu_model_class_lock = threading.RLock()
321348

322349
def __del__(self):
323350
self._reset_forest_data()
@@ -342,6 +369,9 @@ class RandomForestRegressor(BaseRandomForestModel,
342369
self.treelite_serialized_model = None
343370
self.n_cols = None
344371

372+
def get_attr_names(self):
373+
return []
374+
345375
def convert_to_treelite_model(self):
346376
"""
347377
Converts the cuML RF model to a Treelite model
@@ -413,6 +443,7 @@ class RandomForestRegressor(BaseRandomForestModel,
413443
domain="cuml_python")
414444
@generate_docstring()
415445
@cuml.internals.api_base_return_any_skipall
446+
@enable_device_interop
416447
def fit(self, X, y, convert_dtype=True):
417448
"""
418449
Perform Random Forest Regression on the input data
@@ -535,6 +566,7 @@ class RandomForestRegressor(BaseRandomForestModel,
535566
domain="cuml_python")
536567
@insert_into_docstring(parameters=[('dense', '(n_samples, n_features)')],
537568
return_values=[('dense', '(n_samples, 1)')])
569+
@enable_device_interop
538570
def predict(self, X, predict_model="GPU",
539571
algo='auto', convert_dtype=True,
540572
fil_sparse_format='auto') -> CumlArray:
@@ -752,3 +784,39 @@ class RandomForestRegressor(BaseRandomForestModel,
752784
if self.dtype == np.float64:
753785
return get_rf_json(rf_forest64).decode('utf-8')
754786
return get_rf_json(rf_forest).decode('utf-8')
787+
788+
def cpu_to_gpu(self):
789+
# treelite does an internal isinstance check to detect an sklearn
790+
# RF, which proxymodule interferes with. We work around that
791+
# temporarily here just for treelite internal check and
792+
# restore the __class__ at the end of the method.
793+
if GlobalSettings().accelerator_active:
794+
with self._cpu_model_class_lock:
795+
original_class = self._cpu_model.__class__
796+
self._cpu_model.__class__ = sys.modules['sklearn.ensemble'].RandomForestRegressor
797+
798+
try:
799+
super().cpu_to_gpu()
800+
finally:
801+
self._cpu_model.__class__ = original_class
802+
803+
else:
804+
super().cpu_to_gpu()
805+
806+
@classmethod
807+
def _hyperparam_translator(cls, **kwargs):
808+
kwargs, gpuaccel = super(RandomForestRegressor, cls)._hyperparam_translator(**kwargs)
809+
810+
if "max_samples" in kwargs:
811+
if isinstance(kwargs["max_samples"], int):
812+
logger.warn(
813+
f"Integer value of max_samples={kwargs['max_samples']}"
814+
"not supported, changed to 1.0."
815+
)
816+
kwargs["max_samples"] = 1.0
817+
818+
# determinism requires only 1 cuda stream
819+
if "random_state" in kwargs:
820+
kwargs["n_streams"] = 1
821+
822+
return kwargs, gpuaccel

0 commit comments

Comments
 (0)