Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

[Feature] Support distillation loss weight scheduler #444

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions mmrazor/engine/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .hooks import DumpSubnetHook, EstimateResourcesHook
from .hooks import (DumpSubnetHook, EstimateResourcesHook,
LossWeightSchedulerHook)
from .optimizers import SeparateOptimWrapperConstructor
from .runner import (AutoSlimGreedySearchLoop, DartsEpochBasedTrainLoop,
DartsIterBasedTrainLoop, EvolutionSearchLoop,
Expand All @@ -12,5 +13,5 @@
'SingleTeacherDistillValLoop', 'DartsEpochBasedTrainLoop',
'DartsIterBasedTrainLoop', 'SlimmableValLoop', 'EvolutionSearchLoop',
'GreedySamplerTrainLoop', 'EstimateResourcesHook', 'SelfDistillValLoop',
'AutoSlimGreedySearchLoop', 'SubnetValLoop'
'AutoSlimGreedySearchLoop', 'SubnetValLoop', 'LossWeightSchedulerHook'
]
6 changes: 5 additions & 1 deletion mmrazor/engine/hooks/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .dump_subnet_hook import DumpSubnetHook
from .estimate_resources_hook import EstimateResourcesHook
from .loss_weight_scheduler_hook import LossWeightSchedulerHook
from .visualization_hook import RazorVisualizationHook

__all__ = ['DumpSubnetHook', 'EstimateResourcesHook', 'RazorVisualizationHook']
__all__ = [
'DumpSubnetHook', 'EstimateResourcesHook', 'RazorVisualizationHook',
'LossWeightSchedulerHook'
]
131 changes: 131 additions & 0 deletions mmrazor/engine/hooks/loss_weight_scheduler_hook.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
# Copyright (c) OpenMMLab. All rights reserved.
import copy
from typing import Optional, Sequence, Union

from mmengine.hooks import Hook
from mmengine.model import is_model_wrapper
from mmengine.runner import BaseLoop

from mmrazor.models.task_modules import LossWeightScheduler
from mmrazor.registry import HOOKS, PARAM_SCHEDULERS

DATA_BATCH = Optional[Union[dict, tuple, list]]


@HOOKS.register_module()
class LossWeightSchedulerHook(Hook):

priority = 'LOW'
milestones: list = []

def before_run(self, runner) -> None:

def build_loss_weight_scheduler(scheduler):
if not isinstance(scheduler, Sequence):
schedulers = [scheduler]
else:
schedulers = scheduler

loss_weight_schedulers = []
for scheduler in schedulers:
if isinstance(scheduler, LossWeightScheduler):
loss_weight_schedulers.append(scheduler)
elif isinstance(scheduler, dict):
_scheduler = copy.deepcopy(scheduler)

# Set default end
if isinstance(runner.train_loop, BaseLoop):
default_end = runner.max_epochs if _scheduler.get(
'by_epoch', True) else runner.max_iters
_scheduler.setdefault('end', default_end)
runner.logger.debug(
f'The `end` of {_scheduler["type"]} is not set. '
'Use the max epochs/iters of train loop as '
'default.')

loss_weight_schedulers.append(
PARAM_SCHEDULERS.build(
_scheduler,
default_args=dict(
epoch_length=len(runner.train_dataloader))))
else:
raise TypeError(
'scheduler should be a LossWeightScheduler object or '
f'dict, but got {scheduler}')

return loss_weight_schedulers

model = runner.model.module if is_model_wrapper(
runner.model) else runner.model
assert hasattr(model, 'loss_weight_scheduler_manager')

if model.loss_weight_scheduler_manager is None:
# no specific loss weight scheduler
return

schedulers = model.loss_weight_scheduler_manager.schedulers
model.loss_weight_scheduler_manager.schedulers = \
build_loss_weight_scheduler(schedulers)

intervals = []
epoch_length = len(runner.train_dataloader)
for scheduler in model.loss_weight_scheduler_manager.schedulers:
if scheduler.by_epoch:
intervals.append((scheduler.begin * epoch_length,
scheduler.end * epoch_length))
else:
intervals.append((scheduler.begin, scheduler.end))
# 按照begin排序,按照end构建milestone(如果是by_epoch需要转化成iterbased)
# 如果当前iter在milestone里,则需要改base value
intervals = sorted(intervals, key=lambda x: x[0])
for begin, end in intervals:
if not self.milestones:
self.milestones.append(end)
elif end > self.milestones[-1]:
self.milestones.append(end)

