|
1 | 1 | // RUN: triton-opt %s -tritongpu-BF16DotTC -canonicalize | FileCheck %s --check-prefixes=CHECK |
2 | 2 |
|
| 3 | + |
| 4 | + |
| 5 | +//// Tests for BF16x3: |
| 6 | + |
| 7 | +// CHECK: %[[lhs_hi:.*]] = arith.truncf %arg0 |
| 8 | +// CHECK-NEXT: %[[val1:.*]] = arith.extf %[[lhs_hi]] |
| 9 | +// CHECK-NEXT: %[[val2:.*]] = arith.subf %arg0, %[[val1]] |
| 10 | +// CHECK-NEXT: %[[lhs_mid:.*]] = arith.truncf %[[val2]] |
| 11 | + |
| 12 | +// CHECK: %[[rhs_hi:.*]] = arith.truncf %arg1 |
| 13 | +// CHECK-NEXT: %[[val8:.*]] = arith.extf %[[rhs_hi]] |
| 14 | +// CHECK-NEXT: %[[val9:.*]] = arith.subf %arg1, %[[val8]] |
| 15 | +// CHECK-NEXT: %[[rhs_mid:.*]] = arith.truncf %[[val9]] |
| 16 | + |
| 17 | +// CHECK-NEXT: %[[val20:.*]] = tt.dot %[[lhs_mid]], %[[rhs_hi]] |
| 18 | +// CHECK-NEXT: %[[val21:.*]] = tt.dot %[[lhs_hi]], %[[rhs_mid]], %[[val20]] |
| 19 | + |
| 20 | +// CHECK: %[[val22:.*]] = arith.cmpf uno, %[[val21]], %[[val21]] |
| 21 | +// CHECK-NEXT: %[[val23:.*]] = arith.select %[[val22]] |
| 22 | + |
| 23 | +// CHECK: %[[val24:.*]] = tt.dot %[[lhs_hi]], %[[rhs_hi]], %[[val23]] |
| 24 | +// CHECK-NEXT: %[[val25:.*]] = arith.addf %[[val24]], %arg2 |
| 25 | + |
| 26 | + |
| 27 | + |
| 28 | +//// Tests for BF16x6: |
| 29 | + |
| 30 | +// CHECK: %[[lhs_hi:.*]] = arith.truncf %arg0 |
| 31 | +// CHECK-NEXT: %[[val1:.*]] = arith.extf %[[lhs_hi]] |
| 32 | +// CHECK-NEXT: %[[val2:.*]] = arith.subf %arg0, %[[val1]] |
| 33 | +// CHECK-NEXT: %[[lhs_mid:.*]] = arith.truncf %[[val2]] |
| 34 | +// CHECK-NEXT: %[[val4:.*]] = arith.extf %[[lhs_mid]] |
| 35 | +// CHECK-NEXT: %[[val5:.*]] = arith.subf %[[val2]], %[[val4]] |
| 36 | +// CHECK-NEXT: %[[lhs_lo:.*]] = arith.truncf %[[val5]] |
| 37 | + |
| 38 | +// CHECK: %[[rhs_hi:.*]] = arith.truncf %arg1 |
| 39 | +// CHECK-NEXT: %[[val8:.*]] = arith.extf %[[rhs_hi]] |
| 40 | +// CHECK-NEXT: %[[val9:.*]] = arith.subf %arg1, %[[val8]] |
| 41 | +// CHECK-NEXT: %[[rhs_mid:.*]] = arith.truncf %[[val9]] |
| 42 | +// CHECK-NEXT: %[[val11:.*]] = arith.extf %[[rhs_mid]] |
| 43 | +// CHECK-NEXT: %[[val12:.*]] = arith.subf %[[val9]], %[[val11]] |
| 44 | +// CHECK-NEXT: %[[rhs_lo:.*]] = arith.truncf %[[val12]] |
| 45 | + |
| 46 | +// CHECK: %[[val17:.*]] = tt.dot %[[lhs_mid]], %[[rhs_mid]] |
| 47 | +// CHECK-NEXT: %[[val18:.*]] = tt.dot %[[lhs_lo]], %[[rhs_hi]], %[[val17]] |
| 48 | +// CHECK-NEXT: %[[val19:.*]] = tt.dot %[[lhs_hi]], %[[rhs_lo]], %[[val18]] |
| 49 | +// CHECK-NEXT: %[[val20:.*]] = tt.dot %[[lhs_mid]], %[[rhs_hi]], %[[val19]] |
| 50 | +// CHECK-NEXT: %[[val21:.*]] = tt.dot %[[lhs_hi]], %[[rhs_mid]], %[[val20]] |
| 51 | + |
| 52 | +// CHECK: %[[val22:.*]] = arith.cmpf uno, %[[val21]], %[[val21]] |
| 53 | +// CHECK-NEXT: %[[val23:.*]] = arith.select %[[val22]] |
| 54 | + |
| 55 | +// CHECK: %[[val24:.*]] = tt.dot %[[lhs_hi]], %[[rhs_hi]], %[[val23]] |
| 56 | +// CHECK-NEXT: %[[val25:.*]] = arith.addf %[[val24]], %arg2 |
| 57 | + |
| 58 | + |
| 59 | + |
| 60 | +//// Tests for BF16x9: |
| 61 | + |
3 | 62 | // CHECK: %[[lhs_hi:.*]] = arith.truncf %arg0 |
4 | 63 | // CHECK-NEXT: %[[val1:.*]] = arith.extf %[[lhs_hi]] |
5 | 64 | // CHECK-NEXT: %[[val2:.*]] = arith.subf %arg0, %[[val1]] |
|
17 | 76 | // CHECK-NEXT: %[[rhs_lo:.*]] = arith.truncf %[[val12]] |
18 | 77 |
|
19 | 78 | // CHECK: %[[val14:.*]] = tt.dot %[[lhs_lo]], %[[rhs_lo]] |
20 | | -// CHECK-NEXT: %[[val15:.*]] = tt.dot %[[lhs_mid]], %[[rhs_lo]], %[[val14]], inputPrecision = bf16 |
21 | | -// CHECK-NEXT: %[[val16:.*]] = tt.dot %[[lhs_lo]], %[[rhs_mid]], %[[val15]], inputPrecision = bf16 |
22 | | -// CHECK-NEXT: %[[val17:.*]] = tt.dot %[[lhs_mid]], %[[rhs_mid]], %[[val16]], inputPrecision = bf16 |
23 | | -// CHECK-NEXT: %[[val18:.*]] = tt.dot %[[lhs_lo]], %[[rhs_hi]], %[[val17]], inputPrecision = bf16 |
24 | | -// CHECK-NEXT: %[[val19:.*]] = tt.dot %[[lhs_hi]], %[[rhs_lo]], %[[val18]], inputPrecision = bf16 |
25 | | -// CHECK-NEXT: %[[val20:.*]] = tt.dot %[[lhs_mid]], %[[rhs_hi]], %[[val19]], inputPrecision = bf16 |
26 | | -// CHECK-NEXT: %[[val21:.*]] = tt.dot %[[lhs_hi]], %[[rhs_mid]], %[[val20]], inputPrecision = bf16 |
| 79 | +// CHECK-NEXT: %[[val15:.*]] = tt.dot %[[lhs_mid]], %[[rhs_lo]], %[[val14]] |
| 80 | +// CHECK-NEXT: %[[val16:.*]] = tt.dot %[[lhs_lo]], %[[rhs_mid]], %[[val15]] |
| 81 | +// CHECK-NEXT: %[[val17:.*]] = tt.dot %[[lhs_mid]], %[[rhs_mid]], %[[val16]] |
| 82 | +// CHECK-NEXT: %[[val18:.*]] = tt.dot %[[lhs_lo]], %[[rhs_hi]], %[[val17]] |
| 83 | +// CHECK-NEXT: %[[val19:.*]] = tt.dot %[[lhs_hi]], %[[rhs_lo]], %[[val18]] |
| 84 | +// CHECK-NEXT: %[[val20:.*]] = tt.dot %[[lhs_mid]], %[[rhs_hi]], %[[val19]] |
| 85 | +// CHECK-NEXT: %[[val21:.*]] = tt.dot %[[lhs_hi]], %[[rhs_mid]], %[[val20]] |
27 | 86 |
|
28 | 87 | // CHECK: %[[val22:.*]] = arith.cmpf uno, %[[val21]], %[[val21]] |
29 | 88 | // CHECK-NEXT: %[[val23:.*]] = arith.select %[[val22]] |
30 | 89 |
|
31 | | -// CHECK: %[[val24:.*]] = tt.dot %[[lhs_hi]], %[[rhs_hi]], %[[val23]], inputPrecision = bf16 |
| 90 | +// CHECK: %[[val24:.*]] = tt.dot %[[lhs_hi]], %[[rhs_hi]], %[[val23]] |
32 | 91 | // CHECK-NEXT: %[[val25:.*]] = arith.addf %[[val24]], %arg2 |
33 | 92 |
|
34 | 93 | module { |
35 | | - tt.func @dot_test(%arg0: tensor<16x16xf32>, %arg1: tensor<16x16xf32>, %arg2: tensor<16x16xf32>) -> tensor<16x16xf32> { |
| 94 | + tt.func @dot_test_BF16x3(%arg0: tensor<16x16xf32>, %arg1: tensor<16x16xf32>, %arg2: tensor<16x16xf32>) -> tensor<16x16xf32> { |
36 | 95 | %4 = tt.dot %arg0, %arg1, %arg2, inputPrecision = bf16x3 : tensor<16x16xf32> * tensor<16x16xf32> -> tensor<16x16xf32> |
37 | 96 | tt.return %4 : tensor<16x16xf32> |
38 | 97 | } |
| 98 | + |
| 99 | + tt.func @dot_test_BF16x6(%arg0: tensor<16x16xf32>, %arg1: tensor<16x16xf32>, %arg2: tensor<16x16xf32>) -> tensor<16x16xf32> { |
| 100 | + %4 = tt.dot %arg0, %arg1, %arg2, inputPrecision = bf16x6 : tensor<16x16xf32> * tensor<16x16xf32> -> tensor<16x16xf32> |
| 101 | + tt.return %4 : tensor<16x16xf32> |
| 102 | + } |
| 103 | + |
| 104 | + tt.func @dot_test_BF16x9(%arg0: tensor<16x16xf32>, %arg1: tensor<16x16xf32>, %arg2: tensor<16x16xf32>) -> tensor<16x16xf32> { |
| 105 | + %4 = tt.dot %arg0, %arg1, %arg2, inputPrecision = bf16x9 : tensor<16x16xf32> * tensor<16x16xf32> -> tensor<16x16xf32> |
| 106 | + tt.return %4 : tensor<16x16xf32> |
| 107 | + } |
39 | 108 | } |
0 commit comments