Skip to content
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions cpp/include/cuvs/neighbors/cagra.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,12 @@ struct extend_params {
* degrade recall because no edges are added between the nodes in the same chunk. Auto select when
* 0. */
uint32_t max_chunk_size = 0;

/** The dataset chunk where the maximum size is defined by `max_chunk_size` is divided by
* sub-chunks to limit the working memory usage. This is the knob to control the working memory
* usage. Large working memory size can result in high throughput.
* */
uint32_t max_working_device_memory_size_in_megabyte = 512;
};

/**
Expand Down
30 changes: 21 additions & 9 deletions cpp/src/neighbors/detail/cagra/add_nodes.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ void add_node_core(
const cuvs::neighbors::cagra::index<T, IdxT>& idx,
raft::mdspan<const T, raft::matrix_extent<int64_t>, raft::layout_stride, Accessor>
additional_dataset_view,
raft::host_matrix_view<IdxT, std::int64_t> updated_graph)
raft::host_matrix_view<IdxT, std::int64_t> updated_graph,
const cuvs::neighbors::cagra::extend_params& extend_params)
{
using DistanceT = float;
const std::size_t degree = idx.graph_degree();
Expand Down Expand Up @@ -68,7 +69,17 @@ void add_node_core(
new_size,
raft::resource::get_cuda_stream(handle));

const std::size_t max_chunk_size = 1024;
const std::size_t data_size_per_vector =
sizeof(IdxT) * base_degree + sizeof(DistanceT) * base_degree + sizeof(T) * dim;
Comment thread
tfeher marked this conversation as resolved.
const std::size_t max_search_batch_size =
std::min(std::max(1lu,
extend_params.max_working_device_memory_size_in_megabyte * (1u << 20) /
data_size_per_vector),
num_add);
if (extend_params.max_working_device_memory_size_in_megabyte == 0) {
RAFT_LOG_DEBUG("Overwrites the memory size for the extend function to %lu Byte",
data_size_per_vector);
}

cuvs::neighbors::cagra::search_params params;
params.itopk_size = std::max(base_degree * 2lu, 256lu);
Expand All @@ -77,22 +88,22 @@ void add_node_core(
auto mr = raft::resource::get_workspace_resource(handle);

auto neighbor_indices = raft::make_device_mdarray<IdxT, std::int64_t>(
handle, mr, raft::make_extents<std::int64_t>(max_chunk_size, base_degree));
handle, mr, raft::make_extents<std::int64_t>(max_search_batch_size, base_degree));

auto neighbor_distances = raft::make_device_mdarray<DistanceT, std::int64_t>(
handle, mr, raft::make_extents<std::int64_t>(max_chunk_size, base_degree));
handle, mr, raft::make_extents<std::int64_t>(max_search_batch_size, base_degree));

auto queries = raft::make_device_mdarray<T, std::int64_t>(
handle, mr, raft::make_extents<std::int64_t>(max_chunk_size, dim));
handle, mr, raft::make_extents<std::int64_t>(max_search_batch_size, dim));

auto host_neighbor_indices =
raft::make_host_matrix<IdxT, std::int64_t>(max_chunk_size, base_degree);
raft::make_host_matrix<IdxT, std::int64_t>(max_search_batch_size, base_degree);

cuvs::spatial::knn::detail::utils::batch_load_iterator<T> additional_dataset_batch(
additional_dataset_view.data_handle(),
num_add,
additional_dataset_view.stride(0),
max_chunk_size,
max_search_batch_size,
raft::resource::get_cuda_stream(handle),
raft::resource::get_workspace_resource(handle));
for (const auto& batch : additional_dataset_batch) {
Expand Down Expand Up @@ -254,7 +265,8 @@ void add_graph_nodes(
const std::size_t degree = index.graph_degree();
const std::size_t dim = index.dim();
const std::size_t stride = input_updated_dataset_view.stride(0);
const std::size_t max_chunk_size_ = params.max_chunk_size == 0 ? 1 : params.max_chunk_size;
const std::size_t max_chunk_size_ =
params.max_chunk_size == 0 ? new_dataset_size : params.max_chunk_size;

raft::copy(updated_graph_view.data_handle(),
index.graph().data_handle(),
Expand Down Expand Up @@ -298,7 +310,7 @@ void add_graph_nodes(
stride);

neighbors::cagra::add_node_core<T, IdxT>(
handle, internal_index, additional_dataset_view, updated_graph);
handle, internal_index, additional_dataset_view, updated_graph, params);
raft::resource::sync_stream(handle);
}
}
Expand Down