Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions clang/include/clang/CIR/Dialect/IR/CIROps.td
Original file line number Diff line number Diff line change
Expand Up @@ -4738,6 +4738,7 @@ class UnaryFPToFPBuiltinOp<string mnemonic, string llvmOpName>
}

def CeilOp : UnaryFPToFPBuiltinOp<"ceil", "FCeilOp">;
def ACosOp : UnaryFPToFPBuiltinOp<"acos", "ACosOp">;
def CosOp : UnaryFPToFPBuiltinOp<"cos", "CosOp">;
def ExpOp : UnaryFPToFPBuiltinOp<"exp", "ExpOp">;
def Exp2Op : UnaryFPToFPBuiltinOp<"exp2", "Exp2Op">;
Expand Down
2 changes: 1 addition & 1 deletion clang/lib/CIR/CodeGen/CIRGenBuiltin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1400,7 +1400,7 @@ RValue CIRGenFunction::emitBuiltinExpr(const GlobalDecl GD, unsigned BuiltinID,
return RValue::get(result);
}
case Builtin::BI__builtin_elementwise_acos: {
return emitBuiltinWithOneOverloadedType<1>(E, "acos");
return emitUnaryMaybeConstrainedFPBuiltin<cir::ACosOp>(*this, *E);
}
case Builtin::BI__builtin_elementwise_asin:
llvm_unreachable("BI__builtin_elementwise_asin NYI");
Expand Down
65 changes: 39 additions & 26 deletions clang/lib/CIR/Lowering/ThroughMLIR/LowerCIRToMLIR.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -79,11 +79,11 @@ struct ConvertCIRToMLIRPass
: public mlir::PassWrapper<ConvertCIRToMLIRPass,
mlir::OperationPass<mlir::ModuleOp>> {
void getDependentDialects(mlir::DialectRegistry &registry) const override {
registry.insert<mlir::BuiltinDialect, mlir::func::FuncDialect,
mlir::affine::AffineDialect, mlir::memref::MemRefDialect,
mlir::arith::ArithDialect, mlir::cf::ControlFlowDialect,
mlir::scf::SCFDialect, mlir::math::MathDialect,
mlir::vector::VectorDialect>();
registry.insert<mlir::LLVM::LLVMDialect, mlir::BuiltinDialect,
mlir::func::FuncDialect, mlir::affine::AffineDialect,
mlir::memref::MemRefDialect, mlir::arith::ArithDialect,
mlir::cf::ControlFlowDialect, mlir::scf::SCFDialect,
mlir::math::MathDialect, mlir::vector::VectorDialect>();
}
void runOnOperation() final;

Expand Down Expand Up @@ -279,6 +279,18 @@ class CIRStoreOpLowering : public mlir::OpConversionPattern<cir::StoreOp> {
}
};

class CIRACosOpLowering : public mlir::OpConversionPattern<cir::ACosOp> {
public:
using OpConversionPattern<cir::ACosOp>::OpConversionPattern;

mlir::LogicalResult
matchAndRewrite(cir::ACosOp op, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<mlir::LLVM::ACosOp>(op, adaptor.getSrc());
return mlir::LogicalResult::success();
}
};

