Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

[Feature] Distillation enhancement #463

Open
wants to merge 17 commits into
base: main
Choose a base branch
from
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
kl divergence enhancement
HIT-cwh committed Feb 23, 2023
commit 45d47c5098831b9efc24b7086aa09d7b21111c1a
4 changes: 1 addition & 3 deletions mmrazor/models/losses/__init__.py
Original file line number Diff line number Diff line change
@@ -16,8 +16,6 @@
from .mgd_loss import MGDLoss
from .ofd_loss import OFDLoss
from .pkd_loss import PKDLoss
from .ppyoloe_distill_loss import (DistributionFocalLoss, MainKDLoss,
QualityFocalLoss)
from .relational_kd import AngleWiseRKD, DistanceWiseRKD
from .weighted_soft_label_distillation import WSLD

@@ -26,5 +24,5 @@
'WSLD', 'L2Loss', 'ABLoss', 'DKDLoss', 'KDSoftCELoss', 'ActivationLoss',
'OnehotLikeLoss', 'InformationEntropyLoss', 'FTLoss', 'ATLoss', 'OFDLoss',
'L1Loss', 'FBKDLoss', 'CRDLoss', 'CrossEntropyLoss', 'PKDLoss', 'MGDLoss',
'FGDLoss', 'QualityFocalLoss', 'DistributionFocalLoss', 'MainKDLoss'
'FGDLoss'
]
45 changes: 40 additions & 5 deletions mmrazor/models/losses/kl_divergence.py
Original file line number Diff line number Diff line change
@@ -2,9 +2,26 @@
import torch.nn as nn
import torch.nn.functional as F

from mmrazor.models.losses.utils import weighted_loss
from mmrazor.registry import MODELS


@weighted_loss
def kl_div(preds_S, preds_T, tau):
"""Calculate the KL divergence between `preds_S` and `preds_T`.

Args:
preds_S (torch.Tensor): The student model prediction with shape (N, C).
preds_T (torch.Tensor): The teacher model prediction with shape (N, C).
tau (float): Temperature coefficient.
"""
softmax_pred_T = F.softmax(preds_T / tau, dim=1)
logsoftmax_preds_S = F.log_softmax(preds_S / tau, dim=1)
loss = (tau**2) * F.kl_div(
logsoftmax_preds_S, softmax_pred_T, reduction='none')
return loss


@MODELS.register_module()
class KLDivergence(nn.Module):
"""A measure of how one probability distribution Q is different from a
@@ -45,22 +62,40 @@ def __init__(
f'but gets {reduction}.'
self.reduction = reduction

def forward(self, preds_S, preds_T):
def forward(self,
preds_S,
preds_T,
weight=None,
avg_factor=None,
reduction_override=None):
"""Forward computation.

Args:
preds_S (torch.Tensor): The student model prediction with
shape (N, C, H, W) or shape (N, C).
preds_T (torch.Tensor): The teacher model prediction with
shape (N, C, H, W) or shape (N, C).
weight (torch.Tensor, optional): The weight of loss for each
prediction. Defaults to None.
avg_factor (int, optional): Average factor that is used to average
the loss. Defaults to None.
reduction_override (str, optional): The reduction method used to
override the original reduction method of the loss.
Options are "none", "mean", "sum" and "batchmean".

Return:
torch.Tensor: The calculated loss value.
"""
assert reduction_override in (None, 'none', 'mean', 'sum', 'batchmean')
reduction = (
reduction_override if reduction_override else self.reduction)
if self.teacher_detach:
preds_T = preds_T.detach()
softmax_pred_T = F.softmax(preds_T / self.tau, dim=1)
logsoftmax_preds_S = F.log_softmax(preds_S / self.tau, dim=1)
loss = (self.tau**2) * F.kl_div(
logsoftmax_preds_S, softmax_pred_T, reduction=self.reduction)
loss = kl_div(
preds_S,
preds_T,
self.tau,
weight,
reduction=reduction,
avg_factor=avg_factor)
return self.loss_weight * loss
124 changes: 0 additions & 124 deletions mmrazor/models/losses/ppyoloe_distill_loss.py

This file was deleted.

129 changes: 129 additions & 0 deletions mmrazor/models/losses/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
# Copyright (c) OpenMMLab. All rights reserved.
import functools
from typing import Callable, Optional

import torch
import torch.nn.functional as F
from torch import Tensor


def reduce_loss(loss: Tensor, reduction: str) -> Tensor:
"""Reduce loss as specified.

