diff --git a/configs/distill/mmcls/fitnet/README.md b/configs/distill/mmcls/fitnet/README.md deleted file mode 100644 index 23cfe1d2d..000000000 --- a/configs/distill/mmcls/fitnet/README.md +++ /dev/null @@ -1,48 +0,0 @@ -# FitNets - -> [FitNets: Hints for Thin Deep Nets](https://arxiv.org/abs/1412.6550) - - - -## Abstract - -While depth tends to improve network performances, it also makes gradient-based -training more difficult since deeper networks tend to be more non-linear. The recently -proposed knowledge distillation approach is aimed at obtaining small and fast-to-execute -models, and it has shown that a student network could imitate the soft output of a larger -teacher network or ensemble of networks. In this paper, we extend this idea to allow the -training of a student that is deeper and thinner than the teacher, using not only the outputs -but also the intermediate representations learned by the teacher as hints to improve the -training process and final performance of the student. Because the student intermediate hidden -layer will generally be smaller than the teacher's intermediate hidden layer, additional parameters -are introduced to map the student hidden layer to the prediction of the teacher hidden layer. This -allows one to train deeper students that can generalize better or run faster, a trade-off that is -controlled by the chosen student capacity. For example, on CIFAR-10, a deep student network with -almost 10.4 times less parameters outperforms a larger, state-of-the-art teacher network. - -![pipeline](/docs/en/imgs/model_zoo/fitnet/pipeline.png) - -## Results and models - -### Classification - -| Location | Dataset | Teacher | Student | Acc | Acc(T) | Acc(S) | Config | Download | -| :---------------: | :------: | :----------------------------------------------------------------------------------------------------------: | :----------------------------------------------------------------------------------------------------------: | :---: | :----: | :----: | :----------------------------------------------------------------: | :---------------------------------------------------------------------------------------------------------------------------------------- | -| backbone & logits | ImageNet | [resnet50](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnet50_8xb32_in1k.py) | [resnet18](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnet18_8xb32_in1k.py) | 70.85 | 76.55 | 69.90 | [config](./fitnet_backbone_logits_resnet50_resnet18_8xb16_in1k.py) | [teacher](https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_8xb32_in1k_20210831-ea4938fc.pth) \|[model](<>) \| [log](<>) | - -## Citation - -```latex -@inproceedings{DBLP:journals/corr/RomeroBKCGB14, - author = {Adriana Romero, Nicolas Ballas, Samira Ebrahimi Kahou, Antoine Chassang, Carlo Gatta and Yoshua Bengio}, - editor = {Yoshua Bengio and Yann LeCun}, - title = {FitNets: Hints for Thin Deep Nets}, - booktitle = {3rd International Conference on Learning Representations, {ICLR} 2015, - San Diego, CA, USA, May 7-9, 2015, Conference Track Proceedings}, - year = {2015}, - url = {http://arxiv.org/abs/1412.6550}, - timestamp = {Thu, 25 Jul 2019 14:25:38 +0200}, - biburl = {https://dblp.org/rec/journals/corr/RomeroBKCGB14.bib}, - bibsource = {dblp computer science bibliography, https://dblp.org} -} -``` diff --git a/configs/distill/mmcls/fitnet/fitnet_backbone_logits_resnet50_resnet18_8xb32_in1k.py b/configs/distill/mmcls/fitnet/fitnet_backbone_logits_resnet50_resnet18_8xb32_in1k.py deleted file mode 100644 index b46300e73..000000000 --- a/configs/distill/mmcls/fitnet/fitnet_backbone_logits_resnet50_resnet18_8xb32_in1k.py +++ /dev/null @@ -1,71 +0,0 @@ -_base_ = [ - 'mmcls::_base_/datasets/imagenet_bs32.py', - 'mmcls::_base_/schedules/imagenet_bs256.py', - 'mmcls::_base_/default_runtime.py' -] - -model = dict( - _scope_='mmrazor', - type='SingleTeacherDistill', - data_preprocessor=dict( - type='ImgDataPreprocessor', - # RGB format normalization parameters - mean=[123.675, 116.28, 103.53], - std=[58.395, 57.12, 57.375], - # convert image from BGR to RGB - bgr_to_rgb=True), - architecture=dict( - cfg_path='mmcls::resnet/resnet18_8xb32_in1k.py', pretrained=False), - teacher=dict( - cfg_path='mmcls::resnet/resnet50_8xb32_in1k.py', pretrained=True), - teacher_ckpt='resnet50_8xb32_in1k_20210831-ea4938fc.pth', - distiller=dict( - type='ConfigurableDistiller', - student_recorders=dict( - bb_s4=dict(type='ModuleOutputs', source='backbone.layer4.1.relu'), - bb_s3=dict(type='ModuleOutputs', source='backbone.layer3.1.relu'), - fc=dict(type='ModuleOutputs', source='head.fc')), - teacher_recorders=dict( - bb_s4=dict(type='ModuleOutputs', source='backbone.layer4.2.relu'), - bb_s3=dict(type='ModuleOutputs', source='backbone.layer3.5.relu'), - fc=dict(type='ModuleOutputs', source='head.fc')), - distill_losses=dict( - loss_s4=dict(type='L2Loss', loss_weight=10), - loss_s3=dict(type='L2Loss', loss_weight=10), - loss_kl=dict( - type='KLDivergence', tau=6, loss_weight=10, reduction='mean')), - connectors=dict( - loss_s4_sfeat=dict( - type='ConvBNReLUConnector', - in_channel=512, - out_channel=2048, - norm_cfg=dict(type='BN')), - loss_s3_sfeat=dict( - type='ConvBNReLUConnector', - in_channel=256, - out_channel=1024, - norm_cfg=dict(type='BN'))), - loss_forward_mappings=dict( - loss_s4=dict( - s_feature=dict( - from_student=True, - recorder='bb_s4', - record_idx=1, - connector='loss_s4_sfeat'), - t_feature=dict( - from_student=False, recorder='bb_s4', record_idx=2)), - loss_s3=dict( - s_feature=dict( - from_student=True, - recorder='bb_s3', - record_idx=1, - connector='loss_s3_sfeat'), - t_feature=dict( - from_student=False, recorder='bb_s3', record_idx=2)), - loss_kl=dict( - preds_S=dict(from_student=True, recorder='fc'), - preds_T=dict(from_student=False, recorder='fc'))))) - -find_unused_parameters = True - -val_cfg = dict(_delete_=True, type='mmrazor.SingleTeacherDistillValLoop') diff --git a/docs/en/imgs/model_zoo/fitnet/pipeline.png b/docs/en/imgs/model_zoo/fitnet/pipeline.png deleted file mode 100644 index 19662f882..000000000 Binary files a/docs/en/imgs/model_zoo/fitnet/pipeline.png and /dev/null differ diff --git a/mmrazor/models/algorithms/distill/configurable/single_teacher_distill.py b/mmrazor/models/algorithms/distill/configurable/single_teacher_distill.py index 931cbc169..ef3c44246 100644 --- a/mmrazor/models/algorithms/distill/configurable/single_teacher_distill.py +++ b/mmrazor/models/algorithms/distill/configurable/single_teacher_distill.py @@ -5,7 +5,6 @@ from mmcv.runner import load_checkpoint from mmengine import BaseDataElement from mmengine.model import BaseModel -from torch import nn from torch.nn.modules.batchnorm import _BatchNorm from mmrazor.models.utils import add_prefix @@ -19,7 +18,6 @@ class SingleTeacherDistill(BaseAlgorithm): only use one teacher. Args: - distiller (dict): The config dict for built distiller. teacher (dict | BaseModel): The config dict for teacher model or built teacher model. teacher_ckpt (str): The path of teacher's checkpoint. Defaults to None. @@ -28,10 +26,6 @@ class SingleTeacherDistill(BaseAlgorithm): teacher_norm_eval (bool): Whether to set teacher's norm layers to eval mode, namely, freeze running stats (mean and var). Note: Effect on Batch Norm and its variants only. Defaults to True. - student_trainable (bool): Whether the student is trainable. Defaults - to True. - calculate_student_loss (bool): Whether to calculate student loss - (original task loss) to update student model. Defaults to True. """ def __init__(self, @@ -40,9 +34,7 @@ def __init__(self, teacher_ckpt: Optional[str] = None, teacher_trainable: bool = False, teacher_norm_eval: bool = True, - student_trainable: bool = True, - calculate_student_loss: bool = True, - **kwargs) -> None: + **kwargs): super().__init__(**kwargs) self.distiller = MODELS.build(distiller) @@ -63,21 +55,13 @@ def __init__(self, self.teacher_trainable = teacher_trainable self.teacher_norm_eval = teacher_norm_eval - # The student model will not calculate gradients and update parameters - # in some pretraining process. - self.student_trainable = student_trainable - - # The student loss will not be updated into ``losses`` in some - # pretraining process. - self.calculate_student_loss = calculate_student_loss - # In ``ConfigurableDistller``, the recorder manager is just # constructed, but not really initialized yet. self.distiller.prepare_from_student(self.student) self.distiller.prepare_from_teacher(self.teacher) @property - def student(self) -> nn.Module: + def student(self): """Alias for ``architecture``.""" return self.architecture @@ -102,25 +86,16 @@ def loss( else: with self.distiller.teacher_recorders, self.distiller.deliveries: with torch.no_grad(): + _ = self.teacher(batch_inputs, data_samples, mode='loss') # If the `override_data` of a delivery is True, the delivery will # override the origin data with the recorded data. self.distiller.set_deliveries_override(True) - # Original task loss will not be used during some pretraining process. - if self.calculate_student_loss: - with self.distiller.student_recorders, self.distiller.deliveries: - student_losses = self.student( - batch_inputs, data_samples, mode='loss') - losses.update(add_prefix(student_losses, 'student')) - else: - with self.distiller.student_recorders, self.distiller.deliveries: - if self.student_trainable: - _ = self.student(batch_inputs, data_samples, mode='loss') - else: - with torch.no_grad(): - _ = self.student( - batch_inputs, data_samples, mode='loss') + with self.distiller.student_recorders, self.distiller.deliveries: + student_losses = self.student( + batch_inputs, data_samples, mode='loss') + losses.update(add_prefix(student_losses, 'student')) # Automatically compute distill losses based on `loss_forward_mappings` # The required data already exists in the recorders. @@ -129,7 +104,7 @@ def loss( return losses - def train(self, mode: bool = True) -> None: + def train(self, mode=True): """Set distiller's forward mode.""" super().train(mode) if mode and self.teacher_norm_eval: diff --git a/mmrazor/models/architectures/__init__.py b/mmrazor/models/architectures/__init__.py index 317e1fde7..f267930f3 100644 --- a/mmrazor/models/architectures/__init__.py +++ b/mmrazor/models/architectures/__init__.py @@ -1,5 +1,4 @@ # Copyright (c) OpenMMLab. All rights reserved. from .backbones import * # noqa: F401,F403 -from .connectors import * # noqa: F401,F403 from .dynamic_op import * # noqa: F401,F403 from .heads import * # noqa: F401,F403 diff --git a/mmrazor/models/architectures/connectors/__init__.py b/mmrazor/models/architectures/connectors/__init__.py deleted file mode 100644 index 28673e8ee..000000000 --- a/mmrazor/models/architectures/connectors/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -from .general_connector import (ConvBNConnector, ConvBNReLUConnector, - SingleConvConnector) - -__all__ = ['ConvBNConnector', 'ConvBNReLUConnector', 'SingleConvConnector'] diff --git a/mmrazor/models/architectures/connectors/base_connector.py b/mmrazor/models/architectures/connectors/base_connector.py deleted file mode 100644 index 4322efa86..000000000 --- a/mmrazor/models/architectures/connectors/base_connector.py +++ /dev/null @@ -1,41 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -from abc import ABCMeta, abstractmethod -from typing import Dict, Optional - -import torch -from mmcv.runner import BaseModule - - -class BaseConnector(BaseModule, metaclass=ABCMeta): - """Base class of connectors. - - Connector is mainly used for distillation, it usually converts the channel - number of input feature to align features of student and teacher. - - All subclasses should implement the following APIs: - - - ``forward_train()`` - - Args: - init_cfg (dict, optional): The config to control the initialization. - """ - - def __init__(self, init_cfg: Optional[Dict] = None) -> None: - super().__init__(init_cfg=init_cfg) - - def forward(self, feature: torch.Tensor) -> None: - """Forward computation. - - Args: - feature (torch.Tensor): Input feature. - """ - return self.forward_train(feature) - - @abstractmethod - def forward_train(self, feature) -> torch.Tensor: - """Abstract train computation. - - Args: - feature (torch.Tensor): Input feature. - """ - pass diff --git a/mmrazor/models/architectures/connectors/general_connector.py b/mmrazor/models/architectures/connectors/general_connector.py deleted file mode 100644 index 156468cb8..000000000 --- a/mmrazor/models/architectures/connectors/general_connector.py +++ /dev/null @@ -1,135 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -from typing import Dict, Optional - -import numpy as np -import torch -import torch.nn as nn -from mmcv.cnn import build_conv_layer, build_norm_layer - -from mmrazor.registry import MODELS -from .base_connector import BaseConnector - - -@MODELS.register_module() -class SingleConvConnector(BaseConnector): - """General connector which only contains a conv layer. - - Args: - in_channel (int): The input channel of the connector. - out_channel (int): The output channel of the connector. - conv_cfg (dict, optional): The config to control the convolution. - init_cfg (dict, optional): The config to control the initialization. - """ - - def __init__( - self, - in_channel: int, - out_channel: int, - conv_cfg: Optional[Dict] = None, - init_cfg: Optional[Dict] = None, - ) -> None: - super().__init__(init_cfg) - self.conv = build_conv_layer( - conv_cfg, in_channel, out_channel, kernel_size=1, stride=1) - - def forward_train(self, feature: torch.Tensor) -> torch.Tensor: - """Forward computation. - - Args: - feature (torch.Tensor): Input feature. - """ - return self.conv(feature) - - def init_weights(self) -> None: - """Init parameters. - - In the subclass of ``BaseModule``, `init_weights` will be called - automativally. - """ - with torch.no_grad(): - for m in self.modules(): - if isinstance(m, nn.Conv2d): - device = m.weight.device - in_channels, _, k1, k2 = m.weight.shape - m.weight[:] = torch.randn( - m.weight.shape, device=device) / np.sqrt( - k1 * k2 * in_channels) * 1e-4 - if hasattr(m, 'bias') and m.bias is not None: - nn.init.zeros_(m.bias) - else: - continue - - -@MODELS.register_module() -class ConvBNConnector(BaseConnector): - """General connector which contains a conv layer with BN. - - Args: - in_channel (int): The input channels of the connector. - out_channel (int): The output channels of the connector. - norm_cfg (dict): The config to control the normalization. - conv_cfg (dict, optional): The config to control the convolution. - init_cfg (dict, optional): The config to control the initialization. - """ - - def __init__( - self, - in_channel: int, - out_channel: int, - norm_cfg: Dict, - conv_cfg: Optional[Dict] = None, - init_cfg: Optional[Dict] = None, - ) -> None: - super().__init__(init_cfg) - self.conv = build_conv_layer( - conv_cfg, - in_channel, - out_channel, - kernel_size=1, - stride=1, - padding=0, - bias=False) - _, self.bn = build_norm_layer(norm_cfg, out_channel) - - def forward_train(self, feature: torch.Tensor) -> torch.Tensor: - """Forward computation. - - Args: - feature (torch.Tensor): Input feature. - """ - return self.bn(self.conv(feature)) - - -@MODELS.register_module() -class ConvBNReLUConnector(BaseConnector): - """General connector which contains a conv layer with BN and ReLU. - - Args: - in_channel (int): The input channels of the connector. - out_channel (int): The output channels of the connector. - norm_cfg (dict): The config to control the normalization. - conv_cfg (dict, optional): The config to control the convolution. - init_cfg (dict, optional): The config to control the initialization. - """ - - def __init__( - self, - in_channel: int, - out_channel: int, - norm_cfg: Dict, - conv_cfg: Optional[Dict] = None, - init_cfg: Optional[Dict] = None, - ) -> None: - super().__init__(init_cfg) - self.conv = build_conv_layer( - conv_cfg, in_channel, out_channel, kernel_size=1) - _, self.bn = build_norm_layer(norm_cfg, out_channel) - self.relu = nn.ReLU(inplace=True) - - def forward_train(self, feature: torch.Tensor) -> torch.Tensor: - """Forward computation. - - Args: - feature (torch.Tensor): Input feature. - """ - return self.relu(self.bn(self.conv(feature))) diff --git a/mmrazor/models/distillers/base_distiller.py b/mmrazor/models/distillers/base_distiller.py index 4cf575e90..317d033f8 100644 --- a/mmrazor/models/distillers/base_distiller.py +++ b/mmrazor/models/distillers/base_distiller.py @@ -1,6 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. from abc import ABC, abstractmethod -from typing import Dict, Optional from mmengine.model import BaseModule @@ -8,13 +7,9 @@ class BaseDistiller(BaseModule, ABC): - """Base class for distiller. + """Base class for distiller.""" - Args: - init_cfg (dict, optional): Config for distiller. Default to None. - """ - - def __init__(self, init_cfg: Optional[Dict] = None) -> None: + def __init__(self, init_cfg=None): super().__init__(init_cfg) @abstractmethod diff --git a/mmrazor/models/distillers/configurable_distiller.py b/mmrazor/models/distillers/configurable_distiller.py index 621dfa624..a794417ed 100644 --- a/mmrazor/models/distillers/configurable_distiller.py +++ b/mmrazor/models/distillers/configurable_distiller.py @@ -38,9 +38,6 @@ class ConfigurableDistiller(BaseDistiller): distill_deliveries (dict, optional): Config for multiple deliveries. A distill algorithm may have more than one delivery. Defaults to None. - connectors (dict, optional): Config for multiple connectors. A - distillation model may have more than one connector. Defaults to - None. distill_losses: (Dict[str, Dict], optional): Config for multiple distill losses. A distill algorithm may have more than one distill loss. Defaults to None. @@ -67,45 +64,33 @@ class ConfigurableDistiller(BaseDistiller): `student_recorders``; otherwise, it means the recorder is in ``teacher_recorders``. - A connector can be called according to its `connector_name`, so that a - input can use a different connector in different loss. - Examples: >>> distill_losses = dict( - ... loss_neck=dict(type='L2Loss', loss_weight=5)) + ... loss_kl=dict(type='KLDivergence', tau=1, loss_weight=5)) >>> student_recorders = dict( - ... feat = dict(type='ModuleOutputs', sources=['neck.gap'])) + ... fc = dict(type='ModuleOutputs', sources=['head.fc'])) >>> teacher_recorders = dict( - ... feat = dict(type='ModuleOutputs', sources=['neck.gap'])) - - >>> connectors = dict( - ... loss_neck_sfeat = dict( - ... type='SingleConvConnector', in_channel=32, out_channel=64), - ... loss_neck_tfeat = dict( - ... type='SingleConvConnector', in_channel=32, out_channel=64)) + ... fc = dict(type='ModuleOutputs', sources=['head.fc'])) >>> loss_forward_mappings = dict( - ... loss_neck=dict( - ... s_feature=dict(from_recorder='feat', from_student=True, - ... connector='loss_neck_sfeat'), - ... t_feature=dict(from_recorder='feat', from_student=False, - ... connector='loss_neck_tfeat'))) + ... loss_kl=dict( + ... preds_S=dict(from_recorder='fc', from_student=True), + ... preds_T=dict(from_recorder='fc', from_student=False))) """ def __init__(self, student_recorders: Optional[Dict[str, Dict]] = None, teacher_recorders: Optional[Dict[str, Dict]] = None, distill_deliveries: Optional[Dict[str, Dict]] = None, - connectors: Optional[Dict[str, Dict]] = None, distill_losses: Optional[Dict[str, Dict]] = None, loss_forward_mappings: Optional[Dict[str, Dict]] = None, **kwargs): super().__init__(**kwargs) # The recorder manager is just constructed, but not really initialized # yet. Recorder manager initialization needs to input the corresponding - # model. + # model. self.student_recorders = RecorderManager(student_recorders) self.teacher_recorders = RecorderManager(teacher_recorders) @@ -113,10 +98,8 @@ def __init__(self, self.distill_losses = self.build_distill_losses(distill_losses) - self.connectors = self.build_connectors(connectors) - if loss_forward_mappings: - # Check if loss_forward_mappings is in the correct format. + # Check if loss_forward_mappings is in the correct format self._check_loss_forward_mappings(self.distill_losses, loss_forward_mappings, self.student_recorders, @@ -125,7 +108,7 @@ def __init__(self, else: self.loss_forward_mappings = dict() - def set_deliveries_override(self, override: bool) -> None: + def set_deliveries_override(self, override: bool): """Set the `override_data` of all deliveries.""" self.deliveries.override_data = override @@ -137,23 +120,6 @@ def prepare_from_teacher(self, model: nn.Module) -> None: """Initialize teacher recorders.""" self.teacher_recorders.initialize(model) - def build_connectors( - self, - connectors: Optional[Dict[str, Dict]] = None, - ) -> nn.ModuleDict: - """Initialize connectors.""" - - distill_connecotrs = nn.ModuleDict() - if connectors: - for connector_name, connector_cfg in connectors.items(): - assert connector_name not in distill_connecotrs, \ - f'{connector_name} is already in "distill_connecotrs".' - - connector = MODELS.build(connector_cfg) - distill_connecotrs[connector_name] = connector - - return distill_connecotrs - def build_distill_losses( self, losses: Optional[Dict[str, Dict]] = None, @@ -182,8 +148,7 @@ def get_record(self, recorder: str, from_student: bool, record_idx: int = 0, - data_idx: Optional[int] = None, - connector: Optional[str] = None) -> List: + data_idx: Optional[int] = None) -> List: """According to each item in ``record_infos``, get the corresponding record in ``recorder_manager``.""" @@ -191,12 +156,8 @@ def get_record(self, recorder_ = self.student_recorders.get_recorder(recorder) else: recorder_ = self.teacher_recorders.get_recorder(recorder) - record_data = recorder_.get_record_data(record_idx, data_idx) - if connector: - record_data = self.connectors[connector](record_data) - - return record_data + return recorder_.get_record_data(record_idx, data_idx) def compute_distill_losses(self) -> LossResults: """Compute distill losses automatically.""" @@ -204,8 +165,8 @@ def compute_distill_losses(self) -> LossResults: losses = dict() for loss_name, forward_mappings in self.loss_forward_mappings.items(): forward_kwargs = dict() - for forward_key, record in forward_mappings.items(): - forward_var = self.get_record(**record) + for forward_key, record_info in forward_mappings.items(): + forward_var = self.get_record(**record_info) forward_kwargs[forward_key] = forward_var loss_module = self.distill_losses[loss_name] @@ -272,8 +233,3 @@ def _check_loss_forward_mappings( assert recorder in teacher_recorders.recorders, \ f'For {forward_key}, "{recorder}" must be in \ `teacher_recorders`.' - - if 'connector' in record_info: - connector: str = record_info['connector'] - assert connector in self.connectors, \ - f'{connector} must be in "connectors".' diff --git a/mmrazor/models/losses/__init__.py b/mmrazor/models/losses/__init__.py index db0bd7af6..3d3e97e52 100644 --- a/mmrazor/models/losses/__init__.py +++ b/mmrazor/models/losses/__init__.py @@ -1,11 +1,10 @@ # Copyright (c) OpenMMLab. All rights reserved. from .cwd import ChannelWiseDivergence from .kl_divergence import KLDivergence -from .l2_loss import L2Loss from .relational_kd import AngleWiseRKD, DistanceWiseRKD from .weighted_soft_label_distillation import WSLD __all__ = [ 'ChannelWiseDivergence', 'KLDivergence', 'AngleWiseRKD', 'DistanceWiseRKD', - 'WSLD', 'L2Loss' + 'WSLD' ] diff --git a/mmrazor/models/losses/l2_loss.py b/mmrazor/models/losses/l2_loss.py deleted file mode 100644 index 8b373ed38..000000000 --- a/mmrazor/models/losses/l2_loss.py +++ /dev/null @@ -1,67 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import torch -import torch.nn as nn - -from mmrazor.registry import MODELS - - -@MODELS.register_module() -class L2Loss(nn.Module): - """Calculate the two-norm loss between the two features. - - Args: - loss_weight (float): Weight of loss. Defaults to 1.0. - normalize (bool): Whether to normalize the feature. Defaults to True. - mult (float): Multiplier for feature normalization. Defaults to 1.0. - div_element (bool): Whether to divide the loss by element-wise. - Defaults to False. - """ - - def __init__( - self, - loss_weight: float = 1.0, - normalize: bool = True, - mult: float = 1.0, - div_element: bool = False, - ) -> None: - super().__init__() - self.loss_weight = loss_weight - self.normalize = normalize - self.mult = mult - self.div_element = div_element - - def forward( - self, - s_feature: torch.Tensor, - t_feature: torch.Tensor, - ) -> torch.Tensor: - """Forward computation. - - Args: - s_feature (torch.Tensor): The student model feature with - shape (N, C, H, W) or shape (N, C). - t_feature (torch.Tensor): The teacher model feature with - shape (N, C, H, W) or shape (N, C). - """ - if self.normalize: - s_feature = self.normalize_feature(s_feature) - t_feature = self.normalize_feature(t_feature) - - loss = torch.sum(torch.pow(torch.sub(s_feature, t_feature), 2)) - - if self.div_element: - loss = loss / s_feature.numel() - else: - loss = loss / s_feature.size(0) - - return self.loss_weight * loss - - def normalize_feature(self, feature: torch.Tensor) -> torch.Tensor: - """Normalize the input feature. - - Args: - feature (torch.Tensor): The student model feature with - shape (N, C, H, W) or shape (N, C). - """ - feature = feature.view(feature.size(0), -1) - return feature / feature.norm(2, dim=1, keepdim=True) * self.mult diff --git a/tests/test_models/test_connectors/test_connectors.py b/tests/test_models/test_connectors/test_connectors.py deleted file mode 100644 index d04e14854..000000000 --- a/tests/test_models/test_connectors/test_connectors.py +++ /dev/null @@ -1,35 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -from unittest import TestCase - -import torch - -from mmrazor.models import BNConnector, ReLUConnector, SingleConvConnector - - -class TestConnector(TestCase): - - @classmethod - def setUpClass(cls): - cls.s_feat = torch.randn(1, 1, 5, 5) - cls.t_feat = torch.randn(1, 3, 5, 5) - - def test_singleconv_connector(self): - singleconv_connector_cfg = dict(in_channel=1, out_channel=3) - singleconv_connector = SingleConvConnector(**singleconv_connector_cfg) - - output = singleconv_connector.forward_train(self.s_feat) - assert output.size() == self.t_feat.size() - - def test_bn_connector(self): - bn_connector_cfg = dict(in_channel=1, out_channel=3) - bn_connector = BNConnector(**bn_connector_cfg) - - output = bn_connector.forward_train(self.s_feat) - assert output.size() == self.t_feat.size() - - def test_relu_connector(self): - relu_connector_cfg = dict(in_channel=1, out_channel=3) - relu_connector = ReLUConnector(**relu_connector_cfg) - - output = relu_connector.forward_train(self.s_feat) - assert output.size() == self.t_feat.size() diff --git a/tests/test_models/test_losses/test_general_losses.py b/tests/test_models/test_losses/test_general_losses.py deleted file mode 100644 index 70b4c2c75..000000000 --- a/tests/test_models/test_losses/test_general_losses.py +++ /dev/null @@ -1,28 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -from unittest import TestCase - -import torch - -from mmrazor.models import L2Loss - - -class TestLosses(TestCase): - - @classmethod - def setUpClass(cls): - cls.feats_1d = torch.randn(5, 6) - cls.feats_3d = torch.randn(5, 2, 3, 3) - - def normal_test_1d(self, loss_instance): - loss_1d = loss_instance.forward(self.feats_1d, self.feats_1d) - self.assertTrue(loss_1d.numel() == 1) - - def normal_test_3d(self, loss_instance): - loss_3d = loss_instance.forward(self.feats_3d, self.feats_3d) - self.assertTrue(loss_3d.numel() == 1) - - def test_l2_loss(self): - l2_loss_cfg = dict(loss_weight=10) - l2_loss = L2Loss(**l2_loss_cfg) - self.normal_test_1d(l2_loss) - self.normal_test_3d(l2_loss)