Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,9 @@ namespace detail {
struct BitmaskEnumStorage;
} // namespace detail

/// Predefined constant_mask kinds.
enum class ConstantMaskKind { AllFalse = 0, AllTrue };

/// Default callback to build a region with a 'vector.yield' terminator with no
/// arguments.
void buildTerminatedBody(OpBuilder &builder, Location loc);
Expand Down Expand Up @@ -163,6 +166,11 @@ SmallVector<Value> getAsValues(OpBuilder &builder, Location loc,
SmallVector<arith::ConstantIndexOp>
getAsConstantIndexOps(ArrayRef<Value> values);

/// If `value` is a constant multiple of `vector.vscale` (e.g. `%cst *
/// vector.vscale`), return the multiplier (`%cst`). Otherwise, return
/// `std::nullopt`.
std::optional<int64_t> getConstantVscaleMultiplier(Value value);

//===----------------------------------------------------------------------===//
// Vector Masking Utilities
//===----------------------------------------------------------------------===//
Expand Down
5 changes: 5 additions & 0 deletions mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -2362,6 +2362,11 @@ def Vector_ConstantMaskOp :
```
}];

let builders = [
// Build with mixed static/dynamic operands.
OpBuilder<(ins "VectorType":$type, "ConstantMaskKind":$kind)>
];

let extraClassDeclaration = [{
/// Return the result type of this op.
VectorType getVectorType() {
Expand Down
17 changes: 17 additions & 0 deletions mlir/include/mlir/Dialect/Vector/Transforms/VectorTransforms.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
#include "mlir/Interfaces/FunctionInterfaces.h"

namespace mlir {
class MLIRContext;
Expand Down Expand Up @@ -115,6 +116,22 @@ castAwayContractionLeadingOneDim(vector::ContractionOp contractOp,
MaskingOpInterface maskingOp,
RewriterBase &rewriter);

// Structure to hold the range of `vector.vscale`.
struct VscaleRange {
unsigned vscaleMin;
unsigned vscaleMax;
};

/// Attempts to eliminate redundant vector masks by replacing them with all-true
/// constants at the top of the function (which results in the masks folding
/// away). Note: Currently, this only runs for vector.create_mask ops and
/// requires `vscaleRange`. If `vscaleRange` is not provided this transform does
/// nothing. This is because these redundant masks are much more likely for
/// scalable code which requires memref/tensor dynamic sizes, whereas fixed-size
/// code has static sizes, so simpler folds remove the masks.
void eliminateVectorMasks(IRRewriter &rewriter, FunctionOpInterface function,
std::optional<VscaleRange> vscaleRange = {});

} // namespace vector
} // namespace mlir

Expand Down
120 changes: 59 additions & 61 deletions mlir/lib/Dialect/Vector/IR/VectorOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5749,6 +5749,16 @@ void vector::TransposeOp::getCanonicalizationPatterns(
// ConstantMaskOp
//===----------------------------------------------------------------------===//

void ConstantMaskOp::build(OpBuilder &builder, OperationState &result,
VectorType type, ConstantMaskKind kind) {
assert(kind == ConstantMaskKind::AllTrue ||
kind == ConstantMaskKind::AllFalse);
build(builder, result, type,
kind == ConstantMaskKind::AllTrue
? type.getShape()
: SmallVector<int64_t>(type.getRank(), 0));
}

LogicalResult ConstantMaskOp::verify() {
auto resultType = llvm::cast<VectorType>(getResult().getType());
// Check the corner case of 0-D vectors first.
Expand Down Expand Up @@ -5831,6 +5841,21 @@ LogicalResult CreateMaskOp::verify() {
return success();
}

std::optional<int64_t> vector::getConstantVscaleMultiplier(Value value) {
if (value.getDefiningOp<vector::VectorScaleOp>())
return 1;
auto mul = value.getDefiningOp<arith::MulIOp>();
if (!mul)
return {};
auto lhs = mul.getLhs();
auto rhs = mul.getRhs();
if (lhs.getDefiningOp<vector::VectorScaleOp>())
return getConstantIntValue(rhs);
if (rhs.getDefiningOp<vector::VectorScaleOp>())
return getConstantIntValue(lhs);
return {};
}

namespace {

/// Pattern to rewrite a CreateMaskOp with a ConstantMaskOp.
Expand Down Expand Up @@ -5862,73 +5887,46 @@ class CreateMaskFolder final : public OpRewritePattern<CreateMaskOp> {

LogicalResult matchAndRewrite(CreateMaskOp createMaskOp,
PatternRewriter &rewriter) const override {
VectorType retTy = createMaskOp.getResult().getType();
bool isScalable = retTy.isScalable();

// Check every mask operand
for (auto [opIdx, operand] : llvm::enumerate(createMaskOp.getOperands())) {
if (auto cst = getConstantIntValue(operand)) {
// Most basic case - this operand is a constant value. Note that for
// scalable dimensions, CreateMaskOp can be folded only if the
// corresponding operand is negative or zero.
if (retTy.getScalableDims()[opIdx] && *cst > 0)
return failure();

continue;
}

// Non-constant operands are not allowed for non-scalable vectors.
if (!isScalable)
return failure();

// For scalable vectors, "arith.muli %vscale, %dimSize" means an "all
// true" mask, so can also be treated as constant.
auto mul = operand.getDefiningOp<arith::MulIOp>();
if (!mul)
return failure();
auto mulLHS = mul.getRhs();
auto mulRHS = mul.getLhs();
bool isOneOpVscale =
(isa<vector::VectorScaleOp>(mulLHS.getDefiningOp()) ||
isa<vector::VectorScaleOp>(mulRHS.getDefiningOp()));

auto isConstantValMatchingDim =
[=, dim = retTy.getShape()[opIdx]](Value operand) {
auto constantVal = getConstantIntValue(operand);
return (constantVal.has_value() && constantVal.value() == dim);
};

bool isOneOpConstantMatchingDim =
isConstantValMatchingDim(mulLHS) || isConstantValMatchingDim(mulRHS);

if (!isOneOpVscale || !isOneOpConstantMatchingDim)
return failure();
VectorType maskType = createMaskOp.getVectorType();
ArrayRef<int64_t> maskTypeDimSizes = maskType.getShape();
ArrayRef<bool> maskTypeDimScalableFlags = maskType.getScalableDims();

// Special case: Rank zero shape.
constexpr std::array<int64_t, 1> rankZeroShape{1};
constexpr std::array<bool, 1> rankZeroScalableDims{false};
if (maskType.getRank() == 0) {
maskTypeDimSizes = rankZeroShape;
maskTypeDimScalableFlags = rankZeroScalableDims;
}

// Gather constant mask dimension sizes.
SmallVector<int64_t, 4> maskDimSizes;
maskDimSizes.reserve(createMaskOp->getNumOperands());
for (auto [operand, maxDimSize] : llvm::zip_equal(
createMaskOp.getOperands(), createMaskOp.getType().getShape())) {
std::optional dimSize = getConstantIntValue(operand);
if (!dimSize) {
// Although not a constant, it is safe to assume that `operand` is
// "vscale * maxDimSize".
maskDimSizes.push_back(maxDimSize);
continue;
}
int64_t dimSizeVal = std::min(dimSize.value(), maxDimSize);
// If one of dim sizes is zero, set all dims to zero.
if (dimSize <= 0) {
maskDimSizes.assign(createMaskOp.getType().getRank(), 0);
break;
SmallVector<int64_t, 4> constantDims;
for (auto [i, dimSize] : llvm::enumerate(createMaskOp.getOperands())) {
if (auto intSize = getConstantIntValue(dimSize)) {
// Non scalable dims can have any value. Scalable dims can only be zero.
if (intSize >= 0 && maskTypeDimScalableFlags[i])
return failure();
constantDims.push_back(*intSize);
} else if (auto vscaleMultiplier = getConstantVscaleMultiplier(dimSize)) {
// Scalable dims must be all-true.
if (vscaleMultiplier < maskTypeDimSizes[i])
return failure();
constantDims.push_back(*vscaleMultiplier);
} else {
return failure();
}
maskDimSizes.push_back(dimSizeVal);
}

// Clamp values to constant_mask bounds.
for (auto [value, maskDimSize] : llvm::zip(constantDims, maskTypeDimSizes))
value = std::clamp<int64_t>(value, 0, maskDimSize);

// If one of dim sizes is zero, set all dims to zero.
if (llvm::is_contained(constantDims, 0))
constantDims.assign(constantDims.size(), 0);

// Replace 'createMaskOp' with ConstantMaskOp.
rewriter.replaceOpWithNewOp<ConstantMaskOp>(createMaskOp, retTy,
maskDimSizes);
rewriter.replaceOpWithNewOp<ConstantMaskOp>(createMaskOp, maskType,
constantDims);
return success();
}
};
Expand Down
1 change: 1 addition & 0 deletions mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ add_mlir_dialect_library(MLIRVectorTransforms
VectorTransferSplitRewritePatterns.cpp
VectorTransforms.cpp
VectorUnroll.cpp
VectorMaskElimination.cpp

ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Vector/Transforms
Expand Down
114 changes: 114 additions & 0 deletions mlir/lib/Dialect/Vector/Transforms/VectorMaskElimination.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
//===- VectorMaskElimination.cpp - Eliminate Vector Masks -----------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/Dialect/Vector/IR/ScalableValueBoundsConstraintSet.h"
#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
#include "mlir/Interfaces/FunctionInterfaces.h"

using namespace mlir;
using namespace mlir::vector;
namespace {

/// Attempts to resolve a (scalable) CreateMaskOp to an all-true constant mask.
/// All-true masks can then be eliminated by simple folds.
LogicalResult resolveAllTrueCreateMaskOp(IRRewriter &rewriter,
vector::CreateMaskOp createMaskOp,
VscaleRange vscaleRange) {
auto maskType = createMaskOp.getVectorType();
auto maskTypeDimScalableFlags = maskType.getScalableDims();
auto maskTypeDimSizes = maskType.getShape();

struct UnknownMaskDim {
size_t position;
Value dimSize;
};

// Check for any dims that could be (partially) false before doing the more
// expensive value bounds computations.
SmallVector<UnknownMaskDim> unknownDims;
for (auto [i, dimSize] : llvm::enumerate(createMaskOp.getOperands())) {
if (auto intSize = getConstantIntValue(dimSize)) {
// Mask not all-true for this dim.
if (maskTypeDimScalableFlags[i] || intSize < maskTypeDimSizes[i])
return failure();
} else if (auto vscaleMultiplier = getConstantVscaleMultiplier(dimSize)) {
// Mask not all-true for this dim.
if (vscaleMultiplier < maskTypeDimSizes[i])
return failure();
} else {
// Unknown (without further analysis).
unknownDims.push_back(UnknownMaskDim{i, dimSize});
}
}

for (auto [i, dimSize] : unknownDims) {
// Compute the lower bound for the unknown dimension (i.e. the smallest
// value it could be).
FailureOr<ConstantOrScalableBound> dimLowerBound =
vector::ScalableValueBoundsConstraintSet::computeScalableBound(
dimSize, {}, vscaleRange.vscaleMin, vscaleRange.vscaleMax,
presburger::BoundType::LB);
if (failed(dimLowerBound))
return failure();
auto dimLowerBoundSize = dimLowerBound->getSize();
if (failed(dimLowerBoundSize))
return failure();
if (dimLowerBoundSize->scalable) {
// 1. The lower bound, LB, is scalable. If LB is < the mask dim size then
// this dim is not all-true.
if (dimLowerBoundSize->baseSize < maskTypeDimSizes[i])
return failure();
} else {
// 2. The lower bound, LB, is a constant.
// - If the mask dim size is scalable then this dim is not all-true.
if (maskTypeDimScalableFlags[i])
return failure();
// - If LB < the _fixed-size_ mask dim size then this dim is not all-true.
if (dimLowerBoundSize->baseSize < maskTypeDimSizes[i])
return failure();
}
}

// Replace createMaskOp with an all-true constant. This should result in the
// mask being removed in most cases (as xfer ops + vector.mask have folds to
// remove all-true masks).
auto allTrue = rewriter.create<vector::ConstantMaskOp>(
createMaskOp.getLoc(), maskType, ConstantMaskKind::AllTrue);
rewriter.replaceAllUsesWith(createMaskOp, allTrue);
return success();
}

} // namespace

namespace mlir::vector {

void eliminateVectorMasks(IRRewriter &rewriter, FunctionOpInterface function,
std::optional<VscaleRange> vscaleRange) {
// TODO: Support fixed-size case. This is less likely to be useful as for
// fixed-size code dimensions are all static so masks tend to fold away.
if (!vscaleRange)
return;

OpBuilder::InsertionGuard g(rewriter);

// Build worklist so we can safely insert new ops in
// `resolveAllTrueCreateMaskOp()`.
SmallVector<vector::CreateMaskOp> worklist;
function.walk([&](vector::CreateMaskOp createMaskOp) {
worklist.push_back(createMaskOp);
});

rewriter.setInsertionPointToStart(&function.front());
for (auto mask : worklist)
(void)resolveAllTrueCreateMaskOp(rewriter, mask, *vscaleRange);
}

} // namespace mlir::vector
Loading