Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
146 changes: 73 additions & 73 deletions sycl/include/sycl/ext/oneapi/matrix/static-query-use.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,25 +56,25 @@ enum class matrix_type {
enum class scope_t { sub_group, work_group };

template <tpu u, typename Ta = void, typename Tb = void, typename Tc = void,
int M = 0, int N = 0, int K = 0, typename Enabled = void>
int sM = 0, int sN = 0, int sK = 0, typename Enabled = void>
struct tpu_params;

#if __cplusplus >= 201703L
template <typename Ta, typename Tb, typename Tc>
constexpr bool is_combination_valid_amx(int M, int N, int K) {
constexpr bool is_combination_valid_amx(int sM, int sN, int sK) {
// is_same_v is a C++17 feature
if ((std::is_same_v<Ta, int8_t> && std::is_same_v<Tb, int8_t> &&
std::is_same_v<Tc, int> && M <= 16 && N <= 16 && K <= 64) ||
std::is_same_v<Tc, int> && sM <= 16 && sN <= 16 && sK <= 64) ||
(std::is_same_v<Ta, uint8_t> && std::is_same_v<Tb, uint8_t> &&
std::is_same_v<Tc, int> && M <= 16 && N <= 16 && K <= 64) ||
std::is_same_v<Tc, int> && sM <= 16 && sN <= 16 && sK <= 64) ||
(std::is_same_v<Ta, int8_t> && std::is_same_v<Tb, uint8_t> &&
std::is_same_v<Tc, int> && M <= 16 && N <= 16 && K <= 64) ||
std::is_same_v<Tc, int> && sM <= 16 && sN <= 16 && sK <= 64) ||
(std::is_same_v<Ta, uint8_t> && std::is_same_v<Tb, int8_t> &&
std::is_same_v<Tc, int> && M <= 16 && N <= 16 && K <= 64) ||
std::is_same_v<Tc, int> && sM <= 16 && sN <= 16 && sK <= 64) ||
// bf16
(std::is_same_v<Ta, unsigned short> &&
std::is_same_v<Tb, unsigned short> && std::is_same_v<Tc, float> &&
M <= 16 && N <= 16 && K <= 32))
sM <= 16 && sN <= 16 && sK <= 32))
return true;
else
return false;
Expand All @@ -100,11 +100,11 @@ constexpr bool are_types_valid_amx() {

// General query:
// types are not given, no default sizes and no implicit matrix construction
template <int M, int N, int K>
struct tpu_params<tpu::amx, void, void, void, M, N, K> {
static constexpr std::size_t defaultM = -1; // depends on the type
static constexpr std::size_t defaultN = -1;
static constexpr std::size_t defaultK = -1;
template <int sM, int sN, int sK>
struct tpu_params<tpu::amx, void, void, void, sM, sN, sK> {
static constexpr std::size_t M = -1; // depends on the type
static constexpr std::size_t N = -1;
static constexpr std::size_t K = -1;

bool dynamic_p = false; // should be true in future implementations because
// AMX hardware supports dynamic sizes
Expand All @@ -116,7 +116,7 @@ struct tpu_params<tpu::amx, void, void, void, M, N, K> {
uint32_t max_ksize;
matrix_type atype;
matrix_type btype;
matrix_type ctype;
matrix_type accumulatortype;
uint32_t msize;
uint32_t nsize;
uint32_t ksize;
Expand Down Expand Up @@ -146,19 +146,19 @@ struct tpu_params<tpu::amx, Ta, Tb, Tc, 0, 0, 0,
"DPC++ code to implement bf16) ");

// construct the matrices using the default sizes
static constexpr std::size_t defaultM = 16;
static constexpr std::size_t defaultN = 16;
static constexpr std::size_t defaultK = ((sizeof(Ta) == 1) ? 64 : 32);
static constexpr std::size_t M = 16;
static constexpr std::size_t N = 16;
static constexpr std::size_t K = ((sizeof(Ta) == 1) ? 64 : 32);

template <typename Group>
using joint_matrix_a =
joint_matrix<Ta, defaultM, defaultK, use::a, layout::row_major, Group>;
joint_matrix<Ta, M, K, use::a, layout::row_major, Group>;
template <typename Group>
using joint_matrix_b =
joint_matrix<Tb, defaultK, defaultN, use::b, layout::packed_b, Group>;
joint_matrix<Tb, K, N, use::b, layout::packed_b, Group>;
template <typename Group>
using joint_matrix_c = joint_matrix<Tc, defaultM, defaultN, use::accumulator,
layout::row_major, Group>;
using joint_matrix_accumulator =
joint_matrix<Tc, M, N, use::accumulator, layout::row_major, Group>;

bool dynamic_p = false; // should be true in future implementations because
// AMX hardware supports dynamic sizes
Expand All @@ -170,7 +170,7 @@ struct tpu_params<tpu::amx, Ta, Tb, Tc, 0, 0, 0,
uint32_t max_ksize;
matrix_type atype;
matrix_type btype;
matrix_type ctype;
matrix_type accumulatortype;
uint32_t msize;
uint32_t nsize;
uint32_t ksize;
Expand All @@ -183,36 +183,36 @@ struct tpu_params<tpu::amx, Ta, Tb, Tc, 0, 0, 0,

// Valid or not:
// Specialization when both types and sizes are given
template <typename Ta, typename Tb, typename Tc, int M, int N, int K>
template <typename Ta, typename Tb, typename Tc, int sM, int sN, int sK>
struct tpu_params<
tpu::amx, Ta, Tb, Tc, M, N, K,
tpu::amx, Ta, Tb, Tc, sM, sN, sK,
typename std::enable_if<(
!std::is_same_v<Ta, void> && !std::is_same_v<Tb, void> &&
!std::is_same_v<Tc, void> && M != 0 && N != 0 && K != 0)>::type> {
!std::is_same_v<Tc, void> && sM != 0 && sN != 0 && sK != 0)>::type> {
// Validate that parameters are supported
static_assert(
(M == 0 && N == 0 && K == 0) ||
(is_combination_valid_amx<Ta, Tb, Tc>(M, N, K)),
(sM == 0 && sN == 0 && sK == 0) ||
(is_combination_valid_amx<Ta, Tb, Tc>(sM, sN, sK)),
"Invalid parameters for AMX, query valid types and maximum sizes "
"using: tpu_params<tpu::amx> myparams; and then check out "
"myparams.combinations array");

// if combination is valid, construct the matrices

static constexpr std::size_t defaultM = (M != 0) ? M : 16;
static constexpr std::size_t defaultN = (N != 0) ? N : 16;
static constexpr std::size_t defaultK =
(K != 0) ? K : ((sizeof(Ta) == 1) ? 64 : 32);
static constexpr std::size_t M = (sM != 0) ? sM : 16;
static constexpr std::size_t N = (sN != 0) ? sN : 16;
static constexpr std::size_t K =
(sK != 0) ? sK : ((sizeof(Ta) == 1) ? 64 : 32);

template <typename Group>
using joint_matrix_a =
joint_matrix<Ta, defaultM, defaultK, use::a, layout::row_major, Group>;
joint_matrix<Ta, M, K, use::a, layout::row_major, Group>;
template <typename Group>
using joint_matrix_b =
joint_matrix<Tb, defaultK, defaultN, use::b, layout::packed_b, Group>;
joint_matrix<Tb, K, N, use::b, layout::packed_b, Group>;
template <typename Group>
using joint_matrix_c = joint_matrix<Tc, defaultM, defaultN, use::accumulator,
layout::row_major, Group>;
using joint_matrix_accumulator =
joint_matrix<Tc, M, N, use::accumulator, layout::row_major, Group>;

bool dynamic_p = false; // should be true in future implementations
// because AMX hardware supports dynamic sizes
Expand All @@ -226,25 +226,25 @@ struct tpu_params<
// capabilities of the DPAS hardware.

template <typename Ta, typename Tb, typename Tc>
constexpr bool is_combination_valid_dpas(int M, int N, int K) {
constexpr bool is_combination_valid_dpas(int sM, int sN, int sK) {
if ((std::is_same_v<Ta, int8_t> && std::is_same_v<Tb, int8_t> &&
std::is_same_v<Tc, int> && (M == 1 || M == 2 || M == 4 || M == 8) &&
N == 8 && K == 32) ||
std::is_same_v<Tc, int> && (sM == 1 || sM == 2 || sM == 4 || sM == 8) &&
sN == 8 && sK == 32) ||
(std::is_same_v<Ta, int8_t> && std::is_same_v<Tb, uint8_t> &&
std::is_same_v<Tc, int> && (M == 1 || M == 2 || M == 4 || M == 8) &&
N == 8 && K == 32) ||
std::is_same_v<Tc, int> && (sM == 1 || sM == 2 || sM == 4 || sM == 8) &&
sN == 8 && sK == 32) ||
(std::is_same_v<Ta, uint8_t> && std::is_same_v<Tb, int8_t> &&
std::is_same_v<Tc, int> && (M == 1 || M == 2 || M == 4 || M == 8) &&
N == 8 && K == 32) ||
std::is_same_v<Tc, int> && (sM == 1 || sM == 2 || sM == 4 || sM == 8) &&
sN == 8 && sK == 32) ||
(std::is_same_v<Ta, uint8_t> && std::is_same_v<Tb, uint8_t> &&
std::is_same_v<Tc, int> && (M == 1 || M == 2 || M == 4 || M == 8) &&
N == 8 && K == 32) ||
std::is_same_v<Tc, int> && (sM == 1 || sM == 2 || sM == 4 || sM == 8) &&
sN == 8 && sK == 32) ||
(std::is_same_v<Ta, half> && std::is_same_v<Tb, half> &&
std::is_same_v<Tc, float> && (M == 1 || M == 2 || M == 4 || M == 8) &&
N == 8 && K == 16) ||
std::is_same_v<Tc, float> &&
(sM == 1 || sM == 2 || sM == 4 || sM == 8) && sN == 8 && sK == 16) ||
(std::is_same_v<Ta, unsigned short> &&
std::is_same_v<Tb, unsigned short> && std::is_same_v<Tc, float> &&
(M == 1 || M == 2 || M == 4 || M == 8) && N == 8 && K == 16))
(sM == 1 || sM == 2 || sM == 4 || sM == 8) && sN == 8 && sK == 16))
return true;
else
return false;
Expand Down Expand Up @@ -272,11 +272,11 @@ constexpr bool are_types_valid_dpas() {

// General Query
// specialization for when types are not given --> no default values
template <int M, int N, int K>
struct tpu_params<tpu::dpas, void, void, void, M, N, K> {
static constexpr std::size_t defaultM = -1; // depends on the type
static constexpr std::size_t defaultN = -1;
static constexpr std::size_t defaultK = -1;
template <int sM, int sN, int sK>
struct tpu_params<tpu::dpas, void, void, void, sM, sN, sK> {
static constexpr std::size_t M = -1; // depends on the type
static constexpr std::size_t N = -1;
static constexpr std::size_t K = -1;

bool dynamic_p = false; // no dynamic allocation on the GPU
uint32_t numtiles = -1; // does not apply for DPAS
Expand All @@ -288,7 +288,7 @@ struct tpu_params<tpu::dpas, void, void, void, M, N, K> {
uint32_t max_ksize;
matrix_type atype;
matrix_type btype;
matrix_type ctype;
matrix_type accumulatortype;
uint32_t msize;
uint32_t nsize;
uint32_t ksize;
Expand Down Expand Up @@ -340,19 +340,19 @@ struct tpu_params<tpu::dpas, Ta, Tb, Tc, 0, 0, 0,

// construct the matrices using the default sizes

static constexpr std::size_t defaultM = 8;
static constexpr std::size_t defaultN = 8;
static constexpr std::size_t defaultK = ((sizeof(Ta) == 1) ? 32 : 16);
static constexpr std::size_t M = 8;
static constexpr std::size_t N = 8;
static constexpr std::size_t K = ((sizeof(Ta) == 1) ? 32 : 16);

template <typename Group>
using joint_matrix_a =
joint_matrix<Ta, defaultM, defaultK, use::a, layout::row_major, Group>;
joint_matrix<Ta, M, K, use::a, layout::row_major, Group>;
template <typename Group>
using joint_matrix_b =
joint_matrix<Tb, defaultK, defaultN, use::b, layout::packed_b, Group>;
joint_matrix<Tb, K, N, use::b, layout::packed_b, Group>;
template <typename Group>
using joint_matrix_c = joint_matrix<Tc, defaultM, defaultN, use::accumulator,
layout::row_major, Group>;
using joint_matrix_accumulator =
joint_matrix<Tc, M, N, use::accumulator, layout::row_major, Group>;

bool dynamic_p = false; // no dynamic allocation on the GPU
uint32_t numtiles = -1; // does not apply for DPAS
Expand All @@ -363,7 +363,7 @@ struct tpu_params<tpu::dpas, Ta, Tb, Tc, 0, 0, 0,
uint32_t max_ksize;
matrix_type atype;
matrix_type btype;
matrix_type ctype;
matrix_type acuumulatortype;
Copy link
Contributor

@yubingex007-a11y yubingex007-a11y Oct 26, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

accumulator type

uint32_t msize;
uint32_t nsize;
uint32_t ksize;
Expand All @@ -384,32 +384,32 @@ struct tpu_params<tpu::dpas, Ta, Tb, Tc, 0, 0, 0,

// Valid or not:
// Specialization when both types and sizes are given
template <typename Ta, typename Tb, typename Tc, int M, int N, int K>
template <typename Ta, typename Tb, typename Tc, int sM, int sN, int sK>
struct tpu_params<
tpu::dpas, Ta, Tb, Tc, M, N, K,
typename std::enable_if<((!std::is_same_v<Ta, void> && M != 0))>::type> {
tpu::dpas, Ta, Tb, Tc, sM, sN, sK,
typename std::enable_if<((!std::is_same_v<Ta, void> && sM != 0))>::type> {
// Validate that parameters are supported
static_assert((M == 0 && N == 0 && K == 0) ||
(is_combination_valid_dpas<Ta, Tb, Tc>(M, N, K)),
static_assert((sM == 0 && sN == 0 && sK == 0) ||
(is_combination_valid_dpas<Ta, Tb, Tc>(sM, sN, sK)),
"Invalid parameters for DPAS, query valid combinations "
"using: tpu_params<tpu::dpas> myparams; and then check out "
"myparams.combinations array");

// if combination is valid, construct the matrices
static constexpr std::size_t defaultM = (M != 0) ? M : 8;
static constexpr std::size_t defaultN = (N != 0) ? N : 8;
static constexpr std::size_t defaultK =
(K != 0) ? K : ((sizeof(Ta) == 1) ? 32 : 16);
static constexpr std::size_t M = (sM != 0) ? sM : 8;
static constexpr std::size_t N = (sN != 0) ? sN : 8;
static constexpr std::size_t K =
(sK != 0) ? sK : ((sizeof(Ta) == 1) ? 32 : 16);

template <typename Group>
using joint_matrix_a =
joint_matrix<Ta, defaultM, defaultK, use::a, layout::row_major, Group>;
joint_matrix<Ta, M, K, use::a, layout::row_major, Group>;
template <typename Group>
using joint_matrix_b =
joint_matrix<Tb, defaultK, defaultN, use::b, layout::packed_b, Group>;
joint_matrix<Tb, K, N, use::b, layout::packed_b, Group>;
template <typename Group>
using joint_matrix_c = joint_matrix<Tc, defaultM, defaultN, use::accumulator,
layout::row_major, Group>;
using joint_matrix_accumulator =
joint_matrix<Tc, M, N, use::accumulator, layout::row_major, Group>;

bool dynamic_p = false; // no dynamic allocation on the GPU
uint32_t numtiles = -1; // does not apply for DPAS
Expand Down
Loading