Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Support rawframe inference in demo and polish scripts and docs #59

Merged
merged 5 commits into from
Jul 26, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# model settings
model = dict(
type='Recognizer2D',
backbone=dict(
type='ResNet',
pretrained='torchvision://resnet50',
depth=50,
norm_eval=False),
cls_head=dict(
type='TSNHead',
num_classes=400,
in_channels=2048,
spatial_type='avg',
consensus=dict(type='AvgConsensus', dim=1),
dropout_ratio=0.4,
init_std=0.01))
# model training and testing settings
test_cfg = dict(average_clips=None)
# dataset settings
dataset_type = 'RawframeDataset'
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_bgr=False)
test_pipeline = [
dict(
type='SampleFrames',
clip_len=1,
frame_interval=1,
num_clips=25,
test_mode=True),
dict(type='FrameSelector'),
dict(type='Resize', scale=(-1, 256)),
dict(type='ThreeCrop', crop_size=256),
dict(type='Flip', flip_ratio=0),
dict(type='Normalize', **img_norm_cfg),
dict(type='FormatShape', input_format='NCHW'),
dict(type='Collect', keys=['imgs'], meta_keys=[]),
dict(type='ToTensor', keys=['imgs'])
]
data = dict(
videos_per_gpu=1,
workers_per_gpu=2,
test=dict(
type=dataset_type,
ann_file=None,
data_prefix=None,
pipeline=test_pipeline))
18 changes: 14 additions & 4 deletions demo/demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,13 @@ def parse_args():
parser = argparse.ArgumentParser(description='MMAction2 demo')
parser.add_argument('config', help='test config file path')
parser.add_argument('checkpoint', help='checkpoint file')
parser.add_argument('video', help='video file')
parser.add_argument('video', help='video file or rawframes directory')
parser.add_argument('label', help='label file')
parser.add_argument(
'--use-frames',
default=False,
action='store_true',
help='whether to use rawframes as input')
parser.add_argument(
'--device', type=str, default='cuda:0', help='CPU/CUDA device option')
args = parser.parse_args()
Expand All @@ -22,9 +27,14 @@ def main():
# assign the desired device.
device = torch.device(args.device)
# build the recognizer from a config file and checkpoint file
model = init_recognizer(args.config, args.checkpoint, device=device)
# test a single video
results = inference_recognizer(model, args.video, args.label)
model = init_recognizer(
args.config,
args.checkpoint,
device=device,
use_frames=args.use_frames)
# test a single video or rawframes of a single video
results = inference_recognizer(
model, args.video, args.label, use_frames=args.use_frames)

print('The top-5 labels with corresponding scores are:')
for result in results:
Expand Down
54 changes: 51 additions & 3 deletions docs/getting_started.md
Original file line number Diff line number Diff line change
Expand Up @@ -108,19 +108,25 @@ Examples:
python demo/demo.py configs/recognition/tsn/tsn_r50_video_inference_1x1x3_100e_kinetics400_rgb.p checkpoints/tsn.pth demo/demo.mp4
```

### High-level APIs for testing a video.
### High-level APIs for testing a video and rawframes.

Here is an example of building the model and test a given video.
Here is an example of building the model and testing a given video.

```python
import torch

from mmaction.apis import init_recognizer, inference_recognizer

config_file = 'configs/recognition/tsn/tsn_r50_video_inference_1x1x3_100e_kinetics400_rgb.py'
# download the checkpoint from model zoo and put it in `checkpoints/`
checkpoint_file = 'checkpoints/tsn_r50_1x1x3_100e_kinetics400_rgb_20200614-e508be42.pth'

# assign the desired device.
device = 'cuda:0' # or 'cpu'
device = torch.device(device)

# build the model from a config file and a checkpoint file
model = init_recognizer(config_file, checkpoint_file, device='cpu')
model = init_recognizer(config_file, checkpoint_file, device=device)

# test a single video and show the result:
video = 'demo/demo.mp4'
Expand All @@ -133,6 +139,48 @@ for result in results:
print(f'{result[0]}: ', result[1])
```

Here is an example of building the model and testing with a given rawframes directory.

```python
import torch

from mmaction.apis import init_recognizer, inference_recognizer

config_file = 'configs/recognition/tsn/tsn_r50_inference_1x1x3_100e_kinetics400_rgb.py'
# download the checkpoint from model zoo and put it in `checkpoints/`
checkpoint_file = 'checkpoints/tsn_r50_1x1x3_100e_kinetics400_rgb_20200614-e508be42.pth'

# assign the desired device.
device = 'cuda:0' # or 'cpu'
device = torch.device(device)

# build the model from a config file and a checkpoint file
model = init_recognizer(config_file, checkpoint_file, device=device, use_frames=True)

# test rawframe directory of a single video and show the result:
video = 'SOME_DIR_PATH/'
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's the structure of this dir, "SOME_DIR_PATH/video_name/img_xxxx.jpg" or "SOME_DIR_PATH/img_xxxx.jpg"?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • When filename is "SOME_DIR_PATH/video_name/img_xxxx.jpg" and data_prefix is None in config file. video = SOME_DIR_PATH/video_name.
  • When filename is "SOME_DIR_PATH/video_name/img_xxxx.jpg" and data_prefix is "SOME_DIR_PATH" in config file. video = video_name.
  • When filename is "SOME_DIR_PATH/img_xxxx.jpg" and data_prefix is None in config file. video = SOME_DIR_PATH.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As a default, we set data_prefix=None in config file.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this info clear to reader?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add the description in a comment to make this info clear to readers.

