@@ -37,6 +37,12 @@ auto SplitF32(Value input, unsigned N, PatternRewriter &rewriter)
3737 return split_inputs;
3838}
3939
40+ Value IEEEDot (PatternRewriter &rewriter, Value lhs, Value rhs, Value acc) {
41+ return rewriter.create <DotOp>(lhs.getLoc (), lhs, rhs, acc,
42+ /* inputPrecision=*/ InputPrecision::IEEE,
43+ /* maxNumImpreciseAcc=*/ 0 );
44+ }
45+
4046auto getBF16Count (triton::InputPrecision precision) -> unsigned {
4147 switch (precision) {
4248 default :
@@ -51,7 +57,9 @@ auto getBF16Count(triton::InputPrecision precision) -> unsigned {
5157 }
5258}
5359
54- // Implement 3xBF16 https://arxiv.org/abs/1904.06376
60+ // Implements 3xBF16 https://arxiv.org/abs/1904.06376
61+ // See also https://github.com/openxla/xla/blob/e33f93fb7220d408811afdc926cf10baaf49c64e/xla/backends/gpu/codegen/triton/dot_algorithms.cc#L152
62+ // As well as https://github.com/ROCm/rocm-libraries/blob/develop/projects/hipblaslt/tensilelite/Tensile/Components/LocalRead.py#L288-L330
5563struct BF16xN : public OpRewritePattern <DotOp> {
5664 using OpRewritePattern::OpRewritePattern;
5765
@@ -63,62 +71,55 @@ struct BF16xN : public OpRewritePattern<DotOp> {
6371 const unsigned lo = 2 ;
6472 const unsigned N = getBF16Count (dotOp.getInputPrecision ());
6573 Location loc = dotOp.getLoc ();
74+ auto typeA = dotOp.getA ().getType ();
75+ auto typeB = dotOp.getB ().getType ();
6676
67- // Checks for FP32 inputs and BF16 InputPrecision
68- if (N == 0 )
77+ if (!cast<RankedTensorType>(typeA). getElementType (). isF32 () ||
78+ !cast<RankedTensorType>(typeB). getElementType (). isF32 () || !N )
6979 return failure ();
70- for (auto type : {dotOp.getA ().getType (), dotOp.getB ().getType ()})
71- if (!cast<RankedTensorType>(type).getElementType ().isF32 ())
72- return failure ();
73-
74- // Helper Lambdas
75- auto dot = [&](Value a, Value b, Value c) -> Value {
76- return rewriter.create <DotOp>(loc, c.getType (), a, b, c,
77- InputPrecision::BF16,
78- dotOp.getMaxNumImpreciseAcc ());
79- };
80- auto zeroLike = [&]() -> Value {
80+
81+ // Aux functions
82+ auto zeroLike = [&](Value c) -> Value {
8183 return rewriter.create <SplatOp>(
82- loc, dotOp.getC ().getType (),
83- rewriter.create <arith::ConstantOp>(loc, rewriter.getF32FloatAttr (0 )));
84+ dotOp->getLoc (), c.getType (),
85+ rewriter.create <arith::ConstantOp>(dotOp->getLoc (),
86+ rewriter.getF32FloatAttr (0 )));
8487 };
8588 auto replaceNansWithZeros = [&](Value value) -> Value {
86- auto isNaN = rewriter.create <arith::CmpFOp>(
87- loc, arith::CmpFPredicate::UNO, value, value);
88- return rewriter.create <arith::SelectOp>(loc, isNaN, zeroLike (), value);
89+ auto nans = rewriter.create <arith::CmpFOp>(
90+ dotOp->getLoc (), arith::CmpFPredicate::UNO, value, value);
91+ auto zero = zeroLike (value);
92+ return rewriter.create <arith::SelectOp>(dotOp->getLoc (), nans, zero,
93+ value);
8994 };
9095
9196 // Starting Values: a(0), a(1), a(2), b(0), b(1), b(2) and zero accumulator
92- auto lhs_parts = SplitF32 (dotOp.getA (), N, rewriter);
93- auto rhs_parts = SplitF32 (dotOp.getB (), N, rewriter);
94- auto result = zeroLike ();
97+ const auto lhs_parts = SplitF32 (dotOp.getA (), N, rewriter);
98+ const auto rhs_parts = SplitF32 (dotOp.getB (), N, rewriter);
99+ auto result = zeroLike (dotOp. getC () );
95100
96- if (dotOp.getInputPrecision () == InputPrecision::BF16x9) {
97- result = dot (lhs_parts[lo], rhs_parts[lo], result);
98- result = dot (lhs_parts[mid], rhs_parts[lo], result);
99- result = dot (lhs_parts[lo], rhs_parts[mid], result);
100-
101- // Identical to BF16x6 handling code:
102- result = dot (lhs_parts[mid], rhs_parts[mid], result);
103-
104- result = dot (lhs_parts[lo], rhs_parts[hi], result);
105- result = dot (lhs_parts[hi], rhs_parts[lo], result);
101+ switch (dotOp.getInputPrecision ()) {
102+ default :
103+ assert (false && " BF16DotTCPass expects BF16x9, BF16x6 or BF16x3" );
104+ return failure ();
105+ case InputPrecision::BF16x9:
106+ result = IEEEDot (rewriter, lhs_parts[lo], rhs_parts[lo], result);
107+ result = IEEEDot (rewriter, lhs_parts[mid], rhs_parts[lo], result);
108+ result = IEEEDot (rewriter, lhs_parts[lo], rhs_parts[mid], result);
106109
107- } else if (dotOp. getInputPrecision () == InputPrecision::BF16x6) {
108- result = dot ( lhs_parts[mid], rhs_parts[mid], result);
110+ case InputPrecision::BF16x6:
111+ result = IEEEDot (rewriter, lhs_parts[mid], rhs_parts[mid], result);
109112
110- result = dot (lhs_parts[lo], rhs_parts[hi], result);
111- result = dot (lhs_parts[hi], rhs_parts[lo], result);
112- }
113+ result = IEEEDot (rewriter, lhs_parts[lo], rhs_parts[hi], result);
114+ result = IEEEDot (rewriter, lhs_parts[hi], rhs_parts[lo], result);
113115
114- // BF16x3, BF16x6, BF16x9 all need this
115- if (dotOp.getInputPrecision () != InputPrecision::BF16) {
116- result = dot (lhs_parts[mid], rhs_parts[hi], result);
117- result = dot (lhs_parts[hi], rhs_parts[mid], result);
116+ case InputPrecision::BF16x3:
117+ result = IEEEDot (rewriter, lhs_parts[mid], rhs_parts[hi], result);
118+ result = IEEEDot (rewriter, lhs_parts[hi], rhs_parts[mid], result);
118119 }
119120
120121 result = replaceNansWithZeros (result);
121- result = dot ( lhs_parts[hi], rhs_parts[hi], result);
122+ result = IEEEDot (rewriter, lhs_parts[hi], rhs_parts[hi], result);
122123 result = rewriter.create <arith::AddFOp>(loc, result, dotOp.getC ());
123124
124125 rewriter.replaceOp (dotOp, result);
0 commit comments