Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions cpp/include/cuml/fil/forest_model.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -298,8 +298,12 @@ struct forest_model {
std::optional<index_type> specified_chunk_size = std::nullopt)
{
// TODO(wphicks): Make sure buffer lands on same device as model
Comment thread
hcho3 marked this conversation as resolved.
Outdated
auto out_buffer = raft_proto::buffer{output, num_rows * num_outputs(), out_mem_type};
auto in_buffer = raft_proto::buffer{input, num_rows * num_features(), in_mem_type};
int current_device_id;
raft_proto::cuda_check(cudaGetDevice(&current_device_id));
Comment thread
hcho3 marked this conversation as resolved.
auto out_buffer =
raft_proto::buffer{output, num_rows * num_outputs(), out_mem_type, current_device_id};
auto in_buffer =
raft_proto::buffer{input, num_rows * num_features(), in_mem_type, current_device_id};
predict(handle, out_buffer, in_buffer, predict_type, specified_chunk_size);
}

Expand Down
40 changes: 26 additions & 14 deletions python/cuml/cuml/fil/fil.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ from cuml.fil.postprocessing cimport element_op, row_op
from cuml.fil.tree_layout cimport tree_layout as fil_tree_layout
from cuml.internals.treelite cimport *

from cuda.bindings import runtime


cdef extern from "cuml/fil/forest_model.hpp" namespace "ML::fil" nogil:
cdef cppclass forest_model:
Expand Down Expand Up @@ -154,7 +156,7 @@ cdef class ForestInference_impl():
align_bytes=0,
use_double_precision=None,
mem_type=None,
device_id=0
device_id=None,
):
# Store reference to RAFT handle to control lifetime, since raft_proto
# handle keeps a pointer to it
Expand Down Expand Up @@ -197,6 +199,16 @@ cdef class ForestInference_impl():
else:
raise RuntimeError(f"Unrecognized tree layout {layout}")

if device_id is None:
# If no device ID is explicitly given, use the currently
# active device
status, current_device_id = runtime.cudaGetDevice()
if status != runtime.cudaError_t.cudaSuccess:
_, name = runtime.cudaGetErrorName(status)
_, msg = runtime.cudaGetErrorString(status)
raise RuntimeError(f"Failed to run cudaGetDevice(). {name}: {msg}")
device_id = current_device_id

self.model = import_from_treelite_handle(
tl_handle,
tree_layout,
Expand Down Expand Up @@ -457,9 +469,10 @@ class ForestInference(Base, CMajorInputTagMixin):
only for models trained and double precision and when exact
conformance between results from FIL and the original training
framework is of paramount importance.
device_id : int, default=0
device_id : int or None, default=None
For GPU execution, the device on which to load and execute this
model. For CPU execution, this value is currently ignored.
model. If set to None, use the currently active device.
For CPU execution, this value is currently ignored.
"""

def _reload_model(self):
Expand Down Expand Up @@ -553,7 +566,7 @@ class ForestInference(Base, CMajorInputTagMixin):
try:
return self._device_id_
except AttributeError:
self._device_id_ = 0
self._device_id_ = None
return self._device_id_

@device_id.setter
Expand All @@ -562,14 +575,13 @@ class ForestInference(Base, CMajorInputTagMixin):
old_value = self.device_id
except AttributeError:
old_value = None
if value is not None:
self._device_id_ = value
if (
self.treelite_model is not None
and self.device_id != old_value
and hasattr(self, '_gpu_forest')
):
self._load_to_fil(device_id=self.device_id)
self._device_id_ = value
if (
self.treelite_model is not None
and self.device_id != old_value
and hasattr(self, '_gpu_forest')
):
self._load_to_fil(device_id=self.device_id)

@property
def treelite_model(self):
Expand Down Expand Up @@ -616,7 +628,7 @@ class ForestInference(Base, CMajorInputTagMixin):
default_chunk_size=None,
align_bytes=None,
precision='single',
device_id=0,
device_id=None,
):
super().__init__(
handle=handle, verbose=verbose, output_type=output_type
Expand All @@ -633,7 +645,7 @@ class ForestInference(Base, CMajorInputTagMixin):
self.treelite_model = treelite_model
self._load_to_fil(device_id=self.device_id)

def _load_to_fil(self, mem_type=None, device_id=0):
def _load_to_fil(self, mem_type=None, device_id=None):
if mem_type is None:
mem_type = GlobalSettings().fil_memory_type
else:
Expand Down
Loading