Skip to content

Commit

Permalink
[Fix] fix instance statistics when only detecting a single class (ope…
Browse files Browse the repository at this point in the history
  • Loading branch information
JingweiZhang12 authored and ZwwWayne committed Dec 3, 2022
1 parent cc09580 commit 30d955c
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,6 @@
dict(type='ObjectRangeFilter', point_cloud_range=point_cloud_range),
dict(type='ObjectNameFilter', classes=class_names),
dict(type='PointShuffle'),
dict(type='DefaultFormatBundle3D', class_names=class_names),
dict(
type='Pack3DDetInputs',
keys=['points', 'gt_bboxes_3d', 'gt_labels_3d'])
Expand Down
30 changes: 21 additions & 9 deletions mmdet3d/datasets/det3d_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,8 +255,9 @@ def parse_ann_info(self, info: dict) -> Optional[dict]:
ann_info['instances'] = info['instances']

for label in ann_info['gt_labels_3d']:
cat_name = self.metainfo['classes'][label]
self.num_ins_per_cat[cat_name] += 1
if label != -1:
cat_name = self.metainfo['classes'][label]
self.num_ins_per_cat[cat_name] += 1

return ann_info

Expand Down Expand Up @@ -336,12 +337,16 @@ def _show_ins_var(self, old_labels: np.ndarray, new_labels: torch.Tensor):
"""
ori_num_per_cat = dict()
for label in old_labels:
cat_name = self.metainfo['classes'][label]
ori_num_per_cat[cat_name] = ori_num_per_cat.get(cat_name, 0) + 1
if label != -1:
cat_name = self.metainfo['classes'][label]
ori_num_per_cat[cat_name] = ori_num_per_cat.get(cat_name,
0) + 1
new_num_per_cat = dict()
for label in new_labels:
cat_name = self.metainfo['classes'][label]
new_num_per_cat[cat_name] = new_num_per_cat.get(cat_name, 0) + 1
if label != -1:
cat_name = self.metainfo['classes'][label]
new_num_per_cat[cat_name] = new_num_per_cat.get(cat_name,
0) + 1
content_show = [['category', 'new number', 'ori number']]
for cat_name, num in ori_num_per_cat.items():
new_num = new_num_per_cat.get(cat_name, 0)
Expand Down Expand Up @@ -387,9 +392,16 @@ def prepare_data(self, index: int) -> Optional[dict]:
return None

if self.show_ins_var:
self._show_ins_var(
ori_input_dict['ann_info']['gt_labels_3d'],
example['data_samples'].gt_instances_3d.labels_3d)
if 'ann_info' in ori_input_dict:
self._show_ins_var(
ori_input_dict['ann_info']['gt_labels_3d'],
example['data_samples'].gt_instances_3d.labels_3d)
else:
print_log(
"'ann_info' is not in the input dict. It's probably that "
'the data is not in training mode',
'current',
level=30)

return example

Expand Down

0 comments on commit 30d955c

Please # to comment.