Skip to content
40 changes: 24 additions & 16 deletions cpp/src/neighbors/detail/nn_descent.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -488,7 +488,7 @@ __device__ __forceinline__ void remove_duplicates(
// MAX_RESIDENT_THREAD_PER_SM = BLOCK_SIZE * BLOCKS_PER_SM = 2048
// For architectures 750 and 860 (890), the values for MAX_RESIDENT_THREAD_PER_SM
// is 1024 and 1536 respectively, which means the bounds don't work anymore
template <typename Index_t, typename ID_t = InternalID_t<Index_t>>
template <typename Index_t, typename ID_t = InternalID_t<Index_t>, typename DistEpilogue_t>
RAFT_KERNEL
#ifdef __CUDA_ARCH__
// Use minBlocksPerMultiprocessor = 4 on specific arches
Expand All @@ -513,7 +513,8 @@ __launch_bounds__(BLOCK_SIZE)
int graph_width,
int* locks,
DistData_t* l2_norms,
cuvs::distance::DistanceType metric)
cuvs::distance::DistanceType metric,
DistEpilogue_t dist_epilogue)
{
#if (__CUDA_ARCH__ >= 700)
using namespace nvcuda;
Expand Down Expand Up @@ -623,20 +624,22 @@ __launch_bounds__(BLOCK_SIZE)
__syncthreads();

for (int i = threadIdx.x; i < MAX_NUM_BI_SAMPLES * SKEWED_MAX_NUM_BI_SAMPLES; i += blockDim.x) {
if (i % SKEWED_MAX_NUM_BI_SAMPLES < list_new_size &&
i / SKEWED_MAX_NUM_BI_SAMPLES < list_new_size) {
int row_id = i % SKEWED_MAX_NUM_BI_SAMPLES;
int col_id = i / SKEWED_MAX_NUM_BI_SAMPLES;

if (row_id < list_new_size && col_id < list_new_size) {
if (metric == cuvs::distance::DistanceType::InnerProduct) {
s_distances[i] = -s_distances[i];
} else if (metric == cuvs::distance::DistanceType::CosineExpanded) {
s_distances[i] = 1.0 - s_distances[i];
} else { // L2Expanded or L2SqrtExpanded
s_distances[i] = l2_norms[new_neighbors[i % SKEWED_MAX_NUM_BI_SAMPLES]] +
l2_norms[new_neighbors[i / SKEWED_MAX_NUM_BI_SAMPLES]] -
2.0 * s_distances[i];
s_distances[i] =
l2_norms[new_neighbors[row_id]] + l2_norms[new_neighbors[col_id]] - 2.0 * s_distances[i];
// for fp32 vs fp16 precision differences resulting in negative distances when distance
// should be 0 related issue: https://github.com/rapidsai/cuvs/issues/991
s_distances[i] = s_distances[i] < 0.0f ? 0.0f : s_distances[i];
}
s_distances[i] = dist_epilogue(s_distances[i], new_neighbors[row_id], new_neighbors[col_id]);
} else {
s_distances[i] = std::numeric_limits<float>::max();
}
Expand Down Expand Up @@ -707,20 +710,21 @@ __launch_bounds__(BLOCK_SIZE)
__syncthreads();

for (int i = threadIdx.x; i < MAX_NUM_BI_SAMPLES * SKEWED_MAX_NUM_BI_SAMPLES; i += blockDim.x) {
if (i % SKEWED_MAX_NUM_BI_SAMPLES < list_old_size &&
i / SKEWED_MAX_NUM_BI_SAMPLES < list_new_size) {
int row_id = i % SKEWED_MAX_NUM_BI_SAMPLES;
int col_id = i / SKEWED_MAX_NUM_BI_SAMPLES;
if (row_id < list_old_size && col_id < list_new_size) {
if (metric == cuvs::distance::DistanceType::InnerProduct) {
s_distances[i] = -s_distances[i];
} else if (metric == cuvs::distance::DistanceType::CosineExpanded) {
s_distances[i] = 1.0 - s_distances[i];
} else { // L2Expanded or L2SqrtExpanded
s_distances[i] = l2_norms[old_neighbors[i % SKEWED_MAX_NUM_BI_SAMPLES]] +
l2_norms[new_neighbors[i / SKEWED_MAX_NUM_BI_SAMPLES]] -
2.0 * s_distances[i];
s_distances[i] =
l2_norms[old_neighbors[row_id]] + l2_norms[new_neighbors[col_id]] - 2.0 * s_distances[i];
// for fp32 vs fp16 precision differences resulting in negative distances when distance
// should be 0 related issue: https://github.com/rapidsai/cuvs/issues/991
s_distances[i] = s_distances[i] < 0.0f ? 0.0f : s_distances[i];
}
s_distances[i] = dist_epilogue(s_distances[i], old_neighbors[row_id], new_neighbors[col_id]);
} else {
s_distances[i] = std::numeric_limits<float>::max();
}
Expand Down Expand Up @@ -1034,7 +1038,8 @@ void GNND<Data_t, Index_t>::add_reverse_edges(Index_t* graph_ptr,
}

template <typename Data_t, typename Index_t>
void GNND<Data_t, Index_t>::local_join(cudaStream_t stream)
template <typename DistEpilogue_t>
void GNND<Data_t, Index_t>::local_join(cudaStream_t stream, DistEpilogue_t dist_epilogue)
{
raft::matrix::fill(res, dists_buffer_.view(), std::numeric_limits<float>::max());
local_join_kernel<<<nrow_, BLOCK_SIZE, 0, stream>>>(graph_.h_graph_new.data_handle(),
Expand All @@ -1051,15 +1056,18 @@ void GNND<Data_t, Index_t>::local_join(cudaStream_t stream)
DEGREE_ON_DEVICE,
d_locks_.data_handle(),
l2_norms_.data_handle(),
build_config_.metric);
build_config_.metric,
dist_epilogue);
}

template <typename Data_t, typename Index_t>
template <typename DistEpilogue_t>
void GNND<Data_t, Index_t>::build(Data_t* data,
const Index_t nrow,
Index_t* output_graph,
bool return_distances,
DistData_t* output_distances)
DistData_t* output_distances,
DistEpilogue_t dist_epilogue)
{
using input_t = typename std::remove_const<Data_t>::type;

Expand Down Expand Up @@ -1154,7 +1162,7 @@ void GNND<Data_t, Index_t>::build(Data_t* data,
raft::util::arch::SM_range(raft::util::arch::SM_70(), raft::util::arch::SM_future());

if (wmma_range.contains(runtime_arch)) {
local_join(stream);
local_join(stream, dist_epilogue);
} else {
THROW("NN_DESCENT cannot be run for __CUDA_ARCH__ < 700");
}
Expand Down
8 changes: 6 additions & 2 deletions cpp/src/neighbors/detail/nn_descent_gnnd.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -207,11 +207,13 @@ class GNND {
GNND(const GNND&) = delete;
GNND& operator=(const GNND&) = delete;

template <typename DistEpilogue_t = raft::identity_op>
void build(Data_t* data,
const Index_t nrow,
Index_t* output_graph,
bool return_distances,
DistData_t* output_distances);
DistData_t* output_distances,
DistEpilogue_t dist_epilogue = DistEpilogue_t{});
~GNND() = default;
using ID_t = InternalID_t<Index_t>;
void reset(raft::resources const& res);
Expand All @@ -222,7 +224,9 @@ class GNND {
Index_t* d_rev_graph_ptr,
int2* list_sizes,
cudaStream_t stream = 0);
void local_join(cudaStream_t stream = 0);

template <typename DistEpilogue_t = raft::identity_op>
void local_join(cudaStream_t stream = 0, DistEpilogue_t dist_epilogue = DistEpilogue_t{});

raft::resources const& res;

Expand Down
17 changes: 3 additions & 14 deletions cpp/src/neighbors/detail/reachability.cuh
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2024, NVIDIA CORPORATION.
* Copyright (c) 2024-2025, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -16,6 +16,7 @@

#pragma once
#include "./knn_brute_force.cuh"
#include "./reachability_types.cuh"

#include <raft/linalg/unary_op.cuh>
#include <raft/sparse/convert/csr.cuh>
Expand Down Expand Up @@ -121,18 +122,6 @@ void _compute_core_dists(const raft::resources& handle,
core_distances<value_idx>(dists.data(), min_samples, min_samples, m, core_dists, stream);
}

// Functor to post-process distances into reachability space
template <typename value_idx, typename value_t>
struct ReachabilityPostProcess {
DI value_t operator()(value_t value, value_idx row, value_idx col) const
{
return max(core_dists[col], max(core_dists[row], alpha * value));
}

const value_t* core_dists;
value_t alpha;
};

/**
* Given core distances, Fuses computations of L2 distances between all
* points, projection into mutual reachability space, and k-selection.
Expand Down Expand Up @@ -163,7 +152,7 @@ void mutual_reachability_knn_l2(const raft::resources& handle,
// `A type local to a function cannot be used in the template argument of the
// enclosing parent function (and any parent classes) of an extended __device__
// or __host__ __device__ lambda`
auto epilogue = ReachabilityPostProcess<value_idx, value_t>{core_dists, alpha};
auto epilogue = ReachabilityPostProcess<value_idx, value_t>{core_dists, alpha, m};

cuvs::neighbors::detail::
tiled_brute_force_knn<value_t, value_idx, value_t, ReachabilityPostProcess<value_idx, value_t>>(
Expand Down
37 changes: 37 additions & 0 deletions cpp/src/neighbors/detail/reachability_types.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
/*
* Copyright (c) 2025, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not a reason to hold up this PR, but since cuVS is no longer header-only, we don't have to have these _types files anymore. We could just include this in the shared code (assuming there is a shared place to put it).

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Placed it back in reachability.cuh.

The mutual reachability function will be cleaned up in another PR, and the reachability.cuh file will eventually be left with only ReachabilityPostProcess struct.

#include <raft/core/detail/macros.hpp>
#include <raft/matrix/shift.cuh>
#include <rmm/exec_policy.hpp>

namespace cuvs::neighbors::detail::reachability {

// Functor to post-process distances into reachability space
template <typename value_idx, typename value_t>
struct ReachabilityPostProcess {
RAFT_DEVICE_INLINE_FUNCTION value_t operator()(value_t value, value_idx row, value_idx col) const
{
return max(core_dists[col], max(core_dists[row], alpha * value));
}

const value_t* core_dists;
value_t alpha;
size_t n; // size of core_dists array
};

} // namespace cuvs::neighbors::detail::reachability
29 changes: 27 additions & 2 deletions cpp/src/neighbors/nn_descent_float.cu
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2024, NVIDIA CORPORATION.
* Copyright (c) 2024-2025, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -14,6 +14,8 @@
* limitations under the License.
*/

#include "./detail/nn_descent_gnnd.hpp"
#include "./detail/reachability_types.cuh"
#include "nn_descent.cuh"
#include <cuvs/neighbors/nn_descent.hpp>

Expand Down Expand Up @@ -54,7 +56,30 @@ namespace cuvs::neighbors::nn_descent {
return idx; \
} \
}; \
template class detail::GNND<const T, int>;
template class detail::GNND<const T, int>; \
\
template void detail::GNND<const T, int>::build< \
cuvs::neighbors::detail::reachability::ReachabilityPostProcess<int, T>>( \
const T* data, \
const int nrow, \
int* output_graph, \
bool return_distances, \
float* output_distances, \
cuvs::neighbors::detail::reachability::ReachabilityPostProcess<int, T> dist_epilogue); \
template void detail::GNND<const T, int>::local_join< \
cuvs::neighbors::detail::reachability::ReachabilityPostProcess<int, T>>( \
cudaStream_t stream, \
cuvs::neighbors::detail::reachability::ReachabilityPostProcess<int, T> dist_epilogue); \
\
template void detail::GNND<const T, int>::build<raft::identity_op>( \
const T* data, \
const int nrow, \
int* output_graph, \
bool return_distances, \
float* output_distances, \
raft::identity_op dist_epilogue); \
template void detail::GNND<const T, int>::local_join<raft::identity_op>( \
cudaStream_t stream, raft::identity_op dist_epilogue);

CUVS_INST_NN_DESCENT_BUILD(float, uint32_t);

Expand Down
Loading
Loading