Skip to content

Commit 591d437

Browse files
committed
[BACKEND] Implement BF16x3 trick
Implements emulation of a 32-bit floating point dot operation using 3 BF16s. This is based on https://arxiv.org/abs/1904.06376 and works because the mantisa of 3 BF16s add up to the mantisa of a fp32. Storing 1 fp32 in 3 bf16s: ``` def BF16(v): return v.to(torch.bfloat16) def FP32(v): return v.to(torch.float32) def BF16x3(v): b0 = BF16(original) b1 = BF16(original - FP32(b0)) b2 = BF16(original - FP32(b0) - FP32(b1)) return (b0, b1, b2) original = torch.rand(1, 1, dtype=torch.float32) bf16x3 = BF16x3(original) ``` Emulating multiplication of two fp32s: ``` def mul_bf16x3(a, b, c): a0, a1, a2 = BF16x3(a) b0, b1, b2 = BF16x3(b) c = c + (a0 * b0) # low low c = c + (a1 * b0) # mid low c = c + (a0 * b1) # low mid c = c + (a1 * b1) # mid mid c = c + (a0 * b2) # low hi c = c + (a2 * b0) # hi low c = c + (a1 * b2) # mid hi c = c + (a2 * b1) # hi mid c = c + (a2 * b2) # hi hi return c a = torch.rand(1, 1, dtype=torch.float32) b = torch.rand(1, 1, dtype=torch.float32) c = torch.zeros(1, 1, dtype=torch.float32) # accumulator result = mul_bf16x3(a, b, c) ``` The emulation using BF16x3 is used when invoking tl.dot with input precision 'BF16x3'. This pass is implemented in a GPU agnostic manner, but it is needed support for MI350's lack of TF32 support. This part is a work in progress but will be based on this patch.
1 parent 16b25e1 commit 591d437

File tree

11 files changed

+200
-5
lines changed

11 files changed

+200
-5
lines changed

include/triton/Dialect/Triton/IR/TritonAttrDefs.td

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,9 @@ def TT_InputPrecisionAttr : I32EnumAttr<
129129
[
130130
I32EnumAttrCase<"TF32", 0, "tf32">,
131131
I32EnumAttrCase<"TF32x3", 1, "tf32x3">,
132-
I32EnumAttrCase<"IEEE", 2, "ieee">
132+
I32EnumAttrCase<"IEEE", 2, "ieee">,
133+
I32EnumAttrCase<"BF16", 3, "bf16">,
134+
I32EnumAttrCase<"BF16x3", 4, "bf16x3">
133135
]>{
134136
let cppNamespace = "::mlir::triton";
135137
}

include/triton/Dialect/TritonGPU/Transforms/Passes.td

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,17 @@ def TritonGPUF32DotTC : Pass<"tritongpu-F32DotTC", "mlir::ModuleOp"> {
201201
"mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect"];
202202
}
203203

