1
1
# Copyright (c) OpenMMLab. All rights reserved.
2
2
"""This module defines MutableChannelUnit."""
3
3
import abc
4
+ from collections import Set
4
5
from typing import Dict , List , Type , TypeVar
5
6
6
7
import torch .nn as nn
7
8
8
- import mmrazor .models .architectures . dynamic_ops as dynamic_ops
9
+ from mmrazor .models .architectures import dynamic_ops
9
10
from mmrazor .models .architectures .dynamic_ops .mixins import DynamicChannelMixin
10
11
from mmrazor .models .mutables import DerivedMutable
11
- from mmrazor .models .mutables .mutable_channel .base_mutable_channel import \
12
- BaseMutableChannel
13
- from ..mutable_channel_container import MutableChannelContainer
12
+ from mmrazor .models .mutables .mutable_channel import (BaseMutableChannel ,
13
+ MutableChannelContainer )
14
14
from .channel_unit import Channel , ChannelUnit
15
15
16
16
17
17
class MutableChannelUnit (ChannelUnit ):
18
18
19
19
# init methods
20
20
def __init__ (self , num_channels : int , ** kwargs ) -> None :
21
- """MutableChannelUnit inherits from ChannelUnit, which manages
22
- channels with channel-dependency.
21
+ """MutableChannelUnit inherits from ChannelUnit, which manages channels
22
+ with channel-dependency.
23
23
24
24
Compared with ChannelUnit, MutableChannelUnit defines the core
25
25
interfaces for pruning. By inheriting MutableChannelUnit,
@@ -44,6 +44,70 @@ def __init__(self, num_channels: int, **kwargs) -> None:
44
44
45
45
super ().__init__ (num_channels )
46
46
47
+ @classmethod
48
+ def init_from_mutable_channel (cls , mutable_channel : BaseMutableChannel ):
49
+ unit = cls (mutable_channel .num_channels )
50
+ return unit
51
+
52
+ @classmethod
53
+ def init_from_predefined_model (cls , model : nn .Module ):
54
+ """Initialize units using the model with pre-defined dynamicops and
55
+ mutable-channels."""
56
+
57
+ def process_container (contanier : MutableChannelContainer ,
58
+ module ,
59
+ module_name ,
60
+ mutable2units ,
61
+ is_output = True ):
62
+ for index , mutable in contanier .mutable_channels .items ():
63
+ if isinstance (mutable , DerivedMutable ):
64
+ source_mutables : Set = \
65
+ mutable ._trace_source_mutables ()
66
+ source_channel_mutables = [
67
+ mutable for mutable in source_mutables
68
+ if isinstance (mutable , BaseMutableChannel )
69
+ ]
70
+ assert len (source_channel_mutables ) == 1 , (
71
+ 'only support one mutable channel '
72
+ 'used in DerivedMutable' )
73
+ mutable = list (source_channel_mutables )[0 ]
74
+
75
+ if mutable not in mutable2units :
76
+ mutable2units [mutable ] = cls .init_from_mutable_channel (
77
+ mutable )
78
+
79
+ unit : MutableChannelUnit = mutable2units [mutable ]
80
+ if is_output :
81
+ unit .add_ouptut_related (
82
+ Channel (
83
+ module_name ,
84
+ module ,
85
+ index ,
86
+ is_output_channel = is_output ))
87
+ else :
88
+ unit .add_input_related (
89
+ Channel (
90
+ module_name ,
91
+ module ,
92
+ index ,
93
+ is_output_channel = is_output ))
94
+
95
+ mutable2units : Dict = {}
96
+ for name , module in model .named_modules ():
97
+ if isinstance (module , DynamicChannelMixin ):
98
+ in_container : MutableChannelContainer = \
99
+ module .get_mutable_attr (
100
+ 'in_channels' )
101
+ out_container : MutableChannelContainer = \
102
+ module .get_mutable_attr (
103
+ 'out_channels' )
104
+ process_container (in_container , module , name , mutable2units ,
105
+ False )
106
+ process_container (out_container , module , name , mutable2units ,
107
+ True )
108
+ units = list (mutable2units .values ())
109
+ return units
110
+
47
111
# properties
48
112
49
113
@property
@@ -97,7 +161,7 @@ def prepare_for_pruning(self, model):
97
161
98
162
For example, we need to register mutables to dynamic-ops.
99
163
"""
100
- raise not NotImplementedError
164
+ raise NotImplementedError
101
165
102
166
# pruning: choice-related
103
167
0 commit comments