Skip to content

Commit 3715bbc

Browse files
LKJackyfpshuanghuangpengshenggaoyang07humu789
authored
Refine pruning branch (#307)
* [feature] CONTRASTIVE REPRESENTATION DISTILLATION with dataset wrapper (#281) * init * TD: CRDLoss * complete UT * fix docstrings * fix ci * update * fix CI * DONE * maintain CRD dataset unique funcs as a mixin * maintain CRD dataset unique funcs as a mixin * maintain CRD dataset unique funcs as a mixin * add UT: CRD_ClsDataset * init * TODO: UT test formatting. * init * crd dataset wrapper * update docstring Co-authored-by: huangpengsheng <huangpengsheng@sensetime.com> * [Improvement] Update estimator with api revision (#277) * update estimator usage and fix bugs * refactor api of estimator & add inner check methods * fix docstrings * update search loop and config * fix lint * update unittest * decouple mmdet dependency and fix lint Co-authored-by: humu789 <humu@pjlab.org.cn> * [Fix] Fix tracer (#273) * test image_classifier_loss_calculator * fix backward tracer * update SingleStageDetectorPseudoLoss * merge * [Feature] Add Dsnas Algorithm (#226) * [tmp] Update Dsnas * [tmp] refactor arch_loss & flops_loss * Update Dsnas & MMRAZOR_EVALUATOR: 1. finalized compute_loss & handle_grads in algorithm; 2. add MMRAZOR_EVALUATOR; 3. fix bugs. * Update lr scheduler & fix a bug: 1. update param_scheduler & lr_scheduler for dsnas; 2. fix a bug of switching to finetune stage. * remove old evaluators * remove old evaluators * update param_scheduler config * merge dev-1.x into gy/estimator * add flops_loss in Dsnas using ResourcesEstimator * get resources before mutator.prepare_from_supernet * delete unness broadcast api from gml * broadcast spec_modules_resources when estimating * update early fix mechanism for Dsnas * fix merge * update units in estimator * minor change * fix data_preprocessor api * add flops_loss_coef * remove DsnasOptimWrapper * fix bn eps and data_preprocessor * fix bn weight decay bug * add betas for mutator optimizer * set diff_rank_seed=True for dsnas * fix start_factor of lr when warm up * remove .module in non-ddp mode * add GlobalAveragePoolingWithDropout * add UT for dsnas * remove unness channel adjustment for shufflenetv2 * update supernet configs * delete unness dropout * delete unness part with minor change on dsnas * minor change on the flag of search stage * update README and subnet configs * add UT for OneHotMutableOP * [Feature] Update train (#279) * support auto resume * add enable auto_scale_lr in train.py * support '--amp' option * [Fix] Fix darts metafile (#278) fix darts metafile * fix ci (#284) * fix ci for circle ci * fix bug in test_metafiles * add pr_stage_test for github ci * add multiple version * fix ut * fix lint * Temporarily skip dataset UT * update github ci * add github lint ci * install wheel * remove timm from requirements * install wheel when test on windows * fix error * fix bug * remove github windows ci * fix device error of arch_params when DsnasDDP * fix CRD dataset ut * fix scope error * rm test_cuda in workflows of github * [Doc] fix typos in en/usr_guides Co-authored-by: liukai <liukai@pjlab.org.cn> Co-authored-by: pppppM <gjf_mail@126.com> Co-authored-by: gaoyang07 <1546308416@qq.com> Co-authored-by: huangpengsheng <huangpengsheng@sensetime.com> Co-authored-by: SheffieldCao <1751899@tongji.edu.cn> * fix bug when python=3.6 * fix lint * fix bug when test using cpu only * refine ci * fix error in ci * try ci * update repr of Channel * fix error * mv init_from_predefined_model to MutableChannelUnit * move tests * update SquentialMutableChannel * update l1 mutable channel unit * add OneShotMutableChannel * candidate_mode -> choice_mode * update docstring * change ci Co-authored-by: P.Huang <37200926+FreakieHuang@users.noreply.github.com> Co-authored-by: huangpengsheng <huangpengsheng@sensetime.com> Co-authored-by: Yang Gao <Gary1546308416AL@gmail.com> Co-authored-by: humu789 <humu@pjlab.org.cn> Co-authored-by: whcao <41630003+HIT-cwh@users.noreply.github.com> Co-authored-by: liukai <liukai@pjlab.org.cn> Co-authored-by: pppppM <gjf_mail@126.com> Co-authored-by: gaoyang07 <1546308416@qq.com> Co-authored-by: SheffieldCao <1751899@tongji.edu.cn>
1 parent 8330b62 commit 3715bbc

35 files changed

+354
-260
lines changed

.circleci/test.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ jobs:
119119
docker exec mmrazor pip install -e /mmdetection
120120
docker exec mmrazor pip install -e /mmclassification
121121
docker exec mmrazor pip install -e /mmsegmentation
122-
docker exec mmrazor pip install -r requirements/tests.txt
122+
docker exec mmrazor pip install -r requirements.txt
123123
- run:
124124
name: Build and install
125125
command: |

configs/pruning/mmcls/autoslim/autoslim_mbv2_1.5x_supernet_8xb256_in1k.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@
4747
type='OneShotMutableChannelUnit',
4848
default_args=dict(
4949
candidate_choices=list(i / 12 for i in range(2, 13)),
50-
candidate_mode='ratio',
50+
choice_mode='ratio',
5151
divisor=8)),
5252
parse_cfg=dict(
5353
type='BackwardTracer',

mmrazor/models/algorithms/__init__.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -13,5 +13,7 @@
1313
'Darts', 'DartsDDP', 'SelfDistill', 'DataFreeDistillation',
1414
'DAFLDataFreeDistillation', 'OverhaulFeatureDistillation',
1515
'ItePruneAlgorithm', 'DAFLDataFreeDistillation',
16-
'OverhaulFeatureDistillation', 'Dsnas', 'DsnasDDP'
16+
'OverhaulFeatureDistillation', 'Dsnas', 'DsnasDDP',
17+
'DAFLDataFreeDistillation', 'OverhaulFeatureDistillation', 'Dsnas',
18+
'DsnasDDP'
1719
]

mmrazor/models/architectures/dynamic_ops/bricks/dynamic_norm.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -152,8 +152,7 @@ def init_candidates(self, candidates: List):
152152
for num in candidates:
153153
self.candidate_bn[str(num)] = nn.BatchNorm2d(
154154
num, self.eps, self.momentum, self.affine,
155-
self.track_running_stats, self.weight.device,
156-
self.weight.dtype)
155+
self.track_running_stats)
157156

