Skip to content
Merged
22 changes: 19 additions & 3 deletions cpp/src/umap/fuzzy_simpl_set/naive.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,8 @@ static const float MIN_K_DIST_SCALE = 1e-3;
*
*/
template <typename value_t, typename nnz_t, int TPB_X>
CUML_KERNEL void smooth_knn_dist_kernel(const value_t* knn_dists,
CUML_KERNEL void smooth_knn_dist_kernel(bool* error_status,
const value_t* knn_dists,
int n,
float mean_dist,
value_t* sigmas,
Expand Down Expand Up @@ -120,6 +121,11 @@ CUML_KERNEL void smooth_knn_dist_kernel(const value_t* knn_dists,
if (cur_dist > max_nonzero) max_nonzero = cur_dist;
}

if (start_nonzero == -1) {
*error_status = true;
return;
}

float ith_distances_mean = sum / float(n_neighbors);
if (total_nonzero >= local_connectivity) {
int index = int(floor(local_connectivity));
Expand Down Expand Up @@ -265,9 +271,19 @@ void smooth_knn_dist(nnz_t n,
/**
* Smooth kNN distances to be continuous
*/

bool has_found_an_error = false;
rmm::device_scalar<bool> error_status(stream);
error_status.set_value_async(has_found_an_error, stream);

smooth_knn_dist_kernel<value_t, nnz_t, TPB_X><<<grid, blk, 0, stream>>>(
knn_dists, n, mean_dist, sigmas, rhos, n_neighbors, local_connectivity);
RAFT_CUDA_TRY(cudaPeekAtLastError());
error_status.data(), knn_dists, n, mean_dist, sigmas, rhos, n_neighbors, local_connectivity);
RAFT_CUDA_TRY(cudaStreamSynchronize(stream));

has_found_an_error = error_status.value(stream);
if (has_found_an_error) {
throw std::runtime_error("At least one row does not have any neighbor with non-zero distance.");
}
Comment thread
viclafargue marked this conversation as resolved.
Outdated
}

template <typename value_t, typename value_idx, typename nnz_t, int TPB_X>
Expand Down
Loading