Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
Show all changes
99 commits
Select commit Hold shift + click to select a range
6014cef
[SYCL] Move bfloat support from experimental to supported.
rdeodhar Aug 3, 2022
bdd88e5
Corrections to tests.
rdeodhar Aug 3, 2022
73ed541
Merge branch 'sycl' of https://github.com/intel/llvm into bfloat16
rdeodhar Aug 24, 2022
0fe1884
Moved another file out of experimental space.
rdeodhar Aug 24, 2022
feb9d5f
Responses to review comments.
rdeodhar Aug 25, 2022
129f53f
Removed unneeded sycl::half conversion and updated doc.
rdeodhar Aug 26, 2022
2115f09
Added conversion from sycl::half to bfloat16.
rdeodhar Aug 29, 2022
3c2eb80
Cleanup of documentation.
rdeodhar Aug 31, 2022
74aa175
Hooked up bfloat16 aspect within OpenCL plugin.
rdeodhar Sep 2, 2022
bd05711
Support for bfloat16 aspect, and native or fallback support.
rdeodhar Sep 8, 2022
f8e894c
Merge branch 'sycl' of https://github.com/intel/llvm into bfloat16
rdeodhar Sep 8, 2022
2ad68f6
Formatting changes.
rdeodhar Sep 8, 2022
4b78c03
Formatting changes.
rdeodhar Sep 8, 2022
0fce16d
Update to documentation.
rdeodhar Sep 8, 2022
4bcb383
Deprecate bfloat16 aspect.
rdeodhar Sep 8, 2022
35308f8
Fixes for ESIMD.
rdeodhar Sep 9, 2022
fa045e2
Reinstated to_float and from_float, used by NVidia, updated doc.
rdeodhar Sep 9, 2022
3322d6a
Merge branch 'sycl' of https://github.com/intel/llvm into bfloat16
rdeodhar Sep 12, 2022
b12fd94
Update to doc.
rdeodhar Sep 12, 2022
87b0f09
Merge branch 'sycl' of https://github.com/intel/llvm into bfloat16
rdeodhar Sep 14, 2022
f217eb4
Corrections to headers.
rdeodhar Sep 14, 2022
a908b11
Formatting change.
rdeodhar Sep 14, 2022
aab4c78
bfloat16 class supports all sm_xx devices.
Sep 15, 2022
a2568ba
Merge pull request #1 from JackAKirk/bfloat16-cuda-allarch
rdeodhar Sep 15, 2022
4d7a22b
Changes to keep bfloat math functions experimental for now.
rdeodhar Sep 16, 2022
38e5ad4
Merge branch 'sycl' of https://github.com/intel/llvm into bfloat16
rdeodhar Sep 16, 2022
b9accad
Cleanup of bfloat16_math extension.
rdeodhar Sep 16, 2022
ca7880a
Document updates and minor changes.
rdeodhar Sep 19, 2022
dc3b2b5
Fixes for long lines in doc, a different way to check for NaN.
rdeodhar Sep 19, 2022
c955d36
Merge branch 'sycl' of https://github.com/intel/llvm into bfloat16
rdeodhar Sep 20, 2022
1aa6ad3
Broke long lines into multiple lines.
rdeodhar Sep 20, 2022
ff04ce1
Merge branch 'sycl' of https://github.com/intel/llvm into bfloat16
rdeodhar Sep 21, 2022
802f502
Changed library order on Windows.
rdeodhar Sep 21, 2022
8d7f46a
Merge branch 'sycl' of https://github.com/intel/llvm into bfloat16
rdeodhar Sep 22, 2022
190f2a3
Fix for AOT compilation and correction to new headers.
rdeodhar Sep 22, 2022
84c50f3
Noted AOT limitation in doc.
rdeodhar Sep 23, 2022
df058ba
Adjustment for AOT compilation.
rdeodhar Sep 24, 2022
fed4d1d
Fixes for AOT builds.
rdeodhar Sep 26, 2022
28259d0
Merge branch 'sycl' of https://github.com/intel/llvm into bfloat16
rdeodhar Sep 26, 2022
c11115b
Merge branch 'sycl' of https://github.com/intel/llvm into bfloat16
rdeodhar Sep 26, 2022
6b05a2a
Merge branch 'sycl' of https://github.com/intel/llvm into bfloat16
rdeodhar Sep 27, 2022
a82d73a
Fixes for AOT multiple devices.
rdeodhar Sep 27, 2022
3fc8885
Updated documentation.
rdeodhar Sep 27, 2022
1ec6838
Added back missing Status section in documentation.
rdeodhar Sep 27, 2022
105094b
Merge branch 'sycl' of https://github.com/intel/llvm into bfloat16
rdeodhar Sep 27, 2022
432e775
Merge branch 'sycl' of https://github.com/intel/llvm into bfloat16
rdeodhar Sep 29, 2022
c135643
Added tests, corrected aspect check.
rdeodhar Oct 1, 2022
4eca414
Merge branch 'sycl' of https://github.com/intel/llvm into bfloat16
rdeodhar Oct 1, 2022
8876ac8
Added missing newlines.
rdeodhar Oct 3, 2022
f0f2727
Merge branch 'sycl' of https://github.com/intel/llvm into bfloat16
rdeodhar Oct 3, 2022
17673bf
Corrections to tests and macros, added host code emulation.
rdeodhar Oct 4, 2022
1094b8c
Merge branch 'sycl' of https://github.com/intel/llvm into bfloat16
rdeodhar Oct 4, 2022
8d40228
Small corrections.
rdeodhar Oct 4, 2022
c5a85cf
Merge branch 'sycl' of https://github.com/intel/llvm into bfloat16
rdeodhar Oct 4, 2022
cf8f6e0
Fixes for AOT.
rdeodhar Oct 4, 2022
5e50646
Formatting change.
rdeodhar Oct 4, 2022
45d3e70
Merge branch 'sycl' of https://github.com/intel/llvm into bfloat16
rdeodhar Oct 5, 2022
a7be718
Renamed the bfloat aspects.
rdeodhar Oct 5, 2022
cac1c18
Fixes for generic JIT compilation.
rdeodhar Oct 6, 2022
208c09a
Changes for AOT sycl-targets switch.
rdeodhar Oct 6, 2022
46f406d
Merge branch 'sycl' of https://github.com/intel/llvm into bfloat16
rdeodhar Oct 6, 2022
6830857
Corrected aspects queries.
rdeodhar Oct 6, 2022
46e5278
Merge branch 'sycl' of https://github.com/intel/llvm into bfloat16
rdeodhar Oct 6, 2022
10fc9a3
Change in the way fallback/native libs are selected.
rdeodhar Oct 8, 2022
6195545
Merge branch 'sycl' of https://github.com/intel/llvm into bfloat16
rdeodhar Oct 8, 2022
437e34a
Changed type of string.
rdeodhar Oct 10, 2022
09dc4c5
Merge branch 'sycl' of https://github.com/intel/llvm into bfloat16
rdeodhar Oct 12, 2022
386353e
Replaced bfloat16 aspect with bfloat16_math_functions aspect.
rdeodhar Oct 12, 2022
0f93586
Improved devices check in clang driver.
rdeodhar Oct 13, 2022
48f3cac
Merge branch 'sycl' of https://github.com/intel/llvm into bfloat16
rdeodhar Oct 13, 2022
d33cb10
Enhanced test for improved bfloat16 target detection.
rdeodhar Oct 13, 2022
28992c2
Updated bfloat16 driver test for windows.
rdeodhar Oct 13, 2022
ec28c8b
Use STL for parsing devices.
rdeodhar Oct 13, 2022
b958fc7
Merge branch 'sycl' of https://github.com/intel/llvm into bfloat16
rdeodhar Oct 24, 2022
ec70b20
Allow spir64 target to be JIT even when combined with AOT targets.
rdeodhar Oct 24, 2022
1b86012
Updated documentation.
rdeodhar Oct 24, 2022
3e1e681
Modifications for mixed JIT and AOT compilations, added tests.
rdeodhar Oct 25, 2022
8c633d3
Corrections to comments.
rdeodhar Oct 25, 2022
1a59e03
Update to documentation.
rdeodhar Oct 25, 2022
b2fd6cc
Updated doc.
rdeodhar Oct 25, 2022
fab2e54
Merge branch 'sycl' of https://github.com/intel/llvm into bfloat16
rdeodhar Oct 26, 2022
35b8910
Adjustments to tests.
rdeodhar Oct 27, 2022
a05c872
Test cleanup.
rdeodhar Oct 27, 2022
ac5f603
Merge branch 'sycl' of https://github.com/intel/llvm into bfloat16
rdeodhar Oct 27, 2022
6d45ed1
Adjustments to more tests.
rdeodhar Oct 27, 2022
077d0fe
Change to tests to ensure AOT components are available.
rdeodhar Oct 28, 2022
2ff6a9d
Merge branch 'sycl' of https://github.com/intel/llvm into bfloat16
rdeodhar Nov 7, 2022
d7c80ee
Adjustment to test for new bfloat16 header.
rdeodhar Nov 7, 2022
20d13df
Merge branch 'sycl' of https://github.com/intel/llvm into bfloat16
rdeodhar Nov 8, 2022
cd1d0a2
Merge branch 'sycl' of https://github.com/intel/llvm into bfloat16
rdeodhar Nov 15, 2022
4bf60b9
Merge branch 'sycl' of https://github.com/intel/llvm into bfloat16
rdeodhar Nov 18, 2022
45c32f7
Merge branch 'sycl' of https://github.com/intel/llvm into bfloat16
rdeodhar Nov 21, 2022
5de1bf7
Merge branch 'sycl' of https://github.com/intel/llvm into bfloat16
rdeodhar Nov 22, 2022
6ec2bb9
Changes for indirect accesses.
rdeodhar Nov 22, 2022
49e9cd1
Merge branch 'sycl' of https://github.com/intel/llvm into bfloat16
rdeodhar Nov 22, 2022
2065060
Merge branch 'sycl' of https://github.com/intel/llvm into bfloat16
rdeodhar Nov 23, 2022
e24e57b
Fixed conflicts.
rdeodhar Nov 23, 2022
41098ab
Merge branch 'sycl' of https://github.com/intel/llvm into bfloat16
rdeodhar Nov 25, 2022
37b05f0
Correction to library list.
rdeodhar Nov 25, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -135,13 +135,11 @@ public:
bfloat16(const float &a);
bfloat16 &operator=(const float &a);

