diff --git a/torch_topological/nn/sliced_wasserstein_distance.py b/torch_topological/nn/sliced_wasserstein_distance.py index 60741ce..783ef21 100644 --- a/torch_topological/nn/sliced_wasserstein_distance.py +++ b/torch_topological/nn/sliced_wasserstein_distance.py @@ -83,8 +83,8 @@ def forward(self, X, Y): diag = torch.tensor([0.5, 0.5], dtype=torch.float32) # Project both the diagrams onto the diagonals. - D1_diag = torch.vstack([torch.sum(x) * diag for x in D1]) - D2_diag = torch.vstack([torch.sum(x) * diag for x in D2]) + D1_diag = torch.sum(D1, dim=1, keepdim=True) * diag + D2_diag = torch.sum(D2, dim=1, keepdim=True) * diag cost = 0.0