|
39 | 39 |
|
40 | 40 | #include "OCLUtil.h" |
41 | 41 | #include "SPIRVInternal.h" |
| 42 | +#include "SPIRVMDWalker.h" |
42 | 43 | #include "libSPIRV/SPIRVDebug.h" |
43 | 44 |
|
44 | 45 | #include "llvm/ADT/StringExtras.h" // llvm::isDigit |
@@ -72,6 +73,11 @@ class SPIRVRegularizeLLVMBase { |
72 | 73 | // Lower functions |
73 | 74 | bool regularize(); |
74 | 75 |
|
| 76 | + // SPIR-V disallows functions being entrypoints and called |
| 77 | + // LLVM doesn't. This adds a wrapper around the entry point |
| 78 | + // that later SPIR-V writer renames. |
| 79 | + void addKernelEntryPoint(Module *M); |
| 80 | + |
75 | 81 | /// Erase cast inst of function and replace with the function. |
76 | 82 | /// Assuming F is a SPIR-V builtin function with op code \param OC. |
77 | 83 | void lowerFuncPtr(Function *F, Op OC); |
@@ -437,6 +443,7 @@ bool SPIRVRegularizeLLVMBase::runRegularizeLLVM(Module &Module) { |
437 | 443 | bool SPIRVRegularizeLLVMBase::regularize() { |
438 | 444 | eraseUselessFunctions(M); |
439 | 445 | lowerFuncPtr(M); |
| 446 | + addKernelEntryPoint(M); |
440 | 447 |
|
441 | 448 | for (auto I = M->begin(), E = M->end(); I != E;) { |
442 | 449 | Function *F = &(*I++); |
@@ -605,6 +612,69 @@ void SPIRVRegularizeLLVMBase::lowerFuncPtr(Module *M) { |
605 | 612 | lowerFuncPtr(I.first, I.second); |
606 | 613 | } |
607 | 614 |
|
| 615 | +void SPIRVRegularizeLLVMBase::addKernelEntryPoint(Module *M) { |
| 616 | + std::vector<Function *> Work; |
| 617 | + |
| 618 | + // Get a list of all functions that have SPIR kernel calling conv |
| 619 | + for (auto &F : *M) { |
| 620 | + if (F.getCallingConv() == CallingConv::SPIR_KERNEL) |
| 621 | + Work.push_back(&F); |
| 622 | + } |
| 623 | + for (auto &F : Work) { |
| 624 | + // for declarations just make them into SPIR functions. |
| 625 | + F->setCallingConv(CallingConv::SPIR_FUNC); |
| 626 | + if (F->isDeclaration()) |
| 627 | + continue; |
| 628 | + |
| 629 | + // Otherwise add a wrapper around the function to act as an entry point. |
| 630 | + FunctionType *FType = F->getFunctionType(); |
| 631 | + std::string WrapName = |
| 632 | + kSPIRVName::EntrypointPrefix + static_cast<std::string>(F->getName()); |
| 633 | + Function *WrapFn = |
| 634 | + getOrCreateFunction(M, F->getReturnType(), FType->params(), WrapName); |
| 635 | + |
| 636 | + auto *CallBB = BasicBlock::Create(M->getContext(), "", WrapFn); |
| 637 | + IRBuilder<> Builder(CallBB); |
| 638 | + |
| 639 | + Function::arg_iterator DestI = WrapFn->arg_begin(); |
| 640 | + for (const Argument &I : F->args()) { |
| 641 | + DestI->setName(I.getName()); |
| 642 | + DestI++; |
| 643 | + } |
| 644 | + SmallVector<Value *, 1> Args; |
| 645 | + for (Argument &I : WrapFn->args()) { |
| 646 | + Args.emplace_back(&I); |
| 647 | + } |
| 648 | + auto *CI = CallInst::Create(F, ArrayRef<Value *>(Args), "", CallBB); |
| 649 | + CI->setCallingConv(F->getCallingConv()); |
| 650 | + CI->setAttributes(F->getAttributes()); |
| 651 | + |
| 652 | + // copy over all the metadata (should it be removed from F?) |
| 653 | + SmallVector<std::pair<unsigned, MDNode *>> MDs; |
| 654 | + F->getAllMetadata(MDs); |
| 655 | + WrapFn->setAttributes(F->getAttributes()); |
| 656 | + for (auto MD = MDs.begin(), End = MDs.end(); MD != End; ++MD) { |
| 657 | + WrapFn->addMetadata(MD->first, *MD->second); |
| 658 | + } |
| 659 | + WrapFn->setCallingConv(CallingConv::SPIR_KERNEL); |
| 660 | + WrapFn->setLinkage(llvm::GlobalValue::InternalLinkage); |
| 661 | + |
| 662 | + Builder.CreateRet(F->getReturnType()->isVoidTy() ? nullptr : CI); |
| 663 | + |
| 664 | + // Have to find the spir-v metadata for execution mode and transfer it to |
| 665 | + // the wrapper. |
| 666 | + if (auto NMD = SPIRVMDWalker(*M).getNamedMD(kSPIRVMD::ExecutionMode)) { |
| 667 | + while (!NMD.atEnd()) { |
| 668 | + Function *MDF = nullptr; |
| 669 | + auto N = NMD.nextOp(); /* execution mode MDNode */ |
| 670 | + N.get(MDF); |
| 671 | + if (MDF == F) |
| 672 | + N.M->replaceOperandWith(0, ValueAsMetadata::get(WrapFn)); |
| 673 | + } |
| 674 | + } |
| 675 | + } |
| 676 | +} |
| 677 | + |
608 | 678 | } // namespace SPIRV |
609 | 679 |
|
610 | 680 | INITIALIZE_PASS(SPIRVRegularizeLLVMLegacy, "spvregular", |
|
0 commit comments