Skip to content

Commit bcc0fab

Browse files
[SYCL] Enable dynamic linking of AOT compiled images for OpenCL GPU (#14778)
Currently only OpenCL and Level Zero GPU backends support linking of AOT images. Enabling this for Level Zero requires further changes in the adapter, but OpenCL works as-is.
1 parent 65ee744 commit bcc0fab

File tree

4 files changed

+99
-52
lines changed

4 files changed

+99
-52
lines changed

sycl/source/detail/program_manager/program_manager.cpp

Lines changed: 39 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -117,9 +117,8 @@ static ur_program_handle_t createSpirvProgram(const ContextImplPtr Context,
117117
}
118118

119119
// TODO replace this with a new UR API function
120-
static bool
121-
isDeviceBinaryTypeSupported(const context &C,
122-
ur::DeviceBinaryType Format) {
120+
static bool isDeviceBinaryTypeSupported(const context &C,
121+
ur::DeviceBinaryType Format) {
123122
// All formats except SYCL_DEVICE_BINARY_TYPE_SPIRV are supported.
124123
if (Format != SYCL_DEVICE_BINARY_TYPE_SPIRV)
125124
return true;
@@ -532,21 +531,19 @@ static const char *getUrDeviceTarget(const char *URDeviceTarget) {
532531
return UR_DEVICE_BINARY_TARGET_SPIRV32;
533532
else if (strcmp(URDeviceTarget, __SYCL_DEVICE_BINARY_TARGET_SPIRV64) == 0)
534533
return UR_DEVICE_BINARY_TARGET_SPIRV64;
535-
else if (strcmp(URDeviceTarget,
536-
__SYCL_DEVICE_BINARY_TARGET_SPIRV64_X86_64) == 0)
537-
return UR_DEVICE_BINARY_TARGET_SPIRV64_X86_64;
538-
else if (strcmp(URDeviceTarget, __SYCL_DEVICE_BINARY_TARGET_SPIRV64_GEN) ==
534+
else if (strcmp(URDeviceTarget, __SYCL_DEVICE_BINARY_TARGET_SPIRV64_X86_64) ==
539535
0)
536+
return UR_DEVICE_BINARY_TARGET_SPIRV64_X86_64;
537+
else if (strcmp(URDeviceTarget, __SYCL_DEVICE_BINARY_TARGET_SPIRV64_GEN) == 0)
540538
return UR_DEVICE_BINARY_TARGET_SPIRV64_GEN;
541-
else if (strcmp(URDeviceTarget,
542-
__SYCL_DEVICE_BINARY_TARGET_SPIRV64_FPGA) == 0)
539+
else if (strcmp(URDeviceTarget, __SYCL_DEVICE_BINARY_TARGET_SPIRV64_FPGA) ==
540+
0)
543541
return UR_DEVICE_BINARY_TARGET_SPIRV64_FPGA;
544542
else if (strcmp(URDeviceTarget, __SYCL_DEVICE_BINARY_TARGET_NVPTX64) == 0)
545543
return UR_DEVICE_BINARY_TARGET_NVPTX64;
546544
else if (strcmp(URDeviceTarget, __SYCL_DEVICE_BINARY_TARGET_AMDGCN) == 0)
547545
return UR_DEVICE_BINARY_TARGET_AMDGCN;
548-
else if (strcmp(URDeviceTarget, __SYCL_DEVICE_BINARY_TARGET_NATIVE_CPU) ==
549-
0)
546+
else if (strcmp(URDeviceTarget, __SYCL_DEVICE_BINARY_TARGET_NATIVE_CPU) == 0)
550547
return "native_cpu"; // todo: define UR_DEVICE_BINARY_TARGET_NATIVE_CPU;
551548

552549
return UR_DEVICE_BINARY_TARGET_UNKNOWN;
@@ -581,6 +578,18 @@ static bool compatibleWithDevice(RTDeviceBinaryImage *BinImage,
581578
return (0 == SuitableImageID);
582579
}
583580

581+
static bool checkLinkingSupport(device Dev, const RTDeviceBinaryImage &Img) {
582+
const char *Target = Img.getRawData().DeviceTargetSpec;
583+
// TODO replace with extension checks once implemented in UR.
584+
if (strcmp(Target, __SYCL_DEVICE_BINARY_TARGET_SPIRV64) == 0) {
585+
return true;
586+
}
587+
if (strcmp(Target, __SYCL_DEVICE_BINARY_TARGET_SPIRV64_GEN) == 0) {
588+
return Dev.is_gpu() && Dev.get_backend() == backend::opencl;
589+
}
590+
return false;
591+
}
592+
584593
std::set<RTDeviceBinaryImage *>
585594
ProgramManager::collectDeviceImageDepsForImportedSymbols(
586595
const RTDeviceBinaryImage &MainImg, device Dev) {
@@ -593,9 +602,10 @@ ProgramManager::collectDeviceImageDepsForImportedSymbols(
593602
HandledSymbols.insert(ISProp->Name);
594603
}
595604
ur::DeviceBinaryType Format = MainImg.getFormat();
596-
if (!WorkList.empty() && Format != SYCL_DEVICE_BINARY_TYPE_SPIRV)
605+
if (!WorkList.empty() && !checkLinkingSupport(Dev, MainImg))
597606
throw exception(make_error_code(errc::feature_not_supported),
598-
"Dynamic linking is not supported for AOT compilation yet");
607+
"Cannot resolve external symbols, linking is unsupported "
608+
"for the backend");
599609
while (!WorkList.empty()) {
600610
std::string Symbol = WorkList.front();
601611
WorkList.pop();
@@ -831,7 +841,8 @@ ur_program_handle_t ProgramManager::getBuiltURProgram(
831841
ProgramPtr BuiltProgram =
832842
build(std::move(ProgramManaged), ContextImpl, CompileOpts, LinkOpts,
833843
getSyclObjImpl(Device).get()->getHandleRef(), DeviceLibReqMask,
834-
ProgramsToLink);
844+
ProgramsToLink, /*CreatedFromBinary*/ Img.getFormat() !=
845+
SYCL_DEVICE_BINARY_TYPE_SPIRV);
835846
// Those extra programs won't be used anymore, just the final linked result
836847
for (ur_program_handle_t Prg : ProgramsToLink)
837848
Plugin->call(urProgramRelease, Prg);
@@ -1453,7 +1464,8 @@ ProgramManager::ProgramPtr ProgramManager::build(
14531464
ProgramPtr Program, const ContextImplPtr Context,
14541465
const std::string &CompileOptions, const std::string &LinkOptions,
14551466
ur_device_handle_t Device, uint32_t DeviceLibReqMask,
1456-
const std::vector<ur_program_handle_t> &ExtraProgramsToLink) {
1467+
const std::vector<ur_program_handle_t> &ExtraProgramsToLink,
1468+
bool CreatedFromBinary) {
14571469

14581470
if constexpr (DbgProgMgr > 0) {
14591471
std::cerr << ">>> ProgramManager::build(" << Program.get() << ", "
@@ -1501,16 +1513,19 @@ ProgramManager::ProgramPtr ProgramManager::build(
15011513
}
15021514

15031515
// Include the main program and compile/link everything together
1504-
auto Res = doCompile(Plugin, Program.get(), /*num devices =*/1, &Device,
1505-
Context->getHandleRef(), CompileOptions.c_str());
1506-
Plugin->checkUrResult<errc::build>(Res);
1516+
if (!CreatedFromBinary) {
1517+
auto Res = doCompile(Plugin, Program.get(), /*num devices =*/1, &Device,
1518+
Context->getHandleRef(), CompileOptions.c_str());
1519+
Plugin->checkUrResult<errc::build>(Res);
1520+
}
15071521
LinkPrograms.push_back(Program.get());
15081522

15091523
for (ur_program_handle_t Prg : ExtraProgramsToLink) {
1510-
auto Res = doCompile(Plugin, Prg, /*num devices =*/1, &Device,
1511-
Context->getHandleRef(), CompileOptions.c_str());
1512-
Plugin->checkUrResult(Res);
1513-
1524+
if (!CreatedFromBinary) {
1525+
auto Res = doCompile(Plugin, Prg, /*num devices =*/1, &Device,
1526+
Context->getHandleRef(), CompileOptions.c_str());
1527+
Plugin->checkUrResult<errc::build>(Res);
1528+
}
15141529
LinkPrograms.push_back(Prg);
15151530
}
15161531

@@ -2700,8 +2715,8 @@ ur_kernel_handle_t ProgramManager::getOrCreateMaterializedKernel(
27002715
/*For non SPIR-V devices DeviceLibReqdMask is always 0*/ 0,
27012716
ExtraProgramsToLink);
27022717
ur_kernel_handle_t UrKernel{nullptr};
2703-
Plugin->call<errc::kernel_not_supported>(urKernelCreate,
2704-
BuildProgram.get(), KernelName.c_str(), &UrKernel);
2718+
Plugin->call<errc::kernel_not_supported>(urKernelCreate, BuildProgram.get(),
2719+
KernelName.c_str(), &UrKernel);
27052720
{
27062721
std::lock_guard<std::mutex> KernelIDsGuard(m_KernelIDsMutex);
27072722
m_MaterializedKernels[KernelName][SpecializationConsts] = UrKernel;

sycl/source/detail/program_manager/program_manager.hpp

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -123,14 +123,11 @@ class ProgramManager {
123123
/// \return A pair consisting of the UR program created with the corresponding
124124
/// device code binary and a boolean that is true if the device code
125125
/// binary was found in the persistent cache and false otherwise.
126-
std::pair<ur_program_handle_t, bool>
127-
getOrCreateURProgram(
126+
std::pair<ur_program_handle_t, bool> getOrCreateURProgram(
128127
const RTDeviceBinaryImage &Img,
129128
const std::vector<const RTDeviceBinaryImage *> &AllImages,
130-
const context &Context,
131-
const device &Device,
132-
const std::string &CompileAndLinkOptions,
133-
SerializedObj SpecConsts);
129+
const context &Context, const device &Device,
130+
const std::string &CompileAndLinkOptions, SerializedObj SpecConsts);
134131
/// Builds or retrieves from cache a program defining the kernel with given
135132
/// name.
136133
/// \param M identifies the OS module the kernel comes from (multiple OS
@@ -306,7 +303,8 @@ class ProgramManager {
306303
const std::string &CompileOptions,
307304
const std::string &LinkOptions, ur_device_handle_t Device,
308305
uint32_t DeviceLibReqMask,
309-
const std::vector<ur_program_handle_t> &ProgramsToLink);
306+
const std::vector<ur_program_handle_t> &ProgramsToLink,
307+
bool CreatedFromBinary = false);
310308

311309
/// Dumps image to current directory
312310
void dumpImage(const RTDeviceBinaryImage &Img, uint32_t SequenceID = 0) const;

sycl/unittests/helpers/RuntimeLinkingCommon.hpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
// Helper holder for all the data we want to capture from mocked APIs
66
struct LinkingCapturesHolder {
77
unsigned NumOfUrProgramCreateCalls = 0;
8+
unsigned NumOfUrProgramCreateWithBinaryCalls = 0;
89
unsigned NumOfUrProgramLinkCalls = 0;
910
unsigned ProgramUsedToCreateKernel = 0;
1011
std::vector<unsigned> LinkedPrograms;
@@ -37,6 +38,16 @@ static ur_result_t redefined_urProgramCreateWithIL(void *pParams) {
3738
return UR_RESULT_SUCCESS;
3839
}
3940

41+
static ur_result_t redefined_urProgramCreateWithBinary(void *pParams) {
42+
auto Params = *static_cast<ur_program_create_with_binary_params_t *>(pParams);
43+
auto *Magic = reinterpret_cast<const unsigned char *>(*Params.ppBinary);
44+
ur_program_handle_t *res = *Params.pphProgram;
45+
*res = mock::createDummyHandle<ur_program_handle_t>(sizeof(unsigned));
46+
reinterpret_cast<mock::dummy_handle_t>(*res)->setDataAs<unsigned>(*Magic);
47+
++CapturedLinkingData.NumOfUrProgramCreateWithBinaryCalls;
48+
return UR_RESULT_SUCCESS;
49+
}
50+
4051
static ur_result_t redefined_urProgramLinkExp(void *pParams) {
4152
auto Params = *static_cast<ur_program_link_exp_params_t *>(pParams);
4253
unsigned ResProgram = 1;
@@ -69,6 +80,8 @@ static ur_result_t redefined_urKernelCreate(void *pParams) {
6980
static void setupRuntimeLinkingMock() {
7081
mock::getCallbacks().set_replace_callback("urProgramCreateWithIL",
7182
redefined_urProgramCreateWithIL);
83+
mock::getCallbacks().set_replace_callback(
84+
"urProgramCreateWithBinary", redefined_urProgramCreateWithBinary);
7285
mock::getCallbacks().set_replace_callback("urProgramLinkExp",
7386
redefined_urProgramLinkExp);
7487
mock::getCallbacks().set_replace_callback("urKernelCreate",

sycl/unittests/program_manager/DynamicLinking.cpp

Lines changed: 42 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -55,13 +55,12 @@ createPropertySet(const std::vector<std::string> &Symbols) {
5555
return Props;
5656
}
5757

58-
sycl::unittest::UrImage
59-
generateImage(std::initializer_list<std::string> KernelNames,
60-
const std::vector<std::string> &ExportedSymbols,
61-
const std::vector<std::string> &ImportedSymbols,
62-
unsigned char Magic,
63-
sycl::detail::ur::DeviceBinaryType BinType =
64-
SYCL_DEVICE_BINARY_TYPE_SPIRV) {
58+
sycl::unittest::UrImage generateImage(
59+
std::initializer_list<std::string> KernelNames,
60+
const std::vector<std::string> &ExportedSymbols,
61+
const std::vector<std::string> &ImportedSymbols, unsigned char Magic,
62+
sycl::detail::ur::DeviceBinaryType BinType = SYCL_DEVICE_BINARY_TYPE_SPIRV,
63+
const char *DeviceTargetSpec = __SYCL_DEVICE_BINARY_TARGET_SPIRV64) {
6564
sycl::unittest::UrPropertySet PropSet;
6665
if (!ExportedSymbols.empty())
6766
PropSet.insert(__SYCL_PROPERTY_SET_SYCL_EXPORTED_SYMBOLS,
@@ -74,14 +73,13 @@ generateImage(std::initializer_list<std::string> KernelNames,
7473
sycl::unittest::UrArray<sycl::unittest::UrOffloadEntry> Entries =
7574
sycl::unittest::makeEmptyKernels(KernelNames);
7675

77-
sycl::unittest::UrImage Img{
78-
BinType,
79-
__SYCL_DEVICE_BINARY_TARGET_SPIRV64, // DeviceTargetSpec
80-
"", // Compile options
81-
"", // Link options
82-
std::move(Bin),
83-
std::move(Entries),
84-
std::move(PropSet)};
76+
sycl::unittest::UrImage Img{BinType,
77+
DeviceTargetSpec,
78+
"", // Compile options
79+
"", // Link options
80+
std::move(Bin),
81+
std::move(Entries),
82+
std::move(PropSet)};
8583

8684
return Img;
8785
}
@@ -103,7 +101,8 @@ static sycl::unittest::UrImage Imgs[] = {
103101
{"BasicCaseKernelDepDep"}, BASIC_CASE_PRG_DEP),
104102
generateImage({"BasicCaseKernelDep"}, {"BasicCaseKernelDep"},
105103
{"BasicCaseKernelDepDep"}, BASIC_CASE_PRG_DEP_NATIVE,
106-
SYCL_DEVICE_BINARY_TYPE_NATIVE),
104+
SYCL_DEVICE_BINARY_TYPE_NATIVE,
105+
__SYCL_DEVICE_BINARY_TARGET_SPIRV64_GEN),
107106
generateImage({"BasicCaseKernelDepDep"}, {"BasicCaseKernelDepDep"}, {},
108107
BASIC_CASE_PRG_DEP_DEP),
109108
generateImage({"UnresolvedDepKernel"}, {},
@@ -115,9 +114,11 @@ static sycl::unittest::UrImage Imgs[] = {
115114
{"MutualDepKernelADep"}, {"MutualDepKernelBDep"},
116115
MUTUAL_DEP_PRG_B),
117116
generateImage({"AOTCaseKernel"}, {}, {"AOTCaseKernelDep"},
118-
AOT_CASE_PRG_NATIVE, SYCL_DEVICE_BINARY_TYPE_NATIVE),
117+
AOT_CASE_PRG_NATIVE, SYCL_DEVICE_BINARY_TYPE_NATIVE,
118+
__SYCL_DEVICE_BINARY_TARGET_SPIRV64_GEN),
119119
generateImage({"AOTCaseKernelDep"}, {"AOTCaseKernelDep"}, {},
120-
AOT_CASE_PRG_DEP_NATIVE, SYCL_DEVICE_BINARY_TYPE_NATIVE)};
120+
AOT_CASE_PRG_DEP_NATIVE, SYCL_DEVICE_BINARY_TYPE_NATIVE,
121+
__SYCL_DEVICE_BINARY_TARGET_SPIRV64_GEN)};
121122

122123
// Registers mock devices images in the SYCL RT
123124
static sycl::unittest::UrImageArray<9> ImgArray{Imgs};
@@ -184,16 +185,36 @@ TEST(DynamicLinking, MutualDependency) {
184185
}
185186

186187
TEST(DynamicLinking, AheadOfTime) {
188+
sycl::unittest::UrMock<> Mock;
189+
setupRuntimeLinkingMock();
190+
191+
sycl::platform Plt = sycl::platform();
192+
sycl::queue Q(Plt.get_devices()[0]);
193+
194+
CapturedLinkingData.clear();
195+
196+
Q.single_task<DynamicLinkingTest::AOTCaseKernel>([=]() {});
197+
ASSERT_EQ(CapturedLinkingData.NumOfUrProgramCreateWithBinaryCalls, 2u);
198+
// Both programs should be linked together.
199+
ASSERT_EQ(CapturedLinkingData.NumOfUrProgramLinkCalls, 1u);
200+
ASSERT_TRUE(CapturedLinkingData.LinkedProgramsContains(
201+
{AOT_CASE_PRG_NATIVE, AOT_CASE_PRG_DEP_NATIVE}));
202+
// And the linked program should be used to create a kernel.
203+
ASSERT_EQ(CapturedLinkingData.ProgramUsedToCreateKernel,
204+
AOT_CASE_PRG_NATIVE * AOT_CASE_PRG_DEP_NATIVE);
205+
}
206+
207+
TEST(DynamicLinking, AheadOfTimeUnsupported) {
187208
try {
188-
sycl::unittest::UrMock<> Mock;
209+
sycl::unittest::UrMock<sycl::backend::ext_oneapi_level_zero> Mock;
189210
sycl::platform Plt = sycl::platform();
190211
sycl::queue Q(Plt.get_devices()[0]);
191212
Q.single_task<DynamicLinkingTest::AOTCaseKernel>([=]() {});
192213
FAIL();
193214
} catch (sycl::exception &e) {
194215
EXPECT_EQ(e.code(), sycl::errc::feature_not_supported);
195-
EXPECT_STREQ(e.what(),
196-
"Dynamic linking is not supported for AOT compilation yet");
216+
EXPECT_STREQ(e.what(), "Cannot resolve external symbols, linking is "
217+
"unsupported for the backend");
197218
}
198219
}
199220

0 commit comments

Comments
 (0)