-
Notifications
You must be signed in to change notification settings - Fork 808
[SYCL] Add complex support to group algorithms #5394
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 9 commits
bafeffa
4e9e321
a7b0349
1790647
ca53732
6a1cfd2
3e559f9
b5f68f2
8862bae
b9852c3
5a4990c
9823b32
a30df28
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -7,6 +7,8 @@ | |
| //===----------------------------------------------------------------------===// | ||
|
|
||
| #pragma once | ||
| #include <complex> | ||
|
|
||
| #include <CL/__spirv/spirv_ops.hpp> | ||
| #include <CL/__spirv/spirv_types.hpp> | ||
| #include <CL/__spirv/spirv_vars.hpp> | ||
|
|
@@ -20,6 +22,8 @@ | |
| #include <sycl/ext/oneapi/experimental/group_sort.hpp> | ||
| #include <sycl/ext/oneapi/functional.hpp> | ||
|
|
||
| #define SYCL_EXT_ONEAPI_COMPLEX_ALGORITHMS 1 | ||
|
|
||
| __SYCL_INLINE_NAMESPACE(cl) { | ||
| namespace sycl { | ||
| namespace detail { | ||
|
|
@@ -97,6 +101,37 @@ template <typename T, typename BinaryOperation> struct is_native_op { | |
| is_contained<BinaryOperation, native_op_list<void>>::value; | ||
| }; | ||
|
|
||
| // ---- is_complex | ||
| template <typename T> | ||
| struct is_complex | ||
| : std::integral_constant< | ||
| bool, std::is_same<T, std::complex<float>>::value || | ||
| std::is_same<T, std::complex<double>>::value || | ||
| std::is_same<T, std::complex<long double>>::value> {}; | ||
cperkinsintel marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| // ---- is_arithmetic_or_complex | ||
| template <typename T> | ||
| using is_arithmetic_or_complex = | ||
| std::integral_constant<bool, sycl::detail::is_complex<T>::value || | ||
| sycl::detail::is_arithmetic<T>::value>; | ||
|
|
||
| // ---- identity_for_ga_op | ||
| // the group algorithms support std::complex, limited to sycl::plus operation | ||
| // get the correct identity for group algorithm operation. | ||
| template <typename T, class BinaryOperation> | ||
| constexpr detail::enable_if_t< | ||
| (is_complex<T>::value && | ||
| std::is_same<BinaryOperation, sycl::plus<T>>::value), | ||
| T> | ||
| identity_for_ga_op() { | ||
| return {0, 0}; | ||
| } | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should consider whether it makes sense to extend
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @Pennycook the reason I didn't extend known_identity was that I didn't want to accidentally introduce complex support to a bunch of places where it might not be intended (all the callers of known_identity), plus this particular identity of {0,0} only works because the only operation we are worried about is plus. But that wouldn't be correct for the more general known_identity. That said, I guess it's worth asking. |
||
|
|
||
| template <typename T, class BinaryOperation> | ||
| constexpr detail::enable_if_t<!is_complex<T>::value, T> identity_for_ga_op() { | ||
| return sycl::known_identity_v<BinaryOperation, T>; | ||
| } | ||
|
|
||
| // ---- for_each | ||
| template <typename Group, typename Ptr, class Function> | ||
| Function for_each(Group g, Ptr first, Ptr last, Function f) { | ||
|
|
@@ -119,6 +154,9 @@ Function for_each(Group g, Ptr first, Ptr last, Function f) { | |
| } // namespace detail | ||
|
|
||
| // ---- reduce_over_group | ||
| // three argument variant is specialized thrice: | ||
| // scalar arithmetic, complex (plus only), and vector arithmetic | ||
|
|
||
| template <typename Group, typename T, class BinaryOperation> | ||
| detail::enable_if_t<(is_group_v<std::decay_t<Group>> && | ||
| detail::is_scalar_arithmetic<T>::value && | ||
|
|
@@ -141,6 +179,28 @@ reduce_over_group(Group, T x, BinaryOperation binary_op) { | |
| #endif | ||
| } | ||
|
|
||
| // complex specialization. T is std::complex<float> or similar. | ||
| // binary op is sycl::plus<std::complex<float>> | ||
| template <typename Group, typename T> | ||
| detail::enable_if_t<(is_group_v<std::decay_t<Group>> && | ||
| detail::is_complex<T>::value && | ||
| detail::is_native_op<T, sycl::plus<T>>::value), | ||
cperkinsintel marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| T> | ||
| reduce_over_group(Group g, T x, sycl::plus<T> binary_op) { | ||
| #ifdef __SYCL_DEVICE_ONLY__ | ||
| T result; | ||
| result.real(reduce_over_group(g, x.real(), sycl::plus<>())); | ||
| result.imag(reduce_over_group(g, x.imag(), sycl::plus<>())); | ||
| return result; | ||
| #else | ||
| (void)g; | ||
| (void)x; | ||
| (void)binary_op; | ||
| throw runtime_error("Group algorithms are not supported on host device.", | ||
| PI_INVALID_DEVICE); | ||
| #endif | ||
| } | ||
|
|
||
| template <typename Group, typename T, class BinaryOperation> | ||
| detail::enable_if_t<(is_group_v<std::decay_t<Group>> && | ||
| detail::is_vector_arithmetic<T>::value && | ||
|
|
@@ -161,13 +221,16 @@ reduce_over_group(Group g, T x, BinaryOperation binary_op) { | |
| return result; | ||
| } | ||
|
|
||
| // four argument variant of reduce_over_group specialized twice | ||
| // (scalar arithmetic || complex), and vector_arithmetic | ||
| template <typename Group, typename V, typename T, class BinaryOperation> | ||
| detail::enable_if_t<(is_group_v<std::decay_t<Group>> && | ||
| detail::is_scalar_arithmetic<V>::value && | ||
| detail::is_scalar_arithmetic<T>::value && | ||
| detail::is_native_op<V, BinaryOperation>::value && | ||
| detail::is_native_op<T, BinaryOperation>::value), | ||
| T> | ||
| detail::enable_if_t< | ||
| (is_group_v<std::decay_t<Group>> && | ||
| (detail::is_scalar_arithmetic<V>::value || detail::is_complex<V>::value) && | ||
| (detail::is_scalar_arithmetic<T>::value || detail::is_complex<T>::value) && | ||
| detail::is_native_op<V, BinaryOperation>::value && | ||
| detail::is_native_op<T, BinaryOperation>::value), | ||
| T> | ||
cperkinsintel marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| reduce_over_group(Group g, V x, T init, BinaryOperation binary_op) { | ||
| // FIXME: Do not special-case for half precision | ||
| static_assert( | ||
|
|
@@ -214,25 +277,19 @@ reduce_over_group(Group g, V x, T init, BinaryOperation binary_op) { | |
|
|
||
| // ---- joint_reduce | ||
| template <typename Group, typename Ptr, class BinaryOperation> | ||
| detail::enable_if_t< | ||
| (is_group_v<std::decay_t<Group>> && detail::is_pointer<Ptr>::value && | ||
| detail::is_arithmetic<typename detail::remove_pointer<Ptr>::type>::value), | ||
| typename detail::remove_pointer<Ptr>::type> | ||
| detail::enable_if_t<(is_group_v<std::decay_t<Group>> && | ||
| detail::is_pointer<Ptr>::value && | ||
| detail::is_arithmetic_or_complex< | ||
| typename detail::remove_pointer<Ptr>::type>::value), | ||
| typename detail::remove_pointer<Ptr>::type> | ||
| joint_reduce(Group g, Ptr first, Ptr last, BinaryOperation binary_op) { | ||
| using T = typename detail::remove_pointer<Ptr>::type; | ||
| // FIXME: Do not special-case for half precision | ||
| static_assert( | ||
| std::is_same<decltype(binary_op(*first, *first)), T>::value || | ||
| (std::is_same<T, half>::value && | ||
| std::is_same<decltype(binary_op(*first, *first)), float>::value), | ||
| "Result type of binary_op must match reduction accumulation type."); | ||
| #ifdef __SYCL_DEVICE_ONLY__ | ||
| T partial = sycl::known_identity_v<BinaryOperation, T>; | ||
| sycl::detail::for_each(g, first, last, | ||
| [&](const T &x) { partial = binary_op(partial, x); }); | ||
| return reduce_over_group(g, partial, binary_op); | ||
| using T = typename detail::remove_pointer<Ptr>::type; | ||
| T init = detail::identity_for_ga_op<T, BinaryOperation>(); | ||
| return joint_reduce(g, first, last, init, binary_op); | ||
| #else | ||
| (void)g; | ||
| (void)first; | ||
| (void)last; | ||
| (void)binary_op; | ||
| throw runtime_error("Group algorithms are not supported on host device.", | ||
|
|
@@ -243,8 +300,9 @@ joint_reduce(Group g, Ptr first, Ptr last, BinaryOperation binary_op) { | |
| template <typename Group, typename Ptr, typename T, class BinaryOperation> | ||
| detail::enable_if_t< | ||
| (is_group_v<std::decay_t<Group>> && detail::is_pointer<Ptr>::value && | ||
| detail::is_arithmetic<typename detail::remove_pointer<Ptr>::type>::value && | ||
| detail::is_arithmetic<T>::value && | ||
| detail::is_arithmetic_or_complex< | ||
| typename detail::remove_pointer<Ptr>::type>::value && | ||
| detail::is_arithmetic_or_complex<T>::value && | ||
| detail::is_native_op<typename detail::remove_pointer<Ptr>::type, | ||
| BinaryOperation>::value && | ||
| detail::is_native_op<T, BinaryOperation>::value), | ||
|
|
@@ -257,7 +315,7 @@ joint_reduce(Group g, Ptr first, Ptr last, T init, BinaryOperation binary_op) { | |
| std::is_same<decltype(binary_op(init, *first)), float>::value), | ||
| "Result type of binary_op must match reduction accumulation type."); | ||
| #ifdef __SYCL_DEVICE_ONLY__ | ||
| T partial = sycl::known_identity_v<BinaryOperation, T>; | ||
| T partial = detail::identity_for_ga_op<T, BinaryOperation>(); | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. isn't it better to update sycl::known_identity to let it accept std::complex? Otherwise, there may be confusion when known_identity states it is unknown but functionality is supported.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We are asking this question of @v-klochkov above. My feeling is that known_identity is used by the reductions and other parts of SYCL where std::complex is not supported. Adding complex there may not be wise. Furthermore, the identity for complex numbers is {0,0} only in the case of addition - which is the only binary operation the group algorithms support. But known_identity accepts all the binary operations - so having unsupported ones may change its API in ways that confuse it or needlessly increase its test surface.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
This is true, but if it's simple to add
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am ok to NOT extend known_identity with the complex type. but it is worth commenting with motivation on the definition of identity_for_ga_op.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I was hoping v-klochkov would be able to comment on this, but he is not going to be available to comment any time soon. We should not wait on him.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I agree that sycl::has_known_identity should return true for std::plusstd::complex. |
||
| sycl::detail::for_each( | ||
| g, first, last, [&](const typename detail::remove_pointer<Ptr>::type &x) { | ||
| partial = binary_op(partial, x); | ||
|
|
@@ -516,6 +574,9 @@ group_broadcast(Group g, T x) { | |
| } | ||
|
|
||
| // ---- exclusive_scan_over_group | ||
| // this function has two overloads, one with three arguments and one with four | ||
| // the three argument version is specialized thrice: scalar, complex, and | ||
| // vector | ||
| template <typename Group, typename T, class BinaryOperation> | ||
| detail::enable_if_t<(is_group_v<std::decay_t<Group>> && | ||
| detail::is_scalar_arithmetic<T>::value && | ||
|
|
@@ -537,6 +598,28 @@ exclusive_scan_over_group(Group, T x, BinaryOperation binary_op) { | |
| #endif | ||
| } | ||
|
|
||
| // complex specialization. T is std::complex<float> or similar. | ||
| // binary op is sycl::plus<std::complex<float>> | ||
| template <typename Group, typename T> | ||
| detail::enable_if_t<(is_group_v<std::decay_t<Group>> && | ||
| detail::is_complex<T>::value && | ||
| detail::is_native_op<T, sycl::plus<T>>::value), | ||
| T> | ||
| exclusive_scan_over_group(Group g, T x, sycl::plus<T> binary_op) { | ||
| #ifdef __SYCL_DEVICE_ONLY__ | ||
| T result; | ||
| result.real(exclusive_scan_over_group(g, x.real(), sycl::plus<>())); | ||
| result.imag(exclusive_scan_over_group(g, x.imag(), sycl::plus<>())); | ||
| return result; | ||
| #else | ||
| (void)g; | ||
| (void)x; | ||
| (void)binary_op; | ||
| throw runtime_error("Group algorithms are not supported on host device.", | ||
| PI_INVALID_DEVICE); | ||
| #endif | ||
| } | ||
|
|
||
| template <typename Group, typename T, class BinaryOperation> | ||
| detail::enable_if_t<(is_group_v<std::decay_t<Group>> && | ||
| detail::is_vector_arithmetic<T>::value && | ||
|
|
@@ -557,6 +640,8 @@ exclusive_scan_over_group(Group g, T x, BinaryOperation binary_op) { | |
| return result; | ||
| } | ||
|
|
||
| // four argument version of exclusive_scan_over_group is specialized twice | ||
| // once for vector_arithmetic, once for (scalar_arithmetic || complex) | ||
| template <typename Group, typename V, typename T, class BinaryOperation> | ||
| detail::enable_if_t<(is_group_v<std::decay_t<Group>> && | ||
| detail::is_vector_arithmetic<V>::value && | ||
|
|
@@ -580,12 +665,13 @@ exclusive_scan_over_group(Group g, V x, T init, BinaryOperation binary_op) { | |
| } | ||
|
|
||
| template <typename Group, typename V, typename T, class BinaryOperation> | ||
| detail::enable_if_t<(is_group_v<std::decay_t<Group>> && | ||
| detail::is_scalar_arithmetic<V>::value && | ||
| detail::is_scalar_arithmetic<T>::value && | ||
| detail::is_native_op<V, BinaryOperation>::value && | ||
| detail::is_native_op<T, BinaryOperation>::value), | ||
| T> | ||
| detail::enable_if_t< | ||
| (is_group_v<std::decay_t<Group>> && | ||
| (detail::is_scalar_arithmetic<V>::value || detail::is_complex<V>::value) && | ||
| (detail::is_scalar_arithmetic<T>::value || detail::is_complex<T>::value) && | ||
| detail::is_native_op<V, BinaryOperation>::value && | ||
| detail::is_native_op<T, BinaryOperation>::value), | ||
| T> | ||
| exclusive_scan_over_group(Group g, V x, T init, BinaryOperation binary_op) { | ||
| // FIXME: Do not special-case for half precision | ||
| static_assert(std::is_same<decltype(binary_op(init, x)), T>::value || | ||
|
|
@@ -616,9 +702,9 @@ template <typename Group, typename InPtr, typename OutPtr, typename T, | |
| detail::enable_if_t< | ||
| (is_group_v<std::decay_t<Group>> && detail::is_pointer<InPtr>::value && | ||
| detail::is_pointer<OutPtr>::value && | ||
| detail::is_arithmetic< | ||
| detail::is_arithmetic_or_complex< | ||
| typename detail::remove_pointer<InPtr>::type>::value && | ||
| detail::is_arithmetic<T>::value && | ||
| detail::is_arithmetic_or_complex<T>::value && | ||
| detail::is_native_op<typename detail::remove_pointer<InPtr>::type, | ||
| BinaryOperation>::value && | ||
| detail::is_native_op<T, BinaryOperation>::value), | ||
|
|
@@ -670,7 +756,7 @@ template <typename Group, typename InPtr, typename OutPtr, | |
| detail::enable_if_t< | ||
| (is_group_v<std::decay_t<Group>> && detail::is_pointer<InPtr>::value && | ||
| detail::is_pointer<OutPtr>::value && | ||
| detail::is_arithmetic< | ||
| detail::is_arithmetic_or_complex< | ||
| typename detail::remove_pointer<InPtr>::type>::value && | ||
| detail::is_native_op<typename detail::remove_pointer<InPtr>::type, | ||
| BinaryOperation>::value), | ||
|
|
@@ -685,14 +771,15 @@ joint_exclusive_scan(Group g, InPtr first, InPtr last, OutPtr result, | |
| half>::value && | ||
| std::is_same<decltype(binary_op(*first, *first)), float>::value), | ||
| "Result type of binary_op must match scan accumulation type."); | ||
| return joint_exclusive_scan( | ||
| g, first, last, result, | ||
| sycl::known_identity_v<BinaryOperation, | ||
| typename detail::remove_pointer<OutPtr>::type>, | ||
| binary_op); | ||
| using T = typename detail::remove_pointer<InPtr>::type; | ||
| T init = detail::identity_for_ga_op<T, BinaryOperation>(); | ||
| return joint_exclusive_scan(g, first, last, result, init, binary_op); | ||
| } | ||
|
|
||
| // ---- inclusive_scan_over_group | ||
| // this function has two overloads, one with three arguments and one with four | ||
| // the three argument version is specialized thrice: vector, scalar, and | ||
| // complex | ||
| template <typename Group, typename T, class BinaryOperation> | ||
| detail::enable_if_t<(is_group_v<std::decay_t<Group>> && | ||
| detail::is_vector_arithmetic<T>::value && | ||
|
|
@@ -734,13 +821,37 @@ inclusive_scan_over_group(Group, T x, BinaryOperation binary_op) { | |
| #endif | ||
| } | ||
|
|
||
| template <typename Group, typename V, class BinaryOperation, typename T> | ||
| // complex specializaiton | ||
| template <typename Group, typename T, class BinaryOperation> | ||
| detail::enable_if_t<(is_group_v<std::decay_t<Group>> && | ||
| detail::is_scalar_arithmetic<V>::value && | ||
| detail::is_scalar_arithmetic<T>::value && | ||
| detail::is_native_op<V, BinaryOperation>::value && | ||
| detail::is_native_op<T, BinaryOperation>::value), | ||
| detail::is_complex<T>::value && | ||
| detail::is_native_op<T, sycl::plus<T>>::value), | ||
| T> | ||
| inclusive_scan_over_group(Group g, T x, BinaryOperation binary_op) { | ||
| #ifdef __SYCL_DEVICE_ONLY__ | ||
| T result; | ||
| result.real(inclusive_scan_over_group(g, x.real(), sycl::plus<>())); | ||
| result.imag(inclusive_scan_over_group(g, x.imag(), sycl::plus<>())); | ||
| return result; | ||
| #else | ||
| (void)g; | ||
| (void)x; | ||
| (void)binary_op; | ||
| throw runtime_error("Group algorithms are not supported on host device.", | ||
| PI_INVALID_DEVICE); | ||
| #endif | ||
| } | ||
|
|
||
| // four argument version of exclusive_scan_over_group is specialized twice | ||
| // once for (scalar_arithmetic || complex) and once for vector_arithmetic | ||
| template <typename Group, typename V, class BinaryOperation, typename T> | ||
| detail::enable_if_t< | ||
| (is_group_v<std::decay_t<Group>> && | ||
| (detail::is_scalar_arithmetic<V>::value || detail::is_complex<V>::value) && | ||
| (detail::is_scalar_arithmetic<T>::value || detail::is_complex<T>::value) && | ||
| detail::is_native_op<V, BinaryOperation>::value && | ||
| detail::is_native_op<T, BinaryOperation>::value), | ||
| T> | ||
| inclusive_scan_over_group(Group g, V x, BinaryOperation binary_op, T init) { | ||
| // FIXME: Do not special-case for half precision | ||
| static_assert(std::is_same<decltype(binary_op(init, x)), T>::value || | ||
|
|
@@ -786,9 +897,9 @@ template <typename Group, typename InPtr, typename OutPtr, | |
| detail::enable_if_t< | ||
| (is_group_v<std::decay_t<Group>> && detail::is_pointer<InPtr>::value && | ||
| detail::is_pointer<OutPtr>::value && | ||
| detail::is_arithmetic< | ||
| detail::is_arithmetic_or_complex< | ||
| typename detail::remove_pointer<InPtr>::type>::value && | ||
| detail::is_arithmetic<T>::value && | ||
| detail::is_arithmetic_or_complex<T>::value && | ||
| detail::is_native_op<typename detail::remove_pointer<InPtr>::type, | ||
| BinaryOperation>::value && | ||
| detail::is_native_op<T, BinaryOperation>::value), | ||
|
|
@@ -839,7 +950,7 @@ template <typename Group, typename InPtr, typename OutPtr, | |
| detail::enable_if_t< | ||
| (is_group_v<std::decay_t<Group>> && detail::is_pointer<InPtr>::value && | ||
| detail::is_pointer<OutPtr>::value && | ||
| detail::is_arithmetic< | ||
| detail::is_arithmetic_or_complex< | ||
| typename detail::remove_pointer<InPtr>::type>::value && | ||
| detail::is_native_op<typename detail::remove_pointer<InPtr>::type, | ||
| BinaryOperation>::value), | ||
|
|
@@ -854,10 +965,10 @@ joint_inclusive_scan(Group g, InPtr first, InPtr last, OutPtr result, | |
| half>::value && | ||
| std::is_same<decltype(binary_op(*first, *first)), float>::value), | ||
| "Result type of binary_op must match scan accumulation type."); | ||
| return joint_inclusive_scan( | ||
| g, first, last, result, binary_op, | ||
| sycl::known_identity_v<BinaryOperation, | ||
| typename detail::remove_pointer<OutPtr>::type>); | ||
|
|
||
| using T = typename detail::remove_pointer<InPtr>::type; | ||
| T init = detail::identity_for_ga_op<T, BinaryOperation>(); | ||
| return joint_inclusive_scan(g, first, last, result, binary_op, init); | ||
| } | ||
|
|
||
| namespace detail { | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.