2424#include < cuvs/distance/distance.hpp>
2525#include < cuvs/neighbors/cagra.hpp>
2626#include < cuvs/neighbors/common.hpp>
27+ #include < cuvs/neighbors/composite/merge.hpp>
2728#include < cuvs/neighbors/dynamic_batching.hpp>
2829#include < cuvs/neighbors/ivf_pq.hpp>
2930#include < cuvs/neighbors/nn_descent.hpp>
4445#include < iostream>
4546#include < memory>
4647#include < optional>
48+ #include < raft/util/integer_utils.hpp>
4749#include < stdexcept>
4850#include < string>
4951#include < type_traits>
52+ #include < vector>
5053
5154namespace cuvs ::bench {
5255
5356enum class AllocatorType { kHostPinned , kHostHugePage , kDevice };
5457enum class CagraBuildAlgo { kAuto , kIvfPq , kNnDescent };
58+ enum class CagraMergeType { kPhysical , kLogical };
5559
5660template <typename T, typename IdxT>
5761class cuvs_cagra : public algo <T>, public algo_gpu {
@@ -80,6 +84,8 @@ class cuvs_cagra : public algo<T>, public algo_gpu {
8084 std::optional<float > ivf_pq_refine_rate = std::nullopt ;
8185 std::optional<cuvs::neighbors::ivf_pq::index_params> ivf_pq_build_params = std::nullopt ;
8286 std::optional<cuvs::neighbors::ivf_pq::search_params> ivf_pq_search_params = std::nullopt ;
87+ size_t num_dataset_splits = 1 ;
88+ CagraMergeType merge_type = CagraMergeType::kPhysical ;
8389
8490 void prepare_build_params (const raft::extent_2d<IdxT>& dataset_extents)
8591 {
@@ -188,6 +194,7 @@ class cuvs_cagra : public algo<T>, public algo_gpu {
188194 bool dynamic_batching_conservative_dispatch_;
189195
190196 std::shared_ptr<cuvs::neighbors::filtering::base_filter> filter_;
197+ std::vector<std::shared_ptr<cuvs::neighbors::cagra::index<T, IdxT>>> sub_indices_;
191198
192199 inline rmm::device_async_resource_ref get_mr (AllocatorType mem_type)
193200 {
@@ -211,10 +218,57 @@ void cuvs_cagra<T, IdxT>::build(const T* dataset, size_t nrow)
211218 auto dataset_view_device =
212219 raft::make_mdspan<const T, IdxT, raft::row_major, false , true >(dataset, dataset_extents);
213220 bool dataset_is_on_host = raft::get_device_for_address (dataset) == -1 ;
221+ if (index_params_.num_dataset_splits <= 1 ) {
222+ index_ = std::make_shared<cuvs::neighbors::cagra::index<T, IdxT>>(std::move (
223+ dataset_is_on_host ? cuvs::neighbors::cagra::build (handle_, params, dataset_view_host)
224+ : cuvs::neighbors::cagra::build (handle_, params, dataset_view_device)));
225+ } else {
226+ IdxT rows_per_split =
227+ raft::ceildiv<IdxT>(nrow, static_cast <IdxT>(index_params_.num_dataset_splits ));
228+ for (size_t i = 0 ; i < index_params_.num_dataset_splits ; ++i) {
229+ IdxT start = static_cast <IdxT>(i * rows_per_split);
230+ if (start >= nrow) break ;
231+ IdxT rows = std::min (rows_per_split, static_cast <IdxT>(nrow) - start);
232+ const T* sub_ptr = dataset + static_cast <size_t >(start) * dimension_;
233+ auto sub_host =
234+ raft::make_host_matrix_view<const T, int64_t , raft::row_major>(sub_ptr, rows, dimension_);
235+ auto sub_dev =
236+ raft::make_device_matrix_view<const T, int64_t , raft::row_major>(sub_ptr, rows, dimension_);
237+
238+ auto sub_index =
239+ cuvs::neighbors::cagra::index<T, IdxT>(handle_, index_params_.cagra_params .metric );
240+ if (index_params_.merge_type == CagraMergeType::kPhysical ) {
241+ if (dataset_is_on_host) {
242+ sub_index.update_dataset (handle_, sub_host);
243+ } else {
244+ sub_index.update_dataset (handle_, sub_dev);
245+ }
246+ }
247+ if (index_params_.merge_type == CagraMergeType::kLogical ) {
248+ if (dataset_is_on_host) {
249+ sub_index = cuvs::neighbors::cagra::build (handle_, params, sub_host);
250+ } else {
251+ sub_index = cuvs::neighbors::cagra::build (handle_, params, sub_dev);
252+ }
253+ }
254+ auto sub_index_shared =
255+ std::make_shared<cuvs::neighbors::cagra::index<T, IdxT>>(std::move (sub_index));
256+ sub_indices_.push_back (std::move (sub_index_shared));
257+ }
258+ if (index_params_.merge_type == CagraMergeType::kPhysical ) {
259+ cuvs::neighbors::cagra::merge_params merge_params{index_params_.cagra_params };
260+ merge_params.merge_strategy = cuvs::neighbors::MergeStrategy::MERGE_STRATEGY_PHYSICAL;
261+
262+ std::vector<cuvs::neighbors::cagra::index<T, IdxT>*> indices;
263+ indices.reserve (sub_indices_.size ());
264+ for (auto & ptr : sub_indices_) {
265+ indices.push_back (ptr.get ());
266+ }
214267
215- index_ = std::make_shared<cuvs::neighbors::cagra::index<T, IdxT>>(std::move (
216- dataset_is_on_host ? cuvs::neighbors::cagra::build (handle_, params, dataset_view_host)
217- : cuvs::neighbors::cagra::build (handle_, params, dataset_view_device)));
268+ index_ = std::make_shared<cuvs::neighbors::cagra::index<T, IdxT>>(
269+ std::move (cuvs::neighbors::cagra::merge (handle_, merge_params, indices)));
270+ }
271+ }
218272}
219273
220274inline auto allocator_to_string (AllocatorType mem_type) -> std::string
@@ -233,7 +287,7 @@ template <typename T, typename IdxT>
233287void cuvs_cagra<T, IdxT>::set_search_param(const search_param_base& param,
234288 const void * filter_bitset)
235289{
236- filter_ = make_cuvs_filter (filter_bitset, index_->size ());
290+ if (index_) { filter_ = make_cuvs_filter (filter_bitset, index_->size ()); }
237291 auto sp = dynamic_cast <const search_param&>(param);
238292 bool needs_dynamic_batcher_update =
239293 (dynamic_batching_max_batch_size_ != sp.dynamic_batching_max_batch_size ) ||
@@ -314,27 +368,65 @@ void cuvs_cagra<T, IdxT>::set_search_param(const search_param_base& param,
314368template <typename T, typename IdxT>
315369void cuvs_cagra<T, IdxT>::set_search_dataset(const T* dataset, size_t nrow)
316370{
317- using ds_idx_type = decltype (index_->data ().n_rows ());
318- bool is_vpq =
319- dynamic_cast <const cuvs::neighbors::vpq_dataset<half, ds_idx_type>*>(&index_->data ()) ||
320- dynamic_cast <const cuvs::neighbors::vpq_dataset<float , ds_idx_type>*>(&index_->data ());
321- // It can happen that we are re-using a previous algo object which already has
322- // the dataset set. Check if we need update.
323- if (static_cast <size_t >(input_dataset_v_->extent (0 )) != nrow ||
324- input_dataset_v_->data_handle () != dataset) {
325- *input_dataset_v_ = raft::make_device_matrix_view<const T, int64_t >(dataset, nrow, this ->dim_ );
326- need_dataset_update_ = !is_vpq; // ignore update if this is a VPQ dataset.
371+ if (index_params_.num_dataset_splits > 1 &&
372+ index_params_.merge_type == CagraMergeType::kLogical ) {
373+ bool dataset_is_on_host = raft::get_device_for_address (dataset) == -1 ;
374+ IdxT rows_per_split =
375+ raft::ceildiv<IdxT>(nrow, static_cast <IdxT>(index_params_.num_dataset_splits ));
376+ for (size_t i = 0 ; i < sub_indices_.size (); ++i) {
377+ IdxT start = static_cast <IdxT>(i * rows_per_split);
378+ if (start >= nrow) break ;
379+ IdxT rows = std::min (rows_per_split, static_cast <IdxT>(nrow) - start);
380+ const T* sub_ptr = dataset + static_cast <size_t >(start) * dimension_;
381+ auto sub_host =
382+ raft::make_host_matrix_view<const T, int64_t , raft::row_major>(sub_ptr, rows, dimension_);
383+ auto sub_dev =
384+ raft::make_device_matrix_view<const T, int64_t , raft::row_major>(sub_ptr, rows, dimension_);
385+ auto sub_index = sub_indices_[i].get ();
386+ if (index_params_.merge_type == CagraMergeType::kLogical ) {
387+ if (dataset_is_on_host) {
388+ sub_index->update_dataset (handle_, sub_host);
389+ } else {
390+ sub_index->update_dataset (handle_, sub_dev);
391+ }
392+ }
393+ }
394+ need_dataset_update_ = false ;
395+ } else {
396+ using ds_idx_type = decltype (index_->data ().n_rows ());
397+ bool is_vpq =
398+ dynamic_cast <const cuvs::neighbors::vpq_dataset<half, ds_idx_type>*>(&index_->data ()) ||
399+ dynamic_cast <const cuvs::neighbors::vpq_dataset<float , ds_idx_type>*>(&index_->data ());
400+ // It can happen that we are re-using a previous algo object which already has
401+ // the dataset set. Check if we need update.
402+ if (static_cast <size_t >(input_dataset_v_->extent (0 )) != nrow ||
403+ input_dataset_v_->data_handle () != dataset) {
404+ *input_dataset_v_ =
405+ raft::make_device_matrix_view<const T, int64_t >(dataset, nrow, this ->dim_ );
406+ need_dataset_update_ = !is_vpq; // ignore update if this is a VPQ dataset.
407+ }
327408 }
328409}
329410
330411template <typename T, typename IdxT>
331412void cuvs_cagra<T, IdxT>::save(const std::string& file) const
332413{
333- using ds_idx_type = decltype (index_->data ().n_rows ());
334- bool is_vpq =
335- dynamic_cast <const cuvs::neighbors::vpq_dataset<half, ds_idx_type>*>(&index_->data ()) ||
336- dynamic_cast <const cuvs::neighbors::vpq_dataset<float , ds_idx_type>*>(&index_->data ());
337- cuvs::neighbors::cagra::serialize (handle_, file, *index_, is_vpq);
414+ if (index_params_.num_dataset_splits > 1 &&
415+ index_params_.merge_type == CagraMergeType::kLogical ) {
416+ for (size_t i = 0 ; i < sub_indices_.size (); ++i) {
417+ std::string subfile = file + (i == 0 ? " " : " .subidx." + std::to_string (i));
418+ cuvs::neighbors::cagra::serialize (handle_, subfile, *sub_indices_[i], false );
419+ }
420+ std::ofstream f (file + " .submeta" , std::ios::out);
421+ f << sub_indices_.size ();
422+ f.close ();
423+ } else {
424+ using ds_idx_type = decltype (index_->data ().n_rows ());
425+ bool is_vpq =
426+ dynamic_cast <const cuvs::neighbors::vpq_dataset<half, ds_idx_type>*>(&index_->data ()) ||
427+ dynamic_cast <const cuvs::neighbors::vpq_dataset<float , ds_idx_type>*>(&index_->data ());
428+ cuvs::neighbors::cagra::serialize (handle_, file, *index_, is_vpq);
429+ }
338430}
339431
340432template <typename T, typename IdxT>
@@ -346,8 +438,24 @@ void cuvs_cagra<T, IdxT>::save_to_hnswlib(const std::string& file) const
346438template <typename T, typename IdxT>
347439void cuvs_cagra<T, IdxT>::load(const std::string& file)
348440{
349- index_ = std::make_shared<cuvs::neighbors::cagra::index<T, IdxT>>(handle_);
350- cuvs::neighbors::cagra::deserialize (handle_, file, index_.get ());
441+ std::ifstream meta (file + " .submeta" , std::ios::in);
442+ if (index_params_.num_dataset_splits > 1 &&
443+ index_params_.merge_type == CagraMergeType::kLogical && meta.good ()) {
444+ // Load multiple sub-indices for logical merge
445+ size_t count;
446+ meta >> count;
447+ meta.close ();
448+ sub_indices_.clear ();
449+ for (size_t i = 0 ; i < count; ++i) {
450+ std::string subfile = file + (i == 0 ? " " : " .subidx." + std::to_string (i));
451+ auto sub_index = std::make_shared<cuvs::neighbors::cagra::index<T, IdxT>>(handle_);
452+ cuvs::neighbors::cagra::deserialize (handle_, subfile, sub_index.get ());
453+ sub_indices_.push_back (std::move (sub_index));
454+ }
455+ } else {
456+ index_ = std::make_shared<cuvs::neighbors::cagra::index<T, IdxT>>(handle_);
457+ cuvs::neighbors::cagra::deserialize (handle_, file, index_.get ());
458+ }
351459}
352460
353461template <typename T, typename IdxT>
@@ -377,8 +485,41 @@ void cuvs_cagra<T, IdxT>::search_base(
377485 neighbors_view,
378486 distances_view);
379487 } else {
380- cuvs::neighbors::cagra::search (
381- handle_, search_params_, *index_, queries_view, neighbors_view, distances_view, *filter_);
488+ if (index_params_.num_dataset_splits <= 1 ||
489+ index_params_.merge_type == CagraMergeType::kPhysical ) {
490+ cuvs::neighbors::cagra::search (
491+ handle_, search_params_, *index_, queries_view, neighbors_view, distances_view, *filter_);
492+ } else {
493+ if (index_params_.merge_type == CagraMergeType::kLogical ) {
494+ cuvs::neighbors::cagra::merge_params merge_params{index_params_.cagra_params };
495+ merge_params.merge_strategy = cuvs::neighbors::MergeStrategy::MERGE_STRATEGY_LOGICAL;
496+
497+ // Create wrapped indices for composite merge
498+ std::vector<std::shared_ptr<cuvs::neighbors::IndexWrapper<T, IdxT, algo_base::index_type>>>
499+ wrapped_indices;
500+ wrapped_indices.reserve (sub_indices_.size ());
501+ for (auto & ptr : sub_indices_) {
502+ auto index_wrapper =
503+ cuvs::neighbors::cagra::make_index_wrapper<T, IdxT, algo_base::index_type>(ptr.get ());
504+ wrapped_indices.push_back (index_wrapper);
505+ }
506+
507+ raft::resources composite_handle (handle_);
508+ size_t n_streams = wrapped_indices.size ();
509+ raft::resource::set_cuda_stream_pool (composite_handle,
510+ std::make_shared<rmm::cuda_stream_pool>(n_streams));
511+
512+ auto merged_index =
513+ cuvs::neighbors::composite::merge (composite_handle, merge_params, wrapped_indices);
514+ cuvs::neighbors::filtering::none_sample_filter empty_filter;
515+ merged_index->search (composite_handle,
516+ search_params_,
517+ queries_view,
518+ neighbors_view,
519+ distances_view,
520+ empty_filter);
521+ }
522+ }
382523 }
383524}
384525
0 commit comments