diff --git a/SYCL/Matrix/Legacy/XMX8/element_wise_all_ops_bf16.cpp b/SYCL/Matrix/Legacy/XMX8/element_wise_all_ops_bf16.cpp new file mode 100644 index 0000000000..0217eb3741 --- /dev/null +++ b/SYCL/Matrix/Legacy/XMX8/element_wise_all_ops_bf16.cpp @@ -0,0 +1,24 @@ +//==----------- element_wise_all_ops_bf16.cpp - DPC++ joint_matrix---------==// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// REQUIRES: matrix-xmx8 + +// RUN: %clangxx -fsycl %s -o %t.out -DSYCL_EXT_ONEAPI_MATRIX_VERSION=1 +// RUN: %CPU_RUN_PLACEHOLDER %t.out +// RUN: %GPU_RUN_PLACEHOLDER %t.out + +#include +#include +#include + +using namespace sycl; +using namespace sycl::ext::intel; +using namespace sycl::ext::oneapi::experimental::matrix; + +#define SG_SZ 8 + +#include "../element_wise_all_ops_bf16_impl.hpp" diff --git a/SYCL/Matrix/Legacy/XMX8/element_wise_all_ops_half.cpp b/SYCL/Matrix/Legacy/XMX8/element_wise_all_ops_half.cpp new file mode 100644 index 0000000000..08c7251d0b --- /dev/null +++ b/SYCL/Matrix/Legacy/XMX8/element_wise_all_ops_half.cpp @@ -0,0 +1,25 @@ +//==----------- element_wise_all_ops_half.cpp - DPC++ joint_matrix---------==// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// REQUIRES: matrix-xmx8 + +// Only runs on DPAS because AMX implementation does not support half data type +// yet +// RUN: %clangxx -fsycl %s -o %t.out -DSYCL_EXT_ONEAPI_MATRIX_VERSION=1 +// RUN: %GPU_RUN_PLACEHOLDER %t.out + +#include +#include +#include + +using namespace sycl; +using namespace sycl::ext::intel; +using namespace sycl::ext::oneapi::experimental::matrix; + +#define SG_SZ 8 + +#include "../element_wise_all_ops_half_impl.hpp" diff --git a/SYCL/Matrix/Legacy/XMX8/element_wise_all_ops_int8.cpp b/SYCL/Matrix/Legacy/XMX8/element_wise_all_ops_int8.cpp new file mode 100644 index 0000000000..8462dff815 --- /dev/null +++ b/SYCL/Matrix/Legacy/XMX8/element_wise_all_ops_int8.cpp @@ -0,0 +1,24 @@ +//==----------- element_wise_all_ops_int8.cpp - DPC++ joint_matrix---------==// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// REQUIRES: matrix-xmx8 + +// RUN: %clangxx -fsycl %s -o %t.out -DSYCL_EXT_ONEAPI_MATRIX_VERSION=1 +// RUN: %CPU_RUN_PLACEHOLDER %t.out +// RUN: %GPU_RUN_PLACEHOLDER %t.out + +#include +#include +#include + +using namespace sycl; +using namespace sycl::ext::intel; +using namespace sycl::ext::oneapi::experimental::matrix; + +#define SG_SZ 8 + +#include "../element_wise_all_ops_int8_impl.hpp" diff --git a/SYCL/Matrix/Legacy/XMX8/element_wise_all_ops_int8_packed.cpp b/SYCL/Matrix/Legacy/XMX8/element_wise_all_ops_int8_packed.cpp new file mode 100644 index 0000000000..32f5e0138c --- /dev/null +++ b/SYCL/Matrix/Legacy/XMX8/element_wise_all_ops_int8_packed.cpp @@ -0,0 +1,26 @@ +//==------ element_wise_all_ops_int8_packed.cpp - DPC++ joint_matrix-------==// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// REQUIRES: matrix-xmx8 + +// RUN: %clangxx -fsycl %s -o %t.out -DSYCL_EXT_ONEAPI_MATRIX_VERSION=1 +// RUN: %CPU_RUN_PLACEHOLDER %t.out +// RUN: %GPU_RUN_PLACEHOLDER %t.out + +// XFAIL: gpu + +#include +#include +#include + +using namespace sycl; +using namespace sycl::ext::intel; +using namespace sycl::ext::oneapi::experimental::matrix; + +#define SG_SZ 8 + +#include "../element_wise_all_ops_int8_packed_impl.hpp" diff --git a/SYCL/Matrix/Legacy/XMX8/element_wise_irreg_sum_rows.cpp b/SYCL/Matrix/Legacy/XMX8/element_wise_irreg_sum_rows.cpp new file mode 100644 index 0000000000..df7a479e06 --- /dev/null +++ b/SYCL/Matrix/Legacy/XMX8/element_wise_irreg_sum_rows.cpp @@ -0,0 +1,26 @@ +//==-------- element_wise_irreg_sum_rows.cpp - DPC++ joint_matrix----- ----==// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// REQUIRES: matrix-xmx8 + +// RUN: %clangxx -fsycl %s -o %t.out -DSYCL_EXT_ONEAPI_MATRIX_VERSION=1 +// RUN: %CPU_RUN_PLACEHOLDER %t.out +// RUN: %GPU_RUN_PLACEHOLDER %t.out + +// this code calculates the sum of rows into a global array of number of rows +// elements. First, partial reduction is computed inside each SG, then atomic +// add is used to reduce between SG leaders + +#include +#include + +using namespace sycl; +using namespace sycl::ext::oneapi::experimental::matrix; + +#define SG_SZ 8 + +#include "../element_wise_irreg_sum_rows_impl.hpp" diff --git a/SYCL/Matrix/Legacy/XMX8/element_wise_ops.cpp b/SYCL/Matrix/Legacy/XMX8/element_wise_ops.cpp new file mode 100644 index 0000000000..c64a99fee5 --- /dev/null +++ b/SYCL/Matrix/Legacy/XMX8/element_wise_ops.cpp @@ -0,0 +1,22 @@ +//==----------- element_wise_ops.cpp - DPC++ joint_matrix------------- ----==// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// REQUIRES: matrix-xmx8 + +// RUN: %clangxx -fsycl %s -o %t.out -DSYCL_EXT_ONEAPI_MATRIX_VERSION=1 +// RUN: %CPU_RUN_PLACEHOLDER %t.out +// RUN: %GPU_RUN_PLACEHOLDER %t.out + +#include +#include + +using namespace sycl; +using namespace sycl::ext::oneapi::experimental::matrix; + +#define SG_SZ 8 + +#include "../element_wise_ops_impl.hpp" diff --git a/SYCL/Matrix/Legacy/XMX8/joint_matrix_bf16.cpp b/SYCL/Matrix/Legacy/XMX8/joint_matrix_bf16.cpp new file mode 100644 index 0000000000..ee084fd400 --- /dev/null +++ b/SYCL/Matrix/Legacy/XMX8/joint_matrix_bf16.cpp @@ -0,0 +1,22 @@ +//==-------- joint_matrix_bf16.cpp - DPC++ joint_matrix--------------- ----==// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// REQUIRES: matrix-xmx8 + +// RUN: %clangxx -fsycl %s -o %t.out -DSYCL_EXT_ONEAPI_MATRIX_VERSION=1 +// RUN: %CPU_RUN_PLACEHOLDER %t.out +// RUN: %GPU_RUN_PLACEHOLDER %t.out + +#include +#include + +using namespace sycl; +using namespace sycl::ext::oneapi::experimental::matrix; + +#define SG_SZ 8 + +#include "../joint_matrix_bf16_impl.hpp" diff --git a/SYCL/Matrix/Legacy/XMX8/joint_matrix_bfloat16.cpp b/SYCL/Matrix/Legacy/XMX8/joint_matrix_bfloat16.cpp new file mode 100644 index 0000000000..65c03c7dfc --- /dev/null +++ b/SYCL/Matrix/Legacy/XMX8/joint_matrix_bfloat16.cpp @@ -0,0 +1,23 @@ +//==-------- joint_matrix_bfloat16.cpp - DPC++ joint_matrix----------- ----==// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// REQUIRES: matrix-xmx8 + +// RUN: %clangxx -fsycl %s -o %t.out -DSYCL_EXT_ONEAPI_MATRIX_VERSION=1 +// RUN: %CPU_RUN_PLACEHOLDER %t.out +// RUN: %GPU_RUN_PLACEHOLDER %t.out + +#include +#include + +using namespace sycl; +using namespace sycl::ext::oneapi::experimental::matrix; +using bfloat16 = sycl::ext::oneapi::experimental::bfloat16; + +#define SG_SZ 8 + +#include "../joint_matrix_bfloat16_impl.hpp" diff --git a/SYCL/Matrix/Legacy/XMX8/joint_matrix_bfloat16_32x64.cpp b/SYCL/Matrix/Legacy/XMX8/joint_matrix_bfloat16_32x64.cpp new file mode 100644 index 0000000000..b2ba255e41 --- /dev/null +++ b/SYCL/Matrix/Legacy/XMX8/joint_matrix_bfloat16_32x64.cpp @@ -0,0 +1,25 @@ +//==----- joint_matrix_bfloat16_32x64.cpp - DPC++ joint_matrix-------------==// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// REQUIRES: matrix + +// RUN: %clangxx -fsycl %s -o %t.out -DSYCL_EXT_ONEAPI_MATRIX_VERSION=1 +// RUN: %CPU_RUN_PLACEHOLDER %t.out +// RUN: %GPU_RUN_PLACEHOLDER %t.out + +// XFAIL: * + +#include +#include + +using namespace sycl; +using namespace sycl::ext::oneapi::experimental::matrix; +using bfloat16 = sycl::ext::oneapi::experimental::bfloat16; + +#define SG_SZ 8 + +#include "../joint_matrix_bfloat16_32x64_impl.hpp" diff --git a/SYCL/Matrix/Legacy/XMX8/joint_matrix_half.cpp b/SYCL/Matrix/Legacy/XMX8/joint_matrix_half.cpp new file mode 100644 index 0000000000..3c2736ddd1 --- /dev/null +++ b/SYCL/Matrix/Legacy/XMX8/joint_matrix_half.cpp @@ -0,0 +1,22 @@ +//==-------- joint_matrix_half.cpp - DPC++ joint_matrix------------ ----==// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// REQUIRES: matrix-xmx8 + +// RUN: %clangxx -fsycl %s -o %t.out -DSYCL_EXT_ONEAPI_MATRIX_VERSION=1 +// Only run on the GPU because half is not supported on AMX hardware +// RUN: %GPU_RUN_PLACEHOLDER %t.out + +#include +#include + +using namespace sycl; +using namespace sycl::ext::oneapi::experimental::matrix; + +#define SG_SZ 8 + +#include "../joint_matrix_half_impl.hpp" diff --git a/SYCL/Matrix/Legacy/XMX8/joint_matrix_int8_vnni.cpp b/SYCL/Matrix/Legacy/XMX8/joint_matrix_int8_vnni.cpp new file mode 100644 index 0000000000..272f22e554 --- /dev/null +++ b/SYCL/Matrix/Legacy/XMX8/joint_matrix_int8_vnni.cpp @@ -0,0 +1,24 @@ +//==-------- joint_matrix_bf16_vnni.cpp - DPC++ joint_matrix---------------==// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// REQUIRES: matrix-xmx8 + +// RUN: %clangxx -fsycl %s -o %t.out -DSYCL_EXT_ONEAPI_MATRIX_VERSION=1 +// RUN: %CPU_RUN_PLACEHOLDER %t.out +// RUN: %GPU_RUN_PLACEHOLDER %t.out + +// XFAIL: * + +#include +#include + +using namespace sycl; +using namespace sycl::ext::oneapi::experimental::matrix; + +#define SG_SZ 8 + +#include "../joint_matrix_int8_vnni_impl.hpp" diff --git a/SYCL/Matrix/Legacy/XMX8/joint_matrix_ss_int8.cpp b/SYCL/Matrix/Legacy/XMX8/joint_matrix_ss_int8.cpp new file mode 100644 index 0000000000..fc9536ff78 --- /dev/null +++ b/SYCL/Matrix/Legacy/XMX8/joint_matrix_ss_int8.cpp @@ -0,0 +1,22 @@ +//==-------- joint_matrix_ss_int8.cpp - DPC++ joint_matrix------------ ----==// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// REQUIRES: matrix-xmx8 + +// RUN: %clangxx -fsycl %s -o %t.out -DSYCL_EXT_ONEAPI_MATRIX_VERSION=1 +// RUN: %CPU_RUN_PLACEHOLDER %t.out +// RUN: %GPU_RUN_PLACEHOLDER %t.out + +#include +#include + +using namespace sycl; +using namespace sycl::ext::oneapi::experimental::matrix; + +#define SG_SZ 8 + +#include "../joint_matrix_ss_int8_impl.hpp" diff --git a/SYCL/Matrix/Legacy/XMX8/joint_matrix_su_int8.cpp b/SYCL/Matrix/Legacy/XMX8/joint_matrix_su_int8.cpp new file mode 100644 index 0000000000..d1ab5e003c --- /dev/null +++ b/SYCL/Matrix/Legacy/XMX8/joint_matrix_su_int8.cpp @@ -0,0 +1,22 @@ +//==-------- joint_matrix_su_int8.cpp - DPC++ joint_matrix------------ ----==// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// REQUIRES: matrix-xmx8 + +// RUN: %clangxx -fsycl %s -o %t.out -DSYCL_EXT_ONEAPI_MATRIX_VERSION=1 +// RUN: %CPU_RUN_PLACEHOLDER %t.out +// RUN: %GPU_RUN_PLACEHOLDER %t.out + +#include +#include + +using namespace sycl; +using namespace sycl::ext::oneapi::experimental::matrix; + +#define SG_SZ 8 + +#include "../joint_matrix_su_int8_impl.hpp" diff --git a/SYCL/Matrix/Legacy/XMX8/joint_matrix_us_int8.cpp b/SYCL/Matrix/Legacy/XMX8/joint_matrix_us_int8.cpp new file mode 100644 index 0000000000..b07ad6ec28 --- /dev/null +++ b/SYCL/Matrix/Legacy/XMX8/joint_matrix_us_int8.cpp @@ -0,0 +1,22 @@ +//==-------- joint_matrix_us_int8.cpp - DPC++ joint_matrix------------ ----==// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// REQUIRES: matrix-xmx8 + +// RUN: %clangxx -fsycl %s -o %t.out -DSYCL_EXT_ONEAPI_MATRIX_VERSION=1 +// RUN: %CPU_RUN_PLACEHOLDER %t.out +// RUN: %GPU_RUN_PLACEHOLDER %t.out + +#include +#include + +using namespace sycl; +using namespace sycl::ext::oneapi::experimental::matrix; + +#define SG_SZ 8 + +#include "../joint_matrix_us_int8_impl.hpp" diff --git a/SYCL/Matrix/Legacy/XMX8/joint_matrix_uu_int8.cpp b/SYCL/Matrix/Legacy/XMX8/joint_matrix_uu_int8.cpp new file mode 100644 index 0000000000..a14b259b4c --- /dev/null +++ b/SYCL/Matrix/Legacy/XMX8/joint_matrix_uu_int8.cpp @@ -0,0 +1,22 @@ +//==-------- joint_matrix_uu_int8.cpp - DPC++ joint_matrix------------ ----==// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// REQUIRES: matrix-xmx8 + +// RUN: %clangxx -fsycl %s -o %t.out -DSYCL_EXT_ONEAPI_MATRIX_VERSION=1 +// RUN: %CPU_RUN_PLACEHOLDER %t.out +// RUN: %GPU_RUN_PLACEHOLDER %t.out + +#include +#include + +using namespace sycl; +using namespace sycl::ext::oneapi::experimental::matrix; + +#define SG_SZ 8 + +#include "../joint_matrix_uu_int8_impl.hpp" diff --git a/SYCL/Matrix/Legacy/element_wise_all_ops_bf16.cpp b/SYCL/Matrix/Legacy/element_wise_all_ops_bf16.cpp new file mode 100644 index 0000000000..a81f0e255f --- /dev/null +++ b/SYCL/Matrix/Legacy/element_wise_all_ops_bf16.cpp @@ -0,0 +1,24 @@ +//==----------- element_wise_all_ops_bf16.cpp - DPC++ joint_matrix---------==// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// REQUIRES: matrix + +// RUN: %clangxx -fsycl %s -o %t.out -DSYCL_EXT_ONEAPI_MATRIX_VERSION=1 +// RUN: %CPU_RUN_PLACEHOLDER %t.out +// RUN: %GPU_RUN_PLACEHOLDER %t.out + +#include +#include +#include + +using namespace sycl; +using namespace sycl::ext::intel; +using namespace sycl::ext::oneapi::experimental::matrix; + +#define SG_SZ 16 + +#include "element_wise_all_ops_bf16_impl.hpp" diff --git a/SYCL/Matrix/Legacy/element_wise_all_ops_bf16_impl.hpp b/SYCL/Matrix/Legacy/element_wise_all_ops_bf16_impl.hpp new file mode 100644 index 0000000000..5beb3fa8e1 --- /dev/null +++ b/SYCL/Matrix/Legacy/element_wise_all_ops_bf16_impl.hpp @@ -0,0 +1,253 @@ + +#define TM 8 +#define TN SG_SZ +#define TK 16 + +static float make_fp32(uint16_t x) { + unsigned int y = x; + y = y << 16; + float *res = reinterpret_cast(&y); + return *res; +} + +static uint16_t make_bf16(float x) { + int *res = reinterpret_cast(&x); + *res = *res >> 16; + return (uint16_t)*res; +} + +template struct big_matrix { +public: + T *mat; + +public: + T *get_data() { return mat; } + void set_data(T *data) { mat = data; } + big_matrix(T *data) : mat(data) {} +}; + +template +void assert_ops_ref( + accessor C, + const float ref) { + for (size_t i = 0; i < M; i++) + for (size_t j = 0; j < N; j++) { + auto diff = make_fp32(C[i][j]) - ref; + assert(std::fabs(static_cast(diff)) < + std::numeric_limits::epsilon()); + } +} +template +void matrix_verify_add(queue q, big_matrix &A, nd_range<2> &r, + const float ref) { + buffer bufA(A.get_data(), range<2>(M, N)); + + q.submit([&](handler &cgh) { + auto accA = bufA.get_access(cgh); + + cgh.parallel_for( + r, [accA](nd_item<2> spmd_item) [[intel::reqd_sub_group_size(SG_SZ)]] { + const auto global_idx = spmd_item.get_global_id(0); + const auto global_idy = spmd_item.get_global_id(1); + const auto sg_startx = global_idx - spmd_item.get_local_id(0); + const auto sg_starty = global_idy - spmd_item.get_local_id(1); + + ext::oneapi::sub_group sg = spmd_item.get_sub_group(); + joint_matrix sub_a(sg); + + joint_matrix_fill(sg, sub_a, make_bf16(5.0)); + + auto wi_slice_a = sub_a.get_wi_data(); + for (int i = 0; i < wi_slice_a.length(); i++) { + wi_slice_a[i] = wi_slice_a[i] + make_bf16(2); + } + joint_matrix_store(sg, sub_a, + accA.get_pointer() + (sg_startx * TM) * N + + sg_starty / SG_SZ * TN, + N, matrix_layout::row_major); + }); // parallel for + }).wait(); + assert_ops_ref(bufA.get_access(), ref); +} + +template +void matrix_verify_sub(queue q, big_matrix &A, nd_range<2> &r, + const float ref) { + buffer bufA(A.get_data(), range<2>(M, N)); + + q.submit([&](handler &cgh) { + auto accA = bufA.get_access(cgh); + + cgh.parallel_for( + r, [accA](nd_item<2> spmd_item) [[intel::reqd_sub_group_size(SG_SZ)]] { + const auto global_idx = spmd_item.get_global_id(0); + const auto global_idy = spmd_item.get_global_id(1); + const auto sg_startx = global_idx - spmd_item.get_local_id(0); + const auto sg_starty = global_idy - spmd_item.get_local_id(1); + + ext::oneapi::sub_group sg = spmd_item.get_sub_group(); + joint_matrix sub_a(sg); + + joint_matrix_fill(sg, sub_a, make_bf16(5.0)); + + auto wi_slice_a = sub_a.get_wi_data(); + for (int i = 0; i < wi_slice_a.length(); i++) { + wi_slice_a[i] = wi_slice_a[i] - make_bf16(2); + } + joint_matrix_store(sg, sub_a, + accA.get_pointer() + (sg_startx * TM) * N + + sg_starty / SG_SZ * TN, + N, matrix_layout::row_major); + }); // parallel for + }).wait(); + assert_ops_ref(bufA.get_access(), ref); +} + +template +void matrix_verify_mul(queue q, big_matrix &A, nd_range<2> &r, + const float ref) { + buffer bufA(A.get_data(), range<2>(M, N)); + + q.submit([&](handler &cgh) { + auto accA = bufA.get_access(cgh); + + cgh.parallel_for( + r, [accA](nd_item<2> spmd_item) [[intel::reqd_sub_group_size(SG_SZ)]] { + const auto global_idx = spmd_item.get_global_id(0); + const auto global_idy = spmd_item.get_global_id(1); + const auto sg_startx = global_idx - spmd_item.get_local_id(0); + const auto sg_starty = global_idy - spmd_item.get_local_id(1); + + ext::oneapi::sub_group sg = spmd_item.get_sub_group(); + joint_matrix sub_a(sg); + + joint_matrix_fill(sg, sub_a, make_bf16(5.0)); + + auto wi_slice_a = sub_a.get_wi_data(); + for (int i = 0; i < wi_slice_a.length(); i++) { + wi_slice_a[i] = wi_slice_a[i] * make_bf16(3.0); + } + joint_matrix_store(sg, sub_a, + accA.get_pointer() + (sg_startx * TM) * N + + sg_starty / SG_SZ * TN, + N, matrix_layout::row_major); + }); // parallel for + }).wait(); + assert_ops_ref(bufA.get_access(), ref); +} + +template +void matrix_verify_div(queue q, big_matrix &A, nd_range<2> &r, + const float ref) { + buffer bufA(A.get_data(), range<2>(M, N)); + + q.submit([&](handler &cgh) { + auto accA = bufA.get_access(cgh); + + cgh.parallel_for( + r, [accA](nd_item<2> spmd_item) [[intel::reqd_sub_group_size(SG_SZ)]] { + const auto global_idx = spmd_item.get_global_id(0); + const auto global_idy = spmd_item.get_global_id(1); + const auto sg_startx = global_idx - spmd_item.get_local_id(0); + const auto sg_starty = global_idy - spmd_item.get_local_id(1); + + ext::oneapi::sub_group sg = spmd_item.get_sub_group(); + joint_matrix sub_a(sg); + + joint_matrix_fill(sg, sub_a, make_bf16(4.0)); + + auto wi_slice_a = sub_a.get_wi_data(); + for (int i = 0; i < wi_slice_a.length(); i++) { + wi_slice_a[i] = wi_slice_a[i] / make_bf16(2.0); + } + joint_matrix_store(sg, sub_a, + accA.get_pointer() + (sg_startx * TM) * N + + sg_starty / SG_SZ * TN, + N, matrix_layout::row_major); + }); // parallel for + }).wait(); + assert_ops_ref(bufA.get_access(), ref); +} + +template +void matrix_verify_logic(queue q, big_matrix &A, nd_range<2> &r, + const float ref) { + buffer bufA(A.get_data(), range<2>(M, N)); + + q.submit([&](handler &cgh) { + auto accA = bufA.get_access(cgh); + cgh.parallel_for( + r, [accA](nd_item<2> spmd_item) [[intel::reqd_sub_group_size(SG_SZ)]] { + const auto global_idx = spmd_item.get_global_id(0); + const auto global_idy = spmd_item.get_global_id(1); + const auto sg_startx = global_idx - spmd_item.get_local_id(0); + const auto sg_starty = global_idy - spmd_item.get_local_id(1); + + ext::oneapi::sub_group sg = spmd_item.get_sub_group(); + joint_matrix sub_a(sg); + + joint_matrix_fill(sg, sub_a, make_bf16(5.0)); + + auto wi_slice_a = sub_a.get_wi_data(); + for (int i = 0; i < wi_slice_a.length(); i++) { + if (wi_slice_a[i]) { + if (wi_slice_a[i] > make_bf16(2.0) || + wi_slice_a[i] >= make_bf16(2.0) || + wi_slice_a[i] < make_bf16(2.0) || + wi_slice_a[i] <= make_bf16(2.0)) { + T val = (wi_slice_a[i] != make_bf16(2.0)) ? wi_slice_a[i] + : make_bf16(2.0); + val = make_bf16(make_fp32(val) - static_cast(1)); + val = make_bf16(make_fp32(val) + static_cast(1)); + if (wi_slice_a[i] == make_bf16(2.0)) { + val = make_bf16(make_fp32(val) - static_cast(2)); + val = make_bf16(make_fp32(val) * static_cast(3)); + val = make_bf16(make_fp32(val) / static_cast(2)); + + } else { + val = make_bf16(make_fp32(val) + static_cast(2)); + } + wi_slice_a[i] = val; + } + } + } + joint_matrix_store(sg, sub_a, + accA.get_pointer() + (sg_startx * TM) * N + + sg_starty / SG_SZ * TN, + N, matrix_layout::row_major); + }); // parallel for + }).wait(); + assert_ops_ref(bufA.get_access(), ref); +} + +static constexpr size_t MATRIX_M = TM * 2; +static constexpr size_t MATRIX_N = TN * 2; +unsigned short A[MATRIX_M][MATRIX_N]; +float D[MATRIX_M][MATRIX_N]; + +void matrix_ops_ref(float *D, int M, int N) { + for (int m = 0; m < M; m++) + for (int n = 0; n < N; n++) { + *(D + m * N + n) = 0; + *(D + m * N + n) *= 2; + } +} + +int main() { + + big_matrix MD((float *)&D); + big_matrix MA((unsigned short *)&A); + + size_t NDRangeM = MATRIX_M / TM; + size_t NDRangeN = MATRIX_N / TN; + queue q; + nd_range<2> r({NDRangeM, NDRangeN * SG_SZ}, {1, 1 * SG_SZ}); + + matrix_verify_add(q, MA, r, 7.0); + matrix_verify_sub(q, MA, r, 3.0); + matrix_verify_mul(q, MA, r, 15.0); + matrix_verify_div(q, MA, r, 2.0); + matrix_verify_logic(q, MA, r, 7.0); + + return 0; +} diff --git a/SYCL/Matrix/Legacy/element_wise_all_ops_half.cpp b/SYCL/Matrix/Legacy/element_wise_all_ops_half.cpp new file mode 100644 index 0000000000..bd8becff4c --- /dev/null +++ b/SYCL/Matrix/Legacy/element_wise_all_ops_half.cpp @@ -0,0 +1,25 @@ +//==----------- element_wise_all_ops_half.cpp - DPC++ joint_matrix---------==// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// REQUIRES: matrix + +// Only runs on DPAS because AMX implementation does not support half data type +// yet +// RUN: %clangxx -fsycl %s -o %t.out -DSYCL_EXT_ONEAPI_MATRIX_VERSION=1 +// RUN: %GPU_RUN_PLACEHOLDER %t.out + +#include +#include +#include + +using namespace sycl; +using namespace sycl::ext::intel; +using namespace sycl::ext::oneapi::experimental::matrix; + +#define SG_SZ 16 + +#include "element_wise_all_ops_half_impl.hpp" diff --git a/SYCL/Matrix/Legacy/element_wise_all_ops_half_impl.hpp b/SYCL/Matrix/Legacy/element_wise_all_ops_half_impl.hpp new file mode 100644 index 0000000000..49b7e165eb --- /dev/null +++ b/SYCL/Matrix/Legacy/element_wise_all_ops_half_impl.hpp @@ -0,0 +1,240 @@ +#define TM 8 +#define TN SG_SZ +#define TK 16 + +template struct big_matrix { +private: + T *mat; + +public: + T *get_data() { return mat; } + void set_data(T *data) { mat = data; } + big_matrix(T *data) : mat(data) {} +}; + +template +void assert_ops_ref( + accessor C, + const float ref) { + for (size_t i = 0; i < M; i++) + for (size_t j = 0; j < N; j++) { + auto diff = C[i][j] - ref; + assert(std::fabs(static_cast(diff)) < + std::numeric_limits::epsilon()); + } +} +template +void matrix_verify_add(queue q, big_matrix &A, nd_range<2> &r, + const float ref) { + buffer bufA(A.get_data(), range<2>(M, N)); + + q.submit([&](handler &cgh) { + auto accA = bufA.get_access(cgh); + + cgh.parallel_for( + r, [accA](nd_item<2> spmd_item) [[intel::reqd_sub_group_size(SG_SZ)]] { + const auto global_idx = spmd_item.get_global_id(0); + const auto global_idy = spmd_item.get_global_id(1); + const auto sg_startx = global_idx - spmd_item.get_local_id(0); + const auto sg_starty = global_idy - spmd_item.get_local_id(1); + + ext::oneapi::sub_group sg = spmd_item.get_sub_group(); + joint_matrix sub_a(sg); + + joint_matrix_fill(sg, sub_a, 5.0); + + auto wi_slice_a = sub_a.get_wi_data(); + for (int i = 0; i < wi_slice_a.length(); i++) { + wi_slice_a[i] = wi_slice_a[i] + static_cast(2); + } + joint_matrix_store(sg, sub_a, + accA.get_pointer() + (sg_startx * TM) * N + + sg_starty / SG_SZ * TN, + N, matrix_layout::row_major); + }); // parallel for + }).wait(); + assert_ops_ref(bufA.get_access(), ref); +} + +template +void matrix_verify_sub(queue q, big_matrix &A, nd_range<2> &r, + const float ref) { + buffer bufA(A.get_data(), range<2>(M, N)); + + q.submit([&](handler &cgh) { + auto accA = bufA.get_access(cgh); + + cgh.parallel_for( + r, [accA](nd_item<2> spmd_item) [[intel::reqd_sub_group_size(SG_SZ)]] { + const auto global_idx = spmd_item.get_global_id(0); + const auto global_idy = spmd_item.get_global_id(1); + const auto sg_startx = global_idx - spmd_item.get_local_id(0); + const auto sg_starty = global_idy - spmd_item.get_local_id(1); + + ext::oneapi::sub_group sg = spmd_item.get_sub_group(); + joint_matrix sub_a(sg); + + joint_matrix_fill(sg, sub_a, 5.0); + + auto wi_slice_a = sub_a.get_wi_data(); + for (int i = 0; i < wi_slice_a.length(); i++) { + wi_slice_a[i] = wi_slice_a[i] - static_cast(2); + } + joint_matrix_store(sg, sub_a, + accA.get_pointer() + (sg_startx * TM) * N + + sg_starty / SG_SZ * TN, + N, matrix_layout::row_major); + }); // parallel for + }).wait(); + assert_ops_ref(bufA.get_access(), ref); +} + +template +void matrix_verify_mul(queue q, big_matrix &A, nd_range<2> &r, + const float ref) { + buffer bufA(A.get_data(), range<2>(M, N)); + + q.submit([&](handler &cgh) { + auto accA = bufA.get_access(cgh); + + cgh.parallel_for( + r, [accA](nd_item<2> spmd_item) [[intel::reqd_sub_group_size(SG_SZ)]] { + const auto global_idx = spmd_item.get_global_id(0); + const auto global_idy = spmd_item.get_global_id(1); + const auto sg_startx = global_idx - spmd_item.get_local_id(0); + const auto sg_starty = global_idy - spmd_item.get_local_id(1); + + ext::oneapi::sub_group sg = spmd_item.get_sub_group(); + joint_matrix sub_a(sg); + + joint_matrix_fill(sg, sub_a, 5.0); + + auto wi_slice_a = sub_a.get_wi_data(); + for (int i = 0; i < wi_slice_a.length(); i++) { + wi_slice_a[i] = wi_slice_a[i] * static_cast(3.0); + } + joint_matrix_store(sg, sub_a, + accA.get_pointer() + (sg_startx * TM) * N + + sg_starty / SG_SZ * TN, + N, matrix_layout::row_major); + }); // parallel for + }).wait(); + assert_ops_ref(bufA.get_access(), ref); +} + +template +void matrix_verify_div(queue q, big_matrix &A, nd_range<2> &r, + const float ref) { + buffer bufA(A.get_data(), range<2>(M, N)); + + q.submit([&](handler &cgh) { + auto accA = bufA.get_access(cgh); + + cgh.parallel_for( + r, [accA](nd_item<2> spmd_item) [[intel::reqd_sub_group_size(SG_SZ)]] { + const auto global_idx = spmd_item.get_global_id(0); + const auto global_idy = spmd_item.get_global_id(1); + const auto sg_startx = global_idx - spmd_item.get_local_id(0); + const auto sg_starty = global_idy - spmd_item.get_local_id(1); + + ext::oneapi::sub_group sg = spmd_item.get_sub_group(); + joint_matrix sub_a(sg); + + joint_matrix_fill(sg, sub_a, 4.0); + + auto wi_slice_a = sub_a.get_wi_data(); + for (int i = 0; i < wi_slice_a.length(); i++) { + wi_slice_a[i] = wi_slice_a[i] / static_cast(2.0); + } + joint_matrix_store(sg, sub_a, + accA.get_pointer() + (sg_startx * TM) * N + + sg_starty / SG_SZ * TN, + N, matrix_layout::row_major); + }); // parallel for + }).wait(); + assert_ops_ref(bufA.get_access(), ref); +} + +template +void matrix_verify_logic(queue q, big_matrix &A, nd_range<2> &r, + const float ref) { + buffer bufA(A.get_data(), range<2>(M, N)); + + q.submit([&](handler &cgh) { + auto accA = bufA.get_access(cgh); + + cgh.parallel_for( + r, [accA](nd_item<2> spmd_item) [[intel::reqd_sub_group_size(SG_SZ)]] { + const auto global_idx = spmd_item.get_global_id(0); + const auto global_idy = spmd_item.get_global_id(1); + const auto sg_startx = global_idx - spmd_item.get_local_id(0); + const auto sg_starty = global_idy - spmd_item.get_local_id(1); + + ext::oneapi::sub_group sg = spmd_item.get_sub_group(); + joint_matrix sub_a(sg); + + joint_matrix_fill(sg, sub_a, 5.0); + + auto wi_slice_a = sub_a.get_wi_data(); + for (int i = 0; i < wi_slice_a.length(); i++) { + if (wi_slice_a[i]) { + if (wi_slice_a[i] > static_cast(2.0) || + wi_slice_a[i] >= static_cast(2.0) || + wi_slice_a[i] < static_cast(2.0) || + wi_slice_a[i] <= static_cast(2.0)) { + T val = (wi_slice_a[i] != static_cast(2.0)) + ? wi_slice_a[i] + : static_cast(2.0); + val--; + val++; + if (wi_slice_a[i] == static_cast(2.0)) { + val -= 2; + val *= 3.0; + val /= 2.0; + } else { + val += 2; + } + wi_slice_a[i] = val; + } + } + } + joint_matrix_store(sg, sub_a, + accA.get_pointer() + (sg_startx * TM) * N + + sg_starty / SG_SZ * TN, + N, matrix_layout::row_major); + }); // parallel for + }).wait(); + assert_ops_ref(bufA.get_access(), ref); +} + +static constexpr size_t MATRIX_M = TM * 2; +static constexpr size_t MATRIX_N = TN * 2; +half A[MATRIX_M][MATRIX_N]; +float D[MATRIX_M][MATRIX_N]; + +void matrix_ops_ref(float *D, int M, int N) { + for (int m = 0; m < M; m++) + for (int n = 0; n < N; n++) { + *(D + m * N + n) = 0; + *(D + m * N + n) *= 2; + } +} + +int main() { + + big_matrix MD((float *)&D); + big_matrix MA((half *)&A); + + size_t NDRangeM = MATRIX_M / TM; + size_t NDRangeN = MATRIX_N / TN; + queue q; + nd_range<2> r({NDRangeM, NDRangeN * SG_SZ}, {1, 1 * SG_SZ}); + + matrix_verify_add(q, MA, r, 7.0); + matrix_verify_sub(q, MA, r, 3.0); + matrix_verify_mul(q, MA, r, 15.0); + matrix_verify_div(q, MA, r, 2.0); + matrix_verify_logic(q, MA, r, 7.0); + + return 0; +} diff --git a/SYCL/Matrix/Legacy/element_wise_all_ops_int8.cpp b/SYCL/Matrix/Legacy/element_wise_all_ops_int8.cpp new file mode 100644 index 0000000000..49a16d3964 --- /dev/null +++ b/SYCL/Matrix/Legacy/element_wise_all_ops_int8.cpp @@ -0,0 +1,24 @@ +//==----------- element_wise_all_ops_int8.cpp - DPC++ joint_matrix---------==// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// REQUIRES: matrix + +// RUN: %clangxx -fsycl %s -o %t.out -DSYCL_EXT_ONEAPI_MATRIX_VERSION=1 +// RUN: %CPU_RUN_PLACEHOLDER %t.out +// RUN: %GPU_RUN_PLACEHOLDER %t.out + +#include +#include +#include + +using namespace sycl; +using namespace sycl::ext::intel; +using namespace sycl::ext::oneapi::experimental::matrix; + +#define SG_SZ 16 + +#include "element_wise_all_ops_int8_impl.hpp" diff --git a/SYCL/Matrix/Legacy/element_wise_all_ops_int8_impl.hpp b/SYCL/Matrix/Legacy/element_wise_all_ops_int8_impl.hpp new file mode 100644 index 0000000000..8323695cf1 --- /dev/null +++ b/SYCL/Matrix/Legacy/element_wise_all_ops_int8_impl.hpp @@ -0,0 +1,228 @@ +#define TM 8 +#define TN SG_SZ +#define TK 32 + +template struct big_matrix { +public: + T *mat; + +public: + T *get_data() { return mat; } + void set_data(T *data) { mat = data; } + big_matrix(T *data) : mat(data) {} +}; + +template +void assert_ops_ref( + accessor C, + const int ref) { + for (size_t i = 0; i < M; i++) + for (size_t j = 0; j < N; j++) { + auto diff = C[i][j] - ref; + assert(std::fabs(static_cast(diff)) <= + std::numeric_limits::epsilon()); + } +} +template +void matrix_verify_add(queue q, big_matrix &A, nd_range<2> &r, + const int ref) { + buffer bufA(A.get_data(), range<2>(M, N)); + + q.submit([&](handler &cgh) { + auto accA = bufA.get_access(cgh); + + cgh.parallel_for( + r, [accA](nd_item<2> spmd_item) [[intel::reqd_sub_group_size(SG_SZ)]] { + const auto global_idx = spmd_item.get_global_id(0); + const auto global_idy = spmd_item.get_global_id(1); + const auto sg_startx = global_idx - spmd_item.get_local_id(0); + const auto sg_starty = global_idy - spmd_item.get_local_id(1); + + ext::oneapi::sub_group sg = spmd_item.get_sub_group(); + joint_matrix sub_a(sg); + + joint_matrix_fill(sg, sub_a, 5); + + auto wi_slice_a = sub_a.get_wi_data(); + for (int i = 0; i < wi_slice_a.length(); i++) { + wi_slice_a[i] = wi_slice_a[i] + 2; + } + joint_matrix_store(sg, sub_a, + accA.get_pointer() + (sg_startx * TM) * N + + sg_starty / SG_SZ * TN, + N, matrix_layout::row_major); + }); // parallel for + }).wait(); + assert_ops_ref(bufA.get_access(), ref); +} + +template +void matrix_verify_sub(queue q, big_matrix &A, nd_range<2> &r, + const int ref) { + buffer bufA(A.get_data(), range<2>(M, N)); + + q.submit([&](handler &cgh) { + auto accA = bufA.get_access(cgh); + + cgh.parallel_for( + r, [accA](nd_item<2> spmd_item) [[intel::reqd_sub_group_size(SG_SZ)]] { + const auto global_idx = spmd_item.get_global_id(0); + const auto global_idy = spmd_item.get_global_id(1); + const auto sg_startx = global_idx - spmd_item.get_local_id(0); + const auto sg_starty = global_idy - spmd_item.get_local_id(1); + + ext::oneapi::sub_group sg = spmd_item.get_sub_group(); + joint_matrix sub_a(sg); + + joint_matrix_fill(sg, sub_a, 5); + + auto wi_slice_a = sub_a.get_wi_data(); + for (int i = 0; i < wi_slice_a.length(); i++) { + wi_slice_a[i] = wi_slice_a[i] - 2; + } + joint_matrix_store(sg, sub_a, + accA.get_pointer() + (sg_startx * TM) * N + + sg_starty / SG_SZ * TN, + N, matrix_layout::row_major); + }); // parallel for + }).wait(); + assert_ops_ref(bufA.get_access(), ref); +} + +template +void matrix_verify_mul(queue q, big_matrix &A, nd_range<2> &r, + const int ref) { + buffer bufA(A.get_data(), range<2>(M, N)); + + q.submit([&](handler &cgh) { + auto accA = bufA.get_access(cgh); + + cgh.parallel_for( + r, [accA](nd_item<2> spmd_item) [[intel::reqd_sub_group_size(SG_SZ)]] { + const auto global_idx = spmd_item.get_global_id(0); + const auto global_idy = spmd_item.get_global_id(1); + const auto sg_startx = global_idx - spmd_item.get_local_id(0); + const auto sg_starty = global_idy - spmd_item.get_local_id(1); + + ext::oneapi::sub_group sg = spmd_item.get_sub_group(); + joint_matrix sub_a(sg); + + joint_matrix_fill(sg, sub_a, 5); + + auto wi_slice_a = sub_a.get_wi_data(); + for (int i = 0; i < wi_slice_a.length(); i++) { + wi_slice_a[i] = wi_slice_a[i] * 3; + } + joint_matrix_store(sg, sub_a, + accA.get_pointer() + (sg_startx * TM) * N + + sg_starty / SG_SZ * TN, + N, matrix_layout::row_major); + }); // parallel for + }).wait(); + assert_ops_ref(bufA.get_access(), ref); +} + +template +void matrix_verify_div(queue q, big_matrix &A, nd_range<2> &r, + const int ref) { + buffer bufA(A.get_data(), range<2>(M, N)); + + q.submit([&](handler &cgh) { + auto accA = bufA.get_access(cgh); + + cgh.parallel_for( + r, [accA](nd_item<2> spmd_item) [[intel::reqd_sub_group_size(SG_SZ)]] { + const auto global_idx = spmd_item.get_global_id(0); + const auto global_idy = spmd_item.get_global_id(1); + const auto sg_startx = global_idx - spmd_item.get_local_id(0); + const auto sg_starty = global_idy - spmd_item.get_local_id(1); + + ext::oneapi::sub_group sg = spmd_item.get_sub_group(); + joint_matrix sub_a(sg); + + joint_matrix_fill(sg, sub_a, 4); + + auto wi_slice_a = sub_a.get_wi_data(); + for (int i = 0; i < wi_slice_a.length(); i++) { + wi_slice_a[i] = wi_slice_a[i] / 2; + } + joint_matrix_store(sg, sub_a, + accA.get_pointer() + (sg_startx * TM) * N + + sg_starty / SG_SZ * TN, + N, matrix_layout::row_major); + }); // parallel for + }).wait(); + assert_ops_ref(bufA.get_access(), ref); +} + +template +void matrix_verify_logic(queue q, big_matrix &A, nd_range<2> &r, + const int ref) { + buffer bufA(A.get_data(), range<2>(M, N)); + + q.submit([&](handler &cgh) { + auto accA = bufA.get_access(cgh); + + cgh.parallel_for( + r, [accA](nd_item<2> spmd_item) [[intel::reqd_sub_group_size(SG_SZ)]] { + const auto global_idx = spmd_item.get_global_id(0); + const auto global_idy = spmd_item.get_global_id(1); + const auto sg_startx = global_idx - spmd_item.get_local_id(0); + const auto sg_starty = global_idy - spmd_item.get_local_id(1); + + ext::oneapi::sub_group sg = spmd_item.get_sub_group(); + joint_matrix sub_a(sg); + + joint_matrix_fill(sg, sub_a, 5); + + auto wi_slice_a = sub_a.get_wi_data(); + for (int i = 0; i < wi_slice_a.length(); i++) { + if (wi_slice_a[i]) { + if (wi_slice_a[i] > 2 || wi_slice_a[i] >= 2 || + wi_slice_a[i] < 2 || wi_slice_a[i] <= 2) { + T val = (wi_slice_a[i] != 2) ? wi_slice_a[i] : 2; + val--; + val++; + if (wi_slice_a[i] == 2) { + val -= 2; + val *= 3; + val /= 2; + } else { + val += 2; + } + wi_slice_a[i] = val; + } + } + } + joint_matrix_store(sg, sub_a, + accA.get_pointer() + (sg_startx * TM) * N + + sg_starty / SG_SZ * TN, + N, matrix_layout::row_major); + }); // parallel for + }).wait(); + assert_ops_ref(bufA.get_access(), ref); +} + +static constexpr size_t MATRIX_M = TM * 2; +static constexpr size_t MATRIX_N = TN * 2; +int8_t A[MATRIX_M][MATRIX_N]; +int D[MATRIX_M][MATRIX_N]; + +int main() { + + big_matrix MD((int *)&D); + big_matrix MA((int8_t *)&A); + + size_t NDRangeM = MATRIX_M / TM; + size_t NDRangeN = MATRIX_N / TN; + queue q; + nd_range<2> r({NDRangeM, NDRangeN * SG_SZ}, {1, 1 * SG_SZ}); + + matrix_verify_add(q, MA, r, 7); + matrix_verify_sub(q, MA, r, 3); + matrix_verify_mul(q, MA, r, 15); + matrix_verify_div(q, MA, r, 2); + matrix_verify_logic(q, MA, r, 7); + + return 0; +} diff --git a/SYCL/Matrix/Legacy/element_wise_all_ops_int8_packed.cpp b/SYCL/Matrix/Legacy/element_wise_all_ops_int8_packed.cpp new file mode 100644 index 0000000000..d3e38f638b --- /dev/null +++ b/SYCL/Matrix/Legacy/element_wise_all_ops_int8_packed.cpp @@ -0,0 +1,26 @@ +//==------ element_wise_all_ops_int8_packed.cpp - DPC++ joint_matrix-------==// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// REQUIRES: matrix + +// RUN: %clangxx -fsycl %s -o %t.out -DSYCL_EXT_ONEAPI_MATRIX_VERSION=1 +// RUN: %CPU_RUN_PLACEHOLDER %t.out +// RUN: %GPU_RUN_PLACEHOLDER %t.out + +// XFAIL: gpu + +#include +#include +#include + +using namespace sycl; +using namespace sycl::ext::intel; +using namespace sycl::ext::oneapi::experimental::matrix; + +#define SG_SZ 16 + +#include "element_wise_all_ops_int8_packed_impl.hpp" diff --git a/SYCL/Matrix/Legacy/element_wise_all_ops_int8_packed_impl.hpp b/SYCL/Matrix/Legacy/element_wise_all_ops_int8_packed_impl.hpp new file mode 100644 index 0000000000..21fecb2e4f --- /dev/null +++ b/SYCL/Matrix/Legacy/element_wise_all_ops_int8_packed_impl.hpp @@ -0,0 +1,228 @@ +#define TM 8 +#define TN SG_SZ +#define TK 32 + +template struct big_matrix { +public: + T *mat; + +public: + T *get_data() { return mat; } + void set_data(T *data) { mat = data; } + big_matrix(T *data) : mat(data) {} +}; + +template +void assert_ops_ref( + accessor C, + const int ref) { + for (size_t i = 0; i < M; i++) + for (size_t j = 0; j < N; j++) { + auto diff = C[i][j] - ref; + assert(std::fabs(static_cast(diff)) <= + std::numeric_limits::epsilon()); + } +} +template +void matrix_verify_add(queue q, big_matrix &A, nd_range<2> &r, + const int ref) { + buffer bufB(A.get_data(), range<2>(M, N)); + + q.submit([&](handler &cgh) { + auto accA = bufB.get_access(cgh); + + cgh.parallel_for( + r, [accA](nd_item<2> spmd_item) [[intel::reqd_sub_group_size(SG_SZ)]] { + const auto global_idx = spmd_item.get_global_id(0); + const auto global_idy = spmd_item.get_global_id(1); + const auto sg_startx = global_idx - spmd_item.get_local_id(0); + const auto sg_starty = global_idy - spmd_item.get_local_id(1); + + ext::oneapi::sub_group sg = spmd_item.get_sub_group(); + joint_matrix sub_b(sg); + + joint_matrix_fill(sg, sub_b, 5); + + auto wi_slice_b = sub_b.get_wi_data(); + for (int i = 0; i < wi_slice_b.length(); i++) { + wi_slice_b[i] = wi_slice_b[i] + 2; + } + joint_matrix_store(sg, sub_b, + accA.get_pointer() + (sg_startx * TM) * N * 4 + + sg_starty / SG_SZ * TN * 4, + N * 4, matrix_layout::row_major); + }); // parallel for + }).wait(); + assert_ops_ref(bufB.get_access(), ref); +} + +template +void matrix_verify_sub(queue q, big_matrix &A, nd_range<2> &r, + const int ref) { + buffer bufB(A.get_data(), range<2>(M, N)); + + q.submit([&](handler &cgh) { + auto accA = bufB.get_access(cgh); + + cgh.parallel_for( + r, [accA](nd_item<2> spmd_item) [[intel::reqd_sub_group_size(SG_SZ)]] { + const auto global_idx = spmd_item.get_global_id(0); + const auto global_idy = spmd_item.get_global_id(1); + const auto sg_startx = global_idx - spmd_item.get_local_id(0); + const auto sg_starty = global_idy - spmd_item.get_local_id(1); + + ext::oneapi::sub_group sg = spmd_item.get_sub_group(); + joint_matrix sub_b(sg); + + joint_matrix_fill(sg, sub_b, 5); + + auto wi_slice_b = sub_b.get_wi_data(); + for (int i = 0; i < wi_slice_b.length(); i++) { + wi_slice_b[i] = wi_slice_b[i] - 2; + } + joint_matrix_store(sg, sub_b, + accA.get_pointer() + (sg_startx * TM) * N * 4 + + sg_starty / SG_SZ * TN * 4, + N * 4, matrix_layout::row_major); + }); // parallel for + }).wait(); + assert_ops_ref(bufB.get_access(), ref); +} + +template +void matrix_verify_mul(queue q, big_matrix &A, nd_range<2> &r, + const int ref) { + buffer bufB(A.get_data(), range<2>(M, N)); + + q.submit([&](handler &cgh) { + auto accA = bufB.get_access(cgh); + + cgh.parallel_for( + r, [accA](nd_item<2> spmd_item) [[intel::reqd_sub_group_size(SG_SZ)]] { + const auto global_idx = spmd_item.get_global_id(0); + const auto global_idy = spmd_item.get_global_id(1); + const auto sg_startx = global_idx - spmd_item.get_local_id(0); + const auto sg_starty = global_idy - spmd_item.get_local_id(1); + + ext::oneapi::sub_group sg = spmd_item.get_sub_group(); + joint_matrix sub_b(sg); + + joint_matrix_fill(sg, sub_b, 5); + + auto wi_slice_b = sub_b.get_wi_data(); + for (int i = 0; i < wi_slice_b.length(); i++) { + wi_slice_b[i] = wi_slice_b[i] * 3; + } + joint_matrix_store(sg, sub_b, + accA.get_pointer() + (sg_startx * TM) * N * 4 + + sg_starty / SG_SZ * TN * 4, + N * 4, matrix_layout::row_major); + }); // parallel for + }).wait(); + assert_ops_ref(bufB.get_access(), ref); +} + +template +void matrix_verify_div(queue q, big_matrix &A, nd_range<2> &r, + const int ref) { + buffer bufB(A.get_data(), range<2>(M, N)); + + q.submit([&](handler &cgh) { + auto accA = bufB.get_access(cgh); + + cgh.parallel_for( + r, [accA](nd_item<2> spmd_item) [[intel::reqd_sub_group_size(SG_SZ)]] { + const auto global_idx = spmd_item.get_global_id(0); + const auto global_idy = spmd_item.get_global_id(1); + const auto sg_startx = global_idx - spmd_item.get_local_id(0); + const auto sg_starty = global_idy - spmd_item.get_local_id(1); + + ext::oneapi::sub_group sg = spmd_item.get_sub_group(); + joint_matrix sub_b(sg); + + joint_matrix_fill(sg, sub_b, 4); + + auto wi_slice_b = sub_b.get_wi_data(); + for (int i = 0; i < wi_slice_b.length(); i++) { + wi_slice_b[i] = wi_slice_b[i] / 2; + } + joint_matrix_store(sg, sub_b, + accA.get_pointer() + (sg_startx * TM) * N * 4 + + sg_starty / SG_SZ * TN * 4, + N * 4, matrix_layout::row_major); + }); // parallel for + }).wait(); + assert_ops_ref(bufB.get_access(), ref); +} + +template +void matrix_verify_logic(queue q, big_matrix &A, nd_range<2> &r, + const int ref) { + buffer bufB(A.get_data(), range<2>(M, N)); + + q.submit([&](handler &cgh) { + auto accA = bufB.get_access(cgh); + + cgh.parallel_for( + r, [accA](nd_item<2> spmd_item) [[intel::reqd_sub_group_size(SG_SZ)]] { + const auto global_idx = spmd_item.get_global_id(0); + const auto global_idy = spmd_item.get_global_id(1); + const auto sg_startx = global_idx - spmd_item.get_local_id(0); + const auto sg_starty = global_idy - spmd_item.get_local_id(1); + + ext::oneapi::sub_group sg = spmd_item.get_sub_group(); + joint_matrix sub_b(sg); + + joint_matrix_fill(sg, sub_b, 5); + + auto wi_slice_b = sub_b.get_wi_data(); + for (int i = 0; i < wi_slice_b.length(); i++) { + if (wi_slice_b[i]) { + if (wi_slice_b[i] > 2 || wi_slice_b[i] >= 2 || + wi_slice_b[i] < 2 || wi_slice_b[i] <= 2) { + T val = (wi_slice_b[i] != 2) ? wi_slice_b[i] : 2; + val--; + val++; + if (wi_slice_b[i] == 2) { + val -= 2; + val *= 3; + val /= 2; + } else { + val += 2; + } + wi_slice_b[i] = val; + } + } + } + joint_matrix_store(sg, sub_b, + accA.get_pointer() + (sg_startx * TM) * N * 4 + + sg_starty / SG_SZ * TN * 4, + N * 4, matrix_layout::row_major); + }); // parallel for + }).wait(); + assert_ops_ref(bufB.get_access(), ref); +} + +static constexpr size_t MATRIX_M = TM * 2; +static constexpr size_t MATRIX_N = TN * 2; +int8_t B[MATRIX_M][MATRIX_N]; +int D[MATRIX_M][MATRIX_N]; + +int main() { + + big_matrix MD((int *)&D); + big_matrix MB((int8_t *)&B); + + size_t NDRangeM = MATRIX_M / TM; + size_t NDRangeN = MATRIX_N / TN; + queue q; + nd_range<2> r({NDRangeM, NDRangeN * SG_SZ}, {1, 1 * SG_SZ}); + + matrix_verify_add(q, MB, r, 7); + matrix_verify_sub(q, MB, r, 3); + matrix_verify_mul(q, MB, r, 15); + matrix_verify_div(q, MB, r, 2); + matrix_verify_logic(q, MB, r, 7); + + return 0; +} diff --git a/SYCL/Matrix/Legacy/element_wise_irreg_sum_rows.cpp b/SYCL/Matrix/Legacy/element_wise_irreg_sum_rows.cpp new file mode 100644 index 0000000000..2b7895ea5b --- /dev/null +++ b/SYCL/Matrix/Legacy/element_wise_irreg_sum_rows.cpp @@ -0,0 +1,26 @@ +//==-------- element_wise_irreg_sum_rows.cpp - DPC++ joint_matrix----- ----==// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// REQUIRES: matrix + +// RUN: %clangxx -fsycl %s -o %t.out -DSYCL_EXT_ONEAPI_MATRIX_VERSION=1 +// RUN: %CPU_RUN_PLACEHOLDER %t.out +// RUN: %GPU_RUN_PLACEHOLDER %t.out + +// this code calculates the sum of rows into a global array of number of rows +// elements. First, partial reduction is computed inside each SG, then atomic +// add is used to reduce between SG leaders + +#include +#include + +using namespace sycl; +using namespace sycl::ext::oneapi::experimental::matrix; + +#define SG_SZ 16 + +#include "element_wise_irreg_sum_rows_impl.hpp" diff --git a/SYCL/Matrix/Legacy/element_wise_irreg_sum_rows_impl.hpp b/SYCL/Matrix/Legacy/element_wise_irreg_sum_rows_impl.hpp new file mode 100644 index 0000000000..736dfac507 --- /dev/null +++ b/SYCL/Matrix/Legacy/element_wise_irreg_sum_rows_impl.hpp @@ -0,0 +1,106 @@ +#define TN SG_SZ +#define TK 32 + +template struct big_matrix { +public: + T *mat; + +public: + T *get_data() { return mat; } + void set_data(T *data) { mat = data; } + big_matrix(T *data) : mat(data) {} +}; + +template +void sum_rows_ref( + accessor B, + accessor + sum_rows) { + int sum_rows_ref[M] = {0}; + for (size_t i = 0; i < M; i++) { + for (size_t j = 0; j < N; j++) { + sum_rows_ref[i] += B[i][j]; + } + auto diff = sum_rows[i] - sum_rows_ref[i]; + assert(std::fabs(static_cast(diff)) <= + std::numeric_limits::epsilon()); + } +} + +template +void matrix_sum_rows(queue q, big_matrix &B, nd_range<2> &r) { + buffer bufB(B.get_data(), range<2>(M, N)); + // size of vector is known because SG size of set by the user in this case + int sum_rows[M] = {0}; + buffer sum_rows_v(sum_rows, M); // there are total of tK/4 * 2, 16 rows + q.submit([&](handler &cgh) { + auto accB = bufB.get_access(cgh); + + auto v = sum_rows_v.get_access(cgh); + + cgh.parallel_for( + r, [=](nd_item<2> spmd_item) [[intel::reqd_sub_group_size(SG_SZ)]] { + const auto global_idx = spmd_item.get_global_id(0); + const auto global_idy = spmd_item.get_global_id(1); + const auto sg_startx = global_idx - spmd_item.get_local_id(0); + const auto sg_starty = global_idy - spmd_item.get_local_id(1); + + ext::oneapi::sub_group sg = spmd_item.get_sub_group(); + + joint_matrix sub_b(sg); + + joint_matrix_load(sg, sub_b, + accB.get_pointer() + (global_idx * (TK / 4) * N) + + sg_starty / SG_SZ * TN * 4, + N, matrix_layout::packed_b); + // calculate sum of rows in sum_rows_v[8], there are 8 rows in sub_b + // (tK/4) + int32_t sum_local_rows[M] = {0}; // 8 local rows, M total + // sub_b has 32x8 elements, 32 elements per WI, 4 per WI per row + auto data = sub_b.get_wi_data(); + + // each WI calculates local sum of rows + for (int row = 0; row < TK / 4; row++) { // there are 8 rows + for (int i = 0; i < data.length() / (TK / 4); i++) { // 4 per row + // i*SG_SIZE index is found based on the round robin + // distribution we are using in the implementation + sum_local_rows[row + global_idx * (TK / 4)] += data[i + row * 4]; + } + sum_local_rows[row + global_idx * (TK / 4)] = reduce_over_group( + sg, sum_local_rows[row + global_idx * (TK / 4)], + sycl::plus<>()); + + // only Groups leader perform the global reduction + if (global_idy % SG_SZ == 0) { + atomic_fetch_add(v[row + global_idx * (TK / 4)], + sum_local_rows[row + global_idx * (TK / 4)]); + } + } + }); // parallel for + }).wait(); + sum_rows_ref(bufB.get_access(), + sum_rows_v.get_access()); +} + +static constexpr size_t MATRIX_K = TK / 4 * 2; +static constexpr size_t MATRIX_N = TN * 4 * 2; +int8_t B[MATRIX_K][MATRIX_N]; + +int main() { + big_matrix MB((int8_t *)&B); + + size_t NDRangeK = MATRIX_K / (TK / 4); + size_t NDRangeN = (MATRIX_N / 4) / TN; + queue q; + nd_range<2> r({NDRangeK, NDRangeN * SG_SZ}, {1, 1 * SG_SZ}); + + for (int i = 0; i < MATRIX_K; i++) { + for (int j = 0; j < MATRIX_N; j++) { + B[i][j] = i; + } + } + + matrix_sum_rows(q, MB, r); + + return 0; +} diff --git a/SYCL/Matrix/joint_matrix_ss_int8_use.cpp b/SYCL/Matrix/Legacy/element_wise_ops.cpp similarity index 71% rename from SYCL/Matrix/joint_matrix_ss_int8_use.cpp rename to SYCL/Matrix/Legacy/element_wise_ops.cpp index a6c423f6f2..d9a407e131 100644 --- a/SYCL/Matrix/joint_matrix_ss_int8_use.cpp +++ b/SYCL/Matrix/Legacy/element_wise_ops.cpp @@ -1,4 +1,4 @@ -//==-------- joint_matrix_ss_int8_use.cpp - DPC++ joint_matrix-------------==// +//==----------- element_wise_ops.cpp - DPC++ joint_matrix------------- ----==// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -7,12 +7,10 @@ //===----------------------------------------------------------------------===// // REQUIRES: matrix -// RUN: %clangxx -fsycl %s -o %t.out -DSYCL_EXT_ONEAPI_MATRIX=2 +// RUN: %clangxx -fsycl %s -o %t.out -DSYCL_EXT_ONEAPI_MATRIX_VERSION=1 // RUN: %CPU_RUN_PLACEHOLDER %t.out // RUN: %GPU_RUN_PLACEHOLDER %t.out -// XFAIL: gpu - #include #include @@ -21,4 +19,4 @@ using namespace sycl::ext::oneapi::experimental::matrix; #define SG_SZ 16 -#include "joint_matrix_ss_int8_use_impl.hpp" +#include "element_wise_ops_impl.hpp" diff --git a/SYCL/Matrix/Legacy/element_wise_ops_impl.hpp b/SYCL/Matrix/Legacy/element_wise_ops_impl.hpp new file mode 100644 index 0000000000..506193f597 --- /dev/null +++ b/SYCL/Matrix/Legacy/element_wise_ops_impl.hpp @@ -0,0 +1,152 @@ +#define TM 8 +#define TN SG_SZ +#define TK 32 + +template struct big_matrix { +public: + T *mat; + +public: + T *get_data() { return mat; } + void set_data(T *data) { mat = data; } + big_matrix(T *data) : mat(data) {} +}; + +template +void matrix_multiply(big_matrix &C, + big_matrix &A, + big_matrix &B) { + size_t M = NUM_ROWS_C; + size_t N = NUM_COLS_C; + size_t K = NUM_COLS_A; + // B => K/4 x N*4, A => M x K, C => M, N + // stride should be X's cols, e.g., B's stirde = N*4 + assert(NUM_ROWS_C == NUM_ROWS_A && NUM_COLS_A == NUM_ROWS_B * 4); + size_t NDRangeM = M / TM; + size_t NDRangeN = N / TN; + buffer bufA(A.get_data(), range<2>(M, K)); + buffer bufB(B.get_data(), range<2>(K, N)); + buffer bufC(C.get_data(), range<2>(M, N)); + + queue q; + q.submit([&](handler &cgh) { + auto accC = bufC.get_access(cgh); + auto accA = bufA.get_access(cgh); + auto accB = bufB.get_access(cgh); + + cgh.parallel_for( + nd_range<2>({NDRangeM, NDRangeN * SG_SZ}, {1, 1 * SG_SZ}), + [accA, accB, accC, M, N, + K](nd_item<2> spmd_item) [[intel::reqd_sub_group_size(SG_SZ)]] { + // The submatrix API has to be accessed by all the workitems in a + // subgroup these functions will be called once by the subgroup no + // code divergence between the workitems + const auto global_idx = spmd_item.get_global_id(0); + const auto global_idy = spmd_item.get_global_id(1); + const auto sg_startx = global_idx - spmd_item.get_local_id(0); + const auto sg_starty = global_idy - spmd_item.get_local_id(1); + + ext::oneapi::sub_group sg = spmd_item.get_sub_group(); + joint_matrix sub_a(sg); + // For B, since current implementation does not support non-packed + // layout, users need to specify the updated VNNI sizes along with + // the packed_b layout. By default, the layout is row_major and size + // is (TK, TN). + joint_matrix sub_b(sg); + joint_matrix sub_c(sg); + + // AMX: 8 register tiles : 1k byte size, SMmaxxSKmax =16x64 + // strideX = X's cols, so strideC = N, strideA = K, strideB = N*4 + joint_matrix_load(sg, sub_c, + accC.get_pointer() + (sg_startx * TM) * N + + sg_starty / SG_SZ * TN, + N, matrix_layout::row_major); + for (int k = 0; k < K / TK; k += 1) { + joint_matrix_load( + sg, sub_a, accA.get_pointer() + (sg_startx * TM) * K + k * TK, + K, matrix_layout::row_major); + // Assuming B data is already in VNNI format. + joint_matrix_load(sg, sub_b, + accB.get_pointer() + (k * TK / 4) * (N * 4) + + sg_starty / SG_SZ * TN * 4, + N * 4, matrix_layout::packed_b); + sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + } + auto wi_slice_c = sub_c.get_wi_data(); + for (int i = 0; i < wi_slice_c.length(); i++) { + wi_slice_c[i] *= 2; + } + joint_matrix_store(sg, sub_c, + accC.get_pointer() + (sg_startx * TM) * N + + sg_starty / SG_SZ * TN, + N, matrix_layout::row_major); + }); // parallel for + }).wait(); +} + +static constexpr size_t MATRIX_M = TM * 2; +static constexpr size_t MATRIX_N = TN * 2; +static constexpr size_t MATRIX_K = TK * 2; +int8_t A[MATRIX_M][MATRIX_K]; +int8_t B[MATRIX_K / 4][MATRIX_N * 4]; +int32_t C[MATRIX_M][MATRIX_N]; +int32_t D[MATRIX_M][MATRIX_N]; + +void matrix_multiply_ref(int32_t *A_mem, int32_t *B_mem, int32_t *C_mem, int M, + int N, int K) { + // tiling + for (int m = 0; m < M; m++) + for (int n = 0; n < N; n++) { + for (int k = 0; k < K; k++) { + char *va = (char *)(A_mem + m * K + k); + char *vb = (char *)(B_mem + k * N + n); + int acc = *(C_mem + m * N + n); + for (int i = 0; i < 4; i++) { + acc += (va[i] * vb[i]); + } + *(C_mem + m * N + n) = acc; + } + *(C_mem + m * N + n) *= 2; + } +} + +int main() { + for (int i = 0; i < MATRIX_M; i++) { + for (int j = 0; j < MATRIX_K; j++) { + A[i][j] = i + 2 * j; + } + } + for (int i = 0; i < MATRIX_K / 4; i++) { + for (int j = 0; j < MATRIX_N * 4; j++) { + B[i][j] = i + j; + } + } + for (int i = 0; i < MATRIX_M; i++) { + for (int j = 0; j < MATRIX_N; j++) { + C[i][j] = 1; + D[i][j] = 1; + } + } + + big_matrix MC((int32_t *)&C); + big_matrix MD((int32_t *)&D); + big_matrix MA((int8_t *)&A); + big_matrix MB((int8_t *)&B); + matrix_multiply(MC, MA, MB); + matrix_multiply_ref((int32_t *)A, (int32_t *)B, (int32_t *)D, MATRIX_M, + MATRIX_N, MATRIX_K / 4); + + bool res = true; + for (int i = 0; i < MATRIX_M; i++) { + for (int j = 0; j < MATRIX_N; j++) { + if (C[i][j] != D[i][j]) + res = false; + } + } + if (res) + std::cout << "passed\n"; + else + std::cout << "failed\n"; +} diff --git a/SYCL/Matrix/Legacy/elemwise_irreg_size_ops_bf16.cpp b/SYCL/Matrix/Legacy/elemwise_irreg_size_ops_bf16.cpp new file mode 100644 index 0000000000..20604ae676 --- /dev/null +++ b/SYCL/Matrix/Legacy/elemwise_irreg_size_ops_bf16.cpp @@ -0,0 +1,190 @@ +//==-------- elemwise_irreg_size_ops_bf16.cpp - DPC++ joint_matrix---- ----==// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// REQUIRES: matrix + +// RUN: %clangxx -fsycl %s -o %t.out -DSYCL_EXT_ONEAPI_MATRIX_VERSION=1 +// This test is for element wise operations when matrix size does not multiply +// SG size. This corner case only applies to AMX. Also, it tests bf16 type. +// only run this on AMX +// RUN: %CPU_RUN_PLACEHOLDER %t.out + +#include +#include + +using namespace sycl; +using namespace sycl::ext::oneapi::experimental::matrix; + +#define SG_SZ 16 + +// 10x12 is not multiply the sg size, slicing implementation will have to insert +// padding +#define TM 10 +#define TN 12 +#define TK 16 + +template struct big_matrix { +public: + T *mat; + +public: + T *get_data() { return mat; } + void set_data(T *data) { mat = data; } + big_matrix(T *data) : mat(data) {} +}; + +template +void matrix_multiply(big_matrix &C, + big_matrix &A, + big_matrix &B) { + size_t M = NUM_ROWS_C; + size_t N = NUM_COLS_C; + size_t K = NUM_COLS_A; + + assert(NUM_ROWS_C == NUM_ROWS_A && NUM_COLS_A == NUM_ROWS_B * 2); + size_t NDRangeM = M / TM; + size_t NDRangeN = N / TN; + buffer bufA(A.get_data(), range<2>(M, K)); + buffer bufB(B.get_data(), range<2>(K / 2, N * 2)); + buffer bufC((float *)C.get_data(), range<2>(M, N)); + + queue q; + q.submit([&](handler &cgh) { + auto accC = bufC.get_access(cgh); + auto accA = bufA.get_access(cgh); + auto accB = bufB.get_access(cgh); + + cgh.parallel_for( + nd_range<2>({NDRangeM, NDRangeN * SG_SZ}, {1, 1 * SG_SZ}), + [accA, accB, accC, M, N, K](nd_item<2> spmd_item) + [[intel::reqd_sub_group_size(SG_SZ)]] + + { + // The submatrix API has to be accessed by all the workitems in a + // subgroup these functions will be called once by the subgroup no + // code divergence between the workitems + const auto global_idx = spmd_item.get_global_id(0); + const auto global_idy = spmd_item.get_global_id(1); + const auto sg_startx = global_idx - spmd_item.get_local_id(0); + const auto sg_starty = global_idy - spmd_item.get_local_id(1); + + sub_group sg = spmd_item.get_sub_group(); + joint_matrix sub_a(sg); + // For B, since current implementation does not support non-packed + // layout, users need to specify the packed_b layout. + // By default, the layout is row_major + joint_matrix sub_b( + sg); + joint_matrix sub_c(sg); + joint_matrix_load(sg, sub_c, + accC.get_pointer() + (sg_startx * TM) * N + + sg_starty / SG_SZ * TN, + N, matrix_layout::row_major); + for (int k = 0; k < K; k += TK) { + joint_matrix_load(sg, sub_a, + accA.get_pointer() + (sg_startx * TM) * K + k, K, + matrix_layout::row_major); + // Assume we alreay in vnni format. + joint_matrix_load(sg, sub_b, + accB.get_pointer() + (k) * (N) + + sg_starty / SG_SZ * TN * 2, + N * 2, matrix_layout::packed_b); + sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + } + auto wi_slice_c = sub_c.get_wi_data(); + for (int i = 0; i < wi_slice_c.length(); i++) { + wi_slice_c[i] += 5.0; + } + joint_matrix_store(sg, sub_c, + accC.get_pointer() + (sg_startx * TM) * N + + sg_starty / SG_SZ * TN, + N, matrix_layout::row_major); + }); // parallel for + }).wait(); +} + +static constexpr size_t MATRIX_M = TM * 2; +static constexpr size_t MATRIX_N = TN * 2; +static constexpr size_t MATRIX_K = TK * 2; +unsigned short A[MATRIX_M][MATRIX_K]; +unsigned short B[MATRIX_K / 2][MATRIX_N * 2]; +float C[MATRIX_M][MATRIX_N]; +float D[MATRIX_M][MATRIX_N]; + +float make_fp32(short x) { + unsigned int y = x; + y = y << 16; + float *res = reinterpret_cast(&y); + return *res; +} + +unsigned short make_bf16(float x) { + int *res = reinterpret_cast(&x); + *res = *res >> 16; + return (unsigned short)*res; +} + +void matrix_multiply_ref(int *A_mem, int *B_mem, int *C_mem, int M, int N, + int K) { + // tiling + for (int m = 0; m < M; m++) + for (int n = 0; n < N; n++) { + for (int k = 0; k < K; k++) { + short *va = (short *)(A_mem + m * K + k); + short *vb = (short *)(B_mem + k * N + n); + float acc = *((float *)(C_mem + m * N + n)); + // FIXME: Should we do reduce-add in another version? + for (int i = 0; i < 2; i++) { + acc += (make_fp32(va[i]) * make_fp32(vb[i])); + } + *((float *)(C_mem + m * N + n)) = acc; + } + *((float *)(C_mem + m * N + n)) += 5.0; + } +} + +int main() { + for (int i = 0; i < MATRIX_M; i++) { + for (int j = 0; j < MATRIX_K; j++) { + A[i][j] = make_bf16(1.0f * (i + j)); + } + } + for (int i = 0; i < MATRIX_K / 2; i++) { + for (int j = 0; j < MATRIX_N * 2; j++) { + B[i][j] = make_bf16(2.0f * i + 3.0f * j); + } + } + for (int i = 0; i < MATRIX_M; i++) { + for (int j = 0; j < MATRIX_N; j++) { + C[i][j] = 1.0; + D[i][j] = 1.0; + } + } + + big_matrix MC((float *)&C); + big_matrix MD((float *)&D); + big_matrix MA((unsigned short *)&A); + big_matrix MB( + (unsigned short *)&B); + matrix_multiply(MC, MA, MB); + matrix_multiply_ref((int32_t *)A, (int32_t *)B, (int32_t *)D, MATRIX_M, + MATRIX_N, MATRIX_K / 2); + + bool res = true; + for (int i = 0; i < MATRIX_M; i++) { + for (int j = 0; j < MATRIX_N; j++) { + if (C[i][j] != D[i][j]) + res = false; + } + } + if (res) + std::cout << "passed\n"; + else + std::cout << "failed\n"; +} diff --git a/SYCL/Matrix/Legacy/joint_matrix_bf16.cpp b/SYCL/Matrix/Legacy/joint_matrix_bf16.cpp new file mode 100644 index 0000000000..b02e8cfc07 --- /dev/null +++ b/SYCL/Matrix/Legacy/joint_matrix_bf16.cpp @@ -0,0 +1,22 @@ +//==-------- joint_matrix_bf16.cpp - DPC++ joint_matrix--------------- ----==// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// REQUIRES: matrix + +// RUN: %clangxx -fsycl %s -o %t.out -DSYCL_EXT_ONEAPI_MATRIX_VERSION=1 +// RUN: %CPU_RUN_PLACEHOLDER %t.out +// RUN: %GPU_RUN_PLACEHOLDER %t.out + +#include +#include + +using namespace sycl; +using namespace sycl::ext::oneapi::experimental::matrix; + +#define SG_SZ 16 + +#include "joint_matrix_bf16_impl.hpp" diff --git a/SYCL/Matrix/Legacy/joint_matrix_bf16_impl.hpp b/SYCL/Matrix/Legacy/joint_matrix_bf16_impl.hpp new file mode 100644 index 0000000000..795039ae66 --- /dev/null +++ b/SYCL/Matrix/Legacy/joint_matrix_bf16_impl.hpp @@ -0,0 +1,162 @@ +#define TM 8 +#define TN SG_SZ +#define TK 16 + +#define BF16_EPSILON 0.00781250 + +template struct big_matrix { +public: + T *mat; + +public: + T *get_data() { return mat; } + void set_data(T *data) { mat = data; } + big_matrix(T *data) : mat(data) {} +}; + +template +void matrix_multiply(big_matrix &C, + big_matrix &A, + big_matrix &B) { + size_t M = NUM_ROWS_C; + size_t N = NUM_COLS_C; + size_t K = NUM_COLS_A; + + assert(NUM_ROWS_C == NUM_ROWS_A && NUM_COLS_A == NUM_ROWS_B * 2); + size_t NDRangeM = M / TM; + size_t NDRangeN = N / TN; + buffer bufA(A.get_data(), range<2>(M, K)); + buffer bufB(B.get_data(), range<2>(K / 2, N * 2)); + buffer bufC((float *)C.get_data(), range<2>(M, N)); + + queue q; + q.submit([&](handler &cgh) { + auto accC = bufC.get_access(cgh); + auto accA = bufA.get_access(cgh); + auto accB = bufB.get_access(cgh); + + cgh.parallel_for( + nd_range<2>({NDRangeM, NDRangeN * SG_SZ}, {1, 1 * SG_SZ}), + [accA, accB, accC, M, N, K](nd_item<2> spmd_item) + [[intel::reqd_sub_group_size(SG_SZ)]] + + { + // The submatrix API has to be accessed by all the workitems in a + // subgroup these functions will be called once by the subgroup no + // code divergence between the workitems + const auto global_idx = spmd_item.get_global_id(0); + const auto global_idy = spmd_item.get_global_id(1); + const auto sg_startx = global_idx - spmd_item.get_local_id(0); + const auto sg_starty = global_idy - spmd_item.get_local_id(1); + + sub_group sg = spmd_item.get_sub_group(); + joint_matrix sub_a(sg); + // For B, since current implementation does not support non-packed + // layout, users need to specify the packed_b layout. + // By default, the layout is row_major + joint_matrix sub_b( + sg); + joint_matrix sub_c(sg); + joint_matrix_load(sg, sub_c, + accC.get_pointer() + (sg_startx * TM) * N + + sg_starty / SG_SZ * TN, + N, matrix_layout::row_major); + for (int k = 0; k < K; k += TK) { + joint_matrix_load(sg, sub_a, + accA.get_pointer() + (sg_startx * TM) * K + k, K, + matrix_layout::row_major); + // Assume we alreay in vnni format. + joint_matrix_load(sg, sub_b, + accB.get_pointer() + (k) * (N) + + sg_starty / SG_SZ * TN * 2, + N * 2, matrix_layout::packed_b); + sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + } + joint_matrix_store(sg, sub_c, + accC.get_pointer() + (sg_startx * TM) * N + + sg_starty / SG_SZ * TN, + N, matrix_layout::row_major); + }); // parallel for + }).wait(); +} + +static constexpr size_t MATRIX_M = TM * 2; +static constexpr size_t MATRIX_N = TN * 2; +static constexpr size_t MATRIX_K = TK * 2; +unsigned short A[MATRIX_M][MATRIX_K]; +unsigned short B[MATRIX_K / 2][MATRIX_N * 2]; +float C[MATRIX_M][MATRIX_N]; +float D[MATRIX_M][MATRIX_N]; + +float make_fp32(short x) { + unsigned int y = x; + y = y << 16; + float *res = reinterpret_cast(&y); + return *res; +} + +unsigned short make_bf16(float x) { + int *res = reinterpret_cast(&x); + *res = *res >> 16; + return (unsigned short)*res; +} + +void matrix_multiply_ref(int *A_mem, int *B_mem, int *C_mem, int M, int N, + int K) { + // tiling + for (int m = 0; m < M; m++) + for (int n = 0; n < N; n++) { + for (int k = 0; k < K; k++) { + short *va = (short *)(A_mem + m * K + k); + short *vb = (short *)(B_mem + k * N + n); + float acc = *((float *)(C_mem + m * N + n)); + // FIXME: Should we do reduce-add in another version? + for (int i = 0; i < 2; i++) { + acc += (make_fp32(va[i]) * make_fp32(vb[i])); + } + *((float *)(C_mem + m * N + n)) = acc; + } + } +} + +int main() { + for (int i = 0; i < MATRIX_M; i++) { + for (int j = 0; j < MATRIX_K; j++) { + A[i][j] = make_bf16(1.0f * (i + j)); + } + } + for (int i = 0; i < MATRIX_K / 2; i++) { + for (int j = 0; j < MATRIX_N * 2; j++) { + B[i][j] = make_bf16(2.0f * i + 3.0f * j); + } + } + for (int i = 0; i < MATRIX_M; i++) { + for (int j = 0; j < MATRIX_N; j++) { + C[i][j] = 1.0; + D[i][j] = 1.0; + } + } + + big_matrix MC((float *)&C); + big_matrix MD((float *)&D); + big_matrix MA((unsigned short *)&A); + big_matrix MB( + (unsigned short *)&B); + matrix_multiply(MC, MA, MB); + matrix_multiply_ref((int32_t *)A, (int32_t *)B, (int32_t *)D, MATRIX_M, + MATRIX_N, MATRIX_K / 2); + + bool res = true; + for (int i = 0; i < MATRIX_M; i++) { + for (int j = 0; j < MATRIX_N; j++) { + if ((fabs(C[i][j]) - fabs(D[i][j])) > BF16_EPSILON) + res = false; + } + } + if (res) + std::cout << "passed\n"; + else + std::cout << "failed\n"; +} diff --git a/SYCL/Matrix/Legacy/joint_matrix_bfloat16.cpp b/SYCL/Matrix/Legacy/joint_matrix_bfloat16.cpp new file mode 100644 index 0000000000..6ee8304d22 --- /dev/null +++ b/SYCL/Matrix/Legacy/joint_matrix_bfloat16.cpp @@ -0,0 +1,23 @@ +//==-------- joint_matrix_bfloat16.cpp - DPC++ joint_matrix----------- ----==// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// REQUIRES: matrix + +// RUN: %clangxx -fsycl %s -o %t.out -DSYCL_EXT_ONEAPI_MATRIX_VERSION=1 +// RUN: %CPU_RUN_PLACEHOLDER %t.out +// RUN: %GPU_RUN_PLACEHOLDER %t.out + +#include +#include + +using namespace sycl; +using namespace sycl::ext::oneapi::experimental::matrix; +using bfloat16 = sycl::ext::oneapi::experimental::bfloat16; + +#define SG_SZ 16 + +#include "joint_matrix_bfloat16_impl.hpp" diff --git a/SYCL/Matrix/Legacy/joint_matrix_bfloat16_32x64.cpp b/SYCL/Matrix/Legacy/joint_matrix_bfloat16_32x64.cpp new file mode 100644 index 0000000000..41e412d07a --- /dev/null +++ b/SYCL/Matrix/Legacy/joint_matrix_bfloat16_32x64.cpp @@ -0,0 +1,25 @@ +//==----- joint_matrix_bfloat16_32x64.cpp - DPC++ joint_matrix-------------==// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// REQUIRES: matrix + +// RUN: %clangxx -fsycl %s -o %t.out -DSYCL_EXT_ONEAPI_MATRIX_VERSION=1 +// RUN: %CPU_RUN_PLACEHOLDER %t.out +// RUN: %GPU_RUN_PLACEHOLDER %t.out + +// XFAIL: * + +#include +#include + +using namespace sycl; +using namespace sycl::ext::oneapi::experimental::matrix; +using bfloat16 = sycl::ext::oneapi::experimental::bfloat16; + +#define SG_SZ 16 + +#include "joint_matrix_bfloat16_32x64_impl.hpp" diff --git a/SYCL/Matrix/Legacy/joint_matrix_bfloat16_32x64_impl.hpp b/SYCL/Matrix/Legacy/joint_matrix_bfloat16_32x64_impl.hpp new file mode 100644 index 0000000000..85b845501b --- /dev/null +++ b/SYCL/Matrix/Legacy/joint_matrix_bfloat16_32x64_impl.hpp @@ -0,0 +1,159 @@ +#define TM 32 +#define TN 64 +#define TK 16 + +#define BF16_EPSILON 0.00781250 + +template struct big_matrix { +private: + T *mat; + +public: + T *get_data() { return mat; } + void set_data(T *data) { mat = data; } + big_matrix(T *data) : mat(data) {} +}; + +template +void matrix_multiply(big_matrix &C, big_matrix &A, + big_matrix &B) { + size_t NDRangeM = M / TM; + size_t NDRangeN = N / TN; + buffer bufA(A.get_data(), range<2>(M, K)); + buffer bufB(B.get_data(), range<2>(K, N)); + buffer bufC((float *)C.get_data(), range<2>(M, N)); + + queue q; + q.submit([&](handler &cgh) { + auto accC = bufC.get_access(cgh); + auto accA = bufA.get_access(cgh); + auto accB = bufB.get_access(cgh); + + cgh.parallel_for( + nd_range<2>({NDRangeM, NDRangeN * SG_SZ}, {1, 1 * SG_SZ}), + [=](nd_item<2> spmd_item) [[intel::reqd_sub_group_size(SG_SZ)]] + + { + // The submatrix API has to be accessed by all the workitems in a + // subgroup these functions will be called once by the subgroup no + // code divergence between the workitems + const auto global_idx = spmd_item.get_global_id(0); + const auto global_idy = spmd_item.get_global_id(1); + const auto sg_startx = global_idx - spmd_item.get_local_id(0); + const auto sg_starty = global_idy - spmd_item.get_local_id(1); + + ext::oneapi::sub_group sg = spmd_item.get_sub_group(); + joint_matrix sub_a(sg); + // For B, since current implementation does not support non-packed + // layout, users need to specify the updated VNNI sizes along with + // the packed_b layout. By default, the layout is row_major and size + // is (TK, TN). + joint_matrix sub_b(sg); + joint_matrix sub_c(sg); + + joint_matrix_load(sg, sub_c, + accC.get_pointer() + (sg_startx * TM) * N + + sg_starty / SG_SZ * TN, + N, matrix_layout::row_major); + for (int k = 0; k < K / TK; k += 1) { // + joint_matrix_load( + sg, sub_a, accA.get_pointer() + (sg_startx * TM) * K + k * TK, + K, matrix_layout::row_major); + // Assuming B data is already in VNNI format. + joint_matrix_load(sg, sub_b, + accB.get_pointer() + (k * TK / 2) * (N * 2) + + sg_starty / SG_SZ * TN * 2, + N * 2, matrix_layout::packed_b); + sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + } + joint_matrix_store(sg, sub_c, + accC.get_pointer() + (sg_startx * TM) * N + + sg_starty / SG_SZ * TN, + N, matrix_layout::row_major); + }); // parallel for + }).wait(); +} + +static constexpr size_t MATRIX_M = TM * 2; +static constexpr size_t MATRIX_N = TN * 2; +static constexpr size_t MATRIX_K = TK * 2; +bfloat16 A[MATRIX_M][MATRIX_K]; +bfloat16 B[MATRIX_K / 2][MATRIX_N * 2]; +unsigned short Aref[MATRIX_M][MATRIX_K]; +unsigned short Bref[MATRIX_K / 2][MATRIX_N * 2]; +float C[MATRIX_M][MATRIX_N]; +float D[MATRIX_M][MATRIX_N]; + +float make_fp32(short x) { + unsigned int y = x; + y = y << 16; + float *res = reinterpret_cast(&y); + return *res; +} + +unsigned short make_bf16(float x) { + int *res = reinterpret_cast(&x); + *res = *res >> 16; + return (unsigned short)*res; +} + +void matrix_multiply_ref(int *A_mem, int *B_mem, int *C_mem, int M, int N, + int K) { + // tiling + for (int m = 0; m < M; m++) + for (int n = 0; n < N; n++) { + for (int k = 0; k < K; k++) { + short *va = (short *)(A_mem + m * K + k); + short *vb = (short *)(B_mem + k * N + n); + float acc = *((float *)(C_mem + m * N + n)); + // FIXME: Should we do reduce-add in another version? + for (int i = 0; i < 2; i++) { + acc += (make_fp32(va[i]) * make_fp32(vb[i])); + } + *((float *)(C_mem + m * N + n)) = acc; + } + } +} + +int main() { + for (int i = 0; i < MATRIX_M; i++) { + for (int j = 0; j < MATRIX_K; j++) { + // bfloat16 is created using unsigned short since conversion from float to + // bfloat16 is not supported on the host side yet + A[i][j] = bfloat16::from_bits(make_bf16(1.0f * (i + j))); + Aref[i][j] = make_bf16(1.0f * (i + j)); + } + } + for (int i = 0; i < MATRIX_K / 2; i++) { + for (int j = 0; j < MATRIX_N * 2; j++) { + B[i][j] = bfloat16::from_bits((make_bf16(2.0f * i + 3.0f * j))); + Bref[i][j] = make_bf16(2.0f * i + 3.0f * j); + } + } + for (int i = 0; i < MATRIX_M; i++) { + for (int j = 0; j < MATRIX_N; j++) { + C[i][j] = 1.0; + D[i][j] = 1.0; + } + } + + big_matrix MC((float *)&C); + big_matrix MD((float *)&D); + big_matrix MA((bfloat16 *)&A); + big_matrix MB((bfloat16 *)&B); + matrix_multiply(MC, MA, MB); + matrix_multiply_ref((int32_t *)Aref, (int32_t *)Bref, (int32_t *)D, MATRIX_M, + MATRIX_N, MATRIX_K / 2); + + bool res = true; + for (int i = 0; i < MATRIX_M; i++) { + for (int j = 0; j < MATRIX_N; j++) { + if ((fabs(C[i][j]) - fabs(D[i][j])) > BF16_EPSILON) + res = false; + } + } + if (res) + std::cout << "passed\n"; + else + std::cout << "failed\n"; +} diff --git a/SYCL/Matrix/Legacy/joint_matrix_bfloat16_impl.hpp b/SYCL/Matrix/Legacy/joint_matrix_bfloat16_impl.hpp new file mode 100644 index 0000000000..df3f21d0b1 --- /dev/null +++ b/SYCL/Matrix/Legacy/joint_matrix_bfloat16_impl.hpp @@ -0,0 +1,159 @@ +#define TM 8 +#define TN SG_SZ +#define TK 16 + +#define BF16_EPSILON 0.00781250 + +template struct big_matrix { +private: + T *mat; + +public: + T *get_data() { return mat; } + void set_data(T *data) { mat = data; } + big_matrix(T *data) : mat(data) {} +}; + +template +void matrix_multiply(big_matrix &C, big_matrix &A, + big_matrix &B) { + size_t NDRangeM = M / TM; + size_t NDRangeN = N / TN; + buffer bufA(A.get_data(), range<2>(M, K)); + buffer bufB(B.get_data(), range<2>(K, N)); + buffer bufC((float *)C.get_data(), range<2>(M, N)); + + queue q; + q.submit([&](handler &cgh) { + auto accC = bufC.get_access(cgh); + auto accA = bufA.get_access(cgh); + auto accB = bufB.get_access(cgh); + + cgh.parallel_for( + nd_range<2>({NDRangeM, NDRangeN * SG_SZ}, {1, 1 * SG_SZ}), + [=](nd_item<2> spmd_item) [[intel::reqd_sub_group_size(SG_SZ)]] + + { + // The submatrix API has to be accessed by all the workitems in a + // subgroup these functions will be called once by the subgroup no + // code divergence between the workitems + const auto global_idx = spmd_item.get_global_id(0); + const auto global_idy = spmd_item.get_global_id(1); + const auto sg_startx = global_idx - spmd_item.get_local_id(0); + const auto sg_starty = global_idy - spmd_item.get_local_id(1); + + ext::oneapi::sub_group sg = spmd_item.get_sub_group(); + joint_matrix sub_a(sg); + // For B, since current implementation does not support non-packed + // layout, users need to specify the updated VNNI sizes along with + // the packed_b layout. By default, the layout is row_major and size + // is (TK, TN). + joint_matrix sub_b(sg); + joint_matrix sub_c(sg); + + joint_matrix_load(sg, sub_c, + accC.get_pointer() + (sg_startx * TM) * N + + sg_starty / SG_SZ * TN, + N, matrix_layout::row_major); + for (int k = 0; k < K / TK; k += 1) { // + joint_matrix_load( + sg, sub_a, accA.get_pointer() + (sg_startx * TM) * K + k * TK, + K, matrix_layout::row_major); + // Assuming B data is already in VNNI format. + joint_matrix_load(sg, sub_b, + accB.get_pointer() + (k * TK / 2) * (N * 2) + + sg_starty / SG_SZ * TN * 2, + N * 2, matrix_layout::packed_b); + sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + } + joint_matrix_store(sg, sub_c, + accC.get_pointer() + (sg_startx * TM) * N + + sg_starty / SG_SZ * TN, + N, matrix_layout::row_major); + }); // parallel for + }).wait(); +} + +static constexpr size_t MATRIX_M = TM * 2; +static constexpr size_t MATRIX_N = TN * 2; +static constexpr size_t MATRIX_K = TK * 2; +bfloat16 A[MATRIX_M][MATRIX_K]; +bfloat16 B[MATRIX_K / 2][MATRIX_N * 2]; +unsigned short Aref[MATRIX_M][MATRIX_K]; +unsigned short Bref[MATRIX_K / 2][MATRIX_N * 2]; +float C[MATRIX_M][MATRIX_N]; +float D[MATRIX_M][MATRIX_N]; + +float make_fp32(short x) { + unsigned int y = x; + y = y << 16; + float *res = reinterpret_cast(&y); + return *res; +} + +unsigned short make_bf16(float x) { + int *res = reinterpret_cast(&x); + *res = *res >> 16; + return (unsigned short)*res; +} + +void matrix_multiply_ref(int *A_mem, int *B_mem, int *C_mem, int M, int N, + int K) { + // tiling + for (int m = 0; m < M; m++) + for (int n = 0; n < N; n++) { + for (int k = 0; k < K; k++) { + short *va = (short *)(A_mem + m * K + k); + short *vb = (short *)(B_mem + k * N + n); + float acc = *((float *)(C_mem + m * N + n)); + // FIXME: Should we do reduce-add in another version? + for (int i = 0; i < 2; i++) { + acc += (make_fp32(va[i]) * make_fp32(vb[i])); + } + *((float *)(C_mem + m * N + n)) = acc; + } + } +} + +int main() { + for (int i = 0; i < MATRIX_M; i++) { + for (int j = 0; j < MATRIX_K; j++) { + // bfloat16 is created using unsigned short since conversion from float to + // bfloat16 is not supported on the host side yet + A[i][j] = bfloat16::from_bits(make_bf16(1.0f * (i + j))); + Aref[i][j] = make_bf16(1.0f * (i + j)); + } + } + for (int i = 0; i < MATRIX_K / 2; i++) { + for (int j = 0; j < MATRIX_N * 2; j++) { + B[i][j] = bfloat16::from_bits((make_bf16(2.0f * i + 3.0f * j))); + Bref[i][j] = make_bf16(2.0f * i + 3.0f * j); + } + } + for (int i = 0; i < MATRIX_M; i++) { + for (int j = 0; j < MATRIX_N; j++) { + C[i][j] = 1.0; + D[i][j] = 1.0; + } + } + + big_matrix MC((float *)&C); + big_matrix MD((float *)&D); + big_matrix MA((bfloat16 *)&A); + big_matrix MB((bfloat16 *)&B); + matrix_multiply(MC, MA, MB); + matrix_multiply_ref((int32_t *)Aref, (int32_t *)Bref, (int32_t *)D, MATRIX_M, + MATRIX_N, MATRIX_K / 2); + + bool res = true; + for (int i = 0; i < MATRIX_M; i++) { + for (int j = 0; j < MATRIX_N; j++) { + if ((fabs(C[i][j]) - fabs(D[i][j])) > BF16_EPSILON) + res = false; + } + } + if (res) + std::cout << "passed\n"; + else + std::cout << "failed\n"; +} diff --git a/SYCL/Matrix/Legacy/joint_matrix_half.cpp b/SYCL/Matrix/Legacy/joint_matrix_half.cpp new file mode 100644 index 0000000000..d88a9f0b1b --- /dev/null +++ b/SYCL/Matrix/Legacy/joint_matrix_half.cpp @@ -0,0 +1,22 @@ +//==-------- joint_matrix_half.cpp - DPC++ joint_matrix------------ ----==// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// REQUIRES: matrix + +// RUN: %clangxx -fsycl %s -o %t.out -DSYCL_EXT_ONEAPI_MATRIX_VERSION=1 +// Only run on the GPU because half is not supported on AMX hardware +// RUN: %GPU_RUN_PLACEHOLDER %t.out + +#include +#include + +using namespace sycl; +using namespace sycl::ext::oneapi::experimental::matrix; + +#define SG_SZ 16 + +#include "joint_matrix_half_impl.hpp" diff --git a/SYCL/Matrix/Legacy/joint_matrix_half_impl.hpp b/SYCL/Matrix/Legacy/joint_matrix_half_impl.hpp new file mode 100644 index 0000000000..ff8643f909 --- /dev/null +++ b/SYCL/Matrix/Legacy/joint_matrix_half_impl.hpp @@ -0,0 +1,144 @@ +#define TM 8 +#define TN SG_SZ +#define TK 16 + +template struct big_matrix { +public: + T *mat; + +public: + T *get_data() { return mat; } + void set_data(T *data) { mat = data; } + big_matrix(T *data) : mat(data) {} +}; + +template +void matrix_multiply(big_matrix &C, + big_matrix &A, + big_matrix &B) { + size_t M = NUM_ROWS_C; + size_t N = NUM_COLS_C; + size_t K = NUM_COLS_A; + + assert(NUM_ROWS_C == NUM_ROWS_A && NUM_COLS_A == NUM_ROWS_B * 2); + size_t NDRangeM = M / TM; + size_t NDRangeN = N / TN; + buffer bufA(A.get_data(), range<2>(M, K)); + buffer bufB(B.get_data(), range<2>(K, N)); + buffer bufC(C.get_data(), range<2>(M, N)); + + queue q; + q.submit([&](handler &cgh) { + auto accC = bufC.get_access(cgh); + auto accA = bufA.get_access(cgh); + auto accB = bufB.get_access(cgh); + + cgh.parallel_for( + nd_range<2>({NDRangeM, NDRangeN * SG_SZ}, {1, SG_SZ}), + [accA, accB, accC, M, N, + K](nd_item<2> spmd_item) [[intel::reqd_sub_group_size(SG_SZ)]] { + // The submatrix API has to be accessed by all the workitems in a + // subgroup these functions will be called once by the subgroup no + // code divergence between the workitems + const auto global_idx = spmd_item.get_global_id(0); + const auto global_idy = spmd_item.get_global_id(1); + const auto sg_startx = global_idx - spmd_item.get_local_id(0); + const auto sg_starty = global_idy - spmd_item.get_local_id(1); + + ext::oneapi::sub_group sg = spmd_item.get_sub_group(); + joint_matrix sub_a(sg); + // For B, since current implementation does not support non-packed + // layout, users need to specify the updated VNNI sizes along with + // the packed_b layout. By default, the layout is row_major and size + // is (TK, TN). + joint_matrix sub_b(sg); + joint_matrix sub_c(sg); + + joint_matrix_load(sg, sub_c, + accC.get_pointer() + (sg_startx * TM) * N + + sg_starty / SG_SZ * TN, + N, matrix_layout::row_major); + for (int k = 0; k < K / TK; k += 1) { + joint_matrix_load( + sg, sub_a, accA.get_pointer() + (sg_startx * TM) * K + k * TK, + K, matrix_layout::row_major); + // Assuming B data is already in VNNI format. + joint_matrix_load(sg, sub_b, + accB.get_pointer() + (k * TK / 2) * (N * 2) + + sg_starty / SG_SZ * TN * 2, + N * 2, matrix_layout::packed_b); + sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + } + joint_matrix_store(sg, sub_c, + accC.get_pointer() + (sg_startx * TM) * N + + sg_starty / SG_SZ * TN, + N, matrix_layout::row_major); + }); // parallel for + }).wait(); +} + +static constexpr size_t MATRIX_M = TM * 2; +static constexpr size_t MATRIX_N = TN * 2; +static constexpr size_t MATRIX_K = TK * 2; +half A[MATRIX_M][MATRIX_K]; +half B[MATRIX_K / 2][MATRIX_N * 2]; +float C[MATRIX_M][MATRIX_N]; +float D[MATRIX_M][MATRIX_N]; + +void matrix_multiply_ref(float *A_mem, float *B_mem, float *C_mem, int M, int N, + int K) { + // tiling + for (int m = 0; m < M; m++) + for (int n = 0; n < N; n++) { + for (int k = 0; k < K; k++) { + half *va = (half *)(A_mem + m * K + k); + half *vb = (half *)(B_mem + k * N + n); + float acc = *(C_mem + m * N + n); + for (int i = 0; i < 2; i++) { + acc += ((float)va[i] * (float)vb[i]); + } + *((float *)(C_mem + m * N + n)) = acc; + } + } +} + +int main() { + for (int i = 0; i < MATRIX_M; i++) { + for (int j = 0; j < MATRIX_K; j++) { + A[i][j] = i + 2 * j; + } + } + for (int i = 0; i < MATRIX_K / 2; i++) { + for (int j = 0; j < MATRIX_N * 2; j++) { + B[i][j] = i + j; + } + } + for (int i = 0; i < MATRIX_M; i++) { + for (int j = 0; j < MATRIX_N; j++) { + C[i][j] = 1.0; + D[i][j] = 1.0; + } + } + + big_matrix MC((float *)&C); + big_matrix MD((float *)&D); + big_matrix MA((half *)&A); + big_matrix MB((half *)&B); + matrix_multiply(MC, MA, MB); + matrix_multiply_ref((float *)A, (float *)B, (float *)D, MATRIX_M, MATRIX_N, + MATRIX_K / 2); + + bool res = true; + for (int i = 0; i < MATRIX_M; i++) { + for (int j = 0; j < MATRIX_N; j++) { + if (C[i][j] != D[i][j]) + res = false; + } + } + if (res) + std::cout << "passed\n"; + else + std::cout << "failed\n"; +} diff --git a/SYCL/Matrix/joint_matrix_bfloat16_use.cpp b/SYCL/Matrix/Legacy/joint_matrix_int8_vnni.cpp similarity index 78% rename from SYCL/Matrix/joint_matrix_bfloat16_use.cpp rename to SYCL/Matrix/Legacy/joint_matrix_int8_vnni.cpp index aa6412195e..c16d3ad726 100644 --- a/SYCL/Matrix/joint_matrix_bfloat16_use.cpp +++ b/SYCL/Matrix/Legacy/joint_matrix_int8_vnni.cpp @@ -1,4 +1,4 @@ -//==-------- joint_matrix_bfloat16_use.cpp - DPC++ joint_matrix------------==// +//==-------- joint_matrix_bf16_vnni.cpp - DPC++ joint_matrix---------------==// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -7,7 +7,7 @@ //===----------------------------------------------------------------------===// // REQUIRES: matrix -// RUN: %clangxx -fsycl %s -o %t.out -DSYCL_EXT_ONEAPI_MATRIX_VERSION=2 +// RUN: %clangxx -fsycl %s -o %t.out -DSYCL_EXT_ONEAPI_MATRIX_VERSION=1 // RUN: %CPU_RUN_PLACEHOLDER %t.out // RUN: %GPU_RUN_PLACEHOLDER %t.out @@ -16,9 +16,9 @@ #include #include +using namespace sycl; using namespace sycl::ext::oneapi::experimental::matrix; -using bfloat16 = sycl::ext::oneapi::bfloat16; #define SG_SZ 16 -#include "joint_matrix_bfloat16_use_impl.hpp" +#include "joint_matrix_int8_vnni_impl.hpp" diff --git a/SYCL/Matrix/Legacy/joint_matrix_int8_vnni_impl.hpp b/SYCL/Matrix/Legacy/joint_matrix_int8_vnni_impl.hpp new file mode 100644 index 0000000000..6fe03ddac1 --- /dev/null +++ b/SYCL/Matrix/Legacy/joint_matrix_int8_vnni_impl.hpp @@ -0,0 +1,155 @@ +#define TM 8 +#define TN SG_SZ +#define TK 32 + +template struct big_matrix { +public: + T *mat; + +public: + T *get_data() { return mat; } + void set_data(T *data) { mat = data; } + big_matrix(T *data) : mat(data) {} +}; + +template +void matrix_multiply(big_matrix &C, + big_matrix &A, + big_matrix &B) { + size_t M = NUM_ROWS_C; + size_t N = NUM_COLS_C; + size_t K = NUM_COLS_A; + + size_t NDRangeM = M / TM; + size_t NDRangeN = N / TN; + buffer bufA(A.get_data(), range<2>(M, K)); + buffer bufB(B.get_data(), range<2>(K, N)); + buffer bufC(C.get_data(), range<2>(M, N)); + + queue q; + q.submit([&](handler &cgh) { + auto accC = bufC.get_access(cgh); + auto accA = bufA.get_access(cgh); + auto accB = bufB.get_access(cgh); + + cgh.parallel_for( + nd_range<2>({NDRangeM, NDRangeN * SG_SZ}, {1, 1 * SG_SZ}), + [accA, accB, accC, M, N, K](nd_item<2> spmd_item) + + { + // The submatrix API has to be accessed by all the workitems in a + // subgroup these functions will be called once by the subgroup no + // code divergence between the workitems + const auto global_idx = spmd_item.get_global_id(0); + const auto global_idy = spmd_item.get_global_id(1); + const auto sg_startx = global_idx - spmd_item.get_local_id(0); + const auto sg_starty = global_idy - spmd_item.get_local_id(1); + + sub_group sg = spmd_item.get_sub_group(); + joint_matrix sub_a(sg); + joint_matrix sub_b(sg); + joint_matrix sub_c(sg); + + joint_matrix_fill(sg, sub_c, 0); + for (int k = 0; k < K / TK; k += 1) { + joint_matrix_load( + sg, sub_a, accA.get_pointer() + (sg_startx * TM) * K + k * TK, + K, matrix_layout::row_major); + // VNNI transform is done automatically at this level + joint_matrix_load(sg, sub_b, + accB.get_pointer() + (k * TK) * N + + sg_starty / SG_SZ * TN, + N, matrix_layout::row_major); + sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + } + joint_matrix_store(sg, sub_c, + accC.get_pointer() + (sg_startx * TM) * N + + sg_starty / SG_SZ * TN, + N, matrix_layout::row_major); + }); // parallel for + }).wait(); +} + +static constexpr size_t MATRIX_M = TM * 2; +static constexpr size_t MATRIX_N = TN * 2; +static constexpr size_t MATRIX_K = TK * 2; +int8_t A[MATRIX_M][MATRIX_K]; +int8_t B[MATRIX_K][MATRIX_N]; +int8_t Bvnni[MATRIX_K / 4][MATRIX_N * 4]; +int32_t C[MATRIX_M][MATRIX_N]; +int32_t D[MATRIX_M][MATRIX_N]; + +void int8_row_vnni_reformat(int8_t *_in, int8_t *_out, int K, int N, + int stride_in) { + // find the old index, new index, and copy element. + //(K, N) => (k/4, N*4) + // idx in 2d: (i,j)=>(i/4, j*4+i%4) + // linear idx: + for (int i = 0; i < K; ++i) { + for (int j = 0; j < N; ++j) { + size_t oldindex = i * stride_in + j; + size_t newindex = (i / 4) * N * 4 + j * 4 + i % 4; + _out[newindex] = _in[oldindex]; + } + } +} + +void matrix_multiply_ref(int32_t *A_mem, int32_t *B_mem, int32_t *C_mem, int M, + int N, int K) { + // tiling + for (int m = 0; m < M; m++) + for (int n = 0; n < N; n++) { + for (int k = 0; k < K; k++) { + char *va = (char *)(A_mem + m * K + k); + char *vb = (char *)(B_mem + k * N + n); + int acc = *(C_mem + m * N + n); + for (int i = 0; i < 4; i++) { + acc += (va[i] * vb[i]); + } + *(C_mem + m * N + n) = acc; + } + } +} + +int main() { + for (int i = 0; i < MATRIX_M; i++) { + for (int j = 0; j < MATRIX_K; j++) { + A[i][j] = i + j; + } + } + for (int i = 0; i < MATRIX_K; i++) { + for (int j = 0; j < MATRIX_N; j++) { + B[i][j] = i + j * 2; + } + } + for (int i = 0; i < MATRIX_M; i++) { + for (int j = 0; j < MATRIX_N; j++) { + C[i][j] = 0; + D[i][j] = 0; + } + } + + big_matrix MC((int32_t *)&C); + big_matrix MD((int32_t *)&D); + big_matrix MA((int8_t *)&A); + big_matrix MB((int8_t *)&B); + matrix_multiply(MC, MA, MB); + int8_row_vnni_reformat((int8_t *)B, (int8_t *)Bvnni, MATRIX_K, MATRIX_N, + MATRIX_N); + matrix_multiply_ref((int32_t *)A, (int32_t *)Bvnni, (int32_t *)D, MATRIX_M, + MATRIX_N, MATRIX_K / 4); + + bool res = true; + for (int i = 0; i < MATRIX_M; i++) { + for (int j = 0; j < MATRIX_N; j++) { + if (C[i][j] != D[i][j]) + res = false; + } + } + if (res) + std::cout << "passed\n"; + else + std::cout << "failed\n"; +} diff --git a/SYCL/Matrix/Legacy/joint_matrix_query_default.cpp b/SYCL/Matrix/Legacy/joint_matrix_query_default.cpp new file mode 100644 index 0000000000..4e82a2a928 --- /dev/null +++ b/SYCL/Matrix/Legacy/joint_matrix_query_default.cpp @@ -0,0 +1,166 @@ +//==-------- joint_matrix_query.cpp - DPC++ joint_matrix------------ ----==// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// REQUIRES: matrix + +// RUN: %clangxx -fsycl %s -o %t.out -DSYCL_EXT_ONEAPI_MATRIX_VERSION=1 +// RUN: %CPU_RUN_PLACEHOLDER %t.out + +#include +#include + +using namespace sycl; +using namespace sycl::ext::oneapi::experimental::matrix; + +template struct big_matrix { +public: + T *mat; + +public: + T *get_data() { return mat; } + void set_data(T *data) { mat = data; } + big_matrix(T *data) : mat(data) {} +}; + +template +void matrix_multiply(big_matrix &C, + big_matrix &A, + big_matrix &B) { + size_t M = NUM_ROWS_C; + size_t N = NUM_COLS_C; + size_t K = NUM_COLS_A; + assert(NUM_ROWS_C == NUM_ROWS_A && NUM_COLS_A == NUM_ROWS_B * 4); + + using myparams2 = tpu_params; + constexpr int TM = myparams2::defaultM; + constexpr int TN = myparams2::defaultN; + constexpr int TK = myparams2::defaultK; + + std::cout << "AMX query sizes are: M " << TM << " N " << TN << " K " << TK + << std::endl; + + constexpr int SG_SZ = TN; + size_t NDRangeM = M / TM; + size_t NDRangeN = N / TN; + buffer bufA(A.get_data(), range<2>(M, K)); + buffer bufB(B.get_data(), range<2>(K, N)); + buffer bufC(C.get_data(), range<2>(M, N)); + + queue q; + q.submit([&](handler &cgh) { + auto accC = bufC.get_access(cgh); + auto accA = bufA.get_access(cgh); + auto accB = bufB.get_access(cgh); + + cgh.parallel_for( + nd_range<2>({NDRangeM, NDRangeN * SG_SZ}, {1, 1 * SG_SZ}), + [accA, accB, accC, M, N, K](nd_item<2> spmd_item) + [[intel::reqd_sub_group_size(SG_SZ)]] + + { + // The submatrix API has to be accessed by all the workitems in a + // subgroup these functions will be called once by the subgroup no + // code divergence between the workitems + const auto global_idx = spmd_item.get_global_id(0); + const auto global_idy = spmd_item.get_global_id(1); + const auto sg_startx = global_idx - spmd_item.get_local_id(0); + const auto sg_starty = global_idy - spmd_item.get_local_id(1); + + ext::oneapi::sub_group sg = spmd_item.get_sub_group(); + + myparams2::joint_matrix_a sub_a(sg); + myparams2::joint_matrix_b sub_b(sg); + myparams2::joint_matrix_c sub_c(sg); + + joint_matrix_load(sg, sub_c, + accC.get_pointer() + (sg_startx * TM) * N + + sg_starty / SG_SZ * TN, + N, matrix_layout::row_major); + for (int k = 0; k < K / TK; k += 1) { + joint_matrix_load( + sg, sub_a, accA.get_pointer() + (sg_startx * TM) * K + k * TK, + K, matrix_layout::row_major); + // Assuming B data is already in VNNI format. + joint_matrix_load(sg, sub_b, + accB.get_pointer() + (k * TK / 4) * (N * 4) + + sg_starty / SG_SZ * TN * 4, + N * 4, matrix_layout::packed_b); + sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + } + joint_matrix_store(sg, sub_c, + accC.get_pointer() + (sg_startx * TM) * N + + sg_starty / SG_SZ * TN, + N, matrix_layout::row_major); + }); // parallel for + }).wait(); +} + +static constexpr size_t MATRIX_M = 128; +static constexpr size_t MATRIX_N = 128; +static constexpr size_t MATRIX_K = 128; +int8_t A[MATRIX_M][MATRIX_K]; +int8_t B[MATRIX_K / 4][MATRIX_N * 4]; +int32_t C[MATRIX_M][MATRIX_N]; +int32_t D[MATRIX_M][MATRIX_N]; + +void matrix_multiply_ref(int32_t *A_mem, int32_t *B_mem, int32_t *C_mem, int M, + int N, int K) { + // tiling + for (int m = 0; m < M; m++) + for (int n = 0; n < N; n++) { + for (int k = 0; k < K; k++) { + char *va = (char *)(A_mem + m * K + k); + char *vb = (char *)(B_mem + k * N + n); + int acc = *(C_mem + m * N + n); + for (int i = 0; i < 4; i++) { + acc += (va[i] * vb[i]); + } + *(C_mem + m * N + n) = acc; + } + } +} + +int main() { + for (int i = 0; i < MATRIX_M; i++) { + for (int j = 0; j < MATRIX_K; j++) { + A[i][j] = i + 2 * j; + } + } + for (int i = 0; i < MATRIX_K / 4; i++) { + for (int j = 0; j < MATRIX_N * 4; j++) { + B[i][j] = i + j; + } + } + for (int i = 0; i < MATRIX_M; i++) { + for (int j = 0; j < MATRIX_N; j++) { + C[i][j] = 1; + D[i][j] = 1; + } + } + + big_matrix MC((int32_t *)&C); + big_matrix MD((int32_t *)&D); + big_matrix MA((int8_t *)&A); + big_matrix MB((int8_t *)&B); + matrix_multiply(MC, MA, MB); + matrix_multiply_ref((int32_t *)A, (int32_t *)B, (int32_t *)D, MATRIX_M, + MATRIX_N, MATRIX_K / 4); + + bool res = true; + for (int i = 0; i < MATRIX_M; i++) { + for (int j = 0; j < MATRIX_N; j++) { + if (C[i][j] != D[i][j]) + res = false; + } + } + if (res) + std::cout << "passed\n"; + else + std::cout << "failed\n"; +} diff --git a/SYCL/Matrix/Legacy/joint_matrix_ss_int8.cpp b/SYCL/Matrix/Legacy/joint_matrix_ss_int8.cpp new file mode 100644 index 0000000000..afdfd28feb --- /dev/null +++ b/SYCL/Matrix/Legacy/joint_matrix_ss_int8.cpp @@ -0,0 +1,22 @@ +//==-------- joint_matrix_ss_int8.cpp - DPC++ joint_matrix------------ ----==// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// REQUIRES: matrix + +// RUN: %clangxx -fsycl %s -o %t.out -DSYCL_EXT_ONEAPI_MATRIX_VERSION=1 +// RUN: %CPU_RUN_PLACEHOLDER %t.out +// RUN: %GPU_RUN_PLACEHOLDER %t.out + +#include +#include + +using namespace sycl; +using namespace sycl::ext::oneapi::experimental::matrix; + +#define SG_SZ 16 + +#include "joint_matrix_ss_int8_impl.hpp" diff --git a/SYCL/Matrix/joint_matrix_ss_int8_use_impl.hpp b/SYCL/Matrix/Legacy/joint_matrix_ss_int8_impl.hpp similarity index 76% rename from SYCL/Matrix/joint_matrix_ss_int8_use_impl.hpp rename to SYCL/Matrix/Legacy/joint_matrix_ss_int8_impl.hpp index 8f655e1fe9..d83332bfb4 100644 --- a/SYCL/Matrix/joint_matrix_ss_int8_use_impl.hpp +++ b/SYCL/Matrix/Legacy/joint_matrix_ss_int8_impl.hpp @@ -3,7 +3,7 @@ #define TK 32 template struct big_matrix { -private: +public: T *mat; public: @@ -21,7 +21,9 @@ void matrix_multiply(big_matrix &C, size_t M = NUM_ROWS_C; size_t N = NUM_COLS_C; size_t K = NUM_COLS_A; - static_assert(NUM_ROWS_C == NUM_ROWS_A && NUM_COLS_A == NUM_ROWS_B * 4); + // B => K/4 x N*4, A => M x K, C => M, N + // stride should be X's cols, e.g., B's stirde = N*4 + assert(NUM_ROWS_C == NUM_ROWS_A && NUM_COLS_A == NUM_ROWS_B * 4); size_t NDRangeM = M / TM; size_t NDRangeN = N / TN; buffer bufA(A.get_data(), range<2>(M, K)); @@ -30,44 +32,47 @@ void matrix_multiply(big_matrix &C, queue q; q.submit([&](handler &cgh) { - sycl::accessor accC{bufC, cgh}; - sycl::accessor accA{bufA, cgh}; - sycl::accessor accB{bufB, cgh}; + auto accC = bufC.get_access(cgh); + auto accA = bufA.get_access(cgh); + auto accB = bufB.get_access(cgh); cgh.parallel_for( nd_range<2>({NDRangeM, NDRangeN * SG_SZ}, {1, 1 * SG_SZ}), - [accA, accB, accC, M, N, K](nd_item<2> spmd_item) - [[intel::reqd_sub_group_size(SG_SZ)]] - - { + [accA, accB, accC, M, N, + K](nd_item<2> spmd_item) [[intel::reqd_sub_group_size(SG_SZ)]] { // The submatrix API has to be accessed by all the workitems in a - // subgroup + // subgroup these functions will be called once by the subgroup no + // code divergence between the workitems const auto global_idx = spmd_item.get_global_id(0); const auto global_idy = spmd_item.get_global_id(1); const auto sg_startx = global_idx - spmd_item.get_local_id(0); const auto sg_starty = global_idy - spmd_item.get_local_id(1); ext::oneapi::sub_group sg = spmd_item.get_sub_group(); - joint_matrix sub_a(sg); - joint_matrix sub_b(sg); - joint_matrix sub_c(sg); + joint_matrix sub_a(sg); + // For B, since current implementation does not support non-packed + // layout, users need to specify the updated VNNI sizes along with + // the packed_b layout. By default, the layout is row_major and size + // is (TK, TN). + joint_matrix sub_b(sg); + joint_matrix sub_c(sg); joint_matrix_fill(sg, sub_c, 0); for (int k = 0; k < K / TK; k += 1) { joint_matrix_load( sg, sub_a, accA.get_pointer() + (sg_startx * TM) * K + k * TK, - K, layout::row_major); + K, matrix_layout::row_major); // Assuming B data is already in VNNI format. joint_matrix_load(sg, sub_b, accB.get_pointer() + (k * TK / 4) * (N * 4) + sg_starty / SG_SZ * TN * 4, - N * 4, layout::packed_b); + N * 4, matrix_layout::packed_b); sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); } joint_matrix_store(sg, sub_c, accC.get_pointer() + (sg_startx * TM) * N + sg_starty / SG_SZ * TN, - N, layout::row_major); + N, matrix_layout::row_major); }); // parallel for }).wait(); } @@ -82,6 +87,7 @@ int32_t D[MATRIX_M][MATRIX_N]; void matrix_multiply_ref(int32_t *A_mem, int32_t *B_mem, int32_t *C_mem, int M, int N, int K) { + // tiling for (int m = 0; m < M; m++) for (int n = 0; n < N; n++) { for (int k = 0; k < K; k++) { diff --git a/SYCL/Matrix/Legacy/joint_matrix_su_int8.cpp b/SYCL/Matrix/Legacy/joint_matrix_su_int8.cpp new file mode 100644 index 0000000000..7c12200762 --- /dev/null +++ b/SYCL/Matrix/Legacy/joint_matrix_su_int8.cpp @@ -0,0 +1,22 @@ +//==-------- joint_matrix_su_int8.cpp - DPC++ joint_matrix------------ ----==// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// REQUIRES: matrix + +// RUN: %clangxx -fsycl %s -o %t.out -DSYCL_EXT_ONEAPI_MATRIX_VERSION=1 +// RUN: %CPU_RUN_PLACEHOLDER %t.out +// RUN: %GPU_RUN_PLACEHOLDER %t.out + +#include +#include + +using namespace sycl; +using namespace sycl::ext::oneapi::experimental::matrix; + +#define SG_SZ 16 + +#include "joint_matrix_su_int8_impl.hpp" diff --git a/SYCL/Matrix/Legacy/joint_matrix_su_int8_impl.hpp b/SYCL/Matrix/Legacy/joint_matrix_su_int8_impl.hpp new file mode 100644 index 0000000000..4e3e214f8a --- /dev/null +++ b/SYCL/Matrix/Legacy/joint_matrix_su_int8_impl.hpp @@ -0,0 +1,147 @@ +#define TM 8 +#define TN SG_SZ +#define TK 32 + +template struct big_matrix { +public: + T *mat; + +public: + T *get_data() { return mat; } + void set_data(T *data) { mat = data; } + big_matrix(T *data) : mat(data) {} +}; + +template +void matrix_multiply(big_matrix &C, + big_matrix &A, + big_matrix &B) { + size_t M = NUM_ROWS_C; + size_t N = NUM_COLS_C; + size_t K = NUM_COLS_A; + // B => K/4 x N*4, A => M x K, C => M, N + // stride should be X's cols, e.g., B's stirde = N*4 + assert(NUM_ROWS_C == NUM_ROWS_A && NUM_COLS_A == NUM_ROWS_B * 4); + size_t NDRangeM = M / TM; + size_t NDRangeN = N / TN; + buffer bufA(A.get_data(), range<2>(M, K)); + buffer bufB(B.get_data(), range<2>(K, N)); + buffer bufC(C.get_data(), range<2>(M, N)); + + queue q; + q.submit([&](handler &cgh) { + auto accC = bufC.get_access(cgh); + auto accA = bufA.get_access(cgh); + auto accB = bufB.get_access(cgh); + + cgh.parallel_for( + nd_range<2>({NDRangeM, NDRangeN * SG_SZ}, {1, 1 * SG_SZ}), + [accA, accB, accC, M, N, + K](nd_item<2> spmd_item) [[intel::reqd_sub_group_size(SG_SZ)]] { + // The submatrix API has to be accessed by all the workitems in a + // subgroup these functions will be called once by the subgroup no + // code divergence between the workitems + const auto global_idx = spmd_item.get_global_id(0); + const auto global_idy = spmd_item.get_global_id(1); + const auto sg_startx = global_idx - spmd_item.get_local_id(0); + const auto sg_starty = global_idy - spmd_item.get_local_id(1); + + ext::oneapi::sub_group sg = spmd_item.get_sub_group(); + joint_matrix sub_a(sg); + // For B, since current implementation does not support non-packed + // layout, users need to specify the updated VNNI sizes along with + // the packed_b layout. By default, the layout is row_major and size + // is (TK, TN). + joint_matrix sub_b(sg); + joint_matrix sub_c(sg); + + // AMX: 8 register tiles : 1k byte size, SMmaxxSKmax =16x64 + // strideX = X's cols, so strideC = N, strideA = K, strideB = N*4 + joint_matrix_load(sg, sub_c, + accC.get_pointer() + (sg_startx * TM) * N + + sg_starty / SG_SZ * TN, + N, matrix_layout::row_major); + for (int k = 0; k < K / TK; k += 1) { + joint_matrix_load( + sg, sub_a, accA.get_pointer() + (sg_startx * TM) * K + k * TK, + K, matrix_layout::row_major); + // Assuming B data is already in VNNI format. + joint_matrix_load(sg, sub_b, + accB.get_pointer() + (k * TK / 4) * (N * 4) + + sg_starty / SG_SZ * TN * 4, + N * 4, matrix_layout::packed_b); + sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + } + joint_matrix_store(sg, sub_c, + accC.get_pointer() + (sg_startx * TM) * N + + sg_starty / SG_SZ * TN, + N, matrix_layout::row_major); + }); // parallel for + }).wait(); +} + +static constexpr size_t MATRIX_M = TM * 2; +static constexpr size_t MATRIX_N = TN * 2; +static constexpr size_t MATRIX_K = TK * 2; +int8_t A[MATRIX_M][MATRIX_K]; +uint8_t B[MATRIX_K / 4][MATRIX_N * 4]; +int32_t C[MATRIX_M][MATRIX_N]; +int32_t D[MATRIX_M][MATRIX_N]; + +void matrix_multiply_ref(int32_t *A_mem, int32_t *B_mem, int32_t *C_mem, int M, + int N, int K) { + // tiling + for (int m = 0; m < M; m++) + for (int n = 0; n < N; n++) { + for (int k = 0; k < K; k++) { + int8_t *va = (int8_t *)(A_mem + m * K + k); + uint8_t *vb = (uint8_t *)(B_mem + k * N + n); + int acc = *(C_mem + m * N + n); + for (int i = 0; i < 4; i++) { + acc += (static_cast(va[i]) * static_cast(vb[i])); + } + *(C_mem + m * N + n) = acc; + } + } +} + +int main() { + for (int i = 0; i < MATRIX_M; i++) { + for (int j = 0; j < MATRIX_K; j++) { + A[i][j] = i + 2 * j; + } + } + for (int i = 0; i < MATRIX_K / 4; i++) { + for (int j = 0; j < MATRIX_N * 4; j++) { + B[i][j] = i + j; + } + } + for (int i = 0; i < MATRIX_M; i++) { + for (int j = 0; j < MATRIX_N; j++) { + C[i][j] = 1; + D[i][j] = 1; + } + } + + big_matrix MC((int32_t *)&C); + big_matrix MD((int32_t *)&D); + big_matrix MA((int8_t *)&A); + big_matrix MB((uint8_t *)&B); + matrix_multiply(MC, MA, MB); + matrix_multiply_ref((int32_t *)A, (int32_t *)B, (int32_t *)D, MATRIX_M, + MATRIX_N, MATRIX_K / 4); + + bool res = true; + for (int i = 0; i < MATRIX_M; i++) { + for (int j = 0; j < MATRIX_N; j++) { + if (C[i][j] != D[i][j]) + res = false; + } + } + if (res) + std::cout << "passed\n"; + else + std::cout << "failed\n"; +} diff --git a/SYCL/Matrix/Legacy/joint_matrix_us_int8.cpp b/SYCL/Matrix/Legacy/joint_matrix_us_int8.cpp new file mode 100644 index 0000000000..935606cbe6 --- /dev/null +++ b/SYCL/Matrix/Legacy/joint_matrix_us_int8.cpp @@ -0,0 +1,22 @@ +//==-------- joint_matrix_us_int8.cpp - DPC++ joint_matrix------------ ----==// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// REQUIRES: matrix + +// RUN: %clangxx -fsycl %s -o %t.out -DSYCL_EXT_ONEAPI_MATRIX_VERSION=1 +// RUN: %CPU_RUN_PLACEHOLDER %t.out +// RUN: %GPU_RUN_PLACEHOLDER %t.out + +#include +#include + +using namespace sycl; +using namespace sycl::ext::oneapi::experimental::matrix; + +#define SG_SZ 16 + +#include "joint_matrix_us_int8_impl.hpp" diff --git a/SYCL/Matrix/Legacy/joint_matrix_us_int8_impl.hpp b/SYCL/Matrix/Legacy/joint_matrix_us_int8_impl.hpp new file mode 100644 index 0000000000..22485b9af5 --- /dev/null +++ b/SYCL/Matrix/Legacy/joint_matrix_us_int8_impl.hpp @@ -0,0 +1,149 @@ +#define TM 8 +#define TN SG_SZ +#define TK 32 + +template struct big_matrix { +public: + T *mat; + +public: + T *get_data() { return mat; } + void set_data(T *data) { mat = data; } + big_matrix(T *data) : mat(data) {} +}; + +template +void matrix_multiply(big_matrix &C, + big_matrix &A, + big_matrix &B) { + size_t M = NUM_ROWS_C; + size_t N = NUM_COLS_C; + size_t K = NUM_COLS_A; + // B => K/4 x N*4, A => M x K, C => M, N + // stride should be X's cols, e.g., B's stirde = N*4 + assert(NUM_ROWS_C == NUM_ROWS_A && NUM_COLS_A == NUM_ROWS_B * 4); + size_t NDRangeM = M / TM; + size_t NDRangeN = N / TN; + buffer bufA(A.get_data(), range<2>(M, K)); + buffer bufB(B.get_data(), range<2>(K, N)); + buffer bufC(C.get_data(), range<2>(M, N)); + + queue q; + q.submit([&](handler &cgh) { + auto accC = bufC.get_access(cgh); + auto accA = bufA.get_access(cgh); + auto accB = bufB.get_access(cgh); + + cgh.parallel_for( + nd_range<2>({NDRangeM, NDRangeN * SG_SZ}, {1, 1 * SG_SZ}), + [accA, accB, accC, M, N, K](nd_item<2> spmd_item) + [[intel::reqd_sub_group_size(SG_SZ)]] + + { + // The submatrix API has to be accessed by all the workitems in a + // subgroup these functions will be called once by the subgroup no + // code divergence between the workitems + const auto global_idx = spmd_item.get_global_id(0); + const auto global_idy = spmd_item.get_global_id(1); + const auto sg_startx = global_idx - spmd_item.get_local_id(0); + const auto sg_starty = global_idy - spmd_item.get_local_id(1); + + ext::oneapi::sub_group sg = spmd_item.get_sub_group(); + joint_matrix sub_a(sg); + // For B, since current implementation does not support non-packed + // layout, users need to specify the updated VNNI sizes along with + // the packed_b layout. By default, the layout is row_major and size + // is (TK, TN). + joint_matrix sub_b(sg); + joint_matrix sub_c(sg); + + // AMX: 8 register tiles : 1k byte size, SMmaxxSKmax =16x64 + // strideX = X's cols, so strideC = N, strideA = K, strideB = N*4 + joint_matrix_load(sg, sub_c, + accC.get_pointer() + (sg_startx * TM) * N + + sg_starty / SG_SZ * TN, + N, matrix_layout::row_major); + for (int k = 0; k < K / TK; k += 1) { + joint_matrix_load( + sg, sub_a, accA.get_pointer() + (sg_startx * TM) * K + k * TK, + K, matrix_layout::row_major); + // Assuming B data is already in VNNI format. + joint_matrix_load(sg, sub_b, + accB.get_pointer() + (k * TK / 4) * (N * 4) + + sg_starty / SG_SZ * TN * 4, + N * 4, matrix_layout::packed_b); + sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + } + joint_matrix_store(sg, sub_c, + accC.get_pointer() + (sg_startx * TM) * N + + sg_starty / SG_SZ * TN, + N, matrix_layout::row_major); + }); // parallel for + }).wait(); +} + +static constexpr size_t MATRIX_M = TM * 2; +static constexpr size_t MATRIX_N = TN * 2; +static constexpr size_t MATRIX_K = TK * 2; +uint8_t A[MATRIX_M][MATRIX_K]; +int8_t B[MATRIX_K / 4][MATRIX_N * 4]; +int32_t C[MATRIX_M][MATRIX_N]; +int32_t D[MATRIX_M][MATRIX_N]; + +void matrix_multiply_ref(int32_t *A_mem, int32_t *B_mem, int32_t *C_mem, int M, + int N, int K) { + // tiling + for (int m = 0; m < M; m++) + for (int n = 0; n < N; n++) { + for (int k = 0; k < K; k++) { + uint8_t *va = (uint8_t *)(A_mem + m * K + k); + int8_t *vb = (int8_t *)(B_mem + k * N + n); + int acc = *(C_mem + m * N + n); + for (int i = 0; i < 4; i++) { + acc += (static_cast(va[i]) * static_cast(vb[i])); + } + *(C_mem + m * N + n) = acc; + } + } +} + +int main() { + for (int i = 0; i < MATRIX_M; i++) { + for (int j = 0; j < MATRIX_K; j++) { + A[i][j] = i + 2 * j; + } + } + for (int i = 0; i < MATRIX_K / 4; i++) { + for (int j = 0; j < MATRIX_N * 4; j++) { + B[i][j] = i + j; + } + } + for (int i = 0; i < MATRIX_M; i++) { + for (int j = 0; j < MATRIX_N; j++) { + C[i][j] = 1; + D[i][j] = 1; + } + } + + big_matrix MC((int32_t *)&C); + big_matrix MD((int32_t *)&D); + big_matrix MA((uint8_t *)&A); + big_matrix MB((int8_t *)&B); + matrix_multiply(MC, MA, MB); + matrix_multiply_ref((int32_t *)A, (int32_t *)B, (int32_t *)D, MATRIX_M, + MATRIX_N, MATRIX_K / 4); + + bool res = true; + for (int i = 0; i < MATRIX_M; i++) { + for (int j = 0; j < MATRIX_N; j++) { + if (C[i][j] != D[i][j]) + res = false; + } + } + if (res) + std::cout << "passed\n"; + else + std::cout << "failed\n"; +} diff --git a/SYCL/Matrix/Legacy/joint_matrix_uu_int8.cpp b/SYCL/Matrix/Legacy/joint_matrix_uu_int8.cpp new file mode 100644 index 0000000000..054f3aaae5 --- /dev/null +++ b/SYCL/Matrix/Legacy/joint_matrix_uu_int8.cpp @@ -0,0 +1,22 @@ +//==-------- joint_matrix_uu_int8.cpp - DPC++ joint_matrix------------ ----==// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// REQUIRES: matrix + +// RUN: %clangxx -fsycl %s -o %t.out -DSYCL_EXT_ONEAPI_MATRIX_VERSION=1 +// RUN: %CPU_RUN_PLACEHOLDER %t.out +// RUN: %GPU_RUN_PLACEHOLDER %t.out + +#include +#include + +using namespace sycl; +using namespace sycl::ext::oneapi::experimental::matrix; + +#define SG_SZ 16 + +#include "joint_matrix_uu_int8_impl.hpp" diff --git a/SYCL/Matrix/Legacy/joint_matrix_uu_int8_impl.hpp b/SYCL/Matrix/Legacy/joint_matrix_uu_int8_impl.hpp new file mode 100644 index 0000000000..2f6e8ef6a5 --- /dev/null +++ b/SYCL/Matrix/Legacy/joint_matrix_uu_int8_impl.hpp @@ -0,0 +1,147 @@ +#define TM 8 +#define TN SG_SZ +#define TK 32 + +template struct big_matrix { +public: + T *mat; + +public: + T *get_data() { return mat; } + void set_data(T *data) { mat = data; } + big_matrix(T *data) : mat(data) {} +}; + +template +void matrix_multiply(big_matrix &C, + big_matrix &A, + big_matrix &B) { + size_t M = NUM_ROWS_C; + size_t N = NUM_COLS_C; + size_t K = NUM_COLS_A; + // B => K/4 x N*4, A => M x K, C => M, N + // stride should be X's cols, e.g., B's stirde = N*4 + assert(NUM_ROWS_C == NUM_ROWS_A && NUM_COLS_A == NUM_ROWS_B * 4); + size_t NDRangeM = M / TM; + size_t NDRangeN = N / TN; + buffer bufA(A.get_data(), range<2>(M, K)); + buffer bufB(B.get_data(), range<2>(K, N)); + buffer bufC(C.get_data(), range<2>(M, N)); + + queue q; + q.submit([&](handler &cgh) { + auto accC = bufC.get_access(cgh); + auto accA = bufA.get_access(cgh); + auto accB = bufB.get_access(cgh); + + cgh.parallel_for( + nd_range<2>({NDRangeM, NDRangeN * SG_SZ}, {1, 1 * SG_SZ}), + [accA, accB, accC, M, N, + K](nd_item<2> spmd_item) [[intel::reqd_sub_group_size(SG_SZ)]] { + // The submatrix API has to be accessed by all the workitems in a + // subgroup these functions will be called once by the subgroup no + // code divergence between the workitems + const auto global_idx = spmd_item.get_global_id(0); + const auto global_idy = spmd_item.get_global_id(1); + const auto sg_startx = global_idx - spmd_item.get_local_id(0); + const auto sg_starty = global_idy - spmd_item.get_local_id(1); + + ext::oneapi::sub_group sg = spmd_item.get_sub_group(); + joint_matrix sub_a(sg); + // For B, since current implementation does not support non-packed + // layout, users need to specify the updated VNNI sizes along with + // the packed_b layout. By default, the layout is row_major and size + // is (TK, TN). + joint_matrix sub_b(sg); + joint_matrix sub_c(sg); + + // AMX: 8 register tiles : 1k byte size, SMmaxxSKmax =16x64 + // strideX = X's cols, so strideC = N, strideA = K, strideB = N*4 + joint_matrix_load(sg, sub_c, + accC.get_pointer() + (sg_startx * TM) * N + + sg_starty / SG_SZ * TN, + N, matrix_layout::row_major); + for (int k = 0; k < K / TK; k += 1) { + joint_matrix_load( + sg, sub_a, accA.get_pointer() + (sg_startx * TM) * K + k * TK, + K, matrix_layout::row_major); + // Assuming B data is already in VNNI format. + joint_matrix_load(sg, sub_b, + accB.get_pointer() + (k * TK / 4) * (N * 4) + + sg_starty / SG_SZ * TN * 4, + N * 4, matrix_layout::packed_b); + sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + } + joint_matrix_store(sg, sub_c, + accC.get_pointer() + (sg_startx * TM) * N + + sg_starty / SG_SZ * TN, + N, matrix_layout::row_major); + }); // parallel for + }).wait(); +} + +static constexpr size_t MATRIX_M = TM * 2; +static constexpr size_t MATRIX_N = TN * 2; +static constexpr size_t MATRIX_K = TK * 2; +uint8_t A[MATRIX_M][MATRIX_K]; +uint8_t B[MATRIX_K / 4][MATRIX_N * 4]; +int32_t C[MATRIX_M][MATRIX_N]; +int32_t D[MATRIX_M][MATRIX_N]; + +void matrix_multiply_ref(int32_t *A_mem, int32_t *B_mem, int32_t *C_mem, int M, + int N, int K) { + // tiling + for (int m = 0; m < M; m++) + for (int n = 0; n < N; n++) { + for (int k = 0; k < K; k++) { + uint8_t *va = (uint8_t *)(A_mem + m * K + k); + uint8_t *vb = (uint8_t *)(B_mem + k * N + n); + int acc = *(C_mem + m * N + n); + for (int i = 0; i < 4; i++) { + acc += (static_cast(va[i]) * static_cast(vb[i])); + } + *(C_mem + m * N + n) = acc; + } + } +} + +int main() { + for (int i = 0; i < MATRIX_M; i++) { + for (int j = 0; j < MATRIX_K; j++) { + A[i][j] = i + 2 * j; + } + } + for (int i = 0; i < MATRIX_K / 4; i++) { + for (int j = 0; j < MATRIX_N * 4; j++) { + B[i][j] = i + j; + } + } + for (int i = 0; i < MATRIX_M; i++) { + for (int j = 0; j < MATRIX_N; j++) { + C[i][j] = 1; + D[i][j] = 1; + } + } + + big_matrix MC((int32_t *)&C); + big_matrix MD((int32_t *)&D); + big_matrix MA((uint8_t *)&A); + big_matrix MB((uint8_t *)&B); + matrix_multiply(MC, MA, MB); + matrix_multiply_ref((int32_t *)A, (int32_t *)B, (int32_t *)D, MATRIX_M, + MATRIX_N, MATRIX_K / 4); + + bool res = true; + for (int i = 0; i < MATRIX_M; i++) { + for (int j = 0; j < MATRIX_N; j++) { + if (C[i][j] != D[i][j]) + res = false; + } + } + if (res) + std::cout << "passed\n"; + else + std::cout << "failed\n"; +} diff --git a/SYCL/Matrix/XMX8/element_wise_all_ops_bf16.cpp b/SYCL/Matrix/XMX8/element_wise_all_ops_bf16.cpp index 42e976f7ba..97c6593ba3 100644 --- a/SYCL/Matrix/XMX8/element_wise_all_ops_bf16.cpp +++ b/SYCL/Matrix/XMX8/element_wise_all_ops_bf16.cpp @@ -7,7 +7,7 @@ //===----------------------------------------------------------------------===// // REQUIRES: matrix-xmx8 -// RUN: %clangxx -fsycl %s -o %t.out +// RUN: %clangxx -fsycl %s -o %t.out -DSYCL_EXT_ONEAPI_MATRIX_VERSION=4 // RUN: %CPU_RUN_PLACEHOLDER %t.out // RUN: %GPU_RUN_PLACEHOLDER %t.out diff --git a/SYCL/Matrix/XMX8/element_wise_all_ops_half.cpp b/SYCL/Matrix/XMX8/element_wise_all_ops_half.cpp index 65cf66613b..62fd63cc88 100644 --- a/SYCL/Matrix/XMX8/element_wise_all_ops_half.cpp +++ b/SYCL/Matrix/XMX8/element_wise_all_ops_half.cpp @@ -9,7 +9,7 @@ // Only runs on DPAS because AMX implementation does not support half data type // yet -// RUN: %clangxx -fsycl %s -o %t.out +// RUN: %clangxx -fsycl %s -o %t.out -DSYCL_EXT_ONEAPI_MATRIX_VERSION=4 // RUN: %GPU_RUN_PLACEHOLDER %t.out #include diff --git a/SYCL/Matrix/XMX8/element_wise_all_ops_int8.cpp b/SYCL/Matrix/XMX8/element_wise_all_ops_int8.cpp index 48669fc155..2605e89e30 100644 --- a/SYCL/Matrix/XMX8/element_wise_all_ops_int8.cpp +++ b/SYCL/Matrix/XMX8/element_wise_all_ops_int8.cpp @@ -7,7 +7,7 @@ //===----------------------------------------------------------------------===// // REQUIRES: matrix-xmx8 -// RUN: %clangxx -fsycl %s -o %t.out +// RUN: %clangxx -fsycl %s -o %t.out -DSYCL_EXT_ONEAPI_MATRIX_VERSION=4 // RUN: %CPU_RUN_PLACEHOLDER %t.out // RUN: %GPU_RUN_PLACEHOLDER %t.out diff --git a/SYCL/Matrix/XMX8/element_wise_all_ops_int8_packed.cpp b/SYCL/Matrix/XMX8/element_wise_all_ops_int8_packed.cpp index e949af57bb..7124332923 100644 --- a/SYCL/Matrix/XMX8/element_wise_all_ops_int8_packed.cpp +++ b/SYCL/Matrix/XMX8/element_wise_all_ops_int8_packed.cpp @@ -7,11 +7,11 @@ //===----------------------------------------------------------------------===// // REQUIRES: matrix-xmx8 -// RUN: %clangxx -fsycl %s -o %t.out +// RUN: %clangxx -fsycl %s -o %t.out -DSYCL_EXT_ONEAPI_MATRIX_VERSION=4 // RUN: %CPU_RUN_PLACEHOLDER %t.out // RUN: %GPU_RUN_PLACEHOLDER %t.out -// This test store the matrix B that is VNNIed (packed) in a row major fashion. +// This test stores the matrix B that is VNNIed (packed) in a row major fashion. // This is expected to fail on the GPU because the implementation does not // support automatic transformation YET, in this case: VNNI to row major in the // store. diff --git a/SYCL/Matrix/XMX8/element_wise_irreg_sum_rows.cpp b/SYCL/Matrix/XMX8/element_wise_irreg_sum_rows.cpp index 481e0b198b..cc3dd78a63 100644 --- a/SYCL/Matrix/XMX8/element_wise_irreg_sum_rows.cpp +++ b/SYCL/Matrix/XMX8/element_wise_irreg_sum_rows.cpp @@ -7,7 +7,7 @@ //===----------------------------------------------------------------------===// // REQUIRES: matrix-xmx8 -// RUN: %clangxx -fsycl %s -o %t.out +// RUN: %clangxx -fsycl %s -o %t.out -DSYCL_EXT_ONEAPI_MATRIX_VERSION=4 // RUN: %CPU_RUN_PLACEHOLDER %t.out // RUN: %GPU_RUN_PLACEHOLDER %t.out diff --git a/SYCL/Matrix/XMX8/element_wise_ops.cpp b/SYCL/Matrix/XMX8/element_wise_ops.cpp index 0a3a535724..1d7b64e406 100644 --- a/SYCL/Matrix/XMX8/element_wise_ops.cpp +++ b/SYCL/Matrix/XMX8/element_wise_ops.cpp @@ -7,7 +7,7 @@ //===----------------------------------------------------------------------===// // REQUIRES: matrix-xmx8 -// RUN: %clangxx -fsycl %s -o %t.out +// RUN: %clangxx -fsycl %s -o %t.out -DSYCL_EXT_ONEAPI_MATRIX_VERSION=4 // RUN: %CPU_RUN_PLACEHOLDER %t.out // RUN: %GPU_RUN_PLACEHOLDER %t.out diff --git a/SYCL/Matrix/XMX8/joint_matrix_bf16.cpp b/SYCL/Matrix/XMX8/joint_matrix_bf16.cpp index 064c553712..3a7e5c67f0 100644 --- a/SYCL/Matrix/XMX8/joint_matrix_bf16.cpp +++ b/SYCL/Matrix/XMX8/joint_matrix_bf16.cpp @@ -7,7 +7,7 @@ //===----------------------------------------------------------------------===// // REQUIRES: matrix-xmx8 -// RUN: %clangxx -fsycl %s -o %t.out +// RUN: %clangxx -fsycl %s -o %t.out -DSYCL_EXT_ONEAPI_MATRIX_VERSION=4 // RUN: %CPU_RUN_PLACEHOLDER %t.out // RUN: %GPU_RUN_PLACEHOLDER %t.out diff --git a/SYCL/Matrix/XMX8/joint_matrix_bfloat16.cpp b/SYCL/Matrix/XMX8/joint_matrix_bfloat16.cpp index c91bb539dd..e1f67c435e 100644 --- a/SYCL/Matrix/XMX8/joint_matrix_bfloat16.cpp +++ b/SYCL/Matrix/XMX8/joint_matrix_bfloat16.cpp @@ -7,7 +7,7 @@ //===----------------------------------------------------------------------===// // REQUIRES: matrix-xmx8 -// RUN: %clangxx -fsycl %s -o %t.out +// RUN: %clangxx -fsycl %s -o %t.out -DSYCL_EXT_ONEAPI_MATRIX_VERSION=4 // RUN: %CPU_RUN_PLACEHOLDER %t.out // RUN: %GPU_RUN_PLACEHOLDER %t.out diff --git a/SYCL/Matrix/XMX8/joint_matrix_bfloat16_32x64.cpp b/SYCL/Matrix/XMX8/joint_matrix_bfloat16_32x64.cpp new file mode 100644 index 0000000000..e7c1b42dd7 --- /dev/null +++ b/SYCL/Matrix/XMX8/joint_matrix_bfloat16_32x64.cpp @@ -0,0 +1,25 @@ +//==----- joint_matrix_bfloat16_32x64.cpp - DPC++ joint_matrix-------------==// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// REQUIRES: matrix + +// RUN: %clangxx -fsycl %s -o %t.out -DSYCL_EXT_ONEAPI_MATRIX_VERSION=4 +// RUN: %CPU_RUN_PLACEHOLDER %t.out +// RUN: %GPU_RUN_PLACEHOLDER %t.out + +// XFAIL: * + +#include +#include + +using namespace sycl; +using namespace sycl::ext::oneapi::experimental::matrix; +using bfloat16 = sycl::ext::oneapi::bfloat16; + +#define SG_SZ 8 + +#include "../joint_matrix_bfloat16_32x64_impl.hpp" diff --git a/SYCL/Matrix/XMX8/joint_matrix_bfloat16_use.cpp b/SYCL/Matrix/XMX8/joint_matrix_bfloat16_use.cpp index 827b18af0e..a97bdf4bf0 100644 --- a/SYCL/Matrix/XMX8/joint_matrix_bfloat16_use.cpp +++ b/SYCL/Matrix/XMX8/joint_matrix_bfloat16_use.cpp @@ -7,7 +7,7 @@ //===----------------------------------------------------------------------===// // REQUIRES: matrix-xmx8 -// RUN: %clangxx -fsycl %s -o %t.out -DSYCL_EXT_ONEAPI_MATRIX_VERSION=2 +// RUN: %clangxx -fsycl %s -o %t.out -DSYCL_EXT_ONEAPI_MATRIX_VERSION=4 // RUN: %CPU_RUN_PLACEHOLDER %t.out // RUN: %GPU_RUN_PLACEHOLDER %t.out diff --git a/SYCL/Matrix/XMX8/joint_matrix_half.cpp b/SYCL/Matrix/XMX8/joint_matrix_half.cpp index cf781fb5a3..355fef88e2 100644 --- a/SYCL/Matrix/XMX8/joint_matrix_half.cpp +++ b/SYCL/Matrix/XMX8/joint_matrix_half.cpp @@ -7,7 +7,7 @@ //===----------------------------------------------------------------------===// // REQUIRES: matrix-xmx8 -// RUN: %clangxx -fsycl %s -o %t.out +// RUN: %clangxx -fsycl %s -o %t.out -DSYCL_EXT_ONEAPI_MATRIX_VERSION=4 // Only run on the GPU because half is not supported on AMX hardware // RUN: %GPU_RUN_PLACEHOLDER %t.out diff --git a/SYCL/Matrix/XMX8/joint_matrix_int8_vnni.cpp b/SYCL/Matrix/XMX8/joint_matrix_int8_vnni.cpp index e9749a88d1..0af6a21b85 100644 --- a/SYCL/Matrix/XMX8/joint_matrix_int8_vnni.cpp +++ b/SYCL/Matrix/XMX8/joint_matrix_int8_vnni.cpp @@ -7,7 +7,7 @@ //===----------------------------------------------------------------------===// // REQUIRES: matrix-xmx8 -// RUN: %clangxx -fsycl %s -o %t.out +// RUN: %clangxx -fsycl %s -o %t.out -DSYCL_EXT_ONEAPI_MATRIX_VERSION=4 // RUN: %CPU_RUN_PLACEHOLDER %t.out // RUN: %GPU_RUN_PLACEHOLDER %t.out diff --git a/SYCL/Matrix/XMX8/joint_matrix_ss_int8.cpp b/SYCL/Matrix/XMX8/joint_matrix_ss_int8.cpp index 43b3017a8f..86d7f75308 100644 --- a/SYCL/Matrix/XMX8/joint_matrix_ss_int8.cpp +++ b/SYCL/Matrix/XMX8/joint_matrix_ss_int8.cpp @@ -7,7 +7,7 @@ //===----------------------------------------------------------------------===// // REQUIRES: matrix-xmx8 -// RUN: %clangxx -fsycl %s -o %t.out +// RUN: %clangxx -fsycl %s -o %t.out -DSYCL_EXT_ONEAPI_MATRIX_VERSION=4 // RUN: %CPU_RUN_PLACEHOLDER %t.out // RUN: %GPU_RUN_PLACEHOLDER %t.out diff --git a/SYCL/Matrix/XMX8/joint_matrix_su_int8.cpp b/SYCL/Matrix/XMX8/joint_matrix_su_int8.cpp index be3b903d7f..252a647f5d 100644 --- a/SYCL/Matrix/XMX8/joint_matrix_su_int8.cpp +++ b/SYCL/Matrix/XMX8/joint_matrix_su_int8.cpp @@ -7,7 +7,7 @@ //===----------------------------------------------------------------------===// // REQUIRES: matrix-xmx8 -// RUN: %clangxx -fsycl %s -o %t.out +// RUN: %clangxx -fsycl %s -o %t.out -DSYCL_EXT_ONEAPI_MATRIX_VERSION=4 // RUN: %CPU_RUN_PLACEHOLDER %t.out // RUN: %GPU_RUN_PLACEHOLDER %t.out diff --git a/SYCL/Matrix/XMX8/joint_matrix_us_int8.cpp b/SYCL/Matrix/XMX8/joint_matrix_us_int8.cpp index 453fc0e586..e74e7ad46b 100644 --- a/SYCL/Matrix/XMX8/joint_matrix_us_int8.cpp +++ b/SYCL/Matrix/XMX8/joint_matrix_us_int8.cpp @@ -7,7 +7,7 @@ //===----------------------------------------------------------------------===// // REQUIRES: matrix-xmx8 -// RUN: %clangxx -fsycl %s -o %t.out +// RUN: %clangxx -fsycl %s -o %t.out -DSYCL_EXT_ONEAPI_MATRIX_VERSION=4 // RUN: %CPU_RUN_PLACEHOLDER %t.out // RUN: %GPU_RUN_PLACEHOLDER %t.out diff --git a/SYCL/Matrix/XMX8/joint_matrix_uu_int8.cpp b/SYCL/Matrix/XMX8/joint_matrix_uu_int8.cpp index 15ee1f8d38..06934de225 100644 --- a/SYCL/Matrix/XMX8/joint_matrix_uu_int8.cpp +++ b/SYCL/Matrix/XMX8/joint_matrix_uu_int8.cpp @@ -7,7 +7,7 @@ //===----------------------------------------------------------------------===// // REQUIRES: matrix-xmx8 -// RUN: %clangxx -fsycl %s -o %t.out +// RUN: %clangxx -fsycl %s -o %t.out -DSYCL_EXT_ONEAPI_MATRIX_VERSION=4 // RUN: %CPU_RUN_PLACEHOLDER %t.out // RUN: %GPU_RUN_PLACEHOLDER %t.out diff --git a/SYCL/Matrix/element_wise_all_ops_bf16.cpp b/SYCL/Matrix/element_wise_all_ops_bf16.cpp index 320c08b839..58b8bc01aa 100644 --- a/SYCL/Matrix/element_wise_all_ops_bf16.cpp +++ b/SYCL/Matrix/element_wise_all_ops_bf16.cpp @@ -7,7 +7,7 @@ //===----------------------------------------------------------------------===// // REQUIRES: matrix -// RUN: %clangxx -fsycl %s -o %t.out +// RUN: %clangxx -fsycl %s -o %t.out -DSYCL_EXT_ONEAPI_MATRIX_VERSION=4 // RUN: %CPU_RUN_PLACEHOLDER %t.out // RUN: %GPU_RUN_PLACEHOLDER %t.out diff --git a/SYCL/Matrix/element_wise_all_ops_bf16_impl.hpp b/SYCL/Matrix/element_wise_all_ops_bf16_impl.hpp index 5beb3fa8e1..c84e9633c4 100644 --- a/SYCL/Matrix/element_wise_all_ops_bf16_impl.hpp +++ b/SYCL/Matrix/element_wise_all_ops_bf16_impl.hpp @@ -52,19 +52,20 @@ void matrix_verify_add(queue q, big_matrix &A, nd_range<2> &r, const auto sg_startx = global_idx - spmd_item.get_local_id(0); const auto sg_starty = global_idy - spmd_item.get_local_id(1); - ext::oneapi::sub_group sg = spmd_item.get_sub_group(); - joint_matrix sub_a(sg); + sub_group sg = spmd_item.get_sub_group(); + joint_matrix sub_a; joint_matrix_fill(sg, sub_a, make_bf16(5.0)); - auto wi_slice_a = sub_a.get_wi_data(); + auto wi_slice_a = get_wi_data(sg, sub_a); for (int i = 0; i < wi_slice_a.length(); i++) { wi_slice_a[i] = wi_slice_a[i] + make_bf16(2); } - joint_matrix_store(sg, sub_a, - accA.get_pointer() + (sg_startx * TM) * N + - sg_starty / SG_SZ * TN, - N, matrix_layout::row_major); + ext::intel::experimental::matrix::joint_matrix_store( + sg, sub_a, + accA.get_pointer() + (sg_startx * TM) * N + + sg_starty / SG_SZ * TN, + N); }); // parallel for }).wait(); assert_ops_ref(bufA.get_access(), ref); @@ -85,19 +86,20 @@ void matrix_verify_sub(queue q, big_matrix &A, nd_range<2> &r, const auto sg_startx = global_idx - spmd_item.get_local_id(0); const auto sg_starty = global_idy - spmd_item.get_local_id(1); - ext::oneapi::sub_group sg = spmd_item.get_sub_group(); - joint_matrix sub_a(sg); + sub_group sg = spmd_item.get_sub_group(); + joint_matrix sub_a; joint_matrix_fill(sg, sub_a, make_bf16(5.0)); - auto wi_slice_a = sub_a.get_wi_data(); + auto wi_slice_a = get_wi_data(sg, sub_a); for (int i = 0; i < wi_slice_a.length(); i++) { wi_slice_a[i] = wi_slice_a[i] - make_bf16(2); } - joint_matrix_store(sg, sub_a, - accA.get_pointer() + (sg_startx * TM) * N + - sg_starty / SG_SZ * TN, - N, matrix_layout::row_major); + ext::intel::experimental::matrix::joint_matrix_store( + sg, sub_a, + accA.get_pointer() + (sg_startx * TM) * N + + sg_starty / SG_SZ * TN, + N); }); // parallel for }).wait(); assert_ops_ref(bufA.get_access(), ref); @@ -118,19 +120,19 @@ void matrix_verify_mul(queue q, big_matrix &A, nd_range<2> &r, const auto sg_startx = global_idx - spmd_item.get_local_id(0); const auto sg_starty = global_idy - spmd_item.get_local_id(1); - ext::oneapi::sub_group sg = spmd_item.get_sub_group(); - joint_matrix sub_a(sg); - + sub_group sg = spmd_item.get_sub_group(); + joint_matrix sub_a; joint_matrix_fill(sg, sub_a, make_bf16(5.0)); - auto wi_slice_a = sub_a.get_wi_data(); + auto wi_slice_a = get_wi_data(sg, sub_a); for (int i = 0; i < wi_slice_a.length(); i++) { wi_slice_a[i] = wi_slice_a[i] * make_bf16(3.0); } - joint_matrix_store(sg, sub_a, - accA.get_pointer() + (sg_startx * TM) * N + - sg_starty / SG_SZ * TN, - N, matrix_layout::row_major); + ext::intel::experimental::matrix::joint_matrix_store( + sg, sub_a, + accA.get_pointer() + (sg_startx * TM) * N + + sg_starty / SG_SZ * TN, + N); }); // parallel for }).wait(); assert_ops_ref(bufA.get_access(), ref); @@ -151,19 +153,20 @@ void matrix_verify_div(queue q, big_matrix &A, nd_range<2> &r, const auto sg_startx = global_idx - spmd_item.get_local_id(0); const auto sg_starty = global_idy - spmd_item.get_local_id(1); - ext::oneapi::sub_group sg = spmd_item.get_sub_group(); - joint_matrix sub_a(sg); + sub_group sg = spmd_item.get_sub_group(); + joint_matrix sub_a; joint_matrix_fill(sg, sub_a, make_bf16(4.0)); - auto wi_slice_a = sub_a.get_wi_data(); + auto wi_slice_a = get_wi_data(sg, sub_a); for (int i = 0; i < wi_slice_a.length(); i++) { wi_slice_a[i] = wi_slice_a[i] / make_bf16(2.0); } - joint_matrix_store(sg, sub_a, - accA.get_pointer() + (sg_startx * TM) * N + - sg_starty / SG_SZ * TN, - N, matrix_layout::row_major); + ext::intel::experimental::matrix::joint_matrix_store( + sg, sub_a, + accA.get_pointer() + (sg_startx * TM) * N + + sg_starty / SG_SZ * TN, + N); }); // parallel for }).wait(); assert_ops_ref(bufA.get_access(), ref); @@ -183,12 +186,12 @@ void matrix_verify_logic(queue q, big_matrix &A, nd_range<2> &r, const auto sg_startx = global_idx - spmd_item.get_local_id(0); const auto sg_starty = global_idy - spmd_item.get_local_id(1); - ext::oneapi::sub_group sg = spmd_item.get_sub_group(); - joint_matrix sub_a(sg); + sub_group sg = spmd_item.get_sub_group(); + joint_matrix sub_a; joint_matrix_fill(sg, sub_a, make_bf16(5.0)); - auto wi_slice_a = sub_a.get_wi_data(); + auto wi_slice_a = get_wi_data(sg, sub_a); for (int i = 0; i < wi_slice_a.length(); i++) { if (wi_slice_a[i]) { if (wi_slice_a[i] > make_bf16(2.0) || @@ -211,10 +214,11 @@ void matrix_verify_logic(queue q, big_matrix &A, nd_range<2> &r, } } } - joint_matrix_store(sg, sub_a, - accA.get_pointer() + (sg_startx * TM) * N + - sg_starty / SG_SZ * TN, - N, matrix_layout::row_major); + ext::intel::experimental::matrix::joint_matrix_store( + sg, sub_a, + accA.get_pointer() + (sg_startx * TM) * N + + sg_starty / SG_SZ * TN, + N); }); // parallel for }).wait(); assert_ops_ref(bufA.get_access(), ref); diff --git a/SYCL/Matrix/element_wise_all_ops_half.cpp b/SYCL/Matrix/element_wise_all_ops_half.cpp index 27b1fbef22..e860180c03 100644 --- a/SYCL/Matrix/element_wise_all_ops_half.cpp +++ b/SYCL/Matrix/element_wise_all_ops_half.cpp @@ -9,7 +9,7 @@ // Only runs on DPAS because AMX implementation does not support half data type // yet -// RUN: %clangxx -fsycl %s -o %t.out +// RUN: %clangxx -fsycl %s -o %t.out -DSYCL_EXT_ONEAPI_MATRIX_VERSION=4 // RUN: %GPU_RUN_PLACEHOLDER %t.out #include diff --git a/SYCL/Matrix/element_wise_all_ops_half_impl.hpp b/SYCL/Matrix/element_wise_all_ops_half_impl.hpp index 49b7e165eb..a1ae5692a6 100644 --- a/SYCL/Matrix/element_wise_all_ops_half_impl.hpp +++ b/SYCL/Matrix/element_wise_all_ops_half_impl.hpp @@ -38,19 +38,20 @@ void matrix_verify_add(queue q, big_matrix &A, nd_range<2> &r, const auto sg_startx = global_idx - spmd_item.get_local_id(0); const auto sg_starty = global_idy - spmd_item.get_local_id(1); - ext::oneapi::sub_group sg = spmd_item.get_sub_group(); - joint_matrix sub_a(sg); + sub_group sg = spmd_item.get_sub_group(); + joint_matrix sub_a; joint_matrix_fill(sg, sub_a, 5.0); - auto wi_slice_a = sub_a.get_wi_data(); + auto wi_slice_a = get_wi_data(sg, sub_a); for (int i = 0; i < wi_slice_a.length(); i++) { wi_slice_a[i] = wi_slice_a[i] + static_cast(2); } - joint_matrix_store(sg, sub_a, - accA.get_pointer() + (sg_startx * TM) * N + - sg_starty / SG_SZ * TN, - N, matrix_layout::row_major); + ext::intel::experimental::matrix::joint_matrix_store( + sg, sub_a, + accA.get_pointer() + (sg_startx * TM) * N + + sg_starty / SG_SZ * TN, + N); }); // parallel for }).wait(); assert_ops_ref(bufA.get_access(), ref); @@ -71,19 +72,20 @@ void matrix_verify_sub(queue q, big_matrix &A, nd_range<2> &r, const auto sg_startx = global_idx - spmd_item.get_local_id(0); const auto sg_starty = global_idy - spmd_item.get_local_id(1); - ext::oneapi::sub_group sg = spmd_item.get_sub_group(); - joint_matrix sub_a(sg); + sub_group sg = spmd_item.get_sub_group(); + joint_matrix sub_a; joint_matrix_fill(sg, sub_a, 5.0); - auto wi_slice_a = sub_a.get_wi_data(); + auto wi_slice_a = get_wi_data(sg, sub_a); for (int i = 0; i < wi_slice_a.length(); i++) { wi_slice_a[i] = wi_slice_a[i] - static_cast(2); } - joint_matrix_store(sg, sub_a, - accA.get_pointer() + (sg_startx * TM) * N + - sg_starty / SG_SZ * TN, - N, matrix_layout::row_major); + ext::intel::experimental::matrix::joint_matrix_store( + sg, sub_a, + accA.get_pointer() + (sg_startx * TM) * N + + sg_starty / SG_SZ * TN, + N); }); // parallel for }).wait(); assert_ops_ref(bufA.get_access(), ref); @@ -104,19 +106,20 @@ void matrix_verify_mul(queue q, big_matrix &A, nd_range<2> &r, const auto sg_startx = global_idx - spmd_item.get_local_id(0); const auto sg_starty = global_idy - spmd_item.get_local_id(1); - ext::oneapi::sub_group sg = spmd_item.get_sub_group(); - joint_matrix sub_a(sg); + sub_group sg = spmd_item.get_sub_group(); + joint_matrix sub_a; joint_matrix_fill(sg, sub_a, 5.0); - auto wi_slice_a = sub_a.get_wi_data(); + auto wi_slice_a = get_wi_data(sg, sub_a); for (int i = 0; i < wi_slice_a.length(); i++) { wi_slice_a[i] = wi_slice_a[i] * static_cast(3.0); } - joint_matrix_store(sg, sub_a, - accA.get_pointer() + (sg_startx * TM) * N + - sg_starty / SG_SZ * TN, - N, matrix_layout::row_major); + ext::intel::experimental::matrix::joint_matrix_store( + sg, sub_a, + accA.get_pointer() + (sg_startx * TM) * N + + sg_starty / SG_SZ * TN, + N); }); // parallel for }).wait(); assert_ops_ref(bufA.get_access(), ref); @@ -137,19 +140,20 @@ void matrix_verify_div(queue q, big_matrix &A, nd_range<2> &r, const auto sg_startx = global_idx - spmd_item.get_local_id(0); const auto sg_starty = global_idy - spmd_item.get_local_id(1); - ext::oneapi::sub_group sg = spmd_item.get_sub_group(); - joint_matrix sub_a(sg); + sub_group sg = spmd_item.get_sub_group(); + joint_matrix sub_a; joint_matrix_fill(sg, sub_a, 4.0); - auto wi_slice_a = sub_a.get_wi_data(); + auto wi_slice_a = get_wi_data(sg, sub_a); for (int i = 0; i < wi_slice_a.length(); i++) { wi_slice_a[i] = wi_slice_a[i] / static_cast(2.0); } - joint_matrix_store(sg, sub_a, - accA.get_pointer() + (sg_startx * TM) * N + - sg_starty / SG_SZ * TN, - N, matrix_layout::row_major); + ext::intel::experimental::matrix::joint_matrix_store( + sg, sub_a, + accA.get_pointer() + (sg_startx * TM) * N + + sg_starty / SG_SZ * TN, + N); }); // parallel for }).wait(); assert_ops_ref(bufA.get_access(), ref); @@ -170,12 +174,12 @@ void matrix_verify_logic(queue q, big_matrix &A, nd_range<2> &r, const auto sg_startx = global_idx - spmd_item.get_local_id(0); const auto sg_starty = global_idy - spmd_item.get_local_id(1); - ext::oneapi::sub_group sg = spmd_item.get_sub_group(); - joint_matrix sub_a(sg); + sub_group sg = spmd_item.get_sub_group(); + joint_matrix sub_a; joint_matrix_fill(sg, sub_a, 5.0); - auto wi_slice_a = sub_a.get_wi_data(); + auto wi_slice_a = get_wi_data(sg, sub_a); for (int i = 0; i < wi_slice_a.length(); i++) { if (wi_slice_a[i]) { if (wi_slice_a[i] > static_cast(2.0) || @@ -198,10 +202,11 @@ void matrix_verify_logic(queue q, big_matrix &A, nd_range<2> &r, } } } - joint_matrix_store(sg, sub_a, - accA.get_pointer() + (sg_startx * TM) * N + - sg_starty / SG_SZ * TN, - N, matrix_layout::row_major); + ext::intel::experimental::matrix::joint_matrix_store( + sg, sub_a, + accA.get_pointer() + (sg_startx * TM) * N + + sg_starty / SG_SZ * TN, + N); }); // parallel for }).wait(); assert_ops_ref(bufA.get_access(), ref); diff --git a/SYCL/Matrix/element_wise_all_ops_int8.cpp b/SYCL/Matrix/element_wise_all_ops_int8.cpp index 5201c163f2..adcee2a750 100644 --- a/SYCL/Matrix/element_wise_all_ops_int8.cpp +++ b/SYCL/Matrix/element_wise_all_ops_int8.cpp @@ -7,7 +7,7 @@ //===----------------------------------------------------------------------===// // REQUIRES: matrix -// RUN: %clangxx -fsycl %s -o %t.out +// RUN: %clangxx -fsycl %s -o %t.out -DSYCL_EXT_ONEAPI_MATRIX_VERSION=4 // RUN: %CPU_RUN_PLACEHOLDER %t.out // RUN: %GPU_RUN_PLACEHOLDER %t.out diff --git a/SYCL/Matrix/element_wise_all_ops_int8_impl.hpp b/SYCL/Matrix/element_wise_all_ops_int8_impl.hpp index 8323695cf1..683eb59fa0 100644 --- a/SYCL/Matrix/element_wise_all_ops_int8_impl.hpp +++ b/SYCL/Matrix/element_wise_all_ops_int8_impl.hpp @@ -38,19 +38,20 @@ void matrix_verify_add(queue q, big_matrix &A, nd_range<2> &r, const auto sg_startx = global_idx - spmd_item.get_local_id(0); const auto sg_starty = global_idy - spmd_item.get_local_id(1); - ext::oneapi::sub_group sg = spmd_item.get_sub_group(); - joint_matrix sub_a(sg); + sub_group sg = spmd_item.get_sub_group(); + joint_matrix sub_a; joint_matrix_fill(sg, sub_a, 5); - auto wi_slice_a = sub_a.get_wi_data(); + auto wi_slice_a = get_wi_data(sg, sub_a); for (int i = 0; i < wi_slice_a.length(); i++) { wi_slice_a[i] = wi_slice_a[i] + 2; } - joint_matrix_store(sg, sub_a, - accA.get_pointer() + (sg_startx * TM) * N + - sg_starty / SG_SZ * TN, - N, matrix_layout::row_major); + ext::intel::experimental::matrix::joint_matrix_store( + sg, sub_a, + accA.get_pointer() + (sg_startx * TM) * N + + sg_starty / SG_SZ * TN, + N); }); // parallel for }).wait(); assert_ops_ref(bufA.get_access(), ref); @@ -71,19 +72,20 @@ void matrix_verify_sub(queue q, big_matrix &A, nd_range<2> &r, const auto sg_startx = global_idx - spmd_item.get_local_id(0); const auto sg_starty = global_idy - spmd_item.get_local_id(1); - ext::oneapi::sub_group sg = spmd_item.get_sub_group(); - joint_matrix sub_a(sg); + sub_group sg = spmd_item.get_sub_group(); + joint_matrix sub_a; joint_matrix_fill(sg, sub_a, 5); - auto wi_slice_a = sub_a.get_wi_data(); + auto wi_slice_a = get_wi_data(sg, sub_a); for (int i = 0; i < wi_slice_a.length(); i++) { wi_slice_a[i] = wi_slice_a[i] - 2; } - joint_matrix_store(sg, sub_a, - accA.get_pointer() + (sg_startx * TM) * N + - sg_starty / SG_SZ * TN, - N, matrix_layout::row_major); + ext::intel::experimental::matrix::joint_matrix_store( + sg, sub_a, + accA.get_pointer() + (sg_startx * TM) * N + + sg_starty / SG_SZ * TN, + N); }); // parallel for }).wait(); assert_ops_ref(bufA.get_access(), ref); @@ -104,19 +106,20 @@ void matrix_verify_mul(queue q, big_matrix &A, nd_range<2> &r, const auto sg_startx = global_idx - spmd_item.get_local_id(0); const auto sg_starty = global_idy - spmd_item.get_local_id(1); - ext::oneapi::sub_group sg = spmd_item.get_sub_group(); - joint_matrix sub_a(sg); + sub_group sg = spmd_item.get_sub_group(); + joint_matrix sub_a; joint_matrix_fill(sg, sub_a, 5); - auto wi_slice_a = sub_a.get_wi_data(); + auto wi_slice_a = get_wi_data(sg, sub_a); for (int i = 0; i < wi_slice_a.length(); i++) { wi_slice_a[i] = wi_slice_a[i] * 3; } - joint_matrix_store(sg, sub_a, - accA.get_pointer() + (sg_startx * TM) * N + - sg_starty / SG_SZ * TN, - N, matrix_layout::row_major); + ext::intel::experimental::matrix::joint_matrix_store( + sg, sub_a, + accA.get_pointer() + (sg_startx * TM) * N + + sg_starty / SG_SZ * TN, + N); }); // parallel for }).wait(); assert_ops_ref(bufA.get_access(), ref); @@ -137,19 +140,20 @@ void matrix_verify_div(queue q, big_matrix &A, nd_range<2> &r, const auto sg_startx = global_idx - spmd_item.get_local_id(0); const auto sg_starty = global_idy - spmd_item.get_local_id(1); - ext::oneapi::sub_group sg = spmd_item.get_sub_group(); - joint_matrix sub_a(sg); + sub_group sg = spmd_item.get_sub_group(); + joint_matrix sub_a; joint_matrix_fill(sg, sub_a, 4); - auto wi_slice_a = sub_a.get_wi_data(); + auto wi_slice_a = get_wi_data(sg, sub_a); for (int i = 0; i < wi_slice_a.length(); i++) { wi_slice_a[i] = wi_slice_a[i] / 2; } - joint_matrix_store(sg, sub_a, - accA.get_pointer() + (sg_startx * TM) * N + - sg_starty / SG_SZ * TN, - N, matrix_layout::row_major); + ext::intel::experimental::matrix::joint_matrix_store( + sg, sub_a, + accA.get_pointer() + (sg_startx * TM) * N + + sg_starty / SG_SZ * TN, + N); }); // parallel for }).wait(); assert_ops_ref(bufA.get_access(), ref); @@ -170,12 +174,12 @@ void matrix_verify_logic(queue q, big_matrix &A, nd_range<2> &r, const auto sg_startx = global_idx - spmd_item.get_local_id(0); const auto sg_starty = global_idy - spmd_item.get_local_id(1); - ext::oneapi::sub_group sg = spmd_item.get_sub_group(); - joint_matrix sub_a(sg); + sub_group sg = spmd_item.get_sub_group(); + joint_matrix sub_a; joint_matrix_fill(sg, sub_a, 5); - auto wi_slice_a = sub_a.get_wi_data(); + auto wi_slice_a = get_wi_data(sg, sub_a); for (int i = 0; i < wi_slice_a.length(); i++) { if (wi_slice_a[i]) { if (wi_slice_a[i] > 2 || wi_slice_a[i] >= 2 || @@ -194,10 +198,11 @@ void matrix_verify_logic(queue q, big_matrix &A, nd_range<2> &r, } } } - joint_matrix_store(sg, sub_a, - accA.get_pointer() + (sg_startx * TM) * N + - sg_starty / SG_SZ * TN, - N, matrix_layout::row_major); + ext::intel::experimental::matrix::joint_matrix_store( + sg, sub_a, + accA.get_pointer() + (sg_startx * TM) * N + + sg_starty / SG_SZ * TN, + N); }); // parallel for }).wait(); assert_ops_ref(bufA.get_access(), ref); diff --git a/SYCL/Matrix/element_wise_all_ops_int8_packed.cpp b/SYCL/Matrix/element_wise_all_ops_int8_packed.cpp index 231565a510..6008079449 100644 --- a/SYCL/Matrix/element_wise_all_ops_int8_packed.cpp +++ b/SYCL/Matrix/element_wise_all_ops_int8_packed.cpp @@ -7,14 +7,15 @@ //===----------------------------------------------------------------------===// // REQUIRES: matrix -// RUN: %clangxx -fsycl %s -o %t.out +// RUN: %clangxx -fsycl %s -o %t.out -DSYCL_EXT_ONEAPI_MATRIX_VERSION=4 // RUN: %CPU_RUN_PLACEHOLDER %t.out // RUN: %GPU_RUN_PLACEHOLDER %t.out -// This test store the matrix B that is VNNIed (packed) in a row major fashion. +// This test stores the matrix B that is VNNIed (packed) in a row major fashion. // This is expected to fail on the GPU because the implementation does not // support automatic transformation YET, in this case: VNNI to row major in the // store. + // XFAIL: gpu #include diff --git a/SYCL/Matrix/element_wise_all_ops_int8_packed_impl.hpp b/SYCL/Matrix/element_wise_all_ops_int8_packed_impl.hpp index 21fecb2e4f..718a9df202 100644 --- a/SYCL/Matrix/element_wise_all_ops_int8_packed_impl.hpp +++ b/SYCL/Matrix/element_wise_all_ops_int8_packed_impl.hpp @@ -38,19 +38,22 @@ void matrix_verify_add(queue q, big_matrix &A, nd_range<2> &r, const auto sg_startx = global_idx - spmd_item.get_local_id(0); const auto sg_starty = global_idy - spmd_item.get_local_id(1); - ext::oneapi::sub_group sg = spmd_item.get_sub_group(); - joint_matrix sub_b(sg); + sub_group sg = spmd_item.get_sub_group(); + joint_matrix + sub_b; joint_matrix_fill(sg, sub_b, 5); - auto wi_slice_b = sub_b.get_wi_data(); + auto wi_slice_b = get_wi_data(sg, sub_b); for (int i = 0; i < wi_slice_b.length(); i++) { wi_slice_b[i] = wi_slice_b[i] + 2; } - joint_matrix_store(sg, sub_b, - accA.get_pointer() + (sg_startx * TM) * N * 4 + - sg_starty / SG_SZ * TN * 4, - N * 4, matrix_layout::row_major); + ext::intel::experimental::matrix::joint_matrix_store( + sg, sub_b, + accA.get_pointer() + (sg_startx * TM) * N * 4 + + sg_starty / SG_SZ * TN * 4, + N * 4); }); // parallel for }).wait(); assert_ops_ref(bufB.get_access(), ref); @@ -71,19 +74,22 @@ void matrix_verify_sub(queue q, big_matrix &A, nd_range<2> &r, const auto sg_startx = global_idx - spmd_item.get_local_id(0); const auto sg_starty = global_idy - spmd_item.get_local_id(1); - ext::oneapi::sub_group sg = spmd_item.get_sub_group(); - joint_matrix sub_b(sg); + sub_group sg = spmd_item.get_sub_group(); + joint_matrix + sub_b; joint_matrix_fill(sg, sub_b, 5); - auto wi_slice_b = sub_b.get_wi_data(); + auto wi_slice_b = get_wi_data(sg, sub_b); for (int i = 0; i < wi_slice_b.length(); i++) { wi_slice_b[i] = wi_slice_b[i] - 2; } - joint_matrix_store(sg, sub_b, - accA.get_pointer() + (sg_startx * TM) * N * 4 + - sg_starty / SG_SZ * TN * 4, - N * 4, matrix_layout::row_major); + ext::intel::experimental::matrix::joint_matrix_store( + sg, sub_b, + accA.get_pointer() + (sg_startx * TM) * N * 4 + + sg_starty / SG_SZ * TN * 4, + N * 4); }); // parallel for }).wait(); assert_ops_ref(bufB.get_access(), ref); @@ -104,19 +110,22 @@ void matrix_verify_mul(queue q, big_matrix &A, nd_range<2> &r, const auto sg_startx = global_idx - spmd_item.get_local_id(0); const auto sg_starty = global_idy - spmd_item.get_local_id(1); - ext::oneapi::sub_group sg = spmd_item.get_sub_group(); - joint_matrix sub_b(sg); + sub_group sg = spmd_item.get_sub_group(); + joint_matrix + sub_b; joint_matrix_fill(sg, sub_b, 5); - auto wi_slice_b = sub_b.get_wi_data(); + auto wi_slice_b = get_wi_data(sg, sub_b); for (int i = 0; i < wi_slice_b.length(); i++) { wi_slice_b[i] = wi_slice_b[i] * 3; } - joint_matrix_store(sg, sub_b, - accA.get_pointer() + (sg_startx * TM) * N * 4 + - sg_starty / SG_SZ * TN * 4, - N * 4, matrix_layout::row_major); + ext::intel::experimental::matrix::joint_matrix_store( + sg, sub_b, + accA.get_pointer() + (sg_startx * TM) * N * 4 + + sg_starty / SG_SZ * TN * 4, + N * 4); }); // parallel for }).wait(); assert_ops_ref(bufB.get_access(), ref); @@ -137,19 +146,22 @@ void matrix_verify_div(queue q, big_matrix &A, nd_range<2> &r, const auto sg_startx = global_idx - spmd_item.get_local_id(0); const auto sg_starty = global_idy - spmd_item.get_local_id(1); - ext::oneapi::sub_group sg = spmd_item.get_sub_group(); - joint_matrix sub_b(sg); + sub_group sg = spmd_item.get_sub_group(); + joint_matrix + sub_b; joint_matrix_fill(sg, sub_b, 4); - auto wi_slice_b = sub_b.get_wi_data(); + auto wi_slice_b = get_wi_data(sg, sub_b); for (int i = 0; i < wi_slice_b.length(); i++) { wi_slice_b[i] = wi_slice_b[i] / 2; } - joint_matrix_store(sg, sub_b, - accA.get_pointer() + (sg_startx * TM) * N * 4 + - sg_starty / SG_SZ * TN * 4, - N * 4, matrix_layout::row_major); + ext::intel::experimental::matrix::joint_matrix_store( + sg, sub_b, + accA.get_pointer() + (sg_startx * TM) * N * 4 + + sg_starty / SG_SZ * TN * 4, + N * 4); }); // parallel for }).wait(); assert_ops_ref(bufB.get_access(), ref); @@ -170,12 +182,14 @@ void matrix_verify_logic(queue q, big_matrix &A, nd_range<2> &r, const auto sg_startx = global_idx - spmd_item.get_local_id(0); const auto sg_starty = global_idy - spmd_item.get_local_id(1); - ext::oneapi::sub_group sg = spmd_item.get_sub_group(); - joint_matrix sub_b(sg); + sub_group sg = spmd_item.get_sub_group(); + joint_matrix + sub_b; joint_matrix_fill(sg, sub_b, 5); - auto wi_slice_b = sub_b.get_wi_data(); + auto wi_slice_b = get_wi_data(sg, sub_b); for (int i = 0; i < wi_slice_b.length(); i++) { if (wi_slice_b[i]) { if (wi_slice_b[i] > 2 || wi_slice_b[i] >= 2 || @@ -194,10 +208,11 @@ void matrix_verify_logic(queue q, big_matrix &A, nd_range<2> &r, } } } - joint_matrix_store(sg, sub_b, - accA.get_pointer() + (sg_startx * TM) * N * 4 + - sg_starty / SG_SZ * TN * 4, - N * 4, matrix_layout::row_major); + ext::intel::experimental::matrix::joint_matrix_store( + sg, sub_b, + accA.get_pointer() + (sg_startx * TM) * N * 4 + + sg_starty / SG_SZ * TN * 4, + N * 4); }); // parallel for }).wait(); assert_ops_ref(bufB.get_access(), ref); diff --git a/SYCL/Matrix/element_wise_irreg_sum_rows.cpp b/SYCL/Matrix/element_wise_irreg_sum_rows.cpp index cf1b7229c7..76e24de5c6 100644 --- a/SYCL/Matrix/element_wise_irreg_sum_rows.cpp +++ b/SYCL/Matrix/element_wise_irreg_sum_rows.cpp @@ -7,11 +7,11 @@ //===----------------------------------------------------------------------===// // REQUIRES: matrix -// RUN: %clangxx -fsycl %s -o %t.out +// RUN: %clangxx -fsycl %s -o %t.out -DSYCL_EXT_ONEAPI_MATRIX_VERSION=4 // RUN: %CPU_RUN_PLACEHOLDER %t.out // RUN: %GPU_RUN_PLACEHOLDER %t.out -// this code calculates the sum of rows into a global array of number of rows +// This code calculates the sum of rows into a global array of number of rows // elements. First, partial reduction is computed inside each SG, then atomic // add is used to reduce between SG leaders diff --git a/SYCL/Matrix/element_wise_irreg_sum_rows_impl.hpp b/SYCL/Matrix/element_wise_irreg_sum_rows_impl.hpp index 736dfac507..724daab206 100644 --- a/SYCL/Matrix/element_wise_irreg_sum_rows_impl.hpp +++ b/SYCL/Matrix/element_wise_irreg_sum_rows_impl.hpp @@ -47,17 +47,19 @@ void matrix_sum_rows(queue q, big_matrix &B, nd_range<2> &r) { ext::oneapi::sub_group sg = spmd_item.get_sub_group(); - joint_matrix sub_b(sg); + joint_matrix + sub_b; joint_matrix_load(sg, sub_b, accB.get_pointer() + (global_idx * (TK / 4) * N) + sg_starty / SG_SZ * TN * 4, - N, matrix_layout::packed_b); + N); // calculate sum of rows in sum_rows_v[8], there are 8 rows in sub_b // (tK/4) int32_t sum_local_rows[M] = {0}; // 8 local rows, M total // sub_b has 32x8 elements, 32 elements per WI, 4 per WI per row - auto data = sub_b.get_wi_data(); + auto data = get_wi_data(sg, sub_b); // each WI calculates local sum of rows for (int row = 0; row < TK / 4; row++) { // there are 8 rows diff --git a/SYCL/Matrix/element_wise_ops.cpp b/SYCL/Matrix/element_wise_ops.cpp index 9150e8b632..c3b949fd9f 100644 --- a/SYCL/Matrix/element_wise_ops.cpp +++ b/SYCL/Matrix/element_wise_ops.cpp @@ -7,7 +7,7 @@ //===----------------------------------------------------------------------===// // REQUIRES: matrix -// RUN: %clangxx -fsycl %s -o %t.out +// RUN: %clangxx -fsycl %s -o %t.out -DSYCL_EXT_ONEAPI_MATRIX_VERSION=4 // RUN: %CPU_RUN_PLACEHOLDER %t.out // RUN: %GPU_RUN_PLACEHOLDER %t.out diff --git a/SYCL/Matrix/element_wise_ops_impl.hpp b/SYCL/Matrix/element_wise_ops_impl.hpp index aba213ea13..2238064d61 100644 --- a/SYCL/Matrix/element_wise_ops_impl.hpp +++ b/SYCL/Matrix/element_wise_ops_impl.hpp @@ -49,39 +49,36 @@ void matrix_multiply(big_matrix &C, const auto sg_starty = global_idy - spmd_item.get_local_id(1); ext::oneapi::sub_group sg = spmd_item.get_sub_group(); - joint_matrix sub_a(sg); - // For B, since current implementation does not support non-packed - // layout, users need to specify the updated VNNI sizes along with - // the packed_b layout. By default, the layout is row_major and size - // is (TK, TN). - joint_matrix sub_b(sg); - joint_matrix sub_c(sg); + joint_matrix + sub_a; + // For B, we assume B has been already VNNIed. + joint_matrix + sub_b; + joint_matrix sub_c; - // AMX: 8 register tiles : 1k byte size, SMmaxxSKmax =16x64 - // strideX = X's cols, so strideC = N, strideA = K, strideB = N*4 joint_matrix_load(sg, sub_c, accC.get_pointer() + (sg_startx * TM) * N + sg_starty / SG_SZ * TN, - N, matrix_layout::row_major); + N, layout::row_major); for (int k = 0; k < K / TK; k += 1) { joint_matrix_load( sg, sub_a, accA.get_pointer() + (sg_startx * TM) * K + k * TK, - K, matrix_layout::row_major); - // Assuming B data is already in VNNI format. + K); joint_matrix_load(sg, sub_b, accB.get_pointer() + (k * TK / 4) * (N * 4) + sg_starty / SG_SZ * TN * 4, - N * 4, matrix_layout::packed_b); + N * 4); sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); } - auto wi_slice_c = sub_c.get_wi_data(); + auto wi_slice_c = get_wi_data(sg, sub_c); for (int i = 0; i < wi_slice_c.length(); i++) { wi_slice_c[i] *= 2; } joint_matrix_store(sg, sub_c, accC.get_pointer() + (sg_startx * TM) * N + sg_starty / SG_SZ * TN, - N, matrix_layout::row_major); + N, layout::row_major); }); // parallel for }).wait(); } diff --git a/SYCL/Matrix/elemwise_irreg_size_ops_bf16.cpp b/SYCL/Matrix/elemwise_irreg_size_ops_bf16.cpp index c389871438..48df762058 100644 --- a/SYCL/Matrix/elemwise_irreg_size_ops_bf16.cpp +++ b/SYCL/Matrix/elemwise_irreg_size_ops_bf16.cpp @@ -7,7 +7,7 @@ //===----------------------------------------------------------------------===// // REQUIRES: matrix -// RUN: %clangxx -fsycl %s -o %t.out +// RUN: %clangxx -fsycl %s -o %t.out -DSYCL_EXT_ONEAPI_MATRIX_VERSION=4 // This test is for element wise operations when matrix size does not multiply // SG size. This corner case only applies to AMX. Also, it tests bf16 type. // only run this on AMX @@ -75,36 +75,36 @@ void matrix_multiply(big_matrix &C, const auto sg_starty = global_idy - spmd_item.get_local_id(1); sub_group sg = spmd_item.get_sub_group(); - joint_matrix sub_a(sg); - // For B, since current implementation does not support non-packed - // layout, users need to specify the packed_b layout. - // By default, the layout is row_major - joint_matrix sub_b( - sg); - joint_matrix sub_c(sg); + joint_matrix + sub_a; + // For B, we assume B has been already VNNIed. + joint_matrix + sub_b; + joint_matrix sub_c; joint_matrix_load(sg, sub_c, accC.get_pointer() + (sg_startx * TM) * N + sg_starty / SG_SZ * TN, - N, matrix_layout::row_major); + N, layout::row_major); for (int k = 0; k < K; k += TK) { - joint_matrix_load(sg, sub_a, - accA.get_pointer() + (sg_startx * TM) * K + k, K, - matrix_layout::row_major); + joint_matrix_load( + sg, sub_a, accA.get_pointer() + (sg_startx * TM) * K + k, K); // Assume we alreay in vnni format. joint_matrix_load(sg, sub_b, accB.get_pointer() + (k) * (N) + sg_starty / SG_SZ * TN * 2, - N * 2, matrix_layout::packed_b); + N * 2); sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); } - auto wi_slice_c = sub_c.get_wi_data(); + auto wi_slice_c = get_wi_data(sg, sub_c); for (int i = 0; i < wi_slice_c.length(); i++) { wi_slice_c[i] += 5.0; } joint_matrix_store(sg, sub_c, accC.get_pointer() + (sg_startx * TM) * N + sg_starty / SG_SZ * TN, - N, matrix_layout::row_major); + N, layout::row_major); }); // parallel for }).wait(); } diff --git a/SYCL/Matrix/joint_matrix_bf16.cpp b/SYCL/Matrix/joint_matrix_bf16.cpp index e720054f4f..abbf8a744e 100644 --- a/SYCL/Matrix/joint_matrix_bf16.cpp +++ b/SYCL/Matrix/joint_matrix_bf16.cpp @@ -7,7 +7,7 @@ //===----------------------------------------------------------------------===// // REQUIRES: matrix -// RUN: %clangxx -fsycl %s -o %t.out +// RUN: %clangxx -fsycl %s -o %t.out -DSYCL_EXT_ONEAPI_MATRIX_VERSION=4 // RUN: %CPU_RUN_PLACEHOLDER %t.out // RUN: %GPU_RUN_PLACEHOLDER %t.out diff --git a/SYCL/Matrix/joint_matrix_bf16_impl.hpp b/SYCL/Matrix/joint_matrix_bf16_impl.hpp index cad12e0df0..bc0b92a6ef 100644 --- a/SYCL/Matrix/joint_matrix_bf16_impl.hpp +++ b/SYCL/Matrix/joint_matrix_bf16_impl.hpp @@ -52,32 +52,31 @@ void matrix_multiply(big_matrix &C, const auto sg_starty = global_idy - spmd_item.get_local_id(1); sub_group sg = spmd_item.get_sub_group(); - joint_matrix sub_a(sg); - // For B, since current implementation does not support non-packed - // layout, users need to specify the packed_b layout. - // By default, the layout is row_major - joint_matrix sub_b( - sg); - joint_matrix sub_c(sg); + joint_matrix + sub_a; + // For B, we assume B has been already VNNIed. + joint_matrix + sub_b; + joint_matrix sub_c; joint_matrix_load(sg, sub_c, accC.get_pointer() + (sg_startx * TM) * N + sg_starty / SG_SZ * TN, - N, matrix_layout::row_major); + N, layout::row_major); for (int k = 0; k < K; k += TK) { - joint_matrix_load(sg, sub_a, - accA.get_pointer() + (sg_startx * TM) * K + k, K, - matrix_layout::row_major); - // Assume we alreay in vnni format. + joint_matrix_load( + sg, sub_a, accA.get_pointer() + (sg_startx * TM) * K + k, K); joint_matrix_load(sg, sub_b, - accB.get_pointer() + (k) * (N) + + accB.get_pointer() + k * N + sg_starty / SG_SZ * TN * 2, - N * 2, matrix_layout::packed_b); + N * 2); sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); } joint_matrix_store(sg, sub_c, accC.get_pointer() + (sg_startx * TM) * N + sg_starty / SG_SZ * TN, - N, matrix_layout::row_major); + N, layout::row_major); }); // parallel for }).wait(); } @@ -105,14 +104,12 @@ unsigned short make_bf16(float x) { void matrix_multiply_ref(int *A_mem, int *B_mem, int *C_mem, int M, int N, int K) { - // tiling for (int m = 0; m < M; m++) for (int n = 0; n < N; n++) { for (int k = 0; k < K; k++) { short *va = (short *)(A_mem + m * K + k); short *vb = (short *)(B_mem + k * N + n); float acc = *((float *)(C_mem + m * N + n)); - // FIXME: Should we do reduce-add in another version? for (int i = 0; i < 2; i++) { acc += (make_fp32(va[i]) * make_fp32(vb[i])); } diff --git a/SYCL/Matrix/joint_matrix_bfloat16.cpp b/SYCL/Matrix/joint_matrix_bfloat16.cpp index e665617156..c19aa33768 100644 --- a/SYCL/Matrix/joint_matrix_bfloat16.cpp +++ b/SYCL/Matrix/joint_matrix_bfloat16.cpp @@ -7,7 +7,7 @@ //===----------------------------------------------------------------------===// // REQUIRES: matrix -// RUN: %clangxx -fsycl %s -o %t.out +// RUN: %clangxx -fsycl %s -o %t.out -DSYCL_EXT_ONEAPI_MATRIX_VERSION=4 // RUN: %CPU_RUN_PLACEHOLDER %t.out // RUN: %GPU_RUN_PLACEHOLDER %t.out @@ -16,7 +16,7 @@ using namespace sycl; using namespace sycl::ext::oneapi::experimental::matrix; -using bfloat16 = sycl::ext::oneapi::bfloat16; +using bfloat16 = sycl::ext::oneapi::experimental::bfloat16; #define SG_SZ 16 diff --git a/SYCL/Matrix/joint_matrix_bfloat16_32x64.cpp b/SYCL/Matrix/joint_matrix_bfloat16_32x64.cpp index 1c2d47f93b..5c955ec422 100644 --- a/SYCL/Matrix/joint_matrix_bfloat16_32x64.cpp +++ b/SYCL/Matrix/joint_matrix_bfloat16_32x64.cpp @@ -7,7 +7,7 @@ //===----------------------------------------------------------------------===// // REQUIRES: matrix -// RUN: %clangxx -fsycl %s -o %t.out +// RUN: %clangxx -fsycl %s -o %t.out -DSYCL_EXT_ONEAPI_MATRIX_VERSION=4 // RUN: %CPU_RUN_PLACEHOLDER %t.out // RUN: %GPU_RUN_PLACEHOLDER %t.out @@ -22,160 +22,4 @@ using bfloat16 = sycl::ext::oneapi::bfloat16; #define SG_SZ 16 -#define TM 32 -#define TN 64 -#define TK 16 - -#define BF16_EPSILON 0.00781250 - -template struct big_matrix { -private: - T *mat; - -public: - T *get_data() { return mat; } - void set_data(T *data) { mat = data; } - big_matrix(T *data) : mat(data) {} -}; - -template -void matrix_multiply(big_matrix &C, big_matrix &A, - big_matrix &B) { - size_t NDRangeM = M / TM; - size_t NDRangeN = N / TN; - buffer bufA(A.get_data(), range<2>(M, K)); - buffer bufB(B.get_data(), range<2>(K, N)); - buffer bufC((float *)C.get_data(), range<2>(M, N)); - - queue q; - q.submit([&](handler &cgh) { - auto accC = bufC.get_access(cgh); - auto accA = bufA.get_access(cgh); - auto accB = bufB.get_access(cgh); - - cgh.parallel_for( - nd_range<2>({NDRangeM, NDRangeN * SG_SZ}, {1, 1 * SG_SZ}), - [=](nd_item<2> spmd_item) [[intel::reqd_sub_group_size(SG_SZ)]] - - { - // The submatrix API has to be accessed by all the workitems in a - // subgroup these functions will be called once by the subgroup no - // code divergence between the workitems - const auto global_idx = spmd_item.get_global_id(0); - const auto global_idy = spmd_item.get_global_id(1); - const auto sg_startx = global_idx - spmd_item.get_local_id(0); - const auto sg_starty = global_idy - spmd_item.get_local_id(1); - - ext::oneapi::sub_group sg = spmd_item.get_sub_group(); - joint_matrix sub_a(sg); - // For B, since current implementation does not support non-packed - // layout, users need to specify the updated VNNI sizes along with - // the packed_b layout. By default, the layout is row_major and size - // is (TK, TN). - joint_matrix sub_b(sg); - joint_matrix sub_c(sg); - - joint_matrix_load(sg, sub_c, - accC.get_pointer() + (sg_startx * TM) * N + - sg_starty / SG_SZ * TN, - N, matrix_layout::row_major); - for (int k = 0; k < K / TK; k += 1) { // - joint_matrix_load( - sg, sub_a, accA.get_pointer() + (sg_startx * TM) * K + k * TK, - K, matrix_layout::row_major); - // Assuming B data is already in VNNI format. - joint_matrix_load(sg, sub_b, - accB.get_pointer() + (k * TK / 2) * (N * 2) + - sg_starty / SG_SZ * TN * 2, - N * 2, matrix_layout::packed_b); - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); - } - joint_matrix_store(sg, sub_c, - accC.get_pointer() + (sg_startx * TM) * N + - sg_starty / SG_SZ * TN, - N, matrix_layout::row_major); - }); // parallel for - }).wait(); -} - -static constexpr size_t MATRIX_M = TM * 2; -static constexpr size_t MATRIX_N = TN * 2; -static constexpr size_t MATRIX_K = TK * 2; -bfloat16 A[MATRIX_M][MATRIX_K]; -bfloat16 B[MATRIX_K / 2][MATRIX_N * 2]; -unsigned short Aref[MATRIX_M][MATRIX_K]; -unsigned short Bref[MATRIX_K / 2][MATRIX_N * 2]; -float C[MATRIX_M][MATRIX_N]; -float D[MATRIX_M][MATRIX_N]; - -float make_fp32(short x) { - unsigned int y = x; - y = y << 16; - float *res = reinterpret_cast(&y); - return *res; -} - -unsigned short make_bf16(float x) { - int *res = reinterpret_cast(&x); - *res = *res >> 16; - return (unsigned short)*res; -} - -void matrix_multiply_ref(int *A_mem, int *B_mem, int *C_mem, int M, int N, - int K) { - // tiling - for (int m = 0; m < M; m++) - for (int n = 0; n < N; n++) { - for (int k = 0; k < K; k++) { - short *va = (short *)(A_mem + m * K + k); - short *vb = (short *)(B_mem + k * N + n); - float acc = *((float *)(C_mem + m * N + n)); - // FIXME: Should we do reduce-add in another version? - for (int i = 0; i < 2; i++) { - acc += (make_fp32(va[i]) * make_fp32(vb[i])); - } - *((float *)(C_mem + m * N + n)) = acc; - } - } -} - -int main() { - for (int i = 0; i < MATRIX_M; i++) { - for (int j = 0; j < MATRIX_K; j++) { - // bfloat16 is created using unsigned short since conversion from float to - // bfloat16 is not supported on the host side yet - A[i][j] = make_bf16(1.0f * (i + j)); - Aref[i][j] = make_bf16(1.0f * (i + j)); - } - } - for (int i = 0; i < MATRIX_K / 2; i++) { - for (int j = 0; j < MATRIX_N * 2; j++) { - B[i][j] = make_bf16(2.0f * i + 3.0f * j); - Bref[i][j] = make_bf16(2.0f * i + 3.0f * j); - } - } - for (int i = 0; i < MATRIX_M; i++) { - for (int j = 0; j < MATRIX_N; j++) { - C[i][j] = 1.0; - D[i][j] = 1.0; - } - } - - big_matrix MC((float *)&C); - big_matrix MD((float *)&D); - big_matrix MA((bfloat16 *)&A); - big_matrix MB((bfloat16 *)&B); - matrix_multiply(MC, MA, MB); - matrix_multiply_ref((int32_t *)Aref, (int32_t *)Bref, (int32_t *)D, MATRIX_M, - MATRIX_N, MATRIX_K / 2); - - bool res = true; - for (int i = 0; i < MATRIX_M; i++) { - for (int j = 0; j < MATRIX_N; j++) { - if ((fabs(C[i][j]) - fabs(D[i][j])) > BF16_EPSILON) - res = false; - } - } - std::cout << (res ? "passed" : "failed") << std::endl; - return !res; -} +#include "joint_matrix_bfloat16_32x64_impl.hpp" diff --git a/SYCL/Matrix/joint_matrix_bfloat16_use_impl.hpp b/SYCL/Matrix/joint_matrix_bfloat16_32x64_impl.hpp similarity index 66% rename from SYCL/Matrix/joint_matrix_bfloat16_use_impl.hpp rename to SYCL/Matrix/joint_matrix_bfloat16_32x64_impl.hpp index b7c6dfe76c..1e2308b202 100644 --- a/SYCL/Matrix/joint_matrix_bfloat16_use_impl.hpp +++ b/SYCL/Matrix/joint_matrix_bfloat16_32x64_impl.hpp @@ -1,5 +1,5 @@ -#define TM 8 -#define TN SG_SZ +#define TM 32 +#define TN 64 #define TK 16 #define BF16_EPSILON 0.00781250 @@ -14,55 +14,56 @@ template struct big_matrix { big_matrix(T *data) : mat(data) {} }; -template -void matrix_multiply(big_matrix &C, - big_matrix &A, - big_matrix &B) { - size_t M = NUM_ROWS_C; - size_t N = NUM_COLS_C; - size_t K = NUM_COLS_A; - static_assert(NUM_ROWS_C == NUM_ROWS_A && NUM_COLS_A == NUM_ROWS_B * 2); +template +void matrix_multiply(big_matrix &C, big_matrix &A, + big_matrix &B) { size_t NDRangeM = M / TM; size_t NDRangeN = N / TN; - sycl::buffer bufA(A.get_data(), sycl::range<2>(M, K)); - sycl::buffer bufB(B.get_data(), sycl::range<2>(K, N)); - sycl::buffer bufC((float *)C.get_data(), sycl::range<2>(M, N)); + buffer bufA(A.get_data(), range<2>(M, K)); + buffer bufB(B.get_data(), range<2>(K, N)); + buffer bufC((float *)C.get_data(), range<2>(M, N)); - sycl::queue q; - q.submit([&](sycl::handler &cgh) { - sycl::accessor accC{bufC, cgh}; - sycl::accessor accA{bufA, cgh}; - sycl::accessor accB{bufB, cgh}; + queue q; + q.submit([&](handler &cgh) { + auto accC = bufC.get_access(cgh); + auto accA = bufA.get_access(cgh); + auto accB = bufB.get_access(cgh); cgh.parallel_for( - sycl::nd_range<2>({NDRangeM, NDRangeN * SG_SZ}, {1, 1 * SG_SZ}), - [=](sycl::nd_item<2> spmd_item) + nd_range<2>({NDRangeM, NDRangeN * SG_SZ}, {1, 1 * SG_SZ}), + [=](nd_item<2> spmd_item) [[intel::reqd_sub_group_size(SG_SZ)]] { // The submatrix API has to be accessed by all the workitems in a - // subgroup + // subgroup these functions will be called once by the subgroup no + // code divergence between the workitems const auto global_idx = spmd_item.get_global_id(0); const auto global_idy = spmd_item.get_global_id(1); const auto sg_startx = global_idx - spmd_item.get_local_id(0); const auto sg_starty = global_idy - spmd_item.get_local_id(1); - sycl::ext::oneapi::sub_group sg = spmd_item.get_sub_group(); - joint_matrix sub_a(sg); - joint_matrix sub_b(sg); - joint_matrix sub_c(sg); - - joint_matrix_fill(sg, sub_c, 1.0); - for (int k = 0; k < K / TK; k += 1) { + sub_group sg = spmd_item.get_sub_group(); + joint_matrix + sub_a; + // For B, we assume B has been already VNNIed. + joint_matrix + sub_b; + joint_matrix sub_c; + + joint_matrix_load(sg, sub_c, + accC.get_pointer() + (sg_startx * TM) * N + + sg_starty / SG_SZ * TN, + N, layout::row_major); + for (int k = 0; k < K / TK; k += 1) { // joint_matrix_load( sg, sub_a, accA.get_pointer() + (sg_startx * TM) * K + k * TK, - K, layout::row_major); + K); // Assuming B data is already in VNNI format. joint_matrix_load(sg, sub_b, accB.get_pointer() + (k * TK / 2) * (N * 2) + sg_starty / SG_SZ * TN * 2, - N * 2, layout::packed_b); + N * 2); sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); } joint_matrix_store(sg, sub_c, @@ -98,12 +99,14 @@ unsigned short make_bf16(float x) { void matrix_multiply_ref(int *A_mem, int *B_mem, int *C_mem, int M, int N, int K) { + // tiling for (int m = 0; m < M; m++) for (int n = 0; n < N; n++) { for (int k = 0; k < K; k++) { short *va = (short *)(A_mem + m * K + k); short *vb = (short *)(B_mem + k * N + n); float acc = *((float *)(C_mem + m * N + n)); + // FIXME: Should we do reduce-add in another version? for (int i = 0; i < 2; i++) { acc += (make_fp32(va[i]) * make_fp32(vb[i])); } @@ -115,8 +118,8 @@ void matrix_multiply_ref(int *A_mem, int *B_mem, int *C_mem, int M, int N, int main() { for (int i = 0; i < MATRIX_M; i++) { for (int j = 0; j < MATRIX_K; j++) { - // bfloat16 is created from unsigned short since float-to-bfloat's - // conversion is not allowed. + // bfloat16 is created using unsigned short since conversion from float to + // bfloat16 is not supported on the host side yet A[i][j] = bfloat16::from_bits(make_bf16(1.0f * (i + j))); Aref[i][j] = make_bf16(1.0f * (i + j)); } diff --git a/SYCL/Matrix/joint_matrix_bfloat16_impl.hpp b/SYCL/Matrix/joint_matrix_bfloat16_impl.hpp index 7d8ebda013..a784a88290 100644 --- a/SYCL/Matrix/joint_matrix_bfloat16_impl.hpp +++ b/SYCL/Matrix/joint_matrix_bfloat16_impl.hpp @@ -42,34 +42,33 @@ void matrix_multiply(big_matrix &C, big_matrix &A, const auto sg_startx = global_idx - spmd_item.get_local_id(0); const auto sg_starty = global_idy - spmd_item.get_local_id(1); - ext::oneapi::sub_group sg = spmd_item.get_sub_group(); - joint_matrix sub_a(sg); - // For B, since current implementation does not support non-packed - // layout, users need to specify the updated VNNI sizes along with - // the packed_b layout. By default, the layout is row_major and size - // is (TK, TN). - joint_matrix sub_b(sg); - joint_matrix sub_c(sg); + sub_group sg = spmd_item.get_sub_group(); + joint_matrix + sub_a; + // For B, we assume B has been already VNNIed. + joint_matrix + sub_b; + joint_matrix sub_c; joint_matrix_load(sg, sub_c, accC.get_pointer() + (sg_startx * TM) * N + sg_starty / SG_SZ * TN, - N, matrix_layout::row_major); + N, layout::row_major); for (int k = 0; k < K / TK; k += 1) { // joint_matrix_load( sg, sub_a, accA.get_pointer() + (sg_startx * TM) * K + k * TK, - K, matrix_layout::row_major); - // Assuming B data is already in VNNI format. + K); joint_matrix_load(sg, sub_b, accB.get_pointer() + (k * TK / 2) * (N * 2) + sg_starty / SG_SZ * TN * 2, - N * 2, matrix_layout::packed_b); + N * 2); sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); } joint_matrix_store(sg, sub_c, accC.get_pointer() + (sg_startx * TM) * N + sg_starty / SG_SZ * TN, - N, matrix_layout::row_major); + N, layout::row_major); }); // parallel for }).wait(); } diff --git a/SYCL/Matrix/joint_matrix_half.cpp b/SYCL/Matrix/joint_matrix_half.cpp index 33c69d4f1f..1d131a64a8 100644 --- a/SYCL/Matrix/joint_matrix_half.cpp +++ b/SYCL/Matrix/joint_matrix_half.cpp @@ -7,7 +7,7 @@ //===----------------------------------------------------------------------===// // REQUIRES: matrix -// RUN: %clangxx -fsycl %s -o %t.out +// RUN: %clangxx -fsycl %s -o %t.out -DSYCL_EXT_ONEAPI_MATRIX_VERSION=4 // Only run on the GPU because half is not supported on AMX hardware // RUN: %GPU_RUN_PLACEHOLDER %t.out diff --git a/SYCL/Matrix/joint_matrix_half_impl.hpp b/SYCL/Matrix/joint_matrix_half_impl.hpp index 0cb26903e5..bada4fdc7c 100644 --- a/SYCL/Matrix/joint_matrix_half_impl.hpp +++ b/SYCL/Matrix/joint_matrix_half_impl.hpp @@ -37,45 +37,44 @@ void matrix_multiply(big_matrix &C, cgh.parallel_for( nd_range<2>({NDRangeM, NDRangeN * SG_SZ}, {1, SG_SZ}), - [accA, accB, accC, M, N, - K](nd_item<2> spmd_item) [[intel::reqd_sub_group_size(SG_SZ)]] { - // The submatrix API has to be accessed by all the workitems in a - // subgroup these functions will be called once by the subgroup no - // code divergence between the workitems - const auto global_idx = spmd_item.get_global_id(0); - const auto global_idy = spmd_item.get_global_id(1); - const auto sg_startx = global_idx - spmd_item.get_local_id(0); - const auto sg_starty = global_idy - spmd_item.get_local_id(1); + [accA, accB, accC, M, N, K](nd_item<2> spmd_item) + [[intel::reqd_sub_group_size(SG_SZ)]] { + // The submatrix API has to be accessed by all the workitems in a + // subgroup these functions will be called once by the subgroup + // no code divergence between the workitems + const auto global_idx = spmd_item.get_global_id(0); + const auto global_idy = spmd_item.get_global_id(1); + const auto sg_startx = global_idx - spmd_item.get_local_id(0); + const auto sg_starty = global_idy - spmd_item.get_local_id(1); - ext::oneapi::sub_group sg = spmd_item.get_sub_group(); - joint_matrix sub_a(sg); - // For B, since current implementation does not support non-packed - // layout, users need to specify the updated VNNI sizes along with - // the packed_b layout. By default, the layout is row_major and size - // is (TK, TN). - joint_matrix sub_b(sg); - joint_matrix sub_c(sg); + sub_group sg = spmd_item.get_sub_group(); + joint_matrix + sub_a; + // For B, we assume B has been already VNNIed. + joint_matrix + sub_b; + joint_matrix sub_c; - joint_matrix_load(sg, sub_c, - accC.get_pointer() + (sg_startx * TM) * N + - sg_starty / SG_SZ * TN, - N, matrix_layout::row_major); - for (int k = 0; k < K / TK; k += 1) { - joint_matrix_load( - sg, sub_a, accA.get_pointer() + (sg_startx * TM) * K + k * TK, - K, matrix_layout::row_major); - // Assuming B data is already in VNNI format. - joint_matrix_load(sg, sub_b, - accB.get_pointer() + (k * TK / 2) * (N * 2) + - sg_starty / SG_SZ * TN * 2, - N * 2, matrix_layout::packed_b); - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); - } - joint_matrix_store(sg, sub_c, - accC.get_pointer() + (sg_startx * TM) * N + - sg_starty / SG_SZ * TN, - N, matrix_layout::row_major); - }); // parallel for + joint_matrix_load(sg, sub_c, + accC.get_pointer() + (sg_startx * TM) * N + + sg_starty / SG_SZ * TN, + N, layout::row_major); + for (int k = 0; k < K / TK; k += 1) { + joint_matrix_load( + sg, sub_a, + accA.get_pointer() + (sg_startx * TM) * K + k * TK, K); + joint_matrix_load(sg, sub_b, + accB.get_pointer() + (k * TK / 2) * (N * 2) + + sg_starty / SG_SZ * TN * 2, + N * 2); + sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + } + joint_matrix_store(sg, sub_c, + accC.get_pointer() + (sg_startx * TM) * N + + sg_starty / SG_SZ * TN, + N, layout::row_major); + }); // parallel for }).wait(); } diff --git a/SYCL/Matrix/joint_matrix_int8_vnni.cpp b/SYCL/Matrix/joint_matrix_int8_vnni.cpp index 09740aa7f3..f8ae1a8cf7 100644 --- a/SYCL/Matrix/joint_matrix_int8_vnni.cpp +++ b/SYCL/Matrix/joint_matrix_int8_vnni.cpp @@ -7,7 +7,7 @@ //===----------------------------------------------------------------------===// // REQUIRES: matrix -// RUN: %clangxx -fsycl %s -o %t.out +// RUN: %clangxx -fsycl %s -o %t.out -DSYCL_EXT_ONEAPI_MATRIX_VERSION=4 // RUN: %CPU_RUN_PLACEHOLDER %t.out // RUN: %GPU_RUN_PLACEHOLDER %t.out diff --git a/SYCL/Matrix/joint_matrix_int8_vnni_impl.hpp b/SYCL/Matrix/joint_matrix_int8_vnni_impl.hpp index 303bd1819d..2092124e0c 100644 --- a/SYCL/Matrix/joint_matrix_int8_vnni_impl.hpp +++ b/SYCL/Matrix/joint_matrix_int8_vnni_impl.hpp @@ -48,26 +48,28 @@ void matrix_multiply(big_matrix &C, const auto sg_starty = global_idy - spmd_item.get_local_id(1); sub_group sg = spmd_item.get_sub_group(); - joint_matrix sub_a(sg); - joint_matrix sub_b(sg); - joint_matrix sub_c(sg); + joint_matrix + sub_a; + joint_matrix + sub_b; + joint_matrix sub_c; joint_matrix_fill(sg, sub_c, 0); for (int k = 0; k < K / TK; k += 1) { joint_matrix_load( sg, sub_a, accA.get_pointer() + (sg_startx * TM) * K + k * TK, - K, matrix_layout::row_major); + K); // VNNI transform is done automatically at this level - joint_matrix_load(sg, sub_b, - accB.get_pointer() + (k * TK) * N + - sg_starty / SG_SZ * TN, - N, matrix_layout::row_major); + joint_matrix_load( + sg, sub_b, + accB.get_pointer() + (k * TK) * N + sg_starty / SG_SZ * TN, N); sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); } joint_matrix_store(sg, sub_c, accC.get_pointer() + (sg_startx * TM) * N + sg_starty / SG_SZ * TN, - N, matrix_layout::row_major); + N, layout::row_major); }); // parallel for }).wait(); } diff --git a/SYCL/Matrix/joint_matrix_query_default.cpp b/SYCL/Matrix/joint_matrix_query_default.cpp index 931b21cc14..d77d467efb 100644 --- a/SYCL/Matrix/joint_matrix_query_default.cpp +++ b/SYCL/Matrix/joint_matrix_query_default.cpp @@ -7,7 +7,7 @@ //===----------------------------------------------------------------------===// // REQUIRES: matrix -// RUN: %clangxx -fsycl %s -o %t.out +// RUN: %clangxx -fsycl %s -o %t.out -DSYCL_EXT_ONEAPI_MATRIX_VERSION=4 // RUN: %CPU_RUN_PLACEHOLDER %t.out #include @@ -60,7 +60,7 @@ void matrix_multiply(big_matrix &C, cgh.parallel_for( nd_range<2>({NDRangeM, NDRangeN * SG_SZ}, {1, 1 * SG_SZ}), - [ accA, accB, accC, M, N, K ](nd_item<2> spmd_item) + [accA, accB, accC, M, N, K](nd_item<2> spmd_item) [[intel::reqd_sub_group_size(SG_SZ)]] { @@ -74,29 +74,29 @@ void matrix_multiply(big_matrix &C, ext::oneapi::sub_group sg = spmd_item.get_sub_group(); - myparams2::joint_matrix_a sub_a(sg); - myparams2::joint_matrix_b sub_b(sg); - myparams2::joint_matrix_c sub_c(sg); + myparams2::joint_matrix_a sub_a; + myparams2::joint_matrix_b sub_b; + myparams2::joint_matrix_c sub_c; joint_matrix_load(sg, sub_c, accC.get_pointer() + (sg_startx * TM) * N + sg_starty / SG_SZ * TN, - N, matrix_layout::row_major); + N, layout::row_major); for (int k = 0; k < K / TK; k += 1) { joint_matrix_load( sg, sub_a, accA.get_pointer() + (sg_startx * TM) * K + k * TK, - K, matrix_layout::row_major); + K); // Assuming B data is already in VNNI format. joint_matrix_load(sg, sub_b, accB.get_pointer() + (k * TK / 4) * (N * 4) + sg_starty / SG_SZ * TN * 4, - N * 4, matrix_layout::packed_b); + N * 4); sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); } joint_matrix_store(sg, sub_c, accC.get_pointer() + (sg_startx * TM) * N + sg_starty / SG_SZ * TN, - N, matrix_layout::row_major); + N, layout::row_major); }); // parallel for }).wait(); } diff --git a/SYCL/Matrix/joint_matrix_query_use_default.cpp b/SYCL/Matrix/joint_matrix_query_use_default.cpp index ac47aa88b5..cf6fcf62cc 100644 --- a/SYCL/Matrix/joint_matrix_query_use_default.cpp +++ b/SYCL/Matrix/joint_matrix_query_use_default.cpp @@ -7,7 +7,7 @@ //===----------------------------------------------------------------------===// // REQUIRES: matrix -// RUN: %clangxx -fsycl %s -o %t.out -DSYCL_EXT_ONEAPI_MATRIX_VERSION=2 +// RUN: %clangxx -fsycl %s -o %t.out -DSYCL_EXT_ONEAPI_MATRIX_VERSION=4 // RUN: %CPU_RUN_PLACEHOLDER %t.out #include diff --git a/SYCL/Matrix/joint_matrix_ss_int8.cpp b/SYCL/Matrix/joint_matrix_ss_int8.cpp index 2b09ada4cd..860e590357 100644 --- a/SYCL/Matrix/joint_matrix_ss_int8.cpp +++ b/SYCL/Matrix/joint_matrix_ss_int8.cpp @@ -7,7 +7,7 @@ //===----------------------------------------------------------------------===// // REQUIRES: matrix -// RUN: %clangxx -fsycl %s -o %t.out +// RUN: %clangxx -fsycl %s -o %t.out -DSYCL_EXT_ONEAPI_MATRIX_VERSION=4 // RUN: %CPU_RUN_PLACEHOLDER %t.out // RUN: %GPU_RUN_PLACEHOLDER %t.out diff --git a/SYCL/Matrix/joint_matrix_ss_int8_impl.hpp b/SYCL/Matrix/joint_matrix_ss_int8_impl.hpp index d83332bfb4..172789f085 100644 --- a/SYCL/Matrix/joint_matrix_ss_int8_impl.hpp +++ b/SYCL/Matrix/joint_matrix_ss_int8_impl.hpp @@ -48,31 +48,30 @@ void matrix_multiply(big_matrix &C, const auto sg_startx = global_idx - spmd_item.get_local_id(0); const auto sg_starty = global_idy - spmd_item.get_local_id(1); - ext::oneapi::sub_group sg = spmd_item.get_sub_group(); - joint_matrix sub_a(sg); - // For B, since current implementation does not support non-packed - // layout, users need to specify the updated VNNI sizes along with - // the packed_b layout. By default, the layout is row_major and size - // is (TK, TN). - joint_matrix sub_b(sg); - joint_matrix sub_c(sg); + sub_group sg = spmd_item.get_sub_group(); + joint_matrix + sub_a; + // For B, we assume B has been already VNNIed. + joint_matrix + sub_b; + joint_matrix sub_c; joint_matrix_fill(sg, sub_c, 0); for (int k = 0; k < K / TK; k += 1) { joint_matrix_load( sg, sub_a, accA.get_pointer() + (sg_startx * TM) * K + k * TK, - K, matrix_layout::row_major); - // Assuming B data is already in VNNI format. + K); joint_matrix_load(sg, sub_b, accB.get_pointer() + (k * TK / 4) * (N * 4) + sg_starty / SG_SZ * TN * 4, - N * 4, matrix_layout::packed_b); + N * 4); sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); } joint_matrix_store(sg, sub_c, accC.get_pointer() + (sg_startx * TM) * N + sg_starty / SG_SZ * TN, - N, matrix_layout::row_major); + N, layout::row_major); }); // parallel for }).wait(); } diff --git a/SYCL/Matrix/joint_matrix_su_int8.cpp b/SYCL/Matrix/joint_matrix_su_int8.cpp index d040278b2c..bd89977fc3 100644 --- a/SYCL/Matrix/joint_matrix_su_int8.cpp +++ b/SYCL/Matrix/joint_matrix_su_int8.cpp @@ -7,7 +7,7 @@ //===----------------------------------------------------------------------===// // REQUIRES: matrix -// RUN: %clangxx -fsycl %s -o %t.out +// RUN: %clangxx -fsycl %s -o %t.out -DSYCL_EXT_ONEAPI_MATRIX_VERSION=4 // RUN: %CPU_RUN_PLACEHOLDER %t.out // RUN: %GPU_RUN_PLACEHOLDER %t.out diff --git a/SYCL/Matrix/joint_matrix_su_int8_impl.hpp b/SYCL/Matrix/joint_matrix_su_int8_impl.hpp index 07a5d6a41a..2b6cb314f1 100644 --- a/SYCL/Matrix/joint_matrix_su_int8_impl.hpp +++ b/SYCL/Matrix/joint_matrix_su_int8_impl.hpp @@ -48,36 +48,33 @@ void matrix_multiply(big_matrix &C, const auto sg_startx = global_idx - spmd_item.get_local_id(0); const auto sg_starty = global_idy - spmd_item.get_local_id(1); - ext::oneapi::sub_group sg = spmd_item.get_sub_group(); - joint_matrix sub_a(sg); - // For B, since current implementation does not support non-packed - // layout, users need to specify the updated VNNI sizes along with - // the packed_b layout. By default, the layout is row_major and size - // is (TK, TN). - joint_matrix sub_b(sg); - joint_matrix sub_c(sg); + sub_group sg = spmd_item.get_sub_group(); + joint_matrix + sub_a; + // For B, we assume B has been already VNNIed. + joint_matrix + sub_b; + joint_matrix sub_c; - // AMX: 8 register tiles : 1k byte size, SMmaxxSKmax =16x64 - // strideX = X's cols, so strideC = N, strideA = K, strideB = N*4 joint_matrix_load(sg, sub_c, accC.get_pointer() + (sg_startx * TM) * N + sg_starty / SG_SZ * TN, - N, matrix_layout::row_major); + N, layout::row_major); for (int k = 0; k < K / TK; k += 1) { joint_matrix_load( sg, sub_a, accA.get_pointer() + (sg_startx * TM) * K + k * TK, - K, matrix_layout::row_major); - // Assuming B data is already in VNNI format. + K); joint_matrix_load(sg, sub_b, accB.get_pointer() + (k * TK / 4) * (N * 4) + sg_starty / SG_SZ * TN * 4, - N * 4, matrix_layout::packed_b); + N * 4); sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); } joint_matrix_store(sg, sub_c, accC.get_pointer() + (sg_startx * TM) * N + sg_starty / SG_SZ * TN, - N, matrix_layout::row_major); + N, layout::row_major); }); // parallel for }).wait(); } diff --git a/SYCL/Matrix/joint_matrix_us_int8.cpp b/SYCL/Matrix/joint_matrix_us_int8.cpp index ca8e66af0f..0690636c59 100644 --- a/SYCL/Matrix/joint_matrix_us_int8.cpp +++ b/SYCL/Matrix/joint_matrix_us_int8.cpp @@ -7,7 +7,7 @@ //===----------------------------------------------------------------------===// // REQUIRES: matrix -// RUN: %clangxx -fsycl %s -o %t.out +// RUN: %clangxx -fsycl %s -o %t.out -DSYCL_EXT_ONEAPI_MATRIX_VERSION=4 // RUN: %CPU_RUN_PLACEHOLDER %t.out // RUN: %GPU_RUN_PLACEHOLDER %t.out diff --git a/SYCL/Matrix/joint_matrix_us_int8_impl.hpp b/SYCL/Matrix/joint_matrix_us_int8_impl.hpp index bd90a4e929..694787f408 100644 --- a/SYCL/Matrix/joint_matrix_us_int8_impl.hpp +++ b/SYCL/Matrix/joint_matrix_us_int8_impl.hpp @@ -50,36 +50,34 @@ void matrix_multiply(big_matrix &C, const auto sg_startx = global_idx - spmd_item.get_local_id(0); const auto sg_starty = global_idy - spmd_item.get_local_id(1); - ext::oneapi::sub_group sg = spmd_item.get_sub_group(); - joint_matrix sub_a(sg); - // For B, since current implementation does not support non-packed - // layout, users need to specify the updated VNNI sizes along with - // the packed_b layout. By default, the layout is row_major and size - // is (TK, TN). - joint_matrix sub_b(sg); - joint_matrix sub_c(sg); + sub_group sg = spmd_item.get_sub_group(); + joint_matrix + sub_a; + // For B, we assume B has been already VNNIed. + joint_matrix + sub_b; + joint_matrix sub_c; - // AMX: 8 register tiles : 1k byte size, SMmaxxSKmax =16x64 - // strideX = X's cols, so strideC = N, strideA = K, strideB = N*4 joint_matrix_load(sg, sub_c, accC.get_pointer() + (sg_startx * TM) * N + sg_starty / SG_SZ * TN, - N, matrix_layout::row_major); + N, layout::row_major); for (int k = 0; k < K / TK; k += 1) { joint_matrix_load( sg, sub_a, accA.get_pointer() + (sg_startx * TM) * K + k * TK, - K, matrix_layout::row_major); + K); // Assuming B data is already in VNNI format. joint_matrix_load(sg, sub_b, accB.get_pointer() + (k * TK / 4) * (N * 4) + sg_starty / SG_SZ * TN * 4, - N * 4, matrix_layout::packed_b); + N * 4); sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); } joint_matrix_store(sg, sub_c, accC.get_pointer() + (sg_startx * TM) * N + sg_starty / SG_SZ * TN, - N, matrix_layout::row_major); + N, layout::row_major); }); // parallel for }).wait(); } diff --git a/SYCL/Matrix/joint_matrix_uu_int8.cpp b/SYCL/Matrix/joint_matrix_uu_int8.cpp index 2809101929..42f2ff8fe6 100644 --- a/SYCL/Matrix/joint_matrix_uu_int8.cpp +++ b/SYCL/Matrix/joint_matrix_uu_int8.cpp @@ -7,7 +7,7 @@ //===----------------------------------------------------------------------===// // REQUIRES: matrix -// RUN: %clangxx -fsycl %s -o %t.out +// RUN: %clangxx -fsycl %s -o %t.out -DSYCL_EXT_ONEAPI_MATRIX_VERSION=4 // RUN: %CPU_RUN_PLACEHOLDER %t.out // RUN: %GPU_RUN_PLACEHOLDER %t.out diff --git a/SYCL/Matrix/joint_matrix_uu_int8_impl.hpp b/SYCL/Matrix/joint_matrix_uu_int8_impl.hpp index c02097a46b..ce42d45d01 100644 --- a/SYCL/Matrix/joint_matrix_uu_int8_impl.hpp +++ b/SYCL/Matrix/joint_matrix_uu_int8_impl.hpp @@ -48,36 +48,34 @@ void matrix_multiply(big_matrix &C, const auto sg_startx = global_idx - spmd_item.get_local_id(0); const auto sg_starty = global_idy - spmd_item.get_local_id(1); - ext::oneapi::sub_group sg = spmd_item.get_sub_group(); - joint_matrix sub_a(sg); - // For B, since current implementation does not support non-packed - // layout, users need to specify the updated VNNI sizes along with - // the packed_b layout. By default, the layout is row_major and size - // is (TK, TN). - joint_matrix sub_b(sg); - joint_matrix sub_c(sg); + sub_group sg = spmd_item.get_sub_group(); + joint_matrix + sub_a; + // For B, we assume B has been already VNNIed. + joint_matrix + sub_b; + joint_matrix sub_c; - // AMX: 8 register tiles : 1k byte size, SMmaxxSKmax =16x64 - // strideX = X's cols, so strideC = N, strideA = K, strideB = N*4 joint_matrix_load(sg, sub_c, accC.get_pointer() + (sg_startx * TM) * N + sg_starty / SG_SZ * TN, - N, matrix_layout::row_major); + N, layout::row_major); for (int k = 0; k < K / TK; k += 1) { joint_matrix_load( sg, sub_a, accA.get_pointer() + (sg_startx * TM) * K + k * TK, - K, matrix_layout::row_major); + K); // Assuming B data is already in VNNI format. joint_matrix_load(sg, sub_b, accB.get_pointer() + (k * TK / 4) * (N * 4) + sg_starty / SG_SZ * TN * 4, - N * 4, matrix_layout::packed_b); + N * 4); sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); } joint_matrix_store(sg, sub_c, accC.get_pointer() + (sg_startx * TM) * N + sg_starty / SG_SZ * TN, - N, matrix_layout::row_major); + N, layout::row_major); }); // parallel for }).wait(); }