Skip to content

Commit a401090

Browse files
author
Erich Keane
committed
Extract the instantiator out into its own type to simplify its use
1 parent 6c94611 commit a401090

File tree

1 file changed

+32
-24
lines changed

1 file changed

+32
-24
lines changed

sycl/include/CL/sycl/handler.hpp

Lines changed: 32 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -805,31 +805,11 @@ class __SYCL_EXPORT handler {
805805
decltype(getRangeRoundedKernelLambda<TransformedArgType, Dims>(
806806
KernelFunc, NumWorkItems));
807807
using NameWT = typename detail::get_kernel_wrapper_name_t<NameT>::name;
808-
#ifdef __SYCL_NONCONST_FUNCTOR__
809-
using WrapperKernelParamTy = WrapperTy;
810-
using KernelParamTy = KernelFunc;
811-
#else
812-
using WrapperKernelParamTy = const WrapperTy &;
813-
using KernelParamTy = const KernelFunc &;
814-
#endif
815-
816-
if constexpr (detail::isKernelLambdaCallableWithKernelHandler<
817-
WrapperTy, ElementType>()) {
818-
(void)static_cast<void (*)(WrapperKernelParamTy, kernel_handler)>(
819-
kernel_parallel_for<NameWT, TransformedArgType, WrapperTy>);
820-
} else {
821-
(void)static_cast<void (*)(WrapperKernelParamTy)>(
822-
kernel_parallel_for<NameWT, TransformedArgType, WrapperTy>);
823-
}
824808

825-
if constexpr (detail::isKernelLambdaCallableWithKernelHandler<
826-
KernelType, ElementType>()) {
827-
(void)static_cast<void (*)(KernelParamTy, kernel_handler)>(
828-
kernel_parallel_for<NameT, TransformedArgType, KernelType>);
829-
} else {
830-
(void)static_cast<void (*)(KernelParamTy)>(
831-
kernel_parallel_for<NameT, TransformedArgType, KernelType>);
832-
}
809+
(void)kernel_parallel_for_wrapper_instantiator<NameWT, TransformedArgType,
810+
WrapperTy>::value;
811+
(void)kernel_parallel_for_wrapper_instantiator<NameT, TransformedArgType,
812+
KernelType>::value;
833813

834814
using KI = detail::KernelInfo<KernelName>;
835815
bool DisableRounding =
@@ -1050,6 +1030,34 @@ class __SYCL_EXPORT handler {
10501030
}
10511031
}
10521032

1033+
// Helper instantiator type for kernel_parallel_for_wrapper that
1034+
// instantiates but not calls the appropriate kernel. Needed to support use of
1035+
// KernelInfo in parallel_for_lambda_impl when supporting
1036+
// SCYL_DISABLE_PARALLEL_FOR_RANGE_ROUNDING.
1037+
template <typename KernelName, typename ElementType, typename KernelType>
1038+
class kernel_parallel_for_wrapper_instantiator {
1039+
static constexpr auto func_loader() {
1040+
#ifdef __SYCL_NONCONST_FUNCTOR__
1041+
using ParamTy = KernelType;
1042+
#else
1043+
using ParamTy = const KernelType &;
1044+
#endif
1045+
if constexpr (detail::isKernelLambdaCallableWithKernelHandler<
1046+
KernelType, ElementType>()) {
1047+
using FuncTy = void (*)(ParamTy, kernel_handler);
1048+
return static_cast<FuncTy>(
1049+
kernel_parallel_for<KernelName, ElementType, KernelType>);
1050+
} else {
1051+
using FuncTy = void (*)(ParamTy);
1052+
return static_cast<FuncTy>(
1053+
kernel_parallel_for<KernelName, ElementType, KernelType>);
1054+
}
1055+
}
1056+
1057+
public:
1058+
static constexpr auto value = func_loader();
1059+
};
1060+
10531061
// Wrappers for kernel_parallel_for_work_group(...)
10541062

10551063
template <typename KernelName, typename ElementType, typename KernelType>

0 commit comments

Comments
 (0)