Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1300,9 +1300,9 @@ class DropInnerMostUnitDimsTransferRead
if (dimsToDrop == 0)
return failure();

// Make sure that the indices to be dropped are equal 0.
// TODO: Deal with cases when the indices are not 0.
if (!llvm::all_of(readOp.getIndices().take_back(dimsToDrop), isZeroIndex))
auto inBounds = readOp.getInBoundsValues();
auto droppedInBounds = ArrayRef<bool>(inBounds).take_back(dimsToDrop);
if (llvm::is_contained(droppedInBounds, false))
return failure();

auto resultTargetVecType =
Expand Down Expand Up @@ -1394,6 +1394,11 @@ class DropInnerMostUnitDimsTransferWrite
if (dimsToDrop == 0)
return failure();

auto inBounds = writeOp.getInBoundsValues();
auto droppedInBounds = ArrayRef<bool>(inBounds).take_back(dimsToDrop);
if (llvm::is_contained(droppedInBounds, false))
return failure();

auto resultTargetVecType =
VectorType::get(targetType.getShape().drop_back(dimsToDrop),
targetType.getElementType(),
Expand Down
148 changes: 116 additions & 32 deletions mlir/test/Dialect/Vector/vector-transfer-collapse-inner-most-dims.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -113,46 +113,69 @@ func.func @contiguous_inner_most_outer_dim_dyn_scalable_inner_dim(%a: index, %b:

// -----

func.func @contiguous_inner_most_dim_non_zero_idx(%A: memref<16x1xf32>, %i:index) -> (vector<8x1xf32>) {
// Test the impact of changing the in_bounds attribute. The behaviour will
// depend on whether the index is == 0 or != 0.

// The index to be dropped is == 0, so it's safe to collapse. The other index
// should be preserved correctly.
func.func @contiguous_inner_most_zero_idx_in_bounds(%A: memref<16x1xf32>, %i:index) -> (vector<8x1xf32>) {
%pad = arith.constant 0.0 : f32
%c0 = arith.constant 0 : index
%f0 = arith.constant 0.0 : f32
%1 = vector.transfer_read %A[%i, %c0], %f0 : memref<16x1xf32>, vector<8x1xf32>
%1 = vector.transfer_read %A[%i, %c0], %pad {in_bounds = [true, true]} : memref<16x1xf32>, vector<8x1xf32>
return %1 : vector<8x1xf32>
}
// CHECK: func @contiguous_inner_most_dim_non_zero_idx(%[[SRC:.+]]: memref<16x1xf32>, %[[I:.+]]: index) -> vector<8x1xf32>
// CHECK: %[[SRC_0:.+]] = memref.subview %[[SRC]]
// CHECK-SAME: memref<16x1xf32> to memref<16xf32, strided<[1]>>
// CHECK: %[[V:.+]] = vector.transfer_read %[[SRC_0]]
// CHECK: %[[RESULT:.+]] = vector.shape_cast %[[V]] : vector<8xf32> to vector<8x1xf32>
// CHECK: return %[[RESULT]]

// The index to be dropped is != 0 - this is currently not supported.
func.func @negative_contiguous_inner_most_dim_non_zero_idxs(%A: memref<16x1xf32>, %i:index) -> (vector<8x1xf32>) {
%f0 = arith.constant 0.0 : f32
%1 = vector.transfer_read %A[%i, %i], %f0 : memref<16x1xf32>, vector<8x1xf32>
// CHECK-LABEL: func.func @contiguous_inner_most_zero_idx_in_bounds(
// CHECK-SAME: %[[MEM:.*]]: memref<16x1xf32>,
// CHECK-SAME: %[[IDX:.*]]: index) -> vector<8x1xf32> {
// CHECK: %[[PAD:.*]] = arith.constant 0.000000e+00 : f32
// CHECK: %[[SV:.*]] = memref.subview %[[MEM]][0, 0] [16, 1] [1, 1] : memref<16x1xf32> to memref<16xf32, strided<[1]>>
// CHECK: %[[READ:.*]] = vector.transfer_read %[[SV]]{{\[}}%[[IDX]]], %[[PAD]] {in_bounds = [true]} : memref<16xf32, strided<[1]>>, vector<8xf32>
// CHECK: vector.shape_cast %[[READ]] : vector<8xf32> to vector<8x1xf32>

// The index to be dropped is == 0, so it's safe to collapse. The "out of
// bounds" attribute is too conservative and will be folded to "in bounds"
// before the pattern runs. The other index should be preserved correctly.
func.func @contiguous_inner_most_zero_idx_out_of_bounds(%A: memref<16x1xf32>, %i:index) -> (vector<8x1xf32>) {
%pad = arith.constant 0.0 : f32
%c0 = arith.constant 0 : index
%1 = vector.transfer_read %A[%i, %c0], %pad {in_bounds = [true, false]} : memref<16x1xf32>, vector<8x1xf32>
return %1 : vector<8x1xf32>
}
// CHECK-LABEL: func @negative_contiguous_inner_most_dim_non_zero_idxs
// CHECK-LABEL: func.func @contiguous_inner_most_zero_idx_out_of_bounds(
// CHECK-SAME: %[[MEM:.*]]: memref<16x1xf32>,
// CHECK-SAME: %[[IDX:.*]]: index) -> vector<8x1xf32> {
// CHECK: %[[PAD:.*]] = arith.constant 0.000000e+00 : f32
// CHECK: %[[SV:.*]] = memref.subview %[[MEM]][0, 0] [16, 1] [1, 1] : memref<16x1xf32> to memref<16xf32, strided<[1]>>
// CHECK: %[[READ:.*]] = vector.transfer_read %[[SV]]{{\[}}%[[IDX]]], %[[PAD]] {in_bounds = [true]} : memref<16xf32, strided<[1]>>, vector<8xf32>
// CHECK: vector.shape_cast %[[READ]] : vector<8xf32> to vector<8x1xf32>

// The index to be dropped is unknown, but since it's "in bounds", it has to be
// == 0. It's safe to collapse the corresponding dim.
func.func @contiguous_inner_most_non_zero_idx_in_bounds(%A: memref<16x1xf32>, %i:index) -> (vector<8x1xf32>) {
%pad = arith.constant 0.0 : f32
%1 = vector.transfer_read %A[%i, %i], %pad {in_bounds = [true, true]} : memref<16x1xf32>, vector<8x1xf32>
return %1 : vector<8x1xf32>
}
// CHECK-LABEL: func.func @contiguous_inner_most_non_zero_idx_in_bounds(
// CHECK-SAME: %[[MEM:.*]]: memref<16x1xf32>,
// CHECK-SAME: %[[IDX:.*]]: index) -> vector<8x1xf32> {
// CHECK: %[[PAD:.*]] = arith.constant 0.000000e+00 : f32
// CHECK: %[[SV:.*]] = memref.subview %[[MEM]][0, 0] [16, 1] [1, 1] : memref<16x1xf32> to memref<16xf32, strided<[1]>>
// CHECK: %[[READ:.*]] = vector.transfer_read %[[SV]]{{\[}}%[[IDX]]], %[[PAD]] {in_bounds = [true]} : memref<16xf32, strided<[1]>>, vector<8xf32>
// CHECK: vector.shape_cast %[[READ]] : vector<8xf32> to vector<8x1xf32>

// The index to be dropped is unknown and "out of bounds" - not safe to
// collapse.
func.func @negative_contiguous_inner_most_non_zero_idx_out_of_bounds(%A: memref<16x1xf32>, %i:index) -> (vector<8x1xf32>) {
%pad = arith.constant 0.0 : f32
%1 = vector.transfer_read %A[%i, %i], %pad {in_bounds = [true, false]} : memref<16x1xf32>, vector<8x1xf32>
return %1 : vector<8x1xf32>
}
// CHECK-LABEL: func.func @negative_contiguous_inner_most_non_zero_idx_out_of_bounds(
// CHECK-NOT: memref.subview
// CHECK-NOT: memref.shape_cast
// CHECK: vector.transfer_read

// Same as the top example within this split, but with the outer vector
// dim scalable. Note that this example only makes sense when "8 = [8]" (i.e.
// vscale = 1). This is assumed (implicitly) via the `in_bounds` attribute.

func.func @contiguous_inner_most_dim_non_zero_idx_scalable_inner_dim(%A: memref<16x1xf32>, %i:index) -> (vector<[8]x1xf32>) {
%c0 = arith.constant 0 : index
%f0 = arith.constant 0.0 : f32
%1 = vector.transfer_read %A[%i, %c0], %f0 : memref<16x1xf32>, vector<[8]x1xf32>
return %1 : vector<[8]x1xf32>
}
// CHECK-LABEL: func @contiguous_inner_most_dim_non_zero_idx_scalable_inner_dim(
// CHECK-SAME: %[[SRC:.+]]: memref<16x1xf32>, %[[I:.+]]: index) -> vector<[8]x1xf32>
// CHECK: %[[SRC_0:.+]] = memref.subview %[[SRC]]
// CHECK-SAME: memref<16x1xf32> to memref<16xf32, strided<[1]>>
// CHECK: %[[V:.+]] = vector.transfer_read %[[SRC_0]]
// CHECK: %[[RESULT:.+]] = vector.shape_cast %[[V]] : vector<[8]xf32> to vector<[8]x1xf32>
// CHECK: return %[[RESULT]]

// -----

Expand Down Expand Up @@ -367,6 +390,67 @@ func.func @contiguous_inner_most_dynamic_outer_scalable_inner_dim(%a: index, %b:

// -----

// Test the impact of changing the in_bounds attribute. The behaviour will
// depend on whether the index is == 0 or != 0.

// The index to be dropped is == 0, so it's safe to collapse. The other index
// should be preserved correctly.
func.func @contiguous_inner_most_zero_idx_in_bounds(%arg0: memref<16x1xf32>, %arg1: vector<8x1xf32>, %i: index) {
%c0 = arith.constant 0 : index
vector.transfer_write %arg1, %arg0[%i, %c0] {in_bounds = [true, true]} : vector<8x1xf32>, memref<16x1xf32>
return
}
// CHECK-LABEL: func.func @contiguous_inner_most_zero_idx_in_bounds(
// CHECK-SAME: %[[MEM:.*]]: memref<16x1xf32>,
// CHECK-SAME: %[[VEC:.*]]: vector<8x1xf32>,
// CHECK-SAME: %[[IDX:.*]]: index) {
// CHECK: %[[SV:.*]] = memref.subview %[[MEM]][0, 0] [16, 1] [1, 1] : memref<16x1xf32> to memref<16xf32, strided<[1]>>
// CHECK: %[[SC:.*]] = vector.shape_cast %[[VEC]] : vector<8x1xf32> to vector<8xf32>
// CHECK: vector.transfer_write %[[SC]], %[[SV]]{{\[}}%[[IDX]]] {in_bounds = [true]} : vector<8xf32>, memref<16xf32, strided<[1]>>

// The index to be dropped is == 0, so it's safe to collapse. The "out of
// bounds" attribute is too conservative and will be folded to "in bounds"
// before the pattern runs. The other index should be preserved correctly.
func.func @contiguous_inner_most_zero_idx_out_of_bounds(%arg0: memref<16x1xf32>, %arg1: vector<8x1xf32>, %i: index) {
%c0 = arith.constant 0 : index
vector.transfer_write %arg1, %arg0[%i, %c0] {in_bounds = [true, false]} : vector<8x1xf32>, memref<16x1xf32>
return
}
// CHECK-LABEL: func.func @contiguous_inner_most_zero_idx_out_of_bounds
// CHECK-SAME: %[[MEM:.*]]: memref<16x1xf32>,
// CHECK-SAME: %[[VEC:.*]]: vector<8x1xf32>,
// CHECK-SAME: %[[IDX:.*]]: index) {
// CHECK: %[[SV:.*]] = memref.subview %[[MEM]][0, 0] [16, 1] [1, 1] : memref<16x1xf32> to memref<16xf32, strided<[1]>>
// CHECK: %[[SC:.*]] = vector.shape_cast %[[VEC]] : vector<8x1xf32> to vector<8xf32>
// CHECK: vector.transfer_write %[[SC]], %[[SV]]{{\[}}%[[IDX]]] {in_bounds = [true]} : vector<8xf32>, memref<16xf32, strided<[1]>>

// The index to be dropped is unknown, but since it's "in bounds", it has to be
// == 0. It's safe to collapse the corresponding dim.
func.func @contiguous_inner_most_dim_non_zero_idx_in_bounds(%arg0: memref<16x1xf32>, %arg1: vector<8x1xf32>, %i: index) {
vector.transfer_write %arg1, %arg0[%i, %i] {in_bounds = [true, true]} : vector<8x1xf32>, memref<16x1xf32>
return
}
// CHECK-LABEL: func @contiguous_inner_most_dim_non_zero_idx_in_bounds
// CHECK-SAME: %[[MEM:.*]]: memref<16x1xf32>,
// CHECK-SAME: %[[VEC:.*]]: vector<8x1xf32>,
// CHECK-SAME: %[[IDX:.*]]: index) {
// CHECK: %[[SV:.*]] = memref.subview %[[MEM]][0, 0] [16, 1] [1, 1] : memref<16x1xf32> to memref<16xf32, strided<[1]>>
// CHECK: %[[SC:.*]] = vector.shape_cast %[[VEC]] : vector<8x1xf32> to vector<8xf32>
// CHECK: vector.transfer_write %[[SC]], %[[SV]]{{\[}}%[[IDX]]] {in_bounds = [true]} : vector<8xf32>, memref<16xf32, strided<[1]>>

// The index to be dropped is unknown and "out of bounds" - not safe to
// collapse.
func.func @negative_contiguous_inner_most_dim_non_zero_idx_out_of_bounds(%arg0: memref<16x1xf32>, %arg1: vector<8x1xf32>, %i: index) {
vector.transfer_write %arg1, %arg0[%i, %i] {in_bounds = [true, false]} : vector<8x1xf32>, memref<16x1xf32>
return
}
// CHECK-LABEL: func @negative_contiguous_inner_most_dim_non_zero_idx_out_of_bounds
// CHECK-NOT: memref.subview
// CHECK-NOT: memref.shape_cast
// CHECK: vector.transfer_write

// -----

func.func @drop_inner_most_dim(%arg0: memref<1x512x16x1xf32, strided<[8192, 16, 1, 1], offset: ?>>, %arg1: vector<1x16x16x1xf32>, %arg2: index) {
%c0 = arith.constant 0 : index
vector.transfer_write %arg1, %arg0[%c0, %arg2, %c0, %c0]
Expand Down