Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
258 changes: 233 additions & 25 deletions cpp/src/neighbors/scann/detail/scann_avq.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,8 @@ __global__ void build_clusters(
template <typename LabelT, typename IdxT>
void compute_cluster_offsets(raft::resources const& dev_resources,
raft::device_vector_view<const LabelT, IdxT> clusters,
raft::device_vector_view<LabelT, int64_t> cluster_sizes)
raft::device_vector_view<LabelT, int64_t> cluster_sizes,
raft::host_scalar_view<int64_t> max_cluster_size)
{
cudaStream_t stream = raft::resource::get_cuda_stream(dev_resources);
rmm::device_async_resource_ref device_memory =
Expand Down Expand Up @@ -103,10 +104,33 @@ void compute_cluster_offsets(raft::resources const& dev_resources,
clusters.extent(0),
stream);

temp_storage_bytes = 0;
int num_items = cluster_sizes.extent(0);

// Compute max cluster size
auto d_max_cluster_size = raft::make_device_scalar<int64_t>(dev_resources, 0);
temp_storage_bytes = 0;

cub::DeviceReduce::Max(nullptr,
temp_storage_bytes,
cluster_sizes.data_handle(),
d_max_cluster_size.data_handle(),
num_items);

rmm::device_uvector<int64_t> temp_storage_max(temp_storage_bytes, stream, device_memory);

cub::DeviceReduce::Max(temp_storage_max.data(),
temp_storage_bytes,
cluster_sizes.data_handle(),
d_max_cluster_size.data_handle(),
num_items);

raft::copy(max_cluster_size.data_handle(),
Comment thread
tfeher marked this conversation as resolved.
Outdated
d_max_cluster_size.data_handle(),
d_max_cluster_size.size(),
stream);
// Scan to sum cluster sizes and get cluster start ptrs in flat array
// Done in place
int num_items = cluster_sizes.extent(0);
temp_storage_bytes = 0;

cub::DeviceScan::ExclusiveSum(nullptr,
temp_storage_bytes,
Expand Down Expand Up @@ -368,6 +392,190 @@ void rescale_avq_centroids(raft::resources const& dev_resources,
});
}

