Skip to content

Commit 6b7fbd9

Browse files
jbrodmanbader
authored andcommitted
[SYCL][USM] Initial clang support for USM (#256)
- Add support pointer types in Integration Header - Put kernel pointer parameters in the global address space - Fix CheckSYCLType to support proper recursion through pointers Signed-off-by: James Brodman <james.brodman@intel.com>
1 parent a5ad7a1 commit 6b7fbd9

File tree

5 files changed

+71
-8
lines changed

5 files changed

+71
-8
lines changed

clang/include/clang/Sema/Sema.h

+2-1
Original file line numberDiff line numberDiff line change
@@ -301,7 +301,8 @@ class SYCLIntegrationHeader {
301301
kind_accessor = kind_first,
302302
kind_std_layout,
303303
kind_sampler,
304-
kind_last = kind_sampler
304+
kind_pointer,
305+
kind_last = kind_pointer
305306
};
306307

307308
public:

clang/lib/Sema/SemaSYCL.cpp

+30-5
Original file line numberDiff line numberDiff line change
@@ -331,6 +331,11 @@ class MarkDeviceFunction : public RecursiveASTVisitor<MarkDeviceFunction> {
331331

332332
private:
333333
bool CheckSYCLType(QualType Ty, SourceRange Loc) {
334+
llvm::DenseSet<QualType> visited;
335+
return CheckSYCLType(Ty, Loc, visited);
336+
}
337+
338+
bool CheckSYCLType(QualType Ty, SourceRange Loc, llvm::DenseSet<QualType> &Visited) {
334339
if (Ty->isVariableArrayType()) {
335340
SemaRef.Diag(Loc.getBegin(), diag::err_vla_unsupported);
336341
return false;
@@ -339,6 +344,11 @@ class MarkDeviceFunction : public RecursiveASTVisitor<MarkDeviceFunction> {
339344
while (Ty->isAnyPointerType() || Ty->isArrayType())
340345
Ty = QualType{Ty->getPointeeOrArrayElementType(), 0};
341346

347+
// Pointers complicate recursion. Add this type to Visited.
348+
// If already there, bail out.
349+
if (!Visited.insert(Ty).second)
350+
return true;
351+
342352
if (const auto *CRD = Ty->getAsCXXRecordDecl()) {
343353
if (CRD->isPolymorphic()) {
344354
SemaRef.Diag(CRD->getLocation(), diag::err_sycl_virtual_types);
@@ -347,25 +357,25 @@ class MarkDeviceFunction : public RecursiveASTVisitor<MarkDeviceFunction> {
347357
}
348358

349359
for (const auto &Field : CRD->fields()) {
350-
if (!CheckSYCLType(Field->getType(), Field->getSourceRange())) {
360+
if (!CheckSYCLType(Field->getType(), Field->getSourceRange(), Visited)) {
351361
SemaRef.Diag(Loc.getBegin(), diag::note_sycl_used_here);
352362
return false;
353363
}
354364
}
355365
} else if (const auto *RD = Ty->getAsRecordDecl()) {
356366
for (const auto &Field : RD->fields()) {
357-
if (!CheckSYCLType(Field->getType(), Field->getSourceRange())) {
367+
if (!CheckSYCLType(Field->getType(), Field->getSourceRange(), Visited)) {
358368
SemaRef.Diag(Loc.getBegin(), diag::note_sycl_used_here);
359369
return false;
360370
}
361371
}
362372
} else if (const auto *FPTy = dyn_cast<FunctionProtoType>(Ty)) {
363373
for (const auto &ParamTy : FPTy->param_types())
364-
if (!CheckSYCLType(ParamTy, Loc))
374+
if (!CheckSYCLType(ParamTy, Loc, Visited))
365375
return false;
366-
return CheckSYCLType(FPTy->getReturnType(), Loc);
376+
return CheckSYCLType(FPTy->getReturnType(), Loc, Visited);
367377
} else if (const auto *FTy = dyn_cast<FunctionType>(Ty)) {
368-
return CheckSYCLType(FTy->getReturnType(), Loc);
378+
return CheckSYCLType(FTy->getReturnType(), Loc, Visited);
369379
}
370380
return true;
371381
}
@@ -766,6 +776,16 @@ static void buildArgTys(ASTContext &Context, CXXRecordDecl *KernelObj,
766776

767777
// Create descriptors for each accessor field in the class or struct
768778
createParamDescForWrappedAccessors(Fld, ArgTy);
779+
} else if (ArgTy->isPointerType()) {
780+
// Pointer Arguments need to be in the global address space
781+
QualType PointeeTy = ArgTy->getPointeeType();
782+
Qualifiers Quals = PointeeTy.getQualifiers();
783+
Quals.setAddressSpace(LangAS::opencl_global);
784+
PointeeTy = Context.getQualifiedType(PointeeTy.getUnqualifiedType(),
785+
Quals);
786+
QualType ModTy = Context.getPointerType(PointeeTy);
787+
788+
CreateAndAddPrmDsc(Fld, ModTy);
769789
} else if (ArgTy->isScalarType()) {
770790
CreateAndAddPrmDsc(Fld, ArgTy);
771791
} else {
@@ -853,6 +873,10 @@ static void populateIntHeader(SYCLIntegrationHeader &H, const StringRef Name,
853873
uint64_t Sz = Ctx.getTypeSizeInChars(SamplerArg->getType()).getQuantity();
854874
H.addParamDesc(SYCLIntegrationHeader::kind_sampler,
855875
static_cast<unsigned>(Sz), static_cast<unsigned>(Offset));
876+
} else if (ArgTy->isPointerType()) {
877+
uint64_t Sz = Ctx.getTypeSizeInChars(Fld->getType()).getQuantity();
878+
H.addParamDesc(SYCLIntegrationHeader::kind_pointer,
879+
static_cast<unsigned>(Sz), static_cast<unsigned>(Offset));
856880
} else if (ArgTy->isStructureOrClassType() || ArgTy->isScalarType()) {
857881
// the parameter is an object of standard layout type or scalar;
858882
// the check for standard layout is done elsewhere
@@ -1017,6 +1041,7 @@ static const char *paramKind2Str(KernelParamKind K) {
10171041
CASE(accessor);
10181042
CASE(std_layout);
10191043
CASE(sampler);
1044+
CASE(pointer);
10201045
default:
10211046
return "<ERROR>";
10221047
}
+34
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
// RUN: %clang_cc1 -std=c++11 -I %S/Inputs -fsycl-is-device -ast-dump %s | FileCheck %s
2+
// RUN: %clang -I %S/Inputs --sycl -Xclang -fsycl-int-header=%t.h %s -c -o kernel.spv
3+
// RUN: FileCheck -input-file=%t.h %s --check-prefix=INT-HEADER
4+
5+
// INT-HEADER:{ kernel_param_kind_t::kind_pointer, 8, 0 },
6+
// INT-HEADER:{ kernel_param_kind_t::kind_pointer, 8, 8 },
7+
8+
//==--usm-int-header.cpp - USM kernel param aspace and int header test -----==//
9+
//
10+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
11+
// See https://llvm.org/LICENSE.txt for license information.
12+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
13+
//
14+
//===----------------------------------------------------------------------===//
15+
16+
17+
#include <sycl.hpp>
18+
19+
template <typename name, typename Func>
20+
__attribute__((sycl_kernel)) void kernel(Func kernelFunc) {
21+
kernelFunc();
22+
}
23+
24+
int main() {
25+
int* x;
26+
float* y;
27+
28+
kernel<class usm_test>([=]() {
29+
*x = 42;
30+
*y = 3.14;
31+
});
32+
}
33+
34+
// CHECK: FunctionDecl {{.*}}usm_test 'void (__global int *, __global float *)'

sycl/include/CL/sycl/detail/kernel_desc.hpp

+2-1
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,8 @@ class half;
3333
enum class kernel_param_kind_t {
3434
kind_accessor,
3535
kind_std_layout, // standard layout object parameters
36-
kind_sampler
36+
kind_sampler,
37+
kind_pointer
3738
};
3839

3940
// describes a kernel parameter

sycl/include/CL/sycl/handler.hpp

+3-1
Original file line numberDiff line numberDiff line change
@@ -239,9 +239,11 @@ class handler {
239239
const auto kind_std_layout = detail::kernel_param_kind_t::kind_std_layout;
240240
const auto kind_accessor = detail::kernel_param_kind_t::kind_accessor;
241241
const auto kind_sampler = detail::kernel_param_kind_t::kind_sampler;
242+
const auto kind_pointer = detail::kernel_param_kind_t::kind_pointer;
242243

243244
switch (Kind) {
244-
case kind_std_layout: {
245+
case kind_std_layout:
246+
case kind_pointer: {
245247
MArgs.emplace_back(Kind, Ptr, Size, Index + IndexShift);
246248
break;
247249
}

0 commit comments

Comments
 (0)