Skip to content

Commit 03d62f6

Browse files
authored
ScaNN: Overlapped gather for AVQ (#1286)
Adds a class cluster_loader for AVQ that enables overlapping the gather operation and HtoD copy with GPU computation. There are two scenarios: 1. dataset on device: This is identical to the previous code, using raft::matrix::gather to perform the gather on device. 2. dataset on host: cluster_loader allocates to pinned buffers in host for fast (and possibly async) copies of cluster vectors DtoH. The actual gather operation is performed on cpu, into the pinned buffer. Copies can be overlapped with GPU work (namely AVQ update of the previous cluster) if scheduled on a separate stream. Authors: - https://github.com/rmaschal Approvers: - Tamas Bela Feher (https://github.com/tfeher) - Ben Karsin (https://github.com/bkarsin) URL: #1286
1 parent 2c0e124 commit 03d62f6

2 files changed

Lines changed: 235 additions & 26 deletions

File tree

cpp/src/neighbors/scann/detail/scann_avq.cuh

Lines changed: 229 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,8 @@ __global__ void build_clusters(
6868
template <typename LabelT, typename IdxT>
6969
void compute_cluster_offsets(raft::resources const& dev_resources,
7070
raft::device_vector_view<const LabelT, IdxT> clusters,
71-
raft::device_vector_view<LabelT, int64_t> cluster_sizes)
71+
raft::device_vector_view<LabelT, int64_t> cluster_sizes,
72+
int64_t& max_cluster_size)
7273
{
7374
cudaStream_t stream = raft::resource::get_cuda_stream(dev_resources);
7475
rmm::device_async_resource_ref device_memory =
@@ -103,10 +104,30 @@ void compute_cluster_offsets(raft::resources const& dev_resources,
103104
clusters.extent(0),
104105
stream);
105106

107+
int num_items = cluster_sizes.extent(0);
108+
109+
// Compute max cluster size
110+
// auto d_max_cluster_size = rmmo::make_device_scalar<int64_t>(dev_resources, 0);
111+
rmm::device_scalar<int64_t> d_max_cluster_size(stream);
112+
106113
temp_storage_bytes = 0;
114+
115+
cub::DeviceReduce::Max(
116+
nullptr, temp_storage_bytes, cluster_sizes.data_handle(), d_max_cluster_size.data(), num_items);
117+
118+
rmm::device_uvector<int64_t> temp_storage_max(temp_storage_bytes, stream, device_memory);
119+
120+
cub::DeviceReduce::Max(temp_storage_max.data(),
121+
temp_storage_bytes,
122+
cluster_sizes.data_handle(),
123+
d_max_cluster_size.data(),
124+
num_items);
125+
126+
max_cluster_size = d_max_cluster_size.value(stream);
127+
107128
// Scan to sum cluster sizes and get cluster start ptrs in flat array
108129
// Done in place
109-
int num_items = cluster_sizes.extent(0);
130+
temp_storage_bytes = 0;
110131

111132
cub::DeviceScan::ExclusiveSum(nullptr,
112133
temp_storage_bytes,
@@ -368,6 +389,190 @@ void rescale_avq_centroids(raft::resources const& dev_resources,
368389
});
369390
}
370391

392+
/**
393+
* A class for loading clusters into a compact matrix (sparse gather)
394+
* for use in AVQ.
395+
*
396+
* There are two possible scenarios:
397+
* 1. Dataset is stored in device memory: No host buffers are allocated,
398+
* and the gather is performed on device
399+
* 2. Dataset is stored in host memory: Two pinned buffers are allocated
400+
* in host for fast DtoH copies of cluster ids, and fast HtoD copy of the
401+
* cluster matrix, while amortizing the cost of allocating pinned memory.
402+
* The gather is performed on cpu, overlapping with GPU compute. Copies are
403+
* allocated on the provided stream, allowing for overlapping with
404+
* other work on other streams.
405+
*/
406+
template <typename T, typename LabelT>
407+
class cluster_loader {
408+
private:
409+
raft::pinned_matrix<T, int64_t> cluster_buf_;
410+
raft::pinned_vector<LabelT, int64_t> cluster_ids_buf_;
411+
raft::device_matrix<T, int64_t> d_cluster_buf_;
412+
raft::device_matrix<T, int64_t> d_cluster_copy_buf_;
413+
const T* dataset_ptr_;
414+
raft::host_vector_view<const LabelT> h_cluster_offsets_;
415+
raft::device_vector_view<const LabelT> cluster_ids_;
416+
cudaStream_t stream_;
417+
int64_t dim_;
418+
int64_t n_rows_;
419+
bool needs_copy_;
420+
421+
int64_t cur_idx_ = -1;
422+
int64_t copy_idx_ = -1;
423+
424+
size_t cluster_size(LabelT idx)
425+
{
426+
if (idx + 1 < h_cluster_offsets_.extent(0)) {
427+
return h_cluster_offsets_(idx + 1) - h_cluster_offsets_(idx);
428+
}
429+
return n_rows_ - h_cluster_offsets_(idx);
430+
}
431+
432+
cluster_loader(raft::resources const& res,
433+
const T* dataset_ptr,
434+
int64_t dim,
435+
int64_t n_rows,
436+
int64_t max_cluster_size,
437+
int64_t h_buf_size,
438+
raft::host_vector_view<LabelT> h_cluster_offsets,
439+
raft::device_vector_view<LabelT> cluster_ids,
440+
bool needs_copy,
441+
cudaStream_t stream)
442+
: dim_(dim),
443+
n_rows_(n_rows),
444+
dataset_ptr_(dataset_ptr),
445+
cluster_buf_(raft::make_pinned_matrix<T, int64_t>(res, h_buf_size, dim)),
446+
cluster_ids_buf_(raft::make_pinned_vector<LabelT, int64_t>(res, h_buf_size)),
447+
d_cluster_buf_(raft::make_device_matrix<T, int64_t>(res, max_cluster_size, dim)),
448+
d_cluster_copy_buf_(raft::make_device_matrix<T, int64_t>(res, max_cluster_size, dim)),
449+
h_cluster_offsets_(h_cluster_offsets),
450+
cluster_ids_(cluster_ids),
451+
needs_copy_(needs_copy),
452+
stream_(stream)
453+
{
454+
}
455+
456+
public:
457+
cluster_loader(raft::resources const& res,
458+
raft::device_matrix_view<const T, int64_t> dataset_view,
459+
raft::host_vector_view<LabelT> h_cluster_offsets,
460+
raft::device_vector_view<LabelT> cluster_ids,
461+
int64_t max_cluster_size,
462+
cudaStream_t stream)
463+
: cluster_loader(res,
464+
dataset_view.data_handle(),
465+
dataset_view.extent(1),
466+
dataset_view.extent(0),
467+
max_cluster_size,
468+
0,
469+
h_cluster_offsets,
470+
cluster_ids,
471+
false,
472+
stream)
473+
474+
{
475+
}
476+
477+
cluster_loader(raft::resources const& res,
478+
raft::host_matrix_view<const T, int64_t> dataset_view,
479+
raft::host_vector_view<LabelT> h_cluster_offsets,
480+
raft::device_vector_view<LabelT> cluster_ids,
481+
int64_t max_cluster_size,
482+
cudaStream_t stream)
483+
: cluster_loader(res,
484+
dataset_view.data_handle(),
485+
dataset_view.extent(1),
486+
dataset_view.extent(0),
487+
max_cluster_size,
488+
max_cluster_size,
489+
h_cluster_offsets,
490+
cluster_ids,
491+
true,
492+
stream)
493+
494+
{
495+
}
496+
497+
/**
498+
* @brief load and return a view of the provided cluster
499+
*
500+
* @param res: the raft resources;
501+
* @param cluster_idx: the index of the cluster to be loaded
502+
* @return device_matrix_view of the cluster vectors
503+
*/
504+
raft::device_matrix_view<T, int64_t> load_cluster(raft::resources const& res, LabelT cluster_idx)
505+
{
506+
size_t size = cluster_size(cluster_idx);
507+
508+
// Check if cluster is already loaded
509+
if (cur_idx_ != cluster_idx) {
510+
// If not, load the cluster
511+
if (copy_idx_ != cluster_idx) { prefetch_cluster(res, cluster_idx); }
512+
513+
// swap buffers
514+
std::swap(d_cluster_buf_, d_cluster_copy_buf_);
515+
std::swap(cur_idx_, copy_idx_);
516+
}
517+
518+
return raft::make_device_matrix_view<T, int64_t>(d_cluster_buf_.data_handle(), size, dim_);
519+
}
520+
521+
/** @brief Perform gather operation on stream_
522+
*
523+
* @param res: the raft resources
524+
* @param cluster_idx: the index of the cluster
525+
*/
526+
void prefetch_cluster(raft::resources const& res, LabelT cluster_idx)
527+
{
528+
if (cluster_idx >= h_cluster_offsets_.extent(0)) { return; }
529+
530+
size_t size = cluster_size(cluster_idx);
531+
532+
auto cluster_ids = raft::make_device_vector_view<const LabelT, int64_t>(
533+
cluster_ids_.data_handle() + h_cluster_offsets_(cluster_idx), size);
534+
535+
auto cluster_vectors =
536+
raft::make_device_matrix_view<float, int64_t>(d_cluster_copy_buf_.data_handle(), size, dim_);
537+
538+
if (needs_copy_) {
539+
// htod
540+
auto h_cluster_ids =
541+
raft::make_pinned_vector_view<LabelT, int64_t>(cluster_ids_buf_.data_handle(), size);
542+
543+
raft::copy(
544+
h_cluster_ids.data_handle(), cluster_ids.data_handle(), cluster_ids.size(), stream_);
545+
raft::resource::sync_stream(res, stream_);
546+
547+
auto pinned_cluster = raft::make_pinned_matrix_view<T, int64_t>(
548+
cluster_buf_.data_handle(), cluster_vectors.extent(0), cluster_vectors.extent(1));
549+
550+
int n_threads = std::min<int>(omp_get_max_threads(), 32);
551+
#pragma omp parallel for num_threads(n_threads)
552+
for (int i = 0; i < h_cluster_ids.extent(0); i++) {
553+
memcpy(pinned_cluster.data_handle() + i * pinned_cluster.extent(1),
554+
dataset_ptr_ + h_cluster_ids(i) * dim_,
555+
sizeof(T) * dim_);
556+
}
557+
558+
raft::copy(cluster_vectors.data_handle(),
559+
pinned_cluster.data_handle(),
560+
pinned_cluster.size(),
561+
stream_);
562+
raft::resource::sync_stream(res, stream_);
563+
564+
} else {
565+
// dtod
566+
auto dataset_view =
567+
raft::make_device_matrix_view<const T, int64_t>(dataset_ptr_, n_rows_, dim_);
568+
569+
raft::matrix::gather(res, dataset_view, cluster_ids, cluster_vectors);
570+
}
571+
572+
copy_idx_ = cluster_idx;
573+
}
574+
};
575+
371576
/**
372577
* @brief Perform AVQ adjustment on cluster centers
373578
*
@@ -396,63 +601,63 @@ void apply_avq(raft::resources const& res,
396601
raft::mdspan<const T, raft::matrix_extent<IdxT>, raft::row_major, Accessor> dataset,
397602
raft::device_matrix_view<T, IdxT> centroids_view,
398603
raft::device_vector_view<const LabelT, IdxT> labels_view,
399-
float eta)
604+
float eta,
605+
cudaStream_t copy_stream)
400606
{
401607
// Compute clusters
402608

403-
cudaStream_t stream = raft::resource::get_cuda_stream(res);
404-
auto cluster_ptrs = raft::make_device_vector<uint32_t, int64_t>(res, centroids_view.extent(0));
405-
auto clusters = raft::make_device_vector<uint32_t, int64_t>(res, dataset.extent(0));
406-
407-
compute_cluster_offsets(res, labels_view, cluster_ptrs.view());
609+
cudaStream_t stream = raft::resource::get_cuda_stream(res);
610+
auto cluster_offsets = raft::make_device_vector<uint32_t, int64_t>(res, centroids_view.extent(0));
611+
auto clusters = raft::make_device_vector<uint32_t, int64_t>(res, dataset.extent(0));
612+
int64_t max_cluster_size = 0;
408613

409-
auto h_cluster_ptrs = raft::make_host_vector<uint32_t, int64_t>(cluster_ptrs.extent(0));
614+
compute_cluster_offsets(res, labels_view, cluster_offsets.view(), max_cluster_size);
615+
auto h_cluster_offsets = raft::make_host_vector<uint32_t, int64_t>(cluster_offsets.extent(0));
410616

411-
raft::copy(h_cluster_ptrs.data_handle(), cluster_ptrs.data_handle(), cluster_ptrs.size(), stream);
617+
raft::copy(
618+
h_cluster_offsets.data_handle(), cluster_offsets.data_handle(), cluster_offsets.size(), stream);
412619

413620
dim3 block(32, 1, 1);
414621
dim3 grid((dataset.extent(0) + block.x - 1) / block.x, 1, 1);
415622

416623
build_clusters<uint32_t, uint32_t><<<grid, block>>>(labels_view.data_handle(),
417624
clusters.view().data_handle(),
418-
cluster_ptrs.view().data_handle(),
625+
cluster_offsets.view().data_handle(),
419626
dataset.extent(0),
420627
labels_view.extent(0));
421628
RAFT_CUDA_TRY(cudaPeekAtLastError());
422629

423630
auto rescale_num = raft::make_device_vector<float, int64_t>(res, centroids_view.extent(0));
424631
auto rescale_denom = raft::make_device_vector<float, int64_t>(res, centroids_view.extent(0));
425632

633+
cluster_loader<T, LabelT> loader(
634+
res, dataset, h_cluster_offsets.view(), clusters.view(), max_cluster_size, copy_stream);
426635
raft::resource::sync_stream(res);
427636

428637
RAFT_LOG_DEBUG("Compute AVQ centroids\n");
429638

430-
for (int i = 0; i < h_cluster_ptrs.extent(0); i++) {
431-
int cluster_size = i + 1 < h_cluster_ptrs.extent(0) ? h_cluster_ptrs(i + 1) - h_cluster_ptrs(i)
432-
: dataset.extent(0) - h_cluster_ptrs(i);
639+
for (int i = 0; i < h_cluster_offsets.extent(0); i++) {
640+
auto cluster_vectors = loader.load_cluster(res, i);
433641

434-
if (cluster_size == 0) { continue; }
435-
auto cluster_ids = raft::make_device_vector_view<const uint32_t, int64_t>(
436-
clusters.data_handle() + h_cluster_ptrs(i), cluster_size);
437-
auto cluster_vectors =
438-
raft::make_device_matrix<float, int64_t>(res, cluster_size, dataset.extent(1));
439642
auto avq_centroid = raft::make_device_vector_view<float, int64_t>(
440643
centroids_view.data_handle() + i * dataset.extent(1), dataset.extent(1));
441644
auto rescale_num_view = raft::make_device_scalar_view<float>(rescale_num.data_handle() + i);
442645
auto rescale_denom_view = raft::make_device_scalar_view<float>(rescale_denom.data_handle() + i);
443646

444-
gather_functor<float, uint32_t>{}(
445-
res, dataset, cluster_ids, cluster_vectors.view(), raft::resource::get_cuda_stream(res));
446-
447647
compute_avq_centroid(
448-
res, cluster_vectors.view(), avq_centroid, rescale_num_view, rescale_denom_view, eta);
648+
res, cluster_vectors, avq_centroid, rescale_num_view, rescale_denom_view, eta);
649+
650+
loader.prefetch_cluster(res, i + 1);
651+
652+
// make sure work is done before swapping buffers in cluster_loader
653+
raft::resource::sync_stream(res);
449654
}
450655

451656
rescale_avq_centroids(res,
452657
centroids_view,
453658
rescale_num.view(),
454659
rescale_denom.view(),
455-
cluster_ptrs.view(),
660+
cluster_offsets.view(),
456661
dataset.extent(0));
457662

458663
raft::resource::sync_stream(res);

cpp/src/neighbors/scann/detail/scann_build.cuh

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -179,8 +179,12 @@ index<T, IdxT> build(
179179
}
180180

181181
// AVQ update of KMeans centroids
182-
apply_avq(
183-
res, dataset, centroids_view, raft::make_const_mdspan(labels_view), params.partitioning_eta);
182+
apply_avq(res,
183+
dataset,
184+
centroids_view,
185+
raft::make_const_mdspan(labels_view),
186+
params.partitioning_eta,
187+
copy_stream);
184188

185189
raft::device_vector_view<uint32_t, int64_t> soar_labels_view = idx.soar_labels();
186190

0 commit comments

Comments
 (0)