Skip to content

Commit

Permalink
fix bug when deploy a pruned model to cuda. (#495)
Browse files Browse the repository at this point in the history
Co-authored-by: liukai <your_email@abc.example>
  • Loading branch information
LKJacky and liukai authored Apr 10, 2023
1 parent 90c5435 commit 6c06849
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 29 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from mmrazor.models.mutables import DerivedMutable
from mmrazor.models.mutables.mutable_channel import (BaseMutableChannel,
MutableChannelContainer)
from mmrazor.models.utils import get_module_device
from .channel_unit import Channel, ChannelUnit


Expand Down Expand Up @@ -227,7 +228,7 @@ def get_module(model, name):
module = get_module(model, channel.name)
if type(module) in dynamicop_map:
new_module = dynamicop_map[type(module)].convert_from(
module)
module).to(get_module_device(module))
replace_op(model, channel.name, new_module)
channel.module = new_module
else:
Expand All @@ -237,19 +238,22 @@ def get_module(model, name):
def _register_channel_container(
model: nn.Module, container_class: Type[MutableChannelContainer]):
"""register channel container for dynamic ops."""
device = get_module_device(model)
for module in model.modules():
if isinstance(module, DynamicChannelMixin):
in_channels = getattr(module,
module.attr_mappings['in_channels'], 0)
if module.get_mutable_attr('in_channels') is None:
module.register_mutable_attr('in_channels',
container_class(in_channels))
module.register_mutable_attr(
'in_channels',
container_class(in_channels).to(device))
out_channels = getattr(module,
module.attr_mappings['out_channels'], 0)
if module.get_mutable_attr('out_channels') is None:

module.register_mutable_attr('out_channels',
container_class(out_channels))
module.register_mutable_attr(
'out_channels',
container_class(out_channels).to(device))

def _register_mutable_channel(self, mutable_channel: BaseMutableChannel):
# register mutable_channel
Expand Down
10 changes: 7 additions & 3 deletions mmrazor/models/mutators/channel_mutator/channel_mutator.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,9 @@ def prepare_from_supernet(self, supernet: Module) -> None:
1. parse the model and get MutableChannelUnits.
2. call unit.prepare_for_pruning for each unit.
"""
from mmrazor.models.utils import get_module_device
device = get_module_device(supernet)

self._name2module = dict(supernet.named_modules())

if isinstance(self.parse_cfg,
Expand All @@ -115,10 +118,11 @@ def prepare_from_supernet(self, supernet: Module) -> None:
units = self._prepare_from_predefined_model(supernet)
else:
raise NotImplementedError()
for i in range(len(units)):
units[i] = units[i].to(device)
units[i].prepare_for_pruning(supernet)
self._name2unit[units[i].name] = units[i]

for unit in units:
unit.prepare_for_pruning(supernet)
self._name2unit[unit.name] = unit
self.units = ModuleList(units)

@property
Expand Down
16 changes: 9 additions & 7 deletions mmrazor/models/task_modules/tracer/channel_analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from mmrazor.structures.graph.module_graph import (FxTracerToGraphConverter,
PathToGraphConverter)
from mmrazor.structures.graph.pseudo_fx_graph import parse_torch_graph
from mmrazor.utils import print_log
from ..demo_inputs import BaseDemoInput, DefaultDemoInput
from .backward_tracer import BackwardTracer
from .fx_tracer import MMFxTracer
Expand Down Expand Up @@ -136,9 +137,9 @@ def _fx_trace(self, model):
else:
return self.tracer.trace(model)

def _find_mutable_units(self, model, units_config: Dict):
def _find_mutable_units(self, model: nn.Module, units_config: Dict):
"""Test the tracer result and filter unforwardable units."""
model = copy.deepcopy(model)
model = copy.deepcopy(model).cpu()
units: List[SequentialMutableChannelUnit] = [
SequentialMutableChannelUnit.init_from_cfg(model, cfg)
for cfg in units_config.values()
Expand All @@ -156,16 +157,17 @@ def _find_mutable_units(self, model, units_config: Dict):
inputs['mode'] = mode
template_output = model(**inputs)
break
except Exception:
pass
except Exception as e:
print_log(f'Forward failed in {mode} mode as {e}')
else:
try:
template_output = model(inputs)
except Exception:
pass
except Exception as e:
print_log(f'Forward failed in as {e}')
if template_output is None:
raise Exception(
'Forward failed, there may be an error in demo input.')
'Forward failed, there may be an error in demo input.',
f'{inputs}')
mutable_units = find_mutable(model, mutable_units, units, inputs,
template_output)
mutable_unit_config = {}
Expand Down
29 changes: 15 additions & 14 deletions mmrazor/models/utils/expandable_utils/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from mmrazor.models.architectures import dynamic_ops
from mmrazor.models.mutables import MutableChannelContainer
from mmrazor.models.utils import get_module_device


class ExpandableMixin:
Expand Down Expand Up @@ -65,24 +66,20 @@ def expanded_out_channel(self):
@property
def mutable_in_mask(self):
"""Return the mutable in mask."""
device = get_module_device(self)
if self.in_mutable is not None:
return self.in_mutable.current_mask
return self.in_mutable.current_mask.to(device)
else:
if hasattr(self, 'weight'):
return self.weight.new_ones([self.expanded_in_channel])
else:
return torch.ones([self.expanded_in_channel])
return torch.ones([self.expanded_in_channel]).to(device)

@property
def mutable_out_mask(self):
"""Return the mutable out mask."""
device = get_module_device(self)
if self.out_mutable is not None:
return self.out_mutable.current_mask
return self.out_mutable.current_mask.to(device)
else:
if hasattr(self, 'weight'):
return self.weight.new_ones([self.expanded_out_channel])
else:
return torch.ones([self.expanded_out_channel])
return torch.ones([self.expanded_out_channel]).to(device)

@property
def in_mutable(self) -> MutableChannelContainer:
Expand Down Expand Up @@ -152,7 +149,8 @@ def _get_expand_op_normal_conv(self, in_c, out_c, zero=False):

module = nn.Conv2d(in_c, out_c, self.kernel_size, self.stride,
self.padding, self.dilation, self.groups, self.bias
is not None, self.padding_mode)
is not None,
self.padding_mode).to(get_module_device(self))
if zero:
ExpandableMixin.zero_weight_(module)

Expand All @@ -169,7 +167,8 @@ def _get_expand_op_dw_conv(self, in_c, out_c, zero=False):
assert in_c == out_c
module = nn.Conv2d(in_c, out_c, self.kernel_size, self.stride,
self.padding, self.dilation, in_c, self.bias
is not None, self.padding_mode)
is not None,
self.padding_mode).to(get_module_device(self))
if zero:
ExpandableMixin.zero_weight_(module)

Expand All @@ -194,7 +193,8 @@ def _original_out_channel(self):
return self.out_features

def get_expand_op(self, in_c, out_c, zero=False):
module = nn.Linear(in_c, out_c, self.bias is not None)
module = nn.Linear(in_c, out_c, self.bias
is not None).to(get_module_device(self))
if zero:
ExpandableMixin.zero_weight_(module)

Expand All @@ -221,7 +221,8 @@ def _original_out_channel(self):
def get_expand_op(self, in_c, out_c, zero=False):
assert in_c == out_c
module = nn.BatchNorm2d(in_c, self.eps, self.momentum, self.affine,
self.track_running_stats)
self.track_running_stats).to(
get_module_device(self))
if zero:
ExpandableMixin.zero_weight_(module)

Expand Down

0 comments on commit 6c06849

Please # to comment.