Skip to content

Commit a725b51

Browse files
humu789LKJackyliukaigaoyang07kitecats
authored and
humu789
committed
Merge dev-1.x into quantize (open-mmlab#430)
* Fix a bug in make_divisible. (open-mmlab#333) fix bug in make_divisible Co-authored-by: liukai <liukai@pjlab.org.cn> * [Fix] Fix counter mapping bug (open-mmlab#331) * fix counter mapping bug * move judgment into get_counter_type & update UT * [Docs]Add MMYOLO projects link (open-mmlab#334) * [Doc] fix typos in en/usr_guides (open-mmlab#299) * Update README.md * Update README_zh-CN.md Co-authored-by: Sheffield <49406546+SheffieldCao@users.noreply.github.com> * [Features]Support `MethodInputsRecorder` and `FunctionInputsRecorder` (open-mmlab#320) * support MethodInputsRecorder and FunctionInputsRecorder * fix bugs that the model can not be pickled * WIP: add pytest for ema model * fix bugs in recorder and delivery when ema_hook is used * don't register the DummyDataset * fix pytest * [Feature] Add deit-base (open-mmlab#332) * WIP: support deit * WIP: add deithead * WIP: fix checkpoint hook * fix data preprocessor * fix cfg * WIP: add readme * reset single_teacher_distill * add metafile * add model to model-index * fix configs and readme * [Feature]Feature map visualization (open-mmlab#293) * WIP: vis * WIP: add visualization * WIP: add visualization hook * WIP: support razor visualizer * WIP * WIP: wrap draw_featmap * support feature map visualization * add a demo image for visualization * fix typos * change eps to 1e-6 * add pytest for visualization * fix vis hook * fix arguments' name * fix img path * support draw inference results * add visualization doc * fix figure url * move files Co-authored-by: weihan cao <HIT-cwh> * [Feature] Add kd examples (open-mmlab#305) * support kd for mbv2 and shufflenetv2 * WIP: fix ckpt path * WIP: fix kd r34-r18 * add metafile * fix metafile * delete * [Doc] add documents about pruning. (open-mmlab#313) * init * update user guide * update images * update * update How to prune your model * update how_to_use_config_tool_of_pruning.md * update doc * move location * update * update * update * add mutablechannels.md * add references Co-authored-by: liukai <liukai@pjlab.org.cn> Co-authored-by: jacky <jacky@xx.com> * [Feature] PyTorch version of `PKD: General Distillation Framework for Object Detectors via Pearson Correlation Coefficient`. (open-mmlab#304) * add pkd * add pytest for pkd * fix cfg * WIP: support fcos3d * WIP: support fcos3d pkd * support mmdet3d * fix cfgs * change eps to 1e-6 and add some comments * fix docstring * fix cfg * add assert * add type hint * WIP: add readme and metafile * fix readme * update metafiles and readme * fix metafile * fix pipeline figure * [Refactor] Refactor Mutables and Mutators (open-mmlab#324) * refactor mutables * update load fix subnet * add DumpChosen Typehint * adapt UTs * fix lint * Add GroupMixin to ChannelMutator (temporarily) * fix type hints * add GroupMixin doc-string * modified by comments * fix type hits * update subnet format * fix channel group bugs and add UTs * fix doc string * fix comments * refactor diff module forward * fix error in channel mutator doc * fix comments Co-authored-by: liukai <liukai@pjlab.org.cn> * [Fix] Update readme (open-mmlab#341) * update kl readme * update dsnas readme * fix url * Bump version to 1.0.0rc1 (open-mmlab#338) update version * [Feature] Add Autoformer algorithm (open-mmlab#315) * update candidates * update subnet_sampler_loop * update candidate * add readme * rename variable * rename variable * clean * update * add doc string * Revert "[Improvement] Support for candidate multiple dimensional search constraints." * [Improvement] Update Candidate with multi-dim search constraints. (open-mmlab#322) * update doc * add support type * clean code * update candidates * clean * xx * set_resource -> set_score * fix ci bug * py36 lint * fix bug * fix check constrain * py36 ci * redesign candidate * fix pre-commit * update cfg * add build_resource_estimator * fix ci bug * remove runner.epoch in testcase * [Feature] Autoformer architecture and dynamicOPs (open-mmlab#327) * add DynamicSequential * dynamiclayernorm * add dynamic_pathchembed * add DynamicMultiheadAttention and DynamicRelativePosition2D * add channel-level dynamicOP * add autoformer algo * clean notes * adapt channel_mutator * vit fly * fix import * mutable init * remove annotation * add DynamicInputResizer * add unittest for mutables * add OneShotMutableChannelUnit_VIT * clean code * reset unit for vit * remove attr * add autoformer backbone UT * add valuemutator UT * clean code * add autoformer algo UT * update classifier UT * fix test error * ignore * make lint * update * fix lint * mutable_attrs * fix test * fix error * remove DynamicInputResizer * fix test ci * remove InputResizer * rename variables * modify type * Continued improvements of ChannelUnit * fix lint * fix lint * remove OneShotMutableChannelUnit * adjust derived type * combination mixins * clean code * fix sample subnet * search loop fly * more annotations * avoid counter warning and modify batch_augment cfg by gy * restore * source_value_mutables restriction * simply arch_setting api * update * clean * fix ut * [Feature] Add performance predictor (open-mmlab#306) * add predictor with 4 handlers * [Improvement] Update Candidate with multi-dim search constraints. (open-mmlab#322) * update doc * add support type * clean code * update candidates * clean * xx * set_resource -> set_score * fix ci bug * py36 lint * fix bug * fix check constrain * py36 ci * redesign candidate * fix pre-commit * update cfg * add build_resource_estimator * fix ci bug * remove runner.epoch in testcase * update metric_predictor: 1. update MetricPredictor; 2. add predictor config for searching; 3. add predictor in evolution_search_loop. * add UT for predictor * add MLPHandler * patch optional.txt for predictors * patch test_evolution_search_loop * refactor apis of predictor and handlers * fix ut and remove predictor_cfg in predictor * adapt new mutable & mutator design * fix ut * remove unness assert after rebase * move predictor-build in __init__ & simplify estimator-build Co-authored-by: Yue Sun <aptsunny@tongji.edu.cn> * [Feature] Add DCFF (open-mmlab#295) * add ChannelGroup (open-mmlab#250) * rebase new dev-1.x * modification for adding config_template * add docstring to channel_group.py * add docstring to mutable_channel_group.py * rm channel_group_cfg from Graph2ChannelGroups * change choice type of SequentialChannelGroup from float to int * add a warning about group-wise conv * restore __init__ of dynamic op * in_channel_mutable -> mutable_in_channel * rm abstractproperty * add a comment about VT * rm registry for ChannelGroup * MUTABLECHANNELGROUP -> ChannelGroupType * refine docstring of IndexDict * update docstring * update docstring * is_prunable -> is_mutable * update docstring * fix error in pre-commit * update unittest * add return type * unify init_xxx apit * add unitest about init of MutableChannelGroup * update according to reviews * sequential_channel_group -> sequential_mutable_channel_group Co-authored-by: liukai <liukai@pjlab.org.cn> * Add BaseChannelMutator and refactor Autoslim (open-mmlab#289) * add BaseChannelMutator * add autoslim * tmp * make SequentialMutableChannelGroup accpeted both of num and ratio as choice. and supports divisior * update OneShotMutableChannelGroup * pass supernet training of autoslim * refine autoslim * fix bug in OneShotMutableChannelGroup * refactor make_divisible * fix spell error: channl -> channel * init_using_backward_tracer -> init_from_backward_tracer init_from_fx_tracer -> init_from_fx_tracer * refine SequentialMutableChannelGroup * let mutator support models with dynamicop * support define search space in model * tracer_cfg -> parse_cfg * refine * using -> from * update docstring * update docstring Co-authored-by: liukai <liukai@pjlab.org.cn> * tmpsave * migrate ut * tmpsave2 * add loss collector * refactor slimmable and add l1-norm (open-mmlab#291) * refactor slimmable and add l1-norm * make l1-norm support convnd * update get_channel_groups * add l1-norm_resnet34_8xb32_in1k.py * add pretrained to resnet34-l1 * remove old channel mutator * BaseChannelMutator -> ChannelMutator * update according to reviews * add readme to l1-norm * MBV2_slimmable -> MBV2_slimmable_config Co-authored-by: liukai <liukai@pjlab.org.cn> * update config * fix md & pytorch support <1.9.0 in batchnorm init * Clean old codes. (open-mmlab#296) * remove old dynamic ops * move dynamic ops * clean old mutable_channels * rm OneShotMutableChannel * rm MutableChannel * refine * refine * use SquentialMutableChannel to replace OneshotMutableChannel * refactor dynamicops folder * let SquentialMutableChannel support float Co-authored-by: liukai <liukai@pjlab.org.cn> * fix ci * ci fix py3.6.x & add mmpose * ci fix py3.6.9 in utils/index_dict.py * fix mmpose * minimum_version_cpu=3.7 * fix ci 3.7.13 * fix pruning &meta ci * support python3.6.9 * fix py3.6 import caused by circular import patch in py3.7 * fix py3.6.9 * Add channel-flow (open-mmlab#301) * base_channel_mutator -> channel_mutator * init * update docstring * allow omitting redundant configs for channel * add register_mutable_channel_to_a_module to MutableChannelContainer * update according to reviews 1 * update according to reviews 2 * update according to reviews 3 * remove old docstring * fix error * using->from * update according to reviews * support self-define input channel number * update docstring * chanenl -> channel_elem Co-authored-by: liukai <liukai@pjlab.org.cn> Co-authored-by: jacky <jacky@xx.com> * support >=3.7 * support py3.6.9 * Rename: ChannelGroup -> ChannelUnit (open-mmlab#302) * refine repr of MutableChannelGroup * rename folder name * ChannelGroup -> ChannelUnit * filename in units folder * channel_group -> channel_unit * groups -> units * group -> unit * update * get_mutable_channel_groups -> get_mutable_channel_units * fix bug * refine docstring * fix ci * fix bug in tracer Co-authored-by: liukai <liukai@pjlab.org.cn> * update new channel config format * update pruning refactor * update merged pruning * update commit * fix dynamic_conv_mixin * update comments: readme&dynamic_conv_mixins.py * update readme * move kl softmax channel pooling to op by comments * fix comments: fix redundant & split README.md * dcff in ItePruneAlgorithm * partial dynamic params for fuseconv * add step_freq & prune_time check * update comments * update comments * update comments * fix ut * fix gpu ut & revise step_freq in ItePruneAlgorithm * update readme * revise ItePruneAlgorithm * fix docs * fix dynamic_conv attr * fix ci Co-authored-by: LKJacky <108643365+LKJacky@users.noreply.github.com> Co-authored-by: liukai <liukai@pjlab.org.cn> Co-authored-by: zengyi.vendor <zengyi.vendor@sensetime.com> Co-authored-by: jacky <jacky@xx.com> * [Fix] Fix optional requirements (open-mmlab#357) * fix optional requirements * fix dcff ut * fix import with get_placeholder * supplement the previous commit * [Fix] Fix configs of wrn models and ofd. (open-mmlab#361) * 1.revise the configs of wrn22, wrn24, and wrn40. 2.revise the data_preprocessor of ofd_backbone_resnet50_resnet18_8xb16_cifar10 * 1.Add README for vanilla-wrm. * 1.Revise readme of wrn Co-authored-by: zhangzhongyu <zhangzhongyu@pjlab.org.cn> * [Fix] Fix bug on mmrazor visualization, mismatch argument in define and use. (open-mmlab#356) fix bug on mmrazor visualization, mismatch argument in define and use. Co-authored-by: Xianpan Zhou <32625100+PanDaMeow@users.noreply.github.com> * fix bug in benchmark_test (open-mmlab#364) fix bug in configs Co-authored-by: Your Name <you@example.com> * [FIX] Fix wrn configs (open-mmlab#368) * fix wrn configs * fix wrn configs * update online wrn model weight * [Fix] fix bug on pkd config. Wrong import filename. (open-mmlab#373) * [CI] Update ci to torch1.13 (open-mmlab#380) update ci to torch1.13 * [Feature] Add BigNAS algorithm (open-mmlab#219) * add calibrate-bn-statistics * add test calibrate-bn-statistics * fix mixins * fix mixins * fix mixin tests * remove slimmable channel mutable and refactor dynamic op * refact dynamic batch norm * add progressive dynamic conv2d * add center crop dynamic conv2d * refactor dynamic directory * refactor dynamic sequential * rename length to depth in dynamic sequential * add test for derived mutable * refactor dynamic op * refactor api of dynamic op * add derive mutable mixin * addbignas algorithm * refactor bignas structure * add input resizer * add input resizer to bignas * move input resizer from algorithm into classifier * remove compnents * add attentive mobilenet * delete json file * nearly(less 0.2) align inference accuracy with gml * move mutate seperated in bignas mobilenet backbone * add zero_init_residual * add set_dropout * set dropout in bignas algorithm * fix registry * add subnet yaml and nearly align inference accuracy with gml * add rsb config for bignas * remove base in config * add gml bignas config * convert to iter based * bignas forward and backward fly * fix merge conflict * fix dynamicseq bug * fix bug and refactor bignas * arrange configs of bignas * fix typo * refactor attentive_mobilenet * fix channel mismatch due to registion of DerivedMutable * update bignas & fix se channel mismatch * add AutoAugmentV2 & remove unness configs * fix lint * recover channel assertion in channel unit * fix a group bug * fix comments * add docstring * add norm in dynamic_embed * fix search loop & other minor changes * fix se expansion * minor change * add ut for bignas & attentive_mobilenet * fix ut * update bignas readme * rm unness ut & supplement get_placeholder * fix lint * fix ut * add subnet deployment in downstream tasks. * minor change * update ofa backbone * minor fix * Continued improvements of searchable backbone * minor change * drop ratio in backbone * fix comments * fix ci test * fix test * add dynamic shortcut UT * modify strategy to fit bignas * fix test * fix bug in neck * fix error * fix error * fix yaml * save subnet ckpt * merge autoslim_val/test_loop into subnet_val_loop * move calibrate_bn_mixin to utils * fix bugs and add docstring * clean code * fix register bug * clean code * update Co-authored-by: wangshiguang <wangshiguang@sensetime.com> Co-authored-by: gaoyang07 <1546308416@qq.com> Co-authored-by: aptsunny <aptsunny@tongji.edu.cn> Co-authored-by: sunyue1 <sunyue1@sensetime.com> * [Bug] Fix ckpt (open-mmlab#372) fix ckpt * [Feature] Add tools to convert distill ckpt to student-only ckpt. (open-mmlab#381) * [Feature] Add tools to convert distill ckpt to student-only ckpt. * fix bug. * add --model-only to only save model. * Make changes accroding to PR review. * Enhance the Abilities of the Tracer for Pruning. (open-mmlab#371) * tmp * add new mmdet models * add docstring * pass test and pre-commit * rm razor tracer * update fx tracer, now it can automatically wrap methods and functions. * update tracer passed models * add warning for torch <1.12.0 fix bug for python3.6 update placeholder to support placeholder.XXX * fix bug * update docs * fix lint * fix parse_cfg in configs * restore mutablechannel * test ite prune algorithm when using dist * add get_model_from_path to MMModelLibrrary * add mm models to DefaultModelLibrary * add uts * fix bug * fix bug * add uts * add uts * add uts * add uts * fix bug * restore ite_prune_algorithm * update doc * PruneTracer -> ChannelAnalyzer * prune_tracer -> channel_analyzer * add test for fxtracer * fix bug * fix bug * PruneTracer -> ChannelAnalyzer refine * CustomFxTracer -> MMFxTracer * fix bug when test with torch<1.12 * update print log * fix lint * rm unuseful code Co-authored-by: liukai <liukai@pjlab.org.cn> Co-authored-by: jacky <jacky@xx.com> Co-authored-by: Your Name <you@example.com> Co-authored-by: liukai <your_email@abc.example> * fix bug in placer holder (open-mmlab#395) * fix bug in placer holder * remove redundent comment Co-authored-by: liukai <your_email@abc.example> * Add get_prune_config and a demo config_pruning (open-mmlab#389) * update tools and test * add demo * disable test doc * add switch for test tools and test_doc * fix bug * update doc * update tools name * mv get_channel_units Co-authored-by: liukai <your_email@abc.example> * [Improvement] Adapt OFA series with SearchableMobileNetV3 (open-mmlab#385) * fix mutable bug in AttentiveMobileNetV3 * remove unness code * update ATTENTIVE_SUBNET_A0-A6.yaml with optimized names * unify the sampling usage in sandwich_rule-based NAS * use alias to export subnet * update OFA configs * fix attr bug * fix comments * update convert_supernet2subnet.py * correct the way to dump DerivedMutable * fix convert index bug * update OFA configs & models * fix dynamic2static * generalize convert_ofa_ckpt.py * update input_resizer * update README.md * fix ut * update export_fix_subnet * update _dynamic_to_static * update fix_subnet UT & minor fix bugs * fix ut * add new autoaug compared to attentivenas * clean * fix act * fix act_cfg * update fix_subnet * fix lint * add docstring Co-authored-by: gaoyang07 <1546308416@qq.com> Co-authored-by: aptsunny <aptsunny@tongji.edu.cn> * [Fix]Dcff Deploy Revision (open-mmlab#383) * dcff deploy revision * tempsave * update fix_subnet * update mutator load * export/load_fix_subnet revision for mutator * update fix_subnet with dev-1.x * update comments * update docs * update registry * [Fix] Fix commands in README to adapt branch 1.x (open-mmlab#400) * update commands in README for 1.x * fix commands Co-authored-by: gaoyang07 <1546308416@qq.com> * Set requires_grad to False if the teacher is not trainable (open-mmlab#398) * add choice and mask of units to checkpoint (open-mmlab#397) * add choice and mask of units to checkpoint * update * fix bug * remove device operation * fix bug * fix circle ci error * fix error in numpy for circle ci * fix bug in requirements * restore * add a note * a new solution * save mutable_channel.mask as float for dist training * refine * mv meta file test Co-authored-by: liukai <your_email@abc.example> Co-authored-by: jacky <jacky@xx.com> * [Bug]Fix fpn teacher distill (open-mmlab#388) fix fpn distill * [CodeCamp open-mmlab#122] Support KD algorithm MGD for detection. (open-mmlab#377) * [Feature] Support KD algorithm MGD for detection. * use connector to beauty mgd. * fix typo, add unitest. * fix mgd loss unitest. * fix mgd connector unitest. * add model pth and log file. * add mAP. * update l1 config (open-mmlab#405) * add l1 config * update l1 config Co-authored-by: jacky <jacky@xx.com> * [Feature] Add greedy search for AutoSlim (open-mmlab#336) * WIP: add greedysearch * fix greedy search and add bn_training_mode to autoslim * fix cfg files * fix autoslim configs * fix bugs when converting dynamic bn to static bn * change to test loop * refactor greedy search * rebase and fix greedysearch * fix lint * fix and delete useless codes * fix pytest * fix pytest and add bn_training_mode * fix lint * add reference to AutoSlimGreedySearchLoop's docstring * sort candidate_choices * fix save subnet * delete useless codes in channel container * change files' name: convert greedy_search_loop to autoslim_greedy_search_loop * [Fix] Fix metafile (open-mmlab#422) * fix ckpt path in metafile and readme * fix darts file path * fix docstring in ConfigurableDistiller * fix darts * fix error * add darts of mmrazor version * delete py36 Co-authored-by: liukai <your_email@abc.example> * update bignas cfg (open-mmlab#412) * check attentivenas training * update ckpt link * update supernet log Co-authored-by: aptsunny <aptsunny@tongji.edu.cn> * Bump version to 1.0.0rc2 (open-mmlab#423) bump version to 1.0.0rc2 Co-authored-by: liukai <your_email@abc.example> * fix lint * fix ci * add tmp docstring for passed ci * add tmp docstring for passed ci * fix ci * add get_placeholder for quant * add skip for unittest * fix package placeholder bug * add version judgement in __init__ * update prev commit * update prev commit * update prev commit * update prev commit * update prev commit * update prev commit * update prev commit * update prev commit * update prev commit Co-authored-by: LKJacky <108643365+LKJacky@users.noreply.github.com> Co-authored-by: liukai <liukai@pjlab.org.cn> Co-authored-by: Yang Gao <Gary1546308416AL@gmail.com> Co-authored-by: kitecats <90194592+kitecats@users.noreply.github.com> Co-authored-by: Sheffield <49406546+SheffieldCao@users.noreply.github.com> Co-authored-by: whcao <41630003+HIT-cwh@users.noreply.github.com> Co-authored-by: jacky <jacky@xx.com> Co-authored-by: pppppM <67539920+pppppM@users.noreply.github.com> Co-authored-by: Yue Sun <aptsunny@tongji.edu.cn> Co-authored-by: zengyi <31244134+spynccat@users.noreply.github.com> Co-authored-by: zengyi.vendor <zengyi.vendor@sensetime.com> Co-authored-by: zhongyu zhang <43191879+wilxy@users.noreply.github.com> Co-authored-by: zhangzhongyu <zhangzhongyu@pjlab.org.cn> Co-authored-by: Xianpan Zhou <32625100+TinyTigerPan@users.noreply.github.com> Co-authored-by: Xianpan Zhou <32625100+PanDaMeow@users.noreply.github.com> Co-authored-by: Your Name <you@example.com> Co-authored-by: P.Huang <37200926+FreakieHuang@users.noreply.github.com> Co-authored-by: qiufeng <44188071+wutongshenqiu@users.noreply.github.com> Co-authored-by: wangshiguang <wangshiguang@sensetime.com> Co-authored-by: gaoyang07 <1546308416@qq.com> Co-authored-by: sunyue1 <sunyue1@sensetime.com> Co-authored-by: liukai <your_email@abc.example> Co-authored-by: Ming-Hsuan-Tu <qrnnis2623891@gmail.com> Co-authored-by: Yivona <120088893+yivona08@users.noreply.github.com> Co-authored-by: Yue Sun <aptsunny@alumni.tongji.edu.cn>
1 parent f47d49a commit a725b51

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

41 files changed

+838
-305
lines changed

.github/workflows/build.yml

+38
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,44 @@ jobs:
3131
python-version: [3.7]
3232
torch: [1.6.0, 1.7.0, 1.8.0, 1.9.0, 1.10.0, 1.11.0, 1.12.0, 1.13.0]
3333
include:
34+
- torch: 1.6.0
35+
torch_version: 1.6
36+
torchvision: 0.7.0
37+
- torch: 1.7.0
38+
torch_version: 1.7
39+
torchvision: 0.8.1
40+
- torch: 1.7.0
41+
torch_version: 1.7
42+
torchvision: 0.8.1
43+
python-version: 3.8
44+
- torch: 1.8.0
45+
torch_version: 1.8
46+
torchvision: 0.9.0
47+
- torch: 1.8.0
48+
torch_version: 1.8
49+
torchvision: 0.9.0
50+
python-version: 3.8
51+
- torch: 1.9.0
52+
torch_version: 1.9
53+
torchvision: 0.10.0
54+
- torch: 1.9.0
55+
torch_version: 1.9
56+
torchvision: 0.10.0
57+
python-version: 3.8
58+
- torch: 1.10.0
59+
torch_version: 1.10
60+
torchvision: 0.11.0
61+
- torch: 1.10.0
62+
torch_version: 1.10
63+
torchvision: 0.11.0
64+
python-version: 3.8
65+
- torch: 1.11.0
66+
torch_version: 1.11
67+
torchvision: 0.12.0
68+
- torch: 1.11.0
69+
torch_version: 1.11
70+
torchvision: 0.12.0
71+
python-version: 3.8
3472
- torch: 1.12.0
3573
torch_version: 1.12
3674
torchvision: 0.13.0

configs/pruning/mmpose/dcff/fix_subnet.json

+4
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,11 @@
5454
"min_value":1,
5555
"min_ratio":0.9
5656
},
57+
<<<<<<< HEAD
5758
"choice":0.59375
59+
=======
60+
"choice":0.59374
61+
>>>>>>> 985a611e (Merge dev-1.x into quantize (#430))
5862
},
5963
"backbone.layer2.1.conv1_(0, 128)_128":{
6064
"init_args":{

configs/pruning/mmseg/dcff/dcff_compact_pointrend_resnet50_8xb2_cityscapes.py

+4
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,11 @@
11
_base_ = ['dcff_pointrend_resnet50_8xb2_cityscapes.py']
22

33
# model settings
4+
<<<<<<< HEAD
45
_base_.model = dict(
6+
=======
7+
model_cfg = dict(
8+
>>>>>>> 985a611e (Merge dev-1.x into quantize (#430))
59
_scope_='mmrazor',
610
type='sub_model',
711
cfg=_base_.architecture,

mmrazor/engine/__init__.py

+4-5
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,14 @@
44
from .optimizers import SeparateOptimWrapperConstructor
55
from .runner import (AutoSlimGreedySearchLoop, DartsEpochBasedTrainLoop,
66
DartsIterBasedTrainLoop, EvolutionSearchLoop,
7-
GreedySamplerTrainLoop, SelfDistillValLoop,
8-
SingleTeacherDistillValLoop, SlimmableValLoop,
9-
SubnetValLoop)
7+
GreedySamplerTrainLoop, PTQLoop, QATEpochBasedLoop,
8+
SelfDistillValLoop, SingleTeacherDistillValLoop,
9+
SlimmableValLoop, SubnetValLoop)
1010

1111
__all__ = [
1212
'SeparateOptimWrapperConstructor', 'DumpSubnetHook',
1313
'SingleTeacherDistillValLoop', 'DartsEpochBasedTrainLoop',
1414
'DartsIterBasedTrainLoop', 'SlimmableValLoop', 'EvolutionSearchLoop',
1515
'GreedySamplerTrainLoop', 'EstimateResourcesHook', 'SelfDistillValLoop',
16-
'AutoSlimGreedySearchLoop', 'SubnetValLoop', 'StopDistillHook',
17-
'DMCPSubnetHook'
16+
'AutoSlimGreedySearchLoop', 'SubnetValLoop', 'PTQLoop', 'QATEpochBasedLoop'
1817
]

mmrazor/engine/runner/__init__.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,6 @@
1313
'SingleTeacherDistillValLoop', 'DartsEpochBasedTrainLoop',
1414
'DartsIterBasedTrainLoop', 'SlimmableValLoop', 'EvolutionSearchLoop',
1515
'GreedySamplerTrainLoop', 'SubnetValLoop', 'SelfDistillValLoop',
16-
'ItePruneValLoop', 'AutoSlimGreedySearchLoop', 'PTQLoop',
17-
'QATEpochBasedLoop'
16+
'ItePruneValLoop', 'AutoSlimGreedySearchLoop', 'QATEpochBasedLoop',
17+
'PTQLoop'
1818
]

mmrazor/engine/runner/iteprune_val_loop.py

-1
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,6 @@ def _save_fix_subnet(self):
5252
file.write(fix_subnet)
5353
torch.save({'state_dict': static_model.state_dict()},
5454
osp.join(self.runner.work_dir, weight_name))
55-
5655
self.runner.logger.info(
5756
'export finished and '
5857
f'{subnet_name}, '

mmrazor/engine/runner/quantization_loops.py

+12-3
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,18 @@
44
import torch
55
from mmengine.evaluator import Evaluator
66
from mmengine.runner import EpochBasedTrainLoop, TestLoop, ValLoop
7-
from torch.ao.quantization import (disable_observer, enable_fake_quant,
8-
enable_observer)
9-
from torch.nn.intrinsic.qat import freeze_bn_stats
7+
8+
try:
9+
from torch.ao.quantization import (disable_observer, enable_fake_quant,
10+
enable_observer)
11+
from torch.nn.intrinsic.qat import freeze_bn_stats
12+
except ImportError:
13+
from mmrazor.utils import get_placeholder
14+
disable_observer = get_placeholder('torch>=1.13')
15+
enable_fake_quant = get_placeholder('torch>=1.13')
16+
enable_observer = get_placeholder('torch>=1.13')
17+
freeze_bn_stats = get_placeholder('torch>=1.13')
18+
1019
from torch.utils.data import DataLoader
1120

1221
from mmrazor.registry import LOOPS

mmrazor/models/algorithms/nas/autoslim.py

+2
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,8 @@ def __init__(self,
7575
self._optim_wrapper_count_status_reinitialized = False
7676
self.norm_training = norm_training
7777

78+
self.bn_training_mode = bn_training_mode
79+
7880
def _build_mutator(self,
7981
mutator: VALID_MUTATOR_TYPE = None) -> ChannelMutator:
8082
"""Build mutator."""

mmrazor/models/algorithms/pruning/ite_prune_algorithm.py

+4
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from mmrazor.models.mutables import MutableChannelUnit
1111
from mmrazor.models.mutators import ChannelMutator
1212
from mmrazor.registry import MODELS
13+
from mmrazor.utils import ValidFixMutable
1314
from ..base import BaseAlgorithm
1415

1516
LossResults = Dict[str, torch.Tensor]
@@ -97,6 +98,8 @@ class ItePruneAlgorithm(BaseAlgorithm):
9798
mutator_cfg (Union[Dict, ChannelMutator], optional): The config
9899
of a mutator. Defaults to dict( type='ChannelMutator',
99100
channel_unit_cfg=dict( type='SequentialMutableChannelUnit')).
101+
fix_subnet (str | dict | :obj:`FixSubnet`): The path of yaml file or
102+
loaded dict or built :obj:`FixSubnet`. Defaults to None.
100103
data_preprocessor (Optional[Union[Dict, nn.Module]], optional):
101104
Defaults to None.
102105
target_pruning_ratio (dict, optional): The prune-target. The template
@@ -118,6 +121,7 @@ def __init__(self,
118121
type='ChannelMutator',
119122
channel_unit_cfg=dict(
120123
type='SequentialMutableChannelUnit')),
124+
fix_subnet: Optional[ValidFixMutable] = None,
121125
data_preprocessor: Optional[Union[Dict, nn.Module]] = None,
122126
target_pruning_ratio: Optional[Dict[str, float]] = None,
123127
step_freq=1,

mmrazor/models/algorithms/quantization/mm_architecture.py

+7-2
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,17 @@
77
from mmengine.runner import load_checkpoint
88
from mmengine.structures import BaseDataElement
99
from torch import nn
10-
from torch.ao.quantization import FakeQuantizeBase
1110

12-
from mmrazor.models.task_modules import build_graphmodule
11+
from mmrazor.models.task_modules.tracer import build_graphmodule
1312
from mmrazor.registry import MODEL_WRAPPERS, MODELS
1413
from ..base import BaseAlgorithm
1514

15+
try:
16+
from torch.ao.quantization import FakeQuantizeBase
17+
except ImportError:
18+
from mmrazor.utils import get_placeholder
19+
FakeQuantizeBase = get_placeholder('torch>=1.13')
20+
1621
LossResults = Dict[str, torch.Tensor]
1722
TensorResults = Union[Tuple[torch.Tensor], torch.Tensor]
1823
PredictResults = List[BaseDataElement]

mmrazor/models/fake_quants/base.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,8 @@
11
# Copyright (c) OpenMMLab. All rights reserved.
2-
from torch.ao.quantization import FakeQuantize
2+
try:
3+
from torch.ao.quantization import FakeQuantize
4+
except ImportError:
5+
from mmrazor.utils import get_placeholder
6+
FakeQuantize = get_placeholder('torch>=1.13')
37

48
BaseFakeQuantize = FakeQuantize

mmrazor/models/fake_quants/torch_fake_quants.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,14 @@
22
import inspect
33
from typing import List
44

5-
import torch.ao.quantization.fake_quantize as torch_fake_quant_src
6-
75
from mmrazor.registry import MODELS
86

7+
try:
8+
import torch.ao.quantization.fake_quantize as torch_fake_quant_src
9+
except ImportError:
10+
from mmrazor.utils import get_package_placeholder
11+
torch_fake_quant_src = get_package_placeholder('torch>=1.13')
12+
913

1014
def register_torch_fake_quants() -> List[str]:
1115
"""Register fake_quants in ``torch.ao.quantization.fake_quantize`` to the

mmrazor/models/losses/__init__.py

-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
# Copyright (c) OpenMMLab. All rights reserved.
22
from .ab_loss import ABLoss
3-
from .adaround_loss import AdaRoundLoss
43
from .at_loss import ATLoss
54
from .crd_loss import CRDLoss
65
from .cross_entropy_loss import CrossEntropyLoss

mmrazor/models/mutators/channel_mutator/one_shot_channel_mutator.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,13 @@
44

55
from mmrazor.models.mutables import OneShotMutableChannelUnit
66
from mmrazor.registry import MODELS
7+
from ..group_mixin import DynamicSampleMixin
78
from .channel_mutator import ChannelMutator, ChannelUnitType
89

910

1011
@MODELS.register_module()
11-
class OneShotChannelMutator(ChannelMutator[OneShotMutableChannelUnit]):
12+
class OneShotChannelMutator(ChannelMutator[OneShotMutableChannelUnit],
13+
DynamicSampleMixin):
1214
"""OneShotChannelMutator based on ChannelMutator. It use
1315
OneShotMutableChannelUnit by default.
1416

mmrazor/models/mutators/group_mixin.py

+68
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,11 @@
88
from mmrazor.models.mutables.mutable_module import MutableModule
99
from .base_mutator import MUTABLE_TYPE
1010

11+
if sys.version_info < (3, 8):
12+
from typing_extensions import Protocol
13+
else:
14+
from typing import Protocol
15+
1116

1217
class GroupMixin():
1318
"""A mixin for :class:`BaseMutator`, which can group mutables by
@@ -259,3 +264,66 @@ def _check_valid_groups(self, alias2mutable_names: Dict[str, List[str]],
259264
f'When a mutable is set alias attribute :{alias_key},'
260265
f'the corresponding module name {mutable_name} should '
261266
f'not be used in `custom_group` {custom_group}.')
267+
268+
269+
class MutatorProtocol(Protocol): # pragma: no cover
270+
271+
@property
272+
def mutable_class_type(self) -> Type[BaseMutable]:
273+
...
274+
275+
@property
276+
def search_groups(self) -> Dict:
277+
...
278+
279+
280+
class OneShotSampleMixin:
281+
"""Sample mixin for one-shot mutators."""
282+
283+
def sample_choices(self: MutatorProtocol) -> Dict:
284+
"""Sample choices for each group in search_groups."""
285+
random_choices = dict()
286+
for group_id, modules in self.search_groups.items():
287+
random_choices[group_id] = modules[0].sample_choice()
288+
289+
return random_choices
290+
291+
def set_choices(self: MutatorProtocol, choices: Dict) -> None:
292+
"""Set choices for each group in search_groups."""
293+
for group_id, modules in self.search_groups.items():
294+
choice = choices[group_id]
295+
for module in modules:
296+
module.current_choice = choice
297+
298+
299+
class DynamicSampleMixin(OneShotSampleMixin):
300+
301+
def sample_choices(self: MutatorProtocol, kind: str = 'random') -> Dict:
302+
"""Sample choices for each group in search_groups."""
303+
random_choices = dict()
304+
for group_id, modules in self.search_groups.items():
305+
if kind == 'max':
306+
random_choices[group_id] = modules[0].max_choice
307+
elif kind == 'min':
308+
random_choices[group_id] = modules[0].min_choice
309+
else:
310+
random_choices[group_id] = modules[0].sample_choice()
311+
return random_choices
312+
313+
@property
314+
def max_choice(self: MutatorProtocol) -> Dict:
315+
"""Get max choices for each group in search_groups."""
316+
max_choice = dict()
317+
for group_id, modules in self.search_groups.items():
318+
max_choice[group_id] = modules[0].max_choice
319+
320+
return max_choice
321+
322+
@property
323+
def min_choice(self: MutatorProtocol) -> Dict:
324+
"""Get min choices for each group in search_groups."""
325+
min_choice = dict()
326+
for group_id, modules in self.search_groups.items():
327+
min_choice[group_id] = modules[0].min_choice
328+
329+
return min_choice
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
# Copyright (c) OpenMMLab. All rights reserved.
2+
from .dynamic_value_mutator import DynamicValueMutator
3+
from .value_mutator import ValueMutator
4+
5+
__all__ = ['ValueMutator', 'DynamicValueMutator']
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
# Copyright (c) OpenMMLab. All rights reserved.
2+
from mmrazor.models.mutables import OneShotMutableValue
3+
from mmrazor.registry import MODELS
4+
from ..group_mixin import DynamicSampleMixin
5+
from .value_mutator import ValueMutator
6+
7+
8+
@MODELS.register_module()
9+
class DynamicValueMutator(ValueMutator, DynamicSampleMixin):
10+
"""Dynamic value mutator with type as `OneShotMutableValue`."""
11+
12+
@property
13+
def mutable_class_type(self):
14+
return OneShotMutableValue

0 commit comments

Comments
 (0)