diff --git a/pytorch3d/ops/points_to_volumes.py b/pytorch3d/ops/points_to_volumes.py index 6b9efc994..59f43fe1b 100644 --- a/pytorch3d/ops/points_to_volumes.py +++ b/pytorch3d/ops/points_to_volumes.py @@ -181,8 +181,11 @@ def add_points_features_to_volume_densities_features( # init the volumetric grid sizes if uninitialized if grid_sizes is None: - grid_sizes = torch.LongTensor(list(volume_densities.shape[2:])).to( - volume_densities + # grid sizes shape (minibatch, 3) + grid_sizes = ( + torch.LongTensor(list(volume_densities.shape[2:])) + .to(volume_densities) + .expand(volume_densities.shape[0], 3) ) # flatten densities and features