Skip to content

Commit d4f709d

Browse files
erwei-xilinxclaude
andauthored
Add bf16-emulation option to convert-vector-to-aievec pipeline (#2942)
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 1bb6106 commit d4f709d

3 files changed

Lines changed: 448 additions & 0 deletions

File tree

include/aie/Dialect/AIEVec/Pipelines/Passes.h

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,13 @@ struct CanonicalizeVectorForAIEVecOptions
5050
"will determine the aievec operations used to convert "
5151
"from vector dialect."),
5252
llvm::cl::init("cpp")};
53+
PassOptions::Option<bool> enableBF16Emulation{
54+
*this, "bf16-emulation",
55+
llvm::cl::desc(
56+
"Emulate f32 vector arithmetic using bf16 operations. Inserts "
57+
"arith.truncf/arith.extf around f32 vector ops to compute in bf16. "
58+
"Trades precision for performance."),
59+
llvm::cl::init(false)};
5360
};
5461

5562
/// Options for the "lower-vector-to-aievec" pipeline.
@@ -118,6 +125,13 @@ struct ConvertVectorToAIEVecOptions
118125
"will determine the aievec operations used to convert "
119126
"from vector dialect."),
120127
llvm::cl::init("cpp")};
128+
PassOptions::Option<bool> enableBF16Emulation{
129+
*this, "bf16-emulation",
130+
llvm::cl::desc(
131+
"Emulate f32 vector arithmetic using bf16 operations. Inserts "
132+
"arith.truncf/arith.extf around f32 vector ops to compute in bf16. "
133+
"Trades precision for performance."),
134+
llvm::cl::init(false)};
121135

122136
mlir::LogicalResult parseFromString(mlir::StringRef options) {
123137
auto res = PassPipelineOptions::parseFromString(options);
@@ -126,6 +140,7 @@ struct ConvertVectorToAIEVecOptions
126140
lowerOptions.targetBackend = targetBackend;
127141
canonicalizeOptions.aieTarget = aieTarget;
128142
canonicalizeOptions.targetBackend = targetBackend;
143+
canonicalizeOptions.enableBF16Emulation = enableBF16Emulation;
129144
optimizeOptions.aieTarget = aieTarget;
130145
optimizeOptions.targetBackend = targetBackend;
131146
optimizeOptions.shiftParam = shiftParam;
@@ -190,6 +205,10 @@ void buildDynamicSizeNoImplicitBroadcastPass(mlir::OpPassManager &pm);
190205
/// operations for AIE2p targets.
191206
std::unique_ptr<::mlir::Pass> createSplitVectorLoadUpsChainsPass();
192207

208+
/// Create a pass that emulates f32 vector arithmetic using bf16 operations.
209+
/// Inserts arith.truncf/arith.extf around f32 vector ops to compute in bf16.
210+
std::unique_ptr<::mlir::Pass> createBF16EmulationPass();
211+
193212
} // namespace aievec
194213
} // namespace xilinx
195214

lib/Dialect/AIEVec/Transforms/VectorToVectorConversions.cpp

