From 4021a8f1feb391f85c601017bffd9f4c18396421 Mon Sep 17 00:00:00 2001 From: Hyunsu Cho Date: Tue, 8 Jul 2025 07:28:36 +0000 Subject: [PATCH 1/8] Correctly handle non-default GPU context in FIL --- cpp/include/cuml/fil/forest_model.hpp | 6 ++++-- python/cuml/cuml/ensemble/randomforest_common.pyx | 8 ++++++++ 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/cpp/include/cuml/fil/forest_model.hpp b/cpp/include/cuml/fil/forest_model.hpp index c3a848aab4..410d8545ff 100644 --- a/cpp/include/cuml/fil/forest_model.hpp +++ b/cpp/include/cuml/fil/forest_model.hpp @@ -298,8 +298,10 @@ struct forest_model { std::optional specified_chunk_size = std::nullopt) { // TODO(wphicks): Make sure buffer lands on same device as model - auto out_buffer = raft_proto::buffer{output, num_rows * num_outputs(), out_mem_type}; - auto in_buffer = raft_proto::buffer{input, num_rows * num_features(), in_mem_type}; + int current_device_id; + raft_proto::cuda_check(cudaGetDevice(¤t_device_id)); + auto out_buffer = raft_proto::buffer{output, num_rows * num_outputs(), out_mem_type, current_device_id}; + auto in_buffer = raft_proto::buffer{input, num_rows * num_features(), in_mem_type, current_device_id}; predict(handle, out_buffer, in_buffer, predict_type, specified_chunk_size); } diff --git a/python/cuml/cuml/ensemble/randomforest_common.pyx b/python/cuml/cuml/ensemble/randomforest_common.pyx index 7ccbed1429..006437fd7e 100644 --- a/python/cuml/cuml/ensemble/randomforest_common.pyx +++ b/python/cuml/cuml/ensemble/randomforest_common.pyx @@ -42,6 +42,8 @@ from cuml.prims.label.classlabels import check_labels, make_monotonic from cuml.ensemble.randomforest_shared cimport * from cuml.internals.treelite cimport * +from cuda.bindings import runtime + _split_criterion_lookup = { "0": GINI, "gini": GINI, @@ -436,6 +438,11 @@ class BaseRandomForestModel(Base, InteropMixin): align_bytes = None, ) -> CumlArray: treelite_bytes = self._serialize_treelite_bytes() + status, current_device_id = runtime.cudaGetDevice() + if status != runtime.cudaError_t.cudaSuccess: + _, name = runtime.cudaGetErrorName(status) + _, msg = runtime.cudaGetErrorString(status) + raise RuntimeError(f"Failed to run cudaGetDevice(). {name}: {msg}") fil_model = ForestInference( treelite_model=treelite_bytes, handle=self.handle, @@ -445,6 +452,7 @@ class BaseRandomForestModel(Base, InteropMixin): layout=layout, default_chunk_size=default_chunk_size, align_bytes=align_bytes, + device_id=current_device_id, ) if predict_proba: return fil_model.predict_proba(X) From 4de288395566735c408ba9fb058331316e6a09cc Mon Sep 17 00:00:00 2001 From: Hyunsu Cho Date: Tue, 8 Jul 2025 07:43:04 +0000 Subject: [PATCH 2/8] Fix formatting --- cpp/include/cuml/fil/forest_model.hpp | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/cpp/include/cuml/fil/forest_model.hpp b/cpp/include/cuml/fil/forest_model.hpp index 410d8545ff..01511217c5 100644 --- a/cpp/include/cuml/fil/forest_model.hpp +++ b/cpp/include/cuml/fil/forest_model.hpp @@ -300,8 +300,10 @@ struct forest_model { // TODO(wphicks): Make sure buffer lands on same device as model int current_device_id; raft_proto::cuda_check(cudaGetDevice(¤t_device_id)); - auto out_buffer = raft_proto::buffer{output, num_rows * num_outputs(), out_mem_type, current_device_id}; - auto in_buffer = raft_proto::buffer{input, num_rows * num_features(), in_mem_type, current_device_id}; + auto out_buffer = + raft_proto::buffer{output, num_rows * num_outputs(), out_mem_type, current_device_id}; + auto in_buffer = + raft_proto::buffer{input, num_rows * num_features(), in_mem_type, current_device_id}; predict(handle, out_buffer, in_buffer, predict_type, specified_chunk_size); } From 11e6ab137866daad70287791771f164947c728b4 Mon Sep 17 00:00:00 2001 From: Hyunsu Cho Date: Wed, 9 Jul 2025 00:37:34 +0000 Subject: [PATCH 3/8] If no device_id is given to FIL, use current device --- python/cuml/cuml/fil/fil.pyx | 40 +++++++++++++++++++++++------------- 1 file changed, 26 insertions(+), 14 deletions(-) diff --git a/python/cuml/cuml/fil/fil.pyx b/python/cuml/cuml/fil/fil.pyx index 2b857cf1c9..bcdcf05d43 100644 --- a/python/cuml/cuml/fil/fil.pyx +++ b/python/cuml/cuml/fil/fil.pyx @@ -48,6 +48,8 @@ from cuml.fil.postprocessing cimport element_op, row_op from cuml.fil.tree_layout cimport tree_layout as fil_tree_layout from cuml.internals.treelite cimport * +from cuda.bindings import runtime + cdef extern from "cuml/fil/forest_model.hpp" namespace "ML::fil" nogil: cdef cppclass forest_model: @@ -154,7 +156,7 @@ cdef class ForestInference_impl(): align_bytes=0, use_double_precision=None, mem_type=None, - device_id=0 + device_id=None, ): # Store reference to RAFT handle to control lifetime, since raft_proto # handle keeps a pointer to it @@ -197,6 +199,16 @@ cdef class ForestInference_impl(): else: raise RuntimeError(f"Unrecognized tree layout {layout}") + if device_id is None: + # If no device ID is explicitly given, use the currently + # active device + status, current_device_id = runtime.cudaGetDevice() + if status != runtime.cudaError_t.cudaSuccess: + _, name = runtime.cudaGetErrorName(status) + _, msg = runtime.cudaGetErrorString(status) + raise RuntimeError(f"Failed to run cudaGetDevice(). {name}: {msg}") + device_id = current_device_id + self.model = import_from_treelite_handle( tl_handle, tree_layout, @@ -457,9 +469,10 @@ class ForestInference(Base, CMajorInputTagMixin): only for models trained and double precision and when exact conformance between results from FIL and the original training framework is of paramount importance. - device_id : int, default=0 + device_id : int or None, default=None For GPU execution, the device on which to load and execute this - model. For CPU execution, this value is currently ignored. + model. If set to None, use the currently active device. + For CPU execution, this value is currently ignored. """ def _reload_model(self): @@ -553,7 +566,7 @@ class ForestInference(Base, CMajorInputTagMixin): try: return self._device_id_ except AttributeError: - self._device_id_ = 0 + self._device_id_ = None return self._device_id_ @device_id.setter @@ -562,14 +575,13 @@ class ForestInference(Base, CMajorInputTagMixin): old_value = self.device_id except AttributeError: old_value = None - if value is not None: - self._device_id_ = value - if ( - self.treelite_model is not None - and self.device_id != old_value - and hasattr(self, '_gpu_forest') - ): - self._load_to_fil(device_id=self.device_id) + self._device_id_ = value + if ( + self.treelite_model is not None + and self.device_id != old_value + and hasattr(self, '_gpu_forest') + ): + self._load_to_fil(device_id=self.device_id) @property def treelite_model(self): @@ -616,7 +628,7 @@ class ForestInference(Base, CMajorInputTagMixin): default_chunk_size=None, align_bytes=None, precision='single', - device_id=0, + device_id=None, ): super().__init__( handle=handle, verbose=verbose, output_type=output_type @@ -633,7 +645,7 @@ class ForestInference(Base, CMajorInputTagMixin): self.treelite_model = treelite_model self._load_to_fil(device_id=self.device_id) - def _load_to_fil(self, mem_type=None, device_id=0): + def _load_to_fil(self, mem_type=None, device_id=None): if mem_type is None: mem_type = GlobalSettings().fil_memory_type else: From 0f726ab8d28568961528795bb9b5e262d55547c6 Mon Sep 17 00:00:00 2001 From: Hyunsu Cho Date: Wed, 9 Jul 2025 00:46:30 +0000 Subject: [PATCH 4/8] Remove duplicated cudaGetDevice logic --- python/cuml/cuml/ensemble/randomforest_common.pyx | 8 -------- 1 file changed, 8 deletions(-) diff --git a/python/cuml/cuml/ensemble/randomforest_common.pyx b/python/cuml/cuml/ensemble/randomforest_common.pyx index 63e248664b..3e7871ecff 100644 --- a/python/cuml/cuml/ensemble/randomforest_common.pyx +++ b/python/cuml/cuml/ensemble/randomforest_common.pyx @@ -42,8 +42,6 @@ from cuml.prims.label.classlabels import check_labels, make_monotonic from cuml.ensemble.randomforest_shared cimport * from cuml.internals.treelite cimport * -from cuda.bindings import runtime - _split_criterion_lookup = { "0": GINI, "gini": GINI, @@ -446,11 +444,6 @@ class BaseRandomForestModel(Base, InteropMixin): align_bytes = None, ) -> CumlArray: treelite_bytes = self._serialize_treelite_bytes() - status, current_device_id = runtime.cudaGetDevice() - if status != runtime.cudaError_t.cudaSuccess: - _, name = runtime.cudaGetErrorName(status) - _, msg = runtime.cudaGetErrorString(status) - raise RuntimeError(f"Failed to run cudaGetDevice(). {name}: {msg}") fil_model = ForestInference( treelite_model=treelite_bytes, handle=self.handle, @@ -460,7 +453,6 @@ class BaseRandomForestModel(Base, InteropMixin): layout=layout, default_chunk_size=default_chunk_size, align_bytes=align_bytes, - device_id=current_device_id, ) if predict_proba: return fil_model.predict_proba(X) From 0aa62ab985d4ba9ebc3da8d88e39391f28124255 Mon Sep 17 00:00:00 2001 From: Hyunsu Cho Date: Thu, 17 Jul 2025 23:23:21 -0700 Subject: [PATCH 5/8] Move device_id=None logic to outer ForestInference --- python/cuml/cuml/fil/fil.pyx | 26 +++++++++++++++++--------- 1 file changed, 17 insertions(+), 9 deletions(-) diff --git a/python/cuml/cuml/fil/fil.pyx b/python/cuml/cuml/fil/fil.pyx index bcdcf05d43..2b1753d77c 100644 --- a/python/cuml/cuml/fil/fil.pyx +++ b/python/cuml/cuml/fil/fil.pyx @@ -199,15 +199,13 @@ cdef class ForestInference_impl(): else: raise RuntimeError(f"Unrecognized tree layout {layout}") - if device_id is None: - # If no device ID is explicitly given, use the currently - # active device - status, current_device_id = runtime.cudaGetDevice() - if status != runtime.cudaError_t.cudaSuccess: - _, name = runtime.cudaGetErrorName(status) - _, msg = runtime.cudaGetErrorString(status) - raise RuntimeError(f"Failed to run cudaGetDevice(). {name}: {msg}") - device_id = current_device_id + # Use assertion here, since device_id being None would indicate + # a bug, not a user error. The outer ForestInference object + # should set an integer device_id before passing it to + # ForestInference_impl. + assert device_id is not None, ( + "device_id should be set before building ForestInference_impl" + ) self.model = import_from_treelite_handle( tl_handle, @@ -651,6 +649,16 @@ class ForestInference(Base, CMajorInputTagMixin): else: mem_type = MemoryType.from_str(mem_type) + if device_id is None: + # If no device ID is explicitly given, use the currently + # active device + status, current_device_id = runtime.cudaGetDevice() + if status != runtime.cudaError_t.cudaSuccess: + _, name = runtime.cudaGetErrorName(status) + _, msg = runtime.cudaGetErrorString(status) + raise RuntimeError(f"Failed to run cudaGetDevice(). {name}: {msg}") + device_id = current_device_id + if mem_type.is_device_accessible: self.device_id = device_id From 05f6a1be0301a4b4ced7f1307343e35a2828a7a9 Mon Sep 17 00:00:00 2001 From: Hyunsu Cho Date: Thu, 17 Jul 2025 23:41:34 -0700 Subject: [PATCH 6/8] Add tests for device selection --- python/cuml/cuml/tests/test_fil.py | 97 ++++++++++++++++++++++++++++++ 1 file changed, 97 insertions(+) diff --git a/python/cuml/cuml/tests/test_fil.py b/python/cuml/cuml/tests/test_fil.py index 5feb0bd188..f0804b2aee 100644 --- a/python/cuml/cuml/tests/test_fil.py +++ b/python/cuml/cuml/tests/test_fil.py @@ -14,8 +14,10 @@ # import os +from contextlib import nullcontext from math import ceil +import cupy as cp import numpy as np import pandas as pd import pytest @@ -37,6 +39,9 @@ from sklearn.model_selection import train_test_split # noqa: E402 from cuml import ForestInference # noqa: E402 +from cuml.ensemble import ( # noqa: E402 + RandomForestClassifier as cumlRandomForestClassifier, +) from cuml.fil import get_fil_device_type, set_fil_device_type # noqa: E402 from cuml.internals.device_type import DeviceType # noqa: E402 from cuml.internals.global_settings import GlobalSettings # noqa: E402 @@ -897,3 +902,95 @@ def test_missing_categorical(category_list): fm = ForestInference.load_from_treelite_model(model) fil_preds = np.asarray(fm.predict(input)) np.testing.assert_equal(fil_preds.flatten(), gtil_preds.flatten()) + + +@pytest.mark.parametrize("device_id", [None, 0, 1, 2]) +@pytest.mark.parametrize("model_kind", ["sklearn", "xgboost", "cuml"]) +def test_device_selection(device_id, model_kind, tmp_path): + current_device = cp.cuda.runtime.getDevice() + + if device_id is not None and device_id >= cp.cuda.runtime.getDeviceCount(): + pytest.skip( + reason="device_id larger than the number of available GPU devices" + ) + + n_rows = 1000 + n_columns = 30 + n_classes = 3 + n_estimators = 10 + + X, y = simulate_data( + n_rows, + n_columns, + n_classes, + random_state=0, + classification=True, + ) + + # 1. Model can be loaded with device_id set + if model_kind == "sklearn": + skl_model = RandomForestClassifier( + max_depth=3, random_state=0, n_estimators=n_estimators + ) + skl_model.fit(X, y) + fm = ForestInference.load_from_sklearn( + skl_model, + precision="native", + is_classifier=True, + device_id=device_id, + ) + elif model_kind == "xgboost": + xgb_model = xgb.XGBClassifier( + max_depth=3, random_state=0, n_estimators=n_estimators + ) + xgb_model.fit(X, y) + model_path = os.path.join(tmp_path, "xgb_class.ubj") + xgb_model.save_model(model_path) + fm = ForestInference.load( + model_path, + model_type="xgboost_ubj", + precision="native", + is_classifier=True, + device_id=device_id, + ) + elif model_kind == "cuml": + device_context = ( + cp.cuda.Device(device_id) if device_id else nullcontext() + ) + + with device_context: + cuml_model = cumlRandomForestClassifier( + max_depth=3, + random_state=0, + n_estimators=n_estimators, + n_streams=1, + ) + cuml_model.fit(cp.array(X), cp.array(y)) + fm = cuml_model.convert_to_fil_model() + else: + raise NotImplementedError() + + # 2. The section above didn't corrupt current device context + assert cp.cuda.runtime.getDevice() == current_device + + # 3. Device selection is correctly saved to device_id property + assert fm.device_id == (device_id if device_id else 0) + + # 4. Inference can run on an input with the selected device + device_context = cp.cuda.Device(device_id) if device_id else nullcontext() + with device_context: + _ = fm.predict_proba(cp.array(X)) + + # 5. The section above didn't corrupt current device context + assert cp.cuda.runtime.getDevice() == current_device + + # 6. Attempting to run inference with an input from a different device + # is an error + if device_id is not None and device_id != 0: + with cp.cuda.Device(0), pytest.raises( + RuntimeError, match=r".*I/O data on different device than model.*" + ): + _ = fm.predict_proba(cp.array(X)) + + # 7. The section above didn't corrupt current device context + assert cp.cuda.runtime.getDevice() == current_device From 0e2f8f68f001f38a418701f435c4ae33b26df111 Mon Sep 17 00:00:00 2001 From: Hyunsu Cho Date: Thu, 17 Jul 2025 23:43:56 -0700 Subject: [PATCH 7/8] Remove outdated comment --- cpp/include/cuml/fil/forest_model.hpp | 1 - 1 file changed, 1 deletion(-) diff --git a/cpp/include/cuml/fil/forest_model.hpp b/cpp/include/cuml/fil/forest_model.hpp index 01511217c5..7921ed6846 100644 --- a/cpp/include/cuml/fil/forest_model.hpp +++ b/cpp/include/cuml/fil/forest_model.hpp @@ -297,7 +297,6 @@ struct forest_model { infer_kind predict_type = infer_kind::default_kind, std::optional specified_chunk_size = std::nullopt) { - // TODO(wphicks): Make sure buffer lands on same device as model int current_device_id; raft_proto::cuda_check(cudaGetDevice(¤t_device_id)); auto out_buffer = From 3df46088ab3eb67980fd2a986de302d9ae805626 Mon Sep 17 00:00:00 2001 From: Hyunsu Cho Date: Fri, 18 Jul 2025 14:24:56 -0700 Subject: [PATCH 8/8] Add reference to the bug 5983 --- python/cuml/cuml/tests/test_fil.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/python/cuml/cuml/tests/test_fil.py b/python/cuml/cuml/tests/test_fil.py index f0804b2aee..3dd730aec0 100644 --- a/python/cuml/cuml/tests/test_fil.py +++ b/python/cuml/cuml/tests/test_fil.py @@ -959,6 +959,8 @@ def test_device_selection(device_id, model_kind, tmp_path): ) with device_context: + # TODO(hcho3): Remove n_streams=1 argument once the bug + # https://github.com/rapidsai/cuml/issues/5983 is resolved cuml_model = cumlRandomForestClassifier( max_depth=3, random_state=0,