forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathTensorTopK.cu
312 lines (272 loc) · 13.9 KB
/
TensorTopK.cu
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
#include <ATen/ATen.h>
#include <ATen/cuda/detail/TensorInfo.cuh>
#include <ATen/cuda/detail/OffsetCalculator.cuh>
#include <ATen/LegacyTHFunctionsCUDA.h>
#include <ATen/native/Resize.h>
#include <ATen/native/cuda/SortingCommon.cuh>
#include <ATen/native/cuda/SortingRadixSelect.cuh>
#include <ATen/native/cuda/SortUtils.cuh>
#include <c10/macros/Macros.h>
using namespace at::native;
namespace at {
namespace native {
namespace {
template <typename T, typename IndexType, int Dim, bool Order>
C10_LAUNCH_BOUNDS_1(512)
__global__ void gatherTopK(at::cuda::detail::TensorInfo<T, IndexType> input,
IndexType inputSliceSize,
IndexType outputSliceSize, // aka `k`
IndexType numInputSlices,
IndexType inputWithinSliceStride,
at::cuda::detail::TensorInfo<T, IndexType> topK,
IndexType numTopKSlices,
IndexType topKWithinSliceStride,
at::cuda::detail::TensorInfo<int64_t, IndexType> indices,
IndexType indicesWithinSliceStride) {
// Indices are limited to integer fp precision, so counts can fit in
// int32, regardless of IndexType
#ifdef __HIP_PLATFORM_HCC__
__shared__ int smem[64];
#else
__shared__ int smem[32]; // one per each warp, up to warp limit
#endif
IndexType slice = getLinearBlockId<IndexType>();
if (slice >= numInputSlices) {
return;
}
// Find the start offset for our slice
IndexType sliceStartIndex =
at::cuda::detail::IndexToOffset<T, IndexType, Dim>::get(slice, input);
IndexType topKSliceStartIndex =
at::cuda::detail::IndexToOffset<T, IndexType, Dim>::get(slice, topK);
IndexType indicesSliceStartIndex =
at::cuda::detail::IndexToOffset<int64_t, IndexType, Dim>::get(slice, indices);
T* inputSliceStart = &input.data[sliceStartIndex];
T* topKSliceStart = &topK.data[topKSliceStartIndex];
int64_t* indicesSliceStart = &indices.data[indicesSliceStartIndex];
// Find the k-th highest element in our input
T topKValue = ScalarConvert<int, T>::to(0);
radixSelect<T, typename TopKTypeConfig<T>::RadixType, IndexType, Order>(
inputSliceStart, outputSliceSize,
inputSliceSize, inputWithinSliceStride,
smem, &topKValue);
const auto topKConverted = at::native::TopKTypeConfig<T>::convert(topKValue);
// Every value that is strictly less/greater than `pattern`
// (depending on sort dir) in sorted int format is in the top-K.
// The top-K value itself might not be unique.
//
// Since there are a variable number of elements that we see that
// are within the top-k, we don't know at what index to write out
// the resulting values.
// In order to get this, we perform an exclusive prefix sum of
// `hasTopK`. This will return the resulting index into which we
// need to write the result, if a thread has a result.
// All threads need to participate in the loop and the prefix sum,
// but not necessarily in the load; hence loop bounds being rounded
// up to a multiple of the block dim.
IndexType numIterations = THCRoundUp(inputSliceSize, (IndexType) blockDim.x);
IndexType writeIndexStart = 0;
for (IndexType i = threadIdx.x; i < numIterations; i += blockDim.x) {
bool inRange = (i < inputSliceSize);
T v =
inRange ? doLdg(&inputSliceStart[i * inputWithinSliceStride]) : ScalarConvert<int, T>::to(0);
const auto convertedV = at::native::TopKTypeConfig<T>::convert(v);
bool hasTopK;
if (Order) {
hasTopK = inRange && (convertedV > topKConverted);
} else {
hasTopK = inRange && (convertedV < topKConverted);
}
int index;
int carry;
exclusiveBinaryPrefixScan<int, true>(smem, hasTopK, &index, &carry, AddOp<int>());
if (hasTopK) {
int writeIndex = writeIndexStart + index;
CUDA_KERNEL_ASSERT(writeIndex < outputSliceSize);
IndexType topKOffset = writeIndex * topKWithinSliceStride;
IndexType indexOffset = writeIndex * indicesWithinSliceStride;
topKSliceStart[topKOffset] = v;
indicesSliceStart[indexOffset] = i;
}
writeIndexStart += carry;
}
// We need to fill in the rest with actual == top-K values.
// The number that we need is outputSliceSize -
// writeIndexStart. There might be more than that number available,
// in which case we have to choose the first seen set. We do this
// via a prefix sum to calculate indices for writing results.
CUDA_KERNEL_ASSERT(outputSliceSize >= writeIndexStart);
IndexType topKRemaining = (outputSliceSize - writeIndexStart);
for (IndexType i = threadIdx.x; i < numIterations; i += blockDim.x) {
bool inRange = (i < inputSliceSize);
T v =
inRange ? doLdg(&inputSliceStart[i * inputWithinSliceStride]) : ScalarConvert<int, T>::to(0);
const auto convertedV = at::native::TopKTypeConfig<T>::convert(v);
bool hasTopK = inRange && (convertedV == topKConverted);
int index;
int carry;
exclusiveBinaryPrefixScan<int, true>(smem, hasTopK, &index, &carry, AddOp<int>());
if (hasTopK && index < topKRemaining) {
int writeIndex = writeIndexStart + index;
CUDA_KERNEL_ASSERT(writeIndex < outputSliceSize);
IndexType topKOffset = writeIndex * topKWithinSliceStride;
IndexType indexOffset = writeIndex * indicesWithinSliceStride;
topKSliceStart[topKOffset] = v;
indicesSliceStart[indexOffset] = i;
}
if (carry >= topKRemaining) {
break;
}
topKRemaining -= carry;
writeIndexStart += carry;
}
};
} // namespace
TORCH_IMPL_FUNC(topk_out_cuda)
(const Tensor& self,
int64_t k, int64_t dim, bool largest, bool sorted,
const Tensor& values,
const Tensor& indices) {
TensorArg topK_arg{values, "topK", 1}, indices_arg{indices, "indices", 2}, input_arg{self, "self", 3};
checkAllSameGPU(__func__, {topK_arg, indices_arg, input_arg});
dim = at::maybe_wrap_dim(dim, self);
int numDims = self.dim();
numDims = numDims == 0 ? 1 : numDims;
TORCH_CHECK(numDims <= MAX_DIMS, "input tensor has too many dimensions");
int64_t sliceSize = self.dim() == 0 ? 1 : self.size(dim);
Tensor input = self.contiguous();
// If k is 0 the result is an empty tensor, so we don't need to launch a kernel.
if (k == 0) {
return;
}
// static_cast is required to ensure that the correct type (INDEX_T)
// is provided to the kernel for the arguments.
#define RUN_K(INDEX_T, DIM, DIR) \
gatherTopK<scalar_t, INDEX_T, DIM, DIR> \
<<<grid, block, 0, c10::cuda::getCurrentCUDAStream()>>>( \
inputInfo, \
static_cast<INDEX_T>(sliceSize), \
static_cast<INDEX_T>(k), \
static_cast<INDEX_T>(inputSlices), \
/* The actual dimension that the k-selection is running in */ \
/* may have changed from collapseDims() */ \
static_cast<INDEX_T>(inputInfo.strides[collapseInputDim]), \
topKInfo, \
static_cast<INDEX_T>(topKSlices), \
static_cast<INDEX_T>(topKInfo.strides[collapseTopKDim]), \
indicesInfo, \
static_cast<INDEX_T>(indicesInfo.strides[collapseIndicesDim])); \
C10_CUDA_KERNEL_LAUNCH_CHECK();
#define RUN_DIR(INDEX_T, DIM) \
if (largest) { \
RUN_K(INDEX_T, DIM, true); \
} else { \
RUN_K(INDEX_T, DIM, false); \
}
#define RUN_DIM(INDEX_T) \
if (allDims == 1) { \
RUN_DIR(INDEX_T, 1); \
} else if (allDims == 2) { \
RUN_DIR(INDEX_T, 2); \
} else if (allDims == 3) { \
RUN_DIR(INDEX_T, 3); \
} else { \
RUN_DIR(INDEX_T, -1); \
}
#define RUN_T(INDEX_T) \
AT_DISPATCH_ALL_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, input.scalar_type(), "topk_out_cuda", [&] { \
at::cuda::detail::TensorInfo<scalar_t, INDEX_T> inputInfo = \
at::cuda::detail::getTensorInfo<scalar_t, INDEX_T>(input); \
at::cuda::detail::TensorInfo<scalar_t, INDEX_T> topKInfo = \
at::cuda::detail::getTensorInfo<scalar_t, INDEX_T>(values); \
at::cuda::detail::TensorInfo<int64_t, INDEX_T> indicesInfo = \
at::cuda::detail::getTensorInfo<int64_t, INDEX_T>(indices); \
/* tensorInfoLegacyIfScalar*/ \
if (!input.dim()) { \
inputInfo.dims = 1; \
inputInfo.sizes[0] = 1; \
inputInfo.strides[0] = 1; \
topKInfo.dims = 1; \
topKInfo.sizes[0] = 1; \
topKInfo.strides[0] = 1; \
indicesInfo.dims = 1; \
indicesInfo.sizes[0] = 1; \
indicesInfo.strides[0] = 1; \
} \
/* We use these structures solely to find the offset to */ \
/* each slice we are operating on */ \
inputInfo.sizes[dim] = 1; \
topKInfo.sizes[dim] = 1; \
indicesInfo.sizes[dim] = 1; \
/* Collapse all other dims */ \
int collapseInputDim = inputInfo.collapseDims(dim); \
int collapseTopKDim = topKInfo.collapseDims(dim); \
int collapseIndicesDim = indicesInfo.collapseDims(dim); \
int64_t inputSlices = 1; \
for (int i = 0; i < inputInfo.dims; ++i) { \
inputSlices *= inputInfo.sizes[i]; \
} \
int64_t topKSlices = 1; \
for (int i = 0; i < topKInfo.dims; ++i) { \
topKSlices *= topKInfo.sizes[i]; \
} \
\
dim3 grid; \
TORCH_INTERNAL_ASSERT(getGridFromTiles(inputSlices, grid), "Too many slices to sort"); \
\
dim3 block(std::min(at::cuda::ATenCeilDiv(sliceSize, (int64_t) C10_WARP_SIZE)*(int64_t) C10_WARP_SIZE, (int64_t) 512)); \
\
/* This is used as a template parameter to calculate indices. */ \
/* We only specialize it if all collapsed dim sizes are the */ \
/* same; otherwise, we use -1 which is the specialization */ \
/* parameter for arbitrary dimensions */ \
int allDims = inputInfo.dims; \
if (topKInfo.dims != allDims || indicesInfo.dims != allDims) { \
allDims = -1; \
} \
\
RUN_DIM(INDEX_T); \
});
// the below is safe with 0-dimensional tensors because it is based on
// TensorInfo which implicitly expands to 1-dimensional.
if (input.numel() > 0) {
// Based on required index size, run the algorithm with the
// appropriate index type
if (at::cuda::detail::canUse32BitIndexMath(input) &&
at::cuda::detail::canUse32BitIndexMath(values) &&
at::cuda::detail::canUse32BitIndexMath(indices)) {
RUN_T(uint32_t);
} else {
RUN_T(uint64_t);
}
}
#undef RUN_T
#undef RUN_DIM
#undef RUN_DIR
#undef RUN_K
// Sort the results if the user wants them sorted, since our
// selection routine does not ensure sorting
if (sorted && values.numel() > 1) {
if (should_use_small_sort(values, dim)) {
// This avoids any memory allocations and performs all sorting
// work inplace along the slice
sortKeyValueInplace(values, indices, dim, largest);
} else {
// Depend upon the backup sort that returns indices, which we
// can use in conjunction with gather to produce the original
// indices.
// This is not the most efficient implementation, especially since
// there are memory allocations performed here. If the user desires
// greater performance, they should torch.gather() the results
// themselves using the reported indices, providing previously
// allocated tensors to receive the results.
Tensor sortedIndices = at::empty_like(indices);
Tensor sortedValues = at::empty_like(values);
sort_out_cuda(values, dim, largest, sortedValues, sortedIndices);
indices.copy_(indices.gather(dim, sortedIndices));
values.copy_(sortedValues);
}
}
}
} // at::native
} // at