Skip to content

Commit 192a917

Browse files
authored
Adding int64 search for MG CAGRA (#975)
Answers #894 and #895 by exposing new MG CAGRA search functions with `int64_t` neighbors output type. The PR also updates the tests and benchs to use this new search function. Additionally, the tree merge feature used to perform the KNN merge operation in-place which was an issue, this PR fixes this. Authors: - Victor Lafargue (https://github.com/viclafargue) Approvers: - Corey J. Nolet (https://github.com/cjnolet) URL: #975
1 parent b18e199 commit 192a917

18 files changed

Lines changed: 386 additions & 100 deletions

cpp/bench/ann/src/cuvs/cuvs_mg_cagra_wrapper.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@ void cuvs_mg_cagra<T, IdxT>::search_base(
159159
auto queries_view =
160160
raft::make_host_matrix_view<const T, int64_t, raft::row_major>(queries, batch_size, dim_);
161161
auto neighbors_view =
162-
raft::make_host_matrix_view<IdxT, int64_t, raft::row_major>((IdxT*)neighbors, batch_size, k);
162+
raft::make_host_matrix_view<int64_t, int64_t, raft::row_major>(neighbors, batch_size, k);
163163
auto distances_view =
164164
raft::make_host_matrix_view<float, int64_t, raft::row_major>(distances, batch_size, k);
165165

cpp/bench/ann/src/cuvs/cuvs_mg_ivf_flat_wrapper.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ void cuvs_mg_ivf_flat<T, IdxT>::search(
132132
auto queries_view = raft::make_host_matrix_view<const T, int64_t, raft::row_major>(
133133
queries, IdxT(batch_size), IdxT(dim_));
134134
auto neighbors_view = raft::make_host_matrix_view<IdxT, int64_t, raft::row_major>(
135-
(IdxT*)neighbors, IdxT(batch_size), IdxT(k));
135+
neighbors, IdxT(batch_size), IdxT(k));
136136
auto distances_view = raft::make_host_matrix_view<float, int64_t, raft::row_major>(
137137
distances, IdxT(batch_size), IdxT(k));
138138

cpp/bench/ann/src/cuvs/cuvs_mg_ivf_pq_wrapper.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ void cuvs_mg_ivf_pq<T, IdxT>::search(
129129
auto queries_view = raft::make_host_matrix_view<const T, int64_t, raft::row_major>(
130130
queries, IdxT(batch_size), IdxT(dim_));
131131
auto neighbors_view = raft::make_host_matrix_view<IdxT, int64_t, raft::row_major>(
132-
(IdxT*)neighbors, IdxT(batch_size), IdxT(k));
132+
neighbors, IdxT(batch_size), IdxT(k));
133133
auto distances_view = raft::make_host_matrix_view<float, int64_t, raft::row_major>(
134134
distances, IdxT(batch_size), IdxT(k));
135135

cpp/include/cuvs/neighbors/cagra.hpp

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2335,6 +2335,124 @@ void extend(const raft::resources& clique,
23352335

23362336
/// \defgroup mg_cpp_index_search ANN MG index search
23372337

2338+
/// \ingroup mg_cpp_index_search
2339+
/**
2340+
* @brief Searches a multi-GPU index
2341+
*
2342+
* Usage example:
2343+
* @code{.cpp}
2344+
* raft::device_resources_snmg clique;
2345+
* cuvs::neighbors::mg_index_params<cagra::index_params> index_params;
2346+
* auto index = cuvs::neighbors::cagra::build(clique, index_params, index_dataset);
2347+
* cuvs::neighbors::mg_search_params<cagra::search_params> search_params;
2348+
* cuvs::neighbors::cagra::search(clique, index, search_params, queries, neighbors,
2349+
* distances);
2350+
* @endcode
2351+
*
2352+
* @param[in] clique a `raft::resources` object specifying the NCCL clique configuration
2353+
* @param[in] index the pre-built index
2354+
* @param[in] search_params configure the index search
2355+
* @param[in] queries a row-major matrix on host [n_rows, dim]
2356+
* @param[out] neighbors a row-major matrix on host [n_rows, n_neighbors]
2357+
* @param[out] distances a row-major matrix on host [n_rows, n_neighbors]
2358+
*
2359+
*/
2360+
void search(const raft::resources& clique,
2361+
const cuvs::neighbors::mg_index<cagra::index<float, uint32_t>, float, uint32_t>& index,
2362+
const cuvs::neighbors::mg_search_params<cagra::search_params>& search_params,
2363+
raft::host_matrix_view<const float, int64_t, row_major> queries,
2364+
raft::host_matrix_view<int64_t, int64_t, row_major> neighbors,
2365+
raft::host_matrix_view<float, int64_t, row_major> distances);
2366+
2367+
/// \ingroup mg_cpp_index_search
2368+
/**
2369+
* @brief Searches a multi-GPU index
2370+
*
2371+
* Usage example:
2372+
* @code{.cpp}
2373+
* raft::device_resources_snmg clique;
2374+
* cuvs::neighbors::mg_index_params<cagra::index_params> index_params;
2375+
* auto index = cuvs::neighbors::cagra::build(clique, index_params, index_dataset);
2376+
* cuvs::neighbors::mg_search_params<cagra::search_params> search_params;
2377+
* cuvs::neighbors::cagra::search(clique, index, search_params, queries, neighbors,
2378+
* distances);
2379+
* @endcode
2380+
*
2381+
* @param[in] clique a `raft::resources` object specifying the NCCL clique configuration
2382+
* @param[in] index the pre-built index
2383+
* @param[in] search_params configure the index search
2384+
* @param[in] queries a row-major matrix on host [n_rows, dim]
2385+
* @param[out] neighbors a row-major matrix on host [n_rows, n_neighbors]
2386+
* @param[out] distances a row-major matrix on host [n_rows, n_neighbors]
2387+
*
2388+
*/
2389+
void search(const raft::resources& clique,
2390+
const cuvs::neighbors::mg_index<cagra::index<half, uint32_t>, half, uint32_t>& index,
2391+
const cuvs::neighbors::mg_search_params<cagra::search_params>& search_params,
2392+
raft::host_matrix_view<const half, int64_t, row_major> queries,
2393+
raft::host_matrix_view<int64_t, int64_t, row_major> neighbors,
2394+
raft::host_matrix_view<float, int64_t, row_major> distances);
2395+
2396+
/// \ingroup mg_cpp_index_search
2397+
/**
2398+
* @brief Searches a multi-GPU index
2399+
*
2400+
* Usage example:
2401+
* @code{.cpp}
2402+
* raft::device_resources_snmg clique;
2403+
* cuvs::neighbors::mg_index_params<cagra::index_params> index_params;
2404+
* auto index = cuvs::neighbors::cagra::build(clique, index_params, index_dataset);
2405+
* cuvs::neighbors::mg_search_params<cagra::search_params> search_params;
2406+
* cuvs::neighbors::cagra::search(clique, index, search_params, queries, neighbors,
2407+
* distances);
2408+
* @endcode
2409+
*
2410+
* @param[in] clique a `raft::resources` object specifying the NCCL clique configuration
2411+
* @param[in] index the pre-built index
2412+
* @param[in] search_params configure the index search
2413+
* @param[in] queries a row-major matrix on host [n_rows, dim]
2414+
* @param[out] neighbors a row-major matrix on host [n_rows, n_neighbors]
2415+
* @param[out] distances a row-major matrix on host [n_rows, n_neighbors]
2416+
*
2417+
*/
2418+
void search(
2419+
const raft::resources& clique,
2420+
const cuvs::neighbors::mg_index<cagra::index<int8_t, uint32_t>, int8_t, uint32_t>& index,
2421+
const cuvs::neighbors::mg_search_params<cagra::search_params>& search_params,
2422+
raft::host_matrix_view<const int8_t, int64_t, row_major> queries,
2423+
raft::host_matrix_view<int64_t, int64_t, row_major> neighbors,
2424+
raft::host_matrix_view<float, int64_t, row_major> distances);
2425+
2426+
/// \ingroup mg_cpp_index_search
2427+
/**
2428+
* @brief Searches a multi-GPU index
2429+
*
2430+
* Usage example:
2431+
* @code{.cpp}
2432+
* raft::device_resources_snmg clique;
2433+
* cuvs::neighbors::mg_index_params<cagra::index_params> index_params;
2434+
* auto index = cuvs::neighbors::cagra::build(clique, index_params, index_dataset);
2435+
* cuvs::neighbors::mg_search_params<cagra::search_params> search_params;
2436+
* cuvs::neighbors::cagra::search(clique, index, search_params, queries, neighbors,
2437+
* distances);
2438+
* @endcode
2439+
*
2440+
* @param[in] clique a `raft::resources` object specifying the NCCL clique configuration
2441+
* @param[in] index the pre-built index
2442+
* @param[in] search_params configure the index search
2443+
* @param[in] queries a row-major matrix on host [n_rows, dim]
2444+
* @param[out] neighbors a row-major matrix on host [n_rows, n_neighbors]
2445+
* @param[out] distances a row-major matrix on host [n_rows, n_neighbors]
2446+
*
2447+
*/
2448+
void search(
2449+
const raft::resources& clique,
2450+
const cuvs::neighbors::mg_index<cagra::index<uint8_t, uint32_t>, uint8_t, uint32_t>& index,
2451+
const cuvs::neighbors::mg_search_params<cagra::search_params>& search_params,
2452+
raft::host_matrix_view<const uint8_t, int64_t, row_major> queries,
2453+
raft::host_matrix_view<int64_t, int64_t, row_major> neighbors,
2454+
raft::host_matrix_view<float, int64_t, row_major> distances);
2455+
23382456
/// \ingroup mg_cpp_index_search
23392457
/**
23402458
* @brief Searches a multi-GPU index

cpp/include/cuvs/neighbors/common.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -786,12 +786,12 @@ void extend(
786786
std::optional<raft::mdspan<const IdxT, vector_extent<int64_t>, layout_c_contiguous, Accessor2>>
787787
new_indices);
788788

789-
template <typename AnnIndexType, typename T, typename IdxT>
789+
template <typename AnnIndexType, typename T, typename IdxT, typename searchIdxT>
790790
void search(const raft::resources& handle,
791791
const cuvs::neighbors::iface<AnnIndexType, T, IdxT>& interface,
792792
const cuvs::neighbors::search_params* search_params,
793793
raft::device_matrix_view<const T, int64_t, row_major> h_queries,
794-
raft::device_matrix_view<IdxT, int64_t, row_major> d_neighbors,
794+
raft::device_matrix_view<searchIdxT, int64_t, row_major> d_neighbors,
795795
raft::device_matrix_view<float, int64_t, row_major> d_distances);
796796

797797
template <typename AnnIndexType, typename T, typename IdxT>

cpp/src/neighbors/iface/generate_iface.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -190,14 +190,28 @@
190190
const cuvs::neighbors::iface<cagra::index<T, IdxT>, T, IdxT>& interface, \\
191191
const cuvs::neighbors::search_params* search_params, \\
192192
raft::device_matrix_view<const T, int64_t, row_major> queries, \\
193-
raft::device_matrix_view<IdxT, int64_t, row_major> neighbors, \\
193+
raft::device_matrix_view<int64_t, int64_t, row_major> neighbors, \\
194194
raft::device_matrix_view<float, int64_t, row_major> distances); \\
195195
\\
196196
template void search(const raft::resources& handle, \\
197197
const cuvs::neighbors::iface<cagra::index<T, IdxT>, T, IdxT>& interface, \\
198198
const cuvs::neighbors::search_params* search_params, \\
199199
raft::host_matrix_view<const T, int64_t, row_major> h_queries, \\
200-
raft::device_matrix_view<IdxT, int64_t, row_major> d_neighbors, \\
200+
raft::device_matrix_view<int64_t, int64_t, row_major> d_neighbors, \\
201+
raft::device_matrix_view<float, int64_t, row_major> d_distances); \\
202+
\\
203+
template void search(const raft::resources& handle, \\
204+
const cuvs::neighbors::iface<cagra::index<T, IdxT>, T, IdxT>& interface, \\
205+
const cuvs::neighbors::search_params* search_params, \\
206+
raft::device_matrix_view<const T, int64_t, row_major> queries, \\
207+
raft::device_matrix_view<uint32_t, int64_t, row_major> neighbors, \\
208+
raft::device_matrix_view<float, int64_t, row_major> distances); \\
209+
\\
210+
template void search(const raft::resources& handle, \\
211+
const cuvs::neighbors::iface<cagra::index<T, IdxT>, T, IdxT>& interface, \\
212+
const cuvs::neighbors::search_params* search_params, \\
213+
raft::host_matrix_view<const T, int64_t, row_major> h_queries, \\
214+
raft::device_matrix_view<uint32_t, int64_t, row_major> d_neighbors, \\
201215
raft::device_matrix_view<float, int64_t, row_major> d_distances); \\
202216
\\
203217
template void serialize(const raft::resources& handle, \\

cpp/src/neighbors/iface/iface.hpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -78,12 +78,12 @@ void extend(
7878
resource::sync_stream(handle);
7979
}
8080

81-
template <typename AnnIndexType, typename T, typename IdxT>
81+
template <typename AnnIndexType, typename T, typename IdxT, typename searchIdxT>
8282
void search(const raft::resources& handle,
8383
const cuvs::neighbors::iface<AnnIndexType, T, IdxT>& interface,
8484
const cuvs::neighbors::search_params* search_params,
8585
raft::device_matrix_view<const T, int64_t, row_major> queries,
86-
raft::device_matrix_view<IdxT, int64_t, row_major> neighbors,
86+
raft::device_matrix_view<searchIdxT, int64_t, row_major> neighbors,
8787
raft::device_matrix_view<float, int64_t, row_major> distances)
8888
{
8989
// std::lock_guard(*interface.mutex_);
@@ -114,12 +114,12 @@ void search(const raft::resources& handle,
114114
}
115115

116116
// for MG ANN only
117-
template <typename AnnIndexType, typename T, typename IdxT>
117+
template <typename AnnIndexType, typename T, typename IdxT, typename searchIdxT>
118118
void search(const raft::resources& handle,
119119
const cuvs::neighbors::iface<AnnIndexType, T, IdxT>& interface,
120120
const cuvs::neighbors::search_params* search_params,
121121
raft::host_matrix_view<const T, int64_t, row_major> h_queries,
122-
raft::device_matrix_view<IdxT, int64_t, row_major> d_neighbors,
122+
raft::device_matrix_view<searchIdxT, int64_t, row_major> d_neighbors,
123123
raft::device_matrix_view<float, int64_t, row_major> d_distances)
124124
{
125125
// std::lock_guard(*interface.mutex_);

cpp/src/neighbors/iface/iface_cagra_float_uint32_t.cu

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,14 +67,28 @@ namespace cuvs::neighbors {
6767
const cuvs::neighbors::iface<cagra::index<T, IdxT>, T, IdxT>& interface, \
6868
const cuvs::neighbors::search_params* search_params, \
6969
raft::device_matrix_view<const T, int64_t, row_major> queries, \
70-
raft::device_matrix_view<IdxT, int64_t, row_major> neighbors, \
70+
raft::device_matrix_view<int64_t, int64_t, row_major> neighbors, \
7171
raft::device_matrix_view<float, int64_t, row_major> distances); \
7272
\
7373
template void search(const raft::resources& handle, \
7474
const cuvs::neighbors::iface<cagra::index<T, IdxT>, T, IdxT>& interface, \
7575
const cuvs::neighbors::search_params* search_params, \
7676
raft::host_matrix_view<const T, int64_t, row_major> h_queries, \
77-
raft::device_matrix_view<IdxT, int64_t, row_major> d_neighbors, \
77+
raft::device_matrix_view<int64_t, int64_t, row_major> d_neighbors, \
78+
raft::device_matrix_view<float, int64_t, row_major> d_distances); \
79+
\
80+
template void search(const raft::resources& handle, \
81+
const cuvs::neighbors::iface<cagra::index<T, IdxT>, T, IdxT>& interface, \
82+
const cuvs::neighbors::search_params* search_params, \
83+
raft::device_matrix_view<const T, int64_t, row_major> queries, \
84+
raft::device_matrix_view<uint32_t, int64_t, row_major> neighbors, \
85+
raft::device_matrix_view<float, int64_t, row_major> distances); \
86+
\
87+
template void search(const raft::resources& handle, \
88+
const cuvs::neighbors::iface<cagra::index<T, IdxT>, T, IdxT>& interface, \
89+
const cuvs::neighbors::search_params* search_params, \
90+
raft::host_matrix_view<const T, int64_t, row_major> h_queries, \
91+
raft::device_matrix_view<uint32_t, int64_t, row_major> d_neighbors, \
7892
raft::device_matrix_view<float, int64_t, row_major> d_distances); \
7993
\
8094
template void serialize(const raft::resources& handle, \

cpp/src/neighbors/iface/iface_cagra_half_uint32_t.cu

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,14 +67,28 @@ namespace cuvs::neighbors {
6767
const cuvs::neighbors::iface<cagra::index<T, IdxT>, T, IdxT>& interface, \
6868
const cuvs::neighbors::search_params* search_params, \
6969
raft::device_matrix_view<const T, int64_t, row_major> queries, \
70-
raft::device_matrix_view<IdxT, int64_t, row_major> neighbors, \
70+
raft::device_matrix_view<int64_t, int64_t, row_major> neighbors, \
7171
raft::device_matrix_view<float, int64_t, row_major> distances); \
7272
\
7373
template void search(const raft::resources& handle, \
7474
const cuvs::neighbors::iface<cagra::index<T, IdxT>, T, IdxT>& interface, \
7575
const cuvs::neighbors::search_params* search_params, \
7676
raft::host_matrix_view<const T, int64_t, row_major> h_queries, \
77-
raft::device_matrix_view<IdxT, int64_t, row_major> d_neighbors, \
77+
raft::device_matrix_view<int64_t, int64_t, row_major> d_neighbors, \
78+
raft::device_matrix_view<float, int64_t, row_major> d_distances); \
79+
\
80+
template void search(const raft::resources& handle, \
81+
const cuvs::neighbors::iface<cagra::index<T, IdxT>, T, IdxT>& interface, \
82+
const cuvs::neighbors::search_params* search_params, \
83+
raft::device_matrix_view<const T, int64_t, row_major> queries, \
84+
raft::device_matrix_view<uint32_t, int64_t, row_major> neighbors, \
85+
raft::device_matrix_view<float, int64_t, row_major> distances); \
86+
\
87+
template void search(const raft::resources& handle, \
88+
const cuvs::neighbors::iface<cagra::index<T, IdxT>, T, IdxT>& interface, \
89+
const cuvs::neighbors::search_params* search_params, \
90+
raft::host_matrix_view<const T, int64_t, row_major> h_queries, \
91+
raft::device_matrix_view<uint32_t, int64_t, row_major> d_neighbors, \
7892
raft::device_matrix_view<float, int64_t, row_major> d_distances); \
7993
\
8094
template void serialize(const raft::resources& handle, \

0 commit comments

Comments
 (0)