-
Notifications
You must be signed in to change notification settings - Fork 15.5k
[mlir] Implement DestinationStyleOpInterface for scf::ForallOp #66981
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
@llvm/pr-subscribers-mlir-tensor @llvm/pr-subscribers-mlir-scf Changes
For this change to work, we need to add an exception to the Pattern that folds Full diff: https://github.com/llvm/llvm-project/pull/66981.diff 3 Files Affected:
diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCF.h b/mlir/include/mlir/Dialect/SCF/IR/SCF.h
index 915ab3016b688e7..644118ca884c6b1 100644
--- a/mlir/include/mlir/Dialect/SCF/IR/SCF.h
+++ b/mlir/include/mlir/Dialect/SCF/IR/SCF.h
@@ -19,6 +19,7 @@
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/RegionKindInterface.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h"
+#include "mlir/Interfaces/DestinationStyleOpInterface.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Interfaces/LoopLikeInterface.h"
#include "mlir/Interfaces/ParallelCombiningOpInterface.h"
diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
index 08b71e20a2bc079..adc7b2e4170cb89 100644
--- a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
+++ b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
@@ -17,6 +17,7 @@ include "mlir/Interfaces/ControlFlowInterfaces.td"
include "mlir/Interfaces/LoopLikeInterface.td"
include "mlir/IR/RegionKindInterface.td"
include "mlir/Dialect/SCF/IR/DeviceMappingInterface.td"
+include "mlir/Interfaces/DestinationStyleOpInterface.td"
include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/Interfaces/ParallelCombiningOpInterface.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
@@ -333,6 +334,7 @@ def ForallOp : SCF_Op<"forall", [
RecursiveMemoryEffects,
SingleBlockImplicitTerminator<"scf::InParallelOp">,
DeclareOpInterfaceMethods<RegionBranchOpInterface>,
+ DestinationStyleOpInterface
]> {
let summary = "evaluate a block multiple times in parallel";
let description = [{
@@ -630,6 +632,14 @@ def ForallOp : SCF_Op<"forall", [
Location loc);
InParallelOp getTerminator();
+
+ // Implement this to declare all shared_outs as inits/outs to
+ // DestinationStyleOpInterface
+ std::pair<int64_t, int64_t> getDpsInitsPositionRange() {
+ int64_t numOperands = getNumOperands();
+ int64_t numOuts = getOutputs().size();
+ return {numOperands - numOuts, numOperands};
+ }
}];
}
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 3e30e320bee8f83..fa91471f33d4bd3 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -10,6 +10,7 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/Complex/IR/Complex.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
@@ -3970,6 +3971,9 @@ struct FoldTensorCastProducerOp
if (isa<InsertSliceOp>(op.getOperation()))
return failure();
+ if (isa<scf::ForallOp>(op.getOperation()))
+ return failure();
+
// If no operand comes from a tensor::CastOp and can be folded then fail.
bool hasTensorCastOperand =
llvm::any_of(op->getOpOperands(), [&](OpOperand &opOperand) {
|
2822e2e to
4e8e458
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These lines are problematic as the tensor dialect does not depend on the SCF dialect.
To make this matter even worse, the SCF dialect currently depends on the Tensor dialect,, creating a circular dependency between the two.
This leads to BUILD_SHARED_LIBS to fail building as well as any downstream user dependending on the tensor dialect to get linker errors that they need to include SCF as well now.
FAILED: lib/libMLIRTensorDialect.so.18git
ld.lld: error: undefined symbol: mlir::detail::TypeIDResolver<mlir::scf::ForallOp, void>::id
>>> referenced by TensorOps.cpp
>>> tools/mlir/lib/Dialect/Tensor/IR/CMakeFiles/obj.MLIRTensorDialect.dir/TensorOps.cpp.o:(FoldTensorCastProducerOp::matchAndRewrite(mlir::DestinationStyleOpInterface, mlir::PatternRewriter&) const)
clang-16: error: linker command failed with exit code 1 (use -v to see invocation)
I sadly know very little about the tensor dialect/bufferzization/all that stuff, but is this kind of layering really necessary?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh right, I missed that. We should avoid such references to other dialects. FoldTensorCastProducerOp should not fold casts into destination style ops that have a region without further checks. At the moment there is no way to query whether such folding is safe or not.
Maybe we can do something with a new interface. We already have CastOpInterface. We could have another FoldCastOpInterface that ops can implement that support folding of cast ops into themselves.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for pointing that out!
I just tried adding a simple check for op->getNumRegions() != 0 but this also doesn't work because it excludes all of the linalgOps that are some of the primary targets of this Pattern. How about excluding Ops that have tensor bbArgs, or is this getting too specific?
About the potential Interface - is there a way to implement this so that it doesn't add extra implementation effort to every DPS op?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think you can just look for LoopLikeOpInterface here... That adds a dependence from DestPassingStyleOpInterface to LoopLikeOpInterface, but we might be able to live with that.
matthias-springer
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Layering issues (scf.forall is not a tensor dialect op)
|
I am in favor of this change, but given the dependence issue can you comment a bit more on the motivation for this change? |
My original idea was to implement Additionally, I think One caveat about the |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This has changed in #67015.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, I updated that part.
95e9841 to
6420451
Compare
`scf::ForallOp` has `shared_outs` tensor operands which are used to insert partial results into in the parallel terminator. The `scf::ForallOp` returns one tensor for each `shared_out` which then contains the combined result from all threads. Since the parallel terminator cannot change the shape of the `shared_out`, ForallOp is a `DestinationStyleOp` and this patch implements the interface by declaring the `outputs` operands as `inits` in the language of the DPS interface. For this change to work, we need to add an exception to the Pattern that folds `tensor.cast` Ops into DPS Ops because `scf::Forall` needs special handling of it's `BlockArgument` Type during this folding.
6420451 to
d9575d4
Compare
scf::ForallOphasshared_outstensor operands which are used to insert partial results into in the parallel terminator. Thescf::ForallOpreturns one tensor for eachshared_outwhich then contains the combined result from all threads. Since the parallel terminator cannot change the shape of theshared_out, ForallOp is aDestinationStyleOpand this patch implements the interface by declaring theoutputsoperands asinitsin the language of the DPS interface.For this change to work, we need to add an exception to the Pattern that folds
tensor.castOps into DPS Ops becausescf::Forallneeds special handling of it'sBlockArgumentType during this folding.