Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions cpp/cmake/thirdparty/get_cuvs.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,8 @@ endfunction()
# To use a different CUVS locally, set the CMake variable
# CPM_cuvs_SOURCE=/path/to/local/cuvs
find_and_configure_cuvs(VERSION ${CUML_MIN_VERSION_cuvs}
FORK rapidsai
PINNED_TAG ${rapids-cmake-checkout-tag}
FORK aamijar
PINNED_TAG port-raft-epsilon-neighborhood
Comment thread
aamijar marked this conversation as resolved.
Outdated
EXCLUDE_FROM_ALL ${CUML_EXCLUDE_CUVS_FROM_ALL}
# When PINNED_TAG above doesn't match cuml,
# force local cuvs clone in build directory
Expand Down
22 changes: 17 additions & 5 deletions cpp/src/dbscan/vertexdeg/algo.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
#include <raft/linalg/coalesced_reduction.cuh>
#include <raft/linalg/matrix_vector_op.cuh>
#include <raft/linalg/norm.cuh>
#include <raft/neighbors/epsilon_neighborhood.cuh>
#include <raft/util/device_atomics.cuh>

#include <rmm/device_uvector.hpp>
Expand All @@ -39,6 +38,7 @@
#include <thrust/transform.h>

#include <cuvs/neighbors/ball_cover.hpp>
#include <cuvs/neighbors/epsilon_neighborhood.hpp>
#include <math.h>

namespace ML {
Expand Down Expand Up @@ -206,8 +206,14 @@ void launcher(const raft::handle_t& handle,
if (data.rbc_index != nullptr) {
eps_nn(handle, data, start_vertex_id, batch_size, stream, (value_t)sqrtf(eps2));
} else {
raft::neighbors::epsilon_neighborhood::epsUnexpL2SqNeighborhood<value_t, index_t>(
data.adj, data.vd, data.x + start_vertex_id * k, data.x, n, m, k, eps2, stream);
cuvs::neighbors::epsilon_neighborhood::eps_neighbors_l2sq<value_t, index_t, int64_t>(
handle,
raft::make_device_matrix_view<const value_t, int64_t, raft::row_major>(
data.x + start_vertex_id * k, n, k),
raft::make_device_matrix_view<const value_t, int64_t, raft::row_major>(data.x, m, k),
raft::make_device_matrix_view<bool, int64_t, raft::row_major>(data.adj, n, m),
raft::make_device_vector_view<index_t, int64_t>(data.vd, n + 1),
eps2);
}

/**
Expand All @@ -226,8 +232,14 @@ void launcher(const raft::handle_t& handle,
if (data.rbc_index != nullptr) {
eps_nn(handle, data, start_vertex_id, batch_size, stream, data.eps);
} else {
raft::neighbors::epsilon_neighborhood::epsUnexpL2SqNeighborhood<value_t, index_t>(
data.adj, data.vd, data.x + start_vertex_id * k, data.x, n, m, k, eps2, stream);
cuvs::neighbors::epsilon_neighborhood::eps_neighbors_l2sq<value_t, index_t, int64_t>(
handle,
raft::make_device_matrix_view<const value_t, int64_t, raft::row_major>(
data.x + start_vertex_id * k, n, k),
raft::make_device_matrix_view<const value_t, int64_t, raft::row_major>(data.x, m, k),
raft::make_device_matrix_view<bool, int64_t, raft::row_major>(data.adj, n, m),
raft::make_device_vector_view<index_t, int64_t>(data.vd, n + 1),
eps2);
}
}

Expand Down
Loading