-
Notifications
You must be signed in to change notification settings - Fork 552
/
Copy pathdevice_utils.cuh
96 lines (90 loc) · 3.19 KB
/
device_utils.cuh
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
/*
* Copyright (c) 2020, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <cuda_utils.cuh>
namespace MLCommon {
/**
* @brief Batched warp-level sum reduction
*
* @tparam T data type
* @tparam NThreads Number of threads in the warp doing independent reductions
*
* @param[in] val input value
* @return for the first "group" of threads, the reduced value. All
* others will contain unusable values!
*
* @note Why not cub? Because cub doesn't seem to allow working with arbitrary
* number of warps in a block and also doesn't support this kind of
* batched reduction operation
* @note All threads in the warp must enter this function together
*
* @todo Expand this to support arbitrary reduction ops
*/
template <typename T, int NThreads>
DI T batchedWarpReduce(T val) {
#pragma unroll
for (int i = NThreads; i < WarpSize; i <<= 1) {
val += shfl(val, laneId() + i);
}
return val;
}
/**
* @brief 1-D block-level batched sum reduction
*
* @tparam T data type
* @tparam NThreads Number of threads in the warp doing independent reductions
*
* @param val input value
* @param smem shared memory region needed for storing intermediate results. It
* must alteast be of size: `sizeof(T) * nWarps * NThreads`
* @return for the first "group" of threads in the block, the reduced value.
* All others will contain unusable values!
*
* @note Why not cub? Because cub doesn't seem to allow working with arbitrary
* number of warps in a block and also doesn't support this kind of
* batched reduction operation
* @note All threads in the block must enter this function together
*
* @todo Expand this to support arbitrary reduction ops
*/
template <typename T, int NThreads>
DI T batchedBlockReduce(T val, char *smem) {
auto *sTemp = reinterpret_cast<T *>(smem);
constexpr int nGroupsPerWarp = WarpSize / NThreads;
static_assert(isPo2(nGroupsPerWarp), "nGroupsPerWarp must be a PO2!");
const int nGroups = (blockDim.x + NThreads - 1) / NThreads;
const int lid = laneId();
const int lgid = lid % NThreads;
const int gid = threadIdx.x / NThreads;
const auto wrIdx = (gid / nGroupsPerWarp) * NThreads + lgid;
const auto rdIdx = gid * NThreads + lgid;
for (int i = nGroups; i > 0;) {
auto iAligned =
((i + nGroupsPerWarp - 1) / nGroupsPerWarp) * nGroupsPerWarp;
if (gid < iAligned) {
val = batchedWarpReduce<T, NThreads>(val);
if (lid < NThreads) sTemp[wrIdx] = val;
}
__syncthreads();
i /= nGroupsPerWarp;
if (i > 0) {
val = gid < i ? sTemp[rdIdx] : T(0);
}
__syncthreads();
}
return val;
}
} // namespace MLCommon