// Convert from bfloat16 to float
// Convert bfloat16 to floating-point types
operator float() const;
operator sycl::half() const;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like this conversion to sycl::half. However, we should also add the opposite conversion from sycl::half to bfloat16:

bfloat16(const sycl::half &a);
bfloat16 &operator=(const sycl::half &a);

Do we also need conversion to / from double?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This PR is intended to move the current bfloat16 support out of experimental space. Any changes to the level of bfloat16 support can be done in future PRs.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On Intel platforms the bfloat16 to/from float is done using the __spirv_ConvertBF16ToFINTELoperator. I suspect a double version of that does not exist.
Float to double conversion can be made in the usual C++ way more efficiently in hardware. A direct version of bfloat16 to double conversion in software will involve more bit twiddling than the float conversion where only trailing 0 bits of fraction need to be inserted.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The sycl::half class includes conversions to/from float. Those kick in when bfloat16 is used with sycl::half, so conversions between bfloat16 and sycl::half are not needed.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you saying that we should remove this conversion from bfloat16 to sycl::half?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, its not needed.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This item was revisited and it turns out that sycl::half <-> bfloat16 conversions are needed. They have been added.

Copy link
Contributor

@MrSidims MrSidims Sep 20, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for joining the discussion late. May be it's a nitpick, but should we tell, that conversion half <-> bfloat16 follows IEEE 754 float <-> half conversion? In other words, what happens, if bfloat16 value overflows half range? Also are we adding last 3 fraction bits stochastically or they are guarantied to be zero (or it's implementation detail)?


