diff --git a/configs/centerpoint/centerpoint_02pillar_second_secfpn_4x8_cyclic_20e_nus_novelo.py b/configs/centerpoint/centerpoint_02pillar_second_secfpn_4x8_cyclic_20e_nus_novelo.py new file mode 100644 index 0000000000..2484e2839c --- /dev/null +++ b/configs/centerpoint/centerpoint_02pillar_second_secfpn_4x8_cyclic_20e_nus_novelo.py @@ -0,0 +1,186 @@ +_base_ = [ + '../_base_/datasets/nus-3d.py', + '../_base_/models/centerpoint_02pillar_second_secfpn_nus.py', + '../_base_/schedules/cyclic_20e.py', '../_base_/default_runtime.py' +] + +# If point cloud range is changed, the models should also change their point +# cloud range accordingly +point_cloud_range = [-51.2, -51.2, -5.0, 51.2, 51.2, 3.0] +# For nuScenes we usually do 10-class detection +class_names = [ + 'car', 'truck', 'construction_vehicle', 'bus', 'trailer', 'barrier', + 'motorcycle', 'bicycle', 'pedestrian', 'traffic_cone' +] + +model = dict( + pts_voxel_layer=dict(point_cloud_range=point_cloud_range), + pts_voxel_encoder=dict(point_cloud_range=point_cloud_range), + pts_bbox_head=dict( + bbox_coder=dict(pc_range=point_cloud_range[:2], code_size=7), + common_heads=dict( + reg=(2, 2), height=(1, 2), dim=(3, 2), rot=(2, 2), _delete_=True)), + # model training and testing settings + train_cfg=dict( + pts=dict( + point_cloud_range=point_cloud_range, + code_weights=[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0])), + test_cfg=dict(pts=dict(pc_range=point_cloud_range[:2]))) + +dataset_type = 'NuScenesDataset' +data_root = 'data/nuscenes/' +file_client_args = dict(backend='disk') + +db_sampler = dict( + data_root=data_root, + info_path=data_root + 'nuscenes_dbinfos_train.pkl', + rate=1.0, + bbox_code_size=7, + prepare=dict( + filter_by_difficulty=[-1], + filter_by_min_points=dict( + car=5, + truck=5, + bus=5, + trailer=5, + construction_vehicle=5, + traffic_cone=5, + barrier=5, + motorcycle=5, + bicycle=5, + pedestrian=5)), + classes=class_names, + sample_groups=dict( + car=2, + truck=3, + construction_vehicle=7, + bus=4, + trailer=6, + barrier=2, + motorcycle=6, + bicycle=6, + pedestrian=2, + traffic_cone=2), + points_loader=dict( + type='LoadPointsFromFile', + coord_type='LIDAR', + load_dim=5, + use_dim=[0, 1, 2, 3, 4], + file_client_args=file_client_args)) + +train_pipeline = [ + dict( + type='LoadPointsFromFile', + coord_type='LIDAR', + load_dim=5, + use_dim=5, + file_client_args=file_client_args), + dict( + type='LoadPointsFromMultiSweeps', + sweeps_num=9, + use_dim=[0, 1, 2, 3, 4], + file_client_args=file_client_args, + pad_empty_sweeps=True, + remove_close=True), + dict(type='LoadAnnotations3D', with_bbox_3d=True, with_label_3d=True), + dict(type='ObjectSample', db_sampler=db_sampler), + dict( + type='GlobalRotScaleTrans', + rot_range=[-0.3925, 0.3925], + scale_ratio_range=[0.95, 1.05], + translation_std=[0, 0, 0]), + dict( + type='RandomFlip3D', + sync_2d=False, + flip_ratio_bev_horizontal=0.5, + flip_ratio_bev_vertical=0.5), + dict(type='PointsRangeFilter', point_cloud_range=point_cloud_range), + dict(type='ObjectRangeFilter', point_cloud_range=point_cloud_range), + dict(type='ObjectNameFilter', classes=class_names), + dict(type='PointShuffle'), + dict(type='DefaultFormatBundle3D', class_names=class_names), + dict(type='Collect3D', keys=['points', 'gt_bboxes_3d', 'gt_labels_3d']) +] +test_pipeline = [ + dict( + type='LoadPointsFromFile', + coord_type='LIDAR', + load_dim=5, + use_dim=5, + file_client_args=file_client_args), + dict( + type='LoadPointsFromMultiSweeps', + sweeps_num=9, + use_dim=[0, 1, 2, 3, 4], + file_client_args=file_client_args, + pad_empty_sweeps=True, + remove_close=True), + dict( + type='MultiScaleFlipAug3D', + img_scale=(1333, 800), + pts_scale_ratio=1, + flip=False, + transforms=[ + dict( + type='GlobalRotScaleTrans', + rot_range=[0, 0], + scale_ratio_range=[1., 1.], + translation_std=[0, 0, 0]), + dict(type='RandomFlip3D'), + dict( + type='DefaultFormatBundle3D', + class_names=class_names, + with_label=False), + dict(type='Collect3D', keys=['points']) + ]) +] +# construct a pipeline for data and gt loading in show function +# please keep its loading function consistent with test_pipeline (e.g. client) +eval_pipeline = [ + dict( + type='LoadPointsFromFile', + coord_type='LIDAR', + load_dim=5, + use_dim=5, + file_client_args=file_client_args), + dict( + type='LoadPointsFromMultiSweeps', + sweeps_num=9, + use_dim=[0, 1, 2, 3, 4], + file_client_args=file_client_args, + pad_empty_sweeps=True, + remove_close=True), + dict( + type='DefaultFormatBundle3D', + class_names=class_names, + with_label=False), + dict(type='Collect3D', keys=['points']) +] + +train_load_interval = 1000 + +with_velocity = False +data = dict( + train=dict( + type='CBGSDataset', + dataset=dict( + type=dataset_type, + data_root=data_root, + ann_file=data_root + 'nuscenes_infos_train.pkl', + pipeline=train_pipeline, + load_interval=train_load_interval, + classes=class_names, + with_velocity=with_velocity, + test_mode=False, + use_valid_flag=True, + # we use box_type_3d='LiDAR' in kitti and nuscenes dataset + # and box_type_3d='Depth' in sunrgbd and scannet dataset. + box_type_3d='LiDAR')), + val=dict( + pipeline=test_pipeline, + classes=class_names, + with_velocity=with_velocity), + test=dict( + pipeline=test_pipeline, + classes=class_names, + with_velocity=with_velocity)) diff --git a/mmdet3d/datasets/nuscenes_dataset.py b/mmdet3d/datasets/nuscenes_dataset.py index 1ca826571e..47d6e15ed9 100644 --- a/mmdet3d/datasets/nuscenes_dataset.py +++ b/mmdet3d/datasets/nuscenes_dataset.py @@ -316,7 +316,7 @@ def _format_bbox(self, results, jsonfile_prefix=None): print('Start to convert detection format...') for sample_id, det in enumerate(mmcv.track_iter_progress(results)): annos = [] - boxes = output_to_nusc_box(det) + boxes = output_to_nusc_box(det, self.with_velocity) sample_token = self.data_infos[sample_id]['token'] boxes = lidar_nusc_box_to_global(self.data_infos[sample_id], boxes, mapped_class_names, @@ -573,7 +573,7 @@ def show(self, results, out_dir, show=False, pipeline=None): file_name, show) -def output_to_nusc_box(detection): +def output_to_nusc_box(detection, with_velocity=True): """Convert the output to the box class in the nuScenes. Args: @@ -600,7 +600,10 @@ def output_to_nusc_box(detection): box_list = [] for i in range(len(box3d)): quat = pyquaternion.Quaternion(axis=[0, 0, 1], radians=box_yaw[i]) - velocity = (*box3d.tensor[i, 7:9], 0.0) + if with_velocity: + velocity = (*box3d.tensor[i, 7:9], 0.0) + else: + velocity = (0, 0, 0) # velo_val = np.linalg.norm(box3d[i, 7:9]) # velo_ori = box3d[i, 6] # velocity = ( diff --git a/mmdet3d/datasets/pipelines/dbsampler.py b/mmdet3d/datasets/pipelines/dbsampler.py index ef82c88e29..12101c2c60 100644 --- a/mmdet3d/datasets/pipelines/dbsampler.py +++ b/mmdet3d/datasets/pipelines/dbsampler.py @@ -89,6 +89,8 @@ class DataBaseSampler(object): prepare (dict): Name of preparation functions and the input value. sample_groups (dict): Sampled classes and numbers. classes (list[str], optional): List of classes. Default: None. + bbox_code_size (int, optional): The number of bbox dimensions. + Default: None. points_loader(dict, optional): Config of points loader. Default: dict(type='LoadPointsFromFile', load_dim=4, use_dim=[0,1,2,3]) """ @@ -100,6 +102,7 @@ def __init__(self, prepare, sample_groups, classes=None, + bbox_code_size=None, points_loader=dict( type='LoadPointsFromFile', coord_type='LIDAR', @@ -143,6 +146,13 @@ def __init__(self, self.db_infos = db_infos + self.bbox_code_size = bbox_code_size + if bbox_code_size is not None: + for k, info_cls in self.db_infos.items(): + for info in info_cls: + info['box3d_lidar'] = info['box3d_lidar'][:self. + bbox_code_size] + # load sample groups # TODO: more elegant way to load sample groups self.sample_groups = [] @@ -150,6 +160,7 @@ def __init__(self, self.sample_groups.append({name: int(num)}) self.group_db_infos = self.db_infos # just use db_infos + self.sample_classes = [] self.sample_max_nums = [] for group_info in self.sample_groups: diff --git a/mmdet3d/datasets/pipelines/loading.py b/mmdet3d/datasets/pipelines/loading.py index bbdcb8ed2c..ffbfb40bbc 100644 --- a/mmdet3d/datasets/pipelines/loading.py +++ b/mmdet3d/datasets/pipelines/loading.py @@ -108,6 +108,8 @@ class LoadPointsFromMultiSweeps(object): Defaults to 5. use_dim (list[int], optional): Which dimension to use. Defaults to [0, 1, 2, 4]. + time_dim (int, optional): Which dimension to represent the timestamps + of each points. Defaults to 4. file_client_args (dict, optional): Config dict of file clients, refer to https://github.com/open-mmlab/mmcv/blob/master/mmcv/fileio/file_client.py @@ -125,6 +127,7 @@ def __init__(self, sweeps_num=10, load_dim=5, use_dim=[0, 1, 2, 4], + time_dim=4, file_client_args=dict(backend='disk'), pad_empty_sweeps=False, remove_close=False, @@ -132,11 +135,16 @@ def __init__(self, self.load_dim = load_dim self.sweeps_num = sweeps_num self.use_dim = use_dim + self.time_dim = time_dim + assert time_dim < load_dim, \ + f'Expect the timestamp dimension < {load_dim}, got {time_dim}' self.file_client_args = file_client_args.copy() self.file_client = None self.pad_empty_sweeps = pad_empty_sweeps self.remove_close = remove_close self.test_mode = test_mode + assert max(use_dim) < load_dim, \ + f'Expect all used dimensions < {load_dim}, got {use_dim}' def _load_points(self, pts_filename): """Private function to load point clouds data. @@ -197,7 +205,7 @@ def __call__(self, results): cloud arrays. """ points = results['points'] - points.tensor[:, 4] = 0 + points.tensor[:, self.time_dim] = 0 sweep_points_list = [points] ts = results['timestamp'] if self.pad_empty_sweeps and len(results['sweeps']) == 0: @@ -224,7 +232,7 @@ def __call__(self, results): points_sweep[:, :3] = points_sweep[:, :3] @ sweep[ 'sensor2lidar_rotation'].T points_sweep[:, :3] += sweep['sensor2lidar_translation'] - points_sweep[:, 4] = ts - sweep_ts + points_sweep[:, self.time_dim] = ts - sweep_ts points_sweep = points.new_point(points_sweep) sweep_points_list.append(points_sweep) diff --git a/mmdet3d/models/dense_heads/centerpoint_head.py b/mmdet3d/models/dense_heads/centerpoint_head.py index 2cf758bd09..16ed35b3d3 100644 --- a/mmdet3d/models/dense_heads/centerpoint_head.py +++ b/mmdet3d/models/dense_heads/centerpoint_head.py @@ -327,6 +327,8 @@ def __init__(self, in_channels=share_conv_channel, heads=heads, num_cls=num_cls) self.task_heads.append(builder.build_head(separate_head)) + self.with_velocity = 'vel' in common_heads.keys() + def forward_single(self, x): """Forward function for CenterPoint. @@ -490,8 +492,12 @@ def get_targets_single(self, gt_bboxes_3d, gt_labels_3d): (len(self.class_names[idx]), feature_map_size[1], feature_map_size[0])) - anno_box = gt_bboxes_3d.new_zeros((max_objs, 10), - dtype=torch.float32) + if self.with_velocity: + anno_box = gt_bboxes_3d.new_zeros((max_objs, 10), + dtype=torch.float32) + else: + anno_box = gt_bboxes_3d.new_zeros((max_objs, 8), + dtype=torch.float32) ind = gt_labels_3d.new_zeros((max_objs), dtype=torch.int64) mask = gt_bboxes_3d.new_zeros((max_objs), dtype=torch.uint8) @@ -548,19 +554,27 @@ def get_targets_single(self, gt_bboxes_3d, gt_labels_3d): ind[new_idx] = y * feature_map_size[0] + x mask[new_idx] = 1 # TODO: support other outdoor dataset - vx, vy = task_boxes[idx][k][7:] rot = task_boxes[idx][k][6] box_dim = task_boxes[idx][k][3:6] if self.norm_bbox: box_dim = box_dim.log() - anno_box[new_idx] = torch.cat([ - center - torch.tensor([x, y], device=device), - z.unsqueeze(0), box_dim, - torch.sin(rot).unsqueeze(0), - torch.cos(rot).unsqueeze(0), - vx.unsqueeze(0), - vy.unsqueeze(0) - ]) + if self.with_velocity: + vx, vy = task_boxes[idx][k][7:] + anno_box[new_idx] = torch.cat([ + center - torch.tensor([x, y], device=device), + z.unsqueeze(0), box_dim, + torch.sin(rot).unsqueeze(0), + torch.cos(rot).unsqueeze(0), + vx.unsqueeze(0), + vy.unsqueeze(0) + ]) + else: + anno_box[new_idx] = torch.cat([ + center - torch.tensor([x, y], device=device), + z.unsqueeze(0), box_dim, + torch.sin(rot).unsqueeze(0), + torch.cos(rot).unsqueeze(0) + ]) heatmaps.append(heatmap) anno_boxes.append(anno_box) @@ -594,11 +608,17 @@ def loss(self, gt_bboxes_3d, gt_labels_3d, preds_dicts, **kwargs): avg_factor=max(num_pos, 1)) target_box = anno_boxes[task_id] # reconstruct the anno_box from multiple reg heads - preds_dict[0]['anno_box'] = torch.cat( - (preds_dict[0]['reg'], preds_dict[0]['height'], - preds_dict[0]['dim'], preds_dict[0]['rot'], - preds_dict[0]['vel']), - dim=1) + if self.with_velocity: + preds_dict[0]['anno_box'] = torch.cat( + (preds_dict[0]['reg'], preds_dict[0]['height'], + preds_dict[0]['dim'], preds_dict[0]['rot'], + preds_dict[0]['vel']), + dim=1) + else: + preds_dict[0]['anno_box'] = torch.cat( + (preds_dict[0]['reg'], preds_dict[0]['height'], + preds_dict[0]['dim'], preds_dict[0]['rot']), + dim=1) # Regression loss for dimension, offset, height, rotation ind = inds[task_id] diff --git a/mmdet3d/models/detectors/centerpoint.py b/mmdet3d/models/detectors/centerpoint.py index 290af5bedc..fd13783f70 100644 --- a/mmdet3d/models/detectors/centerpoint.py +++ b/mmdet3d/models/detectors/centerpoint.py @@ -33,6 +33,12 @@ def __init__(self, pts_bbox_head, img_roi_head, img_rpn_head, train_cfg, test_cfg, pretrained, init_cfg) + @property + def with_velocity(self): + """bool: Whether the head predicts velocity""" + return self.pts_bbox_head is not None and \ + self.pts_bbox_head.with_velocity + def extract_pts_feat(self, pts, img_feats, img_metas): """Extract features of points.""" if not self.with_pts_bbox: