Skip to content

Correctly handle non-default GPU context in FIL#6987

Merged
rapids-bot[bot] merged 13 commits intorapidsai:branch-25.08from
hcho3:fix_mgpu_fil
Jul 23, 2025
Merged

Correctly handle non-default GPU context in FIL#6987
rapids-bot[bot] merged 13 commits intorapidsai:branch-25.08from
hcho3:fix_mgpu_fil

Conversation

@hcho3
Copy link
Copy Markdown
Contributor

@hcho3 hcho3 commented Jul 8, 2025

Closes #6930

@hcho3 hcho3 requested review from a team as code owners July 8, 2025 07:29
@hcho3 hcho3 requested review from bdice, betatim and vyasr July 8, 2025 07:29
@github-actions github-actions Bot added Cython / Python Cython or Python issue CUDA/C++ labels Jul 8, 2025
@hcho3 hcho3 added improvement Improvement / enhancement to an existing function non-breaking Non-breaking change labels Jul 8, 2025
@csadorf
Copy link
Copy Markdown
Contributor

csadorf commented Jul 8, 2025

I'd like to write a unit test (pytest) modeling the MRE in #6930. Should I put this under the Dask tests?

Assuming the motivation is to get access to multiple GPUs, we are currently not running on multiple GPUs in any of our CI jobs, including the dask tests jobs.

Comment thread cpp/include/cuml/fil/forest_model.hpp
Copy link
Copy Markdown
Contributor

@csadorf csadorf left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can confirm that #6987 fixes the issues identified in #6930. However, I think we should add some basic tests in test_fil.py that test for:

  1. successfully loading a model with device_id set,
  2. checking that the selection is propagated to the device_id property,
  3. ability to run inference on an array that is on the selected device,
  4. and failing with a "I/O data on a different device than model" RuntimeError if not.

We should test both loading an XGBoost and a scikit-learn model and parameterize the tests for device_id as one of [None, 0, 1].

We can skip test cases on hosts with fewer devices (e.g. by checking the selected device id against cupy.cuda.runtime.getDeviceCount or so).

Comment thread cpp/include/cuml/fil/forest_model.hpp
Comment thread cpp/include/cuml/fil/forest_model.hpp Outdated
@hcho3
Copy link
Copy Markdown
Contributor Author

hcho3 commented Jul 18, 2025

@csadorf Can I get another round of review?

Copy link
Copy Markdown
Contributor

@csadorf csadorf left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Two questions, otherwise, this is probably good to go.

Comment thread python/cuml/cuml/tests/test_fil.py
Copy link
Copy Markdown
Contributor

@csadorf csadorf left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just one question about performance impact, otherwise almost good to merge.

@hcho3
Copy link
Copy Markdown
Contributor Author

hcho3 commented Jul 21, 2025

Just one question about performance impact, otherwise almost good to merge.

I don't see this question anywhere?

@csadorf
Copy link
Copy Markdown
Contributor

csadorf commented Jul 22, 2025

Just one question about performance impact, otherwise almost good to merge.

I don't see this question anywhere?

This was in reference to this comment which I apparently forgot to link to. I think it would be good if we did a very brief performance analysis to rule out that the cudaGetDevice query is in fact having no or only negligible impact.

@hcho3
Copy link
Copy Markdown
Contributor Author

hcho3 commented Jul 23, 2025

I wrote a microbenchmark to analyze the performance impact of adding the cudaGetDevice query.

C++ microbenchmark source code

main.cu:

#include <raft/core/handle.hpp>
#include <raft/core/device_mdarray.hpp>
#include <raft/core/device_mdspan.hpp>
#include <raft/random/make_regression.cuh>
#include <rmm/cuda_stream_pool.hpp>
#include <cuml/ensemble/randomforest.hpp>
#include <cuml/fil/infer_kind.hpp>
#include <cuml/fil/forest_model.hpp>
#include <cuml/fil/tree_layout.hpp>
#include <cuml/fil/treelite_importer.hpp>

#include <cstdint>
#include <chrono>
#include <memory>
#include <optional>
#include <utility>

using Device1DArray
    = raft::device_mdarray<float, raft::vector_extent<std::uint64_t>, raft::layout_right>;
using Device1DArrayView
    = raft::device_mdspan<float, raft::vector_extent<std::uint64_t>, raft::layout_right>;
using Device2DArray
    = raft::device_mdarray<float, raft::matrix_extent<std::uint64_t>, raft::layout_right>;
using Device2DArrayView
    = raft::device_mdspan<float, raft::matrix_extent<std::uint64_t>, raft::layout_right>;

std::pair<Device2DArray, Device1DArray> make_regression(const raft::handle_t& handle,
    std::uint64_t n_rows, std::uint64_t n_cols) {
  Device2DArray X = raft::make_device_matrix<float>(handle, n_rows, n_cols);
  Device1DArray y = raft::make_device_vector<float>(handle, n_rows);

  raft::random::make_regression(handle,
      X.data_handle(),
      y.data_handle(),
      n_rows,
      n_cols,
      n_cols / 3,
      handle.get_stream(),
      (float*)nullptr,
      std::uint64_t(1),
      0.0f,
      n_cols / 3,
      0.1f,
      0.01f,
      false,
      12345ULL);
  handle.sync_stream();
  handle.sync_stream_pool();

  return {X, y};
}

