Skip to content

Commit cf0fcb4

Browse files
committed
IMP: filter pretrained models, cleanup
1 parent b78f91c commit cf0fcb4

File tree

3 files changed

+47
-30
lines changed

3 files changed

+47
-30
lines changed

segmentation_models_pytorch/encoders/__init__.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,11 @@
3939

4040
def get_encoder(name, in_channels=3, depth=5, weights=None):
4141

42-
if name in timm_universal_encoders:
43-
encoder = TimmUniversalEncoder(encoder_name=name, in_channels=in_channels, depth=depth, pretrained=weights is not None)
42+
if name in timm_universal_encoders():
43+
pretrained = weights is not None
44+
if pretrained and name not in timm_universal_encoders(pretrained=True):
45+
raise KeyError("No pretrained weights for encoder `{}`.".format(name))
46+
encoder = TimmUniversalEncoder(encoder_name=name, in_channels=in_channels, depth=depth, pretrained=pretrained)
4447
global timm_setting
4548
timm_setting = encoder.formatted_settings
4649
return encoder
@@ -69,11 +72,11 @@ def get_encoder(name, in_channels=3, depth=5, weights=None):
6972

7073

7174
def get_encoder_names():
72-
return list(encoders.keys()) + timm_universal_encoders
75+
return list(encoders.keys()) + timm_universal_encoders()
7376

7477

7578
def get_preprocessing_params(encoder_name, pretrained="imagenet"):
76-
if encoder_name in timm_universal_encoders:
79+
if encoder_name in timm_universal_encoders():
7780
return timm_setting
7881

7982
settings = encoders[encoder_name]["pretrained_settings"]
Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
1-
from ._base import EncoderMixin
2-
from timm import create_model, list_models
31
import torch.nn as nn
2+
from timm import create_model, list_models
3+
4+
from ._base import EncoderMixin
45

56

67
class TimmUniversalEncoder(nn.Module, EncoderMixin):
@@ -10,26 +11,27 @@ def __init__(self, encoder_name, in_channels, depth=5, pretrained=True, **kwargs
1011
self._in_channels = in_channels
1112

1213
self.encoder = create_model(model_name=encoder_name[len(timm_universal_prefix):],
13-
in_chans=in_channels,
14-
exportable=True, # onnx export
15-
features_only=True,
16-
pretrained=pretrained,
17-
out_indices=tuple(range(depth))) # FIXME need to handle a few special cases for specific models
14+
in_chans=in_channels,
15+
exportable=True, # onnx export
16+
features_only=True,
17+
pretrained=pretrained,
18+
out_indices=tuple(range(depth)) # FIXME need to handle a few special cases for specific models
19+
)
1820

1921
channels = self.encoder.feature_info.channels()
20-
self._out_channels = (in_channels,) + tuple(channels)
22+
self._out_channels = (in_channels,) + tuple(channels)
2123

2224
self.formatted_settings = {}
2325
self.formatted_settings["input_space"] = "RGB"
2426
self.formatted_settings["input_range"] = (0, 1)
2527
self.formatted_settings["mean"] = self.encoder.default_cfg['mean']
2628
self.formatted_settings["std"] = self.encoder.default_cfg['std']
2729

28-
2930
def forward(self, x):
3031
features = self.encoder(x)
3132
return [x] + features
3233

3334

3435
timm_universal_prefix = 'timm-u-'
35-
timm_universal_encoders = [f'{timm_universal_prefix}{i}' for i in list_models()]
36+
def timm_universal_encoders(**kwargs):
37+
return [f'{timm_universal_prefix}{i}' for i in list_models(**kwargs)]

tests/test_models.py

Lines changed: 28 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,23 @@ def get_encoders():
1818
"resnext101_32x32d",
1919
"resnext101_32x48d",
2020
]
21-
encoders = smp.encoders.get_encoder_names()
21+
encoders = list(smp.encoders.encoders.keys())
2222
if IS_TRAVIS:
2323
encoders = [e for e in encoders if e not in travis_exclude_encoders]
2424
return encoders
2525

26+
def get_timm_u_encoders():
27+
timm_exclude_encoders = [
28+
'vit_*', 'tnt_*', 'pit_*',
29+
'*iabn*', 'tresnet*', # models using inplace abn
30+
'dla*', 'hrnet*', # hopefully fix at some point
31+
]
32+
33+
return smp.encoders.timm_universal_encoders(exclude_filters=timm_exclude_encoders)
34+
2635

2736
ENCODERS = get_encoders()
37+
ENCODERS_TIMM_U = get_timm_u_encoders()
2838
DEFAULT_ENCODER = "resnet18"
2939

3040

@@ -54,24 +64,26 @@ def _test_forward_backward(model, sample, test_shape=False):
5464
assert out.shape[2:] == sample.shape[2:]
5565

5666

57-
@pytest.mark.parametrize("encoder_name", ENCODERS)
67+
@pytest.mark.parametrize("encoder_name", ENCODERS + ENCODERS_TIMM_U)
5868
@pytest.mark.parametrize("encoder_depth", [3, 5])
5969
@pytest.mark.parametrize("model_class", [smp.FPN, smp.PSPNet, smp.Linknet, smp.Unet, smp.UnetPlusPlus])
6070
def test_forward(model_class, encoder_name, encoder_depth, **kwargs):
61-
if model_class is smp.Unet or model_class is smp.UnetPlusPlus or model_class is smp.MAnet:
62-
kwargs["decoder_channels"] = (16, 16, 16, 16, 16)[-encoder_depth:]
63-
model = model_class(
64-
encoder_name, encoder_depth=encoder_depth, encoder_weights=None, **kwargs
65-
)
66-
sample = get_sample(model_class)
67-
model.eval()
68-
if encoder_depth == 5 and model_class != smp.PSPNet:
69-
test_shape = True
70-
else:
71-
test_shape = False
72-
73-
_test_forward(model, sample, test_shape)
74-
71+
try:
72+
if model_class is smp.Unet or model_class is smp.UnetPlusPlus or model_class is smp.MAnet:
73+
kwargs["decoder_channels"] = (16, 16, 16, 16, 16)[-encoder_depth:]
74+
model = model_class(
75+
encoder_name, encoder_depth=encoder_depth, encoder_weights=None, **kwargs
76+
)
77+
sample = get_sample(model_class)
78+
model.eval()
79+
if encoder_depth == 5 and model_class != smp.PSPNet:
80+
test_shape = True
81+
else:
82+
test_shape = False
83+
_test_forward(model, sample, test_shape)
84+
except Exception as e:
85+
print('\n\r{}-{}: Exception {}'.format(model_class.__name__, encoder_name, e))
86+
assert False, 'Exception {}'.format(e)
7587

7688
@pytest.mark.parametrize(
7789
"model_class",

0 commit comments

Comments
 (0)