-
Notifications
You must be signed in to change notification settings - Fork 2.7k
dice loss #396
New issue
Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? # to your account
dice loss #396
Conversation
Codecov Report
@@ Coverage Diff @@
## master #396 +/- ##
==========================================
+ Coverage 86.20% 86.32% +0.12%
==========================================
Files 96 98 +2
Lines 4906 4973 +67
Branches 799 808 +9
==========================================
+ Hits 4229 4293 +64
- Misses 523 525 +2
- Partials 154 155 +1
Flags with carried forward coverage won't be shown. Click here to find out more.
Continue to review full report at Codecov.
|
Hi @xiexinch |
mmseg/models/losses/dice_loss.py
Outdated
class DiceLoss(nn.Module): | ||
"""DiceLoss. | ||
|
||
""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We may add some docstring here.
mmseg/models/losses/dice_loss.py
Outdated
valid_mask = valid_mask.contiguous().view(valid_mask.shape[0], -1) | ||
|
||
num = torch.sum(torch.mul(pred, target) * valid_mask, dim=1) * 2 + smooth | ||
den = torch.sum((pred.pow(exponent) + target.pow(exponent)) * valid_mask, dim=1) + smooth |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We may directly use denominator
.
mmseg/models/losses/dice_loss.py
Outdated
class_weight (list[float], optional): The weight for each class. | ||
Default: None. | ||
loss_weight (float, optional): Weight of the loss. Default to 1.0. | ||
ignore_index (int | None): The label index to be ignored. Default: -1. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should be 255.
FCN dilate model shouldn't be in this PR |
else: | ||
class_weight = None | ||
|
||
pred = F.softmax(pred, dim=1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In instance segmentation, dice loss may use sigmoid for activation. Suggest supporting both cases.
assert pred.shape[0] == target.shape[0] | ||
total_loss = 0 | ||
num_classes = pred.shape[1] | ||
for i in range(num_classes): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use for loop might be inefficient? Some implementation support to process multi-class in a batched manner.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@weighted_loss | ||
def binary_dice_loss(pred, target, valid_mask, smooth=1, exponent=2, **kwards): | ||
assert pred.shape[0] == target.shape[0] | ||
pred = pred.contiguous().view(pred.shape[0], -1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
.contiguous().view()
can be replaced by reshape
?
* dice loss * format code, add docstring and calculate denominator without valid_mask * minor change * restore
* ddim docs for issue open-mmlab#293 * space
No description provided.