Skip to content

Commit c349e02

Browse files
committed
Use thrust::scatter approach
1 parent 76616dd commit c349e02

1 file changed

Lines changed: 5 additions & 6 deletions

File tree

cpp/src/randomforest/randomforest.cuh

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -192,12 +192,11 @@ class RandomForest {
192192

193193
// Use Thrust to create boolean mask: first fill with false, then mark selected rows
194194
thrust::fill(thrust::cuda::par.on(s), tree_mask, tree_mask + n_rows, false);
195-
thrust::for_each(thrust::cuda::par.on(s),
196-
selected_rows[stream_id].data(),
197-
selected_rows[stream_id].data() + n_sampled_rows,
198-
[tree_mask, n_rows] __device__(int idx) {
199-
if (idx >= 0 && idx < n_rows) { tree_mask[idx] = true; }
200-
});
195+
thrust::scatter(thrust::cuda::par.on(s),
196+
thrust::make_constant_iterator(true),
197+
thrust::make_constant_iterator(true) + n_sampled_rows,
198+
selected_rows[stream_id].data(),
199+
tree_mask);
201200
}
202201
}
203202
// Cleanup

0 commit comments

Comments
 (0)