Skip to content
Merged
Show file tree
Hide file tree
Changes from 103 commits
Commits
Show all changes
113 commits
Select commit Hold shift + click to select a range
07dbefe
initial
tarang-jain Jun 25, 2024
b5c1f2c
Merge branch 'branch-24.08' of https://github.com/rapidsai/cuvs into …
tarang-jain Jun 25, 2024
2387a15
Merge branch 'branch-24.08' of https://github.com/rapidsai/cuvs into …
tarang-jain Jun 28, 2024
6156cf5
update postprocess_distances
tarang-jain Jun 28, 2024
79de8a8
resolve merge conflicts
tarang-jain Aug 13, 2025
cff7494
update tests
tarang-jain Aug 13, 2025
a38b92b
update instantiations
tarang-jain Aug 13, 2025
0e1c980
re-update cagra-search
tarang-jain Aug 13, 2025
2f19510
corrections
tarang-jain Aug 13, 2025
93a9944
correct
tarang-jain Aug 13, 2025
c732eca
correct query normalization
tarang-jain Aug 13, 2025
ccfda68
correct template type
tarang-jain Aug 13, 2025
954e417
use ip dist_op
tarang-jain Aug 13, 2025
7dd97fe
cleanup
tarang-jain Aug 13, 2025
b753ceb
Merge branch 'branch-25.10' into cagra-dist-metric
tarang-jain Aug 13, 2025
56a2a81
cleanup
tarang-jain Aug 13, 2025
c166424
Merge branch 'cagra-dist-metric' of https://github.com/tarang-jain/cu…
tarang-jain Aug 13, 2025
6cddce9
style
tarang-jain Aug 13, 2025
7ca1936
only float and half
tarang-jain Aug 13, 2025
8586fc4
compute dataset norm
tarang-jain Aug 13, 2025
8e185dd
fix errors
tarang-jain Aug 13, 2025
a828f7d
compilation errors
tarang-jain Aug 13, 2025
c148ac6
fix compilation errors
tarang-jain Aug 13, 2025
9bc030b
compilation errors
tarang-jain Aug 13, 2025
f292a49
error instead of warning
tarang-jain Aug 13, 2025
667c2bb
fix error
tarang-jain Aug 13, 2025
b7fe9ec
Merge branch 'branch-25.10' into cagra-dist-metric
tarang-jain Aug 14, 2025
4a20ade
Merge branch 'branch-25.10' into cagra-dist-metric
cjnolet Aug 14, 2025
290fc18
fix compilation;add cmake targets for spec
tarang-jain Aug 14, 2025
59b6333
Merge branch 'cagra-dist-metric' of https://github.com/tarang-jain/cu…
tarang-jain Aug 14, 2025
91a3306
debug
tarang-jain Aug 15, 2025
2a9789e
Merge branch 'branch-25.10' into cagra-dist-metric
tarang-jain Aug 18, 2025
9fe270f
everything seems to be working for cosine metric
tarang-jain Aug 20, 2025
002dbf7
Merge branch 'branch-25.10' of https://github.com/rapidsai/cuvs into …
tarang-jain Aug 20, 2025
40a392a
Merge branch 'cagra-dist-metric' of https://github.com/tarang-jain/cu…
tarang-jain Aug 20, 2025
b7be4fe
move norm computation to helper
tarang-jain Aug 20, 2025
a24ba7c
separate out compute_dataset_norms;style
tarang-jain Aug 20, 2025
32f02bc
rm log statements
tarang-jain Aug 20, 2025
d85f526
cmake c flags,copyright
tarang-jain Aug 20, 2025
b56931f
rm extra files
tarang-jain Aug 20, 2025
40fcdab
cleanup docs'rm unused headers
tarang-jain Aug 20, 2025
ffb4d8d
assertion
tarang-jain Aug 20, 2025
2d7bd0d
cleanup tests
tarang-jain Aug 20, 2025
abb2e10
fix bad optional access
tarang-jain Aug 20, 2025
137c5c8
update python tests
tarang-jain Aug 20, 2025
684b465
update python tests
tarang-jain Aug 20, 2025
12d6d31
fix failing py tests
tarang-jain Aug 21, 2025
b871a80
style
tarang-jain Aug 21, 2025
926dc39
allow nnd and interative
tarang-jain Aug 21, 2025
5f52b84
compute_distance types for uint8 and int8
tarang-jain Aug 21, 2025
0ce20c7
clang format
tarang-jain Aug 21, 2025
b62c55c
add int8 and uint8 src files to CMakeLists.txt
tarang-jain Aug 21, 2025
d673179
update norm computation for iterative
tarang-jain Aug 22, 2025
f88a4ec
fix norm scaling
tarang-jain Aug 22, 2025
30d3882
style
tarang-jain Aug 22, 2025
40880ef
Merge branch 'branch-25.10' into cagra-dist-metric
tarang-jain Aug 22, 2025
0b5c183
compose_op with scale
tarang-jain Aug 22, 2025
dd9bb33
Merge branch 'branch-25.10' of https://github.com/rapidsai/cuvs into …
tarang-jain Aug 22, 2025
7ee4eaf
Merge branch 'cagra-dist-metric' of https://github.com/tarang-jain/cu…
tarang-jain Aug 22, 2025
9aae03b
correct test skips
tarang-jain Aug 23, 2025
a1f0689
compute scaled norms
tarang-jain Aug 25, 2025
215cbc5
Merge branch 'branch-25.10' into cagra-dist-metric
tarang-jain Aug 25, 2025
c660038
debug
tarang-jain Aug 25, 2025
5731bd8
Merge branch 'cagra-dist-metric' of https://github.com/tarang-jain/cu…
tarang-jain Aug 25, 2025
1084bc9
checkzero div
tarang-jain Aug 25, 2025
e6faf97
update tests; rm iterative
tarang-jain Aug 25, 2025
13431cb
Merge branch 'branch-25.10' into cagra-dist-metric
tarang-jain Aug 25, 2025
6e6bfb2
update skip conditions
tarang-jain Aug 25, 2025
7650ef5
Merge branch 'cagra-dist-metric' of https://github.com/tarang-jain/cu…
tarang-jain Aug 25, 2025
ffd8d44
shorten diff
tarang-jain Aug 25, 2025
134b257
rm debug prints;docs
tarang-jain Aug 26, 2025
9507cf6
rm double computation of norms
tarang-jain Aug 26, 2025
b01d54f
rm unused header
tarang-jain Aug 26, 2025
e548e09
rm set_dataset_norms;simplify compute_dataset_norms
tarang-jain Aug 26, 2025
7ec1263
Merge branch 'branch-25.10' of https://github.com/rapidsai/cuvs into …
tarang-jain Aug 26, 2025
03e2f79
Merge branch 'branch-25.10' into cagra-dist-metric
tarang-jain Sep 2, 2025
eb94316
Merge branch 'branch-25.10' into cagra-dist-metric
tarang-jain Sep 2, 2025
fcb0c8e
Merge branch 'branch-25.10' into cagra-dist-metric
tarang-jain Sep 5, 2025
b3bef25
Merge branch 'branch-25.10' into cagra-dist-metric
tarang-jain Sep 9, 2025
d6ddce9
Merge branch 'branch-25.10' into cagra-dist-metric
cjnolet Sep 15, 2025
38e2549
compute_dataset_norms private function
tarang-jain Sep 15, 2025
9424bf6
Merge branch 'cagra-dist-metric' of https://github.com/tarang-jain/cu…
tarang-jain Sep 15, 2025
538d792
Merge branch 'branch-25.10' into cagra-dist-metric
tarang-jain Sep 15, 2025
ca0b7c2
Merge branch 'branch-25.10' into cagra-dist-metric
tarang-jain Sep 17, 2025
2e6fe2f
update cagra python test
tarang-jain Sep 18, 2025
de4fbdc
Merge branch 'branch-25.10' into cagra-dist-metric
tarang-jain Sep 18, 2025
2222009
Update cpp/src/neighbors/cagra.cuh
tarang-jain Sep 22, 2025
f180794
Merge branch 'branch-25.10' into cagra-dist-metric
tarang-jain Sep 22, 2025
0b9c9e1
deallocate norms
tarang-jain Sep 23, 2025
2e60619
pull origin
tarang-jain Sep 23, 2025
fc7fdde
ivfpq cosine support for int types
tarang-jain Sep 23, 2025
66adf7b
Merge branch 'branch-25.10' into cagra-dist-metric
tarang-jain Sep 23, 2025
2f51b66
rm gtest filter for ivfpq
tarang-jain Sep 23, 2025
405c21f
Merge branch 'cagra-dist-metric' of https://github.com/tarang-jain/cu…
tarang-jain Sep 23, 2025
9080462
Merge branch 'branch-25.10' into cagra-dist-metric
tarang-jain Sep 25, 2025
5a7a694
Merge branch 'branch-25.10' into cagra-dist-metric
tarang-jain Sep 25, 2025
887d82b
Update cpp/src/neighbors/ivf_pq/ivf_pq_build.cuh
tarang-jain Sep 25, 2025
d4d2cff
style
tarang-jain Sep 25, 2025
c0801ac
update cagra tests
tarang-jain Sep 26, 2025
accd841
Merge branch 'branch-25.10' of https://github.com/rapidsai/cuvs into …
tarang-jain Sep 26, 2025
1c33e17
fix cpp warning
tarang-jain Sep 26, 2025
5da970a
Merge branch 'branch-25.10' into cagra-dist-metric
tarang-jain Sep 27, 2025
544dd8d
Merge branch 'branch-25.10' into cagra-dist-metric
tarang-jain Sep 29, 2025
eec8c04
update tests
tarang-jain Sep 29, 2025
b416b81
Merge branch 'cagra-dist-metric' of https://github.com/tarang-jain/cu…
tarang-jain Sep 29, 2025
c5089f7
update tests
tarang-jain Sep 30, 2025
d689129
fix syntax
tarang-jain Sep 30, 2025
a6c6592
fix compilation errors
tarang-jain Sep 30, 2025
99f1317
fix cosine docstring
tarang-jain Sep 30, 2025
c17f380
fix cosine docstring
tarang-jain Sep 30, 2025
f4c8d82
Merge branch 'branch-25.10' into cagra-dist-metric
tarang-jain Sep 30, 2025
da6ea75
Merge branch 'branch-25.10' into cagra-dist-metric
tarang-jain Oct 1, 2025
883d368
Merge branch 'branch-25.12' into cagra-dist-metric
tarang-jain Oct 2, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,9 @@ if(NOT BUILD_CPU_ONLY)
src/neighbors/detail/cagra/compute_distance_standard_InnerProduct_float_uint32_dim128_t8.cu
src/neighbors/detail/cagra/compute_distance_standard_InnerProduct_float_uint32_dim256_t16.cu
src/neighbors/detail/cagra/compute_distance_standard_InnerProduct_float_uint32_dim512_t32.cu
src/neighbors/detail/cagra/compute_distance_standard_CosineExpanded_float_uint32_dim128_t8.cu
src/neighbors/detail/cagra/compute_distance_standard_CosineExpanded_float_uint32_dim256_t16.cu
src/neighbors/detail/cagra/compute_distance_standard_CosineExpanded_float_uint32_dim512_t32.cu
src/neighbors/detail/cagra/compute_distance_standard_InnerProduct_half_uint32_dim128_t8.cu
src/neighbors/detail/cagra/compute_distance_standard_InnerProduct_half_uint32_dim256_t16.cu
src/neighbors/detail/cagra/compute_distance_standard_InnerProduct_half_uint32_dim512_t32.cu
Expand All @@ -246,6 +249,15 @@ if(NOT BUILD_CPU_ONLY)
src/neighbors/detail/cagra/compute_distance_standard_L2Expanded_half_uint32_dim128_t8.cu
src/neighbors/detail/cagra/compute_distance_standard_L2Expanded_half_uint32_dim256_t16.cu
src/neighbors/detail/cagra/compute_distance_standard_L2Expanded_half_uint32_dim512_t32.cu
src/neighbors/detail/cagra/compute_distance_standard_CosineExpanded_half_uint32_dim128_t8.cu
src/neighbors/detail/cagra/compute_distance_standard_CosineExpanded_half_uint32_dim256_t16.cu
src/neighbors/detail/cagra/compute_distance_standard_CosineExpanded_half_uint32_dim512_t32.cu
src/neighbors/detail/cagra/compute_distance_standard_CosineExpanded_int8_uint32_dim128_t8.cu
src/neighbors/detail/cagra/compute_distance_standard_CosineExpanded_int8_uint32_dim256_t16.cu
src/neighbors/detail/cagra/compute_distance_standard_CosineExpanded_int8_uint32_dim512_t32.cu
src/neighbors/detail/cagra/compute_distance_standard_CosineExpanded_uint8_uint32_dim128_t8.cu
src/neighbors/detail/cagra/compute_distance_standard_CosineExpanded_uint8_uint32_dim256_t16.cu
src/neighbors/detail/cagra/compute_distance_standard_CosineExpanded_uint8_uint32_dim512_t32.cu
src/neighbors/detail/cagra/compute_distance_standard_L2Expanded_int8_uint32_dim128_t8.cu
src/neighbors/detail/cagra/compute_distance_standard_L2Expanded_int8_uint32_dim256_t16.cu
src/neighbors/detail/cagra/compute_distance_standard_L2Expanded_int8_uint32_dim512_t32.cu
Expand Down
67 changes: 65 additions & 2 deletions cpp/include/cuvs/neighbors/cagra.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -339,6 +339,14 @@ struct index : cuvs::neighbors::index {
return graph_view_;
}

