@@ -467,6 +467,41 @@ void process_and_fill_codes(
467467 RAFT_FAIL (" Invalid pq_bits (%u), the value must be within [4, 16]" , pq_bits);
468468 }
469469 }(pq_bits);
470+ bool need_copy_to_device =
471+ cuvs::spatial::knn::detail::utils::check_pointer_residency (dataset.data_handle ()) ==
472+ cuvs::spatial::knn::detail::utils::pointer_residency::host_only;
473+ bool need_batching = n_rows > kReasonableMaxBatchSize ;
474+ auto launch_work = [&](auto & dataset_view, auto & labels_view, auto & codes_view) {
475+ if (inline_vq_labels || (!vq_labels.empty () && !vq_centers.empty ())) {
476+ predict_vq<label_t >(res, dataset_view, vq_centers, labels_view);
477+ }
478+ dim3 blocks (
479+ raft::div_rounding_up_safe<ix_t >(dataset_view.extent (0 ), kBlockSize / threads_per_vec), 1 , 1 );
480+ kernel<<<blocks, threads, sharedMemorySize, stream>>> (codes_view,
481+ dataset_view,
482+ pq_centers,
483+ vq_centers,
484+ raft::make_const_mdspan (labels_view),
485+ rows_in_shared_memory,
486+ pq_bits,
487+ inline_vq_labels);
488+ RAFT_CUDA_TRY (cudaPeekAtLastError ());
489+ };
490+ auto batch_labels = raft::make_device_vector<label_t , IdxT>(res, 0 );
491+ if (!need_batching && !need_copy_to_device) {
492+ // No batching needed, launch the kernel directly
493+ auto dataset_view = raft::make_device_matrix_view (dataset.data_handle (), n_rows, dim);
494+ auto labels_view = raft::make_device_vector_view<label_t , IdxT>(nullptr , 0 );
495+ if (inline_vq_labels) {
496+ batch_labels = raft::make_device_vector<label_t , IdxT>(res, dataset_view.extent (0 ));
497+ labels_view = batch_labels.view ();
498+ } else if (!vq_labels.empty () && !vq_centers.empty ()) {
499+ labels_view = vq_labels;
500+ }
501+ launch_work (dataset_view, labels_view, codes);
502+ return ;
503+ }
504+
470505 for (const auto & batch : cuvs::spatial::knn::detail::utils::batch_load_iterator (
471506 dataset.data_handle (),
472507 n_rows,
@@ -475,53 +510,20 @@ void process_and_fill_codes(
475510 stream,
476511 rmm::mr::get_current_device_resource ())) {
477512 auto batch_view = raft::make_device_matrix_view (batch.data (), ix_t (batch.size ()), dim);
478- auto batch_labels = raft::make_device_vector<label_t , IdxT>(res, 0 );
479513 auto batch_labels_view = raft::make_device_vector_view<label_t , IdxT>(nullptr , 0 );
480514 if (inline_vq_labels) {
481515 batch_labels = raft::make_device_vector<label_t , IdxT>(res, batch.size ());
482516 batch_labels_view = batch_labels.view ();
483- predict_vq<label_t >(res, batch_view, vq_centers, batch_labels_view);
484- } else {
485- if (!vq_labels.empty () && !vq_centers.empty ()) {
486- batch_labels_view = raft::make_device_vector_view<label_t , IdxT>(
487- vq_labels.data_handle () + batch.offset (), batch.size ());
488- predict_vq<label_t >(res, batch_view, vq_centers, batch_labels_view);
489- }
517+ } else if (!vq_labels.empty () && !vq_centers.empty ()) {
518+ batch_labels_view = raft::make_device_vector_view<label_t , IdxT>(
519+ vq_labels.data_handle () + batch.offset (), batch.size ());
490520 }
491- dim3 blocks (raft::div_rounding_up_safe<ix_t >(n_rows, kBlockSize / threads_per_vec), 1 , 1 );
492- kernel<<<blocks, threads, sharedMemorySize, stream>>> (
493- raft::make_device_matrix_view<uint8_t , IdxT>(
494- codes.data_handle () + batch.offset () * codes_rowlen, batch.size (), codes_rowlen),
495- batch_view,
496- pq_centers,
497- vq_centers,
498- raft::make_const_mdspan (batch_labels_view),
499- rows_in_shared_memory,
500- pq_bits,
501- inline_vq_labels);
502- RAFT_CUDA_TRY (cudaPeekAtLastError ());
521+ auto batch_codes_view = raft::make_device_matrix_view<uint8_t , IdxT>(
522+ codes.data_handle () + batch.offset () * codes_rowlen, batch.size (), codes_rowlen);
523+ launch_work (batch_view, batch_labels_view, batch_codes_view);
503524 }
504525}
505526
506- template <typename NewMathT, typename OldMathT, typename IdxT>
507- auto vpq_convert_math_type (const raft::resources& res, vpq_dataset<OldMathT, IdxT>&& src)
508- -> vpq_dataset<NewMathT, IdxT>
509- {
510- auto vq_code_book = raft::make_device_mdarray<NewMathT>(res, src.vq_code_book .extents ());
511- auto pq_code_book = raft::make_device_mdarray<NewMathT>(res, src.pq_code_book .extents ());
512-
513- raft::linalg::map (res,
514- vq_code_book.view (),
515- cuvs::spatial::knn::detail::utils::mapping<NewMathT>{},
516- raft::make_const_mdspan (src.vq_code_book .view ()));
517- raft::linalg::map (res,
518- pq_code_book.view (),
519- cuvs::spatial::knn::detail::utils::mapping<NewMathT>{},
520- raft::make_const_mdspan (src.pq_code_book .view ()));
521- return vpq_dataset<NewMathT, IdxT>{
522- std::move (vq_code_book), std::move (pq_code_book), std::move (src.data )};
523- }
524-
525527// Helper for operations using vectorized loads of raft::TxN_t
526528template <typename MathT, int VectorSize>
527529struct vec_op : raft::TxN_t<MathT, VectorSize> {
@@ -858,14 +860,40 @@ void process_and_fill_codes_subspaces(
858860 }
859861 }(pq_bits);
860862
861- ix_t max_batch_size = std::min<ix_t >(n_rows, kReasonableMaxBatchSize );
862- auto copy_stream = raft::resource::get_cuda_stream (res); // Using the main stream by default
863- bool enable_prefetch = false ;
864- if (res.has_resource_factory (raft::resource::resource_type::CUDA_STREAM_POOL)) {
865- if (raft::resource::get_stream_pool_size (res) >= 1 ) {
866- enable_prefetch = true ;
867- copy_stream = raft::resource::get_stream_from_stream_pool (res);
863+ ix_t max_batch_size = std::min<ix_t >(n_rows, kReasonableMaxBatchSize );
864+ auto copy_stream = raft::resource::get_cuda_stream (res); // Using the main stream by default
865+ bool enable_prefetch_stream = false ;
866+ bool has_cuda_stream_pool_resource =
867+ res.has_resource_factory (raft::resource::resource_type::CUDA_STREAM_POOL) &&
868+ raft::resource::get_stream_pool_size (res) >= 1 ;
869+ bool need_copy_to_device =
870+ cuvs::spatial::knn::detail::utils::check_pointer_residency (dataset.data_handle ()) ==
871+ cuvs::spatial::knn::detail::utils::pointer_residency::host_only;
872+ bool need_batching = n_rows > kReasonableMaxBatchSize ;
873+ auto launch_work = [&](auto & dataset_view, auto & labels_view, auto & codes_view) {
874+ if (!vq_labels.empty () && !vq_centers.empty ()) {
875+ predict_vq<label_t >(res, dataset_view, vq_centers, labels_view);
868876 }
877+ dim3 blocks (
878+ raft::div_rounding_up_safe<ix_t >(dataset_view.extent (0 ), kBlockSize / threads_per_vec), 1 , 1 );
879+ kernel<<<blocks, threads, shared_memory_size, stream>>> (codes_view,
880+ dataset_view,
881+ pq_centers,
882+ vq_centers,
883+ raft::make_const_mdspan (labels_view),
884+ pq_bits,
885+ shared_memory_size > 0 );
886+ RAFT_CUDA_TRY (cudaPeekAtLastError ());
887+ };
888+ if (!need_batching && !need_copy_to_device) {
889+ // No batching and no copy to device needed, launch the kernel directly
890+ auto dataset_view = raft::make_device_matrix_view (dataset.data_handle (), n_rows, dim);
891+ launch_work (dataset_view, vq_labels, codes);
892+ return ;
893+ }
894+ if (has_cuda_stream_pool_resource && need_copy_to_device) {
895+ enable_prefetch_stream = true ;
896+ copy_stream = raft::resource::get_stream_from_stream_pool (res);
869897 }
870898 auto vec_batches = cuvs::spatial::knn::detail::utils::batch_load_iterator (
871899 dataset.data_handle (),
@@ -874,62 +902,22 @@ void process_and_fill_codes_subspaces(
874902 max_batch_size,
875903 copy_stream,
876904 raft::resource::get_workspace_resource (res),
877- enable_prefetch );
905+ enable_prefetch_stream );
878906 vec_batches.prefetch_next_batch ();
879907 for (const auto & batch : vec_batches) {
880908 auto batch_view = raft::make_device_matrix_view (batch.data (), ix_t (batch.size ()), dim);
881909 auto batch_labels = raft::make_device_vector_view<label_t , IdxT>(nullptr , 0 );
882910 if (!vq_labels.empty () && !vq_centers.empty ()) {
883911 batch_labels = raft::make_device_vector_view<label_t , IdxT>(
884912 vq_labels.data_handle () + batch.offset (), batch.size ());
885- predict_vq<label_t >(res, batch_view, vq_centers, batch_labels);
886913 }
887- dim3 blocks (raft::div_rounding_up_safe<ix_t >(batch.size (), kBlockSize / threads_per_vec), 1 , 1 );
888- kernel<<<blocks, threads, shared_memory_size, stream>>> (
889- raft::make_device_matrix_view<uint8_t , IdxT>(
890- codes.data_handle () + batch.offset () * codes_rowlen, batch.size (), codes_rowlen),
891- batch_view,
892- pq_centers,
893- vq_centers,
894- raft::make_const_mdspan (batch_labels),
895- pq_bits,
896- shared_memory_size > 0 );
897- RAFT_CUDA_TRY (cudaPeekAtLastError ());
898- vec_batches.prefetch_next_batch ();
899- raft::resource::sync_stream (res);
914+ auto batch_codes_view = raft::make_device_matrix_view<uint8_t , IdxT>(
915+ codes.data_handle () + batch.offset () * codes_rowlen, batch.size (), codes_rowlen);
916+ launch_work (batch_view, batch_labels, batch_codes_view);
917+ if (enable_prefetch_stream) {
918+ vec_batches.prefetch_next_batch ();
919+ raft::resource::sync_stream (res);
920+ }
900921 }
901922}
902-
903- template <typename DatasetT, typename MathT, typename IdxT>
904- auto vpq_build (const raft::resources& res, const vpq_params& params, const DatasetT& dataset)
905- -> vpq_dataset<MathT, IdxT>
906- {
907- using label_t = uint32_t ;
908- // Use a heuristic to impute missing parameters.
909- auto ps = fill_missing_params_heuristics (params, dataset);
910-
911- // Train codes
912- auto vq_code_book = train_vq<MathT>(res, ps, dataset);
913- auto pq_code_book =
914- train_pq<MathT>(res, ps, dataset, raft::make_const_mdspan (vq_code_book.view ()));
915-
916- // Encode dataset
917- const IdxT n_rows = dataset.extent (0 );
918- const IdxT codes_rowlen = sizeof (label_t ) * (1 + raft::div_rounding_up_safe<IdxT>(
919- ps.pq_dim * ps.pq_bits , 8 * sizeof (label_t )));
920-
921- auto codes = raft::make_device_matrix<uint8_t , IdxT, raft::row_major>(res, n_rows, codes_rowlen);
922- process_and_fill_codes<MathT, IdxT>(res,
923- ps,
924- dataset,
925- raft::make_const_mdspan (pq_code_book.view ()),
926- raft::make_const_mdspan (vq_code_book.view ()),
927- raft::make_device_vector_view<label_t , IdxT>(nullptr , 0 ),
928- codes.view (),
929- true );
930-
931- return vpq_dataset<MathT, IdxT>{
932- std::move (vq_code_book), std::move (pq_code_book), std::move (codes)};
933- }
934-
935923} // namespace cuvs::neighbors::detail
0 commit comments