Skip to content

Commit 7b31108

Browse files
committed
resolving pr comments
1 parent 74908f2 commit 7b31108

2 files changed

Lines changed: 67 additions & 56 deletions

File tree

cpp/include/raft/sparse/solver/detail/lanczos.cuh

Lines changed: 52 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,10 @@
1717
#pragma once
1818

1919
// for cmath:
20-
#include "raft/core/device_csr_matrix.hpp"
2120
#define _USE_MATH_DEFINES
2221

2322
#include <raft/core/detail/macros.hpp>
23+
#include <raft/core/device_csr_matrix.hpp>
2424
#include <raft/core/device_mdspan.hpp>
2525
#include <raft/core/host_mdarray.hpp>
2626
#include <raft/core/host_mdspan.hpp>
@@ -1544,38 +1544,40 @@ void lanczos_solve_ritz(
15441544
}
15451545

15461546
template <typename index_type_t, typename value_type_t>
1547-
void lanczos_aux(raft::resources const& handle,
1548-
// spectral::matrix::sparse_matrix_t<index_type_t, value_type_t> const* A,
1549-
raft::device_csr_matrix_view<value_type_t, index_type_t, index_type_t, value_type_t> A,
1550-
raft::device_matrix_view<value_type_t, uint32_t, raft::row_major> V,
1551-
raft::device_matrix_view<value_type_t> u,
1552-
raft::device_matrix_view<value_type_t> alpha,
1553-
raft::device_matrix_view<value_type_t> beta,
1554-
int start_idx,
1555-
int end_idx,
1556-
int ncv,
1557-
raft::device_matrix_view<value_type_t> v,
1558-
raft::device_matrix_view<value_type_t> uu,
1559-
raft::device_matrix_view<value_type_t> vv)
1547+
void lanczos_aux(
1548+
raft::resources const& handle,
1549+
// spectral::matrix::sparse_matrix_t<index_type_t, value_type_t> const* A,
1550+
raft::device_csr_matrix_view<value_type_t, index_type_t, index_type_t, index_type_t> A,
1551+
raft::device_matrix_view<value_type_t, uint32_t, raft::row_major> V,
1552+
raft::device_matrix_view<value_type_t> u,
1553+
raft::device_matrix_view<value_type_t> alpha,
1554+
raft::device_matrix_view<value_type_t> beta,
1555+
int start_idx,
1556+
int end_idx,
1557+
int ncv,
1558+
raft::device_matrix_view<value_type_t> v,
1559+
raft::device_matrix_view<value_type_t> uu,
1560+
raft::device_matrix_view<value_type_t> vv)
15601561
{
15611562
auto stream = resource::get_cuda_stream(handle);
15621563

1563-
auto A_structure = A.get_structure_view();
1564-
index_type_t n = A_structure.get_n_rows();
1564+
auto A_structure = A.structure_view();
1565+
index_type_t n = A_structure.get_n_rows();
15651566

15661567
raft::copy(v.data_handle(), &(V(start_idx, 0)), n, stream);
15671568

15681569
std::cout << start_idx << " " << end_idx << std::endl;
15691570

15701571
auto cusparse_h = resource::get_cusparse_handle(handle);
15711572
cusparseSpMatDescr_t cusparse_A;
1572-
raft::sparse::detail::cusparsecreatecsr(&cusparse_A,
1573-
A_structure.get_n_rows(),
1574-
A_structure.get_n_cols(),
1575-
A_structure.get_nnz(),
1576-
const_cast<index_type_t*>(A_structure.get_indptr().data()),
1577-
const_cast<index_type_t*>(A_structure.get_indices().data()),
1578-
const_cast<value_type_t*>(A_structure.get_elements().data()));
1573+
raft::sparse::detail::cusparsecreatecsr(
1574+
&cusparse_A,
1575+
A_structure.get_n_rows(),
1576+
A_structure.get_n_cols(),
1577+
A_structure.get_nnz(),
1578+
const_cast<index_type_t*>(A_structure.get_indptr().data()),
1579+
const_cast<index_type_t*>(A_structure.get_indices().data()),
1580+
const_cast<value_type_t*>(A.get_elements().data()));
15791581

15801582
cusparseDnVecDescr_t cusparse_v;
15811583
cusparseDnVecDescr_t cusparse_u;
@@ -1683,21 +1685,22 @@ void lanczos_aux(raft::resources const& handle,
16831685
}
16841686

16851687
template <typename index_type_t, typename value_type_t>
1686-
int lanczos_smallest(raft::resources const& handle,
1687-
raft::device_csr_matrix_view<value_type_t, index_type_t, index_type_t, value_type_t> A,
1688-
int nEigVecs,
1689-
int maxIter,
1690-
int restartIter,
1691-
value_type_t tol,
1692-
value_type_t* eigVals_dev,
1693-
value_type_t* eigVecs_dev,
1694-
value_type_t* v0,
1695-
uint64_t seed)
1688+
int lanczos_smallest(
1689+
raft::resources const& handle,
1690+
raft::device_csr_matrix_view<value_type_t, index_type_t, index_type_t, index_type_t> A,
1691+
int nEigVecs,
1692+
int maxIter,
1693+
int restartIter,
1694+
value_type_t tol,
1695+
value_type_t* eigVals_dev,
1696+
value_type_t* eigVecs_dev,
1697+
value_type_t* v0,
1698+
uint64_t seed)
16961699
{
16971700
auto A_structure = A.structure_view();
1698-
int n = A_structure.get_n_rows();
1699-
int ncv = restartIter;
1700-
auto stream = resource::get_cuda_stream(handle);
1701+
int n = A_structure.get_n_rows();
1702+
int ncv = restartIter;
1703+
auto stream = resource::get_cuda_stream(handle);
17011704

17021705
std::cout << std::fixed << std::setprecision(7); // Set precision to 10 decimal places
17031706

@@ -1864,16 +1867,17 @@ int lanczos_smallest(raft::resources const& handle,
18641867

18651868
auto cusparse_h = resource::get_cusparse_handle(handle);
18661869
cusparseSpMatDescr_t cusparse_A;
1867-
// input_config.a_indptr = const_cast<IndexType*>(x_structure.get_indptr().data());
1868-
// input_config.a_indices = const_cast<IndexType*>(x_structure.get_indices().data());
1869-
// input_config.a_data = const_cast<ElementType*>(x.get_elements().data());
1870-
raft::sparse::detail::cusparsecreatecsr(&cusparse_A,
1871-
A_structure.get_n_rows(),
1872-
A_structure.get_n_cols(),
1873-
A_structure.get_nnz(),
1874-
const_cast<index_type_t*>(A_structure.get_indptr().data()),
1875-
const_cast<index_type_t*>(A_structure.get_indices().data()),
1876-
const_cast<value_type_t*>(A_structure.get_elements().data()));
1870+
// input_config.a_indptr = const_cast<IndexType*>(x_structure.get_indptr().data());
1871+
// input_config.a_indices = const_cast<IndexType*>(x_structure.get_indices().data());
1872+
// input_config.a_data = const_cast<ElementType*>(x.get_elements().data());
1873+
raft::sparse::detail::cusparsecreatecsr(
1874+
&cusparse_A,
1875+
A_structure.get_n_rows(),
1876+
A_structure.get_n_cols(),
1877+
A_structure.get_nnz(),
1878+
const_cast<index_type_t*>(A_structure.get_indptr().data()),
1879+
const_cast<index_type_t*>(A_structure.get_indices().data()),
1880+
const_cast<value_type_t*>(A.get_elements().data()));
18771881

18781882
cusparseDnVecDescr_t cusparse_v;
18791883
cusparseDnVecDescr_t cusparse_u;
@@ -2058,14 +2062,14 @@ int lanczos_smallest(raft::resources const& handle,
20582062
template <typename IndexTypeT, typename ValueTypeT>
20592063
auto lanczos_compute_smallest_eigenvectors(
20602064
raft::resources const& handle,
2061-
raft::device_csr_matrix_view<ValueTypeT, IndexTypeT, IndexTypeT, ValueTypeT> A,
2065+
raft::device_csr_matrix_view<ValueTypeT, IndexTypeT, IndexTypeT, IndexTypeT> A,
20622066
lanczos_solver_config<IndexTypeT, ValueTypeT> const& config,
20632067
raft::device_vector_view<ValueTypeT, uint32_t> v0,
20642068
raft::device_vector_view<ValueTypeT, uint32_t> eigenvalues,
20652069
raft::device_matrix_view<ValueTypeT, uint32_t, raft::col_major> eigenvectors) -> int
20662070
{
20672071
return lanczos_smallest(handle,
2068-
&A,
2072+
A,
20692073
config.n_components,
20702074
config.max_iterations,
20712075
config.ncv,

cpp/include/raft/sparse/solver/lanczos.cuh

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -48,16 +48,23 @@ auto lanczos_compute_smallest_eigenvectors(
4848
// raft::core::bitmap_view<const bitmap_t, index_t>(bitmap_d.data(), params.m, params.n);
4949

5050
// auto c = raft::make_device_csr_matrix_view<value_t>(c_data_d.data(), c_structure);
51-
51+
5252
// FIXME: move out of function
53-
auto csr_structure = raft::make_device_compressed_structure_view<IndexTypeT, IndexTypeT, IndexTypeT>(
54-
A.row_offsets_,
55-
A.col_indices_,
56-
A.ncols_,
57-
A.nrows_,
58-
static_cast<IndexTypeT>(A.nnz_));
53+
IndexTypeT ncols = A.ncols_;
54+
IndexTypeT nrows = A.nrows_;
55+
IndexTypeT nnz = A.nnz_;
56+
57+
auto csr_structure =
58+
raft::make_device_compressed_structure_view<IndexTypeT, IndexTypeT, IndexTypeT>(
59+
const_cast<IndexTypeT*>(A.row_offsets_),
60+
const_cast<IndexTypeT*>(A.col_indices_),
61+
ncols,
62+
nrows,
63+
nnz);
5964

60-
auto csr_matrix = raft::make_device_matrix_view<ValueTypeT>(A.values_, csr_structure);
65+
auto csr_matrix =
66+
raft::make_device_csr_matrix_view<ValueTypeT, IndexTypeT, IndexTypeT, IndexTypeT>(
67+
const_cast<ValueTypeT*>(A.values_), csr_structure);
6168

6269
return detail::lanczos_compute_smallest_eigenvectors<IndexTypeT, ValueTypeT>(
6370
handle, csr_matrix, config, v0, eigenvalues, eigenvectors);

0 commit comments

Comments
 (0)