diff --git a/configs/recognition/tsn/tsn_r50_inference_1x1x3_100e_kinetics400_rgb.py b/configs/recognition/tsn/tsn_r50_inference_1x1x3_100e_kinetics400_rgb.py new file mode 100644 index 0000000000..206da6b1cb --- /dev/null +++ b/configs/recognition/tsn/tsn_r50_inference_1x1x3_100e_kinetics400_rgb.py @@ -0,0 +1,46 @@ +# model settings +model = dict( + type='Recognizer2D', + backbone=dict( + type='ResNet', + pretrained='torchvision://resnet50', + depth=50, + norm_eval=False), + cls_head=dict( + type='TSNHead', + num_classes=400, + in_channels=2048, + spatial_type='avg', + consensus=dict(type='AvgConsensus', dim=1), + dropout_ratio=0.4, + init_std=0.01)) +# model training and testing settings +test_cfg = dict(average_clips=None) +# dataset settings +dataset_type = 'RawframeDataset' +img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_bgr=False) +test_pipeline = [ + dict( + type='SampleFrames', + clip_len=1, + frame_interval=1, + num_clips=25, + test_mode=True), + dict(type='FrameSelector'), + dict(type='Resize', scale=(-1, 256)), + dict(type='ThreeCrop', crop_size=256), + dict(type='Flip', flip_ratio=0), + dict(type='Normalize', **img_norm_cfg), + dict(type='FormatShape', input_format='NCHW'), + dict(type='Collect', keys=['imgs'], meta_keys=[]), + dict(type='ToTensor', keys=['imgs']) +] +data = dict( + videos_per_gpu=1, + workers_per_gpu=2, + test=dict( + type=dataset_type, + ann_file=None, + data_prefix=None, + pipeline=test_pipeline)) diff --git a/demo/demo.py b/demo/demo.py index 08c6e3d7f3..9d29115134 100644 --- a/demo/demo.py +++ b/demo/demo.py @@ -9,8 +9,13 @@ def parse_args(): parser = argparse.ArgumentParser(description='MMAction2 demo') parser.add_argument('config', help='test config file path') parser.add_argument('checkpoint', help='checkpoint file') - parser.add_argument('video', help='video file') + parser.add_argument('video', help='video file or rawframes directory') parser.add_argument('label', help='label file') + parser.add_argument( + '--use-frames', + default=False, + action='store_true', + help='whether to use rawframes as input') parser.add_argument( '--device', type=str, default='cuda:0', help='CPU/CUDA device option') args = parser.parse_args() @@ -22,9 +27,14 @@ def main(): # assign the desired device. device = torch.device(args.device) # build the recognizer from a config file and checkpoint file - model = init_recognizer(args.config, args.checkpoint, device=device) - # test a single video - results = inference_recognizer(model, args.video, args.label) + model = init_recognizer( + args.config, + args.checkpoint, + device=device, + use_frames=args.use_frames) + # test a single video or rawframes of a single video + results = inference_recognizer( + model, args.video, args.label, use_frames=args.use_frames) print('The top-5 labels with corresponding scores are:') for result in results: diff --git a/docs/getting_started.md b/docs/getting_started.md index 111ea6545d..c03b6b7b55 100644 --- a/docs/getting_started.md +++ b/docs/getting_started.md @@ -108,19 +108,25 @@ Examples: python demo/demo.py configs/recognition/tsn/tsn_r50_video_inference_1x1x3_100e_kinetics400_rgb.p checkpoints/tsn.pth demo/demo.mp4 ``` -### High-level APIs for testing a video. +### High-level APIs for testing a video and rawframes. -Here is an example of building the model and test a given video. +Here is an example of building the model and testing a given video. ```python +import torch + from mmaction.apis import init_recognizer, inference_recognizer config_file = 'configs/recognition/tsn/tsn_r50_video_inference_1x1x3_100e_kinetics400_rgb.py' # download the checkpoint from model zoo and put it in `checkpoints/` checkpoint_file = 'checkpoints/tsn_r50_1x1x3_100e_kinetics400_rgb_20200614-e508be42.pth' +# assign the desired device. +device = 'cuda:0' # or 'cpu' +device = torch.device(device) + # build the model from a config file and a checkpoint file -model = init_recognizer(config_file, checkpoint_file, device='cpu') +model = init_recognizer(config_file, checkpoint_file, device=device) # test a single video and show the result: video = 'demo/demo.mp4' @@ -133,6 +139,48 @@ for result in results: print(f'{result[0]}: ', result[1]) ``` +Here is an example of building the model and testing with a given rawframes directory. + +```python +import torch + +from mmaction.apis import init_recognizer, inference_recognizer + +config_file = 'configs/recognition/tsn/tsn_r50_inference_1x1x3_100e_kinetics400_rgb.py' +# download the checkpoint from model zoo and put it in `checkpoints/` +checkpoint_file = 'checkpoints/tsn_r50_1x1x3_100e_kinetics400_rgb_20200614-e508be42.pth' + +# assign the desired device. +device = 'cuda:0' # or 'cpu' +device = torch.device(device) + + # build the model from a config file and a checkpoint file +model = init_recognizer(config_file, checkpoint_file, device=device, use_frames=True) + +# test rawframe directory of a single video and show the result: +video = 'SOME_DIR_PATH/' +labels = 'demo/label_map.txt' +results = inference_recognizer(model, video, labels, use_frames=True) + +# show the results +print(f'The top-5 labels with corresponding scores are:') +for result in results: + print(f'{result[0]}: ', result[1]) +``` + +**Note**: We define `data_prefix` in config files and set it None as default for our provided inference configs. +If the `data_prefix` is not None, the path for the video file (or rawframe directory) to get will be `osp.path(data_prefix, video)`. +Here, the `video` is the param in the demo scripts above. +This detail can be found in `rawframe_dataset.py` and `video_dataset.py`. For example, + +* When video (rawframes) path is `SOME_DIR_PATH/VIDEO.mp4` (`SOME_DIR_PATH/VIDEO_NAME/img_xxxxx.jpg`), and `data_prefix` is None in the config file, +the param `video` should be `SOME_DIR_PATH/VIDEO.mp4` (`SOME_DIR_PATH/VIDEO_NAME`). + +* When video (rawframes) path is `SOME_DIR_PATH/VIDEO.mp4` (`SOME_DIR_PATH/VIDEO_NAME/img_xxxxx.jpg`), and `data_prefix` is `SOME_DIR_PATH` in the config file, +the param `video` should be `VIDEO.mp4` (`VIDEO_NAME`). + +* When rawframes path is `VIDEO_NAME/img_xxxxx.jpg`, and `data_prefix` is None in the config file, the param `video` should be `VIDEO_NAME`. + A notebook demo can be found in [demo/demo.ipynb](/demo/demo.ipynb) diff --git a/mmaction/apis/inference.py b/mmaction/apis/inference.py index bf1bb916f4..b9ea7d7fad 100644 --- a/mmaction/apis/inference.py +++ b/mmaction/apis/inference.py @@ -1,3 +1,5 @@ +import os +import os.path as osp from operator import itemgetter import mmcv @@ -9,16 +11,20 @@ from ..models import build_recognizer -def init_recognizer(config, checkpoint=None, device='cuda:0'): +def init_recognizer(config, + checkpoint=None, + device='cuda:0', + use_frames=False): """Initialize a recognizer from config file. Args: - config (str or :obj:`mmcv.Config`): Config file path or the config + config (str | :obj:`mmcv.Config`): Config file path or the config object. checkpoint (str, optional): Checkpoint path. If left as None, the model will not load any weights. Default: None. - device (str or :obj:`torch.device`): the desired device of returned + device (str | :obj:`torch.device`): The desired device of returned tensor. Default: 'cuda:0'. + use_frames (bool): Whether to use rawframes as input. Default:False. Returns: nn.Module: The constructed recognizer. @@ -28,6 +34,13 @@ def init_recognizer(config, checkpoint=None, device='cuda:0'): elif not isinstance(config, mmcv.Config): raise TypeError('config must be a filename or Config object, ' f'but got {type(config)}') + if ((use_frames and config.dataset_type != 'RawframeDataset') + or (not use_frames and config.dataset_type != 'VideoDataset')): + input_type = 'rawframes' if use_frames else 'video' + raise RuntimeError('input data type should be consist with the ' + f'dataset type in config, but got input type ' + f"'{input_type}' and dataset type " + f"'{config.dataset_type}'") # pretrained model is unnecessary since we directly load checkpoint later config.model.backbone.pretrained = None model = build_recognizer(config.model, test_cfg=config.test_cfg) @@ -39,17 +52,30 @@ def init_recognizer(config, checkpoint=None, device='cuda:0'): return model -def inference_recognizer(model, video_path, label_path): +def inference_recognizer(model, video_path, label_path, use_frames=False): """Inference a video with the detector. Args: model (nn.Module): The loaded recognizer. - video_path (str): The video file path. + video_path (str): The video file path or the rawframes directory path. + If ``use_frames`` is set to True, it should be rawframes directory + path. Otherwise, it should be video file path. label_path (str): The label file path. + use_frames (bool): Whether to use rawframes as input. Default:False. Returns: dict[tuple(str, float)]: Top-5 recognition result dict. """ + if not osp.exists(video_path): + raise RuntimeError(f"'{video_path}' is missing") + + if osp.isfile(video_path) and use_frames: + raise RuntimeError( + f"'{video_path}' is a video file, not a rawframe directory") + elif osp.isdir(video_path) and not use_frames: + raise RuntimeError( + f"'{video_path}' is a rawframe directory, not a video file") + cfg = model.cfg device = next(model.parameters()).device # model device # construct label map @@ -59,7 +85,17 @@ def inference_recognizer(model, video_path, label_path): test_pipeline = cfg.data.test.pipeline test_pipeline = Compose(test_pipeline) # prepare data - data = dict(filename=video_path, label=label, modality='RGB') + if use_frames: + filename_tmpl = cfg.data.test.get('filename_tmpl', 'img_{:05}.jpg') + modality = cfg.data.test.get('modality', 'RGB') + data = dict( + frame_dir=video_path, + total_frames=len(os.listdir(video_path)), + label=-1, + filename_tmpl=filename_tmpl, + modality=modality) + else: + data = dict(filename=video_path, label=-1, modality='RGB') data = test_pipeline(data) data = collate([data], samples_per_gpu=1) if next(model.parameters()).is_cuda: