Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Add feature with_offset for rawframe dataset #48

Merged
merged 11 commits into from
Jul 29, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions mmaction/datasets/pipelines/loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -591,7 +591,10 @@ def __call__(self, results):
if results['frame_inds'].ndim != 1:
results['frame_inds'] = np.squeeze(results['frame_inds'])

offset = results.get('offset', 0)

for frame_idx in results['frame_inds']:
frame_idx += offset
if modality == 'RGB':
filepath = osp.join(directory, filename_tmpl.format(frame_idx))
img_bytes = self.file_client.get(filepath)
Expand Down
58 changes: 45 additions & 13 deletions mmaction/datasets/rawframe_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,21 @@ class RawframeDataset(BaseDataset):
some/directory-5 295 3
some/directory-6 121 3

Example of a with_offset annotation file (clips from long videos), each
line indicates the directory to frames of a video, the index of the start
frame, total frames of the video clip and the label of a video clip, which
are split with a whitespace.


.. code-block:: txt

some/directory-1 12 163 3
some/directory-2 213 122 4
some/directory-3 100 258 5
some/directory-4 98 234 2
some/directory-5 0 295 3
some/directory-6 50 121 3


Args:
ann_file (str): Path to the annotation file.
Expand All @@ -52,6 +67,8 @@ class RawframeDataset(BaseDataset):
Default: False.
filename_tmpl (str): Template for each filename.
Default: 'img_{:05}.jpg'.
with_offset (bool): Determines whether the offset information is in
ann_file. Default: False.
multi_class (bool): Determines whether it is a multi-class
recognition dataset. Default: False.
num_classes (int): Number of classes in the dataset. Default: None.
Expand All @@ -65,36 +82,51 @@ def __init__(self,
data_prefix=None,
test_mode=False,
filename_tmpl='img_{:05}.jpg',
with_offset=False,
multi_class=False,
num_classes=None,
modality='RGB'):
self.filename_tmpl = filename_tmpl
self.with_offset = with_offset
super().__init__(ann_file, pipeline, data_prefix, test_mode,
multi_class, num_classes, modality)
self.filename_tmpl = filename_tmpl

def load_annotations(self):
"""Load annotation file to get video information."""
video_infos = []
with open(self.ann_file, 'r') as fin:
for line in fin:
line_split = line.strip().split()
video_info = {}
idx = 0
# idx for frame_dir
frame_dir = line_split[idx]
if self.data_prefix is not None:
frame_dir = osp.join(self.data_prefix, frame_dir)
video_info['frame_dir'] = frame_dir
idx += 1
kennymckormick marked this conversation as resolved.
Show resolved Hide resolved
if self.with_offset:
# idx for offset and total_frames
video_info['offset'] = int(line_split[idx])
video_info['total_frames'] = int(line_split[idx + 1])
idx += 2
else:
# idx for total_frames
video_info['total_frames'] = int(line_split[idx])
idx += 1
# idx for label[s]
label = [int(x) for x in line_split[idx:]]
dreamerlin marked this conversation as resolved.
Show resolved Hide resolved
assert len(label), f'missing label in line: {line}'
if self.multi_class:
assert self.num_classes is not None
(frame_dir, total_frames,
label) = (line_split[0], line_split[1], line_split[2:])
label = list(map(int, label))
onehot = torch.zeros(self.num_classes)
onehot[label] = 1.0
video_info['label'] = onehot
else:
frame_dir, total_frames, label = line_split
label = int(label)
if self.data_prefix is not None:
frame_dir = osp.join(self.data_prefix, frame_dir)
video_infos.append(
dict(
frame_dir=frame_dir,
total_frames=int(total_frames),
label=onehot if self.multi_class else label))
assert len(label) == 1
video_info['label'] = label[0]
video_infos.append(video_info)

return video_infos

def prepare_train_frames(self, idx):
Expand Down
2 changes: 2 additions & 0 deletions tests/data/frame_test_list_multi_label.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
test_imgs 5 1
test_imgs 5 3 5
2 changes: 2 additions & 0 deletions tests/data/frame_test_list_with_offset.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
test_imgs 2 5 127
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are you defining a new dataset annotation format?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what does each column mean?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I add an optional format for ann_file, instead of 'frame_dir num_frame label[s]', it can be 'frame_dir start_idx num_frame label[s]'. My code is compatible with original format.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what num_frame in this context mean? the length of the video or the length of the segment?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The length of the video clip (which is part of the entire untrimmed video).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we need a doc page to describe all possible annotation formats we used. @dreamerlin

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can prepare a annotation_description.md in docs/ to describe all possible annotation formats.

test_imgs 2 5 127
36 changes: 36 additions & 0 deletions tests/test_data/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import mmcv
import numpy as np
import pytest
import torch
from numpy.testing import assert_array_equal

from mmaction.datasets import (ActivityNetDataset, RawframeDataset,
Expand All @@ -22,6 +23,10 @@ def check_keys_contain(result_keys, target_keys):
def setup_class(cls):
cls.data_prefix = osp.join(osp.dirname(osp.dirname(__file__)), 'data')
cls.frame_ann_file = osp.join(cls.data_prefix, 'frame_test_list.txt')
cls.frame_ann_file_with_offset = osp.join(
cls.data_prefix, 'frame_test_list_with_offset.txt')
cls.frame_ann_file_multi_label = osp.join(
cls.data_prefix, 'frame_test_list_multi_label.txt')
cls.video_ann_file = osp.join(cls.data_prefix, 'video_test_list.txt')
cls.action_ann_file = osp.join(cls.data_prefix,
'action_test_anno.json')
Expand Down Expand Up @@ -55,6 +60,37 @@ def test_rawframe_dataset(self):
dict(frame_dir=frame_dir, total_frames=5, label=127)
] * 2

def test_rawframe_dataset_with_offset(self):
rawframe_dataset = RawframeDataset(
self.frame_ann_file_with_offset,
self.frame_pipeline,
self.data_prefix,
with_offset=True)
rawframe_infos = rawframe_dataset.video_infos
frame_dir = osp.join(self.data_prefix, 'test_imgs')
assert rawframe_infos == [
dict(frame_dir=frame_dir, offset=2, total_frames=5, label=127)
] * 2

def test_rawframe_dataset_multi_label(self):
rawframe_dataset = RawframeDataset(
self.frame_ann_file_multi_label,
self.frame_pipeline,
self.data_prefix,
multi_class=True,
num_classes=100)
rawframe_infos = rawframe_dataset.video_infos
frame_dir = osp.join(self.data_prefix, 'test_imgs')
label0 = torch.zeros(100)
label0[[1]] = 1.0
label1 = torch.zeros(100)
label1[[3, 5]] = 1.0
labels = [label0, label1]
for info, label in zip(rawframe_infos, labels):
assert info['frame_dir'] == frame_dir
assert info['total_frames'] == 5
assert torch.all(info['label'] == label)

def test_dataset_realpath(self):
dataset = RawframeDataset(self.frame_ann_file, self.frame_pipeline,
'.')
Expand Down
2 changes: 2 additions & 0 deletions tests/test_data/test_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,14 @@ def setup_class(cls):
total_frames=cls.total_frames,
filename_tmpl=cls.filename_tmpl,
modality='RGB',
offset=0,
label=1)
cls.flow_frame_results = dict(
frame_dir=cls.img_dir,
total_frames=cls.total_frames,
filename_tmpl=cls.flow_filename_tmpl,
modality='Flow',
offset=0,
label=1)
cls.action_results = dict(
video_name='v_test1',
Expand Down