// Get bfloat16 as uint16.
operator storage_t() const;

// Convert to bool type
// Convert bfloat16 to bool type
explicit operator bool();

friend bfloat16 operator-(bfloat16 &bf) { /* ... */ }
Expand Down Expand Up @@ -195,11 +193,11 @@ Table 1. Member functions of `bfloat16` class.
| `operator float() const;`
| Return `bfloat16` value converted to `float`.

| `operator storage_t() const;`
| Return `uint16_t` value, whose bits represent `bfloat16` value.
| `operator sycl::half() const;`
| Return `bfloat16` value converted to `sycl::half`.

| `explicit operator bool() { /* ... */ }`
| Convert `bfloat16` to `bool` type. Return `false` if the value equals to
| Convert `bfloat16` to `bool` type. Return `false` if the `value` equals to
zero, return `true` otherwise.

| `friend bfloat16 operator-(bfloat16 &bf) { /* ... */ }`
Expand Down Expand Up @@ -408,4 +406,5 @@ Compute absolute value of a `bfloat16`.
|3|2021-08-18|Alexey Sotkin |Remove `uint16_t` constructor
|4|2022-03-07|Aidan Belton and Jack Kirk |Switch from Intel vendor specific to oneapi
|5|2022-04-05|Jack Kirk | Added section for bfloat16 math builtins
|6|2022-08-03|Alexey Sotkin |Add `operator sycl::half()`
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it should be your name here. 😄

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is carried over from that author, but I agree that more changes have been made, so changing name.

