Skip to content

Commit c225698

Browse files
committed
Overlapped gather for AVQ
1 parent 2c0e124 commit c225698

2 files changed

Lines changed: 239 additions & 27 deletions

File tree

cpp/src/neighbors/scann/detail/scann_avq.cuh

Lines changed: 233 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,8 @@ __global__ void build_clusters(
6868
template <typename LabelT, typename IdxT>
6969
void compute_cluster_offsets(raft::resources const& dev_resources,
7070
raft::device_vector_view<const LabelT, IdxT> clusters,
71-
raft::device_vector_view<LabelT, int64_t> cluster_sizes)
71+
raft::device_vector_view<LabelT, int64_t> cluster_sizes,
72+
raft::host_scalar_view<int64_t> max_cluster_size)
7273
{
7374
cudaStream_t stream = raft::resource::get_cuda_stream(dev_resources);
7475
rmm::device_async_resource_ref device_memory =
@@ -103,10 +104,33 @@ void compute_cluster_offsets(raft::resources const& dev_resources,
103104
clusters.extent(0),
104105
stream);
105106

106-
temp_storage_bytes = 0;
107+
int num_items = cluster_sizes.extent(0);
108+
109+
// Compute max cluster size
110+
auto d_max_cluster_size = raft::make_device_scalar<int64_t>(dev_resources, 0);
111+
temp_storage_bytes = 0;
112+
113+
cub::DeviceReduce::Max(nullptr,
114+
temp_storage_bytes,
115+
cluster_sizes.data_handle(),
116+
d_max_cluster_size.data_handle(),
117+
num_items);
118+
119+
rmm::device_uvector<int64_t> temp_storage_max(temp_storage_bytes, stream, device_memory);
120+
121+
cub::DeviceReduce::Max(temp_storage_max.data(),
122+
temp_storage_bytes,
123+
cluster_sizes.data_handle(),
124+
d_max_cluster_size.data_handle(),
125+
num_items);
126+
127+
raft::copy(max_cluster_size.data_handle(),
128+
d_max_cluster_size.data_handle(),
129+
d_max_cluster_size.size(),
130+
stream);
107131
// Scan to sum cluster sizes and get cluster start ptrs in flat array
108132
// Done in place
109-
int num_items = cluster_sizes.extent(0);
133+
temp_storage_bytes = 0;
110134

111135
cub::DeviceScan::ExclusiveSum(nullptr,
112136
temp_storage_bytes,
@@ -368,6 +392,190 @@ void rescale_avq_centroids(raft::resources const& dev_resources,
368392
});
369393
}
370394

