|
17 | 17 | #pragma once |
18 | 18 |
|
19 | 19 | // for cmath: |
20 | | -#include "raft/core/device_csr_matrix.hpp" |
21 | 20 | #define _USE_MATH_DEFINES |
22 | 21 |
|
23 | 22 | #include <raft/core/detail/macros.hpp> |
| 23 | +#include <raft/core/device_csr_matrix.hpp> |
24 | 24 | #include <raft/core/device_mdspan.hpp> |
25 | 25 | #include <raft/core/host_mdarray.hpp> |
26 | 26 | #include <raft/core/host_mdspan.hpp> |
@@ -1544,38 +1544,40 @@ void lanczos_solve_ritz( |
1544 | 1544 | } |
1545 | 1545 |
|
1546 | 1546 | template <typename index_type_t, typename value_type_t> |
1547 | | -void lanczos_aux(raft::resources const& handle, |
1548 | | - // spectral::matrix::sparse_matrix_t<index_type_t, value_type_t> const* A, |
1549 | | - raft::device_csr_matrix_view<value_type_t, index_type_t, index_type_t, value_type_t> A, |
1550 | | - raft::device_matrix_view<value_type_t, uint32_t, raft::row_major> V, |
1551 | | - raft::device_matrix_view<value_type_t> u, |
1552 | | - raft::device_matrix_view<value_type_t> alpha, |
1553 | | - raft::device_matrix_view<value_type_t> beta, |
1554 | | - int start_idx, |
1555 | | - int end_idx, |
1556 | | - int ncv, |
1557 | | - raft::device_matrix_view<value_type_t> v, |
1558 | | - raft::device_matrix_view<value_type_t> uu, |
1559 | | - raft::device_matrix_view<value_type_t> vv) |
| 1547 | +void lanczos_aux( |
| 1548 | + raft::resources const& handle, |
| 1549 | + // spectral::matrix::sparse_matrix_t<index_type_t, value_type_t> const* A, |
| 1550 | + raft::device_csr_matrix_view<value_type_t, index_type_t, index_type_t, index_type_t> A, |
| 1551 | + raft::device_matrix_view<value_type_t, uint32_t, raft::row_major> V, |
| 1552 | + raft::device_matrix_view<value_type_t> u, |
| 1553 | + raft::device_matrix_view<value_type_t> alpha, |
| 1554 | + raft::device_matrix_view<value_type_t> beta, |
| 1555 | + int start_idx, |
| 1556 | + int end_idx, |
| 1557 | + int ncv, |
| 1558 | + raft::device_matrix_view<value_type_t> v, |
| 1559 | + raft::device_matrix_view<value_type_t> uu, |
| 1560 | + raft::device_matrix_view<value_type_t> vv) |
1560 | 1561 | { |
1561 | 1562 | auto stream = resource::get_cuda_stream(handle); |
1562 | 1563 |
|
1563 | | - auto A_structure = A.get_structure_view(); |
1564 | | - index_type_t n = A_structure.get_n_rows(); |
| 1564 | + auto A_structure = A.structure_view(); |
| 1565 | + index_type_t n = A_structure.get_n_rows(); |
1565 | 1566 |
|
1566 | 1567 | raft::copy(v.data_handle(), &(V(start_idx, 0)), n, stream); |
1567 | 1568 |
|
1568 | 1569 | std::cout << start_idx << " " << end_idx << std::endl; |
1569 | 1570 |
|
1570 | 1571 | auto cusparse_h = resource::get_cusparse_handle(handle); |
1571 | 1572 | cusparseSpMatDescr_t cusparse_A; |
1572 | | - raft::sparse::detail::cusparsecreatecsr(&cusparse_A, |
1573 | | - A_structure.get_n_rows(), |
1574 | | - A_structure.get_n_cols(), |
1575 | | - A_structure.get_nnz(), |
1576 | | - const_cast<index_type_t*>(A_structure.get_indptr().data()), |
1577 | | - const_cast<index_type_t*>(A_structure.get_indices().data()), |
1578 | | - const_cast<value_type_t*>(A_structure.get_elements().data())); |
| 1573 | + raft::sparse::detail::cusparsecreatecsr( |
| 1574 | + &cusparse_A, |
| 1575 | + A_structure.get_n_rows(), |
| 1576 | + A_structure.get_n_cols(), |
| 1577 | + A_structure.get_nnz(), |
| 1578 | + const_cast<index_type_t*>(A_structure.get_indptr().data()), |
| 1579 | + const_cast<index_type_t*>(A_structure.get_indices().data()), |
| 1580 | + const_cast<value_type_t*>(A.get_elements().data())); |
1579 | 1581 |
|
1580 | 1582 | cusparseDnVecDescr_t cusparse_v; |
1581 | 1583 | cusparseDnVecDescr_t cusparse_u; |
@@ -1683,21 +1685,22 @@ void lanczos_aux(raft::resources const& handle, |
1683 | 1685 | } |
1684 | 1686 |
|
1685 | 1687 | template <typename index_type_t, typename value_type_t> |
1686 | | -int lanczos_smallest(raft::resources const& handle, |
1687 | | - raft::device_csr_matrix_view<value_type_t, index_type_t, index_type_t, value_type_t> A, |
1688 | | - int nEigVecs, |
1689 | | - int maxIter, |
1690 | | - int restartIter, |
1691 | | - value_type_t tol, |
1692 | | - value_type_t* eigVals_dev, |
1693 | | - value_type_t* eigVecs_dev, |
1694 | | - value_type_t* v0, |
1695 | | - uint64_t seed) |
| 1688 | +int lanczos_smallest( |
| 1689 | + raft::resources const& handle, |
| 1690 | + raft::device_csr_matrix_view<value_type_t, index_type_t, index_type_t, index_type_t> A, |
| 1691 | + int nEigVecs, |
| 1692 | + int maxIter, |
| 1693 | + int restartIter, |
| 1694 | + value_type_t tol, |
| 1695 | + value_type_t* eigVals_dev, |
| 1696 | + value_type_t* eigVecs_dev, |
| 1697 | + value_type_t* v0, |
| 1698 | + uint64_t seed) |
1696 | 1699 | { |
1697 | 1700 | auto A_structure = A.structure_view(); |
1698 | | - int n = A_structure.get_n_rows(); |
1699 | | - int ncv = restartIter; |
1700 | | - auto stream = resource::get_cuda_stream(handle); |
| 1701 | + int n = A_structure.get_n_rows(); |
| 1702 | + int ncv = restartIter; |
| 1703 | + auto stream = resource::get_cuda_stream(handle); |
1701 | 1704 |
|
1702 | 1705 | std::cout << std::fixed << std::setprecision(7); // Set precision to 10 decimal places |
1703 | 1706 |
|
@@ -1864,16 +1867,17 @@ int lanczos_smallest(raft::resources const& handle, |
1864 | 1867 |
|
1865 | 1868 | auto cusparse_h = resource::get_cusparse_handle(handle); |
1866 | 1869 | cusparseSpMatDescr_t cusparse_A; |
1867 | | - // input_config.a_indptr = const_cast<IndexType*>(x_structure.get_indptr().data()); |
1868 | | - // input_config.a_indices = const_cast<IndexType*>(x_structure.get_indices().data()); |
1869 | | - // input_config.a_data = const_cast<ElementType*>(x.get_elements().data()); |
1870 | | - raft::sparse::detail::cusparsecreatecsr(&cusparse_A, |
1871 | | - A_structure.get_n_rows(), |
1872 | | - A_structure.get_n_cols(), |
1873 | | - A_structure.get_nnz(), |
1874 | | - const_cast<index_type_t*>(A_structure.get_indptr().data()), |
1875 | | - const_cast<index_type_t*>(A_structure.get_indices().data()), |
1876 | | - const_cast<value_type_t*>(A_structure.get_elements().data())); |
| 1870 | + // input_config.a_indptr = const_cast<IndexType*>(x_structure.get_indptr().data()); |
| 1871 | + // input_config.a_indices = const_cast<IndexType*>(x_structure.get_indices().data()); |
| 1872 | + // input_config.a_data = const_cast<ElementType*>(x.get_elements().data()); |
| 1873 | + raft::sparse::detail::cusparsecreatecsr( |
| 1874 | + &cusparse_A, |
| 1875 | + A_structure.get_n_rows(), |
| 1876 | + A_structure.get_n_cols(), |
| 1877 | + A_structure.get_nnz(), |
| 1878 | + const_cast<index_type_t*>(A_structure.get_indptr().data()), |
| 1879 | + const_cast<index_type_t*>(A_structure.get_indices().data()), |
| 1880 | + const_cast<value_type_t*>(A.get_elements().data())); |
1877 | 1881 |
|
1878 | 1882 | cusparseDnVecDescr_t cusparse_v; |
1879 | 1883 | cusparseDnVecDescr_t cusparse_u; |
@@ -2058,14 +2062,14 @@ int lanczos_smallest(raft::resources const& handle, |
2058 | 2062 | template <typename IndexTypeT, typename ValueTypeT> |
2059 | 2063 | auto lanczos_compute_smallest_eigenvectors( |
2060 | 2064 | raft::resources const& handle, |
2061 | | - raft::device_csr_matrix_view<ValueTypeT, IndexTypeT, IndexTypeT, ValueTypeT> A, |
| 2065 | + raft::device_csr_matrix_view<ValueTypeT, IndexTypeT, IndexTypeT, IndexTypeT> A, |
2062 | 2066 | lanczos_solver_config<IndexTypeT, ValueTypeT> const& config, |
2063 | 2067 | raft::device_vector_view<ValueTypeT, uint32_t> v0, |
2064 | 2068 | raft::device_vector_view<ValueTypeT, uint32_t> eigenvalues, |
2065 | 2069 | raft::device_matrix_view<ValueTypeT, uint32_t, raft::col_major> eigenvectors) -> int |
2066 | 2070 | { |
2067 | 2071 | return lanczos_smallest(handle, |
2068 | | - &A, |
| 2072 | + A, |
2069 | 2073 | config.n_components, |
2070 | 2074 | config.max_iterations, |
2071 | 2075 | config.ncv, |
|
0 commit comments