Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Revert "[Feature] Add connector components and FitNet" #211

Merged
merged 1 commit into from
Jul 28, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 0 additions & 48 deletions configs/distill/mmcls/fitnet/README.md

This file was deleted.

This file was deleted.

Binary file removed docs/en/imgs/model_zoo/fitnet/pipeline.png
Binary file not shown.
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from mmcv.runner import load_checkpoint
from mmengine import BaseDataElement
from mmengine.model import BaseModel
from torch import nn
from torch.nn.modules.batchnorm import _BatchNorm

from mmrazor.models.utils import add_prefix
Expand All @@ -19,7 +18,6 @@ class SingleTeacherDistill(BaseAlgorithm):
only use one teacher.

Args:
distiller (dict): The config dict for built distiller.
teacher (dict | BaseModel): The config dict for teacher model or built
teacher model.
teacher_ckpt (str): The path of teacher's checkpoint. Defaults to None.
Expand All @@ -28,10 +26,6 @@ class SingleTeacherDistill(BaseAlgorithm):
teacher_norm_eval (bool): Whether to set teacher's norm layers to eval
mode, namely, freeze running stats (mean and var). Note: Effect on
Batch Norm and its variants only. Defaults to True.
student_trainable (bool): Whether the student is trainable. Defaults
to True.
calculate_student_loss (bool): Whether to calculate student loss
(original task loss) to update student model. Defaults to True.
"""

def __init__(self,
Expand All @@ -40,9 +34,7 @@ def __init__(self,
teacher_ckpt: Optional[str] = None,
teacher_trainable: bool = False,
teacher_norm_eval: bool = True,
student_trainable: bool = True,
calculate_student_loss: bool = True,
**kwargs) -> None:
**kwargs):
super().__init__(**kwargs)

self.distiller = MODELS.build(distiller)
Expand All @@ -63,21 +55,13 @@ def __init__(self,
self.teacher_trainable = teacher_trainable
self.teacher_norm_eval = teacher_norm_eval

# The student model will not calculate gradients and update parameters
# in some pretraining process.
self.student_trainable = student_trainable

# The student loss will not be updated into ``losses`` in some
# pretraining process.
self.calculate_student_loss = calculate_student_loss

# In ``ConfigurableDistller``, the recorder manager is just
# constructed, but not really initialized yet.
self.distiller.prepare_from_student(self.student)
self.distiller.prepare_from_teacher(self.teacher)

@property
def student(self) -> nn.Module:
def student(self):
"""Alias for ``architecture``."""
return self.architecture

Expand All @@ -102,25 +86,16 @@ def loss(
else:
with self.distiller.teacher_recorders, self.distiller.deliveries:
with torch.no_grad():

_ = self.teacher(batch_inputs, data_samples, mode='loss')

# If the `override_data` of a delivery is True, the delivery will
# override the origin data with the recorded data.
self.distiller.set_deliveries_override(True)
# Original task loss will not be used during some pretraining process.
if self.calculate_student_loss:
with self.distiller.student_recorders, self.distiller.deliveries:
student_losses = self.student(
batch_inputs, data_samples, mode='loss')
losses.update(add_prefix(student_losses, 'student'))
else:
with self.distiller.student_recorders, self.distiller.deliveries:
if self.student_trainable:
_ = self.student(batch_inputs, data_samples, mode='loss')
else:
with torch.no_grad():
_ = self.student(
batch_inputs, data_samples, mode='loss')
with self.distiller.student_recorders, self.distiller.deliveries:
student_losses = self.student(
batch_inputs, data_samples, mode='loss')
losses.update(add_prefix(student_losses, 'student'))

# Automatically compute distill losses based on `loss_forward_mappings`
# The required data already exists in the recorders.
Expand All @@ -129,7 +104,7 @@ def loss(

return losses

def train(self, mode: bool = True) -> None:
def train(self, mode=True):
"""Set distiller's forward mode."""
super().train(mode)
if mode and self.teacher_norm_eval:
Expand Down
1 change: 0 additions & 1 deletion mmrazor/models/architectures/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .backbones import * # noqa: F401,F403
from .connectors import * # noqa: F401,F403
from .dynamic_op import * # noqa: F401,F403
from .heads import * # noqa: F401,F403
5 changes: 0 additions & 5 deletions mmrazor/models/architectures/connectors/__init__.py

This file was deleted.

41 changes: 0 additions & 41 deletions mmrazor/models/architectures/connectors/base_connector.py

This file was deleted.

Loading