Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Batch kernels for backward pass of Preprocessing #3

Open
wants to merge 36 commits into
base: mlsys/forward_preprocess_batch
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
c408be1
preprocess batches for backward outline
sandeepnmenon Apr 27, 2024
6444aa7
solve syntax errors in backward
prapti19 Apr 27, 2024
8df1b63
Refactor GaussianRasterizationSettings class to handle raster_setting…
sandeepnmenon Apr 28, 2024
f59e67c
Merge branch 'mlsys/batched_preprocess' of github.com:TarzanZhao/diff…
sandeepnmenon Apr 28, 2024
529710b
added focal_x and focal_y calculation inside the kernel
sandeepnmenon Apr 28, 2024
3ac1ad3
Refactor rasterization_tests.py to use raster_settings_batch instead …
sandeepnmenon Apr 28, 2024
ce314e2
fixed namedtuple setting bug
sandeepnmenon Apr 28, 2024
fdd3b4f
Refactor GaussianRasterizationSettings class to handle raster_setting…
sandeepnmenon Apr 28, 2024
cdf3bc1
remove focal_x and focal_y calculations
sandeepnmenon Apr 28, 2024
591a5c1
Refactor CUDA rasterizer code to include width and height parameters …
sandeepnmenon Apr 28, 2024
1b54023
Renamed W and H to image_width and image_height parameters in prepr…
sandeepnmenon Apr 28, 2024
9294a07
reverted focal_x and focal_y removal in normal preprocessBackward
sandeepnmenon Apr 28, 2024
877677f
grad_means2D to handle more than 2 dimensions
sandeepnmenon Apr 28, 2024
444e8a5
add tests for backward
prapti19 Apr 28, 2024
46b83eb
ruff formatting and gradients for remaining inputs
sandeepnmenon Apr 28, 2024
710b56f
Add pyproject.toml file with ruff line-length set to 120
sandeepnmenon Apr 28, 2024
2ca5ae6
Refactor ruff.toml file to set line-length to 120 and indent-width to 4
sandeepnmenon Apr 28, 2024
7a4b6b4
Refactor compare_tensors function to handle None values in rasterizat…
sandeepnmenon Apr 28, 2024
14889e6
Update ruff.toml file to set line-length to 120
sandeepnmenon Apr 28, 2024
5682d26
Refactor rasterization_backward_tests.py to include gradient checks f…
sandeepnmenon Apr 28, 2024
945e8cf
gradients calculated for all the variables to check and cloning them
sandeepnmenon Apr 28, 2024
c84d7cd
converted to pytest testing
sandeepnmenon Apr 28, 2024
6f38446
fixed colon bug and ruff formatiting
sandeepnmenon Apr 28, 2024
7b30782
Add __pycache__/ to .gitignore
sandeepnmenon Apr 28, 2024
13a5559
renamed to *_test.py
sandeepnmenon Apr 28, 2024
5b24881
Update .gitignore to include __pycache__/
sandeepnmenon Apr 28, 2024
9e6f4a9
moved test into tests folder
sandeepnmenon Apr 28, 2024
7be38fa
Add instructions for running tests in README.md
sandeepnmenon Apr 28, 2024
1887e14
Merge branch 'mlsys/forward_preprocess_batch' of github.com:TarzanZha…
sandeepnmenon Apr 28, 2024
307e156
deleted old test file
sandeepnmenon Apr 28, 2024
21ee225
renamed idx to point_idx and view_idx to result_idx in backward
sandeepnmenon Apr 29, 2024
363b4ee
moved from python time to torch record
sandeepnmenon May 8, 2024
91d1582
fixed num_points in preprocessForwardBatches
sandeepnmenon May 8, 2024
ee767da
Refactor test function names for clarity and consistency
sandeepnmenon May 11, 2024
2e7f032
fixed but in printing only first 5 non matching indices
sandeepnmenon May 11, 2024
e8edb86
fixed backward bug of backward kernel not getting executed
sandeepnmenon May 11, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@ diff_gaussian_rasterization.egg-info/
dist/
diff_gaussian_rasterization/__pycache__/
*so
__pycache__/
6 changes: 5 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,8 @@ Used as the rasterization engine for the paper "3D Gaussian Splatting for Real-T
url = {https://repo-sam.inria.fr/fungraph/3d-gaussian-splatting/}
}</code></pre>
</div>
</section>
</section>

