diff --git a/sycl/doc/extensions/GroupAlgorithms/SYCL_INTEL_group_sort.asciidoc b/sycl/doc/extensions/GroupAlgorithms/SYCL_INTEL_group_sort.asciidoc index b050bb3a372e0..fe4a76abe62d1 100755 --- a/sycl/doc/extensions/GroupAlgorithms/SYCL_INTEL_group_sort.asciidoc +++ b/sycl/doc/extensions/GroupAlgorithms/SYCL_INTEL_group_sort.asciidoc @@ -146,7 +146,7 @@ namespace sycl::ext::oneapi::experimental { class default_sorter { public: template - default_sorter(sycl::span scratch, Compare comp = Compare()); + default_sorter(sycl::span scratch, Compare comp = Compare()); template void operator()(Group g, Ptr first, Ptr last); @@ -167,7 +167,7 @@ namespace sycl::ext::oneapi::experimental { class radix_sorter { public: template - radix_sorter(sycl::span scratch, + radix_sorter(sycl::span scratch, const std::bitset mask = std::bitset (std::numeric_limits::max())); @@ -215,7 +215,7 @@ Table 4. Constructors of the `default_sorter` class. |Constructor|Description |`template -default_sorter(sycl::span scratch, Compare comp = Compare())` +default_sorter(sycl::span scratch, Compare comp = Compare())` |Creates the `default_sorter` object using `comp`. Additional memory for the algorithm is provided using `scratch`. If `scratch.size()` is less than the value returned by @@ -264,7 +264,7 @@ Table 6. Constructors of the `radix_sorter` class. |Constructor|Description |`template -radix_sorter(sycl::span scratch, const std::bitset mask = std::bitset +radix_sorter(sycl::span scratch, const std::bitset mask = std::bitset (std::numeric_limits::max()))` |Creates the `radix_sorter` object to sort values considering only bits that corresponds to 1 in `mask`. @@ -350,16 +350,16 @@ namespace sycl::ext::oneapi::experimental { class group_with_scratchpad { public: - group_with_scratchpad(Group group, sycl::span scratch); + group_with_scratchpad(Group group, sycl::span scratch); Group get_group() const; - sycl::span + sycl::span get_memory() const; }; // Deduction guides template - group_with_scratchpad(Group, sycl::span) + group_with_scratchpad(Group, sycl::span) -> group_with_scratchpad; } @@ -372,7 +372,7 @@ Table 9. Constructors of the `group_with_scratchpad` class. |=== |Constructor|Description -|`group_with_scratchpad(Group group, sycl::span scratch)` +|`group_with_scratchpad(Group group, sycl::span scratch)` |Creates the `group_with_scratchpad` object using `group` and `scratch`. `sycl::is_group_v>` must be true. `scratch.size()` must not be less than value returned by the `memory_required` method @@ -388,7 +388,7 @@ Table 10. Member functions of the `group_with_scratchpad` class. |`Group get_group() const` |Returns the `Group` class object that is handled by the `group_with_scratchpad` object. -|`sycl::span +|`sycl::span get_memory() const` |Returns `sycl::span` that represents an additional memory that is handled by the `group_with_scratchpad` object. @@ -508,7 +508,7 @@ size_t temp_memory_size = q.submit([&](sycl::handler& h) { auto acc = sycl::accessor(buf, h); - auto scratch = sycl::local_accessor( {temp_memory_size}, h ); + auto scratch = sycl::local_accessor( {temp_memory_size}, h ); h.parallel_for( sycl::nd_range<1>{ /*global_size = */ {256}, /*local_size = */ {256} }, @@ -546,7 +546,7 @@ size_t temp_memory_size = q.submit([&](sycl::handler& h) { auto acc = sycl::accessor(buf, h); - auto scratch = sycl::local_accessor( {temp_memory_size}, h); + auto scratch = sycl::local_accessor( {temp_memory_size}, h); h.parallel_for( sycl::nd_range<1>{ local_range, local_range }, @@ -583,7 +583,7 @@ size_t temp_memory_size = q.submit([&](sycl::handler& h) { auto keys_acc = sycl::accessor(keys_buf, h); auto vals_acc = sycl::accessor(vals_buf, h); - auto scratch = sycl::local_accessor( {temp_memory_size}, h); + auto scratch = sycl::local_accessor( {temp_memory_size}, h); h.parallel_for( sycl::nd_range<1>{ /*global_size = */ {1024}, /*local_size = */ {256} }, diff --git a/sycl/include/CL/sycl/detail/group_sort_impl.hpp b/sycl/include/CL/sycl/detail/group_sort_impl.hpp new file mode 100644 index 0000000000000..e80f2fd131e77 --- /dev/null +++ b/sycl/include/CL/sycl/detail/group_sort_impl.hpp @@ -0,0 +1,256 @@ +//==------------ group_sort_impl.hpp ---------------------------------------==// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// This file includes some functions for group sorting algorithm implementations +// + +#pragma once + +#if __cplusplus >= 201703L +#include + +#ifdef __SYCL_DEVICE_ONLY__ + +__SYCL_INLINE_NAMESPACE(cl) { +namespace sycl { +namespace detail { + +// ---- merge sort implementation + +// following two functions could be useless if std::[lower|upper]_bound worked +// well +template +std::size_t lower_bound(Acc acc, std::size_t first, std::size_t last, + const Value &value, Compare comp) { + std::size_t n = last - first; + std::size_t cur = n; + std::size_t it; + while (n > 0) { + it = first; + cur = n / 2; + it += cur; + if (comp(acc[it], value)) { + n -= cur + 1, first = ++it; + } else + n = cur; + } + return first; +} + +template +std::size_t upper_bound(Acc acc, const std::size_t first, + const std::size_t last, const Value &value, + Compare comp) { + return detail::lower_bound(acc, first, last, value, + [comp](auto x, auto y) { return !comp(y, x); }); +} + +// swap for all data types including tuple-like types +template void swap_tuples(T &a, T &b) { std::swap(a, b); } + +template