|
18 | 18 | #include <cuvs/neighbors/composite/index.hpp> |
19 | 19 | #include <cuvs/selection/select_k.hpp> |
20 | 20 | #include <raft/core/resource/cuda_stream.hpp> |
| 21 | +#include <raft/core/resource/cuda_stream_pool.hpp> |
21 | 22 | #include <raft/linalg/add.cuh> |
22 | 23 | #include <rmm/device_uvector.hpp> |
23 | 24 |
|
@@ -47,29 +48,59 @@ void CompositeIndex<T, IdxT, OutputIdxT>::search( |
47 | 48 | size_t num_indices = children_.size(); |
48 | 49 | size_t buffer_size = num_queries * K * num_indices; |
49 | 50 |
|
50 | | - auto stream = raft::resource::get_cuda_stream(handle); |
51 | | - auto tmp_res = raft::resource::get_workspace_resource(handle); |
| 51 | + auto main_stream = raft::resource::get_cuda_stream(handle); |
| 52 | + auto tmp_res = raft::resource::get_workspace_resource(handle); |
52 | 53 |
|
53 | | - rmm::device_uvector<out_index_type> neighbors_buffer(buffer_size, stream, tmp_res); |
54 | | - rmm::device_uvector<float> distances_buffer(buffer_size, stream, tmp_res); |
| 54 | + rmm::device_uvector<out_index_type> neighbors_buffer(buffer_size, main_stream, tmp_res); |
| 55 | + rmm::device_uvector<float> distances_buffer(buffer_size, main_stream, tmp_res); |
| 56 | + |
| 57 | + std::vector<rmm::device_uvector<out_index_type>> temp_neighbors; |
| 58 | + std::vector<rmm::device_uvector<float>> temp_distances; |
| 59 | + |
| 60 | + for (size_t i = 0; i < num_indices; i++) { |
| 61 | + temp_neighbors.emplace_back(num_queries * K, main_stream, tmp_res); |
| 62 | + temp_distances.emplace_back(num_queries * K, main_stream, tmp_res); |
| 63 | + } |
| 64 | + |
| 65 | + raft::resource::wait_stream_pool_on_stream(handle); |
55 | 66 |
|
56 | 67 | out_index_type offset = 0; |
57 | 68 | out_index_type stride = K * num_indices; |
58 | 69 |
|
59 | 70 | for (size_t i = 0; i < num_indices; i++) { |
60 | 71 | const auto& sub_index = children_[i]; |
61 | | - sub_index->search(handle, params, queries, neighbors, distances, filter); |
| 72 | + |
| 73 | + auto stream = raft::resource::get_next_usable_stream(handle, i); |
| 74 | + |
| 75 | + raft::resources stream_pool_handle(handle); |
| 76 | + raft::resource::set_cuda_stream(stream_pool_handle, stream); |
| 77 | + |
| 78 | + auto temp_neighbors_view = |
| 79 | + raft::make_device_matrix_view<out_index_type, matrix_index_type, raft::row_major>( |
| 80 | + temp_neighbors[i].data(), num_queries, K); |
| 81 | + auto temp_distances_view = |
| 82 | + raft::make_device_matrix_view<float, matrix_index_type, raft::row_major>( |
| 83 | + temp_distances[i].data(), num_queries, K); |
| 84 | + |
| 85 | + sub_index->search( |
| 86 | + stream_pool_handle, params, queries, temp_neighbors_view, temp_distances_view, filter); |
| 87 | + |
62 | 88 | if (offset != 0) { |
63 | | - raft::linalg::addScalar( |
64 | | - neighbors.data_handle(), neighbors.data_handle(), offset, neighbors.size(), stream); |
| 89 | + raft::linalg::addScalar(temp_neighbors[i].data(), |
| 90 | + temp_neighbors[i].data(), |
| 91 | + offset, |
| 92 | + temp_neighbors[i].size(), |
| 93 | + stream); |
65 | 94 | } |
66 | 95 |
|
67 | 96 | raft::copy_matrix( |
68 | | - neighbors_buffer.data() + i * K, stride, neighbors.data_handle(), K, K, num_queries, stream); |
| 97 | + neighbors_buffer.data() + i * K, stride, temp_neighbors[i].data(), K, K, num_queries, stream); |
69 | 98 | raft::copy_matrix( |
70 | | - distances_buffer.data() + i * K, stride, distances.data_handle(), K, K, num_queries, stream); |
| 99 | + distances_buffer.data() + i * K, stride, temp_distances[i].data(), K, K, num_queries, stream); |
| 100 | + |
71 | 101 | offset += sub_index->size(); |
72 | 102 | } |
| 103 | + raft::resource::sync_stream_pool(handle); |
73 | 104 |
|
74 | 105 | auto distances_view = raft::make_device_matrix_view<const float, matrix_index_type>( |
75 | 106 | distances_buffer.data(), num_queries, K * num_indices); |
|
0 commit comments