158157
def forward(self, input: Tensor) -> Tensor:
159158
"""Forward."""

mmrazor/models/mutables/__init__.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@
22
from .base_mutable import BaseMutable
33
from .derived_mutable import DerivedMutable
44
from .mutable_channel import (BaseMutableChannel, MutableChannelContainer,
5-
SimpleMutableChannel, SquentialMutableChannel)
5+
OneShotMutableChannel, SimpleMutableChannel,
6+
SquentialMutableChannel)
67
from .mutable_channel.units import (ChannelUnitType, L1MutableChannelUnit,
78
MutableChannelUnit,
89
OneShotMutableChannelUnit,
@@ -22,5 +23,7 @@
2223
'BaseMutableChannel', 'MutableChannelContainer', 'ChannelUnitType',
2324
'SquentialMutableChannel', 'BaseMutable', 'DiffChoiceRoute',
2425
'DiffMutableModule', 'DerivedMutable', 'MutableValue',
25-
'OneShotMutableValue', 'OneHotMutableOP'
26+
'OneShotMutableValue', 'OneHotMutableOP', 'OneShotMutableChannel',
27+
'DiffChoiceRoute', 'DiffMutableModule', 'OneShotMutableChannel',
28+
'DerivedMutable', 'MutableValue', 'OneShotMutableValue', 'OneHotMutableOP'
2629
]
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,17 @@
11
# Copyright (c) OpenMMLab. All rights reserved.
22
from .base_mutable_channel import BaseMutableChannel
3-
from .units import (ChannelUnitType, L1MutableChannelUnit,
4-
MutableChannelUnit, OneShotMutableChannelUnit,
5-
SequentialMutableChannelUnit, SlimmableChannelUnit)
63
from .mutable_channel_container import MutableChannelContainer
4+
from .oneshot_mutalbe_channel import OneShotMutableChannel
75
from .sequential_mutable_channel import SquentialMutableChannel
86
from .simple_mutable_channel import SimpleMutableChannel
7+
from .units import (ChannelUnitType, L1MutableChannelUnit, MutableChannelUnit,
8+
OneShotMutableChannelUnit, SequentialMutableChannelUnit,
9+
SlimmableChannelUnit)
910

