Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 5 additions & 7 deletions c/src/cluster/kmeans.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION.
* SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION.
* SPDX-License-Identifier: Apache-2.0
*/

Expand Down Expand Up @@ -63,12 +63,10 @@ void _fit(cuvsResources_t res,
RAFT_FAIL("float64 is an unsupported dtype for hierarchical kmeans");
} else {
auto kmeans_params = convert_balanced_params(params);
cuvs::cluster::kmeans::fit(*res_ptr,
kmeans_params,
cuvs::core::from_dlpack<const_mdspan_type>(X_tensor),
cuvs::core::from_dlpack<mdspan_type>(centroids_tensor));

*inertia = 0;
T inertia_temp;
auto inertia_view = raft::make_host_scalar_view<T>(&inertia_temp);
cuvs::cluster::kmeans::fit(*res_ptr, kmeans_params, cuvs::core::from_dlpack<const_mdspan_type>(X_tensor), cuvs::core::from_dlpack<mdspan_type>(centroids_tensor), std::make_optional(inertia_view));
*inertia = inertia_temp;
Comment thread
tarang-jain marked this conversation as resolved.
*n_iter = params.hierarchical_n_iters;
}
} else {
Expand Down
88 changes: 75 additions & 13 deletions cpp/include/cuvs/cluster/kmeans.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
#include <raft/random/rng_state.hpp>
#include <rapids_logger/logger.hpp>

#include <optional>

namespace cuvs::cluster::kmeans {

/** Base structure for parameters that are common to all k-means algorithms */
Expand Down Expand Up @@ -420,11 +422,14 @@ void fit(raft::resources const& handle,
* kmeans algorithm are stored at the address
* pointed by 'centroids'.
* [dim = n_clusters x n_features]
* @param[out] inertia Sum of squared distances of samples to their
* closest cluster center.
*/
void fit(const raft::resources& handle,
cuvs::cluster::kmeans::balanced_params const& params,
raft::device_matrix_view<const float, int64_t> X,
raft::device_matrix_view<float, int64_t> centroids);
raft::device_matrix_view<float, int64_t> centroids,
std::optional<raft::host_scalar_view<float>> inertia = std::nullopt);

/**
* @brief Find balanced clusters with k-means algorithm.
Expand Down Expand Up @@ -454,11 +459,14 @@ void fit(const raft::resources& handle,
* kmeans algorithm are stored at the address
* pointed by 'centroids'.
* [dim = n_clusters x n_features]
* @param[out] inertia Sum of squared distances of samples to their
* closest cluster center.
*/
void fit(const raft::resources& handle,
cuvs::cluster::kmeans::balanced_params const& params,
raft::device_matrix_view<const int8_t, int64_t> X,
raft::device_matrix_view<float, int64_t> centroids);
raft::device_matrix_view<float, int64_t> centroids,
std::optional<raft::host_scalar_view<float>> inertia = std::nullopt);

/**
* @brief Find balanced clusters with k-means algorithm.
Expand Down Expand Up @@ -488,11 +496,14 @@ void fit(const raft::resources& handle,
* kmeans algorithm are stored at the address
* pointed by 'centroids'.
* [dim = n_clusters x n_features]
* @param[out] inertia Sum of squared distances of samples to their
* closest cluster center.
*/
void fit(const raft::resources& handle,
cuvs::cluster::kmeans::balanced_params const& params,
raft::device_matrix_view<const half, int64_t> X,
raft::device_matrix_view<float, int64_t> centroids);
raft::device_matrix_view<float, int64_t> centroids,
std::optional<raft::host_scalar_view<float>> inertia = std::nullopt);

/**
* @brief Find balanced clusters with k-means algorithm.
Expand Down Expand Up @@ -522,11 +533,14 @@ void fit(const raft::resources& handle,
* kmeans algorithm are stored at the address
* pointed by 'centroids'.
* [dim = n_clusters x n_features]
* @param[out] inertia Sum of squared distances of samples to their
* closest cluster center.
*/
void fit(const raft::resources& handle,
cuvs::cluster::kmeans::balanced_params const& params,
raft::device_matrix_view<const uint8_t, int64_t> X,
raft::device_matrix_view<float, int64_t> centroids);
raft::device_matrix_view<float, int64_t> centroids,
std::optional<raft::host_scalar_view<float>> inertia = std::nullopt);

