1
1
import copy
2
+ import os .path as osp
2
3
4
+ import mmcv
3
5
import numpy as np
4
6
import pytest
5
7
import torch
6
- import torchvision .models as models
7
8
9
+ from mmaction .models import build_recognizer
8
10
from mmaction .utils import register_module_hooks
9
11
from mmaction .utils .module_hooks import GPUNormalize
10
12
@@ -13,40 +15,56 @@ def test_register_module_hooks():
13
15
_module_hooks = [
14
16
dict (
15
17
type = 'GPUNormalize' ,
18
+ hooked_module = 'backbone' ,
16
19
hook_pos = 'forward_pre' ,
17
20
input_format = 'NCHW' ,
18
21
mean = [123.675 , 116.28 , 103.53 ],
19
22
std = [58.395 , 57.12 , 57.375 ])
20
23
]
21
24
25
+ repo_dpath = osp .dirname (osp .dirname (osp .dirname (__file__ )))
26
+ config_fpath = osp .join (repo_dpath , 'configs/_base_/models/tsm_r50.py' )
27
+ config = mmcv .Config .fromfile (config_fpath )
28
+ config .model ['backbone' ]['pretrained' ] = None
29
+
22
30
# case 1
23
31
module_hooks = copy .deepcopy (_module_hooks )
24
32
module_hooks [0 ]['hook_pos' ] = 'forward_pre'
25
- resnet = models . resnet50 ( )
26
- handles = register_module_hooks (resnet , module_hooks )
27
- assert resnet ._forward_pre_hooks [
33
+ recognizer = build_recognizer ( config . model )
34
+ handles = register_module_hooks (recognizer , module_hooks )
35
+ assert recognizer . backbone ._forward_pre_hooks [
28
36
handles [0 ].id ].__name__ == 'normalize_hook'
29
37
30
38
# case 2
31
39
module_hooks = copy .deepcopy (_module_hooks )
32
40
module_hooks [0 ]['hook_pos' ] = 'forward'
33
- resnet = models .resnet50 ()
34
- handles = register_module_hooks (resnet , module_hooks )
35
- assert resnet ._forward_hooks [handles [0 ].id ].__name__ == 'normalize_hook'
41
+ recognizer = build_recognizer (config .model )
42
+ handles = register_module_hooks (recognizer , module_hooks )
43
+ assert recognizer .backbone ._forward_hooks [
44
+ handles [0 ].id ].__name__ == 'normalize_hook'
36
45
37
46
# case 3
38
47
module_hooks = copy .deepcopy (_module_hooks )
48
+ module_hooks [0 ]['hooked_module' ] = 'cls_head'
39
49
module_hooks [0 ]['hook_pos' ] = 'backward'
40
- resnet = models .resnet50 ()
41
- handles = register_module_hooks (resnet , module_hooks )
42
- assert resnet ._backward_hooks [handles [0 ].id ].__name__ == 'normalize_hook'
50
+ recognizer = build_recognizer (config .model )
51
+ handles = register_module_hooks (recognizer , module_hooks )
52
+ assert recognizer .cls_head ._backward_hooks [
53
+ handles [0 ].id ].__name__ == 'normalize_hook'
43
54
44
55
# case 4
45
56
module_hooks = copy .deepcopy (_module_hooks )
46
57
module_hooks [0 ]['hook_pos' ] = '_other_pos'
47
- resnet = models .resnet50 ()
58
+ recognizer = build_recognizer (config .model )
59
+ with pytest .raises (ValueError ):
60
+ handles = register_module_hooks (recognizer , module_hooks )
61
+
62
+ # case 5
63
+ module_hooks = copy .deepcopy (_module_hooks )
64
+ module_hooks [0 ]['hooked_module' ] = '_other_module'
65
+ recognizer = build_recognizer (config .model )
48
66
with pytest .raises (ValueError ):
49
- handles = register_module_hooks (resnet , module_hooks )
67
+ handles = register_module_hooks (recognizer , module_hooks )
50
68
51
69
52
70
def test_gpu_normalize ():
@@ -72,9 +90,8 @@ def check_normalize(origin_imgs, result_imgs, norm_cfg):
72
90
assert gpu_normalize ._mean .shape == (1 , 3 , 1 , 1 )
73
91
imgs = np .random .randint (256 , size = (2 , 240 , 320 , 3 ), dtype = np .uint8 )
74
92
_input = (torch .tensor (imgs ).permute (0 , 3 , 1 , 2 ), )
75
- resnet = models .resnet50 ()
76
93
normalize_hook = gpu_normalize .hook_func ()
77
- _input = normalize_hook (resnet , _input )
94
+ _input = normalize_hook (torch . nn . Module , _input )
78
95
result_imgs = np .array (_input [0 ].permute (0 , 2 , 3 , 1 ))
79
96
check_normalize (imgs , result_imgs , gpu_normalize_cfg )
80
97
0 commit comments