Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
Show all changes
113 commits
Select commit Hold shift + click to select a range
07dbefe
initial
tarang-jain Jun 25, 2024
b5c1f2c
Merge branch 'branch-24.08' of https://github.com/rapidsai/cuvs into …
tarang-jain Jun 25, 2024
2387a15
Merge branch 'branch-24.08' of https://github.com/rapidsai/cuvs into …
tarang-jain Jun 28, 2024
6156cf5
update postprocess_distances
tarang-jain Jun 28, 2024
79de8a8
resolve merge conflicts
tarang-jain Aug 13, 2025
cff7494
update tests
tarang-jain Aug 13, 2025
a38b92b
update instantiations
tarang-jain Aug 13, 2025
0e1c980
re-update cagra-search
tarang-jain Aug 13, 2025
2f19510
corrections
tarang-jain Aug 13, 2025
93a9944
correct
tarang-jain Aug 13, 2025
c732eca
correct query normalization
tarang-jain Aug 13, 2025
ccfda68
correct template type
tarang-jain Aug 13, 2025
954e417
use ip dist_op
tarang-jain Aug 13, 2025
7dd97fe
cleanup
tarang-jain Aug 13, 2025
b753ceb
Merge branch 'branch-25.10' into cagra-dist-metric
tarang-jain Aug 13, 2025
56a2a81
cleanup
tarang-jain Aug 13, 2025
c166424
Merge branch 'cagra-dist-metric' of https://github.com/tarang-jain/cu…
tarang-jain Aug 13, 2025
6cddce9
style
tarang-jain Aug 13, 2025
7ca1936
only float and half
tarang-jain Aug 13, 2025
8586fc4
compute dataset norm
tarang-jain Aug 13, 2025
8e185dd
fix errors
tarang-jain Aug 13, 2025
a828f7d
compilation errors
tarang-jain Aug 13, 2025
c148ac6
fix compilation errors
tarang-jain Aug 13, 2025
9bc030b
compilation errors
tarang-jain Aug 13, 2025
f292a49
error instead of warning
tarang-jain Aug 13, 2025
667c2bb
fix error
tarang-jain Aug 13, 2025
b7fe9ec
Merge branch 'branch-25.10' into cagra-dist-metric
tarang-jain Aug 14, 2025
4a20ade
Merge branch 'branch-25.10' into cagra-dist-metric
cjnolet Aug 14, 2025
290fc18
fix compilation;add cmake targets for spec
tarang-jain Aug 14, 2025
59b6333
Merge branch 'cagra-dist-metric' of https://github.com/tarang-jain/cu…
tarang-jain Aug 14, 2025
91a3306
debug
tarang-jain Aug 15, 2025
2a9789e
Merge branch 'branch-25.10' into cagra-dist-metric
tarang-jain Aug 18, 2025
9fe270f
everything seems to be working for cosine metric
tarang-jain Aug 20, 2025
002dbf7
Merge branch 'branch-25.10' of https://github.com/rapidsai/cuvs into …
tarang-jain Aug 20, 2025
40a392a
Merge branch 'cagra-dist-metric' of https://github.com/tarang-jain/cu…
tarang-jain Aug 20, 2025
b7be4fe
move norm computation to helper
tarang-jain Aug 20, 2025
a24ba7c
separate out compute_dataset_norms;style
tarang-jain Aug 20, 2025
32f02bc
rm log statements
tarang-jain Aug 20, 2025
d85f526
cmake c flags,copyright
tarang-jain Aug 20, 2025
b56931f
rm extra files
tarang-jain Aug 20, 2025
40fcdab
cleanup docs'rm unused headers
tarang-jain Aug 20, 2025
ffb4d8d
assertion
tarang-jain Aug 20, 2025
2d7bd0d
cleanup tests
tarang-jain Aug 20, 2025
abb2e10
fix bad optional access
tarang-jain Aug 20, 2025
137c5c8
update python tests
tarang-jain Aug 20, 2025
684b465
update python tests
tarang-jain Aug 20, 2025
12d6d31
fix failing py tests
tarang-jain Aug 21, 2025
b871a80
style
tarang-jain Aug 21, 2025
926dc39
allow nnd and interative
tarang-jain Aug 21, 2025
5f52b84
compute_distance types for uint8 and int8
tarang-jain Aug 21, 2025
0ce20c7
clang format
tarang-jain Aug 21, 2025
b62c55c
add int8 and uint8 src files to CMakeLists.txt
tarang-jain Aug 21, 2025
d673179
update norm computation for iterative
tarang-jain Aug 22, 2025
f88a4ec
fix norm scaling
tarang-jain Aug 22, 2025
30d3882
style
tarang-jain Aug 22, 2025
40880ef
Merge branch 'branch-25.10' into cagra-dist-metric
tarang-jain Aug 22, 2025
0b5c183
compose_op with scale
tarang-jain Aug 22, 2025
dd9bb33
Merge branch 'branch-25.10' of https://github.com/rapidsai/cuvs into …
tarang-jain Aug 22, 2025
7ee4eaf
Merge branch 'cagra-dist-metric' of https://github.com/tarang-jain/cu…
tarang-jain Aug 22, 2025
9aae03b
correct test skips
tarang-jain Aug 23, 2025
a1f0689
compute scaled norms
tarang-jain Aug 25, 2025
215cbc5
Merge branch 'branch-25.10' into cagra-dist-metric
tarang-jain Aug 25, 2025
c660038
debug
tarang-jain Aug 25, 2025
5731bd8
Merge branch 'cagra-dist-metric' of https://github.com/tarang-jain/cu…
tarang-jain Aug 25, 2025
1084bc9
checkzero div
tarang-jain Aug 25, 2025
e6faf97
update tests; rm iterative
tarang-jain Aug 25, 2025
13431cb
Merge branch 'branch-25.10' into cagra-dist-metric
tarang-jain Aug 25, 2025
6e6bfb2
update skip conditions
tarang-jain Aug 25, 2025
7650ef5
Merge branch 'cagra-dist-metric' of https://github.com/tarang-jain/cu…
tarang-jain Aug 25, 2025
ffd8d44
shorten diff
tarang-jain Aug 25, 2025
134b257
rm debug prints;docs
tarang-jain Aug 26, 2025
9507cf6
rm double computation of norms
tarang-jain Aug 26, 2025
b01d54f
rm unused header
tarang-jain Aug 26, 2025
e548e09
rm set_dataset_norms;simplify compute_dataset_norms
tarang-jain Aug 26, 2025
7ec1263
Merge branch 'branch-25.10' of https://github.com/rapidsai/cuvs into …
tarang-jain Aug 26, 2025
03e2f79
Merge branch 'branch-25.10' into cagra-dist-metric
tarang-jain Sep 2, 2025
eb94316
Merge branch 'branch-25.10' into cagra-dist-metric
tarang-jain Sep 2, 2025
fcb0c8e
Merge branch 'branch-25.10' into cagra-dist-metric
tarang-jain Sep 5, 2025
b3bef25
Merge branch 'branch-25.10' into cagra-dist-metric
tarang-jain Sep 9, 2025
d6ddce9
Merge branch 'branch-25.10' into cagra-dist-metric
cjnolet Sep 15, 2025
38e2549
compute_dataset_norms private function
tarang-jain Sep 15, 2025
9424bf6
Merge branch 'cagra-dist-metric' of https://github.com/tarang-jain/cu…
tarang-jain Sep 15, 2025
538d792
Merge branch 'branch-25.10' into cagra-dist-metric
tarang-jain Sep 15, 2025
ca0b7c2
Merge branch 'branch-25.10' into cagra-dist-metric
tarang-jain Sep 17, 2025
2e6fe2f
update cagra python test
tarang-jain Sep 18, 2025
de4fbdc
Merge branch 'branch-25.10' into cagra-dist-metric
tarang-jain Sep 18, 2025
2222009
Update cpp/src/neighbors/cagra.cuh
tarang-jain Sep 22, 2025
f180794
Merge branch 'branch-25.10' into cagra-dist-metric
tarang-jain Sep 22, 2025
0b9c9e1
deallocate norms
tarang-jain Sep 23, 2025
2e60619
pull origin
tarang-jain Sep 23, 2025
fc7fdde
ivfpq cosine support for int types
tarang-jain Sep 23, 2025
66adf7b
Merge branch 'branch-25.10' into cagra-dist-metric
tarang-jain Sep 23, 2025
2f51b66
rm gtest filter for ivfpq
tarang-jain Sep 23, 2025
405c21f
Merge branch 'cagra-dist-metric' of https://github.com/tarang-jain/cu…
tarang-jain Sep 23, 2025
9080462
Merge branch 'branch-25.10' into cagra-dist-metric
tarang-jain Sep 25, 2025
5a7a694
Merge branch 'branch-25.10' into cagra-dist-metric
tarang-jain Sep 25, 2025
887d82b
Update cpp/src/neighbors/ivf_pq/ivf_pq_build.cuh
tarang-jain Sep 25, 2025
d4d2cff
style
tarang-jain Sep 25, 2025
c0801ac
update cagra tests
tarang-jain Sep 26, 2025
accd841
Merge branch 'branch-25.10' of https://github.com/rapidsai/cuvs into …
tarang-jain Sep 26, 2025
1c33e17
fix cpp warning
tarang-jain Sep 26, 2025
5da970a
Merge branch 'branch-25.10' into cagra-dist-metric
tarang-jain Sep 27, 2025
544dd8d
Merge branch 'branch-25.10' into cagra-dist-metric
tarang-jain Sep 29, 2025
eec8c04
update tests
tarang-jain Sep 29, 2025
b416b81
Merge branch 'cagra-dist-metric' of https://github.com/tarang-jain/cu…
tarang-jain Sep 29, 2025
c5089f7
update tests
tarang-jain Sep 30, 2025
d689129
fix syntax
tarang-jain Sep 30, 2025
a6c6592
fix compilation errors
tarang-jain Sep 30, 2025
99f1317
fix cosine docstring
tarang-jain Sep 30, 2025
c17f380
fix cosine docstring
tarang-jain Sep 30, 2025
f4c8d82
Merge branch 'branch-25.10' into cagra-dist-metric
tarang-jain Sep 30, 2025
da6ea75
Merge branch 'branch-25.10' into cagra-dist-metric
tarang-jain Oct 1, 2025
883d368
Merge branch 'branch-25.12' into cagra-dist-metric
tarang-jain Oct 2, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions cpp/src/neighbors/detail/cagra/cagra_build.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -128,8 +128,9 @@ void build_knn_graph(
cuvs::neighbors::cagra::graph_build_params::ivf_pq_params pq)
{
RAFT_EXPECTS(pq.build_params.metric == cuvs::distance::DistanceType::L2Expanded ||
pq.build_params.metric == cuvs::distance::DistanceType::InnerProduct,
"Currently only L2Expanded or InnerProduct metric are supported");
pq.build_params.metric == cuvs::distance::DistanceType::InnerProduct ||
pq.build_params.metric == cuvs::distance::DistanceType::CosineExpanded,
"Currently only L2Expanded, InnerProduct and CosineExpanded metrics are supported");

