-
Notifications
You must be signed in to change notification settings - Fork 802
[SYCL] Implement sub-group mask extension #4481
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 6 commits
Commits
Show all changes
16 commits
Select commit
Hold shift + click to select a range
880ae2a
[SYCL] Implement GroupMask extension
vladimirlaz 8b20c3a
Add referencing individual build
vladimirlaz 95e56e9
Initial implementation of all methods
vladimirlaz a26748f
Remove dependency on lib
vladimirlaz 3c578c6
Bugfix
vladimirlaz 24e5a95
Fix internal layout
vladimirlaz db1f5b8
Update tests
vladimirlaz 4442ea4
Apply review comments
vladimirlaz 6cf5820
Merge branch 'group_mask' of github.com:vladimirlaz/llvm into group_mask
vladimirlaz baaeb39
Fix clang-format
vladimirlaz 585fa1b
Apply review comments
vladimirlaz d855e48
Rename feature to sub-group mask
vladimirlaz a184e29
Apply review comments
vladimirlaz e39e8e1
Fix clang format
vladimirlaz fdd0024
Apply review comments
vladimirlaz 1bd4fc3
Apply review comment
vladimirlaz File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,236 @@ | ||
| //==----------------- group_mask.hpp --- SYCL group mask -------------------==// | ||
| // | ||
| // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | ||
| // See https://llvm.org/LICENSE.txt for license information. | ||
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
| // | ||
| //===----------------------------------------------------------------------===// | ||
| #pragma once | ||
|
|
||
| #include <CL/__spirv/spirv_ops.hpp> | ||
| #include <CL/__spirv/spirv_vars.hpp> | ||
| #include <CL/sycl/exception.hpp> | ||
| #include <CL/sycl/id.hpp> | ||
| #include <CL/sycl/marray.hpp> | ||
|
|
||
| __SYCL_INLINE_NAMESPACE(cl) { | ||
| namespace sycl { | ||
| namespace ext { | ||
| namespace oneapi { | ||
|
|
||
| struct group_mask { | ||
| using WordType = uint32_t; | ||
| static constexpr size_t max_bits = 128 /* implementation-defined */; | ||
| static constexpr size_t word_size = sizeof(WordType) * CHAR_BIT; | ||
| /* Bitmask is packed in marray of uint32_t elements. This value represents | ||
| * legth of marray. */ | ||
| static constexpr size_t marray_size = max_bits / word_size; | ||
| /* The bits are stored in the memory in the following way: | ||
| marray id | 0 | 1 | 2 | 3 | | ||
| bit id |127 .. 96|95 .. 64|63 .. 32|31 .. 0| | ||
| */ | ||
| // enable reference to individual bit | ||
| struct reference { | ||
| reference &operator=(bool x) { | ||
| if (x) { | ||
| Ref |= RefBit; | ||
| } else { | ||
| Ref &= ~RefBit; | ||
| } | ||
| return *this; | ||
| } | ||
| reference &operator=(const reference &x) { | ||
| operator=((bool)x); | ||
| return *this; | ||
| } | ||
| bool operator~() const { return !(Ref & RefBit); } | ||
| operator bool() const { return Ref & RefBit; } | ||
| reference &flip() { | ||
| operator=(!(bool)*this); | ||
| return *this; | ||
| } | ||
|
|
||
| reference(group_mask &gmask, size_t pos) | ||
| : Ref(gmask.Bits[marray_size - (pos / word_size) - 1]) { | ||
| RefBit = 1 << pos % word_size; | ||
| } | ||
|
|
||
| private: | ||
| // Reference to the word containing the bit | ||
| WordType &Ref; | ||
| // Bit mask where only referenced bit is set | ||
| WordType RefBit; | ||
| }; | ||
|
|
||
| bool operator[](id<1> id) const { | ||
| return Bits[marray_size - id.get(0) / word_size - 1] & | ||
| (1 << (id.get(0) % word_size)); | ||
| } | ||
| reference operator[](id<1> id) { return {*this, id.get(0)}; } | ||
| bool test(id<1> id) const { return operator[](id); } | ||
| bool all() const { return !(~(Bits[0] & Bits[1] & Bits[2] & Bits[3])); } | ||
| bool any() const { return Bits[0] | Bits[1] | Bits[2] | Bits[3]; } | ||
| bool none() const { return !any(); } | ||
| uint32_t count() const { | ||
| unsigned int count = 0; | ||
| for (auto word : Bits) { | ||
| while (word) { | ||
| word &= (word - 1); | ||
| count++; | ||
| } | ||
| } | ||
| return count; | ||
| } | ||
| uint32_t size() const { return max_bits; } | ||
| id<1> find_low() const { | ||
| size_t i = 0; | ||
| while (i < size() && !operator[](i)) | ||
| i++; | ||
| return {i}; | ||
| } | ||
| id<1> find_high() const { | ||
| size_t i = size() - 1; | ||
| while (i > 0 && !operator[](i)) | ||
| i--; | ||
| return {operator[](i) ? i : size()}; | ||
| } | ||
|
|
||
| template <typename T = marray<WordType, marray_size>> | ||
| void insert_bits(const T &bits, id<1> pos = 0) { | ||
| group_mask tmp(bits); | ||
| if (pos.get(0) > 0) { | ||
| operator<<=(max_bits - pos.get(0)); | ||
| operator>>=(max_bits - pos.get(0)); | ||
| tmp <<= pos.get(0); | ||
| } else { | ||
| reset(); | ||
| } | ||
| Bits |= tmp.Bits; | ||
| } | ||
|
|
||
| template <typename T = marray<WordType, marray_size>> | ||
| T extract_bits(id<1> pos = 0) { | ||
| group_mask Tmp = *this; | ||
| Tmp <<= pos.get(0); | ||
| return Tmp.Bits; | ||
| } | ||
|
|
||
| void set() { Bits = ~(WordType{0}); } | ||
| void set(id<1> id, bool value = true) { operator[](id) = value; } | ||
| void reset() { Bits = WordType{0}; } | ||
| void reset(id<1> id) { operator[](id) = 0; } | ||
| void reset_low() { reset(find_low()); } | ||
| void reset_high() { reset(find_high()); } | ||
| void flip() { Bits = ~Bits; } | ||
| void flip(id<1> id) { operator[](id).flip(); } | ||
|
|
||
| bool operator==(const group_mask &rhs) const { | ||
| bool Res = true; | ||
| for (size_t i = 0; i < marray_size; i++) | ||
| Res &= Bits[i] == rhs.Bits[i]; | ||
| return Res; | ||
| } | ||
| bool operator!=(const group_mask &rhs) const { return !(*this == rhs); } | ||
|
|
||
| group_mask &operator&=(const group_mask &rhs) { | ||
| Bits &= rhs.Bits; | ||
| return *this; | ||
| } | ||
| group_mask &operator|=(const group_mask &rhs) { | ||
| Bits |= rhs.Bits; | ||
| return *this; | ||
| } | ||
|
|
||
| group_mask &operator^=(const group_mask &rhs) { | ||
| Bits ^= rhs.Bits; | ||
| return *this; | ||
| } | ||
|
|
||
| group_mask &operator<<=(size_t pos) { | ||
| if (pos > 0) { | ||
| marray<WordType, marray_size> Res{0}; | ||
| size_t word_shift = pos / word_size; | ||
| size_t bit_shift = pos % word_size; | ||
| WordType extra_bits = 0; | ||
| for (int i = marray_size - 1; i >= 0; i--) { | ||
| Res[i - word_shift] = (Bits[i] << bit_shift) + extra_bits; | ||
| extra_bits = Bits[i] >> (word_size - bit_shift); | ||
| } | ||
| Bits = Res; | ||
| } | ||
| return *this; | ||
| } | ||
|
|
||
| group_mask &operator>>=(size_t pos) { | ||
| if (pos > 0) { | ||
| marray<WordType, marray_size> Res{0}; | ||
| size_t word_shift = pos / word_size; | ||
| size_t bit_shift = pos % word_size; | ||
| WordType extra_bits = 0; | ||
| for (size_t i = 0; i < marray_size; i++) { | ||
| Res[i + word_shift] = (Bits[i] >> bit_shift) + extra_bits; | ||
| extra_bits = Bits[i] << (word_size - bit_shift); | ||
| } | ||
| Bits = Res; | ||
| } | ||
| return *this; | ||
| } | ||
|
|
||
| group_mask operator~() const { | ||
| auto Tmp = *this; | ||
| Tmp.flip(); | ||
| return Tmp; | ||
| } | ||
| group_mask operator<<(size_t pos) const { | ||
| auto Tmp = *this; | ||
| Tmp <<= pos; | ||
| return Tmp; | ||
| } | ||
| group_mask operator>>(size_t pos) const { | ||
| auto Tmp = *this; | ||
| Tmp >>= pos; | ||
| return Tmp; | ||
| } | ||
|
|
||
| group_mask(const group_mask &rhs) : Bits(rhs.Bits) {} | ||
| template <typename Group> | ||
| friend group_mask group_ballot(Group g, bool predicate); | ||
|
|
||
| group_mask(const marray<WordType, marray_size> &rhs) : Bits(rhs) {} | ||
|
|
||
| group_mask operator&(const group_mask &rhs) { | ||
| auto Res = *this; | ||
| Res &= rhs; | ||
| return Res; | ||
| } | ||
| group_mask operator|(const group_mask &rhs) { | ||
| auto Res = *this; | ||
| Res |= rhs; | ||
| return Res; | ||
| } | ||
| group_mask operator^(const group_mask &rhs) { | ||
| auto Res = *this; | ||
| Res ^= rhs; | ||
| return Res; | ||
| } | ||
|
|
||
| private: | ||
| marray<WordType, marray_size> Bits; | ||
| }; | ||
| template <typename Group> group_mask group_ballot(Group g, bool predicate) { | ||
| (void)g; | ||
| #ifdef __SYCL_DEVICE_ONLY__ | ||
| auto res = __spirv_GroupNonUniformBallot( | ||
| detail::spirv::group_scope<Group>::value, predicate); | ||
| return marray<group_mask::WordType, group_mask::marray_size>{res[3], res[2], | ||
| res[1], res[0]}; | ||
| #else | ||
| (void)predicate; | ||
| throw exception{errc::feature_not_supported, | ||
| "Group mask is not supported on host device"}; | ||
| #endif | ||
| } | ||
| } // namespace oneapi | ||
| } // namespace ext | ||
| } // namespace sycl | ||
| } // __SYCL_INLINE_NAMESPACE(cl) |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.