labels = 'demo/label_map.txt'
results = inference_recognizer(model, video, labels, use_frames=True)

# show the results
print(f'The top-5 labels with corresponding scores are:')
for result in results:
print(f'{result[0]}: ', result[1])
```

**Note**: We define `data_prefix` in config files and set it None as default for our provided inference configs.
If the `data_prefix` is not None, the path for the video file (or rawframe directory) to get will be `osp.path(data_prefix, video)`.
Here, the `video` is the param in the demo scripts above.
This detail can be found in `rawframe_dataset.py` and `video_dataset.py`. For example,

* When video (rawframes) path is `SOME_DIR_PATH/VIDEO.mp4` (`SOME_DIR_PATH/VIDEO_NAME/img_xxxxx.jpg`), and `data_prefix` is None in the config file,
the param `video` should be `SOME_DIR_PATH/VIDEO.mp4` (`SOME_DIR_PATH/VIDEO_NAME`).

* When video (rawframes) path is `SOME_DIR_PATH/VIDEO.mp4` (`SOME_DIR_PATH/VIDEO_NAME/img_xxxxx.jpg`), and `data_prefix` is `SOME_DIR_PATH` in the config file,
the param `video` should be `VIDEO.mp4` (`VIDEO_NAME`).

* When rawframes path is `VIDEO_NAME/img_xxxxx.jpg`, and `data_prefix` is None in the config file, the param `video` should be `VIDEO_NAME`.

A notebook demo can be found in [demo/demo.ipynb](/demo/demo.ipynb)


Expand Down
48 changes: 42 additions & 6 deletions mmaction/apis/inference.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import os
import os.path as osp
from operator import itemgetter

import mmcv
Expand All @@ -9,16 +11,20 @@
from ..models import build_recognizer


def init_recognizer(config, checkpoint=None, device='cuda:0'):
def init_recognizer(config,
checkpoint=None,
device='cuda:0',
use_frames=False):
"""Initialize a recognizer from config file.

Args:
config (str or :obj:`mmcv.Config`): Config file path or the config
config (str | :obj:`mmcv.Config`): Config file path or the config
object.
checkpoint (str, optional): Checkpoint path. If left as None, the model
will not load any weights. Default: None.
device (str or :obj:`torch.device`): the desired device of returned
device (str | :obj:`torch.device`): The desired device of returned
tensor. Default: 'cuda:0'.
use_frames (bool): Whether to use rawframes as input. Default:False.

Returns:
nn.Module: The constructed recognizer.
Expand All @@ -28,6 +34,13 @@ def init_recognizer(config, checkpoint=None, device='cuda:0'):
elif not isinstance(config, mmcv.Config):
raise TypeError('config must be a filename or Config object, '
f'but got {type(config)}')
if ((use_frames and config.dataset_type != 'RawframeDataset')
or (not use_frames and config.dataset_type != 'VideoDataset')):
input_type = 'rawframes' if use_frames else 'video'
raise RuntimeError('input data type should be consist with the '
f'dataset type in config, but got input type '
f"'{input_type}' and dataset type "
f"'{config.dataset_type}'")
# pretrained model is unnecessary since we directly load checkpoint later
config.model.backbone.pretrained = None
model = build_recognizer(config.model, test_cfg=config.test_cfg)
Expand All @@ -39,17 +52,30 @@ def init_recognizer(config, checkpoint=None, device='cuda:0'):
return model


def inference_recognizer(model, video_path, label_path):
def inference_recognizer(model, video_path, label_path, use_frames=False):
"""Inference a video with the detector.

Args:
model (nn.Module): The loaded recognizer.
video_path (str): The video file path.
video_path (str): The video file path or the rawframes directory path.
If ``use_frames`` is set to True, it should be rawframes directory
path. Otherwise, it should be video file path.
label_path (str): The label file path.
use_frames (bool): Whether to use rawframes as input. Default:False.

Returns:
dict[tuple(str, float)]: Top-5 recognition result dict.
"""
if not osp.exists(video_path):
raise RuntimeError(f"'{video_path}' is missing")

if osp.isfile(video_path) and use_frames:
raise RuntimeError(
f"'{video_path}' is a video file, not a rawframe directory")
elif osp.isdir(video_path) and not use_frames:
raise RuntimeError(
f"'{video_path}' is a rawframe directory, not a video file")

cfg = model.cfg
device = next(model.parameters()).device # model device
# construct label map
Expand All @@ -59,7 +85,17 @@ def inference_recognizer(model, video_path, label_path):
test_pipeline = cfg.data.test.pipeline
test_pipeline = Compose(test_pipeline)
# prepare data
data = dict(filename=video_path, label=label, modality='RGB')
if use_frames:
filename_tmpl = cfg.data.test.get('filename_tmpl', 'img_{:05}.jpg')
modality = cfg.data.test.get('modality', 'RGB')
data = dict(
frame_dir=video_path,
total_frames=len(os.listdir(video_path)),
label=-1,
filename_tmpl=filename_tmpl,
modality=modality)
else:
data = dict(filename=video_path, label=-1, modality='RGB')
data = test_pipeline(data)
data = collate([data], samples_per_gpu=1)
if next(model.parameters()).is_cuda:
Expand Down