@@ -28,14 +28,14 @@ namespace cuvs::distance::detail::ops {
2828 * for round-off error tolerance.
2929 * @tparam DataT
3030 */
31- template <typename DataT>
32- __device__ constexpr DataT get_clamp_precision ()
31+ template <typename DataT, typename AccT >
32+ __device__ constexpr AccT get_clamp_precision ()
3333{
3434 switch (sizeof (DataT)) {
35- case 2 : return 1e-3 ;
36- case 4 : return 1e-6 ;
37- case 8 : return 1e-15 ;
38- default : return 0 ;
35+ case 2 : return AccT{ 1e-3 } ;
36+ case 4 : return AccT{ 1e-6 } ;
37+ case 8 : return AccT{ 1e-15 } ;
38+ default : return AccT{ 0 } ;
3939 }
4040}
4141
@@ -46,19 +46,27 @@ struct l2_exp_cutlass_op {
4646
4747 __device__ l2_exp_cutlass_op () noexcept : sqrt(false ) {}
4848 __device__ l2_exp_cutlass_op (bool isSqrt) noexcept : sqrt(isSqrt) {}
49- inline __device__ AccT operator ()(DataT aNorm, DataT bNorm, DataT accVal) const noexcept
49+ inline __device__ AccT operator ()(AccT aNorm, AccT bNorm, AccT accVal) const noexcept
5050 {
51- AccT outVal = aNorm + bNorm - DataT (2.0 ) * accVal;
51+ AccT outVal = aNorm + bNorm - AccT (2.0 ) * accVal;
5252
5353 /* *
5454 * Self-neighboring points should have (aNorm == bNorm) == accVal and the dot product (accVal)
5555 * can sometimes have round-off errors, which will cause (aNorm == bNorm) ~ accVal instead.
5656 */
57- outVal = outVal * AccT (!((outVal * outVal < get_clamp_precision<AccT>()) * (aNorm == bNorm)));
57+ outVal =
58+ outVal * AccT (!((outVal * outVal < get_clamp_precision<DataT, AccT>()) * (aNorm == bNorm)));
5859 return sqrt ? raft::sqrt (outVal * static_cast <AccT>(outVal > AccT (0 ))) : outVal;
5960 }
6061
61- __device__ AccT operator ()(DataT aData) const noexcept { return aData; }
62+ __device__ AccT operator ()(DataT aData) const noexcept
63+ {
64+ if constexpr (std::is_same_v<DataT, half> && std::is_same_v<AccT, float >) {
65+ return __half2float (aData);
66+ } else {
67+ return aData;
68+ }
69+ }
6270};
6371
6472/* *
@@ -121,9 +129,9 @@ struct l2_exp_distance_op {
121129 * (accVal) can sometimes have round-off errors, which will cause (aNorm == bNorm) ~ accVal
122130 * instead.
123131 */
124- acc[i][j] =
125- val * static_cast <AccT>((val > AccT ( 0 ))) *
126- static_cast <AccT>( !((val * val < get_clamp_precision<AccT>()) * (regxn[i] == regyn[j])));
132+ acc[i][j] = val * static_cast <AccT>((val > AccT ( 0 ))) *
133+ static_cast <AccT>(
134+ !((val * val < get_clamp_precision<DataT, AccT>()) * (regxn[i] == regyn[j])));
127135 }
128136 }
129137 if (sqrt) {
0 commit comments