diff --git a/sycl/include/sycl/ext/oneapi/matrix/matrix-jit.hpp b/sycl/include/sycl/ext/oneapi/matrix/matrix-jit.hpp index 14466fd5fafb4..b3da11e7c439d 100644 --- a/sycl/include/sycl/ext/oneapi/matrix/matrix-jit.hpp +++ b/sycl/include/sycl/ext/oneapi/matrix/matrix-jit.hpp @@ -451,6 +451,286 @@ class wi_element { } }; +// Note that similarly to the other matrix functions, uint16_t is used here to +// represent bf16 type. Since the AMX and DPAS implementations don't support +// uint16_t, this interpretation is possible. This design choice was made before +// the introduction of SYCL experimental bfloat16 type. Our plan is to move +// towards using the SYCL bfloat16. But since it is still experimental, we will +// probably keep both uint16 interpretation and SYCL bfloat16. +template +class wi_element { + joint_matrix &M; + std::size_t idx; + +public: + wi_element(joint_matrix &Mat, + std::size_t i) + : M(Mat), idx(i) {} + operator uint16_t() { +#ifdef __SYCL_DEVICE_ONLY__ + return __spirv_VectorExtractDynamic(M.spvm, idx); +#else + throw runtime_error("joint matrix is not supported on host device.", + PI_INVALID_DEVICE); +#endif // __SYCL_DEVICE_ONLY__ + } + + explicit operator bool() { +#ifdef __SYCL_DEVICE_ONLY__ + return __spirv_VectorExtractDynamic(M.spvm, idx) != + static_cast(0); +#else + throw runtime_error("joint matrix is not supported on host device.", + PI_INVALID_DEVICE); +#endif // __SYCL_DEVICE_ONLY__ + } + + wi_element &operator=(const uint16_t &rhs) { +#ifdef __SYCL_DEVICE_ONLY__ + M.spvm = __spirv_VectorInsertDynamic(M.spvm, rhs, idx); + return *this; +#else + (void)rhs; + throw runtime_error("joint matrix is not supported on host device.", + PI_INVALID_DEVICE); +#endif // __SYCL_DEVICE_ONLY__ + } + + wi_element & + operator=(const wi_element &rhs) { +#ifdef __SYCL_DEVICE_ONLY__ + M.spvm = __spirv_VectorInsertDynamic( + M.spvm, __spirv_VectorExtractDynamic(rhs.M.spvm, rhs.idx), idx); + return *this; +#else + (void)rhs; + throw runtime_error("joint matrix is not supported on host device.", + PI_INVALID_DEVICE); +#endif // __SYCL_DEVICE_ONLY__ + } + + // We use here the following functions for conversion (bf16=>fp32 and + // fp32=>bf16). This is a workaround until we are able to use + // __spirv_ConvertFToBF16INTEL and __spirv_ConvertBF16ToFINTEL once these are + // supported in the CPU backend + static float make_fp32(uint16_t x) { + unsigned int y = x; + y = y << 16; + float *res = reinterpret_cast(&y); + return *res; + } + + static uint16_t make_bf16(float x) { + int *res = reinterpret_cast(&x); + *res = *res >> 16; + return (uint16_t)*res; + } + + friend uint16_t + operator+(const wi_element &lhs, + const uint16_t &rhs) { +#ifdef __SYCL_DEVICE_ONLY__ + return make_bf16( + make_fp32(__spirv_VectorExtractDynamic(lhs.M.spvm, lhs.idx)) + + make_fp32(rhs)); +#else + (void)lhs; + (void)rhs; + throw runtime_error("joint matrix is not supported on host device.", + PI_INVALID_DEVICE); +#endif // __SYCL_DEVICE_ONLY__ + } + + wi_element &operator+=(const uint16_t &rhs) { +#ifdef __SYCL_DEVICE_ONLY__ + M.spvm = __spirv_VectorInsertDynamic( + M.spvm, + make_bf16(make_fp32(__spirv_VectorExtractDynamic(M.spvm, idx)) + + make_fp32(rhs)), + idx); + return *this; +#else + (void)rhs; + throw runtime_error("joint matrix is not supported on host device.", + PI_INVALID_DEVICE); +#endif // __SYCL_DEVICE_ONLY__ + } + + friend uint16_t + operator-(const wi_element &lhs, + const uint16_t &rhs) { +#ifdef __SYCL_DEVICE_ONLY__ + return make_bf16( + make_fp32(__spirv_VectorExtractDynamic(lhs.M.spvm, lhs.idx)) - + make_fp32(rhs)); +#else + (void)lhs; + (void)rhs; + throw runtime_error("joint matrix is not supported on host device.", + PI_INVALID_DEVICE); +#endif // __SYCL_DEVICE_ONLY__ + } + + wi_element &operator-=(const uint16_t &rhs) { +#ifdef __SYCL_DEVICE_ONLY__ + M.spvm = __spirv_VectorInsertDynamic( + M.spvm, + make_bf16(make_fp32(__spirv_VectorExtractDynamic(M.spvm, idx)) - + make_fp32(rhs)), + idx); + return *this; +#else + (void)rhs; + throw runtime_error("joint matrix is not supported on host device.", + PI_INVALID_DEVICE); +#endif // __SYCL_DEVICE_ONLY__ + } + + friend uint16_t + operator*(const wi_element &lhs, + const uint16_t &rhs) { +#ifdef __SYCL_DEVICE_ONLY__ + return make_bf16( + make_fp32(__spirv_VectorExtractDynamic(lhs.M.spvm, lhs.idx)) * + make_fp32(rhs)); +#else + (void)lhs; + (void)rhs; + throw runtime_error("joint matrix is not supported on host device.", + PI_INVALID_DEVICE); +#endif // __SYCL_DEVICE_ONLY__ + } + + wi_element &operator*=(const uint16_t &rhs) { +#ifdef __SYCL_DEVICE_ONLY__ + M.spvm = __spirv_VectorInsertDynamic( + M.spvm, + make_bf16(make_fp32(__spirv_VectorExtractDynamic(M.spvm, idx)) * + make_fp32(rhs)), + idx); + return *this; +#else + (void)rhs; + throw runtime_error("joint matrix is not supported on host device.", + PI_INVALID_DEVICE); +#endif // __SYCL_DEVICE_ONLY__ + } + + friend uint16_t + operator/(const wi_element &lhs, + const uint16_t &rhs) { +#ifdef __SYCL_DEVICE_ONLY__ + return make_bf16( + make_fp32(__spirv_VectorExtractDynamic(lhs.M.spvm, lhs.idx)) / + make_fp32(rhs)); +#else + (void)lhs; + (void)rhs; + throw runtime_error("joint matrix is not supported on host device.", + PI_INVALID_DEVICE); +#endif // __SYCL_DEVICE_ONLY__ + } + + wi_element &operator/=(const uint16_t &rhs) { +#ifdef __SYCL_DEVICE_ONLY__ + M.spvm = __spirv_VectorInsertDynamic( + M.spvm, + make_bf16(make_fp32(__spirv_VectorExtractDynamic(M.spvm, idx)) / + make_fp32(rhs)), + idx); + return *this; +#else + (void)rhs; + throw runtime_error("joint matrix is not supported on host device.", + PI_INVALID_DEVICE); +#endif // __SYCL_DEVICE_ONLY__ + } + + friend bool + operator<(const wi_element &lhs, + const uint16_t &rhs) { +#ifdef __SYCL_DEVICE_ONLY__ + return make_fp32(__spirv_VectorExtractDynamic(lhs.M.spvm, lhs.idx)) < + make_fp32(rhs); +#else + (void)lhs; + (void)rhs; + throw runtime_error("joint matrix is not supported on host device.", + PI_INVALID_DEVICE); +#endif // __SYCL_DEVICE_ONLY__ + } + + friend bool + operator<=(const wi_element &lhs, + const uint16_t &rhs) { +#ifdef __SYCL_DEVICE_ONLY__ + return make_fp32(__spirv_VectorExtractDynamic(lhs.M.spvm, lhs.idx)) <= + make_fp32(rhs); +#else + (void)lhs; + (void)rhs; + throw runtime_error("joint matrix is not supported on host device.", + PI_INVALID_DEVICE); +#endif // __SYCL_DEVICE_ONLY__ + } + + friend bool + operator>(const wi_element &lhs, + const uint16_t &rhs) { +#ifdef __SYCL_DEVICE_ONLY__ + return make_fp32(__spirv_VectorExtractDynamic(lhs.M.spvm, lhs.idx)) > + make_fp32(rhs); +#else + (void)lhs; + (void)rhs; + throw runtime_error("joint matrix is not supported on host device.", + PI_INVALID_DEVICE); +#endif // __SYCL_DEVICE_ONLY__ + } + + friend bool + operator>=(const wi_element &lhs, + const uint16_t &rhs) { +#ifdef __SYCL_DEVICE_ONLY__ + return make_fp32(__spirv_VectorExtractDynamic(lhs.M.spvm, lhs.idx)) >= + make_fp32(rhs); +#else + (void)lhs; + (void)rhs; + throw runtime_error("joint matrix is not supported on host device.", + PI_INVALID_DEVICE); +#endif // __SYCL_DEVICE_ONLY__ + } + + friend bool + operator==(const wi_element &lhs, + const uint16_t &rhs) { +#ifdef __SYCL_DEVICE_ONLY__ + return make_fp32(__spirv_VectorExtractDynamic(lhs.M.spvm, lhs.idx)) == + make_fp32(rhs); +#else + (void)lhs; + (void)rhs; + throw runtime_error("joint matrix is not supported on host device.", + PI_INVALID_DEVICE); +#endif // __SYCL_DEVICE_ONLY__ + } + + friend bool + operator!=(const wi_element &lhs, + const uint16_t &rhs) { +#ifdef __SYCL_DEVICE_ONLY__ + return make_fp32(__spirv_VectorExtractDynamic(lhs.M.spvm, lhs.idx)) != + make_fp32(rhs); +#else + (void)lhs; + (void)rhs; + throw runtime_error("joint matrix is not supported on host device.", + PI_INVALID_DEVICE); +#endif // __SYCL_DEVICE_ONLY__ + } +}; + template class wi_slice {