11/*
2- * Copyright (c) 2018-2024 , NVIDIA CORPORATION.
2+ * Copyright (c) 2018-2025 , NVIDIA CORPORATION.
33 *
44 * Licensed under the Apache License, Version 2.0 (the "License");
55 * you may not use this file except in compliance with the License.
@@ -298,13 +298,28 @@ void distance_impl(raft::resources const& handle,
298298{
299299 cudaStream_t stream = raft::resource::get_cuda_stream (handle);
300300
301- // First sqrt x and y
301+ // Check if arrays overlap
302+ const DataT* x_end = x + m * k;
303+ const DataT* y_end = y + n * k;
304+ bool arrays_overlap = (x < y_end) && (y < x_end);
305+
302306 const auto raft_sqrt = raft::linalg::unaryOp<DataT, raft::sqrt_op, IdxT>;
307+ const auto raft_sq = raft::linalg::unaryOp<DataT, raft::sq_op, IdxT>;
308+
309+ if (!arrays_overlap) {
310+ // Arrays don't overlap: sqrt each array independently
311+ raft_sqrt ((DataT*)x, x, m * k, raft::sqrt_op{}, stream);
312+ raft_sqrt ((DataT*)y, y, n * k, raft::sqrt_op{}, stream);
313+ } else {
314+ // Arrays overlap: sqrt the union of both arrays exactly once
315+ const DataT* start = (x < y) ? x : y;
316+ const DataT* end = (x_end > y_end) ? x_end : y_end;
317+ IdxT union_size = end - start;
303318
304- raft_sqrt ((DataT*)x, x, m * k , raft::sqrt_op{}, stream);
305- if (x != y) { raft_sqrt ((DataT*)y, y, n * k, raft::sqrt_op{}, stream); }
319+ raft_sqrt ((DataT*)start, start, union_size , raft::sqrt_op{}, stream);
320+ }
306321
307- // Then calculate Hellinger distance
322+ // Calculate Hellinger distance
308323 ops::hellinger_distance_op<DataT, AccT, IdxT> distance_op{};
309324
310325 const OutT* x_norm = nullptr ;
@@ -313,9 +328,19 @@ void distance_impl(raft::resources const& handle,
313328 pairwise_matrix_dispatch<decltype (distance_op), DataT, AccT, OutT, FinOpT, IdxT>(
314329 distance_op, m, n, k, x, y, x_norm, y_norm, out, fin_op, stream, is_row_major);
315330
316- // Finally revert sqrt of x and y
317- raft_sqrt ((DataT*)x, x, m * k, raft::sqrt_op{}, stream);
318- if (x != y) { raft_sqrt ((DataT*)y, y, n * k, raft::sqrt_op{}, stream); }
331+ // Restore arrays by squaring back
332+ if (!arrays_overlap) {
333+ // Arrays don't overlap: square each array independently
334+ raft_sq ((DataT*)x, x, m * k, raft::sq_op{}, stream);
335+ raft_sq ((DataT*)y, y, n * k, raft::sq_op{}, stream);
336+ } else {
337+ // Arrays overlap: square the union back
338+ const DataT* start = (x < y) ? x : y;
339+ const DataT* end = (x_end > y_end) ? x_end : y_end;
340+ IdxT union_size = end - start;
341+
342+ raft_sq ((DataT*)start, start, union_size, raft::sq_op{}, stream);
343+ }
319344
320345 RAFT_CUDA_TRY (cudaGetLastError ());
321346}
0 commit comments