Skip to content

Commit

Permalink
[Fix] fix bug in CombinedDataset and remove RepeatDataset (#1930)
Browse files Browse the repository at this point in the history
  • Loading branch information
Tau-J authored Jan 16, 2023
1 parent a30556a commit 525fdb7
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 38 deletions.
4 changes: 2 additions & 2 deletions mmpose/datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .builder import build_dataset
from .dataset_wrappers import CombinedDataset, RepeatDataset
from .dataset_wrappers import CombinedDataset
from .datasets import * # noqa
from .transforms import * # noqa

__all__ = ['build_dataset', 'RepeatDataset', 'CombinedDataset']
__all__ = ['build_dataset', 'CombinedDataset']
51 changes: 15 additions & 36 deletions mmpose/datasets/dataset_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,35 +10,6 @@
from .datasets.utils import parse_pose_metainfo


@DATASETS.register_module()
class RepeatDataset:
"""A wrapper of repeated dataset.
The length of repeated dataset will be `times` larger than the original
dataset. This is useful when the data loading time is long but the dataset
is small. Using RepeatDataset can reduce the data loading time between
epochs.
Args:
dataset (:obj:`Dataset`): The dataset to be repeated.
times (int): Repeat times.
"""

def __init__(self, dataset, times):
self.dataset = dataset
self.times = times

self._ori_len = len(self.dataset)

def __getitem__(self, idx):
"""Get data."""
return self.dataset[idx % self._ori_len]

def __len__(self):
"""Length after repetition."""
return self.times * self._ori_len


@DATASETS.register_module()
class CombinedDataset(BaseDataset):
"""A wrapper of combined dataset.
Expand Down Expand Up @@ -113,10 +84,7 @@ def prepare_data(self, idx: int) -> Any:
Any: Depends on ``self.pipeline``.
"""

subset_idx, sample_idx = self._get_subset_index(idx)
# Get data sample from the subset
data_info = self.datasets[subset_idx].get_data_info(sample_idx)
data_info = self.datasets[subset_idx].pipeline(data_info)
data_info = self.get_data_info(idx)

# Add metainfo items that are required in the pipeline and the model
metainfo_keys = [
Expand All @@ -125,13 +93,24 @@ def prepare_data(self, idx: int) -> Any:
]

for key in metainfo_keys:
assert key not in data_info, (
f'"{key}" is a reserved key for `metainfo`, but already '
'exists in the `data_info`.')
data_info[key] = deepcopy(self._metainfo[key])

return self.pipeline(data_info)

def get_data_info(self, idx: int) -> dict:
"""Get annotation by index.
Args:
idx (int): Global index of ``CombinedDataset``.
Returns:
dict: The idx-th annotation of the datasets.
"""
subset_idx, sample_idx = self._get_subset_index(idx)
# Get data sample processed by ``subset.pipeline``
data_info = self.datasets[subset_idx][sample_idx]

return data_info

def full_init(self):
"""Fully initialize all sub datasets."""

Expand Down

0 comments on commit 525fdb7

Please # to comment.