-
Notifications
You must be signed in to change notification settings - Fork 2
New issue
Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? # to your account
Batch kernels for backward pass of Preprocessing #3
base: mlsys/forward_preprocess_batch
Are you sure you want to change the base?
Batch kernels for backward pass of Preprocessing #3
Conversation
…-gaussian-rasterization into mlsys/batched_preprocess
…of batched_raster_settings
…in preprocess functions
…ion_backward_tests.py
…o/diff-gaussian-rasterization into mlsys/batched_preprocess
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Tested everything, works great!
@@ -17,21 +17,21 @@ namespace cg = cooperative_groups; | |||
|
|||
// Backward pass for conversion of spherical harmonics to RGB for | |||
// each Gaussian. | |||
__device__ void computeColorFromSH(int idx, int deg, int max_coeffs, const glm::vec3* means, glm::vec3 campos, const float* shs, const bool* clamped, const glm::vec3* dL_dcolor, glm::vec3* dL_dmeans, glm::vec3* dL_dshs) | |||
__device__ void computeColorFromSH(int point_idx, int result_idx, int deg, int max_coeffs, const glm::vec3* means, glm::vec3 campos, const float* shs, const bool* clamped, const glm::vec3* dL_dcolor, glm::vec3* dL_dmeans, glm::vec3* dL_dshs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should make sure your two pull request will not conflict with each other when we merge both of them. Maybe we could do this together in tomorrow meeting.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
They are separate PRs. We'll resolve merge conflict after the forward pass PR is merged in.
@@ -472,8 +472,8 @@ int CudaRasterizer::Rasterizer::preprocessForwardBatches( | |||
// In sep_rendering==True case, we will compute tiles_touched in the renderForward. | |||
// TODO: remove it later by modifying FORWARD::preprocess when we deprecate sep_rendering==False case | |||
uint32_t* tiles_touched_temp_buffer; | |||
CHECK_CUDA(cudaMalloc(&tiles_touched_temp_buffer, P * sizeof(uint32_t)), debug); | |||
CHECK_CUDA(cudaMemset(tiles_touched_temp_buffer, 0, P * sizeof(uint32_t)), debug); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same as comments in the other PR. We could delete these memory allocation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Created issue #4
torch.cuda.synchronize() | ||
start_backward_event.record() | ||
|
||
loss = compute_dummy_loss(means3D, scales, rotations, shs, opacity) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This might be wrong. To test the correctness of your batched implementation, we should indeed execute the backward kernels. However, the dummy loss only use original parameters as input. Therefore, we won't call the backward kernel of preprocess when you do loss.backward()
. You should change to
loss = compute_dummy_loss( batched_means2D, batched_conic_opacity, batched_rgb, batched_depths, batched_radii)
This is a follow up for PR #2 and should be merged after.
Changes
preprocess_gaussians_backward_batched
Results
Tests ran on V100
Time taken by run_batched_gaussian_rasterizer: 95.2484 ms
Time taken by run_batched_gaussian_rasterizer BACKWARD: 84.6223 ms
Time taken by run_batched_gaussian_rasterizer_batch_processing: 28.9444 ms
Time taken by run_batched_gaussian_rasterizer_batch_processing BACKWARD: 4.9838 ms