Skip to content

Commit a6a28f2

Browse files
committed
improve lit tests
1 parent 31d1bf8 commit a6a28f2

File tree

1 file changed

+78
-9
lines changed

1 file changed

+78
-9
lines changed

test/TritonGPU/bf16x3-matmul.mlir

Lines changed: 78 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,64 @@
11
// RUN: triton-opt %s -tritongpu-BF16DotTC -canonicalize | FileCheck %s --check-prefixes=CHECK
22

3+
4+
5+
//// Tests for BF16x3:
6+
7+
// CHECK: %[[lhs_hi:.*]] = arith.truncf %arg0
8+
// CHECK-NEXT: %[[val1:.*]] = arith.extf %[[lhs_hi]]
9+
// CHECK-NEXT: %[[val2:.*]] = arith.subf %arg0, %[[val1]]
10+
// CHECK-NEXT: %[[lhs_mid:.*]] = arith.truncf %[[val2]]
11+
12+
// CHECK: %[[rhs_hi:.*]] = arith.truncf %arg1
13+
// CHECK-NEXT: %[[val8:.*]] = arith.extf %[[rhs_hi]]
14+
// CHECK-NEXT: %[[val9:.*]] = arith.subf %arg1, %[[val8]]
15+
// CHECK-NEXT: %[[rhs_mid:.*]] = arith.truncf %[[val9]]
16+
17+
// CHECK-NEXT: %[[val20:.*]] = tt.dot %[[lhs_mid]], %[[rhs_hi]]
18+
// CHECK-NEXT: %[[val21:.*]] = tt.dot %[[lhs_hi]], %[[rhs_mid]], %[[val20]]
19+
20+
// CHECK: %[[val22:.*]] = arith.cmpf uno, %[[val21]], %[[val21]]
21+
// CHECK-NEXT: %[[val23:.*]] = arith.select %[[val22]]
22+
23+
// CHECK: %[[val24:.*]] = tt.dot %[[lhs_hi]], %[[rhs_hi]], %[[val23]]
24+
// CHECK-NEXT: %[[val25:.*]] = arith.addf %[[val24]], %arg2
25+
26+
27+
28+
//// Tests for BF16x6:
29+
30+
// CHECK: %[[lhs_hi:.*]] = arith.truncf %arg0
31+
// CHECK-NEXT: %[[val1:.*]] = arith.extf %[[lhs_hi]]
32+
// CHECK-NEXT: %[[val2:.*]] = arith.subf %arg0, %[[val1]]
33+
// CHECK-NEXT: %[[lhs_mid:.*]] = arith.truncf %[[val2]]
34+
// CHECK-NEXT: %[[val4:.*]] = arith.extf %[[lhs_mid]]
35+
// CHECK-NEXT: %[[val5:.*]] = arith.subf %[[val2]], %[[val4]]
36+
// CHECK-NEXT: %[[lhs_lo:.*]] = arith.truncf %[[val5]]
37+
38+
// CHECK: %[[rhs_hi:.*]] = arith.truncf %arg1
39+
// CHECK-NEXT: %[[val8:.*]] = arith.extf %[[rhs_hi]]
40+
// CHECK-NEXT: %[[val9:.*]] = arith.subf %arg1, %[[val8]]
41+
// CHECK-NEXT: %[[rhs_mid:.*]] = arith.truncf %[[val9]]
42+
// CHECK-NEXT: %[[val11:.*]] = arith.extf %[[rhs_mid]]
43+
// CHECK-NEXT: %[[val12:.*]] = arith.subf %[[val9]], %[[val11]]
44+
// CHECK-NEXT: %[[rhs_lo:.*]] = arith.truncf %[[val12]]
45+
46+
// CHECK: %[[val17:.*]] = tt.dot %[[lhs_mid]], %[[rhs_mid]]
47+
// CHECK-NEXT: %[[val18:.*]] = tt.dot %[[lhs_lo]], %[[rhs_hi]], %[[val17]]
48+
// CHECK-NEXT: %[[val19:.*]] = tt.dot %[[lhs_hi]], %[[rhs_lo]], %[[val18]]
49+
// CHECK-NEXT: %[[val20:.*]] = tt.dot %[[lhs_mid]], %[[rhs_hi]], %[[val19]]
50+
// CHECK-NEXT: %[[val21:.*]] = tt.dot %[[lhs_hi]], %[[rhs_mid]], %[[val20]]
51+
52+
// CHECK: %[[val22:.*]] = arith.cmpf uno, %[[val21]], %[[val21]]
53+
// CHECK-NEXT: %[[val23:.*]] = arith.select %[[val22]]
54+
55+
// CHECK: %[[val24:.*]] = tt.dot %[[lhs_hi]], %[[rhs_hi]], %[[val23]]
56+
// CHECK-NEXT: %[[val25:.*]] = arith.addf %[[val24]], %arg2
57+
58+
59+
60+
//// Tests for BF16x9:
61+
362
// CHECK: %[[lhs_hi:.*]] = arith.truncf %arg0
463
// CHECK-NEXT: %[[val1:.*]] = arith.extf %[[lhs_hi]]
564
// CHECK-NEXT: %[[val2:.*]] = arith.subf %arg0, %[[val1]]
@@ -17,23 +76,33 @@
1776
// CHECK-NEXT: %[[rhs_lo:.*]] = arith.truncf %[[val12]]
1877

