Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 68 additions & 4 deletions cpp/include/cuvs/cluster/kmeans.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -519,10 +519,26 @@ void predict(raft::resources const& handle,
raft::device_matrix_view<const float, int> X,
std::optional<raft::device_vector_view<const float, int>> sample_weight,
raft::device_matrix_view<const float, int> centroids,
raft::device_vector_view<int, int> labels,
bool normalize_weight,
raft::device_vector_view<int, int> labels,
raft::host_scalar_view<float> inertia);

// This overload is retained for backward compatibility.
[[deprecated(
"The argument order of kmeans::predict has been corrected. Please use the new function "
"instead.")]]
inline void predict(raft::resources const& handle,
const kmeans::params& params,
raft::device_matrix_view<const float, int> X,
std::optional<raft::device_vector_view<const float, int>> sample_weight,
raft::device_matrix_view<const float, int> centroids,
raft::device_vector_view<int, int> labels,
bool normalize_weight,
raft::host_scalar_view<float> inertia)
{
predict(handle, params, X, sample_weight, centroids, normalize_weight, labels, inertia);
}

/**
* @brief Predict the closest cluster each sample in X belongs to.
*
Expand Down Expand Up @@ -577,10 +593,26 @@ void predict(raft::resources const& handle,
raft::device_matrix_view<const float, int> X,
std::optional<raft::device_vector_view<const float, int>> sample_weight,
raft::device_matrix_view<const float, int> centroids,
raft::device_vector_view<int64_t, int> labels,
bool normalize_weight,
raft::device_vector_view<int64_t, int> labels,
raft::host_scalar_view<float> inertia);

// This overload is retained for backward compatibility.
[[deprecated(
"The argument order of kmeans::predict has been corrected. Please use the new function "
"instead.")]]
inline void predict(raft::resources const& handle,
const kmeans::params& params,
raft::device_matrix_view<const float, int> X,
std::optional<raft::device_vector_view<const float, int>> sample_weight,
raft::device_matrix_view<const float, int> centroids,
raft::device_vector_view<int64_t, int> labels,
bool normalize_weight,
raft::host_scalar_view<float> inertia)
{
predict(handle, params, X, sample_weight, centroids, normalize_weight, labels, inertia);
}

/**
* @brief Predict the closest cluster each sample in X belongs to.
*
Expand Down Expand Up @@ -635,10 +667,26 @@ void predict(raft::resources const& handle,
raft::device_matrix_view<const double, int> X,
std::optional<raft::device_vector_view<const double, int>> sample_weight,
raft::device_matrix_view<const double, int> centroids,
raft::device_vector_view<int, int> labels,
bool normalize_weight,
raft::device_vector_view<int, int> labels,
raft::host_scalar_view<double> inertia);

// This overload is retained for backward compatibility.
[[deprecated(
"The argument order of kmeans::predict has been corrected. Please use the new function "
"instead.")]]
inline void predict(raft::resources const& handle,
const kmeans::params& params,
raft::device_matrix_view<const double, int> X,
std::optional<raft::device_vector_view<const double, int>> sample_weight,
raft::device_matrix_view<const double, int> centroids,
raft::device_vector_view<int, int> labels,
bool normalize_weight,
raft::host_scalar_view<double> inertia)
{
predict(handle, params, X, sample_weight, centroids, normalize_weight, labels, inertia);
}

/**
* @brief Predict the closest cluster each sample in X belongs to.
*
Expand Down Expand Up @@ -693,10 +741,26 @@ void predict(raft::resources const& handle,
raft::device_matrix_view<const double, int> X,
std::optional<raft::device_vector_view<const double, int>> sample_weight,
raft::device_matrix_view<const double, int> centroids,
raft::device_vector_view<int64_t, int> labels,
bool normalize_weight,
raft::device_vector_view<int64_t, int> labels,
raft::host_scalar_view<double> inertia);

// This overload is retained for backward compatibility.
[[deprecated(
"The argument order of kmeans::predict has been corrected. Please use the new function "
"instead.")]]
inline void predict(raft::resources const& handle,
const kmeans::params& params,
raft::device_matrix_view<const double, int> X,
std::optional<raft::device_vector_view<const double, int>> sample_weight,
raft::device_matrix_view<const double, int> centroids,
raft::device_vector_view<int64_t, int> labels,
bool normalize_weight,
raft::host_scalar_view<double> inertia)
{
predict(handle, params, X, sample_weight, centroids, normalize_weight, labels, inertia);
}

/**
* @brief Predict the closest cluster each sample in X belongs to.
*
Expand Down
2 changes: 1 addition & 1 deletion cpp/src/cluster/kmeans.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -163,8 +163,8 @@ void predict(raft::resources const& handle,
raft::device_matrix_view<const DataT, IndexT> X,
std::optional<raft::device_vector_view<const DataT, IndexT>> sample_weight,
raft::device_matrix_view<const DataT, IndexT> centroids,
raft::device_vector_view<IndexT, IndexT> labels,
bool normalize_weight,
raft::device_vector_view<IndexT, IndexT> labels,
raft::host_scalar_view<DataT> inertia)
{
cuvs::cluster::kmeans::detail::kmeans_predict<DataT, IndexT>(
Expand Down
8 changes: 4 additions & 4 deletions cpp/src/cluster/kmeans_predict_double.cu
Original file line number Diff line number Diff line change
Expand Up @@ -24,26 +24,26 @@ void predict(raft::resources const& handle,
raft::device_matrix_view<const double, int> X,
std::optional<raft::device_vector_view<const double, int>> sample_weight,
raft::device_matrix_view<const double, int> centroids,
raft::device_vector_view<int, int> labels,
bool normalize_weight,
raft::device_vector_view<int, int> labels,
raft::host_scalar_view<double> inertia)

{
cuvs::cluster::kmeans::predict<double, int>(
handle, params, X, sample_weight, centroids, labels, normalize_weight, inertia);
handle, params, X, sample_weight, centroids, normalize_weight, labels, inertia);
}

void predict(raft::resources const& handle,
const kmeans::params& params,
raft::device_matrix_view<const double, int> X,
std::optional<raft::device_vector_view<const double, int>> sample_weight,
raft::device_matrix_view<const double, int> centroids,
raft::device_vector_view<int64_t, int> labels,
bool normalize_weight,
raft::device_vector_view<int64_t, int> labels,
raft::host_scalar_view<double> inertia)

{
cuvs::cluster::kmeans::predict<double, int64_t>(
handle, params, X, sample_weight, centroids, labels, normalize_weight, inertia);
handle, params, X, sample_weight, centroids, normalize_weight, labels, inertia);
}
} // namespace cuvs::cluster::kmeans
8 changes: 4 additions & 4 deletions cpp/src/cluster/kmeans_predict_float.cu
Original file line number Diff line number Diff line change
Expand Up @@ -24,25 +24,25 @@ void predict(raft::resources const& handle,
raft::device_matrix_view<const float, int> X,
std::optional<raft::device_vector_view<const float, int>> sample_weight,
raft::device_matrix_view<const float, int> centroids,
raft::device_vector_view<int, int> labels,
bool normalize_weight,
raft::device_vector_view<int, int> labels,
raft::host_scalar_view<float> inertia)

{
cuvs::cluster::kmeans::predict<float, int>(
handle, params, X, sample_weight, centroids, labels, normalize_weight, inertia);
handle, params, X, sample_weight, centroids, normalize_weight, labels, inertia);
}
void predict(raft::resources const& handle,
const kmeans::params& params,
raft::device_matrix_view<const float, int> X,
std::optional<raft::device_vector_view<const float, int>> sample_weight,
raft::device_matrix_view<const float, int> centroids,
raft::device_vector_view<int64_t, int> labels,
bool normalize_weight,
raft::device_vector_view<int64_t, int> labels,
raft::host_scalar_view<float> inertia)

{
cuvs::cluster::kmeans::predict<float, int64_t>(
handle, params, X, sample_weight, centroids, labels, normalize_weight, inertia);
handle, params, X, sample_weight, centroids, normalize_weight, labels, inertia);
}
} // namespace cuvs::cluster::kmeans