-
Notifications
You must be signed in to change notification settings - Fork 802
[SYCL] Add bfloat16 utils based on libdevice bfloat16 support. #7503
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 8 commits
8817b7c
c4c7da0
972ff48
3cd5bf2
44d0d6f
b514825
7e36261
4b313ae
418ff33
4032477
94b47e4
5fe308b
098ad9c
2afd5ed
46a2622
a7572c9
c3d9bb3
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,232 @@ | ||
| //==-------------------- imf_bf16.hpp - bfloat16 utils ---------------------==// | ||
| // | ||
| // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | ||
| // See https://llvm.org/LICENSE.txt for license information. | ||
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
| // | ||
| //===----------------------------------------------------------------------===// | ||
| // C++ APIs for bfloat16 util functions. | ||
| //===----------------------------------------------------------------------===// | ||
|
|
||
| #pragma once | ||
| #include <sycl/ext/oneapi/bfloat16.hpp> | ||
| #include <type_traits> | ||
|
|
||
| extern "C" { | ||
| float __imf_bfloat162float(uint16_t); | ||
| uint16_t __imf_float2bfloat16(float); | ||
| uint16_t __imf_float2bfloat16_rd(float); | ||
| uint16_t __imf_float2bfloat16_rn(float); | ||
| uint16_t __imf_float2bfloat16_ru(float); | ||
| uint16_t __imf_float2bfloat16_rz(float); | ||
| }; | ||
|
|
||
| namespace sycl { | ||
| __SYCL_INLINE_VER_NAMESPACE(_V1) { | ||
| namespace ext { | ||
| namespace intel { | ||
| namespace math { | ||
|
|
||
| // Need to ensure that sycl bfloat16 defined in bfloat16.hpp is compatible | ||
| // with uint16_t in layout. | ||
| #if __cplusplus >= 201703L | ||
| static_assert(sizeof(sycl::ext::oneapi::bfloat16) == sizeof(uint16_t), | ||
| "sycl bfloat16 is not compatible with uint16_t."); | ||
|
|
||
| float bfloat162float(sycl::ext::oneapi::bfloat16 b) { | ||
| return __imf_bfloat162float(__builtin_bit_cast(uint16_t, b)); | ||
| } | ||
|
|
||
| sycl::ext::oneapi::bfloat16 float2bfloat16(float f) { | ||
| return __builtin_bit_cast(sycl::ext::oneapi::bfloat16, | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think it might be recommended to use
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hi, @JackAKirk |
||
| __imf_float2bfloat16(f)); | ||
| } | ||
|
|
||
| sycl::ext::oneapi::bfloat16 float2bfloat16_rd(float f) { | ||
| return __builtin_bit_cast(sycl::ext::oneapi::bfloat16, | ||
| __imf_float2bfloat16_rd(f)); | ||
| } | ||
|
|
||
| sycl::ext::oneapi::bfloat16 float2bfloat16_rn(float f) { | ||
| return __builtin_bit_cast(sycl::ext::oneapi::bfloat16, | ||
| __imf_float2bfloat16_rn(f)); | ||
| } | ||
|
|
||
| sycl::ext::oneapi::bfloat16 float2bfloat16_ru(float f) { | ||
| return __builtin_bit_cast(sycl::ext::oneapi::bfloat16, | ||
| __imf_float2bfloat16_ru(f)); | ||
| } | ||
|
|
||
| sycl::ext::oneapi::bfloat16 float2bfloat16_rz(float f) { | ||
| return __builtin_bit_cast(sycl::ext::oneapi::bfloat16, | ||
| __imf_float2bfloat16_rz(f)); | ||
| } | ||
|
|
||
| bool hisnan(sycl::ext::oneapi::bfloat16 b) { | ||
| return sycl::isnan(bfloat162float(b)); | ||
| } | ||
|
|
||
| bool hisinf(sycl::ext::oneapi::bfloat16 b) { | ||
| return sycl::isinf(bfloat162float(b)); | ||
| } | ||
|
|
||
| bool heq(sycl::ext::oneapi::bfloat16 b1, sycl::ext::oneapi::bfloat16 b2) { | ||
| if (hisnan(b1) || hisnan(b2)) | ||
| return false; | ||
| return __builtin_bit_cast(uint16_t, b1) == __builtin_bit_cast(uint16_t, b2); | ||
| } | ||
|
|
||
| bool hequ(sycl::ext::oneapi::bfloat16 b1, sycl::ext::oneapi::bfloat16 b2) { | ||
| if (hisnan(b1) || hisnan(b1)) | ||
| return true; | ||
| return __builtin_bit_cast(uint16_t, b1) == __builtin_bit_cast(uint16_t, b2); | ||
| } | ||
|
|
||
| bool hne(sycl::ext::oneapi::bfloat16 b1, sycl::ext::oneapi::bfloat16 b2) { | ||
| if (hisnan(b1) || hisnan(b2)) | ||
| return false; | ||
| return __builtin_bit_cast(uint16_t, b1) != __builtin_bit_cast(uint16_t, b2); | ||
| } | ||
|
|
||
| bool hneu(sycl::ext::oneapi::bfloat16 b1, sycl::ext::oneapi::bfloat16 b2) { | ||
| if (hisnan(b1) || hisnan(b2)) | ||
| return true; | ||
| return __builtin_bit_cast(uint16_t, b1) != __builtin_bit_cast(uint16_t, b2); | ||
| } | ||
|
|
||
| bool hge(sycl::ext::oneapi::bfloat16 b1, sycl::ext::oneapi::bfloat16 b2) { | ||
| if (hisnan(b1) || hisnan(b2)) | ||
| return false; | ||
| float bf1 = bfloat162float(b1); | ||
| float bf2 = bfloat162float(b2); | ||
| return (bf1 >= bf2); | ||
| } | ||
|
|
||
| bool hgeu(sycl::ext::oneapi::bfloat16 b1, sycl::ext::oneapi::bfloat16 b2) { | ||
| if (hisnan(b1) || hisnan(b2)) | ||
| return true; | ||
| float bf1 = bfloat162float(b1); | ||
| float bf2 = bfloat162float(b2); | ||
| return (bf1 >= bf2); | ||
| } | ||
|
|
||
| bool hgt(sycl::ext::oneapi::bfloat16 b1, sycl::ext::oneapi::bfloat16 b2) { | ||
| if (hisnan(b1) || hisnan(b2)) | ||
| return false; | ||
| float bf1 = bfloat162float(b1); | ||
| float bf2 = bfloat162float(b2); | ||
| return (bf1 > bf2); | ||
| } | ||
|
|
||
| bool hgtu(sycl::ext::oneapi::bfloat16 b1, sycl::ext::oneapi::bfloat16 b2) { | ||
| if (hisnan(b1) || hisnan(b2)) | ||
| return true; | ||
| float bf1 = bfloat162float(b1); | ||
| float bf2 = bfloat162float(b2); | ||
| return (bf1 > bf2); | ||
| } | ||
|
|
||
| bool hle(sycl::ext::oneapi::bfloat16 b1, sycl::ext::oneapi::bfloat16 b2) { | ||
| if (hisnan(b1) || hisnan(b2)) | ||
| return false; | ||
| float bf1 = bfloat162float(b1); | ||
| float bf2 = bfloat162float(b2); | ||
| return (bf1 <= bf2); | ||
| } | ||
|
|
||
| bool hleu(sycl::ext::oneapi::bfloat16 b1, sycl::ext::oneapi::bfloat16 b2) { | ||
| if (hisnan(b1) || hisnan(b2)) | ||
| return true; | ||
| float bf1 = bfloat162float(b1); | ||
| float bf2 = bfloat162float(b2); | ||
| return (bf1 <= bf2); | ||
| } | ||
|
|
||
| bool hlt(sycl::ext::oneapi::bfloat16 b1, sycl::ext::oneapi::bfloat16 b2) { | ||
| if (hisnan(b1) || hisnan(b2)) | ||
| return false; | ||
| float bf1 = bfloat162float(b1); | ||
| float bf2 = bfloat162float(b2); | ||
| return (bf1 < bf2); | ||
| } | ||
|
|
||
| bool hltu(sycl::ext::oneapi::bfloat16 b1, sycl::ext::oneapi::bfloat16 b2) { | ||
| if (hisnan(b1) || hisnan(b2)) | ||
| return true; | ||
| float bf1 = bfloat162float(b1); | ||
| float bf2 = bfloat162float(b2); | ||
| return (bf1 < bf2); | ||
| } | ||
|
|
||
| sycl::ext::oneapi::bfloat16 hmax(sycl::ext::oneapi::bfloat16 b1, | ||
| sycl::ext::oneapi::bfloat16 b2) { | ||
| uint16_t canonical_nan = 0x7FC0; | ||
| uint16_t b1a = __builtin_bit_cast(uint16_t, b1); | ||
| uint16_t b2a = __builtin_bit_cast(uint16_t, b2); | ||
| if (hisnan(b1) && hisnan(b2)) | ||
| return __builtin_bit_cast(sycl::ext::oneapi::bfloat16, canonical_nan); | ||
| if (hisnan(b1)) | ||
| return b2; | ||
| else if (hisnan(b2)) | ||
| return b1; | ||
| else if (((b1a | b2a) == 0x8000) && ((b1a & b2a) == 0x0)) | ||
| return __builtin_bit_cast(sycl::ext::oneapi::bfloat16, | ||
| static_cast<uint16_t>(0x0)); | ||
| else { | ||
| return (hgt(b1, b2) ? b1 : b2); | ||
| } | ||
| } | ||
|
|
||
| sycl::ext::oneapi::bfloat16 hmax_nan(sycl::ext::oneapi::bfloat16 b1, | ||
| sycl::ext::oneapi::bfloat16 b2) { | ||
| uint16_t canonical_nan = 0x7FC0; | ||
| uint16_t b1a = __builtin_bit_cast(uint16_t, b1); | ||
| uint16_t b2a = __builtin_bit_cast(uint16_t, b2); | ||
| if (hisnan(b1) || hisnan(b2)) | ||
| return __builtin_bit_cast(sycl::ext::oneapi::bfloat16, canonical_nan); | ||
| else if (((b1a | b2a) == 0x8000) && ((b1a & b2a) == 0x0)) | ||
| return __builtin_bit_cast(sycl::ext::oneapi::bfloat16, | ||
| static_cast<uint16_t>(0x0)); | ||
| else | ||
| return (hgt(b1, b2) ? b1 : b2); | ||
| } | ||
|
|
||
| sycl::ext::oneapi::bfloat16 hmin(sycl::ext::oneapi::bfloat16 b1, | ||
| sycl::ext::oneapi::bfloat16 b2) { | ||
| uint16_t canonical_nan = 0x7FC0; | ||
| uint16_t b1a = __builtin_bit_cast(uint16_t, b1); | ||
| uint16_t b2a = __builtin_bit_cast(uint16_t, b2); | ||
| if (hisnan(b1) && hisnan(b2)) | ||
| return __builtin_bit_cast(sycl::ext::oneapi::bfloat16, canonical_nan); | ||
| if (hisnan(b1)) | ||
| return b2; | ||
| else if (hisnan(b2)) | ||
| return b1; | ||
| else if (((b1a | b2a) == 0x8000) && ((b1a & b2a) == 0x0)) | ||
| return __builtin_bit_cast(sycl::ext::oneapi::bfloat16, | ||
| static_cast<uint16_t>(0x8000)); | ||
| else { | ||
| return (hlt(b1, b2) ? b1 : b2); | ||
| } | ||
| } | ||
|
|
||
| sycl::ext::oneapi::bfloat16 hmin_nan(sycl::ext::oneapi::bfloat16 b1, | ||
| sycl::ext::oneapi::bfloat16 b2) { | ||
| uint16_t canonical_nan = 0x7FC0; | ||
| uint16_t b1a = __builtin_bit_cast(uint16_t, b1); | ||
| uint16_t b2a = __builtin_bit_cast(uint16_t, b2); | ||
| if (hisnan(b1) || hisnan(b2)) | ||
| return __builtin_bit_cast(sycl::ext::oneapi::bfloat16, canonical_nan); | ||
| else if (((b1a | b2a) == 0x8000) && ((b1a & b2a) == 0x0)) | ||
| return __builtin_bit_cast(sycl::ext::oneapi::bfloat16, | ||
| static_cast<uint16_t>(0x8000)); | ||
| else | ||
| return (hlt(b1, b2) ? b1 : b2); | ||
| } | ||
|
|
||
| #endif | ||
| } // namespace math | ||
| } // namespace intel | ||
| } // namespace ext | ||
| } // __SYCL_INLINE_VER_NAMESPACE(_V1) | ||
| } // namespace sycl | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No need to check this - C++17 is the minimal supported version.