/**
* @brief Predict the closest cluster each sample in X belongs to.
Expand Down Expand Up @@ -1380,7 +1394,7 @@ void transform(raft::resources const& handle,
raft::device_matrix_view<double, int> X_new);

/**
* @brief Compute cluster cost
* @brief Compute (optionally weighted) cluster cost
*
* @param[in] handle The raft handle
* @param[in] X Training instances to cluster. The data must
Expand All @@ -1390,12 +1404,16 @@ void transform(raft::resources const& handle,
* row-major format.
* [dim = n_clusters x n_features]
* @param[out] cost Resulting cluster cost
* @param[in] sample_weight Optional per-sample weights.
* [len = n_samples]
*
*/
void cluster_cost(const raft::resources& handle,
raft::device_matrix_view<const float, int> X,
raft::device_matrix_view<const float, int> centroids,
raft::host_scalar_view<float> cost);
void cluster_cost(
const raft::resources& handle,
raft::device_matrix_view<const float, int> X,
raft::device_matrix_view<const float, int> centroids,
raft::host_scalar_view<float> cost,
std::optional<raft::device_vector_view<const float, int>> sample_weight = std::nullopt);
Comment thread
tarang-jain marked this conversation as resolved.

/**
* @brief Compute cluster cost
Expand All @@ -1408,13 +1426,57 @@ void cluster_cost(const raft::resources& handle,
* row-major format.
* [dim = n_clusters x n_features]
* @param[out] cost Resulting cluster cost
* @param[in] sample_weight Optional per-sample weights.
* [len = n_samples]
*/
void cluster_cost(
const raft::resources& handle,
raft::device_matrix_view<const double, int> X,
raft::device_matrix_view<const double, int> centroids,
raft::host_scalar_view<double> cost,
std::optional<raft::device_vector_view<const double, int>> sample_weight = std::nullopt);

/**
* @brief Compute (optionally weighted) cluster cost
*
* @param[in] handle The raft handle
* @param[in] X Training instances to cluster. The data must
* be in row-major format.
* [dim = n_samples x n_features]
* @param[in] centroids Cluster centroids. The data must be in
* row-major format.
* [dim = n_clusters x n_features]
* @param[out] cost Resulting cluster cost
* @param[in] sample_weight Optional per-sample weights.
* [len = n_samples]
*/
void cluster_cost(const raft::resources& handle,
raft::device_matrix_view<const double, int> X,
raft::device_matrix_view<const double, int> centroids,
raft::host_scalar_view<double> cost);
void cluster_cost(
const raft::resources& handle,
raft::device_matrix_view<const float, int64_t> X,
raft::device_matrix_view<const float, int64_t> centroids,
raft::host_scalar_view<float> cost,
std::optional<raft::device_vector_view<const float, int64_t>> sample_weight = std::nullopt);

