Skip to content

Commit 1c91e1f

Browse files
authored
[Fix] l2_exp random fail in half-float32 mixed precision on self-neighboring (#596)
Authors: - rhdong (https://github.com/rhdong) Approvers: - Ben Frederickson (https://github.com/benfred) URL: #596
1 parent 43969ca commit 1c91e1f

2 files changed

Lines changed: 23 additions & 16 deletions

File tree

cpp/src/distance/detail/distance_ops/l2_exp.cuh

Lines changed: 21 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -28,14 +28,14 @@ namespace cuvs::distance::detail::ops {
2828
* for round-off error tolerance.
2929
* @tparam DataT
3030
*/
31-
template <typename DataT>
32-
__device__ constexpr DataT get_clamp_precision()
31+
template <typename DataT, typename AccT>
32+
__device__ constexpr AccT get_clamp_precision()
3333
{
3434
switch (sizeof(DataT)) {
35-
case 2: return 1e-3;
36-
case 4: return 1e-6;
37-
case 8: return 1e-15;
38-
default: return 0;
35+
case 2: return AccT{1e-3};
36+
case 4: return AccT{1e-6};
37+
case 8: return AccT{1e-15};
38+
default: return AccT{0};
3939
}
4040
}
4141

@@ -46,19 +46,27 @@ struct l2_exp_cutlass_op {
4646

4747
__device__ l2_exp_cutlass_op() noexcept : sqrt(false) {}
4848
__device__ l2_exp_cutlass_op(bool isSqrt) noexcept : sqrt(isSqrt) {}
49-
inline __device__ AccT operator()(DataT aNorm, DataT bNorm, DataT accVal) const noexcept
49+
inline __device__ AccT operator()(AccT aNorm, AccT bNorm, AccT accVal) const noexcept
5050
{
51-
AccT outVal = aNorm + bNorm - DataT(2.0) * accVal;
51+
AccT outVal = aNorm + bNorm - AccT(2.0) * accVal;
5252

5353
/**
5454
* Self-neighboring points should have (aNorm == bNorm) == accVal and the dot product (accVal)
5555
* can sometimes have round-off errors, which will cause (aNorm == bNorm) ~ accVal instead.
5656
*/
57-
outVal = outVal * AccT(!((outVal * outVal < get_clamp_precision<AccT>()) * (aNorm == bNorm)));
57+
outVal =
58+
outVal * AccT(!((outVal * outVal < get_clamp_precision<DataT, AccT>()) * (aNorm == bNorm)));
5859
return sqrt ? raft::sqrt(outVal * static_cast<AccT>(outVal > AccT(0))) : outVal;
5960
}
6061

61-
__device__ AccT operator()(DataT aData) const noexcept { return aData; }
62+
__device__ AccT operator()(DataT aData) const noexcept
63+
{
64+
if constexpr (std::is_same_v<DataT, half> && std::is_same_v<AccT, float>) {
65+
return __half2float(aData);
66+
} else {
67+
return aData;
68+
}
69+
}
6270
};
6371

6472
/**
@@ -121,9 +129,9 @@ struct l2_exp_distance_op {
121129
* (accVal) can sometimes have round-off errors, which will cause (aNorm == bNorm) ~ accVal
122130
* instead.
123131
*/
124-
acc[i][j] =
125-
val * static_cast<AccT>((val > AccT(0))) *
126-
static_cast<AccT>(!((val * val < get_clamp_precision<AccT>()) * (regxn[i] == regyn[j])));
132+
acc[i][j] = val * static_cast<AccT>((val > AccT(0))) *
133+
static_cast<AccT>(
134+
!((val * val < get_clamp_precision<DataT, AccT>()) * (regxn[i] == regyn[j])));
127135
}
128136
}
129137
if (sqrt) {

python/cuvs/cuvs/test/test_distance.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from cuvs.distance import pairwise_distance
2222

2323

24+
@pytest.mark.parametrize("times", range(20))
2425
@pytest.mark.parametrize("n_rows", [50, 100])
2526
@pytest.mark.parametrize("n_cols", [10, 50])
2627
@pytest.mark.parametrize(
@@ -43,7 +44,7 @@
4344
@pytest.mark.parametrize("inplace", [True, False])
4445
@pytest.mark.parametrize("order", ["F", "C"])
4546
@pytest.mark.parametrize("dtype", [np.float32, np.float64, np.float16])
46-
def test_distance(n_rows, n_cols, inplace, order, metric, dtype):
47+
def test_distance(n_rows, n_cols, inplace, order, metric, dtype, times):
4748
input1 = np.random.random_sample((n_rows, n_cols))
4849
input1 = np.asarray(input1, order=order).astype(dtype)
4950

@@ -79,7 +80,5 @@ def test_distance(n_rows, n_cols, inplace, order, metric, dtype):
7980
actual = output_device.copy_to_host()
8081

8182
tol = 1e-3
82-
if np.issubdtype(dtype, np.float16):
83-
tol = 1e-1
8483

8584
assert np.allclose(expected, actual, atol=tol, rtol=tol)

0 commit comments

Comments
 (0)