/** Dataset norms for cosine distance [size] */
[[nodiscard]] inline auto dataset_norms() const noexcept
-> std::optional<raft::device_vector_view<const float, int64_t>>
{
if (dataset_norms_.has_value()) { return raft::make_const_mdspan(dataset_norms_->view()); }
return std::nullopt;
}

// Don't allow copying the index for performance reasons (try avoiding copying data)
/** \cond */
index(const index&) = delete;
Expand All @@ -354,7 +362,8 @@ struct index : cuvs::neighbors::index {
: cuvs::neighbors::index(),
metric_(metric),
graph_(raft::make_device_matrix<IdxT, int64_t>(res, 0, 0)),
dataset_(new cuvs::neighbors::empty_dataset<int64_t>(0))
dataset_(new cuvs::neighbors::empty_dataset<int64_t>(0)),
dataset_norms_(std::nullopt)
{
}

Expand Down Expand Up @@ -420,12 +429,21 @@ struct index : cuvs::neighbors::index {
: cuvs::neighbors::index(),
metric_(metric),
graph_(raft::make_device_matrix<IdxT, int64_t>(res, 0, 0)),
dataset_(make_aligned_dataset(res, dataset, 16))
dataset_(make_aligned_dataset(res, dataset, 16)),
dataset_norms_(std::nullopt)
{
RAFT_EXPECTS(dataset.extent(0) == knn_graph.extent(0),
"Dataset and knn_graph must have equal number of rows");
update_graph(res, knn_graph);

if (metric_ == cuvs::distance::DistanceType::CosineExpanded) {
auto p = dynamic_cast<strided_dataset<T, int64_t>*>(dataset_.get());
if (p) {
auto dataset_view = p->view();
if (dataset_view.extent(0) > 0) { compute_dataset_norms_(res); }
}
}

raft::resource::sync_stream(res);
}

Expand All @@ -435,48 +453,81 @@ struct index : cuvs::neighbors::index {
* If the new dataset rows are aligned on 16 bytes, then only a reference is stored to the
* dataset. It is the caller's responsibility to ensure that dataset stays alive as long as the
* index. It is expected that the same set of vectors are used for update_dataset and index build.
*
* Note: This will clear any precomputed dataset norms.
*/
void update_dataset(raft::resources const& res,
raft::device_matrix_view<const T, int64_t, raft::row_major> dataset)
{
dataset_ = make_aligned_dataset(res, dataset, 16);
dataset_norms_.reset();

if (metric() == cuvs::distance::DistanceType::CosineExpanded) {
if (dataset.extent(0) > 0) { compute_dataset_norms_(res); }
}
}

/** Set the dataset reference explicitly to a device matrix view with padding. */
void update_dataset(raft::resources const& res,
raft::device_matrix_view<const T, int64_t, raft::layout_stride> dataset)
{
dataset_ = make_aligned_dataset(res, dataset, 16);
dataset_norms_.reset();

if (metric() == cuvs::distance::DistanceType::CosineExpanded) {
if (dataset.extent(0) > 0) { compute_dataset_norms_(res); }
}
}

/**
* Replace the dataset with a new dataset.
*
* We create a copy of the dataset on the device. The index manages the lifetime of this copy. It
* is expected that the same set of vectors are used for update_dataset and index build.
*
* Note: This will clear any precomputed dataset norms.
*/
void update_dataset(raft::resources const& res,
raft::host_matrix_view<const T, int64_t, raft::row_major> dataset)
{
dataset_ = make_aligned_dataset(res, dataset, 16);
dataset_norms_.reset();
if (metric() == cuvs::distance::DistanceType::CosineExpanded) {
if (dataset.extent(0) > 0) { compute_dataset_norms_(res); }
}
}

/**
* Replace the dataset with a new dataset. It is expected that the same set of vectors are used
* for update_dataset and index build.
*
* Note: This will clear any precomputed dataset norms.
*/
template <typename DatasetT>
auto update_dataset(raft::resources const& res, DatasetT&& dataset)
-> std::enable_if_t<std::is_base_of_v<cuvs::neighbors::dataset<dataset_index_type>, DatasetT>>
{
dataset_ = std::make_unique<DatasetT>(std::move(dataset));
dataset_norms_.reset();
if (metric() == cuvs::distance::DistanceType::CosineExpanded) {
auto p = dynamic_cast<strided_dataset<T, int64_t>*>(dataset_.get());
if (p) {
auto dataset_view = p->view();
if (dataset_view.extent(0) > 0) { compute_dataset_norms_(res); }
}
}
}

template <typename DatasetT>
auto update_dataset(raft::resources const& res, std::unique_ptr<DatasetT>&& dataset)
-> std::enable_if_t<std::is_base_of_v<neighbors::dataset<dataset_index_type>, DatasetT>>
{
dataset_ = std::move(dataset);
dataset_norms_.reset();
if (metric() == cuvs::distance::DistanceType::CosineExpanded) {
auto dataset_view = this->dataset();
if (dataset_view.extent(0) > 0) { compute_dataset_norms_(res); }
}
}

/**
Expand Down Expand Up @@ -519,6 +570,10 @@ struct index : cuvs::neighbors::index {
raft::device_matrix<IdxT, int64_t, raft::row_major> graph_;
raft::device_matrix_view<const IdxT, int64_t, raft::row_major> graph_view_;
std::unique_ptr<neighbors::dataset<dataset_index_type>> dataset_;
// only float distances supported at the moment
std::optional<raft::device_vector<float, int64_t>> dataset_norms_;
Comment thread
tfeher marked this conversation as resolved.

void compute_dataset_norms_(raft::resources const& res);
};
/**
* @}
Expand All @@ -539,6 +594,7 @@ struct index : cuvs::neighbors::index {
* The following distance metrics are supported:
* - L2
* - InnerProduct (currently only supported with IVF-PQ as the build algorithm)
* - CosineExpanded
*
* Usage example:
* @code{.cpp}
Expand Down Expand Up @@ -576,6 +632,7 @@ auto build(raft::resources const& res,
* The following distance metrics are supported:
* - L2
* - InnerProduct (currently only supported with IVF-PQ as the build algorithm)
* - CosineExpanded
*
* Usage example:
* @code{.cpp}
Expand Down Expand Up @@ -613,6 +670,7 @@ auto build(raft::resources const& res,
* The following distance metrics are supported:
* - L2
* - InnerProduct (currently only supported with IVF-PQ as the build algorithm)
* - CosineExpanded (dataset norms are computed as float regardless of input data type)
*
* Usage example:
* @code{.cpp}
Expand Down Expand Up @@ -649,6 +707,7 @@ auto build(raft::resources const& res,
*
* The following distance metrics are supported:
* - L2
* - CosineExpanded (dataset norms are computed as float regardless of input data type)
*
* Usage example:
* @code{.cpp}
Expand Down Expand Up @@ -685,6 +744,7 @@ auto build(raft::resources const& res,
*
* The following distance metrics are supported:
* - L2
* - CosineExpanded (dataset norms are computed as float regardless of input data type)
*
* Usage example:
* @code{.cpp}
Expand Down Expand Up @@ -722,6 +782,7 @@ auto build(raft::resources const& res,
* The following distance metrics are supported:
* - L2
* - InnerProduct (currently only supported with IVF-PQ as the build algorithm)
* - CosineExpanded (dataset norms are computed as float regardless of input data type)
*
* Usage example:
* @code{.cpp}
Expand Down Expand Up @@ -759,6 +820,7 @@ auto build(raft::resources const& res,
* The following distance metrics are supported:
* - L2
* - InnerProduct (currently only supported with IVF-PQ as the build algorithm)
* - CosineExpanded (dataset norms are computed as float regardless of input data type)
*
* Usage example:
* @code{.cpp}
Expand Down Expand Up @@ -796,6 +858,7 @@ auto build(raft::resources const& res,
* The following distance metrics are supported:
* - L2
* - InnerProduct (currently only supported with IVF-PQ as the build algorithm)
* - CosineExpanded (dataset norms are computed as float regardless of input data type)
*
* Usage example:
* @code{.cpp}
Expand Down
35 changes: 34 additions & 1 deletion cpp/src/neighbors/cagra.cuh
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2023-2024, NVIDIA CORPORATION.
* Copyright (c) 2023-2025, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -22,10 +22,12 @@
#include "detail/cagra/cagra_search.cuh"
#include "detail/cagra/graph_core.cuh"

#include "detail/ann_utils.cuh"
#include <raft/core/device_mdspan.hpp>
#include <raft/core/host_device_accessor.hpp>
#include <raft/core/mdspan.hpp>
#include <raft/core/resources.hpp>
#include <raft/linalg/norm.cuh>

#include <cuvs/distance/distance.hpp>
#include <cuvs/neighbors/cagra.hpp>
Expand All @@ -36,6 +38,37 @@

namespace cuvs::neighbors::cagra {

// Member function implementations for cagra::index
template <typename T, typename IdxT>
void index<T, IdxT>::compute_dataset_norms_(raft::resources const& res)
{
// Get the dataset view
auto dataset_view = this->dataset();

// Allocate norms vector if not already allocated
if (!dataset_norms_.has_value() || dataset_norms_->extent(0) != dataset_view.extent(0)) {
dataset_norms_.reset();
dataset_norms_ = raft::make_device_vector<float, int64_t>(res, dataset_view.extent(0));
Comment thread
tarang-jain marked this conversation as resolved.
}

constexpr float kScale = cuvs::spatial::knn::detail::utils::config<T>::kDivisor /
cuvs::spatial::knn::detail::utils::config<float>::kDivisor;

// first scale the dataset and then compute norms
auto scaled_sq_op = raft::compose_op(
raft::sq_op{}, raft::div_const_op<float>{float(kScale)}, raft::cast_op<float>());
raft::linalg::reduce<true, true, T, float, int64_t>(dataset_norms_->data_handle(),
dataset_view.data_handle(),
dataset_view.stride(0),
dataset_view.extent(0),
(float)0,
raft::resource::get_cuda_stream(res),
false,
scaled_sq_op,
raft::add_op(),
raft::sqrt_op{});
}

/**
* @defgroup cagra CUDA ANN Graph-based nearest neighbor search
* @{
Expand Down
10 changes: 8 additions & 2 deletions cpp/src/neighbors/detail/cagra/cagra_build.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -127,8 +127,9 @@ void build_knn_graph(
cuvs::neighbors::cagra::graph_build_params::ivf_pq_params pq)
{
RAFT_EXPECTS(pq.build_params.metric == cuvs::distance::DistanceType::L2Expanded ||
pq.build_params.metric == cuvs::distance::DistanceType::InnerProduct,
"Currently only L2Expanded or InnerProduct metric are supported");
pq.build_params.metric == cuvs::distance::DistanceType::InnerProduct ||
pq.build_params.metric == cuvs::distance::DistanceType::CosineExpanded,
"Currently only L2Expanded, InnerProduct and CosineExpanded metrics are supported");

uint32_t node_degree = knn_graph.extent(1);
raft::common::nvtx::range<cuvs::common::nvtx::domain::cuvs> fun_scope(
Expand Down Expand Up @@ -710,6 +711,11 @@ index<T, IdxT> build(
std::holds_alternative<cagra::graph_build_params::nn_descent_params>(knn_build_params),
"IVF_PQ for CAGRA graph build does not support BitwiseHamming as a metric. Please "
"use nn-descent or the iterative CAGRA search build.");
RAFT_EXPECTS(
params.metric != cuvs::distance::DistanceType::CosineExpanded ||
std::holds_alternative<cagra::graph_build_params::ivf_pq_params>(knn_build_params) ||
std::holds_alternative<cagra::graph_build_params::nn_descent_params>(knn_build_params),
"CosineExpanded distance is not supported for iterative CAGRA graph build.");

// Validate data type for BitwiseHamming metric
RAFT_EXPECTS(params.metric != cuvs::distance::DistanceType::BitwiseHamming ||
Expand Down
Loading
Loading