Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ namespace sycl::ext::oneapi::experimental {
class default_sorter {
public:
template<std::size_t Extent>
default_sorter(sycl::span<uint8_t, Extent> scratch, Compare comp = Compare());
default_sorter(sycl::span<std::byte, Extent> scratch, Compare comp = Compare());

template<typename Group, typename Ptr>
void operator()(Group g, Ptr first, Ptr last);
Expand All @@ -167,7 +167,7 @@ namespace sycl::ext::oneapi::experimental {
class radix_sorter {
public:
template<std::size_t Extent>
radix_sorter(sycl::span<uint8_t, Extent> scratch,
radix_sorter(sycl::span<std::byte, Extent> scratch,
const std::bitset<sizeof(T) * CHAR_BIT> mask =
std::bitset<sizeof(T) * CHAR_BIT> (std::numeric_limits<unsigned long long>::max()));

Expand Down Expand Up @@ -215,7 +215,7 @@ Table 4. Constructors of the `default_sorter` class.
|Constructor|Description

|`template<std::size_t Extent>
default_sorter(sycl::span<uint8_t, Extent> scratch, Compare comp = Compare())`
default_sorter(sycl::span<std::byte, Extent> scratch, Compare comp = Compare())`
|Creates the `default_sorter` object using `comp`.
Additional memory for the algorithm is provided using `scratch`.
If `scratch.size()` is less than the value returned by
Expand Down Expand Up @@ -264,7 +264,7 @@ Table 6. Constructors of the `radix_sorter` class.
|Constructor|Description

|`template<std::size_t Extent>
radix_sorter(sycl::span<uint8_t, Extent> scratch, const std::bitset<sizeof(T) * CHAR_BIT> mask = std::bitset<sizeof(T) * CHAR_BIT>
radix_sorter(sycl::span<std::byte, Extent> scratch, const std::bitset<sizeof(T) * CHAR_BIT> mask = std::bitset<sizeof(T) * CHAR_BIT>
(std::numeric_limits<unsigned long long>::max()))`
|Creates the `radix_sorter` object to sort values considering only bits
that corresponds to 1 in `mask`.
Expand Down Expand Up @@ -350,16 +350,16 @@ namespace sycl::ext::oneapi::experimental {
class group_with_scratchpad
{
public:
group_with_scratchpad(Group group, sycl::span<uint8_t, Extent> scratch);
group_with_scratchpad(Group group, sycl::span<std::byte, Extent> scratch);
Group get_group() const;

sycl::span<uint8_t, Extent>
sycl::span<std::byte, Extent>
get_memory() const;
};

// Deduction guides
template<typename Group, std::size_t Extent>
group_with_scratchpad(Group, sycl::span<uint8_t, Extent>)
group_with_scratchpad(Group, sycl::span<std::byte, Extent>)
-> group_with_scratchpad<Group, Extent>;

}
Expand All @@ -372,7 +372,7 @@ Table 9. Constructors of the `group_with_scratchpad` class.
|===
|Constructor|Description

|`group_with_scratchpad(Group group, sycl::span<uint8_t, Extent> scratch)`
|`group_with_scratchpad(Group group, sycl::span<std::byte, Extent> scratch)`
|Creates the `group_with_scratchpad` object using `group` and `scratch`.
`sycl::is_group_v<std::decay_t<Group>>` must be true.
`scratch.size()` must not be less than value returned by the `memory_required` method
Expand All @@ -388,7 +388,7 @@ Table 10. Member functions of the `group_with_scratchpad` class.
|`Group get_group() const`
|Returns the `Group` class object that is handled by the `group_with_scratchpad` object.

|`sycl::span<uint8_t, Extent>
|`sycl::span<std::byte, Extent>
get_memory() const`
|Returns `sycl::span` that represents an additional memory
that is handled by the `group_with_scratchpad` object.
Expand Down Expand Up @@ -508,7 +508,7 @@ size_t temp_memory_size =

q.submit([&](sycl::handler& h) {
auto acc = sycl::accessor(buf, h);
auto scratch = sycl::local_accessor<uint8_t, 1>( {temp_memory_size}, h );
auto scratch = sycl::local_accessor<std::byte, 1>( {temp_memory_size}, h );

h.parallel_for(
sycl::nd_range<1>{ /*global_size = */ {256}, /*local_size = */ {256} },
Expand Down Expand Up @@ -546,7 +546,7 @@ size_t temp_memory_size =

q.submit([&](sycl::handler& h) {
auto acc = sycl::accessor(buf, h);
auto scratch = sycl::local_accessor<uint8_t, 1>( {temp_memory_size}, h);
auto scratch = sycl::local_accessor<std::byte, 1>( {temp_memory_size}, h);

h.parallel_for(
sycl::nd_range<1>{ local_range, local_range },
Expand Down Expand Up @@ -583,7 +583,7 @@ size_t temp_memory_size =
q.submit([&](sycl::handler& h) {
auto keys_acc = sycl::accessor(keys_buf, h);
auto vals_acc = sycl::accessor(vals_buf, h);
auto scratch = sycl::local_accessor<uint8_t, 1>( {temp_memory_size}, h);
auto scratch = sycl::local_accessor<std::byte, 1>( {temp_memory_size}, h);

h.parallel_for(
sycl::nd_range<1>{ /*global_size = */ {1024}, /*local_size = */ {256} },
Expand Down
44 changes: 34 additions & 10 deletions sycl/include/CL/sycl/detail/group_sort_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,11 +66,22 @@ struct GetValueType<sycl::multi_ptr<ElementType, Space>> {
using type = ElementType;
};

// since we couldn't assign data to raw memory, it's better to use placement for
// first assignment
template <typename Acc, typename T>
void set_value(Acc ptr, const std::size_t idx, const T &val, bool is_first) {
if (is_first) {
::new (ptr + idx) T(val);
} else {
ptr[idx] = val;
}
}

template <typename InAcc, typename OutAcc, typename Compare>
void merge(const std::size_t offset, InAcc &in_acc1, OutAcc &out_acc1,
const std::size_t start_1, const std::size_t end_1,
const std::size_t end_2, const std::size_t start_out, Compare comp,
const std::size_t chunk) {
const std::size_t chunk, bool is_first) {
const std::size_t start_2 = end_1;
// Borders of the sequences to merge within this call
const std::size_t local_start_1 =
Expand Down Expand Up @@ -98,7 +109,9 @@ void merge(const std::size_t offset, InAcc &in_acc1, OutAcc &out_acc1,
const std::size_t l_shift_1 = local_start_1 - start_1;
const std::size_t l_shift_2 = l_search_bound_2 - start_2;

out_acc1[start_out + l_shift_1 + l_shift_2] = local_l_item_1;
// out_acc1[start_out + l_shift_1 + l_shift_2] = local_l_item_1;
set_value(out_acc1, start_out + l_shift_1 + l_shift_2, local_l_item_1,
is_first);

std::size_t r_search_bound_2{};
// find right border in 2nd sequence
Expand All @@ -109,7 +122,9 @@ void merge(const std::size_t offset, InAcc &in_acc1, OutAcc &out_acc1,
const auto r_shift_1 = local_end_1 - 1 - start_1;
const auto r_shift_2 = r_search_bound_2 - start_2;

out_acc1[start_out + r_shift_1 + r_shift_2] = local_r_item_1;
// out_acc1[start_out + r_shift_1 + r_shift_2] = local_r_item_1;
set_value(out_acc1, start_out + r_shift_1 + r_shift_2, local_r_item_1,
is_first);
}

// Handle intermediate items
Expand All @@ -123,7 +138,8 @@ void merge(const std::size_t offset, InAcc &in_acc1, OutAcc &out_acc1,
const std::size_t shift_1 = idx - start_1;
const std::size_t shift_2 = l_search_bound_2 - start_2;

out_acc1[start_out + shift_1 + shift_2] = intermediate_item_1;
set_value(out_acc1, start_out + shift_1 + shift_2, intermediate_item_1,
is_first);
}
}
// Process 2nd sequence
Expand All @@ -136,7 +152,8 @@ void merge(const std::size_t offset, InAcc &in_acc1, OutAcc &out_acc1,
const std::size_t l_shift_1 = l_search_bound_1 - start_1;
const std::size_t l_shift_2 = local_start_2 - start_2;

out_acc1[start_out + l_shift_1 + l_shift_2] = local_l_item_2;
set_value(out_acc1, start_out + l_shift_1 + l_shift_2, local_l_item_2,
is_first);

std::size_t r_search_bound_1{};
// find right border in 1st sequence
Expand All @@ -147,7 +164,8 @@ void merge(const std::size_t offset, InAcc &in_acc1, OutAcc &out_acc1,
const std::size_t r_shift_1 = r_search_bound_1 - start_1;
const std::size_t r_shift_2 = local_end_2 - 1 - start_2;

out_acc1[start_out + r_shift_1 + r_shift_2] = local_r_item_2;
set_value(out_acc1, start_out + r_shift_1 + r_shift_2, local_r_item_2,
is_first);
}

// Handle intermediate items
Expand All @@ -161,7 +179,8 @@ void merge(const std::size_t offset, InAcc &in_acc1, OutAcc &out_acc1,
const std::size_t shift_1 = l_search_bound_1 - start_1;
const std::size_t shift_2 = idx - start_2;

out_acc1[start_out + shift_1 + shift_2] = intermediate_item_2;
set_value(out_acc1, start_out + shift_1 + shift_2, intermediate_item_2,
is_first);
}
}
}
Expand All @@ -183,7 +202,7 @@ void bubble_sort(Iter first, const std::size_t begin, const std::size_t end,

template <typename Group, typename Iter, typename Compare>
void merge_sort(Group group, Iter first, const std::size_t n, Compare comp,
std::uint8_t *scratch) {
std::byte *scratch) {
using T = typename GetValueType<Iter>::type;
auto id = sycl::detail::Builder::getNDItem<Group::dimensions>();
const std::size_t idx = id.get_local_id();
Expand All @@ -196,6 +215,7 @@ void merge_sort(Group group, Iter first, const std::size_t n, Compare comp,

T *temp = reinterpret_cast<T *>(scratch);
bool data_in_temp = false;
bool is_first = true;
std::size_t sorted_size = 1;
while (sorted_size * chunk < n) {
const std::size_t start_1 =
Expand All @@ -205,14 +225,18 @@ void merge_sort(Group group, Iter first, const std::size_t n, Compare comp,
const std::size_t offset = chunk * (idx % sorted_size);

if (!data_in_temp) {
merge(offset, first, temp, start_1, end_1, end_2, start_1, comp, chunk);
merge(offset, first, temp, start_1, end_1, end_2, start_1, comp, chunk,
is_first);
} else {
merge(offset, temp, first, start_1, end_1, end_2, start_1, comp, chunk);
merge(offset, temp, first, start_1, end_1, end_2, start_1, comp, chunk,
/*is_first*/ false);
}
id.barrier();

data_in_temp = !data_in_temp;
sorted_size *= 2;
if (is_first)
is_first = false;
}

// copy back if data is in a temporary storage
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,24 +19,24 @@ namespace experimental {
// ---- group helpers
template <typename Group, std::size_t Extent> class group_with_scratchpad {
Group g;
sycl::span<std::uint8_t, Extent> scratch;
sycl::span<std::byte, Extent> scratch;

public:
group_with_scratchpad(Group g_, sycl::span<std::uint8_t, Extent> scratch_)
group_with_scratchpad(Group g_, sycl::span<std::byte, Extent> scratch_)
: g(g_), scratch(scratch_) {}
Group get_group() const { return g; }
sycl::span<std::uint8_t, Extent> get_memory() const { return scratch; }
sycl::span<std::byte, Extent> get_memory() const { return scratch; }
};

// ---- sorters
template <typename Compare = std::less<>> class default_sorter {
Compare comp;
std::uint8_t *scratch;
std::byte *scratch;
std::size_t scratch_size;

public:
template <std::size_t Extent>
default_sorter(sycl::span<std::uint8_t, Extent> scratch_,
default_sorter(sycl::span<std::byte, Extent> scratch_,
Compare comp_ = Compare())
: comp(comp_), scratch(scratch_.data()), scratch_size(scratch_.size()) {}

Expand All @@ -51,9 +51,9 @@ template <typename Compare = std::less<>> class default_sorter {
(void)g;
(void)first;
(void)last;
throw runtime_error(
"default_sorter constructor is not supported on host device.",
PI_INVALID_DEVICE);
throw sycl::exception(
std::error_code(PI_INVALID_DEVICE, sycl::sycl_category()),
"default_sorter constructor is not supported on host device.");
#endif
}

Expand All @@ -64,32 +64,31 @@ template <typename Compare = std::less<>> class default_sorter {
auto id = sycl::detail::Builder::getNDItem<Group::dimensions>();
uint32_t local_id = id.get_local_id();
T *temp = reinterpret_cast<T *>(scratch);
temp[local_id] = val;
::new (temp + local_id) T(val);
sycl::detail::merge_sort(g, temp, range_size, comp,
scratch + range_size * sizeof(T));
val = temp[local_id];
}
// TODO: it's better to add else branch
#else
(void)g;
(void)val;
throw runtime_error(
"default_sorter operator() is not supported on host device.",
PI_INVALID_DEVICE);
throw sycl::exception(
std::error_code(PI_INVALID_DEVICE, sycl::sycl_category()),
"default_sorter operator() is not supported on host device.");
#endif
return val;
}

template <typename T>
static constexpr std::size_t memory_required(sycl::memory_scope scope,
static constexpr std::size_t memory_required(sycl::memory_scope,
std::size_t range_size) {
return range_size * sizeof(T);
return range_size * sizeof(T) + alignof(T);
}

template <typename T, int dim = 1>
static constexpr std::size_t memory_required(sycl::memory_scope scope,
sycl::range<dim> r) {
return 2 * r.size() * sizeof(T);
return 2 * memory_required<T>(scope, r.size());
}
};

Expand Down
15 changes: 11 additions & 4 deletions sycl/include/sycl/ext/oneapi/group_sort.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,11 @@ sort_over_group(Group group, T value, Sorter sorter) {
return sorter(group, value);
#else
(void)group;
throw runtime_error("Group algorithms are not supported on host device.",
PI_INVALID_DEVICE);
(void)value;
(void)sorter;
throw sycl::exception(
std::error_code(PI_INVALID_DEVICE, sycl::sycl_category()),
"Group algorithms are not supported on host device.");
#endif
}

Expand Down Expand Up @@ -108,8 +111,12 @@ joint_sort(Group group, Iter first, Iter last, Sorter sorter) {
sorter(group, first, last);
#else
(void)group;
throw runtime_error("Group algorithms are not supported on host device.",
PI_INVALID_DEVICE);
(void)first;
(void)last;
(void)sorter;
throw sycl::exception(
std::error_code(PI_INVALID_DEVICE, sycl::sycl_category()),
"Group algorithms are not supported on host device.");
#endif
}

Expand Down