From 30d955c36ca8ce900dd57caed74d90b0d7b02955 Mon Sep 17 00:00:00 2001 From: Jingwei Zhang Date: Tue, 15 Nov 2022 10:54:37 +0800 Subject: [PATCH] [Fix] fix instance statistics when only detecting a single class (#2003) --- ...75_second_secfpn_8xb4-cyclic-20e_nus-3d.py | 1 - mmdet3d/datasets/det3d_dataset.py | 30 +++++++++++++------ 2 files changed, 21 insertions(+), 10 deletions(-) diff --git a/configs/centerpoint/centerpoint_voxel0075_second_secfpn_8xb4-cyclic-20e_nus-3d.py b/configs/centerpoint/centerpoint_voxel0075_second_secfpn_8xb4-cyclic-20e_nus-3d.py index 3cbddc356..0a3143472 100644 --- a/configs/centerpoint/centerpoint_voxel0075_second_secfpn_8xb4-cyclic-20e_nus-3d.py +++ b/configs/centerpoint/centerpoint_voxel0075_second_secfpn_8xb4-cyclic-20e_nus-3d.py @@ -89,7 +89,6 @@ dict(type='ObjectRangeFilter', point_cloud_range=point_cloud_range), dict(type='ObjectNameFilter', classes=class_names), dict(type='PointShuffle'), - dict(type='DefaultFormatBundle3D', class_names=class_names), dict( type='Pack3DDetInputs', keys=['points', 'gt_bboxes_3d', 'gt_labels_3d']) diff --git a/mmdet3d/datasets/det3d_dataset.py b/mmdet3d/datasets/det3d_dataset.py index 54a2903a3..7d17968ea 100644 --- a/mmdet3d/datasets/det3d_dataset.py +++ b/mmdet3d/datasets/det3d_dataset.py @@ -255,8 +255,9 @@ def parse_ann_info(self, info: dict) -> Optional[dict]: ann_info['instances'] = info['instances'] for label in ann_info['gt_labels_3d']: - cat_name = self.metainfo['classes'][label] - self.num_ins_per_cat[cat_name] += 1 + if label != -1: + cat_name = self.metainfo['classes'][label] + self.num_ins_per_cat[cat_name] += 1 return ann_info @@ -336,12 +337,16 @@ def _show_ins_var(self, old_labels: np.ndarray, new_labels: torch.Tensor): """ ori_num_per_cat = dict() for label in old_labels: - cat_name = self.metainfo['classes'][label] - ori_num_per_cat[cat_name] = ori_num_per_cat.get(cat_name, 0) + 1 + if label != -1: + cat_name = self.metainfo['classes'][label] + ori_num_per_cat[cat_name] = ori_num_per_cat.get(cat_name, + 0) + 1 new_num_per_cat = dict() for label in new_labels: - cat_name = self.metainfo['classes'][label] - new_num_per_cat[cat_name] = new_num_per_cat.get(cat_name, 0) + 1 + if label != -1: + cat_name = self.metainfo['classes'][label] + new_num_per_cat[cat_name] = new_num_per_cat.get(cat_name, + 0) + 1 content_show = [['category', 'new number', 'ori number']] for cat_name, num in ori_num_per_cat.items(): new_num = new_num_per_cat.get(cat_name, 0) @@ -387,9 +392,16 @@ def prepare_data(self, index: int) -> Optional[dict]: return None if self.show_ins_var: - self._show_ins_var( - ori_input_dict['ann_info']['gt_labels_3d'], - example['data_samples'].gt_instances_3d.labels_3d) + if 'ann_info' in ori_input_dict: + self._show_ins_var( + ori_input_dict['ann_info']['gt_labels_3d'], + example['data_samples'].gt_instances_3d.labels_3d) + else: + print_log( + "'ann_info' is not in the input dict. It's probably that " + 'the data is not in training mode', + 'current', + level=30) return example