Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
160 changes: 160 additions & 0 deletions sycl/include/sycl/ext/oneapi/matrix/matrix-jit.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include <CL/__spirv/spirv_ops.hpp>
#include <CL/sycl/detail/defines_elementary.hpp>
#include <CL/sycl/feature_test.hpp>
#include <sycl/ext/intel/experimental/bfloat16.hpp>
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Buildfail will happen since

class [[sycl_detail::uses_aspects(ext_intel_bf16_conversion)]] bfloat16

and sycl_detail::uses_aspects is unsupported code in g++ which is used for intel/llvm's build


__SYCL_INLINE_NAMESPACE(cl) {
namespace sycl {
Expand Down Expand Up @@ -737,6 +738,165 @@ class wi_element<uint16_t, NumRows, NumCols, Layout, Group> {
}
};

template <size_t NumRows, size_t NumCols, matrix_layout Layout, typename Group>
class wi_element<sycl::ext::intel::experimental::bfloat16, NumRows, NumCols,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: probably alias like

using bfloat16 = sycl::ext::intel::experimental::bfloat16

could improve readability of this code.

Copy link
Contributor Author

@yubingex007-a11y yubingex007-a11y Feb 17, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can't use "using" in header file

Copy link
Contributor

@keryell keryell Apr 24, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can't use "using" in header file

What do you mean? I guess you can use them locally if it is in a scope not visible from the end-user.

Layout, Group> {
joint_matrix<sycl::ext::intel::experimental::bfloat16, NumRows, NumCols,
Layout, Group> &M;
std::size_t idx;

public:
wi_element(joint_matrix<sycl::ext::intel::experimental::bfloat16, NumRows,
NumCols, Layout, Group> &Mat,
std::size_t i)
: M(Mat), idx(i) {}
operator sycl::ext::intel::experimental::bfloat16() {
#ifdef __SYCL_DEVICE_ONLY__
return __spirv_VectorExtractDynamic(M.spvm, idx);
#else
throw runtime_error("joint matrix is not supported on host device.",
PI_INVALID_DEVICE);
#endif // __SYCL_DEVICE_ONLY__
}

explicit operator bool() {
#ifdef __SYCL_DEVICE_ONLY__
return std::fabs(static_cast<float>(__spirv_VectorExtractDynamic(
M.spvm, idx))) >= std::numeric_limits<float>::epsilon();
#else
throw runtime_error("joint matrix is not supported on host device.",
PI_INVALID_DEVICE);
#endif // __SYCL_DEVICE_ONLY__
}

wi_element &operator=(const sycl::ext::intel::experimental::bfloat16 &rhs) {
#ifdef __SYCL_DEVICE_ONLY__
M.spvm = __spirv_VectorInsertDynamic(M.spvm, rhs, idx);
return *this;
#else
(void)rhs;
throw runtime_error("joint matrix is not supported on host device.",
PI_INVALID_DEVICE);
#endif // __SYCL_DEVICE_ONLY__
}

wi_element &
operator=(const wi_element<sycl::ext::intel::experimental::bfloat16, NumRows,
NumCols, Layout, Group> &rhs) {
#ifdef __SYCL_DEVICE_ONLY__
M.spvm = __spirv_VectorInsertDynamic(
M.spvm, __spirv_VectorExtractDynamic(rhs.M.spvm, rhs.idx), idx);
return *this;
#else
(void)rhs;
throw runtime_error("joint matrix is not supported on host device.",
PI_INVALID_DEVICE);
#endif // __SYCL_DEVICE_ONLY__
}

#if __SYCL_DEVICE_ONLY__
#define OP(opassign, op) \
wi_element &operator opassign( \
const sycl::ext::intel::experimental::bfloat16 &rhs) { \
M.spvm = __spirv_VectorInsertDynamic( \
M.spvm, \
static_cast<float>(__spirv_VectorExtractDynamic(M.spvm, idx) \
op static_cast<float>(rhs))), \
idx); \
return *this; \
}
#else // __SYCL_DEVICE_ONLY__
#define OP(opassign, op) \
wi_element &operator opassign( \
const sycl::ext::intel::experimental::bfloat16 &rhs) { \
(void)rhs; \
throw runtime_error("joint matrix is not supported on host device.", \
PI_INVALID_DEVICE); \
}
#endif // __SYCL_DEVICE_ONLY__
OP(+=, +)
OP(-=, -)
OP(*=, *)
OP(/=, /)
#undef OP

#if __SYCL_DEVICE_ONLY__
#define OP(type, op) \
friend type operator op( \
const wi_element<sycl::ext::intel::experimental::bfloat16, NumRows, \
NumCols, Layout, Group> &lhs, \
const sycl::ext::intel::experimental::bfloat16 &rhs) { \
return static_cast<float>(__spirv_VectorExtractDynamic( \
lhs.M.spvm, lhs.idx)) op static_cast<float>(rhs); \
} \
friend type operator op( \
const sycl::ext::intel::experimental::bfloat16 &lhs, \
const wi_element<sycl::ext::intel::experimental::bfloat16, NumRows, \
NumCols, Layout, Group> &rhs) { \
return static_cast<float>(__spirv_VectorExtractDynamic( \
rhs.M.spvm, rhs.idx)) op static_cast<float>(lhs); \
}
OP(sycl::ext::intel::experimental::bfloat16, +)
OP(sycl::ext::intel::experimental::bfloat16, -)
OP(sycl::ext::intel::experimental::bfloat16, *)
OP(sycl::ext::intel::experimental::bfloat16, /)
#undef OP
#define OP(type, op) \
friend type operator op( \
const wi_element<sycl::ext::intel::experimental::bfloat16, NumRows, \
NumCols, Layout, Group> &lhs, \
const sycl::ext::intel::experimental::bfloat16 &rhs) { \
return type{static_cast<float>(__spirv_VectorExtractDynamic( \
lhs.M.spvm, lhs.idx)) op static_cast<float>(rhs)}; \
} \
friend type operator op( \
const sycl::ext::intel::experimental::bfloat16 &lhs, \
const wi_element<sycl::ext::intel::experimental::bfloat16, NumRows, \
NumCols, Layout, Group> &rhs) { \
return type{static_cast<float>(__spirv_VectorExtractDynamic( \
rhs.M.spvm, rhs.idx)) op static_cast<float>(lhs)}; \
}
OP(bool, ==)
OP(bool, !=)
OP(bool, <)
OP(bool, >)
OP(bool, <=)
OP(bool, >=)
#undef OP
#else // __SYCL_DEVICE_ONLY__
#define OP(type, op) \
friend type operator op( \
const wi_element<sycl::ext::intel::experimental::bfloat16, NumRows, \
NumCols, Layout, Group> &lhs, \
const sycl::ext::intel::experimental::bfloat16 &rhs) { \
(void)lhs; \
(void)rhs; \
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead just remove the names from the op() declaration

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i didn't get your point. actually this code is for host. the param list of friend type operator op should be the same as line555 which is for device code.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe what @keryell suggests is to change this to

  friend type operator op(                                                     \
      const wi_element<sycl::ext::oneapi::experimental::bfloat16, NumRows,     \
                       NumCols, Layout, Group> &,                              \
      const sycl::ext::oneapi::experimental::bfloat16 &) {                     \
    throw runtime_error("joint matrix is not supported on host device.",       \
                        PI_INVALID_DEVICE);                                    \
  }  

which preserves the signature but doesn't have unused arguments as they are unnamed.

throw runtime_error("joint matrix is not supported on host device.", \
PI_INVALID_DEVICE); \
} \
friend type operator op( \
const sycl::ext::intel::experimental::bfloat16 &lhs, \
const wi_element<sycl::ext::intel::experimental::bfloat16, NumRows, \
NumCols, Layout, Group> &rhs) { \
(void)lhs; \
(void)rhs; \
throw runtime_error("joint matrix is not supported on host device.", \
PI_INVALID_DEVICE); \
}
OP(sycl::ext::intel::experimental::bfloat16, +)
OP(sycl::ext::intel::experimental::bfloat16, -)
OP(sycl::ext::intel::experimental::bfloat16, *)
OP(sycl::ext::intel::experimental::bfloat16, /)
OP(bool, ==)
OP(bool, !=)
OP(bool, <)
OP(bool, >)
OP(bool, <=)
OP(bool, >=)
#undef OP
#endif // __SYCL_DEVICE_ONLY__
};

template <typename T, size_t NumRows, size_t NumCols, matrix_layout Layout,
typename Group>
class wi_slice {
Expand Down