Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

[Feature] Refactor Waymo dataset_converter/dataset/evaluator #2836

Merged
merged 5 commits into from
Dec 27, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -134,3 +134,4 @@ data/sunrgbd/OFFICIAL_SUNRGBD/
# Waymo evaluation
mmdet3d/evaluation/functional/waymo_utils/compute_detection_metrics_main
mmdet3d/evaluation/functional/waymo_utils/compute_detection_let_metrics_main
mmdet3d/evaluation/functional/waymo_utils/compute_segmentation_metrics_main
17 changes: 9 additions & 8 deletions configs/_base_/datasets/waymoD5-3d-3class.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,10 @@
dict(
type='PointsRangeFilter', point_cloud_range=point_cloud_range)
]),
dict(type='Pack3DDetInputs', keys=['points'])
dict(
type='Pack3DDetInputs',
keys=['points'],
meta_keys=['box_type_3d', 'sample_idx', 'context_name', 'timestamp'])
]
# construct a pipeline for data and gt loading in show function
# please keep its loading function consistent with test_pipeline (e.g. client)
Expand All @@ -100,7 +103,10 @@
load_dim=6,
use_dim=5,
backend_args=backend_args),
dict(type='Pack3DDetInputs', keys=['points']),
dict(
type='Pack3DDetInputs',
keys=['points'],
meta_keys=['box_type_3d', 'sample_idx', 'context_name', 'timestamp'])
]

train_dataloader = dict(
Expand Down Expand Up @@ -164,12 +170,7 @@
backend_args=backend_args))

val_evaluator = dict(
type='WaymoMetric',
ann_file='./data/waymo/kitti_format/waymo_infos_val.pkl',
waymo_bin_file='./data/waymo/waymo_format/gt.bin',
data_root='./data/waymo/waymo_format',
backend_args=backend_args,
convert_kitti_format=False)
type='WaymoMetric', waymo_bin_file='./data/waymo/waymo_format/gt.bin')
test_evaluator = val_evaluator

vis_backends = [dict(type='LocalVisBackend')]
Expand Down
15 changes: 7 additions & 8 deletions configs/_base_/datasets/waymoD5-3d-car.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,8 @@
dict(type='PointShuffle'),
dict(
type='Pack3DDetInputs',
keys=['points', 'gt_bboxes_3d', 'gt_labels_3d'])
keys=['points'],
meta_keys=['box_type_3d', 'sample_idx', 'context_name', 'timestamp'])
]
test_pipeline = [
dict(
Expand All @@ -86,7 +87,10 @@
dict(
type='PointsRangeFilter', point_cloud_range=point_cloud_range)
]),
dict(type='Pack3DDetInputs', keys=['points'])
dict(
type='Pack3DDetInputs',
keys=['points'],
meta_keys=['box_type_3d', 'sample_idx', 'context_name', 'timestamp'])
]
# construct a pipeline for data and gt loading in show function
# please keep its loading function consistent with test_pipeline (e.g. client)
Expand Down Expand Up @@ -161,12 +165,7 @@
backend_args=backend_args))

val_evaluator = dict(
type='WaymoMetric',
ann_file='./data/waymo/kitti_format/waymo_infos_val.pkl',
waymo_bin_file='./data/waymo/waymo_format/gt.bin',
data_root='./data/waymo/waymo_format',
convert_kitti_format=False,
backend_args=backend_args)
type='WaymoMetric', waymo_bin_file='./data/waymo/waymo_format/gt.bin')
test_evaluator = val_evaluator

