1414#include " clang/AST/RecordLayout.h"
1515#include " clang/AST/RecursiveASTVisitor.h"
1616#include " clang/Sema/Sema.h"
17- #include " llvm/ADT/SmallVector.h"
1817#include " llvm/ADT/SmallPtrSet.h"
18+ #include " llvm/ADT/SmallVector.h"
1919#include " llvm/Support/FileSystem.h"
2020#include " llvm/Support/Path.h"
2121#include " llvm/Support/raw_ostream.h"
@@ -154,14 +154,28 @@ CreateSYCLKernelBody(Sema &S, FunctionDecl *KernelCallerFunc, DeclContext *DC) {
154154 S.Context , NestedNameSpecifierLoc (), SourceLocation (), LambdaVD, false ,
155155 DeclarationNameInfo (), QualType (LC->getTypeForDecl (), 0 ), VK_LValue);
156156
157- // Initialize Lambda fields
158- llvm::SmallVector<Expr *, 16 > InitCaptures;
159-
160157 auto TargetFunc = dyn_cast<FunctionDecl>(DC);
161158 auto TargetFuncParam =
162159 TargetFunc->param_begin (); // Iterator to ParamVarDecl (VarDecl)
163160 if (TargetFuncParam) {
164161 for (auto Field : LC->fields ()) {
162+ auto getExprForPointer = [](Sema &S, const QualType ¶mTy,
163+ DeclRefExpr *DRE) {
164+ // C++ address space attribute != OpenCL address space attribute
165+ Expr *qualifiersCast = ImplicitCastExpr::Create (
166+ S.Context , paramTy, CK_NoOp, DRE, nullptr , VK_LValue);
167+ Expr *Res =
168+ ImplicitCastExpr::Create (S.Context , paramTy, CK_LValueToRValue,
169+ qualifiersCast, nullptr , VK_RValue);
170+ return Res;
171+ };
172+ auto getExprForRange = [](Sema &S, const QualType ¶mTy,
173+ DeclRefExpr *DRE) {
174+ Expr *Res = ImplicitCastExpr::Create (S.Context , paramTy, CK_NoOp, DRE,
175+ nullptr , VK_RValue);
176+ return Res;
177+ };
178+
165179 QualType ParamType = (*TargetFuncParam)->getOriginalType ();
166180 auto DRE =
167181 DeclRefExpr::Create (S.Context , NestedNameSpecifierLoc (),
@@ -171,18 +185,20 @@ CreateSYCLKernelBody(Sema &S, FunctionDecl *KernelCallerFunc, DeclContext *DC) {
171185 QualType FieldType = Field->getType ();
172186 CXXRecordDecl *CRD = FieldType->getAsCXXRecordDecl ();
173187 if (CRD) {
174- llvm::SmallVector<Expr *, 16 > ParamStmts;
175188 DeclAccessPair FieldDAP = DeclAccessPair::make (Field, AS_none);
189+ // lambda.accessor
176190 auto AccessorME = MemberExpr::Create (
177191 S.Context , LambdaDRE, false , SourceLocation (),
178192 NestedNameSpecifierLoc (), SourceLocation (), Field, FieldDAP,
179193 DeclarationNameInfo (Field->getDeclName (), SourceLocation ()),
180194 nullptr , Field->getType (), VK_LValue, OK_Ordinary);
181-
195+ bool PointerOfAccesorWasSet = false ;
182196 for (auto Method : CRD->methods ()) {
197+ llvm::SmallVector<Expr *, 16 > ParamStmts;
183198 if (Method->getNameInfo ().getName ().getAsString () ==
184199 " __set_pointer" ) {
185200 DeclAccessPair MethodDAP = DeclAccessPair::make (Method, AS_none);
201+ // lambda.accessor.__set_pointer
186202 auto ME = MemberExpr::Create (
187203 S.Context , AccessorME, false , SourceLocation (),
188204 NestedNameSpecifierLoc (), SourceLocation (), Method, MethodDAP,
@@ -199,19 +215,75 @@ CreateSYCLKernelBody(Sema &S, FunctionDecl *KernelCallerFunc, DeclContext *DC) {
199215 // __set_pointer needs one parameter
200216 QualType paramTy = (*(Method->param_begin ()))->getOriginalType ();
201217
202- // C++ address space attribute != OpenCL address space attribute
203- Expr *qualifiersCast = ImplicitCastExpr::Create (
204- S.Context , paramTy, CK_NoOp, DRE, nullptr , VK_LValue);
205- Expr *Res = ImplicitCastExpr::Create (
206- S.Context , paramTy, CK_LValueToRValue, qualifiersCast,
207- nullptr , VK_RValue);
218+ Expr *Res = getExprForPointer (S, paramTy, DRE);
208219
220+ // kernel_parameter
209221 ParamStmts.push_back (Res);
210-
211222 // lambda.accessor.__set_pointer(kernel_parameter)
212223 CXXMemberCallExpr *Call = CXXMemberCallExpr::Create (
213224 S.Context , ME, ParamStmts, ResultTy, VK, SourceLocation ());
214225 BodyStmts.push_back (Call);
226+ PointerOfAccesorWasSet = true ;
227+ }
228+ }
229+ if (PointerOfAccesorWasSet) {
230+ TargetFuncParam++;
231+
232+ ParamType = (*TargetFuncParam)->getOriginalType ();
233+ DRE = DeclRefExpr::Create (S.Context , NestedNameSpecifierLoc (),
234+ SourceLocation (), *TargetFuncParam, false ,
235+ DeclarationNameInfo (), ParamType,
236+ VK_LValue);
237+
238+ FieldType = Field->getType ();
239+ CRD = FieldType->getAsCXXRecordDecl ();
240+ if (CRD) {
241+ FieldDAP = DeclAccessPair::make (Field, AS_none);
242+ // lambda.accessor
243+ AccessorME = MemberExpr::Create (
244+ S.Context , LambdaDRE, false , SourceLocation (),
245+ NestedNameSpecifierLoc (), SourceLocation (), Field, FieldDAP,
246+ DeclarationNameInfo (Field->getDeclName (), SourceLocation ()),
247+ nullptr , Field->getType (), VK_LValue, OK_Ordinary);
248+
249+ for (auto Method : CRD->methods ()) {
250+ llvm::SmallVector<Expr *, 16 > ParamStmts;
251+ if (Method->getNameInfo ().getName ().getAsString () ==
252+ " __set_range" ) {
253+ // lambda.accessor.__set_range
254+ DeclAccessPair MethodDAP =
255+ DeclAccessPair::make (Method, AS_none);
256+ auto ME = MemberExpr::Create (
257+ S.Context , AccessorME, false , SourceLocation (),
258+ NestedNameSpecifierLoc (), SourceLocation (), Method,
259+ MethodDAP, Method->getNameInfo (), nullptr ,
260+ Method->getType (), VK_LValue, OK_Ordinary);
261+
262+ // Not referenced -> not emitted
263+ S.MarkFunctionReferenced (SourceLocation (), Method, true );
264+
265+ QualType ResultTy = Method->getReturnType ();
266+ ExprValueKind VK = Expr::getValueKindForType (ResultTy);
267+ ResultTy = ResultTy.getNonLValueExprType (S.Context );
268+
269+ // __set_range needs one parameter
270+ QualType paramTy =
271+ (*(Method->param_begin ()))->getOriginalType ();
272+
273+ Expr *Res = getExprForRange (S, paramTy, DRE);
274+
275+ // kernel_parameter
276+ ParamStmts.push_back (Res);
277+ // lambda.accessor.__set_range(kernel_parameter)
278+ CXXMemberCallExpr *Call = CXXMemberCallExpr::Create (
279+ S.Context , ME, ParamStmts, ResultTy, VK,
280+ SourceLocation ());
281+ BodyStmts.push_back (Call);
282+ }
283+ }
284+ } else {
285+ llvm_unreachable (
286+ " unsupported accessor and without initialized range" );
215287 }
216288 }
217289 } else if (FieldType->isBuiltinType ()) {
@@ -279,6 +351,7 @@ class Util {
279351// / invocation.
280352enum VisitorContext {
281353 pre_visit,
354+ pre_visit_class_field,
282355 visit_accessor,
283356 visit_scalar,
284357 visit_stream,
@@ -308,7 +381,7 @@ static void visitKernelLambdaCaptures(const CXXRecordDecl *Lambda,
308381 QualType ArgTy = V->getType ();
309382 auto F1 = std::get<pre_visit>(Vis);
310383 F1 (Cnt, V, *Fld);
311-
384+ FieldDecl *AccessorRangeField = nullptr ;
312385 if (Util::isSyclAccessorType (ArgTy)) {
313386 // accessor parameter context
314387 const auto *RecordDecl = ArgTy->getAsCXXRecordDecl ();
@@ -317,6 +390,26 @@ static void visitKernelLambdaCaptures(const CXXRecordDecl *Lambda,
317390 dyn_cast<ClassTemplateSpecializationDecl>(RecordDecl);
318391 assert (TemplateDecl && " templated accessor type expected" );
319392
393+ auto getFieldByName = [](const CXXRecordDecl *RecordDecl,
394+ std::string Name) {
395+ FieldDecl *result = nullptr ;
396+ for (auto jt = RecordDecl->field_begin (); jt != RecordDecl->field_end ();
397+ ++jt) {
398+ if (jt->getNameAsString () == Name) {
399+ result = *jt;
400+ break ;
401+ }
402+ }
403+ return result;
404+ };
405+ FieldDecl *AccessorImplField = getFieldByName (RecordDecl, " __impl" );
406+ assert (AccessorImplField && " no __impl found in accessor" );
407+ const auto *AccessorImplRecord =
408+ AccessorImplField->getType ()->getAsCXXRecordDecl ();
409+ assert (AccessorImplRecord && " accessor __impl must be of a record type" );
410+ AccessorRangeField = getFieldByName (AccessorImplRecord, " Range" );
411+ assert (AccessorRangeField && " no Range found in __impl of accessor" );
412+
320413 // First accessor template parameter - data type
321414 QualType PointeeType = TemplateDecl->getTemplateArgs ()[0 ].getAsType ();
322415 // Fourth parameter - access target
@@ -335,9 +428,20 @@ static void visitKernelLambdaCaptures(const CXXRecordDecl *Lambda,
335428 } else {
336429 llvm_unreachable (" unsupported kernel parameter type" );
337430 }
338- // pos -visit context
431+ // post -visit context
339432 auto F2 = std::get<post_visit>(Vis);
340433 F2 (Cnt, V, *Fld);
434+
435+ if (AccessorRangeField) {
436+ // pre-visit context the same like for accessor
437+ auto F1Range = std::get<pre_visit_class_field>(Vis);
438+ F1Range (Cnt, V, *Fld, AccessorRangeField);
439+ auto FRange = std::get<visit_scalar>(Vis);
440+ FRange (Cnt, V, AccessorRangeField);
441+ // post-visit context
442+ auto F2Range = std::get<post_visit>(Vis);
443+ F2Range (Cnt, nullptr , AccessorRangeField);
444+ }
341445 }
342446 assert ((Cpt == CptEnd) && (Fld == FldEnd) &&
343447 " captures inconsistent with fields" );
@@ -350,6 +454,8 @@ static void BuildArgTys(ASTContext &Context, CXXRecordDecl *Lambda,
350454 auto Vis = std::make_tuple (
351455 // pre_visit
352456 [&](int , VarDecl *, FieldDecl *) {},
457+ // pre_visit_class_field
458+ [&](int , VarDecl *, FieldDecl *, FieldDecl *) {},
353459 // visit_accessor
354460 [&](int CaptureN, target AccTrg, QualType PointeeType,
355461 DeclaratorDecl *CapturedVar, FieldDecl *CapturedVal) {
@@ -390,7 +496,10 @@ static void BuildArgTys(ASTContext &Context, CXXRecordDecl *Lambda,
390496 IdentifierInfo *VarName = 0 ;
391497 SmallString<8 > Str;
392498 llvm::raw_svector_ostream OS (Str);
393- OS << " _arg_" << CapturedVar->getIdentifier ()->getName ();
499+ IdentifierInfo *Identifier = (CapturedVar != nullptr )
500+ ? CapturedVar->getIdentifier ()
501+ : CapturedVal->getIdentifier ();
502+ OS << " _arg_" << Identifier->getName ();
394503 VarName = &Context.Idents .get (OS.str ());
395504
396505 auto NewVarDecl = VarDecl::Create (
@@ -422,7 +531,24 @@ static void populateIntHeader(SYCLIntegrationHeader &H, const StringRef Name,
422531 [&](int CaptureN, VarDecl *CapturedVar, FieldDecl *CapturedVal) {
423532 // Set offset in bytes
424533 Offset = static_cast <unsigned >(
425- Layout.getFieldOffset (CapturedVal->getFieldIndex ()))/8 ;
534+ Layout.getFieldOffset (CapturedVal->getFieldIndex ())) /
535+ 8 ;
536+ },
537+ // pre_visit_class_field
538+ [&](int CaptureN, VarDecl *CapturedVar, FieldDecl *CapturedVal,
539+ FieldDecl *MemberVal) {
540+ // Set offset of parent in bytes
541+ Offset = static_cast <unsigned >(
542+ Layout.getFieldOffset (CapturedVal->getFieldIndex ())) /
543+ 8 ;
544+ const RecordDecl *parent = MemberVal->getParent ();
545+ ASTContext &CtxMember = parent->getASTContext ();
546+ const ASTRecordLayout &LayoutMember =
547+ CtxMember.getASTRecordLayout (parent);
548+ // Add offset relative to parent in bytes
549+ Offset += static_cast <unsigned >(
550+ LayoutMember.getFieldOffset (MemberVal->getFieldIndex ())) /
551+ 8 ;
426552 },
427553 // visit_accessor
428554 [&](int CaptureN, target AccTrg, QualType PointeeType,
@@ -453,15 +579,12 @@ static void populateIntHeader(SYCLIntegrationHeader &H, const StringRef Name,
453579// are removed to make the name shorter. Non-alphanumeric characters in a kernel
454580// name are OK - SPIRV and runtimes allow that.
455581static std::string constructKernelName (QualType KernelNameType) {
456- static const std::string Kwds[] = {
457- std::string (" class" ),
458- std::string (" struct" )
459- };
582+ static const std::string Kwds[] = {std::string (" class" ),
583+ std::string (" struct" )};
460584 std::string TStr = KernelNameType.getAsString ();
461585
462586 for (const std::string &Kwd : Kwds) {
463- for (size_t Pos = TStr.find (Kwd);
464- Pos != StringRef::npos;
587+ for (size_t Pos = TStr.find (Kwd); Pos != StringRef::npos;
465588 Pos = TStr.find (Kwd, Pos)) {
466589
467590 size_t EndPos = Pos + Kwd.length ();
@@ -593,12 +716,13 @@ static void printDecl(raw_ostream &O, const Decl *D) {
593716// \param Depth
594717// recursion depth
595718//
596- static void emitForwardClassDecls (raw_ostream &O,
597- QualType T,
598- llvm::SmallPtrSetImpl<const void *> &Printed) {
719+ static void
720+ emitForwardClassDecls (raw_ostream &O, QualType T,
721+ llvm::SmallPtrSetImpl<const void *> &Printed) {
599722
600723 // peel off the pointer types and get the class/struct type:
601- for (; T->isPointerType (); T = T->getPointeeType ());
724+ for (; T->isPointerType (); T = T->getPointeeType ())
725+ ;
602726 const CXXRecordDecl *RD = T->getAsCXXRecordDecl ();
603727
604728 if (!RD)
@@ -657,7 +781,7 @@ void SYCLIntegrationHeader::emit(raw_ostream &O) {
657781 O << " \n " ;
658782 O << " // Forward declarations of templated kernel function types:\n " ;
659783
660- llvm::SmallPtrSet<const void *, 4 > Printed;
784+ llvm::SmallPtrSet<const void *, 4 > Printed;
661785
662786 for (const KernelDesc &K : KernelDescs) {
663787 emitForwardClassDecls (O, K.NameType , Printed);
@@ -737,12 +861,11 @@ void SYCLIntegrationHeader::emit(raw_ostream &O) {
737861
738862 for (const KernelDesc &K : KernelDescs) {
739863 const size_t N = K.Params .size ();
740- O << " template <> struct KernelInfo<" <<
741- K.NameType .getAsString () << " > {\n " ;
742- O << " static constexpr const char* getName() { return \" "
743- << K.Name << " \" ; }\n " ;
744- O << " static constexpr unsigned getNumParams() { return "
745- << N << " ; }\n " ;
864+ O << " template <> struct KernelInfo<" << K.NameType .getAsString ()
865+ << " > {\n " ;
866+ O << " static constexpr const char* getName() { return \" " << K.Name
867+ << " \" ; }\n " ;
868+ O << " static constexpr unsigned getNumParams() { return " << N << " ; }\n " ;
746869 O << " static constexpr const kernel_param_desc_t& " ;
747870 O << " getParamDesc(unsigned i) {\n " ;
748871 O << " return kernel_signatures[i+" << CurStart << " ];\n " ;
@@ -757,7 +880,6 @@ void SYCLIntegrationHeader::emit(raw_ostream &O) {
757880 O << " \n " ;
758881}
759882
760-
761883bool SYCLIntegrationHeader::emit (const StringRef &IntHeaderName) {
762884 if (IntHeaderName.empty ())
763885 return false ;
0 commit comments