diff --git a/networks/models.py b/networks/models.py index 118ad3b..a0ea431 100644 --- a/networks/models.py +++ b/networks/models.py @@ -260,7 +260,7 @@ def fba_fusion(alpha, img, F, B): F = torch.clamp(F, 0, 1) B = torch.clamp(B, 0, 1) la = 0.1 - alpha = (alpha * la + torch.sum((img - B) * (F - B), 1)) / (torch.sum((F - B) * (F - B), 1) + la) + alpha = (alpha * la + torch.sum((img - B) * (F - B), 1, keepdim=True)) / (torch.sum((F - B) * (F - B), 1, keepdim=True) + la) alpha = torch.clamp(alpha, 0, 1) return alpha, F, B