## Running tests
Use pytest to run the tests. The tests are located in the `tests` directory. To run all tests, simply run `pytest` in the root directory of the project.
Use the `--capture=no` flag to see the output of the tests including the performance metrics.
300 changes: 286 additions & 14 deletions cuda_rasterizer/backward.cu
Original file line number Diff line number Diff line change
Expand Up @@ -17,21 +17,21 @@ namespace cg = cooperative_groups;

// Backward pass for conversion of spherical harmonics to RGB for
// each Gaussian.
__device__ void computeColorFromSH(int idx, int deg, int max_coeffs, const glm::vec3* means, glm::vec3 campos, const float* shs, const bool* clamped, const glm::vec3* dL_dcolor, glm::vec3* dL_dmeans, glm::vec3* dL_dshs)
__device__ void computeColorFromSH(int point_idx, int result_idx, int deg, int max_coeffs, const glm::vec3* means, glm::vec3 campos, const float* shs, const bool* clamped, const glm::vec3* dL_dcolor, glm::vec3* dL_dmeans, glm::vec3* dL_dshs)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should make sure your two pull request will not conflict with each other when we merge both of them. Maybe we could do this together in tomorrow meeting.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They are separate PRs. We'll resolve merge conflict after the forward pass PR is merged in.

{
// Compute intermediate values, as it is done during forward
glm::vec3 pos = means[idx];
glm::vec3 pos = means[point_idx];
glm::vec3 dir_orig = pos - campos;
glm::vec3 dir = dir_orig / glm::length(dir_orig);

glm::vec3* sh = ((glm::vec3*)shs) + idx * max_coeffs;
glm::vec3* sh = ((glm::vec3*)shs) + point_idx * max_coeffs;

// Use PyTorch rule for clamping: if clamping was applied,
// gradient becomes 0.
glm::vec3 dL_dRGB = dL_dcolor[idx];
dL_dRGB.x *= clamped[3 * idx + 0] ? 0 : 1;
dL_dRGB.y *= clamped[3 * idx + 1] ? 0 : 1;
dL_dRGB.z *= clamped[3 * idx + 2] ? 0 : 1;
glm::vec3 dL_dRGB = dL_dcolor[point_idx];
dL_dRGB.x *= clamped[3 * result_idx + 0] ? 0 : 1;
dL_dRGB.y *= clamped[3 * result_idx + 1] ? 0 : 1;
dL_dRGB.z *= clamped[3 * result_idx + 2] ? 0 : 1;

glm::vec3 dRGBdx(0, 0, 0);
glm::vec3 dRGBdy(0, 0, 0);
Expand All @@ -41,9 +41,9 @@ __device__ void computeColorFromSH(int idx, int deg, int max_coeffs, const glm::
float z = dir.z;

// Target location for this Gaussian to write SH gradients to
glm::vec3* dL_dsh = dL_dshs + idx * max_coeffs;
glm::vec3 *dL_dsh = dL_dshs + result_idx * max_coeffs;

// No tricks here, just high school-level calculus.
// No tricks here, just high school-level calculus.
float dRGBdsh0 = SH_C0;
dL_dsh[0] = dRGBdsh0 * dL_dRGB;
if (deg > 0)
Expand All @@ -55,7 +55,7 @@ __device__ void computeColorFromSH(int idx, int deg, int max_coeffs, const glm::
dL_dsh[2] = dRGBdsh2 * dL_dRGB;
dL_dsh[3] = dRGBdsh3 * dL_dRGB;

dRGBdx = -SH_C1 * sh[3];
dRGBdx = -SH_C1 * sh[3];
dRGBdy = -SH_C1 * sh[1];
dRGBdz = SH_C1 * sh[2];

Expand All @@ -75,7 +75,7 @@ __device__ void computeColorFromSH(int idx, int deg, int max_coeffs, const glm::
dL_dsh[7] = dRGBdsh7 * dL_dRGB;
dL_dsh[8] = dRGBdsh8 * dL_dRGB;

dRGBdx += SH_C2[0] * y * sh[4] + SH_C2[2] * 2.f * -x * sh[6] + SH_C2[3] * z * sh[7] + SH_C2[4] * 2.f * x * sh[8];
dRGBdx += SH_C2[0] * y * sh[4] + SH_C2[2] * 2.f * -x * sh[6] + SH_C2[3] * z * sh[7] + SH_C2[4] * 2.f * x * sh[8];
dRGBdy += SH_C2[0] * x * sh[4] + SH_C2[1] * z * sh[5] + SH_C2[2] * 2.f * -y * sh[6] + SH_C2[4] * 2.f * -y * sh[8];
dRGBdz += SH_C2[1] * y * sh[5] + SH_C2[2] * 2.f * 2.f * z * sh[6] + SH_C2[3] * x * sh[7];

Expand All @@ -96,7 +96,7 @@ __device__ void computeColorFromSH(int idx, int deg, int max_coeffs, const glm::
dL_dsh[14] = dRGBdsh14 * dL_dRGB;
dL_dsh[15] = dRGBdsh15 * dL_dRGB;

dRGBdx += (
dRGBdx += (
SH_C3[0] * sh[9] * 3.f * 2.f * xy +
SH_C3[1] * sh[10] * yz +
SH_C3[2] * sh[11] * -2.f * xy +
Expand Down Expand Up @@ -135,7 +135,7 @@ __device__ void computeColorFromSH(int idx, int deg, int max_coeffs, const glm::
// Gradients of loss w.r.t. Gaussian means, but only the portion
// that is caused because the mean affects the view-dependent color.
// Additional mean gradient is accumulated in below methods.
dL_dmeans[idx] += glm::vec3(dL_dmean.x, dL_dmean.y, dL_dmean.z);
dL_dmeans[point_idx] += glm::vec3(dL_dmean.x, dL_dmean.y, dL_dmean.z);
}

// Backward version of INVERSE 2D covariance matrix computation
Expand Down Expand Up @@ -273,6 +273,150 @@ __global__ void computeCov2DCUDA(int P,
dL_dmeans[idx] = dL_dmean;
}

__global__ void computeCov2DCUDABatched(
const int num_viewpoints,
const int P,
const float3* means,
const int* radii,
const float* cov3Ds,
const int image_width, const int image_height,
const float* tan_fovx, const float* tan_fovy,
const float* viewmatrix_arr,
const float* dL_dconics,
float3* dL_dmeans,
float* dL_dcov)
{
auto point_idx = blockIdx.x * blockDim.x + threadIdx.x;
auto viewpoint_idx = blockIdx.y;

if (point_idx >= P || viewpoint_idx >= num_viewpoints)
return;

auto idx = viewpoint_idx * P + point_idx;
if (!(radii[idx] > 0))
return;

const float* view_matrix = viewmatrix_arr + viewpoint_idx * 16;

// Reading location of 3D covariance for this Gaussian
const float* cov3D = cov3Ds + 6 * idx;

// Fetch gradients, recompute 2D covariance and relevant
// intermediate forward results needed in the backward.
float3 mean = means[point_idx];
float3 dL_dconic = { dL_dconics[4 * idx], dL_dconics[4 * idx + 1], dL_dconics[4 * idx + 3] };
float3 t = transformPoint4x3(mean, view_matrix);

const float limx = 1.3f * tan_fovx[viewpoint_idx];
const float limy = 1.3f * tan_fovy[viewpoint_idx];
const float txtz = t.x / t.z;
const float tytz = t.y / t.z;
t.x = min(limx, max(-limx, txtz)) * t.z;
t.y = min(limy, max(-limy, tytz)) * t.z;

const float x_grad_mul = txtz < -limx || txtz > limx ? 0 : 1;
const float y_grad_mul = tytz < -limy || tytz > limy ? 0 : 1;

const float h_x = image_width / (2.0f * tan_fovx[viewpoint_idx]);
const float h_y = image_height / (2.0f * tan_fovy[viewpoint_idx]);
glm::mat3 J = glm::mat3(h_x / t.z, 0.0f, -(h_x * t.x) / (t.z * t.z),
0.0f, h_y / t.z, -(h_y * t.y) / (t.z * t.z),
0, 0, 0);

glm::mat3 W = glm::mat3(
view_matrix[0], view_matrix[4], view_matrix[8],
view_matrix[1], view_matrix[5], view_matrix[9],
view_matrix[2], view_matrix[6], view_matrix[10]);

glm::mat3 Vrk = glm::mat3(
cov3D[0], cov3D[1], cov3D[2],
cov3D[1], cov3D[3], cov3D[4],
cov3D[2], cov3D[4], cov3D[5]);

glm::mat3 T = W * J;

glm::mat3 cov2D = glm::transpose(T) * glm::transpose(Vrk) * T;

// Use helper variables for 2D covariance entries. More compact.
float a = cov2D[0][0] += 0.3f;
float b = cov2D[0][1];
float c = cov2D[1][1] += 0.3f;

float denom = a * c - b * b;
float dL_da = 0, dL_db = 0, dL_dc = 0;
float denom2inv = 1.0f / ((denom * denom) + 0.0000001f);

if (denom2inv != 0)
{
// Gradients of loss w.r.t. entries of 2D covariance matrix,
// given gradients of loss w.r.t. conic matrix (inverse covariance matrix).
// e.g., dL / da = dL / d_conic_a * d_conic_a / d_a
dL_da = denom2inv * (-c * c * dL_dconic.x + 2 * b * c * dL_dconic.y + (denom - a * c) * dL_dconic.z);
dL_dc = denom2inv * (-a * a * dL_dconic.z + 2 * a * b * dL_dconic.y + (denom - a * c) * dL_dconic.x);
dL_db = denom2inv * 2 * (b * c * dL_dconic.x - (denom + 2 * b * b) * dL_dconic.y + a * b * dL_dconic.z);

// Gradients of loss L w.r.t. each 3D covariance matrix (Vrk) entry,
// given gradients w.r.t. 2D covariance matrix (diagonal).
// cov2D = transpose(T) * transpose(Vrk) * T;
dL_dcov[6 * idx + 0] = (T[0][0] * T[0][0] * dL_da + T[0][0] * T[1][0] * dL_db + T[1][0] * T[1][0] * dL_dc);
dL_dcov[6 * idx + 3] = (T[0][1] * T[0][1] * dL_da + T[0][1] * T[1][1] * dL_db + T[1][1] * T[1][1] * dL_dc);
dL_dcov[6 * idx + 5] = (T[0][2] * T[0][2] * dL_da + T[0][2] * T[1][2] * dL_db + T[1][2] * T[1][2] * dL_dc);

// Gradients of loss L w.r.t. each 3D covariance matrix (Vrk) entry,
// given gradients w.r.t. 2D covariance matrix (off-diagonal).
// Off-diagonal elements appear twice --> double the gradient.
// cov2D = transpose(T) * transpose(Vrk) * T;
dL_dcov[6 * idx + 1] = 2 * T[0][0] * T[0][1] * dL_da + (T[0][0] * T[1][1] + T[0][1] * T[1][0]) * dL_db + 2 * T[1][0] * T[1][1] * dL_dc;
dL_dcov[6 * idx + 2] = 2 * T[0][0] * T[0][2] * dL_da + (T[0][0] * T[1][2] + T[0][2] * T[1][0]) * dL_db + 2 * T[1][0] * T[1][2] * dL_dc;
dL_dcov[6 * idx + 4] = 2 * T[0][2] * T[0][1] * dL_da + (T[0][1] * T[1][2] + T[0][2] * T[1][1]) * dL_db + 2 * T[1][1] * T[1][2] * dL_dc;
}
else
{
for (int i = 0; i < 6; i++)
dL_dcov[6 * idx + i] = 0;
}

// Gradients of loss w.r.t. upper 2x3 portion of intermediate matrix T
// cov2D = transpose(T) * transpose(Vrk) * T;
float dL_dT00 = 2 * (T[0][0] * Vrk[0][0] + T[0][1] * Vrk[0][1] + T[0][2] * Vrk[0][2]) * dL_da +
(T[1][0] * Vrk[0][0] + T[1][1] * Vrk[0][1] + T[1][2] * Vrk[0][2]) * dL_db;
float dL_dT01 = 2 * (T[0][0] * Vrk[1][0] + T[0][1] * Vrk[1][1] + T[0][2] * Vrk[1][2]) * dL_da +
(T[1][0] * Vrk[1][0] + T[1][1] * Vrk[1][1] + T[1][2] * Vrk[1][2]) * dL_db;
float dL_dT02 = 2 * (T[0][0] * Vrk[2][0] + T[0][1] * Vrk[2][1] + T[0][2] * Vrk[2][2]) * dL_da +
(T[1][0] * Vrk[2][0] + T[1][1] * Vrk[2][1] + T[1][2] * Vrk[2][2]) * dL_db;
float dL_dT10 = 2 * (T[1][0] * Vrk[0][0] + T[1][1] * Vrk[0][1] + T[1][2] * Vrk[0][2]) * dL_dc +
(T[0][0] * Vrk[0][0] + T[0][1] * Vrk[0][1] + T[0][2] * Vrk[0][2]) * dL_db;
float dL_dT11 = 2 * (T[1][0] * Vrk[1][0] + T[1][1] * Vrk[1][1] + T[1][2] * Vrk[1][2]) * dL_dc +
(T[0][0] * Vrk[1][0] + T[0][1] * Vrk[1][1] + T[0][2] * Vrk[1][2]) * dL_db;
float dL_dT12 = 2 * (T[1][0] * Vrk[2][0] + T[1][1] * Vrk[2][1] + T[1][2] * Vrk[2][2]) * dL_dc +
(T[0][0] * Vrk[2][0] + T[0][1] * Vrk[2][1] + T[0][2] * Vrk[2][2]) * dL_db;

// Gradients of loss w.r.t. upper 3x2 non-zero entries of Jacobian matrix
// T = W * J
float dL_dJ00 = W[0][0] * dL_dT00 + W[0][1] * dL_dT01 + W[0][2] * dL_dT02;
float dL_dJ02 = W[2][0] * dL_dT00 + W[2][1] * dL_dT01 + W[2][2] * dL_dT02;
float dL_dJ11 = W[1][0] * dL_dT10 + W[1][1] * dL_dT11 + W[1][2] * dL_dT12;
float dL_dJ12 = W[2][0] * dL_dT10 + W[2][1] * dL_dT11 + W[2][2] * dL_dT12;

float tz = 1.f / t.z;
float tz2 = tz * tz;
float tz3 = tz2 * tz;

// Gradients of loss w.r.t. transformed Gaussian mean t
float dL_dtx = x_grad_mul * -h_x * tz2 * dL_dJ02;
float dL_dty = y_grad_mul * -h_y * tz2 * dL_dJ12;
float dL_dtz = -h_x * tz2 * dL_dJ00 - h_y * tz2 * dL_dJ11 + (2 * h_x * t.x) * tz3 * dL_dJ02 + (2 * h_y * t.y) * tz3 * dL_dJ12;

// Account for transformation of mean to t
// t = transformPoint4x3(mean, view_matrix);
float3 dL_dmean = transformVec4x3Transpose({ dL_dtx, dL_dty, dL_dtz }, view_matrix);

// Gradients of loss w.r.t. Gaussian means, but only the portion
// that is caused because the mean affects the covariance matrix.
// Additional mean gradient is accumulated in BACKWARD::preprocess.
dL_dmeans[idx] = dL_dmean;
}

// Backward pass for the conversion of scale and rotation to a
// 3D covariance matrix for each Gaussian.
__device__ void computeCov3D(int idx, const glm::vec3 scale, float mod, const glm::vec4 rot, const float* dL_dcov3Ds, glm::vec3* dL_dscales, glm::vec4* dL_drots)
Expand Down Expand Up @@ -388,13 +532,72 @@ __global__ void preprocessCUDA(

// Compute gradient updates due to computing colors from SHs
if (shs)
computeColorFromSH(idx, D, M, (glm::vec3*)means, *campos, shs, clamped, (glm::vec3*)dL_dcolor, (glm::vec3*)dL_dmeans, (glm::vec3*)dL_dsh);
computeColorFromSH(idx, idx, D, M, (glm::vec3*)means, *campos, shs, clamped, (glm::vec3*)dL_dcolor, (glm::vec3*)dL_dmeans, (glm::vec3*)dL_dsh);

// Compute gradient updates due to computing covariance from scale/rotation
if (scales)
computeCov3D(idx, scales[idx], scale_modifier, rotations[idx], dL_dcov3D, dL_dscale, dL_drot);
}

template<int C>
__global__ void preprocessCUDABatched(
const int num_viewpoints,
const int P, const int D, const int M,
const float3* means,
const int* radii,
const float* shs,
const bool* clamped,
const glm::vec3* scales,
const glm::vec4* rotations,
const float scale_modifier,
const float* projmatrix_arr,
const glm::vec3* campos,
const float3* dL_dmean2D,
glm::vec3* dL_dmeans,
float* dL_dcolor,//TODO: this should be change to const float*, because we do not modify dL_dcolor in preprocessCUDA backward.
float* dL_dcov3D,
float* dL_dsh,
glm::vec3* dL_dscale,
glm::vec4* dL_drot)
{
auto point_idx = blockIdx.x * blockDim.x + threadIdx.x;
auto viewpoint_idx = blockIdx.y;
if (viewpoint_idx >= num_viewpoints || point_idx >= P) return;

auto idx = viewpoint_idx * P + point_idx;
if (!(radii[idx] > 0))
return;

const float* proj = projmatrix_arr + viewpoint_idx * 16;

float3 m = means[idx];

// Taking care of gradients from the screenspace points
float4 m_hom = transformPoint4x4(m, proj);
float m_w = 1.0f / (m_hom.w + 0.0000001f);

// Compute loss gradient w.r.t. 3D means due to gradients of 2D means
// from rendering procedure
glm::vec3 dL_dmean;
float mul1 = (proj[0] * m.x + proj[4] * m.y + proj[8] * m.z + proj[12]) * m_w * m_w;
float mul2 = (proj[1] * m.x + proj[5] * m.y + proj[9] * m.z + proj[13]) * m_w * m_w;
dL_dmean.x = (proj[0] * m_w - proj[3] * mul1) * dL_dmean2D[idx].x + (proj[1] * m_w - proj[3] * mul2) * dL_dmean2D[idx].y;
dL_dmean.y = (proj[4] * m_w - proj[7] * mul1) * dL_dmean2D[idx].x + (proj[5] * m_w - proj[7] * mul2) * dL_dmean2D[idx].y;
dL_dmean.z = (proj[8] * m_w - proj[11] * mul1) * dL_dmean2D[idx].x + (proj[9] * m_w - proj[11] * mul2) * dL_dmean2D[idx].y;

// That's the second part of the mean gradient. Previous computation
// of cov2D and following SH conversion also affects it.
dL_dmeans[idx] += dL_dmean;

// Compute gradient updates due to computing colors from SHs
if (shs)
computeColorFromSH(point_idx, idx, D, M, (glm::vec3*)means, campos[viewpoint_idx], shs, clamped, (glm::vec3*)dL_dcolor, (glm::vec3*)dL_dmeans, (glm::vec3*)dL_dsh);

// Compute gradient updates due to computing covariance from scale/rotation
if (scales)
computeCov3D(idx, scales[point_idx], scale_modifier, rotations[point_idx], dL_dcov3D, dL_dscale, dL_drot);
}

// Backward version of the rendering procedure.
template <uint32_t C>
__global__ void __launch_bounds__(BLOCK_X * BLOCK_Y)
Expand Down Expand Up @@ -691,4 +894,73 @@ void BACKWARD::render(
dL_dopacity,
dL_dcolors
);
}

void BACKWARD::preprocess_batch(
const int num_viewpoints,
const int P, const int D, const int M,
const float3* means3D,
const int* radii,
const float* shs,
const bool* clamped,
const glm::vec3* scales,
const glm::vec4* rotations,
const float scale_modifier,
const float* cov3Ds,
const float* viewmatrix,
const float* projmatrix,
const int W, const int H,
const float* tan_fovx, const float* tan_fovy,
const glm::vec3* campos,
const float3* dL_dmean2D,
const float* dL_dconic,
glm::vec3* dL_dmean3D,
float* dL_dcolor,
float* dL_dcov3D,
float* dL_dsh,
glm::vec3* dL_dscale,
glm::vec4* dL_drot)
{
// Propagate gradients for the path of 2D conic matrix computation.
// Somewhat long, thus it is its own kernel rather than being part of
// "preprocess". When done, loss gradient w.r.t. 3D means has been
// modified and gradient w.r.t. 3D covariance matrix has been computed.
dim3 tile_grid(cdiv(P, ONE_DIM_BLOCK_SIZE), num_viewpoints);

computeCov2DCUDABatched << <tile_grid, ONE_DIM_BLOCK_SIZE >> > (
num_viewpoints,
P,
means3D,
radii,
cov3Ds,
W, H,
tan_fovx,
tan_fovy,
viewmatrix,
dL_dconic,
(float3*)dL_dmean3D,
dL_dcov3D);

// Propagate gradients for remaining steps: finish 3D mean gradients,
// propagate color gradients to SH (if desireD), propagate 3D covariance
// matrix gradients to scale and rotation.
preprocessCUDABatched<NUM_CHANNELS> << < tile_grid, ONE_DIM_BLOCK_SIZE >> > (
num_viewpoints,
P, D, M,
(float3*)means3D,
radii,
shs,
clamped,
(glm::vec3*)scales,
(glm::vec4*)rotations,
scale_modifier,
projmatrix,
campos,
(float3*)dL_dmean2D,
(glm::vec3*)dL_dmean3D,
dL_dcolor,
dL_dcov3D,
dL_dsh,
dL_dscale,
dL_drot);
}
Loading