Skip to content

Commit 7e42ad4

Browse files
AviadCojorickert
authored andcommitted
[mlir]: Added properties/attributes ignore flags to OperationEquivalence (#142623)
Those flags are useful for cases and operation which we may consider equivalent even when their attributes/properties are not the same.
1 parent 0531ca1 commit 7e42ad4

3 files changed

Lines changed: 48 additions & 7 deletions

File tree

mlir/include/mlir/IR/OperationSupport.h

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1288,7 +1288,14 @@ struct OperationEquivalence {
12881288
// When provided, the location attached to the operation are ignored.
12891289
IgnoreLocations = 1,
12901290

1291-
LLVM_MARK_AS_BITMASK_ENUM(/* LargestValue = */ IgnoreLocations)
1291+
// When provided, the discardable attributes attached to the operation are
1292+
// ignored.
1293+
IgnoreDiscardableAttrs = 2,
1294+
1295+
// When provided, the properties attached to the operation are ignored.
1296+
IgnoreProperties = 4,
1297+
1298+
LLVM_MARK_AS_BITMASK_ENUM(/* LargestValue = */ IgnoreProperties)
12921299
};
12931300

12941301
/// Compute a hash for the given operation.

mlir/lib/IR/OperationSupport.cpp

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -674,9 +674,13 @@ llvm::hash_code OperationEquivalence::computeHash(
674674
// - Operation Name
675675
// - Attributes
676676
// - Result Types
677+
DictionaryAttr dictAttrs;
678+
if (!(flags & Flags::IgnoreDiscardableAttrs))
679+
dictAttrs = op->getRawDictionaryAttrs();
677680
llvm::hash_code hash =
678-
llvm::hash_combine(op->getName(), op->getRawDictionaryAttrs(),
679-
op->getResultTypes(), op->hashProperties());
681+
llvm::hash_combine(op->getName(), dictAttrs, op->getResultTypes());
682+
if (!(flags & Flags::IgnoreProperties))
683+
hash = llvm::hash_combine(hash, op->hashProperties());
680684

681685
// - Location if required
682686
if (!(flags & Flags::IgnoreLocations))
@@ -830,14 +834,19 @@ OperationEquivalence::isRegionEquivalentTo(Region *lhs, Region *rhs,
830834
return true;
831835

832836
// 1. Compare the operation properties.
837+
if (!(flags & IgnoreDiscardableAttrs) &&
838+
lhs->getRawDictionaryAttrs() != rhs->getRawDictionaryAttrs())
839+
return false;
840+
833841
if (lhs->getName() != rhs->getName() ||
834-
lhs->getRawDictionaryAttrs() != rhs->getRawDictionaryAttrs() ||
835842
lhs->getNumRegions() != rhs->getNumRegions() ||
836843
lhs->getNumSuccessors() != rhs->getNumSuccessors() ||
837844
lhs->getNumOperands() != rhs->getNumOperands() ||
838-
lhs->getNumResults() != rhs->getNumResults() ||
839-
!lhs->getName().compareOpProperties(lhs->getPropertiesStorage(),
840-
rhs->getPropertiesStorage()))
845+
lhs->getNumResults() != rhs->getNumResults())
846+
return false;
847+
if (!(flags & IgnoreProperties) &&
848+
!(lhs->getName().compareOpProperties(lhs->getPropertiesStorage(),
849+
rhs->getPropertiesStorage())))
841850
return false;
842851
if (!(flags & IgnoreLocations) && lhs->getLoc() != rhs->getLoc())
843852
return false;

mlir/unittests/IR/OperationSupportTest.cpp

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -295,6 +295,7 @@ TEST(OperandStorageTest, PopulateDefaultAttrs) {
295295
TEST(OperationEquivalenceTest, HashWorksWithFlags) {
296296
MLIRContext context;
297297
context.getOrLoadDialect<test::TestDialect>();
298+
OpBuilder b(&context);
298299

299300
auto *op1 = createOp(&context);
300301
// `op1` has an unknown loc.
@@ -305,12 +306,36 @@ TEST(OperationEquivalenceTest, HashWorksWithFlags) {
305306
op, OperationEquivalence::ignoreHashValue,
306307
OperationEquivalence::ignoreHashValue, flags);
307308
};
309+
// Check ignore location.
308310
EXPECT_EQ(getHash(op1, OperationEquivalence::IgnoreLocations),
309311
getHash(op2, OperationEquivalence::IgnoreLocations));
310312
EXPECT_NE(getHash(op1, OperationEquivalence::None),
311313
getHash(op2, OperationEquivalence::None));
314+
op1->setLoc(NameLoc::get(StringAttr::get(&context, "foo")));
315+
// Check ignore discardable dictionary attributes.
316+
SmallVector<NamedAttribute> newAttrs = {
317+
b.getNamedAttr("foo", b.getStringAttr("f"))};
318+
op1->setAttrs(newAttrs);
319+
EXPECT_EQ(getHash(op1, OperationEquivalence::IgnoreDiscardableAttrs),
320+
getHash(op2, OperationEquivalence::IgnoreDiscardableAttrs));
321+
EXPECT_NE(getHash(op1, OperationEquivalence::None),
322+
getHash(op2, OperationEquivalence::None));
312323
op1->destroy();
313324
op2->destroy();
325+
326+
// Check ignore properties.
327+
auto req1 = b.getI32IntegerAttr(10);
328+
Operation *opWithProperty1 = b.create<test::OpAttrMatch1>(
329+
b.getUnknownLoc(), req1, nullptr, nullptr, req1);
330+
auto req2 = b.getI32IntegerAttr(60);
331+
Operation *opWithProperty2 = b.create<test::OpAttrMatch1>(
332+
b.getUnknownLoc(), req2, nullptr, nullptr, req2);
333+
EXPECT_EQ(getHash(opWithProperty1, OperationEquivalence::IgnoreProperties),
334+
getHash(opWithProperty2, OperationEquivalence::IgnoreProperties));
335+
EXPECT_NE(getHash(opWithProperty1, OperationEquivalence::None),
336+
getHash(opWithProperty2, OperationEquivalence::None));
337+
opWithProperty1->destroy();
338+
opWithProperty2->destroy();
314339
}
315340

316341
} // namespace

0 commit comments

Comments
 (0)