Skip to content
Merged
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
211 changes: 161 additions & 50 deletions sycl/include/CL/sycl/group_algorithm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
//===----------------------------------------------------------------------===//

#pragma once
#include <complex>

#include <CL/__spirv/spirv_ops.hpp>
#include <CL/__spirv/spirv_types.hpp>
#include <CL/__spirv/spirv_vars.hpp>
Expand All @@ -20,6 +22,8 @@
#include <sycl/ext/oneapi/experimental/group_sort.hpp>
#include <sycl/ext/oneapi/functional.hpp>

#define SYCL_EXT_ONEAPI_COMPLEX_ALGORITHMS 1

__SYCL_INLINE_NAMESPACE(cl) {
namespace sycl {
namespace detail {
Expand Down Expand Up @@ -97,6 +101,37 @@ template <typename T, typename BinaryOperation> struct is_native_op {
is_contained<BinaryOperation, native_op_list<void>>::value;
};

// ---- is_complex
template <typename T>
struct is_complex
: std::integral_constant<
bool, std::is_same<T, std::complex<float>>::value ||
std::is_same<T, std::complex<double>>::value ||
std::is_same<T, std::complex<long double>>::value> {};

// ---- is_arithmetic_or_complex
template <typename T>
using is_arithmetic_or_complex =
std::integral_constant<bool, sycl::detail::is_complex<T>::value ||
sycl::detail::is_arithmetic<T>::value>;

// ---- identity_for_ga_op
// the group algorithms support std::complex, limited to sycl::plus operation
// get the correct identity for group algorithm operation.
template <typename T, class BinaryOperation>
constexpr detail::enable_if_t<
(is_complex<T>::value &&
std::is_same<BinaryOperation, sycl::plus<T>>::value),
T>
identity_for_ga_op() {
return {0, 0};
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should consider whether it makes sense to extend sycl::known_identity_v to std::complex types. I'd like to hear from @v-klochkov -- how hard would it be to do this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Pennycook the reason I didn't extend known_identity was that I didn't want to accidentally introduce complex support to a bunch of places where it might not be intended (all the callers of known_identity), plus this particular identity of {0,0} only works because the only operation we are worried about is plus. But that wouldn't be correct for the more general known_identity.

That said, I guess it's worth asking.


template <typename T, class BinaryOperation>
constexpr detail::enable_if_t<!is_complex<T>::value, T> identity_for_ga_op() {
return sycl::known_identity_v<BinaryOperation, T>;
}

// ---- for_each
template <typename Group, typename Ptr, class Function>
Function for_each(Group g, Ptr first, Ptr last, Function f) {
Expand All @@ -119,6 +154,9 @@ Function for_each(Group g, Ptr first, Ptr last, Function f) {
} // namespace detail

// ---- reduce_over_group
// three argument variant is specialized thrice:
// scalar arithmetic, complex (plus only), and vector arithmetic

template <typename Group, typename T, class BinaryOperation>
detail::enable_if_t<(is_group_v<std::decay_t<Group>> &&
detail::is_scalar_arithmetic<T>::value &&
Expand All @@ -141,6 +179,28 @@ reduce_over_group(Group, T x, BinaryOperation binary_op) {
#endif
}

// complex specialization. T is std::complex<float> or similar.
// binary op is sycl::plus<std::complex<float>>
template <typename Group, typename T>
detail::enable_if_t<(is_group_v<std::decay_t<Group>> &&
detail::is_complex<T>::value &&
detail::is_native_op<T, sycl::plus<T>>::value),
T>
reduce_over_group(Group g, T x, sycl::plus<T> binary_op) {
#ifdef __SYCL_DEVICE_ONLY__
T result;
result.real(reduce_over_group(g, x.real(), sycl::plus<>()));
result.imag(reduce_over_group(g, x.imag(), sycl::plus<>()));
return result;
#else
(void)g;
(void)x;
(void)binary_op;
throw runtime_error("Group algorithms are not supported on host device.",
PI_INVALID_DEVICE);
#endif
}

template <typename Group, typename T, class BinaryOperation>
detail::enable_if_t<(is_group_v<std::decay_t<Group>> &&
detail::is_vector_arithmetic<T>::value &&
Expand All @@ -161,13 +221,16 @@ reduce_over_group(Group g, T x, BinaryOperation binary_op) {
return result;
}

// four argument variant of reduce_over_group specialized twice
// (scalar arithmetic || complex), and vector_arithmetic
template <typename Group, typename V, typename T, class BinaryOperation>
detail::enable_if_t<(is_group_v<std::decay_t<Group>> &&
detail::is_scalar_arithmetic<V>::value &&
detail::is_scalar_arithmetic<T>::value &&
detail::is_native_op<V, BinaryOperation>::value &&
detail::is_native_op<T, BinaryOperation>::value),
T>
detail::enable_if_t<
(is_group_v<std::decay_t<Group>> &&
(detail::is_scalar_arithmetic<V>::value || detail::is_complex<V>::value) &&
(detail::is_scalar_arithmetic<T>::value || detail::is_complex<T>::value) &&
detail::is_native_op<V, BinaryOperation>::value &&
detail::is_native_op<T, BinaryOperation>::value),
T>
reduce_over_group(Group g, V x, T init, BinaryOperation binary_op) {
// FIXME: Do not special-case for half precision
static_assert(
Expand Down Expand Up @@ -214,25 +277,19 @@ reduce_over_group(Group g, V x, T init, BinaryOperation binary_op) {

// ---- joint_reduce
template <typename Group, typename Ptr, class BinaryOperation>
detail::enable_if_t<
(is_group_v<std::decay_t<Group>> && detail::is_pointer<Ptr>::value &&
detail::is_arithmetic<typename detail::remove_pointer<Ptr>::type>::value),
typename detail::remove_pointer<Ptr>::type>
detail::enable_if_t<(is_group_v<std::decay_t<Group>> &&
detail::is_pointer<Ptr>::value &&
detail::is_arithmetic_or_complex<
typename detail::remove_pointer<Ptr>::type>::value),
typename detail::remove_pointer<Ptr>::type>
joint_reduce(Group g, Ptr first, Ptr last, BinaryOperation binary_op) {
using T = typename detail::remove_pointer<Ptr>::type;
// FIXME: Do not special-case for half precision
static_assert(
std::is_same<decltype(binary_op(*first, *first)), T>::value ||
(std::is_same<T, half>::value &&
std::is_same<decltype(binary_op(*first, *first)), float>::value),
"Result type of binary_op must match reduction accumulation type.");
#ifdef __SYCL_DEVICE_ONLY__
T partial = sycl::known_identity_v<BinaryOperation, T>;
sycl::detail::for_each(g, first, last,
[&](const T &x) { partial = binary_op(partial, x); });
return reduce_over_group(g, partial, binary_op);
using T = typename detail::remove_pointer<Ptr>::type;
T init = detail::identity_for_ga_op<T, BinaryOperation>();
return joint_reduce(g, first, last, init, binary_op);
#else
(void)g;
(void)first;
(void)last;
(void)binary_op;
throw runtime_error("Group algorithms are not supported on host device.",
Expand All @@ -243,8 +300,9 @@ joint_reduce(Group g, Ptr first, Ptr last, BinaryOperation binary_op) {
template <typename Group, typename Ptr, typename T, class BinaryOperation>
detail::enable_if_t<
(is_group_v<std::decay_t<Group>> && detail::is_pointer<Ptr>::value &&
detail::is_arithmetic<typename detail::remove_pointer<Ptr>::type>::value &&
detail::is_arithmetic<T>::value &&
detail::is_arithmetic_or_complex<
typename detail::remove_pointer<Ptr>::type>::value &&
detail::is_arithmetic_or_complex<T>::value &&
detail::is_native_op<typename detail::remove_pointer<Ptr>::type,
BinaryOperation>::value &&
detail::is_native_op<T, BinaryOperation>::value),
Expand All @@ -257,7 +315,7 @@ joint_reduce(Group g, Ptr first, Ptr last, T init, BinaryOperation binary_op) {
std::is_same<decltype(binary_op(init, *first)), float>::value),
"Result type of binary_op must match reduction accumulation type.");
#ifdef __SYCL_DEVICE_ONLY__
T partial = sycl::known_identity_v<BinaryOperation, T>;
T partial = detail::identity_for_ga_op<T, BinaryOperation>();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

isn't it better to update sycl::known_identity to let it accept std::complex? Otherwise, there may be confusion when known_identity states it is unknown but functionality is supported.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We are asking this question of @v-klochkov above. My feeling is that known_identity is used by the reductions and other parts of SYCL where std::complex is not supported. Adding complex there may not be wise. Furthermore, the identity for complex numbers is {0,0} only in the case of addition - which is the only binary operation the group algorithms support. But known_identity accepts all the binary operations - so having unsupported ones may change its API in ways that confuse it or needlessly increase its test surface.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My feeling is that known_identity is used by the reductions and other parts of SYCL where std::complex is not supported. Adding complex there may not be wise

This is true, but if it's simple to add std::complex support to has_known_identity then it might be worth us revising the extension specification to say that we're also adding support for reduction with std::complex and plus<>.

Copy link
Contributor

@vladimirlaz vladimirlaz Feb 1, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am ok to NOT extend known_identity with the complex type. but it is worth commenting with motivation on the definition of identity_for_ga_op.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was hoping v-klochkov would be able to comment on this, but he is not going to be available to comment any time soon. We should not wait on him.
I have expanded the comment on identity_for_ga_op and left a //TODO that we should eliminate it once complex number are supported by the other reductions.
@vladimirlaz and @Pennycook would this be enough for now?
@Pennycook do you want to open a ticket for complex support for the reductions? (I think they are the primary other caller of known_identity right now).

Copy link
Contributor

@v-klochkov v-klochkov Feb 22, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree that sycl::has_known_identity should return true for std::plusstd::complex.
It allows to enable std::complex+std::plus for reductions for free.

sycl::detail::for_each(
g, first, last, [&](const typename detail::remove_pointer<Ptr>::type &x) {
partial = binary_op(partial, x);
Expand Down Expand Up @@ -516,6 +574,9 @@ group_broadcast(Group g, T x) {
}

// ---- exclusive_scan_over_group
// this function has two overloads, one with three arguments and one with four
// the three argument version is specialized thrice: scalar, complex, and
// vector
template <typename Group, typename T, class BinaryOperation>
detail::enable_if_t<(is_group_v<std::decay_t<Group>> &&
detail::is_scalar_arithmetic<T>::value &&
Expand All @@ -537,6 +598,28 @@ exclusive_scan_over_group(Group, T x, BinaryOperation binary_op) {
#endif
}

// complex specialization. T is std::complex<float> or similar.
// binary op is sycl::plus<std::complex<float>>
template <typename Group, typename T>
detail::enable_if_t<(is_group_v<std::decay_t<Group>> &&
detail::is_complex<T>::value &&
detail::is_native_op<T, sycl::plus<T>>::value),
T>
exclusive_scan_over_group(Group g, T x, sycl::plus<T> binary_op) {
#ifdef __SYCL_DEVICE_ONLY__
T result;
result.real(exclusive_scan_over_group(g, x.real(), sycl::plus<>()));
result.imag(exclusive_scan_over_group(g, x.imag(), sycl::plus<>()));
return result;
#else
(void)g;
(void)x;
(void)binary_op;
throw runtime_error("Group algorithms are not supported on host device.",
PI_INVALID_DEVICE);
#endif
}

template <typename Group, typename T, class BinaryOperation>
detail::enable_if_t<(is_group_v<std::decay_t<Group>> &&
detail::is_vector_arithmetic<T>::value &&
Expand All @@ -557,6 +640,8 @@ exclusive_scan_over_group(Group g, T x, BinaryOperation binary_op) {
return result;
}

// four argument version of exclusive_scan_over_group is specialized twice
// once for vector_arithmetic, once for (scalar_arithmetic || complex)
template <typename Group, typename V, typename T, class BinaryOperation>
detail::enable_if_t<(is_group_v<std::decay_t<Group>> &&
detail::is_vector_arithmetic<V>::value &&
Expand All @@ -580,12 +665,13 @@ exclusive_scan_over_group(Group g, V x, T init, BinaryOperation binary_op) {
}

template <typename Group, typename V, typename T, class BinaryOperation>
detail::enable_if_t<(is_group_v<std::decay_t<Group>> &&
detail::is_scalar_arithmetic<V>::value &&
detail::is_scalar_arithmetic<T>::value &&
detail::is_native_op<V, BinaryOperation>::value &&
detail::is_native_op<T, BinaryOperation>::value),
T>
detail::enable_if_t<
(is_group_v<std::decay_t<Group>> &&
(detail::is_scalar_arithmetic<V>::value || detail::is_complex<V>::value) &&
(detail::is_scalar_arithmetic<T>::value || detail::is_complex<T>::value) &&
detail::is_native_op<V, BinaryOperation>::value &&
detail::is_native_op<T, BinaryOperation>::value),
T>
exclusive_scan_over_group(Group g, V x, T init, BinaryOperation binary_op) {
// FIXME: Do not special-case for half precision
static_assert(std::is_same<decltype(binary_op(init, x)), T>::value ||
Expand Down Expand Up @@ -616,9 +702,9 @@ template <typename Group, typename InPtr, typename OutPtr, typename T,
detail::enable_if_t<
(is_group_v<std::decay_t<Group>> && detail::is_pointer<InPtr>::value &&
detail::is_pointer<OutPtr>::value &&
detail::is_arithmetic<
detail::is_arithmetic_or_complex<
typename detail::remove_pointer<InPtr>::type>::value &&
detail::is_arithmetic<T>::value &&
detail::is_arithmetic_or_complex<T>::value &&
detail::is_native_op<typename detail::remove_pointer<InPtr>::type,
BinaryOperation>::value &&
detail::is_native_op<T, BinaryOperation>::value),
Expand Down Expand Up @@ -670,7 +756,7 @@ template <typename Group, typename InPtr, typename OutPtr,
detail::enable_if_t<
(is_group_v<std::decay_t<Group>> && detail::is_pointer<InPtr>::value &&
detail::is_pointer<OutPtr>::value &&
detail::is_arithmetic<
detail::is_arithmetic_or_complex<
typename detail::remove_pointer<InPtr>::type>::value &&
detail::is_native_op<typename detail::remove_pointer<InPtr>::type,
BinaryOperation>::value),
Expand All @@ -685,14 +771,15 @@ joint_exclusive_scan(Group g, InPtr first, InPtr last, OutPtr result,
half>::value &&
std::is_same<decltype(binary_op(*first, *first)), float>::value),
"Result type of binary_op must match scan accumulation type.");
return joint_exclusive_scan(
g, first, last, result,
sycl::known_identity_v<BinaryOperation,
typename detail::remove_pointer<OutPtr>::type>,
binary_op);
using T = typename detail::remove_pointer<InPtr>::type;
T init = detail::identity_for_ga_op<T, BinaryOperation>();
return joint_exclusive_scan(g, first, last, result, init, binary_op);
}

// ---- inclusive_scan_over_group
// this function has two overloads, one with three arguments and one with four
// the three argument version is specialized thrice: vector, scalar, and
// complex
template <typename Group, typename T, class BinaryOperation>
detail::enable_if_t<(is_group_v<std::decay_t<Group>> &&
detail::is_vector_arithmetic<T>::value &&
Expand Down Expand Up @@ -734,13 +821,37 @@ inclusive_scan_over_group(Group, T x, BinaryOperation binary_op) {
#endif
}

template <typename Group, typename V, class BinaryOperation, typename T>
// complex specializaiton
template <typename Group, typename T, class BinaryOperation>
detail::enable_if_t<(is_group_v<std::decay_t<Group>> &&
detail::is_scalar_arithmetic<V>::value &&
detail::is_scalar_arithmetic<T>::value &&
detail::is_native_op<V, BinaryOperation>::value &&
detail::is_native_op<T, BinaryOperation>::value),
detail::is_complex<T>::value &&
detail::is_native_op<T, sycl::plus<T>>::value),
T>
inclusive_scan_over_group(Group g, T x, BinaryOperation binary_op) {
#ifdef __SYCL_DEVICE_ONLY__
T result;
result.real(inclusive_scan_over_group(g, x.real(), sycl::plus<>()));
result.imag(inclusive_scan_over_group(g, x.imag(), sycl::plus<>()));
return result;
#else
(void)g;
(void)x;
(void)binary_op;
throw runtime_error("Group algorithms are not supported on host device.",
PI_INVALID_DEVICE);
#endif
}

// four argument version of exclusive_scan_over_group is specialized twice
// once for (scalar_arithmetic || complex) and once for vector_arithmetic
template <typename Group, typename V, class BinaryOperation, typename T>
detail::enable_if_t<
(is_group_v<std::decay_t<Group>> &&
(detail::is_scalar_arithmetic<V>::value || detail::is_complex<V>::value) &&
(detail::is_scalar_arithmetic<T>::value || detail::is_complex<T>::value) &&
detail::is_native_op<V, BinaryOperation>::value &&
detail::is_native_op<T, BinaryOperation>::value),
T>
inclusive_scan_over_group(Group g, V x, BinaryOperation binary_op, T init) {
// FIXME: Do not special-case for half precision
static_assert(std::is_same<decltype(binary_op(init, x)), T>::value ||
Expand Down Expand Up @@ -786,9 +897,9 @@ template <typename Group, typename InPtr, typename OutPtr,
detail::enable_if_t<
(is_group_v<std::decay_t<Group>> && detail::is_pointer<InPtr>::value &&
detail::is_pointer<OutPtr>::value &&
detail::is_arithmetic<
detail::is_arithmetic_or_complex<
typename detail::remove_pointer<InPtr>::type>::value &&
detail::is_arithmetic<T>::value &&
detail::is_arithmetic_or_complex<T>::value &&
detail::is_native_op<typename detail::remove_pointer<InPtr>::type,
BinaryOperation>::value &&
detail::is_native_op<T, BinaryOperation>::value),
Expand Down Expand Up @@ -839,7 +950,7 @@ template <typename Group, typename InPtr, typename OutPtr,
detail::enable_if_t<
(is_group_v<std::decay_t<Group>> && detail::is_pointer<InPtr>::value &&
detail::is_pointer<OutPtr>::value &&
detail::is_arithmetic<
detail::is_arithmetic_or_complex<
typename detail::remove_pointer<InPtr>::type>::value &&
detail::is_native_op<typename detail::remove_pointer<InPtr>::type,
BinaryOperation>::value),
Expand All @@ -854,10 +965,10 @@ joint_inclusive_scan(Group g, InPtr first, InPtr last, OutPtr result,
half>::value &&
std::is_same<decltype(binary_op(*first, *first)), float>::value),
"Result type of binary_op must match scan accumulation type.");
return joint_inclusive_scan(
g, first, last, result, binary_op,
sycl::known_identity_v<BinaryOperation,
typename detail::remove_pointer<OutPtr>::type>);

using T = typename detail::remove_pointer<InPtr>::type;
T init = detail::identity_for_ga_op<T, BinaryOperation>();
return joint_inclusive_scan(g, first, last, result, binary_op, init);
}

namespace detail {
Expand Down