Skip to content

Commit

Permalink
pinned POT in pyproject; reintroducing linesearch test with the funct…
Browse files Browse the repository at this point in the history
…ion in ot library
  • Loading branch information
anushka255 committed Nov 8, 2024
1 parent bf1b013 commit 560c08d
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 1 deletion.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ classifiers = [
dependencies = [
"anndata",
"scanpy",
"POT",
"POT>=0.9.5",
"numpy<2",
"scipy",
"scikit-learn",
Expand Down
24 changes: 24 additions & 0 deletions tests/test_paste.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,30 @@ def test_fused_gromov_wasserstein(spot_distance_matrix):
assert_checksum_equals(temp_dir, "fused_gromov_wasserstein.csv")


def test_gromov_linesearch(spot_distance_matrix):
nx = ot.backend.TorchBackend()

G = 1.509115054931788e-05 * torch.ones((251, 264)).double()
deltaG = torch.Tensor(
np.genfromtxt(input_dir / "deltaG.csv", delimiter=",")
).double()
costG = 6.0935270338235075

alpha, fc, cost_G = ot.gromov.solve_gromov_linesearch(
G=G,
deltaG=deltaG,
cost_G=costG,
C1=spot_distance_matrix[1],
C2=spot_distance_matrix[2],
M=0.0,
reg=2 * 1.0,
nx=nx,
)
assert alpha == 1.0
assert fc == 1
assert pytest.approx(cost_G) == -11.20545


def test_line_search_partial(spot_distance_matrix):
G = 1.509115054931788e-05 * torch.ones((251, 264)).double()
deltaG = torch.Tensor(
Expand Down

0 comments on commit 560c08d

Please # to comment.