@@ -873,6 +873,243 @@ populateAIE2CanonicalizeConversionPatterns(RewritePatternSet &patterns,
873873// =================== Common AIE Canonicalization Passes =====================//
874874// ============================================================================//
875875
876+ // ===----------------------------------------------------------------------===//
877+ // BF16 Emulation: Emulate f32 vector arithmetic using bf16 operations.
878+ // ===----------------------------------------------------------------------===//
879+
880+ // Smart truncation helper: if the value was produced by arith.extf from bf16,
881+ // reuse the bf16 source directly to avoid redundant extf->truncf chains.
882+ static Value smartTruncF32ToBF16 (PatternRewriter &rewriter, Location loc,
883+ Value val, Type bf16Type) {
884+ if (auto extfOp = val.getDefiningOp <arith::ExtFOp>()) {
885+ if (extfOp.getIn ().getType () == bf16Type)
886+ return extfOp.getIn ();
887+ }
888+ return arith::TruncFOp::create (rewriter, loc, bf16Type, val);
889+ }
890+
891+ // Smart truncation for scalar values (used by reduction patterns).
892+ static Value smartTruncScalarF32ToBF16 (PatternRewriter &rewriter, Location loc,
893+ Value val) {
894+ Type bf16Ty = rewriter.getBF16Type ();
895+ if (auto extfOp = val.getDefiningOp <arith::ExtFOp>()) {
896+ if (extfOp.getIn ().getType () == bf16Ty)
897+ return extfOp.getIn ();
898+ }
899+ return arith::TruncFOp::create (rewriter, loc, bf16Ty, val);
900+ }
901+
902+ // / Pattern to emulate f32 binary vector arithmetic ops in bf16.
903+ // / For an op like: %r = arith.addf %a, %b : vector<16xf32>
904+ // / Produces:
905+ // / %a_bf16 = arith.truncf %a : vector<16xf32> to vector<16xbf16>
906+ // / %b_bf16 = arith.truncf %b : vector<16xf32> to vector<16xbf16>
907+ // / %r_bf16 = arith.addf %a_bf16, %b_bf16 : vector<16xbf16>
908+ // / %r = arith.extf %r_bf16 : vector<16xbf16> to vector<16xf32>
909+ template <typename OpTy>
910+ struct EmulateBinaryF32InBF16Pattern : public OpRewritePattern <OpTy> {
911+ using OpRewritePattern<OpTy>::OpRewritePattern;
912+
913+ LogicalResult matchAndRewrite (OpTy op,
914+ PatternRewriter &rewriter) const override {
915+ auto resultType = dyn_cast<VectorType>(op.getType ());
916+ if (!resultType || !resultType.getElementType ().isF32 ())
917+ return failure ();
918+
919+ Location loc = op.getLoc ();
920+ auto bf16VecType =
921+ VectorType::get (resultType.getShape (), rewriter.getBF16Type ());
922+
923+ Value lhsBF16 =
924+ smartTruncF32ToBF16 (rewriter, loc, op.getLhs (), bf16VecType);
925+ Value rhsBF16 =
926+ smartTruncF32ToBF16 (rewriter, loc, op.getRhs (), bf16VecType);
927+
928+ Value newResult =
929+ OpTy::create (rewriter, loc, bf16VecType, lhsBF16, rhsBF16);
930+ auto extOp = arith::ExtFOp::create (rewriter, loc, resultType, newResult);
931+ rewriter.replaceOp (op, extOp);
932+ return success ();
933+ }
934+ };
935+
936+ // / Pattern to emulate f32 comparison ops in bf16.
937+ // / Result type stays vector<Nxi1>, only operands are truncated.
938+ struct EmulateCmpFF32InBF16Pattern : public OpRewritePattern <arith::CmpFOp> {
939+ using OpRewritePattern::OpRewritePattern;
940+
941+ LogicalResult matchAndRewrite (arith::CmpFOp op,
942+ PatternRewriter &rewriter) const override {
943+ auto lhsType = dyn_cast<VectorType>(op.getLhs ().getType ());
944+ if (!lhsType || !lhsType.getElementType ().isF32 ())
945+ return failure ();
946+
947+ Location loc = op.getLoc ();
948+ auto bf16VecType =
949+ VectorType::get (lhsType.getShape (), rewriter.getBF16Type ());
950+
951+ Value lhsBF16 =
952+ smartTruncF32ToBF16 (rewriter, loc, op.getLhs (), bf16VecType);
953+ Value rhsBF16 =
954+ smartTruncF32ToBF16 (rewriter, loc, op.getRhs (), bf16VecType);
955+
956+ rewriter.replaceOpWithNewOp <arith::CmpFOp>(op, op.getPredicate (), lhsBF16,
957+ rhsBF16);
958+ return success ();
959+ }
960+ };
961+
962+ // / Pattern to emulate f32 select ops in bf16.
963+ // / Condition stays vector<Nxi1>, true/false values are truncated.
964+ struct EmulateSelectF32InBF16Pattern
965+ : public OpRewritePattern<arith::SelectOp> {
966+ using OpRewritePattern::OpRewritePattern;
967+
968+ LogicalResult matchAndRewrite (arith::SelectOp op,
969+ PatternRewriter &rewriter) const override {
970+ auto resultType = dyn_cast<VectorType>(op.getType ());
971+ if (!resultType || !resultType.getElementType ().isF32 ())
972+ return failure ();
973+
974+ Location loc = op.getLoc ();
975+ auto bf16VecType =
976+ VectorType::get (resultType.getShape (), rewriter.getBF16Type ());
977+
978+ Value trueValBF16 =
979+ smartTruncF32ToBF16 (rewriter, loc, op.getTrueValue (), bf16VecType);
980+ Value falseValBF16 =
981+ smartTruncF32ToBF16 (rewriter, loc, op.getFalseValue (), bf16VecType);
982+
983+ Value newResult = arith::SelectOp::create (rewriter, loc, op.getCondition (),
984+ trueValBF16, falseValBF16);
985+ auto extOp = arith::ExtFOp::create (rewriter, loc, resultType, newResult);
986+ rewriter.replaceOp (op, extOp);
987+ return success ();
988+ }
989+ };
990+
991+ // / Pattern to emulate f32 vector.fma in bf16.
992+ // / All three operands (lhs, rhs, acc) are truncated.
993+ struct EmulateFMAF32InBF16Pattern : public OpRewritePattern <vector::FMAOp> {
994+ using OpRewritePattern::OpRewritePattern;
995+
996+ LogicalResult matchAndRewrite (vector::FMAOp op,
997+ PatternRewriter &rewriter) const override {
998+ auto resultType = dyn_cast<VectorType>(op.getType ());
999+ if (!resultType || !resultType.getElementType ().isF32 ())
1000+ return failure ();
1001+
1002+ Location loc = op.getLoc ();
1003+ auto bf16VecType =
1004+ VectorType::get (resultType.getShape (), rewriter.getBF16Type ());
1005+
1006+ Value lhsBF16 =
1007+ smartTruncF32ToBF16 (rewriter, loc, op.getLhs (), bf16VecType);
1008+ Value rhsBF16 =
1009+ smartTruncF32ToBF16 (rewriter, loc, op.getRhs (), bf16VecType);
1010+ Value accBF16 =
1011+ smartTruncF32ToBF16 (rewriter, loc, op.getAcc (), bf16VecType);
1012+
1013+ Value newResult =
1014+ vector::FMAOp::create (rewriter, loc, lhsBF16, rhsBF16, accBF16);
1015+ auto extOp = arith::ExtFOp::create (rewriter, loc, resultType, newResult);
1016+ rewriter.replaceOp (op, extOp);
1017+ return success ();
1018+ }
1019+ };
1020+
1021+ // / Pattern to emulate f32 unary vector ops in bf16.
1022+ template <typename OpTy>
1023+ struct EmulateUnaryF32InBF16Pattern : public OpRewritePattern <OpTy> {
1024+ using OpRewritePattern<OpTy>::OpRewritePattern;
1025+
1026+ LogicalResult matchAndRewrite (OpTy op,
1027+ PatternRewriter &rewriter) const override {
1028+ auto resultType = dyn_cast<VectorType>(op.getType ());
1029+ if (!resultType || !resultType.getElementType ().isF32 ())
1030+ return failure ();
1031+
1032+ Location loc = op.getLoc ();
1033+ auto bf16VecType =
1034+ VectorType::get (resultType.getShape (), rewriter.getBF16Type ());
1035+
1036+ Value inputBF16 =
1037+ smartTruncF32ToBF16 (rewriter, loc, op->getOperand (0 ), bf16VecType);
1038+
1039+ Value newResult = OpTy::create (rewriter, loc, bf16VecType, inputBF16);
1040+ auto extOp = arith::ExtFOp::create (rewriter, loc, resultType, newResult);
1041+ rewriter.replaceOp (op, extOp);
1042+ return success ();
1043+ }
1044+ };
1045+
1046+ // / Pattern to emulate f32 vector.reduction in bf16.
1047+ struct EmulateReductionF32InBF16Pattern
1048+ : public OpRewritePattern<vector::ReductionOp> {
1049+ using OpRewritePattern::OpRewritePattern;
1050+
1051+ LogicalResult matchAndRewrite (vector::ReductionOp op,
1052+ PatternRewriter &rewriter) const override {
1053+ if (!op.getType ().isF32 ())
1054+ return failure ();
1055+ auto vectorType = dyn_cast<VectorType>(op.getVector ().getType ());
1056+ if (!vectorType || !vectorType.getElementType ().isF32 ())
1057+ return failure ();
1058+
1059+ Location loc = op.getLoc ();
1060+ auto bf16VecType =
1061+ VectorType::get (vectorType.getShape (), rewriter.getBF16Type ());
1062+
1063+ Value vectorBF16 =
1064+ smartTruncF32ToBF16 (rewriter, loc, op.getVector (), bf16VecType);
1065+
1066+ Value accBF16 = nullptr ;
1067+ if (op.getAcc ())
1068+ accBF16 = smartTruncScalarF32ToBF16 (rewriter, loc, op.getAcc ());
1069+
1070+ Value newResult = vector::ReductionOp::create (rewriter, loc, op.getKind (),
1071+ vectorBF16, accBF16);
1072+ auto extOp =
1073+ arith::ExtFOp::create (rewriter, loc, rewriter.getF32Type (), newResult);
1074+ rewriter.replaceOp (op, extOp);
1075+ return success ();
1076+ }
1077+ };
1078+
1079+ struct BF16EmulationPass
1080+ : public PassWrapper<BF16EmulationPass, OperationPass<>> {
1081+
1082+ void runOnOperation () override {
1083+ auto *op = getOperation ();
1084+ MLIRContext *context = &getContext ();
1085+ RewritePatternSet patterns (context);
1086+
1087+ // Binary arithmetic ops
1088+ patterns.add <EmulateBinaryF32InBF16Pattern<arith::AddFOp>,
1089+ EmulateBinaryF32InBF16Pattern<arith::SubFOp>,
1090+ EmulateBinaryF32InBF16Pattern<arith::MulFOp>,
1091+ EmulateBinaryF32InBF16Pattern<arith::MaximumFOp>,
1092+ EmulateBinaryF32InBF16Pattern<arith::MinimumFOp>>(context);
1093+
1094+ // Note: arith.divf is NOT demoted because bf16 vector divf is unsupported
1095+ // on all AIE targets (Peano does not legalize G_FDIV on <16 x s16>).
1096+
1097+ // Special-case ops
1098+ patterns.add <EmulateCmpFF32InBF16Pattern, EmulateSelectF32InBF16Pattern,
1099+ EmulateFMAF32InBF16Pattern, EmulateReductionF32InBF16Pattern>(
1100+ context);
1101+
1102+ // Unary ops
1103+ patterns.add <EmulateUnaryF32InBF16Pattern<arith::NegFOp>>(context);
1104+
1105+ (void )applyPatternsGreedily (op, std::move (patterns));
1106+ }
1107+ };
1108+
1109+ std::unique_ptr<::mlir::Pass> xilinx::aievec::createBF16EmulationPass () {
1110+ return std::make_unique<BF16EmulationPass>();
1111+ }
1112+
8761113struct VectorBroadcastLoweringPass
8771114 : public PassWrapper<VectorBroadcastLoweringPass, OperationPass<>> {
8781115
@@ -1051,6 +1288,11 @@ void xilinx::aievec::buildCanonicalizeVectorForAIEVec(
10511288 // Add `Vector` code canonicalization passes
10521289 // TODO: Add passes to unroll vector with unsupported types
10531290 // TODO: Add passes to split vectors that won't fit in registers
1291+
1292+ // If bf16-emulation is enabled, demote f32 vector arithmetic to bf16 first.
1293+ if (options.enableBF16Emulation )
1294+ pm.addPass (createBF16EmulationPass ());
1295+
10541296 if (decodeTargetBackend (options.targetBackend ) == TargetBackend::LLVMIR)
10551297 pm.addPass (createReorderOperationsPass ());
10561298 pm.addPass (createCopyRemovalPass ());
0 commit comments