@@ -104,11 +104,12 @@ void batched_insert_vamana(
104104 " to 1.0" );
105105 max_batchsize = (int )dataset.extent (0 );
106106 }
107- int insert_iters = (int )(params.vamana_iters );
108- double base = (double )(params.batch_base );
109- float alpha = (float )(params.alpha );
110- int visited_size = params.visited_size ;
111- int queue_size = params.queue_size ;
107+ int insert_iters = (int )(params.vamana_iters );
108+ double base = (double )(params.batch_base );
109+ float alpha = (float )(params.alpha );
110+ int visited_size = params.visited_size ;
111+ int queue_size = params.queue_size ;
112+ int reverse_batch = params.reverse_batchsize ;
112113
113114 if ((visited_size & (visited_size - 1 )) != 0 ) {
114115 RAFT_LOG_WARN (" visited_size must be a power of 2, rounding up." );
@@ -152,39 +153,20 @@ void batched_insert_vamana(
152153 std::vector<IdxT> insert_order;
153154 create_insert_permutation<IdxT>(insert_order, (uint32_t )N);
154155
155- // Memory needed to sort reverse edges - potentially large memory footprint
156- auto edge_dest =
157- raft::make_device_mdarray<IdxT>(res,
158- raft::resource::get_large_workspace_resource (res),
159- raft::make_extents<int64_t >(max_batchsize, degree));
160- auto edge_src =
161- raft::make_device_mdarray<IdxT>(res,
162- raft::resource::get_large_workspace_resource (res),
163- raft::make_extents<int64_t >(max_batchsize, degree));
164-
165- size_t temp_storage_bytes = max_batchsize * degree * (2 * sizeof (IdxT));
166- RAFT_LOG_DEBUG (" Temp storage needed for sorting (bytes): %lu" , temp_storage_bytes);
167- auto temp_sort_storage =
168- raft::make_device_mdarray<IdxT>(res,
169- raft::resource::get_large_workspace_resource (res),
170- raft::make_extents<int64_t >(2 * max_batchsize, degree));
171-
172156 // Calculate the shared memory sizes of each kernel
173157 int search_smem_sort_size = 0 ;
174158 int prune_smem_sort_size = 0 ;
175159 SELECT_SMEM_SIZES (degree, visited_size); // Sets above 2 variables to appropriate sizes
176160
177161 // Total dynamic shared memory used by GreedySearch
178- int align_padding = ((((dim-1 )/16 )+1 )*16 ) - dim;
179- int search_smem_total_size =
180- static_cast <int >(search_smem_sort_size + (dim+align_padding) * sizeof (T) +
181- visited_size * sizeof (Node<accT>) +
182- degree * sizeof (int ) + queue_size * sizeof (DistPair<IdxT, accT>));
162+ int align_padding = ((((dim - 1 ) / 16 ) + 1 ) * 16 ) - dim;
163+ int search_smem_total_size = static_cast <int >(
164+ search_smem_sort_size + (dim + align_padding) * sizeof (T) + visited_size * sizeof (Node<accT>) +
165+ degree * sizeof (int ) + queue_size * sizeof (DistPair<IdxT, accT>));
183166
184167 // Total dynamic shared memory size needed by both RobustPrune calls
185- int prune_smem_total_size =
186- prune_smem_sort_size + (dim+align_padding) * sizeof (T)
187- + (degree + visited_size) * sizeof (DistPair<IdxT, accT>);
168+ int prune_smem_total_size = prune_smem_sort_size + (dim + align_padding) * sizeof (T) +
169+ (degree + visited_size) * sizeof (DistPair<IdxT, accT>);
188170
189171 RAFT_LOG_DEBUG (" Dynamic shared memory usage (bytes): GreedySearch: %d, RobustPrune: %d" ,
190172 search_smem_total_size,
@@ -255,6 +237,15 @@ void batched_insert_vamana(
255237 int total_edges;
256238 raft::copy (&total_edges, d_total_edges.data_handle (), 1 , stream);
257239
240+ auto edge_dest =
241+ raft::make_device_mdarray<IdxT>(res,
242+ raft::resource::get_large_workspace_resource (res),
243+ raft::make_extents<int64_t >(total_edges));
244+ auto edge_src =
245+ raft::make_device_mdarray<IdxT>(res,
246+ raft::resource::get_large_workspace_resource (res),
247+ raft::make_extents<int64_t >(total_edges));
248+
258249 // Create reverse edge list
259250 create_reverse_edge_list<accT, IdxT>
260251 <<<num_blocks, blockD, 0 , stream>>> (query_list_ptr.data_handle (),
@@ -263,6 +254,24 @@ void batched_insert_vamana(
263254 edge_src.data_handle (),
264255 edge_dest.data_handle ());
265256
257+ void * d_temp_storage = nullptr ;
258+ size_t temp_storage_bytes = 0 ;
259+
260+ cub::DeviceMergeSort::SortPairs (d_temp_storage,
261+ temp_storage_bytes,
262+ edge_dest.data_handle (),
263+ edge_src.data_handle (),
264+ total_edges,
265+ CmpEdge<IdxT>(),
266+ stream);
267+
268+ RAFT_LOG_DEBUG (" Temp storage needed for sorting (bytes): %lu" , temp_storage_bytes);
269+
270+ auto temp_sort_storage = raft::make_device_mdarray<IdxT>(
271+ res,
272+ raft::resource::get_large_workspace_resource (res),
273+ raft::make_extents<int64_t >(temp_storage_bytes / sizeof (IdxT)));
274+
266275 // Sort to group reverse edges by destination
267276 cub::DeviceMergeSort::SortPairs (temp_sort_storage.data_handle (),
268277 temp_storage_bytes,
@@ -285,61 +294,72 @@ void batched_insert_vamana(
285294 thrust::unique_by_key (
286295 edge_dest_vec.begin (), edge_dest_vec.end (), unique_indices.data_handle ());
287296
288- // Allocate reverse QueryCandidate list based on number of unique destinations
289- // TODO - Do this in batches to reduce memory footprint / support larger datasets
290- auto reverse_list_ptr = raft::make_device_mdarray<QueryCandidates<IdxT, accT>>(
291- res,
292- raft::resource::get_large_workspace_resource (res),
293- raft::make_extents<int64_t >(unique_dests));
294- auto rev_ids =
295- raft::make_device_mdarray<IdxT>(res,
296- raft::resource::get_large_workspace_resource (res),
297- raft::make_extents<int64_t >(unique_dests, visited_size));
298- auto rev_dists =
299- raft::make_device_mdarray<accT>(res,
300- raft::resource::get_large_workspace_resource (res),
301- raft::make_extents<int64_t >(unique_dests, visited_size));
302-
303- QueryCandidates<IdxT, accT>* reverse_list =
304- static_cast <QueryCandidates<IdxT, accT>*>(reverse_list_ptr.data_handle ());
305-
306- init_query_candidate_list<IdxT, accT><<<256 , blockD, 0 , stream>>> (reverse_list,
307- rev_ids.data_handle (),
308- rev_dists.data_handle (),
309- (int )unique_dests,
310- visited_size);
311-
312- // May need more blocks for reverse list
313- num_blocks = min (maxBlocks, unique_dests);
314-
315- // Populate reverse list ids and candidate lists from edge_src and edge_dest
316- populate_reverse_list_struct<T, accT, IdxT>
317- <<<num_blocks, blockD, 0 , stream>>> (reverse_list,
318- edge_src.data_handle (),
319- edge_dest.data_handle (),
320- unique_indices.data_handle (),
321- unique_dests,
322- total_edges,
323- dataset.extent (0 ));
324-
325- // Recompute distances (avoided keeping it during sorting)
326- recompute_reverse_dists<T, accT, IdxT>
327- <<<num_blocks, blockD, 0 , stream>>> (reverse_list, dataset, unique_dests, metric);
328-
329- // Call 2nd RobustPrune on reverse query_list
330- RobustPruneKernel<T, accT, IdxT>
331- <<<num_blocks, blockD, prune_smem_total_size, stream>>> (d_graph.view (),
332- raft::make_const_mdspan (dataset),
333- reverse_list_ptr.data_handle (),
334- unique_dests,
335- visited_size,
336- metric,
337- alpha,
338- prune_smem_sort_size);
339-
340- // Write new edge lists to graph
341- write_graph_edges_kernel<accT, IdxT><<<num_blocks, blockD, 0 , stream>>> (
342- d_graph.view (), reverse_list_ptr.data_handle (), degree, unique_dests);
297+ edge_dest_vec.clear ();
298+ edge_dest_vec.shrink_to_fit ();
299+
300+ // Batch execution of reverse edge creation/application
301+ for (int rev_start = 0 ; rev_start < (int )unique_dests; rev_start += reverse_batch) {
302+ if (rev_start + reverse_batch > (int )unique_dests) {
303+ reverse_batch = (int )unique_dests - rev_start;
304+ }
305+
306+ // Allocate reverse QueryCandidate list based on number of unique destinations
307+ auto reverse_list_ptr = raft::make_device_mdarray<QueryCandidates<IdxT, accT>>(
308+ res,
309+ raft::resource::get_large_workspace_resource (res),
310+ raft::make_extents<int64_t >(reverse_batch));
311+ auto rev_ids =
312+ raft::make_device_mdarray<IdxT>(res,
313+ raft::resource::get_large_workspace_resource (res),
314+ raft::make_extents<int64_t >(reverse_batch, visited_size));
315+ auto rev_dists =
316+ raft::make_device_mdarray<accT>(res,
317+ raft::resource::get_large_workspace_resource (res),
318+ raft::make_extents<int64_t >(reverse_batch, visited_size));
319+
320+ QueryCandidates<IdxT, accT>* reverse_list =
321+ static_cast <QueryCandidates<IdxT, accT>*>(reverse_list_ptr.data_handle ());
322+
323+ init_query_candidate_list<IdxT, accT><<<256 , blockD, 0 , stream>>> (reverse_list,
324+ rev_ids.data_handle (),
325+ rev_dists.data_handle (),
326+ (int )reverse_batch,
327+ visited_size);
328+
329+ // May need more blocks for reverse list
330+ num_blocks = min (maxBlocks, reverse_batch);
331+
332+ // Populate reverse list ids and candidate lists from edge_src and edge_dest
333+ populate_reverse_list_struct<T, accT, IdxT>
334+ <<<num_blocks, blockD, 0 , stream>>> (reverse_list,
335+ edge_src.data_handle (),
336+ edge_dest.data_handle (),
337+ unique_indices.data_handle (),
338+ unique_dests,
339+ total_edges,
340+ dataset.extent (0 ),
341+ rev_start,
342+ reverse_batch);
343+
344+ // Recompute distances (avoided keeping it during sorting)
345+ recompute_reverse_dists<T, accT, IdxT>
346+ <<<num_blocks, blockD, 0 , stream>>> (reverse_list, dataset, reverse_batch, metric);
347+
348+ // Call 2nd RobustPrune on reverse query_list
349+ RobustPruneKernel<T, accT, IdxT>
350+ <<<num_blocks, blockD, prune_smem_total_size, stream>>> (d_graph.view (),
351+ raft::make_const_mdspan (dataset),
352+ reverse_list_ptr.data_handle (),
353+ reverse_batch,
354+ visited_size,
355+ metric,
356+ alpha,
357+ prune_smem_sort_size);
358+
359+ // Write new edge lists to graph
360+ write_graph_edges_kernel<accT, IdxT><<<num_blocks, blockD, 0 , stream>>> (
361+ d_graph.view (), reverse_list_ptr.data_handle (), degree, reverse_batch);
362+ }
343363
344364 start += step_size;
345365 step_size *= base;
0 commit comments