-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathlosses.py
117 lines (93 loc) · 4.06 KB
/
losses.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
from typing import Dict, List, Optional
import torch
import torch.backends.cudnn
import torch.nn.functional as F
def gram_matrix(
feature_maps: torch.Tensor, normalize: bool = True
) -> torch.Tensor:
"""Returns normalized Gram matrix of the given feature maps.
The gram matrix of a feature map F, where f_k: feature vector k
(of shape c x (h * w)) is defined as G = FF^T, where G_ij = f_i * f_j,
the inner product between feature vectors i and j.
"""
n, c, h, w = feature_maps.shape
feature_maps = feature_maps.view(n, c, h * w)
feature_maps_t = feature_maps.transpose(1, 2)
gram = torch.bmm(feature_maps, feature_maps_t)
gram /= (c * h * w) if normalize else 1.0
return gram
def _tv_isotropic(x: torch.Tensor) -> torch.Tensor:
"""Isotropic TV loss."""
# sum right neighbor pixel differences
loss = torch.sum((x[:, :, :, :-1] - x[:, :, :, 1:]) ** 2)
# sum lower neighbor pixel differences
loss += torch.sum((x[:, :, :-1, :] - x[:, :, 1:, :]) ** 2)
return loss / x.numel()
def _tv_anisotropic(x: torch.Tensor) -> torch.Tensor:
"""Anisotropic TV loss."""
# sum right neighbor pixel differences
loss = torch.sum(torch.abs(x[:, :, :, :-1] - x[:, :, :, 1:]))
# sum lower neighbor pixel differences
loss += torch.sum(torch.abs(x[:, :, :-1, :] - x[:, :, 1:, :]))
return loss / x.numel()
def total_variation_loss(
image: torch.Tensor, anisotropic: bool = False
) -> torch.Tensor:
"""Returns normalized total variation (TV) loss of the given image.
Implements both (1) `isotropic` and (2) `anisotropic` versions of TV loss,
where (1) ditches the square root and aggregates the squared derivatives in
horizontal and vertical directions together, and (2) uses absolute value
to separately aggregate terms.
The advantage of (1) is that neither direction (horizontal or vertical) is
biased, and so edges in both directions are prioritized "equally," whereas
(2) may prioritize one direction (e.g. vertical) if the loss is
significantly greater than the other direction.
See https://en.wikipedia.org/wiki/Total_variation_denoising for more
details.
"""
return _tv_anisotropic(image) if anisotropic else _tv_isotropic(image)
def style_loss(
generated_features: torch.Tensor,
target_features: torch.Tensor
) -> torch.Tensor:
"""Returns the style loss given the generated and target features."""
generated_gram = gram_matrix(generated_features)
target_gram = gram_matrix(target_features)
loss = F.mse_loss(generated_gram, target_gram)
return loss
def perceptual_loss(
generated_content: Dict[str, torch.Tensor],
generated_style: Dict[str, torch.Tensor],
content_targets: Dict[str, torch.Tensor],
style_targets: Dict[str, torch.Tensor],
generated_image: Optional[torch.Tensor] = None,
content_weight: float = 1.0,
style_weight: float = 1e5,
tv_weight: float = 1.0e-10,
anisotropic: bool = False
) -> Dict[str, torch.Tensor]:
"""Calculates the perceptual loss, which combines content and style losses.
Optionally, uses TV regularization to enhance edge preservation and pixel
smoothening (reduction of small artifacts/noise).
Returns the perceptual loss, as well as each loss term (unweighted).
"""
c_loss = s_loss = tv_loss = 0
# calculate total content loss
for label in content_targets.keys():
gen_feat = generated_content[label]
target_feat = content_targets[label]
c_loss += F.mse_loss(gen_feat, target_feat)
# calculate total style loss
for label in style_targets.keys():
gen_feat = generated_style[label]
target_feat = style_targets[label]
s_loss += style_loss(gen_feat, target_feat)
# combine the losses (perceptual loss)
loss = content_weight * c_loss + style_weight * s_loss
# add tv loss (if applicable)
if tv_weight > 0:
tv_loss = total_variation_loss(generated_image, anisotropic)
loss += tv_weight * tv_loss
return {
"perceptual": loss, "content": c_loss, "style": s_loss, "tv": tv_loss
}