Lines changed: 242 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -873,6 +873,243 @@ populateAIE2CanonicalizeConversionPatterns(RewritePatternSet &patterns,
873873
//=================== Common AIE Canonicalization Passes =====================//
874874
//============================================================================//
875875

876+
//===----------------------------------------------------------------------===//
877+
// BF16 Emulation: Emulate f32 vector arithmetic using bf16 operations.
878+
//===----------------------------------------------------------------------===//
879+
880+
// Smart truncation helper: if the value was produced by arith.extf from bf16,
881+
// reuse the bf16 source directly to avoid redundant extf->truncf chains.
882+
static Value smartTruncF32ToBF16(PatternRewriter &rewriter, Location loc,
883+
Value val, Type bf16Type) {
884+
if (auto extfOp = val.getDefiningOp<arith::ExtFOp>()) {
885+
if (extfOp.getIn().getType() == bf16Type)
886+
return extfOp.getIn();
887+
}
888+
return arith::TruncFOp::create(rewriter, loc, bf16Type, val);
889+
}
890+
891+
// Smart truncation for scalar values (used by reduction patterns).
892+
static Value smartTruncScalarF32ToBF16(PatternRewriter &rewriter, Location loc,
893+
Value val) {
894+
Type bf16Ty = rewriter.getBF16Type();
895+
if (auto extfOp = val.getDefiningOp<arith::ExtFOp>()) {
896+
if (extfOp.getIn().getType() == bf16Ty)
897+
return extfOp.getIn();
898+
}
899+
return arith::TruncFOp::create(rewriter, loc, bf16Ty, val);
900+
}
901+
902+
/// Pattern to emulate f32 binary vector arithmetic ops in bf16.
903+
/// For an op like: %r = arith.addf %a, %b : vector<16xf32>
904+
/// Produces:
905+
/// %a_bf16 = arith.truncf %a : vector<16xf32> to vector<16xbf16>
906+
/// %b_bf16 = arith.truncf %b : vector<16xf32> to vector<16xbf16>
907+
/// %r_bf16 = arith.addf %a_bf16, %b_bf16 : vector<16xbf16>
908+
/// %r = arith.extf %r_bf16 : vector<16xbf16> to vector<16xf32>
909+
template <typename OpTy>
910+
struct EmulateBinaryF32InBF16Pattern : public OpRewritePattern<OpTy> {
911+
using OpRewritePattern<OpTy>::OpRewritePattern;
912+
913+
LogicalResult matchAndRewrite(OpTy op,
914+
PatternRewriter &rewriter) const override {
915+
auto resultType = dyn_cast<VectorType>(op.getType());
916+
if (!resultType || !resultType.getElementType().isF32())
917+
return failure();
918+
919+
Location loc = op.getLoc();
920+
auto bf16VecType =
921+
VectorType::get(resultType.getShape(), rewriter.getBF16Type());
922+
923+
Value lhsBF16 =
924+
smartTruncF32ToBF16(rewriter, loc, op.getLhs(), bf16VecType);
925+
Value rhsBF16 =
926+
smartTruncF32ToBF16(rewriter, loc, op.getRhs(), bf16VecType);
927+
928+
Value newResult =
929+
OpTy::create(rewriter, loc, bf16VecType, lhsBF16, rhsBF16);
930+
auto extOp = arith::ExtFOp::create(rewriter, loc, resultType, newResult);
931+
rewriter.replaceOp(op, extOp);
932+
return success();
933+
}
934+
};
935+
936+
/// Pattern to emulate f32 comparison ops in bf16.
937+
/// Result type stays vector<Nxi1>, only operands are truncated.
938+
struct EmulateCmpFF32InBF16Pattern : public OpRewritePattern<arith::CmpFOp> {
939+
using OpRewritePattern::OpRewritePattern;
940+
941+
LogicalResult matchAndRewrite(arith::CmpFOp op,
942+
PatternRewriter &rewriter) const override {
943+
auto lhsType = dyn_cast<VectorType>(op.getLhs().getType());
944+
if (!lhsType || !lhsType.getElementType().isF32())
945+
return failure();
946+
947+
Location loc = op.getLoc();
948+
auto bf16VecType =
949+
VectorType::get(lhsType.getShape(), rewriter.getBF16Type());
950+
951+
Value lhsBF16 =
952+
smartTruncF32ToBF16(rewriter, loc, op.getLhs(), bf16VecType);
953+
Value rhsBF16 =
954+
smartTruncF32ToBF16(rewriter, loc, op.getRhs(), bf16VecType);
955+
956+
rewriter.replaceOpWithNewOp<arith::CmpFOp>(op, op.getPredicate(), lhsBF16,
957+
rhsBF16);
958+
return success();
959+
}
960+
};
961+
962+
/// Pattern to emulate f32 select ops in bf16.
963+
/// Condition stays vector<Nxi1>, true/false values are truncated.
964+
struct EmulateSelectF32InBF16Pattern
965+
: public OpRewritePattern<arith::SelectOp> {
966+
using OpRewritePattern::OpRewritePattern;
967+
968+
LogicalResult matchAndRewrite(arith::SelectOp op,
969+
PatternRewriter &rewriter) const override {
970+
auto resultType = dyn_cast<VectorType>(op.getType());
971+
if (!resultType || !resultType.getElementType().isF32())
972+
return failure();
973+
974+
Location loc = op.getLoc();
975+
auto bf16VecType =
976+
VectorType::get(resultType.getShape(), rewriter.getBF16Type());
977+
978+
Value trueValBF16 =
979+
smartTruncF32ToBF16(rewriter, loc, op.getTrueValue(), bf16VecType);
980+
Value falseValBF16 =
981+
smartTruncF32ToBF16(rewriter, loc, op.getFalseValue(), bf16VecType);
982+
983+
Value newResult = arith::SelectOp::create(rewriter, loc, op.getCondition(),
984+
trueValBF16, falseValBF16);
985+
auto extOp = arith::ExtFOp::create(rewriter, loc, resultType, newResult);
986+
rewriter.replaceOp(op, extOp);
987+
return success();
988+
}
989+
};
990+
991+
/// Pattern to emulate f32 vector.fma in bf16.
992+
/// All three operands (lhs, rhs, acc) are truncated.
993+
struct EmulateFMAF32InBF16Pattern : public OpRewritePattern<vector::FMAOp> {
994+
using OpRewritePattern::OpRewritePattern;
995+
996+
LogicalResult matchAndRewrite(vector::FMAOp op,
997+
PatternRewriter &rewriter) const override {
998+
auto resultType = dyn_cast<VectorType>(op.getType());
999+
if (!resultType || !resultType.getElementType().isF32())
1000+
return failure();
1001+
1002+
Location loc = op.getLoc();
1003+
auto bf16VecType =
1004+
VectorType::get(resultType.getShape(), rewriter.getBF16Type());
1005+
1006+
Value lhsBF16 =
1007+
smartTruncF32ToBF16(rewriter, loc, op.getLhs(), bf16VecType);
1008+
Value rhsBF16 =
1009+
smartTruncF32ToBF16(rewriter, loc, op.getRhs(), bf16VecType);
1010+
Value accBF16 =
1011+
smartTruncF32ToBF16(rewriter, loc, op.getAcc(), bf16VecType);
1012+
1013+
Value newResult =
1014+
vector::FMAOp::create(rewriter, loc, lhsBF16, rhsBF16, accBF16);
1015+
auto extOp = arith::ExtFOp::create(rewriter, loc, resultType, newResult);
1016+
rewriter.replaceOp(op, extOp);
1017+
return success();
1018+
}
1019+
};
1020+
1021+
/// Pattern to emulate f32 unary vector ops in bf16.
1022+
template <typename OpTy>
1023+
struct EmulateUnaryF32InBF16Pattern : public OpRewritePattern<OpTy> {
1024+
using OpRewritePattern<OpTy>::OpRewritePattern;
1025+
1026+
LogicalResult matchAndRewrite(OpTy op,
1027+
PatternRewriter &rewriter) const override {
1028+
auto resultType = dyn_cast<VectorType>(op.getType());
1029+
if (!resultType || !resultType.getElementType().isF32())
1030+
return failure();
1031+
1032+
Location loc = op.getLoc();
1033+
auto bf16VecType =
1034+
VectorType::get(resultType.getShape(), rewriter.getBF16Type());
1035+
1036+
Value inputBF16 =
1037+
smartTruncF32ToBF16(rewriter, loc, op->getOperand(0), bf16VecType);
1038+
1039+
Value newResult = OpTy::create(rewriter, loc, bf16VecType, inputBF16);
1040+
auto extOp = arith::ExtFOp::create(rewriter, loc, resultType, newResult);
1041+
rewriter.replaceOp(op, extOp);
1042+
return success();
1043+
}
1044+
};
1045+
1046+
/// Pattern to emulate f32 vector.reduction in bf16.
1047+
struct EmulateReductionF32InBF16Pattern
1048+
: public OpRewritePattern<vector::ReductionOp> {
1049+
using OpRewritePattern::OpRewritePattern;
1050+
1051+
LogicalResult matchAndRewrite(vector::ReductionOp op,
1052+
PatternRewriter &rewriter) const override {
1053+
if (!op.getType().isF32())
1054+
return failure();
1055+
auto vectorType = dyn_cast<VectorType>(op.getVector().getType());
1056+
if (!vectorType || !vectorType.getElementType().isF32())
1057+
return failure();
1058+
1059+
Location loc = op.getLoc();
1060+
auto bf16VecType =
1061+
VectorType::get(vectorType.getShape(), rewriter.getBF16Type());
1062+
1063+
Value vectorBF16 =
1064+
smartTruncF32ToBF16(rewriter, loc, op.getVector(), bf16VecType);
1065+
1066+
Value accBF16 = nullptr;
1067+
if (op.getAcc())
1068+
accBF16 = smartTruncScalarF32ToBF16(rewriter, loc, op.getAcc());
1069+
1070+
Value newResult = vector::ReductionOp::create(rewriter, loc, op.getKind(),
1071+
vectorBF16, accBF16);
1072+
auto extOp =
1073+
arith::ExtFOp::create(rewriter, loc, rewriter.getF32Type(), newResult);
1074+
rewriter.replaceOp(op, extOp);
1075+
return success();
1076+
}
1077+
};
1078+
1079+
struct BF16EmulationPass
1080+
: public PassWrapper<BF16EmulationPass, OperationPass<>> {
1081+
1082+
void runOnOperation() override {
1083+
auto *op = getOperation();
1084+
MLIRContext *context = &getContext();
1085+
RewritePatternSet patterns(context);
1086+
1087+
// Binary arithmetic ops
1088+
patterns.add<EmulateBinaryF32InBF16Pattern<arith::AddFOp>,
1089+
EmulateBinaryF32InBF16Pattern<arith::SubFOp>,
1090+
EmulateBinaryF32InBF16Pattern<arith::MulFOp>,
1091+
EmulateBinaryF32InBF16Pattern<arith::MaximumFOp>,
1092+
EmulateBinaryF32InBF16Pattern<arith::MinimumFOp>>(context);
1093+
1094+
// Note: arith.divf is NOT demoted because bf16 vector divf is unsupported
1095+
// on all AIE targets (Peano does not legalize G_FDIV on <16 x s16>).
1096+
1097+
// Special-case ops
1098+
patterns.add<EmulateCmpFF32InBF16Pattern, EmulateSelectF32InBF16Pattern,
1099+
EmulateFMAF32InBF16Pattern, EmulateReductionF32InBF16Pattern>(
1100+
context);
1101+
1102+
// Unary ops
1103+
patterns.add<EmulateUnaryF32InBF16Pattern<arith::NegFOp>>(context);
1104+
1105+
(void)applyPatternsGreedily(op, std::move(patterns));
1106+
}
1107+
};
1108+
1109+
std::unique_ptr<::mlir::Pass> xilinx::aievec::createBF16EmulationPass() {
1110+
return std::make_unique<BF16EmulationPass>();
1111+
}
1112+
8761113
struct VectorBroadcastLoweringPass
8771114
: public PassWrapper<VectorBroadcastLoweringPass, OperationPass<>> {
8781115

@@ -1051,6 +1288,11 @@ void xilinx::aievec::buildCanonicalizeVectorForAIEVec(
10511288
// Add `Vector` code canonicalization passes
10521289
// TODO: Add passes to unroll vector with unsupported types
10531290
// TODO: Add passes to split vectors that won't fit in registers
1291+
1292+
// If bf16-emulation is enabled, demote f32 vector arithmetic to bf16 first.
1293+
if (options.enableBF16Emulation)
1294+
pm.addPass(createBF16EmulationPass());
1295+
10541296
if (decodeTargetBackend(options.targetBackend) == TargetBackend::LLVMIR)
10551297
pm.addPass(createReorderOperationsPass());
10561298
pm.addPass(createCopyRemovalPass());

0 commit comments

Comments
 (0)