uint32_t node_degree = knn_graph.extent(1);
raft::common::nvtx::range<raft::common::nvtx::domain::raft> fun_scope(
Expand Down
84 changes: 78 additions & 6 deletions cpp/src/neighbors/detail/cagra/compute_distance.hpp
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @tarang-jain, a heads up here: #296 does a major refactoring of related code; let's have a look together how we can proceed with this PR once you're back to it, ok?
I have similar performance concerns as the ones we discussed on IVF-PQ; maybe it makes sense to keep the dataset normalized for cosine distance (and reuse the inner-product code path)?
Then we can either normalize the query at the time we copy it to the shared memory (pre-processing) or divide by the query norm at the post-processing/filtering step at the end of the kernel.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Refactored this PR to do the divide by query norm at the very end (postprocessing stage)-

Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,12 @@ _RAFT_DEVICE void compute_distance_to_random_nodes(
TEAM_SIZE,
cuvs::distance::DistanceType::InnerProduct>(
query_buffer, seed_index, valid_i);
case cuvs::distance::DistanceType::CosineExpanded:
norm2 =
dataset_desc.template compute_similarity<DATASET_BLOCK_DIM,
TEAM_SIZE,
cuvs::distance::DistanceType::CosineExpanded>(
query_buffer, seed_index, valid_i);
break;
default: break;
}
Expand Down Expand Up @@ -191,6 +197,13 @@ _RAFT_DEVICE void compute_distance_to_child_nodes(
cuvs::distance::DistanceType::InnerProduct>(
query_buffer, child_id, child_id != invalid_index);
break;
case cuvs::distance::DistanceType::CosineExpanded:
norm2 =
dataset_desc.template compute_similarity<DATASET_BLOCK_DIM,
TEAM_SIZE,
cuvs::distance::DistanceType::CosineExpanded>(
query_buffer, child_id, child_id != invalid_index);
break;
default: break;
}

Expand Down Expand Up @@ -275,9 +288,12 @@ struct standard_dataset_descriptor_t
}

