diff --git a/sycl/include/sycl/ext/oneapi/matrix/static-query-use.hpp b/sycl/include/sycl/ext/oneapi/matrix/static-query-use.hpp index 3e561c9067061..03fe1fbe0549b 100644 --- a/sycl/include/sycl/ext/oneapi/matrix/static-query-use.hpp +++ b/sycl/include/sycl/ext/oneapi/matrix/static-query-use.hpp @@ -6,19 +6,16 @@ // // ===--------------------------------------------------------------------=== // // This file implements the static query interface for the joint_matrix -// experimental extension. AMX, DPAS and different other TPUs support different -// logical sizes and types. The query interface is used to validate user code -// and inform them about supported types, sizes, scope, and layouts by the -// current implementation. Note that this query interface is a compile-time -// query, so there will be no runtime errors. The query interface provides -// three functionalities: -// 1- At compile time, inform the user whether a specific -// combination is valid or not. -// 2- Construct the matrices using a default shape -// if user does not provide a combination -// 3- General query interface for sizes, types, -// static/dynamic, scope. This is needed to void padding by the user, -// for tuning, and efficient code generation if used by a library. +// experimental extension. Intel AMX, Intel XMX, and Nvidia Tensor Cores support +// different logical sizes and types. The query interface is used to validate +// user code and inform them about supported types, sizes, scopes, and layouts +// by the current implementation. Note that this query interface is a +// compile-time query, so there will be no runtime errors. The query interface +// provides three functionalities: 1- At compile time, inform the user whether a +// specific combination is valid or not. 2- Construct the matrices using a +// default shape if user does not provide a combination 3- General query +// interface for sizes, types, scopes. This is needed to void padding by the +// user, for tuning, and efficient code generation if used by a library. #pragma once @@ -29,14 +26,15 @@ namespace oneapi { namespace experimental::matrix { enum class tpu { - dpas, + xmx8, + xmx16, amx, }; enum class matrix_type { bf8, bf16, fp16, - fp19, // tfloat32 + tf32, fp32, fp64, sint2, @@ -104,10 +102,9 @@ struct tpu_params { static constexpr std::size_t N = -1; static constexpr std::size_t K = -1; - bool dynamic_p = false; // should be true in future implementations because - // AMX hardware supports dynamic sizes uint32_t numtiles = 8; - scope_t scope = scope_t::sub_group; + static constexpr scope_t scopes[] = {scope_t::sub_group}; + static constexpr int num_scopes = sizeof(scopes) / sizeof(scope_t); struct combination { uint32_t max_msize; uint32_t max_nsize; @@ -155,10 +152,9 @@ struct tpu_params; - bool dynamic_p = false; // should be true in future implementations because - // AMX hardware supports dynamic sizes uint32_t numtiles = 8; - scope_t scope = scope_t::sub_group; + static constexpr scope_t scopes[] = {scope_t::sub_group}; + static constexpr int num_scopes = sizeof(scopes) / sizeof(scope_t); struct combination { uint32_t max_msize; uint32_t max_nsize; @@ -207,19 +203,18 @@ struct tpu_params< using joint_matrix_accumulator = joint_matrix; - bool dynamic_p = false; // should be true in future implementations - // because AMX hardware supports dynamic sizes uint32_t numtiles = 8; - scope_t scope = scope_t::sub_group; + static constexpr scope_t scopes[] = {scope_t::sub_group}; + static constexpr int num_scopes = sizeof(scopes) / sizeof(scope_t); }; -// DPAS case -// The DPAS implementation supports the logical capability support of the HW -// So in this case, M, N, K sizes returned by the query represent the logical -// capabilities of the DPAS hardware. +// Intel XMX with SIMD8 capability +// The Intel XMX implementation supports the logical capability support of the +// HW So in this case, M, N, K sizes returned by the query represent the logical +// capabilities of the Intel XMX hardware. template -constexpr bool is_combination_valid_dpas(int sM, int sN, int sK) { +constexpr bool is_combination_valid_xmx8(int sM, int sN, int sK) { if ((std::is_same_v && std::is_same_v && std::is_same_v && (sM == 1 || sM == 2 || sM == 4 || sM == 8) && sN == 8 && sK == 32) || @@ -244,7 +239,7 @@ constexpr bool is_combination_valid_dpas(int sM, int sN, int sK) { } template -constexpr bool are_types_valid_dpas() { +constexpr bool are_types_valid_xmx8() { if ((std::is_same_v && std::is_same_v && std::is_same_v) || (std::is_same_v && std::is_same_v && @@ -265,14 +260,14 @@ constexpr bool are_types_valid_dpas() { // General Query // specialization for when types are not given --> no default values template -struct tpu_params { +struct tpu_params { static constexpr std::size_t M = -1; // depends on the type static constexpr std::size_t N = -1; static constexpr std::size_t K = -1; - bool dynamic_p = false; // no dynamic allocation on the GPU - uint32_t numtiles = -1; // does not apply for DPAS - scope_t scope = scope_t::sub_group; + uint32_t numtiles = -1; // does not apply for XMX8 + static constexpr scope_t scopes[] = {scope_t::sub_group}; + static constexpr int num_scopes = sizeof(scopes) / sizeof(scope_t); struct combination { uint32_t max_msize; @@ -320,12 +315,12 @@ struct tpu_params { // Specialization for when only types are given, need to query only sizes template -struct tpu_params && !std::is_same_v && !std::is_same_v)>::type> { - static_assert((are_types_valid_dpas()), - "Invalid types for DPAS, supported types are int8_t, uint8_t, " + static_assert((are_types_valid_xmx8()), + "Invalid types for XMX8, supported types are int8_t, uint8_t, " "half, and bf16 (Note that unsigned short should be used in the" "DPC++ code to implement bf16)"); @@ -343,9 +338,9 @@ struct tpu_params; - bool dynamic_p = false; // no dynamic allocation on the GPU - uint32_t numtiles = -1; // does not apply for DPAS - scope_t scope = scope_t::sub_group; + uint32_t numtiles = -1; // does not apply for XMX8 + static constexpr scope_t scopes[] = {scope_t::sub_group}; + static constexpr int num_scopes = sizeof(scopes) / sizeof(scope_t); struct combination { uint32_t max_msize; uint32_t max_nsize; @@ -376,13 +371,13 @@ struct tpu_params struct tpu_params< - tpu::dpas, Ta, Tb, Tc, sM, sN, sK, + tpu::xmx8, Ta, Tb, Tc, sM, sN, sK, typename std::enable_if<((!std::is_same_v && sM != 0))>::type> { // Validate that parameters are supported static_assert((sM == 0 && sN == 0 && sK == 0) || - (is_combination_valid_dpas(sM, sN, sK)), - "Invalid parameters for DPAS, query valid combinations " - "using: tpu_params myparams; and then check out " + (is_combination_valid_xmx8(sM, sN, sK)), + "Invalid parameters for XMX8, query valid combinations " + "using: tpu_params myparams; and then check out " "myparams.combinations array"); // if combination is valid, construct the matrices @@ -399,9 +394,200 @@ struct tpu_params< using joint_matrix_accumulator = joint_matrix; - bool dynamic_p = false; // no dynamic allocation on the GPU - uint32_t numtiles = -1; // does not apply for DPAS - scope_t scope = scope_t::sub_group; + uint32_t numtiles = -1; // does not apply for XMX8 + static constexpr scope_t scopes[] = {scope_t::sub_group}; + static constexpr int num_scopes = sizeof(scopes) / sizeof(scope_t); +}; + +// Intel XMX with SIMD16 capability +// The Intel XMX implementation supports the logical capability support of the +// HW So in this case, M, N, K sizes returned by the query represent the logical +// capabilities of the Intel XMX hardware. + +template +constexpr bool is_combination_valid_xmx16(int sM, int sN, int sK) { + if ((std::is_same_v && std::is_same_v && + std::is_same_v && (sM == 1 || sM == 2 || sM == 4 || sM == 8) && + sN == 16 && sK == 32) || + (std::is_same_v && std::is_same_v && + std::is_same_v && (sM == 1 || sM == 2 || sM == 4 || sM == 8) && + sN == 16 && sK == 32) || + (std::is_same_v && std::is_same_v && + std::is_same_v && (sM == 1 || sM == 2 || sM == 4 || sM == 8) && + sN == 16 && sK == 32) || + (std::is_same_v && std::is_same_v && + std::is_same_v && (sM == 1 || sM == 2 || sM == 4 || sM == 8) && + sN == 16 && sK == 32) || + (std::is_same_v && std::is_same_v && + std::is_same_v && + (sM == 1 || sM == 2 || sM == 4 || sM == 8) && sN == 16 && sK == 16) || + (std::is_same_v && + std::is_same_v && std::is_same_v && + (sM == 1 || sM == 2 || sM == 4 || sM == 8) && sN == 16 && sK == 16)) + return true; + else + return false; +} + +template +constexpr bool are_types_valid_xmx16() { + if ((std::is_same_v && std::is_same_v && + std::is_same_v) || + (std::is_same_v && std::is_same_v && + std::is_same_v) || + (std::is_same_v && std::is_same_v && + std::is_same_v) || + (std::is_same_v && std::is_same_v && + std::is_same_v) || + (std::is_same_v && std::is_same_v && + std::is_same_v) || + (std::is_same_v && + std::is_same_v && std::is_same_v)) + return true; + else + return false; +} + +// General Query +// specialization for when types are not given --> no default values +template +struct tpu_params { + static constexpr std::size_t M = -1; // depends on the type + static constexpr std::size_t N = -1; + static constexpr std::size_t K = -1; + + uint32_t numtiles = -1; // does not apply for XMX + static constexpr scope_t scopes[] = {scope_t::sub_group}; + static constexpr int num_scopes = sizeof(scopes) / sizeof(scope_t); + + struct combination { + uint32_t max_msize; + uint32_t max_nsize; + uint32_t max_ksize; + matrix_type atype; + matrix_type btype; + matrix_type accumulatortype; + uint32_t msize; + uint32_t nsize; + uint32_t ksize; + }; + using mt = matrix_type; + static constexpr combination combinations[] = { + {0, 0, 0, mt::sint8, mt::sint8, mt::sint32, 1, 16, 32}, + {0, 0, 0, mt::sint8, mt::sint8, mt::sint32, 2, 16, 32}, + {0, 0, 0, mt::sint8, mt::sint8, mt::sint32, 4, 16, 32}, + {0, 0, 0, mt::sint8, mt::sint8, mt::sint32, 8, 16, 32}, + {0, 0, 0, mt::sint8, mt::uint8, mt::sint32, 1, 16, 32}, + {0, 0, 0, mt::sint8, mt::uint8, mt::sint32, 2, 16, 32}, + {0, 0, 0, mt::sint8, mt::uint8, mt::sint32, 4, 16, 32}, + {0, 0, 0, mt::sint8, mt::uint8, mt::sint32, 8, 16, 32}, + {0, 0, 0, mt::uint8, mt::sint8, mt::sint32, 1, 16, 32}, + {0, 0, 0, mt::uint8, mt::sint8, mt::sint32, 2, 16, 32}, + {0, 0, 0, mt::uint8, mt::sint8, mt::sint32, 4, 16, 32}, + {0, 0, 0, mt::uint8, mt::sint8, mt::sint32, 8, 16, 32}, + {0, 0, 0, mt::uint8, mt::uint8, mt::sint32, 1, 16, 32}, + {0, 0, 0, mt::uint8, mt::uint8, mt::sint32, 2, 16, 32}, + {0, 0, 0, mt::uint8, mt::uint8, mt::sint32, 4, 16, 32}, + {0, 0, 0, mt::uint8, mt::uint8, mt::sint32, 8, 16, 32}, + {0, 0, 0, mt::fp16, mt::fp16, mt::fp32, 1, 16, 16}, + {0, 0, 0, mt::fp16, mt::fp16, mt::fp32, 2, 16, 16}, + {0, 0, 0, mt::fp16, mt::fp16, mt::fp32, 4, 16, 16}, + {0, 0, 0, mt::fp16, mt::fp16, mt::fp32, 8, 16, 16}, + {0, 0, 0, mt::bf16, mt::bf16, mt::fp32, 1, 16, 16}, + {0, 0, 0, mt::bf16, mt::bf16, mt::fp32, 2, 16, 16}, + {0, 0, 0, mt::bf16, mt::bf16, mt::fp32, 4, 16, 16}, + {0, 0, 0, mt::bf16, mt::bf16, mt::fp32, 8, 16, 16}, + }; + static constexpr int num_combinations = + sizeof(combinations) / sizeof(combination); +}; + +// Sizes-only query: +// Specialization for when only types are given, need to query only sizes + +template +struct tpu_params && + !std::is_same_v && + !std::is_same_v)>::type> { + static_assert((are_types_valid_xmx16()), + "Invalid types for XMX16, supported types are int8_t, uint8_t, " + "half, and bf16 (Note that unsigned short should be used in the" + "DPC++ code to implement bf16)"); + + // construct the matrices using the default sizes + + static constexpr std::size_t M = 8; + static constexpr std::size_t N = 16; + static constexpr std::size_t K = ((sizeof(Ta) == 1) ? 32 : 16); + + template + using joint_matrix_a = joint_matrix; + template + using joint_matrix_b = joint_matrix; + template + using joint_matrix_accumulator = + joint_matrix; + + uint32_t numtiles = -1; // does not apply for XMX + static constexpr scope_t scopes[] = {scope_t::sub_group}; + static constexpr int num_scopes = sizeof(scopes) / sizeof(scope_t); + struct combination { + uint32_t max_msize; + uint32_t max_nsize; + uint32_t max_ksize; + matrix_type atype; + matrix_type btype; + matrix_type accumulatortype; + uint32_t msize; + uint32_t nsize; + uint32_t ksize; + }; + using mt = matrix_type; + static constexpr combination combinations[] = { + // The types used in the initialization below are fake and not used. In + // this case, users already chose the types, they are only looking for + // the + // sizes + {0, 0, 0, mt::bf8, mt::bf8, mt::bf8, 1, 16, (sizeof(Ta) == 1) ? 32 : 16}, + {0, 0, 0, mt::bf8, mt::bf8, mt::bf8, 2, 16, (sizeof(Ta) == 1) ? 32 : 16}, + {0, 0, 0, mt::bf8, mt::bf8, mt::bf8, 4, 16, (sizeof(Ta) == 1) ? 32 : 16}, + {0, 0, 0, mt::bf8, mt::bf8, mt::bf8, 8, 16, (sizeof(Ta) == 1) ? 32 : 16}, + }; + static constexpr int num_combinations = + sizeof(combinations) / sizeof(combination); +}; + +// Valid or not: +// Specialization when both types and sizes are given +template +struct tpu_params< + tpu::xmx16, Ta, Tb, Tc, sM, sN, sK, + typename std::enable_if<((!std::is_same_v && sM != 0))>::type> { + // Validate that parameters are supported + static_assert((sM == 0 && sN == 0 && sK == 0) || + (is_combination_valid_xmx16(sM, sN, sK)), + "Invalid parameters for XMX16, query valid combinations " + "using: tpu_params myparams; and then check out " + "myparams.combinations array"); + + // if combination is valid, construct the matrices + static constexpr std::size_t M = (sM != 0) ? sM : 8; + static constexpr std::size_t N = (sN != 0) ? sN : 8; + static constexpr std::size_t K = + (sK != 0) ? sK : ((sizeof(Ta) == 1) ? 32 : 16); + + template + using joint_matrix_a = joint_matrix; + template + using joint_matrix_b = joint_matrix; + template + using joint_matrix_accumulator = + joint_matrix; + + uint32_t numtiles = -1; // does not apply for XMX16 + static constexpr scope_t scopes[] = {scope_t::sub_group}; + static constexpr int num_scopes = sizeof(scopes) / sizeof(scope_t); }; } // namespace experimental::matrix } // namespace oneapi diff --git a/sycl/test/matrix/query-use.cpp b/sycl/test/matrix/query-use.cpp index 6239d05d5f79f..9afc8e1173043 100644 --- a/sycl/test/matrix/query-use.cpp +++ b/sycl/test/matrix/query-use.cpp @@ -36,6 +36,13 @@ void query_amx() { // general query: types are not given tpu_params myparams3; + if (myparams3.num_scopes > 0) + if (myparams3.scopes[0] == scope_t::sub_group) + std::cout << "There are " << myparams3.num_scopes + << " Scopes that are supported by AMX implementation and " + "subgroup is one of them " + << std::endl; + std::cout << "AMX query num combinations: " << myparams3.num_combinations << std::endl; @@ -68,43 +75,51 @@ void query_amx() { }); } -void query_dpas() { +void query_xmx8() { // generates combination assert - // using myparams = tpu_params; + // using myparams = tpu_params; // generate combination of type assert - // using myparams = tpu_params; + // using myparams = tpu_params; // tells whether a combination is valid or not, if valid, those will be set as // default - using myparams = tpu_params; + using myparams = tpu_params; size_t dmsize = myparams::M; size_t dnsize = myparams::N; size_t dksize = myparams::K; - std::cout << "sizes of DPAS tpu_params chosen by the user are: M " << dmsize + std::cout << "sizes of XMX8 tpu_params chosen by the user are: M " << dmsize << " N " << dnsize << " K " << dksize << std::endl; // sizes-only query: types are given, generate default sizes - using myparams2 = tpu_params; + using myparams2 = tpu_params; myparams2 p; dmsize = myparams2::M; dnsize = myparams2::N; dksize = myparams2::K; - std::cout << "Default DPAS sizes are: M " << dmsize << " N " << dnsize - << " K " << dksize << "\n DPAS int8 num combinations is " + std::cout << "Default XMX8 sizes are: M " << dmsize << " N " << dnsize + << " K " << dksize << "\n XMX8 int8 num combinations is " << p.num_combinations << std::endl; dmsize = myparams2::combinations[0].msize; dnsize = myparams2::combinations[0].nsize; dksize = myparams2::combinations[0].ksize; - std::cout << "one of DPAS combination sizes is: M " << dmsize << " N " + std::cout << "one of XMX8 combination sizes is: M " << dmsize << " N " << dnsize << " K " << dksize << std::endl; // general query: types are not given - tpu_params myparams3; - std::cout << "DPAS query num combinations: " << myparams3.num_combinations + tpu_params myparams3; + + if (myparams3.num_scopes > 0) + if (myparams3.scopes[0] == scope_t::sub_group) + std::cout << "There are " << myparams3.num_scopes + << " Scopes that are supported by XMX8 implementation and " + "subgroup is one of them " + << std::endl; + + std::cout << "XMX8 query num combinations: " << myparams3.num_combinations << std::endl; if (myparams3.combinations[0].msize == 0) // this is not a max params hardware @@ -112,9 +127,9 @@ void query_dpas() { constexpr int msize = myparams3.combinations[0].msize; constexpr int nsize = myparams3.combinations[0].nsize; constexpr int ksize = myparams3.combinations[0].ksize; - std::cout << "DPAS query sizes are: M " << msize << " N " << nsize << " K " + std::cout << "XMX8 query sizes are: M " << msize << " N " << nsize << " K " << ksize << std::endl; - std::cout << "DPAS query max sizes are: M " + std::cout << "XMX8 query max sizes are: M " << myparams3.combinations[0].max_msize << " N " << myparams3.combinations[0].max_nsize << " K " << myparams3.combinations[0].max_ksize << std::endl; @@ -142,6 +157,6 @@ void query_dpas() { int main() { query_amx(); - query_dpas(); + query_xmx8(); return 0; }