Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 31 additions & 18 deletions cpp/bench/ann/src/cuvs/cuvs_cagra_hnswlib_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,7 @@ class cuvs_cagra_hnswlib : public algo<T>, public algo_gpu {
};

cuvs_cagra_hnswlib(Metric metric, int dim, const build_param& param, int concurrent_searches = 1)
: algo<T>(metric, dim),
build_param_{param},
cagra_build_{metric, dim, param.cagra_build_param, concurrent_searches}
: algo<T>(metric, dim), build_param_{param}
{
}

Expand All @@ -57,7 +55,13 @@ class cuvs_cagra_hnswlib : public algo<T>, public algo_gpu {

[[nodiscard]] auto get_sync_stream() const noexcept -> cudaStream_t override
{
return cagra_build_.get_sync_stream();
return handle_.get_sync_stream();
}

[[nodiscard]] auto uses_stream() const noexcept -> bool override
{
// there's no need to synchronize with the GPU neither on build nor on search
return false;
}

// to enable dataset access from GPU memory
Expand All @@ -77,28 +81,37 @@ class cuvs_cagra_hnswlib : public algo<T>, public algo_gpu {
}

private:
raft::resources handle_{};
configured_raft_resources handle_{};
build_param build_param_;
search_param search_param_;
cuvs_cagra<T, IdxT> cagra_build_;
std::shared_ptr<cuvs::neighbors::hnsw::index<T>> hnsw_index_;
};

template <typename T, typename IdxT>
void cuvs_cagra_hnswlib<T, IdxT>::build(const T* dataset, size_t nrow)
{
cagra_build_.build(dataset, nrow);
auto* cagra_index = cagra_build_.get_index();
auto host_dataset_view = raft::make_host_matrix_view<const T, int64_t>(dataset, nrow, this->dim_);
auto opt_dataset_view =
std::optional<raft::host_matrix_view<const T, int64_t>>(std::move(host_dataset_view));
const auto start_clock = std::chrono::system_clock::now();
hnsw_index_ = cuvs::neighbors::hnsw::from_cagra(
handle_, build_param_.hnsw_index_params, *cagra_index, opt_dataset_view);
int time =
std::chrono::duration_cast<std::chrono::seconds>(std::chrono::system_clock::now() - start_clock)
.count();
RAFT_LOG_DEBUG("Graph saved to HNSW format in %d:%d min", time / 60, time % 60);
// when the data set is on host, we can pass it directly to HNSW
bool dataset_is_on_host = raft::get_device_for_address(dataset) == -1;

// re-use the CAGRA wrapper to parse build params
auto bps = build_param_.cagra_build_param;
bps.cagra_params.attach_dataset_on_build = !dataset_is_on_host;
cuvs_cagra<T, IdxT> cagra_wrapper{this->metric_, this->dim_, bps};

// build the CAGRA index
cagra_wrapper.build(dataset, nrow);
auto& cagra_index = *cagra_wrapper.get_index();

// pass the dataset directly to HNSW if it's on the host
std::optional<raft::host_matrix_view<const T, int64_t>> opt_dataset_view = std::nullopt;
if (dataset_is_on_host) {
opt_dataset_view.emplace(
raft::make_host_matrix_view<const T, int64_t>(dataset, nrow, this->dim_));
}

// convert the index to HNSW format
hnsw_index_ = cuvs::neighbors::hnsw::from_cagra(
handle_, build_param_.hnsw_index_params, cagra_index, opt_dataset_view);
}

template <typename T, typename IdxT>
Expand Down
7 changes: 7 additions & 0 deletions cpp/src/neighbors/detail/cagra/cagra_build.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -574,6 +574,13 @@ index<T, IdxT> build(
{
size_t intermediate_degree = params.intermediate_graph_degree;
size_t graph_degree = params.graph_degree;
common::nvtx::range<common::nvtx::domain::cuvs> function_scope(
"cagra::build<%s>(%zu, %zu)",
Accessor::is_managed_type::value ? "managed"
: Accessor::is_host_type::value ? "host"
: "device",
intermediate_degree,
graph_degree);
if (intermediate_degree >= static_cast<size_t>(dataset.extent(0))) {
RAFT_LOG_WARN(
"Intermediate graph degree cannot be larger than dataset size, reducing it to %lu",
Expand Down
Loading