From f0175bcfc758fa5a67003917476def63fd36debe Mon Sep 17 00:00:00 2001 From: Jingwei Zhang Date: Mon, 17 Oct 2022 11:42:54 +0800 Subject: [PATCH] [Enhance] Show instance statistics before and after through pipeline (#1863) * add instance statistics before and after through pipeline * add docstring * support showing cat-wise instance statistics * show all statistics of the dataset * small fix * polish code * show table * small fix * rename some varibles --- mmdet3d/datasets/det3d_dataset.py | 72 +++++++++++++++++++++++- mmdet3d/datasets/transforms/dbsampler.py | 4 +- 2 files changed, 71 insertions(+), 5 deletions(-) diff --git a/mmdet3d/datasets/det3d_dataset.py b/mmdet3d/datasets/det3d_dataset.py index 8a77ab3d7..8901aa1e4 100644 --- a/mmdet3d/datasets/det3d_dataset.py +++ b/mmdet3d/datasets/det3d_dataset.py @@ -5,7 +5,10 @@ import mmengine import numpy as np +import torch from mmengine.dataset import BaseDataset +from mmengine.logging import print_log +from terminaltables import AsciiTable from mmdet3d.datasets import DATASETS from mmdet3d.structures import get_box_type @@ -58,6 +61,9 @@ class Det3DDataset(BaseDataset): which can be used in Evaluator. Defaults to True. file_client_args (dict, optional): Configuration of file client. Defaults to dict(backend='disk'). + show_ins_var (bool, optional): For debug purpose. Whether to show + variation of the number of instances before and after through + pipeline. Defaults to False. """ def __init__(self, @@ -73,6 +79,7 @@ def __init__(self, test_mode: bool = False, load_eval_anns=True, file_client_args: dict = dict(backend='disk'), + show_ins_var: bool = False, **kwargs) -> None: # init file client self.file_client = mmengine.FileClient(**file_client_args) @@ -105,6 +112,8 @@ def __init__(self, for label_idx, name in enumerate(metainfo['CLASSES']): ori_label = self.METAINFO['CLASSES'].index(name) self.label_mapping[ori_label] = label_idx + + self.num_ins_per_cat = {name: 0 for name in metainfo['CLASSES']} else: self.label_mapping = { i: i @@ -112,6 +121,11 @@ def __init__(self, } self.label_mapping[-1] = -1 + self.num_ins_per_cat = { + name: 0 + for name in self.METAINFO['CLASSES'] + } + super().__init__( ann_file=ann_file, metainfo=metainfo, @@ -125,7 +139,22 @@ def __init__(self, self.metainfo['box_type_3d'] = box_type_3d self.metainfo['label_mapping'] = self.label_mapping - def _remove_dontcare(self, ann_info: dict) -> dict: + # used for showing variation of the number of instances before and + # after through the pipeline + self.show_ins_var = show_ins_var + + # show statistics of this dataset + print_log('-' * 30, 'current') + print_log(f'The length of the dataset: {len(self)}', 'current') + content_show = [['category', 'number']] + for cat_name, num in self.num_ins_per_cat.items(): + content_show.append([cat_name, num]) + table = AsciiTable(content_show) + print_log( + f'The number of instances per category in the dataset:\n{table.table}', # noqa: E501 + 'current') + + def _remove_dontcare(self, ann_info): """Remove annotations that do not need to be cared. -1 indicate dontcare in MMDet3d. @@ -223,6 +252,11 @@ def parse_ann_info(self, info: dict) -> Optional[dict]: ann_info[mapped_ann_name] = temp_anns ann_info['instances'] = info['instances'] + + for label in ann_info['gt_labels_3d']: + cat_name = self.metainfo['CLASSES'][label] + self.num_ins_per_cat[cat_name] += 1 + return ann_info def parse_data_info(self, info: dict) -> dict: @@ -291,6 +325,31 @@ def parse_data_info(self, info: dict) -> dict: return info + def _show_ins_var(self, old_labels: np.ndarray, new_labels: torch.Tensor): + """Show variation of the number of instances before and after through + the pipeline. + + Args: + old_labels (np.ndarray): The labels before through the pipeline. + new_labels (torch.Tensor): The labels after through the pipeline. + """ + ori_num_per_cat = dict() + for label in old_labels: + cat_name = self.metainfo['CLASSES'][label] + ori_num_per_cat[cat_name] = ori_num_per_cat.get(cat_name, 0) + 1 + new_num_per_cat = dict() + for label in new_labels: + cat_name = self.metainfo['CLASSES'][label] + new_num_per_cat[cat_name] = new_num_per_cat.get(cat_name, 0) + 1 + content_show = [['category', 'new number', 'ori number']] + for cat_name, num in ori_num_per_cat.items(): + new_num = new_num_per_cat.get(cat_name, 0) + content_show.append([cat_name, new_num, num]) + table = AsciiTable(content_show) + print_log( + 'The number of instances per category after and before ' + f'through pipeline:\n{table.table}', 'current') + def prepare_data(self, index: int) -> Optional[dict]: """Data preparation for both training and testing stage. @@ -302,10 +361,10 @@ def prepare_data(self, index: int) -> Optional[dict]: Returns: dict | None: Data dict of the corresponding index. """ - input_dict = self.get_data_info(index) + ori_input_dict = self.get_data_info(index) # deepcopy here to avoid inplace modification in pipeline. - input_dict = copy.deepcopy(input_dict) + input_dict = copy.deepcopy(ori_input_dict) # box_type_3d (str): 3D box type. input_dict['box_type_3d'] = self.box_type_3d @@ -318,12 +377,19 @@ def prepare_data(self, index: int) -> Optional[dict]: return None example = self.pipeline(input_dict) + if not self.test_mode and self.filter_empty_gt: # after pipeline drop the example with empty annotations # return None to random another in `__getitem__` if example is None or len( example['data_samples'].gt_instances_3d.labels_3d) == 0: return None + + if self.show_ins_var: + self._show_ins_var( + ori_input_dict['ann_info']['gt_labels_3d'], + example['data_samples'].gt_instances_3d.labels_3d) + return example def get_cat_ids(self, idx: int) -> List[int]: diff --git a/mmdet3d/datasets/transforms/dbsampler.py b/mmdet3d/datasets/transforms/dbsampler.py index 4f24a666e..44e8ae05b 100644 --- a/mmdet3d/datasets/transforms/dbsampler.py +++ b/mmdet3d/datasets/transforms/dbsampler.py @@ -133,12 +133,12 @@ def __init__( from mmengine.logging import MMLogger logger: MMLogger = MMLogger.get_current_instance() for k, v in db_infos.items(): - logger.info(f'load {len(v)} {k} database infos') + logger.info(f'load {len(v)} {k} database infos in DataBaseSampler') for prep_func, val in prepare.items(): db_infos = getattr(self, prep_func)(db_infos, val) logger.info('After filter database:') for k, v in db_infos.items(): - logger.info(f'load {len(v)} {k} database infos') + logger.info(f'load {len(v)} {k} database infos in DataBaseSampler') self.db_infos = db_infos