11/*
2- * Copyright (c) 2022-2024 , NVIDIA CORPORATION.
2+ * Copyright (c) 2022-2025 , NVIDIA CORPORATION.
33 *
44 * Licensed under the Apache License, Version 2.0 (the "License");
55 * you may not use this file except in compliance with the License.
@@ -151,7 +151,7 @@ void select_clusters(raft::resources const& handle,
151151 } break ;
152152 default : RAFT_FAIL (" Unsupported distance type %d." , int (metric));
153153 }
154- rmm::device_uvector<float > qc_distances (n_queries * n_lists, stream, mr);
154+ rmm::device_uvector<float > qc_distances (size_t ( n_queries) * size_t ( n_lists) , stream, mr);
155155 raft::linalg::gemm (handle,
156156 true ,
157157 false ,
@@ -169,7 +169,7 @@ void select_clusters(raft::resources const& handle,
169169 stream);
170170
171171 // Select neighbor clusters for each query.
172- rmm::device_uvector<float > cluster_dists (n_queries * n_probes, stream, mr);
172+ rmm::device_uvector<float > cluster_dists (size_t ( n_queries) * size_t ( n_probes) , stream, mr);
173173 cuvs::selection::select_k (
174174 handle,
175175 raft::make_device_matrix_view<const float , int64_t >(qc_distances.data (), n_queries, n_lists),
@@ -237,7 +237,7 @@ void select_clusters(raft::resources const& handle,
237237 } break ;
238238 default : RAFT_FAIL (" Unsupported distance type %d." , int (metric));
239239 }
240- rmm::device_uvector<dist_type> qc_distances (n_queries * n_lists, stream, mr);
240+ rmm::device_uvector<dist_type> qc_distances (size_t ( n_queries) * size_t ( n_lists) , stream, mr);
241241 raft::linalg::gemm (handle,
242242 true ,
243243 false ,
@@ -255,7 +255,7 @@ void select_clusters(raft::resources const& handle,
255255 stream);
256256
257257 // Select neighbor clusters for each query.
258- rmm::device_uvector<dist_type> cluster_dists (n_queries * n_probes, stream, mr);
258+ rmm::device_uvector<dist_type> cluster_dists (size_t ( n_queries) * size_t ( n_probes) , stream, mr);
259259 // cuvs::selection::select_k lacks uint32_t-as-a-value support at the moment
260260 raft::matrix::select_k<dist_type, uint32_t >(
261261 handle,
@@ -322,7 +322,7 @@ void select_clusters(raft::resources const& handle,
322322 } break ;
323323 default : RAFT_FAIL (" Unsupported distance type %d." , int (metric));
324324 }
325- rmm::device_uvector<dist_type> qc_distances (n_queries * n_lists, stream, mr);
325+ rmm::device_uvector<dist_type> qc_distances (size_t ( n_queries) * size_t ( n_lists) , stream, mr);
326326 raft::linalg::gemm (handle,
327327 true ,
328328 false ,
@@ -340,7 +340,7 @@ void select_clusters(raft::resources const& handle,
340340 stream);
341341
342342 // Select neighbor clusters for each query.
343- rmm::device_uvector<dist_type> cluster_dists (n_queries * n_probes, stream, mr);
343+ rmm::device_uvector<dist_type> cluster_dists (size_t ( n_queries) * size_t ( n_probes) , stream, mr);
344344 cuvs::selection::select_k (
345345 handle,
346346 raft::make_device_matrix_view<const dist_type, int64_t >(
@@ -730,8 +730,9 @@ struct ivfpq_search {
730730};
731731
732732/* *
733- * A heuristic for bounding the number of queries per batch, to improve GPU utilization.
734- * (based on the number of SMs and the work size).
733+ * A heuristic for bounding the number of queries in compute similarity kernel batch, to improve GPU
734+ * utilization. (based on the number of SMs and the work size). A major restriction here is the
735+ * max_samples - how many intermediate results may be kept in memory.
735736 *
736737 * @param res is used to query the workspace size
737738 * @param k top-k
@@ -742,11 +743,11 @@ struct ivfpq_search {
742743 *
743744 * @return maximum recommended batch size.
744745 */
745- inline auto get_max_batch_size (raft::resources const & res,
746- uint32_t k,
747- uint32_t n_probes,
748- uint32_t n_queries,
749- uint32_t max_samples) -> uint32_t
746+ inline auto get_max_fine_batch_size (raft::resources const & res,
747+ uint32_t k,
748+ uint32_t n_probes,
749+ uint32_t n_queries,
750+ uint32_t max_samples) -> uint32_t
750751{
751752 uint32_t max_batch_size = n_queries;
752753 uint32_t n_ctas_total = raft::resource::get_device_properties (res).multiProcessorCount * 2 ;
@@ -778,6 +779,44 @@ inline auto get_max_batch_size(raft::resources const& res,
778779 return max_batch_size;
779780}
780781
782+ /* *
783+ * A heuristic for bounding the number of queries per batch for the outer loop of the search,
784+ * to improve GPU utilization and memory usage.
785+ *
786+ * @param res is used to query the workspace size
787+ * @param data_size size of each data element in bytes
788+ * @param n_probes number of selected clusters per query
789+ * @param n_lists total number of clusters
790+ * @param n_queries number of queries to process
791+ * @param max_queries maximum number of queries that can be processed at once
792+ *
793+ * @return maximum recommended batch size for the outer loop
794+ */
795+ inline auto get_max_coarse_batch_size (raft::resources const & res,
796+ const search_params& params,
797+ uint32_t n_probes,
798+ uint32_t n_lists,
799+ uint32_t n_queries) -> uint32_t
800+ {
801+ size_t data_size = 4 ;
802+ switch (params.coarse_search_dtype ) {
803+ case CUDA_R_32F: data_size = 4 ; break ;
804+ case CUDA_R_16F: data_size = 2 ; break ;
805+ case CUDA_R_8I: data_size = 1 ; break ;
806+ default : RAFT_FAIL (" Unexpected coarse_search_dtype (%d)" , int (params.coarse_search_dtype ));
807+ }
808+ // How much data we allocate for coarse GEMM.
809+ // This is NOT all memory we need, as a rule of thumb max it out to half of the workspace.
810+ // We don't reach this limit by default, but only when we increase the max_internal_batch_size by
811+ // a lot.
812+ auto bytes_per_query = static_cast <size_t >(n_probes + n_lists) * data_size;
813+ auto max_per_ws = raft::resource::get_workspace_free_bytes (res) / bytes_per_query;
814+ return std::max<uint32_t >(
815+ 1 ,
816+ std::min<uint32_t >(max_per_ws / 2 ,
817+ std::min<uint32_t >(params.max_internal_batch_size , n_queries)));
818+ }
819+
781820template <typename T, typename IdxT>
782821inline auto get_rotation_matrix (const raft::resources& res, const index<IdxT>& index)
783822 -> raft::device_matrix_view<const T, uint32_t, raft::row_major>
@@ -867,33 +906,37 @@ inline void search(raft::resources const& handle,
867906 auto mr = raft::resource::get_workspace_resource (handle);
868907
869908 // Maximum number of query vectors to search at the same time.
870- const auto max_queries =
871- std::min<uint32_t >(std::max<uint32_t >(n_queries, 1 ), params.max_internal_batch_size );
872- auto max_batch_size = get_max_batch_size (handle, k, n_probes, max_queries, max_samples);
909+ // Number of queries in the outer loop, which includes query transform and coarse search.
910+ const auto max_bs_outer =
911+ get_max_coarse_batch_size (handle, params, n_probes, index.n_lists (), n_queries);
912+ // Number of queries in the inner loop, which includes the fine search;
913+ // This is usually smaller than the outer loop when the non-fused kernel has to keep intermediate
914+ // results in the device memory.
915+ auto max_bs_inner = get_max_fine_batch_size (handle, k, n_probes, max_bs_outer, max_samples);
873916
874917 using some_query_t = std::
875918 variant<rmm::device_uvector<float >, rmm::device_uvector<half>, rmm::device_uvector<int8_t >>;
876919 some_query_t gemm_queries (
877920 params.coarse_search_dtype == CUDA_R_32F
878921 ? std::move (some_query_t {
879- std::in_place_type_t <rmm::device_uvector<float >>{}, max_queries * dim_ext, stream, mr})
922+ std::in_place_type_t <rmm::device_uvector<float >>{}, max_bs_outer * dim_ext, stream, mr})
880923 : params.coarse_search_dtype == CUDA_R_16F
881924 ? std::move (some_query_t {
882- std::in_place_type_t <rmm::device_uvector<half>>{}, max_queries * dim_ext, stream, mr})
925+ std::in_place_type_t <rmm::device_uvector<half>>{}, max_bs_outer * dim_ext, stream, mr})
883926 : params.coarse_search_dtype == CUDA_R_8I
884927 ? std::move (some_query_t {
885- std::in_place_type_t <rmm::device_uvector<int8_t >>{}, max_queries * dim_ext, stream, mr})
928+ std::in_place_type_t <rmm::device_uvector<int8_t >>{}, max_bs_outer * dim_ext, stream, mr})
886929 : throw raft::logic_error (" Unsupported coarse_search_dtype (only CUDA_R_32F, "
887930 " CUDA_R_16F, and CUDA_R_8I are supported)" ));
888- rmm::device_uvector<float > rot_queries (max_queries * index.rot_dim (), stream, mr);
889- rmm::device_uvector<uint32_t > clusters_to_probe (max_queries * n_probes, stream, mr);
931+ rmm::device_uvector<float > rot_queries (max_bs_outer * index.rot_dim (), stream, mr);
932+ rmm::device_uvector<uint32_t > clusters_to_probe (max_bs_outer * n_probes, stream, mr);
890933
891934 auto filter_adapter = cuvs::neighbors::filtering::ivf_to_sample_filter (
892935 index.inds_ptrs ().data_handle (), sample_filter);
893936 auto search_instance = ivfpq_search<IdxT, decltype (filter_adapter)>::fun (params, index.metric ());
894937
895- for (uint32_t offset_q = 0 ; offset_q < n_queries; offset_q += max_queries ) {
896- uint32_t queries_batch = min (max_queries , n_queries - offset_q);
938+ for (uint32_t offset_q = 0 ; offset_q < n_queries; offset_q += max_bs_outer ) {
939+ uint32_t queries_batch = min (max_bs_outer , n_queries - offset_q);
897940 raft::common::nvtx::range<cuvs::common::nvtx::domain::cuvs> batch_scope (
898941 " ivf_pq::search-batch(queries: %u - %u)" , offset_q, offset_q + queries_batch);
899942
@@ -942,12 +985,12 @@ inline void search(raft::resources const& handle,
942985 gemm_queries);
943986 if (index.metric () == distance::DistanceType::CosineExpanded) {
944987 auto rot_queries_view = raft::make_device_matrix_view<float , uint32_t >(
945- rot_queries.data (), max_queries , index.rot_dim ());
988+ rot_queries.data (), max_bs_outer , index.rot_dim ());
946989 raft::linalg::row_normalize<raft::linalg::L2Norm>(
947990 handle, raft::make_const_mdspan (rot_queries_view), rot_queries_view);
948991 }
949- for (uint32_t offset_b = 0 ; offset_b < queries_batch; offset_b += max_batch_size ) {
950- uint32_t batch_size = min (max_batch_size , queries_batch - offset_b);
992+ for (uint32_t offset_b = 0 ; offset_b < queries_batch; offset_b += max_bs_inner ) {
993+ uint32_t batch_size = min (max_bs_inner , queries_batch - offset_b);
951994 /* The distance calculation is done in the rotated/transformed space;
952995 as long as `index.rotation_matrix()` is orthogonal, the distances and thus results are
953996 preserved.
0 commit comments