11/*
2- * Copyright (c) 2021-2024 , NVIDIA CORPORATION.
2+ * Copyright (c) 2021-2025 , NVIDIA CORPORATION.
33 *
44 * Licensed under the Apache License, Version 2.0 (the "License");
55 * you may not use this file except in compliance with the License.
1717#pragma once
1818
1919#include " ../../sparse/neighbors/cross_component_nn.cuh"
20+ #include < cuvs/distance/distance.hpp>
2021#include < raft/core/resource/cuda_stream.hpp>
22+ #include < raft/label/classlabels.cuh>
23+ #include < raft/matrix/detail/gather.cuh>
24+ #include < raft/matrix/diagonal.cuh>
2125#include < raft/sparse/op/sort.cuh>
2226#include < raft/sparse/solver/mst.cuh>
2327#include < raft/util/cuda_utils.cuh>
@@ -59,7 +63,7 @@ void merge_msts(raft::sparse::solver::Graph_COO<value_idx, value_idx, value_t>&
5963 * @tparam value_idx index type
6064 * @tparam value_t floating-point value type
6165 * @param[in] handle raft handle
62- * @param[in] X original dense data from which knn grpah was constructed
66+ * @param[in] X original dense data on device memory from which knn graph was constructed
6367 * @param[inout] msf edge list containing the mst result
6468 * @param[in] m number of rows in X
6569 * @param[in] n number of columns in X
@@ -117,6 +121,132 @@ void connect_knn_graph(
117121 merge_msts<value_idx, value_t >(msf, new_mst, stream);
118122}
119123
124+ /* *
125+ * Connect an unconnected knn graph (one in which mst returns an msf). The
126+ * device buffers underlying the Graph_COO object are modified in-place.
127+ * @tparam value_idx index type
128+ * @tparam value_t floating-point value type
129+ * @param[in] handle raft handle
130+ * @param[in] X original dense data on host memory from which knn graph was constructed
131+ * @param[inout] msf edge list containing the mst result
132+ * @param[in] m number of rows in X
133+ * @param[in] n number of columns in X
134+ * @param[in] n_components number of components in color
135+ * @param[inout] color the color labels array returned from the mst invocation
136+ * @return updated MST edge list
137+ */
138+ template <typename value_idx, typename value_t >
139+ void connect_knn_graph (
140+ raft::resources const & handle,
141+ const value_t * X,
142+ raft::sparse::solver::Graph_COO<value_idx, value_idx, value_t >& msf,
143+ size_t m,
144+ size_t n,
145+ int n_components,
146+ value_idx* color,
147+ cuvs::distance::DistanceType metric = cuvs::distance::DistanceType::L2SqrtExpanded)
148+ {
149+ auto stream = raft::resource::get_cuda_stream (handle);
150+
151+ rmm::device_uvector<value_idx> d_color_remapped (m, stream);
152+ raft::label::make_monotonic (d_color_remapped.data (), color, m, stream, true );
153+
154+ std::vector<value_idx> h_color (m);
155+ raft::copy (h_color.data (), d_color_remapped.data (), m, stream);
156+ raft::resource::sync_stream (handle, stream);
157+
158+ // make key (color) : value (vector of ids that have that color)
159+ std::unordered_map<value_idx, std::vector<value_idx>> component_map;
160+ for (value_idx i = 0 ; i < static_cast <value_idx>(m); ++i) {
161+ component_map[h_color[i]].push_back (i);
162+ }
163+
164+ std::vector<std::tuple<value_idx, value_idx, value_t >> selected_edges;
165+
166+ std::random_device rd;
167+ std::mt19937 gen (rd ());
168+ std::uniform_int_distribution<> dis;
169+
170+ std::vector<value_idx> host_u_indices;
171+ std::vector<value_idx> host_v_indices;
172+
173+ // connect i-1 component and i component
174+ for (int i = 1 ; i < n_components; ++i) {
175+ value_idx color_a = i - 1 ;
176+ value_idx color_b = i;
177+
178+ const auto & nodes_a = component_map[color_a];
179+ const auto & nodes_b = component_map[color_b];
180+
181+ // Randomly pick a data index from each component
182+ dis.param (std::uniform_int_distribution<>::param_type (0 , nodes_a.size () - 1 ));
183+ value_idx u = nodes_a[dis (gen)];
184+
185+ dis.param (std::uniform_int_distribution<>::param_type (0 , nodes_b.size () - 1 ));
186+ value_idx v = nodes_b[dis (gen)];
187+
188+ host_u_indices.push_back (u);
189+ host_v_indices.push_back (v);
190+ }
191+
192+ auto device_u_indices = raft::make_device_vector<value_idx, int64_t >(handle, n_components - 1 );
193+ auto device_v_indices = raft::make_device_vector<value_idx, int64_t >(handle, n_components - 1 );
194+
195+ raft::copy (device_u_indices.data_handle (), host_u_indices.data (), n_components - 1 , stream);
196+ raft::copy (device_v_indices.data_handle (), host_v_indices.data (), n_components - 1 , stream);
197+
198+ auto X_view = raft::make_host_matrix_view<const value_t , int64_t >(X, m, n);
199+ auto data_u = raft::make_device_matrix<value_t , int64_t >(handle, n_components - 1 , n);
200+ auto data_v = raft::make_device_matrix<value_t , int64_t >(handle, n_components - 1 , n);
201+
202+ raft::matrix::detail::gather (
203+ handle, X_view, raft::make_const_mdspan (device_u_indices.view ()), data_u.view ());
204+ raft::matrix::detail::gather (
205+ handle, X_view, raft::make_const_mdspan (device_v_indices.view ()), data_v.view ());
206+
207+ auto pairwise_dist =
208+ raft::make_device_matrix<value_t , int64_t >(handle, n_components - 1 , n_components - 1 );
209+ cuvs::distance::pairwise_distance (handle,
210+ raft::make_const_mdspan (data_u.view ()),
211+ raft::make_const_mdspan (data_v.view ()),
212+ pairwise_dist.view (),
213+ metric);
214+
215+ auto pairwise_dist_vec = raft::make_device_vector<value_t , int64_t >(handle, n_components - 1 );
216+ raft::matrix::get_diagonal (
217+ handle, raft::make_const_mdspan (pairwise_dist.view ()), pairwise_dist_vec.view ());
218+
219+ size_t new_nnz = n_components - 1 ;
220+
221+ // sort in order of rows to run sorted_coo_to_csr
222+ auto rows_begin = thrust::device_pointer_cast (device_u_indices.data_handle ());
223+ auto cols_begin = thrust::device_pointer_cast (device_v_indices.data_handle ());
224+ auto dist_begin = thrust::device_pointer_cast (pairwise_dist_vec.data_handle ());
225+
226+ auto zipped_begin = thrust::make_zip_iterator (thrust::make_tuple (cols_begin, dist_begin));
227+ thrust::sort_by_key (rows_begin, rows_begin + new_nnz, zipped_begin);
228+
229+ rmm::device_uvector<value_idx> indptr2 (m + 1 , stream);
230+ raft::sparse::convert::sorted_coo_to_csr (
231+ device_u_indices.data_handle (), new_nnz, indptr2.data (), m + 1 , stream);
232+
233+ // On the second call, we hand the MST the original colors
234+ // and the new set of edges and let it restart the optimization process
235+ auto new_mst = raft::sparse::solver::mst<value_idx, value_idx, value_t , double >(
236+ handle,
237+ indptr2.data (),
238+ device_v_indices.data_handle (),
239+ pairwise_dist_vec.data_handle (),
240+ m,
241+ new_nnz,
242+ color,
243+ stream,
244+ false ,
245+ false );
246+
247+ merge_msts<value_idx, value_t >(msf, new_mst, stream);
248+ }
249+
120250/* *
121251 * Constructs an MST and sorts the resulting edges in ascending
122252 * order by their weight.
@@ -130,6 +260,7 @@ void connect_knn_graph(
130260 * @tparam value_idx
131261 * @tparam value_t
132262 * @param[in] handle raft handle
263+ * @param[in] X dataset residing on host or device memory
133264 * @param[in] indptr CSR indptr of connectivities graph
134265 * @param[in] indices CSR indices array of connectivities graph
135266 * @param[in] pw_dists CSR weights array of connectivities graph
@@ -168,8 +299,16 @@ void build_sorted_mst(
168299 int iters = 1 ;
169300 int n_components = cuvs::sparse::neighbors::get_n_components (color, m, stream);
170301
302+ cudaPointerAttributes attr;
303+ RAFT_CUDA_TRY (cudaPointerGetAttributes (&attr, X));
304+ bool data_on_device = attr.type == cudaMemoryTypeDevice;
305+
171306 while (n_components > 1 && iters < max_iter) {
172- connect_knn_graph<value_idx, value_t >(handle, X, mst_coo, m, n, color, reduction_op);
307+ if (data_on_device) {
308+ connect_knn_graph<value_idx, value_t >(handle, X, mst_coo, m, n, color, reduction_op);
309+ } else {
310+ connect_knn_graph<value_idx, value_t >(handle, X, mst_coo, m, n, n_components, color, metric);
311+ }
173312
174313 iters++;
175314
0 commit comments