Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
046aaee
change mg to use device_resources_snmg_nccl
jinsolp May 2, 2025
93ed24b
Merge branch 'branch-25.06' into mg-change-for-multi-gpu-resource
jinsolp May 2, 2025
37bd44d
cmake file
jinsolp May 2, 2025
48dca90
Merge branch 'mg-change-for-multi-gpu-resource' of https://github.com…
jinsolp May 2, 2025
7f83024
pin branch
jinsolp May 3, 2025
280adf1
fix for renamed functions
jinsolp May 5, 2025
3868e52
Merge branch 'branch-25.06' into mg-change-for-multi-gpu-resource
jinsolp May 5, 2025
db12f9a
Merge branch 'branch-25.06' into mg-change-for-multi-gpu-resource
jinsolp May 5, 2025
e85ca9a
revert
jinsolp May 7, 2025
0ea8579
Merge branch 'mg-change-for-multi-gpu-resource' of https://github.com…
jinsolp May 7, 2025
fb0f539
Merge branch 'branch-25.06' into mg-change-for-multi-gpu-resource
jinsolp May 7, 2025
929f40c
Merge branch 'branch-25.06' into mg-change-for-multi-gpu-resource
jinsolp May 7, 2025
1b80798
revert cmake file
jinsolp May 8, 2025
0e7f8a1
Merge branch 'mg-change-for-multi-gpu-resource' of https://github.com…
jinsolp May 8, 2025
a20ff57
typo
jinsolp May 8, 2025
07c2bf7
Merge branch 'branch-25.06' into mg-change-for-multi-gpu-resource
jinsolp May 8, 2025
237f5e0
changes and pin branch
jinsolp May 12, 2025
182fcd9
Merge branch 'branch-25.06' into mg-change-for-multi-gpu-resource
jinsolp May 12, 2025
3afd86a
Merge branch 'branch-25.06' into mg-change-for-multi-gpu-resource
jinsolp May 13, 2025
ddf29bd
Merge branch 'branch-25.06' into mg-change-for-multi-gpu-resource
jinsolp May 14, 2025
026a55c
revert cmake file
jinsolp May 14, 2025
ba29d49
Merge branch 'mg-change-for-multi-gpu-resource' of https://github.com…
jinsolp May 14, 2025
3876058
empty commit to trigger CI
jinsolp May 14, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cpp/bench/ann/src/cuvs/cuvs_mg_cagra_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
#include "cuvs_ann_bench_utils.h"
#include "cuvs_cagra_wrapper.h"
#include <cuvs/neighbors/cagra.hpp>
#include <raft/core/device_resources_snmg.hpp>
#include <raft/core/device_resources_snmg_nccl.hpp>