ML::fil::forest_model fit_rf_regressor(
    const raft::handle_t& handle, Device2DArrayView X, Device1DArrayView y,
    std::uint32_t n_trees, std::uint32_t max_depth) {
  // Take first 1000 rows as training set
  auto train_nrows = std::min(X.extent(0), std::uint64_t(1000));
  auto rf_model = std::make_unique<ML::RandomForestRegressorF>();
  auto* rf_model_ptr = rf_model.get();
  ML::RF_params rf_params = ML::set_rf_params(
      static_cast<int>(max_depth),
      (1 << 20),
      1.f,
      32,
      3,
      3,
      0.0f,
      true,
      static_cast<int>(n_trees),
      1.f,
      1234ULL,
      ML::CRITERION::MSE,
      8,
      128
  );
  ML::fit(handle, rf_model_ptr, X.data_handle(), train_nrows, X.extent(1), y.data_handle(), rf_params);
  handle.sync_stream();
  handle.sync_stream_pool();

  void* tl_model_ptr{nullptr};
  ML::build_treelite_forest(&tl_model_ptr, rf_model.get(), X.extent(1));
  std::unique_ptr<treelite::Model> tl_model{static_cast<treelite::Model*>(tl_model_ptr)};
  return ML::fil::import_from_treelite_handle(
    tl_model.get(), ML::fil::tree_layout::depth_first, 128, false,
    raft_proto::device_type::gpu, 0, handle.get_stream());
}

int main() {
  constexpr int num_streams = 16;
  auto stream_pool = std::make_shared<rmm::cuda_stream_pool>(num_streams);
  raft::handle_t handle{rmm::cuda_stream_per_thread, stream_pool};

  constexpr std::uint64_t n_rows = 1000000;
  constexpr std::uint64_t n_cols = 30;
  constexpr std::uint32_t n_trees = 100;
  constexpr std::uint32_t max_depth = 10;
  constexpr std::uint32_t n_reps = 1000;

  std::cout << "Fit a random forest..." << std::endl;
  auto [X, y] = make_regression(handle, n_rows, n_cols);
  auto fm = fit_rf_regressor(handle, X.view(), y.view(), n_trees, max_depth);

  std::cout << "Run FIL predict..." << std::endl;

  auto ypred = raft::make_device_vector<float>(handle, X.extent(0));

  auto tstart = std::chrono::high_resolution_clock::now();
  for (std::uint32_t i = 0; i < n_reps; i++) {
    fm.predict(handle, ypred.data_handle(), X.data_handle(), X.extent(0),
        raft_proto::device_type::gpu, raft_proto::device_type::gpu,
        ML::fil::infer_kind::default_kind, std::nullopt);
  }
  handle.sync_stream();
  handle.sync_stream_pool();

  auto tend = std::chrono::high_resolution_clock::now();
  std::cout << "Time elapsed: "
    << static_cast<double>(
        std::chrono::duration_cast<std::chrono::nanoseconds>(tend - tstart).count()
        / 1000000)
    << " ms" << std::endl;

  return 0;
}

CMakeLists.txt:

cmake_minimum_required(VERSION 3.28)
project(fil_test LANGUAGES C CXX CUDA)

find_package(cuml 25.08 CONFIG REQUIRED)
find_package(raft 25.08 CONFIG REQUIRED)
find_package(Treelite 4.3.1 CONFIG REQUIRED)

add_executable(test main.cu)

target_link_libraries(test
    PRIVATE
    cuml::cuml++
    PUBLIC
    raft::raft
    treelite::treelite
    CUDA::cublas
    CUDA::cusolver)
target_compile_definitions(test
    PUBLIC
    CUML_ENABLE_GPU=1)
set_target_properties(test
    PROPERTIES
    POSITION_INDEPENDENT_CODE ON
    CXX_STANDARD 17
    CXX_STANDARD_REQUIRED ON
    CUDA_ARCHITECTURES 75
)

I ran the microbenchmark program 5 times.

branch-25.08: 6762 ms, 6816 ms, 6848 ms, 6902 ms, 6938 ms (mean = 6853.2 ms, stddev = 69.46 ms)
This PR: 6856 ms, 6795 ms, 6836 ms, 6876 ms, 6944 ms (mean = 6861.4 ms, stddev = 55.04 ms)

Using the principles of the null hypothesis testing, the difference between the two versions is not statistically significant (p = 0.8413).

Conclusion. The extra cudaGetDevice has no effect on the inference performance.

@csadorf
Copy link
Copy Markdown
Contributor

csadorf commented Jul 23, 2025

/merge

@rapids-bot rapids-bot Bot merged commit ee51a45 into rapidsai:branch-25.08 Jul 23, 2025
73 checks passed
@hcho3 hcho3 deleted the fix_mgpu_fil branch July 23, 2025 18:50
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CUDA/C++ Cython / Python Cython or Python issue improvement Improvement / enhancement to an existing function non-breaking Non-breaking change

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[BUG] Errors with 25.06 RandomForest inference on a multi-gpu server

4 participants