diff --git a/mmaction/datasets/pipelines/loading.py b/mmaction/datasets/pipelines/loading.py index 79a6656c9c..7365eb7054 100644 --- a/mmaction/datasets/pipelines/loading.py +++ b/mmaction/datasets/pipelines/loading.py @@ -99,6 +99,8 @@ class SampleFrames: start_index (None): This argument is deprecated and moved to dataset class (``BaseDataset``, ``VideoDatset``, ``RawframeDataset``, etc), see this: https://github.com/open-mmlab/mmaction2/pull/89. + keep_tail_frames (bool): Whether to keep tail frames when sampling. + Default: False. """ def __init__(self, @@ -109,7 +111,8 @@ def __init__(self, twice_sample=False, out_of_bound_opt='loop', test_mode=False, - start_index=None): + start_index=None, + keep_tail_frames=False): self.clip_len = clip_len self.frame_interval = frame_interval @@ -118,6 +121,7 @@ def __init__(self, self.twice_sample = twice_sample self.out_of_bound_opt = out_of_bound_opt self.test_mode = test_mode + self.keep_tail_frames = keep_tail_frames assert self.out_of_bound_opt in ['loop', 'repeat_last'] if start_index is not None: @@ -140,21 +144,32 @@ def _get_train_clips(self, num_frames): np.ndarray: Sampled frame indices in train mode. """ ori_clip_len = self.clip_len * self.frame_interval - avg_interval = (num_frames - ori_clip_len + 1) // self.num_clips - if avg_interval > 0: - base_offsets = np.arange(self.num_clips) * avg_interval - clip_offsets = base_offsets + np.random.randint( - avg_interval, size=self.num_clips) - elif num_frames > max(self.num_clips, ori_clip_len): - clip_offsets = np.sort( - np.random.randint( - num_frames - ori_clip_len + 1, size=self.num_clips)) - elif avg_interval == 0: - ratio = (num_frames - ori_clip_len + 1.0) / self.num_clips - clip_offsets = np.around(np.arange(self.num_clips) * ratio) + if self.keep_tail_frames: + avg_interval = (num_frames - ori_clip_len + 1) / float( + self.num_clips) + if num_frames > ori_clip_len - 1: + base_offsets = np.arange(self.num_clips) * avg_interval + clip_offsets = (base_offsets + np.random.uniform( + 0, avg_interval, self.num_clips)).astype(np.int) + else: + clip_offsets = np.zeros((self.num_clips, ), dtype=np.int) else: - clip_offsets = np.zeros((self.num_clips, ), dtype=np.int) + avg_interval = (num_frames - ori_clip_len + 1) // self.num_clips + + if avg_interval > 0: + base_offsets = np.arange(self.num_clips) * avg_interval + clip_offsets = base_offsets + np.random.randint( + avg_interval, size=self.num_clips) + elif num_frames > max(self.num_clips, ori_clip_len): + clip_offsets = np.sort( + np.random.randint( + num_frames - ori_clip_len + 1, size=self.num_clips)) + elif avg_interval == 0: + ratio = (num_frames - ori_clip_len + 1.0) / self.num_clips + clip_offsets = np.around(np.arange(self.num_clips) * ratio) + else: + clip_offsets = np.zeros((self.num_clips, ), dtype=np.int) return clip_offsets @@ -333,21 +348,11 @@ class DenseSampleFrames(SampleFrames): """ def __init__(self, - clip_len, - frame_interval=1, - num_clips=1, + *args, sample_range=64, num_sample_positions=10, - temporal_jitter=False, - out_of_bound_opt='loop', - test_mode=False): - super().__init__( - clip_len, - frame_interval, - num_clips, - temporal_jitter, - out_of_bound_opt=out_of_bound_opt, - test_mode=test_mode) + **kwargs): + super().__init__(*args, **kwargs) self.sample_range = sample_range self.num_sample_positions = num_sample_positions diff --git a/tests/test_data/test_pipelines/test_loadings/test_sampling.py b/tests/test_data/test_pipelines/test_loadings/test_sampling.py index 4e47424f74..2cd7a60116 100644 --- a/tests/test_data/test_pipelines/test_loadings/test_sampling.py +++ b/tests/test_data/test_pipelines/test_loadings/test_sampling.py @@ -26,6 +26,15 @@ def test_sample_frames(self): clip_len=3, frame_interval=1, num_clips=5, start_index=1) SampleFrames(**config) + # Sample Frame with tail Frames + video_result = copy.deepcopy(self.video_results) + frame_result = copy.deepcopy(self.frame_results) + config = dict( + clip_len=3, frame_interval=1, num_clips=5, keep_tail_frames=True) + sample_frames = SampleFrames(**config) + sample_frames(video_result) + sample_frames(frame_result) + # Sample Frame with no temporal_jitter # clip_len=3, frame_interval=1, num_clips=5 video_result = copy.deepcopy(self.video_results)