Skip to content

Commit 36da448

Browse files
author
liukai
committed
add args min_channel
1 parent 062ffe7 commit 36da448

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

projects/group_fisher/modules/group_fisher_channel_mutator.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -37,10 +37,12 @@ def __init__(self,
3737
demo_input=(1, 3, 224, 224),
3838
tracer_type='FxTracer'),
3939
min_ratio=0.0,
40+
min_channel=0,
4041
**kwargs) -> None:
4142
super().__init__(channel_unit_cfg, parse_cfg, **kwargs)
4243
self.mutable_units: List[GroupFisherChannelUnit]
4344
self.min_ratio = min_ratio
45+
self.min_channel = min_channel
4446

4547
def start_record_info(self) -> None:
4648
"""Start recording the related information."""
@@ -64,7 +66,7 @@ def try_prune(self) -> None:
6466
min_unit = self.mutable_units[0]
6567
for unit in self.mutable_units:
6668
if unit.mutable_channel.activated_channels > max(
67-
20, (unit.num_channels * self.min_ratio)):
69+
self.min_channel, (unit.num_channels * self.min_ratio), 0):
6870
imp = unit.importance()
6971
if imp.isnan().any():
7072
if dist.get_rank() == 0:

0 commit comments

Comments
 (0)