diff --git a/sycl/include/CL/sycl/feature_test.hpp.in b/sycl/include/CL/sycl/feature_test.hpp.in index e6053ebf4ff1c..44962a4dde012 100644 --- a/sycl/include/CL/sycl/feature_test.hpp.in +++ b/sycl/include/CL/sycl/feature_test.hpp.in @@ -5,6 +5,12 @@ // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // // ===--------------------------------------------------------------------=== // + +// +// IMPORTANT: feature_test.hpp is a generated file - DO NOT EDIT +// original definitions are in feature_test.hpp.in +// + #pragma once #include @@ -35,6 +41,7 @@ namespace sycl { #define SYCL_EXT_ONEAPI_MATRIX 2 #endif #define SYCL_EXT_ONEAPI_ASSERT 1 +#define SYCL_EXT_ONEAPI_COMPLEX_ALGORITHMS 1 #define SYCL_EXT_ONEAPI_DISCARD_QUEUE_EVENTS 1 #define SYCL_EXT_ONEAPI_ENQUEUE_BARRIER 1 #define SYCL_EXT_ONEAPI_FREE_FUNCTION_QUERIES 1 diff --git a/sycl/include/CL/sycl/group_algorithm.hpp b/sycl/include/CL/sycl/group_algorithm.hpp index 6ede3ff23b68f..bc3a1d9834edf 100644 --- a/sycl/include/CL/sycl/group_algorithm.hpp +++ b/sycl/include/CL/sycl/group_algorithm.hpp @@ -7,6 +7,8 @@ //===----------------------------------------------------------------------===// #pragma once +#include + #include #include #include @@ -97,6 +99,50 @@ template struct is_native_op { is_contained>::value; }; +// ---- is_plus +template +using is_plus = std::integral_constant< + bool, std::is_same>::value || + std::is_same>::value>; + +// ---- is_complex +// NOTE: std::complex not yet supported by group algorithms. +template +struct is_complex + : std::integral_constant>::value || + std::is_same>::value> { +}; + +// ---- is_arithmetic_or_complex +template +using is_arithmetic_or_complex = + std::integral_constant::value || + sycl::detail::is_arithmetic::value>; +// ---- is_plus_if_complex +template +using is_plus_if_complex = + std::integral_constant::value + ? is_plus::value + : std::true_type::value)>; + +// ---- identity_for_ga_op +// the group algorithms support std::complex, limited to sycl::plus operation +// get the correct identity for group algorithm operation. +// TODO: identiy_for_ga_op should be replaced with known_identity once the other +// callers of known_identity support complex numbers. +template +constexpr detail::enable_if_t< + (is_complex::value && is_plus::value), T> +identity_for_ga_op() { + return {0, 0}; +} + +template +constexpr detail::enable_if_t::value, T> identity_for_ga_op() { + return sycl::known_identity_v; +} + // ---- for_each template Function for_each(Group g, Ptr first, Ptr last, Function f) { @@ -119,6 +165,9 @@ Function for_each(Group g, Ptr first, Ptr last, Function f) { } // namespace detail // ---- reduce_over_group +// three argument variant is specialized thrice: +// scalar arithmetic, complex (plus only), and vector arithmetic + template detail::enable_if_t<(is_group_v> && detail::is_scalar_arithmetic::value && @@ -141,6 +190,29 @@ reduce_over_group(Group, T x, BinaryOperation binary_op) { #endif } +// complex specialization. T is std::complex or similar. +// binary op is sycl::plus> +template +detail::enable_if_t<(is_group_v> && + detail::is_complex::value && + detail::is_native_op>::value && + detail::is_plus::value), + T> +reduce_over_group(Group g, T x, BinaryOperation binary_op) { +#ifdef __SYCL_DEVICE_ONLY__ + T result; + result.real(reduce_over_group(g, x.real(), sycl::plus<>())); + result.imag(reduce_over_group(g, x.imag(), sycl::plus<>())); + return result; +#else + (void)g; + (void)x; + (void)binary_op; + throw runtime_error("Group algorithms are not supported on host device.", + PI_INVALID_DEVICE); +#endif +} + template detail::enable_if_t<(is_group_v> && detail::is_vector_arithmetic::value && @@ -161,13 +233,18 @@ reduce_over_group(Group g, T x, BinaryOperation binary_op) { return result; } +// four argument variant of reduce_over_group specialized twice +// (scalar arithmetic || complex), and vector_arithmetic template -detail::enable_if_t<(is_group_v> && - detail::is_scalar_arithmetic::value && - detail::is_scalar_arithmetic::value && - detail::is_native_op::value && - detail::is_native_op::value), - T> +detail::enable_if_t< + (is_group_v> && + (detail::is_scalar_arithmetic::value || detail::is_complex::value) && + (detail::is_scalar_arithmetic::value || detail::is_complex::value) && + detail::is_native_op::value && + detail::is_native_op::value && + detail::is_plus_if_complex::value && + detail::is_plus_if_complex::value), + T> reduce_over_group(Group g, V x, T init, BinaryOperation binary_op) { // FIXME: Do not special-case for half precision static_assert( @@ -216,23 +293,19 @@ reduce_over_group(Group g, V x, T init, BinaryOperation binary_op) { template detail::enable_if_t< (is_group_v> && detail::is_pointer::value && - detail::is_arithmetic::type>::value), + detail::is_arithmetic_or_complex< + typename detail::remove_pointer::type>::value && + detail::is_plus_if_complex::type, + BinaryOperation>::value), typename detail::remove_pointer::type> joint_reduce(Group g, Ptr first, Ptr last, BinaryOperation binary_op) { - using T = typename detail::remove_pointer::type; - // FIXME: Do not special-case for half precision - static_assert( - std::is_same::value || - (std::is_same::value && - std::is_same::value), - "Result type of binary_op must match reduction accumulation type."); #ifdef __SYCL_DEVICE_ONLY__ - T partial = sycl::known_identity_v; - sycl::detail::for_each(g, first, last, - [&](const T &x) { partial = binary_op(partial, x); }); - return reduce_over_group(g, partial, binary_op); + using T = typename detail::remove_pointer::type; + T init = detail::identity_for_ga_op(); + return joint_reduce(g, first, last, init, binary_op); #else (void)g; + (void)first; (void)last; (void)binary_op; throw runtime_error("Group algorithms are not supported on host device.", @@ -243,10 +316,14 @@ joint_reduce(Group g, Ptr first, Ptr last, BinaryOperation binary_op) { template detail::enable_if_t< (is_group_v> && detail::is_pointer::value && - detail::is_arithmetic::type>::value && - detail::is_arithmetic::value && + detail::is_arithmetic_or_complex< + typename detail::remove_pointer::type>::value && + detail::is_arithmetic_or_complex::value && detail::is_native_op::type, BinaryOperation>::value && + detail::is_plus_if_complex::type, + BinaryOperation>::value && + detail::is_plus_if_complex::value && detail::is_native_op::value), T> joint_reduce(Group g, Ptr first, Ptr last, T init, BinaryOperation binary_op) { @@ -257,7 +334,7 @@ joint_reduce(Group g, Ptr first, Ptr last, T init, BinaryOperation binary_op) { std::is_same::value), "Result type of binary_op must match reduction accumulation type."); #ifdef __SYCL_DEVICE_ONLY__ - T partial = sycl::known_identity_v; + T partial = detail::identity_for_ga_op(); sycl::detail::for_each( g, first, last, [&](const typename detail::remove_pointer::type &x) { partial = binary_op(partial, x); @@ -516,6 +593,9 @@ group_broadcast(Group g, T x) { } // ---- exclusive_scan_over_group +// this function has two overloads, one with three arguments and one with four +// the three argument version is specialized thrice: scalar, complex, and +// vector template detail::enable_if_t<(is_group_v> && detail::is_scalar_arithmetic::value && @@ -537,6 +617,29 @@ exclusive_scan_over_group(Group, T x, BinaryOperation binary_op) { #endif } +// complex specialization. T is std::complex or similar. +// binary op is sycl::plus> +template +detail::enable_if_t<(is_group_v> && + detail::is_complex::value && + detail::is_native_op>::value && + detail::is_plus::value), + T> +exclusive_scan_over_group(Group g, T x, BinaryOperation binary_op) { +#ifdef __SYCL_DEVICE_ONLY__ + T result; + result.real(exclusive_scan_over_group(g, x.real(), sycl::plus<>())); + result.imag(exclusive_scan_over_group(g, x.imag(), sycl::plus<>())); + return result; +#else + (void)g; + (void)x; + (void)binary_op; + throw runtime_error("Group algorithms are not supported on host device.", + PI_INVALID_DEVICE); +#endif +} + template detail::enable_if_t<(is_group_v> && detail::is_vector_arithmetic::value && @@ -557,6 +660,8 @@ exclusive_scan_over_group(Group g, T x, BinaryOperation binary_op) { return result; } +// four argument version of exclusive_scan_over_group is specialized twice +// once for vector_arithmetic, once for (scalar_arithmetic || complex) template detail::enable_if_t<(is_group_v> && detail::is_vector_arithmetic::value && @@ -580,12 +685,15 @@ exclusive_scan_over_group(Group g, V x, T init, BinaryOperation binary_op) { } template -detail::enable_if_t<(is_group_v> && - detail::is_scalar_arithmetic::value && - detail::is_scalar_arithmetic::value && - detail::is_native_op::value && - detail::is_native_op::value), - T> +detail::enable_if_t< + (is_group_v> && + (detail::is_scalar_arithmetic::value || detail::is_complex::value) && + (detail::is_scalar_arithmetic::value || detail::is_complex::value) && + detail::is_native_op::value && + detail::is_native_op::value && + detail::is_plus_if_complex::value && + detail::is_plus_if_complex::value), + T> exclusive_scan_over_group(Group g, V x, T init, BinaryOperation binary_op) { // FIXME: Do not special-case for half precision static_assert(std::is_same::value || @@ -616,12 +724,15 @@ template > && detail::is_pointer::value && detail::is_pointer::value && - detail::is_arithmetic< + detail::is_arithmetic_or_complex< typename detail::remove_pointer::type>::value && - detail::is_arithmetic::value && + detail::is_arithmetic_or_complex::value && detail::is_native_op::type, BinaryOperation>::value && - detail::is_native_op::value), + detail::is_native_op::value && + detail::is_plus_if_complex::type, + BinaryOperation>::value && + detail::is_plus_if_complex::value), OutPtr> joint_exclusive_scan(Group g, InPtr first, InPtr last, OutPtr result, T init, BinaryOperation binary_op) { @@ -670,10 +781,12 @@ template > && detail::is_pointer::value && detail::is_pointer::value && - detail::is_arithmetic< + detail::is_arithmetic_or_complex< typename detail::remove_pointer::type>::value && detail::is_native_op::type, - BinaryOperation>::value), + BinaryOperation>::value && + detail::is_plus_if_complex::type, + BinaryOperation>::value), OutPtr> joint_exclusive_scan(Group g, InPtr first, InPtr last, OutPtr result, BinaryOperation binary_op) { @@ -685,14 +798,15 @@ joint_exclusive_scan(Group g, InPtr first, InPtr last, OutPtr result, half>::value && std::is_same::value), "Result type of binary_op must match scan accumulation type."); - return joint_exclusive_scan( - g, first, last, result, - sycl::known_identity_v::type>, - binary_op); + using T = typename detail::remove_pointer::type; + T init = detail::identity_for_ga_op(); + return joint_exclusive_scan(g, first, last, result, init, binary_op); } // ---- inclusive_scan_over_group +// this function has two overloads, one with three arguments and one with four +// the three argument version is specialized thrice: vector, scalar, and +// complex template detail::enable_if_t<(is_group_v> && detail::is_vector_arithmetic::value && @@ -734,13 +848,40 @@ inclusive_scan_over_group(Group, T x, BinaryOperation binary_op) { #endif } -template +// complex specializaiton +template detail::enable_if_t<(is_group_v> && - detail::is_scalar_arithmetic::value && - detail::is_scalar_arithmetic::value && - detail::is_native_op::value && - detail::is_native_op::value), + detail::is_complex::value && + detail::is_native_op>::value && + detail::is_plus::value), T> +inclusive_scan_over_group(Group g, T x, BinaryOperation binary_op) { +#ifdef __SYCL_DEVICE_ONLY__ + T result; + result.real(inclusive_scan_over_group(g, x.real(), sycl::plus<>())); + result.imag(inclusive_scan_over_group(g, x.imag(), sycl::plus<>())); + return result; +#else + (void)g; + (void)x; + (void)binary_op; + throw runtime_error("Group algorithms are not supported on host device.", + PI_INVALID_DEVICE); +#endif +} + +// four argument version of inclusive_scan_over_group is specialized twice +// once for (scalar_arithmetic || complex) and once for vector_arithmetic +template +detail::enable_if_t< + (is_group_v> && + (detail::is_scalar_arithmetic::value || detail::is_complex::value) && + (detail::is_scalar_arithmetic::value || detail::is_complex::value) && + detail::is_native_op::value && + detail::is_native_op::value && + detail::is_plus_if_complex::value && + detail::is_plus_if_complex::value), + T> inclusive_scan_over_group(Group g, V x, BinaryOperation binary_op, T init) { // FIXME: Do not special-case for half precision static_assert(std::is_same::value || @@ -786,12 +927,15 @@ template > && detail::is_pointer::value && detail::is_pointer::value && - detail::is_arithmetic< + detail::is_arithmetic_or_complex< typename detail::remove_pointer::type>::value && - detail::is_arithmetic::value && + detail::is_arithmetic_or_complex::value && detail::is_native_op::type, BinaryOperation>::value && - detail::is_native_op::value), + detail::is_native_op::value && + detail::is_plus_if_complex::type, + BinaryOperation>::value && + detail::is_plus_if_complex::value), OutPtr> joint_inclusive_scan(Group g, InPtr first, InPtr last, OutPtr result, BinaryOperation binary_op, T init) { @@ -839,10 +983,12 @@ template > && detail::is_pointer::value && detail::is_pointer::value && - detail::is_arithmetic< + detail::is_arithmetic_or_complex< typename detail::remove_pointer::type>::value && detail::is_native_op::type, - BinaryOperation>::value), + BinaryOperation>::value && + detail::is_plus_if_complex::type, + BinaryOperation>::value), OutPtr> joint_inclusive_scan(Group g, InPtr first, InPtr last, OutPtr result, BinaryOperation binary_op) { @@ -854,10 +1000,10 @@ joint_inclusive_scan(Group g, InPtr first, InPtr last, OutPtr result, half>::value && std::is_same::value), "Result type of binary_op must match scan accumulation type."); - return joint_inclusive_scan( - g, first, last, result, binary_op, - sycl::known_identity_v::type>); + + using T = typename detail::remove_pointer::type; + T init = detail::identity_for_ga_op(); + return joint_inclusive_scan(g, first, last, result, binary_op, init); } namespace detail {