Skip to content

Commit 51fb88e

Browse files
authored
Use PQ API in CAGRA-Q + SCANN (#1746)
Follow-up to the PQ PR #1278 . Closes #1575 Closes #1747 This PR removes the need to compile multiple times the same code for PQ in CAGRA-Q and SCANN, removing code duplication and improving build time. CAGRA-Q can't use the new public API since it is using half for its math type so an private API function is used. A small test is added to SCANN to make sure that the returned index is not complete garbage but more testing should be done there. (Created issue #1747 to track this) This PR saves ~2-3 Mb on libcuvs.so compiled on a single architecture (141 Mb -> 138Mb) Authors: - Micka (https://github.com/lowener) Approvers: - Corey J. Nolet (https://github.com/cjnolet) URL: #1746
1 parent b5eebba commit 51fb88e

11 files changed

Lines changed: 420 additions & 407 deletions

File tree

cpp/src/neighbors/detail/cagra/cagra_build.cuh

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
#pragma once
66

77
#include "../../../core/nvtx.hpp"
8-
#include "../../vpq_dataset.cuh"
8+
#include "../../../preprocessing/quantize/vpq_build-ext.cuh"
99
#include "graph_core.cuh"
1010

1111
#include <raft/core/copy.cuh>
@@ -2279,8 +2279,7 @@ index<T, IdxT> build(
22792279
idx.update_dataset(
22802280
res,
22812281
// TODO: hardcoding codebook math to `half`, we can do runtime dispatching later
2282-
cuvs::neighbors::vpq_build<decltype(dataset), half, int64_t>(
2283-
res, *params.compression, dataset));
2282+
cuvs::preprocessing::quantize::pq::vpq_build(res, *params.compression, dataset));
22842283

22852284
return idx;
22862285
}

cpp/src/neighbors/detail/cagra/factory.cuh

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* SPDX-FileCopyrightText: Copyright (c) 2023-2025, NVIDIA CORPORATION.
2+
* SPDX-FileCopyrightText: Copyright (c) 2023-2026, NVIDIA CORPORATION.
33
* SPDX-License-Identifier: Apache-2.0
44
*/
55

@@ -11,6 +11,9 @@
1111
#include "search_plan.cuh"
1212
#include "search_single_cta.cuh"
1313

14+
#include <raft/core/resource/custom_resource.hpp>
15+
#include <raft/util/cache.hpp>
16+
1417
#include <cuvs/neighbors/common.hpp>
1518

