Skip to content

Commit

Permalink
Removed for loop in sliced wasserstein distance
Browse files Browse the repository at this point in the history
  • Loading branch information
Clancy97 authored and Pseudomanifold committed Jan 30, 2024
1 parent 5426f42 commit aba1474
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions torch_topological/nn/sliced_wasserstein_distance.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,8 @@ def forward(self, X, Y):
diag = torch.tensor([0.5, 0.5], dtype=torch.float32)

# Project both the diagrams onto the diagonals.
D1_diag = torch.vstack([torch.sum(x) * diag for x in D1])
D2_diag = torch.vstack([torch.sum(x) * diag for x in D2])
D1_diag = torch.sum(D1, dim=1, keepdim=True) * diag
D2_diag = torch.sum(D2, dim=1, keepdim=True) * diag

cost = 0.0

Expand Down

0 comments on commit aba1474

Please # to comment.