Skip to content

Commit

Permalink
Merge pull request #180 from eigenvivek/pytorch3d-remove
Browse files Browse the repository at this point in the history
Remove dependency on `pytorch3d`
  • Loading branch information
eigenvivek authored Jan 24, 2024
2 parents 6831c49 + 5979636 commit e671d45
Show file tree
Hide file tree
Showing 11 changed files with 4,326 additions and 115 deletions.
6 changes: 0 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,6 @@ To install `DiffDRR` from PyPI:
pip install diffdrr
```

`DiffDRR` also requires `PyTorch3D`, which gives us the ability to use multiple parameterizations of SO(3) when constructing camera poses! For most users,
```zsh
conda install pytorch3d -c pytorch3d
```
should work perfectly well. Otherwise, see PyTorch3D's [installation guide](https://github.com/facebookresearch/pytorch3d/blob/main/INSTALL.md).

## Usage

The following minimal example specifies the geometry of the projectional radiograph imaging system and traces rays through a CT volume:
Expand Down
89 changes: 87 additions & 2 deletions diffdrr/_modidx.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,23 +56,108 @@
'diffdrr.siddon._get_index': ('api/siddon.html#_get_index', 'diffdrr/siddon.py'),
'diffdrr.siddon._get_voxel': ('api/siddon.html#_get_voxel', 'diffdrr/siddon.py'),
'diffdrr.siddon.siddon_raycast': ('api/siddon.html#siddon_raycast', 'diffdrr/siddon.py')},
'diffdrr.utils': { 'diffdrr.utils._10vec_to_4x4symmetric': ('api/utils.html#_10vec_to_4x4symmetric', 'diffdrr/utils.py'),
'diffdrr.utils': { 'diffdrr.utils.Rotate': ('api/utils.html#rotate', 'diffdrr/utils.py'),
'diffdrr.utils.Rotate.__init__': ('api/utils.html#rotate.__init__', 'diffdrr/utils.py'),
'diffdrr.utils.Rotate._get_matrix_inverse': ( 'api/utils.html#rotate._get_matrix_inverse',
'diffdrr/utils.py'),
'diffdrr.utils.RotateAxisAngle': ('api/utils.html#rotateaxisangle', 'diffdrr/utils.py'),
'diffdrr.utils.RotateAxisAngle.__init__': ('api/utils.html#rotateaxisangle.__init__', 'diffdrr/utils.py'),
'diffdrr.utils.Scale': ('api/utils.html#scale', 'diffdrr/utils.py'),
'diffdrr.utils.Scale.__init__': ('api/utils.html#scale.__init__', 'diffdrr/utils.py'),
'diffdrr.utils.Scale._get_matrix_inverse': ('api/utils.html#scale._get_matrix_inverse', 'diffdrr/utils.py'),
'diffdrr.utils.Transform3d': ('api/utils.html#transform3d', 'diffdrr/utils.py'),
'diffdrr.utils.Transform3d.__getitem__': ('api/utils.html#transform3d.__getitem__', 'diffdrr/utils.py'),
'diffdrr.utils.Transform3d.__init__': ('api/utils.html#transform3d.__init__', 'diffdrr/utils.py'),
'diffdrr.utils.Transform3d.__len__': ('api/utils.html#transform3d.__len__', 'diffdrr/utils.py'),
'diffdrr.utils.Transform3d._get_matrix_inverse': ( 'api/utils.html#transform3d._get_matrix_inverse',
'diffdrr/utils.py'),
'diffdrr.utils.Transform3d.clone': ('api/utils.html#transform3d.clone', 'diffdrr/utils.py'),
'diffdrr.utils.Transform3d.compose': ('api/utils.html#transform3d.compose', 'diffdrr/utils.py'),
'diffdrr.utils.Transform3d.cpu': ('api/utils.html#transform3d.cpu', 'diffdrr/utils.py'),
'diffdrr.utils.Transform3d.cuda': ('api/utils.html#transform3d.cuda', 'diffdrr/utils.py'),
'diffdrr.utils.Transform3d.get_matrix': ('api/utils.html#transform3d.get_matrix', 'diffdrr/utils.py'),
'diffdrr.utils.Transform3d.get_se3_log': ('api/utils.html#transform3d.get_se3_log', 'diffdrr/utils.py'),
'diffdrr.utils.Transform3d.inverse': ('api/utils.html#transform3d.inverse', 'diffdrr/utils.py'),
'diffdrr.utils.Transform3d.rotate': ('api/utils.html#transform3d.rotate', 'diffdrr/utils.py'),
'diffdrr.utils.Transform3d.rotate_axis_angle': ( 'api/utils.html#transform3d.rotate_axis_angle',
'diffdrr/utils.py'),
'diffdrr.utils.Transform3d.scale': ('api/utils.html#transform3d.scale', 'diffdrr/utils.py'),
'diffdrr.utils.Transform3d.stack': ('api/utils.html#transform3d.stack', 'diffdrr/utils.py'),
'diffdrr.utils.Transform3d.to': ('api/utils.html#transform3d.to', 'diffdrr/utils.py'),
'diffdrr.utils.Transform3d.transform_normals': ( 'api/utils.html#transform3d.transform_normals',
'diffdrr/utils.py'),
'diffdrr.utils.Transform3d.transform_points': ( 'api/utils.html#transform3d.transform_points',
'diffdrr/utils.py'),
'diffdrr.utils.Transform3d.translate': ('api/utils.html#transform3d.translate', 'diffdrr/utils.py'),
'diffdrr.utils.Translate': ('api/utils.html#translate', 'diffdrr/utils.py'),
'diffdrr.utils.Translate.__init__': ('api/utils.html#translate.__init__', 'diffdrr/utils.py'),
'diffdrr.utils.Translate._get_matrix_inverse': ( 'api/utils.html#translate._get_matrix_inverse',
'diffdrr/utils.py'),
'diffdrr.utils._10vec_to_4x4symmetric': ('api/utils.html#_10vec_to_4x4symmetric', 'diffdrr/utils.py'),
'diffdrr.utils._acos_linear_approximation': ( 'api/utils.html#_acos_linear_approximation',
'diffdrr/utils.py'),
'diffdrr.utils._angle_from_tan': ('api/utils.html#_angle_from_tan', 'diffdrr/utils.py'),
'diffdrr.utils._axis_angle_rotation': ('api/utils.html#_axis_angle_rotation', 'diffdrr/utils.py'),
'diffdrr.utils._broadcast_bmm': ('api/utils.html#_broadcast_bmm', 'diffdrr/utils.py'),
'diffdrr.utils._check_valid_rotation_matrix': ( 'api/utils.html#_check_valid_rotation_matrix',
'diffdrr/utils.py'),
'diffdrr.utils._convert_from_rotation_matrix': ( 'api/utils.html#_convert_from_rotation_matrix',
'diffdrr/utils.py'),
'diffdrr.utils._convert_to_rotation_matrix': ( 'api/utils.html#_convert_to_rotation_matrix',
'diffdrr/utils.py'),
'diffdrr.utils._copysign': ('api/utils.html#_copysign', 'diffdrr/utils.py'),
'diffdrr.utils._dacos_dx': ('api/utils.html#_dacos_dx', 'diffdrr/utils.py'),
'diffdrr.utils._get_se3_V_input': ('api/utils.html#_get_se3_v_input', 'diffdrr/utils.py'),
'diffdrr.utils._handle_angle_input': ('api/utils.html#_handle_angle_input', 'diffdrr/utils.py'),
'diffdrr.utils._handle_coord': ('api/utils.html#_handle_coord', 'diffdrr/utils.py'),
'diffdrr.utils._handle_input': ('api/utils.html#_handle_input', 'diffdrr/utils.py'),
'diffdrr.utils._index_from_letter': ('api/utils.html#_index_from_letter', 'diffdrr/utils.py'),
'diffdrr.utils._safe_det_3x3': ('api/utils.html#_safe_det_3x3', 'diffdrr/utils.py'),
'diffdrr.utils._se3_V_matrix': ('api/utils.html#_se3_v_matrix', 'diffdrr/utils.py'),
'diffdrr.utils._so3_exp_map': ('api/utils.html#_so3_exp_map', 'diffdrr/utils.py'),
'diffdrr.utils._sqrt_positive_part': ('api/utils.html#_sqrt_positive_part', 'diffdrr/utils.py'),
'diffdrr.utils.acos_linear_extrapolation': ('api/utils.html#acos_linear_extrapolation', 'diffdrr/utils.py'),
'diffdrr.utils.axis_angle_to_matrix': ('api/utils.html#axis_angle_to_matrix', 'diffdrr/utils.py'),
'diffdrr.utils.axis_angle_to_quaternion': ('api/utils.html#axis_angle_to_quaternion', 'diffdrr/utils.py'),
'diffdrr.utils.convert': ('api/utils.html#convert', 'diffdrr/utils.py'),
'diffdrr.utils.euler_angles_to_matrix': ('api/utils.html#euler_angles_to_matrix', 'diffdrr/utils.py'),
'diffdrr.utils.get_device': ('api/utils.html#get_device', 'diffdrr/utils.py'),
'diffdrr.utils.get_focal_length': ('api/utils.html#get_focal_length', 'diffdrr/utils.py'),
'diffdrr.utils.get_principal_point': ('api/utils.html#get_principal_point', 'diffdrr/utils.py'),
'diffdrr.utils.hat': ('api/utils.html#hat', 'diffdrr/utils.py'),
'diffdrr.utils.hat_inv': ('api/utils.html#hat_inv', 'diffdrr/utils.py'),
'diffdrr.utils.make_device': ('api/utils.html#make_device', 'diffdrr/utils.py'),
'diffdrr.utils.matrix_to_axis_angle': ('api/utils.html#matrix_to_axis_angle', 'diffdrr/utils.py'),
'diffdrr.utils.matrix_to_euler_angles': ('api/utils.html#matrix_to_euler_angles', 'diffdrr/utils.py'),
'diffdrr.utils.matrix_to_quaternion': ('api/utils.html#matrix_to_quaternion', 'diffdrr/utils.py'),
'diffdrr.utils.matrix_to_rotation_6d': ('api/utils.html#matrix_to_rotation_6d', 'diffdrr/utils.py'),
'diffdrr.utils.parse_intrinsic_matrix': ('api/utils.html#parse_intrinsic_matrix', 'diffdrr/utils.py'),
'diffdrr.utils.quaternion_adjugate_to_quaternion': ( 'api/utils.html#quaternion_adjugate_to_quaternion',
'diffdrr/utils.py'),
'diffdrr.utils.quaternion_apply': ('api/utils.html#quaternion_apply', 'diffdrr/utils.py'),
'diffdrr.utils.quaternion_invert': ('api/utils.html#quaternion_invert', 'diffdrr/utils.py'),
'diffdrr.utils.quaternion_multiply': ('api/utils.html#quaternion_multiply', 'diffdrr/utils.py'),
'diffdrr.utils.quaternion_raw_multiply': ('api/utils.html#quaternion_raw_multiply', 'diffdrr/utils.py'),
'diffdrr.utils.quaternion_to_axis_angle': ('api/utils.html#quaternion_to_axis_angle', 'diffdrr/utils.py'),
'diffdrr.utils.quaternion_to_matrix': ('api/utils.html#quaternion_to_matrix', 'diffdrr/utils.py'),
'diffdrr.utils.quaternion_to_quaternion_adjugate': ( 'api/utils.html#quaternion_to_quaternion_adjugate',
'diffdrr/utils.py'),
'diffdrr.utils.quaternion_to_rotation_10d': ( 'api/utils.html#quaternion_to_rotation_10d',
'diffdrr/utils.py'),
'diffdrr.utils.random_quaternions': ('api/utils.html#random_quaternions', 'diffdrr/utils.py'),
'diffdrr.utils.random_rotation': ('api/utils.html#random_rotation', 'diffdrr/utils.py'),
'diffdrr.utils.random_rotations': ('api/utils.html#random_rotations', 'diffdrr/utils.py'),
'diffdrr.utils.rotation_10d_to_quaternion': ( 'api/utils.html#rotation_10d_to_quaternion',
'diffdrr/utils.py')},
'diffdrr/utils.py'),
'diffdrr.utils.rotation_6d_to_matrix': ('api/utils.html#rotation_6d_to_matrix', 'diffdrr/utils.py'),
'diffdrr.utils.se3_exp_map': ('api/utils.html#se3_exp_map', 'diffdrr/utils.py'),
'diffdrr.utils.se3_log_map': ('api/utils.html#se3_log_map', 'diffdrr/utils.py'),
'diffdrr.utils.so3_exp_map': ('api/utils.html#so3_exp_map', 'diffdrr/utils.py'),
'diffdrr.utils.so3_exponential_map': ('api/utils.html#so3_exponential_map', 'diffdrr/utils.py'),
'diffdrr.utils.so3_log_map': ('api/utils.html#so3_log_map', 'diffdrr/utils.py'),
'diffdrr.utils.so3_relative_angle': ('api/utils.html#so3_relative_angle', 'diffdrr/utils.py'),
'diffdrr.utils.so3_rotation_angle': ('api/utils.html#so3_rotation_angle', 'diffdrr/utils.py'),
'diffdrr.utils.standardize_quaternion': ('api/utils.html#standardize_quaternion', 'diffdrr/utils.py')},
'diffdrr.visualization': { 'diffdrr.visualization._make_camera_frustum_mesh': ( 'api/visualization.html#_make_camera_frustum_mesh',
'diffdrr/visualization.py'),
'diffdrr.visualization.animate': ('api/visualization.html#animate', 'diffdrr/visualization.py'),
Expand Down
23 changes: 6 additions & 17 deletions diffdrr/detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,7 @@
# %% auto 0
__all__ = ['Detector', 'diffdrr_to_deepdrr']

# %% ../notebooks/api/02_detector.ipynb 4
try:
import pytorch3d
except ModuleNotFoundError:
install = "https://github.com/facebookresearch/pytorch3d/blob/main/INSTALL.md"
raise ModuleNotFoundError(
f"PyTorch3D is not installed, which is required to parameterize camera poses. See installation instructions here: {install}"
)

# %% ../notebooks/api/02_detector.ipynb 6
# %% ../notebooks/api/02_detector.ipynb 5
class Detector(torch.nn.Module):
"""Construct a 6 DoF X-ray detector system. This model is based on a C-Arm."""

Expand Down Expand Up @@ -53,7 +44,7 @@ def __init__(
self.register_buffer("source", source)
self.register_buffer("target", target)

# %% ../notebooks/api/02_detector.ipynb 7
# %% ../notebooks/api/02_detector.ipynb 6
@patch
def _initialize_carm(self: Detector):
"""Initialize the default position for the source and detector plane."""
Expand Down Expand Up @@ -100,10 +91,8 @@ def _initialize_carm(self: Detector):
self.subsamples.append(sample.tolist())
return source, target

# %% ../notebooks/api/02_detector.ipynb 8
from pytorch3d.transforms import Transform3d

from .utils import convert
# %% ../notebooks/api/02_detector.ipynb 7
from .utils import Transform3d, convert


@patch
Expand All @@ -128,13 +117,13 @@ def forward(
source, target = make_xrays(t, self.source, self.target)
return source, target

# %% ../notebooks/api/02_detector.ipynb 9
# %% ../notebooks/api/02_detector.ipynb 8
def make_xrays(t: Transform3d, source: torch.Tensor, target: torch.Tensor):
source = t.transform_points(source)
target = t.transform_points(target)
return source, target

# %% ../notebooks/api/02_detector.ipynb 10
# %% ../notebooks/api/02_detector.ipynb 9
def diffdrr_to_deepdrr(euler_angles):
alpha, beta, gamma = euler_angles.unbind(-1)
return torch.stack([beta, alpha, gamma], dim=1)
3 changes: 1 addition & 2 deletions diffdrr/drr.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,9 +92,8 @@ def reshape_subsampled_drr(
return drr

# %% ../notebooks/api/00_drr.ipynb 10
from pytorch3d.transforms import Transform3d

from .detector import make_xrays
from .utils import Transform3d


@patch
Expand Down
Loading

0 comments on commit e671d45

Please # to comment.