diff --git a/Src/Base/AMReX_GpuReduce.H b/Src/Base/AMReX_GpuReduce.H index 9b48138940c..3907ca385f6 100644 --- a/Src/Base/AMReX_GpuReduce.H +++ b/Src/Base/AMReX_GpuReduce.H @@ -8,6 +8,7 @@ #include #include #include +#include #if !defined(AMREX_USE_CUB) && defined(AMREX_USE_CUDA) && defined(__CUDACC__) && (__CUDACC_VER_MAJOR__ >= 11) #define AMREX_USE_CUB 1 @@ -249,15 +250,54 @@ void deviceReduceLogicalOr (int * dest, int source, Gpu::Handler const& h) noexc #elif defined(AMREX_USE_CUDA) || defined(AMREX_USE_HIP) +namespace detail { + +template +AMREX_GPU_DEVICE AMREX_FORCE_INLINE +T shuffle_down (T x, int offset) noexcept +{ + return AMREX_HIP_OR_CUDA(__shfl_down(x, offset), + __shfl_down_sync(0xffffffff, x, offset)); +} + +// If other sizeof is needed, we can implement it later. +template = 0> +AMREX_GPU_DEVICE AMREX_FORCE_INLINE +T multi_shuffle_down (T x, int offset) noexcept +{ + constexpr int nwords = (sizeof(T) + sizeof(unsigned int) - 1) / sizeof(unsigned int); + T y; + auto py = reinterpret_cast(&y); + auto px = reinterpret_cast(&x); + for (int i = 0; i < nwords; ++i) { + py[i] = shuffle_down(px[i],offset); + } + return y; +} + +} + template struct warpReduce { + // Not all arithmetic types can be taken by shuffle_down, but it's good enough. + template ::value,int> = 0> + AMREX_GPU_DEVICE AMREX_FORCE_INLINE + T operator() (T x) const noexcept + { + for (int offset = warpSize/2; offset > 0; offset /= 2) { + T y = detail::shuffle_down(x, offset); + x = F()(x,y); + } + return x; + } + + template ::value,int> = 0> AMREX_GPU_DEVICE AMREX_FORCE_INLINE T operator() (T x) const noexcept { for (int offset = warpSize/2; offset > 0; offset /= 2) { - AMREX_HIP_OR_CUDA(T y = __shfl_down(x, offset);, - T y = __shfl_down_sync(0xffffffff, x, offset); ) + T y = detail::multi_shuffle_down(x, offset); x = F()(x,y); } return x; diff --git a/Src/Base/AMReX_Reduce.H b/Src/Base/AMReX_Reduce.H index 9c07b7b4a2a..05b56b97fa9 100644 --- a/Src/Base/AMReX_Reduce.H +++ b/Src/Base/AMReX_Reduce.H @@ -9,9 +9,35 @@ #include #include +#include namespace amrex { +template +struct ValLocPair +{ + TV value; + TI index; + + static constexpr ValLocPair max () { + return ValLocPair{std::numeric_limits::max(), TI()}; + } + + static constexpr ValLocPair lowest () { + return ValLocPair{std::numeric_limits::lowest(), TI()}; + } + + friend constexpr bool operator< (ValLocPair const& a, ValLocPair const& b) + { + return a.value < b.value; + } + + friend constexpr bool operator> (ValLocPair const& a, ValLocPair const& b) + { + return a.value > b.value; + } +}; + namespace Reduce { namespace detail { #ifdef AMREX_USE_GPU @@ -133,7 +159,12 @@ struct ReduceOpMin void local_update (T& d, T const& s) const noexcept { d = amrex::min(d,s); } template - constexpr void init (T& t) const noexcept { t = std::numeric_limits::max(); } + constexpr std::enable_if_t::is_specialized> + init (T& t) const noexcept { t = std::numeric_limits::max(); } + + template + constexpr std::enable_if_t::is_specialized> + init (T& t) const noexcept { t = T::max(); } }; struct ReduceOpMax @@ -161,7 +192,12 @@ struct ReduceOpMax void local_update (T& d, T const& s) const noexcept { d = amrex::max(d,s); } template - constexpr void init (T& t) const noexcept { t = std::numeric_limits::lowest(); } + constexpr std::enable_if_t::is_specialized> + init (T& t) const noexcept { t = std::numeric_limits::lowest(); } + + template + constexpr std::enable_if_t::is_specialized> + init (T& t) const noexcept { t = T::lowest(); } }; struct ReduceOpLogicalAnd