From c5f00a8d7091b46b0ecd175f722d98ea8bea485e Mon Sep 17 00:00:00 2001 From: HaodongDuan Date: Fri, 16 Oct 2020 17:35:47 +0800 Subject: [PATCH 01/10] resolve comments --- tools/data/hvu/generate_sub_file_list.py | 49 ++++++++++++++++++++++++ 1 file changed, 49 insertions(+) create mode 100644 tools/data/hvu/generate_sub_file_list.py diff --git a/tools/data/hvu/generate_sub_file_list.py b/tools/data/hvu/generate_sub_file_list.py new file mode 100644 index 0000000000..77c7bed651 --- /dev/null +++ b/tools/data/hvu/generate_sub_file_list.py @@ -0,0 +1,49 @@ +import argparse +import os.path as osp + +import mmcv + + +def main(annotation_file, category): + assert category in [ + 'action', 'attribute', 'concept', 'event', 'object', 'scene' + ] + + data = mmcv.load(annotation_file) + basename = osp.basename(annotation_file) + dirname = osp.dirname(annotation_file) + basename = basename.replace('hvu', f'hvu_{category}') + + target_file = osp.join(dirname, basename) + + def parse_item(item, category): + label = item['label'] + if category in label: + item['label'] = label[category] + return item + else: + return None + + result = [] + for item in data: + label = item['label'] + if category in label: + item['label'] = label[category] + result.append(item) + + mmcv.dump(data, target_file) + + +if __name__ == '__main__': + description = 'Helper script for generating HVU per-category file list.' + p = argparse.ArgumentParser(description=description) + p.add_argument( + 'annotation_file', + type=str, + help=('The annotation file which contains tags of all categories.')) + p.add_argument( + 'category', + type=str, + choices=['action', 'attribute', 'concept', 'event', 'object', 'scene'], + help='The tag category that you want to generate file list for.') + main(**vars(p.parse_args())) From 05575c18bf7dbba7e6e9f9a0f1c4f973668d2c44 Mon Sep 17 00:00:00 2001 From: HaodongDuan Date: Fri, 16 Oct 2020 17:37:19 +0800 Subject: [PATCH 02/10] update changelog --- docs/changelog.md | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/changelog.md b/docs/changelog.md index d0fe20a249..3cff48664d 100644 --- a/docs/changelog.md +++ b/docs/changelog.md @@ -4,6 +4,7 @@ **Improvements** - Set default values of 'average_clips' in each config file so that there is no need to set it explicitly during testing in most cases ([#232](https://github.com/open-mmlab/mmaction2/pull/232)) +- Extend HVU datatools to generate individual file list for each tag category ([#258](https://github.com/open-mmlab/mmaction2/pull/258)) **Bug Fixes** - Fix the potential bug for default value in dataset_setting ([#245](https://github.com/open-mmlab/mmaction2/pull/245)) From c5c25edce27aa22ac4f8b51c9300d39393170b75 Mon Sep 17 00:00:00 2001 From: kenny Date: Mon, 21 Dec 2020 13:43:33 +0800 Subject: [PATCH 03/10] add with_global option --- .../roi_extractors/single_straight3d.py | 20 ++++++++++++++++--- 1 file changed, 17 insertions(+), 3 deletions(-) diff --git a/mmaction/models/roi_extractors/single_straight3d.py b/mmaction/models/roi_extractors/single_straight3d.py index 09d40b1dcf..4fe2184b2b 100644 --- a/mmaction/models/roi_extractors/single_straight3d.py +++ b/mmaction/models/roi_extractors/single_straight3d.py @@ -28,6 +28,8 @@ class SingleRoIExtractor3D(nn.Module): Default: True. with_temporal_pool (bool): if True, avgpool the temporal dim. Default: True. + with_global (bool): if True, concatenate the RoI feature with global + feature. Default: False. Note that sampling_ratio, pool_mode, aligned only apply when roi_layer_type is set as RoIAlign. @@ -40,7 +42,8 @@ def __init__(self, sampling_ratio=0, pool_mode='avg', aligned=True, - with_temporal_pool=True): + with_temporal_pool=True, + with_global=False): super().__init__() self.roi_layer_type = roi_layer_type assert self.roi_layer_type in ['RoIPool', 'RoIAlign'] @@ -53,6 +56,8 @@ def __init__(self, self.aligned = aligned self.with_temporal_pool = with_temporal_pool + self.with_global = with_global + if self.roi_layer_type == 'RoIPool': self.roi_layer = RoIPool(self.output_size, self.spatial_scale) else: @@ -62,10 +67,12 @@ def __init__(self, sampling_ratio=self.sampling_ratio, pool_mode=self.pool_mode, aligned=self.aligned) + self.global_pool = nn.AdaptiveAvgPool2d(self.output_size) def init_weights(self): pass + # The shape of feat is N, C, T, H, W def forward(self, feat, rois): if not isinstance(feat, tuple): feat = (feat, ) @@ -74,10 +81,17 @@ def forward(self, feat, rois): if self.with_temporal_pool: feat = [torch.mean(x, 2, keepdim=True) for x in feat] feat = torch.cat(feat, axis=1) + roi_feats = [] for t in range(feat.size(2)): - frame_feat = feat[:, :, t, :, :].contiguous() - roi_feats.append(self.roi_layer(frame_feat, rois)) + frame_feat = feat[:, :, t].contiguous() + roi_feat = self.roi_layer(frame_feat, rois) + if self.with_global: + global_feat = self.global_pool(feat[:, :, t].contiguous()) + inds = rois[:, 0].type(torch.int64) + global_feat = global_feat[inds] + roi_feat = torch.cat([roi_feat, global_feat], dim=1) + return torch.stack(roi_feats, dim=2) From 2b6895741e20447488f5f790fd3ffc4e2293d456 Mon Sep 17 00:00:00 2001 From: kenny Date: Mon, 21 Dec 2020 13:56:44 +0800 Subject: [PATCH 04/10] f bug --- mmaction/models/roi_extractors/single_straight3d.py | 1 + 1 file changed, 1 insertion(+) diff --git a/mmaction/models/roi_extractors/single_straight3d.py b/mmaction/models/roi_extractors/single_straight3d.py index 4fe2184b2b..7bf138601f 100644 --- a/mmaction/models/roi_extractors/single_straight3d.py +++ b/mmaction/models/roi_extractors/single_straight3d.py @@ -91,6 +91,7 @@ def forward(self, feat, rois): inds = rois[:, 0].type(torch.int64) global_feat = global_feat[inds] roi_feat = torch.cat([roi_feat, global_feat], dim=1) + roi_feats.append(roi_feat) return torch.stack(roi_feats, dim=2) From 921d60551d76220ebaa2c8d2f714c5e039e8d7cd Mon Sep 17 00:00:00 2001 From: kenny Date: Mon, 21 Dec 2020 14:06:01 +0800 Subject: [PATCH 05/10] + contiguous --- mmaction/models/roi_extractors/single_straight3d.py | 1 + 1 file changed, 1 insertion(+) diff --git a/mmaction/models/roi_extractors/single_straight3d.py b/mmaction/models/roi_extractors/single_straight3d.py index 7bf138601f..1a21059e5c 100644 --- a/mmaction/models/roi_extractors/single_straight3d.py +++ b/mmaction/models/roi_extractors/single_straight3d.py @@ -91,6 +91,7 @@ def forward(self, feat, rois): inds = rois[:, 0].type(torch.int64) global_feat = global_feat[inds] roi_feat = torch.cat([roi_feat, global_feat], dim=1) + roi_feat = roi_feat.contiguous() roi_feats.append(roi_feat) return torch.stack(roi_feats, dim=2) From 70084c91fd632d30d9fd5328ba56f1c58659a8ff Mon Sep 17 00:00:00 2001 From: kenny Date: Mon, 21 Dec 2020 14:24:52 +0800 Subject: [PATCH 06/10] update --- mmaction/models/roi_extractors/single_straight3d.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mmaction/models/roi_extractors/single_straight3d.py b/mmaction/models/roi_extractors/single_straight3d.py index 1a21059e5c..3d6b292193 100644 --- a/mmaction/models/roi_extractors/single_straight3d.py +++ b/mmaction/models/roi_extractors/single_straight3d.py @@ -87,7 +87,7 @@ def forward(self, feat, rois): frame_feat = feat[:, :, t].contiguous() roi_feat = self.roi_layer(frame_feat, rois) if self.with_global: - global_feat = self.global_pool(feat[:, :, t].contiguous()) + global_feat = self.global_pool(frame_feat.contiguous()) inds = rois[:, 0].type(torch.int64) global_feat = global_feat[inds] roi_feat = torch.cat([roi_feat, global_feat], dim=1) From 75504f6849e50007f3f570e0bdda91e7811013f6 Mon Sep 17 00:00:00 2001 From: kenny Date: Tue, 22 Dec 2020 13:02:09 +0800 Subject: [PATCH 07/10] add config w. context --- configs/detection/ava/README.md | 1 + ...etics_pretrained_r50_4x16x1_20e_ava_rgb.py | 175 ++++++++++++++++++ 2 files changed, 176 insertions(+) create mode 100644 configs/detection/ava/slowfast_context_kinetics_pretrained_r50_4x16x1_20e_ava_rgb.py diff --git a/configs/detection/ava/README.md b/configs/detection/ava/README.md index 5267186d5b..17b63df4c9 100644 --- a/configs/detection/ava/README.md +++ b/configs/detection/ava/README.md @@ -38,6 +38,7 @@ | [slowonly_kinetics_pretrained_r101_8x8x1_20e_ava_rgb](/configs/detection/ava/slowonly_kinetics_pretrained_r101_8x8x1_20e_ava_rgb.py) | RGB | Kinetics-400 | ResNet101 | 8x8 | 8x2 | short-side 256 | 24.6 | [log](https://download.openmmlab.com/mmaction/detection/ava/slowonly_kinetics_pretrained_r101_8x8x1_20e_ava_rgb/slowonly_kinetics_pretrained_r101_8x8x1_20e_ava_rgb_20201127.log) | [json](https://download.openmmlab.com/mmaction/detection/ava/slowonly_kinetics_pretrained_r101_8x8x1_20e_ava_rgb/slowonly_kinetics_pretrained_r101_8x8x1_20e_ava_rgb_20201127.json) | [ckpt](https://download.openmmlab.com/mmaction/detection/ava/slowonly_kinetics_pretrained_r101_8x8x1_20e_ava_rgb/slowonly_kinetics_pretrained_r101_8x8x1_20e_ava_rgb_20201217-1c9b4117.pth) | | [slowonly_omnisource_pretrained_r101_8x8x1_20e_ava_rgb](/configs/detection/ava/slowonly_omnisource_pretrained_r101_8x8x1_20e_ava_rgb.py) | RGB | OmniSource | ResNet101 | 8x8 | 8x2 | short-side 256 | 25.9 | [log](https://download.openmmlab.com/mmaction/detection/ava/slowonly_omnisource_pretrained_r101_8x8x1_20e_ava_rgb/slowonly_omnisource_pretrained_r101_8x8x1_20e_ava_rgb_20201127.log) | [json](https://download.openmmlab.com/mmaction/detection/ava/slowonly_omnisource_pretrained_r101_8x8x1_20e_ava_rgb/slowonly_omnisource_pretrained_r101_8x8x1_20e_ava_rgb_20201127.json) | [ckpt](https://download.openmmlab.com/mmaction/detection/ava/slowonly_omnisource_pretrained_r101_8x8x1_20e_ava_rgb/slowonly_omnisource_pretrained_r101_8x8x1_20e_ava_rgb_20201217-16378594.pth) | | [slowfast_kinetics_pretrained_r50_4x16x1_20e_ava_rgb](/configs/detection/ava/slowfast_kinetics_pretrained_r50_4x16x1_20e_ava_rgb.py) | RGB | Kinetics-400 | ResNet50 | 32x2 | 8x2 | short-side 256 | 24.4 | [log](https://download.openmmlab.com/mmaction/detection/ava/slowfast_kinetics_pretrained_r50_4x16x1_20e_ava_rgb/slowfast_kinetics_pretrained_r50_4x16x1_20e_ava_rgb_20201217.log) | [json](https://download.openmmlab.com/mmaction/detection/ava/slowfast_kinetics_pretrained_r50_4x16x1_20e_ava_rgb/slowfast_kinetics_pretrained_r50_4x16x1_20e_ava_rgb_20201217.json) | [ckpt](https://download.openmmlab.com/mmaction/detection/ava/slowfast_kinetics_pretrained_r50_4x16x1_20e_ava_rgb/slowfast_kinetics_pretrained_r50_4x16x1_20e_ava_rgb_20201217-6e7c704d.pth) | +| [slowfast_context_kinetics_pretrained_r50_4x16x1_20e_ava_rgb](/configs/detection/ava/slowfast_context_kinetics_pretrained_r50_4x16x1_20e_ava_rgb.py) | RGB | Kinetics-400 | ResNet50 | 32x2 | 8x2 | short-side 256 | 25.4 | | | | | [slowfast_kinetics_pretrained_r50_8x8x1_20e_ava_rgb](/configs/detection/ava/slowfast_kinetics_pretrained_r50_8x8x1_20e_ava_rgb.py) | RGB | Kinetics-400 | ResNet50 | 32x2 | 8x2 | short-side 256 | 25.5 | [log](https://download.openmmlab.com/mmaction/detection/ava/slowfast_kinetics_pretrained_r50_8x8x1_20e_ava_rgb/slowfast_kinetics_pretrained_r50_8x8x1_20e_ava_rgb_20201217.log) | [json](https://download.openmmlab.com/mmaction/detection/ava/slowfast_kinetics_pretrained_r50_8x8x1_20e_ava_rgb/slowfast_kinetics_pretrained_r50_8x8x1_20e_ava_rgb_20201217.json) | [ckpt](https://download.openmmlab.com/mmaction/detection/ava/slowfast_kinetics_pretrained_r50_8x8x1_20e_ava_rgb/slowfast_kinetics_pretrained_r50_8x8x1_20e_ava_rgb_20201217-ae225e97.pth) | - Notes: diff --git a/configs/detection/ava/slowfast_context_kinetics_pretrained_r50_4x16x1_20e_ava_rgb.py b/configs/detection/ava/slowfast_context_kinetics_pretrained_r50_4x16x1_20e_ava_rgb.py new file mode 100644 index 0000000000..9274ee47e0 --- /dev/null +++ b/configs/detection/ava/slowfast_context_kinetics_pretrained_r50_4x16x1_20e_ava_rgb.py @@ -0,0 +1,175 @@ +# model setting +model = dict( + type='FastRCNN', + backbone=dict( + type='ResNet3dSlowFast', + pretrained=None, + resample_rate=8, + speed_ratio=8, + channel_ratio=8, + slow_pathway=dict( + type='resnet3d', + depth=50, + pretrained=None, + lateral=True, + conv1_kernel=(1, 7, 7), + dilations=(1, 1, 1, 1), + conv1_stride_t=1, + pool1_stride_t=1, + inflate=(0, 0, 1, 1), + spatial_strides=(1, 2, 2, 1)), + fast_pathway=dict( + type='resnet3d', + depth=50, + pretrained=None, + lateral=False, + base_channels=8, + conv1_kernel=(5, 7, 7), + conv1_stride_t=1, + pool1_stride_t=1, + spatial_strides=(1, 2, 2, 1))), + roi_head=dict( + type='AVARoIHead', + bbox_roi_extractor=dict( + type='SingleRoIExtractor3D', + roi_layer_type='RoIAlign', + output_size=8, + with_temporal_pool=True, + with_global=True), + bbox_head=dict( + type='BBoxHeadAVA', + in_channels=4608, + num_classes=81, + multilabel=True, + dropout_ratio=0.5))) + +train_cfg = dict( + rcnn=dict( + assigner=dict( + type='MaxIoUAssignerAVA', + pos_iou_thr=0.9, + neg_iou_thr=0.9, + min_pos_iou=0.9), + sampler=dict( + type='RandomSampler', + num=32, + pos_fraction=1, + neg_pos_ub=-1, + add_gt_as_proposals=True), + pos_weight=1.0, + debug=False)) +test_cfg = dict(rcnn=dict(action_thr=0.00)) + +dataset_type = 'AVADataset' +data_root = 'data/ava/rawframes' +anno_root = 'data/ava/annotations' + +ann_file_train = f'{anno_root}/ava_train_v2.1.csv' +ann_file_val = f'{anno_root}/ava_val_v2.1.csv' + +exclude_file_train = f'{anno_root}/ava_train_excluded_timestamps_v2.1.csv' +exclude_file_val = f'{anno_root}/ava_val_excluded_timestamps_v2.1.csv' + +label_file = f'{anno_root}/ava_action_list_v2.1_for_activitynet_2018.pbtxt' + +proposal_file_train = (f'{anno_root}/ava_dense_proposals_train.FAIR.' + 'recall_93.9.pkl') +proposal_file_val = f'{anno_root}/ava_dense_proposals_val.FAIR.recall_93.9.pkl' + +img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_bgr=False) + +train_pipeline = [ + dict(type='SampleAVAFrames', clip_len=32, frame_interval=2), + dict(type='RawFrameDecode'), + dict(type='RandomRescale', scale_range=(256, 320)), + dict(type='RandomCrop', size=256), + dict(type='Flip', flip_ratio=0.5), + dict(type='Normalize', **img_norm_cfg), + dict(type='FormatShape', input_format='NCTHW', collapse=True), + # Rename is needed to use mmdet detectors + dict(type='Rename', mapping=dict(imgs='img')), + dict(type='ToTensor', keys=['img', 'proposals', 'gt_bboxes', 'gt_labels']), + dict( + type='ToDataContainer', + fields=[ + dict(key=['proposals', 'gt_bboxes', 'gt_labels'], stack=False) + ]), + dict( + type='Collect', + keys=['img', 'proposals', 'gt_bboxes', 'gt_labels'], + meta_keys=['scores', 'entity_ids']) +] +# The testing is w/o. any cropping / flipping +val_pipeline = [ + dict(type='SampleAVAFrames', clip_len=32, frame_interval=2), + dict(type='RawFrameDecode'), + dict(type='Resize', scale=(-1, 256)), + dict(type='Normalize', **img_norm_cfg), + dict(type='FormatShape', input_format='NCTHW', collapse=True), + # Rename is needed to use mmdet detectors + dict(type='Rename', mapping=dict(imgs='img')), + dict(type='ToTensor', keys=['img', 'proposals']), + dict(type='ToDataContainer', fields=[dict(key='proposals', stack=False)]), + dict( + type='Collect', + keys=['img', 'proposals'], + meta_keys=['scores', 'img_shape'], + nested=True) +] + +data = dict( + videos_per_gpu=9, + workers_per_gpu=4, + val_dataloader=dict(videos_per_gpu=1), + test_dataloader=dict(videos_per_gpu=1), + train=dict( + type=dataset_type, + ann_file=ann_file_train, + exclude_file=exclude_file_train, + pipeline=train_pipeline, + label_file=label_file, + proposal_file=proposal_file_train, + person_det_score_thr=0.9, + data_prefix=data_root), + val=dict( + type=dataset_type, + ann_file=ann_file_val, + exclude_file=exclude_file_val, + pipeline=val_pipeline, + label_file=label_file, + proposal_file=proposal_file_val, + person_det_score_thr=0.9, + data_prefix=data_root)) +data['test'] = data['val'] + +optimizer = dict(type='SGD', lr=0.1125, momentum=0.9, weight_decay=0.00001) +# this lr is used for 8 gpus + +optimizer_config = dict(grad_clip=dict(max_norm=40, norm_type=2)) +# learning policy + +lr_config = dict( + policy='step', + step=[10, 15], + warmup='linear', + warmup_by_epoch=True, + warmup_iters=5, + warmup_ratio=0.1) +total_epochs = 20 +checkpoint_config = dict(interval=1) +workflow = [('train', 1)] +evaluation = dict(interval=1) +log_config = dict( + interval=20, hooks=[ + dict(type='TextLoggerHook'), + ]) +dist_params = dict(backend='nccl') +log_level = 'INFO' +work_dir = ('./work_dirs/ava/' + 'slowfast_context_kinetics_pretrained_r50_4x16x1_20e_ava_rgb') +load_from = ('https://download.openmmlab.com/mmaction/recognition/slowfast/' + 'slowfast_r50_4x16x1_256e_kinetics400_rgb/' + 'slowfast_r50_4x16x1_256e_kinetics400_rgb_20200704-bcde7ed7.pth') +resume_from = None +find_unused_parameters = False From 6779d5f8e2c32fc2c0b4faf3bafb68187a2244bf Mon Sep 17 00:00:00 2001 From: kenny Date: Tue, 22 Dec 2020 19:08:20 +0800 Subject: [PATCH 08/10] test with_global --- tests/test_models/test_roi_extractor.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/tests/test_models/test_roi_extractor.py b/tests/test_models/test_roi_extractor.py index d37655b9a1..414b7fdaa9 100644 --- a/tests/test_models/test_roi_extractor.py +++ b/tests/test_models/test_roi_extractor.py @@ -38,3 +38,16 @@ def test_single_roi_extractor3d(): feat = (torch.randn([4, 64, 8, 16, 16]), torch.randn([4, 32, 16, 16, 16])) with pytest.raises(AssertionError): extracted = roi_extractor(feat, rois) + + feat = torch.randn([4, 64, 8, 16, 16]) + roi_extractor = SingleRoIExtractor3D( + roi_layer_type='RoIAlign', + featmap_stride=16, + output_size=8, + sampling_ratio=0, + pool_mode='avg', + aligned=True, + with_temporal_pool=True, + with_global=True) + extracted = roi_extractor(feat, rois) + assert extracted.shape == (4, 128, 1, 8, 8) From 2b11402736eb6b6f48fd49dbd08e8716af1033d4 Mon Sep 17 00:00:00 2001 From: kenny Date: Tue, 22 Dec 2020 19:17:27 +0800 Subject: [PATCH 09/10] update README --- configs/detection/ava/README.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/configs/detection/ava/README.md b/configs/detection/ava/README.md index 17b63df4c9..7a7b75c470 100644 --- a/configs/detection/ava/README.md +++ b/configs/detection/ava/README.md @@ -38,7 +38,7 @@ | [slowonly_kinetics_pretrained_r101_8x8x1_20e_ava_rgb](/configs/detection/ava/slowonly_kinetics_pretrained_r101_8x8x1_20e_ava_rgb.py) | RGB | Kinetics-400 | ResNet101 | 8x8 | 8x2 | short-side 256 | 24.6 | [log](https://download.openmmlab.com/mmaction/detection/ava/slowonly_kinetics_pretrained_r101_8x8x1_20e_ava_rgb/slowonly_kinetics_pretrained_r101_8x8x1_20e_ava_rgb_20201127.log) | [json](https://download.openmmlab.com/mmaction/detection/ava/slowonly_kinetics_pretrained_r101_8x8x1_20e_ava_rgb/slowonly_kinetics_pretrained_r101_8x8x1_20e_ava_rgb_20201127.json) | [ckpt](https://download.openmmlab.com/mmaction/detection/ava/slowonly_kinetics_pretrained_r101_8x8x1_20e_ava_rgb/slowonly_kinetics_pretrained_r101_8x8x1_20e_ava_rgb_20201217-1c9b4117.pth) | | [slowonly_omnisource_pretrained_r101_8x8x1_20e_ava_rgb](/configs/detection/ava/slowonly_omnisource_pretrained_r101_8x8x1_20e_ava_rgb.py) | RGB | OmniSource | ResNet101 | 8x8 | 8x2 | short-side 256 | 25.9 | [log](https://download.openmmlab.com/mmaction/detection/ava/slowonly_omnisource_pretrained_r101_8x8x1_20e_ava_rgb/slowonly_omnisource_pretrained_r101_8x8x1_20e_ava_rgb_20201127.log) | [json](https://download.openmmlab.com/mmaction/detection/ava/slowonly_omnisource_pretrained_r101_8x8x1_20e_ava_rgb/slowonly_omnisource_pretrained_r101_8x8x1_20e_ava_rgb_20201127.json) | [ckpt](https://download.openmmlab.com/mmaction/detection/ava/slowonly_omnisource_pretrained_r101_8x8x1_20e_ava_rgb/slowonly_omnisource_pretrained_r101_8x8x1_20e_ava_rgb_20201217-16378594.pth) | | [slowfast_kinetics_pretrained_r50_4x16x1_20e_ava_rgb](/configs/detection/ava/slowfast_kinetics_pretrained_r50_4x16x1_20e_ava_rgb.py) | RGB | Kinetics-400 | ResNet50 | 32x2 | 8x2 | short-side 256 | 24.4 | [log](https://download.openmmlab.com/mmaction/detection/ava/slowfast_kinetics_pretrained_r50_4x16x1_20e_ava_rgb/slowfast_kinetics_pretrained_r50_4x16x1_20e_ava_rgb_20201217.log) | [json](https://download.openmmlab.com/mmaction/detection/ava/slowfast_kinetics_pretrained_r50_4x16x1_20e_ava_rgb/slowfast_kinetics_pretrained_r50_4x16x1_20e_ava_rgb_20201217.json) | [ckpt](https://download.openmmlab.com/mmaction/detection/ava/slowfast_kinetics_pretrained_r50_4x16x1_20e_ava_rgb/slowfast_kinetics_pretrained_r50_4x16x1_20e_ava_rgb_20201217-6e7c704d.pth) | -| [slowfast_context_kinetics_pretrained_r50_4x16x1_20e_ava_rgb](/configs/detection/ava/slowfast_context_kinetics_pretrained_r50_4x16x1_20e_ava_rgb.py) | RGB | Kinetics-400 | ResNet50 | 32x2 | 8x2 | short-side 256 | 25.4 | | | | +| [slowfast_context_kinetics_pretrained_r50_4x16x1_20e_ava_rgb](/configs/detection/ava/slowfast_context_kinetics_pretrained_r50_4x16x1_20e_ava_rgb.py) | RGB | Kinetics-400 | ResNet50 | 32x2 | 8x2 | short-side 256 | 25.4 | [log](https://download.openmmlab.com/mmaction/detection/ava/slowfast_context_kinetics_pretrained_r50_4x16x1_20e_ava_rgb/slowfast_context_kinetics_pretrained_r50_4x16x1_20e_ava_rgb_20201222.log) | [json](https://download.openmmlab.com/mmaction/detection/ava/slowfast_context_kinetics_pretrained_r50_4x16x1_20e_ava_rgb/slowfast_context_kinetics_pretrained_r50_4x16x1_20e_ava_rgb_20201222.json) | [ckpt](https://download.openmmlab.com/mmaction/detection/ava/slowfast_context_kinetics_pretrained_r50_4x16x1_20e_ava_rgb/slowfast_context_kinetics_pretrained_r50_4x16x1_20e_ava_rgb_20201222-f4d209c9.pth) | | [slowfast_kinetics_pretrained_r50_8x8x1_20e_ava_rgb](/configs/detection/ava/slowfast_kinetics_pretrained_r50_8x8x1_20e_ava_rgb.py) | RGB | Kinetics-400 | ResNet50 | 32x2 | 8x2 | short-side 256 | 25.5 | [log](https://download.openmmlab.com/mmaction/detection/ava/slowfast_kinetics_pretrained_r50_8x8x1_20e_ava_rgb/slowfast_kinetics_pretrained_r50_8x8x1_20e_ava_rgb_20201217.log) | [json](https://download.openmmlab.com/mmaction/detection/ava/slowfast_kinetics_pretrained_r50_8x8x1_20e_ava_rgb/slowfast_kinetics_pretrained_r50_8x8x1_20e_ava_rgb_20201217.json) | [ckpt](https://download.openmmlab.com/mmaction/detection/ava/slowfast_kinetics_pretrained_r50_8x8x1_20e_ava_rgb/slowfast_kinetics_pretrained_r50_8x8x1_20e_ava_rgb_20201217-ae225e97.pth) | - Notes: @@ -46,6 +46,7 @@ 1. The **gpus** indicates the number of gpu we used to get the checkpoint. According to the [Linear Scaling Rule](https://arxiv.org/abs/1706.02677), you may set the learning rate proportional to the batch size if you use different GPUs or videos per GPU, e.g., lr=0.01 for 4 GPUs x 2 video/gpu and lr=0.08 for 16 GPUs x 4 video/gpu. +2. **Context** indicates that using both RoI feature and global pooled feature for classification, which leads to around 1% mAP improvement in general. For more details on data preparation, you can refer to AVA in [Data Preparation](/docs/data_preparation.md). From c3df36e0c9f5699884b473de231e6c8624dc2079 Mon Sep 17 00:00:00 2001 From: kenny Date: Tue, 22 Dec 2020 19:21:44 +0800 Subject: [PATCH 10/10] update changelog --- docs/changelog.md | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/changelog.md b/docs/changelog.md index 3e56586252..b75f622435 100644 --- a/docs/changelog.md +++ b/docs/changelog.md @@ -15,6 +15,7 @@ - Add markdown lint in pre-commit hook ([#255](https://github.com/open-mmlab/mmaction2/pull/225)) - Use title case in modelzoo statistics. ([#456](https://github.com/open-mmlab/mmaction2/pull/456)) - Add FAQ documents for easy troubleshooting. ([#413](https://github.com/open-mmlab/mmaction2/pull/413), [#420](https://github.com/open-mmlab/mmaction2/pull/420), [#439](https://github.com/open-mmlab/mmaction2/pull/439)) +- Support Spatio-Temporal Action Detection with context ([#471](https://github.com/open-mmlab/mmaction2/pull/471)) **Bug and Typo Fixes**