class CIRCosOpLowering : public mlir::OpConversionPattern<cir::CosOp> {
public:
using OpConversionPattern<cir::CosOp>::OpConversionPattern;
Expand Down Expand Up @@ -1356,22 +1368,23 @@ void populateCIRToMLIRConversionPatterns(mlir::RewritePatternSet &patterns,
mlir::TypeConverter &converter) {
patterns.add<CIRReturnLowering, CIRBrOpLowering>(patterns.getContext());

patterns.add<
CIRCmpOpLowering, CIRCallOpLowering, CIRUnaryOpLowering, CIRBinOpLowering,
CIRLoadOpLowering, CIRConstantOpLowering, CIRStoreOpLowering,
CIRAllocaOpLowering, CIRFuncOpLowering, CIRScopeOpLowering,
CIRBrCondOpLowering, CIRTernaryOpLowering, CIRYieldOpLowering,
CIRCosOpLowering, CIRGlobalOpLowering, CIRGetGlobalOpLowering,
CIRCastOpLowering, CIRPtrStrideOpLowering, CIRSqrtOpLowering,
CIRCeilOpLowering, CIRExp2OpLowering, CIRExpOpLowering, CIRFAbsOpLowering,
CIRAbsOpLowering, CIRFloorOpLowering, CIRLog10OpLowering,
CIRLog2OpLowering, CIRLogOpLowering, CIRRoundOpLowering,
CIRPtrStrideOpLowering, CIRSinOpLowering, CIRShiftOpLowering,
CIRBitClzOpLowering, CIRBitCtzOpLowering, CIRBitPopcountOpLowering,
CIRBitClrsbOpLowering, CIRBitFfsOpLowering, CIRBitParityOpLowering,
CIRIfOpLowering, CIRVectorCreateLowering, CIRVectorInsertLowering,
CIRVectorExtractLowering, CIRVectorCmpOpLowering>(converter,
patterns.getContext());
patterns
.add<CIRCmpOpLowering, CIRCallOpLowering, CIRUnaryOpLowering,
CIRBinOpLowering, CIRLoadOpLowering, CIRConstantOpLowering,
CIRStoreOpLowering, CIRAllocaOpLowering, CIRFuncOpLowering,
CIRScopeOpLowering, CIRBrCondOpLowering, CIRTernaryOpLowering,
CIRYieldOpLowering, CIRACosOpLowering, CIRCosOpLowering,
CIRGlobalOpLowering, CIRGetGlobalOpLowering, CIRCastOpLowering,
CIRPtrStrideOpLowering, CIRSqrtOpLowering, CIRCeilOpLowering,
CIRExp2OpLowering, CIRExpOpLowering, CIRFAbsOpLowering,
CIRAbsOpLowering, CIRFloorOpLowering, CIRLog10OpLowering,
CIRLog2OpLowering, CIRLogOpLowering, CIRRoundOpLowering,
CIRPtrStrideOpLowering, CIRSinOpLowering, CIRShiftOpLowering,
CIRBitClzOpLowering, CIRBitCtzOpLowering, CIRBitPopcountOpLowering,
CIRBitClrsbOpLowering, CIRBitFfsOpLowering, CIRBitParityOpLowering,
CIRIfOpLowering, CIRVectorCreateLowering, CIRVectorInsertLowering,
CIRVectorExtractLowering, CIRVectorCmpOpLowering>(
converter, patterns.getContext());
}

static mlir::TypeConverter prepareTypeConverter() {
Expand Down Expand Up @@ -1453,11 +1466,11 @@ void ConvertCIRToMLIRPass::runOnOperation() {

mlir::ConversionTarget target(getContext());
target.addLegalOp<mlir::ModuleOp>();
target
.addLegalDialect<mlir::affine::AffineDialect, mlir::arith::ArithDialect,
mlir::memref::MemRefDialect, mlir::func::FuncDialect,
mlir::scf::SCFDialect, mlir::cf::ControlFlowDialect,
mlir::math::MathDialect, mlir::vector::VectorDialect>();
target.addLegalDialect<mlir::LLVM::LLVMDialect, mlir::affine::AffineDialect,
mlir::arith::ArithDialect, mlir::memref::MemRefDialect,
mlir::func::FuncDialect, mlir::scf::SCFDialect,
mlir::cf::ControlFlowDialect, mlir::math::MathDialect,
mlir::vector::VectorDialect>();
target.addIllegalDialect<cir::CIRDialect>();

if (failed(applyPartialConversion(module, target, std::move(patterns))))
Expand Down
8 changes: 4 additions & 4 deletions clang/test/CIR/CodeGen/builtins-elementwise.c
Original file line number Diff line number Diff line change
Expand Up @@ -41,19 +41,19 @@ void test_builtin_elementwise_acos(float f, double d, vfloat4 vf4,
vdouble4 vd4) {
// CIR-LABEL: test_builtin_elementwise_acos
// LLVM-LABEL: test_builtin_elementwise_acos
// CIR: {{%.*}} = cir.llvm.intrinsic "acos" {{%.*}} : (!cir.float) -> !cir.float
// CIR: {{%.*}} = cir.acos {{%.*}} : !cir.float
// LLVM: {{%.*}} = call float @llvm.acos.f32(float {{%.*}})
f = __builtin_elementwise_acos(f);

// CIR: {{%.*}} = cir.llvm.intrinsic "acos" {{%.*}} : (!cir.double) -> !cir.double
// CIR: {{%.*}} = cir.acos {{%.*}} : !cir.double
// LLVM: {{%.*}} = call double @llvm.acos.f64(double {{%.*}})
d = __builtin_elementwise_acos(d);

// CIR: {{%.*}} = cir.llvm.intrinsic "acos" {{%.*}} : (!cir.vector<!cir.float x 4>) -> !cir.vector<!cir.float x 4>
// CIR: {{%.*}} = cir.acos {{%.*}} : !cir.vector<!cir.float x 4>
// LLVM: {{%.*}} = call <4 x float> @llvm.acos.v4f32(<4 x float> {{%.*}})
vf4 = __builtin_elementwise_acos(vf4);

// CIR: {{%.*}} = cir.llvm.intrinsic "acos" {{%.*}} : (!cir.vector<!cir.double x 4>) -> !cir.vector<!cir.double x 4>
// CIR: {{%.*}} = cir.acos {{%.*}} : !cir.vector<!cir.double x 4>
// LLVM: {{%.*}} = call <4 x double> @llvm.acos.v4f64(<4 x double> {{%.*}})
vd4 = __builtin_elementwise_acos(vd4);
}
Expand Down
30 changes: 30 additions & 0 deletions clang/test/CIR/Lowering/ThroughMLIR/acos.cir
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
// RUN: cir-opt %s -cir-to-mlir -o %t.mlir
// RUN: FileCheck %s --input-file %t.mlir

module {
cir.func @foo() {
%1 = cir.const #cir.fp<1.0> : !cir.float
%2 = cir.const #cir.fp<1.0> : !cir.double
%3 = cir.const #cir.fp<1.0> : !cir.long_double<!cir.f80>
%4 = cir.const #cir.fp<1.0> : !cir.long_double<!cir.double>
%5 = cir.acos %1 : !cir.float
%6 = cir.acos %2 : !cir.double
%7 = cir.acos %3 : !cir.long_double<!cir.f80>
%8 = cir.acos %4 : !cir.long_double<!cir.double>
cir.return
}
}

// CHECK: module {
// CHECK-NEXT: func.func @foo() {
// CHECK-NEXT: %[[C0:.+]] = arith.constant 1.000000e+00 : f32
// CHECK-NEXT: %[[C1:.+]] = arith.constant 1.000000e+00 : f64
// CHECK-NEXT: %[[C2:.+]] = arith.constant 1.000000e+00 : f80
// CHECK-NEXT: %[[C3:.+]] = arith.constant 1.000000e+00 : f64
// CHECK-NEXT: %{{.+}} = llvm.intr.acos(%[[C0]]) : (f32) -> f32
// CHECK-NEXT: %{{.+}} = llvm.intr.acos(%[[C1]]) : (f64) -> f64
// CHECK-NEXT: %{{.+}} = llvm.intr.acos(%[[C2]]) : (f80) -> f80
// CHECK-NEXT: %{{.+}} = llvm.intr.acos(%[[C3]]) : (f64) -> f64
// CHECK-NEXT: return
// CHECK-NEXT: }
// CHECK-NEXT: }