vis_backends = [dict(type='LocalVisBackend')]
Expand Down
17 changes: 8 additions & 9 deletions mmdet3d/datasets/det3d_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,18 +113,15 @@ def __init__(self,
ori_label = self.METAINFO['classes'].index(name)
self.label_mapping[ori_label] = label_idx

self.num_ins_per_cat = {name: 0 for name in metainfo['classes']}
self.num_ins_per_cat = [0] * len(metainfo['classes'])
ZCMax marked this conversation as resolved.
Show resolved Hide resolved
else:
self.label_mapping = {
i: i
for i in range(len(self.METAINFO['classes']))
}
self.label_mapping[-1] = -1

self.num_ins_per_cat = {
name: 0
for name in self.METAINFO['classes']
}
self.num_ins_per_cat = [0] * len(self.METAINFO['classes'])

super().__init__(
ann_file=ann_file,
Expand All @@ -146,9 +143,12 @@ def __init__(self,

# show statistics of this dataset
print_log('-' * 30, 'current')
print_log(f'The length of the dataset: {len(self)}', 'current')
print_log(
f'The length of {"test" if self.test_mode else "training"} dataset: {len(self)}', # noqa: E501
'current')
content_show = [['category', 'number']]
for cat_name, num in self.num_ins_per_cat.items():
for label, num in enumerate(self.num_ins_per_cat):
cat_name = self.metainfo['classes'][label]
content_show.append([cat_name, num])
table = AsciiTable(content_show)
print_log(
Expand Down Expand Up @@ -256,8 +256,7 @@ def parse_ann_info(self, info: dict) -> Union[dict, None]:

for label in ann_info['gt_labels_3d']:
if label != -1:
cat_name = self.metainfo['classes'][label]
self.num_ins_per_cat[cat_name] += 1
self.num_ins_per_cat[label] += 1

return ann_info

Expand Down
117 changes: 80 additions & 37 deletions mmdet3d/datasets/waymo_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@
from typing import Callable, List, Union

import numpy as np
from mmengine import print_log
from mmengine.fileio import load

from mmdet3d.registry import DATASETS
from mmdet3d.structures import CameraInstance3DBoxes
from mmdet3d.structures import CameraInstance3DBoxes, LiDARInstance3DBoxes
from .det3d_dataset import Det3DDataset
from .kitti_dataset import KittiDataset

Expand Down Expand Up @@ -163,13 +165,10 @@
centers_2d = np.zeros((0, 2), dtype=np.float32)
depths = np.zeros((0), dtype=np.float32)

# in waymo, lidar2cam = R0_rect @ Tr_velo_to_cam
# convert gt_bboxes_3d to velodyne coordinates with `lidar2cam`
lidar2cam = np.array(info['images'][self.default_cam_key]['lidar2cam'])
gt_bboxes_3d = CameraInstance3DBoxes(
ann_info['gt_bboxes_3d']).convert_to(self.box_mode_3d,
np.linalg.inv(lidar2cam))
ann_info['gt_bboxes_3d'] = gt_bboxes_3d
if self.load_type == 'frame_based':
gt_bboxes_3d = LiDARInstance3DBoxes(ann_info['gt_bboxes_3d'])

Check warning on line 169 in mmdet3d/datasets/waymo_dataset.py

View check run for this annotation

Codecov / codecov/patch

mmdet3d/datasets/waymo_dataset.py#L169

Added line #L169 was not covered by tests
else:
gt_bboxes_3d = CameraInstance3DBoxes(ann_info['gt_bboxes_3d'])

Check warning on line 171 in mmdet3d/datasets/waymo_dataset.py

View check run for this annotation

Codecov / codecov/patch

mmdet3d/datasets/waymo_dataset.py#L171

Added line #L171 was not covered by tests

anns_results = dict(
gt_bboxes_3d=gt_bboxes_3d,
Expand All @@ -182,9 +181,58 @@
return anns_results

def load_data_list(self) -> List[dict]:
"""Add the load interval."""
data_list = super().load_data_list()
data_list = data_list[::self.load_interval]
"""Add the load interval.

Returns:
list[dict]: A list of annotation.
""" # noqa: E501
# `self.ann_file` denotes the absolute annotation file path if
# `self.root=None` or relative path if `self.root=/path/to/data/`.
annotations = load(self.ann_file)

Check warning on line 191 in mmdet3d/datasets/waymo_dataset.py

View check run for this annotation

Codecov / codecov/patch

mmdet3d/datasets/waymo_dataset.py#L191

Added line #L191 was not covered by tests
if not isinstance(annotations, dict):
raise TypeError(f'The annotations loaded from annotation file '

Check warning on line 193 in mmdet3d/datasets/waymo_dataset.py

View check run for this annotation

Codecov / codecov/patch

mmdet3d/datasets/waymo_dataset.py#L193

Added line #L193 was not covered by tests
f'should be a dict, but got {type(annotations)}!')
if 'data_list' not in annotations or 'metainfo' not in annotations:
raise ValueError('Annotation must have data_list and metainfo '

Check warning on line 196 in mmdet3d/datasets/waymo_dataset.py

View check run for this annotation

Codecov / codecov/patch

mmdet3d/datasets/waymo_dataset.py#L196

Added line #L196 was not covered by tests
'keys')
metainfo = annotations['metainfo']
raw_data_list = annotations['data_list']
raw_data_list = raw_data_list[::self.load_interval]

Check warning on line 200 in mmdet3d/datasets/waymo_dataset.py

View check run for this annotation

Codecov / codecov/patch

mmdet3d/datasets/waymo_dataset.py#L198-L200

Added lines #L198 - L200 were not covered by tests
if self.load_interval > 1:
print_log(

Check warning on line 202 in mmdet3d/datasets/waymo_dataset.py

View check run for this annotation

Codecov / codecov/patch

mmdet3d/datasets/waymo_dataset.py#L202

Added line #L202 was not covered by tests
f'Sample size will be reduced to 1/{self.load_interval} of'
' the original data sample',
logger='current')

# Meta information load from annotation file will not influence the
# existed meta information load from `BaseDataset.METAINFO` and
# `metainfo` arguments defined in constructor.
for k, v in metainfo.items():
self._metainfo.setdefault(k, v)

Check warning on line 211 in mmdet3d/datasets/waymo_dataset.py

View check run for this annotation

Codecov / codecov/patch

mmdet3d/datasets/waymo_dataset.py#L211

Added line #L211 was not covered by tests

# load and parse data_infos.
data_list = []

Check warning on line 214 in mmdet3d/datasets/waymo_dataset.py

View check run for this annotation

Codecov / codecov/patch

mmdet3d/datasets/waymo_dataset.py#L214

Added line #L214 was not covered by tests
for raw_data_info in raw_data_list:
# parse raw data information to target format
data_info = self.parse_data_info(raw_data_info)

Check warning on line 217 in mmdet3d/datasets/waymo_dataset.py

View check run for this annotation

Codecov / codecov/patch

mmdet3d/datasets/waymo_dataset.py#L217

Added line #L217 was not covered by tests
if isinstance(data_info, dict):
# For image tasks, `data_info` should information if single
# image, such as dict(img_path='xxx', width=360, ...)
data_list.append(data_info)

Check warning on line 221 in mmdet3d/datasets/waymo_dataset.py

View check run for this annotation

Codecov / codecov/patch

mmdet3d/datasets/waymo_dataset.py#L221

Added line #L221 was not covered by tests
elif isinstance(data_info, list):
# For video tasks, `data_info` could contain image
# information of multiple frames, such as
# [dict(video_path='xxx', timestamps=...),
# dict(video_path='xxx', timestamps=...)]
for item in data_info:
if not isinstance(item, dict):
raise TypeError('data_info must be list of dict, but '

Check warning on line 229 in mmdet3d/datasets/waymo_dataset.py

View check run for this annotation

Codecov / codecov/patch

mmdet3d/datasets/waymo_dataset.py#L229

Added line #L229 was not covered by tests
f'got {type(item)}')
data_list.extend(data_info)

Check warning on line 231 in mmdet3d/datasets/waymo_dataset.py

View check run for this annotation

Codecov / codecov/patch

mmdet3d/datasets/waymo_dataset.py#L231

Added line #L231 was not covered by tests
else:
raise TypeError('data_info should be a dict or list of dict, '

Check warning on line 233 in mmdet3d/datasets/waymo_dataset.py

View check run for this annotation

Codecov / codecov/patch

mmdet3d/datasets/waymo_dataset.py#L233

Added line #L233 was not covered by tests
f'but got {type(data_info)}')

return data_list

def parse_data_info(self, info: dict) -> Union[dict, List[dict]]:
Expand All @@ -203,44 +251,39 @@
info['images'][self.default_cam_key]
info['images'] = new_image_info
info['instances'] = info['cam_instances'][self.default_cam_key]
return super().parse_data_info(info)
return Det3DDataset.parse_data_info(self, info)

Check warning on line 254 in mmdet3d/datasets/waymo_dataset.py

View check run for this annotation

Codecov / codecov/patch

mmdet3d/datasets/waymo_dataset.py#L254

Added line #L254 was not covered by tests
else:
# in the mono3d, the instances is from cam sync.
# Convert frame-based infos to multi-view image-based
data_list = []
if self.modality['use_lidar']:
info['lidar_points']['lidar_path'] = \
osp.join(
self.data_prefix.get('pts', ''),
info['lidar_points']['lidar_path'])

if self.modality['use_camera']:
for cam_key, img_info in info['images'].items():
if 'img_path' in img_info:
cam_prefix = self.data_prefix.get(cam_key, '')
img_info['img_path'] = osp.join(
cam_prefix, img_info['img_path'])

for (cam_key, img_info) in info['images'].items():
camera_info = dict()
camera_info['sample_idx'] = info['sample_idx']
camera_info['timestamp'] = info['timestamp']
camera_info['context_name'] = info['context_name']

Check warning on line 263 in mmdet3d/datasets/waymo_dataset.py

View check run for this annotation

Codecov / codecov/patch

mmdet3d/datasets/waymo_dataset.py#L261-L263

Added lines #L261 - L263 were not covered by tests
sunjiahao1999 marked this conversation as resolved.
Show resolved Hide resolved
camera_info['images'] = dict()
camera_info['images'][cam_key] = img_info
if 'cam_instances' in info \
and cam_key in info['cam_instances']:
camera_info['instances'] = info['cam_instances'][cam_key]
if 'img_path' in img_info:
cam_prefix = self.data_prefix.get(cam_key, '')
camera_info['images'][cam_key]['img_path'] = osp.join(

Check warning on line 268 in mmdet3d/datasets/waymo_dataset.py

View check run for this annotation

Codecov / codecov/patch

mmdet3d/datasets/waymo_dataset.py#L267-L268

Added lines #L267 - L268 were not covered by tests
cam_prefix, img_info['img_path'])
if 'lidar2cam' in img_info:
camera_info['lidar2cam'] = np.array(img_info['lidar2cam'])

Check warning on line 271 in mmdet3d/datasets/waymo_dataset.py

View check run for this annotation

Codecov / codecov/patch

mmdet3d/datasets/waymo_dataset.py#L271

Added line #L271 was not covered by tests
if 'cam2img' in img_info:
camera_info['cam2img'] = np.array(img_info['cam2img'])

Check warning on line 273 in mmdet3d/datasets/waymo_dataset.py

View check run for this annotation

Codecov / codecov/patch

mmdet3d/datasets/waymo_dataset.py#L273

Added line #L273 was not covered by tests
if 'lidar2img' in img_info:
camera_info['lidar2img'] = np.array(img_info['lidar2img'])

Check warning on line 275 in mmdet3d/datasets/waymo_dataset.py

View check run for this annotation

Codecov / codecov/patch

mmdet3d/datasets/waymo_dataset.py#L275

Added line #L275 was not covered by tests
else:
camera_info['instances'] = []
camera_info['ego2global'] = info['ego2global']
if 'image_sweeps' in info:
camera_info['image_sweeps'] = info['image_sweeps']

# TODO check if need to modify the sample id
# TODO check when will use it except for evaluation.
camera_info['sample_idx'] = info['sample_idx']
camera_info['lidar2img'] = camera_info[

Check warning on line 277 in mmdet3d/datasets/waymo_dataset.py

View check run for this annotation

Codecov / codecov/patch

mmdet3d/datasets/waymo_dataset.py#L277

Added line #L277 was not covered by tests
'cam2img'] @ camera_info['lidar2cam']

if not self.test_mode:
# used in training
camera_info['instances'] = info['cam_instances'][cam_key]

Check warning on line 282 in mmdet3d/datasets/waymo_dataset.py

View check run for this annotation

Codecov / codecov/patch

mmdet3d/datasets/waymo_dataset.py#L282

Added line #L282 was not covered by tests
camera_info['ann_info'] = self.parse_ann_info(camera_info)
if self.test_mode and self.load_eval_anns:
info['eval_ann_info'] = self.parse_ann_info(info)
camera_info['instances'] = info['cam_instances'][cam_key]
camera_info['eval_ann_info'] = self.parse_ann_info(

Check warning on line 286 in mmdet3d/datasets/waymo_dataset.py

View check run for this annotation

Codecov / codecov/patch

mmdet3d/datasets/waymo_dataset.py#L285-L286

Added lines #L285 - L286 were not covered by tests
camera_info)
data_list.append(camera_info)
return data_list
4 changes: 2 additions & 2 deletions mmdet3d/engine/hooks/visualization_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,11 +78,11 @@ def __init__(self,
'needs to be excluded.')
self.vis_task = vis_task

if wait_time == -1:
if show and wait_time == -1:
print_log(
'Manual control mode, press [Right] to next sample.',
logger='current')
else:
elif show:
print_log(
'Autoplay mode, press [SPACE] to pause.', logger='current')
self.wait_time = wait_time
Expand Down
Loading
Loading