|
18 | 18 |
|
19 | 19 | #pragma once |
20 | 20 |
|
21 | | -#include <raft/core/mdspan.hpp> |
22 | 21 | #include <raft/sparse/solver/lanczos.cuh> |
23 | 22 | #include <raft/spectral/matrix_wrappers.hpp> |
24 | 23 |
|
@@ -58,31 +57,18 @@ struct lanczos_solver_t { |
58 | 57 | { |
59 | 58 | RAFT_EXPECTS(eigVals != nullptr, "Null eigVals buffer."); |
60 | 59 | RAFT_EXPECTS(eigVecs != nullptr, "Null eigVecs buffer."); |
61 | | - index_type_t iters{0}; // TODO: return total number of iter |
62 | | - auto lanczos_config = raft::sparse::solver::lanczos_solver_config<value_type_t>{ |
63 | | - config_.n_eigVecs, config_.maxIter, config_.restartIter, config_.tol, config_.seed}; |
64 | | - auto csr_structure = |
65 | | - raft::make_device_compressed_structure_view<index_type_t, index_type_t, index_type_t>( |
66 | | - const_cast<index_type_t*>(A.row_offsets_), |
67 | | - const_cast<index_type_t*>(A.col_indices_), |
68 | | - A.nrows_, |
69 | | - A.ncols_, |
70 | | - A.nnz_); |
71 | | - |
72 | | - auto csr_matrix = |
73 | | - raft::make_device_csr_matrix_view<value_type_t, index_type_t, index_type_t, index_type_t>( |
74 | | - const_cast<value_type_t*>(A.values_), csr_structure); |
75 | | - std::optional<raft::device_vector_view<value_type_t, uint32_t, raft::row_major>> v0_opt; |
76 | | - |
77 | | - sparse::solver::lanczos_compute_smallest_eigenvectors( |
78 | | - handle, |
79 | | - lanczos_config, |
80 | | - csr_matrix, |
81 | | - v0_opt, |
82 | | - raft::make_device_vector_view<value_type_t, uint32_t, raft::col_major>(eigVals, |
83 | | - config_.n_eigVecs), |
84 | | - raft::make_device_matrix_view<value_type_t, uint32_t, raft::col_major>( |
85 | | - eigVecs, A.nrows_, config_.n_eigVecs)); |
| 60 | + index_type_t iters{}; |
| 61 | + sparse::solver::computeSmallestEigenvectors(handle, |
| 62 | + A, |
| 63 | + config_.n_eigVecs, |
| 64 | + config_.maxIter, |
| 65 | + config_.restartIter, |
| 66 | + config_.tol, |
| 67 | + config_.reorthogonalize, |
| 68 | + iters, |
| 69 | + eigVals, |
| 70 | + eigVecs, |
| 71 | + config_.seed); |
86 | 72 |
|
87 | 73 | return iters; |
88 | 74 | } |
|
0 commit comments