diff --git a/configs/_base_/models/tsm_mobilenet_v2.py b/configs/_base_/models/tsm_mobilenet_v2.py new file mode 100644 index 0000000000..0338e19e27 --- /dev/null +++ b/configs/_base_/models/tsm_mobilenet_v2.py @@ -0,0 +1,22 @@ +# model settings +model = dict( + type='Recognizer2D', + backbone=dict( + type='MobileNetV2TSM', + shift_div=8, + num_segments=8, + is_shift=True, + pretrained='mmcls://mobilenet_v2'), + cls_head=dict( + type='TSMHead', + num_segments=8, + num_classes=400, + in_channels=1280, + spatial_type='avg', + consensus=dict(type='AvgConsensus', dim=1), + dropout_ratio=0.5, + init_std=0.001, + is_shift=True)) +# model training and testing settings +train_cfg = None +test_cfg = dict(average_clips='prob') diff --git a/configs/_base_/schedules/sgd_tsm_mobilenet_v2_100e.py b/configs/_base_/schedules/sgd_tsm_mobilenet_v2_100e.py new file mode 100644 index 0000000000..63ed3f275a --- /dev/null +++ b/configs/_base_/schedules/sgd_tsm_mobilenet_v2_100e.py @@ -0,0 +1,12 @@ +# optimizer +optimizer = dict( + type='SGD', + constructor='TSMOptimizerConstructor', + paramwise_cfg=dict(fc_lr5=True), + lr=0.01, # this lr is used for 8 gpus + momentum=0.9, + weight_decay=0.00002) +optimizer_config = dict(grad_clip=dict(max_norm=20, norm_type=2)) +# learning policy +lr_config = dict(policy='step', step=[40, 80]) +total_epochs = 100 diff --git a/configs/_base_/schedules/sgd_tsm_mobilenet_v2_50e.py b/configs/_base_/schedules/sgd_tsm_mobilenet_v2_50e.py new file mode 100644 index 0000000000..78612def95 --- /dev/null +++ b/configs/_base_/schedules/sgd_tsm_mobilenet_v2_50e.py @@ -0,0 +1,12 @@ +# optimizer +optimizer = dict( + type='SGD', + constructor='TSMOptimizerConstructor', + paramwise_cfg=dict(fc_lr5=True), + lr=0.01, # this lr is used for 8 gpus + momentum=0.9, + weight_decay=0.00002) +optimizer_config = dict(grad_clip=dict(max_norm=20, norm_type=2)) +# learning policy +lr_config = dict(policy='step', step=[20, 40]) +total_epochs = 50 diff --git a/configs/recognition/tsm/README.md b/configs/recognition/tsm/README.md index 4ccc570537..6fb3cf18b7 100644 --- a/configs/recognition/tsm/README.md +++ b/configs/recognition/tsm/README.md @@ -40,6 +40,7 @@ |[tsm_nl_embedded_gaussian_r50_1x1x8_50e_kinetics400_rgb](/configs/recognition/tsm/tsm_nl_embedded_gaussian_r50_1x1x8_50e_kinetics400_rgb.py)|short-side 320|8x4| ResNet50| ImageNet |72.03|90.25|71.81|90.36|x|8931|[ckpt](https://download.openmmlab.com/mmaction/recognition/tsm/tsm_nl_embedded_gaussian_r50_1x1x8_50e_kinetics400_rgb/tsm_nl_embedded_gaussian_r50_1x1x8_50e_kinetics400_rgb_20200724-f00f1336.pth)|[log](https://download.openmmlab.com/mmaction/recognition/tsm/tsm_nl_embedded_gaussian_r50_1x1x8_50e_kinetics400_rgb/20200724_120023.log)|[json](https://download.openmmlab.com/mmaction/recognition/tsm/tsm_nl_embedded_gaussian_r50_1x1x8_50e_kinetics400_rgb/20200724_120023.log.json)| |[tsm_nl_gaussian_r50_1x1x8_50e_kinetics400_rgb](/configs/recognition/tsm/tsm_nl_gaussian_r50_1x1x8_50e_kinetics400_rgb.py)|short-side 320|8x4| ResNet50| ImageNet |70.70|89.90|x|x|x|10125|[ckpt](https://download.openmmlab.com/mmaction/recognition/tsm/tsm_nl_gaussian_r50_1x1x8_50e_kinetics400_rgb/tsm_nl_gaussian_r50_1x1x8_50e_kinetics400_rgb_20200816-b93fd297.pth)|[log](https://download.openmmlab.com/mmaction/recognition/tsm/tsm_nl_gaussian_r50_1x1x8_50e_kinetics400_rgb/20200815_210253.log)|[json](https://download.openmmlab.com/mmaction/recognition/tsm/tsm_nl_gaussian_r50_1x1x8_50e_kinetics400_rgb/20200815_210253.log.json)| |[tsm_nl_dot_product_r50_1x1x8_50e_kinetics400_rgb](/configs/recognition/tsm/tsm_nl_dot_product_r50_1x1x8_50e_kinetics400_rgb.py)|short-side 320|8x4|ResNet50| ImageNet |71.60|90.34|x|x|x|8358|[ckpt](https://download.openmmlab.com/mmaction/recognition/tsm/tsm_nl_dot_product_r50_1x1x8_50e_kinetics400_rgb/tsm_nl_dot_product_r50_1x1x8_50e_kinetics400_rgb_20200724-d8ad84d2.pth)|[log](https://download.openmmlab.com/mmaction/recognition/tsm/tsm_nl_dot_product_r50_1x1x8_50e_kinetics400_rgb/20200723_220442.log)|[json](https://download.openmmlab.com/mmaction/recognition/tsm/tsm_nl_dot_product_r50_1x1x8_50e_kinetics400_rgb/20200723_220442.log.json)| +|[tsm_mobilenetv2_dense_1x1x8_100e_kinetics400_rgb](/configs/recognition/tsm/tsm_mobilenetv2_dense_1x1x8_100e_kinetics400_rgb.py)|short-side 320|8|MobileNetV2| ImageNet |68.46|88.64|x|x|x|3385|[ckpt](https://download.openmmlab.com/mmaction/recognition/tsm/tsm_mobilenetv2_dense_1x1x8_100e_kinetics400_rgb/tsm_mobilenetv2_dense_320p_1x1x8_100e_kinetics400_rgb_20210202-61135809.pth)|[log](https://download.openmmlab.com/mmaction/recognition/tsm/tsm_mobilenetv2_dense_1x1x8_100e_kinetics400_rgb/20210129_024936.log)|[json](https://download.openmmlab.com/mmaction/recognition/tsm/tsm_mobilenetv2_dense_1x1x8_100e_kinetics400_rgb/20210129_024936.log.json)| ### Something-Something V1 diff --git a/configs/recognition/tsm/tsm_mobilenetv2_1x1x8_50e_kinetics400_rgb.py b/configs/recognition/tsm/tsm_mobilenetv2_1x1x8_50e_kinetics400_rgb.py new file mode 100644 index 0000000000..647cb609ac --- /dev/null +++ b/configs/recognition/tsm/tsm_mobilenetv2_1x1x8_50e_kinetics400_rgb.py @@ -0,0 +1,92 @@ +_base_ = [ + '../../_base_/models/tsm_mobilenet_v2.py', + '../../_base_/schedules/sgd_tsm_mobilenet_v2_50e.py', + '../../_base_/default_runtime.py' +] + +# dataset settings +dataset_type = 'RawframeDataset' +data_root = 'data/kinetics400/rawframes_train' +data_root_val = 'data/kinetics400/rawframes_val' +ann_file_train = 'data/kinetics400/kinetics400_train_list_rawframes.txt' +ann_file_val = 'data/kinetics400/kinetics400_val_list_rawframes.txt' +ann_file_test = 'data/kinetics400/kinetics400_val_list_rawframes.txt' +img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_bgr=False) +train_pipeline = [ + dict(type='SampleFrames', clip_len=1, frame_interval=1, num_clips=8), + dict(type='RawFrameDecode'), + dict(type='Resize', scale=(-1, 256)), + dict( + type='MultiScaleCrop', + input_size=224, + scales=(1, 0.875, 0.75, 0.66), + random_crop=False, + max_wh_scale_gap=1, + num_fixed_crops=13), + dict(type='Resize', scale=(224, 224), keep_ratio=False), + dict(type='Flip', flip_ratio=0.5), + dict(type='Normalize', **img_norm_cfg), + dict(type='FormatShape', input_format='NCHW'), + dict(type='Collect', keys=['imgs', 'label'], meta_keys=[]), + dict(type='ToTensor', keys=['imgs', 'label']) +] +val_pipeline = [ + dict( + type='SampleFrames', + clip_len=1, + frame_interval=1, + num_clips=8, + test_mode=True), + dict(type='RawFrameDecode'), + dict(type='Resize', scale=(-1, 256)), + dict(type='CenterCrop', crop_size=224), + dict(type='Normalize', **img_norm_cfg), + dict(type='FormatShape', input_format='NCHW'), + dict(type='Collect', keys=['imgs', 'label'], meta_keys=[]), + dict(type='ToTensor', keys=['imgs']) +] +test_pipeline = [ + dict( + type='SampleFrames', + clip_len=1, + frame_interval=1, + num_clips=8, + test_mode=True), + dict(type='RawFrameDecode'), + dict(type='Resize', scale=(-1, 256)), + dict(type='ThreeCrop', crop_size=256), + dict(type='Normalize', **img_norm_cfg), + dict(type='FormatShape', input_format='NCHW'), + dict(type='Collect', keys=['imgs', 'label'], meta_keys=[]), + dict(type='ToTensor', keys=['imgs']) +] +data = dict( + videos_per_gpu=8, + workers_per_gpu=4, + train=dict( + type=dataset_type, + ann_file=ann_file_train, + data_prefix=data_root, + pipeline=train_pipeline), + val=dict( + type=dataset_type, + ann_file=ann_file_val, + data_prefix=data_root_val, + pipeline=val_pipeline), + test=dict( + type=dataset_type, + ann_file=ann_file_test, + data_prefix=data_root_val, + pipeline=test_pipeline)) +evaluation = dict( + interval=5, metrics=['top_k_accuracy', 'mean_class_accuracy']) + +# optimizer +optimizer = dict( + lr=0.01, # this lr is used for 8 gpus +) + +# runtime settings +checkpoint_config = dict(interval=1) +work_dir = './work_dirs/tsm_mobilenetv2_dense_1x1x8_100e_kinetics400_rgb/' diff --git a/configs/recognition/tsm/tsm_mobilenetv2_video_dense_1x1x8_100e_kinetics400_rgb.py b/configs/recognition/tsm/tsm_mobilenetv2_video_dense_1x1x8_100e_kinetics400_rgb.py new file mode 100644 index 0000000000..e0a3c4873b --- /dev/null +++ b/configs/recognition/tsm/tsm_mobilenetv2_video_dense_1x1x8_100e_kinetics400_rgb.py @@ -0,0 +1,97 @@ +_base_ = [ + '../../_base_/models/tsm_mobilenet_v2.py', + '../../_base_/schedules/sgd_tsm_mobilenet_v2_100e.py', + '../../_base_/default_runtime.py' +] + +# dataset settings +dataset_type = 'VideoDataset' +data_root = 'data/kinetics400/videos_train' +data_root_val = 'data/kinetics400/videos_val' +ann_file_train = 'data/kinetics400/kinetics400_train_list_videos.txt' +ann_file_val = 'data/kinetics400/kinetics400_val_list_videos.txt' +ann_file_test = 'data/kinetics400/kinetics400_val_list_videos.txt' +img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_bgr=False) +train_pipeline = [ + dict(type='DecordInit'), + dict(type='DenseSampleFrames', clip_len=1, frame_interval=1, num_clips=8), + dict(type='DecordDecode'), + dict(type='Resize', scale=(-1, 256)), + dict( + type='MultiScaleCrop', + input_size=224, + scales=(1, 0.875, 0.75, 0.66), + random_crop=False, + max_wh_scale_gap=1, + num_fixed_crops=13), + dict(type='Resize', scale=(224, 224), keep_ratio=False), + dict(type='Flip', flip_ratio=0.5), + dict(type='Normalize', **img_norm_cfg), + dict(type='FormatShape', input_format='NCHW'), + dict(type='Collect', keys=['imgs', 'label'], meta_keys=[]), + dict(type='ToTensor', keys=['imgs', 'label']) +] +val_pipeline = [ + dict(type='DecordInit'), + dict( + type='DenseSampleFrames', + clip_len=1, + frame_interval=1, + num_clips=8, + test_mode=True), + dict(type='DecordDecode'), + dict(type='Resize', scale=(-1, 256)), + dict(type='CenterCrop', crop_size=224), + dict(type='Flip', flip_ratio=0), + dict(type='Normalize', **img_norm_cfg), + dict(type='FormatShape', input_format='NCHW'), + dict(type='Collect', keys=['imgs', 'label'], meta_keys=[]), + dict(type='ToTensor', keys=['imgs']) +] +test_pipeline = [ + dict(type='DecordInit'), + dict( + type='DenseSampleFrames', + clip_len=1, + frame_interval=1, + num_clips=8, + test_mode=True), + dict(type='DecordDecode'), + dict(type='Resize', scale=(-1, 256)), + dict(type='ThreeCrop', crop_size=256), + dict(type='Flip', flip_ratio=0), + dict(type='Normalize', **img_norm_cfg), + dict(type='FormatShape', input_format='NCHW'), + dict(type='Collect', keys=['imgs', 'label'], meta_keys=[]), + dict(type='ToTensor', keys=['imgs']) +] +data = dict( + videos_per_gpu=8, + workers_per_gpu=4, + train=dict( + type=dataset_type, + ann_file=ann_file_train, + data_prefix=data_root, + pipeline=train_pipeline), + val=dict( + type=dataset_type, + ann_file=ann_file_val, + data_prefix=data_root_val, + pipeline=val_pipeline), + test=dict( + type=dataset_type, + ann_file=ann_file_test, + data_prefix=data_root_val, + pipeline=test_pipeline)) +evaluation = dict( + interval=5, metrics=['top_k_accuracy', 'mean_class_accuracy']) + +# optimizer +optimizer = dict( + lr=0.01, # this lr is used for 8 gpus +) + +# runtime settings +checkpoint_config = dict(interval=5) +work_dir = './work_dirs/tsm_mobilenetv2_dense_video_1x1x8_100e_kinetics400_rgb/' # noqa diff --git a/configs/recognition/tsm/tsm_mobilenetv2_video_inference_dense_1x1x8_100e_kinetics400_rgb.py b/configs/recognition/tsm/tsm_mobilenetv2_video_inference_dense_1x1x8_100e_kinetics400_rgb.py new file mode 100644 index 0000000000..a66c772a9b --- /dev/null +++ b/configs/recognition/tsm/tsm_mobilenetv2_video_inference_dense_1x1x8_100e_kinetics400_rgb.py @@ -0,0 +1,34 @@ +_base_ = ['../../_base_/models/tsm_mobilenet_v2.py'] + +# dataset settings +dataset_type = 'VideoDataset' +data_root_val = 'data/kinetics400/videos_val' +ann_file_test = 'data/kinetics400/kinetics400_val_list_videos.txt' +img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_bgr=False) +test_pipeline = [ + dict(type='DecordInit'), + dict( + type='DenseSampleFrames', + clip_len=1, + frame_interval=1, + num_clips=8, + test_mode=True), + dict(type='DecordDecode'), + dict(type='Resize', scale=(-1, 256)), + dict(type='CenterCrop', crop_size=224), + dict(type='Flip', flip_ratio=0), + dict(type='Normalize', **img_norm_cfg), + dict(type='FormatShape', input_format='NCHW'), + dict(type='Collect', keys=['imgs', 'label'], meta_keys=[]), + dict(type='ToTensor', keys=['imgs']) +] + +data = dict( + videos_per_gpu=4, + workers_per_gpu=4, + test=dict( + type=dataset_type, + ann_file=ann_file_test, + data_prefix=data_root_val, + pipeline=test_pipeline)) diff --git a/docs/changelog.md b/docs/changelog.md index 96df4048a3..0411e77c28 100644 --- a/docs/changelog.md +++ b/docs/changelog.md @@ -1,5 +1,19 @@ ## Changelog +### Master + +**Highlights** + +- Support TSM-MobileNetV2 + +**New Features** + +- Support TSM-MobileNetV2 ([#415](https://github.com/open-mmlab/mmaction2/pull/415)) + +**ModelZoo** + +- Add TSM-MobileNetV2 for Kinetics400 ([#415](https://github.com/open-mmlab/mmaction2/pull/415)) + ### 0.11.0 (31/01/2021) **Highlights** diff --git a/mmaction/models/__init__.py b/mmaction/models/__init__.py index 16958b2558..8e3f67e1ae 100644 --- a/mmaction/models/__init__.py +++ b/mmaction/models/__init__.py @@ -1,6 +1,7 @@ -from .backbones import (C3D, X3D, ResNet, ResNet2Plus1d, ResNet3d, ResNet3dCSN, - ResNet3dLayer, ResNet3dSlowFast, ResNet3dSlowOnly, - ResNetAudio, ResNetTIN, ResNetTSM) +from .backbones import (C3D, X3D, MobileNetV2, MobileNetV2TSM, ResNet, + ResNet2Plus1d, ResNet3d, ResNet3dCSN, ResNet3dLayer, + ResNet3dSlowFast, ResNet3dSlowOnly, ResNetAudio, + ResNetTIN, ResNetTSM) from .builder import (DETECTORS, build_backbone, build_detector, build_head, build_localizer, build_loss, build_model, build_neck, build_recognizer) @@ -29,5 +30,5 @@ 'TPN', 'TPNHead', 'build_loss', 'build_neck', 'AudioRecognizer', 'AudioTSNHead', 'X3D', 'X3DHead', 'ResNet3dLayer', 'DETECTORS', 'SingleRoIExtractor3D', 'BBoxHeadAVA', 'ResNetAudio', 'build_detector', - 'ConvAudio', 'AVARoIHead' + 'ConvAudio', 'AVARoIHead', 'MobileNetV2', 'MobileNetV2TSM' ] diff --git a/mmaction/models/backbones/__init__.py b/mmaction/models/backbones/__init__.py index 6fcf52a06b..576b42e08b 100644 --- a/mmaction/models/backbones/__init__.py +++ b/mmaction/models/backbones/__init__.py @@ -1,4 +1,6 @@ from .c3d import C3D +from .mobilenet_v2 import MobileNetV2 +from .mobilenet_v2_tsm import MobileNetV2TSM from .resnet import ResNet from .resnet2plus1d import ResNet2Plus1d from .resnet3d import ResNet3d, ResNet3dLayer @@ -13,5 +15,5 @@ __all__ = [ 'C3D', 'ResNet', 'ResNet3d', 'ResNetTSM', 'ResNet2Plus1d', 'ResNet3dSlowFast', 'ResNet3dSlowOnly', 'ResNet3dCSN', 'ResNetTIN', 'X3D', - 'ResNetAudio', 'ResNet3dLayer' + 'ResNetAudio', 'ResNet3dLayer', 'MobileNetV2TSM', 'MobileNetV2' ] diff --git a/mmaction/models/backbones/mobilenet_v2.py b/mmaction/models/backbones/mobilenet_v2.py new file mode 100644 index 0000000000..5a093fa1fa --- /dev/null +++ b/mmaction/models/backbones/mobilenet_v2.py @@ -0,0 +1,297 @@ +import torch.nn as nn +import torch.utils.checkpoint as cp +from mmcv.cnn import ConvModule, constant_init, kaiming_init +from mmcv.runner import load_checkpoint +from torch.nn.modules.batchnorm import _BatchNorm + +from ...utils import get_root_logger +from ..builder import BACKBONES + + +def make_divisible(value, divisor, min_value=None, min_ratio=0.9): + """Make divisible function. + + This function rounds the channel number down to the nearest value that can + be divisible by the divisor. + Args: + value (int): The original channel number. + divisor (int): The divisor to fully divide the channel number. + min_value (int, optional): The minimum value of the output channel. + Default: None, means that the minimum value equal to the divisor. + min_ratio (float, optional): The minimum ratio of the rounded channel + number to the original channel number. Default: 0.9. + Returns: + int: The modified output channel number + """ + + if min_value is None: + min_value = divisor + new_value = max(min_value, int(value + divisor / 2) // divisor * divisor) + # Make sure that round down does not go down by more than (1-min_ratio). + if new_value < min_ratio * value: + new_value += divisor + return new_value + + +class InvertedResidual(nn.Module): + """InvertedResidual block for MobileNetV2. + + Args: + in_channels (int): The input channels of the InvertedResidual block. + out_channels (int): The output channels of the InvertedResidual block. + stride (int): Stride of the middle (first) 3x3 convolution. + expand_ratio (int): adjusts number of channels of the hidden layer + in InvertedResidual by this amount. + conv_cfg (dict): Config dict for convolution layer. + Default: None, which means using conv2d. + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='BN'). + act_cfg (dict): Config dict for activation layer. + Default: dict(type='ReLU6'). + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Default: False. + Returns: + Tensor: The output tensor + """ + + def __init__(self, + in_channels, + out_channels, + stride, + expand_ratio, + conv_cfg=None, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU6'), + with_cp=False): + super(InvertedResidual, self).__init__() + self.stride = stride + assert stride in [1, 2], f'stride must in [1, 2]. ' \ + f'But received {stride}.' + self.with_cp = with_cp + self.use_res_connect = self.stride == 1 and in_channels == out_channels + hidden_dim = int(round(in_channels * expand_ratio)) + + layers = [] + if expand_ratio != 1: + layers.append( + ConvModule( + in_channels=in_channels, + out_channels=hidden_dim, + kernel_size=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg)) + layers.extend([ + ConvModule( + in_channels=hidden_dim, + out_channels=hidden_dim, + kernel_size=3, + stride=stride, + padding=1, + groups=hidden_dim, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg), + ConvModule( + in_channels=hidden_dim, + out_channels=out_channels, + kernel_size=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=None) + ]) + self.conv = nn.Sequential(*layers) + + def forward(self, x): + + def _inner_forward(x): + if self.use_res_connect: + return x + self.conv(x) + else: + return self.conv(x) + + if self.with_cp and x.requires_grad: + out = cp.checkpoint(_inner_forward, x) + else: + out = _inner_forward(x) + + return out + + +@BACKBONES.register_module() +class MobileNetV2(nn.Module): + """MobileNetV2 backbone. + + Args: + pretrained (str | None): Name of pretrained model. Default: None. + widen_factor (float): Width multiplier, multiply number of + channels in each layer by this amount. Default: 1.0. + out_indices (None or Sequence[int]): Output from which stages. + Default: (7, ). + frozen_stages (int): Stages to be frozen (all param fixed). + Default: -1, which means not freezing any parameters. + conv_cfg (dict): Config dict for convolution layer. + Default: None, which means using conv2d. + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='BN'). + act_cfg (dict): Config dict for activation layer. + Default: dict(type='ReLU6'). + norm_eval (bool): Whether to set norm layers to eval mode, namely, + freeze running stats (mean and var). Note: Effect on Batch Norm + and its variants only. Default: False. + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Default: False. + """ + + # Parameters to build layers. 4 parameters are needed to construct a + # layer, from left to right: expand_ratio, channel, num_blocks, stride. + arch_settings = [[1, 16, 1, 1], [6, 24, 2, 2], [6, 32, 3, 2], + [6, 64, 4, 2], [6, 96, 3, 1], [6, 160, 3, 2], + [6, 320, 1, 1]] + + def __init__(self, + pretrained=None, + widen_factor=1., + out_indices=(7, ), + frozen_stages=-1, + conv_cfg=dict(type='Conv'), + norm_cfg=dict(type='BN2d', requires_grad=True), + act_cfg=dict(type='ReLU6', inplace=True), + norm_eval=False, + with_cp=False): + super().__init__() + self.pretrained = pretrained + self.widen_factor = widen_factor + self.out_indices = out_indices + for index in out_indices: + if index not in range(0, 8): + raise ValueError('the item in out_indices must in ' + f'range(0, 8). But received {index}') + + if frozen_stages not in range(-1, 8): + raise ValueError('frozen_stages must be in range(-1, 8). ' + f'But received {frozen_stages}') + self.out_indices = out_indices + self.frozen_stages = frozen_stages + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.act_cfg = act_cfg + self.norm_eval = norm_eval + self.with_cp = with_cp + + self.in_channels = make_divisible(32 * widen_factor, 8) + + self.conv1 = ConvModule( + in_channels=3, + out_channels=self.in_channels, + kernel_size=3, + stride=2, + padding=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + + self.layers = [] + + for i, layer_cfg in enumerate(self.arch_settings): + expand_ratio, channel, num_blocks, stride = layer_cfg + out_channels = make_divisible(channel * widen_factor, 8) + inverted_res_layer = self.make_layer( + out_channels=out_channels, + num_blocks=num_blocks, + stride=stride, + expand_ratio=expand_ratio) + layer_name = f'layer{i + 1}' + self.add_module(layer_name, inverted_res_layer) + self.layers.append(layer_name) + + if widen_factor > 1.0: + self.out_channel = int(1280 * widen_factor) + else: + self.out_channel = 1280 + + layer = ConvModule( + in_channels=self.in_channels, + out_channels=self.out_channel, + kernel_size=1, + stride=1, + padding=0, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + self.add_module('conv2', layer) + self.layers.append('conv2') + + def make_layer(self, out_channels, num_blocks, stride, expand_ratio): + """Stack InvertedResidual blocks to build a layer for MobileNetV2. + + Args: + out_channels (int): out_channels of block. + num_blocks (int): number of blocks. + stride (int): stride of the first block. Default: 1 + expand_ratio (int): Expand the number of channels of the + hidden layer in InvertedResidual by this ratio. Default: 6. + """ + layers = [] + for i in range(num_blocks): + if i >= 1: + stride = 1 + layers.append( + InvertedResidual( + self.in_channels, + out_channels, + stride, + expand_ratio=expand_ratio, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg, + with_cp=self.with_cp)) + self.in_channels = out_channels + + return nn.Sequential(*layers) + + def init_weights(self): + if isinstance(self.pretrained, str): + logger = get_root_logger() + load_checkpoint(self, self.pretrained, strict=False, logger=logger) + elif self.pretrained is None: + for m in self.modules(): + if isinstance(m, nn.Conv2d): + kaiming_init(m) + elif isinstance(m, (_BatchNorm, nn.GroupNorm)): + constant_init(m, 1) + else: + raise TypeError('pretrained must be a str or None') + + def forward(self, x): + x = self.conv1(x) + + outs = [] + for i, layer_name in enumerate(self.layers): + layer = getattr(self, layer_name) + x = layer(x) + if i in self.out_indices: + outs.append(x) + + if len(outs) == 1: + return outs[0] + else: + return tuple(outs) + + def _freeze_stages(self): + if self.frozen_stages >= 0: + for param in self.conv1.parameters(): + param.requires_grad = False + for i in range(1, self.frozen_stages + 1): + layer = getattr(self, f'layer{i}') + layer.eval() + for param in layer.parameters(): + param.requires_grad = False + + def train(self, mode=True): + super(MobileNetV2, self).train(mode) + self._freeze_stages() + if mode and self.norm_eval: + for m in self.modules(): + if isinstance(m, _BatchNorm): + m.eval() diff --git a/mmaction/models/backbones/mobilenet_v2_tsm.py b/mmaction/models/backbones/mobilenet_v2_tsm.py new file mode 100644 index 0000000000..dd37bdbb7c --- /dev/null +++ b/mmaction/models/backbones/mobilenet_v2_tsm.py @@ -0,0 +1,40 @@ +from ..registry import BACKBONES +from .mobilenet_v2 import InvertedResidual, MobileNetV2 +from .resnet_tsm import TemporalShift + + +@BACKBONES.register_module() +class MobileNetV2TSM(MobileNetV2): + """MobileNetV2 backbone for TSM. + + Args: + num_segments (int): Number of frame segments. Default: 8. + is_shift (bool): Whether to make temporal shift in reset layers. + Default: True. + shift_div (int): Number of div for shift. Default: 8. + **kwargs (keyword arguments, optional): Arguments for MobilNetV2. + """ + + def __init__(self, num_segments=8, is_shift=True, shift_div=8, **kwargs): + super().__init__(**kwargs) + self.num_segments = num_segments + self.is_shift = is_shift + self.shift_div = shift_div + + def make_temporal_shift(self): + """Make temporal shift for some layers.""" + for m in self.modules(): + if isinstance(m, InvertedResidual) and \ + len(m.conv) == 3 and m.use_res_connect: + m.conv[0] = TemporalShift( + m.conv[0], + num_segments=self.num_segments, + shift_div=self.shift_div, + ) + + def init_weights(self): + """Initiate the parameters either from existing checkpoint or from + scratch.""" + super().init_weights() + if self.is_shift: + self.make_temporal_shift() diff --git a/tests/test_models/test_backbones.py b/tests/test_models/test_backbones.py index 76d97b2f4d..ba2b3ee67f 100644 --- a/tests/test_models/test_backbones.py +++ b/tests/test_models/test_backbones.py @@ -5,9 +5,9 @@ import torch.nn as nn from mmcv.utils import _BatchNorm -from mmaction.models import (C3D, X3D, ResNet2Plus1d, ResNet3dCSN, - ResNet3dSlowFast, ResNet3dSlowOnly, ResNetAudio, - ResNetTIN, ResNetTSM) +from mmaction.models import (C3D, X3D, MobileNetV2TSM, ResNet2Plus1d, + ResNet3dCSN, ResNet3dSlowFast, ResNet3dSlowOnly, + ResNetAudio, ResNetTIN, ResNetTSM) from mmaction.models.backbones.resnet_tsm import NL3DWrapper from .base import check_norm_state, generate_backbone_demo_inputs @@ -322,6 +322,45 @@ def test_resnet_tsm_backbone(): assert feat.shape == torch.Size([8, 2048, 1, 1]) +def test_mobilenetv2_tsm_backbone(): + """Test mobilenetv2_tsm backbone.""" + from mmaction.models.backbones.resnet_tsm import TemporalShift + from mmaction.models.backbones.mobilenet_v2 import InvertedResidual + from mmcv.cnn import ConvModule + + input_shape = (8, 3, 64, 64) + imgs = generate_backbone_demo_inputs(input_shape) + + # mobilenetv2_tsm with width_mult = 1.0 + mobilenetv2_tsm = MobileNetV2TSM() + mobilenetv2_tsm.init_weights() + for cur_module in mobilenetv2_tsm.modules(): + if isinstance(cur_module, InvertedResidual) and \ + len(cur_module.conv) == 3 and \ + cur_module.use_res_connect: + assert isinstance(cur_module.conv[0], TemporalShift) + assert cur_module.conv[0].num_segments == \ + mobilenetv2_tsm.num_segments + assert cur_module.conv[0].shift_div == mobilenetv2_tsm.shift_div + assert isinstance(cur_module.conv[0].net, ConvModule) + + # TSM-MobileNetV2 with widen_factor = 1.0 forword + feat = mobilenetv2_tsm(imgs) + assert feat.shape == torch.Size([8, 1280, 2, 2]) + + # mobilenetv2 with widen_factor = 0.5 forword + mobilenetv2_tsm_05 = MobileNetV2TSM(widen_factor=0.5) + mobilenetv2_tsm_05.init_weights() + feat = mobilenetv2_tsm_05(imgs) + assert feat.shape == torch.Size([8, 1280, 2, 2]) + + # mobilenetv2 with widen_factor = 1.5 forword + mobilenetv2_tsm_15 = MobileNetV2TSM(widen_factor=1.5) + mobilenetv2_tsm_15.init_weights() + feat = mobilenetv2_tsm_15(imgs) + assert feat.shape == torch.Size([8, 1920, 2, 2]) + + def test_slowfast_backbone(): """Test SlowFast backbone.""" with pytest.raises(TypeError): diff --git a/tests/test_models/test_common_modules/test_mobilenet_v2.py b/tests/test_models/test_common_modules/test_mobilenet_v2.py new file mode 100644 index 0000000000..cdb4cd9ec4 --- /dev/null +++ b/tests/test_models/test_common_modules/test_mobilenet_v2.py @@ -0,0 +1,204 @@ +import pytest +import torch +from mmcv.utils import _BatchNorm + +from mmaction.models import MobileNetV2 +from ..base import check_norm_state, generate_backbone_demo_inputs + + +def test_mobilenetv2_backbone(): + """Test MobileNetV2. + + Modified from mmclassification. + """ + from torch.nn.modules import GroupNorm + from mmaction.models.backbones.mobilenet_v2 import InvertedResidual + + def is_norm(modules): + """Check if is one of the norms.""" + if isinstance(modules, (GroupNorm, _BatchNorm)): + return True + return False + + def is_block(modules): + """Check if is ResNet building block.""" + if isinstance(modules, (InvertedResidual, )): + return True + return False + + with pytest.raises(TypeError): + # pretrained must be a string path + model = MobileNetV2(pretrained=0) + model.init_weights() + + with pytest.raises(ValueError): + # frozen_stages must in range(1, 8) + MobileNetV2(frozen_stages=8) + + with pytest.raises(ValueError): + # tout_indices in range(-1, 8) + MobileNetV2(out_indices=[8]) + + input_shape = (1, 3, 224, 224) + imgs = generate_backbone_demo_inputs(input_shape) + + # Test MobileNetV2 with first stage frozen + frozen_stages = 1 + model = MobileNetV2(frozen_stages=frozen_stages) + model.init_weights() + model.train() + + for mod in model.conv1.modules(): + for param in mod.parameters(): + assert param.requires_grad is False + for i in range(1, frozen_stages + 1): + layer = getattr(model, f'layer{i}') + for mod in layer.modules(): + if isinstance(mod, _BatchNorm): + assert mod.training is False + for param in layer.parameters(): + assert param.requires_grad is False + + # Test MobileNetV2 with norm_eval=True + model = MobileNetV2(norm_eval=True) + model.init_weights() + model.train() + + assert check_norm_state(model.modules(), False) + + # Test MobileNetV2 forward with widen_factor=1.0, pretrained + model = MobileNetV2( + widen_factor=1.0, + out_indices=range(0, 8), + pretrained='mmcls://mobilenet_v2') + model.init_weights() + model.train() + + assert check_norm_state(model.modules(), True) + + feat = model(imgs) + assert len(feat) == 8 + assert feat[0].shape == torch.Size((1, 16, 112, 112)) + assert feat[1].shape == torch.Size((1, 24, 56, 56)) + assert feat[2].shape == torch.Size((1, 32, 28, 28)) + assert feat[3].shape == torch.Size((1, 64, 14, 14)) + assert feat[4].shape == torch.Size((1, 96, 14, 14)) + assert feat[5].shape == torch.Size((1, 160, 7, 7)) + assert feat[6].shape == torch.Size((1, 320, 7, 7)) + assert feat[7].shape == torch.Size((1, 1280, 7, 7)) + + # Test MobileNetV2 forward with widen_factor=0.5 + model = MobileNetV2(widen_factor=0.5, out_indices=range(0, 7)) + model.init_weights() + model.train() + + feat = model(imgs) + assert len(feat) == 7 + assert feat[0].shape == torch.Size((1, 8, 112, 112)) + assert feat[1].shape == torch.Size((1, 16, 56, 56)) + assert feat[2].shape == torch.Size((1, 16, 28, 28)) + assert feat[3].shape == torch.Size((1, 32, 14, 14)) + assert feat[4].shape == torch.Size((1, 48, 14, 14)) + assert feat[5].shape == torch.Size((1, 80, 7, 7)) + assert feat[6].shape == torch.Size((1, 160, 7, 7)) + + # Test MobileNetV2 forward with widen_factor=2.0 + model = MobileNetV2(widen_factor=2.0) + model.init_weights() + model.train() + + feat = model(imgs) + assert feat.shape == torch.Size((1, 2560, 7, 7)) + + # Test MobileNetV2 forward with out_indices=None + model = MobileNetV2(widen_factor=1.0) + model.init_weights() + model.train() + + feat = model(imgs) + assert feat.shape == torch.Size((1, 1280, 7, 7)) + + # Test MobileNetV2 forward with dict(type='ReLU') + model = MobileNetV2( + widen_factor=1.0, act_cfg=dict(type='ReLU'), out_indices=range(0, 7)) + model.init_weights() + model.train() + + feat = model(imgs) + assert len(feat) == 7 + assert feat[0].shape == torch.Size((1, 16, 112, 112)) + assert feat[1].shape == torch.Size((1, 24, 56, 56)) + assert feat[2].shape == torch.Size((1, 32, 28, 28)) + assert feat[3].shape == torch.Size((1, 64, 14, 14)) + assert feat[4].shape == torch.Size((1, 96, 14, 14)) + assert feat[5].shape == torch.Size((1, 160, 7, 7)) + assert feat[6].shape == torch.Size((1, 320, 7, 7)) + + # Test MobileNetV2 with GroupNorm forward + model = MobileNetV2(widen_factor=1.0, out_indices=range(0, 7)) + for m in model.modules(): + if is_norm(m): + assert isinstance(m, _BatchNorm) + model.init_weights() + model.train() + + feat = model(imgs) + assert len(feat) == 7 + assert feat[0].shape == torch.Size((1, 16, 112, 112)) + assert feat[1].shape == torch.Size((1, 24, 56, 56)) + assert feat[2].shape == torch.Size((1, 32, 28, 28)) + assert feat[3].shape == torch.Size((1, 64, 14, 14)) + assert feat[4].shape == torch.Size((1, 96, 14, 14)) + assert feat[5].shape == torch.Size((1, 160, 7, 7)) + assert feat[6].shape == torch.Size((1, 320, 7, 7)) + + # Test MobileNetV2 with BatchNorm forward + model = MobileNetV2( + widen_factor=1.0, + norm_cfg=dict(type='GN', num_groups=2, requires_grad=True), + out_indices=range(0, 7)) + for m in model.modules(): + if is_norm(m): + assert isinstance(m, GroupNorm) + model.init_weights() + model.train() + + feat = model(imgs) + assert len(feat) == 7 + assert feat[0].shape == torch.Size((1, 16, 112, 112)) + assert feat[1].shape == torch.Size((1, 24, 56, 56)) + assert feat[2].shape == torch.Size((1, 32, 28, 28)) + assert feat[3].shape == torch.Size((1, 64, 14, 14)) + assert feat[4].shape == torch.Size((1, 96, 14, 14)) + assert feat[5].shape == torch.Size((1, 160, 7, 7)) + assert feat[6].shape == torch.Size((1, 320, 7, 7)) + + # Test MobileNetV2 with layers 1, 3, 5 out forward + model = MobileNetV2(widen_factor=1.0, out_indices=(0, 2, 4)) + model.init_weights() + model.train() + + feat = model(imgs) + assert len(feat) == 3 + assert feat[0].shape == torch.Size((1, 16, 112, 112)) + assert feat[1].shape == torch.Size((1, 32, 28, 28)) + assert feat[2].shape == torch.Size((1, 96, 14, 14)) + + # Test MobileNetV2 with checkpoint forward + model = MobileNetV2( + widen_factor=1.0, with_cp=True, out_indices=range(0, 7)) + for m in model.modules(): + if is_block(m): + assert m.with_cp + model.init_weights() + model.train() + + feat = model(imgs) + assert len(feat) == 7 + assert feat[0].shape == torch.Size((1, 16, 112, 112)) + assert feat[1].shape == torch.Size((1, 24, 56, 56)) + assert feat[2].shape == torch.Size((1, 32, 28, 28)) + assert feat[3].shape == torch.Size((1, 64, 14, 14)) + assert feat[4].shape == torch.Size((1, 96, 14, 14)) + assert feat[5].shape == torch.Size((1, 160, 7, 7)) + assert feat[6].shape == torch.Size((1, 320, 7, 7))