-
Notifications
You must be signed in to change notification settings - Fork 623
Prepare cuml for removal of deprecated raft apis #7561
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 15 commits
b2d1c07
7c9199d
df7d906
8883369
1d94692
cc11ecc
f13d176
0b1204e
5899a33
72776f0
0014575
ef1f3b7
bc4e91e
988ad8d
fc3da25
28aef64
320cd83
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -8,9 +8,13 @@ | |
|
|
||
| #include <raft/core/device_resources.hpp> | ||
| #include <raft/core/handle.hpp> | ||
| #include <raft/core/operators.hpp> | ||
| #include <raft/label/classlabels.cuh> | ||
| #include <raft/spatial/knn/ann.cuh> | ||
| #include <raft/spatial/knn/knn.cuh> | ||
| #include <raft/linalg/matrix_vector_op.cuh> | ||
| #include <raft/linalg/norm.cuh> | ||
| #include <raft/linalg/reduce.cuh> | ||
| #include <raft/linalg/unary_op.cuh> | ||
| #include <raft/stats/mean_center.cuh> | ||
| #include <raft/util/cuda_utils.cuh> | ||
|
|
||
| #include <rmm/device_uvector.hpp> | ||
|
|
@@ -34,6 +38,9 @@ namespace ML { | |
| struct knnIndexImpl { | ||
| std::unique_ptr<cuvs::neighbors::ivf_flat::index<float, int64_t>> ivf_flat; | ||
| std::unique_ptr<cuvs::neighbors::ivf_pq::index<int64_t>> ivf_pq; | ||
|
|
||
| std::unique_ptr<rmm::device_uvector<float>> corr_norms; | ||
| std::unique_ptr<rmm::device_uvector<float>> corr_means; | ||
| }; | ||
|
|
||
| knnIndex::knnIndex() : pimpl{std::make_unique<knnIndexImpl>()} {} | ||
|
|
@@ -215,21 +222,34 @@ void approx_knn_build_index(raft::handle_t& handle, | |
|
|
||
| auto ivf_ft_pams = dynamic_cast<IVFFlatParam*>(params); | ||
| auto ivf_pq_pams = dynamic_cast<IVFPQParam*>(params); | ||
|
|
||
| index->metric_processor = raft::spatial::knn::create_processor<false, float>( | ||
| static_cast<raft::distance::DistanceType>(metric), | ||
| n, | ||
| D, | ||
| 0, | ||
| raft::resource::get_cuda_stream(handle)); | ||
| // For cosine/correlation distance, the metric processor translates distance | ||
| // to inner product via pre/post processing - pass the translated metric to | ||
| // ANN index | ||
| if (metric == ML::distance::DistanceType::CosineExpanded || | ||
| metric == ML::distance::DistanceType::CorrelationExpanded) { | ||
| metric = index->metric = ML::distance::DistanceType::InnerProduct; | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do you know why we used to set
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. But, if I understood correctly
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes,
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, that's what I understood |
||
| auto stream = raft::resource::get_cuda_stream(handle); | ||
|
|
||
| // For correlation: preprocess (center + normalize), use InnerProduct, then revert | ||
| if (metric == ML::distance::DistanceType::CorrelationExpanded) { | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Did you mean
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, I meant
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes L2 is handled in cuvs, but Lp is not. I've updated the code to only do processing on Lp. |
||
| index->pimpl->corr_means = std::make_unique<rmm::device_uvector<float>>(n, stream); | ||
| index->pimpl->corr_norms = std::make_unique<rmm::device_uvector<float>>(n, stream); | ||
|
|
||
| // Compute means and center data | ||
| float normalizer = 1.0f / static_cast<float>(D); | ||
| raft::linalg::reduce<false, true>( | ||
| index->pimpl->corr_means->data(), index_array, D, n, 0.0f, stream); | ||
| raft::linalg::unaryOp(index->pimpl->corr_means->data(), | ||
| index->pimpl->corr_means->data(), | ||
| n, | ||
| raft::mul_const_op<float>(normalizer), | ||
| stream); | ||
| raft::stats::meanCenter<false, false>( | ||
| index_array, index_array, index->pimpl->corr_means->data(), D, n, stream); | ||
|
|
||
| // Compute norms and normalize | ||
| raft::linalg::rowNorm<raft::linalg::L2Norm, false>( | ||
| index->pimpl->corr_norms->data(), index_array, D, n, stream, raft::sqrt_op{}); | ||
| raft::linalg::matrixVectorOp<false, false>( | ||
| index_array, index_array, index->pimpl->corr_norms->data(), D, n, raft::div_op{}, stream); | ||
|
|
||
| metric = ML::distance::DistanceType::InnerProduct; | ||
| } | ||
| index->metric_processor->preprocess(index_array); | ||
|
|
||
| auto index_view = raft::make_device_matrix_view<const float, int64_t>(index_array, n, D); | ||
|
|
||
| if (ivf_ft_pams) { | ||
|
|
@@ -257,7 +277,13 @@ void approx_knn_build_index(raft::handle_t& handle, | |
| RAFT_FAIL("Unrecognized index type."); | ||
| } | ||
|
|
||
| index->metric_processor->revert(index_array); | ||
| // Revert user data for correlation | ||
| if (index->metric == ML::distance::DistanceType::CorrelationExpanded) { | ||
| raft::linalg::matrixVectorOp<false, false>( | ||
| index_array, index_array, index->pimpl->corr_norms->data(), D, n, raft::mul_op{}, stream); | ||
| raft::stats::meanAdd<false, false>( | ||
| index_array, index_array, index->pimpl->corr_means->data(), D, n, stream); | ||
| } | ||
| } | ||
|
|
||
| void approx_knn_search(raft::handle_t& handle, | ||
|
|
@@ -268,8 +294,32 @@ void approx_knn_search(raft::handle_t& handle, | |
| float* query_array, | ||
| int n) | ||
| { | ||
| index->metric_processor->preprocess(query_array); | ||
| index->metric_processor->set_num_queries(k); | ||
| auto stream = raft::resource::get_cuda_stream(handle); | ||
|
|
||
| // Get dimension from index | ||
| int D = index->pimpl->ivf_flat ? index->pimpl->ivf_flat->dim() : index->pimpl->ivf_pq->dim(); | ||
|
|
||
| // Temporary storage for correlation query preprocessing | ||
| std::unique_ptr<rmm::device_uvector<float>> query_means; | ||
| std::unique_ptr<rmm::device_uvector<float>> query_norms; | ||
|
|
||
| // Preprocess queries for correlation | ||
| if (index->metric == ML::distance::DistanceType::CorrelationExpanded) { | ||
| query_means = std::make_unique<rmm::device_uvector<float>>(n, stream); | ||
| query_norms = std::make_unique<rmm::device_uvector<float>>(n, stream); | ||
|
|
||
| float normalizer = 1.0f / static_cast<float>(D); | ||
| raft::linalg::reduce<false, true>(query_means->data(), query_array, D, n, 0.0f, stream); | ||
| raft::linalg::unaryOp( | ||
| query_means->data(), query_means->data(), n, raft::mul_const_op<float>(normalizer), stream); | ||
| raft::stats::meanCenter<false, false>( | ||
| query_array, query_array, query_means->data(), D, n, stream); | ||
|
|
||
| raft::linalg::rowNorm<raft::linalg::L2Norm, false>( | ||
| query_norms->data(), query_array, D, n, stream, raft::sqrt_op{}); | ||
| raft::linalg::matrixVectorOp<false, false>( | ||
| query_array, query_array, query_norms->data(), D, n, raft::div_op{}, stream); | ||
| } | ||
|
|
||
| auto indices_view = raft::make_device_matrix_view<int64_t, int64_t>(indices, n, k); | ||
| auto distances_view = raft::make_device_matrix_view<float, int64_t>(distances, n, k); | ||
|
|
@@ -294,7 +344,12 @@ void approx_knn_search(raft::handle_t& handle, | |
| RAFT_FAIL("The model is not trained"); | ||
| } | ||
|
|
||
| index->metric_processor->revert(query_array); | ||
| // Revert query data for correlation | ||
| if (index->metric == ML::distance::DistanceType::CorrelationExpanded) { | ||
| raft::linalg::matrixVectorOp<false, false>( | ||
| query_array, query_array, query_norms->data(), D, n, raft::mul_op{}, stream); | ||
| raft::stats::meanAdd<false, false>(query_array, query_array, query_means->data(), D, n, stream); | ||
| } | ||
|
|
||
| // perform post-processing to show the real distances | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. How about that post-processing, is it also needed ? |
||
| if (index->metric == ML::distance::DistanceType::L2SqrtExpanded || | ||
|
|
@@ -311,7 +366,12 @@ void approx_knn_search(raft::handle_t& handle, | |
| raft::pow_const_op<float>(p), | ||
| raft::resource::get_cuda_stream(handle)); | ||
| } | ||
| index->metric_processor->postprocess(distances); | ||
|
|
||
| // Post-process correlation: convert inner product to correlation distance | ||
| if (index->metric == ML::distance::DistanceType::CorrelationExpanded) { | ||
| raft::linalg::unaryOp( | ||
| distances, distances, n * k, [] __device__(float in) { return 1.0f - in; }, stream); | ||
| } | ||
| } | ||
|
|
||
| void knn_classify(raft::handle_t& handle, | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Revert later