diff --git a/clang/lib/CIR/Dialect/Transforms/CIRSimplify.cpp b/clang/lib/CIR/Dialect/Transforms/CIRSimplify.cpp index 4cc0021ee287..7913a0ccebac 100644 --- a/clang/lib/CIR/Dialect/Transforms/CIRSimplify.cpp +++ b/clang/lib/CIR/Dialect/Transforms/CIRSimplify.cpp @@ -141,6 +141,31 @@ struct SimplifySelect : public OpRewritePattern { } }; +struct SimplifyVecSplat : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(VecSplatOp op, + PatternRewriter &rewriter) const override { + mlir::Value splatValue = op.getValue(); + auto constant = + mlir::dyn_cast_if_present(splatValue.getDefiningOp()); + if (!constant) + return mlir::failure(); + + auto value = constant.getValue(); + if (!mlir::isa_and_nonnull(value) && + !mlir::isa_and_nonnull(value)) + return mlir::failure(); + + cir::VectorType resultType = op.getResult().getType(); + SmallVector elements(resultType.getSize(), value); + auto constVecAttr = cir::ConstVectorAttr::get( + resultType, mlir::ArrayAttr::get(getContext(), elements)); + + rewriter.replaceOpWithNewOp(op, constVecAttr); + return mlir::success(); + } +}; + //===----------------------------------------------------------------------===// // CIRSimplifyPass //===----------------------------------------------------------------------===// @@ -155,7 +180,8 @@ void populateMergeCleanupPatterns(RewritePatternSet &patterns) { // clang-format off patterns.add< SimplifyTernary, - SimplifySelect + SimplifySelect, + SimplifyVecSplat >(patterns.getContext()); // clang-format on } @@ -168,7 +194,7 @@ void CIRSimplifyPass::runOnOperation() { // Collect operations to apply patterns. llvm::SmallVector ops; getOperation()->walk([&](Operation *op) { - if (isa(op)) + if (isa(op)) ops.push_back(op); }); diff --git a/clang/test/CIR/Transforms/vector-splat.cir b/clang/test/CIR/Transforms/vector-splat.cir new file mode 100644 index 000000000000..76195c8a289e --- /dev/null +++ b/clang/test/CIR/Transforms/vector-splat.cir @@ -0,0 +1,16 @@ +// RUN: cir-opt %s -cir-simplify -o - | FileCheck %s + +!s32i = !cir.int + +module { + cir.func @fold_splat_vector_op_test() -> !cir.vector { + %v = cir.const #cir.int<3> : !s32i + %vec = cir.vec.splat %v : !s32i, !cir.vector + cir.return %vec : !cir.vector + } + + // CHECK: cir.func @fold_splat_vector_op_test() -> !cir.vector { + // CHECK-NEXT: %0 = cir.const #cir.const_vector<[#cir.int<3> : !s32i, #cir.int<3> : !s32i, + // CHECK-SAME: #cir.int<3> : !s32i, #cir.int<3> : !s32i]> : !cir.vector + // CHECK-NEXT: cir.return %0 : !cir.vector +}