def set_loss_weight_multiplier(self, runner, scheduler_manager, by_epoch):
schedulers = scheduler_manager.schedulers
assert isinstance(schedulers, list)
for scheduler in schedulers:
if scheduler.by_epoch == by_epoch:
cur_iter = runner.iter
cur_epoch = runner.epoch
if cur_iter in self.milestones:
# move to the next stage and modify the base value
scheduler_manager.base_value = scheduler_manager.cur_value

base_value = scheduler_manager.base_value
cur_value = scheduler_manager.cur_value
multiplier = scheduler.get_multiplier(
base_value, cur_value, cur_epoch if by_epoch else cur_iter)
if multiplier is not None:
scheduler_manager.cur_value = multiplier
break

def before_train_iter(self,
runner,
batch_idx: int,
data_batch: DATA_BATCH = None,
outputs: Optional[dict] = None) -> None:

model = runner.model.module if is_model_wrapper(
runner.model) else runner.model

if model.loss_weight_scheduler_manager is None:
# no specific loss weight scheduler
return

self.set_loss_weight_multiplier(
runner, model.loss_weight_scheduler_manager, by_epoch=False)

def before_train_epoch(self, runner) -> None:
model = runner.model.module if is_model_wrapper(
runner.model) else runner.model

if model.loss_weight_scheduler_manager is None:
# no specific loss weight scheduler
return

self.set_loss_weight_multiplier(
runner, model.loss_weight_scheduler_manager, by_epoch=True)
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from torch import nn
from torch.nn.modules.batchnorm import _BatchNorm

from mmrazor.models.task_modules import LossWeightSchedulerManager
from mmrazor.models.utils import add_prefix
from mmrazor.registry import MODELS
from ...base import BaseAlgorithm, LossResults
Expand Down Expand Up @@ -42,6 +43,7 @@ def __init__(self,
teacher_norm_eval: bool = True,
student_trainable: bool = True,
calculate_student_loss: bool = True,
loss_weight_schedulers: Optional[List] = None,
**kwargs) -> None:
super().__init__(**kwargs)

Expand Down Expand Up @@ -79,6 +81,13 @@ def __init__(self,
self.distiller.prepare_from_student(self.student)
self.distiller.prepare_from_teacher(self.teacher)

if loss_weight_schedulers is not None:
self.loss_weight_scheduler_manager: \
Optional[LossWeightSchedulerManager] = \
LossWeightSchedulerManager(loss_weight_schedulers)
else:
self.loss_weight_scheduler_manager = None

@property
def student(self) -> nn.Module:
"""Alias for ``architecture``."""
Expand Down Expand Up @@ -128,6 +137,13 @@ def loss(
# Automatically compute distill losses based on `loss_forward_mappings`
# The required data already exists in the recorders.
distill_losses = self.distiller.compute_distill_losses()

if self.loss_weight_scheduler_manager is not None:
# distillation loss weight schedule
for name, value in distill_losses.items():
distill_losses[name] = \
value * self.loss_weight_scheduler_manager.cur_value

losses.update(add_prefix(distill_losses, 'distill'))

return losses
Expand Down
1 change: 1 addition & 0 deletions mmrazor/models/task_modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from .estimators import ResourceEstimator
from .predictor import * # noqa: F401,F403
from .recorder import * # noqa: F401,F403
from .scheduler import * # noqa: F401,F403
from .tracer import * # noqa: F401,F403

__all__ = ['ResourceEstimator']
12 changes: 12 additions & 0 deletions mmrazor/models/task_modules/scheduler/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .distill_loss_weight_scheduler import (CosineAnnealingLossWeightScheduler,
LinearLossWeightScheduler,
LossWeightScheduler,
LossWeightSchedulerManager,
MultiStepLossWeightScheduler)

__all__ = [
'CosineAnnealingLossWeightScheduler', 'LossWeightScheduler',
'MultiStepLossWeightScheduler', 'LinearLossWeightScheduler',
'LossWeightSchedulerManager'
]
Loading