-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathloss.py
executable file
·123 lines (96 loc) · 4.1 KB
/
loss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import torch
import torch.nn as nn
import torchvision
import torch.nn.functional as F
class AdversarialLoss(nn.Module):
r"""
Adversarial loss
https://arxiv.org/abs/1711.10337
"""
def __init__(self,
type='nsgan',
target_real_label=1.0,
target_fake_label=0.0):
r"""
type = nsgan | lsgan | hinge
"""
super(AdversarialLoss, self).__init__()
self.type = type
self.register_buffer('real_label', torch.tensor(target_real_label))
self.register_buffer('fake_label', torch.tensor(target_fake_label))
if type == 'nsgan':
self.criterion = nn.BCELoss()
elif type == 'lsgan':
self.criterion = nn.MSELoss()
elif type == 'hinge':
self.criterion = nn.ReLU()
def __call__(self, outputs, is_real, is_disc=None):
if self.type == 'hinge':
if is_disc:
if is_real:
outputs = -outputs
return self.criterion(1 + outputs).mean()
else:
return (-outputs).mean()
else:
labels = (self.real_label
if is_real else self.fake_label).expand_as(outputs)
loss = self.criterion(outputs, labels)
return loss
IMAGENET_MEAN = torch.FloatTensor([0.485, 0.456, 0.406])[None, :, None, None]
IMAGENET_STD = torch.FloatTensor([0.229, 0.224, 0.225])[None, :, None, None]
# @LOSS_REGISTRY.register()
class LaMaPerceptualLoss(nn.Module):
def __init__(self, loss_weight=1.0, normalize_inputs=True):
super(LaMaPerceptualLoss, self).__init__()
self.loss_weight = loss_weight
self.normalize_inputs = normalize_inputs
self.mean_ = IMAGENET_MEAN
self.std_ = IMAGENET_STD
vgg = torchvision.models.vgg19(pretrained=True).features
vgg_avg_pooling = []
for weights in vgg.parameters():
weights.requires_grad = False
for module in vgg.modules():
if module.__class__.__name__ == 'Sequential':
continue
elif module.__class__.__name__ == 'MaxPool2d':
vgg_avg_pooling.append(nn.AvgPool2d(kernel_size=2, stride=2, padding=0))
else:
vgg_avg_pooling.append(module)
self.vgg = nn.Sequential(*vgg_avg_pooling)
def do_normalize_inputs(self, x):
return (x - self.mean_.to(x.device)) / self.std_.to(x.device)
def partial_losses(self, input, target, mask=None):
# check_and_warn_input_range(target, 0, 1, 'PerceptualLoss target in partial_losses')
# we expect input and target to be in [0, 1] range
losses = []
if self.normalize_inputs:
features_input = self.do_normalize_inputs(input)
features_target = self.do_normalize_inputs(target)
else:
features_input = input
features_target = target
for layer in self.vgg[:30]:
features_input = layer(features_input)
features_target = layer(features_target)
if layer.__class__.__name__ == 'ReLU':
loss = F.mse_loss(features_input, features_target, reduction='none')
if mask is not None:
cur_mask = F.interpolate(mask, size=features_input.shape[-2:],
mode='bilinear', align_corners=False)
loss = loss * (1 - cur_mask)
loss = loss.mean(dim=tuple(range(1, len(loss.shape))))
losses.append(loss)
return losses
def forward(self, input, target, mask=None):
losses = self.partial_losses(input, target, mask=mask)
return torch.stack(losses).sum() * self.loss_weight
def get_global_features(self, input):
# check_and_warn_input_range(input, 0, 1, 'PerceptualLoss input in get_global_features')
if self.normalize_inputs:
features_input = self.do_normalize_inputs(input)
else:
features_input = input
features_input = self.vgg(features_input)
return features_input