Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 21 additions & 23 deletions sycl/doc/extensions/GroupMask/GroupMask.asciidoc
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ must be encountered by all work-items in the group in converged control flow.
|===
|Function|Description

|`template <typename Group> Group::mask_type group_ballot(Group g, bool predicate = true) const`
|`template <typename Group> Group::mask_type group_ballot(Group g, bool predicate = true)`
|Return a `group_mask` representing the set of work-items in group _g_ for which _predicate_ is `true`.
|===

Expand Down Expand Up @@ -137,17 +137,15 @@ work-item with the id `max_local_range()-1`.
|Return the highest `id` with a corresponding bit set in the mask. If no bits
are set, the return value is equal to `size()`.

|`template <typename T = marray<uint32_t, max_bits/sizeof(uint32_t)>> void insert_bits(T bits, id<1> pos = 0)`
|`template <typename T = marray<uint32_t, marray_size>> void insert_bits(const T& bits, id<1> pos = 0)`
|Insert `CHAR_BIT * sizeof(T)` bits into the mask, starting from _pos_. `T`
must be an integral type or a SYCL `marray` of integral types. _pos_ must be a
multiple of `CHAR_BIT * sizeof(T)` in the range [0, `size()`). If _pos_ pass:[+]
must be an integral type or a SYCL `marray` of integral types. If _pos_ pass:[+]
`CHAR_BIT * sizeof(T)` is greater than `size()`, the final `size()` - (_pos_ pass:[+]
`CHAR_BIT * sizeof(T)`) bits are ignored.

|`template <typename T = marray<uint32_t, max_bits/sizeof(uint32_t)>> T extract_bits(id<1> pos = 0) const`
|`template <typename T = marray<uint32_t, marray_size>> T extract_bits(id<1> pos = 0) const`
|Return `CHAR_BIT * sizeof(T)` bits from the mask, starting from _pos_. `T`
must be an integral type or a SYCL `marray` of integral types. _pos_ must be a
multiple of `CHAR_BIT * sizeof(T)` in the range [0, `size()`). If _pos_ pass:[+]
must be an integral type or a SYCL `marray` of integral types. If _pos_ pass:[+]
`CHAR_BIT * sizeof(T)` is greater than `size()`, the final `size()` - (_pos_ pass:[+]
`CHAR_BIT * sizeof(T)`) bits of the return value are zero.

Expand Down Expand Up @@ -259,6 +257,7 @@ struct group_mask {
};

static constexpr size_t max_bits = /* implementation-defined */;
static constexpr size_t marray_size = max_bits/sizeof(uint32_t)/CHAR_BIT;

