Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

[Feature] Support TSM-MobileNetV2 #415

Merged
merged 26 commits into from
Feb 3, 2021
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
4dd0755
[Feature] add mobilenetv2-tsm, first commit.
irvingzhang0512 Dec 2, 2020
bf6af41
1. remove torchvision mobilentv2.
irvingzhang0512 Dec 3, 2020
d465376
fix unittest bug and update annotations.
irvingzhang0512 Dec 3, 2020
b11155d
chanage default config.
irvingzhang0512 Dec 3, 2020
6ba7e56
fix default configs
irvingzhang0512 Dec 3, 2020
57a16a8
fix unittest and config bug.
irvingzhang0512 Dec 3, 2020
986cbd1
improve unittest coverage
irvingzhang0512 Dec 3, 2020
cc90d42
update training configs
irvingzhang0512 Dec 17, 2020
c4b2901
refactor mobilenet_v2 with mmclassification
irvingzhang0512 Dec 17, 2020
02b8acf
Merge branch 'master' into mobilenetv2-tsm
irvingzhang0512 Dec 17, 2020
385aa47
add changelog
irvingzhang0512 Dec 17, 2020
3952cbf
update changelog
irvingzhang0512 Dec 17, 2020
a360968
Merge branch 'master' into mobilenetv2-tsm
irvingzhang0512 Dec 17, 2020
3a91ed6
fix unittest
irvingzhang0512 Dec 17, 2020
4b4c8ac
fix unittest
irvingzhang0512 Dec 17, 2020
a37daa1
improve Codecov
irvingzhang0512 Dec 17, 2020
f364124
update default training configs
irvingzhang0512 Dec 30, 2020
882ed59
Merge branch 'master' into mobilenetv2-tsm
irvingzhang0512 Dec 30, 2020
3cb20e7
add inference config
irvingzhang0512 Jan 14, 2021
e247e79
Merge branch 'master' into mobilenetv2-tsm
irvingzhang0512 Jan 28, 2021
577ccf1
fix
irvingzhang0512 Jan 28, 2021
3a1568f
refactor unittest
irvingzhang0512 Jan 28, 2021
3f1f241
refactor config
irvingzhang0512 Jan 28, 2021
400de06
Merge branch 'master' into mobilenetv2-tsm
dreamerlin Feb 2, 2021
ebcb910
update config and model link
dreamerlin Feb 2, 2021
fd05748
update config & docs
irvingzhang0512 Feb 2, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
# model settings
model = dict(
type='Recognizer2D',
backbone=dict(
type='MobileNetV2TSM',
shift_div=8,
num_segments=8,
is_shift=True,
pretrained=True),
cls_head=dict(
type='TSMHead',
num_segments=8,
num_classes=400,
in_channels=1280,
spatial_type='avg',
consensus=dict(type='AvgConsensus', dim=1),
dropout_ratio=0.5,
init_std=0.001,
is_shift=True))
# model training and testing settings
train_cfg = None
test_cfg = dict(average_clips='prob')
# dataset settings
dataset_type = 'VideoDataset'
data_root = 'data/kinetics400/videos_train'
data_root_val = 'data/kinetics400/videos_val'
ann_file_train = 'data/kinetics400/kinetics400_train_list_videos.txt'
ann_file_val = 'data/kinetics400/kinetics400_val_list_videos.txt'
ann_file_test = 'data/kinetics400/kinetics400_val_list_videos.txt'
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_bgr=False)
train_pipeline = [
dict(type='DecordInit'),
dict(type='SampleFrames', clip_len=1, frame_interval=1, num_clips=8),
dict(type='DecordDecode'),
dict(type='Resize', scale=(-1, 256)),
dict(
type='MultiScaleCrop',
input_size=224,
scales=(1, 0.875, 0.75, 0.66),
random_crop=False,
max_wh_scale_gap=1,
num_fixed_crops=13),
dict(type='Resize', scale=(224, 224), keep_ratio=False),
dict(type='Flip', flip_ratio=0.5),
dict(type='Normalize', **img_norm_cfg),
dict(type='FormatShape', input_format='NCHW'),
dict(type='Collect', keys=['imgs', 'label'], meta_keys=[]),
dict(type='ToTensor', keys=['imgs', 'label'])
]
val_pipeline = [
dict(type='DecordInit'),
dict(
type='SampleFrames',
clip_len=1,
frame_interval=1,
num_clips=8,
test_mode=True),
dict(type='DecordDecode'),
dict(type='Resize', scale=(-1, 256)),
dict(type='CenterCrop', crop_size=224),
dict(type='Flip', flip_ratio=0),
dict(type='Normalize', **img_norm_cfg),
dict(type='FormatShape', input_format='NCHW'),
dict(type='Collect', keys=['imgs', 'label'], meta_keys=[]),
dict(type='ToTensor', keys=['imgs'])
]
test_pipeline = [
dict(type='DecordInit'),
dict(
type='SampleFrames',
clip_len=8,
frame_interval=8,
num_clips=10,
test_mode=True),
dict(type='DecordDecode'),
dict(type='Resize', scale=(-1, 256)),
dict(type='CenterCrop', crop_size=224),
dict(type='Flip', flip_ratio=0),
dict(type='Normalize', **img_norm_cfg),
dict(type='FormatShape', input_format='NCHW'),
dict(type='Collect', keys=['imgs', 'label'], meta_keys=[]),
dict(type='ToTensor', keys=['imgs'])
]
data = dict(
videos_per_gpu=8,
workers_per_gpu=4,
train=dict(
type=dataset_type,
ann_file=ann_file_train,
data_prefix=data_root,
pipeline=train_pipeline),
val=dict(
type=dataset_type,
ann_file=ann_file_val,
data_prefix=data_root_val,
pipeline=val_pipeline),
test=dict(
type=dataset_type,
ann_file=ann_file_test,
data_prefix=data_root_val,
pipeline=test_pipeline))
# optimizer
optimizer = dict(
type='SGD',
constructor='TSMOptimizerConstructor',
paramwise_cfg=dict(fc_lr5=True),
lr=0.02, # this lr is used for 8 gpus
momentum=0.9,
weight_decay=0.0001)
optimizer_config = dict(grad_clip=dict(max_norm=20, norm_type=2))
# learning policy
lr_config = dict(policy='step', step=[20, 40])
total_epochs = 50
checkpoint_config = dict(interval=5)
evaluation = dict(
interval=5, metrics=['top_k_accuracy', 'mean_class_accuracy'])
log_config = dict(
interval=20,
hooks=[
dict(type='TextLoggerHook'),
# dict(type='TensorboardLoggerHook'),
])
# runtime settings
dist_params = dict(backend='nccl')
log_level = 'INFO'
work_dir = './work_dirs/tsm_mobilenetv2_video_1x1x8_50e_kinetics400_rgb/'
load_from = None
resume_from = None
workflow = [('train', 1)]
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
# model settings
model = dict(
type='Recognizer2D',
backbone=dict(
type='MobileNetV2TSM',
shift_div=8,
num_segments=8,
is_shift=True,
pretrained=True),
cls_head=dict(
type='TSMHead',
num_segments=8,
num_classes=400,
in_channels=1280,
spatial_type='avg',
consensus=dict(type='AvgConsensus', dim=1),
dropout_ratio=0.5,
init_std=0.001,
is_shift=True))
# model training and testing settings
train_cfg = None
test_cfg = dict(average_clips='prob')
# dataset settings
dataset_type = 'VideoDataset'
data_root = 'data/kinetics400/videos_train'
data_root_val = 'data/kinetics400/videos_val'
ann_file_train = 'data/kinetics400/kinetics400_train_list_videos.txt'
ann_file_val = 'data/kinetics400/kinetics400_val_list_videos.txt'
ann_file_test = 'data/kinetics400/kinetics400_val_list_videos.txt'
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_bgr=False)
train_pipeline = [
dict(type='DecordInit'),
dict(type='DenseSampleFrames', clip_len=1, frame_interval=1, num_clips=8),
dict(type='DecordDecode'),
dict(type='Resize', scale=(-1, 256)),
dict(
type='MultiScaleCrop',
input_size=224,
scales=(1, 0.875, 0.75, 0.66),
random_crop=False,
max_wh_scale_gap=1,
num_fixed_crops=13),
dict(type='Resize', scale=(224, 224), keep_ratio=False),
dict(type='Flip', flip_ratio=0.5),
dict(type='Normalize', **img_norm_cfg),
dict(type='FormatShape', input_format='NCHW'),
dict(type='Collect', keys=['imgs', 'label'], meta_keys=[]),
dict(type='ToTensor', keys=['imgs', 'label'])
]
val_pipeline = [
dict(type='DecordInit'),
dict(
type='DenseSampleFrames',
clip_len=1,
frame_interval=1,
num_clips=8,
test_mode=True),
dict(type='DecordDecode'),
dict(type='Resize', scale=(-1, 256)),
dict(type='CenterCrop', crop_size=224),
dict(type='Flip', flip_ratio=0),
dict(type='Normalize', **img_norm_cfg),
dict(type='FormatShape', input_format='NCHW'),
dict(type='Collect', keys=['imgs', 'label'], meta_keys=[]),
dict(type='ToTensor', keys=['imgs'])
]
test_pipeline = [
dict(type='DecordInit'),
dict(
type='DenseSampleFrames',
clip_len=1,
frame_interval=1,
num_clips=8,
test_mode=True),
dict(type='DecordDecode'),
dict(type='Resize', scale=(-1, 256)),
dict(type='CenterCrop', crop_size=224),
dict(type='Flip', flip_ratio=0),
dict(type='Normalize', **img_norm_cfg),
dict(type='FormatShape', input_format='NCHW'),
dict(type='Collect', keys=['imgs', 'label'], meta_keys=[]),
dict(type='ToTensor', keys=['imgs'])
]
data = dict(
videos_per_gpu=8,
workers_per_gpu=4,
train=dict(
type=dataset_type,
ann_file=ann_file_train,
data_prefix=data_root,
pipeline=train_pipeline),
val=dict(
type=dataset_type,
ann_file=ann_file_val,
data_prefix=data_root_val,
pipeline=val_pipeline),
test=dict(
type=dataset_type,
ann_file=ann_file_test,
data_prefix=data_root_val,
pipeline=test_pipeline))
# optimizer
optimizer = dict(
type='SGD',
constructor='TSMOptimizerConstructor',
paramwise_cfg=dict(fc_lr5=True),
lr=0.02, # this lr is used for 8 gpus
momentum=0.9,
weight_decay=0.0001)
optimizer_config = dict(grad_clip=dict(max_norm=20, norm_type=2))
# learning policy
lr_config = dict(policy='step', step=[20, 40])
total_epochs = 50
checkpoint_config = dict(interval=5)
evaluation = dict(
interval=5, metrics=['top_k_accuracy', 'mean_class_accuracy'])
log_config = dict(
interval=20,
hooks=[
dict(type='TextLoggerHook'),
# dict(type='TensorboardLoggerHook'),
])
# runtime settings
dist_params = dict(backend='nccl')
log_level = 'INFO'
work_dir = './work_dirs/tsm_mobilenetv2_dense_video_1x1x8_50e_kinetics400_rgb/'
load_from = None
resume_from = None
workflow = [('train', 1)]
23 changes: 12 additions & 11 deletions mmaction/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from .backbones import (C3D, X3D, ResNet, ResNet2Plus1d, ResNet3d, ResNet3dCSN,
ResNet3dSlowFast, ResNet3dSlowOnly, ResNetAudio,
ResNetTIN, ResNetTSM)
from .backbones import (C3D, X3D, MobileNetV2, MobileNetV2TSM, ResNet,
ResNet2Plus1d, ResNet3d, ResNet3dCSN, ResNet3dSlowFast,
ResNet3dSlowOnly, ResNetAudio, ResNetTIN, ResNetTSM)
from .builder import (build_backbone, build_head, build_localizer, build_loss,
build_model, build_neck, build_recognizer)
from .common import Conv2plus1d, ConvAudio
Expand All @@ -18,12 +18,13 @@
__all__ = [
'BACKBONES', 'HEADS', 'RECOGNIZERS', 'build_recognizer', 'build_head',
'build_backbone', 'recognizer2d', 'recognizer3d', 'C3D', 'ResNet',
'ResNet3d', 'ResNet2Plus1d', 'I3DHead', 'TSNHead', 'TSMHead', 'BaseHead',
'BaseRecognizer', 'LOSSES', 'CrossEntropyLoss', 'NLLLoss', 'HVULoss',
'ResNetTSM', 'ResNet3dSlowFast', 'SlowFastHead', 'Conv2plus1d',
'ResNet3dSlowOnly', 'BCELossWithLogits', 'LOCALIZERS', 'build_localizer',
'PEM', 'TEM', 'BinaryLogisticRegressionLoss', 'BMN', 'BMNLoss',
'build_model', 'OHEMHingeLoss', 'SSNLoss', 'ResNet3dCSN', 'ResNetTIN',
'TPN', 'TPNHead', 'build_loss', 'build_neck', 'AudioRecognizer',
'AudioTSNHead', 'X3D', 'X3DHead', 'ResNetAudio', 'ConvAudio'
'MobileNetV2', 'ResNet3d', 'ResNet2Plus1d', 'I3DHead', 'TSNHead',
'TSMHead', 'BaseHead', 'BaseRecognizer', 'LOSSES', 'CrossEntropyLoss',
'NLLLoss', 'HVULoss', 'ResNetTSM', 'MobileNetV2TSM', 'ResNet3dSlowFast',
'SlowFastHead', 'Conv2plus1d', 'ResNet3dSlowOnly', 'BCELossWithLogits',
'LOCALIZERS', 'build_localizer', 'PEM', 'TEM',
'BinaryLogisticRegressionLoss', 'BMN', 'BMNLoss', 'build_model',
'OHEMHingeLoss', 'SSNLoss', 'ResNet3dCSN', 'ResNetTIN', 'TPN', 'TPNHead',
'build_loss', 'build_neck', 'AudioRecognizer', 'AudioTSNHead', 'X3D',
'X3DHead', 'ResNetAudio', 'ConvAudio'
]
4 changes: 3 additions & 1 deletion mmaction/models/backbones/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from .c3d import C3D
from .mobilenetv2 import MobileNetV2
from .mobilenetv2_tsm import MobileNetV2TSM
from .resnet import ResNet
from .resnet2plus1d import ResNet2Plus1d
from .resnet3d import ResNet3d
Expand All @@ -13,5 +15,5 @@
__all__ = [
'C3D', 'ResNet', 'ResNet3d', 'ResNetTSM', 'ResNet2Plus1d',
'ResNet3dSlowFast', 'ResNet3dSlowOnly', 'ResNet3dCSN', 'ResNetTIN', 'X3D',
'ResNetAudio'
'ResNetAudio', 'MobileNetV2TSM', 'MobileNetV2'
]
Loading