/**
* A class for loading clusters into a compact matrix (sparse gather)
* for use in AVQ.
*
* There are two possible scenarios:
* 1. Dataset is stored in device memory: No host buffers are allocated,
* and the gather is performed on device
* 2. Dataset is stored in host memory: Two pinned buffers are allocated
* in host for fast DtoH copies of cluster ids, and fast HtoD copy of the
* cluster matrix, while amortizing the cost of allocating pinned memory.
* The gather is performed on cpu, overlapping with GPU compute. Copies are
* allocated on the provided stream, allowing for overlapping with
* other work on other streams.
*/
template <typename T, typename LabelT>
class cluster_loader {
private:
raft::pinned_matrix<T, int64_t> cluster_buf_;
raft::pinned_vector<LabelT, int64_t> cluster_ids_buf_;
raft::device_matrix<T, int64_t> d_cluster_buf_;
raft::device_matrix<T, int64_t> d_cluster_copy_buf_;
const T* dataset_ptr_;
raft::host_vector_view<const LabelT> h_cluster_offsets_;
raft::device_vector_view<const LabelT> cluster_ids_;
cudaStream_t stream_;
int64_t dim_;
int64_t n_rows_;
bool needs_copy_;

int64_t cur_idx_ = -1;
int64_t copy_idx_ = -1;

size_t cluster_size(LabelT idx)
{
if (idx + 1 < h_cluster_offsets_.extent(0)) {
return h_cluster_offsets_(idx + 1) - h_cluster_offsets_(idx);
}
return n_rows_ - h_cluster_offsets_(idx);
}

cluster_loader(raft::resources const& res,
const T* dataset_ptr,
int64_t dim,
int64_t n_rows,
int64_t max_cluster_size,
int64_t h_buf_size,
raft::host_vector_view<LabelT> h_cluster_offsets,
raft::device_vector_view<LabelT> cluster_ids,
bool needs_copy,
cudaStream_t stream)
: dim_(dim),
n_rows_(n_rows),
dataset_ptr_(dataset_ptr),
cluster_buf_(raft::make_pinned_matrix<T, int64_t>(res, h_buf_size, dim)),
Comment thread
tfeher marked this conversation as resolved.
cluster_ids_buf_(raft::make_pinned_vector<LabelT, int64_t>(res, h_buf_size)),
d_cluster_buf_(raft::make_device_matrix<T, int64_t>(res, max_cluster_size, dim)),
d_cluster_copy_buf_(raft::make_device_matrix<T, int64_t>(res, max_cluster_size, dim)),
h_cluster_offsets_(h_cluster_offsets),
cluster_ids_(cluster_ids),
needs_copy_(needs_copy),
stream_(stream)
{
}

public:
cluster_loader(raft::resources const& res,
raft::device_matrix_view<const T, int64_t> dataset_view,
raft::host_vector_view<LabelT> h_cluster_offsets,
raft::device_vector_view<LabelT> cluster_ids,
int64_t max_cluster_size,
cudaStream_t stream)
: cluster_loader(res,
dataset_view.data_handle(),
dataset_view.extent(1),
dataset_view.extent(0),
max_cluster_size,
0,
h_cluster_offsets,
cluster_ids,
false,
stream)

{
}

cluster_loader(raft::resources const& res,
raft::host_matrix_view<const T, int64_t> dataset_view,
raft::host_vector_view<LabelT> h_cluster_offsets,
raft::device_vector_view<LabelT> cluster_ids,
int64_t max_cluster_size,
cudaStream_t stream)
: cluster_loader(res,
dataset_view.data_handle(),
dataset_view.extent(1),
dataset_view.extent(0),
max_cluster_size,
max_cluster_size,
h_cluster_offsets,
cluster_ids,
true,
stream)

{
}

/**
* @brief load and return a view of the provided cluster
*
* @param res: the raft resources;
* @param cluster_idx: the index of the cluster to be loaded
* @return device_matrix_view of the cluster vectors
*/
raft::device_matrix_view<T, int64_t> load_cluster(raft::resources const& res, LabelT cluster_idx)
{
size_t size = cluster_size(cluster_idx);

// Check if cluster is already loaded
if (cur_idx_ != cluster_idx) {
// If not, load the cluster
if (copy_idx_ != cluster_idx) { prefetch_cluster(res, cluster_idx); }

// swap buffers
std::swap(d_cluster_buf_, d_cluster_copy_buf_);
std::swap(cur_idx_, copy_idx_);
}

return raft::make_device_matrix_view<T, int64_t>(d_cluster_buf_.data_handle(), size, dim_);
}

/** @brief Perform gather operation on stream_
*
* @param res: the raft resources
* @param cluster_idx: the index of the cluster
*/
void prefetch_cluster(raft::resources const& res, LabelT cluster_idx)
{
if (cluster_idx >= h_cluster_offsets_.extent(0)) { return; }

size_t size = cluster_size(cluster_idx);

auto cluster_ids = raft::make_device_vector_view<const LabelT, int64_t>(
cluster_ids_.data_handle() + h_cluster_offsets_(cluster_idx), size);

auto cluster_vectors =
raft::make_device_matrix_view<float, int64_t>(d_cluster_copy_buf_.data_handle(), size, dim_);

if (needs_copy_) {
// htod
auto h_cluster_ids =
raft::make_pinned_vector_view<LabelT, int64_t>(cluster_ids_buf_.data_handle(), size);
Comment thread
tfeher marked this conversation as resolved.

raft::copy(
h_cluster_ids.data_handle(), cluster_ids.data_handle(), cluster_ids.size(), stream_);
raft::resource::sync_stream(res, stream_);

auto pinned_cluster = raft::make_pinned_matrix_view<T, int64_t>(
cluster_buf_.data_handle(), cluster_vectors.extent(0), cluster_vectors.extent(1));

int n_threads = std::min<int>(omp_get_max_threads(), 32);
#pragma omp parallel for num_threads(n_threads)
for (int i = 0; i < h_cluster_ids.extent(0); i++) {
memcpy(pinned_cluster.data_handle() + i * pinned_cluster.extent(1),
dataset_ptr_ + h_cluster_ids(i) * dim_,
sizeof(T) * dim_);
}

raft::copy(cluster_vectors.data_handle(),
pinned_cluster.data_handle(),
pinned_cluster.size(),
stream_);
raft::resource::sync_stream(res, stream_);

} else {
// dtod
auto dataset_view =
raft::make_device_matrix_view<const T, int64_t>(dataset_ptr_, n_rows_, dim_);

raft::matrix::gather(res, dataset_view, cluster_ids, cluster_vectors);
}

copy_idx_ = cluster_idx;
}
};

