@@ -162,8 +162,7 @@ void transform(raft::resources const& handle,
162162{
163163 const int n_samples = connectivity_graph.structure_view ().get_n_rows ();
164164
165- auto sym_coo_row_ind = raft::make_device_vector<int >(handle, n_samples + 1 );
166- auto diagonal = raft::make_device_vector<DataT, int >(handle, n_samples);
165+ auto diagonal = raft::make_device_vector<DataT, int >(handle, n_samples);
167166
168167 auto laplacian = create_laplacian<DataT, raft::device_coo_matrix<DataT, int , int , NNZType>>(
169168 handle, spectral_embedding_config, connectivity_graph, diagonal.view ());
@@ -179,16 +178,17 @@ void transform(raft::resources const& handle,
179178 handle, spectral_embedding_config, n_samples, laplacian.view (), diagonal.view (), embedding);
180179}
181180
181+ template <typename NNZType>
182182void create_connectivity_graph (
183183 raft::resources const & handle,
184184 cuvs::preprocessing::spectral_embedding::params spectral_embedding_config,
185185 raft::device_matrix_view<float , int , raft::row_major> dataset,
186- raft::device_coo_matrix<float , int , int , int >& connectivity_graph)
186+ raft::device_coo_matrix<float , int , int , NNZType >& connectivity_graph)
187187{
188188 const int n_samples = dataset.extent (0 );
189189 const int n_features = dataset.extent (1 );
190190 const int k_search = spectral_embedding_config.n_neighbors ;
191- const size_t nnz = n_samples * k_search;
191+ const NNZType nnz = static_cast <NNZType>( n_samples) * k_search;
192192
193193 auto stream = raft::resource::get_cuda_stream (handle);
194194
@@ -221,29 +221,30 @@ void create_connectivity_graph(
221221 // set all distances to 1.0f (connectivity KNN graph)
222222 raft::matrix::fill (handle, raft::make_device_vector_view (d_distances.data_handle (), nnz), 1 .0f );
223223
224- auto coo_matrix_view = raft::make_device_coo_matrix_view<const float , int , int , int >(
224+ auto coo_matrix_view = raft::make_device_coo_matrix_view<const float , int , int , NNZType >(
225225 d_distances.data_handle (),
226- raft::make_device_coordinate_structure_view<int , int , int >(
226+ raft::make_device_coordinate_structure_view<int , int , NNZType >(
227227 knn_rows.data_handle (), knn_cols.data_handle (), n_samples, n_samples, nnz));
228228
229229 auto sym_coo1_matrix =
230- raft::make_device_coo_matrix<float , int , int , int >(handle, n_samples, n_samples);
231- raft::sparse::linalg::coo_symmetrize<128 , float , int , int >(
230+ raft::make_device_coo_matrix<float , int , int , NNZType >(handle, n_samples, n_samples);
231+ raft::sparse::linalg::coo_symmetrize<128 , float , int , NNZType >(
232232 handle, coo_matrix_view, sym_coo1_matrix, [] __device__ (int row, int col, float a, float b) {
233233 return 0 .5f * (a + b);
234234 });
235235
236- raft::sparse::op::coo_sort<float >(n_samples,
237- n_samples,
238- sym_coo1_matrix.structure_view ().get_nnz (),
239- sym_coo1_matrix.structure_view ().get_rows ().data (),
240- sym_coo1_matrix.structure_view ().get_cols ().data (),
241- sym_coo1_matrix.get_elements ().data (),
242- stream);
236+ raft::sparse::op::coo_sort<float , int , NNZType>(
237+ n_samples,
238+ n_samples,
239+ sym_coo1_matrix.structure_view ().get_nnz (),
240+ sym_coo1_matrix.structure_view ().get_rows ().data (),
241+ sym_coo1_matrix.structure_view ().get_cols ().data (),
242+ sym_coo1_matrix.get_elements ().data (),
243+ stream);
243244
244- raft::sparse::op::coo_remove_scalar<128 , float , int , int >(
245+ raft::sparse::op::coo_remove_scalar<128 , float , int , NNZType >(
245246 handle,
246- raft::make_device_coo_matrix_view<const float , int , int , int >(
247+ raft::make_device_coo_matrix_view<const float , int , int , NNZType >(
247248 sym_coo1_matrix.get_elements ().data (), sym_coo1_matrix.structure_view ()),
248249 raft::make_host_scalar<float >(0 .0f ).view (),
249250 connectivity_graph);
@@ -257,15 +258,19 @@ void transform(raft::resources const& handle,
257258 const int n_samples = dataset.extent (0 );
258259
259260 auto sym_coo_matrix =
260- raft::make_device_coo_matrix<float , int , int , int >(handle, n_samples, n_samples);
261- auto sym_coo_row_ind = raft::make_device_vector<int >(handle, n_samples + 1 );
262- auto diagonal = raft::make_device_vector<float , int >(handle, n_samples);
263-
264- create_connectivity_graph (handle, spectral_embedding_config, dataset, sym_coo_matrix);
265- auto csr_matrix_view =
266- coo_to_csr_matrix<float >(handle, n_samples, sym_coo_row_ind.view (), sym_coo_matrix.view ());
267- auto laplacian = create_laplacian<float , raft::device_csr_matrix<float , int , int , int >>(
268- handle, spectral_embedding_config, csr_matrix_view, diagonal.view ());
261+ raft::make_device_coo_matrix<float , int , int , int64_t >(handle, n_samples, n_samples);
262+ auto diagonal = raft::make_device_vector<float , int >(handle, n_samples);
263+
264+ create_connectivity_graph<int64_t >(handle, spectral_embedding_config, dataset, sym_coo_matrix);
265+ auto laplacian = create_laplacian<float , raft::device_coo_matrix<float , int , int , int64_t >>(
266+ handle, spectral_embedding_config, sym_coo_matrix.view (), diagonal.view ());
267+ raft::sparse::op::coo_sort<float , int , int64_t >(n_samples,
268+ n_samples,
269+ laplacian.structure_view ().get_nnz (),
270+ laplacian.structure_view ().get_rows ().data (),
271+ laplacian.structure_view ().get_cols ().data (),
272+ laplacian.get_elements ().data (),
273+ raft::resource::get_cuda_stream (handle));
269274 compute_eigenpairs<float >(
270275 handle, spectral_embedding_config, n_samples, laplacian.view (), diagonal.view (), embedding);
271276}
0 commit comments