1619
namespace cuvs::neighbors::cagra::detail {

cpp/src/neighbors/detail/vpq_dataset.cuh

Lines changed: 82 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -467,6 +467,41 @@ void process_and_fill_codes(
467467
RAFT_FAIL("Invalid pq_bits (%u), the value must be within [4, 16]", pq_bits);
468468
}
469469
}(pq_bits);
470+
bool need_copy_to_device =
471+
cuvs::spatial::knn::detail::utils::check_pointer_residency(dataset.data_handle()) ==
472+
cuvs::spatial::knn::detail::utils::pointer_residency::host_only;
473+
bool need_batching = n_rows > kReasonableMaxBatchSize;
474+
auto launch_work = [&](auto& dataset_view, auto& labels_view, auto& codes_view) {
475+
if (inline_vq_labels || (!vq_labels.empty() && !vq_centers.empty())) {
476+
predict_vq<label_t>(res, dataset_view, vq_centers, labels_view);
477+
}
478+
dim3 blocks(
479+
raft::div_rounding_up_safe<ix_t>(dataset_view.extent(0), kBlockSize / threads_per_vec), 1, 1);
480+
kernel<<<blocks, threads, sharedMemorySize, stream>>>(codes_view,
481+
dataset_view,
482+
pq_centers,
483+
vq_centers,
484+
raft::make_const_mdspan(labels_view),
485+
rows_in_shared_memory,
486+
pq_bits,
487+
inline_vq_labels);
488+
RAFT_CUDA_TRY(cudaPeekAtLastError());
489+
};
490+
auto batch_labels = raft::make_device_vector<label_t, IdxT>(res, 0);
491+
if (!need_batching && !need_copy_to_device) {
492+
// No batching needed, launch the kernel directly
493+
auto dataset_view = raft::make_device_matrix_view(dataset.data_handle(), n_rows, dim);
494+
auto labels_view = raft::make_device_vector_view<label_t, IdxT>(nullptr, 0);
495+
if (inline_vq_labels) {
496+
batch_labels = raft::make_device_vector<label_t, IdxT>(res, dataset_view.extent(0));
497+
labels_view = batch_labels.view();
498+
} else if (!vq_labels.empty() && !vq_centers.empty()) {
499+
labels_view = vq_labels;
500+
}
501+
launch_work(dataset_view, labels_view, codes);
502+
return;
503+
}
504+
470505
for (const auto& batch : cuvs::spatial::knn::detail::utils::batch_load_iterator(
471506
dataset.data_handle(),
472507
n_rows,
@@ -475,53 +510,20 @@ void process_and_fill_codes(
475510
stream,
476511
rmm::mr::get_current_device_resource())) {
477512
auto batch_view = raft::make_device_matrix_view(batch.data(), ix_t(batch.size()), dim);
478-
auto batch_labels = raft::make_device_vector<label_t, IdxT>(res, 0);
479513
auto batch_labels_view = raft::make_device_vector_view<label_t, IdxT>(nullptr, 0);
480514
if (inline_vq_labels) {
481515
batch_labels = raft::make_device_vector<label_t, IdxT>(res, batch.size());
482516
batch_labels_view = batch_labels.view();
483-
predict_vq<label_t>(res, batch_view, vq_centers, batch_labels_view);
484-
} else {
485-
if (!vq_labels.empty() && !vq_centers.empty()) {
486-
batch_labels_view = raft::make_device_vector_view<label_t, IdxT>(
487-
vq_labels.data_handle() + batch.offset(), batch.size());
488-
predict_vq<label_t>(res, batch_view, vq_centers, batch_labels_view);
489-
}
517+
} else if (!vq_labels.empty() && !vq_centers.empty()) {
518+
batch_labels_view = raft::make_device_vector_view<label_t, IdxT>(
519+
vq_labels.data_handle() + batch.offset(), batch.size());
490520
}
491-
dim3 blocks(raft::div_rounding_up_safe<ix_t>(n_rows, kBlockSize / threads_per_vec), 1, 1);
492-
kernel<<<blocks, threads, sharedMemorySize, stream>>>(
493-
raft::make_device_matrix_view<uint8_t, IdxT>(
494-
codes.data_handle() + batch.offset() * codes_rowlen, batch.size(), codes_rowlen),
495-
batch_view,
496-
pq_centers,
497-
vq_centers,
498-
raft::make_const_mdspan(batch_labels_view),
499-
rows_in_shared_memory,
500-
pq_bits,
501-
inline_vq_labels);
502-
RAFT_CUDA_TRY(cudaPeekAtLastError());
521+
auto batch_codes_view = raft::make_device_matrix_view<uint8_t, IdxT>(
522+
codes.data_handle() + batch.offset() * codes_rowlen, batch.size(), codes_rowlen);
523+
launch_work(batch_view, batch_labels_view, batch_codes_view);
503524
}
504525
}
505526