395+
/**
396+
* A class for loading clusters into a compact matrix (sparse gather)
397+
* for use in AVQ.
398+
*
399+
* There are two possible scenarios:
400+
* 1. Dataset is stored in device memory: No host buffers are allocated,
401+
* and the gather is performed on device
402+
* 2. Dataset is stored in host memory: Two pinned buffers are allocated
403+
* in host for fast DtoH copies of cluster ids, and fast HtoD copy of the
404+
* cluster matrix, while amortizing the cost of allocating pinned memory.
405+
* The gather is performed on cpu, overlapping with GPU compute. Copies are
406+
* allocated on the provided stream, allowing for overlapping with
407+
* other work on other streams.
408+
*/
409+
template <typename T, typename LabelT>
410+
class cluster_loader {
411+
private:
412+
raft::pinned_matrix<T, int64_t> cluster_buf_;
413+
raft::pinned_vector<LabelT, int64_t> cluster_ids_buf_;
414+
raft::device_matrix<T, int64_t> d_cluster_buf_;
415+
raft::device_matrix<T, int64_t> d_cluster_copy_buf_;
416+
const T* dataset_ptr_;
417+
raft::host_vector_view<const LabelT> h_cluster_offsets_;
418+
raft::device_vector_view<const LabelT> cluster_ids_;
419+
cudaStream_t stream_;
420+
int64_t dim_;
421+
int64_t n_rows_;
422+
bool needs_copy_;
423+
424+
int64_t cur_idx_ = -1;
425+
int64_t copy_idx_ = -1;
426+
427+
size_t cluster_size(LabelT idx)
428+
{
429+
if (idx + 1 < h_cluster_offsets_.extent(0)) {
430+
return h_cluster_offsets_(idx + 1) - h_cluster_offsets_(idx);
431+
}
432+
return n_rows_ - h_cluster_offsets_(idx);
433+
}
434+
435+
cluster_loader(raft::resources const& res,
436+
const T* dataset_ptr,
437+
int64_t dim,
438+
int64_t n_rows,
439+
int64_t max_cluster_size,
440+
int64_t h_buf_size,
441+
raft::host_vector_view<LabelT> h_cluster_offsets,
442+
raft::device_vector_view<LabelT> cluster_ids,
443+
bool needs_copy,
444+
cudaStream_t stream)
445+
: dim_(dim),
446+
n_rows_(n_rows),
447+
dataset_ptr_(dataset_ptr),
448+
cluster_buf_(raft::make_pinned_matrix<T, int64_t>(res, h_buf_size, dim)),
449+
cluster_ids_buf_(raft::make_pinned_vector<LabelT, int64_t>(res, h_buf_size)),
450+
d_cluster_buf_(raft::make_device_matrix<T, int64_t>(res, max_cluster_size, dim)),
451+
d_cluster_copy_buf_(raft::make_device_matrix<T, int64_t>(res, max_cluster_size, dim)),
452+
h_cluster_offsets_(h_cluster_offsets),
453+
cluster_ids_(cluster_ids),
454+
needs_copy_(needs_copy),
455+
stream_(stream)
456+
{
457+
}
458+
459+
public:
460+
cluster_loader(raft::resources const& res,
461+
raft::device_matrix_view<const T, int64_t> dataset_view,
462+
raft::host_vector_view<LabelT> h_cluster_offsets,
463+
raft::device_vector_view<LabelT> cluster_ids,
464+
int64_t max_cluster_size,
465+
cudaStream_t stream)
466+
: cluster_loader(res,
467+
dataset_view.data_handle(),
468+
dataset_view.extent(1),
469+
dataset_view.extent(0),
470+
max_cluster_size,
471+
0,
472+
h_cluster_offsets,
473+
cluster_ids,
474+
false,
475+
stream)
476+
477+
{
478+
}
479+
480+
cluster_loader(raft::resources const& res,
481+
raft::host_matrix_view<const T, int64_t> dataset_view,
482+
raft::host_vector_view<LabelT> h_cluster_offsets,
483+
raft::device_vector_view<LabelT> cluster_ids,
484+
int64_t max_cluster_size,
485+
cudaStream_t stream)
486+
: cluster_loader(res,
487+
dataset_view.data_handle(),
488+
dataset_view.extent(1),
489+
dataset_view.extent(0),
490+
max_cluster_size,
491+
max_cluster_size,
492+
h_cluster_offsets,
493+
cluster_ids,
494+
true,
495+
stream)
496+
497+
{
498+
}
499+
500+
/**
501+
* @brief load and return a view of the provided cluster
502+
*
503+
* @param res: the raft resources;
504+
* @param cluster_idx: the index of the cluster to be loaded
505+
* @return device_matrix_view of the cluster vectors
506+
*/
507+
raft::device_matrix_view<T, int64_t> load_cluster(raft::resources const& res, LabelT cluster_idx)
508+
{
509+
size_t size = cluster_size(cluster_idx);
510+
511+
// Check if cluster is already loaded
512+
if (cur_idx_ != cluster_idx) {
513+
// If not, load the cluster
514+
if (copy_idx_ != cluster_idx) { prefetch_cluster(res, cluster_idx); }
515+
516+
// swap buffers
517+
std::swap(d_cluster_buf_, d_cluster_copy_buf_);
518+
std::swap(cur_idx_, copy_idx_);
519+
}
520+
521+
return raft::make_device_matrix_view<T, int64_t>(d_cluster_buf_.data_handle(), size, dim_);
522+
}
523+
524+
/** @brief Perform gather operation on stream_
525+
*
526+
* @param res: the raft resources
527+
* @param cluster_idx: the index of the cluster
528+
*/
529+
void prefetch_cluster(raft::resources const& res, LabelT cluster_idx)
530+
{
531+
if (cluster_idx >= h_cluster_offsets_.extent(0)) { return; }
532+
533+
size_t size = cluster_size(cluster_idx);
534+
535+
auto cluster_ids = raft::make_device_vector_view<const LabelT, int64_t>(
536+
cluster_ids_.data_handle() + h_cluster_offsets_(cluster_idx), size);
537+
538+
auto cluster_vectors =
539+
raft::make_device_matrix_view<float, int64_t>(d_cluster_copy_buf_.data_handle(), size, dim_);
540+
541+
if (needs_copy_) {
542+
// htod
543+
auto h_cluster_ids =
544+
raft::make_pinned_vector_view<LabelT, int64_t>(cluster_ids_buf_.data_handle(), size);
545+
546+
raft::copy(
547+
h_cluster_ids.data_handle(), cluster_ids.data_handle(), cluster_ids.size(), stream_);
548+
raft::resource::sync_stream(res, stream_);
549+
550+
auto pinned_cluster = raft::make_pinned_matrix_view<T, int64_t>(
551+
cluster_buf_.data_handle(), cluster_vectors.extent(0), cluster_vectors.extent(1));
552+
553+
int n_threads = std::min<int>(omp_get_max_threads(), 32);
554+
#pragma omp parallel for num_threads(n_threads)
555+
for (int i = 0; i < h_cluster_ids.extent(0); i++) {
556+
memcpy(pinned_cluster.data_handle() + i * pinned_cluster.extent(1),
557+
dataset_ptr_ + h_cluster_ids(i) * dim_,
558+
sizeof(T) * dim_);
559+
}
560+
561+
raft::copy(cluster_vectors.data_handle(),
562+
pinned_cluster.data_handle(),
563+
pinned_cluster.size(),
564+
stream_);
565+
raft::resource::sync_stream(res, stream_);
566+
567+
} else {
568+
// dtod
569+
auto dataset_view =
570+
raft::make_device_matrix_view<const T, int64_t>(dataset_ptr_, n_rows_, dim_);
571+
572+
raft::matrix::gather(res, dataset_view, cluster_ids, cluster_vectors);
573+
}
574+
575+
copy_idx_ = cluster_idx;
576+
}
577+
};
578+
371579
/**
372580
* @brief Perform AVQ adjustment on cluster centers
373581
*
@@ -396,63 +604,63 @@ void apply_avq(raft::resources const& res,
396604
raft::mdspan<const T, raft::matrix_extent<IdxT>, raft::row_major, Accessor> dataset,
397605
raft::device_matrix_view<T, IdxT> centroids_view,
398606
raft::device_vector_view<const LabelT, IdxT> labels_view,
399-
float eta)
607+
float eta,
608+
cudaStream_t gather_stream)
400609
{
401610
// Compute clusters
402611

403-
cudaStream_t stream = raft::resource::get_cuda_stream(res);
404-
auto cluster_ptrs = raft::make_device_vector<uint32_t, int64_t>(res, centroids_view.extent(0));
405-
auto clusters = raft::make_device_vector<uint32_t, int64_t>(res, dataset.extent(0));
406-
407-
compute_cluster_offsets(res, labels_view, cluster_ptrs.view());
612+
cudaStream_t stream = raft::resource::get_cuda_stream(res);
613+
auto cluster_offsets = raft::make_device_vector<uint32_t, int64_t>(res, centroids_view.extent(0));
614+
auto clusters = raft::make_device_vector<uint32_t, int64_t>(res, dataset.extent(0));
615+
auto max_cluster_size = raft::make_host_scalar<int64_t>(0);
408616

409-
auto h_cluster_ptrs = raft::make_host_vector<uint32_t, int64_t>(cluster_ptrs.extent(0));
617+
compute_cluster_offsets(res, labels_view, cluster_offsets.view(), max_cluster_size.view());
618+
auto h_cluster_offsets = raft::make_host_vector<uint32_t, int64_t>(cluster_offsets.extent(0));
410619

411-
raft::copy(h_cluster_ptrs.data_handle(), cluster_ptrs.data_handle(), cluster_ptrs.size(), stream);
620+
raft::copy(
621+
h_cluster_offsets.data_handle(), cluster_offsets.data_handle(), cluster_offsets.size(), stream);
412622

413623
dim3 block(32, 1, 1);
414624
dim3 grid((dataset.extent(0) + block.x - 1) / block.x, 1, 1);
415625

416626
build_clusters<uint32_t, uint32_t><<<grid, block>>>(labels_view.data_handle(),
417627
clusters.view().data_handle(),
418-
cluster_ptrs.view().data_handle(),
628+
cluster_offsets.view().data_handle(),
419629
dataset.extent(0),
420630
labels_view.extent(0));
421631
RAFT_CUDA_TRY(cudaPeekAtLastError());
422632

423633
auto rescale_num = raft::make_device_vector<float, int64_t>(res, centroids_view.extent(0));
424634
auto rescale_denom = raft::make_device_vector<float, int64_t>(res, centroids_view.extent(0));
425635

636+
cluster_loader<T, LabelT> loader(
637+
res, dataset, h_cluster_offsets.view(), clusters.view(), max_cluster_size(0), gather_stream);
426638
raft::resource::sync_stream(res);
427639

428640
RAFT_LOG_DEBUG("Compute AVQ centroids\n");
429641

430-
for (int i = 0; i < h_cluster_ptrs.extent(0); i++) {
431-
int cluster_size = i + 1 < h_cluster_ptrs.extent(0) ? h_cluster_ptrs(i + 1) - h_cluster_ptrs(i)
432-
: dataset.extent(0) - h_cluster_ptrs(i);
642+
for (int i = 0; i < h_cluster_offsets.extent(0); i++) {
643+
auto cluster_vectors = loader.load_cluster(res, i);
433644

434-
if (cluster_size == 0) { continue; }
435-
auto cluster_ids = raft::make_device_vector_view<const uint32_t, int64_t>(
436-
clusters.data_handle() + h_cluster_ptrs(i), cluster_size);
437-
auto cluster_vectors =
438-
raft::make_device_matrix<float, int64_t>(res, cluster_size, dataset.extent(1));
439645
auto avq_centroid = raft::make_device_vector_view<float, int64_t>(
440646
centroids_view.data_handle() + i * dataset.extent(1), dataset.extent(1));
441647
auto rescale_num_view = raft::make_device_scalar_view<float>(rescale_num.data_handle() + i);
442648
auto rescale_denom_view = raft::make_device_scalar_view<float>(rescale_denom.data_handle() + i);
443649

444-
gather_functor<float, uint32_t>{}(
445-
res, dataset, cluster_ids, cluster_vectors.view(), raft::resource::get_cuda_stream(res));
446-
447650
compute_avq_centroid(
448-
res, cluster_vectors.view(), avq_centroid, rescale_num_view, rescale_denom_view, eta);
651+
res, cluster_vectors, avq_centroid, rescale_num_view, rescale_denom_view, eta);
652+
653+
loader.prefetch_cluster(res, i + 1);
654+
655+
// make sure work is done before swapping buffers in cluster_loader
656+
raft::resource::sync_stream(res);
449657
}
450658

451659
rescale_avq_centroids(res,
452660
centroids_view,
453661
rescale_num.view(),
454662
rescale_denom.view(),
455-
cluster_ptrs.view(),
663+
cluster_offsets.view(),
456664
dataset.extent(0));
457665

458666
raft::resource::sync_stream(res);

cpp/src/neighbors/scann/detail/scann_build.cuh

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -179,8 +179,12 @@ index<T, IdxT> build(
179179
}
180180

181181
// AVQ update of KMeans centroids
182-
apply_avq(
183-
res, dataset, centroids_view, raft::make_const_mdspan(labels_view), params.partitioning_eta);
182+
apply_avq(res,
183+
dataset,
184+
centroids_view,
185+
raft::make_const_mdspan(labels_view),
186+
params.partitioning_eta,
187+
copy_stream);
184188

185189
raft::device_vector_view<uint32_t, int64_t> soar_labels_view = idx.soar_labels();
186190

0 commit comments

Comments
 (0)