Skip to content

Commit 4ba4537

Browse files
authored
Merge branch 'branch-25.08' into java/thread-check-on-resource-access
2 parents e91bd75 + c04ad9b commit 4ba4537

2 files changed

Lines changed: 108 additions & 39 deletions

File tree

cpp/src/neighbors/detail/cagra/cagra_build.cuh

Lines changed: 38 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright (c) 2023-2024, NVIDIA CORPORATION.
2+
* Copyright (c) 2023-2025, NVIDIA CORPORATION.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -166,25 +166,51 @@ void build_knn_graph(
166166
const auto num_queries = dataset.extent(0);
167167

168168
// Use the same maximum batch size as the ivf_pq::search to avoid allocating more than needed.
169-
const uint32_t max_queries = pq.search_params.max_internal_batch_size;
169+
uint32_t max_queries = pq.search_params.max_internal_batch_size;
170170

171171
// Heuristic: the build_knn_graph code should use only a fraction of the workspace memory; the
172172
// rest should be used by the ivf_pq::search. Here we say that the workspace size should be a good
173173
// multiple of what is required for the I/O batching below.
174174
constexpr size_t kMinWorkspaceRatio = 5;
175-
auto desired_workspace_size = max_queries * kMinWorkspaceRatio *
176-
(sizeof(DataT) * dataset.extent(1) // queries (dataset batch)
177-
+ sizeof(float) * gpu_top_k // distances
178-
+ sizeof(int64_t) * gpu_top_k // neighbors
179-
+ sizeof(float) * top_k // refined_distances
180-
+ sizeof(int64_t) * top_k // refined_neighbors
181-
);
175+
constexpr size_t kMinLargeBatchSize = 512;
176+
auto desired_workspace_size =
177+
max_queries * (sizeof(DataT) * dataset.extent(1) // queries (dataset batch)
178+
+ sizeof(float) * gpu_top_k // distances
179+
+ sizeof(int64_t) * gpu_top_k // neighbors
180+
+ sizeof(float) * top_k // refined_distances
181+
+ sizeof(int64_t) * top_k // refined_neighbors
182+
);
183+
auto free_space_ratio = raft::resource::get_workspace_free_bytes(res) / desired_workspace_size;
184+
bool use_large_workspace = false;
185+
if (free_space_ratio < kMinWorkspaceRatio) {
186+
auto adjusted_max_queries =
187+
static_cast<uint32_t>(max_queries * free_space_ratio / kMinWorkspaceRatio);
188+
if (adjusted_max_queries >= kMinLargeBatchSize) {
189+
// adjust max_queries, so that the ratio free_space_ratio gets not larger than
190+
// kMinWorkspaceRatio.
191+
RAFT_LOG_INFO(
192+
"CAGRA graph build: reducing IVF-PQ search max_internal_batch_size from %u -> %u to fit "
193+
"the workspace",
194+
max_queries,
195+
adjusted_max_queries);
196+
max_queries = adjusted_max_queries;
197+
pq.search_params.max_internal_batch_size = adjusted_max_queries;
198+
} else {
199+
// adjusting max_queries to a very small value isn't practical, so we use the large workspace
200+
// instead.
201+
use_large_workspace = true;
202+
RAFT_LOG_WARN(
203+
"Using large workspace memory for IVF-PQ search during CAGRA graph build. Desired "
204+
"workspace size: %zu, free workspace size: %zu",
205+
desired_workspace_size * kMinWorkspaceRatio,
206+
raft::resource::get_workspace_free_bytes(res));
207+
}
208+
}
182209

183210
// If the workspace is smaller than desired, put the I/O buffers into the large workspace.
184211
rmm::device_async_resource_ref workspace_mr =
185-
desired_workspace_size <= raft::resource::get_workspace_free_bytes(res)
186-
? raft::resource::get_workspace_resource(res)
187-
: raft::resource::get_large_workspace_resource(res);
212+
use_large_workspace ? raft::resource::get_large_workspace_resource(res)
213+
: raft::resource::get_workspace_resource(res);
188214

189215
RAFT_LOG_DEBUG(
190216
"IVF-PQ search node_degree: %d, top_k: %d, gpu_top_k: %d, max_batch_size:: %d, n_probes: %u",

cpp/src/neighbors/ivf_pq/ivf_pq_search.cuh

Lines changed: 70 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright (c) 2022-2024, NVIDIA CORPORATION.
2+
* Copyright (c) 2022-2025, NVIDIA CORPORATION.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -151,7 +151,7 @@ void select_clusters(raft::resources const& handle,
151151
} break;
152152
default: RAFT_FAIL("Unsupported distance type %d.", int(metric));
153153
}
154-
rmm::device_uvector<float> qc_distances(n_queries * n_lists, stream, mr);
154+
rmm::device_uvector<float> qc_distances(size_t(n_queries) * size_t(n_lists), stream, mr);
155155
raft::linalg::gemm(handle,
156156
true,
157157
false,
@@ -169,7 +169,7 @@ void select_clusters(raft::resources const& handle,
169169
stream);
170170

171171
// Select neighbor clusters for each query.
172-
rmm::device_uvector<float> cluster_dists(n_queries * n_probes, stream, mr);
172+
rmm::device_uvector<float> cluster_dists(size_t(n_queries) * size_t(n_probes), stream, mr);
173173
cuvs::selection::select_k(
174174
handle,
175175
raft::make_device_matrix_view<const float, int64_t>(qc_distances.data(), n_queries, n_lists),
@@ -237,7 +237,7 @@ void select_clusters(raft::resources const& handle,
237237
} break;
238238
default: RAFT_FAIL("Unsupported distance type %d.", int(metric));
239239
}
240-
rmm::device_uvector<dist_type> qc_distances(n_queries * n_lists, stream, mr);
240+
rmm::device_uvector<dist_type> qc_distances(size_t(n_queries) * size_t(n_lists), stream, mr);
241241
raft::linalg::gemm(handle,
242242
true,
243243
false,
@@ -255,7 +255,7 @@ void select_clusters(raft::resources const& handle,
255255
stream);
256256

257257
// Select neighbor clusters for each query.
258-
rmm::device_uvector<dist_type> cluster_dists(n_queries * n_probes, stream, mr);
258+
rmm::device_uvector<dist_type> cluster_dists(size_t(n_queries) * size_t(n_probes), stream, mr);
259259
// cuvs::selection::select_k lacks uint32_t-as-a-value support at the moment
260260
raft::matrix::select_k<dist_type, uint32_t>(
261261
handle,
@@ -322,7 +322,7 @@ void select_clusters(raft::resources const& handle,
322322
} break;
323323
default: RAFT_FAIL("Unsupported distance type %d.", int(metric));
324324
}
325-
rmm::device_uvector<dist_type> qc_distances(n_queries * n_lists, stream, mr);
325+
rmm::device_uvector<dist_type> qc_distances(size_t(n_queries) * size_t(n_lists), stream, mr);
326326
raft::linalg::gemm(handle,
327327
true,
328328
false,
@@ -340,7 +340,7 @@ void select_clusters(raft::resources const& handle,
340340
stream);
341341

342342
// Select neighbor clusters for each query.
343-
rmm::device_uvector<dist_type> cluster_dists(n_queries * n_probes, stream, mr);
343+
rmm::device_uvector<dist_type> cluster_dists(size_t(n_queries) * size_t(n_probes), stream, mr);
344344
cuvs::selection::select_k(
345345
handle,
346346
raft::make_device_matrix_view<const dist_type, int64_t>(
@@ -730,8 +730,9 @@ struct ivfpq_search {
730730
};
731731

732732
/**
733-
* A heuristic for bounding the number of queries per batch, to improve GPU utilization.
734-
* (based on the number of SMs and the work size).
733+
* A heuristic for bounding the number of queries in compute similarity kernel batch, to improve GPU
734+
* utilization. (based on the number of SMs and the work size). A major restriction here is the
735+
* max_samples - how many intermediate results may be kept in memory.
735736
*
736737
* @param res is used to query the workspace size
737738
* @param k top-k
@@ -742,11 +743,11 @@ struct ivfpq_search {
742743
*
743744
* @return maximum recommended batch size.
744745
*/
745-
inline auto get_max_batch_size(raft::resources const& res,
746-
uint32_t k,
747-
uint32_t n_probes,
748-
uint32_t n_queries,
749-
uint32_t max_samples) -> uint32_t
746+
inline auto get_max_fine_batch_size(raft::resources const& res,
747+
uint32_t k,
748+
uint32_t n_probes,
749+
uint32_t n_queries,
750+
uint32_t max_samples) -> uint32_t
750751
{
751752
uint32_t max_batch_size = n_queries;
752753
uint32_t n_ctas_total = raft::resource::get_device_properties(res).multiProcessorCount * 2;
@@ -778,6 +779,44 @@ inline auto get_max_batch_size(raft::resources const& res,
778779
return max_batch_size;
779780
}
780781

782+
/**
783+
* A heuristic for bounding the number of queries per batch for the outer loop of the search,
784+
* to improve GPU utilization and memory usage.
785+
*
786+
* @param res is used to query the workspace size
787+
* @param data_size size of each data element in bytes
788+
* @param n_probes number of selected clusters per query
789+
* @param n_lists total number of clusters
790+
* @param n_queries number of queries to process
791+
* @param max_queries maximum number of queries that can be processed at once
792+
*
793+
* @return maximum recommended batch size for the outer loop
794+
*/
795+
inline auto get_max_coarse_batch_size(raft::resources const& res,
796+
const search_params& params,
797+
uint32_t n_probes,
798+
uint32_t n_lists,
799+
uint32_t n_queries) -> uint32_t
800+
{
801+
size_t data_size = 4;
802+
switch (params.coarse_search_dtype) {
803+
case CUDA_R_32F: data_size = 4; break;
804+
case CUDA_R_16F: data_size = 2; break;
805+
case CUDA_R_8I: data_size = 1; break;
806+
default: RAFT_FAIL("Unexpected coarse_search_dtype (%d)", int(params.coarse_search_dtype));
807+
}
808+
// How much data we allocate for coarse GEMM.
809+
// This is NOT all memory we need, as a rule of thumb max it out to half of the workspace.
810+
// We don't reach this limit by default, but only when we increase the max_internal_batch_size by
811+
// a lot.
812+
auto bytes_per_query = static_cast<size_t>(n_probes + n_lists) * data_size;
813+
auto max_per_ws = raft::resource::get_workspace_free_bytes(res) / bytes_per_query;
814+
return std::max<uint32_t>(
815+
1,
816+
std::min<uint32_t>(max_per_ws / 2,
817+
std::min<uint32_t>(params.max_internal_batch_size, n_queries)));
818+
}
819+
781820
template <typename T, typename IdxT>
782821
inline auto get_rotation_matrix(const raft::resources& res, const index<IdxT>& index)
783822
-> raft::device_matrix_view<const T, uint32_t, raft::row_major>
@@ -867,33 +906,37 @@ inline void search(raft::resources const& handle,
867906
auto mr = raft::resource::get_workspace_resource(handle);
868907

869908
// Maximum number of query vectors to search at the same time.
870-
const auto max_queries =
871-
std::min<uint32_t>(std::max<uint32_t>(n_queries, 1), params.max_internal_batch_size);
872-
auto max_batch_size = get_max_batch_size(handle, k, n_probes, max_queries, max_samples);
909+
// Number of queries in the outer loop, which includes query transform and coarse search.
910+
const auto max_bs_outer =
911+
get_max_coarse_batch_size(handle, params, n_probes, index.n_lists(), n_queries);
912+
// Number of queries in the inner loop, which includes the fine search;
913+
// This is usually smaller than the outer loop when the non-fused kernel has to keep intermediate
914+
// results in the device memory.
915+
auto max_bs_inner = get_max_fine_batch_size(handle, k, n_probes, max_bs_outer, max_samples);
873916

874917
using some_query_t = std::
875918
variant<rmm::device_uvector<float>, rmm::device_uvector<half>, rmm::device_uvector<int8_t>>;
876919
some_query_t gemm_queries(
877920
params.coarse_search_dtype == CUDA_R_32F
878921
? std::move(some_query_t{
879-
std::in_place_type_t<rmm::device_uvector<float>>{}, max_queries * dim_ext, stream, mr})
922+
std::in_place_type_t<rmm::device_uvector<float>>{}, max_bs_outer * dim_ext, stream, mr})
880923
: params.coarse_search_dtype == CUDA_R_16F
881924
? std::move(some_query_t{
882-
std::in_place_type_t<rmm::device_uvector<half>>{}, max_queries * dim_ext, stream, mr})
925+
std::in_place_type_t<rmm::device_uvector<half>>{}, max_bs_outer * dim_ext, stream, mr})
883926
: params.coarse_search_dtype == CUDA_R_8I
884927
? std::move(some_query_t{
885-
std::in_place_type_t<rmm::device_uvector<int8_t>>{}, max_queries * dim_ext, stream, mr})
928+
std::in_place_type_t<rmm::device_uvector<int8_t>>{}, max_bs_outer * dim_ext, stream, mr})
886929
: throw raft::logic_error("Unsupported coarse_search_dtype (only CUDA_R_32F, "
887930
"CUDA_R_16F, and CUDA_R_8I are supported)"));
888-
rmm::device_uvector<float> rot_queries(max_queries * index.rot_dim(), stream, mr);
889-
rmm::device_uvector<uint32_t> clusters_to_probe(max_queries * n_probes, stream, mr);
931+
rmm::device_uvector<float> rot_queries(max_bs_outer * index.rot_dim(), stream, mr);
932+
rmm::device_uvector<uint32_t> clusters_to_probe(max_bs_outer * n_probes, stream, mr);
890933

891934
auto filter_adapter = cuvs::neighbors::filtering::ivf_to_sample_filter(
892935
index.inds_ptrs().data_handle(), sample_filter);
893936
auto search_instance = ivfpq_search<IdxT, decltype(filter_adapter)>::fun(params, index.metric());
894937

895-
for (uint32_t offset_q = 0; offset_q < n_queries; offset_q += max_queries) {
896-
uint32_t queries_batch = min(max_queries, n_queries - offset_q);
938+
for (uint32_t offset_q = 0; offset_q < n_queries; offset_q += max_bs_outer) {
939+
uint32_t queries_batch = min(max_bs_outer, n_queries - offset_q);
897940
raft::common::nvtx::range<cuvs::common::nvtx::domain::cuvs> batch_scope(
898941
"ivf_pq::search-batch(queries: %u - %u)", offset_q, offset_q + queries_batch);
899942

@@ -942,12 +985,12 @@ inline void search(raft::resources const& handle,
942985
gemm_queries);
943986
if (index.metric() == distance::DistanceType::CosineExpanded) {
944987
auto rot_queries_view = raft::make_device_matrix_view<float, uint32_t>(
945-
rot_queries.data(), max_queries, index.rot_dim());
988+
rot_queries.data(), max_bs_outer, index.rot_dim());
946989
raft::linalg::row_normalize<raft::linalg::L2Norm>(
947990
handle, raft::make_const_mdspan(rot_queries_view), rot_queries_view);
948991
}
949-
for (uint32_t offset_b = 0; offset_b < queries_batch; offset_b += max_batch_size) {
950-
uint32_t batch_size = min(max_batch_size, queries_batch - offset_b);
992+
for (uint32_t offset_b = 0; offset_b < queries_batch; offset_b += max_bs_inner) {
993+
uint32_t batch_size = min(max_bs_inner, queries_batch - offset_b);
951994
/* The distance calculation is done in the rotated/transformed space;
952995
as long as `index.rotation_matrix()` is orthogonal, the distances and thus results are
953996
preserved.

0 commit comments

Comments
 (0)