@@ -917,7 +917,7 @@ static std::map<int64_t, unsigned> getNumOccurences(ArrayRef<int64_t> vals) {
917917// / This accounts for cases where there are multiple unit-dims, but only a
918918// / subset of those are dropped. For MemRefTypes these can be disambiguated
919919// / using the strides. If a dimension is dropped the stride must be dropped too.
920- static std::optional <llvm::SmallBitVector>
920+ static FailureOr <llvm::SmallBitVector>
921921computeMemRefRankReductionMask (MemRefType originalType, MemRefType reducedType,
922922 ArrayRef<OpFoldResult> sizes) {
923923 llvm::SmallBitVector unusedDims (originalType.getRank ());
@@ -941,7 +941,7 @@ computeMemRefRankReductionMask(MemRefType originalType, MemRefType reducedType,
941941 getStridesAndOffset (originalType, originalStrides, originalOffset)) ||
942942 failed (
943943 getStridesAndOffset (reducedType, candidateStrides, candidateOffset)))
944- return std:: nullopt ;
944+ return failure () ;
945945
946946 // For memrefs, a dimension is truly dropped if its corresponding stride is
947947 // also dropped. This is particularly important when more than one of the dims
@@ -976,22 +976,22 @@ computeMemRefRankReductionMask(MemRefType originalType, MemRefType reducedType,
976976 candidateStridesNumOccurences[originalStride]) {
977977 // This should never happen. Cant have a stride in the reduced rank type
978978 // that wasnt in the original one.
979- return std:: nullopt ;
979+ return failure () ;
980980 }
981981 }
982982
983983 if ((int64_t )unusedDims.count () + reducedType.getRank () !=
984984 originalType.getRank ())
985- return std:: nullopt ;
985+ return failure () ;
986986 return unusedDims;
987987}
988988
989989llvm::SmallBitVector SubViewOp::getDroppedDims () {
990990 MemRefType sourceType = getSourceType ();
991991 MemRefType resultType = getType ();
992- std::optional <llvm::SmallBitVector> unusedDims =
992+ FailureOr <llvm::SmallBitVector> unusedDims =
993993 computeMemRefRankReductionMask (sourceType, resultType, getMixedSizes ());
994- assert (unusedDims && " unable to find unused dims of subview" );
994+ assert (succeeded ( unusedDims) && " unable to find unused dims of subview" );
995995 return *unusedDims;
996996}
997997
@@ -2745,7 +2745,7 @@ void SubViewOp::build(OpBuilder &b, OperationState &result, Value source,
27452745// / For ViewLikeOpInterface.
27462746Value SubViewOp::getViewSource () { return getSource (); }
27472747
2748- // / Return true if t1 and t2 have equal offsets (both dynamic or of same
2748+ // / Return true if `t1` and `t2` have equal offsets (both dynamic or of same
27492749// / static value).
27502750static bool haveCompatibleOffsets (MemRefType t1, MemRefType t2) {
27512751 int64_t t1Offset, t2Offset;
@@ -2755,56 +2755,41 @@ static bool haveCompatibleOffsets(MemRefType t1, MemRefType t2) {
27552755 return succeeded (res1) && succeeded (res2) && t1Offset == t2Offset;
27562756}
27572757
2758- // / Checks if `original` Type type can be rank reduced to `reduced` type.
2759- // / This function is slight variant of `is subsequence` algorithm where
2760- // / not matching dimension must be 1.
2761- static SliceVerificationResult
2762- isRankReducedMemRefType (MemRefType originalType,
2763- MemRefType candidateRankReducedType,
2764- ArrayRef<OpFoldResult> sizes) {
2765- auto partialRes = isRankReducedType (originalType, candidateRankReducedType);
2766- if (partialRes != SliceVerificationResult::Success)
2767- return partialRes;
2768-
2769- auto optionalUnusedDimsMask = computeMemRefRankReductionMask (
2770- originalType, candidateRankReducedType, sizes);
2771-
2772- // Sizes cannot be matched in case empty vector is returned.
2773- if (!optionalUnusedDimsMask)
2774- return SliceVerificationResult::LayoutMismatch;
2775-
2776- if (originalType.getMemorySpace () !=
2777- candidateRankReducedType.getMemorySpace ())
2778- return SliceVerificationResult::MemSpaceMismatch;
2779-
2780- // No amount of stride dropping can reconcile incompatible offsets.
2781- if (!haveCompatibleOffsets (originalType, candidateRankReducedType))
2782- return SliceVerificationResult::LayoutMismatch;
2783-
2784- return SliceVerificationResult::Success;
2758+ // / Return true if `t1` and `t2` have equal strides (both dynamic or of same
2759+ // / static value).
2760+ static bool haveCompatibleStrides (MemRefType t1, MemRefType t2) {
2761+ int64_t t1Offset, t2Offset;
2762+ SmallVector<int64_t > t1Strides, t2Strides;
2763+ auto res1 = getStridesAndOffset (t1, t1Strides, t1Offset);
2764+ auto res2 = getStridesAndOffset (t2, t2Strides, t2Offset);
2765+ if (failed (res1) || failed (res2))
2766+ return false ;
2767+ for (auto [s1, s2] : llvm::zip_equal (t1Strides, t2Strides))
2768+ if (s1 != s2)
2769+ return false ;
2770+ return true ;
27852771}
27862772
2787- template <typename OpTy>
27882773static LogicalResult produceSubViewErrorMsg (SliceVerificationResult result,
2789- OpTy op, Type expectedType) {
2774+ Operation * op, Type expectedType) {
27902775 auto memrefType = llvm::cast<ShapedType>(expectedType);
27912776 switch (result) {
27922777 case SliceVerificationResult::Success:
27932778 return success ();
27942779 case SliceVerificationResult::RankTooLarge:
2795- return op. emitError (" expected result rank to be smaller or equal to " )
2780+ return op-> emitError (" expected result rank to be smaller or equal to " )
27962781 << " the source rank. " ;
27972782 case SliceVerificationResult::SizeMismatch:
2798- return op. emitError (" expected result type to be " )
2783+ return op-> emitError (" expected result type to be " )
27992784 << expectedType
28002785 << " or a rank-reduced version. (mismatch of result sizes) " ;
28012786 case SliceVerificationResult::ElemTypeMismatch:
2802- return op. emitError (" expected result element type to be " )
2787+ return op-> emitError (" expected result element type to be " )
28032788 << memrefType.getElementType ();
28042789 case SliceVerificationResult::MemSpaceMismatch:
2805- return op. emitError (" expected result and source memory spaces to match." );
2790+ return op-> emitError (" expected result and source memory spaces to match." );
28062791 case SliceVerificationResult::LayoutMismatch:
2807- return op. emitError (" expected result type to be " )
2792+ return op-> emitError (" expected result type to be " )
28082793 << expectedType
28092794 << " or a rank-reduced version. (mismatch of result layout) " ;
28102795 }
@@ -2826,13 +2811,46 @@ LogicalResult SubViewOp::verify() {
28262811 if (!isStrided (baseType))
28272812 return emitError (" base type " ) << baseType << " is not strided" ;
28282813
2829- // Verify result type against inferred type.
2830- auto expectedType = SubViewOp::inferResultType (
2831- baseType, getStaticOffsets (), getStaticSizes (), getStaticStrides ());
2814+ // Compute the expected result type, assuming that there are no rank
2815+ // reductions.
2816+ auto expectedType = cast<MemRefType>(SubViewOp::inferResultType (
2817+ baseType, getStaticOffsets (), getStaticSizes (), getStaticStrides ()));
2818+
2819+ // Verify all properties of a shaped type: rank, element type and dimension
2820+ // sizes. This takes into account potential rank reductions.
2821+ auto shapedTypeVerification = isRankReducedType (
2822+ /* originalType=*/ expectedType, /* candidateReducedType=*/ subViewType);
2823+ if (shapedTypeVerification != SliceVerificationResult::Success)
2824+ return produceSubViewErrorMsg (shapedTypeVerification, *this , expectedType);
2825+
2826+ // Make sure that the memory space did not change.
2827+ if (expectedType.getMemorySpace () != subViewType.getMemorySpace ())
2828+ return produceSubViewErrorMsg (SliceVerificationResult::MemSpaceMismatch,
2829+ *this , expectedType);
2830+
2831+ // Verify the offset of the layout map.
2832+ if (!haveCompatibleOffsets (expectedType, subViewType))
2833+ return produceSubViewErrorMsg (SliceVerificationResult::LayoutMismatch,
2834+ *this , expectedType);
2835+
2836+ // The only thing that's left to verify now are the strides. First, compute
2837+ // the unused dimensions due to rank reductions. We have to look at sizes and
2838+ // strides to decide which dimensions were dropped. This function also
2839+ // partially verifies strides in case of rank reductions.
2840+ auto unusedDims = computeMemRefRankReductionMask (expectedType, subViewType,
2841+ getMixedSizes ());
2842+ if (failed (unusedDims))
2843+ return produceSubViewErrorMsg (SliceVerificationResult::LayoutMismatch,
2844+ *this , expectedType);
2845+
2846+ // Strides must match if there are no rank reductions.
2847+ // TODO: Verify strides when there are rank reductions. Strides are partially
2848+ // checked in `computeMemRefRankReductionMask`.
2849+ if (unusedDims->none () && !haveCompatibleStrides (expectedType, subViewType))
2850+ return produceSubViewErrorMsg (SliceVerificationResult::LayoutMismatch,
2851+ *this , expectedType);
28322852
2833- auto result = isRankReducedMemRefType (llvm::cast<MemRefType>(expectedType),
2834- subViewType, getMixedSizes ());
2835- return produceSubViewErrorMsg (result, *this , expectedType);
2853+ return success ();
28362854}
28372855
28382856raw_ostream &mlir::operator <<(raw_ostream &os, const Range &range) {
@@ -2882,11 +2900,9 @@ static MemRefType getCanonicalSubViewResultType(
28822900 ArrayRef<OpFoldResult> mixedSizes, ArrayRef<OpFoldResult> mixedStrides) {
28832901 auto nonRankReducedType = llvm::cast<MemRefType>(SubViewOp::inferResultType (
28842902 sourceType, mixedOffsets, mixedSizes, mixedStrides));
2885- std::optional<llvm::SmallBitVector> unusedDims =
2886- computeMemRefRankReductionMask (currentSourceType, currentResultType,
2887- mixedSizes);
2888- // Return nullptr as failure mode.
2889- if (!unusedDims)
2903+ FailureOr<llvm::SmallBitVector> unusedDims = computeMemRefRankReductionMask (
2904+ currentSourceType, currentResultType, mixedSizes);
2905+ if (failed (unusedDims))
28902906 return nullptr ;
28912907
28922908 auto layout = llvm::cast<StridedLayoutAttr>(nonRankReducedType.getLayout ());
0 commit comments