diff --git a/sycl/doc/extensions/GroupMask/GroupMask.asciidoc b/sycl/doc/extensions/SubGroupMask/SubGroupMask.asciidoc similarity index 75% rename from sycl/doc/extensions/GroupMask/GroupMask.asciidoc rename to sycl/doc/extensions/SubGroupMask/SubGroupMask.asciidoc index d95fd6c5b12be..c3b9a6ca98ca4 100755 --- a/sycl/doc/extensions/GroupMask/GroupMask.asciidoc +++ b/sycl/doc/extensions/SubGroupMask/SubGroupMask.asciidoc @@ -1,4 +1,4 @@ -= SYCL_EXT_ONEAPI_GROUP_MASK += SYCL_EXT_ONEAPI_SUB_GROUP_MASK :source-highlighter: coderay :coderay-linenums-mode: table @@ -21,7 +21,7 @@ IMPORTANT: This specification is a draft. NOTE: Khronos(R) is a registered trademark and SYCL(TM) and SPIR(TM) are trademarks of The Khronos Group Inc. OpenCL(TM) is a trademark of Apple Inc. used by permission by Khronos. -This document describes an extension which adds a `group_mask` type. Such a mask can be used to efficiently represent subsets of work-items in a group for which a given Boolean condition holds. Group mask functionality is currently limited to groups that are instances of the `sub_group` class. +This document describes an extension which adds a `sub_group_mask` type. Such a mask can be used to efficiently represent subsets of work-items in a sub-group for which a given Boolean condition holds. == Notice @@ -51,9 +51,9 @@ This extension is written against the SYCL 2020 specification, Revision 3. This extension provides a feature-test macro as described in the core SYCL specification section 6.3.3 "Feature test macros". Therefore, an implementation supporting this extension must predefine the macro -`SYCL_EXT_ONEAPI_GROUP_MASK` to one of the values defined in the table below. -Applications can test for the existence of this macro to determine if the -implementation supports this feature, or applications can test the macro's +`SYCL_EXT_ONEAPI_SUB_GROUP_MASK` to one of the values defined in the table +below. Applications can test for the existence of this macro to determine if +the implementation supports this feature, or applications can test the macro's value to determine which of the extension's APIs the implementation supports. [%header,cols="1,5"] @@ -81,18 +81,18 @@ must be encountered by all work-items in the group in converged control flow. |=== |Function|Description -|`template Group::mask_type group_ballot(Group g, bool predicate = true) const` -|Return a `group_mask` representing the set of work-items in group _g_ for which _predicate_ is `true`. +|`template Group::mask_type group_ballot(Group g, bool predicate = true)` +|Return a `sub_group_mask` with one bit for each work-item in group _g_. A bit is set in this mask if and only if the corresponding work-item's _predicate_ is `true`. |=== === Group Masks The group mask type is an opaque type, permitting implementations to use any mask representation that has the same size and alignment across host and -device. The maximum number of bits that can be stored in a `group_mask` is -exposed as a static member variable, `group_mask::max_bits`. +device. The maximum number of bits that can be stored in a `sub_group_mask` is +exposed as a static member variable, `sub_group_mask::max_bits`. -Functions declared in the `group_mask` class can be called independently by +Functions declared in the `sub_group_mask` class can be called independently by different work-items in the same group. An instance of a group class (e.g. `group` or `sub_group`) is not required to manipulate a group mask. @@ -107,7 +107,7 @@ work-item with the id `max_local_range()-1`. |Return `true` if the bit corresponding to the specified _id_ is set in the mask. -|`group_mask::reference operator[](id<1> id)` +|`sub_group_mask::reference operator[](id<1> id)` |Return a reference to the bit corresponding to the specified _id_ in the mask. |`bool test(id<1> id) const` @@ -137,17 +137,15 @@ work-item with the id `max_local_range()-1`. |Return the highest `id` with a corresponding bit set in the mask. If no bits are set, the return value is equal to `size()`. -|`template > void insert_bits(T bits, id<1> pos = 0)` +|`template void insert_bits(const T &bits, id<1> pos = 0)` |Insert `CHAR_BIT * sizeof(T)` bits into the mask, starting from _pos_. `T` - must be an integral type or a SYCL `marray` of integral types. _pos_ must be a - multiple of `CHAR_BIT * sizeof(T)` in the range [0, `size()`). If _pos_ pass:[+] + must be an integral type or a SYCL `marray` of integral types. If _pos_ pass:[+] `CHAR_BIT * sizeof(T)` is greater than `size()`, the final `size()` - (_pos_ pass:[+] `CHAR_BIT * sizeof(T)`) bits are ignored. -|`template > T extract_bits(id<1> pos = 0) const` +|`template void extract_bits(T &out, id<1> pos = 0) const` |Return `CHAR_BIT * sizeof(T)` bits from the mask, starting from _pos_. `T` - must be an integral type or a SYCL `marray` of integral types. _pos_ must be a - multiple of `CHAR_BIT * sizeof(T)` in the range [0, `size()`). If _pos_ pass:[+] + must be an integral type or a SYCL `marray` of integral types. If _pos_ pass:[+] `CHAR_BIT * sizeof(T)` is greater than `size()`, the final `size()` - (_pos_ pass:[+] `CHAR_BIT * sizeof(T)`) bits of the return value are zero. @@ -178,62 +176,63 @@ work-item with the id `max_local_range()-1`. |`void flip(id<1> id)` |Toggle the value of the bit corresponding to the specified _id_. -|`bool operator==(group_mask rhs) const` +|`bool operator==(const sub_group_mask &rhs) const` |Return true if each bit in this mask is equal to the corresponding bit in `rhs`. -|`bool operator!=(group_mask rhs) const` +|`bool operator!=(const sub_group_mask &rhs) const` |Return true if any bit in this mask is not equal to the corresponding bit in `rhs`. -|`group_mask operator &=(group_mask rhs)` +|`sub_group_mask &operator &=(const sub_group_mask &rhs)` |Set the bits of this mask to the result of performing a bitwise AND with this mask and `rhs`. -|`group_mask operator \|=(group_mask rhs)` +|`sub_group_mask &operator \|=(const sub_group_mask &rhs)` |Set the bits of this mask to the result of performing a bitwise OR with this mask and `rhs`. -|`group_mask operator ^=(group_mask rhs)` +|`sub_group_mask &operator ^=(const sub_group_mask &rhs)` |Set the bits of this mask to the result of performing a bitwise XOR with this mask and `rhs`. -|`group_mask operator pass:[<<=](size_t shift)` +|`sub_group_mask &operator pass:[<<=](size_t shift)` |Set the bits of this mask to the result of shifting its bits _shift_ positions to the left using a logical shift. Bits that are shifted out to the left are discarded, and zeroes are shifted in from the right. -|`group_mask operator >>=(size_t shift)` +|`sub_group_mask &operator >>=(size_t shift)` |Set the bits of this mask to the result of shifting its bits _shift_ positions to the right using a logical shift. Bits that are shifted out to the right are discarded, and zeroes are shifted in from the left. -|`group_mask operator ~() const` +|`sub_group_mask operator ~() const` |Return a mask representing the result of flipping all the bits in this mask. -|`group_mask operator <<(size_t shift)` +|`sub_group_mask operator <<(size_t shift) const` |Return a mask representing the result of shifting its bits _shift_ positions to the left using a logical shift. Bits that are shifted out to the left are discarded, and zeroes are shifted in from the right. -|`group_mask operator >>(size_t shift)` +|`sub_group_mask operator >>(size_t shift) const` |Return a mask representing the result of shifting its bits _shift_ positions to the right using a logical shift. Bits that are shifted out to the right are discarded, and zeroes are shifted in from the left. + |=== |=== |Function|Description -|`group_mask operator &(const group_mask& lhs, const group_mask& rhs)` +|`sub_group_mask operator &(const sub_group_mask& lhs, const sub_group_mask& rhs)` |Return a mask representing the result of performing a bitwise AND of `lhs` and `rhs`. -|`group_mask operator \|(const group_mask& lhs, const group_mask& rhs)` +|`sub_group_mask operator \|(const sub_group_mask& lhs, const sub_group_mask& rhs)` |Return a mask representing the result of performing a bitwise OR of `lhs` and `rhs`. -|`group_mask operator ^(const group_mask& lhs, const group_mask& rhs)` +|`sub_group_mask operator ^(const sub_group_mask& lhs, const sub_group_mask& rhs)` |Return a mask representing the result of performing a bitwise XOR of `lhs` and `rhs`. @@ -247,7 +246,7 @@ namespace sycl { namespace ext { namespace oneapi { -struct group_mask { +struct sub_group_mask { // enable reference to individual bit struct reference { @@ -271,11 +270,11 @@ struct group_mask { id<1> find_low() const; id<1> find_high() const; - template > - void insert_bits(T bits, id<1> pos = 0); + template + void insert_bits(const T &bits, id<1> pos = 0); - template > - T extract_bits(id<1> pos = 0); + template + void extract_bits(T &out, id<1> pos = 0); void set(); void set(id<1> id, bool value = true); @@ -286,24 +285,24 @@ struct group_mask { void flip(); void flip(id<1> id); - bool operator==(group_mask rhs) const; - bool operator!=(group_mask rhs) const; + bool operator==(const sub_group_mask &rhs) const; + bool operator!=(const sub_group_mask &rhs) const; - group_mask operator &=(group_mask rhs); - group_mask operator |=(group_mask rhs); - group_mask operator ^=(group_mask rhs); - group_mask operator <<=(size_t); - group_mask operator >>=(size_t rhs); + sub_group_mask &operator &=(const sub_group_mask &rhs); + sub_group_mask &operator |=(const sub_group_mask &rhs); + sub_group_mask &operator ^=(const sub_group_mask &rhs); + sub_group_mask &operator <<=(size_t n); + sub_group_mask &operator >>=(size_t n); - group_mask operator ~() const; - group_mask operator <<(size_t) const; - group_mask operator >>(size_t) const; + sub_group_mask operator ~() const; + sub_group_mask operator <<(size_t n) const; + sub_group_mask operator >>(size_t n) const; }; -group_mask operator &(const group_mask& lhs, const group_mask& rhs); -group_mask operator |(const group_mask& lhs, const group_mask& rhs); -group_mask operator ^(const group_mask& lhs, const group_mask& rhs); +sub_group_mask operator &(const sub_group_mask& lhs, const sub_group_mask& rhs); +sub_group_mask operator |(const sub_group_mask& lhs, const sub_group_mask& rhs); +sub_group_mask operator ^(const sub_group_mask& lhs, const sub_group_mask& rhs); } // namespace oneapi } // namespace ext @@ -328,6 +327,7 @@ None. |======================================== |Rev|Date|Author|Changes |1|2021-08-11|John Pennycook|*Initial public working draft* +|2|2021-09-13|Vladimir Lazarev|*Update during implementation* |======================================== //************************************************************************ diff --git a/sycl/include/CL/__spirv/spirv_ops.hpp b/sycl/include/CL/__spirv/spirv_ops.hpp index 18ef03cc70607..7e6a2aad1f5fb 100644 --- a/sycl/include/CL/__spirv/spirv_ops.hpp +++ b/sycl/include/CL/__spirv/spirv_ops.hpp @@ -597,6 +597,9 @@ __spirv_ocl_prefetch(const __attribute__((opencl_global)) char *Ptr, extern SYCL_EXTERNAL uint16_t __spirv_ConvertFToBF16INTEL(float) noexcept; extern SYCL_EXTERNAL float __spirv_ConvertBF16ToFINTEL(uint16_t) noexcept; +__SYCL_CONVERGENT__ extern SYCL_EXTERNAL __SYCL_EXPORT __ocl_vec_t +__spirv_GroupNonUniformBallot(uint32_t Execution, bool Predicate) noexcept; + #else // if !__SYCL_DEVICE_ONLY__ template diff --git a/sycl/include/CL/sycl.hpp b/sycl/include/CL/sycl.hpp index 75f0cedd57c44..46c080ee880e4 100644 --- a/sycl/include/CL/sycl.hpp +++ b/sycl/include/CL/sycl.hpp @@ -59,3 +59,4 @@ #include #include #include +#include diff --git a/sycl/include/CL/sycl/detail/helpers.hpp b/sycl/include/CL/sycl/detail/helpers.hpp index 118271a35bab5..838b68f33a641 100644 --- a/sycl/include/CL/sycl/detail/helpers.hpp +++ b/sycl/include/CL/sycl/detail/helpers.hpp @@ -31,6 +31,7 @@ template class range; template class id; template class nd_item; template class h_item; +template class marray; enum class memory_order; namespace detail { @@ -82,6 +83,11 @@ class Builder { return group(Global, Local, Global / Local, Index); } + template + static ResType createSubGroupMask(uint32_t Bits, size_t BitsNum) { + return ResType(Bits, BitsNum); + } + template static detail::enable_if_t> createItem(const range &Extent, const id &Index, diff --git a/sycl/include/CL/sycl/feature_test.hpp b/sycl/include/CL/sycl/feature_test.hpp index 4625cfa06fed7..e3cacc48f0982 100644 --- a/sycl/include/CL/sycl/feature_test.hpp +++ b/sycl/include/CL/sycl/feature_test.hpp @@ -14,6 +14,7 @@ namespace sycl { // TODO: Move these feature-test macros to compiler driver. #define SYCL_EXT_INTEL_DEVICE_INFO 2 +#define SYCL_EXT_ONEAPI_SUB_GROUP_MASK 1 #define SYCL_EXT_ONEAPI_LOCAL_MEMORY 1 // As for SYCL_EXT_ONEAPI_MATRIX: // 1- provides AOT initial implementation for AMX for the experimental matrix diff --git a/sycl/include/CL/sycl/marray.hpp b/sycl/include/CL/sycl/marray.hpp index 5b758b80683d0..0267f0a85ff8a 100644 --- a/sycl/include/CL/sycl/marray.hpp +++ b/sycl/include/CL/sycl/marray.hpp @@ -149,9 +149,9 @@ template class marray { } #define __SYCL_BINOP_INTEGRAL(BINOP, OPASSIGN) \ - template \ - friend typename std::enable_if::value, marray> \ - operator BINOP(const marray &Lhs, const marray &Rhs) { \ + template ::value, marray>> \ + friend marray operator BINOP(const marray &Lhs, const marray &Rhs) { \ marray Ret; \ for (size_t I = 0; I < NumElements; ++I) { \ Ret[I] = Lhs[I] BINOP Rhs[I]; \ @@ -166,9 +166,9 @@ template class marray { operator BINOP(const marray &Lhs, const T &Rhs) { \ return Lhs BINOP marray(static_cast(Rhs)); \ } \ - template \ - friend typename std::enable_if::value, marray> \ - &operator OPASSIGN(marray &Lhs, const marray &Rhs) { \ + template ::value, marray>> \ + friend marray &operator OPASSIGN(marray &Lhs, const marray &Rhs) { \ Lhs = Lhs BINOP Rhs; \ return Lhs; \ } \ diff --git a/sycl/include/sycl/ext/oneapi/sub_group_mask.hpp b/sycl/include/sycl/ext/oneapi/sub_group_mask.hpp new file mode 100644 index 0000000000000..3ff837beabd16 --- /dev/null +++ b/sycl/include/sycl/ext/oneapi/sub_group_mask.hpp @@ -0,0 +1,267 @@ +//==------------ sub_group_mask.hpp --- SYCL sub-group mask ----------------==// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +#pragma once + +#include +#include +#include +#include +#include +#include + +__SYCL_INLINE_NAMESPACE(cl) { +namespace sycl { +namespace detail { +class Builder; +} // namespace detail + +namespace ext { +namespace oneapi { + +struct sub_group_mask { + friend class detail::Builder; + static constexpr size_t max_bits = 32 /* implementation-defined */; + static constexpr size_t word_size = sizeof(uint32_t) * CHAR_BIT; + + // enable reference to individual bit + struct reference { + reference &operator=(bool x) { + if (x) { + Ref |= RefBit; + } else { + Ref &= ~RefBit; + } + return *this; + } + reference &operator=(const reference &x) { + operator=((bool)x); + return *this; + } + bool operator~() const { return !(Ref & RefBit); } + operator bool() const { return Ref & RefBit; } + reference &flip() { + operator=(!(bool)*this); + return *this; + } + + reference(sub_group_mask &gmask, size_t pos) : Ref(gmask.Bits) { + RefBit = 1 << pos % word_size; + } + + private: + // Reference to the word containing the bit + uint32_t &Ref; + // Bit mask where only referenced bit is set + uint32_t RefBit; + }; + + bool operator[](id<1> id) const { + return Bits & (1 << (id.get(0) % word_size)); + } + reference operator[](id<1> id) { return {*this, id.get(0)}; } + bool test(id<1> id) const { return operator[](id); } + bool all() const { return !~Bits; } + bool any() const { return Bits; } + bool none() const { return !Bits; } + uint32_t count() const { + unsigned int count = 0; + auto word = Bits; + while (word) { + word &= (word - 1); + count++; + } + return count; + } + uint32_t size() const { return bits_num; } + id<1> find_low() const { + size_t i = 0; + while (i < size() && !operator[](i)) + i++; + return {i}; + } + id<1> find_high() const { + size_t i = size() - 1; + while (i > 0 && !operator[](i)) + i--; + return {operator[](i) ? i : size()}; + } + + template ::value>> + void insert_bits(Type bits, id<1> pos = 0) { + size_t insert_size = sizeof(Type) * CHAR_BIT; + uint32_t insert_data = (uint32_t)bits; + insert_data <<= pos.get(0); + uint32_t mask = 0; + if (pos.get(0) + insert_size < size()) + mask |= (0xffffffff << (pos.get(0) + insert_size)); + if (pos.get(0) < size()) + mask |= (0xffffffff >> (size() - pos.get(0))); + Bits &= mask; + Bits += insert_data; + } + + /* The bits are stored in the memory in the following way: + marray id | 0 | 1 | 2 | 3 | + bit id |7 .. 0|15 .. 8|23 .. 16|31 .. 24| + */ + template ::value>> + void insert_bits(const marray &bits, id<1> pos = 0) { + size_t cur_pos = pos.get(0); + for (auto elem : bits) { + if (cur_pos < size()) { + this->insert_bits(elem, cur_pos); + cur_pos += sizeof(Type) * CHAR_BIT; + } + } + } + + template ::value>> + void extract_bits(Type &bits, id<1> pos = 0) { + uint32_t Res = Bits; + if (pos.get(0) < size()) { + if (pos.get(0) > 0) { + Res >>= pos.get(0); + } + + if (sizeof(Type) * CHAR_BIT < size()) { + Res &= (0xffffffff >> (size() - (sizeof(Type) * CHAR_BIT))); + } + bits = (Type)Res; + } else { + bits = 0; + } + } + + template ::value>> + void extract_bits(marray &bits, id<1> pos = 0) { + size_t cur_pos = pos.get(0); + for (auto &elem : bits) { + if (cur_pos < size()) { + this->extract_bits(elem, cur_pos); + cur_pos += sizeof(Type) * CHAR_BIT; + } else { + elem = 0; + } + } + } + + void set() { Bits = uint32_t{0xffffffff}; } + void set(id<1> id, bool value = true) { operator[](id) = value; } + void reset() { Bits = uint32_t{0}; } + void reset(id<1> id) { operator[](id) = 0; } + void reset_low() { reset(find_low()); } + void reset_high() { reset(find_high()); } + void flip() { Bits = ~Bits; } + void flip(id<1> id) { operator[](id).flip(); } + + bool operator==(const sub_group_mask &rhs) const { return Bits == rhs.Bits; } + bool operator!=(const sub_group_mask &rhs) const { return !(*this == rhs); } + + sub_group_mask &operator&=(const sub_group_mask &rhs) { + Bits &= rhs.Bits; + return *this; + } + sub_group_mask &operator|=(const sub_group_mask &rhs) { + Bits |= rhs.Bits; + return *this; + } + + sub_group_mask &operator^=(const sub_group_mask &rhs) { + Bits ^= rhs.Bits; + return *this; + } + + sub_group_mask &operator<<=(size_t pos) { + Bits <<= pos; + return *this; + } + + sub_group_mask &operator>>=(size_t pos) { + Bits >>= pos; + return *this; + } + + sub_group_mask operator~() const { + auto Tmp = *this; + Tmp.flip(); + return Tmp; + } + sub_group_mask operator<<(size_t pos) const { + auto Tmp = *this; + Tmp <<= pos; + return Tmp; + } + sub_group_mask operator>>(size_t pos) const { + auto Tmp = *this; + Tmp >>= pos; + return Tmp; + } + + sub_group_mask(const sub_group_mask &rhs) + : Bits(rhs.Bits), bits_num(rhs.bits_num) {} + + template + friend detail::enable_if_t< + std::is_same, sub_group>::value, sub_group_mask> + group_ballot(Group g, bool predicate); + + friend sub_group_mask operator&(const sub_group_mask &lhs, + const sub_group_mask &rhs) { + auto Res = lhs; + Res &= rhs; + return Res; + } + + friend sub_group_mask operator|(const sub_group_mask &lhs, + const sub_group_mask &rhs) { + auto Res = lhs; + Res |= rhs; + return Res; + } + + friend sub_group_mask operator^(const sub_group_mask &lhs, + const sub_group_mask &rhs) { + auto Res = lhs; + Res ^= rhs; + return Res; + } + +private: + sub_group_mask(uint32_t rhs, size_t bn) : Bits(rhs), bits_num(bn) { + assert(bits_num <= max_bits); + } + uint32_t Bits; + // Number of valuable bits + size_t bits_num; +}; + +template +detail::enable_if_t, sub_group>::value, + sub_group_mask> +group_ballot(Group g, bool predicate) { + (void)g; +#ifdef __SYCL_DEVICE_ONLY__ + auto res = __spirv_GroupNonUniformBallot( + detail::spirv::group_scope::value, predicate); + return detail::Builder::createSubGroupMask( + res[0], g.get_max_local_range()[0]); +#else + (void)predicate; + throw exception{errc::feature_not_supported, + "Sub-group mask is not supported on host device"}; +#endif +} + +} // namespace oneapi +} // namespace ext +} // namespace sycl +} // __SYCL_INLINE_NAMESPACE(cl) diff --git a/sycl/test/check_device_code/sub_group_mask.cpp b/sycl/test/check_device_code/sub_group_mask.cpp new file mode 100644 index 0000000000000..b074567e6d461 --- /dev/null +++ b/sycl/test/check_device_code/sub_group_mask.cpp @@ -0,0 +1,10 @@ +// RUN: %clangxx -I %sycl_include -S -emit-llvm -fsycl-device-only %s -o - -Xclang -disable-llvm-passes | FileCheck %s + +#include + +using namespace sycl; + +SYCL_EXTERNAL void test_group_mask(sub_group g) { + ext::oneapi::group_ballot(g, true); +} +// CHECK: %{{.*}} = call spir_func <4 x i32> @_Z[[#]]__spirv_GroupNonUniformBallotjb(i32 {{.*}}, i1{{.*}}) diff --git a/sycl/test/extensions/macro.cpp b/sycl/test/extensions/macro.cpp new file mode 100644 index 0000000000000..7264ac21e4264 --- /dev/null +++ b/sycl/test/extensions/macro.cpp @@ -0,0 +1,15 @@ +// This test checks presence of macros for available extensions. +// RUN: %clangxx -fsycl %s -o %t.out +// RUN: %t.out + +#include +#include +int main() { +#if SYCL_EXT_ONEAPI_SUB_GROUP_MASK == 1 + std::cout << "SYCL_EXT_ONEAPI_SUB_GROUP_MASK=1" << std::endl; +#else + std::cerr << "SYCL_EXT_ONEAPI_SUB_GROUP_MASK!=1" << std::endl; + exit(1); +#endif + exit(0); +} diff --git a/sycl/test/extensions/sub_group_mask.cpp b/sycl/test/extensions/sub_group_mask.cpp new file mode 100644 index 0000000000000..cc2cdfa17c43c --- /dev/null +++ b/sycl/test/extensions/sub_group_mask.cpp @@ -0,0 +1,90 @@ +// RUN: %clangxx -g -O0 -fsycl -fsycl-targets=%sycl_triple %s -o %t.out +// RUN: %t.out + +//==-------- sub_group_mask.cpp - SYCL sub-group mask test -----------------==// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +#include +#include +#include + +int main() { + auto g = sycl::detail::Builder::createSubGroupMask< + sycl::ext::oneapi::sub_group_mask>(0, 32); + assert(g.none() && !g.any() && !g.all()); + assert(g[10] == false); // reference::operator[](id) const; + g[10] = true; // reference::operator=(bool); + assert(g[10] == true); + g[11] = g[10]; // reference::operator=(reference) reference::operator[](id); + assert(g[10].flip() == false); // reference::flip() + assert(~g[10] == true); // refernce::operator~() + assert(g[10] == false); + assert(g[11] == true); + assert(g.test(10) == false && g.test(11) == true); + g.set(30, 1); + g.set(11, 0); + g.set(23, 1); + assert(!g.none() && g.any() && !g.all()); + + assert(g.count() == 2); + assert(g.find_low() == 23); + assert(g.find_high() == 30); + assert(g.size() == 32); + + g.reset(); + assert(g.none() && !g.any() && !g.all()); + assert(g.find_low() == g.size() && g.find_high() == g.size()); + g.set(); + assert(!g.none() && g.any() && g.all()); + assert(g.find_low() == 0 && g.find_high() == 31); + g.flip(); + assert(g.none() && !g.any() && !g.all()); + + g.flip(13); + g.flip(23); + g.flip(29); + auto b = g; + assert(b == g && !(b != g)); + g.flip(31); + assert(g.find_high() == 31); + assert(b.find_high() == 29); + assert(b != g && !(b == g)); + b.flip(31); + assert(b == g && !(b != g)); + b = g >> 1; + assert(b[12] && b[22] && b[28] && b[30]); + b <<= 1; + assert(b == g); + g ^= ~b; + assert(!g.none() && g.any() && g.all()); + assert((g | ~g).all()); + assert((g & ~g).none()); + assert((g ^ ~g).all()); + b.reset_low(); + b.reset_high(); + assert(!b[13] && b[23] && b[29] && !b[31]); + b.insert_bits(0x01020408); + assert(b[24] && b[17] && b[10] && b[3]); + b <<= 13; + assert(!b[24] && !b[17] && !b[10] && !b[3] && b[30] && b[23] && b[16]); + b.insert_bits((char)0b01010101, 18); + assert(b[18] && b[20] && b[22] && b[24] && b[30] && !b[23] && b[16]); + b[3] = true; + b.insert_bits(sycl::marray{1, 2, 4, 8, 16, 32, 64, 128}, 5); + assert(!b[18] && !b[20] && !b[22] && !b[24] && !b[30] && !b[16] && b[3] && + b[5] && b[14] && b[23]); + char r; + b.extract_bits(r); + assert(r == 0b00101000); + long r2 = -1; + b.extract_bits(r2, 16); + assert(r2 == 128); + b[31] = true; + sycl::marray r3{-1}; + b.extract_bits(r3, 14); + assert(r3[0] == 1 && r3[1] == 2 && r3[2] == 2 && !r3[3] && !r3[4] && !r3[5]); +}