Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

A bug in the code, influencing the training when batch size > 1 #7

Closed
xymsh opened this issue Apr 12, 2020 · 1 comment
Closed

A bug in the code, influencing the training when batch size > 1 #7

xymsh opened this issue Apr 12, 2020 · 1 comment

Comments

@xymsh
Copy link

xymsh commented Apr 12, 2020

https://github.com/MarcoForte/FBA-Matting/blob/76751dd752d4d3b40bf58c64185fc77c0195cbeb/networks/models.py#L263

Hi, I think you forgot to set the "keepdim" parameter to True in the "torch.sum()" operations.
The correct one should be
alpha = (alpha * la + torch.sum((img - B) * (F - B), 1, keepdim=True)) / (torch.sum((F - B) * (F - B), 1, keepdim=True) + la)

Without keeping dim, the output alpha size would become [batch, batch, height, width], due to the wrong broadcast, while it is supposed to be [batch, 1, height, width].

Apparently, when batch size > 1, the size of alpha prediction is not correct. Thus the loss calculation would be negatively influenced because the alpha, fg, and bg predictions are concatenated together in the last step of forward process.
https://github.com/MarcoForte/FBA-Matting/blob/76751dd752d4d3b40bf58c64185fc77c0195cbeb/networks/models.py#L361

For example, we set the batch size 4. With this bug, we get an alpha prediction of size [4, 4, height, width]. The fg and bg prediction are both of size [4, 3, height, width]. After concatenation, we get an output with size [4, 10, height, width], instead of the supposed [4, 7, height, width].

When calculating loss, we would slice the output by indices to extract the alpha, fg, and bg predictions.

alpha_pred = output[:, 0:1, :, :]
fg_pred = output[:, 1:4, :, :]
bg_pred = output[:, 4:7, :, :]

Here fg_pred is actually part of the alpha_pred, because the first 4 channels are alpha prediction, instead of only 1 channel. Same to bg_pred. The loss for fg and bg predictions is meaningless here.

My experimental results for models using batch size > 1 proved this bug. The errors are extremely high. I'm wondering if this bug has a negative influence on your experiments.

@MarcoForte
Copy link
Owner

Hi thanks for pointing this out I will update the code to your implementation.
I did not use the fba fusion for my models with batch-size above one and I also did not use it during training. So my way of coding it should not influence the results negatively.

MarcoForte added a commit that referenced this issue Apr 12, 2020
# for free to join this conversation on GitHub. Already have an account? # to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants