|
| 1 | +#include "mlir/IR/PatternMatch.h" |
| 2 | +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
| 3 | +#include "triton/Dialect/TritonGPU/Transforms/Passes.h" |
| 4 | +#include "llvm/Support/raw_ostream.h" |
| 5 | + |
| 6 | +namespace mlir { |
| 7 | +namespace triton { |
| 8 | +namespace gpu { |
| 9 | + |
| 10 | +#define GEN_PASS_DEF_TRITONGPUBF16X3DOT |
| 11 | +#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc" |
| 12 | + |
| 13 | +namespace { |
| 14 | + |
| 15 | +// Implement 3xBF16 https://arxiv.org/abs/1904.06376 |
| 16 | +class BF16x3 : public OpRewritePattern<DotOp> { |
| 17 | +public: |
| 18 | + using OpRewritePattern::OpRewritePattern; |
| 19 | + |
| 20 | + LogicalResult matchAndRewrite(DotOp dotOp, |
| 21 | + PatternRewriter &rewriter) const override { |
| 22 | + |
| 23 | + auto isBF16x3Candidate = [](Value operand) { |
| 24 | + return cast<RankedTensorType>(operand.getType()).getElementType().isF32(); |
| 25 | + }; |
| 26 | + |
| 27 | + if (!(dotOp.getInputPrecision() == InputPrecision::BF16x3 && |
| 28 | + isBF16x3Candidate(dotOp.getA()) && isBF16x3Candidate(dotOp.getB()))) { |
| 29 | + return failure(); |
| 30 | + } |
| 31 | + |
| 32 | + // Aux functions |
| 33 | + auto f32ToBF16 = [&](Value value) -> Value { |
| 34 | + auto fp32Type = cast<RankedTensorType>(value.getType()); |
| 35 | + auto bf16Type = |
| 36 | + RankedTensorType::get(fp32Type.getShape(), rewriter.getBF16Type()); |
| 37 | + return rewriter.create<arith::TruncFOp>(dotOp.getLoc(), bf16Type, value) |
| 38 | + .getResult(); |
| 39 | + }; |
| 40 | + auto bf16ToF32 = [&](Value value) -> Value { |
| 41 | + auto bf16Type = cast<RankedTensorType>(value.getType()); |
| 42 | + auto fp32Type = |
| 43 | + RankedTensorType::get(bf16Type.getShape(), rewriter.getF32Type()); |
| 44 | + return rewriter.create<arith::ExtFOp>(dotOp.getLoc(), fp32Type, value) |
| 45 | + .getResult(); |
| 46 | + }; |
| 47 | + auto zeroLike = [&](Value c) -> Value { |
| 48 | + return rewriter.create<SplatOp>( |
| 49 | + dotOp->getLoc(), c.getType(), |
| 50 | + rewriter.create<arith::ConstantOp>(dotOp->getLoc(), |
| 51 | + rewriter.getF32FloatAttr(0))); |
| 52 | + }; |
| 53 | + auto add = [&](Value a, Value b) -> Value { |
| 54 | + return rewriter.create<arith::AddFOp>(dotOp.getLoc(), a, b); |
| 55 | + }; |
| 56 | + auto sub = [&](Value a, Value b) -> Value { |
| 57 | + return rewriter.create<arith::SubFOp>(dotOp.getLoc(), a, b); |
| 58 | + }; |
| 59 | + auto dot = [&](Value a, Value b, Value c) -> Value { |
| 60 | + return rewriter.create<DotOp>(dotOp->getLoc(), c.getType(), a, b, c, |
| 61 | + InputPrecision::BF16, |
| 62 | + dotOp.getMaxNumImpreciseAcc()); |
| 63 | + }; |
| 64 | + auto replaceNansWithZeros = [&](Value value) -> Value { |
| 65 | + auto nans = rewriter.create<arith::CmpFOp>( |
| 66 | + dotOp->getLoc(), arith::CmpFPredicate::UNO, value, value); |
| 67 | + auto zero = zeroLike(value); |
| 68 | + return rewriter.create<arith::SelectOp>(dotOp->getLoc(), nans, zero, |
| 69 | + value); |
| 70 | + }; |
| 71 | + |
| 72 | + auto SplitF32 = [&](Value input, unsigned N) -> std::vector<Value> { |
| 73 | + std::vector<Value> split_inputs; |
| 74 | + split_inputs.reserve(N); |
| 75 | + for (int i = 0; i < N; ++i) { |
| 76 | + Value input_as_bf16 = f32ToBF16(input); |
| 77 | + if (i != N - 1) { |
| 78 | + Value input_as_f32 = bf16ToF32(input_as_bf16); |
| 79 | + input = rewriter.create<arith::SubFOp>(dotOp->getLoc(), input, |
| 80 | + input_as_f32); |
| 81 | + } |
| 82 | + split_inputs.push_back(input_as_bf16); |
| 83 | + } |
| 84 | + return split_inputs; |
| 85 | + }; |
| 86 | + |
| 87 | + const int hi = 0; |
| 88 | + const int med = 1; |
| 89 | + const int lo = 2; |
| 90 | + |
| 91 | + const unsigned N = 3; |
| 92 | + auto lhs_parts = SplitF32(dotOp.getA(), N); |
| 93 | + auto rhs_parts = SplitF32(dotOp.getB(), N); |
| 94 | + |
| 95 | + auto result = zeroLike(dotOp.getC()); |
| 96 | + |
| 97 | + result = dot(lhs_parts[lo], rhs_parts[lo], result); |
| 98 | + result = dot(lhs_parts[med], rhs_parts[lo], result); |
| 99 | + result = dot(lhs_parts[lo], rhs_parts[med], result); |
| 100 | + |
| 101 | + result = dot(lhs_parts[med], rhs_parts[med], result); |
| 102 | + |
| 103 | + result = dot(lhs_parts[lo], rhs_parts[hi], result); |
| 104 | + result = dot(lhs_parts[hi], rhs_parts[lo], result); |
| 105 | + |
| 106 | + result = dot(lhs_parts[med], rhs_parts[hi], result); |
| 107 | + result = dot(lhs_parts[hi], rhs_parts[med], result); |
| 108 | + |
| 109 | + result = replaceNansWithZeros(result); |
| 110 | + result = dot(lhs_parts[hi], rhs_parts[hi], result); |
| 111 | + result = add(result, dotOp.getC()); |
| 112 | + |
| 113 | + rewriter.replaceOp(dotOp, result); |
| 114 | + return success(); |
| 115 | + } |
| 116 | +}; |
| 117 | + |
| 118 | +} // anonymous namespace |
| 119 | + |
| 120 | +struct BF16x3DotPass : public impl::TritonGPUBF16x3DotBase<BF16x3DotPass> { |
| 121 | + void runOnOperation() override { |
| 122 | + MLIRContext *context = &getContext(); |
| 123 | + ModuleOp m = getOperation(); |
| 124 | + |
| 125 | + RewritePatternSet decomposePatterns(context); |
| 126 | + decomposePatterns.add<BF16x3>(context); |
| 127 | + if (applyPatternsGreedily(m, std::move(decomposePatterns)).failed()) { |
| 128 | + signalPassFailure(); |
| 129 | + } |
| 130 | + } |
| 131 | +}; |
| 132 | + |
| 133 | +} // namespace gpu |
| 134 | +} // namespace triton |
| 135 | +} // namespace mlir |
0 commit comments