Skip to content

Commit

Permalink
[Enhance] Show instance statistics before and after through pipeline (#…
Browse files Browse the repository at this point in the history
…1863)

* add instance statistics before and after through pipeline

* add docstring

* support showing cat-wise instance statistics

* show all statistics of the dataset

* small fix

* polish code

* show table

* small fix

* rename some varibles
  • Loading branch information
JingweiZhang12 authored Oct 17, 2022
1 parent 42199e7 commit f0175bc
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 5 deletions.
72 changes: 69 additions & 3 deletions mmdet3d/datasets/det3d_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@

import mmengine
import numpy as np
import torch
from mmengine.dataset import BaseDataset
from mmengine.logging import print_log
from terminaltables import AsciiTable

from mmdet3d.datasets import DATASETS
from mmdet3d.structures import get_box_type
Expand Down Expand Up @@ -58,6 +61,9 @@ class Det3DDataset(BaseDataset):
which can be used in Evaluator. Defaults to True.
file_client_args (dict, optional): Configuration of file client.
Defaults to dict(backend='disk').
show_ins_var (bool, optional): For debug purpose. Whether to show
variation of the number of instances before and after through
pipeline. Defaults to False.
"""

def __init__(self,
Expand All @@ -73,6 +79,7 @@ def __init__(self,
test_mode: bool = False,
load_eval_anns=True,
file_client_args: dict = dict(backend='disk'),
show_ins_var: bool = False,
**kwargs) -> None:
# init file client
self.file_client = mmengine.FileClient(**file_client_args)
Expand Down Expand Up @@ -105,13 +112,20 @@ def __init__(self,
for label_idx, name in enumerate(metainfo['CLASSES']):
ori_label = self.METAINFO['CLASSES'].index(name)
self.label_mapping[ori_label] = label_idx

self.num_ins_per_cat = {name: 0 for name in metainfo['CLASSES']}
else:
self.label_mapping = {
i: i
for i in range(len(self.METAINFO['CLASSES']))
}
self.label_mapping[-1] = -1

self.num_ins_per_cat = {
name: 0
for name in self.METAINFO['CLASSES']
}

super().__init__(
ann_file=ann_file,
metainfo=metainfo,
Expand All @@ -125,7 +139,22 @@ def __init__(self,
self.metainfo['box_type_3d'] = box_type_3d
self.metainfo['label_mapping'] = self.label_mapping

def _remove_dontcare(self, ann_info: dict) -> dict:
# used for showing variation of the number of instances before and
# after through the pipeline
self.show_ins_var = show_ins_var

# show statistics of this dataset
print_log('-' * 30, 'current')
print_log(f'The length of the dataset: {len(self)}', 'current')
content_show = [['category', 'number']]
for cat_name, num in self.num_ins_per_cat.items():
content_show.append([cat_name, num])
table = AsciiTable(content_show)
print_log(
f'The number of instances per category in the dataset:\n{table.table}', # noqa: E501
'current')

def _remove_dontcare(self, ann_info):
"""Remove annotations that do not need to be cared.
-1 indicate dontcare in MMDet3d.
Expand Down Expand Up @@ -223,6 +252,11 @@ def parse_ann_info(self, info: dict) -> Optional[dict]:

ann_info[mapped_ann_name] = temp_anns
ann_info['instances'] = info['instances']

for label in ann_info['gt_labels_3d']:
cat_name = self.metainfo['CLASSES'][label]
self.num_ins_per_cat[cat_name] += 1

return ann_info

def parse_data_info(self, info: dict) -> dict:
Expand Down Expand Up @@ -291,6 +325,31 @@ def parse_data_info(self, info: dict) -> dict:

return info

def _show_ins_var(self, old_labels: np.ndarray, new_labels: torch.Tensor):
"""Show variation of the number of instances before and after through
the pipeline.
Args:
old_labels (np.ndarray): The labels before through the pipeline.
new_labels (torch.Tensor): The labels after through the pipeline.
"""
ori_num_per_cat = dict()
for label in old_labels:
cat_name = self.metainfo['CLASSES'][label]
ori_num_per_cat[cat_name] = ori_num_per_cat.get(cat_name, 0) + 1
new_num_per_cat = dict()
for label in new_labels:
cat_name = self.metainfo['CLASSES'][label]
new_num_per_cat[cat_name] = new_num_per_cat.get(cat_name, 0) + 1
content_show = [['category', 'new number', 'ori number']]
for cat_name, num in ori_num_per_cat.items():
new_num = new_num_per_cat.get(cat_name, 0)
content_show.append([cat_name, new_num, num])
table = AsciiTable(content_show)
print_log(
'The number of instances per category after and before '
f'through pipeline:\n{table.table}', 'current')

def prepare_data(self, index: int) -> Optional[dict]:
"""Data preparation for both training and testing stage.
Expand All @@ -302,10 +361,10 @@ def prepare_data(self, index: int) -> Optional[dict]:
Returns:
dict | None: Data dict of the corresponding index.
"""
input_dict = self.get_data_info(index)
ori_input_dict = self.get_data_info(index)

# deepcopy here to avoid inplace modification in pipeline.
input_dict = copy.deepcopy(input_dict)
input_dict = copy.deepcopy(ori_input_dict)

# box_type_3d (str): 3D box type.
input_dict['box_type_3d'] = self.box_type_3d
Expand All @@ -318,12 +377,19 @@ def prepare_data(self, index: int) -> Optional[dict]:
return None

example = self.pipeline(input_dict)

if not self.test_mode and self.filter_empty_gt:
# after pipeline drop the example with empty annotations
# return None to random another in `__getitem__`
if example is None or len(
example['data_samples'].gt_instances_3d.labels_3d) == 0:
return None

if self.show_ins_var:
self._show_ins_var(
ori_input_dict['ann_info']['gt_labels_3d'],
example['data_samples'].gt_instances_3d.labels_3d)

return example

def get_cat_ids(self, idx: int) -> List[int]:
Expand Down
4 changes: 2 additions & 2 deletions mmdet3d/datasets/transforms/dbsampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,12 +133,12 @@ def __init__(
from mmengine.logging import MMLogger
logger: MMLogger = MMLogger.get_current_instance()
for k, v in db_infos.items():
logger.info(f'load {len(v)} {k} database infos')
logger.info(f'load {len(v)} {k} database infos in DataBaseSampler')
for prep_func, val in prepare.items():
db_infos = getattr(self, prep_func)(db_infos, val)
logger.info('After filter database:')
for k, v in db_infos.items():
logger.info(f'load {len(v)} {k} database infos')
logger.info(f'load {len(v)} {k} database infos in DataBaseSampler')

self.db_infos = db_infos

Expand Down

0 comments on commit f0175bc

Please # to comment.