Skip to content

Commit 49975c1

Browse files
committed
more cleanup
1 parent cf586ae commit 49975c1

File tree

1 file changed

+43
-42
lines changed

1 file changed

+43
-42
lines changed

lib/Dialect/TritonGPU/Transforms/BF16DotTC.cpp

Lines changed: 43 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,12 @@ auto SplitF32(Value input, unsigned N, PatternRewriter &rewriter)
3737
return split_inputs;
3838
}
3939

40+
Value IEEEDot(PatternRewriter &rewriter, Value lhs, Value rhs, Value acc) {
41+
return rewriter.create<DotOp>(lhs.getLoc(), lhs, rhs, acc,
42+
/*inputPrecision=*/InputPrecision::IEEE,
43+
/*maxNumImpreciseAcc=*/0);
44+
}
45+
4046
auto getBF16Count(triton::InputPrecision precision) -> unsigned {
4147
switch (precision) {
4248
default:
@@ -51,7 +57,9 @@ auto getBF16Count(triton::InputPrecision precision) -> unsigned {
5157
}
5258
}
5359

54-
// Implement 3xBF16 https://arxiv.org/abs/1904.06376
60+
// Implements 3xBF16 https://arxiv.org/abs/1904.06376
61+
// See also https://github.com/openxla/xla/blob/e33f93fb7220d408811afdc926cf10baaf49c64e/xla/backends/gpu/codegen/triton/dot_algorithms.cc#L152
62+
// As well as https://github.com/ROCm/rocm-libraries/blob/develop/projects/hipblaslt/tensilelite/Tensile/Components/LocalRead.py#L288-L330
5563
struct BF16xN : public OpRewritePattern<DotOp> {
5664
using OpRewritePattern::OpRewritePattern;
5765

@@ -63,62 +71,55 @@ struct BF16xN : public OpRewritePattern<DotOp> {
6371
const unsigned lo = 2;
6472
const unsigned N = getBF16Count(dotOp.getInputPrecision());
6573
Location loc = dotOp.getLoc();
74+
auto typeA = dotOp.getA().getType();
75+
auto typeB = dotOp.getB().getType();
6676

67-
// Checks for FP32 inputs and BF16 InputPrecision
68-
if (N == 0)
77+
if (!cast<RankedTensorType>(typeA).getElementType().isF32() ||
78+
!cast<RankedTensorType>(typeB).getElementType().isF32() || !N)
6979
return failure();
70-
for (auto type : {dotOp.getA().getType(), dotOp.getB().getType()})
71-
if (!cast<RankedTensorType>(type).getElementType().isF32())
72-
return failure();
73-
74-
// Helper Lambdas
75-
auto dot = [&](Value a, Value b, Value c) -> Value {
76-
return rewriter.create<DotOp>(loc, c.getType(), a, b, c,
77-
InputPrecision::BF16,
78-
dotOp.getMaxNumImpreciseAcc());
79-
};
80-
auto zeroLike = [&]() -> Value {
80+
81+
// Aux functions
82+
auto zeroLike = [&](Value c) -> Value {
8183
return rewriter.create<SplatOp>(
82-
loc, dotOp.getC().getType(),
83-
rewriter.create<arith::ConstantOp>(loc, rewriter.getF32FloatAttr(0)));
84+
dotOp->getLoc(), c.getType(),
85+
rewriter.create<arith::ConstantOp>(dotOp->getLoc(),
86+
rewriter.getF32FloatAttr(0)));
8487
};
8588
auto replaceNansWithZeros = [&](Value value) -> Value {
86-
auto isNaN = rewriter.create<arith::CmpFOp>(
87-
loc, arith::CmpFPredicate::UNO, value, value);
88-
return rewriter.create<arith::SelectOp>(loc, isNaN, zeroLike(), value);
89+
auto nans = rewriter.create<arith::CmpFOp>(
90+
dotOp->getLoc(), arith::CmpFPredicate::UNO, value, value);
91+
auto zero = zeroLike(value);
92+
return rewriter.create<arith::SelectOp>(dotOp->getLoc(), nans, zero,
93+
value);
8994
};
9095

9196
// Starting Values: a(0), a(1), a(2), b(0), b(1), b(2) and zero accumulator
92-
auto lhs_parts = SplitF32(dotOp.getA(), N, rewriter);
93-
auto rhs_parts = SplitF32(dotOp.getB(), N, rewriter);
94-
auto result = zeroLike();
97+
const auto lhs_parts = SplitF32(dotOp.getA(), N, rewriter);
98+
const auto rhs_parts = SplitF32(dotOp.getB(), N, rewriter);
99+
auto result = zeroLike(dotOp.getC());
95100

96-
if (dotOp.getInputPrecision() == InputPrecision::BF16x9) {
97-
result = dot(lhs_parts[lo], rhs_parts[lo], result);
98-
result = dot(lhs_parts[mid], rhs_parts[lo], result);
99-
result = dot(lhs_parts[lo], rhs_parts[mid], result);
100-
101-
// Identical to BF16x6 handling code:
102-
result = dot(lhs_parts[mid], rhs_parts[mid], result);
103-
104-
result = dot(lhs_parts[lo], rhs_parts[hi], result);
105-
result = dot(lhs_parts[hi], rhs_parts[lo], result);
101+
switch (dotOp.getInputPrecision()) {
102+
default:
103+
assert(false && "BF16DotTCPass expects BF16x9, BF16x6 or BF16x3");
104+
return failure();
105+
case InputPrecision::BF16x9:
106+
result = IEEEDot(rewriter, lhs_parts[lo], rhs_parts[lo], result);
107+
result = IEEEDot(rewriter, lhs_parts[mid], rhs_parts[lo], result);
108+
result = IEEEDot(rewriter, lhs_parts[lo], rhs_parts[mid], result);
106109

107-
} else if (dotOp.getInputPrecision() == InputPrecision::BF16x6) {
108-
result = dot(lhs_parts[mid], rhs_parts[mid], result);
110+
case InputPrecision::BF16x6:
111+
result = IEEEDot(rewriter, lhs_parts[mid], rhs_parts[mid], result);
109112

110-
result = dot(lhs_parts[lo], rhs_parts[hi], result);
111-
result = dot(lhs_parts[hi], rhs_parts[lo], result);
112-
}
113+
result = IEEEDot(rewriter, lhs_parts[lo], rhs_parts[hi], result);
114+
result = IEEEDot(rewriter, lhs_parts[hi], rhs_parts[lo], result);
113115

114-
// BF16x3, BF16x6, BF16x9 all need this
115-
if (dotOp.getInputPrecision() != InputPrecision::BF16) {
116-
result = dot(lhs_parts[mid], rhs_parts[hi], result);
117-
result = dot(lhs_parts[hi], rhs_parts[mid], result);
116+
case InputPrecision::BF16x3:
117+
result = IEEEDot(rewriter, lhs_parts[mid], rhs_parts[hi], result);
118+
result = IEEEDot(rewriter, lhs_parts[hi], rhs_parts[mid], result);
118119
}
119120

120121
result = replaceNansWithZeros(result);
121-
result = dot(lhs_parts[hi], rhs_parts[hi], result);
122+
result = IEEEDot(rewriter, lhs_parts[hi], rhs_parts[hi], result);
122123
result = rewriter.create<arith::AddFOp>(loc, result, dotOp.getC());
123124

124125
rewriter.replaceOp(dotOp, result);

0 commit comments

Comments
 (0)