-
Notifications
You must be signed in to change notification settings - Fork 16
/
Copy pathdice_jaccard.py
89 lines (74 loc) · 3.33 KB
/
dice_jaccard.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import torch
from .base import Loss, Mode, _reduce
from .functional import soft_dice_score, soft_jaccard_score
class DiceLoss(Loss):
"""
Implementation of Dice loss for image segmentation task.
It supports binary, multiclass and multilabel cases
Args:
mode (str): Target mode {'binary', 'multiclass', 'multilabel'}
'multilabel' - expects y_true of shape [N, C, H, W]
'multiclass', 'binary' - expects y_true of shape [N, H, W]
log_loss (bool): If True, loss computed as `-log(jaccard)`; otherwise `1 - jaccard`
from_logits (bool): If True assumes input is raw logits
eps (float): small epsilon for numerical stability
Shape:
y_pred: [N, C, H, W]
y_true: [N, C, H, W] or [N, H, W] depending on mode
"""
IOU_FUNCTION = soft_dice_score
def __init__(self, mode="binary", reduction="mean", log_loss=False, from_logits=True, eps=1.0):
super(DiceLoss, self).__init__()
self.mode = Mode(mode) # raises an error if not valid
self.reduction = reduction
self.log_loss = log_loss
self.from_logits = from_logits
self.eps = eps
def forward(self, y_pred, y_true):
if self.from_logits:
# Apply activations to get [0..1] class probabilities
if self.mode == Mode.BINARY or self.mode == Mode.MULTILABEL:
y_pred = y_pred.sigmoid()
elif self.mode == Mode.MULTICLASS:
y_pred = y_pred.softmax(dim=1)
bs = y_true.size(0)
num_classes = y_pred.size(1)
dims = (0, 2)
if self.mode == Mode.BINARY:
y_true = y_true.view(bs, 1, -1)
y_pred = y_pred.view(bs, 1, -1)
elif self.mode == Mode.MULTICLASS:
y_true = y_true.view(bs, -1)
y_pred = y_pred.view(bs, num_classes, -1)
y_true = torch.nn.functional.one_hot(y_true, num_classes) # N,H*W -> N,H*W, C
y_true = y_true.permute(0, 2, 1) # H, C, H*W
elif self.mode == Mode.MULTILABEL:
y_true = y_true.view(bs, num_classes, -1)
y_pred = y_pred.view(bs, num_classes, -1)
scores = self.__class__.IOU_FUNCTION(y_pred, y_true.type(y_pred.dtype), dims=dims, eps=self.eps)
if self.log_loss:
loss = -torch.log(scores)
else:
loss = 1 - scores
# IoU loss is defined for non-empty classes
# So we zero contribution of channel that does not have true pixels
mask = y_true.sum(dims) > 0
loss *= mask.float()
return _reduce(loss, self.reduction)
class JaccardLoss(DiceLoss):
"""
Implementation of Jaccard loss for image segmentation task.
It supports binary, multiclass and multilabel cases
Args:
mode (str): Target mode {'binary', 'multiclass', 'multilabel'}
'multilabel' - expects y_true of shape [N, C, H, W]
'multiclass', 'binary' - expects y_true of shape [N, H, W]
log_loss (bool): If True, loss computed as `-log(jaccard)`; otherwise `1 - jaccard`
from_logits (bool): If True assumes input is raw logits
eps (float): small epsilon for numerical stability
Shape:
y_pred: [N, C, H, W]
y_true: [N, C, H, W] or [N, H, W] depending on mode
"""
# the only difference is which function to use
IOU_FUNCTION = soft_jaccard_score