Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions include/aie/Dialect/AIEVec/Pipelines/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,13 @@ struct CanonicalizeVectorForAIEVecOptions
"will determine the aievec operations used to convert "
"from vector dialect."),
llvm::cl::init("cpp")};
PassOptions::Option<bool> enableBF16Emulation{
*this, "bf16-emulation",
llvm::cl::desc(
"Emulate f32 vector arithmetic using bf16 operations. Inserts "
"arith.truncf/arith.extf around f32 vector ops to compute in bf16. "
"Trades precision for performance."),
llvm::cl::init(false)};
};

/// Options for the "lower-vector-to-aievec" pipeline.
Expand Down Expand Up @@ -118,6 +125,13 @@ struct ConvertVectorToAIEVecOptions
"will determine the aievec operations used to convert "
"from vector dialect."),
llvm::cl::init("cpp")};
PassOptions::Option<bool> enableBF16Emulation{
*this, "bf16-emulation",
llvm::cl::desc(
"Emulate f32 vector arithmetic using bf16 operations. Inserts "
"arith.truncf/arith.extf around f32 vector ops to compute in bf16. "
"Trades precision for performance."),
llvm::cl::init(false)};

mlir::LogicalResult parseFromString(mlir::StringRef options) {
auto res = PassPipelineOptions::parseFromString(options);
Expand All @@ -126,6 +140,7 @@ struct ConvertVectorToAIEVecOptions
lowerOptions.targetBackend = targetBackend;
canonicalizeOptions.aieTarget = aieTarget;
canonicalizeOptions.targetBackend = targetBackend;
canonicalizeOptions.enableBF16Emulation = enableBF16Emulation;
optimizeOptions.aieTarget = aieTarget;
optimizeOptions.targetBackend = targetBackend;
optimizeOptions.shiftParam = shiftParam;
Expand Down Expand Up @@ -190,6 +205,10 @@ void buildDynamicSizeNoImplicitBroadcastPass(mlir::OpPassManager &pm);
/// operations for AIE2p targets.
std::unique_ptr<::mlir::Pass> createSplitVectorLoadUpsChainsPass();

/// Create a pass that emulates f32 vector arithmetic using bf16 operations.
/// Inserts arith.truncf/arith.extf around f32 vector ops to compute in bf16.
std::unique_ptr<::mlir::Pass> createBF16EmulationPass();

} // namespace aievec
} // namespace xilinx

