Skip to content

Commit ecf0208

Browse files
authored
fix module hook bug (#667)
* fix bugs * modify module register * modify var name * modify unittest
1 parent 4711717 commit ecf0208

File tree

6 files changed

+46
-20
lines changed

6 files changed

+46
-20
lines changed

configs/recognition/tsm/tsm_r50_gpu_normalize_1x1x8_50e_kinetics400_rgb.py

+1
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
module_hooks = [
77
dict(
88
type='GPUNormalize',
9+
hooked_module='backbone',
910
hook_pos='forward_pre',
1011
input_format='NCHW',
1112
mean=[123.675, 116.28, 103.53],

mmaction/utils/module_hooks.py

+9-3
Original file line numberDiff line numberDiff line change
@@ -7,15 +7,21 @@
77
def register_module_hooks(Module, module_hooks_list):
88
handles = []
99
for module_hook_cfg in module_hooks_list:
10+
hooked_module_name = module_hook_cfg.pop('hooked_module', 'backbone')
11+
if not hasattr(Module, hooked_module_name):
12+
raise ValueError(
13+
f'{Module.__class__} has no {hooked_module_name}!')
14+
hooked_module = getattr(Module, hooked_module_name)
1015
hook_pos = module_hook_cfg.pop('hook_pos', 'forward_pre')
16+
1117
if hook_pos == 'forward_pre':
12-
handle = Module.register_forward_pre_hook(
18+
handle = hooked_module.register_forward_pre_hook(
1319
build_from_cfg(module_hook_cfg, MODULE_HOOKS).hook_func())
1420
elif hook_pos == 'forward':
15-
handle = Module.register_forward_hook(
21+
handle = hooked_module.register_forward_hook(
1622
build_from_cfg(module_hook_cfg, MODULE_HOOKS).hook_func())
1723
elif hook_pos == 'backward':
18-
handle = Module.register_backward_hook(
24+
handle = hooked_module.register_backward_hook(
1925
build_from_cfg(module_hook_cfg, MODULE_HOOKS).hook_func())
2026
else:
2127
raise ValueError(

setup.cfg

+1-1
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,6 @@ line_length = 79
1919
multi_line_output = 0
2020
known_standard_library = pkg_resources,setuptools
2121
known_first_party = mmaction
22-
known_third_party = cv2,joblib,matplotlib,mmcv,numpy,pandas,pytest,scipy,seaborn,titlecase,torch,torchvision,tqdm
22+
known_third_party = cv2,joblib,matplotlib,mmcv,numpy,pandas,pytest,scipy,seaborn,titlecase,torch,tqdm
2323
no_lines_before = STDLIB,LOCALFOLDER
2424
default_section = THIRDPARTY

tests/test_utils/test_module_hooks.py

+31-14
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
import copy
2+
import os.path as osp
23

4+
import mmcv
35
import numpy as np
46
import pytest
57
import torch
6-
import torchvision.models as models
78

9+
from mmaction.models import build_recognizer
810
from mmaction.utils import register_module_hooks
911
from mmaction.utils.module_hooks import GPUNormalize
1012

@@ -13,40 +15,56 @@ def test_register_module_hooks():
1315
_module_hooks = [
1416
dict(
1517
type='GPUNormalize',
18+
hooked_module='backbone',
1619
hook_pos='forward_pre',
1720
input_format='NCHW',
1821
mean=[123.675, 116.28, 103.53],
1922
std=[58.395, 57.12, 57.375])
2023
]
2124

25+
repo_dpath = osp.dirname(osp.dirname(osp.dirname(__file__)))
26+
config_fpath = osp.join(repo_dpath, 'configs/_base_/models/tsm_r50.py')
27+
config = mmcv.Config.fromfile(config_fpath)
28+
config.model['backbone']['pretrained'] = None
29+
2230
# case 1
2331
module_hooks = copy.deepcopy(_module_hooks)
2432
module_hooks[0]['hook_pos'] = 'forward_pre'
25-
resnet = models.resnet50()
26-
handles = register_module_hooks(resnet, module_hooks)
27-
assert resnet._forward_pre_hooks[
33+
recognizer = build_recognizer(config.model)
34+
handles = register_module_hooks(recognizer, module_hooks)
35+
assert recognizer.backbone._forward_pre_hooks[
2836
handles[0].id].__name__ == 'normalize_hook'
2937

3038
# case 2
3139
module_hooks = copy.deepcopy(_module_hooks)
3240
module_hooks[0]['hook_pos'] = 'forward'
33-
resnet = models.resnet50()
34-
handles = register_module_hooks(resnet, module_hooks)
35-
assert resnet._forward_hooks[handles[0].id].__name__ == 'normalize_hook'
41+
recognizer = build_recognizer(config.model)
42+
handles = register_module_hooks(recognizer, module_hooks)
43+
assert recognizer.backbone._forward_hooks[
44+
handles[0].id].__name__ == 'normalize_hook'
3645

3746
# case 3
3847
module_hooks = copy.deepcopy(_module_hooks)
48+
module_hooks[0]['hooked_module'] = 'cls_head'
3949
module_hooks[0]['hook_pos'] = 'backward'
40-
resnet = models.resnet50()
41-
handles = register_module_hooks(resnet, module_hooks)
42-
assert resnet._backward_hooks[handles[0].id].__name__ == 'normalize_hook'
50+
recognizer = build_recognizer(config.model)
51+
handles = register_module_hooks(recognizer, module_hooks)
52+
assert recognizer.cls_head._backward_hooks[
53+
handles[0].id].__name__ == 'normalize_hook'
4354

4455
# case 4
4556
module_hooks = copy.deepcopy(_module_hooks)
4657
module_hooks[0]['hook_pos'] = '_other_pos'
47-
resnet = models.resnet50()
58+
recognizer = build_recognizer(config.model)
59+
with pytest.raises(ValueError):
60+
handles = register_module_hooks(recognizer, module_hooks)
61+
62+
# case 5
63+
module_hooks = copy.deepcopy(_module_hooks)
64+
module_hooks[0]['hooked_module'] = '_other_module'
65+
recognizer = build_recognizer(config.model)
4866
with pytest.raises(ValueError):
49-
handles = register_module_hooks(resnet, module_hooks)
67+
handles = register_module_hooks(recognizer, module_hooks)
5068

5169

5270
def test_gpu_normalize():
@@ -72,9 +90,8 @@ def check_normalize(origin_imgs, result_imgs, norm_cfg):
7290
assert gpu_normalize._mean.shape == (1, 3, 1, 1)
7391
imgs = np.random.randint(256, size=(2, 240, 320, 3), dtype=np.uint8)
7492
_input = (torch.tensor(imgs).permute(0, 3, 1, 2), )
75-
resnet = models.resnet50()
7693
normalize_hook = gpu_normalize.hook_func()
77-
_input = normalize_hook(resnet, _input)
94+
_input = normalize_hook(torch.nn.Module, _input)
7895
result_imgs = np.array(_input[0].permute(0, 2, 3, 1))
7996
check_normalize(imgs, result_imgs, gpu_normalize_cfg)
8097

tools/test.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,8 @@ def main():
178178
model = build_model(
179179
cfg.model, train_cfg=None, test_cfg=cfg.get('test_cfg'))
180180

181-
register_module_hooks(model.backbone, cfg.module_hooks)
181+
if len(cfg.module_hooks) > 0:
182+
register_module_hooks(model, cfg.module_hooks)
182183

183184
fp16_cfg = cfg.get('fp16', None)
184185
if fp16_cfg is not None:

tools/train.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,8 @@ def main():
145145
train_cfg=cfg.get('train_cfg'),
146146
test_cfg=cfg.get('test_cfg'))
147147

148-
register_module_hooks(model.backbone, cfg.module_hooks)
148+
if len(cfg.module_hooks) > 0:
149+
register_module_hooks(model, cfg.module_hooks)
149150

150151
if cfg.omnisource:
151152
# If omnisource flag is set, cfg.data.train should be a list

0 commit comments

Comments
 (0)