From b68aeadc3c0792823772e080bae0b7ec6c914368 Mon Sep 17 00:00:00 2001 From: Bing1 Yu Date: Tue, 19 Sep 2023 11:46:01 +0800 Subject: [PATCH 01/50] =?UTF-8?q?[Matrix]=20syntax=20changes=20as=20prepra?= =?UTF-8?q?ration=20before=20moving=20joint=20matrix=20from=20experimental?= =?UTF-8?q?=20namespace=20As=20part=20of=20the=20effort=20to=20move=20join?= =?UTF-8?q?t=20matrix=20from=20experimental=20namespace=20to=20supported.?= =?UTF-8?q?=20A=20review=20of=20the=20API=20is=20being=20done=20as=20part?= =?UTF-8?q?=20of=20https://github.com/intel/llvm/pull/7964.=20This=20resul?= =?UTF-8?q?ts=20in=20the=20following=20changes=20in=20the=20syntax:=201-?= =?UTF-8?q?=20Add=20Td=20to=20joint=5Fmatrix=5Fmad=20as=20Tc=20can=20be=20?= =?UTF-8?q?different=20from=20Td=20on=20the=20GPU,=20Now,=20we=20make=20D?= =?UTF-8?q?=20as=20an=20input=20argument=20to=20mad.=202-=20=20Change=20?= =?UTF-8?q?=E2=80=9Cpacked=E2=80=9D=20to=20ext=5Fintel=5Fpacked:=203-=20?= =?UTF-8?q?=20Move=20EWOps=20(get=5Fwi=5Fdata,=20wi=5Felement,=20get=5Fcoo?= =?UTF-8?q?rd)=20to=20detail=20namespace)=204-=20add=20const=20to=20joint?= =?UTF-8?q?=5Fmatrix=20in=20store=20and=20mad=205=20-=20add=20joint=5Fmatr?= =?UTF-8?q?ix=5Fcopy/assignment=20function=206-=20add=20apply=20with=20coo?= =?UTF-8?q?rdination=20(change=20existing=20tests)=207-=20change=20get=5Fc?= =?UTF-8?q?oord=20vector=20type=20from=20int32=5Ft=20to=20size=5Ft=208-=20?= =?UTF-8?q?delete=20explicitly=20both=20=3D=20and=20copy=20ctor.?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../sycl/ext/oneapi/matrix/matrix-intel.hpp | 59 +++++++--- .../oneapi/matrix/matrix-unified-utils.hpp | 7 +- .../sycl/ext/oneapi/matrix/matrix-unified.hpp | 105 ++++++++++-------- .../Matrix/Legacy/element_wise_ops_impl.hpp | 2 +- .../Legacy/elemwise_irreg_size_ops_bf16.cpp | 2 +- .../Matrix/Legacy/joint_matrix_bf16_impl.hpp | 2 +- .../joint_matrix_bfloat16_32x64_impl.hpp | 2 +- ...trix_bfloat16_colmajorA_colmajorB_impl.hpp | 2 +- .../Legacy/joint_matrix_bfloat16_impl.hpp | 2 +- ...trix_bfloat16_rowmajorA_rowmajorB_impl.hpp | 2 +- .../Matrix/Legacy/joint_matrix_half_impl.hpp | 2 +- ...t_matrix_int8_colmajorA_colmajorB_impl.hpp | 2 +- .../Legacy/joint_matrix_int8_vnni_impl.hpp | 2 +- .../Legacy/joint_matrix_query_default.cpp | 2 +- .../Legacy/joint_matrix_ss_int8_impl.hpp | 2 +- .../Legacy/joint_matrix_su_int8_impl.hpp | 2 +- .../Legacy/joint_matrix_us_int8_impl.hpp | 2 +- .../Legacy/joint_matrix_uu_int8_impl.hpp | 2 +- .../test-e2e/Matrix/element_wise_abc_impl.hpp | 14 +-- .../Matrix/element_wise_all_ops_half_impl.hpp | 10 +- .../Matrix/element_wise_all_ops_impl.hpp | 4 +- .../Matrix/element_wise_all_ops_int8_impl.hpp | 10 +- .../element_wise_all_ops_int8_packed_impl.hpp | 40 +++---- .../Matrix/element_wise_all_ops_tf32_impl.hpp | 10 +- .../Matrix/element_wise_all_sizes_impl.hpp | 2 +- .../element_wise_irreg_sum_rows_impl.hpp | 8 +- .../test-e2e/Matrix/element_wise_ops_impl.hpp | 10 +- .../Matrix/elemwise_irreg_size_ops_bf16.cpp | 10 +- .../Matrix/joint_matrix_all_sizes_impl.hpp | 7 +- .../Matrix/joint_matrix_apply_cuda.hpp | 65 ++++++----- .../joint_matrix_bf16_fill_k_cache_impl.hpp | 53 ++++----- .../joint_matrix_bfloat16_32x64_impl.hpp | 7 +- .../joint_matrix_bfloat16_array_impl.hpp | 7 +- ...trix_bfloat16_colmajorA_colmajorB_impl.hpp | 2 +- .../Matrix/joint_matrix_bfloat16_impl.hpp | 7 +- ...trix_bfloat16_rowmajorA_rowmajorB_impl.hpp | 2 +- .../joint_matrix_colA_rowB_colC_impl.hpp | 2 +- .../Matrix/joint_matrix_gemm_cuda.hpp | 2 +- .../Matrix/joint_matrix_half_impl.hpp | 7 +- ...t_matrix_int8_colmajorA_colmajorB_impl.hpp | 2 +- .../Matrix/joint_matrix_int8_vnni_impl.hpp | 2 +- .../Matrix/joint_matrix_out_bounds_impl.hpp | 2 +- .../Matrix/joint_matrix_query_default.cpp | 5 +- .../Matrix/joint_matrix_ss_int8_impl.hpp | 7 +- .../Matrix/joint_matrix_su_int8_impl.hpp | 7 +- .../Matrix/joint_matrix_tf32_impl.hpp | 8 +- .../Matrix/joint_matrix_transposeC_impl.hpp | 7 +- .../Matrix/joint_matrix_us_int8_impl.hpp | 7 +- .../Matrix/joint_matrix_uu_int8_impl.hpp | 7 +- .../matrix/matrix-nvptx-bfloat16-test.cpp | 12 +- .../cuda/matrix/matrix-nvptx-double-test.cpp | 4 +- .../matrix/matrix-nvptx-half-float-test.cpp | 12 +- .../matrix/matrix-nvptx-half-half-test.cpp | 12 +- .../cuda/matrix/matrix-nvptx-int8-test.cpp | 12 +- .../cuda/matrix/matrix-nvptx-tf32-test.cpp | 4 +- .../cuda/matrix/matrix-nvptx-uint8-test.cpp | 12 +- .../matrix/matrix_load_store_as.cpp | 7 +- .../matrix/matrix_load_store_as_legacy.cpp | 2 +- .../matrix/legacy/matrix-bf16-test-SG-16.cpp | 2 +- sycl/test/matrix/legacy/matrix-bf16-test.cpp | 2 +- .../matrix/legacy/matrix-bfloat16-test.cpp | 2 +- .../matrix/legacy/matrix-elemwise-ops.cpp | 2 +- .../matrix/legacy/matrix-int8-test-SG-16.cpp | 2 +- sycl/test/matrix/legacy/matrix-int8-test.cpp | 2 +- .../matrix-bfloat16-test-coord-basicB.cpp | 8 +- sycl/test/matrix/matrix-bfloat16-test.cpp | 5 +- sycl/test/matrix/matrix-elemwise-ops.cpp | 8 +- sycl/test/matrix/matrix-int8-test.cpp | 5 +- sycl/test/matrix/matrix-tf32-test.cpp | 5 +- sycl/test/matrix/query-use.cpp | 6 +- 70 files changed, 361 insertions(+), 299 deletions(-) diff --git a/sycl/include/sycl/ext/oneapi/matrix/matrix-intel.hpp b/sycl/include/sycl/ext/oneapi/matrix/matrix-intel.hpp index 9de4f4ec1851f..eac624a9c3360 100644 --- a/sycl/include/sycl/ext/oneapi/matrix/matrix-intel.hpp +++ b/sycl/include/sycl/ext/oneapi/matrix/matrix-intel.hpp @@ -29,10 +29,6 @@ namespace sycl { inline namespace _V1 { namespace ext { -namespace intel::experimental::matrix::layout { -constexpr sycl::ext::oneapi::experimental::matrix::layout packed = - static_cast(2); -} namespace oneapi { namespace experimental { namespace matrix { @@ -48,8 +44,7 @@ template struct spv_matrix_layout_traits { SPV_MATRIX_LAYOUT_TRAITS(layout::row_major, __spv::MatrixLayout::RowMajor) SPV_MATRIX_LAYOUT_TRAITS(layout::col_major, __spv::MatrixLayout::ColumnMajor) -SPV_MATRIX_LAYOUT_TRAITS(sycl::ext::intel::experimental::matrix::layout::packed, - __spv::MatrixLayout::Packed) +SPV_MATRIX_LAYOUT_TRAITS(layout::ext_intel_packed, __spv::MatrixLayout::Packed) SPV_MATRIX_LAYOUT_TRAITS(layout::dynamic, __spv::MatrixLayout::Dynamic) template struct spv_matrix_use_traits { @@ -94,10 +89,6 @@ struct jm_type_interpretation_helper_trait< using element_type = sycl::ext::oneapi::experimental::matrix::precision::tf32; using storage_element_type = float; }; -} // namespace detail -} // namespace oneapi - -namespace intel::experimental::matrix { using namespace sycl::ext::oneapi::experimental::matrix; // Begin wi_element definition @@ -121,12 +112,12 @@ class wi_element { std::size_t i) : M(Mat), idx(i) {} - inline __SYCL_ALWAYS_INLINE std::tuple get_coord() { + inline __SYCL_ALWAYS_INLINE std::tuple get_coord() { #if defined(__SYCL_DEVICE_ONLY__) __ocl_vec_t coord = __spirv_JointMatrixGetElementCoordINTEL(M.spvm, idx); - const uint32_t row = coord[0]; - const uint32_t col = coord[1]; + const size_t row = coord[0]; + const size_t col = coord[1]; return std::make_tuple(row, col); #else throw runtime_error("joint matrix is not supported on host device.", @@ -479,7 +470,10 @@ get_wi_data(Group sg, sycl::ext::oneapi::experimental::matrix::joint_matrix< } // End wi_data definition +} // namespace detail +} // namespace oneapi +namespace intel::experimental::matrix { template < typename Group, typename T, typename Tp, sycl::ext::oneapi::experimental::matrix::use Use, size_t NumRows, @@ -490,7 +484,7 @@ template < bool> = true> inline __SYCL_ALWAYS_INLINE void joint_matrix_store(Group sg, - sycl::ext::oneapi::experimental::matrix::joint_matrix< + const sycl::ext::oneapi::experimental::matrix::joint_matrix< Group, Tp, Use, NumRows, NumCols, Layout> &src, multi_ptr dst, size_t stride) { #if defined(__SYCL_DEVICE_ONLY__) @@ -528,6 +522,43 @@ joint_matrix_store(Group sg, PI_ERROR_INVALID_DEVICE); #endif // defined(__SYCL_DEVICE_ONLY__) } + +template +inline __SYCL_ALWAYS_INLINE void joint_matrix_apply( + Group sg, + sycl::ext::oneapi::experimental::matrix::joint_matrix &jm, + F &&lambda) { +#if defined(__SYCL_DEVICE_ONLY__) +#if defined(__NVPTX__) + std::ignore = sg; + for (int i = 0; i < jm.cuda_impl.wi_marray.size(); i++) { + lambda(jm.cuda_impl.wi_marray[i]); + } +#else // NVPTX + using storage_element_type = + typename oneapi::detail::jm_type_interpretation_helper_trait< + T>::storage_element_type; + auto wi_data_c = sycl::ext::oneapi::detail::get_wi_data(sg, jm); + for (int i = 0; i < wi_data_c.length(); i++) { + storage_element_type element = wi_data_c[i]; + auto [row, col] = wi_data_c[i].get_coord(); + lambda(element, row, col); + wi_data_c[i] = element; + } +#endif +#else + std::ignore = sg; + std::ignore = jm; + std::ignore = lambda; + throw runtime_error("joint matrix is not supported on host device.", + PI_ERROR_INVALID_DEVICE); +#endif + return; +} + } // namespace intel::experimental::matrix } // namespace ext diff --git a/sycl/include/sycl/ext/oneapi/matrix/matrix-unified-utils.hpp b/sycl/include/sycl/ext/oneapi/matrix/matrix-unified-utils.hpp index f51e146fd9a0c..8a9dbc12df2ec 100644 --- a/sycl/include/sycl/ext/oneapi/matrix/matrix-unified-utils.hpp +++ b/sycl/include/sycl/ext/oneapi/matrix/matrix-unified-utils.hpp @@ -16,7 +16,12 @@ namespace matrix { enum class use { a, b, accumulator }; -enum class layout { row_major = 0, col_major = 1, dynamic = 3 }; +enum class layout { + row_major = 0, + col_major = 1, + ext_intel_packed = 2, + dynamic = 3 +}; namespace precision { class tf32 { diff --git a/sycl/include/sycl/ext/oneapi/matrix/matrix-unified.hpp b/sycl/include/sycl/ext/oneapi/matrix/matrix-unified.hpp index 7b101b18cea90..5fc6290c1f71d 100644 --- a/sycl/include/sycl/ext/oneapi/matrix/matrix-unified.hpp +++ b/sycl/include/sycl/ext/oneapi/matrix/matrix-unified.hpp @@ -61,19 +61,8 @@ struct joint_matrix { } #ifdef __SYCL_DEVICE_ONLY__ #if defined(__SPIR__) - // Generate a non-trivial assignment operator and copy c'tor that prevents - // memcpy from being generated. - // TODO: to remove, when either IGC can handle alloca JointMatrix or - // combination of InstCombine + SROA + mem2reg can remove it - joint_matrix(const joint_matrix &other) { - spvm = other.spvm; - return *this; - } - - joint_matrix &operator=(const joint_matrix &rhs) { - spvm = rhs.spvm; - return *this; - } + joint_matrix(const joint_matrix &other) = delete; + joint_matrix &operator=(const joint_matrix &rhs) = delete; #endif // defined(__SPIR__) #endif }; @@ -99,7 +88,7 @@ class wi_data { return jm.cuda_impl.wi_marray.size(); #else throw runtime_error("get_wi_data is available using: " - "ext::intel::experimental::matrix::get_wi_data.", + "ext::oneapi::detail::get_wi_data.", PI_ERROR_INVALID_DEVICE); #endif }; @@ -109,7 +98,7 @@ class wi_data { return (jm.cuda_impl.wi_marray[i]); #else throw runtime_error("get_wi_data is available using: " - "ext::intel::experimental::matrix::get_wi_data.", + "ext::oneapi::detail::get_wi_data.", PI_ERROR_INVALID_DEVICE); #endif }; @@ -138,9 +127,9 @@ template &jm, using storage_element_type = typename oneapi::detail::jm_type_interpretation_helper_trait< T>::storage_element_type; - auto wi_data_c = sycl::ext::intel::experimental::matrix::get_wi_data(sg, jm); + auto wi_data_c = sycl::ext::oneapi::detail::get_wi_data(sg, jm); for (int i = 0; i < wi_data_c.length(); i++) { storage_element_type element = wi_data_c[i]; lambda(element); @@ -262,7 +251,7 @@ inline __SYCL_ALWAYS_INLINE void joint_matrix_load( Ptr, stride, __spv::MatrixLayout::ColumnMajor, spv_scope_traits::value); break; - case sycl::ext::intel::experimental::matrix::layout::packed: + case sycl::ext::oneapi::experimental::matrix::layout::ext_intel_packed: res.spvm = __spirv_JointMatrixLoadINTEL< DecorT, S, NumRows, NumCols, spv_matrix_use_traits::value, @@ -327,8 +316,9 @@ template inline __SYCL_ALWAYS_INLINE void joint_matrix_store( Group sg, - joint_matrix &src, + const joint_matrix + &src, multi_ptr dst, size_t stride, sycl::ext::oneapi::experimental::matrix::layout Layout) { #if defined(__SYCL_DEVICE_ONLY__) @@ -361,7 +351,7 @@ inline __SYCL_ALWAYS_INLINE void joint_matrix_store( Ptr, src.spvm, stride, __spv::MatrixLayout::ColumnMajor, spv_scope_traits::value); break; - case sycl::ext::intel::experimental::matrix::layout::packed: + case sycl::ext::oneapi::experimental::matrix::layout::ext_intel_packed: __spirv_JointMatrixStoreINTEL< DecorT, T, NumRows, NumCols, spv_matrix_use_traits::value, @@ -382,53 +372,78 @@ inline __SYCL_ALWAYS_INLINE void joint_matrix_store( #endif // defined(__SYCL_DEVICE_ONLY__) } -template -inline __SYCL_ALWAYS_INLINE - joint_matrix - joint_matrix_mad( - Group sg, joint_matrix &A, - joint_matrix &B, - joint_matrix - &C) { +template +inline __SYCL_ALWAYS_INLINE void joint_matrix_mad( + Group sg, const joint_matrix &A, + const joint_matrix &B, + const joint_matrix + &C, + joint_matrix &D) { #if defined(__SYCL_DEVICE_ONLY__) #if defined(__NVPTX__) std::ignore = sg; if constexpr (std::is_same::value) { - joint_matrix - D; sycl::ext::oneapi::detail::joint_matrix_mad_cuda( D.cuda_impl, A.cuda_impl, B.cuda_impl, C.cuda_impl); - return D; } else { assert(false && "Ta != Tb : In the CUDA backend joint_matrix_mad " "requires that joint_matrix data types Ta and Tb match"); } #else - joint_matrix res; if constexpr (std::is_same::value && std::is_same::value && std::is_same::value) - res.spvm = __spirv_JointMatrixMadINTEL(A.spvm, B.spvm, C.spvm); + D.spvm = __spirv_JointMatrixMadINTEL(A.spvm, B.spvm, C.spvm); else if constexpr (std::is_unsigned::value && std::is_unsigned::value) - res.spvm = __spirv_JointMatrixUUMadINTEL(A.spvm, B.spvm, C.spvm); + D.spvm = __spirv_JointMatrixUUMadINTEL(A.spvm, B.spvm, C.spvm); else if constexpr (std::is_signed::value && std::is_unsigned::value) - res.spvm = __spirv_JointMatrixSUMadINTEL(A.spvm, B.spvm, C.spvm); + D.spvm = __spirv_JointMatrixSUMadINTEL(A.spvm, B.spvm, C.spvm); else if constexpr (std::is_unsigned::value && std::is_signed::value) - res.spvm = __spirv_JointMatrixUSMadINTEL(A.spvm, B.spvm, C.spvm); + D.spvm = __spirv_JointMatrixUSMadINTEL(A.spvm, B.spvm, C.spvm); else - res.spvm = __spirv_JointMatrixMadINTEL(A.spvm, B.spvm, C.spvm); - return res; + D.spvm = __spirv_JointMatrixMadINTEL(A.spvm, B.spvm, C.spvm); #endif // defined(__NVPTX__) #else std::ignore = sg; std::ignore = A; std::ignore = B; std::ignore = C; + std::ignore = D; + throw runtime_error("joint matrix is not supported on host device.", + PI_ERROR_INVALID_DEVICE); +#endif // defined(__SYCL_DEVICE_ONLY__) +} + +template +void joint_matrix_copy(Group sg, + joint_matrix &src, + joint_matrix &dst) { +#if defined(__SYCL_DEVICE_ONLY__) +#if defined(__NVPTX__) + std::ignore = sg; + for (int i = 0; i < src.cuda_impl.wi_marray.size(); i++) { + dest.cuda_impl.wi_marray[i] = src.cuda_impl.wi_marray[i]; + } +#else + using storage_element_type = + typename oneapi::detail::jm_type_interpretation_helper_trait< + T2>::storage_element_type; + auto wi_data_c = sycl::ext::oneapi::detail::get_wi_data(sg, src); + auto wi_data_dst = sycl::ext::oneapi::detail::get_wi_data(sg, dst); + for (int i = 0; i < wi_data_c.length(); i++) { + wi_data_dst[i] = static_cast(wi_data_c[i]); + } +#endif // defined(__NVPTX__) +#else + std::ignore = sg; + std::ignore = dst; + std::ignore = src; throw runtime_error("joint matrix is not supported on host device.", PI_ERROR_INVALID_DEVICE); #endif // defined(__SYCL_DEVICE_ONLY__) diff --git a/sycl/test-e2e/Matrix/Legacy/element_wise_ops_impl.hpp b/sycl/test-e2e/Matrix/Legacy/element_wise_ops_impl.hpp index 8d15b78fd3198..7179b82855f50 100644 --- a/sycl/test-e2e/Matrix/Legacy/element_wise_ops_impl.hpp +++ b/sycl/test-e2e/Matrix/Legacy/element_wise_ops_impl.hpp @@ -76,7 +76,7 @@ void matrix_multiply(big_matrix &C, accB.template get_multi_ptr() + (k * TK / 4) * (N * 4) + sg_starty / SG_SZ * TN * 4, N * 4, matrix_layout::packed_b); - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, sub_c); } auto wi_slice_c = sub_c.get_wi_data(); for (int i = 0; i < wi_slice_c.length(); i++) { diff --git a/sycl/test-e2e/Matrix/Legacy/elemwise_irreg_size_ops_bf16.cpp b/sycl/test-e2e/Matrix/Legacy/elemwise_irreg_size_ops_bf16.cpp index 0f57377c571ac..d041cd6050d73 100644 --- a/sycl/test-e2e/Matrix/Legacy/elemwise_irreg_size_ops_bf16.cpp +++ b/sycl/test-e2e/Matrix/Legacy/elemwise_irreg_size_ops_bf16.cpp @@ -100,7 +100,7 @@ void matrix_multiply(big_matrix &C, accB.template get_multi_ptr() + (k) * (N) + sg_starty / SG_SZ * TN * 2, N * 2, matrix_layout::packed_b); - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, sub_c); } auto wi_slice_c = sub_c.get_wi_data(); for (int i = 0; i < wi_slice_c.length(); i++) { diff --git a/sycl/test-e2e/Matrix/Legacy/joint_matrix_bf16_impl.hpp b/sycl/test-e2e/Matrix/Legacy/joint_matrix_bf16_impl.hpp index 9868aef0d92e2..3ba9ae346e070 100644 --- a/sycl/test-e2e/Matrix/Legacy/joint_matrix_bf16_impl.hpp +++ b/sycl/test-e2e/Matrix/Legacy/joint_matrix_bf16_impl.hpp @@ -76,7 +76,7 @@ void matrix_multiply(big_matrix &C, accB.template get_multi_ptr() + (k) * (N) + sg_starty / SG_SZ * TN * 2, N * 2, matrix_layout::packed_b); - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, sub_c); } joint_matrix_store( sg, sub_c, diff --git a/sycl/test-e2e/Matrix/Legacy/joint_matrix_bfloat16_32x64_impl.hpp b/sycl/test-e2e/Matrix/Legacy/joint_matrix_bfloat16_32x64_impl.hpp index ac4a0bc405816..8835bce054171 100644 --- a/sycl/test-e2e/Matrix/Legacy/joint_matrix_bfloat16_32x64_impl.hpp +++ b/sycl/test-e2e/Matrix/Legacy/joint_matrix_bfloat16_32x64_impl.hpp @@ -68,7 +68,7 @@ void matrix_multiply(big_matrix &C, big_matrix &A, accB.template get_multi_ptr() + (k * TK / 2) * (N * 2) + sg_starty / SG_SZ * TN * 2, N * 2, matrix_layout::packed_b); - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, sub_c); } joint_matrix_store( sg, sub_c, diff --git a/sycl/test-e2e/Matrix/Legacy/joint_matrix_bfloat16_colmajorA_colmajorB_impl.hpp b/sycl/test-e2e/Matrix/Legacy/joint_matrix_bfloat16_colmajorA_colmajorB_impl.hpp index 91845ac61a180..48c4a894ab385 100644 --- a/sycl/test-e2e/Matrix/Legacy/joint_matrix_bfloat16_colmajorA_colmajorB_impl.hpp +++ b/sycl/test-e2e/Matrix/Legacy/joint_matrix_bfloat16_colmajorA_colmajorB_impl.hpp @@ -62,7 +62,7 @@ void matrix_multiply(big_matrix &C, big_matrix &A, accB.template get_multi_ptr() + (sg_starty / SG_SZ * TN) * K + k * TK, K, matrix_layout::col_major); - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, sub_c); } joint_matrix_store( sg, sub_c, diff --git a/sycl/test-e2e/Matrix/Legacy/joint_matrix_bfloat16_impl.hpp b/sycl/test-e2e/Matrix/Legacy/joint_matrix_bfloat16_impl.hpp index 2598905f9f6fe..db16f05673321 100644 --- a/sycl/test-e2e/Matrix/Legacy/joint_matrix_bfloat16_impl.hpp +++ b/sycl/test-e2e/Matrix/Legacy/joint_matrix_bfloat16_impl.hpp @@ -68,7 +68,7 @@ void matrix_multiply(big_matrix &C, big_matrix &A, accB.template get_multi_ptr() + (k * TK / 2) * (N * 2) + sg_starty / SG_SZ * TN * 2, N * 2, matrix_layout::packed_b); - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, sub_c); } joint_matrix_store( sg, sub_c, diff --git a/sycl/test-e2e/Matrix/Legacy/joint_matrix_bfloat16_rowmajorA_rowmajorB_impl.hpp b/sycl/test-e2e/Matrix/Legacy/joint_matrix_bfloat16_rowmajorA_rowmajorB_impl.hpp index c293d8ff22944..5676cd849e1b5 100644 --- a/sycl/test-e2e/Matrix/Legacy/joint_matrix_bfloat16_rowmajorA_rowmajorB_impl.hpp +++ b/sycl/test-e2e/Matrix/Legacy/joint_matrix_bfloat16_rowmajorA_rowmajorB_impl.hpp @@ -62,7 +62,7 @@ void matrix_multiply(big_matrix &C, big_matrix &A, accB.template get_multi_ptr() + (k * TK) * (N) + sg_starty / SG_SZ * TN, N, matrix_layout::row_major); - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, sub_c); } joint_matrix_store( sg, sub_c, diff --git a/sycl/test-e2e/Matrix/Legacy/joint_matrix_half_impl.hpp b/sycl/test-e2e/Matrix/Legacy/joint_matrix_half_impl.hpp index c663cc282c758..658985c10ab5c 100644 --- a/sycl/test-e2e/Matrix/Legacy/joint_matrix_half_impl.hpp +++ b/sycl/test-e2e/Matrix/Legacy/joint_matrix_half_impl.hpp @@ -73,7 +73,7 @@ void matrix_multiply(big_matrix &C, accB.template get_multi_ptr() + (k * TK / 2) * (N * 2) + sg_starty / SG_SZ * TN * 2, N * 2, matrix_layout::packed_b); - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, sub_c); } joint_matrix_store( sg, sub_c, diff --git a/sycl/test-e2e/Matrix/Legacy/joint_matrix_int8_colmajorA_colmajorB_impl.hpp b/sycl/test-e2e/Matrix/Legacy/joint_matrix_int8_colmajorA_colmajorB_impl.hpp index 21347c80c083b..85bffe8957e77 100644 --- a/sycl/test-e2e/Matrix/Legacy/joint_matrix_int8_colmajorA_colmajorB_impl.hpp +++ b/sycl/test-e2e/Matrix/Legacy/joint_matrix_int8_colmajorA_colmajorB_impl.hpp @@ -72,7 +72,7 @@ void matrix_multiply(big_matrix &C, accB.template get_multi_ptr() + (sg_starty / SG_SZ * TN) * K + k * TK, K, matrix_layout::col_major); - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, sub_c); } joint_matrix_store( sg, sub_c, diff --git a/sycl/test-e2e/Matrix/Legacy/joint_matrix_int8_vnni_impl.hpp b/sycl/test-e2e/Matrix/Legacy/joint_matrix_int8_vnni_impl.hpp index 1948071dbf405..2cb8196a49306 100644 --- a/sycl/test-e2e/Matrix/Legacy/joint_matrix_int8_vnni_impl.hpp +++ b/sycl/test-e2e/Matrix/Legacy/joint_matrix_int8_vnni_impl.hpp @@ -64,7 +64,7 @@ void matrix_multiply(big_matrix &C, accB.template get_multi_ptr() + (k * TK) * N + sg_starty / SG_SZ * TN, N, matrix_layout::row_major); - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, sub_c); } joint_matrix_store( sg, sub_c, diff --git a/sycl/test-e2e/Matrix/Legacy/joint_matrix_query_default.cpp b/sycl/test-e2e/Matrix/Legacy/joint_matrix_query_default.cpp index 8aaf737a274a8..6b7878b5cce7d 100644 --- a/sycl/test-e2e/Matrix/Legacy/joint_matrix_query_default.cpp +++ b/sycl/test-e2e/Matrix/Legacy/joint_matrix_query_default.cpp @@ -97,7 +97,7 @@ void matrix_multiply(big_matrix &C, accB.template get_multi_ptr() + (k * TK / 4) * (N * 4) + sg_starty / SG_SZ * TN * 4, N * 4, matrix_layout::packed_b); - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, sub_c); } joint_matrix_store( sg, sub_c, diff --git a/sycl/test-e2e/Matrix/Legacy/joint_matrix_ss_int8_impl.hpp b/sycl/test-e2e/Matrix/Legacy/joint_matrix_ss_int8_impl.hpp index a2436bc56e792..ef594dc6bc3a8 100644 --- a/sycl/test-e2e/Matrix/Legacy/joint_matrix_ss_int8_impl.hpp +++ b/sycl/test-e2e/Matrix/Legacy/joint_matrix_ss_int8_impl.hpp @@ -70,7 +70,7 @@ void matrix_multiply(big_matrix &C, accB.template get_multi_ptr() + (k * TK / 4) * (N * 4) + sg_starty / SG_SZ * TN * 4, N * 4, matrix_layout::packed_b); - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, sub_c); } joint_matrix_store( sg, sub_c, diff --git a/sycl/test-e2e/Matrix/Legacy/joint_matrix_su_int8_impl.hpp b/sycl/test-e2e/Matrix/Legacy/joint_matrix_su_int8_impl.hpp index f0a9a7155fb0d..0eb5bc2bc8d58 100644 --- a/sycl/test-e2e/Matrix/Legacy/joint_matrix_su_int8_impl.hpp +++ b/sycl/test-e2e/Matrix/Legacy/joint_matrix_su_int8_impl.hpp @@ -76,7 +76,7 @@ void matrix_multiply(big_matrix &C, accB.template get_multi_ptr() + (k * TK / 4) * (N * 4) + sg_starty / SG_SZ * TN * 4, N * 4, matrix_layout::packed_b); - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, sub_c); } joint_matrix_store( sg, sub_c, diff --git a/sycl/test-e2e/Matrix/Legacy/joint_matrix_us_int8_impl.hpp b/sycl/test-e2e/Matrix/Legacy/joint_matrix_us_int8_impl.hpp index 68cf40bb481b9..3d0067cbb7b36 100644 --- a/sycl/test-e2e/Matrix/Legacy/joint_matrix_us_int8_impl.hpp +++ b/sycl/test-e2e/Matrix/Legacy/joint_matrix_us_int8_impl.hpp @@ -78,7 +78,7 @@ void matrix_multiply(big_matrix &C, accB.template get_multi_ptr() + (k * TK / 4) * (N * 4) + sg_starty / SG_SZ * TN * 4, N * 4, matrix_layout::packed_b); - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, sub_c); } joint_matrix_store( sg, sub_c, diff --git a/sycl/test-e2e/Matrix/Legacy/joint_matrix_uu_int8_impl.hpp b/sycl/test-e2e/Matrix/Legacy/joint_matrix_uu_int8_impl.hpp index 14190434dd2b1..be4eee8452c3c 100644 --- a/sycl/test-e2e/Matrix/Legacy/joint_matrix_uu_int8_impl.hpp +++ b/sycl/test-e2e/Matrix/Legacy/joint_matrix_uu_int8_impl.hpp @@ -76,7 +76,7 @@ void matrix_multiply(big_matrix &C, accB.template get_multi_ptr() + (k * TK / 4) * (N * 4) + sg_starty / SG_SZ * TN * 4, N * 4, matrix_layout::packed_b); - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, sub_c); } joint_matrix_store( sg, sub_c, diff --git a/sycl/test-e2e/Matrix/element_wise_abc_impl.hpp b/sycl/test-e2e/Matrix/element_wise_abc_impl.hpp index 8b7ee3af2b9c5..1da92f4fd4dbc 100644 --- a/sycl/test-e2e/Matrix/element_wise_abc_impl.hpp +++ b/sycl/test-e2e/Matrix/element_wise_abc_impl.hpp @@ -55,8 +55,9 @@ void matrix_elem_wise_ops(big_matrix &C, big_matrix &A, sub_group sg = spmd_item.get_sub_group(); joint_matrix sub_a; // For B, we assume B has been already VNNIed. - joint_matrix + joint_matrix< + sub_group, T2, use::b, TK, TN, + ext::oneapi::experimental::matrix::layout::ext_intel_packed> sub_b; joint_matrix sub_c; @@ -65,8 +66,7 @@ void matrix_elem_wise_ops(big_matrix &C, big_matrix &A, accA.template get_multi_ptr() + (sg_startx * TM) * K, K); - auto wi_slice_a = - sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_a); + auto wi_slice_a = sycl::ext::oneapi::detail::get_wi_data(sg, sub_a); for (int i = 0; i < wi_slice_a.length(); i++) { wi_slice_a[i] += 1; } @@ -76,8 +76,7 @@ void matrix_elem_wise_ops(big_matrix &C, big_matrix &A, accB.template get_multi_ptr() + sg_starty / SG_SZ * TN * vnniFactor, N * vnniFactor); - auto wi_slice_b = - sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_b); + auto wi_slice_b = sycl::ext::oneapi::detail::get_wi_data(sg, sub_b); for (int i = 0; i < wi_slice_b.length(); i++) { wi_slice_b[i] += 1; } @@ -87,8 +86,7 @@ void matrix_elem_wise_ops(big_matrix &C, big_matrix &A, accC.template get_multi_ptr() + (sg_startx * TM) * N + sg_starty / SG_SZ * TN, N, layout::row_major); - auto wi_slice_c = - sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_c); + auto wi_slice_c = sycl::ext::oneapi::detail::get_wi_data(sg, sub_c); for (int i = 0; i < wi_slice_c.length(); i++) { wi_slice_c[i] += 1; } diff --git a/sycl/test-e2e/Matrix/element_wise_all_ops_half_impl.hpp b/sycl/test-e2e/Matrix/element_wise_all_ops_half_impl.hpp index 540d75c245815..063acee4c2ffe 100644 --- a/sycl/test-e2e/Matrix/element_wise_all_ops_half_impl.hpp +++ b/sycl/test-e2e/Matrix/element_wise_all_ops_half_impl.hpp @@ -42,7 +42,7 @@ void matrix_verify_add(queue q, big_matrix &A, nd_range<2> &r, joint_matrix_fill(sg, sub_a, 5); auto wi_slice_a = - sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_a); + sycl::ext::oneapi::detail::get_wi_data(sg, sub_a); for (int i = 0; i < wi_slice_a.length(); i++) { wi_slice_a[i] = wi_slice_a[i] + static_cast(2); } @@ -77,7 +77,7 @@ void matrix_verify_sub(queue q, big_matrix &A, nd_range<2> &r, joint_matrix_fill(sg, sub_a, 5); auto wi_slice_a = - sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_a); + sycl::ext::oneapi::detail::get_wi_data(sg, sub_a); for (int i = 0; i < wi_slice_a.length(); i++) { wi_slice_a[i] = wi_slice_a[i] - static_cast(2); } @@ -112,7 +112,7 @@ void matrix_verify_mul(queue q, big_matrix &A, nd_range<2> &r, joint_matrix_fill(sg, sub_a, 5); auto wi_slice_a = - sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_a); + sycl::ext::oneapi::detail::get_wi_data(sg, sub_a); for (int i = 0; i < wi_slice_a.length(); i++) { wi_slice_a[i] = wi_slice_a[i] * static_cast(3.0); } @@ -147,7 +147,7 @@ void matrix_verify_div(queue q, big_matrix &A, nd_range<2> &r, joint_matrix_fill(sg, sub_a, 4); auto wi_slice_a = - sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_a); + sycl::ext::oneapi::detail::get_wi_data(sg, sub_a); for (int i = 0; i < wi_slice_a.length(); i++) { wi_slice_a[i] = wi_slice_a[i] / static_cast(2.0); } @@ -182,7 +182,7 @@ void matrix_verify_logic(queue q, big_matrix &A, nd_range<2> &r, joint_matrix_fill(sg, sub_a, 5); auto wi_slice_a = - sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_a); + sycl::ext::oneapi::detail::get_wi_data(sg, sub_a); for (int i = 0; i < wi_slice_a.length(); i++) { if (wi_slice_a[i]) { if (wi_slice_a[i] > static_cast(2.0) || diff --git a/sycl/test-e2e/Matrix/element_wise_all_ops_impl.hpp b/sycl/test-e2e/Matrix/element_wise_all_ops_impl.hpp index 8e15488e151a0..25d77269b03a4 100644 --- a/sycl/test-e2e/Matrix/element_wise_all_ops_impl.hpp +++ b/sycl/test-e2e/Matrix/element_wise_all_ops_impl.hpp @@ -64,7 +64,7 @@ void verify_op_a(const T l, const T r, const float ref, OP op) { sub_mat; joint_matrix_fill(sg, sub_mat, l); auto wi_slice = - sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_mat); + sycl::ext::oneapi::detail::get_wi_data(sg, sub_mat); for (int i = 0; i < wi_slice.length(); i++) { wi_slice[i] = op(wi_slice[i], r); } @@ -105,7 +105,7 @@ void verify_op_c(const T l, const T r, const float ref, OP op) { sub_mat; joint_matrix_fill(sg, sub_mat, l); auto wi_slice = - sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_mat); + sycl::ext::oneapi::detail::get_wi_data(sg, sub_mat); for (int i = 0; i < wi_slice.length(); i++) { wi_slice[i] = op(wi_slice[i], r); } diff --git a/sycl/test-e2e/Matrix/element_wise_all_ops_int8_impl.hpp b/sycl/test-e2e/Matrix/element_wise_all_ops_int8_impl.hpp index 803ebe0addb3a..5d1668b753baa 100644 --- a/sycl/test-e2e/Matrix/element_wise_all_ops_int8_impl.hpp +++ b/sycl/test-e2e/Matrix/element_wise_all_ops_int8_impl.hpp @@ -41,7 +41,7 @@ void matrix_verify_add(queue q, big_matrix &A, nd_range<2> &r, joint_matrix_fill(sg, sub_a, 5); auto wi_slice_a = - sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_a); + sycl::ext::oneapi::detail::get_wi_data(sg, sub_a); for (int i = 0; i < wi_slice_a.length(); i++) { wi_slice_a[i] = wi_slice_a[i] + 2; } @@ -76,7 +76,7 @@ void matrix_verify_sub(queue q, big_matrix &A, nd_range<2> &r, joint_matrix_fill(sg, sub_a, 5); auto wi_slice_a = - sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_a); + sycl::ext::oneapi::detail::get_wi_data(sg, sub_a); for (int i = 0; i < wi_slice_a.length(); i++) { wi_slice_a[i] = wi_slice_a[i] - 2; } @@ -111,7 +111,7 @@ void matrix_verify_mul(queue q, big_matrix &A, nd_range<2> &r, joint_matrix_fill(sg, sub_a, 5); auto wi_slice_a = - sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_a); + sycl::ext::oneapi::detail::get_wi_data(sg, sub_a); for (int i = 0; i < wi_slice_a.length(); i++) { wi_slice_a[i] = wi_slice_a[i] * 3; } @@ -146,7 +146,7 @@ void matrix_verify_div(queue q, big_matrix &A, nd_range<2> &r, joint_matrix_fill(sg, sub_a, 4); auto wi_slice_a = - sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_a); + sycl::ext::oneapi::detail::get_wi_data(sg, sub_a); for (int i = 0; i < wi_slice_a.length(); i++) { wi_slice_a[i] = wi_slice_a[i] / 2; } @@ -181,7 +181,7 @@ void matrix_verify_logic(queue q, big_matrix &A, nd_range<2> &r, joint_matrix_fill(sg, sub_a, 5); auto wi_slice_a = - sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_a); + sycl::ext::oneapi::detail::get_wi_data(sg, sub_a); for (int i = 0; i < wi_slice_a.length(); i++) { if (wi_slice_a[i]) { if (wi_slice_a[i] > 2 || wi_slice_a[i] >= 2 || diff --git a/sycl/test-e2e/Matrix/element_wise_all_ops_int8_packed_impl.hpp b/sycl/test-e2e/Matrix/element_wise_all_ops_int8_packed_impl.hpp index ce89a04b4168c..a04d464605fef 100644 --- a/sycl/test-e2e/Matrix/element_wise_all_ops_int8_packed_impl.hpp +++ b/sycl/test-e2e/Matrix/element_wise_all_ops_int8_packed_impl.hpp @@ -36,14 +36,14 @@ void matrix_verify_add(queue q, big_matrix &A, nd_range<2> &r, const auto sg_starty = global_idy - spmd_item.get_local_id(1); sub_group sg = spmd_item.get_sub_group(); - joint_matrix + joint_matrix< + sub_group, int8_t, use::b, TK, TN, + ext::oneapi::experimental::matrix::layout::ext_intel_packed> sub_b; joint_matrix_fill(sg, sub_b, 5); - auto wi_slice_b = - sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_b); + auto wi_slice_b = sycl::ext::oneapi::detail::get_wi_data(sg, sub_b); for (int i = 0; i < wi_slice_b.length(); i++) { wi_slice_b[i] = wi_slice_b[i] + 2; } @@ -73,14 +73,14 @@ void matrix_verify_sub(queue q, big_matrix &A, nd_range<2> &r, const auto sg_starty = global_idy - spmd_item.get_local_id(1); sub_group sg = spmd_item.get_sub_group(); - joint_matrix + joint_matrix< + sub_group, int8_t, use::b, TK, TN, + ext::oneapi::experimental::matrix::layout::ext_intel_packed> sub_b; joint_matrix_fill(sg, sub_b, 5); - auto wi_slice_b = - sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_b); + auto wi_slice_b = sycl::ext::oneapi::detail::get_wi_data(sg, sub_b); for (int i = 0; i < wi_slice_b.length(); i++) { wi_slice_b[i] = wi_slice_b[i] - 2; } @@ -110,14 +110,14 @@ void matrix_verify_mul(queue q, big_matrix &A, nd_range<2> &r, const auto sg_starty = global_idy - spmd_item.get_local_id(1); sub_group sg = spmd_item.get_sub_group(); - joint_matrix + joint_matrix< + sub_group, int8_t, use::b, TK, TN, + ext::oneapi::experimental::matrix::layout::ext_intel_packed> sub_b; joint_matrix_fill(sg, sub_b, 5); - auto wi_slice_b = - sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_b); + auto wi_slice_b = sycl::ext::oneapi::detail::get_wi_data(sg, sub_b); for (int i = 0; i < wi_slice_b.length(); i++) { wi_slice_b[i] = wi_slice_b[i] * 3; } @@ -147,14 +147,14 @@ void matrix_verify_div(queue q, big_matrix &A, nd_range<2> &r, const auto sg_starty = global_idy - spmd_item.get_local_id(1); sub_group sg = spmd_item.get_sub_group(); - joint_matrix + joint_matrix< + sub_group, int8_t, use::b, TK, TN, + ext::oneapi::experimental::matrix::layout::ext_intel_packed> sub_b; joint_matrix_fill(sg, sub_b, 4); - auto wi_slice_b = - sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_b); + auto wi_slice_b = sycl::ext::oneapi::detail::get_wi_data(sg, sub_b); for (int i = 0; i < wi_slice_b.length(); i++) { wi_slice_b[i] = wi_slice_b[i] / 2; } @@ -184,14 +184,14 @@ void matrix_verify_logic(queue q, big_matrix &A, nd_range<2> &r, const auto sg_starty = global_idy - spmd_item.get_local_id(1); sub_group sg = spmd_item.get_sub_group(); - joint_matrix + joint_matrix< + sub_group, int8_t, use::b, TK, TN, + ext::oneapi::experimental::matrix::layout::ext_intel_packed> sub_b; joint_matrix_fill(sg, sub_b, 5); - auto wi_slice_b = - sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_b); + auto wi_slice_b = sycl::ext::oneapi::detail::get_wi_data(sg, sub_b); for (int i = 0; i < wi_slice_b.length(); i++) { if (wi_slice_b[i]) { if (wi_slice_b[i] > 2 || wi_slice_b[i] >= 2 || diff --git a/sycl/test-e2e/Matrix/element_wise_all_ops_tf32_impl.hpp b/sycl/test-e2e/Matrix/element_wise_all_ops_tf32_impl.hpp index 27eacf89c748a..96fcfa975b408 100644 --- a/sycl/test-e2e/Matrix/element_wise_all_ops_tf32_impl.hpp +++ b/sycl/test-e2e/Matrix/element_wise_all_ops_tf32_impl.hpp @@ -42,7 +42,7 @@ void matrix_verify_add(queue q, big_matrix &A, nd_range<2> &r, joint_matrix_fill(sg, sub_a, round_to_tf32(5.0)); auto wi_slice_a = - ext::intel::experimental::matrix::get_wi_data(sg, sub_a); + ext::oneapi::detail::get_wi_data(sg, sub_a); for (int i = 0; i < wi_slice_a.length(); i++) { wi_slice_a[i] = wi_slice_a[i] + 2; } @@ -78,7 +78,7 @@ void matrix_verify_sub(queue q, big_matrix &A, nd_range<2> &r, joint_matrix_fill(sg, sub_a, round_to_tf32(5.0)); auto wi_slice_a = - ext::intel::experimental::matrix::get_wi_data(sg, sub_a); + ext::oneapi::detail::get_wi_data(sg, sub_a); for (int i = 0; i < wi_slice_a.length(); i++) { wi_slice_a[i] = wi_slice_a[i] - round_to_tf32(2); } @@ -112,7 +112,7 @@ void matrix_verify_mul(queue q, big_matrix &A, nd_range<2> &r, joint_matrix_fill(sg, sub_a, round_to_tf32(5.0)); auto wi_slice_a = - ext::intel::experimental::matrix::get_wi_data(sg, sub_a); + ext::oneapi::detail::get_wi_data(sg, sub_a); for (int i = 0; i < wi_slice_a.length(); i++) { wi_slice_a[i] = wi_slice_a[i] * round_to_tf32(3.0); } @@ -147,7 +147,7 @@ void matrix_verify_div(queue q, big_matrix &A, nd_range<2> &r, joint_matrix_fill(sg, sub_a, round_to_tf32(4.0)); auto wi_slice_a = - ext::intel::experimental::matrix::get_wi_data(sg, sub_a); + ext::oneapi::detail::get_wi_data(sg, sub_a); for (int i = 0; i < wi_slice_a.length(); i++) { wi_slice_a[i] = wi_slice_a[i] / round_to_tf32(2.0); } @@ -181,7 +181,7 @@ void matrix_verify_logic(queue q, big_matrix &A, nd_range<2> &r, joint_matrix_fill(sg, sub_a, round_to_tf32(5.0)); auto wi_slice_a = - ext::intel::experimental::matrix::get_wi_data(sg, sub_a); + ext::oneapi::detail::get_wi_data(sg, sub_a); for (int i = 0; i < wi_slice_a.length(); i++) { if (wi_slice_a[i]) { if (wi_slice_a[i] > 2.0 || wi_slice_a[i] >= 2.0 || diff --git a/sycl/test-e2e/Matrix/element_wise_all_sizes_impl.hpp b/sycl/test-e2e/Matrix/element_wise_all_sizes_impl.hpp index c49f9b57e2f32..b18ca01193974 100644 --- a/sycl/test-e2e/Matrix/element_wise_all_sizes_impl.hpp +++ b/sycl/test-e2e/Matrix/element_wise_all_sizes_impl.hpp @@ -62,7 +62,7 @@ void matrix_verify_add(const T1 val1, const T1 val2, const T1 result) { joint_matrix_fill(sg, sub_a, val1); auto wi_slice_a = - sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_a); + sycl::ext::oneapi::detail::get_wi_data(sg, sub_a); for (int i = 0; i < wi_slice_a.length(); i++) { wi_slice_a[i] = wi_slice_a[i] + val2; } diff --git a/sycl/test-e2e/Matrix/element_wise_irreg_sum_rows_impl.hpp b/sycl/test-e2e/Matrix/element_wise_irreg_sum_rows_impl.hpp index cfce95cba269f..6e3a86d7cb77b 100644 --- a/sycl/test-e2e/Matrix/element_wise_irreg_sum_rows_impl.hpp +++ b/sycl/test-e2e/Matrix/element_wise_irreg_sum_rows_impl.hpp @@ -44,8 +44,9 @@ void matrix_sum_rows(queue q, big_matrix &B, nd_range<2> &r) { sycl::sub_group sg = spmd_item.get_sub_group(); - joint_matrix + joint_matrix< + sub_group, T, use::b, TK, TN, + ext::oneapi::experimental::matrix::layout::ext_intel_packed> sub_b; joint_matrix_load( @@ -57,8 +58,7 @@ void matrix_sum_rows(queue q, big_matrix &B, nd_range<2> &r) { // (tK/4) int32_t sum_local_rows[M] = {0}; // 8 local rows, M total // sub_b has 32x8 elements, 32 elements per WI, 4 per WI per row - auto data = - sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_b); + auto data = sycl::ext::oneapi::detail::get_wi_data(sg, sub_b); // each WI calculates local sum of rows for (int row = 0; row < TK / 4; row++) { // there are 8 rows diff --git a/sycl/test-e2e/Matrix/element_wise_ops_impl.hpp b/sycl/test-e2e/Matrix/element_wise_ops_impl.hpp index 1206b556339a9..edf9f162d5c3a 100644 --- a/sycl/test-e2e/Matrix/element_wise_ops_impl.hpp +++ b/sycl/test-e2e/Matrix/element_wise_ops_impl.hpp @@ -51,8 +51,9 @@ void matrix_multiply(big_matrix &C, joint_matrix sub_a; // For B, we assume B has been already VNNIed. - joint_matrix + joint_matrix< + sub_group, int8_t, use::b, TK, TN, + ext::oneapi::experimental::matrix::layout::ext_intel_packed> sub_b; joint_matrix sub_c; @@ -72,10 +73,9 @@ void matrix_multiply(big_matrix &C, accB.template get_multi_ptr() + (k * TK / 4) * (N * 4) + sg_starty / SG_SZ * TN * 4, N * 4); - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c); } - auto wi_slice_c = - sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_c); + auto wi_slice_c = sycl::ext::oneapi::detail::get_wi_data(sg, sub_c); for (int i = 0; i < wi_slice_c.length(); i++) { wi_slice_c[i] *= 2; } diff --git a/sycl/test-e2e/Matrix/elemwise_irreg_size_ops_bf16.cpp b/sycl/test-e2e/Matrix/elemwise_irreg_size_ops_bf16.cpp index 8e2865de207b4..08724fbb48e9d 100644 --- a/sycl/test-e2e/Matrix/elemwise_irreg_size_ops_bf16.cpp +++ b/sycl/test-e2e/Matrix/elemwise_irreg_size_ops_bf16.cpp @@ -80,8 +80,9 @@ void matrix_multiply(big_matrix &C, joint_matrix sub_a; // For B, we assume B has been already VNNIed. - joint_matrix + joint_matrix< + sub_group, bfloat16, use::b, TK, TN, + ext::oneapi::experimental::matrix::layout::ext_intel_packed> sub_b; joint_matrix sub_c; joint_matrix_load( @@ -101,10 +102,9 @@ void matrix_multiply(big_matrix &C, accB.template get_multi_ptr() + (k) * (N) + sg_starty / SG_SZ * TN * 2, N * 2); - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c); } - auto wi_slice_c = - sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_c); + auto wi_slice_c = sycl::ext::oneapi::detail::get_wi_data(sg, sub_c); for (int i = 0; i < wi_slice_c.length(); i++) { wi_slice_c[i] += 5.0; } diff --git a/sycl/test-e2e/Matrix/joint_matrix_all_sizes_impl.hpp b/sycl/test-e2e/Matrix/joint_matrix_all_sizes_impl.hpp index 469837cd26d41..f4572990ded76 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_all_sizes_impl.hpp +++ b/sycl/test-e2e/Matrix/joint_matrix_all_sizes_impl.hpp @@ -56,8 +56,9 @@ void matrix_multiply(big_matrix &C, big_matrix &A, sub_group sg = spmd_item.get_sub_group(); joint_matrix sub_a; // For B, we assume B has been already VNNIed. - joint_matrix + joint_matrix< + sub_group, T2, use::b, TK, TN, + ext::oneapi::experimental::matrix::layout::ext_intel_packed> sub_b; joint_matrix sub_c; @@ -78,7 +79,7 @@ void matrix_multiply(big_matrix &C, big_matrix &A, (k * TK / vnniFactor) * (N * vnniFactor) + sg_starty / SG_SZ * TN * vnniFactor, N * vnniFactor); - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c); } joint_matrix_store( sg, sub_c, diff --git a/sycl/test-e2e/Matrix/joint_matrix_apply_cuda.hpp b/sycl/test-e2e/Matrix/joint_matrix_apply_cuda.hpp index 4303239cefe32..b58e0a1b7c467 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_apply_cuda.hpp +++ b/sycl/test-e2e/Matrix/joint_matrix_apply_cuda.hpp @@ -51,34 +51,34 @@ void matrix_verify_lambda(queue q, q.submit([&](handler &cgh) { accessor accC(bufC, cgh); - cgh.parallel_for>(r, [ - accC, lambda - ](nd_item<2> spmd_item)[[sycl::reqd_sub_group_size(SG_SZ)]] { - const auto global_idx = spmd_item.get_global_id(0); - const auto global_idy = spmd_item.get_global_id(1); - const auto sg_startx = global_idx - spmd_item.get_local_id(0); - const auto sg_starty = global_idy - spmd_item.get_local_id(1); - - auto sg = spmd_item.get_sub_group(); - - joint_matrix sub_a; - joint_matrix sub_b; - joint_matrix sub_c; - - joint_matrix_fill(sg, sub_a, 3); - joint_matrix_fill(sg, sub_b, 1); - joint_matrix_fill(sg, sub_c, -80); - - joint_matrix_apply(sg, sub_a, lambda); - - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); - - joint_matrix_store( - sg, sub_c, - accC.template get_multi_ptr() + - (sg_startx * M) * (N * nWGperDim) + sg_starty / SG_SZ * N, - (N * nWGperDim), layout::row_major); - }); // parallel for + cgh.parallel_for>( + r, [accC, lambda]( + nd_item<2> spmd_item) [[sycl::reqd_sub_group_size(SG_SZ)]] { + const auto global_idx = spmd_item.get_global_id(0); + const auto global_idy = spmd_item.get_global_id(1); + const auto sg_startx = global_idx - spmd_item.get_local_id(0); + const auto sg_starty = global_idy - spmd_item.get_local_id(1); + + auto sg = spmd_item.get_sub_group(); + + joint_matrix sub_a; + joint_matrix sub_b; + joint_matrix sub_c; + + joint_matrix_fill(sg, sub_a, 3); + joint_matrix_fill(sg, sub_b, 1); + joint_matrix_fill(sg, sub_c, -80); + + joint_matrix_apply(sg, sub_a, lambda); + + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c); + + joint_matrix_store( + sg, sub_c, + accC.template get_multi_ptr() + + (sg_startx * M) * (N * nWGperDim) + sg_starty / SG_SZ * N, + (N * nWGperDim), layout::row_major); + }); // parallel for }); } assert_ref(C.get_data(), ref); @@ -111,8 +111,8 @@ void matrix_verify_op(queue q, big_matrix &C, cgh); cgh.parallel_for>( - r, [ accC, - Op ](nd_item<2> spmd_item)[[sycl::reqd_sub_group_size(SG_SZ)]] { + r, [accC, + Op](nd_item<2> spmd_item) [[sycl::reqd_sub_group_size(SG_SZ)]] { const auto global_idx = spmd_item.get_global_id(0); const auto global_idy = spmd_item.get_global_id(1); const auto sg_startx = global_idx - spmd_item.get_local_id(0); @@ -154,7 +154,7 @@ void matrix_verify_op(queue q, big_matrix &C, } } - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c); joint_matrix_store( sg, sub_c, @@ -162,8 +162,7 @@ void matrix_verify_op(queue q, big_matrix &C, (sg_startx * M) * (N * nWGperDim) + sg_starty / SG_SZ * N, (N * nWGperDim), layout::row_major); }); // parallel for - }) - .wait(); + }).wait(); } assert_ops_ref(C.get_data(), ref); } diff --git a/sycl/test-e2e/Matrix/joint_matrix_bf16_fill_k_cache_impl.hpp b/sycl/test-e2e/Matrix/joint_matrix_bf16_fill_k_cache_impl.hpp index f2d359cfe130b..74af36f0238ef 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_bf16_fill_k_cache_impl.hpp +++ b/sycl/test-e2e/Matrix/joint_matrix_bf16_fill_k_cache_impl.hpp @@ -147,36 +147,37 @@ double joint_matmul(TOperand *A, TOperand *B, TResult *C, queue &q, int i) { #endif ; - joint_matrix + joint_matrix< + sub_group, TOperand, use::b, tK, tN, + ext::oneapi::experimental::matrix::layout::ext_intel_packed> tB[NCACHE1 / tN][KCACHE2 / KCACHE1] #ifdef INIT_LIST = { - joint_matrix< - sub_group, TOperand, use::b, tK, tN, - ext::intel::experimental::matrix::layout::packed>(), - joint_matrix< - sub_group, TOperand, use::b, tK, tN, - ext::intel::experimental::matrix::layout::packed>(), - joint_matrix< - sub_group, TOperand, use::b, tK, tN, - ext::intel::experimental::matrix::layout::packed>(), - joint_matrix< - sub_group, TOperand, use::b, tK, tN, - ext::intel::experimental::matrix::layout::packed>(), - joint_matrix< - sub_group, TOperand, use::b, tK, tN, - ext::intel::experimental::matrix::layout::packed>(), - joint_matrix< - sub_group, TOperand, use::b, tK, tN, - ext::intel::experimental::matrix::layout::packed>(), - joint_matrix< - sub_group, TOperand, use::b, tK, tN, - ext::intel::experimental::matrix::layout::packed>(), - joint_matrix< - sub_group, TOperand, use::b, tK, tN, - ext::intel::experimental::matrix::layout::packed>(), + joint_matrix(), + joint_matrix(), + joint_matrix(), + joint_matrix(), + joint_matrix(), + joint_matrix(), + joint_matrix(), + joint_matrix(), } #endif ; diff --git a/sycl/test-e2e/Matrix/joint_matrix_bfloat16_32x64_impl.hpp b/sycl/test-e2e/Matrix/joint_matrix_bfloat16_32x64_impl.hpp index cc0196660744a..c27df0c01caa4 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_bfloat16_32x64_impl.hpp +++ b/sycl/test-e2e/Matrix/joint_matrix_bfloat16_32x64_impl.hpp @@ -46,8 +46,9 @@ void matrix_multiply(big_matrix &C, big_matrix &A, joint_matrix sub_a; // For B, we assume B has been already VNNIed. - joint_matrix + joint_matrix< + sub_group, bfloat16, use::b, TK, TN, + ext::oneapi::experimental::matrix::layout::ext_intel_packed> sub_b; joint_matrix sub_c; @@ -68,7 +69,7 @@ void matrix_multiply(big_matrix &C, big_matrix &A, accB.template get_multi_ptr() + (k * TK / 2) * (N * 2) + sg_starty / SG_SZ * TN * 2, N * 2); - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c); } joint_matrix_store( sg, sub_c, diff --git a/sycl/test-e2e/Matrix/joint_matrix_bfloat16_array_impl.hpp b/sycl/test-e2e/Matrix/joint_matrix_bfloat16_array_impl.hpp index d6390d8061dcc..9cdf9b2435f82 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_bfloat16_array_impl.hpp +++ b/sycl/test-e2e/Matrix/joint_matrix_bfloat16_array_impl.hpp @@ -59,8 +59,9 @@ void matrix_multiply(big_matrix &C, big_matrix &A, sub_a[JM_ARRAY_SZ]; // For B, we assume B has been already VNNIed. - joint_matrix + joint_matrix< + sub_group, bfloat16, use::b, TK, TN, + ext::oneapi::experimental::matrix::layout::ext_intel_packed> sub_b; joint_matrix sub_c[JM_ARRAY_SZ]; @@ -81,7 +82,7 @@ void matrix_multiply(big_matrix &C, big_matrix &A, accA.template get_multi_ptr() + (sg_startx * TM * JM_ARRAY_SZ + TM * i) * K + k * TK, K); - sub_c[i] = joint_matrix_mad(sg, sub_a[i], sub_b, sub_c[i]); + joint_matrix_mad(sg, sub_a[i], sub_b, sub_c[i], sub_c[i]); } } diff --git a/sycl/test-e2e/Matrix/joint_matrix_bfloat16_colmajorA_colmajorB_impl.hpp b/sycl/test-e2e/Matrix/joint_matrix_bfloat16_colmajorA_colmajorB_impl.hpp index 4847e093127a8..2c627d8b88e31 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_bfloat16_colmajorA_colmajorB_impl.hpp +++ b/sycl/test-e2e/Matrix/joint_matrix_bfloat16_colmajorA_colmajorB_impl.hpp @@ -64,7 +64,7 @@ void matrix_multiply(big_matrix &C, big_matrix &A, accB.template get_multi_ptr() + (sg_starty / SG_SZ * TN) * K + k * TK, K); - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c); } joint_matrix_store( sg, sub_c, diff --git a/sycl/test-e2e/Matrix/joint_matrix_bfloat16_impl.hpp b/sycl/test-e2e/Matrix/joint_matrix_bfloat16_impl.hpp index 76ac69da27677..a35dd11ee8cfa 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_bfloat16_impl.hpp +++ b/sycl/test-e2e/Matrix/joint_matrix_bfloat16_impl.hpp @@ -45,8 +45,9 @@ void matrix_multiply(big_matrix &C, big_matrix &A, joint_matrix sub_a; // For B, we assume B has been already VNNIed. - joint_matrix + joint_matrix< + sub_group, bfloat16, use::b, TK, TN, + ext::oneapi::experimental::matrix::layout::ext_intel_packed> sub_b; joint_matrix sub_c; @@ -66,7 +67,7 @@ void matrix_multiply(big_matrix &C, big_matrix &A, accB.template get_multi_ptr() + (k * TK / 2) * (N * 2) + sg_starty / SG_SZ * TN * 2, N * 2); - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c); } joint_matrix_store( sg, sub_c, diff --git a/sycl/test-e2e/Matrix/joint_matrix_bfloat16_rowmajorA_rowmajorB_impl.hpp b/sycl/test-e2e/Matrix/joint_matrix_bfloat16_rowmajorA_rowmajorB_impl.hpp index 4d61a733e5927..08341b3835d57 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_bfloat16_rowmajorA_rowmajorB_impl.hpp +++ b/sycl/test-e2e/Matrix/joint_matrix_bfloat16_rowmajorA_rowmajorB_impl.hpp @@ -64,7 +64,7 @@ void matrix_multiply(big_matrix &C, big_matrix &A, accB.template get_multi_ptr() + (k * TK) * (N) + sg_starty / SG_SZ * TN, N); - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c); } joint_matrix_store( sg, sub_c, diff --git a/sycl/test-e2e/Matrix/joint_matrix_colA_rowB_colC_impl.hpp b/sycl/test-e2e/Matrix/joint_matrix_colA_rowB_colC_impl.hpp index f75da1824d94b..a90a46e258452 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_colA_rowB_colC_impl.hpp +++ b/sycl/test-e2e/Matrix/joint_matrix_colA_rowB_colC_impl.hpp @@ -48,7 +48,7 @@ void matrix_multiply(T1 *C, T2 *A, T2 *B, queue q) { joint_matrix_load(sg, sub_a, pA + (sg_startx * TM) * K + k, K); joint_matrix_load(sg, sub_b, pB + k * N + sg_starty / SG_SZ * TN, N); - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c); } joint_matrix_store( sg, sub_c, pC + (sg_startx * TM) * N + sg_starty / SG_SZ * TN, N, diff --git a/sycl/test-e2e/Matrix/joint_matrix_gemm_cuda.hpp b/sycl/test-e2e/Matrix/joint_matrix_gemm_cuda.hpp index 5e451d45d7727..244311f7cc9ee 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_gemm_cuda.hpp +++ b/sycl/test-e2e/Matrix/joint_matrix_gemm_cuda.hpp @@ -178,7 +178,7 @@ void test(queue &q) { } } - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c); } joint_matrix_store( sg, sub_c, diff --git a/sycl/test-e2e/Matrix/joint_matrix_half_impl.hpp b/sycl/test-e2e/Matrix/joint_matrix_half_impl.hpp index f92548d2f7ed8..fcfb87545ff6b 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_half_impl.hpp +++ b/sycl/test-e2e/Matrix/joint_matrix_half_impl.hpp @@ -50,8 +50,9 @@ void matrix_multiply(big_matrix &C, joint_matrix sub_a; // For B, we assume B has been already VNNIed. - joint_matrix + joint_matrix< + sub_group, half, use::b, TK, TN, + ext::oneapi::experimental::matrix::layout::ext_intel_packed> sub_b; joint_matrix sub_c; @@ -71,7 +72,7 @@ void matrix_multiply(big_matrix &C, accB.template get_multi_ptr() + (k * TK / 2) * (N * 2) + sg_starty / SG_SZ * TN * 2, N * 2); - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c); } joint_matrix_store( sg, sub_c, diff --git a/sycl/test-e2e/Matrix/joint_matrix_int8_colmajorA_colmajorB_impl.hpp b/sycl/test-e2e/Matrix/joint_matrix_int8_colmajorA_colmajorB_impl.hpp index 6111f503007f5..a5b1fddcbbbb8 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_int8_colmajorA_colmajorB_impl.hpp +++ b/sycl/test-e2e/Matrix/joint_matrix_int8_colmajorA_colmajorB_impl.hpp @@ -64,7 +64,7 @@ void matrix_multiply(big_matrix &C, accB.template get_multi_ptr() + (sg_starty / SG_SZ * TN) * K + k * TK, K); - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c); } joint_matrix_store( sg, sub_c, diff --git a/sycl/test-e2e/Matrix/joint_matrix_int8_vnni_impl.hpp b/sycl/test-e2e/Matrix/joint_matrix_int8_vnni_impl.hpp index b6fe3f0376ffd..e042a8e282b56 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_int8_vnni_impl.hpp +++ b/sycl/test-e2e/Matrix/joint_matrix_int8_vnni_impl.hpp @@ -65,7 +65,7 @@ void matrix_multiply(big_matrix &C, accB.template get_multi_ptr() + (k * TK) * N + sg_starty / SG_SZ * TN, N); - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c); } joint_matrix_store( sg, sub_c, diff --git a/sycl/test-e2e/Matrix/joint_matrix_out_bounds_impl.hpp b/sycl/test-e2e/Matrix/joint_matrix_out_bounds_impl.hpp index 1c1c4f97819bf..46c49f6576e44 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_out_bounds_impl.hpp +++ b/sycl/test-e2e/Matrix/joint_matrix_out_bounds_impl.hpp @@ -58,7 +58,7 @@ void matrix_multiply(T1 *C, T2 *A, T2 *B, queue q, unsigned int vnniFactor) { joint_matrix_load_checked( sg, sub_b, pB + k * N + sg_starty / SG_SZ * TN * vnniFactor, N * vnniFactor, K / vnniFactor, N * vnniFactor); - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, sub_c); } // bounds-checked store where width and height are added joint_matrix_store_checked( diff --git a/sycl/test-e2e/Matrix/joint_matrix_query_default.cpp b/sycl/test-e2e/Matrix/joint_matrix_query_default.cpp index 048aed6341f6c..b44661a1f269f 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_query_default.cpp +++ b/sycl/test-e2e/Matrix/joint_matrix_query_default.cpp @@ -78,7 +78,8 @@ void matrix_multiply(big_matrix &C, myparams2::joint_matrix_a sub_a; myparams2::joint_matrix_b< - sub_group, ext::intel::experimental::matrix::layout::packed> + sub_group, + ext::oneapi::experimental::matrix::layout::ext_intel_packed> sub_b; myparams2::joint_matrix_accumulator sub_c; @@ -99,7 +100,7 @@ void matrix_multiply(big_matrix &C, accB.template get_multi_ptr() + (k * TK / 4) * (N * 4) + sg_starty / SG_SZ * TN * 4, N * 4); - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c); } joint_matrix_store( sg, sub_c, diff --git a/sycl/test-e2e/Matrix/joint_matrix_ss_int8_impl.hpp b/sycl/test-e2e/Matrix/joint_matrix_ss_int8_impl.hpp index 5cca6572cef21..5820544722f55 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_ss_int8_impl.hpp +++ b/sycl/test-e2e/Matrix/joint_matrix_ss_int8_impl.hpp @@ -51,8 +51,9 @@ void matrix_multiply(big_matrix &C, joint_matrix sub_a; // For B, we assume B has been already VNNIed. - joint_matrix + joint_matrix< + sub_group, int8_t, use::b, TK, TN, + ext::oneapi::experimental::matrix::layout::ext_intel_packed> sub_b; joint_matrix sub_c; @@ -68,7 +69,7 @@ void matrix_multiply(big_matrix &C, accB.template get_multi_ptr() + (k * TK / 4) * (N * 4) + sg_starty / SG_SZ * TN * 4, N * 4); - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c); } joint_matrix_store( sg, sub_c, diff --git a/sycl/test-e2e/Matrix/joint_matrix_su_int8_impl.hpp b/sycl/test-e2e/Matrix/joint_matrix_su_int8_impl.hpp index 397fcc9a5aa97..628735e8523e2 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_su_int8_impl.hpp +++ b/sycl/test-e2e/Matrix/joint_matrix_su_int8_impl.hpp @@ -51,8 +51,9 @@ void matrix_multiply(big_matrix &C, joint_matrix sub_a; // For B, we assume B has been already VNNIed. - joint_matrix + joint_matrix< + sub_group, uint8_t, use::b, TK, TN, + ext::oneapi::experimental::matrix::layout::ext_intel_packed> sub_b; joint_matrix sub_c; @@ -72,7 +73,7 @@ void matrix_multiply(big_matrix &C, accB.template get_multi_ptr() + (k * TK / 4) * (N * 4) + sg_starty / SG_SZ * TN * 4, N * 4); - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c); } joint_matrix_store( sg, sub_c, diff --git a/sycl/test-e2e/Matrix/joint_matrix_tf32_impl.hpp b/sycl/test-e2e/Matrix/joint_matrix_tf32_impl.hpp index 4d4ba0ee951e9..e8a4ece3bb00f 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_tf32_impl.hpp +++ b/sycl/test-e2e/Matrix/joint_matrix_tf32_impl.hpp @@ -76,15 +76,13 @@ void matrix_multiply(big_matrix &C, // function will work on truncated floats. joint_matrix_apply(sg, sub_a, [=](float x) { x = round_to_tf32(x); }); - auto wi_data_b = - sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_b); + auto wi_data_b = sycl::ext::oneapi::detail::get_wi_data(sg, sub_b); for (int i = 0; i < wi_data_b.length(); i++) { wi_data_b[i] = round_to_tf32(wi_data_b[i]); } - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c); } - auto wi_slice_a = - sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_a); + auto wi_slice_a = sycl::ext::oneapi::detail::get_wi_data(sg, sub_a); joint_matrix_apply(sg, sub_a, [=](float x) { x *= 2; }); joint_matrix_store( sg, sub_c, diff --git a/sycl/test-e2e/Matrix/joint_matrix_transposeC_impl.hpp b/sycl/test-e2e/Matrix/joint_matrix_transposeC_impl.hpp index 24f6cce4cc09d..8634efde40c74 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_transposeC_impl.hpp +++ b/sycl/test-e2e/Matrix/joint_matrix_transposeC_impl.hpp @@ -42,8 +42,9 @@ void matrix_multiply(T1 *C, T2 *A, T2 *B, queue q, unsigned int vnniFactor) { // For B, since current implementation does not support non-packed // layout, users need to specify the packed_b layout. - joint_matrix + joint_matrix< + sub_group, bfloat16, use::b, TK, TN, + ext::oneapi::experimental::matrix::layout::ext_intel_packed> sub_b; joint_matrix sub_c; joint_matrix_load(sg, sub_c, @@ -55,7 +56,7 @@ void matrix_multiply(T1 *C, T2 *A, T2 *B, queue q, unsigned int vnniFactor) { joint_matrix_load(sg, sub_b, pB + k * N + sg_starty / SG_SZ * TN * vnniFactor, N * vnniFactor); - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c); } joint_matrix_store( sg, sub_c, pC + (sg_startx * TM) * N + sg_starty / SG_SZ * TN, N, diff --git a/sycl/test-e2e/Matrix/joint_matrix_us_int8_impl.hpp b/sycl/test-e2e/Matrix/joint_matrix_us_int8_impl.hpp index 1d82f8833aba6..493ece5bc7d5e 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_us_int8_impl.hpp +++ b/sycl/test-e2e/Matrix/joint_matrix_us_int8_impl.hpp @@ -53,8 +53,9 @@ void matrix_multiply(big_matrix &C, joint_matrix sub_a; // For B, we assume B has been already VNNIed. - joint_matrix + joint_matrix< + sub_group, int8_t, use::b, TK, TN, + ext::oneapi::experimental::matrix::layout::ext_intel_packed> sub_b; joint_matrix sub_c; @@ -75,7 +76,7 @@ void matrix_multiply(big_matrix &C, accB.template get_multi_ptr() + (k * TK / 4) * (N * 4) + sg_starty / SG_SZ * TN * 4, N * 4); - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c); } joint_matrix_store( sg, sub_c, diff --git a/sycl/test-e2e/Matrix/joint_matrix_uu_int8_impl.hpp b/sycl/test-e2e/Matrix/joint_matrix_uu_int8_impl.hpp index e400b6694e4a9..d668fa604d395 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_uu_int8_impl.hpp +++ b/sycl/test-e2e/Matrix/joint_matrix_uu_int8_impl.hpp @@ -51,8 +51,9 @@ void matrix_multiply(big_matrix &C, joint_matrix sub_a; // For B, we assume B has been already VNNIed. - joint_matrix + joint_matrix< + sub_group, uint8_t, use::b, TK, TN, + ext::oneapi::experimental::matrix::layout::ext_intel_packed> sub_b; joint_matrix sub_c; @@ -73,7 +74,7 @@ void matrix_multiply(big_matrix &C, accB.template get_multi_ptr() + (k * TK / 4) * (N * 4) + sg_starty / SG_SZ * TN * 4, N * 4); - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c); } joint_matrix_store( sg, sub_c, diff --git a/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-bfloat16-test.cpp b/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-bfloat16-test.cpp index 80b67b14a55ac..4790a1dcb1bf6 100644 --- a/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-bfloat16-test.cpp +++ b/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-bfloat16-test.cpp @@ -57,7 +57,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m16n16k16.mma.row.row.bf16(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}) - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m16n16k16.store.d.row.stride.f32.p1(ptr addrspace(1) %{{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -88,7 +88,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m16n16k16.mma.col.col.bf16(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}) - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m16n16k16.store.d.col.stride.f32.p1(ptr addrspace(1) %{{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -119,7 +119,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m32n8k16.mma.row.row.bf16(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}) - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m32n8k16.store.d.row.stride.f32.p1(ptr addrspace(1) %{{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -150,7 +150,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m32n8k16.mma.col.col.bf16(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}) - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m32n8k16.store.d.col.stride.f32.p1(ptr addrspace(1) %{{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -181,7 +181,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m8n32k16.mma.row.row.bf16(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}) - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m8n32k16.store.d.row.stride.f32.p1(ptr addrspace(1) %{{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -212,7 +212,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m8n32k16.mma.col.col.bf16(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}) - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m8n32k16.store.d.col.stride.f32.p1(ptr addrspace(1) %{{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), diff --git a/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-double-test.cpp b/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-double-test.cpp index 31f77dc55b16f..30704f8869778 100644 --- a/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-double-test.cpp +++ b/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-double-test.cpp @@ -67,7 +67,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), N); //CHECK-OPAQUE: tail call { double, double } @llvm.nvvm.wmma.m8n8k4.mma.row.row.f64(double {{.*}}, double {{.*}}, double {{.*}}, double {{.*}}) - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, sub_c); //CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m8n8k4.store.d.row.stride.f64.p1(ptr addrspace(1) %{{.*}}, double {{.*}}, double {{.*}}, i32 8) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -98,7 +98,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), K); //CHECK-OPAQUE: tail call { double, double } @llvm.nvvm.wmma.m8n8k4.mma.col.col.f64(double {{.*}}, double {{.*}}, double {{.*}}, double {{.*}}) - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, sub_c); //CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m8n8k4.store.d.col.stride.f64.p1(ptr addrspace(1) %{{.*}}, double {{.*}}, double {{.*}}, i32 8) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), diff --git a/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-half-float-test.cpp b/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-half-float-test.cpp index 7c24179022d55..3331c66d302b6 100644 --- a/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-half-float-test.cpp +++ b/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-half-float-test.cpp @@ -56,7 +56,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m16n16k16.mma.row.row.f32.f32(<2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}) - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m16n16k16.store.d.row.stride.f32.p1(ptr addrspace(1) %{{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -87,7 +87,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m16n16k16.mma.col.col.f32.f32(<2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}) - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m16n16k16.store.d.col.stride.f32.p1(ptr addrspace(1) %{{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -118,7 +118,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m32n8k16.mma.row.row.f32.f32(<2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}) - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m32n8k16.store.d.row.stride.f32.p1(ptr addrspace(1) %{{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -149,7 +149,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m32n8k16.mma.col.col.f32.f32(<2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}) - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m32n8k16.store.d.col.stride.f32.p1(ptr addrspace(1) %{{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -180,7 +180,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m8n32k16.mma.row.row.f32.f32(<2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}) - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m8n32k16.store.d.row.stride.f32.p1(ptr addrspace(1) %{{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -211,7 +211,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m8n32k16.mma.col.col.f32.f32(<2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}) - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m8n32k16.store.d.col.stride.f32.p1(ptr addrspace(1) %{{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), diff --git a/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-half-half-test.cpp b/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-half-half-test.cpp index 0e0b4ce903be2..e4250952fa0c2 100644 --- a/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-half-half-test.cpp +++ b/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-half-half-test.cpp @@ -56,7 +56,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m16n16k16.mma.row.row.f16.f16(<2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}) - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m16n16k16.store.d.row.stride.f16.p0(ptr %{{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -87,7 +87,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m16n16k16.mma.col.col.f16.f16(<2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}) - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m16n16k16.store.d.col.stride.f16.p0(ptr %{{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -118,7 +118,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m32n8k16.mma.row.row.f16.f16(<2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}) - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m32n8k16.store.d.row.stride.f16.p0(ptr %{{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -149,7 +149,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m32n8k16.mma.col.col.f16.f16(<2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}) - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m32n8k16.store.d.col.stride.f16.p0(ptr %{{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -180,7 +180,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m8n32k16.mma.row.row.f16.f16(<2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}) - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m8n32k16.store.d.row.stride.f16.p0(ptr %{{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -211,7 +211,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m8n32k16.mma.col.col.f16.f16(<2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}) - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m8n32k16.store.d.col.stride.f16.p0(ptr %{{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), diff --git a/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-int8-test.cpp b/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-int8-test.cpp index 575039723d56e..50fceeb6c34dc 100644 --- a/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-int8-test.cpp +++ b/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-int8-test.cpp @@ -56,7 +56,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m16n16k16.mma.row.row.s8(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}) - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m16n16k16.store.d.row.stride.s32.p1(ptr addrspace(1) %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -87,7 +87,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m16n16k16.mma.col.col.s8(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}) - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m16n16k16.store.d.col.stride.s32.p1(ptr addrspace(1) %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -118,7 +118,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m32n8k16.mma.row.row.s8(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}) - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m32n8k16.store.d.row.stride.s32.p1(ptr addrspace(1) %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -149,7 +149,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m32n8k16.mma.col.col.s8(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}) - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m32n8k16.store.d.col.stride.s32.p1(ptr addrspace(1) %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -180,7 +180,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m8n32k16.mma.row.row.s8(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}) - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m8n32k16.store.d.row.stride.s32.p1(ptr addrspace(1) %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -211,7 +211,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m8n32k16.mma.col.col.s8(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}) - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m8n32k16.store.d.col.stride.s32.p1(ptr addrspace(1) %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), diff --git a/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-tf32-test.cpp b/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-tf32-test.cpp index 8ada375fff395..e8b77588237d1 100644 --- a/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-tf32-test.cpp +++ b/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-tf32-test.cpp @@ -88,7 +88,7 @@ int main() { get_wi_data(sg, sub_b)[i] = round_to_tf32(get_wi_data(sg, sub_b)[i]); - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, sub_c); //CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m16n16k16.store.d.row.stride.f32.p1(ptr addrspace(1) {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, i32 {{.*}} joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -137,7 +137,7 @@ int main() { get_wi_data(sg, sub_b)[i] = round_to_tf32(get_wi_data(sg, sub_b)[i]); - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, sub_c); //CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m16n16k16.store.d.col.stride.f32.p1(ptr addrspace(1) {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), diff --git a/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-uint8-test.cpp b/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-uint8-test.cpp index 69bc136e79776..256bf847645e8 100644 --- a/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-uint8-test.cpp +++ b/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-uint8-test.cpp @@ -56,7 +56,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m16n16k16.mma.row.row.u8(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}) - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m16n16k16.store.d.row.stride.s32.p1(ptr addrspace(1) %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -87,7 +87,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m16n16k16.mma.col.col.u8(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}) - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m16n16k16.store.d.col.stride.s32.p1(ptr addrspace(1) %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -118,7 +118,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m32n8k16.mma.row.row.u8(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}) - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m32n8k16.store.d.row.stride.s32.p1(ptr addrspace(1) %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -149,7 +149,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m32n8k16.mma.col.col.u8(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}) - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m32n8k16.store.d.col.stride.s32.p1(ptr addrspace(1) %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -180,7 +180,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m8n32k16.mma.row.row.u8(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}) - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m8n32k16.store.d.row.stride.s32.p1(ptr addrspace(1) %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -211,7 +211,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m8n32k16.mma.col.col.u8(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}) - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m8n32k16.store.d.col.stride.s32.p1(ptr addrspace(1) %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), diff --git a/sycl/test/check_device_code/matrix/matrix_load_store_as.cpp b/sycl/test/check_device_code/matrix/matrix_load_store_as.cpp index e5935b8b3af47..4c3c97c77edd8 100644 --- a/sycl/test/check_device_code/matrix/matrix_load_store_as.cpp +++ b/sycl/test/check_device_code/matrix/matrix_load_store_as.cpp @@ -29,8 +29,9 @@ int main(void) { joint_matrix tA; - joint_matrix + joint_matrix< + sub_group, unsigned short, use::b, 16, 16, + ext::oneapi::experimental::matrix::layout::ext_intel_packed> tB; joint_matrix tC; @@ -49,7 +50,7 @@ int main(void) { // B should load from global address space // CHECK: %{{.*}} = tail call spir_func noundef target("spirv.JointMatrixINTEL", i16, 16, 16, 2, 3, 1) @_Z[[#]]__spirv_JointMatrixLoadINTEL{{.*}}(ptr addrspace(1) noundef %{{.*}}, i64 noundef 32, i32 noundef 2, i32 noundef 3, i32 noundef 0) #{{.*}} joint_matrix_load(sg, tB, pB, 32); - tC = joint_matrix_mad(sg, tA, tB, tC); + joint_matrix_mad(sg, tA, tB, tC, tC); // C should store to global address space // CHECK: tail call spir_func void @_Z[[#]]__spirv_JointMatrixStoreINTEL{{.*}}(ptr addrspace(1) noundef %{{.*}}, target("spirv.JointMatrixINTEL", float, 8, 16, 3, 3, 2) noundef %{{.*}}, i64 noundef 16, i32 noundef 0, i32 noundef 3, i32 noundef 0) #{{.*}} joint_matrix_store(sg, tC, pC, 16, layout::row_major); diff --git a/sycl/test/check_device_code/matrix/matrix_load_store_as_legacy.cpp b/sycl/test/check_device_code/matrix/matrix_load_store_as_legacy.cpp index bb18a21bc1002..022ac65612a28 100644 --- a/sycl/test/check_device_code/matrix/matrix_load_store_as_legacy.cpp +++ b/sycl/test/check_device_code/matrix/matrix_load_store_as_legacy.cpp @@ -47,7 +47,7 @@ int main(void) { // B should load from global address space // CHECK: %{{.*}} = tail call spir_func noundef target("spirv.JointMatrixINTEL", i16, 16, 16, 3, 3) @_Z[[#]]__spirv_JointMatrixLoadINTEL{{.*}}(ptr addrspace(1) noundef %{{.*}}, i64 noundef 32, i32 noundef [[#]], i32 noundef 3, i32 noundef 0) #{{.*}} joint_matrix_load(sg, tB, pB, 32, matrix_layout::packed_b); - tC = joint_matrix_mad(sg, tA, tB, tC); + joint_matrix_mad(sg, tA, tB, tC, tC); // C should store to global address space // CHECK: tail call spir_func void @_Z[[#]]__spirv_JointMatrixStoreINTEL{{.*}}(ptr addrspace(1) noundef %{{.*}}, target("spirv.JointMatrixINTEL", float, 8, 16, 0, 3) noundef %{{.*}}, i64 noundef 16, i32 noundef 0, i32 noundef 3, i32 noundef 0) #{{.*}} joint_matrix_store(sg, tC, pC, 16, matrix_layout::row_major); diff --git a/sycl/test/matrix/legacy/matrix-bf16-test-SG-16.cpp b/sycl/test/matrix/legacy/matrix-bf16-test-SG-16.cpp index 391a9be2197c6..2a4ac877a3864 100644 --- a/sycl/test/matrix/legacy/matrix-bf16-test-SG-16.cpp +++ b/sycl/test/matrix/legacy/matrix-bf16-test-SG-16.cpp @@ -89,7 +89,7 @@ void matrix_multiply(big_matrix &C, accB.template get_multi_ptr() + (k * TK / 2) * (N * 2) + sg_starty / SG_SZ * TN * 2, N * 2, matrix_layout::packed_b); - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, sub_c); } joint_matrix_store( sg, sub_c, diff --git a/sycl/test/matrix/legacy/matrix-bf16-test.cpp b/sycl/test/matrix/legacy/matrix-bf16-test.cpp index 6c6bfc1066f01..f58fe277d0073 100644 --- a/sycl/test/matrix/legacy/matrix-bf16-test.cpp +++ b/sycl/test/matrix/legacy/matrix-bf16-test.cpp @@ -88,7 +88,7 @@ void matrix_multiply(big_matrix &C, accB.template get_multi_ptr() + (k * TK / 2) * (N * 2) + sg_starty / SG_SZ * TN * 2, N * 2, matrix_layout::packed_b); - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, sub_c); } joint_matrix_store( sg, sub_c, diff --git a/sycl/test/matrix/legacy/matrix-bfloat16-test.cpp b/sycl/test/matrix/legacy/matrix-bfloat16-test.cpp index 022e69f9b75a2..3271b27bc466d 100644 --- a/sycl/test/matrix/legacy/matrix-bfloat16-test.cpp +++ b/sycl/test/matrix/legacy/matrix-bfloat16-test.cpp @@ -87,7 +87,7 @@ void matrix_multiply(big_matrix &C, accB.template get_multi_ptr() + (k * TK / 2) * (N * 2) + sg_starty / SG_SZ * TN * 2, N * 2, matrix_layout::packed_b); - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, sub_c); } joint_matrix_store( sg, sub_c, diff --git a/sycl/test/matrix/legacy/matrix-elemwise-ops.cpp b/sycl/test/matrix/legacy/matrix-elemwise-ops.cpp index feddb05148c4e..5285b8eb9aa2b 100644 --- a/sycl/test/matrix/legacy/matrix-elemwise-ops.cpp +++ b/sycl/test/matrix/legacy/matrix-elemwise-ops.cpp @@ -88,7 +88,7 @@ void matrix_multiply(big_matrix &C, accB.template get_multi_ptr() + (k * TK / 4) * (N * 4) + sg_starty / SG_SZ * TN * 4, N * 4, matrix_layout::packed_b); - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, sub_c); } auto wi_data_c = sub_c.get_wi_data(); for (int i = 0; i < wi_data_c.length(); i++) { diff --git a/sycl/test/matrix/legacy/matrix-int8-test-SG-16.cpp b/sycl/test/matrix/legacy/matrix-int8-test-SG-16.cpp index 335529ad3120a..47af048171265 100644 --- a/sycl/test/matrix/legacy/matrix-int8-test-SG-16.cpp +++ b/sycl/test/matrix/legacy/matrix-int8-test-SG-16.cpp @@ -88,7 +88,7 @@ void matrix_multiply(big_matrix &C, accB.template get_multi_ptr() + (k * TK / 4) * (N * 4) + sg_starty / SG_SZ * TN * 4, N * 4, matrix_layout::packed_b); - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, sub_c); } joint_matrix_store( sg, sub_c, diff --git a/sycl/test/matrix/legacy/matrix-int8-test.cpp b/sycl/test/matrix/legacy/matrix-int8-test.cpp index 77c57b4ef711e..614faf8defe5a 100644 --- a/sycl/test/matrix/legacy/matrix-int8-test.cpp +++ b/sycl/test/matrix/legacy/matrix-int8-test.cpp @@ -88,7 +88,7 @@ void matrix_multiply(big_matrix &C, accB.template get_multi_ptr() + (k * TK / 4) * (N * 4) + sg_starty / SG_SZ * TN * 4, N * 4, matrix_layout::packed_b); - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, sub_c); } joint_matrix_store( sg, sub_c, diff --git a/sycl/test/matrix/matrix-bfloat16-test-coord-basicB.cpp b/sycl/test/matrix/matrix-bfloat16-test-coord-basicB.cpp index b141f2176971d..769efbdb0d959 100644 --- a/sycl/test/matrix/matrix-bfloat16-test-coord-basicB.cpp +++ b/sycl/test/matrix/matrix-bfloat16-test-coord-basicB.cpp @@ -154,8 +154,9 @@ void matrix_sum_cols(queue q, big_matrix &B, nd_range<2> &r) { sub_group sg = spmd_item.get_sub_group(); // TK = 32, TN = 16 - joint_matrix + joint_matrix< + sub_group, int8_t, use::b, TK, TN, + ext::oneapi::experimental::matrix::layout::ext_intel_packed> sub_b; joint_matrix_load( @@ -166,8 +167,7 @@ void matrix_sum_cols(queue q, big_matrix &B, nd_range<2> &r) { int32_t sum_local_cols[N] = {0}; // 4 local cols, N total // sub_b has 32x16 elements, 32 elements per WI, 4 per WI per row - auto wiData = - sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_b); + auto wiData = sycl::ext::oneapi::detail::get_wi_data(sg, sub_b); size_t global_index; // Index into the result array that holds the sums. diff --git a/sycl/test/matrix/matrix-bfloat16-test.cpp b/sycl/test/matrix/matrix-bfloat16-test.cpp index 2e0e309081464..eae42973d45c4 100644 --- a/sycl/test/matrix/matrix-bfloat16-test.cpp +++ b/sycl/test/matrix/matrix-bfloat16-test.cpp @@ -68,7 +68,8 @@ void matrix_multiply(big_matrix &C, // the packed_b layout. By default, the layout is row_major and size // is (TK, TN). joint_matrix + sycl::ext::oneapi::experimental::matrix::layout:: + ext_intel_packed> sub_b; joint_matrix sub_c; @@ -89,7 +90,7 @@ void matrix_multiply(big_matrix &C, accB.template get_multi_ptr() + (k * TK / 2) * (N * 2) + sg_starty / SG_SZ * TN * 2, N * 2); - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c); } joint_matrix_store( sg, sub_c, diff --git a/sycl/test/matrix/matrix-elemwise-ops.cpp b/sycl/test/matrix/matrix-elemwise-ops.cpp index 3205e4c346ba6..861727c3fe92d 100644 --- a/sycl/test/matrix/matrix-elemwise-ops.cpp +++ b/sycl/test/matrix/matrix-elemwise-ops.cpp @@ -69,7 +69,8 @@ void matrix_multiply(big_matrix &C, // the packed_b layout. By default, the layout is row_major and size // is (TK, TN). joint_matrix + sycl::ext::oneapi::experimental::matrix::layout:: + ext_intel_packed> sub_b; joint_matrix sub_c; @@ -93,10 +94,9 @@ void matrix_multiply(big_matrix &C, accB.template get_multi_ptr() + (k * TK / 4) * (N * 4) + sg_starty / SG_SZ * TN * 4, N * 4); - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c); } - auto wi_data_c = - sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_c); + auto wi_data_c = sycl::ext::oneapi::detail::get_wi_data(sg, sub_c); for (int i = 0; i < wi_data_c.length(); i++) { wi_data_c[i] *= 2; } diff --git a/sycl/test/matrix/matrix-int8-test.cpp b/sycl/test/matrix/matrix-int8-test.cpp index f8dcc26ab1b17..959f5b2b30871 100644 --- a/sycl/test/matrix/matrix-int8-test.cpp +++ b/sycl/test/matrix/matrix-int8-test.cpp @@ -74,7 +74,8 @@ void matrix_multiply(big_matrix &C, // the packed_b layout. By default, the layout is row_major and size // is (TK, TN). joint_matrix + sycl::ext::oneapi::experimental::matrix::layout:: + ext_intel_packed> sub_b; joint_matrix sub_c; @@ -94,7 +95,7 @@ void matrix_multiply(big_matrix &C, accB.template get_multi_ptr() + (k * TK / 4) * (N * 4) + sg_starty / SG_SZ * TN * 4, N * 4); - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c); } joint_matrix_store( sg, sub_c, diff --git a/sycl/test/matrix/matrix-tf32-test.cpp b/sycl/test/matrix/matrix-tf32-test.cpp index d6affb4067003..64852b4a890cc 100644 --- a/sycl/test/matrix/matrix-tf32-test.cpp +++ b/sycl/test/matrix/matrix-tf32-test.cpp @@ -87,12 +87,11 @@ void matrix_multiply(big_matrix &C, // function will work on truncated floats. joint_matrix_apply(sg, sub_a, [=](float x) { x = round_to_tf32(x); }); - auto wi_data_b = - sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_b); + auto wi_data_b = sycl::ext::oneapi::detail::get_wi_data(sg, sub_b); for (int i = 0; i < wi_data_b.length(); i++) { wi_data_b[i] = round_to_tf32(wi_data_b[i]); } - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c); } joint_matrix_store( sg, sub_c, diff --git a/sycl/test/matrix/query-use.cpp b/sycl/test/matrix/query-use.cpp index 9afc8e1173043..96d2fd3c0c26a 100644 --- a/sycl/test/matrix/query-use.cpp +++ b/sycl/test/matrix/query-use.cpp @@ -64,7 +64,8 @@ void query_amx() { sub_group sg = spmd_item.get_sub_group(); myparams2::joint_matrix_a sub_a1; myparams2::joint_matrix_b< - sub_group, sycl::ext::intel::experimental::matrix::layout::packed> + sub_group, + sycl::ext::oneapi::experimental::matrix::layout::ext_intel_packed> sub_b1; myparams2::joint_matrix_accumulator sub_c1; @@ -144,7 +145,8 @@ void query_xmx8() { sub_group sg = spmd_item.get_sub_group(); myparams2::joint_matrix_a sub_a1; myparams2::joint_matrix_b< - sub_group, sycl::ext::intel::experimental::matrix::layout::packed> + sub_group, + sycl::ext::oneapi::experimental::matrix::layout::ext_intel_packed> sub_b1; myparams2::joint_matrix_accumulator sub_c1; From 5fbb285ad0dd6727b58d2865632d53cc829db16d Mon Sep 17 00:00:00 2001 From: Bing1 Yu Date: Tue, 19 Sep 2023 12:11:32 +0800 Subject: [PATCH 02/50] clang-format --- .../sycl/ext/oneapi/matrix/matrix-intel.hpp | 8 +- .../sycl/ext/oneapi/matrix/matrix-unified.hpp | 6 +- .../Matrix/Legacy/element_wise_ops_impl.hpp | 97 ++++++++++--------- .../Legacy/elemwise_irreg_size_ops_bf16.cpp | 3 +- .../Matrix/Legacy/joint_matrix_bf16_impl.hpp | 3 +- .../joint_matrix_bfloat16_32x64_impl.hpp | 3 +- ...trix_bfloat16_colmajorA_colmajorB_impl.hpp | 3 +- .../Legacy/joint_matrix_bfloat16_impl.hpp | 3 +- ...trix_bfloat16_rowmajorA_rowmajorB_impl.hpp | 3 +- .../Matrix/Legacy/joint_matrix_half_impl.hpp | 85 ++++++++-------- ...t_matrix_int8_colmajorA_colmajorB_impl.hpp | 3 +- .../Legacy/joint_matrix_int8_vnni_impl.hpp | 3 +- .../Legacy/joint_matrix_query_default.cpp | 3 +- .../Legacy/joint_matrix_ss_int8_impl.hpp | 77 +++++++-------- .../Legacy/joint_matrix_su_int8_impl.hpp | 89 ++++++++--------- .../Legacy/joint_matrix_us_int8_impl.hpp | 3 +- .../Legacy/joint_matrix_uu_int8_impl.hpp | 89 ++++++++--------- .../Matrix/element_wise_all_ops_half_impl.hpp | 15 +-- .../Matrix/element_wise_all_ops_impl.hpp | 6 +- .../Matrix/element_wise_all_ops_int8_impl.hpp | 15 +-- .../Matrix/element_wise_all_ops_tf32_impl.hpp | 15 +-- .../Matrix/element_wise_all_sizes_impl.hpp | 47 +++++---- .../Matrix/joint_matrix_down_convert_impl.hpp | 9 +- .../Matrix/joint_matrix_out_bounds_impl.hpp | 3 +- .../matrix/matrix-nvptx-bfloat16-test.cpp | 18 ++-- .../cuda/matrix/matrix-nvptx-double-test.cpp | 6 +- .../matrix/matrix-nvptx-half-float-test.cpp | 18 ++-- .../matrix/matrix-nvptx-half-half-test.cpp | 18 ++-- .../cuda/matrix/matrix-nvptx-int8-test.cpp | 18 ++-- .../cuda/matrix/matrix-nvptx-tf32-test.cpp | 6 +- .../cuda/matrix/matrix-nvptx-uint8-test.cpp | 18 ++-- .../matrix/legacy/matrix-bf16-test-SG-16.cpp | 3 +- sycl/test/matrix/legacy/matrix-bf16-test.cpp | 3 +- .../matrix/legacy/matrix-bfloat16-test.cpp | 3 +- .../matrix/legacy/matrix-elemwise-ops.cpp | 3 +- .../matrix/legacy/matrix-int8-test-SG-16.cpp | 3 +- sycl/test/matrix/legacy/matrix-int8-test.cpp | 3 +- 37 files changed, 371 insertions(+), 340 deletions(-) diff --git a/sycl/include/sycl/ext/oneapi/matrix/matrix-intel.hpp b/sycl/include/sycl/ext/oneapi/matrix/matrix-intel.hpp index eac624a9c3360..c3d4e973ef643 100644 --- a/sycl/include/sycl/ext/oneapi/matrix/matrix-intel.hpp +++ b/sycl/include/sycl/ext/oneapi/matrix/matrix-intel.hpp @@ -187,7 +187,7 @@ class wi_element { #if __SYCL_DEVICE_ONLY__ #define OP(op) \ - template wi_element &operator op##=(const T2 &rhs) { \ + template wi_element &operator op##=(const T2 & rhs) { \ M.spvm = __spirv_VectorInsertDynamic( \ M.spvm, \ static_cast( \ @@ -202,7 +202,7 @@ class wi_element { } #else // __SYCL_DEVICE_ONLY__ #define OP(op) \ - template wi_element &operator op##=(const T2 &rhs) { \ + template wi_element &operator op##=(const T2 & rhs) { \ (void)rhs; \ throw runtime_error("joint matrix is not supported on host device.", \ PI_ERROR_INVALID_DEVICE); \ @@ -306,7 +306,7 @@ class wi_element -void joint_matrix_copy(Group sg, - joint_matrix &src, - joint_matrix &dst) { +void joint_matrix_copy( + Group sg, joint_matrix &src, + joint_matrix &dst) { #if defined(__SYCL_DEVICE_ONLY__) #if defined(__NVPTX__) std::ignore = sg; diff --git a/sycl/test-e2e/Matrix/Legacy/element_wise_ops_impl.hpp b/sycl/test-e2e/Matrix/Legacy/element_wise_ops_impl.hpp index 7179b82855f50..dcc5764db1fcb 100644 --- a/sycl/test-e2e/Matrix/Legacy/element_wise_ops_impl.hpp +++ b/sycl/test-e2e/Matrix/Legacy/element_wise_ops_impl.hpp @@ -38,56 +38,57 @@ void matrix_multiply(big_matrix &C, cgh.parallel_for( nd_range<2>({NDRangeM, NDRangeN * SG_SZ}, {1, 1 * SG_SZ}), - [accA, accB, accC, M, N, - K](nd_item<2> spmd_item) [[intel::reqd_sub_group_size(SG_SZ)]] { - // The submatrix API has to be accessed by all the workitems in a - // subgroup these functions will be called once by the subgroup no - // code divergence between the workitems - const auto global_idx = spmd_item.get_global_id(0); - const auto global_idy = spmd_item.get_global_id(1); - const auto sg_startx = global_idx - spmd_item.get_local_id(0); - const auto sg_starty = global_idy - spmd_item.get_local_id(1); + [accA, accB, accC, M, N, K](nd_item<2> spmd_item) + [[intel::reqd_sub_group_size(SG_SZ)]] { + // The submatrix API has to be accessed by all the workitems in a + // subgroup these functions will be called once by the subgroup + // no code divergence between the workitems + const auto global_idx = spmd_item.get_global_id(0); + const auto global_idy = spmd_item.get_global_id(1); + const auto sg_startx = global_idx - spmd_item.get_local_id(0); + const auto sg_starty = global_idy - spmd_item.get_local_id(1); - sycl::sub_group sg = spmd_item.get_sub_group(); - joint_matrix sub_a(sg); - // For B, since current implementation does not support non-packed - // layout, users need to specify the updated VNNI sizes along with - // the packed_b layout. By default, the layout is row_major and size - // is (TK, TN). - joint_matrix sub_b(sg); - joint_matrix sub_c(sg); + sycl::sub_group sg = spmd_item.get_sub_group(); + joint_matrix sub_a(sg); + // For B, since current implementation does not support + // non-packed layout, users need to specify the updated VNNI + // sizes along with the packed_b layout. By default, the layout + // is row_major and size is (TK, TN). + joint_matrix sub_b(sg); + joint_matrix sub_c(sg); - // AMX: 8 register tiles : 1k byte size, SMmaxxSKmax =16x64 - // strideX = X's cols, so strideC = N, strideA = K, strideB = N*4 - joint_matrix_load( - sg, sub_c, - accC.template get_multi_ptr() + - (sg_startx * TM) * N + sg_starty / SG_SZ * TN, - N, matrix_layout::row_major); - for (int k = 0; k < K / TK; k += 1) { - joint_matrix_load( - sg, sub_a, - accA.template get_multi_ptr() + - (sg_startx * TM) * K + k * TK, - K, matrix_layout::row_major); - // Assuming B data is already in VNNI format. - joint_matrix_load( - sg, sub_b, - accB.template get_multi_ptr() + - (k * TK / 4) * (N * 4) + sg_starty / SG_SZ * TN * 4, - N * 4, matrix_layout::packed_b); - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, sub_c); - } - auto wi_slice_c = sub_c.get_wi_data(); - for (int i = 0; i < wi_slice_c.length(); i++) { - wi_slice_c[i] *= 2; - } - joint_matrix_store( - sg, sub_c, - accC.template get_multi_ptr() + - (sg_startx * TM) * N + sg_starty / SG_SZ * TN, - N, matrix_layout::row_major); - }); // parallel for + // AMX: 8 register tiles : 1k byte size, SMmaxxSKmax =16x64 + // strideX = X's cols, so strideC = N, strideA = K, strideB = N*4 + joint_matrix_load( + sg, sub_c, + accC.template get_multi_ptr() + + (sg_startx * TM) * N + sg_starty / SG_SZ * TN, + N, matrix_layout::row_major); + for (int k = 0; k < K / TK; k += 1) { + joint_matrix_load( + sg, sub_a, + accA.template get_multi_ptr() + + (sg_startx * TM) * K + k * TK, + K, matrix_layout::row_major); + // Assuming B data is already in VNNI format. + joint_matrix_load( + sg, sub_b, + accB.template get_multi_ptr() + + (k * TK / 4) * (N * 4) + sg_starty / SG_SZ * TN * 4, + N * 4, matrix_layout::packed_b); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, + sub_b, sub_c); + } + auto wi_slice_c = sub_c.get_wi_data(); + for (int i = 0; i < wi_slice_c.length(); i++) { + wi_slice_c[i] *= 2; + } + joint_matrix_store( + sg, sub_c, + accC.template get_multi_ptr() + + (sg_startx * TM) * N + sg_starty / SG_SZ * TN, + N, matrix_layout::row_major); + }); // parallel for }).wait(); } diff --git a/sycl/test-e2e/Matrix/Legacy/elemwise_irreg_size_ops_bf16.cpp b/sycl/test-e2e/Matrix/Legacy/elemwise_irreg_size_ops_bf16.cpp index d041cd6050d73..d0ee5869b1d54 100644 --- a/sycl/test-e2e/Matrix/Legacy/elemwise_irreg_size_ops_bf16.cpp +++ b/sycl/test-e2e/Matrix/Legacy/elemwise_irreg_size_ops_bf16.cpp @@ -100,7 +100,8 @@ void matrix_multiply(big_matrix &C, accB.template get_multi_ptr() + (k) * (N) + sg_starty / SG_SZ * TN * 2, N * 2, matrix_layout::packed_b); - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, + sub_c); } auto wi_slice_c = sub_c.get_wi_data(); for (int i = 0; i < wi_slice_c.length(); i++) { diff --git a/sycl/test-e2e/Matrix/Legacy/joint_matrix_bf16_impl.hpp b/sycl/test-e2e/Matrix/Legacy/joint_matrix_bf16_impl.hpp index 3ba9ae346e070..060830074d1fe 100644 --- a/sycl/test-e2e/Matrix/Legacy/joint_matrix_bf16_impl.hpp +++ b/sycl/test-e2e/Matrix/Legacy/joint_matrix_bf16_impl.hpp @@ -76,7 +76,8 @@ void matrix_multiply(big_matrix &C, accB.template get_multi_ptr() + (k) * (N) + sg_starty / SG_SZ * TN * 2, N * 2, matrix_layout::packed_b); - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, + sub_c); } joint_matrix_store( sg, sub_c, diff --git a/sycl/test-e2e/Matrix/Legacy/joint_matrix_bfloat16_32x64_impl.hpp b/sycl/test-e2e/Matrix/Legacy/joint_matrix_bfloat16_32x64_impl.hpp index 8835bce054171..b8ed30ae495b9 100644 --- a/sycl/test-e2e/Matrix/Legacy/joint_matrix_bfloat16_32x64_impl.hpp +++ b/sycl/test-e2e/Matrix/Legacy/joint_matrix_bfloat16_32x64_impl.hpp @@ -68,7 +68,8 @@ void matrix_multiply(big_matrix &C, big_matrix &A, accB.template get_multi_ptr() + (k * TK / 2) * (N * 2) + sg_starty / SG_SZ * TN * 2, N * 2, matrix_layout::packed_b); - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, + sub_c); } joint_matrix_store( sg, sub_c, diff --git a/sycl/test-e2e/Matrix/Legacy/joint_matrix_bfloat16_colmajorA_colmajorB_impl.hpp b/sycl/test-e2e/Matrix/Legacy/joint_matrix_bfloat16_colmajorA_colmajorB_impl.hpp index 48c4a894ab385..5738b3a109a9c 100644 --- a/sycl/test-e2e/Matrix/Legacy/joint_matrix_bfloat16_colmajorA_colmajorB_impl.hpp +++ b/sycl/test-e2e/Matrix/Legacy/joint_matrix_bfloat16_colmajorA_colmajorB_impl.hpp @@ -62,7 +62,8 @@ void matrix_multiply(big_matrix &C, big_matrix &A, accB.template get_multi_ptr() + (sg_starty / SG_SZ * TN) * K + k * TK, K, matrix_layout::col_major); - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, + sub_c); } joint_matrix_store( sg, sub_c, diff --git a/sycl/test-e2e/Matrix/Legacy/joint_matrix_bfloat16_impl.hpp b/sycl/test-e2e/Matrix/Legacy/joint_matrix_bfloat16_impl.hpp index db16f05673321..9358ffb4168cd 100644 --- a/sycl/test-e2e/Matrix/Legacy/joint_matrix_bfloat16_impl.hpp +++ b/sycl/test-e2e/Matrix/Legacy/joint_matrix_bfloat16_impl.hpp @@ -68,7 +68,8 @@ void matrix_multiply(big_matrix &C, big_matrix &A, accB.template get_multi_ptr() + (k * TK / 2) * (N * 2) + sg_starty / SG_SZ * TN * 2, N * 2, matrix_layout::packed_b); - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, + sub_c); } joint_matrix_store( sg, sub_c, diff --git a/sycl/test-e2e/Matrix/Legacy/joint_matrix_bfloat16_rowmajorA_rowmajorB_impl.hpp b/sycl/test-e2e/Matrix/Legacy/joint_matrix_bfloat16_rowmajorA_rowmajorB_impl.hpp index 5676cd849e1b5..1a567c9147097 100644 --- a/sycl/test-e2e/Matrix/Legacy/joint_matrix_bfloat16_rowmajorA_rowmajorB_impl.hpp +++ b/sycl/test-e2e/Matrix/Legacy/joint_matrix_bfloat16_rowmajorA_rowmajorB_impl.hpp @@ -62,7 +62,8 @@ void matrix_multiply(big_matrix &C, big_matrix &A, accB.template get_multi_ptr() + (k * TK) * (N) + sg_starty / SG_SZ * TN, N, matrix_layout::row_major); - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, + sub_c); } joint_matrix_store( sg, sub_c, diff --git a/sycl/test-e2e/Matrix/Legacy/joint_matrix_half_impl.hpp b/sycl/test-e2e/Matrix/Legacy/joint_matrix_half_impl.hpp index 658985c10ab5c..86439c83ff840 100644 --- a/sycl/test-e2e/Matrix/Legacy/joint_matrix_half_impl.hpp +++ b/sycl/test-e2e/Matrix/Legacy/joint_matrix_half_impl.hpp @@ -37,50 +37,51 @@ void matrix_multiply(big_matrix &C, cgh.parallel_for( nd_range<2>({NDRangeM, NDRangeN * SG_SZ}, {1, SG_SZ}), - [accA, accB, accC, M, N, - K](nd_item<2> spmd_item) [[intel::reqd_sub_group_size(SG_SZ)]] { - // The submatrix API has to be accessed by all the workitems in a - // subgroup these functions will be called once by the subgroup no - // code divergence between the workitems - const auto global_idx = spmd_item.get_global_id(0); - const auto global_idy = spmd_item.get_global_id(1); - const auto sg_startx = global_idx - spmd_item.get_local_id(0); - const auto sg_starty = global_idy - spmd_item.get_local_id(1); + [accA, accB, accC, M, N, K](nd_item<2> spmd_item) + [[intel::reqd_sub_group_size(SG_SZ)]] { + // The submatrix API has to be accessed by all the workitems in a + // subgroup these functions will be called once by the subgroup + // no code divergence between the workitems + const auto global_idx = spmd_item.get_global_id(0); + const auto global_idy = spmd_item.get_global_id(1); + const auto sg_startx = global_idx - spmd_item.get_local_id(0); + const auto sg_starty = global_idy - spmd_item.get_local_id(1); - sycl::sub_group sg = spmd_item.get_sub_group(); - joint_matrix sub_a(sg); - // For B, since current implementation does not support non-packed - // layout, users need to specify the updated VNNI sizes along with - // the packed_b layout. By default, the layout is row_major and size - // is (TK, TN). - joint_matrix sub_b(sg); - joint_matrix sub_c(sg); + sycl::sub_group sg = spmd_item.get_sub_group(); + joint_matrix sub_a(sg); + // For B, since current implementation does not support + // non-packed layout, users need to specify the updated VNNI + // sizes along with the packed_b layout. By default, the layout + // is row_major and size is (TK, TN). + joint_matrix sub_b(sg); + joint_matrix sub_c(sg); - joint_matrix_load( - sg, sub_c, - accC.template get_multi_ptr() + - (sg_startx * TM) * N + sg_starty / SG_SZ * TN, - N, matrix_layout::row_major); - for (int k = 0; k < K / TK; k += 1) { - joint_matrix_load( - sg, sub_a, - accA.template get_multi_ptr() + - (sg_startx * TM) * K + k * TK, - K, matrix_layout::row_major); - // Assuming B data is already in VNNI format. - joint_matrix_load( - sg, sub_b, - accB.template get_multi_ptr() + - (k * TK / 2) * (N * 2) + sg_starty / SG_SZ * TN * 2, - N * 2, matrix_layout::packed_b); - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, sub_c); - } - joint_matrix_store( - sg, sub_c, - accC.template get_multi_ptr() + - (sg_startx * TM) * N + sg_starty / SG_SZ * TN, - N, matrix_layout::row_major); - }); // parallel for + joint_matrix_load( + sg, sub_c, + accC.template get_multi_ptr() + + (sg_startx * TM) * N + sg_starty / SG_SZ * TN, + N, matrix_layout::row_major); + for (int k = 0; k < K / TK; k += 1) { + joint_matrix_load( + sg, sub_a, + accA.template get_multi_ptr() + + (sg_startx * TM) * K + k * TK, + K, matrix_layout::row_major); + // Assuming B data is already in VNNI format. + joint_matrix_load( + sg, sub_b, + accB.template get_multi_ptr() + + (k * TK / 2) * (N * 2) + sg_starty / SG_SZ * TN * 2, + N * 2, matrix_layout::packed_b); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, + sub_b, sub_c); + } + joint_matrix_store( + sg, sub_c, + accC.template get_multi_ptr() + + (sg_startx * TM) * N + sg_starty / SG_SZ * TN, + N, matrix_layout::row_major); + }); // parallel for }).wait(); } diff --git a/sycl/test-e2e/Matrix/Legacy/joint_matrix_int8_colmajorA_colmajorB_impl.hpp b/sycl/test-e2e/Matrix/Legacy/joint_matrix_int8_colmajorA_colmajorB_impl.hpp index 85bffe8957e77..a8c97d97aced3 100644 --- a/sycl/test-e2e/Matrix/Legacy/joint_matrix_int8_colmajorA_colmajorB_impl.hpp +++ b/sycl/test-e2e/Matrix/Legacy/joint_matrix_int8_colmajorA_colmajorB_impl.hpp @@ -72,7 +72,8 @@ void matrix_multiply(big_matrix &C, accB.template get_multi_ptr() + (sg_starty / SG_SZ * TN) * K + k * TK, K, matrix_layout::col_major); - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, + sub_c); } joint_matrix_store( sg, sub_c, diff --git a/sycl/test-e2e/Matrix/Legacy/joint_matrix_int8_vnni_impl.hpp b/sycl/test-e2e/Matrix/Legacy/joint_matrix_int8_vnni_impl.hpp index 2cb8196a49306..9a9d895f4771c 100644 --- a/sycl/test-e2e/Matrix/Legacy/joint_matrix_int8_vnni_impl.hpp +++ b/sycl/test-e2e/Matrix/Legacy/joint_matrix_int8_vnni_impl.hpp @@ -64,7 +64,8 @@ void matrix_multiply(big_matrix &C, accB.template get_multi_ptr() + (k * TK) * N + sg_starty / SG_SZ * TN, N, matrix_layout::row_major); - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, + sub_b, sub_c); } joint_matrix_store( sg, sub_c, diff --git a/sycl/test-e2e/Matrix/Legacy/joint_matrix_query_default.cpp b/sycl/test-e2e/Matrix/Legacy/joint_matrix_query_default.cpp index 6b7878b5cce7d..fb3dceee3f361 100644 --- a/sycl/test-e2e/Matrix/Legacy/joint_matrix_query_default.cpp +++ b/sycl/test-e2e/Matrix/Legacy/joint_matrix_query_default.cpp @@ -97,7 +97,8 @@ void matrix_multiply(big_matrix &C, accB.template get_multi_ptr() + (k * TK / 4) * (N * 4) + sg_starty / SG_SZ * TN * 4, N * 4, matrix_layout::packed_b); - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, + sub_c); } joint_matrix_store( sg, sub_c, diff --git a/sycl/test-e2e/Matrix/Legacy/joint_matrix_ss_int8_impl.hpp b/sycl/test-e2e/Matrix/Legacy/joint_matrix_ss_int8_impl.hpp index ef594dc6bc3a8..1598f8b78b3a2 100644 --- a/sycl/test-e2e/Matrix/Legacy/joint_matrix_ss_int8_impl.hpp +++ b/sycl/test-e2e/Matrix/Legacy/joint_matrix_ss_int8_impl.hpp @@ -38,46 +38,47 @@ void matrix_multiply(big_matrix &C, cgh.parallel_for( nd_range<2>({NDRangeM, NDRangeN * SG_SZ}, {1, 1 * SG_SZ}), - [accA, accB, accC, M, N, - K](nd_item<2> spmd_item) [[intel::reqd_sub_group_size(SG_SZ)]] { - // The submatrix API has to be accessed by all the workitems in a - // subgroup these functions will be called once by the subgroup no - // code divergence between the workitems - const auto global_idx = spmd_item.get_global_id(0); - const auto global_idy = spmd_item.get_global_id(1); - const auto sg_startx = global_idx - spmd_item.get_local_id(0); - const auto sg_starty = global_idy - spmd_item.get_local_id(1); + [accA, accB, accC, M, N, K](nd_item<2> spmd_item) + [[intel::reqd_sub_group_size(SG_SZ)]] { + // The submatrix API has to be accessed by all the workitems in a + // subgroup these functions will be called once by the subgroup + // no code divergence between the workitems + const auto global_idx = spmd_item.get_global_id(0); + const auto global_idy = spmd_item.get_global_id(1); + const auto sg_startx = global_idx - spmd_item.get_local_id(0); + const auto sg_starty = global_idy - spmd_item.get_local_id(1); - sycl::sub_group sg = spmd_item.get_sub_group(); - joint_matrix sub_a(sg); - // For B, since current implementation does not support non-packed - // layout, users need to specify the updated VNNI sizes along with - // the packed_b layout. By default, the layout is row_major and size - // is (TK, TN). - joint_matrix sub_b(sg); - joint_matrix sub_c(sg); + sycl::sub_group sg = spmd_item.get_sub_group(); + joint_matrix sub_a(sg); + // For B, since current implementation does not support + // non-packed layout, users need to specify the updated VNNI + // sizes along with the packed_b layout. By default, the layout + // is row_major and size is (TK, TN). + joint_matrix sub_b(sg); + joint_matrix sub_c(sg); - joint_matrix_fill(sg, sub_c, 0); - for (int k = 0; k < K / TK; k += 1) { - joint_matrix_load( - sg, sub_a, - accA.template get_multi_ptr() + - (sg_startx * TM) * K + k * TK, - K, matrix_layout::row_major); - // Assuming B data is already in VNNI format. - joint_matrix_load( - sg, sub_b, - accB.template get_multi_ptr() + - (k * TK / 4) * (N * 4) + sg_starty / SG_SZ * TN * 4, - N * 4, matrix_layout::packed_b); - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, sub_c); - } - joint_matrix_store( - sg, sub_c, - accC.template get_multi_ptr() + - (sg_startx * TM) * N + sg_starty / SG_SZ * TN, - N, matrix_layout::row_major); - }); // parallel for + joint_matrix_fill(sg, sub_c, 0); + for (int k = 0; k < K / TK; k += 1) { + joint_matrix_load( + sg, sub_a, + accA.template get_multi_ptr() + + (sg_startx * TM) * K + k * TK, + K, matrix_layout::row_major); + // Assuming B data is already in VNNI format. + joint_matrix_load( + sg, sub_b, + accB.template get_multi_ptr() + + (k * TK / 4) * (N * 4) + sg_starty / SG_SZ * TN * 4, + N * 4, matrix_layout::packed_b); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, + sub_b, sub_c); + } + joint_matrix_store( + sg, sub_c, + accC.template get_multi_ptr() + + (sg_startx * TM) * N + sg_starty / SG_SZ * TN, + N, matrix_layout::row_major); + }); // parallel for }).wait(); } diff --git a/sycl/test-e2e/Matrix/Legacy/joint_matrix_su_int8_impl.hpp b/sycl/test-e2e/Matrix/Legacy/joint_matrix_su_int8_impl.hpp index 0eb5bc2bc8d58..34320622f006c 100644 --- a/sycl/test-e2e/Matrix/Legacy/joint_matrix_su_int8_impl.hpp +++ b/sycl/test-e2e/Matrix/Legacy/joint_matrix_su_int8_impl.hpp @@ -38,52 +38,53 @@ void matrix_multiply(big_matrix &C, cgh.parallel_for( nd_range<2>({NDRangeM, NDRangeN * SG_SZ}, {1, 1 * SG_SZ}), - [accA, accB, accC, M, N, - K](nd_item<2> spmd_item) [[intel::reqd_sub_group_size(SG_SZ)]] { - // The submatrix API has to be accessed by all the workitems in a - // subgroup these functions will be called once by the subgroup no - // code divergence between the workitems - const auto global_idx = spmd_item.get_global_id(0); - const auto global_idy = spmd_item.get_global_id(1); - const auto sg_startx = global_idx - spmd_item.get_local_id(0); - const auto sg_starty = global_idy - spmd_item.get_local_id(1); + [accA, accB, accC, M, N, K](nd_item<2> spmd_item) + [[intel::reqd_sub_group_size(SG_SZ)]] { + // The submatrix API has to be accessed by all the workitems in a + // subgroup these functions will be called once by the subgroup + // no code divergence between the workitems + const auto global_idx = spmd_item.get_global_id(0); + const auto global_idy = spmd_item.get_global_id(1); + const auto sg_startx = global_idx - spmd_item.get_local_id(0); + const auto sg_starty = global_idy - spmd_item.get_local_id(1); - sycl::sub_group sg = spmd_item.get_sub_group(); - joint_matrix sub_a(sg); - // For B, since current implementation does not support non-packed - // layout, users need to specify the updated VNNI sizes along with - // the packed_b layout. By default, the layout is row_major and size - // is (TK, TN). - joint_matrix sub_b(sg); - joint_matrix sub_c(sg); + sycl::sub_group sg = spmd_item.get_sub_group(); + joint_matrix sub_a(sg); + // For B, since current implementation does not support + // non-packed layout, users need to specify the updated VNNI + // sizes along with the packed_b layout. By default, the layout + // is row_major and size is (TK, TN). + joint_matrix sub_b(sg); + joint_matrix sub_c(sg); - // AMX: 8 register tiles : 1k byte size, SMmaxxSKmax =16x64 - // strideX = X's cols, so strideC = N, strideA = K, strideB = N*4 - joint_matrix_load( - sg, sub_c, - accC.template get_multi_ptr() + - (sg_startx * TM) * N + sg_starty / SG_SZ * TN, - N, matrix_layout::row_major); - for (int k = 0; k < K / TK; k += 1) { - joint_matrix_load( - sg, sub_a, - accA.template get_multi_ptr() + - (sg_startx * TM) * K + k * TK, - K, matrix_layout::row_major); - // Assuming B data is already in VNNI format. - joint_matrix_load( - sg, sub_b, - accB.template get_multi_ptr() + - (k * TK / 4) * (N * 4) + sg_starty / SG_SZ * TN * 4, - N * 4, matrix_layout::packed_b); - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, sub_c); - } - joint_matrix_store( - sg, sub_c, - accC.template get_multi_ptr() + - (sg_startx * TM) * N + sg_starty / SG_SZ * TN, - N, matrix_layout::row_major); - }); // parallel for + // AMX: 8 register tiles : 1k byte size, SMmaxxSKmax =16x64 + // strideX = X's cols, so strideC = N, strideA = K, strideB = N*4 + joint_matrix_load( + sg, sub_c, + accC.template get_multi_ptr() + + (sg_startx * TM) * N + sg_starty / SG_SZ * TN, + N, matrix_layout::row_major); + for (int k = 0; k < K / TK; k += 1) { + joint_matrix_load( + sg, sub_a, + accA.template get_multi_ptr() + + (sg_startx * TM) * K + k * TK, + K, matrix_layout::row_major); + // Assuming B data is already in VNNI format. + joint_matrix_load( + sg, sub_b, + accB.template get_multi_ptr() + + (k * TK / 4) * (N * 4) + sg_starty / SG_SZ * TN * 4, + N * 4, matrix_layout::packed_b); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, + sub_b, sub_c); + } + joint_matrix_store( + sg, sub_c, + accC.template get_multi_ptr() + + (sg_startx * TM) * N + sg_starty / SG_SZ * TN, + N, matrix_layout::row_major); + }); // parallel for }).wait(); } diff --git a/sycl/test-e2e/Matrix/Legacy/joint_matrix_us_int8_impl.hpp b/sycl/test-e2e/Matrix/Legacy/joint_matrix_us_int8_impl.hpp index 3d0067cbb7b36..bc0ed40202116 100644 --- a/sycl/test-e2e/Matrix/Legacy/joint_matrix_us_int8_impl.hpp +++ b/sycl/test-e2e/Matrix/Legacy/joint_matrix_us_int8_impl.hpp @@ -78,7 +78,8 @@ void matrix_multiply(big_matrix &C, accB.template get_multi_ptr() + (k * TK / 4) * (N * 4) + sg_starty / SG_SZ * TN * 4, N * 4, matrix_layout::packed_b); - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, + sub_c); } joint_matrix_store( sg, sub_c, diff --git a/sycl/test-e2e/Matrix/Legacy/joint_matrix_uu_int8_impl.hpp b/sycl/test-e2e/Matrix/Legacy/joint_matrix_uu_int8_impl.hpp index be4eee8452c3c..164892179f674 100644 --- a/sycl/test-e2e/Matrix/Legacy/joint_matrix_uu_int8_impl.hpp +++ b/sycl/test-e2e/Matrix/Legacy/joint_matrix_uu_int8_impl.hpp @@ -38,52 +38,53 @@ void matrix_multiply(big_matrix &C, cgh.parallel_for( nd_range<2>({NDRangeM, NDRangeN * SG_SZ}, {1, 1 * SG_SZ}), - [accA, accB, accC, M, N, - K](nd_item<2> spmd_item) [[intel::reqd_sub_group_size(SG_SZ)]] { - // The submatrix API has to be accessed by all the workitems in a - // subgroup these functions will be called once by the subgroup no - // code divergence between the workitems - const auto global_idx = spmd_item.get_global_id(0); - const auto global_idy = spmd_item.get_global_id(1); - const auto sg_startx = global_idx - spmd_item.get_local_id(0); - const auto sg_starty = global_idy - spmd_item.get_local_id(1); + [accA, accB, accC, M, N, K](nd_item<2> spmd_item) + [[intel::reqd_sub_group_size(SG_SZ)]] { + // The submatrix API has to be accessed by all the workitems in a + // subgroup these functions will be called once by the subgroup + // no code divergence between the workitems + const auto global_idx = spmd_item.get_global_id(0); + const auto global_idy = spmd_item.get_global_id(1); + const auto sg_startx = global_idx - spmd_item.get_local_id(0); + const auto sg_starty = global_idy - spmd_item.get_local_id(1); - sycl::sub_group sg = spmd_item.get_sub_group(); - joint_matrix sub_a(sg); - // For B, since current implementation does not support non-packed - // layout, users need to specify the updated VNNI sizes along with - // the packed_b layout. By default, the layout is row_major and size - // is (TK, TN). - joint_matrix sub_b(sg); - joint_matrix sub_c(sg); + sycl::sub_group sg = spmd_item.get_sub_group(); + joint_matrix sub_a(sg); + // For B, since current implementation does not support + // non-packed layout, users need to specify the updated VNNI + // sizes along with the packed_b layout. By default, the layout + // is row_major and size is (TK, TN). + joint_matrix sub_b(sg); + joint_matrix sub_c(sg); - // AMX: 8 register tiles : 1k byte size, SMmaxxSKmax =16x64 - // strideX = X's cols, so strideC = N, strideA = K, strideB = N*4 - joint_matrix_load( - sg, sub_c, - accC.template get_multi_ptr() + - (sg_startx * TM) * N + sg_starty / SG_SZ * TN, - N, matrix_layout::row_major); - for (int k = 0; k < K / TK; k += 1) { - joint_matrix_load( - sg, sub_a, - accA.template get_multi_ptr() + - (sg_startx * TM) * K + k * TK, - K, matrix_layout::row_major); - // Assuming B data is already in VNNI format. - joint_matrix_load( - sg, sub_b, - accB.template get_multi_ptr() + - (k * TK / 4) * (N * 4) + sg_starty / SG_SZ * TN * 4, - N * 4, matrix_layout::packed_b); - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, sub_c); - } - joint_matrix_store( - sg, sub_c, - accC.template get_multi_ptr() + - (sg_startx * TM) * N + sg_starty / SG_SZ * TN, - N, matrix_layout::row_major); - }); // parallel for + // AMX: 8 register tiles : 1k byte size, SMmaxxSKmax =16x64 + // strideX = X's cols, so strideC = N, strideA = K, strideB = N*4 + joint_matrix_load( + sg, sub_c, + accC.template get_multi_ptr() + + (sg_startx * TM) * N + sg_starty / SG_SZ * TN, + N, matrix_layout::row_major); + for (int k = 0; k < K / TK; k += 1) { + joint_matrix_load( + sg, sub_a, + accA.template get_multi_ptr() + + (sg_startx * TM) * K + k * TK, + K, matrix_layout::row_major); + // Assuming B data is already in VNNI format. + joint_matrix_load( + sg, sub_b, + accB.template get_multi_ptr() + + (k * TK / 4) * (N * 4) + sg_starty / SG_SZ * TN * 4, + N * 4, matrix_layout::packed_b); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, + sub_b, sub_c); + } + joint_matrix_store( + sg, sub_c, + accC.template get_multi_ptr() + + (sg_startx * TM) * N + sg_starty / SG_SZ * TN, + N, matrix_layout::row_major); + }); // parallel for }).wait(); } diff --git a/sycl/test-e2e/Matrix/element_wise_all_ops_half_impl.hpp b/sycl/test-e2e/Matrix/element_wise_all_ops_half_impl.hpp index 063acee4c2ffe..df786d73a78ea 100644 --- a/sycl/test-e2e/Matrix/element_wise_all_ops_half_impl.hpp +++ b/sycl/test-e2e/Matrix/element_wise_all_ops_half_impl.hpp @@ -41,8 +41,7 @@ void matrix_verify_add(queue q, big_matrix &A, nd_range<2> &r, joint_matrix_fill(sg, sub_a, 5); - auto wi_slice_a = - sycl::ext::oneapi::detail::get_wi_data(sg, sub_a); + auto wi_slice_a = sycl::ext::oneapi::detail::get_wi_data(sg, sub_a); for (int i = 0; i < wi_slice_a.length(); i++) { wi_slice_a[i] = wi_slice_a[i] + static_cast(2); } @@ -76,8 +75,7 @@ void matrix_verify_sub(queue q, big_matrix &A, nd_range<2> &r, joint_matrix_fill(sg, sub_a, 5); - auto wi_slice_a = - sycl::ext::oneapi::detail::get_wi_data(sg, sub_a); + auto wi_slice_a = sycl::ext::oneapi::detail::get_wi_data(sg, sub_a); for (int i = 0; i < wi_slice_a.length(); i++) { wi_slice_a[i] = wi_slice_a[i] - static_cast(2); } @@ -111,8 +109,7 @@ void matrix_verify_mul(queue q, big_matrix &A, nd_range<2> &r, joint_matrix_fill(sg, sub_a, 5); - auto wi_slice_a = - sycl::ext::oneapi::detail::get_wi_data(sg, sub_a); + auto wi_slice_a = sycl::ext::oneapi::detail::get_wi_data(sg, sub_a); for (int i = 0; i < wi_slice_a.length(); i++) { wi_slice_a[i] = wi_slice_a[i] * static_cast(3.0); } @@ -146,8 +143,7 @@ void matrix_verify_div(queue q, big_matrix &A, nd_range<2> &r, joint_matrix_fill(sg, sub_a, 4); - auto wi_slice_a = - sycl::ext::oneapi::detail::get_wi_data(sg, sub_a); + auto wi_slice_a = sycl::ext::oneapi::detail::get_wi_data(sg, sub_a); for (int i = 0; i < wi_slice_a.length(); i++) { wi_slice_a[i] = wi_slice_a[i] / static_cast(2.0); } @@ -181,8 +177,7 @@ void matrix_verify_logic(queue q, big_matrix &A, nd_range<2> &r, joint_matrix_fill(sg, sub_a, 5); - auto wi_slice_a = - sycl::ext::oneapi::detail::get_wi_data(sg, sub_a); + auto wi_slice_a = sycl::ext::oneapi::detail::get_wi_data(sg, sub_a); for (int i = 0; i < wi_slice_a.length(); i++) { if (wi_slice_a[i]) { if (wi_slice_a[i] > static_cast(2.0) || diff --git a/sycl/test-e2e/Matrix/element_wise_all_ops_impl.hpp b/sycl/test-e2e/Matrix/element_wise_all_ops_impl.hpp index 25d77269b03a4..447dff879cb01 100644 --- a/sycl/test-e2e/Matrix/element_wise_all_ops_impl.hpp +++ b/sycl/test-e2e/Matrix/element_wise_all_ops_impl.hpp @@ -63,8 +63,7 @@ void verify_op_a(const T l, const T r, const float ref, OP op) { layout::row_major> sub_mat; joint_matrix_fill(sg, sub_mat, l); - auto wi_slice = - sycl::ext::oneapi::detail::get_wi_data(sg, sub_mat); + auto wi_slice = sycl::ext::oneapi::detail::get_wi_data(sg, sub_mat); for (int i = 0; i < wi_slice.length(); i++) { wi_slice[i] = op(wi_slice[i], r); } @@ -104,8 +103,7 @@ void verify_op_c(const T l, const T r, const float ref, OP op) { joint_matrix sub_mat; joint_matrix_fill(sg, sub_mat, l); - auto wi_slice = - sycl::ext::oneapi::detail::get_wi_data(sg, sub_mat); + auto wi_slice = sycl::ext::oneapi::detail::get_wi_data(sg, sub_mat); for (int i = 0; i < wi_slice.length(); i++) { wi_slice[i] = op(wi_slice[i], r); } diff --git a/sycl/test-e2e/Matrix/element_wise_all_ops_int8_impl.hpp b/sycl/test-e2e/Matrix/element_wise_all_ops_int8_impl.hpp index 5d1668b753baa..c18e371711824 100644 --- a/sycl/test-e2e/Matrix/element_wise_all_ops_int8_impl.hpp +++ b/sycl/test-e2e/Matrix/element_wise_all_ops_int8_impl.hpp @@ -40,8 +40,7 @@ void matrix_verify_add(queue q, big_matrix &A, nd_range<2> &r, joint_matrix_fill(sg, sub_a, 5); - auto wi_slice_a = - sycl::ext::oneapi::detail::get_wi_data(sg, sub_a); + auto wi_slice_a = sycl::ext::oneapi::detail::get_wi_data(sg, sub_a); for (int i = 0; i < wi_slice_a.length(); i++) { wi_slice_a[i] = wi_slice_a[i] + 2; } @@ -75,8 +74,7 @@ void matrix_verify_sub(queue q, big_matrix &A, nd_range<2> &r, joint_matrix_fill(sg, sub_a, 5); - auto wi_slice_a = - sycl::ext::oneapi::detail::get_wi_data(sg, sub_a); + auto wi_slice_a = sycl::ext::oneapi::detail::get_wi_data(sg, sub_a); for (int i = 0; i < wi_slice_a.length(); i++) { wi_slice_a[i] = wi_slice_a[i] - 2; } @@ -110,8 +108,7 @@ void matrix_verify_mul(queue q, big_matrix &A, nd_range<2> &r, joint_matrix_fill(sg, sub_a, 5); - auto wi_slice_a = - sycl::ext::oneapi::detail::get_wi_data(sg, sub_a); + auto wi_slice_a = sycl::ext::oneapi::detail::get_wi_data(sg, sub_a); for (int i = 0; i < wi_slice_a.length(); i++) { wi_slice_a[i] = wi_slice_a[i] * 3; } @@ -145,8 +142,7 @@ void matrix_verify_div(queue q, big_matrix &A, nd_range<2> &r, joint_matrix_fill(sg, sub_a, 4); - auto wi_slice_a = - sycl::ext::oneapi::detail::get_wi_data(sg, sub_a); + auto wi_slice_a = sycl::ext::oneapi::detail::get_wi_data(sg, sub_a); for (int i = 0; i < wi_slice_a.length(); i++) { wi_slice_a[i] = wi_slice_a[i] / 2; } @@ -180,8 +176,7 @@ void matrix_verify_logic(queue q, big_matrix &A, nd_range<2> &r, joint_matrix_fill(sg, sub_a, 5); - auto wi_slice_a = - sycl::ext::oneapi::detail::get_wi_data(sg, sub_a); + auto wi_slice_a = sycl::ext::oneapi::detail::get_wi_data(sg, sub_a); for (int i = 0; i < wi_slice_a.length(); i++) { if (wi_slice_a[i]) { if (wi_slice_a[i] > 2 || wi_slice_a[i] >= 2 || diff --git a/sycl/test-e2e/Matrix/element_wise_all_ops_tf32_impl.hpp b/sycl/test-e2e/Matrix/element_wise_all_ops_tf32_impl.hpp index 96fcfa975b408..e73056639eb74 100644 --- a/sycl/test-e2e/Matrix/element_wise_all_ops_tf32_impl.hpp +++ b/sycl/test-e2e/Matrix/element_wise_all_ops_tf32_impl.hpp @@ -41,8 +41,7 @@ void matrix_verify_add(queue q, big_matrix &A, nd_range<2> &r, joint_matrix_fill(sg, sub_a, round_to_tf32(5.0)); - auto wi_slice_a = - ext::oneapi::detail::get_wi_data(sg, sub_a); + auto wi_slice_a = ext::oneapi::detail::get_wi_data(sg, sub_a); for (int i = 0; i < wi_slice_a.length(); i++) { wi_slice_a[i] = wi_slice_a[i] + 2; } @@ -77,8 +76,7 @@ void matrix_verify_sub(queue q, big_matrix &A, nd_range<2> &r, joint_matrix_fill(sg, sub_a, round_to_tf32(5.0)); - auto wi_slice_a = - ext::oneapi::detail::get_wi_data(sg, sub_a); + auto wi_slice_a = ext::oneapi::detail::get_wi_data(sg, sub_a); for (int i = 0; i < wi_slice_a.length(); i++) { wi_slice_a[i] = wi_slice_a[i] - round_to_tf32(2); } @@ -111,8 +109,7 @@ void matrix_verify_mul(queue q, big_matrix &A, nd_range<2> &r, joint_matrix sub_a; joint_matrix_fill(sg, sub_a, round_to_tf32(5.0)); - auto wi_slice_a = - ext::oneapi::detail::get_wi_data(sg, sub_a); + auto wi_slice_a = ext::oneapi::detail::get_wi_data(sg, sub_a); for (int i = 0; i < wi_slice_a.length(); i++) { wi_slice_a[i] = wi_slice_a[i] * round_to_tf32(3.0); } @@ -146,8 +143,7 @@ void matrix_verify_div(queue q, big_matrix &A, nd_range<2> &r, joint_matrix_fill(sg, sub_a, round_to_tf32(4.0)); - auto wi_slice_a = - ext::oneapi::detail::get_wi_data(sg, sub_a); + auto wi_slice_a = ext::oneapi::detail::get_wi_data(sg, sub_a); for (int i = 0; i < wi_slice_a.length(); i++) { wi_slice_a[i] = wi_slice_a[i] / round_to_tf32(2.0); } @@ -180,8 +176,7 @@ void matrix_verify_logic(queue q, big_matrix &A, nd_range<2> &r, joint_matrix_fill(sg, sub_a, round_to_tf32(5.0)); - auto wi_slice_a = - ext::oneapi::detail::get_wi_data(sg, sub_a); + auto wi_slice_a = ext::oneapi::detail::get_wi_data(sg, sub_a); for (int i = 0; i < wi_slice_a.length(); i++) { if (wi_slice_a[i]) { if (wi_slice_a[i] > 2.0 || wi_slice_a[i] >= 2.0 || diff --git a/sycl/test-e2e/Matrix/element_wise_all_sizes_impl.hpp b/sycl/test-e2e/Matrix/element_wise_all_sizes_impl.hpp index b18ca01193974..f1aacbef1230f 100644 --- a/sycl/test-e2e/Matrix/element_wise_all_sizes_impl.hpp +++ b/sycl/test-e2e/Matrix/element_wise_all_sizes_impl.hpp @@ -49,30 +49,29 @@ void matrix_verify_add(const T1 val1, const T1 val2, const T1 result) { q.submit([&](handler &cgh) { sycl::accessor accA{bufA, cgh, sycl::read_write}; - cgh.parallel_for(r, [=](nd_item<2> spmd_item) [[intel::reqd_sub_group_size( - SG_SZ)]] { - const auto global_idx = spmd_item.get_global_id(0); - const auto global_idy = spmd_item.get_global_id(1); - const auto sg_startx = global_idx - spmd_item.get_local_id(0); - const auto sg_starty = global_idy - spmd_item.get_local_id(1); - - sub_group sg = spmd_item.get_sub_group(); - joint_matrix sub_a; - - joint_matrix_fill(sg, sub_a, val1); - - auto wi_slice_a = - sycl::ext::oneapi::detail::get_wi_data(sg, sub_a); - for (int i = 0; i < wi_slice_a.length(); i++) { - wi_slice_a[i] = wi_slice_a[i] + val2; - } - - ext::intel::experimental::matrix::joint_matrix_store( - sg, sub_a, - accA.template get_multi_ptr() + - (sg_startx * TM) * K + sg_starty / SG_SZ * TK, - K); - }); // parallel for + cgh.parallel_for( + r, [=](nd_item<2> spmd_item) [[intel::reqd_sub_group_size(SG_SZ)]] { + const auto global_idx = spmd_item.get_global_id(0); + const auto global_idy = spmd_item.get_global_id(1); + const auto sg_startx = global_idx - spmd_item.get_local_id(0); + const auto sg_starty = global_idy - spmd_item.get_local_id(1); + + sub_group sg = spmd_item.get_sub_group(); + joint_matrix sub_a; + + joint_matrix_fill(sg, sub_a, val1); + + auto wi_slice_a = sycl::ext::oneapi::detail::get_wi_data(sg, sub_a); + for (int i = 0; i < wi_slice_a.length(); i++) { + wi_slice_a[i] = wi_slice_a[i] + val2; + } + + ext::intel::experimental::matrix::joint_matrix_store( + sg, sub_a, + accA.template get_multi_ptr() + + (sg_startx * TM) * K + sg_starty / SG_SZ * TK, + K); + }); // parallel for }).wait(); assert_ops_ref(bufA.get_host_access(), result); } diff --git a/sycl/test-e2e/Matrix/joint_matrix_down_convert_impl.hpp b/sycl/test-e2e/Matrix/joint_matrix_down_convert_impl.hpp index 6972e3854c8e8..5dd21cfe5340b 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_down_convert_impl.hpp +++ b/sycl/test-e2e/Matrix/joint_matrix_down_convert_impl.hpp @@ -49,14 +49,7 @@ void matrix_copy(big_matrix &C, big_matrix &A) { (sg_startx * TM) * N + sg_starty / SG_SZ * TN, N, layout::row_major); // This will be replaced by joint_matrix_copy API - // joint_matrix_copy(sg, sub_c, sub_ac); - auto wi_slice_c = - sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_c); - auto wi_slice_a = - sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_a); - for (int i = 0; i < wi_slice_c.length(); i++) { - wi_slice_a[i] = (bfloat16)wi_slice_c[i]; - } + joint_matrix_copy(sg, sub_c, sub_a); ext::intel::experimental::matrix::joint_matrix_store( sg, sub_a, accA.template get_multi_ptr() + diff --git a/sycl/test-e2e/Matrix/joint_matrix_out_bounds_impl.hpp b/sycl/test-e2e/Matrix/joint_matrix_out_bounds_impl.hpp index 46c49f6576e44..63eb4af659170 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_out_bounds_impl.hpp +++ b/sycl/test-e2e/Matrix/joint_matrix_out_bounds_impl.hpp @@ -58,7 +58,8 @@ void matrix_multiply(T1 *C, T2 *A, T2 *B, queue q, unsigned int vnniFactor) { joint_matrix_load_checked( sg, sub_b, pB + k * N + sg_starty / SG_SZ * TN * vnniFactor, N * vnniFactor, K / vnniFactor, N * vnniFactor); - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, + sub_c); } // bounds-checked store where width and height are added joint_matrix_store_checked( diff --git a/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-bfloat16-test.cpp b/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-bfloat16-test.cpp index 4790a1dcb1bf6..336db7c2f00e9 100644 --- a/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-bfloat16-test.cpp +++ b/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-bfloat16-test.cpp @@ -57,7 +57,8 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m16n16k16.mma.row.row.bf16(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, + sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m16n16k16.store.d.row.stride.f32.p1(ptr addrspace(1) %{{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -88,7 +89,8 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m16n16k16.mma.col.col.bf16(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, + sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m16n16k16.store.d.col.stride.f32.p1(ptr addrspace(1) %{{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -119,7 +121,8 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m32n8k16.mma.row.row.bf16(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, + sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m32n8k16.store.d.row.stride.f32.p1(ptr addrspace(1) %{{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -150,7 +153,8 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m32n8k16.mma.col.col.bf16(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, + sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m32n8k16.store.d.col.stride.f32.p1(ptr addrspace(1) %{{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -181,7 +185,8 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m8n32k16.mma.row.row.bf16(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, + sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m8n32k16.store.d.row.stride.f32.p1(ptr addrspace(1) %{{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -212,7 +217,8 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m8n32k16.mma.col.col.bf16(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, + sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m8n32k16.store.d.col.stride.f32.p1(ptr addrspace(1) %{{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), diff --git a/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-double-test.cpp b/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-double-test.cpp index 30704f8869778..c8fcccf2d015e 100644 --- a/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-double-test.cpp +++ b/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-double-test.cpp @@ -67,7 +67,8 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), N); //CHECK-OPAQUE: tail call { double, double } @llvm.nvvm.wmma.m8n8k4.mma.row.row.f64(double {{.*}}, double {{.*}}, double {{.*}}, double {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, + sub_c); //CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m8n8k4.store.d.row.stride.f64.p1(ptr addrspace(1) %{{.*}}, double {{.*}}, double {{.*}}, i32 8) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -98,7 +99,8 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), K); //CHECK-OPAQUE: tail call { double, double } @llvm.nvvm.wmma.m8n8k4.mma.col.col.f64(double {{.*}}, double {{.*}}, double {{.*}}, double {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, + sub_c); //CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m8n8k4.store.d.col.stride.f64.p1(ptr addrspace(1) %{{.*}}, double {{.*}}, double {{.*}}, i32 8) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), diff --git a/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-half-float-test.cpp b/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-half-float-test.cpp index 3331c66d302b6..b1e4015460f3f 100644 --- a/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-half-float-test.cpp +++ b/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-half-float-test.cpp @@ -56,7 +56,8 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m16n16k16.mma.row.row.f32.f32(<2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, + sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m16n16k16.store.d.row.stride.f32.p1(ptr addrspace(1) %{{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -87,7 +88,8 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m16n16k16.mma.col.col.f32.f32(<2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, + sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m16n16k16.store.d.col.stride.f32.p1(ptr addrspace(1) %{{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -118,7 +120,8 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m32n8k16.mma.row.row.f32.f32(<2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, + sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m32n8k16.store.d.row.stride.f32.p1(ptr addrspace(1) %{{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -149,7 +152,8 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m32n8k16.mma.col.col.f32.f32(<2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, + sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m32n8k16.store.d.col.stride.f32.p1(ptr addrspace(1) %{{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -180,7 +184,8 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m8n32k16.mma.row.row.f32.f32(<2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, + sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m8n32k16.store.d.row.stride.f32.p1(ptr addrspace(1) %{{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -211,7 +216,8 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m8n32k16.mma.col.col.f32.f32(<2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, + sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m8n32k16.store.d.col.stride.f32.p1(ptr addrspace(1) %{{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), diff --git a/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-half-half-test.cpp b/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-half-half-test.cpp index e4250952fa0c2..0cbc24560c589 100644 --- a/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-half-half-test.cpp +++ b/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-half-half-test.cpp @@ -56,7 +56,8 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m16n16k16.mma.row.row.f16.f16(<2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, + sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m16n16k16.store.d.row.stride.f16.p0(ptr %{{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -87,7 +88,8 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m16n16k16.mma.col.col.f16.f16(<2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, + sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m16n16k16.store.d.col.stride.f16.p0(ptr %{{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -118,7 +120,8 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m32n8k16.mma.row.row.f16.f16(<2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, + sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m32n8k16.store.d.row.stride.f16.p0(ptr %{{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -149,7 +152,8 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m32n8k16.mma.col.col.f16.f16(<2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, + sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m32n8k16.store.d.col.stride.f16.p0(ptr %{{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -180,7 +184,8 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m8n32k16.mma.row.row.f16.f16(<2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, + sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m8n32k16.store.d.row.stride.f16.p0(ptr %{{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -211,7 +216,8 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m8n32k16.mma.col.col.f16.f16(<2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, + sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m8n32k16.store.d.col.stride.f16.p0(ptr %{{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), diff --git a/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-int8-test.cpp b/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-int8-test.cpp index 50fceeb6c34dc..0743dc0a7ffc1 100644 --- a/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-int8-test.cpp +++ b/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-int8-test.cpp @@ -56,7 +56,8 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m16n16k16.mma.row.row.s8(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, + sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m16n16k16.store.d.row.stride.s32.p1(ptr addrspace(1) %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -87,7 +88,8 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m16n16k16.mma.col.col.s8(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, + sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m16n16k16.store.d.col.stride.s32.p1(ptr addrspace(1) %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -118,7 +120,8 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m32n8k16.mma.row.row.s8(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, + sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m32n8k16.store.d.row.stride.s32.p1(ptr addrspace(1) %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -149,7 +152,8 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m32n8k16.mma.col.col.s8(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, + sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m32n8k16.store.d.col.stride.s32.p1(ptr addrspace(1) %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -180,7 +184,8 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m8n32k16.mma.row.row.s8(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, + sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m8n32k16.store.d.row.stride.s32.p1(ptr addrspace(1) %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -211,7 +216,8 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m8n32k16.mma.col.col.s8(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, + sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m8n32k16.store.d.col.stride.s32.p1(ptr addrspace(1) %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), diff --git a/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-tf32-test.cpp b/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-tf32-test.cpp index e8b77588237d1..f6a10c56cd866 100644 --- a/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-tf32-test.cpp +++ b/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-tf32-test.cpp @@ -88,7 +88,8 @@ int main() { get_wi_data(sg, sub_b)[i] = round_to_tf32(get_wi_data(sg, sub_b)[i]); - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, + sub_c); //CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m16n16k16.store.d.row.stride.f32.p1(ptr addrspace(1) {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, i32 {{.*}} joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -137,7 +138,8 @@ int main() { get_wi_data(sg, sub_b)[i] = round_to_tf32(get_wi_data(sg, sub_b)[i]); - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, + sub_c); //CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m16n16k16.store.d.col.stride.f32.p1(ptr addrspace(1) {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), diff --git a/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-uint8-test.cpp b/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-uint8-test.cpp index 256bf847645e8..448dc86f3321f 100644 --- a/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-uint8-test.cpp +++ b/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-uint8-test.cpp @@ -56,7 +56,8 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m16n16k16.mma.row.row.u8(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, + sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m16n16k16.store.d.row.stride.s32.p1(ptr addrspace(1) %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -87,7 +88,8 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m16n16k16.mma.col.col.u8(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, + sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m16n16k16.store.d.col.stride.s32.p1(ptr addrspace(1) %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -118,7 +120,8 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m32n8k16.mma.row.row.u8(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, + sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m32n8k16.store.d.row.stride.s32.p1(ptr addrspace(1) %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -149,7 +152,8 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m32n8k16.mma.col.col.u8(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, + sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m32n8k16.store.d.col.stride.s32.p1(ptr addrspace(1) %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -180,7 +184,8 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m8n32k16.mma.row.row.u8(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, + sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m8n32k16.store.d.row.stride.s32.p1(ptr addrspace(1) %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -211,7 +216,8 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m8n32k16.mma.col.col.u8(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, + sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m8n32k16.store.d.col.stride.s32.p1(ptr addrspace(1) %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), diff --git a/sycl/test/matrix/legacy/matrix-bf16-test-SG-16.cpp b/sycl/test/matrix/legacy/matrix-bf16-test-SG-16.cpp index 2a4ac877a3864..b45e32786acf6 100644 --- a/sycl/test/matrix/legacy/matrix-bf16-test-SG-16.cpp +++ b/sycl/test/matrix/legacy/matrix-bf16-test-SG-16.cpp @@ -89,7 +89,8 @@ void matrix_multiply(big_matrix &C, accB.template get_multi_ptr() + (k * TK / 2) * (N * 2) + sg_starty / SG_SZ * TN * 2, N * 2, matrix_layout::packed_b); - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, + sub_c); } joint_matrix_store( sg, sub_c, diff --git a/sycl/test/matrix/legacy/matrix-bf16-test.cpp b/sycl/test/matrix/legacy/matrix-bf16-test.cpp index f58fe277d0073..f2bc1f3c5618e 100644 --- a/sycl/test/matrix/legacy/matrix-bf16-test.cpp +++ b/sycl/test/matrix/legacy/matrix-bf16-test.cpp @@ -88,7 +88,8 @@ void matrix_multiply(big_matrix &C, accB.template get_multi_ptr() + (k * TK / 2) * (N * 2) + sg_starty / SG_SZ * TN * 2, N * 2, matrix_layout::packed_b); - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, + sub_c); } joint_matrix_store( sg, sub_c, diff --git a/sycl/test/matrix/legacy/matrix-bfloat16-test.cpp b/sycl/test/matrix/legacy/matrix-bfloat16-test.cpp index 3271b27bc466d..d41d9152cd157 100644 --- a/sycl/test/matrix/legacy/matrix-bfloat16-test.cpp +++ b/sycl/test/matrix/legacy/matrix-bfloat16-test.cpp @@ -87,7 +87,8 @@ void matrix_multiply(big_matrix &C, accB.template get_multi_ptr() + (k * TK / 2) * (N * 2) + sg_starty / SG_SZ * TN * 2, N * 2, matrix_layout::packed_b); - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, + sub_c); } joint_matrix_store( sg, sub_c, diff --git a/sycl/test/matrix/legacy/matrix-elemwise-ops.cpp b/sycl/test/matrix/legacy/matrix-elemwise-ops.cpp index 5285b8eb9aa2b..a13b79deea9f2 100644 --- a/sycl/test/matrix/legacy/matrix-elemwise-ops.cpp +++ b/sycl/test/matrix/legacy/matrix-elemwise-ops.cpp @@ -88,7 +88,8 @@ void matrix_multiply(big_matrix &C, accB.template get_multi_ptr() + (k * TK / 4) * (N * 4) + sg_starty / SG_SZ * TN * 4, N * 4, matrix_layout::packed_b); - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, + sub_c); } auto wi_data_c = sub_c.get_wi_data(); for (int i = 0; i < wi_data_c.length(); i++) { diff --git a/sycl/test/matrix/legacy/matrix-int8-test-SG-16.cpp b/sycl/test/matrix/legacy/matrix-int8-test-SG-16.cpp index 47af048171265..09aa3349c38d1 100644 --- a/sycl/test/matrix/legacy/matrix-int8-test-SG-16.cpp +++ b/sycl/test/matrix/legacy/matrix-int8-test-SG-16.cpp @@ -88,7 +88,8 @@ void matrix_multiply(big_matrix &C, accB.template get_multi_ptr() + (k * TK / 4) * (N * 4) + sg_starty / SG_SZ * TN * 4, N * 4, matrix_layout::packed_b); - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, + sub_c); } joint_matrix_store( sg, sub_c, diff --git a/sycl/test/matrix/legacy/matrix-int8-test.cpp b/sycl/test/matrix/legacy/matrix-int8-test.cpp index 614faf8defe5a..457f5b4747d54 100644 --- a/sycl/test/matrix/legacy/matrix-int8-test.cpp +++ b/sycl/test/matrix/legacy/matrix-int8-test.cpp @@ -88,7 +88,8 @@ void matrix_multiply(big_matrix &C, accB.template get_multi_ptr() + (k * TK / 4) * (N * 4) + sg_starty / SG_SZ * TN * 4, N * 4, matrix_layout::packed_b); - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, + sub_c); } joint_matrix_store( sg, sub_c, From bf6cd56fe117de6cb639545870b3fb8d0c8a361f Mon Sep 17 00:00:00 2001 From: Bing1 Yu Date: Tue, 19 Sep 2023 12:25:08 +0800 Subject: [PATCH 03/50] fix typo: dest->dst --- sycl/include/sycl/ext/oneapi/matrix/matrix-unified.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sycl/include/sycl/ext/oneapi/matrix/matrix-unified.hpp b/sycl/include/sycl/ext/oneapi/matrix/matrix-unified.hpp index 55a163f4dc38e..90f267bff19c4 100644 --- a/sycl/include/sycl/ext/oneapi/matrix/matrix-unified.hpp +++ b/sycl/include/sycl/ext/oneapi/matrix/matrix-unified.hpp @@ -428,7 +428,7 @@ void joint_matrix_copy( #if defined(__NVPTX__) std::ignore = sg; for (int i = 0; i < src.cuda_impl.wi_marray.size(); i++) { - dest.cuda_impl.wi_marray[i] = src.cuda_impl.wi_marray[i]; + dst.cuda_impl.wi_marray[i] = src.cuda_impl.wi_marray[i]; } #else using storage_element_type = From b399041060b065afa32904f1181af9aef570ed57 Mon Sep 17 00:00:00 2001 From: Bing1 Yu Date: Tue, 19 Sep 2023 13:53:35 +0800 Subject: [PATCH 04/50] fix testcase --- .../Matrix/Legacy/element_wise_ops_impl.hpp | 3 +-- .../Legacy/elemwise_irreg_size_ops_bf16.cpp | 3 +-- .../Matrix/Legacy/joint_matrix_bf16_impl.hpp | 3 +-- .../joint_matrix_bfloat16_32x64_impl.hpp | 3 +-- ...atrix_bfloat16_colmajorA_colmajorB_impl.hpp | 3 +-- .../Legacy/joint_matrix_bfloat16_impl.hpp | 3 +-- ...atrix_bfloat16_rowmajorA_rowmajorB_impl.hpp | 3 +-- .../Matrix/Legacy/joint_matrix_half_impl.hpp | 3 +-- ...nt_matrix_int8_colmajorA_colmajorB_impl.hpp | 3 +-- .../Legacy/joint_matrix_int8_vnni_impl.hpp | 3 +-- .../Legacy/joint_matrix_query_default.cpp | 3 +-- .../Legacy/joint_matrix_ss_int8_impl.hpp | 3 +-- .../Legacy/joint_matrix_su_int8_impl.hpp | 3 +-- .../Legacy/joint_matrix_us_int8_impl.hpp | 3 +-- .../Legacy/joint_matrix_uu_int8_impl.hpp | 3 +-- .../Matrix/joint_matrix_out_bounds_impl.hpp | 3 +-- .../cuda/matrix/matrix-nvptx-bfloat16-test.cpp | 18 ++++++------------ .../cuda/matrix/matrix-nvptx-double-test.cpp | 6 ++---- .../matrix/matrix-nvptx-half-float-test.cpp | 18 ++++++------------ .../matrix/matrix-nvptx-half-half-test.cpp | 18 ++++++------------ .../cuda/matrix/matrix-nvptx-int8-test.cpp | 18 ++++++------------ .../cuda/matrix/matrix-nvptx-tf32-test.cpp | 6 ++---- .../cuda/matrix/matrix-nvptx-uint8-test.cpp | 18 ++++++------------ .../matrix/legacy/matrix-bf16-test-SG-16.cpp | 3 +-- sycl/test/matrix/legacy/matrix-bf16-test.cpp | 3 +-- .../matrix/legacy/matrix-bfloat16-test.cpp | 3 +-- .../test/matrix/legacy/matrix-elemwise-ops.cpp | 3 +-- .../matrix/legacy/matrix-int8-test-SG-16.cpp | 3 +-- sycl/test/matrix/legacy/matrix-int8-test.cpp | 3 +-- 29 files changed, 56 insertions(+), 112 deletions(-) diff --git a/sycl/test-e2e/Matrix/Legacy/element_wise_ops_impl.hpp b/sycl/test-e2e/Matrix/Legacy/element_wise_ops_impl.hpp index dcc5764db1fcb..51eb10095b4a5 100644 --- a/sycl/test-e2e/Matrix/Legacy/element_wise_ops_impl.hpp +++ b/sycl/test-e2e/Matrix/Legacy/element_wise_ops_impl.hpp @@ -76,8 +76,7 @@ void matrix_multiply(big_matrix &C, accB.template get_multi_ptr() + (k * TK / 4) * (N * 4) + sg_starty / SG_SZ * TN * 4, N * 4, matrix_layout::packed_b); - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, - sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c); } auto wi_slice_c = sub_c.get_wi_data(); for (int i = 0; i < wi_slice_c.length(); i++) { diff --git a/sycl/test-e2e/Matrix/Legacy/elemwise_irreg_size_ops_bf16.cpp b/sycl/test-e2e/Matrix/Legacy/elemwise_irreg_size_ops_bf16.cpp index d0ee5869b1d54..b6e2e5fbe2315 100644 --- a/sycl/test-e2e/Matrix/Legacy/elemwise_irreg_size_ops_bf16.cpp +++ b/sycl/test-e2e/Matrix/Legacy/elemwise_irreg_size_ops_bf16.cpp @@ -100,8 +100,7 @@ void matrix_multiply(big_matrix &C, accB.template get_multi_ptr() + (k) * (N) + sg_starty / SG_SZ * TN * 2, N * 2, matrix_layout::packed_b); - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, - sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c); } auto wi_slice_c = sub_c.get_wi_data(); for (int i = 0; i < wi_slice_c.length(); i++) { diff --git a/sycl/test-e2e/Matrix/Legacy/joint_matrix_bf16_impl.hpp b/sycl/test-e2e/Matrix/Legacy/joint_matrix_bf16_impl.hpp index 060830074d1fe..847ae955fd41e 100644 --- a/sycl/test-e2e/Matrix/Legacy/joint_matrix_bf16_impl.hpp +++ b/sycl/test-e2e/Matrix/Legacy/joint_matrix_bf16_impl.hpp @@ -76,8 +76,7 @@ void matrix_multiply(big_matrix &C, accB.template get_multi_ptr() + (k) * (N) + sg_starty / SG_SZ * TN * 2, N * 2, matrix_layout::packed_b); - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, - sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c); } joint_matrix_store( sg, sub_c, diff --git a/sycl/test-e2e/Matrix/Legacy/joint_matrix_bfloat16_32x64_impl.hpp b/sycl/test-e2e/Matrix/Legacy/joint_matrix_bfloat16_32x64_impl.hpp index b8ed30ae495b9..49775fb18d437 100644 --- a/sycl/test-e2e/Matrix/Legacy/joint_matrix_bfloat16_32x64_impl.hpp +++ b/sycl/test-e2e/Matrix/Legacy/joint_matrix_bfloat16_32x64_impl.hpp @@ -68,8 +68,7 @@ void matrix_multiply(big_matrix &C, big_matrix &A, accB.template get_multi_ptr() + (k * TK / 2) * (N * 2) + sg_starty / SG_SZ * TN * 2, N * 2, matrix_layout::packed_b); - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, - sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c); } joint_matrix_store( sg, sub_c, diff --git a/sycl/test-e2e/Matrix/Legacy/joint_matrix_bfloat16_colmajorA_colmajorB_impl.hpp b/sycl/test-e2e/Matrix/Legacy/joint_matrix_bfloat16_colmajorA_colmajorB_impl.hpp index 5738b3a109a9c..d322210bf7728 100644 --- a/sycl/test-e2e/Matrix/Legacy/joint_matrix_bfloat16_colmajorA_colmajorB_impl.hpp +++ b/sycl/test-e2e/Matrix/Legacy/joint_matrix_bfloat16_colmajorA_colmajorB_impl.hpp @@ -62,8 +62,7 @@ void matrix_multiply(big_matrix &C, big_matrix &A, accB.template get_multi_ptr() + (sg_starty / SG_SZ * TN) * K + k * TK, K, matrix_layout::col_major); - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, - sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c); } joint_matrix_store( sg, sub_c, diff --git a/sycl/test-e2e/Matrix/Legacy/joint_matrix_bfloat16_impl.hpp b/sycl/test-e2e/Matrix/Legacy/joint_matrix_bfloat16_impl.hpp index 9358ffb4168cd..610ac5158794a 100644 --- a/sycl/test-e2e/Matrix/Legacy/joint_matrix_bfloat16_impl.hpp +++ b/sycl/test-e2e/Matrix/Legacy/joint_matrix_bfloat16_impl.hpp @@ -68,8 +68,7 @@ void matrix_multiply(big_matrix &C, big_matrix &A, accB.template get_multi_ptr() + (k * TK / 2) * (N * 2) + sg_starty / SG_SZ * TN * 2, N * 2, matrix_layout::packed_b); - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, - sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c); } joint_matrix_store( sg, sub_c, diff --git a/sycl/test-e2e/Matrix/Legacy/joint_matrix_bfloat16_rowmajorA_rowmajorB_impl.hpp b/sycl/test-e2e/Matrix/Legacy/joint_matrix_bfloat16_rowmajorA_rowmajorB_impl.hpp index 1a567c9147097..11b74d7270d27 100644 --- a/sycl/test-e2e/Matrix/Legacy/joint_matrix_bfloat16_rowmajorA_rowmajorB_impl.hpp +++ b/sycl/test-e2e/Matrix/Legacy/joint_matrix_bfloat16_rowmajorA_rowmajorB_impl.hpp @@ -62,8 +62,7 @@ void matrix_multiply(big_matrix &C, big_matrix &A, accB.template get_multi_ptr() + (k * TK) * (N) + sg_starty / SG_SZ * TN, N, matrix_layout::row_major); - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, - sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c); } joint_matrix_store( sg, sub_c, diff --git a/sycl/test-e2e/Matrix/Legacy/joint_matrix_half_impl.hpp b/sycl/test-e2e/Matrix/Legacy/joint_matrix_half_impl.hpp index 86439c83ff840..c1f780e09144e 100644 --- a/sycl/test-e2e/Matrix/Legacy/joint_matrix_half_impl.hpp +++ b/sycl/test-e2e/Matrix/Legacy/joint_matrix_half_impl.hpp @@ -73,8 +73,7 @@ void matrix_multiply(big_matrix &C, accB.template get_multi_ptr() + (k * TK / 2) * (N * 2) + sg_starty / SG_SZ * TN * 2, N * 2, matrix_layout::packed_b); - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, - sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c); } joint_matrix_store( sg, sub_c, diff --git a/sycl/test-e2e/Matrix/Legacy/joint_matrix_int8_colmajorA_colmajorB_impl.hpp b/sycl/test-e2e/Matrix/Legacy/joint_matrix_int8_colmajorA_colmajorB_impl.hpp index a8c97d97aced3..6fb5b3981879e 100644 --- a/sycl/test-e2e/Matrix/Legacy/joint_matrix_int8_colmajorA_colmajorB_impl.hpp +++ b/sycl/test-e2e/Matrix/Legacy/joint_matrix_int8_colmajorA_colmajorB_impl.hpp @@ -72,8 +72,7 @@ void matrix_multiply(big_matrix &C, accB.template get_multi_ptr() + (sg_starty / SG_SZ * TN) * K + k * TK, K, matrix_layout::col_major); - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, - sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c); } joint_matrix_store( sg, sub_c, diff --git a/sycl/test-e2e/Matrix/Legacy/joint_matrix_int8_vnni_impl.hpp b/sycl/test-e2e/Matrix/Legacy/joint_matrix_int8_vnni_impl.hpp index 9a9d895f4771c..066934b98221c 100644 --- a/sycl/test-e2e/Matrix/Legacy/joint_matrix_int8_vnni_impl.hpp +++ b/sycl/test-e2e/Matrix/Legacy/joint_matrix_int8_vnni_impl.hpp @@ -64,8 +64,7 @@ void matrix_multiply(big_matrix &C, accB.template get_multi_ptr() + (k * TK) * N + sg_starty / SG_SZ * TN, N, matrix_layout::row_major); - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, - sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c); } joint_matrix_store( sg, sub_c, diff --git a/sycl/test-e2e/Matrix/Legacy/joint_matrix_query_default.cpp b/sycl/test-e2e/Matrix/Legacy/joint_matrix_query_default.cpp index fb3dceee3f361..0dd9cf7e1ec6c 100644 --- a/sycl/test-e2e/Matrix/Legacy/joint_matrix_query_default.cpp +++ b/sycl/test-e2e/Matrix/Legacy/joint_matrix_query_default.cpp @@ -97,8 +97,7 @@ void matrix_multiply(big_matrix &C, accB.template get_multi_ptr() + (k * TK / 4) * (N * 4) + sg_starty / SG_SZ * TN * 4, N * 4, matrix_layout::packed_b); - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, - sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c); } joint_matrix_store( sg, sub_c, diff --git a/sycl/test-e2e/Matrix/Legacy/joint_matrix_ss_int8_impl.hpp b/sycl/test-e2e/Matrix/Legacy/joint_matrix_ss_int8_impl.hpp index 1598f8b78b3a2..c1a5d1c762e14 100644 --- a/sycl/test-e2e/Matrix/Legacy/joint_matrix_ss_int8_impl.hpp +++ b/sycl/test-e2e/Matrix/Legacy/joint_matrix_ss_int8_impl.hpp @@ -70,8 +70,7 @@ void matrix_multiply(big_matrix &C, accB.template get_multi_ptr() + (k * TK / 4) * (N * 4) + sg_starty / SG_SZ * TN * 4, N * 4, matrix_layout::packed_b); - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, - sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c); } joint_matrix_store( sg, sub_c, diff --git a/sycl/test-e2e/Matrix/Legacy/joint_matrix_su_int8_impl.hpp b/sycl/test-e2e/Matrix/Legacy/joint_matrix_su_int8_impl.hpp index 34320622f006c..630708e0b54aa 100644 --- a/sycl/test-e2e/Matrix/Legacy/joint_matrix_su_int8_impl.hpp +++ b/sycl/test-e2e/Matrix/Legacy/joint_matrix_su_int8_impl.hpp @@ -76,8 +76,7 @@ void matrix_multiply(big_matrix &C, accB.template get_multi_ptr() + (k * TK / 4) * (N * 4) + sg_starty / SG_SZ * TN * 4, N * 4, matrix_layout::packed_b); - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, - sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c); } joint_matrix_store( sg, sub_c, diff --git a/sycl/test-e2e/Matrix/Legacy/joint_matrix_us_int8_impl.hpp b/sycl/test-e2e/Matrix/Legacy/joint_matrix_us_int8_impl.hpp index bc0ed40202116..8072f813fdb26 100644 --- a/sycl/test-e2e/Matrix/Legacy/joint_matrix_us_int8_impl.hpp +++ b/sycl/test-e2e/Matrix/Legacy/joint_matrix_us_int8_impl.hpp @@ -78,8 +78,7 @@ void matrix_multiply(big_matrix &C, accB.template get_multi_ptr() + (k * TK / 4) * (N * 4) + sg_starty / SG_SZ * TN * 4, N * 4, matrix_layout::packed_b); - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, - sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c); } joint_matrix_store( sg, sub_c, diff --git a/sycl/test-e2e/Matrix/Legacy/joint_matrix_uu_int8_impl.hpp b/sycl/test-e2e/Matrix/Legacy/joint_matrix_uu_int8_impl.hpp index 164892179f674..6ee550537f285 100644 --- a/sycl/test-e2e/Matrix/Legacy/joint_matrix_uu_int8_impl.hpp +++ b/sycl/test-e2e/Matrix/Legacy/joint_matrix_uu_int8_impl.hpp @@ -76,8 +76,7 @@ void matrix_multiply(big_matrix &C, accB.template get_multi_ptr() + (k * TK / 4) * (N * 4) + sg_starty / SG_SZ * TN * 4, N * 4, matrix_layout::packed_b); - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, - sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c); } joint_matrix_store( sg, sub_c, diff --git a/sycl/test-e2e/Matrix/joint_matrix_out_bounds_impl.hpp b/sycl/test-e2e/Matrix/joint_matrix_out_bounds_impl.hpp index 63eb4af659170..67f99facdd96f 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_out_bounds_impl.hpp +++ b/sycl/test-e2e/Matrix/joint_matrix_out_bounds_impl.hpp @@ -58,8 +58,7 @@ void matrix_multiply(T1 *C, T2 *A, T2 *B, queue q, unsigned int vnniFactor) { joint_matrix_load_checked( sg, sub_b, pB + k * N + sg_starty / SG_SZ * TN * vnniFactor, N * vnniFactor, K / vnniFactor, N * vnniFactor); - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, - sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c); } // bounds-checked store where width and height are added joint_matrix_store_checked( diff --git a/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-bfloat16-test.cpp b/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-bfloat16-test.cpp index 336db7c2f00e9..784c3ad489cb6 100644 --- a/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-bfloat16-test.cpp +++ b/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-bfloat16-test.cpp @@ -57,8 +57,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m16n16k16.mma.row.row.bf16(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, - sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m16n16k16.store.d.row.stride.f32.p1(ptr addrspace(1) %{{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -89,8 +88,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m16n16k16.mma.col.col.bf16(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, - sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m16n16k16.store.d.col.stride.f32.p1(ptr addrspace(1) %{{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -121,8 +119,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m32n8k16.mma.row.row.bf16(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, - sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m32n8k16.store.d.row.stride.f32.p1(ptr addrspace(1) %{{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -153,8 +150,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m32n8k16.mma.col.col.bf16(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, - sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m32n8k16.store.d.col.stride.f32.p1(ptr addrspace(1) %{{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -185,8 +181,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m8n32k16.mma.row.row.bf16(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, - sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m8n32k16.store.d.row.stride.f32.p1(ptr addrspace(1) %{{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -217,8 +212,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m8n32k16.mma.col.col.bf16(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, - sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m8n32k16.store.d.col.stride.f32.p1(ptr addrspace(1) %{{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), diff --git a/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-double-test.cpp b/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-double-test.cpp index c8fcccf2d015e..0090805e7a55c 100644 --- a/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-double-test.cpp +++ b/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-double-test.cpp @@ -67,8 +67,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), N); //CHECK-OPAQUE: tail call { double, double } @llvm.nvvm.wmma.m8n8k4.mma.row.row.f64(double {{.*}}, double {{.*}}, double {{.*}}, double {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, - sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c); //CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m8n8k4.store.d.row.stride.f64.p1(ptr addrspace(1) %{{.*}}, double {{.*}}, double {{.*}}, i32 8) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -99,8 +98,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), K); //CHECK-OPAQUE: tail call { double, double } @llvm.nvvm.wmma.m8n8k4.mma.col.col.f64(double {{.*}}, double {{.*}}, double {{.*}}, double {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, - sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c); //CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m8n8k4.store.d.col.stride.f64.p1(ptr addrspace(1) %{{.*}}, double {{.*}}, double {{.*}}, i32 8) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), diff --git a/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-half-float-test.cpp b/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-half-float-test.cpp index b1e4015460f3f..209341a71e03d 100644 --- a/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-half-float-test.cpp +++ b/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-half-float-test.cpp @@ -56,8 +56,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m16n16k16.mma.row.row.f32.f32(<2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, - sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m16n16k16.store.d.row.stride.f32.p1(ptr addrspace(1) %{{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -88,8 +87,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m16n16k16.mma.col.col.f32.f32(<2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, - sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m16n16k16.store.d.col.stride.f32.p1(ptr addrspace(1) %{{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -120,8 +118,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m32n8k16.mma.row.row.f32.f32(<2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, - sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m32n8k16.store.d.row.stride.f32.p1(ptr addrspace(1) %{{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -152,8 +149,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m32n8k16.mma.col.col.f32.f32(<2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, - sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m32n8k16.store.d.col.stride.f32.p1(ptr addrspace(1) %{{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -184,8 +180,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m8n32k16.mma.row.row.f32.f32(<2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, - sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m8n32k16.store.d.row.stride.f32.p1(ptr addrspace(1) %{{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -216,8 +211,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m8n32k16.mma.col.col.f32.f32(<2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, - sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m8n32k16.store.d.col.stride.f32.p1(ptr addrspace(1) %{{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), diff --git a/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-half-half-test.cpp b/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-half-half-test.cpp index 0cbc24560c589..e78fbe523dd29 100644 --- a/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-half-half-test.cpp +++ b/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-half-half-test.cpp @@ -56,8 +56,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m16n16k16.mma.row.row.f16.f16(<2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, - sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m16n16k16.store.d.row.stride.f16.p0(ptr %{{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -88,8 +87,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m16n16k16.mma.col.col.f16.f16(<2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, - sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m16n16k16.store.d.col.stride.f16.p0(ptr %{{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -120,8 +118,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m32n8k16.mma.row.row.f16.f16(<2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, - sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m32n8k16.store.d.row.stride.f16.p0(ptr %{{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -152,8 +149,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m32n8k16.mma.col.col.f16.f16(<2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, - sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m32n8k16.store.d.col.stride.f16.p0(ptr %{{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -184,8 +180,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m8n32k16.mma.row.row.f16.f16(<2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, - sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m8n32k16.store.d.row.stride.f16.p0(ptr %{{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -216,8 +211,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m8n32k16.mma.col.col.f16.f16(<2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, - sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m8n32k16.store.d.col.stride.f16.p0(ptr %{{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), diff --git a/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-int8-test.cpp b/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-int8-test.cpp index 0743dc0a7ffc1..743f9fd54e12e 100644 --- a/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-int8-test.cpp +++ b/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-int8-test.cpp @@ -56,8 +56,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m16n16k16.mma.row.row.s8(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, - sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m16n16k16.store.d.row.stride.s32.p1(ptr addrspace(1) %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -88,8 +87,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m16n16k16.mma.col.col.s8(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, - sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m16n16k16.store.d.col.stride.s32.p1(ptr addrspace(1) %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -120,8 +118,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m32n8k16.mma.row.row.s8(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, - sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m32n8k16.store.d.row.stride.s32.p1(ptr addrspace(1) %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -152,8 +149,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m32n8k16.mma.col.col.s8(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, - sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m32n8k16.store.d.col.stride.s32.p1(ptr addrspace(1) %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -184,8 +180,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m8n32k16.mma.row.row.s8(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, - sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m8n32k16.store.d.row.stride.s32.p1(ptr addrspace(1) %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -216,8 +211,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m8n32k16.mma.col.col.s8(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, - sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m8n32k16.store.d.col.stride.s32.p1(ptr addrspace(1) %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), diff --git a/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-tf32-test.cpp b/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-tf32-test.cpp index f6a10c56cd866..d3e28e94e5e71 100644 --- a/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-tf32-test.cpp +++ b/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-tf32-test.cpp @@ -88,8 +88,7 @@ int main() { get_wi_data(sg, sub_b)[i] = round_to_tf32(get_wi_data(sg, sub_b)[i]); - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, - sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c); //CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m16n16k16.store.d.row.stride.f32.p1(ptr addrspace(1) {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, i32 {{.*}} joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -138,8 +137,7 @@ int main() { get_wi_data(sg, sub_b)[i] = round_to_tf32(get_wi_data(sg, sub_b)[i]); - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, - sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c); //CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m16n16k16.store.d.col.stride.f32.p1(ptr addrspace(1) {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), diff --git a/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-uint8-test.cpp b/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-uint8-test.cpp index 448dc86f3321f..1aa82e27f6c68 100644 --- a/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-uint8-test.cpp +++ b/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-uint8-test.cpp @@ -56,8 +56,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m16n16k16.mma.row.row.u8(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, - sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m16n16k16.store.d.row.stride.s32.p1(ptr addrspace(1) %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -88,8 +87,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m16n16k16.mma.col.col.u8(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, - sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m16n16k16.store.d.col.stride.s32.p1(ptr addrspace(1) %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -120,8 +118,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m32n8k16.mma.row.row.u8(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, - sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m32n8k16.store.d.row.stride.s32.p1(ptr addrspace(1) %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -152,8 +149,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m32n8k16.mma.col.col.u8(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, - sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m32n8k16.store.d.col.stride.s32.p1(ptr addrspace(1) %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -184,8 +180,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m8n32k16.mma.row.row.u8(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, - sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m8n32k16.store.d.row.stride.s32.p1(ptr addrspace(1) %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -216,8 +211,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m8n32k16.mma.col.col.u8(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, - sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m8n32k16.store.d.col.stride.s32.p1(ptr addrspace(1) %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), diff --git a/sycl/test/matrix/legacy/matrix-bf16-test-SG-16.cpp b/sycl/test/matrix/legacy/matrix-bf16-test-SG-16.cpp index b45e32786acf6..c33848a81a2ed 100644 --- a/sycl/test/matrix/legacy/matrix-bf16-test-SG-16.cpp +++ b/sycl/test/matrix/legacy/matrix-bf16-test-SG-16.cpp @@ -89,8 +89,7 @@ void matrix_multiply(big_matrix &C, accB.template get_multi_ptr() + (k * TK / 2) * (N * 2) + sg_starty / SG_SZ * TN * 2, N * 2, matrix_layout::packed_b); - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, - sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c); } joint_matrix_store( sg, sub_c, diff --git a/sycl/test/matrix/legacy/matrix-bf16-test.cpp b/sycl/test/matrix/legacy/matrix-bf16-test.cpp index f2bc1f3c5618e..bd989a6e34d0f 100644 --- a/sycl/test/matrix/legacy/matrix-bf16-test.cpp +++ b/sycl/test/matrix/legacy/matrix-bf16-test.cpp @@ -88,8 +88,7 @@ void matrix_multiply(big_matrix &C, accB.template get_multi_ptr() + (k * TK / 2) * (N * 2) + sg_starty / SG_SZ * TN * 2, N * 2, matrix_layout::packed_b); - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, - sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c); } joint_matrix_store( sg, sub_c, diff --git a/sycl/test/matrix/legacy/matrix-bfloat16-test.cpp b/sycl/test/matrix/legacy/matrix-bfloat16-test.cpp index d41d9152cd157..92715e3b488da 100644 --- a/sycl/test/matrix/legacy/matrix-bfloat16-test.cpp +++ b/sycl/test/matrix/legacy/matrix-bfloat16-test.cpp @@ -87,8 +87,7 @@ void matrix_multiply(big_matrix &C, accB.template get_multi_ptr() + (k * TK / 2) * (N * 2) + sg_starty / SG_SZ * TN * 2, N * 2, matrix_layout::packed_b); - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, - sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c); } joint_matrix_store( sg, sub_c, diff --git a/sycl/test/matrix/legacy/matrix-elemwise-ops.cpp b/sycl/test/matrix/legacy/matrix-elemwise-ops.cpp index a13b79deea9f2..3cb773a2c2239 100644 --- a/sycl/test/matrix/legacy/matrix-elemwise-ops.cpp +++ b/sycl/test/matrix/legacy/matrix-elemwise-ops.cpp @@ -88,8 +88,7 @@ void matrix_multiply(big_matrix &C, accB.template get_multi_ptr() + (k * TK / 4) * (N * 4) + sg_starty / SG_SZ * TN * 4, N * 4, matrix_layout::packed_b); - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, - sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c); } auto wi_data_c = sub_c.get_wi_data(); for (int i = 0; i < wi_data_c.length(); i++) { diff --git a/sycl/test/matrix/legacy/matrix-int8-test-SG-16.cpp b/sycl/test/matrix/legacy/matrix-int8-test-SG-16.cpp index 09aa3349c38d1..0bc7b66b2e878 100644 --- a/sycl/test/matrix/legacy/matrix-int8-test-SG-16.cpp +++ b/sycl/test/matrix/legacy/matrix-int8-test-SG-16.cpp @@ -88,8 +88,7 @@ void matrix_multiply(big_matrix &C, accB.template get_multi_ptr() + (k * TK / 4) * (N * 4) + sg_starty / SG_SZ * TN * 4, N * 4, matrix_layout::packed_b); - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, - sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c); } joint_matrix_store( sg, sub_c, diff --git a/sycl/test/matrix/legacy/matrix-int8-test.cpp b/sycl/test/matrix/legacy/matrix-int8-test.cpp index 457f5b4747d54..b3fe44fa56250 100644 --- a/sycl/test/matrix/legacy/matrix-int8-test.cpp +++ b/sycl/test/matrix/legacy/matrix-int8-test.cpp @@ -88,8 +88,7 @@ void matrix_multiply(big_matrix &C, accB.template get_multi_ptr() + (k * TK / 4) * (N * 4) + sg_starty / SG_SZ * TN * 4, N * 4, matrix_layout::packed_b); - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c)(sg, sub_a, sub_b, - sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c); } joint_matrix_store( sg, sub_c, From dae1ec6dabedfb65e92884ecccd162472aeba3cc Mon Sep 17 00:00:00 2001 From: Bing1 Yu Date: Tue, 19 Sep 2023 14:16:27 +0800 Subject: [PATCH 05/50] fix mad bug --- .../test-e2e/Matrix/Legacy/element_wise_ops_impl.hpp | 2 +- .../Matrix/Legacy/elemwise_irreg_size_ops_bf16.cpp | 2 +- .../Matrix/Legacy/joint_matrix_bf16_impl.hpp | 2 +- .../Legacy/joint_matrix_bfloat16_32x64_impl.hpp | 2 +- ...oint_matrix_bfloat16_colmajorA_colmajorB_impl.hpp | 2 +- .../Matrix/Legacy/joint_matrix_bfloat16_impl.hpp | 2 +- ...oint_matrix_bfloat16_rowmajorA_rowmajorB_impl.hpp | 2 +- .../Matrix/Legacy/joint_matrix_half_impl.hpp | 2 +- .../joint_matrix_int8_colmajorA_colmajorB_impl.hpp | 2 +- .../Matrix/Legacy/joint_matrix_int8_vnni_impl.hpp | 2 +- .../Matrix/Legacy/joint_matrix_query_default.cpp | 2 +- .../Matrix/Legacy/joint_matrix_ss_int8_impl.hpp | 2 +- .../Matrix/Legacy/joint_matrix_su_int8_impl.hpp | 2 +- .../Matrix/Legacy/joint_matrix_us_int8_impl.hpp | 2 +- .../Matrix/Legacy/joint_matrix_uu_int8_impl.hpp | 2 +- .../test-e2e/Matrix/joint_matrix_out_bounds_impl.hpp | 2 +- .../cuda/matrix/matrix-nvptx-bfloat16-test.cpp | 12 ++++++------ .../cuda/matrix/matrix-nvptx-double-test.cpp | 4 ++-- .../cuda/matrix/matrix-nvptx-half-float-test.cpp | 12 ++++++------ .../cuda/matrix/matrix-nvptx-half-half-test.cpp | 12 ++++++------ .../cuda/matrix/matrix-nvptx-int8-test.cpp | 12 ++++++------ .../cuda/matrix/matrix-nvptx-tf32-test.cpp | 4 ++-- .../cuda/matrix/matrix-nvptx-uint8-test.cpp | 12 ++++++------ sycl/test/matrix/legacy/matrix-bf16-test-SG-16.cpp | 2 +- sycl/test/matrix/legacy/matrix-bf16-test.cpp | 2 +- sycl/test/matrix/legacy/matrix-bfloat16-test.cpp | 2 +- sycl/test/matrix/legacy/matrix-elemwise-ops.cpp | 2 +- sycl/test/matrix/legacy/matrix-int8-test-SG-16.cpp | 2 +- sycl/test/matrix/legacy/matrix-int8-test.cpp | 2 +- 29 files changed, 56 insertions(+), 56 deletions(-) diff --git a/sycl/test-e2e/Matrix/Legacy/element_wise_ops_impl.hpp b/sycl/test-e2e/Matrix/Legacy/element_wise_ops_impl.hpp index 51eb10095b4a5..2ef278e229ff5 100644 --- a/sycl/test-e2e/Matrix/Legacy/element_wise_ops_impl.hpp +++ b/sycl/test-e2e/Matrix/Legacy/element_wise_ops_impl.hpp @@ -76,7 +76,7 @@ void matrix_multiply(big_matrix &C, accB.template get_multi_ptr() + (k * TK / 4) * (N * 4) + sg_starty / SG_SZ * TN * 4, N * 4, matrix_layout::packed_b); - joint_matrix_mad(sg, sub_a, sub_b, sub_c); + sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); } auto wi_slice_c = sub_c.get_wi_data(); for (int i = 0; i < wi_slice_c.length(); i++) { diff --git a/sycl/test-e2e/Matrix/Legacy/elemwise_irreg_size_ops_bf16.cpp b/sycl/test-e2e/Matrix/Legacy/elemwise_irreg_size_ops_bf16.cpp index b6e2e5fbe2315..0f57377c571ac 100644 --- a/sycl/test-e2e/Matrix/Legacy/elemwise_irreg_size_ops_bf16.cpp +++ b/sycl/test-e2e/Matrix/Legacy/elemwise_irreg_size_ops_bf16.cpp @@ -100,7 +100,7 @@ void matrix_multiply(big_matrix &C, accB.template get_multi_ptr() + (k) * (N) + sg_starty / SG_SZ * TN * 2, N * 2, matrix_layout::packed_b); - joint_matrix_mad(sg, sub_a, sub_b, sub_c); + sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); } auto wi_slice_c = sub_c.get_wi_data(); for (int i = 0; i < wi_slice_c.length(); i++) { diff --git a/sycl/test-e2e/Matrix/Legacy/joint_matrix_bf16_impl.hpp b/sycl/test-e2e/Matrix/Legacy/joint_matrix_bf16_impl.hpp index 847ae955fd41e..9868aef0d92e2 100644 --- a/sycl/test-e2e/Matrix/Legacy/joint_matrix_bf16_impl.hpp +++ b/sycl/test-e2e/Matrix/Legacy/joint_matrix_bf16_impl.hpp @@ -76,7 +76,7 @@ void matrix_multiply(big_matrix &C, accB.template get_multi_ptr() + (k) * (N) + sg_starty / SG_SZ * TN * 2, N * 2, matrix_layout::packed_b); - joint_matrix_mad(sg, sub_a, sub_b, sub_c); + sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); } joint_matrix_store( sg, sub_c, diff --git a/sycl/test-e2e/Matrix/Legacy/joint_matrix_bfloat16_32x64_impl.hpp b/sycl/test-e2e/Matrix/Legacy/joint_matrix_bfloat16_32x64_impl.hpp index 49775fb18d437..ac4a0bc405816 100644 --- a/sycl/test-e2e/Matrix/Legacy/joint_matrix_bfloat16_32x64_impl.hpp +++ b/sycl/test-e2e/Matrix/Legacy/joint_matrix_bfloat16_32x64_impl.hpp @@ -68,7 +68,7 @@ void matrix_multiply(big_matrix &C, big_matrix &A, accB.template get_multi_ptr() + (k * TK / 2) * (N * 2) + sg_starty / SG_SZ * TN * 2, N * 2, matrix_layout::packed_b); - joint_matrix_mad(sg, sub_a, sub_b, sub_c); + sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); } joint_matrix_store( sg, sub_c, diff --git a/sycl/test-e2e/Matrix/Legacy/joint_matrix_bfloat16_colmajorA_colmajorB_impl.hpp b/sycl/test-e2e/Matrix/Legacy/joint_matrix_bfloat16_colmajorA_colmajorB_impl.hpp index d322210bf7728..91845ac61a180 100644 --- a/sycl/test-e2e/Matrix/Legacy/joint_matrix_bfloat16_colmajorA_colmajorB_impl.hpp +++ b/sycl/test-e2e/Matrix/Legacy/joint_matrix_bfloat16_colmajorA_colmajorB_impl.hpp @@ -62,7 +62,7 @@ void matrix_multiply(big_matrix &C, big_matrix &A, accB.template get_multi_ptr() + (sg_starty / SG_SZ * TN) * K + k * TK, K, matrix_layout::col_major); - joint_matrix_mad(sg, sub_a, sub_b, sub_c); + sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); } joint_matrix_store( sg, sub_c, diff --git a/sycl/test-e2e/Matrix/Legacy/joint_matrix_bfloat16_impl.hpp b/sycl/test-e2e/Matrix/Legacy/joint_matrix_bfloat16_impl.hpp index 610ac5158794a..2598905f9f6fe 100644 --- a/sycl/test-e2e/Matrix/Legacy/joint_matrix_bfloat16_impl.hpp +++ b/sycl/test-e2e/Matrix/Legacy/joint_matrix_bfloat16_impl.hpp @@ -68,7 +68,7 @@ void matrix_multiply(big_matrix &C, big_matrix &A, accB.template get_multi_ptr() + (k * TK / 2) * (N * 2) + sg_starty / SG_SZ * TN * 2, N * 2, matrix_layout::packed_b); - joint_matrix_mad(sg, sub_a, sub_b, sub_c); + sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); } joint_matrix_store( sg, sub_c, diff --git a/sycl/test-e2e/Matrix/Legacy/joint_matrix_bfloat16_rowmajorA_rowmajorB_impl.hpp b/sycl/test-e2e/Matrix/Legacy/joint_matrix_bfloat16_rowmajorA_rowmajorB_impl.hpp index 11b74d7270d27..c293d8ff22944 100644 --- a/sycl/test-e2e/Matrix/Legacy/joint_matrix_bfloat16_rowmajorA_rowmajorB_impl.hpp +++ b/sycl/test-e2e/Matrix/Legacy/joint_matrix_bfloat16_rowmajorA_rowmajorB_impl.hpp @@ -62,7 +62,7 @@ void matrix_multiply(big_matrix &C, big_matrix &A, accB.template get_multi_ptr() + (k * TK) * (N) + sg_starty / SG_SZ * TN, N, matrix_layout::row_major); - joint_matrix_mad(sg, sub_a, sub_b, sub_c); + sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); } joint_matrix_store( sg, sub_c, diff --git a/sycl/test-e2e/Matrix/Legacy/joint_matrix_half_impl.hpp b/sycl/test-e2e/Matrix/Legacy/joint_matrix_half_impl.hpp index c1f780e09144e..81ca8faa4977d 100644 --- a/sycl/test-e2e/Matrix/Legacy/joint_matrix_half_impl.hpp +++ b/sycl/test-e2e/Matrix/Legacy/joint_matrix_half_impl.hpp @@ -73,7 +73,7 @@ void matrix_multiply(big_matrix &C, accB.template get_multi_ptr() + (k * TK / 2) * (N * 2) + sg_starty / SG_SZ * TN * 2, N * 2, matrix_layout::packed_b); - joint_matrix_mad(sg, sub_a, sub_b, sub_c); + sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); } joint_matrix_store( sg, sub_c, diff --git a/sycl/test-e2e/Matrix/Legacy/joint_matrix_int8_colmajorA_colmajorB_impl.hpp b/sycl/test-e2e/Matrix/Legacy/joint_matrix_int8_colmajorA_colmajorB_impl.hpp index 6fb5b3981879e..21347c80c083b 100644 --- a/sycl/test-e2e/Matrix/Legacy/joint_matrix_int8_colmajorA_colmajorB_impl.hpp +++ b/sycl/test-e2e/Matrix/Legacy/joint_matrix_int8_colmajorA_colmajorB_impl.hpp @@ -72,7 +72,7 @@ void matrix_multiply(big_matrix &C, accB.template get_multi_ptr() + (sg_starty / SG_SZ * TN) * K + k * TK, K, matrix_layout::col_major); - joint_matrix_mad(sg, sub_a, sub_b, sub_c); + sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); } joint_matrix_store( sg, sub_c, diff --git a/sycl/test-e2e/Matrix/Legacy/joint_matrix_int8_vnni_impl.hpp b/sycl/test-e2e/Matrix/Legacy/joint_matrix_int8_vnni_impl.hpp index 066934b98221c..1948071dbf405 100644 --- a/sycl/test-e2e/Matrix/Legacy/joint_matrix_int8_vnni_impl.hpp +++ b/sycl/test-e2e/Matrix/Legacy/joint_matrix_int8_vnni_impl.hpp @@ -64,7 +64,7 @@ void matrix_multiply(big_matrix &C, accB.template get_multi_ptr() + (k * TK) * N + sg_starty / SG_SZ * TN, N, matrix_layout::row_major); - joint_matrix_mad(sg, sub_a, sub_b, sub_c); + sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); } joint_matrix_store( sg, sub_c, diff --git a/sycl/test-e2e/Matrix/Legacy/joint_matrix_query_default.cpp b/sycl/test-e2e/Matrix/Legacy/joint_matrix_query_default.cpp index 0dd9cf7e1ec6c..8aaf737a274a8 100644 --- a/sycl/test-e2e/Matrix/Legacy/joint_matrix_query_default.cpp +++ b/sycl/test-e2e/Matrix/Legacy/joint_matrix_query_default.cpp @@ -97,7 +97,7 @@ void matrix_multiply(big_matrix &C, accB.template get_multi_ptr() + (k * TK / 4) * (N * 4) + sg_starty / SG_SZ * TN * 4, N * 4, matrix_layout::packed_b); - joint_matrix_mad(sg, sub_a, sub_b, sub_c); + sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); } joint_matrix_store( sg, sub_c, diff --git a/sycl/test-e2e/Matrix/Legacy/joint_matrix_ss_int8_impl.hpp b/sycl/test-e2e/Matrix/Legacy/joint_matrix_ss_int8_impl.hpp index c1a5d1c762e14..20e14381f24bf 100644 --- a/sycl/test-e2e/Matrix/Legacy/joint_matrix_ss_int8_impl.hpp +++ b/sycl/test-e2e/Matrix/Legacy/joint_matrix_ss_int8_impl.hpp @@ -70,7 +70,7 @@ void matrix_multiply(big_matrix &C, accB.template get_multi_ptr() + (k * TK / 4) * (N * 4) + sg_starty / SG_SZ * TN * 4, N * 4, matrix_layout::packed_b); - joint_matrix_mad(sg, sub_a, sub_b, sub_c); + sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); } joint_matrix_store( sg, sub_c, diff --git a/sycl/test-e2e/Matrix/Legacy/joint_matrix_su_int8_impl.hpp b/sycl/test-e2e/Matrix/Legacy/joint_matrix_su_int8_impl.hpp index 630708e0b54aa..7c80fff72a55f 100644 --- a/sycl/test-e2e/Matrix/Legacy/joint_matrix_su_int8_impl.hpp +++ b/sycl/test-e2e/Matrix/Legacy/joint_matrix_su_int8_impl.hpp @@ -76,7 +76,7 @@ void matrix_multiply(big_matrix &C, accB.template get_multi_ptr() + (k * TK / 4) * (N * 4) + sg_starty / SG_SZ * TN * 4, N * 4, matrix_layout::packed_b); - joint_matrix_mad(sg, sub_a, sub_b, sub_c); + sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); } joint_matrix_store( sg, sub_c, diff --git a/sycl/test-e2e/Matrix/Legacy/joint_matrix_us_int8_impl.hpp b/sycl/test-e2e/Matrix/Legacy/joint_matrix_us_int8_impl.hpp index 8072f813fdb26..68cf40bb481b9 100644 --- a/sycl/test-e2e/Matrix/Legacy/joint_matrix_us_int8_impl.hpp +++ b/sycl/test-e2e/Matrix/Legacy/joint_matrix_us_int8_impl.hpp @@ -78,7 +78,7 @@ void matrix_multiply(big_matrix &C, accB.template get_multi_ptr() + (k * TK / 4) * (N * 4) + sg_starty / SG_SZ * TN * 4, N * 4, matrix_layout::packed_b); - joint_matrix_mad(sg, sub_a, sub_b, sub_c); + sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); } joint_matrix_store( sg, sub_c, diff --git a/sycl/test-e2e/Matrix/Legacy/joint_matrix_uu_int8_impl.hpp b/sycl/test-e2e/Matrix/Legacy/joint_matrix_uu_int8_impl.hpp index 6ee550537f285..895b7f0339cfe 100644 --- a/sycl/test-e2e/Matrix/Legacy/joint_matrix_uu_int8_impl.hpp +++ b/sycl/test-e2e/Matrix/Legacy/joint_matrix_uu_int8_impl.hpp @@ -76,7 +76,7 @@ void matrix_multiply(big_matrix &C, accB.template get_multi_ptr() + (k * TK / 4) * (N * 4) + sg_starty / SG_SZ * TN * 4, N * 4, matrix_layout::packed_b); - joint_matrix_mad(sg, sub_a, sub_b, sub_c); + sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); } joint_matrix_store( sg, sub_c, diff --git a/sycl/test-e2e/Matrix/joint_matrix_out_bounds_impl.hpp b/sycl/test-e2e/Matrix/joint_matrix_out_bounds_impl.hpp index 67f99facdd96f..1c1c4f97819bf 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_out_bounds_impl.hpp +++ b/sycl/test-e2e/Matrix/joint_matrix_out_bounds_impl.hpp @@ -58,7 +58,7 @@ void matrix_multiply(T1 *C, T2 *A, T2 *B, queue q, unsigned int vnniFactor) { joint_matrix_load_checked( sg, sub_b, pB + k * N + sg_starty / SG_SZ * TN * vnniFactor, N * vnniFactor, K / vnniFactor, N * vnniFactor); - joint_matrix_mad(sg, sub_a, sub_b, sub_c); + sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); } // bounds-checked store where width and height are added joint_matrix_store_checked( diff --git a/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-bfloat16-test.cpp b/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-bfloat16-test.cpp index 784c3ad489cb6..958aba55c3b46 100644 --- a/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-bfloat16-test.cpp +++ b/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-bfloat16-test.cpp @@ -57,7 +57,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m16n16k16.mma.row.row.bf16(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m16n16k16.store.d.row.stride.f32.p1(ptr addrspace(1) %{{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -88,7 +88,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m16n16k16.mma.col.col.bf16(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m16n16k16.store.d.col.stride.f32.p1(ptr addrspace(1) %{{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -119,7 +119,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m32n8k16.mma.row.row.bf16(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m32n8k16.store.d.row.stride.f32.p1(ptr addrspace(1) %{{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -150,7 +150,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m32n8k16.mma.col.col.bf16(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m32n8k16.store.d.col.stride.f32.p1(ptr addrspace(1) %{{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -181,7 +181,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m8n32k16.mma.row.row.bf16(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m8n32k16.store.d.row.stride.f32.p1(ptr addrspace(1) %{{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -212,7 +212,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m8n32k16.mma.col.col.bf16(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m8n32k16.store.d.col.stride.f32.p1(ptr addrspace(1) %{{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), diff --git a/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-double-test.cpp b/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-double-test.cpp index 0090805e7a55c..567e3293e1862 100644 --- a/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-double-test.cpp +++ b/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-double-test.cpp @@ -67,7 +67,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), N); //CHECK-OPAQUE: tail call { double, double } @llvm.nvvm.wmma.m8n8k4.mma.row.row.f64(double {{.*}}, double {{.*}}, double {{.*}}, double {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c); //CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m8n8k4.store.d.row.stride.f64.p1(ptr addrspace(1) %{{.*}}, double {{.*}}, double {{.*}}, i32 8) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -98,7 +98,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), K); //CHECK-OPAQUE: tail call { double, double } @llvm.nvvm.wmma.m8n8k4.mma.col.col.f64(double {{.*}}, double {{.*}}, double {{.*}}, double {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c); //CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m8n8k4.store.d.col.stride.f64.p1(ptr addrspace(1) %{{.*}}, double {{.*}}, double {{.*}}, i32 8) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), diff --git a/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-half-float-test.cpp b/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-half-float-test.cpp index 209341a71e03d..93d431061a3e0 100644 --- a/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-half-float-test.cpp +++ b/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-half-float-test.cpp @@ -56,7 +56,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m16n16k16.mma.row.row.f32.f32(<2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m16n16k16.store.d.row.stride.f32.p1(ptr addrspace(1) %{{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -87,7 +87,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m16n16k16.mma.col.col.f32.f32(<2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m16n16k16.store.d.col.stride.f32.p1(ptr addrspace(1) %{{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -118,7 +118,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m32n8k16.mma.row.row.f32.f32(<2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m32n8k16.store.d.row.stride.f32.p1(ptr addrspace(1) %{{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -149,7 +149,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m32n8k16.mma.col.col.f32.f32(<2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m32n8k16.store.d.col.stride.f32.p1(ptr addrspace(1) %{{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -180,7 +180,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m8n32k16.mma.row.row.f32.f32(<2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m8n32k16.store.d.row.stride.f32.p1(ptr addrspace(1) %{{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -211,7 +211,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m8n32k16.mma.col.col.f32.f32(<2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m8n32k16.store.d.col.stride.f32.p1(ptr addrspace(1) %{{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), diff --git a/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-half-half-test.cpp b/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-half-half-test.cpp index e78fbe523dd29..ba4d27a2feb89 100644 --- a/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-half-half-test.cpp +++ b/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-half-half-test.cpp @@ -56,7 +56,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m16n16k16.mma.row.row.f16.f16(<2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m16n16k16.store.d.row.stride.f16.p0(ptr %{{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -87,7 +87,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m16n16k16.mma.col.col.f16.f16(<2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m16n16k16.store.d.col.stride.f16.p0(ptr %{{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -118,7 +118,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m32n8k16.mma.row.row.f16.f16(<2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m32n8k16.store.d.row.stride.f16.p0(ptr %{{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -149,7 +149,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m32n8k16.mma.col.col.f16.f16(<2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m32n8k16.store.d.col.stride.f16.p0(ptr %{{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -180,7 +180,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m8n32k16.mma.row.row.f16.f16(<2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m8n32k16.store.d.row.stride.f16.p0(ptr %{{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -211,7 +211,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m8n32k16.mma.col.col.f16.f16(<2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m8n32k16.store.d.col.stride.f16.p0(ptr %{{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), diff --git a/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-int8-test.cpp b/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-int8-test.cpp index 743f9fd54e12e..2d012581faf8b 100644 --- a/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-int8-test.cpp +++ b/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-int8-test.cpp @@ -56,7 +56,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m16n16k16.mma.row.row.s8(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m16n16k16.store.d.row.stride.s32.p1(ptr addrspace(1) %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -87,7 +87,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m16n16k16.mma.col.col.s8(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m16n16k16.store.d.col.stride.s32.p1(ptr addrspace(1) %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -118,7 +118,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m32n8k16.mma.row.row.s8(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m32n8k16.store.d.row.stride.s32.p1(ptr addrspace(1) %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -149,7 +149,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m32n8k16.mma.col.col.s8(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m32n8k16.store.d.col.stride.s32.p1(ptr addrspace(1) %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -180,7 +180,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m8n32k16.mma.row.row.s8(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m8n32k16.store.d.row.stride.s32.p1(ptr addrspace(1) %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -211,7 +211,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m8n32k16.mma.col.col.s8(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m8n32k16.store.d.col.stride.s32.p1(ptr addrspace(1) %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), diff --git a/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-tf32-test.cpp b/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-tf32-test.cpp index d3e28e94e5e71..a69246dac7315 100644 --- a/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-tf32-test.cpp +++ b/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-tf32-test.cpp @@ -88,7 +88,7 @@ int main() { get_wi_data(sg, sub_b)[i] = round_to_tf32(get_wi_data(sg, sub_b)[i]); - joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c); //CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m16n16k16.store.d.row.stride.f32.p1(ptr addrspace(1) {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, i32 {{.*}} joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -137,7 +137,7 @@ int main() { get_wi_data(sg, sub_b)[i] = round_to_tf32(get_wi_data(sg, sub_b)[i]); - joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c); //CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m16n16k16.store.d.col.stride.f32.p1(ptr addrspace(1) {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), diff --git a/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-uint8-test.cpp b/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-uint8-test.cpp index 1aa82e27f6c68..c22a50908d1c7 100644 --- a/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-uint8-test.cpp +++ b/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-uint8-test.cpp @@ -56,7 +56,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m16n16k16.mma.row.row.u8(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m16n16k16.store.d.row.stride.s32.p1(ptr addrspace(1) %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -87,7 +87,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m16n16k16.mma.col.col.u8(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m16n16k16.store.d.col.stride.s32.p1(ptr addrspace(1) %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -118,7 +118,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m32n8k16.mma.row.row.u8(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m32n8k16.store.d.row.stride.s32.p1(ptr addrspace(1) %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -149,7 +149,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m32n8k16.mma.col.col.u8(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m32n8k16.store.d.col.stride.s32.p1(ptr addrspace(1) %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -180,7 +180,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m8n32k16.mma.row.row.u8(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m8n32k16.store.d.row.stride.s32.p1(ptr addrspace(1) %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -211,7 +211,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m8n32k16.mma.col.col.u8(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m8n32k16.store.d.col.stride.s32.p1(ptr addrspace(1) %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), diff --git a/sycl/test/matrix/legacy/matrix-bf16-test-SG-16.cpp b/sycl/test/matrix/legacy/matrix-bf16-test-SG-16.cpp index c33848a81a2ed..391a9be2197c6 100644 --- a/sycl/test/matrix/legacy/matrix-bf16-test-SG-16.cpp +++ b/sycl/test/matrix/legacy/matrix-bf16-test-SG-16.cpp @@ -89,7 +89,7 @@ void matrix_multiply(big_matrix &C, accB.template get_multi_ptr() + (k * TK / 2) * (N * 2) + sg_starty / SG_SZ * TN * 2, N * 2, matrix_layout::packed_b); - joint_matrix_mad(sg, sub_a, sub_b, sub_c); + sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); } joint_matrix_store( sg, sub_c, diff --git a/sycl/test/matrix/legacy/matrix-bf16-test.cpp b/sycl/test/matrix/legacy/matrix-bf16-test.cpp index bd989a6e34d0f..6c6bfc1066f01 100644 --- a/sycl/test/matrix/legacy/matrix-bf16-test.cpp +++ b/sycl/test/matrix/legacy/matrix-bf16-test.cpp @@ -88,7 +88,7 @@ void matrix_multiply(big_matrix &C, accB.template get_multi_ptr() + (k * TK / 2) * (N * 2) + sg_starty / SG_SZ * TN * 2, N * 2, matrix_layout::packed_b); - joint_matrix_mad(sg, sub_a, sub_b, sub_c); + sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); } joint_matrix_store( sg, sub_c, diff --git a/sycl/test/matrix/legacy/matrix-bfloat16-test.cpp b/sycl/test/matrix/legacy/matrix-bfloat16-test.cpp index 92715e3b488da..022e69f9b75a2 100644 --- a/sycl/test/matrix/legacy/matrix-bfloat16-test.cpp +++ b/sycl/test/matrix/legacy/matrix-bfloat16-test.cpp @@ -87,7 +87,7 @@ void matrix_multiply(big_matrix &C, accB.template get_multi_ptr() + (k * TK / 2) * (N * 2) + sg_starty / SG_SZ * TN * 2, N * 2, matrix_layout::packed_b); - joint_matrix_mad(sg, sub_a, sub_b, sub_c); + sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); } joint_matrix_store( sg, sub_c, diff --git a/sycl/test/matrix/legacy/matrix-elemwise-ops.cpp b/sycl/test/matrix/legacy/matrix-elemwise-ops.cpp index 3cb773a2c2239..feddb05148c4e 100644 --- a/sycl/test/matrix/legacy/matrix-elemwise-ops.cpp +++ b/sycl/test/matrix/legacy/matrix-elemwise-ops.cpp @@ -88,7 +88,7 @@ void matrix_multiply(big_matrix &C, accB.template get_multi_ptr() + (k * TK / 4) * (N * 4) + sg_starty / SG_SZ * TN * 4, N * 4, matrix_layout::packed_b); - joint_matrix_mad(sg, sub_a, sub_b, sub_c); + sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); } auto wi_data_c = sub_c.get_wi_data(); for (int i = 0; i < wi_data_c.length(); i++) { diff --git a/sycl/test/matrix/legacy/matrix-int8-test-SG-16.cpp b/sycl/test/matrix/legacy/matrix-int8-test-SG-16.cpp index 0bc7b66b2e878..335529ad3120a 100644 --- a/sycl/test/matrix/legacy/matrix-int8-test-SG-16.cpp +++ b/sycl/test/matrix/legacy/matrix-int8-test-SG-16.cpp @@ -88,7 +88,7 @@ void matrix_multiply(big_matrix &C, accB.template get_multi_ptr() + (k * TK / 4) * (N * 4) + sg_starty / SG_SZ * TN * 4, N * 4, matrix_layout::packed_b); - joint_matrix_mad(sg, sub_a, sub_b, sub_c); + sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); } joint_matrix_store( sg, sub_c, diff --git a/sycl/test/matrix/legacy/matrix-int8-test.cpp b/sycl/test/matrix/legacy/matrix-int8-test.cpp index b3fe44fa56250..77c57b4ef711e 100644 --- a/sycl/test/matrix/legacy/matrix-int8-test.cpp +++ b/sycl/test/matrix/legacy/matrix-int8-test.cpp @@ -88,7 +88,7 @@ void matrix_multiply(big_matrix &C, accB.template get_multi_ptr() + (k * TK / 4) * (N * 4) + sg_starty / SG_SZ * TN * 4, N * 4, matrix_layout::packed_b); - joint_matrix_mad(sg, sub_a, sub_b, sub_c); + sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); } joint_matrix_store( sg, sub_c, From 4ec8360e83f5389ba68afaf9cbc2ca70959c4137 Mon Sep 17 00:00:00 2001 From: Bing1 Yu Date: Tue, 19 Sep 2023 15:35:38 +0800 Subject: [PATCH 06/50] fix cuda const joint_matrix_cuda --- .../sycl/ext/oneapi/matrix/matrix-tensorcores.hpp | 10 +++++----- .../matrix/matrix_load_store_as_legacy.cpp | 2 +- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/sycl/include/sycl/ext/oneapi/matrix/matrix-tensorcores.hpp b/sycl/include/sycl/ext/oneapi/matrix/matrix-tensorcores.hpp index 94ae318540012..1ab3b56b79ca2 100644 --- a/sycl/include/sycl/ext/oneapi/matrix/matrix-tensorcores.hpp +++ b/sycl/include/sycl/ext/oneapi/matrix/matrix-tensorcores.hpp @@ -482,11 +482,11 @@ void joint_matrix_mad_cuda( joint_matrix_cuda< Tc, sycl::ext::oneapi::experimental::matrix::use::accumulator, M, N, sycl::ext::oneapi::experimental::matrix::layout::dynamic> &D, - joint_matrix_cuda &A, - joint_matrix_cuda &B, - joint_matrix_cuda< + const joint_matrix_cuda &A, + const joint_matrix_cuda &B, + const joint_matrix_cuda< Tc, sycl::ext::oneapi::experimental::matrix::use::accumulator, M, N, sycl::ext::oneapi::experimental::matrix::layout::dynamic> &C) { if constexpr (M == 16 && N == 16 && K == 16) { diff --git a/sycl/test/check_device_code/matrix/matrix_load_store_as_legacy.cpp b/sycl/test/check_device_code/matrix/matrix_load_store_as_legacy.cpp index 022ac65612a28..bb18a21bc1002 100644 --- a/sycl/test/check_device_code/matrix/matrix_load_store_as_legacy.cpp +++ b/sycl/test/check_device_code/matrix/matrix_load_store_as_legacy.cpp @@ -47,7 +47,7 @@ int main(void) { // B should load from global address space // CHECK: %{{.*}} = tail call spir_func noundef target("spirv.JointMatrixINTEL", i16, 16, 16, 3, 3) @_Z[[#]]__spirv_JointMatrixLoadINTEL{{.*}}(ptr addrspace(1) noundef %{{.*}}, i64 noundef 32, i32 noundef [[#]], i32 noundef 3, i32 noundef 0) #{{.*}} joint_matrix_load(sg, tB, pB, 32, matrix_layout::packed_b); - joint_matrix_mad(sg, tA, tB, tC, tC); + tC = joint_matrix_mad(sg, tA, tB, tC); // C should store to global address space // CHECK: tail call spir_func void @_Z[[#]]__spirv_JointMatrixStoreINTEL{{.*}}(ptr addrspace(1) noundef %{{.*}}, target("spirv.JointMatrixINTEL", float, 8, 16, 0, 3) noundef %{{.*}}, i64 noundef 16, i32 noundef 0, i32 noundef 3, i32 noundef 0) #{{.*}} joint_matrix_store(sg, tC, pC, 16, matrix_layout::row_major); From a461cbb88c64518d58a58e04218530512042fe8e Mon Sep 17 00:00:00 2001 From: Bing1 Yu Date: Tue, 19 Sep 2023 16:04:10 +0800 Subject: [PATCH 07/50] fix const issue of jm_store_cuda --- sycl/include/sycl/ext/oneapi/matrix/matrix-tensorcores.hpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sycl/include/sycl/ext/oneapi/matrix/matrix-tensorcores.hpp b/sycl/include/sycl/ext/oneapi/matrix/matrix-tensorcores.hpp index 1ab3b56b79ca2..ac532dbb1886a 100644 --- a/sycl/include/sycl/ext/oneapi/matrix/matrix-tensorcores.hpp +++ b/sycl/include/sycl/ext/oneapi/matrix/matrix-tensorcores.hpp @@ -415,7 +415,8 @@ template &src, + NumCols, const sycl::ext::oneapi::experimental::matrix::layout::dynamic> + &src, multi_ptr dst, size_t stride, sycl::ext::oneapi::experimental::matrix::layout Layout) { switch (Layout) { From 5ff715bef9394e9028f00273252da777a5f7d56f Mon Sep 17 00:00:00 2001 From: Bing1 Yu Date: Tue, 19 Sep 2023 18:07:15 +0800 Subject: [PATCH 08/50] fix const --- sycl/include/sycl/ext/oneapi/matrix/matrix-tensorcores.hpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sycl/include/sycl/ext/oneapi/matrix/matrix-tensorcores.hpp b/sycl/include/sycl/ext/oneapi/matrix/matrix-tensorcores.hpp index ac532dbb1886a..ca912a04cedc5 100644 --- a/sycl/include/sycl/ext/oneapi/matrix/matrix-tensorcores.hpp +++ b/sycl/include/sycl/ext/oneapi/matrix/matrix-tensorcores.hpp @@ -413,9 +413,9 @@ void store_layoutT( template void joint_matrix_store_cuda( - joint_matrix_cuda< + const joint_matrix_cuda< T, sycl::ext::oneapi::experimental::matrix::use::accumulator, NumRows, - NumCols, const sycl::ext::oneapi::experimental::matrix::layout::dynamic> + NumCols, sycl::ext::oneapi::experimental::matrix::layout::dynamic> &src, multi_ptr dst, size_t stride, sycl::ext::oneapi::experimental::matrix::layout Layout) { From 8ad7da922f4c00763098dd2f160675f161e7b749 Mon Sep 17 00:00:00 2001 From: Bing1 Yu Date: Tue, 19 Sep 2023 22:06:03 +0800 Subject: [PATCH 09/50] lint --- sycl/include/sycl/ext/oneapi/matrix/matrix-tensorcores.hpp | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/sycl/include/sycl/ext/oneapi/matrix/matrix-tensorcores.hpp b/sycl/include/sycl/ext/oneapi/matrix/matrix-tensorcores.hpp index ca912a04cedc5..849cb676c6613 100644 --- a/sycl/include/sycl/ext/oneapi/matrix/matrix-tensorcores.hpp +++ b/sycl/include/sycl/ext/oneapi/matrix/matrix-tensorcores.hpp @@ -415,8 +415,7 @@ template - &src, + NumCols, sycl::ext::oneapi::experimental::matrix::layout::dynamic> &src, multi_ptr dst, size_t stride, sycl::ext::oneapi::experimental::matrix::layout Layout) { switch (Layout) { From 26ea49da18fe71a3e9c94e179b3eff735ea1cc3a Mon Sep 17 00:00:00 2001 From: Bing1 Yu Date: Thu, 21 Sep 2023 14:52:53 +0800 Subject: [PATCH 10/50] address dounia's comments and roll back all the testcase changes --- .../sycl/ext/oneapi/matrix/matrix-unified.hpp | 29 +++--- .../Matrix/Legacy/element_wise_ops_impl.hpp | 96 +++++++++---------- .../Matrix/Legacy/joint_matrix_half_impl.hpp | 84 ++++++++-------- .../Legacy/joint_matrix_ss_int8_impl.hpp | 76 +++++++-------- .../Legacy/joint_matrix_su_int8_impl.hpp | 88 ++++++++--------- .../Legacy/joint_matrix_uu_int8_impl.hpp | 88 ++++++++--------- .../test-e2e/Matrix/element_wise_abc_impl.hpp | 14 +-- .../Matrix/element_wise_all_ops_half_impl.hpp | 15 ++- .../Matrix/element_wise_all_ops_impl.hpp | 6 +- .../Matrix/element_wise_all_ops_int8_impl.hpp | 15 ++- .../element_wise_all_ops_int8_packed_impl.hpp | 40 ++++---- .../Matrix/element_wise_all_ops_tf32_impl.hpp | 15 ++- .../Matrix/element_wise_all_sizes_impl.hpp | 47 ++++----- .../element_wise_irreg_sum_rows_impl.hpp | 8 +- .../test-e2e/Matrix/element_wise_ops_impl.hpp | 10 +- .../Matrix/elemwise_irreg_size_ops_bf16.cpp | 10 +- .../Matrix/joint_matrix_all_sizes_impl.hpp | 7 +- .../Matrix/joint_matrix_apply_cuda.hpp | 65 ++++++------- .../joint_matrix_bf16_fill_k_cache_impl.hpp | 53 +++++----- .../joint_matrix_bfloat16_32x64_impl.hpp | 7 +- .../joint_matrix_bfloat16_array_impl.hpp | 7 +- ...trix_bfloat16_colmajorA_colmajorB_impl.hpp | 2 +- .../Matrix/joint_matrix_bfloat16_impl.hpp | 7 +- ...trix_bfloat16_rowmajorA_rowmajorB_impl.hpp | 2 +- .../joint_matrix_colA_rowB_colC_impl.hpp | 2 +- .../Matrix/joint_matrix_down_convert_impl.hpp | 9 +- .../Matrix/joint_matrix_gemm_cuda.hpp | 2 +- .../Matrix/joint_matrix_half_impl.hpp | 7 +- ...t_matrix_int8_colmajorA_colmajorB_impl.hpp | 2 +- .../Matrix/joint_matrix_int8_vnni_impl.hpp | 2 +- .../Matrix/joint_matrix_query_default.cpp | 5 +- .../Matrix/joint_matrix_ss_int8_impl.hpp | 7 +- .../Matrix/joint_matrix_su_int8_impl.hpp | 7 +- .../Matrix/joint_matrix_tf32_impl.hpp | 8 +- .../Matrix/joint_matrix_transposeC_impl.hpp | 7 +- .../Matrix/joint_matrix_us_int8_impl.hpp | 7 +- .../Matrix/joint_matrix_uu_int8_impl.hpp | 7 +- .../matrix/matrix-nvptx-bfloat16-test.cpp | 12 +-- .../cuda/matrix/matrix-nvptx-double-test.cpp | 4 +- .../matrix/matrix-nvptx-half-float-test.cpp | 12 +-- .../matrix/matrix-nvptx-half-half-test.cpp | 12 +-- .../cuda/matrix/matrix-nvptx-int8-test.cpp | 12 +-- .../cuda/matrix/matrix-nvptx-tf32-test.cpp | 4 +- .../cuda/matrix/matrix-nvptx-uint8-test.cpp | 12 +-- .../matrix/matrix_load_store_as.cpp | 7 +- .../matrix-bfloat16-test-coord-basicB.cpp | 8 +- sycl/test/matrix/matrix-bfloat16-test.cpp | 5 +- sycl/test/matrix/matrix-elemwise-ops.cpp | 8 +- sycl/test/matrix/matrix-int8-test.cpp | 5 +- sycl/test/matrix/matrix-tf32-test.cpp | 5 +- sycl/test/matrix/query-use.cpp | 6 +- 51 files changed, 496 insertions(+), 479 deletions(-) diff --git a/sycl/include/sycl/ext/oneapi/matrix/matrix-unified.hpp b/sycl/include/sycl/ext/oneapi/matrix/matrix-unified.hpp index 90f267bff19c4..bf2441ac17c2e 100644 --- a/sycl/include/sycl/ext/oneapi/matrix/matrix-unified.hpp +++ b/sycl/include/sycl/ext/oneapi/matrix/matrix-unified.hpp @@ -87,9 +87,10 @@ class wi_data { #if defined(__NVPTX__) return jm.cuda_impl.wi_marray.size(); #else - throw runtime_error("get_wi_data is available using: " - "ext::oneapi::detail::get_wi_data.", - PI_ERROR_INVALID_DEVICE); + throw runtime_error( + "get_wi_data is available using: ext::oneapi::detail::get_wi_data, but " + "intel users are expected to use joint_matrix_copy instead.", + PI_ERROR_INVALID_DEVICE); #endif }; @@ -97,9 +98,10 @@ class wi_data { #if defined(__NVPTX__) return (jm.cuda_impl.wi_marray[i]); #else - throw runtime_error("get_wi_data is available using: " - "ext::oneapi::detail::get_wi_data.", - PI_ERROR_INVALID_DEVICE); + throw runtime_error( + "get_wi_data is available using: ext::oneapi::detail::get_wi_data, but " + "intel users are expected to use joint_matrix_copy instead.", + PI_ERROR_INVALID_DEVICE); #endif }; }; @@ -129,7 +131,7 @@ __SYCL2020_DEPRECATED("get_wi_data() is deprecated for CUDA backend. Please " #else __attribute__(( unavailable("get_wi_data can't be used on intel device, please use " - "sycl::ext::oneapi::detail::get_wi_data instead!"))) + "joint_matrix_apply instead!"))) #endif #endif inline __SYCL_ALWAYS_INLINE decltype(auto) @@ -251,7 +253,7 @@ inline __SYCL_ALWAYS_INLINE void joint_matrix_load( Ptr, stride, __spv::MatrixLayout::ColumnMajor, spv_scope_traits::value); break; - case sycl::ext::oneapi::experimental::matrix::layout::ext_intel_packed: + case layout::ext_intel_packed: res.spvm = __spirv_JointMatrixLoadINTEL< DecorT, S, NumRows, NumCols, spv_matrix_use_traits::value, @@ -351,7 +353,7 @@ inline __SYCL_ALWAYS_INLINE void joint_matrix_store( Ptr, src.spvm, stride, __spv::MatrixLayout::ColumnMajor, spv_scope_traits::value); break; - case sycl::ext::oneapi::experimental::matrix::layout::ext_intel_packed: + case layout::ext_intel_packed: __spirv_JointMatrixStoreINTEL< DecorT, T, NumRows, NumCols, spv_matrix_use_traits::value, @@ -376,13 +378,14 @@ template inline __SYCL_ALWAYS_INLINE void joint_matrix_mad( - Group sg, const joint_matrix &A, + Group sg, + joint_matrix &D, + const joint_matrix &A, const joint_matrix &B, const joint_matrix - &C, - joint_matrix &D) { + &C) { #if defined(__SYCL_DEVICE_ONLY__) #if defined(__NVPTX__) std::ignore = sg; diff --git a/sycl/test-e2e/Matrix/Legacy/element_wise_ops_impl.hpp b/sycl/test-e2e/Matrix/Legacy/element_wise_ops_impl.hpp index 2ef278e229ff5..8d15b78fd3198 100644 --- a/sycl/test-e2e/Matrix/Legacy/element_wise_ops_impl.hpp +++ b/sycl/test-e2e/Matrix/Legacy/element_wise_ops_impl.hpp @@ -38,56 +38,56 @@ void matrix_multiply(big_matrix &C, cgh.parallel_for( nd_range<2>({NDRangeM, NDRangeN * SG_SZ}, {1, 1 * SG_SZ}), - [accA, accB, accC, M, N, K](nd_item<2> spmd_item) - [[intel::reqd_sub_group_size(SG_SZ)]] { - // The submatrix API has to be accessed by all the workitems in a - // subgroup these functions will be called once by the subgroup - // no code divergence between the workitems - const auto global_idx = spmd_item.get_global_id(0); - const auto global_idy = spmd_item.get_global_id(1); - const auto sg_startx = global_idx - spmd_item.get_local_id(0); - const auto sg_starty = global_idy - spmd_item.get_local_id(1); + [accA, accB, accC, M, N, + K](nd_item<2> spmd_item) [[intel::reqd_sub_group_size(SG_SZ)]] { + // The submatrix API has to be accessed by all the workitems in a + // subgroup these functions will be called once by the subgroup no + // code divergence between the workitems + const auto global_idx = spmd_item.get_global_id(0); + const auto global_idy = spmd_item.get_global_id(1); + const auto sg_startx = global_idx - spmd_item.get_local_id(0); + const auto sg_starty = global_idy - spmd_item.get_local_id(1); - sycl::sub_group sg = spmd_item.get_sub_group(); - joint_matrix sub_a(sg); - // For B, since current implementation does not support - // non-packed layout, users need to specify the updated VNNI - // sizes along with the packed_b layout. By default, the layout - // is row_major and size is (TK, TN). - joint_matrix sub_b(sg); - joint_matrix sub_c(sg); + sycl::sub_group sg = spmd_item.get_sub_group(); + joint_matrix sub_a(sg); + // For B, since current implementation does not support non-packed + // layout, users need to specify the updated VNNI sizes along with + // the packed_b layout. By default, the layout is row_major and size + // is (TK, TN). + joint_matrix sub_b(sg); + joint_matrix sub_c(sg); - // AMX: 8 register tiles : 1k byte size, SMmaxxSKmax =16x64 - // strideX = X's cols, so strideC = N, strideA = K, strideB = N*4 - joint_matrix_load( - sg, sub_c, - accC.template get_multi_ptr() + - (sg_startx * TM) * N + sg_starty / SG_SZ * TN, - N, matrix_layout::row_major); - for (int k = 0; k < K / TK; k += 1) { - joint_matrix_load( - sg, sub_a, - accA.template get_multi_ptr() + - (sg_startx * TM) * K + k * TK, - K, matrix_layout::row_major); - // Assuming B data is already in VNNI format. - joint_matrix_load( - sg, sub_b, - accB.template get_multi_ptr() + - (k * TK / 4) * (N * 4) + sg_starty / SG_SZ * TN * 4, - N * 4, matrix_layout::packed_b); - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); - } - auto wi_slice_c = sub_c.get_wi_data(); - for (int i = 0; i < wi_slice_c.length(); i++) { - wi_slice_c[i] *= 2; - } - joint_matrix_store( - sg, sub_c, - accC.template get_multi_ptr() + - (sg_startx * TM) * N + sg_starty / SG_SZ * TN, - N, matrix_layout::row_major); - }); // parallel for + // AMX: 8 register tiles : 1k byte size, SMmaxxSKmax =16x64 + // strideX = X's cols, so strideC = N, strideA = K, strideB = N*4 + joint_matrix_load( + sg, sub_c, + accC.template get_multi_ptr() + + (sg_startx * TM) * N + sg_starty / SG_SZ * TN, + N, matrix_layout::row_major); + for (int k = 0; k < K / TK; k += 1) { + joint_matrix_load( + sg, sub_a, + accA.template get_multi_ptr() + + (sg_startx * TM) * K + k * TK, + K, matrix_layout::row_major); + // Assuming B data is already in VNNI format. + joint_matrix_load( + sg, sub_b, + accB.template get_multi_ptr() + + (k * TK / 4) * (N * 4) + sg_starty / SG_SZ * TN * 4, + N * 4, matrix_layout::packed_b); + sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + } + auto wi_slice_c = sub_c.get_wi_data(); + for (int i = 0; i < wi_slice_c.length(); i++) { + wi_slice_c[i] *= 2; + } + joint_matrix_store( + sg, sub_c, + accC.template get_multi_ptr() + + (sg_startx * TM) * N + sg_starty / SG_SZ * TN, + N, matrix_layout::row_major); + }); // parallel for }).wait(); } diff --git a/sycl/test-e2e/Matrix/Legacy/joint_matrix_half_impl.hpp b/sycl/test-e2e/Matrix/Legacy/joint_matrix_half_impl.hpp index 81ca8faa4977d..c663cc282c758 100644 --- a/sycl/test-e2e/Matrix/Legacy/joint_matrix_half_impl.hpp +++ b/sycl/test-e2e/Matrix/Legacy/joint_matrix_half_impl.hpp @@ -37,50 +37,50 @@ void matrix_multiply(big_matrix &C, cgh.parallel_for( nd_range<2>({NDRangeM, NDRangeN * SG_SZ}, {1, SG_SZ}), - [accA, accB, accC, M, N, K](nd_item<2> spmd_item) - [[intel::reqd_sub_group_size(SG_SZ)]] { - // The submatrix API has to be accessed by all the workitems in a - // subgroup these functions will be called once by the subgroup - // no code divergence between the workitems - const auto global_idx = spmd_item.get_global_id(0); - const auto global_idy = spmd_item.get_global_id(1); - const auto sg_startx = global_idx - spmd_item.get_local_id(0); - const auto sg_starty = global_idy - spmd_item.get_local_id(1); + [accA, accB, accC, M, N, + K](nd_item<2> spmd_item) [[intel::reqd_sub_group_size(SG_SZ)]] { + // The submatrix API has to be accessed by all the workitems in a + // subgroup these functions will be called once by the subgroup no + // code divergence between the workitems + const auto global_idx = spmd_item.get_global_id(0); + const auto global_idy = spmd_item.get_global_id(1); + const auto sg_startx = global_idx - spmd_item.get_local_id(0); + const auto sg_starty = global_idy - spmd_item.get_local_id(1); - sycl::sub_group sg = spmd_item.get_sub_group(); - joint_matrix sub_a(sg); - // For B, since current implementation does not support - // non-packed layout, users need to specify the updated VNNI - // sizes along with the packed_b layout. By default, the layout - // is row_major and size is (TK, TN). - joint_matrix sub_b(sg); - joint_matrix sub_c(sg); + sycl::sub_group sg = spmd_item.get_sub_group(); + joint_matrix sub_a(sg); + // For B, since current implementation does not support non-packed + // layout, users need to specify the updated VNNI sizes along with + // the packed_b layout. By default, the layout is row_major and size + // is (TK, TN). + joint_matrix sub_b(sg); + joint_matrix sub_c(sg); - joint_matrix_load( - sg, sub_c, - accC.template get_multi_ptr() + - (sg_startx * TM) * N + sg_starty / SG_SZ * TN, - N, matrix_layout::row_major); - for (int k = 0; k < K / TK; k += 1) { - joint_matrix_load( - sg, sub_a, - accA.template get_multi_ptr() + - (sg_startx * TM) * K + k * TK, - K, matrix_layout::row_major); - // Assuming B data is already in VNNI format. - joint_matrix_load( - sg, sub_b, - accB.template get_multi_ptr() + - (k * TK / 2) * (N * 2) + sg_starty / SG_SZ * TN * 2, - N * 2, matrix_layout::packed_b); - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); - } - joint_matrix_store( - sg, sub_c, - accC.template get_multi_ptr() + - (sg_startx * TM) * N + sg_starty / SG_SZ * TN, - N, matrix_layout::row_major); - }); // parallel for + joint_matrix_load( + sg, sub_c, + accC.template get_multi_ptr() + + (sg_startx * TM) * N + sg_starty / SG_SZ * TN, + N, matrix_layout::row_major); + for (int k = 0; k < K / TK; k += 1) { + joint_matrix_load( + sg, sub_a, + accA.template get_multi_ptr() + + (sg_startx * TM) * K + k * TK, + K, matrix_layout::row_major); + // Assuming B data is already in VNNI format. + joint_matrix_load( + sg, sub_b, + accB.template get_multi_ptr() + + (k * TK / 2) * (N * 2) + sg_starty / SG_SZ * TN * 2, + N * 2, matrix_layout::packed_b); + sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + } + joint_matrix_store( + sg, sub_c, + accC.template get_multi_ptr() + + (sg_startx * TM) * N + sg_starty / SG_SZ * TN, + N, matrix_layout::row_major); + }); // parallel for }).wait(); } diff --git a/sycl/test-e2e/Matrix/Legacy/joint_matrix_ss_int8_impl.hpp b/sycl/test-e2e/Matrix/Legacy/joint_matrix_ss_int8_impl.hpp index 20e14381f24bf..a2436bc56e792 100644 --- a/sycl/test-e2e/Matrix/Legacy/joint_matrix_ss_int8_impl.hpp +++ b/sycl/test-e2e/Matrix/Legacy/joint_matrix_ss_int8_impl.hpp @@ -38,46 +38,46 @@ void matrix_multiply(big_matrix &C, cgh.parallel_for( nd_range<2>({NDRangeM, NDRangeN * SG_SZ}, {1, 1 * SG_SZ}), - [accA, accB, accC, M, N, K](nd_item<2> spmd_item) - [[intel::reqd_sub_group_size(SG_SZ)]] { - // The submatrix API has to be accessed by all the workitems in a - // subgroup these functions will be called once by the subgroup - // no code divergence between the workitems - const auto global_idx = spmd_item.get_global_id(0); - const auto global_idy = spmd_item.get_global_id(1); - const auto sg_startx = global_idx - spmd_item.get_local_id(0); - const auto sg_starty = global_idy - spmd_item.get_local_id(1); + [accA, accB, accC, M, N, + K](nd_item<2> spmd_item) [[intel::reqd_sub_group_size(SG_SZ)]] { + // The submatrix API has to be accessed by all the workitems in a + // subgroup these functions will be called once by the subgroup no + // code divergence between the workitems + const auto global_idx = spmd_item.get_global_id(0); + const auto global_idy = spmd_item.get_global_id(1); + const auto sg_startx = global_idx - spmd_item.get_local_id(0); + const auto sg_starty = global_idy - spmd_item.get_local_id(1); - sycl::sub_group sg = spmd_item.get_sub_group(); - joint_matrix sub_a(sg); - // For B, since current implementation does not support - // non-packed layout, users need to specify the updated VNNI - // sizes along with the packed_b layout. By default, the layout - // is row_major and size is (TK, TN). - joint_matrix sub_b(sg); - joint_matrix sub_c(sg); + sycl::sub_group sg = spmd_item.get_sub_group(); + joint_matrix sub_a(sg); + // For B, since current implementation does not support non-packed + // layout, users need to specify the updated VNNI sizes along with + // the packed_b layout. By default, the layout is row_major and size + // is (TK, TN). + joint_matrix sub_b(sg); + joint_matrix sub_c(sg); - joint_matrix_fill(sg, sub_c, 0); - for (int k = 0; k < K / TK; k += 1) { - joint_matrix_load( - sg, sub_a, - accA.template get_multi_ptr() + - (sg_startx * TM) * K + k * TK, - K, matrix_layout::row_major); - // Assuming B data is already in VNNI format. - joint_matrix_load( - sg, sub_b, - accB.template get_multi_ptr() + - (k * TK / 4) * (N * 4) + sg_starty / SG_SZ * TN * 4, - N * 4, matrix_layout::packed_b); - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); - } - joint_matrix_store( - sg, sub_c, - accC.template get_multi_ptr() + - (sg_startx * TM) * N + sg_starty / SG_SZ * TN, - N, matrix_layout::row_major); - }); // parallel for + joint_matrix_fill(sg, sub_c, 0); + for (int k = 0; k < K / TK; k += 1) { + joint_matrix_load( + sg, sub_a, + accA.template get_multi_ptr() + + (sg_startx * TM) * K + k * TK, + K, matrix_layout::row_major); + // Assuming B data is already in VNNI format. + joint_matrix_load( + sg, sub_b, + accB.template get_multi_ptr() + + (k * TK / 4) * (N * 4) + sg_starty / SG_SZ * TN * 4, + N * 4, matrix_layout::packed_b); + sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + } + joint_matrix_store( + sg, sub_c, + accC.template get_multi_ptr() + + (sg_startx * TM) * N + sg_starty / SG_SZ * TN, + N, matrix_layout::row_major); + }); // parallel for }).wait(); } diff --git a/sycl/test-e2e/Matrix/Legacy/joint_matrix_su_int8_impl.hpp b/sycl/test-e2e/Matrix/Legacy/joint_matrix_su_int8_impl.hpp index 7c80fff72a55f..f0a9a7155fb0d 100644 --- a/sycl/test-e2e/Matrix/Legacy/joint_matrix_su_int8_impl.hpp +++ b/sycl/test-e2e/Matrix/Legacy/joint_matrix_su_int8_impl.hpp @@ -38,52 +38,52 @@ void matrix_multiply(big_matrix &C, cgh.parallel_for( nd_range<2>({NDRangeM, NDRangeN * SG_SZ}, {1, 1 * SG_SZ}), - [accA, accB, accC, M, N, K](nd_item<2> spmd_item) - [[intel::reqd_sub_group_size(SG_SZ)]] { - // The submatrix API has to be accessed by all the workitems in a - // subgroup these functions will be called once by the subgroup - // no code divergence between the workitems - const auto global_idx = spmd_item.get_global_id(0); - const auto global_idy = spmd_item.get_global_id(1); - const auto sg_startx = global_idx - spmd_item.get_local_id(0); - const auto sg_starty = global_idy - spmd_item.get_local_id(1); + [accA, accB, accC, M, N, + K](nd_item<2> spmd_item) [[intel::reqd_sub_group_size(SG_SZ)]] { + // The submatrix API has to be accessed by all the workitems in a + // subgroup these functions will be called once by the subgroup no + // code divergence between the workitems + const auto global_idx = spmd_item.get_global_id(0); + const auto global_idy = spmd_item.get_global_id(1); + const auto sg_startx = global_idx - spmd_item.get_local_id(0); + const auto sg_starty = global_idy - spmd_item.get_local_id(1); - sycl::sub_group sg = spmd_item.get_sub_group(); - joint_matrix sub_a(sg); - // For B, since current implementation does not support - // non-packed layout, users need to specify the updated VNNI - // sizes along with the packed_b layout. By default, the layout - // is row_major and size is (TK, TN). - joint_matrix sub_b(sg); - joint_matrix sub_c(sg); + sycl::sub_group sg = spmd_item.get_sub_group(); + joint_matrix sub_a(sg); + // For B, since current implementation does not support non-packed + // layout, users need to specify the updated VNNI sizes along with + // the packed_b layout. By default, the layout is row_major and size + // is (TK, TN). + joint_matrix sub_b(sg); + joint_matrix sub_c(sg); - // AMX: 8 register tiles : 1k byte size, SMmaxxSKmax =16x64 - // strideX = X's cols, so strideC = N, strideA = K, strideB = N*4 - joint_matrix_load( - sg, sub_c, - accC.template get_multi_ptr() + - (sg_startx * TM) * N + sg_starty / SG_SZ * TN, - N, matrix_layout::row_major); - for (int k = 0; k < K / TK; k += 1) { - joint_matrix_load( - sg, sub_a, - accA.template get_multi_ptr() + - (sg_startx * TM) * K + k * TK, - K, matrix_layout::row_major); - // Assuming B data is already in VNNI format. - joint_matrix_load( - sg, sub_b, - accB.template get_multi_ptr() + - (k * TK / 4) * (N * 4) + sg_starty / SG_SZ * TN * 4, - N * 4, matrix_layout::packed_b); - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); - } - joint_matrix_store( - sg, sub_c, - accC.template get_multi_ptr() + - (sg_startx * TM) * N + sg_starty / SG_SZ * TN, - N, matrix_layout::row_major); - }); // parallel for + // AMX: 8 register tiles : 1k byte size, SMmaxxSKmax =16x64 + // strideX = X's cols, so strideC = N, strideA = K, strideB = N*4 + joint_matrix_load( + sg, sub_c, + accC.template get_multi_ptr() + + (sg_startx * TM) * N + sg_starty / SG_SZ * TN, + N, matrix_layout::row_major); + for (int k = 0; k < K / TK; k += 1) { + joint_matrix_load( + sg, sub_a, + accA.template get_multi_ptr() + + (sg_startx * TM) * K + k * TK, + K, matrix_layout::row_major); + // Assuming B data is already in VNNI format. + joint_matrix_load( + sg, sub_b, + accB.template get_multi_ptr() + + (k * TK / 4) * (N * 4) + sg_starty / SG_SZ * TN * 4, + N * 4, matrix_layout::packed_b); + sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + } + joint_matrix_store( + sg, sub_c, + accC.template get_multi_ptr() + + (sg_startx * TM) * N + sg_starty / SG_SZ * TN, + N, matrix_layout::row_major); + }); // parallel for }).wait(); } diff --git a/sycl/test-e2e/Matrix/Legacy/joint_matrix_uu_int8_impl.hpp b/sycl/test-e2e/Matrix/Legacy/joint_matrix_uu_int8_impl.hpp index 895b7f0339cfe..14190434dd2b1 100644 --- a/sycl/test-e2e/Matrix/Legacy/joint_matrix_uu_int8_impl.hpp +++ b/sycl/test-e2e/Matrix/Legacy/joint_matrix_uu_int8_impl.hpp @@ -38,52 +38,52 @@ void matrix_multiply(big_matrix &C, cgh.parallel_for( nd_range<2>({NDRangeM, NDRangeN * SG_SZ}, {1, 1 * SG_SZ}), - [accA, accB, accC, M, N, K](nd_item<2> spmd_item) - [[intel::reqd_sub_group_size(SG_SZ)]] { - // The submatrix API has to be accessed by all the workitems in a - // subgroup these functions will be called once by the subgroup - // no code divergence between the workitems - const auto global_idx = spmd_item.get_global_id(0); - const auto global_idy = spmd_item.get_global_id(1); - const auto sg_startx = global_idx - spmd_item.get_local_id(0); - const auto sg_starty = global_idy - spmd_item.get_local_id(1); + [accA, accB, accC, M, N, + K](nd_item<2> spmd_item) [[intel::reqd_sub_group_size(SG_SZ)]] { + // The submatrix API has to be accessed by all the workitems in a + // subgroup these functions will be called once by the subgroup no + // code divergence between the workitems + const auto global_idx = spmd_item.get_global_id(0); + const auto global_idy = spmd_item.get_global_id(1); + const auto sg_startx = global_idx - spmd_item.get_local_id(0); + const auto sg_starty = global_idy - spmd_item.get_local_id(1); - sycl::sub_group sg = spmd_item.get_sub_group(); - joint_matrix sub_a(sg); - // For B, since current implementation does not support - // non-packed layout, users need to specify the updated VNNI - // sizes along with the packed_b layout. By default, the layout - // is row_major and size is (TK, TN). - joint_matrix sub_b(sg); - joint_matrix sub_c(sg); + sycl::sub_group sg = spmd_item.get_sub_group(); + joint_matrix sub_a(sg); + // For B, since current implementation does not support non-packed + // layout, users need to specify the updated VNNI sizes along with + // the packed_b layout. By default, the layout is row_major and size + // is (TK, TN). + joint_matrix sub_b(sg); + joint_matrix sub_c(sg); - // AMX: 8 register tiles : 1k byte size, SMmaxxSKmax =16x64 - // strideX = X's cols, so strideC = N, strideA = K, strideB = N*4 - joint_matrix_load( - sg, sub_c, - accC.template get_multi_ptr() + - (sg_startx * TM) * N + sg_starty / SG_SZ * TN, - N, matrix_layout::row_major); - for (int k = 0; k < K / TK; k += 1) { - joint_matrix_load( - sg, sub_a, - accA.template get_multi_ptr() + - (sg_startx * TM) * K + k * TK, - K, matrix_layout::row_major); - // Assuming B data is already in VNNI format. - joint_matrix_load( - sg, sub_b, - accB.template get_multi_ptr() + - (k * TK / 4) * (N * 4) + sg_starty / SG_SZ * TN * 4, - N * 4, matrix_layout::packed_b); - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); - } - joint_matrix_store( - sg, sub_c, - accC.template get_multi_ptr() + - (sg_startx * TM) * N + sg_starty / SG_SZ * TN, - N, matrix_layout::row_major); - }); // parallel for + // AMX: 8 register tiles : 1k byte size, SMmaxxSKmax =16x64 + // strideX = X's cols, so strideC = N, strideA = K, strideB = N*4 + joint_matrix_load( + sg, sub_c, + accC.template get_multi_ptr() + + (sg_startx * TM) * N + sg_starty / SG_SZ * TN, + N, matrix_layout::row_major); + for (int k = 0; k < K / TK; k += 1) { + joint_matrix_load( + sg, sub_a, + accA.template get_multi_ptr() + + (sg_startx * TM) * K + k * TK, + K, matrix_layout::row_major); + // Assuming B data is already in VNNI format. + joint_matrix_load( + sg, sub_b, + accB.template get_multi_ptr() + + (k * TK / 4) * (N * 4) + sg_starty / SG_SZ * TN * 4, + N * 4, matrix_layout::packed_b); + sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + } + joint_matrix_store( + sg, sub_c, + accC.template get_multi_ptr() + + (sg_startx * TM) * N + sg_starty / SG_SZ * TN, + N, matrix_layout::row_major); + }); // parallel for }).wait(); } diff --git a/sycl/test-e2e/Matrix/element_wise_abc_impl.hpp b/sycl/test-e2e/Matrix/element_wise_abc_impl.hpp index 1da92f4fd4dbc..8b7ee3af2b9c5 100644 --- a/sycl/test-e2e/Matrix/element_wise_abc_impl.hpp +++ b/sycl/test-e2e/Matrix/element_wise_abc_impl.hpp @@ -55,9 +55,8 @@ void matrix_elem_wise_ops(big_matrix &C, big_matrix &A, sub_group sg = spmd_item.get_sub_group(); joint_matrix sub_a; // For B, we assume B has been already VNNIed. - joint_matrix< - sub_group, T2, use::b, TK, TN, - ext::oneapi::experimental::matrix::layout::ext_intel_packed> + joint_matrix sub_b; joint_matrix sub_c; @@ -66,7 +65,8 @@ void matrix_elem_wise_ops(big_matrix &C, big_matrix &A, accA.template get_multi_ptr() + (sg_startx * TM) * K, K); - auto wi_slice_a = sycl::ext::oneapi::detail::get_wi_data(sg, sub_a); + auto wi_slice_a = + sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_a); for (int i = 0; i < wi_slice_a.length(); i++) { wi_slice_a[i] += 1; } @@ -76,7 +76,8 @@ void matrix_elem_wise_ops(big_matrix &C, big_matrix &A, accB.template get_multi_ptr() + sg_starty / SG_SZ * TN * vnniFactor, N * vnniFactor); - auto wi_slice_b = sycl::ext::oneapi::detail::get_wi_data(sg, sub_b); + auto wi_slice_b = + sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_b); for (int i = 0; i < wi_slice_b.length(); i++) { wi_slice_b[i] += 1; } @@ -86,7 +87,8 @@ void matrix_elem_wise_ops(big_matrix &C, big_matrix &A, accC.template get_multi_ptr() + (sg_startx * TM) * N + sg_starty / SG_SZ * TN, N, layout::row_major); - auto wi_slice_c = sycl::ext::oneapi::detail::get_wi_data(sg, sub_c); + auto wi_slice_c = + sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_c); for (int i = 0; i < wi_slice_c.length(); i++) { wi_slice_c[i] += 1; } diff --git a/sycl/test-e2e/Matrix/element_wise_all_ops_half_impl.hpp b/sycl/test-e2e/Matrix/element_wise_all_ops_half_impl.hpp index df786d73a78ea..540d75c245815 100644 --- a/sycl/test-e2e/Matrix/element_wise_all_ops_half_impl.hpp +++ b/sycl/test-e2e/Matrix/element_wise_all_ops_half_impl.hpp @@ -41,7 +41,8 @@ void matrix_verify_add(queue q, big_matrix &A, nd_range<2> &r, joint_matrix_fill(sg, sub_a, 5); - auto wi_slice_a = sycl::ext::oneapi::detail::get_wi_data(sg, sub_a); + auto wi_slice_a = + sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_a); for (int i = 0; i < wi_slice_a.length(); i++) { wi_slice_a[i] = wi_slice_a[i] + static_cast(2); } @@ -75,7 +76,8 @@ void matrix_verify_sub(queue q, big_matrix &A, nd_range<2> &r, joint_matrix_fill(sg, sub_a, 5); - auto wi_slice_a = sycl::ext::oneapi::detail::get_wi_data(sg, sub_a); + auto wi_slice_a = + sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_a); for (int i = 0; i < wi_slice_a.length(); i++) { wi_slice_a[i] = wi_slice_a[i] - static_cast(2); } @@ -109,7 +111,8 @@ void matrix_verify_mul(queue q, big_matrix &A, nd_range<2> &r, joint_matrix_fill(sg, sub_a, 5); - auto wi_slice_a = sycl::ext::oneapi::detail::get_wi_data(sg, sub_a); + auto wi_slice_a = + sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_a); for (int i = 0; i < wi_slice_a.length(); i++) { wi_slice_a[i] = wi_slice_a[i] * static_cast(3.0); } @@ -143,7 +146,8 @@ void matrix_verify_div(queue q, big_matrix &A, nd_range<2> &r, joint_matrix_fill(sg, sub_a, 4); - auto wi_slice_a = sycl::ext::oneapi::detail::get_wi_data(sg, sub_a); + auto wi_slice_a = + sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_a); for (int i = 0; i < wi_slice_a.length(); i++) { wi_slice_a[i] = wi_slice_a[i] / static_cast(2.0); } @@ -177,7 +181,8 @@ void matrix_verify_logic(queue q, big_matrix &A, nd_range<2> &r, joint_matrix_fill(sg, sub_a, 5); - auto wi_slice_a = sycl::ext::oneapi::detail::get_wi_data(sg, sub_a); + auto wi_slice_a = + sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_a); for (int i = 0; i < wi_slice_a.length(); i++) { if (wi_slice_a[i]) { if (wi_slice_a[i] > static_cast(2.0) || diff --git a/sycl/test-e2e/Matrix/element_wise_all_ops_impl.hpp b/sycl/test-e2e/Matrix/element_wise_all_ops_impl.hpp index 447dff879cb01..8e15488e151a0 100644 --- a/sycl/test-e2e/Matrix/element_wise_all_ops_impl.hpp +++ b/sycl/test-e2e/Matrix/element_wise_all_ops_impl.hpp @@ -63,7 +63,8 @@ void verify_op_a(const T l, const T r, const float ref, OP op) { layout::row_major> sub_mat; joint_matrix_fill(sg, sub_mat, l); - auto wi_slice = sycl::ext::oneapi::detail::get_wi_data(sg, sub_mat); + auto wi_slice = + sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_mat); for (int i = 0; i < wi_slice.length(); i++) { wi_slice[i] = op(wi_slice[i], r); } @@ -103,7 +104,8 @@ void verify_op_c(const T l, const T r, const float ref, OP op) { joint_matrix sub_mat; joint_matrix_fill(sg, sub_mat, l); - auto wi_slice = sycl::ext::oneapi::detail::get_wi_data(sg, sub_mat); + auto wi_slice = + sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_mat); for (int i = 0; i < wi_slice.length(); i++) { wi_slice[i] = op(wi_slice[i], r); } diff --git a/sycl/test-e2e/Matrix/element_wise_all_ops_int8_impl.hpp b/sycl/test-e2e/Matrix/element_wise_all_ops_int8_impl.hpp index c18e371711824..803ebe0addb3a 100644 --- a/sycl/test-e2e/Matrix/element_wise_all_ops_int8_impl.hpp +++ b/sycl/test-e2e/Matrix/element_wise_all_ops_int8_impl.hpp @@ -40,7 +40,8 @@ void matrix_verify_add(queue q, big_matrix &A, nd_range<2> &r, joint_matrix_fill(sg, sub_a, 5); - auto wi_slice_a = sycl::ext::oneapi::detail::get_wi_data(sg, sub_a); + auto wi_slice_a = + sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_a); for (int i = 0; i < wi_slice_a.length(); i++) { wi_slice_a[i] = wi_slice_a[i] + 2; } @@ -74,7 +75,8 @@ void matrix_verify_sub(queue q, big_matrix &A, nd_range<2> &r, joint_matrix_fill(sg, sub_a, 5); - auto wi_slice_a = sycl::ext::oneapi::detail::get_wi_data(sg, sub_a); + auto wi_slice_a = + sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_a); for (int i = 0; i < wi_slice_a.length(); i++) { wi_slice_a[i] = wi_slice_a[i] - 2; } @@ -108,7 +110,8 @@ void matrix_verify_mul(queue q, big_matrix &A, nd_range<2> &r, joint_matrix_fill(sg, sub_a, 5); - auto wi_slice_a = sycl::ext::oneapi::detail::get_wi_data(sg, sub_a); + auto wi_slice_a = + sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_a); for (int i = 0; i < wi_slice_a.length(); i++) { wi_slice_a[i] = wi_slice_a[i] * 3; } @@ -142,7 +145,8 @@ void matrix_verify_div(queue q, big_matrix &A, nd_range<2> &r, joint_matrix_fill(sg, sub_a, 4); - auto wi_slice_a = sycl::ext::oneapi::detail::get_wi_data(sg, sub_a); + auto wi_slice_a = + sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_a); for (int i = 0; i < wi_slice_a.length(); i++) { wi_slice_a[i] = wi_slice_a[i] / 2; } @@ -176,7 +180,8 @@ void matrix_verify_logic(queue q, big_matrix &A, nd_range<2> &r, joint_matrix_fill(sg, sub_a, 5); - auto wi_slice_a = sycl::ext::oneapi::detail::get_wi_data(sg, sub_a); + auto wi_slice_a = + sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_a); for (int i = 0; i < wi_slice_a.length(); i++) { if (wi_slice_a[i]) { if (wi_slice_a[i] > 2 || wi_slice_a[i] >= 2 || diff --git a/sycl/test-e2e/Matrix/element_wise_all_ops_int8_packed_impl.hpp b/sycl/test-e2e/Matrix/element_wise_all_ops_int8_packed_impl.hpp index a04d464605fef..ce89a04b4168c 100644 --- a/sycl/test-e2e/Matrix/element_wise_all_ops_int8_packed_impl.hpp +++ b/sycl/test-e2e/Matrix/element_wise_all_ops_int8_packed_impl.hpp @@ -36,14 +36,14 @@ void matrix_verify_add(queue q, big_matrix &A, nd_range<2> &r, const auto sg_starty = global_idy - spmd_item.get_local_id(1); sub_group sg = spmd_item.get_sub_group(); - joint_matrix< - sub_group, int8_t, use::b, TK, TN, - ext::oneapi::experimental::matrix::layout::ext_intel_packed> + joint_matrix sub_b; joint_matrix_fill(sg, sub_b, 5); - auto wi_slice_b = sycl::ext::oneapi::detail::get_wi_data(sg, sub_b); + auto wi_slice_b = + sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_b); for (int i = 0; i < wi_slice_b.length(); i++) { wi_slice_b[i] = wi_slice_b[i] + 2; } @@ -73,14 +73,14 @@ void matrix_verify_sub(queue q, big_matrix &A, nd_range<2> &r, const auto sg_starty = global_idy - spmd_item.get_local_id(1); sub_group sg = spmd_item.get_sub_group(); - joint_matrix< - sub_group, int8_t, use::b, TK, TN, - ext::oneapi::experimental::matrix::layout::ext_intel_packed> + joint_matrix sub_b; joint_matrix_fill(sg, sub_b, 5); - auto wi_slice_b = sycl::ext::oneapi::detail::get_wi_data(sg, sub_b); + auto wi_slice_b = + sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_b); for (int i = 0; i < wi_slice_b.length(); i++) { wi_slice_b[i] = wi_slice_b[i] - 2; } @@ -110,14 +110,14 @@ void matrix_verify_mul(queue q, big_matrix &A, nd_range<2> &r, const auto sg_starty = global_idy - spmd_item.get_local_id(1); sub_group sg = spmd_item.get_sub_group(); - joint_matrix< - sub_group, int8_t, use::b, TK, TN, - ext::oneapi::experimental::matrix::layout::ext_intel_packed> + joint_matrix sub_b; joint_matrix_fill(sg, sub_b, 5); - auto wi_slice_b = sycl::ext::oneapi::detail::get_wi_data(sg, sub_b); + auto wi_slice_b = + sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_b); for (int i = 0; i < wi_slice_b.length(); i++) { wi_slice_b[i] = wi_slice_b[i] * 3; } @@ -147,14 +147,14 @@ void matrix_verify_div(queue q, big_matrix &A, nd_range<2> &r, const auto sg_starty = global_idy - spmd_item.get_local_id(1); sub_group sg = spmd_item.get_sub_group(); - joint_matrix< - sub_group, int8_t, use::b, TK, TN, - ext::oneapi::experimental::matrix::layout::ext_intel_packed> + joint_matrix sub_b; joint_matrix_fill(sg, sub_b, 4); - auto wi_slice_b = sycl::ext::oneapi::detail::get_wi_data(sg, sub_b); + auto wi_slice_b = + sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_b); for (int i = 0; i < wi_slice_b.length(); i++) { wi_slice_b[i] = wi_slice_b[i] / 2; } @@ -184,14 +184,14 @@ void matrix_verify_logic(queue q, big_matrix &A, nd_range<2> &r, const auto sg_starty = global_idy - spmd_item.get_local_id(1); sub_group sg = spmd_item.get_sub_group(); - joint_matrix< - sub_group, int8_t, use::b, TK, TN, - ext::oneapi::experimental::matrix::layout::ext_intel_packed> + joint_matrix sub_b; joint_matrix_fill(sg, sub_b, 5); - auto wi_slice_b = sycl::ext::oneapi::detail::get_wi_data(sg, sub_b); + auto wi_slice_b = + sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_b); for (int i = 0; i < wi_slice_b.length(); i++) { if (wi_slice_b[i]) { if (wi_slice_b[i] > 2 || wi_slice_b[i] >= 2 || diff --git a/sycl/test-e2e/Matrix/element_wise_all_ops_tf32_impl.hpp b/sycl/test-e2e/Matrix/element_wise_all_ops_tf32_impl.hpp index e73056639eb74..27eacf89c748a 100644 --- a/sycl/test-e2e/Matrix/element_wise_all_ops_tf32_impl.hpp +++ b/sycl/test-e2e/Matrix/element_wise_all_ops_tf32_impl.hpp @@ -41,7 +41,8 @@ void matrix_verify_add(queue q, big_matrix &A, nd_range<2> &r, joint_matrix_fill(sg, sub_a, round_to_tf32(5.0)); - auto wi_slice_a = ext::oneapi::detail::get_wi_data(sg, sub_a); + auto wi_slice_a = + ext::intel::experimental::matrix::get_wi_data(sg, sub_a); for (int i = 0; i < wi_slice_a.length(); i++) { wi_slice_a[i] = wi_slice_a[i] + 2; } @@ -76,7 +77,8 @@ void matrix_verify_sub(queue q, big_matrix &A, nd_range<2> &r, joint_matrix_fill(sg, sub_a, round_to_tf32(5.0)); - auto wi_slice_a = ext::oneapi::detail::get_wi_data(sg, sub_a); + auto wi_slice_a = + ext::intel::experimental::matrix::get_wi_data(sg, sub_a); for (int i = 0; i < wi_slice_a.length(); i++) { wi_slice_a[i] = wi_slice_a[i] - round_to_tf32(2); } @@ -109,7 +111,8 @@ void matrix_verify_mul(queue q, big_matrix &A, nd_range<2> &r, joint_matrix sub_a; joint_matrix_fill(sg, sub_a, round_to_tf32(5.0)); - auto wi_slice_a = ext::oneapi::detail::get_wi_data(sg, sub_a); + auto wi_slice_a = + ext::intel::experimental::matrix::get_wi_data(sg, sub_a); for (int i = 0; i < wi_slice_a.length(); i++) { wi_slice_a[i] = wi_slice_a[i] * round_to_tf32(3.0); } @@ -143,7 +146,8 @@ void matrix_verify_div(queue q, big_matrix &A, nd_range<2> &r, joint_matrix_fill(sg, sub_a, round_to_tf32(4.0)); - auto wi_slice_a = ext::oneapi::detail::get_wi_data(sg, sub_a); + auto wi_slice_a = + ext::intel::experimental::matrix::get_wi_data(sg, sub_a); for (int i = 0; i < wi_slice_a.length(); i++) { wi_slice_a[i] = wi_slice_a[i] / round_to_tf32(2.0); } @@ -176,7 +180,8 @@ void matrix_verify_logic(queue q, big_matrix &A, nd_range<2> &r, joint_matrix_fill(sg, sub_a, round_to_tf32(5.0)); - auto wi_slice_a = ext::oneapi::detail::get_wi_data(sg, sub_a); + auto wi_slice_a = + ext::intel::experimental::matrix::get_wi_data(sg, sub_a); for (int i = 0; i < wi_slice_a.length(); i++) { if (wi_slice_a[i]) { if (wi_slice_a[i] > 2.0 || wi_slice_a[i] >= 2.0 || diff --git a/sycl/test-e2e/Matrix/element_wise_all_sizes_impl.hpp b/sycl/test-e2e/Matrix/element_wise_all_sizes_impl.hpp index f1aacbef1230f..c49f9b57e2f32 100644 --- a/sycl/test-e2e/Matrix/element_wise_all_sizes_impl.hpp +++ b/sycl/test-e2e/Matrix/element_wise_all_sizes_impl.hpp @@ -49,29 +49,30 @@ void matrix_verify_add(const T1 val1, const T1 val2, const T1 result) { q.submit([&](handler &cgh) { sycl::accessor accA{bufA, cgh, sycl::read_write}; - cgh.parallel_for( - r, [=](nd_item<2> spmd_item) [[intel::reqd_sub_group_size(SG_SZ)]] { - const auto global_idx = spmd_item.get_global_id(0); - const auto global_idy = spmd_item.get_global_id(1); - const auto sg_startx = global_idx - spmd_item.get_local_id(0); - const auto sg_starty = global_idy - spmd_item.get_local_id(1); - - sub_group sg = spmd_item.get_sub_group(); - joint_matrix sub_a; - - joint_matrix_fill(sg, sub_a, val1); - - auto wi_slice_a = sycl::ext::oneapi::detail::get_wi_data(sg, sub_a); - for (int i = 0; i < wi_slice_a.length(); i++) { - wi_slice_a[i] = wi_slice_a[i] + val2; - } - - ext::intel::experimental::matrix::joint_matrix_store( - sg, sub_a, - accA.template get_multi_ptr() + - (sg_startx * TM) * K + sg_starty / SG_SZ * TK, - K); - }); // parallel for + cgh.parallel_for(r, [=](nd_item<2> spmd_item) [[intel::reqd_sub_group_size( + SG_SZ)]] { + const auto global_idx = spmd_item.get_global_id(0); + const auto global_idy = spmd_item.get_global_id(1); + const auto sg_startx = global_idx - spmd_item.get_local_id(0); + const auto sg_starty = global_idy - spmd_item.get_local_id(1); + + sub_group sg = spmd_item.get_sub_group(); + joint_matrix sub_a; + + joint_matrix_fill(sg, sub_a, val1); + + auto wi_slice_a = + sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_a); + for (int i = 0; i < wi_slice_a.length(); i++) { + wi_slice_a[i] = wi_slice_a[i] + val2; + } + + ext::intel::experimental::matrix::joint_matrix_store( + sg, sub_a, + accA.template get_multi_ptr() + + (sg_startx * TM) * K + sg_starty / SG_SZ * TK, + K); + }); // parallel for }).wait(); assert_ops_ref(bufA.get_host_access(), result); } diff --git a/sycl/test-e2e/Matrix/element_wise_irreg_sum_rows_impl.hpp b/sycl/test-e2e/Matrix/element_wise_irreg_sum_rows_impl.hpp index 6e3a86d7cb77b..cfce95cba269f 100644 --- a/sycl/test-e2e/Matrix/element_wise_irreg_sum_rows_impl.hpp +++ b/sycl/test-e2e/Matrix/element_wise_irreg_sum_rows_impl.hpp @@ -44,9 +44,8 @@ void matrix_sum_rows(queue q, big_matrix &B, nd_range<2> &r) { sycl::sub_group sg = spmd_item.get_sub_group(); - joint_matrix< - sub_group, T, use::b, TK, TN, - ext::oneapi::experimental::matrix::layout::ext_intel_packed> + joint_matrix sub_b; joint_matrix_load( @@ -58,7 +57,8 @@ void matrix_sum_rows(queue q, big_matrix &B, nd_range<2> &r) { // (tK/4) int32_t sum_local_rows[M] = {0}; // 8 local rows, M total // sub_b has 32x8 elements, 32 elements per WI, 4 per WI per row - auto data = sycl::ext::oneapi::detail::get_wi_data(sg, sub_b); + auto data = + sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_b); // each WI calculates local sum of rows for (int row = 0; row < TK / 4; row++) { // there are 8 rows diff --git a/sycl/test-e2e/Matrix/element_wise_ops_impl.hpp b/sycl/test-e2e/Matrix/element_wise_ops_impl.hpp index edf9f162d5c3a..1206b556339a9 100644 --- a/sycl/test-e2e/Matrix/element_wise_ops_impl.hpp +++ b/sycl/test-e2e/Matrix/element_wise_ops_impl.hpp @@ -51,9 +51,8 @@ void matrix_multiply(big_matrix &C, joint_matrix sub_a; // For B, we assume B has been already VNNIed. - joint_matrix< - sub_group, int8_t, use::b, TK, TN, - ext::oneapi::experimental::matrix::layout::ext_intel_packed> + joint_matrix sub_b; joint_matrix sub_c; @@ -73,9 +72,10 @@ void matrix_multiply(big_matrix &C, accB.template get_multi_ptr() + (k * TK / 4) * (N * 4) + sg_starty / SG_SZ * TN * 4, N * 4); - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c); + sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); } - auto wi_slice_c = sycl::ext::oneapi::detail::get_wi_data(sg, sub_c); + auto wi_slice_c = + sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_c); for (int i = 0; i < wi_slice_c.length(); i++) { wi_slice_c[i] *= 2; } diff --git a/sycl/test-e2e/Matrix/elemwise_irreg_size_ops_bf16.cpp b/sycl/test-e2e/Matrix/elemwise_irreg_size_ops_bf16.cpp index 08724fbb48e9d..8e2865de207b4 100644 --- a/sycl/test-e2e/Matrix/elemwise_irreg_size_ops_bf16.cpp +++ b/sycl/test-e2e/Matrix/elemwise_irreg_size_ops_bf16.cpp @@ -80,9 +80,8 @@ void matrix_multiply(big_matrix &C, joint_matrix sub_a; // For B, we assume B has been already VNNIed. - joint_matrix< - sub_group, bfloat16, use::b, TK, TN, - ext::oneapi::experimental::matrix::layout::ext_intel_packed> + joint_matrix sub_b; joint_matrix sub_c; joint_matrix_load( @@ -102,9 +101,10 @@ void matrix_multiply(big_matrix &C, accB.template get_multi_ptr() + (k) * (N) + sg_starty / SG_SZ * TN * 2, N * 2); - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c); + sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); } - auto wi_slice_c = sycl::ext::oneapi::detail::get_wi_data(sg, sub_c); + auto wi_slice_c = + sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_c); for (int i = 0; i < wi_slice_c.length(); i++) { wi_slice_c[i] += 5.0; } diff --git a/sycl/test-e2e/Matrix/joint_matrix_all_sizes_impl.hpp b/sycl/test-e2e/Matrix/joint_matrix_all_sizes_impl.hpp index f4572990ded76..469837cd26d41 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_all_sizes_impl.hpp +++ b/sycl/test-e2e/Matrix/joint_matrix_all_sizes_impl.hpp @@ -56,9 +56,8 @@ void matrix_multiply(big_matrix &C, big_matrix &A, sub_group sg = spmd_item.get_sub_group(); joint_matrix sub_a; // For B, we assume B has been already VNNIed. - joint_matrix< - sub_group, T2, use::b, TK, TN, - ext::oneapi::experimental::matrix::layout::ext_intel_packed> + joint_matrix sub_b; joint_matrix sub_c; @@ -79,7 +78,7 @@ void matrix_multiply(big_matrix &C, big_matrix &A, (k * TK / vnniFactor) * (N * vnniFactor) + sg_starty / SG_SZ * TN * vnniFactor, N * vnniFactor); - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c); + sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); } joint_matrix_store( sg, sub_c, diff --git a/sycl/test-e2e/Matrix/joint_matrix_apply_cuda.hpp b/sycl/test-e2e/Matrix/joint_matrix_apply_cuda.hpp index b58e0a1b7c467..4303239cefe32 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_apply_cuda.hpp +++ b/sycl/test-e2e/Matrix/joint_matrix_apply_cuda.hpp @@ -51,34 +51,34 @@ void matrix_verify_lambda(queue q, q.submit([&](handler &cgh) { accessor accC(bufC, cgh); - cgh.parallel_for>( - r, [accC, lambda]( - nd_item<2> spmd_item) [[sycl::reqd_sub_group_size(SG_SZ)]] { - const auto global_idx = spmd_item.get_global_id(0); - const auto global_idy = spmd_item.get_global_id(1); - const auto sg_startx = global_idx - spmd_item.get_local_id(0); - const auto sg_starty = global_idy - spmd_item.get_local_id(1); - - auto sg = spmd_item.get_sub_group(); - - joint_matrix sub_a; - joint_matrix sub_b; - joint_matrix sub_c; - - joint_matrix_fill(sg, sub_a, 3); - joint_matrix_fill(sg, sub_b, 1); - joint_matrix_fill(sg, sub_c, -80); - - joint_matrix_apply(sg, sub_a, lambda); - - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c); - - joint_matrix_store( - sg, sub_c, - accC.template get_multi_ptr() + - (sg_startx * M) * (N * nWGperDim) + sg_starty / SG_SZ * N, - (N * nWGperDim), layout::row_major); - }); // parallel for + cgh.parallel_for>(r, [ + accC, lambda + ](nd_item<2> spmd_item)[[sycl::reqd_sub_group_size(SG_SZ)]] { + const auto global_idx = spmd_item.get_global_id(0); + const auto global_idy = spmd_item.get_global_id(1); + const auto sg_startx = global_idx - spmd_item.get_local_id(0); + const auto sg_starty = global_idy - spmd_item.get_local_id(1); + + auto sg = spmd_item.get_sub_group(); + + joint_matrix sub_a; + joint_matrix sub_b; + joint_matrix sub_c; + + joint_matrix_fill(sg, sub_a, 3); + joint_matrix_fill(sg, sub_b, 1); + joint_matrix_fill(sg, sub_c, -80); + + joint_matrix_apply(sg, sub_a, lambda); + + sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + + joint_matrix_store( + sg, sub_c, + accC.template get_multi_ptr() + + (sg_startx * M) * (N * nWGperDim) + sg_starty / SG_SZ * N, + (N * nWGperDim), layout::row_major); + }); // parallel for }); } assert_ref(C.get_data(), ref); @@ -111,8 +111,8 @@ void matrix_verify_op(queue q, big_matrix &C, cgh); cgh.parallel_for>( - r, [accC, - Op](nd_item<2> spmd_item) [[sycl::reqd_sub_group_size(SG_SZ)]] { + r, [ accC, + Op ](nd_item<2> spmd_item)[[sycl::reqd_sub_group_size(SG_SZ)]] { const auto global_idx = spmd_item.get_global_id(0); const auto global_idy = spmd_item.get_global_id(1); const auto sg_startx = global_idx - spmd_item.get_local_id(0); @@ -154,7 +154,7 @@ void matrix_verify_op(queue q, big_matrix &C, } } - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c); + sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); joint_matrix_store( sg, sub_c, @@ -162,7 +162,8 @@ void matrix_verify_op(queue q, big_matrix &C, (sg_startx * M) * (N * nWGperDim) + sg_starty / SG_SZ * N, (N * nWGperDim), layout::row_major); }); // parallel for - }).wait(); + }) + .wait(); } assert_ops_ref(C.get_data(), ref); } diff --git a/sycl/test-e2e/Matrix/joint_matrix_bf16_fill_k_cache_impl.hpp b/sycl/test-e2e/Matrix/joint_matrix_bf16_fill_k_cache_impl.hpp index 74af36f0238ef..f2d359cfe130b 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_bf16_fill_k_cache_impl.hpp +++ b/sycl/test-e2e/Matrix/joint_matrix_bf16_fill_k_cache_impl.hpp @@ -147,37 +147,36 @@ double joint_matmul(TOperand *A, TOperand *B, TResult *C, queue &q, int i) { #endif ; - joint_matrix< - sub_group, TOperand, use::b, tK, tN, - ext::oneapi::experimental::matrix::layout::ext_intel_packed> + joint_matrix tB[NCACHE1 / tN][KCACHE2 / KCACHE1] #ifdef INIT_LIST = { - joint_matrix(), - joint_matrix(), - joint_matrix(), - joint_matrix(), - joint_matrix(), - joint_matrix(), - joint_matrix(), - joint_matrix(), + joint_matrix< + sub_group, TOperand, use::b, tK, tN, + ext::intel::experimental::matrix::layout::packed>(), + joint_matrix< + sub_group, TOperand, use::b, tK, tN, + ext::intel::experimental::matrix::layout::packed>(), + joint_matrix< + sub_group, TOperand, use::b, tK, tN, + ext::intel::experimental::matrix::layout::packed>(), + joint_matrix< + sub_group, TOperand, use::b, tK, tN, + ext::intel::experimental::matrix::layout::packed>(), + joint_matrix< + sub_group, TOperand, use::b, tK, tN, + ext::intel::experimental::matrix::layout::packed>(), + joint_matrix< + sub_group, TOperand, use::b, tK, tN, + ext::intel::experimental::matrix::layout::packed>(), + joint_matrix< + sub_group, TOperand, use::b, tK, tN, + ext::intel::experimental::matrix::layout::packed>(), + joint_matrix< + sub_group, TOperand, use::b, tK, tN, + ext::intel::experimental::matrix::layout::packed>(), } #endif ; diff --git a/sycl/test-e2e/Matrix/joint_matrix_bfloat16_32x64_impl.hpp b/sycl/test-e2e/Matrix/joint_matrix_bfloat16_32x64_impl.hpp index c27df0c01caa4..cc0196660744a 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_bfloat16_32x64_impl.hpp +++ b/sycl/test-e2e/Matrix/joint_matrix_bfloat16_32x64_impl.hpp @@ -46,9 +46,8 @@ void matrix_multiply(big_matrix &C, big_matrix &A, joint_matrix sub_a; // For B, we assume B has been already VNNIed. - joint_matrix< - sub_group, bfloat16, use::b, TK, TN, - ext::oneapi::experimental::matrix::layout::ext_intel_packed> + joint_matrix sub_b; joint_matrix sub_c; @@ -69,7 +68,7 @@ void matrix_multiply(big_matrix &C, big_matrix &A, accB.template get_multi_ptr() + (k * TK / 2) * (N * 2) + sg_starty / SG_SZ * TN * 2, N * 2); - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c); + sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); } joint_matrix_store( sg, sub_c, diff --git a/sycl/test-e2e/Matrix/joint_matrix_bfloat16_array_impl.hpp b/sycl/test-e2e/Matrix/joint_matrix_bfloat16_array_impl.hpp index 9cdf9b2435f82..d6390d8061dcc 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_bfloat16_array_impl.hpp +++ b/sycl/test-e2e/Matrix/joint_matrix_bfloat16_array_impl.hpp @@ -59,9 +59,8 @@ void matrix_multiply(big_matrix &C, big_matrix &A, sub_a[JM_ARRAY_SZ]; // For B, we assume B has been already VNNIed. - joint_matrix< - sub_group, bfloat16, use::b, TK, TN, - ext::oneapi::experimental::matrix::layout::ext_intel_packed> + joint_matrix sub_b; joint_matrix sub_c[JM_ARRAY_SZ]; @@ -82,7 +81,7 @@ void matrix_multiply(big_matrix &C, big_matrix &A, accA.template get_multi_ptr() + (sg_startx * TM * JM_ARRAY_SZ + TM * i) * K + k * TK, K); - joint_matrix_mad(sg, sub_a[i], sub_b, sub_c[i], sub_c[i]); + sub_c[i] = joint_matrix_mad(sg, sub_a[i], sub_b, sub_c[i]); } } diff --git a/sycl/test-e2e/Matrix/joint_matrix_bfloat16_colmajorA_colmajorB_impl.hpp b/sycl/test-e2e/Matrix/joint_matrix_bfloat16_colmajorA_colmajorB_impl.hpp index 2c627d8b88e31..4847e093127a8 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_bfloat16_colmajorA_colmajorB_impl.hpp +++ b/sycl/test-e2e/Matrix/joint_matrix_bfloat16_colmajorA_colmajorB_impl.hpp @@ -64,7 +64,7 @@ void matrix_multiply(big_matrix &C, big_matrix &A, accB.template get_multi_ptr() + (sg_starty / SG_SZ * TN) * K + k * TK, K); - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c); + sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); } joint_matrix_store( sg, sub_c, diff --git a/sycl/test-e2e/Matrix/joint_matrix_bfloat16_impl.hpp b/sycl/test-e2e/Matrix/joint_matrix_bfloat16_impl.hpp index a35dd11ee8cfa..76ac69da27677 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_bfloat16_impl.hpp +++ b/sycl/test-e2e/Matrix/joint_matrix_bfloat16_impl.hpp @@ -45,9 +45,8 @@ void matrix_multiply(big_matrix &C, big_matrix &A, joint_matrix sub_a; // For B, we assume B has been already VNNIed. - joint_matrix< - sub_group, bfloat16, use::b, TK, TN, - ext::oneapi::experimental::matrix::layout::ext_intel_packed> + joint_matrix sub_b; joint_matrix sub_c; @@ -67,7 +66,7 @@ void matrix_multiply(big_matrix &C, big_matrix &A, accB.template get_multi_ptr() + (k * TK / 2) * (N * 2) + sg_starty / SG_SZ * TN * 2, N * 2); - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c); + sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); } joint_matrix_store( sg, sub_c, diff --git a/sycl/test-e2e/Matrix/joint_matrix_bfloat16_rowmajorA_rowmajorB_impl.hpp b/sycl/test-e2e/Matrix/joint_matrix_bfloat16_rowmajorA_rowmajorB_impl.hpp index 08341b3835d57..4d61a733e5927 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_bfloat16_rowmajorA_rowmajorB_impl.hpp +++ b/sycl/test-e2e/Matrix/joint_matrix_bfloat16_rowmajorA_rowmajorB_impl.hpp @@ -64,7 +64,7 @@ void matrix_multiply(big_matrix &C, big_matrix &A, accB.template get_multi_ptr() + (k * TK) * (N) + sg_starty / SG_SZ * TN, N); - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c); + sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); } joint_matrix_store( sg, sub_c, diff --git a/sycl/test-e2e/Matrix/joint_matrix_colA_rowB_colC_impl.hpp b/sycl/test-e2e/Matrix/joint_matrix_colA_rowB_colC_impl.hpp index a90a46e258452..f75da1824d94b 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_colA_rowB_colC_impl.hpp +++ b/sycl/test-e2e/Matrix/joint_matrix_colA_rowB_colC_impl.hpp @@ -48,7 +48,7 @@ void matrix_multiply(T1 *C, T2 *A, T2 *B, queue q) { joint_matrix_load(sg, sub_a, pA + (sg_startx * TM) * K + k, K); joint_matrix_load(sg, sub_b, pB + k * N + sg_starty / SG_SZ * TN, N); - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c); + sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); } joint_matrix_store( sg, sub_c, pC + (sg_startx * TM) * N + sg_starty / SG_SZ * TN, N, diff --git a/sycl/test-e2e/Matrix/joint_matrix_down_convert_impl.hpp b/sycl/test-e2e/Matrix/joint_matrix_down_convert_impl.hpp index 5dd21cfe5340b..6972e3854c8e8 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_down_convert_impl.hpp +++ b/sycl/test-e2e/Matrix/joint_matrix_down_convert_impl.hpp @@ -49,7 +49,14 @@ void matrix_copy(big_matrix &C, big_matrix &A) { (sg_startx * TM) * N + sg_starty / SG_SZ * TN, N, layout::row_major); // This will be replaced by joint_matrix_copy API - joint_matrix_copy(sg, sub_c, sub_a); + // joint_matrix_copy(sg, sub_c, sub_ac); + auto wi_slice_c = + sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_c); + auto wi_slice_a = + sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_a); + for (int i = 0; i < wi_slice_c.length(); i++) { + wi_slice_a[i] = (bfloat16)wi_slice_c[i]; + } ext::intel::experimental::matrix::joint_matrix_store( sg, sub_a, accA.template get_multi_ptr() + diff --git a/sycl/test-e2e/Matrix/joint_matrix_gemm_cuda.hpp b/sycl/test-e2e/Matrix/joint_matrix_gemm_cuda.hpp index 244311f7cc9ee..5e451d45d7727 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_gemm_cuda.hpp +++ b/sycl/test-e2e/Matrix/joint_matrix_gemm_cuda.hpp @@ -178,7 +178,7 @@ void test(queue &q) { } } - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c); + sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); } joint_matrix_store( sg, sub_c, diff --git a/sycl/test-e2e/Matrix/joint_matrix_half_impl.hpp b/sycl/test-e2e/Matrix/joint_matrix_half_impl.hpp index fcfb87545ff6b..f92548d2f7ed8 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_half_impl.hpp +++ b/sycl/test-e2e/Matrix/joint_matrix_half_impl.hpp @@ -50,9 +50,8 @@ void matrix_multiply(big_matrix &C, joint_matrix sub_a; // For B, we assume B has been already VNNIed. - joint_matrix< - sub_group, half, use::b, TK, TN, - ext::oneapi::experimental::matrix::layout::ext_intel_packed> + joint_matrix sub_b; joint_matrix sub_c; @@ -72,7 +71,7 @@ void matrix_multiply(big_matrix &C, accB.template get_multi_ptr() + (k * TK / 2) * (N * 2) + sg_starty / SG_SZ * TN * 2, N * 2); - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c); + sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); } joint_matrix_store( sg, sub_c, diff --git a/sycl/test-e2e/Matrix/joint_matrix_int8_colmajorA_colmajorB_impl.hpp b/sycl/test-e2e/Matrix/joint_matrix_int8_colmajorA_colmajorB_impl.hpp index a5b1fddcbbbb8..6111f503007f5 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_int8_colmajorA_colmajorB_impl.hpp +++ b/sycl/test-e2e/Matrix/joint_matrix_int8_colmajorA_colmajorB_impl.hpp @@ -64,7 +64,7 @@ void matrix_multiply(big_matrix &C, accB.template get_multi_ptr() + (sg_starty / SG_SZ * TN) * K + k * TK, K); - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c); + sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); } joint_matrix_store( sg, sub_c, diff --git a/sycl/test-e2e/Matrix/joint_matrix_int8_vnni_impl.hpp b/sycl/test-e2e/Matrix/joint_matrix_int8_vnni_impl.hpp index e042a8e282b56..b6fe3f0376ffd 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_int8_vnni_impl.hpp +++ b/sycl/test-e2e/Matrix/joint_matrix_int8_vnni_impl.hpp @@ -65,7 +65,7 @@ void matrix_multiply(big_matrix &C, accB.template get_multi_ptr() + (k * TK) * N + sg_starty / SG_SZ * TN, N); - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c); + sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); } joint_matrix_store( sg, sub_c, diff --git a/sycl/test-e2e/Matrix/joint_matrix_query_default.cpp b/sycl/test-e2e/Matrix/joint_matrix_query_default.cpp index b44661a1f269f..048aed6341f6c 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_query_default.cpp +++ b/sycl/test-e2e/Matrix/joint_matrix_query_default.cpp @@ -78,8 +78,7 @@ void matrix_multiply(big_matrix &C, myparams2::joint_matrix_a sub_a; myparams2::joint_matrix_b< - sub_group, - ext::oneapi::experimental::matrix::layout::ext_intel_packed> + sub_group, ext::intel::experimental::matrix::layout::packed> sub_b; myparams2::joint_matrix_accumulator sub_c; @@ -100,7 +99,7 @@ void matrix_multiply(big_matrix &C, accB.template get_multi_ptr() + (k * TK / 4) * (N * 4) + sg_starty / SG_SZ * TN * 4, N * 4); - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c); + sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); } joint_matrix_store( sg, sub_c, diff --git a/sycl/test-e2e/Matrix/joint_matrix_ss_int8_impl.hpp b/sycl/test-e2e/Matrix/joint_matrix_ss_int8_impl.hpp index 5820544722f55..5cca6572cef21 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_ss_int8_impl.hpp +++ b/sycl/test-e2e/Matrix/joint_matrix_ss_int8_impl.hpp @@ -51,9 +51,8 @@ void matrix_multiply(big_matrix &C, joint_matrix sub_a; // For B, we assume B has been already VNNIed. - joint_matrix< - sub_group, int8_t, use::b, TK, TN, - ext::oneapi::experimental::matrix::layout::ext_intel_packed> + joint_matrix sub_b; joint_matrix sub_c; @@ -69,7 +68,7 @@ void matrix_multiply(big_matrix &C, accB.template get_multi_ptr() + (k * TK / 4) * (N * 4) + sg_starty / SG_SZ * TN * 4, N * 4); - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c); + sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); } joint_matrix_store( sg, sub_c, diff --git a/sycl/test-e2e/Matrix/joint_matrix_su_int8_impl.hpp b/sycl/test-e2e/Matrix/joint_matrix_su_int8_impl.hpp index 628735e8523e2..397fcc9a5aa97 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_su_int8_impl.hpp +++ b/sycl/test-e2e/Matrix/joint_matrix_su_int8_impl.hpp @@ -51,9 +51,8 @@ void matrix_multiply(big_matrix &C, joint_matrix sub_a; // For B, we assume B has been already VNNIed. - joint_matrix< - sub_group, uint8_t, use::b, TK, TN, - ext::oneapi::experimental::matrix::layout::ext_intel_packed> + joint_matrix sub_b; joint_matrix sub_c; @@ -73,7 +72,7 @@ void matrix_multiply(big_matrix &C, accB.template get_multi_ptr() + (k * TK / 4) * (N * 4) + sg_starty / SG_SZ * TN * 4, N * 4); - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c); + sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); } joint_matrix_store( sg, sub_c, diff --git a/sycl/test-e2e/Matrix/joint_matrix_tf32_impl.hpp b/sycl/test-e2e/Matrix/joint_matrix_tf32_impl.hpp index e8a4ece3bb00f..4d4ba0ee951e9 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_tf32_impl.hpp +++ b/sycl/test-e2e/Matrix/joint_matrix_tf32_impl.hpp @@ -76,13 +76,15 @@ void matrix_multiply(big_matrix &C, // function will work on truncated floats. joint_matrix_apply(sg, sub_a, [=](float x) { x = round_to_tf32(x); }); - auto wi_data_b = sycl::ext::oneapi::detail::get_wi_data(sg, sub_b); + auto wi_data_b = + sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_b); for (int i = 0; i < wi_data_b.length(); i++) { wi_data_b[i] = round_to_tf32(wi_data_b[i]); } - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c); + sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); } - auto wi_slice_a = sycl::ext::oneapi::detail::get_wi_data(sg, sub_a); + auto wi_slice_a = + sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_a); joint_matrix_apply(sg, sub_a, [=](float x) { x *= 2; }); joint_matrix_store( sg, sub_c, diff --git a/sycl/test-e2e/Matrix/joint_matrix_transposeC_impl.hpp b/sycl/test-e2e/Matrix/joint_matrix_transposeC_impl.hpp index 8634efde40c74..24f6cce4cc09d 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_transposeC_impl.hpp +++ b/sycl/test-e2e/Matrix/joint_matrix_transposeC_impl.hpp @@ -42,9 +42,8 @@ void matrix_multiply(T1 *C, T2 *A, T2 *B, queue q, unsigned int vnniFactor) { // For B, since current implementation does not support non-packed // layout, users need to specify the packed_b layout. - joint_matrix< - sub_group, bfloat16, use::b, TK, TN, - ext::oneapi::experimental::matrix::layout::ext_intel_packed> + joint_matrix sub_b; joint_matrix sub_c; joint_matrix_load(sg, sub_c, @@ -56,7 +55,7 @@ void matrix_multiply(T1 *C, T2 *A, T2 *B, queue q, unsigned int vnniFactor) { joint_matrix_load(sg, sub_b, pB + k * N + sg_starty / SG_SZ * TN * vnniFactor, N * vnniFactor); - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c); + sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); } joint_matrix_store( sg, sub_c, pC + (sg_startx * TM) * N + sg_starty / SG_SZ * TN, N, diff --git a/sycl/test-e2e/Matrix/joint_matrix_us_int8_impl.hpp b/sycl/test-e2e/Matrix/joint_matrix_us_int8_impl.hpp index 493ece5bc7d5e..1d82f8833aba6 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_us_int8_impl.hpp +++ b/sycl/test-e2e/Matrix/joint_matrix_us_int8_impl.hpp @@ -53,9 +53,8 @@ void matrix_multiply(big_matrix &C, joint_matrix sub_a; // For B, we assume B has been already VNNIed. - joint_matrix< - sub_group, int8_t, use::b, TK, TN, - ext::oneapi::experimental::matrix::layout::ext_intel_packed> + joint_matrix sub_b; joint_matrix sub_c; @@ -76,7 +75,7 @@ void matrix_multiply(big_matrix &C, accB.template get_multi_ptr() + (k * TK / 4) * (N * 4) + sg_starty / SG_SZ * TN * 4, N * 4); - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c); + sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); } joint_matrix_store( sg, sub_c, diff --git a/sycl/test-e2e/Matrix/joint_matrix_uu_int8_impl.hpp b/sycl/test-e2e/Matrix/joint_matrix_uu_int8_impl.hpp index d668fa604d395..e400b6694e4a9 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_uu_int8_impl.hpp +++ b/sycl/test-e2e/Matrix/joint_matrix_uu_int8_impl.hpp @@ -51,9 +51,8 @@ void matrix_multiply(big_matrix &C, joint_matrix sub_a; // For B, we assume B has been already VNNIed. - joint_matrix< - sub_group, uint8_t, use::b, TK, TN, - ext::oneapi::experimental::matrix::layout::ext_intel_packed> + joint_matrix sub_b; joint_matrix sub_c; @@ -74,7 +73,7 @@ void matrix_multiply(big_matrix &C, accB.template get_multi_ptr() + (k * TK / 4) * (N * 4) + sg_starty / SG_SZ * TN * 4, N * 4); - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c); + sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); } joint_matrix_store( sg, sub_c, diff --git a/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-bfloat16-test.cpp b/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-bfloat16-test.cpp index 958aba55c3b46..80b67b14a55ac 100644 --- a/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-bfloat16-test.cpp +++ b/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-bfloat16-test.cpp @@ -57,7 +57,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m16n16k16.mma.row.row.bf16(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c); + sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m16n16k16.store.d.row.stride.f32.p1(ptr addrspace(1) %{{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -88,7 +88,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m16n16k16.mma.col.col.bf16(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c); + sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m16n16k16.store.d.col.stride.f32.p1(ptr addrspace(1) %{{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -119,7 +119,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m32n8k16.mma.row.row.bf16(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c); + sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m32n8k16.store.d.row.stride.f32.p1(ptr addrspace(1) %{{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -150,7 +150,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m32n8k16.mma.col.col.bf16(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c); + sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m32n8k16.store.d.col.stride.f32.p1(ptr addrspace(1) %{{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -181,7 +181,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m8n32k16.mma.row.row.bf16(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c); + sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m8n32k16.store.d.row.stride.f32.p1(ptr addrspace(1) %{{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -212,7 +212,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m8n32k16.mma.col.col.bf16(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c); + sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m8n32k16.store.d.col.stride.f32.p1(ptr addrspace(1) %{{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), diff --git a/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-double-test.cpp b/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-double-test.cpp index 567e3293e1862..31f77dc55b16f 100644 --- a/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-double-test.cpp +++ b/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-double-test.cpp @@ -67,7 +67,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), N); //CHECK-OPAQUE: tail call { double, double } @llvm.nvvm.wmma.m8n8k4.mma.row.row.f64(double {{.*}}, double {{.*}}, double {{.*}}, double {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c); + sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); //CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m8n8k4.store.d.row.stride.f64.p1(ptr addrspace(1) %{{.*}}, double {{.*}}, double {{.*}}, i32 8) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -98,7 +98,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), K); //CHECK-OPAQUE: tail call { double, double } @llvm.nvvm.wmma.m8n8k4.mma.col.col.f64(double {{.*}}, double {{.*}}, double {{.*}}, double {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c); + sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); //CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m8n8k4.store.d.col.stride.f64.p1(ptr addrspace(1) %{{.*}}, double {{.*}}, double {{.*}}, i32 8) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), diff --git a/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-half-float-test.cpp b/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-half-float-test.cpp index 93d431061a3e0..7c24179022d55 100644 --- a/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-half-float-test.cpp +++ b/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-half-float-test.cpp @@ -56,7 +56,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m16n16k16.mma.row.row.f32.f32(<2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c); + sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m16n16k16.store.d.row.stride.f32.p1(ptr addrspace(1) %{{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -87,7 +87,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m16n16k16.mma.col.col.f32.f32(<2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c); + sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m16n16k16.store.d.col.stride.f32.p1(ptr addrspace(1) %{{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -118,7 +118,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m32n8k16.mma.row.row.f32.f32(<2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c); + sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m32n8k16.store.d.row.stride.f32.p1(ptr addrspace(1) %{{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -149,7 +149,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m32n8k16.mma.col.col.f32.f32(<2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c); + sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m32n8k16.store.d.col.stride.f32.p1(ptr addrspace(1) %{{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -180,7 +180,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m8n32k16.mma.row.row.f32.f32(<2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c); + sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m8n32k16.store.d.row.stride.f32.p1(ptr addrspace(1) %{{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -211,7 +211,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m8n32k16.mma.col.col.f32.f32(<2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c); + sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m8n32k16.store.d.col.stride.f32.p1(ptr addrspace(1) %{{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), diff --git a/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-half-half-test.cpp b/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-half-half-test.cpp index ba4d27a2feb89..0e0b4ce903be2 100644 --- a/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-half-half-test.cpp +++ b/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-half-half-test.cpp @@ -56,7 +56,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m16n16k16.mma.row.row.f16.f16(<2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c); + sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m16n16k16.store.d.row.stride.f16.p0(ptr %{{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -87,7 +87,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m16n16k16.mma.col.col.f16.f16(<2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c); + sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m16n16k16.store.d.col.stride.f16.p0(ptr %{{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -118,7 +118,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m32n8k16.mma.row.row.f16.f16(<2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c); + sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m32n8k16.store.d.row.stride.f16.p0(ptr %{{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -149,7 +149,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m32n8k16.mma.col.col.f16.f16(<2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c); + sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m32n8k16.store.d.col.stride.f16.p0(ptr %{{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -180,7 +180,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m8n32k16.mma.row.row.f16.f16(<2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c); + sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m8n32k16.store.d.row.stride.f16.p0(ptr %{{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -211,7 +211,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m8n32k16.mma.col.col.f16.f16(<2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c); + sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m8n32k16.store.d.col.stride.f16.p0(ptr %{{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), diff --git a/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-int8-test.cpp b/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-int8-test.cpp index 2d012581faf8b..575039723d56e 100644 --- a/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-int8-test.cpp +++ b/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-int8-test.cpp @@ -56,7 +56,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m16n16k16.mma.row.row.s8(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c); + sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m16n16k16.store.d.row.stride.s32.p1(ptr addrspace(1) %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -87,7 +87,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m16n16k16.mma.col.col.s8(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c); + sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m16n16k16.store.d.col.stride.s32.p1(ptr addrspace(1) %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -118,7 +118,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m32n8k16.mma.row.row.s8(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c); + sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m32n8k16.store.d.row.stride.s32.p1(ptr addrspace(1) %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -149,7 +149,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m32n8k16.mma.col.col.s8(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c); + sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m32n8k16.store.d.col.stride.s32.p1(ptr addrspace(1) %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -180,7 +180,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m8n32k16.mma.row.row.s8(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c); + sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m8n32k16.store.d.row.stride.s32.p1(ptr addrspace(1) %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -211,7 +211,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m8n32k16.mma.col.col.s8(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c); + sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m8n32k16.store.d.col.stride.s32.p1(ptr addrspace(1) %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), diff --git a/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-tf32-test.cpp b/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-tf32-test.cpp index a69246dac7315..8ada375fff395 100644 --- a/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-tf32-test.cpp +++ b/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-tf32-test.cpp @@ -88,7 +88,7 @@ int main() { get_wi_data(sg, sub_b)[i] = round_to_tf32(get_wi_data(sg, sub_b)[i]); - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c); + sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); //CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m16n16k16.store.d.row.stride.f32.p1(ptr addrspace(1) {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, i32 {{.*}} joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -137,7 +137,7 @@ int main() { get_wi_data(sg, sub_b)[i] = round_to_tf32(get_wi_data(sg, sub_b)[i]); - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c); + sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); //CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m16n16k16.store.d.col.stride.f32.p1(ptr addrspace(1) {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), diff --git a/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-uint8-test.cpp b/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-uint8-test.cpp index c22a50908d1c7..69bc136e79776 100644 --- a/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-uint8-test.cpp +++ b/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-uint8-test.cpp @@ -56,7 +56,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m16n16k16.mma.row.row.u8(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c); + sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m16n16k16.store.d.row.stride.s32.p1(ptr addrspace(1) %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -87,7 +87,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m16n16k16.mma.col.col.u8(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c); + sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m16n16k16.store.d.col.stride.s32.p1(ptr addrspace(1) %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -118,7 +118,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m32n8k16.mma.row.row.u8(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c); + sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m32n8k16.store.d.row.stride.s32.p1(ptr addrspace(1) %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -149,7 +149,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m32n8k16.mma.col.col.u8(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c); + sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m32n8k16.store.d.col.stride.s32.p1(ptr addrspace(1) %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -180,7 +180,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m8n32k16.mma.row.row.u8(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c); + sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m8n32k16.store.d.row.stride.s32.p1(ptr addrspace(1) %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -211,7 +211,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m8n32k16.mma.col.col.u8(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}) - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c); + sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m8n32k16.store.d.col.stride.s32.p1(ptr addrspace(1) %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), diff --git a/sycl/test/check_device_code/matrix/matrix_load_store_as.cpp b/sycl/test/check_device_code/matrix/matrix_load_store_as.cpp index 4c3c97c77edd8..e5935b8b3af47 100644 --- a/sycl/test/check_device_code/matrix/matrix_load_store_as.cpp +++ b/sycl/test/check_device_code/matrix/matrix_load_store_as.cpp @@ -29,9 +29,8 @@ int main(void) { joint_matrix tA; - joint_matrix< - sub_group, unsigned short, use::b, 16, 16, - ext::oneapi::experimental::matrix::layout::ext_intel_packed> + joint_matrix tB; joint_matrix tC; @@ -50,7 +49,7 @@ int main(void) { // B should load from global address space // CHECK: %{{.*}} = tail call spir_func noundef target("spirv.JointMatrixINTEL", i16, 16, 16, 2, 3, 1) @_Z[[#]]__spirv_JointMatrixLoadINTEL{{.*}}(ptr addrspace(1) noundef %{{.*}}, i64 noundef 32, i32 noundef 2, i32 noundef 3, i32 noundef 0) #{{.*}} joint_matrix_load(sg, tB, pB, 32); - joint_matrix_mad(sg, tA, tB, tC, tC); + tC = joint_matrix_mad(sg, tA, tB, tC); // C should store to global address space // CHECK: tail call spir_func void @_Z[[#]]__spirv_JointMatrixStoreINTEL{{.*}}(ptr addrspace(1) noundef %{{.*}}, target("spirv.JointMatrixINTEL", float, 8, 16, 3, 3, 2) noundef %{{.*}}, i64 noundef 16, i32 noundef 0, i32 noundef 3, i32 noundef 0) #{{.*}} joint_matrix_store(sg, tC, pC, 16, layout::row_major); diff --git a/sycl/test/matrix/matrix-bfloat16-test-coord-basicB.cpp b/sycl/test/matrix/matrix-bfloat16-test-coord-basicB.cpp index 769efbdb0d959..b141f2176971d 100644 --- a/sycl/test/matrix/matrix-bfloat16-test-coord-basicB.cpp +++ b/sycl/test/matrix/matrix-bfloat16-test-coord-basicB.cpp @@ -154,9 +154,8 @@ void matrix_sum_cols(queue q, big_matrix &B, nd_range<2> &r) { sub_group sg = spmd_item.get_sub_group(); // TK = 32, TN = 16 - joint_matrix< - sub_group, int8_t, use::b, TK, TN, - ext::oneapi::experimental::matrix::layout::ext_intel_packed> + joint_matrix sub_b; joint_matrix_load( @@ -167,7 +166,8 @@ void matrix_sum_cols(queue q, big_matrix &B, nd_range<2> &r) { int32_t sum_local_cols[N] = {0}; // 4 local cols, N total // sub_b has 32x16 elements, 32 elements per WI, 4 per WI per row - auto wiData = sycl::ext::oneapi::detail::get_wi_data(sg, sub_b); + auto wiData = + sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_b); size_t global_index; // Index into the result array that holds the sums. diff --git a/sycl/test/matrix/matrix-bfloat16-test.cpp b/sycl/test/matrix/matrix-bfloat16-test.cpp index eae42973d45c4..2e0e309081464 100644 --- a/sycl/test/matrix/matrix-bfloat16-test.cpp +++ b/sycl/test/matrix/matrix-bfloat16-test.cpp @@ -68,8 +68,7 @@ void matrix_multiply(big_matrix &C, // the packed_b layout. By default, the layout is row_major and size // is (TK, TN). joint_matrix + sycl::ext::intel::experimental::matrix::layout::packed> sub_b; joint_matrix sub_c; @@ -90,7 +89,7 @@ void matrix_multiply(big_matrix &C, accB.template get_multi_ptr() + (k * TK / 2) * (N * 2) + sg_starty / SG_SZ * TN * 2, N * 2); - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c); + sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); } joint_matrix_store( sg, sub_c, diff --git a/sycl/test/matrix/matrix-elemwise-ops.cpp b/sycl/test/matrix/matrix-elemwise-ops.cpp index 861727c3fe92d..3205e4c346ba6 100644 --- a/sycl/test/matrix/matrix-elemwise-ops.cpp +++ b/sycl/test/matrix/matrix-elemwise-ops.cpp @@ -69,8 +69,7 @@ void matrix_multiply(big_matrix &C, // the packed_b layout. By default, the layout is row_major and size // is (TK, TN). joint_matrix + sycl::ext::intel::experimental::matrix::layout::packed> sub_b; joint_matrix sub_c; @@ -94,9 +93,10 @@ void matrix_multiply(big_matrix &C, accB.template get_multi_ptr() + (k * TK / 4) * (N * 4) + sg_starty / SG_SZ * TN * 4, N * 4); - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c); + sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); } - auto wi_data_c = sycl::ext::oneapi::detail::get_wi_data(sg, sub_c); + auto wi_data_c = + sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_c); for (int i = 0; i < wi_data_c.length(); i++) { wi_data_c[i] *= 2; } diff --git a/sycl/test/matrix/matrix-int8-test.cpp b/sycl/test/matrix/matrix-int8-test.cpp index 959f5b2b30871..f8dcc26ab1b17 100644 --- a/sycl/test/matrix/matrix-int8-test.cpp +++ b/sycl/test/matrix/matrix-int8-test.cpp @@ -74,8 +74,7 @@ void matrix_multiply(big_matrix &C, // the packed_b layout. By default, the layout is row_major and size // is (TK, TN). joint_matrix + sycl::ext::intel::experimental::matrix::layout::packed> sub_b; joint_matrix sub_c; @@ -95,7 +94,7 @@ void matrix_multiply(big_matrix &C, accB.template get_multi_ptr() + (k * TK / 4) * (N * 4) + sg_starty / SG_SZ * TN * 4, N * 4); - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c); + sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); } joint_matrix_store( sg, sub_c, diff --git a/sycl/test/matrix/matrix-tf32-test.cpp b/sycl/test/matrix/matrix-tf32-test.cpp index 64852b4a890cc..d6affb4067003 100644 --- a/sycl/test/matrix/matrix-tf32-test.cpp +++ b/sycl/test/matrix/matrix-tf32-test.cpp @@ -87,11 +87,12 @@ void matrix_multiply(big_matrix &C, // function will work on truncated floats. joint_matrix_apply(sg, sub_a, [=](float x) { x = round_to_tf32(x); }); - auto wi_data_b = sycl::ext::oneapi::detail::get_wi_data(sg, sub_b); + auto wi_data_b = + sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_b); for (int i = 0; i < wi_data_b.length(); i++) { wi_data_b[i] = round_to_tf32(wi_data_b[i]); } - joint_matrix_mad(sg, sub_a, sub_b, sub_c, sub_c); + sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); } joint_matrix_store( sg, sub_c, diff --git a/sycl/test/matrix/query-use.cpp b/sycl/test/matrix/query-use.cpp index 96d2fd3c0c26a..9afc8e1173043 100644 --- a/sycl/test/matrix/query-use.cpp +++ b/sycl/test/matrix/query-use.cpp @@ -64,8 +64,7 @@ void query_amx() { sub_group sg = spmd_item.get_sub_group(); myparams2::joint_matrix_a sub_a1; myparams2::joint_matrix_b< - sub_group, - sycl::ext::oneapi::experimental::matrix::layout::ext_intel_packed> + sub_group, sycl::ext::intel::experimental::matrix::layout::packed> sub_b1; myparams2::joint_matrix_accumulator sub_c1; @@ -145,8 +144,7 @@ void query_xmx8() { sub_group sg = spmd_item.get_sub_group(); myparams2::joint_matrix_a sub_a1; myparams2::joint_matrix_b< - sub_group, - sycl::ext::oneapi::experimental::matrix::layout::ext_intel_packed> + sub_group, sycl::ext::intel::experimental::matrix::layout::packed> sub_b1; myparams2::joint_matrix_accumulator sub_c1; From a09a778416f23bec6e4f4f4db185b5baed4c8079 Mon Sep 17 00:00:00 2001 From: Bing1 Yu Date: Thu, 21 Sep 2023 15:08:16 +0800 Subject: [PATCH 11/50] test changes: mov D in mad --- sycl/test-e2e/Matrix/element_wise_ops_impl.hpp | 2 +- .../test-e2e/Matrix/elemwise_irreg_size_ops_bf16.cpp | 2 +- sycl/test-e2e/Matrix/joint_matrix_all_sizes_impl.hpp | 2 +- sycl/test-e2e/Matrix/joint_matrix_apply_cuda.hpp | 4 ++-- .../Matrix/joint_matrix_bfloat16_32x64_impl.hpp | 2 +- .../Matrix/joint_matrix_bfloat16_array_impl.hpp | 2 +- ...oint_matrix_bfloat16_colmajorA_colmajorB_impl.hpp | 2 +- sycl/test-e2e/Matrix/joint_matrix_bfloat16_impl.hpp | 2 +- ...oint_matrix_bfloat16_rowmajorA_rowmajorB_impl.hpp | 2 +- .../Matrix/joint_matrix_colA_rowB_colC_impl.hpp | 2 +- sycl/test-e2e/Matrix/joint_matrix_gemm_cuda.hpp | 2 +- sycl/test-e2e/Matrix/joint_matrix_half_impl.hpp | 2 +- .../joint_matrix_int8_colmajorA_colmajorB_impl.hpp | 2 +- sycl/test-e2e/Matrix/joint_matrix_int8_vnni_impl.hpp | 2 +- .../test-e2e/Matrix/joint_matrix_out_bounds_impl.hpp | 2 +- sycl/test-e2e/Matrix/joint_matrix_query_default.cpp | 2 +- sycl/test-e2e/Matrix/joint_matrix_ss_int8_impl.hpp | 2 +- sycl/test-e2e/Matrix/joint_matrix_su_int8_impl.hpp | 2 +- sycl/test-e2e/Matrix/joint_matrix_tf32_impl.hpp | 2 +- .../test-e2e/Matrix/joint_matrix_transposeC_impl.hpp | 2 +- sycl/test-e2e/Matrix/joint_matrix_us_int8_impl.hpp | 2 +- sycl/test-e2e/Matrix/joint_matrix_uu_int8_impl.hpp | 2 +- .../cuda/matrix/matrix-nvptx-bfloat16-test.cpp | 12 ++++++------ .../cuda/matrix/matrix-nvptx-double-test.cpp | 4 ++-- .../cuda/matrix/matrix-nvptx-half-float-test.cpp | 12 ++++++------ .../cuda/matrix/matrix-nvptx-half-half-test.cpp | 12 ++++++------ .../cuda/matrix/matrix-nvptx-int8-test.cpp | 12 ++++++------ .../cuda/matrix/matrix-nvptx-tf32-test.cpp | 4 ++-- .../cuda/matrix/matrix-nvptx-uint8-test.cpp | 12 ++++++------ .../matrix/matrix_load_store_as.cpp | 2 +- sycl/test/matrix/matrix-bfloat16-test.cpp | 2 +- sycl/test/matrix/matrix-elemwise-ops.cpp | 2 +- sycl/test/matrix/matrix-int8-test.cpp | 2 +- sycl/test/matrix/matrix-tf32-test.cpp | 2 +- 34 files changed, 62 insertions(+), 62 deletions(-) diff --git a/sycl/test-e2e/Matrix/element_wise_ops_impl.hpp b/sycl/test-e2e/Matrix/element_wise_ops_impl.hpp index 1206b556339a9..66fb8237c2be3 100644 --- a/sycl/test-e2e/Matrix/element_wise_ops_impl.hpp +++ b/sycl/test-e2e/Matrix/element_wise_ops_impl.hpp @@ -72,7 +72,7 @@ void matrix_multiply(big_matrix &C, accB.template get_multi_ptr() + (k * TK / 4) * (N * 4) + sg_starty / SG_SZ * TN * 4, N * 4); - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_c, sub_a, sub_b, sub_c); } auto wi_slice_c = sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_c); diff --git a/sycl/test-e2e/Matrix/elemwise_irreg_size_ops_bf16.cpp b/sycl/test-e2e/Matrix/elemwise_irreg_size_ops_bf16.cpp index 8e2865de207b4..a74288c71ae16 100644 --- a/sycl/test-e2e/Matrix/elemwise_irreg_size_ops_bf16.cpp +++ b/sycl/test-e2e/Matrix/elemwise_irreg_size_ops_bf16.cpp @@ -101,7 +101,7 @@ void matrix_multiply(big_matrix &C, accB.template get_multi_ptr() + (k) * (N) + sg_starty / SG_SZ * TN * 2, N * 2); - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_c, sub_a, sub_b, sub_c); } auto wi_slice_c = sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_c); diff --git a/sycl/test-e2e/Matrix/joint_matrix_all_sizes_impl.hpp b/sycl/test-e2e/Matrix/joint_matrix_all_sizes_impl.hpp index 469837cd26d41..dbd88e607f243 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_all_sizes_impl.hpp +++ b/sycl/test-e2e/Matrix/joint_matrix_all_sizes_impl.hpp @@ -78,7 +78,7 @@ void matrix_multiply(big_matrix &C, big_matrix &A, (k * TK / vnniFactor) * (N * vnniFactor) + sg_starty / SG_SZ * TN * vnniFactor, N * vnniFactor); - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_c, sub_a, sub_b, sub_c); } joint_matrix_store( sg, sub_c, diff --git a/sycl/test-e2e/Matrix/joint_matrix_apply_cuda.hpp b/sycl/test-e2e/Matrix/joint_matrix_apply_cuda.hpp index 4303239cefe32..754e6429d0d96 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_apply_cuda.hpp +++ b/sycl/test-e2e/Matrix/joint_matrix_apply_cuda.hpp @@ -71,7 +71,7 @@ void matrix_verify_lambda(queue q, joint_matrix_apply(sg, sub_a, lambda); - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_c, sub_a, sub_b, sub_c); joint_matrix_store( sg, sub_c, @@ -154,7 +154,7 @@ void matrix_verify_op(queue q, big_matrix &C, } } - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_c, sub_a, sub_b, sub_c); joint_matrix_store( sg, sub_c, diff --git a/sycl/test-e2e/Matrix/joint_matrix_bfloat16_32x64_impl.hpp b/sycl/test-e2e/Matrix/joint_matrix_bfloat16_32x64_impl.hpp index cc0196660744a..264961a05ad96 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_bfloat16_32x64_impl.hpp +++ b/sycl/test-e2e/Matrix/joint_matrix_bfloat16_32x64_impl.hpp @@ -68,7 +68,7 @@ void matrix_multiply(big_matrix &C, big_matrix &A, accB.template get_multi_ptr() + (k * TK / 2) * (N * 2) + sg_starty / SG_SZ * TN * 2, N * 2); - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_c, sub_a, sub_b, sub_c); } joint_matrix_store( sg, sub_c, diff --git a/sycl/test-e2e/Matrix/joint_matrix_bfloat16_array_impl.hpp b/sycl/test-e2e/Matrix/joint_matrix_bfloat16_array_impl.hpp index d6390d8061dcc..d1c9939318551 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_bfloat16_array_impl.hpp +++ b/sycl/test-e2e/Matrix/joint_matrix_bfloat16_array_impl.hpp @@ -81,7 +81,7 @@ void matrix_multiply(big_matrix &C, big_matrix &A, accA.template get_multi_ptr() + (sg_startx * TM * JM_ARRAY_SZ + TM * i) * K + k * TK, K); - sub_c[i] = joint_matrix_mad(sg, sub_a[i], sub_b, sub_c[i]); + joint_matrix_mad(sg, sub_c[i], sub_a[i], sub_b, sub_c[i]); } } diff --git a/sycl/test-e2e/Matrix/joint_matrix_bfloat16_colmajorA_colmajorB_impl.hpp b/sycl/test-e2e/Matrix/joint_matrix_bfloat16_colmajorA_colmajorB_impl.hpp index 4847e093127a8..7c07afcb3ecb7 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_bfloat16_colmajorA_colmajorB_impl.hpp +++ b/sycl/test-e2e/Matrix/joint_matrix_bfloat16_colmajorA_colmajorB_impl.hpp @@ -64,7 +64,7 @@ void matrix_multiply(big_matrix &C, big_matrix &A, accB.template get_multi_ptr() + (sg_starty / SG_SZ * TN) * K + k * TK, K); - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_c, sub_a, sub_b, sub_c); } joint_matrix_store( sg, sub_c, diff --git a/sycl/test-e2e/Matrix/joint_matrix_bfloat16_impl.hpp b/sycl/test-e2e/Matrix/joint_matrix_bfloat16_impl.hpp index 76ac69da27677..4889e77812d72 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_bfloat16_impl.hpp +++ b/sycl/test-e2e/Matrix/joint_matrix_bfloat16_impl.hpp @@ -66,7 +66,7 @@ void matrix_multiply(big_matrix &C, big_matrix &A, accB.template get_multi_ptr() + (k * TK / 2) * (N * 2) + sg_starty / SG_SZ * TN * 2, N * 2); - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_c, sub_a, sub_b, sub_c); } joint_matrix_store( sg, sub_c, diff --git a/sycl/test-e2e/Matrix/joint_matrix_bfloat16_rowmajorA_rowmajorB_impl.hpp b/sycl/test-e2e/Matrix/joint_matrix_bfloat16_rowmajorA_rowmajorB_impl.hpp index 4d61a733e5927..119554c9b23ad 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_bfloat16_rowmajorA_rowmajorB_impl.hpp +++ b/sycl/test-e2e/Matrix/joint_matrix_bfloat16_rowmajorA_rowmajorB_impl.hpp @@ -64,7 +64,7 @@ void matrix_multiply(big_matrix &C, big_matrix &A, accB.template get_multi_ptr() + (k * TK) * (N) + sg_starty / SG_SZ * TN, N); - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_c, sub_a, sub_b, sub_c); } joint_matrix_store( sg, sub_c, diff --git a/sycl/test-e2e/Matrix/joint_matrix_colA_rowB_colC_impl.hpp b/sycl/test-e2e/Matrix/joint_matrix_colA_rowB_colC_impl.hpp index f75da1824d94b..f24f720715788 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_colA_rowB_colC_impl.hpp +++ b/sycl/test-e2e/Matrix/joint_matrix_colA_rowB_colC_impl.hpp @@ -48,7 +48,7 @@ void matrix_multiply(T1 *C, T2 *A, T2 *B, queue q) { joint_matrix_load(sg, sub_a, pA + (sg_startx * TM) * K + k, K); joint_matrix_load(sg, sub_b, pB + k * N + sg_starty / SG_SZ * TN, N); - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_c, sub_a, sub_b, sub_c); } joint_matrix_store( sg, sub_c, pC + (sg_startx * TM) * N + sg_starty / SG_SZ * TN, N, diff --git a/sycl/test-e2e/Matrix/joint_matrix_gemm_cuda.hpp b/sycl/test-e2e/Matrix/joint_matrix_gemm_cuda.hpp index 5e451d45d7727..219a3976f4c90 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_gemm_cuda.hpp +++ b/sycl/test-e2e/Matrix/joint_matrix_gemm_cuda.hpp @@ -178,7 +178,7 @@ void test(queue &q) { } } - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_c, sub_a, sub_b, sub_c); } joint_matrix_store( sg, sub_c, diff --git a/sycl/test-e2e/Matrix/joint_matrix_half_impl.hpp b/sycl/test-e2e/Matrix/joint_matrix_half_impl.hpp index f92548d2f7ed8..2ac425955a555 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_half_impl.hpp +++ b/sycl/test-e2e/Matrix/joint_matrix_half_impl.hpp @@ -71,7 +71,7 @@ void matrix_multiply(big_matrix &C, accB.template get_multi_ptr() + (k * TK / 2) * (N * 2) + sg_starty / SG_SZ * TN * 2, N * 2); - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_c, sub_a, sub_b, sub_c); } joint_matrix_store( sg, sub_c, diff --git a/sycl/test-e2e/Matrix/joint_matrix_int8_colmajorA_colmajorB_impl.hpp b/sycl/test-e2e/Matrix/joint_matrix_int8_colmajorA_colmajorB_impl.hpp index 6111f503007f5..d2081f01ec167 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_int8_colmajorA_colmajorB_impl.hpp +++ b/sycl/test-e2e/Matrix/joint_matrix_int8_colmajorA_colmajorB_impl.hpp @@ -64,7 +64,7 @@ void matrix_multiply(big_matrix &C, accB.template get_multi_ptr() + (sg_starty / SG_SZ * TN) * K + k * TK, K); - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_c, sub_a, sub_b, sub_c); } joint_matrix_store( sg, sub_c, diff --git a/sycl/test-e2e/Matrix/joint_matrix_int8_vnni_impl.hpp b/sycl/test-e2e/Matrix/joint_matrix_int8_vnni_impl.hpp index b6fe3f0376ffd..f4f4d682930a4 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_int8_vnni_impl.hpp +++ b/sycl/test-e2e/Matrix/joint_matrix_int8_vnni_impl.hpp @@ -65,7 +65,7 @@ void matrix_multiply(big_matrix &C, accB.template get_multi_ptr() + (k * TK) * N + sg_starty / SG_SZ * TN, N); - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_c, sub_a, sub_b, sub_c); } joint_matrix_store( sg, sub_c, diff --git a/sycl/test-e2e/Matrix/joint_matrix_out_bounds_impl.hpp b/sycl/test-e2e/Matrix/joint_matrix_out_bounds_impl.hpp index 1c1c4f97819bf..008b3531a7ec8 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_out_bounds_impl.hpp +++ b/sycl/test-e2e/Matrix/joint_matrix_out_bounds_impl.hpp @@ -58,7 +58,7 @@ void matrix_multiply(T1 *C, T2 *A, T2 *B, queue q, unsigned int vnniFactor) { joint_matrix_load_checked( sg, sub_b, pB + k * N + sg_starty / SG_SZ * TN * vnniFactor, N * vnniFactor, K / vnniFactor, N * vnniFactor); - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_c, sub_a, sub_b, sub_c); } // bounds-checked store where width and height are added joint_matrix_store_checked( diff --git a/sycl/test-e2e/Matrix/joint_matrix_query_default.cpp b/sycl/test-e2e/Matrix/joint_matrix_query_default.cpp index 048aed6341f6c..ddc0c14e32de8 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_query_default.cpp +++ b/sycl/test-e2e/Matrix/joint_matrix_query_default.cpp @@ -99,7 +99,7 @@ void matrix_multiply(big_matrix &C, accB.template get_multi_ptr() + (k * TK / 4) * (N * 4) + sg_starty / SG_SZ * TN * 4, N * 4); - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_c, sub_a, sub_b, sub_c); } joint_matrix_store( sg, sub_c, diff --git a/sycl/test-e2e/Matrix/joint_matrix_ss_int8_impl.hpp b/sycl/test-e2e/Matrix/joint_matrix_ss_int8_impl.hpp index 5cca6572cef21..21353754b6580 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_ss_int8_impl.hpp +++ b/sycl/test-e2e/Matrix/joint_matrix_ss_int8_impl.hpp @@ -68,7 +68,7 @@ void matrix_multiply(big_matrix &C, accB.template get_multi_ptr() + (k * TK / 4) * (N * 4) + sg_starty / SG_SZ * TN * 4, N * 4); - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_c, sub_a, sub_b, sub_c); } joint_matrix_store( sg, sub_c, diff --git a/sycl/test-e2e/Matrix/joint_matrix_su_int8_impl.hpp b/sycl/test-e2e/Matrix/joint_matrix_su_int8_impl.hpp index 397fcc9a5aa97..80dbbd1afbbc5 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_su_int8_impl.hpp +++ b/sycl/test-e2e/Matrix/joint_matrix_su_int8_impl.hpp @@ -72,7 +72,7 @@ void matrix_multiply(big_matrix &C, accB.template get_multi_ptr() + (k * TK / 4) * (N * 4) + sg_starty / SG_SZ * TN * 4, N * 4); - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_c, sub_a, sub_b, sub_c); } joint_matrix_store( sg, sub_c, diff --git a/sycl/test-e2e/Matrix/joint_matrix_tf32_impl.hpp b/sycl/test-e2e/Matrix/joint_matrix_tf32_impl.hpp index 4d4ba0ee951e9..33ee1d69b4e35 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_tf32_impl.hpp +++ b/sycl/test-e2e/Matrix/joint_matrix_tf32_impl.hpp @@ -81,7 +81,7 @@ void matrix_multiply(big_matrix &C, for (int i = 0; i < wi_data_b.length(); i++) { wi_data_b[i] = round_to_tf32(wi_data_b[i]); } - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_c, sub_a, sub_b, sub_c); } auto wi_slice_a = sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_a); diff --git a/sycl/test-e2e/Matrix/joint_matrix_transposeC_impl.hpp b/sycl/test-e2e/Matrix/joint_matrix_transposeC_impl.hpp index 24f6cce4cc09d..56dbf30bed4d4 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_transposeC_impl.hpp +++ b/sycl/test-e2e/Matrix/joint_matrix_transposeC_impl.hpp @@ -55,7 +55,7 @@ void matrix_multiply(T1 *C, T2 *A, T2 *B, queue q, unsigned int vnniFactor) { joint_matrix_load(sg, sub_b, pB + k * N + sg_starty / SG_SZ * TN * vnniFactor, N * vnniFactor); - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_c, sub_a, sub_b, sub_c); } joint_matrix_store( sg, sub_c, pC + (sg_startx * TM) * N + sg_starty / SG_SZ * TN, N, diff --git a/sycl/test-e2e/Matrix/joint_matrix_us_int8_impl.hpp b/sycl/test-e2e/Matrix/joint_matrix_us_int8_impl.hpp index 1d82f8833aba6..2bda766fc290a 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_us_int8_impl.hpp +++ b/sycl/test-e2e/Matrix/joint_matrix_us_int8_impl.hpp @@ -75,7 +75,7 @@ void matrix_multiply(big_matrix &C, accB.template get_multi_ptr() + (k * TK / 4) * (N * 4) + sg_starty / SG_SZ * TN * 4, N * 4); - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_c, sub_a, sub_b, sub_c); } joint_matrix_store( sg, sub_c, diff --git a/sycl/test-e2e/Matrix/joint_matrix_uu_int8_impl.hpp b/sycl/test-e2e/Matrix/joint_matrix_uu_int8_impl.hpp index e400b6694e4a9..832b28ecc0562 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_uu_int8_impl.hpp +++ b/sycl/test-e2e/Matrix/joint_matrix_uu_int8_impl.hpp @@ -73,7 +73,7 @@ void matrix_multiply(big_matrix &C, accB.template get_multi_ptr() + (k * TK / 4) * (N * 4) + sg_starty / SG_SZ * TN * 4, N * 4); - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_c, sub_a, sub_b, sub_c); } joint_matrix_store( sg, sub_c, diff --git a/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-bfloat16-test.cpp b/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-bfloat16-test.cpp index 80b67b14a55ac..309786a38003f 100644 --- a/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-bfloat16-test.cpp +++ b/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-bfloat16-test.cpp @@ -57,7 +57,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m16n16k16.mma.row.row.bf16(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}) - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_c, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m16n16k16.store.d.row.stride.f32.p1(ptr addrspace(1) %{{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -88,7 +88,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m16n16k16.mma.col.col.bf16(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}) - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_c, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m16n16k16.store.d.col.stride.f32.p1(ptr addrspace(1) %{{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -119,7 +119,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m32n8k16.mma.row.row.bf16(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}) - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_c, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m32n8k16.store.d.row.stride.f32.p1(ptr addrspace(1) %{{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -150,7 +150,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m32n8k16.mma.col.col.bf16(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}) - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_c, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m32n8k16.store.d.col.stride.f32.p1(ptr addrspace(1) %{{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -181,7 +181,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m8n32k16.mma.row.row.bf16(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}) - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_c, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m8n32k16.store.d.row.stride.f32.p1(ptr addrspace(1) %{{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -212,7 +212,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m8n32k16.mma.col.col.bf16(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}) - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_c, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m8n32k16.store.d.col.stride.f32.p1(ptr addrspace(1) %{{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), diff --git a/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-double-test.cpp b/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-double-test.cpp index 31f77dc55b16f..16603407d74b1 100644 --- a/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-double-test.cpp +++ b/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-double-test.cpp @@ -67,7 +67,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), N); //CHECK-OPAQUE: tail call { double, double } @llvm.nvvm.wmma.m8n8k4.mma.row.row.f64(double {{.*}}, double {{.*}}, double {{.*}}, double {{.*}}) - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_c, sub_a, sub_b, sub_c); //CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m8n8k4.store.d.row.stride.f64.p1(ptr addrspace(1) %{{.*}}, double {{.*}}, double {{.*}}, i32 8) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -98,7 +98,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), K); //CHECK-OPAQUE: tail call { double, double } @llvm.nvvm.wmma.m8n8k4.mma.col.col.f64(double {{.*}}, double {{.*}}, double {{.*}}, double {{.*}}) - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_c, sub_a, sub_b, sub_c); //CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m8n8k4.store.d.col.stride.f64.p1(ptr addrspace(1) %{{.*}}, double {{.*}}, double {{.*}}, i32 8) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), diff --git a/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-half-float-test.cpp b/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-half-float-test.cpp index 7c24179022d55..47ddc0fb42f48 100644 --- a/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-half-float-test.cpp +++ b/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-half-float-test.cpp @@ -56,7 +56,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m16n16k16.mma.row.row.f32.f32(<2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}) - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_c, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m16n16k16.store.d.row.stride.f32.p1(ptr addrspace(1) %{{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -87,7 +87,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m16n16k16.mma.col.col.f32.f32(<2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}) - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_c, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m16n16k16.store.d.col.stride.f32.p1(ptr addrspace(1) %{{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -118,7 +118,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m32n8k16.mma.row.row.f32.f32(<2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}) - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_c, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m32n8k16.store.d.row.stride.f32.p1(ptr addrspace(1) %{{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -149,7 +149,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m32n8k16.mma.col.col.f32.f32(<2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}) - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_c, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m32n8k16.store.d.col.stride.f32.p1(ptr addrspace(1) %{{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -180,7 +180,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m8n32k16.mma.row.row.f32.f32(<2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}) - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_c, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m8n32k16.store.d.row.stride.f32.p1(ptr addrspace(1) %{{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -211,7 +211,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m8n32k16.mma.col.col.f32.f32(<2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}) - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_c, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m8n32k16.store.d.col.stride.f32.p1(ptr addrspace(1) %{{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), diff --git a/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-half-half-test.cpp b/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-half-half-test.cpp index 0e0b4ce903be2..0468f592b6427 100644 --- a/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-half-half-test.cpp +++ b/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-half-half-test.cpp @@ -56,7 +56,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m16n16k16.mma.row.row.f16.f16(<2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}) - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_c, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m16n16k16.store.d.row.stride.f16.p0(ptr %{{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -87,7 +87,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m16n16k16.mma.col.col.f16.f16(<2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}) - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_c, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m16n16k16.store.d.col.stride.f16.p0(ptr %{{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -118,7 +118,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m32n8k16.mma.row.row.f16.f16(<2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}) - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_c, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m32n8k16.store.d.row.stride.f16.p0(ptr %{{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -149,7 +149,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m32n8k16.mma.col.col.f16.f16(<2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}) - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_c, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m32n8k16.store.d.col.stride.f16.p0(ptr %{{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -180,7 +180,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m8n32k16.mma.row.row.f16.f16(<2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}) - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_c, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m8n32k16.store.d.row.stride.f16.p0(ptr %{{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -211,7 +211,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m8n32k16.mma.col.col.f16.f16(<2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}) - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_c, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m8n32k16.store.d.col.stride.f16.p0(ptr %{{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), diff --git a/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-int8-test.cpp b/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-int8-test.cpp index 575039723d56e..858c8625cc6e9 100644 --- a/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-int8-test.cpp +++ b/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-int8-test.cpp @@ -56,7 +56,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m16n16k16.mma.row.row.s8(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}) - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_c, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m16n16k16.store.d.row.stride.s32.p1(ptr addrspace(1) %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -87,7 +87,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m16n16k16.mma.col.col.s8(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}) - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_c, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m16n16k16.store.d.col.stride.s32.p1(ptr addrspace(1) %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -118,7 +118,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m32n8k16.mma.row.row.s8(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}) - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_c, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m32n8k16.store.d.row.stride.s32.p1(ptr addrspace(1) %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -149,7 +149,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m32n8k16.mma.col.col.s8(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}) - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_c, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m32n8k16.store.d.col.stride.s32.p1(ptr addrspace(1) %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -180,7 +180,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m8n32k16.mma.row.row.s8(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}) - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_c, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m8n32k16.store.d.row.stride.s32.p1(ptr addrspace(1) %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -211,7 +211,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m8n32k16.mma.col.col.s8(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}) - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_c, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m8n32k16.store.d.col.stride.s32.p1(ptr addrspace(1) %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), diff --git a/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-tf32-test.cpp b/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-tf32-test.cpp index 8ada375fff395..f47a701fe7bc6 100644 --- a/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-tf32-test.cpp +++ b/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-tf32-test.cpp @@ -88,7 +88,7 @@ int main() { get_wi_data(sg, sub_b)[i] = round_to_tf32(get_wi_data(sg, sub_b)[i]); - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_c, sub_a, sub_b, sub_c); //CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m16n16k16.store.d.row.stride.f32.p1(ptr addrspace(1) {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, i32 {{.*}} joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -137,7 +137,7 @@ int main() { get_wi_data(sg, sub_b)[i] = round_to_tf32(get_wi_data(sg, sub_b)[i]); - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_c, sub_a, sub_b, sub_c); //CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m16n16k16.store.d.col.stride.f32.p1(ptr addrspace(1) {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), diff --git a/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-uint8-test.cpp b/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-uint8-test.cpp index 69bc136e79776..c6a1bda15cdcb 100644 --- a/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-uint8-test.cpp +++ b/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-uint8-test.cpp @@ -56,7 +56,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m16n16k16.mma.row.row.u8(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}) - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_c, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m16n16k16.store.d.row.stride.s32.p1(ptr addrspace(1) %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -87,7 +87,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m16n16k16.mma.col.col.u8(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}) - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_c, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m16n16k16.store.d.col.stride.s32.p1(ptr addrspace(1) %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -118,7 +118,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m32n8k16.mma.row.row.u8(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}) - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_c, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m32n8k16.store.d.row.stride.s32.p1(ptr addrspace(1) %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -149,7 +149,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m32n8k16.mma.col.col.u8(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}) - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_c, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m32n8k16.store.d.col.stride.s32.p1(ptr addrspace(1) %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -180,7 +180,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m8n32k16.mma.row.row.u8(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}) - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_c, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m8n32k16.store.d.row.stride.s32.p1(ptr addrspace(1) %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), @@ -211,7 +211,7 @@ int main() { sg, sub_b, accB.template get_multi_ptr(), stride); // CHECK-OPAQUE: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m8n32k16.mma.col.col.u8(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}) - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_c, sub_a, sub_b, sub_c); // CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m8n32k16.store.d.col.stride.s32.p1(ptr addrspace(1) %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 16) joint_matrix_store( sg, sub_c, accD.template get_multi_ptr(), diff --git a/sycl/test/check_device_code/matrix/matrix_load_store_as.cpp b/sycl/test/check_device_code/matrix/matrix_load_store_as.cpp index e5935b8b3af47..20689495d6aa8 100644 --- a/sycl/test/check_device_code/matrix/matrix_load_store_as.cpp +++ b/sycl/test/check_device_code/matrix/matrix_load_store_as.cpp @@ -49,7 +49,7 @@ int main(void) { // B should load from global address space // CHECK: %{{.*}} = tail call spir_func noundef target("spirv.JointMatrixINTEL", i16, 16, 16, 2, 3, 1) @_Z[[#]]__spirv_JointMatrixLoadINTEL{{.*}}(ptr addrspace(1) noundef %{{.*}}, i64 noundef 32, i32 noundef 2, i32 noundef 3, i32 noundef 0) #{{.*}} joint_matrix_load(sg, tB, pB, 32); - tC = joint_matrix_mad(sg, tA, tB, tC); + joint_matrix_mad(sg, tC, tA, tB, tC); // C should store to global address space // CHECK: tail call spir_func void @_Z[[#]]__spirv_JointMatrixStoreINTEL{{.*}}(ptr addrspace(1) noundef %{{.*}}, target("spirv.JointMatrixINTEL", float, 8, 16, 3, 3, 2) noundef %{{.*}}, i64 noundef 16, i32 noundef 0, i32 noundef 3, i32 noundef 0) #{{.*}} joint_matrix_store(sg, tC, pC, 16, layout::row_major); diff --git a/sycl/test/matrix/matrix-bfloat16-test.cpp b/sycl/test/matrix/matrix-bfloat16-test.cpp index 2e0e309081464..83bfb767e7d79 100644 --- a/sycl/test/matrix/matrix-bfloat16-test.cpp +++ b/sycl/test/matrix/matrix-bfloat16-test.cpp @@ -89,7 +89,7 @@ void matrix_multiply(big_matrix &C, accB.template get_multi_ptr() + (k * TK / 2) * (N * 2) + sg_starty / SG_SZ * TN * 2, N * 2); - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_c, sub_a, sub_b, sub_c); } joint_matrix_store( sg, sub_c, diff --git a/sycl/test/matrix/matrix-elemwise-ops.cpp b/sycl/test/matrix/matrix-elemwise-ops.cpp index 3205e4c346ba6..a7f7b4526dc74 100644 --- a/sycl/test/matrix/matrix-elemwise-ops.cpp +++ b/sycl/test/matrix/matrix-elemwise-ops.cpp @@ -93,7 +93,7 @@ void matrix_multiply(big_matrix &C, accB.template get_multi_ptr() + (k * TK / 4) * (N * 4) + sg_starty / SG_SZ * TN * 4, N * 4); - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_c, sub_a, sub_b, sub_c); } auto wi_data_c = sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_c); diff --git a/sycl/test/matrix/matrix-int8-test.cpp b/sycl/test/matrix/matrix-int8-test.cpp index f8dcc26ab1b17..dd83f6dd6242f 100644 --- a/sycl/test/matrix/matrix-int8-test.cpp +++ b/sycl/test/matrix/matrix-int8-test.cpp @@ -94,7 +94,7 @@ void matrix_multiply(big_matrix &C, accB.template get_multi_ptr() + (k * TK / 4) * (N * 4) + sg_starty / SG_SZ * TN * 4, N * 4); - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_c, sub_a, sub_b, sub_c); } joint_matrix_store( sg, sub_c, diff --git a/sycl/test/matrix/matrix-tf32-test.cpp b/sycl/test/matrix/matrix-tf32-test.cpp index d6affb4067003..fb52c722d3b9f 100644 --- a/sycl/test/matrix/matrix-tf32-test.cpp +++ b/sycl/test/matrix/matrix-tf32-test.cpp @@ -92,7 +92,7 @@ void matrix_multiply(big_matrix &C, for (int i = 0; i < wi_data_b.length(); i++) { wi_data_b[i] = round_to_tf32(wi_data_b[i]); } - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_c, sub_a, sub_b, sub_c); } joint_matrix_store( sg, sub_c, From 821fa89fc74493b0b984a472c5e3d4851d26dbf7 Mon Sep 17 00:00:00 2001 From: Bing1 Yu Date: Thu, 21 Sep 2023 15:18:41 +0800 Subject: [PATCH 12/50] testcase changes: ext_intel_layout --- sycl/test-e2e/Matrix/element_wise_abc_impl.hpp | 2 +- .../element_wise_all_ops_int8_packed_impl.hpp | 10 +++++----- .../element_wise_irreg_sum_rows_impl.hpp | 2 +- sycl/test-e2e/Matrix/element_wise_ops_impl.hpp | 2 +- .../Matrix/elemwise_irreg_size_ops_bf16.cpp | 2 +- .../Matrix/get_coord_int8_matB_impl.hpp | 2 +- .../Matrix/joint_matrix_all_sizes_impl.hpp | 2 +- .../joint_matrix_bf16_fill_k_cache_impl.hpp | 18 +++++++++--------- .../joint_matrix_bfloat16_32x64_impl.hpp | 2 +- .../joint_matrix_bfloat16_array_impl.hpp | 2 +- .../Matrix/joint_matrix_bfloat16_impl.hpp | 2 +- .../test-e2e/Matrix/joint_matrix_half_impl.hpp | 2 +- .../Matrix/joint_matrix_out_bounds_impl.hpp | 2 +- .../Matrix/joint_matrix_query_default.cpp | 2 +- .../Matrix/joint_matrix_ss_int8_impl.hpp | 2 +- .../Matrix/joint_matrix_su_int8_impl.hpp | 2 +- .../Matrix/joint_matrix_transposeC_impl.hpp | 2 +- .../Matrix/joint_matrix_us_int8_impl.hpp | 2 +- .../Matrix/joint_matrix_uu_int8_impl.hpp | 2 +- .../matrix/matrix_load_store_as.cpp | 2 +- .../matrix-bfloat16-test-coord-basicB.cpp | 2 +- sycl/test/matrix/matrix-bfloat16-test.cpp | 2 +- sycl/test/matrix/matrix-elemwise-ops.cpp | 2 +- sycl/test/matrix/matrix-int8-test.cpp | 2 +- sycl/test/matrix/query-use.cpp | 4 ++-- 25 files changed, 38 insertions(+), 38 deletions(-) diff --git a/sycl/test-e2e/Matrix/element_wise_abc_impl.hpp b/sycl/test-e2e/Matrix/element_wise_abc_impl.hpp index 8b7ee3af2b9c5..3efc1c547c169 100644 --- a/sycl/test-e2e/Matrix/element_wise_abc_impl.hpp +++ b/sycl/test-e2e/Matrix/element_wise_abc_impl.hpp @@ -56,7 +56,7 @@ void matrix_elem_wise_ops(big_matrix &C, big_matrix &A, joint_matrix sub_a; // For B, we assume B has been already VNNIed. joint_matrix + layout::ext_intel_packed> sub_b; joint_matrix sub_c; diff --git a/sycl/test-e2e/Matrix/element_wise_all_ops_int8_packed_impl.hpp b/sycl/test-e2e/Matrix/element_wise_all_ops_int8_packed_impl.hpp index ce89a04b4168c..f7f8f305ac19b 100644 --- a/sycl/test-e2e/Matrix/element_wise_all_ops_int8_packed_impl.hpp +++ b/sycl/test-e2e/Matrix/element_wise_all_ops_int8_packed_impl.hpp @@ -37,7 +37,7 @@ void matrix_verify_add(queue q, big_matrix &A, nd_range<2> &r, sub_group sg = spmd_item.get_sub_group(); joint_matrix + layout::ext_intel_packed> sub_b; joint_matrix_fill(sg, sub_b, 5); @@ -74,7 +74,7 @@ void matrix_verify_sub(queue q, big_matrix &A, nd_range<2> &r, sub_group sg = spmd_item.get_sub_group(); joint_matrix + layout::ext_intel_packed> sub_b; joint_matrix_fill(sg, sub_b, 5); @@ -111,7 +111,7 @@ void matrix_verify_mul(queue q, big_matrix &A, nd_range<2> &r, sub_group sg = spmd_item.get_sub_group(); joint_matrix + layout::ext_intel_packed> sub_b; joint_matrix_fill(sg, sub_b, 5); @@ -148,7 +148,7 @@ void matrix_verify_div(queue q, big_matrix &A, nd_range<2> &r, sub_group sg = spmd_item.get_sub_group(); joint_matrix + layout::ext_intel_packed> sub_b; joint_matrix_fill(sg, sub_b, 4); @@ -185,7 +185,7 @@ void matrix_verify_logic(queue q, big_matrix &A, nd_range<2> &r, sub_group sg = spmd_item.get_sub_group(); joint_matrix + layout::ext_intel_packed> sub_b; joint_matrix_fill(sg, sub_b, 5); diff --git a/sycl/test-e2e/Matrix/element_wise_irreg_sum_rows_impl.hpp b/sycl/test-e2e/Matrix/element_wise_irreg_sum_rows_impl.hpp index cfce95cba269f..d412033e684e7 100644 --- a/sycl/test-e2e/Matrix/element_wise_irreg_sum_rows_impl.hpp +++ b/sycl/test-e2e/Matrix/element_wise_irreg_sum_rows_impl.hpp @@ -45,7 +45,7 @@ void matrix_sum_rows(queue q, big_matrix &B, nd_range<2> &r) { sycl::sub_group sg = spmd_item.get_sub_group(); joint_matrix + layout::ext_intel_packed> sub_b; joint_matrix_load( diff --git a/sycl/test-e2e/Matrix/element_wise_ops_impl.hpp b/sycl/test-e2e/Matrix/element_wise_ops_impl.hpp index 66fb8237c2be3..1a92343a27558 100644 --- a/sycl/test-e2e/Matrix/element_wise_ops_impl.hpp +++ b/sycl/test-e2e/Matrix/element_wise_ops_impl.hpp @@ -52,7 +52,7 @@ void matrix_multiply(big_matrix &C, sub_a; // For B, we assume B has been already VNNIed. joint_matrix + layout::ext_intel_packed> sub_b; joint_matrix sub_c; diff --git a/sycl/test-e2e/Matrix/elemwise_irreg_size_ops_bf16.cpp b/sycl/test-e2e/Matrix/elemwise_irreg_size_ops_bf16.cpp index a74288c71ae16..11db7a04e5295 100644 --- a/sycl/test-e2e/Matrix/elemwise_irreg_size_ops_bf16.cpp +++ b/sycl/test-e2e/Matrix/elemwise_irreg_size_ops_bf16.cpp @@ -81,7 +81,7 @@ void matrix_multiply(big_matrix &C, sub_a; // For B, we assume B has been already VNNIed. joint_matrix + layout::ext_intel_packed> sub_b; joint_matrix sub_c; joint_matrix_load( diff --git a/sycl/test-e2e/Matrix/get_coord_int8_matB_impl.hpp b/sycl/test-e2e/Matrix/get_coord_int8_matB_impl.hpp index 414254dc669d4..eb1785e3da44d 100644 --- a/sycl/test-e2e/Matrix/get_coord_int8_matB_impl.hpp +++ b/sycl/test-e2e/Matrix/get_coord_int8_matB_impl.hpp @@ -113,7 +113,7 @@ void matrix_sum_cols(queue q, big_matrix &B, nd_range<2> &r) { sycl::sub_group sg = spmd_item.get_sub_group(); joint_matrix + layout::ext_intel_packed> sub_b; joint_matrix_load( diff --git a/sycl/test-e2e/Matrix/joint_matrix_all_sizes_impl.hpp b/sycl/test-e2e/Matrix/joint_matrix_all_sizes_impl.hpp index dbd88e607f243..b8b660fb33ee2 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_all_sizes_impl.hpp +++ b/sycl/test-e2e/Matrix/joint_matrix_all_sizes_impl.hpp @@ -57,7 +57,7 @@ void matrix_multiply(big_matrix &C, big_matrix &A, joint_matrix sub_a; // For B, we assume B has been already VNNIed. joint_matrix + layout::ext_intel_packed> sub_b; joint_matrix sub_c; diff --git a/sycl/test-e2e/Matrix/joint_matrix_bf16_fill_k_cache_impl.hpp b/sycl/test-e2e/Matrix/joint_matrix_bf16_fill_k_cache_impl.hpp index f2d359cfe130b..94ab1d07646e1 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_bf16_fill_k_cache_impl.hpp +++ b/sycl/test-e2e/Matrix/joint_matrix_bf16_fill_k_cache_impl.hpp @@ -148,35 +148,35 @@ double joint_matmul(TOperand *A, TOperand *B, TResult *C, queue &q, int i) { ; joint_matrix + layout::ext_intel_packed> tB[NCACHE1 / tN][KCACHE2 / KCACHE1] #ifdef INIT_LIST = { joint_matrix< sub_group, TOperand, use::b, tK, tN, - ext::intel::experimental::matrix::layout::packed>(), + layout::ext_intel_packed>(), joint_matrix< sub_group, TOperand, use::b, tK, tN, - ext::intel::experimental::matrix::layout::packed>(), + layout::ext_intel_packed>(), joint_matrix< sub_group, TOperand, use::b, tK, tN, - ext::intel::experimental::matrix::layout::packed>(), + layout::ext_intel_packed>(), joint_matrix< sub_group, TOperand, use::b, tK, tN, - ext::intel::experimental::matrix::layout::packed>(), + layout::ext_intel_packed>(), joint_matrix< sub_group, TOperand, use::b, tK, tN, - ext::intel::experimental::matrix::layout::packed>(), + layout::ext_intel_packed>(), joint_matrix< sub_group, TOperand, use::b, tK, tN, - ext::intel::experimental::matrix::layout::packed>(), + layout::ext_intel_packed>(), joint_matrix< sub_group, TOperand, use::b, tK, tN, - ext::intel::experimental::matrix::layout::packed>(), + layout::ext_intel_packed>(), joint_matrix< sub_group, TOperand, use::b, tK, tN, - ext::intel::experimental::matrix::layout::packed>(), + layout::ext_intel_packed>(), } #endif ; diff --git a/sycl/test-e2e/Matrix/joint_matrix_bfloat16_32x64_impl.hpp b/sycl/test-e2e/Matrix/joint_matrix_bfloat16_32x64_impl.hpp index 264961a05ad96..40cc2ad58bdc6 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_bfloat16_32x64_impl.hpp +++ b/sycl/test-e2e/Matrix/joint_matrix_bfloat16_32x64_impl.hpp @@ -47,7 +47,7 @@ void matrix_multiply(big_matrix &C, big_matrix &A, sub_a; // For B, we assume B has been already VNNIed. joint_matrix + layout::ext_intel_packed> sub_b; joint_matrix sub_c; diff --git a/sycl/test-e2e/Matrix/joint_matrix_bfloat16_array_impl.hpp b/sycl/test-e2e/Matrix/joint_matrix_bfloat16_array_impl.hpp index d1c9939318551..671cf78b660a1 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_bfloat16_array_impl.hpp +++ b/sycl/test-e2e/Matrix/joint_matrix_bfloat16_array_impl.hpp @@ -60,7 +60,7 @@ void matrix_multiply(big_matrix &C, big_matrix &A, // For B, we assume B has been already VNNIed. joint_matrix + layout::ext_intel_packed> sub_b; joint_matrix sub_c[JM_ARRAY_SZ]; diff --git a/sycl/test-e2e/Matrix/joint_matrix_bfloat16_impl.hpp b/sycl/test-e2e/Matrix/joint_matrix_bfloat16_impl.hpp index 4889e77812d72..ddf731fba24a3 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_bfloat16_impl.hpp +++ b/sycl/test-e2e/Matrix/joint_matrix_bfloat16_impl.hpp @@ -46,7 +46,7 @@ void matrix_multiply(big_matrix &C, big_matrix &A, sub_a; // For B, we assume B has been already VNNIed. joint_matrix + layout::ext_intel_packed> sub_b; joint_matrix sub_c; diff --git a/sycl/test-e2e/Matrix/joint_matrix_half_impl.hpp b/sycl/test-e2e/Matrix/joint_matrix_half_impl.hpp index 2ac425955a555..c7a09229063eb 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_half_impl.hpp +++ b/sycl/test-e2e/Matrix/joint_matrix_half_impl.hpp @@ -51,7 +51,7 @@ void matrix_multiply(big_matrix &C, sub_a; // For B, we assume B has been already VNNIed. joint_matrix + layout::ext_intel_packed> sub_b; joint_matrix sub_c; diff --git a/sycl/test-e2e/Matrix/joint_matrix_out_bounds_impl.hpp b/sycl/test-e2e/Matrix/joint_matrix_out_bounds_impl.hpp index 008b3531a7ec8..51ea6745a8174 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_out_bounds_impl.hpp +++ b/sycl/test-e2e/Matrix/joint_matrix_out_bounds_impl.hpp @@ -44,7 +44,7 @@ void matrix_multiply(T1 *C, T2 *A, T2 *B, queue q, unsigned int vnniFactor) { // For B, since current implementation does not support non-packed // layout, users need to specify the packed_b layout. joint_matrix + layout::ext_intel_packed> sub_b; joint_matrix sub_c; // bounds-checked load where width and height are added diff --git a/sycl/test-e2e/Matrix/joint_matrix_query_default.cpp b/sycl/test-e2e/Matrix/joint_matrix_query_default.cpp index ddc0c14e32de8..ef5f702b4356c 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_query_default.cpp +++ b/sycl/test-e2e/Matrix/joint_matrix_query_default.cpp @@ -78,7 +78,7 @@ void matrix_multiply(big_matrix &C, myparams2::joint_matrix_a sub_a; myparams2::joint_matrix_b< - sub_group, ext::intel::experimental::matrix::layout::packed> + sub_group, layout::ext_intel_packed> sub_b; myparams2::joint_matrix_accumulator sub_c; diff --git a/sycl/test-e2e/Matrix/joint_matrix_ss_int8_impl.hpp b/sycl/test-e2e/Matrix/joint_matrix_ss_int8_impl.hpp index 21353754b6580..8135897f893f9 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_ss_int8_impl.hpp +++ b/sycl/test-e2e/Matrix/joint_matrix_ss_int8_impl.hpp @@ -52,7 +52,7 @@ void matrix_multiply(big_matrix &C, sub_a; // For B, we assume B has been already VNNIed. joint_matrix + layout::ext_intel_packed> sub_b; joint_matrix sub_c; diff --git a/sycl/test-e2e/Matrix/joint_matrix_su_int8_impl.hpp b/sycl/test-e2e/Matrix/joint_matrix_su_int8_impl.hpp index 80dbbd1afbbc5..2730f0f6184de 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_su_int8_impl.hpp +++ b/sycl/test-e2e/Matrix/joint_matrix_su_int8_impl.hpp @@ -52,7 +52,7 @@ void matrix_multiply(big_matrix &C, sub_a; // For B, we assume B has been already VNNIed. joint_matrix + layout::ext_intel_packed> sub_b; joint_matrix sub_c; diff --git a/sycl/test-e2e/Matrix/joint_matrix_transposeC_impl.hpp b/sycl/test-e2e/Matrix/joint_matrix_transposeC_impl.hpp index 56dbf30bed4d4..02bad19d0d4f4 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_transposeC_impl.hpp +++ b/sycl/test-e2e/Matrix/joint_matrix_transposeC_impl.hpp @@ -43,7 +43,7 @@ void matrix_multiply(T1 *C, T2 *A, T2 *B, queue q, unsigned int vnniFactor) { // For B, since current implementation does not support non-packed // layout, users need to specify the packed_b layout. joint_matrix + layout::ext_intel_packed> sub_b; joint_matrix sub_c; joint_matrix_load(sg, sub_c, diff --git a/sycl/test-e2e/Matrix/joint_matrix_us_int8_impl.hpp b/sycl/test-e2e/Matrix/joint_matrix_us_int8_impl.hpp index 2bda766fc290a..47c9d82e18479 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_us_int8_impl.hpp +++ b/sycl/test-e2e/Matrix/joint_matrix_us_int8_impl.hpp @@ -54,7 +54,7 @@ void matrix_multiply(big_matrix &C, sub_a; // For B, we assume B has been already VNNIed. joint_matrix + layout::ext_intel_packed> sub_b; joint_matrix sub_c; diff --git a/sycl/test-e2e/Matrix/joint_matrix_uu_int8_impl.hpp b/sycl/test-e2e/Matrix/joint_matrix_uu_int8_impl.hpp index 832b28ecc0562..c132aeafef9d2 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_uu_int8_impl.hpp +++ b/sycl/test-e2e/Matrix/joint_matrix_uu_int8_impl.hpp @@ -52,7 +52,7 @@ void matrix_multiply(big_matrix &C, sub_a; // For B, we assume B has been already VNNIed. joint_matrix + layout::ext_intel_packed> sub_b; joint_matrix sub_c; diff --git a/sycl/test/check_device_code/matrix/matrix_load_store_as.cpp b/sycl/test/check_device_code/matrix/matrix_load_store_as.cpp index 20689495d6aa8..22c8203444ab4 100644 --- a/sycl/test/check_device_code/matrix/matrix_load_store_as.cpp +++ b/sycl/test/check_device_code/matrix/matrix_load_store_as.cpp @@ -30,7 +30,7 @@ int main(void) { layout::row_major> tA; joint_matrix + layout::ext_intel_packed> tB; joint_matrix tC; diff --git a/sycl/test/matrix/matrix-bfloat16-test-coord-basicB.cpp b/sycl/test/matrix/matrix-bfloat16-test-coord-basicB.cpp index b141f2176971d..0823244cd1dc5 100644 --- a/sycl/test/matrix/matrix-bfloat16-test-coord-basicB.cpp +++ b/sycl/test/matrix/matrix-bfloat16-test-coord-basicB.cpp @@ -155,7 +155,7 @@ void matrix_sum_cols(queue q, big_matrix &B, nd_range<2> &r) { // TK = 32, TN = 16 joint_matrix + layout::ext_intel_packed> sub_b; joint_matrix_load( diff --git a/sycl/test/matrix/matrix-bfloat16-test.cpp b/sycl/test/matrix/matrix-bfloat16-test.cpp index 83bfb767e7d79..37dc5a1607631 100644 --- a/sycl/test/matrix/matrix-bfloat16-test.cpp +++ b/sycl/test/matrix/matrix-bfloat16-test.cpp @@ -68,7 +68,7 @@ void matrix_multiply(big_matrix &C, // the packed_b layout. By default, the layout is row_major and size // is (TK, TN). joint_matrix + layout::ext_intel_packed> sub_b; joint_matrix sub_c; diff --git a/sycl/test/matrix/matrix-elemwise-ops.cpp b/sycl/test/matrix/matrix-elemwise-ops.cpp index a7f7b4526dc74..d5ec19ea0a096 100644 --- a/sycl/test/matrix/matrix-elemwise-ops.cpp +++ b/sycl/test/matrix/matrix-elemwise-ops.cpp @@ -69,7 +69,7 @@ void matrix_multiply(big_matrix &C, // the packed_b layout. By default, the layout is row_major and size // is (TK, TN). joint_matrix + layout::ext_intel_packed> sub_b; joint_matrix sub_c; diff --git a/sycl/test/matrix/matrix-int8-test.cpp b/sycl/test/matrix/matrix-int8-test.cpp index dd83f6dd6242f..c4ab58c1deaec 100644 --- a/sycl/test/matrix/matrix-int8-test.cpp +++ b/sycl/test/matrix/matrix-int8-test.cpp @@ -74,7 +74,7 @@ void matrix_multiply(big_matrix &C, // the packed_b layout. By default, the layout is row_major and size // is (TK, TN). joint_matrix + layout::ext_intel_packed> sub_b; joint_matrix sub_c; diff --git a/sycl/test/matrix/query-use.cpp b/sycl/test/matrix/query-use.cpp index 9afc8e1173043..05f62c093fb28 100644 --- a/sycl/test/matrix/query-use.cpp +++ b/sycl/test/matrix/query-use.cpp @@ -64,7 +64,7 @@ void query_amx() { sub_group sg = spmd_item.get_sub_group(); myparams2::joint_matrix_a sub_a1; myparams2::joint_matrix_b< - sub_group, sycl::ext::intel::experimental::matrix::layout::packed> + sub_group, layout::ext_intel_packed> sub_b1; myparams2::joint_matrix_accumulator sub_c1; @@ -144,7 +144,7 @@ void query_xmx8() { sub_group sg = spmd_item.get_sub_group(); myparams2::joint_matrix_a sub_a1; myparams2::joint_matrix_b< - sub_group, sycl::ext::intel::experimental::matrix::layout::packed> + sub_group, layout::ext_intel_packed> sub_b1; myparams2::joint_matrix_accumulator sub_c1; From a3921b52b62f8cf8b8e6da131da7ad8d2264ed88 Mon Sep 17 00:00:00 2001 From: Bing1 Yu Date: Thu, 21 Sep 2023 20:03:57 +0800 Subject: [PATCH 13/50] testcase changes: wi_data=>jm_apply --- .../test-e2e/Matrix/element_wise_abc_impl.hpp | 18 ++----- .../Matrix/element_wise_all_ops_half_impl.hpp | 51 +++++++------------ .../Matrix/element_wise_all_ops_impl.hpp | 13 +---- .../Matrix/element_wise_all_ops_int8_impl.hpp | 41 ++++----------- .../element_wise_all_ops_int8_packed_impl.hpp | 41 ++++----------- .../Matrix/element_wise_all_sizes_impl.hpp | 44 ++++++++-------- .../element_wise_irreg_sum_rows_impl.hpp | 2 +- .../test-e2e/Matrix/element_wise_ops_impl.hpp | 6 +-- .../Matrix/elemwise_irreg_size_ops_bf16.cpp | 6 +-- .../Matrix/get_coord_float_matC_impl.hpp | 12 ++--- .../Matrix/get_coord_int8_matA_impl.hpp | 13 ++--- .../Matrix/get_coord_int8_matB_impl.hpp | 2 +- .../Matrix/joint_matrix_down_convert_impl.hpp | 10 +--- .../Matrix/joint_matrix_tf32_impl.hpp | 10 ++-- .../matrix-bfloat16-test-coord-basicB.cpp | 24 ++++----- sycl/test/matrix/matrix-elemwise-ops.cpp | 6 +-- sycl/test/matrix/matrix-tf32-test.cpp | 7 +-- 17 files changed, 92 insertions(+), 214 deletions(-) diff --git a/sycl/test-e2e/Matrix/element_wise_abc_impl.hpp b/sycl/test-e2e/Matrix/element_wise_abc_impl.hpp index 3efc1c547c169..37c2a93554eec 100644 --- a/sycl/test-e2e/Matrix/element_wise_abc_impl.hpp +++ b/sycl/test-e2e/Matrix/element_wise_abc_impl.hpp @@ -65,33 +65,21 @@ void matrix_elem_wise_ops(big_matrix &C, big_matrix &A, accA.template get_multi_ptr() + (sg_startx * TM) * K, K); - auto wi_slice_a = - sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_a); - for (int i = 0; i < wi_slice_a.length(); i++) { - wi_slice_a[i] += 1; - } + joint_matrix_apply(sg, sub_a, [](T2 &x) { x += 1; }); joint_matrix_load( sg, sub_b, accB.template get_multi_ptr() + sg_starty / SG_SZ * TN * vnniFactor, N * vnniFactor); - auto wi_slice_b = - sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_b); - for (int i = 0; i < wi_slice_b.length(); i++) { - wi_slice_b[i] += 1; - } + joint_matrix_apply(sg, sub_b, [](T2 &x) { x += 1; }); joint_matrix_load( sg, sub_c, accC.template get_multi_ptr() + (sg_startx * TM) * N + sg_starty / SG_SZ * TN, N, layout::row_major); - auto wi_slice_c = - sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_c); - for (int i = 0; i < wi_slice_c.length(); i++) { - wi_slice_c[i] += 1; - } + joint_matrix_apply(sg, sub_c, [](T1 &x) { x += 1; }); }); // parallel for }).wait(); } diff --git a/sycl/test-e2e/Matrix/element_wise_all_ops_half_impl.hpp b/sycl/test-e2e/Matrix/element_wise_all_ops_half_impl.hpp index 540d75c245815..42e1afb4d69f1 100644 --- a/sycl/test-e2e/Matrix/element_wise_all_ops_half_impl.hpp +++ b/sycl/test-e2e/Matrix/element_wise_all_ops_half_impl.hpp @@ -41,11 +41,8 @@ void matrix_verify_add(queue q, big_matrix &A, nd_range<2> &r, joint_matrix_fill(sg, sub_a, 5); - auto wi_slice_a = - sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_a); - for (int i = 0; i < wi_slice_a.length(); i++) { - wi_slice_a[i] = wi_slice_a[i] + static_cast(2); - } + joint_matrix_apply(sg, sub_a, + [=](T &x) { x = x + static_cast(2); }); ext::intel::experimental::matrix::joint_matrix_store( sg, sub_a, accA.template get_multi_ptr() + @@ -76,11 +73,8 @@ void matrix_verify_sub(queue q, big_matrix &A, nd_range<2> &r, joint_matrix_fill(sg, sub_a, 5); - auto wi_slice_a = - sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_a); - for (int i = 0; i < wi_slice_a.length(); i++) { - wi_slice_a[i] = wi_slice_a[i] - static_cast(2); - } + joint_matrix_apply(sg, sub_a, + [=](T &x) { x = x - static_cast(2); }); ext::intel::experimental::matrix::joint_matrix_store( sg, sub_a, accA.template get_multi_ptr() + @@ -111,11 +105,8 @@ void matrix_verify_mul(queue q, big_matrix &A, nd_range<2> &r, joint_matrix_fill(sg, sub_a, 5); - auto wi_slice_a = - sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_a); - for (int i = 0; i < wi_slice_a.length(); i++) { - wi_slice_a[i] = wi_slice_a[i] * static_cast(3.0); - } + joint_matrix_apply(sg, sub_a, + [=](T &x) { x = x * static_cast(3.0); }); ext::intel::experimental::matrix::joint_matrix_store( sg, sub_a, accA.template get_multi_ptr() + @@ -146,11 +137,8 @@ void matrix_verify_div(queue q, big_matrix &A, nd_range<2> &r, joint_matrix_fill(sg, sub_a, 4); - auto wi_slice_a = - sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_a); - for (int i = 0; i < wi_slice_a.length(); i++) { - wi_slice_a[i] = wi_slice_a[i] / static_cast(2.0); - } + joint_matrix_apply(sg, sub_a, + [=](T &x) { x = x / static_cast(2.0); }); ext::intel::experimental::matrix::joint_matrix_store( sg, sub_a, accA.template get_multi_ptr() + @@ -181,30 +169,25 @@ void matrix_verify_logic(queue q, big_matrix &A, nd_range<2> &r, joint_matrix_fill(sg, sub_a, 5); - auto wi_slice_a = - sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_a); - for (int i = 0; i < wi_slice_a.length(); i++) { - if (wi_slice_a[i]) { - if (wi_slice_a[i] > static_cast(2.0) || - wi_slice_a[i] >= static_cast(2.0) || - wi_slice_a[i] < static_cast(2.0) || - wi_slice_a[i] <= static_cast(2.0)) { - T val = (wi_slice_a[i] != static_cast(2.0)) - ? wi_slice_a[i] - : static_cast(2.0); + joint_matrix_apply(sg, sub_a, [](T &x) { + if (x) { + if (x > static_cast(2.0) || x >= static_cast(2.0) || + x < static_cast(2.0) || x <= static_cast(2.0)) { + T val = + (x != static_cast(2.0)) ? x : static_cast(2.0); val--; val++; - if (wi_slice_a[i] == static_cast(2.0)) { + if (x == static_cast(2.0)) { val -= 2; val *= 3; val /= 2; } else { val += 2; } - wi_slice_a[i] = val; + x = val; } } - } + }); ext::intel::experimental::matrix::joint_matrix_store( sg, sub_a, accA.template get_multi_ptr() + diff --git a/sycl/test-e2e/Matrix/element_wise_all_ops_impl.hpp b/sycl/test-e2e/Matrix/element_wise_all_ops_impl.hpp index 8e15488e151a0..b11d3093bf08d 100644 --- a/sycl/test-e2e/Matrix/element_wise_all_ops_impl.hpp +++ b/sycl/test-e2e/Matrix/element_wise_all_ops_impl.hpp @@ -63,12 +63,7 @@ void verify_op_a(const T l, const T r, const float ref, OP op) { layout::row_major> sub_mat; joint_matrix_fill(sg, sub_mat, l); - auto wi_slice = - sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_mat); - for (int i = 0; i < wi_slice.length(); i++) { - wi_slice[i] = op(wi_slice[i], r); - } - + joint_matrix_apply(sg, sub_mat, [=](T &x) { x = op(x, r); }); ext::intel::experimental::matrix::joint_matrix_store( sg, sub_mat, accessMat.template get_multi_ptr() + @@ -104,11 +99,7 @@ void verify_op_c(const T l, const T r, const float ref, OP op) { joint_matrix sub_mat; joint_matrix_fill(sg, sub_mat, l); - auto wi_slice = - sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_mat); - for (int i = 0; i < wi_slice.length(); i++) { - wi_slice[i] = op(wi_slice[i], r); - } + joint_matrix_apply(sg, sub_mat, [=](T &x) { x = op(x, r); }); joint_matrix_store( sg, sub_mat, diff --git a/sycl/test-e2e/Matrix/element_wise_all_ops_int8_impl.hpp b/sycl/test-e2e/Matrix/element_wise_all_ops_int8_impl.hpp index 803ebe0addb3a..4a43d39738657 100644 --- a/sycl/test-e2e/Matrix/element_wise_all_ops_int8_impl.hpp +++ b/sycl/test-e2e/Matrix/element_wise_all_ops_int8_impl.hpp @@ -40,11 +40,7 @@ void matrix_verify_add(queue q, big_matrix &A, nd_range<2> &r, joint_matrix_fill(sg, sub_a, 5); - auto wi_slice_a = - sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_a); - for (int i = 0; i < wi_slice_a.length(); i++) { - wi_slice_a[i] = wi_slice_a[i] + 2; - } + joint_matrix_apply(sg, sub_a, [](T &x) { x = x + 2; }); ext::intel::experimental::matrix::joint_matrix_store( sg, sub_a, accA.template get_multi_ptr() + @@ -75,11 +71,7 @@ void matrix_verify_sub(queue q, big_matrix &A, nd_range<2> &r, joint_matrix_fill(sg, sub_a, 5); - auto wi_slice_a = - sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_a); - for (int i = 0; i < wi_slice_a.length(); i++) { - wi_slice_a[i] = wi_slice_a[i] - 2; - } + joint_matrix_apply(sg, sub_a, [](T &x) { x = x - 2; }); ext::intel::experimental::matrix::joint_matrix_store( sg, sub_a, accA.template get_multi_ptr() + @@ -110,11 +102,7 @@ void matrix_verify_mul(queue q, big_matrix &A, nd_range<2> &r, joint_matrix_fill(sg, sub_a, 5); - auto wi_slice_a = - sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_a); - for (int i = 0; i < wi_slice_a.length(); i++) { - wi_slice_a[i] = wi_slice_a[i] * 3; - } + joint_matrix_apply(sg, sub_a, [](T &x) { x = x * 3; }); ext::intel::experimental::matrix::joint_matrix_store( sg, sub_a, accA.template get_multi_ptr() + @@ -145,11 +133,7 @@ void matrix_verify_div(queue q, big_matrix &A, nd_range<2> &r, joint_matrix_fill(sg, sub_a, 4); - auto wi_slice_a = - sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_a); - for (int i = 0; i < wi_slice_a.length(); i++) { - wi_slice_a[i] = wi_slice_a[i] / 2; - } + joint_matrix_apply(sg, sub_a, [](T &x) { x = x / 2; }); ext::intel::experimental::matrix::joint_matrix_store( sg, sub_a, accA.template get_multi_ptr() + @@ -180,26 +164,23 @@ void matrix_verify_logic(queue q, big_matrix &A, nd_range<2> &r, joint_matrix_fill(sg, sub_a, 5); - auto wi_slice_a = - sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_a); - for (int i = 0; i < wi_slice_a.length(); i++) { - if (wi_slice_a[i]) { - if (wi_slice_a[i] > 2 || wi_slice_a[i] >= 2 || - wi_slice_a[i] < 2 || wi_slice_a[i] <= 2) { - T val = (wi_slice_a[i] != 2) ? wi_slice_a[i] : 2; + joint_matrix_apply(sg, sub_a, [](T &x) { + if (x) { + if (x > 2 || x >= 2 || x < 2 || x <= 2) { + T val = (x != 2) ? x : 2; val--; val++; - if (wi_slice_a[i] == 2) { + if (x == 2) { val -= 2; val *= 3; val /= 2; } else { val += 2; } - wi_slice_a[i] = val; + x = val; } } - } + }); ext::intel::experimental::matrix::joint_matrix_store( sg, sub_a, accA.template get_multi_ptr() + diff --git a/sycl/test-e2e/Matrix/element_wise_all_ops_int8_packed_impl.hpp b/sycl/test-e2e/Matrix/element_wise_all_ops_int8_packed_impl.hpp index f7f8f305ac19b..e3d21a36bd6e1 100644 --- a/sycl/test-e2e/Matrix/element_wise_all_ops_int8_packed_impl.hpp +++ b/sycl/test-e2e/Matrix/element_wise_all_ops_int8_packed_impl.hpp @@ -42,11 +42,7 @@ void matrix_verify_add(queue q, big_matrix &A, nd_range<2> &r, joint_matrix_fill(sg, sub_b, 5); - auto wi_slice_b = - sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_b); - for (int i = 0; i < wi_slice_b.length(); i++) { - wi_slice_b[i] = wi_slice_b[i] + 2; - } + joint_matrix_apply(sg, sub_b, [](int8_t &x) { x = x + 2; }); ext::intel::experimental::matrix::joint_matrix_store( sg, sub_b, accA.template get_multi_ptr() + @@ -79,11 +75,7 @@ void matrix_verify_sub(queue q, big_matrix &A, nd_range<2> &r, joint_matrix_fill(sg, sub_b, 5); - auto wi_slice_b = - sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_b); - for (int i = 0; i < wi_slice_b.length(); i++) { - wi_slice_b[i] = wi_slice_b[i] - 2; - } + joint_matrix_apply(sg, sub_b, [](int8_t &x) { x = x - 2; }); ext::intel::experimental::matrix::joint_matrix_store( sg, sub_b, accA.template get_multi_ptr() + @@ -116,11 +108,7 @@ void matrix_verify_mul(queue q, big_matrix &A, nd_range<2> &r, joint_matrix_fill(sg, sub_b, 5); - auto wi_slice_b = - sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_b); - for (int i = 0; i < wi_slice_b.length(); i++) { - wi_slice_b[i] = wi_slice_b[i] * 3; - } + joint_matrix_apply(sg, sub_b, [](int8_t &x) { x = x * 3; }); ext::intel::experimental::matrix::joint_matrix_store( sg, sub_b, accA.template get_multi_ptr() + @@ -153,11 +141,7 @@ void matrix_verify_div(queue q, big_matrix &A, nd_range<2> &r, joint_matrix_fill(sg, sub_b, 4); - auto wi_slice_b = - sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_b); - for (int i = 0; i < wi_slice_b.length(); i++) { - wi_slice_b[i] = wi_slice_b[i] / 2; - } + joint_matrix_apply(sg, sub_b, [](int8_t &x) { x = x / 2; }); ext::intel::experimental::matrix::joint_matrix_store( sg, sub_b, accA.template get_multi_ptr() + @@ -190,26 +174,23 @@ void matrix_verify_logic(queue q, big_matrix &A, nd_range<2> &r, joint_matrix_fill(sg, sub_b, 5); - auto wi_slice_b = - sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_b); - for (int i = 0; i < wi_slice_b.length(); i++) { - if (wi_slice_b[i]) { - if (wi_slice_b[i] > 2 || wi_slice_b[i] >= 2 || - wi_slice_b[i] < 2 || wi_slice_b[i] <= 2) { - T val = (wi_slice_b[i] != 2) ? wi_slice_b[i] : 2; + joint_matrix_apply(sg, sub_b, [](T &x) { + if (x) { + if (x > 2 || x >= 2 || x < 2 || x <= 2) { + T val = (x != 2) ? x : 2; val--; val++; - if (wi_slice_b[i] == 2) { + if (x == 2) { val -= 2; val *= 3; val /= 2; } else { val += 2; } - wi_slice_b[i] = val; + x = val; } } - } + }); ext::intel::experimental::matrix::joint_matrix_store( sg, sub_b, accA.template get_multi_ptr() + diff --git a/sycl/test-e2e/Matrix/element_wise_all_sizes_impl.hpp b/sycl/test-e2e/Matrix/element_wise_all_sizes_impl.hpp index c49f9b57e2f32..6e1b6410547ad 100644 --- a/sycl/test-e2e/Matrix/element_wise_all_sizes_impl.hpp +++ b/sycl/test-e2e/Matrix/element_wise_all_sizes_impl.hpp @@ -49,30 +49,26 @@ void matrix_verify_add(const T1 val1, const T1 val2, const T1 result) { q.submit([&](handler &cgh) { sycl::accessor accA{bufA, cgh, sycl::read_write}; - cgh.parallel_for(r, [=](nd_item<2> spmd_item) [[intel::reqd_sub_group_size( - SG_SZ)]] { - const auto global_idx = spmd_item.get_global_id(0); - const auto global_idy = spmd_item.get_global_id(1); - const auto sg_startx = global_idx - spmd_item.get_local_id(0); - const auto sg_starty = global_idy - spmd_item.get_local_id(1); - - sub_group sg = spmd_item.get_sub_group(); - joint_matrix sub_a; - - joint_matrix_fill(sg, sub_a, val1); - - auto wi_slice_a = - sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_a); - for (int i = 0; i < wi_slice_a.length(); i++) { - wi_slice_a[i] = wi_slice_a[i] + val2; - } - - ext::intel::experimental::matrix::joint_matrix_store( - sg, sub_a, - accA.template get_multi_ptr() + - (sg_startx * TM) * K + sg_starty / SG_SZ * TK, - K); - }); // parallel for + cgh.parallel_for( + r, [=](nd_item<2> spmd_item) [[intel::reqd_sub_group_size(SG_SZ)]] { + const auto global_idx = spmd_item.get_global_id(0); + const auto global_idy = spmd_item.get_global_id(1); + const auto sg_startx = global_idx - spmd_item.get_local_id(0); + const auto sg_starty = global_idy - spmd_item.get_local_id(1); + + sub_group sg = spmd_item.get_sub_group(); + joint_matrix sub_a; + + joint_matrix_fill(sg, sub_a, val1); + + joint_matrix_apply(sg, sub_a, [=](T &x) { x += val2; }); + + ext::intel::experimental::matrix::joint_matrix_store( + sg, sub_a, + accA.template get_multi_ptr() + + (sg_startx * TM) * K + sg_starty / SG_SZ * TK, + K); + }); // parallel for }).wait(); assert_ops_ref(bufA.get_host_access(), result); } diff --git a/sycl/test-e2e/Matrix/element_wise_irreg_sum_rows_impl.hpp b/sycl/test-e2e/Matrix/element_wise_irreg_sum_rows_impl.hpp index d412033e684e7..18761986561ac 100644 --- a/sycl/test-e2e/Matrix/element_wise_irreg_sum_rows_impl.hpp +++ b/sycl/test-e2e/Matrix/element_wise_irreg_sum_rows_impl.hpp @@ -58,7 +58,7 @@ void matrix_sum_rows(queue q, big_matrix &B, nd_range<2> &r) { int32_t sum_local_rows[M] = {0}; // 8 local rows, M total // sub_b has 32x8 elements, 32 elements per WI, 4 per WI per row auto data = - sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_b); + sycl::ext::oneapi::detail::get_wi_data(sg, sub_b); // each WI calculates local sum of rows for (int row = 0; row < TK / 4; row++) { // there are 8 rows diff --git a/sycl/test-e2e/Matrix/element_wise_ops_impl.hpp b/sycl/test-e2e/Matrix/element_wise_ops_impl.hpp index 1a92343a27558..1dd9779aa0b56 100644 --- a/sycl/test-e2e/Matrix/element_wise_ops_impl.hpp +++ b/sycl/test-e2e/Matrix/element_wise_ops_impl.hpp @@ -74,11 +74,7 @@ void matrix_multiply(big_matrix &C, N * 4); joint_matrix_mad(sg, sub_c, sub_a, sub_b, sub_c); } - auto wi_slice_c = - sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_c); - for (int i = 0; i < wi_slice_c.length(); i++) { - wi_slice_c[i] *= 2; - } + joint_matrix_apply(sg, sub_c, [](int32_t &x) { x = x * 2; }); joint_matrix_store( sg, sub_c, accC.template get_multi_ptr() + diff --git a/sycl/test-e2e/Matrix/elemwise_irreg_size_ops_bf16.cpp b/sycl/test-e2e/Matrix/elemwise_irreg_size_ops_bf16.cpp index 11db7a04e5295..cc8722467d262 100644 --- a/sycl/test-e2e/Matrix/elemwise_irreg_size_ops_bf16.cpp +++ b/sycl/test-e2e/Matrix/elemwise_irreg_size_ops_bf16.cpp @@ -103,11 +103,7 @@ void matrix_multiply(big_matrix &C, N * 2); joint_matrix_mad(sg, sub_c, sub_a, sub_b, sub_c); } - auto wi_slice_c = - sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_c); - for (int i = 0; i < wi_slice_c.length(); i++) { - wi_slice_c[i] += 5.0; - } + joint_matrix_apply(sg, sub_c, [](float &x) { x += 5.0; }); joint_matrix_store( sg, sub_c, accC.template get_multi_ptr() + diff --git a/sycl/test-e2e/Matrix/get_coord_float_matC_impl.hpp b/sycl/test-e2e/Matrix/get_coord_float_matC_impl.hpp index dea7601437742..f9d19e914e639 100644 --- a/sycl/test-e2e/Matrix/get_coord_float_matC_impl.hpp +++ b/sycl/test-e2e/Matrix/get_coord_float_matC_impl.hpp @@ -50,15 +50,11 @@ void matrix_sum_rows(big_matrix &C, float *sum_rows) { N, layout::row_major); float sum_local_rows[M] = {0}; - auto data = - sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_c); - - for (int i = 0; i < data.length(); ++i) { - auto dataItem = data[i]; - auto [row, col] = dataItem.get_coord(); - sum_local_rows[row + global_idx * TM] += dataItem; - } + ext::intel::experimental::matrix::joint_matrix_apply( + sg, sub_c, [&](float &x, size_t row, size_t col) { + sum_local_rows[row + global_idx * TM] += x; + }); for (int i = 0; i < M; i++) { sum_local_rows[i] = reduce_over_group(sg, sum_local_rows[i], sycl::plus<>()); diff --git a/sycl/test-e2e/Matrix/get_coord_int8_matA_impl.hpp b/sycl/test-e2e/Matrix/get_coord_int8_matA_impl.hpp index ec21cfa036807..619f97969b29c 100644 --- a/sycl/test-e2e/Matrix/get_coord_int8_matA_impl.hpp +++ b/sycl/test-e2e/Matrix/get_coord_int8_matA_impl.hpp @@ -96,16 +96,11 @@ void matrix_sum_rows(queue q, big_matrix &A, nd_range<2> &r) { K); int32_t sum_local_rows[M] = {0}; - auto data = - sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_a); - - // each WI calculates local sum of rows - for (int i = 0; i < data.length(); ++i) { - auto data_item = data[i]; - auto [row, col] = data_item.get_coord(); - sum_local_rows[row + global_idx * TM] += data_item; - } + ext::intel::experimental::matrix::joint_matrix_apply( + sg, sub_a, [&](int8_t &x, size_t row, size_t col) { + sum_local_rows[row + global_idx * TM] += x; + }); for (int i = 0; i < M; ++i) { sum_local_rows[i] = reduce_over_group(sg, sum_local_rows[i], sycl::plus<>()); diff --git a/sycl/test-e2e/Matrix/get_coord_int8_matB_impl.hpp b/sycl/test-e2e/Matrix/get_coord_int8_matB_impl.hpp index eb1785e3da44d..26295e8a5050e 100644 --- a/sycl/test-e2e/Matrix/get_coord_int8_matB_impl.hpp +++ b/sycl/test-e2e/Matrix/get_coord_int8_matB_impl.hpp @@ -124,7 +124,7 @@ void matrix_sum_cols(queue q, big_matrix &B, nd_range<2> &r) { int32_t sum_local_cols[N] = {0}; auto wiData = - sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_b); + sycl::ext::oneapi::detail::get_wi_data(sg, sub_b); // each WI calculates local sum of cols for (int i = 0; i < wiData.length(); ++i) { diff --git a/sycl/test-e2e/Matrix/joint_matrix_down_convert_impl.hpp b/sycl/test-e2e/Matrix/joint_matrix_down_convert_impl.hpp index 6972e3854c8e8..68e9d3c145675 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_down_convert_impl.hpp +++ b/sycl/test-e2e/Matrix/joint_matrix_down_convert_impl.hpp @@ -48,15 +48,7 @@ void matrix_copy(big_matrix &C, big_matrix &A) { accC.template get_multi_ptr() + (sg_startx * TM) * N + sg_starty / SG_SZ * TN, N, layout::row_major); - // This will be replaced by joint_matrix_copy API - // joint_matrix_copy(sg, sub_c, sub_ac); - auto wi_slice_c = - sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_c); - auto wi_slice_a = - sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_a); - for (int i = 0; i < wi_slice_c.length(); i++) { - wi_slice_a[i] = (bfloat16)wi_slice_c[i]; - } + joint_matrix_copy(sg, sub_c, sub_a); ext::intel::experimental::matrix::joint_matrix_store( sg, sub_a, accA.template get_multi_ptr() + diff --git a/sycl/test-e2e/Matrix/joint_matrix_tf32_impl.hpp b/sycl/test-e2e/Matrix/joint_matrix_tf32_impl.hpp index 33ee1d69b4e35..607aba535c74f 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_tf32_impl.hpp +++ b/sycl/test-e2e/Matrix/joint_matrix_tf32_impl.hpp @@ -76,15 +76,11 @@ void matrix_multiply(big_matrix &C, // function will work on truncated floats. joint_matrix_apply(sg, sub_a, [=](float x) { x = round_to_tf32(x); }); - auto wi_data_b = - sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_b); - for (int i = 0; i < wi_data_b.length(); i++) { - wi_data_b[i] = round_to_tf32(wi_data_b[i]); - } + joint_matrix_apply(sg, sub_b, + [=](float x) { x = round_to_tf32(x); }); joint_matrix_mad(sg, sub_c, sub_a, sub_b, sub_c); } - auto wi_slice_a = - sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_a); + joint_matrix_apply(sg, sub_a, [=](float x) { x *= 2; }); joint_matrix_store( sg, sub_c, diff --git a/sycl/test/matrix/matrix-bfloat16-test-coord-basicB.cpp b/sycl/test/matrix/matrix-bfloat16-test-coord-basicB.cpp index 0823244cd1dc5..dd715887e728b 100644 --- a/sycl/test/matrix/matrix-bfloat16-test-coord-basicB.cpp +++ b/sycl/test/matrix/matrix-bfloat16-test-coord-basicB.cpp @@ -166,8 +166,6 @@ void matrix_sum_cols(queue q, big_matrix &B, nd_range<2> &r) { int32_t sum_local_cols[N] = {0}; // 4 local cols, N total // sub_b has 32x16 elements, 32 elements per WI, 4 per WI per row - auto wiData = - sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_b); size_t global_index; // Index into the result array that holds the sums. @@ -175,19 +173,15 @@ void matrix_sum_cols(queue q, big_matrix &B, nd_range<2> &r) { // Keep track of cols handled in this WI int32_t handled_cols[N] = {-1}; - // each WI calculates local sum of cols - for (int i = 0; i < wiData.length(); ++i) { - // get the index of the element in the submatrix - auto dataItem = wiData[i]; - auto [row, col] = dataItem.get_coord(); - - // Calculation of global index - int sg_idx = (int)global_idy / SG_SZ; - global_index = col + sg_idx * 4 /*VNNI_FACTOR*/ * SG_SZ; - sum_local_cols[global_index] += wiData[i]; - handled_cols[global_index] = 1; - } - + sycl::ext::intel::experimental::matrix::joint_matrix_apply( + sg, sub_b, + [&](int8_t &x, size_t row, + size_t col) { // Calculation of global index + int sg_idx = (int)global_idy / SG_SZ; + global_index = col + sg_idx * 4 /*VNNI_FACTOR*/ * SG_SZ; + sum_local_cols[global_index] += x; + handled_cols[global_index] = 1; + }); for (int j = 0; j < N; j++) { if (handled_cols[j] == 1) { global_index = j; diff --git a/sycl/test/matrix/matrix-elemwise-ops.cpp b/sycl/test/matrix/matrix-elemwise-ops.cpp index d5ec19ea0a096..9621f570cf461 100644 --- a/sycl/test/matrix/matrix-elemwise-ops.cpp +++ b/sycl/test/matrix/matrix-elemwise-ops.cpp @@ -95,11 +95,7 @@ void matrix_multiply(big_matrix &C, N * 4); joint_matrix_mad(sg, sub_c, sub_a, sub_b, sub_c); } - auto wi_data_c = - sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_c); - for (int i = 0; i < wi_data_c.length(); i++) { - wi_data_c[i] *= 2; - } + joint_matrix_apply(sg, sub_c, [](int32_t &x) { x *= 2; }); joint_matrix_store( sg, sub_c, accC.template get_multi_ptr() + diff --git a/sycl/test/matrix/matrix-tf32-test.cpp b/sycl/test/matrix/matrix-tf32-test.cpp index fb52c722d3b9f..496af7dabd335 100644 --- a/sycl/test/matrix/matrix-tf32-test.cpp +++ b/sycl/test/matrix/matrix-tf32-test.cpp @@ -87,11 +87,8 @@ void matrix_multiply(big_matrix &C, // function will work on truncated floats. joint_matrix_apply(sg, sub_a, [=](float x) { x = round_to_tf32(x); }); - auto wi_data_b = - sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_b); - for (int i = 0; i < wi_data_b.length(); i++) { - wi_data_b[i] = round_to_tf32(wi_data_b[i]); - } + joint_matrix_apply(sg, sub_b, + [=](float &x) { x = round_to_tf32(x); }); joint_matrix_mad(sg, sub_c, sub_a, sub_b, sub_c); } joint_matrix_store( From ef1bc6764e2783e5d0bf06b3a09684310352861f Mon Sep 17 00:00:00 2001 From: Bing1 Yu Date: Thu, 21 Sep 2023 20:06:10 +0800 Subject: [PATCH 14/50] lint --- .../test-e2e/Matrix/element_wise_abc_impl.hpp | 3 +- .../element_wise_irreg_sum_rows_impl.hpp | 6 +- .../Matrix/get_coord_int8_matB_impl.hpp | 76 +++++++++---------- .../Matrix/joint_matrix_all_sizes_impl.hpp | 3 +- .../Matrix/joint_matrix_apply_cuda.hpp | 63 ++++++++------- .../joint_matrix_bf16_fill_k_cache_impl.hpp | 40 ++++------ .../Matrix/joint_matrix_query_default.cpp | 4 +- sycl/test/matrix/query-use.cpp | 8 +- 8 files changed, 91 insertions(+), 112 deletions(-) diff --git a/sycl/test-e2e/Matrix/element_wise_abc_impl.hpp b/sycl/test-e2e/Matrix/element_wise_abc_impl.hpp index 37c2a93554eec..378c46c4b84d5 100644 --- a/sycl/test-e2e/Matrix/element_wise_abc_impl.hpp +++ b/sycl/test-e2e/Matrix/element_wise_abc_impl.hpp @@ -55,8 +55,7 @@ void matrix_elem_wise_ops(big_matrix &C, big_matrix &A, sub_group sg = spmd_item.get_sub_group(); joint_matrix sub_a; // For B, we assume B has been already VNNIed. - joint_matrix + joint_matrix sub_b; joint_matrix sub_c; diff --git a/sycl/test-e2e/Matrix/element_wise_irreg_sum_rows_impl.hpp b/sycl/test-e2e/Matrix/element_wise_irreg_sum_rows_impl.hpp index 18761986561ac..683ad694fe26a 100644 --- a/sycl/test-e2e/Matrix/element_wise_irreg_sum_rows_impl.hpp +++ b/sycl/test-e2e/Matrix/element_wise_irreg_sum_rows_impl.hpp @@ -44,8 +44,7 @@ void matrix_sum_rows(queue q, big_matrix &B, nd_range<2> &r) { sycl::sub_group sg = spmd_item.get_sub_group(); - joint_matrix + joint_matrix sub_b; joint_matrix_load( @@ -57,8 +56,7 @@ void matrix_sum_rows(queue q, big_matrix &B, nd_range<2> &r) { // (tK/4) int32_t sum_local_rows[M] = {0}; // 8 local rows, M total // sub_b has 32x8 elements, 32 elements per WI, 4 per WI per row - auto data = - sycl::ext::oneapi::detail::get_wi_data(sg, sub_b); + auto data = sycl::ext::oneapi::detail::get_wi_data(sg, sub_b); // each WI calculates local sum of rows for (int row = 0; row < TK / 4; row++) { // there are 8 rows diff --git a/sycl/test-e2e/Matrix/get_coord_int8_matB_impl.hpp b/sycl/test-e2e/Matrix/get_coord_int8_matB_impl.hpp index 26295e8a5050e..0df698e3ace48 100644 --- a/sycl/test-e2e/Matrix/get_coord_int8_matB_impl.hpp +++ b/sycl/test-e2e/Matrix/get_coord_int8_matB_impl.hpp @@ -103,45 +103,43 @@ void matrix_sum_cols(queue q, big_matrix &B, nd_range<2> &r) { auto accB = bufB.get_access(cgh); auto v = sum_cols_v.get_access(cgh); - cgh.parallel_for( - r, [=](nd_item<2> spmd_item) [[intel::reqd_sub_group_size(SG_SZ)]] { - const auto global_idx = spmd_item.get_global_id(0); - const auto global_idy = spmd_item.get_global_id(1); - const auto sg_startx = global_idx - spmd_item.get_local_id(0); - const auto sg_starty = global_idy - spmd_item.get_local_id(1); - - sycl::sub_group sg = spmd_item.get_sub_group(); - - joint_matrix - sub_b; - - joint_matrix_load( - sg, sub_b, - accB.template get_multi_ptr() + - (sg_startx * (TK / VF) * N) + sg_starty / SG_SZ * TN * VF, - N); - - int32_t sum_local_cols[N] = {0}; - auto wiData = - sycl::ext::oneapi::detail::get_wi_data(sg, sub_b); - - // each WI calculates local sum of cols - for (int i = 0; i < wiData.length(); ++i) { - // get the index of the element in the submatrix - auto dataItem = wiData[i]; - auto [row, col] = dataItem.get_coord(); - size_t global_index = col + global_idy / SG_SZ * TN * VF; - sum_local_cols[global_index] += dataItem; - } - - for (int i = 0; i < N; i++) { - sum_local_cols[i] = - reduce_over_group(sg, sum_local_cols[i], sycl::plus<>()); - if (global_idy % SG_SZ == 0) - atomic_fetch_add(v[i], sum_local_cols[i]); - } - }); // parallel for + cgh.parallel_for(r, [=](nd_item<2> spmd_item) [[intel::reqd_sub_group_size( + SG_SZ)]] { + const auto global_idx = spmd_item.get_global_id(0); + const auto global_idy = spmd_item.get_global_id(1); + const auto sg_startx = global_idx - spmd_item.get_local_id(0); + const auto sg_starty = global_idy - spmd_item.get_local_id(1); + + sycl::sub_group sg = spmd_item.get_sub_group(); + + joint_matrix + sub_b; + + joint_matrix_load(sg, sub_b, + accB.template get_multi_ptr() + + (sg_startx * (TK / VF) * N) + + sg_starty / SG_SZ * TN * VF, + N); + + int32_t sum_local_cols[N] = {0}; + auto wiData = sycl::ext::oneapi::detail::get_wi_data(sg, sub_b); + + // each WI calculates local sum of cols + for (int i = 0; i < wiData.length(); ++i) { + // get the index of the element in the submatrix + auto dataItem = wiData[i]; + auto [row, col] = dataItem.get_coord(); + size_t global_index = col + global_idy / SG_SZ * TN * VF; + sum_local_cols[global_index] += dataItem; + } + + for (int i = 0; i < N; i++) { + sum_local_cols[i] = + reduce_over_group(sg, sum_local_cols[i], sycl::plus<>()); + if (global_idy % SG_SZ == 0) + atomic_fetch_add(v[i], sum_local_cols[i]); + } + }); // parallel for }).wait(); sum_cols_ref(bufB.get_host_access(), sum_cols_v.get_host_access()); } diff --git a/sycl/test-e2e/Matrix/joint_matrix_all_sizes_impl.hpp b/sycl/test-e2e/Matrix/joint_matrix_all_sizes_impl.hpp index b8b660fb33ee2..00149e0b55ce4 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_all_sizes_impl.hpp +++ b/sycl/test-e2e/Matrix/joint_matrix_all_sizes_impl.hpp @@ -56,8 +56,7 @@ void matrix_multiply(big_matrix &C, big_matrix &A, sub_group sg = spmd_item.get_sub_group(); joint_matrix sub_a; // For B, we assume B has been already VNNIed. - joint_matrix + joint_matrix sub_b; joint_matrix sub_c; diff --git a/sycl/test-e2e/Matrix/joint_matrix_apply_cuda.hpp b/sycl/test-e2e/Matrix/joint_matrix_apply_cuda.hpp index 754e6429d0d96..e091442b84cb7 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_apply_cuda.hpp +++ b/sycl/test-e2e/Matrix/joint_matrix_apply_cuda.hpp @@ -51,34 +51,34 @@ void matrix_verify_lambda(queue q, q.submit([&](handler &cgh) { accessor accC(bufC, cgh); - cgh.parallel_for>(r, [ - accC, lambda - ](nd_item<2> spmd_item)[[sycl::reqd_sub_group_size(SG_SZ)]] { - const auto global_idx = spmd_item.get_global_id(0); - const auto global_idy = spmd_item.get_global_id(1); - const auto sg_startx = global_idx - spmd_item.get_local_id(0); - const auto sg_starty = global_idy - spmd_item.get_local_id(1); - - auto sg = spmd_item.get_sub_group(); - - joint_matrix sub_a; - joint_matrix sub_b; - joint_matrix sub_c; - - joint_matrix_fill(sg, sub_a, 3); - joint_matrix_fill(sg, sub_b, 1); - joint_matrix_fill(sg, sub_c, -80); - - joint_matrix_apply(sg, sub_a, lambda); - - joint_matrix_mad(sg, sub_c, sub_a, sub_b, sub_c); - - joint_matrix_store( - sg, sub_c, - accC.template get_multi_ptr() + - (sg_startx * M) * (N * nWGperDim) + sg_starty / SG_SZ * N, - (N * nWGperDim), layout::row_major); - }); // parallel for + cgh.parallel_for>( + r, [accC, lambda]( + nd_item<2> spmd_item) [[sycl::reqd_sub_group_size(SG_SZ)]] { + const auto global_idx = spmd_item.get_global_id(0); + const auto global_idy = spmd_item.get_global_id(1); + const auto sg_startx = global_idx - spmd_item.get_local_id(0); + const auto sg_starty = global_idy - spmd_item.get_local_id(1); + + auto sg = spmd_item.get_sub_group(); + + joint_matrix sub_a; + joint_matrix sub_b; + joint_matrix sub_c; + + joint_matrix_fill(sg, sub_a, 3); + joint_matrix_fill(sg, sub_b, 1); + joint_matrix_fill(sg, sub_c, -80); + + joint_matrix_apply(sg, sub_a, lambda); + + joint_matrix_mad(sg, sub_c, sub_a, sub_b, sub_c); + + joint_matrix_store( + sg, sub_c, + accC.template get_multi_ptr() + + (sg_startx * M) * (N * nWGperDim) + sg_starty / SG_SZ * N, + (N * nWGperDim), layout::row_major); + }); // parallel for }); } assert_ref(C.get_data(), ref); @@ -111,8 +111,8 @@ void matrix_verify_op(queue q, big_matrix &C, cgh); cgh.parallel_for>( - r, [ accC, - Op ](nd_item<2> spmd_item)[[sycl::reqd_sub_group_size(SG_SZ)]] { + r, [accC, + Op](nd_item<2> spmd_item) [[sycl::reqd_sub_group_size(SG_SZ)]] { const auto global_idx = spmd_item.get_global_id(0); const auto global_idy = spmd_item.get_global_id(1); const auto sg_startx = global_idx - spmd_item.get_local_id(0); @@ -162,8 +162,7 @@ void matrix_verify_op(queue q, big_matrix &C, (sg_startx * M) * (N * nWGperDim) + sg_starty / SG_SZ * N, (N * nWGperDim), layout::row_major); }); // parallel for - }) - .wait(); + }).wait(); } assert_ops_ref(C.get_data(), ref); } diff --git a/sycl/test-e2e/Matrix/joint_matrix_bf16_fill_k_cache_impl.hpp b/sycl/test-e2e/Matrix/joint_matrix_bf16_fill_k_cache_impl.hpp index 94ab1d07646e1..b9edadad461a6 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_bf16_fill_k_cache_impl.hpp +++ b/sycl/test-e2e/Matrix/joint_matrix_bf16_fill_k_cache_impl.hpp @@ -153,30 +153,22 @@ double joint_matmul(TOperand *A, TOperand *B, TResult *C, queue &q, int i) { #ifdef INIT_LIST = { - joint_matrix< - sub_group, TOperand, use::b, tK, tN, - layout::ext_intel_packed>(), - joint_matrix< - sub_group, TOperand, use::b, tK, tN, - layout::ext_intel_packed>(), - joint_matrix< - sub_group, TOperand, use::b, tK, tN, - layout::ext_intel_packed>(), - joint_matrix< - sub_group, TOperand, use::b, tK, tN, - layout::ext_intel_packed>(), - joint_matrix< - sub_group, TOperand, use::b, tK, tN, - layout::ext_intel_packed>(), - joint_matrix< - sub_group, TOperand, use::b, tK, tN, - layout::ext_intel_packed>(), - joint_matrix< - sub_group, TOperand, use::b, tK, tN, - layout::ext_intel_packed>(), - joint_matrix< - sub_group, TOperand, use::b, tK, tN, - layout::ext_intel_packed>(), + joint_matrix(), + joint_matrix(), + joint_matrix(), + joint_matrix(), + joint_matrix(), + joint_matrix(), + joint_matrix(), + joint_matrix(), } #endif ; diff --git a/sycl/test-e2e/Matrix/joint_matrix_query_default.cpp b/sycl/test-e2e/Matrix/joint_matrix_query_default.cpp index ef5f702b4356c..7fc3547b56e31 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_query_default.cpp +++ b/sycl/test-e2e/Matrix/joint_matrix_query_default.cpp @@ -77,9 +77,7 @@ void matrix_multiply(big_matrix &C, sycl::sub_group sg = spmd_item.get_sub_group(); myparams2::joint_matrix_a sub_a; - myparams2::joint_matrix_b< - sub_group, layout::ext_intel_packed> - sub_b; + myparams2::joint_matrix_b sub_b; myparams2::joint_matrix_accumulator sub_c; joint_matrix_load( diff --git a/sycl/test/matrix/query-use.cpp b/sycl/test/matrix/query-use.cpp index 05f62c093fb28..fce49a83df30b 100644 --- a/sycl/test/matrix/query-use.cpp +++ b/sycl/test/matrix/query-use.cpp @@ -63,9 +63,7 @@ void query_amx() { [msize, ksize, nsize](nd_item<2> spmd_item) { sub_group sg = spmd_item.get_sub_group(); myparams2::joint_matrix_a sub_a1; - myparams2::joint_matrix_b< - sub_group, layout::ext_intel_packed> - sub_b1; + myparams2::joint_matrix_b sub_b1; myparams2::joint_matrix_accumulator sub_c1; joint_matrix sub_a; @@ -143,9 +141,7 @@ void query_xmx8() { [msize, ksize, nsize](nd_item<2> spmd_item) { sub_group sg = spmd_item.get_sub_group(); myparams2::joint_matrix_a sub_a1; - myparams2::joint_matrix_b< - sub_group, layout::ext_intel_packed> - sub_b1; + myparams2::joint_matrix_b sub_b1; myparams2::joint_matrix_accumulator sub_c1; joint_matrix sub_a; From 8f2f1971b2af532635546bf10faec7d229ef1cc6 Mon Sep 17 00:00:00 2001 From: Bing1 Yu Date: Fri, 22 Sep 2023 16:25:31 +0800 Subject: [PATCH 15/50] handle cuda testcase compfail --- .../sycl/ext/oneapi/matrix/matrix-tensorcores.hpp | 12 ++++++------ .../sycl/ext/oneapi/matrix/matrix-unified.hpp | 2 +- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/sycl/include/sycl/ext/oneapi/matrix/matrix-tensorcores.hpp b/sycl/include/sycl/ext/oneapi/matrix/matrix-tensorcores.hpp index 849cb676c6613..94ae318540012 100644 --- a/sycl/include/sycl/ext/oneapi/matrix/matrix-tensorcores.hpp +++ b/sycl/include/sycl/ext/oneapi/matrix/matrix-tensorcores.hpp @@ -413,7 +413,7 @@ void store_layoutT( template void joint_matrix_store_cuda( - const joint_matrix_cuda< + joint_matrix_cuda< T, sycl::ext::oneapi::experimental::matrix::use::accumulator, NumRows, NumCols, sycl::ext::oneapi::experimental::matrix::layout::dynamic> &src, multi_ptr dst, size_t stride, @@ -482,11 +482,11 @@ void joint_matrix_mad_cuda( joint_matrix_cuda< Tc, sycl::ext::oneapi::experimental::matrix::use::accumulator, M, N, sycl::ext::oneapi::experimental::matrix::layout::dynamic> &D, - const joint_matrix_cuda &A, - const joint_matrix_cuda &B, - const joint_matrix_cuda< + joint_matrix_cuda &A, + joint_matrix_cuda &B, + joint_matrix_cuda< Tc, sycl::ext::oneapi::experimental::matrix::use::accumulator, M, N, sycl::ext::oneapi::experimental::matrix::layout::dynamic> &C) { if constexpr (M == 16 && N == 16 && K == 16) { diff --git a/sycl/include/sycl/ext/oneapi/matrix/matrix-unified.hpp b/sycl/include/sycl/ext/oneapi/matrix/matrix-unified.hpp index bf2441ac17c2e..71a50663e9017 100644 --- a/sycl/include/sycl/ext/oneapi/matrix/matrix-unified.hpp +++ b/sycl/include/sycl/ext/oneapi/matrix/matrix-unified.hpp @@ -40,7 +40,7 @@ struct joint_matrix { #if defined(__SYCL_DEVICE_ONLY__) #if defined(__NVPTX__) - sycl::ext::oneapi::detail::joint_matrix_cuda + mutable sycl::ext::oneapi::detail::joint_matrix_cuda cuda_impl; #elif defined(__SPIR__) __spv::__spirv_JointMatrixINTEL< From 1411376b42069f36341dcd4930bade17dad45f4e Mon Sep 17 00:00:00 2001 From: Bing1 Yu Date: Fri, 22 Sep 2023 16:36:55 +0800 Subject: [PATCH 16/50] address dounia's comments --- .../sycl/ext/oneapi/matrix/matrix-unified.hpp | 6 ++--- .../Matrix/element_wise_irreg_sum_rows.cpp | 26 ------------------- .../Matrix/get_coord_int8_matB_impl.hpp | 15 ++++------- sycl/test/matrix/query-use.cpp | 8 ++++-- 4 files changed, 13 insertions(+), 42 deletions(-) delete mode 100644 sycl/test-e2e/Matrix/element_wise_irreg_sum_rows.cpp diff --git a/sycl/include/sycl/ext/oneapi/matrix/matrix-unified.hpp b/sycl/include/sycl/ext/oneapi/matrix/matrix-unified.hpp index 71a50663e9017..026729b32e788 100644 --- a/sycl/include/sycl/ext/oneapi/matrix/matrix-unified.hpp +++ b/sycl/include/sycl/ext/oneapi/matrix/matrix-unified.hpp @@ -88,8 +88,7 @@ class wi_data { return jm.cuda_impl.wi_marray.size(); #else throw runtime_error( - "get_wi_data is available using: ext::oneapi::detail::get_wi_data, but " - "intel users are expected to use joint_matrix_copy instead.", + "get_wi_data is unavailable, use joint_matrix_copy instead.", PI_ERROR_INVALID_DEVICE); #endif }; @@ -99,8 +98,7 @@ class wi_data { return (jm.cuda_impl.wi_marray[i]); #else throw runtime_error( - "get_wi_data is available using: ext::oneapi::detail::get_wi_data, but " - "intel users are expected to use joint_matrix_copy instead.", + "get_wi_data is unavailable, use joint_matrix_copy instead.", PI_ERROR_INVALID_DEVICE); #endif }; diff --git a/sycl/test-e2e/Matrix/element_wise_irreg_sum_rows.cpp b/sycl/test-e2e/Matrix/element_wise_irreg_sum_rows.cpp deleted file mode 100644 index 1cb48f1bc4f72..0000000000000 --- a/sycl/test-e2e/Matrix/element_wise_irreg_sum_rows.cpp +++ /dev/null @@ -1,26 +0,0 @@ -//==-------- element_wise_irreg_sum_rows.cpp - DPC++ joint_matrix----- ----==// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// REQUIRES: matrix - -// RUN: %{build} -o %t.out -DSYCL_EXT_ONEAPI_MATRIX_VERSION=4 -// RUN: %{run} %t.out - -// This code calculates the sum of rows into a global array of number of rows -// elements. First, partial reduction is computed inside each SG, then atomic -// add is used to reduce between SG leaders - -#include -#include - -using namespace sycl; -using namespace sycl::ext::oneapi::experimental::matrix; - -#define SG_SZ 16 -constexpr size_t TN = 16; - -#include "element_wise_irreg_sum_rows_impl.hpp" diff --git a/sycl/test-e2e/Matrix/get_coord_int8_matB_impl.hpp b/sycl/test-e2e/Matrix/get_coord_int8_matB_impl.hpp index 0df698e3ace48..22259eb0072b0 100644 --- a/sycl/test-e2e/Matrix/get_coord_int8_matB_impl.hpp +++ b/sycl/test-e2e/Matrix/get_coord_int8_matB_impl.hpp @@ -122,16 +122,11 @@ void matrix_sum_cols(queue q, big_matrix &B, nd_range<2> &r) { N); int32_t sum_local_cols[N] = {0}; - auto wiData = sycl::ext::oneapi::detail::get_wi_data(sg, sub_b); - - // each WI calculates local sum of cols - for (int i = 0; i < wiData.length(); ++i) { - // get the index of the element in the submatrix - auto dataItem = wiData[i]; - auto [row, col] = dataItem.get_coord(); - size_t global_index = col + global_idy / SG_SZ * TN * VF; - sum_local_cols[global_index] += dataItem; - } + ext::intel::experimental::matrix::joint_matrix_apply( + sg, sub_b, [&](int8_t &x, size_t row, size_t col) { + size_t global_index = col + global_idy / SG_SZ * TN * VF; + sum_local_cols[global_index] += x; + }); for (int i = 0; i < N; i++) { sum_local_cols[i] = diff --git a/sycl/test/matrix/query-use.cpp b/sycl/test/matrix/query-use.cpp index fce49a83df30b..9afc8e1173043 100644 --- a/sycl/test/matrix/query-use.cpp +++ b/sycl/test/matrix/query-use.cpp @@ -63,7 +63,9 @@ void query_amx() { [msize, ksize, nsize](nd_item<2> spmd_item) { sub_group sg = spmd_item.get_sub_group(); myparams2::joint_matrix_a sub_a1; - myparams2::joint_matrix_b sub_b1; + myparams2::joint_matrix_b< + sub_group, sycl::ext::intel::experimental::matrix::layout::packed> + sub_b1; myparams2::joint_matrix_accumulator sub_c1; joint_matrix sub_a; @@ -141,7 +143,9 @@ void query_xmx8() { [msize, ksize, nsize](nd_item<2> spmd_item) { sub_group sg = spmd_item.get_sub_group(); myparams2::joint_matrix_a sub_a1; - myparams2::joint_matrix_b sub_b1; + myparams2::joint_matrix_b< + sub_group, sycl::ext::intel::experimental::matrix::layout::packed> + sub_b1; myparams2::joint_matrix_accumulator sub_c1; joint_matrix sub_a; From 95df3b18379e8993958bc8b6b800e345ae9099d3 Mon Sep 17 00:00:00 2001 From: Bing1 Yu Date: Fri, 22 Sep 2023 16:40:38 +0800 Subject: [PATCH 17/50] lint --- sycl/include/sycl/ext/oneapi/matrix/matrix-unified.hpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sycl/include/sycl/ext/oneapi/matrix/matrix-unified.hpp b/sycl/include/sycl/ext/oneapi/matrix/matrix-unified.hpp index 026729b32e788..1d7ef2dae065c 100644 --- a/sycl/include/sycl/ext/oneapi/matrix/matrix-unified.hpp +++ b/sycl/include/sycl/ext/oneapi/matrix/matrix-unified.hpp @@ -40,7 +40,8 @@ struct joint_matrix { #if defined(__SYCL_DEVICE_ONLY__) #if defined(__NVPTX__) - mutable sycl::ext::oneapi::detail::joint_matrix_cuda + mutable sycl::ext::oneapi::detail::joint_matrix_cuda cuda_impl; #elif defined(__SPIR__) __spv::__spirv_JointMatrixINTEL< From fb1afdcd5f0b26f073a2bd1a09ea3ab3eb64d75d Mon Sep 17 00:00:00 2001 From: Bing1 Yu Date: Fri, 22 Sep 2023 17:01:59 +0800 Subject: [PATCH 18/50] rm sycl/test/matrix/query-use.cpp --- sycl/test/matrix/query-use.cpp | 162 --------------------------------- 1 file changed, 162 deletions(-) delete mode 100644 sycl/test/matrix/query-use.cpp diff --git a/sycl/test/matrix/query-use.cpp b/sycl/test/matrix/query-use.cpp deleted file mode 100644 index 9afc8e1173043..0000000000000 --- a/sycl/test/matrix/query-use.cpp +++ /dev/null @@ -1,162 +0,0 @@ -// RUN: %clangxx -DSYCL_EXT_ONEAPI_MATRIX_VERSION=4 -fsycl -o query-use %s -#include -#include - -using namespace sycl; -using namespace sycl::ext::oneapi::experimental::matrix; - -void query_amx() { - - // generates combination assert - // using myparams = tpu_params; - - // generates types assert - // using myparams2 = tpu_params; - - // tells whether a combination is valid or not, if valid, those will be set as - // default - using myparams = tpu_params; - - size_t dmsize = myparams::M; - size_t dnsize = myparams::N; - size_t dksize = myparams::K; - std::cout << "sizes of AMX tpu_params chosen by the user are: M " << dmsize - << " N " << dnsize << " K " << dksize << std::endl; - - // Sizes-only query: types are given, generate default sizes - using myparams2 = tpu_params; - myparams2 p; - dmsize = myparams2::M; - dnsize = myparams2::N; - dksize = myparams2::K; - std::cout << "default AMX sizes tpu_params are: M " << dmsize << " N " - << dnsize << " K " << dksize << "\n AMX int8 num combinations is " - << p.num_combinations << std::endl; - - // general query: types are not given - tpu_params myparams3; - - if (myparams3.num_scopes > 0) - if (myparams3.scopes[0] == scope_t::sub_group) - std::cout << "There are " << myparams3.num_scopes - << " Scopes that are supported by AMX implementation and " - "subgroup is one of them " - << std::endl; - - std::cout << "AMX query num combinations: " << myparams3.num_combinations - << std::endl; - - if (myparams3.combinations[0].msize != 0) // this is a max params hardware - return; - constexpr int msize = myparams3.combinations[0].max_msize; - constexpr int nsize = myparams3.combinations[0].max_nsize; - constexpr int ksize = myparams3.combinations[0].max_ksize; - std::cout << "AMX query sizes are: M " << msize << " N " << nsize << " K " - << ksize << std::endl; - - size_t NDRangeM = 1024 / msize; - size_t NDRangeN = 1024 / nsize; - queue q; - q.submit([&](handler &cgh) { - cgh.parallel_for( - nd_range<2>({NDRangeM, NDRangeN}, {1, 1}), - [msize, ksize, nsize](nd_item<2> spmd_item) { - sub_group sg = spmd_item.get_sub_group(); - myparams2::joint_matrix_a sub_a1; - myparams2::joint_matrix_b< - sub_group, sycl::ext::intel::experimental::matrix::layout::packed> - sub_b1; - myparams2::joint_matrix_accumulator sub_c1; - - joint_matrix sub_a; - joint_matrix sub_b; - joint_matrix sub_c; - }); - }); -} - -void query_xmx8() { - - // generates combination assert - // using myparams = tpu_params; - - // generate combination of type assert - // using myparams = tpu_params; - - // tells whether a combination is valid or not, if valid, those will be set as - // default - using myparams = tpu_params; - - size_t dmsize = myparams::M; - size_t dnsize = myparams::N; - size_t dksize = myparams::K; - std::cout << "sizes of XMX8 tpu_params chosen by the user are: M " << dmsize - << " N " << dnsize << " K " << dksize << std::endl; - - // sizes-only query: types are given, generate default sizes - using myparams2 = tpu_params; - myparams2 p; - dmsize = myparams2::M; - dnsize = myparams2::N; - dksize = myparams2::K; - std::cout << "Default XMX8 sizes are: M " << dmsize << " N " << dnsize - << " K " << dksize << "\n XMX8 int8 num combinations is " - << p.num_combinations << std::endl; - - dmsize = myparams2::combinations[0].msize; - dnsize = myparams2::combinations[0].nsize; - dksize = myparams2::combinations[0].ksize; - std::cout << "one of XMX8 combination sizes is: M " << dmsize << " N " - << dnsize << " K " << dksize << std::endl; - - // general query: types are not given - tpu_params myparams3; - - if (myparams3.num_scopes > 0) - if (myparams3.scopes[0] == scope_t::sub_group) - std::cout << "There are " << myparams3.num_scopes - << " Scopes that are supported by XMX8 implementation and " - "subgroup is one of them " - << std::endl; - - std::cout << "XMX8 query num combinations: " << myparams3.num_combinations - << std::endl; - - if (myparams3.combinations[0].msize == 0) // this is not a max params hardware - return; - constexpr int msize = myparams3.combinations[0].msize; - constexpr int nsize = myparams3.combinations[0].nsize; - constexpr int ksize = myparams3.combinations[0].ksize; - std::cout << "XMX8 query sizes are: M " << msize << " N " << nsize << " K " - << ksize << std::endl; - std::cout << "XMX8 query max sizes are: M " - << myparams3.combinations[0].max_msize << " N " - << myparams3.combinations[0].max_nsize << " K " - << myparams3.combinations[0].max_ksize << std::endl; - - size_t NDRangeM = 1024 / msize; - size_t NDRangeN = 1024 / nsize; - queue q; - q.submit([&](handler &cgh) { - cgh.parallel_for( - nd_range<2>({NDRangeM, NDRangeN}, {1, 1}), - [msize, ksize, nsize](nd_item<2> spmd_item) { - sub_group sg = spmd_item.get_sub_group(); - myparams2::joint_matrix_a sub_a1; - myparams2::joint_matrix_b< - sub_group, sycl::ext::intel::experimental::matrix::layout::packed> - sub_b1; - myparams2::joint_matrix_accumulator sub_c1; - - joint_matrix sub_a; - joint_matrix sub_b; - joint_matrix sub_c; - }); - }); -} - -int main() { - query_amx(); - query_xmx8(); - return 0; -} From 11df5313a65eda4d46cba2c4fff052dcdda7d4be Mon Sep 17 00:00:00 2001 From: Bing1 Yu Date: Mon, 25 Sep 2023 16:22:14 +0800 Subject: [PATCH 19/50] fix x jm_mad in joint_matrix_bf16_fill_k_cache_impl.hpp --- sycl/test-e2e/Matrix/joint_matrix_bf16_fill_k_cache_impl.hpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sycl/test-e2e/Matrix/joint_matrix_bf16_fill_k_cache_impl.hpp b/sycl/test-e2e/Matrix/joint_matrix_bf16_fill_k_cache_impl.hpp index b9edadad461a6..7efdb03b25a8c 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_bf16_fill_k_cache_impl.hpp +++ b/sycl/test-e2e/Matrix/joint_matrix_bf16_fill_k_cache_impl.hpp @@ -220,8 +220,8 @@ double joint_matmul(TOperand *A, TOperand *B, TResult *C, queue &q, int i) { for (unsigned int n = 0; n < NCACHE1 / tN; n++) { #endif - tC[m][n] = - joint_matrix_mad(sg, tA[m][k1], tB[n][k1], tC[m][n]); + joint_matrix_mad(sg, tC[m][n], tA[m][k1], tB[n][k1], + tC[m][n]); #ifdef MANUAL_UNROLL }); // n }); // m From a82110767505718b630b3de1c33171bd4ef581ec Mon Sep 17 00:00:00 2001 From: Bing1 Yu Date: Thu, 12 Oct 2023 01:31:20 +0800 Subject: [PATCH 20/50] address comments --- .../sycl/ext/oneapi/matrix/matrix-intel.hpp | 10 ++-- .../sycl/ext/oneapi/matrix/matrix-unified.hpp | 14 +---- .../Matrix/element_wise_all_ops_tf32_impl.hpp | 59 +++++++------------ 3 files changed, 29 insertions(+), 54 deletions(-) diff --git a/sycl/include/sycl/ext/oneapi/matrix/matrix-intel.hpp b/sycl/include/sycl/ext/oneapi/matrix/matrix-intel.hpp index 98836da6cc7c3..b852e3f1ff3f5 100644 --- a/sycl/include/sycl/ext/oneapi/matrix/matrix-intel.hpp +++ b/sycl/include/sycl/ext/oneapi/matrix/matrix-intel.hpp @@ -522,12 +522,13 @@ joint_matrix_store(Group, } template + sycl::ext::oneapi::experimental::matrix::use Use, size_t Rows, + size_t Cols, sycl::ext::oneapi::experimental::matrix::layout Layout, + typename F> inline __SYCL_ALWAYS_INLINE void joint_matrix_apply( Group sg, - sycl::ext::oneapi::experimental::matrix::joint_matrix &jm, + sycl::ext::oneapi::experimental::matrix::joint_matrix &jm, F &&lambda) { #if defined(__SYCL_DEVICE_ONLY__) #if defined(__NVPTX__) @@ -554,7 +555,6 @@ inline __SYCL_ALWAYS_INLINE void joint_matrix_apply( throw runtime_error("joint matrix is not supported on host device.", PI_ERROR_INVALID_DEVICE); #endif - return; } } // namespace intel::experimental::matrix diff --git a/sycl/include/sycl/ext/oneapi/matrix/matrix-unified.hpp b/sycl/include/sycl/ext/oneapi/matrix/matrix-unified.hpp index 431ddc4dd4d82..1389076c52e97 100644 --- a/sycl/include/sycl/ext/oneapi/matrix/matrix-unified.hpp +++ b/sycl/include/sycl/ext/oneapi/matrix/matrix-unified.hpp @@ -87,21 +87,12 @@ class wi_data { size_t length() { #if defined(__NVPTX__) return jm.cuda_impl.wi_marray.size(); -#else - throw runtime_error( - "get_wi_data is unavailable, use joint_matrix_copy instead.", - PI_ERROR_INVALID_DEVICE); #endif }; decltype(auto) operator[](size_t i) { #if defined(__NVPTX__) return (jm.cuda_impl.wi_marray[i]); -#else - std::ignore = i; - throw runtime_error( - "get_wi_data is unavailable, use joint_matrix_copy instead.", - PI_ERROR_INVALID_DEVICE); #endif }; }; @@ -129,9 +120,8 @@ template &A, nd_range<2> &r, joint_matrix_fill(sg, sub_a, round_to_tf32(5.0)); - auto wi_slice_a = - ext::intel::experimental::matrix::get_wi_data(sg, sub_a); - for (int i = 0; i < wi_slice_a.length(); i++) { - wi_slice_a[i] = wi_slice_a[i] + 2; - } + joint_matrix_apply(sg, sub_a, + [&](float &x) { x = x + round_to_tf32(2); }); ext::intel::experimental::matrix::joint_matrix_store( sg, sub_a, @@ -77,11 +74,9 @@ void matrix_verify_sub(queue q, big_matrix &A, nd_range<2> &r, joint_matrix_fill(sg, sub_a, round_to_tf32(5.0)); - auto wi_slice_a = - ext::intel::experimental::matrix::get_wi_data(sg, sub_a); - for (int i = 0; i < wi_slice_a.length(); i++) { - wi_slice_a[i] = wi_slice_a[i] - round_to_tf32(2); - } + joint_matrix_apply(sg, sub_a, + [&](float &x) { x = x - round_to_tf32(2); }); + ext::intel::experimental::matrix::joint_matrix_store( sg, sub_a, accA.template get_multi_ptr() + @@ -111,11 +106,8 @@ void matrix_verify_mul(queue q, big_matrix &A, nd_range<2> &r, joint_matrix sub_a; joint_matrix_fill(sg, sub_a, round_to_tf32(5.0)); - auto wi_slice_a = - ext::intel::experimental::matrix::get_wi_data(sg, sub_a); - for (int i = 0; i < wi_slice_a.length(); i++) { - wi_slice_a[i] = wi_slice_a[i] * round_to_tf32(3.0); - } + joint_matrix_apply(sg, sub_a, + [&](float &x) { x = x * round_to_tf32(3.0); }); ext::intel::experimental::matrix::joint_matrix_store( sg, sub_a, accA.template get_multi_ptr() + @@ -146,11 +138,8 @@ void matrix_verify_div(queue q, big_matrix &A, nd_range<2> &r, joint_matrix_fill(sg, sub_a, round_to_tf32(4.0)); - auto wi_slice_a = - ext::intel::experimental::matrix::get_wi_data(sg, sub_a); - for (int i = 0; i < wi_slice_a.length(); i++) { - wi_slice_a[i] = wi_slice_a[i] / round_to_tf32(2.0); - } + joint_matrix_apply(sg, sub_a, + [&](float &x) { x = x / round_to_tf32(2); }); ext::intel::experimental::matrix::joint_matrix_store( sg, sub_a, accA.template get_multi_ptr() + @@ -180,27 +169,23 @@ void matrix_verify_logic(queue q, big_matrix &A, nd_range<2> &r, joint_matrix_fill(sg, sub_a, round_to_tf32(5.0)); - auto wi_slice_a = - ext::intel::experimental::matrix::get_wi_data(sg, sub_a); - for (int i = 0; i < wi_slice_a.length(); i++) { - if (wi_slice_a[i]) { - if (wi_slice_a[i] > 2.0 || wi_slice_a[i] >= 2.0 || - wi_slice_a[i] < 2.0 || wi_slice_a[i] <= 2.0) { - Ts val = (wi_slice_a[i] != 2.0) ? wi_slice_a[i] : 2.0; - val = val - static_cast(1); - val = val + static_cast(1); - if (wi_slice_a[i] == 2.0) { - val = val - static_cast(2); - val = val * static_cast(3); - val = val / static_cast(2); - + joint_matrix_apply(sg, sub_a, [&](float &x) { + if (x) { + if (x > 2 || x >= 2 || x < 2 || x <= 2) { + float val = (x != 2) ? x : 2; + val--; + val++; + if (x == 2) { + val -= 2; + val *= 3; + val /= 2; } else { - val = val + static_cast(2); + val += 2; } - wi_slice_a[i] = val; + x = val; } } - } + }); ext::intel::experimental::matrix::joint_matrix_store( sg, sub_a, accA.template get_multi_ptr() + From 1d091de81a2cbacce18b51edbac3b656e2fe8798 Mon Sep 17 00:00:00 2001 From: Bing1 Yu Date: Thu, 12 Oct 2023 02:51:16 +0800 Subject: [PATCH 21/50] rm element_wise_irreg_sum_rows_impl.hpp --- .../element_wise_irreg_sum_rows_impl.hpp | 105 ------------------ .../element_wise_irreg_sum_rows_impl.hpp | 105 ------------------ 2 files changed, 210 deletions(-) delete mode 100644 sycl/test-e2e/Matrix/Legacy/element_wise_irreg_sum_rows_impl.hpp delete mode 100644 sycl/test-e2e/Matrix/element_wise_irreg_sum_rows_impl.hpp diff --git a/sycl/test-e2e/Matrix/Legacy/element_wise_irreg_sum_rows_impl.hpp b/sycl/test-e2e/Matrix/Legacy/element_wise_irreg_sum_rows_impl.hpp deleted file mode 100644 index 6a18fe3650f2c..0000000000000 --- a/sycl/test-e2e/Matrix/Legacy/element_wise_irreg_sum_rows_impl.hpp +++ /dev/null @@ -1,105 +0,0 @@ -#define TN SG_SZ -#define TK 32 - -template struct big_matrix { -public: - T *mat; - -public: - T *get_data() { return mat; } - void set_data(T *data) { mat = data; } - big_matrix(T *data) : mat(data) {} -}; - -template -void sum_rows_ref(host_accessor B, - host_accessor sum_rows) { - int sum_rows_ref[M] = {0}; - for (size_t i = 0; i < M; i++) { - for (size_t j = 0; j < N; j++) { - sum_rows_ref[i] += B[i][j]; - } - auto diff = sum_rows[i] - sum_rows_ref[i]; - assert(std::fabs(static_cast(diff)) <= - std::numeric_limits::epsilon()); - } -} - -template -void matrix_sum_rows(queue q, big_matrix &B, nd_range<2> &r) { - buffer bufB(B.get_data(), range<2>(M, N)); - // size of vector is known because SG size of set by the user in this case - int sum_rows[M] = {0}; - buffer sum_rows_v(sum_rows, M); // there are total of tK/4 * 2, 16 rows - q.submit([&](handler &cgh) { - auto accB = bufB.get_access(cgh); - - auto v = sum_rows_v.get_access(cgh); - - cgh.parallel_for( - r, [=](nd_item<2> spmd_item) [[intel::reqd_sub_group_size(SG_SZ)]] { - const auto global_idx = spmd_item.get_global_id(0); - const auto global_idy = spmd_item.get_global_id(1); - const auto sg_startx = global_idx - spmd_item.get_local_id(0); - const auto sg_starty = global_idy - spmd_item.get_local_id(1); - - sycl::sub_group sg = spmd_item.get_sub_group(); - - joint_matrix sub_b(sg); - - joint_matrix_load( - sg, sub_b, - accB.template get_multi_ptr() + - (global_idx * (TK / 4) * N) + sg_starty / SG_SZ * TN * 4, - N, matrix_layout::packed_b); - // calculate sum of rows in sum_rows_v[8], there are 8 rows in sub_b - // (tK/4) - int32_t sum_local_rows[M] = {0}; // 8 local rows, M total - // sub_b has 32x8 elements, 32 elements per WI, 4 per WI per row - auto data = sub_b.get_wi_data(); - - // each WI calculates local sum of rows - for (int row = 0; row < TK / 4; row++) { // there are 8 rows - for (int i = 0; i < data.length() / (TK / 4); i++) { // 4 per row - // i*SG_SIZE index is found based on the round robin - // distribution we are using in the implementation - sum_local_rows[row + global_idx * (TK / 4)] += data[i + row * 4]; - } - sum_local_rows[row + global_idx * (TK / 4)] = reduce_over_group( - sg, sum_local_rows[row + global_idx * (TK / 4)], - sycl::plus<>()); - - // only Groups leader perform the global reduction - if (global_idy % SG_SZ == 0) { - atomic_fetch_add(v[row + global_idx * (TK / 4)], - sum_local_rows[row + global_idx * (TK / 4)]); - } - } - }); // parallel for - }).wait(); - sum_rows_ref(bufB.get_host_access(read_only), - sum_rows_v.get_host_access(read_only)); -} - -static constexpr size_t MATRIX_K = TK / 4 * 2; -static constexpr size_t MATRIX_N = TN * 4 * 2; -int8_t B[MATRIX_K][MATRIX_N]; - -int main() { - big_matrix MB((int8_t *)&B); - - size_t NDRangeK = MATRIX_K / (TK / 4); - size_t NDRangeN = (MATRIX_N / 4) / TN; - queue q; - nd_range<2> r({NDRangeK, NDRangeN * SG_SZ}, {1, 1 * SG_SZ}); - - for (int i = 0; i < MATRIX_K; i++) { - for (int j = 0; j < MATRIX_N; j++) { - B[i][j] = i; - } - } - - matrix_sum_rows(q, MB, r); - - return 0; -} diff --git a/sycl/test-e2e/Matrix/element_wise_irreg_sum_rows_impl.hpp b/sycl/test-e2e/Matrix/element_wise_irreg_sum_rows_impl.hpp deleted file mode 100644 index 683ad694fe26a..0000000000000 --- a/sycl/test-e2e/Matrix/element_wise_irreg_sum_rows_impl.hpp +++ /dev/null @@ -1,105 +0,0 @@ -#define TK 32 - -template struct big_matrix { -public: - T *mat; - -public: - T *get_data() { return mat; } - void set_data(T *data) { mat = data; } - big_matrix(T *data) : mat(data) {} -}; - -template -void sum_rows_ref(host_accessor B, - host_accessor sum_rows) { - int sum_rows_ref[M] = {0}; - for (size_t i = 0; i < M; i++) { - for (size_t j = 0; j < N; j++) { - sum_rows_ref[i] += B[i][j]; - } - auto diff = sum_rows[i] - sum_rows_ref[i]; - assert(std::fabs(static_cast(diff)) <= - std::numeric_limits::epsilon()); - } -} - -template -void matrix_sum_rows(queue q, big_matrix &B, nd_range<2> &r) { - buffer bufB(B.get_data(), range<2>(M, N)); - // size of vector is known because SG size of set by the user in this case - int sum_rows[M] = {0}; - buffer sum_rows_v(sum_rows, M); // there are total of tK/4 * 2, 16 rows - q.submit([&](handler &cgh) { - auto accB = bufB.get_access(cgh); - - auto v = sum_rows_v.get_access(cgh); - - cgh.parallel_for( - r, [=](nd_item<2> spmd_item) [[intel::reqd_sub_group_size(SG_SZ)]] { - const auto global_idx = spmd_item.get_global_id(0); - const auto global_idy = spmd_item.get_global_id(1); - const auto sg_startx = global_idx - spmd_item.get_local_id(0); - const auto sg_starty = global_idy - spmd_item.get_local_id(1); - - sycl::sub_group sg = spmd_item.get_sub_group(); - - joint_matrix - sub_b; - - joint_matrix_load( - sg, sub_b, - accB.template get_multi_ptr() + - (global_idx * (TK / 4) * N) + sg_starty / SG_SZ * TN * 4, - N); - // calculate sum of rows in sum_rows_v[8], there are 8 rows in sub_b - // (tK/4) - int32_t sum_local_rows[M] = {0}; // 8 local rows, M total - // sub_b has 32x8 elements, 32 elements per WI, 4 per WI per row - auto data = sycl::ext::oneapi::detail::get_wi_data(sg, sub_b); - - // each WI calculates local sum of rows - for (int row = 0; row < TK / 4; row++) { // there are 8 rows - for (int i = 0; i < data.length() / (TK / 4); i++) { // 4 per row - // i*SG_SIZE index is found based on the round robin - // distribution we are using in the implementation - sum_local_rows[row + global_idx * (TK / 4)] += data[i + row * 4]; - } - sum_local_rows[row + global_idx * (TK / 4)] = reduce_over_group( - sg, sum_local_rows[row + global_idx * (TK / 4)], - sycl::plus<>()); - - // only Groups leader perform the global reduction - if (global_idy % SG_SZ == 0) { - atomic_fetch_add(v[row + global_idx * (TK / 4)], - sum_local_rows[row + global_idx * (TK / 4)]); - } - } - }); // parallel for - }).wait(); - sum_rows_ref(bufB.get_host_access(read_only), - sum_rows_v.get_host_access(read_only)); -} - -static constexpr size_t MATRIX_K = TK / 4 * 2; -static constexpr size_t MATRIX_N = TN * 4 * 2; -int8_t B[MATRIX_K][MATRIX_N]; - -int main() { - big_matrix MB((int8_t *)&B); - - size_t NDRangeK = MATRIX_K / (TK / 4); - size_t NDRangeN = (MATRIX_N / 4) / TN; - queue q; - nd_range<2> r({NDRangeK, NDRangeN * SG_SZ}, {1, 1 * SG_SZ}); - - for (int i = 0; i < MATRIX_K; i++) { - for (int j = 0; j < MATRIX_N; j++) { - B[i][j] = i; - } - } - - matrix_sum_rows(q, MB, r); - - return 0; -} From 1e20968f7752069c7e2e370f009e8a0499e5b1f1 Mon Sep 17 00:00:00 2001 From: Bing1 Yu Date: Thu, 12 Oct 2023 02:54:50 +0800 Subject: [PATCH 22/50] small fix --- sycl/include/sycl/ext/oneapi/matrix/matrix-unified.hpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/sycl/include/sycl/ext/oneapi/matrix/matrix-unified.hpp b/sycl/include/sycl/ext/oneapi/matrix/matrix-unified.hpp index 1389076c52e97..327e1e326f108 100644 --- a/sycl/include/sycl/ext/oneapi/matrix/matrix-unified.hpp +++ b/sycl/include/sycl/ext/oneapi/matrix/matrix-unified.hpp @@ -93,6 +93,8 @@ class wi_data { decltype(auto) operator[](size_t i) { #if defined(__NVPTX__) return (jm.cuda_impl.wi_marray[i]); +#else + std::ignore = i; #endif }; }; From 1fe7fcdd13663619a0e7e3a911b1ea6bb72139d2 Mon Sep 17 00:00:00 2001 From: Bing1 Yu Date: Thu, 12 Oct 2023 03:06:42 +0800 Subject: [PATCH 23/50] small fix --- .../element_wise_irreg_sum_rows_impl.hpp | 105 ++++++++++++++++++ .../XMX8/element_wise_irreg_sum_rows.cpp | 26 ----- 2 files changed, 105 insertions(+), 26 deletions(-) create mode 100644 sycl/test-e2e/Matrix/Legacy/element_wise_irreg_sum_rows_impl.hpp delete mode 100644 sycl/test-e2e/Matrix/XMX8/element_wise_irreg_sum_rows.cpp diff --git a/sycl/test-e2e/Matrix/Legacy/element_wise_irreg_sum_rows_impl.hpp b/sycl/test-e2e/Matrix/Legacy/element_wise_irreg_sum_rows_impl.hpp new file mode 100644 index 0000000000000..6a18fe3650f2c --- /dev/null +++ b/sycl/test-e2e/Matrix/Legacy/element_wise_irreg_sum_rows_impl.hpp @@ -0,0 +1,105 @@ +#define TN SG_SZ +#define TK 32 + +template struct big_matrix { +public: + T *mat; + +public: + T *get_data() { return mat; } + void set_data(T *data) { mat = data; } + big_matrix(T *data) : mat(data) {} +}; + +template +void sum_rows_ref(host_accessor B, + host_accessor sum_rows) { + int sum_rows_ref[M] = {0}; + for (size_t i = 0; i < M; i++) { + for (size_t j = 0; j < N; j++) { + sum_rows_ref[i] += B[i][j]; + } + auto diff = sum_rows[i] - sum_rows_ref[i]; + assert(std::fabs(static_cast(diff)) <= + std::numeric_limits::epsilon()); + } +} + +template +void matrix_sum_rows(queue q, big_matrix &B, nd_range<2> &r) { + buffer bufB(B.get_data(), range<2>(M, N)); + // size of vector is known because SG size of set by the user in this case + int sum_rows[M] = {0}; + buffer sum_rows_v(sum_rows, M); // there are total of tK/4 * 2, 16 rows + q.submit([&](handler &cgh) { + auto accB = bufB.get_access(cgh); + + auto v = sum_rows_v.get_access(cgh); + + cgh.parallel_for( + r, [=](nd_item<2> spmd_item) [[intel::reqd_sub_group_size(SG_SZ)]] { + const auto global_idx = spmd_item.get_global_id(0); + const auto global_idy = spmd_item.get_global_id(1); + const auto sg_startx = global_idx - spmd_item.get_local_id(0); + const auto sg_starty = global_idy - spmd_item.get_local_id(1); + + sycl::sub_group sg = spmd_item.get_sub_group(); + + joint_matrix sub_b(sg); + + joint_matrix_load( + sg, sub_b, + accB.template get_multi_ptr() + + (global_idx * (TK / 4) * N) + sg_starty / SG_SZ * TN * 4, + N, matrix_layout::packed_b); + // calculate sum of rows in sum_rows_v[8], there are 8 rows in sub_b + // (tK/4) + int32_t sum_local_rows[M] = {0}; // 8 local rows, M total + // sub_b has 32x8 elements, 32 elements per WI, 4 per WI per row + auto data = sub_b.get_wi_data(); + + // each WI calculates local sum of rows + for (int row = 0; row < TK / 4; row++) { // there are 8 rows + for (int i = 0; i < data.length() / (TK / 4); i++) { // 4 per row + // i*SG_SIZE index is found based on the round robin + // distribution we are using in the implementation + sum_local_rows[row + global_idx * (TK / 4)] += data[i + row * 4]; + } + sum_local_rows[row + global_idx * (TK / 4)] = reduce_over_group( + sg, sum_local_rows[row + global_idx * (TK / 4)], + sycl::plus<>()); + + // only Groups leader perform the global reduction + if (global_idy % SG_SZ == 0) { + atomic_fetch_add(v[row + global_idx * (TK / 4)], + sum_local_rows[row + global_idx * (TK / 4)]); + } + } + }); // parallel for + }).wait(); + sum_rows_ref(bufB.get_host_access(read_only), + sum_rows_v.get_host_access(read_only)); +} + +static constexpr size_t MATRIX_K = TK / 4 * 2; +static constexpr size_t MATRIX_N = TN * 4 * 2; +int8_t B[MATRIX_K][MATRIX_N]; + +int main() { + big_matrix MB((int8_t *)&B); + + size_t NDRangeK = MATRIX_K / (TK / 4); + size_t NDRangeN = (MATRIX_N / 4) / TN; + queue q; + nd_range<2> r({NDRangeK, NDRangeN * SG_SZ}, {1, 1 * SG_SZ}); + + for (int i = 0; i < MATRIX_K; i++) { + for (int j = 0; j < MATRIX_N; j++) { + B[i][j] = i; + } + } + + matrix_sum_rows(q, MB, r); + + return 0; +} diff --git a/sycl/test-e2e/Matrix/XMX8/element_wise_irreg_sum_rows.cpp b/sycl/test-e2e/Matrix/XMX8/element_wise_irreg_sum_rows.cpp deleted file mode 100644 index 6559f4c93248d..0000000000000 --- a/sycl/test-e2e/Matrix/XMX8/element_wise_irreg_sum_rows.cpp +++ /dev/null @@ -1,26 +0,0 @@ -//==-------- element_wise_irreg_sum_rows.cpp - DPC++ joint_matrix----- ----==// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// REQUIRES: matrix-xmx8 - -// RUN: %{build} -o %t.out -DSYCL_EXT_ONEAPI_MATRIX_VERSION=4 -// RUN: %{run} %t.out - -// this code calculates the sum of rows into a global array of number of rows -// elements. First, partial reduction is computed inside each SG, then atomic -// add is used to reduce between SG leaders - -#include -#include - -using namespace sycl; -using namespace sycl::ext::oneapi::experimental::matrix; - -#define SG_SZ 8 -constexpr size_t TN = 8; - -#include "../element_wise_irreg_sum_rows_impl.hpp" From b5c091140e81ff17f2a5ee41b8024df5033e55ab Mon Sep 17 00:00:00 2001 From: mmoadeli Date: Sun, 1 Oct 2023 20:32:03 +0100 Subject: [PATCH 24/50] * Support one block AMD matrix core instructions * Add HIP matrix core tests. * Add `gpu-amd-gfx90a` as a feature in lig.cfg.py. Add HIP matrix core support into joint_matrix documentation. --- .../sycl_ext_oneapi_matrix.asciidoc | 38 +- sycl/include/sycl/detail/defines.hpp | 5 +- .../sycl/ext/oneapi/matrix/matrix-hip.hpp | 389 ++++++++++++++++++ .../sycl/ext/oneapi/matrix/matrix-unified.hpp | 35 ++ .../Matrix/joint_matrix_hip_apply.cpp | 14 + .../Matrix/joint_matrix_hip_apply.hpp | 102 +++++ .../test-e2e/Matrix/joint_matrix_hip_fill.cpp | 14 + .../test-e2e/Matrix/joint_matrix_hip_fill.hpp | 88 ++++ .../Matrix/joint_matrix_hip_half_apply.cpp | 12 + .../Matrix/joint_matrix_hip_half_fill.cpp | 12 + .../Matrix/joint_matrix_hip_half_mfma.cpp | 15 + .../test-e2e/Matrix/joint_matrix_hip_mfma.cpp | 20 + .../test-e2e/Matrix/joint_matrix_hip_mfma.hpp | 116 ++++++ sycl/test-e2e/lit.cfg.py | 2 + .../matrix/matrix-hip-bfloat16-float-test.cpp | 90 ++++ .../matrix/matrix-hip-double-double-test.cpp | 62 +++ .../hip/matrix/matrix-hip-half-float-test.cpp | 89 ++++ .../hip/matrix/matrix-hip-int8-int32-test.cpp | 89 ++++ 18 files changed, 1186 insertions(+), 6 deletions(-) create mode 100644 sycl/include/sycl/ext/oneapi/matrix/matrix-hip.hpp create mode 100644 sycl/test-e2e/Matrix/joint_matrix_hip_apply.cpp create mode 100644 sycl/test-e2e/Matrix/joint_matrix_hip_apply.hpp create mode 100644 sycl/test-e2e/Matrix/joint_matrix_hip_fill.cpp create mode 100644 sycl/test-e2e/Matrix/joint_matrix_hip_fill.hpp create mode 100644 sycl/test-e2e/Matrix/joint_matrix_hip_half_apply.cpp create mode 100644 sycl/test-e2e/Matrix/joint_matrix_hip_half_fill.cpp create mode 100644 sycl/test-e2e/Matrix/joint_matrix_hip_half_mfma.cpp create mode 100644 sycl/test-e2e/Matrix/joint_matrix_hip_mfma.cpp create mode 100644 sycl/test-e2e/Matrix/joint_matrix_hip_mfma.hpp create mode 100644 sycl/test/check_device_code/hip/matrix/matrix-hip-bfloat16-float-test.cpp create mode 100644 sycl/test/check_device_code/hip/matrix/matrix-hip-double-double-test.cpp create mode 100644 sycl/test/check_device_code/hip/matrix/matrix-hip-half-float-test.cpp create mode 100644 sycl/test/check_device_code/hip/matrix/matrix-hip-int8-int32-test.cpp diff --git a/sycl/doc/extensions/experimental/sycl_ext_matrix/sycl_ext_oneapi_matrix.asciidoc b/sycl/doc/extensions/experimental/sycl_ext_matrix/sycl_ext_oneapi_matrix.asciidoc index 6a0983b9cf35c..c6ee448d0d5a4 100644 --- a/sycl/doc/extensions/experimental/sycl_ext_matrix/sycl_ext_oneapi_matrix.asciidoc +++ b/sycl/doc/extensions/experimental/sycl_ext_matrix/sycl_ext_oneapi_matrix.asciidoc @@ -50,7 +50,7 @@ specification.* This extension is currently implemented in {dpcpp} only for devices that contain a matrix hardware, specifically Intel(R) Advanced Matrix Extensions (Intel(R) AMX), Intel(R) Xe Matrix Extensions (Intel(R) -XMX) and Nvidia(R) Tensor Cores. +XMX), Nvidia(R) Tensor Cores and AMD Matrix Cores. The `joint_matrix` type and the `joint_matrix_mad` function are optional kernel features as defined in section 5.7 of the core SYCL @@ -67,8 +67,8 @@ implementation throws a synchronous exception with the == Overview Joint matrix is a SYCL extension for matrix hardware programming. It -unifies targets like Intel AMX in CPUs, Intel XMX in Intel GPUs and -Nvidia Tensor Cores. This provides a portable and performant API for +unifies targets like Intel AMX in CPUs, Intel XMX in Intel GPUs, +Nvidia Tensor Cores and AMD Matrix Cores. This provides a portable and performant API for users who want to build their own neural networks applications, perform custom optimizations, or experiment with new operations in a timely and performing manner. @@ -922,7 +922,7 @@ matrix. Also, the type of the C matrix must be the same as the type of the D matrix. IMPORTANT: When compiling for the `ext_oneapi_cuda` backend the target -arch backend flag, `-Xsycl-target-backend --cuda-gpu-arch=sm_xx`, must +arch backend flag, `-fsycl-targets=nvidia_gpu_sm_xx`, must be used, where `sm_xx` must be a Compute Capability that is equal to or greater than the appropriate Minimum Compute Capability. When an executable has been compiled for `sm_xx`, if the executable is run on @@ -964,6 +964,35 @@ multiple of 4 when `T` is `float`; where `T` is the type of the `joint_matrix` elements. When `T` is not `half` or `float` there are no restrictions to `stride`. +==== AMD Matrix Cores Supported Combinations +The complete set of matrix data types and dimenstions that are supported by +the `ext_oneapi_hip` backend are represented in the following +table. In this architecture's implementation, A and B matrices must have the same type. +Similarly, C and D matrices must share the same type. + +IMPORTANT: Currently, only one block AMD Matrix Core instructions in +GFX90A (MI200, MI210, MI250 and MI250X GPUs) architecture are supported. +When compiling for the `ext_oneapi_hip` backend the target arch backend flag, + `-fsycl-targets=amd_gpu_gfx90a`, must +be used. An attempt to run the compiled code on an unsupported architecture will throw an error. + + +[frame="none",options="header"] +|====================== +| A and B type | C and D type | M | N | K +.2+| `matrix_type::fp16` .2+| `matrix_type::fp32` +|32 |32 |8 +|16 |16 |16 +.2+| `matrix_type::int8` .2+| `matrix_type::int32` +|32 |32 |8 +|16 |16 |16 +.2+|`matrix_type::bf16` .2+|`matrix_type::fp32` +|32 |32 |8 +|16 |16 |16 +.1+|`matrix_type::fp64` .1+| `matrix_type::fp64` +|16 |16 |4 +|====================== + === Revision History [frame="none",options="header"] @@ -983,4 +1012,5 @@ the Intel-specifics to a separate extension document type, runtime query, and supported combinations appendix for Intel AMX and Intel XMX |7 |2023-04-11 |Jack Kirk |Add Nvidia Tensor Cores supported combinations +|8 |2023-10-05 |Mahmoud Moadeli |Add AMD Matrix Core supported combinations |====================== diff --git a/sycl/include/sycl/detail/defines.hpp b/sycl/include/sycl/detail/defines.hpp index 5d44727d71fb1..ab90fd83331b7 100644 --- a/sycl/include/sycl/detail/defines.hpp +++ b/sycl/include/sycl/detail/defines.hpp @@ -40,8 +40,9 @@ #endif // joint matrix should only be included by default for SPIR or NVPTX backends -#if defined __SPIR__ || defined __NVPTX__ || !defined __SYCL_DEVICE_ONLY__ +#if defined __SPIR__ || defined __NVPTX__ || !defined __SYCL_DEVICE_ONLY__ || \ + defined __gfx90a__ #ifndef SYCL_EXT_ONEAPI_MATRIX_VERSION #define SYCL_EXT_ONEAPI_MATRIX_VERSION 4 #endif // SYCL_EXT_ONEAPI_MATRIX_VERSION -#endif // __SPIR__ || __NVPTX__ || !__SYCL_DEVICE_ONLY +#endif // __SPIR__ || __NVPTX__ || !__SYCL_DEVICE_ONLY || __gfx90a__ diff --git a/sycl/include/sycl/ext/oneapi/matrix/matrix-hip.hpp b/sycl/include/sycl/ext/oneapi/matrix/matrix-hip.hpp new file mode 100644 index 0000000000000..2a4bd38385417 --- /dev/null +++ b/sycl/include/sycl/ext/oneapi/matrix/matrix-hip.hpp @@ -0,0 +1,389 @@ + +//===-------- matrix-hip.hpp - matrix ext impl ---*- C++ -*-------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +// ===-------------------------------------------------------------------=== // + +#pragma once +#include "matrix-unified-utils.hpp" +#include + +#if defined(__gfx90a__) +#define __HIP_PLATFORM_AMD_MFMA__ +#endif + +namespace sycl { +inline namespace _V1 { +namespace ext { +namespace oneapi { +namespace experimental { +namespace matrix {} // namespace matrix +} // namespace experimental + +using matrix_layout = sycl::ext::oneapi::experimental::matrix::layout; +using matrix_use = sycl::ext::oneapi::experimental::matrix::use; + +namespace detail { + +template +struct joint_matrix_hip; + +#if defined(__SYCL_DEVICE_ONLY__) && defined(__HIP_PLATFORM_AMD_MFMA__) + +template struct to_hip_type { + using type = T; +}; + +template <> struct to_hip_type { + using type = __bf16; +}; + +template <> struct to_hip_type { + using type = __fp16; +}; + +template <> struct to_hip_type { + using type = int32_t; +}; + +#undef __SYCL_JOINT_MATRIX_OVERLOAD_ARR + +#define __SYCL_JOINT_MATRIX_OVERLOAD_ARR(TYPE, USE, M, N, SIZE) \ + template \ + struct joint_matrix_hip< \ + TYPE, matrix_use::USE, M, N, Layout, \ + typename std::enable_if_t> { \ + using ext_array_t = __attribute__(( \ + __vector_size__(SIZE * sizeof(typename to_hip_type::type)))) \ + typename to_hip_type::type; \ + ext_array_t data = {0}; \ + }; + +__SYCL_JOINT_MATRIX_OVERLOAD_ARR(bfloat16, a, 16, 16, 4) +__SYCL_JOINT_MATRIX_OVERLOAD_ARR(bfloat16, b, 16, 16, 4) +__SYCL_JOINT_MATRIX_OVERLOAD_ARR(bfloat16, a, 32, 8, 4) +__SYCL_JOINT_MATRIX_OVERLOAD_ARR(bfloat16, b, 8, 32, 4) + +__SYCL_JOINT_MATRIX_OVERLOAD_ARR(half, a, 16, 16, 4) +__SYCL_JOINT_MATRIX_OVERLOAD_ARR(half, b, 16, 16, 4) +__SYCL_JOINT_MATRIX_OVERLOAD_ARR(half, a, 32, 8, 4) +__SYCL_JOINT_MATRIX_OVERLOAD_ARR(half, b, 8, 32, 4) + +__SYCL_JOINT_MATRIX_OVERLOAD_ARR(double, a, 16, 4, 1) +__SYCL_JOINT_MATRIX_OVERLOAD_ARR(double, b, 4, 16, 1) + +#undef __SYCL_JOINT_MATRIX_OVERLOAD_ARR + +#define __SYCL_JOINT_MATRIX_OVERLOAD_INT8_ARR(USE, M, N, SIZE) \ + template \ + struct joint_matrix_hip< \ + int8_t, matrix_use::USE, M, N, Layout, \ + typename std::enable_if_t> { \ + int8_t data[SIZE]; \ + }; + +__SYCL_JOINT_MATRIX_OVERLOAD_INT8_ARR(a, 32, 8, 4) +__SYCL_JOINT_MATRIX_OVERLOAD_INT8_ARR(b, 8, 32, 4) +__SYCL_JOINT_MATRIX_OVERLOAD_INT8_ARR(a, 16, 16, 4) +__SYCL_JOINT_MATRIX_OVERLOAD_INT8_ARR(b, 16, 16, 4) + +#undef __SYCL_JOINT_MATRIX_OVERLOAD_INT8_ARR + +#define __SYCL_JOINT_MATRIX_OVERLOAD_ARR_ACC(TYPE, M, N) \ + template <> \ + struct joint_matrix_hip { \ + using ext_array_t = \ + __attribute__((__vector_size__((M * N) / 64 * sizeof(TYPE)))) TYPE; \ + ext_array_t data = {0}; \ + }; + +__SYCL_JOINT_MATRIX_OVERLOAD_ARR_ACC(float, 16, 16) +__SYCL_JOINT_MATRIX_OVERLOAD_ARR_ACC(float, 32, 32) +__SYCL_JOINT_MATRIX_OVERLOAD_ARR_ACC(double, 16, 16) +__SYCL_JOINT_MATRIX_OVERLOAD_ARR_ACC(int32_t, 32, 32) +__SYCL_JOINT_MATRIX_OVERLOAD_ARR_ACC(int32_t, 16, 16) + +#undef __SYCL_JOINT_MATRIX_OVERLOAD_ARR_ACC + +template +void load_accumulator_layoutT(joint_matrix_hip &res, + multi_ptr src, + size_t stride, Group &sg) { + const auto idx = sg.get_group_linear_id() * sg.get_local_range()[0] + + sg.get_local_linear_id(); + + if constexpr (std::is_same_v) { + const auto thread_x = idx % N; + const auto thread_y = idx / N; + + if constexpr (Layout == matrix_layout::row_major) { + for (int i = 0; i < 4; ++i) { + const int s_idx = thread_x + i * 4 * stride + thread_y * stride; + res.data[i] = src[s_idx]; + } + } else { + for (int i = 0; i < 4; ++i) { + const int s_idx = i * 4 + thread_x * stride + thread_y; + res.data[i] = src[s_idx]; + } + } + } else if constexpr (std::is_same_v || std::is_same_v) { + if constexpr (M == 16 && N == 16) { + const auto thread_x = idx % N; + const auto thread_y = idx / N; + + if constexpr (Layout == matrix_layout::row_major) { + for (int i = 0; i < 4; ++i) { + const int s_idx = thread_x + i * stride + thread_y * 4 * stride; + res.data[i] = src[s_idx]; + } + } else { + for (int i = 0; i < 4; ++i) { + const int s_idx = i + thread_x * stride + thread_y * 4; + res.data[i] = src[s_idx]; + } + } + } else if constexpr (M == 32 && N == 32) { + const auto thread_x = idx % N; + const auto thread_y = idx / N; + + if constexpr (Layout == matrix_layout::row_major) { + for (int j = 0; j < 4; ++j) { + for (int i = 0; i < 4; ++i) { + const int s_idx = + thread_x + i * stride + thread_y * 4 * stride + j * 8 * N; + res.data[i + 4 * j] = src[s_idx]; + } + } + } else { + for (int j = 0; j < 4; ++j) { + for (int i = 0; i < 4; ++i) { + const int s_idx = i + thread_x * stride + thread_y * 4 + j * 8; + res.data[i + 4 * j] = src[s_idx]; + } + } + } + } + } +} + +template < + typename Group, typename S, typename T, size_t M, size_t N, + access::address_space Space, access::decorated IsDecorated, + typename = std::enable_if_t>>> +void load_accumulator_hip(joint_matrix_hip &res, + multi_ptr src, size_t stride, + matrix_layout layout, Group &sg) { + static_assert(std::is_same_v || std::is_same_v || + std::is_same_v, + "Unsupported matrix type!"); + + if (layout == matrix_layout::row_major) + load_accumulator_layoutT(res, src, stride, sg); + else + load_accumulator_layoutT(res, src, stride, sg); +} + +template >>> +void load_multiplicand_hip(joint_matrix_hip &res, + multi_ptr src, size_t stride, + Group &sg) { + static_assert(std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v, + "Unsupported matrix type!"); + + const auto idx = sg.get_group_linear_id() * sg.get_local_range()[0] + + sg.get_local_linear_id(); + + if constexpr (std::is_same_v) { + if constexpr (Layout == matrix_layout::row_major) { + res.data[0] = src[idx]; + } else if constexpr (Layout == matrix_layout::col_major) { + res.data[0] = src[(idx % M) * 4 + idx / M]; + } + } else { + constexpr int Dim = (M == 16) ? 16 : 32; + + const auto thread_x = idx % Dim; + const auto thread_y = idx / Dim; + + if constexpr (Layout == matrix_layout::col_major) { + for (int i = 0; i < 4; ++i) { + const int c_idx = thread_x * stride + i + thread_y * 4; + res.data[i] = src[c_idx]; + } + } else if constexpr (Layout == matrix_layout::row_major) { + for (int i = 0; i < 4; ++i) { + const int r_idx = thread_x + i * stride + thread_y * stride * 4; + res.data[i] = src[r_idx]; + } + } + } +} + +template +void store_layoutT(joint_matrix_hip &src, + multi_ptr dst, size_t stride, + Group &sg) { + const auto idx = sg.get_group_linear_id() * sg.get_local_range()[0] + + sg.get_local_linear_id(); + + if constexpr (std::is_same_v) { + const auto thread_x = idx % N; + const auto thread_y = idx / N; + + if constexpr (Layout == matrix_layout::row_major) { + for (int i = 0; i < 4; ++i) { + const int d_idx = thread_x + i * 4 * stride + thread_y * stride; + dst[d_idx] = src.data[i]; + } + } else { + for (int i = 0; i < 4; ++i) { + const int d_idx = i * 4 + thread_x * stride + thread_y; + dst[d_idx] = src.data[i]; + } + } + } else if constexpr (std::is_same_v || std::is_same_v) { + if constexpr (M == 16 && N == 16) { + const auto thread_x = idx % N; + const auto thread_y = idx / N; + + if constexpr (Layout == matrix_layout::row_major) { + for (int i = 0; i < 4; ++i) { + const int d_idx = thread_x + i * stride + thread_y * 4 * stride; + dst[d_idx] = src.data[i]; + } + } else { + for (int i = 0; i < 4; ++i) { + const int d_idx = i + thread_x * stride + thread_y * 4; + dst[d_idx] = src.data[i]; + } + } + } else if constexpr (M == 32 && N == 32) { + const auto thread_x = idx % N; + const auto thread_y = idx / N; + + if constexpr (Layout == matrix_layout::row_major) { + for (int j = 0; j < 4; ++j) { + for (int i = 0; i < 4; ++i) { + const int d_idx = + thread_x + i * stride + thread_y * 4 * stride + j * 8 * stride; + dst[d_idx] = src.data[i + 4 * j]; + } + } + } else { + for (int j = 0; j < 4; ++j) { + for (int i = 0; i < 4; ++i) { + const int d_idx = i + thread_x * stride + thread_y * 4 + j * 8; + dst[d_idx] = src.data[i + 4 * j]; + } + } + } + } + } +} + +template +void joint_matrix_store_hip(joint_matrix_hip &src, + multi_ptr dst, size_t stride, + matrix_layout layout, Group &sg) { + if (matrix_layout::row_major == layout) { + store_layoutT(src, dst, stride, sg); + } else { + store_layoutT(src, dst, stride, sg); + } +} + +template = true> +void joint_matrix_mad_hip(joint_matrix_hip &D, + joint_matrix_hip &A, + joint_matrix_hip &B, + joint_matrix_hip &C) { + if constexpr (std::is_same_v) { + if constexpr (M == 16 && N == 16) { + D.data = __builtin_amdgcn_mfma_f32_16x16x16f16(A.data, B.data, C.data, 0, + 0, 0); + } else if constexpr (M == 32 && N == 32) { + D.data = + __builtin_amdgcn_mfma_f32_32x32x8f16(A.data, B.data, C.data, 0, 0, 0); + } + } else if constexpr (std::is_same_v) { + if constexpr (M == 16 && N == 16) { + D.data = __builtin_amdgcn_mfma_f32_16x16x16bf16_1k(A.data, B.data, C.data, + 0, 0, 0); + } else if constexpr (M == 32 && N == 32) { + D.data = __builtin_amdgcn_mfma_f32_32x32x8bf16_1k(A.data, B.data, C.data, + 0, 0, 0); + } + } else if constexpr (std::is_same_v) { + if constexpr (M == 16 && N == 16) { + D.data = __builtin_amdgcn_mfma_f64_16x16x4f64(A.data[0], B.data[0], + C.data, 0, 0, 0); + } + } else if constexpr (std::is_same_v) { + if constexpr (M == 16 && N == 16) { + D.data = __builtin_amdgcn_mfma_i32_16x16x16i8( + *reinterpret_cast(A.data), + *reinterpret_cast(B.data), C.data, 0, 0, 0); + } else if constexpr (M == 32 && N == 32) { + D.data = __builtin_amdgcn_mfma_i32_32x32x8i8( + *reinterpret_cast(A.data), + *reinterpret_cast(B.data), C.data, 0, 0, 0); + } + } else { + static_assert(false && "Invalid configuration!"); + } +} + +template +void joint_matrix_apply(joint_matrix_hip &jm, + F &&lambda) { + if constexpr (std::is_same_v && Use != matrix_use::accumulator) { + jm.data[0] = lambda(jm.data[0]); + } else if constexpr (Use != matrix_use::accumulator || + (Use == matrix_use::accumulator && NumRows == 16)) { + for (auto i = 0; i < 4; ++i) + jm.data[i] = lambda(jm.data[i]); + } else { + for (auto i = 0; i < 16; ++i) + jm.data[i] = lambda(jm.data[i]); + } +} + +#endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) + +} // namespace detail +} // namespace oneapi +} // namespace ext +} // namespace _V1 +} // namespace sycl diff --git a/sycl/include/sycl/ext/oneapi/matrix/matrix-unified.hpp b/sycl/include/sycl/ext/oneapi/matrix/matrix-unified.hpp index 327e1e326f108..86132fceaa003 100644 --- a/sycl/include/sycl/ext/oneapi/matrix/matrix-unified.hpp +++ b/sycl/include/sycl/ext/oneapi/matrix/matrix-unified.hpp @@ -8,6 +8,7 @@ #pragma once +#include "matrix-hip.hpp" #include "matrix-intel.hpp" #if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) @@ -43,6 +44,9 @@ struct joint_matrix { mutable sycl::ext::oneapi::detail::joint_matrix_cuda cuda_impl; +#elif defined(__HIP_PLATFORM_AMD_MFMA__) + sycl::ext::oneapi::detail::joint_matrix_hip + hip_impl; #elif defined(__SPIR__) __spv::__spirv_JointMatrixINTEL< T, Rows, Cols, spv_matrix_layout_traits::value, @@ -155,6 +159,9 @@ joint_matrix_apply(Group sg, joint_matrix &jm, for (int i = 0; i < jm.cuda_impl.wi_marray.size(); i++) { lambda(jm.cuda_impl.wi_marray[i]); } +#elif defined(__HIP_PLATFORM_AMD_MFMA__) + std::ignore = sg; + sycl::ext::oneapi::detail::joint_matrix_apply(jm.hip_impl, lambda); #else // NVPTX using storage_element_type = typename oneapi::detail::jm_type_interpretation_helper_trait< @@ -185,6 +192,10 @@ joint_matrix_fill(Group, #if defined(__SYCL_DEVICE_ONLY__) #if defined(__NVPTX__) res.cuda_impl.wi_marray = v; +#elif defined(__HIP_PLATFORM_AMD_MFMA__) + std::ignore = sg; + sycl::ext::oneapi::detail::joint_matrix_apply(res.hip_impl, + [=](T) { return v; }); #else using storage_element_type = typename oneapi::detail::jm_type_interpretation_helper_trait< @@ -220,6 +231,9 @@ inline __SYCL_ALWAYS_INLINE void joint_matrix_load( #if defined(__NVPTX__) sycl::ext::oneapi::detail::load_accumulator_cuda(res.cuda_impl, src, stride, Layout); +#elif defined(__HIP_PLATFORM_AMD_MFMA__) + sycl::ext::oneapi::detail::load_accumulator_hip(res.hip_impl, src, stride, + Layout, sg); #else using DecorT = typename sycl::detail::DecoratedType::type; DecorT *Ptr = sycl::detail::getDecorated(src); @@ -281,6 +295,10 @@ joint_matrix_load(Group, sycl::ext::oneapi::detail::load_multiplicand_cuda( res.cuda_impl, src, stride); +#elif defined(__HIP_PLATFORM_AMD_MFMA__) + sycl::ext::oneapi::detail::load_multiplicand_hip( + res.hip_impl, src, stride, sg); #else using DecorT = typename sycl::detail::DecoratedType::type; DecorT *Ptr = sycl::detail::getDecorated(src); @@ -316,6 +334,10 @@ inline __SYCL_ALWAYS_INLINE void joint_matrix_store( sycl::ext::oneapi::detail::joint_matrix_store_cuda(src.cuda_impl, dst, stride, Layout); +#elif defined(__HIP_PLATFORM_AMD_MFMA__) + sycl::ext::oneapi::detail::joint_matrix_store_hip(src.hip_impl, dst, + stride, Layout, sg); #else using DecorT = typename sycl::detail::DecoratedType::type; DecorT *Ptr = sycl::detail::getDecorated(dst); @@ -380,6 +402,19 @@ inline __SYCL_ALWAYS_INLINE void joint_matrix_mad( assert(false && "Ta != Tb : In the CUDA backend joint_matrix_mad " "requires that joint_matrix data types Ta and Tb match"); } +#elif defined(__HIP_PLATFORM_AMD_MFMA__) + if constexpr (std::is_same::value) { + joint_matrix + D; + sycl::ext::oneapi::detail::joint_matrix_mad_hip( + D.hip_impl, A.hip_impl, B.hip_impl, C.hip_impl); + return D; + } else { + assert(false && "Ta != Tb : In the HIP backend joint_matrix_mad " + "requires that joint_matrix data types Ta and Tb match"); + } #else if constexpr (std::is_same::value && std::is_same::value && diff --git a/sycl/test-e2e/Matrix/joint_matrix_hip_apply.cpp b/sycl/test-e2e/Matrix/joint_matrix_hip_apply.cpp new file mode 100644 index 0000000000000..c7cdedd8d53e2 --- /dev/null +++ b/sycl/test-e2e/Matrix/joint_matrix_hip_apply.cpp @@ -0,0 +1,14 @@ +// RUN: %{build} -fsycl -fsycl-targets=amd_gpu_gfx90a -DSYCL_EXT_ONEAPI_MATRIX_VERSION=4 %s -o %t.out +// RUN: %{run} %t.out + +// REQUIRES: gpu-amd-gfx90a + +#include "joint_matrix_hip_fill.hpp" + +int main() { + hip_matrix_mfma(); + hip_matrix_mfma(); + hip_matrix_mfma(); + hip_matrix_mfma(); + hip_matrix_mfma(); +} diff --git a/sycl/test-e2e/Matrix/joint_matrix_hip_apply.hpp b/sycl/test-e2e/Matrix/joint_matrix_hip_apply.hpp new file mode 100644 index 0000000000000..e52427520a11d --- /dev/null +++ b/sycl/test-e2e/Matrix/joint_matrix_hip_apply.hpp @@ -0,0 +1,102 @@ + +#include + +#include +#include + +using namespace sycl; +using namespace sycl::ext::oneapi::experimental::matrix; +using sycl::ext::oneapi::bfloat16; + +template +void hip_matrix_mfma() { + InType A[M * K]; + InType B[K * N]; + OutType C[M * N]; + OutType D[M * N]; + OutType E[M * N]; + + for (auto i = 0; i < M * K; ++i) { + A[i] = 1; + } + + for (auto i = 0; i < K * N; ++i) { + B[i] = 2; + } + + for (auto i = 0; i < M * N; ++i) { + D[i] = 0; + C[i] = 3; + E[i] = 3; + } + + try { + auto defaultQueue = sycl::queue{}; + + auto bufA = sycl::buffer{A, sycl::range{M * K}}; + auto bufB = sycl::buffer{B, sycl::range{K * N}}; + auto bufC = sycl::buffer{C, sycl::range{M * N}}; + auto bufD = sycl::buffer{D, sycl::range{M * N}}; + + defaultQueue + .submit([&](sycl::handler &cgh) { + sycl::accessor accA{bufA, cgh, sycl::read_write}; + sycl::accessor accB{bufB, cgh, sycl::read_write}; + sycl::accessor accC{bufC, cgh, sycl::read_only}; + sycl::accessor accD{bufD, cgh, sycl::write_only}; + + cgh.parallel_for( + sycl::nd_range<2>{{4, 16}, {4, 16}}, [=](sycl::nd_item<2> idx) { + auto sg = idx.get_sub_group(); + joint_matrix + sub_c{}; + joint_matrix + sub_b{}; + joint_matrix + sub_a{}; + + joint_matrix_load( + sg, sub_a, + accA.template get_multi_ptr(), K); + + joint_matrix_load( + sg, sub_b, + accB.template get_multi_ptr(), N); + + joint_matrix_load( + sg, sub_c, + accC.template get_multi_ptr(), N, + layout::row_major); + + joint_matrix_apply(sg, sub_a, [=](InType v) { return v * 2; }); + joint_matrix_apply(sg, sub_b, [=](InType v) { return v * 3; }); + joint_matrix_apply(sg, sub_c, [=](OutType v) { return v * 4; }); + + sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + + joint_matrix_store( + sg, sub_c, + accD.template get_multi_ptr(), N, + layout::row_major); + }); + }) + .wait(); + + defaultQueue.throw_asynchronous(); + } catch (const sycl::exception &e) { + std::cout << "Exception caught: " << e.what() << std::endl; + } + + for (int m = 0; m < M; m++) { + for (int n = 0; n < N; n++) { + E[m * N + n] *= 4; + for (int k = 0; k < K; k++) { + E[m * N + n] += A[m * K + k] * 2 * B[k * N + n] * 3; + } + } + } + + for (int i = 0; i < M * N; ++i) { + assert(D[i] == E[i] && "Unexpected difference"); + } +}; diff --git a/sycl/test-e2e/Matrix/joint_matrix_hip_fill.cpp b/sycl/test-e2e/Matrix/joint_matrix_hip_fill.cpp new file mode 100644 index 0000000000000..c7cdedd8d53e2 --- /dev/null +++ b/sycl/test-e2e/Matrix/joint_matrix_hip_fill.cpp @@ -0,0 +1,14 @@ +// RUN: %{build} -fsycl -fsycl-targets=amd_gpu_gfx90a -DSYCL_EXT_ONEAPI_MATRIX_VERSION=4 %s -o %t.out +// RUN: %{run} %t.out + +// REQUIRES: gpu-amd-gfx90a + +#include "joint_matrix_hip_fill.hpp" + +int main() { + hip_matrix_mfma(); + hip_matrix_mfma(); + hip_matrix_mfma(); + hip_matrix_mfma(); + hip_matrix_mfma(); +} diff --git a/sycl/test-e2e/Matrix/joint_matrix_hip_fill.hpp b/sycl/test-e2e/Matrix/joint_matrix_hip_fill.hpp new file mode 100644 index 0000000000000..2692b570a7a38 --- /dev/null +++ b/sycl/test-e2e/Matrix/joint_matrix_hip_fill.hpp @@ -0,0 +1,88 @@ + +#include + +#include +#include + +using namespace sycl; +using namespace sycl::ext::oneapi::experimental::matrix; +using sycl::ext::oneapi::bfloat16; + +template +void hip_matrix_mfma() { + InType A[M * K]; + InType B[K * N]; + OutType C[M * N]; + OutType D[M * N]; + OutType E[M * N]; + + for (auto i = 0; i < M * K; ++i) { + A[i] = 1; + } + + for (auto i = 0; i < K * N; ++i) { + B[i] = 2; + } + + for (auto i = 0; i < M * N; ++i) { + D[i] = 0; + C[i] = 3; + E[i] = 3; + } + + try { + auto defaultQueue = sycl::queue{}; + + auto bufA = sycl::buffer{A, sycl::range{M * K}}; + auto bufB = sycl::buffer{B, sycl::range{K * N}}; + auto bufC = sycl::buffer{C, sycl::range{M * N}}; + auto bufD = sycl::buffer{D, sycl::range{M * N}}; + + defaultQueue + .submit([&](sycl::handler &cgh) { + sycl::accessor accA{bufA, cgh, sycl::read_only}; + sycl::accessor accB{bufB, cgh, sycl::read_only}; + sycl::accessor accC{bufC, cgh, sycl::read_only}; + sycl::accessor accD{bufD, cgh, sycl::write_only}; + + cgh.parallel_for( + sycl::nd_range<2>{{4, 16}, {4, 16}}, [=](sycl::nd_item<2> idx) { + auto sg = idx.get_sub_group(); + joint_matrix + sub_c{}; + joint_matrix + sub_b{}; + joint_matrix + sub_a{}; + + joint_matrix_fill(sg, sub_a, 1); + joint_matrix_fill(sg, sub_b, 2); + joint_matrix_fill(sg, sub_c, 3); + + sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + + joint_matrix_store( + sg, sub_c, + accD.template get_multi_ptr(), N, + layout::row_major); + }); + }) + .wait(); + + defaultQueue.throw_asynchronous(); + } catch (const sycl::exception &e) { + std::cout << "Exception caught: " << e.what() << std::endl; + } + + for (int m = 0; m < M; m++) { + for (int n = 0; n < N; n++) { + for (int k = 0; k < K; k++) { + E[m * N + n] += A[m * K + k] * B[k * N + n]; + } + } + } + + for (int i = 0; i < M * N; ++i) { + assert(D[i] == E[i] && "Unexpected difference"); + } +}; diff --git a/sycl/test-e2e/Matrix/joint_matrix_hip_half_apply.cpp b/sycl/test-e2e/Matrix/joint_matrix_hip_half_apply.cpp new file mode 100644 index 0000000000000..dec6ee41604c7 --- /dev/null +++ b/sycl/test-e2e/Matrix/joint_matrix_hip_half_apply.cpp @@ -0,0 +1,12 @@ +// RUN: %{build} -fsycl -fsycl-targets=amd_gpu_gfx90a -DSYCL_EXT_ONEAPI_MATRIX_VERSION=4 %s -o %t.out +// RUN: %{run} %t.out + +// REQUIRES: gpu-amd-gfx90a +// REQUIRES: aspect-fp16 + +#include "joint_matrix_hip_fill.hpp" + +int main() { + hip_matrix_mfma(); + hip_matrix_mfma(); +} diff --git a/sycl/test-e2e/Matrix/joint_matrix_hip_half_fill.cpp b/sycl/test-e2e/Matrix/joint_matrix_hip_half_fill.cpp new file mode 100644 index 0000000000000..dec6ee41604c7 --- /dev/null +++ b/sycl/test-e2e/Matrix/joint_matrix_hip_half_fill.cpp @@ -0,0 +1,12 @@ +// RUN: %{build} -fsycl -fsycl-targets=amd_gpu_gfx90a -DSYCL_EXT_ONEAPI_MATRIX_VERSION=4 %s -o %t.out +// RUN: %{run} %t.out + +// REQUIRES: gpu-amd-gfx90a +// REQUIRES: aspect-fp16 + +#include "joint_matrix_hip_fill.hpp" + +int main() { + hip_matrix_mfma(); + hip_matrix_mfma(); +} diff --git a/sycl/test-e2e/Matrix/joint_matrix_hip_half_mfma.cpp b/sycl/test-e2e/Matrix/joint_matrix_hip_half_mfma.cpp new file mode 100644 index 0000000000000..644a1c2dc62df --- /dev/null +++ b/sycl/test-e2e/Matrix/joint_matrix_hip_half_mfma.cpp @@ -0,0 +1,15 @@ +// RUN: %{build} -fsycl -fsycl-targets=amd_gpu_gfx90a -DSYCL_EXT_ONEAPI_MATRIX_VERSION=4 %s -o %t.out +// RUN: %{run} %t.out + +// REQUIRES: gpu-amd-gfx90a +// REQUIRES: aspect-fp16 + +#include "joint_matrix_hip_mfma.hpp" + +int main() { + hip_matrix_mfma(); + hip_matrix_mfma(); + + hip_matrix_mfma(); + hip_matrix_mfma(); +} diff --git a/sycl/test-e2e/Matrix/joint_matrix_hip_mfma.cpp b/sycl/test-e2e/Matrix/joint_matrix_hip_mfma.cpp new file mode 100644 index 0000000000000..7fff98e0b4995 --- /dev/null +++ b/sycl/test-e2e/Matrix/joint_matrix_hip_mfma.cpp @@ -0,0 +1,20 @@ +// RUN: %{build} -fsycl -fsycl-targets=amd_gpu_gfx90a -DSYCL_EXT_ONEAPI_MATRIX_VERSION=4 %s -o %t.out +// RUN: %{run} %t.out + +// REQUIRES: gpu-amd-gfx90a + +#include "joint_matrix_hip_mfma.hpp" + +int main() { + hip_matrix_mfma(); + hip_matrix_mfma(); + hip_matrix_mfma(); + hip_matrix_mfma(); + hip_matrix_mfma(); + + hip_matrix_mfma(); + hip_matrix_mfma(); + hip_matrix_mfma(); + hip_matrix_mfma(); + hip_matrix_mfma(); +} diff --git a/sycl/test-e2e/Matrix/joint_matrix_hip_mfma.hpp b/sycl/test-e2e/Matrix/joint_matrix_hip_mfma.hpp new file mode 100644 index 0000000000000..b58477c408bc0 --- /dev/null +++ b/sycl/test-e2e/Matrix/joint_matrix_hip_mfma.hpp @@ -0,0 +1,116 @@ + +#include + +#include +#include + +using namespace sycl; +using namespace sycl::ext::oneapi::experimental::matrix; +using sycl::ext::oneapi::bfloat16; + +template struct input_limit { + static constexpr int value = M * N; +}; + +template <> struct input_limit { + static constexpr auto value = 128; +}; + +template <> struct input_limit { + static constexpr auto value = 128; +}; + +template +void hip_matrix_mfma() { + InType A[M * K]; + InType B[K * N]; + OutType C[M * N]; + OutType D[M * N]; + OutType E[M * N]; + + for (auto i = 0; i < M * K; ++i) { + A[i] = i % input_limit::value; + } + + for (auto i = 0; i < K * N; ++i) { + B[i] = i % input_limit::value; + } + + for (auto i = 0; i < M * N; ++i) { + D[i] = 0; + C[i] = i; + if (OutLayout == layout::row_major) + E[i] = i; + else + E[(i % N) * M + int(i / M)] = i; + } + + try { + auto defaultQueue = sycl::queue{}; + + auto bufA = sycl::buffer{A, sycl::range{M * K}}; + auto bufB = sycl::buffer{B, sycl::range{K * N}}; + auto bufC = sycl::buffer{C, sycl::range{M * N}}; + auto bufD = sycl::buffer{D, sycl::range{M * N}}; + + defaultQueue + .submit([&](sycl::handler &cgh) { + sycl::accessor accA{bufA, cgh, sycl::read_only}; + sycl::accessor accB{bufB, cgh, sycl::read_only}; + sycl::accessor accC{bufC, cgh, sycl::read_only}; + sycl::accessor accD{bufD, cgh, sycl::write_only}; + + cgh.parallel_for( + sycl::nd_range<2>{{4, 16}, {4, 16}}, [=](sycl::nd_item<2> idx) { + auto sg = idx.get_sub_group(); + joint_matrix + sub_c{}; + joint_matrix + sub_b{}; + joint_matrix + sub_a{}; + + joint_matrix_load( + sg, sub_a, + accA.template get_multi_ptr(), K); + + joint_matrix_load( + sg, sub_b, + accB.template get_multi_ptr(), N); + + joint_matrix_load( + sg, sub_c, + accC.template get_multi_ptr(), N, + layout::row_major); + + sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + + joint_matrix_store( + sg, sub_c, + accD.template get_multi_ptr(), N, + OutLayout); + }); + }) + .wait(); + + defaultQueue.throw_asynchronous(); + } catch (const sycl::exception &e) { + std::cout << "Exception caught: " << e.what() << std::endl; + } + + for (int m = 0; m < M; m++) { + for (int n = 0; n < N; n++) { + for (int k = 0; k < K; k++) { + if (OutLayout == layout::row_major) + E[m * N + n] += A[m * K + k] * B[k * N + n]; + else + E[n * M + m] += A[m * K + k] * B[k * N + n]; + } + } + } + + for (int i = 0; i < M * N; ++i) { + assert(abs(D[i] - E[i]) <= D[i] / 100 && "Unexpected difference"); + } +}; diff --git a/sycl/test-e2e/lit.cfg.py b/sycl/test-e2e/lit.cfg.py index 62093772043b1..a104d0eeec9bf 100644 --- a/sycl/test-e2e/lit.cfg.py +++ b/sycl/test-e2e/lit.cfg.py @@ -281,6 +281,8 @@ if "ext_oneapi_hip:gpu" in config.sycl_devices and config.hip_platform == "AMD": config.available_features.add('hip_amd') arch_flag = '-Xsycl-target-backend=amdgcn-amd-amdhsa --offload-arch=' + config.amd_arch + if "gfx90a" in config.sycl_devices: + config.available_features.add("gpu-amd-gfx90a") elif "ext_oneapi_hip:gpu" in config.sycl_devices and config.hip_platform == "NVIDIA": config.available_features.add('hip_nvidia') arch_flag = "" diff --git a/sycl/test/check_device_code/hip/matrix/matrix-hip-bfloat16-float-test.cpp b/sycl/test/check_device_code/hip/matrix/matrix-hip-bfloat16-float-test.cpp new file mode 100644 index 0000000000000..9ff29d317b3f8 --- /dev/null +++ b/sycl/test/check_device_code/hip/matrix/matrix-hip-bfloat16-float-test.cpp @@ -0,0 +1,90 @@ +// REQUIRES: hip + +// RUN: %clangxx -fsycl-device-only -fsycl-targets=amd_gpu_gfx90a -DSYCL_EXT_ONEAPI_MATRIX_VERSION=4 -S -Xclang -emit-llvm %s -o -| FileCheck %s + +#include + +using namespace sycl; +using namespace sycl::ext::oneapi::experimental::matrix; +using sycl::ext::oneapi::bfloat16; + +int main() { + + buffer bufA(nullptr, range<1>(1)); + buffer bufB(nullptr, range<1>(1)); + buffer bufC(nullptr, range<1>(1)); + buffer bufD(nullptr, range<1>(1)); + + queue q; + + q.submit([&](handler &cgh) { + sycl::accessor + accA(bufA, cgh); + sycl::accessor + accB(bufB, cgh); + sycl::accessor + accC(bufC, cgh); + sycl::accessor + accD(bufD, cgh); + + cgh.parallel_for( + nd_range<2>({1, 64}, {1, 64}), + [=](nd_item<2> item) [[sycl::reqd_work_group_size(1, 1, 64)]] { + sycl::sub_group sg = item.get_sub_group(); + + joint_matrix sub_c{}; + joint_matrix + sub_a{}; + joint_matrix + sub_b{}; + + joint_matrix_load( + sg, sub_c, accC.template get_multi_ptr(), + 16, layout::row_major); + joint_matrix_load( + sg, sub_a, accA.template get_multi_ptr(), + 16); + joint_matrix_load( + sg, sub_b, accB.template get_multi_ptr(), + 16); + // CHECK: tail call <4 x float> @llvm.amdgcn.mfma.f32.16x16x16bf16.1k(<4 x i16> %{{.*}}, <4 x i16> %{{.*}} <4 x float> zeroinitializer, i32 0, i32 0, i32 0) + sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_store( + sg, sub_c, accD.template get_multi_ptr(), + 16, layout::row_major); + }); + + cgh.parallel_for( + nd_range<2>({1, 64}, {1, 64}), + [=](nd_item<2> item) [[sycl::reqd_work_group_size(1, 1, 64)]] { + sycl::sub_group sg = item.get_sub_group(); + + joint_matrix sub_c{}; + joint_matrix + sub_a{}; + joint_matrix + sub_b{}; + + joint_matrix_load( + sg, sub_c, accC.template get_multi_ptr(), + 32, layout::row_major); + joint_matrix_load( + sg, sub_a, accA.template get_multi_ptr(), + 32); + joint_matrix_load( + sg, sub_b, accB.template get_multi_ptr(), + 8); + // CHECK: tail call <16 x float> @llvm.amdgcn.mfma.f32.32x32x8bf16.1k(<4 x i16> {{.*}}, <4 x i16> {{.*}}, <16 x float> zeroinitializer, i32 0, i32 0, i32 0) + sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_store( + sg, sub_c, accD.template get_multi_ptr(), + 32, layout::row_major); + }); + }); + + return 0; +}; diff --git a/sycl/test/check_device_code/hip/matrix/matrix-hip-double-double-test.cpp b/sycl/test/check_device_code/hip/matrix/matrix-hip-double-double-test.cpp new file mode 100644 index 0000000000000..30cfdb1d8aa39 --- /dev/null +++ b/sycl/test/check_device_code/hip/matrix/matrix-hip-double-double-test.cpp @@ -0,0 +1,62 @@ +// REQUIRES: hip + +// RUN: %clangxx -fsycl-device-only -fsycl-targets=amd_gpu_gfx90a -DSYCL_EXT_ONEAPI_MATRIX_VERSION=4 -S -Xclang -emit-llvm %s -o -| FileCheck %s + +#include + +using namespace sycl; +using namespace sycl::ext::oneapi::experimental::matrix; + +int main() { + + buffer bufA(nullptr, range<1>(1)); + buffer bufB(nullptr, range<1>(1)); + buffer bufC(nullptr, range<1>(1)); + buffer bufD(nullptr, range<1>(1)); + + queue q; + + q.submit([&](handler &cgh) { + sycl::accessor + accA(bufA, cgh); + sycl::accessor + accB(bufB, cgh); + sycl::accessor + accC(bufC, cgh); + sycl::accessor + accD(bufD, cgh); + + cgh.parallel_for( + nd_range<2>({1, 64}, {1, 64}), + [=](nd_item<2> item) [[sycl::reqd_work_group_size(1, 1, 64)]] { + sycl::sub_group sg = item.get_sub_group(); + + joint_matrix sub_c{}; + joint_matrix + sub_a{}; + joint_matrix + sub_b{}; + + joint_matrix_load( + sg, sub_c, accC.template get_multi_ptr(), + 16, layout::row_major); + joint_matrix_load( + sg, sub_a, accA.template get_multi_ptr(), + 16); + joint_matrix_load( + sg, sub_b, accB.template get_multi_ptr(), + 4); + // CHECK: tail call <4 x double> @llvm.amdgcn.mfma.f64.16x16x4f64(double %{{.*}}, double %{{.*}}, <4 x double> zeroinitializer, i32 0, i32 0, i32 0) + sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_store( + sg, sub_c, accD.template get_multi_ptr(), + 16, layout::row_major); + }); + }); + + return 0; +}; diff --git a/sycl/test/check_device_code/hip/matrix/matrix-hip-half-float-test.cpp b/sycl/test/check_device_code/hip/matrix/matrix-hip-half-float-test.cpp new file mode 100644 index 0000000000000..5ee6aed4ae2f8 --- /dev/null +++ b/sycl/test/check_device_code/hip/matrix/matrix-hip-half-float-test.cpp @@ -0,0 +1,89 @@ +// REQUIRES: hip + +// RUN: %clangxx -fsycl-device-only -fsycl-targets=amd_gpu_gfx90a -DSYCL_EXT_ONEAPI_MATRIX_VERSION=4 -S -Xclang -emit-llvm %s -o -| FileCheck %s + +#include + +using namespace sycl; +using namespace sycl::ext::oneapi::experimental::matrix; + +int main() { + + buffer bufA(nullptr, range<1>(1)); + buffer bufB(nullptr, range<1>(1)); + buffer bufC(nullptr, range<1>(1)); + buffer bufD(nullptr, range<1>(1)); + + queue q; + + q.submit([&](handler &cgh) { + sycl::accessor + accA(bufA, cgh); + sycl::accessor + accB(bufB, cgh); + sycl::accessor + accC(bufC, cgh); + sycl::accessor + accD(bufD, cgh); + + cgh.parallel_for( + nd_range<2>({1, 64}, {1, 64}), + [=](nd_item<2> item) [[sycl::reqd_work_group_size(1, 1, 64)]] { + sycl::sub_group sg = item.get_sub_group(); + + joint_matrix sub_c{}; + joint_matrix + sub_a{}; + joint_matrix + sub_b{}; + + joint_matrix_load( + sg, sub_c, accC.template get_multi_ptr(), + 16, layout::row_major); + joint_matrix_load( + sg, sub_a, accA.template get_multi_ptr(), + 16); + joint_matrix_load( + sg, sub_b, accB.template get_multi_ptr(), + 16); + // CHECK: tail call <4 x float> @llvm.amdgcn.mfma.f32.16x16x16f16(<4 x half> %{{.*}}, <4 x half> %{{.*}}, <4 x float> zeroinitializer, i32 0, i32 0, i32 0) + sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_store( + sg, sub_c, accD.template get_multi_ptr(), + 16, layout::row_major); + }); + + cgh.parallel_for( + nd_range<2>({1, 64}, {1, 64}), + [=](nd_item<2> item) [[sycl::reqd_work_group_size(1, 1, 64)]] { + sycl::sub_group sg = item.get_sub_group(); + + joint_matrix sub_c{}; + joint_matrix + sub_a{}; + joint_matrix + sub_b{}; + + joint_matrix_load( + sg, sub_c, accC.template get_multi_ptr(), + 32, layout::row_major); + joint_matrix_load( + sg, sub_a, accA.template get_multi_ptr(), + 32); + joint_matrix_load( + sg, sub_b, accB.template get_multi_ptr(), + 8); + // CHECK: tail call <16 x float> @llvm.amdgcn.mfma.f32.32x32x8f16(<4 x half> {{.*}}, <4 x half> {{.*}}, <16 x float> zeroinitializer, i32 0, i32 0, i32 0) + sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_store( + sg, sub_c, accD.template get_multi_ptr(), + 32, layout::row_major); + }); + }); + + return 0; +}; diff --git a/sycl/test/check_device_code/hip/matrix/matrix-hip-int8-int32-test.cpp b/sycl/test/check_device_code/hip/matrix/matrix-hip-int8-int32-test.cpp new file mode 100644 index 0000000000000..cbadcd03328a6 --- /dev/null +++ b/sycl/test/check_device_code/hip/matrix/matrix-hip-int8-int32-test.cpp @@ -0,0 +1,89 @@ +// REQUIRES: hip + +// RUN: %clangxx -fsycl-device-only -fsycl-targets=amd_gpu_gfx90a -DSYCL_EXT_ONEAPI_MATRIX_VERSION=4 -S -Xclang -emit-llvm %s -o -| FileCheck %s + +#include + +using namespace sycl; +using namespace sycl::ext::oneapi::experimental::matrix; + +int main() { + + buffer bufA(nullptr, range<1>(1)); + buffer bufB(nullptr, range<1>(1)); + buffer bufC(nullptr, range<1>(1)); + buffer bufD(nullptr, range<1>(1)); + + queue q; + + q.submit([&](handler &cgh) { + sycl::accessor + accA(bufA, cgh); + sycl::accessor + accB(bufB, cgh); + sycl::accessor + accC(bufC, cgh); + sycl::accessor + accD(bufD, cgh); + + cgh.parallel_for( + nd_range<2>({1, 64}, {1, 64}), + [=](nd_item<2> item) [[sycl::reqd_work_group_size(1, 1, 64)]] { + sycl::sub_group sg = item.get_sub_group(); + + joint_matrix sub_c{}; + joint_matrix + sub_a{}; + joint_matrix + sub_b{}; + + joint_matrix_load( + sg, sub_c, accC.template get_multi_ptr(), + 16, layout::row_major); + joint_matrix_load( + sg, sub_a, accA.template get_multi_ptr(), + 16); + joint_matrix_load( + sg, sub_b, accB.template get_multi_ptr(), + 16); + // CHECK: tail call <4 x i32> @llvm.amdgcn.mfma.i32.16x16x16i8(i32 %{{.*}}, i32 %{{.*}}, <4 x i32> zeroinitializer, i32 0, i32 0, i32 0) + sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_store( + sg, sub_c, accD.template get_multi_ptr(), + 16, layout::row_major); + }); + + cgh.parallel_for( + nd_range<2>({1, 64}, {1, 64}), + [=](nd_item<2> item) [[sycl::reqd_work_group_size(1, 1, 64)]] { + sycl::sub_group sg = item.get_sub_group(); + + joint_matrix sub_c{}; + joint_matrix + sub_a{}; + joint_matrix + sub_b{}; + + joint_matrix_load( + sg, sub_c, accC.template get_multi_ptr(), + 32, layout::row_major); + joint_matrix_load( + sg, sub_a, accA.template get_multi_ptr(), + 32); + joint_matrix_load( + sg, sub_b, accB.template get_multi_ptr(), + 8); + // CHECK: tail call <16 x i32> @llvm.amdgcn.mfma.i32.32x32x8i8(i32 {{.*}}, i32 {{.*}}, <16 x i32> zeroinitializer, i32 0, i32 0, i32 0) + sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_store( + sg, sub_c, accD.template get_multi_ptr(), + 32, layout::row_major); + }); + }); + + return 0; +}; From f6b2cb36435268979b9f08306858926d66cb9919 Mon Sep 17 00:00:00 2001 From: mmoadeli Date: Tue, 10 Oct 2023 15:15:09 +0100 Subject: [PATCH 25/50] * Update matrix core support into joint_matrix documentation. * Use fully qualified names. * Add diagnostic to tell the user that joint_maitrx is only supported for gfx90a. * Remove unnecessary `else` conditions. * Merge HIP matrix cpp files. --- .../sycl_ext_oneapi_matrix.asciidoc | 18 +- .../sycl/ext/oneapi/matrix/matrix-hip.hpp | 192 +++++++++++------- .../sycl/ext/oneapi/matrix/matrix-unified.hpp | 5 +- .../Matrix/joint_matrix_hip_apply.cpp | 14 -- .../Matrix/joint_matrix_hip_apply.hpp | 24 +-- .../test-e2e/Matrix/joint_matrix_hip_fill.cpp | 14 -- .../test-e2e/Matrix/joint_matrix_hip_fill.hpp | 8 +- ...p_mfma.cpp => joint_matrix_hip_gfx90a.cpp} | 15 +- .../Matrix/joint_matrix_hip_half_apply.cpp | 12 -- .../Matrix/joint_matrix_hip_half_fill.cpp | 12 -- .../Matrix/joint_matrix_hip_half_gfx90a.cpp | 22 ++ .../Matrix/joint_matrix_hip_half_mfma.cpp | 15 -- .../test-e2e/Matrix/joint_matrix_hip_mfma.hpp | 22 +- 13 files changed, 179 insertions(+), 194 deletions(-) delete mode 100644 sycl/test-e2e/Matrix/joint_matrix_hip_apply.cpp delete mode 100644 sycl/test-e2e/Matrix/joint_matrix_hip_fill.cpp rename sycl/test-e2e/Matrix/{joint_matrix_hip_mfma.cpp => joint_matrix_hip_gfx90a.cpp} (60%) delete mode 100644 sycl/test-e2e/Matrix/joint_matrix_hip_half_apply.cpp delete mode 100644 sycl/test-e2e/Matrix/joint_matrix_hip_half_fill.cpp create mode 100644 sycl/test-e2e/Matrix/joint_matrix_hip_half_gfx90a.cpp delete mode 100644 sycl/test-e2e/Matrix/joint_matrix_hip_half_mfma.cpp diff --git a/sycl/doc/extensions/experimental/sycl_ext_matrix/sycl_ext_oneapi_matrix.asciidoc b/sycl/doc/extensions/experimental/sycl_ext_matrix/sycl_ext_oneapi_matrix.asciidoc index c6ee448d0d5a4..2a77609eb9809 100644 --- a/sycl/doc/extensions/experimental/sycl_ext_matrix/sycl_ext_oneapi_matrix.asciidoc +++ b/sycl/doc/extensions/experimental/sycl_ext_matrix/sycl_ext_oneapi_matrix.asciidoc @@ -50,7 +50,7 @@ specification.* This extension is currently implemented in {dpcpp} only for devices that contain a matrix hardware, specifically Intel(R) Advanced Matrix Extensions (Intel(R) AMX), Intel(R) Xe Matrix Extensions (Intel(R) -XMX), Nvidia(R) Tensor Cores and AMD Matrix Cores. +XMX), Nvidia(R) Tensor Cores and AMD Matrix Cores(R). The `joint_matrix` type and the `joint_matrix_mad` function are optional kernel features as defined in section 5.7 of the core SYCL @@ -68,7 +68,7 @@ implementation throws a synchronous exception with the == Overview Joint matrix is a SYCL extension for matrix hardware programming. It unifies targets like Intel AMX in CPUs, Intel XMX in Intel GPUs, -Nvidia Tensor Cores and AMD Matrix Cores. This provides a portable and performant API for +Nvidia Tensor Cores and AMD Matrix Cores(R). This provides a portable and performant API for users who want to build their own neural networks applications, perform custom optimizations, or experiment with new operations in a timely and performing manner. @@ -922,7 +922,8 @@ matrix. Also, the type of the C matrix must be the same as the type of the D matrix. IMPORTANT: When compiling for the `ext_oneapi_cuda` backend the target -arch backend flag, `-fsycl-targets=nvidia_gpu_sm_xx`, must +arch backend flag, `-fsycl-targets=nvidia_gpu_sm_xx` +(or equivalents, e.g. `-Xsycl-target-backend --cuda-gpu-arch=sm_xx`), must be used, where `sm_xx` must be a Compute Capability that is equal to or greater than the appropriate Minimum Compute Capability. When an executable has been compiled for `sm_xx`, if the executable is run on @@ -965,15 +966,14 @@ multiple of 4 when `T` is `float`; where `T` is the type of the no restrictions to `stride`. ==== AMD Matrix Cores Supported Combinations -The complete set of matrix data types and dimenstions that are supported by +The complete set of matrix data types and dimensions that are supported by the `ext_oneapi_hip` backend are represented in the following table. In this architecture's implementation, A and B matrices must have the same type. Similarly, C and D matrices must share the same type. -IMPORTANT: Currently, only one block AMD Matrix Core instructions in -GFX90A (MI200, MI210, MI250 and MI250X GPUs) architecture are supported. -When compiling for the `ext_oneapi_hip` backend the target arch backend flag, - `-fsycl-targets=amd_gpu_gfx90a`, must +IMPORTANT: The supported instructions may be run on GFX90A (MI200, MI210, MI250 and MI250X GPUs) +architecture. When compiling for the `ext_oneapi_hip` backend the +target arch backend flag, `-fsycl-targets=amd_gpu_gfx90a`, must be used. An attempt to run the compiled code on an unsupported architecture will throw an error. @@ -983,7 +983,7 @@ be used. An attempt to run the compiled code on an unsupported architecture will .2+| `matrix_type::fp16` .2+| `matrix_type::fp32` |32 |32 |8 |16 |16 |16 -.2+| `matrix_type::int8` .2+| `matrix_type::int32` +.2+| `matrix_type::sint8` .2+| `matrix_type::sint32` |32 |32 |8 |16 |16 |16 .2+|`matrix_type::bf16` .2+|`matrix_type::fp32` diff --git a/sycl/include/sycl/ext/oneapi/matrix/matrix-hip.hpp b/sycl/include/sycl/ext/oneapi/matrix/matrix-hip.hpp index 2a4bd38385417..1792c04c5b8c2 100644 --- a/sycl/include/sycl/ext/oneapi/matrix/matrix-hip.hpp +++ b/sycl/include/sycl/ext/oneapi/matrix/matrix-hip.hpp @@ -23,13 +23,13 @@ namespace experimental { namespace matrix {} // namespace matrix } // namespace experimental -using matrix_layout = sycl::ext::oneapi::experimental::matrix::layout; -using matrix_use = sycl::ext::oneapi::experimental::matrix::use; - namespace detail { -template +template struct joint_matrix_hip; #if defined(__SYCL_DEVICE_ONLY__) && defined(__HIP_PLATFORM_AMD_MFMA__) @@ -53,11 +53,14 @@ template <> struct to_hip_type { #undef __SYCL_JOINT_MATRIX_OVERLOAD_ARR #define __SYCL_JOINT_MATRIX_OVERLOAD_ARR(TYPE, USE, M, N, SIZE) \ - template \ + template \ struct joint_matrix_hip< \ - TYPE, matrix_use::USE, M, N, Layout, \ - typename std::enable_if_t> { \ + TYPE, sycl::ext::oneapi::experimental::matrix::use::USE, M, N, Layout, \ + typename std::enable_if_t< \ + Layout == \ + sycl::ext::oneapi::experimental::matrix::layout::row_major || \ + Layout == \ + sycl::ext::oneapi::experimental::matrix::layout::col_major>> { \ using ext_array_t = __attribute__(( \ __vector_size__(SIZE * sizeof(typename to_hip_type::type)))) \ typename to_hip_type::type; \ @@ -80,11 +83,14 @@ __SYCL_JOINT_MATRIX_OVERLOAD_ARR(double, b, 4, 16, 1) #undef __SYCL_JOINT_MATRIX_OVERLOAD_ARR #define __SYCL_JOINT_MATRIX_OVERLOAD_INT8_ARR(USE, M, N, SIZE) \ - template \ + template \ struct joint_matrix_hip< \ - int8_t, matrix_use::USE, M, N, Layout, \ - typename std::enable_if_t> { \ + int8_t, sycl::ext::oneapi::experimental::matrix::use::USE, M, N, Layout, \ + typename std::enable_if_t< \ + Layout == \ + sycl::ext::oneapi::experimental::matrix::layout::row_major || \ + Layout == \ + sycl::ext::oneapi::experimental::matrix::layout::col_major>> { \ int8_t data[SIZE]; \ }; @@ -97,8 +103,9 @@ __SYCL_JOINT_MATRIX_OVERLOAD_INT8_ARR(b, 16, 16, 4) #define __SYCL_JOINT_MATRIX_OVERLOAD_ARR_ACC(TYPE, M, N) \ template <> \ - struct joint_matrix_hip { \ + struct joint_matrix_hip< \ + TYPE, sycl::ext::oneapi::experimental::matrix::use::accumulator, M, N, \ + sycl::ext::oneapi::experimental::matrix::layout::dynamic> { \ using ext_array_t = \ __attribute__((__vector_size__((M * N) / 64 * sizeof(TYPE)))) TYPE; \ ext_array_t data = {0}; \ @@ -112,13 +119,14 @@ __SYCL_JOINT_MATRIX_OVERLOAD_ARR_ACC(int32_t, 16, 16) #undef __SYCL_JOINT_MATRIX_OVERLOAD_ARR_ACC -template -void load_accumulator_layoutT(joint_matrix_hip &res, - multi_ptr src, - size_t stride, Group &sg) { +template +void load_accumulator_layoutT( + joint_matrix_hip< + S, sycl::ext::oneapi::experimental::matrix::use::accumulator, M, N, + sycl::ext::oneapi::experimental::matrix::layout::dynamic> &res, + multi_ptr src, size_t stride, Group &sg) { const auto idx = sg.get_group_linear_id() * sg.get_local_range()[0] + sg.get_local_linear_id(); @@ -126,7 +134,8 @@ void load_accumulator_layoutT(joint_matrix_hip>>> -void load_accumulator_hip(joint_matrix_hip &res, - multi_ptr src, size_t stride, - matrix_layout layout, Group &sg) { +void load_accumulator_hip( + joint_matrix_hip< + S, sycl::ext::oneapi::experimental::matrix::use::accumulator, M, N, + sycl::ext::oneapi::experimental::matrix::layout::dynamic> &res, + multi_ptr src, size_t stride, + sycl::ext::oneapi::experimental::matrix::layout layout, Group &sg) { static_assert(std::is_same_v || std::is_same_v || std::is_same_v, "Unsupported matrix type!"); - if (layout == matrix_layout::row_major) - load_accumulator_layoutT(res, src, stride, sg); + if (layout == sycl::ext::oneapi::experimental::matrix::layout::row_major) + load_accumulator_layoutT< + sycl::ext::oneapi::experimental::matrix::layout::row_major>(res, src, + stride, sg); else - load_accumulator_layoutT(res, src, stride, sg); + load_accumulator_layoutT< + sycl::ext::oneapi::experimental::matrix::layout::col_major>(res, src, + stride, sg); } -template >>> +template < + typename Group, typename S, typename T, size_t M, size_t N, + sycl::ext::oneapi::experimental::matrix::use Use, + sycl::ext::oneapi::experimental::matrix::layout Layout, + access::address_space Space, access::decorated IsDecorated, + typename = typename std::enable_if_t< + (Layout == sycl::ext::oneapi::experimental::matrix::layout::row_major || + Layout == + sycl::ext::oneapi::experimental::matrix::layout::col_major) && + std::is_same_v>>> void load_multiplicand_hip(joint_matrix_hip &res, multi_ptr src, size_t stride, Group &sg) { @@ -213,9 +233,10 @@ void load_multiplicand_hip(joint_matrix_hip &res, sg.get_local_linear_id(); if constexpr (std::is_same_v) { - if constexpr (Layout == matrix_layout::row_major) { + if constexpr (Layout == + sycl::ext::oneapi::experimental::matrix::layout::row_major) { res.data[0] = src[idx]; - } else if constexpr (Layout == matrix_layout::col_major) { + } else { res.data[0] = src[(idx % M) * 4 + idx / M]; } } else { @@ -224,12 +245,13 @@ void load_multiplicand_hip(joint_matrix_hip &res, const auto thread_x = idx % Dim; const auto thread_y = idx / Dim; - if constexpr (Layout == matrix_layout::col_major) { + if constexpr (Layout == + sycl::ext::oneapi::experimental::matrix::layout::col_major) { for (int i = 0; i < 4; ++i) { const int c_idx = thread_x * stride + i + thread_y * 4; res.data[i] = src[c_idx]; } - } else if constexpr (Layout == matrix_layout::row_major) { + } else { for (int i = 0; i < 4; ++i) { const int r_idx = thread_x + i * stride + thread_y * stride * 4; res.data[i] = src[r_idx]; @@ -238,12 +260,15 @@ void load_multiplicand_hip(joint_matrix_hip &res, } } -template -void store_layoutT(joint_matrix_hip &src, - multi_ptr dst, size_t stride, - Group &sg) { +template +void store_layoutT( + joint_matrix_hip< + T, sycl::ext::oneapi::experimental::matrix::use::accumulator, M, N, + sycl::ext::oneapi::experimental::matrix::layout::dynamic> &src, + multi_ptr dst, size_t stride, Group &sg) { const auto idx = sg.get_group_linear_id() * sg.get_local_range()[0] + sg.get_local_linear_id(); @@ -251,7 +276,8 @@ void store_layoutT(joint_matrix_hip -void joint_matrix_store_hip(joint_matrix_hip &src, - multi_ptr dst, size_t stride, - matrix_layout layout, Group &sg) { - if (matrix_layout::row_major == layout) { - store_layoutT(src, dst, stride, sg); +void joint_matrix_store_hip( + joint_matrix_hip< + T, sycl::ext::oneapi::experimental::matrix::use::accumulator, M, N, + sycl::ext::oneapi::experimental::matrix::layout::dynamic> &src, + multi_ptr dst, size_t stride, + sycl::ext::oneapi::experimental::matrix::layout layout, Group &sg) { + if (sycl::ext::oneapi::experimental::matrix::layout::row_major == layout) { + store_layoutT( + src, dst, stride, sg); } else { - store_layoutT(src, dst, stride, sg); + store_layoutT( + src, dst, stride, sg); } } template = true> -void joint_matrix_mad_hip(joint_matrix_hip &D, - joint_matrix_hip &A, - joint_matrix_hip &B, - joint_matrix_hip &C) { + sycl::ext::oneapi::experimental::matrix::layout LayoutA, + sycl::ext::oneapi::experimental::matrix::layout LayoutB> +void joint_matrix_mad_hip( + joint_matrix_hip< + Tc, sycl::ext::oneapi::experimental::matrix::use::accumulator, M, N, + sycl::ext::oneapi::experimental::matrix::layout::dynamic> &D, + joint_matrix_hip &A, + joint_matrix_hip &B, + joint_matrix_hip< + Tc, sycl::ext::oneapi::experimental::matrix::use::accumulator, M, N, + sycl::ext::oneapi::experimental::matrix::layout::dynamic> &C) { if constexpr (std::is_same_v) { if constexpr (M == 16 && N == 16) { D.data = __builtin_amdgcn_mfma_f32_16x16x16f16(A.data, B.data, C.data, 0, @@ -364,14 +399,19 @@ void joint_matrix_mad_hip(joint_matrix_hip +template void joint_matrix_apply(joint_matrix_hip &jm, F &&lambda) { - if constexpr (std::is_same_v && Use != matrix_use::accumulator) { + if constexpr (std::is_same_v && + Use != + sycl::ext::oneapi::experimental::matrix::use::accumulator) { jm.data[0] = lambda(jm.data[0]); - } else if constexpr (Use != matrix_use::accumulator || - (Use == matrix_use::accumulator && NumRows == 16)) { + } else if constexpr ( + Use != sycl::ext::oneapi::experimental::matrix::use::accumulator || + (Use == sycl::ext::oneapi::experimental::matrix::use::accumulator && + NumRows == 16)) { for (auto i = 0; i < 4; ++i) jm.data[i] = lambda(jm.data[i]); } else { diff --git a/sycl/include/sycl/ext/oneapi/matrix/matrix-unified.hpp b/sycl/include/sycl/ext/oneapi/matrix/matrix-unified.hpp index 86132fceaa003..6dce725b4ca9d 100644 --- a/sycl/include/sycl/ext/oneapi/matrix/matrix-unified.hpp +++ b/sycl/include/sycl/ext/oneapi/matrix/matrix-unified.hpp @@ -52,9 +52,8 @@ struct joint_matrix { T, Rows, Cols, spv_matrix_layout_traits::value, spv_scope_traits::value, spv_matrix_use_traits::value> *spvm; #else - static_assert( - false, - "The joint_matrix API is only supported by the Intel and CUDA backends"); + static_assert(false, "The joint_matrix API is only supported by the Intel, " + "CUDA and HIP (GFX90A) backends"); #endif // defined(__NVPTX__) #endif // defined(__SYCL_DEVICE_ONLY__) diff --git a/sycl/test-e2e/Matrix/joint_matrix_hip_apply.cpp b/sycl/test-e2e/Matrix/joint_matrix_hip_apply.cpp deleted file mode 100644 index c7cdedd8d53e2..0000000000000 --- a/sycl/test-e2e/Matrix/joint_matrix_hip_apply.cpp +++ /dev/null @@ -1,14 +0,0 @@ -// RUN: %{build} -fsycl -fsycl-targets=amd_gpu_gfx90a -DSYCL_EXT_ONEAPI_MATRIX_VERSION=4 %s -o %t.out -// RUN: %{run} %t.out - -// REQUIRES: gpu-amd-gfx90a - -#include "joint_matrix_hip_fill.hpp" - -int main() { - hip_matrix_mfma(); - hip_matrix_mfma(); - hip_matrix_mfma(); - hip_matrix_mfma(); - hip_matrix_mfma(); -} diff --git a/sycl/test-e2e/Matrix/joint_matrix_hip_apply.hpp b/sycl/test-e2e/Matrix/joint_matrix_hip_apply.hpp index e52427520a11d..7501da9fde3e3 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_hip_apply.hpp +++ b/sycl/test-e2e/Matrix/joint_matrix_hip_apply.hpp @@ -9,7 +9,7 @@ using namespace sycl::ext::oneapi::experimental::matrix; using sycl::ext::oneapi::bfloat16; template -void hip_matrix_mfma() { +void hip_matrix_apply() { InType A[M * K]; InType B[K * N]; OutType C[M * N]; @@ -55,18 +55,10 @@ void hip_matrix_mfma() { joint_matrix sub_a{}; - joint_matrix_load( - sg, sub_a, - accA.template get_multi_ptr(), K); - - joint_matrix_load( - sg, sub_b, - accB.template get_multi_ptr(), N); - - joint_matrix_load( - sg, sub_c, - accC.template get_multi_ptr(), N, - layout::row_major); + joint_matrix_load(sg, sub_a, accA.template get_multi_ptr(), K); + joint_matrix_load(sg, sub_b, accB.template get_multi_ptr(), N); + joint_matrix_load(sg, sub_c, accC.template get_multi_ptr(), N, + layout::row_major); joint_matrix_apply(sg, sub_a, [=](InType v) { return v * 2; }); joint_matrix_apply(sg, sub_b, [=](InType v) { return v * 3; }); @@ -74,10 +66,8 @@ void hip_matrix_mfma() { sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); - joint_matrix_store( - sg, sub_c, - accD.template get_multi_ptr(), N, - layout::row_major); + joint_matrix_store(sg, sub_c, accD.template get_multi_ptr(), N, + layout::row_major); }); }) .wait(); diff --git a/sycl/test-e2e/Matrix/joint_matrix_hip_fill.cpp b/sycl/test-e2e/Matrix/joint_matrix_hip_fill.cpp deleted file mode 100644 index c7cdedd8d53e2..0000000000000 --- a/sycl/test-e2e/Matrix/joint_matrix_hip_fill.cpp +++ /dev/null @@ -1,14 +0,0 @@ -// RUN: %{build} -fsycl -fsycl-targets=amd_gpu_gfx90a -DSYCL_EXT_ONEAPI_MATRIX_VERSION=4 %s -o %t.out -// RUN: %{run} %t.out - -// REQUIRES: gpu-amd-gfx90a - -#include "joint_matrix_hip_fill.hpp" - -int main() { - hip_matrix_mfma(); - hip_matrix_mfma(); - hip_matrix_mfma(); - hip_matrix_mfma(); - hip_matrix_mfma(); -} diff --git a/sycl/test-e2e/Matrix/joint_matrix_hip_fill.hpp b/sycl/test-e2e/Matrix/joint_matrix_hip_fill.hpp index 2692b570a7a38..91f4dd1d1435d 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_hip_fill.hpp +++ b/sycl/test-e2e/Matrix/joint_matrix_hip_fill.hpp @@ -9,7 +9,7 @@ using namespace sycl::ext::oneapi::experimental::matrix; using sycl::ext::oneapi::bfloat16; template -void hip_matrix_mfma() { +void hip_matrix_fill() { InType A[M * K]; InType B[K * N]; OutType C[M * N]; @@ -61,10 +61,8 @@ void hip_matrix_mfma() { sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); - joint_matrix_store( - sg, sub_c, - accD.template get_multi_ptr(), N, - layout::row_major); + joint_matrix_store(sg, sub_c, accD.template get_multi_ptr(), N, + layout::row_major); }); }) .wait(); diff --git a/sycl/test-e2e/Matrix/joint_matrix_hip_mfma.cpp b/sycl/test-e2e/Matrix/joint_matrix_hip_gfx90a.cpp similarity index 60% rename from sycl/test-e2e/Matrix/joint_matrix_hip_mfma.cpp rename to sycl/test-e2e/Matrix/joint_matrix_hip_gfx90a.cpp index 7fff98e0b4995..ab3b7b7eeb3d9 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_hip_mfma.cpp +++ b/sycl/test-e2e/Matrix/joint_matrix_hip_gfx90a.cpp @@ -3,6 +3,8 @@ // REQUIRES: gpu-amd-gfx90a +#include "joint_matrix_hip_apply.hpp" +#include "joint_matrix_hip_fill.hpp" #include "joint_matrix_hip_mfma.hpp" int main() { @@ -11,10 +13,21 @@ int main() { hip_matrix_mfma(); hip_matrix_mfma(); hip_matrix_mfma(); - hip_matrix_mfma(); hip_matrix_mfma(); hip_matrix_mfma(); hip_matrix_mfma(); hip_matrix_mfma(); + + hip_matrix_fill(); + hip_matrix_fill(); + hip_matrix_fill(); + hip_matrix_fill(); + hip_matrix_fill(); + + hip_matrix_apply(); + hip_matrix_apply(); + hip_matrix_apply(); + hip_matrix_apply(); + hip_matrix_apply(); } diff --git a/sycl/test-e2e/Matrix/joint_matrix_hip_half_apply.cpp b/sycl/test-e2e/Matrix/joint_matrix_hip_half_apply.cpp deleted file mode 100644 index dec6ee41604c7..0000000000000 --- a/sycl/test-e2e/Matrix/joint_matrix_hip_half_apply.cpp +++ /dev/null @@ -1,12 +0,0 @@ -// RUN: %{build} -fsycl -fsycl-targets=amd_gpu_gfx90a -DSYCL_EXT_ONEAPI_MATRIX_VERSION=4 %s -o %t.out -// RUN: %{run} %t.out - -// REQUIRES: gpu-amd-gfx90a -// REQUIRES: aspect-fp16 - -#include "joint_matrix_hip_fill.hpp" - -int main() { - hip_matrix_mfma(); - hip_matrix_mfma(); -} diff --git a/sycl/test-e2e/Matrix/joint_matrix_hip_half_fill.cpp b/sycl/test-e2e/Matrix/joint_matrix_hip_half_fill.cpp deleted file mode 100644 index dec6ee41604c7..0000000000000 --- a/sycl/test-e2e/Matrix/joint_matrix_hip_half_fill.cpp +++ /dev/null @@ -1,12 +0,0 @@ -// RUN: %{build} -fsycl -fsycl-targets=amd_gpu_gfx90a -DSYCL_EXT_ONEAPI_MATRIX_VERSION=4 %s -o %t.out -// RUN: %{run} %t.out - -// REQUIRES: gpu-amd-gfx90a -// REQUIRES: aspect-fp16 - -#include "joint_matrix_hip_fill.hpp" - -int main() { - hip_matrix_mfma(); - hip_matrix_mfma(); -} diff --git a/sycl/test-e2e/Matrix/joint_matrix_hip_half_gfx90a.cpp b/sycl/test-e2e/Matrix/joint_matrix_hip_half_gfx90a.cpp new file mode 100644 index 0000000000000..4a193be0535d0 --- /dev/null +++ b/sycl/test-e2e/Matrix/joint_matrix_hip_half_gfx90a.cpp @@ -0,0 +1,22 @@ +// RUN: %{build} -fsycl -fsycl-targets=amd_gpu_gfx90a -DSYCL_EXT_ONEAPI_MATRIX_VERSION=4 %s -o %t.out +// RUN: %{run} %t.out + +// REQUIRES: gpu-amd-gfx90a +// REQUIRES: aspect-fp16 + +#include "joint_matrix_hip_apply.hpp" +#include "joint_matrix_hip_fill.hpp" +#include "joint_matrix_hip_mfma.hpp" + +int main() { + hip_matrix_fill(); + hip_matrix_fill(); + hip_matrix_fill(); + hip_matrix_fill(); + + hip_matrix_fill(); + hip_matrix_fill(); + + hip_matrix_apply(); + hip_matrix_apply(); +} diff --git a/sycl/test-e2e/Matrix/joint_matrix_hip_half_mfma.cpp b/sycl/test-e2e/Matrix/joint_matrix_hip_half_mfma.cpp deleted file mode 100644 index 644a1c2dc62df..0000000000000 --- a/sycl/test-e2e/Matrix/joint_matrix_hip_half_mfma.cpp +++ /dev/null @@ -1,15 +0,0 @@ -// RUN: %{build} -fsycl -fsycl-targets=amd_gpu_gfx90a -DSYCL_EXT_ONEAPI_MATRIX_VERSION=4 %s -o %t.out -// RUN: %{run} %t.out - -// REQUIRES: gpu-amd-gfx90a -// REQUIRES: aspect-fp16 - -#include "joint_matrix_hip_mfma.hpp" - -int main() { - hip_matrix_mfma(); - hip_matrix_mfma(); - - hip_matrix_mfma(); - hip_matrix_mfma(); -} diff --git a/sycl/test-e2e/Matrix/joint_matrix_hip_mfma.hpp b/sycl/test-e2e/Matrix/joint_matrix_hip_mfma.hpp index b58477c408bc0..adc486ac6c10e 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_hip_mfma.hpp +++ b/sycl/test-e2e/Matrix/joint_matrix_hip_mfma.hpp @@ -71,25 +71,15 @@ void hip_matrix_mfma() { joint_matrix sub_a{}; - joint_matrix_load( - sg, sub_a, - accA.template get_multi_ptr(), K); - - joint_matrix_load( - sg, sub_b, - accB.template get_multi_ptr(), N); - - joint_matrix_load( - sg, sub_c, - accC.template get_multi_ptr(), N, - layout::row_major); + joint_matrix_load(sg, sub_a, accA.template get_multi_ptr(), K); + joint_matrix_load(sg, sub_b, accB.template get_multi_ptr(), N); + joint_matrix_load(sg, sub_c, accC.template get_multi_ptr(), N, + layout::row_major); sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); - joint_matrix_store( - sg, sub_c, - accD.template get_multi_ptr(), N, - OutLayout); + joint_matrix_store(sg, sub_c, accD.template get_multi_ptr(), N, + OutLayout); }); }) .wait(); From 919884ba994637e0dfccf91af2debd94bd19c7da Mon Sep 17 00:00:00 2001 From: mmoadeli Date: Tue, 10 Oct 2023 22:45:43 +0100 Subject: [PATCH 26/50] - Fix adding `gpu-amd-gfx90a` as available feature. - Update comment with `GFX90A` info. --- sycl/include/sycl/detail/defines.hpp | 3 ++- sycl/test-e2e/lit.cfg.py | 4 ++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/sycl/include/sycl/detail/defines.hpp b/sycl/include/sycl/detail/defines.hpp index ab90fd83331b7..de2de047528b1 100644 --- a/sycl/include/sycl/detail/defines.hpp +++ b/sycl/include/sycl/detail/defines.hpp @@ -39,7 +39,8 @@ #define __SYCL_TYPE(x) #endif -// joint matrix should only be included by default for SPIR or NVPTX backends +// joint matrix should only be included by default for SPIR, NVPTX or HIP(GFX90A +// only) backends #if defined __SPIR__ || defined __NVPTX__ || !defined __SYCL_DEVICE_ONLY__ || \ defined __gfx90a__ #ifndef SYCL_EXT_ONEAPI_MATRIX_VERSION diff --git a/sycl/test-e2e/lit.cfg.py b/sycl/test-e2e/lit.cfg.py index a104d0eeec9bf..cf5a1d038d4a8 100644 --- a/sycl/test-e2e/lit.cfg.py +++ b/sycl/test-e2e/lit.cfg.py @@ -253,6 +253,8 @@ devices = set() sp = subprocess.check_output(sycl_ls, text=True) for line in sp.splitlines(): + if "gfx90a" in line: + config.available_features.add("gpu-amd-gfx90a") (backend, device, _) = line[1:].split(':', 2) devices.add('{}:{}'.format(backend, device)) config.sycl_devices = list(devices) @@ -281,8 +283,6 @@ if "ext_oneapi_hip:gpu" in config.sycl_devices and config.hip_platform == "AMD": config.available_features.add('hip_amd') arch_flag = '-Xsycl-target-backend=amdgcn-amd-amdhsa --offload-arch=' + config.amd_arch - if "gfx90a" in config.sycl_devices: - config.available_features.add("gpu-amd-gfx90a") elif "ext_oneapi_hip:gpu" in config.sycl_devices and config.hip_platform == "NVIDIA": config.available_features.add('hip_nvidia') arch_flag = "" From 02bec23e3c2ba7df268f1bb9a57386a406fa5c91 Mon Sep 17 00:00:00 2001 From: mmoadeli Date: Wed, 11 Oct 2023 01:34:57 +0100 Subject: [PATCH 27/50] Fix a missing variable name change. --- sycl/include/sycl/ext/oneapi/matrix/matrix-hip.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sycl/include/sycl/ext/oneapi/matrix/matrix-hip.hpp b/sycl/include/sycl/ext/oneapi/matrix/matrix-hip.hpp index 1792c04c5b8c2..605b5dc967950 100644 --- a/sycl/include/sycl/ext/oneapi/matrix/matrix-hip.hpp +++ b/sycl/include/sycl/ext/oneapi/matrix/matrix-hip.hpp @@ -411,7 +411,7 @@ void joint_matrix_apply(joint_matrix_hip &jm, } else if constexpr ( Use != sycl::ext::oneapi::experimental::matrix::use::accumulator || (Use == sycl::ext::oneapi::experimental::matrix::use::accumulator && - NumRows == 16)) { + M == 16)) { for (auto i = 0; i < 4; ++i) jm.data[i] = lambda(jm.data[i]); } else { From 3c460af198a5db637a51d356adeee1cfe3159a61 Mon Sep 17 00:00:00 2001 From: mmoadeli Date: Wed, 11 Oct 2023 11:46:21 +0100 Subject: [PATCH 28/50] Add decoration type for call to get_multi_ptr function. --- .../Matrix/joint_matrix_hip_apply.hpp | 20 +++++++++++++------ .../test-e2e/Matrix/joint_matrix_hip_fill.hpp | 6 ++++-- .../test-e2e/Matrix/joint_matrix_hip_mfma.hpp | 20 +++++++++++++------ 3 files changed, 32 insertions(+), 14 deletions(-) diff --git a/sycl/test-e2e/Matrix/joint_matrix_hip_apply.hpp b/sycl/test-e2e/Matrix/joint_matrix_hip_apply.hpp index 7501da9fde3e3..bb13f198b279c 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_hip_apply.hpp +++ b/sycl/test-e2e/Matrix/joint_matrix_hip_apply.hpp @@ -55,10 +55,16 @@ void hip_matrix_apply() { joint_matrix sub_a{}; - joint_matrix_load(sg, sub_a, accA.template get_multi_ptr(), K); - joint_matrix_load(sg, sub_b, accB.template get_multi_ptr(), N); - joint_matrix_load(sg, sub_c, accC.template get_multi_ptr(), N, - layout::row_major); + joint_matrix_load( + sg, sub_a, + accA.template get_multi_ptr(), K); + joint_matrix_load( + sg, sub_b, + accB.template get_multi_ptr(), N); + joint_matrix_load( + sg, sub_c, + accC.template get_multi_ptr(), N, + layout::row_major); joint_matrix_apply(sg, sub_a, [=](InType v) { return v * 2; }); joint_matrix_apply(sg, sub_b, [=](InType v) { return v * 3; }); @@ -66,8 +72,10 @@ void hip_matrix_apply() { sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); - joint_matrix_store(sg, sub_c, accD.template get_multi_ptr(), N, - layout::row_major); + joint_matrix_store( + sg, sub_c, + accD.template get_multi_ptr(), N, + layout::row_major); }); }) .wait(); diff --git a/sycl/test-e2e/Matrix/joint_matrix_hip_fill.hpp b/sycl/test-e2e/Matrix/joint_matrix_hip_fill.hpp index 91f4dd1d1435d..1802cffd0c3f6 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_hip_fill.hpp +++ b/sycl/test-e2e/Matrix/joint_matrix_hip_fill.hpp @@ -61,8 +61,10 @@ void hip_matrix_fill() { sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); - joint_matrix_store(sg, sub_c, accD.template get_multi_ptr(), N, - layout::row_major); + joint_matrix_store( + sg, sub_c, + accD.template get_multi_ptr(), N, + layout::row_major); }); }) .wait(); diff --git a/sycl/test-e2e/Matrix/joint_matrix_hip_mfma.hpp b/sycl/test-e2e/Matrix/joint_matrix_hip_mfma.hpp index adc486ac6c10e..df0f62fa5dcc3 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_hip_mfma.hpp +++ b/sycl/test-e2e/Matrix/joint_matrix_hip_mfma.hpp @@ -71,15 +71,23 @@ void hip_matrix_mfma() { joint_matrix sub_a{}; - joint_matrix_load(sg, sub_a, accA.template get_multi_ptr(), K); - joint_matrix_load(sg, sub_b, accB.template get_multi_ptr(), N); - joint_matrix_load(sg, sub_c, accC.template get_multi_ptr(), N, - layout::row_major); + joint_matrix_load( + sg, sub_a, + accA.template get_multi_ptr(), K); + joint_matrix_load( + sg, sub_b, + accB.template get_multi_ptr(), N); + joint_matrix_load( + sg, sub_c, + accC.template get_multi_ptr(), N, + layout::row_major); sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); - joint_matrix_store(sg, sub_c, accD.template get_multi_ptr(), N, - OutLayout); + joint_matrix_store( + sg, sub_c, + accD.template get_multi_ptr(), N, + OutLayout); }); }) .wait(); From 42e0c62ee83be46257f8f742ff9a97c5ef4e7a62 Mon Sep 17 00:00:00 2001 From: mmoadeli Date: Thu, 12 Oct 2023 10:35:51 +0100 Subject: [PATCH 29/50] Update use cases of `mad` to have variables holding result of `mad` as a parameter of the function. --- .../sycl/ext/oneapi/matrix/matrix-hip.hpp | 22 +++++++++---------- .../sycl/ext/oneapi/matrix/matrix-unified.hpp | 15 ++++++------- .../Matrix/joint_matrix_hip_apply.hpp | 2 +- .../test-e2e/Matrix/joint_matrix_hip_fill.hpp | 2 +- .../test-e2e/Matrix/joint_matrix_hip_mfma.hpp | 2 +- 5 files changed, 21 insertions(+), 22 deletions(-) diff --git a/sycl/include/sycl/ext/oneapi/matrix/matrix-hip.hpp b/sycl/include/sycl/ext/oneapi/matrix/matrix-hip.hpp index 605b5dc967950..5713cb00bb7de 100644 --- a/sycl/include/sycl/ext/oneapi/matrix/matrix-hip.hpp +++ b/sycl/include/sycl/ext/oneapi/matrix/matrix-hip.hpp @@ -265,7 +265,7 @@ template void store_layoutT( - joint_matrix_hip< + const joint_matrix_hip< T, sycl::ext::oneapi::experimental::matrix::use::accumulator, M, N, sycl::ext::oneapi::experimental::matrix::layout::dynamic> &src, multi_ptr dst, size_t stride, Group &sg) { @@ -333,7 +333,7 @@ void store_layoutT( template void joint_matrix_store_hip( - joint_matrix_hip< + const joint_matrix_hip< T, sycl::ext::oneapi::experimental::matrix::use::accumulator, M, N, sycl::ext::oneapi::experimental::matrix::layout::dynamic> &src, multi_ptr dst, size_t stride, @@ -356,11 +356,11 @@ void joint_matrix_mad_hip( joint_matrix_hip< Tc, sycl::ext::oneapi::experimental::matrix::use::accumulator, M, N, sycl::ext::oneapi::experimental::matrix::layout::dynamic> &D, - joint_matrix_hip &A, - joint_matrix_hip &B, - joint_matrix_hip< + const joint_matrix_hip &A, + const joint_matrix_hip &B, + const joint_matrix_hip< Tc, sycl::ext::oneapi::experimental::matrix::use::accumulator, M, N, sycl::ext::oneapi::experimental::matrix::layout::dynamic> &C) { if constexpr (std::is_same_v) { @@ -387,12 +387,12 @@ void joint_matrix_mad_hip( } else if constexpr (std::is_same_v) { if constexpr (M == 16 && N == 16) { D.data = __builtin_amdgcn_mfma_i32_16x16x16i8( - *reinterpret_cast(A.data), - *reinterpret_cast(B.data), C.data, 0, 0, 0); + *reinterpret_cast(A.data), + *reinterpret_cast(B.data), C.data, 0, 0, 0); } else if constexpr (M == 32 && N == 32) { D.data = __builtin_amdgcn_mfma_i32_32x32x8i8( - *reinterpret_cast(A.data), - *reinterpret_cast(B.data), C.data, 0, 0, 0); + *reinterpret_cast(A.data), + *reinterpret_cast(B.data), C.data, 0, 0, 0); } } else { static_assert(false && "Invalid configuration!"); diff --git a/sycl/include/sycl/ext/oneapi/matrix/matrix-unified.hpp b/sycl/include/sycl/ext/oneapi/matrix/matrix-unified.hpp index 6dce725b4ca9d..0d5410d32d2ac 100644 --- a/sycl/include/sycl/ext/oneapi/matrix/matrix-unified.hpp +++ b/sycl/include/sycl/ext/oneapi/matrix/matrix-unified.hpp @@ -192,7 +192,6 @@ joint_matrix_fill(Group, #if defined(__NVPTX__) res.cuda_impl.wi_marray = v; #elif defined(__HIP_PLATFORM_AMD_MFMA__) - std::ignore = sg; sycl::ext::oneapi::detail::joint_matrix_apply(res.hip_impl, [=](T) { return v; }); #else @@ -219,7 +218,7 @@ template < std::enable_if_t>::value, bool> = true> inline __SYCL_ALWAYS_INLINE void joint_matrix_load( - Group, + Group &sg, joint_matrix &res, multi_ptr src, size_t stride, @@ -228,6 +227,7 @@ inline __SYCL_ALWAYS_INLINE void joint_matrix_load( static_assert(Space != access::address_space::private_space, "Joint Matrix doesn't support load from private memory!"); #if defined(__NVPTX__) + std::ignore = sg; sycl::ext::oneapi::detail::load_accumulator_cuda(res.cuda_impl, src, stride, Layout); #elif defined(__HIP_PLATFORM_AMD_MFMA__) @@ -266,6 +266,7 @@ inline __SYCL_ALWAYS_INLINE void joint_matrix_load( } #endif // defined(__NVPTX__) #else + std::ignore = sg; std::ignore = res; std::ignore = src; std::ignore = stride; @@ -284,13 +285,14 @@ template < std::is_same, float>::value), bool> = true> inline __SYCL_ALWAYS_INLINE void -joint_matrix_load(Group, +joint_matrix_load(Group &sg, joint_matrix &res, multi_ptr src, size_t stride) { #if defined(__SYCL_DEVICE_ONLY__) static_assert(Space != access::address_space::private_space, "Joint Matrix doesn't support load from private memory!"); #if defined(__NVPTX__) + std::ignore = sg; sycl::ext::oneapi::detail::load_multiplicand_cuda( res.cuda_impl, src, stride); @@ -320,7 +322,7 @@ joint_matrix_load(Group, template inline __SYCL_ALWAYS_INLINE void joint_matrix_store( - Group, + Group &sg, const joint_matrix &src, @@ -330,6 +332,7 @@ inline __SYCL_ALWAYS_INLINE void joint_matrix_store( static_assert(Space != access::address_space::private_space, "Joint Matrix doesn't support store to private memory!"); #if defined(__NVPTX__) + std::ignore = sg; sycl::ext::oneapi::detail::joint_matrix_store_cuda(src.cuda_impl, dst, stride, Layout); @@ -403,13 +406,9 @@ inline __SYCL_ALWAYS_INLINE void joint_matrix_mad( } #elif defined(__HIP_PLATFORM_AMD_MFMA__) if constexpr (std::is_same::value) { - joint_matrix - D; sycl::ext::oneapi::detail::joint_matrix_mad_hip( D.hip_impl, A.hip_impl, B.hip_impl, C.hip_impl); - return D; } else { assert(false && "Ta != Tb : In the HIP backend joint_matrix_mad " "requires that joint_matrix data types Ta and Tb match"); diff --git a/sycl/test-e2e/Matrix/joint_matrix_hip_apply.hpp b/sycl/test-e2e/Matrix/joint_matrix_hip_apply.hpp index bb13f198b279c..ef3a63bd44396 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_hip_apply.hpp +++ b/sycl/test-e2e/Matrix/joint_matrix_hip_apply.hpp @@ -70,7 +70,7 @@ void hip_matrix_apply() { joint_matrix_apply(sg, sub_b, [=](InType v) { return v * 3; }); joint_matrix_apply(sg, sub_c, [=](OutType v) { return v * 4; }); - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_c, sub_a, sub_b, sub_c); joint_matrix_store( sg, sub_c, diff --git a/sycl/test-e2e/Matrix/joint_matrix_hip_fill.hpp b/sycl/test-e2e/Matrix/joint_matrix_hip_fill.hpp index 1802cffd0c3f6..2dc4be9d8d5f1 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_hip_fill.hpp +++ b/sycl/test-e2e/Matrix/joint_matrix_hip_fill.hpp @@ -59,7 +59,7 @@ void hip_matrix_fill() { joint_matrix_fill(sg, sub_b, 2); joint_matrix_fill(sg, sub_c, 3); - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_c, sub_a, sub_b, sub_c); joint_matrix_store( sg, sub_c, diff --git a/sycl/test-e2e/Matrix/joint_matrix_hip_mfma.hpp b/sycl/test-e2e/Matrix/joint_matrix_hip_mfma.hpp index df0f62fa5dcc3..c15e62e5222b7 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_hip_mfma.hpp +++ b/sycl/test-e2e/Matrix/joint_matrix_hip_mfma.hpp @@ -82,7 +82,7 @@ void hip_matrix_mfma() { accC.template get_multi_ptr(), N, layout::row_major); - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_c, sub_a, sub_b, sub_c); joint_matrix_store( sg, sub_c, From be7a90fd1ab0402c3640f014d9ab680a1006d45f Mon Sep 17 00:00:00 2001 From: mmoadeli Date: Thu, 12 Oct 2023 13:31:02 +0100 Subject: [PATCH 30/50] Implement joint_matrix_copy for HIP(gfx90a) backend. --- .../sycl/ext/oneapi/matrix/matrix-hip.hpp | 23 ++++ .../sycl/ext/oneapi/matrix/matrix-unified.hpp | 4 + .../test-e2e/Matrix/joint_matrix_hip_copy.hpp | 118 ++++++++++++++++++ .../test-e2e/Matrix/joint_matrix_hip_fill.hpp | 1 - .../Matrix/joint_matrix_hip_gfx90a.cpp | 12 ++ .../Matrix/joint_matrix_hip_half_gfx90a.cpp | 14 ++- 6 files changed, 167 insertions(+), 5 deletions(-) create mode 100644 sycl/test-e2e/Matrix/joint_matrix_hip_copy.hpp diff --git a/sycl/include/sycl/ext/oneapi/matrix/matrix-hip.hpp b/sycl/include/sycl/ext/oneapi/matrix/matrix-hip.hpp index 5713cb00bb7de..1980a76d5fa38 100644 --- a/sycl/include/sycl/ext/oneapi/matrix/matrix-hip.hpp +++ b/sycl/include/sycl/ext/oneapi/matrix/matrix-hip.hpp @@ -420,6 +420,29 @@ void joint_matrix_apply(joint_matrix_hip &jm, } } +template +void joint_matrix_copy(joint_matrix_hip &src, + joint_matrix_hip &dst) { + if constexpr (std::is_same_v && + Use1 != + sycl::ext::oneapi::experimental::matrix::use::accumulator) { + dst.data[0] = src.data[0]; + } else if constexpr ( + Use1 != sycl::ext::oneapi::experimental::matrix::use::accumulator || + (Use1 == sycl::ext::oneapi::experimental::matrix::use::accumulator && + M == 16)) { + for (auto i = 0; i < 4; ++i) + dst.data[i] = src.data[i]; + } else { + for (auto i = 0; i < 16; ++i) + src.data[i] = src.data[i]; + } +} + #endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) } // namespace detail diff --git a/sycl/include/sycl/ext/oneapi/matrix/matrix-unified.hpp b/sycl/include/sycl/ext/oneapi/matrix/matrix-unified.hpp index 0d5410d32d2ac..f1d9c4359bc65 100644 --- a/sycl/include/sycl/ext/oneapi/matrix/matrix-unified.hpp +++ b/sycl/include/sycl/ext/oneapi/matrix/matrix-unified.hpp @@ -341,6 +341,7 @@ inline __SYCL_ALWAYS_INLINE void joint_matrix_store( Space>(src.hip_impl, dst, stride, Layout, sg); #else + std::ignore = sg; using DecorT = typename sycl::detail::DecoratedType::type; DecorT *Ptr = sycl::detail::getDecorated(dst); switch (Layout) { @@ -448,6 +449,9 @@ void joint_matrix_copy( for (int i = 0; i < src.cuda_impl.wi_marray.size(); i++) { dst.cuda_impl.wi_marray[i] = src.cuda_impl.wi_marray[i]; } +#elif defined(__HIP_PLATFORM_AMD_MFMA__) + std::ignore = sg; + sycl::ext::oneapi::detail::joint_matrix_apply(src.hip_impl, src.hip_impl); #else using storage_element_type = typename oneapi::detail::jm_type_interpretation_helper_trait< diff --git a/sycl/test-e2e/Matrix/joint_matrix_hip_copy.hpp b/sycl/test-e2e/Matrix/joint_matrix_hip_copy.hpp new file mode 100644 index 0000000000000..6950f868149e3 --- /dev/null +++ b/sycl/test-e2e/Matrix/joint_matrix_hip_copy.hpp @@ -0,0 +1,118 @@ +#include + +#include +#include + +using namespace sycl; +using namespace sycl::ext::oneapi::experimental::matrix; +using sycl::ext::oneapi::bfloat16; + +template struct input_limit { + static constexpr int value = M * N; +}; + +template <> struct input_limit { + static constexpr auto value = 128; +}; + +template <> struct input_limit { + static constexpr auto value = 128; +}; + +template +void hip_matrix_copy() { + InType A[M * K]; + InType B[K * N]; + OutType C[M * N]; + OutType D[M * N]; + OutType E[M * N]; + + for (auto i = 0; i < M * K; ++i) { + A[i] = i % input_limit::value; + } + + for (auto i = 0; i < K * N; ++i) { + B[i] = i % input_limit::value; + } + + for (auto i = 0; i < M * N; ++i) { + D[i] = 0; + C[i] = i; + if (OutLayout == layout::row_major) + E[i] = i; + else + E[(i % N) * M + int(i / M)] = i; + } + + try { + auto defaultQueue = sycl::queue{}; + + auto bufA = sycl::buffer{A, sycl::range{M * K}}; + auto bufB = sycl::buffer{B, sycl::range{K * N}}; + auto bufC = sycl::buffer{C, sycl::range{M * N}}; + auto bufD = sycl::buffer{D, sycl::range{M * N}}; + + defaultQueue + .submit([&](sycl::handler &cgh) { + sycl::accessor accA{bufA, cgh, sycl::read_only}; + sycl::accessor accB{bufB, cgh, sycl::read_only}; + sycl::accessor accC{bufC, cgh, sycl::read_only}; + sycl::accessor accD{bufD, cgh, sycl::write_only}; + + cgh.parallel_for( + sycl::nd_range<2>{{4, 16}, {4, 16}}, [=](sycl::nd_item<2> idx) { + auto sg = idx.get_sub_group(); + joint_matrix sub_c, + sub_c_copy{}; + joint_matrix + sub_b{}, sub_b_copy{}; + joint_matrix + sub_a, sub_a_copy{}; + + joint_matrix_copy(sg, sub_c, sub_c_copy); + joint_matrix_copy(sg, sub_a, sub_a_copy); + joint_matrix_copy(sg, sub_b, sub_b_copy); + + joint_matrix_load( + sg, sub_a_copy, + accA.template get_multi_ptr(), K); + joint_matrix_load( + sg, sub_b_copy, + accB.template get_multi_ptr(), N); + joint_matrix_load( + sg, sub_c_copy, + accC.template get_multi_ptr(), N, + layout::row_major); + + joint_matrix_mad(sg, sub_c_copy, sub_a_copy, sub_b_copy, + sub_c_copy); + + joint_matrix_store( + sg, sub_c_copy, + accD.template get_multi_ptr(), N, + OutLayout); + }); + }) + .wait(); + + defaultQueue.throw_asynchronous(); + } catch (const sycl::exception &e) { + std::cout << "Exception caught: " << e.what() << std::endl; + } + + for (int m = 0; m < M; m++) { + for (int n = 0; n < N; n++) { + for (int k = 0; k < K; k++) { + if (OutLayout == layout::row_major) + E[m * N + n] += A[m * K + k] * B[k * N + n]; + else + E[n * M + m] += A[m * K + k] * B[k * N + n]; + } + } + } + + for (int i = 0; i < M * N; ++i) { + assert(abs(D[i] - E[i]) <= D[i] / 100 && "Unexpected difference"); + } +}; diff --git a/sycl/test-e2e/Matrix/joint_matrix_hip_fill.hpp b/sycl/test-e2e/Matrix/joint_matrix_hip_fill.hpp index 2dc4be9d8d5f1..c58e5bce5c5b5 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_hip_fill.hpp +++ b/sycl/test-e2e/Matrix/joint_matrix_hip_fill.hpp @@ -1,4 +1,3 @@ - #include #include diff --git a/sycl/test-e2e/Matrix/joint_matrix_hip_gfx90a.cpp b/sycl/test-e2e/Matrix/joint_matrix_hip_gfx90a.cpp index ab3b7b7eeb3d9..e437cd0e7b28c 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_hip_gfx90a.cpp +++ b/sycl/test-e2e/Matrix/joint_matrix_hip_gfx90a.cpp @@ -4,6 +4,7 @@ // REQUIRES: gpu-amd-gfx90a #include "joint_matrix_hip_apply.hpp" +#include "joint_matrix_hip_copy.hpp" #include "joint_matrix_hip_fill.hpp" #include "joint_matrix_hip_mfma.hpp" @@ -19,6 +20,17 @@ int main() { hip_matrix_mfma(); hip_matrix_mfma(); + hip_matrix_copy(); + hip_matrix_copy(); + hip_matrix_copy(); + hip_matrix_copy(); + hip_matrix_copy(); + hip_matrix_copy(); + hip_matrix_copy(); + hip_matrix_copy(); + hip_matrix_copy(); + hip_matrix_copy(); + hip_matrix_fill(); hip_matrix_fill(); hip_matrix_fill(); diff --git a/sycl/test-e2e/Matrix/joint_matrix_hip_half_gfx90a.cpp b/sycl/test-e2e/Matrix/joint_matrix_hip_half_gfx90a.cpp index 4a193be0535d0..47fee217ce5d8 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_hip_half_gfx90a.cpp +++ b/sycl/test-e2e/Matrix/joint_matrix_hip_half_gfx90a.cpp @@ -5,14 +5,20 @@ // REQUIRES: aspect-fp16 #include "joint_matrix_hip_apply.hpp" +#include "joint_matrix_hip_copy.hpp" #include "joint_matrix_hip_fill.hpp" #include "joint_matrix_hip_mfma.hpp" int main() { - hip_matrix_fill(); - hip_matrix_fill(); - hip_matrix_fill(); - hip_matrix_fill(); + hip_matrix_mfma(); + hip_matrix_mfma(); + hip_matrix_mfma(); + hip_matrix_mfma(); + + hip_matrix_copy(); + hip_matrix_copy(); + hip_matrix_copy(); + hip_matrix_copy(); hip_matrix_fill(); hip_matrix_fill(); From 3d1237a7bd90f5473b182408179889dcf60dedcb Mon Sep 17 00:00:00 2001 From: mmoadeli Date: Thu, 12 Oct 2023 13:58:07 +0100 Subject: [PATCH 31/50] std::ignore unused input parameters. --- sycl/include/sycl/ext/oneapi/matrix/matrix-unified.hpp | 3 +++ 1 file changed, 3 insertions(+) diff --git a/sycl/include/sycl/ext/oneapi/matrix/matrix-unified.hpp b/sycl/include/sycl/ext/oneapi/matrix/matrix-unified.hpp index f1d9c4359bc65..f83d7a8c7ea11 100644 --- a/sycl/include/sycl/ext/oneapi/matrix/matrix-unified.hpp +++ b/sycl/include/sycl/ext/oneapi/matrix/matrix-unified.hpp @@ -234,6 +234,7 @@ inline __SYCL_ALWAYS_INLINE void joint_matrix_load( sycl::ext::oneapi::detail::load_accumulator_hip(res.hip_impl, src, stride, Layout, sg); #else + std::ignore = sg; using DecorT = typename sycl::detail::DecoratedType::type; DecorT *Ptr = sycl::detail::getDecorated(src); switch (Layout) { @@ -301,6 +302,7 @@ joint_matrix_load(Group &sg, NumCols, Use, Layout, Space>( res.hip_impl, src, stride, sg); #else + std::ignore = sg; using DecorT = typename sycl::detail::DecoratedType::type; DecorT *Ptr = sycl::detail::getDecorated(src); res.spvm = @@ -311,6 +313,7 @@ joint_matrix_load(Group &sg, spv_scope_traits::value); #endif // defined(__NVPTX__) #else + std::ignore = sg; std::ignore = res; std::ignore = src; std::ignore = stride; From ad7b8cdfda5a4358ba2f2f771ed7fce67f3a8a48 Mon Sep 17 00:00:00 2001 From: mmoadeli Date: Thu, 12 Oct 2023 14:22:08 +0100 Subject: [PATCH 32/50] std::ignore unused `sg` parameters in`joint_matrix_store`. --- sycl/include/sycl/ext/oneapi/matrix/matrix-unified.hpp | 1 + 1 file changed, 1 insertion(+) diff --git a/sycl/include/sycl/ext/oneapi/matrix/matrix-unified.hpp b/sycl/include/sycl/ext/oneapi/matrix/matrix-unified.hpp index f83d7a8c7ea11..2572384a71b1e 100644 --- a/sycl/include/sycl/ext/oneapi/matrix/matrix-unified.hpp +++ b/sycl/include/sycl/ext/oneapi/matrix/matrix-unified.hpp @@ -377,6 +377,7 @@ inline __SYCL_ALWAYS_INLINE void joint_matrix_store( } #endif // defined(__NVPTX__) #else + std::ignore = sg; std::ignore = src; std::ignore = dst; std::ignore = stride; From 4231a1cb2ee4850d8b4bf7926ade4151f7d09b7e Mon Sep 17 00:00:00 2001 From: mmoadeli Date: Thu, 12 Oct 2023 17:43:32 +0100 Subject: [PATCH 33/50] Fix AMD `joint_matrix_copy` function. --- sycl/include/sycl/ext/oneapi/matrix/matrix-unified.hpp | 2 +- sycl/test-e2e/Matrix/joint_matrix_hip_copy.hpp | 8 ++++++-- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/sycl/include/sycl/ext/oneapi/matrix/matrix-unified.hpp b/sycl/include/sycl/ext/oneapi/matrix/matrix-unified.hpp index 2572384a71b1e..222e58b56b75d 100644 --- a/sycl/include/sycl/ext/oneapi/matrix/matrix-unified.hpp +++ b/sycl/include/sycl/ext/oneapi/matrix/matrix-unified.hpp @@ -455,7 +455,7 @@ void joint_matrix_copy( } #elif defined(__HIP_PLATFORM_AMD_MFMA__) std::ignore = sg; - sycl::ext::oneapi::detail::joint_matrix_apply(src.hip_impl, src.hip_impl); + sycl::ext::oneapi::detail::joint_matrix_copy(src.hip_impl, src.hip_impl); #else using storage_element_type = typename oneapi::detail::jm_type_interpretation_helper_trait< diff --git a/sycl/test-e2e/Matrix/joint_matrix_hip_copy.hpp b/sycl/test-e2e/Matrix/joint_matrix_hip_copy.hpp index 6950f868149e3..1ebdf2fac0b8f 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_hip_copy.hpp +++ b/sycl/test-e2e/Matrix/joint_matrix_hip_copy.hpp @@ -7,6 +7,8 @@ using namespace sycl; using namespace sycl::ext::oneapi::experimental::matrix; using sycl::ext::oneapi::bfloat16; +namespace details { + template struct input_limit { static constexpr int value = M * N; }; @@ -19,6 +21,8 @@ template <> struct input_limit { static constexpr auto value = 128; }; +} // namespace details + template void hip_matrix_copy() { @@ -29,11 +33,11 @@ void hip_matrix_copy() { OutType E[M * N]; for (auto i = 0; i < M * K; ++i) { - A[i] = i % input_limit::value; + A[i] = i % details::input_limit::value; } for (auto i = 0; i < K * N; ++i) { - B[i] = i % input_limit::value; + B[i] = i % details::input_limit::value; } for (auto i = 0; i < M * N; ++i) { From 1a595806dc21e70e7ef0d837ee34cc708256b1ae Mon Sep 17 00:00:00 2001 From: mmoadeli Date: Fri, 13 Oct 2023 09:45:57 +0100 Subject: [PATCH 34/50] - Fix joint_matrix_hip_copy. - Initialize the matrix with random values. --- .../test-e2e/Matrix/joint_matrix_hip_copy.hpp | 22 +++++++++++-------- .../test-e2e/Matrix/joint_matrix_hip_mfma.hpp | 14 +++++++----- 2 files changed, 22 insertions(+), 14 deletions(-) diff --git a/sycl/test-e2e/Matrix/joint_matrix_hip_copy.hpp b/sycl/test-e2e/Matrix/joint_matrix_hip_copy.hpp index 1ebdf2fac0b8f..0af4844e70f77 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_hip_copy.hpp +++ b/sycl/test-e2e/Matrix/joint_matrix_hip_copy.hpp @@ -2,6 +2,7 @@ #include #include +#include using namespace sycl; using namespace sycl::ext::oneapi::experimental::matrix; @@ -32,21 +33,24 @@ void hip_matrix_copy() { OutType D[M * N]; OutType E[M * N]; + std::mt19937 gen(0); + std::uniform_real_distribution dist(-100, 100); + for (auto i = 0; i < M * K; ++i) { - A[i] = i % details::input_limit::value; + A[i] = static_cast(dist(gen)); } for (auto i = 0; i < K * N; ++i) { - B[i] = i % details::input_limit::value; + B[i] = static_cast(dist(gen)); } for (auto i = 0; i < M * N; ++i) { D[i] = 0; - C[i] = i; + C[i] = static_cast(dist(gen)); if (OutLayout == layout::row_major) - E[i] = i; + E[i] = C[i]; else - E[(i % N) * M + int(i / M)] = i; + E[(i % N) * M + int(i / M)] = C[i]; } try { @@ -74,10 +78,6 @@ void hip_matrix_copy() { joint_matrix sub_a, sub_a_copy{}; - joint_matrix_copy(sg, sub_c, sub_c_copy); - joint_matrix_copy(sg, sub_a, sub_a_copy); - joint_matrix_copy(sg, sub_b, sub_b_copy); - joint_matrix_load( sg, sub_a_copy, accA.template get_multi_ptr(), K); @@ -89,6 +89,10 @@ void hip_matrix_copy() { accC.template get_multi_ptr(), N, layout::row_major); + joint_matrix_copy(sg, sub_c, sub_c_copy); + joint_matrix_copy(sg, sub_a, sub_a_copy); + joint_matrix_copy(sg, sub_b, sub_b_copy); + joint_matrix_mad(sg, sub_c_copy, sub_a_copy, sub_b_copy, sub_c_copy); diff --git a/sycl/test-e2e/Matrix/joint_matrix_hip_mfma.hpp b/sycl/test-e2e/Matrix/joint_matrix_hip_mfma.hpp index c15e62e5222b7..4d5f209444890 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_hip_mfma.hpp +++ b/sycl/test-e2e/Matrix/joint_matrix_hip_mfma.hpp @@ -3,6 +3,7 @@ #include #include +#include using namespace sycl; using namespace sycl::ext::oneapi::experimental::matrix; @@ -29,21 +30,24 @@ void hip_matrix_mfma() { OutType D[M * N]; OutType E[M * N]; + std::mt19937 gen(0); + std::uniform_real_distribution dist(-100, 100); + for (auto i = 0; i < M * K; ++i) { - A[i] = i % input_limit::value; + A[i] = static_cast(dist(gen)); } for (auto i = 0; i < K * N; ++i) { - B[i] = i % input_limit::value; + B[i] = static_cast(dist(gen)); } for (auto i = 0; i < M * N; ++i) { D[i] = 0; - C[i] = i; + C[i] = static_cast(dist(gen)); if (OutLayout == layout::row_major) - E[i] = i; + E[i] = C[i]; else - E[(i % N) * M + int(i / M)] = i; + E[(i % N) * M + int(i / M)] = C[i]; } try { From fc31965414606ead62a1d15db28174da02850f37 Mon Sep 17 00:00:00 2001 From: mmoadeli Date: Fri, 13 Oct 2023 10:05:16 +0100 Subject: [PATCH 35/50] Remove curly braces for initialization of joint_matrix. --- sycl/test-e2e/Matrix/joint_matrix_hip_apply.hpp | 7 +++---- sycl/test-e2e/Matrix/joint_matrix_hip_copy.hpp | 6 +++--- sycl/test-e2e/Matrix/joint_matrix_hip_mfma.hpp | 7 +++---- 3 files changed, 9 insertions(+), 11 deletions(-) diff --git a/sycl/test-e2e/Matrix/joint_matrix_hip_apply.hpp b/sycl/test-e2e/Matrix/joint_matrix_hip_apply.hpp index ef3a63bd44396..062480b541249 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_hip_apply.hpp +++ b/sycl/test-e2e/Matrix/joint_matrix_hip_apply.hpp @@ -48,12 +48,11 @@ void hip_matrix_apply() { cgh.parallel_for( sycl::nd_range<2>{{4, 16}, {4, 16}}, [=](sycl::nd_item<2> idx) { auto sg = idx.get_sub_group(); - joint_matrix - sub_c{}; + joint_matrix sub_c; joint_matrix - sub_b{}; + sub_b; joint_matrix - sub_a{}; + sub_a; joint_matrix_load( sg, sub_a, diff --git a/sycl/test-e2e/Matrix/joint_matrix_hip_copy.hpp b/sycl/test-e2e/Matrix/joint_matrix_hip_copy.hpp index 0af4844e70f77..f5b058095941d 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_hip_copy.hpp +++ b/sycl/test-e2e/Matrix/joint_matrix_hip_copy.hpp @@ -72,11 +72,11 @@ void hip_matrix_copy() { sycl::nd_range<2>{{4, 16}, {4, 16}}, [=](sycl::nd_item<2> idx) { auto sg = idx.get_sub_group(); joint_matrix sub_c, - sub_c_copy{}; + sub_c_copy; joint_matrix - sub_b{}, sub_b_copy{}; + sub_b, sub_b_copy; joint_matrix - sub_a, sub_a_copy{}; + sub_a, sub_a_copy; joint_matrix_load( sg, sub_a_copy, diff --git a/sycl/test-e2e/Matrix/joint_matrix_hip_mfma.hpp b/sycl/test-e2e/Matrix/joint_matrix_hip_mfma.hpp index 4d5f209444890..a17eb784f17ed 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_hip_mfma.hpp +++ b/sycl/test-e2e/Matrix/joint_matrix_hip_mfma.hpp @@ -68,12 +68,11 @@ void hip_matrix_mfma() { cgh.parallel_for( sycl::nd_range<2>{{4, 16}, {4, 16}}, [=](sycl::nd_item<2> idx) { auto sg = idx.get_sub_group(); - joint_matrix - sub_c{}; + joint_matrix sub_c; joint_matrix - sub_b{}; + sub_b; joint_matrix - sub_a{}; + sub_a; joint_matrix_load( sg, sub_a, From 8bba0fb1eab9669d99fcdfecda876ba6e40cefdd Mon Sep 17 00:00:00 2001 From: mmoadeli Date: Fri, 13 Oct 2023 15:49:27 +0100 Subject: [PATCH 36/50] - Use sycl::marray as container for jont_matrix data. - Update joint_matrix_apply / test. - Remove un-necessary code from test. --- .../sycl/ext/oneapi/matrix/matrix-hip.hpp | 118 +++++++++--------- .../matrix/matrix-hip-bfloat16-float-test.cpp | 18 --- .../matrix/matrix-hip-double-double-test.cpp | 9 -- .../hip/matrix/matrix-hip-half-float-test.cpp | 18 --- .../hip/matrix/matrix-hip-int8-int32-test.cpp | 18 --- 5 files changed, 59 insertions(+), 122 deletions(-) diff --git a/sycl/include/sycl/ext/oneapi/matrix/matrix-hip.hpp b/sycl/include/sycl/ext/oneapi/matrix/matrix-hip.hpp index 1980a76d5fa38..fbaabb941a2ce 100644 --- a/sycl/include/sycl/ext/oneapi/matrix/matrix-hip.hpp +++ b/sycl/include/sycl/ext/oneapi/matrix/matrix-hip.hpp @@ -19,10 +19,6 @@ namespace sycl { inline namespace _V1 { namespace ext { namespace oneapi { -namespace experimental { -namespace matrix {} // namespace matrix -} // namespace experimental - namespace detail { template struct to_hip_type { using type = T; }; @@ -61,10 +65,7 @@ template <> struct to_hip_type { sycl::ext::oneapi::experimental::matrix::layout::row_major || \ Layout == \ sycl::ext::oneapi::experimental::matrix::layout::col_major>> { \ - using ext_array_t = __attribute__(( \ - __vector_size__(SIZE * sizeof(typename to_hip_type::type)))) \ - typename to_hip_type::type; \ - ext_array_t data = {0}; \ + sycl::marray data; \ }; __SYCL_JOINT_MATRIX_OVERLOAD_ARR(bfloat16, a, 16, 16, 4) @@ -80,35 +81,19 @@ __SYCL_JOINT_MATRIX_OVERLOAD_ARR(half, b, 8, 32, 4) __SYCL_JOINT_MATRIX_OVERLOAD_ARR(double, a, 16, 4, 1) __SYCL_JOINT_MATRIX_OVERLOAD_ARR(double, b, 4, 16, 1) -#undef __SYCL_JOINT_MATRIX_OVERLOAD_ARR +__SYCL_JOINT_MATRIX_OVERLOAD_ARR(int8_t, a, 32, 8, 4) +__SYCL_JOINT_MATRIX_OVERLOAD_ARR(int8_t, b, 8, 32, 4) +__SYCL_JOINT_MATRIX_OVERLOAD_ARR(int8_t, a, 16, 16, 4) +__SYCL_JOINT_MATRIX_OVERLOAD_ARR(int8_t, b, 16, 16, 4) -#define __SYCL_JOINT_MATRIX_OVERLOAD_INT8_ARR(USE, M, N, SIZE) \ - template \ - struct joint_matrix_hip< \ - int8_t, sycl::ext::oneapi::experimental::matrix::use::USE, M, N, Layout, \ - typename std::enable_if_t< \ - Layout == \ - sycl::ext::oneapi::experimental::matrix::layout::row_major || \ - Layout == \ - sycl::ext::oneapi::experimental::matrix::layout::col_major>> { \ - int8_t data[SIZE]; \ - }; - -__SYCL_JOINT_MATRIX_OVERLOAD_INT8_ARR(a, 32, 8, 4) -__SYCL_JOINT_MATRIX_OVERLOAD_INT8_ARR(b, 8, 32, 4) -__SYCL_JOINT_MATRIX_OVERLOAD_INT8_ARR(a, 16, 16, 4) -__SYCL_JOINT_MATRIX_OVERLOAD_INT8_ARR(b, 16, 16, 4) - -#undef __SYCL_JOINT_MATRIX_OVERLOAD_INT8_ARR +#undef __SYCL_JOINT_MATRIX_OVERLOAD_ARR #define __SYCL_JOINT_MATRIX_OVERLOAD_ARR_ACC(TYPE, M, N) \ template <> \ struct joint_matrix_hip< \ TYPE, sycl::ext::oneapi::experimental::matrix::use::accumulator, M, N, \ sycl::ext::oneapi::experimental::matrix::layout::dynamic> { \ - using ext_array_t = \ - __attribute__((__vector_size__((M * N) / 64 * sizeof(TYPE)))) TYPE; \ - ext_array_t data = {0}; \ + sycl::marray data; \ }; __SYCL_JOINT_MATRIX_OVERLOAD_ARR_ACC(float, 16, 16) @@ -198,10 +183,6 @@ void load_accumulator_hip( sycl::ext::oneapi::experimental::matrix::layout::dynamic> &res, multi_ptr src, size_t stride, sycl::ext::oneapi::experimental::matrix::layout layout, Group &sg) { - static_assert(std::is_same_v || std::is_same_v || - std::is_same_v, - "Unsupported matrix type!"); - if (layout == sycl::ext::oneapi::experimental::matrix::layout::row_major) load_accumulator_layoutT< sycl::ext::oneapi::experimental::matrix::layout::row_major>(res, src, @@ -225,10 +206,6 @@ template < void load_multiplicand_hip(joint_matrix_hip &res, multi_ptr src, size_t stride, Group &sg) { - static_assert(std::is_same_v || std::is_same_v || - std::is_same_v || std::is_same_v, - "Unsupported matrix type!"); - const auto idx = sg.get_group_linear_id() * sg.get_local_range()[0] + sg.get_local_linear_id(); @@ -365,37 +342,60 @@ void joint_matrix_mad_hip( sycl::ext::oneapi::experimental::matrix::layout::dynamic> &C) { if constexpr (std::is_same_v) { if constexpr (M == 16 && N == 16) { - D.data = __builtin_amdgcn_mfma_f32_16x16x16f16(A.data, B.data, C.data, 0, - 0, 0); + auto result = __builtin_amdgcn_mfma_f32_16x16x16f16( + *reinterpret_cast(&A.data), + *reinterpret_cast(&B.data), + *reinterpret_cast(&C.data), 0, 0, 0); + for (int i = 0; i < 4; ++i) + D.data[i] = result[i]; } else if constexpr (M == 32 && N == 32) { - D.data = - __builtin_amdgcn_mfma_f32_32x32x8f16(A.data, B.data, C.data, 0, 0, 0); + auto result = __builtin_amdgcn_mfma_f32_32x32x8f16( + *reinterpret_cast(&A.data), + *reinterpret_cast(&B.data), + *reinterpret_cast(&C.data), 0, 0, 0); + for (int i = 0; i < 16; ++i) + D.data[i] = result[i]; } } else if constexpr (std::is_same_v) { if constexpr (M == 16 && N == 16) { - D.data = __builtin_amdgcn_mfma_f32_16x16x16bf16_1k(A.data, B.data, C.data, - 0, 0, 0); + auto result = __builtin_amdgcn_mfma_f32_16x16x16bf16_1k( + *reinterpret_cast(&A.data), + *reinterpret_cast(&B.data), + *reinterpret_cast(&C.data), 0, 0, 0); + for (int i = 0; i < 4; ++i) + D.data[i] = result[i]; } else if constexpr (M == 32 && N == 32) { - D.data = __builtin_amdgcn_mfma_f32_32x32x8bf16_1k(A.data, B.data, C.data, - 0, 0, 0); + auto result = __builtin_amdgcn_mfma_f32_32x32x8bf16_1k( + *reinterpret_cast(&A.data), + *reinterpret_cast(&B.data), + *reinterpret_cast(&C.data), 0, 0, 0); + for (int i = 0; i < 16; ++i) + D.data[i] = result[i]; } } else if constexpr (std::is_same_v) { if constexpr (M == 16 && N == 16) { - D.data = __builtin_amdgcn_mfma_f64_16x16x4f64(A.data[0], B.data[0], - C.data, 0, 0, 0); + auto result = __builtin_amdgcn_mfma_f64_16x16x4f64( + A.data[0], B.data[0], *reinterpret_cast(&C.data), 0, + 0, 0); + for (int i = 0; i < 4; ++i) + D.data[i] = result[i]; } } else if constexpr (std::is_same_v) { if constexpr (M == 16 && N == 16) { - D.data = __builtin_amdgcn_mfma_i32_16x16x16i8( - *reinterpret_cast(A.data), - *reinterpret_cast(B.data), C.data, 0, 0, 0); + auto result = __builtin_amdgcn_mfma_i32_16x16x16i8( + *reinterpret_cast(&A.data), + *reinterpret_cast(&B.data), + *reinterpret_cast(&C.data), 0, 0, 0); + for (int i = 0; i < 4; ++i) + D.data[i] = result[i]; } else if constexpr (M == 32 && N == 32) { - D.data = __builtin_amdgcn_mfma_i32_32x32x8i8( - *reinterpret_cast(A.data), - *reinterpret_cast(B.data), C.data, 0, 0, 0); + auto result = __builtin_amdgcn_mfma_i32_32x32x8i8( + *reinterpret_cast(&A.data), + *reinterpret_cast(&B.data), + *reinterpret_cast(&C.data), 0, 0, 0); + for (int i = 0; i < 16; ++i) + D.data[i] = result[i]; } - } else { - static_assert(false && "Invalid configuration!"); } } @@ -407,16 +407,16 @@ void joint_matrix_apply(joint_matrix_hip &jm, if constexpr (std::is_same_v && Use != sycl::ext::oneapi::experimental::matrix::use::accumulator) { - jm.data[0] = lambda(jm.data[0]); + lambda(jm.data[0]); } else if constexpr ( Use != sycl::ext::oneapi::experimental::matrix::use::accumulator || (Use == sycl::ext::oneapi::experimental::matrix::use::accumulator && M == 16)) { for (auto i = 0; i < 4; ++i) - jm.data[i] = lambda(jm.data[i]); + lambda(jm.data[i]); } else { for (auto i = 0; i < 16; ++i) - jm.data[i] = lambda(jm.data[i]); + lambda(jm.data[i]); } } diff --git a/sycl/test/check_device_code/hip/matrix/matrix-hip-bfloat16-float-test.cpp b/sycl/test/check_device_code/hip/matrix/matrix-hip-bfloat16-float-test.cpp index 9ff29d317b3f8..f9f8e847606d0 100644 --- a/sycl/test/check_device_code/hip/matrix/matrix-hip-bfloat16-float-test.cpp +++ b/sycl/test/check_device_code/hip/matrix/matrix-hip-bfloat16-float-test.cpp @@ -42,15 +42,6 @@ int main() { joint_matrix sub_b{}; - joint_matrix_load( - sg, sub_c, accC.template get_multi_ptr(), - 16, layout::row_major); - joint_matrix_load( - sg, sub_a, accA.template get_multi_ptr(), - 16); - joint_matrix_load( - sg, sub_b, accB.template get_multi_ptr(), - 16); // CHECK: tail call <4 x float> @llvm.amdgcn.mfma.f32.16x16x16bf16.1k(<4 x i16> %{{.*}}, <4 x i16> %{{.*}} <4 x float> zeroinitializer, i32 0, i32 0, i32 0) sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); joint_matrix_store( @@ -69,15 +60,6 @@ int main() { joint_matrix sub_b{}; - joint_matrix_load( - sg, sub_c, accC.template get_multi_ptr(), - 32, layout::row_major); - joint_matrix_load( - sg, sub_a, accA.template get_multi_ptr(), - 32); - joint_matrix_load( - sg, sub_b, accB.template get_multi_ptr(), - 8); // CHECK: tail call <16 x float> @llvm.amdgcn.mfma.f32.32x32x8bf16.1k(<4 x i16> {{.*}}, <4 x i16> {{.*}}, <16 x float> zeroinitializer, i32 0, i32 0, i32 0) sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); joint_matrix_store( diff --git a/sycl/test/check_device_code/hip/matrix/matrix-hip-double-double-test.cpp b/sycl/test/check_device_code/hip/matrix/matrix-hip-double-double-test.cpp index 30cfdb1d8aa39..32fe4a23d4ea0 100644 --- a/sycl/test/check_device_code/hip/matrix/matrix-hip-double-double-test.cpp +++ b/sycl/test/check_device_code/hip/matrix/matrix-hip-double-double-test.cpp @@ -41,15 +41,6 @@ int main() { joint_matrix sub_b{}; - joint_matrix_load( - sg, sub_c, accC.template get_multi_ptr(), - 16, layout::row_major); - joint_matrix_load( - sg, sub_a, accA.template get_multi_ptr(), - 16); - joint_matrix_load( - sg, sub_b, accB.template get_multi_ptr(), - 4); // CHECK: tail call <4 x double> @llvm.amdgcn.mfma.f64.16x16x4f64(double %{{.*}}, double %{{.*}}, <4 x double> zeroinitializer, i32 0, i32 0, i32 0) sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); joint_matrix_store( diff --git a/sycl/test/check_device_code/hip/matrix/matrix-hip-half-float-test.cpp b/sycl/test/check_device_code/hip/matrix/matrix-hip-half-float-test.cpp index 5ee6aed4ae2f8..0db7e4ce45b16 100644 --- a/sycl/test/check_device_code/hip/matrix/matrix-hip-half-float-test.cpp +++ b/sycl/test/check_device_code/hip/matrix/matrix-hip-half-float-test.cpp @@ -41,15 +41,6 @@ int main() { joint_matrix sub_b{}; - joint_matrix_load( - sg, sub_c, accC.template get_multi_ptr(), - 16, layout::row_major); - joint_matrix_load( - sg, sub_a, accA.template get_multi_ptr(), - 16); - joint_matrix_load( - sg, sub_b, accB.template get_multi_ptr(), - 16); // CHECK: tail call <4 x float> @llvm.amdgcn.mfma.f32.16x16x16f16(<4 x half> %{{.*}}, <4 x half> %{{.*}}, <4 x float> zeroinitializer, i32 0, i32 0, i32 0) sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); joint_matrix_store( @@ -68,15 +59,6 @@ int main() { joint_matrix sub_b{}; - joint_matrix_load( - sg, sub_c, accC.template get_multi_ptr(), - 32, layout::row_major); - joint_matrix_load( - sg, sub_a, accA.template get_multi_ptr(), - 32); - joint_matrix_load( - sg, sub_b, accB.template get_multi_ptr(), - 8); // CHECK: tail call <16 x float> @llvm.amdgcn.mfma.f32.32x32x8f16(<4 x half> {{.*}}, <4 x half> {{.*}}, <16 x float> zeroinitializer, i32 0, i32 0, i32 0) sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); joint_matrix_store( diff --git a/sycl/test/check_device_code/hip/matrix/matrix-hip-int8-int32-test.cpp b/sycl/test/check_device_code/hip/matrix/matrix-hip-int8-int32-test.cpp index cbadcd03328a6..64db6fe21a493 100644 --- a/sycl/test/check_device_code/hip/matrix/matrix-hip-int8-int32-test.cpp +++ b/sycl/test/check_device_code/hip/matrix/matrix-hip-int8-int32-test.cpp @@ -41,15 +41,6 @@ int main() { joint_matrix sub_b{}; - joint_matrix_load( - sg, sub_c, accC.template get_multi_ptr(), - 16, layout::row_major); - joint_matrix_load( - sg, sub_a, accA.template get_multi_ptr(), - 16); - joint_matrix_load( - sg, sub_b, accB.template get_multi_ptr(), - 16); // CHECK: tail call <4 x i32> @llvm.amdgcn.mfma.i32.16x16x16i8(i32 %{{.*}}, i32 %{{.*}}, <4 x i32> zeroinitializer, i32 0, i32 0, i32 0) sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); joint_matrix_store( @@ -68,15 +59,6 @@ int main() { joint_matrix sub_b{}; - joint_matrix_load( - sg, sub_c, accC.template get_multi_ptr(), - 32, layout::row_major); - joint_matrix_load( - sg, sub_a, accA.template get_multi_ptr(), - 32); - joint_matrix_load( - sg, sub_b, accB.template get_multi_ptr(), - 8); // CHECK: tail call <16 x i32> @llvm.amdgcn.mfma.i32.32x32x8i8(i32 {{.*}}, i32 {{.*}}, <16 x i32> zeroinitializer, i32 0, i32 0, i32 0) sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); joint_matrix_store( From a152eda0abe02d1c52db2c8507e01129e7a7e33a Mon Sep 17 00:00:00 2001 From: mmoadeli Date: Fri, 13 Oct 2023 16:58:00 +0100 Subject: [PATCH 37/50] Modify `joint_matrix_apply` test. --- sycl/test-e2e/Matrix/joint_matrix_hip_apply.hpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/sycl/test-e2e/Matrix/joint_matrix_hip_apply.hpp b/sycl/test-e2e/Matrix/joint_matrix_hip_apply.hpp index 062480b541249..296128915a136 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_hip_apply.hpp +++ b/sycl/test-e2e/Matrix/joint_matrix_hip_apply.hpp @@ -65,9 +65,9 @@ void hip_matrix_apply() { accC.template get_multi_ptr(), N, layout::row_major); - joint_matrix_apply(sg, sub_a, [=](InType v) { return v * 2; }); - joint_matrix_apply(sg, sub_b, [=](InType v) { return v * 3; }); - joint_matrix_apply(sg, sub_c, [=](OutType v) { return v * 4; }); + joint_matrix_apply(sg, sub_a, [=](InType &v) { v *= 2; }); + joint_matrix_apply(sg, sub_b, [=](InType &v) { v *= 3; }); + joint_matrix_apply(sg, sub_c, [=](OutType &v) { v *= 4; }); joint_matrix_mad(sg, sub_c, sub_a, sub_b, sub_c); From b44a76c70def070594e9cbfbdeb74a6c36201831 Mon Sep 17 00:00:00 2001 From: mmoadeli Date: Fri, 13 Oct 2023 17:11:36 +0100 Subject: [PATCH 38/50] Update allow difference after using matrix random input. --- sycl/test-e2e/Matrix/joint_matrix_hip_copy.hpp | 2 +- sycl/test-e2e/Matrix/joint_matrix_hip_mfma.hpp | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/sycl/test-e2e/Matrix/joint_matrix_hip_copy.hpp b/sycl/test-e2e/Matrix/joint_matrix_hip_copy.hpp index f5b058095941d..3436f441e28f6 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_hip_copy.hpp +++ b/sycl/test-e2e/Matrix/joint_matrix_hip_copy.hpp @@ -121,6 +121,6 @@ void hip_matrix_copy() { } for (int i = 0; i < M * N; ++i) { - assert(abs(D[i] - E[i]) <= D[i] / 100 && "Unexpected difference"); + assert(abs(D[i] - E[i]) < 100 && "Unexpected difference"); } }; diff --git a/sycl/test-e2e/Matrix/joint_matrix_hip_mfma.hpp b/sycl/test-e2e/Matrix/joint_matrix_hip_mfma.hpp index a17eb784f17ed..aede9bd04a92c 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_hip_mfma.hpp +++ b/sycl/test-e2e/Matrix/joint_matrix_hip_mfma.hpp @@ -112,6 +112,6 @@ void hip_matrix_mfma() { } for (int i = 0; i < M * N; ++i) { - assert(abs(D[i] - E[i]) <= D[i] / 100 && "Unexpected difference"); + assert(abs(D[i] - E[i]) < 100 && "Unexpected difference"); } }; From 66131d759641d071ed899a3f299da4ae832e4ccf Mon Sep 17 00:00:00 2001 From: mmoadeli Date: Fri, 13 Oct 2023 19:48:06 +0100 Subject: [PATCH 39/50] Fix `hip_matrix_copy`. --- sycl/test-e2e/Matrix/joint_matrix_hip_copy.hpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/sycl/test-e2e/Matrix/joint_matrix_hip_copy.hpp b/sycl/test-e2e/Matrix/joint_matrix_hip_copy.hpp index 3436f441e28f6..276fa823d54a9 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_hip_copy.hpp +++ b/sycl/test-e2e/Matrix/joint_matrix_hip_copy.hpp @@ -79,13 +79,13 @@ void hip_matrix_copy() { sub_a, sub_a_copy; joint_matrix_load( - sg, sub_a_copy, + sg, sub_a, accA.template get_multi_ptr(), K); joint_matrix_load( - sg, sub_b_copy, + sg, sub_b, accB.template get_multi_ptr(), N); joint_matrix_load( - sg, sub_c_copy, + sg, sub_c, accC.template get_multi_ptr(), N, layout::row_major); From f719779cb7411b423ad16c63bae69464ff43b2e3 Mon Sep 17 00:00:00 2001 From: mmoadeli Date: Sun, 22 Oct 2023 20:54:46 +0100 Subject: [PATCH 40/50] - Improve hip mfma tests to support matrices of multiple of K size. - Fix 'load / double' to to support matrices of multiple of K size. --- .../sycl/ext/oneapi/matrix/matrix-hip.hpp | 2 +- .../test-e2e/Matrix/joint_matrix_hip_copy.hpp | 26 ++------ .../test-e2e/Matrix/joint_matrix_hip_fill.hpp | 8 +-- .../Matrix/joint_matrix_hip_gfx90a.cpp | 27 +++++--- .../Matrix/joint_matrix_hip_half_gfx90a.cpp | 15 +++-- .../test-e2e/Matrix/joint_matrix_hip_mfma.hpp | 66 +++++++++---------- 6 files changed, 69 insertions(+), 75 deletions(-) diff --git a/sycl/include/sycl/ext/oneapi/matrix/matrix-hip.hpp b/sycl/include/sycl/ext/oneapi/matrix/matrix-hip.hpp index fbaabb941a2ce..24a22a68d8c4e 100644 --- a/sycl/include/sycl/ext/oneapi/matrix/matrix-hip.hpp +++ b/sycl/include/sycl/ext/oneapi/matrix/matrix-hip.hpp @@ -214,7 +214,7 @@ void load_multiplicand_hip(joint_matrix_hip &res, sycl::ext::oneapi::experimental::matrix::layout::row_major) { res.data[0] = src[idx]; } else { - res.data[0] = src[(idx % M) * 4 + idx / M]; + res.data[0] = src[(idx % M) * stride + idx / M]; } } else { constexpr int Dim = (M == 16) ? 16 : 32; diff --git a/sycl/test-e2e/Matrix/joint_matrix_hip_copy.hpp b/sycl/test-e2e/Matrix/joint_matrix_hip_copy.hpp index 276fa823d54a9..b1ed65bc8fc93 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_hip_copy.hpp +++ b/sycl/test-e2e/Matrix/joint_matrix_hip_copy.hpp @@ -8,22 +8,6 @@ using namespace sycl; using namespace sycl::ext::oneapi::experimental::matrix; using sycl::ext::oneapi::bfloat16; -namespace details { - -template struct input_limit { - static constexpr int value = M * N; -}; - -template <> struct input_limit { - static constexpr auto value = 128; -}; - -template <> struct input_limit { - static constexpr auto value = 128; -}; - -} // namespace details - template void hip_matrix_copy() { @@ -34,7 +18,7 @@ void hip_matrix_copy() { OutType E[M * N]; std::mt19937 gen(0); - std::uniform_real_distribution dist(-100, 100); + std::uniform_real_distribution dist(-10, 10); for (auto i = 0; i < M * K; ++i) { A[i] = static_cast(dist(gen)); @@ -109,9 +93,9 @@ void hip_matrix_copy() { std::cout << "Exception caught: " << e.what() << std::endl; } - for (int m = 0; m < M; m++) { - for (int n = 0; n < N; n++) { - for (int k = 0; k < K; k++) { + for (auto m = 0; m < M; m++) { + for (auto n = 0; n < N; n++) { + for (auto k = 0; k < K; k++) { if (OutLayout == layout::row_major) E[m * N + n] += A[m * K + k] * B[k * N + n]; else @@ -120,7 +104,7 @@ void hip_matrix_copy() { } } - for (int i = 0; i < M * N; ++i) { + for (auto i = 0; i < M * N; ++i) { assert(abs(D[i] - E[i]) < 100 && "Unexpected difference"); } }; diff --git a/sycl/test-e2e/Matrix/joint_matrix_hip_fill.hpp b/sycl/test-e2e/Matrix/joint_matrix_hip_fill.hpp index c58e5bce5c5b5..642562aed4de9 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_hip_fill.hpp +++ b/sycl/test-e2e/Matrix/joint_matrix_hip_fill.hpp @@ -73,15 +73,15 @@ void hip_matrix_fill() { std::cout << "Exception caught: " << e.what() << std::endl; } - for (int m = 0; m < M; m++) { - for (int n = 0; n < N; n++) { - for (int k = 0; k < K; k++) { + for (auto m = 0; m < M; m++) { + for (auto n = 0; n < N; n++) { + for (auto k = 0; k < K; k++) { E[m * N + n] += A[m * K + k] * B[k * N + n]; } } } - for (int i = 0; i < M * N; ++i) { + for (auto i = 0; i < M * N; ++i) { assert(D[i] == E[i] && "Unexpected difference"); } }; diff --git a/sycl/test-e2e/Matrix/joint_matrix_hip_gfx90a.cpp b/sycl/test-e2e/Matrix/joint_matrix_hip_gfx90a.cpp index e437cd0e7b28c..db9ac7d4d4da6 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_hip_gfx90a.cpp +++ b/sycl/test-e2e/Matrix/joint_matrix_hip_gfx90a.cpp @@ -8,17 +8,24 @@ #include "joint_matrix_hip_fill.hpp" #include "joint_matrix_hip_mfma.hpp" +template void matrix_mfma() { + hip_matrix_mfma(); + hip_matrix_mfma(); + hip_matrix_mfma(); + hip_matrix_mfma(); + hip_matrix_mfma(); + hip_matrix_mfma(); + hip_matrix_mfma(); + hip_matrix_mfma(); + hip_matrix_mfma(); + hip_matrix_mfma(); +} + int main() { - hip_matrix_mfma(); - hip_matrix_mfma(); - hip_matrix_mfma(); - hip_matrix_mfma(); - hip_matrix_mfma(); - hip_matrix_mfma(); - hip_matrix_mfma(); - hip_matrix_mfma(); - hip_matrix_mfma(); - hip_matrix_mfma(); + matrix_mfma<1>(); + matrix_mfma<2>(); + matrix_mfma<3>(); + matrix_mfma<4>(); hip_matrix_copy(); hip_matrix_copy(); diff --git a/sycl/test-e2e/Matrix/joint_matrix_hip_half_gfx90a.cpp b/sycl/test-e2e/Matrix/joint_matrix_hip_half_gfx90a.cpp index 47fee217ce5d8..729bd8a16c18d 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_hip_half_gfx90a.cpp +++ b/sycl/test-e2e/Matrix/joint_matrix_hip_half_gfx90a.cpp @@ -9,11 +9,18 @@ #include "joint_matrix_hip_fill.hpp" #include "joint_matrix_hip_mfma.hpp" +template void half_matrix_mfma() { + hip_matrix_mfma(); + hip_matrix_mfma(); + hip_matrix_mfma(); + hip_matrix_mfma(); +} + int main() { - hip_matrix_mfma(); - hip_matrix_mfma(); - hip_matrix_mfma(); - hip_matrix_mfma(); + half_matrix_mfma<1>(); + half_matrix_mfma<2>(); + half_matrix_mfma<3>(); + half_matrix_mfma<4>(); hip_matrix_copy(); hip_matrix_copy(); diff --git a/sycl/test-e2e/Matrix/joint_matrix_hip_mfma.hpp b/sycl/test-e2e/Matrix/joint_matrix_hip_mfma.hpp index aede9bd04a92c..48b1143977fe2 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_hip_mfma.hpp +++ b/sycl/test-e2e/Matrix/joint_matrix_hip_mfma.hpp @@ -9,35 +9,23 @@ using namespace sycl; using namespace sycl::ext::oneapi::experimental::matrix; using sycl::ext::oneapi::bfloat16; -template struct input_limit { - static constexpr int value = M * N; -}; - -template <> struct input_limit { - static constexpr auto value = 128; -}; - -template <> struct input_limit { - static constexpr auto value = 128; -}; - template + size_t KX, layout OutLayout> void hip_matrix_mfma() { - InType A[M * K]; - InType B[K * N]; + InType A[M * K * KX]; + InType B[K * N * KX]; OutType C[M * N]; OutType D[M * N]; OutType E[M * N]; std::mt19937 gen(0); - std::uniform_real_distribution dist(-100, 100); + std::uniform_real_distribution dist(-10, 10); - for (auto i = 0; i < M * K; ++i) { + for (auto i = 0; i < M * K * KX; ++i) { A[i] = static_cast(dist(gen)); } - for (auto i = 0; i < K * N; ++i) { + for (auto i = 0; i < K * N * KX; ++i) { B[i] = static_cast(dist(gen)); } @@ -53,8 +41,8 @@ void hip_matrix_mfma() { try { auto defaultQueue = sycl::queue{}; - auto bufA = sycl::buffer{A, sycl::range{M * K}}; - auto bufB = sycl::buffer{B, sycl::range{K * N}}; + auto bufA = sycl::buffer{A, sycl::range{M * K * KX}}; + auto bufB = sycl::buffer{B, sycl::range{K * N * KX}}; auto bufC = sycl::buffer{C, sycl::range{M * N}}; auto bufD = sycl::buffer{D, sycl::range{M * N}}; @@ -74,18 +62,24 @@ void hip_matrix_mfma() { joint_matrix sub_a; - joint_matrix_load( - sg, sub_a, - accA.template get_multi_ptr(), K); - joint_matrix_load( - sg, sub_b, - accB.template get_multi_ptr(), N); joint_matrix_load( sg, sub_c, accC.template get_multi_ptr(), N, layout::row_major); - joint_matrix_mad(sg, sub_c, sub_a, sub_b, sub_c); + for (auto kx = 0; kx < KX; ++kx) { + joint_matrix_load( + sg, sub_a, + accA.template get_multi_ptr() + + kx * K, + K * KX); + joint_matrix_load( + sg, sub_b, + accB.template get_multi_ptr() + + kx * K * N, + N); + joint_matrix_mad(sg, sub_c, sub_a, sub_b, sub_c); + } joint_matrix_store( sg, sub_c, @@ -100,18 +94,20 @@ void hip_matrix_mfma() { std::cout << "Exception caught: " << e.what() << std::endl; } - for (int m = 0; m < M; m++) { - for (int n = 0; n < N; n++) { - for (int k = 0; k < K; k++) { - if (OutLayout == layout::row_major) - E[m * N + n] += A[m * K + k] * B[k * N + n]; - else - E[n * M + m] += A[m * K + k] * B[k * N + n]; + for (auto kx = 0; kx < KX; kx++) { + for (auto m = 0; m < M; m++) { + for (auto n = 0; n < N; n++) { + for (auto k = 0; k < K; k++) { + if (OutLayout == layout::row_major) + E[m * N + n] += A[m * K + k + kx * K] * B[k * N + n + kx * K * N]; + else + E[n * M + m] += A[m * K + k + kx * K] * B[k * N + n + kx * K * N]; + } } } } - for (int i = 0; i < M * N; ++i) { + for (auto i = 0; i < M * N; ++i) { assert(abs(D[i] - E[i]) < 100 && "Unexpected difference"); } }; From 7aa4ce393d2734c7347c2d6bd4e79339b48013ba Mon Sep 17 00:00:00 2001 From: mmoadeli Date: Mon, 23 Oct 2023 14:27:36 +0100 Subject: [PATCH 41/50] Fix call to `copy` and `fill` for hip joint matrix. --- sycl/include/sycl/ext/oneapi/matrix/matrix-unified.hpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sycl/include/sycl/ext/oneapi/matrix/matrix-unified.hpp b/sycl/include/sycl/ext/oneapi/matrix/matrix-unified.hpp index 222e58b56b75d..e76963f2707fd 100644 --- a/sycl/include/sycl/ext/oneapi/matrix/matrix-unified.hpp +++ b/sycl/include/sycl/ext/oneapi/matrix/matrix-unified.hpp @@ -193,7 +193,7 @@ joint_matrix_fill(Group, res.cuda_impl.wi_marray = v; #elif defined(__HIP_PLATFORM_AMD_MFMA__) sycl::ext::oneapi::detail::joint_matrix_apply(res.hip_impl, - [=](T) { return v; }); + [=](T &value) { value = v; }); #else using storage_element_type = typename oneapi::detail::jm_type_interpretation_helper_trait< @@ -455,7 +455,7 @@ void joint_matrix_copy( } #elif defined(__HIP_PLATFORM_AMD_MFMA__) std::ignore = sg; - sycl::ext::oneapi::detail::joint_matrix_copy(src.hip_impl, src.hip_impl); + sycl::ext::oneapi::detail::joint_matrix_copy(src.hip_impl, dst.hip_impl); #else using storage_element_type = typename oneapi::detail::jm_type_interpretation_helper_trait< From 3d1348406278ce5b9039a719b2340ac247840ed8 Mon Sep 17 00:00:00 2001 From: mmoadeli Date: Mon, 23 Oct 2023 16:41:56 +0100 Subject: [PATCH 42/50] Fix reference mma calculation. --- .../test-e2e/Matrix/joint_matrix_hip_mfma.hpp | 20 ++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/sycl/test-e2e/Matrix/joint_matrix_hip_mfma.hpp b/sycl/test-e2e/Matrix/joint_matrix_hip_mfma.hpp index 48b1143977fe2..650bcbaa0908b 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_hip_mfma.hpp +++ b/sycl/test-e2e/Matrix/joint_matrix_hip_mfma.hpp @@ -94,16 +94,18 @@ void hip_matrix_mfma() { std::cout << "Exception caught: " << e.what() << std::endl; } - for (auto kx = 0; kx < KX; kx++) { - for (auto m = 0; m < M; m++) { - for (auto n = 0; n < N; n++) { - for (auto k = 0; k < K; k++) { - if (OutLayout == layout::row_major) - E[m * N + n] += A[m * K + k + kx * K] * B[k * N + n + kx * K * N]; - else - E[n * M + m] += A[m * K + k + kx * K] * B[k * N + n + kx * K * N]; - } + constexpr int LDA = K * KX; + + for (auto m = 0; m < M; m++) { + for (auto n = 0; n < N; n++) { + OutType e = 0; + for (auto k = 0; k < LDA; k++) { + e += A[m * LDA + k] * B[k * N + n]; } + if (OutLayout == layout::row_major) + E[m * N + n] += e; + else + E[n * M + m] += e; } } From af5cc07e230684a6eddaf81a37c0d14d4c532c35 Mon Sep 17 00:00:00 2001 From: mmoadeli Date: Mon, 23 Oct 2023 16:56:29 +0100 Subject: [PATCH 43/50] Rename `cuda_impl` and `hip_impl` member of `joint_matrix` to `matrix_impl` --- .../sycl/ext/oneapi/matrix/matrix-unified.hpp | 43 ++++++++++--------- 1 file changed, 22 insertions(+), 21 deletions(-) diff --git a/sycl/include/sycl/ext/oneapi/matrix/matrix-unified.hpp b/sycl/include/sycl/ext/oneapi/matrix/matrix-unified.hpp index e76963f2707fd..55cb9d1354062 100644 --- a/sycl/include/sycl/ext/oneapi/matrix/matrix-unified.hpp +++ b/sycl/include/sycl/ext/oneapi/matrix/matrix-unified.hpp @@ -43,10 +43,10 @@ struct joint_matrix { #if defined(__NVPTX__) mutable sycl::ext::oneapi::detail::joint_matrix_cuda - cuda_impl; + matrix_impl; #elif defined(__HIP_PLATFORM_AMD_MFMA__) sycl::ext::oneapi::detail::joint_matrix_hip - hip_impl; + matrix_impl; #elif defined(__SPIR__) __spv::__spirv_JointMatrixINTEL< T, Rows, Cols, spv_matrix_layout_traits::value, @@ -89,13 +89,13 @@ class wi_data { public: size_t length() { #if defined(__NVPTX__) - return jm.cuda_impl.wi_marray.size(); + return jm.matrix_impl.wi_marray.size(); #endif }; decltype(auto) operator[](size_t i) { #if defined(__NVPTX__) - return (jm.cuda_impl.wi_marray[i]); + return (jm.matrix_impl.wi_marray[i]); #else std::ignore = i; #endif @@ -155,12 +155,12 @@ joint_matrix_apply(Group sg, joint_matrix &jm, #if defined(__SYCL_DEVICE_ONLY__) #if defined(__NVPTX__) std::ignore = sg; - for (int i = 0; i < jm.cuda_impl.wi_marray.size(); i++) { - lambda(jm.cuda_impl.wi_marray[i]); + for (int i = 0; i < jm.matrix_impl.wi_marray.size(); i++) { + lambda(jm.matrix_impl.wi_marray[i]); } #elif defined(__HIP_PLATFORM_AMD_MFMA__) std::ignore = sg; - sycl::ext::oneapi::detail::joint_matrix_apply(jm.hip_impl, lambda); + sycl::ext::oneapi::detail::joint_matrix_apply(jm.matrix_impl, lambda); #else // NVPTX using storage_element_type = typename oneapi::detail::jm_type_interpretation_helper_trait< @@ -190,9 +190,9 @@ joint_matrix_fill(Group, const T2 &v) { #if defined(__SYCL_DEVICE_ONLY__) #if defined(__NVPTX__) - res.cuda_impl.wi_marray = v; + res.matrix_impl.wi_marray = v; #elif defined(__HIP_PLATFORM_AMD_MFMA__) - sycl::ext::oneapi::detail::joint_matrix_apply(res.hip_impl, + sycl::ext::oneapi::detail::joint_matrix_apply(res.matrix_impl, [=](T &value) { value = v; }); #else using storage_element_type = @@ -228,10 +228,10 @@ inline __SYCL_ALWAYS_INLINE void joint_matrix_load( "Joint Matrix doesn't support load from private memory!"); #if defined(__NVPTX__) std::ignore = sg; - sycl::ext::oneapi::detail::load_accumulator_cuda(res.cuda_impl, src, stride, + sycl::ext::oneapi::detail::load_accumulator_cuda(res.matrix_impl, src, stride, Layout); #elif defined(__HIP_PLATFORM_AMD_MFMA__) - sycl::ext::oneapi::detail::load_accumulator_hip(res.hip_impl, src, stride, + sycl::ext::oneapi::detail::load_accumulator_hip(res.matrix_impl, src, stride, Layout, sg); #else std::ignore = sg; @@ -296,11 +296,11 @@ joint_matrix_load(Group &sg, std::ignore = sg; sycl::ext::oneapi::detail::load_multiplicand_cuda( - res.cuda_impl, src, stride); + res.matrix_impl, src, stride); #elif defined(__HIP_PLATFORM_AMD_MFMA__) sycl::ext::oneapi::detail::load_multiplicand_hip( - res.hip_impl, src, stride, sg); + res.matrix_impl, src, stride, sg); #else std::ignore = sg; using DecorT = typename sycl::detail::DecoratedType::type; @@ -337,11 +337,11 @@ inline __SYCL_ALWAYS_INLINE void joint_matrix_store( #if defined(__NVPTX__) std::ignore = sg; sycl::ext::oneapi::detail::joint_matrix_store_cuda(src.cuda_impl, dst, - stride, Layout); + Space>( + src.matrix_impl, dst, stride, Layout); #elif defined(__HIP_PLATFORM_AMD_MFMA__) sycl::ext::oneapi::detail::joint_matrix_store_hip(src.hip_impl, dst, + Space>(src.matrix_impl, dst, stride, Layout, sg); #else std::ignore = sg; @@ -404,7 +404,7 @@ inline __SYCL_ALWAYS_INLINE void joint_matrix_mad( if constexpr (std::is_same::value) { sycl::ext::oneapi::detail::joint_matrix_mad_cuda( - D.cuda_impl, A.cuda_impl, B.cuda_impl, C.cuda_impl); + D.matrix_impl, A.matrix_impl, B.matrix_impl, C.matrix_impl); } else { assert(false && "Ta != Tb : In the CUDA backend joint_matrix_mad " "requires that joint_matrix data types Ta and Tb match"); @@ -413,7 +413,7 @@ inline __SYCL_ALWAYS_INLINE void joint_matrix_mad( if constexpr (std::is_same::value) { sycl::ext::oneapi::detail::joint_matrix_mad_hip( - D.hip_impl, A.hip_impl, B.hip_impl, C.hip_impl); + D.matrix_impl, A.matrix_impl, B.matrix_impl, C.matrix_impl); } else { assert(false && "Ta != Tb : In the HIP backend joint_matrix_mad " "requires that joint_matrix data types Ta and Tb match"); @@ -450,12 +450,13 @@ void joint_matrix_copy( #if defined(__SYCL_DEVICE_ONLY__) #if defined(__NVPTX__) std::ignore = sg; - for (int i = 0; i < src.cuda_impl.wi_marray.size(); i++) { - dst.cuda_impl.wi_marray[i] = src.cuda_impl.wi_marray[i]; + for (int i = 0; i < src.matrix_impl.wi_marray.size(); i++) { + dst.matrix_impl.wi_marray[i] = src.matrix_impl.wi_marray[i]; } #elif defined(__HIP_PLATFORM_AMD_MFMA__) std::ignore = sg; - sycl::ext::oneapi::detail::joint_matrix_copy(src.hip_impl, dst.hip_impl); + sycl::ext::oneapi::detail::joint_matrix_copy(src.matrix_impl, + dst.matrix_impl); #else using storage_element_type = typename oneapi::detail::jm_type_interpretation_helper_trait< From c6a3cee806065d0c61b606d985290667d65804e3 Mon Sep 17 00:00:00 2001 From: mmoadeli Date: Mon, 23 Oct 2023 20:33:11 +0100 Subject: [PATCH 44/50] Replace `cuda_impl` with `matrix_impl` --- sycl/include/sycl/ext/oneapi/matrix/matrix-intel.hpp | 4 ++-- sycl/include/sycl/ext/oneapi/matrix/matrix-unified.hpp | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/sycl/include/sycl/ext/oneapi/matrix/matrix-intel.hpp b/sycl/include/sycl/ext/oneapi/matrix/matrix-intel.hpp index b852e3f1ff3f5..e1efa463939dd 100644 --- a/sycl/include/sycl/ext/oneapi/matrix/matrix-intel.hpp +++ b/sycl/include/sycl/ext/oneapi/matrix/matrix-intel.hpp @@ -533,8 +533,8 @@ inline __SYCL_ALWAYS_INLINE void joint_matrix_apply( #if defined(__SYCL_DEVICE_ONLY__) #if defined(__NVPTX__) std::ignore = sg; - for (int i = 0; i < jm.cuda_impl.wi_marray.size(); i++) { - lambda(jm.cuda_impl.wi_marray[i]); + for (int i = 0; i < jm.matrix_impl.wi_marray.size(); i++) { + lambda(jm.matrix_impl.wi_marray[i]); } #else // NVPTX using storage_element_type = diff --git a/sycl/include/sycl/ext/oneapi/matrix/matrix-unified.hpp b/sycl/include/sycl/ext/oneapi/matrix/matrix-unified.hpp index fc73269b60650..62d154ac83ba2 100644 --- a/sycl/include/sycl/ext/oneapi/matrix/matrix-unified.hpp +++ b/sycl/include/sycl/ext/oneapi/matrix/matrix-unified.hpp @@ -347,7 +347,7 @@ joint_matrix_mad( if constexpr (std::is_same::value) { sycl::ext::oneapi::detail::joint_matrix_mad_cuda( - D.cuda_impl, A.cuda_impl, B.cuda_impl, C.cuda_impl); + D.matrix_impl, A.matrix_impl, B.matrix_impl, C.matrix_impl); } else { assert(false && "Ta != Tb : In the CUDA backend joint_matrix_mad " "requires that joint_matrix data types Ta and Tb match"); From 2f6885fc845ef2831ae1ea50448640e95b239585 Mon Sep 17 00:00:00 2001 From: mmoadeli Date: Tue, 24 Oct 2023 13:29:39 +0100 Subject: [PATCH 45/50] Rename `data` in `joint_matrix_hip` with `wi_marray`. Use same code for `copy`, `fill` and `apply`. Remove `-DSYCL_EXT_ONEAPI_MATRIX_VERSION=4` --- .../sycl/ext/oneapi/matrix/matrix-hip.hpp | 136 ++++++------------ .../sycl/ext/oneapi/matrix/matrix-unified.hpp | 16 +-- .../matrix/matrix-hip-bfloat16-float-test.cpp | 2 +- .../matrix/matrix-hip-double-double-test.cpp | 2 +- .../hip/matrix/matrix-hip-half-float-test.cpp | 2 +- .../hip/matrix/matrix-hip-int8-int32-test.cpp | 2 +- 6 files changed, 54 insertions(+), 106 deletions(-) diff --git a/sycl/include/sycl/ext/oneapi/matrix/matrix-hip.hpp b/sycl/include/sycl/ext/oneapi/matrix/matrix-hip.hpp index 24a22a68d8c4e..e53dc21dec8e5 100644 --- a/sycl/include/sycl/ext/oneapi/matrix/matrix-hip.hpp +++ b/sycl/include/sycl/ext/oneapi/matrix/matrix-hip.hpp @@ -30,6 +30,8 @@ struct joint_matrix_hip; #if defined(__SYCL_DEVICE_ONLY__) && defined(__HIP_PLATFORM_AMD_MFMA__) +constexpr int WAVEFRONT_SIZE = 64; + using bfloat16x4 = __attribute__((__vector_size__(4 * sizeof(__bf16)))) __fp16; using float16x4 = __attribute__((__vector_size__(4 * sizeof(__fp16)))) __fp16; using floatx4 = __attribute__((__vector_size__(4 * sizeof(float)))) float; @@ -65,7 +67,7 @@ template <> struct to_hip_type { sycl::ext::oneapi::experimental::matrix::layout::row_major || \ Layout == \ sycl::ext::oneapi::experimental::matrix::layout::col_major>> { \ - sycl::marray data; \ + sycl::marray wi_marray; \ }; __SYCL_JOINT_MATRIX_OVERLOAD_ARR(bfloat16, a, 16, 16, 4) @@ -93,7 +95,7 @@ __SYCL_JOINT_MATRIX_OVERLOAD_ARR(int8_t, b, 16, 16, 4) struct joint_matrix_hip< \ TYPE, sycl::ext::oneapi::experimental::matrix::use::accumulator, M, N, \ sycl::ext::oneapi::experimental::matrix::layout::dynamic> { \ - sycl::marray data; \ + sycl::marray wi_marray; \ }; __SYCL_JOINT_MATRIX_OVERLOAD_ARR_ACC(float, 16, 16) @@ -123,12 +125,12 @@ void load_accumulator_layoutT( sycl::ext::oneapi::experimental::matrix::layout::row_major) { for (int i = 0; i < 4; ++i) { const int s_idx = thread_x + i * 4 * stride + thread_y * stride; - res.data[i] = src[s_idx]; + res.wi_marray[i] = src[s_idx]; } } else { for (int i = 0; i < 4; ++i) { const int s_idx = i * 4 + thread_x * stride + thread_y; - res.data[i] = src[s_idx]; + res.wi_marray[i] = src[s_idx]; } } } else if constexpr (std::is_same_v || std::is_same_v) { @@ -140,12 +142,12 @@ void load_accumulator_layoutT( row_major) { for (int i = 0; i < 4; ++i) { const int s_idx = thread_x + i * stride + thread_y * 4 * stride; - res.data[i] = src[s_idx]; + res.wi_marray[i] = src[s_idx]; } } else { for (int i = 0; i < 4; ++i) { const int s_idx = i + thread_x * stride + thread_y * 4; - res.data[i] = src[s_idx]; + res.wi_marray[i] = src[s_idx]; } } } else if constexpr (M == 32 && N == 32) { @@ -158,14 +160,14 @@ void load_accumulator_layoutT( for (int i = 0; i < 4; ++i) { const int s_idx = thread_x + i * stride + thread_y * 4 * stride + j * 8 * N; - res.data[i + 4 * j] = src[s_idx]; + res.wi_marray[i + 4 * j] = src[s_idx]; } } } else { for (int j = 0; j < 4; ++j) { for (int i = 0; i < 4; ++i) { const int s_idx = i + thread_x * stride + thread_y * 4 + j * 8; - res.data[i + 4 * j] = src[s_idx]; + res.wi_marray[i + 4 * j] = src[s_idx]; } } } @@ -212,9 +214,9 @@ void load_multiplicand_hip(joint_matrix_hip &res, if constexpr (std::is_same_v) { if constexpr (Layout == sycl::ext::oneapi::experimental::matrix::layout::row_major) { - res.data[0] = src[idx]; + res.wi_marray[0] = src[idx]; } else { - res.data[0] = src[(idx % M) * stride + idx / M]; + res.wi_marray[0] = src[(idx % M) * stride + idx / M]; } } else { constexpr int Dim = (M == 16) ? 16 : 32; @@ -226,12 +228,12 @@ void load_multiplicand_hip(joint_matrix_hip &res, sycl::ext::oneapi::experimental::matrix::layout::col_major) { for (int i = 0; i < 4; ++i) { const int c_idx = thread_x * stride + i + thread_y * 4; - res.data[i] = src[c_idx]; + res.wi_marray[i] = src[c_idx]; } } else { for (int i = 0; i < 4; ++i) { const int r_idx = thread_x + i * stride + thread_y * stride * 4; - res.data[i] = src[r_idx]; + res.wi_marray[i] = src[r_idx]; } } } @@ -257,12 +259,12 @@ void store_layoutT( sycl::ext::oneapi::experimental::matrix::layout::row_major) { for (int i = 0; i < 4; ++i) { const int d_idx = thread_x + i * 4 * stride + thread_y * stride; - dst[d_idx] = src.data[i]; + dst[d_idx] = src.wi_marray[i]; } } else { for (int i = 0; i < 4; ++i) { const int d_idx = i * 4 + thread_x * stride + thread_y; - dst[d_idx] = src.data[i]; + dst[d_idx] = src.wi_marray[i]; } } } else if constexpr (std::is_same_v || std::is_same_v) { @@ -274,12 +276,12 @@ void store_layoutT( row_major) { for (int i = 0; i < 4; ++i) { const int d_idx = thread_x + i * stride + thread_y * 4 * stride; - dst[d_idx] = src.data[i]; + dst[d_idx] = src.wi_marray[i]; } } else { for (int i = 0; i < 4; ++i) { const int d_idx = i + thread_x * stride + thread_y * 4; - dst[d_idx] = src.data[i]; + dst[d_idx] = src.wi_marray[i]; } } } else if constexpr (M == 32 && N == 32) { @@ -292,14 +294,14 @@ void store_layoutT( for (int i = 0; i < 4; ++i) { const int d_idx = thread_x + i * stride + thread_y * 4 * stride + j * 8 * stride; - dst[d_idx] = src.data[i + 4 * j]; + dst[d_idx] = src.wi_marray[i + 4 * j]; } } } else { for (int j = 0; j < 4; ++j) { for (int i = 0; i < 4; ++i) { const int d_idx = i + thread_x * stride + thread_y * 4 + j * 8; - dst[d_idx] = src.data[i + 4 * j]; + dst[d_idx] = src.wi_marray[i + 4 * j]; } } } @@ -343,106 +345,62 @@ void joint_matrix_mad_hip( if constexpr (std::is_same_v) { if constexpr (M == 16 && N == 16) { auto result = __builtin_amdgcn_mfma_f32_16x16x16f16( - *reinterpret_cast(&A.data), - *reinterpret_cast(&B.data), - *reinterpret_cast(&C.data), 0, 0, 0); + *reinterpret_cast(&A.wi_marray), + *reinterpret_cast(&B.wi_marray), + *reinterpret_cast(&C.wi_marray), 0, 0, 0); for (int i = 0; i < 4; ++i) - D.data[i] = result[i]; + D.wi_marray[i] = result[i]; } else if constexpr (M == 32 && N == 32) { auto result = __builtin_amdgcn_mfma_f32_32x32x8f16( - *reinterpret_cast(&A.data), - *reinterpret_cast(&B.data), - *reinterpret_cast(&C.data), 0, 0, 0); + *reinterpret_cast(&A.wi_marray), + *reinterpret_cast(&B.wi_marray), + *reinterpret_cast(&C.wi_marray), 0, 0, 0); for (int i = 0; i < 16; ++i) - D.data[i] = result[i]; + D.wi_marray[i] = result[i]; } } else if constexpr (std::is_same_v) { if constexpr (M == 16 && N == 16) { auto result = __builtin_amdgcn_mfma_f32_16x16x16bf16_1k( - *reinterpret_cast(&A.data), - *reinterpret_cast(&B.data), - *reinterpret_cast(&C.data), 0, 0, 0); + *reinterpret_cast(&A.wi_marray), + *reinterpret_cast(&B.wi_marray), + *reinterpret_cast(&C.wi_marray), 0, 0, 0); for (int i = 0; i < 4; ++i) - D.data[i] = result[i]; + D.wi_marray[i] = result[i]; } else if constexpr (M == 32 && N == 32) { auto result = __builtin_amdgcn_mfma_f32_32x32x8bf16_1k( - *reinterpret_cast(&A.data), - *reinterpret_cast(&B.data), - *reinterpret_cast(&C.data), 0, 0, 0); + *reinterpret_cast(&A.wi_marray), + *reinterpret_cast(&B.wi_marray), + *reinterpret_cast(&C.wi_marray), 0, 0, 0); for (int i = 0; i < 16; ++i) - D.data[i] = result[i]; + D.wi_marray[i] = result[i]; } } else if constexpr (std::is_same_v) { if constexpr (M == 16 && N == 16) { auto result = __builtin_amdgcn_mfma_f64_16x16x4f64( - A.data[0], B.data[0], *reinterpret_cast(&C.data), 0, - 0, 0); + A.wi_marray[0], B.wi_marray[0], + *reinterpret_cast(&C.wi_marray), 0, 0, 0); for (int i = 0; i < 4; ++i) - D.data[i] = result[i]; + D.wi_marray[i] = result[i]; } } else if constexpr (std::is_same_v) { if constexpr (M == 16 && N == 16) { auto result = __builtin_amdgcn_mfma_i32_16x16x16i8( - *reinterpret_cast(&A.data), - *reinterpret_cast(&B.data), - *reinterpret_cast(&C.data), 0, 0, 0); + *reinterpret_cast(&A.wi_marray), + *reinterpret_cast(&B.wi_marray), + *reinterpret_cast(&C.wi_marray), 0, 0, 0); for (int i = 0; i < 4; ++i) - D.data[i] = result[i]; + D.wi_marray[i] = result[i]; } else if constexpr (M == 32 && N == 32) { auto result = __builtin_amdgcn_mfma_i32_32x32x8i8( - *reinterpret_cast(&A.data), - *reinterpret_cast(&B.data), - *reinterpret_cast(&C.data), 0, 0, 0); + *reinterpret_cast(&A.wi_marray), + *reinterpret_cast(&B.wi_marray), + *reinterpret_cast(&C.wi_marray), 0, 0, 0); for (int i = 0; i < 16; ++i) - D.data[i] = result[i]; + D.wi_marray[i] = result[i]; } } } -template -void joint_matrix_apply(joint_matrix_hip &jm, - F &&lambda) { - if constexpr (std::is_same_v && - Use != - sycl::ext::oneapi::experimental::matrix::use::accumulator) { - lambda(jm.data[0]); - } else if constexpr ( - Use != sycl::ext::oneapi::experimental::matrix::use::accumulator || - (Use == sycl::ext::oneapi::experimental::matrix::use::accumulator && - M == 16)) { - for (auto i = 0; i < 4; ++i) - lambda(jm.data[i]); - } else { - for (auto i = 0; i < 16; ++i) - lambda(jm.data[i]); - } -} - -template -void joint_matrix_copy(joint_matrix_hip &src, - joint_matrix_hip &dst) { - if constexpr (std::is_same_v && - Use1 != - sycl::ext::oneapi::experimental::matrix::use::accumulator) { - dst.data[0] = src.data[0]; - } else if constexpr ( - Use1 != sycl::ext::oneapi::experimental::matrix::use::accumulator || - (Use1 == sycl::ext::oneapi::experimental::matrix::use::accumulator && - M == 16)) { - for (auto i = 0; i < 4; ++i) - dst.data[i] = src.data[i]; - } else { - for (auto i = 0; i < 16; ++i) - src.data[i] = src.data[i]; - } -} - #endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) } // namespace detail diff --git a/sycl/include/sycl/ext/oneapi/matrix/matrix-unified.hpp b/sycl/include/sycl/ext/oneapi/matrix/matrix-unified.hpp index 62d154ac83ba2..15da559d7e241 100644 --- a/sycl/include/sycl/ext/oneapi/matrix/matrix-unified.hpp +++ b/sycl/include/sycl/ext/oneapi/matrix/matrix-unified.hpp @@ -84,14 +84,11 @@ inline __SYCL_ALWAYS_INLINE void joint_matrix_apply(Group sg, joint_matrix &jm, F &&lambda) { #if defined(__SYCL_DEVICE_ONLY__) -#if defined(__NVPTX__) +#if defined(__NVPTX__) || defined(__HIP_PLATFORM_AMD_MFMA__) std::ignore = sg; for (int i = 0; i < jm.matrix_impl.wi_marray.size(); i++) { lambda(jm.matrix_impl.wi_marray[i]); } -#elif defined(__HIP_PLATFORM_AMD_MFMA__) - std::ignore = sg; - sycl::ext::oneapi::detail::joint_matrix_apply(jm.matrix_impl, lambda); #else // NVPTX using storage_element_type = typename oneapi::detail::jm_type_interpretation_helper_trait< @@ -120,11 +117,8 @@ joint_matrix_fill(Group, joint_matrix &res, const T2 &v) { #if defined(__SYCL_DEVICE_ONLY__) -#if defined(__NVPTX__) +#if defined(__NVPTX__) || defined(__HIP_PLATFORM_AMD_MFMA__) res.matrix_impl.wi_marray = v; -#elif defined(__HIP_PLATFORM_AMD_MFMA__) - sycl::ext::oneapi::detail::joint_matrix_apply(res.matrix_impl, - [=](T &value) { value = v; }); #else using storage_element_type = typename oneapi::detail::jm_type_interpretation_helper_trait< @@ -391,15 +385,11 @@ void joint_matrix_copy( Group sg, joint_matrix &src, joint_matrix &dst) { #if defined(__SYCL_DEVICE_ONLY__) -#if defined(__NVPTX__) +#if defined(__NVPTX__) || defined(__HIP_PLATFORM_AMD_MFMA__) std::ignore = sg; for (int i = 0; i < src.matrix_impl.wi_marray.size(); i++) { dst.matrix_impl.wi_marray[i] = src.matrix_impl.wi_marray[i]; } -#elif defined(__HIP_PLATFORM_AMD_MFMA__) - std::ignore = sg; - sycl::ext::oneapi::detail::joint_matrix_copy(src.matrix_impl, - dst.matrix_impl); #else using storage_element_type = typename oneapi::detail::jm_type_interpretation_helper_trait< diff --git a/sycl/test/check_device_code/hip/matrix/matrix-hip-bfloat16-float-test.cpp b/sycl/test/check_device_code/hip/matrix/matrix-hip-bfloat16-float-test.cpp index f9f8e847606d0..9f3454d5ef83a 100644 --- a/sycl/test/check_device_code/hip/matrix/matrix-hip-bfloat16-float-test.cpp +++ b/sycl/test/check_device_code/hip/matrix/matrix-hip-bfloat16-float-test.cpp @@ -1,6 +1,6 @@ // REQUIRES: hip -// RUN: %clangxx -fsycl-device-only -fsycl-targets=amd_gpu_gfx90a -DSYCL_EXT_ONEAPI_MATRIX_VERSION=4 -S -Xclang -emit-llvm %s -o -| FileCheck %s +// RUN: %clangxx -fsycl-device-only -fsycl-targets=amd_gpu_gfx90a -S -Xclang -emit-llvm %s -o -| FileCheck %s #include diff --git a/sycl/test/check_device_code/hip/matrix/matrix-hip-double-double-test.cpp b/sycl/test/check_device_code/hip/matrix/matrix-hip-double-double-test.cpp index 32fe4a23d4ea0..8475afee205b7 100644 --- a/sycl/test/check_device_code/hip/matrix/matrix-hip-double-double-test.cpp +++ b/sycl/test/check_device_code/hip/matrix/matrix-hip-double-double-test.cpp @@ -1,6 +1,6 @@ // REQUIRES: hip -// RUN: %clangxx -fsycl-device-only -fsycl-targets=amd_gpu_gfx90a -DSYCL_EXT_ONEAPI_MATRIX_VERSION=4 -S -Xclang -emit-llvm %s -o -| FileCheck %s +// RUN: %clangxx -fsycl-device-only -fsycl-targets=amd_gpu_gfx90a -S -Xclang -emit-llvm %s -o -| FileCheck %s #include diff --git a/sycl/test/check_device_code/hip/matrix/matrix-hip-half-float-test.cpp b/sycl/test/check_device_code/hip/matrix/matrix-hip-half-float-test.cpp index 0db7e4ce45b16..9019233a1fa8b 100644 --- a/sycl/test/check_device_code/hip/matrix/matrix-hip-half-float-test.cpp +++ b/sycl/test/check_device_code/hip/matrix/matrix-hip-half-float-test.cpp @@ -1,6 +1,6 @@ // REQUIRES: hip -// RUN: %clangxx -fsycl-device-only -fsycl-targets=amd_gpu_gfx90a -DSYCL_EXT_ONEAPI_MATRIX_VERSION=4 -S -Xclang -emit-llvm %s -o -| FileCheck %s +// RUN: %clangxx -fsycl-device-only -fsycl-targets=amd_gpu_gfx90a -S -Xclang -emit-llvm %s -o -| FileCheck %s #include diff --git a/sycl/test/check_device_code/hip/matrix/matrix-hip-int8-int32-test.cpp b/sycl/test/check_device_code/hip/matrix/matrix-hip-int8-int32-test.cpp index 64db6fe21a493..f2c7b1ec8c08c 100644 --- a/sycl/test/check_device_code/hip/matrix/matrix-hip-int8-int32-test.cpp +++ b/sycl/test/check_device_code/hip/matrix/matrix-hip-int8-int32-test.cpp @@ -1,6 +1,6 @@ // REQUIRES: hip -// RUN: %clangxx -fsycl-device-only -fsycl-targets=amd_gpu_gfx90a -DSYCL_EXT_ONEAPI_MATRIX_VERSION=4 -S -Xclang -emit-llvm %s -o -| FileCheck %s +// RUN: %clangxx -fsycl-device-only -fsycl-targets=amd_gpu_gfx90a -S -Xclang -emit-llvm %s -o -| FileCheck %s #include From defd8745726f54ed1c492a60e71260d231b86558 Mon Sep 17 00:00:00 2001 From: mmoadeli Date: Tue, 24 Oct 2023 13:55:16 +0100 Subject: [PATCH 46/50] Improve `joint_matrix_copy` by avoiding the loop. --- sycl/include/sycl/ext/oneapi/matrix/matrix-unified.hpp | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/sycl/include/sycl/ext/oneapi/matrix/matrix-unified.hpp b/sycl/include/sycl/ext/oneapi/matrix/matrix-unified.hpp index 15da559d7e241..fb56120dba89f 100644 --- a/sycl/include/sycl/ext/oneapi/matrix/matrix-unified.hpp +++ b/sycl/include/sycl/ext/oneapi/matrix/matrix-unified.hpp @@ -387,9 +387,7 @@ void joint_matrix_copy( #if defined(__SYCL_DEVICE_ONLY__) #if defined(__NVPTX__) || defined(__HIP_PLATFORM_AMD_MFMA__) std::ignore = sg; - for (int i = 0; i < src.matrix_impl.wi_marray.size(); i++) { - dst.matrix_impl.wi_marray[i] = src.matrix_impl.wi_marray[i]; - } + dst.matrix_impl.wi_marray = src.matrix_impl.wi_marray; #else using storage_element_type = typename oneapi::detail::jm_type_interpretation_helper_trait< From 89c52d739a4d18bcae5b972fdbb4e77d9ef95c19 Mon Sep 17 00:00:00 2001 From: mmoadeli Date: Tue, 24 Oct 2023 16:15:04 +0100 Subject: [PATCH 47/50] Add a missing `comma` to the test. --- sycl/test-e2e/Matrix/joint_matrix_hip_half_gfx90a.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sycl/test-e2e/Matrix/joint_matrix_hip_half_gfx90a.cpp b/sycl/test-e2e/Matrix/joint_matrix_hip_half_gfx90a.cpp index 729bd8a16c18d..7c8e22faccd08 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_hip_half_gfx90a.cpp +++ b/sycl/test-e2e/Matrix/joint_matrix_hip_half_gfx90a.cpp @@ -12,7 +12,7 @@ template void half_matrix_mfma() { hip_matrix_mfma(); hip_matrix_mfma(); - hip_matrix_mfma(); + hip_matrix_mfma(); hip_matrix_mfma(); } From abfa2abd96c523414538426b7dcd39b41417b05d Mon Sep 17 00:00:00 2001 From: mmoadeli Date: Tue, 24 Oct 2023 16:22:54 +0100 Subject: [PATCH 48/50] Remove `-DSYCL_EXT_ONEAPI_MATRIX_VERSION=4` from AMD matrix compilation arguments. --- sycl/test-e2e/Matrix/joint_matrix_hip_gfx90a.cpp | 2 +- sycl/test-e2e/Matrix/joint_matrix_hip_half_gfx90a.cpp | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/sycl/test-e2e/Matrix/joint_matrix_hip_gfx90a.cpp b/sycl/test-e2e/Matrix/joint_matrix_hip_gfx90a.cpp index db9ac7d4d4da6..e9dc8659e69ae 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_hip_gfx90a.cpp +++ b/sycl/test-e2e/Matrix/joint_matrix_hip_gfx90a.cpp @@ -1,4 +1,4 @@ -// RUN: %{build} -fsycl -fsycl-targets=amd_gpu_gfx90a -DSYCL_EXT_ONEAPI_MATRIX_VERSION=4 %s -o %t.out +// RUN: %{build} -fsycl -fsycl-targets=amd_gpu_gfx90a %s -o %t.out // RUN: %{run} %t.out // REQUIRES: gpu-amd-gfx90a diff --git a/sycl/test-e2e/Matrix/joint_matrix_hip_half_gfx90a.cpp b/sycl/test-e2e/Matrix/joint_matrix_hip_half_gfx90a.cpp index 7c8e22faccd08..96aacbac9c280 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_hip_half_gfx90a.cpp +++ b/sycl/test-e2e/Matrix/joint_matrix_hip_half_gfx90a.cpp @@ -1,4 +1,4 @@ -// RUN: %{build} -fsycl -fsycl-targets=amd_gpu_gfx90a -DSYCL_EXT_ONEAPI_MATRIX_VERSION=4 %s -o %t.out +// RUN: %{build} -fsycl -fsycl-targets=amd_gpu_gfx90a %s -o %t.out // RUN: %{run} %t.out // REQUIRES: gpu-amd-gfx90a From 048ac39b65fdd78c9fa190433b8ed8fdc2fb118e Mon Sep 17 00:00:00 2001 From: mmoadeli Date: Tue, 24 Oct 2023 17:08:15 +0100 Subject: [PATCH 49/50] Guard including `matrix-hip.hpp` --- sycl/include/sycl/ext/oneapi/matrix/matrix-hip.hpp | 10 ++-------- sycl/include/sycl/ext/oneapi/matrix/matrix-unified.hpp | 7 +++++-- 2 files changed, 7 insertions(+), 10 deletions(-) diff --git a/sycl/include/sycl/ext/oneapi/matrix/matrix-hip.hpp b/sycl/include/sycl/ext/oneapi/matrix/matrix-hip.hpp index e53dc21dec8e5..7f9f9b1219cf4 100644 --- a/sycl/include/sycl/ext/oneapi/matrix/matrix-hip.hpp +++ b/sycl/include/sycl/ext/oneapi/matrix/matrix-hip.hpp @@ -11,9 +11,7 @@ #include "matrix-unified-utils.hpp" #include -#if defined(__gfx90a__) #define __HIP_PLATFORM_AMD_MFMA__ -#endif namespace sycl { inline namespace _V1 { @@ -21,6 +19,8 @@ namespace ext { namespace oneapi { namespace detail { +constexpr int WAVEFRONT_SIZE = 64; + template struct joint_matrix_hip; -#if defined(__SYCL_DEVICE_ONLY__) && defined(__HIP_PLATFORM_AMD_MFMA__) - -constexpr int WAVEFRONT_SIZE = 64; - using bfloat16x4 = __attribute__((__vector_size__(4 * sizeof(__bf16)))) __fp16; using float16x4 = __attribute__((__vector_size__(4 * sizeof(__fp16)))) __fp16; using floatx4 = __attribute__((__vector_size__(4 * sizeof(float)))) float; @@ -401,8 +397,6 @@ void joint_matrix_mad_hip( } } -#endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) - } // namespace detail } // namespace oneapi } // namespace ext diff --git a/sycl/include/sycl/ext/oneapi/matrix/matrix-unified.hpp b/sycl/include/sycl/ext/oneapi/matrix/matrix-unified.hpp index fb56120dba89f..21271d1c09b14 100644 --- a/sycl/include/sycl/ext/oneapi/matrix/matrix-unified.hpp +++ b/sycl/include/sycl/ext/oneapi/matrix/matrix-unified.hpp @@ -8,11 +8,14 @@ #pragma once -#include "matrix-hip.hpp" #include "matrix-intel.hpp" -#if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) +#if defined(__SYCL_DEVICE_ONLY__) +#if defined(__NVPTX__) #include "matrix-tensorcores.hpp" +#elif defined(__gfx90a__) +#include "matrix-hip.hpp" +#endif #endif #include // for address_space From 05d2e9d2d883244ac68afb9852f20261727d2f4c Mon Sep 17 00:00:00 2001 From: mmoadeli Date: Tue, 24 Oct 2023 17:20:02 +0100 Subject: [PATCH 50/50] Minor macro readability improvement. --- sycl/include/sycl/ext/oneapi/matrix/matrix-unified.hpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sycl/include/sycl/ext/oneapi/matrix/matrix-unified.hpp b/sycl/include/sycl/ext/oneapi/matrix/matrix-unified.hpp index 21271d1c09b14..d8a751680900a 100644 --- a/sycl/include/sycl/ext/oneapi/matrix/matrix-unified.hpp +++ b/sycl/include/sycl/ext/oneapi/matrix/matrix-unified.hpp @@ -15,8 +15,8 @@ #include "matrix-tensorcores.hpp" #elif defined(__gfx90a__) #include "matrix-hip.hpp" -#endif -#endif +#endif // defined(__NVPTX__) +#endif // defined(__SYCL_DEVICE_ONLY__) #include // for address_space #include // for __SYCL_ALWAYS_...