Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 47 additions & 29 deletions sycl/include/CL/__spirv/spirv_ops.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,89 +23,107 @@

#ifdef __SYCL_DEVICE_ONLY__
template <typename T, std::size_t R, std::size_t C,
__spv::MatrixUse U = __spv::MatrixUse::Unnecessary,
__spv::MatrixLayout L = __spv::MatrixLayout::RowMajor,
__spv::Scope::Flag S = __spv::Scope::Flag::Subgroup>
extern SYCL_EXTERNAL __spv::__spirv_JointMatrixINTEL<T, R, C, L, S> *
extern SYCL_EXTERNAL __spv::__spirv_JointMatrixINTEL<T, R, C, U, L, S> *
__spirv_JointMatrixLoadINTEL(T *Ptr, std::size_t Stride,
__spv::MatrixLayout Layout = L,
__spv::Scope::Flag Sc = S, int MemOperand = 0);

template <typename T, std::size_t R, std::size_t C,
__spv::MatrixUse U = __spv::MatrixUse::Unnecessary,
__spv::MatrixLayout L = __spv::MatrixLayout::RowMajor,
__spv::Scope::Flag S = __spv::Scope::Flag::Subgroup>
extern SYCL_EXTERNAL void __spirv_JointMatrixStoreINTEL(
T *Ptr, __spv::__spirv_JointMatrixINTEL<T, R, C, L, S> *Object,
T *Ptr, __spv::__spirv_JointMatrixINTEL<T, R, C, U, L, S> *Object,
std::size_t Stride, __spv::MatrixLayout Layout = L,
__spv::Scope::Flag Sc = S, int MemOperand = 0);

template <typename T1, typename T2, std::size_t M, std::size_t K, std::size_t N,
__spv::MatrixUse UA = __spv::MatrixUse::Unnecessary,
__spv::MatrixUse UB = __spv::MatrixUse::Unnecessary,
__spv::MatrixUse UC = __spv::MatrixUse::Unnecessary,
__spv::MatrixLayout LA = __spv::MatrixLayout::RowMajor,
__spv::MatrixLayout LB = __spv::MatrixLayout::RowMajor,
__spv::MatrixLayout LC = __spv::MatrixLayout::RowMajor,
__spv::Scope::Flag S = __spv::Scope::Flag::Subgroup>
extern SYCL_EXTERNAL __spv::__spirv_JointMatrixINTEL<T2, M, N, LC, S> *
extern SYCL_EXTERNAL __spv::__spirv_JointMatrixINTEL<T2, M, N, UC, LC, S> *
__spirv_JointMatrixMadINTEL(
__spv::__spirv_JointMatrixINTEL<T1, M, K, LA, S> *A,
__spv::__spirv_JointMatrixINTEL<T1, K, N, LB, S> *B,
__spv::__spirv_JointMatrixINTEL<T2, M, N, LC, S> *C,
__spv::__spirv_JointMatrixINTEL<T1, M, K, UA, LA, S> *A,
__spv::__spirv_JointMatrixINTEL<T1, K, N, UB, LB, S> *B,
__spv::__spirv_JointMatrixINTEL<T2, M, N, UC, LC, S> *C,
__spv::Scope::Flag Sc = __spv::Scope::Flag::Subgroup);

template <typename T1, typename T2, typename T3, std::size_t M, std::size_t K,
std::size_t N, __spv::MatrixLayout LA = __spv::MatrixLayout::RowMajor,
std::size_t N, __spv::MatrixUse UA = __spv::MatrixUse::Unnecessary,
__spv::MatrixUse UB = __spv::MatrixUse::Unnecessary,
__spv::MatrixUse UC = __spv::MatrixUse::Unnecessary,
__spv::MatrixLayout LA = __spv::MatrixLayout::RowMajor,
__spv::MatrixLayout LB = __spv::MatrixLayout::RowMajor,
__spv::MatrixLayout LC = __spv::MatrixLayout::RowMajor,
__spv::Scope::Flag S = __spv::Scope::Flag::Subgroup>
extern SYCL_EXTERNAL __spv::__spirv_JointMatrixINTEL<T3, M, N, LC, S> *
extern SYCL_EXTERNAL __spv::__spirv_JointMatrixINTEL<T3, M, N, UC, LC, S> *
__spirv_JointMatrixUUMadINTEL(
__spv::__spirv_JointMatrixINTEL<T1, M, K, LA, S> *A,
__spv::__spirv_JointMatrixINTEL<T2, K, N, LB, S> *B,
__spv::__spirv_JointMatrixINTEL<T3, M, N, LC, S> *C,
__spv::__spirv_JointMatrixINTEL<T1, M, K, UA, LA, S> *A,
__spv::__spirv_JointMatrixINTEL<T2, K, N, UB, LB, S> *B,
__spv::__spirv_JointMatrixINTEL<T3, M, N, UC, LC, S> *C,
__spv::Scope::Flag Sc = __spv::Scope::Flag::Subgroup);

template <typename T1, typename T2, typename T3, std::size_t M, std::size_t K,
std::size_t N, __spv::MatrixLayout LA = __spv::MatrixLayout::RowMajor,
std::size_t N, __spv::MatrixUse UA = __spv::MatrixUse::Unnecessary,
__spv::MatrixUse UB = __spv::MatrixUse::Unnecessary,
__spv::MatrixUse UC = __spv::MatrixUse::Unnecessary,
__spv::MatrixLayout LA = __spv::MatrixLayout::RowMajor,
__spv::MatrixLayout LB = __spv::MatrixLayout::RowMajor,
__spv::MatrixLayout LC = __spv::MatrixLayout::RowMajor,
__spv::Scope::Flag S = __spv::Scope::Flag::Subgroup>
extern SYCL_EXTERNAL __spv::__spirv_JointMatrixINTEL<T3, M, N, LC, S> *
extern SYCL_EXTERNAL __spv::__spirv_JointMatrixINTEL<T3, M, N, UC, LC, S> *
__spirv_JointMatrixUSMadINTEL(
__spv::__spirv_JointMatrixINTEL<T1, M, K, LA, S> *A,
__spv::__spirv_JointMatrixINTEL<T2, K, N, LB, S> *B,
__spv::__spirv_JointMatrixINTEL<T3, M, N, LC, S> *C,
__spv::__spirv_JointMatrixINTEL<T1, M, K, UA, LA, S> *A,
__spv::__spirv_JointMatrixINTEL<T2, K, N, UB, LB, S> *B,
__spv::__spirv_JointMatrixINTEL<T3, M, N, UC, LC, S> *C,
__spv::Scope::Flag Sc = __spv::Scope::Flag::Subgroup);

template <typename T1, typename T2, typename T3, std::size_t M, std::size_t K,
std::size_t N, __spv::MatrixLayout LA = __spv::MatrixLayout::RowMajor,
std::size_t N, __spv::MatrixUse UA = __spv::MatrixUse::Unnecessary,
__spv::MatrixUse UB = __spv::MatrixUse::Unnecessary,
__spv::MatrixUse UC = __spv::MatrixUse::Unnecessary,
__spv::MatrixLayout LA = __spv::MatrixLayout::RowMajor,
__spv::MatrixLayout LB = __spv::MatrixLayout::RowMajor,
__spv::MatrixLayout LC = __spv::MatrixLayout::RowMajor,
__spv::Scope::Flag S = __spv::Scope::Flag::Subgroup>
extern SYCL_EXTERNAL __spv::__spirv_JointMatrixINTEL<T3, M, N, LC, S> *
extern SYCL_EXTERNAL __spv::__spirv_JointMatrixINTEL<T3, M, N, UC, LC, S> *
__spirv_JointMatrixSUMadINTEL(
__spv::__spirv_JointMatrixINTEL<T1, M, K, LA, S> *A,
__spv::__spirv_JointMatrixINTEL<T2, K, N, LB, S> *B,
__spv::__spirv_JointMatrixINTEL<T3, M, N, LC, S> *C,
__spv::__spirv_JointMatrixINTEL<T1, M, K, UA, LA, S> *A,
__spv::__spirv_JointMatrixINTEL<T2, K, N, UB, LB, S> *B,
__spv::__spirv_JointMatrixINTEL<T3, M, N, UC, LC, S> *C,
__spv::Scope::Flag Sc = __spv::Scope::Flag::Subgroup);

template <typename T, std::size_t R, std::size_t C,
__spv::MatrixUse U = __spv::MatrixUse::Unnecessary,
__spv::MatrixLayout L = __spv::MatrixLayout::RowMajor,
__spv::Scope::Flag S = __spv::Scope::Flag::Subgroup>
extern SYCL_EXTERNAL __spv::__spirv_JointMatrixINTEL<T, R, C, L, S> *
extern SYCL_EXTERNAL __spv::__spirv_JointMatrixINTEL<T, R, C, U, L, S> *
__spirv_CompositeConstruct(const T v);

template <typename T, std::size_t R, std::size_t C, __spv::MatrixLayout U,
template <typename T, std::size_t R, std::size_t C, __spv::MatrixUse U,
__spv::MatrixLayout L,
__spv::Scope::Flag S = __spv::Scope::Flag::Subgroup>
extern SYCL_EXTERNAL size_t __spirv_JointMatrixWorkItemLengthINTEL(
__spv::__spirv_JointMatrixINTEL<T, R, C, U, S> *);
__spv::__spirv_JointMatrixINTEL<T, R, C, U, L, S> *);

