Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Mistake in the code? #18

Open
Twice22 opened this issue Jul 2, 2020 · 3 comments
Open

Mistake in the code? #18

Twice22 opened this issue Jul 2, 2020 · 3 comments

Comments

@Twice22
Copy link

Twice22 commented Jul 2, 2020

Hello!

Thank you for releasing your implementation. Yet, it looks like the fba_fusion doesn't do what you want to do. Or am I missing something?

Indeed, before calling the fba_fusion function, you've defined, alpha, fg, bg as follow:

        alpha = torch.clamp(output[:, 0][:, None], 0, 1)

        F = torch.sigmoid(output[:, 1:4])
        B = torch.sigmoid(output[:, 4:7])

        alpha, F, B = fba_fusion(alpha, img, F, B)

So, you are broadcasting alpha so that it is of size (B, 1, H, W)
Moreover F, and B are respectively of sizes (B, 3, H, W)

Now, if we look at how you compute alpha in the fba_fusion module, we have:

def fba_fusion(alpha, img, F, B):
    F = ((alpha * img + (1 - alpha**2) * F - alpha * (1 - alpha) * B))
    B = ((1 - alpha) * img + (2 * alpha - alpha**2) * B - alpha * (1 - alpha) * F)

    F = torch.clamp(F, 0, 1)
    B = torch.clamp(B, 0, 1)
    la = 0.1
    alpha = (alpha * la + torch.sum((img - B) * (F - B), 1)) / (torch.sum((F - B) * (F - B), 1) + la)
    alpha = torch.clamp(alpha, 0, 1)
    return alpha, F, B

So, we have (by using the broadcasting rules)

size = ((B, 1, H, W) * scalar + sum((B, 3, H, W), 1)) / (sum((B, 3, H, W), 1) + scalar)
size = (B, 1, H, W) + (B, H, W)) / (B, H, W)
size = (B, 1, H, W) + (1, B, H, W) / (B, H, W)
size = (B, B, H, W) / (B, B, H, W)
size = (B, B, H, W)

So, in the end, alpha is of size (B, B, H, W)

Wheren't you supposed to add keepdim=True in torch.sum?
Your final pth model used this flawed operation?

Hope you can reply my enquiries.
Thank you

@raphychek
Copy link

Well there actually is a keepdim = True in the torch.sum. In networks/models.py, the code is as follows, on line 256:

def fba_fusion(alpha, img, F, B):
    F = ((alpha * img + (1 - alpha**2) * F - alpha * (1 - alpha) * B))
    B = ((1 - alpha) * img + (2 * alpha - alpha**2) * B - alpha * (1 - alpha) * F)

    F = torch.clamp(F, 0, 1)
    B = torch.clamp(B, 0, 1)
    la = 0.1
    alpha = (alpha * la + torch.sum((img - B) * (F - B), 1, keepdim=True)) / (torch.sum((F - B) * (F - B), 1, keepdim=True) + la)
    alpha = torch.clamp(alpha, 0, 1)
    return alpha, F, B

@Twice22
Copy link
Author

Twice22 commented Jul 2, 2020

Oh ok. I haven't seen this because I was working on the implementation before your the last commit

@MarcoForte
Copy link
Owner

Hi thanks for your interest and taking time to inform me of this issue. As raphychek pointed out it has been corrected already, see #7

# for free to join this conversation on GitHub. Already have an account? # to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants