diff --git a/mmrazor/models/mutables/mutable_channel/units/mutable_channel_unit.py b/mmrazor/models/mutables/mutable_channel/units/mutable_channel_unit.py index f99b9b7cc..d0a2deff0 100644 --- a/mmrazor/models/mutables/mutable_channel/units/mutable_channel_unit.py +++ b/mmrazor/models/mutables/mutable_channel/units/mutable_channel_unit.py @@ -11,6 +11,7 @@ from mmrazor.models.mutables import DerivedMutable from mmrazor.models.mutables.mutable_channel import (BaseMutableChannel, MutableChannelContainer) +from mmrazor.models.utils import get_module_device from .channel_unit import Channel, ChannelUnit @@ -227,7 +228,7 @@ def get_module(model, name): module = get_module(model, channel.name) if type(module) in dynamicop_map: new_module = dynamicop_map[type(module)].convert_from( - module) + module).to(get_module_device(module)) replace_op(model, channel.name, new_module) channel.module = new_module else: @@ -237,19 +238,22 @@ def get_module(model, name): def _register_channel_container( model: nn.Module, container_class: Type[MutableChannelContainer]): """register channel container for dynamic ops.""" + device = get_module_device(model) for module in model.modules(): if isinstance(module, DynamicChannelMixin): in_channels = getattr(module, module.attr_mappings['in_channels'], 0) if module.get_mutable_attr('in_channels') is None: - module.register_mutable_attr('in_channels', - container_class(in_channels)) + module.register_mutable_attr( + 'in_channels', + container_class(in_channels).to(device)) out_channels = getattr(module, module.attr_mappings['out_channels'], 0) if module.get_mutable_attr('out_channels') is None: - module.register_mutable_attr('out_channels', - container_class(out_channels)) + module.register_mutable_attr( + 'out_channels', + container_class(out_channels).to(device)) def _register_mutable_channel(self, mutable_channel: BaseMutableChannel): # register mutable_channel diff --git a/mmrazor/models/mutators/channel_mutator/channel_mutator.py b/mmrazor/models/mutators/channel_mutator/channel_mutator.py index 7c2c39338..910992e1e 100644 --- a/mmrazor/models/mutators/channel_mutator/channel_mutator.py +++ b/mmrazor/models/mutators/channel_mutator/channel_mutator.py @@ -99,6 +99,9 @@ def prepare_from_supernet(self, supernet: Module) -> None: 1. parse the model and get MutableChannelUnits. 2. call unit.prepare_for_pruning for each unit. """ + from mmrazor.models.utils import get_module_device + device = get_module_device(supernet) + self._name2module = dict(supernet.named_modules()) if isinstance(self.parse_cfg, @@ -115,10 +118,11 @@ def prepare_from_supernet(self, supernet: Module) -> None: units = self._prepare_from_predefined_model(supernet) else: raise NotImplementedError() + for i in range(len(units)): + units[i] = units[i].to(device) + units[i].prepare_for_pruning(supernet) + self._name2unit[units[i].name] = units[i] - for unit in units: - unit.prepare_for_pruning(supernet) - self._name2unit[unit.name] = unit self.units = ModuleList(units) @property diff --git a/mmrazor/models/task_modules/tracer/channel_analyzer.py b/mmrazor/models/task_modules/tracer/channel_analyzer.py index 2ca55d0e5..9f754e97e 100644 --- a/mmrazor/models/task_modules/tracer/channel_analyzer.py +++ b/mmrazor/models/task_modules/tracer/channel_analyzer.py @@ -32,6 +32,7 @@ from mmrazor.structures.graph.module_graph import (FxTracerToGraphConverter, PathToGraphConverter) from mmrazor.structures.graph.pseudo_fx_graph import parse_torch_graph +from mmrazor.utils import print_log from ..demo_inputs import BaseDemoInput, DefaultDemoInput from .backward_tracer import BackwardTracer from .fx_tracer import MMFxTracer @@ -136,9 +137,9 @@ def _fx_trace(self, model): else: return self.tracer.trace(model) - def _find_mutable_units(self, model, units_config: Dict): + def _find_mutable_units(self, model: nn.Module, units_config: Dict): """Test the tracer result and filter unforwardable units.""" - model = copy.deepcopy(model) + model = copy.deepcopy(model).cpu() units: List[SequentialMutableChannelUnit] = [ SequentialMutableChannelUnit.init_from_cfg(model, cfg) for cfg in units_config.values() @@ -156,16 +157,17 @@ def _find_mutable_units(self, model, units_config: Dict): inputs['mode'] = mode template_output = model(**inputs) break - except Exception: - pass + except Exception as e: + print_log(f'Forward failed in {mode} mode as {e}') else: try: template_output = model(inputs) - except Exception: - pass + except Exception as e: + print_log(f'Forward failed in as {e}') if template_output is None: raise Exception( - 'Forward failed, there may be an error in demo input.') + 'Forward failed, there may be an error in demo input.', + f'{inputs}') mutable_units = find_mutable(model, mutable_units, units, inputs, template_output) mutable_unit_config = {} diff --git a/mmrazor/models/utils/expandable_utils/ops.py b/mmrazor/models/utils/expandable_utils/ops.py index fa4c41db9..f2bc2b046 100644 --- a/mmrazor/models/utils/expandable_utils/ops.py +++ b/mmrazor/models/utils/expandable_utils/ops.py @@ -4,6 +4,7 @@ from mmrazor.models.architectures import dynamic_ops from mmrazor.models.mutables import MutableChannelContainer +from mmrazor.models.utils import get_module_device class ExpandableMixin: @@ -65,24 +66,20 @@ def expanded_out_channel(self): @property def mutable_in_mask(self): """Return the mutable in mask.""" + device = get_module_device(self) if self.in_mutable is not None: - return self.in_mutable.current_mask + return self.in_mutable.current_mask.to(device) else: - if hasattr(self, 'weight'): - return self.weight.new_ones([self.expanded_in_channel]) - else: - return torch.ones([self.expanded_in_channel]) + return torch.ones([self.expanded_in_channel]).to(device) @property def mutable_out_mask(self): """Return the mutable out mask.""" + device = get_module_device(self) if self.out_mutable is not None: - return self.out_mutable.current_mask + return self.out_mutable.current_mask.to(device) else: - if hasattr(self, 'weight'): - return self.weight.new_ones([self.expanded_out_channel]) - else: - return torch.ones([self.expanded_out_channel]) + return torch.ones([self.expanded_out_channel]).to(device) @property def in_mutable(self) -> MutableChannelContainer: @@ -152,7 +149,8 @@ def _get_expand_op_normal_conv(self, in_c, out_c, zero=False): module = nn.Conv2d(in_c, out_c, self.kernel_size, self.stride, self.padding, self.dilation, self.groups, self.bias - is not None, self.padding_mode) + is not None, + self.padding_mode).to(get_module_device(self)) if zero: ExpandableMixin.zero_weight_(module) @@ -169,7 +167,8 @@ def _get_expand_op_dw_conv(self, in_c, out_c, zero=False): assert in_c == out_c module = nn.Conv2d(in_c, out_c, self.kernel_size, self.stride, self.padding, self.dilation, in_c, self.bias - is not None, self.padding_mode) + is not None, + self.padding_mode).to(get_module_device(self)) if zero: ExpandableMixin.zero_weight_(module) @@ -194,7 +193,8 @@ def _original_out_channel(self): return self.out_features def get_expand_op(self, in_c, out_c, zero=False): - module = nn.Linear(in_c, out_c, self.bias is not None) + module = nn.Linear(in_c, out_c, self.bias + is not None).to(get_module_device(self)) if zero: ExpandableMixin.zero_weight_(module) @@ -221,7 +221,8 @@ def _original_out_channel(self): def get_expand_op(self, in_c, out_c, zero=False): assert in_c == out_c module = nn.BatchNorm2d(in_c, self.eps, self.momentum, self.affine, - self.track_running_stats) + self.track_running_stats).to( + get_module_device(self)) if zero: ExpandableMixin.zero_weight_(module)