Skip to content
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions cpp/include/cuml/manifold/spectral_embedding.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
* limitations under the License.
*/

#include <raft/core/device_coo_matrix.hpp>
#include <raft/core/device_mdspan.hpp>
#include <raft/core/resources.hpp>

Expand Down Expand Up @@ -52,4 +53,9 @@ void transform(raft::resources const& handle,
raft::device_matrix_view<float, int, raft::row_major> dataset,
raft::device_matrix_view<float, int, raft::col_major> embedding);

void transform(raft::resources const& handle,
ML::SpectralEmbedding::params config,
raft::device_coo_matrix_view<float, int, int, int> connectivity_graph,
raft::device_matrix_view<float, int, raft::col_major> embedding);

} // namespace ML::SpectralEmbedding
9 changes: 9 additions & 0 deletions cpp/src/spectral/spectral_embedding.cu
Original file line number Diff line number Diff line change
Expand Up @@ -44,4 +44,13 @@ void transform(raft::resources const& handle,
cuvs::preprocessing::spectral_embedding::transform(handle, to_cuvs(config), dataset, embedding);
}

void transform(raft::resources const& handle,
ML::SpectralEmbedding::params config,
raft::device_coo_matrix_view<float, int, int, int> connectivity_graph,
raft::device_matrix_view<float, int, raft::col_major> embedding)
{
cuvs::preprocessing::spectral_embedding::transform(
handle, to_cuvs(config), connectivity_graph, embedding);
}

} // namespace ML::SpectralEmbedding
72 changes: 24 additions & 48 deletions cpp/src/umap/init_embed/spectral_algo.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -16,23 +16,16 @@

#pragma once

#include <cuml/cluster/spectral.hpp>
#include <cuml/manifold/spectral_embedding.hpp>
#include <cuml/manifold/umapparams.h>

#include <raft/core/device_coo_matrix.hpp>
#include <raft/core/handle.hpp>
#include <raft/linalg/add.cuh>
#include <raft/linalg/transpose.cuh>
#include <raft/random/rng.cuh>
#include <raft/sparse/coo.hpp>

#include <thrust/device_ptr.h>
#include <thrust/execution_policy.h>
#include <thrust/extrema.h>

#include <stdint.h>

#include <iostream>

namespace UMAPAlgo {

namespace InitEmbed {
Expand All @@ -57,45 +50,28 @@ void launcher(const raft::handle_t& handle,
ASSERT(n > static_cast<nnz_t>(params->n_components),
"Spectral layout requires n_samples > n_components");

rmm::device_uvector<T> tmp_storage(n * params->n_components, stream);

uint64_t seed = params->random_state;

Spectral::fit_embedding(handle,
coo->rows(),
coo->cols(),
coo->vals(),
coo->nnz,
n,
params->n_components,
tmp_storage.data(),
seed);

raft::linalg::transpose(handle, tmp_storage.data(), embedding, n, params->n_components, stream);

raft::linalg::unaryOp<T>(
tmp_storage.data(),
tmp_storage.data(),
n * params->n_components,
[=] __device__(T input) { return fabsf(input); },
stream);

thrust::device_ptr<T> d_ptr = thrust::device_pointer_cast(tmp_storage.data());
T max =
*(thrust::max_element(thrust::cuda::par.on(stream), d_ptr, d_ptr + (n * params->n_components)));

// Reuse tmp_storage to add random noise
raft::random::Rng r(seed);
r.normal(tmp_storage.data(), n * params->n_components, 0.0f, 0.0001f, stream);

raft::linalg::unaryOp<T>(
embedding,
embedding,
n * params->n_components,
[=] __device__(T input) { return (10.0f / max) * input; },
stream);

raft::linalg::add(embedding, embedding, tmp_storage.data(), n * params->n_components, stream);
auto connectivity_graph_view = raft::make_device_coo_matrix_view<float, int, int, int>(
coo->vals(),
raft::make_device_coordinate_structure_view<int, int, int>(
coo->rows(), coo->cols(), n, n, coo->nnz));
Comment on lines +62 to +65
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @aamijar , could you make sure that this uses nnz_t for the nnz type instead of hardwiring them toint types? I think that should fix this issue

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @jinsolp, I think we can create a follow up issue for nnz_t types since I would need to change the cuvs api too.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tracking here rapidsai/cuvs#1243

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sounds good, thank you!


ML::SpectralEmbedding::params spectral_params;
spectral_params.n_neighbors = params->n_neighbors;
spectral_params.norm_laplacian = true;
spectral_params.drop_first = true;
spectral_params.seed = params->random_state;
spectral_params.n_components =
spectral_params.drop_first ? params->n_components + 1 : params->n_components;

auto tmp_embedding = raft::make_device_vector<float, int>(handle, n * params->n_components);
auto tmp_embedding_view = raft::make_device_matrix_view<float, int, raft::col_major>(
tmp_embedding.data_handle(), n, params->n_components);
Copy link
Copy Markdown
Contributor

@viclafargue viclafargue Jul 31, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a detail, but you could use .view() to create a view.
EDIT: Unless raft::col_major is important here.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Addressed in 8f32544


ML::SpectralEmbedding::transform(
handle, spectral_params, connectivity_graph_view, tmp_embedding_view);

raft::linalg::transpose(
handle, tmp_embedding.data_handle(), embedding, n, params->n_components, stream);

RAFT_CUDA_TRY(cudaPeekAtLastError());
}
Expand Down
Loading