-
Notifications
You must be signed in to change notification settings - Fork 16
/
Copy pathrender_lerp_kernel_nvol.cu
502 lines (440 loc) · 17.6 KB
/
render_lerp_kernel_nvol.cu
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
// Copyright 2021 Alex Yu
// This is an alternate implementation using the volume rendering formula from Neural Volumes (Lombardi, ToG 2019)
// NOTE: this still uses density as in NeRF, but the key difference is using absolute, instead of relative, transmittance.
// This formula allows for parallel evaluation of points along the ray, since it's just an ablation this is not optimized
// Background is NOT supported
#include <torch/extension.h>
#include "cuda_util.cuh"
#include "data_spec_packed.cuh"
#include "render_util.cuh"
#include <iostream>
#include <cstdint>
#include <tuple>
namespace {
const int WARP_SIZE = 32;
const int TRACE_RAY_CUDA_THREADS = 128;
const int TRACE_RAY_CUDA_RAYS_PER_BLOCK = TRACE_RAY_CUDA_THREADS / WARP_SIZE;
const int TRACE_RAY_BKWD_CUDA_THREADS = 128;
const int TRACE_RAY_BKWD_CUDA_RAYS_PER_BLOCK = TRACE_RAY_BKWD_CUDA_THREADS / WARP_SIZE;
const int MIN_BLOCKS_PER_SM = 8;
typedef cub::WarpReduce<float> WarpReducef;
namespace device {
// * For ray rendering
__device__ __inline__ void trace_ray_nvol(
const PackedSparseGridSpec& __restrict__ grid,
SingleRaySpec& __restrict__ ray,
const RenderOptions& __restrict__ opt,
uint32_t lane_id,
float* __restrict__ sphfunc_val,
WarpReducef::TempStorage& __restrict__ temp_storage,
float* __restrict__ out) {
const uint32_t lane_colorgrp_id = lane_id % grid.basis_dim;
const uint32_t lane_colorgrp = lane_id / grid.basis_dim;
if (ray.tmin > ray.tmax) {
out[lane_colorgrp] = (grid.background_nlayers == 0) ? opt.background_brightness : 0.f;
return;
}
float t = ray.tmin;
float outv = 0.f;
float total_alpha = 0.f;
while (t <= ray.tmax) {
#pragma unroll 3
for (int j = 0; j < 3; ++j) {
ray.pos[j] = fmaf(t, ray.dir[j], ray.origin[j]);
ray.pos[j] = min(max(ray.pos[j], 0.f), grid.size[j] - 1.f);
ray.l[j] = min(static_cast<int32_t>(ray.pos[j]), grid.size[j] - 2);
ray.pos[j] -= static_cast<float>(ray.l[j]);
}
const float skip = compute_skip_dist(ray,
grid.links, grid.stride_x,
grid.size[2], 0);
if (skip >= opt.step_size) {
// For consistency, we skip the by step size
t += ceilf(skip / opt.step_size) * opt.step_size;
continue;
}
float sigma = trilerp_cuvol_one(
grid.links, grid.density_data,
grid.stride_x,
grid.size[2],
1,
ray.l, ray.pos,
0);
if (sigma > opt.sigma_thresh) {
float lane_color = trilerp_cuvol_one(
grid.links,
grid.sh_data,
grid.stride_x,
grid.size[2],
grid.sh_data_dim,
ray.l, ray.pos, lane_id);
lane_color *= sphfunc_val[lane_colorgrp_id];
const float new_total_alpha = fminf(total_alpha + 1.f - _EXP(
-ray.world_step * sigma), 1.f);
const float weight = new_total_alpha - total_alpha;
total_alpha = new_total_alpha;
float lane_color_total = WarpReducef(temp_storage).HeadSegmentedSum(
lane_color, lane_colorgrp_id == 0);
outv += weight * fmaxf(lane_color_total + 0.5f, 0.f); // Clamp to [+0, infty)
if (total_alpha >= 1.f) break;
}
t += opt.step_size;
}
if (grid.background_nlayers == 0) {
outv += (1.f - total_alpha) * opt.background_brightness;
}
if (lane_colorgrp_id == 0) {
out[lane_colorgrp] = outv;
}
}
__device__ __inline__ void trace_ray_nvol_backward(
const PackedSparseGridSpec& __restrict__ grid,
const float* __restrict__ grad_output,
const float* __restrict__ color_cache,
SingleRaySpec& __restrict__ ray,
const RenderOptions& __restrict__ opt,
uint32_t lane_id,
const float* __restrict__ sphfunc_val,
float* __restrict__ grad_sphfunc_val,
WarpReducef::TempStorage& __restrict__ temp_storage,
float sparsity_loss,
PackedGridOutputGrads& __restrict__ grads
) {
const uint32_t lane_colorgrp_id = lane_id % grid.basis_dim;
const uint32_t lane_colorgrp = lane_id / grid.basis_dim;
const uint32_t leader_mask = 1U | (1U << grid.basis_dim) | (1U << (2 * grid.basis_dim));
if (ray.tmin > ray.tmax) {
return;
}
float t = ray.tmin;
const float gout = grad_output[lane_colorgrp];
float total_alpha = 0.f;
float last_total_color = 0.f;
// remat samples
while (t <= ray.tmax) {
#pragma unroll 3
for (int j = 0; j < 3; ++j) {
ray.pos[j] = fmaf(t, ray.dir[j], ray.origin[j]);
ray.pos[j] = min(max(ray.pos[j], 0.f), grid.size[j] - 1.f);
ray.l[j] = min(static_cast<int32_t>(ray.pos[j]), grid.size[j] - 2);
ray.pos[j] -= static_cast<float>(ray.l[j]);
}
const float skip = compute_skip_dist(ray,
grid.links, grid.stride_x,
grid.size[2], 0);
if (skip >= opt.step_size) {
// For consistency, we skip the by step size
t += ceilf(skip / opt.step_size) * opt.step_size;
continue;
}
float sigma = trilerp_cuvol_one(
grid.links,
grid.density_data,
grid.stride_x,
grid.size[2],
1,
ray.l, ray.pos,
0);
if (sigma > opt.sigma_thresh) {
float lane_color = trilerp_cuvol_one(
grid.links,
grid.sh_data,
grid.stride_x,
grid.size[2],
grid.sh_data_dim,
ray.l, ray.pos, lane_id);
float weighted_lane_color = lane_color * sphfunc_val[lane_colorgrp_id];
const float curr_transmit = _EXP(-ray.world_step * sigma);
const float new_total_alpha = fminf(total_alpha + 1.f - curr_transmit, 1.f);
const float weight = new_total_alpha - total_alpha;
bool not_last = new_total_alpha < 1.f;
total_alpha = new_total_alpha;
const float lane_color_total = WarpReducef(temp_storage).HeadSegmentedSum(
weighted_lane_color, lane_colorgrp_id == 0) + 0.5f;
float total_color = fmaxf(lane_color_total, 0.f);
float color_in_01 = total_color == lane_color_total;
total_color *= gout; // Clamp to [+0, infty)
float total_color_c1 = __shfl_sync(leader_mask, total_color, grid.basis_dim);
total_color += __shfl_sync(leader_mask, total_color, 2 * grid.basis_dim);
total_color += total_color_c1;
color_in_01 = __shfl_sync((1U << grid.sh_data_dim) - 1, color_in_01, lane_colorgrp * grid.basis_dim);
const float grad_common = weight * color_in_01 * gout;
const float curr_grad_color = sphfunc_val[lane_colorgrp_id] * grad_common;
if (grid.basis_type != BASIS_TYPE_SH) {
float curr_grad_sphfunc = lane_color * grad_common;
const float curr_grad_up2 = __shfl_down_sync((1U << grid.sh_data_dim) - 1,
curr_grad_sphfunc, 2 * grid.basis_dim);
curr_grad_sphfunc += __shfl_down_sync((1U << grid.sh_data_dim) - 1,
curr_grad_sphfunc, grid.basis_dim);
curr_grad_sphfunc += curr_grad_up2;
if (lane_id < grid.basis_dim) {
grad_sphfunc_val[lane_id] += curr_grad_sphfunc;
}
}
trilerp_backward_cuvol_one(grid.links, grads.grad_sh_out,
grid.stride_x,
grid.size[2],
grid.sh_data_dim,
ray.l, ray.pos,
curr_grad_color, lane_id);
if (not_last) {
float curr_grad_sigma = ray.world_step * curr_transmit * total_color;
if (sparsity_loss > 0.f) {
// Cauchy version (from SNeRG)
curr_grad_sigma += sparsity_loss * (4 * sigma / (1 + 2 * (sigma * sigma)));
// Alphs version (from PlenOctrees)
// curr_grad_sigma += sparsity_loss * _EXP(-pcnt) * ray.world_step;
}
if (lane_id == 0) {
trilerp_backward_cuvol_one_density(
grid.links,
grads.grad_density_out,
grads.mask_out,
grid.stride_x,
grid.size[2],
ray.l, ray.pos, curr_grad_sigma);
}
} else {
ray.tmax = t;
last_total_color = total_color;
break;
}
}
t += opt.step_size;
}
if (total_alpha < 1.f) {
// Never saturatedo
last_total_color = opt.background_brightness * (
grad_output[0] + grad_output[1] + grad_output[2]);
}
if (last_total_color != 0.f) {
t = ray.tmin;
total_alpha = 0.f;
while (t <= ray.tmax) {
#pragma unroll 3
for (int j = 0; j < 3; ++j) {
ray.pos[j] = fmaf(t, ray.dir[j], ray.origin[j]);
ray.pos[j] = min(max(ray.pos[j], 0.f), grid.size[j] - 1.f);
ray.l[j] = min(static_cast<int32_t>(ray.pos[j]), grid.size[j] - 2);
ray.pos[j] -= static_cast<float>(ray.l[j]);
}
const float skip = compute_skip_dist(ray,
grid.links, grid.stride_x,
grid.size[2], 0);
if (skip >= opt.step_size) {
// For consistency, we skip the by step size
t += ceilf(skip / opt.step_size) * opt.step_size;
continue;
}
float sigma = trilerp_cuvol_one(
grid.links,
grid.density_data,
grid.stride_x,
grid.size[2],
1,
ray.l, ray.pos,
0);
const float curr_transmit = _EXP(-ray.world_step * sigma);
total_alpha = fminf(total_alpha + 1.f - curr_transmit, 1.f);
// const float weight = new_total_alpha - total_alpha;
// total_alpha = new_total_alpha;
if (total_alpha >= 1.f) break;
float curr_grad_sigma = -ray.world_step * curr_transmit * last_total_color;
if (lane_id == 0) {
trilerp_backward_cuvol_one_density(
grid.links,
grads.grad_density_out,
grads.mask_out,
grid.stride_x,
grid.size[2],
ray.l, ray.pos, curr_grad_sigma);
}
t += opt.step_size;
}
}
}
// BEGIN KERNELS
__launch_bounds__(TRACE_RAY_CUDA_THREADS, MIN_BLOCKS_PER_SM)
__global__ void render_ray_kernel(
PackedSparseGridSpec grid,
PackedRaysSpec rays,
RenderOptions opt,
torch::PackedTensorAccessor32<float, 2, torch::RestrictPtrTraits> out) {
CUDA_GET_THREAD_ID(tid, int(rays.origins.size(0)) * WARP_SIZE);
const int ray_id = tid >> 5;
const int ray_blk_id = threadIdx.x >> 5;
const int lane_id = threadIdx.x & 0x1F;
if (lane_id >= grid.sh_data_dim)
return;
__shared__ float sphfunc_val[TRACE_RAY_CUDA_RAYS_PER_BLOCK][10];
__shared__ SingleRaySpec ray_spec[TRACE_RAY_CUDA_RAYS_PER_BLOCK];
__shared__ typename WarpReducef::TempStorage temp_storage[
TRACE_RAY_CUDA_RAYS_PER_BLOCK];
ray_spec[ray_blk_id].set(rays.origins[ray_id].data(),
rays.dirs[ray_id].data());
calc_sphfunc(grid, lane_id,
ray_id,
ray_spec[ray_blk_id].dir,
sphfunc_val[ray_blk_id]);
ray_find_bounds(ray_spec[ray_blk_id], grid, opt, ray_id);
__syncwarp((1U << grid.sh_data_dim) - 1);
trace_ray_nvol(
grid,
ray_spec[ray_blk_id],
opt,
lane_id,
sphfunc_val[ray_blk_id],
temp_storage[ray_blk_id],
out[ray_id].data());
}
__launch_bounds__(TRACE_RAY_BKWD_CUDA_THREADS, MIN_BLOCKS_PER_SM)
__global__ void render_ray_backward_kernel(
PackedSparseGridSpec grid,
const float* __restrict__ grad_output,
const float* __restrict__ color_cache,
PackedRaysSpec rays,
RenderOptions opt,
bool grad_out_is_rgb,
float sparsity_loss,
PackedGridOutputGrads grads) {
CUDA_GET_THREAD_ID(tid, int(rays.origins.size(0)) * WARP_SIZE);
const int ray_id = tid >> 5;
const int ray_blk_id = threadIdx.x >> 5;
const int lane_id = threadIdx.x & 0x1F;
if (lane_id >= grid.sh_data_dim)
return;
__shared__ float sphfunc_val[TRACE_RAY_BKWD_CUDA_RAYS_PER_BLOCK][10];
__shared__ float grad_sphfunc_val[TRACE_RAY_CUDA_RAYS_PER_BLOCK][10];
__shared__ SingleRaySpec ray_spec[TRACE_RAY_BKWD_CUDA_RAYS_PER_BLOCK];
__shared__ typename WarpReducef::TempStorage temp_storage[
TRACE_RAY_CUDA_RAYS_PER_BLOCK];
ray_spec[ray_blk_id].set(rays.origins[ray_id].data(),
rays.dirs[ray_id].data());
const float vdir[3] = {ray_spec[ray_blk_id].dir[0],
ray_spec[ray_blk_id].dir[1],
ray_spec[ray_blk_id].dir[2] };
if (lane_id < grid.basis_dim) {
grad_sphfunc_val[ray_blk_id][lane_id] = 0.f;
}
calc_sphfunc(grid, lane_id,
ray_id,
vdir, sphfunc_val[ray_blk_id]);
if (lane_id == 0) {
ray_find_bounds(ray_spec[ray_blk_id], grid, opt, ray_id);
}
float grad_out[3];
if (grad_out_is_rgb) {
const float norm_factor = 2.f / (3 * int(rays.origins.size(0)));
#pragma unroll 3
for (int i = 0; i < 3; ++i) {
const float resid = color_cache[ray_id * 3 + i] - grad_output[ray_id * 3 + i];
grad_out[i] = resid * norm_factor;
}
} else {
#pragma unroll 3
for (int i = 0; i < 3; ++i) {
grad_out[i] = grad_output[ray_id * 3 + i];
}
}
__syncwarp((1U << grid.sh_data_dim) - 1);
trace_ray_nvol_backward(
grid,
grad_out,
color_cache + ray_id * 3,
ray_spec[ray_blk_id],
opt,
lane_id,
sphfunc_val[ray_blk_id],
grad_sphfunc_val[ray_blk_id],
temp_storage[ray_blk_id],
sparsity_loss,
grads);
calc_sphfunc_backward(
grid, lane_id,
ray_id,
vdir,
sphfunc_val[ray_blk_id],
grad_sphfunc_val[ray_blk_id],
grads.grad_basis_out);
}
} // namespace device
} // namespace
torch::Tensor volume_render_nvol(SparseGridSpec& grid, RaysSpec& rays, RenderOptions& opt) {
DEVICE_GUARD(grid.sh_data);
grid.check();
rays.check();
const auto Q = rays.origins.size(0);
torch::Tensor results = torch::empty_like(rays.origins);
const int cuda_n_threads = TRACE_RAY_CUDA_THREADS;
const int blocks = CUDA_N_BLOCKS_NEEDED(Q * WARP_SIZE, cuda_n_threads);
device::render_ray_kernel<<<blocks, cuda_n_threads>>>(
grid, rays, opt,
// Output
results.packed_accessor32<float, 2, torch::RestrictPtrTraits>());
CUDA_CHECK_ERRORS;
return results;
}
void volume_render_nvol_backward(
SparseGridSpec& grid,
RaysSpec& rays,
RenderOptions& opt,
torch::Tensor grad_out,
torch::Tensor color_cache,
GridOutputGrads& grads) {
DEVICE_GUARD(grid.sh_data);
grid.check();
rays.check();
grads.check();
const auto Q = rays.origins.size(0);
const int cuda_n_threads_render_backward = TRACE_RAY_BKWD_CUDA_THREADS;
const int blocks = CUDA_N_BLOCKS_NEEDED(Q * WARP_SIZE, cuda_n_threads_render_backward);
device::render_ray_backward_kernel<<<blocks,
cuda_n_threads_render_backward>>>(
grid,
grad_out.data_ptr<float>(),
color_cache.data_ptr<float>(),
rays, opt,
false,
0.f,
// Output
grads);
CUDA_CHECK_ERRORS;
}
void volume_render_nvol_fused(
SparseGridSpec& grid,
RaysSpec& rays,
RenderOptions& opt,
torch::Tensor rgb_gt,
float _, // not supported
float sparsity_loss,
torch::Tensor rgb_out,
GridOutputGrads& grads) {
DEVICE_GUARD(grid.sh_data);
CHECK_INPUT(rgb_gt);
CHECK_INPUT(rgb_out);
grid.check();
rays.check();
grads.check();
const auto Q = rays.origins.size(0);
{
const int blocks = CUDA_N_BLOCKS_NEEDED(Q * WARP_SIZE, TRACE_RAY_CUDA_THREADS);
device::render_ray_kernel<<<blocks, TRACE_RAY_CUDA_THREADS>>>(
grid, rays, opt,
// Output
rgb_out.packed_accessor32<float, 2, torch::RestrictPtrTraits>());
}
{
const int blocks = CUDA_N_BLOCKS_NEEDED(Q * WARP_SIZE, TRACE_RAY_BKWD_CUDA_THREADS);
device::render_ray_backward_kernel<<<blocks, TRACE_RAY_BKWD_CUDA_THREADS>>>(
grid,
rgb_gt.data_ptr<float>(),
rgb_out.data_ptr<float>(),
rays,
opt,
true,
sparsity_loss,
// Output
grads);
}
CUDA_CHECK_ERRORS;
}