-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathlosses.py
116 lines (94 loc) · 5.68 KB
/
losses.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import torch
import torch.nn.functional as F
import torchvision
from typing import Union
class VanillaDiscriminatorLoss(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, real_pred : torch.Tensor, fake_pred : torch.Tensor) -> torch.Tensor:
return (F.softplus(-real_pred) + F.softplus(fake_pred)).mean()
class VanillaGeneratorLossNS(torch.nn.Module):
"""
Implements non-saturating vanilla GAN Generator loss
"""
def __init__(self):
super().__init__()
def forward(self, fake_pred : torch.Tensor) -> torch.Tensor:
return F.softplus(-fake_pred).mean()
class WGANDiscriminatorLoss(torch.nn.Module):
"""
Implements WGAN Discriminator loss, i.e. approximation to Wasserstein-1 distance between real and fake data distributions.
"""
def __init__(self):
super().__init__()
def forward(self, real_pred : torch.Tensor, fake_pred : torch.Tensor) -> torch.Tensor:
return (fake_pred - real_pred).mean()
class WGANGeneratorLoss(torch.nn.Module):
def __init__(self):
"""
Implements WGAN Generator loss, i.e. minimizing Wasserstein-1 distance between real and fake data distributions.
"""
super().__init__()
def forward(self, fake_pred : torch.Tensor) -> torch.Tensor:
return (-fake_pred).mean()
class GradientPenalty(torch.nn.Module):
def __init__(self, reg_weight : float = 10.0, gp_type : str = "r1"):
"""
WGAN gradient penalty penalizes critic/discriminator gradient deviations from 1, tries to enforce 1-Lipschitz constraint on the critic, as discussed in https://arxiv.org/pdf/1704.00028.
On the other hand, R1 penalty penalizes expected norm of discriminator gradient on the distribution of real samples: https://arxiv.org/pdf/1801.04406v4
We implement gradient penalty so that it supports both regimes.
"""
assert gp_type in ["wgan-gp", "r1"]
super().__init__()
self.reg_weight = reg_weight
self.gp_type = gp_type
def forward(self, real : torch.Tensor, real_pred : Union[torch.Tensor, None] = None, fake_samples : Union[torch.Tensor, None] = None, critic = None) -> torch.Tensor:
if self.gp_type == "wgan-gp":
eps = torch.rand((fake_samples.shape[0], 1, 1, 1), device = fake_samples.get_device())
real = eps * real + (1 - eps) * fake_samples
real_pred = critic(real)
grad, = torch.autograd.grad(inputs = real, outputs = real_pred.sum(), create_graph = True)
norm = torch.sum(grad ** 2, dim = (1, 2, 3))
if self.gp_type == "wgan-gp":
norm = torch.sqrt(norm)
return self.reg_weight * ((norm.sqrt() - 1) ** 2).mean()
return (self.reg_weight / 2) * norm.mean()
class PathLengthPenalty(torch.nn.Module):
def __init__(self, reg_weight = 2.0, beta = 0.99):
super().__init__()
self.reg_weight = reg_weight
self.beta = beta
self.steps = 0
self.a = torch.nn.Parameter(torch.zeros([]), requires_grad = False)
def forward(self, w : torch.Tensor, x : torch.Tensor):
"""
w: (2 * logw(image_size) - 2, batch_size, latent_dim).
x: (batch_size, channels, h, w) - Images generated by generator through w
"""
rh, rw = x.shape[2], x.shape[3]
device = x.device
"""
Following is not mentioned in the paper but is present in StyleGAN2-ADA implementation.
StyleGAN2 TensorFlow implementation: https://github.com/NVlabs/stylegan2/blob/master/training/loss.py#L167
StyleGAN2-ADA implementation: https://github.com/NVlabs/stylegan2-ada-pytorch/blob/main/training/loss.py#L81C17-L81C26
"""
y = torch.randn(x.shape, device = device) / (rh * rw) ** 0.5
out = (x * y).sum() # Sum of <x_i, y_i> where x_i is an image obtained with w_i, and y_i are IID Gaussians with 0 mean and 1 / (rh * rw) variance
grad, = torch.autograd.grad(inputs = w, outputs = out, create_graph = True) # (2 * logw(image_size) - 2, batch_size, latent_dim)
gnorm = torch.sum(grad ** 2, dim = -1).mean(dim = 0).sqrt() # (batch_size)
"""
Quote from StyleGAN2 paper: 'To ensure that our regular-izer interacts correctly with style mixing regularization, wecompute it as an average of
all individual layers of the syn-thesis network'. Dimensions for w in this implementation are (2 * logw(image_size) - 2, batch_size, latent_dim), while in
official StyleGAN-2 ADA implementation it looks like the ordering is (2 * logw(image_size) - 2, batch_size, latent_dim)
In the paper, they mention that regularization weight is computed as ln2 / (target_res ** 2(ln(target_res) - ln2)). However,
in official Tensorflow implemntation of StyleGAN2 (https://github.com/NVlabs/stylegan2/blob/master/training/loss.py#L185),
they elaborate that settings regularization weight to 2, dividing noise by (rh * rw) ** 0.5 and taking mean(dim = 0) has the following effect on regularization weight:
reg_weight = 2 / (target_res ** 2) / (2 * log(target_res) - 2), and this equates to original regularization weight proposed in the paper.
"""
# Inverse of https://github.com/NVlabs/stylegan2-ada-pytorch/blob/d72cc7d041b42ec8e806021a205ed9349f87c6a4/training/loss.py#L85C69-L85C77
# because of different beta parameterization.
a = torch.lerp(gnorm.mean(), self.a, self.beta) # https://pytorch.org/docs/stable/generated/torch.lerp.html
self.a.copy_(a.detach())
res = self.reg_weight * ((gnorm - a) ** 2).mean() # Monte-Carlo estimate of Equation (4) from StyleGAN2 paper.
self.steps += 1
return res