namespace cuvs::bench {
using namespace cuvs::neighbors;
Expand Down
2 changes: 1 addition & 1 deletion cpp/bench/ann/src/cuvs/cuvs_mg_ivf_flat_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
#include "cuvs_ann_bench_utils.h"
#include "cuvs_ivf_flat_wrapper.h"
#include <cuvs/neighbors/ivf_flat.hpp>
#include <raft/core/device_resources_snmg.hpp>
#include <raft/core/device_resources_snmg_nccl.hpp>

namespace cuvs::bench {
using namespace cuvs::neighbors;
Expand Down
2 changes: 1 addition & 1 deletion cpp/bench/ann/src/cuvs/cuvs_mg_ivf_pq_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
#include "cuvs_ann_bench_utils.h"
#include "cuvs_ivf_pq_wrapper.h"
#include <cuvs/neighbors/ivf_pq.hpp>
#include <raft/core/device_resources_snmg.hpp>
#include <raft/core/device_resources_snmg_nccl.hpp>

namespace cuvs::bench {
using namespace cuvs::neighbors;
Expand Down
4 changes: 2 additions & 2 deletions cpp/cmake/thirdparty/get_raft.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,8 @@ function(find_and_configure_raft)
COMPONENTS ${RAFT_COMPONENTS}
CPM_ARGS
EXCLUDE_FROM_ALL TRUE
GIT_REPOSITORY https://github.com/${PKG_FORK}/raft.git
GIT_TAG ${PKG_PINNED_TAG}
GIT_REPOSITORY https://github.com/jinsolp/raft.git
GIT_TAG multi-gpu-resource
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This needs to be reverted

SOURCE_SUBDIR cpp
OPTIONS
"BUILD_TESTS OFF"
Expand Down
19 changes: 10 additions & 9 deletions cpp/src/neighbors/mg/snmg.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@
#pragma once

#include "../detail/knn_merge_parts.cuh"
#include <raft/core/resource/nccl_clique.hpp>
#include <raft/core/resource/multi_gpu.hpp>
#include <raft/core/resource/nccl_comm.hpp>
#include <raft/core/serialize.hpp>
#include <raft/linalg/add.cuh>
#include <raft/util/cuda_dev_essentials.cuh>
Expand Down Expand Up @@ -75,10 +76,10 @@ void deserialize(const raft::resources& clique,
index.mode_ = (cuvs::neighbors::distribution_mode)deserialize_scalar<int>(handle, is);
index.num_ranks_ = deserialize_scalar<int>(handle, is);

if (index.num_ranks_ != raft::resource::get_nccl_num_ranks(clique)) {
if (index.num_ranks_ != raft::resource::get_num_ranks(clique)) {
RAFT_FAIL("Serialized index has %d ranks whereas NCCL clique has %d ranks",
index.num_ranks_,
raft::resource::get_nccl_num_ranks(clique));
raft::resource::get_num_ranks(clique));
}

for (int rank = 0; rank < index.num_ranks_; rank++) {
Expand Down Expand Up @@ -215,8 +216,8 @@ void sharded_search_with_direct_merge(const raft::resources& clique,
const raft::resources& dev_res = raft::resource::set_current_device_to_rank(clique, rank);
auto& ann_if = index.ann_interfaces_[rank];

if (rank == raft::resource::get_nccl_clique_root_rank(clique)) { // root rank
uint64_t batch_offset = raft::resource::get_nccl_clique_root_rank(clique) * part_size;
if (rank == raft::resource::get_root_rank(clique)) { // root rank
uint64_t batch_offset = raft::resource::get_root_rank(clique) * part_size;
auto d_neighbors = raft::make_device_matrix_view<IdxT, int64_t, row_major>(
in_neighbors.data_handle() + batch_offset, n_rows_of_current_batch, n_neighbors);
auto d_distances = raft::make_device_matrix_view<float, int64_t, row_major>(
Expand All @@ -227,7 +228,7 @@ void sharded_search_with_direct_merge(const raft::resources& clique,
// wait for other ranks
ncclGroupStart();
for (int from_rank = 0; from_rank < index.num_ranks_; from_rank++) {
if (from_rank == raft::resource::get_nccl_clique_root_rank(clique)) continue;
if (from_rank == raft::resource::get_root_rank(clique)) continue;

batch_offset = from_rank * part_size;
ncclRecv(in_neighbors.data_handle() + batch_offset,
Expand Down Expand Up @@ -258,13 +259,13 @@ void sharded_search_with_direct_merge(const raft::resources& clique,
ncclSend(d_neighbors.data_handle(),
part_size * sizeof(IdxT),
ncclUint8,
raft::resource::get_nccl_clique_root_rank(clique),
raft::resource::get_root_rank(clique),
raft::resource::get_nccl_comm(dev_res),
raft::resource::get_cuda_stream(dev_res));
ncclSend(d_distances.data_handle(),
part_size * sizeof(float),
ncclUint8,
raft::resource::get_nccl_clique_root_rank(clique),
raft::resource::get_root_rank(clique),
raft::resource::get_nccl_comm(dev_res),
raft::resource::get_cuda_stream(dev_res));
ncclGroupEnd();
Expand Down Expand Up @@ -655,7 +656,7 @@ template <typename AnnIndexType, typename T, typename IdxT>
mg_index<AnnIndexType, T, IdxT>::mg_index(const raft::resources& clique, distribution_mode mode)
: mode_(mode), round_robin_counter_(std::make_shared<std::atomic<int64_t>>(0))
{
num_ranks_ = raft::resource::get_nccl_num_ranks(clique);
num_ranks_ = raft::resource::get_num_ranks(clique);
}

template <typename AnnIndexType, typename T, typename IdxT>
Expand Down
4 changes: 2 additions & 2 deletions cpp/tests/neighbors/mg.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
#include <cuvs/neighbors/cagra.hpp>
#include <cuvs/neighbors/ivf_flat.hpp>
#include <cuvs/neighbors/ivf_pq.hpp>
#include <raft/core/device_resources_snmg.hpp>
#include <raft/core/device_resources_snmg_nccl.hpp>

namespace cuvs::neighbors::mg {

Expand Down Expand Up @@ -645,7 +645,7 @@ class AnnMGTest : public ::testing::TestWithParam<AnnMGInputs> {
void TearDown() override {}

private:
raft::device_resources_snmg clique_;
raft::device_resources_snmg_nccl clique_;
AnnMGInputs ps;
std::vector<DataT> h_index_dataset;
std::vector<DataT> h_queries;
Expand Down
Loading