Skip to content

Commit

Permalink
AVA+Context (#471)
Browse files Browse the repository at this point in the history
* resolve comments

* update changelog

* add with_global option

* f bug

* + contiguous

* update

* add config w. context

* test with_global

* update README

* update changelog
  • Loading branch information
kennymckormick authored Dec 24, 2020
1 parent cb02092 commit dc0c82b
Show file tree
Hide file tree
Showing 5 changed files with 210 additions and 3 deletions.
2 changes: 2 additions & 0 deletions configs/detection/ava/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,15 @@
| [slowonly_kinetics_pretrained_r101_8x8x1_20e_ava_rgb](/configs/detection/ava/slowonly_kinetics_pretrained_r101_8x8x1_20e_ava_rgb.py) | RGB | Kinetics-400 | ResNet101 | 8x8 | 8x2 | short-side 256 | 24.6 | [log](https://download.openmmlab.com/mmaction/detection/ava/slowonly_kinetics_pretrained_r101_8x8x1_20e_ava_rgb/slowonly_kinetics_pretrained_r101_8x8x1_20e_ava_rgb_20201127.log) | [json](https://download.openmmlab.com/mmaction/detection/ava/slowonly_kinetics_pretrained_r101_8x8x1_20e_ava_rgb/slowonly_kinetics_pretrained_r101_8x8x1_20e_ava_rgb_20201127.json) | [ckpt](https://download.openmmlab.com/mmaction/detection/ava/slowonly_kinetics_pretrained_r101_8x8x1_20e_ava_rgb/slowonly_kinetics_pretrained_r101_8x8x1_20e_ava_rgb_20201217-1c9b4117.pth) |
| [slowonly_omnisource_pretrained_r101_8x8x1_20e_ava_rgb](/configs/detection/ava/slowonly_omnisource_pretrained_r101_8x8x1_20e_ava_rgb.py) | RGB | OmniSource | ResNet101 | 8x8 | 8x2 | short-side 256 | 25.9 | [log](https://download.openmmlab.com/mmaction/detection/ava/slowonly_omnisource_pretrained_r101_8x8x1_20e_ava_rgb/slowonly_omnisource_pretrained_r101_8x8x1_20e_ava_rgb_20201127.log) | [json](https://download.openmmlab.com/mmaction/detection/ava/slowonly_omnisource_pretrained_r101_8x8x1_20e_ava_rgb/slowonly_omnisource_pretrained_r101_8x8x1_20e_ava_rgb_20201127.json) | [ckpt](https://download.openmmlab.com/mmaction/detection/ava/slowonly_omnisource_pretrained_r101_8x8x1_20e_ava_rgb/slowonly_omnisource_pretrained_r101_8x8x1_20e_ava_rgb_20201217-16378594.pth) |
| [slowfast_kinetics_pretrained_r50_4x16x1_20e_ava_rgb](/configs/detection/ava/slowfast_kinetics_pretrained_r50_4x16x1_20e_ava_rgb.py) | RGB | Kinetics-400 | ResNet50 | 32x2 | 8x2 | short-side 256 | 24.4 | [log](https://download.openmmlab.com/mmaction/detection/ava/slowfast_kinetics_pretrained_r50_4x16x1_20e_ava_rgb/slowfast_kinetics_pretrained_r50_4x16x1_20e_ava_rgb_20201217.log) | [json](https://download.openmmlab.com/mmaction/detection/ava/slowfast_kinetics_pretrained_r50_4x16x1_20e_ava_rgb/slowfast_kinetics_pretrained_r50_4x16x1_20e_ava_rgb_20201217.json) | [ckpt](https://download.openmmlab.com/mmaction/detection/ava/slowfast_kinetics_pretrained_r50_4x16x1_20e_ava_rgb/slowfast_kinetics_pretrained_r50_4x16x1_20e_ava_rgb_20201217-6e7c704d.pth) |
| [slowfast_context_kinetics_pretrained_r50_4x16x1_20e_ava_rgb](/configs/detection/ava/slowfast_context_kinetics_pretrained_r50_4x16x1_20e_ava_rgb.py) | RGB | Kinetics-400 | ResNet50 | 32x2 | 8x2 | short-side 256 | 25.4 | [log](https://download.openmmlab.com/mmaction/detection/ava/slowfast_context_kinetics_pretrained_r50_4x16x1_20e_ava_rgb/slowfast_context_kinetics_pretrained_r50_4x16x1_20e_ava_rgb_20201222.log) | [json](https://download.openmmlab.com/mmaction/detection/ava/slowfast_context_kinetics_pretrained_r50_4x16x1_20e_ava_rgb/slowfast_context_kinetics_pretrained_r50_4x16x1_20e_ava_rgb_20201222.json) | [ckpt](https://download.openmmlab.com/mmaction/detection/ava/slowfast_context_kinetics_pretrained_r50_4x16x1_20e_ava_rgb/slowfast_context_kinetics_pretrained_r50_4x16x1_20e_ava_rgb_20201222-f4d209c9.pth) |
| [slowfast_kinetics_pretrained_r50_8x8x1_20e_ava_rgb](/configs/detection/ava/slowfast_kinetics_pretrained_r50_8x8x1_20e_ava_rgb.py) | RGB | Kinetics-400 | ResNet50 | 32x2 | 8x2 | short-side 256 | 25.5 | [log](https://download.openmmlab.com/mmaction/detection/ava/slowfast_kinetics_pretrained_r50_8x8x1_20e_ava_rgb/slowfast_kinetics_pretrained_r50_8x8x1_20e_ava_rgb_20201217.log) | [json](https://download.openmmlab.com/mmaction/detection/ava/slowfast_kinetics_pretrained_r50_8x8x1_20e_ava_rgb/slowfast_kinetics_pretrained_r50_8x8x1_20e_ava_rgb_20201217.json) | [ckpt](https://download.openmmlab.com/mmaction/detection/ava/slowfast_kinetics_pretrained_r50_8x8x1_20e_ava_rgb/slowfast_kinetics_pretrained_r50_8x8x1_20e_ava_rgb_20201217-ae225e97.pth) |

- Notes:

1. The **gpus** indicates the number of gpu we used to get the checkpoint.
According to the [Linear Scaling Rule](https://arxiv.org/abs/1706.02677), you may set the learning rate proportional to the batch size if you use different GPUs or videos per GPU,
e.g., lr=0.01 for 4 GPUs x 2 video/gpu and lr=0.08 for 16 GPUs x 4 video/gpu.
2. **Context** indicates that using both RoI feature and global pooled feature for classification, which leads to around 1% mAP improvement in general.

For more details on data preparation, you can refer to AVA in [Data Preparation](/docs/data_preparation.md).

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
# model setting
model = dict(
type='FastRCNN',
backbone=dict(
type='ResNet3dSlowFast',
pretrained=None,
resample_rate=8,
speed_ratio=8,
channel_ratio=8,
slow_pathway=dict(
type='resnet3d',
depth=50,
pretrained=None,
lateral=True,
conv1_kernel=(1, 7, 7),
dilations=(1, 1, 1, 1),
conv1_stride_t=1,
pool1_stride_t=1,
inflate=(0, 0, 1, 1),
spatial_strides=(1, 2, 2, 1)),
fast_pathway=dict(
type='resnet3d',
depth=50,
pretrained=None,
lateral=False,
base_channels=8,
conv1_kernel=(5, 7, 7),
conv1_stride_t=1,
pool1_stride_t=1,
spatial_strides=(1, 2, 2, 1))),
roi_head=dict(
type='AVARoIHead',
bbox_roi_extractor=dict(
type='SingleRoIExtractor3D',
roi_layer_type='RoIAlign',
output_size=8,
with_temporal_pool=True,
with_global=True),
bbox_head=dict(
type='BBoxHeadAVA',
in_channels=4608,
num_classes=81,
multilabel=True,
dropout_ratio=0.5)))

train_cfg = dict(
rcnn=dict(
assigner=dict(
type='MaxIoUAssignerAVA',
pos_iou_thr=0.9,
neg_iou_thr=0.9,
min_pos_iou=0.9),
sampler=dict(
type='RandomSampler',
num=32,
pos_fraction=1,
neg_pos_ub=-1,
add_gt_as_proposals=True),
pos_weight=1.0,
debug=False))
test_cfg = dict(rcnn=dict(action_thr=0.00))

dataset_type = 'AVADataset'
data_root = 'data/ava/rawframes'
anno_root = 'data/ava/annotations'

ann_file_train = f'{anno_root}/ava_train_v2.1.csv'
ann_file_val = f'{anno_root}/ava_val_v2.1.csv'

exclude_file_train = f'{anno_root}/ava_train_excluded_timestamps_v2.1.csv'
exclude_file_val = f'{anno_root}/ava_val_excluded_timestamps_v2.1.csv'

label_file = f'{anno_root}/ava_action_list_v2.1_for_activitynet_2018.pbtxt'

proposal_file_train = (f'{anno_root}/ava_dense_proposals_train.FAIR.'
'recall_93.9.pkl')
proposal_file_val = f'{anno_root}/ava_dense_proposals_val.FAIR.recall_93.9.pkl'

img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_bgr=False)

train_pipeline = [
dict(type='SampleAVAFrames', clip_len=32, frame_interval=2),
dict(type='RawFrameDecode'),
dict(type='RandomRescale', scale_range=(256, 320)),
dict(type='RandomCrop', size=256),
dict(type='Flip', flip_ratio=0.5),
dict(type='Normalize', **img_norm_cfg),
dict(type='FormatShape', input_format='NCTHW', collapse=True),
# Rename is needed to use mmdet detectors
dict(type='Rename', mapping=dict(imgs='img')),
dict(type='ToTensor', keys=['img', 'proposals', 'gt_bboxes', 'gt_labels']),
dict(
type='ToDataContainer',
fields=[
dict(key=['proposals', 'gt_bboxes', 'gt_labels'], stack=False)
]),
dict(
type='Collect',
keys=['img', 'proposals', 'gt_bboxes', 'gt_labels'],
meta_keys=['scores', 'entity_ids'])
]
# The testing is w/o. any cropping / flipping
val_pipeline = [
dict(type='SampleAVAFrames', clip_len=32, frame_interval=2),
dict(type='RawFrameDecode'),
dict(type='Resize', scale=(-1, 256)),
dict(type='Normalize', **img_norm_cfg),
dict(type='FormatShape', input_format='NCTHW', collapse=True),
# Rename is needed to use mmdet detectors
dict(type='Rename', mapping=dict(imgs='img')),
dict(type='ToTensor', keys=['img', 'proposals']),
dict(type='ToDataContainer', fields=[dict(key='proposals', stack=False)]),
dict(
type='Collect',
keys=['img', 'proposals'],
meta_keys=['scores', 'img_shape'],
nested=True)
]

data = dict(
videos_per_gpu=9,
workers_per_gpu=4,
val_dataloader=dict(videos_per_gpu=1),
test_dataloader=dict(videos_per_gpu=1),
train=dict(
type=dataset_type,
ann_file=ann_file_train,
exclude_file=exclude_file_train,
pipeline=train_pipeline,
label_file=label_file,
proposal_file=proposal_file_train,
person_det_score_thr=0.9,
data_prefix=data_root),
val=dict(
type=dataset_type,
ann_file=ann_file_val,
exclude_file=exclude_file_val,
pipeline=val_pipeline,
label_file=label_file,
proposal_file=proposal_file_val,
person_det_score_thr=0.9,
data_prefix=data_root))
data['test'] = data['val']

optimizer = dict(type='SGD', lr=0.1125, momentum=0.9, weight_decay=0.00001)
# this lr is used for 8 gpus

optimizer_config = dict(grad_clip=dict(max_norm=40, norm_type=2))
# learning policy

lr_config = dict(
policy='step',
step=[10, 15],
warmup='linear',
warmup_by_epoch=True,
warmup_iters=5,
warmup_ratio=0.1)
total_epochs = 20
checkpoint_config = dict(interval=1)
workflow = [('train', 1)]
evaluation = dict(interval=1)
log_config = dict(
interval=20, hooks=[
dict(type='TextLoggerHook'),
])
dist_params = dict(backend='nccl')
log_level = 'INFO'
work_dir = ('./work_dirs/ava/'
'slowfast_context_kinetics_pretrained_r50_4x16x1_20e_ava_rgb')
load_from = ('https://download.openmmlab.com/mmaction/recognition/slowfast/'
'slowfast_r50_4x16x1_256e_kinetics400_rgb/'
'slowfast_r50_4x16x1_256e_kinetics400_rgb_20200704-bcde7ed7.pth')
resume_from = None
find_unused_parameters = False
1 change: 1 addition & 0 deletions docs/changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
- Speed up confusion matrix calculation ([#465](https://github.com/open-mmlab/mmaction2/pull/465))
- Use title case in modelzoo statistics ([#456](https://github.com/open-mmlab/mmaction2/pull/456))
- Add FAQ documents for easy troubleshooting. ([#413](https://github.com/open-mmlab/mmaction2/pull/413), [#420](https://github.com/open-mmlab/mmaction2/pull/420), [#439](https://github.com/open-mmlab/mmaction2/pull/439))
- Support Spatio-Temporal Action Detection with context ([#471](https://github.com/open-mmlab/mmaction2/pull/471))

**Bug and Typo Fixes**

Expand Down
22 changes: 19 additions & 3 deletions mmaction/models/roi_extractors/single_straight3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ class SingleRoIExtractor3D(nn.Module):
Default: True.
with_temporal_pool (bool): if True, avgpool the temporal dim.
Default: True.
with_global (bool): if True, concatenate the RoI feature with global
feature. Default: False.
Note that sampling_ratio, pool_mode, aligned only apply when roi_layer_type
is set as RoIAlign.
Expand All @@ -44,7 +46,8 @@ def __init__(self,
sampling_ratio=0,
pool_mode='avg',
aligned=True,
with_temporal_pool=True):
with_temporal_pool=True,
with_global=False):
super().__init__()
self.roi_layer_type = roi_layer_type
assert self.roi_layer_type in ['RoIPool', 'RoIAlign']
Expand All @@ -57,6 +60,8 @@ def __init__(self,
self.aligned = aligned

self.with_temporal_pool = with_temporal_pool
self.with_global = with_global

if self.roi_layer_type == 'RoIPool':
self.roi_layer = RoIPool(self.output_size, self.spatial_scale)
else:
Expand All @@ -66,10 +71,12 @@ def __init__(self,
sampling_ratio=self.sampling_ratio,
pool_mode=self.pool_mode,
aligned=self.aligned)
self.global_pool = nn.AdaptiveAvgPool2d(self.output_size)

def init_weights(self):
pass

# The shape of feat is N, C, T, H, W
def forward(self, feat, rois):
if not isinstance(feat, tuple):
feat = (feat, )
Expand All @@ -78,10 +85,19 @@ def forward(self, feat, rois):
if self.with_temporal_pool:
feat = [torch.mean(x, 2, keepdim=True) for x in feat]
feat = torch.cat(feat, axis=1)

roi_feats = []
for t in range(feat.size(2)):
frame_feat = feat[:, :, t, :, :].contiguous()
roi_feats.append(self.roi_layer(frame_feat, rois))
frame_feat = feat[:, :, t].contiguous()
roi_feat = self.roi_layer(frame_feat, rois)
if self.with_global:
global_feat = self.global_pool(frame_feat.contiguous())
inds = rois[:, 0].type(torch.int64)
global_feat = global_feat[inds]
roi_feat = torch.cat([roi_feat, global_feat], dim=1)
roi_feat = roi_feat.contiguous()
roi_feats.append(roi_feat)

return torch.stack(roi_feats, dim=2)


Expand Down
13 changes: 13 additions & 0 deletions tests/test_models/test_roi_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,3 +38,16 @@ def test_single_roi_extractor3d():
feat = (torch.randn([4, 64, 8, 16, 16]), torch.randn([4, 32, 16, 16, 16]))
with pytest.raises(AssertionError):
extracted = roi_extractor(feat, rois)

feat = torch.randn([4, 64, 8, 16, 16])
roi_extractor = SingleRoIExtractor3D(
roi_layer_type='RoIAlign',
featmap_stride=16,
output_size=8,
sampling_ratio=0,
pool_mode='avg',
aligned=True,
with_temporal_pool=True,
with_global=True)
extracted = roi_extractor(feat, rois)
assert extracted.shape == (4, 128, 1, 8, 8)

0 comments on commit dc0c82b

Please # to comment.