Args:
loss (Tensor): Elementwise loss tensor.
reduction (str): Options are "none", "mean" and "sum".

Return:
Tensor: Reduced loss tensor.
"""
# special case for batchmean in kl_div
if reduction == 'batchmean':
return loss.sum() / loss.size()[0]

reduction_enum = F._Reduction.get_enum(reduction)
# none: 0, elementwise_mean:1, sum: 2
if reduction_enum == 0:
return loss
elif reduction_enum == 1:
return loss.mean()
elif reduction_enum == 2:
return loss.sum()


def weight_reduce_loss(loss: Tensor,
weight: Optional[Tensor] = None,
reduction: str = 'mean',
avg_factor: Optional[float] = None) -> Tensor:
"""Apply element-wise weight and reduce loss.

Args:
loss (Tensor): Element-wise loss.
weight (Optional[Tensor], optional): Element-wise weights.
Defaults to None.
reduction (str, optional): Same as built-in losses of PyTorch.
Defaults to 'mean'.
avg_factor (Optional[float], optional): Average factor when
computing the mean of losses. Defaults to None.

Returns:
Tensor: Processed loss values.
"""
# if weight is specified, apply element-wise weight
if weight is not None:
loss = loss * weight

# if avg_factor is not specified, just reduce the loss
if avg_factor is None:
loss = reduce_loss(loss, reduction)
else:
# if reduction is mean, then average the loss by avg_factor
if reduction == 'mean':
# Avoid causing ZeroDivisionError when avg_factor is 0.0,
# i.e., all labels of an image belong to ignore index.
eps = torch.finfo(torch.float32).eps
loss = loss.sum() / (avg_factor + eps)
# if reduction is 'none', then do nothing, otherwise raise an error
elif reduction != 'none':
raise ValueError('avg_factor can not be used with reduction="sum"')
return loss


def weighted_loss(loss_func: Callable) -> Callable:
"""Create a weighted version of a given loss function.

To use this decorator, the loss function must have the signature like
`loss_func(pred, target, **kwargs)`. The function only needs to compute
element-wise loss without any reduction. This decorator will add weight
and reduction arguments to the function. The decorated function will have
the signature like `loss_func(pred, target, weight=None, reduction='mean',
avg_factor=None, **kwargs)`.

:Example:

>>> import torch
>>> @weighted_loss
>>> def l1_loss(pred, target):
>>> return (pred - target).abs()

>>> pred = torch.Tensor([0, 2, 3])
>>> target = torch.Tensor([1, 1, 1])
>>> weight = torch.Tensor([1, 0, 1])

>>> l1_loss(pred, target)
tensor(1.3333)
>>> l1_loss(pred, target, weight)
tensor(1.)
>>> l1_loss(pred, target, reduction='none')
tensor([1., 1., 2.])
>>> l1_loss(pred, target, weight, avg_factor=2)
tensor(1.5000)
"""

@functools.wraps(loss_func)
def wrapper(pred: Tensor,
target: Tensor,
weight: Optional[Tensor] = None,
reduction: str = 'mean',
avg_factor: Optional[int] = None,
**kwargs) -> Tensor:
"""
Args:
pred (Tensor): The prediction.
target (Tensor): Target bboxes.
weight (Optional[Tensor], optional): The weight of loss for each
prediction. Defaults to None.
reduction (str, optional): Options are "none", "mean" and "sum".
Defaults to 'mean'.
avg_factor (Optional[int], optional): Average factor that is used
to average the loss. Defaults to None.

Returns:
Tensor: Loss tensor.
"""
# get element-wise loss
loss = loss_func(pred, target, **kwargs)
loss = weight_reduce_loss(loss, weight, reduction, avg_factor)
return loss

return wrapper