506-
template <typename NewMathT, typename OldMathT, typename IdxT>
507-
auto vpq_convert_math_type(const raft::resources& res, vpq_dataset<OldMathT, IdxT>&& src)
508-
-> vpq_dataset<NewMathT, IdxT>
509-
{
510-
auto vq_code_book = raft::make_device_mdarray<NewMathT>(res, src.vq_code_book.extents());
511-
auto pq_code_book = raft::make_device_mdarray<NewMathT>(res, src.pq_code_book.extents());
512-
513-
raft::linalg::map(res,
514-
vq_code_book.view(),
515-
cuvs::spatial::knn::detail::utils::mapping<NewMathT>{},
516-
raft::make_const_mdspan(src.vq_code_book.view()));
517-
raft::linalg::map(res,
518-
pq_code_book.view(),
519-
cuvs::spatial::knn::detail::utils::mapping<NewMathT>{},
520-
raft::make_const_mdspan(src.pq_code_book.view()));
521-
return vpq_dataset<NewMathT, IdxT>{
522-
std::move(vq_code_book), std::move(pq_code_book), std::move(src.data)};
523-
}
524-
525527
// Helper for operations using vectorized loads of raft::TxN_t
526528
template <typename MathT, int VectorSize>
527529
struct vec_op : raft::TxN_t<MathT, VectorSize> {
@@ -858,14 +860,40 @@ void process_and_fill_codes_subspaces(
858860
}
859861
}(pq_bits);
860862

