Skip to content

Commit 7a49356

Browse files
authored
Fix for hellinger metric (#1128)
The implementation of the distance function for the `hellinger` metric had two issues : 1) It did not restore data integrity correctly : double `sqrt` operation 2) It did not check for data overlap (two data chunks/tiles in the pairwise operation have overlapping data). Authors: - Victor Lafargue (https://github.com/viclafargue) Approvers: - Tarang Jain (https://github.com/tarang-jain) - Corey J. Nolet (https://github.com/cjnolet) URL: #1128
1 parent abbbe95 commit 7a49356

1 file changed

Lines changed: 33 additions & 8 deletions

File tree

cpp/src/distance/detail/distance.cuh

Lines changed: 33 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright (c) 2018-2024, NVIDIA CORPORATION.
2+
* Copyright (c) 2018-2025, NVIDIA CORPORATION.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -298,13 +298,28 @@ void distance_impl(raft::resources const& handle,
298298
{
299299
cudaStream_t stream = raft::resource::get_cuda_stream(handle);
300300

301-
// First sqrt x and y
301+
// Check if arrays overlap
302+
const DataT* x_end = x + m * k;
303+
const DataT* y_end = y + n * k;
304+
bool arrays_overlap = (x < y_end) && (y < x_end);
305+
302306
const auto raft_sqrt = raft::linalg::unaryOp<DataT, raft::sqrt_op, IdxT>;
307+
const auto raft_sq = raft::linalg::unaryOp<DataT, raft::sq_op, IdxT>;
308+
309+
if (!arrays_overlap) {
310+
// Arrays don't overlap: sqrt each array independently
311+
raft_sqrt((DataT*)x, x, m * k, raft::sqrt_op{}, stream);
312+
raft_sqrt((DataT*)y, y, n * k, raft::sqrt_op{}, stream);
313+
} else {
314+
// Arrays overlap: sqrt the union of both arrays exactly once
315+
const DataT* start = (x < y) ? x : y;
316+
const DataT* end = (x_end > y_end) ? x_end : y_end;
317+
IdxT union_size = end - start;
303318

304-
raft_sqrt((DataT*)x, x, m * k, raft::sqrt_op{}, stream);
305-
if (x != y) { raft_sqrt((DataT*)y, y, n * k, raft::sqrt_op{}, stream); }
319+
raft_sqrt((DataT*)start, start, union_size, raft::sqrt_op{}, stream);
320+
}
306321

307-
// Then calculate Hellinger distance
322+
// Calculate Hellinger distance
308323
ops::hellinger_distance_op<DataT, AccT, IdxT> distance_op{};
309324

310325
const OutT* x_norm = nullptr;
@@ -313,9 +328,19 @@ void distance_impl(raft::resources const& handle,
313328
pairwise_matrix_dispatch<decltype(distance_op), DataT, AccT, OutT, FinOpT, IdxT>(
314329
distance_op, m, n, k, x, y, x_norm, y_norm, out, fin_op, stream, is_row_major);
315330

316-
// Finally revert sqrt of x and y
317-
raft_sqrt((DataT*)x, x, m * k, raft::sqrt_op{}, stream);
318-
if (x != y) { raft_sqrt((DataT*)y, y, n * k, raft::sqrt_op{}, stream); }
331+
// Restore arrays by squaring back
332+
if (!arrays_overlap) {
333+
// Arrays don't overlap: square each array independently
334+
raft_sq((DataT*)x, x, m * k, raft::sq_op{}, stream);
335+
raft_sq((DataT*)y, y, n * k, raft::sq_op{}, stream);
336+
} else {
337+
// Arrays overlap: square the union back
338+
const DataT* start = (x < y) ? x : y;
339+
const DataT* end = (x_end > y_end) ? x_end : y_end;
340+
IdxT union_size = end - start;
341+
342+
raft_sq((DataT*)start, start, union_size, raft::sq_op{}, stream);
343+
}
319344

320345
RAFT_CUDA_TRY(cudaGetLastError());
321346
}

0 commit comments

Comments
 (0)