template <typename T, std::size_t R, std::size_t C, __spv::MatrixLayout U,
template <typename T, std::size_t R, std::size_t C, __spv::MatrixUse U,
__spv::MatrixLayout L,
__spv::Scope::Flag S = __spv::Scope::Flag::Subgroup>
extern SYCL_EXTERNAL T __spirv_VectorExtractDynamic(
__spv::__spirv_JointMatrixINTEL<T, R, C, U, S> *, size_t i);
__spv::__spirv_JointMatrixINTEL<T, R, C, U, L, S> *, size_t i);

template <typename T, std::size_t R, std::size_t C, __spv::MatrixLayout U,
template <typename T, std::size_t R, std::size_t C, __spv::MatrixUse U,
__spv::MatrixLayout L,
__spv::Scope::Flag S = __spv::Scope::Flag::Subgroup>
extern SYCL_EXTERNAL __spv::__spirv_JointMatrixINTEL<T, R, C, U, S> *
__spirv_VectorInsertDynamic(__spv::__spirv_JointMatrixINTEL<T, R, C, U, S> *,
extern SYCL_EXTERNAL __spv::__spirv_JointMatrixINTEL<T, R, C, U, L, S> *
__spirv_VectorInsertDynamic(__spv::__spirv_JointMatrixINTEL<T, R, C, U, L, S> *,
T val, size_t i);

#ifndef __SPIRV_BUILTIN_DECLARATIONS__
Expand Down
13 changes: 11 additions & 2 deletions sycl/include/CL/__spirv/spirv_types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,13 @@ enum class MatrixLayout : uint32_t {
PackedB = 3
};

enum class MatrixUse : uint32_t {
MatrixA = 0,
MatrixB = 1,
Accumulator = 2,
Unnecessary = 3
};

// TODO: replace the following W/A with a better solution when we have it.
// The following structure is used to represent the joint matrix type in the
// LLVM IR. The structure has a pointer to a multidimensional array member which
Expand All @@ -129,10 +136,12 @@ enum class MatrixLayout : uint32_t {
// information to SPIRV translator.
// The long term solution would be to introduce a matrix type in Clang and use
// it instead of this member.
template <typename T, std::size_t R, std::size_t C, MatrixLayout U,
template <typename T, std::size_t R, std::size_t C, MatrixUse U, MatrixLayout L,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This matches the template arguments used for joint_matrix in matrix-tensorcore.hpp (the current CUDA impl), except that the ordering is a little different (MatrixUse is the second template param in matrix-tensorcore.hpp). Trivial point, but we should align an ordering before merging the impls into a single SYCL_EXT_ONEAPI_MATRIX version.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a good point. Currently, this is at this position because it is an optional argument for now (to keep backward compatibility with the previous API (no use argument)).
Once this API with "use" argument is stable enough that we can remove the non-use-API, we can revise the order.
My personal preference is that it is probably a good idea to keep the arguments that might be become "optional" in the future as last. Today, "use" is required but layout is definitely optional.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suggest not to modify the order of template parameters of internal SPIR-V joint matrix type representation. The reason is that after this patch: #6535 the clang started to generate opaque matrix type like this: spirv.JointMatrixINTEL._half_10_2_0_0 and hence we can and will remove the array W/A. So it becomes crucial to keep the template parameter's order internally (it still can be changed in user-visible API).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So, we will keep the Workaround in SPIRVRegularizeLLVMBase::adaptStructTypes for now. Since the W/A will be removed in the future, we'd better move the Use at the end of template param list, right?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes

Scope::Flag S = Scope::Flag::Subgroup>
struct __spirv_JointMatrixINTEL {
T (*Value)[R][C][static_cast<size_t>(U) + 1][static_cast<size_t>(S) + 1];
T(*Value)
[R][C][static_cast<size_t>(L) + 1][static_cast<size_t>(S) + 1]
[static_cast<size_t>(U) + 1];
};

} // namespace __spv
Expand Down
2 changes: 1 addition & 1 deletion sycl/include/CL/sycl/feature_test.hpp.in
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ namespace sycl {
// 2- provides JIT implementation (target agnostic) for the
// experimental matrix extension
#ifndef SYCL_EXT_ONEAPI_MATRIX
#define SYCL_EXT_ONEAPI_MATRIX 2
#define SYCL_EXT_ONEAPI_MATRIX 3
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

change to 2

#endif
#define SYCL_EXT_ONEAPI_ASSERT 1
#define SYCL_EXT_ONEAPI_COMPLEX_ALGORITHMS 1
Expand Down
Loading