|========================================
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ __SYCL_INLINE_NAMESPACE(cl) {
namespace sycl {
namespace ext {
namespace oneapi {
namespace experimental {

class bfloat16 {
using storage_t = uint16_t;
Expand Down Expand Up @@ -165,7 +164,6 @@ class bfloat16 {
// for floating-point types.
};

} // namespace experimental
} // namespace oneapi
} // namespace ext

Expand Down
2 changes: 1 addition & 1 deletion sycl/include/sycl/ext/oneapi/experimental/builtins.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
#include <sycl/detail/type_traits.hpp>

#include <CL/__spirv/spirv_ops.hpp>
#include <sycl/ext/oneapi/experimental/bfloat16.hpp>
#include <sycl/ext/oneapi/bfloat16.hpp>

// TODO Decide whether to mark functions with this attribute.
#define __NOEXC /*noexcept*/
Expand Down
80 changes: 37 additions & 43 deletions sycl/include/sycl/ext/oneapi/matrix/matrix-jit.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

#include <CL/__spirv/spirv_ops.hpp>
#include <sycl/detail/defines_elementary.hpp>
#include <sycl/ext/oneapi/experimental/bfloat16.hpp>
#include <sycl/ext/oneapi/bfloat16.hpp>
#include <sycl/feature_test.hpp>

__SYCL_INLINE_NAMESPACE(cl) {
Expand Down Expand Up @@ -458,18 +458,16 @@ class wi_element<uint16_t, NumRows, NumCols, Layout, Group> {
};

template <size_t NumRows, size_t NumCols, matrix_layout Layout, typename Group>
class wi_element<sycl::ext::oneapi::experimental::bfloat16, NumRows, NumCols,
Layout, Group> {
joint_matrix<sycl::ext::oneapi::experimental::bfloat16, NumRows, NumCols,
Layout, Group> &M;
class wi_element<sycl::ext::oneapi::bfloat16, NumRows, NumCols, Layout, Group> {
joint_matrix<sycl::ext::oneapi::bfloat16, NumRows, NumCols, Layout, Group> &M;
std::size_t idx;

public:
wi_element(joint_matrix<sycl::ext::oneapi::experimental::bfloat16, NumRows,
NumCols, Layout, Group> &Mat,
wi_element(joint_matrix<sycl::ext::oneapi::bfloat16, NumRows, NumCols, Layout,
Group> &Mat,
std::size_t i)
: M(Mat), idx(i) {}
operator sycl::ext::oneapi::experimental::bfloat16() {
operator sycl::ext::oneapi::bfloat16() {
#ifdef __SYCL_DEVICE_ONLY__
return __spirv_VectorExtractDynamic(M.spvm, idx);
#else
Expand All @@ -488,7 +486,7 @@ class wi_element<sycl::ext::oneapi::experimental::bfloat16, NumRows, NumCols,
#endif // __SYCL_DEVICE_ONLY__
}

wi_element &operator=(const sycl::ext::oneapi::experimental::bfloat16 &rhs) {
wi_element &operator=(const sycl::ext::oneapi::bfloat16 &rhs) {
#ifdef __SYCL_DEVICE_ONLY__
M.spvm = __spirv_VectorInsertDynamic(M.spvm, rhs, idx);
return *this;
Expand All @@ -499,9 +497,8 @@ class wi_element<sycl::ext::oneapi::experimental::bfloat16, NumRows, NumCols,
#endif // __SYCL_DEVICE_ONLY__
}

wi_element &
operator=(const wi_element<sycl::ext::oneapi::experimental::bfloat16, NumRows,
NumCols, Layout, Group> &rhs) {
wi_element &operator=(const wi_element<sycl::ext::oneapi::bfloat16, NumRows,
NumCols, Layout, Group> &rhs) {
#ifdef __SYCL_DEVICE_ONLY__
M.spvm = __spirv_VectorInsertDynamic(
M.spvm, __spirv_VectorExtractDynamic(rhs.M.spvm, rhs.idx), idx);
Expand All @@ -515,16 +512,14 @@ class wi_element<sycl::ext::oneapi::experimental::bfloat16, NumRows, NumCols,

#if __SYCL_DEVICE_ONLY__
#define OP(opassign, op) \
wi_element &operator opassign( \
const sycl::ext::oneapi::experimental::bfloat16 &rhs) { \
wi_element &operator opassign(const sycl::ext::oneapi::bfloat16 &rhs) { \
M.spvm = __spirv_VectorInsertDynamic( \
M.spvm, __spirv_VectorExtractDynamic(M.spvm, idx) op rhs, idx); \
return *this; \
}
#else // __SYCL_DEVICE_ONLY__
#define OP(opassign, op) \
wi_element &operator opassign( \
const sycl::ext::oneapi::experimental::bfloat16 &rhs) { \
wi_element &operator opassign(const sycl::ext::oneapi::bfloat16 &rhs) { \
(void)rhs; \
throw runtime_error("joint matrix is not supported on host device.", \
PI_ERROR_INVALID_DEVICE); \
Expand All @@ -539,34 +534,34 @@ class wi_element<sycl::ext::oneapi::experimental::bfloat16, NumRows, NumCols,
#if __SYCL_DEVICE_ONLY__
#define OP(type, op) \
friend type operator op( \
const wi_element<sycl::ext::oneapi::experimental::bfloat16, NumRows, \
NumCols, Layout, Group> &lhs, \
const sycl::ext::oneapi::experimental::bfloat16 &rhs) { \
const wi_element<sycl::ext::oneapi::bfloat16, NumRows, NumCols, Layout, \
Group> &lhs, \
const sycl::ext::oneapi::bfloat16 &rhs) { \
return __spirv_VectorExtractDynamic(lhs.M.spvm, lhs.idx) op rhs; \
} \
friend type operator op( \
const sycl::ext::oneapi::experimental::bfloat16 &lhs, \
const wi_element<sycl::ext::oneapi::experimental::bfloat16, NumRows, \
NumCols, Layout, Group> &rhs) { \
const sycl::ext::oneapi::bfloat16 &lhs, \
const wi_element<sycl::ext::oneapi::bfloat16, NumRows, NumCols, Layout, \
Group> &rhs) { \
return __spirv_VectorExtractDynamic(rhs.M.spvm, rhs.idx) op lhs; \
}
OP(sycl::ext::oneapi::experimental::bfloat16, +)
OP(sycl::ext::oneapi::experimental::bfloat16, -)
OP(sycl::ext::oneapi::experimental::bfloat16, *)
OP(sycl::ext::oneapi::experimental::bfloat16, /)
OP(sycl::ext::oneapi::bfloat16, +)
OP(sycl::ext::oneapi::bfloat16, -)
OP(sycl::ext::oneapi::bfloat16, *)
OP(sycl::ext::oneapi::bfloat16, /)
#undef OP
#define OP(type, op) \
friend type operator op( \
const wi_element<sycl::ext::oneapi::experimental::bfloat16, NumRows, \
NumCols, Layout, Group> &lhs, \
const sycl::ext::oneapi::experimental::bfloat16 &rhs) { \
const wi_element<sycl::ext::oneapi::bfloat16, NumRows, NumCols, Layout, \
Group> &lhs, \
const sycl::ext::oneapi::bfloat16 &rhs) { \
return type{static_cast<float>(__spirv_VectorExtractDynamic( \
lhs.M.spvm, lhs.idx)) op static_cast<float>(rhs)}; \
} \
friend type operator op( \
const sycl::ext::oneapi::experimental::bfloat16 &lhs, \
const wi_element<sycl::ext::oneapi::experimental::bfloat16, NumRows, \
NumCols, Layout, Group> &rhs) { \
const sycl::ext::oneapi::bfloat16 &lhs, \
const wi_element<sycl::ext::oneapi::bfloat16, NumRows, NumCols, Layout, \
Group> &rhs) { \
return type{static_cast<float>(__spirv_VectorExtractDynamic( \
rhs.M.spvm, rhs.idx)) op static_cast<float>(lhs)}; \
}
Expand All @@ -579,24 +574,23 @@ class wi_element<sycl::ext::oneapi::experimental::bfloat16, NumRows, NumCols,
#undef OP
#else // __SYCL_DEVICE_ONLY__
#define OP(type, op) \
friend type operator op( \
const wi_element<sycl::ext::oneapi::experimental::bfloat16, NumRows, \
NumCols, Layout, Group> &, \
const sycl::ext::oneapi::experimental::bfloat16 &) { \
friend type operator op(const wi_element<sycl::ext::oneapi::bfloat16, \
NumRows, NumCols, Layout, Group> &, \
const sycl::ext::oneapi::bfloat16 &) { \
throw runtime_error("joint matrix is not supported on host device.", \
PI_ERROR_INVALID_DEVICE); \
} \
friend type operator op( \
const sycl::ext::oneapi::experimental::bfloat16 &, \
const wi_element<sycl::ext::oneapi::experimental::bfloat16, NumRows, \
NumCols, Layout, Group> &) { \
const sycl::ext::oneapi::bfloat16 &, \
const wi_element<sycl::ext::oneapi::bfloat16, NumRows, NumCols, Layout, \
Group> &) { \
throw runtime_error("joint matrix is not supported on host device.", \
PI_ERROR_INVALID_DEVICE); \
}
OP(sycl::ext::oneapi::experimental::bfloat16, +)
OP(sycl::ext::oneapi::experimental::bfloat16, -)
OP(sycl::ext::oneapi::experimental::bfloat16, *)
OP(sycl::ext::oneapi::experimental::bfloat16, /)
OP(sycl::ext::oneapi::bfloat16, +)
OP(sycl::ext::oneapi::bfloat16, -)
OP(sycl::ext::oneapi::bfloat16, *)
OP(sycl::ext::oneapi::bfloat16, /)
OP(bool, ==)
OP(bool, !=)
OP(bool, <)
Expand Down
17 changes: 8 additions & 9 deletions sycl/include/sycl/ext/oneapi/matrix/matrix-tensorcore.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
// ===--------------------------------------------------------------------=== //

#pragma once
#include <sycl/ext/oneapi/experimental/bfloat16.hpp>
#include <sycl/ext/oneapi/bfloat16.hpp>

__SYCL_INLINE_NAMESPACE(cl) {
namespace sycl {
Expand Down Expand Up @@ -219,8 +219,7 @@ struct joint_matrix_load_impl<
S, Use, NumRows, NumCols, Layout, sycl::sub_group> &res,
multi_ptr<T, Space> src, size_t stride) {
if constexpr (std::is_same<T, uint16_t>::value ||
std::is_same<
T, sycl::ext::oneapi::experimental::bfloat16>::value) {
std::is_same<T, sycl::ext::oneapi::bfloat16>::value) {
auto tileptr = reinterpret_cast<int32_t const *>(src.get());
auto destptr = reinterpret_cast<int32_t *>(&res.wi_marray);
if constexpr (NumRows == 16 && NumCols == 16) {
Expand Down Expand Up @@ -585,8 +584,8 @@ struct joint_matrix_mad_impl<
get_layout_pair_id<LayoutA, LayoutB>(), 0);
}
} else if constexpr (std::is_same<T1, uint16_t>::value ||
std::is_same<T1, sycl::ext::oneapi::experimental::
bfloat16>::value) {
std::is_same<T1,
sycl::ext::oneapi::bfloat16>::value) {
__mma_bf16_m16n16k16_mma_f32(
reinterpret_cast<float *>(&D.wi_marray),
reinterpret_cast<int32_t const *>(&A.wi_marray),
Expand Down Expand Up @@ -622,8 +621,8 @@ struct joint_matrix_mad_impl<
get_layout_pair_id<LayoutA, LayoutB>(), 0);
}
} else if constexpr (std::is_same<T1, uint16_t>::value ||
std::is_same<T1, sycl::ext::oneapi::experimental::
bfloat16>::value) {
std::is_same<T1,
sycl::ext::oneapi::bfloat16>::value) {
__mma_bf16_m8n32k16_mma_f32(
reinterpret_cast<float *>(&D.wi_marray),
reinterpret_cast<int32_t const *>(&A.wi_marray),
Expand All @@ -645,8 +644,8 @@ struct joint_matrix_mad_impl<
get_layout_pair_id<LayoutA, LayoutB>(), 0);
}
} else if constexpr (std::is_same<T1, uint16_t>::value ||
std::is_same<T1, sycl::ext::oneapi::experimental::
bfloat16>::value) {
std::is_same<T1,
sycl::ext::oneapi::bfloat16>::value) {
__mma_bf16_m32n8k16_mma_f32(
reinterpret_cast<float *>(&D.wi_marray),
reinterpret_cast<int32_t const *>(&A.wi_marray),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

using namespace sycl;
using namespace sycl::ext::oneapi::experimental::matrix;
using sycl::ext::oneapi::experimental::bfloat16;
using sycl::ext::oneapi::bfloat16;

constexpr int stride = 16;

Expand Down
4 changes: 2 additions & 2 deletions sycl/test/extensions/bfloat16.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@

// UNSUPPORTED: cuda || hip_amd

#include <sycl/ext/oneapi/experimental/bfloat16.hpp>
#include <sycl/ext/oneapi/bfloat16.hpp>
#include <sycl/sycl.hpp>

using sycl::ext::oneapi::experimental::bfloat16;
using sycl::ext::oneapi::bfloat16;

SYCL_EXTERNAL uint16_t some_bf16_intrinsic(uint16_t x, uint16_t y);
SYCL_EXTERNAL void foo(long x, sycl::half y);
Expand Down
2 changes: 1 addition & 1 deletion sycl/test/matrix/matrix-bfloat16-test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
#include <iostream>

using namespace sycl::ext::oneapi::experimental::matrix;
using bfloat16 = sycl::ext::oneapi::experimental::bfloat16;
using bfloat16 = sycl::ext::oneapi::bfloat16;

static constexpr auto TILE_SZ = 16;
static constexpr auto TM = TILE_SZ - 1;
Expand Down