diff --git a/README.md b/README.md index e1a23e8..f92594d 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,5 @@ -# pytorchse3 +pytorchse3 +================ diff --git a/notebooks/01_se3.ipynb b/notebooks/01_se3.ipynb index b118e9a..5d97dfa 100644 --- a/notebooks/01_se3.ipynb +++ b/notebooks/01_se3.ipynb @@ -109,7 +109,7 @@ " B = taylor_B(theta)\n", " D = (1 - A / (2 * B)) / theta.pow(2)\n", " \n", - " V_inv = torch.eye(3) - 0.5 * log_R + D * log_R_2\n", + " V_inv = torch.eye(3, dtype=A.dtype, device=A.device) - 0.5 * log_R + D * log_R_2\n", " log_t_vee = torch.einsum(\"bij, bj -> bi\", V_inv, t)\n", "\n", " return torch.concat([log_R_vee, log_t_vee], dim=-1)" diff --git a/pytorchse3/se3.py b/pytorchse3/se3.py index f19250a..4cb4178 100644 --- a/pytorchse3/se3.py +++ b/pytorchse3/se3.py @@ -58,7 +58,7 @@ def se3_log_map(T: torch.Tensor): B = taylor_B(theta) D = (1 - A / (2 * B)) / theta.pow(2) - V_inv = torch.eye(3) - 0.5 * log_R + D * log_R_2 + V_inv = torch.eye(3, dtype=A.dtype, device=A.device) - 0.5 * log_R + D * log_R_2 log_t_vee = torch.einsum("bij, bj -> bi", V_inv, t) return torch.concat([log_R_vee, log_t_vee], dim=-1)