Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Fixed the bug in generalized procrustes analysis #92

Merged
merged 3 commits into from
Nov 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 10 additions & 7 deletions src/paste3/visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,13 +254,16 @@ def generalized_procrustes_analysis(
assert target_coordinates.shape[1] == 2

weighted_source = pi.sum(axis=1).matmul(source_coordinates)
weighted_targed = pi.sum(axis=0).matmul(target_coordinates)
source_coordinates = source_coordinates - weighted_source
target_coordinates = target_coordinates - weighted_targed
weighted_target = pi.sum(axis=0).matmul(target_coordinates)

if is_partial:
m = torch.sum(pi)
source_coordinates = source_coordinates * (1.0 / m)
target_coordinates = target_coordinates * (1.0 / m)
weighted_source = weighted_source * (1.0 / m)
weighted_target = weighted_target * (1.0 / m)

source_coordinates = source_coordinates - weighted_source
target_coordinates = target_coordinates - weighted_target

covariance_matrix = target_coordinates.T.matmul(pi.T.matmul(source_coordinates))
U, S, Vt = torch.linalg.svd(covariance_matrix, full_matrices=True)
rotation_matrix = Vt.T.matmul(U.T)
Expand All @@ -275,14 +278,14 @@ def generalized_procrustes_analysis(
target_coordinates,
rotation_angle,
weighted_source,
weighted_targed,
weighted_target,
)
if return_params and return_as_matrix:
return (
source_coordinates,
target_coordinates,
rotation_matrix,
weighted_source,
weighted_targed,
weighted_target,
)
return source_coordinates, target_coordinates
Binary file added tests/data/output/partial_procrustes_analysis.npz
Binary file not shown.
29 changes: 9 additions & 20 deletions tests/test_visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,27 +165,16 @@ def test_partial_stack_slices_pairwise(slices):
)


def test_partial_procrustes_analysis(slices):
center_slice = sc.read_h5ad(input_dir / "center_slice.h5ad")
def test_partial_procrustes_analysis(slices2):
data = np.load(output_dir / "partial_procrustes_analysis.npz")

pairwise_info = torch.Tensor(
np.genfromtxt(input_dir / "center_slice1_pairwise.csv", delimiter=",")
).double()
assert torch.sum(torch.Tensor(data["pi"])) < 0.99999999

aligned_center, aligned_slice = generalized_procrustes_analysis(
torch.Tensor(center_slice.obsm["spatial"]).double(),
torch.Tensor(slices[0].obsm["spatial"]).double(),
pairwise_info,
x_aligned, y_aligned = generalized_procrustes_analysis(
torch.Tensor(slices2[0].obsm["spatial"]).double(),
torch.Tensor(slices2[1].obsm["spatial"]).double(),
torch.Tensor(data["pi"]).double(),
is_partial=True,
)

assert_frame_equal(
pd.DataFrame(aligned_center, columns=["0", "1"]),
pd.read_csv(output_dir / "aligned_center.csv"),
atol=1e-6,
)
assert_frame_equal(
pd.DataFrame(aligned_slice, columns=["0", "1"]),
pd.read_csv(output_dir / "aligned_slice.csv"),
atol=1e-6,
)
assert np.allclose(x_aligned.cpu().numpy(), data["x_aligned"])
assert np.allclose(y_aligned.cpu().numpy(), data["y_aligned"], atol=1e-06)
vineetbansal marked this conversation as resolved.
Show resolved Hide resolved
Loading