-
Notifications
You must be signed in to change notification settings - Fork 184
Add function for calculating the mutual_reachability_graph #323
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
cc7d253
9872fdf
0fc4186
ce461b5
25592bd
905a6b4
5385458
e4e0740
2819ccb
1018e52
bc3feb7
21d8a18
407873c
000cf84
8ac36cb
19ab46f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,79 @@ | ||
| /* | ||
| * Copyright (c) 2024, NVIDIA CORPORATION. | ||
| * | ||
| * Licensed under the Apache License, Version 2.0 (the "License"); | ||
| * you may not use this file except in compliance with the License. | ||
| * You may obtain a copy of the License at | ||
| * | ||
| * http://www.apache.org/licenses/LICENSE-2.0 | ||
| * | ||
| * Unless required by applicable law or agreed to in writing, software | ||
| * distributed under the License is distributed on an "AS IS" BASIS, | ||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| * See the License for the specific language governing permissions and | ||
| * limitations under the License. | ||
| */ | ||
|
|
||
| #pragma once | ||
|
|
||
| #include <raft/core/device_mdspan.hpp> | ||
| #include <raft/core/handle.hpp> | ||
| #include <raft/sparse/coo.hpp> | ||
|
|
||
| #include <cuvs/distance/distance.hpp> | ||
|
|
||
| namespace cuvs::neighbors::reachability { | ||
|
|
||
| /** | ||
| * @defgroup reachability_cpp Mutual Reachability | ||
| * @{ | ||
| */ | ||
| /** | ||
| * Constructs a mutual reachability graph, which is a k-nearest neighbors | ||
| * graph projected into mutual reachability space using the following | ||
| * function for each data point, where core_distance is the distance | ||
| * to the kth neighbor: max(core_distance(a), core_distance(b), d(a, b)) | ||
| * | ||
| * Unfortunately, points in the tails of the pdf (e.g. in sparse regions | ||
| * of the space) can have very large neighborhoods, which will impact | ||
| * nearby neighborhoods. Because of this, it's possible that the | ||
| * radius for points in the main mass, which might have a very small | ||
| * radius initially, to expand very large. As a result, the initial | ||
| * knn which was used to compute the core distances may no longer | ||
| * capture the actual neighborhoods after projection into mutual | ||
| * reachability space. | ||
| * | ||
| * For the experimental version, we execute the knn twice- once | ||
| * to compute the radii (core distances) and again to capture | ||
| * the final neighborhoods. Future iterations of this algorithm | ||
| * will work improve upon this "exact" version, by using | ||
| * more specialized data structures, such as space-partitioning | ||
| * structures. It has also been shown that approximate nearest | ||
| * neighbors can yield reasonable neighborhoods as the | ||
| * data sizes increase. | ||
| * | ||
| * @param[in] handle raft handle for resource reuse | ||
| * @param[in] X input data points (size m * n) | ||
| * @param[in] min_samples this neighborhood will be selected for core distances | ||
| * @param[out] indptr CSR indptr of output knn graph (size m + 1) | ||
| * @param[out] core_dists output core distances array (size m) | ||
| * @param[out] out COO object, uninitialized on entry, on exit it stores the | ||
| * (symmetrized) maximum reachability distance for the k nearest | ||
| * neighbors. | ||
| * @param[in] metric distance metric to use, default Euclidean | ||
| * @param[in] alpha weight applied when internal distance is chosen for | ||
| * mutual reachability (value of 1.0 disables the weighting) | ||
| */ | ||
| void mutual_reachability_graph( | ||
| const raft::resources& handle, | ||
| raft::device_matrix_view<const float, int, raft::row_major> X, | ||
| int min_samples, | ||
| raft::device_vector_view<int> indptr, | ||
| raft::device_vector_view<float> core_dists, | ||
| raft::sparse::COO<float, int>& out, | ||
| cuvs::distance::DistanceType metric = cuvs::distance::DistanceType::L2SqrtExpanded, | ||
| float alpha = 1.0); | ||
| /** | ||
| * @} | ||
| */ | ||
| } // namespace cuvs::neighbors::reachability |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -62,7 +62,10 @@ namespace cuvs::neighbors::detail { | |
| * Calculates brute force knn, using a fixed memory budget | ||
| * by tiling over both the rows and columns of pairwise_distances | ||
| */ | ||
| template <typename ElementType = float, typename IndexType = int64_t, typename DistanceT = float> | ||
| template <typename ElementType = float, | ||
| typename IndexType = int64_t, | ||
| typename DistanceT = float, | ||
| typename DistanceEpilogue = raft::identity_op> | ||
| void tiled_brute_force_knn(const raft::resources& handle, | ||
| const ElementType* search, // size (m ,d) | ||
| const ElementType* index, // size (n ,d) | ||
|
|
@@ -78,7 +81,8 @@ void tiled_brute_force_knn(const raft::resources& handle, | |
| size_t max_col_tile_size = 0, | ||
| const DistanceT* precomputed_index_norms = nullptr, | ||
| const DistanceT* precomputed_search_norms = nullptr, | ||
| const uint32_t* filter_bitmap = nullptr) | ||
| const uint32_t* filter_bitmap = nullptr, | ||
| DistanceEpilogue distance_epilogue = raft::identity_op()) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Unfortunate we have to instantiate bfknn twice now :-( but great that we don't have to expose the eilogue through the public APIs. Hopefully at some point soon we'll establish a good way to specify these (maybe as JIT compiled functions) through the public APIs where we don't need to instanaite all the kernels end to end for each one. |
||
| { | ||
| // Figure out the number of rows/cols to tile for | ||
| size_t tile_rows = 0; | ||
|
|
@@ -207,7 +211,8 @@ void tiled_brute_force_knn(const raft::resources& handle, | |
| IndexType col = j + (idx % current_centroid_size); | ||
|
|
||
| cuvs::distance::detail::ops::l2_exp_cutlass_op<DistanceT, DistanceT> l2_op(sqrt); | ||
| return l2_op(row_norms[row], col_norms[col], dist[idx]); | ||
| auto val = l2_op(row_norms[row], col_norms[col], dist[idx]); | ||
| return distance_epilogue(val, row, col); | ||
| }); | ||
| } else if (metric == cuvs::distance::DistanceType::CosineExpanded) { | ||
| auto row_norms = precomputed_search_norms ? precomputed_search_norms : search_norms.data(); | ||
|
|
@@ -221,8 +226,22 @@ void tiled_brute_force_knn(const raft::resources& handle, | |
| IndexType row = i + (idx / current_centroid_size); | ||
| IndexType col = j + (idx % current_centroid_size); | ||
| auto val = DistanceT(1.0) - dist[idx] / DistanceT(row_norms[row] * col_norms[col]); | ||
| return val; | ||
| return distance_epilogue(val, row, col); | ||
| }); | ||
| } else { | ||
| // if we're not l2 distance, and we have a distance epilogue - run it now | ||
| if constexpr (!std::is_same_v<DistanceEpilogue, raft::identity_op>) { | ||
| auto distances_ptr = temp_distances.data(); | ||
| raft::linalg::map_offset( | ||
| handle, | ||
| raft::make_device_vector_view(temp_distances.data(), | ||
| current_query_size * current_centroid_size), | ||
| [=] __device__(size_t idx) { | ||
| IndexType row = i + (idx / current_centroid_size); | ||
| IndexType col = j + (idx % current_centroid_size); | ||
| return distance_epilogue(distances_ptr[idx], row, col); | ||
| }); | ||
| } | ||
| } | ||
|
|
||
| if (filter_bitmap != nullptr) { | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's saddening to see having to add this in here just for
inttype. Can you create an issue or maybe even a larger tracker issue with some follow-up tasks and add consolidating these type instantiations to that? Similar to the other types, it would be great if we could establish a single set of common integral types and maybe even consolidate float/double just into float.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've added an issue here #370 for consolidating the template params -