diff --git a/sycl/include/CL/sycl.hpp b/sycl/include/CL/sycl.hpp index 1e986a4f264e8..63cc42e0e62c6 100644 --- a/sycl/include/CL/sycl.hpp +++ b/sycl/include/CL/sycl.hpp @@ -28,6 +28,7 @@ #include #include #include +#include #include #include #include diff --git a/sycl/include/CL/sycl/ONEAPI/functional.hpp b/sycl/include/CL/sycl/ONEAPI/functional.hpp index ba4d49d7a414f..ede18a17dd96f 100644 --- a/sycl/include/CL/sycl/ONEAPI/functional.hpp +++ b/sycl/include/CL/sycl/ONEAPI/functional.hpp @@ -7,6 +7,8 @@ //===----------------------------------------------------------------------===// #pragma once +#include + #include __SYCL_INLINE_NAMESPACE(cl) { @@ -90,6 +92,15 @@ struct GroupOpTag::value>> { return Ret; \ } +// calc for sycl minimum/maximum function objects +__SYCL_CALC_OVERLOAD(GroupOpISigned, SMin, sycl::minimum) +__SYCL_CALC_OVERLOAD(GroupOpIUnsigned, UMin, sycl::minimum) +__SYCL_CALC_OVERLOAD(GroupOpFP, FMin, sycl::minimum) +__SYCL_CALC_OVERLOAD(GroupOpISigned, SMax, sycl::maximum) +__SYCL_CALC_OVERLOAD(GroupOpIUnsigned, UMax, sycl::maximum) +__SYCL_CALC_OVERLOAD(GroupOpFP, FMax, sycl::maximum) + +// calc for ONEAPI function objects __SYCL_CALC_OVERLOAD(GroupOpISigned, SMin, ONEAPI::minimum) __SYCL_CALC_OVERLOAD(GroupOpIUnsigned, UMin, ONEAPI::minimum) __SYCL_CALC_OVERLOAD(GroupOpFP, FMin, ONEAPI::minimum) diff --git a/sycl/include/CL/sycl/functional.hpp b/sycl/include/CL/sycl/functional.hpp new file mode 100644 index 0000000000000..1549945b23a56 --- /dev/null +++ b/sycl/include/CL/sycl/functional.hpp @@ -0,0 +1,57 @@ +//==----------- functional.hpp --- SYCL functional -------------------------==// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#pragma once +#include + +__SYCL_INLINE_NAMESPACE(cl) { +namespace sycl { + +template using plus = std::plus; +template using multiplies = std::multiplies; +template using bit_or = std::bit_or; +template using bit_xor = std::bit_xor; +template using bit_and = std::bit_and; + +template struct minimum { + T operator()(const T &lhs, const T &rhs) const { + return std::less()(lhs, rhs) ? lhs : rhs; + } +}; + +template <> struct minimum { + struct is_transparent {}; + template + auto operator()(T &&lhs, U &&rhs) const -> + typename std::common_type::type { + return std::less<>()(std::forward(lhs), std::forward(rhs)) + ? std::forward(lhs) + : std::forward(rhs); + } +}; + +template struct maximum { + T operator()(const T &lhs, const T &rhs) const { + return std::greater()(lhs, rhs) ? lhs : rhs; + } +}; + +template <> struct maximum { + struct is_transparent {}; + template + auto operator()(T &&lhs, U &&rhs) const -> + typename std::common_type::type { + return std::greater<>()(std::forward(lhs), + std::forward(rhs)) + ? std::forward(lhs) + : std::forward(rhs); + } +}; + +} // namespace sycl +} // __SYCL_INLINE_NAMESPACE(cl) \ No newline at end of file diff --git a/sycl/include/CL/sycl/group_algorithm.hpp b/sycl/include/CL/sycl/group_algorithm.hpp index c7dd9680eb0f7..e36c551c89b85 100644 --- a/sycl/include/CL/sycl/group_algorithm.hpp +++ b/sycl/include/CL/sycl/group_algorithm.hpp @@ -13,6 +13,7 @@ #include #include #include +#include #include #include #include @@ -86,7 +87,9 @@ template using native_op_list = type_list, ONEAPI::bit_or, ONEAPI::bit_xor, ONEAPI::bit_and, ONEAPI::maximum, ONEAPI::minimum, - ONEAPI::multiplies>; + ONEAPI::multiplies, sycl::plus, sycl::bit_or, + sycl::bit_xor, sycl::bit_and, sycl::maximum, + sycl::minimum, sycl::multiplies>; template struct is_native_op { static constexpr bool value = diff --git a/sycl/include/CL/sycl/known_identity.hpp b/sycl/include/CL/sycl/known_identity.hpp index c37caf2da7190..a7370b08d9386 100644 --- a/sycl/include/CL/sycl/known_identity.hpp +++ b/sycl/include/CL/sycl/known_identity.hpp @@ -19,37 +19,51 @@ namespace detail { template using IsPlus = - bool_constant>::value || + bool_constant>::value || + std::is_same>::value || + std::is_same>::value || std::is_same>::value>; template using IsMultiplies = bool_constant< + std::is_same>::value || + std::is_same>::value || std::is_same>::value || std::is_same>::value>; template using IsMinimum = - bool_constant>::value || + bool_constant>::value || + std::is_same>::value || + std::is_same>::value || std::is_same>::value>; template using IsMaximum = - bool_constant>::value || + bool_constant>::value || + std::is_same>::value || + std::is_same>::value || std::is_same>::value>; template using IsBitOR = - bool_constant>::value || + bool_constant>::value || + std::is_same>::value || + std::is_same>::value || std::is_same>::value>; template using IsBitXOR = - bool_constant>::value || + bool_constant>::value || + std::is_same>::value || + std::is_same>::value || std::is_same>::value>; template using IsBitAND = - bool_constant>::value || + bool_constant>::value || + std::is_same>::value || + std::is_same>::value || std::is_same>::value>; // Identity = 0 diff --git a/sycl/test/extensions/group-algorithm.cpp b/sycl/test/extensions/group-algorithm.cpp index 938ad2ef66227..71ed9772c1a62 100644 --- a/sycl/test/extensions/group-algorithm.cpp +++ b/sycl/test/extensions/group-algorithm.cpp @@ -76,13 +76,13 @@ int main() { std::iota(input.begin(), input.end(), 0); std::fill(output.begin(), output.end(), 0); - test(q, input, output, plus<>(), 0, GeZero()); - test(q, input, output, minimum<>(), + test(q, input, output, ONEAPI::plus<>(), 0, GeZero()); + test(q, input, output, ONEAPI::minimum<>(), std::numeric_limits::max(), IsEven()); #ifdef SPIRV_1_3 test(q, input, output, - multiplies(), 1, LtZero()); + ONEAPI::multiplies(), 1, LtZero()); #endif // SPIRV_1_3 std::cout << "Test passed." << std::endl; diff --git a/sycl/test/on-device/back_to_back_collectives.cpp b/sycl/test/on-device/back_to_back_collectives.cpp index 492ca0b6a157e..fc891b888ac0c 100644 --- a/sycl/test/on-device/back_to_back_collectives.cpp +++ b/sycl/test/on-device/back_to_back_collectives.cpp @@ -47,9 +47,9 @@ int main() { auto g = it.get_group(); // Loop to increase number of back-to-back calls for (int r = 0; r < 10; ++r) { - Sum[i] = reduce(g, Input[i], plus<>()); - EScan[i] = exclusive_scan(g, Input[i], plus<>()); - IScan[i] = inclusive_scan(g, Input[i], plus<>()); + Sum[i] = reduce(g, Input[i], sycl::plus<>()); + EScan[i] = exclusive_scan(g, Input[i], sycl::plus<>()); + IScan[i] = inclusive_scan(g, Input[i], sycl::plus<>()); } }); }); diff --git a/sycl/test/on-device/group_algorithms_sycl2020/exclusive_scan.cpp b/sycl/test/on-device/group_algorithms_sycl2020/exclusive_scan.cpp index 7932cb9ccd138..1987e52656a59 100644 --- a/sycl/test/on-device/group_algorithms_sycl2020/exclusive_scan.cpp +++ b/sycl/test/on-device/group_algorithms_sycl2020/exclusive_scan.cpp @@ -138,24 +138,24 @@ int main() { std::iota(input.begin(), input.end(), 0); std::fill(output.begin(), output.end(), 0); - test(q, input, output, std::plus<>(), 0); - test(q, input, output, sycl::ONEAPI::minimum<>(), + test(q, input, output, sycl::plus<>(), 0); + test(q, input, output, sycl::minimum<>(), std::numeric_limits::max()); - test(q, input, output, sycl::ONEAPI::maximum<>(), + test(q, input, output, sycl::maximum<>(), std::numeric_limits::lowest()); - test(q, input, output, std::plus(), 0); - test(q, input, output, sycl::ONEAPI::minimum(), + test(q, input, output, sycl::plus(), 0); + test(q, input, output, sycl::minimum(), std::numeric_limits::max()); - test(q, input, output, sycl::ONEAPI::maximum(), + test(q, input, output, sycl::maximum(), std::numeric_limits::lowest()); #ifdef SPIRV_1_3 - test(q, input, output, multiplies(), + test(q, input, output, sycl::multiplies(), 1); - test(q, input, output, bit_or(), 0); - test(q, input, output, bit_xor(), 0); - test(q, input, output, bit_and(), ~0); + test(q, input, output, sycl::bit_or(), 0); + test(q, input, output, sycl::bit_xor(), 0); + test(q, input, output, sycl::bit_and(), ~0); #endif // SPIRV_1_3 std::cout << "Test passed." << std::endl; diff --git a/sycl/test/on-device/group_algorithms_sycl2020/inclusive_scan.cpp b/sycl/test/on-device/group_algorithms_sycl2020/inclusive_scan.cpp index d084574b58c6c..f855c0717efcb 100644 --- a/sycl/test/on-device/group_algorithms_sycl2020/inclusive_scan.cpp +++ b/sycl/test/on-device/group_algorithms_sycl2020/inclusive_scan.cpp @@ -138,25 +138,25 @@ int main() { std::iota(input.begin(), input.end(), 0); std::fill(output.begin(), output.end(), 0); - test(q, input, output, std::plus<>(), 0); - test(q, input, output, sycl::ONEAPI::minimum<>(), + test(q, input, output, sycl::plus<>(), 0); + test(q, input, output, sycl::minimum<>(), std::numeric_limits::max()); - test(q, input, output, sycl::ONEAPI::maximum<>(), + test(q, input, output, sycl::maximum<>(), std::numeric_limits::lowest()); - test(q, input, output, std::plus(), 0); - test(q, input, output, sycl::ONEAPI::minimum(), + test(q, input, output, sycl::plus(), 0); + test(q, input, output, sycl::minimum(), std::numeric_limits::max()); - test(q, input, output, sycl::ONEAPI::maximum(), + test(q, input, output, sycl::maximum(), std::numeric_limits::lowest()); #ifdef SPIRV_1_3 test(q, input, output, - multiplies(), 1); - test(q, input, output, bit_or(), 0); - test(q, input, output, bit_xor(), + sycl::multiplies(), 1); + test(q, input, output, sycl::bit_or(), 0); + test(q, input, output, sycl::bit_xor(), 0); - test(q, input, output, bit_and(), ~0); + test(q, input, output, sycl::bit_and(), ~0); #endif // SPIRV_1_3 std::cout << "Test passed." << std::endl; diff --git a/sycl/test/on-device/group_algorithms_sycl2020/reduce.cpp b/sycl/test/on-device/group_algorithms_sycl2020/reduce.cpp index 0daf7b4158500..4ad407cf5ccb8 100644 --- a/sycl/test/on-device/group_algorithms_sycl2020/reduce.cpp +++ b/sycl/test/on-device/group_algorithms_sycl2020/reduce.cpp @@ -74,24 +74,24 @@ int main() { std::iota(input.begin(), input.end(), 0); std::fill(output.begin(), output.end(), 0); - test(q, input, output, std::plus<>(), 0); - test(q, input, output, sycl::ONEAPI::minimum<>(), + test(q, input, output, sycl::plus<>(), 0); + test(q, input, output, sycl::minimum<>(), std::numeric_limits::max()); - test(q, input, output, sycl::ONEAPI::maximum<>(), + test(q, input, output, sycl::maximum<>(), std::numeric_limits::lowest()); - test(q, input, output, std::plus(), 0); - test(q, input, output, sycl::ONEAPI::minimum(), + test(q, input, output, sycl::plus(), 0); + test(q, input, output, sycl::minimum(), std::numeric_limits::max()); - test(q, input, output, sycl::ONEAPI::maximum(), + test(q, input, output, sycl::maximum(), std::numeric_limits::lowest()); #ifdef SPIRV_1_3 test(q, input, output, - multiplies(), 1); - test(q, input, output, bit_or(), 0); - test(q, input, output, bit_xor(), 0); - test(q, input, output, bit_and(), ~0); + sycl::multiplies(), 1); + test(q, input, output, sycl::bit_or(), 0); + test(q, input, output, sycl::bit_xor(), 0); + test(q, input, output, sycl::bit_and(), ~0); #endif // SPIRV_1_3 std::cout << "Test passed." << std::endl;