Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -412,6 +412,7 @@ add_library(
src/neighbors/nn_descent_float.cu
src/neighbors/nn_descent_int8.cu
src/neighbors/nn_descent_uint8.cu
src/neighbors/reachability.cu
src/neighbors/refine/detail/refine_device_float_float.cu
src/neighbors/refine/detail/refine_device_half_float.cu
src/neighbors/refine/detail/refine_device_int8_t_float.cu
Expand Down
4 changes: 2 additions & 2 deletions cpp/cmake/thirdparty/get_raft.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,8 @@ endfunction()
# To use a different RAFT locally, set the CMake variable
# CPM_raft_SOURCE=/path/to/local/raft
find_and_configure_raft(VERSION ${RAFT_VERSION}.00
FORK ${RAFT_FORK}
PINNED_TAG ${RAFT_PINNED_TAG}
FORK benfred
PINNED_TAG coo_sort_int64_t
ENABLE_MNMG_DEPENDENCIES OFF
ENABLE_NVTX OFF
USE_RAFT_STATIC ${CUVS_USE_RAFT_STATIC}
Expand Down
79 changes: 79 additions & 0 deletions cpp/include/cuvs/neighbors/reachability.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
/*
* Copyright (c) 2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once

#include <raft/core/device_mdspan.hpp>
#include <raft/core/handle.hpp>
#include <raft/sparse/coo.hpp>

#include <cuvs/distance/distance.hpp>

namespace cuvs::neighbors::reachability {

/**
* @defgroup reachability_cpp Mutual Reachability
* @{
*/
/**
* Constructs a mutual reachability graph, which is a k-nearest neighbors
* graph projected into mutual reachability space using the following
* function for each data point, where core_distance is the distance
* to the kth neighbor: max(core_distance(a), core_distance(b), d(a, b))
*
* Unfortunately, points in the tails of the pdf (e.g. in sparse regions
* of the space) can have very large neighborhoods, which will impact
* nearby neighborhoods. Because of this, it's possible that the
* radius for points in the main mass, which might have a very small
* radius initially, to expand very large. As a result, the initial
* knn which was used to compute the core distances may no longer
* capture the actual neighborhoods after projection into mutual
* reachability space.
*
* For the experimental version, we execute the knn twice- once
* to compute the radii (core distances) and again to capture
* the final neighborhoods. Future iterations of this algorithm
* will work improve upon this "exact" version, by using
* more specialized data structures, such as space-partitioning
* structures. It has also been shown that approximate nearest
* neighbors can yield reasonable neighborhoods as the
* data sizes increase.
*
* @param[in] handle raft handle for resource reuse
* @param[in] X input data points (size m * n)
* @param[in] min_samples this neighborhood will be selected for core distances
* @param[out] indptr CSR indptr of output knn graph (size m + 1)
* @param[out] core_dists output core distances array (size m)
* @param[out] out COO object, uninitialized on entry, on exit it stores the
* (symmetrized) maximum reachability distance for the k nearest
* neighbors.
* @param[in] metric distance metric to use, default Euclidean
* @param[in] alpha weight applied when internal distance is chosen for
* mutual reachability (value of 1.0 disables the weighting)
*/
void mutual_reachability_graph(
const raft::resources& handle,
raft::device_matrix_view<const float, int64_t, raft::row_major> X,
int min_samples,
raft::device_vector_view<int64_t> indptr,
raft::device_vector_view<float> core_dists,
raft::sparse::COO<float, int64_t>& out,
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Small nitpick- this sparse::COO class is going away completely and it's really not designed well. Can we swap this out for the raft::sparsity_owning_coo_matrix? I think we can do a view since we already know the size, right?

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's also okay if the answer is "time is running out, let's change that later"

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One challenge here is that we'd have to add support in RAFT for device_coo_matrix_view / device_sparsity_owning_coo_matrix to some sparse algorithms like raft::sparse::linalg::symmetrize andraft::sparse::convert::sorted_coo_to_csr - which only accept the raft::sparse::COO class. This shouldn't be too hard to do though -

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Underneath the public API, we can just use the arrays / existing types. But I’d like to get to a point where we are using the new types at least for new public APIs. Eventually we need to scrape through all the sparse APIs and use the new types everywhere. However new APIs could at least use the new types in the meantime.

Can you create an issue for this? I think we are running out of time to do it for 24.10.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Created an issue here #369

cuvs::distance::DistanceType metric = cuvs::distance::DistanceType::L2SqrtExpanded,
float alpha = 1.0);
/**
* @}
*/
} // namespace cuvs::neighbors::reachability
27 changes: 23 additions & 4 deletions cpp/src/neighbors/detail/knn_brute_force.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,10 @@ namespace cuvs::neighbors::detail {
* Calculates brute force knn, using a fixed memory budget
* by tiling over both the rows and columns of pairwise_distances
*/
template <typename ElementType = float, typename IndexType = int64_t, typename DistanceT = float>
template <typename ElementType = float,
typename IndexType = int64_t,
typename DistanceT = float,
typename DistanceEpilogue = raft::identity_op>
void tiled_brute_force_knn(const raft::resources& handle,
const ElementType* search, // size (m ,d)
const ElementType* index, // size (n ,d)
Expand All @@ -78,7 +81,8 @@ void tiled_brute_force_knn(const raft::resources& handle,
size_t max_col_tile_size = 0,
const DistanceT* precomputed_index_norms = nullptr,
const DistanceT* precomputed_search_norms = nullptr,
const uint32_t* filter_bitmap = nullptr)
const uint32_t* filter_bitmap = nullptr,
DistanceEpilogue distance_epilogue = raft::identity_op())
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunate we have to instantiate bfknn twice now :-( but great that we don't have to expose the eilogue through the public APIs. Hopefully at some point soon we'll establish a good way to specify these (maybe as JIT compiled functions) through the public APIs where we don't need to instanaite all the kernels end to end for each one.

{
// Figure out the number of rows/cols to tile for
size_t tile_rows = 0;
Expand Down Expand Up @@ -209,7 +213,8 @@ void tiled_brute_force_knn(const raft::resources& handle,
IndexType col = j + (idx % current_centroid_size);

cuvs::distance::detail::ops::l2_exp_cutlass_op<DistanceT, DistanceT> l2_op(sqrt);
return l2_op(row_norms[row], col_norms[col], dist[idx]);
auto val = l2_op(row_norms[row], col_norms[col], dist[idx]);
return distance_epilogue(val, row, col);
});
} else if (metric == cuvs::distance::DistanceType::CosineExpanded) {
auto row_norms = precomputed_search_norms ? precomputed_search_norms : search_norms.data();
Expand All @@ -223,8 +228,22 @@ void tiled_brute_force_knn(const raft::resources& handle,
IndexType row = i + (idx / current_centroid_size);
IndexType col = j + (idx % current_centroid_size);
auto val = DistanceT(1.0) - dist[idx] / DistanceT(row_norms[row] * col_norms[col]);
return val;
return distance_epilogue(val, row, col);
});
} else {
// if we're not l2 distance, and we have a distance epilogue - run it now
if constexpr (!std::is_same_v<DistanceEpilogue, raft::identity_op>) {
auto distances_ptr = temp_distances.data();
raft::linalg::map_offset(
handle,
raft::make_device_vector_view(temp_distances.data(),
current_query_size * current_centroid_size),
[=] __device__(size_t idx) {
IndexType row = i + (idx / current_centroid_size);
IndexType col = j + (idx % current_centroid_size);
return distance_epilogue(distances_ptr[idx], row, col);
});
}
}

if (filter_bitmap != nullptr) {
Expand Down
Loading