Expand Down
243 changes: 243 additions & 0 deletions lib/Dialect/AIEVec/Transforms/VectorToVectorConversions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -873,6 +873,244 @@ populateAIE2CanonicalizeConversionPatterns(RewritePatternSet &patterns,
//=================== Common AIE Canonicalization Passes =====================//
//============================================================================//

//===----------------------------------------------------------------------===//
// BF16 Emulation: Emulate f32 vector arithmetic using bf16 operations.
//===----------------------------------------------------------------------===//

// Smart truncation helper: if the value was produced by arith.extf from bf16,
// reuse the bf16 source directly to avoid redundant extf->truncf chains.
static Value smartTruncF32ToBF16(PatternRewriter &rewriter, Location loc,
Value val, Type bf16Type) {
if (auto extfOp = val.getDefiningOp<arith::ExtFOp>()) {
if (extfOp.getIn().getType() == bf16Type)
return extfOp.getIn();
}
return arith::TruncFOp::create(rewriter, loc, bf16Type, val);
}

// Smart truncation for scalar values (used by reduction patterns).
static Value smartTruncScalarF32ToBF16(PatternRewriter &rewriter, Location loc,
Value val) {
Type bf16Ty = rewriter.getBF16Type();
if (auto extfOp = val.getDefiningOp<arith::ExtFOp>()) {
if (extfOp.getIn().getType() == bf16Ty)
return extfOp.getIn();
}
return arith::TruncFOp::create(rewriter, loc, bf16Ty, val);
}

/// Pattern to emulate f32 binary vector arithmetic ops in bf16.
/// For an op like: %r = arith.addf %a, %b : vector<16xf32>
/// Produces:
/// %a_bf16 = arith.truncf %a : vector<16xf32> to vector<16xbf16>
/// %b_bf16 = arith.truncf %b : vector<16xf32> to vector<16xbf16>
/// %r_bf16 = arith.addf %a_bf16, %b_bf16 : vector<16xbf16>
/// %r = arith.extf %r_bf16 : vector<16xbf16> to vector<16xf32>
template <typename OpTy>
struct EmulateBinaryF32InBF16Pattern : public OpRewritePattern<OpTy> {
using OpRewritePattern<OpTy>::OpRewritePattern;

LogicalResult matchAndRewrite(OpTy op,
PatternRewriter &rewriter) const override {
auto resultType = dyn_cast<VectorType>(op.getType());
if (!resultType || !resultType.getElementType().isF32())
return failure();

Location loc = op.getLoc();
auto bf16VecType =
VectorType::get(resultType.getShape(), rewriter.getBF16Type());

Value lhsBF16 =
smartTruncF32ToBF16(rewriter, loc, op.getLhs(), bf16VecType);
Value rhsBF16 =
smartTruncF32ToBF16(rewriter, loc, op.getRhs(), bf16VecType);

Value newResult =
OpTy::create(rewriter, loc, bf16VecType, lhsBF16, rhsBF16);
auto extOp = arith::ExtFOp::create(rewriter, loc, resultType, newResult);
rewriter.replaceOp(op, extOp);
return success();
}
};

/// Pattern to emulate f32 comparison ops in bf16.
/// Result type stays vector<Nxi1>, only operands are truncated.
struct EmulateCmpFF32InBF16Pattern : public OpRewritePattern<arith::CmpFOp> {
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(arith::CmpFOp op,
PatternRewriter &rewriter) const override {
auto lhsType = dyn_cast<VectorType>(op.getLhs().getType());
if (!lhsType || !lhsType.getElementType().isF32())
return failure();

Location loc = op.getLoc();
auto bf16VecType =
VectorType::get(lhsType.getShape(), rewriter.getBF16Type());

Value lhsBF16 =
smartTruncF32ToBF16(rewriter, loc, op.getLhs(), bf16VecType);
Value rhsBF16 =
smartTruncF32ToBF16(rewriter, loc, op.getRhs(), bf16VecType);

rewriter.replaceOpWithNewOp<arith::CmpFOp>(op, op.getPredicate(), lhsBF16,
rhsBF16);
return success();
}
};

/// Pattern to emulate f32 select ops in bf16.
/// Condition stays vector<Nxi1>, true/false values are truncated.
struct EmulateSelectF32InBF16Pattern
: public OpRewritePattern<arith::SelectOp> {
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(arith::SelectOp op,
PatternRewriter &rewriter) const override {
auto resultType = dyn_cast<VectorType>(op.getType());
if (!resultType || !resultType.getElementType().isF32())
return failure();

Location loc = op.getLoc();
auto bf16VecType =
VectorType::get(resultType.getShape(), rewriter.getBF16Type());

Value trueValBF16 =
smartTruncF32ToBF16(rewriter, loc, op.getTrueValue(), bf16VecType);
Value falseValBF16 =
smartTruncF32ToBF16(rewriter, loc, op.getFalseValue(), bf16VecType);

Value newResult = arith::SelectOp::create(rewriter, loc, op.getCondition(),
trueValBF16, falseValBF16);
auto extOp = arith::ExtFOp::create(rewriter, loc, resultType, newResult);
rewriter.replaceOp(op, extOp);
return success();
}
};

/// Pattern to emulate f32 vector.fma in bf16.
/// All three operands (lhs, rhs, acc) are truncated.
struct EmulateFMAF32InBF16Pattern : public OpRewritePattern<vector::FMAOp> {
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(vector::FMAOp op,
PatternRewriter &rewriter) const override {
auto resultType = dyn_cast<VectorType>(op.getType());
if (!resultType || !resultType.getElementType().isF32())
return failure();

Location loc = op.getLoc();
auto bf16VecType =
VectorType::get(resultType.getShape(), rewriter.getBF16Type());

Value lhsBF16 =
smartTruncF32ToBF16(rewriter, loc, op.getLhs(), bf16VecType);
Value rhsBF16 =
smartTruncF32ToBF16(rewriter, loc, op.getRhs(), bf16VecType);
Value accBF16 =
smartTruncF32ToBF16(rewriter, loc, op.getAcc(), bf16VecType);

Value newResult =
vector::FMAOp::create(rewriter, loc, lhsBF16, rhsBF16, accBF16);
auto extOp = arith::ExtFOp::create(rewriter, loc, resultType, newResult);
rewriter.replaceOp(op, extOp);
return success();
}
};

/// Pattern to emulate f32 unary vector ops in bf16.
template <typename OpTy>
struct EmulateUnaryF32InBF16Pattern : public OpRewritePattern<OpTy> {
using OpRewritePattern<OpTy>::OpRewritePattern;

LogicalResult matchAndRewrite(OpTy op,
PatternRewriter &rewriter) const override {
auto resultType = dyn_cast<VectorType>(op.getType());
if (!resultType || !resultType.getElementType().isF32())
return failure();

Location loc = op.getLoc();
auto bf16VecType =
VectorType::get(resultType.getShape(), rewriter.getBF16Type());

Value inputBF16 = smartTruncF32ToBF16(rewriter, loc, op->getOperand(0),
bf16VecType);

Value newResult = OpTy::create(rewriter, loc, bf16VecType, inputBF16);
auto extOp = arith::ExtFOp::create(rewriter, loc, resultType, newResult);
rewriter.replaceOp(op, extOp);
return success();
}
};

/// Pattern to emulate f32 vector.reduction in bf16.
struct EmulateReductionF32InBF16Pattern
: public OpRewritePattern<vector::ReductionOp> {
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(vector::ReductionOp op,
PatternRewriter &rewriter) const override {
if (!op.getType().isF32())
return failure();
auto vectorType = dyn_cast<VectorType>(op.getVector().getType());
if (!vectorType || !vectorType.getElementType().isF32())
return failure();

Location loc = op.getLoc();
auto bf16VecType =
VectorType::get(vectorType.getShape(), rewriter.getBF16Type());

Value vectorBF16 =
smartTruncF32ToBF16(rewriter, loc, op.getVector(), bf16VecType);

Value accBF16 = nullptr;
if (op.getAcc())
accBF16 = smartTruncScalarF32ToBF16(rewriter, loc, op.getAcc());

Value newResult = vector::ReductionOp::create(rewriter, loc, op.getKind(),
vectorBF16, accBF16);
auto extOp =
arith::ExtFOp::create(rewriter, loc, rewriter.getF32Type(), newResult);
rewriter.replaceOp(op, extOp);
return success();
}
};

struct BF16EmulationPass
: public PassWrapper<BF16EmulationPass, OperationPass<>> {

void runOnOperation() override {
auto *op = getOperation();
MLIRContext *context = &getContext();
RewritePatternSet patterns(context);

// Binary arithmetic ops
patterns
.add<EmulateBinaryF32InBF16Pattern<arith::AddFOp>,
EmulateBinaryF32InBF16Pattern<arith::SubFOp>,
EmulateBinaryF32InBF16Pattern<arith::MulFOp>,
EmulateBinaryF32InBF16Pattern<arith::MaximumFOp>,
EmulateBinaryF32InBF16Pattern<arith::MinimumFOp>>(context);

// Note: arith.divf is NOT demoted because bf16 vector divf is unsupported
// on all AIE targets (Peano does not legalize G_FDIV on <16 x s16>).

// Special-case ops
patterns.add<EmulateCmpFF32InBF16Pattern, EmulateSelectF32InBF16Pattern,
EmulateFMAF32InBF16Pattern,
EmulateReductionF32InBF16Pattern>(context);

// Unary ops
patterns.add<EmulateUnaryF32InBF16Pattern<arith::NegFOp>>(context);

(void)applyPatternsGreedily(op, std::move(patterns));
}
};

std::unique_ptr<::mlir::Pass> xilinx::aievec::createBF16EmulationPass() {
return std::make_unique<BF16EmulationPass>();
}

struct VectorBroadcastLoweringPass
: public PassWrapper<VectorBroadcastLoweringPass, OperationPass<>> {

Expand Down Expand Up @@ -1051,6 +1289,11 @@ void xilinx::aievec::buildCanonicalizeVectorForAIEVec(
// Add `Vector` code canonicalization passes
// TODO: Add passes to unroll vector with unsupported types
// TODO: Add passes to split vectors that won't fit in registers

// If bf16-emulation is enabled, demote f32 vector arithmetic to bf16 first.
if (options.enableBF16Emulation)
pm.addPass(createBF16EmulationPass());

if (decodeTargetBackend(options.targetBackend) == TargetBackend::LLVMIR)
pm.addPass(createReorderOperationsPass());
pm.addPass(createCopyRemovalPass());
Expand Down
Loading
Loading