Skip to content

Commit 0778a95

Browse files
authored
Optimize hnsw::from_cagra<GPU> (#826)
Reduce the CAGRA-for-HNSW build times by: - avoiding unnecessary copies of the data between cagra::build and hnsw::from_cagra in the benchmarks - avoiding unnecessary temporary data buffers in hnsw::from_cagra<GPU> - reducing random reads via forcing 1-1 mapping between the internal indices and external labels during HNSW import As a side-effect, this PR also fixes the bug where hnsw::from_cagra segfaults in benchmarks if the dataset is passed in device memory (and incorrectly wrapped in a host_matrix_view). In addition, this PR adds a bit more verbose NVTX reporting of different stages during the CAGRA/HNSW index build. Authors: - Artem M. Chirkin (https://github.com/achirkin) - Corey J. Nolet (https://github.com/cjnolet) Approvers: - Corey J. Nolet (https://github.com/cjnolet) URL: #826
1 parent fd845a9 commit 0778a95

5 files changed

Lines changed: 210 additions & 99 deletions

File tree

cpp/bench/ann/src/cuvs/cuvs_cagra_hnswlib_wrapper.h

Lines changed: 31 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -39,9 +39,7 @@ class cuvs_cagra_hnswlib : public algo<T>, public algo_gpu {
3939
};
4040

4141
cuvs_cagra_hnswlib(Metric metric, int dim, const build_param& param, int concurrent_searches = 1)
42-
: algo<T>(metric, dim),
43-
build_param_{param},
44-
cagra_build_{metric, dim, param.cagra_build_param, concurrent_searches}
42+
: algo<T>(metric, dim), build_param_{param}
4543
{
4644
}
4745

@@ -57,7 +55,13 @@ class cuvs_cagra_hnswlib : public algo<T>, public algo_gpu {
5755

5856
[[nodiscard]] auto get_sync_stream() const noexcept -> cudaStream_t override
5957
{
60-
return cagra_build_.get_sync_stream();
58+
return handle_.get_sync_stream();
59+
}
60+
61+
[[nodiscard]] auto uses_stream() const noexcept -> bool override
62+
{
63+
// there's no need to synchronize with the GPU neither on build nor on search
64+
return false;
6165
}
6266

6367
// to enable dataset access from GPU memory
@@ -77,28 +81,37 @@ class cuvs_cagra_hnswlib : public algo<T>, public algo_gpu {
7781
}
7882

7983
private:
80-
raft::resources handle_{};
84+
configured_raft_resources handle_{};
8185
build_param build_param_;
8286
search_param search_param_;
83-
cuvs_cagra<T, IdxT> cagra_build_;
8487
std::shared_ptr<cuvs::neighbors::hnsw::index<T>> hnsw_index_;
8588
};
8689

8790
template <typename T, typename IdxT>
8891
void cuvs_cagra_hnswlib<T, IdxT>::build(const T* dataset, size_t nrow)
8992
{
90-
cagra_build_.build(dataset, nrow);
91-
auto* cagra_index = cagra_build_.get_index();
92-
auto host_dataset_view = raft::make_host_matrix_view<const T, int64_t>(dataset, nrow, this->dim_);
93-
auto opt_dataset_view =
94-
std::optional<raft::host_matrix_view<const T, int64_t>>(std::move(host_dataset_view));
95-
const auto start_clock = std::chrono::system_clock::now();
96-
hnsw_index_ = cuvs::neighbors::hnsw::from_cagra(
97-
handle_, build_param_.hnsw_index_params, *cagra_index, opt_dataset_view);
98-
int time =
99-
std::chrono::duration_cast<std::chrono::seconds>(std::chrono::system_clock::now() - start_clock)
100-
.count();
101-
RAFT_LOG_DEBUG("Graph saved to HNSW format in %d:%d min", time / 60, time % 60);
93+
// when the data set is on host, we can pass it directly to HNSW
94+
bool dataset_is_on_host = raft::get_device_for_address(dataset) == -1;
95+
96+
// re-use the CAGRA wrapper to parse build params
97+
auto bps = build_param_.cagra_build_param;
98+
bps.cagra_params.attach_dataset_on_build = !dataset_is_on_host;
99+
cuvs_cagra<T, IdxT> cagra_wrapper{this->metric_, this->dim_, bps};
100+
101+
// build the CAGRA index
102+
cagra_wrapper.build(dataset, nrow);
103+
auto& cagra_index = *cagra_wrapper.get_index();
104+
105+
// pass the dataset directly to HNSW if it's on the host
106+
std::optional<raft::host_matrix_view<const T, int64_t>> opt_dataset_view = std::nullopt;
107+
if (dataset_is_on_host) {
108+
opt_dataset_view.emplace(
109+
raft::make_host_matrix_view<const T, int64_t>(dataset, nrow, this->dim_));
110+
}
111+
112+
// convert the index to HNSW format
113+
hnsw_index_ = cuvs::neighbors::hnsw::from_cagra(
114+
handle_, build_param_.hnsw_index_params, cagra_index, opt_dataset_view);
102115
}
103116

104117
template <typename T, typename IdxT>

cpp/src/neighbors/detail/cagra/cagra_build.cuh

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -574,6 +574,13 @@ index<T, IdxT> build(
574574
{
575575
size_t intermediate_degree = params.intermediate_graph_degree;
576576
size_t graph_degree = params.graph_degree;
577+
common::nvtx::range<common::nvtx::domain::cuvs> function_scope(
578+
"cagra::build<%s>(%zu, %zu)",
579+
Accessor::is_managed_type::value ? "managed"
580+
: Accessor::is_host_type::value ? "host"
581+
: "device",
582+
intermediate_degree,
583+
graph_degree);
577584
if (intermediate_degree >= static_cast<size_t>(dataset.extent(0))) {
578585
RAFT_LOG_WARN(
579586
"Intermediate graph degree cannot be larger than dataset size, reducing it to %lu",

0 commit comments

Comments
 (0)