Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

[Feature] Distillation enhancement #463

Open
wants to merge 17 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions mmrazor/engine/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .hooks import DumpSubnetHook, EstimateResourcesHook
from .hooks import (DistillationLossDetachHook, DumpSubnetHook,
EstimateResourcesHook)
from .optimizers import SeparateOptimWrapperConstructor
from .runner import (AutoSlimGreedySearchLoop, DartsEpochBasedTrainLoop,
DartsIterBasedTrainLoop, EvolutionSearchLoop,
Expand All @@ -12,5 +13,5 @@
'SingleTeacherDistillValLoop', 'DartsEpochBasedTrainLoop',
'DartsIterBasedTrainLoop', 'SlimmableValLoop', 'EvolutionSearchLoop',
'GreedySamplerTrainLoop', 'EstimateResourcesHook', 'SelfDistillValLoop',
'AutoSlimGreedySearchLoop', 'SubnetValLoop'
'AutoSlimGreedySearchLoop', 'SubnetValLoop', 'DistillationLossDetachHook'
]
6 changes: 5 additions & 1 deletion mmrazor/engine/hooks/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .distillation_loss_detach_hook import DistillationLossDetachHook
from .dump_subnet_hook import DumpSubnetHook
from .estimate_resources_hook import EstimateResourcesHook
from .visualization_hook import RazorVisualizationHook

__all__ = ['DumpSubnetHook', 'EstimateResourcesHook', 'RazorVisualizationHook']
__all__ = [
'DumpSubnetHook', 'EstimateResourcesHook', 'RazorVisualizationHook',
'DistillationLossDetachHook'
]
25 changes: 25 additions & 0 deletions mmrazor/engine/hooks/distillation_loss_detach_hook.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# Copyright (c) OpenMMLab. All rights reserved.
from mmengine.hooks import Hook
from mmengine.model import is_model_wrapper

from mmrazor.registry import HOOKS


@HOOKS.register_module()
class DistillationLossDetachHook(Hook):

priority = 'LOW'

def __init__(self, detach_epoch) -> None:
self.detach_epoch = detach_epoch

def before_train_epoch(self, runner) -> None:
if runner.epoch >= self.detach_epoch:
model = runner.model
# TODO: refactor after mmengine using model wrapper
if is_model_wrapper(model):
model = model.module
assert hasattr(model, 'distill_loss_detach')

runner.logger.info('Distillation stop now!')
model.distill_loss_detach = True
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,11 @@ def loss(
batch_inputs, data_samples, mode='loss')
losses.update(add_prefix(student_losses, 'student'))

# Automatically compute distill losses based on `loss_forward_mappings`
# The required data already exists in the recorders.
distill_losses = self.distiller.compute_distill_losses()
losses.update(add_prefix(distill_losses, 'distill'))
if not self.distill_loss_detach:
# Automatically compute distill losses based on
# `loss_forward_mappings`.
# The required data already exists in the recorders.
distill_losses = self.distiller.compute_distill_losses()
losses.update(add_prefix(distill_losses, 'distill'))

