Skip to content

Commit

Permalink
fixed the bug in generalized procrustes analysis and added tolerance …
Browse files Browse the repository at this point in the history
…in the test case
  • Loading branch information
anushka255 committed Nov 15, 2024
1 parent 3b1787b commit 5493ffa
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 9 deletions.
17 changes: 10 additions & 7 deletions src/paste3/visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,13 +254,16 @@ def generalized_procrustes_analysis(
assert target_coordinates.shape[1] == 2

weighted_source = pi.sum(axis=1).matmul(source_coordinates)
weighted_targed = pi.sum(axis=0).matmul(target_coordinates)
source_coordinates = source_coordinates - weighted_source
target_coordinates = target_coordinates - weighted_targed
weighted_target = pi.sum(axis=0).matmul(target_coordinates)

if is_partial:
m = torch.sum(pi)
source_coordinates = source_coordinates * (1.0 / m)
target_coordinates = target_coordinates * (1.0 / m)
weighted_source = weighted_source * (1.0 / m)
weighted_target = weighted_target * (1.0 / m)

source_coordinates = source_coordinates - weighted_source
target_coordinates = target_coordinates - weighted_target

covariance_matrix = target_coordinates.T.matmul(pi.T.matmul(source_coordinates))
U, S, Vt = torch.linalg.svd(covariance_matrix, full_matrices=True)
rotation_matrix = Vt.T.matmul(U.T)
Expand All @@ -275,14 +278,14 @@ def generalized_procrustes_analysis(
target_coordinates,
rotation_angle,
weighted_source,
weighted_targed,
weighted_target,
)
if return_params and return_as_matrix:
return (
source_coordinates,
target_coordinates,
rotation_matrix,
weighted_source,
weighted_targed,
weighted_target,
)
return source_coordinates, target_coordinates
4 changes: 2 additions & 2 deletions tests/test_visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,5 +176,5 @@ def test_partial_procrustes_analysis(slices2):
torch.Tensor(data["pi"]).double(),
is_partial=True,
)
assert np.allclose(x_aligned, data["x_aligned"])
assert np.allclose(y_aligned, data["y_aligned"])
assert np.allclose(x_aligned.cpu().numpy(), data["x_aligned"])
assert np.allclose(y_aligned.cpu().numpy(), data["y_aligned"], atol=1e-06)

0 comments on commit 5493ffa

Please # to comment.