Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ namespace arm_sme {
/// Pass to enable Armv9 Streaming SVE mode.
std::unique_ptr<Pass> createEnableArmStreamingPass(
const ArmStreamingMode = ArmStreamingMode::Streaming,
const ArmZaMode = ArmZaMode::Disabled, bool onlyIfRequiredByOps = false);
const ArmZaMode = ArmZaMode::Disabled, bool ifRequiredByOps = false,
bool ifContainsScalableVectors = false);

/// Pass that fuses 'arm_sme.outerproduct' ops into 2-way or 4-way widening
/// variants.
Expand Down
10 changes: 7 additions & 3 deletions mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -116,10 +116,14 @@ def EnableArmStreaming
"not be used for input and/or output and the "
"function must return with ZA unchanged")
)}]>,
Option<"onlyIfRequiredByOps", "only-if-required-by-ops", "bool",
Option<"ifRequiredByOps", "if-required-by-ops", "bool",
/*default=*/"false",
"Only apply the selected streaming/ZA modes if the function "
" contains ops that require them.">
"Apply the selected streaming/ZA modes if the function contains ops "
"that require them.">,
Option<"ifContainsScalableVectors", "if-contains-scalable-vectors",
"bool", /*default=*/"false",
"Apply the selected streaming/ZA modes if the function contains "
"operations that use scalable vector types.">
];
let dependentDialects = ["func::FuncDialect"];
}
Expand Down
49 changes: 38 additions & 11 deletions mlir/lib/Dialect/ArmSME/Transforms/EnableArmStreaming.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,17 +58,25 @@ constexpr StringLiteral
struct EnableArmStreamingPass
: public arm_sme::impl::EnableArmStreamingBase<EnableArmStreamingPass> {
EnableArmStreamingPass(ArmStreamingMode streamingMode, ArmZaMode zaMode,
bool onlyIfRequiredByOps) {
bool ifRequiredByOps, bool ifContainsScalableVectors) {
this->streamingMode = streamingMode;
this->zaMode = zaMode;
this->onlyIfRequiredByOps = onlyIfRequiredByOps;
this->ifRequiredByOps = ifRequiredByOps;
this->ifContainsScalableVectors = ifContainsScalableVectors;
}
void runOnOperation() override {
auto op = getOperation();
auto function = getOperation();

if (onlyIfRequiredByOps) {
if (ifRequiredByOps && ifContainsScalableVectors) {
function->emitOpError(
"enable-arm-streaming: `if-required-by-ops` and "
"`if-contains-scalable-vectors` are mutually exclusive");
return signalPassFailure();
}

if (ifRequiredByOps) {
bool foundTileOp = false;
op.walk([&](Operation *op) {
function.walk([&](Operation *op) {
if (llvm::isa<ArmSMETileOpInterface>(op)) {
foundTileOp = true;
return WalkResult::interrupt();
Expand All @@ -79,27 +87,46 @@ struct EnableArmStreamingPass
return;
}

if (op->getAttr(kEnableArmStreamingIgnoreAttr) ||
if (ifContainsScalableVectors) {
bool foundScalableVector = false;
auto isScalableVector = [&](Type type) {
if (auto vectorType = dyn_cast<VectorType>(type))
return vectorType.isScalable();
return false;
};
function.walk([&](Operation *op) {
if (llvm::any_of(op->getOperandTypes(), isScalableVector) ||
llvm::any_of(op->getResultTypes(), isScalableVector)) {
foundScalableVector = true;
return WalkResult::interrupt();
}
return WalkResult::advance();
});
if (!foundScalableVector)
return;
}

if (function->getAttr(kEnableArmStreamingIgnoreAttr) ||
streamingMode == ArmStreamingMode::Disabled)
return;

auto unitAttr = UnitAttr::get(&getContext());

op->setAttr(stringifyArmStreamingMode(streamingMode), unitAttr);
function->setAttr(stringifyArmStreamingMode(streamingMode), unitAttr);

// The pass currently only supports enabling ZA when in streaming-mode, but
// ZA can be accessed by the SME LDR, STR and ZERO instructions when not in
// streaming-mode (see section B1.1.1, IDGNQM of spec [1]). It may be worth
// supporting this later.
if (zaMode != ArmZaMode::Disabled)
op->setAttr(stringifyArmZaMode(zaMode), unitAttr);
function->setAttr(stringifyArmZaMode(zaMode), unitAttr);
}
};
} // namespace

std::unique_ptr<Pass> mlir::arm_sme::createEnableArmStreamingPass(
const ArmStreamingMode streamingMode, const ArmZaMode zaMode,
bool onlyIfRequiredByOps) {
return std::make_unique<EnableArmStreamingPass>(streamingMode, zaMode,
onlyIfRequiredByOps);
bool ifRequiredByOps, bool ifContainsScalableVectors) {
return std::make_unique<EnableArmStreamingPass>(
streamingMode, zaMode, ifRequiredByOps, ifContainsScalableVectors);
}
4 changes: 4 additions & 0 deletions mlir/test/Dialect/ArmSME/enable-arm-streaming-invalid.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
// RUN: mlir-opt %s -enable-arm-streaming="if-contains-scalable-vectors if-required-by-ops" -verify-diagnostics

// expected-error@below {{enable-arm-streaming: `if-required-by-ops` and `if-contains-scalable-vectors` are mutually exclusive}}
func.func @test() { return }
17 changes: 16 additions & 1 deletion mlir/test/Dialect/ArmSME/enable-arm-streaming.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
// RUN: mlir-opt %s -enable-arm-streaming=streaming-mode=streaming-locally -verify-diagnostics | FileCheck %s -check-prefix=CHECK-LOCALLY
// RUN: mlir-opt %s -enable-arm-streaming=streaming-mode=streaming-compatible -verify-diagnostics | FileCheck %s -check-prefix=CHECK-COMPATIBLE
// RUN: mlir-opt %s -enable-arm-streaming=za-mode=new-za -verify-diagnostics | FileCheck %s -check-prefix=CHECK-ENABLE-ZA
// RUN: mlir-opt %s -enable-arm-streaming=only-if-required-by-ops -verify-diagnostics | FileCheck %s -check-prefix=IF-REQUIRED
// RUN: mlir-opt %s -enable-arm-streaming=if-required-by-ops -verify-diagnostics | FileCheck %s -check-prefix=IF-REQUIRED
// RUN: mlir-opt %s -enable-arm-streaming=if-contains-scalable-vectors -verify-diagnostics | FileCheck %s -check-prefix=IF-SCALABLE

// CHECK-LABEL: @arm_streaming
// CHECK-SAME: attributes {arm_streaming}
Expand Down Expand Up @@ -38,3 +39,17 @@ func.func @requires_arm_streaming() {
// IF-REQUIRED: @does_not_require_arm_streaming
// IF-REQUIRED-NOT: arm_streaming
func.func @does_not_require_arm_streaming() { return }

// IF-SCALABLE-LABEL: @contains_scalable_vectors
// IF-SCALABLE-SAME: attributes {arm_streaming}
func.func @contains_scalable_vectors(%vec: vector<[4]xf32>) -> vector<[4]xf32> {
%0 = arith.addf %vec, %vec : vector<[4]xf32>
return %0 : vector<[4]xf32>
}

// IF-SCALABLE-LABEL: @no_scalable_vectors
// IF-SCALABLE-NOT: arm_streaming
func.func @no_scalable_vectors(%vec: vector<4xf32>) -> vector<4xf32> {
%0 = arith.addf %vec, %vec : vector<4xf32>
return %0 : vector<4xf32>
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
// RUN: -arm-sme-vector-legalization -canonicalize -cse \
// RUN: -convert-vector-to-arm-sme -arm-sme-outer-product-fusion \
// RUN: -allocate-arm-sme-tiles -convert-arm-sme-to-scf \
// RUN: -enable-arm-streaming="streaming-mode=streaming-locally za-mode=new-za only-if-required-by-ops" \
// RUN: -enable-arm-streaming="streaming-mode=streaming-locally za-mode=new-za if-required-by-ops" \
// RUN: -convert-vector-to-scf=full-unroll -convert-arm-sme-to-llvm \
// RUN: -test-lower-to-llvm | \
// RUN: %mcr_aarch64_cmd \
Expand Down
2 changes: 1 addition & 1 deletion mlir/test/lib/Dialect/ArmSME/TestLowerToArmSME.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ void buildTestLowerToArmSME(OpPassManager &pm,
// Enable streaming-mode and ZA.
pm.addPass(arm_sme::createEnableArmStreamingPass(
arm_sme::ArmStreamingMode::StreamingLocally, arm_sme::ArmZaMode::NewZA,
/*onlyIfRequiredByOps=*/true));
/*ifRequiredByOps=*/true));

// Convert SCF to CF (required for ArmSME tile allocation).
pm.addPass(createConvertSCFToCFPass());
Expand Down