return losses
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,9 @@ def __init__(self,
self.set_module_inplace_false(teacher, 'self.teacher')

if teacher_ckpt:
# avoid loaded parameters be overwritten
self.teacher.init_weights()
_ = load_checkpoint(self.teacher, teacher_ckpt)
# avoid loaded parameters be overwritten
self.teacher._is_init = True
self.teacher_trainable = teacher_trainable
if not self.teacher_trainable:
for param in self.teacher.parameters():
Expand All @@ -89,6 +89,9 @@ def __init__(self,
self.distiller.prepare_from_student(self.student)
self.distiller.prepare_from_teacher(self.teacher)

# may be modified by distill loss scheduler hook
self.distill_loss_detach = False

@property
def student(self) -> nn.Module:
"""Alias for ``architecture``."""
Expand Down Expand Up @@ -135,10 +138,12 @@ def loss(
_ = self.student(
batch_inputs, data_samples, mode='loss')

# Automatically compute distill losses based on `loss_forward_mappings`
# The required data already exists in the recorders.
distill_losses = self.distiller.compute_distill_losses()
losses.update(add_prefix(distill_losses, 'distill'))
if not self.distill_loss_detach:
# Automatically compute distill losses based on
# `loss_forward_mappings`.
# The required data already exists in the recorders.
distill_losses = self.distiller.compute_distill_losses()
losses.update(add_prefix(distill_losses, 'distill'))

return losses

Expand Down
4 changes: 3 additions & 1 deletion mmrazor/models/architectures/connectors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,13 @@
from .factor_transfer_connectors import Paraphraser, Translator
from .fbkd_connector import FBKDStudentConnector, FBKDTeacherConnector
from .mgd_connector import MGDConnector
from .norm_connector import NormConnector
from .ofd_connector import OFDTeacherConnector
from .torch_connector import TorchFunctionalConnector, TorchNNConnector

__all__ = [
'ConvModuleConnector', 'Translator', 'Paraphraser', 'BYOTConnector',
'FBKDTeacherConnector', 'FBKDStudentConnector', 'TorchFunctionalConnector',
'CRDConnector', 'TorchNNConnector', 'OFDTeacherConnector', 'MGDConnector'
'CRDConnector', 'TorchNNConnector', 'OFDTeacherConnector', 'MGDConnector',
'NormConnector'
]
19 changes: 19 additions & 0 deletions mmrazor/models/architectures/connectors/norm_connector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, Optional

import torch
from mmcv.cnn import build_norm_layer

from mmrazor.registry import MODELS
from .base_connector import BaseConnector


@MODELS.register_module()
class NormConnector(BaseConnector):

def __init__(self, in_channels, norm_cfg, init_cfg: Optional[Dict] = None):
super(NormConnector, self).__init__(init_cfg)
_, self.norm = build_norm_layer(norm_cfg, in_channels)

def forward_train(self, feature: torch.Tensor) -> torch.Tensor:
return self.norm(feature)
30 changes: 24 additions & 6 deletions mmrazor/models/distillers/configurable_distiller.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
# Copyright (c) OpenMMLab. All rights reserved.
import warnings
from inspect import signature
from typing import Dict, List, Optional
from typing import Dict, List, Optional, Union

import torch
from mmengine.model import BaseModel
from torch import nn

Expand Down Expand Up @@ -139,15 +140,24 @@ def prepare_from_teacher(self, model: nn.Module) -> None:

def build_connectors(
self,
connectors: Optional[Dict[str, Dict]] = None,
connectors: Optional[Union[Dict[str, List], Dict[str, Dict]]] = None,
) -> nn.ModuleDict:
"""Initialize connectors."""

distill_connecotrs = nn.ModuleDict()
if connectors:
for connector_name, connector_cfg in connectors.items():
connector = MODELS.build(connector_cfg)
distill_connecotrs[connector_name] = connector
if isinstance(connector_cfg, dict):
connector = MODELS.build(connector_cfg)
distill_connecotrs[connector_name] = connector
else:
assert isinstance(connector_cfg, list)
module_list = []
for cfg in connector_cfg:
connector = MODELS.build(cfg)
module_list.append(connector)
distill_connecotrs[connector_name] = nn.Sequential(
*module_list)

return distill_connecotrs

Expand Down Expand Up @@ -204,12 +214,20 @@ def compute_distill_losses(self) -> LossResults:
losses = dict()
for loss_name, forward_mappings in self.loss_forward_mappings.items():
forward_kwargs = dict()
is_empty = False
for forward_key, record in forward_mappings.items():
forward_var = self.get_record(**record)
try:
forward_var = self.get_record(**record)
except AssertionError:
is_empty = True
break
forward_kwargs[forward_key] = forward_var

loss_module = self.distill_losses[loss_name]
loss = loss_module(**forward_kwargs) # type: ignore
if not is_empty:
loss = loss_module(**forward_kwargs) # type: ignore
else:
loss = torch.tensor(0.)
# add computed loss result.
losses[loss_name] = loss

Expand Down
4 changes: 3 additions & 1 deletion mmrazor/models/losses/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from .decoupled_kd import DKDLoss
from .factor_transfer_loss import FTLoss
from .fbkd_loss import FBKDLoss
from .fgd_loss import FGDLoss
from .kd_soft_ce_loss import KDSoftCELoss
from .kl_divergence import KLDivergence
from .l1_loss import L1Loss
Expand All @@ -22,5 +23,6 @@
'ChannelWiseDivergence', 'KLDivergence', 'AngleWiseRKD', 'DistanceWiseRKD',
'WSLD', 'L2Loss', 'ABLoss', 'DKDLoss', 'KDSoftCELoss', 'ActivationLoss',
'OnehotLikeLoss', 'InformationEntropyLoss', 'FTLoss', 'ATLoss', 'OFDLoss',
'L1Loss', 'FBKDLoss', 'CRDLoss', 'CrossEntropyLoss', 'PKDLoss', 'MGDLoss'
'L1Loss', 'FBKDLoss', 'CRDLoss', 'CrossEntropyLoss', 'PKDLoss', 'MGDLoss',
'FGDLoss'
]
6 changes: 1 addition & 5 deletions mmrazor/models/losses/cwd.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,7 @@ class ChannelWiseDivergence(nn.Module):
loss_weight (float): Weight of loss. Defaults to 1.0.
"""

def __init__(
self,
tau=1.0,
loss_weight=1.0,
):
def __init__(self, tau=1.0, loss_weight=1.0):
super(ChannelWiseDivergence, self).__init__()
self.tau = tau
self.loss_weight = loss_weight
Expand Down
Loading