diff --git a/tests/test_visualization.py b/tests/test_visualization.py index 83861a3..f7162a6 100644 --- a/tests/test_visualization.py +++ b/tests/test_visualization.py @@ -7,6 +7,7 @@ stack_slices_center, generalized_procrustes_analysis, ) +from paste2.projection import partial_stack_slices_pairwise, partial_procrustes_analysis from pandas.testing import assert_frame_equal test_dir = Path(__file__).parent @@ -112,11 +113,11 @@ def test_generalized_procrustes_analysis(slices): ) assert_frame_equal( - pd.DataFrame(aligned_center, columns=['0', '1']), + pd.DataFrame(aligned_center, columns=["0", "1"]), pd.read_csv(output_dir / "aligned_center.csv"), ) assert_frame_equal( - pd.DataFrame(aligned_slice, columns=['0', '1']), + pd.DataFrame(aligned_slice, columns=["0", "1"]), pd.read_csv(output_dir / "aligned_slice.csv"), ) expected_theta = 0.0 @@ -144,3 +145,43 @@ def test_generalized_procrustes_analysis(slices): equal_nan=True, ) ) + + +def test_partial_stack_slices_pairwise(slices): + n_slices = len(slices) + + pairwise_info = [ + np.genfromtxt(input_dir / f"slices_{i}_{i + 1}_pairwise.csv", delimiter=",") + for i in range(1, n_slices) + ] + + new_slices = partial_stack_slices_pairwise(slices, pairwise_info) + + for i, slice in enumerate(new_slices, start=1): + assert_frame_equal( + pd.DataFrame(slice.obsm["spatial"], columns=["0", "1"]), + pd.read_csv(output_dir / f"aligned_spatial_{i}_{i + 1}.csv"), + ) + + +def test_partial_procrustes_analysis(slices): + center_slice = sc.read_h5ad(input_dir / "center_slice.h5ad") + + pairwise_info = np.genfromtxt( + input_dir / "center_slice1_pairwise.csv", delimiter="," + ) + + aligned_center, aligned_slice = partial_procrustes_analysis( + center_slice.obsm["spatial"], + slices[0].obsm["spatial"], + pairwise_info, + ) + + assert_frame_equal( + pd.DataFrame(aligned_center, columns=["0", "1"]), + pd.read_csv(output_dir / "aligned_center.csv"), + ) + assert_frame_equal( + pd.DataFrame(aligned_slice, columns=["0", "1"]), + pd.read_csv(output_dir / "aligned_slice.csv"), + )