@@ -74,7 +74,7 @@ T2 matrix_ref_mn(const int &m, const int &n, T1 *A, T1 *B, T2 *C) {
7474}
7575
7676template <typename T1, typename T2, size_t Sub_Tiles_M, size_t Sub_Tiles_K,
77- size_t Sub_Tiles_N, size_t M, size_t K, size_t N>
77+ size_t Sub_Tiles_N, size_t M, size_t K, size_t N, typename T3 = T1 >
7878void test () {
7979
8080 constexpr auto Big_M =
@@ -131,19 +131,19 @@ void test() {
131131 range<2 > GlobalRange = {Sub_Tiles_M, Sub_Tiles_N * N_THREADS_PER_MATRIX_OP};
132132
133133 cgh.parallel_for <KernelName<T1, T2, M, K, N>>(
134- nd_range<2 >(GlobalRange, LocalRange), [=
135- ](nd_item<2 > item) [[sycl::reqd_work_group_size (1 , 1 , 32 )]] {
134+ nd_range<2 >(GlobalRange, LocalRange),
135+ [= ](nd_item<2 > item) [[sycl::reqd_work_group_size (1 , 1 , 32 )]] {
136136 sycl::sub_group sg = item.get_sub_group ();
137137 const auto m =
138- item.get_group ()
139- . get_id ()[ 0 ]; // row id of current submatrix of BIG C matrix
138+ item.get_group (). get_group_id ()[ 0 ]; // row id of current submatrix
139+ // of BIG C matrix
140140 const auto n =
141- item.get_group ().get_id ()[1 ]; // column id of current
142- // submatrix of BIG C matrix
141+ item.get_group ().get_group_id ()[1 ]; // column id of current
142+ // submatrix of BIG C matrix
143143
144- joint_matrix<T1 , matrix_use::a, M, K, matrix_layout::row_major> sub_a;
144+ joint_matrix<T3 , matrix_use::a, M, K, matrix_layout::row_major> sub_a;
145145
146- joint_matrix<T1 , matrix_use::b, K, N, matrix_layout::row_major> sub_b;
146+ joint_matrix<T3 , matrix_use::b, K, N, matrix_layout::row_major> sub_b;
147147
148148 joint_matrix<T2, matrix_use::accumulator, M, N,
149149 matrix_layout::row_major>
@@ -163,6 +163,14 @@ void test() {
163163 accB.get_pointer () + (k * K * Big_N) + (n * N),
164164 Big_N);
165165
166+ // Convert values if using tf32
167+ if constexpr (std::is_same<T3, precision::tf32>::value) {
168+ for (auto i = 0 ; i < 4 ; ++i) {
169+ sub_a.data [i] = round_to_tf32 (sub_a.data [i]);
170+ sub_b.data [i] = round_to_tf32 (sub_b.data [i]);
171+ }
172+ }
173+
166174 sub_c = joint_matrix_mad (sg, sub_a, sub_b, sub_c);
167175 }
168176 joint_matrix_store (
@@ -182,7 +190,6 @@ void test() {
182190};
183191
184192int main () {
185-
186193 // A/B half, Accumulator float
187194 test<half, float , SUB_TILES_M, SUB_TILES_K, SUB_TILES_N, 16 , 16 , 16 >();
188195 test<half, float , SUB_TILES_M, SUB_TILES_K, SUB_TILES_N, 8 , 16 , 32 >();
@@ -208,5 +215,9 @@ int main() {
208215 test<uint16_t , float , SUB_TILES_M, SUB_TILES_K, SUB_TILES_N, 8 , 16 , 32 >();
209216 test<uint16_t , float , SUB_TILES_M, SUB_TILES_K, SUB_TILES_N, 32 , 16 , 8 >();
210217
218+ // A/B tf32
219+ test<float , float , SUB_TILES_M, SUB_TILES_K, SUB_TILES_N, 16 , 8 , 16 ,
220+ precision::tf32>();
221+
211222 return 0 ;
212223};
0 commit comments