1978
// CHECK: %[[val14:.*]] = tt.dot %[[lhs_lo]], %[[rhs_lo]]
20-
// CHECK-NEXT: %[[val15:.*]] = tt.dot %[[lhs_mid]], %[[rhs_lo]], %[[val14]], inputPrecision = bf16
21-
// CHECK-NEXT: %[[val16:.*]] = tt.dot %[[lhs_lo]], %[[rhs_mid]], %[[val15]], inputPrecision = bf16
22-
// CHECK-NEXT: %[[val17:.*]] = tt.dot %[[lhs_mid]], %[[rhs_mid]], %[[val16]], inputPrecision = bf16
23-
// CHECK-NEXT: %[[val18:.*]] = tt.dot %[[lhs_lo]], %[[rhs_hi]], %[[val17]], inputPrecision = bf16
24-
// CHECK-NEXT: %[[val19:.*]] = tt.dot %[[lhs_hi]], %[[rhs_lo]], %[[val18]], inputPrecision = bf16
25-
// CHECK-NEXT: %[[val20:.*]] = tt.dot %[[lhs_mid]], %[[rhs_hi]], %[[val19]], inputPrecision = bf16
26-
// CHECK-NEXT: %[[val21:.*]] = tt.dot %[[lhs_hi]], %[[rhs_mid]], %[[val20]], inputPrecision = bf16
79+
// CHECK-NEXT: %[[val15:.*]] = tt.dot %[[lhs_mid]], %[[rhs_lo]], %[[val14]]
80+
// CHECK-NEXT: %[[val16:.*]] = tt.dot %[[lhs_lo]], %[[rhs_mid]], %[[val15]]
81+
// CHECK-NEXT: %[[val17:.*]] = tt.dot %[[lhs_mid]], %[[rhs_mid]], %[[val16]]
82+
// CHECK-NEXT: %[[val18:.*]] = tt.dot %[[lhs_lo]], %[[rhs_hi]], %[[val17]]
83+
// CHECK-NEXT: %[[val19:.*]] = tt.dot %[[lhs_hi]], %[[rhs_lo]], %[[val18]]
84+
// CHECK-NEXT: %[[val20:.*]] = tt.dot %[[lhs_mid]], %[[rhs_hi]], %[[val19]]
85+
// CHECK-NEXT: %[[val21:.*]] = tt.dot %[[lhs_hi]], %[[rhs_mid]], %[[val20]]
2786

2887
// CHECK: %[[val22:.*]] = arith.cmpf uno, %[[val21]], %[[val21]]
2988
// CHECK-NEXT: %[[val23:.*]] = arith.select %[[val22]]
3089

31-
// CHECK: %[[val24:.*]] = tt.dot %[[lhs_hi]], %[[rhs_hi]], %[[val23]], inputPrecision = bf16
90+
// CHECK: %[[val24:.*]] = tt.dot %[[lhs_hi]], %[[rhs_hi]], %[[val23]]
3291
// CHECK-NEXT: %[[val25:.*]] = arith.addf %[[val24]], %arg2
3392

3493
module {
35-
tt.func @dot_test(%arg0: tensor<16x16xf32>, %arg1: tensor<16x16xf32>, %arg2: tensor<16x16xf32>) -> tensor<16x16xf32> {
94+
tt.func @dot_test_BF16x3(%arg0: tensor<16x16xf32>, %arg1: tensor<16x16xf32>, %arg2: tensor<16x16xf32>) -> tensor<16x16xf32> {
3695
%4 = tt.dot %arg0, %arg1, %arg2, inputPrecision = bf16x3 : tensor<16x16xf32> * tensor<16x16xf32> -> tensor<16x16xf32>
3796
tt.return %4 : tensor<16x16xf32>
3897
}
98+
99+
tt.func @dot_test_BF16x6(%arg0: tensor<16x16xf32>, %arg1: tensor<16x16xf32>, %arg2: tensor<16x16xf32>) -> tensor<16x16xf32> {
100+
%4 = tt.dot %arg0, %arg1, %arg2, inputPrecision = bf16x6 : tensor<16x16xf32> * tensor<16x16xf32> -> tensor<16x16xf32>
101+
tt.return %4 : tensor<16x16xf32>
102+
}
103+
104+
tt.func @dot_test_BF16x9(%arg0: tensor<16x16xf32>, %arg1: tensor<16x16xf32>, %arg2: tensor<16x16xf32>) -> tensor<16x16xf32> {
105+
%4 = tt.dot %arg0, %arg1, %arg2, inputPrecision = bf16x9 : tensor<16x16xf32> * tensor<16x16xf32> -> tensor<16x16xf32>
106+
tt.return %4 : tensor<16x16xf32>
107+
}
39108
}

0 commit comments

Comments
 (0)