1011
__all__ = [
1112
'SimpleMutableChannel', 'L1MutableChannelUnit',
1213
'SequentialMutableChannelUnit', 'MutableChannelUnit',
13-
'OneShotMutableChannelUnit', 'SlimmableChannelUnit',
14-
'BaseMutableChannel', 'MutableChannelContainer', 'SquentialMutableChannel',
15-
'ChannelUnitType'
14+
'OneShotMutableChannelUnit', 'SlimmableChannelUnit', 'BaseMutableChannel',
15+
'MutableChannelContainer', 'SquentialMutableChannel', 'ChannelUnitType',
16+
'OneShotMutableChannel'
1617
]

mmrazor/models/mutables/mutable_channel/mutable_channel_container.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ class MutableChannelContainer(BaseMutableChannel):
3030

3131
def __init__(self, num_channels: int, **kwargs):
3232
super().__init__(num_channels, **kwargs)
33-
self.mutable_channels: IndexDict[BaseMutableChannel] = IndexDict()
33+
self.mutable_channels = IndexDict()
3434

3535
# choice
3636

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
# Copyright (c) OpenMMLab. All rights reserved.
2+
from typing import List, Union
3+
4+
from .sequential_mutable_channel import SquentialMutableChannel
5+
6+
7+
class OneShotMutableChannel(SquentialMutableChannel):
8+
"""OneShotMutableChannel is a subclass of SquentialMutableChannel. The
9+
difference is that a OneShotMutableChannel limits the candidates of the
10+
choice.
11+
12+
Args:
13+
num_channels (int): number of channels.
14+
candidate_choices (List[Union[float, int]], optional): A list of
15+
candidate width ratios. Each candidate indicates how many
16+
channels to be reserved. Defaults to [].
17+
choice_mode (str, optional): Mode of choices. Defaults to 'number'.
18+
"""
19+
20+
def __init__(self,
21+
num_channels: int,
22+
candidate_choices: List[Union[float, int]] = [],
23+
choice_mode='number',
24+
**kwargs):
25+
super().__init__(num_channels, choice_mode, **kwargs)
26+
self.candidate_choices = candidate_choices
27+
if candidate_choices == []:
28+
candidate_choices.append(num_channels if self.is_num_mode else 1.0)
29+
30+
@property
31+
def current_choice(self) -> Union[int, float]:
32+
"""Get current choice."""
33+
return super().current_choice
34+
35+
@current_choice.setter
36+
def current_choice(self, choice: Union[int, float]):
37+
"""Set current choice."""
38+
assert choice in self.candidate_choices
39+
SquentialMutableChannel.current_choice.fset( # type: ignore
40+
self, # type: ignore
41+
choice) # type: ignore

mmrazor/models/mutables/mutable_channel/sequential_mutable_channel.py

+37-22
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,17 @@
11
# Copyright (c) OpenMMLab. All rights reserved.
2-
from typing import Callable
2+
from typing import Callable, Union
33

44
import torch
55

66
from mmrazor.registry import MODELS
77
from ..derived_mutable import DerivedMutable
8-
from .base_mutable_channel import BaseMutableChannel
8+
from .simple_mutable_channel import SimpleMutableChannel
99

1010
# TODO discuss later
1111

1212

1313
@MODELS.register_module()
14-
class SquentialMutableChannel(BaseMutableChannel):
14+
class SquentialMutableChannel(SimpleMutableChannel):
1515
"""SquentialMutableChannel defines a BaseMutableChannel which switch off
1616
channel mask from right to left sequentially, like '11111000'.
1717
@@ -22,21 +22,36 @@ class SquentialMutableChannel(BaseMutableChannel):
2222
num_channels (int): number of channels.
2323
"""
2424

25-
def __init__(self, num_channels: int, **kwargs):
25+
def __init__(self, num_channels: int, choice_mode='number', **kwargs):
2626

