diff --git a/sycl/include/CL/sycl/detail/cg.hpp b/sycl/include/CL/sycl/detail/cg.hpp index fc6042bf69183..170c0f39906c3 100644 --- a/sycl/include/CL/sycl/detail/cg.hpp +++ b/sycl/include/CL/sycl/detail/cg.hpp @@ -94,6 +94,10 @@ namespace detail { enum class ExtendedMembersType : unsigned int { HANDLER_KERNEL_BUNDLE = 0, HANDLER_MEM_ADVICE, + // handler_impl is stored in the exended members to avoid breaking ABI. + // TODO: This should be made a member of the handler class once ABI can be + // broken. + HANDLER_IMPL, }; // Holds a pointer to an object of an arbitrary type and an ID value which diff --git a/sycl/include/CL/sycl/handler.hpp b/sycl/include/CL/sycl/handler.hpp index cbe982540e07b..e64bc4d77eea8 100644 --- a/sycl/include/CL/sycl/handler.hpp +++ b/sycl/include/CL/sycl/handler.hpp @@ -80,6 +80,7 @@ template class buffer; namespace detail { +class handler_impl; class kernel_impl; class queue_impl; class stream_impl; @@ -1116,6 +1117,12 @@ class __SYCL_EXPORT handler { kernel_parallel_for_work_group(KernelFunc); } + std::shared_ptr getHandlerImpl() const; + + void setStateExplicitKernelBundle(); + void setStateSpecConstSet(); + bool isStateExplicitKernelBundle() const; + std::shared_ptr getOrInsertHandlerKernelBundle(bool Insert) const; @@ -1150,6 +1157,8 @@ class __SYCL_EXPORT handler { void set_specialization_constant( typename std::remove_reference_t::value_type Value) { + setStateSpecConstSet(); + std::shared_ptr KernelBundleImplPtr = getOrInsertHandlerKernelBundle(/*Insert=*/true); @@ -1162,6 +1171,11 @@ class __SYCL_EXPORT handler { typename std::remove_reference_t::value_type get_specialization_constant() const { + if (isStateExplicitKernelBundle()) + throw sycl::exception(make_error_code(errc::invalid), + "Specialization constants cannot be read after " + "explicitly setting the used kernel bundle"); + std::shared_ptr KernelBundleImplPtr = getOrInsertHandlerKernelBundle(/*Insert=*/true); @@ -1174,6 +1188,7 @@ class __SYCL_EXPORT handler { void use_kernel_bundle(const kernel_bundle &ExecBundle) { + setStateExplicitKernelBundle(); setHandlerKernelBundle(detail::getSyclObjImpl(ExecBundle)); } diff --git a/sycl/source/detail/handler_impl.hpp b/sycl/source/detail/handler_impl.hpp new file mode 100644 index 0000000000000..96f1621d28d34 --- /dev/null +++ b/sycl/source/detail/handler_impl.hpp @@ -0,0 +1,58 @@ +//==---------------- handler_impl.hpp - SYCL handler -----------------------==// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include + +__SYCL_INLINE_NAMESPACE(cl) { +namespace sycl { +namespace detail { + +using KernelBundleImplPtr = std::shared_ptr; + +enum class HandlerSubmissionState : std::uint8_t { + NO_STATE = 0, + EXPLICIT_KERNEL_BUNDLE_STATE, + SPEC_CONST_SET_STATE, +}; + +class handler_impl { +public: + handler_impl() = default; + + void setStateExplicitKernelBundle() { + if (MSubmissionState == HandlerSubmissionState::SPEC_CONST_SET_STATE) + throw sycl::exception( + make_error_code(errc::invalid), + "Kernel bundle cannot be explicitly set after a specialization " + "constant has been set"); + MSubmissionState = HandlerSubmissionState::EXPLICIT_KERNEL_BUNDLE_STATE; + } + + void setStateSpecConstSet() { + if (MSubmissionState == + HandlerSubmissionState::EXPLICIT_KERNEL_BUNDLE_STATE) + throw sycl::exception(make_error_code(errc::invalid), + "Specialization constants cannot be set after " + "explicitly setting the used kernel bundle"); + MSubmissionState = HandlerSubmissionState::SPEC_CONST_SET_STATE; + } + + bool isStateExplicitKernelBundle() const { + return MSubmissionState == + HandlerSubmissionState::EXPLICIT_KERNEL_BUNDLE_STATE; + } + + /// Registers mutually exclusive submission states. + HandlerSubmissionState MSubmissionState = HandlerSubmissionState::NO_STATE; +}; + +} // namespace detail +} // namespace sycl +} // __SYCL_INLINE_NAMESPACE(cl) diff --git a/sycl/source/handler.cpp b/sycl/source/handler.cpp index 7f2b0f14dc7da..8f12abaee8a1f 100644 --- a/sycl/source/handler.cpp +++ b/sycl/source/handler.cpp @@ -17,6 +17,7 @@ #include #include #include +#include #include #include #include @@ -27,8 +28,56 @@ namespace sycl { handler::handler(std::shared_ptr Queue, bool IsHost) : MQueue(std::move(Queue)), MIsHost(IsHost) { - MSharedPtrStorage.emplace_back( - std::make_shared>()); + // Create extended members and insert handler_impl + // TODO: When allowed to break ABI the handler_impl should be made a member + // of the handler class. + auto ExtendedMembers = + std::make_shared>(); + detail::ExtendedMemberT HandlerImplMember = { + detail::ExtendedMembersType::HANDLER_IMPL, + std::make_shared()}; + ExtendedMembers->push_back(std::move(HandlerImplMember)); + MSharedPtrStorage.push_back(std::move(ExtendedMembers)); +} + +/// Gets the handler_impl at the start of the extended members. +std::shared_ptr handler::getHandlerImpl() const { + std::lock_guard Lock( + detail::GlobalHandler::instance().getHandlerExtendedMembersMutex()); + + assert(!MSharedPtrStorage.empty()); + + std::shared_ptr> ExtendedMembersVec = + detail::convertToExtendedMembers(MSharedPtrStorage[0]); + + assert(ExtendedMembersVec->size() > 0); + + auto HandlerImplMember = (*ExtendedMembersVec)[0]; + + assert(detail::ExtendedMembersType::HANDLER_IMPL == HandlerImplMember.MType); + + return std::static_pointer_cast( + HandlerImplMember.MData); +} + +// Sets the submission state to indicate that an explicit kernel bundle has been +// set. Throws a sycl::exception with errc::invalid if the current state +// indicates that a specialization constant has been set. +void handler::setStateExplicitKernelBundle() { + getHandlerImpl()->setStateExplicitKernelBundle(); +} + +// Sets the submission state to indicate that a specialization constant has been +// set. Throws a sycl::exception with errc::invalid if the current state +// indicates that an explicit kernel bundle has been set. +void handler::setStateSpecConstSet() { + getHandlerImpl()->setStateSpecConstSet(); +} + +// Returns true if the submission state is EXPLICIT_KERNEL_BUNDLE_STATE and +// false otherwise. +bool handler::isStateExplicitKernelBundle() const { + return getHandlerImpl()->isStateExplicitKernelBundle(); } // Returns a shared_ptr to kernel_bundle stored in the extended members vector. @@ -43,12 +92,11 @@ handler::getOrInsertHandlerKernelBundle(bool Insert) const { assert(!MSharedPtrStorage.empty()); - std::shared_ptr> ExendedMembersVec = + std::shared_ptr> ExtendedMembersVec = detail::convertToExtendedMembers(MSharedPtrStorage[0]); - // Look for the kernel bundle in extended members std::shared_ptr KernelBundleImpPtr; - for (const detail::ExtendedMemberT &EMember : *ExendedMembersVec) + for (const detail::ExtendedMemberT &EMember : *ExtendedMembersVec) if (detail::ExtendedMembersType::HANDLER_KERNEL_BUNDLE == EMember.MType) { KernelBundleImpPtr = std::static_pointer_cast(EMember.MData); @@ -66,8 +114,7 @@ handler::getOrInsertHandlerKernelBundle(bool Insert) const { detail::ExtendedMemberT EMember = { detail::ExtendedMembersType::HANDLER_KERNEL_BUNDLE, KernelBundleImpPtr}; - - ExendedMembersVec->push_back(EMember); + ExtendedMembersVec->push_back(EMember); } return KernelBundleImpPtr; @@ -85,16 +132,18 @@ void handler::setHandlerKernelBundle( std::shared_ptr> ExendedMembersVec = detail::convertToExtendedMembers(MSharedPtrStorage[0]); - for (detail::ExtendedMemberT &EMember : *ExendedMembersVec) + // Look for kernel bundle in extended members and overwrite it. + for (detail::ExtendedMemberT &EMember : *ExendedMembersVec) { if (detail::ExtendedMembersType::HANDLER_KERNEL_BUNDLE == EMember.MType) { EMember.MData = NewKernelBundleImpPtr; return; } + } + // Kernel bundle was set found so we add it. detail::ExtendedMemberT EMember = { detail::ExtendedMembersType::HANDLER_KERNEL_BUNDLE, NewKernelBundleImpPtr}; - ExendedMembersVec->push_back(EMember); } diff --git a/sycl/test/abi/sycl_symbols_linux.dump b/sycl/test/abi/sycl_symbols_linux.dump index 7bbd5f580fa52..3e891d3977e2d 100644 --- a/sycl/test/abi/sycl_symbols_linux.dump +++ b/sycl/test/abi/sycl_symbols_linux.dump @@ -3920,7 +3920,9 @@ _ZN2cl4sycl7handler18ext_oneapi_barrierERKSt6vectorINS0_5eventESaIS3_EE _ZN2cl4sycl7handler18extractArgsAndReqsEv _ZN2cl4sycl7handler20DisableRangeRoundingEv _ZN2cl4sycl7handler20associateWithHandlerEPNS0_6detail16AccessorBaseHostENS0_6access6targetE +_ZN2cl4sycl7handler20setStateSpecConstSetEv _ZN2cl4sycl7handler22setHandlerKernelBundleERKSt10shared_ptrINS0_6detail18kernel_bundle_implEE +_ZN2cl4sycl7handler28setStateExplicitKernelBundleEv _ZN2cl4sycl7handler24GetRangeRoundingSettingsERmS2_S2_ _ZN2cl4sycl7handler28extractArgsAndReqsFromLambdaEPcmPKNS0_6detail19kernel_param_desc_tE _ZN2cl4sycl7handler28extractArgsAndReqsFromLambdaEPcmPKNS0_6detail19kernel_param_desc_tEb @@ -4263,6 +4265,8 @@ _ZNK2cl4sycl7context8get_infoILNS0_4info7contextE4225EEENS3_12param_traitsIS4_XT _ZNK2cl4sycl7context8get_infoILNS0_4info7contextE4228EEENS3_12param_traitsIS4_XT_EE11return_typeEv _ZNK2cl4sycl7context8get_infoILNS0_4info7contextE65552EEENS3_12param_traitsIS4_XT_EE11return_typeEv _ZNK2cl4sycl7context9getNativeEv +_ZNK2cl4sycl7handler14getHandlerImplEv +_ZNK2cl4sycl7handler27isStateExplicitKernelBundleEv _ZNK2cl4sycl7handler30getOrInsertHandlerKernelBundleEb _ZNK2cl4sycl7program10get_kernelENSt7__cxx1112basic_stringIcSt11char_traitsIcESaIcEEE _ZNK2cl4sycl7program10get_kernelENSt7__cxx1112basic_stringIcSt11char_traitsIcESaIcEEEb diff --git a/sycl/test/abi/sycl_symbols_windows.dump b/sycl/test/abi/sycl_symbols_windows.dump index 9b347ceb7b47a..0fab219338d04 100644 --- a/sycl/test/abi/sycl_symbols_windows.dump +++ b/sycl/test/abi/sycl_symbols_windows.dump @@ -2056,6 +2056,7 @@ ?getElementSize@?$image_impl@$01@detail@sycl@cl@@QEBA_KXZ ?getElementSize@?$image_impl@$02@detail@sycl@cl@@QEBA_KXZ ?getEndTime@HostProfilingInfo@detail@sycl@cl@@QEBA_KXZ +?getHandlerImpl@handler@sycl@cl@@AEBA?AV?$shared_ptr@Vhandler_impl@detail@sycl@cl@@@std@@XZ ?getImageDesc@?$image_impl@$00@detail@sycl@cl@@AEAA?AU_pi_image_desc@@_N@Z ?getImageDesc@?$image_impl@$01@detail@sycl@cl@@AEAA?AU_pi_image_desc@@_N@Z ?getImageDesc@?$image_impl@$02@detail@sycl@cl@@AEAA?AU_pi_image_desc@@_N@Z @@ -2331,6 +2332,7 @@ ?isInterop@SYCLMemObjT@detail@sycl@cl@@QEBA_NXZ ?isOutOfRange@detail@sycl@cl@@YA_NV?$vec@H$03@23@W4addressing_mode@23@V?$range@$02@23@@Z ?isPathPresent@OSUtil@detail@sycl@cl@@SA_NAEBV?$basic_string@DU?$char_traits@D@std@@V?$allocator@D@2@@std@@@Z +?isStateExplicitKernelBundle@handler@sycl@cl@@AEBA_NXZ ?isValidModeForDestinationAccessor@handler@sycl@cl@@CA_NW4mode@access@23@@Z ?isValidModeForSourceAccessor@handler@sycl@cl@@CA_NW4mode@access@23@@Z ?isValidTargetForExplicitOp@handler@sycl@cl@@CA_NW4target@access@23@@Z @@ -3812,6 +3814,8 @@ ?setPitches@?$image_impl@$00@detail@sycl@cl@@AEAAXXZ ?setPitches@?$image_impl@$01@detail@sycl@cl@@AEAAXXZ ?setPitches@?$image_impl@$02@detail@sycl@cl@@AEAAXXZ +?setStateExplicitKernelBundle@handler@sycl@cl@@AEAAXXZ +?setStateSpecConstSet@handler@sycl@cl@@AEAAXXZ ?setType@handler@sycl@cl@@AEAAXW4CGTYPE@CG@detail@23@@Z ?set_final_data@SYCLMemObjT@detail@sycl@cl@@QEAAX$$T@Z ?set_final_data_from_storage@SYCLMemObjT@detail@sycl@cl@@QEAAXXZ diff --git a/sycl/unittests/SYCL2020/CMakeLists.txt b/sycl/unittests/SYCL2020/CMakeLists.txt index 247adc4b7e11d..f4b64df26afb0 100644 --- a/sycl/unittests/SYCL2020/CMakeLists.txt +++ b/sycl/unittests/SYCL2020/CMakeLists.txt @@ -4,7 +4,7 @@ set(CMAKE_CXX_EXTENSIONS OFF) set(LLVM_REQUIRES_EH 1) add_sycl_unittest(SYCL2020Tests OBJECT GetNativeOpenCL.cpp - SpecConstDefaultValues.cpp + SpecializationConstant.cpp KernelBundle.cpp KernelID.cpp ) diff --git a/sycl/unittests/SYCL2020/SpecConstDefaultValues.cpp b/sycl/unittests/SYCL2020/SpecConstDefaultValues.cpp deleted file mode 100644 index 8bdf85b162c9d..0000000000000 --- a/sycl/unittests/SYCL2020/SpecConstDefaultValues.cpp +++ /dev/null @@ -1,171 +0,0 @@ -//==---- DefaultValues.cpp --- Spec constants default values unit test -----==// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#define SYCL2020_DISABLE_DEPRECATION_WARNINGS - -#include -#include - -#include -#include -#include - -#include - -class TestKernel; -const static sycl::specialization_id SpecConst1{42}; - -__SYCL_INLINE_NAMESPACE(cl) { -namespace sycl { -namespace detail { -template <> struct KernelInfo { - static constexpr unsigned getNumParams() { return 0; } - static const kernel_param_desc_t &getParamDesc(int) { - static kernel_param_desc_t Dummy; - return Dummy; - } - static constexpr const char *getName() { - return "SpecConstDefaultValues_TestKernel"; - } - static constexpr bool isESIMD() { return false; } - static constexpr bool callsThisItem() { return false; } - static constexpr bool callsAnyThisFreeFunction() { return false; } -}; - -template <> const char *get_spec_constant_symbolic_ID() { - return "SC1"; -} -} // namespace detail -} // namespace sycl -} // __SYCL_INLINE_NAMESPACE(cl) - -static sycl::unittest::PiImage generateImageWithSpecConsts() { - using namespace sycl::unittest; - - std::vector SpecConstData; - PiProperty SC1 = makeSpecConstant(SpecConstData, "SC1", {0}, {0}, {42}); - PiProperty SC2 = makeSpecConstant(SpecConstData, "SC2", {1}, {0}, {8}); - - PiPropertySet PropSet; - addSpecConstants({SC1, SC2}, std::move(SpecConstData), PropSet); - - std::vector Bin{0, 1, 2, 3, 4, 5}; // Random data - - PiArray Entries = - makeEmptyKernels({"SpecConstDefaultValues_TestKernel"}); - - PiImage Img{PI_DEVICE_BINARY_TYPE_SPIRV, // Format - __SYCL_PI_DEVICE_BINARY_TARGET_SPIRV64, // DeviceTargetSpec - "", // Compile options - "", // Link options - std::move(Bin), - std::move(Entries), - std::move(PropSet)}; - - return Img; -} - -static sycl::unittest::PiImage Img = generateImageWithSpecConsts(); -static sycl::unittest::PiImageArray<1> ImgArray{&Img}; - -TEST(SpecConstDefaultValues, DefaultValuesAreSet) { - sycl::platform Plt{sycl::default_selector()}; - if (Plt.is_host()) { - std::cerr << "Test is not supported on host, skipping\n"; - return; // test is not supported on host. - } - - if (Plt.get_backend() == sycl::backend::cuda) { - std::cerr << "Test is not supported on CUDA platform, skipping\n"; - return; - } - - if (Plt.get_backend() == sycl::backend::hip) { - std::cerr << "Test is not supported on HIP platform, skipping\n"; - return; - } - - sycl::unittest::PiMock Mock{Plt}; - setupDefaultMockAPIs(Mock); - - const sycl::device Dev = Plt.get_devices()[0]; - - sycl::queue Queue{Dev}; - - const sycl::context Ctx = Queue.get_context(); - - sycl::kernel_bundle KernelBundle = - sycl::get_kernel_bundle(Ctx, {Dev}); - - sycl::kernel_id TestKernelID = sycl::get_kernel_id(); - auto DevImage = - std::find_if(KernelBundle.begin(), KernelBundle.end(), - [&](auto Image) { return Image.has_kernel(TestKernelID); }); - EXPECT_NE(DevImage, KernelBundle.end()); - - auto DevImageImpl = sycl::detail::getSyclObjImpl(*DevImage); - const auto &Blob = DevImageImpl->get_spec_const_blob_ref(); - - int SpecConstVal1 = *reinterpret_cast(Blob.data()); - int SpecConstVal2 = *(reinterpret_cast(Blob.data()) + 1); - - EXPECT_EQ(SpecConstVal1, 42); - EXPECT_EQ(SpecConstVal2, 8); -} - -TEST(SpecConstDefaultValues, DefaultValuesAreOverriden) { - sycl::platform Plt{sycl::default_selector()}; - if (Plt.is_host()) { - std::cerr << "Test is not supported on host, skipping\n"; - return; // test is not supported on host. - } - - if (Plt.get_backend() == sycl::backend::cuda) { - std::cerr << "Test is not supported on CUDA platform, skipping\n"; - return; - } - - if (Plt.get_backend() == sycl::backend::hip) { - std::cerr << "Test is not supported on HIP platform, skipping\n"; - return; - } - - sycl::unittest::PiMock Mock{Plt}; - setupDefaultMockAPIs(Mock); - - const sycl::device Dev = Plt.get_devices()[0]; - - sycl::queue Queue{Dev}; - - const sycl::context Ctx = Queue.get_context(); - - sycl::kernel_bundle KernelBundle = - sycl::get_kernel_bundle(Ctx, {Dev}); - - sycl::kernel_id TestKernelID = sycl::get_kernel_id(); - auto DevImage = - std::find_if(KernelBundle.begin(), KernelBundle.end(), - [&](auto Image) { return Image.has_kernel(TestKernelID); }); - EXPECT_NE(DevImage, KernelBundle.end()); - - auto DevImageImpl = sycl::detail::getSyclObjImpl(*DevImage); - auto &Blob = DevImageImpl->get_spec_const_blob_ref(); - int SpecConstVal1 = *reinterpret_cast(Blob.data()); - int SpecConstVal2 = *(reinterpret_cast(Blob.data()) + 1); - - EXPECT_EQ(SpecConstVal1, 42); - EXPECT_EQ(SpecConstVal2, 8); - - KernelBundle.set_specialization_constant(80); - - SpecConstVal1 = *reinterpret_cast(Blob.data()); - SpecConstVal2 = *(reinterpret_cast(Blob.data()) + 1); - - EXPECT_EQ(SpecConstVal1, 80); - EXPECT_EQ(SpecConstVal2, 8); -} diff --git a/sycl/unittests/SYCL2020/SpecializationConstant.cpp b/sycl/unittests/SYCL2020/SpecializationConstant.cpp new file mode 100644 index 0000000000000..8d89113049aa5 --- /dev/null +++ b/sycl/unittests/SYCL2020/SpecializationConstant.cpp @@ -0,0 +1,351 @@ +//==------ SpecializationConstant.cpp --- Spec constants unit tests --------==// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#define SYCL2020_DISABLE_DEPRECATION_WARNINGS + +#include +#include + +#include +#include +#include + +#include + +class TestKernel; +const static sycl::specialization_id SpecConst1{42}; + +__SYCL_INLINE_NAMESPACE(cl) { +namespace sycl { +namespace detail { +template <> struct KernelInfo { + static constexpr unsigned getNumParams() { return 0; } + static const kernel_param_desc_t &getParamDesc(int) { + static kernel_param_desc_t Dummy; + return Dummy; + } + static constexpr const char *getName() { + return "SpecializationConstant_TestKernel"; + } + static constexpr bool isESIMD() { return false; } + static constexpr bool callsThisItem() { return false; } + static constexpr bool callsAnyThisFreeFunction() { return false; } +}; + +template <> const char *get_spec_constant_symbolic_ID() { + return "SC1"; +} +} // namespace detail +} // namespace sycl +} // __SYCL_INLINE_NAMESPACE(cl) + +static sycl::unittest::PiImage generateImageWithSpecConsts() { + using namespace sycl::unittest; + + std::vector SpecConstData; + PiProperty SC1 = makeSpecConstant(SpecConstData, "SC1", {0}, {0}, {42}); + PiProperty SC2 = makeSpecConstant(SpecConstData, "SC2", {1}, {0}, {8}); + + PiPropertySet PropSet; + addSpecConstants({SC1, SC2}, std::move(SpecConstData), PropSet); + + std::vector Bin{0, 1, 2, 3, 4, 5}; // Random data + + PiArray Entries = + makeEmptyKernels({"SpecializationConstant_TestKernel"}); + + PiImage Img{PI_DEVICE_BINARY_TYPE_SPIRV, // Format + __SYCL_PI_DEVICE_BINARY_TARGET_SPIRV64, // DeviceTargetSpec + "", // Compile options + "", // Link options + std::move(Bin), + std::move(Entries), + std::move(PropSet)}; + + return Img; +} + +static sycl::unittest::PiImage Img = generateImageWithSpecConsts(); +static sycl::unittest::PiImageArray<1> ImgArray{&Img}; + +TEST(SpecializationConstant, DefaultValuesAreSet) { + sycl::platform Plt{sycl::default_selector()}; + if (Plt.is_host()) { + std::cerr << "Test is not supported on host, skipping\n"; + return; // test is not supported on host. + } + + if (Plt.get_backend() == sycl::backend::cuda) { + std::cerr << "Test is not supported on CUDA platform, skipping\n"; + return; + } + + if (Plt.get_backend() == sycl::backend::hip) { + std::cerr << "Test is not supported on HIP platform, skipping\n"; + return; + } + + sycl::unittest::PiMock Mock{Plt}; + setupDefaultMockAPIs(Mock); + + const sycl::device Dev = Plt.get_devices()[0]; + + sycl::queue Queue{Dev}; + + const sycl::context Ctx = Queue.get_context(); + + sycl::kernel_bundle KernelBundle = + sycl::get_kernel_bundle(Ctx, {Dev}); + + sycl::kernel_id TestKernelID = sycl::get_kernel_id(); + auto DevImage = + std::find_if(KernelBundle.begin(), KernelBundle.end(), + [&](auto Image) { return Image.has_kernel(TestKernelID); }); + EXPECT_NE(DevImage, KernelBundle.end()); + + auto DevImageImpl = sycl::detail::getSyclObjImpl(*DevImage); + const auto &Blob = DevImageImpl->get_spec_const_blob_ref(); + + int SpecConstVal1 = *reinterpret_cast(Blob.data()); + int SpecConstVal2 = *(reinterpret_cast(Blob.data()) + 1); + + EXPECT_EQ(SpecConstVal1, 42); + EXPECT_EQ(SpecConstVal2, 8); +} + +TEST(SpecializationConstant, DefaultValuesAreOverriden) { + sycl::platform Plt{sycl::default_selector()}; + if (Plt.is_host()) { + std::cerr << "Test is not supported on host, skipping\n"; + return; // test is not supported on host. + } + + if (Plt.get_backend() == sycl::backend::cuda) { + std::cerr << "Test is not supported on CUDA platform, skipping\n"; + return; + } + + if (Plt.get_backend() == sycl::backend::hip) { + std::cerr << "Test is not supported on HIP platform, skipping\n"; + return; + } + + sycl::unittest::PiMock Mock{Plt}; + setupDefaultMockAPIs(Mock); + + const sycl::device Dev = Plt.get_devices()[0]; + + sycl::queue Queue{Dev}; + + const sycl::context Ctx = Queue.get_context(); + + sycl::kernel_bundle KernelBundle = + sycl::get_kernel_bundle(Ctx, {Dev}); + + sycl::kernel_id TestKernelID = sycl::get_kernel_id(); + auto DevImage = + std::find_if(KernelBundle.begin(), KernelBundle.end(), + [&](auto Image) { return Image.has_kernel(TestKernelID); }); + EXPECT_NE(DevImage, KernelBundle.end()); + + auto DevImageImpl = sycl::detail::getSyclObjImpl(*DevImage); + auto &Blob = DevImageImpl->get_spec_const_blob_ref(); + int SpecConstVal1 = *reinterpret_cast(Blob.data()); + int SpecConstVal2 = *(reinterpret_cast(Blob.data()) + 1); + + EXPECT_EQ(SpecConstVal1, 42); + EXPECT_EQ(SpecConstVal2, 8); + + KernelBundle.set_specialization_constant(80); + + SpecConstVal1 = *reinterpret_cast(Blob.data()); + SpecConstVal2 = *(reinterpret_cast(Blob.data()) + 1); + + EXPECT_EQ(SpecConstVal1, 80); + EXPECT_EQ(SpecConstVal2, 8); +} + +TEST(SpecializationConstant, SetSpecConstAfterUseKernelBundle) { + sycl::platform Plt{sycl::default_selector()}; + if (Plt.is_host()) { + std::cerr << "Test is not supported on host, skipping\n"; + return; // test is not supported on host. + } + + if (Plt.get_backend() == sycl::backend::cuda) { + std::cerr << "Test is not supported on CUDA platform, skipping\n"; + return; + } + + if (Plt.get_backend() == sycl::backend::hip) { + std::cerr << "Test is not supported on HIP platform, skipping\n"; + return; + } + + sycl::unittest::PiMock Mock{Plt}; + setupDefaultMockAPIs(Mock); + + const sycl::device Dev = Plt.get_devices()[0]; + + sycl::queue Queue{Dev}; + + const sycl::context Ctx = Queue.get_context(); + + sycl::kernel_bundle KernelBundle = + sycl::get_kernel_bundle(Ctx, {Dev}); + + // Create uniquely identifyable class to throw on expected exception + class UniqueException {}; + + try { + Queue.submit([&](sycl::handler &CGH) { + CGH.use_kernel_bundle(KernelBundle); + try { + CGH.set_specialization_constant(80); + FAIL() << "No exception was thrown."; + } catch (const sycl::exception &e) { + if (static_cast(e.code().value()) != sycl::errc::invalid) { + FAIL() << "Unexpected SYCL exception was thrown."; + throw; + } + throw UniqueException{}; + } catch (...) { + FAIL() << "Unexpected non-SYCL exception was thrown."; + throw; + } + CGH.single_task([]() {}); + }); + } catch (const sycl::exception &e) { + if (static_cast(e.code().value()) == sycl::errc::invalid) { + FAIL() << "SYCL exception with error code sycl::errc::invalid was " + "thrown at the wrong level."; + } + throw; + } catch (const UniqueException &) { + // Expected path + } +} + +TEST(SpecializationConstant, GetSpecConstAfterUseKernelBundle) { + sycl::platform Plt{sycl::default_selector()}; + if (Plt.is_host()) { + std::cerr << "Test is not supported on host, skipping\n"; + return; // test is not supported on host. + } + + if (Plt.get_backend() == sycl::backend::cuda) { + std::cerr << "Test is not supported on CUDA platform, skipping\n"; + return; + } + + if (Plt.get_backend() == sycl::backend::hip) { + std::cerr << "Test is not supported on HIP platform, skipping\n"; + return; + } + + sycl::unittest::PiMock Mock{Plt}; + setupDefaultMockAPIs(Mock); + + const sycl::device Dev = Plt.get_devices()[0]; + sycl::queue Queue{Dev}; + const sycl::context Ctx = Queue.get_context(); + + sycl::kernel_bundle KernelBundle = + sycl::get_kernel_bundle(Ctx, {Dev}); + + // Create uniquely identifyable class to throw on expected exception + class UniqueException {}; + + try { + Queue.submit([&](sycl::handler &CGH) { + CGH.use_kernel_bundle(KernelBundle); + try { + auto SpecConst1Val = CGH.get_specialization_constant(); + (void)SpecConst1Val; + FAIL() << "No exception was thrown."; + } catch (const sycl::exception &e) { + if (static_cast(e.code().value()) != sycl::errc::invalid) { + FAIL() << "Unexpected SYCL exception was thrown."; + throw; + } + throw UniqueException{}; + } catch (...) { + FAIL() << "Unexpected non-SYCL exception was thrown."; + throw; + } + CGH.single_task([]() {}); + }); + } catch (const sycl::exception &e) { + if (static_cast(e.code().value()) == sycl::errc::invalid) { + FAIL() << "SYCL exception with error code sycl::errc::invalid was " + "thrown at the wrong level."; + } + throw; + } catch (const UniqueException &) { + // Expected path + } +} + +TEST(SpecializationConstant, UseKernelBundleAfterSetSpecConst) { + sycl::platform Plt{sycl::default_selector()}; + if (Plt.is_host()) { + std::cerr << "Test is not supported on host, skipping\n"; + return; // test is not supported on host. + } + + if (Plt.get_backend() == sycl::backend::cuda) { + std::cerr << "Test is not supported on CUDA platform, skipping\n"; + return; + } + + if (Plt.get_backend() == sycl::backend::hip) { + std::cerr << "Test is not supported on HIP platform, skipping\n"; + return; + } + + sycl::unittest::PiMock Mock{Plt}; + setupDefaultMockAPIs(Mock); + + const sycl::device Dev = Plt.get_devices()[0]; + sycl::queue Queue{Dev}; + const sycl::context Ctx = Queue.get_context(); + + sycl::kernel_bundle KernelBundle = + sycl::get_kernel_bundle(Ctx, {Dev}); + + // Create uniquely identifyable class to throw on expected exception + class UniqueException {}; + + try { + Queue.submit([&](sycl::handler &CGH) { + CGH.set_specialization_constant(80); + try { + CGH.use_kernel_bundle(KernelBundle); + FAIL() << "No exception was thrown."; + } catch (const sycl::exception &e) { + if (static_cast(e.code().value()) != sycl::errc::invalid) { + FAIL() << "Unexpected SYCL exception was thrown."; + throw; + } + throw UniqueException{}; + } catch (...) { + FAIL() << "Unexpected non-SYCL exception was thrown."; + throw; + } + CGH.single_task([]() {}); + }); + } catch (const sycl::exception &e) { + if (static_cast(e.code().value()) == sycl::errc::invalid) { + FAIL() << "SYCL exception with error code sycl::errc::invalid was " + "thrown at the wrong level."; + } + throw; + } catch (const UniqueException &) { + // Expected path + } +}