-
Notifications
You must be signed in to change notification settings - Fork 16
/
Copy pathoptim_kernel.cu
267 lines (238 loc) · 9.64 KB
/
optim_kernel.cu
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
// Copyright 2021 Alex Yu
// Optimizer-related kernels
#include <torch/extension.h>
#include "cuda_util.cuh"
namespace {
const int RMSPROP_STEP_CUDA_THREADS = 256;
const int MIN_BLOCKS_PER_SM = 4;
namespace device {
// RMSPROP
__inline__ __device__ void rmsprop_once(
float* __restrict__ ptr_data,
float* __restrict__ ptr_rms,
float* __restrict__ ptr_grad,
const float beta, const float lr, const float epsilon, float minval) {
float rms = *ptr_rms;
rms = rms == 0.f ? _SQR(*ptr_grad) : lerp(_SQR(*ptr_grad), rms, beta);
*ptr_rms = rms;
*ptr_data = fmaxf(*ptr_data - lr * (*ptr_grad) / (sqrtf(rms) + epsilon), minval);
*ptr_grad = 0.f;
}
__launch_bounds__(RMSPROP_STEP_CUDA_THREADS, MIN_BLOCKS_PER_SM)
__global__ void rmsprop_step_kernel(
torch::PackedTensorAccessor64<float, 2, torch::RestrictPtrTraits> all_data,
torch::PackedTensorAccessor64<float, 2, torch::RestrictPtrTraits> all_rms,
torch::PackedTensorAccessor64<float, 2, torch::RestrictPtrTraits> all_grad,
float beta,
float lr,
float epsilon,
float minval,
float lr_last) {
CUDA_GET_THREAD_ID(tid, all_data.size(0) * all_data.size(1));
int32_t chnl = tid % all_data.size(1);
rmsprop_once(all_data.data() + tid,
all_rms.data() + tid,
all_grad.data() + tid,
beta,
(chnl == all_data.size(1) - 1) ? lr_last : lr,
epsilon,
minval);
}
__launch_bounds__(RMSPROP_STEP_CUDA_THREADS, MIN_BLOCKS_PER_SM)
__global__ void rmsprop_mask_step_kernel(
torch::PackedTensorAccessor64<float, 2, torch::RestrictPtrTraits> all_data,
torch::PackedTensorAccessor64<float, 2, torch::RestrictPtrTraits> all_rms,
torch::PackedTensorAccessor64<float, 2, torch::RestrictPtrTraits> all_grad,
const bool* __restrict__ mask,
float beta,
float lr,
float epsilon,
float minval,
float lr_last) {
CUDA_GET_THREAD_ID(tid, all_data.size(0) * all_data.size(1));
if (mask[tid / all_data.size(1)] == false) return;
int32_t chnl = tid % all_data.size(1);
rmsprop_once(all_data.data() + tid,
all_rms.data() + tid,
all_grad.data() + tid,
beta,
(chnl == all_data.size(1) - 1) ? lr_last : lr,
epsilon,
minval);
}
__launch_bounds__(RMSPROP_STEP_CUDA_THREADS, MIN_BLOCKS_PER_SM)
__global__ void rmsprop_index_step_kernel(
torch::PackedTensorAccessor64<float, 2, torch::RestrictPtrTraits> all_data,
torch::PackedTensorAccessor64<float, 2, torch::RestrictPtrTraits> all_rms,
torch::PackedTensorAccessor64<float, 2, torch::RestrictPtrTraits> all_grad,
torch::PackedTensorAccessor32<int64_t, 1, torch::RestrictPtrTraits> indices,
float beta,
float lr,
float epsilon,
float minval,
float lr_last) {
CUDA_GET_THREAD_ID(tid, indices.size(0) * all_data.size(1));
int32_t i = indices[tid / all_data.size(1)];
int32_t chnl = tid % all_data.size(1);
size_t off = i * all_data.size(1) + chnl;
rmsprop_once(all_data.data() + off, all_rms.data() + off,
all_grad.data() + off,
beta,
(chnl == all_data.size(1) - 1) ? lr_last : lr,
epsilon,
minval);
}
// SGD
__inline__ __device__ void sgd_once(
float* __restrict__ ptr_data,
float* __restrict__ ptr_grad,
const float lr) {
*ptr_data -= lr * (*ptr_grad);
*ptr_grad = 0.f;
}
__launch_bounds__(RMSPROP_STEP_CUDA_THREADS, MIN_BLOCKS_PER_SM)
__global__ void sgd_step_kernel(
torch::PackedTensorAccessor64<float, 2, torch::RestrictPtrTraits> all_data,
torch::PackedTensorAccessor64<float, 2, torch::RestrictPtrTraits> all_grad,
float lr,
float lr_last) {
CUDA_GET_THREAD_ID(tid, all_data.size(0) * all_data.size(1));
int32_t chnl = tid % all_data.size(1);
sgd_once(all_data.data() + tid,
all_grad.data() + tid,
(chnl == all_data.size(1) - 1) ? lr_last : lr);
}
__launch_bounds__(RMSPROP_STEP_CUDA_THREADS, MIN_BLOCKS_PER_SM)
__global__ void sgd_mask_step_kernel(
torch::PackedTensorAccessor64<float, 2, torch::RestrictPtrTraits> all_data,
torch::PackedTensorAccessor64<float, 2, torch::RestrictPtrTraits> all_grad,
const bool* __restrict__ mask,
float lr,
float lr_last) {
CUDA_GET_THREAD_ID(tid, all_data.size(0) * all_data.size(1));
if (mask[tid / all_data.size(1)] == false) return;
int32_t chnl = tid % all_data.size(1);
sgd_once(all_data.data() + tid,
all_grad.data() + tid,
(chnl == all_data.size(1) - 1) ? lr_last : lr);
}
__launch_bounds__(RMSPROP_STEP_CUDA_THREADS, MIN_BLOCKS_PER_SM)
__global__ void sgd_index_step_kernel(
torch::PackedTensorAccessor64<float, 2, torch::RestrictPtrTraits> all_data,
torch::PackedTensorAccessor64<float, 2, torch::RestrictPtrTraits> all_grad,
torch::PackedTensorAccessor32<int64_t, 1, torch::RestrictPtrTraits> indices,
float lr,
float lr_last) {
CUDA_GET_THREAD_ID(tid, indices.size(0) * all_data.size(1));
int32_t i = indices[tid / all_data.size(1)];
int32_t chnl = tid % all_data.size(1);
size_t off = i * all_data.size(1) + chnl;
sgd_once(all_data.data() + off,
all_grad.data() + off,
(chnl == all_data.size(1) - 1) ? lr_last : lr);
}
} // namespace device
} // namespace
void rmsprop_step(
torch::Tensor data,
torch::Tensor rms,
torch::Tensor grad,
torch::Tensor indexer,
float beta,
float lr,
float epsilon,
float minval,
float lr_last) {
DEVICE_GUARD(data);
CHECK_INPUT(data);
CHECK_INPUT(rms);
CHECK_INPUT(grad);
CHECK_INPUT(indexer);
if (lr_last < 0.f) lr_last = lr;
const int cuda_n_threads = RMSPROP_STEP_CUDA_THREADS;
if (indexer.dim() == 0) {
const size_t Q = data.size(0) * data.size(1);
const int blocks = CUDA_N_BLOCKS_NEEDED(Q, cuda_n_threads);
device::rmsprop_step_kernel<<<blocks, cuda_n_threads>>>(
data.packed_accessor64<float, 2, torch::RestrictPtrTraits>(),
rms.packed_accessor64<float, 2, torch::RestrictPtrTraits>(),
grad.packed_accessor64<float, 2, torch::RestrictPtrTraits>(),
beta,
lr,
epsilon,
minval,
lr_last);
} else if (indexer.size(0) == 0) {
// Skip
} else if (indexer.scalar_type() == at::ScalarType::Bool) {
const size_t Q = data.size(0) * data.size(1);
const int blocks = CUDA_N_BLOCKS_NEEDED(Q, cuda_n_threads);
device::rmsprop_mask_step_kernel<<<blocks, cuda_n_threads>>>(
data.packed_accessor64<float, 2, torch::RestrictPtrTraits>(),
rms.packed_accessor64<float, 2, torch::RestrictPtrTraits>(),
grad.packed_accessor64<float, 2, torch::RestrictPtrTraits>(),
indexer.data_ptr<bool>(),
beta,
lr,
epsilon,
minval,
lr_last);
} else {
const size_t Q = indexer.size(0) * data.size(1);
const int blocks = CUDA_N_BLOCKS_NEEDED(Q, cuda_n_threads);
device::rmsprop_index_step_kernel<<<blocks, cuda_n_threads>>>(
data.packed_accessor64<float, 2, torch::RestrictPtrTraits>(),
rms.packed_accessor64<float, 2, torch::RestrictPtrTraits>(),
grad.packed_accessor64<float, 2, torch::RestrictPtrTraits>(),
indexer.packed_accessor32<int64_t, 1, torch::RestrictPtrTraits>(),
beta,
lr,
epsilon,
minval,
lr_last);
}
CUDA_CHECK_ERRORS;
}
void sgd_step(
torch::Tensor data,
torch::Tensor grad,
torch::Tensor indexer,
float lr,
float lr_last) {
DEVICE_GUARD(data);
CHECK_INPUT(data);
CHECK_INPUT(grad);
CHECK_INPUT(indexer);
if (lr_last < 0.f) lr_last = lr;
const int cuda_n_threads = RMSPROP_STEP_CUDA_THREADS;
if (indexer.dim() == 0) {
const size_t Q = data.size(0) * data.size(1);
const int blocks = CUDA_N_BLOCKS_NEEDED(Q, cuda_n_threads);
device::sgd_step_kernel<<<blocks, cuda_n_threads>>>(
data.packed_accessor64<float, 2, torch::RestrictPtrTraits>(),
grad.packed_accessor64<float, 2, torch::RestrictPtrTraits>(),
lr,
lr_last);
} else if (indexer.size(0) == 0) {
// Skip
} else if (indexer.scalar_type() == at::ScalarType::Bool) {
const size_t Q = data.size(0) * data.size(1);
const int blocks = CUDA_N_BLOCKS_NEEDED(Q, cuda_n_threads);
device::sgd_mask_step_kernel<<<blocks, cuda_n_threads>>>(
data.packed_accessor64<float, 2, torch::RestrictPtrTraits>(),
grad.packed_accessor64<float, 2, torch::RestrictPtrTraits>(),
indexer.data_ptr<bool>(),
lr,
lr_last);
} else {
const size_t Q = indexer.size(0) * data.size(1);
const int blocks = CUDA_N_BLOCKS_NEEDED(Q, cuda_n_threads);
device::sgd_index_step_kernel<<<blocks, cuda_n_threads>>>(
data.packed_accessor64<float, 2, torch::RestrictPtrTraits>(),
grad.packed_accessor64<float, 2, torch::RestrictPtrTraits>(),
indexer.packed_accessor32<int64_t, 1, torch::RestrictPtrTraits>(),
lr,
lr_last);
}
CUDA_CHECK_ERRORS;
}