Skip to content

AVA+Context #471

New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Merged
merged 26 commits into from
Dec 24, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
c5f00a8
resolve comments
Oct 16, 2020
05575c1
update changelog
Oct 16, 2020
eb09070
Merge branch 'master' of https://github.com/open-mmlab/mmaction2
Oct 19, 2020
81302a4
Merge branch 'master' of https://github.com/open-mmlab/mmaction2
Oct 22, 2020
43a649a
Merge branch 'master' of https://github.com/open-mmlab/mmaction2
Oct 27, 2020
755809e
Merge branch 'master' of https://github.com/open-mmlab/mmaction2
Oct 29, 2020
d478c9d
Merge branch 'master' of https://github.com/open-mmlab/mmaction2
Nov 2, 2020
08bbc06
Merge branch 'master' of https://github.com/open-mmlab/mmaction2
Nov 7, 2020
ff958e6
Merge branch 'master' of https://github.com/open-mmlab/mmaction2
Nov 8, 2020
d0e192d
Merge branch 'master' of https://github.com/open-mmlab/mmaction2
Nov 11, 2020
a52c536
Merge branch 'master' of https://github.com/open-mmlab/mmaction2
Nov 11, 2020
81a2029
Merge branch 'master' of https://github.com/open-mmlab/mmaction2
Nov 16, 2020
e03d2a9
Merge branch 'master' of https://github.com/open-mmlab/mmaction2
Nov 17, 2020
2a9b57f
Merge branch 'master' of https://github.com/open-mmlab/mmaction2
Nov 27, 2020
28001ff
Merge branch 'master' of https://github.com/open-mmlab/mmaction2
Nov 30, 2020
46cc5dd
Merge branch 'master' of https://github.com/open-mmlab/mmaction2
Dec 1, 2020
667818a
Merge branch 'master' of https://github.com/open-mmlab/mmaction2
Dec 18, 2020
34398a8
Merge branch 'master' of https://github.com/open-mmlab/mmaction2
Dec 18, 2020
c5c25ed
add with_global option
Dec 21, 2020
2b68957
f bug
Dec 21, 2020
921d605
+ contiguous
Dec 21, 2020
70084c9
update
Dec 21, 2020
75504f6
add config w. context
Dec 22, 2020
6779d5f
test with_global
Dec 22, 2020
2b11402
update README
Dec 22, 2020
c3df36e
update changelog
Dec 22, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions configs/detection/ava/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,15 @@
| [slowonly_kinetics_pretrained_r101_8x8x1_20e_ava_rgb](/configs/detection/ava/slowonly_kinetics_pretrained_r101_8x8x1_20e_ava_rgb.py) | RGB | Kinetics-400 | ResNet101 | 8x8 | 8x2 | short-side 256 | 24.6 | [log](https://download.openmmlab.com/mmaction/detection/ava/slowonly_kinetics_pretrained_r101_8x8x1_20e_ava_rgb/slowonly_kinetics_pretrained_r101_8x8x1_20e_ava_rgb_20201127.log) | [json](https://download.openmmlab.com/mmaction/detection/ava/slowonly_kinetics_pretrained_r101_8x8x1_20e_ava_rgb/slowonly_kinetics_pretrained_r101_8x8x1_20e_ava_rgb_20201127.json) | [ckpt](https://download.openmmlab.com/mmaction/detection/ava/slowonly_kinetics_pretrained_r101_8x8x1_20e_ava_rgb/slowonly_kinetics_pretrained_r101_8x8x1_20e_ava_rgb_20201217-1c9b4117.pth) |
| [slowonly_omnisource_pretrained_r101_8x8x1_20e_ava_rgb](/configs/detection/ava/slowonly_omnisource_pretrained_r101_8x8x1_20e_ava_rgb.py) | RGB | OmniSource | ResNet101 | 8x8 | 8x2 | short-side 256 | 25.9 | [log](https://download.openmmlab.com/mmaction/detection/ava/slowonly_omnisource_pretrained_r101_8x8x1_20e_ava_rgb/slowonly_omnisource_pretrained_r101_8x8x1_20e_ava_rgb_20201127.log) | [json](https://download.openmmlab.com/mmaction/detection/ava/slowonly_omnisource_pretrained_r101_8x8x1_20e_ava_rgb/slowonly_omnisource_pretrained_r101_8x8x1_20e_ava_rgb_20201127.json) | [ckpt](https://download.openmmlab.com/mmaction/detection/ava/slowonly_omnisource_pretrained_r101_8x8x1_20e_ava_rgb/slowonly_omnisource_pretrained_r101_8x8x1_20e_ava_rgb_20201217-16378594.pth) |
| [slowfast_kinetics_pretrained_r50_4x16x1_20e_ava_rgb](/configs/detection/ava/slowfast_kinetics_pretrained_r50_4x16x1_20e_ava_rgb.py) | RGB | Kinetics-400 | ResNet50 | 32x2 | 8x2 | short-side 256 | 24.4 | [log](https://download.openmmlab.com/mmaction/detection/ava/slowfast_kinetics_pretrained_r50_4x16x1_20e_ava_rgb/slowfast_kinetics_pretrained_r50_4x16x1_20e_ava_rgb_20201217.log) | [json](https://download.openmmlab.com/mmaction/detection/ava/slowfast_kinetics_pretrained_r50_4x16x1_20e_ava_rgb/slowfast_kinetics_pretrained_r50_4x16x1_20e_ava_rgb_20201217.json) | [ckpt](https://download.openmmlab.com/mmaction/detection/ava/slowfast_kinetics_pretrained_r50_4x16x1_20e_ava_rgb/slowfast_kinetics_pretrained_r50_4x16x1_20e_ava_rgb_20201217-6e7c704d.pth) |
| [slowfast_context_kinetics_pretrained_r50_4x16x1_20e_ava_rgb](/configs/detection/ava/slowfast_context_kinetics_pretrained_r50_4x16x1_20e_ava_rgb.py) | RGB | Kinetics-400 | ResNet50 | 32x2 | 8x2 | short-side 256 | 25.4 | [log](https://download.openmmlab.com/mmaction/detection/ava/slowfast_context_kinetics_pretrained_r50_4x16x1_20e_ava_rgb/slowfast_context_kinetics_pretrained_r50_4x16x1_20e_ava_rgb_20201222.log) | [json](https://download.openmmlab.com/mmaction/detection/ava/slowfast_context_kinetics_pretrained_r50_4x16x1_20e_ava_rgb/slowfast_context_kinetics_pretrained_r50_4x16x1_20e_ava_rgb_20201222.json) | [ckpt](https://download.openmmlab.com/mmaction/detection/ava/slowfast_context_kinetics_pretrained_r50_4x16x1_20e_ava_rgb/slowfast_context_kinetics_pretrained_r50_4x16x1_20e_ava_rgb_20201222-f4d209c9.pth) |
| [slowfast_kinetics_pretrained_r50_8x8x1_20e_ava_rgb](/configs/detection/ava/slowfast_kinetics_pretrained_r50_8x8x1_20e_ava_rgb.py) | RGB | Kinetics-400 | ResNet50 | 32x2 | 8x2 | short-side 256 | 25.5 | [log](https://download.openmmlab.com/mmaction/detection/ava/slowfast_kinetics_pretrained_r50_8x8x1_20e_ava_rgb/slowfast_kinetics_pretrained_r50_8x8x1_20e_ava_rgb_20201217.log) | [json](https://download.openmmlab.com/mmaction/detection/ava/slowfast_kinetics_pretrained_r50_8x8x1_20e_ava_rgb/slowfast_kinetics_pretrained_r50_8x8x1_20e_ava_rgb_20201217.json) | [ckpt](https://download.openmmlab.com/mmaction/detection/ava/slowfast_kinetics_pretrained_r50_8x8x1_20e_ava_rgb/slowfast_kinetics_pretrained_r50_8x8x1_20e_ava_rgb_20201217-ae225e97.pth) |

- Notes:

1. The **gpus** indicates the number of gpu we used to get the checkpoint.
According to the [Linear Scaling Rule](https://arxiv.org/abs/1706.02677), you may set the learning rate proportional to the batch size if you use different GPUs or videos per GPU,
e.g., lr=0.01 for 4 GPUs x 2 video/gpu and lr=0.08 for 16 GPUs x 4 video/gpu.
2. **Context** indicates that using both RoI feature and global pooled feature for classification, which leads to around 1% mAP improvement in general.

For more details on data preparation, you can refer to AVA in [Data Preparation](/docs/data_preparation.md).

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
# model setting
model = dict(
type='FastRCNN',
backbone=dict(
type='ResNet3dSlowFast',
pretrained=None,
resample_rate=8,
speed_ratio=8,
channel_ratio=8,
slow_pathway=dict(
type='resnet3d',
depth=50,
pretrained=None,
lateral=True,
conv1_kernel=(1, 7, 7),
dilations=(1, 1, 1, 1),
conv1_stride_t=1,
pool1_stride_t=1,
inflate=(0, 0, 1, 1),
spatial_strides=(1, 2, 2, 1)),
fast_pathway=dict(
type='resnet3d',
depth=50,
pretrained=None,
lateral=False,
base_channels=8,
conv1_kernel=(5, 7, 7),
conv1_stride_t=1,
pool1_stride_t=1,
spatial_strides=(1, 2, 2, 1))),
roi_head=dict(
type='AVARoIHead',
bbox_roi_extractor=dict(
type='SingleRoIExtractor3D',
roi_layer_type='RoIAlign',
output_size=8,
with_temporal_pool=True,
with_global=True),
bbox_head=dict(
type='BBoxHeadAVA',
in_channels=4608,
num_classes=81,
multilabel=True,
dropout_ratio=0.5)))

train_cfg = dict(
rcnn=dict(
assigner=dict(
type='MaxIoUAssignerAVA',
pos_iou_thr=0.9,
neg_iou_thr=0.9,
min_pos_iou=0.9),
sampler=dict(
type='RandomSampler',
num=32,
pos_fraction=1,
neg_pos_ub=-1,
add_gt_as_proposals=True),
pos_weight=1.0,
debug=False))
test_cfg = dict(rcnn=dict(action_thr=0.00))

dataset_type = 'AVADataset'
data_root = 'data/ava/rawframes'
anno_root = 'data/ava/annotations'

ann_file_train = f'{anno_root}/ava_train_v2.1.csv'
ann_file_val = f'{anno_root}/ava_val_v2.1.csv'

exclude_file_train = f'{anno_root}/ava_train_excluded_timestamps_v2.1.csv'
exclude_file_val = f'{anno_root}/ava_val_excluded_timestamps_v2.1.csv'

label_file = f'{anno_root}/ava_action_list_v2.1_for_activitynet_2018.pbtxt'

proposal_file_train = (f'{anno_root}/ava_dense_proposals_train.FAIR.'
'recall_93.9.pkl')
proposal_file_val = f'{anno_root}/ava_dense_proposals_val.FAIR.recall_93.9.pkl'

img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_bgr=False)

train_pipeline = [
dict(type='SampleAVAFrames', clip_len=32, frame_interval=2),
dict(type='RawFrameDecode'),
dict(type='RandomRescale', scale_range=(256, 320)),
dict(type='RandomCrop', size=256),
dict(type='Flip', flip_ratio=0.5),
dict(type='Normalize', **img_norm_cfg),
dict(type='FormatShape', input_format='NCTHW', collapse=True),
# Rename is needed to use mmdet detectors
dict(type='Rename', mapping=dict(imgs='img')),
dict(type='ToTensor', keys=['img', 'proposals', 'gt_bboxes', 'gt_labels']),
dict(
type='ToDataContainer',
fields=[
dict(key=['proposals', 'gt_bboxes', 'gt_labels'], stack=False)
]),
dict(
type='Collect',
keys=['img', 'proposals', 'gt_bboxes', 'gt_labels'],
meta_keys=['scores', 'entity_ids'])
]
# The testing is w/o. any cropping / flipping
val_pipeline = [
dict(type='SampleAVAFrames', clip_len=32, frame_interval=2),
dict(type='RawFrameDecode'),
dict(type='Resize', scale=(-1, 256)),
dict(type='Normalize', **img_norm_cfg),
dict(type='FormatShape', input_format='NCTHW', collapse=True),
# Rename is needed to use mmdet detectors
dict(type='Rename', mapping=dict(imgs='img')),
dict(type='ToTensor', keys=['img', 'proposals']),
dict(type='ToDataContainer', fields=[dict(key='proposals', stack=False)]),
dict(
type='Collect',
keys=['img', 'proposals'],
meta_keys=['scores', 'img_shape'],
nested=True)
]

data = dict(
videos_per_gpu=9,
workers_per_gpu=4,
val_dataloader=dict(videos_per_gpu=1),
test_dataloader=dict(videos_per_gpu=1),
train=dict(
type=dataset_type,
ann_file=ann_file_train,
exclude_file=exclude_file_train,
pipeline=train_pipeline,
label_file=label_file,
proposal_file=proposal_file_train,
person_det_score_thr=0.9,
data_prefix=data_root),
val=dict(
type=dataset_type,
ann_file=ann_file_val,
exclude_file=exclude_file_val,
pipeline=val_pipeline,
label_file=label_file,
proposal_file=proposal_file_val,
person_det_score_thr=0.9,
data_prefix=data_root))
data['test'] = data['val']

optimizer = dict(type='SGD', lr=0.1125, momentum=0.9, weight_decay=0.00001)
# this lr is used for 8 gpus

optimizer_config = dict(grad_clip=dict(max_norm=40, norm_type=2))
# learning policy

lr_config = dict(
policy='step',
step=[10, 15],
warmup='linear',
warmup_by_epoch=True,
warmup_iters=5,
warmup_ratio=0.1)
total_epochs = 20
checkpoint_config = dict(interval=1)
workflow = [('train', 1)]
evaluation = dict(interval=1)
log_config = dict(
interval=20, hooks=[
dict(type='TextLoggerHook'),
])
dist_params = dict(backend='nccl')
log_level = 'INFO'
work_dir = ('./work_dirs/ava/'
'slowfast_context_kinetics_pretrained_r50_4x16x1_20e_ava_rgb')
load_from = ('https://download.openmmlab.com/mmaction/recognition/slowfast/'
'slowfast_r50_4x16x1_256e_kinetics400_rgb/'
'slowfast_r50_4x16x1_256e_kinetics400_rgb_20200704-bcde7ed7.pth')
resume_from = None
find_unused_parameters = False
1 change: 1 addition & 0 deletions docs/changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
- Add markdown lint in pre-commit hook ([#255](https://github.com/open-mmlab/mmaction2/pull/225))
- Use title case in modelzoo statistics. ([#456](https://github.com/open-mmlab/mmaction2/pull/456))
- Add FAQ documents for easy troubleshooting. ([#413](https://github.com/open-mmlab/mmaction2/pull/413), [#420](https://github.com/open-mmlab/mmaction2/pull/420), [#439](https://github.com/open-mmlab/mmaction2/pull/439))
- Support Spatio-Temporal Action Detection with context ([#471](https://github.com/open-mmlab/mmaction2/pull/471))

**Bug and Typo Fixes**

Expand Down
22 changes: 19 additions & 3 deletions mmaction/models/roi_extractors/single_straight3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ class SingleRoIExtractor3D(nn.Module):
Default: True.
with_temporal_pool (bool): if True, avgpool the temporal dim.
Default: True.
with_global (bool): if True, concatenate the RoI feature with global
feature. Default: False.

Note that sampling_ratio, pool_mode, aligned only apply when roi_layer_type
is set as RoIAlign.
Expand All @@ -40,7 +42,8 @@ def __init__(self,
sampling_ratio=0,
pool_mode='avg',
aligned=True,
with_temporal_pool=True):
with_temporal_pool=True,
with_global=False):
super().__init__()
self.roi_layer_type = roi_layer_type
assert self.roi_layer_type in ['RoIPool', 'RoIAlign']
Expand All @@ -53,6 +56,8 @@ def __init__(self,
self.aligned = aligned

self.with_temporal_pool = with_temporal_pool
self.with_global = with_global

if self.roi_layer_type == 'RoIPool':
self.roi_layer = RoIPool(self.output_size, self.spatial_scale)
else:
Expand All @@ -62,10 +67,12 @@ def __init__(self,
sampling_ratio=self.sampling_ratio,
pool_mode=self.pool_mode,
aligned=self.aligned)
self.global_pool = nn.AdaptiveAvgPool2d(self.output_size)

def init_weights(self):
pass

# The shape of feat is N, C, T, H, W
def forward(self, feat, rois):
if not isinstance(feat, tuple):
feat = (feat, )
Expand All @@ -74,10 +81,19 @@ def forward(self, feat, rois):
if self.with_temporal_pool:
feat = [torch.mean(x, 2, keepdim=True) for x in feat]
feat = torch.cat(feat, axis=1)

roi_feats = []
for t in range(feat.size(2)):
frame_feat = feat[:, :, t, :, :].contiguous()
roi_feats.append(self.roi_layer(frame_feat, rois))
frame_feat = feat[:, :, t].contiguous()
roi_feat = self.roi_layer(frame_feat, rois)
if self.with_global:
global_feat = self.global_pool(frame_feat.contiguous())
inds = rois[:, 0].type(torch.int64)
global_feat = global_feat[inds]
roi_feat = torch.cat([roi_feat, global_feat], dim=1)
roi_feat = roi_feat.contiguous()
roi_feats.append(roi_feat)

return torch.stack(roi_feats, dim=2)


Expand Down
13 changes: 13 additions & 0 deletions tests/test_models/test_roi_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,3 +38,16 @@ def test_single_roi_extractor3d():
feat = (torch.randn([4, 64, 8, 16, 16]), torch.randn([4, 32, 16, 16, 16]))
with pytest.raises(AssertionError):
extracted = roi_extractor(feat, rois)

feat = torch.randn([4, 64, 8, 16, 16])
roi_extractor = SingleRoIExtractor3D(
roi_layer_type='RoIAlign',
featmap_stride=16,
output_size=8,
sampling_ratio=0,
pool_mode='avg',
aligned=True,
with_temporal_pool=True,
with_global=True)
extracted = roi_extractor(feat, rois)
assert extracted.shape == (4, 128, 1, 8, 8)