Skip to content
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 47 additions & 12 deletions include/xgboost/multi_target_tree_model.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,15 @@ struct TreeParam;

/**
* @brief Tree structure for multi-target model.
*
* In order to support reduced gradient, the internal storage distinguishes weights
* between base weights and leaf weights. The former is the weight calculated from split
* gradient, and the later is the weight calculated from value gradient and used as
* outputs. Every node has a base weight, but only leaves have leaf weights.
*
* To access the leaf weights, we re-use the right child to store leaf indices. For split
* nodes, the `right_` member stores their right child node indices, for leaf nodes, the
* `right_` member stores the corresponding leaf weight indices.
*/
class MultiTargetTree : public Model {
public:
Expand All @@ -33,24 +42,36 @@ class MultiTargetTree : public Model {

private:
TreeParam const* param_;
// Mapping from node index to its left child. -1 for a leaf node.
HostDeviceVector<bst_node_t> left_;
// Mapping from node index to its right child. Maps to leaf weight for a leaf node.
HostDeviceVector<bst_node_t> right_;
// Mapping from node index to its parent.
HostDeviceVector<bst_node_t> parent_;
// Feature index for node split.
HostDeviceVector<bst_feature_t> split_index_;
// Whether the left child is the default node when split feature is missing.
HostDeviceVector<std::uint8_t> default_left_;
// Threshold for splitting a node.
HostDeviceVector<float> split_conds_;
// Internal base weights.
HostDeviceVector<float> weights_;
// Output weights.
HostDeviceVector<float> leaf_weights_;

[[nodiscard]] linalg::VectorView<float const> NodeWeight(bst_node_t nidx) const {
auto beg = nidx * this->NumTargets();
auto v = this->weights_.ConstHostSpan().subspan(beg, this->NumTargets());
auto beg = nidx * this->NumSplitTargets();
auto v = this->weights_.ConstHostSpan().subspan(beg, this->NumSplitTargets());
return linalg::MakeTensorView(DeviceOrd::CPU(), v, v.size());
}
[[nodiscard]] linalg::VectorView<float> NodeWeight(bst_node_t nidx) {
auto beg = nidx * this->NumTargets();
auto v = this->weights_.HostSpan().subspan(beg, this->NumTargets());
// Unlike the const version, `NumSplitTargets` is not reliable if the tree can change.
[[nodiscard]] linalg::VectorView<float> NodeWeight(bst_node_t nidx,
bst_target_t n_split_targets) {
auto beg = nidx * n_split_targets;
auto v = this->weights_.HostSpan().subspan(beg, n_split_targets);
return linalg::MakeTensorView(DeviceOrd::CPU(), v, v.size());
}
[[nodiscard]] bst_node_t LeafIdx(bst_node_t nidx) const { return this->RightChild(nidx); }

public:
explicit MultiTargetTree(TreeParam const* param);
Expand All @@ -72,6 +93,8 @@ class MultiTargetTree : public Model {
linalg::VectorView<float const> right_weight);
/** @see RegTree::SetLeaves */
void SetLeaves(std::vector<bst_node_t> leaves, common::Span<float const> weights);
/** @brief Copy base weight into leaf weight for a non-reduced multi-target tree. */
void SetLeaves();

[[nodiscard]] bool IsLeaf(bst_node_t nidx) const {
return left_.ConstHostVector()[nidx] == InvalidNodeId();
Expand All @@ -82,24 +105,36 @@ class MultiTargetTree : public Model {
[[nodiscard]] bst_node_t RightChild(bst_node_t nidx) const {
return right_.ConstHostVector().at(nidx);
}

/**
* @brief Number of targets (size of a leaf).
*/
[[nodiscard]] bst_target_t NumTargets() const;
[[nodiscard]] auto NumLeaves() const { return this->weights_.Size() / this->NumTargets(); }
/**
* @brief Number of reduced targets.
*/
[[nodiscard]] bst_target_t NumSplitTargets() const;
[[nodiscard]] auto NumLeaves() const { return this->leaf_weights_.Size() / this->NumTargets(); }

[[nodiscard]] std::size_t Size() const;
[[nodiscard]] MultiTargetTree* Copy(TreeParam const* param) const;

common::Span<float const> Weights(DeviceOrd device) const {
common::Span<float const> LeafWeights(DeviceOrd device) const {
if (device.IsCPU()) {
return this->weights_.ConstHostSpan();
return this->leaf_weights_.ConstHostSpan();
}
this->weights_.SetDevice(device);
return this->weights_.ConstDeviceSpan();
this->leaf_weights_.SetDevice(device);
return this->leaf_weights_.ConstDeviceSpan();
}

[[nodiscard]] linalg::VectorView<float const> LeafValue(bst_node_t nidx) const {
CHECK(IsLeaf(nidx));
return this->NodeWeight(nidx);
auto n_targets = this->NumTargets();
auto h_leaf_mapping = this->right_.ConstHostSpan();
auto h_leaf_weights = this->leaf_weights_.ConstHostSpan();
auto lidx = h_leaf_mapping[nidx];
CHECK_NE(lidx, InvalidNodeId());
auto weight = h_leaf_weights.subspan(lidx * n_targets, n_targets);
return linalg::MakeVec(DeviceOrd::CPU(), weight);
}

void LoadModel(Json const& in) override;
Expand Down
2 changes: 2 additions & 0 deletions include/xgboost/tree_model.h
Original file line number Diff line number Diff line change
Expand Up @@ -408,6 +408,8 @@ class RegTree : public Model {
[[nodiscard]] bst_node_t GetDepth(bst_node_t nidx) const;
/**
* @brief Set the root weight for a multi-target tree.
*
* @param weight Internal split weight, with size equals to reduced targets.
*/
void SetRoot(linalg::VectorView<float const> weight) {
CHECK(IsMultiTarget());
Expand Down
8 changes: 8 additions & 0 deletions src/common/cuda_rt_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
*/
#include "cuda_rt_utils.h"

#include "cuda_stream.h" // for StreamRef

#if defined(XGBOOST_USE_CUDA)
#include <cuda_runtime_api.h>

Expand Down Expand Up @@ -99,6 +101,10 @@ void GetDrVersionGlobal(std::int32_t* major, std::int32_t* minor) {
return numa_id;
}

void MemcpyAsync(void* dst, const void* src, std::size_t count, StreamRef stream) {
dh::safe_cuda(cudaMemcpyAsync(dst, src, count, cudaMemcpyDefault, stream));
}

#else
std::int32_t AllVisibleGPUs() { return 0; }

Expand Down Expand Up @@ -128,5 +134,7 @@ void SetDevice(std::int32_t device) {
return 0;
}

void MemcpyAsync(void*, const void*, std::size_t, StreamRef) { common::AssertGPUSupport(); }

#endif // !defined(XGBOOST_USE_CUDA)
} // namespace xgboost::curt
5 changes: 5 additions & 0 deletions src/common/cuda_rt_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
#include <cstddef> // for size_t
#include <cstdint> // for int32_t

#include "cuda_stream.h" // for StreamRef

namespace xgboost::curt {
std::int32_t AllVisibleGPUs();

Expand Down Expand Up @@ -35,4 +37,7 @@ void GetDrVersionGlobal(std::int32_t* major, std::int32_t* minor);

// Get the current device's numa ID.
[[nodiscard]] std::int32_t GetNumaId();

// cudaMemcpyAsync
void MemcpyAsync(void* dst, const void* src, std::size_t count, StreamRef stream);
} // namespace xgboost::curt
12 changes: 12 additions & 0 deletions src/common/cuda_stream.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,18 @@
* Copyright 2022-2025, XGBoost contributors
*/
#pragma once

#if defined(XGBOOST_USE_CUDA)
#include <cuda_runtime.h>
#endif // defined(XGBOOST_USE_CUDA)

#include <memory> // for unique_ptr
#include <utility> // for swap

#include "common.h"

namespace xgboost::curt {
#if defined(XGBOOST_USE_CUDA)
class StreamRef;

class Event {
Expand Down Expand Up @@ -94,4 +98,12 @@ class Stream {
void Sync() { this->View().Sync(); }
void Wait(Event const &e) { this->View().Wait(e); }
};
#else
class StreamRef {};

inline StreamRef DefaultStream() {
common::AssertGPUSupport();
return StreamRef{};
}
#endif
} // namespace xgboost::curt
6 changes: 3 additions & 3 deletions src/gbm/gbtree.h
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,7 @@ class GBTree : public GradientBooster {
HostDeviceVector<bst_float>* out_preds,
uint32_t layer_begin, uint32_t layer_end) override {
auto [tree_begin, tree_end] = detail::LayerToTree(model_, layer_begin, layer_end);
CHECK_EQ(tree_begin, 0) << "Predict leaf supports only iteration end: (0, "
CHECK_EQ(tree_begin, 0) << "Predict leaf supports only iteration end: [0, "
"n_iteration), use model slicing instead.";
this->GetPredictor(false)->PredictLeaf(p_fmat, out_preds, model_, tree_end);
}
Expand All @@ -304,7 +304,7 @@ class GBTree : public GradientBooster {
bst_layer_t layer_begin, bst_layer_t layer_end,
bool approximate) override {
auto [tree_begin, tree_end] = detail::LayerToTree(model_, layer_begin, layer_end);
CHECK_EQ(tree_begin, 0) << "Predict contribution supports only iteration end: (0, "
CHECK_EQ(tree_begin, 0) << "Predict contribution supports only iteration end: [0, "
"n_iteration), using model slicing instead.";
this->GetPredictor(false)->PredictContribution(p_fmat, out_contribs, model_, tree_end, nullptr,
approximate);
Expand All @@ -314,7 +314,7 @@ class GBTree : public GradientBooster {
bst_layer_t layer_begin, bst_layer_t layer_end,
bool approximate) override {
auto [tree_begin, tree_end] = detail::LayerToTree(model_, layer_begin, layer_end);
CHECK_EQ(tree_begin, 0) << "Predict interaction contribution supports only iteration end: (0, "
CHECK_EQ(tree_begin, 0) << "Predict interaction contribution supports only iteration end: [0, "
"n_iteration), using model slicing instead.";
this->GetPredictor(false)->PredictInteractionContributions(p_fmat, out_contribs, model_,
tree_end, nullptr, approximate);
Expand Down
1 change: 1 addition & 0 deletions src/tree/io_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ namespace tree_field {
inline std::string const kLossChg{"loss_changes"};
inline std::string const kSumHess{"sum_hessian"};
inline std::string const kBaseWeight{"base_weights"};
inline std::string const kLeafWeight{"leaf_weights"};

inline std::string const kSplitIdx{"split_indices"};
inline std::string const kSplitCond{"split_conditions"};
Expand Down
Loading
Loading