Skip to content

Commit 30ebe8e

Browse files
committed
dataset api int64
1 parent 1d7dbc5 commit 30ebe8e

1 file changed

Lines changed: 31 additions & 26 deletions

File tree

cpp/src/preprocessing/spectral/detail/spectral_embedding.cuh

Lines changed: 31 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -162,8 +162,7 @@ void transform(raft::resources const& handle,
162162
{
163163
const int n_samples = connectivity_graph.structure_view().get_n_rows();
164164

165-
auto sym_coo_row_ind = raft::make_device_vector<int>(handle, n_samples + 1);
166-
auto diagonal = raft::make_device_vector<DataT, int>(handle, n_samples);
165+
auto diagonal = raft::make_device_vector<DataT, int>(handle, n_samples);
167166

168167
auto laplacian = create_laplacian<DataT, raft::device_coo_matrix<DataT, int, int, NNZType>>(
169168
handle, spectral_embedding_config, connectivity_graph, diagonal.view());
@@ -179,16 +178,17 @@ void transform(raft::resources const& handle,
179178
handle, spectral_embedding_config, n_samples, laplacian.view(), diagonal.view(), embedding);
180179
}
181180

181+
template <typename NNZType>
182182
void create_connectivity_graph(
183183
raft::resources const& handle,
184184
cuvs::preprocessing::spectral_embedding::params spectral_embedding_config,
185185
raft::device_matrix_view<float, int, raft::row_major> dataset,
186-
raft::device_coo_matrix<float, int, int, int>& connectivity_graph)
186+
raft::device_coo_matrix<float, int, int, NNZType>& connectivity_graph)
187187
{
188188
const int n_samples = dataset.extent(0);
189189
const int n_features = dataset.extent(1);
190190
const int k_search = spectral_embedding_config.n_neighbors;
191-
const size_t nnz = n_samples * k_search;
191+
const NNZType nnz = static_cast<NNZType>(n_samples) * k_search;
192192

193193
auto stream = raft::resource::get_cuda_stream(handle);
194194

@@ -221,29 +221,30 @@ void create_connectivity_graph(
221221
// set all distances to 1.0f (connectivity KNN graph)
222222
raft::matrix::fill(handle, raft::make_device_vector_view(d_distances.data_handle(), nnz), 1.0f);
223223

224-
auto coo_matrix_view = raft::make_device_coo_matrix_view<const float, int, int, int>(
224+
auto coo_matrix_view = raft::make_device_coo_matrix_view<const float, int, int, NNZType>(
225225
d_distances.data_handle(),
226-
raft::make_device_coordinate_structure_view<int, int, int>(
226+
raft::make_device_coordinate_structure_view<int, int, NNZType>(
227227
knn_rows.data_handle(), knn_cols.data_handle(), n_samples, n_samples, nnz));
228228

229229
auto sym_coo1_matrix =
230-
raft::make_device_coo_matrix<float, int, int, int>(handle, n_samples, n_samples);
231-
raft::sparse::linalg::coo_symmetrize<128, float, int, int>(
230+
raft::make_device_coo_matrix<float, int, int, NNZType>(handle, n_samples, n_samples);
231+
raft::sparse::linalg::coo_symmetrize<128, float, int, NNZType>(
232232
handle, coo_matrix_view, sym_coo1_matrix, [] __device__(int row, int col, float a, float b) {
233233
return 0.5f * (a + b);
234234
});
235235

236-
raft::sparse::op::coo_sort<float>(n_samples,
237-
n_samples,
238-
sym_coo1_matrix.structure_view().get_nnz(),
239-
sym_coo1_matrix.structure_view().get_rows().data(),
240-
sym_coo1_matrix.structure_view().get_cols().data(),
241-
sym_coo1_matrix.get_elements().data(),
242-
stream);
236+
raft::sparse::op::coo_sort<float, int, NNZType>(
237+
n_samples,
238+
n_samples,
239+
sym_coo1_matrix.structure_view().get_nnz(),
240+
sym_coo1_matrix.structure_view().get_rows().data(),
241+
sym_coo1_matrix.structure_view().get_cols().data(),
242+
sym_coo1_matrix.get_elements().data(),
243+
stream);
243244

244-
raft::sparse::op::coo_remove_scalar<128, float, int, int>(
245+
raft::sparse::op::coo_remove_scalar<128, float, int, NNZType>(
245246
handle,
246-
raft::make_device_coo_matrix_view<const float, int, int, int>(
247+
raft::make_device_coo_matrix_view<const float, int, int, NNZType>(
247248
sym_coo1_matrix.get_elements().data(), sym_coo1_matrix.structure_view()),
248249
raft::make_host_scalar<float>(0.0f).view(),
249250
connectivity_graph);
@@ -257,15 +258,19 @@ void transform(raft::resources const& handle,
257258
const int n_samples = dataset.extent(0);
258259

259260
auto sym_coo_matrix =
260-
raft::make_device_coo_matrix<float, int, int, int>(handle, n_samples, n_samples);
261-
auto sym_coo_row_ind = raft::make_device_vector<int>(handle, n_samples + 1);
262-
auto diagonal = raft::make_device_vector<float, int>(handle, n_samples);
263-
264-
create_connectivity_graph(handle, spectral_embedding_config, dataset, sym_coo_matrix);
265-
auto csr_matrix_view =
266-
coo_to_csr_matrix<float>(handle, n_samples, sym_coo_row_ind.view(), sym_coo_matrix.view());
267-
auto laplacian = create_laplacian<float, raft::device_csr_matrix<float, int, int, int>>(
268-
handle, spectral_embedding_config, csr_matrix_view, diagonal.view());
261+
raft::make_device_coo_matrix<float, int, int, int64_t>(handle, n_samples, n_samples);
262+
auto diagonal = raft::make_device_vector<float, int>(handle, n_samples);
263+
264+
create_connectivity_graph<int64_t>(handle, spectral_embedding_config, dataset, sym_coo_matrix);
265+
auto laplacian = create_laplacian<float, raft::device_coo_matrix<float, int, int, int64_t>>(
266+
handle, spectral_embedding_config, sym_coo_matrix.view(), diagonal.view());
267+
raft::sparse::op::coo_sort<float, int, int64_t>(n_samples,
268+
n_samples,
269+
laplacian.structure_view().get_nnz(),
270+
laplacian.structure_view().get_rows().data(),
271+
laplacian.structure_view().get_cols().data(),
272+
laplacian.get_elements().data(),
273+
raft::resource::get_cuda_stream(handle));
269274
compute_eigenpairs<float>(
270275
handle, spectral_embedding_config, n_samples, laplacian.view(), diagonal.view(), embedding);
271276
}

0 commit comments

Comments
 (0)