Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion cpp/include/cuvs/preprocessing/spectral_embedding.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION.
* SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION.
* SPDX-License-Identifier: Apache-2.0
*/

Expand Down Expand Up @@ -174,6 +174,16 @@ void transform(raft::resources const& handle,
raft::device_coo_matrix_view<double, int, int, int> connectivity_graph,
raft::device_matrix_view<double, int, raft::col_major> embedding);

void transform(raft::resources const& handle,
params config,
raft::device_coo_matrix_view<float, int, int, int64_t> connectivity_graph,
raft::device_matrix_view<float, int, raft::col_major> embedding);

void transform(raft::resources const& handle,
params config,
raft::device_coo_matrix_view<double, int, int, int64_t> connectivity_graph,
raft::device_matrix_view<double, int, raft::col_major> embedding);

/**
* @}
*/
Expand Down
154 changes: 54 additions & 100 deletions cpp/src/preprocessing/spectral/detail/spectral_embedding.cuh
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION.
* SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION.
* SPDX-License-Identifier: Apache-2.0
*/

Expand All @@ -11,6 +11,7 @@
#include <raft/core/device_coo_matrix.hpp>
#include <raft/core/device_mdspan.hpp>
#include <raft/core/handle.hpp>
#include <raft/core/resource/cuda_stream.hpp>
#include <raft/core/resources.hpp>
#include <raft/linalg/matrix_vector_op.cuh>
#include <raft/matrix/gather.cuh>
Expand All @@ -29,55 +30,18 @@

