@@ -123,13 +123,16 @@ struct LoweringPreparePass : public LoweringPrepareBase<LoweringPreparePass> {
123123 // / CUDA related
124124 // / ------------
125125
126- // Maps CUDA device stub name to kernel name .
127- llvm::DenseMap<llvm::StringRef, std::string > cudaKernelMap;
126+ // Maps CUDA kernel name to device stub function .
127+ llvm::StringMap<FuncOp > cudaKernelMap;
128128
129129 void buildCUDAModuleCtor ();
130130 void buildCUDAModuleDtor ();
131131 std::optional<FuncOp> buildCUDARegisterGlobals ();
132132
133+ void buildCUDARegisterGlobalFunctions (cir::CIRBaseBuilderTy &builder,
134+ FuncOp regGlobalFunc);
135+
133136 // /
134137 // / AST related
135138 // / -----------
@@ -185,6 +188,18 @@ struct LoweringPreparePass : public LoweringPrepareBase<LoweringPreparePass> {
185188 // / List of annotations in the module
186189 llvm::SmallVector<mlir::Attribute, 4 > globalAnnotations;
187190};
191+
192+ std::string getCUDAPrefix (clang::ASTContext *astCtx) {
193+ if (astCtx->getLangOpts ().HIP )
194+ return " hip" ;
195+ return " cuda" ;
196+ }
197+
198+ std::string addUnderscoredPrefix (llvm::StringRef cudaPrefix,
199+ llvm::StringRef cudaFunctionName) {
200+ return (" __" + cudaPrefix + cudaFunctionName).str ();
201+ }
202+
188203} // namespace
189204
190205GlobalOp LoweringPreparePass::buildRuntimeVariable (
@@ -983,6 +998,11 @@ void LoweringPreparePass::buildCUDAModuleCtor() {
983998 if (astCtx->getLangOpts ().GPURelocatableDeviceCode )
984999 llvm_unreachable (" NYI" );
9851000
1001+ // For CUDA without -fgpu-rdc, it's safe to stop generating ctor
1002+ // if there's nothing to register.
1003+ if (cudaKernelMap.empty ())
1004+ return ;
1005+
9861006 // There's no device-side binary, so no need to proceed for CUDA.
9871007 // HIP has to create an external symbol in this case, which is NYI.
9881008 auto cudaBinaryHandleAttr =
@@ -995,18 +1015,14 @@ void LoweringPreparePass::buildCUDAModuleCtor() {
9951015 std::string cudaGPUBinaryName =
9961016 cast<CUDABinaryHandleAttr>(cudaBinaryHandleAttr).getName ();
9971017
998- llvm::StringRef prefix = " cuda" ;
999-
10001018 constexpr unsigned cudaFatMagic = 0x466243b1 ;
10011019 constexpr unsigned hipFatMagic = 0x48495046 ; // "HIPF"
10021020
1021+ auto cudaPrefix = getCUDAPrefix (astCtx);
1022+
10031023 const unsigned fatMagic =
10041024 astCtx->getLangOpts ().HIP ? hipFatMagic : cudaFatMagic;
10051025
1006- auto addUnderscoredPrefix = [&](llvm::StringRef name) -> std::string {
1007- return (" __" + prefix + name).str ();
1008- };
1009-
10101026 // MAC OS X needs special care, but we haven't supported that in CIR yet.
10111027 assert (!cir::MissingFeatures::checkMacOSXTriple ());
10121028
@@ -1015,15 +1031,11 @@ void LoweringPreparePass::buildCUDAModuleCtor() {
10151031
10161032 mlir::Location loc = theModule.getLoc ();
10171033
1018- // Extract types from the module.
1019- auto typeSizesAttr = cast<TypeSizeInfoAttr>(
1020- theModule->getAttr (CIRDialect::getTypeSizeInfoAttrName ()));
1021-
10221034 auto voidTy = VoidType::get (&getContext ());
10231035 auto voidPtrTy = PointerType::get (voidTy);
10241036 auto voidPtrPtrTy = PointerType::get (voidPtrTy);
1025- auto intTy = typeSizesAttr. getIntType (&getContext ());
1026- auto charTy = typeSizesAttr. getCharType (&getContext ());
1037+ auto intTy = datalayout-> getIntType (&getContext ());
1038+ auto charTy = datalayout-> getCharType (&getContext ());
10271039
10281040 // Read the GPU binary and create a constant array for it.
10291041 llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> cudaGPUBinaryOrErr =
@@ -1046,7 +1058,7 @@ void LoweringPreparePass::buildCUDAModuleCtor() {
10461058
10471059 // OG gives an empty name to this global constant,
10481060 // which is not allowed in CIR.
1049- std::string fatbinStrName = addUnderscoredPrefix (" _fatbin_str" );
1061+ std::string fatbinStrName = addUnderscoredPrefix (cudaPrefix, " _fatbin_str" );
10501062 GlobalOp fatbinStr = builder.create <GlobalOp>(
10511063 loc, fatbinStrName, fatbinType, /* isConstant=*/ true ,
10521064 /* linkage=*/ cir::GlobalLinkageKind::PrivateLinkage);
@@ -1064,59 +1076,186 @@ void LoweringPreparePass::buildCUDAModuleCtor() {
10641076 &getContext (), {intTy, intTy, voidPtrTy, voidPtrTy}, /* packed=*/ false ,
10651077 /* padded=*/ false , StructType::RecordKind::Struct);
10661078
1067- std::string fatbinWrapperName = addUnderscoredPrefix (" _fatbin_wrapper" );
1079+ std::string fatbinWrapperName =
1080+ addUnderscoredPrefix (cudaPrefix, " _fatbin_wrapper" );
10681081 GlobalOp fatbinWrapper = builder.create <GlobalOp>(
1069- loc, fatbinWrapperName, fatbinWrapperType, /* isConstant=*/ false ,
1082+ loc, fatbinWrapperName, fatbinWrapperType, /* isConstant=*/ true ,
10701083 /* linkage=*/ cir::GlobalLinkageKind::InternalLinkage);
10711084 fatbinWrapper.setPrivate ();
10721085 fatbinWrapper.setSection (fatbinSectionName);
10731086
10741087 auto magicInit = IntAttr::get (intTy, fatMagic);
10751088 auto versionInit = IntAttr::get (intTy, 1 );
1076- // `fatbinInit` is only a placeholder. The value will be initialized at the
1077- // beginning of module ctor.
1078- auto fatbinInit = builder. getConstNullPtrAttr (voidPtrTy);
1089+ auto fatbinStrSymbol =
1090+ mlir::FlatSymbolRefAttr::get (fatbinStr. getSymNameAttr ());
1091+ auto fatbinInit = GlobalViewAttr::get (voidPtrTy, fatbinStrSymbol );
10791092 auto unusedInit = builder.getConstNullPtrAttr (voidPtrTy);
10801093 fatbinWrapper.setInitialValueAttr (cir::ConstStructAttr::get (
10811094 fatbinWrapperType,
10821095 ArrayAttr::get (&getContext (),
10831096 {magicInit, versionInit, fatbinInit, unusedInit})));
10841097
1098+ // GPU fat binary handle is also a global variable in OG.
1099+ std::string gpubinHandleName =
1100+ addUnderscoredPrefix (cudaPrefix, " _gpubin_handle" );
1101+ auto gpubinHandle = builder.create <GlobalOp>(
1102+ loc, gpubinHandleName, voidPtrPtrTy,
1103+ /* isConstant=*/ false , /* linkage=*/ GlobalLinkageKind::InternalLinkage);
1104+ gpubinHandle.setInitialValueAttr (builder.getConstNullPtrAttr (voidPtrPtrTy));
1105+ gpubinHandle.setPrivate ();
1106+
10851107 // Declare this function:
10861108 // void **__{cuda|hip}RegisterFatBinary(void *);
10871109
1088- std::string regFuncName = addUnderscoredPrefix (" RegisterFatBinary" );
1110+ std::string regFuncName =
1111+ addUnderscoredPrefix (cudaPrefix, " RegisterFatBinary" );
10891112 auto regFuncType = FuncType::get ({voidPtrTy}, voidPtrPtrTy);
10901113 auto regFunc = buildRuntimeFunction (builder, regFuncName, loc, regFuncType);
10911114
10921115 // Create the module constructor.
10931116
1094- std::string moduleCtorName = addUnderscoredPrefix (" _module_ctor" );
1117+ std::string moduleCtorName = addUnderscoredPrefix (cudaPrefix, " _module_ctor" );
10951118 auto moduleCtor = buildRuntimeFunction (builder, moduleCtorName, loc,
10961119 FuncType::get ({}, voidTy),
10971120 GlobalLinkageKind::InternalLinkage);
10981121 globalCtorList.push_back (GlobalCtorAttr::get (&getContext (), moduleCtorName));
10991122 builder.setInsertionPointToStart (moduleCtor.addEntryBlock ());
11001123
1101- auto wrapper = builder.createGetGlobal (fatbinWrapper);
1102- // Put fatbinStr inside fatbinWrapper.
1103- mlir::Value fatbinStrValue = builder.createGetGlobal (fatbinStr);
1104- mlir::Value fatbinField = builder.createGetMemberOp (loc, wrapper, " " , 2 );
1105- builder.createStore (loc, fatbinStrValue, fatbinField);
1106-
11071124 // Register binary with CUDA runtime. This is substantially different in
11081125 // default mode vs. separate compilation.
11091126 // Corresponding code:
11101127 // gpuBinaryHandle = __cudaRegisterFatBinary(&fatbinWrapper);
1128+ auto wrapper = builder.createGetGlobal (fatbinWrapper);
11111129 auto fatbinVoidPtr = builder.createBitcast (wrapper, voidPtrTy);
1112- builder.createCallOp (loc, regFunc, fatbinVoidPtr);
1130+ auto gpuBinaryHandleCall = builder.createCallOp (loc, regFunc, fatbinVoidPtr);
1131+ auto gpuBinaryHandle = gpuBinaryHandleCall.getResult ();
1132+ // Store the value back to the global `__cuda_gpubin_handle`.
1133+ auto gpuBinaryHandleGlobal = builder.createGetGlobal (gpubinHandle);
1134+ builder.createStore (loc, gpuBinaryHandle, gpuBinaryHandleGlobal);
1135+
1136+ // Generate __cuda_register_globals and call it.
1137+ std::optional<FuncOp> regGlobal = buildCUDARegisterGlobals ();
1138+ if (regGlobal) {
1139+ builder.createCallOp (loc, *regGlobal, gpuBinaryHandle);
1140+ }
11131141
1114- // This is currently incomplete.
1115- // TODO(cir): create __cuda_register_globals(), and call it here.
1142+ // From CUDA 10.1 onwards, we must call this function to end registration:
1143+ // void __cudaRegisterFatBinaryEnd(void **fatbinHandle);
1144+ // This is CUDA-specific, so no need to use `addUnderscoredPrefix`.
1145+ if (clang::CudaFeatureEnabled (
1146+ astCtx->getTargetInfo ().getSDKVersion (),
1147+ clang::CudaFeature::CUDA_USES_FATBIN_REGISTER_END)) {
1148+ cir::CIRBaseBuilderTy globalBuilder (getContext ());
1149+ globalBuilder.setInsertionPointToStart (theModule.getBody ());
1150+ FuncOp endFunc =
1151+ buildRuntimeFunction (globalBuilder, " __cudaRegisterFatBinaryEnd" , loc,
1152+ FuncType::get ({voidPtrPtrTy}, voidTy));
1153+ builder.createCallOp (loc, endFunc, gpuBinaryHandle);
1154+ }
11161155
11171156 builder.create <cir::ReturnOp>(loc);
11181157}
11191158
1159+ std::optional<FuncOp> LoweringPreparePass::buildCUDARegisterGlobals () {
1160+ // There is nothing to register.
1161+ if (cudaKernelMap.empty ())
1162+ return {};
1163+
1164+ cir::CIRBaseBuilderTy builder (getContext ());
1165+ builder.setInsertionPointToStart (theModule.getBody ());
1166+
1167+ auto loc = theModule.getLoc ();
1168+ auto cudaPrefix = getCUDAPrefix (astCtx);
1169+
1170+ auto voidTy = VoidType::get (&getContext ());
1171+ auto voidPtrPtrTy = PointerType::get (PointerType::get (voidTy));
1172+
1173+ // Create the function:
1174+ // void __cuda_register_globals(void **fatbinHandle)
1175+ std::string regGlobalFuncName =
1176+ addUnderscoredPrefix (cudaPrefix, " _register_globals" );
1177+ auto regGlobalFuncTy = FuncType::get ({voidPtrPtrTy}, voidTy);
1178+ FuncOp regGlobalFunc =
1179+ buildRuntimeFunction (builder, regGlobalFuncName, loc, regGlobalFuncTy,
1180+ /* linkage=*/ GlobalLinkageKind::InternalLinkage);
1181+ builder.setInsertionPointToStart (regGlobalFunc.addEntryBlock ());
1182+
1183+ buildCUDARegisterGlobalFunctions (builder, regGlobalFunc);
1184+
1185+ // TODO(cir): registration for global variables.
1186+
1187+ builder.create <ReturnOp>(loc);
1188+ return regGlobalFunc;
1189+ }
1190+
1191+ void LoweringPreparePass::buildCUDARegisterGlobalFunctions (
1192+ cir::CIRBaseBuilderTy &builder, FuncOp regGlobalFunc) {
1193+ auto loc = theModule.getLoc ();
1194+ auto cudaPrefix = getCUDAPrefix (astCtx);
1195+
1196+ auto voidTy = VoidType::get (&getContext ());
1197+ auto voidPtrTy = PointerType::get (voidTy);
1198+ auto voidPtrPtrTy = PointerType::get (voidPtrTy);
1199+ auto intTy = datalayout->getIntType (&getContext ());
1200+ auto charTy = datalayout->getCharType (&getContext ());
1201+
1202+ // Extract the GPU binary handle argument.
1203+ mlir::Value fatbinHandle = *regGlobalFunc.args_begin ();
1204+
1205+ cir::CIRBaseBuilderTy globalBuilder (getContext ());
1206+ globalBuilder.setInsertionPointToStart (theModule.getBody ());
1207+
1208+ // Declare CUDA internal functions:
1209+ // int __cudaRegisterFunction(
1210+ // void **fatbinHandle,
1211+ // const char *hostFunc,
1212+ // char *deviceFunc,
1213+ // const char *deviceName,
1214+ // int threadLimit,
1215+ // uint3 *tid, uint3 *bid, dim3 *bDim, dim3 *gDim,
1216+ // int *wsize
1217+ // )
1218+ // OG doesn't care about the types at all. They're treated as void*.
1219+
1220+ FuncOp cudaRegisterFunction = buildRuntimeFunction (
1221+ globalBuilder, addUnderscoredPrefix (cudaPrefix, " RegisterFunction" ), loc,
1222+ FuncType::get ({voidPtrPtrTy, voidPtrTy, voidPtrTy, voidPtrTy, intTy,
1223+ voidPtrTy, voidPtrTy, voidPtrTy, voidPtrTy, voidPtrTy},
1224+ intTy));
1225+
1226+ auto makeConstantString = [&](llvm::StringRef str) -> GlobalOp {
1227+ auto strType = ArrayType::get (&getContext (), charTy, 1 + str.size ());
1228+
1229+ auto tmpString = globalBuilder.create <GlobalOp>(
1230+ loc, (" .str" + str).str (), strType, /* isConstant=*/ true ,
1231+ /* linkage=*/ cir::GlobalLinkageKind::PrivateLinkage);
1232+
1233+ // We must make the string zero-terminated.
1234+ tmpString.setInitialValueAttr (ConstArrayAttr::get (
1235+ strType, StringAttr::get (&getContext (), str + " \0 " )));
1236+ tmpString.setPrivate ();
1237+ return tmpString;
1238+ };
1239+
1240+ auto cirNullPtr = builder.getNullPtr (voidPtrTy, loc);
1241+ for (auto kernelName : cudaKernelMap.keys ()) {
1242+ FuncOp deviceStub = cudaKernelMap[kernelName];
1243+ GlobalOp deviceFuncStr = makeConstantString (kernelName);
1244+ mlir::Value deviceFunc = builder.createBitcast (
1245+ builder.createGetGlobal (deviceFuncStr), voidPtrTy);
1246+ mlir::Value hostFunc = builder.createBitcast (
1247+ builder.create <GetGlobalOp>(
1248+ loc, PointerType::get (deviceStub.getFunctionType ()),
1249+ mlir::FlatSymbolRefAttr::get (deviceStub.getSymNameAttr ())),
1250+ voidPtrTy);
1251+ builder.createCallOp (
1252+ loc, cudaRegisterFunction,
1253+ {fatbinHandle, hostFunc, deviceFunc, deviceFunc,
1254+ builder.create <ConstantOp>(loc, IntAttr::get (intTy, -1 )), cirNullPtr,
1255+ cirNullPtr, cirNullPtr, cirNullPtr, cirNullPtr});
1256+ }
1257+ }
1258+
11201259void LoweringPreparePass::lowerDynamicCastOp (DynamicCastOp op) {
11211260 CIRBaseBuilderTy builder (getContext ());
11221261 builder.setInsertionPointAfter (op);
@@ -1378,11 +1517,10 @@ void LoweringPreparePass::runOnOp(Operation *op) {
13781517 globalDtorList.push_back (globalDtor);
13791518 }
13801519 if (auto attr = fnOp.getExtraAttrs ().getElements ().get (
1381- CIRDialect::getCUDABinaryHandleAttrName ())) {
1382- auto cudaBinaryAttr = dyn_cast<CUDABinaryHandleAttr>(attr);
1383- std::string kernelName = cudaBinaryAttr.getName ();
1384- llvm::StringRef stubName = fnOp.getSymName ();
1385- cudaKernelMap[stubName] = kernelName;
1520+ CUDAKernelNameAttr::getMnemonic ())) {
1521+ auto cudaBinaryAttr = dyn_cast<CUDAKernelNameAttr>(attr);
1522+ std::string kernelName = cudaBinaryAttr.getKernelName ();
1523+ cudaKernelMap[kernelName] = fnOp;
13861524 }
13871525 if (std::optional<mlir::ArrayAttr> annotations = fnOp.getAnnotations ())
13881526 addGlobalAnnotations (fnOp, annotations.value ());
@@ -1399,6 +1537,9 @@ void LoweringPreparePass::runOnOperation() {
13991537 datalayout.emplace (theModule);
14001538 }
14011539
1540+ auto typeSizeInfo = cast<TypeSizeInfoAttr>(
1541+ theModule->getAttr (CIRDialect::getTypeSizeInfoAttrName ()));
1542+
14021543 llvm::SmallVector<Operation *> opsToTransform;
14031544
14041545 op->walk ([&](Operation *op) {
0 commit comments