Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 31 additions & 31 deletions cpp/src/cluster/detail/connectivities.cuh
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2021-2024, NVIDIA CORPORATION.
* Copyright (c) 2021-2025, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -23,6 +23,7 @@
#include <raft/core/resource/cuda_stream.hpp>
#include <raft/core/resource/thrust_policy.hpp>
#include <raft/core/resources.hpp>
#include <raft/linalg/map.cuh>
#include <raft/linalg/unary_op.cuh>
#include <raft/sparse/convert/csr.cuh>
#include <raft/sparse/coo.hpp>
Expand All @@ -31,10 +32,6 @@

#include <rmm/device_uvector.hpp>

#include <thrust/iterator/zip_iterator.h>
#include <thrust/transform.h>
#include <thrust/tuple.h>

#include <limits>

namespace cuvs::cluster::agglomerative::detail {
Expand Down Expand Up @@ -83,18 +80,25 @@ struct distance_graph_impl<Linkage::KNN_GRAPH, value_idx, value_t> {
data.resize(knn_graph_coo.nnz, stream);

// self-loops get max distance
auto transform_in = thrust::make_zip_iterator(
thrust::make_tuple(knn_graph_coo.rows(), knn_graph_coo.cols(), knn_graph_coo.vals()));

thrust::transform(thrust_policy,
transform_in,
transform_in + knn_graph_coo.nnz,
knn_graph_coo.vals(),
[=] __device__(const thrust::tuple<value_idx, value_idx, value_t>& tup) {
bool self_loop = thrust::get<0>(tup) == thrust::get<1>(tup);
return (self_loop * std::numeric_limits<value_t>::max()) +
(!self_loop * thrust::get<2>(tup));
});
auto rows_view = raft::make_device_vector_view<const value_idx, value_idx>(knn_graph_coo.rows(),
knn_graph_coo.nnz);
auto cols_view = raft::make_device_vector_view<const value_idx, value_idx>(knn_graph_coo.cols(),
knn_graph_coo.nnz);
auto vals_in_view = raft::make_device_vector_view<const value_t, value_idx>(
knn_graph_coo.vals(), knn_graph_coo.nnz);
auto vals_out_view =
raft::make_device_vector_view<value_t, value_idx>(knn_graph_coo.vals(), knn_graph_coo.nnz);

raft::linalg::map(
handle,
vals_out_view,
[=] __device__(const value_idx row, const value_idx col, const value_t val) {
bool self_loop = row == col;
return (self_loop * std::numeric_limits<value_t>::max()) + (!self_loop * val);
},
rows_view,
cols_view,
vals_in_view);

raft::sparse::convert::sorted_coo_to_csr(
knn_graph_coo.rows(), knn_graph_coo.nnz, indptr.data(), m + 1, stream);
Expand Down Expand Up @@ -147,7 +151,9 @@ void pairwise_distances(const raft::resources& handle,
value_idx blocks = raft::ceildiv(nnz, (value_idx)256);
fill_indices2<value_idx><<<blocks, 256, 0, stream>>>(indices, m, nnz);

thrust::sequence(exec_policy, indptr, indptr + m, 0, (int)m);
raft::linalg::map_offset(handle,
raft::make_device_vector_view<value_idx, value_idx>(indptr, m),
[=] __device__(value_idx idx) { return idx * m; });

raft::update_device(indptr + m, &nnz, 1, stream);

Expand All @@ -160,19 +166,13 @@ void pairwise_distances(const raft::resources& handle,
handle, X_view, X_view, raft::make_device_matrix_view<value_t, value_idx>(data, m, m), metric);

// self-loops get max distance
auto transform_in =
thrust::make_zip_iterator(thrust::make_tuple(thrust::make_counting_iterator(0), data));

thrust::transform(exec_policy,
transform_in,
transform_in + nnz,
data,
[=] __device__(const thrust::tuple<value_idx, value_t>& tup) {
value_idx idx = thrust::get<0>(tup);
bool self_loop = idx % m == idx / m;
return (self_loop * std::numeric_limits<value_t>::max()) +
(!self_loop * thrust::get<1>(tup));
});
auto data_view = raft::make_device_vector_view<value_t, value_idx>(data, nnz);

raft::linalg::map_offset(handle, data_view, [=] __device__(value_idx idx) {
value_t val = data[idx];
bool self_loop = idx % m == idx / m;
return (self_loop * std::numeric_limits<value_t>::max()) + (!self_loop * val);
});
}

/**
Expand Down
48 changes: 18 additions & 30 deletions cpp/src/neighbors/ball_cover/ball_cover.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include <raft/core/resource/cuda_stream.hpp>
#include <raft/core/resource/thrust_policy.hpp>
#include <raft/core/resources.hpp>
#include <raft/linalg/map.cuh>
#include <raft/matrix/copy.cuh>
#include <raft/random/rng.cuh>
#include <raft/sparse/convert/csr.cuh>
Expand All @@ -34,12 +35,9 @@
#include <rmm/exec_policy.hpp>

#include <thrust/fill.h>
#include <thrust/for_each.h>
#include <thrust/functional.h>
#include <thrust/iterator/counting_iterator.h>
#include <thrust/iterator/zip_iterator.h>
#include <thrust/reduce.h>
#include <thrust/sequence.h>
#include <thrust/sort.h>
#include <thrust/tuple.h>

Expand Down Expand Up @@ -68,10 +66,7 @@ void sample_landmarks(raft::resources const& handle,
rmm::device_uvector<value_idx> R_indices(index.n_landmarks,
raft::resource::get_cuda_stream(handle));

thrust::sequence(raft::resource::get_thrust_policy(handle),
index.get_R_1nn_cols().data_handle(),
index.get_R_1nn_cols().data_handle() + index.m,
(value_idx)0);
raft::linalg::map_offset(handle, index.get_R_1nn_cols(), raft::identity_op{});

thrust::fill(raft::resource::get_thrust_policy(handle),
R_1nn_ones.data(),
Expand Down Expand Up @@ -124,25 +119,23 @@ void construct_landmark_1nn(raft::resources const& handle,
int64_t k,
cuvs::neighbors::ball_cover::index<value_idx, value_t>& index)
{
rmm::device_uvector<value_idx> R_1nn_inds(index.m, raft::resource::get_cuda_stream(handle));
auto R_1nn_inds = raft::make_device_vector<value_idx, value_idx>(handle, index.m);

thrust::fill(raft::resource::get_thrust_policy(handle),
R_1nn_inds.data(),
R_1nn_inds.data() + index.m,
R_1nn_inds.data_handle(),
R_1nn_inds.data_handle() + index.m,
std::numeric_limits<value_idx>::max());

value_idx* R_1nn_inds_ptr = R_1nn_inds.data();
value_t* R_1nn_dists_ptr = index.get_R_1nn_dists().data_handle();

auto idxs = thrust::make_counting_iterator<value_idx>(0);
thrust::for_each(
raft::resource::get_thrust_policy(handle), idxs, idxs + index.m, [=] __device__(value_idx i) {
R_1nn_inds_ptr[i] = R_knn_inds_ptr[i * k];
R_1nn_dists_ptr[i] = R_knn_dists_ptr[i * k];
raft::linalg::map_offset(handle, R_1nn_inds.view(), [R_knn_inds_ptr, k] __device__(value_idx i) {
return R_knn_inds_ptr[i * k];
});
raft::linalg::map_offset(
handle, index.get_R_1nn_dists(), [R_knn_dists_ptr, k] __device__(value_idx i) {
return R_knn_dists_ptr[i * k];
});

auto keys = thrust::make_zip_iterator(
thrust::make_tuple(R_1nn_inds.data(), index.get_R_1nn_dists().data_handle()));
thrust::make_tuple(R_1nn_inds.data_handle(), index.get_R_1nn_dists().data_handle()));

// group neighborhoods for each reference landmark and sort each group by distance
thrust::sort_by_key(raft::resource::get_thrust_policy(handle),
Expand All @@ -152,7 +145,7 @@ void construct_landmark_1nn(raft::resources const& handle,
NNComp());

// convert to CSR for fast lookup
raft::sparse::convert::sorted_coo_to_csr(R_1nn_inds.data(),
raft::sparse::convert::sorted_coo_to_csr(R_1nn_inds.data_handle(),
index.m,
index.get_R_indptr().data_handle(),
index.n_landmarks + 1,
Expand Down Expand Up @@ -212,18 +205,13 @@ template <typename value_idx, typename value_t>
void compute_landmark_radii(raft::resources const& handle,
cuvs::neighbors::ball_cover::index<value_idx, value_t>& index)
{
auto entries = thrust::make_counting_iterator<value_idx>(0);

const value_idx* R_indptr_ptr = index.get_R_indptr().data_handle();
const value_t* R_1nn_dists_ptr = index.get_R_1nn_dists().data_handle();
value_t* R_radius_ptr = index.get_R_radius().data_handle();
thrust::for_each(raft::resource::get_thrust_policy(handle),
entries,
entries + index.n_landmarks,
[=] __device__(value_idx input) {
value_idx last_row_idx = R_indptr_ptr[input + 1] - 1;
R_radius_ptr[input] = R_1nn_dists_ptr[last_row_idx];
});
raft::linalg::map_offset(
handle, index.get_R_radius(), [R_indptr_ptr, R_1nn_dists_ptr] __device__(value_idx input) {
value_idx last_row_idx = R_indptr_ptr[input + 1] - 1;
return R_1nn_dists_ptr[last_row_idx];
});
}

/**
Expand Down
16 changes: 9 additions & 7 deletions cpp/src/neighbors/ball_cover/registers.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

#include <raft/core/resource/cuda_stream.hpp>
#include <raft/core/resource/thrust_policy.hpp>
#include <raft/linalg/unary_op.cuh>
#include <raft/neighbors/detail/faiss_select/key_value_block_select.cuh>
#include <raft/util/cuda_utils.cuh>

Expand Down Expand Up @@ -1458,13 +1459,14 @@ void rbc_eps_pass(raft::resources const& handle,

if (actual_max > max_k_in) {
// ceil vd to max_k
thrust::transform(raft::resource::get_thrust_policy(handle),
vd_ptr,
vd_ptr + n_query_rows,
vd_ptr,
[max_k_in] __device__(value_idx vd_count) {
return vd_count > max_k_in ? max_k_in : vd_count;
});
raft::linalg::unaryOp(
vd_ptr,
vd_ptr,
n_query_rows,
[max_k_in] __device__(value_idx vd_count) {
return vd_count > max_k_in ? max_k_in : vd_count;
},
raft::resource::get_cuda_stream(handle));
}

thrust::exclusive_scan(raft::resource::get_thrust_policy(handle),
Expand Down
66 changes: 33 additions & 33 deletions cpp/src/neighbors/detail/reachability.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,16 @@
#pragma once
#include "./knn_brute_force.cuh"

#include <raft/core/resource/cuda_stream.hpp>
#include <raft/core/resources.hpp>
#include <raft/linalg/map.cuh>
#include <raft/linalg/unary_op.cuh>
#include <raft/sparse/convert/csr.cuh>
#include <raft/sparse/linalg/symmetrize.cuh>
#include <raft/util/cuda_utils.cuh>
#include <raft/util/cudart_utils.hpp>

#include <rmm/device_uvector.hpp>
#include <rmm/exec_policy.hpp>

#include <thrust/iterator/counting_iterator.h>
#include <thrust/iterator/zip_iterator.h>
#include <thrust/transform.h>
#include <thrust/tuple.h>

namespace cuvs::neighbors::detail::reachability {

Expand All @@ -47,17 +44,19 @@ namespace cuvs::neighbors::detail::reachability {
* @param[in] stream stream for which to order cuda operations
*/
template <typename value_idx, typename value_t, int tpb = 256>
void core_distances(
value_t* knn_dists, int min_samples, int n_neighbors, size_t n, value_t* out, cudaStream_t stream)
void core_distances(raft::resources const& handle,
value_t* knn_dists,
int min_samples,
int n_neighbors,
size_t n,
value_t* out)
{
ASSERT(n_neighbors >= min_samples,
"the size of the neighborhood should be greater than or equal to min_samples");

auto exec_policy = rmm::exec_policy(stream);

auto indices = thrust::make_counting_iterator<value_idx>(0);
auto out_view = raft::make_device_vector_view<value_t, value_idx>(out, n);

thrust::transform(exec_policy, indices, indices + n, out, [=] __device__(value_idx row) {
raft::linalg::map_offset(handle, out_view, [=] __device__(value_idx row) {
return knn_dists[row * n_neighbors + (min_samples - 1)];
});
}
Expand Down Expand Up @@ -118,7 +117,7 @@ void _compute_core_dists(const raft::resources& handle,
compute_knn(handle, X, inds.data(), dists.data(), m, n, X, m, min_samples, metric);

// Slice core distances (distances to kth nearest neighbor)
core_distances<value_idx>(dists.data(), min_samples, min_samples, m, core_dists, stream);
core_distances<value_idx>(handle, dists.data(), min_samples, min_samples, m, core_dists);
}

// Functor to post-process distances into reachability space
Expand Down Expand Up @@ -202,8 +201,7 @@ void mutual_reachability_graph(const raft::resources& handle,
RAFT_EXPECTS(metric == cuvs::distance::DistanceType::L2SqrtExpanded,
"Currently only L2 expanded distance is supported");

auto stream = raft::resource::get_cuda_stream(handle);
auto exec_policy = raft::resource::get_thrust_policy(handle);
auto stream = raft::resource::get_cuda_stream(handle);

rmm::device_uvector<value_idx> coo_rows(min_samples * m, stream);
rmm::device_uvector<value_idx> inds(min_samples * m, stream);
Expand All @@ -213,7 +211,7 @@ void mutual_reachability_graph(const raft::resources& handle,
compute_knn(handle, X, inds.data(), dists.data(), m, n, X, m, min_samples, metric);

// Slice core distances (distances to kth nearest neighbor)
core_distances<value_idx>(dists.data(), min_samples, min_samples, m, core_dists, stream);
core_distances<value_idx>(handle, dists.data(), min_samples, min_samples, m, core_dists);

/**
* Compute L2 norm
Expand All @@ -222,12 +220,12 @@ void mutual_reachability_graph(const raft::resources& handle,
handle, inds.data(), dists.data(), X, m, n, min_samples, core_dists, (value_t)1.0 / alpha);

// self-loops get max distance
auto coo_rows_counting_itr = thrust::make_counting_iterator<value_idx>(0);
thrust::transform(exec_policy,
coo_rows_counting_itr,
coo_rows_counting_itr + (m * min_samples),
coo_rows.data(),
[min_samples] __device__(value_idx c) -> value_idx { return c / min_samples; });
auto coo_rows_view =
raft::make_device_vector_view<value_idx, value_idx>(coo_rows.data(), m * min_samples);
raft::linalg::map_offset(
handle, coo_rows_view, [min_samples] __device__(value_idx c) -> value_idx {
return c / min_samples;
});

raft::sparse::linalg::symmetrize(handle,
coo_rows.data(),
Expand All @@ -241,18 +239,20 @@ void mutual_reachability_graph(const raft::resources& handle,
raft::sparse::convert::sorted_coo_to_csr(out.rows(), out.nnz, indptr, m + 1, stream);

// self-loops get max distance
auto transform_in =
thrust::make_zip_iterator(thrust::make_tuple(out.rows(), out.cols(), out.vals()));
auto rows_view = raft::make_device_vector_view<const value_idx, nnz_t>(out.rows(), out.nnz);
auto cols_view = raft::make_device_vector_view<const value_idx, nnz_t>(out.cols(), out.nnz);
auto vals_in_view = raft::make_device_vector_view<const value_t, nnz_t>(out.vals(), out.nnz);
auto vals_out_view = raft::make_device_vector_view<value_t, nnz_t>(out.vals(), out.nnz);

thrust::transform(exec_policy,
transform_in,
transform_in + out.nnz,
out.vals(),
[=] __device__(const thrust::tuple<value_idx, value_idx, value_t>& tup) {
return thrust::get<0>(tup) == thrust::get<1>(tup)
? std::numeric_limits<value_t>::max()
: thrust::get<2>(tup);
});
raft::linalg::map(
handle,
vals_out_view,
[=] __device__(const value_idx row, const value_idx col, const value_t val) {
return row == col ? std::numeric_limits<value_t>::max() : val;
},
rows_view,
cols_view,
vals_in_view);
}

} // namespace cuvs::neighbors::detail::reachability
15 changes: 7 additions & 8 deletions cpp/tests/cluster/connect_knn.cu
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

#include <cuvs/distance/distance.hpp>
#include <raft/core/handle.hpp>
#include <raft/linalg/map.cuh>
#include <raft/random/make_blobs.cuh>
#include <raft/sparse/convert/csr.cuh>
#include <raft/sparse/coo.hpp>
Expand Down Expand Up @@ -78,7 +79,7 @@ class ConnectKNNTest : public ::testing::TestWithParam<ConnectKNNInputs> {
rmm::device_uvector<T> core_dists(ps.n_rows, stream);
if (ps.mutual_reach) {
cuvs::neighbors::detail::reachability::core_distances<int64_t, T>(
dists.data(), ps.k, ps.k, (size_t)ps.n_rows, core_dists.data(), stream);
handle, dists.data(), ps.k, ps.k, (size_t)ps.n_rows, core_dists.data());

auto epilogue = cuvs::neighbors::detail::reachability::ReachabilityPostProcess<int64_t, T>{
core_dists.data(), 1.0};
Expand Down Expand Up @@ -111,13 +112,11 @@ class ConnectKNNTest : public ::testing::TestWithParam<ConnectKNNInputs> {
rmm::device_uvector<int64_t> indptr(ps.n_rows + 1, stream);

// changing inds and dists to sparse format
int64_t k = ps.k;
auto coo_rows_counting_itr = thrust::make_counting_iterator<int64_t>(0);
thrust::transform(raft::resource::get_thrust_policy(handle),
coo_rows_counting_itr,
coo_rows_counting_itr + (ps.n_rows * ps.k),
coo_rows.data(),
[k] __device__(int64_t c) -> int64_t { return c / k; });
int64_t k = ps.k;
auto coo_rows_view =
raft::make_device_vector_view<int64_t, int64_t>(coo_rows.data(), ps.n_rows * ps.k);
raft::linalg::map_offset(
handle, coo_rows_view, [k] __device__(int64_t c) -> int64_t { return c / k; });

raft::sparse::linalg::symmetrize(handle,
coo_rows.data(),
Expand Down
Loading