@@ -68,7 +68,8 @@ __global__ void build_clusters(
6868template <typename LabelT, typename IdxT>
6969void compute_cluster_offsets (raft::resources const & dev_resources,
7070 raft::device_vector_view<const LabelT, IdxT> clusters,
71- raft::device_vector_view<LabelT, int64_t > cluster_sizes)
71+ raft::device_vector_view<LabelT, int64_t > cluster_sizes,
72+ int64_t & max_cluster_size)
7273{
7374 cudaStream_t stream = raft::resource::get_cuda_stream (dev_resources);
7475 rmm::device_async_resource_ref device_memory =
@@ -103,10 +104,30 @@ void compute_cluster_offsets(raft::resources const& dev_resources,
103104 clusters.extent (0 ),
104105 stream);
105106
107+ int num_items = cluster_sizes.extent (0 );
108+
109+ // Compute max cluster size
110+ // auto d_max_cluster_size = rmmo::make_device_scalar<int64_t>(dev_resources, 0);
111+ rmm::device_scalar<int64_t > d_max_cluster_size (stream);
112+
106113 temp_storage_bytes = 0 ;
114+
115+ cub::DeviceReduce::Max (
116+ nullptr , temp_storage_bytes, cluster_sizes.data_handle (), d_max_cluster_size.data (), num_items);
117+
118+ rmm::device_uvector<int64_t > temp_storage_max (temp_storage_bytes, stream, device_memory);
119+
120+ cub::DeviceReduce::Max (temp_storage_max.data (),
121+ temp_storage_bytes,
122+ cluster_sizes.data_handle (),
123+ d_max_cluster_size.data (),
124+ num_items);
125+
126+ max_cluster_size = d_max_cluster_size.value (stream);
127+
107128 // Scan to sum cluster sizes and get cluster start ptrs in flat array
108129 // Done in place
109- int num_items = cluster_sizes. extent ( 0 ) ;
130+ temp_storage_bytes = 0 ;
110131
111132 cub::DeviceScan::ExclusiveSum (nullptr ,
112133 temp_storage_bytes,
@@ -368,6 +389,190 @@ void rescale_avq_centroids(raft::resources const& dev_resources,
368389 });
369390}
370391
392+ /* *
393+ * A class for loading clusters into a compact matrix (sparse gather)
394+ * for use in AVQ.
395+ *
396+ * There are two possible scenarios:
397+ * 1. Dataset is stored in device memory: No host buffers are allocated,
398+ * and the gather is performed on device
399+ * 2. Dataset is stored in host memory: Two pinned buffers are allocated
400+ * in host for fast DtoH copies of cluster ids, and fast HtoD copy of the
401+ * cluster matrix, while amortizing the cost of allocating pinned memory.
402+ * The gather is performed on cpu, overlapping with GPU compute. Copies are
403+ * allocated on the provided stream, allowing for overlapping with
404+ * other work on other streams.
405+ */
406+ template <typename T, typename LabelT>
407+ class cluster_loader {
408+ private:
409+ raft::pinned_matrix<T, int64_t > cluster_buf_;
410+ raft::pinned_vector<LabelT, int64_t > cluster_ids_buf_;
411+ raft::device_matrix<T, int64_t > d_cluster_buf_;
412+ raft::device_matrix<T, int64_t > d_cluster_copy_buf_;
413+ const T* dataset_ptr_;
414+ raft::host_vector_view<const LabelT> h_cluster_offsets_;
415+ raft::device_vector_view<const LabelT> cluster_ids_;
416+ cudaStream_t stream_;
417+ int64_t dim_;
418+ int64_t n_rows_;
419+ bool needs_copy_;
420+
421+ int64_t cur_idx_ = -1 ;
422+ int64_t copy_idx_ = -1 ;
423+
424+ size_t cluster_size (LabelT idx)
425+ {
426+ if (idx + 1 < h_cluster_offsets_.extent (0 )) {
427+ return h_cluster_offsets_ (idx + 1 ) - h_cluster_offsets_ (idx);
428+ }
429+ return n_rows_ - h_cluster_offsets_ (idx);
430+ }
431+
432+ cluster_loader (raft::resources const & res,
433+ const T* dataset_ptr,
434+ int64_t dim,
435+ int64_t n_rows,
436+ int64_t max_cluster_size,
437+ int64_t h_buf_size,
438+ raft::host_vector_view<LabelT> h_cluster_offsets,
439+ raft::device_vector_view<LabelT> cluster_ids,
440+ bool needs_copy,
441+ cudaStream_t stream)
442+ : dim_(dim),
443+ n_rows_ (n_rows),
444+ dataset_ptr_(dataset_ptr),
445+ cluster_buf_(raft::make_pinned_matrix<T, int64_t >(res, h_buf_size, dim)),
446+ cluster_ids_buf_(raft::make_pinned_vector<LabelT, int64_t >(res, h_buf_size)),
447+ d_cluster_buf_(raft::make_device_matrix<T, int64_t >(res, max_cluster_size, dim)),
448+ d_cluster_copy_buf_(raft::make_device_matrix<T, int64_t >(res, max_cluster_size, dim)),
449+ h_cluster_offsets_(h_cluster_offsets),
450+ cluster_ids_(cluster_ids),
451+ needs_copy_(needs_copy),
452+ stream_(stream)
453+ {
454+ }
455+
456+ public:
457+ cluster_loader (raft::resources const & res,
458+ raft::device_matrix_view<const T, int64_t > dataset_view,
459+ raft::host_vector_view<LabelT> h_cluster_offsets,
460+ raft::device_vector_view<LabelT> cluster_ids,
461+ int64_t max_cluster_size,
462+ cudaStream_t stream)
463+ : cluster_loader(res,
464+ dataset_view.data_handle(),
465+ dataset_view.extent(1 ),
466+ dataset_view.extent(0 ),
467+ max_cluster_size,
468+ 0,
469+ h_cluster_offsets,
470+ cluster_ids,
471+ false,
472+ stream)
473+
474+ {
475+ }
476+
477+ cluster_loader (raft::resources const & res,
478+ raft::host_matrix_view<const T, int64_t > dataset_view,
479+ raft::host_vector_view<LabelT> h_cluster_offsets,
480+ raft::device_vector_view<LabelT> cluster_ids,
481+ int64_t max_cluster_size,
482+ cudaStream_t stream)
483+ : cluster_loader(res,
484+ dataset_view.data_handle(),
485+ dataset_view.extent(1 ),
486+ dataset_view.extent(0 ),
487+ max_cluster_size,
488+ max_cluster_size,
489+ h_cluster_offsets,
490+ cluster_ids,
491+ true,
492+ stream)
493+
494+ {
495+ }
496+
497+ /* *
498+ * @brief load and return a view of the provided cluster
499+ *
500+ * @param res: the raft resources;
501+ * @param cluster_idx: the index of the cluster to be loaded
502+ * @return device_matrix_view of the cluster vectors
503+ */
504+ raft::device_matrix_view<T, int64_t > load_cluster (raft::resources const & res, LabelT cluster_idx)
505+ {
506+ size_t size = cluster_size (cluster_idx);
507+
508+ // Check if cluster is already loaded
509+ if (cur_idx_ != cluster_idx) {
510+ // If not, load the cluster
511+ if (copy_idx_ != cluster_idx) { prefetch_cluster (res, cluster_idx); }
512+
513+ // swap buffers
514+ std::swap (d_cluster_buf_, d_cluster_copy_buf_);
515+ std::swap (cur_idx_, copy_idx_);
516+ }
517+
518+ return raft::make_device_matrix_view<T, int64_t >(d_cluster_buf_.data_handle (), size, dim_);
519+ }
520+
521+ /* * @brief Perform gather operation on stream_
522+ *
523+ * @param res: the raft resources
524+ * @param cluster_idx: the index of the cluster
525+ */
526+ void prefetch_cluster (raft::resources const & res, LabelT cluster_idx)
527+ {
528+ if (cluster_idx >= h_cluster_offsets_.extent (0 )) { return ; }
529+
530+ size_t size = cluster_size (cluster_idx);
531+
532+ auto cluster_ids = raft::make_device_vector_view<const LabelT, int64_t >(
533+ cluster_ids_.data_handle () + h_cluster_offsets_ (cluster_idx), size);
534+
535+ auto cluster_vectors =
536+ raft::make_device_matrix_view<float , int64_t >(d_cluster_copy_buf_.data_handle (), size, dim_);
537+
538+ if (needs_copy_) {
539+ // htod
540+ auto h_cluster_ids =
541+ raft::make_pinned_vector_view<LabelT, int64_t >(cluster_ids_buf_.data_handle (), size);
542+
543+ raft::copy (
544+ h_cluster_ids.data_handle (), cluster_ids.data_handle (), cluster_ids.size (), stream_);
545+ raft::resource::sync_stream (res, stream_);
546+
547+ auto pinned_cluster = raft::make_pinned_matrix_view<T, int64_t >(
548+ cluster_buf_.data_handle (), cluster_vectors.extent (0 ), cluster_vectors.extent (1 ));
549+
550+ int n_threads = std::min<int >(omp_get_max_threads (), 32 );
551+ #pragma omp parallel for num_threads(n_threads)
552+ for (int i = 0 ; i < h_cluster_ids.extent (0 ); i++) {
553+ memcpy (pinned_cluster.data_handle () + i * pinned_cluster.extent (1 ),
554+ dataset_ptr_ + h_cluster_ids (i) * dim_,
555+ sizeof (T) * dim_);
556+ }
557+
558+ raft::copy (cluster_vectors.data_handle (),
559+ pinned_cluster.data_handle (),
560+ pinned_cluster.size (),
561+ stream_);
562+ raft::resource::sync_stream (res, stream_);
563+
564+ } else {
565+ // dtod
566+ auto dataset_view =
567+ raft::make_device_matrix_view<const T, int64_t >(dataset_ptr_, n_rows_, dim_);
568+
569+ raft::matrix::gather (res, dataset_view, cluster_ids, cluster_vectors);
570+ }
571+
572+ copy_idx_ = cluster_idx;
573+ }
574+ };
575+
371576/* *
372577 * @brief Perform AVQ adjustment on cluster centers
373578 *
@@ -396,63 +601,63 @@ void apply_avq(raft::resources const& res,
396601 raft::mdspan<const T, raft::matrix_extent<IdxT>, raft::row_major, Accessor> dataset,
397602 raft::device_matrix_view<T, IdxT> centroids_view,
398603 raft::device_vector_view<const LabelT, IdxT> labels_view,
399- float eta)
604+ float eta,
605+ cudaStream_t copy_stream)
400606{
401607 // Compute clusters
402608
403- cudaStream_t stream = raft::resource::get_cuda_stream (res);
404- auto cluster_ptrs = raft::make_device_vector<uint32_t , int64_t >(res, centroids_view.extent (0 ));
405- auto clusters = raft::make_device_vector<uint32_t , int64_t >(res, dataset.extent (0 ));
406-
407- compute_cluster_offsets (res, labels_view, cluster_ptrs.view ());
609+ cudaStream_t stream = raft::resource::get_cuda_stream (res);
610+ auto cluster_offsets = raft::make_device_vector<uint32_t , int64_t >(res, centroids_view.extent (0 ));
611+ auto clusters = raft::make_device_vector<uint32_t , int64_t >(res, dataset.extent (0 ));
612+ int64_t max_cluster_size = 0 ;
408613
409- auto h_cluster_ptrs = raft::make_host_vector<uint32_t , int64_t >(cluster_ptrs.extent (0 ));
614+ compute_cluster_offsets (res, labels_view, cluster_offsets.view (), max_cluster_size);
615+ auto h_cluster_offsets = raft::make_host_vector<uint32_t , int64_t >(cluster_offsets.extent (0 ));
410616
411- raft::copy (h_cluster_ptrs.data_handle (), cluster_ptrs.data_handle (), cluster_ptrs.size (), stream);
617+ raft::copy (
618+ h_cluster_offsets.data_handle (), cluster_offsets.data_handle (), cluster_offsets.size (), stream);
412619
413620 dim3 block (32 , 1 , 1 );
414621 dim3 grid ((dataset.extent (0 ) + block.x - 1 ) / block.x , 1 , 1 );
415622
416623 build_clusters<uint32_t , uint32_t ><<<grid, block>>> (labels_view.data_handle (),
417624 clusters.view ().data_handle (),
418- cluster_ptrs .view ().data_handle (),
625+ cluster_offsets .view ().data_handle (),
419626 dataset.extent (0 ),
420627 labels_view.extent (0 ));
421628 RAFT_CUDA_TRY (cudaPeekAtLastError ());
422629
423630 auto rescale_num = raft::make_device_vector<float , int64_t >(res, centroids_view.extent (0 ));
424631 auto rescale_denom = raft::make_device_vector<float , int64_t >(res, centroids_view.extent (0 ));
425632
633+ cluster_loader<T, LabelT> loader (
634+ res, dataset, h_cluster_offsets.view (), clusters.view (), max_cluster_size, copy_stream);
426635 raft::resource::sync_stream (res);
427636
428637 RAFT_LOG_DEBUG (" Compute AVQ centroids\n " );
429638
430- for (int i = 0 ; i < h_cluster_ptrs.extent (0 ); i++) {
431- int cluster_size = i + 1 < h_cluster_ptrs.extent (0 ) ? h_cluster_ptrs (i + 1 ) - h_cluster_ptrs (i)
432- : dataset.extent (0 ) - h_cluster_ptrs (i);
639+ for (int i = 0 ; i < h_cluster_offsets.extent (0 ); i++) {
640+ auto cluster_vectors = loader.load_cluster (res, i);
433641
434- if (cluster_size == 0 ) { continue ; }
435- auto cluster_ids = raft::make_device_vector_view<const uint32_t , int64_t >(
436- clusters.data_handle () + h_cluster_ptrs (i), cluster_size);
437- auto cluster_vectors =
438- raft::make_device_matrix<float , int64_t >(res, cluster_size, dataset.extent (1 ));
439642 auto avq_centroid = raft::make_device_vector_view<float , int64_t >(
440643 centroids_view.data_handle () + i * dataset.extent (1 ), dataset.extent (1 ));
441644 auto rescale_num_view = raft::make_device_scalar_view<float >(rescale_num.data_handle () + i);
442645 auto rescale_denom_view = raft::make_device_scalar_view<float >(rescale_denom.data_handle () + i);
443646
444- gather_functor<float , uint32_t >{}(
445- res, dataset, cluster_ids, cluster_vectors.view (), raft::resource::get_cuda_stream (res));
446-
447647 compute_avq_centroid (
448- res, cluster_vectors.view (), avq_centroid, rescale_num_view, rescale_denom_view, eta);
648+ res, cluster_vectors, avq_centroid, rescale_num_view, rescale_denom_view, eta);
649+
650+ loader.prefetch_cluster (res, i + 1 );
651+
652+ // make sure work is done before swapping buffers in cluster_loader
653+ raft::resource::sync_stream (res);
449654 }
450655
451656 rescale_avq_centroids (res,
452657 centroids_view,
453658 rescale_num.view (),
454659 rescale_denom.view (),
455- cluster_ptrs .view (),
660+ cluster_offsets .view (),
456661 dataset.extent (0 ));
457662
458663 raft::resource::sync_stream (res);
0 commit comments