@@ -805,31 +805,11 @@ class __SYCL_EXPORT handler {
805
805
decltype (getRangeRoundedKernelLambda<TransformedArgType, Dims>(
806
806
KernelFunc, NumWorkItems));
807
807
using NameWT = typename detail::get_kernel_wrapper_name_t <NameT>::name;
808
- #ifdef __SYCL_NONCONST_FUNCTOR__
809
- using WrapperKernelParamTy = WrapperTy;
810
- using KernelParamTy = KernelFunc;
811
- #else
812
- using WrapperKernelParamTy = const WrapperTy &;
813
- using KernelParamTy = const KernelFunc &;
814
- #endif
815
-
816
- if constexpr (detail::isKernelLambdaCallableWithKernelHandler<
817
- WrapperTy, ElementType>()) {
818
- (void )static_cast <void (*)(WrapperKernelParamTy, kernel_handler)>(
819
- kernel_parallel_for<NameWT, TransformedArgType, WrapperTy>);
820
- } else {
821
- (void )static_cast <void (*)(WrapperKernelParamTy)>(
822
- kernel_parallel_for<NameWT, TransformedArgType, WrapperTy>);
823
- }
824
808
825
- if constexpr (detail::isKernelLambdaCallableWithKernelHandler<
826
- KernelType, ElementType>()) {
827
- (void )static_cast <void (*)(KernelParamTy, kernel_handler)>(
828
- kernel_parallel_for<NameT, TransformedArgType, KernelType>);
829
- } else {
830
- (void )static_cast <void (*)(KernelParamTy)>(
831
- kernel_parallel_for<NameT, TransformedArgType, KernelType>);
832
- }
809
+ (void )kernel_parallel_for_wrapper_instantiator<NameWT, TransformedArgType,
810
+ WrapperTy>::value;
811
+ (void )kernel_parallel_for_wrapper_instantiator<NameT, TransformedArgType,
812
+ KernelType>::value;
833
813
834
814
using KI = detail::KernelInfo<KernelName>;
835
815
bool DisableRounding =
@@ -1050,6 +1030,34 @@ class __SYCL_EXPORT handler {
1050
1030
}
1051
1031
}
1052
1032
1033
+ // Helper instantiator type for kernel_parallel_for_wrapper that
1034
+ // instantiates but not calls the appropriate kernel. Needed to support use of
1035
+ // KernelInfo in parallel_for_lambda_impl when supporting
1036
+ // SCYL_DISABLE_PARALLEL_FOR_RANGE_ROUNDING.
1037
+ template <typename KernelName, typename ElementType, typename KernelType>
1038
+ class kernel_parallel_for_wrapper_instantiator {
1039
+ static constexpr auto func_loader () {
1040
+ #ifdef __SYCL_NONCONST_FUNCTOR__
1041
+ using ParamTy = KernelType;
1042
+ #else
1043
+ using ParamTy = const KernelType &;
1044
+ #endif
1045
+ if constexpr (detail::isKernelLambdaCallableWithKernelHandler<
1046
+ KernelType, ElementType>()) {
1047
+ using FuncTy = void (*)(ParamTy, kernel_handler);
1048
+ return static_cast <FuncTy>(
1049
+ kernel_parallel_for<KernelName, ElementType, KernelType>);
1050
+ } else {
1051
+ using FuncTy = void (*)(ParamTy);
1052
+ return static_cast <FuncTy>(
1053
+ kernel_parallel_for<KernelName, ElementType, KernelType>);
1054
+ }
1055
+ }
1056
+
1057
+ public:
1058
+ static constexpr auto value = func_loader();
1059
+ };
1060
+
1053
1061
// Wrappers for kernel_parallel_for_work_group(...)
1054
1062
1055
1063
template <typename KernelName, typename ElementType, typename KernelType>
0 commit comments