@@ -291,6 +291,7 @@ struct index : cuvs::neighbors::index {
291291 using index_type = IdxT;
292292 using value_type = T;
293293 using dataset_index_type = int64_t ;
294+ using graph_index_type = uint32_t ;
294295
295296 static_assert (!raft::is_narrowing_v<uint32_t , IdxT>,
296297 " IdxT must be able to represent all values of uint32_t" );
@@ -334,11 +335,21 @@ struct index : cuvs::neighbors::index {
334335
335336 /* * neighborhood graph [size, graph-degree] */
336337 [[nodiscard]] inline auto graph () const noexcept
337- -> raft::device_matrix_view<const IdxT , int64_t, raft::row_major>
338+ -> raft::device_matrix_view<const graph_index_type , int64_t, raft::row_major>
338339 {
339340 return graph_view_;
340341 }
341342
343+ /* * Mapping from internal graph node indices to the original user-provided indices. */
344+ [[nodiscard]] inline auto source_indices () const noexcept
345+ -> std::optional<raft::device_vector_view<const index_type, int64_t>>
346+ {
347+ return source_indices_.has_value ()
348+ ? std::optional<raft::device_vector_view<const index_type, int64_t >>(
349+ source_indices_->view ())
350+ : std::nullopt ;
351+ }
352+
342353 /* * Dataset norms for cosine distance [size] */
343354 [[nodiscard]] inline auto dataset_norms () const noexcept
344355 -> std::optional<raft::device_vector_view<const float, int64_t>>
@@ -361,7 +372,7 @@ struct index : cuvs::neighbors::index {
361372 cuvs::distance::DistanceType metric = cuvs::distance::DistanceType::L2Expanded)
362373 : cuvs::neighbors::index(),
363374 metric_ (metric),
364- graph_(raft::make_device_matrix<IdxT , int64_t >(res, 0 , 0 )),
375+ graph_(raft::make_device_matrix<graph_index_type , int64_t >(res, 0 , 0 )),
365376 dataset_(new cuvs::neighbors::empty_dataset<int64_t >(0 )),
366377 dataset_norms_(std::nullopt )
367378 {
@@ -405,7 +416,7 @@ struct index : cuvs::neighbors::index {
405416 * using namespace raft::neighbors::experimental;
406417 *
407418 * auto dataset = raft::make_device_matrix<float, int64_t>(res, n_rows, n_cols);
408- * auto knn_graph = raft::make_device_matrix<uint32_n , int64_t>(res, n_rows, graph_degree);
419+ * auto knn_graph = raft::make_device_matrix<uint32_t , int64_t>(res, n_rows, graph_degree);
409420 *
410421 * // custom loading and graph creation
411422 * // load_dataset(dataset.view());
@@ -424,11 +435,13 @@ struct index : cuvs::neighbors::index {
424435 index (raft::resources const & res,
425436 cuvs::distance::DistanceType metric,
426437 raft::mdspan<const T, raft::matrix_extent<int64_t >, raft::row_major, data_accessor> dataset,
427- raft::mdspan<const IdxT, raft::matrix_extent<int64_t >, raft::row_major, graph_accessor>
428- knn_graph)
438+ raft::mdspan<const graph_index_type,
439+ raft::matrix_extent<int64_t >,
440+ raft::row_major,
441+ graph_accessor> knn_graph)
429442 : cuvs::neighbors::index(),
430443 metric_(metric),
431- graph_(raft::make_device_matrix<IdxT , int64_t >(res, 0 , 0 )),
444+ graph_(raft::make_device_matrix<graph_index_type , int64_t >(res, 0 , 0 )),
432445 dataset_(make_aligned_dataset(res, dataset, 16 )),
433446 dataset_norms_(std::nullopt )
434447 {
@@ -536,8 +549,9 @@ struct index : cuvs::neighbors::index {
536549 * Since the new graph is a device array, we store a reference to that, and it is
537550 * the caller's responsibility to ensure that knn_graph stays alive as long as the index.
538551 */
539- void update_graph (raft::resources const & res,
540- raft::device_matrix_view<const IdxT, int64_t , raft::row_major> knn_graph)
552+ void update_graph (
553+ raft::resources const & res,
554+ raft::device_matrix_view<const graph_index_type, int64_t , raft::row_major> knn_graph)
541555 {
542556 graph_view_ = knn_graph;
543557 }
@@ -547,16 +561,19 @@ struct index : cuvs::neighbors::index {
547561 *
548562 * We create a copy of the graph on the device. The index manages the lifetime of this copy.
549563 */
550- void update_graph (raft::resources const & res,
551- raft::host_matrix_view<const IdxT, int64_t , raft::row_major> knn_graph)
564+ void update_graph (
565+ raft::resources const & res,
566+ raft::host_matrix_view<const graph_index_type, int64_t , raft::row_major> knn_graph)
552567 {
553568 RAFT_LOG_DEBUG (" Copying CAGRA knn graph from host to device" );
554569
555570 if ((graph_.extent (0 ) != knn_graph.extent (0 )) || (graph_.extent (1 ) != knn_graph.extent (1 ))) {
556571 // clear existing memory before allocating to prevent OOM errors on large graphs
557- if (graph_.size ()) { graph_ = raft::make_device_matrix<IdxT, int64_t >(res, 0 , 0 ); }
558- graph_ =
559- raft::make_device_matrix<IdxT, int64_t >(res, knn_graph.extent (0 ), knn_graph.extent (1 ));
572+ if (graph_.size ()) {
573+ graph_ = raft::make_device_matrix<graph_index_type, int64_t >(res, 0 , 0 );
574+ }
575+ graph_ = raft::make_device_matrix<graph_index_type, int64_t >(
576+ res, knn_graph.extent (0 ), knn_graph.extent (1 ));
560577 }
561578 raft::copy (graph_.data_handle (),
562579 knn_graph.data_handle (),
@@ -565,11 +582,52 @@ struct index : cuvs::neighbors::index {
565582 graph_view_ = graph_.view ();
566583 }
567584
585+ /* *
586+ * Replace the source indices with a new source indices taking the ownership of the passed vector.
587+ */
588+ void update_source_indices (raft::device_vector<index_type, int64_t >&& source_indices)
589+ {
590+ RAFT_EXPECTS (source_indices.extent (0 ) == size (),
591+ " Source indices must have the same number of rows as the index" );
592+ source_indices_.emplace (std::move (source_indices));
593+ }
594+
595+ /* *
596+ * Copy the provided source indices into the index.
597+ */
598+ template <typename Accessor>
599+ void update_source_indices (
600+ raft::resources const & res,
601+ raft::mdspan<const index_type, raft::vector_extent<int64_t >, raft::row_major, Accessor>
602+ source_indices)
603+ {
604+ RAFT_EXPECTS (source_indices.extent (0 ) == size (),
605+ " Source indices must have the same number of rows as the index" );
606+ // Reset the array if it's not compatible to avoid using more memory than necessary.
607+ // NB: this likely is never triggered because we check the invariant above (but it doesn't
608+ // hurt).
609+ if (source_indices_.has_value ()) {
610+ if (source_indices_->extent (0 ) != source_indices.extent (0 )) { source_indices_.reset (); }
611+ }
612+ // Allocate the new array if needed.
613+ if (!source_indices_.has_value ()) {
614+ source_indices_.emplace (
615+ raft::make_device_vector<index_type, int64_t >(res, source_indices.extent (0 )));
616+ }
617+ // Copy the data.
618+ raft::copy (source_indices_->data_handle (),
619+ source_indices.data_handle (),
620+ source_indices.extent (0 ),
621+ raft::resource::get_cuda_stream (res));
622+ }
623+
568624 private:
569625 cuvs::distance::DistanceType metric_;
570- raft::device_matrix<IdxT , int64_t , raft::row_major> graph_;
571- raft::device_matrix_view<const IdxT , int64_t , raft::row_major> graph_view_;
626+ raft::device_matrix<graph_index_type , int64_t , raft::row_major> graph_;
627+ raft::device_matrix_view<const graph_index_type , int64_t , raft::row_major> graph_view_;
572628 std::unique_ptr<neighbors::dataset<dataset_index_type>> dataset_;
629+ // Mapping from internal graph node indices to the original user-provided indices.
630+ std::optional<raft::device_vector<IdxT, int64_t >> source_indices_;
573631 // only float distances supported at the moment
574632 std::optional<raft::device_vector<float , int64_t >> dataset_norms_;
575633
0 commit comments