@@ -274,11 +274,11 @@ rows for the row major layout, or between columns for the column major layout.
274274```c++
275275namespace sycl::ext::oneapi::experimental::matrix {
276276
277- template <typename Group, typename Ta, typename Tb, typename Tc,
278- std::size_t M, std::size_t K, std::size_t N, layout LayoutA, layout
279- LayoutB, typename Td = Tc >
280- joint_matrix< Group, Td, use::accumulator, M, N, layout::dynamic>
281- joint_matrix_mad( Group g ,
277+ template <typename Group, typename Ta, typename Tb, typename Tc, typename Td,
278+ std::size_t M, std::size_t K, std::size_t N,
279+ layout LayoutA, layout LayoutB >
280+ void joint_matrix_mad( Group g,
281+ joint_matrix< Group, Td, use::accumulator, M, N, layout::dynamic> &D ,
282282 const joint_matrix<Group, Ta, use::a, M, K, LayoutA> &A,
283283 const joint_matrix<Group, Tb, use::b, K, N, LayoutB> &B,
284284 const joint_matrix<Group, Tc, use::accumulator, M, N, layout::dynamic> &C);
@@ -287,7 +287,7 @@ joint_matrix_mad(Group g,
287287```
288288The matrix multiply and add function performs the multiply operation
289289on the matrices `A` and `B`, accumulates the result with `C` and returns
290- the result.
290+ the result into the matrix `D` .
291291
292292Each device supports only certain combinations of types for the `A`,
293293`B`, and `C` matrices. The application must use the query operations
@@ -505,6 +505,12 @@ range<2> L = {1, SG_SIZE};
505505int8_t *memA = malloc_shared<int8_t>(M*K, q);
506506int8_t *memB = malloc_shared<int8_t>(K*N, q);
507507int32_t *memC = malloc_shared<int32_t>(M*N, q);
508+ auto pA = address_space_cast<sycl::access::address_space::global_space,
509+ sycl::access::decorated::no>(memA);
510+ auto pB = address_space_cast<sycl::access::address_space::global_space,
511+ sycl::access::decorated::no>(memB);
512+ auto pC = address_space_cast<sycl::access::address_space::global_space,
513+ sycl::access::decorated::no>(memC);
508514q.parallel_for(nd_range<2>(G, L), [=](nd_item<2> item)
509515 [[sycl::reqd_sub_group_size(SG_SIZE)]] {
510516 const auto global_idx = item.get_global_id(0);
@@ -517,20 +523,15 @@ q.parallel_for(nd_range<2>(G, L), [=](nd_item<2> item)
517523 joint_matrix<sub_group, int32_t, use::accumulator, tM, tN> tC;
518524 joint_matrix_fill(sg, tC, 0);
519525 for (int k = 0; k < K; k += tK) {
520- joint_matrix_load(sg, tA,
521- multi_ptr<int8_t, sycl::access::address_space::global_space>(memA) +
522- sg_startx * tM * K + k, K);
523- joint_matrix_load(sg, tB,
524- multi_ptr<int8_t, sycl::access::address_space::global_space>(memB) +
525- k * N + sg_starty/SG_SIZE*tN, N);
526- tC = joint_matrix_mad(sg, tA, tB, tC);
526+ joint_matrix_load(sg, tA, pA + sg_startx * tM * K + k, K);
527+ joint_matrix_load(sg, tB, pB + k * N + sg_starty/SG_SIZE*tN, N);
528+ joint_matrix_mad(sg, tC, tA, tB, tC);
527529 }
528530 joint_matrix_apply(sg, tC, [=](int8_t x) {
529531 x *= alpha;
530532 });
531- joint_matrix_store(sg, tC,
532- multi_ptr<int32_t, sycl::access::address_space::global_space>(memC) +
533- sg_startx * tM * N + sg_starty/SG_SIZE*tN, N, layout::row_major);
533+ joint_matrix_store(sg, tC, pC + sg_startx * tM * N + sg_starty/SG_SIZE*tN,
534+ N, layout::row_major);
534535}).wait();
535536```
536537
0 commit comments