Skip to content
8 changes: 8 additions & 0 deletions sycl/include/CL/__spirv/spirv_ops.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,14 @@ template <typename T, std::size_t R, std::size_t C,
extern SYCL_EXTERNAL size_t __spirv_JointMatrixWorkItemLengthINTEL(
JOINT_MATRIX_INTEL(T, R, C, L, S, U) *);

template <typename T, std::size_t R, std::size_t C,
__spv::MatrixUse U = __spv::MatrixUse::Unnecessary,
__spv::MatrixLayout L = __spv::MatrixLayout::RowMajor,
__spv::Scope::Flag S = __spv::Scope::Flag::Subgroup>
extern SYCL_EXTERNAL __ocl_vec_t<int32_t, 2>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For some reasons can not add suggestion.
here and after int32_t -> uint32_t

__spirv_JointMatrixWorkItemElemCoord(JOINT_MATRIX_INTEL(T, R, C, L, S, U) *,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is no such thing as std::tuple in SPIR-V. The instruction should return int2 and if we want to create a tuple for get_coord API, then we should read elements from this vector to create tuple.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the comment. Can you please tell a bit more about this int2 type? Is there any documentation/ code that I can take a look?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a 2 elements vector. int2 is a spelling from OpenCL, but guess the appropriate alias should be known for DPCPP, see: https://registry.khronos.org/SYCL/specs/sycl-2020/html/sycl-2020.html#_aliases

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Copy link
Contributor

@yubingex007-a11y yubingex007-a11y Oct 26, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

eh, why we use __ocl_vec_t<int32_t, 2> instead of sycl::vec<int32_t, 2> here? @MrSidims

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please wait with the merge until the name for the instruction is picked. In the draft SPIR-V spec version it is JointMatrixGetElementCoordINTEL

size_t i);

template <typename T, std::size_t R, std::size_t C,
__spv::MatrixUse U = __spv::MatrixUse::Unnecessary,
__spv::MatrixLayout L = __spv::MatrixLayout::RowMajor,
Expand Down
44 changes: 44 additions & 0 deletions sycl/include/sycl/ext/oneapi/matrix/matrix-jit-use.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include <CL/__spirv/spirv_ops.hpp>
#include <sycl/detail/defines_elementary.hpp>
#include <sycl/feature_test.hpp>
#include <tuple>

namespace sycl {
__SYCL_INLINE_VER_NAMESPACE(_V1) {
Expand Down Expand Up @@ -256,6 +257,21 @@ class wi_element {
wi_element(joint_matrix<T, NumRows, NumCols, Use, Layout, Group> &Mat,
std::size_t i)
: M(Mat), idx(i) {}

std::tuple<size_t, size_t> get_coord() {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you need to add this function to the specialization of wi_element for bfloat16 type?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: size_t -> uint32_t is probably better
same nit applicable to the code below

#ifdef __SYCL_DEVICE_ONLY__
__ocl_vec_t<int32_t, 2> co_ord =
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove underscore from the name (co_ord)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't see it's applied

__spirv_JointMatrixWorkItemElemCoord(M.spvm, idx);
const int32_t row = co_ord[0];
const int32_t col = co_ord[1];
return std::make_tuple(row, col);
#else
throw runtime_error("joint matrix is not supported on host device.",
PI_ERROR_INVALID_DEVICE);
#endif // __SYCL_DEVICE_ONLY__
}

// Various Operations
operator T() {
#ifdef __SYCL_DEVICE_ONLY__
return __spirv_VectorExtractDynamic(M.spvm, idx);
Expand Down Expand Up @@ -339,6 +355,20 @@ class wi_element<uint16_t, NumRows, NumCols, Use, Layout, Group> {
wi_element(joint_matrix<uint16_t, NumRows, NumCols, Use, Layout, Group> &Mat,
std::size_t i)
: M(Mat), idx(i) {}

std::tuple<size_t, size_t> get_coord() {
#ifdef __SYCL_DEVICE_ONLY__
__ocl_vec_t<int32_t, 2> coord =
__spirv_JointMatrixWorkItemElemCoord(M.spvm, idx);
const int32_t row = coord[0];
const int32_t col = coord[1];
return std::make_tuple(row, col);
#else
throw runtime_error("joint matrix is not supported on host device.",
PI_ERROR_INVALID_DEVICE);
#endif // __SYCL_DEVICE_ONLY__
}

operator uint16_t() {
#ifdef __SYCL_DEVICE_ONLY__
return __spirv_VectorExtractDynamic(M.spvm, idx);
Expand Down Expand Up @@ -489,6 +519,20 @@ class wi_element<sycl::ext::oneapi::experimental::bfloat16, NumRows, NumCols,
NumCols, Use, Layout, Group> &Mat,
std::size_t i)
: M(Mat), idx(i) {}

std::tuple<size_t, size_t> get_coord() {
#ifdef __SYCL_DEVICE_ONLY__
__ocl_vec_t<int32_t, 2> co_ord =
__spirv_JointMatrixWorkItemElemCoord(M.spvm, idx);
const int32_t row = co_ord[0];
const int32_t col = co_ord[1];
return std::make_tuple(row, col);
#else
throw runtime_error("joint matrix is not supported on host device.",
PI_ERROR_INVALID_DEVICE);
#endif // __SYCL_DEVICE_ONLY__
}

operator sycl::ext::oneapi::experimental::bfloat16() {
#ifdef __SYCL_DEVICE_ONLY__
return __spirv_VectorExtractDynamic(M.spvm, idx);
Expand Down
Loading