Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 48 additions & 3 deletions cpp/src/neighbors/mg/snmg.cuh
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2024, NVIDIA CORPORATION.
* Copyright (c) 2024-2025, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -210,6 +210,29 @@ void sharded_search_with_direct_merge(
auto query_partition = raft::make_host_matrix_view<const T, int64_t, row_major>(
queries.data_handle() + query_offset, n_rows_of_current_batch, n_cols);

if (index.num_ranks_ == 1) {
const raft::resources& dev_res = raft::resource::set_current_device_to_root_rank(clique);
Comment thread
viclafargue marked this conversation as resolved.
Outdated
auto d_neighbors = raft::make_device_matrix<searchIdxT, int64_t, row_major>(
dev_res, n_rows_of_current_batch, n_neighbors);
auto d_distances = raft::make_device_matrix<float, int64_t, row_major>(
dev_res, n_rows_of_current_batch, n_neighbors);

auto& ann_if = index.ann_interfaces_[0];
cuvs::neighbors::search(
dev_res, ann_if, search_params, query_partition, d_neighbors.view(), d_distances.view());

raft::copy(neighbors.data_handle() + output_offset,
d_neighbors.data_handle(),
part_size,
raft::resource::get_cuda_stream(dev_res));
raft::copy(distances.data_handle() + output_offset,
d_distances.data_handle(),
part_size,
raft::resource::get_cuda_stream(dev_res));
resource::sync_stream(dev_res);
continue;
}

const int& requirements = index.num_ranks_;
check_omp_threads(requirements); // should use at least num_ranks_ threads to avoid NCCL hang
#pragma omp parallel for num_threads(index.num_ranks_)
Expand Down Expand Up @@ -329,6 +352,30 @@ void sharded_search_with_tree_merge(
int64_t n_rows_of_current_batch = std::min((int64_t)n_rows_per_batch, n_rows - offset);
auto query_partition = raft::make_host_matrix_view<const T, int64_t, row_major>(
queries.data_handle() + query_offset, n_rows_of_current_batch, n_cols);
int64_t part_size = n_rows_of_current_batch * n_neighbors;

if (index.num_ranks_ == 1) {
const raft::resources& dev_res = raft::resource::set_current_device_to_root_rank(clique);
auto d_neighbors = raft::make_device_matrix<searchIdxT, int64_t, row_major>(
dev_res, n_rows_of_current_batch, n_neighbors);
auto d_distances = raft::make_device_matrix<float, int64_t, row_major>(
dev_res, n_rows_of_current_batch, n_neighbors);

auto& ann_if = index.ann_interfaces_[0];
cuvs::neighbors::search(
dev_res, ann_if, search_params, query_partition, d_neighbors.view(), d_distances.view());

raft::copy(neighbors.data_handle() + output_offset,
d_neighbors.data_handle(),
part_size,
raft::resource::get_cuda_stream(dev_res));
raft::copy(distances.data_handle() + output_offset,
d_distances.data_handle(),
part_size,
raft::resource::get_cuda_stream(dev_res));
resource::sync_stream(dev_res);
continue;
}

const int& requirements = index.num_ranks_;
check_omp_threads(requirements); // should use at least num_ranks_ threads to avoid NCCL hang
Expand All @@ -337,8 +384,6 @@ void sharded_search_with_tree_merge(
const raft::resources& dev_res = raft::resource::set_current_device_to_rank(clique, rank);
auto& ann_if = index.ann_interfaces_[rank];

int64_t part_size = n_rows_of_current_batch * n_neighbors;

auto tmp_neighbors = raft::make_device_matrix<searchIdxT, int64_t, row_major>(
dev_res, 2 * n_rows_of_current_batch, n_neighbors);
auto tmp_distances = raft::make_device_matrix<float, int64_t, row_major>(
Expand Down