diff --git a/diffdrr/detector.py b/diffdrr/detector.py index 59b4f796b..ad92d9587 100644 --- a/diffdrr/detector.py +++ b/diffdrr/detector.py @@ -57,22 +57,29 @@ def __init__( @patch def _initialize_carm(self: Detector): """Initialize the default position for the source and detector plane.""" - # Initialize the source on the x-axis - source = torch.tensor([[self.sdr, 0.0, 0.0]]) + try: + device = self.sdr.device + except AttributeError: + device = torch.device("cpu") - # Initialize the center of the detector plane on the negative x-axis - center = torch.tensor([[-self.sdr, 0.0, 0.0]]) + # Initialize the source on the x-axis and the center of the detector plane on the negative x-axis + source = torch.tensor([[1.0, 0.0, 0.0]], device=device) * self.sdr + center = torch.tensor([[-1.0, 0.0, 0.0]], device=device) * self.sdr # Use the standard basis for the detector plane - basis = torch.tensor([[0.0, 1.0, 0.0], [0.0, 0.0, 1.0]]) + basis = torch.tensor([[0.0, 1.0, 0.0], [0.0, 0.0, 1.0]], device=device) # Construct the detector plane with different offsets for even or odd heights h_off = 1.0 if self.height % 2 else 0.5 w_off = 1.0 if self.width % 2 else 0.5 # Construct equally spaced points along the basis vectors - t = (torch.arange(-self.height // 2, self.height // 2) + h_off) * self.delx - s = (torch.arange(-self.width // 2, self.width // 2) + w_off) * self.dely + t = ( + torch.arange(-self.height // 2, self.height // 2, device=device) + h_off + ) * self.delx + s = ( + torch.arange(-self.width // 2, self.width // 2, device=device) + w_off + ) * self.dely if self.reverse_x_axis: s = -s coefs = torch.cartesian_prod(t, s).reshape(-1, 2) diff --git a/notebooks/api/02_detector.ipynb b/notebooks/api/02_detector.ipynb index b8c8a1463..ed6b966c2 100644 --- a/notebooks/api/02_detector.ipynb +++ b/notebooks/api/02_detector.ipynb @@ -128,22 +128,25 @@ "@patch\n", "def _initialize_carm(self: Detector):\n", " \"\"\"Initialize the default position for the source and detector plane.\"\"\"\n", - " # Initialize the source on the x-axis\n", - " source = torch.tensor([[self.sdr, 0.0, 0.0]])\n", - "\n", - " # Initialize the center of the detector plane on the negative x-axis\n", - " center = torch.tensor([[-self.sdr, 0.0, 0.0]])\n", + " try:\n", + " device = self.sdr.device\n", + " except AttributeError:\n", + " device = torch.device(\"cpu\")\n", + " \n", + " # Initialize the source on the x-axis and the center of the detector plane on the negative x-axis\n", + " source = torch.tensor([[1.0, 0.0, 0.0]], device=device) * self.sdr\n", + " center = torch.tensor([[-1.0, 0.0, 0.0]], device=device) * self.sdr\n", "\n", " # Use the standard basis for the detector plane\n", - " basis = torch.tensor([[0.0, 1.0, 0.0], [0.0, 0.0, 1.0]])\n", + " basis = torch.tensor([[0.0, 1.0, 0.0], [0.0, 0.0, 1.0]], device=device)\n", "\n", " # Construct the detector plane with different offsets for even or odd heights\n", " h_off = 1.0 if self.height % 2 else 0.5\n", " w_off = 1.0 if self.width % 2 else 0.5\n", "\n", " # Construct equally spaced points along the basis vectors\n", - " t = (torch.arange(-self.height // 2, self.height // 2) + h_off) * self.delx\n", - " s = (torch.arange(-self.width // 2, self.width // 2) + w_off) * self.dely\n", + " t = (torch.arange(-self.height // 2, self.height // 2, device=device) + h_off) * self.delx\n", + " s = (torch.arange(-self.width // 2, self.width // 2, device=device) + w_off) * self.dely\n", " if self.reverse_x_axis:\n", " s = -s\n", " coefs = torch.cartesian_prod(t, s).reshape(-1, 2)\n",