Skip to content

Batch kernels for forward pass of Preprocessing #2

New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Merged
merged 57 commits into from
May 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
57 commits
Select commit Hold shift + click to select a range
6aed776
test function for rasterszaton tests
sandeepnmenon Apr 20, 2024
b7b08ba
add mock of improved preproc
Apr 20, 2024
54302e4
batched rasterization
sandeepnmenon Apr 21, 2024
a5e505a
Merge branch 'prapti/preproc_gauss' of github.com:TarzanZhao/diff-gau…
sandeepnmenon Apr 21, 2024
53e12d2
add rough idea for kernel
Apr 21, 2024
f0e0469
Refactor rasterizer import in rasterization_tests.py
sandeepnmenon Apr 22, 2024
a16acd0
Refactor GaussianRasterizerBatches class to support batched preproces…
sandeepnmenon Apr 22, 2024
aa07ce2
Merge branch 'prapti/preproc_gauss' of github.com:TarzanZhao/diff-gau…
sandeepnmenon Apr 22, 2024
268f46a
Refactor preprocess_gaussians function to remove flag_batched paramet…
sandeepnmenon Apr 22, 2024
7361323
batched forward pass kernel
sandeepnmenon Apr 22, 2024
7f4935d
added headers and changed kernel structure to 1d block
sandeepnmenon Apr 23, 2024
5f05af5
solved syntax errors
Apr 23, 2024
543d4b8
fixed import syntax in test
Apr 24, 2024
8ca5a9f
formatting changes
sandeepnmenon Apr 24, 2024
0dbe8fd
Refactor GaussianRasterizerBatches class to use torch.tensor instead …
sandeepnmenon Apr 24, 2024
193fa82
Refactor variable name in test_batched_gaussian_rasterizer_batch_proc…
sandeepnmenon Apr 24, 2024
ac43fc4
Refactor preprocess_gaussians function to handle batched and non-batc…
sandeepnmenon Apr 24, 2024
4115266
Refactor test_batched_gaussian_rasterizer and test_batched_gaussian_r…
sandeepnmenon Apr 24, 2024
162e7d0
Refactor test_batched_gaussian_rasterizer and test_batched_gaussian_r…
sandeepnmenon Apr 24, 2024
fdf3bf5
add parity test
prapti19 Apr 25, 2024
cace4fd
Refactor preprocess_gaussians function to handle batched and non-batc…
sandeepnmenon Apr 25, 2024
7cad1b0
Merge branch 'prapti_mlsys/batched_preprocess' of github.com:TarzanZh…
sandeepnmenon Apr 25, 2024
eaf0d42
Refactor test_batched_gaussian_rasterizer and test_batched_gaussian_r…
sandeepnmenon Apr 25, 2024
d9eb4e8
Refactor test_batched_gaussian_rasterizer and test_batched_gaussian_r…
sandeepnmenon Apr 25, 2024
c38cfa9
add debug flag to extra_compile_args
sandeepnmenon Apr 25, 2024
24905aa
Refactor tan_fovy parameter to be const in CUDA rasterizer files
sandeepnmenon Apr 25, 2024
d376d41
Refactor tan_fovy parameter to be const in CUDA rasterizer files
sandeepnmenon Apr 25, 2024
8c82fa7
Refactor tan_fovy parameter to be const in CUDA rasterizer files
sandeepnmenon Apr 25, 2024
4cca118
Refactor CUDA rasterizer files to use CUDA tensors for batched calcul…
sandeepnmenon Apr 25, 2024
34ebced
Refactor test_batched_gaussian_rasterizer and test_batched_gaussian_r…
sandeepnmenon Apr 25, 2024
f6374d9
Refactor test_batched_gaussian_rasterizer and test_batched_gaussian_r…
sandeepnmenon Apr 25, 2024
d0230a6
Refactor assert_tensor_equal function to compare_tensors in rasteriza…
sandeepnmenon Apr 25, 2024
e529d2a
tile_grid calculated before kernel launch
sandeepnmenon Apr 26, 2024
09b853e
Fix indexing bug in preprocessCUDABatched function
sandeepnmenon Apr 26, 2024
44f8fc1
Refactor indexing in preprocessCUDABatched function in forward.cu
sandeepnmenon Apr 26, 2024
a00921d
Refactor tile_grid calculation in rasterizer_impl.cu
sandeepnmenon Apr 26, 2024
890d95f
Refactor indexing in preprocessCUDABatched function in forward.cu
sandeepnmenon Apr 26, 2024
b3ad196
Refactor compare_tensors function in rasterization_tests.py to handle…
sandeepnmenon Apr 26, 2024
abfb8b4
Fix indexing bug in preprocessCUDABatched function
sandeepnmenon Apr 26, 2024
18f9c20
Update rasterization_tests.py
prapti19 Apr 26, 2024
a04b34d
Refactor compare_tensors function in rasterization_tests.py to handle…
sandeepnmenon Apr 26, 2024
e593132
Refactor compare_tensors function to fix indexing bug and handle non-…
sandeepnmenon Apr 26, 2024
e109969
Update forward.cu
prapti19 Apr 26, 2024
edfea2e
Update rasterization_tests.py
prapti19 Apr 27, 2024
53a14e2
Update forward.cu
prapti19 Apr 27, 2024
32a601f
fixed sh_sdegree
sandeepnmenon Apr 27, 2024
147f71d
merged commits
sandeepnmenon Apr 27, 2024
22eb043
Refactor GaussianRasterizationSettings class to handle raster_setting…
sandeepnmenon Apr 28, 2024
7ff2fd3
Refactor rasterization_tests.py to use raster_settings_batch instead …
sandeepnmenon Apr 28, 2024
fc48eec
fixed namedtuple setting bug
sandeepnmenon Apr 28, 2024
49c5179
Refactor GaussianRasterizationSettings class to handle raster_setting…
sandeepnmenon Apr 28, 2024
a0d7127
Update setup.py to remove debug flag from extra_compile_args
sandeepnmenon Apr 28, 2024
a21c4b9
Fix formatting issues in forward.cu and __init__.py
sandeepnmenon Apr 28, 2024
25c6812
Refactor computeColorFromSH function in forward.cu to use point_idx a…
sandeepnmenon Apr 28, 2024
1b7fdc4
replaced python time with torch event records
sandeepnmenon May 8, 2024
3c4c667
fixed cuda illegal memory bug and can run for 1M gaussians
sandeepnmenon May 8, 2024
1e4cbc9
chore: Update .gitignore to ignore *.pyc files
sandeepnmenon May 8, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@ diff_gaussian_rasterization.egg-info/
dist/
diff_gaussian_rasterization/__pycache__/
*so
*.pyc
170 changes: 160 additions & 10 deletions cuda_rasterizer/forward.cu
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,16 @@ namespace cg = cooperative_groups;

