-
Notifications
You must be signed in to change notification settings - Fork 184
Spectral Embedding nnz_t
#1628
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Spectral Embedding nnz_t
#1628
Changes from 7 commits
e72adcd
3ab13ea
9e7f2c8
81ebfb2
74087c9
8b653d2
a7389df
bf4e1b7
1d7dbc5
30ebe8e
f29ec0a
fb7e26b
a93ef93
a12e009
58cb448
c81cb15
3d962a3
54e3133
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,5 +1,5 @@ | ||
| /* | ||
| * SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. | ||
| * SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION. | ||
| * SPDX-License-Identifier: Apache-2.0 | ||
| */ | ||
|
|
||
|
|
@@ -11,6 +11,7 @@ | |
| #include <raft/core/device_coo_matrix.hpp> | ||
| #include <raft/core/device_mdspan.hpp> | ||
| #include <raft/core/handle.hpp> | ||
| #include <raft/core/resource/cuda_stream.hpp> | ||
| #include <raft/core/resources.hpp> | ||
| #include <raft/linalg/matrix_vector_op.cuh> | ||
| #include <raft/matrix/gather.cuh> | ||
|
|
@@ -66,18 +67,17 @@ raft::device_csr_matrix_view<DataT, int, int, int> coo_to_csr_matrix( | |
| return csr_matrix_view; | ||
| } | ||
|
|
||
| template <typename DataT> | ||
| raft::device_csr_matrix<DataT, int, int, int> create_laplacian( | ||
| raft::resources const& handle, | ||
| params spectral_embedding_config, | ||
| raft::device_csr_matrix_view<DataT, int, int, int> csr_matrix_view, | ||
| raft::device_vector_view<DataT, int> diagonal) | ||
| template <typename DataT, typename RetA, typename A> | ||
| RetA create_laplacian(raft::resources const& handle, | ||
| params spectral_embedding_config, | ||
| A csr_matrix_view, | ||
| raft::device_vector_view<DataT, int> diagonal) | ||
| { | ||
| auto laplacian = spectral_embedding_config.norm_laplacian | ||
| ? raft::sparse::linalg::laplacian_normalized(handle, csr_matrix_view, diagonal) | ||
| : raft::sparse::linalg::compute_graph_laplacian(handle, csr_matrix_view); | ||
|
|
||
| auto laplacian_elements_view = raft::make_device_vector_view<DataT, int>( | ||
| auto laplacian_elements_view = raft::make_device_vector_view<DataT>( | ||
| laplacian.get_elements().data(), laplacian.structure_view().get_nnz()); | ||
|
|
||
| raft::linalg::unary_op(handle, | ||
|
|
@@ -88,11 +88,11 @@ raft::device_csr_matrix<DataT, int, int, int> create_laplacian( | |
| return laplacian; | ||
| } | ||
|
|
||
| template <typename DataT> | ||
| template <typename DataT, typename A> | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please use a more informative name here instead of
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Addressed in 3d962a3 |
||
| void compute_eigenpairs(raft::resources const& handle, | ||
| params spectral_embedding_config, | ||
| const int n_samples, | ||
| raft::device_csr_matrix<DataT, int, int, int>& laplacian, | ||
| A laplacian_view, | ||
| raft::device_vector_view<DataT, int> diagonal, | ||
| raft::device_matrix_view<DataT, int, raft::col_major> embedding) | ||
| { | ||
|
|
@@ -110,13 +110,7 @@ void compute_eigenpairs(raft::resources const& handle, | |
| raft::make_device_matrix<DataT, int, raft::col_major>(handle, n_samples, config.n_components); | ||
|
|
||
| raft::sparse::solver::lanczos_compute_smallest_eigenvectors<int, DataT>( | ||
| handle, | ||
| config, | ||
| raft::make_device_csr_matrix_view<DataT, int, int, int>(laplacian.get_elements().data(), | ||
| laplacian.structure_view()), | ||
| std::nullopt, | ||
| eigenvalues.view(), | ||
| eigenvectors.view()); | ||
| handle, config, laplacian_view, std::nullopt, eigenvalues.view(), eigenvectors.view()); | ||
|
|
||
| if (spectral_embedding_config.norm_laplacian) { | ||
| raft::linalg::matrix_vector_op<raft::Apply::ALONG_COLUMNS>( | ||
|
|
@@ -160,23 +154,29 @@ void compute_eigenpairs(raft::resources const& handle, | |
| ); | ||
| } | ||
|
|
||
| template <typename DataT> | ||
| template <typename DataT, typename NNZType> | ||
| void transform(raft::resources const& handle, | ||
| params spectral_embedding_config, | ||
| raft::device_coo_matrix_view<DataT, int, int, int> connectivity_graph, | ||
| raft::device_coo_matrix_view<DataT, int, int, NNZType> connectivity_graph, | ||
| raft::device_matrix_view<DataT, int, raft::col_major> embedding) | ||
| { | ||
| const int n_samples = connectivity_graph.structure_view().get_n_rows(); | ||
|
|
||
| auto sym_coo_row_ind = raft::make_device_vector<int>(handle, n_samples + 1); | ||
| auto diagonal = raft::make_device_vector<DataT, int>(handle, n_samples); | ||
|
|
||
| auto csr_matrix_view = | ||
| coo_to_csr_matrix(handle, n_samples, sym_coo_row_ind.view(), connectivity_graph); | ||
| auto laplacian = | ||
| create_laplacian(handle, spectral_embedding_config, csr_matrix_view, diagonal.view()); | ||
| auto laplacian = create_laplacian<DataT, raft::device_coo_matrix<DataT, int, int, NNZType>>( | ||
| handle, spectral_embedding_config, connectivity_graph, diagonal.view()); | ||
|
|
||
| raft::sparse::op::coo_sort<DataT>(n_samples, | ||
|
aamijar marked this conversation as resolved.
Outdated
|
||
| n_samples, | ||
| laplacian.structure_view().get_nnz(), | ||
| laplacian.structure_view().get_rows().data(), | ||
| laplacian.structure_view().get_cols().data(), | ||
| laplacian.get_elements().data(), | ||
| raft::resource::get_cuda_stream(handle)); | ||
| compute_eigenpairs( | ||
| handle, spectral_embedding_config, n_samples, laplacian, diagonal.view(), embedding); | ||
| handle, spectral_embedding_config, n_samples, laplacian.view(), diagonal.view(), embedding); | ||
| } | ||
|
|
||
| void create_connectivity_graph( | ||
|
|
@@ -264,10 +264,10 @@ void transform(raft::resources const& handle, | |
| create_connectivity_graph(handle, spectral_embedding_config, dataset, sym_coo_matrix); | ||
| auto csr_matrix_view = | ||
| coo_to_csr_matrix<float>(handle, n_samples, sym_coo_row_ind.view(), sym_coo_matrix.view()); | ||
| auto laplacian = | ||
| create_laplacian<float>(handle, spectral_embedding_config, csr_matrix_view, diagonal.view()); | ||
| auto laplacian = create_laplacian<float, raft::device_csr_matrix<float, int, int, int>>( | ||
| handle, spectral_embedding_config, csr_matrix_view, diagonal.view()); | ||
| compute_eigenpairs<float>( | ||
| handle, spectral_embedding_config, n_samples, laplacian, diagonal.view(), embedding); | ||
| handle, spectral_embedding_config, n_samples, laplacian.view(), diagonal.view(), embedding); | ||
| } | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It looks like the connectivity graph in the transform function that takes the dataset as argument is assumed to have a nnz of type int. Is this intentional? Will it be updated in a follow-up PR?
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good point, I'll try to change it so that it defaults to int64_t.
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Addressed in 30ebe8e |
||
|
|
||
| } // namespace cuvs::preprocessing::spectral_embedding::detail | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,25 +1,29 @@ | ||
| /* | ||
| * SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. | ||
| * SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION. | ||
| * SPDX-License-Identifier: Apache-2.0 | ||
| */ | ||
|
|
||
| #include "./detail/spectral_embedding.cuh" | ||
|
|
||
| #include <cuvs/preprocessing/spectral_embedding.hpp> | ||
|
|
||
| #include <cstdint> | ||
|
|
||
| namespace cuvs::preprocessing::spectral_embedding { | ||
|
|
||
| #define CUVS_INST_SPECTRAL_EMBEDDING(DataT) \ | ||
| void transform(raft::resources const& handle, \ | ||
| params config, \ | ||
| raft::device_coo_matrix_view<DataT, int, int, int> connectivity_graph, \ | ||
| raft::device_matrix_view<DataT, int, raft::col_major> embedding) \ | ||
| { \ | ||
| detail::transform<DataT>(handle, config, connectivity_graph, embedding); \ | ||
| #define CUVS_INST_SPECTRAL_EMBEDDING(DataT, NNZType) \ | ||
| void transform(raft::resources const& handle, \ | ||
| params config, \ | ||
| raft::device_coo_matrix_view<DataT, int, int, NNZType> connectivity_graph, \ | ||
| raft::device_matrix_view<DataT, int, raft::col_major> embedding) \ | ||
| { \ | ||
| detail::transform<DataT, NNZType>(handle, config, connectivity_graph, embedding); \ | ||
| } | ||
|
|
||
| CUVS_INST_SPECTRAL_EMBEDDING(float); | ||
| CUVS_INST_SPECTRAL_EMBEDDING(double); | ||
| CUVS_INST_SPECTRAL_EMBEDDING(float, int); | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we need the
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We can keep the
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes sounds good.
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Tracking here #1695 |
||
| CUVS_INST_SPECTRAL_EMBEDDING(float, int64_t); | ||
| CUVS_INST_SPECTRAL_EMBEDDING(double, int); | ||
| CUVS_INST_SPECTRAL_EMBEDDING(double, int64_t); | ||
|
|
||
| #undef CUVS_INST_SPECTRAL_EMBEDDING | ||
|
|
||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.