Skip to content

Commit 5a4c5fe

Browse files
committed
[CIR][CUDA] Handle clang builtin functions
1 parent 79d0d74 commit 5a4c5fe

File tree

3 files changed

+133
-9
lines changed

3 files changed

+133
-9
lines changed

clang/lib/CIR/CodeGen/CIRGenBuiltin.cpp

Lines changed: 68 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -467,6 +467,46 @@ static bool isMemBuiltinOutOfBoundPossible(const clang::Expr *sizeArg,
467467
return size.ugt(dstSize);
468468
}
469469

470+
static mlir::Type
471+
decodeFixedType(ArrayRef<llvm::Intrinsic::IITDescriptor> &infos,
472+
mlir::MLIRContext *context) {
473+
using namespace llvm::Intrinsic;
474+
475+
IITDescriptor descriptor = infos.front();
476+
infos = infos.slice(1);
477+
478+
switch (descriptor.Kind) {
479+
case IITDescriptor::Void:
480+
return VoidType::get(context);
481+
case IITDescriptor::Integer:
482+
return IntType::get(context, descriptor.Integer_Width, /*signed=*/true);
483+
case IITDescriptor::Float:
484+
return SingleType::get(context);
485+
case IITDescriptor::Double:
486+
return DoubleType::get(context);
487+
default:
488+
llvm_unreachable("NYI");
489+
}
490+
}
491+
492+
// llvm::Intrinsics accepts only LLVMContext. We need to reimplement it here.
493+
static cir::FuncType getIntrinsicType(mlir::MLIRContext *context,
494+
llvm::Intrinsic::ID id) {
495+
using namespace llvm::Intrinsic;
496+
497+
SmallVector<IITDescriptor, 8> table;
498+
getIntrinsicInfoTableEntries(id, table);
499+
500+
ArrayRef<IITDescriptor> tableRef = table;
501+
mlir::Type resultTy = decodeFixedType(tableRef, context);
502+
503+
SmallVector<mlir::Type, 8> argTypes;
504+
while (!tableRef.empty())
505+
argTypes.push_back(decodeFixedType(tableRef, context));
506+
507+
return FuncType::get(argTypes, resultTy);
508+
}
509+
470510
RValue CIRGenFunction::emitBuiltinExpr(const GlobalDecl GD, unsigned BuiltinID,
471511
const CallExpr *E,
472512
ReturnValueSlot ReturnValue) {
@@ -2525,25 +2565,46 @@ RValue CIRGenFunction::emitBuiltinExpr(const GlobalDecl GD, unsigned BuiltinID,
25252565

25262566
// See if we have a target specific intrinsic.
25272567
std::string Name = getContext().BuiltinInfo.getName(BuiltinID);
2528-
Intrinsic::ID IntrinsicID = Intrinsic::not_intrinsic;
2568+
Intrinsic::ID intrinsicID = Intrinsic::not_intrinsic;
25292569
StringRef Prefix =
25302570
llvm::Triple::getArchTypePrefix(getTarget().getTriple().getArch());
25312571
if (!Prefix.empty()) {
2532-
IntrinsicID = Intrinsic::getIntrinsicForClangBuiltin(Prefix.data(), Name);
2572+
intrinsicID = Intrinsic::getIntrinsicForClangBuiltin(Prefix.data(), Name);
25332573
// NOTE we don't need to perform a compatibility flag check here since the
25342574
// intrinsics are declared in Builtins*.def via LANGBUILTIN which filter the
25352575
// MS builtins via ALL_MS_LANGUAGES and are filtered earlier.
2536-
if (IntrinsicID == Intrinsic::not_intrinsic)
2537-
IntrinsicID = Intrinsic::getIntrinsicForMSBuiltin(Prefix.data(), Name);
2576+
if (intrinsicID == Intrinsic::not_intrinsic)
2577+
intrinsicID = Intrinsic::getIntrinsicForMSBuiltin(Prefix.data(), Name);
25382578
}
25392579

2540-
if (IntrinsicID != Intrinsic::not_intrinsic) {
2580+
if (intrinsicID != Intrinsic::not_intrinsic) {
25412581
unsigned iceArguments = 0;
25422582
ASTContext::GetBuiltinTypeError error;
25432583
getContext().GetBuiltinType(BuiltinID, error, &iceArguments);
25442584
assert(error == ASTContext::GE_None && "Should not codegen an error");
2545-
if (iceArguments > 0)
2546-
llvm_unreachable("NYI");
2585+
2586+
llvm::StringRef name = llvm::Intrinsic::getName(intrinsicID);
2587+
// cir::LLVMIntrinsicCallOp expects intrinsic name to not have prefix
2588+
// "llvm." For example, `llvm.nvvm.barrier0` should be passed as
2589+
// `nvvm.barrier0`.
2590+
if (!name.consume_front("llvm."))
2591+
assert(false && "bad intrinsic name!");
2592+
2593+
cir::FuncType intrinsicType =
2594+
getIntrinsicType(&getMLIRContext(), intrinsicID);
2595+
2596+
SmallVector<mlir::Value> args;
2597+
for (unsigned i = 0; i < E->getNumArgs(); i++) {
2598+
mlir::Value arg = emitScalarExpr(E->getArg(i));
2599+
if (arg.getType() != intrinsicType.getInput(i))
2600+
llvm_unreachable("NYI");
2601+
2602+
args.push_back(arg);
2603+
}
2604+
auto intrinsicCall = builder.create<cir::LLVMIntrinsicCallOp>(
2605+
getLoc(E->getExprLoc()), builder.getStringAttr(name),
2606+
intrinsicType.getReturnType(), args);
2607+
return RValue::get(intrinsicCall.getResult());
25472608
}
25482609

25492610
// Some target-specific builtins can have aggregate return values, e.g.

clang/lib/CIR/CodeGen/CIRGenBuiltinNVPTX.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,9 +75,8 @@ mlir::Value CIRGenFunction::emitNVPTXBuiltinExpr(unsigned builtinId,
7575
return getIntrinsic("nvvm.read.ptx.sreg.nctaid.z");
7676
case NVPTX::BI__nvvm_read_ptx_sreg_nctaid_w:
7777
return getIntrinsic("nvvm.read.ptx.sreg.nctaid.w");
78-
7978
default:
80-
llvm_unreachable("NYI");
79+
return nullptr;
8180
}
8281
}
8382

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
#include "../Inputs/cuda.h"
2+
3+
// RUN: %clang_cc1 -triple nvptx64-nvidia-cuda -fclangir \
4+
// RUN: -fcuda-is-device -emit-cir -target-sdk-version=12.3 \
5+
// RUN: %s -o %t.cir
6+
// RUN: FileCheck --check-prefix=CIR --input-file=%t.cir %s
7+
8+
// RUN: %clang_cc1 -triple nvptx64-nvidia-cuda -fclangir \
9+
// RUN: -fcuda-is-device -emit-llvm -target-sdk-version=12.3 \
10+
// RUN: %s -o %t.ll
11+
// RUN: FileCheck --check-prefix=LLVM --input-file=%t.ll %s
12+
13+
__device__ void builtins() {
14+
float f1, f2;
15+
double d1, d2;
16+
17+
// CIR: cir.llvm.intrinsic "nvvm.fmax.f" {{.*}} : (!cir.float, !cir.float) -> !cir.float
18+
// LLVM: call float @llvm.nvvm.fmax.f(float {{.*}}, float {{.*}})
19+
float t1 = __nvvm_fmax_f(f1, f2);
20+
// CIR: cir.llvm.intrinsic "nvvm.fmin.f" {{.*}} : (!cir.float, !cir.float) -> !cir.float
21+
// LLVM: call float @llvm.nvvm.fmin.f(float {{.*}}, float {{.*}})
22+
float t2 = __nvvm_fmin_f(f1, f2);
23+
// CIR: cir.llvm.intrinsic "nvvm.sqrt.rn.f" {{.*}} : (!cir.float) -> !cir.float
24+
// LLVM: call float @llvm.nvvm.sqrt.rn.f(float {{.*}})
25+
float t3 = __nvvm_sqrt_rn_f(f1);
26+
// CIR: cir.llvm.intrinsic "nvvm.rcp.rn.f" {{.*}} : (!cir.float) -> !cir.float
27+
// LLVM: call float @llvm.nvvm.rcp.rn.f(float {{.*}})
28+
float t4 = __nvvm_rcp_rn_f(f2);
29+
// CIR: cir.llvm.intrinsic "nvvm.add.rn.f" {{.*}} : (!cir.float, !cir.float) -> !cir.float
30+
// LLVM: call float @llvm.nvvm.add.rn.f(float {{.*}}, float {{.*}})
31+
float t5 = __nvvm_add_rn_f(f1, f2);
32+
33+
// CIR: cir.llvm.intrinsic "nvvm.fmax.d" {{.*}} : (!cir.double, !cir.double) -> !cir.double
34+
// LLVM: call double @llvm.nvvm.fmax.d(double {{.*}}, double {{.*}})
35+
double td1 = __nvvm_fmax_d(d1, d2);
36+
// CIR: cir.llvm.intrinsic "nvvm.fmin.d" {{.*}} : (!cir.double, !cir.double) -> !cir.double
37+
// LLVM: call double @llvm.nvvm.fmin.d(double {{.*}}, double {{.*}})
38+
double td2 = __nvvm_fmin_d(d1, d2);
39+
// CIR: cir.llvm.intrinsic "nvvm.sqrt.rn.d" {{.*}} : (!cir.double) -> !cir.double
40+
// LLVM: call double @llvm.nvvm.sqrt.rn.d(double {{.*}})
41+
double td3 = __nvvm_sqrt_rn_d(d1);
42+
// CIR: cir.llvm.intrinsic "nvvm.rcp.rn.d" {{.*}} : (!cir.double) -> !cir.double
43+
// LLVM: call double @llvm.nvvm.rcp.rn.d(double {{.*}})
44+
double td4 = __nvvm_rcp_rn_d(d2);
45+
46+
int i1, i2;
47+
48+
// CIR: cir.llvm.intrinsic "nvvm.mulhi.i" {{.*}} : (!s32i, !s32i) -> !s32i
49+
// LLVM: call i32 @llvm.nvvm.mulhi.i(i32 {{.*}}, i32 {{.*}})
50+
int ti1 = __nvvm_mulhi_i(i1, i2);
51+
52+
// CIR: cir.llvm.intrinsic "nvvm.membar.cta"
53+
// LLVM: call void @llvm.nvvm.membar.cta()
54+
__nvvm_membar_cta();
55+
// CIR: cir.llvm.intrinsic "nvvm.membar.gl"
56+
// LLVM: call void @llvm.nvvm.membar.gl()
57+
__nvvm_membar_gl();
58+
// CIR: cir.llvm.intrinsic "nvvm.membar.sys"
59+
// LLVM: call void @llvm.nvvm.membar.sys()
60+
__nvvm_membar_sys();
61+
// CIR: cir.llvm.intrinsic "nvvm.barrier0"
62+
// LLVM: call void @llvm.nvvm.barrier0()
63+
__syncthreads();
64+
}

0 commit comments

Comments
 (0)