Skip to content

Commit b8f0531

Browse files
csadorfvardhan30016
authored andcommitted
Add Out-of-Bag (OOB) Score Support to RandomForest (rapidsai#7401)
## Summary Implements out-of-bag (OOB) scoring for `RandomForestClassifier` and `RandomForestRegressor`, enabling users to estimate model performance without requiring a separate validation set. Closes rapidsai#7395 ## Changes ### C++ Layer - Modified `fit()` functions to accept optional `bootstrap_masks` parameter for storing per-tree bootstrap sample indicators - Updated `RandomForest::fit()` to capture and store bootstrap masks when `oob_score=True` ### Python Layer - Added `oob_score` parameter (boolean only) to Random Forest estimators - Implemented `_compute_oob_score()` method that leverages FIL's `predict_per_tree()` for efficient OOB predictions - Added `oob_score_` and `oob_decision_function_` (or `oob_prediction_`) attributes - Validates that `oob_score` is boolean (custom scorer functions not supported) - Added proper attribute transfer for pickle and CPU interop ### Metrics - **Classifier**: Uses accuracy score on OOB predictions - **Regressor**: Uses R² score on OOB predictions ## Limitations - Custom scorer functions (callable `oob_score`) are not supported - only boolean values accepted - Multi-output targets not supported for OOB scoring ## Testing Added comprehensive tests covering: - Binary and multi-class classification OOB scoring - Regression OOB scoring - Error handling for invalid configurations - Comparison with scikit-learn baseline Authors: - Simon Adorf (https://github.com/csadorf) Approvers: - Victor Lafargue (https://github.com/viclafargue) - Divye Gala (https://github.com/divyegala) - Jim Crist-Harif (https://github.com/jcrist) URL: rapidsai#7401
1 parent 3edc2fa commit b8f0531

11 files changed

Lines changed: 444 additions & 63 deletions

File tree

cpp/include/cuml/ensemble/randomforest.hpp

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,8 @@ void fit(const raft::handle_t& user_handle,
134134
int* labels,
135135
int n_unique_labels,
136136
RF_params rf_params,
137-
rapids_logger::level_enum verbosity = rapids_logger::level_enum::info);
137+
rapids_logger::level_enum verbosity = rapids_logger::level_enum::info,
138+
bool* bootstrap_masks = nullptr);
138139
void fit(const raft::handle_t& user_handle,
139140
RandomForestClassifierD* forest,
140141
double* input,
@@ -143,7 +144,8 @@ void fit(const raft::handle_t& user_handle,
143144
int* labels,
144145
int n_unique_labels,
145146
RF_params rf_params,
146-
rapids_logger::level_enum verbosity = rapids_logger::level_enum::info);
147+
rapids_logger::level_enum verbosity = rapids_logger::level_enum::info,
148+
bool* bootstrap_masks = nullptr);
147149

148150
template <typename T, typename L>
149151
void fit_treelite(const raft::handle_t& user_handle,
@@ -154,6 +156,7 @@ void fit_treelite(const raft::handle_t& user_handle,
154156
L* labels,
155157
int n_unique_labels,
156158
RF_params rf_params,
159+
bool* bootstrap_masks,
157160
rapids_logger::level_enum verbosity);
158161

159162
void predict(const raft::handle_t& user_handle,
@@ -211,15 +214,17 @@ void fit(const raft::handle_t& user_handle,
211214
int n_cols,
212215
float* labels,
213216
RF_params rf_params,
214-
rapids_logger::level_enum verbosity = rapids_logger::level_enum::info);
217+
rapids_logger::level_enum verbosity = rapids_logger::level_enum::info,
218+
bool* bootstrap_masks = nullptr);
215219
void fit(const raft::handle_t& user_handle,
216220
RandomForestRegressorD* forest,
217221
double* input,
218222
int n_rows,
219223
int n_cols,
220224
double* labels,
221225
RF_params rf_params,
222-
rapids_logger::level_enum verbosity = rapids_logger::level_enum::info);
226+
rapids_logger::level_enum verbosity = rapids_logger::level_enum::info,
227+
bool* bootstrap_masks = nullptr);
223228

