From 1a7863dabc942e61f92957a4e7074cbb4e4da84c Mon Sep 17 00:00:00 2001 From: HIT-cwh <2892770585@qq.com> Date: Wed, 1 Feb 2023 11:29:30 +0800 Subject: [PATCH 1/2] WIP: support loss weight scheduler --- mmrazor/engine/__init__.py | 5 +- mmrazor/engine/hooks/__init__.py | 6 +- .../hooks/loss_weight_scheduler_hook.py | 131 +++++++++++ .../configurable/single_teacher_distill.py | 16 ++ mmrazor/models/task_modules/__init__.py | 1 + .../models/task_modules/scheduler/__init__.py | 12 + .../distill_loss_weight_scheduler.py | 208 ++++++++++++++++++ .../test_loss_weight_scheduler_hook.py | 108 +++++++++ .../test_loss_weight_scheduler.py | 205 +++++++++++++++++ 9 files changed, 689 insertions(+), 3 deletions(-) create mode 100644 mmrazor/engine/hooks/loss_weight_scheduler_hook.py create mode 100644 mmrazor/models/task_modules/scheduler/__init__.py create mode 100644 mmrazor/models/task_modules/scheduler/distill_loss_weight_scheduler.py create mode 100644 tests/test_engine/test_hooks/test_loss_weight_scheduler_hook.py create mode 100644 tests/test_models/test_task_modules/test_scheduler/test_loss_weight_scheduler.py diff --git a/mmrazor/engine/__init__.py b/mmrazor/engine/__init__.py index 7435fa822..56a25801d 100644 --- a/mmrazor/engine/__init__.py +++ b/mmrazor/engine/__init__.py @@ -1,5 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. -from .hooks import DumpSubnetHook, EstimateResourcesHook +from .hooks import (DumpSubnetHook, EstimateResourcesHook, + LossWeightSchedulerHook) from .optimizers import SeparateOptimWrapperConstructor from .runner import (AutoSlimGreedySearchLoop, DartsEpochBasedTrainLoop, DartsIterBasedTrainLoop, EvolutionSearchLoop, @@ -12,5 +13,5 @@ 'SingleTeacherDistillValLoop', 'DartsEpochBasedTrainLoop', 'DartsIterBasedTrainLoop', 'SlimmableValLoop', 'EvolutionSearchLoop', 'GreedySamplerTrainLoop', 'EstimateResourcesHook', 'SelfDistillValLoop', - 'AutoSlimGreedySearchLoop', 'SubnetValLoop' + 'AutoSlimGreedySearchLoop', 'SubnetValLoop', 'LossWeightSchedulerHook' ] diff --git a/mmrazor/engine/hooks/__init__.py b/mmrazor/engine/hooks/__init__.py index d25c7c993..e270e497a 100644 --- a/mmrazor/engine/hooks/__init__.py +++ b/mmrazor/engine/hooks/__init__.py @@ -1,6 +1,10 @@ # Copyright (c) OpenMMLab. All rights reserved. from .dump_subnet_hook import DumpSubnetHook from .estimate_resources_hook import EstimateResourcesHook +from .loss_weight_scheduler_hook import LossWeightSchedulerHook from .visualization_hook import RazorVisualizationHook -__all__ = ['DumpSubnetHook', 'EstimateResourcesHook', 'RazorVisualizationHook'] +__all__ = [ + 'DumpSubnetHook', 'EstimateResourcesHook', 'RazorVisualizationHook', + 'LossWeightSchedulerHook' +] diff --git a/mmrazor/engine/hooks/loss_weight_scheduler_hook.py b/mmrazor/engine/hooks/loss_weight_scheduler_hook.py new file mode 100644 index 000000000..5ebff814b --- /dev/null +++ b/mmrazor/engine/hooks/loss_weight_scheduler_hook.py @@ -0,0 +1,131 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy +from typing import Optional, Sequence, Union + +from mmengine.hooks import Hook +from mmengine.model import is_model_wrapper +from mmengine.runner import BaseLoop + +from mmrazor.models.task_modules import LossWeightScheduler +from mmrazor.registry import HOOKS, PARAM_SCHEDULERS + +DATA_BATCH = Optional[Union[dict, tuple, list]] + + +@HOOKS.register_module() +class LossWeightSchedulerHook(Hook): + + priority = 'LOW' + milestones: list = [] + + def before_run(self, runner) -> None: + + def build_loss_weight_scheduler(scheduler): + if not isinstance(scheduler, Sequence): + schedulers = [scheduler] + else: + schedulers = scheduler + + loss_weight_schedulers = [] + for scheduler in schedulers: + if isinstance(scheduler, LossWeightScheduler): + loss_weight_schedulers.append(scheduler) + elif isinstance(scheduler, dict): + _scheduler = copy.deepcopy(scheduler) + + # Set default end + if isinstance(runner.train_loop, BaseLoop): + default_end = runner.max_epochs if _scheduler.get( + 'by_epoch', True) else runner.max_iters + _scheduler.setdefault('end', default_end) + runner.logger.debug( + f'The `end` of {_scheduler["type"]} is not set. ' + 'Use the max epochs/iters of train loop as ' + 'default.') + + loss_weight_schedulers.append( + PARAM_SCHEDULERS.build( + _scheduler, + default_args=dict( + epoch_length=len(runner.train_dataloader)))) + else: + raise TypeError( + 'scheduler should be a LossWeightScheduler object or ' + f'dict, but got {scheduler}') + + return loss_weight_schedulers + + model = runner.model.module if is_model_wrapper( + runner.model) else runner.model + assert hasattr(model, 'loss_weight_scheduler_manager') + + if model.loss_weight_scheduler_manager is None: + # no specific loss weight scheduler + return + + schedulers = model.loss_weight_scheduler_manager.schedulers + model.loss_weight_scheduler_manager.schedulers = \ + build_loss_weight_scheduler(schedulers) + + intervals = [] + epoch_length = len(runner.train_dataloader) + for scheduler in model.loss_weight_scheduler_manager.schedulers: + if scheduler.by_epoch: + intervals.append((scheduler.begin * epoch_length, + scheduler.end * epoch_length)) + else: + intervals.append((scheduler.begin, scheduler.end)) + # 按照begin排序,按照end构建milestone(如果是by_epoch需要转化成iterbased) + # 如果当前iter在milestone里,则需要改base value + intervals = sorted(intervals, key=lambda x: x[0]) + for begin, end in intervals: + if not self.milestones: + self.milestones.append(end) + elif end > self.milestones[-1]: + self.milestones.append(end) + + def set_loss_weight_multiplier(self, runner, scheduler_manager, by_epoch): + schedulers = scheduler_manager.schedulers + assert isinstance(schedulers, list) + for scheduler in schedulers: + if scheduler.by_epoch == by_epoch: + cur_iter = runner.iter + cur_epoch = runner.epoch + if cur_iter in self.milestones: + # move to the next stage and modify the base value + scheduler_manager.base_value = scheduler_manager.cur_value + + base_value = scheduler_manager.base_value + cur_value = scheduler_manager.cur_value + multiplier = scheduler.get_multiplier( + base_value, cur_value, cur_epoch if by_epoch else cur_iter) + if multiplier is not None: + scheduler_manager.cur_value = multiplier + break + + def before_train_iter(self, + runner, + batch_idx: int, + data_batch: DATA_BATCH = None, + outputs: Optional[dict] = None) -> None: + + model = runner.model.module if is_model_wrapper( + runner.model) else runner.model + + if model.loss_weight_scheduler_manager is None: + # no specific loss weight scheduler + return + + self.set_loss_weight_multiplier( + runner, model.loss_weight_scheduler_manager, by_epoch=False) + + def before_train_epoch(self, runner) -> None: + model = runner.model.module if is_model_wrapper( + runner.model) else runner.model + + if model.loss_weight_scheduler_manager is None: + # no specific loss weight scheduler + return + + self.set_loss_weight_multiplier( + runner, model.loss_weight_scheduler_manager, by_epoch=True) diff --git a/mmrazor/models/algorithms/distill/configurable/single_teacher_distill.py b/mmrazor/models/algorithms/distill/configurable/single_teacher_distill.py index 44a8a3438..7be7a4bb4 100644 --- a/mmrazor/models/algorithms/distill/configurable/single_teacher_distill.py +++ b/mmrazor/models/algorithms/distill/configurable/single_teacher_distill.py @@ -8,6 +8,7 @@ from torch import nn from torch.nn.modules.batchnorm import _BatchNorm +from mmrazor.models.task_modules import LossWeightSchedulerManager from mmrazor.models.utils import add_prefix from mmrazor.registry import MODELS from ...base import BaseAlgorithm, LossResults @@ -42,6 +43,7 @@ def __init__(self, teacher_norm_eval: bool = True, student_trainable: bool = True, calculate_student_loss: bool = True, + loss_weight_schedulers: Optional[List] = None, **kwargs) -> None: super().__init__(**kwargs) @@ -79,6 +81,13 @@ def __init__(self, self.distiller.prepare_from_student(self.student) self.distiller.prepare_from_teacher(self.teacher) + if loss_weight_schedulers is not None: + self.loss_weight_scheduler_manager: \ + Optional[LossWeightSchedulerManager] = \ + LossWeightSchedulerManager(loss_weight_schedulers) + else: + self.loss_weight_scheduler_manager = None + @property def student(self) -> nn.Module: """Alias for ``architecture``.""" @@ -128,6 +137,13 @@ def loss( # Automatically compute distill losses based on `loss_forward_mappings` # The required data already exists in the recorders. distill_losses = self.distiller.compute_distill_losses() + + if self.loss_weight_scheduler_manager is not None: + # distillation loss weight schedule + for name, value in distill_losses.items(): + distill_losses[name] = \ + value * self.loss_weight_scheduler_manager.cur_value + losses.update(add_prefix(distill_losses, 'distill')) return losses diff --git a/mmrazor/models/task_modules/__init__.py b/mmrazor/models/task_modules/__init__.py index 931278b8a..409da5396 100644 --- a/mmrazor/models/task_modules/__init__.py +++ b/mmrazor/models/task_modules/__init__.py @@ -4,6 +4,7 @@ from .estimators import ResourceEstimator from .predictor import * # noqa: F401,F403 from .recorder import * # noqa: F401,F403 +from .scheduler import * # noqa: F401,F403 from .tracer import * # noqa: F401,F403 __all__ = ['ResourceEstimator'] diff --git a/mmrazor/models/task_modules/scheduler/__init__.py b/mmrazor/models/task_modules/scheduler/__init__.py new file mode 100644 index 000000000..2526a2e0d --- /dev/null +++ b/mmrazor/models/task_modules/scheduler/__init__.py @@ -0,0 +1,12 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .distill_loss_weight_scheduler import (CosineAnnealingLossWeightScheduler, + LinearLossWeightScheduler, + LossWeightScheduler, + LossWeightSchedulerManager, + MultiStepLossWeightScheduler) + +__all__ = [ + 'CosineAnnealingLossWeightScheduler', 'LossWeightScheduler', + 'MultiStepLossWeightScheduler', 'LinearLossWeightScheduler', + 'LossWeightSchedulerManager' +] diff --git a/mmrazor/models/task_modules/scheduler/distill_loss_weight_scheduler.py b/mmrazor/models/task_modules/scheduler/distill_loss_weight_scheduler.py new file mode 100644 index 000000000..49adf574e --- /dev/null +++ b/mmrazor/models/task_modules/scheduler/distill_loss_weight_scheduler.py @@ -0,0 +1,208 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import math +from collections import Counter +from typing import List, Optional + +from mmrazor.registry import PARAM_SCHEDULERS + +INF = int(1e9) + + +class LossWeightScheduler: + + def __init__(self, begin: int = 0, end: int = INF, by_epoch: bool = True): + + if end <= begin: + raise ValueError('end should be larger than begin, but got' + ' begin={}, end={}'.format(begin, end)) + self.begin = begin + self.end = end + + # if convert_to_iter_based: + # assert not by_epoch + self.by_epoch = by_epoch + # self.convert_to_iter_based = convert_to_iter_based + + def _get_multiplier(self, base_value, cur_value, cur_step): + raise NotImplementedError + + def get_multiplier(self, base_value, cur_value, cur_step): + """Compute value using chainable form of the scheduler.""" + if not self.begin <= cur_step < self.end: + return None + return self._get_multiplier(base_value, cur_value, cur_step) + + +@PARAM_SCHEDULERS.register_module() +class CosineAnnealingLossWeightScheduler(LossWeightScheduler): + + def __init__(self, + eta_min: Optional[float] = None, + begin: int = 0, + end: int = INF, + by_epoch: bool = True, + eta_min_ratio: Optional[float] = None): + if eta_min is None and eta_min_ratio is None: + eta_min = 0. + assert (eta_min is None) ^ (eta_min_ratio is None), \ + 'Either `eta_min` or `eta_min_ratio should be specified' + self.eta_min = eta_min + self.eta_min_ratio = eta_min_ratio + self.T_max = end - begin + super().__init__(begin, end, by_epoch) + + @classmethod + def build_iter_from_epoch(cls, + *args, + begin=0, + end=INF, + by_epoch=True, + epoch_length=None, + **kwargs): + """Build an iter-based instance of this scheduler from an epoch-based + config.""" + assert by_epoch, 'Only epoch-based kwargs whose `by_epoch=True` can ' \ + 'be converted to iter-based.' + assert epoch_length is not None and epoch_length > 0, \ + f'`epoch_length` must be a positive integer, ' \ + f'but got {epoch_length}.' + by_epoch = False + begin = int(begin * epoch_length) + if end != INF: + end = int(end * epoch_length) + return cls(*args, begin=begin, end=end, by_epoch=by_epoch, **kwargs) + + def _get_multiplier(self, base_value, cur_value, cur_iter): + + def _get_eta_min(): + if self.eta_min_ratio is None: + return self.eta_min + return base_value * self.eta_min_ratio + + eta_min = _get_eta_min() + return eta_min + 0.5 * (base_value - eta_min) * ( + 1 + math.cos(math.pi * (cur_iter - self.begin) / self.T_max)) + + +@PARAM_SCHEDULERS.register_module() +class MultiStepLossWeightScheduler(LossWeightScheduler): + + def __init__(self, + milestones: List[int], + gamma: float = 0.1, + begin: int = 0, + end: int = INF, + by_epoch: bool = True): + super().__init__(begin, end, by_epoch) + # todo: check + milestones = [value + self.begin for value in milestones] + self.milestones = Counter(milestones) + self.gamma = gamma + + @classmethod + def build_iter_from_epoch(cls, + *args, + milestones, + begin=0, + end=INF, + by_epoch=True, + epoch_length=None, + **kwargs): + """Build an iter-based instance of this scheduler from an epoch-based + config.""" + assert by_epoch, 'Only epoch-based kwargs whose `by_epoch=True` can ' \ + 'be converted to iter-based.' + assert epoch_length is not None and epoch_length > 0, \ + f'`epoch_length` must be a positive integer, ' \ + f'but got {epoch_length}.' + by_epoch = False + milestones = [i * epoch_length for i in milestones] + begin = int(begin * epoch_length) + if end != INF: + end = int(end * epoch_length) + return cls( + *args, + milestones=milestones, + begin=begin, + end=end, + by_epoch=by_epoch, + **kwargs) + + def _get_multiplier(self, base_value, cur_value, cur_iter): + if cur_iter not in self.milestones: + return cur_value + return cur_value * self.gamma**self.milestones[cur_iter] + + +@PARAM_SCHEDULERS.register_module() +class LinearLossWeightScheduler(LossWeightScheduler): + + def __init__(self, + start_factor: float = 1.0 / 3, + end_factor: float = 1.0, + begin: int = 0, + end: int = INF, + by_epoch: bool = True): + if start_factor > 1.0 or start_factor < 0: + raise ValueError( + 'Starting multiplicative factor should between 0 and 1.') + + if end_factor > 1.0 or end_factor < 0: + raise ValueError( + 'Ending multiplicative factor should between 0 and 1.') + + self.start_factor = start_factor + self.end_factor = end_factor + self.total_iters = end - begin - 1 + super().__init__(begin, end, by_epoch) + + @classmethod + def build_iter_from_epoch(cls, + *args, + begin=0, + end=INF, + by_epoch=True, + epoch_length=None, + **kwargs): + """Build an iter-based instance of this scheduler from an epoch-based + config.""" + assert by_epoch, 'Only epoch-based kwargs whose `by_epoch=True` can ' \ + 'be converted to iter-based.' + assert epoch_length is not None and epoch_length > 0, \ + f'`epoch_length` must be a positive integer, ' \ + f'but got {epoch_length}.' + by_epoch = False + begin = int(begin * epoch_length) + if end != INF: + end = int(end * epoch_length) + return cls(*args, begin=begin, end=end, by_epoch=by_epoch, **kwargs) + + def _get_multiplier(self, base_value, cur_value, cur_iter): + start_eta = base_value * self.start_factor + end_eta = base_value * self.end_factor + return start_eta + (end_eta - start_eta) * ( + cur_iter - self.begin) / self.total_iters + + +class LossWeightSchedulerManager: + + def __init__(self, schedulers): + self.schedulers = schedulers + self._base_value = 1. + self._cur_value = 1. + + @property + def base_value(self): + return self._base_value + + @base_value.setter + def base_value(self, value): + self._base_value = value + + @property + def cur_value(self): + return self._cur_value + + @cur_value.setter + def cur_value(self, value): + self._cur_value = value diff --git a/tests/test_engine/test_hooks/test_loss_weight_scheduler_hook.py b/tests/test_engine/test_hooks/test_loss_weight_scheduler_hook.py new file mode 100644 index 000000000..e99f00418 --- /dev/null +++ b/tests/test_engine/test_hooks/test_loss_weight_scheduler_hook.py @@ -0,0 +1,108 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import math +from unittest import TestCase +from unittest.mock import Mock + +import torch + +from mmrazor.engine import LossWeightSchedulerHook +from mmrazor.models.task_modules import (CosineAnnealingLossWeightScheduler, + LinearLossWeightScheduler, + LossWeightScheduler, + LossWeightSchedulerManager) + + +class TestParamSchedulerHook(TestCase): + + def setUp(self): + self.hook = LossWeightSchedulerHook() + runner = Mock() + runner.model = Mock() + + epochs = 10 + epoch_length = 7 + scheduler1 = LinearLossWeightScheduler( + start_factor=1 / 2, begin=0, end=5) + + eta_min = 1e-10 + scheduler2 = CosineAnnealingLossWeightScheduler.build_iter_from_epoch( + begin=5, end=epochs, eta_min=eta_min, epoch_length=epoch_length) + runner.model.loss_weight_scheduler_manager = \ + LossWeightSchedulerManager([scheduler1, scheduler2]) + runner.epoch = 0 + runner.iter = 0 + runner.train_dataloader = [torch.rand(1)] * epoch_length + self.runner = runner + + def reset(self): + self.runner.model.loss_weight_scheduler_manager.cur_value = 1. + self.runner.model.loss_weight_scheduler_manager.base_value = 1. + self.runner.epoch = 0 + self.runner.iter = 0 + + def test_before_run(self): + self.reset() + self.hook.before_run(self.runner) + self.assertEquals(self.hook.milestones, [35, 70]) + schedulers = self.runner.model.loss_weight_scheduler_manager.schedulers + for scheduler in schedulers: + self.assertIsInstance(scheduler, LossWeightScheduler) + + def test_before_train_epoch(self): + self.reset() + epochs = 10 + epoch_length = 7 + targets1 = [0.5, 0.625, 0.75, 0.875, 1.0] + targets = targets1 + [targets1[-1]] * 5 + for epoch in range(epochs): + self.hook.before_train_epoch(self.runner) + self.assertAlmostEqual( + self.runner.model.loss_weight_scheduler_manager.cur_value, + targets[epoch]) + self.runner.epoch += 1 + self.runner.iter += epoch_length + + def test_after_train_iter(self): + self.reset() + epochs = 10 + epoch_length = 7 + eta_min = 1e-10 + targets2 = [ + eta_min + (1.0 - eta_min) * + (1 + math.cos(math.pi * x / 5 / epoch_length)) / 2 + for x in range(5 * epoch_length) + ] + targets = [1.0] * 5 * epoch_length + targets2 + for iter in range(epochs * epoch_length): + self.hook.before_train_iter(self.runner, iter) + self.assertAlmostEqual( + self.runner.model.loss_weight_scheduler_manager.cur_value, + targets[iter]) + self.runner.iter += 1 + if iter > 0 and iter % epoch_length == 0: + self.runner.epoch += 1 + + def test_train(self): + self.reset() + epochs = 10 + epoch_length = 7 + targets1 = [0.5, 0.625, 0.75, 0.875, 1.0] + eta_min = 1e-10 + targets2 = [ + eta_min + (targets1[-1] - eta_min) * + (1 + math.cos(math.pi * x / 5 / epoch_length)) / 2 + for x in range(5 * epoch_length) + ] + targets = [] + for num in targets1: + targets += [num] * epoch_length + targets += targets2 + for epoch in range(epochs): + self.hook.before_train_epoch(self.runner) + for iter in range(epoch_length): + self.hook.before_train_iter(self.runner, iter) + self.assertAlmostEqual( + self.runner.model.loss_weight_scheduler_manager.cur_value, + targets[epoch * epoch_length + iter]) + self.runner.iter += 1 + self.runner.epoch += 1 diff --git a/tests/test_models/test_task_modules/test_scheduler/test_loss_weight_scheduler.py b/tests/test_models/test_task_modules/test_scheduler/test_loss_weight_scheduler.py new file mode 100644 index 000000000..d9927eff2 --- /dev/null +++ b/tests/test_models/test_task_modules/test_scheduler/test_loss_weight_scheduler.py @@ -0,0 +1,205 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import math +from unittest import TestCase + +from mmrazor.models.task_modules import (CosineAnnealingLossWeightScheduler, + LinearLossWeightScheduler, + LossWeightSchedulerManager, + MultiStepLossWeightScheduler) + + +class TestLossWeightScheduler(TestCase): + + def _test_scheduler_value(self, scheduler_manager, targets, epochs=10): + schedulers = scheduler_manager.schedulers + assert isinstance(schedulers, list) + + intervals = [(scheduler.begin, scheduler.end) + for scheduler in schedulers] + intervals = sorted(intervals, key=lambda x: x[0]) + milestones = [] + for begin, end in intervals: + if not milestones: + milestones.append(end) + elif end > milestones[-1]: + milestones.append(end) + + for epoch in range(epochs): + for scheduler in schedulers: + if epoch in milestones: + scheduler_manager.base_value = scheduler_manager.cur_value + + base_value = scheduler_manager.base_value + cur_value = scheduler_manager.cur_value + multiplier = scheduler.get_multiplier(base_value, cur_value, + epoch) + if multiplier is not None: + scheduler_manager.cur_value = multiplier + break + self.assertAlmostEqual(scheduler_manager.cur_value, targets[epoch]) + + def test_cos_anneal_scheduler(self): + with self.assertRaises(AssertionError): + CosineAnnealingLossWeightScheduler( + begin=0, end=12, eta_min=0, eta_min_ratio=0.1) + + eta_min = 0. + epochs = 12 + targets = [ + eta_min + (1. - eta_min) * (1 + math.cos(math.pi * x / epochs)) / 2 + for x in range(epochs) + ] + scheduler = CosineAnnealingLossWeightScheduler( + eta_min=0., begin=0, end=12) + scheduler_manager = LossWeightSchedulerManager([scheduler]) + self._test_scheduler_value(scheduler_manager, targets, epochs) + + def test_multi_step_scheduler(self): + # loss weight = 1. if epoch < 2 + # loss weight = 0.1 if 2 <= epoch < 5 + # loss weight = 0.01 if 5 <= epoch < 9 + # loss weight = 0.001 if epoch >= 9 + epochs = 10 + targets = [1.] * 2 + [0.1] * 3 + [0.01] * 4 + [0.001] * 3 + + scheduler = MultiStepLossWeightScheduler( + gamma=0.1, milestones=[2, 5, 9]) + scheduler_manager = LossWeightSchedulerManager([scheduler]) + self._test_scheduler_value(scheduler_manager, targets, epochs) + + def test_linear_scheduler(self): + with self.assertRaises(ValueError): + LinearLossWeightScheduler(start_factor=10, end=900) + with self.assertRaises(ValueError): + LinearLossWeightScheduler(start_factor=-1, end=900) + with self.assertRaises(ValueError): + LinearLossWeightScheduler(end_factor=1.001, end=900) + with self.assertRaises(ValueError): + LinearLossWeightScheduler(end_factor=-0.00001, end=900) + # lr = 0.5 if epoch == 0 + # lr = 0.625 if epoch == 1 + # lr = 0.75 if epoch == 2 + # lr = 0.875 if epoch == 3 + # lr = 1.0 if epoch >= 4 + epochs = 10 + start_factor = 1.0 / 2 + iters = 4 + interpolation = [ + start_factor + i * (1 - start_factor) / iters for i in range(iters) + ] + targets = [x * 1. for x in interpolation] + [1.] * (epochs - iters) + scheduler = LinearLossWeightScheduler( + start_factor=start_factor, end=iters + 1) + scheduler_manager = LossWeightSchedulerManager([scheduler]) + self._test_scheduler_value(scheduler_manager, targets, epochs) + + def test_cos_anneal_scheduler_convert_iterbased(self): + epochs = 12 + eta_min = 1e-10 + epoch_length = 11 + targets = [ + eta_min + (1. - eta_min) * + (1 + math.cos(math.pi * x / epochs / epoch_length)) / 2 + for x in range(epochs * epoch_length) + ] + scheduler = CosineAnnealingLossWeightScheduler.build_iter_from_epoch( + end=epochs, eta_min=eta_min, epoch_length=epoch_length) + scheduler_manager = LossWeightSchedulerManager([scheduler]) + self._test_scheduler_value(scheduler_manager, targets, + epochs * epoch_length) + + def test_multi_step_scheduler_convert_iterbased(self): + # lr = 1.0 if epoch < 2 + # lr = 0.1 if 2 <= epoch < 5 + # lr = 0.01 if 5 <= epoch < 9 + # lr = 0.001 if epoch >= 9 + epochs = 10 + epoch_length = 7 + targets = [1.] * 2 * epoch_length + [0.1] * 3 * epoch_length + [ + 0.01 + ] * 4 * epoch_length + [0.001] * 3 * epoch_length + scheduler = MultiStepLossWeightScheduler.build_iter_from_epoch( + gamma=0.1, milestones=[2, 5, 9], epoch_length=epoch_length) + scheduler_manager = LossWeightSchedulerManager([scheduler]) + self._test_scheduler_value(scheduler_manager, targets, + epochs * epoch_length) + + def test_linear_scheduler_convert_iterbased(self): + epochs = 10 + start_factor = 1.0 / 2 + end = 5 + epoch_length = 11 + + iters = end * epoch_length - 1 + interpolation = [ + start_factor + i * (1 - start_factor) / iters for i in range(iters) + ] + targets = [x * 1. for x in interpolation] + [1.] * ( + epochs * epoch_length - iters) + scheduler = LinearLossWeightScheduler.build_iter_from_epoch( + start_factor=start_factor, end=end, epoch_length=epoch_length) + scheduler_manager = LossWeightSchedulerManager([scheduler]) + self._test_scheduler_value(scheduler_manager, targets, + epochs * epoch_length) + + def test_multi_scheduler_without_overlap_linear_multi_step(self): + # use Linear in the first 5 epochs and then use MultiStep + epochs = 12 + targets = [0.5, 0.625, 0.75, 0.875 + ] + [1.0] * 4 + [0.1] * 3 + [0.01] * 1 + scheduler1 = LinearLossWeightScheduler( + start_factor=1 / 2, begin=0, end=5) + scheduler2 = MultiStepLossWeightScheduler( + gamma=0.1, milestones=[3, 6], begin=5, end=12) + scheduler_manager = LossWeightSchedulerManager( + [scheduler1, scheduler2]) + self._test_scheduler_value(scheduler_manager, targets, epochs) + + def test_multi_scheduler_without_overlap_linear_cosine(self): + # use Linear in the first 5 epochs and then use Cosine + epochs = 10 + targets1 = [0.5, 0.625, 0.75, 0.875, 1.0] + scheduler1 = LinearLossWeightScheduler( + start_factor=1 / 2, begin=0, end=5) + + eta_min = 1e-10 + targets2 = [ + eta_min + (targets1[-1] - eta_min) * + (1 + math.cos(math.pi * x / 5)) / 2 for x in range(5) + ] + scheduler2 = CosineAnnealingLossWeightScheduler( + begin=5, end=epochs, eta_min=eta_min) + + targets = targets1 + targets2 + scheduler_manager = LossWeightSchedulerManager( + [scheduler1, scheduler2]) + self._test_scheduler_value(scheduler_manager, targets, epochs) + + def test_multi_scheduler_with_overlap(self): + # use Linear at first 5 epochs together with MultiStep + epochs = 10 + targets = [0.5, 0.625, 0.75, 0.875 + ] + [1.0] * 2 + [0.1] * 3 + [0.01] * 1 + scheduler1 = LinearLossWeightScheduler( + start_factor=1 / 2, begin=0, end=5) + scheduler2 = MultiStepLossWeightScheduler( + gamma=0.1, milestones=[3, 6, 9]) + scheduler_manager = LossWeightSchedulerManager( + [scheduler1, scheduler2]) + self._test_scheduler_value(scheduler_manager, targets, epochs) + + def test_multi_scheduler_with_gap(self): + # use Linear in the first 5 epochs and the last 5 epochs use MultiStep + # no scheduler in the middle 5 epochs + epochs = 15 + targets1 = [0.5, 0.625, 0.75, 0.875, 1.0] + scheduler1 = LinearLossWeightScheduler( + start_factor=1 / 2, begin=0, end=5) + + scheduler2 = MultiStepLossWeightScheduler( + gamma=0., milestones=[0], begin=10, end=15) + targets2 = [0.] * 5 + targets = targets1 + [targets1[-1]] * 5 + targets2 + scheduler_manager = LossWeightSchedulerManager( + [scheduler1, scheduler2]) + self._test_scheduler_value(scheduler_manager, targets, epochs) From c577d4777c38706356e93a12fe8819a1a26c0cd6 Mon Sep 17 00:00:00 2001 From: HIT-cwh <2892770585@qq.com> Date: Mon, 6 Feb 2023 16:35:17 +0800 Subject: [PATCH 2/2] fix pytest --- .../test_engine/test_hooks/test_loss_weight_scheduler_hook.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/test_engine/test_hooks/test_loss_weight_scheduler_hook.py b/tests/test_engine/test_hooks/test_loss_weight_scheduler_hook.py index e99f00418..2f22ad9af 100644 --- a/tests/test_engine/test_hooks/test_loss_weight_scheduler_hook.py +++ b/tests/test_engine/test_hooks/test_loss_weight_scheduler_hook.py @@ -50,6 +50,7 @@ def test_before_run(self): def test_before_train_epoch(self): self.reset() + self.hook.before_run(self.runner) epochs = 10 epoch_length = 7 targets1 = [0.5, 0.625, 0.75, 0.875, 1.0] @@ -64,6 +65,7 @@ def test_before_train_epoch(self): def test_after_train_iter(self): self.reset() + self.hook.before_run(self.runner) epochs = 10 epoch_length = 7 eta_min = 1e-10 @@ -84,6 +86,7 @@ def test_after_train_iter(self): def test_train(self): self.reset() + self.hook.before_run(self.runner) epochs = 10 epoch_length = 7 targets1 = [0.5, 0.625, 0.75, 0.875, 1.0]