Skip to content

Commit 85815e7

Browse files
authored
Add an entry point wrapper around functions (llvm pass) (#1149)
SPIR-V spec states: "It is invalid for any function to be targeted by both an OpEntryPoint instruction and an OpFunctionCall instruction." In order to satisfy SPIR-V that entrypoints and functions must be different, this introduces an entrypoint wrapper around functions at the LLVM IR level, then fixes up a few things like naming at the SPIRV translation.
1 parent 2db19de commit 85815e7

25 files changed

Lines changed: 155 additions & 36 deletions

lib/SPIRV/SPIRVInternal.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -377,6 +377,7 @@ const static char TranslateOCLMemScope[] = "__translate_ocl_memory_scope";
377377
const static char TranslateSPIRVMemOrder[] = "__translate_spirv_memory_order";
378378
const static char TranslateSPIRVMemScope[] = "__translate_spirv_memory_scope";
379379
const static char TranslateSPIRVMemFence[] = "__translate_spirv_memory_fence";
380+
const static char EntrypointPrefix[] = "__spirv_entry_";
380381
} // namespace kSPIRVName
381382

382383
namespace kSPIRVPostfix {

lib/SPIRV/SPIRVReader.cpp

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2768,6 +2768,24 @@ Function *SPIRVToLLVM::transFunction(SPIRVFunction *BF) {
27682768
return Loc->second;
27692769

27702770
auto IsKernel = isKernel(BF);
2771+
2772+
if (IsKernel) {
2773+
// search for a previous function with the same name
2774+
// upgrade it to a kernel and drop this if it's found
2775+
for (auto &I : FuncMap) {
2776+
auto BFName = I.getFirst()->getName();
2777+
if (BF->getName() == BFName) {
2778+
auto *F = I.getSecond();
2779+
F->setCallingConv(CallingConv::SPIR_KERNEL);
2780+
F->setLinkage(GlobalValue::ExternalLinkage);
2781+
F->setDSOLocal(false);
2782+
F = cast<Function>(mapValue(BF, F));
2783+
mapFunction(BF, F);
2784+
return F;
2785+
}
2786+
}
2787+
}
2788+
27712789
auto Linkage = IsKernel ? GlobalValue::ExternalLinkage : transLinkageType(BF);
27722790
FunctionType *FT = dyn_cast<FunctionType>(transType(BF->getFunctionType()));
27732791
std::string FuncName = BF->getName();

lib/SPIRV/SPIRVRegularizeLLVM.cpp

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939

4040
#include "OCLUtil.h"
4141
#include "SPIRVInternal.h"
42+
#include "SPIRVMDWalker.h"
4243
#include "libSPIRV/SPIRVDebug.h"
4344

4445
#include "llvm/ADT/StringExtras.h" // llvm::isDigit
@@ -72,6 +73,11 @@ class SPIRVRegularizeLLVMBase {
7273
// Lower functions
7374
bool regularize();
7475

76+
// SPIR-V disallows functions being entrypoints and called
77+
// LLVM doesn't. This adds a wrapper around the entry point
78+
// that later SPIR-V writer renames.
79+
void addKernelEntryPoint(Module *M);
80+
7581
/// Erase cast inst of function and replace with the function.
7682
/// Assuming F is a SPIR-V builtin function with op code \param OC.
7783
void lowerFuncPtr(Function *F, Op OC);
@@ -437,6 +443,7 @@ bool SPIRVRegularizeLLVMBase::runRegularizeLLVM(Module &Module) {
437443
bool SPIRVRegularizeLLVMBase::regularize() {
438444
eraseUselessFunctions(M);
439445
lowerFuncPtr(M);
446+
addKernelEntryPoint(M);
440447

441448
for (auto I = M->begin(), E = M->end(); I != E;) {
442449
Function *F = &(*I++);
@@ -605,6 +612,69 @@ void SPIRVRegularizeLLVMBase::lowerFuncPtr(Module *M) {
605612
lowerFuncPtr(I.first, I.second);
606613
}
607614

615+
void SPIRVRegularizeLLVMBase::addKernelEntryPoint(Module *M) {
616+
std::vector<Function *> Work;
617+
618+
// Get a list of all functions that have SPIR kernel calling conv
619+
for (auto &F : *M) {
620+
if (F.getCallingConv() == CallingConv::SPIR_KERNEL)
621+
Work.push_back(&F);
622+
}
623+
for (auto &F : Work) {
624+
// for declarations just make them into SPIR functions.
625+
F->setCallingConv(CallingConv::SPIR_FUNC);
626+
if (F->isDeclaration())
627+
continue;
628+
629+
// Otherwise add a wrapper around the function to act as an entry point.
630+
FunctionType *FType = F->getFunctionType();
631+
std::string WrapName =
632+
kSPIRVName::EntrypointPrefix + static_cast<std::string>(F->getName());
633+
Function *WrapFn =
634+
getOrCreateFunction(M, F->getReturnType(), FType->params(), WrapName);
635+
636+
auto *CallBB = BasicBlock::Create(M->getContext(), "", WrapFn);
637+
IRBuilder<> Builder(CallBB);
638+
639+
Function::arg_iterator DestI = WrapFn->arg_begin();
640+
for (const Argument &I : F->args()) {
641+
DestI->setName(I.getName());
642+
DestI++;
643+
}
644+
SmallVector<Value *, 1> Args;
645+
for (Argument &I : WrapFn->args()) {
646+
Args.emplace_back(&I);
647+
}
648+
auto *CI = CallInst::Create(F, ArrayRef<Value *>(Args), "", CallBB);
649+
CI->setCallingConv(F->getCallingConv());
650+
CI->setAttributes(F->getAttributes());
651+
652+
// copy over all the metadata (should it be removed from F?)
653+
SmallVector<std::pair<unsigned, MDNode *>> MDs;
654+
F->getAllMetadata(MDs);
655+
WrapFn->setAttributes(F->getAttributes());
656+
for (auto MD = MDs.begin(), End = MDs.end(); MD != End; ++MD) {
657+
WrapFn->addMetadata(MD->first, *MD->second);
658+
}
659+
WrapFn->setCallingConv(CallingConv::SPIR_KERNEL);
660+
WrapFn->setLinkage(llvm::GlobalValue::InternalLinkage);
661+
662+
Builder.CreateRet(F->getReturnType()->isVoidTy() ? nullptr : CI);
663+
664+
// Have to find the spir-v metadata for execution mode and transfer it to
665+
// the wrapper.
666+
if (auto NMD = SPIRVMDWalker(*M).getNamedMD(kSPIRVMD::ExecutionMode)) {
667+
while (!NMD.atEnd()) {
668+
Function *MDF = nullptr;
669+
auto N = NMD.nextOp(); /* execution mode MDNode */
670+
N.get(MDF);
671+
if (MDF == F)
672+
N.M->replaceOperandWith(0, ValueAsMetadata::get(WrapFn));
673+
}
674+
}
675+
}
676+
}
677+
608678
} // namespace SPIRV
609679

610680
INITIALIZE_PASS(SPIRVRegularizeLLVMLegacy, "spvregular",

lib/SPIRV/SPIRVWriter.cpp

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -638,8 +638,15 @@ SPIRVFunction *LLVMToSPIRVBase::transFunctionDecl(Function *F) {
638638
SPIRVFunction *BF =
639639
static_cast<SPIRVFunction *>(mapValue(F, BM->addFunction(BFT)));
640640
BF->setFunctionControlMask(transFunctionControlMask(F));
641-
if (F->hasName())
642-
BM->setName(BF, F->getName().str());
641+
if (F->hasName()) {
642+
if (isKernel(F)) {
643+
/* strip the prefix as the runtime will be looking for this name */
644+
std::string Prefix = kSPIRVName::EntrypointPrefix;
645+
std::string Name = F->getName().str();
646+
BM->setName(BF, Name.substr(Prefix.size()));
647+
} else
648+
BM->setName(BF, F->getName().str());
649+
}
643650
if (!isKernel(F) && F->getLinkage() != GlobalValue::InternalLinkage)
644651
BF->setLinkageType(transLinkageType(F));
645652

@@ -3735,7 +3742,7 @@ void LLVMToSPIRVBase::transFunction(Function *I) {
37353742

37363743
if (isKernel(I)) {
37373744
auto Interface = collectEntryPointInterfaces(BF, I);
3738-
BM->addEntryPoint(ExecutionModelKernel, BF->getId(), I->getName().str(),
3745+
BM->addEntryPoint(ExecutionModelKernel, BF->getId(), BF->getName(),
37393746
Interface);
37403747
}
37413748
}
@@ -4064,8 +4071,9 @@ bool LLVMToSPIRVBase::transMetadata() {
40644071
// Work around to translate kernel_arg_type and kernel_arg_type_qual metadata
40654072
static void transKernelArgTypeMD(SPIRVModule *BM, Function *F, MDNode *MD,
40664073
std::string MDName) {
4067-
std::string KernelArgTypesMDStr =
4068-
std::string(MDName) + "." + F->getName().str() + ".";
4074+
std::string Prefix = kSPIRVName::EntrypointPrefix;
4075+
std::string Name = F->getName().str().substr(Prefix.size());
4076+
std::string KernelArgTypesMDStr = std::string(MDName) + "." + Name + ".";
40694077
for (const auto &TyOp : MD->operands())
40704078
KernelArgTypesMDStr += cast<MDString>(TyOp)->getString().str() + ",";
40714079
BM->getString(KernelArgTypesMDStr);

test/entry_point_func.ll

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
;; Test to check that an LLVM spir_kernel gets translated into an
2+
;; Entrypoint wrapper and Function with LinkageAttributes
3+
; RUN: llvm-as %s -o %t.bc
4+
; RUN: llvm-spirv %t.bc -o - -spirv-text | FileCheck %s --check-prefix=CHECK-SPIRV
5+
; RUN: llvm-spirv %t.bc -o %t.spv
6+
; RUN: spirv-val %t.spv
7+
8+
target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-n8:16:32:64"
9+
target triple = "spir64-unknown-unknown"
10+
11+
define spir_kernel void @testfunction() {
12+
ret void
13+
}
14+
15+
; Check there is an entrypoint and a function produced.
16+
; CHECK-SPIRV: EntryPoint 6 [[EP:[0-9]+]] "testfunction"
17+
; CHECK-SPIRV: Name [[FUNC:[0-9]+]] "testfunction"
18+
; CHECK-SPIRV: Decorate [[FUNC]] LinkageAttributes "testfunction" Export
19+
; CHECK-SPIRV: Function 2 [[FUNC]] 0 3
20+
; CHECK-SPIRV: Function 2 [[EP]] 0 3
21+
; CHECK-SPIRV: FunctionCall 2 8 [[FUNC]]

test/mem2reg.cl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
// RUN: %clang_cc1 -O0 -S -triple spir-unknown-unknown -cl-std=CL2.0 -x cl -disable-O0-optnone %s -emit-llvm-bc -o %t.bc
22
// RUN: llvm-spirv -s %t.bc
3-
// RUN: llvm-dis < %t.bc | FileCheck %s --check-prefixes=CHECK,CHECK-WO
3+
// RUN: llvm-dis < %t.bc | FileCheck %s --check-prefixes=CHECK-WO
44
// RUN: llvm-spirv -s -spirv-mem2reg %t.bc -o %t.opt.bc
5-
// RUN: llvm-dis < %t.opt.bc | FileCheck %s --check-prefixes=CHECK,CHECK-W
6-
// CHECK-LABEL: spir_kernel void @foo
5+
// RUN: llvm-dis < %t.opt.bc | FileCheck %s --check-prefixes=CHECK-W
6+
// CHECK-W-LABEL: spir_func void @foo
77
// CHECK-W-NOT: alloca i32
8+
// CHECK-WO-LABEL: spir_kernel void @foo
89
// CHECK-WO: alloca i32
910
__kernel void foo(__global int *a) {
1011
*a = *a + 1;

test/transcoding/FPGAUnstructuredLoopAttr.ll

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,10 @@
99
; CHECK-SPIRV: 2 Capability FPGALoopControlsINTEL
1010
; CHECK-SPIRV: 9 Extension "SPV_INTEL_fpga_loop_controls"
1111
; CHECK-SPIRV: 11 Extension "SPV_INTEL_unstructured_loop_controls"
12-
; CHECK-SPIRV: 4 EntryPoint 6 [[FOO:[0-9]+]] "foo"
13-
; CHECK-SPIRV: 4 EntryPoint 6 [[BOO:[0-9]+]] "boo"
12+
; CHECK-SPIRV: 3 Name [[FOO:[0-9]+]] "foo"
1413
; CHECK-SPIRV: 4 Name [[ENTRY_1:[0-9]+]] "entry"
1514
; CHECK-SPIRV: 5 Name [[FOR:[0-9]+]] "for.cond"
15+
; CHECK-SPIRV: 3 Name [[BOO:[0-9]+]] "boo"
1616
; CHECK-SPIRV: 4 Name [[ENTRY_2:[0-9]+]] "entry"
1717
; CHECK-SPIRV: 5 Name [[WHILE:[0-9]+]] "while.body"
1818

test/transcoding/KernelArgTypeInOpString.ll

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,8 @@
3939
target datalayout = "e-p:32:32-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024"
4040
target triple = "spir-unknown-unknown"
4141

42-
; CHECK-SPIRV-WORKAROUND: String 14 "kernel_arg_type.foo.image_kernel_data*,myInt,struct struct_name*,"
43-
; CHECK-SPIRV-WORKAROUND-NEGATIVE-NOT: String 14 "kernel_arg_type.foo.image_kernel_data*,myInt,struct struct_name*,"
42+
; CHECK-SPIRV-WORKAROUND: String 20 "kernel_arg_type.foo.image_kernel_data*,myInt,struct struct_name*,"
43+
; CHECK-SPIRV-WORKAROUND-NEGATIVE-NOT: String 20 "kernel_arg_type.foo.image_kernel_data*,myInt,struct struct_name*,"
4444

4545
; CHECK-LLVM-WORKAROUND: !kernel_arg_type [[TYPE:![0-9]+]]
4646
; CHECK-LLVM-WORKAROUND: [[TYPE]] = !{!"image_kernel_data*", !"myInt", !"struct struct_name*"}

test/transcoding/KernelArgTypeInOpString2.ll

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,8 @@
4141
target datalayout = "e-p:32:32-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024"
4242
target triple = "spir"
4343

44-
; CHECK-SPIRV-WORKAROUND: String 17 "kernel_arg_type.foo.cl::tt::vec<float, 4>*,"
45-
; CHECK-SPIRV-WORKAROUND-NEGATIVE-NOT: String 17 "kernel_arg_type.foo.cl::tt::vec<float, 4>*,"
44+
; CHECK-SPIRV-WORKAROUND: String 21 "kernel_arg_type.foo.cl::tt::vec<float, 4>*,"
45+
; CHECK-SPIRV-WORKAROUND-NEGATIVE-NOT: String 21 "kernel_arg_type.foo.cl::tt::vec<float, 4>*,"
4646

4747
; CHECK-LLVM-WORKAROUND: !kernel_arg_type [[TYPE:![0-9]+]]
4848
; CHECK-LLVM-WORKAROUND: [[TYPE]] = !{!"cl::tt::vec<float, 4>*"}

test/transcoding/OpenCL/atomic_cmpxchg.cl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ __kernel void test_atomic_cmpxchg(__global int *p, int cmp, int val) {
1717
atomic_cmpxchg(up, ucmp, uval);
1818
}
1919

20-
// CHECK-SPIRV: EntryPoint {{[0-9]+}} [[TEST:[0-9]+]] "test_atomic_cmpxchg"
20+
// CHECK-SPIRV: Name [[TEST:[0-9]+]] "test_atomic_cmpxchg"
2121
// CHECK-SPIRV-DAG: TypeInt [[UINT:[0-9]+]] 32 0
2222
// CHECK-SPIRV-DAG: TypePointer [[UINT_PTR:[0-9]+]] 5 [[UINT]]
2323
//

0 commit comments

Comments
 (0)