224229
template <typename T, typename L>
225230
void fit_treelite(const raft::handle_t& user_handle,
@@ -229,6 +234,7 @@ void fit_treelite(const raft::handle_t& user_handle,
229234
int n_cols,
230235
L* labels,
231236
RF_params rf_params,
237+
bool* bootstrap_masks,
232238
rapids_logger::level_enum verbosity);
233239

234240
void predict(const raft::handle_t& user_handle,

cpp/src/randomforest/randomforest.cu

Lines changed: 31 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -355,7 +355,8 @@ void fit(const raft::handle_t& user_handle,
355355
int* labels,
356356
int n_unique_labels,
357357
RF_params rf_params,
358-
rapids_logger::level_enum verbosity)
358+
rapids_logger::level_enum verbosity,
359+
bool* bootstrap_masks)
359360
{
360361
raft::common::nvtx::range fun_scope("RF::fit @randomforest.cu");
361362
ML::default_logger().set_level(verbosity);
@@ -365,7 +366,8 @@ void fit(const raft::handle_t& user_handle,
365366

366367
std::shared_ptr<RandomForest<float, int>> rf_classifier =
367368
std::make_shared<RandomForest<float, int>>(rf_params, RF_type::CLASSIFICATION);
368-
rf_classifier->fit(user_handle, input, n_rows, n_cols, labels, n_unique_labels, forest);
369+
rf_classifier->fit(
370+
user_handle, input, n_rows, n_cols, labels, n_unique_labels, forest, bootstrap_masks);
369371
}
370372

371373
void fit(const raft::handle_t& user_handle,
@@ -376,7 +378,8 @@ void fit(const raft::handle_t& user_handle,
376378
int* labels,
377379
int n_unique_labels,
378380
RF_params rf_params,
379-
rapids_logger::level_enum verbosity)
381+
rapids_logger::level_enum verbosity,
382+
bool* bootstrap_masks)
380383
{
381384
raft::common::nvtx::range fun_scope("RF::fit @randomforest.cu");
382385
ML::default_logger().set_level(verbosity);
@@ -386,7 +389,8 @@ void fit(const raft::handle_t& user_handle,
386389

387390
std::shared_ptr<RandomForest<double, int>> rf_classifier =
388391
std::make_shared<RandomForest<double, int>>(rf_params, RF_type::CLASSIFICATION);
389-
rf_classifier->fit(user_handle, input, n_rows, n_cols, labels, n_unique_labels, forest);
392+
rf_classifier->fit(
393+
user_handle, input, n_rows, n_cols, labels, n_unique_labels, forest, bootstrap_masks);
390394
}
391395

392396
template <typename value_t, typename label_t>
@@ -398,10 +402,20 @@ void fit_treelite(const raft::handle_t& user_handle,
398402
label_t* labels,
399403
int n_unique_labels,
400404
RF_params rf_params,
405+
bool* bootstrap_masks,
401406
rapids_logger::level_enum verbosity)
402407
{
403408
RandomForestMetaData<value_t, label_t> metadata;
404-
fit(user_handle, &metadata, input, n_rows, n_cols, labels, n_unique_labels, rf_params, verbosity);
409+
fit(user_handle,
410+
&metadata,
411+
input,
412+
n_rows,
413+
n_cols,
414+
labels,
415+
n_unique_labels,
416+
rf_params,
417+
verbosity,
418+
bootstrap_masks);
405419
build_treelite_forest(model, &metadata, n_cols);
406420
}
407421

@@ -562,7 +576,8 @@ void fit(const raft::handle_t& user_handle,
562576
int n_cols,
563577
float* labels,
564578
RF_params rf_params,
565-
rapids_logger::level_enum verbosity)
579+
rapids_logger::level_enum verbosity,
580+
bool* bootstrap_masks)
566581
{
567582
raft::common::nvtx::range fun_scope("RF::fit @randomforest.cu");
568583
ML::default_logger().set_level(verbosity);
@@ -572,7 +587,7 @@ void fit(const raft::handle_t& user_handle,
572587

573588
std::shared_ptr<RandomForest<float, float>> rf_regressor =
574589
std::make_shared<RandomForest<float, float>>(rf_params, RF_type::REGRESSION);
575-
rf_regressor->fit(user_handle, input, n_rows, n_cols, labels, 1, forest);
590+
rf_regressor->fit(user_handle, input, n_rows, n_cols, labels, 1, forest, bootstrap_masks);
576591
}
577592

578593
void fit(const raft::handle_t& user_handle,
@@ -582,7 +597,8 @@ void fit(const raft::handle_t& user_handle,
582597
int n_cols,
583598
double* labels,
584599
RF_params rf_params,
585-
rapids_logger::level_enum verbosity)
600+
rapids_logger::level_enum verbosity,
601+
bool* bootstrap_masks)
586602
{
587603
raft::common::nvtx::range fun_scope("RF::fit @randomforest.cu");
588604
ML::default_logger().set_level(verbosity);
@@ -592,7 +608,7 @@ void fit(const raft::handle_t& user_handle,
592608

593609
std::shared_ptr<RandomForest<double, double>> rf_regressor =
594610
std::make_shared<RandomForest<double, double>>(rf_params, RF_type::REGRESSION);
595-
rf_regressor->fit(user_handle, input, n_rows, n_cols, labels, 1, forest);
611+
rf_regressor->fit(user_handle, input, n_rows, n_cols, labels, 1, forest, bootstrap_masks);
596612
}
597613

598614
template <typename value_t, typename label_t>
@@ -603,10 +619,11 @@ void fit_treelite(const raft::handle_t& user_handle,
603619
int n_cols,
604620
label_t* labels,
605621
RF_params rf_params,
622+
bool* bootstrap_masks,
606623
rapids_logger::level_enum verbosity)
607624
{
608625
RandomForestMetaData<value_t, label_t> metadata;
609-
fit(user_handle, &metadata, input, n_rows, n_cols, labels, rf_params, verbosity);
626+
fit(user_handle, &metadata, input, n_rows, n_cols, labels, rf_params, verbosity, bootstrap_masks);
610627
build_treelite_forest(model, &metadata, n_cols);
611628
}
612629

@@ -735,6 +752,7 @@ template void fit_treelite<float, int>(const raft::handle_t& user_handle,
735752
int* labels,
736753
int n_unique_labels,
737754
RF_params rf_params,
755+
bool* bootstrap_masks,
738756
rapids_logger::level_enum verbosity);
739757
template void fit_treelite<double, int>(const raft::handle_t& user_handle,
740758
TreeliteModelHandle* model,
@@ -744,6 +762,7 @@ template void fit_treelite<double, int>(const raft::handle_t& user_handle,
744762
int* labels,
745763
int n_unique_labels,
746764
RF_params rf_params,
765+
bool* bootstrap_masks,
747766
rapids_logger::level_enum verbosity);
748767
template void fit_treelite<float, float>(const raft::handle_t& user_handle,
749768
TreeliteModelHandle* model,
@@ -752,6 +771,7 @@ template void fit_treelite<float, float>(const raft::handle_t& user_handle,
752771
int n_cols,
753772
float* labels,
754773
RF_params rf_params,
774+
bool* bootstrap_masks,
755775
rapids_logger::level_enum verbosity);
756776
template void fit_treelite<double, double>(const raft::handle_t& user_handle,
757777
TreeliteModelHandle* model,
@@ -760,6 +780,7 @@ template void fit_treelite<double, double>(const raft::handle_t& user_handle,
760780
int n_cols,
761781
double* labels,
762782
RF_params rf_params,
783+
bool* bootstrap_masks,
763784
rapids_logger::level_enum verbosity);
764785

765786
} // End namespace ML

cpp/src/randomforest/randomforest.cuh

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,10 @@
1515
#include <raft/stats/regression_metrics.cuh>
1616
#include <raft/util/cudart_utils.hpp>
1717

18-
#include <thrust/execution_policy.h>
18+
#include <rmm/exec_policy.hpp>
19+
20+
#include <thrust/fill.h>
21+
#include <thrust/for_each.h>
1922
#include <thrust/sequence.h>
2023

2124
#include <decisiontree/batched-levelalgo/quantiles.cuh>
@@ -32,6 +35,7 @@
3235
#include <map>
3336

3437
namespace ML {
38+
3539
template <class T, class L>
3640
class RandomForest {
3741
protected:
@@ -56,7 +60,7 @@ class RandomForest {
5660

5761
} else {
5862
// Use all the samples from the dataset
59-
thrust::sequence(thrust::cuda::par.on(stream), selected_rows->begin(), selected_rows->end());
63+
thrust::sequence(rmm::exec_policy(stream), selected_rows->begin(), selected_rows->end());
6064
}
6165
}
6266

@@ -102,14 +106,17 @@ class RandomForest {
102106
* @param[in] n_unique_labels: (meaningful only for classification) #unique label values (known
103107
during preprocessing)
104108
* @param[in] forest: CPU point to RandomForestMetaData struct.
109+
* @param[out] bootstrap_masks: optional device pointer to store bootstrap masks
110+
* (n_trees * n_rows), only populated if a non-null pointer is provided
105111
*/
106112
void fit(const raft::handle_t& user_handle,
107113
const T* input,
108114
int n_rows,
109115
int n_cols,
110116
L* labels,
111117
int n_unique_labels,
112-
RandomForestMetaData<T, L>* forest)
118+
RandomForestMetaData<T, L>* forest,
119+
bool* bootstrap_masks = nullptr)
113120
{
114121
raft::common::nvtx::range fun_scope("RandomForest::fit @randomforest.cuh");
115122
this->error_checking(input, labels, n_rows, n_cols, false);
@@ -178,6 +185,20 @@ class RandomForest {
178185
this->rf_params.seed,
179186
quantiles,
180187
i);
188+
189+
// Store bootstrap mask if device buffer is provided
190+
if (bootstrap_masks != nullptr) {
191+
// Calculate pointer offset for this tree's mask
192+
bool* tree_mask = bootstrap_masks + (i * n_rows);
193+
194+
// Use Thrust to create boolean mask: first fill with false, then mark selected rows
195+
thrust::fill(rmm::exec_policy(s), tree_mask, tree_mask + n_rows, false);
196+
thrust::scatter(rmm::exec_policy(s),
197+
thrust::make_constant_iterator(true),
198+
thrust::make_constant_iterator(true) + n_sampled_rows,
199+
selected_rows[stream_id].data(),
200+
tree_mask);
201+
}
181202
}
182203
// Cleanup
183204
handle.sync_stream_pool();

docs/source/cuml-accel/limitations.rst

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,7 @@ RandomForestClassifier
187187
``RandomForestClassifier`` will fall back to CPU in the following cases:
188188

189189
- If ``criterion`` is ``"log_loss"``.
190-
- If ``oob_score=True``.
190+
- If ``oob_score`` is a callable.
191191
- If ``warm_start=True``.
192192
- If ``monotonic_cst`` is not ``None``.
193193
- If ``max_values`` is an integer.
@@ -196,21 +196,23 @@ RandomForestClassifier
196196
- If ``class_weight`` is not ``None``.
197197
- If ``sample_weight`` is passed to ``fit`` or ``score``.
198198
- If ``X`` is sparse.
199+
- If ``y`` is a multi-output target.
199200

200201
RandomForestRegressor
201202
^^^^^^^^^^^^^^^^^^^^^
202203

203204
``RandomForestRegressor`` will fall back to CPU in the following cases:
204205

205206
- If ``criterion`` is ``"absolute_error"`` or ``"friedman_mse"``.
206-
- If ``oob_score=True``.
207+
- If ``oob_score`` is a callable.
207208
- If ``warm_start=True``.
208209
- If ``monotonic_cst`` is not ``None``.
209210
- If ``max_values`` is an integer.
210211
- If ``min_weight_fraction_leaf`` is not ``0``.
211212
- If ``ccp_alpha`` is not ``0``.
212213
- If ``sample_weight`` is passed to ``fit`` or ``score``.
213214
- If ``X`` is sparse.
215+
- If ``y`` is a multi-output target.
214216

215217

216218
sklearn.kernel_ridge

python/cuml/cuml/accel/_wrappers/sklearn/ensemble.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
import cuml.ensemble
77
from cuml.accel.estimator_proxy import ProxyBase
8+
from cuml.internals.input_utils import input_to_cuml_array
89
from cuml.internals.interop import UnsupportedOnGPU
910

1011
__all__ = ("RandomForestRegressor", "RandomForestClassifier")
@@ -14,6 +15,16 @@ class RandomForestRegressor(ProxyBase):
1415
_gpu_class = cuml.ensemble.RandomForestRegressor
1516

1617
def _gpu_fit(self, X, y, sample_weight=None):
18+
try:
19+
y = input_to_cuml_array(y, convert_to_mem_type=False)[0]
20+
except ValueError:
21+
raise
22+
else:
23+
if len(y.shape) > 1 and y.shape[1] > 1:
24+
raise UnsupportedOnGPU(
25+
"Multi-output targets are not supported"
26+
)
27+
1728
if sample_weight is not None:
1829
raise UnsupportedOnGPU("`sample_weight` is not supported")
1930
return self._gpu.fit(X, y)
@@ -37,6 +48,21 @@ class RandomForestClassifier(ProxyBase):
3748
_gpu_class = cuml.ensemble.RandomForestClassifier
3849

3950
def _gpu_fit(self, X, y, sample_weight=None):
51+
try:
52+
y = input_to_cuml_array(y, convert_to_mem_type=False)[0]
53+
except ValueError:
54+
raise
55+
else:
56+
if len(y.shape) > 1 and y.shape[1] > 1:
57+
if self.oob_score:
58+
raise ValueError(
59+
"The type of target cannot be used to compute OOB estimates"
60+
)
61+
else:
62+
raise UnsupportedOnGPU(
63+
"Multi-output targets are not supported"
64+
)
65+
4066
if sample_weight is not None:
4167
raise UnsupportedOnGPU("`sample_weight` is not supported")
4268
return self._gpu.fit(X, y)

0 commit comments

Comments
 (0)