template <uint32_t DATASET_BLOCK_DIM, uint32_t TEAM_SIZE, cuvs::distance::DistanceType METRIC>
__device__ DISTANCE_T compute_similarity(const QUERY_T* const query_ptr,
const INDEX_T dataset_i,
const bool valid) const
std::enable_if_t<METRIC == cuvs::distance::DistanceType::L2Expanded ||
METRIC == cuvs::distance::DistanceType::InnerProduct,
DISTANCE_T>
__device__ compute_similarity(const QUERY_T* const query_ptr,
const INDEX_T dataset_i,
const bool valid) const
{
const auto dataset_ptr = ptr + dataset_i * ld;
const unsigned lane_id = threadIdx.x % TEAM_SIZE;
Expand All @@ -286,7 +302,8 @@ struct standard_dataset_descriptor_t
constexpr unsigned reg_nelem = raft::ceildiv<unsigned>(DATASET_BLOCK_DIM, TEAM_SIZE * vlen);
raft::TxN_t<DATA_T, vlen> dl_buff[reg_nelem];

DISTANCE_T norm2 = 0;
DISTANCE_T dist = 0;

if (valid) {
for (uint32_t elem_offset = 0; elem_offset < dim; elem_offset += DATASET_BLOCK_DIM) {
#pragma unroll
Expand All @@ -307,16 +324,71 @@ struct standard_dataset_descriptor_t
// - Above the last element (dataset_dim-1), the query array is filled with zeros.
// - The data buffer has to be also padded with zeros.
DISTANCE_T d = query_ptr[device::swizzling(kv)];
norm2 += dist_op<DISTANCE_T, METRIC>(
dist += dist_op<DISTANCE_T, METRIC>(
d, cuvs::spatial::knn::detail::utils::mapping<float>{}(dl_buff[e].val.data[v]));
}
}
}
}
for (uint32_t offset = TEAM_SIZE / 2; offset > 0; offset >>= 1) {
dist += __shfl_xor_sync(0xffffffff, dist, offset);
}