861-
ix_t max_batch_size = std::min<ix_t>(n_rows, kReasonableMaxBatchSize);
862-
auto copy_stream = raft::resource::get_cuda_stream(res); // Using the main stream by default
863-
bool enable_prefetch = false;
864-
if (res.has_resource_factory(raft::resource::resource_type::CUDA_STREAM_POOL)) {
865-
if (raft::resource::get_stream_pool_size(res) >= 1) {
866-
enable_prefetch = true;
867-
copy_stream = raft::resource::get_stream_from_stream_pool(res);
863+
ix_t max_batch_size = std::min<ix_t>(n_rows, kReasonableMaxBatchSize);
864+
auto copy_stream = raft::resource::get_cuda_stream(res); // Using the main stream by default
865+
bool enable_prefetch_stream = false;
866+
bool has_cuda_stream_pool_resource =
867+
res.has_resource_factory(raft::resource::resource_type::CUDA_STREAM_POOL) &&
868+
raft::resource::get_stream_pool_size(res) >= 1;
869+
bool need_copy_to_device =
870+
cuvs::spatial::knn::detail::utils::check_pointer_residency(dataset.data_handle()) ==
871+
cuvs::spatial::knn::detail::utils::pointer_residency::host_only;
872+
bool need_batching = n_rows > kReasonableMaxBatchSize;
873+
auto launch_work = [&](auto& dataset_view, auto& labels_view, auto& codes_view) {
874+
if (!vq_labels.empty() && !vq_centers.empty()) {
875+
predict_vq<label_t>(res, dataset_view, vq_centers, labels_view);
868876
}
877+
dim3 blocks(
878+
raft::div_rounding_up_safe<ix_t>(dataset_view.extent(0), kBlockSize / threads_per_vec), 1, 1);
879+
kernel<<<blocks, threads, shared_memory_size, stream>>>(codes_view,
880+
dataset_view,
881+
pq_centers,
882+
vq_centers,
883+
raft::make_const_mdspan(labels_view),
884+
pq_bits,
885+
shared_memory_size > 0);
886+
RAFT_CUDA_TRY(cudaPeekAtLastError());
887+
};
888+
if (!need_batching && !need_copy_to_device) {
889+
// No batching and no copy to device needed, launch the kernel directly
890+
auto dataset_view = raft::make_device_matrix_view(dataset.data_handle(), n_rows, dim);
891+
launch_work(dataset_view, vq_labels, codes);
892+
return;
893+
}
894+
if (has_cuda_stream_pool_resource && need_copy_to_device) {
895+
enable_prefetch_stream = true;
896+
copy_stream = raft::resource::get_stream_from_stream_pool(res);
869897
}
870898
auto vec_batches = cuvs::spatial::knn::detail::utils::batch_load_iterator(
871899
dataset.data_handle(),
@@ -874,62 +902,22 @@ void process_and_fill_codes_subspaces(
874902
max_batch_size,
875903
copy_stream,
876904
raft::resource::get_workspace_resource(res),
877-
enable_prefetch);
905+
enable_prefetch_stream);
878906
vec_batches.prefetch_next_batch();
879907
for (const auto& batch : vec_batches) {
880908
auto batch_view = raft::make_device_matrix_view(batch.data(), ix_t(batch.size()), dim);
881909
auto batch_labels = raft::make_device_vector_view<label_t, IdxT>(nullptr, 0);
882910
if (!vq_labels.empty() && !vq_centers.empty()) {
883911
batch_labels = raft::make_device_vector_view<label_t, IdxT>(
884912
vq_labels.data_handle() + batch.offset(), batch.size());
885-
predict_vq<label_t>(res, batch_view, vq_centers, batch_labels);
886913
}
887-
dim3 blocks(raft::div_rounding_up_safe<ix_t>(batch.size(), kBlockSize / threads_per_vec), 1, 1);
888-
kernel<<<blocks, threads, shared_memory_size, stream>>>(
889-
raft::make_device_matrix_view<uint8_t, IdxT>(
890-
codes.data_handle() + batch.offset() * codes_rowlen, batch.size(), codes_rowlen),
891-
batch_view,
892-
pq_centers,
893-
vq_centers,
894-
raft::make_const_mdspan(batch_labels),
895-
pq_bits,
896-
shared_memory_size > 0);
897-
RAFT_CUDA_TRY(cudaPeekAtLastError());
898-
vec_batches.prefetch_next_batch();
899-
raft::resource::sync_stream(res);
914+
auto batch_codes_view = raft::make_device_matrix_view<uint8_t, IdxT>(
915+
codes.data_handle() + batch.offset() * codes_rowlen, batch.size(), codes_rowlen);
916+
launch_work(batch_view, batch_labels, batch_codes_view);
917+
if (enable_prefetch_stream) {
918+
vec_batches.prefetch_next_batch();
919+
raft::resource::sync_stream(res);
920+
}
900921
}
901922
}
902-
903-
template <typename DatasetT, typename MathT, typename IdxT>
904-
auto vpq_build(const raft::resources& res, const vpq_params& params, const DatasetT& dataset)
905-
-> vpq_dataset<MathT, IdxT>
906-
{
907-
using label_t = uint32_t;
908-
// Use a heuristic to impute missing parameters.
909-
auto ps = fill_missing_params_heuristics(params, dataset);
910-
911-
// Train codes
912-
auto vq_code_book = train_vq<MathT>(res, ps, dataset);
913-
auto pq_code_book =
914-
train_pq<MathT>(res, ps, dataset, raft::make_const_mdspan(vq_code_book.view()));
915-
916-
// Encode dataset
917-
const IdxT n_rows = dataset.extent(0);
918-
const IdxT codes_rowlen = sizeof(label_t) * (1 + raft::div_rounding_up_safe<IdxT>(
919-
ps.pq_dim * ps.pq_bits, 8 * sizeof(label_t)));
920-
921-
auto codes = raft::make_device_matrix<uint8_t, IdxT, raft::row_major>(res, n_rows, codes_rowlen);
922-
process_and_fill_codes<MathT, IdxT>(res,
923-
ps,
924-
dataset,
925-
raft::make_const_mdspan(pq_code_book.view()),
926-
raft::make_const_mdspan(vq_code_book.view()),
927-
raft::make_device_vector_view<label_t, IdxT>(nullptr, 0),
928-
codes.view(),
929-
true);
930-
931-
return vpq_dataset<MathT, IdxT>{
932-
std::move(vq_code_book), std::move(pq_code_book), std::move(codes)};
933-
}
934-
935923
} // namespace cuvs::neighbors::detail

0 commit comments

Comments
 (0)