Skip to content

Commit 156125b

Browse files
authored
Merge branch 'branch-25.06' into patch-1
2 parents 61652af + 043c06f commit 156125b

41 files changed

Lines changed: 1822 additions & 279 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

cpp/CMakeLists.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -419,6 +419,9 @@ if(BUILD_SHARED_LIBS)
419419
src/neighbors/detail/cagra/cagra_build.cpp
420420
src/neighbors/detail/cagra/topk_for_cagra/topk.cu
421421
src/neighbors/dynamic_batching.cu
422+
src/neighbors/cagra_index_wrapper.cu
423+
src/neighbors/composite/index.cu
424+
src/neighbors/composite/merge.cpp
422425
$<$<BOOL:${BUILD_CAGRA_HNSWLIB}>:src/neighbors/hnsw.cpp>
423426
src/neighbors/ivf_flat_index.cpp
424427
src/neighbors/ivf_flat/ivf_flat_build_extend_float_int64_t.cu

cpp/bench/ann/src/cuvs/cuvs_ann_bench_param_parser.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -277,6 +277,20 @@ void parse_build_param(const nlohmann::json& conf,
277277
parse_build_param(comp_search_conf, vpq_pams);
278278
param.cagra_params.compression.emplace(vpq_pams);
279279
}
280+
281+
if (conf.contains("num_dataset_splits")) {
282+
param.num_dataset_splits = conf.at("num_dataset_splits");
283+
}
284+
if (conf.contains("merge_type")) {
285+
std::string mt = conf.at("merge_type");
286+
if (mt == "PHYSICAL") {
287+
param.merge_type = cuvs::bench::CagraMergeType::kPhysical;
288+
} else if (mt == "LOGICAL") {
289+
param.merge_type = cuvs::bench::CagraMergeType::kLogical;
290+
} else {
291+
throw std::runtime_error("invalid value for merge_type");
292+
}
293+
}
280294
}
281295

282296
cuvs::bench::AllocatorType parse_allocator(std::string mem_type)

cpp/bench/ann/src/cuvs/cuvs_cagra_wrapper.h

Lines changed: 164 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
#include <cuvs/distance/distance.hpp>
2525
#include <cuvs/neighbors/cagra.hpp>
2626
#include <cuvs/neighbors/common.hpp>
27+
#include <cuvs/neighbors/composite/merge.hpp>
2728
#include <cuvs/neighbors/dynamic_batching.hpp>
2829
#include <cuvs/neighbors/ivf_pq.hpp>
2930
#include <cuvs/neighbors/nn_descent.hpp>
@@ -44,14 +45,17 @@
4445
#include <iostream>
4546
#include <memory>
4647
#include <optional>
48+
#include <raft/util/integer_utils.hpp>
4749
#include <stdexcept>
4850
#include <string>
4951
#include <type_traits>
52+
#include <vector>
5053

