@@ -68,7 +68,8 @@ __global__ void build_clusters(
6868template <typename LabelT, typename IdxT>
6969void compute_cluster_offsets (raft::resources const & dev_resources,
7070 raft::device_vector_view<const LabelT, IdxT> clusters,
71- raft::device_vector_view<LabelT, int64_t > cluster_sizes)
71+ raft::device_vector_view<LabelT, int64_t > cluster_sizes,
72+ raft::host_scalar_view<int64_t > max_cluster_size)
7273{
7374 cudaStream_t stream = raft::resource::get_cuda_stream (dev_resources);
7475 rmm::device_async_resource_ref device_memory =
@@ -103,10 +104,33 @@ void compute_cluster_offsets(raft::resources const& dev_resources,
103104 clusters.extent (0 ),
104105 stream);
105106
106- temp_storage_bytes = 0 ;
107+ int num_items = cluster_sizes.extent (0 );
108+
109+ // Compute max cluster size
110+ auto d_max_cluster_size = raft::make_device_scalar<int64_t >(dev_resources, 0 );
111+ temp_storage_bytes = 0 ;
112+
113+ cub::DeviceReduce::Max (nullptr ,
114+ temp_storage_bytes,
115+ cluster_sizes.data_handle (),
116+ d_max_cluster_size.data_handle (),
117+ num_items);
118+
119+ rmm::device_uvector<int64_t > temp_storage_max (temp_storage_bytes, stream, device_memory);
120+
121+ cub::DeviceReduce::Max (temp_storage_max.data (),
122+ temp_storage_bytes,
123+ cluster_sizes.data_handle (),
124+ d_max_cluster_size.data_handle (),
125+ num_items);
126+
127+ raft::copy (max_cluster_size.data_handle (),
128+ d_max_cluster_size.data_handle (),
129+ d_max_cluster_size.size (),
130+ stream);
107131 // Scan to sum cluster sizes and get cluster start ptrs in flat array
108132 // Done in place
109- int num_items = cluster_sizes. extent ( 0 ) ;
133+ temp_storage_bytes = 0 ;
110134
111135 cub::DeviceScan::ExclusiveSum (nullptr ,
112136 temp_storage_bytes,
@@ -368,6 +392,190 @@ void rescale_avq_centroids(raft::resources const& dev_resources,
368392 });
369393}
370394
395+ /* *
396+ * A class for loading clusters into a compact matrix (sparse gather)
397+ * for use in AVQ.
398+ *
399+ * There are two possible scenarios:
400+ * 1. Dataset is stored in device memory: No host buffers are allocated,
401+ * and the gather is performed on device
402+ * 2. Dataset is stored in host memory: Two pinned buffers are allocated
403+ * in host for fast DtoH copies of cluster ids, and fast HtoD copy of the
404+ * cluster matrix, while amortizing the cost of allocating pinned memory.
405+ * The gather is performed on cpu, overlapping with GPU compute. Copies are
406+ * allocated on the provided stream, allowing for overlapping with
407+ * other work on other streams.
408+ */
409+ template <typename T, typename LabelT>
410+ class cluster_loader {
411+ private:
412+ raft::pinned_matrix<T, int64_t > cluster_buf_;
413+ raft::pinned_vector<LabelT, int64_t > cluster_ids_buf_;
414+ raft::device_matrix<T, int64_t > d_cluster_buf_;
415+ raft::device_matrix<T, int64_t > d_cluster_copy_buf_;
416+ const T* dataset_ptr_;
417+ raft::host_vector_view<const LabelT> h_cluster_offsets_;
418+ raft::device_vector_view<const LabelT> cluster_ids_;
419+ cudaStream_t stream_;
420+ int64_t dim_;
421+ int64_t n_rows_;
422+ bool needs_copy_;
423+
424+ int64_t cur_idx_ = -1 ;
425+ int64_t copy_idx_ = -1 ;
426+
427+ size_t cluster_size (LabelT idx)
428+ {
429+ if (idx + 1 < h_cluster_offsets_.extent (0 )) {
430+ return h_cluster_offsets_ (idx + 1 ) - h_cluster_offsets_ (idx);
431+ }
432+ return n_rows_ - h_cluster_offsets_ (idx);
433+ }
434+
435+ cluster_loader (raft::resources const & res,
436+ const T* dataset_ptr,
437+ int64_t dim,
438+ int64_t n_rows,
439+ int64_t max_cluster_size,
440+ int64_t h_buf_size,
441+ raft::host_vector_view<LabelT> h_cluster_offsets,
442+ raft::device_vector_view<LabelT> cluster_ids,
443+ bool needs_copy,
444+ cudaStream_t stream)
445+ : dim_(dim),
446+ n_rows_ (n_rows),
447+ dataset_ptr_(dataset_ptr),
448+ cluster_buf_(raft::make_pinned_matrix<T, int64_t >(res, h_buf_size, dim)),
449+ cluster_ids_buf_(raft::make_pinned_vector<LabelT, int64_t >(res, h_buf_size)),
450+ d_cluster_buf_(raft::make_device_matrix<T, int64_t >(res, max_cluster_size, dim)),
451+ d_cluster_copy_buf_(raft::make_device_matrix<T, int64_t >(res, max_cluster_size, dim)),
452+ h_cluster_offsets_(h_cluster_offsets),
453+ cluster_ids_(cluster_ids),
454+ needs_copy_(needs_copy),
455+ stream_(stream)
456+ {
457+ }
458+
459+ public:
460+ cluster_loader (raft::resources const & res,
461+ raft::device_matrix_view<const T, int64_t > dataset_view,
462+ raft::host_vector_view<LabelT> h_cluster_offsets,
463+ raft::device_vector_view<LabelT> cluster_ids,
464+ int64_t max_cluster_size,
465+ cudaStream_t stream)
466+ : cluster_loader(res,
467+ dataset_view.data_handle(),
468+ dataset_view.extent(1 ),
469+ dataset_view.extent(0 ),
470+ max_cluster_size,
471+ 0,
472+ h_cluster_offsets,
473+ cluster_ids,
474+ false,
475+ stream)
476+
477+ {
478+ }
479+
480+ cluster_loader (raft::resources const & res,
481+ raft::host_matrix_view<const T, int64_t > dataset_view,
482+ raft::host_vector_view<LabelT> h_cluster_offsets,
483+ raft::device_vector_view<LabelT> cluster_ids,
484+ int64_t max_cluster_size,
485+ cudaStream_t stream)
486+ : cluster_loader(res,
487+ dataset_view.data_handle(),
488+ dataset_view.extent(1 ),
489+ dataset_view.extent(0 ),
490+ max_cluster_size,
491+ max_cluster_size,
492+ h_cluster_offsets,
493+ cluster_ids,
494+ true,
495+ stream)
496+
497+ {
498+ }
499+
500+ /* *
501+ * @brief load and return a view of the provided cluster
502+ *
503+ * @param res: the raft resources;
504+ * @param cluster_idx: the index of the cluster to be loaded
505+ * @return device_matrix_view of the cluster vectors
506+ */
507+ raft::device_matrix_view<T, int64_t > load_cluster (raft::resources const & res, LabelT cluster_idx)
508+ {
509+ size_t size = cluster_size (cluster_idx);
510+
511+ // Check if cluster is already loaded
512+ if (cur_idx_ != cluster_idx) {
513+ // If not, load the cluster
514+ if (copy_idx_ != cluster_idx) { prefetch_cluster (res, cluster_idx); }
515+
516+ // swap buffers
517+ std::swap (d_cluster_buf_, d_cluster_copy_buf_);
518+ std::swap (cur_idx_, copy_idx_);
519+ }
520+
521+ return raft::make_device_matrix_view<T, int64_t >(d_cluster_buf_.data_handle (), size, dim_);
522+ }
523+
524+ /* * @brief Perform gather operation on stream_
525+ *
526+ * @param res: the raft resources
527+ * @param cluster_idx: the index of the cluster
528+ */
529+ void prefetch_cluster (raft::resources const & res, LabelT cluster_idx)
530+ {
531+ if (cluster_idx >= h_cluster_offsets_.extent (0 )) { return ; }
532+
533+ size_t size = cluster_size (cluster_idx);
534+
535+ auto cluster_ids = raft::make_device_vector_view<const LabelT, int64_t >(
536+ cluster_ids_.data_handle () + h_cluster_offsets_ (cluster_idx), size);
537+
538+ auto cluster_vectors =
539+ raft::make_device_matrix_view<float , int64_t >(d_cluster_copy_buf_.data_handle (), size, dim_);
540+
541+ if (needs_copy_) {
542+ // htod
543+ auto h_cluster_ids =
544+ raft::make_pinned_vector_view<LabelT, int64_t >(cluster_ids_buf_.data_handle (), size);
545+
546+ raft::copy (
547+ h_cluster_ids.data_handle (), cluster_ids.data_handle (), cluster_ids.size (), stream_);
548+ raft::resource::sync_stream (res, stream_);
549+
550+ auto pinned_cluster = raft::make_pinned_matrix_view<T, int64_t >(
551+ cluster_buf_.data_handle (), cluster_vectors.extent (0 ), cluster_vectors.extent (1 ));
552+
553+ int n_threads = std::min<int >(omp_get_max_threads (), 32 );
554+ #pragma omp parallel for num_threads(n_threads)
555+ for (int i = 0 ; i < h_cluster_ids.extent (0 ); i++) {
556+ memcpy (pinned_cluster.data_handle () + i * pinned_cluster.extent (1 ),
557+ dataset_ptr_ + h_cluster_ids (i) * dim_,
558+ sizeof (T) * dim_);
559+ }
560+
561+ raft::copy (cluster_vectors.data_handle (),
562+ pinned_cluster.data_handle (),
563+ pinned_cluster.size (),
564+ stream_);
565+ raft::resource::sync_stream (res, stream_);
566+
567+ } else {
568+ // dtod
569+ auto dataset_view =
570+ raft::make_device_matrix_view<const T, int64_t >(dataset_ptr_, n_rows_, dim_);
571+
572+ raft::matrix::gather (res, dataset_view, cluster_ids, cluster_vectors);
573+ }
574+
575+ copy_idx_ = cluster_idx;
576+ }
577+ };
578+
371579/* *
372580 * @brief Perform AVQ adjustment on cluster centers
373581 *
@@ -396,63 +604,63 @@ void apply_avq(raft::resources const& res,
396604 raft::mdspan<const T, raft::matrix_extent<IdxT>, raft::row_major, Accessor> dataset,
397605 raft::device_matrix_view<T, IdxT> centroids_view,
398606 raft::device_vector_view<const LabelT, IdxT> labels_view,
399- float eta)
607+ float eta,
608+ cudaStream_t gather_stream)
400609{
401610 // Compute clusters
402611
403- cudaStream_t stream = raft::resource::get_cuda_stream (res);
404- auto cluster_ptrs = raft::make_device_vector<uint32_t , int64_t >(res, centroids_view.extent (0 ));
405- auto clusters = raft::make_device_vector<uint32_t , int64_t >(res, dataset.extent (0 ));
406-
407- compute_cluster_offsets (res, labels_view, cluster_ptrs.view ());
612+ cudaStream_t stream = raft::resource::get_cuda_stream (res);
613+ auto cluster_offsets = raft::make_device_vector<uint32_t , int64_t >(res, centroids_view.extent (0 ));
614+ auto clusters = raft::make_device_vector<uint32_t , int64_t >(res, dataset.extent (0 ));
615+ auto max_cluster_size = raft::make_host_scalar<int64_t >(0 );
408616
409- auto h_cluster_ptrs = raft::make_host_vector<uint32_t , int64_t >(cluster_ptrs.extent (0 ));
617+ compute_cluster_offsets (res, labels_view, cluster_offsets.view (), max_cluster_size.view ());
618+ auto h_cluster_offsets = raft::make_host_vector<uint32_t , int64_t >(cluster_offsets.extent (0 ));
410619
411- raft::copy (h_cluster_ptrs.data_handle (), cluster_ptrs.data_handle (), cluster_ptrs.size (), stream);
620+ raft::copy (
621+ h_cluster_offsets.data_handle (), cluster_offsets.data_handle (), cluster_offsets.size (), stream);
412622
413623 dim3 block (32 , 1 , 1 );
414624 dim3 grid ((dataset.extent (0 ) + block.x - 1 ) / block.x , 1 , 1 );
415625
416626 build_clusters<uint32_t , uint32_t ><<<grid, block>>> (labels_view.data_handle (),
417627 clusters.view ().data_handle (),
418- cluster_ptrs .view ().data_handle (),
628+ cluster_offsets .view ().data_handle (),
419629 dataset.extent (0 ),
420630 labels_view.extent (0 ));
421631 RAFT_CUDA_TRY (cudaPeekAtLastError ());
422632
423633 auto rescale_num = raft::make_device_vector<float , int64_t >(res, centroids_view.extent (0 ));
424634 auto rescale_denom = raft::make_device_vector<float , int64_t >(res, centroids_view.extent (0 ));
425635
636+ cluster_loader<T, LabelT> loader (
637+ res, dataset, h_cluster_offsets.view (), clusters.view (), max_cluster_size (0 ), gather_stream);
426638 raft::resource::sync_stream (res);
427639
428640 RAFT_LOG_DEBUG (" Compute AVQ centroids\n " );
429641
430- for (int i = 0 ; i < h_cluster_ptrs.extent (0 ); i++) {
431- int cluster_size = i + 1 < h_cluster_ptrs.extent (0 ) ? h_cluster_ptrs (i + 1 ) - h_cluster_ptrs (i)
432- : dataset.extent (0 ) - h_cluster_ptrs (i);
642+ for (int i = 0 ; i < h_cluster_offsets.extent (0 ); i++) {
643+ auto cluster_vectors = loader.load_cluster (res, i);
433644
434- if (cluster_size == 0 ) { continue ; }
435- auto cluster_ids = raft::make_device_vector_view<const uint32_t , int64_t >(
436- clusters.data_handle () + h_cluster_ptrs (i), cluster_size);
437- auto cluster_vectors =
438- raft::make_device_matrix<float , int64_t >(res, cluster_size, dataset.extent (1 ));
439645 auto avq_centroid = raft::make_device_vector_view<float , int64_t >(
440646 centroids_view.data_handle () + i * dataset.extent (1 ), dataset.extent (1 ));
441647 auto rescale_num_view = raft::make_device_scalar_view<float >(rescale_num.data_handle () + i);
442648 auto rescale_denom_view = raft::make_device_scalar_view<float >(rescale_denom.data_handle () + i);
443649
444- gather_functor<float , uint32_t >{}(
445- res, dataset, cluster_ids, cluster_vectors.view (), raft::resource::get_cuda_stream (res));
446-
447650 compute_avq_centroid (
448- res, cluster_vectors.view (), avq_centroid, rescale_num_view, rescale_denom_view, eta);
651+ res, cluster_vectors, avq_centroid, rescale_num_view, rescale_denom_view, eta);
652+
653+ loader.prefetch_cluster (res, i + 1 );
654+
655+ // make sure work is done before swapping buffers in cluster_loader
656+ raft::resource::sync_stream (res);
449657 }
450658
451659 rescale_avq_centroids (res,
452660 centroids_view,
453661 rescale_num.view (),
454662 rescale_denom.view (),
455- cluster_ptrs .view (),
663+ cluster_offsets .view (),
456664 dataset.extent (0 ));
457665
458666 raft::resource::sync_stream (res);
0 commit comments