From 4dd07558c1cd32d7be81c2bbbe064a5fc6cec8d3 Mon Sep 17 00:00:00 2001 From: irvingzhang0512 Date: Wed, 2 Dec 2020 18:28:00 +0800 Subject: [PATCH 01/21] [Feature] add mobilenetv2-tsm, first commit. --- ...lenetv2_video_1x1x8_50e_kinetics400_rgb.py | 127 +++++++++++++ ...2_video_dense_1x1x8_50e_kinetics400_rgb.py | 130 +++++++++++++ mmaction/models/__init__.py | 23 +-- mmaction/models/backbones/__init__.py | 4 +- mmaction/models/backbones/mobilenetv2.py | 160 ++++++++++++++++ .../backbones/mobilenetv2_torchvision.py | 176 ++++++++++++++++++ mmaction/models/backbones/mobilenetv2_tsm.py | 30 +++ tests/test_models/test_backbone.py | 56 +++++- 8 files changed, 691 insertions(+), 15 deletions(-) create mode 100644 configs/recognition/tsm/tsm_mobilenetv2_video_1x1x8_50e_kinetics400_rgb.py create mode 100644 configs/recognition/tsm/tsm_mobilenetv2_video_dense_1x1x8_50e_kinetics400_rgb.py create mode 100644 mmaction/models/backbones/mobilenetv2.py create mode 100644 mmaction/models/backbones/mobilenetv2_torchvision.py create mode 100644 mmaction/models/backbones/mobilenetv2_tsm.py diff --git a/configs/recognition/tsm/tsm_mobilenetv2_video_1x1x8_50e_kinetics400_rgb.py b/configs/recognition/tsm/tsm_mobilenetv2_video_1x1x8_50e_kinetics400_rgb.py new file mode 100644 index 0000000000..89a8477dcb --- /dev/null +++ b/configs/recognition/tsm/tsm_mobilenetv2_video_1x1x8_50e_kinetics400_rgb.py @@ -0,0 +1,127 @@ +# model settings +model = dict( + type='Recognizer2D', + backbone=dict(type='MobileNetV2TSM', shift_div=8), + cls_head=dict( + type='TSMHead', + num_segments=8, + num_classes=400, + in_channels=1280, + spatial_type='avg', + consensus=dict(type='AvgConsensus', dim=1), + dropout_ratio=0.5, + init_std=0.001, + is_shift=True)) +# model training and testing settings +train_cfg = None +test_cfg = dict(average_clips='prob') +# dataset settings +dataset_type = 'VideoDataset' +data_root = 'data/kinetics400/videos_train' +data_root_val = 'data/kinetics400/videos_val' +ann_file_train = 'data/kinetics400/kinetics400_train_list_videos.txt' +ann_file_val = 'data/kinetics400/kinetics400_val_list_videos.txt' +ann_file_test = 'data/kinetics400/kinetics400_val_list_videos.txt' +img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_bgr=False) +train_pipeline = [ + dict(type='DecordInit'), + dict(type='SampleFrames', clip_len=1, frame_interval=1, num_clips=8), + dict(type='DecordDecode'), + dict(type='Resize', scale=(-1, 256)), + dict( + type='MultiScaleCrop', + input_size=224, + scales=(1, 0.875, 0.75, 0.66), + random_crop=False, + max_wh_scale_gap=1, + num_fixed_crops=13), + dict(type='Resize', scale=(224, 224), keep_ratio=False), + dict(type='Flip', flip_ratio=0.5), + dict(type='Normalize', **img_norm_cfg), + dict(type='FormatShape', input_format='NCHW'), + dict(type='Collect', keys=['imgs', 'label'], meta_keys=[]), + dict(type='ToTensor', keys=['imgs', 'label']) +] +val_pipeline = [ + dict(type='DecordInit'), + dict( + type='SampleFrames', + clip_len=1, + frame_interval=1, + num_clips=8, + test_mode=True), + dict(type='DecordDecode'), + dict(type='Resize', scale=(-1, 256)), + dict(type='CenterCrop', crop_size=224), + dict(type='Flip', flip_ratio=0), + dict(type='Normalize', **img_norm_cfg), + dict(type='FormatShape', input_format='NCHW'), + dict(type='Collect', keys=['imgs', 'label'], meta_keys=[]), + dict(type='ToTensor', keys=['imgs']) +] +test_pipeline = [ + dict(type='DecordInit'), + dict( + type='SampleFrames', + clip_len=8, + frame_interval=8, + num_clips=10, + # twice_sample=True, + test_mode=True), + dict(type='DecordDecode'), + dict(type='Resize', scale=(-1, 256)), + # dict(type='CenterCrop', crop_size=224), + dict(type='ThreeCrop', crop_size=256), # it is used for accurate setting + dict(type='Flip', flip_ratio=0), + dict(type='Normalize', **img_norm_cfg), + dict(type='FormatShape', input_format='NCHW'), + dict(type='Collect', keys=['imgs', 'label'], meta_keys=[]), + dict(type='ToTensor', keys=['imgs']) +] +data = dict( + videos_per_gpu=4, + workers_per_gpu=8, + train=dict( + type=dataset_type, + ann_file=ann_file_train, + data_prefix=data_root, + pipeline=train_pipeline), + val=dict( + type=dataset_type, + ann_file=ann_file_val, + data_prefix=data_root_val, + pipeline=val_pipeline), + test=dict( + type=dataset_type, + ann_file=ann_file_test, + data_prefix=data_root_val, + pipeline=test_pipeline)) +# optimizer +optimizer = dict( + type='SGD', + constructor='TSMOptimizerConstructor', + paramwise_cfg=dict(fc_lr5=True), + lr=0.02, # this lr is used for 8 gpus + momentum=0.9, + weight_decay=0.0001) +optimizer_config = dict(grad_clip=dict(max_norm=20, norm_type=2)) +# learning policy +lr_config = dict(policy='step', step=[20, 40]) +total_epochs = 50 +checkpoint_config = dict(interval=5) +evaluation = dict( + interval=5, metrics=['top_k_accuracy', 'mean_class_accuracy'], topk=(1, 5)) +log_config = dict( + interval=20, + hooks=[ + dict(type='TextLoggerHook'), + # dict(type='TensorboardLoggerHook'), + ]) +# runtime settings +dist_params = dict(backend='nccl') +log_level = 'INFO' +work_dir = './work_dirs/tsm_mobilenetv2_video_1x1x8_50e_kinetics400_rgb/' +load_from = None +resume_from = None +workflow = [('train', 1)] diff --git a/configs/recognition/tsm/tsm_mobilenetv2_video_dense_1x1x8_50e_kinetics400_rgb.py b/configs/recognition/tsm/tsm_mobilenetv2_video_dense_1x1x8_50e_kinetics400_rgb.py new file mode 100644 index 0000000000..7201e1f1fd --- /dev/null +++ b/configs/recognition/tsm/tsm_mobilenetv2_video_dense_1x1x8_50e_kinetics400_rgb.py @@ -0,0 +1,130 @@ +# model settings +model = dict( + type='Recognizer2D', + backbone=dict( + type='MobileNetV2TSM', + shift_div=8, + num_segments=8, + is_shift=True, + pretrained=True), + cls_head=dict( + type='TSMHead', + num_segments=8, + num_classes=400, + in_channels=1280, + spatial_type='avg', + consensus=dict(type='AvgConsensus', dim=1), + dropout_ratio=0.5, + init_std=0.001, + is_shift=True)) +# model training and testing settings +train_cfg = None +test_cfg = dict(average_clips='prob') +# dataset settings +dataset_type = 'VideoDataset' +data_root = 'data/kinetics400/videos_train' +data_root_val = 'data/kinetics400/videos_val' +ann_file_train = 'data/kinetics400/kinetics400_train_list_videos.txt' +ann_file_val = 'data/kinetics400/kinetics400_val_list_videos.txt' +ann_file_test = 'data/kinetics400/kinetics400_val_list_videos.txt' +img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_bgr=False) +train_pipeline = [ + dict(type='DecordInit'), + dict(type='DenseSampleFrames', clip_len=1, frame_interval=1, num_clips=8), + dict(type='DecordDecode'), + dict(type='Resize', scale=(-1, 256)), + dict( + type='MultiScaleCrop', + input_size=224, + scales=(1, 0.875, 0.75, 0.66), + random_crop=False, + max_wh_scale_gap=1, + num_fixed_crops=13), + dict(type='Resize', scale=(224, 224), keep_ratio=False), + dict(type='Flip', flip_ratio=0.5), + dict(type='Normalize', **img_norm_cfg), + dict(type='FormatShape', input_format='NCHW'), + dict(type='Collect', keys=['imgs', 'label'], meta_keys=[]), + dict(type='ToTensor', keys=['imgs', 'label']) +] +val_pipeline = [ + dict(type='DecordInit'), + dict( + type='DenseSampleFrames', + clip_len=1, + frame_interval=1, + num_clips=8, + test_mode=True), + dict(type='DecordDecode'), + dict(type='Resize', scale=(-1, 256)), + dict(type='CenterCrop', crop_size=224), + dict(type='Flip', flip_ratio=0), + dict(type='Normalize', **img_norm_cfg), + dict(type='FormatShape', input_format='NCHW'), + dict(type='Collect', keys=['imgs', 'label'], meta_keys=[]), + dict(type='ToTensor', keys=['imgs']) +] +test_pipeline = [ + dict(type='DecordInit'), + dict( + type='DenseSampleFrames', + clip_len=1, + frame_interval=1, + num_clips=8, + test_mode=True), + dict(type='DecordDecode'), + dict(type='Resize', scale=(-1, 256)), + dict(type='ThreeCrop', crop_size=224), + dict(type='Flip', flip_ratio=0), + dict(type='Normalize', **img_norm_cfg), + dict(type='FormatShape', input_format='NCHW'), + dict(type='Collect', keys=['imgs', 'label'], meta_keys=[]), + dict(type='ToTensor', keys=['imgs']) +] +data = dict( + videos_per_gpu=4, + workers_per_gpu=8, + train=dict( + type=dataset_type, + ann_file=ann_file_train, + data_prefix=data_root, + pipeline=train_pipeline), + val=dict( + type=dataset_type, + ann_file=ann_file_val, + data_prefix=data_root_val, + pipeline=val_pipeline), + test=dict( + type=dataset_type, + ann_file=ann_file_test, + data_prefix=data_root_val, + pipeline=test_pipeline)) +# optimizer +optimizer = dict( + type='SGD', + constructor='TSMOptimizerConstructor', + paramwise_cfg=dict(fc_lr5=True), + lr=0.035, # this lr is used for 7 gpus + momentum=0.9, + weight_decay=0.0001) +optimizer_config = dict(grad_clip=dict(max_norm=20, norm_type=2)) +# learning policy +lr_config = dict(policy='step', step=[40, 80]) +total_epochs = 100 +checkpoint_config = dict(interval=5) +evaluation = dict( + interval=1, metrics=['top_k_accuracy', 'mean_class_accuracy'], topk=(1, 5)) +log_config = dict( + interval=100, + hooks=[ + dict(type='TextLoggerHook'), + # dict(type='TensorboardLoggerHook'), + ]) +# runtime settings +dist_params = dict(backend='nccl') +log_level = 'INFO' +work_dir = './work_dirs/tsm_mobilenetv2_dense_video_1x1x8_50e_kinetics400_rgb/' +load_from = None +resume_from = None +workflow = [('train', 1)] diff --git a/mmaction/models/__init__.py b/mmaction/models/__init__.py index a2ef283058..9c3317f83a 100644 --- a/mmaction/models/__init__.py +++ b/mmaction/models/__init__.py @@ -1,6 +1,6 @@ -from .backbones import (C3D, X3D, ResNet, ResNet2Plus1d, ResNet3d, ResNet3dCSN, - ResNet3dSlowFast, ResNet3dSlowOnly, ResNetAudio, - ResNetTIN, ResNetTSM) +from .backbones import (C3D, X3D, MobileNetV2, MobileNetV2TSM, ResNet, + ResNet2Plus1d, ResNet3d, ResNet3dCSN, ResNet3dSlowFast, + ResNet3dSlowOnly, ResNetAudio, ResNetTIN, ResNetTSM) from .builder import (build_backbone, build_head, build_localizer, build_loss, build_model, build_neck, build_recognizer) from .common import Conv2plus1d, ConvAudio @@ -18,12 +18,13 @@ __all__ = [ 'BACKBONES', 'HEADS', 'RECOGNIZERS', 'build_recognizer', 'build_head', 'build_backbone', 'recognizer2d', 'recognizer3d', 'C3D', 'ResNet', - 'ResNet3d', 'ResNet2Plus1d', 'I3DHead', 'TSNHead', 'TSMHead', 'BaseHead', - 'BaseRecognizer', 'LOSSES', 'CrossEntropyLoss', 'NLLLoss', 'HVULoss', - 'ResNetTSM', 'ResNet3dSlowFast', 'SlowFastHead', 'Conv2plus1d', - 'ResNet3dSlowOnly', 'BCELossWithLogits', 'LOCALIZERS', 'build_localizer', - 'PEM', 'TEM', 'BinaryLogisticRegressionLoss', 'BMN', 'BMNLoss', - 'build_model', 'OHEMHingeLoss', 'SSNLoss', 'ResNet3dCSN', 'ResNetTIN', - 'TPN', 'TPNHead', 'build_loss', 'build_neck', 'AudioRecognizer', - 'AudioTSNHead', 'X3D', 'X3DHead', 'ResNetAudio', 'ConvAudio' + 'MobileNetV2', 'ResNet3d', 'ResNet2Plus1d', 'I3DHead', 'TSNHead', + 'TSMHead', 'BaseHead', 'BaseRecognizer', 'LOSSES', 'CrossEntropyLoss', + 'NLLLoss', 'HVULoss', 'ResNetTSM', 'MobileNetV2TSM', 'ResNet3dSlowFast', + 'SlowFastHead', 'Conv2plus1d', 'ResNet3dSlowOnly', 'BCELossWithLogits', + 'LOCALIZERS', 'build_localizer', 'PEM', 'TEM', + 'BinaryLogisticRegressionLoss', 'BMN', 'BMNLoss', 'build_model', + 'OHEMHingeLoss', 'SSNLoss', 'ResNet3dCSN', 'ResNetTIN', 'TPN', 'TPNHead', + 'build_loss', 'build_neck', 'AudioRecognizer', 'AudioTSNHead', 'X3D', + 'X3DHead', 'ResNetAudio', 'ConvAudio' ] diff --git a/mmaction/models/backbones/__init__.py b/mmaction/models/backbones/__init__.py index cf51746446..70a77bd036 100644 --- a/mmaction/models/backbones/__init__.py +++ b/mmaction/models/backbones/__init__.py @@ -1,4 +1,6 @@ from .c3d import C3D +from .mobilenetv2 import MobileNetV2 +from .mobilenetv2_tsm import MobileNetV2TSM from .resnet import ResNet from .resnet2plus1d import ResNet2Plus1d from .resnet3d import ResNet3d @@ -13,5 +15,5 @@ __all__ = [ 'C3D', 'ResNet', 'ResNet3d', 'ResNetTSM', 'ResNet2Plus1d', 'ResNet3dSlowFast', 'ResNet3dSlowOnly', 'ResNet3dCSN', 'ResNetTIN', 'X3D', - 'ResNetAudio' + 'ResNetAudio', 'MobileNetV2TSM', 'MobileNetV2' ] diff --git a/mmaction/models/backbones/mobilenetv2.py b/mmaction/models/backbones/mobilenetv2.py new file mode 100644 index 0000000000..9adbbec3e4 --- /dev/null +++ b/mmaction/models/backbones/mobilenetv2.py @@ -0,0 +1,160 @@ +import math + +import torch.nn as nn + + +def conv_bn(inp, oup, stride): + return nn.Sequential( + nn.Conv2d(inp, oup, 3, stride, 1, bias=False), nn.BatchNorm2d(oup), + nn.ReLU6(inplace=True)) + + +def conv_1x1_bn(inp, oup): + return nn.Sequential( + nn.Conv2d(inp, oup, 1, 1, 0, bias=False), nn.BatchNorm2d(oup), + nn.ReLU6(inplace=True)) + + +def make_divisible(x, divisible_by=8): + import numpy as np + return int(np.ceil(x * 1. / divisible_by) * divisible_by) + + +class InvertedResidual(nn.Module): + + def __init__(self, inp, oup, stride, expand_ratio): + super(InvertedResidual, self).__init__() + self.stride = stride + assert stride in [1, 2] + + hidden_dim = int(inp * expand_ratio) + self.use_res_connect = self.stride == 1 and inp == oup + + if expand_ratio == 1: + self.conv = nn.Sequential( + # dw + nn.Conv2d( + hidden_dim, + hidden_dim, + 3, + stride, + 1, + groups=hidden_dim, + bias=False), + nn.BatchNorm2d(hidden_dim), + nn.ReLU6(inplace=True), + # pw-linear + nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), + nn.BatchNorm2d(oup), + ) + else: + self.conv = nn.Sequential( + # pw + nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False), + nn.BatchNorm2d(hidden_dim), + nn.ReLU6(inplace=True), + # dw + nn.Conv2d( + hidden_dim, + hidden_dim, + 3, + stride, + 1, + groups=hidden_dim, + bias=False), + nn.BatchNorm2d(hidden_dim), + nn.ReLU6(inplace=True), + # pw-linear + nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), + nn.BatchNorm2d(oup), + ) + + def forward(self, x): + if self.use_res_connect: + return x + self.conv(x) + else: + return self.conv(x) + + +class MobileNetV2(nn.Module): + + def __init__(self, + width_mult=1., + inverted_residual_setting=None, + round_nearest=8, + pretrained=False): + super(MobileNetV2, self).__init__() + self.pretrained = pretrained + block = InvertedResidual + input_channel = 32 + last_channel = 1280 + if inverted_residual_setting is None: + interverted_residual_setting = [ + # t/expand_ratio, c/output_channels, n/num_of_blocks, s/stride + [1, 16, 1, 1], + [6, 24, 2, 2], + [6, 32, 3, 2], + [6, 64, 4, 2], + [6, 96, 3, 1], + [6, 160, 3, 2], + [6, 320, 1, 1], + ] + + # input_channel = make_divisible(input_channel * width_mult) + self.last_channel = make_divisible( + last_channel * + width_mult, round_nearest) if width_mult > 1.0 else last_channel + + self.features = [conv_bn(3, input_channel, 2)] + + # building inverted residual blocks + for t, c, n, s in interverted_residual_setting: + output_channel = make_divisible(c * width_mult, + round_nearest) if t > 1 else c + + for i in range(n): + if i == 0: + self.features.append( + block( + input_channel, output_channel, s, expand_ratio=t)) + else: + self.features.append( + block( + input_channel, output_channel, 1, expand_ratio=t)) + input_channel = output_channel + # building last several layers + self.features.append(conv_1x1_bn(input_channel, self.last_channel)) + # make it nn.Sequential + self.features = nn.Sequential(*self.features) + + def forward(self, x): + x = self.features(x) + return x + + def init_weights(self): + if self.pretrained: + try: + from torch.hub import load_state_dict_from_url + except ImportError: + from torch.utils.model_zoo import \ + load_url as load_state_dict_from_url + state_dict = load_state_dict_from_url( + 'https://www.dropbox.com/s/47tyzpofuuyyv1b/mobilenetv2_1.0-f2a8633.pth.tar?dl=1', # noqa + progress=True) + del state_dict['classifier.weight'] + del state_dict['classifier.bias'] + self.load_state_dict(state_dict) + else: + for m in self.modules(): + if isinstance(m, nn.Conv2d): + n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + m.weight.data.normal_(0, math.sqrt(2. / n)) + if m.bias is not None: + m.bias.data.zero_() + elif isinstance(m, nn.BatchNorm2d): + m.weight.data.fill_(1) + m.bias.data.zero_() + elif isinstance(m, nn.Linear): + n = m.weight.size(1) + m.weight.data.normal_(0, 0.01) + m.bias.data.zero_() diff --git a/mmaction/models/backbones/mobilenetv2_torchvision.py b/mmaction/models/backbones/mobilenetv2_torchvision.py new file mode 100644 index 0000000000..f34168f508 --- /dev/null +++ b/mmaction/models/backbones/mobilenetv2_torchvision.py @@ -0,0 +1,176 @@ +from torch import nn + +try: + from torch.hub import load_state_dict_from_url +except ImportError: + from torch.utils.model_zoo import \ + load_url as load_state_dict_from_url + +model_urls = { + 'mobilenet_v2': + 'https://download.pytorch.org/models/mobilenet_v2-b0353104.pth', +} + + +def _make_divisible(v, divisor, min_value=None): + """This function is taken from the original tf repo. It ensures that all + layers have a channel number that is divisible by 8 It can be seen here: ht + tps://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet + . + + /mobilenet.py # noqa. + + :param v: + :param divisor: + :param min_value: + :return: + """ + if min_value is None: + min_value = divisor + new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) + # Make sure that round down does not go down by more than 10%. + if new_v < 0.9 * v: + new_v += divisor + return new_v + + +class ConvBNReLU(nn.Sequential): + + def __init__(self, + in_planes, + out_planes, + kernel_size=3, + stride=1, + groups=1): + padding = (kernel_size - 1) // 2 + super(ConvBNReLU, self).__init__( + nn.Conv2d( + in_planes, + out_planes, + kernel_size, + stride, + padding, + groups=groups, + bias=False), nn.BatchNorm2d(out_planes), + nn.ReLU6(inplace=True)) + + +class InvertedResidual(nn.Module): + + def __init__(self, inp, oup, stride, expand_ratio): + super(InvertedResidual, self).__init__() + self.stride = stride + assert stride in [1, 2] + + hidden_dim = int(round(inp * expand_ratio)) + self.use_res_connect = self.stride == 1 and inp == oup + + layers = [] + if expand_ratio != 1: + # pw + layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1)) + layers.extend([ + # dw + ConvBNReLU( + hidden_dim, hidden_dim, stride=stride, groups=hidden_dim), + # pw-linear + nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), + nn.BatchNorm2d(oup), + ]) + self.conv = nn.Sequential(*layers) + + def forward(self, x): + if self.use_res_connect: + return x + self.conv(x) + else: + return self.conv(x) + + +class MobileNetV2(nn.Module): + + def __init__(self, + pretrained=False, + width_mult=1.0, + inverted_residual_setting=None, + round_nearest=8, + block=None): + """MobileNet V2 main class. + + Args: + num_classes (int): Number of classes + width_mult (float): Width multiplier - adjusts number of channels + in each layer by this amount + inverted_residual_setting: Network structure + round_nearest (int): Round the number of channels in each layer to + be a multiple of this number + Set to 1 to turn off rounding + block: Module specifying inverted residual building block for + mobilenet + """ + super(MobileNetV2, self).__init__() + + if block is None: + block = InvertedResidual + input_channel = 32 + last_channel = 1280 + + if inverted_residual_setting is None: + inverted_residual_setting = [ + # t, c, n, s + [1, 16, 1, 1], + [6, 24, 2, 2], + [6, 32, 3, 2], + [6, 64, 4, 2], + [6, 96, 3, 1], + [6, 160, 3, 2], + [6, 320, 1, 1], + ] + + # only check the first element + # assuming user knows t,c,n,s are required + if len(inverted_residual_setting) == 0 or len( + inverted_residual_setting[0]) != 4: + raise ValueError('inverted_residual_setting should be non-empty ' + 'or a 4-element list, got {}'.format( + inverted_residual_setting)) + + # building first layer + input_channel = _make_divisible(input_channel * width_mult, + round_nearest) + self.last_channel = _make_divisible( + last_channel * max(1.0, width_mult), round_nearest) + features = [ConvBNReLU(3, input_channel, stride=2)] + # building inverted residual blocks + for t, c, n, s in inverted_residual_setting: + output_channel = _make_divisible(c * width_mult, round_nearest) + for i in range(n): + stride = s if i == 0 else 1 + features.append( + block( + input_channel, output_channel, stride, expand_ratio=t)) + input_channel = output_channel + # building last several layers + features.append( + ConvBNReLU(input_channel, self.last_channel, kernel_size=1)) + # make it nn.Sequential + self.features = nn.Sequential(*features) + + def forward(self, x): + return self.features(x) + + def init_weights(self): + # weight initialization + if self.pretrained: + state_dict = load_state_dict_from_url(model_urls['mobilenet_v2']) + self.load_state_dict(state_dict) + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out') + if m.bias is not None: + nn.init.zeros_(m.bias) + elif isinstance(m, nn.BatchNorm2d): + nn.init.ones_(m.weight) + nn.init.zeros_(m.bias) + elif isinstance(m, nn.Linear): + nn.init.normal_(m.weight, 0, 0.01) + nn.init.zeros_(m.bias) diff --git a/mmaction/models/backbones/mobilenetv2_tsm.py b/mmaction/models/backbones/mobilenetv2_tsm.py new file mode 100644 index 0000000000..7fd0ee0a82 --- /dev/null +++ b/mmaction/models/backbones/mobilenetv2_tsm.py @@ -0,0 +1,30 @@ +from ..registry import BACKBONES +from .mobilenetv2 import InvertedResidual, MobileNetV2 +from .resnet_tsm import TemporalShift + + +@BACKBONES.register_module() +class MobileNetV2TSM(MobileNetV2): + + def __init__(self, num_segments=8, is_shift=True, shift_div=8, **kwargs): + super().__init__(**kwargs) + self.num_segments = num_segments + self.is_shift = is_shift + self.shift_div = shift_div + + def make_temporal_shift(self): + for m in self.modules(): + if isinstance(m, InvertedResidual) and \ + len(m.conv) == 8 and m.use_res_connect: + m.conv[0] = TemporalShift( + m.conv[0], + num_segments=self.num_segments, + shift_div=self.shift_div, + ) + + def init_weights(self): + """Initiate the parameters either from existing checkpoint or from + scratch.""" + super().init_weights() + if self.is_shift: + self.make_temporal_shift() diff --git a/tests/test_models/test_backbone.py b/tests/test_models/test_backbone.py index f6931f7911..026b578e98 100644 --- a/tests/test_models/test_backbone.py +++ b/tests/test_models/test_backbone.py @@ -6,9 +6,10 @@ import torch.nn as nn from mmcv.utils import _BatchNorm -from mmaction.models import (C3D, X3D, ResNet, ResNet2Plus1d, ResNet3d, - ResNet3dCSN, ResNet3dSlowFast, ResNet3dSlowOnly, - ResNetAudio, ResNetTIN, ResNetTSM) +from mmaction.models import (C3D, X3D, MobileNetV2, MobileNetV2TSM, ResNet, + ResNet2Plus1d, ResNet3d, ResNet3dCSN, + ResNet3dSlowFast, ResNet3dSlowOnly, ResNetAudio, + ResNetTIN, ResNetTSM) from mmaction.models.backbones.resnet_tsm import NL3DWrapper @@ -127,6 +128,26 @@ def test_resnet_backbone(): assert feat.shape == torch.Size([1, 2048, 2, 2]) +def test_mobilenetv2_backbone(): + """Test MobileNetV2 backbone.""" + with pytest.raises(KeyError): + MobileNetV2() + + with pytest.raises(TypeError): + # pretrain is a bool + mobilenetv2 = MobileNetV2(50, pretrained=True) + mobilenetv2.init_weights() + + input_shape = (1, 3, 64, 64) + imgs = _demo_inputs(input_shape) + + # resnet with depth 18 inference + mobilenetv2 = MobileNetV2() + mobilenetv2.init_weights() + feat = mobilenetv2(imgs) + assert feat.shape == torch.Size([1, 1280, 2, 2]) + + def test_x3d_backbone(): """Test resnet3d backbone.""" with pytest.raises(AssertionError): @@ -726,6 +747,35 @@ def test_resnet_tsm_backbone(): assert feat.shape == torch.Size([8, 2048, 1, 1]) +def test_mobilenetv2_tsm_backbone(): + """Test mobilenetv2_tsm backbone.""" + with pytest.raises(NotImplementedError): + # shift_place must be block or blockres + mobilenetv2_tsm = MobileNetV2TSM() + mobilenetv2_tsm.init_weights() + + from mmaction.models.backbones.resnet_tsm import TemporalShift + + input_shape = (8, 3, 64, 64) + imgs = _demo_inputs(input_shape) + + # resnet_tsm with depth 50 + mobilenetv2_tsm = MobileNetV2TSM(50) + mobilenetv2_tsm.init_weights() + for layer_name in mobilenetv2_tsm.res_layers: + layer = getattr(mobilenetv2_tsm, layer_name) + blocks = list(layer.children()) + for block in blocks: + assert isinstance(block.conv1.conv, TemporalShift) + assert block.conv1.num_segments == mobilenetv2_tsm.num_segments + assert block.conv1.conv.shift_div == mobilenetv2_tsm.shift_div + assert isinstance(block.conv1.conv.net, nn.Conv2d) + + # TSM-MobileNetV2 forword + feat = mobilenetv2_tsm(imgs) + assert feat.shape == torch.Size([8, 2048, 2, 2]) + + def test_slowfast_backbone(): """Test SlowFast backbone.""" with pytest.raises(TypeError): From bf6af41fa0d6b1c09e08dd15fe585818b6b08cf3 Mon Sep 17 00:00:00 2001 From: irvingzhang0512 Date: Thu, 3 Dec 2020 15:53:57 +0800 Subject: [PATCH 02/21] 1. remove torchvision mobilentv2. 2. update unittest. 3. refactor mobilentv2 with torchvision version. --- mmaction/models/backbones/mobilenetv2.py | 63 +++++-- .../backbones/mobilenetv2_torchvision.py | 176 ------------------ tests/test_models/test_backbone.py | 68 ++++--- 3 files changed, 87 insertions(+), 220 deletions(-) delete mode 100644 mmaction/models/backbones/mobilenetv2_torchvision.py diff --git a/mmaction/models/backbones/mobilenetv2.py b/mmaction/models/backbones/mobilenetv2.py index 9adbbec3e4..8939ad8bc0 100644 --- a/mmaction/models/backbones/mobilenetv2.py +++ b/mmaction/models/backbones/mobilenetv2.py @@ -82,14 +82,34 @@ def __init__(self, width_mult=1., inverted_residual_setting=None, round_nearest=8, + block=None, pretrained=False): + """MobileNet V2 main class. + + Args: + width_mult (float): Width multiplier - adjusts number of channels + in each layer by this amount. + inverted_residual_setting: Network structure. + round_nearest (int): Round the number of channels in each layer to + be a multiple of this number. Set to 1 to turn off rounding. + block (nn.Module): Module specifying inverted residual building + block for mobilenet. + pretrained (bool): whether to load pretrained checkpoints. + """ super(MobileNetV2, self).__init__() self.pretrained = pretrained - block = InvertedResidual + + if abs(width_mult - 1.0) > 1e-5 and pretrained: + raise ValueError('MobileNetV2 only supports one pretrained model ' + 'with `width_mult=1.0`.') + + if block is None: + block = InvertedResidual input_channel = 32 last_channel = 1280 + if inverted_residual_setting is None: - interverted_residual_setting = [ + inverted_residual_setting = [ # t/expand_ratio, c/output_channels, n/num_of_blocks, s/stride [1, 16, 1, 1], [6, 24, 2, 2], @@ -100,36 +120,41 @@ def __init__(self, [6, 320, 1, 1], ] - # input_channel = make_divisible(input_channel * width_mult) - self.last_channel = make_divisible( - last_channel * - width_mult, round_nearest) if width_mult > 1.0 else last_channel + # only check the first element + # assuming user knows t,c,n,s are required + if len(inverted_residual_setting) == 0 or len( + inverted_residual_setting[0]) != 4: + raise ValueError('inverted_residual_setting should be non-empty ' + 'or a 4-element list, got {}'.format( + inverted_residual_setting)) + input_channel = make_divisible(input_channel * width_mult, + round_nearest) + self.last_channel = make_divisible(last_channel * max(1.0, width_mult), + round_nearest) + + # first layer self.features = [conv_bn(3, input_channel, 2)] # building inverted residual blocks - for t, c, n, s in interverted_residual_setting: - output_channel = make_divisible(c * width_mult, - round_nearest) if t > 1 else c + for t, c, n, s in inverted_residual_setting: + output_channel = make_divisible(c * width_mult, round_nearest) for i in range(n): - if i == 0: - self.features.append( - block( - input_channel, output_channel, s, expand_ratio=t)) - else: - self.features.append( - block( - input_channel, output_channel, 1, expand_ratio=t)) + stride = s if i == 0 else 1 + self.features.append( + block( + input_channel, output_channel, stride, expand_ratio=t)) input_channel = output_channel + # building last several layers self.features.append(conv_1x1_bn(input_channel, self.last_channel)) + # make it nn.Sequential self.features = nn.Sequential(*self.features) def forward(self, x): - x = self.features(x) - return x + return self.features(x) def init_weights(self): if self.pretrained: diff --git a/mmaction/models/backbones/mobilenetv2_torchvision.py b/mmaction/models/backbones/mobilenetv2_torchvision.py deleted file mode 100644 index f34168f508..0000000000 --- a/mmaction/models/backbones/mobilenetv2_torchvision.py +++ /dev/null @@ -1,176 +0,0 @@ -from torch import nn - -try: - from torch.hub import load_state_dict_from_url -except ImportError: - from torch.utils.model_zoo import \ - load_url as load_state_dict_from_url - -model_urls = { - 'mobilenet_v2': - 'https://download.pytorch.org/models/mobilenet_v2-b0353104.pth', -} - - -def _make_divisible(v, divisor, min_value=None): - """This function is taken from the original tf repo. It ensures that all - layers have a channel number that is divisible by 8 It can be seen here: ht - tps://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet - . - - /mobilenet.py # noqa. - - :param v: - :param divisor: - :param min_value: - :return: - """ - if min_value is None: - min_value = divisor - new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) - # Make sure that round down does not go down by more than 10%. - if new_v < 0.9 * v: - new_v += divisor - return new_v - - -class ConvBNReLU(nn.Sequential): - - def __init__(self, - in_planes, - out_planes, - kernel_size=3, - stride=1, - groups=1): - padding = (kernel_size - 1) // 2 - super(ConvBNReLU, self).__init__( - nn.Conv2d( - in_planes, - out_planes, - kernel_size, - stride, - padding, - groups=groups, - bias=False), nn.BatchNorm2d(out_planes), - nn.ReLU6(inplace=True)) - - -class InvertedResidual(nn.Module): - - def __init__(self, inp, oup, stride, expand_ratio): - super(InvertedResidual, self).__init__() - self.stride = stride - assert stride in [1, 2] - - hidden_dim = int(round(inp * expand_ratio)) - self.use_res_connect = self.stride == 1 and inp == oup - - layers = [] - if expand_ratio != 1: - # pw - layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1)) - layers.extend([ - # dw - ConvBNReLU( - hidden_dim, hidden_dim, stride=stride, groups=hidden_dim), - # pw-linear - nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), - nn.BatchNorm2d(oup), - ]) - self.conv = nn.Sequential(*layers) - - def forward(self, x): - if self.use_res_connect: - return x + self.conv(x) - else: - return self.conv(x) - - -class MobileNetV2(nn.Module): - - def __init__(self, - pretrained=False, - width_mult=1.0, - inverted_residual_setting=None, - round_nearest=8, - block=None): - """MobileNet V2 main class. - - Args: - num_classes (int): Number of classes - width_mult (float): Width multiplier - adjusts number of channels - in each layer by this amount - inverted_residual_setting: Network structure - round_nearest (int): Round the number of channels in each layer to - be a multiple of this number - Set to 1 to turn off rounding - block: Module specifying inverted residual building block for - mobilenet - """ - super(MobileNetV2, self).__init__() - - if block is None: - block = InvertedResidual - input_channel = 32 - last_channel = 1280 - - if inverted_residual_setting is None: - inverted_residual_setting = [ - # t, c, n, s - [1, 16, 1, 1], - [6, 24, 2, 2], - [6, 32, 3, 2], - [6, 64, 4, 2], - [6, 96, 3, 1], - [6, 160, 3, 2], - [6, 320, 1, 1], - ] - - # only check the first element - # assuming user knows t,c,n,s are required - if len(inverted_residual_setting) == 0 or len( - inverted_residual_setting[0]) != 4: - raise ValueError('inverted_residual_setting should be non-empty ' - 'or a 4-element list, got {}'.format( - inverted_residual_setting)) - - # building first layer - input_channel = _make_divisible(input_channel * width_mult, - round_nearest) - self.last_channel = _make_divisible( - last_channel * max(1.0, width_mult), round_nearest) - features = [ConvBNReLU(3, input_channel, stride=2)] - # building inverted residual blocks - for t, c, n, s in inverted_residual_setting: - output_channel = _make_divisible(c * width_mult, round_nearest) - for i in range(n): - stride = s if i == 0 else 1 - features.append( - block( - input_channel, output_channel, stride, expand_ratio=t)) - input_channel = output_channel - # building last several layers - features.append( - ConvBNReLU(input_channel, self.last_channel, kernel_size=1)) - # make it nn.Sequential - self.features = nn.Sequential(*features) - - def forward(self, x): - return self.features(x) - - def init_weights(self): - # weight initialization - if self.pretrained: - state_dict = load_state_dict_from_url(model_urls['mobilenet_v2']) - self.load_state_dict(state_dict) - for m in self.modules(): - if isinstance(m, nn.Conv2d): - nn.init.kaiming_normal_(m.weight, mode='fan_out') - if m.bias is not None: - nn.init.zeros_(m.bias) - elif isinstance(m, nn.BatchNorm2d): - nn.init.ones_(m.weight) - nn.init.zeros_(m.bias) - elif isinstance(m, nn.Linear): - nn.init.normal_(m.weight, 0, 0.01) - nn.init.zeros_(m.bias) diff --git a/tests/test_models/test_backbone.py b/tests/test_models/test_backbone.py index 026b578e98..d8d7530e30 100644 --- a/tests/test_models/test_backbone.py +++ b/tests/test_models/test_backbone.py @@ -130,23 +130,32 @@ def test_resnet_backbone(): def test_mobilenetv2_backbone(): """Test MobileNetV2 backbone.""" - with pytest.raises(KeyError): - MobileNetV2() - with pytest.raises(TypeError): # pretrain is a bool - mobilenetv2 = MobileNetV2(50, pretrained=True) + mobilenetv2 = MobileNetV2(pretrained='') mobilenetv2.init_weights() input_shape = (1, 3, 64, 64) imgs = _demo_inputs(input_shape) - # resnet with depth 18 inference - mobilenetv2 = MobileNetV2() + # mobilenetv2 with width_mult = 1.0, pretrained + mobilenetv2 = MobileNetV2(pretrained=True) mobilenetv2.init_weights() feat = mobilenetv2(imgs) assert feat.shape == torch.Size([1, 1280, 2, 2]) + # mobilenetv2 with width_mult = 0.5 + mobilenetv2 = MobileNetV2(width_mult=0.5) + mobilenetv2.init_weights() + feat = mobilenetv2(imgs) + assert feat.shape == torch.Size([1, 1280, 2, 2]) + + # mobilenetv2 with width_mult = 1.5 + mobilenetv2 = MobileNetV2(width_mult=1.5) + mobilenetv2.init_weights() + feat = mobilenetv2(imgs) + assert feat.shape == torch.Size([1, 1920, 2, 2]) + def test_x3d_backbone(): """Test resnet3d backbone.""" @@ -749,31 +758,40 @@ def test_resnet_tsm_backbone(): def test_mobilenetv2_tsm_backbone(): """Test mobilenetv2_tsm backbone.""" - with pytest.raises(NotImplementedError): - # shift_place must be block or blockres - mobilenetv2_tsm = MobileNetV2TSM() - mobilenetv2_tsm.init_weights() - from mmaction.models.backbones.resnet_tsm import TemporalShift + from mmaction.models.backbones.mobilenetv2 import InvertedResidual - input_shape = (8, 3, 64, 64) + input_shape = (4, 3, 64, 64) imgs = _demo_inputs(input_shape) - # resnet_tsm with depth 50 - mobilenetv2_tsm = MobileNetV2TSM(50) + # mobilenetv2_tsm with width_mult = 1.0 + mobilenetv2_tsm = MobileNetV2TSM() mobilenetv2_tsm.init_weights() - for layer_name in mobilenetv2_tsm.res_layers: - layer = getattr(mobilenetv2_tsm, layer_name) - blocks = list(layer.children()) - for block in blocks: - assert isinstance(block.conv1.conv, TemporalShift) - assert block.conv1.num_segments == mobilenetv2_tsm.num_segments - assert block.conv1.conv.shift_div == mobilenetv2_tsm.shift_div - assert isinstance(block.conv1.conv.net, nn.Conv2d) - - # TSM-MobileNetV2 forword + for cur_module in mobilenetv2_tsm.modules(): + if isinstance(cur_module, InvertedResidual) and \ + len(cur_module.conv) == 8 and \ + cur_module.use_res_connect: + assert isinstance(cur_module.conv[0], TemporalShift) + assert cur_module.conv[0].num_segments == \ + mobilenetv2_tsm.num_segments + assert cur_module.conv[0].shift_div == mobilenetv2_tsm.shift_div + assert isinstance(cur_module.conv[0].net, nn.Conv2d) + + # TSM-MobileNetV2 with width_mult = 1.0 forword feat = mobilenetv2_tsm(imgs) - assert feat.shape == torch.Size([8, 2048, 2, 2]) + assert feat.shape == torch.Size([4, 1280, 2, 2]) + + # mobilenetv2 with width_mult = 0.5 forword + mobilenetv2_tsm_05 = MobileNetV2TSM(width_mult=0.5) + mobilenetv2_tsm_05.init_weights() + feat = mobilenetv2_tsm_05(imgs) + assert feat.shape == torch.Size([4, 1280, 2, 2]) + + # mobilenetv2 with width_mult = 1.5 forword + mobilenetv2_tsm_15 = MobileNetV2TSM(width_mult=1.5) + mobilenetv2_tsm_15.init_weights() + feat = mobilenetv2_tsm_15(imgs) + assert feat.shape == torch.Size([4, 1920, 2, 2]) def test_slowfast_backbone(): From d465376ca457a8536da7bdfbde0d4e5717e63d38 Mon Sep 17 00:00:00 2001 From: irvingzhang0512 Date: Thu, 3 Dec 2020 16:24:34 +0800 Subject: [PATCH 03/21] fix unittest bug and update annotations. --- mmaction/models/backbones/mobilenetv2.py | 35 +++++++++++++++++--- mmaction/models/backbones/mobilenetv2_tsm.py | 10 ++++++ tests/test_models/test_backbone.py | 5 --- 3 files changed, 41 insertions(+), 9 deletions(-) diff --git a/mmaction/models/backbones/mobilenetv2.py b/mmaction/models/backbones/mobilenetv2.py index 8939ad8bc0..1fb74712e1 100644 --- a/mmaction/models/backbones/mobilenetv2.py +++ b/mmaction/models/backbones/mobilenetv2.py @@ -15,14 +15,38 @@ def conv_1x1_bn(inp, oup): nn.ReLU6(inplace=True)) -def make_divisible(x, divisible_by=8): - import numpy as np - return int(np.ceil(x * 1. / divisible_by) * divisible_by) +def make_divisible(v, divisor, min_value=None): + """This function is taken from the original tf repo. + + It ensures that all layers have a channel number that is divisible by 8. + It can be seen here: + https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py # noqa + Args: + v (float): original number of channels. + divisor (float): Round the number of channels in each layer to + be a multiple of this number. Set to 1 to turn off rounding. + min_value (int): minimal value to return + """ + if min_value is None: + min_value = divisor + new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) + # Make sure that round down does not go down by more than 10%. + if new_v < 0.9 * v: + new_v += divisor + return new_v class InvertedResidual(nn.Module): def __init__(self, inp, oup, stride, expand_ratio): + """Inverted Residual Mobule from MobilNetV2. + + Args: + inp (int): number of input channels. + oup (int): number of output channels. + stride (int): stride for depthwise convolution. + expand_ratio (int): expand ratio for hidden layers. + """ super(InvertedResidual, self).__init__() self.stride = stride assert stride in [1, 2] @@ -89,7 +113,7 @@ def __init__(self, Args: width_mult (float): Width multiplier - adjusts number of channels in each layer by this amount. - inverted_residual_setting: Network structure. + inverted_residual_setting (list): Network structure. round_nearest (int): Round the number of channels in each layer to be a multiple of this number. Set to 1 to turn off rounding. block (nn.Module): Module specifying inverted residual building @@ -157,6 +181,9 @@ def forward(self, x): return self.features(x) def init_weights(self): + """Initiate the parameters either from existing checkpoint or from + scratch.""" + if self.pretrained: try: from torch.hub import load_state_dict_from_url diff --git a/mmaction/models/backbones/mobilenetv2_tsm.py b/mmaction/models/backbones/mobilenetv2_tsm.py index 7fd0ee0a82..c22c3b7e4b 100644 --- a/mmaction/models/backbones/mobilenetv2_tsm.py +++ b/mmaction/models/backbones/mobilenetv2_tsm.py @@ -5,6 +5,15 @@ @BACKBONES.register_module() class MobileNetV2TSM(MobileNetV2): + """MobileNetV2 backbone for TSM. + + Args: + num_segments (int): Number of frame segments. Default: 8. + is_shift (bool): Whether to make temporal shift in reset layers. + Default: True. + shift_div (int): Number of div for shift. Default: 8. + **kwargs (keyword arguments, optional): Arguments for MobilNetV2. + """ def __init__(self, num_segments=8, is_shift=True, shift_div=8, **kwargs): super().__init__(**kwargs) @@ -13,6 +22,7 @@ def __init__(self, num_segments=8, is_shift=True, shift_div=8, **kwargs): self.shift_div = shift_div def make_temporal_shift(self): + """Make temporal shift for some layers.""" for m in self.modules(): if isinstance(m, InvertedResidual) and \ len(m.conv) == 8 and m.use_res_connect: diff --git a/tests/test_models/test_backbone.py b/tests/test_models/test_backbone.py index d8d7530e30..ea06a99799 100644 --- a/tests/test_models/test_backbone.py +++ b/tests/test_models/test_backbone.py @@ -130,11 +130,6 @@ def test_resnet_backbone(): def test_mobilenetv2_backbone(): """Test MobileNetV2 backbone.""" - with pytest.raises(TypeError): - # pretrain is a bool - mobilenetv2 = MobileNetV2(pretrained='') - mobilenetv2.init_weights() - input_shape = (1, 3, 64, 64) imgs = _demo_inputs(input_shape) From b11155deeadcf1be620f418ac53cb01e7340fe51 Mon Sep 17 00:00:00 2001 From: irvingzhang0512 Date: Thu, 3 Dec 2020 16:36:18 +0800 Subject: [PATCH 04/21] chanage default config. --- ...mobilenetv2_video_dense_1x1x8_100e_kinetics400_rgb.py} | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) rename configs/recognition/tsm/{tsm_mobilenetv2_video_dense_1x1x8_50e_kinetics400_rgb.py => tsm_mobilenetv2_video_dense_1x1x8_100e_kinetics400_rgb.py} (96%) diff --git a/configs/recognition/tsm/tsm_mobilenetv2_video_dense_1x1x8_50e_kinetics400_rgb.py b/configs/recognition/tsm/tsm_mobilenetv2_video_dense_1x1x8_100e_kinetics400_rgb.py similarity index 96% rename from configs/recognition/tsm/tsm_mobilenetv2_video_dense_1x1x8_50e_kinetics400_rgb.py rename to configs/recognition/tsm/tsm_mobilenetv2_video_dense_1x1x8_100e_kinetics400_rgb.py index 7201e1f1fd..e8bc2ab376 100644 --- a/configs/recognition/tsm/tsm_mobilenetv2_video_dense_1x1x8_50e_kinetics400_rgb.py +++ b/configs/recognition/tsm/tsm_mobilenetv2_video_dense_1x1x8_100e_kinetics400_rgb.py @@ -110,13 +110,13 @@ weight_decay=0.0001) optimizer_config = dict(grad_clip=dict(max_norm=20, norm_type=2)) # learning policy -lr_config = dict(policy='step', step=[40, 80]) -total_epochs = 100 +lr_config = dict(policy='step', step=[20, 40]) +total_epochs = 50 checkpoint_config = dict(interval=5) evaluation = dict( - interval=1, metrics=['top_k_accuracy', 'mean_class_accuracy'], topk=(1, 5)) + interval=5, metrics=['top_k_accuracy', 'mean_class_accuracy']) log_config = dict( - interval=100, + interval=20, hooks=[ dict(type='TextLoggerHook'), # dict(type='TensorboardLoggerHook'), From 6ba7e569169d206ad3acc469e468f81a0c425554 Mon Sep 17 00:00:00 2001 From: irvingzhang0512 Date: Thu, 3 Dec 2020 16:41:28 +0800 Subject: [PATCH 05/21] fix default configs --- ...bilenetv2_video_1x1x8_50e_kinetics400_rgb.py | 17 ++++++++++------- ...v2_video_dense_1x1x8_100e_kinetics400_rgb.py | 6 +++--- 2 files changed, 13 insertions(+), 10 deletions(-) diff --git a/configs/recognition/tsm/tsm_mobilenetv2_video_1x1x8_50e_kinetics400_rgb.py b/configs/recognition/tsm/tsm_mobilenetv2_video_1x1x8_50e_kinetics400_rgb.py index 89a8477dcb..d3cc39b1a0 100644 --- a/configs/recognition/tsm/tsm_mobilenetv2_video_1x1x8_50e_kinetics400_rgb.py +++ b/configs/recognition/tsm/tsm_mobilenetv2_video_1x1x8_50e_kinetics400_rgb.py @@ -1,7 +1,12 @@ # model settings model = dict( type='Recognizer2D', - backbone=dict(type='MobileNetV2TSM', shift_div=8), + backbone=dict( + type='MobileNetV2TSM', + shift_div=8, + num_segments=8, + is_shift=True, + pretrained=True), cls_head=dict( type='TSMHead', num_segments=8, @@ -67,12 +72,10 @@ clip_len=8, frame_interval=8, num_clips=10, - # twice_sample=True, test_mode=True), dict(type='DecordDecode'), dict(type='Resize', scale=(-1, 256)), - # dict(type='CenterCrop', crop_size=224), - dict(type='ThreeCrop', crop_size=256), # it is used for accurate setting + dict(type='CenterCrop', crop_size=224), dict(type='Flip', flip_ratio=0), dict(type='Normalize', **img_norm_cfg), dict(type='FormatShape', input_format='NCHW'), @@ -80,8 +83,8 @@ dict(type='ToTensor', keys=['imgs']) ] data = dict( - videos_per_gpu=4, - workers_per_gpu=8, + videos_per_gpu=8, + workers_per_gpu=4, train=dict( type=dataset_type, ann_file=ann_file_train, @@ -111,7 +114,7 @@ total_epochs = 50 checkpoint_config = dict(interval=5) evaluation = dict( - interval=5, metrics=['top_k_accuracy', 'mean_class_accuracy'], topk=(1, 5)) + interval=5, metrics=['top_k_accuracy', 'mean_class_accuracy']) log_config = dict( interval=20, hooks=[ diff --git a/configs/recognition/tsm/tsm_mobilenetv2_video_dense_1x1x8_100e_kinetics400_rgb.py b/configs/recognition/tsm/tsm_mobilenetv2_video_dense_1x1x8_100e_kinetics400_rgb.py index e8bc2ab376..7ad0cea806 100644 --- a/configs/recognition/tsm/tsm_mobilenetv2_video_dense_1x1x8_100e_kinetics400_rgb.py +++ b/configs/recognition/tsm/tsm_mobilenetv2_video_dense_1x1x8_100e_kinetics400_rgb.py @@ -83,8 +83,8 @@ dict(type='ToTensor', keys=['imgs']) ] data = dict( - videos_per_gpu=4, - workers_per_gpu=8, + videos_per_gpu=8, + workers_per_gpu=4, train=dict( type=dataset_type, ann_file=ann_file_train, @@ -105,7 +105,7 @@ type='SGD', constructor='TSMOptimizerConstructor', paramwise_cfg=dict(fc_lr5=True), - lr=0.035, # this lr is used for 7 gpus + lr=0.02, # this lr is used for 8 gpus momentum=0.9, weight_decay=0.0001) optimizer_config = dict(grad_clip=dict(max_norm=20, norm_type=2)) From 57a16a853e1caa4ea5b254090b6d31e3bba93d4a Mon Sep 17 00:00:00 2001 From: irvingzhang0512 Date: Thu, 3 Dec 2020 17:12:24 +0800 Subject: [PATCH 06/21] fix unittest and config bug. --- ..._mobilenetv2_video_dense_1x1x8_100e_kinetics400_rgb.py | 2 +- tests/test_models/test_backbone.py | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/configs/recognition/tsm/tsm_mobilenetv2_video_dense_1x1x8_100e_kinetics400_rgb.py b/configs/recognition/tsm/tsm_mobilenetv2_video_dense_1x1x8_100e_kinetics400_rgb.py index 7ad0cea806..fddc2e2073 100644 --- a/configs/recognition/tsm/tsm_mobilenetv2_video_dense_1x1x8_100e_kinetics400_rgb.py +++ b/configs/recognition/tsm/tsm_mobilenetv2_video_dense_1x1x8_100e_kinetics400_rgb.py @@ -75,7 +75,7 @@ test_mode=True), dict(type='DecordDecode'), dict(type='Resize', scale=(-1, 256)), - dict(type='ThreeCrop', crop_size=224), + dict(type='CenterCrop', crop_size=224), dict(type='Flip', flip_ratio=0), dict(type='Normalize', **img_norm_cfg), dict(type='FormatShape', input_format='NCHW'), diff --git a/tests/test_models/test_backbone.py b/tests/test_models/test_backbone.py index ea06a99799..6479f65fd8 100644 --- a/tests/test_models/test_backbone.py +++ b/tests/test_models/test_backbone.py @@ -756,7 +756,7 @@ def test_mobilenetv2_tsm_backbone(): from mmaction.models.backbones.resnet_tsm import TemporalShift from mmaction.models.backbones.mobilenetv2 import InvertedResidual - input_shape = (4, 3, 64, 64) + input_shape = (8, 3, 64, 64) imgs = _demo_inputs(input_shape) # mobilenetv2_tsm with width_mult = 1.0 @@ -774,19 +774,19 @@ def test_mobilenetv2_tsm_backbone(): # TSM-MobileNetV2 with width_mult = 1.0 forword feat = mobilenetv2_tsm(imgs) - assert feat.shape == torch.Size([4, 1280, 2, 2]) + assert feat.shape == torch.Size([8, 1280, 2, 2]) # mobilenetv2 with width_mult = 0.5 forword mobilenetv2_tsm_05 = MobileNetV2TSM(width_mult=0.5) mobilenetv2_tsm_05.init_weights() feat = mobilenetv2_tsm_05(imgs) - assert feat.shape == torch.Size([4, 1280, 2, 2]) + assert feat.shape == torch.Size([8, 1280, 2, 2]) # mobilenetv2 with width_mult = 1.5 forword mobilenetv2_tsm_15 = MobileNetV2TSM(width_mult=1.5) mobilenetv2_tsm_15.init_weights() feat = mobilenetv2_tsm_15(imgs) - assert feat.shape == torch.Size([4, 1920, 2, 2]) + assert feat.shape == torch.Size([8, 1920, 2, 2]) def test_slowfast_backbone(): From 986cbd151a3564415222a7201dadcdf987b37a47 Mon Sep 17 00:00:00 2001 From: irvingzhang0512 Date: Thu, 3 Dec 2020 18:12:56 +0800 Subject: [PATCH 07/21] improve unittest coverage --- mmaction/models/backbones/mobilenetv2.py | 21 ++++----------------- tests/test_models/test_backbone.py | 10 ++++++++++ 2 files changed, 14 insertions(+), 17 deletions(-) diff --git a/mmaction/models/backbones/mobilenetv2.py b/mmaction/models/backbones/mobilenetv2.py index 1fb74712e1..319e646647 100644 --- a/mmaction/models/backbones/mobilenetv2.py +++ b/mmaction/models/backbones/mobilenetv2.py @@ -1,6 +1,5 @@ -import math - import torch.nn as nn +from mmcv.cnn import constant_init, kaiming_init def conv_bn(inp, oup, stride): @@ -185,11 +184,7 @@ def init_weights(self): scratch.""" if self.pretrained: - try: - from torch.hub import load_state_dict_from_url - except ImportError: - from torch.utils.model_zoo import \ - load_url as load_state_dict_from_url + from torch.hub import load_state_dict_from_url state_dict = load_state_dict_from_url( 'https://www.dropbox.com/s/47tyzpofuuyyv1b/mobilenetv2_1.0-f2a8633.pth.tar?dl=1', # noqa progress=True) @@ -199,14 +194,6 @@ def init_weights(self): else: for m in self.modules(): if isinstance(m, nn.Conv2d): - n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels - m.weight.data.normal_(0, math.sqrt(2. / n)) - if m.bias is not None: - m.bias.data.zero_() + kaiming_init(m) elif isinstance(m, nn.BatchNorm2d): - m.weight.data.fill_(1) - m.bias.data.zero_() - elif isinstance(m, nn.Linear): - n = m.weight.size(1) - m.weight.data.normal_(0, 0.01) - m.bias.data.zero_() + constant_init(m, 1) diff --git a/tests/test_models/test_backbone.py b/tests/test_models/test_backbone.py index 6479f65fd8..5450f16969 100644 --- a/tests/test_models/test_backbone.py +++ b/tests/test_models/test_backbone.py @@ -130,6 +130,16 @@ def test_resnet_backbone(): def test_mobilenetv2_backbone(): """Test MobileNetV2 backbone.""" + with pytest.raises(ValueError): + # MobileNetV2 only supports one pretrained model + MobileNetV2(width_mult=0.5, pretrained=True) + + with pytest.raises(ValueError): + # In MobileNetV2, inverted_residual_setting must be None or a list. + # The input list should at least have one element. Each Element should + # be a list with exact 4 ints. + MobileNetV2(inverted_residual_setting=[]) + input_shape = (1, 3, 64, 64) imgs = _demo_inputs(input_shape) From cc90d425e4181c37f47daeca767ede6fd2483a3e Mon Sep 17 00:00:00 2001 From: irvingzhang0512 Date: Thu, 17 Dec 2020 13:48:16 +0800 Subject: [PATCH 08/21] update training configs --- ...lenetv2_video_1x1x8_50e_kinetics400_rgb.py | 2 +- ..._video_dense_1x1x8_100e_kinetics400_rgb.py | 8 +- ...2_video_dense_1x1x8_50e_kinetics400_rgb.py | 130 ++++++++++++++++++ 3 files changed, 135 insertions(+), 5 deletions(-) create mode 100644 configs/recognition/tsm/tsm_mobilenetv2_video_dense_1x1x8_50e_kinetics400_rgb.py diff --git a/configs/recognition/tsm/tsm_mobilenetv2_video_1x1x8_50e_kinetics400_rgb.py b/configs/recognition/tsm/tsm_mobilenetv2_video_1x1x8_50e_kinetics400_rgb.py index d3cc39b1a0..5bdb7cc35d 100644 --- a/configs/recognition/tsm/tsm_mobilenetv2_video_1x1x8_50e_kinetics400_rgb.py +++ b/configs/recognition/tsm/tsm_mobilenetv2_video_1x1x8_50e_kinetics400_rgb.py @@ -105,7 +105,7 @@ type='SGD', constructor='TSMOptimizerConstructor', paramwise_cfg=dict(fc_lr5=True), - lr=0.02, # this lr is used for 8 gpus + lr=0.0025, # this lr is used for 4 gpus momentum=0.9, weight_decay=0.0001) optimizer_config = dict(grad_clip=dict(max_norm=20, norm_type=2)) diff --git a/configs/recognition/tsm/tsm_mobilenetv2_video_dense_1x1x8_100e_kinetics400_rgb.py b/configs/recognition/tsm/tsm_mobilenetv2_video_dense_1x1x8_100e_kinetics400_rgb.py index fddc2e2073..5b96096613 100644 --- a/configs/recognition/tsm/tsm_mobilenetv2_video_dense_1x1x8_100e_kinetics400_rgb.py +++ b/configs/recognition/tsm/tsm_mobilenetv2_video_dense_1x1x8_100e_kinetics400_rgb.py @@ -105,13 +105,13 @@ type='SGD', constructor='TSMOptimizerConstructor', paramwise_cfg=dict(fc_lr5=True), - lr=0.02, # this lr is used for 8 gpus + lr=0.0025, # this lr is used for 4 gpus momentum=0.9, weight_decay=0.0001) optimizer_config = dict(grad_clip=dict(max_norm=20, norm_type=2)) # learning policy -lr_config = dict(policy='step', step=[20, 40]) -total_epochs = 50 +lr_config = dict(policy='step', step=[40, 80]) +total_epochs = 100 checkpoint_config = dict(interval=5) evaluation = dict( interval=5, metrics=['top_k_accuracy', 'mean_class_accuracy']) @@ -124,7 +124,7 @@ # runtime settings dist_params = dict(backend='nccl') log_level = 'INFO' -work_dir = './work_dirs/tsm_mobilenetv2_dense_video_1x1x8_50e_kinetics400_rgb/' +work_dir = './work_dirs/tsm_mobilenetv2_dense_video_1x1x8_100e_kinetics400_rgb/' # noqa load_from = None resume_from = None workflow = [('train', 1)] diff --git a/configs/recognition/tsm/tsm_mobilenetv2_video_dense_1x1x8_50e_kinetics400_rgb.py b/configs/recognition/tsm/tsm_mobilenetv2_video_dense_1x1x8_50e_kinetics400_rgb.py new file mode 100644 index 0000000000..8500568fbb --- /dev/null +++ b/configs/recognition/tsm/tsm_mobilenetv2_video_dense_1x1x8_50e_kinetics400_rgb.py @@ -0,0 +1,130 @@ +# model settings +model = dict( + type='Recognizer2D', + backbone=dict( + type='MobileNetV2TSM', + shift_div=8, + num_segments=8, + is_shift=True, + pretrained=True), + cls_head=dict( + type='TSMHead', + num_segments=8, + num_classes=400, + in_channels=1280, + spatial_type='avg', + consensus=dict(type='AvgConsensus', dim=1), + dropout_ratio=0.5, + init_std=0.001, + is_shift=True)) +# model training and testing settings +train_cfg = None +test_cfg = dict(average_clips='prob') +# dataset settings +dataset_type = 'VideoDataset' +data_root = 'data/kinetics400/videos_train' +data_root_val = 'data/kinetics400/videos_val' +ann_file_train = 'data/kinetics400/kinetics400_train_list_videos.txt' +ann_file_val = 'data/kinetics400/kinetics400_val_list_videos.txt' +ann_file_test = 'data/kinetics400/kinetics400_val_list_videos.txt' +img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_bgr=False) +train_pipeline = [ + dict(type='DecordInit'), + dict(type='DenseSampleFrames', clip_len=1, frame_interval=1, num_clips=8), + dict(type='DecordDecode'), + dict(type='Resize', scale=(-1, 256)), + dict( + type='MultiScaleCrop', + input_size=224, + scales=(1, 0.875, 0.75, 0.66), + random_crop=False, + max_wh_scale_gap=1, + num_fixed_crops=13), + dict(type='Resize', scale=(224, 224), keep_ratio=False), + dict(type='Flip', flip_ratio=0.5), + dict(type='Normalize', **img_norm_cfg), + dict(type='FormatShape', input_format='NCHW'), + dict(type='Collect', keys=['imgs', 'label'], meta_keys=[]), + dict(type='ToTensor', keys=['imgs', 'label']) +] +val_pipeline = [ + dict(type='DecordInit'), + dict( + type='DenseSampleFrames', + clip_len=1, + frame_interval=1, + num_clips=8, + test_mode=True), + dict(type='DecordDecode'), + dict(type='Resize', scale=(-1, 256)), + dict(type='CenterCrop', crop_size=224), + dict(type='Flip', flip_ratio=0), + dict(type='Normalize', **img_norm_cfg), + dict(type='FormatShape', input_format='NCHW'), + dict(type='Collect', keys=['imgs', 'label'], meta_keys=[]), + dict(type='ToTensor', keys=['imgs']) +] +test_pipeline = [ + dict(type='DecordInit'), + dict( + type='DenseSampleFrames', + clip_len=1, + frame_interval=1, + num_clips=8, + test_mode=True), + dict(type='DecordDecode'), + dict(type='Resize', scale=(-1, 224)), + dict(type='ThreeCrop', crop_size=224), + dict(type='Flip', flip_ratio=0), + dict(type='Normalize', **img_norm_cfg), + dict(type='FormatShape', input_format='NCHW'), + dict(type='Collect', keys=['imgs', 'label'], meta_keys=[]), + dict(type='ToTensor', keys=['imgs']) +] +data = dict( + videos_per_gpu=4, + workers_per_gpu=4, + train=dict( + type=dataset_type, + ann_file=ann_file_train, + data_prefix=data_root, + pipeline=train_pipeline), + val=dict( + type=dataset_type, + ann_file=ann_file_val, + data_prefix=data_root_val, + pipeline=val_pipeline), + test=dict( + type=dataset_type, + ann_file=ann_file_test, + data_prefix=data_root_val, + pipeline=test_pipeline)) +# optimizer +optimizer = dict( + type='SGD', + constructor='TSMOptimizerConstructor', + paramwise_cfg=dict(fc_lr5=True), + lr=0.0025, # this lr is used for 4 gpus + momentum=0.9, + weight_decay=0.0001) +optimizer_config = dict(grad_clip=dict(max_norm=20, norm_type=2)) +# learning policy +lr_config = dict(policy='step', step=[20, 40]) +total_epochs = 50 +checkpoint_config = dict(interval=5) +evaluation = dict( + interval=5, metrics=['top_k_accuracy', 'mean_class_accuracy']) +log_config = dict( + interval=20, + hooks=[ + dict(type='TextLoggerHook'), + # dict(type='TensorboardLoggerHook'), + ]) +# runtime settings +dist_params = dict(backend='nccl') +log_level = 'INFO' +work_dir = './work_dirs/tsm_mobilenetv2_dense_video_1x1x8_50e_kinetics400_rgb/' +load_from = None +resume_from = None +workflow = [('train', 1)] From c4b29018c7fe14a169bb75d48ebd13add5c05180 Mon Sep 17 00:00:00 2001 From: irvingzhang0512 Date: Thu, 17 Dec 2020 17:42:42 +0800 Subject: [PATCH 09/21] refactor mobilenet_v2 with mmclassification --- ...lenetv2_video_1x1x8_50e_kinetics400_rgb.py | 2 +- ..._video_dense_1x1x8_100e_kinetics400_rgb.py | 2 +- ...2_video_dense_1x1x8_50e_kinetics400_rgb.py | 4 +- mmaction/models/backbones/__init__.py | 4 +- mmaction/models/backbones/mobilenet_v2.py | 298 ++++++++++++++++++ ...mobilenetv2_tsm.py => mobilenet_v2_tsm.py} | 4 +- mmaction/models/backbones/mobilenetv2.py | 199 ------------ tests/test_models/test_backbone.py | 222 +++++++++++-- 8 files changed, 502 insertions(+), 233 deletions(-) create mode 100644 mmaction/models/backbones/mobilenet_v2.py rename mmaction/models/backbones/{mobilenetv2_tsm.py => mobilenet_v2_tsm.py} (91%) delete mode 100644 mmaction/models/backbones/mobilenetv2.py diff --git a/configs/recognition/tsm/tsm_mobilenetv2_video_1x1x8_50e_kinetics400_rgb.py b/configs/recognition/tsm/tsm_mobilenetv2_video_1x1x8_50e_kinetics400_rgb.py index 5bdb7cc35d..aa3d82e7cb 100644 --- a/configs/recognition/tsm/tsm_mobilenetv2_video_1x1x8_50e_kinetics400_rgb.py +++ b/configs/recognition/tsm/tsm_mobilenetv2_video_1x1x8_50e_kinetics400_rgb.py @@ -6,7 +6,7 @@ shift_div=8, num_segments=8, is_shift=True, - pretrained=True), + pretrained='mmcls://mobilenet_v2'), cls_head=dict( type='TSMHead', num_segments=8, diff --git a/configs/recognition/tsm/tsm_mobilenetv2_video_dense_1x1x8_100e_kinetics400_rgb.py b/configs/recognition/tsm/tsm_mobilenetv2_video_dense_1x1x8_100e_kinetics400_rgb.py index 5b96096613..97554b7b40 100644 --- a/configs/recognition/tsm/tsm_mobilenetv2_video_dense_1x1x8_100e_kinetics400_rgb.py +++ b/configs/recognition/tsm/tsm_mobilenetv2_video_dense_1x1x8_100e_kinetics400_rgb.py @@ -6,7 +6,7 @@ shift_div=8, num_segments=8, is_shift=True, - pretrained=True), + pretrained='mmcls://mobilenet_v2'), cls_head=dict( type='TSMHead', num_segments=8, diff --git a/configs/recognition/tsm/tsm_mobilenetv2_video_dense_1x1x8_50e_kinetics400_rgb.py b/configs/recognition/tsm/tsm_mobilenetv2_video_dense_1x1x8_50e_kinetics400_rgb.py index 8500568fbb..bd87ec86da 100644 --- a/configs/recognition/tsm/tsm_mobilenetv2_video_dense_1x1x8_50e_kinetics400_rgb.py +++ b/configs/recognition/tsm/tsm_mobilenetv2_video_dense_1x1x8_50e_kinetics400_rgb.py @@ -6,7 +6,7 @@ shift_div=8, num_segments=8, is_shift=True, - pretrained=True), + pretrained='mmcls://mobilenet_v2'), cls_head=dict( type='TSMHead', num_segments=8, @@ -83,7 +83,7 @@ dict(type='ToTensor', keys=['imgs']) ] data = dict( - videos_per_gpu=4, + videos_per_gpu=8, workers_per_gpu=4, train=dict( type=dataset_type, diff --git a/mmaction/models/backbones/__init__.py b/mmaction/models/backbones/__init__.py index 70a77bd036..55151b1112 100644 --- a/mmaction/models/backbones/__init__.py +++ b/mmaction/models/backbones/__init__.py @@ -1,6 +1,6 @@ from .c3d import C3D -from .mobilenetv2 import MobileNetV2 -from .mobilenetv2_tsm import MobileNetV2TSM +from .mobilenet_v2 import MobileNetV2 +from .mobilenet_v2_tsm import MobileNetV2TSM from .resnet import ResNet from .resnet2plus1d import ResNet2Plus1d from .resnet3d import ResNet3d diff --git a/mmaction/models/backbones/mobilenet_v2.py b/mmaction/models/backbones/mobilenet_v2.py new file mode 100644 index 0000000000..e63110b650 --- /dev/null +++ b/mmaction/models/backbones/mobilenet_v2.py @@ -0,0 +1,298 @@ +import logging + +import torch.nn as nn +import torch.utils.checkpoint as cp +from mmcv.cnn import ConvModule, constant_init, kaiming_init +from mmcv.runner import load_checkpoint +from torch.nn.modules.batchnorm import _BatchNorm + +from ..builder import BACKBONES + + +def make_divisible(value, divisor, min_value=None, min_ratio=0.9): + """Make divisible function. + + This function rounds the channel number down to the nearest value that can + be divisible by the divisor. + Args: + value (int): The original channel number. + divisor (int): The divisor to fully divide the channel number. + min_value (int, optional): The minimum value of the output channel. + Default: None, means that the minimum value equal to the divisor. + min_ratio (float, optional): The minimum ratio of the rounded channel + number to the original channel number. Default: 0.9. + Returns: + int: The modified output channel number + """ + + if min_value is None: + min_value = divisor + new_value = max(min_value, int(value + divisor / 2) // divisor * divisor) + # Make sure that round down does not go down by more than (1-min_ratio). + if new_value < min_ratio * value: + new_value += divisor + return new_value + + +class InvertedResidual(nn.Module): + """InvertedResidual block for MobileNetV2. + + Args: + in_channels (int): The input channels of the InvertedResidual block. + out_channels (int): The output channels of the InvertedResidual block. + stride (int): Stride of the middle (first) 3x3 convolution. + expand_ratio (int): adjusts number of channels of the hidden layer + in InvertedResidual by this amount. + conv_cfg (dict): Config dict for convolution layer. + Default: None, which means using conv2d. + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='BN'). + act_cfg (dict): Config dict for activation layer. + Default: dict(type='ReLU6'). + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Default: False. + Returns: + Tensor: The output tensor + """ + + def __init__(self, + in_channels, + out_channels, + stride, + expand_ratio, + conv_cfg=None, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU6'), + with_cp=False): + super(InvertedResidual, self).__init__() + self.stride = stride + assert stride in [1, 2], f'stride must in [1, 2]. ' \ + f'But received {stride}.' + self.with_cp = with_cp + self.use_res_connect = self.stride == 1 and in_channels == out_channels + hidden_dim = int(round(in_channels * expand_ratio)) + + layers = [] + if expand_ratio != 1: + layers.append( + ConvModule( + in_channels=in_channels, + out_channels=hidden_dim, + kernel_size=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg)) + layers.extend([ + ConvModule( + in_channels=hidden_dim, + out_channels=hidden_dim, + kernel_size=3, + stride=stride, + padding=1, + groups=hidden_dim, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg), + ConvModule( + in_channels=hidden_dim, + out_channels=out_channels, + kernel_size=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=None) + ]) + self.conv = nn.Sequential(*layers) + + def forward(self, x): + + def _inner_forward(x): + if self.use_res_connect: + return x + self.conv(x) + else: + return self.conv(x) + + if self.with_cp and x.requires_grad: + out = cp.checkpoint(_inner_forward, x) + else: + out = _inner_forward(x) + + return out + + +@BACKBONES.register_module() +class MobileNetV2(nn.Module): + """MobileNetV2 backbone. + + Args: + pretrained (str | None): Name of pretrained model. Default: None. + widen_factor (float): Width multiplier, multiply number of + channels in each layer by this amount. Default: 1.0. + out_indices (None or Sequence[int]): Output from which stages. + Default: (7, ). + frozen_stages (int): Stages to be frozen (all param fixed). + Default: -1, which means not freezing any parameters. + conv_cfg (dict): Config dict for convolution layer. + Default: None, which means using conv2d. + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='BN'). + act_cfg (dict): Config dict for activation layer. + Default: dict(type='ReLU6'). + norm_eval (bool): Whether to set norm layers to eval mode, namely, + freeze running stats (mean and var). Note: Effect on Batch Norm + and its variants only. Default: False. + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Default: False. + """ + + # Parameters to build layers. 4 parameters are needed to construct a + # layer, from left to right: expand_ratio, channel, num_blocks, stride. + arch_settings = [[1, 16, 1, 1], [6, 24, 2, 2], [6, 32, 3, 2], + [6, 64, 4, 2], [6, 96, 3, 1], [6, 160, 3, 2], + [6, 320, 1, 1]] + + def __init__(self, + pretrained=None, + widen_factor=1., + out_indices=(7, ), + frozen_stages=-1, + conv_cfg=dict(type='Conv'), + norm_cfg=dict(type='BN2d', requires_grad=True), + act_cfg=dict(type='ReLU6', inplace=True), + norm_eval=False, + with_cp=False): + super(MobileNetV2, self).__init__() + self.pretrained = pretrained + self.widen_factor = widen_factor + self.out_indices = out_indices + for index in out_indices: + if index not in range(0, 8): + raise ValueError('the item in out_indices must in ' + f'range(0, 8). But received {index}') + + if frozen_stages not in range(-1, 8): + raise ValueError('frozen_stages must be in range(-1, 8). ' + f'But received {frozen_stages}') + self.out_indices = out_indices + self.frozen_stages = frozen_stages + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.act_cfg = act_cfg + self.norm_eval = norm_eval + self.with_cp = with_cp + + self.in_channels = make_divisible(32 * widen_factor, 8) + + self.conv1 = ConvModule( + in_channels=3, + out_channels=self.in_channels, + kernel_size=3, + stride=2, + padding=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + + self.layers = [] + + for i, layer_cfg in enumerate(self.arch_settings): + expand_ratio, channel, num_blocks, stride = layer_cfg + out_channels = make_divisible(channel * widen_factor, 8) + inverted_res_layer = self.make_layer( + out_channels=out_channels, + num_blocks=num_blocks, + stride=stride, + expand_ratio=expand_ratio) + layer_name = f'layer{i + 1}' + self.add_module(layer_name, inverted_res_layer) + self.layers.append(layer_name) + + if widen_factor > 1.0: + self.out_channel = int(1280 * widen_factor) + else: + self.out_channel = 1280 + + layer = ConvModule( + in_channels=self.in_channels, + out_channels=self.out_channel, + kernel_size=1, + stride=1, + padding=0, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + self.add_module('conv2', layer) + self.layers.append('conv2') + + def make_layer(self, out_channels, num_blocks, stride, expand_ratio): + """Stack InvertedResidual blocks to build a layer for MobileNetV2. + + Args: + out_channels (int): out_channels of block. + num_blocks (int): number of blocks. + stride (int): stride of the first block. Default: 1 + expand_ratio (int): Expand the number of channels of the + hidden layer in InvertedResidual by this ratio. Default: 6. + """ + layers = [] + for i in range(num_blocks): + if i >= 1: + stride = 1 + layers.append( + InvertedResidual( + self.in_channels, + out_channels, + stride, + expand_ratio=expand_ratio, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg, + with_cp=self.with_cp)) + self.in_channels = out_channels + + return nn.Sequential(*layers) + + def init_weights(self): + if isinstance(self.pretrained, str): + logger = logging.getLogger() + load_checkpoint(self, self.pretrained, strict=False, logger=logger) + elif self.pretrained is None: + for m in self.modules(): + if isinstance(m, nn.Conv2d): + kaiming_init(m) + elif isinstance(m, (_BatchNorm, nn.GroupNorm)): + constant_init(m, 1) + else: + raise TypeError('pretrained must be a str or None') + + def forward(self, x): + x = self.conv1(x) + + outs = [] + for i, layer_name in enumerate(self.layers): + layer = getattr(self, layer_name) + x = layer(x) + if i in self.out_indices: + outs.append(x) + + if len(outs) == 1: + return outs[0] + else: + return tuple(outs) + + def _freeze_stages(self): + if self.frozen_stages >= 0: + for param in self.conv1.parameters(): + param.requires_grad = False + for i in range(1, self.frozen_stages + 1): + layer = getattr(self, f'layer{i}') + layer.eval() + for param in layer.parameters(): + param.requires_grad = False + + def train(self, mode=True): + super(MobileNetV2, self).train(mode) + self._freeze_stages() + if mode and self.norm_eval: + for m in self.modules(): + if isinstance(m, _BatchNorm): + m.eval() diff --git a/mmaction/models/backbones/mobilenetv2_tsm.py b/mmaction/models/backbones/mobilenet_v2_tsm.py similarity index 91% rename from mmaction/models/backbones/mobilenetv2_tsm.py rename to mmaction/models/backbones/mobilenet_v2_tsm.py index c22c3b7e4b..dd37bdbb7c 100644 --- a/mmaction/models/backbones/mobilenetv2_tsm.py +++ b/mmaction/models/backbones/mobilenet_v2_tsm.py @@ -1,5 +1,5 @@ from ..registry import BACKBONES -from .mobilenetv2 import InvertedResidual, MobileNetV2 +from .mobilenet_v2 import InvertedResidual, MobileNetV2 from .resnet_tsm import TemporalShift @@ -25,7 +25,7 @@ def make_temporal_shift(self): """Make temporal shift for some layers.""" for m in self.modules(): if isinstance(m, InvertedResidual) and \ - len(m.conv) == 8 and m.use_res_connect: + len(m.conv) == 3 and m.use_res_connect: m.conv[0] = TemporalShift( m.conv[0], num_segments=self.num_segments, diff --git a/mmaction/models/backbones/mobilenetv2.py b/mmaction/models/backbones/mobilenetv2.py deleted file mode 100644 index 319e646647..0000000000 --- a/mmaction/models/backbones/mobilenetv2.py +++ /dev/null @@ -1,199 +0,0 @@ -import torch.nn as nn -from mmcv.cnn import constant_init, kaiming_init - - -def conv_bn(inp, oup, stride): - return nn.Sequential( - nn.Conv2d(inp, oup, 3, stride, 1, bias=False), nn.BatchNorm2d(oup), - nn.ReLU6(inplace=True)) - - -def conv_1x1_bn(inp, oup): - return nn.Sequential( - nn.Conv2d(inp, oup, 1, 1, 0, bias=False), nn.BatchNorm2d(oup), - nn.ReLU6(inplace=True)) - - -def make_divisible(v, divisor, min_value=None): - """This function is taken from the original tf repo. - - It ensures that all layers have a channel number that is divisible by 8. - It can be seen here: - https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py # noqa - Args: - v (float): original number of channels. - divisor (float): Round the number of channels in each layer to - be a multiple of this number. Set to 1 to turn off rounding. - min_value (int): minimal value to return - """ - if min_value is None: - min_value = divisor - new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) - # Make sure that round down does not go down by more than 10%. - if new_v < 0.9 * v: - new_v += divisor - return new_v - - -class InvertedResidual(nn.Module): - - def __init__(self, inp, oup, stride, expand_ratio): - """Inverted Residual Mobule from MobilNetV2. - - Args: - inp (int): number of input channels. - oup (int): number of output channels. - stride (int): stride for depthwise convolution. - expand_ratio (int): expand ratio for hidden layers. - """ - super(InvertedResidual, self).__init__() - self.stride = stride - assert stride in [1, 2] - - hidden_dim = int(inp * expand_ratio) - self.use_res_connect = self.stride == 1 and inp == oup - - if expand_ratio == 1: - self.conv = nn.Sequential( - # dw - nn.Conv2d( - hidden_dim, - hidden_dim, - 3, - stride, - 1, - groups=hidden_dim, - bias=False), - nn.BatchNorm2d(hidden_dim), - nn.ReLU6(inplace=True), - # pw-linear - nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), - nn.BatchNorm2d(oup), - ) - else: - self.conv = nn.Sequential( - # pw - nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False), - nn.BatchNorm2d(hidden_dim), - nn.ReLU6(inplace=True), - # dw - nn.Conv2d( - hidden_dim, - hidden_dim, - 3, - stride, - 1, - groups=hidden_dim, - bias=False), - nn.BatchNorm2d(hidden_dim), - nn.ReLU6(inplace=True), - # pw-linear - nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), - nn.BatchNorm2d(oup), - ) - - def forward(self, x): - if self.use_res_connect: - return x + self.conv(x) - else: - return self.conv(x) - - -class MobileNetV2(nn.Module): - - def __init__(self, - width_mult=1., - inverted_residual_setting=None, - round_nearest=8, - block=None, - pretrained=False): - """MobileNet V2 main class. - - Args: - width_mult (float): Width multiplier - adjusts number of channels - in each layer by this amount. - inverted_residual_setting (list): Network structure. - round_nearest (int): Round the number of channels in each layer to - be a multiple of this number. Set to 1 to turn off rounding. - block (nn.Module): Module specifying inverted residual building - block for mobilenet. - pretrained (bool): whether to load pretrained checkpoints. - """ - super(MobileNetV2, self).__init__() - self.pretrained = pretrained - - if abs(width_mult - 1.0) > 1e-5 and pretrained: - raise ValueError('MobileNetV2 only supports one pretrained model ' - 'with `width_mult=1.0`.') - - if block is None: - block = InvertedResidual - input_channel = 32 - last_channel = 1280 - - if inverted_residual_setting is None: - inverted_residual_setting = [ - # t/expand_ratio, c/output_channels, n/num_of_blocks, s/stride - [1, 16, 1, 1], - [6, 24, 2, 2], - [6, 32, 3, 2], - [6, 64, 4, 2], - [6, 96, 3, 1], - [6, 160, 3, 2], - [6, 320, 1, 1], - ] - - # only check the first element - # assuming user knows t,c,n,s are required - if len(inverted_residual_setting) == 0 or len( - inverted_residual_setting[0]) != 4: - raise ValueError('inverted_residual_setting should be non-empty ' - 'or a 4-element list, got {}'.format( - inverted_residual_setting)) - - input_channel = make_divisible(input_channel * width_mult, - round_nearest) - self.last_channel = make_divisible(last_channel * max(1.0, width_mult), - round_nearest) - - # first layer - self.features = [conv_bn(3, input_channel, 2)] - - # building inverted residual blocks - for t, c, n, s in inverted_residual_setting: - output_channel = make_divisible(c * width_mult, round_nearest) - - for i in range(n): - stride = s if i == 0 else 1 - self.features.append( - block( - input_channel, output_channel, stride, expand_ratio=t)) - input_channel = output_channel - - # building last several layers - self.features.append(conv_1x1_bn(input_channel, self.last_channel)) - - # make it nn.Sequential - self.features = nn.Sequential(*self.features) - - def forward(self, x): - return self.features(x) - - def init_weights(self): - """Initiate the parameters either from existing checkpoint or from - scratch.""" - - if self.pretrained: - from torch.hub import load_state_dict_from_url - state_dict = load_state_dict_from_url( - 'https://www.dropbox.com/s/47tyzpofuuyyv1b/mobilenetv2_1.0-f2a8633.pth.tar?dl=1', # noqa - progress=True) - del state_dict['classifier.weight'] - del state_dict['classifier.bias'] - self.load_state_dict(state_dict) - else: - for m in self.modules(): - if isinstance(m, nn.Conv2d): - kaiming_init(m) - elif isinstance(m, nn.BatchNorm2d): - constant_init(m, 1) diff --git a/tests/test_models/test_backbone.py b/tests/test_models/test_backbone.py index 5450f16969..d2892accb5 100644 --- a/tests/test_models/test_backbone.py +++ b/tests/test_models/test_backbone.py @@ -129,37 +129,207 @@ def test_resnet_backbone(): def test_mobilenetv2_backbone(): - """Test MobileNetV2 backbone.""" - with pytest.raises(ValueError): - # MobileNetV2 only supports one pretrained model - MobileNetV2(width_mult=0.5, pretrained=True) + """Test MobileNetV2. + + Modified from mmclassification. + """ + from torch.nn.modules import GroupNorm + from mmaction.models.backbones.mobilenet_v2 import InvertedResidual + + def is_norm(modules): + """Check if is one of the norms.""" + if isinstance(modules, (GroupNorm, _BatchNorm)): + return True + return False + + def is_block(modules): + """Check if is ResNet building block.""" + if isinstance(modules, (InvertedResidual, )): + return True + return False + + with pytest.raises(TypeError): + # pretrained must be a string path + model = MobileNetV2() + model.init_weights(pretrained=0) with pytest.raises(ValueError): - # In MobileNetV2, inverted_residual_setting must be None or a list. - # The input list should at least have one element. Each Element should - # be a list with exact 4 ints. - MobileNetV2(inverted_residual_setting=[]) + # frozen_stages must in range(1, 8) + MobileNetV2(frozen_stages=8) - input_shape = (1, 3, 64, 64) - imgs = _demo_inputs(input_shape) + with pytest.raises(ValueError): + # tout_indices in range(-1, 8) + MobileNetV2(out_indices=[8]) - # mobilenetv2 with width_mult = 1.0, pretrained - mobilenetv2 = MobileNetV2(pretrained=True) - mobilenetv2.init_weights() - feat = mobilenetv2(imgs) - assert feat.shape == torch.Size([1, 1280, 2, 2]) + # Test MobileNetV2 with first stage frozen + frozen_stages = 1 + model = MobileNetV2(frozen_stages=frozen_stages) + model.init_weights() + model.train() - # mobilenetv2 with width_mult = 0.5 - mobilenetv2 = MobileNetV2(width_mult=0.5) - mobilenetv2.init_weights() - feat = mobilenetv2(imgs) - assert feat.shape == torch.Size([1, 1280, 2, 2]) + for mod in model.conv1.modules(): + for param in mod.parameters(): + assert param.requires_grad is False + for i in range(1, frozen_stages + 1): + layer = getattr(model, f'layer{i}') + for mod in layer.modules(): + if isinstance(mod, _BatchNorm): + assert mod.training is False + for param in layer.parameters(): + assert param.requires_grad is False - # mobilenetv2 with width_mult = 1.5 - mobilenetv2 = MobileNetV2(width_mult=1.5) - mobilenetv2.init_weights() - feat = mobilenetv2(imgs) - assert feat.shape == torch.Size([1, 1920, 2, 2]) + # Test MobileNetV2 with norm_eval=True + model = MobileNetV2(norm_eval=True) + model.init_weights() + model.train() + + assert check_norm_state(model.modules(), False) + + # Test MobileNetV2 forward with widen_factor=1.0, pretrained + model = MobileNetV2( + widen_factor=1.0, + out_indices=range(0, 8), + pretrained='mmcls://mobilenet_v2') + model.init_weights() + model.train() + + assert check_norm_state(model.modules(), True) + + imgs = torch.randn(1, 3, 224, 224) + feat = model(imgs) + assert len(feat) == 8 + assert feat[0].shape == torch.Size((1, 16, 112, 112)) + assert feat[1].shape == torch.Size((1, 24, 56, 56)) + assert feat[2].shape == torch.Size((1, 32, 28, 28)) + assert feat[3].shape == torch.Size((1, 64, 14, 14)) + assert feat[4].shape == torch.Size((1, 96, 14, 14)) + assert feat[5].shape == torch.Size((1, 160, 7, 7)) + assert feat[6].shape == torch.Size((1, 320, 7, 7)) + assert feat[7].shape == torch.Size((1, 1280, 7, 7)) + + # Test MobileNetV2 forward with widen_factor=0.5 + model = MobileNetV2(widen_factor=0.5, out_indices=range(0, 7)) + model.init_weights() + model.train() + + imgs = torch.randn(1, 3, 224, 224) + feat = model(imgs) + assert len(feat) == 7 + assert feat[0].shape == torch.Size((1, 8, 112, 112)) + assert feat[1].shape == torch.Size((1, 16, 56, 56)) + assert feat[2].shape == torch.Size((1, 16, 28, 28)) + assert feat[3].shape == torch.Size((1, 32, 14, 14)) + assert feat[4].shape == torch.Size((1, 48, 14, 14)) + assert feat[5].shape == torch.Size((1, 80, 7, 7)) + assert feat[6].shape == torch.Size((1, 160, 7, 7)) + + # Test MobileNetV2 forward with widen_factor=2.0 + model = MobileNetV2(widen_factor=2.0) + model.init_weights() + model.train() + + imgs = torch.randn(1, 3, 224, 224) + feat = model(imgs) + assert feat.shape == torch.Size((1, 2560, 7, 7)) + + # Test MobileNetV2 forward with out_indices=None + model = MobileNetV2(widen_factor=1.0) + model.init_weights() + model.train() + + imgs = torch.randn(1, 3, 224, 224) + feat = model(imgs) + assert feat.shape == torch.Size((1, 1280, 7, 7)) + + # Test MobileNetV2 forward with dict(type='ReLU') + model = MobileNetV2( + widen_factor=1.0, act_cfg=dict(type='ReLU'), out_indices=range(0, 7)) + model.init_weights() + model.train() + + imgs = torch.randn(1, 3, 224, 224) + feat = model(imgs) + assert len(feat) == 7 + assert feat[0].shape == torch.Size((1, 16, 112, 112)) + assert feat[1].shape == torch.Size((1, 24, 56, 56)) + assert feat[2].shape == torch.Size((1, 32, 28, 28)) + assert feat[3].shape == torch.Size((1, 64, 14, 14)) + assert feat[4].shape == torch.Size((1, 96, 14, 14)) + assert feat[5].shape == torch.Size((1, 160, 7, 7)) + assert feat[6].shape == torch.Size((1, 320, 7, 7)) + + # Test MobileNetV2 with GroupNorm forward + model = MobileNetV2(widen_factor=1.0, out_indices=range(0, 7)) + for m in model.modules(): + if is_norm(m): + assert isinstance(m, _BatchNorm) + model.init_weights() + model.train() + + imgs = torch.randn(1, 3, 224, 224) + feat = model(imgs) + assert len(feat) == 7 + assert feat[0].shape == torch.Size((1, 16, 112, 112)) + assert feat[1].shape == torch.Size((1, 24, 56, 56)) + assert feat[2].shape == torch.Size((1, 32, 28, 28)) + assert feat[3].shape == torch.Size((1, 64, 14, 14)) + assert feat[4].shape == torch.Size((1, 96, 14, 14)) + assert feat[5].shape == torch.Size((1, 160, 7, 7)) + assert feat[6].shape == torch.Size((1, 320, 7, 7)) + + # Test MobileNetV2 with BatchNorm forward + model = MobileNetV2( + widen_factor=1.0, + norm_cfg=dict(type='GN', num_groups=2, requires_grad=True), + out_indices=range(0, 7)) + for m in model.modules(): + if is_norm(m): + assert isinstance(m, GroupNorm) + model.init_weights() + model.train() + + imgs = torch.randn(1, 3, 224, 224) + feat = model(imgs) + assert len(feat) == 7 + assert feat[0].shape == torch.Size((1, 16, 112, 112)) + assert feat[1].shape == torch.Size((1, 24, 56, 56)) + assert feat[2].shape == torch.Size((1, 32, 28, 28)) + assert feat[3].shape == torch.Size((1, 64, 14, 14)) + assert feat[4].shape == torch.Size((1, 96, 14, 14)) + assert feat[5].shape == torch.Size((1, 160, 7, 7)) + assert feat[6].shape == torch.Size((1, 320, 7, 7)) + + # Test MobileNetV2 with layers 1, 3, 5 out forward + model = MobileNetV2(widen_factor=1.0, out_indices=(0, 2, 4)) + model.init_weights() + model.train() + + imgs = torch.randn(1, 3, 224, 224) + feat = model(imgs) + assert len(feat) == 3 + assert feat[0].shape == torch.Size((1, 16, 112, 112)) + assert feat[1].shape == torch.Size((1, 32, 28, 28)) + assert feat[2].shape == torch.Size((1, 96, 14, 14)) + + # Test MobileNetV2 with checkpoint forward + model = MobileNetV2( + widen_factor=1.0, with_cp=True, out_indices=range(0, 7)) + for m in model.modules(): + if is_block(m): + assert m.with_cp + model.init_weights() + model.train() + + imgs = torch.randn(1, 3, 224, 224) + feat = model(imgs) + assert len(feat) == 7 + assert feat[0].shape == torch.Size((1, 16, 112, 112)) + assert feat[1].shape == torch.Size((1, 24, 56, 56)) + assert feat[2].shape == torch.Size((1, 32, 28, 28)) + assert feat[3].shape == torch.Size((1, 64, 14, 14)) + assert feat[4].shape == torch.Size((1, 96, 14, 14)) + assert feat[5].shape == torch.Size((1, 160, 7, 7)) + assert feat[6].shape == torch.Size((1, 320, 7, 7)) def test_x3d_backbone(): @@ -764,7 +934,7 @@ def test_resnet_tsm_backbone(): def test_mobilenetv2_tsm_backbone(): """Test mobilenetv2_tsm backbone.""" from mmaction.models.backbones.resnet_tsm import TemporalShift - from mmaction.models.backbones.mobilenetv2 import InvertedResidual + from mmaction.models.backbones.mobilenet_v2 import InvertedResidual input_shape = (8, 3, 64, 64) imgs = _demo_inputs(input_shape) From 385aa479340f481cb4e6d3846ddf1a2ae74ddefc Mon Sep 17 00:00:00 2001 From: irvingzhang0512 Date: Thu, 17 Dec 2020 17:48:41 +0800 Subject: [PATCH 10/21] add changelog --- docs/changelog.md | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/changelog.md b/docs/changelog.md index 7445e93b34..f59414645f 100644 --- a/docs/changelog.md +++ b/docs/changelog.md @@ -5,6 +5,7 @@ **Highlights** **New Features** +- Support TSM-MobileNetV2 **Improvements** - Add FAQ documents for easy troubleshooting. ([#413](https://github.com/open-mmlab/mmaction2/pull/413), [#420](https://github.com/open-mmlab/mmaction2/pull/420), [#439](https://github.com/open-mmlab/mmaction2/pull/439)) From 3952cbf8a9dfa9dd7eccbe7f0b97f1004f6afc7b Mon Sep 17 00:00:00 2001 From: irvingzhang0512 Date: Thu, 17 Dec 2020 17:54:23 +0800 Subject: [PATCH 11/21] update changelog --- docs/changelog.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/changelog.md b/docs/changelog.md index f59414645f..b84500ce86 100644 --- a/docs/changelog.md +++ b/docs/changelog.md @@ -5,7 +5,7 @@ **Highlights** **New Features** -- Support TSM-MobileNetV2 +- Support TSM-MobileNetV2. ([#415](https://github.com/open-mmlab/mmaction2/pull/415)) **Improvements** - Add FAQ documents for easy troubleshooting. ([#413](https://github.com/open-mmlab/mmaction2/pull/413), [#420](https://github.com/open-mmlab/mmaction2/pull/420), [#439](https://github.com/open-mmlab/mmaction2/pull/439)) From 3a91ed67a1718d9a8220d9113df5980eaf8521dd Mon Sep 17 00:00:00 2001 From: irvingzhang0512 Date: Thu, 17 Dec 2020 18:19:29 +0800 Subject: [PATCH 12/21] fix unittest --- tests/test_models/test_backbone.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_models/test_backbone.py b/tests/test_models/test_backbone.py index d2892accb5..5859208f17 100644 --- a/tests/test_models/test_backbone.py +++ b/tests/test_models/test_backbone.py @@ -944,7 +944,7 @@ def test_mobilenetv2_tsm_backbone(): mobilenetv2_tsm.init_weights() for cur_module in mobilenetv2_tsm.modules(): if isinstance(cur_module, InvertedResidual) and \ - len(cur_module.conv) == 8 and \ + len(cur_module.conv) == 3 and \ cur_module.use_res_connect: assert isinstance(cur_module.conv[0], TemporalShift) assert cur_module.conv[0].num_segments == \ From 4b4c8ac0c3bb444d273c09c3fa2aab0287406c3c Mon Sep 17 00:00:00 2001 From: irvingzhang0512 Date: Thu, 17 Dec 2020 20:01:53 +0800 Subject: [PATCH 13/21] fix unittest --- tests/test_models/test_backbone.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/tests/test_models/test_backbone.py b/tests/test_models/test_backbone.py index 5859208f17..a45df47802 100644 --- a/tests/test_models/test_backbone.py +++ b/tests/test_models/test_backbone.py @@ -935,6 +935,7 @@ def test_mobilenetv2_tsm_backbone(): """Test mobilenetv2_tsm backbone.""" from mmaction.models.backbones.resnet_tsm import TemporalShift from mmaction.models.backbones.mobilenet_v2 import InvertedResidual + from mmcv.cnn import ConvModule input_shape = (8, 3, 64, 64) imgs = _demo_inputs(input_shape) @@ -950,20 +951,20 @@ def test_mobilenetv2_tsm_backbone(): assert cur_module.conv[0].num_segments == \ mobilenetv2_tsm.num_segments assert cur_module.conv[0].shift_div == mobilenetv2_tsm.shift_div - assert isinstance(cur_module.conv[0].net, nn.Conv2d) + assert isinstance(cur_module.conv[0].net, ConvModule) - # TSM-MobileNetV2 with width_mult = 1.0 forword + # TSM-MobileNetV2 with widen_factor = 1.0 forword feat = mobilenetv2_tsm(imgs) assert feat.shape == torch.Size([8, 1280, 2, 2]) - # mobilenetv2 with width_mult = 0.5 forword - mobilenetv2_tsm_05 = MobileNetV2TSM(width_mult=0.5) + # mobilenetv2 with widen_factor = 0.5 forword + mobilenetv2_tsm_05 = MobileNetV2TSM(widen_factor=0.5) mobilenetv2_tsm_05.init_weights() feat = mobilenetv2_tsm_05(imgs) assert feat.shape == torch.Size([8, 1280, 2, 2]) - # mobilenetv2 with width_mult = 1.5 forword - mobilenetv2_tsm_15 = MobileNetV2TSM(width_mult=1.5) + # mobilenetv2 with widen_factor = 1.5 forword + mobilenetv2_tsm_15 = MobileNetV2TSM(widen_factor=1.5) mobilenetv2_tsm_15.init_weights() feat = mobilenetv2_tsm_15(imgs) assert feat.shape == torch.Size([8, 1920, 2, 2]) From a37daa168c31f7273735d40cf77334a29bbe3d87 Mon Sep 17 00:00:00 2001 From: irvingzhang0512 Date: Thu, 17 Dec 2020 20:42:25 +0800 Subject: [PATCH 14/21] improve Codecov --- tests/test_models/test_backbone.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_models/test_backbone.py b/tests/test_models/test_backbone.py index a45df47802..a085a05b62 100644 --- a/tests/test_models/test_backbone.py +++ b/tests/test_models/test_backbone.py @@ -150,8 +150,8 @@ def is_block(modules): with pytest.raises(TypeError): # pretrained must be a string path - model = MobileNetV2() - model.init_weights(pretrained=0) + model = MobileNetV2(pretrained=0) + model.init_weights() with pytest.raises(ValueError): # frozen_stages must in range(1, 8) From f36412456767001ab19949bb02be04322ded5540 Mon Sep 17 00:00:00 2001 From: irvingzhang0512 Date: Wed, 30 Dec 2020 17:20:25 +0800 Subject: [PATCH 15/21] update default training configs --- .../tsm/tsm_mobilenetv2_video_1x1x8_50e_kinetics400_rgb.py | 4 ++-- .../tsm_mobilenetv2_video_dense_1x1x8_100e_kinetics400_rgb.py | 4 ++-- .../tsm_mobilenetv2_video_dense_1x1x8_50e_kinetics400_rgb.py | 4 ++-- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/configs/recognition/tsm/tsm_mobilenetv2_video_1x1x8_50e_kinetics400_rgb.py b/configs/recognition/tsm/tsm_mobilenetv2_video_1x1x8_50e_kinetics400_rgb.py index aa3d82e7cb..18a8aa540c 100644 --- a/configs/recognition/tsm/tsm_mobilenetv2_video_1x1x8_50e_kinetics400_rgb.py +++ b/configs/recognition/tsm/tsm_mobilenetv2_video_1x1x8_50e_kinetics400_rgb.py @@ -105,9 +105,9 @@ type='SGD', constructor='TSMOptimizerConstructor', paramwise_cfg=dict(fc_lr5=True), - lr=0.0025, # this lr is used for 4 gpus + lr=0.01, # this lr is used for 8 gpus momentum=0.9, - weight_decay=0.0001) + weight_decay=0.00004) optimizer_config = dict(grad_clip=dict(max_norm=20, norm_type=2)) # learning policy lr_config = dict(policy='step', step=[20, 40]) diff --git a/configs/recognition/tsm/tsm_mobilenetv2_video_dense_1x1x8_100e_kinetics400_rgb.py b/configs/recognition/tsm/tsm_mobilenetv2_video_dense_1x1x8_100e_kinetics400_rgb.py index 97554b7b40..c2e24fb7a7 100644 --- a/configs/recognition/tsm/tsm_mobilenetv2_video_dense_1x1x8_100e_kinetics400_rgb.py +++ b/configs/recognition/tsm/tsm_mobilenetv2_video_dense_1x1x8_100e_kinetics400_rgb.py @@ -105,9 +105,9 @@ type='SGD', constructor='TSMOptimizerConstructor', paramwise_cfg=dict(fc_lr5=True), - lr=0.0025, # this lr is used for 4 gpus + lr=0.01, # this lr is used for 8 gpus momentum=0.9, - weight_decay=0.0001) + weight_decay=0.00004) optimizer_config = dict(grad_clip=dict(max_norm=20, norm_type=2)) # learning policy lr_config = dict(policy='step', step=[40, 80]) diff --git a/configs/recognition/tsm/tsm_mobilenetv2_video_dense_1x1x8_50e_kinetics400_rgb.py b/configs/recognition/tsm/tsm_mobilenetv2_video_dense_1x1x8_50e_kinetics400_rgb.py index bd87ec86da..bce92c4f2a 100644 --- a/configs/recognition/tsm/tsm_mobilenetv2_video_dense_1x1x8_50e_kinetics400_rgb.py +++ b/configs/recognition/tsm/tsm_mobilenetv2_video_dense_1x1x8_50e_kinetics400_rgb.py @@ -105,9 +105,9 @@ type='SGD', constructor='TSMOptimizerConstructor', paramwise_cfg=dict(fc_lr5=True), - lr=0.0025, # this lr is used for 4 gpus + lr=0.01, # this lr is used for 8 gpus momentum=0.9, - weight_decay=0.0001) + weight_decay=0.00004) optimizer_config = dict(grad_clip=dict(max_norm=20, norm_type=2)) # learning policy lr_config = dict(policy='step', step=[20, 40]) From 3cb20e780991c066771848d8218e788ffefba270 Mon Sep 17 00:00:00 2001 From: irvingzhang0512 Date: Thu, 14 Jan 2021 09:08:04 +0800 Subject: [PATCH 16/21] add inference config --- ...erence_dense_1x1x8_100e_kinetics400_rgb.py | 68 +++++++++++++++++++ 1 file changed, 68 insertions(+) create mode 100644 configs/recognition/tsm/tsm_mobilenetv2_video_inference_dense_1x1x8_100e_kinetics400_rgb.py diff --git a/configs/recognition/tsm/tsm_mobilenetv2_video_inference_dense_1x1x8_100e_kinetics400_rgb.py b/configs/recognition/tsm/tsm_mobilenetv2_video_inference_dense_1x1x8_100e_kinetics400_rgb.py new file mode 100644 index 0000000000..6b06c97074 --- /dev/null +++ b/configs/recognition/tsm/tsm_mobilenetv2_video_inference_dense_1x1x8_100e_kinetics400_rgb.py @@ -0,0 +1,68 @@ +# model settings +model = dict( + type='Recognizer2D', + backbone=dict( + type='MobileNetV2TSM', + shift_div=8, + num_segments=8, + is_shift=True, + pretrained='mmcls://mobilenet_v2'), + cls_head=dict( + type='TSMHead', + num_segments=8, + num_classes=400, + in_channels=1280, + spatial_type='avg', + consensus=dict(type='AvgConsensus', dim=1), + dropout_ratio=0.5, + init_std=0.001, + is_shift=True)) +# model training and testing settings +train_cfg = None +test_cfg = dict(average_clips='prob') +# dataset settings +dataset_type = 'VideoDataset' +data_root_val = 'data/kinetics400/videos_val' +ann_file_test = 'data/kinetics400/kinetics400_val_list_videos.txt' +img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_bgr=False) +test_pipeline = [ + dict(type='DecordInit'), + dict( + type='DenseSampleFrames', + clip_len=1, + frame_interval=1, + num_clips=8, + test_mode=True), + dict(type='DecordDecode'), + dict(type='Resize', scale=(-1, 256)), + dict(type='ThreeCrop', crop_size=256), + dict(type='Flip', flip_ratio=0), + dict(type='Normalize', **img_norm_cfg), + dict(type='FormatShape', input_format='NCHW'), + dict(type='Collect', keys=['imgs', 'label'], meta_keys=[]), + dict(type='ToTensor', keys=['imgs']) +] +data = dict( + videos_per_gpu=4, + workers_per_gpu=4, + test=dict( + type=dataset_type, + ann_file=ann_file_test, + data_prefix=data_root_val, + pipeline=test_pipeline)) +evaluation = dict( + interval=5, metrics=['top_k_accuracy', 'mean_class_accuracy']) +log_config = dict( + interval=20, + hooks=[ + dict(type='TextLoggerHook'), + # dict(type='TensorboardLoggerHook'), + ]) +# runtime settings +dist_params = dict(backend='nccl') +log_level = 'INFO' +work_dir = './work_dirs/tsm_mobilenetv2_dense_video_1x1x8_100e_kinetics400_rgb/' # noqa +load_from = None +resume_from = None +workflow = [('train', 1)] From 577ccf17f407c425d386db65ab01b93eb3e9d82d Mon Sep 17 00:00:00 2001 From: irvingzhang0512 Date: Thu, 28 Jan 2021 14:40:19 +0800 Subject: [PATCH 17/21] fix --- docs/changelog.md | 2 +- mmaction/models/backbones/mobilenet_v2.py | 7 +++---- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/docs/changelog.md b/docs/changelog.md index 58fad064e2..2dd151f7a9 100644 --- a/docs/changelog.md +++ b/docs/changelog.md @@ -7,6 +7,7 @@ **New Features** - Support [imgaug](https://imgaug.readthedocs.io/en/latest/index.html) for augmentations in the data pipeline ([#492](https://github.com/open-mmlab/mmaction2/pull/492)) +- Support TSM-MobileNetV2. ([#415](https://github.com/open-mmlab/mmaction2/pull/415)) **Improvements** @@ -32,7 +33,6 @@ - Support precise BN ([#501](https://github.com/open-mmlab/mmaction2/pull/501/)) - Support Spatio-Temporal Action Detection (AVA) ([#351](https://github.com/open-mmlab/mmaction2/pull/351)) - Support to return feature maps in `inference_recognizer` ([#458](https://github.com/open-mmlab/mmaction2/pull/458)) -- Support TSM-MobileNetV2. ([#415](https://github.com/open-mmlab/mmaction2/pull/415)) **Improvements** diff --git a/mmaction/models/backbones/mobilenet_v2.py b/mmaction/models/backbones/mobilenet_v2.py index e63110b650..5a093fa1fa 100644 --- a/mmaction/models/backbones/mobilenet_v2.py +++ b/mmaction/models/backbones/mobilenet_v2.py @@ -1,11 +1,10 @@ -import logging - import torch.nn as nn import torch.utils.checkpoint as cp from mmcv.cnn import ConvModule, constant_init, kaiming_init from mmcv.runner import load_checkpoint from torch.nn.modules.batchnorm import _BatchNorm +from ...utils import get_root_logger from ..builder import BACKBONES @@ -160,7 +159,7 @@ def __init__(self, act_cfg=dict(type='ReLU6', inplace=True), norm_eval=False, with_cp=False): - super(MobileNetV2, self).__init__() + super().__init__() self.pretrained = pretrained self.widen_factor = widen_factor self.out_indices = out_indices @@ -253,7 +252,7 @@ def make_layer(self, out_channels, num_blocks, stride, expand_ratio): def init_weights(self): if isinstance(self.pretrained, str): - logger = logging.getLogger() + logger = get_root_logger() load_checkpoint(self, self.pretrained, strict=False, logger=logger) elif self.pretrained is None: for m in self.modules(): From 3a1568f044f1cfe6fa9d923fa4b91bb698220c3a Mon Sep 17 00:00:00 2001 From: irvingzhang0512 Date: Thu, 28 Jan 2021 14:50:38 +0800 Subject: [PATCH 18/21] refactor unittest --- tests/test_models/test_backbones.py | 211 +----------------- .../test_common_modules/test_mobilenet_v2.py | 204 +++++++++++++++++ 2 files changed, 207 insertions(+), 208 deletions(-) create mode 100644 tests/test_models/test_common_modules/test_mobilenet_v2.py diff --git a/tests/test_models/test_backbones.py b/tests/test_models/test_backbones.py index a5209cf6e2..30959d2b32 100644 --- a/tests/test_models/test_backbones.py +++ b/tests/test_models/test_backbones.py @@ -5,218 +5,13 @@ import torch.nn as nn from mmcv.utils import _BatchNorm -from mmaction.models import (C3D, X3D, MobileNetV2, MobileNetV2TSM, - ResNet2Plus1d, ResNet3dCSN, ResNet3dSlowFast, - ResNet3dSlowOnly, ResNetAudio, ResNetTIN, - ResNetTSM) +from mmaction.models import (C3D, X3D, MobileNetV2TSM, ResNet2Plus1d, + ResNet3dCSN, ResNet3dSlowFast, ResNet3dSlowOnly, + ResNetAudio, ResNetTIN, ResNetTSM) from mmaction.models.backbones.resnet_tsm import NL3DWrapper from .base import check_norm_state, generate_backbone_demo_inputs -def test_mobilenetv2_backbone(): - """Test MobileNetV2. - - Modified from mmclassification. - """ - from torch.nn.modules import GroupNorm - from mmaction.models.backbones.mobilenet_v2 import InvertedResidual - - def is_norm(modules): - """Check if is one of the norms.""" - if isinstance(modules, (GroupNorm, _BatchNorm)): - return True - return False - - def is_block(modules): - """Check if is ResNet building block.""" - if isinstance(modules, (InvertedResidual, )): - return True - return False - - with pytest.raises(TypeError): - # pretrained must be a string path - model = MobileNetV2(pretrained=0) - model.init_weights() - - with pytest.raises(ValueError): - # frozen_stages must in range(1, 8) - MobileNetV2(frozen_stages=8) - - with pytest.raises(ValueError): - # tout_indices in range(-1, 8) - MobileNetV2(out_indices=[8]) - - # Test MobileNetV2 with first stage frozen - frozen_stages = 1 - model = MobileNetV2(frozen_stages=frozen_stages) - model.init_weights() - model.train() - - for mod in model.conv1.modules(): - for param in mod.parameters(): - assert param.requires_grad is False - for i in range(1, frozen_stages + 1): - layer = getattr(model, f'layer{i}') - for mod in layer.modules(): - if isinstance(mod, _BatchNorm): - assert mod.training is False - for param in layer.parameters(): - assert param.requires_grad is False - - # Test MobileNetV2 with norm_eval=True - model = MobileNetV2(norm_eval=True) - model.init_weights() - model.train() - - assert check_norm_state(model.modules(), False) - - # Test MobileNetV2 forward with widen_factor=1.0, pretrained - model = MobileNetV2( - widen_factor=1.0, - out_indices=range(0, 8), - pretrained='mmcls://mobilenet_v2') - model.init_weights() - model.train() - - assert check_norm_state(model.modules(), True) - - imgs = torch.randn(1, 3, 224, 224) - feat = model(imgs) - assert len(feat) == 8 - assert feat[0].shape == torch.Size((1, 16, 112, 112)) - assert feat[1].shape == torch.Size((1, 24, 56, 56)) - assert feat[2].shape == torch.Size((1, 32, 28, 28)) - assert feat[3].shape == torch.Size((1, 64, 14, 14)) - assert feat[4].shape == torch.Size((1, 96, 14, 14)) - assert feat[5].shape == torch.Size((1, 160, 7, 7)) - assert feat[6].shape == torch.Size((1, 320, 7, 7)) - assert feat[7].shape == torch.Size((1, 1280, 7, 7)) - - # Test MobileNetV2 forward with widen_factor=0.5 - model = MobileNetV2(widen_factor=0.5, out_indices=range(0, 7)) - model.init_weights() - model.train() - - imgs = torch.randn(1, 3, 224, 224) - feat = model(imgs) - assert len(feat) == 7 - assert feat[0].shape == torch.Size((1, 8, 112, 112)) - assert feat[1].shape == torch.Size((1, 16, 56, 56)) - assert feat[2].shape == torch.Size((1, 16, 28, 28)) - assert feat[3].shape == torch.Size((1, 32, 14, 14)) - assert feat[4].shape == torch.Size((1, 48, 14, 14)) - assert feat[5].shape == torch.Size((1, 80, 7, 7)) - assert feat[6].shape == torch.Size((1, 160, 7, 7)) - - # Test MobileNetV2 forward with widen_factor=2.0 - model = MobileNetV2(widen_factor=2.0) - model.init_weights() - model.train() - - imgs = torch.randn(1, 3, 224, 224) - feat = model(imgs) - assert feat.shape == torch.Size((1, 2560, 7, 7)) - - # Test MobileNetV2 forward with out_indices=None - model = MobileNetV2(widen_factor=1.0) - model.init_weights() - model.train() - - imgs = torch.randn(1, 3, 224, 224) - feat = model(imgs) - assert feat.shape == torch.Size((1, 1280, 7, 7)) - - # Test MobileNetV2 forward with dict(type='ReLU') - model = MobileNetV2( - widen_factor=1.0, act_cfg=dict(type='ReLU'), out_indices=range(0, 7)) - model.init_weights() - model.train() - - imgs = torch.randn(1, 3, 224, 224) - feat = model(imgs) - assert len(feat) == 7 - assert feat[0].shape == torch.Size((1, 16, 112, 112)) - assert feat[1].shape == torch.Size((1, 24, 56, 56)) - assert feat[2].shape == torch.Size((1, 32, 28, 28)) - assert feat[3].shape == torch.Size((1, 64, 14, 14)) - assert feat[4].shape == torch.Size((1, 96, 14, 14)) - assert feat[5].shape == torch.Size((1, 160, 7, 7)) - assert feat[6].shape == torch.Size((1, 320, 7, 7)) - - # Test MobileNetV2 with GroupNorm forward - model = MobileNetV2(widen_factor=1.0, out_indices=range(0, 7)) - for m in model.modules(): - if is_norm(m): - assert isinstance(m, _BatchNorm) - model.init_weights() - model.train() - - imgs = torch.randn(1, 3, 224, 224) - feat = model(imgs) - assert len(feat) == 7 - assert feat[0].shape == torch.Size((1, 16, 112, 112)) - assert feat[1].shape == torch.Size((1, 24, 56, 56)) - assert feat[2].shape == torch.Size((1, 32, 28, 28)) - assert feat[3].shape == torch.Size((1, 64, 14, 14)) - assert feat[4].shape == torch.Size((1, 96, 14, 14)) - assert feat[5].shape == torch.Size((1, 160, 7, 7)) - assert feat[6].shape == torch.Size((1, 320, 7, 7)) - - # Test MobileNetV2 with BatchNorm forward - model = MobileNetV2( - widen_factor=1.0, - norm_cfg=dict(type='GN', num_groups=2, requires_grad=True), - out_indices=range(0, 7)) - for m in model.modules(): - if is_norm(m): - assert isinstance(m, GroupNorm) - model.init_weights() - model.train() - - imgs = torch.randn(1, 3, 224, 224) - feat = model(imgs) - assert len(feat) == 7 - assert feat[0].shape == torch.Size((1, 16, 112, 112)) - assert feat[1].shape == torch.Size((1, 24, 56, 56)) - assert feat[2].shape == torch.Size((1, 32, 28, 28)) - assert feat[3].shape == torch.Size((1, 64, 14, 14)) - assert feat[4].shape == torch.Size((1, 96, 14, 14)) - assert feat[5].shape == torch.Size((1, 160, 7, 7)) - assert feat[6].shape == torch.Size((1, 320, 7, 7)) - - # Test MobileNetV2 with layers 1, 3, 5 out forward - model = MobileNetV2(widen_factor=1.0, out_indices=(0, 2, 4)) - model.init_weights() - model.train() - - imgs = torch.randn(1, 3, 224, 224) - feat = model(imgs) - assert len(feat) == 3 - assert feat[0].shape == torch.Size((1, 16, 112, 112)) - assert feat[1].shape == torch.Size((1, 32, 28, 28)) - assert feat[2].shape == torch.Size((1, 96, 14, 14)) - - # Test MobileNetV2 with checkpoint forward - model = MobileNetV2( - widen_factor=1.0, with_cp=True, out_indices=range(0, 7)) - for m in model.modules(): - if is_block(m): - assert m.with_cp - model.init_weights() - model.train() - - imgs = torch.randn(1, 3, 224, 224) - feat = model(imgs) - assert len(feat) == 7 - assert feat[0].shape == torch.Size((1, 16, 112, 112)) - assert feat[1].shape == torch.Size((1, 24, 56, 56)) - assert feat[2].shape == torch.Size((1, 32, 28, 28)) - assert feat[3].shape == torch.Size((1, 64, 14, 14)) - assert feat[4].shape == torch.Size((1, 96, 14, 14)) - assert feat[5].shape == torch.Size((1, 160, 7, 7)) - assert feat[6].shape == torch.Size((1, 320, 7, 7)) - - def test_x3d_backbone(): """Test resnet3d backbone.""" with pytest.raises(AssertionError): diff --git a/tests/test_models/test_common_modules/test_mobilenet_v2.py b/tests/test_models/test_common_modules/test_mobilenet_v2.py new file mode 100644 index 0000000000..cdb4cd9ec4 --- /dev/null +++ b/tests/test_models/test_common_modules/test_mobilenet_v2.py @@ -0,0 +1,204 @@ +import pytest +import torch +from mmcv.utils import _BatchNorm + +from mmaction.models import MobileNetV2 +from ..base import check_norm_state, generate_backbone_demo_inputs + + +def test_mobilenetv2_backbone(): + """Test MobileNetV2. + + Modified from mmclassification. + """ + from torch.nn.modules import GroupNorm + from mmaction.models.backbones.mobilenet_v2 import InvertedResidual + + def is_norm(modules): + """Check if is one of the norms.""" + if isinstance(modules, (GroupNorm, _BatchNorm)): + return True + return False + + def is_block(modules): + """Check if is ResNet building block.""" + if isinstance(modules, (InvertedResidual, )): + return True + return False + + with pytest.raises(TypeError): + # pretrained must be a string path + model = MobileNetV2(pretrained=0) + model.init_weights() + + with pytest.raises(ValueError): + # frozen_stages must in range(1, 8) + MobileNetV2(frozen_stages=8) + + with pytest.raises(ValueError): + # tout_indices in range(-1, 8) + MobileNetV2(out_indices=[8]) + + input_shape = (1, 3, 224, 224) + imgs = generate_backbone_demo_inputs(input_shape) + + # Test MobileNetV2 with first stage frozen + frozen_stages = 1 + model = MobileNetV2(frozen_stages=frozen_stages) + model.init_weights() + model.train() + + for mod in model.conv1.modules(): + for param in mod.parameters(): + assert param.requires_grad is False + for i in range(1, frozen_stages + 1): + layer = getattr(model, f'layer{i}') + for mod in layer.modules(): + if isinstance(mod, _BatchNorm): + assert mod.training is False + for param in layer.parameters(): + assert param.requires_grad is False + + # Test MobileNetV2 with norm_eval=True + model = MobileNetV2(norm_eval=True) + model.init_weights() + model.train() + + assert check_norm_state(model.modules(), False) + + # Test MobileNetV2 forward with widen_factor=1.0, pretrained + model = MobileNetV2( + widen_factor=1.0, + out_indices=range(0, 8), + pretrained='mmcls://mobilenet_v2') + model.init_weights() + model.train() + + assert check_norm_state(model.modules(), True) + + feat = model(imgs) + assert len(feat) == 8 + assert feat[0].shape == torch.Size((1, 16, 112, 112)) + assert feat[1].shape == torch.Size((1, 24, 56, 56)) + assert feat[2].shape == torch.Size((1, 32, 28, 28)) + assert feat[3].shape == torch.Size((1, 64, 14, 14)) + assert feat[4].shape == torch.Size((1, 96, 14, 14)) + assert feat[5].shape == torch.Size((1, 160, 7, 7)) + assert feat[6].shape == torch.Size((1, 320, 7, 7)) + assert feat[7].shape == torch.Size((1, 1280, 7, 7)) + + # Test MobileNetV2 forward with widen_factor=0.5 + model = MobileNetV2(widen_factor=0.5, out_indices=range(0, 7)) + model.init_weights() + model.train() + + feat = model(imgs) + assert len(feat) == 7 + assert feat[0].shape == torch.Size((1, 8, 112, 112)) + assert feat[1].shape == torch.Size((1, 16, 56, 56)) + assert feat[2].shape == torch.Size((1, 16, 28, 28)) + assert feat[3].shape == torch.Size((1, 32, 14, 14)) + assert feat[4].shape == torch.Size((1, 48, 14, 14)) + assert feat[5].shape == torch.Size((1, 80, 7, 7)) + assert feat[6].shape == torch.Size((1, 160, 7, 7)) + + # Test MobileNetV2 forward with widen_factor=2.0 + model = MobileNetV2(widen_factor=2.0) + model.init_weights() + model.train() + + feat = model(imgs) + assert feat.shape == torch.Size((1, 2560, 7, 7)) + + # Test MobileNetV2 forward with out_indices=None + model = MobileNetV2(widen_factor=1.0) + model.init_weights() + model.train() + + feat = model(imgs) + assert feat.shape == torch.Size((1, 1280, 7, 7)) + + # Test MobileNetV2 forward with dict(type='ReLU') + model = MobileNetV2( + widen_factor=1.0, act_cfg=dict(type='ReLU'), out_indices=range(0, 7)) + model.init_weights() + model.train() + + feat = model(imgs) + assert len(feat) == 7 + assert feat[0].shape == torch.Size((1, 16, 112, 112)) + assert feat[1].shape == torch.Size((1, 24, 56, 56)) + assert feat[2].shape == torch.Size((1, 32, 28, 28)) + assert feat[3].shape == torch.Size((1, 64, 14, 14)) + assert feat[4].shape == torch.Size((1, 96, 14, 14)) + assert feat[5].shape == torch.Size((1, 160, 7, 7)) + assert feat[6].shape == torch.Size((1, 320, 7, 7)) + + # Test MobileNetV2 with GroupNorm forward + model = MobileNetV2(widen_factor=1.0, out_indices=range(0, 7)) + for m in model.modules(): + if is_norm(m): + assert isinstance(m, _BatchNorm) + model.init_weights() + model.train() + + feat = model(imgs) + assert len(feat) == 7 + assert feat[0].shape == torch.Size((1, 16, 112, 112)) + assert feat[1].shape == torch.Size((1, 24, 56, 56)) + assert feat[2].shape == torch.Size((1, 32, 28, 28)) + assert feat[3].shape == torch.Size((1, 64, 14, 14)) + assert feat[4].shape == torch.Size((1, 96, 14, 14)) + assert feat[5].shape == torch.Size((1, 160, 7, 7)) + assert feat[6].shape == torch.Size((1, 320, 7, 7)) + + # Test MobileNetV2 with BatchNorm forward + model = MobileNetV2( + widen_factor=1.0, + norm_cfg=dict(type='GN', num_groups=2, requires_grad=True), + out_indices=range(0, 7)) + for m in model.modules(): + if is_norm(m): + assert isinstance(m, GroupNorm) + model.init_weights() + model.train() + + feat = model(imgs) + assert len(feat) == 7 + assert feat[0].shape == torch.Size((1, 16, 112, 112)) + assert feat[1].shape == torch.Size((1, 24, 56, 56)) + assert feat[2].shape == torch.Size((1, 32, 28, 28)) + assert feat[3].shape == torch.Size((1, 64, 14, 14)) + assert feat[4].shape == torch.Size((1, 96, 14, 14)) + assert feat[5].shape == torch.Size((1, 160, 7, 7)) + assert feat[6].shape == torch.Size((1, 320, 7, 7)) + + # Test MobileNetV2 with layers 1, 3, 5 out forward + model = MobileNetV2(widen_factor=1.0, out_indices=(0, 2, 4)) + model.init_weights() + model.train() + + feat = model(imgs) + assert len(feat) == 3 + assert feat[0].shape == torch.Size((1, 16, 112, 112)) + assert feat[1].shape == torch.Size((1, 32, 28, 28)) + assert feat[2].shape == torch.Size((1, 96, 14, 14)) + + # Test MobileNetV2 with checkpoint forward + model = MobileNetV2( + widen_factor=1.0, with_cp=True, out_indices=range(0, 7)) + for m in model.modules(): + if is_block(m): + assert m.with_cp + model.init_weights() + model.train() + + feat = model(imgs) + assert len(feat) == 7 + assert feat[0].shape == torch.Size((1, 16, 112, 112)) + assert feat[1].shape == torch.Size((1, 24, 56, 56)) + assert feat[2].shape == torch.Size((1, 32, 28, 28)) + assert feat[3].shape == torch.Size((1, 64, 14, 14)) + assert feat[4].shape == torch.Size((1, 96, 14, 14)) + assert feat[5].shape == torch.Size((1, 160, 7, 7)) + assert feat[6].shape == torch.Size((1, 320, 7, 7)) From 3f1f241ee75d3301d6b3da039c819019ce27b8d3 Mon Sep 17 00:00:00 2001 From: irvingzhang0512 Date: Thu, 28 Jan 2021 15:49:10 +0800 Subject: [PATCH 19/21] refactor config --- configs/_base_/models/tsm_mobilenet_v2.py | 22 +++++++ .../schedules/sgd_tsm_mobilenet_v2_100e.py | 12 ++++ .../schedules/sgd_tsm_mobilenet_v2_50e.py | 12 ++++ ...lenetv2_video_1x1x8_50e_kinetics400_rgb.py | 57 ++++--------------- ..._video_dense_1x1x8_100e_kinetics400_rgb.py | 57 ++++--------------- ...2_video_dense_1x1x8_50e_kinetics400_rgb.py | 57 ++++--------------- ...erence_dense_1x1x8_100e_kinetics400_rgb.py | 38 ++----------- 7 files changed, 86 insertions(+), 169 deletions(-) create mode 100644 configs/_base_/models/tsm_mobilenet_v2.py create mode 100644 configs/_base_/schedules/sgd_tsm_mobilenet_v2_100e.py create mode 100644 configs/_base_/schedules/sgd_tsm_mobilenet_v2_50e.py diff --git a/configs/_base_/models/tsm_mobilenet_v2.py b/configs/_base_/models/tsm_mobilenet_v2.py new file mode 100644 index 0000000000..0338e19e27 --- /dev/null +++ b/configs/_base_/models/tsm_mobilenet_v2.py @@ -0,0 +1,22 @@ +# model settings +model = dict( + type='Recognizer2D', + backbone=dict( + type='MobileNetV2TSM', + shift_div=8, + num_segments=8, + is_shift=True, + pretrained='mmcls://mobilenet_v2'), + cls_head=dict( + type='TSMHead', + num_segments=8, + num_classes=400, + in_channels=1280, + spatial_type='avg', + consensus=dict(type='AvgConsensus', dim=1), + dropout_ratio=0.5, + init_std=0.001, + is_shift=True)) +# model training and testing settings +train_cfg = None +test_cfg = dict(average_clips='prob') diff --git a/configs/_base_/schedules/sgd_tsm_mobilenet_v2_100e.py b/configs/_base_/schedules/sgd_tsm_mobilenet_v2_100e.py new file mode 100644 index 0000000000..63ed3f275a --- /dev/null +++ b/configs/_base_/schedules/sgd_tsm_mobilenet_v2_100e.py @@ -0,0 +1,12 @@ +# optimizer +optimizer = dict( + type='SGD', + constructor='TSMOptimizerConstructor', + paramwise_cfg=dict(fc_lr5=True), + lr=0.01, # this lr is used for 8 gpus + momentum=0.9, + weight_decay=0.00002) +optimizer_config = dict(grad_clip=dict(max_norm=20, norm_type=2)) +# learning policy +lr_config = dict(policy='step', step=[40, 80]) +total_epochs = 100 diff --git a/configs/_base_/schedules/sgd_tsm_mobilenet_v2_50e.py b/configs/_base_/schedules/sgd_tsm_mobilenet_v2_50e.py new file mode 100644 index 0000000000..78612def95 --- /dev/null +++ b/configs/_base_/schedules/sgd_tsm_mobilenet_v2_50e.py @@ -0,0 +1,12 @@ +# optimizer +optimizer = dict( + type='SGD', + constructor='TSMOptimizerConstructor', + paramwise_cfg=dict(fc_lr5=True), + lr=0.01, # this lr is used for 8 gpus + momentum=0.9, + weight_decay=0.00002) +optimizer_config = dict(grad_clip=dict(max_norm=20, norm_type=2)) +# learning policy +lr_config = dict(policy='step', step=[20, 40]) +total_epochs = 50 diff --git a/configs/recognition/tsm/tsm_mobilenetv2_video_1x1x8_50e_kinetics400_rgb.py b/configs/recognition/tsm/tsm_mobilenetv2_video_1x1x8_50e_kinetics400_rgb.py index 18a8aa540c..0b00fec0be 100644 --- a/configs/recognition/tsm/tsm_mobilenetv2_video_1x1x8_50e_kinetics400_rgb.py +++ b/configs/recognition/tsm/tsm_mobilenetv2_video_1x1x8_50e_kinetics400_rgb.py @@ -1,25 +1,9 @@ -# model settings -model = dict( - type='Recognizer2D', - backbone=dict( - type='MobileNetV2TSM', - shift_div=8, - num_segments=8, - is_shift=True, - pretrained='mmcls://mobilenet_v2'), - cls_head=dict( - type='TSMHead', - num_segments=8, - num_classes=400, - in_channels=1280, - spatial_type='avg', - consensus=dict(type='AvgConsensus', dim=1), - dropout_ratio=0.5, - init_std=0.001, - is_shift=True)) -# model training and testing settings -train_cfg = None -test_cfg = dict(average_clips='prob') +_base_ = [ + '../../_base_/models/tsm_mobilenet_v2.py', + '../../_base_/schedules/sgd_tsm_mobilenet_v2_50e.py', + '../../_base_/default_runtime.py' +] + # dataset settings dataset_type = 'VideoDataset' data_root = 'data/kinetics400/videos_train' @@ -100,31 +84,14 @@ ann_file=ann_file_test, data_prefix=data_root_val, pipeline=test_pipeline)) +evaluation = dict( + interval=5, metrics=['top_k_accuracy', 'mean_class_accuracy']) + # optimizer optimizer = dict( - type='SGD', - constructor='TSMOptimizerConstructor', - paramwise_cfg=dict(fc_lr5=True), lr=0.01, # this lr is used for 8 gpus - momentum=0.9, - weight_decay=0.00004) -optimizer_config = dict(grad_clip=dict(max_norm=20, norm_type=2)) -# learning policy -lr_config = dict(policy='step', step=[20, 40]) -total_epochs = 50 -checkpoint_config = dict(interval=5) -evaluation = dict( - interval=5, metrics=['top_k_accuracy', 'mean_class_accuracy']) -log_config = dict( - interval=20, - hooks=[ - dict(type='TextLoggerHook'), - # dict(type='TensorboardLoggerHook'), - ]) +) + # runtime settings -dist_params = dict(backend='nccl') -log_level = 'INFO' +checkpoint_config = dict(interval=5) work_dir = './work_dirs/tsm_mobilenetv2_video_1x1x8_50e_kinetics400_rgb/' -load_from = None -resume_from = None -workflow = [('train', 1)] diff --git a/configs/recognition/tsm/tsm_mobilenetv2_video_dense_1x1x8_100e_kinetics400_rgb.py b/configs/recognition/tsm/tsm_mobilenetv2_video_dense_1x1x8_100e_kinetics400_rgb.py index c2e24fb7a7..14c564faf5 100644 --- a/configs/recognition/tsm/tsm_mobilenetv2_video_dense_1x1x8_100e_kinetics400_rgb.py +++ b/configs/recognition/tsm/tsm_mobilenetv2_video_dense_1x1x8_100e_kinetics400_rgb.py @@ -1,25 +1,9 @@ -# model settings -model = dict( - type='Recognizer2D', - backbone=dict( - type='MobileNetV2TSM', - shift_div=8, - num_segments=8, - is_shift=True, - pretrained='mmcls://mobilenet_v2'), - cls_head=dict( - type='TSMHead', - num_segments=8, - num_classes=400, - in_channels=1280, - spatial_type='avg', - consensus=dict(type='AvgConsensus', dim=1), - dropout_ratio=0.5, - init_std=0.001, - is_shift=True)) -# model training and testing settings -train_cfg = None -test_cfg = dict(average_clips='prob') +_base_ = [ + '../../_base_/models/tsm_mobilenet_v2.py', + '../../_base_/schedules/sgd_tsm_mobilenet_v2_50e.py', + '../../_base_/default_runtime.py' +] + # dataset settings dataset_type = 'VideoDataset' data_root = 'data/kinetics400/videos_train' @@ -100,31 +84,14 @@ ann_file=ann_file_test, data_prefix=data_root_val, pipeline=test_pipeline)) +evaluation = dict( + interval=5, metrics=['top_k_accuracy', 'mean_class_accuracy']) + # optimizer optimizer = dict( - type='SGD', - constructor='TSMOptimizerConstructor', - paramwise_cfg=dict(fc_lr5=True), lr=0.01, # this lr is used for 8 gpus - momentum=0.9, - weight_decay=0.00004) -optimizer_config = dict(grad_clip=dict(max_norm=20, norm_type=2)) -# learning policy -lr_config = dict(policy='step', step=[40, 80]) -total_epochs = 100 -checkpoint_config = dict(interval=5) -evaluation = dict( - interval=5, metrics=['top_k_accuracy', 'mean_class_accuracy']) -log_config = dict( - interval=20, - hooks=[ - dict(type='TextLoggerHook'), - # dict(type='TensorboardLoggerHook'), - ]) +) + # runtime settings -dist_params = dict(backend='nccl') -log_level = 'INFO' +checkpoint_config = dict(interval=5) work_dir = './work_dirs/tsm_mobilenetv2_dense_video_1x1x8_100e_kinetics400_rgb/' # noqa -load_from = None -resume_from = None -workflow = [('train', 1)] diff --git a/configs/recognition/tsm/tsm_mobilenetv2_video_dense_1x1x8_50e_kinetics400_rgb.py b/configs/recognition/tsm/tsm_mobilenetv2_video_dense_1x1x8_50e_kinetics400_rgb.py index bce92c4f2a..0ff68aa18f 100644 --- a/configs/recognition/tsm/tsm_mobilenetv2_video_dense_1x1x8_50e_kinetics400_rgb.py +++ b/configs/recognition/tsm/tsm_mobilenetv2_video_dense_1x1x8_50e_kinetics400_rgb.py @@ -1,25 +1,9 @@ -# model settings -model = dict( - type='Recognizer2D', - backbone=dict( - type='MobileNetV2TSM', - shift_div=8, - num_segments=8, - is_shift=True, - pretrained='mmcls://mobilenet_v2'), - cls_head=dict( - type='TSMHead', - num_segments=8, - num_classes=400, - in_channels=1280, - spatial_type='avg', - consensus=dict(type='AvgConsensus', dim=1), - dropout_ratio=0.5, - init_std=0.001, - is_shift=True)) -# model training and testing settings -train_cfg = None -test_cfg = dict(average_clips='prob') +_base_ = [ + '../../_base_/models/tsm_mobilenet_v2.py', + '../../_base_/schedules/sgd_tsm_mobilenet_v2_50e.py', + '../../_base_/default_runtime.py' +] + # dataset settings dataset_type = 'VideoDataset' data_root = 'data/kinetics400/videos_train' @@ -100,31 +84,14 @@ ann_file=ann_file_test, data_prefix=data_root_val, pipeline=test_pipeline)) +evaluation = dict( + interval=5, metrics=['top_k_accuracy', 'mean_class_accuracy']) + # optimizer optimizer = dict( - type='SGD', - constructor='TSMOptimizerConstructor', - paramwise_cfg=dict(fc_lr5=True), lr=0.01, # this lr is used for 8 gpus - momentum=0.9, - weight_decay=0.00004) -optimizer_config = dict(grad_clip=dict(max_norm=20, norm_type=2)) -# learning policy -lr_config = dict(policy='step', step=[20, 40]) -total_epochs = 50 -checkpoint_config = dict(interval=5) -evaluation = dict( - interval=5, metrics=['top_k_accuracy', 'mean_class_accuracy']) -log_config = dict( - interval=20, - hooks=[ - dict(type='TextLoggerHook'), - # dict(type='TensorboardLoggerHook'), - ]) +) + # runtime settings -dist_params = dict(backend='nccl') -log_level = 'INFO' +checkpoint_config = dict(interval=5) work_dir = './work_dirs/tsm_mobilenetv2_dense_video_1x1x8_50e_kinetics400_rgb/' -load_from = None -resume_from = None -workflow = [('train', 1)] diff --git a/configs/recognition/tsm/tsm_mobilenetv2_video_inference_dense_1x1x8_100e_kinetics400_rgb.py b/configs/recognition/tsm/tsm_mobilenetv2_video_inference_dense_1x1x8_100e_kinetics400_rgb.py index 6b06c97074..4fdd3967f6 100644 --- a/configs/recognition/tsm/tsm_mobilenetv2_video_inference_dense_1x1x8_100e_kinetics400_rgb.py +++ b/configs/recognition/tsm/tsm_mobilenetv2_video_inference_dense_1x1x8_100e_kinetics400_rgb.py @@ -1,25 +1,5 @@ -# model settings -model = dict( - type='Recognizer2D', - backbone=dict( - type='MobileNetV2TSM', - shift_div=8, - num_segments=8, - is_shift=True, - pretrained='mmcls://mobilenet_v2'), - cls_head=dict( - type='TSMHead', - num_segments=8, - num_classes=400, - in_channels=1280, - spatial_type='avg', - consensus=dict(type='AvgConsensus', dim=1), - dropout_ratio=0.5, - init_std=0.001, - is_shift=True)) -# model training and testing settings -train_cfg = None -test_cfg = dict(average_clips='prob') +_base_ = ['../../_base_/models/tsm_mobilenet_v2.py'] + # dataset settings dataset_type = 'VideoDataset' data_root_val = 'data/kinetics400/videos_val' @@ -43,6 +23,7 @@ dict(type='Collect', keys=['imgs', 'label'], meta_keys=[]), dict(type='ToTensor', keys=['imgs']) ] + data = dict( videos_per_gpu=4, workers_per_gpu=4, @@ -51,18 +32,7 @@ ann_file=ann_file_test, data_prefix=data_root_val, pipeline=test_pipeline)) -evaluation = dict( - interval=5, metrics=['top_k_accuracy', 'mean_class_accuracy']) -log_config = dict( - interval=20, - hooks=[ - dict(type='TextLoggerHook'), - # dict(type='TensorboardLoggerHook'), - ]) + # runtime settings dist_params = dict(backend='nccl') log_level = 'INFO' -work_dir = './work_dirs/tsm_mobilenetv2_dense_video_1x1x8_100e_kinetics400_rgb/' # noqa -load_from = None -resume_from = None -workflow = [('train', 1)] From ebcb9105e4bac49ccfdec02de05b8f9ab12d710a Mon Sep 17 00:00:00 2001 From: dreamerlin <528557675@qq.com> Date: Tue, 2 Feb 2021 23:08:17 +0800 Subject: [PATCH 20/21] update config and model link --- configs/recognition/tsm/README.md | 1 + ...enetv2_dense_1x1x8_100e_kinetics400_rgb.py | 92 +++++++++++++++++++ ..._video_dense_1x1x8_100e_kinetics400_rgb.py | 2 +- 3 files changed, 94 insertions(+), 1 deletion(-) create mode 100644 configs/recognition/tsm/tsm_mobilenetv2_dense_1x1x8_100e_kinetics400_rgb.py diff --git a/configs/recognition/tsm/README.md b/configs/recognition/tsm/README.md index 4ccc570537..c619a226f6 100644 --- a/configs/recognition/tsm/README.md +++ b/configs/recognition/tsm/README.md @@ -40,6 +40,7 @@ |[tsm_nl_embedded_gaussian_r50_1x1x8_50e_kinetics400_rgb](/configs/recognition/tsm/tsm_nl_embedded_gaussian_r50_1x1x8_50e_kinetics400_rgb.py)|short-side 320|8x4| ResNet50| ImageNet |72.03|90.25|71.81|90.36|x|8931|[ckpt](https://download.openmmlab.com/mmaction/recognition/tsm/tsm_nl_embedded_gaussian_r50_1x1x8_50e_kinetics400_rgb/tsm_nl_embedded_gaussian_r50_1x1x8_50e_kinetics400_rgb_20200724-f00f1336.pth)|[log](https://download.openmmlab.com/mmaction/recognition/tsm/tsm_nl_embedded_gaussian_r50_1x1x8_50e_kinetics400_rgb/20200724_120023.log)|[json](https://download.openmmlab.com/mmaction/recognition/tsm/tsm_nl_embedded_gaussian_r50_1x1x8_50e_kinetics400_rgb/20200724_120023.log.json)| |[tsm_nl_gaussian_r50_1x1x8_50e_kinetics400_rgb](/configs/recognition/tsm/tsm_nl_gaussian_r50_1x1x8_50e_kinetics400_rgb.py)|short-side 320|8x4| ResNet50| ImageNet |70.70|89.90|x|x|x|10125|[ckpt](https://download.openmmlab.com/mmaction/recognition/tsm/tsm_nl_gaussian_r50_1x1x8_50e_kinetics400_rgb/tsm_nl_gaussian_r50_1x1x8_50e_kinetics400_rgb_20200816-b93fd297.pth)|[log](https://download.openmmlab.com/mmaction/recognition/tsm/tsm_nl_gaussian_r50_1x1x8_50e_kinetics400_rgb/20200815_210253.log)|[json](https://download.openmmlab.com/mmaction/recognition/tsm/tsm_nl_gaussian_r50_1x1x8_50e_kinetics400_rgb/20200815_210253.log.json)| |[tsm_nl_dot_product_r50_1x1x8_50e_kinetics400_rgb](/configs/recognition/tsm/tsm_nl_dot_product_r50_1x1x8_50e_kinetics400_rgb.py)|short-side 320|8x4|ResNet50| ImageNet |71.60|90.34|x|x|x|8358|[ckpt](https://download.openmmlab.com/mmaction/recognition/tsm/tsm_nl_dot_product_r50_1x1x8_50e_kinetics400_rgb/tsm_nl_dot_product_r50_1x1x8_50e_kinetics400_rgb_20200724-d8ad84d2.pth)|[log](https://download.openmmlab.com/mmaction/recognition/tsm/tsm_nl_dot_product_r50_1x1x8_50e_kinetics400_rgb/20200723_220442.log)|[json](https://download.openmmlab.com/mmaction/recognition/tsm/tsm_nl_dot_product_r50_1x1x8_50e_kinetics400_rgb/20200723_220442.log.json)| +|[tsm_mobilenetv2_dense_1x1x8_100e_kinetics400_rgb](/configs/recognition/tsm/tsm_mobilenetv2_dense_1x1x8_100e_kinetics400_rgb.py)|short-side 320|8|MobileNetV2| MobileNetV2 |68.46|88.64|x|x|x|3385|[ckpt](https://download.openmmlab.com/mmaction/recognition/tsm/tsm_mobilenetv2_dense_1x1x8_100e_kinetics400_rgb/tsm_mobilenetv2_dense_320p_1x1x8_100e_kinetics400_rgb_20210202-61135809.pth)|[log](https://download.openmmlab.com/mmaction/recognition/tsm/tsm_mobilenetv2_dense_1x1x8_100e_kinetics400_rgb/20210129_024936.log)|[json](https://download.openmmlab.com/mmaction/recognition/tsm/tsm_mobilenetv2_dense_1x1x8_100e_kinetics400_rgb/20210129_024936.log.json)| ### Something-Something V1 diff --git a/configs/recognition/tsm/tsm_mobilenetv2_dense_1x1x8_100e_kinetics400_rgb.py b/configs/recognition/tsm/tsm_mobilenetv2_dense_1x1x8_100e_kinetics400_rgb.py new file mode 100644 index 0000000000..0069a84c86 --- /dev/null +++ b/configs/recognition/tsm/tsm_mobilenetv2_dense_1x1x8_100e_kinetics400_rgb.py @@ -0,0 +1,92 @@ +_base_ = [ + '../../_base_/models/tsm_mobilenet_v2.py', + '../../_base_/schedules/sgd_tsm_mobilenet_v2_100e.py', + '../../_base_/default_runtime.py' +] + +# dataset settings +dataset_type = 'RawframeDataset' +data_root = 'data/kinetics400/rawframes_train' +data_root_val = 'data/kinetics400/rawframes_val' +ann_file_train = 'data/kinetics400/kinetics400_train_list_rawframes.txt' +ann_file_val = 'data/kinetics400/kinetics400_val_list_rawframes.txt' +ann_file_test = 'data/kinetics400/kinetics400_val_list_rawframes.txt' +img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_bgr=False) +train_pipeline = [ + dict(type='SampleFrames', clip_len=1, frame_interval=1, num_clips=8), + dict(type='RawFrameDecode'), + dict(type='Resize', scale=(-1, 256)), + dict( + type='MultiScaleCrop', + input_size=224, + scales=(1, 0.875, 0.75, 0.66), + random_crop=False, + max_wh_scale_gap=1, + num_fixed_crops=13), + dict(type='Resize', scale=(224, 224), keep_ratio=False), + dict(type='Flip', flip_ratio=0.5), + dict(type='Normalize', **img_norm_cfg), + dict(type='FormatShape', input_format='NCHW'), + dict(type='Collect', keys=['imgs', 'label'], meta_keys=[]), + dict(type='ToTensor', keys=['imgs', 'label']) +] +val_pipeline = [ + dict( + type='SampleFrames', + clip_len=1, + frame_interval=1, + num_clips=8, + test_mode=True), + dict(type='RawFrameDecode'), + dict(type='Resize', scale=(-1, 256)), + dict(type='ThreeCrop', crop_size=256), + dict(type='Normalize', **img_norm_cfg), + dict(type='FormatShape', input_format='NCHW'), + dict(type='Collect', keys=['imgs', 'label'], meta_keys=[]), + dict(type='ToTensor', keys=['imgs']) +] +test_pipeline = [ + dict( + type='SampleFrames', + clip_len=1, + frame_interval=1, + num_clips=8, + test_mode=True), + dict(type='RawFrameDecode'), + dict(type='Resize', scale=(-1, 256)), + dict(type='ThreeCrop', crop_size=256), + dict(type='Normalize', **img_norm_cfg), + dict(type='FormatShape', input_format='NCHW'), + dict(type='Collect', keys=['imgs', 'label'], meta_keys=[]), + dict(type='ToTensor', keys=['imgs']) +] +data = dict( + videos_per_gpu=8, + workers_per_gpu=4, + train=dict( + type=dataset_type, + ann_file=ann_file_train, + data_prefix=data_root, + pipeline=train_pipeline), + val=dict( + type=dataset_type, + ann_file=ann_file_val, + data_prefix=data_root_val, + pipeline=val_pipeline), + test=dict( + type=dataset_type, + ann_file=ann_file_test, + data_prefix=data_root_val, + pipeline=test_pipeline)) +evaluation = dict( + interval=5, metrics=['top_k_accuracy', 'mean_class_accuracy']) + +# optimizer +optimizer = dict( + lr=0.01, # this lr is used for 8 gpus +) + +# runtime settings +checkpoint_config = dict(interval=1) +work_dir = './work_dirs/tsm_mobilenetv2_dense_1x1x8_100e_kinetics400_rgb/' diff --git a/configs/recognition/tsm/tsm_mobilenetv2_video_dense_1x1x8_100e_kinetics400_rgb.py b/configs/recognition/tsm/tsm_mobilenetv2_video_dense_1x1x8_100e_kinetics400_rgb.py index 14c564faf5..a3680ed14e 100644 --- a/configs/recognition/tsm/tsm_mobilenetv2_video_dense_1x1x8_100e_kinetics400_rgb.py +++ b/configs/recognition/tsm/tsm_mobilenetv2_video_dense_1x1x8_100e_kinetics400_rgb.py @@ -1,6 +1,6 @@ _base_ = [ '../../_base_/models/tsm_mobilenet_v2.py', - '../../_base_/schedules/sgd_tsm_mobilenet_v2_50e.py', + '../../_base_/schedules/sgd_tsm_mobilenet_v2_100e.py', '../../_base_/default_runtime.py' ] From fd0574826cd7d3a8c24765b87bc90379c0029c83 Mon Sep 17 00:00:00 2001 From: irvingzhang0512 Date: Wed, 3 Feb 2021 00:32:17 +0800 Subject: [PATCH 21/21] update config & docs --- configs/recognition/tsm/README.md | 2 +- ..._mobilenetv2_1x1x8_50e_kinetics400_rgb.py} | 4 +- ...lenetv2_video_1x1x8_50e_kinetics400_rgb.py | 97 ------------------- ..._video_dense_1x1x8_100e_kinetics400_rgb.py | 2 +- ...2_video_dense_1x1x8_50e_kinetics400_rgb.py | 97 ------------------- ...erence_dense_1x1x8_100e_kinetics400_rgb.py | 6 +- 6 files changed, 5 insertions(+), 203 deletions(-) rename configs/recognition/tsm/{tsm_mobilenetv2_dense_1x1x8_100e_kinetics400_rgb.py => tsm_mobilenetv2_1x1x8_50e_kinetics400_rgb.py} (96%) delete mode 100644 configs/recognition/tsm/tsm_mobilenetv2_video_1x1x8_50e_kinetics400_rgb.py delete mode 100644 configs/recognition/tsm/tsm_mobilenetv2_video_dense_1x1x8_50e_kinetics400_rgb.py diff --git a/configs/recognition/tsm/README.md b/configs/recognition/tsm/README.md index c619a226f6..6fb3cf18b7 100644 --- a/configs/recognition/tsm/README.md +++ b/configs/recognition/tsm/README.md @@ -40,7 +40,7 @@ |[tsm_nl_embedded_gaussian_r50_1x1x8_50e_kinetics400_rgb](/configs/recognition/tsm/tsm_nl_embedded_gaussian_r50_1x1x8_50e_kinetics400_rgb.py)|short-side 320|8x4| ResNet50| ImageNet |72.03|90.25|71.81|90.36|x|8931|[ckpt](https://download.openmmlab.com/mmaction/recognition/tsm/tsm_nl_embedded_gaussian_r50_1x1x8_50e_kinetics400_rgb/tsm_nl_embedded_gaussian_r50_1x1x8_50e_kinetics400_rgb_20200724-f00f1336.pth)|[log](https://download.openmmlab.com/mmaction/recognition/tsm/tsm_nl_embedded_gaussian_r50_1x1x8_50e_kinetics400_rgb/20200724_120023.log)|[json](https://download.openmmlab.com/mmaction/recognition/tsm/tsm_nl_embedded_gaussian_r50_1x1x8_50e_kinetics400_rgb/20200724_120023.log.json)| |[tsm_nl_gaussian_r50_1x1x8_50e_kinetics400_rgb](/configs/recognition/tsm/tsm_nl_gaussian_r50_1x1x8_50e_kinetics400_rgb.py)|short-side 320|8x4| ResNet50| ImageNet |70.70|89.90|x|x|x|10125|[ckpt](https://download.openmmlab.com/mmaction/recognition/tsm/tsm_nl_gaussian_r50_1x1x8_50e_kinetics400_rgb/tsm_nl_gaussian_r50_1x1x8_50e_kinetics400_rgb_20200816-b93fd297.pth)|[log](https://download.openmmlab.com/mmaction/recognition/tsm/tsm_nl_gaussian_r50_1x1x8_50e_kinetics400_rgb/20200815_210253.log)|[json](https://download.openmmlab.com/mmaction/recognition/tsm/tsm_nl_gaussian_r50_1x1x8_50e_kinetics400_rgb/20200815_210253.log.json)| |[tsm_nl_dot_product_r50_1x1x8_50e_kinetics400_rgb](/configs/recognition/tsm/tsm_nl_dot_product_r50_1x1x8_50e_kinetics400_rgb.py)|short-side 320|8x4|ResNet50| ImageNet |71.60|90.34|x|x|x|8358|[ckpt](https://download.openmmlab.com/mmaction/recognition/tsm/tsm_nl_dot_product_r50_1x1x8_50e_kinetics400_rgb/tsm_nl_dot_product_r50_1x1x8_50e_kinetics400_rgb_20200724-d8ad84d2.pth)|[log](https://download.openmmlab.com/mmaction/recognition/tsm/tsm_nl_dot_product_r50_1x1x8_50e_kinetics400_rgb/20200723_220442.log)|[json](https://download.openmmlab.com/mmaction/recognition/tsm/tsm_nl_dot_product_r50_1x1x8_50e_kinetics400_rgb/20200723_220442.log.json)| -|[tsm_mobilenetv2_dense_1x1x8_100e_kinetics400_rgb](/configs/recognition/tsm/tsm_mobilenetv2_dense_1x1x8_100e_kinetics400_rgb.py)|short-side 320|8|MobileNetV2| MobileNetV2 |68.46|88.64|x|x|x|3385|[ckpt](https://download.openmmlab.com/mmaction/recognition/tsm/tsm_mobilenetv2_dense_1x1x8_100e_kinetics400_rgb/tsm_mobilenetv2_dense_320p_1x1x8_100e_kinetics400_rgb_20210202-61135809.pth)|[log](https://download.openmmlab.com/mmaction/recognition/tsm/tsm_mobilenetv2_dense_1x1x8_100e_kinetics400_rgb/20210129_024936.log)|[json](https://download.openmmlab.com/mmaction/recognition/tsm/tsm_mobilenetv2_dense_1x1x8_100e_kinetics400_rgb/20210129_024936.log.json)| +|[tsm_mobilenetv2_dense_1x1x8_100e_kinetics400_rgb](/configs/recognition/tsm/tsm_mobilenetv2_dense_1x1x8_100e_kinetics400_rgb.py)|short-side 320|8|MobileNetV2| ImageNet |68.46|88.64|x|x|x|3385|[ckpt](https://download.openmmlab.com/mmaction/recognition/tsm/tsm_mobilenetv2_dense_1x1x8_100e_kinetics400_rgb/tsm_mobilenetv2_dense_320p_1x1x8_100e_kinetics400_rgb_20210202-61135809.pth)|[log](https://download.openmmlab.com/mmaction/recognition/tsm/tsm_mobilenetv2_dense_1x1x8_100e_kinetics400_rgb/20210129_024936.log)|[json](https://download.openmmlab.com/mmaction/recognition/tsm/tsm_mobilenetv2_dense_1x1x8_100e_kinetics400_rgb/20210129_024936.log.json)| ### Something-Something V1 diff --git a/configs/recognition/tsm/tsm_mobilenetv2_dense_1x1x8_100e_kinetics400_rgb.py b/configs/recognition/tsm/tsm_mobilenetv2_1x1x8_50e_kinetics400_rgb.py similarity index 96% rename from configs/recognition/tsm/tsm_mobilenetv2_dense_1x1x8_100e_kinetics400_rgb.py rename to configs/recognition/tsm/tsm_mobilenetv2_1x1x8_50e_kinetics400_rgb.py index 0069a84c86..647cb609ac 100644 --- a/configs/recognition/tsm/tsm_mobilenetv2_dense_1x1x8_100e_kinetics400_rgb.py +++ b/configs/recognition/tsm/tsm_mobilenetv2_1x1x8_50e_kinetics400_rgb.py @@ -1,6 +1,6 @@ _base_ = [ '../../_base_/models/tsm_mobilenet_v2.py', - '../../_base_/schedules/sgd_tsm_mobilenet_v2_100e.py', + '../../_base_/schedules/sgd_tsm_mobilenet_v2_50e.py', '../../_base_/default_runtime.py' ] @@ -40,7 +40,7 @@ test_mode=True), dict(type='RawFrameDecode'), dict(type='Resize', scale=(-1, 256)), - dict(type='ThreeCrop', crop_size=256), + dict(type='CenterCrop', crop_size=224), dict(type='Normalize', **img_norm_cfg), dict(type='FormatShape', input_format='NCHW'), dict(type='Collect', keys=['imgs', 'label'], meta_keys=[]), diff --git a/configs/recognition/tsm/tsm_mobilenetv2_video_1x1x8_50e_kinetics400_rgb.py b/configs/recognition/tsm/tsm_mobilenetv2_video_1x1x8_50e_kinetics400_rgb.py deleted file mode 100644 index 0b00fec0be..0000000000 --- a/configs/recognition/tsm/tsm_mobilenetv2_video_1x1x8_50e_kinetics400_rgb.py +++ /dev/null @@ -1,97 +0,0 @@ -_base_ = [ - '../../_base_/models/tsm_mobilenet_v2.py', - '../../_base_/schedules/sgd_tsm_mobilenet_v2_50e.py', - '../../_base_/default_runtime.py' -] - -# dataset settings -dataset_type = 'VideoDataset' -data_root = 'data/kinetics400/videos_train' -data_root_val = 'data/kinetics400/videos_val' -ann_file_train = 'data/kinetics400/kinetics400_train_list_videos.txt' -ann_file_val = 'data/kinetics400/kinetics400_val_list_videos.txt' -ann_file_test = 'data/kinetics400/kinetics400_val_list_videos.txt' -img_norm_cfg = dict( - mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_bgr=False) -train_pipeline = [ - dict(type='DecordInit'), - dict(type='SampleFrames', clip_len=1, frame_interval=1, num_clips=8), - dict(type='DecordDecode'), - dict(type='Resize', scale=(-1, 256)), - dict( - type='MultiScaleCrop', - input_size=224, - scales=(1, 0.875, 0.75, 0.66), - random_crop=False, - max_wh_scale_gap=1, - num_fixed_crops=13), - dict(type='Resize', scale=(224, 224), keep_ratio=False), - dict(type='Flip', flip_ratio=0.5), - dict(type='Normalize', **img_norm_cfg), - dict(type='FormatShape', input_format='NCHW'), - dict(type='Collect', keys=['imgs', 'label'], meta_keys=[]), - dict(type='ToTensor', keys=['imgs', 'label']) -] -val_pipeline = [ - dict(type='DecordInit'), - dict( - type='SampleFrames', - clip_len=1, - frame_interval=1, - num_clips=8, - test_mode=True), - dict(type='DecordDecode'), - dict(type='Resize', scale=(-1, 256)), - dict(type='CenterCrop', crop_size=224), - dict(type='Flip', flip_ratio=0), - dict(type='Normalize', **img_norm_cfg), - dict(type='FormatShape', input_format='NCHW'), - dict(type='Collect', keys=['imgs', 'label'], meta_keys=[]), - dict(type='ToTensor', keys=['imgs']) -] -test_pipeline = [ - dict(type='DecordInit'), - dict( - type='SampleFrames', - clip_len=8, - frame_interval=8, - num_clips=10, - test_mode=True), - dict(type='DecordDecode'), - dict(type='Resize', scale=(-1, 256)), - dict(type='CenterCrop', crop_size=224), - dict(type='Flip', flip_ratio=0), - dict(type='Normalize', **img_norm_cfg), - dict(type='FormatShape', input_format='NCHW'), - dict(type='Collect', keys=['imgs', 'label'], meta_keys=[]), - dict(type='ToTensor', keys=['imgs']) -] -data = dict( - videos_per_gpu=8, - workers_per_gpu=4, - train=dict( - type=dataset_type, - ann_file=ann_file_train, - data_prefix=data_root, - pipeline=train_pipeline), - val=dict( - type=dataset_type, - ann_file=ann_file_val, - data_prefix=data_root_val, - pipeline=val_pipeline), - test=dict( - type=dataset_type, - ann_file=ann_file_test, - data_prefix=data_root_val, - pipeline=test_pipeline)) -evaluation = dict( - interval=5, metrics=['top_k_accuracy', 'mean_class_accuracy']) - -# optimizer -optimizer = dict( - lr=0.01, # this lr is used for 8 gpus -) - -# runtime settings -checkpoint_config = dict(interval=5) -work_dir = './work_dirs/tsm_mobilenetv2_video_1x1x8_50e_kinetics400_rgb/' diff --git a/configs/recognition/tsm/tsm_mobilenetv2_video_dense_1x1x8_100e_kinetics400_rgb.py b/configs/recognition/tsm/tsm_mobilenetv2_video_dense_1x1x8_100e_kinetics400_rgb.py index a3680ed14e..e0a3c4873b 100644 --- a/configs/recognition/tsm/tsm_mobilenetv2_video_dense_1x1x8_100e_kinetics400_rgb.py +++ b/configs/recognition/tsm/tsm_mobilenetv2_video_dense_1x1x8_100e_kinetics400_rgb.py @@ -59,7 +59,7 @@ test_mode=True), dict(type='DecordDecode'), dict(type='Resize', scale=(-1, 256)), - dict(type='CenterCrop', crop_size=224), + dict(type='ThreeCrop', crop_size=256), dict(type='Flip', flip_ratio=0), dict(type='Normalize', **img_norm_cfg), dict(type='FormatShape', input_format='NCHW'), diff --git a/configs/recognition/tsm/tsm_mobilenetv2_video_dense_1x1x8_50e_kinetics400_rgb.py b/configs/recognition/tsm/tsm_mobilenetv2_video_dense_1x1x8_50e_kinetics400_rgb.py deleted file mode 100644 index 0ff68aa18f..0000000000 --- a/configs/recognition/tsm/tsm_mobilenetv2_video_dense_1x1x8_50e_kinetics400_rgb.py +++ /dev/null @@ -1,97 +0,0 @@ -_base_ = [ - '../../_base_/models/tsm_mobilenet_v2.py', - '../../_base_/schedules/sgd_tsm_mobilenet_v2_50e.py', - '../../_base_/default_runtime.py' -] - -# dataset settings -dataset_type = 'VideoDataset' -data_root = 'data/kinetics400/videos_train' -data_root_val = 'data/kinetics400/videos_val' -ann_file_train = 'data/kinetics400/kinetics400_train_list_videos.txt' -ann_file_val = 'data/kinetics400/kinetics400_val_list_videos.txt' -ann_file_test = 'data/kinetics400/kinetics400_val_list_videos.txt' -img_norm_cfg = dict( - mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_bgr=False) -train_pipeline = [ - dict(type='DecordInit'), - dict(type='DenseSampleFrames', clip_len=1, frame_interval=1, num_clips=8), - dict(type='DecordDecode'), - dict(type='Resize', scale=(-1, 256)), - dict( - type='MultiScaleCrop', - input_size=224, - scales=(1, 0.875, 0.75, 0.66), - random_crop=False, - max_wh_scale_gap=1, - num_fixed_crops=13), - dict(type='Resize', scale=(224, 224), keep_ratio=False), - dict(type='Flip', flip_ratio=0.5), - dict(type='Normalize', **img_norm_cfg), - dict(type='FormatShape', input_format='NCHW'), - dict(type='Collect', keys=['imgs', 'label'], meta_keys=[]), - dict(type='ToTensor', keys=['imgs', 'label']) -] -val_pipeline = [ - dict(type='DecordInit'), - dict( - type='DenseSampleFrames', - clip_len=1, - frame_interval=1, - num_clips=8, - test_mode=True), - dict(type='DecordDecode'), - dict(type='Resize', scale=(-1, 256)), - dict(type='CenterCrop', crop_size=224), - dict(type='Flip', flip_ratio=0), - dict(type='Normalize', **img_norm_cfg), - dict(type='FormatShape', input_format='NCHW'), - dict(type='Collect', keys=['imgs', 'label'], meta_keys=[]), - dict(type='ToTensor', keys=['imgs']) -] -test_pipeline = [ - dict(type='DecordInit'), - dict( - type='DenseSampleFrames', - clip_len=1, - frame_interval=1, - num_clips=8, - test_mode=True), - dict(type='DecordDecode'), - dict(type='Resize', scale=(-1, 224)), - dict(type='ThreeCrop', crop_size=224), - dict(type='Flip', flip_ratio=0), - dict(type='Normalize', **img_norm_cfg), - dict(type='FormatShape', input_format='NCHW'), - dict(type='Collect', keys=['imgs', 'label'], meta_keys=[]), - dict(type='ToTensor', keys=['imgs']) -] -data = dict( - videos_per_gpu=8, - workers_per_gpu=4, - train=dict( - type=dataset_type, - ann_file=ann_file_train, - data_prefix=data_root, - pipeline=train_pipeline), - val=dict( - type=dataset_type, - ann_file=ann_file_val, - data_prefix=data_root_val, - pipeline=val_pipeline), - test=dict( - type=dataset_type, - ann_file=ann_file_test, - data_prefix=data_root_val, - pipeline=test_pipeline)) -evaluation = dict( - interval=5, metrics=['top_k_accuracy', 'mean_class_accuracy']) - -# optimizer -optimizer = dict( - lr=0.01, # this lr is used for 8 gpus -) - -# runtime settings -checkpoint_config = dict(interval=5) -work_dir = './work_dirs/tsm_mobilenetv2_dense_video_1x1x8_50e_kinetics400_rgb/' diff --git a/configs/recognition/tsm/tsm_mobilenetv2_video_inference_dense_1x1x8_100e_kinetics400_rgb.py b/configs/recognition/tsm/tsm_mobilenetv2_video_inference_dense_1x1x8_100e_kinetics400_rgb.py index 4fdd3967f6..a66c772a9b 100644 --- a/configs/recognition/tsm/tsm_mobilenetv2_video_inference_dense_1x1x8_100e_kinetics400_rgb.py +++ b/configs/recognition/tsm/tsm_mobilenetv2_video_inference_dense_1x1x8_100e_kinetics400_rgb.py @@ -16,7 +16,7 @@ test_mode=True), dict(type='DecordDecode'), dict(type='Resize', scale=(-1, 256)), - dict(type='ThreeCrop', crop_size=256), + dict(type='CenterCrop', crop_size=224), dict(type='Flip', flip_ratio=0), dict(type='Normalize', **img_norm_cfg), dict(type='FormatShape', input_format='NCHW'), @@ -32,7 +32,3 @@ ann_file=ann_file_test, data_prefix=data_root_val, pipeline=test_pipeline)) - -# runtime settings -dist_params = dict(backend='nccl') -log_level = 'INFO'