Skip to content

Commit dd57fc6

Browse files
enp1s0mythrocks
authored andcommitted
Fix kmeans::predict argument order (rapidsai#915)
This PR updates the order of arguments in `kmeans::predict` to match the order in the documentation. The order in the documentation is considered appropriate because it lists the output arguments last, unlike the order of the actual function arguments. Authors: - tsuki (https://github.com/enp1s0) Approvers: - Corey J. Nolet (https://github.com/cjnolet) URL: rapidsai#915
1 parent d162bdc commit dd57fc6

4 files changed

Lines changed: 77 additions & 13 deletions

File tree

cpp/include/cuvs/cluster/kmeans.hpp

Lines changed: 68 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -519,10 +519,26 @@ void predict(raft::resources const& handle,
519519
raft::device_matrix_view<const float, int> X,
520520
std::optional<raft::device_vector_view<const float, int>> sample_weight,
521521
raft::device_matrix_view<const float, int> centroids,
522-
raft::device_vector_view<int, int> labels,
523522
bool normalize_weight,
523+
raft::device_vector_view<int, int> labels,
524524
raft::host_scalar_view<float> inertia);
525525

526+
// This overload is retained for backward compatibility.
527+
[[deprecated(
528+
"The argument order of kmeans::predict has been corrected. Please use the new function "
529+
"instead.")]]
530+
inline void predict(raft::resources const& handle,
531+
const kmeans::params& params,
532+
raft::device_matrix_view<const float, int> X,
533+
std::optional<raft::device_vector_view<const float, int>> sample_weight,
534+
raft::device_matrix_view<const float, int> centroids,
535+
raft::device_vector_view<int, int> labels,
536+
bool normalize_weight,
537+
raft::host_scalar_view<float> inertia)
538+
{
539+
predict(handle, params, X, sample_weight, centroids, normalize_weight, labels, inertia);
540+
}
541+
526542
/**
527543
* @brief Predict the closest cluster each sample in X belongs to.
528544
*
@@ -577,10 +593,26 @@ void predict(raft::resources const& handle,
577593
raft::device_matrix_view<const float, int> X,
578594
std::optional<raft::device_vector_view<const float, int>> sample_weight,
579595
raft::device_matrix_view<const float, int> centroids,
580-
raft::device_vector_view<int64_t, int> labels,
581596
bool normalize_weight,
597+
raft::device_vector_view<int64_t, int> labels,
582598
raft::host_scalar_view<float> inertia);
583599

600+
// This overload is retained for backward compatibility.
601+
[[deprecated(
602+
"The argument order of kmeans::predict has been corrected. Please use the new function "
603+
"instead.")]]
604+
inline void predict(raft::resources const& handle,
605+
const kmeans::params& params,
606+
raft::device_matrix_view<const float, int> X,
607+
std::optional<raft::device_vector_view<const float, int>> sample_weight,
608+
raft::device_matrix_view<const float, int> centroids,
609+
raft::device_vector_view<int64_t, int> labels,
610+
bool normalize_weight,
611+
raft::host_scalar_view<float> inertia)
612+
{
613+
predict(handle, params, X, sample_weight, centroids, normalize_weight, labels, inertia);
614+
}
615+
584616
/**
585617
* @brief Predict the closest cluster each sample in X belongs to.
586618
*
@@ -635,10 +667,26 @@ void predict(raft::resources const& handle,
635667
raft::device_matrix_view<const double, int> X,
636668
std::optional<raft::device_vector_view<const double, int>> sample_weight,
637669
raft::device_matrix_view<const double, int> centroids,
638-
raft::device_vector_view<int, int> labels,
639670
bool normalize_weight,
671+
raft::device_vector_view<int, int> labels,
640672
raft::host_scalar_view<double> inertia);
641673

674+
// This overload is retained for backward compatibility.
675+
[[deprecated(
676+
"The argument order of kmeans::predict has been corrected. Please use the new function "
677+
"instead.")]]
678+
inline void predict(raft::resources const& handle,
679+
const kmeans::params& params,
680+
raft::device_matrix_view<const double, int> X,
681+
std::optional<raft::device_vector_view<const double, int>> sample_weight,
682+
raft::device_matrix_view<const double, int> centroids,
683+
raft::device_vector_view<int, int> labels,
684+
bool normalize_weight,
685+
raft::host_scalar_view<double> inertia)
686+
{
687+
predict(handle, params, X, sample_weight, centroids, normalize_weight, labels, inertia);
688+
}
689+
642690
/**
643691
* @brief Predict the closest cluster each sample in X belongs to.
644692
*
@@ -693,10 +741,26 @@ void predict(raft::resources const& handle,
693741
raft::device_matrix_view<const double, int> X,
694742
std::optional<raft::device_vector_view<const double, int>> sample_weight,
695743
raft::device_matrix_view<const double, int> centroids,
696-
raft::device_vector_view<int64_t, int> labels,
697744
bool normalize_weight,
745+
raft::device_vector_view<int64_t, int> labels,
698746
raft::host_scalar_view<double> inertia);
699747

748+
// This overload is retained for backward compatibility.
749+
[[deprecated(
750+
"The argument order of kmeans::predict has been corrected. Please use the new function "
751+
"instead.")]]
752+
inline void predict(raft::resources const& handle,
753+
const kmeans::params& params,
754+
raft::device_matrix_view<const double, int> X,
755+
std::optional<raft::device_vector_view<const double, int>> sample_weight,
756+
raft::device_matrix_view<const double, int> centroids,
757+
raft::device_vector_view<int64_t, int> labels,
758+
bool normalize_weight,
759+
raft::host_scalar_view<double> inertia)
760+
{
761+
predict(handle, params, X, sample_weight, centroids, normalize_weight, labels, inertia);
762+
}
763+
700764
/**
701765
* @brief Predict the closest cluster each sample in X belongs to.
702766
*

cpp/src/cluster/kmeans.cuh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -163,8 +163,8 @@ void predict(raft::resources const& handle,
163163
raft::device_matrix_view<const DataT, IndexT> X,
164164
std::optional<raft::device_vector_view<const DataT, IndexT>> sample_weight,
165165
raft::device_matrix_view<const DataT, IndexT> centroids,
166-
raft::device_vector_view<IndexT, IndexT> labels,
167166
bool normalize_weight,
167+
raft::device_vector_view<IndexT, IndexT> labels,
168168
raft::host_scalar_view<DataT> inertia)
169169
{
170170
cuvs::cluster::kmeans::detail::kmeans_predict<DataT, IndexT>(

cpp/src/cluster/kmeans_predict_double.cu

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,26 +24,26 @@ void predict(raft::resources const& handle,
2424
raft::device_matrix_view<const double, int> X,
2525
std::optional<raft::device_vector_view<const double, int>> sample_weight,
2626
raft::device_matrix_view<const double, int> centroids,
27-
raft::device_vector_view<int, int> labels,
2827
bool normalize_weight,
28+
raft::device_vector_view<int, int> labels,
2929
raft::host_scalar_view<double> inertia)
3030

3131
{
3232
cuvs::cluster::kmeans::predict<double, int>(
33-
handle, params, X, sample_weight, centroids, labels, normalize_weight, inertia);
33+
handle, params, X, sample_weight, centroids, normalize_weight, labels, inertia);
3434
}
3535

3636
void predict(raft::resources const& handle,
3737
const kmeans::params& params,
3838
raft::device_matrix_view<const double, int> X,
3939
std::optional<raft::device_vector_view<const double, int>> sample_weight,
4040
raft::device_matrix_view<const double, int> centroids,
41-
raft::device_vector_view<int64_t, int> labels,
4241
bool normalize_weight,
42+
raft::device_vector_view<int64_t, int> labels,
4343
raft::host_scalar_view<double> inertia)
4444

4545
{
4646
cuvs::cluster::kmeans::predict<double, int64_t>(
47-
handle, params, X, sample_weight, centroids, labels, normalize_weight, inertia);
47+
handle, params, X, sample_weight, centroids, normalize_weight, labels, inertia);
4848
}
4949
} // namespace cuvs::cluster::kmeans

cpp/src/cluster/kmeans_predict_float.cu

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,25 +24,25 @@ void predict(raft::resources const& handle,
2424
raft::device_matrix_view<const float, int> X,
2525
std::optional<raft::device_vector_view<const float, int>> sample_weight,
2626
raft::device_matrix_view<const float, int> centroids,
27-
raft::device_vector_view<int, int> labels,
2827
bool normalize_weight,
28+
raft::device_vector_view<int, int> labels,
2929
raft::host_scalar_view<float> inertia)
3030

3131
{
3232
cuvs::cluster::kmeans::predict<float, int>(
33-
handle, params, X, sample_weight, centroids, labels, normalize_weight, inertia);
33+
handle, params, X, sample_weight, centroids, normalize_weight, labels, inertia);
3434
}
3535
void predict(raft::resources const& handle,
3636
const kmeans::params& params,
3737
raft::device_matrix_view<const float, int> X,
3838
std::optional<raft::device_vector_view<const float, int>> sample_weight,
3939
raft::device_matrix_view<const float, int> centroids,
40-
raft::device_vector_view<int64_t, int> labels,
4140
bool normalize_weight,
41+
raft::device_vector_view<int64_t, int> labels,
4242
raft::host_scalar_view<float> inertia)
4343

4444
{
4545
cuvs::cluster::kmeans::predict<float, int64_t>(
46-
handle, params, X, sample_weight, centroids, labels, normalize_weight, inertia);
46+
handle, params, X, sample_weight, centroids, normalize_weight, labels, inertia);
4747
}
4848
} // namespace cuvs::cluster::kmeans

0 commit comments

Comments
 (0)