Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions cpp/src/core/omp_wrapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,12 @@ void set_nested(int v)
if constexpr (is_omp_enabled()) { omp_set_nested(v); }
}

void set_num_threads(int v)
{
(void)v;
Comment thread
viclafargue marked this conversation as resolved.
if constexpr (is_omp_enabled()) { omp_set_num_threads(v); }
}

void check_threads(const int requirements)
{
const int max_threads = get_max_threads();
Expand Down
1 change: 1 addition & 0 deletions cpp/src/core/omp_wrapper.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ int get_num_threads();
int get_thread_num();

void set_nested(int v);
void set_num_threads(int v);

void check_threads(const int requirements);

Expand Down
34 changes: 32 additions & 2 deletions cpp/src/neighbors/mg/snmg.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -92,13 +92,28 @@ void build(const raft::resources& clique,
RAFT_LOG_DEBUG("REPLICATED BUILD: %d*%drows", index.num_ranks_, n_rows);

index.ann_interfaces_.resize(index.num_ranks_);
#pragma omp parallel for

// Enable nested parallelism
int saved_omp_threads = cuvs::core::omp::get_max_threads();
int threads_per_rank = std::max(1, saved_omp_threads / index.num_ranks_);
cuvs::core::omp::set_nested(1);

const int& requirements = index.num_ranks_;
cuvs::core::omp::check_threads(requirements);

#pragma omp parallel for num_threads(index.num_ranks_)
Comment thread
viclafargue marked this conversation as resolved.
for (int rank = 0; rank < index.num_ranks_; rank++) {
// Set thread limit for this rank's nested OpenMP regions
cuvs::core::omp::set_num_threads(threads_per_rank);
Comment thread
viclafargue marked this conversation as resolved.

const raft::resources& dev_res = raft::resource::set_current_device_to_rank(clique, rank);
auto& ann_if = index.ann_interfaces_[rank];
cuvs::neighbors::build(dev_res, ann_if, index_params, index_dataset);
resource::sync_stream(dev_res);
}

// Restore original thread count
cuvs::core::omp::set_num_threads(saved_omp_threads);
Copy link
Copy Markdown
Contributor

@tarang-jain tarang-jain Nov 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this restoration is a bit confusing. the pool of threads is set inside the for loop above, but its restored after the for loop? Can you verify if this logic is correct?

Copy link
Copy Markdown
Contributor Author

@viclafargue viclafargue Nov 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree. This is actually not necessary. My worry was that the main thread could be one of the threads in the loop. But, after checking it appears like OpenMP is designed in such a way that omp_set_num_threads calls inside of a parallel region only affect how many threads each thread can use in nested parallel region. The main thread is unaffected, and its number of threads is actually preserved. I made sure that each thread sees its number of thread restored back to the global maximal number of threads within the parallel region.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also added the same OpenMP usage for the extend function.

} else if (index.mode_ == SHARDED) {
int64_t n_rows = index_dataset.extent(0);
int64_t n_cols = index_dataset.extent(1);
Expand All @@ -107,8 +122,20 @@ void build(const raft::resources& clique,
RAFT_LOG_DEBUG("SHARDED BUILD: %d*%drows", index.num_ranks_, n_rows_per_shard);

index.ann_interfaces_.resize(index.num_ranks_);
#pragma omp parallel for

// Enable nested parallelism
int saved_omp_threads = cuvs::core::omp::get_max_threads();
int threads_per_rank = std::max(1, saved_omp_threads / index.num_ranks_);
cuvs::core::omp::set_nested(1);

const int& requirements = index.num_ranks_;
Comment thread
viclafargue marked this conversation as resolved.
Outdated
cuvs::core::omp::check_threads(requirements);

#pragma omp parallel for num_threads(index.num_ranks_)
for (int rank = 0; rank < index.num_ranks_; rank++) {
// Set thread limit for this rank's nested OpenMP regions
cuvs::core::omp::set_num_threads(threads_per_rank);

const raft::resources& dev_res = raft::resource::set_current_device_to_rank(clique, rank);
int64_t offset = rank * n_rows_per_shard;
int64_t n_rows_of_current_shard = std::min(n_rows_per_shard, n_rows - offset);
Expand All @@ -119,6 +146,9 @@ void build(const raft::resources& clique,
cuvs::neighbors::build(dev_res, ann_if, index_params, partition);
Comment thread
viclafargue marked this conversation as resolved.
resource::sync_stream(dev_res);
}

Comment thread
viclafargue marked this conversation as resolved.
// Restore original thread count
cuvs::core::omp::set_num_threads(saved_omp_threads);
}
}

Expand Down
Loading