return dist;
}

template <uint32_t DATASET_BLOCK_DIM, uint32_t TEAM_SIZE, cuvs::distance::DistanceType METRIC>
std::enable_if_t<METRIC == cuvs::distance::DistanceType::CosineExpanded, DISTANCE_T> __device__
compute_similarity(const QUERY_T* const query_ptr,
const INDEX_T dataset_i,
const bool valid) const
{
const auto dataset_ptr = ptr + dataset_i * ld;
const unsigned lane_id = threadIdx.x % TEAM_SIZE;
constexpr unsigned vlen = device::get_vlen<LOAD_T, DATA_T>();
// #include <raft/util/cuda_dev_essentials.cuh
constexpr unsigned reg_nelem = raft::ceildiv<unsigned>(DATASET_BLOCK_DIM, TEAM_SIZE * vlen);
raft::TxN_t<DATA_T, vlen> dl_buff[reg_nelem];

DISTANCE_T dist = 0;
DISTANCE_T norm1 = 0;
DISTANCE_T norm2 = 0;
if (valid) {
for (uint32_t elem_offset = 0; elem_offset < dim; elem_offset += DATASET_BLOCK_DIM) {
#pragma unroll
for (uint32_t e = 0; e < reg_nelem; e++) {
const uint32_t k = (lane_id + (TEAM_SIZE * e)) * vlen + elem_offset;
if (k >= dim) break;
dl_buff[e].load(dataset_ptr, k);
}
#pragma unroll
for (uint32_t e = 0; e < reg_nelem; e++) {
const uint32_t k = (lane_id + (TEAM_SIZE * e)) * vlen + elem_offset;
if (k >= dim) break;
#pragma unroll
for (uint32_t v = 0; v < vlen; v++) {
const uint32_t kv = k + v;
// Note this loop can go above the dataset_dim for padded arrays. This is not a problem
// because:
// - Above the last element (dataset_dim-1), the query array is filled with zeros.
// - The data buffer has to be also padded with zeros.
DISTANCE_T q = query_ptr[device::swizzling(kv)];
DISTANCE_T d =
cuvs::spatial::knn::detail::utils::mapping<float>{}(dl_buff[e].val.data[v]);
dist += dist_op<DISTANCE_T, cuvs::distance::DistanceType::InnerProduct>(q, d);
norm1 += q * q;
norm2 += d * d;
}
}
}
}
for (uint32_t offset = TEAM_SIZE / 2; offset > 0; offset >>= 1) {
dist += __shfl_xor_sync(0xffffffff, dist, offset);
norm1 += __shfl_xor_sync(0xffffffff, norm1, offset);
norm2 += __shfl_xor_sync(0xffffffff, norm2, offset);
}
return norm2;

return dist / (norm1 * norm2);
}
};

