Skip to content

Commit 36266c8

Browse files
authored
Merge pull request rapidsai#787 from rapidsai/branch-25.04
Forward-merge branch-25.04 into branch-25.06
2 parents e4b5808 + ff86827 commit 36266c8

3 files changed

Lines changed: 57 additions & 14 deletions

File tree

cpp/src/neighbors/cagra.cuh

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,7 @@ void build_knn_graph(
172172
* // build KNN graph not using `cagra::build_knn_graph`
173173
* // build(knn_graph, dataset, ...);
174174
* // sort graph index
175-
* sort_knn_graph(res, dataset.view(), knn_graph.view());
175+
* sort_knn_graph(res, build_params.metric, dataset.view(), knn_graph.view());
176176
* // optimize graph
177177
* cagra::optimize(res, dataset, knn_graph.view(), optimized_graph.view());
178178
* // Construct an index from dataset and optimized knn_graph
@@ -184,6 +184,7 @@ void build_knn_graph(
184184
* @tparam IdxT type of the dataset vector indices
185185
*
186186
* @param[in] res raft resources
187+
* @param[in] metric metric
187188
* @param[in] dataset a matrix view (host or device) to a row-major matrix [n_rows, dim]
188189
* @param[in,out] knn_graph a matrix view (host or device) of the input knn graph [n_rows,
189190
* knn_graph_degree]
@@ -197,6 +198,7 @@ template <
197198
raft::host_device_accessor<std::experimental::default_accessor<IdxT>, raft::memory_type::host>>
198199
void sort_knn_graph(
199200
raft::resources const& res,
201+
cuvs::distance::DistanceType metric,
200202
raft::mdspan<const DataT, raft::matrix_extent<int64_t>, raft::row_major, d_accessor> dataset,
201203
raft::mdspan<IdxT, raft::matrix_extent<int64_t>, raft::row_major, g_accessor> knn_graph)
202204
{
@@ -215,7 +217,7 @@ void sort_knn_graph(
215217
raft::mdspan<const DataT, raft::matrix_extent<int64_t>, raft::row_major, d_accessor>(
216218
dataset.data_handle(), dataset.extent(0), dataset.extent(1));
217219

218-
cagra::detail::graph::sort_knn_graph(res, dataset_internal, knn_graph_internal);
220+
cagra::detail::graph::sort_knn_graph(res, metric, dataset_internal, knn_graph_internal);
219221
}
220222

221223
/**

cpp/src/neighbors/detail/cagra/cagra_build.cuh

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -368,7 +368,8 @@ void build_knn_graph(
368368
nn_descent_idx.graph().extent(0),
369369
nn_descent_idx.graph().extent(1));
370370

371-
cuvs::neighbors::cagra::detail::graph::sort_knn_graph(res, dataset, knn_graph_internal);
371+
cuvs::neighbors::cagra::detail::graph::sort_knn_graph(
372+
res, build_params.metric, dataset, knn_graph_internal);
372373
}
373374

374375
template <

cpp/src/neighbors/detail/cagra/graph_core.cuh

Lines changed: 51 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,8 @@ __global__ void kern_sort(const DATA_T* const dataset, // [dataset_chunk_size,
7878
const uint32_t dataset_dim,
7979
IdxT* const knn_graph, // [graph_chunk_size, graph_degree]
8080
const uint32_t graph_size,
81-
const uint32_t graph_degree)
81+
const uint32_t graph_degree,
82+
const cuvs::distance::DistanceType metric)
8283
{
8384
const IdxT srcNode = (blockDim.x * blockIdx.x + threadIdx.x) / raft::WarpSize;
8485
if (srcNode >= graph_size) { return; }
@@ -91,19 +92,46 @@ __global__ void kern_sort(const DATA_T* const dataset, // [dataset_chunk_size,
9192
// Compute distance from a src node to its neighbors
9293
for (int k = 0; k < graph_degree; k++) {
9394
const IdxT dstNode = knn_graph[k + static_cast<uint64_t>(graph_degree) * srcNode];
94-
float dist = 0.0;
95-
for (int d = lane_id; d < dataset_dim; d += raft::WarpSize) {
96-
float diff = cuvs::spatial::knn::detail::utils::mapping<float>{}(
97-
dataset[d + static_cast<uint64_t>(dataset_dim) * srcNode]) -
98-
cuvs::spatial::knn::detail::utils::mapping<float>{}(
99-
dataset[d + static_cast<uint64_t>(dataset_dim) * dstNode]);
100-
dist += diff * diff;
95+
float dist = 0;
96+
float norm2_dst = 0;
97+
if (metric == cuvs::distance::DistanceType::InnerProduct ||
98+
metric == cuvs::distance::DistanceType::CosineExpanded) {
99+
for (int d = lane_id; d < dataset_dim; d += raft::WarpSize) {
100+
auto elem_b = cuvs::spatial::knn::detail::utils::mapping<float>{}(
101+
dataset[d + static_cast<uint64_t>(dataset_dim) * dstNode]);
102+
dist -= cuvs::spatial::knn::detail::utils::mapping<float>{}(
103+
dataset[d + static_cast<uint64_t>(dataset_dim) * srcNode]) *
104+
elem_b;
105+
106+
if (metric == cuvs::distance::DistanceType::CosineExpanded) {
107+
norm2_dst += elem_b * elem_b;
108+
}
109+
}
110+
} else {
111+
// L2Expanded
112+
for (int d = lane_id; d < dataset_dim; d += raft::WarpSize) {
113+
float diff = cuvs::spatial::knn::detail::utils::mapping<float>{}(
114+
dataset[d + static_cast<uint64_t>(dataset_dim) * srcNode]) -
115+
cuvs::spatial::knn::detail::utils::mapping<float>{}(
116+
dataset[d + static_cast<uint64_t>(dataset_dim) * dstNode]);
117+
dist += diff * diff;
118+
}
101119
}
102120
dist += __shfl_xor_sync(0xffffffff, dist, 1);
103121
dist += __shfl_xor_sync(0xffffffff, dist, 2);
104122
dist += __shfl_xor_sync(0xffffffff, dist, 4);
105123
dist += __shfl_xor_sync(0xffffffff, dist, 8);
106124
dist += __shfl_xor_sync(0xffffffff, dist, 16);
125+
126+
if (metric == cuvs::distance::DistanceType::CosineExpanded) {
127+
norm2_dst += __shfl_xor_sync(0xffffffff, norm2_dst, 1);
128+
norm2_dst += __shfl_xor_sync(0xffffffff, norm2_dst, 2);
129+
norm2_dst += __shfl_xor_sync(0xffffffff, norm2_dst, 4);
130+
norm2_dst += __shfl_xor_sync(0xffffffff, norm2_dst, 8);
131+
norm2_dst += __shfl_xor_sync(0xffffffff, norm2_dst, 16);
132+
if (lane_id == (k % raft::WarpSize)) { dist /= sqrt(norm2_dst); }
133+
}
134+
107135
if (lane_id == (k % raft::WarpSize)) {
108136
my_keys[k / raft::WarpSize] = dist;
109137
my_vals[k / raft::WarpSize] = dstNode;
@@ -471,11 +499,17 @@ template <
471499
raft::host_device_accessor<std::experimental::default_accessor<IdxT>, raft::memory_type::host>>
472500
void sort_knn_graph(
473501
raft::resources const& res,
502+
const cuvs::distance::DistanceType metric,
474503
raft::mdspan<const DataT, raft::matrix_extent<int64_t>, raft::row_major, d_accessor> dataset,
475504
raft::mdspan<IdxT, raft::matrix_extent<int64_t>, raft::row_major, g_accessor> knn_graph)
476505
{
477506
RAFT_EXPECTS(dataset.extent(0) == knn_graph.extent(0),
478507
"dataset size is expected to have the same number of graph index size");
508+
RAFT_EXPECTS(
509+
metric == cuvs::distance::DistanceType::InnerProduct ||
510+
metric == cuvs::distance::DistanceType::CosineExpanded ||
511+
metric == cuvs::distance::DistanceType::L2Expanded,
512+
"Unsupported metric. Only InnerProduct, CosineExpanded, and L2Expanded are supported");
479513
const uint64_t dataset_size = dataset.extent(0);
480514
const uint64_t dataset_dim = dataset.extent(1);
481515
const DataT* dataset_ptr = dataset.data_handle();
@@ -507,8 +541,13 @@ void sort_knn_graph(
507541
graph_size * input_graph_degree,
508542
raft::resource::get_cuda_stream(res));
509543

510-
void (*kernel_sort)(
511-
const DataT* const, const IdxT, const uint32_t, IdxT* const, const uint32_t, const uint32_t);
544+
void (*kernel_sort)(const DataT* const,
545+
const IdxT,
546+
const uint32_t,
547+
IdxT* const,
548+
const uint32_t,
549+
const uint32_t,
550+
const cuvs::distance::DistanceType);
512551
if (input_graph_degree <= 32) {
513552
constexpr int numElementsPerThread = 1;
514553
kernel_sort = kern_sort<DataT, IdxT, numElementsPerThread>;
@@ -545,7 +584,8 @@ void sort_knn_graph(
545584
dataset_dim,
546585
d_input_graph.data_handle(),
547586
graph_size,
548-
input_graph_degree);
587+
input_graph_degree,
588+
metric);
549589
raft::resource::sync_stream(res);
550590
RAFT_LOG_DEBUG(".");
551591
raft::copy(input_graph_ptr,

0 commit comments

Comments
 (0)