From 10848214a2aeacf039f96354de2e5127723becf8 Mon Sep 17 00:00:00 2001 From: Puyan Lotfi Date: Tue, 22 Jul 2025 00:16:15 -0700 Subject: [PATCH 01/39] [BACKEND] Implement BF16x3 trick Implements emulation of a 32-bit floating point dot operation using 3 BF16s. This is based on https://arxiv.org/abs/1904.06376 and works because the mantisa of 3 BF16s add up to the mantisa of a fp32. Storing 1 fp32 in 3 bf16s: ``` def BF16(v): return v.to(torch.bfloat16) def FP32(v): return v.to(torch.float32) def BF16x3(v): b0 = BF16(original) b1 = BF16(original - FP32(b0)) b2 = BF16(original - FP32(b0) - FP32(b1)) return (b0, b1, b2) original = torch.rand(1, 1, dtype=torch.float32) bf16x3 = BF16x3(original) ``` Emulating multiplication of two fp32s: ``` def mul_bf16x3(a, b, c): a0, a1, a2 = BF16x3(a) b0, b1, b2 = BF16x3(b) c = c + (a0 * b0) # low low c = c + (a1 * b0) # mid low c = c + (a0 * b1) # low mid c = c + (a1 * b1) # mid mid c = c + (a0 * b2) # low hi c = c + (a2 * b0) # hi low c = c + (a1 * b2) # mid hi c = c + (a2 * b1) # hi mid c = c + (a2 * b2) # hi hi return c a = torch.rand(1, 1, dtype=torch.float32) b = torch.rand(1, 1, dtype=torch.float32) c = torch.zeros(1, 1, dtype=torch.float32) # accumulator result = mul_bf16x3(a, b, c) ``` The emulation using BF16x3 is used when invoking tl.dot with input precision 'BF16x3'. This pass is implemented in a GPU agnostic manner, but it is needed support for MI350's lack of TF32 support. This part is a work in progress but will be based on this patch. --- 03-matrix-multiplication.py | 160 ++++++++++++++++++ .../Dialect/Triton/IR/TritonAttrDefs.td | 6 +- .../Dialect/TritonGPU/Transforms/Passes.td | 11 ++ .../TritonGPU/Transforms/BF16DotTC.cpp | 157 +++++++++++++++++ .../TritonGPU/Transforms/CMakeLists.txt | 1 + python/src/ir.cc | 4 + python/src/passes.cc | 1 + python/test/unit/language/test_core.py | 5 +- python/triton/language/semantic.py | 6 + test/TritonGPU/bf16x3-matmul.mlir | 39 +++++ third_party/amd/backend/compiler.py | 3 +- third_party/nvidia/backend/compiler.py | 3 +- 12 files changed, 391 insertions(+), 5 deletions(-) create mode 100644 03-matrix-multiplication.py create mode 100644 lib/Dialect/TritonGPU/Transforms/BF16DotTC.cpp create mode 100644 test/TritonGPU/bf16x3-matmul.mlir diff --git a/03-matrix-multiplication.py b/03-matrix-multiplication.py new file mode 100644 index 000000000000..13448035b6b8 --- /dev/null +++ b/03-matrix-multiplication.py @@ -0,0 +1,160 @@ +import torch + +import triton +import triton.language as tl + +DEVICE = triton.runtime.driver.active.get_active_torch_device() + +def is_cuda(): + return triton.runtime.driver.active.get_current_target().backend == "cuda" + +def get_cuda_autotune_config(): + return [ + triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5,num_warps=2), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5,num_warps=2), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4,num_warps=4), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4,num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4,num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4,num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4,num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=3,num_warps=8), + ] + +def get_hip_autotune_config(): + sizes = [ + {'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 6}, + {'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 6}, + {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 4}, + {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 6}, + {'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 4}, + {'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 4}, + {'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 4}, + {'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 6}, + ] + return [triton.Config(s | {'matrix_instr_nonkdim': 16}, num_warps=8, num_stages=2) for s in sizes] + +def get_autotune_config(): + if is_cuda(): + return get_cuda_autotune_config() + else: + return get_hip_autotune_config() + +@triton.autotune( + configs=get_autotune_config(), + key=['M', 'N', 'K'], +) +@triton.jit +def matmul_kernel( + a_ptr, b_ptr, c_ptr, + M, N, K, + stride_am, stride_ak, + stride_bk, stride_bn, + stride_cm, stride_cn, + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + PRECISION: tl.constexpr +): + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + tl.assume(pid_m >= 0) + tl.assume(pid_n >= 0) + tl.assume(stride_am > 0) + tl.assume(stride_ak > 0) + tl.assume(stride_bn > 0) + tl.assume(stride_bk > 0) + tl.assume(stride_cm > 0) + tl.assume(stride_cn > 0) + + offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M + offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) + b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + accumulator += tl.dot(a, b, input_precision=PRECISION) + # accumulator = tl.dot(a, b, accumulator) + # Advance the ptrs to the next K block. + a_ptrs += BLOCK_SIZE_K * stride_ak + b_ptrs += BLOCK_SIZE_K * stride_bk + c = accumulator.to(tl.float32) + + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + tl.store(c_ptrs, c, mask=c_mask) + + +def matmul(a, b, precision="ieee"): + M, K = a.shape + K, N = b.shape + c = torch.empty((M, N), device=a.device, dtype=torch.float32) + grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), ) + matmul_kernel[grid]( + a, b, c, + M, N, K, + a.stride(0), a.stride(1), + b.stride(0), b.stride(1), + c.stride(0), c.stride(1), + PRECISION=precision + ) + return c + + +precisions = ["ieee", "bf16", "bf16x3", "bf16x6", "bf16x9"] +torch.manual_seed(0) + +for precision in precisions: + a = torch.rand((512, 512), device=DEVICE, dtype=torch.float32) - 0.5 + b = torch.rand((512, 512), device=DEVICE, dtype=torch.float32) - 0.5 + triton_output = matmul(a, b, precision=precision) + torch_output = torch.matmul(a, b) + #print(f"triton_output_with_fp32_inputs={triton_output}") + #print(f"torch_output_with_fp32_inputs={torch_output}") + + if torch.allclose(triton_output, torch_output, atol=1e-2, rtol=0): + print(f'✅ Triton and Torch match for input_precision={precision}') + else: + print(f'❌ Triton and Torch differ for input_precision={precision}') + +ref_lib = 'cuBLAS' if is_cuda() else 'rocBLAS' + +configs = [] +configs.append( + triton.testing.Benchmark( + x_names=["M", "N", "K"], + x_vals=[128 * i for i in range(2, 33)], + line_arg="provider", + line_vals=[ref_lib.lower(), "triton-ieee", "triton-bf16", "triton-bf16x3", "triton-bf16x6", "triton-bf16x9"], + line_names=[ref_lib, "Triton-IEEE", "Triton-BF16", "Triton-BF16x3", "Triton-BF16x6", "Triton-BF16x9"], + styles=[("green", "-"), ("blue", "-"), ("blue", "-"), ("blue", "-"), ("blue", "-"), ("blue", "-")], + ylabel="TFLOPS", + plot_name="matmul-performance-f32", + args={}, + )) + +@triton.testing.perf_report(configs) +def benchmark(M, N, K, provider): + a = torch.randn((M, K), device=DEVICE, dtype=torch.float32) + b = torch.randn((K, N), device=DEVICE, dtype=torch.float32) + quantiles = [0.5, 0.2, 0.8] + if provider == ref_lib.lower(): + ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.matmul(a, b), quantiles=quantiles) + if provider.startswith('triton-'): + ms, min_ms, max_ms = triton.testing.do_bench(lambda: matmul(a, b, provider.removeprefix('triton-')), quantiles=quantiles) + perf = lambda ms: 2 * M * N * K * 1e-12 / (ms * 1e-3) + return perf(ms), perf(max_ms), perf(min_ms) + +benchmark.run(show_plots=False, print_data=True) diff --git a/include/triton/Dialect/Triton/IR/TritonAttrDefs.td b/include/triton/Dialect/Triton/IR/TritonAttrDefs.td index 2414fb9d76e9..24d345a21d53 100644 --- a/include/triton/Dialect/Triton/IR/TritonAttrDefs.td +++ b/include/triton/Dialect/Triton/IR/TritonAttrDefs.td @@ -129,7 +129,11 @@ def TT_InputPrecisionAttr : I32EnumAttr< [ I32EnumAttrCase<"TF32", 0, "tf32">, I32EnumAttrCase<"TF32x3", 1, "tf32x3">, - I32EnumAttrCase<"IEEE", 2, "ieee"> + I32EnumAttrCase<"IEEE", 2, "ieee">, + I32EnumAttrCase<"BF16", 3, "bf16">, + I32EnumAttrCase<"BF16x3", 4, "bf16x3">, + I32EnumAttrCase<"BF16x6", 5, "bf16x6">, + I32EnumAttrCase<"BF16x9", 6, "bf16x9"> ]>{ let cppNamespace = "::mlir::triton"; } diff --git a/include/triton/Dialect/TritonGPU/Transforms/Passes.td b/include/triton/Dialect/TritonGPU/Transforms/Passes.td index 51c9ee17709d..1d62c4a4271f 100644 --- a/include/triton/Dialect/TritonGPU/Transforms/Passes.td +++ b/include/triton/Dialect/TritonGPU/Transforms/Passes.td @@ -188,6 +188,17 @@ def TritonGPUF32DotTC : Pass<"tritongpu-F32DotTC", "mlir::ModuleOp"> { "mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect"]; } +def TritonGPUBF16DotTC : Pass<"tritongpu-BF16DotTC", "mlir::ModuleOp"> { + let summary = "3xBF16 trick"; + + let description = [{ + Decompose fp32 `DotOp` instructions into BF16 operations. + See https://arxiv.org/abs/1904.06376 + }]; + + let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect"]; +} + def TritonGPUPrefetch : Pass<"tritongpu-prefetch", "mlir::ModuleOp"> { let summary = "prefetch"; diff --git a/lib/Dialect/TritonGPU/Transforms/BF16DotTC.cpp b/lib/Dialect/TritonGPU/Transforms/BF16DotTC.cpp new file mode 100644 index 000000000000..7c60755884ae --- /dev/null +++ b/lib/Dialect/TritonGPU/Transforms/BF16DotTC.cpp @@ -0,0 +1,157 @@ +#include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "triton/Dialect/TritonGPU/Transforms/Passes.h" +#include "llvm/Support/raw_ostream.h" + +namespace mlir { +namespace triton { +namespace gpu { + +#define GEN_PASS_DEF_TRITONGPUBF16DOTTC +#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc" + +namespace { + +// Implement 3xBF16 https://arxiv.org/abs/1904.06376 +class BF16x3 : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(DotOp dotOp, + PatternRewriter &rewriter) const override { + switch (dotOp.getInputPrecision()) { + case InputPrecision::BF16: + case InputPrecision::BF16x3: + case InputPrecision::BF16x6: + case InputPrecision::BF16x9: + break; + default: + return failure(); + } + + auto isF32 = [](Value operand) { + return cast(operand.getType()).getElementType().isF32(); + }; + if (!isF32(dotOp.getA()) || !isF32(dotOp.getB())) { + return failure(); + } + + // Aux functions + auto f32ToBF16 = [&](Value value) -> Value { + auto fp32Type = cast(value.getType()); + auto bf16Type = + RankedTensorType::get(fp32Type.getShape(), rewriter.getBF16Type(), fp32Type.getEncoding()); + return rewriter.create(dotOp.getLoc(), bf16Type, value) + .getResult(); + }; + auto bf16ToF32 = [&](Value value) -> Value { + auto bf16Type = cast(value.getType()); + auto fp32Type = + RankedTensorType::get(bf16Type.getShape(), rewriter.getF32Type(), bf16Type.getEncoding()); + return rewriter.create(dotOp.getLoc(), fp32Type, value) + .getResult(); + }; + auto zeroLike = [&](Value c) -> Value { + return rewriter.create( + dotOp->getLoc(), c.getType(), + rewriter.create(dotOp->getLoc(), + rewriter.getF32FloatAttr(0))); + }; + auto add = [&](Value a, Value b) -> Value { + return rewriter.create(dotOp.getLoc(), a, b); + }; + auto sub = [&](Value a, Value b) -> Value { + return rewriter.create(dotOp.getLoc(), a, b); + }; + auto dot = [&](Value a, Value b, Value c) -> Value { + return rewriter.create(dotOp->getLoc(), c.getType(), a, b, c, + InputPrecision::BF16, + dotOp.getMaxNumImpreciseAcc()); + }; + auto replaceNansWithZeros = [&](Value value) -> Value { + auto nans = rewriter.create( + dotOp->getLoc(), arith::CmpFPredicate::UNO, value, value); + auto zero = zeroLike(value); + return rewriter.create(dotOp->getLoc(), nans, zero, + value); + }; + + auto SplitF32 = [&](Value input, unsigned N) -> std::vector { + std::vector split_inputs; + split_inputs.reserve(N); + for (int i = 0; i < N; ++i) { + Value input_as_bf16 = f32ToBF16(input); + if (i != N - 1) { + Value input_as_f32 = bf16ToF32(input_as_bf16); + input = rewriter.create(dotOp->getLoc(), input, + input_as_f32); + } + split_inputs.push_back(input_as_bf16); + } + return split_inputs; + }; + + const int hi = 0; + const int med = 1; + const int lo = 2; + + const unsigned N = 3; + auto lhs_parts = SplitF32(dotOp.getA(), N); + auto rhs_parts = SplitF32(dotOp.getB(), N); + + auto result = zeroLike(dotOp.getC()); + + if (dotOp.getInputPrecision() == InputPrecision::BF16x9) { + result = dot(lhs_parts[lo], rhs_parts[lo], result); + result = dot(lhs_parts[med], rhs_parts[lo], result); + result = dot(lhs_parts[lo], rhs_parts[med], result); + + result = dot(lhs_parts[med], rhs_parts[med], result); + + result = dot(lhs_parts[lo], rhs_parts[hi], result); + result = dot(lhs_parts[hi], rhs_parts[lo], result); + + result = dot(lhs_parts[med], rhs_parts[hi], result); + result = dot(lhs_parts[hi], rhs_parts[med], result); + + } else if (dotOp.getInputPrecision() == InputPrecision::BF16x6) { + result = dot(lhs_parts[med], rhs_parts[med], result); + + result = dot(lhs_parts[lo], rhs_parts[hi], result); + result = dot(lhs_parts[hi], rhs_parts[lo], result); + + result = dot(lhs_parts[med], rhs_parts[hi], result); + result = dot(lhs_parts[hi], rhs_parts[med], result); + + } else if (dotOp.getInputPrecision() == InputPrecision::BF16x3) { + result = dot(lhs_parts[med], rhs_parts[hi], result); + result = dot(lhs_parts[hi], rhs_parts[med], result); + } + + result = replaceNansWithZeros(result); + result = dot(lhs_parts[hi], rhs_parts[hi], result); + result = add(result, dotOp.getC()); + + rewriter.replaceOp(dotOp, result); + return success(); + } +}; + +} // anonymous namespace + +struct BF16DotTCPass : public impl::TritonGPUBF16DotTCBase { + void runOnOperation() override { + MLIRContext *context = &getContext(); + ModuleOp m = getOperation(); + + RewritePatternSet decomposePatterns(context); + decomposePatterns.add(context); + if (applyPatternsGreedily(m, std::move(decomposePatterns)).failed()) { + signalPassFailure(); + } + } +}; + +} // namespace gpu +} // namespace triton +} // namespace mlir diff --git a/lib/Dialect/TritonGPU/Transforms/CMakeLists.txt b/lib/Dialect/TritonGPU/Transforms/CMakeLists.txt index 965b1e1e7d0e..99e756c6c430 100644 --- a/lib/Dialect/TritonGPU/Transforms/CMakeLists.txt +++ b/lib/Dialect/TritonGPU/Transforms/CMakeLists.txt @@ -2,6 +2,7 @@ add_triton_library(TritonGPUTransforms AccelerateMatmul.cpp Coalesce.cpp F32DotTC.cpp + BF16DotTC.cpp FuseNestedLoops.cpp CombineTensorSelectAndIf.cpp DecomposeScaledBlocked.cpp diff --git a/python/src/ir.cc b/python/src/ir.cc index 1ac95724a169..f4fee5e05f1b 100644 --- a/python/src/ir.cc +++ b/python/src/ir.cc @@ -308,6 +308,10 @@ void init_triton_ir(py::module &&m) { .value("TF32", InputPrecision::TF32) .value("TF32x3", InputPrecision::TF32x3) .value("IEEE", InputPrecision::IEEE) + .value("BF16", InputPrecision::BF16) + .value("BF16x3", InputPrecision::BF16x3) + .value("BF16x6", InputPrecision::BF16x6) + .value("BF16x9", InputPrecision::BF16x9) .export_values(); py::enum_(m, "ScaleDotElemTypeTY", py::module_local()) diff --git a/python/src/passes.cc b/python/src/passes.cc index 3cd6f79c1084..a1162acfb7f6 100644 --- a/python/src/passes.cc +++ b/python/src/passes.cc @@ -71,6 +71,7 @@ void init_triton_passes_ttgpuir(py::module &&m) { ADD_PASS_WRAPPER_0("add_accelerate_matmul", createTritonGPUAccelerateMatmul); ADD_PASS_WRAPPER_0("add_reorder_instructions", createTritonGPUReorderInstructions); + ADD_PASS_WRAPPER_0("add_bf16_dot_tc", createTritonGPUBF16DotTC); ADD_PASS_WRAPPER_0("add_f32_dot_tc", createTritonGPUF32DotTC); ADD_PASS_OPTION_WRAPPER_1("add_optimize_dot_operands", createTritonGPUOptimizeDotOperands, bool); diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index bf3d85417cff..8927ecfa31ae 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -3084,7 +3084,7 @@ def get_test_dot_base_cases(): return [(*shape, 4, False, False, epilogue, input_precision, in_dtype, out_dtype, 1, None) for shape in [(64, 64, 64), (32, 32, 32), (16, 16, 16)] for epilogue in ['none', 'trans', 'add-matrix', 'add-rows', 'add-cols', 'softmax', 'chain-dot'] - for input_precision in ['tf32', 'tf32x3', 'ieee'] + for input_precision in ['tf32', 'tf32x3', 'ieee', 'bf16', 'bf16x3'] for in_dtype, out_dtype in [('float16', 'float16'), ('float16', 'float32'), ('float32', 'float32'), ('float64', 'float64')] @@ -3238,7 +3238,8 @@ def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, input_precision, in_dty pytest.skip(f"{in_dtype} only supported on CDNA4 and gfx12") if in_dtype in ("float8e5b16", "float8e4b8") and not is_hip_cdna3(): pytest.skip(f"{in_dtype} only supported on CDNA3") - if not ((input_precision == "ieee") or (input_precision == "tf32" and is_hip_cdna3())): + if not ((input_precision == "bf16x3") or (input_precision == "ieee") or + (input_precision == "tf32" and is_hip_cdna3())): pytest.skip(f"{input_precision} not supported on HIP") if kpack == 2 and in_dtype == 'int8' and K < 64: pytest.skip("kpack too large for K") diff --git a/python/triton/language/semantic.py b/python/triton/language/semantic.py index 70a6672c8733..c0c5b383da96 100644 --- a/python/triton/language/semantic.py +++ b/python/triton/language/semantic.py @@ -1467,6 +1467,12 @@ def _str_to_dot_input_precision(self, input_precision): input_precision = input_precision.upper() if input_precision == "TF32X3": input_precision = "TF32x3" + if input_precision == "BF16X3": + input_precision = "BF16x3" + if input_precision == "BF16X6": + input_precision = "BF16x6" + if input_precision == "BF16X9": + input_precision = "BF16x9" return getattr(ir.INPUT_PRECISION, input_precision) def dot(self, lhs: TensorTy, rhs: TensorTy, acc: TensorTy, input_precision: Optional[str], diff --git a/test/TritonGPU/bf16x3-matmul.mlir b/test/TritonGPU/bf16x3-matmul.mlir new file mode 100644 index 000000000000..e5b46fbbc016 --- /dev/null +++ b/test/TritonGPU/bf16x3-matmul.mlir @@ -0,0 +1,39 @@ +// RUN: triton-opt %s -tritongpu-BF16DotTC -canonicalize | FileCheck %s --check-prefixes=CHECK + +// CHECK: %[[lhs_hi:.*]] = arith.truncf %arg0 +// CHECK-NEXT: %[[val1:.*]] = arith.extf %[[lhs_hi]] +// CHECK-NEXT: %[[val2:.*]] = arith.subf %arg0, %[[val1]] +// CHECK-NEXT: %[[lhs_mid:.*]] = arith.truncf %[[val2]] +// CHECK-NEXT: %[[val4:.*]] = arith.extf %[[lhs_mid]] +// CHECK-NEXT: %[[val5:.*]] = arith.subf %[[val2]], %[[val4]] +// CHECK-NEXT: %[[lhs_lo:.*]] = arith.truncf %[[val5]] + +// CHECK: %[[rhs_hi:.*]] = arith.truncf %arg1 +// CHECK-NEXT: %[[val8:.*]] = arith.extf %[[rhs_hi]] +// CHECK-NEXT: %[[val9:.*]] = arith.subf %arg1, %[[val8]] +// CHECK-NEXT: %[[rhs_mid:.*]] = arith.truncf %[[val9]] +// CHECK-NEXT: %[[val11:.*]] = arith.extf %[[rhs_mid]] +// CHECK-NEXT: %[[val12:.*]] = arith.subf %[[val9]], %[[val11]] +// CHECK-NEXT: %[[rhs_lo:.*]] = arith.truncf %[[val12]] + +// CHECK: %[[val14:.*]] = tt.dot %[[lhs_lo]], %[[rhs_lo]] +// CHECK-NEXT: %[[val15:.*]] = tt.dot %[[lhs_mid]], %[[rhs_lo]], %[[val14]], inputPrecision = bf16 +// CHECK-NEXT: %[[val16:.*]] = tt.dot %[[lhs_lo]], %[[rhs_mid]], %[[val15]], inputPrecision = bf16 +// CHECK-NEXT: %[[val17:.*]] = tt.dot %[[lhs_mid]], %[[rhs_mid]], %[[val16]], inputPrecision = bf16 +// CHECK-NEXT: %[[val18:.*]] = tt.dot %[[lhs_lo]], %[[rhs_hi]], %[[val17]], inputPrecision = bf16 +// CHECK-NEXT: %[[val19:.*]] = tt.dot %[[lhs_hi]], %[[rhs_lo]], %[[val18]], inputPrecision = bf16 +// CHECK-NEXT: %[[val20:.*]] = tt.dot %[[lhs_mid]], %[[rhs_hi]], %[[val19]], inputPrecision = bf16 +// CHECK-NEXT: %[[val21:.*]] = tt.dot %[[lhs_hi]], %[[rhs_mid]], %[[val20]], inputPrecision = bf16 + +// CHECK: %[[val22:.*]] = arith.cmpf uno, %[[val21]], %[[val21]] +// CHECK-NEXT: %[[val23:.*]] = arith.select %[[val22]] + +// CHECK: %[[val24:.*]] = tt.dot %[[lhs_hi]], %[[rhs_hi]], %[[val23]], inputPrecision = bf16 +// CHECK-NEXT: %[[val25:.*]] = arith.addf %[[val24]], %arg2 + +module { + tt.func @dot_test(%arg0: tensor<16x16xf32>, %arg1: tensor<16x16xf32>, %arg2: tensor<16x16xf32>) -> tensor<16x16xf32> { + %4 = tt.dot %arg0, %arg1, %arg2, inputPrecision = bf16x3 : tensor<16x16xf32> * tensor<16x16xf32> -> tensor<16x16xf32> + tt.return %4 : tensor<16x16xf32> + } +} diff --git a/third_party/amd/backend/compiler.py b/third_party/amd/backend/compiler.py index e7c5a6674dde..e478b6e8c76f 100644 --- a/third_party/amd/backend/compiler.py +++ b/third_party/amd/backend/compiler.py @@ -45,7 +45,7 @@ class HIPOptions: supported_fp8_dtypes: Tuple[str] = ("fp8e4nv", "fp8e5", "fp8e5b16", "fp8e4b8") deprecated_fp8_dot_operand_dtypes: Tuple[str] = () default_dot_input_precision: str = "ieee" - allowed_dot_input_precisions: Tuple[str] = ("ieee", ) + allowed_dot_input_precisions: Tuple[str] = ("ieee", 'bf16x3', 'bf16', 'bf16x6', 'bf16x9') enable_fp_fusion: bool = True launch_cooperative_grid: bool = False matrix_instr_nonkdim: int = 0 @@ -207,6 +207,7 @@ def make_ttgir(mod, metadata, options): pm.run(mod, 'make_ttgir_early') pm = ir.pass_manager(mod.context) pm.enable_debug() + passes.ttgpuir.add_bf16_dot_tc(pm) passes.ttgpuir.add_coalesce(pm) passes.ttgpuir.add_remove_layout_conversions(pm) passes.ttgpuir.add_optimize_thread_locality(pm) diff --git a/third_party/nvidia/backend/compiler.py b/third_party/nvidia/backend/compiler.py index 1586fa54909e..29dbc1efe53f 100644 --- a/third_party/nvidia/backend/compiler.py +++ b/third_party/nvidia/backend/compiler.py @@ -121,7 +121,7 @@ class CUDAOptions: supported_fp8_dtypes: Tuple[str] = ("fp8e5", "fp8e4b15") deprecated_fp8_dot_operand_dtypes: Tuple[str] = () default_dot_input_precision: str = "tf32" - allowed_dot_input_precisions: Tuple[str] = ("tf32", "tf32x3", "ieee") + allowed_dot_input_precisions: Tuple[str] = ("tf32", "tf32x3", "ieee", 'bf16x3', 'bf16', 'bf16x6', 'bf16x9') max_num_imprecise_acc_default: bool = None extern_libs: dict = None debug: bool = False @@ -262,6 +262,7 @@ def make_ttgir(mod, metadata, opt, capability): passes.ttir.add_convert_to_ttgpuir(pm, f"cuda:{capability}", opt.num_warps, 32, opt.num_ctas) # optimize TTGIR passes.ttgpuir.add_coalesce(pm) + passes.ttgpuir.add_bf16_dot_tc(pm) if capability // 10 >= 8: passes.ttgpuir.add_f32_dot_tc(pm) # TODO(Qingyi): Move PlanCTAPass to the front of CoalescePass From df46410236ec0a1162466ef4f9de6130af3900e4 Mon Sep 17 00:00:00 2001 From: Puyan Lotfi Date: Mon, 29 Sep 2025 23:52:16 -0700 Subject: [PATCH 02/39] move pass to after coalescer --- third_party/amd/backend/compiler.py | 2 +- third_party/nvidia/backend/compiler.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/amd/backend/compiler.py b/third_party/amd/backend/compiler.py index e478b6e8c76f..2aad479baa8e 100644 --- a/third_party/amd/backend/compiler.py +++ b/third_party/amd/backend/compiler.py @@ -207,8 +207,8 @@ def make_ttgir(mod, metadata, options): pm.run(mod, 'make_ttgir_early') pm = ir.pass_manager(mod.context) pm.enable_debug() - passes.ttgpuir.add_bf16_dot_tc(pm) passes.ttgpuir.add_coalesce(pm) + passes.ttgpuir.add_bf16_dot_tc(pm) passes.ttgpuir.add_remove_layout_conversions(pm) passes.ttgpuir.add_optimize_thread_locality(pm) amd.passes.ttgpuir.add_accelerate_matmul(pm, options.arch, options.matrix_instr_nonkdim, options.kpack) diff --git a/third_party/nvidia/backend/compiler.py b/third_party/nvidia/backend/compiler.py index 29dbc1efe53f..e32025abbc4b 100644 --- a/third_party/nvidia/backend/compiler.py +++ b/third_party/nvidia/backend/compiler.py @@ -262,9 +262,9 @@ def make_ttgir(mod, metadata, opt, capability): passes.ttir.add_convert_to_ttgpuir(pm, f"cuda:{capability}", opt.num_warps, 32, opt.num_ctas) # optimize TTGIR passes.ttgpuir.add_coalesce(pm) - passes.ttgpuir.add_bf16_dot_tc(pm) if capability // 10 >= 8: passes.ttgpuir.add_f32_dot_tc(pm) + passes.ttgpuir.add_bf16_dot_tc(pm) # TODO(Qingyi): Move PlanCTAPass to the front of CoalescePass nvidia.passes.ttnvgpuir.add_plan_cta(pm, cluster_info) passes.ttgpuir.add_remove_layout_conversions(pm) From 5ac391b07cf655c80c70b53080f50cb7b9e4900c Mon Sep 17 00:00:00 2001 From: Puyan Lotfi Date: Mon, 29 Sep 2025 23:53:52 -0700 Subject: [PATCH 03/39] [NFC] drop raw_ostream include --- lib/Dialect/TritonGPU/Transforms/BF16DotTC.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/lib/Dialect/TritonGPU/Transforms/BF16DotTC.cpp b/lib/Dialect/TritonGPU/Transforms/BF16DotTC.cpp index 7c60755884ae..1a025cad1a8f 100644 --- a/lib/Dialect/TritonGPU/Transforms/BF16DotTC.cpp +++ b/lib/Dialect/TritonGPU/Transforms/BF16DotTC.cpp @@ -1,7 +1,6 @@ #include "mlir/IR/PatternMatch.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "triton/Dialect/TritonGPU/Transforms/Passes.h" -#include "llvm/Support/raw_ostream.h" namespace mlir { namespace triton { From 18daed190ef9734ad389904d23cc7f7913cbc59a Mon Sep 17 00:00:00 2001 From: Puyan Lotfi Date: Mon, 29 Sep 2025 23:54:17 -0700 Subject: [PATCH 04/39] [NFC] drop sub lambda unused --- lib/Dialect/TritonGPU/Transforms/BF16DotTC.cpp | 3 --- 1 file changed, 3 deletions(-) diff --git a/lib/Dialect/TritonGPU/Transforms/BF16DotTC.cpp b/lib/Dialect/TritonGPU/Transforms/BF16DotTC.cpp index 1a025cad1a8f..aae83afae4f9 100644 --- a/lib/Dialect/TritonGPU/Transforms/BF16DotTC.cpp +++ b/lib/Dialect/TritonGPU/Transforms/BF16DotTC.cpp @@ -59,9 +59,6 @@ class BF16x3 : public OpRewritePattern { auto add = [&](Value a, Value b) -> Value { return rewriter.create(dotOp.getLoc(), a, b); }; - auto sub = [&](Value a, Value b) -> Value { - return rewriter.create(dotOp.getLoc(), a, b); - }; auto dot = [&](Value a, Value b, Value c) -> Value { return rewriter.create(dotOp->getLoc(), c.getType(), a, b, c, InputPrecision::BF16, From 94d4bd6ad277697e3d4512866ee1a2adec4d9feb Mon Sep 17 00:00:00 2001 From: Puyan Lotfi Date: Mon, 29 Sep 2025 23:58:05 -0700 Subject: [PATCH 05/39] [NFC] collapse reused code handling BF16x3 --- lib/Dialect/TritonGPU/Transforms/BF16DotTC.cpp | 14 ++++---------- 1 file changed, 4 insertions(+), 10 deletions(-) diff --git a/lib/Dialect/TritonGPU/Transforms/BF16DotTC.cpp b/lib/Dialect/TritonGPU/Transforms/BF16DotTC.cpp index aae83afae4f9..8d945c2d4e9c 100644 --- a/lib/Dialect/TritonGPU/Transforms/BF16DotTC.cpp +++ b/lib/Dialect/TritonGPU/Transforms/BF16DotTC.cpp @@ -107,23 +107,17 @@ class BF16x3 : public OpRewritePattern { result = dot(lhs_parts[lo], rhs_parts[hi], result); result = dot(lhs_parts[hi], rhs_parts[lo], result); - result = dot(lhs_parts[med], rhs_parts[hi], result); - result = dot(lhs_parts[hi], rhs_parts[med], result); - } else if (dotOp.getInputPrecision() == InputPrecision::BF16x6) { result = dot(lhs_parts[med], rhs_parts[med], result); result = dot(lhs_parts[lo], rhs_parts[hi], result); result = dot(lhs_parts[hi], rhs_parts[lo], result); - - result = dot(lhs_parts[med], rhs_parts[hi], result); - result = dot(lhs_parts[hi], rhs_parts[med], result); - - } else if (dotOp.getInputPrecision() == InputPrecision::BF16x3) { - result = dot(lhs_parts[med], rhs_parts[hi], result); - result = dot(lhs_parts[hi], rhs_parts[med], result); } + // BF16x3, BF16x6, BF16x9 all need this + result = dot(lhs_parts[med], rhs_parts[hi], result); + result = dot(lhs_parts[hi], rhs_parts[med], result); + result = replaceNansWithZeros(result); result = dot(lhs_parts[hi], rhs_parts[hi], result); result = add(result, dotOp.getC()); From 5cc88bd5373dfedbef34b1ec00b8b9d2eff3587a Mon Sep 17 00:00:00 2001 From: Puyan Lotfi Date: Tue, 30 Sep 2025 00:19:55 -0700 Subject: [PATCH 06/39] [NFC] flatted add lambda --- lib/Dialect/TritonGPU/Transforms/BF16DotTC.cpp | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/lib/Dialect/TritonGPU/Transforms/BF16DotTC.cpp b/lib/Dialect/TritonGPU/Transforms/BF16DotTC.cpp index 8d945c2d4e9c..8fb6f1f08e46 100644 --- a/lib/Dialect/TritonGPU/Transforms/BF16DotTC.cpp +++ b/lib/Dialect/TritonGPU/Transforms/BF16DotTC.cpp @@ -56,9 +56,6 @@ class BF16x3 : public OpRewritePattern { rewriter.create(dotOp->getLoc(), rewriter.getF32FloatAttr(0))); }; - auto add = [&](Value a, Value b) -> Value { - return rewriter.create(dotOp.getLoc(), a, b); - }; auto dot = [&](Value a, Value b, Value c) -> Value { return rewriter.create(dotOp->getLoc(), c.getType(), a, b, c, InputPrecision::BF16, @@ -120,7 +117,7 @@ class BF16x3 : public OpRewritePattern { result = replaceNansWithZeros(result); result = dot(lhs_parts[hi], rhs_parts[hi], result); - result = add(result, dotOp.getC()); + result = rewriter.create(dotOp.getLoc(), result, dotOp.getC()); rewriter.replaceOp(dotOp, result); return success(); From e284c855a272b07a537879db729dc0695dafa8d4 Mon Sep 17 00:00:00 2001 From: Puyan Lotfi Date: Tue, 30 Sep 2025 00:32:24 -0700 Subject: [PATCH 07/39] [NFCi] skip 2xdots for 1xBF16: 1xBF16 will likely get dropped --- lib/Dialect/TritonGPU/Transforms/BF16DotTC.cpp | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/lib/Dialect/TritonGPU/Transforms/BF16DotTC.cpp b/lib/Dialect/TritonGPU/Transforms/BF16DotTC.cpp index 8fb6f1f08e46..4aae0bc0b0d9 100644 --- a/lib/Dialect/TritonGPU/Transforms/BF16DotTC.cpp +++ b/lib/Dialect/TritonGPU/Transforms/BF16DotTC.cpp @@ -112,8 +112,10 @@ class BF16x3 : public OpRewritePattern { } // BF16x3, BF16x6, BF16x9 all need this - result = dot(lhs_parts[med], rhs_parts[hi], result); - result = dot(lhs_parts[hi], rhs_parts[med], result); + if (dotOp.getInputPrecision() != InputPrecision::BF16) { + result = dot(lhs_parts[med], rhs_parts[hi], result); + result = dot(lhs_parts[hi], rhs_parts[med], result); + } result = replaceNansWithZeros(result); result = dot(lhs_parts[hi], rhs_parts[hi], result); From 56d86c801990023cfb9017945f5ae094883bf709 Mon Sep 17 00:00:00 2001 From: Puyan Lotfi Date: Tue, 30 Sep 2025 00:34:17 -0700 Subject: [PATCH 08/39] [NFCi] change into to unsigned and med to mid --- .../TritonGPU/Transforms/BF16DotTC.cpp | 20 +++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/lib/Dialect/TritonGPU/Transforms/BF16DotTC.cpp b/lib/Dialect/TritonGPU/Transforms/BF16DotTC.cpp index 4aae0bc0b0d9..300e0b04d2dd 100644 --- a/lib/Dialect/TritonGPU/Transforms/BF16DotTC.cpp +++ b/lib/Dialect/TritonGPU/Transforms/BF16DotTC.cpp @@ -72,7 +72,7 @@ class BF16x3 : public OpRewritePattern { auto SplitF32 = [&](Value input, unsigned N) -> std::vector { std::vector split_inputs; split_inputs.reserve(N); - for (int i = 0; i < N; ++i) { + for (unsigned i = 0; i < N; ++i) { Value input_as_bf16 = f32ToBF16(input); if (i != N - 1) { Value input_as_f32 = bf16ToF32(input_as_bf16); @@ -84,9 +84,9 @@ class BF16x3 : public OpRewritePattern { return split_inputs; }; - const int hi = 0; - const int med = 1; - const int lo = 2; + const unsigned hi = 0; + const unsigned mid = 1; + const unsigned lo = 2; const unsigned N = 3; auto lhs_parts = SplitF32(dotOp.getA(), N); @@ -96,16 +96,16 @@ class BF16x3 : public OpRewritePattern { if (dotOp.getInputPrecision() == InputPrecision::BF16x9) { result = dot(lhs_parts[lo], rhs_parts[lo], result); - result = dot(lhs_parts[med], rhs_parts[lo], result); - result = dot(lhs_parts[lo], rhs_parts[med], result); + result = dot(lhs_parts[mid], rhs_parts[lo], result); + result = dot(lhs_parts[lo], rhs_parts[mid], result); - result = dot(lhs_parts[med], rhs_parts[med], result); + result = dot(lhs_parts[mid], rhs_parts[mid], result); result = dot(lhs_parts[lo], rhs_parts[hi], result); result = dot(lhs_parts[hi], rhs_parts[lo], result); } else if (dotOp.getInputPrecision() == InputPrecision::BF16x6) { - result = dot(lhs_parts[med], rhs_parts[med], result); + result = dot(lhs_parts[mid], rhs_parts[mid], result); result = dot(lhs_parts[lo], rhs_parts[hi], result); result = dot(lhs_parts[hi], rhs_parts[lo], result); @@ -113,8 +113,8 @@ class BF16x3 : public OpRewritePattern { // BF16x3, BF16x6, BF16x9 all need this if (dotOp.getInputPrecision() != InputPrecision::BF16) { - result = dot(lhs_parts[med], rhs_parts[hi], result); - result = dot(lhs_parts[hi], rhs_parts[med], result); + result = dot(lhs_parts[mid], rhs_parts[hi], result); + result = dot(lhs_parts[hi], rhs_parts[mid], result); } result = replaceNansWithZeros(result); From 967e686957e03c8f5eb40098f7b61f6dc8447b86 Mon Sep 17 00:00:00 2001 From: Puyan Lotfi Date: Tue, 30 Sep 2025 00:38:53 -0700 Subject: [PATCH 09/39] [NFCi] flatten zeroLike into a single reusable Constant zero fp32 value --- lib/Dialect/TritonGPU/Transforms/BF16DotTC.cpp | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/lib/Dialect/TritonGPU/Transforms/BF16DotTC.cpp b/lib/Dialect/TritonGPU/Transforms/BF16DotTC.cpp index 300e0b04d2dd..5a5dcc53ae07 100644 --- a/lib/Dialect/TritonGPU/Transforms/BF16DotTC.cpp +++ b/lib/Dialect/TritonGPU/Transforms/BF16DotTC.cpp @@ -50,12 +50,10 @@ class BF16x3 : public OpRewritePattern { return rewriter.create(dotOp.getLoc(), fp32Type, value) .getResult(); }; - auto zeroLike = [&](Value c) -> Value { - return rewriter.create( - dotOp->getLoc(), c.getType(), - rewriter.create(dotOp->getLoc(), - rewriter.getF32FloatAttr(0))); - }; + Value zero = rewriter.create( + dotOp->getLoc(), dotOp.getC().getType(), + rewriter.create(dotOp->getLoc(), + rewriter.getF32FloatAttr(0))); auto dot = [&](Value a, Value b, Value c) -> Value { return rewriter.create(dotOp->getLoc(), c.getType(), a, b, c, InputPrecision::BF16, @@ -64,7 +62,6 @@ class BF16x3 : public OpRewritePattern { auto replaceNansWithZeros = [&](Value value) -> Value { auto nans = rewriter.create( dotOp->getLoc(), arith::CmpFPredicate::UNO, value, value); - auto zero = zeroLike(value); return rewriter.create(dotOp->getLoc(), nans, zero, value); }; @@ -92,7 +89,7 @@ class BF16x3 : public OpRewritePattern { auto lhs_parts = SplitF32(dotOp.getA(), N); auto rhs_parts = SplitF32(dotOp.getB(), N); - auto result = zeroLike(dotOp.getC()); + auto result = zero; if (dotOp.getInputPrecision() == InputPrecision::BF16x9) { result = dot(lhs_parts[lo], rhs_parts[lo], result); From 5dd2ce69dbdae01c659a930168a37b3b25979401 Mon Sep 17 00:00:00 2001 From: Puyan Lotfi Date: Tue, 30 Sep 2025 00:46:44 -0700 Subject: [PATCH 10/39] [NFCi] clean up Split32 into its own helper function --- .../TritonGPU/Transforms/BF16DotTC.cpp | 55 ++++++++----------- 1 file changed, 23 insertions(+), 32 deletions(-) diff --git a/lib/Dialect/TritonGPU/Transforms/BF16DotTC.cpp b/lib/Dialect/TritonGPU/Transforms/BF16DotTC.cpp index 5a5dcc53ae07..4b121f433eb7 100644 --- a/lib/Dialect/TritonGPU/Transforms/BF16DotTC.cpp +++ b/lib/Dialect/TritonGPU/Transforms/BF16DotTC.cpp @@ -11,6 +11,27 @@ namespace gpu { namespace { +template +auto convertValue(const Value &value, const FloatType &scalarToType, PatternRewriter &rewriter) -> mlir::Value { + auto fromType = cast(value.getType()); + auto toType = RankedTensorType::get(fromType.getShape(), scalarToType, fromType.getEncoding()); + return rewriter.create(value.getLoc(), toType, value).getResult(); +} + +auto SplitF32(Value input, unsigned N, PatternRewriter &rewriter) -> llvm::SmallVector { + llvm::SmallVector split_inputs; + for (unsigned i = 0; i < N; ++i) { + Value input_as_bf16 = convertValue(input, rewriter.getBF16Type(), rewriter); + if (i != N - 1) { + Value input_as_f32 = convertValue(input_as_bf16, rewriter.getF32Type(), rewriter); + input = rewriter.create(input.getLoc(), input, + input_as_f32); + } + split_inputs.push_back(input_as_bf16); + } + return split_inputs; +} + // Implement 3xBF16 https://arxiv.org/abs/1904.06376 class BF16x3 : public OpRewritePattern { public: @@ -35,21 +56,6 @@ class BF16x3 : public OpRewritePattern { return failure(); } - // Aux functions - auto f32ToBF16 = [&](Value value) -> Value { - auto fp32Type = cast(value.getType()); - auto bf16Type = - RankedTensorType::get(fp32Type.getShape(), rewriter.getBF16Type(), fp32Type.getEncoding()); - return rewriter.create(dotOp.getLoc(), bf16Type, value) - .getResult(); - }; - auto bf16ToF32 = [&](Value value) -> Value { - auto bf16Type = cast(value.getType()); - auto fp32Type = - RankedTensorType::get(bf16Type.getShape(), rewriter.getF32Type(), bf16Type.getEncoding()); - return rewriter.create(dotOp.getLoc(), fp32Type, value) - .getResult(); - }; Value zero = rewriter.create( dotOp->getLoc(), dotOp.getC().getType(), rewriter.create(dotOp->getLoc(), @@ -66,28 +72,13 @@ class BF16x3 : public OpRewritePattern { value); }; - auto SplitF32 = [&](Value input, unsigned N) -> std::vector { - std::vector split_inputs; - split_inputs.reserve(N); - for (unsigned i = 0; i < N; ++i) { - Value input_as_bf16 = f32ToBF16(input); - if (i != N - 1) { - Value input_as_f32 = bf16ToF32(input_as_bf16); - input = rewriter.create(dotOp->getLoc(), input, - input_as_f32); - } - split_inputs.push_back(input_as_bf16); - } - return split_inputs; - }; - const unsigned hi = 0; const unsigned mid = 1; const unsigned lo = 2; const unsigned N = 3; - auto lhs_parts = SplitF32(dotOp.getA(), N); - auto rhs_parts = SplitF32(dotOp.getB(), N); + auto lhs_parts = SplitF32(dotOp.getA(), N, rewriter); + auto rhs_parts = SplitF32(dotOp.getB(), N, rewriter); auto result = zero; From ae0ab9120d10e170f78fe7cde74af7a2e8ccaf6a Mon Sep 17 00:00:00 2001 From: Puyan Lotfi Date: Tue, 30 Sep 2025 00:50:56 -0700 Subject: [PATCH 11/39] [NFCi] move replaceNansWithZeros to where it is used --- lib/Dialect/TritonGPU/Transforms/BF16DotTC.cpp | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/lib/Dialect/TritonGPU/Transforms/BF16DotTC.cpp b/lib/Dialect/TritonGPU/Transforms/BF16DotTC.cpp index 4b121f433eb7..5e2b7400a264 100644 --- a/lib/Dialect/TritonGPU/Transforms/BF16DotTC.cpp +++ b/lib/Dialect/TritonGPU/Transforms/BF16DotTC.cpp @@ -65,12 +65,6 @@ class BF16x3 : public OpRewritePattern { InputPrecision::BF16, dotOp.getMaxNumImpreciseAcc()); }; - auto replaceNansWithZeros = [&](Value value) -> Value { - auto nans = rewriter.create( - dotOp->getLoc(), arith::CmpFPredicate::UNO, value, value); - return rewriter.create(dotOp->getLoc(), nans, zero, - value); - }; const unsigned hi = 0; const unsigned mid = 1; @@ -105,6 +99,13 @@ class BF16x3 : public OpRewritePattern { result = dot(lhs_parts[hi], rhs_parts[mid], result); } + auto replaceNansWithZeros = [&](Value value) -> Value { + auto nans = rewriter.create( + dotOp->getLoc(), arith::CmpFPredicate::UNO, value, value); + return rewriter.create(dotOp->getLoc(), nans, zero, + value); + }; + result = replaceNansWithZeros(result); result = dot(lhs_parts[hi], rhs_parts[hi], result); result = rewriter.create(dotOp.getLoc(), result, dotOp.getC()); From 220412ad634649a2061d640e3500224ffe6bd404 Mon Sep 17 00:00:00 2001 From: Puyan Lotfi Date: Tue, 30 Sep 2025 00:51:37 -0700 Subject: [PATCH 12/39] [NFCi] clean up the check for f32 operands --- lib/Dialect/TritonGPU/Transforms/BF16DotTC.cpp | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/lib/Dialect/TritonGPU/Transforms/BF16DotTC.cpp b/lib/Dialect/TritonGPU/Transforms/BF16DotTC.cpp index 5e2b7400a264..8370b40196d2 100644 --- a/lib/Dialect/TritonGPU/Transforms/BF16DotTC.cpp +++ b/lib/Dialect/TritonGPU/Transforms/BF16DotTC.cpp @@ -49,12 +49,9 @@ class BF16x3 : public OpRewritePattern { return failure(); } - auto isF32 = [](Value operand) { - return cast(operand.getType()).getElementType().isF32(); - }; - if (!isF32(dotOp.getA()) || !isF32(dotOp.getB())) { - return failure(); - } + for (auto type : {dotOp.getA().getType(), dotOp.getB().getType()}) + if (!cast(type).getElementType().isF32()) + return failure(); Value zero = rewriter.create( dotOp->getLoc(), dotOp.getC().getType(), From 347103ced3845931956e6b887af961360ef7e03f Mon Sep 17 00:00:00 2001 From: Puyan Lotfi Date: Tue, 30 Sep 2025 01:21:53 -0700 Subject: [PATCH 13/39] [NFCi] clean up placement for hi,mid,low,N, rename struct to BF16xN --- .../TritonGPU/Transforms/BF16DotTC.cpp | 39 +++++++++++-------- 1 file changed, 22 insertions(+), 17 deletions(-) diff --git a/lib/Dialect/TritonGPU/Transforms/BF16DotTC.cpp b/lib/Dialect/TritonGPU/Transforms/BF16DotTC.cpp index 8370b40196d2..6abb0d4e6c56 100644 --- a/lib/Dialect/TritonGPU/Transforms/BF16DotTC.cpp +++ b/lib/Dialect/TritonGPU/Transforms/BF16DotTC.cpp @@ -32,23 +32,33 @@ auto SplitF32(Value input, unsigned N, PatternRewriter &rewriter) -> llvm::Small return split_inputs; } +auto getBF16Count(triton::InputPrecision precision) -> unsigned { + switch (precision) { + default: + return 0; + case InputPrecision::BF16: + return 1; + case InputPrecision::BF16x3: + return 2; + case InputPrecision::BF16x6: + case InputPrecision::BF16x9: + return 3; + } +} + // Implement 3xBF16 https://arxiv.org/abs/1904.06376 -class BF16x3 : public OpRewritePattern { -public: +struct BF16xN : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(DotOp dotOp, PatternRewriter &rewriter) const override { - switch (dotOp.getInputPrecision()) { - case InputPrecision::BF16: - case InputPrecision::BF16x3: - case InputPrecision::BF16x6: - case InputPrecision::BF16x9: - break; - default: - return failure(); - } + const unsigned hi = 0; + const unsigned mid = 1; + const unsigned lo = 2; + const unsigned N = getBF16Count(dotOp.getInputPrecision()); + if (N == 0) + return failure(); for (auto type : {dotOp.getA().getType(), dotOp.getB().getType()}) if (!cast(type).getElementType().isF32()) return failure(); @@ -63,11 +73,6 @@ class BF16x3 : public OpRewritePattern { dotOp.getMaxNumImpreciseAcc()); }; - const unsigned hi = 0; - const unsigned mid = 1; - const unsigned lo = 2; - - const unsigned N = 3; auto lhs_parts = SplitF32(dotOp.getA(), N, rewriter); auto rhs_parts = SplitF32(dotOp.getB(), N, rewriter); @@ -120,7 +125,7 @@ struct BF16DotTCPass : public impl::TritonGPUBF16DotTCBase { ModuleOp m = getOperation(); RewritePatternSet decomposePatterns(context); - decomposePatterns.add(context); + decomposePatterns.add(context); if (applyPatternsGreedily(m, std::move(decomposePatterns)).failed()) { signalPassFailure(); } From c72616c9219fd52b069183f8380e477e485d8287 Mon Sep 17 00:00:00 2001 From: Puyan Lotfi Date: Tue, 30 Sep 2025 01:23:30 -0700 Subject: [PATCH 14/39] [NFCi] pre-commit --- .../TritonGPU/Transforms/BF16DotTC.cpp | 22 ++++++++++++------- 1 file changed, 14 insertions(+), 8 deletions(-) diff --git a/lib/Dialect/TritonGPU/Transforms/BF16DotTC.cpp b/lib/Dialect/TritonGPU/Transforms/BF16DotTC.cpp index 6abb0d4e6c56..07166f439811 100644 --- a/lib/Dialect/TritonGPU/Transforms/BF16DotTC.cpp +++ b/lib/Dialect/TritonGPU/Transforms/BF16DotTC.cpp @@ -12,20 +12,25 @@ namespace gpu { namespace { template -auto convertValue(const Value &value, const FloatType &scalarToType, PatternRewriter &rewriter) -> mlir::Value { +auto convertValue(const Value &value, const FloatType &scalarToType, + PatternRewriter &rewriter) -> mlir::Value { auto fromType = cast(value.getType()); - auto toType = RankedTensorType::get(fromType.getShape(), scalarToType, fromType.getEncoding()); + auto toType = RankedTensorType::get(fromType.getShape(), scalarToType, + fromType.getEncoding()); return rewriter.create(value.getLoc(), toType, value).getResult(); } -auto SplitF32(Value input, unsigned N, PatternRewriter &rewriter) -> llvm::SmallVector { +auto SplitF32(Value input, unsigned N, PatternRewriter &rewriter) + -> llvm::SmallVector { llvm::SmallVector split_inputs; for (unsigned i = 0; i < N; ++i) { - Value input_as_bf16 = convertValue(input, rewriter.getBF16Type(), rewriter); + Value input_as_bf16 = + convertValue(input, rewriter.getBF16Type(), rewriter); if (i != N - 1) { - Value input_as_f32 = convertValue(input_as_bf16, rewriter.getF32Type(), rewriter); - input = rewriter.create(input.getLoc(), input, - input_as_f32); + Value input_as_f32 = convertValue( + input_as_bf16, rewriter.getF32Type(), rewriter); + input = + rewriter.create(input.getLoc(), input, input_as_f32); } split_inputs.push_back(input_as_bf16); } @@ -110,7 +115,8 @@ struct BF16xN : public OpRewritePattern { result = replaceNansWithZeros(result); result = dot(lhs_parts[hi], rhs_parts[hi], result); - result = rewriter.create(dotOp.getLoc(), result, dotOp.getC()); + result = + rewriter.create(dotOp.getLoc(), result, dotOp.getC()); rewriter.replaceOp(dotOp, result); return success(); From f246280fe2c96ba9f865724d5e0d2cba45bb9982 Mon Sep 17 00:00:00 2001 From: Puyan Lotfi Date: Tue, 30 Sep 2025 01:47:29 -0700 Subject: [PATCH 15/39] [NFCi] clean up lambdas --- .../TritonGPU/Transforms/BF16DotTC.cpp | 31 +++++++++++-------- 1 file changed, 18 insertions(+), 13 deletions(-) diff --git a/lib/Dialect/TritonGPU/Transforms/BF16DotTC.cpp b/lib/Dialect/TritonGPU/Transforms/BF16DotTC.cpp index 07166f439811..055227f60cba 100644 --- a/lib/Dialect/TritonGPU/Transforms/BF16DotTC.cpp +++ b/lib/Dialect/TritonGPU/Transforms/BF16DotTC.cpp @@ -57,37 +57,49 @@ struct BF16xN : public OpRewritePattern { LogicalResult matchAndRewrite(DotOp dotOp, PatternRewriter &rewriter) const override { + // BF16 indices and count const unsigned hi = 0; const unsigned mid = 1; const unsigned lo = 2; const unsigned N = getBF16Count(dotOp.getInputPrecision()); + // Checks for FP32 inputs and BF16 InputPrecision if (N == 0) return failure(); for (auto type : {dotOp.getA().getType(), dotOp.getB().getType()}) if (!cast(type).getElementType().isF32()) return failure(); - Value zero = rewriter.create( - dotOp->getLoc(), dotOp.getC().getType(), - rewriter.create(dotOp->getLoc(), - rewriter.getF32FloatAttr(0))); + // Helper Lambdas auto dot = [&](Value a, Value b, Value c) -> Value { return rewriter.create(dotOp->getLoc(), c.getType(), a, b, c, InputPrecision::BF16, dotOp.getMaxNumImpreciseAcc()); }; + auto zeroLike = [&]() -> Value { + return rewriter.create( + dotOp->getLoc(), dotOp.getC().getType(), + rewriter.create(dotOp->getLoc(), + rewriter.getF32FloatAttr(0))); + }; + auto replaceNansWithZeros = [&](Value value) -> Value { + auto isNaN = rewriter.create( + dotOp->getLoc(), arith::CmpFPredicate::UNO, value, value); + return rewriter.create(dotOp->getLoc(), isNaN, zeroLike(), + value); + }; + // Starting Values: a(0), a(1), a(2), b(0), b(1), b(2) and zero accumulator start value auto lhs_parts = SplitF32(dotOp.getA(), N, rewriter); auto rhs_parts = SplitF32(dotOp.getB(), N, rewriter); - - auto result = zero; + auto result = zeroLike(); if (dotOp.getInputPrecision() == InputPrecision::BF16x9) { result = dot(lhs_parts[lo], rhs_parts[lo], result); result = dot(lhs_parts[mid], rhs_parts[lo], result); result = dot(lhs_parts[lo], rhs_parts[mid], result); + // Identical to BF16x6 handling code: result = dot(lhs_parts[mid], rhs_parts[mid], result); result = dot(lhs_parts[lo], rhs_parts[hi], result); @@ -106,13 +118,6 @@ struct BF16xN : public OpRewritePattern { result = dot(lhs_parts[hi], rhs_parts[mid], result); } - auto replaceNansWithZeros = [&](Value value) -> Value { - auto nans = rewriter.create( - dotOp->getLoc(), arith::CmpFPredicate::UNO, value, value); - return rewriter.create(dotOp->getLoc(), nans, zero, - value); - }; - result = replaceNansWithZeros(result); result = dot(lhs_parts[hi], rhs_parts[hi], result); result = From 79223701f84d44069cdaef2887e10be3dc61eb03 Mon Sep 17 00:00:00 2001 From: Puyan Lotfi Date: Tue, 30 Sep 2025 01:48:41 -0700 Subject: [PATCH 16/39] [NFCi] pre-commit --- lib/Dialect/TritonGPU/Transforms/BF16DotTC.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/lib/Dialect/TritonGPU/Transforms/BF16DotTC.cpp b/lib/Dialect/TritonGPU/Transforms/BF16DotTC.cpp index 055227f60cba..6656182643c4 100644 --- a/lib/Dialect/TritonGPU/Transforms/BF16DotTC.cpp +++ b/lib/Dialect/TritonGPU/Transforms/BF16DotTC.cpp @@ -85,11 +85,11 @@ struct BF16xN : public OpRewritePattern { auto replaceNansWithZeros = [&](Value value) -> Value { auto isNaN = rewriter.create( dotOp->getLoc(), arith::CmpFPredicate::UNO, value, value); - return rewriter.create(dotOp->getLoc(), isNaN, zeroLike(), - value); + return rewriter.create(dotOp->getLoc(), isNaN, + zeroLike(), value); }; - // Starting Values: a(0), a(1), a(2), b(0), b(1), b(2) and zero accumulator start value + // Starting Values: a(0), a(1), a(2), b(0), b(1), b(2) and zero accumulator auto lhs_parts = SplitF32(dotOp.getA(), N, rewriter); auto rhs_parts = SplitF32(dotOp.getB(), N, rewriter); auto result = zeroLike(); From 03e2cf2a4a1ac1d31583370e44f8f8399c186580 Mon Sep 17 00:00:00 2001 From: Puyan Lotfi Date: Tue, 30 Sep 2025 01:52:32 -0700 Subject: [PATCH 17/39] [NFCi] loc --- lib/Dialect/TritonGPU/Transforms/BF16DotTC.cpp | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/lib/Dialect/TritonGPU/Transforms/BF16DotTC.cpp b/lib/Dialect/TritonGPU/Transforms/BF16DotTC.cpp index 6656182643c4..17a8944e0aee 100644 --- a/lib/Dialect/TritonGPU/Transforms/BF16DotTC.cpp +++ b/lib/Dialect/TritonGPU/Transforms/BF16DotTC.cpp @@ -62,6 +62,7 @@ struct BF16xN : public OpRewritePattern { const unsigned mid = 1; const unsigned lo = 2; const unsigned N = getBF16Count(dotOp.getInputPrecision()); + Location loc = dotOp.getLoc(); // Checks for FP32 inputs and BF16 InputPrecision if (N == 0) @@ -72,20 +73,20 @@ struct BF16xN : public OpRewritePattern { // Helper Lambdas auto dot = [&](Value a, Value b, Value c) -> Value { - return rewriter.create(dotOp->getLoc(), c.getType(), a, b, c, + return rewriter.create(loc, c.getType(), a, b, c, InputPrecision::BF16, dotOp.getMaxNumImpreciseAcc()); }; auto zeroLike = [&]() -> Value { return rewriter.create( - dotOp->getLoc(), dotOp.getC().getType(), - rewriter.create(dotOp->getLoc(), + loc, dotOp.getC().getType(), + rewriter.create(loc, rewriter.getF32FloatAttr(0))); }; auto replaceNansWithZeros = [&](Value value) -> Value { auto isNaN = rewriter.create( - dotOp->getLoc(), arith::CmpFPredicate::UNO, value, value); - return rewriter.create(dotOp->getLoc(), isNaN, + loc, arith::CmpFPredicate::UNO, value, value); + return rewriter.create(loc, isNaN, zeroLike(), value); }; @@ -121,7 +122,7 @@ struct BF16xN : public OpRewritePattern { result = replaceNansWithZeros(result); result = dot(lhs_parts[hi], rhs_parts[hi], result); result = - rewriter.create(dotOp.getLoc(), result, dotOp.getC()); + rewriter.create(loc, result, dotOp.getC()); rewriter.replaceOp(dotOp, result); return success(); From 72e8aa9b090d18768012f26b5ac3d9cb2d664916 Mon Sep 17 00:00:00 2001 From: Puyan Lotfi Date: Tue, 30 Sep 2025 01:52:49 -0700 Subject: [PATCH 18/39] [NFCi] pre-commit --- lib/Dialect/TritonGPU/Transforms/BF16DotTC.cpp | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/lib/Dialect/TritonGPU/Transforms/BF16DotTC.cpp b/lib/Dialect/TritonGPU/Transforms/BF16DotTC.cpp index 17a8944e0aee..89257654c2cb 100644 --- a/lib/Dialect/TritonGPU/Transforms/BF16DotTC.cpp +++ b/lib/Dialect/TritonGPU/Transforms/BF16DotTC.cpp @@ -80,14 +80,12 @@ struct BF16xN : public OpRewritePattern { auto zeroLike = [&]() -> Value { return rewriter.create( loc, dotOp.getC().getType(), - rewriter.create(loc, - rewriter.getF32FloatAttr(0))); + rewriter.create(loc, rewriter.getF32FloatAttr(0))); }; auto replaceNansWithZeros = [&](Value value) -> Value { auto isNaN = rewriter.create( loc, arith::CmpFPredicate::UNO, value, value); - return rewriter.create(loc, isNaN, - zeroLike(), value); + return rewriter.create(loc, isNaN, zeroLike(), value); }; // Starting Values: a(0), a(1), a(2), b(0), b(1), b(2) and zero accumulator @@ -121,8 +119,7 @@ struct BF16xN : public OpRewritePattern { result = replaceNansWithZeros(result); result = dot(lhs_parts[hi], rhs_parts[hi], result); - result = - rewriter.create(loc, result, dotOp.getC()); + result = rewriter.create(loc, result, dotOp.getC()); rewriter.replaceOp(dotOp, result); return success(); From 949426870a6e8c643b2bcb34e3cd740dd80dba9a Mon Sep 17 00:00:00 2001 From: Puyan Lotfi Date: Tue, 30 Sep 2025 17:31:45 -0700 Subject: [PATCH 19/39] remove 03-matrix-multiplication.py --- 03-matrix-multiplication.py | 160 ------------------------------------ 1 file changed, 160 deletions(-) delete mode 100644 03-matrix-multiplication.py diff --git a/03-matrix-multiplication.py b/03-matrix-multiplication.py deleted file mode 100644 index 13448035b6b8..000000000000 --- a/03-matrix-multiplication.py +++ /dev/null @@ -1,160 +0,0 @@ -import torch - -import triton -import triton.language as tl - -DEVICE = triton.runtime.driver.active.get_active_torch_device() - -def is_cuda(): - return triton.runtime.driver.active.get_current_target().backend == "cuda" - -def get_cuda_autotune_config(): - return [ - triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5,num_warps=2), - triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5,num_warps=2), - triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4,num_warps=4), - triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4,num_warps=4), - triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4,num_warps=4), - triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4,num_warps=4), - triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4,num_warps=4), - triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=3,num_warps=8), - ] - -def get_hip_autotune_config(): - sizes = [ - {'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 6}, - {'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 6}, - {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 4}, - {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 6}, - {'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 4}, - {'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 4}, - {'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 4}, - {'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 6}, - ] - return [triton.Config(s | {'matrix_instr_nonkdim': 16}, num_warps=8, num_stages=2) for s in sizes] - -def get_autotune_config(): - if is_cuda(): - return get_cuda_autotune_config() - else: - return get_hip_autotune_config() - -@triton.autotune( - configs=get_autotune_config(), - key=['M', 'N', 'K'], -) -@triton.jit -def matmul_kernel( - a_ptr, b_ptr, c_ptr, - M, N, K, - stride_am, stride_ak, - stride_bk, stride_bn, - stride_cm, stride_cn, - BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, - GROUP_SIZE_M: tl.constexpr, - PRECISION: tl.constexpr -): - pid = tl.program_id(axis=0) - num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) - num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) - num_pid_in_group = GROUP_SIZE_M * num_pid_n - group_id = pid // num_pid_in_group - first_pid_m = group_id * GROUP_SIZE_M - group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) - pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) - pid_n = (pid % num_pid_in_group) // group_size_m - - tl.assume(pid_m >= 0) - tl.assume(pid_n >= 0) - tl.assume(stride_am > 0) - tl.assume(stride_ak > 0) - tl.assume(stride_bn > 0) - tl.assume(stride_bk > 0) - tl.assume(stride_cm > 0) - tl.assume(stride_cn > 0) - - offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M - offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N - offs_k = tl.arange(0, BLOCK_SIZE_K) - a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) - b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) - - accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) - for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): - a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) - b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) - accumulator += tl.dot(a, b, input_precision=PRECISION) - # accumulator = tl.dot(a, b, accumulator) - # Advance the ptrs to the next K block. - a_ptrs += BLOCK_SIZE_K * stride_ak - b_ptrs += BLOCK_SIZE_K * stride_bk - c = accumulator.to(tl.float32) - - offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] - c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) - tl.store(c_ptrs, c, mask=c_mask) - - -def matmul(a, b, precision="ieee"): - M, K = a.shape - K, N = b.shape - c = torch.empty((M, N), device=a.device, dtype=torch.float32) - grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), ) - matmul_kernel[grid]( - a, b, c, - M, N, K, - a.stride(0), a.stride(1), - b.stride(0), b.stride(1), - c.stride(0), c.stride(1), - PRECISION=precision - ) - return c - - -precisions = ["ieee", "bf16", "bf16x3", "bf16x6", "bf16x9"] -torch.manual_seed(0) - -for precision in precisions: - a = torch.rand((512, 512), device=DEVICE, dtype=torch.float32) - 0.5 - b = torch.rand((512, 512), device=DEVICE, dtype=torch.float32) - 0.5 - triton_output = matmul(a, b, precision=precision) - torch_output = torch.matmul(a, b) - #print(f"triton_output_with_fp32_inputs={triton_output}") - #print(f"torch_output_with_fp32_inputs={torch_output}") - - if torch.allclose(triton_output, torch_output, atol=1e-2, rtol=0): - print(f'✅ Triton and Torch match for input_precision={precision}') - else: - print(f'❌ Triton and Torch differ for input_precision={precision}') - -ref_lib = 'cuBLAS' if is_cuda() else 'rocBLAS' - -configs = [] -configs.append( - triton.testing.Benchmark( - x_names=["M", "N", "K"], - x_vals=[128 * i for i in range(2, 33)], - line_arg="provider", - line_vals=[ref_lib.lower(), "triton-ieee", "triton-bf16", "triton-bf16x3", "triton-bf16x6", "triton-bf16x9"], - line_names=[ref_lib, "Triton-IEEE", "Triton-BF16", "Triton-BF16x3", "Triton-BF16x6", "Triton-BF16x9"], - styles=[("green", "-"), ("blue", "-"), ("blue", "-"), ("blue", "-"), ("blue", "-"), ("blue", "-")], - ylabel="TFLOPS", - plot_name="matmul-performance-f32", - args={}, - )) - -@triton.testing.perf_report(configs) -def benchmark(M, N, K, provider): - a = torch.randn((M, K), device=DEVICE, dtype=torch.float32) - b = torch.randn((K, N), device=DEVICE, dtype=torch.float32) - quantiles = [0.5, 0.2, 0.8] - if provider == ref_lib.lower(): - ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.matmul(a, b), quantiles=quantiles) - if provider.startswith('triton-'): - ms, min_ms, max_ms = triton.testing.do_bench(lambda: matmul(a, b, provider.removeprefix('triton-')), quantiles=quantiles) - perf = lambda ms: 2 * M * N * K * 1e-12 / (ms * 1e-3) - return perf(ms), perf(max_ms), perf(min_ms) - -benchmark.run(show_plots=False, print_data=True) From a8a7bf9733f668d8725a36e46f9bdc402d12d685 Mon Sep 17 00:00:00 2001 From: Puyan Lotfi Date: Wed, 1 Oct 2025 17:20:36 -0700 Subject: [PATCH 20/39] more cleanup --- .../TritonGPU/Transforms/BF16DotTC.cpp | 85 ++++++++++--------- 1 file changed, 43 insertions(+), 42 deletions(-) diff --git a/lib/Dialect/TritonGPU/Transforms/BF16DotTC.cpp b/lib/Dialect/TritonGPU/Transforms/BF16DotTC.cpp index 89257654c2cb..33c197df232f 100644 --- a/lib/Dialect/TritonGPU/Transforms/BF16DotTC.cpp +++ b/lib/Dialect/TritonGPU/Transforms/BF16DotTC.cpp @@ -37,6 +37,12 @@ auto SplitF32(Value input, unsigned N, PatternRewriter &rewriter) return split_inputs; } +Value IEEEDot(PatternRewriter &rewriter, Value lhs, Value rhs, Value acc) { + return rewriter.create(lhs.getLoc(), lhs, rhs, acc, + /*inputPrecision=*/InputPrecision::IEEE, + /*maxNumImpreciseAcc=*/0); +} + auto getBF16Count(triton::InputPrecision precision) -> unsigned { switch (precision) { default: @@ -51,7 +57,9 @@ auto getBF16Count(triton::InputPrecision precision) -> unsigned { } } -// Implement 3xBF16 https://arxiv.org/abs/1904.06376 +// Implements 3xBF16 https://arxiv.org/abs/1904.06376 +// See also https://github.com/openxla/xla/blob/e33f93fb7220d408811afdc926cf10baaf49c64e/xla/backends/gpu/codegen/triton/dot_algorithms.cc#L152 +// As well as https://github.com/ROCm/rocm-libraries/blob/develop/projects/hipblaslt/tensilelite/Tensile/Components/LocalRead.py#L288-L330 struct BF16xN : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -63,62 +71,55 @@ struct BF16xN : public OpRewritePattern { const unsigned lo = 2; const unsigned N = getBF16Count(dotOp.getInputPrecision()); Location loc = dotOp.getLoc(); + auto typeA = dotOp.getA().getType(); + auto typeB = dotOp.getB().getType(); - // Checks for FP32 inputs and BF16 InputPrecision - if (N == 0) + if (!cast(typeA).getElementType().isF32() || + !cast(typeB).getElementType().isF32() || !N) return failure(); - for (auto type : {dotOp.getA().getType(), dotOp.getB().getType()}) - if (!cast(type).getElementType().isF32()) - return failure(); - - // Helper Lambdas - auto dot = [&](Value a, Value b, Value c) -> Value { - return rewriter.create(loc, c.getType(), a, b, c, - InputPrecision::BF16, - dotOp.getMaxNumImpreciseAcc()); - }; - auto zeroLike = [&]() -> Value { + + // Aux functions + auto zeroLike = [&](Value c) -> Value { return rewriter.create( - loc, dotOp.getC().getType(), - rewriter.create(loc, rewriter.getF32FloatAttr(0))); + dotOp->getLoc(), c.getType(), + rewriter.create(dotOp->getLoc(), + rewriter.getF32FloatAttr(0))); }; auto replaceNansWithZeros = [&](Value value) -> Value { - auto isNaN = rewriter.create( - loc, arith::CmpFPredicate::UNO, value, value); - return rewriter.create(loc, isNaN, zeroLike(), value); + auto nans = rewriter.create( + dotOp->getLoc(), arith::CmpFPredicate::UNO, value, value); + auto zero = zeroLike(value); + return rewriter.create(dotOp->getLoc(), nans, zero, + value); }; // Starting Values: a(0), a(1), a(2), b(0), b(1), b(2) and zero accumulator - auto lhs_parts = SplitF32(dotOp.getA(), N, rewriter); - auto rhs_parts = SplitF32(dotOp.getB(), N, rewriter); - auto result = zeroLike(); + const auto lhs_parts = SplitF32(dotOp.getA(), N, rewriter); + const auto rhs_parts = SplitF32(dotOp.getB(), N, rewriter); + auto result = zeroLike(dotOp.getC()); - if (dotOp.getInputPrecision() == InputPrecision::BF16x9) { - result = dot(lhs_parts[lo], rhs_parts[lo], result); - result = dot(lhs_parts[mid], rhs_parts[lo], result); - result = dot(lhs_parts[lo], rhs_parts[mid], result); - - // Identical to BF16x6 handling code: - result = dot(lhs_parts[mid], rhs_parts[mid], result); - - result = dot(lhs_parts[lo], rhs_parts[hi], result); - result = dot(lhs_parts[hi], rhs_parts[lo], result); + switch (dotOp.getInputPrecision()) { + default: + assert(false && "BF16DotTCPass expects BF16x9, BF16x6 or BF16x3"); + return failure(); + case InputPrecision::BF16x9: + result = IEEEDot(rewriter, lhs_parts[lo], rhs_parts[lo], result); + result = IEEEDot(rewriter, lhs_parts[mid], rhs_parts[lo], result); + result = IEEEDot(rewriter, lhs_parts[lo], rhs_parts[mid], result); - } else if (dotOp.getInputPrecision() == InputPrecision::BF16x6) { - result = dot(lhs_parts[mid], rhs_parts[mid], result); + case InputPrecision::BF16x6: + result = IEEEDot(rewriter, lhs_parts[mid], rhs_parts[mid], result); - result = dot(lhs_parts[lo], rhs_parts[hi], result); - result = dot(lhs_parts[hi], rhs_parts[lo], result); - } + result = IEEEDot(rewriter, lhs_parts[lo], rhs_parts[hi], result); + result = IEEEDot(rewriter, lhs_parts[hi], rhs_parts[lo], result); - // BF16x3, BF16x6, BF16x9 all need this - if (dotOp.getInputPrecision() != InputPrecision::BF16) { - result = dot(lhs_parts[mid], rhs_parts[hi], result); - result = dot(lhs_parts[hi], rhs_parts[mid], result); + case InputPrecision::BF16x3: + result = IEEEDot(rewriter, lhs_parts[mid], rhs_parts[hi], result); + result = IEEEDot(rewriter, lhs_parts[hi], rhs_parts[mid], result); } result = replaceNansWithZeros(result); - result = dot(lhs_parts[hi], rhs_parts[hi], result); + result = IEEEDot(rewriter, lhs_parts[hi], rhs_parts[hi], result); result = rewriter.create(loc, result, dotOp.getC()); rewriter.replaceOp(dotOp, result); From f7fa3fb7dda50802a4a1575e3bb4238fc653e956 Mon Sep 17 00:00:00 2001 From: Puyan Lotfi Date: Wed, 1 Oct 2025 17:21:31 -0700 Subject: [PATCH 21/39] pre-commit --- lib/Dialect/TritonGPU/Transforms/BF16DotTC.cpp | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/lib/Dialect/TritonGPU/Transforms/BF16DotTC.cpp b/lib/Dialect/TritonGPU/Transforms/BF16DotTC.cpp index 33c197df232f..85daa55f0f70 100644 --- a/lib/Dialect/TritonGPU/Transforms/BF16DotTC.cpp +++ b/lib/Dialect/TritonGPU/Transforms/BF16DotTC.cpp @@ -39,8 +39,8 @@ auto SplitF32(Value input, unsigned N, PatternRewriter &rewriter) Value IEEEDot(PatternRewriter &rewriter, Value lhs, Value rhs, Value acc) { return rewriter.create(lhs.getLoc(), lhs, rhs, acc, - /*inputPrecision=*/InputPrecision::IEEE, - /*maxNumImpreciseAcc=*/0); + /*inputPrecision=*/InputPrecision::IEEE, + /*maxNumImpreciseAcc=*/0); } auto getBF16Count(triton::InputPrecision precision) -> unsigned { @@ -58,8 +58,10 @@ auto getBF16Count(triton::InputPrecision precision) -> unsigned { } // Implements 3xBF16 https://arxiv.org/abs/1904.06376 -// See also https://github.com/openxla/xla/blob/e33f93fb7220d408811afdc926cf10baaf49c64e/xla/backends/gpu/codegen/triton/dot_algorithms.cc#L152 -// As well as https://github.com/ROCm/rocm-libraries/blob/develop/projects/hipblaslt/tensilelite/Tensile/Components/LocalRead.py#L288-L330 +// See also +// https://github.com/openxla/xla/blob/e33f93fb7220d408811afdc926cf10baaf49c64e/xla/backends/gpu/codegen/triton/dot_algorithms.cc#L152 +// As well as +// https://github.com/ROCm/rocm-libraries/blob/develop/projects/hipblaslt/tensilelite/Tensile/Components/LocalRead.py#L288-L330 struct BF16xN : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; From a4c29a2589df9efc8c5dae104eb92cb38d27b4de Mon Sep 17 00:00:00 2001 From: Puyan Lotfi Date: Wed, 1 Oct 2025 18:04:19 -0700 Subject: [PATCH 22/39] drop bf16x1 --- include/triton/Dialect/Triton/IR/TritonAttrDefs.td | 7 +++---- lib/Dialect/TritonGPU/Transforms/BF16DotTC.cpp | 2 -- python/src/ir.cc | 1 - python/test/unit/language/test_core.py | 5 +++-- third_party/amd/backend/compiler.py | 2 +- third_party/nvidia/backend/compiler.py | 2 +- 6 files changed, 8 insertions(+), 11 deletions(-) diff --git a/include/triton/Dialect/Triton/IR/TritonAttrDefs.td b/include/triton/Dialect/Triton/IR/TritonAttrDefs.td index 24d345a21d53..1f48d3c5f4c8 100644 --- a/include/triton/Dialect/Triton/IR/TritonAttrDefs.td +++ b/include/triton/Dialect/Triton/IR/TritonAttrDefs.td @@ -130,10 +130,9 @@ def TT_InputPrecisionAttr : I32EnumAttr< I32EnumAttrCase<"TF32", 0, "tf32">, I32EnumAttrCase<"TF32x3", 1, "tf32x3">, I32EnumAttrCase<"IEEE", 2, "ieee">, - I32EnumAttrCase<"BF16", 3, "bf16">, - I32EnumAttrCase<"BF16x3", 4, "bf16x3">, - I32EnumAttrCase<"BF16x6", 5, "bf16x6">, - I32EnumAttrCase<"BF16x9", 6, "bf16x9"> + I32EnumAttrCase<"BF16x3", 3, "bf16x3">, + I32EnumAttrCase<"BF16x6", 4, "bf16x6">, + I32EnumAttrCase<"BF16x9", 5, "bf16x9"> ]>{ let cppNamespace = "::mlir::triton"; } diff --git a/lib/Dialect/TritonGPU/Transforms/BF16DotTC.cpp b/lib/Dialect/TritonGPU/Transforms/BF16DotTC.cpp index 85daa55f0f70..e93989b2bd85 100644 --- a/lib/Dialect/TritonGPU/Transforms/BF16DotTC.cpp +++ b/lib/Dialect/TritonGPU/Transforms/BF16DotTC.cpp @@ -47,8 +47,6 @@ auto getBF16Count(triton::InputPrecision precision) -> unsigned { switch (precision) { default: return 0; - case InputPrecision::BF16: - return 1; case InputPrecision::BF16x3: return 2; case InputPrecision::BF16x6: diff --git a/python/src/ir.cc b/python/src/ir.cc index f4fee5e05f1b..b890a2af05c1 100644 --- a/python/src/ir.cc +++ b/python/src/ir.cc @@ -308,7 +308,6 @@ void init_triton_ir(py::module &&m) { .value("TF32", InputPrecision::TF32) .value("TF32x3", InputPrecision::TF32x3) .value("IEEE", InputPrecision::IEEE) - .value("BF16", InputPrecision::BF16) .value("BF16x3", InputPrecision::BF16x3) .value("BF16x6", InputPrecision::BF16x6) .value("BF16x9", InputPrecision::BF16x9) diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index 8927ecfa31ae..51a80c2d8686 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -3084,7 +3084,7 @@ def get_test_dot_base_cases(): return [(*shape, 4, False, False, epilogue, input_precision, in_dtype, out_dtype, 1, None) for shape in [(64, 64, 64), (32, 32, 32), (16, 16, 16)] for epilogue in ['none', 'trans', 'add-matrix', 'add-rows', 'add-cols', 'softmax', 'chain-dot'] - for input_precision in ['tf32', 'tf32x3', 'ieee', 'bf16', 'bf16x3'] + for input_precision in ['tf32', 'tf32x3', 'ieee', 'bf16x3', 'bf16x6', 'bf16x9'] for in_dtype, out_dtype in [('float16', 'float16'), ('float16', 'float32'), ('float32', 'float32'), ('float64', 'float64')] @@ -3238,7 +3238,8 @@ def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, input_precision, in_dty pytest.skip(f"{in_dtype} only supported on CDNA4 and gfx12") if in_dtype in ("float8e5b16", "float8e4b8") and not is_hip_cdna3(): pytest.skip(f"{in_dtype} only supported on CDNA3") - if not ((input_precision == "bf16x3") or (input_precision == "ieee") or + if not ((input_precision == "bf16x3") or (input_precision == "bf16x6") or + (input_precision == "bf16x9") or (input_precision == "ieee") or (input_precision == "tf32" and is_hip_cdna3())): pytest.skip(f"{input_precision} not supported on HIP") if kpack == 2 and in_dtype == 'int8' and K < 64: diff --git a/third_party/amd/backend/compiler.py b/third_party/amd/backend/compiler.py index 2aad479baa8e..56516e85504c 100644 --- a/third_party/amd/backend/compiler.py +++ b/third_party/amd/backend/compiler.py @@ -45,7 +45,7 @@ class HIPOptions: supported_fp8_dtypes: Tuple[str] = ("fp8e4nv", "fp8e5", "fp8e5b16", "fp8e4b8") deprecated_fp8_dot_operand_dtypes: Tuple[str] = () default_dot_input_precision: str = "ieee" - allowed_dot_input_precisions: Tuple[str] = ("ieee", 'bf16x3', 'bf16', 'bf16x6', 'bf16x9') + allowed_dot_input_precisions: Tuple[str] = ("ieee", 'bf16x3', 'bf16x6', 'bf16x9') enable_fp_fusion: bool = True launch_cooperative_grid: bool = False matrix_instr_nonkdim: int = 0 diff --git a/third_party/nvidia/backend/compiler.py b/third_party/nvidia/backend/compiler.py index e32025abbc4b..6b9c2b311185 100644 --- a/third_party/nvidia/backend/compiler.py +++ b/third_party/nvidia/backend/compiler.py @@ -121,7 +121,7 @@ class CUDAOptions: supported_fp8_dtypes: Tuple[str] = ("fp8e5", "fp8e4b15") deprecated_fp8_dot_operand_dtypes: Tuple[str] = () default_dot_input_precision: str = "tf32" - allowed_dot_input_precisions: Tuple[str] = ("tf32", "tf32x3", "ieee", 'bf16x3', 'bf16', 'bf16x6', 'bf16x9') + allowed_dot_input_precisions: Tuple[str] = ("tf32", "tf32x3", "ieee", 'bf16x3', 'bf16x6', 'bf16x9') max_num_imprecise_acc_default: bool = None extern_libs: dict = None debug: bool = False From 80e339d4ed1c21fed0926055bd048245f5c26061 Mon Sep 17 00:00:00 2001 From: Puyan Lotfi Date: Wed, 1 Oct 2025 18:08:11 -0700 Subject: [PATCH 23/39] pre-commit --- python/test/unit/language/test_core.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index 51a80c2d8686..25d399065b11 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -3238,9 +3238,8 @@ def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, input_precision, in_dty pytest.skip(f"{in_dtype} only supported on CDNA4 and gfx12") if in_dtype in ("float8e5b16", "float8e4b8") and not is_hip_cdna3(): pytest.skip(f"{in_dtype} only supported on CDNA3") - if not ((input_precision == "bf16x3") or (input_precision == "bf16x6") or - (input_precision == "bf16x9") or (input_precision == "ieee") or - (input_precision == "tf32" and is_hip_cdna3())): + if not ((input_precision == "bf16x3") or (input_precision == "bf16x6") or (input_precision == "bf16x9") or + (input_precision == "ieee") or (input_precision == "tf32" and is_hip_cdna3())): pytest.skip(f"{input_precision} not supported on HIP") if kpack == 2 and in_dtype == 'int8' and K < 64: pytest.skip("kpack too large for K") From 4abb7938c14ad36022bd675c72719052f3418b68 Mon Sep 17 00:00:00 2001 From: Puyan Lotfi Date: Wed, 1 Oct 2025 18:23:31 -0700 Subject: [PATCH 24/39] improve lit tests --- test/TritonGPU/bf16x3-matmul.mlir | 87 +++++++++++++++++++++++++++---- 1 file changed, 78 insertions(+), 9 deletions(-) diff --git a/test/TritonGPU/bf16x3-matmul.mlir b/test/TritonGPU/bf16x3-matmul.mlir index e5b46fbbc016..0a13c7278d18 100644 --- a/test/TritonGPU/bf16x3-matmul.mlir +++ b/test/TritonGPU/bf16x3-matmul.mlir @@ -1,5 +1,64 @@ // RUN: triton-opt %s -tritongpu-BF16DotTC -canonicalize | FileCheck %s --check-prefixes=CHECK + + +//// Tests for BF16x3: + +// CHECK: %[[lhs_hi:.*]] = arith.truncf %arg0 +// CHECK-NEXT: %[[val1:.*]] = arith.extf %[[lhs_hi]] +// CHECK-NEXT: %[[val2:.*]] = arith.subf %arg0, %[[val1]] +// CHECK-NEXT: %[[lhs_mid:.*]] = arith.truncf %[[val2]] + +// CHECK: %[[rhs_hi:.*]] = arith.truncf %arg1 +// CHECK-NEXT: %[[val8:.*]] = arith.extf %[[rhs_hi]] +// CHECK-NEXT: %[[val9:.*]] = arith.subf %arg1, %[[val8]] +// CHECK-NEXT: %[[rhs_mid:.*]] = arith.truncf %[[val9]] + +// CHECK-NEXT: %[[val20:.*]] = tt.dot %[[lhs_mid]], %[[rhs_hi]] +// CHECK-NEXT: %[[val21:.*]] = tt.dot %[[lhs_hi]], %[[rhs_mid]], %[[val20]] + +// CHECK: %[[val22:.*]] = arith.cmpf uno, %[[val21]], %[[val21]] +// CHECK-NEXT: %[[val23:.*]] = arith.select %[[val22]] + +// CHECK: %[[val24:.*]] = tt.dot %[[lhs_hi]], %[[rhs_hi]], %[[val23]] +// CHECK-NEXT: %[[val25:.*]] = arith.addf %[[val24]], %arg2 + + + +//// Tests for BF16x6: + +// CHECK: %[[lhs_hi:.*]] = arith.truncf %arg0 +// CHECK-NEXT: %[[val1:.*]] = arith.extf %[[lhs_hi]] +// CHECK-NEXT: %[[val2:.*]] = arith.subf %arg0, %[[val1]] +// CHECK-NEXT: %[[lhs_mid:.*]] = arith.truncf %[[val2]] +// CHECK-NEXT: %[[val4:.*]] = arith.extf %[[lhs_mid]] +// CHECK-NEXT: %[[val5:.*]] = arith.subf %[[val2]], %[[val4]] +// CHECK-NEXT: %[[lhs_lo:.*]] = arith.truncf %[[val5]] + +// CHECK: %[[rhs_hi:.*]] = arith.truncf %arg1 +// CHECK-NEXT: %[[val8:.*]] = arith.extf %[[rhs_hi]] +// CHECK-NEXT: %[[val9:.*]] = arith.subf %arg1, %[[val8]] +// CHECK-NEXT: %[[rhs_mid:.*]] = arith.truncf %[[val9]] +// CHECK-NEXT: %[[val11:.*]] = arith.extf %[[rhs_mid]] +// CHECK-NEXT: %[[val12:.*]] = arith.subf %[[val9]], %[[val11]] +// CHECK-NEXT: %[[rhs_lo:.*]] = arith.truncf %[[val12]] + +// CHECK: %[[val17:.*]] = tt.dot %[[lhs_mid]], %[[rhs_mid]] +// CHECK-NEXT: %[[val18:.*]] = tt.dot %[[lhs_lo]], %[[rhs_hi]], %[[val17]] +// CHECK-NEXT: %[[val19:.*]] = tt.dot %[[lhs_hi]], %[[rhs_lo]], %[[val18]] +// CHECK-NEXT: %[[val20:.*]] = tt.dot %[[lhs_mid]], %[[rhs_hi]], %[[val19]] +// CHECK-NEXT: %[[val21:.*]] = tt.dot %[[lhs_hi]], %[[rhs_mid]], %[[val20]] + +// CHECK: %[[val22:.*]] = arith.cmpf uno, %[[val21]], %[[val21]] +// CHECK-NEXT: %[[val23:.*]] = arith.select %[[val22]] + +// CHECK: %[[val24:.*]] = tt.dot %[[lhs_hi]], %[[rhs_hi]], %[[val23]] +// CHECK-NEXT: %[[val25:.*]] = arith.addf %[[val24]], %arg2 + + + +//// Tests for BF16x9: + // CHECK: %[[lhs_hi:.*]] = arith.truncf %arg0 // CHECK-NEXT: %[[val1:.*]] = arith.extf %[[lhs_hi]] // CHECK-NEXT: %[[val2:.*]] = arith.subf %arg0, %[[val1]] @@ -17,23 +76,33 @@ // CHECK-NEXT: %[[rhs_lo:.*]] = arith.truncf %[[val12]] // CHECK: %[[val14:.*]] = tt.dot %[[lhs_lo]], %[[rhs_lo]] -// CHECK-NEXT: %[[val15:.*]] = tt.dot %[[lhs_mid]], %[[rhs_lo]], %[[val14]], inputPrecision = bf16 -// CHECK-NEXT: %[[val16:.*]] = tt.dot %[[lhs_lo]], %[[rhs_mid]], %[[val15]], inputPrecision = bf16 -// CHECK-NEXT: %[[val17:.*]] = tt.dot %[[lhs_mid]], %[[rhs_mid]], %[[val16]], inputPrecision = bf16 -// CHECK-NEXT: %[[val18:.*]] = tt.dot %[[lhs_lo]], %[[rhs_hi]], %[[val17]], inputPrecision = bf16 -// CHECK-NEXT: %[[val19:.*]] = tt.dot %[[lhs_hi]], %[[rhs_lo]], %[[val18]], inputPrecision = bf16 -// CHECK-NEXT: %[[val20:.*]] = tt.dot %[[lhs_mid]], %[[rhs_hi]], %[[val19]], inputPrecision = bf16 -// CHECK-NEXT: %[[val21:.*]] = tt.dot %[[lhs_hi]], %[[rhs_mid]], %[[val20]], inputPrecision = bf16 +// CHECK-NEXT: %[[val15:.*]] = tt.dot %[[lhs_mid]], %[[rhs_lo]], %[[val14]] +// CHECK-NEXT: %[[val16:.*]] = tt.dot %[[lhs_lo]], %[[rhs_mid]], %[[val15]] +// CHECK-NEXT: %[[val17:.*]] = tt.dot %[[lhs_mid]], %[[rhs_mid]], %[[val16]] +// CHECK-NEXT: %[[val18:.*]] = tt.dot %[[lhs_lo]], %[[rhs_hi]], %[[val17]] +// CHECK-NEXT: %[[val19:.*]] = tt.dot %[[lhs_hi]], %[[rhs_lo]], %[[val18]] +// CHECK-NEXT: %[[val20:.*]] = tt.dot %[[lhs_mid]], %[[rhs_hi]], %[[val19]] +// CHECK-NEXT: %[[val21:.*]] = tt.dot %[[lhs_hi]], %[[rhs_mid]], %[[val20]] // CHECK: %[[val22:.*]] = arith.cmpf uno, %[[val21]], %[[val21]] // CHECK-NEXT: %[[val23:.*]] = arith.select %[[val22]] -// CHECK: %[[val24:.*]] = tt.dot %[[lhs_hi]], %[[rhs_hi]], %[[val23]], inputPrecision = bf16 +// CHECK: %[[val24:.*]] = tt.dot %[[lhs_hi]], %[[rhs_hi]], %[[val23]] // CHECK-NEXT: %[[val25:.*]] = arith.addf %[[val24]], %arg2 module { - tt.func @dot_test(%arg0: tensor<16x16xf32>, %arg1: tensor<16x16xf32>, %arg2: tensor<16x16xf32>) -> tensor<16x16xf32> { + tt.func @dot_test_BF16x3(%arg0: tensor<16x16xf32>, %arg1: tensor<16x16xf32>, %arg2: tensor<16x16xf32>) -> tensor<16x16xf32> { %4 = tt.dot %arg0, %arg1, %arg2, inputPrecision = bf16x3 : tensor<16x16xf32> * tensor<16x16xf32> -> tensor<16x16xf32> tt.return %4 : tensor<16x16xf32> } + + tt.func @dot_test_BF16x6(%arg0: tensor<16x16xf32>, %arg1: tensor<16x16xf32>, %arg2: tensor<16x16xf32>) -> tensor<16x16xf32> { + %4 = tt.dot %arg0, %arg1, %arg2, inputPrecision = bf16x6 : tensor<16x16xf32> * tensor<16x16xf32> -> tensor<16x16xf32> + tt.return %4 : tensor<16x16xf32> + } + + tt.func @dot_test_BF16x9(%arg0: tensor<16x16xf32>, %arg1: tensor<16x16xf32>, %arg2: tensor<16x16xf32>) -> tensor<16x16xf32> { + %4 = tt.dot %arg0, %arg1, %arg2, inputPrecision = bf16x9 : tensor<16x16xf32> * tensor<16x16xf32> -> tensor<16x16xf32> + tt.return %4 : tensor<16x16xf32> + } } From 7e9c0f0632a2ce3c93d5479241259e8fe2debdfb Mon Sep 17 00:00:00 2001 From: Puyan Lotfi Date: Fri, 10 Oct 2025 09:51:30 -0700 Subject: [PATCH 25/39] addressing Lei's feedback, drop BF16x9 with exception of a small comment --- .../Dialect/Triton/IR/TritonAttrDefs.td | 3 +- .../Dialect/TritonGPU/Transforms/Passes.td | 2 +- .../TritonGPU/Transforms/BF16DotTC.cpp | 48 +++--- python/src/ir.cc | 1 - python/test/unit/language/test_core.py | 6 +- python/triton/language/semantic.py | 2 - test/TritonGPU/bf16x3-matmul.mlir | 138 ++++++------------ third_party/amd/backend/compiler.py | 2 +- third_party/nvidia/backend/compiler.py | 2 +- 9 files changed, 76 insertions(+), 128 deletions(-) diff --git a/include/triton/Dialect/Triton/IR/TritonAttrDefs.td b/include/triton/Dialect/Triton/IR/TritonAttrDefs.td index 1f48d3c5f4c8..5a76a1d7b15f 100644 --- a/include/triton/Dialect/Triton/IR/TritonAttrDefs.td +++ b/include/triton/Dialect/Triton/IR/TritonAttrDefs.td @@ -131,8 +131,7 @@ def TT_InputPrecisionAttr : I32EnumAttr< I32EnumAttrCase<"TF32x3", 1, "tf32x3">, I32EnumAttrCase<"IEEE", 2, "ieee">, I32EnumAttrCase<"BF16x3", 3, "bf16x3">, - I32EnumAttrCase<"BF16x6", 4, "bf16x6">, - I32EnumAttrCase<"BF16x9", 5, "bf16x9"> + I32EnumAttrCase<"BF16x6", 4, "bf16x6"> ]>{ let cppNamespace = "::mlir::triton"; } diff --git a/include/triton/Dialect/TritonGPU/Transforms/Passes.td b/include/triton/Dialect/TritonGPU/Transforms/Passes.td index 1d62c4a4271f..9bfd2023d263 100644 --- a/include/triton/Dialect/TritonGPU/Transforms/Passes.td +++ b/include/triton/Dialect/TritonGPU/Transforms/Passes.td @@ -189,7 +189,7 @@ def TritonGPUF32DotTC : Pass<"tritongpu-F32DotTC", "mlir::ModuleOp"> { } def TritonGPUBF16DotTC : Pass<"tritongpu-BF16DotTC", "mlir::ModuleOp"> { - let summary = "3xBF16 trick"; + let summary = "Use 3xBF16 dot ops to compute F32 dot result"; let description = [{ Decompose fp32 `DotOp` instructions into BF16 operations. diff --git a/lib/Dialect/TritonGPU/Transforms/BF16DotTC.cpp b/lib/Dialect/TritonGPU/Transforms/BF16DotTC.cpp index e93989b2bd85..a81787d52b7e 100644 --- a/lib/Dialect/TritonGPU/Transforms/BF16DotTC.cpp +++ b/lib/Dialect/TritonGPU/Transforms/BF16DotTC.cpp @@ -2,9 +2,7 @@ #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "triton/Dialect/TritonGPU/Transforms/Passes.h" -namespace mlir { -namespace triton { -namespace gpu { +namespace mlir::triton::gpu { #define GEN_PASS_DEF_TRITONGPUBF16DOTTC #include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc" @@ -12,29 +10,27 @@ namespace gpu { namespace { template -auto convertValue(const Value &value, const FloatType &scalarToType, +auto convertValue(Value value, const FloatType &scalarToType, PatternRewriter &rewriter) -> mlir::Value { auto fromType = cast(value.getType()); - auto toType = RankedTensorType::get(fromType.getShape(), scalarToType, - fromType.getEncoding()); + auto toType = fromType.cloneWith(std::nullopt, scalarToType); return rewriter.create(value.getLoc(), toType, value).getResult(); } -auto SplitF32(Value input, unsigned N, PatternRewriter &rewriter) +auto splitF32(Value input, unsigned N, PatternRewriter &rewriter) -> llvm::SmallVector { - llvm::SmallVector split_inputs; + llvm::SmallVector splitInputs; for (unsigned i = 0; i < N; ++i) { - Value input_as_bf16 = + Value inputAsBF16 = convertValue(input, rewriter.getBF16Type(), rewriter); if (i != N - 1) { - Value input_as_f32 = convertValue( - input_as_bf16, rewriter.getF32Type(), rewriter); - input = - rewriter.create(input.getLoc(), input, input_as_f32); + Value inputAsF32 = convertValue( + inputAsBF16, rewriter.getF32Type(), rewriter); + input = rewriter.create(input.getLoc(), input, inputAsF32); } - split_inputs.push_back(input_as_bf16); + splitInputs.push_back(inputAsBF16); } - return split_inputs; + return splitInputs; } Value IEEEDot(PatternRewriter &rewriter, Value lhs, Value rhs, Value acc) { @@ -48,9 +44,9 @@ auto getBF16Count(triton::InputPrecision precision) -> unsigned { default: return 0; case InputPrecision::BF16x3: + // BF16x3 only needs the first 2 values derived from splitting an F32 return 2; case InputPrecision::BF16x6: - case InputPrecision::BF16x9: return 3; } } @@ -94,18 +90,20 @@ struct BF16xN : public OpRewritePattern { }; // Starting Values: a(0), a(1), a(2), b(0), b(1), b(2) and zero accumulator - const auto lhs_parts = SplitF32(dotOp.getA(), N, rewriter); - const auto rhs_parts = SplitF32(dotOp.getB(), N, rewriter); + const auto lhs_parts = splitF32(dotOp.getA(), N, rewriter); + const auto rhs_parts = splitF32(dotOp.getB(), N, rewriter); auto result = zeroLike(dotOp.getC()); switch (dotOp.getInputPrecision()) { default: - assert(false && "BF16DotTCPass expects BF16x9, BF16x6 or BF16x3"); + assert(false && "BF16DotTCPass expects BF16x6 or BF16x3"); return failure(); - case InputPrecision::BF16x9: - result = IEEEDot(rewriter, lhs_parts[lo], rhs_parts[lo], result); - result = IEEEDot(rewriter, lhs_parts[mid], rhs_parts[lo], result); - result = IEEEDot(rewriter, lhs_parts[lo], rhs_parts[mid], result); + + // NOTE: 9 dots possible; handled like so if not for lack of speedup: + // case InputPrecision::BF16x9: + // result = IEEEDot(rewriter, lhs_parts[lo], rhs_parts[lo], result); + // result = IEEEDot(rewriter, lhs_parts[mid], rhs_parts[lo], result); + // result = IEEEDot(rewriter, lhs_parts[lo], rhs_parts[mid], result); case InputPrecision::BF16x6: result = IEEEDot(rewriter, lhs_parts[mid], rhs_parts[mid], result); @@ -142,6 +140,4 @@ struct BF16DotTCPass : public impl::TritonGPUBF16DotTCBase { } }; -} // namespace gpu -} // namespace triton -} // namespace mlir +} // namespace mlir::triton::gpu diff --git a/python/src/ir.cc b/python/src/ir.cc index b890a2af05c1..93be162289d6 100644 --- a/python/src/ir.cc +++ b/python/src/ir.cc @@ -310,7 +310,6 @@ void init_triton_ir(py::module &&m) { .value("IEEE", InputPrecision::IEEE) .value("BF16x3", InputPrecision::BF16x3) .value("BF16x6", InputPrecision::BF16x6) - .value("BF16x9", InputPrecision::BF16x9) .export_values(); py::enum_(m, "ScaleDotElemTypeTY", py::module_local()) diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index 25d399065b11..6c55b94931d3 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -3084,7 +3084,7 @@ def get_test_dot_base_cases(): return [(*shape, 4, False, False, epilogue, input_precision, in_dtype, out_dtype, 1, None) for shape in [(64, 64, 64), (32, 32, 32), (16, 16, 16)] for epilogue in ['none', 'trans', 'add-matrix', 'add-rows', 'add-cols', 'softmax', 'chain-dot'] - for input_precision in ['tf32', 'tf32x3', 'ieee', 'bf16x3', 'bf16x6', 'bf16x9'] + for input_precision in ['tf32', 'tf32x3', 'ieee', 'bf16x3', 'bf16x6'] for in_dtype, out_dtype in [('float16', 'float16'), ('float16', 'float32'), ('float32', 'float32'), ('float64', 'float64')] @@ -3238,8 +3238,8 @@ def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, input_precision, in_dty pytest.skip(f"{in_dtype} only supported on CDNA4 and gfx12") if in_dtype in ("float8e5b16", "float8e4b8") and not is_hip_cdna3(): pytest.skip(f"{in_dtype} only supported on CDNA3") - if not ((input_precision == "bf16x3") or (input_precision == "bf16x6") or (input_precision == "bf16x9") or - (input_precision == "ieee") or (input_precision == "tf32" and is_hip_cdna3())): + if not ((input_precision in ("bf16x3", "bf16x6")) or (input_precision == "ieee") or + (input_precision == "tf32" and is_hip_cdna3())): pytest.skip(f"{input_precision} not supported on HIP") if kpack == 2 and in_dtype == 'int8' and K < 64: pytest.skip("kpack too large for K") diff --git a/python/triton/language/semantic.py b/python/triton/language/semantic.py index c0c5b383da96..4b75d8e7cdb6 100644 --- a/python/triton/language/semantic.py +++ b/python/triton/language/semantic.py @@ -1471,8 +1471,6 @@ def _str_to_dot_input_precision(self, input_precision): input_precision = "BF16x3" if input_precision == "BF16X6": input_precision = "BF16x6" - if input_precision == "BF16X9": - input_precision = "BF16x9" return getattr(ir.INPUT_PRECISION, input_precision) def dot(self, lhs: TensorTy, rhs: TensorTy, acc: TensorTy, input_precision: Optional[str], diff --git a/test/TritonGPU/bf16x3-matmul.mlir b/test/TritonGPU/bf16x3-matmul.mlir index 0a13c7278d18..85a8aef6464f 100644 --- a/test/TritonGPU/bf16x3-matmul.mlir +++ b/test/TritonGPU/bf16x3-matmul.mlir @@ -1,108 +1,64 @@ // RUN: triton-opt %s -tritongpu-BF16DotTC -canonicalize | FileCheck %s --check-prefixes=CHECK +module { + tt.func @dot_test_BF16x3(%arg0: tensor<16x16xf32>, %arg1: tensor<16x16xf32>, %arg2: tensor<16x16xf32>) -> tensor<16x16xf32> { + // CHECK-LABEL: dot_test_BF16x3 + // CHECK: %[[lhs_hi:.*]] = arith.truncf %arg0 + // CHECK-NEXT: %[[val1:.*]] = arith.extf %[[lhs_hi]] + // CHECK-NEXT: %[[val2:.*]] = arith.subf %arg0, %[[val1]] + // CHECK-NEXT: %[[lhs_mid:.*]] = arith.truncf %[[val2]] -//// Tests for BF16x3: - -// CHECK: %[[lhs_hi:.*]] = arith.truncf %arg0 -// CHECK-NEXT: %[[val1:.*]] = arith.extf %[[lhs_hi]] -// CHECK-NEXT: %[[val2:.*]] = arith.subf %arg0, %[[val1]] -// CHECK-NEXT: %[[lhs_mid:.*]] = arith.truncf %[[val2]] - -// CHECK: %[[rhs_hi:.*]] = arith.truncf %arg1 -// CHECK-NEXT: %[[val8:.*]] = arith.extf %[[rhs_hi]] -// CHECK-NEXT: %[[val9:.*]] = arith.subf %arg1, %[[val8]] -// CHECK-NEXT: %[[rhs_mid:.*]] = arith.truncf %[[val9]] - -// CHECK-NEXT: %[[val20:.*]] = tt.dot %[[lhs_mid]], %[[rhs_hi]] -// CHECK-NEXT: %[[val21:.*]] = tt.dot %[[lhs_hi]], %[[rhs_mid]], %[[val20]] - -// CHECK: %[[val22:.*]] = arith.cmpf uno, %[[val21]], %[[val21]] -// CHECK-NEXT: %[[val23:.*]] = arith.select %[[val22]] - -// CHECK: %[[val24:.*]] = tt.dot %[[lhs_hi]], %[[rhs_hi]], %[[val23]] -// CHECK-NEXT: %[[val25:.*]] = arith.addf %[[val24]], %arg2 - - - -//// Tests for BF16x6: - -// CHECK: %[[lhs_hi:.*]] = arith.truncf %arg0 -// CHECK-NEXT: %[[val1:.*]] = arith.extf %[[lhs_hi]] -// CHECK-NEXT: %[[val2:.*]] = arith.subf %arg0, %[[val1]] -// CHECK-NEXT: %[[lhs_mid:.*]] = arith.truncf %[[val2]] -// CHECK-NEXT: %[[val4:.*]] = arith.extf %[[lhs_mid]] -// CHECK-NEXT: %[[val5:.*]] = arith.subf %[[val2]], %[[val4]] -// CHECK-NEXT: %[[lhs_lo:.*]] = arith.truncf %[[val5]] - -// CHECK: %[[rhs_hi:.*]] = arith.truncf %arg1 -// CHECK-NEXT: %[[val8:.*]] = arith.extf %[[rhs_hi]] -// CHECK-NEXT: %[[val9:.*]] = arith.subf %arg1, %[[val8]] -// CHECK-NEXT: %[[rhs_mid:.*]] = arith.truncf %[[val9]] -// CHECK-NEXT: %[[val11:.*]] = arith.extf %[[rhs_mid]] -// CHECK-NEXT: %[[val12:.*]] = arith.subf %[[val9]], %[[val11]] -// CHECK-NEXT: %[[rhs_lo:.*]] = arith.truncf %[[val12]] - -// CHECK: %[[val17:.*]] = tt.dot %[[lhs_mid]], %[[rhs_mid]] -// CHECK-NEXT: %[[val18:.*]] = tt.dot %[[lhs_lo]], %[[rhs_hi]], %[[val17]] -// CHECK-NEXT: %[[val19:.*]] = tt.dot %[[lhs_hi]], %[[rhs_lo]], %[[val18]] -// CHECK-NEXT: %[[val20:.*]] = tt.dot %[[lhs_mid]], %[[rhs_hi]], %[[val19]] -// CHECK-NEXT: %[[val21:.*]] = tt.dot %[[lhs_hi]], %[[rhs_mid]], %[[val20]] - -// CHECK: %[[val22:.*]] = arith.cmpf uno, %[[val21]], %[[val21]] -// CHECK-NEXT: %[[val23:.*]] = arith.select %[[val22]] - -// CHECK: %[[val24:.*]] = tt.dot %[[lhs_hi]], %[[rhs_hi]], %[[val23]] -// CHECK-NEXT: %[[val25:.*]] = arith.addf %[[val24]], %arg2 - - - -//// Tests for BF16x9: - -// CHECK: %[[lhs_hi:.*]] = arith.truncf %arg0 -// CHECK-NEXT: %[[val1:.*]] = arith.extf %[[lhs_hi]] -// CHECK-NEXT: %[[val2:.*]] = arith.subf %arg0, %[[val1]] -// CHECK-NEXT: %[[lhs_mid:.*]] = arith.truncf %[[val2]] -// CHECK-NEXT: %[[val4:.*]] = arith.extf %[[lhs_mid]] -// CHECK-NEXT: %[[val5:.*]] = arith.subf %[[val2]], %[[val4]] -// CHECK-NEXT: %[[lhs_lo:.*]] = arith.truncf %[[val5]] - -// CHECK: %[[rhs_hi:.*]] = arith.truncf %arg1 -// CHECK-NEXT: %[[val8:.*]] = arith.extf %[[rhs_hi]] -// CHECK-NEXT: %[[val9:.*]] = arith.subf %arg1, %[[val8]] -// CHECK-NEXT: %[[rhs_mid:.*]] = arith.truncf %[[val9]] -// CHECK-NEXT: %[[val11:.*]] = arith.extf %[[rhs_mid]] -// CHECK-NEXT: %[[val12:.*]] = arith.subf %[[val9]], %[[val11]] -// CHECK-NEXT: %[[rhs_lo:.*]] = arith.truncf %[[val12]] + // CHECK: %[[rhs_hi:.*]] = arith.truncf %arg1 + // CHECK-NEXT: %[[val8:.*]] = arith.extf %[[rhs_hi]] + // CHECK-NEXT: %[[val9:.*]] = arith.subf %arg1, %[[val8]] + // CHECK-NEXT: %[[rhs_mid:.*]] = arith.truncf %[[val9]] -// CHECK: %[[val14:.*]] = tt.dot %[[lhs_lo]], %[[rhs_lo]] -// CHECK-NEXT: %[[val15:.*]] = tt.dot %[[lhs_mid]], %[[rhs_lo]], %[[val14]] -// CHECK-NEXT: %[[val16:.*]] = tt.dot %[[lhs_lo]], %[[rhs_mid]], %[[val15]] -// CHECK-NEXT: %[[val17:.*]] = tt.dot %[[lhs_mid]], %[[rhs_mid]], %[[val16]] -// CHECK-NEXT: %[[val18:.*]] = tt.dot %[[lhs_lo]], %[[rhs_hi]], %[[val17]] -// CHECK-NEXT: %[[val19:.*]] = tt.dot %[[lhs_hi]], %[[rhs_lo]], %[[val18]] -// CHECK-NEXT: %[[val20:.*]] = tt.dot %[[lhs_mid]], %[[rhs_hi]], %[[val19]] -// CHECK-NEXT: %[[val21:.*]] = tt.dot %[[lhs_hi]], %[[rhs_mid]], %[[val20]] + // CHECK-NEXT: %[[val20:.*]] = tt.dot %[[lhs_mid]], %[[rhs_hi]] + // CHECK-NEXT: %[[val21:.*]] = tt.dot %[[lhs_hi]], %[[rhs_mid]], %[[val20]] -// CHECK: %[[val22:.*]] = arith.cmpf uno, %[[val21]], %[[val21]] -// CHECK-NEXT: %[[val23:.*]] = arith.select %[[val22]] + // CHECK: %[[val22:.*]] = arith.cmpf uno, %[[val21]], %[[val21]] + // CHECK-NEXT: %[[val23:.*]] = arith.select %[[val22]] -// CHECK: %[[val24:.*]] = tt.dot %[[lhs_hi]], %[[rhs_hi]], %[[val23]] -// CHECK-NEXT: %[[val25:.*]] = arith.addf %[[val24]], %arg2 + // CHECK: %[[val24:.*]] = tt.dot %[[lhs_hi]], %[[rhs_hi]], %[[val23]] + // CHECK-NEXT: %[[val25:.*]] = arith.addf %[[val24]], %arg2 -module { - tt.func @dot_test_BF16x3(%arg0: tensor<16x16xf32>, %arg1: tensor<16x16xf32>, %arg2: tensor<16x16xf32>) -> tensor<16x16xf32> { %4 = tt.dot %arg0, %arg1, %arg2, inputPrecision = bf16x3 : tensor<16x16xf32> * tensor<16x16xf32> -> tensor<16x16xf32> tt.return %4 : tensor<16x16xf32> } tt.func @dot_test_BF16x6(%arg0: tensor<16x16xf32>, %arg1: tensor<16x16xf32>, %arg2: tensor<16x16xf32>) -> tensor<16x16xf32> { - %4 = tt.dot %arg0, %arg1, %arg2, inputPrecision = bf16x6 : tensor<16x16xf32> * tensor<16x16xf32> -> tensor<16x16xf32> - tt.return %4 : tensor<16x16xf32> - } + // CHECK-LABEL: dot_test_BF16x6 + + // CHECK: %[[lhs_hi:.*]] = arith.truncf %arg0 + // CHECK-NEXT: %[[val1:.*]] = arith.extf %[[lhs_hi]] + // CHECK-NEXT: %[[val2:.*]] = arith.subf %arg0, %[[val1]] + // CHECK-NEXT: %[[lhs_mid:.*]] = arith.truncf %[[val2]] + // CHECK-NEXT: %[[val4:.*]] = arith.extf %[[lhs_mid]] + // CHECK-NEXT: %[[val5:.*]] = arith.subf %[[val2]], %[[val4]] + // CHECK-NEXT: %[[lhs_lo:.*]] = arith.truncf %[[val5]] + + // CHECK: %[[rhs_hi:.*]] = arith.truncf %arg1 + // CHECK-NEXT: %[[val8:.*]] = arith.extf %[[rhs_hi]] + // CHECK-NEXT: %[[val9:.*]] = arith.subf %arg1, %[[val8]] + // CHECK-NEXT: %[[rhs_mid:.*]] = arith.truncf %[[val9]] + // CHECK-NEXT: %[[val11:.*]] = arith.extf %[[rhs_mid]] + // CHECK-NEXT: %[[val12:.*]] = arith.subf %[[val9]], %[[val11]] + // CHECK-NEXT: %[[rhs_lo:.*]] = arith.truncf %[[val12]] + + // CHECK: %[[val17:.*]] = tt.dot %[[lhs_mid]], %[[rhs_mid]] + // CHECK-NEXT: %[[val18:.*]] = tt.dot %[[lhs_lo]], %[[rhs_hi]], %[[val17]] + // CHECK-NEXT: %[[val19:.*]] = tt.dot %[[lhs_hi]], %[[rhs_lo]], %[[val18]] + // CHECK-NEXT: %[[val20:.*]] = tt.dot %[[lhs_mid]], %[[rhs_hi]], %[[val19]] + // CHECK-NEXT: %[[val21:.*]] = tt.dot %[[lhs_hi]], %[[rhs_mid]], %[[val20]] + + // CHECK: %[[val22:.*]] = arith.cmpf uno, %[[val21]], %[[val21]] + // CHECK-NEXT: %[[val23:.*]] = arith.select %[[val22]] + + // CHECK: %[[val24:.*]] = tt.dot %[[lhs_hi]], %[[rhs_hi]], %[[val23]] + // CHECK-NEXT: %[[val25:.*]] = arith.addf %[[val24]], %arg2 - tt.func @dot_test_BF16x9(%arg0: tensor<16x16xf32>, %arg1: tensor<16x16xf32>, %arg2: tensor<16x16xf32>) -> tensor<16x16xf32> { - %4 = tt.dot %arg0, %arg1, %arg2, inputPrecision = bf16x9 : tensor<16x16xf32> * tensor<16x16xf32> -> tensor<16x16xf32> + %4 = tt.dot %arg0, %arg1, %arg2, inputPrecision = bf16x6 : tensor<16x16xf32> * tensor<16x16xf32> -> tensor<16x16xf32> tt.return %4 : tensor<16x16xf32> } } diff --git a/third_party/amd/backend/compiler.py b/third_party/amd/backend/compiler.py index 56516e85504c..52027a293977 100644 --- a/third_party/amd/backend/compiler.py +++ b/third_party/amd/backend/compiler.py @@ -45,7 +45,7 @@ class HIPOptions: supported_fp8_dtypes: Tuple[str] = ("fp8e4nv", "fp8e5", "fp8e5b16", "fp8e4b8") deprecated_fp8_dot_operand_dtypes: Tuple[str] = () default_dot_input_precision: str = "ieee" - allowed_dot_input_precisions: Tuple[str] = ("ieee", 'bf16x3', 'bf16x6', 'bf16x9') + allowed_dot_input_precisions: Tuple[str] = ("ieee", 'bf16x3', 'bf16x6') enable_fp_fusion: bool = True launch_cooperative_grid: bool = False matrix_instr_nonkdim: int = 0 diff --git a/third_party/nvidia/backend/compiler.py b/third_party/nvidia/backend/compiler.py index 6b9c2b311185..a9b0e8f5eb9f 100644 --- a/third_party/nvidia/backend/compiler.py +++ b/third_party/nvidia/backend/compiler.py @@ -121,7 +121,7 @@ class CUDAOptions: supported_fp8_dtypes: Tuple[str] = ("fp8e5", "fp8e4b15") deprecated_fp8_dot_operand_dtypes: Tuple[str] = () default_dot_input_precision: str = "tf32" - allowed_dot_input_precisions: Tuple[str] = ("tf32", "tf32x3", "ieee", 'bf16x3', 'bf16x6', 'bf16x9') + allowed_dot_input_precisions: Tuple[str] = ("tf32", "tf32x3", "ieee", 'bf16x3', 'bf16x6') max_num_imprecise_acc_default: bool = None extern_libs: dict = None debug: bool = False From 8c005f4d82321eeba41b1ab40d4f0670323b557f Mon Sep 17 00:00:00 2001 From: Puyan Lotfi Date: Fri, 10 Oct 2025 10:29:51 -0700 Subject: [PATCH 26/39] leave door open for BF16x1 to match hipblas --- lib/Dialect/TritonGPU/Transforms/BF16DotTC.cpp | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/lib/Dialect/TritonGPU/Transforms/BF16DotTC.cpp b/lib/Dialect/TritonGPU/Transforms/BF16DotTC.cpp index a81787d52b7e..912cc785aa8d 100644 --- a/lib/Dialect/TritonGPU/Transforms/BF16DotTC.cpp +++ b/lib/Dialect/TritonGPU/Transforms/BF16DotTC.cpp @@ -114,9 +114,12 @@ struct BF16xN : public OpRewritePattern { case InputPrecision::BF16x3: result = IEEEDot(rewriter, lhs_parts[mid], rhs_parts[hi], result); result = IEEEDot(rewriter, lhs_parts[hi], rhs_parts[mid], result); + result = replaceNansWithZeros(result); + + // NOTE: For BF16x1 bail without replaceNansWithZeros + // case InputPrecision::BF16x1: break; } - result = replaceNansWithZeros(result); result = IEEEDot(rewriter, lhs_parts[hi], rhs_parts[hi], result); result = rewriter.create(loc, result, dotOp.getC()); From ca6b7e3cb365adb9b12be5e79b9c7c6c4add6858 Mon Sep 17 00:00:00 2001 From: Puyan Lotfi Date: Wed, 15 Oct 2025 13:50:07 -0700 Subject: [PATCH 27/39] Fix cuda tests, and disable BF16xN for interpreter --- python/test/unit/language/test_core.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index 6c55b94931d3..0ee9c431e818 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -3209,6 +3209,8 @@ def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, input_precision, in_dty if is_interpreter(): if in_dtype == 'bfloat16': pytest.skip("bfloat16 is not supported in the interpreter") + if input_precision == "bf16x3" or input_precision == "bf16x6": + pytest.skip(f"{input_precision} not currently supported on CUDA") else: if not is_hip() and K < 16: pytest.skip("small dots are supported only on HIP at the moment") @@ -3428,6 +3430,8 @@ def kernel(X, stride_xm, stride_xk, Y, stride_yk, stride_yn, W, stride_wn, strid if in_dtype == 'float32' and input_precision != "ieee": if is_tcgen5: assert re.search(r'tcgen05.mma.cta_group::1.kind::tf32', ptx) + elif input_precision in ("bf16x3", "bf16x6"): + assert re.search(r'[mma|wgmma.mma_async].sync.aligned.m\d+n\d+k16(?:.row.col)?.f32.bf16.bf16', ptx) else: assert re.search(r'[mma|wgmma.mma_async].sync.aligned.m\d+n\d+k8(?:.row.col)?.f32.tf32.tf32', ptx) elif in_dtype == 'float16' and out_dtype == tl.float32: From 1ebfaaf1617407aeda2287b074d9f399e645763b Mon Sep 17 00:00:00 2001 From: Puyan Lotfi Date: Wed, 15 Oct 2025 23:44:21 -0700 Subject: [PATCH 28/39] attempting to fix blackwell tl.dot tests --- python/test/unit/language/test_core.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index 0ee9c431e818..99f743f0820d 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -3429,7 +3429,10 @@ def kernel(X, stride_xm, stride_xk, Y, stride_yk, stride_yn, W, stride_wn, strid if in_dtype == 'float32' and input_precision != "ieee": if is_tcgen5: - assert re.search(r'tcgen05.mma.cta_group::1.kind::tf32', ptx) + if input_precision in ("bf16x3", "bf16x6"): + assert re.search(r'tcgen05.mma.cta_group::1.kind::bf16', ptx) + else: + assert re.search(r'tcgen05.mma.cta_group::1.kind::tf32', ptx) elif input_precision in ("bf16x3", "bf16x6"): assert re.search(r'[mma|wgmma.mma_async].sync.aligned.m\d+n\d+k16(?:.row.col)?.f32.bf16.bf16', ptx) else: From 3d552a67573e54397a464c8d73919284f1c7fa0c Mon Sep 17 00:00:00 2001 From: Puyan Lotfi Date: Thu, 16 Oct 2025 15:41:36 -0700 Subject: [PATCH 29/39] Fix B200 tl.dot test. BF16 dot sub-opcode mnemonic is ".f16" https://github.com/triton-lang/triton/blob/main/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv5.cpp#L261-L262 --- python/test/unit/language/test_core.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index 99f743f0820d..315a4caad6fa 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -3430,7 +3430,7 @@ def kernel(X, stride_xm, stride_xk, Y, stride_yk, stride_yn, W, stride_wn, strid if in_dtype == 'float32' and input_precision != "ieee": if is_tcgen5: if input_precision in ("bf16x3", "bf16x6"): - assert re.search(r'tcgen05.mma.cta_group::1.kind::bf16', ptx) + assert re.search(r'tcgen05.mma.cta_group::1.kind::f16', ptx) else: assert re.search(r'tcgen05.mma.cta_group::1.kind::tf32', ptx) elif input_precision in ("bf16x3", "bf16x6"): From 400427ddf2aa97136aa182c4406256ae9abd04a9 Mon Sep 17 00:00:00 2001 From: Puyan Lotfi Date: Fri, 17 Oct 2025 20:51:38 -0700 Subject: [PATCH 30/39] Merge BF16DotTC into F32DotTC pass, separate pattern however (cherry picked from commit f2dcc4e71280f76dddead1390a0459f7d3f93a8f) --- .../Dialect/TritonGPU/Transforms/Passes.td | 25 ++- .../TritonGPU/Transforms/BF16DotTC.cpp | 146 ------------------ .../TritonGPU/Transforms/CMakeLists.txt | 1 - lib/Dialect/TritonGPU/Transforms/F32DotTC.cpp | 134 +++++++++++++++- python/src/passes.cc | 3 +- third_party/amd/backend/compiler.py | 3 +- third_party/nvidia/backend/compiler.py | 5 +- 7 files changed, 143 insertions(+), 174 deletions(-) delete mode 100644 lib/Dialect/TritonGPU/Transforms/BF16DotTC.cpp diff --git a/include/triton/Dialect/TritonGPU/Transforms/Passes.td b/include/triton/Dialect/TritonGPU/Transforms/Passes.td index 9bfd2023d263..73b057dea6cd 100644 --- a/include/triton/Dialect/TritonGPU/Transforms/Passes.td +++ b/include/triton/Dialect/TritonGPU/Transforms/Passes.td @@ -177,26 +177,23 @@ def TritonGPUPartitionScheduling : Pass<"tritongpu-partition-scheduling", "mlir: } def TritonGPUF32DotTC : Pass<"tritongpu-F32DotTC", "mlir::ModuleOp"> { - let summary = "3xTF32 trick"; + let summary = "Emulate dot-product tensor core precision using TF32s or BF16s"; let description = [{ - Decompose fp32 `DotOp` instructions into 4 pointwise ops and 3 fp16 `DotOp`s - to allow using TensorCores. See https://github.com/NVIDIA/cutlass/discussions/385 + Generic pass to emulate/decompose f32 `DotOp` instructions. + * Decompose fp32 `DotOp` instructions into 4 pointwise ops and 3 fp16 `DotOp`s + to allow using TensorCores. See https://github.com/NVIDIA/cutlass/discussions/385. + * Decompose fp32 `DotOp` instructions into BF16 operations. + See https://arxiv.org/abs/1904.06376 }]; let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect", "mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect"]; -} - -def TritonGPUBF16DotTC : Pass<"tritongpu-BF16DotTC", "mlir::ModuleOp"> { - let summary = "Use 3xBF16 dot ops to compute F32 dot result"; - - let description = [{ - Decompose fp32 `DotOp` instructions into BF16 operations. - See https://arxiv.org/abs/1904.06376 - }]; - - let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect"]; + let options = [ + Option<"emuTF32", "emu-tf32", + "bool", /*default*/"false", + "whether to handle InputPrecision TF32xN for Nvidia GPUs"> + ]; } def TritonGPUPrefetch : Pass<"tritongpu-prefetch", "mlir::ModuleOp"> { diff --git a/lib/Dialect/TritonGPU/Transforms/BF16DotTC.cpp b/lib/Dialect/TritonGPU/Transforms/BF16DotTC.cpp deleted file mode 100644 index 912cc785aa8d..000000000000 --- a/lib/Dialect/TritonGPU/Transforms/BF16DotTC.cpp +++ /dev/null @@ -1,146 +0,0 @@ -#include "mlir/IR/PatternMatch.h" -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" -#include "triton/Dialect/TritonGPU/Transforms/Passes.h" - -namespace mlir::triton::gpu { - -#define GEN_PASS_DEF_TRITONGPUBF16DOTTC -#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc" - -namespace { - -template -auto convertValue(Value value, const FloatType &scalarToType, - PatternRewriter &rewriter) -> mlir::Value { - auto fromType = cast(value.getType()); - auto toType = fromType.cloneWith(std::nullopt, scalarToType); - return rewriter.create(value.getLoc(), toType, value).getResult(); -} - -auto splitF32(Value input, unsigned N, PatternRewriter &rewriter) - -> llvm::SmallVector { - llvm::SmallVector splitInputs; - for (unsigned i = 0; i < N; ++i) { - Value inputAsBF16 = - convertValue(input, rewriter.getBF16Type(), rewriter); - if (i != N - 1) { - Value inputAsF32 = convertValue( - inputAsBF16, rewriter.getF32Type(), rewriter); - input = rewriter.create(input.getLoc(), input, inputAsF32); - } - splitInputs.push_back(inputAsBF16); - } - return splitInputs; -} - -Value IEEEDot(PatternRewriter &rewriter, Value lhs, Value rhs, Value acc) { - return rewriter.create(lhs.getLoc(), lhs, rhs, acc, - /*inputPrecision=*/InputPrecision::IEEE, - /*maxNumImpreciseAcc=*/0); -} - -auto getBF16Count(triton::InputPrecision precision) -> unsigned { - switch (precision) { - default: - return 0; - case InputPrecision::BF16x3: - // BF16x3 only needs the first 2 values derived from splitting an F32 - return 2; - case InputPrecision::BF16x6: - return 3; - } -} - -// Implements 3xBF16 https://arxiv.org/abs/1904.06376 -// See also -// https://github.com/openxla/xla/blob/e33f93fb7220d408811afdc926cf10baaf49c64e/xla/backends/gpu/codegen/triton/dot_algorithms.cc#L152 -// As well as -// https://github.com/ROCm/rocm-libraries/blob/develop/projects/hipblaslt/tensilelite/Tensile/Components/LocalRead.py#L288-L330 -struct BF16xN : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(DotOp dotOp, - PatternRewriter &rewriter) const override { - // BF16 indices and count - const unsigned hi = 0; - const unsigned mid = 1; - const unsigned lo = 2; - const unsigned N = getBF16Count(dotOp.getInputPrecision()); - Location loc = dotOp.getLoc(); - auto typeA = dotOp.getA().getType(); - auto typeB = dotOp.getB().getType(); - - if (!cast(typeA).getElementType().isF32() || - !cast(typeB).getElementType().isF32() || !N) - return failure(); - - // Aux functions - auto zeroLike = [&](Value c) -> Value { - return rewriter.create( - dotOp->getLoc(), c.getType(), - rewriter.create(dotOp->getLoc(), - rewriter.getF32FloatAttr(0))); - }; - auto replaceNansWithZeros = [&](Value value) -> Value { - auto nans = rewriter.create( - dotOp->getLoc(), arith::CmpFPredicate::UNO, value, value); - auto zero = zeroLike(value); - return rewriter.create(dotOp->getLoc(), nans, zero, - value); - }; - - // Starting Values: a(0), a(1), a(2), b(0), b(1), b(2) and zero accumulator - const auto lhs_parts = splitF32(dotOp.getA(), N, rewriter); - const auto rhs_parts = splitF32(dotOp.getB(), N, rewriter); - auto result = zeroLike(dotOp.getC()); - - switch (dotOp.getInputPrecision()) { - default: - assert(false && "BF16DotTCPass expects BF16x6 or BF16x3"); - return failure(); - - // NOTE: 9 dots possible; handled like so if not for lack of speedup: - // case InputPrecision::BF16x9: - // result = IEEEDot(rewriter, lhs_parts[lo], rhs_parts[lo], result); - // result = IEEEDot(rewriter, lhs_parts[mid], rhs_parts[lo], result); - // result = IEEEDot(rewriter, lhs_parts[lo], rhs_parts[mid], result); - - case InputPrecision::BF16x6: - result = IEEEDot(rewriter, lhs_parts[mid], rhs_parts[mid], result); - - result = IEEEDot(rewriter, lhs_parts[lo], rhs_parts[hi], result); - result = IEEEDot(rewriter, lhs_parts[hi], rhs_parts[lo], result); - - case InputPrecision::BF16x3: - result = IEEEDot(rewriter, lhs_parts[mid], rhs_parts[hi], result); - result = IEEEDot(rewriter, lhs_parts[hi], rhs_parts[mid], result); - result = replaceNansWithZeros(result); - - // NOTE: For BF16x1 bail without replaceNansWithZeros - // case InputPrecision::BF16x1: break; - } - - result = IEEEDot(rewriter, lhs_parts[hi], rhs_parts[hi], result); - result = rewriter.create(loc, result, dotOp.getC()); - - rewriter.replaceOp(dotOp, result); - return success(); - } -}; - -} // anonymous namespace - -struct BF16DotTCPass : public impl::TritonGPUBF16DotTCBase { - void runOnOperation() override { - MLIRContext *context = &getContext(); - ModuleOp m = getOperation(); - - RewritePatternSet decomposePatterns(context); - decomposePatterns.add(context); - if (applyPatternsGreedily(m, std::move(decomposePatterns)).failed()) { - signalPassFailure(); - } - } -}; - -} // namespace mlir::triton::gpu diff --git a/lib/Dialect/TritonGPU/Transforms/CMakeLists.txt b/lib/Dialect/TritonGPU/Transforms/CMakeLists.txt index 99e756c6c430..965b1e1e7d0e 100644 --- a/lib/Dialect/TritonGPU/Transforms/CMakeLists.txt +++ b/lib/Dialect/TritonGPU/Transforms/CMakeLists.txt @@ -2,7 +2,6 @@ add_triton_library(TritonGPUTransforms AccelerateMatmul.cpp Coalesce.cpp F32DotTC.cpp - BF16DotTC.cpp FuseNestedLoops.cpp CombineTensorSelectAndIf.cpp DecomposeScaledBlocked.cpp diff --git a/lib/Dialect/TritonGPU/Transforms/F32DotTC.cpp b/lib/Dialect/TritonGPU/Transforms/F32DotTC.cpp index 6fe35aebdc37..ef83c954785f 100644 --- a/lib/Dialect/TritonGPU/Transforms/F32DotTC.cpp +++ b/lib/Dialect/TritonGPU/Transforms/F32DotTC.cpp @@ -2,15 +2,132 @@ #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "triton/Dialect/TritonGPU/Transforms/Passes.h" -namespace mlir { -namespace triton { -namespace gpu { +namespace mlir::triton::gpu { #define GEN_PASS_DEF_TRITONGPUF32DOTTC #include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc" namespace { +template +auto convertValue(Value value, const FloatType &scalarToType, + PatternRewriter &rewriter) -> mlir::Value { + auto fromType = cast(value.getType()); + auto toType = fromType.cloneWith(std::nullopt, scalarToType); + return rewriter.create(value.getLoc(), toType, value).getResult(); +} + +auto splitF32(Value input, unsigned N, PatternRewriter &rewriter) + -> llvm::SmallVector { + llvm::SmallVector splitInputs; + for (unsigned i = 0; i < N; ++i) { + Value inputAsBF16 = + convertValue(input, rewriter.getBF16Type(), rewriter); + if (i != N - 1) { + Value inputAsF32 = convertValue( + inputAsBF16, rewriter.getF32Type(), rewriter); + input = rewriter.create(input.getLoc(), input, inputAsF32); + } + splitInputs.push_back(inputAsBF16); + } + return splitInputs; +} + +Value IEEEDot(PatternRewriter &rewriter, Value lhs, Value rhs, Value acc) { + return rewriter.create(lhs.getLoc(), lhs, rhs, acc, + /*inputPrecision=*/InputPrecision::IEEE, + /*maxNumImpreciseAcc=*/0); +} + +auto getBF16Count(triton::InputPrecision precision) -> unsigned { + switch (precision) { + default: + return 0; + case InputPrecision::BF16x3: + // BF16x3 only needs the first 2 values derived from splitting an F32 + return 2; + case InputPrecision::BF16x6: + return 3; + } +} + +// Implements 3xBF16 https://arxiv.org/abs/1904.06376 +// See also +// https://github.com/openxla/xla/blob/e33f93fb7220d408811afdc926cf10baaf49c64e/xla/backends/gpu/codegen/triton/dot_algorithms.cc#L152 +// As well as +// https://github.com/ROCm/rocm-libraries/blob/develop/projects/hipblaslt/tensilelite/Tensile/Components/LocalRead.py#L288-L330 +struct BF16xN : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(DotOp dotOp, + PatternRewriter &rewriter) const override { + // BF16 indices and count + const unsigned hi = 0; + const unsigned mid = 1; + const unsigned lo = 2; + const unsigned N = getBF16Count(dotOp.getInputPrecision()); + Location loc = dotOp.getLoc(); + auto typeA = dotOp.getA().getType(); + auto typeB = dotOp.getB().getType(); + + if (!cast(typeA).getElementType().isF32() || + !cast(typeB).getElementType().isF32() || !N) + return failure(); + + // Aux functions + auto zeroLike = [&](Value c) -> Value { + return rewriter.create( + dotOp->getLoc(), c.getType(), + rewriter.create(dotOp->getLoc(), + rewriter.getF32FloatAttr(0))); + }; + auto replaceNansWithZeros = [&](Value value) -> Value { + auto nans = rewriter.create( + dotOp->getLoc(), arith::CmpFPredicate::UNO, value, value); + auto zero = zeroLike(value); + return rewriter.create(dotOp->getLoc(), nans, zero, + value); + }; + + // Starting Values: a(0), a(1), a(2), b(0), b(1), b(2) and zero accumulator + const auto lhs_parts = splitF32(dotOp.getA(), N, rewriter); + const auto rhs_parts = splitF32(dotOp.getB(), N, rewriter); + auto result = zeroLike(dotOp.getC()); + + switch (dotOp.getInputPrecision()) { + default: + assert(false && "BF16DotTCPass expects BF16x6 or BF16x3"); + return failure(); + + // NOTE: 9 dots possible; handled like so if not for lack of speedup: + // case InputPrecision::BF16x9: + // result = IEEEDot(rewriter, lhs_parts[lo], rhs_parts[lo], result); + // result = IEEEDot(rewriter, lhs_parts[mid], rhs_parts[lo], result); + // result = IEEEDot(rewriter, lhs_parts[lo], rhs_parts[mid], result); + + case InputPrecision::BF16x6: + result = IEEEDot(rewriter, lhs_parts[mid], rhs_parts[mid], result); + + result = IEEEDot(rewriter, lhs_parts[lo], rhs_parts[hi], result); + result = IEEEDot(rewriter, lhs_parts[hi], rhs_parts[lo], result); + + case InputPrecision::BF16x3: + result = IEEEDot(rewriter, lhs_parts[mid], rhs_parts[hi], result); + result = IEEEDot(rewriter, lhs_parts[hi], rhs_parts[mid], result); + result = replaceNansWithZeros(result); + + // NOTE: For BF16x1 bail without replaceNansWithZeros + // case InputPrecision::BF16x1: break; + } + + result = IEEEDot(rewriter, lhs_parts[hi], rhs_parts[hi], result); + result = rewriter.create(loc, result, dotOp.getC()); + + rewriter.replaceOp(dotOp, result); + return success(); + } +}; + // nb. We call the trick TF32x3 as C++ disallows variables starting with numbers // Implement 3xTF32 trick https://github.com/NVIDIA/cutlass/discussions/385 // For a, b f32 @@ -103,18 +220,21 @@ class TF32x3 : public OpRewritePattern { } // anonymous namespace struct F32DotTCPass : public impl::TritonGPUF32DotTCBase { + using impl::TritonGPUF32DotTCBase< + F32DotTCPass>::TritonGPUF32DotTCBase; void runOnOperation() override { MLIRContext *context = &getContext(); ModuleOp m = getOperation(); RewritePatternSet decomposePatterns(context); - decomposePatterns.add(context); + if (this->emuTF32) { + decomposePatterns.add(context); + } + decomposePatterns.add(context); if (applyPatternsGreedily(m, std::move(decomposePatterns)).failed()) { signalPassFailure(); } } }; -} // namespace gpu -} // namespace triton -} // namespace mlir +} // namespace mlir::triton::gpu diff --git a/python/src/passes.cc b/python/src/passes.cc index a1162acfb7f6..e9b2f6e0d123 100644 --- a/python/src/passes.cc +++ b/python/src/passes.cc @@ -71,8 +71,7 @@ void init_triton_passes_ttgpuir(py::module &&m) { ADD_PASS_WRAPPER_0("add_accelerate_matmul", createTritonGPUAccelerateMatmul); ADD_PASS_WRAPPER_0("add_reorder_instructions", createTritonGPUReorderInstructions); - ADD_PASS_WRAPPER_0("add_bf16_dot_tc", createTritonGPUBF16DotTC); - ADD_PASS_WRAPPER_0("add_f32_dot_tc", createTritonGPUF32DotTC); + ADD_PASS_OPTION_WRAPPER_1("add_f32_dot_tc", createTritonGPUF32DotTC, bool); ADD_PASS_OPTION_WRAPPER_1("add_optimize_dot_operands", createTritonGPUOptimizeDotOperands, bool); ADD_PASS_WRAPPER_0("add_remove_layout_conversions", diff --git a/third_party/amd/backend/compiler.py b/third_party/amd/backend/compiler.py index 52027a293977..9675eb5d4c2a 100644 --- a/third_party/amd/backend/compiler.py +++ b/third_party/amd/backend/compiler.py @@ -207,8 +207,9 @@ def make_ttgir(mod, metadata, options): pm.run(mod, 'make_ttgir_early') pm = ir.pass_manager(mod.context) pm.enable_debug() + emuTF32 = False passes.ttgpuir.add_coalesce(pm) - passes.ttgpuir.add_bf16_dot_tc(pm) + passes.ttgpuir.add_f32_dot_tc(pm, emuTF32) passes.ttgpuir.add_remove_layout_conversions(pm) passes.ttgpuir.add_optimize_thread_locality(pm) amd.passes.ttgpuir.add_accelerate_matmul(pm, options.arch, options.matrix_instr_nonkdim, options.kpack) diff --git a/third_party/nvidia/backend/compiler.py b/third_party/nvidia/backend/compiler.py index a9b0e8f5eb9f..2b52e4112cd8 100644 --- a/third_party/nvidia/backend/compiler.py +++ b/third_party/nvidia/backend/compiler.py @@ -259,12 +259,11 @@ def make_ttgir(mod, metadata, opt, capability): cluster_info.clusterDimZ = opt.cluster_dims[2] pm = ir.pass_manager(mod.context) dump_enabled = pm.enable_debug() + emuTF32 = (capability // 10 >= 8) passes.ttir.add_convert_to_ttgpuir(pm, f"cuda:{capability}", opt.num_warps, 32, opt.num_ctas) # optimize TTGIR passes.ttgpuir.add_coalesce(pm) - if capability // 10 >= 8: - passes.ttgpuir.add_f32_dot_tc(pm) - passes.ttgpuir.add_bf16_dot_tc(pm) + passes.ttgpuir.add_f32_dot_tc(pm, emuTF32) # TODO(Qingyi): Move PlanCTAPass to the front of CoalescePass nvidia.passes.ttnvgpuir.add_plan_cta(pm, cluster_info) passes.ttgpuir.add_remove_layout_conversions(pm) From c1773434f8b16cdbe1cbb285d886a45ed949b27e Mon Sep 17 00:00:00 2001 From: Puyan Lotfi Date: Fri, 17 Oct 2025 20:52:35 -0700 Subject: [PATCH 31/39] lint (cherry picked from commit 787bb8e0e81985e3666eaa8e19358c6002b97ec3) --- lib/Dialect/TritonGPU/Transforms/F32DotTC.cpp | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/lib/Dialect/TritonGPU/Transforms/F32DotTC.cpp b/lib/Dialect/TritonGPU/Transforms/F32DotTC.cpp index ef83c954785f..7b34cfddf266 100644 --- a/lib/Dialect/TritonGPU/Transforms/F32DotTC.cpp +++ b/lib/Dialect/TritonGPU/Transforms/F32DotTC.cpp @@ -220,8 +220,7 @@ class TF32x3 : public OpRewritePattern { } // anonymous namespace struct F32DotTCPass : public impl::TritonGPUF32DotTCBase { - using impl::TritonGPUF32DotTCBase< - F32DotTCPass>::TritonGPUF32DotTCBase; + using impl::TritonGPUF32DotTCBase::TritonGPUF32DotTCBase; void runOnOperation() override { MLIRContext *context = &getContext(); ModuleOp m = getOperation(); From 6fd2e44cb7acdce67a2f28d1dc98e96eb00b7608 Mon Sep 17 00:00:00 2001 From: Puyan Lotfi Date: Fri, 17 Oct 2025 21:07:34 -0700 Subject: [PATCH 32/39] fix lit tests --- test/TritonGPU/bf16x3-matmul.mlir | 2 +- test/TritonGPU/tf32x3-matmul.mlir | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/test/TritonGPU/bf16x3-matmul.mlir b/test/TritonGPU/bf16x3-matmul.mlir index 85a8aef6464f..e4dc3737f2bf 100644 --- a/test/TritonGPU/bf16x3-matmul.mlir +++ b/test/TritonGPU/bf16x3-matmul.mlir @@ -1,4 +1,4 @@ -// RUN: triton-opt %s -tritongpu-BF16DotTC -canonicalize | FileCheck %s --check-prefixes=CHECK +// RUN: triton-opt %s -tritongpu-F32DotTC="emu-tf32=0" -canonicalize | FileCheck %s --check-prefixes=CHECK module { tt.func @dot_test_BF16x3(%arg0: tensor<16x16xf32>, %arg1: tensor<16x16xf32>, %arg2: tensor<16x16xf32>) -> tensor<16x16xf32> { diff --git a/test/TritonGPU/tf32x3-matmul.mlir b/test/TritonGPU/tf32x3-matmul.mlir index 7f7f3a11aa07..9e77679d82cc 100644 --- a/test/TritonGPU/tf32x3-matmul.mlir +++ b/test/TritonGPU/tf32x3-matmul.mlir @@ -1,4 +1,4 @@ -// RUN: triton-opt %s -tritongpu-F32DotTC -canonicalize | FileCheck %s --check-prefixes=CHECK +// RUN: triton-opt %s -tritongpu-F32DotTC="emu-tf32=1" -canonicalize | FileCheck %s --check-prefixes=CHECK // CHECK: %[[DOT1:.*]] = tt.dot %[[LHS_LOW:.*]], %[[RHS_HIGH:.*]], %cst, inputPrecision = tf32 : tensor<16x16xf32> * tensor<16x16xf32> -> tensor<16x16xf32> // CHECK: %[[DOT2:.*]] = tt.dot %[[LHS_HIGH:.*]], %[[RHS_LOW:.*]], %[[DOT1]], inputPrecision = tf32 : tensor<16x16xf32> * tensor<16x16xf32> -> tensor<16x16xf32> From 8d853bcde34b153db057b3c82d1e55f24dad0d42 Mon Sep 17 00:00:00 2001 From: Puyan Lotfi Date: Fri, 17 Oct 2025 21:10:05 -0700 Subject: [PATCH 33/39] drop TritonNvidiaGPUDialect dep from TritonGPUF32DotTC --- include/triton/Dialect/TritonGPU/Transforms/Passes.td | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/include/triton/Dialect/TritonGPU/Transforms/Passes.td b/include/triton/Dialect/TritonGPU/Transforms/Passes.td index 73b057dea6cd..f1db3d49a5a7 100644 --- a/include/triton/Dialect/TritonGPU/Transforms/Passes.td +++ b/include/triton/Dialect/TritonGPU/Transforms/Passes.td @@ -187,8 +187,7 @@ def TritonGPUF32DotTC : Pass<"tritongpu-F32DotTC", "mlir::ModuleOp"> { See https://arxiv.org/abs/1904.06376 }]; - let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect", - "mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect"]; + let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect"]; let options = [ Option<"emuTF32", "emu-tf32", "bool", /*default*/"false", From 9228dac45185bfa6d2f56b304f9b038e081ce23d Mon Sep 17 00:00:00 2001 From: Puyan Lotfi Date: Fri, 17 Oct 2025 21:16:22 -0700 Subject: [PATCH 34/39] fix input prec pytest skip reason for interpreter --- python/test/unit/language/test_core.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index 315a4caad6fa..e7d7ccd3eb3e 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -3210,7 +3210,7 @@ def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, input_precision, in_dty if in_dtype == 'bfloat16': pytest.skip("bfloat16 is not supported in the interpreter") if input_precision == "bf16x3" or input_precision == "bf16x6": - pytest.skip(f"{input_precision} not currently supported on CUDA") + pytest.skip(f"input_precision {input_precision} is not supported in the interpreter") else: if not is_hip() and K < 16: pytest.skip("small dots are supported only on HIP at the moment") From e16f9516a737ee0181ef86236835f8d9c57d3170 Mon Sep 17 00:00:00 2001 From: Puyan Lotfi Date: Sat, 18 Oct 2025 01:45:14 -0700 Subject: [PATCH 35/39] merge functionality of the BF16xN and TF32x3 pattern writers --- lib/Dialect/TritonGPU/Transforms/F32DotTC.cpp | 107 +++++++----------- 1 file changed, 42 insertions(+), 65 deletions(-) diff --git a/lib/Dialect/TritonGPU/Transforms/F32DotTC.cpp b/lib/Dialect/TritonGPU/Transforms/F32DotTC.cpp index 7b34cfddf266..bbc9384c928a 100644 --- a/lib/Dialect/TritonGPU/Transforms/F32DotTC.cpp +++ b/lib/Dialect/TritonGPU/Transforms/F32DotTC.cpp @@ -33,11 +33,30 @@ auto splitF32(Value input, unsigned N, PatternRewriter &rewriter) return splitInputs; } -Value IEEEDot(PatternRewriter &rewriter, Value lhs, Value rhs, Value acc) { - return rewriter.create(lhs.getLoc(), lhs, rhs, acc, - /*inputPrecision=*/InputPrecision::IEEE, - /*maxNumImpreciseAcc=*/0); -} +auto isF32(Value operand) { + return cast(operand.getType()).getElementType().isF32(); +}; + +auto zeroLike(Value c, PatternRewriter &rewriter) -> Value { + return rewriter.create( + c.getLoc(), c.getType(), + rewriter.create(c.getLoc(), + rewriter.getF32FloatAttr(0))); +}; + +template +Value Dot(Value lhs, Value rhs, Value acc, PatternRewriter &rewriter, uint32_t maxNumImpreciseAcc = 0) { + return rewriter.create(lhs.getLoc(), lhs, rhs, acc, precision, + maxNumImpreciseAcc); +}; + +auto replaceNansWithZeros(Value value, PatternRewriter &rewriter) -> Value { + auto nans = rewriter.create( + value.getLoc(), arith::CmpFPredicate::UNO, value, value); + auto zero = zeroLike(value, rewriter); + return rewriter.create(value.getLoc(), nans, zero, + value); +}; auto getBF16Count(triton::InputPrecision precision) -> unsigned { switch (precision) { @@ -66,33 +85,14 @@ struct BF16xN : public OpRewritePattern { const unsigned mid = 1; const unsigned lo = 2; const unsigned N = getBF16Count(dotOp.getInputPrecision()); - Location loc = dotOp.getLoc(); - auto typeA = dotOp.getA().getType(); - auto typeB = dotOp.getB().getType(); - if (!cast(typeA).getElementType().isF32() || - !cast(typeB).getElementType().isF32() || !N) + if (!isF32(dotOp.getA()) || !isF32(dotOp.getB()) || !N) return failure(); - // Aux functions - auto zeroLike = [&](Value c) -> Value { - return rewriter.create( - dotOp->getLoc(), c.getType(), - rewriter.create(dotOp->getLoc(), - rewriter.getF32FloatAttr(0))); - }; - auto replaceNansWithZeros = [&](Value value) -> Value { - auto nans = rewriter.create( - dotOp->getLoc(), arith::CmpFPredicate::UNO, value, value); - auto zero = zeroLike(value); - return rewriter.create(dotOp->getLoc(), nans, zero, - value); - }; - // Starting Values: a(0), a(1), a(2), b(0), b(1), b(2) and zero accumulator const auto lhs_parts = splitF32(dotOp.getA(), N, rewriter); const auto rhs_parts = splitF32(dotOp.getB(), N, rewriter); - auto result = zeroLike(dotOp.getC()); + auto result = zeroLike(dotOp.getC(), rewriter); switch (dotOp.getInputPrecision()) { default: @@ -101,27 +101,27 @@ struct BF16xN : public OpRewritePattern { // NOTE: 9 dots possible; handled like so if not for lack of speedup: // case InputPrecision::BF16x9: - // result = IEEEDot(rewriter, lhs_parts[lo], rhs_parts[lo], result); - // result = IEEEDot(rewriter, lhs_parts[mid], rhs_parts[lo], result); - // result = IEEEDot(rewriter, lhs_parts[lo], rhs_parts[mid], result); + // result = Dot(lhs_parts[lo], rhs_parts[lo], result, rewriter); + // result = Dot(lhs_parts[mid], rhs_parts[lo], result, rewriter); + // result = Dot(lhs_parts[lo], rhs_parts[mid], result, rewriter); case InputPrecision::BF16x6: - result = IEEEDot(rewriter, lhs_parts[mid], rhs_parts[mid], result); + result = Dot(lhs_parts[mid], rhs_parts[mid], result, rewriter); - result = IEEEDot(rewriter, lhs_parts[lo], rhs_parts[hi], result); - result = IEEEDot(rewriter, lhs_parts[hi], rhs_parts[lo], result); + result = Dot(lhs_parts[lo], rhs_parts[hi], result, rewriter); + result = Dot(lhs_parts[hi], rhs_parts[lo], result, rewriter); case InputPrecision::BF16x3: - result = IEEEDot(rewriter, lhs_parts[mid], rhs_parts[hi], result); - result = IEEEDot(rewriter, lhs_parts[hi], rhs_parts[mid], result); - result = replaceNansWithZeros(result); + result = Dot(lhs_parts[mid], rhs_parts[hi], result, rewriter); + result = Dot(lhs_parts[hi], rhs_parts[mid], result, rewriter); + result = replaceNansWithZeros(result, rewriter); // NOTE: For BF16x1 bail without replaceNansWithZeros // case InputPrecision::BF16x1: break; } - result = IEEEDot(rewriter, lhs_parts[hi], rhs_parts[hi], result); - result = rewriter.create(loc, result, dotOp.getC()); + result = Dot(lhs_parts[hi], rhs_parts[hi], result, rewriter); + result = rewriter.create(dotOp.getLoc(), result, dotOp.getC()); rewriter.replaceOp(dotOp, result); return success(); @@ -145,11 +145,6 @@ class TF32x3 : public OpRewritePattern { LogicalResult matchAndRewrite(DotOp dotOp, PatternRewriter &rewriter) const override { - - auto isF32 = [](Value operand) { - return cast(operand.getType()).getElementType().isF32(); - }; - if (!(dotOp.getInputPrecision() == InputPrecision::TF32x3 && isF32(dotOp.getA()) && isF32(dotOp.getB()))) { return failure(); @@ -164,30 +159,12 @@ class TF32x3 : public OpRewritePattern { ArrayRef{value}) .getResult()[0]; }; - auto zeroLike = [&](Value c) -> Value { - return rewriter.create( - dotOp->getLoc(), c.getType(), - rewriter.create(dotOp->getLoc(), - rewriter.getF32FloatAttr(0))); - }; auto add = [&](Value a, Value b) -> Value { return rewriter.create(dotOp.getLoc(), a, b); }; auto sub = [&](Value a, Value b) -> Value { return rewriter.create(dotOp.getLoc(), a, b); }; - auto dot = [&](Value a, Value b, Value c) -> Value { - return rewriter.create(dotOp->getLoc(), c.getType(), a, b, c, - InputPrecision::TF32, - dotOp.getMaxNumImpreciseAcc()); - }; - auto replaceNansWithZeros = [&](Value value) -> Value { - auto nans = rewriter.create( - dotOp->getLoc(), arith::CmpFPredicate::UNO, value, value); - auto zero = zeroLike(value); - return rewriter.create(dotOp->getLoc(), nans, zero, - value); - }; auto aBig = f32ToTF32(dotOp.getA()); auto aSmall = sub(dotOp.getA(), aBig); @@ -195,10 +172,10 @@ class TF32x3 : public OpRewritePattern { auto bBig = f32ToTF32(dotOp.getB()); auto bSmall = sub(dotOp.getB(), bBig); - auto zero = zeroLike(dotOp.getC()); + auto zero = zeroLike(dotOp.getC(), rewriter); - auto dot1 = dot(aSmall, bBig, zero); - auto dot2 = dot(aBig, bSmall, dot1); + auto dot1 = Dot(aSmall, bBig, zero, rewriter, dotOp.getMaxNumImpreciseAcc()); + auto dot2 = Dot(aBig, bSmall, dot1, rewriter, dotOp.getMaxNumImpreciseAcc()); // If lhs is 1.0, we will have lhs_high = 1.0 and lhs_low = 0.0. // If rhs is +infinity, we will have: @@ -207,8 +184,8 @@ class TF32x3 : public OpRewritePattern { // We would get the wrong result if we sum these partial products. Instead, // we must override any accumulated result if the last partial product is // non-finite. - auto dot2withZeroedNans = replaceNansWithZeros(dot2); - auto dot3 = dot(aBig, bBig, dot2withZeroedNans); + auto dot2withZeroedNans = replaceNansWithZeros(dot2, rewriter); + auto dot3 = Dot(aBig, bBig, dot2withZeroedNans, rewriter, dotOp.getMaxNumImpreciseAcc()); auto sum = add(dot3, dotOp.getC()); From c540446d61bc9713401cea95d73479a6d4428d98 Mon Sep 17 00:00:00 2001 From: Puyan Lotfi Date: Sat, 18 Oct 2025 01:45:29 -0700 Subject: [PATCH 36/39] lint --- lib/Dialect/TritonGPU/Transforms/F32DotTC.cpp | 52 ++++++++++++------- 1 file changed, 32 insertions(+), 20 deletions(-) diff --git a/lib/Dialect/TritonGPU/Transforms/F32DotTC.cpp b/lib/Dialect/TritonGPU/Transforms/F32DotTC.cpp index bbc9384c928a..7efe14dd8569 100644 --- a/lib/Dialect/TritonGPU/Transforms/F32DotTC.cpp +++ b/lib/Dialect/TritonGPU/Transforms/F32DotTC.cpp @@ -38,14 +38,14 @@ auto isF32(Value operand) { }; auto zeroLike(Value c, PatternRewriter &rewriter) -> Value { - return rewriter.create( - c.getLoc(), c.getType(), - rewriter.create(c.getLoc(), - rewriter.getF32FloatAttr(0))); + return rewriter.create(c.getLoc(), c.getType(), + rewriter.create( + c.getLoc(), rewriter.getF32FloatAttr(0))); }; template -Value Dot(Value lhs, Value rhs, Value acc, PatternRewriter &rewriter, uint32_t maxNumImpreciseAcc = 0) { +Value Dot(Value lhs, Value rhs, Value acc, PatternRewriter &rewriter, + uint32_t maxNumImpreciseAcc = 0) { return rewriter.create(lhs.getLoc(), lhs, rhs, acc, precision, maxNumImpreciseAcc); }; @@ -54,8 +54,7 @@ auto replaceNansWithZeros(Value value, PatternRewriter &rewriter) -> Value { auto nans = rewriter.create( value.getLoc(), arith::CmpFPredicate::UNO, value, value); auto zero = zeroLike(value, rewriter); - return rewriter.create(value.getLoc(), nans, zero, - value); + return rewriter.create(value.getLoc(), nans, zero, value); }; auto getBF16Count(triton::InputPrecision precision) -> unsigned { @@ -101,27 +100,36 @@ struct BF16xN : public OpRewritePattern { // NOTE: 9 dots possible; handled like so if not for lack of speedup: // case InputPrecision::BF16x9: - // result = Dot(lhs_parts[lo], rhs_parts[lo], result, rewriter); - // result = Dot(lhs_parts[mid], rhs_parts[lo], result, rewriter); - // result = Dot(lhs_parts[lo], rhs_parts[mid], result, rewriter); + // result = Dot(lhs_parts[lo], rhs_parts[lo], + // result, rewriter); result = Dot(lhs_parts[mid], + // rhs_parts[lo], result, rewriter); result = + // Dot(lhs_parts[lo], rhs_parts[mid], result, + // rewriter); case InputPrecision::BF16x6: - result = Dot(lhs_parts[mid], rhs_parts[mid], result, rewriter); + result = Dot(lhs_parts[mid], rhs_parts[mid], result, + rewriter); - result = Dot(lhs_parts[lo], rhs_parts[hi], result, rewriter); - result = Dot(lhs_parts[hi], rhs_parts[lo], result, rewriter); + result = Dot(lhs_parts[lo], rhs_parts[hi], result, + rewriter); + result = Dot(lhs_parts[hi], rhs_parts[lo], result, + rewriter); case InputPrecision::BF16x3: - result = Dot(lhs_parts[mid], rhs_parts[hi], result, rewriter); - result = Dot(lhs_parts[hi], rhs_parts[mid], result, rewriter); + result = Dot(lhs_parts[mid], rhs_parts[hi], result, + rewriter); + result = Dot(lhs_parts[hi], rhs_parts[mid], result, + rewriter); result = replaceNansWithZeros(result, rewriter); // NOTE: For BF16x1 bail without replaceNansWithZeros // case InputPrecision::BF16x1: break; } - result = Dot(lhs_parts[hi], rhs_parts[hi], result, rewriter); - result = rewriter.create(dotOp.getLoc(), result, dotOp.getC()); + result = Dot(lhs_parts[hi], rhs_parts[hi], result, + rewriter); + result = + rewriter.create(dotOp.getLoc(), result, dotOp.getC()); rewriter.replaceOp(dotOp, result); return success(); @@ -174,8 +182,10 @@ class TF32x3 : public OpRewritePattern { auto zero = zeroLike(dotOp.getC(), rewriter); - auto dot1 = Dot(aSmall, bBig, zero, rewriter, dotOp.getMaxNumImpreciseAcc()); - auto dot2 = Dot(aBig, bSmall, dot1, rewriter, dotOp.getMaxNumImpreciseAcc()); + auto dot1 = Dot(aSmall, bBig, zero, rewriter, + dotOp.getMaxNumImpreciseAcc()); + auto dot2 = Dot(aBig, bSmall, dot1, rewriter, + dotOp.getMaxNumImpreciseAcc()); // If lhs is 1.0, we will have lhs_high = 1.0 and lhs_low = 0.0. // If rhs is +infinity, we will have: @@ -185,7 +195,9 @@ class TF32x3 : public OpRewritePattern { // we must override any accumulated result if the last partial product is // non-finite. auto dot2withZeroedNans = replaceNansWithZeros(dot2, rewriter); - auto dot3 = Dot(aBig, bBig, dot2withZeroedNans, rewriter, dotOp.getMaxNumImpreciseAcc()); + auto dot3 = + Dot(aBig, bBig, dot2withZeroedNans, rewriter, + dotOp.getMaxNumImpreciseAcc()); auto sum = add(dot3, dotOp.getC()); From 70ba6e0e01f3691eea4c429719c4f4c9c9e2c1d1 Mon Sep 17 00:00:00 2001 From: Puyan Lotfi Date: Sat, 18 Oct 2025 01:52:08 -0700 Subject: [PATCH 37/39] improve TT_DotOp description --- include/triton/Dialect/Triton/IR/TritonOps.td | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/include/triton/Dialect/Triton/IR/TritonOps.td b/include/triton/Dialect/Triton/IR/TritonOps.td index 06b6b5d0726d..a745dd12e850 100644 --- a/include/triton/Dialect/Triton/IR/TritonOps.td +++ b/include/triton/Dialect/Triton/IR/TritonOps.td @@ -664,9 +664,11 @@ def TT_DotOp : TT_Op<"dot", [Pure, let description = [{ $d = matrix_multiply($a, $b) + $c. $inputPrecision describes how to exercise the TC - when the inputs are f32. It can be one of: tf32, tf32x3, ieee. + when the inputs are f32. It can be one of: tf32, tf32x3, ieee, bf16x3, bf16x6. tf32: use TC with tf32 ops. tf32x3: implement the 3xTF32 trick. For more info see the pass in F32DotTC.cpp + bf16x3: implement the 3xBF16 trick. For more info see the pass in F32DotTC.cpp + bf16x6: implement the 6xBF16 trick. For more info see the pass in F32DotTC.cpp ieee: don't use TC, implement dot in software. If the GPU does not have Tensor cores or the inputs are not f32, this flag is ignored. }]; From aaf3faf165bf77068b717ebde573f61adfc52804 Mon Sep 17 00:00:00 2001 From: Puyan Lotfi Date: Sat, 18 Oct 2025 11:32:34 -0700 Subject: [PATCH 38/39] address more PR feedback --- lib/Dialect/TritonGPU/Transforms/F32DotTC.cpp | 51 ++++++++----------- 1 file changed, 22 insertions(+), 29 deletions(-) diff --git a/lib/Dialect/TritonGPU/Transforms/F32DotTC.cpp b/lib/Dialect/TritonGPU/Transforms/F32DotTC.cpp index 7efe14dd8569..064d8aec5207 100644 --- a/lib/Dialect/TritonGPU/Transforms/F32DotTC.cpp +++ b/lib/Dialect/TritonGPU/Transforms/F32DotTC.cpp @@ -33,31 +33,30 @@ auto splitF32(Value input, unsigned N, PatternRewriter &rewriter) return splitInputs; } -auto isF32(Value operand) { +bool isF32(Value operand) { return cast(operand.getType()).getElementType().isF32(); }; -auto zeroLike(Value c, PatternRewriter &rewriter) -> Value { +Value zeroLike(Value c, PatternRewriter &rewriter) { return rewriter.create(c.getLoc(), c.getType(), rewriter.create( c.getLoc(), rewriter.getF32FloatAttr(0))); }; -template -Value Dot(Value lhs, Value rhs, Value acc, PatternRewriter &rewriter, - uint32_t maxNumImpreciseAcc = 0) { +Value dot(Value lhs, Value rhs, Value acc, PatternRewriter &rewriter, + InputPrecision precision = InputPrecision::IEEE, uint32_t maxNumImpreciseAcc = 0) { return rewriter.create(lhs.getLoc(), lhs, rhs, acc, precision, maxNumImpreciseAcc); }; -auto replaceNansWithZeros(Value value, PatternRewriter &rewriter) -> Value { +Value replaceNansWithZeros(Value value, PatternRewriter &rewriter) { auto nans = rewriter.create( value.getLoc(), arith::CmpFPredicate::UNO, value, value); auto zero = zeroLike(value, rewriter); return rewriter.create(value.getLoc(), nans, zero, value); }; -auto getBF16Count(triton::InputPrecision precision) -> unsigned { +unsigned getBF16Count(triton::InputPrecision precision) { switch (precision) { default: return 0; @@ -98,36 +97,30 @@ struct BF16xN : public OpRewritePattern { assert(false && "BF16DotTCPass expects BF16x6 or BF16x3"); return failure(); - // NOTE: 9 dots possible; handled like so if not for lack of speedup: - // case InputPrecision::BF16x9: - // result = Dot(lhs_parts[lo], rhs_parts[lo], - // result, rewriter); result = Dot(lhs_parts[mid], - // rhs_parts[lo], result, rewriter); result = - // Dot(lhs_parts[lo], rhs_parts[mid], result, - // rewriter); + // clang-format off + // NOTE: 9 dots possible; handled like so if not for lack of speedup: + // case InputPrecision::BF16x9: + // result = dot(lhs_parts[lo], rhs_parts[lo], result, rewriter); + // result = dot(lhs_parts[mid], rhs_parts[lo], result, rewriter); + // result = dot(lhs_parts[lo], rhs_parts[mid], result, rewriter); + // clang-format on case InputPrecision::BF16x6: - result = Dot(lhs_parts[mid], rhs_parts[mid], result, - rewriter); + result = dot(lhs_parts[mid], rhs_parts[mid], result, rewriter); - result = Dot(lhs_parts[lo], rhs_parts[hi], result, - rewriter); - result = Dot(lhs_parts[hi], rhs_parts[lo], result, - rewriter); + result = dot(lhs_parts[lo], rhs_parts[hi], result, rewriter); + result = dot(lhs_parts[hi], rhs_parts[lo], result, rewriter); case InputPrecision::BF16x3: - result = Dot(lhs_parts[mid], rhs_parts[hi], result, - rewriter); - result = Dot(lhs_parts[hi], rhs_parts[mid], result, - rewriter); + result = dot(lhs_parts[mid], rhs_parts[hi], result, rewriter); + result = dot(lhs_parts[hi], rhs_parts[mid], result, rewriter); result = replaceNansWithZeros(result, rewriter); // NOTE: For BF16x1 bail without replaceNansWithZeros // case InputPrecision::BF16x1: break; } - result = Dot(lhs_parts[hi], rhs_parts[hi], result, - rewriter); + result = dot(lhs_parts[hi], rhs_parts[hi], result, rewriter); result = rewriter.create(dotOp.getLoc(), result, dotOp.getC()); @@ -182,9 +175,9 @@ class TF32x3 : public OpRewritePattern { auto zero = zeroLike(dotOp.getC(), rewriter); - auto dot1 = Dot(aSmall, bBig, zero, rewriter, + auto dot1 = dot(aSmall, bBig, zero, rewriter, InputPrecision::TF32, dotOp.getMaxNumImpreciseAcc()); - auto dot2 = Dot(aBig, bSmall, dot1, rewriter, + auto dot2 = dot(aBig, bSmall, dot1, rewriter, InputPrecision::TF32, dotOp.getMaxNumImpreciseAcc()); // If lhs is 1.0, we will have lhs_high = 1.0 and lhs_low = 0.0. @@ -196,7 +189,7 @@ class TF32x3 : public OpRewritePattern { // non-finite. auto dot2withZeroedNans = replaceNansWithZeros(dot2, rewriter); auto dot3 = - Dot(aBig, bBig, dot2withZeroedNans, rewriter, + dot(aBig, bBig, dot2withZeroedNans, rewriter, InputPrecision::TF32, dotOp.getMaxNumImpreciseAcc()); auto sum = add(dot3, dotOp.getC()); From 04305b450c84397246a396e529fa0ab57bf5ed79 Mon Sep 17 00:00:00 2001 From: Puyan Lotfi Date: Sat, 18 Oct 2025 11:32:46 -0700 Subject: [PATCH 39/39] lint --- lib/Dialect/TritonGPU/Transforms/F32DotTC.cpp | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/lib/Dialect/TritonGPU/Transforms/F32DotTC.cpp b/lib/Dialect/TritonGPU/Transforms/F32DotTC.cpp index 064d8aec5207..e5db6abbd973 100644 --- a/lib/Dialect/TritonGPU/Transforms/F32DotTC.cpp +++ b/lib/Dialect/TritonGPU/Transforms/F32DotTC.cpp @@ -44,7 +44,8 @@ Value zeroLike(Value c, PatternRewriter &rewriter) { }; Value dot(Value lhs, Value rhs, Value acc, PatternRewriter &rewriter, - InputPrecision precision = InputPrecision::IEEE, uint32_t maxNumImpreciseAcc = 0) { + InputPrecision precision = InputPrecision::IEEE, + uint32_t maxNumImpreciseAcc = 0) { return rewriter.create(lhs.getLoc(), lhs, rhs, acc, precision, maxNumImpreciseAcc); }; @@ -97,13 +98,13 @@ struct BF16xN : public OpRewritePattern { assert(false && "BF16DotTCPass expects BF16x6 or BF16x3"); return failure(); - // clang-format off + // clang-format off // NOTE: 9 dots possible; handled like so if not for lack of speedup: // case InputPrecision::BF16x9: // result = dot(lhs_parts[lo], rhs_parts[lo], result, rewriter); // result = dot(lhs_parts[mid], rhs_parts[lo], result, rewriter); // result = dot(lhs_parts[lo], rhs_parts[mid], result, rewriter); - // clang-format on + // clang-format on case InputPrecision::BF16x6: result = dot(lhs_parts[mid], rhs_parts[mid], result, rewriter); @@ -176,9 +177,9 @@ class TF32x3 : public OpRewritePattern { auto zero = zeroLike(dotOp.getC(), rewriter); auto dot1 = dot(aSmall, bBig, zero, rewriter, InputPrecision::TF32, - dotOp.getMaxNumImpreciseAcc()); + dotOp.getMaxNumImpreciseAcc()); auto dot2 = dot(aBig, bSmall, dot1, rewriter, InputPrecision::TF32, - dotOp.getMaxNumImpreciseAcc()); + dotOp.getMaxNumImpreciseAcc()); // If lhs is 1.0, we will have lhs_high = 1.0 and lhs_low = 0.0. // If rhs is +infinity, we will have: @@ -188,9 +189,8 @@ class TF32x3 : public OpRewritePattern { // we must override any accumulated result if the last partial product is // non-finite. auto dot2withZeroedNans = replaceNansWithZeros(dot2, rewriter); - auto dot3 = - dot(aBig, bBig, dot2withZeroedNans, rewriter, InputPrecision::TF32, - dotOp.getMaxNumImpreciseAcc()); + auto dot3 = dot(aBig, bBig, dot2withZeroedNans, rewriter, + InputPrecision::TF32, dotOp.getMaxNumImpreciseAcc()); auto sum = add(dot3, dotOp.getC());