From 525fdb78e369c6448f5224c8bd57b1e33a5f1898 Mon Sep 17 00:00:00 2001 From: Tau <674106399@qq.com> Date: Mon, 16 Jan 2023 11:14:28 +0800 Subject: [PATCH] [Fix] fix bug in CombinedDataset and remove RepeatDataset (#1930) --- mmpose/datasets/__init__.py | 4 +-- mmpose/datasets/dataset_wrappers.py | 51 +++++++++-------------------- 2 files changed, 17 insertions(+), 38 deletions(-) diff --git a/mmpose/datasets/__init__.py b/mmpose/datasets/__init__.py index 042c1b7e28..001155172b 100644 --- a/mmpose/datasets/__init__.py +++ b/mmpose/datasets/__init__.py @@ -1,7 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. from .builder import build_dataset -from .dataset_wrappers import CombinedDataset, RepeatDataset +from .dataset_wrappers import CombinedDataset from .datasets import * # noqa from .transforms import * # noqa -__all__ = ['build_dataset', 'RepeatDataset', 'CombinedDataset'] +__all__ = ['build_dataset', 'CombinedDataset'] diff --git a/mmpose/datasets/dataset_wrappers.py b/mmpose/datasets/dataset_wrappers.py index 2997615afd..3836100ed2 100644 --- a/mmpose/datasets/dataset_wrappers.py +++ b/mmpose/datasets/dataset_wrappers.py @@ -10,35 +10,6 @@ from .datasets.utils import parse_pose_metainfo -@DATASETS.register_module() -class RepeatDataset: - """A wrapper of repeated dataset. - - The length of repeated dataset will be `times` larger than the original - dataset. This is useful when the data loading time is long but the dataset - is small. Using RepeatDataset can reduce the data loading time between - epochs. - - Args: - dataset (:obj:`Dataset`): The dataset to be repeated. - times (int): Repeat times. - """ - - def __init__(self, dataset, times): - self.dataset = dataset - self.times = times - - self._ori_len = len(self.dataset) - - def __getitem__(self, idx): - """Get data.""" - return self.dataset[idx % self._ori_len] - - def __len__(self): - """Length after repetition.""" - return self.times * self._ori_len - - @DATASETS.register_module() class CombinedDataset(BaseDataset): """A wrapper of combined dataset. @@ -113,10 +84,7 @@ def prepare_data(self, idx: int) -> Any: Any: Depends on ``self.pipeline``. """ - subset_idx, sample_idx = self._get_subset_index(idx) - # Get data sample from the subset - data_info = self.datasets[subset_idx].get_data_info(sample_idx) - data_info = self.datasets[subset_idx].pipeline(data_info) + data_info = self.get_data_info(idx) # Add metainfo items that are required in the pipeline and the model metainfo_keys = [ @@ -125,13 +93,24 @@ def prepare_data(self, idx: int) -> Any: ] for key in metainfo_keys: - assert key not in data_info, ( - f'"{key}" is a reserved key for `metainfo`, but already ' - 'exists in the `data_info`.') data_info[key] = deepcopy(self._metainfo[key]) return self.pipeline(data_info) + def get_data_info(self, idx: int) -> dict: + """Get annotation by index. + + Args: + idx (int): Global index of ``CombinedDataset``. + Returns: + dict: The idx-th annotation of the datasets. + """ + subset_idx, sample_idx = self._get_subset_index(idx) + # Get data sample processed by ``subset.pipeline`` + data_info = self.datasets[subset_idx][sample_idx] + + return data_info + def full_init(self): """Fully initialize all sub datasets."""