Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
1084821
[BACKEND] Implement BF16x3 trick
plotfi Jul 22, 2025
df46410
move pass to after coalescer
plotfi Sep 30, 2025
5ac391b
[NFC] drop raw_ostream include
plotfi Sep 30, 2025
18daed1
[NFC] drop sub lambda unused
plotfi Sep 30, 2025
94d4bd6
[NFC] collapse reused code handling BF16x3
plotfi Sep 30, 2025
5cc88bd
[NFC] flatted add lambda
plotfi Sep 30, 2025
e284c85
[NFCi] skip 2xdots for 1xBF16: 1xBF16 will likely get dropped
plotfi Sep 30, 2025
56d86c8
[NFCi] change into to unsigned and med to mid
plotfi Sep 30, 2025
967e686
[NFCi] flatten zeroLike into a single reusable Constant zero fp32 value
plotfi Sep 30, 2025
5dd2ce6
[NFCi] clean up Split32 into its own helper function
plotfi Sep 30, 2025
ae0ab91
[NFCi] move replaceNansWithZeros to where it is used
plotfi Sep 30, 2025
220412a
[NFCi] clean up the check for f32 operands
plotfi Sep 30, 2025
347103c
[NFCi] clean up placement for hi,mid,low,N, rename struct to BF16xN
plotfi Sep 30, 2025
c72616c
[NFCi] pre-commit
plotfi Sep 30, 2025
f246280
[NFCi] clean up lambdas
plotfi Sep 30, 2025
7922370
[NFCi] pre-commit
plotfi Sep 30, 2025
03e2cf2
[NFCi] loc
plotfi Sep 30, 2025
72e8aa9
[NFCi] pre-commit
plotfi Sep 30, 2025
9494268
remove 03-matrix-multiplication.py
plotfi Oct 1, 2025
a8a7bf9
more cleanup
plotfi Oct 2, 2025
f7fa3fb
pre-commit
plotfi Oct 2, 2025
a4c29a2
drop bf16x1
plotfi Oct 2, 2025
80e339d
pre-commit
plotfi Oct 2, 2025
4abb793
improve lit tests
plotfi Oct 2, 2025
7e9c0f0
addressing Lei's feedback, drop BF16x9 with exception of a small comment
plotfi Oct 10, 2025
8c005f4
leave door open for BF16x1 to match hipblas
plotfi Oct 10, 2025
ca6b7e3
Fix cuda tests, and disable BF16xN for interpreter
plotfi Oct 15, 2025
1ebfaaf
attempting to fix blackwell tl.dot tests
plotfi Oct 16, 2025
3d552a6
Fix B200 tl.dot test. BF16 dot sub-opcode mnemonic is ".f16"
plotfi Oct 16, 2025
400427d
Merge BF16DotTC into F32DotTC pass, separate pattern however
plotfi Oct 18, 2025
c177343
lint
plotfi Oct 18, 2025
6fd2e44
fix lit tests
plotfi Oct 18, 2025
8d853bc
drop TritonNvidiaGPUDialect dep from TritonGPUF32DotTC
plotfi Oct 18, 2025
9228dac
fix input prec pytest skip reason for interpreter
plotfi Oct 18, 2025
e16f951
merge functionality of the BF16xN and TF32x3 pattern writers
plotfi Oct 18, 2025
c540446
lint
plotfi Oct 18, 2025
70ba6e0
improve TT_DotOp description
plotfi Oct 18, 2025
aaf3faf
address more PR feedback
plotfi Oct 18, 2025
04305b4
lint
plotfi Oct 18, 2025
6b10688
Merge branch 'main' into plotfi-bf16x3-dot
antiagainst Oct 19, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion include/triton/Dialect/Triton/IR/TritonAttrDefs.td
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,9 @@ def TT_InputPrecisionAttr : I32EnumAttr<
[
I32EnumAttrCase<"TF32", 0, "tf32">,
I32EnumAttrCase<"TF32x3", 1, "tf32x3">,
I32EnumAttrCase<"IEEE", 2, "ieee">
I32EnumAttrCase<"IEEE", 2, "ieee">,
I32EnumAttrCase<"BF16x3", 3, "bf16x3">,
I32EnumAttrCase<"BF16x6", 4, "bf16x6">
]>{
let cppNamespace = "::mlir::triton";
}
Expand Down
4 changes: 3 additions & 1 deletion include/triton/Dialect/Triton/IR/TritonOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -664,9 +664,11 @@ def TT_DotOp : TT_Op<"dot", [Pure,

let description = [{
$d = matrix_multiply($a, $b) + $c. $inputPrecision describes how to exercise the TC
when the inputs are f32. It can be one of: tf32, tf32x3, ieee.
when the inputs are f32. It can be one of: tf32, tf32x3, ieee, bf16x3, bf16x6.
tf32: use TC with tf32 ops.
tf32x3: implement the 3xTF32 trick. For more info see the pass in F32DotTC.cpp
bf16x3: implement the 3xBF16 trick. For more info see the pass in F32DotTC.cpp
bf16x6: implement the 6xBF16 trick. For more info see the pass in F32DotTC.cpp
ieee: don't use TC, implement dot in software.
If the GPU does not have Tensor cores or the inputs are not f32, this flag is ignored.
}];
Expand Down
17 changes: 12 additions & 5 deletions include/triton/Dialect/TritonGPU/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -177,15 +177,22 @@ def TritonGPUPartitionScheduling : Pass<"tritongpu-partition-scheduling", "mlir:
}