5154
namespace cuvs::bench {
5255

5356
enum class AllocatorType { kHostPinned, kHostHugePage, kDevice };
5457
enum class CagraBuildAlgo { kAuto, kIvfPq, kNnDescent };
58+
enum class CagraMergeType { kPhysical, kLogical };
5559

5660
template <typename T, typename IdxT>
5761
class cuvs_cagra : public algo<T>, public algo_gpu {
@@ -80,6 +84,8 @@ class cuvs_cagra : public algo<T>, public algo_gpu {
8084
std::optional<float> ivf_pq_refine_rate = std::nullopt;
8185
std::optional<cuvs::neighbors::ivf_pq::index_params> ivf_pq_build_params = std::nullopt;
8286
std::optional<cuvs::neighbors::ivf_pq::search_params> ivf_pq_search_params = std::nullopt;
87+
size_t num_dataset_splits = 1;
88+
CagraMergeType merge_type = CagraMergeType::kPhysical;
8389

8490
void prepare_build_params(const raft::extent_2d<IdxT>& dataset_extents)
8591
{
@@ -188,6 +194,7 @@ class cuvs_cagra : public algo<T>, public algo_gpu {
188194
bool dynamic_batching_conservative_dispatch_;
189195

190196
std::shared_ptr<cuvs::neighbors::filtering::base_filter> filter_;
197+
std::vector<std::shared_ptr<cuvs::neighbors::cagra::index<T, IdxT>>> sub_indices_;
191198

192199
inline rmm::device_async_resource_ref get_mr(AllocatorType mem_type)
193200
{
@@ -211,10 +218,57 @@ void cuvs_cagra<T, IdxT>::build(const T* dataset, size_t nrow)
211218
auto dataset_view_device =
212219
raft::make_mdspan<const T, IdxT, raft::row_major, false, true>(dataset, dataset_extents);
213220
bool dataset_is_on_host = raft::get_device_for_address(dataset) == -1;
221+
if (index_params_.num_dataset_splits <= 1) {
222+
index_ = std::make_shared<cuvs::neighbors::cagra::index<T, IdxT>>(std::move(
223+
dataset_is_on_host ? cuvs::neighbors::cagra::build(handle_, params, dataset_view_host)
224+
: cuvs::neighbors::cagra::build(handle_, params, dataset_view_device)));
225+
} else {
226+
IdxT rows_per_split =
227+
raft::ceildiv<IdxT>(nrow, static_cast<IdxT>(index_params_.num_dataset_splits));
228+
for (size_t i = 0; i < index_params_.num_dataset_splits; ++i) {
229+
IdxT start = static_cast<IdxT>(i * rows_per_split);
230+
if (start >= nrow) break;
231+
IdxT rows = std::min(rows_per_split, static_cast<IdxT>(nrow) - start);
232+
const T* sub_ptr = dataset + static_cast<size_t>(start) * dimension_;
233+
auto sub_host =
234+
raft::make_host_matrix_view<const T, int64_t, raft::row_major>(sub_ptr, rows, dimension_);
235+
auto sub_dev =
236+
raft::make_device_matrix_view<const T, int64_t, raft::row_major>(sub_ptr, rows, dimension_);
237+
238+
auto sub_index =
239+
cuvs::neighbors::cagra::index<T, IdxT>(handle_, index_params_.cagra_params.metric);
240+
if (index_params_.merge_type == CagraMergeType::kPhysical) {
241+
if (dataset_is_on_host) {
242+
sub_index.update_dataset(handle_, sub_host);
243+
} else {
244+
sub_index.update_dataset(handle_, sub_dev);
245+
}
246+
}
247+
if (index_params_.merge_type == CagraMergeType::kLogical) {
248+
if (dataset_is_on_host) {
249+
sub_index = cuvs::neighbors::cagra::build(handle_, params, sub_host);
250+
} else {
251+
sub_index = cuvs::neighbors::cagra::build(handle_, params, sub_dev);
252+
}
253+
}
254+
auto sub_index_shared =
255+
std::make_shared<cuvs::neighbors::cagra::index<T, IdxT>>(std::move(sub_index));
256+
sub_indices_.push_back(std::move(sub_index_shared));
257+
}
258+
if (index_params_.merge_type == CagraMergeType::kPhysical) {
259+
cuvs::neighbors::cagra::merge_params merge_params{index_params_.cagra_params};
260+
merge_params.merge_strategy = cuvs::neighbors::MergeStrategy::MERGE_STRATEGY_PHYSICAL;
261+
262+
std::vector<cuvs::neighbors::cagra::index<T, IdxT>*> indices;
263+
indices.reserve(sub_indices_.size());
264+
for (auto& ptr : sub_indices_) {
265+
indices.push_back(ptr.get());
266+
}
214267

215-
index_ = std::make_shared<cuvs::neighbors::cagra::index<T, IdxT>>(std::move(
216-
dataset_is_on_host ? cuvs::neighbors::cagra::build(handle_, params, dataset_view_host)
217-
: cuvs::neighbors::cagra::build(handle_, params, dataset_view_device)));
268+
index_ = std::make_shared<cuvs::neighbors::cagra::index<T, IdxT>>(
269+
std::move(cuvs::neighbors::cagra::merge(handle_, merge_params, indices)));
270+
}
271+
}
218272
}
219273

220274
inline auto allocator_to_string(AllocatorType mem_type) -> std::string
@@ -233,7 +287,7 @@ template <typename T, typename IdxT>
233287
void cuvs_cagra<T, IdxT>::set_search_param(const search_param_base& param,
234288
const void* filter_bitset)
235289
{
236-
filter_ = make_cuvs_filter(filter_bitset, index_->size());
290+
if (index_) { filter_ = make_cuvs_filter(filter_bitset, index_->size()); }
237291
auto sp = dynamic_cast<const search_param&>(param);
238292
bool needs_dynamic_batcher_update =
239293
(dynamic_batching_max_batch_size_ != sp.dynamic_batching_max_batch_size) ||
@@ -314,27 +368,65 @@ void cuvs_cagra<T, IdxT>::set_search_param(const search_param_base& param,
314368
template <typename T, typename IdxT>
315369
void cuvs_cagra<T, IdxT>::set_search_dataset(const T* dataset, size_t nrow)
316370
{
317-
using ds_idx_type = decltype(index_->data().n_rows());
318-
bool is_vpq =
319-
dynamic_cast<const cuvs::neighbors::vpq_dataset<half, ds_idx_type>*>(&index_->data()) ||
320-
dynamic_cast<const cuvs::neighbors::vpq_dataset<float, ds_idx_type>*>(&index_->data());
321-
// It can happen that we are re-using a previous algo object which already has
322-
// the dataset set. Check if we need update.
323-
if (static_cast<size_t>(input_dataset_v_->extent(0)) != nrow ||
324-
input_dataset_v_->data_handle() != dataset) {
325-
*input_dataset_v_ = raft::make_device_matrix_view<const T, int64_t>(dataset, nrow, this->dim_);
326-
need_dataset_update_ = !is_vpq; // ignore update if this is a VPQ dataset.
371+
if (index_params_.num_dataset_splits > 1 &&
372+
index_params_.merge_type == CagraMergeType::kLogical) {
373+
bool dataset_is_on_host = raft::get_device_for_address(dataset) == -1;
374+
IdxT rows_per_split =
375+
raft::ceildiv<IdxT>(nrow, static_cast<IdxT>(index_params_.num_dataset_splits));
376+
for (size_t i = 0; i < sub_indices_.size(); ++i) {
377+
IdxT start = static_cast<IdxT>(i * rows_per_split);
378+
if (start >= nrow) break;
379+
IdxT rows = std::min(rows_per_split, static_cast<IdxT>(nrow) - start);
380+
const T* sub_ptr = dataset + static_cast<size_t>(start) * dimension_;
381+
auto sub_host =
382+
raft::make_host_matrix_view<const T, int64_t, raft::row_major>(sub_ptr, rows, dimension_);
383+
auto sub_dev =
384+
raft::make_device_matrix_view<const T, int64_t, raft::row_major>(sub_ptr, rows, dimension_);
385+
auto sub_index = sub_indices_[i].get();
386+
if (index_params_.merge_type == CagraMergeType::kLogical) {
387+
if (dataset_is_on_host) {
388+
sub_index->update_dataset(handle_, sub_host);
389+
} else {
390+
sub_index->update_dataset(handle_, sub_dev);
391+
}
392+
}
393+
}
394+
need_dataset_update_ = false;
395+
} else {
396+
using ds_idx_type = decltype(index_->data().n_rows());
397+
bool is_vpq =
398+
dynamic_cast<const cuvs::neighbors::vpq_dataset<half, ds_idx_type>*>(&index_->data()) ||
399+
dynamic_cast<const cuvs::neighbors::vpq_dataset<float, ds_idx_type>*>(&index_->data());
400+
// It can happen that we are re-using a previous algo object which already has
401+
// the dataset set. Check if we need update.
402+
if (static_cast<size_t>(input_dataset_v_->extent(0)) != nrow ||
403+
input_dataset_v_->data_handle() != dataset) {
404+
*input_dataset_v_ =
405+
raft::make_device_matrix_view<const T, int64_t>(dataset, nrow, this->dim_);
406+
need_dataset_update_ = !is_vpq; // ignore update if this is a VPQ dataset.
407+
}
327408
}
328409
}
329410

330411
template <typename T, typename IdxT>
331412
void cuvs_cagra<T, IdxT>::save(const std::string& file) const
332413
{
333-
using ds_idx_type = decltype(index_->data().n_rows());
334-
bool is_vpq =
335-
dynamic_cast<const cuvs::neighbors::vpq_dataset<half, ds_idx_type>*>(&index_->data()) ||
336-
dynamic_cast<const cuvs::neighbors::vpq_dataset<float, ds_idx_type>*>(&index_->data());
337-
cuvs::neighbors::cagra::serialize(handle_, file, *index_, is_vpq);
414+
if (index_params_.num_dataset_splits > 1 &&
415+
index_params_.merge_type == CagraMergeType::kLogical) {
416+
for (size_t i = 0; i < sub_indices_.size(); ++i) {
417+
std::string subfile = file + (i == 0 ? "" : ".subidx." + std::to_string(i));
418+
cuvs::neighbors::cagra::serialize(handle_, subfile, *sub_indices_[i], false);
419+
}
420+
std::ofstream f(file + ".submeta", std::ios::out);
421+
f << sub_indices_.size();
422+
f.close();
423+
} else {
424+
using ds_idx_type = decltype(index_->data().n_rows());
425+
bool is_vpq =
426+
dynamic_cast<const cuvs::neighbors::vpq_dataset<half, ds_idx_type>*>(&index_->data()) ||
427+
dynamic_cast<const cuvs::neighbors::vpq_dataset<float, ds_idx_type>*>(&index_->data());
428+
cuvs::neighbors::cagra::serialize(handle_, file, *index_, is_vpq);
429+
}
338430
}
339431

340432
template <typename T, typename IdxT>
@@ -346,8 +438,24 @@ void cuvs_cagra<T, IdxT>::save_to_hnswlib(const std::string& file) const
346438
template <typename T, typename IdxT>
347439
void cuvs_cagra<T, IdxT>::load(const std::string& file)
348440
{
349-
index_ = std::make_shared<cuvs::neighbors::cagra::index<T, IdxT>>(handle_);
350-
cuvs::neighbors::cagra::deserialize(handle_, file, index_.get());
441+
std::ifstream meta(file + ".submeta", std::ios::in);
442+
if (index_params_.num_dataset_splits > 1 &&
443+
index_params_.merge_type == CagraMergeType::kLogical && meta.good()) {
444+
// Load multiple sub-indices for logical merge
445+
size_t count;
446+
meta >> count;
447+
meta.close();
448+
sub_indices_.clear();
449+
for (size_t i = 0; i < count; ++i) {
450+
std::string subfile = file + (i == 0 ? "" : ".subidx." + std::to_string(i));
451+
auto sub_index = std::make_shared<cuvs::neighbors::cagra::index<T, IdxT>>(handle_);
452+
cuvs::neighbors::cagra::deserialize(handle_, subfile, sub_index.get());
453+
sub_indices_.push_back(std::move(sub_index));
454+
}
455+
} else {
456+
index_ = std::make_shared<cuvs::neighbors::cagra::index<T, IdxT>>(handle_);
457+
cuvs::neighbors::cagra::deserialize(handle_, file, index_.get());
458+
}
351459
}
352460

353461
template <typename T, typename IdxT>
@@ -377,8 +485,41 @@ void cuvs_cagra<T, IdxT>::search_base(
377485
neighbors_view,
378486
distances_view);
379487
} else {
380-
cuvs::neighbors::cagra::search(
381-
handle_, search_params_, *index_, queries_view, neighbors_view, distances_view, *filter_);
488+
if (index_params_.num_dataset_splits <= 1 ||
489+
index_params_.merge_type == CagraMergeType::kPhysical) {
490+
cuvs::neighbors::cagra::search(
491+
handle_, search_params_, *index_, queries_view, neighbors_view, distances_view, *filter_);
492+
} else {
493+
if (index_params_.merge_type == CagraMergeType::kLogical) {
494+
cuvs::neighbors::cagra::merge_params merge_params{index_params_.cagra_params};
495+
merge_params.merge_strategy = cuvs::neighbors::MergeStrategy::MERGE_STRATEGY_LOGICAL;
496+
497+
// Create wrapped indices for composite merge
498+
std::vector<std::shared_ptr<cuvs::neighbors::IndexWrapper<T, IdxT, algo_base::index_type>>>
499+
wrapped_indices;
500+
wrapped_indices.reserve(sub_indices_.size());
501+
for (auto& ptr : sub_indices_) {
502+
auto index_wrapper =
503+
cuvs::neighbors::cagra::make_index_wrapper<T, IdxT, algo_base::index_type>(ptr.get());
504+
wrapped_indices.push_back(index_wrapper);
505+
}
506+
507+
raft::resources composite_handle(handle_);
508+
size_t n_streams = wrapped_indices.size();
509+
raft::resource::set_cuda_stream_pool(composite_handle,
510+
std::make_shared<rmm::cuda_stream_pool>(n_streams));
511+
512+
auto merged_index =
513+
cuvs::neighbors::composite::merge(composite_handle, merge_params, wrapped_indices);
514+
cuvs::neighbors::filtering::none_sample_filter empty_filter;
515+
merged_index->search(composite_handle,
516+
search_params_,
517+
queries_view,
518+
neighbors_view,
519+
distances_view,
520+
empty_filter);
521+
}
522+
}
382523
}
383524
}
384525

0 commit comments

Comments
 (0)