-
Notifications
You must be signed in to change notification settings - Fork 808
[Matrix][SYCL] Add bfloat16 support for joint_matrix #5566
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
aa5cf45
c310976
7b8b37e
37d75f5
a6a7c82
6bc53d5
b5c1194
a42e3aa
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -11,6 +11,7 @@ | |
| #include <CL/__spirv/spirv_ops.hpp> | ||
| #include <CL/sycl/detail/defines_elementary.hpp> | ||
| #include <CL/sycl/feature_test.hpp> | ||
| #include <sycl/ext/intel/experimental/bfloat16.hpp> | ||
|
|
||
| __SYCL_INLINE_NAMESPACE(cl) { | ||
| namespace sycl { | ||
|
|
@@ -737,6 +738,165 @@ class wi_element<uint16_t, NumRows, NumCols, Layout, Group> { | |
| } | ||
| }; | ||
|
|
||
| template <size_t NumRows, size_t NumCols, matrix_layout Layout, typename Group> | ||
| class wi_element<sycl::ext::intel::experimental::bfloat16, NumRows, NumCols, | ||
|
||
| Layout, Group> { | ||
| joint_matrix<sycl::ext::intel::experimental::bfloat16, NumRows, NumCols, | ||
| Layout, Group> &M; | ||
| std::size_t idx; | ||
|
|
||
| public: | ||
| wi_element(joint_matrix<sycl::ext::intel::experimental::bfloat16, NumRows, | ||
| NumCols, Layout, Group> &Mat, | ||
| std::size_t i) | ||
| : M(Mat), idx(i) {} | ||
| operator sycl::ext::intel::experimental::bfloat16() { | ||
| #ifdef __SYCL_DEVICE_ONLY__ | ||
| return __spirv_VectorExtractDynamic(M.spvm, idx); | ||
| #else | ||
| throw runtime_error("joint matrix is not supported on host device.", | ||
| PI_INVALID_DEVICE); | ||
| #endif // __SYCL_DEVICE_ONLY__ | ||
| } | ||
|
|
||
| explicit operator bool() { | ||
| #ifdef __SYCL_DEVICE_ONLY__ | ||
| return std::fabs(static_cast<float>(__spirv_VectorExtractDynamic( | ||
| M.spvm, idx))) >= std::numeric_limits<float>::epsilon(); | ||
| #else | ||
| throw runtime_error("joint matrix is not supported on host device.", | ||
| PI_INVALID_DEVICE); | ||
| #endif // __SYCL_DEVICE_ONLY__ | ||
| } | ||
|
|
||
| wi_element &operator=(const sycl::ext::intel::experimental::bfloat16 &rhs) { | ||
| #ifdef __SYCL_DEVICE_ONLY__ | ||
| M.spvm = __spirv_VectorInsertDynamic(M.spvm, rhs, idx); | ||
| return *this; | ||
| #else | ||
| (void)rhs; | ||
| throw runtime_error("joint matrix is not supported on host device.", | ||
| PI_INVALID_DEVICE); | ||
| #endif // __SYCL_DEVICE_ONLY__ | ||
| } | ||
|
|
||
| wi_element & | ||
| operator=(const wi_element<sycl::ext::intel::experimental::bfloat16, NumRows, | ||
| NumCols, Layout, Group> &rhs) { | ||
| #ifdef __SYCL_DEVICE_ONLY__ | ||
| M.spvm = __spirv_VectorInsertDynamic( | ||
| M.spvm, __spirv_VectorExtractDynamic(rhs.M.spvm, rhs.idx), idx); | ||
| return *this; | ||
| #else | ||
| (void)rhs; | ||
| throw runtime_error("joint matrix is not supported on host device.", | ||
| PI_INVALID_DEVICE); | ||
| #endif // __SYCL_DEVICE_ONLY__ | ||
| } | ||
|
|
||
| #if __SYCL_DEVICE_ONLY__ | ||
| #define OP(opassign, op) \ | ||
| wi_element &operator opassign( \ | ||
| const sycl::ext::intel::experimental::bfloat16 &rhs) { \ | ||
| M.spvm = __spirv_VectorInsertDynamic( \ | ||
| M.spvm, \ | ||
| static_cast<float>(__spirv_VectorExtractDynamic(M.spvm, idx) \ | ||
| op static_cast<float>(rhs))), \ | ||
| idx); \ | ||
| return *this; \ | ||
| } | ||
| #else // __SYCL_DEVICE_ONLY__ | ||
| #define OP(opassign, op) \ | ||
| wi_element &operator opassign( \ | ||
| const sycl::ext::intel::experimental::bfloat16 &rhs) { \ | ||
| (void)rhs; \ | ||
| throw runtime_error("joint matrix is not supported on host device.", \ | ||
| PI_INVALID_DEVICE); \ | ||
| } | ||
| #endif // __SYCL_DEVICE_ONLY__ | ||
| OP(+=, +) | ||
| OP(-=, -) | ||
| OP(*=, *) | ||
| OP(/=, /) | ||
| #undef OP | ||
|
|
||
| #if __SYCL_DEVICE_ONLY__ | ||
| #define OP(type, op) \ | ||
| friend type operator op( \ | ||
| const wi_element<sycl::ext::intel::experimental::bfloat16, NumRows, \ | ||
| NumCols, Layout, Group> &lhs, \ | ||
| const sycl::ext::intel::experimental::bfloat16 &rhs) { \ | ||
| return static_cast<float>(__spirv_VectorExtractDynamic( \ | ||
| lhs.M.spvm, lhs.idx)) op static_cast<float>(rhs); \ | ||
| } \ | ||
| friend type operator op( \ | ||
| const sycl::ext::intel::experimental::bfloat16 &lhs, \ | ||
| const wi_element<sycl::ext::intel::experimental::bfloat16, NumRows, \ | ||
| NumCols, Layout, Group> &rhs) { \ | ||
| return static_cast<float>(__spirv_VectorExtractDynamic( \ | ||
| rhs.M.spvm, rhs.idx)) op static_cast<float>(lhs); \ | ||
| } | ||
| OP(sycl::ext::intel::experimental::bfloat16, +) | ||
| OP(sycl::ext::intel::experimental::bfloat16, -) | ||
| OP(sycl::ext::intel::experimental::bfloat16, *) | ||
| OP(sycl::ext::intel::experimental::bfloat16, /) | ||
| #undef OP | ||
| #define OP(type, op) \ | ||
| friend type operator op( \ | ||
| const wi_element<sycl::ext::intel::experimental::bfloat16, NumRows, \ | ||
| NumCols, Layout, Group> &lhs, \ | ||
| const sycl::ext::intel::experimental::bfloat16 &rhs) { \ | ||
| return type{static_cast<float>(__spirv_VectorExtractDynamic( \ | ||
| lhs.M.spvm, lhs.idx)) op static_cast<float>(rhs)}; \ | ||
| } \ | ||
| friend type operator op( \ | ||
| const sycl::ext::intel::experimental::bfloat16 &lhs, \ | ||
| const wi_element<sycl::ext::intel::experimental::bfloat16, NumRows, \ | ||
| NumCols, Layout, Group> &rhs) { \ | ||
| return type{static_cast<float>(__spirv_VectorExtractDynamic( \ | ||
| rhs.M.spvm, rhs.idx)) op static_cast<float>(lhs)}; \ | ||
| } | ||
| OP(bool, ==) | ||
| OP(bool, !=) | ||
| OP(bool, <) | ||
| OP(bool, >) | ||
| OP(bool, <=) | ||
| OP(bool, >=) | ||
| #undef OP | ||
| #else // __SYCL_DEVICE_ONLY__ | ||
| #define OP(type, op) \ | ||
| friend type operator op( \ | ||
| const wi_element<sycl::ext::intel::experimental::bfloat16, NumRows, \ | ||
| NumCols, Layout, Group> &lhs, \ | ||
| const sycl::ext::intel::experimental::bfloat16 &rhs) { \ | ||
| (void)lhs; \ | ||
| (void)rhs; \ | ||
|
||
| throw runtime_error("joint matrix is not supported on host device.", \ | ||
| PI_INVALID_DEVICE); \ | ||
| } \ | ||
| friend type operator op( \ | ||
| const sycl::ext::intel::experimental::bfloat16 &lhs, \ | ||
| const wi_element<sycl::ext::intel::experimental::bfloat16, NumRows, \ | ||
| NumCols, Layout, Group> &rhs) { \ | ||
| (void)lhs; \ | ||
| (void)rhs; \ | ||
| throw runtime_error("joint matrix is not supported on host device.", \ | ||
| PI_INVALID_DEVICE); \ | ||
| } | ||
| OP(sycl::ext::intel::experimental::bfloat16, +) | ||
| OP(sycl::ext::intel::experimental::bfloat16, -) | ||
| OP(sycl::ext::intel::experimental::bfloat16, *) | ||
| OP(sycl::ext::intel::experimental::bfloat16, /) | ||
| OP(bool, ==) | ||
| OP(bool, !=) | ||
| OP(bool, <) | ||
| OP(bool, >) | ||
| OP(bool, <=) | ||
| OP(bool, >=) | ||
| #undef OP | ||
| #endif // __SYCL_DEVICE_ONLY__ | ||
| }; | ||
|
|
||
| template <typename T, size_t NumRows, size_t NumCols, matrix_layout Layout, | ||
| typename Group> | ||
| class wi_slice { | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Buildfail will happen since
class [[sycl_detail::uses_aspects(ext_intel_bf16_conversion)]] bfloat16
and sycl_detail::uses_aspects is unsupported code in g++ which is used for intel/llvm's build