Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Gradient penalty "interpolates" term #176

Open
rl-max opened this issue Nov 3, 2022 · 1 comment
Open

Gradient penalty "interpolates" term #176

rl-max opened this issue Nov 3, 2022 · 1 comment

Comments

@rl-max
Copy link

rl-max commented Nov 3, 2022

Hi,

This is 'cal_grad_penalty' function in /src/utils/losses.py

def cal_grad_penalty(real_images, real_labels, fake_images, discriminator, device):
    batch_size, c, h, w = real_images.shape
    alpha = torch.rand(batch_size, 1)
    alpha = alpha.expand(batch_size, real_images.nelement() // batch_size).contiguous().view(batch_size, c, h, w)
    alpha = alpha.to(device)

    real_images = real_images.to(device)
    interpolates = alpha * real_images + ((1 - alpha) * fake_images)
    interpolates = interpolates.to(device)
    interpolates = autograd.Variable(interpolates, requires_grad=True)
    fake_dict = discriminator(interpolates, real_labels, eval=False)
    grads = cal_deriv(inputs=interpolates, outputs=fake_dict["adv_output"], device=device)
    grads = grads.view(grads.size(0), -1)

    grad_penalty = ((grads.norm(2, dim=1) - 1)**2).mean() + interpolates[:,0,0,0].mean()*0
    return grad_penalty

In the last line, grad_penalty = ((grads.norm(2, dim=1) - 1)**2).mean() + interpolates[:,0,0,0].mean()*0, I wanted to know what additive term + interpolates[:,0,0,0].mean()*0 means. Since it's zero-multiplicated, I think it has actually no effect for code.

I'll be waiting for your answer

Thank you!

@mingukkang
Copy link
Collaborator

The implementation has no effect on the code and was introduced to address a PyTorch bug that arises during DDP training with R1 regularization.

# for free to join this conversation on GitHub. Already have an account? # to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants