Skip to content

add timm-MobileNetV3 as an Encoder #355

New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Merged
merged 5 commits into from
Jul 4, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 17 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ The main features of this library are:

- High level API (just two lines to create a neural network)
- 9 models architectures for binary and multi class segmentation (including legendary Unet)
- 109 available encoders
- 115 available encoders
- All encoders have pre-trained weights for faster and better convergence

### [📚 Project Documentation 📚](http://smp.readthedocs.io/)
Expand Down Expand Up @@ -337,6 +337,22 @@ The following is a list of supported encoders in the SMP. Select the appropriate
</div>
</details>

<details>
<summary style="margin-left: 25px;">MobileNetV3</summary>
<div style="margin-left: 25px;">

|Encoder |Weights |Params, M |
|--------------------------------|:------------------------------:|:------------------------------:|
|timm-mobilenetv3_large_075 |imagenet |1.78M |
|timm-mobilenetv3_large_100 |imagenet |2.97M |
|timm-mobilenetv3_large_minimal_100|imagenet |1.41M |
|timm-mobilenetv3_small_075 |imagenet |0.57M |
|timm-mobilenetv3_small_100 |imagenet |0.93M |
|timm-mobilenetv3_small_minimal_100|imagenet |0.43M |

</div>
</details>


\* `ssl`, `swsl` - semi-supervised and weakly-supervised learning on ImageNet ([repo](https://github.com/facebookresearch/semi-supervised-ImageNet1K-models)).

Expand Down
19 changes: 19 additions & 0 deletions docs/encoders.rst
Original file line number Diff line number Diff line change
Expand Up @@ -316,3 +316,22 @@ VGG
+-------------+------------+-------------+
| vgg19\_bn | imagenet | 20M |
+-------------+------------+-------------+

MobileNetV3
~~~~~~~~~

+-----------------------------------+------------+-------------+
| Encoder | Weights | Params, M |
+===================================+============+=============+
| timm-mobilenetv3_large_075 | imagenet | 1.78M |
+-----------------------------------+------------+-------------+
| timm-mobilenetv3_large_100 | imagenet | 2.97M |
+-----------------------------------+------------+-------------+
| timm-mobilenetv3_large_minimal_100| imagenet | 1.41M |
+-----------------------------------+------------+-------------+
| timm-mobilenetv3_small_075 | imagenet | 0.57M |
+-----------------------------------+------------+-------------+
| timm-mobilenetv3_small_100 | imagenet | 0.93M |
+-----------------------------------+------------+-------------+
| timm-mobilenetv3_small_minimal_100| imagenet | 0.43M |
+-----------------------------------+------------+-------------+
2 changes: 2 additions & 0 deletions segmentation_models_pytorch/encoders/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from .timm_res2net import timm_res2net_encoders
from .timm_regnet import timm_regnet_encoders
from .timm_sknet import timm_sknet_encoders
from .timm_mobilenetv3 import timm_mobilenetv3_encoders
try:
from .timm_gernet import timm_gernet_encoders
except ImportError as e:
Expand All @@ -43,6 +44,7 @@
encoders.update(timm_res2net_encoders)
encoders.update(timm_regnet_encoders)
encoders.update(timm_sknet_encoders)
encoders.update(timm_mobilenetv3_encoders)
encoders.update(timm_gernet_encoders)


Expand Down
164 changes: 164 additions & 0 deletions segmentation_models_pytorch/encoders/timm_mobilenetv3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
from timm import create_model
import torch.nn as nn
from ._base import EncoderMixin


def make_divisible(x, divisible_by=8):
import numpy as np
return int(np.ceil(x * 1. / divisible_by) * divisible_by)


class MobileNetV3Encoder(nn.Module, EncoderMixin):
def __init__(self, model, width_mult, depth=5, **kwargs):
super().__init__()
self._depth = depth
if 'small' in str(model):
self.mode = 'small'
self._out_channels = (16*width_mult, 16*width_mult, 24*width_mult, 48*width_mult, 576*width_mult)
self._out_channels = tuple(map(make_divisible, self._out_channels))
elif 'large' in str(model):
self.mode = 'large'
self._out_channels = (16*width_mult, 24*width_mult, 40*width_mult, 112*width_mult, 960*width_mult)
self._out_channels = tuple(map(make_divisible, self._out_channels))
else:
self.mode = 'None'
raise ValueError(
'MobileNetV3 mode should be small or large, got {}'.format(self.mode))
self._out_channels = (3,) + self._out_channels
self._in_channels = 3
# minimal models replace hardswish with relu
model = create_model(model_name=model,
scriptable=True, # torch.jit scriptable
exportable=True, # onnx export
features_only=True)
self.conv_stem = model.conv_stem
self.bn1 = model.bn1
self.act1 = model.act1
self.blocks = model.blocks

def get_stages(self):
if self.mode == 'small':
return [
nn.Identity(),
nn.Sequential(self.conv_stem, self.bn1, self.act1),
self.blocks[0],
self.blocks[1],
self.blocks[2:4],
self.blocks[4:],
]
elif self.mode == 'large':
return [
nn.Identity(),
nn.Sequential(self.conv_stem, self.bn1, self.act1, self.blocks[0]),
self.blocks[1],
self.blocks[2],
self.blocks[3:5],
self.blocks[5:],
]
else:
ValueError('MobileNetV3 mode should be small or large, got {}'.format(self.mode))

def forward(self, x):
stages = self.get_stages()

features = []
for i in range(self._depth + 1):
x = stages[i](x)
features.append(x)

return features

def load_state_dict(self, state_dict, **kwargs):
state_dict.pop('conv_head.weight')
state_dict.pop('conv_head.bias')
state_dict.pop('classifier.weight')
state_dict.pop('classifier.bias')
super().load_state_dict(state_dict, **kwargs)


mobilenetv3_weights = {
'tf_mobilenetv3_large_075': {
'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_large_075-150ee8b0.pth'
},
'tf_mobilenetv3_large_100': {
'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_large_100-427764d5.pth'
},
'tf_mobilenetv3_large_minimal_100': {
'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_large_minimal_100-8596ae28.pth'
},
'tf_mobilenetv3_small_075': {
'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_small_075-da427f52.pth'
},
'tf_mobilenetv3_small_100': {
'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_small_100-37f49e2b.pth'
},
'tf_mobilenetv3_small_minimal_100': {
'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_small_minimal_100-922a7843.pth'
},


}

pretrained_settings = {}
for model_name, sources in mobilenetv3_weights.items():
pretrained_settings[model_name] = {}
for source_name, source_url in sources.items():
pretrained_settings[model_name][source_name] = {
"url": source_url,
'input_range': [0, 1],
'mean': [0.485, 0.456, 0.406],
'std': [0.229, 0.224, 0.225],
'input_space': 'RGB',
}


timm_mobilenetv3_encoders = {
'timm-mobilenetv3_large_075': {
'encoder': MobileNetV3Encoder,
'pretrained_settings': pretrained_settings['tf_mobilenetv3_large_075'],
'params': {
'model': 'tf_mobilenetv3_large_075',
'width_mult': 0.75
}
},
'timm-mobilenetv3_large_100': {
'encoder': MobileNetV3Encoder,
'pretrained_settings': pretrained_settings['tf_mobilenetv3_large_100'],
'params': {
'model': 'tf_mobilenetv3_large_100',
'width_mult': 1.0
}
},
'timm-mobilenetv3_large_minimal_100': {
'encoder': MobileNetV3Encoder,
'pretrained_settings': pretrained_settings['tf_mobilenetv3_large_minimal_100'],
'params': {
'model': 'tf_mobilenetv3_large_minimal_100',
'width_mult': 1.0
}
},
'timm-mobilenetv3_small_075': {
'encoder': MobileNetV3Encoder,
'pretrained_settings': pretrained_settings['tf_mobilenetv3_small_075'],
'params': {
'model': 'tf_mobilenetv3_small_075',
'width_mult': 0.75
}
},
'timm-mobilenetv3_small_100': {
'encoder': MobileNetV3Encoder,
'pretrained_settings': pretrained_settings['tf_mobilenetv3_small_100'],
'params': {
'model': 'tf_mobilenetv3_small_100',
'width_mult': 1.0
}
},
'timm-mobilenetv3_small_minimal_100': {
'encoder': MobileNetV3Encoder,
'pretrained_settings': pretrained_settings['tf_mobilenetv3_small_minimal_100'],
'params': {
'model': 'tf_mobilenetv3_small_minimal_100',
'width_mult': 1.0
}
},
}