namespace cuvs::preprocessing::spectral_embedding::detail {

template <typename DataT>
raft::device_csr_matrix_view<DataT, int, int, int> coo_to_csr_matrix(
raft::resources const& handle,
const int n_samples,
raft::device_vector_view<int> sym_coo_row_ind,
raft::device_coo_matrix_view<DataT, int, int, int> sym_coo_matrix_view)
{
auto stream = raft::resource::get_cuda_stream(handle);

raft::sparse::op::coo_sort<DataT>(n_samples,
n_samples,
sym_coo_matrix_view.structure_view().get_nnz(),
sym_coo_matrix_view.structure_view().get_rows().data(),
sym_coo_matrix_view.structure_view().get_cols().data(),
sym_coo_matrix_view.get_elements().data(),
stream);

raft::sparse::convert::sorted_coo_to_csr(sym_coo_matrix_view.structure_view().get_rows().data(),
sym_coo_matrix_view.structure_view().get_nnz(),
sym_coo_row_ind.data_handle(),
n_samples,
stream);

auto sym_coo_nnz = sym_coo_matrix_view.structure_view().get_nnz();
raft::copy(sym_coo_row_ind.data_handle() + sym_coo_row_ind.size() - 1, &sym_coo_nnz, 1, stream);

auto csr_matrix_view = raft::make_device_csr_matrix_view<DataT, int, int, int>(
const_cast<DataT*>(sym_coo_matrix_view.get_elements().data()),
raft::make_device_compressed_structure_view<int, int, int>(
const_cast<int*>(sym_coo_row_ind.data_handle()),
const_cast<int*>(sym_coo_matrix_view.structure_view().get_cols().data()),
n_samples,
n_samples,
sym_coo_matrix_view.structure_view().get_nnz()));
return csr_matrix_view;
}

template <typename DataT>
raft::device_csr_matrix<DataT, int, int, int> create_laplacian(
raft::resources const& handle,
params spectral_embedding_config,
raft::device_csr_matrix_view<DataT, int, int, int> csr_matrix_view,
raft::device_vector_view<DataT, int> diagonal)
template <typename DataT, typename OutSparseMatrixType, typename InSparseMatrixViewType>
OutSparseMatrixType create_laplacian(raft::resources const& handle,
params spectral_embedding_config,
InSparseMatrixViewType sparse_matrix_view,
raft::device_vector_view<DataT, int> diagonal)
{
auto laplacian = spectral_embedding_config.norm_laplacian
? raft::sparse::linalg::laplacian_normalized(handle, csr_matrix_view, diagonal)
: raft::sparse::linalg::compute_graph_laplacian(handle, csr_matrix_view);
auto laplacian =
spectral_embedding_config.norm_laplacian
? raft::sparse::linalg::laplacian_normalized(handle, sparse_matrix_view, diagonal)
: raft::sparse::linalg::compute_graph_laplacian(handle, sparse_matrix_view);

auto laplacian_elements_view = raft::make_device_vector_view<DataT, int>(
auto laplacian_elements_view = raft::make_device_vector_view<DataT>(
laplacian.get_elements().data(), laplacian.structure_view().get_nnz());

raft::linalg::unary_op(handle,
Expand All @@ -88,11 +52,11 @@ raft::device_csr_matrix<DataT, int, int, int> create_laplacian(
return laplacian;
}

template <typename DataT>
template <typename DataT, typename InSparseMatrixViewType>
void compute_eigenpairs(raft::resources const& handle,
params spectral_embedding_config,
const int n_samples,
raft::device_csr_matrix<DataT, int, int, int>& laplacian,
InSparseMatrixViewType laplacian_view,
raft::device_vector_view<DataT, int> diagonal,
raft::device_matrix_view<DataT, int, raft::col_major> embedding)
{
Expand All @@ -110,13 +74,7 @@ void compute_eigenpairs(raft::resources const& handle,
raft::make_device_matrix<DataT, int, raft::col_major>(handle, n_samples, config.n_components);

raft::sparse::solver::lanczos_compute_smallest_eigenvectors<int, DataT>(
handle,
config,
raft::make_device_csr_matrix_view<DataT, int, int, int>(laplacian.get_elements().data(),
laplacian.structure_view()),
std::nullopt,
eigenvalues.view(),
eigenvectors.view());
handle, config, laplacian_view, std::nullopt, eigenvalues.view(), eigenvectors.view());

if (spectral_embedding_config.norm_laplacian) {
raft::linalg::matrix_vector_op<raft::Apply::ALONG_COLUMNS>(
Expand Down Expand Up @@ -160,53 +118,50 @@ void compute_eigenpairs(raft::resources const& handle,
);
}

template <typename DataT>
template <typename DataT, typename NNZType>
void transform(raft::resources const& handle,
params spectral_embedding_config,
raft::device_coo_matrix_view<DataT, int, int, int> connectivity_graph,
raft::device_coo_matrix_view<DataT, int, int, NNZType> connectivity_graph,
raft::device_matrix_view<DataT, int, raft::col_major> embedding)
{
const int n_samples = connectivity_graph.structure_view().get_n_rows();
auto diagonal = raft::make_device_vector<DataT, int>(handle, n_samples);

auto sym_coo_row_ind = raft::make_device_vector<int>(handle, n_samples + 1);
auto diagonal = raft::make_device_vector<DataT, int>(handle, n_samples);

auto csr_matrix_view =
coo_to_csr_matrix(handle, n_samples, sym_coo_row_ind.view(), connectivity_graph);
auto laplacian =
create_laplacian(handle, spectral_embedding_config, csr_matrix_view, diagonal.view());
auto laplacian = create_laplacian<DataT, raft::device_coo_matrix<DataT, int, int, NNZType>>(
handle, spectral_embedding_config, connectivity_graph, diagonal.view());
compute_eigenpairs(
handle, spectral_embedding_config, n_samples, laplacian, diagonal.view(), embedding);
handle, spectral_embedding_config, n_samples, laplacian.view(), diagonal.view(), embedding);
}

template <typename NNZType>
void create_connectivity_graph(
raft::resources const& handle,
cuvs::preprocessing::spectral_embedding::params spectral_embedding_config,
raft::device_matrix_view<float, int, raft::row_major> dataset,
raft::device_coo_matrix<float, int, int, int>& connectivity_graph)
raft::device_coo_matrix<float, int, int, NNZType>& connectivity_graph)
{
const int n_samples = dataset.extent(0);
const int n_features = dataset.extent(1);
const int k_search = spectral_embedding_config.n_neighbors;
const size_t nnz = n_samples * k_search;
const int64_t n_samples = dataset.extent(0);
const int64_t n_features = dataset.extent(1);
const int k_search = spectral_embedding_config.n_neighbors;
const NNZType nnz = static_cast<NNZType>(n_samples) * k_search;

auto stream = raft::resource::get_cuda_stream(handle);

cuvs::neighbors::brute_force::search_params search_params;
cuvs::neighbors::brute_force::index_params index_params;
index_params.metric = cuvs::distance::DistanceType::L2SqrtExpanded;

auto d_indices = raft::make_device_matrix<int64_t>(handle, n_samples, k_search);
auto d_distances = raft::make_device_matrix<float>(handle, n_samples, k_search);
auto d_indices = raft::make_device_matrix<int64_t, int64_t>(handle, n_samples, k_search);
auto d_distances = raft::make_device_matrix<float, int64_t>(handle, n_samples, k_search);

auto index =
cuvs::neighbors::brute_force::build(handle, index_params, raft::make_const_mdspan(dataset));

cuvs::neighbors::brute_force::search(
handle, search_params, index, dataset, d_indices.view(), d_distances.view());

auto knn_rows = raft::make_device_vector<int>(handle, nnz);
auto knn_cols = raft::make_device_vector<int>(handle, nnz);
auto knn_rows = raft::make_device_vector<int, NNZType>(handle, nnz);
auto knn_cols = raft::make_device_vector<int, NNZType>(handle, nnz);

raft::linalg::unary_op(
handle, make_const_mdspan(d_indices.view()), knn_cols.view(), [] __device__(int64_t x) {
Expand All @@ -216,34 +171,36 @@ void create_connectivity_graph(
thrust::tabulate(raft::resource::get_thrust_policy(handle),
knn_rows.data_handle(),
knn_rows.data_handle() + nnz,
[k_search] __device__(int idx) { return idx / k_search; });
[k_search] __device__(NNZType idx) { return idx / k_search; });

// set all distances to 1.0f (connectivity KNN graph)
raft::matrix::fill(handle, raft::make_device_vector_view(d_distances.data_handle(), nnz), 1.0f);
raft::matrix::fill(
handle, raft::make_device_vector_view<float, NNZType>(d_distances.data_handle(), nnz), 1.0f);

auto coo_matrix_view = raft::make_device_coo_matrix_view<const float, int, int, int>(
auto coo_matrix_view = raft::make_device_coo_matrix_view<const float, int, int, NNZType>(
d_distances.data_handle(),
raft::make_device_coordinate_structure_view<int, int, int>(
raft::make_device_coordinate_structure_view<int, int, NNZType>(
knn_rows.data_handle(), knn_cols.data_handle(), n_samples, n_samples, nnz));

auto sym_coo1_matrix =
raft::make_device_coo_matrix<float, int, int, int>(handle, n_samples, n_samples);
raft::sparse::linalg::coo_symmetrize<128, float, int, int>(
raft::make_device_coo_matrix<float, int, int, NNZType>(handle, n_samples, n_samples);
raft::sparse::linalg::coo_symmetrize<128, float, int, NNZType>(
handle, coo_matrix_view, sym_coo1_matrix, [] __device__(int row, int col, float a, float b) {
return 0.5f * (a + b);
});

raft::sparse::op::coo_sort<float>(n_samples,
n_samples,
sym_coo1_matrix.structure_view().get_nnz(),
sym_coo1_matrix.structure_view().get_rows().data(),
sym_coo1_matrix.structure_view().get_cols().data(),
sym_coo1_matrix.get_elements().data(),
stream);
raft::sparse::op::coo_sort<float, int, NNZType>(
n_samples,
n_samples,
sym_coo1_matrix.structure_view().get_nnz(),
sym_coo1_matrix.structure_view().get_rows().data(),
sym_coo1_matrix.structure_view().get_cols().data(),
sym_coo1_matrix.get_elements().data(),
stream);

raft::sparse::op::coo_remove_scalar<128, float, int, int>(
raft::sparse::op::coo_remove_scalar<128, float, int, NNZType>(
handle,
raft::make_device_coo_matrix_view<const float, int, int, int>(
raft::make_device_coo_matrix_view<const float, int, int, NNZType>(
sym_coo1_matrix.get_elements().data(), sym_coo1_matrix.structure_view()),
raft::make_host_scalar<float>(0.0f).view(),
connectivity_graph);
Expand All @@ -257,17 +214,14 @@ void transform(raft::resources const& handle,
const int n_samples = dataset.extent(0);

auto sym_coo_matrix =
raft::make_device_coo_matrix<float, int, int, int>(handle, n_samples, n_samples);
auto sym_coo_row_ind = raft::make_device_vector<int>(handle, n_samples + 1);
auto diagonal = raft::make_device_vector<float, int>(handle, n_samples);
raft::make_device_coo_matrix<float, int, int, int64_t>(handle, n_samples, n_samples);
auto diagonal = raft::make_device_vector<float, int>(handle, n_samples);

create_connectivity_graph(handle, spectral_embedding_config, dataset, sym_coo_matrix);
auto csr_matrix_view =
coo_to_csr_matrix<float>(handle, n_samples, sym_coo_row_ind.view(), sym_coo_matrix.view());
auto laplacian =
create_laplacian<float>(handle, spectral_embedding_config, csr_matrix_view, diagonal.view());
create_connectivity_graph<int64_t>(handle, spectral_embedding_config, dataset, sym_coo_matrix);
auto laplacian = create_laplacian<float, raft::device_coo_matrix<float, int, int, int64_t>>(
handle, spectral_embedding_config, sym_coo_matrix.view(), diagonal.view());
compute_eigenpairs<float>(
handle, spectral_embedding_config, n_samples, laplacian, diagonal.view(), embedding);
handle, spectral_embedding_config, n_samples, laplacian.view(), diagonal.view(), embedding);
}

} // namespace cuvs::preprocessing::spectral_embedding::detail
24 changes: 14 additions & 10 deletions cpp/src/preprocessing/spectral/spectral_embedding.cu
Original file line number Diff line number Diff line change
@@ -1,25 +1,29 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION.
* SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION.
* SPDX-License-Identifier: Apache-2.0
*/

#include "./detail/spectral_embedding.cuh"

#include <cuvs/preprocessing/spectral_embedding.hpp>

#include <cstdint>

namespace cuvs::preprocessing::spectral_embedding {

#define CUVS_INST_SPECTRAL_EMBEDDING(DataT) \
void transform(raft::resources const& handle, \
params config, \
raft::device_coo_matrix_view<DataT, int, int, int> connectivity_graph, \
raft::device_matrix_view<DataT, int, raft::col_major> embedding) \
{ \
detail::transform<DataT>(handle, config, connectivity_graph, embedding); \
#define CUVS_INST_SPECTRAL_EMBEDDING(DataT, NNZType) \
void transform(raft::resources const& handle, \
params config, \
raft::device_coo_matrix_view<DataT, int, int, NNZType> connectivity_graph, \
raft::device_matrix_view<DataT, int, raft::col_major> embedding) \
{ \
detail::transform<DataT, NNZType>(handle, config, connectivity_graph, embedding); \
}

CUVS_INST_SPECTRAL_EMBEDDING(float);
CUVS_INST_SPECTRAL_EMBEDDING(double);
CUVS_INST_SPECTRAL_EMBEDDING(float, int);
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need the int instantiations here? Or can we skip them and stick to int64_t only?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can keep the int ones to avoid breaking cuml and remove them later.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes sounds good.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tracking here #1695

CUVS_INST_SPECTRAL_EMBEDDING(float, int64_t);
CUVS_INST_SPECTRAL_EMBEDDING(double, int);
CUVS_INST_SPECTRAL_EMBEDDING(double, int64_t);

#undef CUVS_INST_SPECTRAL_EMBEDDING

Expand Down