forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathTensorModeKernel.cu
301 lines (263 loc) · 9.75 KB
/
TensorModeKernel.cu
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
#include <ATen/Context.h>
#include <ATen/Dispatch.h>
#include <ATen/TensorIterator.h>
#include <ATen/native/ReduceOpsUtils.h>
#include <ATen/native/Resize.h>
#include <ATen/native/TensorCompare.h>
#include <c10/util/Exception.h>
#include <thrust/device_ptr.h>
#include <thrust/device_vector.h>
#include <thrust/execution_policy.h>
#include <thrust/extrema.h>
#include <thrust/inner_product.h>
#include <thrust/iterator/constant_iterator.h>
#include <thrust/sequence.h>
#include <thrust/sort.h>
#include <ATen/cuda/detail/IndexUtils.cuh>
#include <ATen/native/cuda/TensorModeKernel.cuh>
#include <THC/THCThrustAllocator.cuh>
namespace at {
namespace native {
template <typename scalar_t>
void calculate_mode(
Tensor& values,
Tensor& indices,
const Tensor& self,
std::vector<int64_t>& position,
int dim) {
THCThrustAllocator thrust_allocator(globalContext().lazyInitCUDA());
auto stream = at::cuda::getCurrentCUDAStream();
auto policy = thrust::cuda::par(thrust_allocator).on(stream);
TORCH_INTERNAL_ASSERT(self.is_contiguous());
// Because the input is contiguous, we want to get a reference to the
// location of the buffer at the innermost dimension that we are going
// to calculate the mode for --> we do this by manually doing the stride
// calculations to get an offset
scalar_t* data = self.data_ptr<scalar_t>();
for (int64_t i = 0; i < position.size(); i++) {
data += position[i] * ensure_nonempty_stride(self, i);
}
int64_t ndim = ensure_nonempty_dim(self.dim());
int64_t n_element = ensure_nonempty_size(self, ndim - 1);
scalar_t* iter_begin = data;
scalar_t* iter_end = data + n_element;
Tensor sort_buffer = at::arange(0, n_element, self.options().dtype(kLong));
auto sort_buffer_ptr =
thrust::device_pointer_cast(sort_buffer.data_ptr<int64_t>());
// Sort the input data. The original indices of the data are stored in
// sort_buffer_ptr
thrust::sort_by_key(policy, iter_begin, iter_end, sort_buffer_ptr);
// Count # of unique elements via an inner product between adjacent elements.
// Add 1 if two neighboring element are not equal.
int unique = 1 +
thrust::inner_product(
policy,
iter_begin,
iter_end - 1,
iter_begin + 1,
0,
thrust::plus<int>(),
thrust::not_equal_to<scalar_t>());
// Count frequency of each element
Tensor keys = at::empty(unique, self.options());
Tensor counts = at::empty(unique, self.options().dtype(kLong));
auto keys_ptr = thrust::device_pointer_cast(keys.data_ptr<scalar_t>());
auto counts_ptr = thrust::device_pointer_cast(counts.data_ptr<int64_t>());
thrust::reduce_by_key(
policy,
iter_begin,
iter_end,
thrust::constant_iterator<int>(1),
keys_ptr,
counts_ptr);
// Find index of maximum count
auto it = thrust::max_element(policy, counts_ptr, counts_ptr + unique);
scalar_t mode = keys_ptr[it - counts_ptr];
// Find first index within which it occurs
auto position_iter = thrust::find(policy, iter_begin, iter_end, mode);
TORCH_INTERNAL_ASSERT(position_iter != iter_end);
int64_t index = sort_buffer_ptr[position_iter - iter_begin];
// Place mode, index in output
scalar_t* values_data = values.data_ptr<scalar_t>();
int64_t* indices_data = indices.data_ptr<int64_t>();
for (int64_t i = 0; i < position.size(); i++) {
int64_t pos = position[i];
values_data += ensure_nonempty_stride(values, i) * pos;
indices_data += ensure_nonempty_stride(indices, i) * pos;
}
AT_CUDA_CHECK(cudaMemcpyAsync(
values_data, &mode, sizeof(scalar_t), cudaMemcpyHostToDevice, stream));
//memcpy_and_sync will synchronize results
at::cuda::memcpy_and_sync(indices_data, &index, sizeof(scalar_t), cudaMemcpyHostToDevice, stream);
}
template <typename scalar_t>
void apply_mode(
Tensor& values,
Tensor& indices,
const Tensor& self,
std::vector<int64_t>& position,
int dim,
int curDim) {
// Because we have transposed the Tensor, the data for the dimension we are
// mode'ing along is always in the innermost dimension
int64_t ndim = ensure_nonempty_dim(self.dim());
if (curDim == ndim - 1) {
calculate_mode<scalar_t>(values, indices, self, position, dim);
} else {
for (int i = 0; i < ensure_nonempty_size(self, curDim); ++i) {
position[curDim] = i;
apply_mode<scalar_t>(values, indices, self, position, dim, curDim + 1);
}
}
}
template <int64_t size, typename scalar_t>
void handle_fused_mode(
dim3 grid,
const Tensor& self,
cuda::detail::TensorInfo<scalar_t, unsigned int>& ti_values,
cuda::detail::TensorInfo<int64_t, unsigned int>& ti_indices,
int64_t slice_size,
int64_t slices) {
const dim3 block(size / 2);
const auto memsize =
(sizeof(scalar_t) * size) + (2 * size * sizeof(unsigned int));
compute_mode<scalar_t, size>
<<<grid, block, memsize, at::cuda::getCurrentCUDAStream()>>>(
self.data_ptr<scalar_t>(), ti_values, ti_indices, slice_size, slices);
C10_CUDA_KERNEL_LAUNCH_CHECK();
}
template <typename scalar_t>
void fused_mode(
Tensor& values,
Tensor& indices,
const Tensor& self,
int64_t slice_size,
int64_t slices) {
// Set-up TensorInfo structs for passing to kernel
auto ti_values = cuda::detail::getTensorInfo<scalar_t, unsigned int>(values);
auto ti_indices = cuda::detail::getTensorInfo<int64_t, unsigned int>(indices);
// The number of blocks is the number of slices that we need to calculate
// the mode for. Each block is responsible for computing a single mode
dim3 grid;
getGridFromTiles(slices, grid);
// The blocksize is two elements per thread, rounded up to the nearest power
// of 2
auto ceilPowerOf2 = nextHighestPowerOf2(slice_size);
// Tradeoff between compilation time and the number of specializations.
// Ideally we would have one handle_fused_mode for each power of 2
switch (ceilPowerOf2) {
case 2048:
handle_fused_mode<2048, scalar_t>(
grid, self, ti_values, ti_indices, slice_size, slices);
break;
case 1024:
case 512:
case 256:
handle_fused_mode<1024, scalar_t>(
grid, self, ti_values, ti_indices, slice_size, slices);
break;
case 128:
case 64:
handle_fused_mode<128, scalar_t>(
grid, self, ti_values, ti_indices, slice_size, slices);
break;
case 32:
case 16:
case 8:
case 4:
case 2:
handle_fused_mode<32, scalar_t>(
grid, self, ti_values, ti_indices, slice_size, slices);
break;
case 1:
default:
TORCH_INTERNAL_ASSERT(false);
}
AT_CUDA_CHECK(cudaGetLastError());
}
void mode_kernel_impl(
Tensor& values,
Tensor& indices,
const Tensor& self,
int64_t dim,
bool keepdim) {
auto self_sizes = ensure_nonempty_vec(self.sizes().vec());
int64_t ndim = ensure_nonempty_dim(self.dim());
int64_t slice_size = ensure_nonempty_size(self, dim);
int64_t slices = self.numel() / slice_size;
// Resize output value, index Tensors to appropriate sizes (i.e. the same as
// the input Tensor, except at dim=dimension, the size is 1)
self_sizes[dim] = 1;
if (!keepdim) {
if (values.ndimension() >= dim) {
values.unsqueeze_(dim);
}
if (indices.ndimension() >= dim) {
indices.unsqueeze_(dim);
}
}
at::native::resize_output(values, self_sizes);
at::native::resize_output(indices, self_sizes);
// If sliceSize is 1, copy input to values and set indices
if (slice_size == 1) {
values.copy_(self);
indices.fill_(0);
if (!keepdim) {
values.squeeze_(dim);
indices.squeeze_(dim);
}
return;
}
// Beginning our optimized implementation. First thing we want to do is to
// transpose the input Tensor along the sort dimension, and then make it
// contiguous.
auto transposed = self.transpose(dim, ndim - 1);
auto contiguous = transposed.contiguous();
// We also need to view the values and indices Tensors as transposed in order
// to properly determine the offset into the underlying storage in which to
// place the mode and index for a particular set of dimension values.
auto values_transposed = values.transpose(dim, ndim - 1);
auto indices_transposed = indices.transpose(dim, ndim - 1);
// Call mode
AT_DISPATCH_ALL_TYPES_AND3(kBool, kBFloat16, kHalf, self.scalar_type(), "cuda_mode", [&] {
// Requirements for fused kernel implementation:
//
// 1. sliceSize <= 2 * max threads per block
// 2. uses one block per slice, so number of slices must be less than the
// maximum number of blocks for a kernel launch
// 3. Can use 32-bit index math for indexing (mainly just for implementation
// conciseness, could be changed)
//
// MAX_BLOCK_SIZE and MAX_GRID_SIZE come from:
// ATen/native/cuda/SortingCommon.cuh
if (slice_size <= 2 * MAX_BLOCK_SIZE &&
slices <= MAX_GRID_SIZE * MAX_GRID_SIZE * MAX_GRID_SIZE &&
cuda::detail::canUse32BitIndexMath(self)) {
fused_mode<scalar_t>(
values_transposed,
indices_transposed,
contiguous,
slice_size,
slices);
} else {
// If transposed is already contiguous, it will return a tensor with the
// same storage. So, since we do not want to modify self, we clone it.
if (transposed.is_contiguous()) {
contiguous = contiguous.clone();
}
// Position will store the dimension values we are processing
std::vector<int64_t> position(ndim - 1, 0);
apply_mode<scalar_t>(
values_transposed, indices_transposed, contiguous, position, dim, 0);
}
});
if (!keepdim) {
values.squeeze_(dim);
indices.squeeze_(dim);
}
}
#undef MAX_GRID_SIZE
#undef MAX_BLOCK_SIZE
REGISTER_DISPATCH(mode_stub, &mode_kernel_impl);
} // namespace native
} // namespace at