Skip to content

[SYCL] Fix bugs with recursion in SYCL kernel #3958

New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Merged
merged 6 commits into from
Jun 23, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 17 additions & 1 deletion clang/include/clang/Analysis/CallGraph.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@

namespace clang {

class ASTContext;
class CallGraphNode;
class Decl;
class DeclContext;
Expand All @@ -51,6 +52,12 @@ class CallGraph : public RecursiveASTVisitor<CallGraph> {
/// This is a virtual root node that has edges to all the functions.
CallGraphNode *Root;

/// A setting to determine whether this should include calls that are done in
/// a constant expression's context. This DOES require the ASTContext object
/// for constexpr-if, so setting it requires a valid ASTContext.
bool ShouldSkipConstexpr = false;
ASTContext *Ctx;

public:
CallGraph();
~CallGraph();
Expand All @@ -66,7 +73,7 @@ class CallGraph : public RecursiveASTVisitor<CallGraph> {
/// Determine if a declaration should be included in the graph.
static bool includeInGraph(const Decl *D);

/// Determine if a declaration should be included in the graph for the
/// Determine if a declaration should be included in the graph for the
/// purposes of being a callee. This is similar to includeInGraph except
/// it permits declarations, not just definitions.
static bool includeCalleeInGraph(const Decl *D);
Expand Down Expand Up @@ -138,6 +145,15 @@ class CallGraph : public RecursiveASTVisitor<CallGraph> {
bool shouldWalkTypesOfTypeLocs() const { return false; }
bool shouldVisitTemplateInstantiations() const { return true; }
bool shouldVisitImplicitCode() const { return true; }
bool shouldSkipConstantExpressions() const { return ShouldSkipConstexpr; }
void setSkipConstantExpressions(ASTContext &Context) {
Ctx = &Context;
ShouldSkipConstexpr = true;
}
ASTContext &getASTContext() {
assert(Ctx);
return *Ctx;
}

private:
/// Add the given declaration to the call graph.
Expand Down
32 changes: 32 additions & 0 deletions clang/lib/Analysis/CallGraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
//===----------------------------------------------------------------------===//

#include "clang/Analysis/CallGraph.h"
#include "clang/AST/ASTContext.h"
#include "clang/AST/Decl.h"
#include "clang/AST/DeclBase.h"
#include "clang/AST/DeclObjC.h"
Expand Down Expand Up @@ -136,6 +137,37 @@ class CGBuilder : public StmtVisitor<CGBuilder> {
}
}

void VisitIfStmt(IfStmt *If) {
if (G->shouldSkipConstantExpressions()) {
if (llvm::Optional<Stmt *> ActiveStmt =
If->getNondiscardedCase(G->getASTContext())) {
if (*ActiveStmt)
this->Visit(*ActiveStmt);
return;
}
}

StmtVisitor::VisitIfStmt(If);
}

void VisitDeclStmt(DeclStmt *DS) {
if (G->shouldSkipConstantExpressions()) {
auto IsConstexprVarDecl = [](Decl *D) {
if (const auto *VD = dyn_cast<VarDecl>(D))
return VD->isConstexpr();
return false;
};
if (llvm::any_of(DS->decls(), IsConstexprVarDecl)) {
assert(llvm::all_of(DS->decls(), IsConstexprVarDecl) &&
"Situation where a decl-group would be a mix of decl types, or "
"constexpr and not?");
return;
}
}

StmtVisitor::VisitDeclStmt(DS);
}

void VisitChildren(Stmt *S) {
for (Stmt *SubStmt : S->children())
if (SubStmt)
Expand Down
53 changes: 16 additions & 37 deletions clang/lib/Sema/SemaSYCL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -579,27 +579,9 @@ static void collectSYCLAttributes(Sema &S, FunctionDecl *FD,
}

class DiagDeviceFunction : public RecursiveASTVisitor<DiagDeviceFunction> {
// Used to keep track of the constexpr depth, so we know whether to skip
// diagnostics.
unsigned ConstexprDepth = 0;
Sema &SemaRef;
const llvm::SmallPtrSetImpl<const FunctionDecl *> &RecursiveFuncs;

struct ConstexprDepthRAII {
DiagDeviceFunction &DDF;
bool Increment;

ConstexprDepthRAII(DiagDeviceFunction &DDF, bool Increment = true)
: DDF(DDF), Increment(Increment) {
if (Increment)
++DDF.ConstexprDepth;
}
~ConstexprDepthRAII() {
if (Increment)
--DDF.ConstexprDepth;
}
};

public:
DiagDeviceFunction(
Sema &S,
Expand All @@ -617,7 +599,7 @@ class DiagDeviceFunction : public RecursiveASTVisitor<DiagDeviceFunction> {
// instantiation as template functions. It means that
// all functions used by kernel have already been parsed and have
// definitions.
if (RecursiveFuncs.count(Callee) && !ConstexprDepth) {
if (RecursiveFuncs.count(Callee)) {
SemaRef.Diag(e->getExprLoc(), diag::err_sycl_restrict)
<< Sema::KernelCallRecursiveFunction;
SemaRef.Diag(Callee->getSourceRange().getBegin(),
Expand Down Expand Up @@ -670,45 +652,41 @@ class DiagDeviceFunction : public RecursiveASTVisitor<DiagDeviceFunction> {

// Skip checking rules on variables initialized during constant evaluation.
bool TraverseVarDecl(VarDecl *VD) {
ConstexprDepthRAII R(*this, VD->isConstexpr());
if (VD->isConstexpr())
return true;
return RecursiveASTVisitor::TraverseVarDecl(VD);
}

// Skip checking rules on template arguments, since these are constant
// expressions.
bool TraverseTemplateArgumentLoc(const TemplateArgumentLoc &ArgLoc) {
ConstexprDepthRAII R(*this);
return RecursiveASTVisitor::TraverseTemplateArgumentLoc(ArgLoc);
return true;
}

// Skip checking the static assert, both components are required to be
// constant expressions.
bool TraverseStaticAssertDecl(StaticAssertDecl *D) {
ConstexprDepthRAII R(*this);
return RecursiveASTVisitor::TraverseStaticAssertDecl(D);
}
bool TraverseStaticAssertDecl(StaticAssertDecl *D) { return true; }

// Make sure we skip the condition of the case, since that is a constant
// expression.
bool TraverseCaseStmt(CaseStmt *S) {
{
ConstexprDepthRAII R(*this);
if (!TraverseStmt(S->getLHS()))
return false;
if (!TraverseStmt(S->getRHS()))
return false;
}
return TraverseStmt(S->getSubStmt());
}

// Skip checking the size expr, since a constant array type loc's size expr is
// a constant expression.
bool TraverseConstantArrayTypeLoc(const ConstantArrayTypeLoc &ArrLoc) {
if (!TraverseTypeLoc(ArrLoc.getElementLoc()))
return false;
return true;
}

ConstexprDepthRAII R(*this);
return TraverseStmt(ArrLoc.getSizeExpr());
bool TraverseIfStmt(IfStmt *S) {
if (llvm::Optional<Stmt *> ActiveStmt =
S->getNondiscardedCase(SemaRef.Context)) {
if (*ActiveStmt)
return TraverseStmt(*ActiveStmt);
return true;
}
return RecursiveASTVisitor::TraverseIfStmt(S);
}
};

Expand Down Expand Up @@ -749,6 +727,7 @@ class DeviceFunctionTracker {

public:
DeviceFunctionTracker(Sema &S) : SemaRef(S) {
CG.setSkipConstantExpressions(S.Context);
CG.addToCallGraph(S.getASTContext().getTranslationUnitDecl());
CollectSyclExternalFuncs();
}
Expand Down
42 changes: 35 additions & 7 deletions clang/test/SemaSYCL/allow-constexpr-recursion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ sycl::queue q;

constexpr int constexpr_recurse1(int n);

// expected-note@+1 3{{function implemented using recursion declared here}}
// expected-note@+1 5{{function implemented using recursion declared here}}
constexpr int constexpr_recurse(int n) {
if (n)
return constexpr_recurse1(n - 1);
Expand All @@ -20,6 +20,10 @@ constexpr int constexpr_recurse1(int n) {
return constexpr_recurse(n) + 1;
}

constexpr int test_constexpr_context(int n) {
return n;
}

template <int I>
void bar() {}

Expand Down Expand Up @@ -55,15 +59,13 @@ void ConstexprIf2() {
// they should not diagnose.
void constexpr_recurse_test() {
constexpr int i = constexpr_recurse(1);
constexpr int j = test_constexpr_context(constexpr_recurse(1));
bar<constexpr_recurse(2)>();
bar2<1, 2, constexpr_recurse(2)>();
static_assert(constexpr_recurse(2) == 105, "");

int j;
switch (105) {
case constexpr_recurse(2):
// expected-error@+1{{SYCL kernel cannot call a recursive function}}
j = constexpr_recurse(5);
break;
}

Expand All @@ -78,14 +80,40 @@ void constexpr_recurse_test() {

ConditionallyExplicitCtor c(1);

ConstexprIf1<0>(); // Should not cause a diagnostic.
// expected-error@+1{{SYCL kernel cannot call a recursive function}}
ConstexprIf2<1>();
ConstexprIf1<0>();

int k;
if constexpr (false)
k = constexpr_recurse(1);
else
constexpr int l = test_constexpr_context(constexpr_recurse(1));
}

void constexpr_recurse_test_err() {
// expected-error@+1{{SYCL kernel cannot call a recursive function}}
int i = constexpr_recurse(1);

// expected-error@+1{{SYCL kernel cannot call a recursive function}}
ConstexprIf2<1>();

int j, k;
if constexpr (true)
// expected-error@+1{{SYCL kernel cannot call a recursive function}}
j = constexpr_recurse(1);

if constexpr (false)
j = constexpr_recurse(1); // Should not diagnose in discarded branch
else
// expected-error@+1{{SYCL kernel cannot call a recursive function}}
k = constexpr_recurse(1);

switch (105) {
case constexpr_recurse(2):
constexpr int l = test_constexpr_context(constexpr_recurse(1));
// expected-error@+1{{SYCL kernel cannot call a recursive function}}
j = constexpr_recurse(5);
break;
}
}

int main() {
Expand Down