/**
* @brief Perform AVQ adjustment on cluster centers
*
Expand Down Expand Up @@ -396,63 +604,63 @@ void apply_avq(raft::resources const& res,
raft::mdspan<const T, raft::matrix_extent<IdxT>, raft::row_major, Accessor> dataset,
raft::device_matrix_view<T, IdxT> centroids_view,
raft::device_vector_view<const LabelT, IdxT> labels_view,
float eta)
float eta,
cudaStream_t gather_stream)
Comment thread
tfeher marked this conversation as resolved.
Outdated
{
// Compute clusters

cudaStream_t stream = raft::resource::get_cuda_stream(res);
auto cluster_ptrs = raft::make_device_vector<uint32_t, int64_t>(res, centroids_view.extent(0));
auto clusters = raft::make_device_vector<uint32_t, int64_t>(res, dataset.extent(0));

compute_cluster_offsets(res, labels_view, cluster_ptrs.view());
cudaStream_t stream = raft::resource::get_cuda_stream(res);
auto cluster_offsets = raft::make_device_vector<uint32_t, int64_t>(res, centroids_view.extent(0));
auto clusters = raft::make_device_vector<uint32_t, int64_t>(res, dataset.extent(0));
auto max_cluster_size = raft::make_host_scalar<int64_t>(0);
Comment thread
tfeher marked this conversation as resolved.
Outdated

auto h_cluster_ptrs = raft::make_host_vector<uint32_t, int64_t>(cluster_ptrs.extent(0));
compute_cluster_offsets(res, labels_view, cluster_offsets.view(), max_cluster_size.view());
auto h_cluster_offsets = raft::make_host_vector<uint32_t, int64_t>(cluster_offsets.extent(0));

raft::copy(h_cluster_ptrs.data_handle(), cluster_ptrs.data_handle(), cluster_ptrs.size(), stream);
raft::copy(
h_cluster_offsets.data_handle(), cluster_offsets.data_handle(), cluster_offsets.size(), stream);

dim3 block(32, 1, 1);
dim3 grid((dataset.extent(0) + block.x - 1) / block.x, 1, 1);

build_clusters<uint32_t, uint32_t><<<grid, block>>>(labels_view.data_handle(),
clusters.view().data_handle(),
cluster_ptrs.view().data_handle(),
cluster_offsets.view().data_handle(),
dataset.extent(0),
labels_view.extent(0));
RAFT_CUDA_TRY(cudaPeekAtLastError());

auto rescale_num = raft::make_device_vector<float, int64_t>(res, centroids_view.extent(0));
auto rescale_denom = raft::make_device_vector<float, int64_t>(res, centroids_view.extent(0));

cluster_loader<T, LabelT> loader(
res, dataset, h_cluster_offsets.view(), clusters.view(), max_cluster_size(0), gather_stream);
raft::resource::sync_stream(res);

RAFT_LOG_DEBUG("Compute AVQ centroids\n");

for (int i = 0; i < h_cluster_ptrs.extent(0); i++) {
int cluster_size = i + 1 < h_cluster_ptrs.extent(0) ? h_cluster_ptrs(i + 1) - h_cluster_ptrs(i)
: dataset.extent(0) - h_cluster_ptrs(i);
for (int i = 0; i < h_cluster_offsets.extent(0); i++) {
auto cluster_vectors = loader.load_cluster(res, i);

if (cluster_size == 0) { continue; }
auto cluster_ids = raft::make_device_vector_view<const uint32_t, int64_t>(
clusters.data_handle() + h_cluster_ptrs(i), cluster_size);
auto cluster_vectors =
raft::make_device_matrix<float, int64_t>(res, cluster_size, dataset.extent(1));
auto avq_centroid = raft::make_device_vector_view<float, int64_t>(
centroids_view.data_handle() + i * dataset.extent(1), dataset.extent(1));
auto rescale_num_view = raft::make_device_scalar_view<float>(rescale_num.data_handle() + i);
auto rescale_denom_view = raft::make_device_scalar_view<float>(rescale_denom.data_handle() + i);

gather_functor<float, uint32_t>{}(
res, dataset, cluster_ids, cluster_vectors.view(), raft::resource::get_cuda_stream(res));

compute_avq_centroid(
res, cluster_vectors.view(), avq_centroid, rescale_num_view, rescale_denom_view, eta);
res, cluster_vectors, avq_centroid, rescale_num_view, rescale_denom_view, eta);

loader.prefetch_cluster(res, i + 1);

// make sure work is done before swapping buffers in cluster_loader
raft::resource::sync_stream(res);
}

rescale_avq_centroids(res,
centroids_view,
rescale_num.view(),
rescale_denom.view(),
cluster_ptrs.view(),
cluster_offsets.view(),
dataset.extent(0));

raft::resource::sync_stream(res);
Expand Down
8 changes: 6 additions & 2 deletions cpp/src/neighbors/scann/detail/scann_build.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -179,8 +179,12 @@ index<T, IdxT> build(
}

// AVQ update of KMeans centroids
apply_avq(
res, dataset, centroids_view, raft::make_const_mdspan(labels_view), params.partitioning_eta);
apply_avq(res,
dataset,
centroids_view,
raft::make_const_mdspan(labels_view),
params.partitioning_eta,
copy_stream);
Comment thread
tfeher marked this conversation as resolved.

raft::device_vector_view<uint32_t, int64_t> soar_labels_view = idx.soar_labels();

Expand Down
Loading