bool operator[](id<1> id) const;
reference operator[](id<1> id);
Expand All @@ -271,10 +270,10 @@ struct group_mask {
id<1> find_low() const;
id<1> find_high() const;

template <typename T = marray<uint32_t, max_bits/sizeof(uint32_t)>>
void insert_bits(T bits, id<1> pos = 0);
template <typename T = marray<uint32_t, marray_size>>
void insert_bits(const T& bits, id<1> pos = 0);

template <typename T = marray<uint32_t, max_bits/sizeof(uint32_t)>>
template <typename T = marray<uint32_t, marray_size>>
T extract_bits(id<1> pos = 0);

void set();
Expand All @@ -286,25 +285,24 @@ struct group_mask {
void flip();
void flip(id<1> id);

bool operator==(group_mask rhs) const;
bool operator!=(group_mask rhs) const;
bool operator==(const group_mask& rhs) const;
bool operator!=(const group_mask& rhs) const;

group_mask operator &=(group_mask rhs);
group_mask operator |=(group_mask rhs);
group_mask operator ^=(group_mask rhs);
group_mask operator <<=(size_t);
group_mask operator >>=(size_t rhs);
group_mask &operator &=(const group_mask& rhs);
group_mask &operator |=(const group_mask& rhs);
group_mask &operator ^=(const group_mask& rhs);
group_mask &operator <<=(size_t n);
group_mask &operator >>=(size_t n);

group_mask operator ~() const;
group_mask operator <<(size_t) const;
group_mask operator >>(size_t) const;
group_mask operator <<(size_t n) const;
group_mask operator >>(size_t n) const;

group_mask operator &(const group_mask& rhs) const;
group_mask operator |(const group_mask& rhs) const;
group_mask operator ^(const group_mask& rhs) const;
};

group_mask operator &(const group_mask& lhs, const group_mask& rhs);
group_mask operator |(const group_mask& lhs, const group_mask& rhs);
group_mask operator ^(const group_mask& lhs, const group_mask& rhs);

} // namespace oneapi
} // namespace ext
} // namespace sycl
Expand Down
3 changes: 3 additions & 0 deletions sycl/include/CL/__spirv/spirv_ops.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -597,6 +597,9 @@ __spirv_ocl_prefetch(const __attribute__((opencl_global)) char *Ptr,
extern SYCL_EXTERNAL uint16_t __spirv_ConvertFToBF16INTEL(float) noexcept;
extern SYCL_EXTERNAL float __spirv_ConvertBF16ToFINTEL(uint16_t) noexcept;

__SYCL_CONVERGENT__ extern SYCL_EXTERNAL __SYCL_EXPORT __ocl_vec_t<uint32_t, 4>
__spirv_GroupNonUniformBallot(uint32_t Execution, bool Predicate) noexcept;

#else // if !__SYCL_DEVICE_ONLY__

template <typename dataT>
Expand Down
1 change: 1 addition & 0 deletions sycl/include/CL/sycl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
#include <sycl/ext/oneapi/filter_selector.hpp>
#include <sycl/ext/oneapi/function_pointer.hpp>
#include <sycl/ext/oneapi/group_algorithm.hpp>
#include <sycl/ext/oneapi/group_mask.hpp>
#include <sycl/ext/oneapi/matrix/matrix.hpp>
#include <sycl/ext/oneapi/reduction.hpp>
#include <sycl/ext/oneapi/sub_group.hpp>
1 change: 1 addition & 0 deletions sycl/include/CL/sycl/feature_test.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ namespace sycl {

// TODO: Move these feature-test macros to compiler driver.
#define SYCL_EXT_INTEL_DEVICE_INFO 2
#define SYCL_EXT_ONEAPI_GROUP_MASK 1
#define SYCL_EXT_ONEAPI_LOCAL_MEMORY 1
// As for SYCL_EXT_ONEAPI_MATRIX:
// 1- provides AOT initial implementation for AMX for the experimental matrix
Expand Down
12 changes: 6 additions & 6 deletions sycl/include/CL/sycl/marray.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -149,9 +149,9 @@ template <typename Type, std::size_t NumElements> class marray {
}

#define __SYCL_BINOP_INTEGRAL(BINOP, OPASSIGN) \
template <typename T = DataT> \
friend typename std::enable_if<std::is_integral<T>::value, marray> \
operator BINOP(const marray &Lhs, const marray &Rhs) { \
template <typename T = DataT, \
typename = std::enable_if<std::is_integral<T>::value, marray>> \
friend marray operator BINOP(const marray &Lhs, const marray &Rhs) { \
marray Ret; \
for (size_t I = 0; I < NumElements; ++I) { \
Ret[I] = Lhs[I] BINOP Rhs[I]; \
Expand All @@ -166,9 +166,9 @@ template <typename Type, std::size_t NumElements> class marray {
operator BINOP(const marray &Lhs, const T &Rhs) { \
return Lhs BINOP marray(static_cast<DataT>(Rhs)); \
} \
template <typename T = DataT> \
friend typename std::enable_if<std::is_integral<T>::value, marray> \
&operator OPASSIGN(marray &Lhs, const marray &Rhs) { \
template <typename T = DataT, \
typename = std::enable_if<std::is_integral<T>::value, marray>> \
friend marray &operator OPASSIGN(marray &Lhs, const marray &Rhs) { \
Lhs = Lhs BINOP Rhs; \
return Lhs; \
} \
Expand Down
236 changes: 236 additions & 0 deletions sycl/include/sycl/ext/oneapi/group_mask.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,236 @@
//==----------------- group_mask.hpp --- SYCL group mask -------------------==//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#pragma once

#include <CL/__spirv/spirv_ops.hpp>
#include <CL/__spirv/spirv_vars.hpp>
#include <CL/sycl/exception.hpp>
#include <CL/sycl/id.hpp>
#include <CL/sycl/marray.hpp>

__SYCL_INLINE_NAMESPACE(cl) {
namespace sycl {
namespace ext {
namespace oneapi {

struct group_mask {
using WordType = uint32_t;
static constexpr size_t max_bits = 128 /* implementation-defined */;
static constexpr size_t word_size = sizeof(WordType) * CHAR_BIT;
/* Bitmask is packed in marray of uint32_t elements. This value represents
* legth of marray. */
static constexpr size_t marray_size = max_bits / word_size;
/* The bits are stored in the memory in the following way:
marray id | 0 | 1 | 2 | 3 |
bit id |127 .. 96|95 .. 64|63 .. 32|31 .. 0|
*/
// enable reference to individual bit
struct reference {
reference &operator=(bool x) {
if (x) {
Ref |= RefBit;
} else {
Ref &= ~RefBit;
}
return *this;
}
reference &operator=(const reference &x) {
operator=((bool)x);
return *this;
}
bool operator~() const { return !(Ref & RefBit); }
operator bool() const { return Ref & RefBit; }
reference &flip() {
operator=(!(bool)*this);
return *this;
}

reference(group_mask &gmask, size_t pos)
: Ref(gmask.Bits[marray_size - (pos / word_size) - 1]) {
RefBit = 1 << pos % word_size;
}

private:
// Reference to the word containing the bit
WordType &Ref;
// Bit mask where only referenced bit is set
WordType RefBit;
};

bool operator[](id<1> id) const {
return Bits[marray_size - id.get(0) / word_size - 1] &
(1 << (id.get(0) % word_size));
}
reference operator[](id<1> id) { return {*this, id.get(0)}; }
bool test(id<1> id) const { return operator[](id); }
bool all() const { return !(~(Bits[0] & Bits[1] & Bits[2] & Bits[3])); }
bool any() const { return Bits[0] | Bits[1] | Bits[2] | Bits[3]; }
bool none() const { return !any(); }
uint32_t count() const {
unsigned int count = 0;
for (auto word : Bits) {
while (word) {
word &= (word - 1);
count++;
}
}
return count;
}
uint32_t size() const { return max_bits; }
id<1> find_low() const {
size_t i = 0;
while (i < size() && !operator[](i))
i++;
return {i};
}
id<1> find_high() const {
size_t i = size() - 1;
while (i > 0 && !operator[](i))
i--;
return {operator[](i) ? i : size()};
}

template <typename T = marray<WordType, marray_size>>
void insert_bits(const T &bits, id<1> pos = 0) {
group_mask tmp(bits);
if (pos.get(0) > 0) {
operator<<=(max_bits - pos.get(0));
operator>>=(max_bits - pos.get(0));
tmp <<= pos.get(0);
} else {
reset();
}
Bits |= tmp.Bits;
}

template <typename T = marray<WordType, marray_size>>
T extract_bits(id<1> pos = 0) {
group_mask Tmp = *this;
Tmp <<= pos.get(0);
return Tmp.Bits;
}

void set() { Bits = ~(WordType{0}); }
void set(id<1> id, bool value = true) { operator[](id) = value; }
void reset() { Bits = WordType{0}; }
void reset(id<1> id) { operator[](id) = 0; }
void reset_low() { reset(find_low()); }
void reset_high() { reset(find_high()); }
void flip() { Bits = ~Bits; }
void flip(id<1> id) { operator[](id).flip(); }

bool operator==(const group_mask &rhs) const {
bool Res = true;
for (size_t i = 0; i < marray_size; i++)
Res &= Bits[i] == rhs.Bits[i];
return Res;
}
bool operator!=(const group_mask &rhs) const { return !(*this == rhs); }

group_mask &operator&=(const group_mask &rhs) {
Bits &= rhs.Bits;
return *this;
}
group_mask &operator|=(const group_mask &rhs) {
Bits |= rhs.Bits;
return *this;
}

group_mask &operator^=(const group_mask &rhs) {
Bits ^= rhs.Bits;
return *this;
}

group_mask &operator<<=(size_t pos) {
if (pos > 0) {
marray<WordType, marray_size> Res{0};
size_t word_shift = pos / word_size;
size_t bit_shift = pos % word_size;
WordType extra_bits = 0;
for (int i = marray_size - 1; i >= 0; i--) {
Res[i - word_shift] = (Bits[i] << bit_shift) + extra_bits;
extra_bits = Bits[i] >> (word_size - bit_shift);
}
Bits = Res;
}
return *this;
}

group_mask &operator>>=(size_t pos) {
if (pos > 0) {
marray<WordType, marray_size> Res{0};
size_t word_shift = pos / word_size;
size_t bit_shift = pos % word_size;
WordType extra_bits = 0;
for (size_t i = 0; i < marray_size; i++) {
Res[i + word_shift] = (Bits[i] >> bit_shift) + extra_bits;
extra_bits = Bits[i] << (word_size - bit_shift);
}
Bits = Res;
}
return *this;
}

group_mask operator~() const {
auto Tmp = *this;
Tmp.flip();
return Tmp;
}
group_mask operator<<(size_t pos) const {
auto Tmp = *this;
Tmp <<= pos;
return Tmp;
}
group_mask operator>>(size_t pos) const {
auto Tmp = *this;
Tmp >>= pos;
return Tmp;
}

group_mask(const group_mask &rhs) : Bits(rhs.Bits) {}
template <typename Group>
friend group_mask group_ballot(Group g, bool predicate);

group_mask(const marray<WordType, marray_size> &rhs) : Bits(rhs) {}

group_mask operator&(const group_mask &rhs) {
auto Res = *this;
Res &= rhs;
return Res;
}
group_mask operator|(const group_mask &rhs) {
auto Res = *this;
Res |= rhs;
return Res;
}
group_mask operator^(const group_mask &rhs) {
auto Res = *this;
Res ^= rhs;
return Res;
}

private:
marray<WordType, marray_size> Bits;
};
template <typename Group> group_mask group_ballot(Group g, bool predicate) {
(void)g;
#ifdef __SYCL_DEVICE_ONLY__
auto res = __spirv_GroupNonUniformBallot(
detail::spirv::group_scope<Group>::value, predicate);
return marray<group_mask::WordType, group_mask::marray_size>{res[3], res[2],
res[1], res[0]};
#else
(void)predicate;
throw exception{errc::feature_not_supported,
"Group mask is not supported on host device"};
#endif
}
} // namespace oneapi
} // namespace ext
} // namespace sycl
} // __SYCL_INLINE_NAMESPACE(cl)
Loading