-
Notifications
You must be signed in to change notification settings - Fork 0
/
loss.py
82 lines (67 loc) · 3.36 KB
/
loss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import torch
import torchvision.transforms as tf
def alpha_prediction_loss(predAlpha, trueAlpha):
"""
Both inputs are expected to be in the form BxSxS as this function operates
on a batch of single channel images with values in the range of 0 - 1
"""
eps = torch.tensor(1e-6).float()
squareEps = eps.pow(2)
difference = predAlpha - trueAlpha
squaredDifference = torch.pow(difference, 2) + squareEps
rootDiff = torch.sqrt(squaredDifference)
sumRootDiff = rootDiff.sum(dim=[1,2])
sumTrueAlpha = trueAlpha.sum(dim=[1,2]) + eps
totalLoss = sumRootDiff / sumTrueAlpha
avgTotalLoss = totalLoss.mean()
return avgTotalLoss
def compositional_loss(predAlpha, trueAlpha, compositeImage):
def show(xf):
for idx in range(xf.size(0)):
f = tf.ToPILImage()(xf[idx])
f.show()
eps = torch.tensor(1e-6).float()
squareEps = torch.tensor(1e-6).pow(2).float()
trimaps = compositeImage[:,3,:] * 255
compositeImage = compositeImage[:,0:3,:] #This removes the trimap added to the last dimension
"""
When using only the trimap to calculate the compositional loss
It seems this confuses the model and causes the model (either the encoder-decoder or the refinement, whichever it is
applied on) to output only black images.
But when using the alpha mask for loss calculations, the model outputs as expected
"""
# blackMask = torch.zeros_like(trueAlpha)
# unknownTrueMask = torch.where(trimaps == 127, trueAlpha, blackMask)
# unknownPredictedMask = torch.where(trimaps == 127, predAlpha, blackMask)
# unknownTrueForeground = compositeImage * unknownTrueMask.unsqueeze(1)
# unknownPredictedForeground = compositeImage * unknownPredictedMask.unsqueeze(1)
# difference = unknownPredictedForeground - unknownTrueForeground
# squaredDifference = torch.pow(difference, 2) + squareEps
# rootDiff = torch.sqrt(squaredDifference)
# sumRootDiff = rootDiff.sum(dim=[2,3])
# sumTrueUnknownForeground = unknownTrueForeground.sum(dim=[2,3]) + eps
# totalLoss = sumRootDiff / sumTrueUnknownForeground
# avgLoss = totalLoss.mean(dim=1).mean() # average over the RGB channels and also across the batch
trueForeground = compositeImage * trueAlpha.unsqueeze(1)
predictedForeground = compositeImage * predAlpha.unsqueeze(1)
difference = predictedForeground - trueForeground
squaredDifference = torch.pow(difference, 2) + squareEps
rootDiff = torch.sqrt(squaredDifference)
sumTrueForeground = trueForeground.sum(dim=[2,3]) + eps
totalLoss = rootDiff.sum(dim=[2,3]) / sumTrueForeground
avgLoss = totalLoss.mean().mean() # average over the RGB channels and also across the batch
return avgLoss
def sum_absolute_difference(trueAlpha, predAlpha):
"""
calculates the sum of absolute differences between images and predictions in batches
As the calculation is done over a batch, the mean is used to reduce the results
"""
difference = predAlpha - trueAlpha
avgDiff = difference.sum(dim=[1,2]).mean()
return avgDiff
def mean_squared_error(trueAlpha, predAlpha, compositeImage):
trimaps = compositeImage[:,3,:] * 255
blackMask = torch.zeros_like(trueAlpha)
unknownRegions = torch.where(trimaps == 127, trueAlpha, blackMask)
mse = torch.pow(predAlpha - trueAlpha, 2).sum() / unknownRegions.sum()
return mse