204+
def TritonGPUBF16x3Dot : Pass<"tritongpu-BF16x3Dot", "mlir::ModuleOp"> {
205+
let summary = "3xBF16 trick";
206+
207+
let description = [{
208+
Decompose fp32 `DotOp` instructions into BF16 operations.
209+
See https://arxiv.org/abs/1904.06376
210+
}];
211+
212+
let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect"];
213+
}
214+
204215
def TritonGPUPrefetch : Pass<"tritongpu-prefetch", "mlir::ModuleOp"> {
205216
let summary = "prefetch";
206217

Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
#include "mlir/IR/PatternMatch.h"
2+
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
3+
#include "triton/Dialect/TritonGPU/Transforms/Passes.h"
4+
#include "llvm/Support/raw_ostream.h"
5+
6+
namespace mlir {
7+
namespace triton {
8+
namespace gpu {
9+
10+
#define GEN_PASS_DEF_TRITONGPUBF16X3DOT
11+
#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc"
12+
13+
namespace {
14+
15+
// Implement 3xBF16 https://arxiv.org/abs/1904.06376
16+
class BF16x3 : public OpRewritePattern<DotOp> {
17+
public:
18+
using OpRewritePattern::OpRewritePattern;
19+
20+
LogicalResult matchAndRewrite(DotOp dotOp,
21+
PatternRewriter &rewriter) const override {
22+
23+
auto isBF16x3Candidate = [](Value operand) {
24+
return cast<RankedTensorType>(operand.getType()).getElementType().isF32();
25+
};
26+
27+
if (!(dotOp.getInputPrecision() == InputPrecision::BF16x3 &&
28+
isBF16x3Candidate(dotOp.getA()) && isBF16x3Candidate(dotOp.getB()))) {
29+
return failure();
30+
}
31+
32+
// Aux functions
33+
auto f32ToBF16 = [&](Value value) -> Value {
34+
auto fp32Type = cast<RankedTensorType>(value.getType());
35+
auto bf16Type =
36+
RankedTensorType::get(fp32Type.getShape(), rewriter.getBF16Type());
37+
return rewriter.create<arith::TruncFOp>(dotOp.getLoc(), bf16Type, value)
38+
.getResult();
39+
};
40+
auto bf16ToF32 = [&](Value value) -> Value {
41+
auto bf16Type = cast<RankedTensorType>(value.getType());
42+
auto fp32Type =
43+
RankedTensorType::get(bf16Type.getShape(), rewriter.getF32Type());
44+
return rewriter.create<arith::ExtFOp>(dotOp.getLoc(), fp32Type, value)
45+
.getResult();
46+
};
47+
auto zeroLike = [&](Value c) -> Value {
48+
return rewriter.create<SplatOp>(
49+
dotOp->getLoc(), c.getType(),
50+
rewriter.create<arith::ConstantOp>(dotOp->getLoc(),
51+
rewriter.getF32FloatAttr(0)));
52+
};
53+
auto add = [&](Value a, Value b) -> Value {
54+
return rewriter.create<arith::AddFOp>(dotOp.getLoc(), a, b);
55+
};
56+
auto sub = [&](Value a, Value b) -> Value {
57+
return rewriter.create<arith::SubFOp>(dotOp.getLoc(), a, b);
58+
};
59+
auto dot = [&](Value a, Value b, Value c) -> Value {
60+
return rewriter.create<DotOp>(dotOp->getLoc(), c.getType(), a, b, c,
61+
InputPrecision::BF16,
62+
dotOp.getMaxNumImpreciseAcc());
63+
};
64+
auto replaceNansWithZeros = [&](Value value) -> Value {
65+
auto nans = rewriter.create<arith::CmpFOp>(
66+
dotOp->getLoc(), arith::CmpFPredicate::UNO, value, value);
67+
auto zero = zeroLike(value);
68+
return rewriter.create<arith::SelectOp>(dotOp->getLoc(), nans, zero,
69+
value);
70+
};
71+
72+
auto SplitF32 = [&](Value input, unsigned N) -> std::vector<Value> {
73+
std::vector<Value> split_inputs;
74+
split_inputs.reserve(N);
75+
for (int i = 0; i < N; ++i) {
76+
Value input_as_bf16 = f32ToBF16(input);
77+
if (i != N - 1) {
78+
Value input_as_f32 = bf16ToF32(input_as_bf16);
79+
input = rewriter.create<arith::SubFOp>(dotOp->getLoc(), input,
80+
input_as_f32);
81+
}
82+
split_inputs.push_back(input_as_bf16);
83+
}
84+
return split_inputs;
85+
};
86+
87+
const int hi = 0;
88+
const int med = 1;
89+
const int lo = 2;
90+
91+
const unsigned N = 3;
92+
auto lhs_parts = SplitF32(dotOp.getA(), N);
93+
auto rhs_parts = SplitF32(dotOp.getB(), N);
94+
95+
auto result = zeroLike(dotOp.getC());
96+
97+
result = dot(lhs_parts[lo], rhs_parts[lo], result);
98+
result = dot(lhs_parts[med], rhs_parts[lo], result);
99+
result = dot(lhs_parts[lo], rhs_parts[med], result);
100+
101+
result = dot(lhs_parts[med], rhs_parts[med], result);
102+
103+
result = dot(lhs_parts[lo], rhs_parts[hi], result);
104+
result = dot(lhs_parts[hi], rhs_parts[lo], result);
105+
106+
result = dot(lhs_parts[med], rhs_parts[hi], result);
107+
result = dot(lhs_parts[hi], rhs_parts[med], result);
108+
109+
result = replaceNansWithZeros(result);
110+
result = dot(lhs_parts[hi], rhs_parts[hi], result);
111+
result = add(result, dotOp.getC());
112+
113+
rewriter.replaceOp(dotOp, result);
114+
return success();
115+
}
116+
};
117+
118+
} // anonymous namespace
119+
120+
struct BF16x3DotPass : public impl::TritonGPUBF16x3DotBase<BF16x3DotPass> {
121+
void runOnOperation() override {
122+
MLIRContext *context = &getContext();
123+
ModuleOp m = getOperation();
124+
125+
RewritePatternSet decomposePatterns(context);
126+
decomposePatterns.add<BF16x3>(context);
127+
if (applyPatternsGreedily(m, std::move(decomposePatterns)).failed()) {
128+
signalPassFailure();
129+
}
130+
}
131+
};
132+
133+
} // namespace gpu
134+
} // namespace triton
135+
} // namespace mlir

lib/Dialect/TritonGPU/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ add_triton_library(TritonGPUTransforms
22
AccelerateMatmul.cpp
33
Coalesce.cpp
44
F32DotTC.cpp
5+
BF16x3Dot.cpp
56
FuseNestedLoops.cpp
67
CombineTensorSelectAndIf.cpp
78
DecomposeScaledBlocked.cpp

python/src/ir.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -272,6 +272,8 @@ void init_triton_ir(py::module &&m) {
272272
.value("TF32", InputPrecision::TF32)
273273
.value("TF32x3", InputPrecision::TF32x3)
274274
.value("IEEE", InputPrecision::IEEE)
275+
.value("BF16", InputPrecision::BF16)
276+
.value("BF16x3", InputPrecision::BF16x3)
275277
.export_values();
276278

277279
py::enum_<ScaleDotElemType>(m, "ScaleDotElemTypeTY", py::module_local())

python/src/passes.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ void init_triton_passes_ttgpuir(py::module &&m) {
7070
ADD_PASS_WRAPPER_0("add_accelerate_matmul", createTritonGPUAccelerateMatmul);
7171
ADD_PASS_WRAPPER_0("add_reorder_instructions",
7272
createTritonGPUReorderInstructions);
73+
ADD_PASS_WRAPPER_0("add_bf16x3_dot", createTritonGPUBF16x3Dot);
7374
ADD_PASS_WRAPPER_0("add_f32_dot_tc", createTritonGPUF32DotTC);
7475
ADD_PASS_OPTION_WRAPPER_1("add_optimize_dot_operands",
7576
createTritonGPUOptimizeDotOperands, bool);

python/test/unit/language/test_core.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3761,7 +3761,7 @@ def get_test_dot_base_cases():
37613761
return [(*shape, 4, False, False, epilogue, input_precision, in_dtype, out_dtype, 1, None)
37623762
for shape in [(64, 64, 64), (32, 32, 32), (16, 16, 16)]
37633763
for epilogue in ['none', 'trans', 'add-matrix', 'add-rows', 'add-cols', 'softmax', 'chain-dot']
3764-
for input_precision in ['tf32', 'tf32x3', 'ieee']
3764+
for input_precision in ['tf32', 'tf32x3', 'ieee', 'bf16', 'bf16x3']
37653765
for in_dtype, out_dtype in [('float16', 'float16'), ('float16',
37663766
'float32'), ('float32',
37673767
'float32'), ('float64', 'float64')]
@@ -3915,7 +3915,8 @@ def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, input_precision, in_dty
39153915
pytest.skip(f"{in_dtype} only supported on CDNA4 and gfx12")
39163916
if in_dtype in ("float8e5b16", "float8e4b8") and not is_hip_cdna3():
39173917
pytest.skip(f"{in_dtype} only supported on CDNA3")
3918-
if not ((input_precision == "ieee") or (input_precision == "tf32" and is_hip_cdna3())):
3918+
if not ((input_precision == "bf16x3") or (input_precision == "ieee") or
3919+
(input_precision == "tf32" and is_hip_cdna3())):
39193920
pytest.skip(f"{input_precision} not supported on HIP")
39203921
if kpack == 2 and in_dtype == 'int8' and K < 64:
39213922
pytest.skip("kpack too large for K")

python/triton/language/semantic.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1465,6 +1465,8 @@ def _str_to_dot_input_precision(self, input_precision):
14651465
input_precision = input_precision.upper()
14661466
if input_precision == "TF32X3":
14671467
input_precision = "TF32x3"
1468+
if input_precision == "BF16X3":
1469+
input_precision = "BF16x3"
14681470
return getattr(ir.INPUT_PRECISION, input_precision)
14691471

14701472
def dot(self, lhs: TensorTy, rhs: TensorTy, acc: TensorTy, input_precision: Optional[str],

test/TritonGPU/bf16x3-matmul.mlir

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
// RUN: triton-opt %s -tritongpu-BF16x3Dot -canonicalize | FileCheck %s --check-prefixes=CHECK
2+
3+
// CHECK: %[[lhs_hi:.*]] = arith.truncf %arg0
4+
// CHECK-NEXT: %[[val1:.*]] = arith.extf %[[lhs_hi]]
5+
// CHECK-NEXT: %[[val2:.*]] = arith.subf %arg0, %[[val1]]
6+
// CHECK-NEXT: %[[lhs_mid:.*]] = arith.truncf %[[val2]]
7+
// CHECK-NEXT: %[[val4:.*]] = arith.extf %[[lhs_mid]]
8+
// CHECK-NEXT: %[[val5:.*]] = arith.subf %[[val2]], %[[val4]]
9+
// CHECK-NEXT: %[[lhs_lo:.*]] = arith.truncf %[[val5]]
10+
11+
// CHECK: %[[rhs_hi:.*]] = arith.truncf %arg1
12+
// CHECK-NEXT: %[[val8:.*]] = arith.extf %[[rhs_hi]]
13+
// CHECK-NEXT: %[[val9:.*]] = arith.subf %arg1, %[[val8]]
14+
// CHECK-NEXT: %[[rhs_mid:.*]] = arith.truncf %[[val9]]
15+
// CHECK-NEXT: %[[val11:.*]] = arith.extf %[[rhs_mid]]
16+
// CHECK-NEXT: %[[val12:.*]] = arith.subf %[[val9]], %[[val11]]
17+
// CHECK-NEXT: %[[rhs_lo:.*]] = arith.truncf %[[val12]]
18+
19+
// CHECK: %[[val14:.*]] = tt.dot %[[lhs_lo]], %[[rhs_lo]]
20+
// CHECK-NEXT: %[[val15:.*]] = tt.dot %[[lhs_mid]], %[[rhs_lo]], %[[val14]], inputPrecision = bf16
21+
// CHECK-NEXT: %[[val16:.*]] = tt.dot %[[lhs_lo]], %[[rhs_mid]], %[[val15]], inputPrecision = bf16
22+
// CHECK-NEXT: %[[val17:.*]] = tt.dot %[[lhs_mid]], %[[rhs_mid]], %[[val16]], inputPrecision = bf16
23+
// CHECK-NEXT: %[[val18:.*]] = tt.dot %[[lhs_lo]], %[[rhs_hi]], %[[val17]], inputPrecision = bf16
24+
// CHECK-NEXT: %[[val19:.*]] = tt.dot %[[lhs_hi]], %[[rhs_lo]], %[[val18]], inputPrecision = bf16
25+
// CHECK-NEXT: %[[val20:.*]] = tt.dot %[[lhs_mid]], %[[rhs_hi]], %[[val19]], inputPrecision = bf16
26+
// CHECK-NEXT: %[[val21:.*]] = tt.dot %[[lhs_hi]], %[[rhs_mid]], %[[val20]], inputPrecision = bf16
27+
28+
// CHECK: %[[val22:.*]] = arith.cmpf uno, %[[val21]], %[[val21]]
29+
// CHECK-NEXT: %[[val23:.*]] = arith.select %[[val22]]
30+
31+
// CHECK: %[[val24:.*]] = tt.dot %[[lhs_hi]], %[[rhs_hi]], %[[val23]], inputPrecision = bf16
32+
// CHECK-NEXT: %[[val25:.*]] = arith.addf %[[val24]], %arg2
33+
34+
module {
35+
tt.func @dot_test(%arg0: tensor<16x16xf32>, %arg1: tensor<16x16xf32>, %arg2: tensor<16x16xf32>) -> tensor<16x16xf32> {
36+
%4 = tt.dot %arg0, %arg1, %arg2, inputPrecision = bf16x3 : tensor<16x16xf32> * tensor<16x16xf32> -> tensor<16x16xf32>
37+
tt.return %4 : tensor<16x16xf32>
38+
}
39+
}

third_party/amd/backend/compiler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ class HIPOptions:
4444
supported_fp8_dtypes: Tuple[str] = ("fp8e4nv", "fp8e5", "fp8e5b16", "fp8e4b8")
4545
deprecated_fp8_dot_operand_dtypes: Tuple[str] = ()
4646
default_dot_input_precision: str = "ieee"
47-
allowed_dot_input_precisions: Tuple[str] = ("ieee", )
47+
allowed_dot_input_precisions: Tuple[str] = ("ieee", 'bf16x3')
4848
enable_fp_fusion: bool = True
4949
launch_cooperative_grid: bool = False
5050
matrix_instr_nonkdim: int = 0

0 commit comments

Comments
 (0)