Expand Down
13 changes: 13 additions & 0 deletions cpp/src/neighbors/ivf_common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,19 @@ void postprocess_distances(ScoreOutT* out, // [n_queries, topk]
raft::linalg::unaryOp(out, in, len, raft::cast_op<ScoreOutT>{}, stream);
}
} break;
case distance::DistanceType::CosineExpanded: {
float factor = (account_for_max_close ? -1.0 : 1.0);
if (factor != 1.0) {
raft::linalg::unaryOp(
out,
in,
len,
raft::compose_op(raft::mul_const_op<ScoreOutT>{factor}, raft::cast_op<ScoreOutT>{}),
stream);
} else if (needs_cast || needs_copy) {
raft::linalg::unaryOp(out, in, len, raft::cast_op<ScoreOutT>{}, stream);
}
} break;
default: RAFT_FAIL("Unexpected metric.");
}
}
Expand Down
24 changes: 18 additions & 6 deletions cpp/test/neighbors/ann_cagra.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -385,7 +385,9 @@ inline std::vector<AnnCagraInputs> generate_inputs()
{0},
{256},
{1},
{cuvs::distance::DistanceType::L2Expanded},
{cuvs::distance::DistanceType::L2Expanded,
cuvs::distance::DistanceType::InnerProduct,
cuvs::distance::DistanceType::CosineExpanded},
{false},
{true},
{0.995});
Expand All @@ -401,7 +403,9 @@ inline std::vector<AnnCagraInputs> generate_inputs()
{0},
{64},
{1},
{cuvs::distance::DistanceType::L2Expanded},
{cuvs::distance::DistanceType::L2Expanded,
cuvs::distance::DistanceType::InnerProduct,
cuvs::distance::DistanceType::CosineExpanded},
{false},
{true},
{0.995});
Expand All @@ -417,7 +421,9 @@ inline std::vector<AnnCagraInputs> generate_inputs()
{0, 4, 8, 16, 32}, // team_size
{64},
{1},
{cuvs::distance::DistanceType::L2Expanded},
{cuvs::distance::DistanceType::L2Expanded,
cuvs::distance::DistanceType::InnerProduct,
cuvs::distance::DistanceType::CosineExpanded},
{false},
{false},
{0.995});
Expand All @@ -434,7 +440,9 @@ inline std::vector<AnnCagraInputs> generate_inputs()
{0}, // team_size
{32, 64, 128, 256, 512, 768},
{1},
{cuvs::distance::DistanceType::L2Expanded},
{cuvs::distance::DistanceType::L2Expanded,
cuvs::distance::DistanceType::InnerProduct,
cuvs::distance::DistanceType::CosineExpanded},
{false},
{true},
{0.995});
Expand Down Expand Up @@ -469,7 +477,9 @@ inline std::vector<AnnCagraInputs> generate_inputs()
{0},
{64},
{1},
{cuvs::distance::DistanceType::L2Expanded},
{cuvs::distance::DistanceType::L2Expanded,
cuvs::distance::DistanceType::InnerProduct,
cuvs::distance::DistanceType::CosineExpanded},
{false},
{true},
{0.6}); // don't demand high recall without refinement
Expand Down Expand Up @@ -497,7 +507,9 @@ inline std::vector<AnnCagraInputs> generate_inputs()
{0}, // team_size
{64},
{1},
{cuvs::distance::DistanceType::L2Expanded},
{cuvs::distance::DistanceType::L2Expanded,
cuvs::distance::DistanceType::InnerProduct,
cuvs::distance::DistanceType::CosineExpanded},
{false, true},
{false},
{0.99},
Expand Down