-
Notifications
You must be signed in to change notification settings - Fork 1.3k
New issue
Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? # to your account
Support rawframe inference in demo and polish scripts and docs #59
Changes from 3 commits
459d645
a488464
a9a9d9b
4191f26
800f16a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
# model settings | ||
model = dict( | ||
type='Recognizer2D', | ||
backbone=dict( | ||
type='ResNet', | ||
pretrained='torchvision://resnet50', | ||
depth=50, | ||
norm_eval=False), | ||
cls_head=dict( | ||
type='TSNHead', | ||
num_classes=400, | ||
in_channels=2048, | ||
spatial_type='avg', | ||
consensus=dict(type='AvgConsensus', dim=1), | ||
dropout_ratio=0.4, | ||
init_std=0.01)) | ||
# model training and testing settings | ||
test_cfg = dict(average_clips=None) | ||
# dataset settings | ||
dataset_type = 'RawframeDataset' | ||
img_norm_cfg = dict( | ||
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_bgr=False) | ||
test_pipeline = [ | ||
dict( | ||
type='SampleFrames', | ||
clip_len=1, | ||
frame_interval=1, | ||
num_clips=25, | ||
test_mode=True), | ||
dict(type='FrameSelector'), | ||
dict(type='Resize', scale=(-1, 256)), | ||
dict(type='ThreeCrop', crop_size=256), | ||
dict(type='Flip', flip_ratio=0), | ||
dict(type='Normalize', **img_norm_cfg), | ||
dict(type='FormatShape', input_format='NCHW'), | ||
dict(type='Collect', keys=['imgs'], meta_keys=[]), | ||
dict(type='ToTensor', keys=['imgs']) | ||
] | ||
data = dict( | ||
videos_per_gpu=1, | ||
workers_per_gpu=2, | ||
test=dict( | ||
type=dataset_type, | ||
ann_file=None, | ||
data_prefix=None, | ||
pipeline=test_pipeline)) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -108,21 +108,30 @@ Examples: | |
python demo/demo.py configs/recognition/tsn/tsn_r50_video_inference_1x1x3_100e_kinetics400_rgb.p checkpoints/tsn.pth demo/demo.mp4 | ||
``` | ||
|
||
### High-level APIs for testing a video. | ||
### High-level APIs for testing a video and rawframes. | ||
|
||
Here is an example of building the model and test a given video. | ||
Here is an example of building the model and testing a given video. | ||
|
||
```python | ||
import torch | ||
|
||
from mmaction.apis import init_recognizer, inference_recognizer | ||
|
||
config_file = 'configs/recognition/tsn/tsn_r50_video_inference_1x1x3_100e_kinetics400_rgb.py' | ||
# download the checkpoint from model zoo and put it in `checkpoints/` | ||
checkpoint_file = 'checkpoints/tsn_r50_1x1x3_100e_kinetics400_rgb_20200614-e508be42.pth' | ||
|
||
# assign the desired device. | ||
device = 'cuda:0' # or 'cpu' | ||
device = torch.device(device) | ||
|
||
# build the model from a config file and a checkpoint file | ||
model = init_recognizer(config_file, checkpoint_file, device='cpu') | ||
model = init_recognizer(config_file, checkpoint_file, device=device) | ||
|
||
# test a single video and show the result: | ||
# The real path for the video to get in this scripts will be `osp.path(data_prefix, video)`. | ||
# This detail can be found in `rawframe_dataset.py` and `video_dataset.py`. | ||
# `data_prefix` is set in config files and we set it None for our provided inference configs. | ||
video = 'demo/demo.mp4' | ||
labels = 'demo/label_map.txt' | ||
results = inference_recognizer(model, video, labels) | ||
|
@@ -133,6 +142,38 @@ for result in results: | |
print(f'{result[0]}: ', result[1]) | ||
``` | ||
|
||
Here is an example of building the model and testing with a given rawframes directory. | ||
|
||
```python | ||
import torch | ||
|
||
from mmaction.apis import init_recognizer, inference_recognizer | ||
|
||
config_file = 'configs/recognition/tsn/tsn_r50_inference_1x1x3_100e_kinetics400_rgb.py' | ||
# download the checkpoint from model zoo and put it in `checkpoints/` | ||
checkpoint_file = 'checkpoints/tsn_r50_1x1x3_100e_kinetics400_rgb_20200614-e508be42.pth' | ||
|
||
# assign the desired device. | ||
device = 'cuda:0' # or 'cpu' | ||
device = torch.device(device) | ||
|
||
# build the model from a config file and a checkpoint file | ||
model = init_recognizer(config_file, checkpoint_file, device=device, use_frames=True) | ||
|
||
# test rawframe directory of a single video and show the result: | ||
# The real path for the rawframe directory to get in this scripts will be `osp.path(data_prefix, video)`. | ||
# This detail can be found in `rawframe_dataset.py` and `video_dataset.py`. | ||
# `data_prefix` is set in config files and we set it None for our provided inference configs. | ||
video = 'SOME_DIR_PATH/' | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what's the structure of this dir, "SOME_DIR_PATH/video_name/img_xxxx.jpg" or "SOME_DIR_PATH/img_xxxx.jpg"? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As a default, we set There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. is this info clear to reader? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add the description in a comment to make this info clear to readers. |
||
labels = 'demo/label_map.txt' | ||
results = inference_recognizer(model, video, labels, use_frames=True) | ||
|
||
# show the results | ||
print(f'The top-5 labels with corresponding scores are:') | ||
for result in results: | ||
print(f'{result[0]}: ', result[1]) | ||
``` | ||
|
||
A notebook demo can be found in [demo/demo.ipynb](/demo/demo.ipynb) | ||
|
||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
real path for the video
is vague, because user is not sure what thevideo
mean. does it mean the folder name of frames when using rawframe?it will be good if you provide examples, rather than giving abstract description