@@ -519,10 +519,26 @@ void predict(raft::resources const& handle,
519519 raft::device_matrix_view<const float , int > X,
520520 std::optional<raft::device_vector_view<const float , int >> sample_weight,
521521 raft::device_matrix_view<const float , int > centroids,
522- raft::device_vector_view<int , int > labels,
523522 bool normalize_weight,
523+ raft::device_vector_view<int , int > labels,
524524 raft::host_scalar_view<float > inertia);
525525
526+ // This overload is retained for backward compatibility.
527+ [[deprecated(
528+ " The argument order of kmeans::predict has been corrected. Please use the new function "
529+ " instead." )]]
530+ inline void predict (raft::resources const & handle,
531+ const kmeans::params& params,
532+ raft::device_matrix_view<const float , int > X,
533+ std::optional<raft::device_vector_view<const float , int >> sample_weight,
534+ raft::device_matrix_view<const float , int > centroids,
535+ raft::device_vector_view<int , int > labels,
536+ bool normalize_weight,
537+ raft::host_scalar_view<float > inertia)
538+ {
539+ predict (handle, params, X, sample_weight, centroids, normalize_weight, labels, inertia);
540+ }
541+
526542/* *
527543 * @brief Predict the closest cluster each sample in X belongs to.
528544 *
@@ -577,10 +593,26 @@ void predict(raft::resources const& handle,
577593 raft::device_matrix_view<const float , int > X,
578594 std::optional<raft::device_vector_view<const float , int >> sample_weight,
579595 raft::device_matrix_view<const float , int > centroids,
580- raft::device_vector_view<int64_t , int > labels,
581596 bool normalize_weight,
597+ raft::device_vector_view<int64_t , int > labels,
582598 raft::host_scalar_view<float > inertia);
583599
600+ // This overload is retained for backward compatibility.
601+ [[deprecated(
602+ " The argument order of kmeans::predict has been corrected. Please use the new function "
603+ " instead." )]]
604+ inline void predict (raft::resources const & handle,
605+ const kmeans::params& params,
606+ raft::device_matrix_view<const float , int > X,
607+ std::optional<raft::device_vector_view<const float , int >> sample_weight,
608+ raft::device_matrix_view<const float , int > centroids,
609+ raft::device_vector_view<int64_t , int > labels,
610+ bool normalize_weight,
611+ raft::host_scalar_view<float > inertia)
612+ {
613+ predict (handle, params, X, sample_weight, centroids, normalize_weight, labels, inertia);
614+ }
615+
584616/* *
585617 * @brief Predict the closest cluster each sample in X belongs to.
586618 *
@@ -635,10 +667,26 @@ void predict(raft::resources const& handle,
635667 raft::device_matrix_view<const double , int > X,
636668 std::optional<raft::device_vector_view<const double , int >> sample_weight,
637669 raft::device_matrix_view<const double , int > centroids,
638- raft::device_vector_view<int , int > labels,
639670 bool normalize_weight,
671+ raft::device_vector_view<int , int > labels,
640672 raft::host_scalar_view<double > inertia);
641673
674+ // This overload is retained for backward compatibility.
675+ [[deprecated(
676+ " The argument order of kmeans::predict has been corrected. Please use the new function "
677+ " instead." )]]
678+ inline void predict (raft::resources const & handle,
679+ const kmeans::params& params,
680+ raft::device_matrix_view<const double , int > X,
681+ std::optional<raft::device_vector_view<const double , int >> sample_weight,
682+ raft::device_matrix_view<const double , int > centroids,
683+ raft::device_vector_view<int , int > labels,
684+ bool normalize_weight,
685+ raft::host_scalar_view<double > inertia)
686+ {
687+ predict (handle, params, X, sample_weight, centroids, normalize_weight, labels, inertia);
688+ }
689+
642690/* *
643691 * @brief Predict the closest cluster each sample in X belongs to.
644692 *
@@ -693,10 +741,26 @@ void predict(raft::resources const& handle,
693741 raft::device_matrix_view<const double , int > X,
694742 std::optional<raft::device_vector_view<const double , int >> sample_weight,
695743 raft::device_matrix_view<const double , int > centroids,
696- raft::device_vector_view<int64_t , int > labels,
697744 bool normalize_weight,
745+ raft::device_vector_view<int64_t , int > labels,
698746 raft::host_scalar_view<double > inertia);
699747
748+ // This overload is retained for backward compatibility.
749+ [[deprecated(
750+ " The argument order of kmeans::predict has been corrected. Please use the new function "
751+ " instead." )]]
752+ inline void predict (raft::resources const & handle,
753+ const kmeans::params& params,
754+ raft::device_matrix_view<const double , int > X,
755+ std::optional<raft::device_vector_view<const double , int >> sample_weight,
756+ raft::device_matrix_view<const double , int > centroids,
757+ raft::device_vector_view<int64_t , int > labels,
758+ bool normalize_weight,
759+ raft::host_scalar_view<double > inertia)
760+ {
761+ predict (handle, params, X, sample_weight, centroids, normalize_weight, labels, inertia);
762+ }
763+
700764/* *
701765 * @brief Predict the closest cluster each sample in X belongs to.
702766 *
0 commit comments