Skip to content

Commit e19590d

Browse files
committed
add support multi-stream for CompositeIndex::search
1 parent 9a69747 commit e19590d

1 file changed

Lines changed: 40 additions & 9 deletions

File tree

cpp/src/neighbors/composite/index.cu

Lines changed: 40 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#include <cuvs/neighbors/composite/index.hpp>
1919
#include <cuvs/selection/select_k.hpp>
2020
#include <raft/core/resource/cuda_stream.hpp>
21+
#include <raft/core/resource/cuda_stream_pool.hpp>
2122
#include <raft/linalg/add.cuh>
2223
#include <rmm/device_uvector.hpp>
2324

@@ -47,29 +48,59 @@ void CompositeIndex<T, IdxT, OutputIdxT>::search(
4748
size_t num_indices = children_.size();
4849
size_t buffer_size = num_queries * K * num_indices;
4950

50-
auto stream = raft::resource::get_cuda_stream(handle);
51-
auto tmp_res = raft::resource::get_workspace_resource(handle);
51+
auto main_stream = raft::resource::get_cuda_stream(handle);
52+
auto tmp_res = raft::resource::get_workspace_resource(handle);
5253

53-
rmm::device_uvector<out_index_type> neighbors_buffer(buffer_size, stream, tmp_res);
54-
rmm::device_uvector<float> distances_buffer(buffer_size, stream, tmp_res);
54+
rmm::device_uvector<out_index_type> neighbors_buffer(buffer_size, main_stream, tmp_res);
55+
rmm::device_uvector<float> distances_buffer(buffer_size, main_stream, tmp_res);
56+
57+
std::vector<rmm::device_uvector<out_index_type>> temp_neighbors;
58+
std::vector<rmm::device_uvector<float>> temp_distances;
59+
60+
for (size_t i = 0; i < num_indices; i++) {
61+
temp_neighbors.emplace_back(num_queries * K, main_stream, tmp_res);
62+
temp_distances.emplace_back(num_queries * K, main_stream, tmp_res);
63+
}
64+
65+
raft::resource::wait_stream_pool_on_stream(handle);
5566

5667
out_index_type offset = 0;
5768
out_index_type stride = K * num_indices;
5869

5970
for (size_t i = 0; i < num_indices; i++) {
6071
const auto& sub_index = children_[i];
61-
sub_index->search(handle, params, queries, neighbors, distances, filter);
72+
73+
auto stream = raft::resource::get_next_usable_stream(handle, i);
74+
75+
raft::resources stream_pool_handle(handle);
76+
raft::resource::set_cuda_stream(stream_pool_handle, stream);
77+
78+
auto temp_neighbors_view =
79+
raft::make_device_matrix_view<out_index_type, matrix_index_type, raft::row_major>(
80+
temp_neighbors[i].data(), num_queries, K);
81+
auto temp_distances_view =
82+
raft::make_device_matrix_view<float, matrix_index_type, raft::row_major>(
83+
temp_distances[i].data(), num_queries, K);
84+
85+
sub_index->search(
86+
stream_pool_handle, params, queries, temp_neighbors_view, temp_distances_view, filter);
87+
6288
if (offset != 0) {
63-
raft::linalg::addScalar(
64-
neighbors.data_handle(), neighbors.data_handle(), offset, neighbors.size(), stream);
89+
raft::linalg::addScalar(temp_neighbors[i].data(),
90+
temp_neighbors[i].data(),
91+
offset,
92+
temp_neighbors[i].size(),
93+
stream);
6594
}
6695

6796
raft::copy_matrix(
68-
neighbors_buffer.data() + i * K, stride, neighbors.data_handle(), K, K, num_queries, stream);
97+
neighbors_buffer.data() + i * K, stride, temp_neighbors[i].data(), K, K, num_queries, stream);
6998
raft::copy_matrix(
70-
distances_buffer.data() + i * K, stride, distances.data_handle(), K, K, num_queries, stream);
99+
distances_buffer.data() + i * K, stride, temp_distances[i].data(), K, K, num_queries, stream);
100+
71101
offset += sub_index->size();
72102
}
103+
raft::resource::sync_stream_pool(handle);
73104

74105
auto distances_view = raft::make_device_matrix_view<const float, matrix_index_type>(
75106
distances_buffer.data(), num_queries, K * num_indices);

0 commit comments

Comments
 (0)