@@ -117,9 +117,8 @@ static ur_program_handle_t createSpirvProgram(const ContextImplPtr Context,
117
117
}
118
118
119
119
// TODO replace this with a new UR API function
120
- static bool
121
- isDeviceBinaryTypeSupported (const context &C,
122
- ur::DeviceBinaryType Format) {
120
+ static bool isDeviceBinaryTypeSupported (const context &C,
121
+ ur::DeviceBinaryType Format) {
123
122
// All formats except SYCL_DEVICE_BINARY_TYPE_SPIRV are supported.
124
123
if (Format != SYCL_DEVICE_BINARY_TYPE_SPIRV)
125
124
return true ;
@@ -532,21 +531,19 @@ static const char *getUrDeviceTarget(const char *URDeviceTarget) {
532
531
return UR_DEVICE_BINARY_TARGET_SPIRV32;
533
532
else if (strcmp (URDeviceTarget, __SYCL_DEVICE_BINARY_TARGET_SPIRV64) == 0 )
534
533
return UR_DEVICE_BINARY_TARGET_SPIRV64;
535
- else if (strcmp (URDeviceTarget,
536
- __SYCL_DEVICE_BINARY_TARGET_SPIRV64_X86_64) == 0 )
537
- return UR_DEVICE_BINARY_TARGET_SPIRV64_X86_64;
538
- else if (strcmp (URDeviceTarget, __SYCL_DEVICE_BINARY_TARGET_SPIRV64_GEN) ==
534
+ else if (strcmp (URDeviceTarget, __SYCL_DEVICE_BINARY_TARGET_SPIRV64_X86_64) ==
539
535
0 )
536
+ return UR_DEVICE_BINARY_TARGET_SPIRV64_X86_64;
537
+ else if (strcmp (URDeviceTarget, __SYCL_DEVICE_BINARY_TARGET_SPIRV64_GEN) == 0 )
540
538
return UR_DEVICE_BINARY_TARGET_SPIRV64_GEN;
541
- else if (strcmp (URDeviceTarget,
542
- __SYCL_DEVICE_BINARY_TARGET_SPIRV64_FPGA) == 0 )
539
+ else if (strcmp (URDeviceTarget, __SYCL_DEVICE_BINARY_TARGET_SPIRV64_FPGA) ==
540
+ 0 )
543
541
return UR_DEVICE_BINARY_TARGET_SPIRV64_FPGA;
544
542
else if (strcmp (URDeviceTarget, __SYCL_DEVICE_BINARY_TARGET_NVPTX64) == 0 )
545
543
return UR_DEVICE_BINARY_TARGET_NVPTX64;
546
544
else if (strcmp (URDeviceTarget, __SYCL_DEVICE_BINARY_TARGET_AMDGCN) == 0 )
547
545
return UR_DEVICE_BINARY_TARGET_AMDGCN;
548
- else if (strcmp (URDeviceTarget, __SYCL_DEVICE_BINARY_TARGET_NATIVE_CPU) ==
549
- 0 )
546
+ else if (strcmp (URDeviceTarget, __SYCL_DEVICE_BINARY_TARGET_NATIVE_CPU) == 0 )
550
547
return " native_cpu" ; // todo: define UR_DEVICE_BINARY_TARGET_NATIVE_CPU;
551
548
552
549
return UR_DEVICE_BINARY_TARGET_UNKNOWN;
@@ -581,6 +578,18 @@ static bool compatibleWithDevice(RTDeviceBinaryImage *BinImage,
581
578
return (0 == SuitableImageID);
582
579
}
583
580
581
+ static bool checkLinkingSupport (device Dev, const RTDeviceBinaryImage &Img) {
582
+ const char *Target = Img.getRawData ().DeviceTargetSpec ;
583
+ // TODO replace with extension checks once implemented in UR.
584
+ if (strcmp (Target, __SYCL_DEVICE_BINARY_TARGET_SPIRV64) == 0 ) {
585
+ return true ;
586
+ }
587
+ if (strcmp (Target, __SYCL_DEVICE_BINARY_TARGET_SPIRV64_GEN) == 0 ) {
588
+ return Dev.is_gpu () && Dev.get_backend () == backend::opencl;
589
+ }
590
+ return false ;
591
+ }
592
+
584
593
std::set<RTDeviceBinaryImage *>
585
594
ProgramManager::collectDeviceImageDepsForImportedSymbols (
586
595
const RTDeviceBinaryImage &MainImg, device Dev) {
@@ -593,9 +602,10 @@ ProgramManager::collectDeviceImageDepsForImportedSymbols(
593
602
HandledSymbols.insert (ISProp->Name );
594
603
}
595
604
ur::DeviceBinaryType Format = MainImg.getFormat ();
596
- if (!WorkList.empty () && Format != SYCL_DEVICE_BINARY_TYPE_SPIRV )
605
+ if (!WorkList.empty () && ! checkLinkingSupport (Dev, MainImg) )
597
606
throw exception (make_error_code (errc::feature_not_supported),
598
- " Dynamic linking is not supported for AOT compilation yet" );
607
+ " Cannot resolve external symbols, linking is unsupported "
608
+ " for the backend" );
599
609
while (!WorkList.empty ()) {
600
610
std::string Symbol = WorkList.front ();
601
611
WorkList.pop ();
@@ -831,7 +841,8 @@ ur_program_handle_t ProgramManager::getBuiltURProgram(
831
841
ProgramPtr BuiltProgram =
832
842
build (std::move (ProgramManaged), ContextImpl, CompileOpts, LinkOpts,
833
843
getSyclObjImpl (Device).get ()->getHandleRef (), DeviceLibReqMask,
834
- ProgramsToLink);
844
+ ProgramsToLink, /* CreatedFromBinary*/ Img.getFormat () !=
845
+ SYCL_DEVICE_BINARY_TYPE_SPIRV);
835
846
// Those extra programs won't be used anymore, just the final linked result
836
847
for (ur_program_handle_t Prg : ProgramsToLink)
837
848
Plugin->call (urProgramRelease, Prg);
@@ -1453,7 +1464,8 @@ ProgramManager::ProgramPtr ProgramManager::build(
1453
1464
ProgramPtr Program, const ContextImplPtr Context,
1454
1465
const std::string &CompileOptions, const std::string &LinkOptions,
1455
1466
ur_device_handle_t Device, uint32_t DeviceLibReqMask,
1456
- const std::vector<ur_program_handle_t > &ExtraProgramsToLink) {
1467
+ const std::vector<ur_program_handle_t > &ExtraProgramsToLink,
1468
+ bool CreatedFromBinary) {
1457
1469
1458
1470
if constexpr (DbgProgMgr > 0 ) {
1459
1471
std::cerr << " >>> ProgramManager::build(" << Program.get () << " , "
@@ -1501,16 +1513,19 @@ ProgramManager::ProgramPtr ProgramManager::build(
1501
1513
}
1502
1514
1503
1515
// Include the main program and compile/link everything together
1504
- auto Res = doCompile (Plugin, Program.get (), /* num devices =*/ 1 , &Device,
1505
- Context->getHandleRef (), CompileOptions.c_str ());
1506
- Plugin->checkUrResult <errc::build>(Res);
1516
+ if (!CreatedFromBinary) {
1517
+ auto Res = doCompile (Plugin, Program.get (), /* num devices =*/ 1 , &Device,
1518
+ Context->getHandleRef (), CompileOptions.c_str ());
1519
+ Plugin->checkUrResult <errc::build>(Res);
1520
+ }
1507
1521
LinkPrograms.push_back (Program.get ());
1508
1522
1509
1523
for (ur_program_handle_t Prg : ExtraProgramsToLink) {
1510
- auto Res = doCompile (Plugin, Prg, /* num devices =*/ 1 , &Device,
1511
- Context->getHandleRef (), CompileOptions.c_str ());
1512
- Plugin->checkUrResult (Res);
1513
-
1524
+ if (!CreatedFromBinary) {
1525
+ auto Res = doCompile (Plugin, Prg, /* num devices =*/ 1 , &Device,
1526
+ Context->getHandleRef (), CompileOptions.c_str ());
1527
+ Plugin->checkUrResult <errc::build>(Res);
1528
+ }
1514
1529
LinkPrograms.push_back (Prg);
1515
1530
}
1516
1531
@@ -2700,8 +2715,8 @@ ur_kernel_handle_t ProgramManager::getOrCreateMaterializedKernel(
2700
2715
/* For non SPIR-V devices DeviceLibReqdMask is always 0*/ 0 ,
2701
2716
ExtraProgramsToLink);
2702
2717
ur_kernel_handle_t UrKernel{nullptr };
2703
- Plugin->call <errc::kernel_not_supported>(urKernelCreate,
2704
- BuildProgram. get (), KernelName.c_str (), &UrKernel);
2718
+ Plugin->call <errc::kernel_not_supported>(urKernelCreate, BuildProgram. get (),
2719
+ KernelName.c_str (), &UrKernel);
2705
2720
{
2706
2721
std::lock_guard<std::mutex> KernelIDsGuard (m_KernelIDsMutex);
2707
2722
m_MaterializedKernels[KernelName][SpecializationConsts] = UrKernel;
0 commit comments