2727
super().__init__(num_channels, **kwargs)
28+
assert choice_mode in ['ratio', 'number']
29+
self.choice_mode = choice_mode
2830
self.mask = torch.ones([self.num_channels]).bool()
2931

3032
@property
31-
def current_choice(self) -> int:
33+
def is_num_mode(self):
34+
"""Get if the choice is number mode."""
35+
return self.choice_mode == 'number'
36+
37+
@property
38+
def current_choice(self) -> Union[int, float]:
3239
"""Get current choice."""
33-
return (self.mask == 1).sum().item()
40+
int_choice = (self.mask == 1).sum().item()
41+
if self.is_num_mode:
42+
return int_choice
43+
else:
44+
return self._num2ratio(int_choice)
3445

3546
@current_choice.setter
36-
def current_choice(self, choice: int):
47+
def current_choice(self, choice: Union[int, float]):
3748
"""Set choice."""
49+
if isinstance(choice, float):
50+
int_choice = self._ratio2num(choice)
51+
else:
52+
int_choice = choice
3853
mask = torch.zeros([self.num_channels], device=self.mask.device)
39-
mask[0:choice] = 1
54+
mask[0:int_choice] = 1
4055
self.mask = mask.bool()
4156

4257
@property
@@ -58,20 +73,6 @@ def dump_chosen(self):
5873
"""Dump chosen."""
5974
return self.current_choice
6075

61-
# def __mul__(self, other):
62-
# """multiplication."""
63-
# if isinstance(other, int):
64-
# return self.derive_expand_mutable(other)
65-
# else:
66-
# return None
67-
68-
# def __floordiv__(self, other):
69-
# """division."""
70-
# if isinstance(other, int):
71-
# return self.derive_divide_mutable(other)
72-
# else:
73-
# return None
74-
7576
def __rmul__(self, other) -> DerivedMutable:
7677
return self * other
7778

@@ -121,3 +122,17 @@ def __floordiv__(self, other) -> DerivedMutable:
121122
return self.derive_divide_mutable(*other)
122123

123124
raise TypeError(f'Unsupported type {type(other)} for div!')
125+
126+
def _num2ratio(self, choice: Union[int, float]) -> float:
127+
"""Convert the a number choice to a ratio choice."""
128+
if isinstance(choice, float):
129+
return choice
130+
else:
131+
return choice / self.num_channels
132+
133+
def _ratio2num(self, choice: Union[int, float]) -> int:
134+
"""Convert the a ratio choice to a number choice."""
135+
if isinstance(choice, int):
136+
return choice
137+
else:
138+
return max(1, int(self.num_channels * choice))

mmrazor/models/mutables/mutable_channel/units/channel_unit.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ def is_mutable(self) -> bool:
113113

