Skip to content
Merged
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 27 additions & 2 deletions cpp/src/neighbors/mg/snmg.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include <cuvs/neighbors/knn_merge_parts.hpp>

#include <fstream>
#include <omp.h>

namespace cuvs::neighbors {
using namespace raft;
Expand Down Expand Up @@ -92,13 +93,25 @@ void build(const raft::resources& clique,
RAFT_LOG_DEBUG("REPLICATED BUILD: %d*%drows", index.num_ranks_, n_rows);

index.ann_interfaces_.resize(index.num_ranks_);
#pragma omp parallel for

// Enable nested parallelism
int saved_omp_threads = omp_get_max_threads();
int threads_per_rank = std::max(1, saved_omp_threads / index.num_ranks_);
omp_set_nested(1);
Comment thread
viclafargue marked this conversation as resolved.
Outdated

#pragma omp parallel for num_threads(index.num_ranks_)
Comment thread
viclafargue marked this conversation as resolved.
for (int rank = 0; rank < index.num_ranks_; rank++) {
// Set thread limit for this rank's nested OpenMP regions
omp_set_num_threads(threads_per_rank);

const raft::resources& dev_res = raft::resource::set_current_device_to_rank(clique, rank);
auto& ann_if = index.ann_interfaces_[rank];
cuvs::neighbors::build(dev_res, ann_if, index_params, index_dataset);
resource::sync_stream(dev_res);
}

// Restore original thread count
omp_set_num_threads(saved_omp_threads);
} else if (index.mode_ == SHARDED) {
int64_t n_rows = index_dataset.extent(0);
int64_t n_cols = index_dataset.extent(1);
Expand All @@ -107,8 +120,17 @@ void build(const raft::resources& clique,
RAFT_LOG_DEBUG("SHARDED BUILD: %d*%drows", index.num_ranks_, n_rows_per_shard);

index.ann_interfaces_.resize(index.num_ranks_);
#pragma omp parallel for

// Enable nested parallelism
int saved_omp_threads = omp_get_max_threads();
int threads_per_rank = std::max(1, saved_omp_threads / index.num_ranks_);
omp_set_nested(1);

#pragma omp parallel for num_threads(index.num_ranks_)
for (int rank = 0; rank < index.num_ranks_; rank++) {
// Set thread limit for this rank's nested OpenMP regions
omp_set_num_threads(threads_per_rank);

const raft::resources& dev_res = raft::resource::set_current_device_to_rank(clique, rank);
int64_t offset = rank * n_rows_per_shard;
int64_t n_rows_of_current_shard = std::min(n_rows_per_shard, n_rows - offset);
Expand All @@ -119,6 +141,9 @@ void build(const raft::resources& clique,
cuvs::neighbors::build(dev_res, ann_if, index_params, partition);
Comment thread
viclafargue marked this conversation as resolved.
resource::sync_stream(dev_res);
}

Comment thread
viclafargue marked this conversation as resolved.
// Restore original thread count
omp_set_num_threads(saved_omp_threads);
Comment thread
viclafargue marked this conversation as resolved.
Outdated
}
}

Expand Down
Loading