diff --git a/sycl/include/CL/sycl/group_algorithm.hpp b/sycl/include/CL/sycl/group_algorithm.hpp index 5c306e5842c1e..8277f2d0037be 100644 --- a/sycl/include/CL/sycl/group_algorithm.hpp +++ b/sycl/include/CL/sycl/group_algorithm.hpp @@ -86,12 +86,9 @@ get_local_linear_id(ext::oneapi::sub_group g) { // ---- is_native_op template using native_op_list = - type_list, ext::oneapi::bit_or, - ext::oneapi::bit_xor, ext::oneapi::bit_and, - ext::oneapi::maximum, ext::oneapi::minimum, - ext::oneapi::multiplies, sycl::plus, sycl::bit_or, - sycl::bit_xor, sycl::bit_and, sycl::maximum, - sycl::minimum, sycl::multiplies>; + type_list, sycl::bit_or, sycl::bit_xor, + sycl::bit_and, sycl::maximum, sycl::minimum, + sycl::multiplies>; template struct is_native_op { static constexpr bool value = diff --git a/sycl/include/CL/sycl/known_identity.hpp b/sycl/include/CL/sycl/known_identity.hpp index ac4041764ad30..59e51252d7c3a 100644 --- a/sycl/include/CL/sycl/known_identity.hpp +++ b/sycl/include/CL/sycl/known_identity.hpp @@ -18,53 +18,39 @@ namespace sycl { namespace detail { template -using IsPlus = bool_constant< - std::is_same>::value || - std::is_same>::value || - std::is_same>::value || - std::is_same>::value>; +using IsPlus = + bool_constant>::value || + std::is_same>::value>; template -using IsMultiplies = bool_constant< - std::is_same>::value || - std::is_same>::value || - std::is_same>::value || - std::is_same>::value>; +using IsMultiplies = + bool_constant>::value || + std::is_same>::value>; template -using IsMinimum = bool_constant< - std::is_same>::value || - std::is_same>::value || - std::is_same>::value || - std::is_same>::value>; +using IsMinimum = + bool_constant>::value || + std::is_same>::value>; template -using IsMaximum = bool_constant< - std::is_same>::value || - std::is_same>::value || - std::is_same>::value || - std::is_same>::value>; +using IsMaximum = + bool_constant>::value || + std::is_same>::value>; template -using IsBitOR = bool_constant< - std::is_same>::value || - std::is_same>::value || - std::is_same>::value || - std::is_same>::value>; +using IsBitOR = + bool_constant>::value || + std::is_same>::value>; template -using IsBitXOR = bool_constant< - std::is_same>::value || - std::is_same>::value || - std::is_same>::value || - std::is_same>::value>; +using IsBitXOR = + bool_constant>::value || + std::is_same>::value>; template -using IsBitAND = bool_constant< - std::is_same>::value || - std::is_same>::value || - std::is_same>::value || - std::is_same>::value>; +using IsBitAND = + bool_constant>::value || + std::is_same>::value>; // Identity = 0 template diff --git a/sycl/include/sycl/ext/oneapi/functional.hpp b/sycl/include/sycl/ext/oneapi/functional.hpp index 3ae3e7e4e14be..e7a1820a937f3 100644 --- a/sycl/include/sycl/ext/oneapi/functional.hpp +++ b/sycl/include/sycl/ext/oneapi/functional.hpp @@ -16,46 +16,13 @@ namespace sycl { namespace ext { namespace oneapi { -template struct minimum { - T operator()(const T &lhs, const T &rhs) const { - return std::less()(lhs, rhs) ? lhs : rhs; - } -}; - -template <> struct minimum { - struct is_transparent {}; - template - auto operator()(T &&lhs, U &&rhs) const -> - typename std::common_type::type { - return std::less<>()(std::forward(lhs), std::forward(rhs)) - ? std::forward(lhs) - : std::forward(rhs); - } -}; - -template struct maximum { - T operator()(const T &lhs, const T &rhs) const { - return std::greater()(lhs, rhs) ? lhs : rhs; - } -}; - -template <> struct maximum { - struct is_transparent {}; - template - auto operator()(T &&lhs, U &&rhs) const -> - typename std::common_type::type { - return std::greater<>()(std::forward(lhs), - std::forward(rhs)) - ? std::forward(lhs) - : std::forward(rhs); - } -}; - template using plus = std::plus; template using multiplies = std::multiplies; template using bit_or = std::bit_or; template using bit_xor = std::bit_xor; template using bit_and = std::bit_and; +template using maximum = sycl::maximum; +template using minimum = sycl::minimum; } // namespace oneapi } // namespace ext @@ -106,41 +73,29 @@ struct GroupOpTag::value>> { return Ret; \ } -// calc for sycl minimum/maximum function objects +// calc for sycl function objects __SYCL_CALC_OVERLOAD(GroupOpISigned, SMin, sycl::minimum) __SYCL_CALC_OVERLOAD(GroupOpIUnsigned, UMin, sycl::minimum) __SYCL_CALC_OVERLOAD(GroupOpFP, FMin, sycl::minimum) + __SYCL_CALC_OVERLOAD(GroupOpISigned, SMax, sycl::maximum) __SYCL_CALC_OVERLOAD(GroupOpIUnsigned, UMax, sycl::maximum) __SYCL_CALC_OVERLOAD(GroupOpFP, FMax, sycl::maximum) -// calc for oneapi function objects -__SYCL_CALC_OVERLOAD(GroupOpISigned, SMin, ext::oneapi::minimum) -__SYCL_CALC_OVERLOAD(GroupOpIUnsigned, UMin, ext::oneapi::minimum) -__SYCL_CALC_OVERLOAD(GroupOpFP, FMin, ext::oneapi::minimum) -__SYCL_CALC_OVERLOAD(GroupOpISigned, SMax, ext::oneapi::maximum) -__SYCL_CALC_OVERLOAD(GroupOpIUnsigned, UMax, ext::oneapi::maximum) -__SYCL_CALC_OVERLOAD(GroupOpFP, FMax, ext::oneapi::maximum) -__SYCL_CALC_OVERLOAD(GroupOpISigned, IAdd, ext::oneapi::plus) -__SYCL_CALC_OVERLOAD(GroupOpIUnsigned, IAdd, ext::oneapi::plus) -__SYCL_CALC_OVERLOAD(GroupOpFP, FAdd, ext::oneapi::plus) - -__SYCL_CALC_OVERLOAD(GroupOpISigned, NonUniformIMul, ext::oneapi::multiplies) -__SYCL_CALC_OVERLOAD(GroupOpIUnsigned, NonUniformIMul, - ext::oneapi::multiplies) -__SYCL_CALC_OVERLOAD(GroupOpFP, NonUniformFMul, ext::oneapi::multiplies) -__SYCL_CALC_OVERLOAD(GroupOpISigned, NonUniformBitwiseOr, - ext::oneapi::bit_or) -__SYCL_CALC_OVERLOAD(GroupOpIUnsigned, NonUniformBitwiseOr, - ext::oneapi::bit_or) -__SYCL_CALC_OVERLOAD(GroupOpISigned, NonUniformBitwiseXor, - ext::oneapi::bit_xor) -__SYCL_CALC_OVERLOAD(GroupOpIUnsigned, NonUniformBitwiseXor, - ext::oneapi::bit_xor) -__SYCL_CALC_OVERLOAD(GroupOpISigned, NonUniformBitwiseAnd, - ext::oneapi::bit_and) -__SYCL_CALC_OVERLOAD(GroupOpIUnsigned, NonUniformBitwiseAnd, - ext::oneapi::bit_and) +__SYCL_CALC_OVERLOAD(GroupOpISigned, IAdd, sycl::plus) +__SYCL_CALC_OVERLOAD(GroupOpIUnsigned, IAdd, sycl::plus) +__SYCL_CALC_OVERLOAD(GroupOpFP, FAdd, sycl::plus) + +__SYCL_CALC_OVERLOAD(GroupOpISigned, NonUniformIMul, sycl::multiplies) +__SYCL_CALC_OVERLOAD(GroupOpIUnsigned, NonUniformIMul, sycl::multiplies) +__SYCL_CALC_OVERLOAD(GroupOpFP, NonUniformFMul, sycl::multiplies) + +__SYCL_CALC_OVERLOAD(GroupOpISigned, NonUniformBitwiseOr, sycl::bit_or) +__SYCL_CALC_OVERLOAD(GroupOpIUnsigned, NonUniformBitwiseOr, sycl::bit_or) +__SYCL_CALC_OVERLOAD(GroupOpISigned, NonUniformBitwiseXor, sycl::bit_xor) +__SYCL_CALC_OVERLOAD(GroupOpIUnsigned, NonUniformBitwiseXor, sycl::bit_xor) +__SYCL_CALC_OVERLOAD(GroupOpISigned, NonUniformBitwiseAnd, sycl::bit_and) +__SYCL_CALC_OVERLOAD(GroupOpIUnsigned, NonUniformBitwiseAnd, sycl::bit_and) #undef __SYCL_CALC_OVERLOAD diff --git a/sycl/include/sycl/ext/oneapi/reduction.hpp b/sycl/include/sycl/ext/oneapi/reduction.hpp index 748ce8596fd66..f2d47d4781700 100644 --- a/sycl/include/sycl/ext/oneapi/reduction.hpp +++ b/sycl/include/sycl/ext/oneapi/reduction.hpp @@ -183,7 +183,7 @@ class reducer { /// using those operations, which are based on functionality provided by /// sycl::atomic class. /// -/// For example, it is known that 0 is identity for ext::oneapi::plus operations +/// For example, it is known that 0 is identity for sycl::plus operations /// accepting native scalar types to which scalar 0 is convertible. /// Also, for int32/64 types the atomic_combine() is lowered to /// sycl::atomic::fetch_add(). @@ -317,8 +317,7 @@ class reducer enable_if_t::type, T>::value && @@ -332,8 +331,7 @@ class reducer enable_if_t::type, T>::value && diff --git a/sycl/source/detail/reduction.cpp b/sycl/source/detail/reduction.cpp index cb814aa3a4bfa..6f268f2fb15a6 100644 --- a/sycl/source/detail/reduction.cpp +++ b/sycl/source/detail/reduction.cpp @@ -55,9 +55,10 @@ __SYCL_EXPORT uint32_t reduGetMaxNumConcurrentWorkGroups( std::shared_ptr Queue) { device Dev = Queue->get_device(); uint32_t NumThreads = Dev.get_info(); - // The heuristics require additional tuning for various devices and vendors. - // For now assuming that each of execution units have about 8 working threads - // gives good results on some known/supported GPU devices. + // TODO: The heuristics here require additional tuning for various devices + // and vendors. For now this code assumes that execution units have about + // 8 working threads, which gives good results on some known/supported + // GPU devices. if (Dev.is_gpu()) NumThreads *= 8; return NumThreads; diff --git a/sycl/test/basic_tests/reduction_known_identity.cpp b/sycl/test/basic_tests/reduction_known_identity.cpp index 6401801514a1e..ca445f95285c6 100644 --- a/sycl/test/basic_tests/reduction_known_identity.cpp +++ b/sycl/test/basic_tests/reduction_known_identity.cpp @@ -10,10 +10,10 @@ using namespace cl::sycl; template void checkCommonBasicKnownIdentity() { - static_assert(has_known_identity, T>::value); - static_assert(has_known_identity, T>::value); - static_assert(has_known_identity, T>::value); - static_assert(has_known_identity, T>::value); + static_assert(has_known_identity, T>::value); + static_assert(has_known_identity, T>::value); + static_assert(has_known_identity, T>::value); + static_assert(has_known_identity, T>::value); } template void checkCommonKnownIdentity() { @@ -100,7 +100,7 @@ int main() { // Few negative tests just to check that it does not always return true. static_assert(!has_known_identity, int>::value); - static_assert(!has_known_identity, float>::value); + static_assert(!has_known_identity, float>::value); return 0; }