def TritonGPUF32DotTC : Pass<"tritongpu-F32DotTC", "mlir::ModuleOp"> {
let summary = "3xTF32 trick";
let summary = "Emulate dot-product tensor core precision using TF32s or BF16s";

let description = [{
Decompose fp32 `DotOp` instructions into 4 pointwise ops and 3 fp16 `DotOp`s
to allow using TensorCores. See https://github.com/NVIDIA/cutlass/discussions/385
Generic pass to emulate/decompose f32 `DotOp` instructions.
* Decompose fp32 `DotOp` instructions into 4 pointwise ops and 3 fp16 `DotOp`s
to allow using TensorCores. See https://github.com/NVIDIA/cutlass/discussions/385.
* Decompose fp32 `DotOp` instructions into BF16 operations.
See https://arxiv.org/abs/1904.06376
}];

let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect",
"mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect"];
let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect"];
let options = [
Option<"emuTF32", "emu-tf32",
"bool", /*default*/"false",
"whether to handle InputPrecision TF32xN for Nvidia GPUs">
];
}

def TritonGPUPrefetch : Pass<"tritongpu-prefetch", "mlir::ModuleOp"> {
Expand Down
171 changes: 136 additions & 35 deletions lib/Dialect/TritonGPU/Transforms/F32DotTC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,134 @@
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "triton/Dialect/TritonGPU/Transforms/Passes.h"

namespace mlir {
namespace triton {
namespace gpu {
namespace mlir::triton::gpu {

#define GEN_PASS_DEF_TRITONGPUF32DOTTC
#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc"

namespace {

template <typename T>
auto convertValue(Value value, const FloatType &scalarToType,
PatternRewriter &rewriter) -> mlir::Value {
auto fromType = cast<RankedTensorType>(value.getType());
auto toType = fromType.cloneWith(std::nullopt, scalarToType);
return rewriter.create<T>(value.getLoc(), toType, value).getResult();
}

auto splitF32(Value input, unsigned N, PatternRewriter &rewriter)
-> llvm::SmallVector<Value, 3> {
llvm::SmallVector<Value, 3> splitInputs;
for (unsigned i = 0; i < N; ++i) {
Value inputAsBF16 =
convertValue<arith::TruncFOp>(input, rewriter.getBF16Type(), rewriter);
if (i != N - 1) {
Value inputAsF32 = convertValue<arith::ExtFOp>(
inputAsBF16, rewriter.getF32Type(), rewriter);
input = rewriter.create<arith::SubFOp>(input.getLoc(), input, inputAsF32);
}
splitInputs.push_back(inputAsBF16);
}
return splitInputs;
}

bool isF32(Value operand) {
return cast<RankedTensorType>(operand.getType()).getElementType().isF32();
};

Value zeroLike(Value c, PatternRewriter &rewriter) {
return rewriter.create<SplatOp>(c.getLoc(), c.getType(),
rewriter.create<arith::ConstantOp>(
c.getLoc(), rewriter.getF32FloatAttr(0)));
};

Value dot(Value lhs, Value rhs, Value acc, PatternRewriter &rewriter,
InputPrecision precision = InputPrecision::IEEE,
uint32_t maxNumImpreciseAcc = 0) {
return rewriter.create<DotOp>(lhs.getLoc(), lhs, rhs, acc, precision,
maxNumImpreciseAcc);
};

Value replaceNansWithZeros(Value value, PatternRewriter &rewriter) {
auto nans = rewriter.create<arith::CmpFOp>(
value.getLoc(), arith::CmpFPredicate::UNO, value, value);
auto zero = zeroLike(value, rewriter);
return rewriter.create<arith::SelectOp>(value.getLoc(), nans, zero, value);
};

unsigned getBF16Count(triton::InputPrecision precision) {
switch (precision) {
default:
return 0;
case InputPrecision::BF16x3:
// BF16x3 only needs the first 2 values derived from splitting an F32
return 2;
case InputPrecision::BF16x6:
return 3;
}
}

// Implements 3xBF16 https://arxiv.org/abs/1904.06376
// See also
// https://github.com/openxla/xla/blob/e33f93fb7220d408811afdc926cf10baaf49c64e/xla/backends/gpu/codegen/triton/dot_algorithms.cc#L152
// As well as
// https://github.com/ROCm/rocm-libraries/blob/develop/projects/hipblaslt/tensilelite/Tensile/Components/LocalRead.py#L288-L330
struct BF16xN : public OpRewritePattern<DotOp> {
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(DotOp dotOp,
PatternRewriter &rewriter) const override {
// BF16 indices and count
const unsigned hi = 0;
const unsigned mid = 1;
const unsigned lo = 2;
const unsigned N = getBF16Count(dotOp.getInputPrecision());

if (!isF32(dotOp.getA()) || !isF32(dotOp.getB()) || !N)
return failure();

// Starting Values: a(0), a(1), a(2), b(0), b(1), b(2) and zero accumulator
const auto lhs_parts = splitF32(dotOp.getA(), N, rewriter);
const auto rhs_parts = splitF32(dotOp.getB(), N, rewriter);
auto result = zeroLike(dotOp.getC(), rewriter);

switch (dotOp.getInputPrecision()) {
default:
assert(false && "BF16DotTCPass expects BF16x6 or BF16x3");
return failure();

// clang-format off
// NOTE: 9 dots possible; handled like so if not for lack of speedup:
// case InputPrecision::BF16x9:
// result = dot(lhs_parts[lo], rhs_parts[lo], result, rewriter);
// result = dot(lhs_parts[mid], rhs_parts[lo], result, rewriter);
// result = dot(lhs_parts[lo], rhs_parts[mid], result, rewriter);
// clang-format on

case InputPrecision::BF16x6:
result = dot(lhs_parts[mid], rhs_parts[mid], result, rewriter);

result = dot(lhs_parts[lo], rhs_parts[hi], result, rewriter);
result = dot(lhs_parts[hi], rhs_parts[lo], result, rewriter);

case InputPrecision::BF16x3:
result = dot(lhs_parts[mid], rhs_parts[hi], result, rewriter);
result = dot(lhs_parts[hi], rhs_parts[mid], result, rewriter);
result = replaceNansWithZeros(result, rewriter);

// NOTE: For BF16x1 bail without replaceNansWithZeros
// case InputPrecision::BF16x1: break;
}

result = dot(lhs_parts[hi], rhs_parts[hi], result, rewriter);
result =
rewriter.create<arith::AddFOp>(dotOp.getLoc(), result, dotOp.getC());

rewriter.replaceOp(dotOp, result);
return success();
}
};

// nb. We call the trick TF32x3 as C++ disallows variables starting with numbers
// Implement 3xTF32 trick https://github.com/NVIDIA/cutlass/discussions/385
// For a, b f32
Expand All @@ -28,11 +147,6 @@ class TF32x3 : public OpRewritePattern<DotOp> {

LogicalResult matchAndRewrite(DotOp dotOp,
PatternRewriter &rewriter) const override {

auto isF32 = [](Value operand) {
return cast<RankedTensorType>(operand.getType()).getElementType().isF32();
};

if (!(dotOp.getInputPrecision() == InputPrecision::TF32x3 &&
isF32(dotOp.getA()) && isF32(dotOp.getB()))) {
return failure();
Expand All @@ -47,41 +161,25 @@ class TF32x3 : public OpRewritePattern<DotOp> {
ArrayRef<Value>{value})
.getResult()[0];
};
auto zeroLike = [&](Value c) -> Value {
return rewriter.create<SplatOp>(
dotOp->getLoc(), c.getType(),
rewriter.create<arith::ConstantOp>(dotOp->getLoc(),
rewriter.getF32FloatAttr(0)));
};
auto add = [&](Value a, Value b) -> Value {
return rewriter.create<arith::AddFOp>(dotOp.getLoc(), a, b);
};
auto sub = [&](Value a, Value b) -> Value {
return rewriter.create<arith::SubFOp>(dotOp.getLoc(), a, b);
};
auto dot = [&](Value a, Value b, Value c) -> Value {
return rewriter.create<DotOp>(dotOp->getLoc(), c.getType(), a, b, c,
InputPrecision::TF32,
dotOp.getMaxNumImpreciseAcc());
};
auto replaceNansWithZeros = [&](Value value) -> Value {
auto nans = rewriter.create<arith::CmpFOp>(
dotOp->getLoc(), arith::CmpFPredicate::UNO, value, value);
auto zero = zeroLike(value);
return rewriter.create<arith::SelectOp>(dotOp->getLoc(), nans, zero,
value);
};

auto aBig = f32ToTF32(dotOp.getA());
auto aSmall = sub(dotOp.getA(), aBig);

auto bBig = f32ToTF32(dotOp.getB());
auto bSmall = sub(dotOp.getB(), bBig);

auto zero = zeroLike(dotOp.getC());
auto zero = zeroLike(dotOp.getC(), rewriter);

auto dot1 = dot(aSmall, bBig, zero);
auto dot2 = dot(aBig, bSmall, dot1);
auto dot1 = dot(aSmall, bBig, zero, rewriter, InputPrecision::TF32,
dotOp.getMaxNumImpreciseAcc());
auto dot2 = dot(aBig, bSmall, dot1, rewriter, InputPrecision::TF32,
dotOp.getMaxNumImpreciseAcc());

// If lhs is 1.0, we will have lhs_high = 1.0 and lhs_low = 0.0.
// If rhs is +infinity, we will have:
Expand All @@ -90,8 +188,9 @@ class TF32x3 : public OpRewritePattern<DotOp> {
// We would get the wrong result if we sum these partial products. Instead,
// we must override any accumulated result if the last partial product is
// non-finite.
auto dot2withZeroedNans = replaceNansWithZeros(dot2);
auto dot3 = dot(aBig, bBig, dot2withZeroedNans);
auto dot2withZeroedNans = replaceNansWithZeros(dot2, rewriter);
auto dot3 = dot(aBig, bBig, dot2withZeroedNans, rewriter,
InputPrecision::TF32, dotOp.getMaxNumImpreciseAcc());

auto sum = add(dot3, dotOp.getC());

Expand All @@ -103,18 +202,20 @@ class TF32x3 : public OpRewritePattern<DotOp> {
} // anonymous namespace

struct F32DotTCPass : public impl::TritonGPUF32DotTCBase<F32DotTCPass> {
using impl::TritonGPUF32DotTCBase<F32DotTCPass>::TritonGPUF32DotTCBase;
void runOnOperation() override {
MLIRContext *context = &getContext();
ModuleOp m = getOperation();

RewritePatternSet decomposePatterns(context);
decomposePatterns.add<TF32x3>(context);
if (this->emuTF32) {
decomposePatterns.add<TF32x3>(context);
}
decomposePatterns.add<BF16xN>(context);
if (applyPatternsGreedily(m, std::move(decomposePatterns)).failed()) {
signalPassFailure();
}
}
};

} // namespace gpu
} // namespace triton
} // namespace mlir
} // namespace mlir::triton::gpu
2 changes: 2 additions & 0 deletions python/src/ir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -308,6 +308,8 @@ void init_triton_ir(py::module &&m) {
.value("TF32", InputPrecision::TF32)
.value("TF32x3", InputPrecision::TF32x3)
.value("IEEE", InputPrecision::IEEE)
.value("BF16x3", InputPrecision::BF16x3)
.value("BF16x6", InputPrecision::BF16x6)
.export_values();

py::enum_<ScaleDotElemType>(m, "ScaleDotElemTypeTY", py::module_local())
Expand Down
2 changes: 1 addition & 1 deletion python/src/passes.cc
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ void init_triton_passes_ttgpuir(py::module &&m) {
ADD_PASS_WRAPPER_0("add_accelerate_matmul", createTritonGPUAccelerateMatmul);
ADD_PASS_WRAPPER_0("add_reorder_instructions",
createTritonGPUReorderInstructions);
ADD_PASS_WRAPPER_0("add_f32_dot_tc", createTritonGPUF32DotTC);
ADD_PASS_OPTION_WRAPPER_1("add_f32_dot_tc", createTritonGPUF32DotTC, bool);
ADD_PASS_OPTION_WRAPPER_1("add_optimize_dot_operands",
createTritonGPUOptimizeDotOperands, bool);
ADD_PASS_WRAPPER_0("add_remove_layout_conversions",
Expand Down
14 changes: 11 additions & 3 deletions python/test/unit/language/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -3084,7 +3084,7 @@ def get_test_dot_base_cases():
return [(*shape, 4, False, False, epilogue, input_precision, in_dtype, out_dtype, 1, None)
for shape in [(64, 64, 64), (32, 32, 32), (16, 16, 16)]
for epilogue in ['none', 'trans', 'add-matrix', 'add-rows', 'add-cols', 'softmax', 'chain-dot']
for input_precision in ['tf32', 'tf32x3', 'ieee']
for input_precision in ['tf32', 'tf32x3', 'ieee', 'bf16x3', 'bf16x6']
for in_dtype, out_dtype in [('float16', 'float16'), ('float16',
'float32'), ('float32',
'float32'), ('float64', 'float64')]
Expand Down Expand Up @@ -3209,6 +3209,8 @@ def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, input_precision, in_dty
if is_interpreter():
if in_dtype == 'bfloat16':
pytest.skip("bfloat16 is not supported in the interpreter")
if input_precision == "bf16x3" or input_precision == "bf16x6":
pytest.skip(f"input_precision {input_precision} is not supported in the interpreter")
else:
if not is_hip() and K < 16:
pytest.skip("small dots are supported only on HIP at the moment")
Expand Down Expand Up @@ -3238,7 +3240,8 @@ def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, input_precision, in_dty
pytest.skip(f"{in_dtype} only supported on CDNA4 and gfx12")
if in_dtype in ("float8e5b16", "float8e4b8") and not is_hip_cdna3():
pytest.skip(f"{in_dtype} only supported on CDNA3")
if not ((input_precision == "ieee") or (input_precision == "tf32" and is_hip_cdna3())):
if not ((input_precision in ("bf16x3", "bf16x6")) or (input_precision == "ieee") or
(input_precision == "tf32" and is_hip_cdna3())):
pytest.skip(f"{input_precision} not supported on HIP")
if kpack == 2 and in_dtype == 'int8' and K < 64:
pytest.skip("kpack too large for K")
Expand Down Expand Up @@ -3426,7 +3429,12 @@ def kernel(X, stride_xm, stride_xk, Y, stride_yk, stride_yn, W, stride_wn, strid

if in_dtype == 'float32' and input_precision != "ieee":
if is_tcgen5:
assert re.search(r'tcgen05.mma.cta_group::1.kind::tf32', ptx)
if input_precision in ("bf16x3", "bf16x6"):
assert re.search(r'tcgen05.mma.cta_group::1.kind::f16', ptx)
else:
assert re.search(r'tcgen05.mma.cta_group::1.kind::tf32', ptx)
elif input_precision in ("bf16x3", "bf16x6"):
assert re.search(r'[mma|wgmma.mma_async].sync.aligned.m\d+n\d+k16(?:.row.col)?.f32.bf16.bf16', ptx)
else:
assert re.search(r'[mma|wgmma.mma_async].sync.aligned.m\d+n\d+k8(?:.row.col)?.f32.tf32.tf32', ptx)
elif in_dtype == 'float16' and out_dtype == tl.float32:
Expand Down
4 changes: 4 additions & 0 deletions python/triton/language/semantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1467,6 +1467,10 @@ def _str_to_dot_input_precision(self, input_precision):
input_precision = input_precision.upper()
if input_precision == "TF32X3":
input_precision = "TF32x3"
if input_precision == "BF16X3":
input_precision = "BF16x3"
if input_precision == "BF16X6":
input_precision = "BF16x6"
return getattr(ir.INPUT_PRECISION, input_precision)

def dot(self, lhs: TensorTy, rhs: TensorTy, acc: TensorTy, input_precision: Optional[str],
Expand Down
Loading
Loading