// Forward method for converting the input spherical harmonics
// coefficients of each Gaussian to a simple RGB color.
__device__ glm::vec3 computeColorFromSH(int idx, int deg, int max_coeffs, const glm::vec3* means, glm::vec3 campos, const float* shs, bool* clamped)
__device__ glm::vec3 computeColorFromSH(int point_idx, int result_idx, int deg, int max_coeffs, const glm::vec3* means, glm::vec3 campos, const float* shs, bool* clamped)
{
// The implementation is loosely based on code for
// "Differentiable Point-Based Radiance Fields for
// Efficient View Synthesis" by Zhang et al. (2022)
glm::vec3 pos = means[idx];
glm::vec3 pos = means[point_idx];
glm::vec3 dir = pos - campos;
dir = dir / glm::length(dir);

glm::vec3* sh = ((glm::vec3*)shs) + idx * max_coeffs;
glm::vec3* sh = ((glm::vec3*)shs) + point_idx * max_coeffs;
glm::vec3 result = SH_C0 * sh[0];

if (deg > 0)
Expand Down Expand Up @@ -65,9 +65,9 @@ __device__ glm::vec3 computeColorFromSH(int idx, int deg, int max_coeffs, const

// RGB colors are clamped to positive values. If values are
// clamped, we need to keep track of this for the backward pass.
clamped[3 * idx + 0] = (result.x < 0);
clamped[3 * idx + 1] = (result.y < 0);
clamped[3 * idx + 2] = (result.z < 0);
clamped[3 * result_idx + 0] = (result.x < 0);
clamped[3 * result_idx + 1] = (result.y < 0);
clamped[3 * result_idx + 2] = (result.z < 0);
return glm::max(result, 0.0f);
}

Expand Down Expand Up @@ -213,7 +213,6 @@ __global__ void preprocessCUDA(int P, int D, int M,
computeCov3D(scales[idx], scale_modifier, rotations[idx], cov3Ds + idx * 6);
cov3D = cov3Ds + idx * 6;
}

// Compute 2D screen-space covariance matrix
float3 cov = computeCov2D(p_orig, focal_x, focal_y, tan_fovx, tan_fovy, cov3D, viewmatrix);

Expand Down Expand Up @@ -242,7 +241,7 @@ __global__ void preprocessCUDA(int P, int D, int M,
// spherical harmonics coefficients to RGB color.
if (colors_precomp == nullptr)
{
glm::vec3 result = computeColorFromSH(idx, D, M, (glm::vec3*)orig_points, *cam_pos, shs, clamped);
glm::vec3 result = computeColorFromSH(idx, idx, D, M, (glm::vec3*)orig_points, *cam_pos, shs, clamped);
rgb[idx * C + 0] = result.x;
rgb[idx * C + 1] = result.y;
rgb[idx * C + 2] = result.z;
Expand Down Expand Up @@ -471,7 +470,7 @@ void FORWARD::preprocess(int P, int D, int M,
uint32_t* tiles_touched,
bool prefiltered)
{
preprocessCUDA<NUM_CHANNELS> << <(P + ONE_DIM_BLOCK_SIZE - 1) / ONE_DIM_BLOCK_SIZE, ONE_DIM_BLOCK_SIZE >> > (
preprocessCUDA<NUM_CHANNELS> << <cdiv(P, ONE_DIM_BLOCK_SIZE), ONE_DIM_BLOCK_SIZE >> > (
P, D, M,
means3D,
scales,
Expand All @@ -498,4 +497,155 @@ void FORWARD::preprocess(int P, int D, int M,
tiles_touched,
prefiltered
);
}
}


template<int C>
__global__ void preprocessCUDABatched(
int P, int D, int M,
const float* orig_points, const glm::vec3* scales, const float scale_modifier,
const glm::vec4* rotations, const float* opacities, const float* shs,
bool* clamped, const float* cov3D_precomp, const float* colors_precomp,
const float* viewmatrix_arr, const float* projmatrix_arr, const glm::vec3* cam_pos,
const int W, int H, const float* tan_fovx, const float* tan_fovy,
int* radii, float2* points_xy_image, float* depths, float* cov3Ds,
float* rgb, float4* conic_opacity, const dim3 grid, uint32_t* tiles_touched,
bool prefiltered, const int num_viewpoints)
{
auto point_idx = blockIdx.x * blockDim.x + threadIdx.x;
auto viewpoint_idx = blockIdx.y;

if (viewpoint_idx >= num_viewpoints || point_idx >= P) return;

auto idx = viewpoint_idx * P + point_idx;
const float* viewmatrix = viewmatrix_arr + viewpoint_idx * 16;
const float* projmatrix = projmatrix_arr + viewpoint_idx * 16;

// Initialize radius and touched tiles to 0. If this isn't changed,
// this Gaussian will not be processed further.
radii[idx] = 0;
tiles_touched[idx] = 0;

// Perform near culling, quit if outside.
float3 p_view;
if (!in_frustum(point_idx, orig_points, viewmatrix, projmatrix, prefiltered, p_view)) return;

// Transform point by projecting
float3 p_orig = { orig_points[3 * point_idx], orig_points[3 * point_idx + 1], orig_points[3 * point_idx + 2] };

float4 p_hom = transformPoint4x4(p_orig, projmatrix);
float p_w = 1.0f / (p_hom.w + 0.0000001f);
float3 p_proj = { p_hom.x * p_w, p_hom.y * p_w, p_hom.z * p_w };

// If 3D covariance matrix is precomputed, use it, otherwise compute
// from scaling and rotation parameters.
const float* cov3D;
if (cov3D_precomp != nullptr) {
cov3D = cov3D_precomp + idx * 6;
} else {
computeCov3D(scales[point_idx], scale_modifier, rotations[point_idx], cov3Ds + idx * 6);
cov3D = cov3Ds + idx * 6;
}


// Compute 2D screen-space covariance matrix
const float focal_x = W / (2.0f * tan_fovx[viewpoint_idx]);
const float focal_y = H / (2.0f * tan_fovy[viewpoint_idx]);
float3 cov = computeCov2D(p_orig, focal_x, focal_y, tan_fovx[viewpoint_idx], tan_fovy[viewpoint_idx], cov3D, viewmatrix);


// Invert covariance (EWA algorithm)
float det = (cov.x * cov.z - cov.y * cov.y);
if (det == 0.0f) return;
float det_inv = 1.f / det;
float3 conic = { cov.z * det_inv, -cov.y * det_inv, cov.x * det_inv };

// Compute extent in screen space (by finding eigenvalues of
// 2D covariance matrix). Use extent to compute a bounding rectangle
// of screen-space tiles that this Gaussian overlaps with. Quit if
// rectangle covers 0 tiles.
float mid = 0.5f * (cov.x + cov.z);
float lambda1 = mid + sqrt(max(0.1f, mid * mid - det));
float lambda2 = mid - sqrt(max(0.1f, mid * mid - det));
float my_radius = ceil(3.f * sqrt(max(lambda1, lambda2)));
float2 point_image = { ndc2Pix(p_proj.x, W), ndc2Pix(p_proj.y, H) };
uint2 rect_min, rect_max;
getRect(point_image, my_radius, rect_min, rect_max, grid);
if ((rect_max.x - rect_min.x) * (rect_max.y - rect_min.y) == 0) return;

// If colors have been precomputed, use them, otherwise convert
// spherical harmonics coefficients to RGB color.

if (colors_precomp == nullptr) {

glm::vec3 result = computeColorFromSH(point_idx, idx, D, M, (glm::vec3*)orig_points, cam_pos[viewpoint_idx], shs, clamped);
rgb[idx * C + 0] = result.x;
rgb[idx * C + 1] = result.y;
rgb[idx * C + 2] = result.z;
}

// Store some useful helper data for the next steps.
depths[idx] = p_view.z;
radii[idx] = my_radius;
points_xy_image[idx] = point_image;

// Inverse 2D covariance and opacity neatly pack into one float4
conic_opacity[idx] = { conic.x, conic.y, conic.z, opacities[point_idx] };
tiles_touched[idx] = (rect_max.y - rect_min.y) * (rect_max.x - rect_min.x);
}

void FORWARD::preprocess_batch(int P, int D, int M,
const float* means3D,
const glm::vec3* scales,
const float scale_modifier,
const glm::vec4* rotations,
const float* opacities,
const float* shs,
bool* clamped,
const float* cov3D_precomp,
const float* colors_precomp,
const float* viewmatrix,
const float* projmatrix,
const glm::vec3* cam_pos,
const int W, int H,
const float* tan_fovx, const float* tan_fovy,
int* radii,
float2* means2D,
float* depths,
float* cov3Ds,
float* rgb,
float4* conic_opacity,
const dim3 grid,
uint32_t* tiles_touched,
bool prefiltered,
const int num_viewpoints)
{
dim3 tile_grid(cdiv(P, ONE_DIM_BLOCK_SIZE), num_viewpoints);
preprocessCUDABatched<NUM_CHANNELS><<<tile_grid, ONE_DIM_BLOCK_SIZE>>>(
P, D, M,
means3D,
scales,
scale_modifier,
rotations,
opacities,
shs,
clamped,
cov3D_precomp,
colors_precomp,
viewmatrix,
projmatrix,
cam_pos,
W, H,
tan_fovx, tan_fovy,
radii,
means2D,
depths,
cov3Ds,
rgb,
conic_opacity,
grid,
tiles_touched,
prefiltered,
num_viewpoints
);
}
33 changes: 31 additions & 2 deletions cuda_rasterizer/forward.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,35 @@ namespace FORWARD
float4* conic_opacity,
const dim3 grid,
uint32_t* tiles_touched,
bool prefiltered);
bool prefiltered
);

void preprocess_batch(int P, int D, int M,
const float* means3D,
const glm::vec3* scales,
const float scale_modifier,
const glm::vec4* rotations,
const float* opacities,
const float* shs,
bool* clamped,
const float* cov3D_precomp,
const float* colors_precomp,
const float* viewmatrix,
const float* projmatrix,
const glm::vec3* cam_pos,
const int W, int H,
const float* tan_fovx, const float* tan_fovy,
int* radii,
float2* means2D,
float* depths,
float* cov3Ds,
float* rgb,
float4* conic_opacity,
const dim3 grid,
uint32_t* tiles_touched,
bool prefiltered,
const int num_viewpoints
);

// Main rasterization method.
void render(
Expand All @@ -61,7 +89,8 @@ namespace FORWARD
uint32_t* n_contrib2loss,
const int* compute_locally_1D_2D_map,
const float* bg_color,
float* out_color);
float* out_color
);
}


Expand Down
25 changes: 25 additions & 0 deletions cuda_rasterizer/rasterizer.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,31 @@ namespace CudaRasterizer
bool debug,//raster_settings
const pybind11::dict &args);

static int preprocessForwardBatches(
float2* means2D,
float* depths,
int* radii,
float* cov3D,
float4* conic_opacity,
float* rgb,
bool* clamped,//the above are all per-Gaussian intemediate results.
const int P, int D, int M,
const int width, int height,
const float* means3D,
const float* scales,
const float* rotations,
const float* shs,
const float* opacities,//3dgs parameters
const float scale_modifier,
const float* viewmatrix,
const float* projmatrix,
const float* cam_pos,
const float* tan_fovx, const float* tan_fovy,
const bool prefiltered,
const int num_viewpoints,
bool debug,//raster_settings
const pybind11::dict &args);

static void preprocessBackward(
const int* radii,
const float* cov3D,
Expand Down
Loading