@@ -18,13 +18,23 @@ def get_encoders():
18
18
"resnext101_32x32d" ,
19
19
"resnext101_32x48d" ,
20
20
]
21
- encoders = smp .encoders .get_encoder_names ( )
21
+ encoders = list ( smp .encoders .encoders . keys () )
22
22
if IS_TRAVIS :
23
23
encoders = [e for e in encoders if e not in travis_exclude_encoders ]
24
24
return encoders
25
25
26
+ def get_timm_u_encoders ():
27
+ timm_exclude_encoders = [
28
+ 'vit_*' , 'tnt_*' , 'pit_*' ,
29
+ '*iabn*' , 'tresnet*' , # models using inplace abn
30
+ 'dla*' , 'hrnet*' , # hopefully fix at some point
31
+ ]
32
+
33
+ return smp .encoders .timm_universal_encoders (exclude_filters = timm_exclude_encoders )
34
+
26
35
27
36
ENCODERS = get_encoders ()
37
+ ENCODERS_TIMM_U = get_timm_u_encoders ()
28
38
DEFAULT_ENCODER = "resnet18"
29
39
30
40
@@ -54,24 +64,26 @@ def _test_forward_backward(model, sample, test_shape=False):
54
64
assert out .shape [2 :] == sample .shape [2 :]
55
65
56
66
57
- @pytest .mark .parametrize ("encoder_name" , ENCODERS )
67
+ @pytest .mark .parametrize ("encoder_name" , ENCODERS + ENCODERS_TIMM_U )
58
68
@pytest .mark .parametrize ("encoder_depth" , [3 , 5 ])
59
69
@pytest .mark .parametrize ("model_class" , [smp .FPN , smp .PSPNet , smp .Linknet , smp .Unet , smp .UnetPlusPlus ])
60
70
def test_forward (model_class , encoder_name , encoder_depth , ** kwargs ):
61
- if model_class is smp .Unet or model_class is smp .UnetPlusPlus or model_class is smp .MAnet :
62
- kwargs ["decoder_channels" ] = (16 , 16 , 16 , 16 , 16 )[- encoder_depth :]
63
- model = model_class (
64
- encoder_name , encoder_depth = encoder_depth , encoder_weights = None , ** kwargs
65
- )
66
- sample = get_sample (model_class )
67
- model .eval ()
68
- if encoder_depth == 5 and model_class != smp .PSPNet :
69
- test_shape = True
70
- else :
71
- test_shape = False
72
-
73
- _test_forward (model , sample , test_shape )
74
-
71
+ try :
72
+ if model_class is smp .Unet or model_class is smp .UnetPlusPlus or model_class is smp .MAnet :
73
+ kwargs ["decoder_channels" ] = (16 , 16 , 16 , 16 , 16 )[- encoder_depth :]
74
+ model = model_class (
75
+ encoder_name , encoder_depth = encoder_depth , encoder_weights = None , ** kwargs
76
+ )
77
+ sample = get_sample (model_class )
78
+ model .eval ()
79
+ if encoder_depth == 5 and model_class != smp .PSPNet :
80
+ test_shape = True
81
+ else :
82
+ test_shape = False
83
+ _test_forward (model , sample , test_shape )
84
+ except Exception as e :
85
+ print ('\n \r {}-{}: Exception {}' .format (model_class .__name__ , encoder_name , e ))
86
+ assert False , 'Exception {}' .format (e )
75
87
76
88
@pytest .mark .parametrize (
77
89
"model_class" ,
0 commit comments