@@ -331,6 +331,11 @@ class MarkDeviceFunction : public RecursiveASTVisitor<MarkDeviceFunction> {
331
331
332
332
private:
333
333
bool CheckSYCLType (QualType Ty, SourceRange Loc) {
334
+ llvm::DenseSet<QualType> visited;
335
+ return CheckSYCLType (Ty, Loc, visited);
336
+ }
337
+
338
+ bool CheckSYCLType (QualType Ty, SourceRange Loc, llvm::DenseSet<QualType> &Visited) {
334
339
if (Ty->isVariableArrayType ()) {
335
340
SemaRef.Diag (Loc.getBegin (), diag::err_vla_unsupported);
336
341
return false ;
@@ -339,6 +344,11 @@ class MarkDeviceFunction : public RecursiveASTVisitor<MarkDeviceFunction> {
339
344
while (Ty->isAnyPointerType () || Ty->isArrayType ())
340
345
Ty = QualType{Ty->getPointeeOrArrayElementType (), 0 };
341
346
347
+ // Pointers complicate recursion. Add this type to Visited.
348
+ // If already there, bail out.
349
+ if (!Visited.insert (Ty).second )
350
+ return true ;
351
+
342
352
if (const auto *CRD = Ty->getAsCXXRecordDecl ()) {
343
353
if (CRD->isPolymorphic ()) {
344
354
SemaRef.Diag (CRD->getLocation (), diag::err_sycl_virtual_types);
@@ -347,25 +357,25 @@ class MarkDeviceFunction : public RecursiveASTVisitor<MarkDeviceFunction> {
347
357
}
348
358
349
359
for (const auto &Field : CRD->fields ()) {
350
- if (!CheckSYCLType (Field->getType (), Field->getSourceRange ())) {
360
+ if (!CheckSYCLType (Field->getType (), Field->getSourceRange (), Visited )) {
351
361
SemaRef.Diag (Loc.getBegin (), diag::note_sycl_used_here);
352
362
return false ;
353
363
}
354
364
}
355
365
} else if (const auto *RD = Ty->getAsRecordDecl ()) {
356
366
for (const auto &Field : RD->fields ()) {
357
- if (!CheckSYCLType (Field->getType (), Field->getSourceRange ())) {
367
+ if (!CheckSYCLType (Field->getType (), Field->getSourceRange (), Visited )) {
358
368
SemaRef.Diag (Loc.getBegin (), diag::note_sycl_used_here);
359
369
return false ;
360
370
}
361
371
}
362
372
} else if (const auto *FPTy = dyn_cast<FunctionProtoType>(Ty)) {
363
373
for (const auto &ParamTy : FPTy->param_types ())
364
- if (!CheckSYCLType (ParamTy, Loc))
374
+ if (!CheckSYCLType (ParamTy, Loc, Visited ))
365
375
return false ;
366
- return CheckSYCLType (FPTy->getReturnType (), Loc);
376
+ return CheckSYCLType (FPTy->getReturnType (), Loc, Visited );
367
377
} else if (const auto *FTy = dyn_cast<FunctionType>(Ty)) {
368
- return CheckSYCLType (FTy->getReturnType (), Loc);
378
+ return CheckSYCLType (FTy->getReturnType (), Loc, Visited );
369
379
}
370
380
return true ;
371
381
}
@@ -766,6 +776,16 @@ static void buildArgTys(ASTContext &Context, CXXRecordDecl *KernelObj,
766
776
767
777
// Create descriptors for each accessor field in the class or struct
768
778
createParamDescForWrappedAccessors (Fld, ArgTy);
779
+ } else if (ArgTy->isPointerType ()) {
780
+ // Pointer Arguments need to be in the global address space
781
+ QualType PointeeTy = ArgTy->getPointeeType ();
782
+ Qualifiers Quals = PointeeTy.getQualifiers ();
783
+ Quals.setAddressSpace (LangAS::opencl_global);
784
+ PointeeTy = Context.getQualifiedType (PointeeTy.getUnqualifiedType (),
785
+ Quals);
786
+ QualType ModTy = Context.getPointerType (PointeeTy);
787
+
788
+ CreateAndAddPrmDsc (Fld, ModTy);
769
789
} else if (ArgTy->isScalarType ()) {
770
790
CreateAndAddPrmDsc (Fld, ArgTy);
771
791
} else {
@@ -853,6 +873,10 @@ static void populateIntHeader(SYCLIntegrationHeader &H, const StringRef Name,
853
873
uint64_t Sz = Ctx.getTypeSizeInChars (SamplerArg->getType ()).getQuantity ();
854
874
H.addParamDesc (SYCLIntegrationHeader::kind_sampler,
855
875
static_cast <unsigned >(Sz), static_cast <unsigned >(Offset));
876
+ } else if (ArgTy->isPointerType ()) {
877
+ uint64_t Sz = Ctx.getTypeSizeInChars (Fld->getType ()).getQuantity ();
878
+ H.addParamDesc (SYCLIntegrationHeader::kind_pointer,
879
+ static_cast <unsigned >(Sz), static_cast <unsigned >(Offset));
856
880
} else if (ArgTy->isStructureOrClassType () || ArgTy->isScalarType ()) {
857
881
// the parameter is an object of standard layout type or scalar;
858
882
// the check for standard layout is done elsewhere
@@ -1017,6 +1041,7 @@ static const char *paramKind2Str(KernelParamKind K) {
1017
1041
CASE (accessor);
1018
1042
CASE (std_layout);
1019
1043
CASE (sampler);
1044
+ CASE (pointer);
1020
1045
default :
1021
1046
return " <ERROR>" ;
1022
1047
}
0 commit comments