Skip to content

Commit

Permalink
[Feature] Support MMCls backbones for TSN (#679)
Browse files Browse the repository at this point in the history
* resolve comments

* update changelog

* enable TSNxMMCls Backbone

* add rn101 config

* install mmcls

* add a unittest

* fix config

* Update README.md

remove backbones from other sources for now

* Update changelog.md

Co-authored-by: Jintao Lin <528557675@qq.com>
  • Loading branch information
kennymckormick and dreamerlin authored Mar 9, 2021
1 parent 1396c3f commit d8f8746
Show file tree
Hide file tree
Showing 6 changed files with 170 additions and 9 deletions.
3 changes: 3 additions & 0 deletions .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@ jobs:
run: pip install mmcv-full==1.2.6 -f https://download.openmmlab.com/mmcv/dist/cpu/torch${{matrix.torch}}/index.html
- name: Install MMDet
run: pip install git+https://github.com/open-mmlab/mmdetection/
- name: Install MMCls
run: pip install git+https://github.com/open-mmlab/mmclassification/
- name: Install unittest dependencies
run: pip install -r requirements/tests.txt -r requirements/optional.txt
- name: Build and install
Expand Down Expand Up @@ -144,6 +146,7 @@ jobs:
run: |
pip install mmcv-full==1.2.6 -f https://download.openmmlab.com/mmcv/dist/${{matrix.mmcv}}/index.html
pip install -q git+https://github.com/open-mmlab/mmdetection/
pip install -q git+https://github.com/open-mmlab/mmclassification/
pip install -r requirements.txt
- name: Build and install
run: |
Expand Down
26 changes: 18 additions & 8 deletions configs/recognition/tsn/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,16 @@

Here, We use [1: 1] to indicate that we combine rgb and flow score with coefficients 1: 1 to get the two-stream prediction (without applying softmax).

### Using backbones from 3rd-party in TSN

It's possible and convenient to use a 3rd-party backbone for TSN under the framework of MMAction2, here we provide some examples for:

- [x] Backbones from MMClassification

| config | resolution | gpus | backbone | pretrain | top1 acc | top5 acc | ckpt | log | json |
| :----------------------------------------------------------: | :------------: | :--: | :----------------------------------------------------------: | :------: | :------: | :------: | :----------------------------------------------------------: | :----------------------------------------------------------: | :----------------------------------------------------------: |
| [tsn_rn101_32x4d_320p_1x1x3_100e_kinetics400_rgb](/configs/recognition/tsn/custom_backbones/tsn_rn101_32x4d_320p_1x1x3_100e_kinetics400_rgb.py) | short-side 320 | 8x2 | ResNeXt101-32x4d [[MMCls](https://github.com/open-mmlab/mmclassification/tree/master/configs/resnext)] | ImageNet | 73.43 | 91.01 | [ckpt](https://download.openmmlab.com/mmaction/recognition/tsn/custom_backbones/tsn_rn101_32x4d_320p_1x1x3_100e_kinetics400_rgb-16a8b561.pth) | [log](https://download.openmmlab.com/mmaction/recognition/tsn/custom_backbones/tsn_rn101_32x4d_320p_1x1x3_100e_kinetics400_rgb.log) | [json](https://download.openmmlab.com/mmaction/recognition/tsn/custom_backbones/tsn_rn101_32x4d_320p_1x1x3_100e_kinetics400_rgb.json) |

### Kinetics-400 Data Benchmark (8-gpus, ResNet50, ImageNet pretrain; 3 segments)

In data benchmark, we compare:
Expand Down Expand Up @@ -162,14 +172,14 @@ Notes:

For more details on data preparation, you can refer to

* [preparing_ucf101](/tools/data/ucf101/README.md)
* [preparing_kinetics](/tools/data/kinetics/README.md)
* [preparing_sthv1](/tools/data/sthv1/README.md)
* [preparing_sthv2](/tools/data/sthv2/README.md)
* [preparing_mit](/tools/data/mit/README.md)
* [preparing_mmit](/tools/data/mmit/README.md)
* [preparing_hvu](/tools/data/hvu/README.md)
* [preparing_hmdb51](/tools/data/hmdb51/README.md)
- [preparing_ucf101](/tools/data/ucf101/README.md)
- [preparing_kinetics](/tools/data/kinetics/README.md)
- [preparing_sthv1](/tools/data/sthv1/README.md)
- [preparing_sthv2](/tools/data/sthv2/README.md)
- [preparing_mit](/tools/data/mit/README.md)
- [preparing_mmit](/tools/data/mmit/README.md)
- [preparing_hvu](/tools/data/hvu/README.md)
- [preparing_hmdb51](/tools/data/hmdb51/README.md)

## Train

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
_base_ = [
'../../../_base_/schedules/sgd_100e.py',
'../../../_base_/default_runtime.py'
]

# model settings
model = dict(
type='Recognizer2D',
backbone=dict(
type='mmcls::ResNeXt',
depth=101,
num_stages=4,
out_indices=(3, ),
groups=32,
width_per_group=4,
style='pytorch'),
cls_head=dict(
type='TSNHead',
num_classes=400,
in_channels=2048,
spatial_type='avg',
consensus=dict(type='AvgConsensus', dim=1),
dropout_ratio=0.4,
init_std=0.01),
# model training and testing settings
train_cfg=None,
test_cfg=dict(average_clips=None))

# dataset settings
dataset_type = 'RawframeDataset'
data_root = 'data/kinetics400/rawframes_train_320p'
data_root_val = 'data/kinetics400/rawframes_val_320p'
ann_file_train = 'data/kinetics400/kinetics400_train_list_rawframes_320p.txt'
ann_file_val = 'data/kinetics400/kinetics400_val_list_rawframes_320p.txt'
ann_file_test = 'data/kinetics400/kinetics400_val_list_rawframes_320p.txt'
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_bgr=False)
train_pipeline = [
dict(type='SampleFrames', clip_len=1, frame_interval=1, num_clips=3),
dict(type='RawFrameDecode'),
dict(type='Resize', scale=(-1, 256)),
dict(type='RandomResizedCrop'),
dict(type='Resize', scale=(224, 224), keep_ratio=False),
dict(type='Flip', flip_ratio=0.5),
dict(type='Normalize', **img_norm_cfg),
dict(type='FormatShape', input_format='NCHW'),
dict(type='Collect', keys=['imgs', 'label'], meta_keys=[]),
dict(type='ToTensor', keys=['imgs', 'label'])
]
val_pipeline = [
dict(
type='SampleFrames',
clip_len=1,
frame_interval=1,
num_clips=3,
test_mode=True),
dict(type='RawFrameDecode'),
dict(type='Resize', scale=(-1, 256)),
dict(type='CenterCrop', crop_size=256),
dict(type='Flip', flip_ratio=0),
dict(type='Normalize', **img_norm_cfg),
dict(type='FormatShape', input_format='NCHW'),
dict(type='Collect', keys=['imgs', 'label'], meta_keys=[]),
dict(type='ToTensor', keys=['imgs'])
]
test_pipeline = [
dict(
type='SampleFrames',
clip_len=1,
frame_interval=1,
num_clips=25,
test_mode=True),
dict(type='RawFrameDecode'),
dict(type='Resize', scale=(-1, 256)),
dict(type='ThreeCrop', crop_size=256),
dict(type='Flip', flip_ratio=0),
dict(type='Normalize', **img_norm_cfg),
dict(type='FormatShape', input_format='NCHW'),
dict(type='Collect', keys=['imgs', 'label'], meta_keys=[]),
dict(type='ToTensor', keys=['imgs'])
]
data = dict(
videos_per_gpu=16,
workers_per_gpu=4,
train=dict(
type=dataset_type,
ann_file=ann_file_train,
data_prefix=data_root,
pipeline=train_pipeline),
val=dict(
type=dataset_type,
ann_file=ann_file_val,
data_prefix=data_root_val,
pipeline=val_pipeline),
test=dict(
type=dataset_type,
ann_file=ann_file_test,
data_prefix=data_root_val,
pipeline=test_pipeline))

# runtime settings
work_dir = './work_dirs/tsn_rn101_32x4d_320p_1x1x3_100e_kinetics400_rgb/'
load_from = ('https://download.openmmlab.com/mmclassification/v0/resnext/'
'resnext101_32x4d_batch256_imagenet_20200708-87f2d1c9.pth')
optimizer = dict(
type='SGD',
lr=0.005, # this lr is used for 8 gpus
momentum=0.9,
weight_decay=0.0001)
2 changes: 2 additions & 0 deletions docs/changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
**New Features**

- Support LFB ([#553](https://github.com/open-mmlab/mmaction2/pull/553))
- Support using backbones from MMCls for TSN ([#679](https://github.com/open-mmlab/mmaction2/pull/679))

**Improvements**

Expand All @@ -18,6 +19,7 @@

- Add LFB for AVA2.1 ([#553](https://github.com/open-mmlab/mmaction2/pull/553))
- Add slowonly_nl_embedded_gaussian_r50_4x16x1_150e_kinetics400_rgb ([#690](https://github.com/open-mmlab/mmaction2/pull/690))
- Add TSN with ResNeXt-101-32x4d backbone ([#679](https://github.com/open-mmlab/mmaction2/pull/679))

### 0.12.0 (28/02/2021)

Expand Down
12 changes: 11 additions & 1 deletion mmaction/models/recognizers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,17 @@ def __init__(self,
train_cfg=None,
test_cfg=None):
super().__init__()
self.backbone = builder.build_backbone(backbone)
# The backbones in mmcls can be used by TSN
if backbone['type'].startswith('mmcls::'):
try:
import mmcls.models.builder as mmcls_builder
except (ImportError, ModuleNotFoundError):
raise ImportError('Please install mmcls to use this backbone.')
backbone['type'] = backbone['type'][7:]
self.backbone = mmcls_builder.build_backbone(backbone)
else:
self.backbone = builder.build_backbone(backbone)

if neck is not None:
self.neck = builder.build_neck(neck)
self.cls_head = builder.build_head(cls_head)
Expand Down
27 changes: 27 additions & 0 deletions tests/test_models/test_recognizers/test_recognizer2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,33 @@ def test_tsn():
for one_img in img_list:
recognizer(one_img, gradcam=True)

mmcls_backbone = dict(
type='mmcls::ResNeXt',
depth=101,
num_stages=4,
out_indices=(3, ),
groups=32,
width_per_group=4,
style='pytorch')
config.model['backbone'] = mmcls_backbone

recognizer = build_recognizer(config.model)

input_shape = (1, 3, 3, 32, 32)
demo_inputs = generate_recognizer_demo_inputs(input_shape)

imgs = demo_inputs['imgs']
gt_labels = demo_inputs['gt_labels']

losses = recognizer(imgs, gt_labels)
assert isinstance(losses, dict)

# Test forward test
with torch.no_grad():
img_list = [img[None, :] for img in imgs]
for one_img in img_list:
recognizer(one_img, None, return_loss=False)


def test_tsm():
config = get_recognizer_cfg('tsm/tsm_r50_1x1x8_50e_kinetics400_rgb.py')
Expand Down

0 comments on commit d8f8746

Please # to comment.