diff --git a/clang/include/clang/Analysis/CallGraph.h b/clang/include/clang/Analysis/CallGraph.h index 999ac5da8acb6..a63ded97dd030 100644 --- a/clang/include/clang/Analysis/CallGraph.h +++ b/clang/include/clang/Analysis/CallGraph.h @@ -29,6 +29,7 @@ namespace clang { +class ASTContext; class CallGraphNode; class Decl; class DeclContext; @@ -51,6 +52,12 @@ class CallGraph : public RecursiveASTVisitor { /// This is a virtual root node that has edges to all the functions. CallGraphNode *Root; + /// A setting to determine whether this should include calls that are done in + /// a constant expression's context. This DOES require the ASTContext object + /// for constexpr-if, so setting it requires a valid ASTContext. + bool ShouldSkipConstexpr = false; + ASTContext *Ctx; + public: CallGraph(); ~CallGraph(); @@ -66,7 +73,7 @@ class CallGraph : public RecursiveASTVisitor { /// Determine if a declaration should be included in the graph. static bool includeInGraph(const Decl *D); - /// Determine if a declaration should be included in the graph for the + /// Determine if a declaration should be included in the graph for the /// purposes of being a callee. This is similar to includeInGraph except /// it permits declarations, not just definitions. static bool includeCalleeInGraph(const Decl *D); @@ -138,6 +145,15 @@ class CallGraph : public RecursiveASTVisitor { bool shouldWalkTypesOfTypeLocs() const { return false; } bool shouldVisitTemplateInstantiations() const { return true; } bool shouldVisitImplicitCode() const { return true; } + bool shouldSkipConstantExpressions() const { return ShouldSkipConstexpr; } + void setSkipConstantExpressions(ASTContext &Context) { + Ctx = &Context; + ShouldSkipConstexpr = true; + } + ASTContext &getASTContext() { + assert(Ctx); + return *Ctx; + } private: /// Add the given declaration to the call graph. diff --git a/clang/lib/Analysis/CallGraph.cpp b/clang/lib/Analysis/CallGraph.cpp index 2299ba32db501..78197ea5e65c2 100644 --- a/clang/lib/Analysis/CallGraph.cpp +++ b/clang/lib/Analysis/CallGraph.cpp @@ -11,6 +11,7 @@ //===----------------------------------------------------------------------===// #include "clang/Analysis/CallGraph.h" +#include "clang/AST/ASTContext.h" #include "clang/AST/Decl.h" #include "clang/AST/DeclBase.h" #include "clang/AST/DeclObjC.h" @@ -136,6 +137,37 @@ class CGBuilder : public StmtVisitor { } } + void VisitIfStmt(IfStmt *If) { + if (G->shouldSkipConstantExpressions()) { + if (llvm::Optional ActiveStmt = + If->getNondiscardedCase(G->getASTContext())) { + if (*ActiveStmt) + this->Visit(*ActiveStmt); + return; + } + } + + StmtVisitor::VisitIfStmt(If); + } + + void VisitDeclStmt(DeclStmt *DS) { + if (G->shouldSkipConstantExpressions()) { + auto IsConstexprVarDecl = [](Decl *D) { + if (const auto *VD = dyn_cast(D)) + return VD->isConstexpr(); + return false; + }; + if (llvm::any_of(DS->decls(), IsConstexprVarDecl)) { + assert(llvm::all_of(DS->decls(), IsConstexprVarDecl) && + "Situation where a decl-group would be a mix of decl types, or " + "constexpr and not?"); + return; + } + } + + StmtVisitor::VisitDeclStmt(DS); + } + void VisitChildren(Stmt *S) { for (Stmt *SubStmt : S->children()) if (SubStmt) diff --git a/clang/lib/Sema/SemaSYCL.cpp b/clang/lib/Sema/SemaSYCL.cpp index f2e484eaa131d..a51ae72667e4c 100644 --- a/clang/lib/Sema/SemaSYCL.cpp +++ b/clang/lib/Sema/SemaSYCL.cpp @@ -579,27 +579,9 @@ static void collectSYCLAttributes(Sema &S, FunctionDecl *FD, } class DiagDeviceFunction : public RecursiveASTVisitor { - // Used to keep track of the constexpr depth, so we know whether to skip - // diagnostics. - unsigned ConstexprDepth = 0; Sema &SemaRef; const llvm::SmallPtrSetImpl &RecursiveFuncs; - struct ConstexprDepthRAII { - DiagDeviceFunction &DDF; - bool Increment; - - ConstexprDepthRAII(DiagDeviceFunction &DDF, bool Increment = true) - : DDF(DDF), Increment(Increment) { - if (Increment) - ++DDF.ConstexprDepth; - } - ~ConstexprDepthRAII() { - if (Increment) - --DDF.ConstexprDepth; - } - }; - public: DiagDeviceFunction( Sema &S, @@ -617,7 +599,7 @@ class DiagDeviceFunction : public RecursiveASTVisitor { // instantiation as template functions. It means that // all functions used by kernel have already been parsed and have // definitions. - if (RecursiveFuncs.count(Callee) && !ConstexprDepth) { + if (RecursiveFuncs.count(Callee)) { SemaRef.Diag(e->getExprLoc(), diag::err_sycl_restrict) << Sema::KernelCallRecursiveFunction; SemaRef.Diag(Callee->getSourceRange().getBegin(), @@ -670,45 +652,41 @@ class DiagDeviceFunction : public RecursiveASTVisitor { // Skip checking rules on variables initialized during constant evaluation. bool TraverseVarDecl(VarDecl *VD) { - ConstexprDepthRAII R(*this, VD->isConstexpr()); + if (VD->isConstexpr()) + return true; return RecursiveASTVisitor::TraverseVarDecl(VD); } // Skip checking rules on template arguments, since these are constant // expressions. bool TraverseTemplateArgumentLoc(const TemplateArgumentLoc &ArgLoc) { - ConstexprDepthRAII R(*this); - return RecursiveASTVisitor::TraverseTemplateArgumentLoc(ArgLoc); + return true; } // Skip checking the static assert, both components are required to be // constant expressions. - bool TraverseStaticAssertDecl(StaticAssertDecl *D) { - ConstexprDepthRAII R(*this); - return RecursiveASTVisitor::TraverseStaticAssertDecl(D); - } + bool TraverseStaticAssertDecl(StaticAssertDecl *D) { return true; } // Make sure we skip the condition of the case, since that is a constant // expression. bool TraverseCaseStmt(CaseStmt *S) { - { - ConstexprDepthRAII R(*this); - if (!TraverseStmt(S->getLHS())) - return false; - if (!TraverseStmt(S->getRHS())) - return false; - } return TraverseStmt(S->getSubStmt()); } // Skip checking the size expr, since a constant array type loc's size expr is // a constant expression. bool TraverseConstantArrayTypeLoc(const ConstantArrayTypeLoc &ArrLoc) { - if (!TraverseTypeLoc(ArrLoc.getElementLoc())) - return false; + return true; + } - ConstexprDepthRAII R(*this); - return TraverseStmt(ArrLoc.getSizeExpr()); + bool TraverseIfStmt(IfStmt *S) { + if (llvm::Optional ActiveStmt = + S->getNondiscardedCase(SemaRef.Context)) { + if (*ActiveStmt) + return TraverseStmt(*ActiveStmt); + return true; + } + return RecursiveASTVisitor::TraverseIfStmt(S); } }; @@ -749,6 +727,7 @@ class DeviceFunctionTracker { public: DeviceFunctionTracker(Sema &S) : SemaRef(S) { + CG.setSkipConstantExpressions(S.Context); CG.addToCallGraph(S.getASTContext().getTranslationUnitDecl()); CollectSyclExternalFuncs(); } diff --git a/clang/test/SemaSYCL/allow-constexpr-recursion.cpp b/clang/test/SemaSYCL/allow-constexpr-recursion.cpp index 4dde146900849..0e888eb15d888 100644 --- a/clang/test/SemaSYCL/allow-constexpr-recursion.cpp +++ b/clang/test/SemaSYCL/allow-constexpr-recursion.cpp @@ -8,7 +8,7 @@ sycl::queue q; constexpr int constexpr_recurse1(int n); -// expected-note@+1 3{{function implemented using recursion declared here}} +// expected-note@+1 5{{function implemented using recursion declared here}} constexpr int constexpr_recurse(int n) { if (n) return constexpr_recurse1(n - 1); @@ -20,6 +20,10 @@ constexpr int constexpr_recurse1(int n) { return constexpr_recurse(n) + 1; } +constexpr int test_constexpr_context(int n) { + return n; +} + template void bar() {} @@ -55,15 +59,13 @@ void ConstexprIf2() { // they should not diagnose. void constexpr_recurse_test() { constexpr int i = constexpr_recurse(1); + constexpr int j = test_constexpr_context(constexpr_recurse(1)); bar(); bar2<1, 2, constexpr_recurse(2)>(); static_assert(constexpr_recurse(2) == 105, ""); - int j; switch (105) { case constexpr_recurse(2): - // expected-error@+1{{SYCL kernel cannot call a recursive function}} - j = constexpr_recurse(5); break; } @@ -78,14 +80,40 @@ void constexpr_recurse_test() { ConditionallyExplicitCtor c(1); - ConstexprIf1<0>(); // Should not cause a diagnostic. - // expected-error@+1{{SYCL kernel cannot call a recursive function}} - ConstexprIf2<1>(); + ConstexprIf1<0>(); + + int k; + if constexpr (false) + k = constexpr_recurse(1); + else + constexpr int l = test_constexpr_context(constexpr_recurse(1)); } void constexpr_recurse_test_err() { // expected-error@+1{{SYCL kernel cannot call a recursive function}} int i = constexpr_recurse(1); + + // expected-error@+1{{SYCL kernel cannot call a recursive function}} + ConstexprIf2<1>(); + + int j, k; + if constexpr (true) + // expected-error@+1{{SYCL kernel cannot call a recursive function}} + j = constexpr_recurse(1); + + if constexpr (false) + j = constexpr_recurse(1); // Should not diagnose in discarded branch + else + // expected-error@+1{{SYCL kernel cannot call a recursive function}} + k = constexpr_recurse(1); + + switch (105) { + case constexpr_recurse(2): + constexpr int l = test_constexpr_context(constexpr_recurse(1)); + // expected-error@+1{{SYCL kernel cannot call a recursive function}} + j = constexpr_recurse(5); + break; + } } int main() {