114114
def __repr__(self) -> str:
115115
return (f'{self.__class__.__name__}('
116-
f'{self.name}, index=({self.index}), '
116+
f'{self.name}, index={self.index}, '
117117
f'is_output_channel='
118118
f'{"true" if self.is_output_channel else "false"}, '
119119
f'expand_ratio={self.expand_ratio}'

mmrazor/models/mutables/mutable_channel/units/l1_mutable_channel_unit.py

+18
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import torch.nn as nn
66

77
from mmrazor.registry import MODELS
8+
from ..simple_mutable_channel import SimpleMutableChannel
89
from .sequential_mutable_channel_unit import SequentialMutableChannelUnit
910

1011

@@ -25,6 +26,23 @@ def __init__(self,
2526
min_ratio=0.9) -> None:
2627
super().__init__(num_channels, choice_mode, divisor, min_value,
2728
min_ratio)
29+
self.mutable_channel = SimpleMutableChannel(num_channels)
30+
31+
# choices
32+
33+
@property
34+
def current_choice(self) -> Union[int, float]:
35+
num = self.mutable_channel.activated_channels
36+
if self.is_num_mode:
37+
return num
38+
else:
39+
return self._num2ratio(num)
40+
41+
@current_choice.setter
42+
def current_choice(self, choice: Union[int, float]):
43+
int_choice = self._get_valid_int_choice(choice)
44+
mask = self._generate_mask(int_choice).bool()
45+
self.mutable_channel.current_choice = mask
2846

2947
# private methods
3048

mmrazor/models/mutables/mutable_channel/units/mutable_channel_unit.py

+71-7
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,25 @@
11
# Copyright (c) OpenMMLab. All rights reserved.
22
"""This module defines MutableChannelUnit."""
33
import abc
4+
from collections import Set
45
from typing import Dict, List, Type, TypeVar
56

67
import torch.nn as nn
78

8-
import mmrazor.models.architectures.dynamic_ops as dynamic_ops
9+
from mmrazor.models.architectures import dynamic_ops
910
from mmrazor.models.architectures.dynamic_ops.mixins import DynamicChannelMixin
1011
from mmrazor.models.mutables import DerivedMutable
11-
from mmrazor.models.mutables.mutable_channel.base_mutable_channel import \
12-
BaseMutableChannel
13-
from ..mutable_channel_container import MutableChannelContainer
12+
from mmrazor.models.mutables.mutable_channel import (BaseMutableChannel,
13+
MutableChannelContainer)
1414
from .channel_unit import Channel, ChannelUnit
1515

1616

1717
class MutableChannelUnit(ChannelUnit):
1818

1919
# init methods
2020
def __init__(self, num_channels: int, **kwargs) -> None:
21-
"""MutableChannelUnit inherits from ChannelUnit, which manages
22-
channels with channel-dependency.
21+
"""MutableChannelUnit inherits from ChannelUnit, which manages channels
22+
with channel-dependency.
2323
2424
Compared with ChannelUnit, MutableChannelUnit defines the core
2525
interfaces for pruning. By inheriting MutableChannelUnit,
@@ -44,6 +44,70 @@ def __init__(self, num_channels: int, **kwargs) -> None:
4444

4545
super().__init__(num_channels)
4646

47+
@classmethod
48+
def init_from_mutable_channel(cls, mutable_channel: BaseMutableChannel):
49+
unit = cls(mutable_channel.num_channels)
50+
return unit
51+
52+
@classmethod
53+
def init_from_predefined_model(cls, model: nn.Module):
54+
"""Initialize units using the model with pre-defined dynamicops and
55+
mutable-channels."""
56+
57+
def process_container(contanier: MutableChannelContainer,
58+
module,
59+
module_name,
60+
mutable2units,
61+
is_output=True):
62+
for index, mutable in contanier.mutable_channels.items():
63+
if isinstance(mutable, DerivedMutable):
64+
source_mutables: Set = \
65+
mutable._trace_source_mutables()
66+
source_channel_mutables = [
67+
mutable for mutable in source_mutables
68+
if isinstance(mutable, BaseMutableChannel)
69+
]
70+
assert len(source_channel_mutables) == 1, (
71+
'only support one mutable channel '
72+
'used in DerivedMutable')
73+
mutable = list(source_channel_mutables)[0]
74+
75+
if mutable not in mutable2units:
76+
mutable2units[mutable] = cls.init_from_mutable_channel(
77+
mutable)
78+
79+
unit: MutableChannelUnit = mutable2units[mutable]
80+
if is_output:
81+
unit.add_ouptut_related(
82+
Channel(
83+
module_name,
84+
module,
85+
index,
86+
is_output_channel=is_output))
87+
else:
88+
unit.add_input_related(
89+
Channel(
90+
module_name,
91+
module,
92+
index,
93+
is_output_channel=is_output))
94+
95+
mutable2units: Dict = {}
96+
for name, module in model.named_modules():
97+
if isinstance(module, DynamicChannelMixin):
98+
in_container: MutableChannelContainer = \
99+
module.get_mutable_attr(
100+
'in_channels')
101+
out_container: MutableChannelContainer = \
102+
module.get_mutable_attr(
103+
'out_channels')
104+
process_container(in_container, module, name, mutable2units,
105+
False)
106+
process_container(out_container, module, name, mutable2units,
107+
True)
108+
units = list(mutable2units.values())
109+
return units
110+
47111
# properties
48112

49113
@property
@@ -97,7 +161,7 @@ def prepare_for_pruning(self, model):
97161
98162
For example, we need to register mutables to dynamic-ops.
99163
"""
100-
raise not NotImplementedError
164+
raise NotImplementedError
101165

102166
# pruning: choice-related
103167

0 commit comments

Comments
 (0)