Skip to content
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions cpp/cmake/thirdparty/get_raft.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,8 @@ endfunction()
# To use a different RAFT locally, set the CMake variable
# CPM_raft_SOURCE=/path/to/local/raft
find_and_configure_raft(VERSION ${CUML_MIN_VERSION_raft}
FORK rapidsai
PINNED_TAG ${rapids-cmake-checkout-tag}
FORK aamijar
PINNED_TAG raft-deprecated-apis
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Revert later

EXCLUDE_FROM_ALL ${CUML_EXCLUDE_RAFT_FROM_ALL}
# When PINNED_TAG above doesn't match cuml,
# force local raft clone in build directory
Expand Down
3 changes: 0 additions & 3 deletions cpp/include/cuml/neighbors/knn.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,6 @@

#include <cuml/common/distance_type.hpp>

#include <raft/spatial/knn/detail/processing.hpp> // MetricProcessor

#include <cstdint>
#include <memory>
#include <vector>
Expand Down Expand Up @@ -89,7 +87,6 @@ struct knnIndex {
ML::distance::DistanceType metric;
float metricArg;
int nprobe;
std::unique_ptr<raft::spatial::knn::MetricProcessor<float>> metric_processor;
int device;

std::unique_ptr<knnIndexImpl> pimpl;
Expand Down
102 changes: 81 additions & 21 deletions cpp/src/knn/knn.cu
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,13 @@

#include <raft/core/device_resources.hpp>
#include <raft/core/handle.hpp>
#include <raft/core/operators.hpp>
#include <raft/label/classlabels.cuh>
#include <raft/spatial/knn/ann.cuh>
#include <raft/spatial/knn/knn.cuh>
#include <raft/linalg/matrix_vector_op.cuh>
#include <raft/linalg/norm.cuh>
#include <raft/linalg/reduce.cuh>
#include <raft/linalg/unary_op.cuh>
#include <raft/stats/mean_center.cuh>
#include <raft/util/cuda_utils.cuh>

#include <rmm/device_uvector.hpp>
Expand All @@ -34,6 +38,9 @@ namespace ML {
struct knnIndexImpl {
std::unique_ptr<cuvs::neighbors::ivf_flat::index<float, int64_t>> ivf_flat;
std::unique_ptr<cuvs::neighbors::ivf_pq::index<int64_t>> ivf_pq;

std::unique_ptr<rmm::device_uvector<float>> corr_norms;
std::unique_ptr<rmm::device_uvector<float>> corr_means;
};

knnIndex::knnIndex() : pimpl{std::make_unique<knnIndexImpl>()} {}
Expand Down Expand Up @@ -215,21 +222,34 @@ void approx_knn_build_index(raft::handle_t& handle,

auto ivf_ft_pams = dynamic_cast<IVFFlatParam*>(params);
auto ivf_pq_pams = dynamic_cast<IVFPQParam*>(params);

index->metric_processor = raft::spatial::knn::create_processor<false, float>(
static_cast<raft::distance::DistanceType>(metric),
n,
D,
0,
raft::resource::get_cuda_stream(handle));
// For cosine/correlation distance, the metric processor translates distance
// to inner product via pre/post processing - pass the translated metric to
// ANN index
if (metric == ML::distance::DistanceType::CosineExpanded ||
metric == ML::distance::DistanceType::CorrelationExpanded) {
metric = index->metric = ML::distance::DistanceType::InnerProduct;
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you know why we used to set index->metric?

Copy link
Copy Markdown
Member Author

@aamijar aamijar Dec 15, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

index->metric is still used as usual at the beginning of the function where it is set to
index->metric = metric then forwarded to cuvs. The reason we set it to inner product in the if statement before is because of the special processing that was required since cuvs ivf-flat and ivf-pq didn't support cosine or correlation. So to do the equivalent computation we used inner product + pre/post processing.

Copy link
Copy Markdown
Contributor

@viclafargue viclafargue Dec 15, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But, if I understood correctly CorrelationExpanded is not supported. We implement pre/post processing here, but should pass InnerProduct metric to cuVS, right? It looks like in this case metric = InnerProduct, but index->metric = CorrelationExpanded. Why don't we do metric = index->metric = InnerProduct?

Copy link
Copy Markdown
Member Author

@aamijar aamijar Dec 15, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

metric variable is what is actually passed to cuvs. index->metric is just recording locally what the original metric was.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, metric is what is sent to cuVS anyway so it should work, but index->metric is kept to CorrelationExpanded unlike before, but I guess this is just an implementation detail and is made on purpose.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, that's what I understood

auto stream = raft::resource::get_cuda_stream(handle);

// For correlation: preprocess (center + normalize), use InnerProduct, then revert
if (metric == ML::distance::DistanceType::CorrelationExpanded) {
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The CorrelationExpanded case seems to be handled correctly. Most metrics would use the DefaultMetricProcessor that does not do anything. But, unless I am missing something it looks like the CorrelationExpanded is not handled atm.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did you mean CosineExpanded is not handled? CosineExpanded does have support in cuvs ivf-flat and ivf-pq now so we don't need to do special processing and override the metric to be inner product.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Copy Markdown
Contributor

@viclafargue viclafargue Dec 15, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I meant CosineExpanded. Makes sense. What about Lp/L2 metrics postprocessing, is it not handled in cuVS? It looks like we are sending the metricArg to cuVS.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes L2 is handled in cuvs, but Lp is not. I've updated the code to only do processing on Lp.

index->pimpl->corr_means = std::make_unique<rmm::device_uvector<float>>(n, stream);
index->pimpl->corr_norms = std::make_unique<rmm::device_uvector<float>>(n, stream);

// Compute means and center data
float normalizer = 1.0f / static_cast<float>(D);
raft::linalg::reduce<false, true>(
index->pimpl->corr_means->data(), index_array, D, n, 0.0f, stream);
raft::linalg::unaryOp(index->pimpl->corr_means->data(),
index->pimpl->corr_means->data(),
n,
raft::mul_const_op<float>(normalizer),
stream);
raft::stats::meanCenter<false, false>(
index_array, index_array, index->pimpl->corr_means->data(), D, n, stream);

// Compute norms and normalize
raft::linalg::rowNorm<raft::linalg::L2Norm, false>(
index->pimpl->corr_norms->data(), index_array, D, n, stream, raft::sqrt_op{});
raft::linalg::matrixVectorOp<false, false>(
index_array, index_array, index->pimpl->corr_norms->data(), D, n, raft::div_op{}, stream);

metric = ML::distance::DistanceType::InnerProduct;
}
index->metric_processor->preprocess(index_array);

auto index_view = raft::make_device_matrix_view<const float, int64_t>(index_array, n, D);

if (ivf_ft_pams) {
Expand Down Expand Up @@ -257,7 +277,13 @@ void approx_knn_build_index(raft::handle_t& handle,
RAFT_FAIL("Unrecognized index type.");
}

index->metric_processor->revert(index_array);
// Revert user data for correlation
if (index->metric == ML::distance::DistanceType::CorrelationExpanded) {
raft::linalg::matrixVectorOp<false, false>(
index_array, index_array, index->pimpl->corr_norms->data(), D, n, raft::mul_op{}, stream);
raft::stats::meanAdd<false, false>(
index_array, index_array, index->pimpl->corr_means->data(), D, n, stream);
}
}

void approx_knn_search(raft::handle_t& handle,
Expand All @@ -268,8 +294,32 @@ void approx_knn_search(raft::handle_t& handle,
float* query_array,
int n)
{
index->metric_processor->preprocess(query_array);
index->metric_processor->set_num_queries(k);
auto stream = raft::resource::get_cuda_stream(handle);

// Get dimension from index
int D = index->pimpl->ivf_flat ? index->pimpl->ivf_flat->dim() : index->pimpl->ivf_pq->dim();

// Temporary storage for correlation query preprocessing
std::unique_ptr<rmm::device_uvector<float>> query_means;
std::unique_ptr<rmm::device_uvector<float>> query_norms;

// Preprocess queries for correlation
if (index->metric == ML::distance::DistanceType::CorrelationExpanded) {
query_means = std::make_unique<rmm::device_uvector<float>>(n, stream);
query_norms = std::make_unique<rmm::device_uvector<float>>(n, stream);

float normalizer = 1.0f / static_cast<float>(D);
raft::linalg::reduce<false, true>(query_means->data(), query_array, D, n, 0.0f, stream);
raft::linalg::unaryOp(
query_means->data(), query_means->data(), n, raft::mul_const_op<float>(normalizer), stream);
raft::stats::meanCenter<false, false>(
query_array, query_array, query_means->data(), D, n, stream);

raft::linalg::rowNorm<raft::linalg::L2Norm, false>(
query_norms->data(), query_array, D, n, stream, raft::sqrt_op{});
raft::linalg::matrixVectorOp<false, false>(
query_array, query_array, query_norms->data(), D, n, raft::div_op{}, stream);
}

auto indices_view = raft::make_device_matrix_view<int64_t, int64_t>(indices, n, k);
auto distances_view = raft::make_device_matrix_view<float, int64_t>(distances, n, k);
Expand All @@ -294,7 +344,12 @@ void approx_knn_search(raft::handle_t& handle,
RAFT_FAIL("The model is not trained");
}

index->metric_processor->revert(query_array);
// Revert query data for correlation
if (index->metric == ML::distance::DistanceType::CorrelationExpanded) {
raft::linalg::matrixVectorOp<false, false>(
query_array, query_array, query_norms->data(), D, n, raft::mul_op{}, stream);
raft::stats::meanAdd<false, false>(query_array, query_array, query_means->data(), D, n, stream);
}

// perform post-processing to show the real distances
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about that post-processing, is it also needed ?

if (index->metric == ML::distance::DistanceType::L2SqrtExpanded ||
Expand All @@ -311,7 +366,12 @@ void approx_knn_search(raft::handle_t& handle,
raft::pow_const_op<float>(p),
raft::resource::get_cuda_stream(handle));
}
index->metric_processor->postprocess(distances);

// Post-process correlation: convert inner product to correlation distance
if (index->metric == ML::distance::DistanceType::CorrelationExpanded) {
raft::linalg::unaryOp(
distances, distances, n * k, [] __device__(float in) { return 1.0f - in; }, stream);
}
}

void knn_classify(raft::handle_t& handle,
Expand Down
1 change: 0 additions & 1 deletion cpp/src/knn/knn_opg_common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
#include <cumlprims/opg/matrix/part_descriptor.hpp>
#include <raft/core/comms.hpp>
#include <raft/core/handle.hpp>
#include <raft/spatial/knn/knn.cuh>
#include <raft/util/cuda_utils.cuh>
#include <raft/util/cudart_utils.hpp>

Expand Down
1 change: 0 additions & 1 deletion cpp/src/metrics/pairwise_distance.cu
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
#include <cuml/metrics/metrics.hpp>

#include <raft/core/handle.hpp>
#include <raft/distance/distance.cuh>

#include <cuvs/distance/distance.hpp>

Expand Down
1 change: 0 additions & 1 deletion cpp/src/svm/linear.cu
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@

#include <raft/core/handle.hpp>
#include <raft/core/nvtx.hpp>
#include <raft/distance/kernels.cuh>
#include <raft/label/classlabels.cuh>
#include <raft/linalg/gemm.cuh>
#include <raft/linalg/gemv.cuh>
Expand Down
2 changes: 0 additions & 2 deletions cpp/src/svm/smosolver.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@
#include "ws_util.cuh"

#include <raft/core/handle.hpp>
#include <raft/distance/distance_types.hpp>
#include <raft/distance/kernels.cuh>
#include <raft/linalg/detail/cublas_wrappers.hpp>
#include <raft/linalg/gemv.cuh>
#include <raft/linalg/unary_op.cuh>
Expand Down
1 change: 0 additions & 1 deletion cpp/src_prims/selection/knn.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
#include <cuml/neighbors/knn.hpp>

#include <raft/core/handle.hpp>
#include <raft/distance/distance.cuh>
#include <raft/label/classlabels.cuh>
#include <raft/util/cuda_utils.cuh>
#include <raft/util/cudart_utils.hpp>
Expand Down
3 changes: 1 addition & 2 deletions cpp/tests/prims/knn_classify.cu
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2019-2024, NVIDIA CORPORATION.
* SPDX-FileCopyrightText: Copyright (c) 2019-2025, NVIDIA CORPORATION.
* SPDX-License-Identifier: Apache-2.0
*/

#include "test_utils.h"

#include <raft/label/classlabels.cuh>
#include <raft/random/make_blobs.cuh>
#include <raft/spatial/knn/knn.cuh>
#include <raft/util/cuda_utils.cuh>
#include <raft/util/cudart_utils.hpp>

Expand Down
1 change: 0 additions & 1 deletion cpp/tests/prims/knn_regression.cu
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
#include <raft/label/classlabels.cuh>
#include <raft/linalg/reduce.cuh>
#include <raft/random/rng.cuh>
#include <raft/spatial/knn/knn.cuh>
#include <raft/util/cuda_utils.cuh>
#include <raft/util/cudart_utils.hpp>

Expand Down
1 change: 0 additions & 1 deletion cpp/tests/sg/dbscan_test.cu
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
#include <cuml/metrics/metrics.hpp>

#include <raft/core/handle.hpp>
#include <raft/distance/distance.cuh>
#include <raft/linalg/transpose.cuh>
#include <raft/util/cuda_utils.cuh>
#include <raft/util/cudart_utils.hpp>
Expand Down
18 changes: 9 additions & 9 deletions cpp/tests/sg/hdbscan_test.cu
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
#include <cuml/cluster/hdbscan.hpp>
#include <cuml/common/distance_type.hpp>

#include <raft/cluster/detail/agglomerative.cuh> // build_dendrogram_host
#include <raft/core/handle.hpp>
#include <raft/linalg/transpose.cuh>
#include <raft/sparse/coo.hpp>
Expand Down Expand Up @@ -165,14 +164,15 @@ class ClusterCondensingTest : public ::testing::TestWithParam<ClusterCondensingI
/**
* Build dendrogram of MST
*/
raft::cluster::detail::build_dendrogram_host(handle,
mst_src.data(),
mst_dst.data(),
mst_data.data(),
params.n_row - 1,
out_children.data(),
out_delta.data(),
out_size.data());
cuvs::cluster::agglomerative::helpers::build_dendrogram(
handle,
raft::make_device_vector_view<const IdxT, IdxT>(mst_src.data(), params.n_row - 1),
raft::make_device_vector_view<const IdxT, IdxT>(mst_dst.data(), params.n_row - 1),
raft::make_device_vector_view<const T, IdxT>(mst_data.data(), params.n_row - 1),
raft::make_device_matrix_view<IdxT, IdxT, raft::row_major>(
out_children.data(), params.n_row - 1, 2),
raft::make_device_vector_view<T, IdxT>(out_delta.data(), params.n_row - 1),
raft::make_device_vector_view<IdxT, IdxT>(out_size.data(), params.n_row - 1));

/**
* Condense Hierarchy
Expand Down
2 changes: 0 additions & 2 deletions cpp/tests/sg/umap_parametrizable_test.cu
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,7 @@

#include <raft/core/handle.hpp>
#include <raft/core/host_coo_matrix.hpp>
#include <raft/distance/distance.cuh>
#include <raft/linalg/reduce_rows_by_key.cuh>
#include <raft/spatial/knn/knn.cuh>
#include <raft/util/cuda_utils.cuh>
#include <raft/util/cudart_utils.hpp>

Expand Down