Skip to content

Commit 23c6dc0

Browse files
authored
ReachabilityPostProcess distance epilogue for NN Descent (#1073)
NN Descent changed to support distance epilogues. Currently supporting `ReachabilityPostProcess` and `identity_op`. A new distance epilogue will need new instantiations. Mutual reachability computation will eventually be hidden behind the `all_neighbors` API. Basic tests are still added in this PR to ensure the correctness of this feature. Authors: - Jinsol Park (https://github.com/jinsolp) Approvers: - Corey J. Nolet (https://github.com/cjnolet) URL: #1073
1 parent 7462cbd commit 23c6dc0

6 files changed

Lines changed: 238 additions & 27 deletions

File tree

cpp/src/neighbors/detail/nn_descent.cuh

Lines changed: 24 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -488,7 +488,7 @@ __device__ __forceinline__ void remove_duplicates(
488488
// MAX_RESIDENT_THREAD_PER_SM = BLOCK_SIZE * BLOCKS_PER_SM = 2048
489489
// For architectures 750 and 860 (890), the values for MAX_RESIDENT_THREAD_PER_SM
490490
// is 1024 and 1536 respectively, which means the bounds don't work anymore
491-
template <typename Index_t, typename ID_t = InternalID_t<Index_t>>
491+
template <typename Index_t, typename ID_t = InternalID_t<Index_t>, typename DistEpilogue_t>
492492
RAFT_KERNEL
493493
#ifdef __CUDA_ARCH__
494494
// Use minBlocksPerMultiprocessor = 4 on specific arches
@@ -513,7 +513,8 @@ __launch_bounds__(BLOCK_SIZE)
513513
int graph_width,
514514
int* locks,
515515
DistData_t* l2_norms,
516-
cuvs::distance::DistanceType metric)
516+
cuvs::distance::DistanceType metric,
517+
DistEpilogue_t dist_epilogue)
517518
{
518519
#if (__CUDA_ARCH__ >= 700)
519520
using namespace nvcuda;
@@ -623,20 +624,22 @@ __launch_bounds__(BLOCK_SIZE)
623624
__syncthreads();
624625

625626
for (int i = threadIdx.x; i < MAX_NUM_BI_SAMPLES * SKEWED_MAX_NUM_BI_SAMPLES; i += blockDim.x) {
626-
if (i % SKEWED_MAX_NUM_BI_SAMPLES < list_new_size &&
627-
i / SKEWED_MAX_NUM_BI_SAMPLES < list_new_size) {
627+
int row_id = i % SKEWED_MAX_NUM_BI_SAMPLES;
628+
int col_id = i / SKEWED_MAX_NUM_BI_SAMPLES;
629+
630+
if (row_id < list_new_size && col_id < list_new_size) {
628631
if (metric == cuvs::distance::DistanceType::InnerProduct) {
629632
s_distances[i] = -s_distances[i];
630633
} else if (metric == cuvs::distance::DistanceType::CosineExpanded) {
631634
s_distances[i] = 1.0 - s_distances[i];
632635
} else { // L2Expanded or L2SqrtExpanded
633-
s_distances[i] = l2_norms[new_neighbors[i % SKEWED_MAX_NUM_BI_SAMPLES]] +
634-
l2_norms[new_neighbors[i / SKEWED_MAX_NUM_BI_SAMPLES]] -
635-
2.0 * s_distances[i];
636+
s_distances[i] =
637+
l2_norms[new_neighbors[row_id]] + l2_norms[new_neighbors[col_id]] - 2.0 * s_distances[i];
636638
// for fp32 vs fp16 precision differences resulting in negative distances when distance
637639
// should be 0 related issue: https://github.com/rapidsai/cuvs/issues/991
638640
s_distances[i] = s_distances[i] < 0.0f ? 0.0f : s_distances[i];
639641
}
642+
s_distances[i] = dist_epilogue(s_distances[i], new_neighbors[row_id], new_neighbors[col_id]);
640643
} else {
641644
s_distances[i] = std::numeric_limits<float>::max();
642645
}
@@ -707,20 +710,21 @@ __launch_bounds__(BLOCK_SIZE)
707710
__syncthreads();
708711

709712
for (int i = threadIdx.x; i < MAX_NUM_BI_SAMPLES * SKEWED_MAX_NUM_BI_SAMPLES; i += blockDim.x) {
710-
if (i % SKEWED_MAX_NUM_BI_SAMPLES < list_old_size &&
711-
i / SKEWED_MAX_NUM_BI_SAMPLES < list_new_size) {
713+
int row_id = i % SKEWED_MAX_NUM_BI_SAMPLES;
714+
int col_id = i / SKEWED_MAX_NUM_BI_SAMPLES;
715+
if (row_id < list_old_size && col_id < list_new_size) {
712716
if (metric == cuvs::distance::DistanceType::InnerProduct) {
713717
s_distances[i] = -s_distances[i];
714718
} else if (metric == cuvs::distance::DistanceType::CosineExpanded) {
715719
s_distances[i] = 1.0 - s_distances[i];
716720
} else { // L2Expanded or L2SqrtExpanded
717-
s_distances[i] = l2_norms[old_neighbors[i % SKEWED_MAX_NUM_BI_SAMPLES]] +
718-
l2_norms[new_neighbors[i / SKEWED_MAX_NUM_BI_SAMPLES]] -
719-
2.0 * s_distances[i];
721+
s_distances[i] =
722+
l2_norms[old_neighbors[row_id]] + l2_norms[new_neighbors[col_id]] - 2.0 * s_distances[i];
720723
// for fp32 vs fp16 precision differences resulting in negative distances when distance
721724
// should be 0 related issue: https://github.com/rapidsai/cuvs/issues/991
722725
s_distances[i] = s_distances[i] < 0.0f ? 0.0f : s_distances[i];
723726
}
727+
s_distances[i] = dist_epilogue(s_distances[i], old_neighbors[row_id], new_neighbors[col_id]);
724728
} else {
725729
s_distances[i] = std::numeric_limits<float>::max();
726730
}
@@ -1034,7 +1038,8 @@ void GNND<Data_t, Index_t>::add_reverse_edges(Index_t* graph_ptr,
10341038
}
10351039

10361040
template <typename Data_t, typename Index_t>
1037-
void GNND<Data_t, Index_t>::local_join(cudaStream_t stream)
1041+
template <typename DistEpilogue_t>
1042+
void GNND<Data_t, Index_t>::local_join(cudaStream_t stream, DistEpilogue_t dist_epilogue)
10381043
{
10391044
raft::matrix::fill(res, dists_buffer_.view(), std::numeric_limits<float>::max());
10401045
local_join_kernel<<<nrow_, BLOCK_SIZE, 0, stream>>>(graph_.h_graph_new.data_handle(),
@@ -1051,15 +1056,18 @@ void GNND<Data_t, Index_t>::local_join(cudaStream_t stream)
10511056
DEGREE_ON_DEVICE,
10521057
d_locks_.data_handle(),
10531058
l2_norms_.data_handle(),
1054-
build_config_.metric);
1059+
build_config_.metric,
1060+
dist_epilogue);
10551061
}
10561062

10571063
template <typename Data_t, typename Index_t>
1064+
template <typename DistEpilogue_t>
10581065
void GNND<Data_t, Index_t>::build(Data_t* data,
10591066
const Index_t nrow,
10601067
Index_t* output_graph,
10611068
bool return_distances,
1062-
DistData_t* output_distances)
1069+
DistData_t* output_distances,
1070+
DistEpilogue_t dist_epilogue)
10631071
{
10641072
using input_t = typename std::remove_const<Data_t>::type;
10651073

@@ -1154,7 +1162,7 @@ void GNND<Data_t, Index_t>::build(Data_t* data,
11541162
raft::util::arch::SM_range(raft::util::arch::SM_70(), raft::util::arch::SM_future());
11551163
11561164
if (wmma_range.contains(runtime_arch)) {
1157-
local_join(stream);
1165+
local_join(stream, dist_epilogue);
11581166
} else {
11591167
THROW("NN_DESCENT cannot be run for __CUDA_ARCH__ < 700");
11601168
}

cpp/src/neighbors/detail/nn_descent_gnnd.hpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -207,11 +207,13 @@ class GNND {
207207
GNND(const GNND&) = delete;
208208
GNND& operator=(const GNND&) = delete;
209209

210+
template <typename DistEpilogue_t = raft::identity_op>
210211
void build(Data_t* data,
211212
const Index_t nrow,
212213
Index_t* output_graph,
213214
bool return_distances,
214-
DistData_t* output_distances);
215+
DistData_t* output_distances,
216+
DistEpilogue_t dist_epilogue = DistEpilogue_t{});
215217
~GNND() = default;
216218
using ID_t = InternalID_t<Index_t>;
217219
void reset(raft::resources const& res);
@@ -222,7 +224,9 @@ class GNND {
222224
Index_t* d_rev_graph_ptr,
223225
int2* list_sizes,
224226
cudaStream_t stream = 0);
225-
void local_join(cudaStream_t stream = 0);
227+
228+
template <typename DistEpilogue_t = raft::identity_op>
229+
void local_join(cudaStream_t stream = 0, DistEpilogue_t dist_epilogue = DistEpilogue_t{});
226230

227231
raft::resources const& res;
228232

cpp/src/neighbors/detail/reachability.cuh

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright (c) 2024, NVIDIA CORPORATION.
2+
* Copyright (c) 2024-2025, NVIDIA CORPORATION.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -131,6 +131,7 @@ struct ReachabilityPostProcess {
131131

132132
const value_t* core_dists;
133133
value_t alpha;
134+
size_t n; // size of core_dists array
134135
};
135136

136137
/**
@@ -163,7 +164,7 @@ void mutual_reachability_knn_l2(const raft::resources& handle,
163164
// `A type local to a function cannot be used in the template argument of the
164165
// enclosing parent function (and any parent classes) of an extended __device__
165166
// or __host__ __device__ lambda`
166-
auto epilogue = ReachabilityPostProcess<value_idx, value_t>{core_dists, alpha};
167+
auto epilogue = ReachabilityPostProcess<value_idx, value_t>{core_dists, alpha, m};
167168

168169
cuvs::neighbors::detail::
169170
tiled_brute_force_knn<value_t, value_idx, value_t, ReachabilityPostProcess<value_idx, value_t>>(

cpp/src/neighbors/nn_descent_float.cu

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright (c) 2024, NVIDIA CORPORATION.
2+
* Copyright (c) 2024-2025, NVIDIA CORPORATION.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -14,6 +14,8 @@
1414
* limitations under the License.
1515
*/
1616

17+
#include "./detail/nn_descent_gnnd.hpp"
18+
#include "./detail/reachability.cuh"
1719
#include "nn_descent.cuh"
1820
#include <cuvs/neighbors/nn_descent.hpp>
1921

@@ -54,7 +56,30 @@ namespace cuvs::neighbors::nn_descent {
5456
return idx; \
5557
} \
5658
}; \
57-
template class detail::GNND<const T, int>;
59+
template class detail::GNND<const T, int>; \
60+
\
61+
template void detail::GNND<const T, int>::build< \
62+
cuvs::neighbors::detail::reachability::ReachabilityPostProcess<int, T>>( \
63+
const T* data, \
64+
const int nrow, \
65+
int* output_graph, \
66+
bool return_distances, \
67+
float* output_distances, \
68+
cuvs::neighbors::detail::reachability::ReachabilityPostProcess<int, T> dist_epilogue); \
69+
template void detail::GNND<const T, int>::local_join< \
70+
cuvs::neighbors::detail::reachability::ReachabilityPostProcess<int, T>>( \
71+
cudaStream_t stream, \
72+
cuvs::neighbors::detail::reachability::ReachabilityPostProcess<int, T> dist_epilogue); \
73+
\
74+
template void detail::GNND<const T, int>::build<raft::identity_op>( \
75+
const T* data, \
76+
const int nrow, \
77+
int* output_graph, \
78+
bool return_distances, \
79+
float* output_distances, \
80+
raft::identity_op dist_epilogue); \
81+
template void detail::GNND<const T, int>::local_join<raft::identity_op>( \
82+
cudaStream_t stream, raft::identity_op dist_epilogue);
5883

5984
CUVS_INST_NN_DESCENT_BUILD(float, uint32_t);
6085

0 commit comments

Comments
 (0)