/**
* @brief Compute (optionally weighted) cluster cost
*
* @param[in] handle The raft handle
* @param[in] X Training instances to cluster. The data must
* be in row-major format.
* [dim = n_samples x n_features]
* @param[in] centroids Cluster centroids. The data must be in
* row-major format.
* [dim = n_clusters x n_features]
* @param[out] cost Resulting cluster cost
* @param[in] sample_weight Optional per-sample weights.
* [len = n_samples]
*/
void cluster_cost(
const raft::resources& handle,
raft::device_matrix_view<const double, int64_t> X,
raft::device_matrix_view<const double, int64_t> centroids,
raft::host_scalar_view<double> cost,
std::optional<raft::device_vector_view<const double, int64_t>> sample_weight = std::nullopt);
/**
* @}
*/
Expand Down
51 changes: 33 additions & 18 deletions cpp/src/cluster/detail/kmeans_balanced.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -999,22 +999,22 @@ auto build_fine_clusters(const raft::resources& handle,
/**
* @brief Hierarchical balanced k-means
*
* @tparam T element type
* @tparam MathT type of the centroids and mapped data
* @tparam IdxT index type
* @tparam LabelT label type
* @tparam T element type
* @tparam MathT type of the centroids and mapped data
* @tparam IdxT index type
* @tparam MappingOpT type of the mapping operation
*
* @param[in] handle The raft handle.
* @param[in] params Structure containing the hyper-parameters
* @param dim number of columns in `centers` and `dataset`
* @param[in] dataset a device pointer to the source dataset [n_rows, dim]
* @param n_rows number of rows in the input
* @param[out] cluster_centers a device pointer to the found cluster centers [n_cluster, dim]
* @param n_cluster
* @param metric the distance type
* @param mapping_op Mapping operation from T to MathT
* @param stream
* @param[in] handle The raft handle.
* @param[in] params Structure containing the hyper-parameters
* @param[in] dim Number of columns in `cluster_centers` and `dataset`
* @param[in] dataset A device pointer to the source dataset [n_rows, dim]
* @param[in] n_rows Number of rows in the input
* @param[out] cluster_centers A device pointer to the found cluster centers [n_clusters, dim]
* @param[in] n_clusters Requested number of clusters
* @param[in] mapping_op Mapping operation from T to MathT
* @param[out] inertia (optional) If non-null, the sum of squared distances of samples to
* their closest cluster center is written here.
* Only supported when T == MathT (float/double).
*/
template <typename T, typename MathT, typename IdxT, typename MappingOpT>
void build_hierarchical(const raft::resources& handle,
Expand All @@ -1025,7 +1025,7 @@ void build_hierarchical(const raft::resources& handle,
MathT* cluster_centers,
IdxT n_clusters,
MappingOpT mapping_op,
const MathT* dataset_norm = nullptr)
MathT* inertia = nullptr)
{
auto stream = raft::resource::get_cuda_stream(handle);
using LabelT = uint32_t;
Expand All @@ -1044,9 +1044,10 @@ void build_hierarchical(const raft::resources& handle,

// Precompute the L2 norm of the dataset if relevant and not yet computed.
rmm::device_uvector<MathT> dataset_norm_buf(0, stream, device_memory);
if (dataset_norm == nullptr && (params.metric == cuvs::distance::DistanceType::L2Expanded ||
params.metric == cuvs::distance::DistanceType::L2SqrtExpanded ||
params.metric == cuvs::distance::DistanceType::CosineExpanded)) {
const MathT* dataset_norm = nullptr;
if ((params.metric == cuvs::distance::DistanceType::L2Expanded ||
params.metric == cuvs::distance::DistanceType::L2SqrtExpanded ||
params.metric == cuvs::distance::DistanceType::CosineExpanded)) {
dataset_norm_buf.resize(n_rows, stream);
for (IdxT offset = 0; offset < n_rows; offset += max_minibatch_size) {
IdxT minibatch_size = std::min<IdxT>(max_minibatch_size, n_rows - offset);
Expand Down Expand Up @@ -1164,6 +1165,20 @@ void build_hierarchical(const raft::resources& handle,
MathT{0.2},
mapping_op,
device_memory);

// Compute inertia if requested (only supported when T == MathT)
if (inertia != nullptr) {
if constexpr (std::is_same_v<T, MathT>) {
Comment thread
tarang-jain marked this conversation as resolved.
auto X_view = raft::make_device_matrix_view<const MathT, IdxT>(
reinterpret_cast<const MathT*>(dataset), n_rows, dim);
auto centroids_view =
raft::make_device_matrix_view<const MathT, IdxT>(cluster_centers, n_clusters, dim);
cuvs::cluster::kmeans::cluster_cost(
handle, X_view, centroids_view, raft::make_host_scalar_view<MathT>(inertia));
} else {
RAFT_LOG_WARN("Inertia is not computed for non float/double types");
}
}
}

} // namespace cuvs::cluster::kmeans::detail
34 changes: 28 additions & 6 deletions cpp/src/cluster/kmeans.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -404,14 +404,27 @@ void min_cluster_distance(raft::resources const& handle,
workspace);
}

/**
* @brief Compute (optionally weighted) cluster cost (inertia).
*
* @tparam DataT float or double
* @tparam IndexT Index type
*
* @param[in] handle The raft handle
* @param[in] X Input data [n_samples x n_features]
* @param[in] centroids Cluster centroids [n_clusters x n_features]
* @param[out] cost Sum of squared distances to nearest centroid
* @param[in] sample_weight Optional per-sample weights [n_samples]
*/
template <typename DataT, typename IndexT>
void cluster_cost(raft::resources const& handle,
raft::device_matrix_view<const DataT, IndexT> X,
raft::device_matrix_view<const DataT, IndexT> centroids,
raft::host_scalar_view<DataT> cost)
void cluster_cost(
raft::resources const& handle,
raft::device_matrix_view<const DataT, IndexT> X,
raft::device_matrix_view<const DataT, IndexT> centroids,
raft::host_scalar_view<DataT> cost,
std::optional<raft::device_vector_view<const DataT, IndexT>> sample_weight = std::nullopt)
{
auto stream = raft::resource::get_cuda_stream(handle);

auto stream = raft::resource::get_cuda_stream(handle);
auto n_clusters = centroids.extent(0);
auto n_samples = X.extent(0);
auto n_features = X.extent(1);
Expand Down Expand Up @@ -440,6 +453,15 @@ void cluster_cost(raft::resources const& handle,
n_clusters,
workspace);

// Apply sample weights if provided
if (sample_weight.has_value()) {
raft::linalg::map(handle,
min_cluster_distance.view(),
raft::mul_op{},
raft::make_const_mdspan(min_cluster_distance.view()),
sample_weight.value());
}

auto device_cost = raft::make_device_scalar<DataT>(handle, DataT(0));

cuvs::cluster::kmeans::cluster_cost(
Expand Down
12 changes: 9 additions & 3 deletions cpp/src/cluster/kmeans_balanced.cuh
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2022-2025, NVIDIA CORPORATION.
* SPDX-FileCopyrightText: Copyright (c) 2022-2026, NVIDIA CORPORATION.
* SPDX-License-Identifier: Apache-2.0
*/

Expand Down Expand Up @@ -61,13 +61,16 @@ namespace cuvs::cluster::kmeans_balanced {
* @param[out] centroids The generated centroids [dim = n_clusters x n_features]
* @param[in] mapping_op (optional) Functor to convert from the input datatype to the arithmetic
* datatype. If DataT == MathT, this must be the identity.
* @param[out] inertia (optional) Sum of squared distances of samples to their
* closest cluster center.
*/
template <typename DataT, typename MathT, typename IndexT, typename MappingOpT = raft::identity_op>
void fit(const raft::resources& handle,
cuvs::cluster::kmeans::balanced_params const& params,
raft::device_matrix_view<const DataT, IndexT> X,
raft::device_matrix_view<MathT, IndexT> centroids,
MappingOpT mapping_op = raft::identity_op())
MappingOpT mapping_op = raft::identity_op(),
std::optional<raft::host_scalar_view<MathT>> inertia = std::nullopt)
{
RAFT_EXPECTS(X.extent(1) == centroids.extent(1),
"Number of features in dataset and centroids are different");
Expand All @@ -78,14 +81,17 @@ void fit(const raft::resources& handle,
"The number of centroids must be strictly positive and cannot exceed the number of "
"points in the training dataset.");

MathT* inertia_ptr = inertia.has_value() ? inertia.value().data_handle() : nullptr;

cuvs::cluster::kmeans::detail::build_hierarchical(handle,
params,
X.extent(1),
X.data_handle(),
X.extent(0),
centroids.data_handle(),
centroids.extent(0),
mapping_op);
mapping_op,
inertia_ptr);
}

/**
Expand Down
7 changes: 4 additions & 3 deletions cpp/src/cluster/kmeans_balanced_fit_float.cu
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2022-2025, NVIDIA CORPORATION.
* SPDX-FileCopyrightText: Copyright (c) 2022-2026, NVIDIA CORPORATION.
* SPDX-License-Identifier: Apache-2.0
*/

Expand All @@ -14,9 +14,10 @@ namespace cuvs::cluster::kmeans {
void fit(const raft::resources& handle,
cuvs::cluster::kmeans::balanced_params const& params,
raft::device_matrix_view<const float, int64_t> X,
raft::device_matrix_view<float, int64_t> centroids)
raft::device_matrix_view<float, int64_t> centroids,
std::optional<raft::host_scalar_view<float>> inertia)
{
cuvs::cluster::kmeans_balanced::fit(
handle, params, X, centroids, cuvs::spatial::knn::detail::utils::mapping<float>{});
handle, params, X, centroids, cuvs::spatial::knn::detail::utils::mapping<float>{}, inertia);
}
} // namespace cuvs::cluster::kmeans
Loading