diff --git a/include/aie/Dialect/AIEVec/Pipelines/Passes.h b/include/aie/Dialect/AIEVec/Pipelines/Passes.h index 182e7a887ef..f2a9d8fae0b 100644 --- a/include/aie/Dialect/AIEVec/Pipelines/Passes.h +++ b/include/aie/Dialect/AIEVec/Pipelines/Passes.h @@ -50,6 +50,13 @@ struct CanonicalizeVectorForAIEVecOptions "will determine the aievec operations used to convert " "from vector dialect."), llvm::cl::init("cpp")}; + PassOptions::Option enableBF16Emulation{ + *this, "bf16-emulation", + llvm::cl::desc( + "Emulate f32 vector arithmetic using bf16 operations. Inserts " + "arith.truncf/arith.extf around f32 vector ops to compute in bf16. " + "Trades precision for performance."), + llvm::cl::init(false)}; }; /// Options for the "lower-vector-to-aievec" pipeline. @@ -118,6 +125,13 @@ struct ConvertVectorToAIEVecOptions "will determine the aievec operations used to convert " "from vector dialect."), llvm::cl::init("cpp")}; + PassOptions::Option enableBF16Emulation{ + *this, "bf16-emulation", + llvm::cl::desc( + "Emulate f32 vector arithmetic using bf16 operations. Inserts " + "arith.truncf/arith.extf around f32 vector ops to compute in bf16. " + "Trades precision for performance."), + llvm::cl::init(false)}; mlir::LogicalResult parseFromString(mlir::StringRef options) { auto res = PassPipelineOptions::parseFromString(options); @@ -126,6 +140,7 @@ struct ConvertVectorToAIEVecOptions lowerOptions.targetBackend = targetBackend; canonicalizeOptions.aieTarget = aieTarget; canonicalizeOptions.targetBackend = targetBackend; + canonicalizeOptions.enableBF16Emulation = enableBF16Emulation; optimizeOptions.aieTarget = aieTarget; optimizeOptions.targetBackend = targetBackend; optimizeOptions.shiftParam = shiftParam; @@ -190,6 +205,10 @@ void buildDynamicSizeNoImplicitBroadcastPass(mlir::OpPassManager &pm); /// operations for AIE2p targets. std::unique_ptr<::mlir::Pass> createSplitVectorLoadUpsChainsPass(); +/// Create a pass that emulates f32 vector arithmetic using bf16 operations. +/// Inserts arith.truncf/arith.extf around f32 vector ops to compute in bf16. +std::unique_ptr<::mlir::Pass> createBF16EmulationPass(); + } // namespace aievec } // namespace xilinx diff --git a/lib/Dialect/AIEVec/Transforms/VectorToVectorConversions.cpp b/lib/Dialect/AIEVec/Transforms/VectorToVectorConversions.cpp index 8446ebf6055..e61c0df5c4b 100644 --- a/lib/Dialect/AIEVec/Transforms/VectorToVectorConversions.cpp +++ b/lib/Dialect/AIEVec/Transforms/VectorToVectorConversions.cpp @@ -873,6 +873,243 @@ populateAIE2CanonicalizeConversionPatterns(RewritePatternSet &patterns, //=================== Common AIE Canonicalization Passes =====================// //============================================================================// +//===----------------------------------------------------------------------===// +// BF16 Emulation: Emulate f32 vector arithmetic using bf16 operations. +//===----------------------------------------------------------------------===// + +// Smart truncation helper: if the value was produced by arith.extf from bf16, +// reuse the bf16 source directly to avoid redundant extf->truncf chains. +static Value smartTruncF32ToBF16(PatternRewriter &rewriter, Location loc, + Value val, Type bf16Type) { + if (auto extfOp = val.getDefiningOp()) { + if (extfOp.getIn().getType() == bf16Type) + return extfOp.getIn(); + } + return arith::TruncFOp::create(rewriter, loc, bf16Type, val); +} + +// Smart truncation for scalar values (used by reduction patterns). +static Value smartTruncScalarF32ToBF16(PatternRewriter &rewriter, Location loc, + Value val) { + Type bf16Ty = rewriter.getBF16Type(); + if (auto extfOp = val.getDefiningOp()) { + if (extfOp.getIn().getType() == bf16Ty) + return extfOp.getIn(); + } + return arith::TruncFOp::create(rewriter, loc, bf16Ty, val); +} + +/// Pattern to emulate f32 binary vector arithmetic ops in bf16. +/// For an op like: %r = arith.addf %a, %b : vector<16xf32> +/// Produces: +/// %a_bf16 = arith.truncf %a : vector<16xf32> to vector<16xbf16> +/// %b_bf16 = arith.truncf %b : vector<16xf32> to vector<16xbf16> +/// %r_bf16 = arith.addf %a_bf16, %b_bf16 : vector<16xbf16> +/// %r = arith.extf %r_bf16 : vector<16xbf16> to vector<16xf32> +template +struct EmulateBinaryF32InBF16Pattern : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(OpTy op, + PatternRewriter &rewriter) const override { + auto resultType = dyn_cast(op.getType()); + if (!resultType || !resultType.getElementType().isF32()) + return failure(); + + Location loc = op.getLoc(); + auto bf16VecType = + VectorType::get(resultType.getShape(), rewriter.getBF16Type()); + + Value lhsBF16 = + smartTruncF32ToBF16(rewriter, loc, op.getLhs(), bf16VecType); + Value rhsBF16 = + smartTruncF32ToBF16(rewriter, loc, op.getRhs(), bf16VecType); + + Value newResult = + OpTy::create(rewriter, loc, bf16VecType, lhsBF16, rhsBF16); + auto extOp = arith::ExtFOp::create(rewriter, loc, resultType, newResult); + rewriter.replaceOp(op, extOp); + return success(); + } +}; + +/// Pattern to emulate f32 comparison ops in bf16. +/// Result type stays vector, only operands are truncated. +struct EmulateCmpFF32InBF16Pattern : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(arith::CmpFOp op, + PatternRewriter &rewriter) const override { + auto lhsType = dyn_cast(op.getLhs().getType()); + if (!lhsType || !lhsType.getElementType().isF32()) + return failure(); + + Location loc = op.getLoc(); + auto bf16VecType = + VectorType::get(lhsType.getShape(), rewriter.getBF16Type()); + + Value lhsBF16 = + smartTruncF32ToBF16(rewriter, loc, op.getLhs(), bf16VecType); + Value rhsBF16 = + smartTruncF32ToBF16(rewriter, loc, op.getRhs(), bf16VecType); + + rewriter.replaceOpWithNewOp(op, op.getPredicate(), lhsBF16, + rhsBF16); + return success(); + } +}; + +/// Pattern to emulate f32 select ops in bf16. +/// Condition stays vector, true/false values are truncated. +struct EmulateSelectF32InBF16Pattern + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(arith::SelectOp op, + PatternRewriter &rewriter) const override { + auto resultType = dyn_cast(op.getType()); + if (!resultType || !resultType.getElementType().isF32()) + return failure(); + + Location loc = op.getLoc(); + auto bf16VecType = + VectorType::get(resultType.getShape(), rewriter.getBF16Type()); + + Value trueValBF16 = + smartTruncF32ToBF16(rewriter, loc, op.getTrueValue(), bf16VecType); + Value falseValBF16 = + smartTruncF32ToBF16(rewriter, loc, op.getFalseValue(), bf16VecType); + + Value newResult = arith::SelectOp::create(rewriter, loc, op.getCondition(), + trueValBF16, falseValBF16); + auto extOp = arith::ExtFOp::create(rewriter, loc, resultType, newResult); + rewriter.replaceOp(op, extOp); + return success(); + } +}; + +/// Pattern to emulate f32 vector.fma in bf16. +/// All three operands (lhs, rhs, acc) are truncated. +struct EmulateFMAF32InBF16Pattern : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(vector::FMAOp op, + PatternRewriter &rewriter) const override { + auto resultType = dyn_cast(op.getType()); + if (!resultType || !resultType.getElementType().isF32()) + return failure(); + + Location loc = op.getLoc(); + auto bf16VecType = + VectorType::get(resultType.getShape(), rewriter.getBF16Type()); + + Value lhsBF16 = + smartTruncF32ToBF16(rewriter, loc, op.getLhs(), bf16VecType); + Value rhsBF16 = + smartTruncF32ToBF16(rewriter, loc, op.getRhs(), bf16VecType); + Value accBF16 = + smartTruncF32ToBF16(rewriter, loc, op.getAcc(), bf16VecType); + + Value newResult = + vector::FMAOp::create(rewriter, loc, lhsBF16, rhsBF16, accBF16); + auto extOp = arith::ExtFOp::create(rewriter, loc, resultType, newResult); + rewriter.replaceOp(op, extOp); + return success(); + } +}; + +/// Pattern to emulate f32 unary vector ops in bf16. +template +struct EmulateUnaryF32InBF16Pattern : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(OpTy op, + PatternRewriter &rewriter) const override { + auto resultType = dyn_cast(op.getType()); + if (!resultType || !resultType.getElementType().isF32()) + return failure(); + + Location loc = op.getLoc(); + auto bf16VecType = + VectorType::get(resultType.getShape(), rewriter.getBF16Type()); + + Value inputBF16 = + smartTruncF32ToBF16(rewriter, loc, op->getOperand(0), bf16VecType); + + Value newResult = OpTy::create(rewriter, loc, bf16VecType, inputBF16); + auto extOp = arith::ExtFOp::create(rewriter, loc, resultType, newResult); + rewriter.replaceOp(op, extOp); + return success(); + } +}; + +/// Pattern to emulate f32 vector.reduction in bf16. +struct EmulateReductionF32InBF16Pattern + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(vector::ReductionOp op, + PatternRewriter &rewriter) const override { + if (!op.getType().isF32()) + return failure(); + auto vectorType = dyn_cast(op.getVector().getType()); + if (!vectorType || !vectorType.getElementType().isF32()) + return failure(); + + Location loc = op.getLoc(); + auto bf16VecType = + VectorType::get(vectorType.getShape(), rewriter.getBF16Type()); + + Value vectorBF16 = + smartTruncF32ToBF16(rewriter, loc, op.getVector(), bf16VecType); + + Value accBF16 = nullptr; + if (op.getAcc()) + accBF16 = smartTruncScalarF32ToBF16(rewriter, loc, op.getAcc()); + + Value newResult = vector::ReductionOp::create(rewriter, loc, op.getKind(), + vectorBF16, accBF16); + auto extOp = + arith::ExtFOp::create(rewriter, loc, rewriter.getF32Type(), newResult); + rewriter.replaceOp(op, extOp); + return success(); + } +}; + +struct BF16EmulationPass + : public PassWrapper> { + + void runOnOperation() override { + auto *op = getOperation(); + MLIRContext *context = &getContext(); + RewritePatternSet patterns(context); + + // Binary arithmetic ops + patterns.add, + EmulateBinaryF32InBF16Pattern, + EmulateBinaryF32InBF16Pattern, + EmulateBinaryF32InBF16Pattern, + EmulateBinaryF32InBF16Pattern>(context); + + // Note: arith.divf is NOT demoted because bf16 vector divf is unsupported + // on all AIE targets (Peano does not legalize G_FDIV on <16 x s16>). + + // Special-case ops + patterns.add( + context); + + // Unary ops + patterns.add>(context); + + (void)applyPatternsGreedily(op, std::move(patterns)); + } +}; + +std::unique_ptr<::mlir::Pass> xilinx::aievec::createBF16EmulationPass() { + return std::make_unique(); +} + struct VectorBroadcastLoweringPass : public PassWrapper> { @@ -1051,6 +1288,11 @@ void xilinx::aievec::buildCanonicalizeVectorForAIEVec( // Add `Vector` code canonicalization passes // TODO: Add passes to unroll vector with unsupported types // TODO: Add passes to split vectors that won't fit in registers + + // If bf16-emulation is enabled, demote f32 vector arithmetic to bf16 first. + if (options.enableBF16Emulation) + pm.addPass(createBF16EmulationPass()); + if (decodeTargetBackend(options.targetBackend) == TargetBackend::LLVMIR) pm.addPass(createReorderOperationsPass()); pm.addPass(createCopyRemovalPass()); diff --git a/test/Conversion/VectorToAIEVec/test-bf16-emulation.mlir b/test/Conversion/VectorToAIEVec/test-bf16-emulation.mlir new file mode 100644 index 00000000000..2ce1c0d53e1 --- /dev/null +++ b/test/Conversion/VectorToAIEVec/test-bf16-emulation.mlir @@ -0,0 +1,187 @@ +//===- test-bf16-emulation.mlir - bf16 emulation of f32 ops --------*- MLIR -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +// Copyright (C) 2026, Advanced Micro Devices, Inc. +// +//===----------------------------------------------------------------------===// + +// Test the bf16-emulation option which demotes f32 vector arithmetic to bf16. + +// RUN: aie-opt %s -split-input-file --canonicalize-vector-for-aievec="aie-target=aie2 target-backend=llvmir bf16-emulation=true" | FileCheck %s + +// Test: basic addf demotion +// CHECK-LABEL: func @test_addf +// CHECK-SAME: (%[[A:.*]]: vector<16xf32>, %[[B:.*]]: vector<16xf32>) +// CHECK: %[[A_BF16:.*]] = arith.truncf %[[A]] : vector<16xf32> to vector<16xbf16> +// CHECK: %[[B_BF16:.*]] = arith.truncf %[[B]] : vector<16xf32> to vector<16xbf16> +// CHECK: %[[RES_BF16:.*]] = arith.addf %[[A_BF16]], %[[B_BF16]] : vector<16xbf16> +// CHECK: %[[RES:.*]] = arith.extf %[[RES_BF16]] : vector<16xbf16> to vector<16xf32> +// CHECK: return %[[RES]] +func.func @test_addf(%a: vector<16xf32>, %b: vector<16xf32>) -> vector<16xf32> { + %0 = arith.addf %a, %b : vector<16xf32> + return %0 : vector<16xf32> +} + +// ----- + +// Test: basic mulf demotion +// CHECK-LABEL: func @test_mulf +// CHECK: arith.truncf {{.*}} : vector<16xf32> to vector<16xbf16> +// CHECK: arith.truncf {{.*}} : vector<16xf32> to vector<16xbf16> +// CHECK: arith.mulf {{.*}} : vector<16xbf16> +// CHECK: arith.extf {{.*}} : vector<16xbf16> to vector<16xf32> +func.func @test_mulf(%a: vector<16xf32>, %b: vector<16xf32>) -> vector<16xf32> { + %0 = arith.mulf %a, %b : vector<16xf32> + return %0 : vector<16xf32> +} + +// ----- + +// Test: basic subf demotion +// CHECK-LABEL: func @test_subf +// CHECK: arith.truncf {{.*}} : vector<16xf32> to vector<16xbf16> +// CHECK: arith.truncf {{.*}} : vector<16xf32> to vector<16xbf16> +// CHECK: arith.subf {{.*}} : vector<16xbf16> +// CHECK: arith.extf {{.*}} : vector<16xbf16> to vector<16xf32> +func.func @test_subf(%a: vector<16xf32>, %b: vector<16xf32>) -> vector<16xf32> { + %0 = arith.subf %a, %b : vector<16xf32> + return %0 : vector<16xf32> +} + +// ----- + +// Test: chain optimization - intermediate extf->truncf should be eliminated +// CHECK-LABEL: func @test_chain_optimization +// CHECK-SAME: (%[[A:.*]]: vector<16xf32>, %[[B:.*]]: vector<16xf32>, %[[C:.*]]: vector<16xf32>) +// CHECK: %[[A_BF16:.*]] = arith.truncf %[[A]] : vector<16xf32> to vector<16xbf16> +// CHECK: %[[B_BF16:.*]] = arith.truncf %[[B]] : vector<16xf32> to vector<16xbf16> +// CHECK: %[[ADD:.*]] = arith.addf %[[A_BF16]], %[[B_BF16]] : vector<16xbf16> +// No intermediate extf->truncf between add and mul: +// CHECK: %[[C_BF16:.*]] = arith.truncf %[[C]] : vector<16xf32> to vector<16xbf16> +// CHECK: %[[MUL:.*]] = arith.mulf %[[ADD]], %[[C_BF16]] : vector<16xbf16> +// CHECK: %[[RES:.*]] = arith.extf %[[MUL]] : vector<16xbf16> to vector<16xf32> +// CHECK: return %[[RES]] +func.func @test_chain_optimization(%a: vector<16xf32>, %b: vector<16xf32>, %c: vector<16xf32>) -> vector<16xf32> { + %0 = arith.addf %a, %b : vector<16xf32> + %1 = arith.mulf %0, %c : vector<16xf32> + return %1 : vector<16xf32> +} + +// ----- + +// Test: cmpf + select demotion +// CHECK-LABEL: func @test_cmpf_select +// CHECK: arith.truncf {{.*}} : vector<16xf32> to vector<16xbf16> +// CHECK: arith.truncf {{.*}} : vector<16xf32> to vector<16xbf16> +// CHECK: arith.cmpf ogt, {{.*}} : vector<16xbf16> +// CHECK: arith.truncf {{.*}} : vector<16xf32> to vector<16xbf16> +// CHECK: arith.truncf {{.*}} : vector<16xf32> to vector<16xbf16> +// CHECK: arith.select {{.*}} : vector<16xi1>, vector<16xbf16> +// CHECK: arith.extf {{.*}} : vector<16xbf16> to vector<16xf32> +func.func @test_cmpf_select(%a: vector<16xf32>, %b: vector<16xf32>) -> vector<16xf32> { + %cmp = arith.cmpf ogt, %a, %b : vector<16xf32> + %sel = arith.select %cmp, %a, %b : vector<16xi1>, vector<16xf32> + return %sel : vector<16xf32> +} + +// ----- + +// Test: vector.fma demotion +// CHECK-LABEL: func @test_fma +// CHECK: arith.truncf {{.*}} : vector<16xf32> to vector<16xbf16> +// CHECK: arith.truncf {{.*}} : vector<16xf32> to vector<16xbf16> +// CHECK: arith.truncf {{.*}} : vector<16xf32> to vector<16xbf16> +// CHECK: vector.fma {{.*}} : vector<16xbf16> +// CHECK: arith.extf {{.*}} : vector<16xbf16> to vector<16xf32> +func.func @test_fma(%a: vector<16xf32>, %b: vector<16xf32>, %c: vector<16xf32>) -> vector<16xf32> { + %0 = vector.fma %a, %b, %c : vector<16xf32> + return %0 : vector<16xf32> +} + +// ----- + +// Test: maximumf demotion +// CHECK-LABEL: func @test_maximumf +// CHECK: arith.truncf {{.*}} : vector<16xf32> to vector<16xbf16> +// CHECK: arith.truncf {{.*}} : vector<16xf32> to vector<16xbf16> +// CHECK: arith.maximumf {{.*}} : vector<16xbf16> +// CHECK: arith.extf {{.*}} : vector<16xbf16> to vector<16xf32> +func.func @test_maximumf(%a: vector<16xf32>, %b: vector<16xf32>) -> vector<16xf32> { + %0 = arith.maximumf %a, %b : vector<16xf32> + return %0 : vector<16xf32> +} + +// ----- + +// Test: divf is NOT demoted (bf16 vector divf unsupported on all AIE targets) +// CHECK-LABEL: func @test_divf_not_demoted +// CHECK-NOT: arith.truncf +// CHECK: arith.divf {{.*}} : vector<16xf32> +// CHECK-NOT: arith.extf +func.func @test_divf_not_demoted(%a: vector<16xf32>, %b: vector<16xf32>) -> vector<16xf32> { + %0 = arith.divf %a, %b : vector<16xf32> + return %0 : vector<16xf32> +} + +// ----- + +// Test: chain with divf - addf/mulf are bf16, divf stays f32 +// CHECK-LABEL: func @test_chain_with_divf +// CHECK-SAME: (%[[A:.*]]: vector<16xf32>, %[[B:.*]]: vector<16xf32>, %[[C:.*]]: vector<16xf32>) +// addf is demoted to bf16: +// CHECK: arith.truncf %[[A]] : vector<16xf32> to vector<16xbf16> +// CHECK: arith.truncf %[[B]] : vector<16xf32> to vector<16xbf16> +// CHECK: arith.addf {{.*}} : vector<16xbf16> +// divf stays in f32 (with extf from the addf result): +// CHECK: arith.extf {{.*}} : vector<16xbf16> to vector<16xf32> +// CHECK: arith.divf {{.*}} : vector<16xf32> +// mulf is demoted to bf16 (with truncf from the divf result): +// CHECK: arith.truncf {{.*}} : vector<16xf32> to vector<16xbf16> +// CHECK: arith.truncf %[[C]] : vector<16xf32> to vector<16xbf16> +// CHECK: arith.mulf {{.*}} : vector<16xbf16> +// CHECK: arith.extf {{.*}} : vector<16xbf16> to vector<16xf32> +func.func @test_chain_with_divf(%a: vector<16xf32>, %b: vector<16xf32>, %c: vector<16xf32>) -> vector<16xf32> { + %0 = arith.addf %a, %b : vector<16xf32> + %1 = arith.divf %0, %b : vector<16xf32> + %2 = arith.mulf %1, %c : vector<16xf32> + return %2 : vector<16xf32> +} + +// ----- + +// Test: bf16 ops are NOT affected (only f32 ops are demoted) +// CHECK-LABEL: func @test_bf16_unchanged +// CHECK-NOT: arith.truncf +// CHECK-NOT: arith.extf +// CHECK: arith.addf {{.*}} : vector<16xbf16> +func.func @test_bf16_unchanged(%a: vector<16xbf16>, %b: vector<16xbf16>) -> vector<16xbf16> { + %0 = arith.addf %a, %b : vector<16xbf16> + return %0 : vector<16xbf16> +} + +// ----- + +// Test: scalar f32 ops are NOT demoted (only vector ops) +// CHECK-LABEL: func @test_scalar_unchanged +// CHECK-NOT: arith.truncf +// CHECK-NOT: arith.extf +// CHECK: arith.addf {{.*}} : f32 +func.func @test_scalar_unchanged(%a: f32, %b: f32) -> f32 { + %0 = arith.addf %a, %b : f32 + return %0 : f32 +} + +// ----- + +// Test: negf demotion +// CHECK-LABEL: func @test_negf +// CHECK: arith.truncf {{.*}} : vector<16xf32> to vector<16xbf16> +// CHECK: arith.negf {{.*}} : vector<16xbf16> +// CHECK: arith.extf {{.*}} : vector<16xbf16> to vector<16xf32> +func.func @test_negf(%a: vector<16xf32>) -> vector<16xf32> { + %0 = arith.negf %a : vector<16xf32> + return %0 : vector<16xf32> +}