Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -346,8 +346,6 @@ preconditionsFoldSubViewOpImpl(RewriterBase &rewriter, XferOp xferOp,
"must be a vector transfer op");
if (xferOp.hasOutOfBoundsDim())
return rewriter.notifyMatchFailure(xferOp, "out of bounds transfer dim");
if (xferOp.getMask())
return rewriter.notifyMatchFailure(xferOp, "masked transfer");
if (!subviewOp.hasUnitStride()) {
return rewriter.notifyMatchFailure(
xferOp, "non-1 stride subview, need to track strides in folded memref");
Expand Down Expand Up @@ -428,7 +426,7 @@ LogicalResult LoadOpOfSubViewOpFolder<OpTy>::matchAndRewrite(
AffineMapAttr::get(expandDimsToRank(
op.getPermutationMap(), subViewOp.getSourceType().getRank(),
subViewOp.getDroppedDims())),
op.getPadding(), /*mask=*/Value(), op.getInBoundsAttr());
op.getPadding(), op.getMask(), op.getInBoundsAttr());
})
.Case([&](gpu::SubgroupMmaLoadMatrixOp op) {
rewriter.replaceOpWithNewOp<gpu::SubgroupMmaLoadMatrixOp>(
Expand Down Expand Up @@ -557,7 +555,7 @@ LogicalResult StoreOpOfSubViewOpFolder<OpTy>::matchAndRewrite(
AffineMapAttr::get(expandDimsToRank(
op.getPermutationMap(), subViewOp.getSourceType().getRank(),
subViewOp.getDroppedDims())),
op.getInBoundsAttr());
op.getMask(), op.getInBoundsAttr());
})
.Case([&](gpu::SubgroupMmaStoreMatrixOp op) {
rewriter.replaceOpWithNewOp<gpu::SubgroupMmaStoreMatrixOp>(
Expand Down
57 changes: 57 additions & 0 deletions mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,63 @@ func.func @fold_vector_transfer_write_with_inner_rank_reduced_subview(

// -----

func.func @fold_masked_vector_transfer_read_with_subview(
%arg0 : memref<?x?xf32, strided<[?, ?], offset: ?>>,
%arg1: index, %arg2 : index, %arg3 : index, %arg4: index, %arg5 : index,
%arg6 : index, %mask : vector<4xi1>) -> vector<4xf32> {
%cst = arith.constant 0.0 : f32
%0 = memref.subview %arg0[%arg1, %arg2] [%arg3, %arg4] [1, 1]
: memref<?x?xf32, strided<[?, ?], offset: ?>> to
memref<?x?xf32, strided<[?, ?], offset: ?>>
%1 = vector.transfer_read %0[%arg5, %arg6], %cst, %mask {in_bounds = [true]}
: memref<?x?xf32, strided<[?, ?], offset: ?>>, vector<4xf32>
return %1 : vector<4xf32>
}
// CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1] -> (s0 + s1)>
// CHECK: func @fold_masked_vector_transfer_read_with_subview
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: memref<?x?xf32, strided<[?, ?], offset: ?>>
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: index
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: index
// CHECK-SAME: %[[ARG3:[a-zA-Z0-9]+]]: index
// CHECK-SAME: %[[ARG4:[a-zA-Z0-9]+]]: index
// CHECK-SAME: %[[ARG5:[a-zA-Z0-9]+]]: index
// CHECK-SAME: %[[ARG6:[a-zA-Z0-9]+]]: index
// CHECK-SAME: %[[ARG7:[a-zA-Z0-9]+]]: vector<4xi1>
// CHECK-DAG: %[[IDX0:.+]] = affine.apply #[[MAP1]]()[%[[ARG1]], %[[ARG5]]]
// CHECK-DAG: %[[IDX1:.+]] = affine.apply #[[MAP1]]()[%[[ARG2]], %[[ARG6]]]
// CHECK: vector.transfer_read %[[ARG0]][%[[IDX0]], %[[IDX1]]], %{{.*}}, %[[ARG7]] {{.*}} : memref<?x?xf32

// -----

func.func @fold_masked_vector_transfer_write_with_subview(
%arg0 : memref<?x?xf32, strided<[?, ?], offset: ?>>,
%arg1 : vector<4xf32>, %arg2: index, %arg3 : index, %arg4 : index,
%arg5: index, %arg6 : index, %arg7 : index, %mask : vector<4xi1>) {
%cst = arith.constant 0.0 : f32
%0 = memref.subview %arg0[%arg2, %arg3] [%arg4, %arg5] [1, 1]
: memref<?x?xf32, strided<[?, ?], offset: ?>> to
memref<?x?xf32, strided<[?, ?], offset: ?>>
vector.transfer_write %arg1, %0[%arg6, %arg7], %mask {in_bounds = [true]}
: vector<4xf32>, memref<?x?xf32, strided<[?, ?], offset: ?>>
return
}
// CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1] -> (s0 + s1)>
// CHECK: func @fold_masked_vector_transfer_write_with_subview
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: memref<?x?xf32, strided<[?, ?], offset: ?>>
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: vector<4xf32>
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: index
// CHECK-SAME: %[[ARG3:[a-zA-Z0-9]+]]: index
// CHECK-SAME: %[[ARG4:[a-zA-Z0-9]+]]: index
// CHECK-SAME: %[[ARG5:[a-zA-Z0-9]+]]: index
// CHECK-SAME: %[[ARG6:[a-zA-Z0-9]+]]: index
// CHECK-SAME: %[[ARG7:[a-zA-Z0-9]+]]: index
// CHECK-SAME: %[[ARG8:[a-zA-Z0-9]+]]: vector<4xi1>
// CHECK-DAG: %[[IDX0:.+]] = affine.apply #[[MAP1]]()[%[[ARG2]], %[[ARG6]]]
// CHECK-DAG: %[[IDX1:.+]] = affine.apply #[[MAP1]]()[%[[ARG3]], %[[ARG7]]]
// CHECK-DAG: vector.transfer_write %[[ARG1]], %[[ARG0]][%[[IDX0]], %[[IDX1]]], %[[ARG8]] {in_bounds = [true]} : vector<4xf32>, memref<?x?xf32

// -----

// Test with affine.load/store ops. We only do a basic test here since the
// logic is identical to that with memref.load/store ops. The same affine.apply
// ops would be generated.
Expand Down