diff --git a/src/paste3/visualization.py b/src/paste3/visualization.py index 8382ea5..b02308b 100644 --- a/src/paste3/visualization.py +++ b/src/paste3/visualization.py @@ -254,13 +254,16 @@ def generalized_procrustes_analysis( assert target_coordinates.shape[1] == 2 weighted_source = pi.sum(axis=1).matmul(source_coordinates) - weighted_targed = pi.sum(axis=0).matmul(target_coordinates) - source_coordinates = source_coordinates - weighted_source - target_coordinates = target_coordinates - weighted_targed + weighted_target = pi.sum(axis=0).matmul(target_coordinates) + if is_partial: m = torch.sum(pi) - source_coordinates = source_coordinates * (1.0 / m) - target_coordinates = target_coordinates * (1.0 / m) + weighted_source = weighted_source * (1.0 / m) + weighted_target = weighted_target * (1.0 / m) + + source_coordinates = source_coordinates - weighted_source + target_coordinates = target_coordinates - weighted_target + covariance_matrix = target_coordinates.T.matmul(pi.T.matmul(source_coordinates)) U, S, Vt = torch.linalg.svd(covariance_matrix, full_matrices=True) rotation_matrix = Vt.T.matmul(U.T) @@ -275,7 +278,7 @@ def generalized_procrustes_analysis( target_coordinates, rotation_angle, weighted_source, - weighted_targed, + weighted_target, ) if return_params and return_as_matrix: return ( @@ -283,6 +286,6 @@ def generalized_procrustes_analysis( target_coordinates, rotation_matrix, weighted_source, - weighted_targed, + weighted_target, ) return source_coordinates, target_coordinates diff --git a/tests/data/output/partial_procrustes_analysis.npz b/tests/data/output/partial_procrustes_analysis.npz new file mode 100644 index 0000000..bb49ef4 Binary files /dev/null and b/tests/data/output/partial_procrustes_analysis.npz differ diff --git a/tests/test_visualization.py b/tests/test_visualization.py index a10fe6c..6894e61 100644 --- a/tests/test_visualization.py +++ b/tests/test_visualization.py @@ -165,27 +165,16 @@ def test_partial_stack_slices_pairwise(slices): ) -def test_partial_procrustes_analysis(slices): - center_slice = sc.read_h5ad(input_dir / "center_slice.h5ad") +def test_partial_procrustes_analysis(slices2): + data = np.load(output_dir / "partial_procrustes_analysis.npz") - pairwise_info = torch.Tensor( - np.genfromtxt(input_dir / "center_slice1_pairwise.csv", delimiter=",") - ).double() + assert torch.sum(torch.Tensor(data["pi"])) < 0.99999999 - aligned_center, aligned_slice = generalized_procrustes_analysis( - torch.Tensor(center_slice.obsm["spatial"]).double(), - torch.Tensor(slices[0].obsm["spatial"]).double(), - pairwise_info, + x_aligned, y_aligned = generalized_procrustes_analysis( + torch.Tensor(slices2[0].obsm["spatial"]).double(), + torch.Tensor(slices2[1].obsm["spatial"]).double(), + torch.Tensor(data["pi"]).double(), is_partial=True, ) - - assert_frame_equal( - pd.DataFrame(aligned_center, columns=["0", "1"]), - pd.read_csv(output_dir / "aligned_center.csv"), - atol=1e-6, - ) - assert_frame_equal( - pd.DataFrame(aligned_slice, columns=["0", "1"]), - pd.read_csv(output_dir / "aligned_slice.csv"), - atol=1e-6, - ) + assert np.allclose(x_aligned.cpu().numpy(), data["x_aligned"]) + assert np.allclose(y_aligned.cpu().numpy(), data["y_aligned"], atol=1e-06)