Skip to content

Commit

Permalink
[Fix] Fix timm backbone initialization
Browse files Browse the repository at this point in the history
  • Loading branch information
cir7 committed May 26, 2023
1 parent 3f69540 commit 394cb73
Showing 1 changed file with 7 additions and 1 deletion.
8 changes: 7 additions & 1 deletion mmaction/models/recognizers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,13 +118,19 @@ def with_cls_head(self) -> bool:

def init_weights(self) -> None:
"""Initialize the model network weights."""
super().init_weights()
if self.backbone_from in ['torchvision', 'timm']:
warnings.warn('We do not initialize weights for backbones in '
f'{self.backbone_from}, since the weights for '
f'backbones in {self.backbone_from} are initialized '
'in their __init__ functions.')

def fake_init():
pass

# avoid repeated initialization
self.backbone.init_weights = fake_init
super().init_weights()

def loss(self, inputs: torch.Tensor, data_samples: SampleList,
**kwargs) -> dict:
"""Calculate losses from a batch of inputs and data samples.
Expand Down

0 comments on commit 394cb73

Please # to comment.