From aa544d4a1def643c80dd40699887509ad75d6026 Mon Sep 17 00:00:00 2001 From: VVsssssk <88368822+VVsssssk@users.noreply.github.com> Date: Wed, 16 Nov 2022 16:44:55 +0800 Subject: [PATCH] [Features] Support PV_RCNN modules (#1957) * add pvrcnn module code * add voxelsa * fix * fix comments * fix comments * fix comments * add stack sa * fix * fix comments * fix comments * fix * add ut * fix comments --- .../pvrcnn/pvrcnn_8xb2-80e_kitti-3d-3class.py | 353 ++++++++++++ mmdet3d/apis/inference.py | 1 - mmdet3d/models/detectors/__init__.py | 4 +- mmdet3d/models/detectors/pv_rcnn.py | 232 ++++++++ .../layers/pointnet_modules/__init__.py | 3 +- .../pointnet_modules/stack_point_sa_module.py | 198 +++++++ mmdet3d/models/middle_encoders/__init__.py | 4 +- .../models/middle_encoders/sparse_encoder.py | 20 +- .../middle_encoders/voxel_set_abstraction.py | 334 ++++++++++++ mmdet3d/models/roi_heads/__init__.py | 3 +- .../models/roi_heads/bbox_heads/__init__.py | 3 +- .../roi_heads/bbox_heads/pv_rcnn_bbox_head.py | 510 ++++++++++++++++++ .../models/roi_heads/mask_heads/__init__.py | 5 +- .../foreground_segmentation_head.py | 175 ++++++ mmdet3d/models/roi_heads/pv_rcnn_roi_head.py | 312 +++++++++++ .../roi_heads/roi_extractors/__init__.py | 3 +- .../batch_roigridpoint_extractor.py | 97 ++++ .../test_models/test_detectors/test_pvrcnn.py | 64 +++ 18 files changed, 2310 insertions(+), 11 deletions(-) create mode 100644 configs/pvrcnn/pvrcnn_8xb2-80e_kitti-3d-3class.py create mode 100644 mmdet3d/models/detectors/pv_rcnn.py create mode 100644 mmdet3d/models/layers/pointnet_modules/stack_point_sa_module.py create mode 100644 mmdet3d/models/middle_encoders/voxel_set_abstraction.py create mode 100644 mmdet3d/models/roi_heads/bbox_heads/pv_rcnn_bbox_head.py create mode 100644 mmdet3d/models/roi_heads/mask_heads/foreground_segmentation_head.py create mode 100644 mmdet3d/models/roi_heads/pv_rcnn_roi_head.py create mode 100644 mmdet3d/models/roi_heads/roi_extractors/batch_roigridpoint_extractor.py create mode 100644 tests/test_models/test_detectors/test_pvrcnn.py diff --git a/configs/pvrcnn/pvrcnn_8xb2-80e_kitti-3d-3class.py b/configs/pvrcnn/pvrcnn_8xb2-80e_kitti-3d-3class.py new file mode 100644 index 000000000..f3cb14395 --- /dev/null +++ b/configs/pvrcnn/pvrcnn_8xb2-80e_kitti-3d-3class.py @@ -0,0 +1,353 @@ +_base_ = [ + '../_base_/datasets/kitti-3d-3class.py', + '../_base_/schedules/cyclic-40e.py', '../_base_/default_runtime.py' +] + +voxel_size = [0.05, 0.05, 0.1] +point_cloud_range = [0, -40, -3, 70.4, 40, 1] + +data_root = 'data/kitti/' +class_names = ['Pedestrian', 'Cyclist', 'Car'] +metainfo = dict(CLASSES=class_names) +db_sampler = dict( + data_root=data_root, + info_path=data_root + 'kitti_dbinfos_train.pkl', + rate=1.0, + prepare=dict( + filter_by_difficulty=[-1], + filter_by_min_points=dict(Car=5, Pedestrian=5, Cyclist=5)), + classes=class_names, + sample_groups=dict(Car=15, Pedestrian=10, Cyclist=10), + points_loader=dict( + type='LoadPointsFromFile', coord_type='LIDAR', load_dim=4, use_dim=4)) + +train_pipeline = [ + dict(type='LoadPointsFromFile', coord_type='LIDAR', load_dim=4, use_dim=4), + dict(type='LoadAnnotations3D', with_bbox_3d=True, with_label_3d=True), + dict(type='ObjectSample', db_sampler=db_sampler, use_ground_plane=True), + dict(type='RandomFlip3D', flip_ratio_bev_horizontal=0.5), + dict( + type='GlobalRotScaleTrans', + rot_range=[-0.78539816, 0.78539816], + scale_ratio_range=[0.95, 1.05]), + dict(type='PointsRangeFilter', point_cloud_range=point_cloud_range), + dict(type='ObjectRangeFilter', point_cloud_range=point_cloud_range), + dict(type='PointShuffle'), + dict( + type='Pack3DDetInputs', + keys=['points', 'gt_bboxes_3d', 'gt_labels_3d']) +] +test_pipeline = [ + dict(type='LoadPointsFromFile', coord_type='LIDAR', load_dim=4, use_dim=4), + dict( + type='MultiScaleFlipAug3D', + img_scale=(1333, 800), + pts_scale_ratio=1, + flip=False, + transforms=[ + dict( + type='GlobalRotScaleTrans', + rot_range=[0, 0], + scale_ratio_range=[1., 1.], + translation_std=[0, 0, 0]), + dict(type='RandomFlip3D'), + dict( + type='PointsRangeFilter', point_cloud_range=point_cloud_range) + ]), + dict(type='Pack3DDetInputs', keys=['points']) +] + +model = dict( + type='PointVoxelRCNN', + data_preprocessor=dict( + type='Det3DDataPreprocessor', + voxel=True, + voxel_layer=dict( + max_num_points=5, # max_points_per_voxel + point_cloud_range=point_cloud_range, + voxel_size=voxel_size, + max_voxels=(16000, 40000))), + voxel_encoder=dict(type='HardSimpleVFE'), + middle_encoder=dict( + type='SparseEncoder', + in_channels=4, + sparse_shape=[41, 1600, 1408], + order=('conv', 'norm', 'act'), + encoder_paddings=((0, 0, 0), ((1, 1, 1), 0, 0), ((1, 1, 1), 0, 0), + ((0, 1, 1), 0, 0)), + return_middle_feats=True), + points_encoder=dict( + type='VoxelSetAbstraction', + num_keypoints=2048, + fused_out_channel=128, + voxel_size=voxel_size, + point_cloud_range=point_cloud_range, + voxel_sa_cfgs_list=[ + dict( + type='StackedSAModuleMSG', + in_channels=16, + scale_factor=1, + radius=(0.4, 0.8), + sample_nums=(16, 16), + mlp_channels=((16, 16), (16, 16)), + use_xyz=True), + dict( + type='StackedSAModuleMSG', + in_channels=32, + scale_factor=2, + radius=(0.8, 1.2), + sample_nums=(16, 32), + mlp_channels=((32, 32), (32, 32)), + use_xyz=True), + dict( + type='StackedSAModuleMSG', + in_channels=64, + scale_factor=4, + radius=(1.2, 2.4), + sample_nums=(16, 32), + mlp_channels=((64, 64), (64, 64)), + use_xyz=True), + dict( + type='StackedSAModuleMSG', + in_channels=64, + scale_factor=8, + radius=(2.4, 4.8), + sample_nums=(16, 32), + mlp_channels=((64, 64), (64, 64)), + use_xyz=True) + ], + rawpoints_sa_cfgs=dict( + type='StackedSAModuleMSG', + in_channels=1, + radius=(0.4, 0.8), + sample_nums=(16, 16), + mlp_channels=((16, 16), (16, 16)), + use_xyz=True), + bev_feat_channel=256, + bev_scale_factor=8), + backbone=dict( + type='SECOND', + in_channels=256, + layer_nums=[5, 5], + layer_strides=[1, 2], + out_channels=[128, 256]), + neck=dict( + type='SECONDFPN', + in_channels=[128, 256], + upsample_strides=[1, 2], + out_channels=[256, 256]), + rpn_head=dict( + type='PartA2RPNHead', + num_classes=3, + in_channels=512, + feat_channels=512, + use_direction_classifier=True, + dir_offset=0.78539, + anchor_generator=dict( + type='Anchor3DRangeGenerator', + ranges=[[0, -40.0, -0.6, 70.4, 40.0, -0.6], + [0, -40.0, -0.6, 70.4, 40.0, -0.6], + [0, -40.0, -1.78, 70.4, 40.0, -1.78]], + sizes=[[0.8, 0.6, 1.73], [1.76, 0.6, 1.73], [3.9, 1.6, 1.56]], + rotations=[0, 1.57], + reshape_out=False), + diff_rad_by_sin=True, + assigner_per_size=True, + assign_per_class=True, + bbox_coder=dict(type='DeltaXYZWLHRBBoxCoder'), + loss_cls=dict( + type='mmdet.FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=1.0), + loss_bbox=dict( + type='mmdet.SmoothL1Loss', beta=1.0 / 9.0, loss_weight=2.0), + loss_dir=dict( + type='mmdet.CrossEntropyLoss', use_sigmoid=False, + loss_weight=0.2)), + roi_head=dict( + type='PVRCNNRoiHead', + num_classes=3, + semantic_head=dict( + type='ForegroundSegmentationHead', + in_channels=640, + extra_width=0.1, + loss_seg=dict( + type='mmdet.FocalLoss', + use_sigmoid=True, + reduction='sum', + gamma=2.0, + alpha=0.25, + activated=True, + loss_weight=1.0)), + bbox_roi_extractor=dict( + type='Batch3DRoIGridExtractor', + grid_size=6, + roi_layer=dict( + type='StackedSAModuleMSG', + in_channels=128, + radius=(0.8, 1.6), + sample_nums=(16, 16), + mlp_channels=((64, 64), (64, 64)), + use_xyz=True, + pool_mod='max'), + ), + bbox_head=dict( + type='PVRCNNBBoxHead', + in_channels=128, + grid_size=6, + num_classes=3, + class_agnostic=True, + shared_fc_channels=(256, 256), + reg_channels=(256, 256), + cls_channels=(256, 256), + dropout_ratio=0.3, + with_corner_loss=True, + bbox_coder=dict(type='DeltaXYZWLHRBBoxCoder'), + loss_bbox=dict( + type='mmdet.SmoothL1Loss', + beta=1.0 / 9.0, + reduction='sum', + loss_weight=1.0), + loss_cls=dict( + type='mmdet.CrossEntropyLoss', + use_sigmoid=True, + reduction='sum', + loss_weight=1.0))), + # model training and testing settings + train_cfg=dict( + rpn=dict( + assigner=[ + dict( # for Pedestrian + type='Max3DIoUAssigner', + iou_calculator=dict(type='BboxOverlapsNearest3D'), + pos_iou_thr=0.5, + neg_iou_thr=0.35, + min_pos_iou=0.35, + ignore_iof_thr=-1), + dict( # for Cyclist + type='Max3DIoUAssigner', + iou_calculator=dict(type='BboxOverlapsNearest3D'), + pos_iou_thr=0.5, + neg_iou_thr=0.35, + min_pos_iou=0.35, + ignore_iof_thr=-1), + dict( # for Car + type='Max3DIoUAssigner', + iou_calculator=dict(type='BboxOverlapsNearest3D'), + pos_iou_thr=0.6, + neg_iou_thr=0.45, + min_pos_iou=0.45, + ignore_iof_thr=-1) + ], + allowed_border=0, + pos_weight=-1, + debug=False), + rpn_proposal=dict( + nms_pre=9000, + nms_post=512, + max_num=512, + nms_thr=0.8, + score_thr=0, + use_rotate_nms=True), + rcnn=dict( + assigner=[ + dict( # for Pedestrian + type='Max3DIoUAssigner', + iou_calculator=dict( + type='BboxOverlaps3D', coordinate='lidar'), + pos_iou_thr=0.55, + neg_iou_thr=0.55, + min_pos_iou=0.55, + ignore_iof_thr=-1), + dict( # for Cyclist + type='Max3DIoUAssigner', + iou_calculator=dict( + type='BboxOverlaps3D', coordinate='lidar'), + pos_iou_thr=0.55, + neg_iou_thr=0.55, + min_pos_iou=0.55, + ignore_iof_thr=-1), + dict( # for Car + type='Max3DIoUAssigner', + iou_calculator=dict( + type='BboxOverlaps3D', coordinate='lidar'), + pos_iou_thr=0.55, + neg_iou_thr=0.55, + min_pos_iou=0.55, + ignore_iof_thr=-1) + ], + sampler=dict( + type='IoUNegPiecewiseSampler', + num=128, + pos_fraction=0.5, + neg_piece_fractions=[0.8, 0.2], + neg_iou_piece_thrs=[0.55, 0.1], + neg_pos_ub=-1, + add_gt_as_proposals=False, + return_iou=True), + cls_pos_thr=0.75, + cls_neg_thr=0.25)), + test_cfg=dict( + rpn=dict( + nms_pre=1024, + nms_post=100, + max_num=100, + nms_thr=0.7, + score_thr=0, + use_rotate_nms=True), + rcnn=dict( + use_rotate_nms=True, + use_raw_score=True, + nms_thr=0.1, + score_thr=0.1))) +train_dataloader = dict( + batch_size=2, + num_workers=2, + dataset=dict(dataset=dict(pipeline=train_pipeline, metainfo=metainfo))) +test_dataloader = dict(dataset=dict(pipeline=test_pipeline, metainfo=metainfo)) +eval_dataloader = dict(dataset=dict(pipeline=test_pipeline, metainfo=metainfo)) +lr = 0.001 +optim_wrapper = dict(optimizer=dict(lr=lr)) +param_scheduler = [ + # learning rate scheduler + # During the first 16 epochs, learning rate increases from 0 to lr * 10 + # during the next 24 epochs, learning rate decreases from lr * 10 to + # lr * 1e-4 + dict( + type='CosineAnnealingLR', + T_max=15, + eta_min=lr * 10, + begin=0, + end=15, + by_epoch=True, + convert_to_iter_based=True), + dict( + type='CosineAnnealingLR', + T_max=25, + eta_min=lr * 1e-4, + begin=15, + end=40, + by_epoch=True, + convert_to_iter_based=True), + # momentum scheduler + # During the first 16 epochs, momentum increases from 0 to 0.85 / 0.95 + # during the next 24 epochs, momentum increases from 0.85 / 0.95 to 1 + dict( + type='CosineAnnealingMomentum', + T_max=15, + eta_min=0.85 / 0.95, + begin=0, + end=15, + by_epoch=True, + convert_to_iter_based=True), + dict( + type='CosineAnnealingMomentum', + T_max=25, + eta_min=1, + begin=15, + end=40, + by_epoch=True, + convert_to_iter_based=True) +] diff --git a/mmdet3d/apis/inference.py b/mmdet3d/apis/inference.py index 891a9896d..2273aa95a 100644 --- a/mmdet3d/apis/inference.py +++ b/mmdet3d/apis/inference.py @@ -67,7 +67,6 @@ def init_model(config: Union[str, Path, Config], if checkpoint is not None: checkpoint = load_checkpoint(model, checkpoint, map_location='cpu') - dataset_meta = checkpoint['meta'].get('dataset_meta', None) # save the dataset_meta in the model for convenience if 'dataset_meta' in checkpoint.get('meta', {}): diff --git a/mmdet3d/models/detectors/__init__.py b/mmdet3d/models/detectors/__init__.py index 5619d86f1..c95e00ca0 100644 --- a/mmdet3d/models/detectors/__init__.py +++ b/mmdet3d/models/detectors/__init__.py @@ -14,6 +14,7 @@ from .mvx_two_stage import MVXTwoStageDetector from .parta2 import PartA2 from .point_rcnn import PointRCNN +from .pv_rcnn import PointVoxelRCNN from .sassd import SASSD from .single_stage_mono3d import SingleStageMono3DDetector from .smoke_mono3d import SMOKEMono3D @@ -26,5 +27,6 @@ 'DynamicMVXFasterRCNN', 'MVXFasterRCNN', 'PartA2', 'VoteNet', 'H3DNet', 'CenterPoint', 'SSD3DNet', 'ImVoteNet', 'SingleStageMono3DDetector', 'FCOSMono3D', 'ImVoxelNet', 'GroupFree3DNet', 'PointRCNN', 'SMOKEMono3D', - 'SASSD', 'MinkSingleStage3DDetector', 'MultiViewDfM', 'DfM' + 'SASSD', 'MinkSingleStage3DDetector', 'MultiViewDfM', 'DfM', + 'PointVoxelRCNN' ] diff --git a/mmdet3d/models/detectors/pv_rcnn.py b/mmdet3d/models/detectors/pv_rcnn.py new file mode 100644 index 000000000..ac03a6193 --- /dev/null +++ b/mmdet3d/models/detectors/pv_rcnn.py @@ -0,0 +1,232 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy +from typing import Optional + +from mmdet3d.registry import MODELS +from mmdet3d.structures.det3d_data_sample import SampleList +from mmdet3d.utils import InstanceList +from .two_stage import TwoStage3DDetector + + +@MODELS.register_module() +class PointVoxelRCNN(TwoStage3DDetector): + r"""PointVoxelRCNN detector. + + Please refer to the `PointVoxelRCNN `_. + + Args: + voxel_encoder (dict): Point voxelization encoder layer. + middle_encoder (dict): Middle encoder layer + of points cloud modality. + backbone (dict): Backbone of extracting points features. + neck (dict, optional): Neck of extracting points features. + Defaults to None. + rpn_head (dict, optional): Config of RPN head. Defaults to None. + points_encoder (dict, optional): Points encoder to extract point-wise + features. Defaults to None. + roi_head (dict, optional): Config of ROI head. Defaults to None. + train_cfg (dict, optional): Train config of model. + Defaults to None. + test_cfg (dict, optional): Train config of model. + Defaults to None. + init_cfg (dict, optional): Initialize config of + model. Defaults to None. + data_preprocessor (dict or ConfigDict, optional): The pre-process + config of :class:`Det3DDataPreprocessor`. Defaults to None. + """ + + def __init__(self, + voxel_encoder: dict, + middle_encoder: dict, + backbone: dict, + neck: Optional[dict] = None, + rpn_head: Optional[dict] = None, + points_encoder: Optional[dict] = None, + roi_head: Optional[dict] = None, + train_cfg: Optional[dict] = None, + test_cfg: Optional[dict] = None, + init_cfg: Optional[dict] = None, + data_preprocessor: Optional[dict] = None) -> None: + super().__init__( + backbone=backbone, + neck=neck, + rpn_head=rpn_head, + roi_head=roi_head, + train_cfg=train_cfg, + test_cfg=test_cfg, + init_cfg=init_cfg, + data_preprocessor=data_preprocessor) + self.voxel_encoder = MODELS.build(voxel_encoder) + self.middle_encoder = MODELS.build(middle_encoder) + self.points_encoder = MODELS.build(points_encoder) + + def predict(self, batch_inputs_dict: dict, batch_data_samples: SampleList, + **kwargs) -> SampleList: + """Predict results from a batch of inputs and data samples with post- + processing. + + Args: + batch_inputs_dict (dict): The model input dict which include + 'points', 'voxels' keys. + + - points (list[torch.Tensor]): Point cloud of each sample. + - voxels (dict[torch.Tensor]): Voxels of the batch sample. + + batch_data_samples (List[:obj:`Det3DDataSample`]): The Data + samples. It usually includes information such as + `gt_instance_3d`, `gt_panoptic_seg_3d` and `gt_sem_seg_3d`. + + Returns: + list[:obj:`Det3DDataSample`]: Detection results of the + input samples. Each Det3DDataSample usually contain + 'pred_instances_3d'. And the ``pred_instances_3d`` usually + contains following keys. + + - scores_3d (Tensor): Classification scores, has a shape + (num_instance, ) + - labels_3d (Tensor): Labels of bboxes, has a shape + (num_instances, ). + - bboxes_3d (Tensor): Contains a tensor with shape + (num_instances, C) where C >=7. + """ + feats_dict = self.extract_feat(batch_inputs_dict) + if self.with_rpn: + rpn_results_list = self.rpn_head.predict(feats_dict, + batch_data_samples) + else: + rpn_results_list = [ + data_sample.proposals for data_sample in batch_data_samples + ] + + # extrack points feats by points_encoder + points_feats_dict = self.extract_points_feat(batch_inputs_dict, + feats_dict, + rpn_results_list) + + results_list_3d = self.roi_head.predict(points_feats_dict, + rpn_results_list, + batch_data_samples) + + # connvert to Det3DDataSample + results_list = self.add_pred_to_datasample(batch_data_samples, + results_list_3d) + + return results_list + + def extract_feat(self, batch_inputs_dict: dict) -> dict: + """Extract features from the input voxels. + + Args: + batch_inputs_dict (dict): The model input dict which include + 'points', 'voxels' keys. + + - points (list[torch.Tensor]): Point cloud of each sample. + - voxels (dict[torch.Tensor]): Voxels of the batch sample. + + Returns: + dict: We typically obtain a dict of features from the backbone + + neck, it includes: + + - spatial_feats (torch.Tensor): Spatial feats from middle + encoder. + - multi_scale_3d_feats (list[torch.Tensor]): Multi scale + middle feats from middle encoder. + - neck_feats (torch.Tensor): Neck feats from neck. + """ + feats_dict = dict() + voxel_dict = batch_inputs_dict['voxels'] + voxel_features = self.voxel_encoder(voxel_dict['voxels'], + voxel_dict['num_points'], + voxel_dict['coors']) + batch_size = voxel_dict['coors'][-1, 0].item() + 1 + feats_dict['spatial_feats'], feats_dict[ + 'multi_scale_3d_feats'] = self.middle_encoder( + voxel_features, voxel_dict['coors'], batch_size) + x = self.backbone(feats_dict['spatial_feats']) + if self.with_neck: + neck_feats = self.neck(x) + feats_dict['neck_feats'] = neck_feats + return feats_dict + + def extract_points_feat(self, batch_inputs_dict: dict, feats_dict: dict, + rpn_results_list: InstanceList) -> dict: + """Extract point-wise features from the raw points and voxel features. + + Args: + batch_inputs_dict (dict): The model input dict which include + 'points', 'voxels' keys. + + - points (list[torch.Tensor]): Point cloud of each sample. + - voxels (dict[torch.Tensor]): Voxels of the batch sample. + feats_dict (dict): Contains features from the first stage. + rpn_results_list (List[:obj:`InstanceData`]): Detection results + of rpn head. + + Returns: + dict: Contain Point-wise features, include: + - keypoints (torch.Tensor): Sampled key points. + - keypoint_features (torch.Tensor): Gather key points features + from multi input. + - fusion_keypoint_features (torch.Tensor): Fusion + keypoint_features by point_feature_fusion_layer. + """ + return self.points_encoder(batch_inputs_dict, feats_dict, + rpn_results_list) + + def loss(self, batch_inputs_dict: dict, batch_data_samples: SampleList, + **kwargs): + """Calculate losses from a batch of inputs and data samples. + + Args: + batch_inputs_dict (dict): The model input dict which include + 'points', 'voxels' keys. + + - points (list[torch.Tensor]): Point cloud of each sample. + - voxels (dict[torch.Tensor]): Voxels of the batch sample. + + batch_data_samples (List[:obj:`Det3DDataSample`]): The Data + samples. It usually includes information such as + `gt_instance_3d`, `gt_panoptic_seg_3d` and `gt_sem_seg_3d`. + + Returns: + dict: A dictionary of loss components. + """ + feats_dict = self.extract_feat(batch_inputs_dict) + + losses = dict() + + # RPN forward and loss + if self.with_rpn: + proposal_cfg = self.train_cfg.get('rpn_proposal', + self.test_cfg.rpn) + rpn_data_samples = copy.deepcopy(batch_data_samples) + + rpn_losses, rpn_results_list = self.rpn_head.loss_and_predict( + feats_dict, + rpn_data_samples, + proposal_cfg=proposal_cfg, + **kwargs) + # avoid get same name with roi_head loss + keys = rpn_losses.keys() + for key in keys: + if 'loss' in key and 'rpn' not in key: + rpn_losses[f'rpn_{key}'] = rpn_losses.pop(key) + losses.update(rpn_losses) + else: + # TODO: Not support currently, should have a check at Fast R-CNN + assert batch_data_samples[0].get('proposals', None) is not None + # use pre-defined proposals in InstanceData for the second stage + # to extract ROI features. + rpn_results_list = [ + data_sample.proposals for data_sample in batch_data_samples + ] + + points_feats_dict = self.extract_points_feat(batch_inputs_dict, + feats_dict, + rpn_results_list) + + roi_losses = self.roi_head.loss(points_feats_dict, rpn_results_list, + batch_data_samples) + losses.update(roi_losses) + + return losses diff --git a/mmdet3d/models/layers/pointnet_modules/__init__.py b/mmdet3d/models/layers/pointnet_modules/__init__.py index 99b08eb88..13d6e1d81 100644 --- a/mmdet3d/models/layers/pointnet_modules/__init__.py +++ b/mmdet3d/models/layers/pointnet_modules/__init__.py @@ -4,9 +4,10 @@ PAConvSAModule, PAConvSAModuleMSG) from .point_fp_module import PointFPModule from .point_sa_module import PointSAModule, PointSAModuleMSG +from .stack_point_sa_module import StackedSAModuleMSG __all__ = [ 'build_sa_module', 'PointSAModuleMSG', 'PointSAModule', 'PointFPModule', 'PAConvSAModule', 'PAConvSAModuleMSG', 'PAConvCUDASAModule', - 'PAConvCUDASAModuleMSG' + 'PAConvCUDASAModuleMSG', 'StackedSAModuleMSG' ] diff --git a/mmdet3d/models/layers/pointnet_modules/stack_point_sa_module.py b/mmdet3d/models/layers/pointnet_modules/stack_point_sa_module.py new file mode 100644 index 000000000..d20ee510a --- /dev/null +++ b/mmdet3d/models/layers/pointnet_modules/stack_point_sa_module.py @@ -0,0 +1,198 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Optional, Tuple + +import torch +import torch.nn as nn +from mmcv.cnn import ConvModule +from mmcv.ops import ball_query, grouping_operation +from mmengine.model import BaseModule +from torch import Tensor + +from mmdet3d.registry import MODELS + + +class StackQueryAndGroup(BaseModule): + """Find nearby points in spherical space. + + Args: + radius (float): List of radius in each ball query. + sample_nums (int): Number of samples in each ball query. + use_xyz (bool): Whether to use xyz. Default: True. + init_cfg (dict, optional): Initialize config of + model. Defaults to None. + """ + + def __init__(self, + radius: float, + sample_nums: int, + use_xyz: bool = True, + init_cfg: dict = None): + super().__init__(init_cfg=init_cfg) + self.radius, self.sample_nums, self.use_xyz = \ + radius, sample_nums, use_xyz + + def forward(self, + xyz: torch.Tensor, + xyz_batch_cnt: torch.Tensor, + new_xyz: torch.Tensor, + new_xyz_batch_cnt: torch.Tensor, + features: torch.Tensor = None) -> Tuple[Tensor, Tensor]: + """Forward. + + Args: + xyz (Tensor): Tensor of the xyz coordinates + of the features shape with (N1 + N2 ..., 3). + xyz_batch_cnt: (Tensor): Stacked input xyz coordinates nums in + each batch, just like (N1, N2, ...). + new_xyz (Tensor): New coords of the outputs shape with + (M1 + M2 ..., 3). + new_xyz_batch_cnt: (Tensor): Stacked new xyz coordinates nums + in each batch, just like (M1, M2, ...). + features (Tensor, optional): Features of each point with shape + (N1 + N2 ..., C). C is features channel number. Default: None. + """ + assert xyz.shape[0] == xyz_batch_cnt.sum( + ), f'xyz: {str(xyz.shape)}, xyz_batch_cnt: str(new_xyz_batch_cnt)' + assert new_xyz.shape[0] == new_xyz_batch_cnt.sum(), \ + 'new_xyz: str(new_xyz.shape), new_xyz_batch_cnt: ' \ + 'str(new_xyz_batch_cnt)' + + # idx: (M1 + M2 ..., nsample), empty_ball_mask: (M1 + M2 ...) + idx, empty_ball_mask = ball_query(0, self.radius, self.sample_nums, + xyz, new_xyz, xyz_batch_cnt, + new_xyz_batch_cnt) + grouped_xyz = grouping_operation( + xyz, idx, xyz_batch_cnt, + new_xyz_batch_cnt) # (M1 + M2, 3, nsample) + grouped_xyz -= new_xyz.unsqueeze(-1) + + grouped_xyz[empty_ball_mask] = 0 + if features is not None: + grouped_features = grouping_operation( + features, idx, xyz_batch_cnt, + new_xyz_batch_cnt) # (M1 + M2, C, nsample) + grouped_features[empty_ball_mask] = 0 + if self.use_xyz: + new_features = torch.cat( + [grouped_xyz, grouped_features], + dim=1) # (M1 + M2 ..., C + 3, nsample) + else: + new_features = grouped_features + else: + assert self.use_xyz, 'Cannot have not features and not' \ + ' use xyz as a feature!' + new_features = grouped_xyz + return new_features, idx + + +@MODELS.register_module() +class StackedSAModuleMSG(BaseModule): + """Stack point set abstraction module. + + Args: + in_channels (int): Input channels. + radius (list[float]): List of radius in each ball query. + sample_nums (list[int]): Number of samples in each ball query. + mlp_channels (list[list[int]]): Specify mlp channels of the + pointnet before the global pooling for each scale to encode + point features. + use_xyz (bool): Whether to use xyz. Default: True. + pool_mod (str): Type of pooling method. + Default: 'max_pool'. + norm_cfg (dict): Type of normalization method. Defaults to + dict(type='BN2d', eps=1e-5, momentum=0.01). + init_cfg (dict, optional): Initialize config of + model. Defaults to None. + """ + + def __init__(self, + in_channels: int, + radius: List[float], + sample_nums: List[int], + mlp_channels: List[List[int]], + use_xyz: bool = True, + pool_mod='max', + norm_cfg: dict = dict(type='BN2d', eps=1e-5, momentum=0.01), + init_cfg: dict = None, + **kwargs) -> None: + super(StackedSAModuleMSG, self).__init__(init_cfg=init_cfg) + assert len(radius) == len(sample_nums) == len(mlp_channels) + + self.groupers = nn.ModuleList() + self.mlps = nn.ModuleList() + for i in range(len(radius)): + cin = in_channels + if use_xyz: + cin += 3 + cur_radius = radius[i] + nsample = sample_nums[i] + mlp_spec = mlp_channels[i] + + self.groupers.append( + StackQueryAndGroup(cur_radius, nsample, use_xyz=use_xyz)) + + mlp = nn.Sequential() + for i in range(len(mlp_spec)): + cout = mlp_spec[i] + mlp.add_module( + f'layer{i}', + ConvModule( + cin, + cout, + kernel_size=(1, 1), + stride=(1, 1), + conv_cfg=dict(type='Conv2d'), + norm_cfg=norm_cfg, + bias=False)) + cin = cout + self.mlps.append(mlp) + self.pool_mod = pool_mod + + def forward(self, + xyz: Tensor, + xyz_batch_cnt: Tensor, + new_xyz: Tensor, + new_xyz_batch_cnt: Tensor, + features: Optional[Tensor] = None) -> Tuple[Tensor, Tensor]: + """Forward. + + Args: + xyz (Tensor): Tensor of the xyz coordinates + of the features shape with (N1 + N2 ..., 3). + xyz_batch_cnt: (Tensor): Stacked input xyz coordinates nums in + each batch, just like (N1, N2, ...). + new_xyz (Tensor): New coords of the outputs shape with + (M1 + M2 ..., 3). + new_xyz_batch_cnt: (Tensor): Stacked new xyz coordinates nums + in each batch, just like (M1, M2, ...). + features (Tensor, optional): Features of each point with shape + (N1 + N2 ..., C). C is features channel number. Default: None. + + Returns: + Return new points coordinates and features: + - new_xyz (Tensor): Target points coordinates with shape + (N1 + N2 ..., 3). + - new_features (Tensor): Target points features with shape + (M1 + M2 ..., sum_k(mlps[k][-1])). + """ + new_features_list = [] + for k in range(len(self.groupers)): + grouped_features, ball_idxs = self.groupers[k]( + xyz, xyz_batch_cnt, new_xyz, new_xyz_batch_cnt, + features) # (M1 + M2, Cin, nsample) + grouped_features = grouped_features.permute(1, 0, + 2).unsqueeze(dim=0) + new_features = self.mlps[k](grouped_features) + # (M1 + M2 ..., Cout, nsample) + if self.pool_mod == 'max': + new_features = new_features.max(-1).values + elif self.pool_mod == 'avg': + new_features = new_features.mean(-1) + else: + raise NotImplementedError + new_features = new_features.squeeze(dim=0).permute(1, 0) + new_features_list.append(new_features) + + new_features = torch.cat(new_features_list, dim=1) + + return new_xyz, new_features diff --git a/mmdet3d/models/middle_encoders/__init__.py b/mmdet3d/models/middle_encoders/__init__.py index d7b443551..96f5d2019 100644 --- a/mmdet3d/models/middle_encoders/__init__.py +++ b/mmdet3d/models/middle_encoders/__init__.py @@ -2,7 +2,9 @@ from .pillar_scatter import PointPillarsScatter from .sparse_encoder import SparseEncoder, SparseEncoderSASSD from .sparse_unet import SparseUNet +from .voxel_set_abstraction import VoxelSetAbstraction __all__ = [ - 'PointPillarsScatter', 'SparseEncoder', 'SparseEncoderSASSD', 'SparseUNet' + 'PointPillarsScatter', 'SparseEncoder', 'SparseEncoderSASSD', 'SparseUNet', + 'VoxelSetAbstraction' ] diff --git a/mmdet3d/models/middle_encoders/sparse_encoder.py b/mmdet3d/models/middle_encoders/sparse_encoder.py index 61fc3cb31..e5331bebb 100644 --- a/mmdet3d/models/middle_encoders/sparse_encoder.py +++ b/mmdet3d/models/middle_encoders/sparse_encoder.py @@ -41,6 +41,8 @@ class SparseEncoder(nn.Module): Defaults to ((1, ), (1, 1, 1), (1, 1, 1), ((0, 1, 1), 1, 1)). block_type (str, optional): Type of the block to use. Defaults to 'conv_module'. + return_middle_feats (bool): Whether output middle features. + Default to False. """ def __init__(self, @@ -54,7 +56,8 @@ def __init__(self, 64)), encoder_paddings=((1, ), (1, 1, 1), (1, 1, 1), ((0, 1, 1), 1, 1)), - block_type='conv_module'): + block_type='conv_module', + return_middle_feats=False): super().__init__() assert block_type in ['conv_module', 'basicblock'] self.sparse_shape = sparse_shape @@ -66,6 +69,7 @@ def __init__(self, self.encoder_paddings = encoder_paddings self.stage_num = len(self.encoder_channels) self.fp16_enabled = False + self.return_middle_feats = return_middle_feats # Spconv init all weight on its own assert isinstance(order, tuple) and len(order) == 3 @@ -117,7 +121,14 @@ def forward(self, voxel_features, coors, batch_size): batch_size (int): Batch size. Returns: - dict: Backbone features. + torch.Tensor | tuple[torch.Tensor, list]: Return spatial features + include: + + - spatial_features (torch.Tensor): Spatial features are out from + the last layer. + - encode_features (List[SparseConvTensor], optional): Middle layer + output features. When self.return_middle_feats is True, the + module returns middle features. """ coors = coors.int() input_sp_tensor = SparseConvTensor(voxel_features, coors, @@ -137,7 +148,10 @@ def forward(self, voxel_features, coors, batch_size): N, C, D, H, W = spatial_features.shape spatial_features = spatial_features.view(N, C * D, H, W) - return spatial_features + if self.return_middle_feats: + return spatial_features, encode_features + else: + return spatial_features def make_encoder_layers(self, make_block, diff --git a/mmdet3d/models/middle_encoders/voxel_set_abstraction.py b/mmdet3d/models/middle_encoders/voxel_set_abstraction.py new file mode 100644 index 000000000..e2161fac7 --- /dev/null +++ b/mmdet3d/models/middle_encoders/voxel_set_abstraction.py @@ -0,0 +1,334 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Optional + +import mmengine +import torch +import torch.nn as nn +from mmcv.cnn import ConvModule +from mmcv.ops.furthest_point_sample import furthest_point_sample +from mmengine.model import BaseModule + +from mmdet3d.registry import MODELS +from mmdet3d.utils import InstanceList + + +def bilinear_interpolate_torch(inputs, x, y): + """Bilinear interpolate for inputs.""" + x0 = torch.floor(x).long() + x1 = x0 + 1 + + y0 = torch.floor(y).long() + y1 = y0 + 1 + + x0 = torch.clamp(x0, 0, inputs.shape[1] - 1) + x1 = torch.clamp(x1, 0, inputs.shape[1] - 1) + y0 = torch.clamp(y0, 0, inputs.shape[0] - 1) + y1 = torch.clamp(y1, 0, inputs.shape[0] - 1) + + Ia = inputs[y0, x0] + Ib = inputs[y1, x0] + Ic = inputs[y0, x1] + Id = inputs[y1, x1] + + wa = (x1.type_as(x) - x) * (y1.type_as(y) - y) + wb = (x1.type_as(x) - x) * (y - y0.type_as(y)) + wc = (x - x0.type_as(x)) * (y1.type_as(y) - y) + wd = (x - x0.type_as(x)) * (y - y0.type_as(y)) + ans = torch.t((torch.t(Ia) * wa)) + torch.t(torch.t(Ib) * wb) + torch.t( + torch.t(Ic) * wc) + torch.t(torch.t(Id) * wd) + return ans + + +@MODELS.register_module() +class VoxelSetAbstraction(BaseModule): + """Voxel set abstraction module for PVRCNN and PVRCNN++. + + Args: + num_keypoints (int): The number of key points sampled from + raw points cloud. + fused_out_channel (int): Key points feature output channels + num after fused. Default to 128. + voxel_size (list[float]): Size of voxels. Defaults to + [0.05, 0.05, 0.1]. + point_cloud_range (list[float]): Point cloud range. Defaults to + [0, -40, -3, 70.4, 40, 1]. + voxel_sa_cfgs_list (List[dict or ConfigDict], optional): List of SA + module cfg. Used to gather key points features from multi-wise + voxel features. Default to None. + rawpoints_sa_cfgs (dict or ConfigDict, optional): SA module cfg. + Used to gather key points features from raw points. Default to + None. + bev_feat_channel (int): Bev features channels num. + Default to 256. + bev_scale_factor (int): Bev features scale factor. Default to 8. + voxel_center_as_source (bool): Whether used voxel centers as points + cloud key points. Defaults to False. + norm_cfg (dict[str]): Config of normalization layer. Default + used dict(type='BN1d', eps=1e-5, momentum=0.1). + bias (bool | str, optional): If specified as `auto`, it will be + decided by `norm_cfg`. `bias` will be set as True if + `norm_cfg` is None, otherwise False. Default: 'auto'. + """ + + def __init__(self, + num_keypoints: int, + fused_out_channel: int = 128, + voxel_size: list = [0.05, 0.05, 0.1], + point_cloud_range: list = [0, -40, -3, 70.4, 40, 1], + voxel_sa_cfgs_list: Optional[list] = None, + rawpoints_sa_cfgs: Optional[dict] = None, + bev_feat_channel: int = 256, + bev_scale_factor: int = 8, + voxel_center_as_source: bool = False, + norm_cfg: dict = dict(type='BN2d', eps=1e-5, momentum=0.1), + bias: str = 'auto') -> None: + super().__init__() + self.num_keypoints = num_keypoints + self.fused_out_channel = fused_out_channel + self.voxel_size = voxel_size + self.point_cloud_range = point_cloud_range + self.voxel_center_as_source = voxel_center_as_source + + gathered_channel = 0 + + if rawpoints_sa_cfgs is not None: + self.rawpoints_sa_layer = MODELS.build(rawpoints_sa_cfgs) + gathered_channel += sum( + [x[-1] for x in rawpoints_sa_cfgs.mlp_channels]) + else: + self.rawpoints_sa_layer = None + + if voxel_sa_cfgs_list is not None: + self.voxel_sa_configs_list = voxel_sa_cfgs_list + self.voxel_sa_layers = nn.ModuleList() + for voxel_sa_config in voxel_sa_cfgs_list: + cur_layer = MODELS.build(voxel_sa_config) + self.voxel_sa_layers.append(cur_layer) + gathered_channel += sum( + [x[-1] for x in voxel_sa_config.mlp_channels]) + else: + self.voxel_sa_layers = None + + if bev_feat_channel is not None and bev_scale_factor is not None: + self.bev_cfg = mmengine.Config( + dict( + bev_feat_channels=bev_feat_channel, + bev_scale_factor=bev_scale_factor)) + gathered_channel += bev_feat_channel + else: + self.bev_cfg = None + self.point_feature_fusion_layer = nn.Sequential( + ConvModule( + gathered_channel, + fused_out_channel, + kernel_size=(1, 1), + stride=(1, 1), + conv_cfg=dict(type='Conv2d'), + norm_cfg=norm_cfg, + bias=bias)) + + def interpolate_from_bev_features(self, keypoints: torch.Tensor, + bev_features: torch.Tensor, + batch_size: int, + bev_scale_factor: int) -> torch.Tensor: + """Gather key points features from bev feature map by interpolate. + + Args: + keypoints (torch.Tensor): Sampled key points with shape + (N1 + N2 + ..., NDim). + bev_features (torch.Tensor): Bev feature map from the first + stage with shape (B, C, H, W). + batch_size (int): Input batch size. + bev_scale_factor (int): Bev feature map scale factor. + + Returns: + torch.Tensor: Key points features gather from bev feature + map with shape (N1 + N2 + ..., C) + """ + x_idxs = (keypoints[..., 0] - + self.point_cloud_range[0]) / self.voxel_size[0] + y_idxs = (keypoints[..., 1] - + self.point_cloud_range[1]) / self.voxel_size[1] + + x_idxs = x_idxs / bev_scale_factor + y_idxs = y_idxs / bev_scale_factor + + point_bev_features_list = [] + for k in range(batch_size): + cur_x_idxs = x_idxs[k, ...] + cur_y_idxs = y_idxs[k, ...] + cur_bev_features = bev_features[k].permute(1, 2, 0) # (H, W, C) + point_bev_features = bilinear_interpolate_torch( + cur_bev_features, cur_x_idxs, cur_y_idxs) + point_bev_features_list.append(point_bev_features) + + point_bev_features = torch.cat( + point_bev_features_list, dim=0) # (N1 + N2 + ..., C) + return point_bev_features.view(batch_size, keypoints.shape[1], -1) + + def get_voxel_centers(self, coors: torch.Tensor, + scale_factor: float) -> torch.Tensor: + """Get voxel centers coordinate. + + Args: + coors (torch.Tensor): Coordinates of voxels shape is Nx(1+NDim), + where 1 represents the batch index. + scale_factor (float): Scale factor. + + Returns: + torch.Tensor: Voxel centers coordinate with shape (N, 3). + """ + assert coors.shape[1] == 4 + voxel_centers = coors[:, [3, 2, 1]].float() # (xyz) + voxel_size = torch.tensor( + self.voxel_size, + device=voxel_centers.device).float() * scale_factor + pc_range = torch.tensor( + self.point_cloud_range[0:3], device=voxel_centers.device).float() + voxel_centers = (voxel_centers + 0.5) * voxel_size + pc_range + return voxel_centers + + def sample_key_points(self, points: List[torch.Tensor], + coors: torch.Tensor) -> torch.Tensor: + """Sample key points from raw points cloud. + + Args: + points (List[torch.Tensor]): Point cloud of each sample. + coors (torch.Tensor): Coordinates of voxels shape is Nx(1+NDim), + where 1 represents the batch index. + + Returns: + torch.Tensor: (B, M, 3) Key points of each sample. + M is num_keypoints. + """ + assert points is not None or coors is not None + if self.voxel_center_as_source: + _src_points = self.get_voxel_centers(coors=coors, scale_factor=1) + batch_size = coors[-1, 0].item() + 1 + src_points = [ + _src_points[coors[:, 0] == b] for b in range(batch_size) + ] + else: + src_points = [p[..., :3] for p in points] + + keypoints_list = [] + for points_to_sample in src_points: + num_points = points_to_sample.shape[0] + cur_pt_idxs = furthest_point_sample( + points_to_sample.unsqueeze(dim=0).contiguous(), + self.num_keypoints).long()[0] + + if num_points < self.num_keypoints: + times = int(self.num_keypoints / num_points) + 1 + non_empty = cur_pt_idxs[:num_points] + cur_pt_idxs = non_empty.repeat(times)[:self.num_keypoints] + + keypoints = points_to_sample[cur_pt_idxs] + + keypoints_list.append(keypoints) + keypoints = torch.stack(keypoints_list, dim=0) # (B, M, 3) + return keypoints + + def forward(self, batch_inputs_dict: dict, feats_dict: dict, + rpn_results_list: InstanceList) -> dict: + """Extract point-wise features from multi-input. + + Args: + batch_inputs_dict (dict): The model input dict which include + 'points', 'voxels' keys. + + - points (list[torch.Tensor]): Point cloud of each sample. + - voxels (dict[torch.Tensor]): Voxels of the batch sample. + feats_dict (dict): Contains features from the first + stage. + rpn_results_list (List[:obj:`InstanceData`]): Detection results + of rpn head. + + Returns: + dict: Contain Point-wise features, include: + - keypoints (torch.Tensor): Sampled key points. + - keypoint_features (torch.Tensor): Gathered key points + features from multi input. + - fusion_keypoint_features (torch.Tensor): Fusion + keypoint_features by point_feature_fusion_layer. + """ + points = batch_inputs_dict['points'] + voxel_encode_features = feats_dict['multi_scale_3d_feats'] + bev_encode_features = feats_dict['spatial_feats'] + if self.voxel_center_as_source: + voxels_coors = batch_inputs_dict['voxels']['coors'] + else: + voxels_coors = None + keypoints = self.sample_key_points(points, voxels_coors) + + point_features_list = [] + batch_size = len(points) + + if self.bev_cfg is not None: + point_bev_features = self.interpolate_from_bev_features( + keypoints, bev_encode_features, batch_size, + self.bev_cfg.bev_scale_factor) + point_features_list.append(point_bev_features.contiguous()) + + batch_size, num_keypoints, _ = keypoints.shape + key_xyz = keypoints.view(-1, 3) + key_xyz_batch_cnt = key_xyz.new_zeros(batch_size).int().fill_( + num_keypoints) + + if self.rawpoints_sa_layer is not None: + batch_points = torch.cat(points, dim=0) + batch_cnt = [len(p) for p in points] + xyz = batch_points[:, :3].contiguous() + features = None + if batch_points.size(1) > 0: + features = batch_points[:, 3:].contiguous() + xyz_batch_cnt = xyz.new_tensor(batch_cnt, dtype=torch.int32) + + pooled_points, pooled_features = self.rawpoints_sa_layer( + xyz=xyz.contiguous(), + xyz_batch_cnt=xyz_batch_cnt, + new_xyz=key_xyz.contiguous(), + new_xyz_batch_cnt=key_xyz_batch_cnt, + features=features.contiguous(), + ) + + point_features_list.append(pooled_features.contiguous().view( + batch_size, num_keypoints, -1)) + if self.voxel_sa_layers is not None: + for k, voxel_sa_layer in enumerate(self.voxel_sa_layers): + cur_coords = voxel_encode_features[k].indices + xyz = self.get_voxel_centers( + coors=cur_coords, + scale_factor=self.voxel_sa_configs_list[k].scale_factor + ).contiguous() + xyz_batch_cnt = xyz.new_zeros(batch_size).int() + for bs_idx in range(batch_size): + xyz_batch_cnt[bs_idx] = (cur_coords[:, 0] == bs_idx).sum() + + pooled_points, pooled_features = voxel_sa_layer( + xyz=xyz.contiguous(), + xyz_batch_cnt=xyz_batch_cnt, + new_xyz=key_xyz.contiguous(), + new_xyz_batch_cnt=key_xyz_batch_cnt, + features=voxel_encode_features[k].features.contiguous(), + ) + point_features_list.append(pooled_features.contiguous().view( + batch_size, num_keypoints, -1)) + + point_features = torch.cat( + point_features_list, dim=-1).view(batch_size * num_keypoints, -1, + 1) + + fusion_point_features = self.point_feature_fusion_layer( + point_features.unsqueeze(dim=-1)).squeeze(dim=-1) + + batch_idxs = torch.arange( + batch_size * num_keypoints, device=keypoints.device + ) // num_keypoints # batch indexes of each key points + batch_keypoints_xyz = torch.cat( + (batch_idxs.to(key_xyz.dtype).unsqueeze(dim=-1), key_xyz), dim=-1) + + return dict( + keypoint_features=point_features.squeeze(dim=-1), + fusion_keypoint_features=fusion_point_features.squeeze(dim=-1), + keypoints=batch_keypoints_xyz) diff --git a/mmdet3d/models/roi_heads/__init__.py b/mmdet3d/models/roi_heads/__init__.py index e607570d7..0e90b1a75 100644 --- a/mmdet3d/models/roi_heads/__init__.py +++ b/mmdet3d/models/roi_heads/__init__.py @@ -5,10 +5,11 @@ from .mask_heads import PointwiseSemanticHead, PrimitiveHead from .part_aggregation_roi_head import PartAggregationROIHead from .point_rcnn_roi_head import PointRCNNRoIHead +from .pv_rcnn_roi_head import PVRCNNRoiHead from .roi_extractors import Single3DRoIAwareExtractor, SingleRoIExtractor __all__ = [ 'Base3DRoIHead', 'PartAggregationROIHead', 'PointwiseSemanticHead', 'Single3DRoIAwareExtractor', 'PartA2BboxHead', 'SingleRoIExtractor', - 'H3DRoIHead', 'PrimitiveHead', 'PointRCNNRoIHead' + 'H3DRoIHead', 'PrimitiveHead', 'PointRCNNRoIHead', 'PVRCNNRoiHead' ] diff --git a/mmdet3d/models/roi_heads/bbox_heads/__init__.py b/mmdet3d/models/roi_heads/bbox_heads/__init__.py index 0a6ebeeeb..994465ed8 100644 --- a/mmdet3d/models/roi_heads/bbox_heads/__init__.py +++ b/mmdet3d/models/roi_heads/bbox_heads/__init__.py @@ -7,9 +7,10 @@ from .h3d_bbox_head import H3DBboxHead from .parta2_bbox_head import PartA2BboxHead from .point_rcnn_bbox_head import PointRCNNBboxHead +from .pv_rcnn_bbox_head import PVRCNNBBoxHead __all__ = [ 'BBoxHead', 'ConvFCBBoxHead', 'Shared2FCBBoxHead', 'Shared4Conv1FCBBoxHead', 'DoubleConvFCBBoxHead', 'PartA2BboxHead', - 'H3DBboxHead', 'PointRCNNBboxHead' + 'H3DBboxHead', 'PointRCNNBboxHead', 'PVRCNNBBoxHead' ] diff --git a/mmdet3d/models/roi_heads/bbox_heads/pv_rcnn_bbox_head.py b/mmdet3d/models/roi_heads/bbox_heads/pv_rcnn_bbox_head.py new file mode 100644 index 000000000..0cc9b6bae --- /dev/null +++ b/mmdet3d/models/roi_heads/bbox_heads/pv_rcnn_bbox_head.py @@ -0,0 +1,510 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Dict, List, Optional, Tuple + +import numpy as np +import torch +from mmcv.cnn import ConvModule +from mmengine.model import BaseModule +from mmengine.structures import InstanceData +from torch import nn as nn + +from mmdet3d.models.builder import build_loss +from mmdet3d.models.layers import nms_bev, nms_normal_bev +from mmdet3d.registry import MODELS, TASK_UTILS +from mmdet3d.structures.bbox_3d import (LiDARInstance3DBoxes, + rotation_3d_in_axis, xywhr2xyxyr) +from mmdet3d.utils import InstanceList +from mmdet.models.task_modules.samplers import SamplingResult +from mmdet.models.utils import multi_apply + + +@MODELS.register_module() +class PVRCNNBBoxHead(BaseModule): + """PVRCNN BBox head. + + Args: + in_channels (int): The number of input channel. + grid_size (int): The number of grid points in roi bbox. + num_classes (int): The number of classes. + class_agnostic (bool): Whether generate class agnostic prediction. + Defaults to True. + shared_fc_channels (tuple(int)): Out channels of each shared fc layer. + Defaults to (256, 256). + cls_channels (tuple(int)): Out channels of each classification layer. + Defaults to (256, 256). + reg_channels (tuple(int)): Out channels of each regression layer. + Defaults to (256, 256). + dropout_ratio (float): Ratio of dropout layer. Defaults to 0.5. + with_corner_loss (bool): Whether to use corner loss or not. + Defaults to True. + bbox_coder (:obj:`BaseBBoxCoder`): Bbox coder for box head. + Defaults to dict(type='DeltaXYZWLHRBBoxCoder'). + norm_cfg (dict): Type of normalization method. + Defaults to dict(type='BN1d', eps=1e-5, momentum=0.1) + loss_bbox (dict): Config dict of box regression loss. + loss_cls (dict): Config dict of classifacation loss. + init_cfg (dict, optional): Initialize config of + model. + """ + + def __init__( + self, + in_channels: int, + grid_size: int, + num_classes: int, + class_agnostic: bool = True, + shared_fc_channels: Tuple[int] = (256, 256), + cls_channels: Tuple[int] = (256, 256), + reg_channels: Tuple[int] = (256, 256), + dropout_ratio: float = 0.3, + with_corner_loss: bool = True, + bbox_coder: dict = dict(type='DeltaXYZWLHRBBoxCoder'), + norm_cfg: dict = dict(type='BN2d', eps=1e-5, momentum=0.1), + loss_bbox: dict = dict( + type='mmdet.SmoothL1Loss', beta=1.0 / 9.0, loss_weight=2.0), + loss_cls: dict = dict( + type='mmdet.CrossEntropyLoss', + use_sigmoid=True, + reduction='none', + loss_weight=1.0), + init_cfg: Optional[dict] = dict( + type='Xavier', layer=['Conv2d', 'Conv1d'], distribution='uniform') + ) -> None: + super(PVRCNNBBoxHead, self).__init__(init_cfg=init_cfg) + self.init_cfg = init_cfg + self.num_classes = num_classes + self.with_corner_loss = with_corner_loss + self.class_agnostic = class_agnostic + self.bbox_coder = TASK_UTILS.build(bbox_coder) + self.loss_bbox = build_loss(loss_bbox) + self.loss_cls = build_loss(loss_cls) + self.use_sigmoid_cls = loss_cls.get('use_sigmoid', False) + + cls_out_channels = 1 if class_agnostic else num_classes + self.reg_out_channels = self.bbox_coder.code_size * cls_out_channels + if self.use_sigmoid_cls: + self.cls_out_channels = cls_out_channels + else: + self.cls_out_channels = cls_out_channels + 1 + + self.dropout_ratio = dropout_ratio + self.grid_size = grid_size + + # PVRCNNBBoxHead model in_channels is num of grid points in roi box. + in_channels *= (self.grid_size**3) + + self.in_channels = in_channels + + self.shared_fc_layer = self._make_fc_layers( + in_channels, shared_fc_channels, + range(len(shared_fc_channels) - 1), norm_cfg) + self.cls_layer = self._make_fc_layers( + shared_fc_channels[-1], + cls_channels, + range(1), + norm_cfg, + out_channels=self.cls_out_channels) + self.reg_layer = self._make_fc_layers( + shared_fc_channels[-1], + reg_channels, + range(1), + norm_cfg, + out_channels=self.reg_out_channels) + + def _make_fc_layers(self, + in_channels: int, + fc_channels: list, + dropout_indices: list, + norm_cfg: dict, + out_channels: Optional[int] = None) -> torch.nn.Module: + """Initial a full connection layer. + + Args: + in_channels (int): Module in channels. + fc_channels (list): Full connection layer channels. + dropout_indices (list): Dropout indices. + norm_cfg (dict): Type of normalization method. + out_channels (int, optional): Module out channels. + """ + fc_layers = [] + pre_channel = in_channels + for k in range(len(fc_channels)): + fc_layers.append( + ConvModule( + pre_channel, + fc_channels[k], + kernel_size=(1, 1), + stride=(1, 1), + norm_cfg=norm_cfg, + conv_cfg=dict(type='Conv2d'), + bias=False, + inplace=True)) + pre_channel = fc_channels[k] + if self.dropout_ratio >= 0 and k in dropout_indices: + fc_layers.append(nn.Dropout(self.dropout_ratio)) + if out_channels is not None: + fc_layers.append( + nn.Conv2d(fc_channels[-1], out_channels, 1, bias=True)) + fc_layers = nn.Sequential(*fc_layers) + return fc_layers + + def forward(self, feats: torch.Tensor) -> Tuple[torch.Tensor]: + """Forward pvrcnn bbox head. + + Args: + feats (torch.Tensor): Batch point-wise features. + + Returns: + tuple[torch.Tensor]: Score of class and bbox predictions. + """ + # (B * N, 6, 6, 6, C) + rcnn_batch_size = feats.shape[0] + feats = feats.permute(0, 4, 1, 2, + 3).contiguous().view(rcnn_batch_size, -1, 1, 1) + # (BxN, C*6*6*6) + shared_feats = self.shared_fc_layer(feats) + cls_score = self.cls_layer(shared_feats).transpose( + 1, 2).contiguous().view(-1, self.cls_out_channels) # (B, 1) + bbox_pred = self.reg_layer(shared_feats).transpose( + 1, 2).contiguous().view(-1, self.reg_out_channels) # (B, C) + return cls_score, bbox_pred + + def loss(self, cls_score: torch.Tensor, bbox_pred: torch.Tensor, + rois: torch.Tensor, labels: torch.Tensor, + bbox_targets: torch.Tensor, pos_gt_bboxes: torch.Tensor, + reg_mask: torch.Tensor, label_weights: torch.Tensor, + bbox_weights: torch.Tensor) -> Dict: + """Coumputing losses. + + Args: + cls_score (torch.Tensor): Scores of each roi. + bbox_pred (torch.Tensor): Predictions of bboxes. + rois (torch.Tensor): Roi bboxes. + labels (torch.Tensor): Labels of class. + bbox_targets (torch.Tensor): Target of positive bboxes. + pos_gt_bboxes (torch.Tensor): Ground truths of positive bboxes. + reg_mask (torch.Tensor): Mask for positive bboxes. + label_weights (torch.Tensor): Weights of class loss. + bbox_weights (torch.Tensor): Weights of bbox loss. + + Returns: + dict: Computed losses. + + - loss_cls (torch.Tensor): Loss of classes. + - loss_bbox (torch.Tensor): Loss of bboxes. + - loss_corner (torch.Tensor): Loss of corners. + """ + losses = dict() + rcnn_batch_size = cls_score.shape[0] + + # calculate class loss + cls_flat = cls_score.view(-1) + loss_cls = self.loss_cls(cls_flat, labels, label_weights) + losses['loss_cls'] = loss_cls + + # calculate regression loss + code_size = self.bbox_coder.code_size + pos_inds = (reg_mask > 0) + if pos_inds.any() == 0: + # fake a part loss + losses['loss_bbox'] = 0 * bbox_pred.sum() + if self.with_corner_loss: + losses['loss_corner'] = 0 * bbox_pred.sum() + else: + pos_bbox_pred = bbox_pred.view(rcnn_batch_size, -1)[pos_inds] + bbox_weights_flat = bbox_weights[pos_inds].view(-1, 1).repeat( + 1, pos_bbox_pred.shape[-1]) + loss_bbox = self.loss_bbox( + pos_bbox_pred.unsqueeze(dim=0), bbox_targets.unsqueeze(dim=0), + bbox_weights_flat.unsqueeze(dim=0)) + losses['loss_bbox'] = loss_bbox + + if self.with_corner_loss: + pos_roi_boxes3d = rois[..., 1:].view(-1, code_size)[pos_inds] + pos_roi_boxes3d = pos_roi_boxes3d.view(-1, code_size) + batch_anchors = pos_roi_boxes3d.clone().detach() + pos_rois_rotation = pos_roi_boxes3d[..., 6].view(-1) + roi_xyz = pos_roi_boxes3d[..., 0:3].view(-1, 3) + batch_anchors[..., 0:3] = 0 + # decode boxes + pred_boxes3d = self.bbox_coder.decode( + batch_anchors, + pos_bbox_pred.view(-1, code_size)).view(-1, code_size) + + pred_boxes3d[..., 0:3] = rotation_3d_in_axis( + pred_boxes3d[..., 0:3].unsqueeze(1), + pos_rois_rotation, + axis=2).squeeze(1) + + pred_boxes3d[:, 0:3] += roi_xyz + + # calculate corner loss + loss_corner = self.get_corner_loss_lidar( + pred_boxes3d, pos_gt_bboxes) + losses['loss_corner'] = loss_corner.mean() + + return losses + + def get_targets(self, + sampling_results: SamplingResult, + rcnn_train_cfg: dict, + concat: bool = True) -> Tuple[torch.Tensor]: + """Generate targets. + + Args: + sampling_results (list[:obj:`SamplingResult`]): + Sampled results from rois. + rcnn_train_cfg (:obj:`ConfigDict`): Training config of rcnn. + concat (bool): Whether to concatenate targets between batches. + + Returns: + tuple[torch.Tensor]: Targets of boxes and class prediction. + """ + pos_bboxes_list = [res.pos_bboxes for res in sampling_results] + pos_gt_bboxes_list = [res.pos_gt_bboxes for res in sampling_results] + iou_list = [res.iou for res in sampling_results] + targets = multi_apply( + self._get_target_single, + pos_bboxes_list, + pos_gt_bboxes_list, + iou_list, + cfg=rcnn_train_cfg) + + (label, bbox_targets, pos_gt_bboxes, reg_mask, label_weights, + bbox_weights) = targets + + if concat: + label = torch.cat(label, 0) + bbox_targets = torch.cat(bbox_targets, 0) + pos_gt_bboxes = torch.cat(pos_gt_bboxes, 0) + reg_mask = torch.cat(reg_mask, 0) + + label_weights = torch.cat(label_weights, 0) + label_weights /= torch.clamp(label_weights.sum(), min=1.0) + + bbox_weights = torch.cat(bbox_weights, 0) + bbox_weights /= torch.clamp(bbox_weights.sum(), min=1.0) + + return (label, bbox_targets, pos_gt_bboxes, reg_mask, label_weights, + bbox_weights) + + def _get_target_single(self, pos_bboxes: torch.Tensor, + pos_gt_bboxes: torch.Tensor, ious: torch.Tensor, + cfg: dict) -> Tuple[torch.Tensor]: + """Generate training targets for a single sample. + + Args: + pos_bboxes (torch.Tensor): Positive boxes with shape + (N, 7). + pos_gt_bboxes (torch.Tensor): Ground truth boxes with shape + (M, 7). + ious (torch.Tensor): IoU between `pos_bboxes` and `pos_gt_bboxes` + in shape (N, M). + cfg (dict): Training configs. + + Returns: + tuple[torch.Tensor]: Target for positive boxes. + (label, bbox_targets, pos_gt_bboxes, reg_mask, label_weights, + bbox_weights) + """ + cls_pos_mask = ious > cfg.cls_pos_thr + cls_neg_mask = ious < cfg.cls_neg_thr + interval_mask = (cls_pos_mask == 0) & (cls_neg_mask == 0) + + # iou regression target + label = (cls_pos_mask > 0).float() + label[interval_mask] = ious[interval_mask] * 2 - 0.5 + # label weights + label_weights = (label >= 0).float() + + # box regression target + reg_mask = pos_bboxes.new_zeros(ious.size(0)).long() + reg_mask[0:pos_gt_bboxes.size(0)] = 1 + bbox_weights = (reg_mask > 0).float() + if reg_mask.bool().any(): + pos_gt_bboxes_ct = pos_gt_bboxes.clone().detach() + roi_center = pos_bboxes[..., 0:3] + roi_ry = pos_bboxes[..., 6] % (2 * np.pi) + + # canonical transformation + pos_gt_bboxes_ct[..., 0:3] -= roi_center + pos_gt_bboxes_ct[..., 6] -= roi_ry + pos_gt_bboxes_ct[..., 0:3] = rotation_3d_in_axis( + pos_gt_bboxes_ct[..., 0:3].unsqueeze(1), -roi_ry, + axis=2).squeeze(1) + + # flip orientation if rois have opposite orientation + ry_label = pos_gt_bboxes_ct[..., 6] % (2 * np.pi) # 0 ~ 2pi + opposite_flag = (ry_label > np.pi * 0.5) & (ry_label < np.pi * 1.5) + ry_label[opposite_flag] = (ry_label[opposite_flag] + np.pi) % ( + 2 * np.pi) # (0 ~ pi/2, 3pi/2 ~ 2pi) + flag = ry_label > np.pi + ry_label[flag] = ry_label[flag] - np.pi * 2 # (-pi/2, pi/2) + ry_label = torch.clamp(ry_label, min=-np.pi / 2, max=np.pi / 2) + pos_gt_bboxes_ct[..., 6] = ry_label + + rois_anchor = pos_bboxes.clone().detach() + rois_anchor[:, 0:3] = 0 + rois_anchor[:, 6] = 0 + bbox_targets = self.bbox_coder.encode(rois_anchor, + pos_gt_bboxes_ct) + else: + # no fg bbox + bbox_targets = pos_gt_bboxes.new_empty((0, 7)) + + return (label, bbox_targets, pos_gt_bboxes, reg_mask, label_weights, + bbox_weights) + + def get_corner_loss_lidar(self, + pred_bbox3d: torch.Tensor, + gt_bbox3d: torch.Tensor, + delta: float = 1.0) -> torch.Tensor: + """Calculate corner loss of given boxes. + + Args: + pred_bbox3d (torch.FloatTensor): Predicted boxes in shape (N, 7). + gt_bbox3d (torch.FloatTensor): Ground truth boxes in shape (N, 7). + delta (float, optional): huber loss threshold. Defaults to 1.0 + + Returns: + torch.FloatTensor: Calculated corner loss in shape (N). + """ + assert pred_bbox3d.shape[0] == gt_bbox3d.shape[0] + + # This is a little bit hack here because we assume the box for + # Part-A2 is in LiDAR coordinates + gt_boxes_structure = LiDARInstance3DBoxes(gt_bbox3d) + pred_box_corners = LiDARInstance3DBoxes(pred_bbox3d).corners + gt_box_corners = gt_boxes_structure.corners + + # This flip only changes the heading direction of GT boxes + gt_bbox3d_flip = gt_boxes_structure.clone() + gt_bbox3d_flip.tensor[:, 6] += np.pi + gt_box_corners_flip = gt_bbox3d_flip.corners + + corner_dist = torch.min( + torch.norm(pred_box_corners - gt_box_corners, dim=2), + torch.norm(pred_box_corners - gt_box_corners_flip, + dim=2)) # (N, 8) + # huber loss + abs_error = torch.abs(corner_dist) + corner_loss = torch.where(abs_error < delta, + 0.5 * abs_error**2 / delta, + abs_error - 0.5 * delta) + return corner_loss.mean(dim=1) + + def get_results(self, + rois: torch.Tensor, + cls_preds: torch.Tensor, + bbox_reg: torch.Tensor, + class_labels: torch.Tensor, + input_metas: List[dict], + test_cfg: dict = None) -> InstanceList: + """Generate bboxes from bbox head predictions. + + Args: + rois (torch.Tensor): Roi bounding boxes. + cls_preds (torch.Tensor): Scores of bounding boxes. + bbox_reg (torch.Tensor): Bounding boxes predictions + class_labels (torch.Tensor): Label of classes + input_metas (list[dict]): Point cloud meta info. + test_cfg (:obj:`ConfigDict`): Testing config. + + Returns: + list[:obj:`InstanceData`]: Detection results of each sample + after the post process. + Each item usually contains following keys. + + - scores_3d (Tensor): Classification scores, has a shape + (num_instances, ) + - labels_3d (Tensor): Labels of bboxes, has a shape + (num_instances, ). + - bboxes_3d (BaseInstance3DBoxes): Prediction of bboxes, + contains a tensor with shape (num_instances, C), where + C >= 7. + """ + roi_batch_id = rois[..., 0] + roi_boxes = rois[..., 1:] # boxes without batch id + batch_size = int(roi_batch_id.max().item() + 1) + + # decode boxes + roi_ry = roi_boxes[..., 6].view(-1) + roi_xyz = roi_boxes[..., 0:3].view(-1, 3) + local_roi_boxes = roi_boxes.clone().detach() + local_roi_boxes[..., 0:3] = 0 + batch_box_preds = self.bbox_coder.decode(local_roi_boxes, bbox_reg) + batch_box_preds[..., 0:3] = rotation_3d_in_axis( + batch_box_preds[..., 0:3].unsqueeze(1), roi_ry, axis=2).squeeze(1) + batch_box_preds[:, 0:3] += roi_xyz + + # post processing + result_list = [] + for batch_id in range(batch_size): + cls_preds = cls_preds[roi_batch_id == batch_id] + box_preds = batch_box_preds[roi_batch_id == batch_id] + label_preds = class_labels[batch_id] + + cls_preds = cls_preds.sigmoid() + cls_preds, _ = torch.max(cls_preds, dim=-1) + selected = self.class_agnostic_nms( + scores=cls_preds, + bbox_preds=box_preds, + input_meta=input_metas[batch_id], + nms_cfg=test_cfg) + + selected_bboxes = box_preds[selected] + selected_label_preds = label_preds[selected] + selected_scores = cls_preds[selected] + + results = InstanceData() + results.bboxes_3d = input_metas[batch_id]['box_type_3d']( + selected_bboxes, self.bbox_coder.code_size) + results.scores_3d = selected_scores + results.labels_3d = selected_label_preds + + result_list.append(results) + return result_list + + def class_agnostic_nms(self, scores: torch.Tensor, + bbox_preds: torch.Tensor, nms_cfg: dict, + input_meta: dict) -> Tuple[torch.Tensor]: + """Class agnostic NMS for box head. + + Args: + scores (torch.Tensor): Object score of bounding boxes. + bbox_preds (torch.Tensor): Predicted bounding boxes. + nms_cfg (dict): NMS config dict. + input_meta (dict): Contain pcd and img's meta info. + + Returns: + tuple[torch.Tensor]: Bounding boxes, scores and labels. + """ + obj_scores = scores.clone() + if nms_cfg.use_rotate_nms: + nms_func = nms_bev + else: + nms_func = nms_normal_bev + + bbox = input_meta['box_type_3d']( + bbox_preds.clone(), + box_dim=bbox_preds.shape[-1], + with_yaw=True, + origin=(0.5, 0.5, 0.5)) + + if nms_cfg.score_thr is not None: + scores_mask = (obj_scores >= nms_cfg.score_thr) + obj_scores = obj_scores[scores_mask] + bbox = bbox[scores_mask] + selected = [] + if obj_scores.shape[0] > 0: + box_scores_nms, indices = torch.topk( + obj_scores, k=min(4096, obj_scores.shape[0])) + bbox_bev = bbox.bev[indices] + bbox_for_nms = xywhr2xyxyr(bbox_bev) + + keep = nms_func(bbox_for_nms, box_scores_nms, nms_cfg.nms_thr) + selected = indices[keep] + if nms_cfg.score_thr is not None: + original_idxs = scores_mask.nonzero().view(-1) + selected = original_idxs[selected] + return selected diff --git a/mmdet3d/models/roi_heads/mask_heads/__init__.py b/mmdet3d/models/roi_heads/mask_heads/__init__.py index 0aa11569a..68e754b4f 100644 --- a/mmdet3d/models/roi_heads/mask_heads/__init__.py +++ b/mmdet3d/models/roi_heads/mask_heads/__init__.py @@ -1,5 +1,8 @@ # Copyright (c) OpenMMLab. All rights reserved. +from .foreground_segmentation_head import ForegroundSegmentationHead from .pointwise_semantic_head import PointwiseSemanticHead from .primitive_head import PrimitiveHead -__all__ = ['PointwiseSemanticHead', 'PrimitiveHead'] +__all__ = [ + 'PointwiseSemanticHead', 'PrimitiveHead', 'ForegroundSegmentationHead' +] diff --git a/mmdet3d/models/roi_heads/mask_heads/foreground_segmentation_head.py b/mmdet3d/models/roi_heads/mask_heads/foreground_segmentation_head.py new file mode 100644 index 000000000..65ea38fa0 --- /dev/null +++ b/mmdet3d/models/roi_heads/mask_heads/foreground_segmentation_head.py @@ -0,0 +1,175 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Dict, Optional, Tuple + +import torch +from mmcv.cnn.bricks import build_norm_layer +from mmengine.model import BaseModule +from mmengine.structures import InstanceData +from torch import nn as nn + +from mmdet3d.models.builder import build_loss +from mmdet3d.registry import MODELS +from mmdet3d.utils import InstanceList +from mmdet.models.utils import multi_apply + + +@MODELS.register_module() +class ForegroundSegmentationHead(BaseModule): + """Foreground segmentation head. + + Args: + in_channels (int): The number of input channel. + mlp_channels (tuple[int]): Specify of mlp channels. Defaults + to (256, 256). + extra_width (float): Boxes enlarge width. Default used 0.1. + norm_cfg (dict): Type of normalization method. Defaults to + dict(type='BN1d', eps=1e-5, momentum=0.1). + init_cfg (dict, optional): Initialize config of + model. Defaults to None. + loss_seg (dict): Config of segmentation loss. Defaults to + dict(type='mmdet.FocalLoss') + """ + + def __init__( + self, + in_channels: int, + mlp_channels: Tuple[int] = (256, 256), + extra_width: float = 0.1, + norm_cfg: dict = dict(type='BN1d', eps=1e-5, momentum=0.1), + init_cfg: Optional[dict] = None, + loss_seg: dict = dict( + type='mmdet.FocalLoss', + use_sigmoid=True, + reduction='sum', + gamma=2.0, + alpha=0.25, + activated=True, + loss_weight=1.0) + ) -> None: + super(ForegroundSegmentationHead, self).__init__(init_cfg=init_cfg) + self.extra_width = extra_width + self.num_classes = 1 + + self.in_channels = in_channels + self.use_sigmoid_cls = loss_seg.get('use_sigmoid', False) + + out_channels = 1 + if self.use_sigmoid_cls: + self.out_channels = out_channels + else: + self.out_channels = out_channels + 1 + + mlps_layers = [] + cin = in_channels + for mlp in mlp_channels: + mlps_layers.extend([ + nn.Linear(cin, mlp, bias=False), + build_norm_layer(norm_cfg, mlp)[1], + nn.ReLU() + ]) + cin = mlp + mlps_layers.append(nn.Linear(cin, self.out_channels, bias=True)) + + self.seg_cls_layer = nn.Sequential(*mlps_layers) + + self.loss_seg = build_loss(loss_seg) + + def forward(self, feats: torch.Tensor) -> dict: + """Forward head. + + Args: + feats (torch.Tensor): Point-wise features. + + Returns: + dict: Segment predictions. + """ + seg_preds = self.seg_cls_layer(feats) + return dict(seg_preds=seg_preds) + + def _get_targets_single(self, point_xyz: torch.Tensor, + gt_bboxes_3d: InstanceData, + gt_labels_3d: torch.Tensor) -> torch.Tensor: + """generate segmentation targets for a single sample. + + Args: + point_xyz (torch.Tensor): Coordinate of points. + gt_bboxes_3d (:obj:`BaseInstance3DBoxes`): Ground truth boxes in + shape (box_num, 7). + gt_labels_3d (torch.Tensor): Class labels of ground truths in + shape (box_num). + + Returns: + torch.Tensor: Points class labels. + """ + point_cls_labels_single = point_xyz.new_zeros( + point_xyz.shape[0]).long() + enlarged_gt_boxes = gt_bboxes_3d.enlarged_box(self.extra_width) + + box_idxs_of_pts = gt_bboxes_3d.points_in_boxes_part(point_xyz).long() + extend_box_idxs_of_pts = enlarged_gt_boxes.points_in_boxes_part( + point_xyz).long() + box_fg_flag = box_idxs_of_pts >= 0 + fg_flag = box_fg_flag.clone() + ignore_flag = fg_flag ^ (extend_box_idxs_of_pts >= 0) + point_cls_labels_single[ignore_flag] = -1 + gt_box_of_fg_points = gt_labels_3d[box_idxs_of_pts[fg_flag]] + point_cls_labels_single[ + fg_flag] = 1 if self.num_classes == 1 else\ + gt_box_of_fg_points.long() + return point_cls_labels_single, + + def get_targets(self, points_bxyz: torch.Tensor, + batch_gt_instances_3d: InstanceList) -> dict: + """Generate segmentation targets. + + Args: + points_bxyz (torch.Tensor): The coordinates of point in shape + (B, num_points, 3). + batch_gt_instances_3d (list[:obj:`InstanceData`]): Batch of + gt_instances. It usually includes ``bboxes_3d`` and + ``labels_3d`` attributes. + + Returns: + dict: Prediction targets + - seg_targets (torch.Tensor): Segmentation targets. + """ + batch_size = len(batch_gt_instances_3d) + points_xyz_list = [] + gt_bboxes_3d = [] + gt_labels_3d = [] + for idx in range(batch_size): + coords_idx = points_bxyz[:, 0] == idx + points_xyz_list.append(points_bxyz[coords_idx][..., 1:]) + gt_bboxes_3d.append(batch_gt_instances_3d[idx].bboxes_3d) + gt_labels_3d.append(batch_gt_instances_3d[idx].labels_3d) + seg_targets, = multi_apply(self._get_targets_single, points_xyz_list, + gt_bboxes_3d, gt_labels_3d) + seg_targets = torch.cat(seg_targets, dim=0) + return dict(seg_targets=seg_targets) + + def loss(self, semantic_results: dict, + semantic_targets: dict) -> Dict[str, torch.Tensor]: + """Calculate point-wise segmentation losses. + + Args: + semantic_results (dict): Results from semantic head. + semantic_targets (dict): Targets of semantic results. + + Returns: + dict: Loss of segmentation. + + - loss_semantic (torch.Tensor): Segmentation prediction loss. + """ + seg_preds = semantic_results['seg_preds'] + seg_targets = semantic_targets['seg_targets'] + + positives = (seg_targets > 0) + + negative_cls_weights = (seg_targets == 0).float() + seg_weights = (negative_cls_weights + 1.0 * positives).float() + pos_normalizer = positives.sum(dim=0).float() + seg_weights /= torch.clamp(pos_normalizer, min=1.0) + + seg_preds = torch.sigmoid(seg_preds) + loss_seg = self.loss_seg(seg_preds, (~positives).long(), seg_weights) + return dict(loss_semantic=loss_seg) diff --git a/mmdet3d/models/roi_heads/pv_rcnn_roi_head.py b/mmdet3d/models/roi_heads/pv_rcnn_roi_head.py new file mode 100644 index 000000000..adc001b0b --- /dev/null +++ b/mmdet3d/models/roi_heads/pv_rcnn_roi_head.py @@ -0,0 +1,312 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Optional + +import torch +from torch.nn import functional as F + +from mmdet3d.models.roi_heads.base_3droi_head import Base3DRoIHead +from mmdet3d.registry import MODELS +from mmdet3d.structures import bbox3d2roi +from mmdet3d.structures.det3d_data_sample import SampleList +from mmdet3d.utils import InstanceList +from mmdet.models.task_modules import AssignResult +from mmdet.models.task_modules.samplers import SamplingResult + + +@MODELS.register_module() +class PVRCNNRoiHead(Base3DRoIHead): + """RoI head for PV-RCNN. + + Args: + num_classes (int): The number of classes. Defaults to 3. + semantic_head (dict, optional): Config of semantic head. + Defaults to None. + bbox_roi_extractor (dict, optional): Config of roi_extractor. + Defaults to None. + bbox_head (dict, optional): Config of bbox_head. Defaults to None. + train_cfg (dict, optional): Train config of model. + Defaults to None. + test_cfg (dict, optional): Train config of model. + Defaults to None. + init_cfg (dict, optional): Initialize config of + model. Defaults to None. + """ + + def __init__(self, + num_classes: int = 3, + semantic_head: Optional[dict] = None, + bbox_roi_extractor: Optional[dict] = None, + bbox_head: Optional[dict] = None, + train_cfg: Optional[dict] = None, + test_cfg: Optional[dict] = None, + init_cfg: Optional[dict] = None): + super(PVRCNNRoiHead, self).__init__( + bbox_head=bbox_head, + bbox_roi_extractor=bbox_roi_extractor, + train_cfg=train_cfg, + test_cfg=test_cfg, + init_cfg=init_cfg) + self.num_classes = num_classes + self.semantic_head = MODELS.build(semantic_head) + + self.init_assigner_sampler() + + @property + def with_semantic(self): + """bool: whether the head has semantic branch""" + return hasattr(self, + 'semantic_head') and self.semantic_head is not None + + def loss(self, feats_dict: dict, rpn_results_list: InstanceList, + batch_data_samples: SampleList, **kwargs) -> dict: + """Training forward function of PVRCNNROIHead. + + Args: + feats_dict (dict): Contains point-wise features. + rpn_results_list (List[:obj:`InstanceData`]): Detection results + of rpn head. + batch_data_samples (List[:obj:`Det3DDataSample`]): The Data + samples. It usually includes information such as + `gt_instance_3d`, `gt_panoptic_seg_3d` and `gt_sem_seg_3d`. + + Returns: + dict: losses from each head. + + - loss_semantic (torch.Tensor): loss of semantic head. + - loss_bbox (torch.Tensor): loss of bboxes. + - loss_cls (torch.Tensor): loss of object classification. + - loss_corner (torch.Tensor): loss of bboxes corners. + """ + losses = dict() + batch_gt_instances_3d = [] + batch_gt_instances_ignore = [] + for data_sample in batch_data_samples: + batch_gt_instances_3d.append(data_sample.gt_instances_3d) + if 'ignored_instances' in data_sample: + batch_gt_instances_ignore.append(data_sample.ignored_instances) + else: + batch_gt_instances_ignore.append(None) + if self.with_semantic: + semantic_results = self._semantic_forward_train( + feats_dict['keypoint_features'], feats_dict['keypoints'], + batch_gt_instances_3d) + losses['loss_semantic'] = semantic_results['loss_semantic'] + + sample_results = self._assign_and_sample(rpn_results_list, + batch_gt_instances_3d) + if self.with_bbox: + bbox_results = self._bbox_forward_train( + semantic_results['seg_preds'], + feats_dict['fusion_keypoint_features'], + feats_dict['keypoints'], sample_results) + losses.update(bbox_results['loss_bbox']) + + return losses + + def predict(self, feats_dict: dict, rpn_results_list: InstanceList, + batch_data_samples: SampleList, **kwargs) -> SampleList: + """Perform forward propagation of the roi head and predict detection + results on the features of the upstream network. + + Args: + feats_dict (dict): Contains point-wise features. + rpn_results_list (List[:obj:`InstanceData`]): Detection results + of rpn head. + batch_data_samples (List[:obj:`Det3DDataSample`]): The Data + samples. It usually includes information such as + `gt_instance_3d`, `gt_panoptic_seg_3d` and `gt_sem_seg_3d`. + + Returns: + list[:obj:`InstanceData`]: Detection results of each sample + after the post process. + Each item usually contains following keys. + + - scores_3d (Tensor): Classification scores, has a shape + (num_instances, ) + - labels_3d (Tensor): Labels of bboxes, has a shape + (num_instances, ). + - bboxes_3d (BaseInstance3DBoxes): Prediction of bboxes, + contains a tensor with shape (num_instances, C), where + C >= 7. + """ + assert self.with_bbox, 'Bbox head must be implemented.' + assert self.with_semantic, 'Semantic head must be implemented.' + + batch_input_metas = [ + data_samples.metainfo for data_samples in batch_data_samples + ] + + semantic_results = self.semantic_head(feats_dict['keypoint_features']) + point_features = feats_dict[ + 'fusion_keypoint_features'] * semantic_results[ + 'seg_preds'].sigmoid().max( + dim=-1, keepdim=True).values + rois = bbox3d2roi( + [res['bboxes_3d'].tensor for res in rpn_results_list]) + labels_3d = [res['labels_3d'] for res in rpn_results_list] + bbox_results = self._bbox_forward(point_features, + feats_dict['keypoints'], rois) + + results_list = self.bbox_head.get_results(rois, + bbox_results['bbox_scores'], + bbox_results['bbox_reg'], + labels_3d, batch_input_metas, + self.test_cfg) + return results_list + + def _bbox_forward_train(self, seg_preds: torch.Tensor, + keypoint_features: torch.Tensor, + keypoints: torch.Tensor, + sampling_results: SamplingResult) -> dict: + """Forward training function of roi_extractor and bbox_head. + + Args: + seg_preds (torch.Tensor): Point-wise semantic features. + keypoint_features (torch.Tensor): key points features + from points encoder. + keypoints (torch.Tensor): Coordinate of key points. + sampling_results (:obj:`SamplingResult`): Sampled results used + for training. + + Returns: + dict: Forward results including losses and predictions. + """ + rois = bbox3d2roi([res.bboxes for res in sampling_results]) + keypoint_features = keypoint_features * seg_preds.sigmoid().max( + dim=-1, keepdim=True).values + bbox_results = self._bbox_forward(keypoint_features, keypoints, rois) + + bbox_targets = self.bbox_head.get_targets(sampling_results, + self.train_cfg) + loss_bbox = self.bbox_head.loss(bbox_results['bbox_scores'], + bbox_results['bbox_reg'], rois, + *bbox_targets) + + bbox_results.update(loss_bbox=loss_bbox) + return bbox_results + + def _bbox_forward(self, keypoint_features: torch.Tensor, + keypoints: torch.Tensor, rois: torch.Tensor) -> dict: + """Forward function of roi_extractor and bbox_head used in both + training and testing. + + Args: + rois (Tensor): Roi boxes. + keypoint_features (torch.Tensor): key points features + from points encoder. + keypoints (torch.Tensor): Coordinate of key points. + rois (Tensor): Roi boxes. + + Returns: + dict: Contains predictions of bbox_head and + features of roi_extractor. + """ + pooled_keypoint_features = self.bbox_roi_extractor( + keypoint_features, keypoints[..., 1:], keypoints[..., 0].int(), + rois) + bbox_score, bbox_reg = self.bbox_head(pooled_keypoint_features) + + bbox_results = dict(bbox_scores=bbox_score, bbox_reg=bbox_reg) + return bbox_results + + def _assign_and_sample( + self, proposal_list: InstanceList, + batch_gt_instances_3d: InstanceList) -> List[SamplingResult]: + """Assign and sample proposals for training. + + Args: + proposal_list (list[:obj:`InstancesData`]): Proposals produced by + rpn head. + batch_gt_instances_3d (list[:obj:`InstanceData`]): Batch of + gt_instances. It usually includes ``bboxes_3d`` and + ``labels_3d`` attributes. + + Returns: + list[:obj:`SamplingResult`]: Sampled results of each training + sample. + """ + sampling_results = [] + # bbox assign + for batch_idx in range(len(proposal_list)): + cur_proposal_list = proposal_list[batch_idx] + cur_boxes = cur_proposal_list['bboxes_3d'] + cur_labels_3d = cur_proposal_list['labels_3d'] + cur_gt_instances_3d = batch_gt_instances_3d[batch_idx] + cur_gt_instances_3d.bboxes_3d = cur_gt_instances_3d.\ + bboxes_3d.tensor + cur_gt_bboxes = batch_gt_instances_3d[batch_idx].bboxes_3d.to( + cur_boxes.device) + cur_gt_labels = batch_gt_instances_3d[batch_idx].labels_3d + + batch_num_gts = 0 + # 0 is bg + batch_gt_indis = cur_gt_labels.new_full((len(cur_boxes), ), 0) + batch_max_overlaps = cur_boxes.tensor.new_zeros(len(cur_boxes)) + # -1 is bg + batch_gt_labels = cur_gt_labels.new_full((len(cur_boxes), ), -1) + + # each class may have its own assigner + if isinstance(self.bbox_assigner, list): + for i, assigner in enumerate(self.bbox_assigner): + gt_per_cls = (cur_gt_labels == i) + pred_per_cls = (cur_labels_3d == i) + cur_assign_res = assigner.assign( + cur_proposal_list[pred_per_cls], + cur_gt_instances_3d[gt_per_cls]) + # gather assign_results in different class into one result + batch_num_gts += cur_assign_res.num_gts + # gt inds (1-based) + gt_inds_arange_pad = gt_per_cls.nonzero( + as_tuple=False).view(-1) + 1 + # pad 0 for indice unassigned + gt_inds_arange_pad = F.pad( + gt_inds_arange_pad, (1, 0), mode='constant', value=0) + # pad -1 for indice ignore + gt_inds_arange_pad = F.pad( + gt_inds_arange_pad, (1, 0), mode='constant', value=-1) + # convert to 0~gt_num+2 for indices + gt_inds_arange_pad += 1 + # now 0 is bg, >1 is fg in batch_gt_indis + batch_gt_indis[pred_per_cls] = gt_inds_arange_pad[ + cur_assign_res.gt_inds + 1] - 1 + batch_max_overlaps[ + pred_per_cls] = cur_assign_res.max_overlaps + batch_gt_labels[pred_per_cls] = cur_assign_res.labels + + assign_result = AssignResult(batch_num_gts, batch_gt_indis, + batch_max_overlaps, + batch_gt_labels) + else: # for single class + assign_result = self.bbox_assigner.assign( + cur_proposal_list, cur_gt_instances_3d) + # sample boxes + sampling_result = self.bbox_sampler.sample(assign_result, + cur_boxes.tensor, + cur_gt_bboxes, + cur_gt_labels) + sampling_results.append(sampling_result) + return sampling_results + + def _semantic_forward_train(self, keypoint_features: torch.Tensor, + keypoints: torch.Tensor, + batch_gt_instances_3d: InstanceList) -> dict: + """Train semantic head. + + Args: + keypoint_features (torch.Tensor): key points features + from points encoder. + keypoints (torch.Tensor): Coordinate of key points. + batch_gt_instances_3d (list[:obj:`InstanceData`]): Batch of + gt_instances. It usually includes ``bboxes_3d`` and + ``labels_3d`` attributes. + + Returns: + dict: Segmentation results including losses + """ + semantic_results = self.semantic_head(keypoint_features) + semantic_targets = self.semantic_head.get_targets( + keypoints, batch_gt_instances_3d) + loss_semantic = self.semantic_head.loss(semantic_results, + semantic_targets) + semantic_results.update(loss_semantic) + return semantic_results diff --git a/mmdet3d/models/roi_heads/roi_extractors/__init__.py b/mmdet3d/models/roi_heads/roi_extractors/__init__.py index d2b4d03a5..f10e7179c 100644 --- a/mmdet3d/models/roi_heads/roi_extractors/__init__.py +++ b/mmdet3d/models/roi_heads/roi_extractors/__init__.py @@ -1,10 +1,11 @@ # Copyright (c) OpenMMLab. All rights reserved. from mmdet.models.roi_heads.roi_extractors import SingleRoIExtractor +from .batch_roigridpoint_extractor import Batch3DRoIGridExtractor from .single_roiaware_extractor import Single3DRoIAwareExtractor from .single_roipoint_extractor import Single3DRoIPointExtractor __all__ = [ 'SingleRoIExtractor', 'Single3DRoIAwareExtractor', - 'Single3DRoIPointExtractor' + 'Single3DRoIPointExtractor', 'Batch3DRoIGridExtractor' ] diff --git a/mmdet3d/models/roi_heads/roi_extractors/batch_roigridpoint_extractor.py b/mmdet3d/models/roi_heads/roi_extractors/batch_roigridpoint_extractor.py new file mode 100644 index 000000000..6d4825f31 --- /dev/null +++ b/mmdet3d/models/roi_heads/roi_extractors/batch_roigridpoint_extractor.py @@ -0,0 +1,97 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +from mmengine.model import BaseModule + +from mmdet3d.registry import MODELS +from mmdet3d.structures.bbox_3d import rotation_3d_in_axis + + +@MODELS.register_module() +class Batch3DRoIGridExtractor(BaseModule): + """Grid point wise roi-aware Extractor. + + Args: + grid_size (int): The number of grid points in a roi bbox. + Defaults to 6. + roi_layer (dict, optional): Config of sa module to get + grid points features. Defaults to None. + init_cfg (dict, optional): Initialize config of + model. Defaults to None. + """ + + def __init__(self, + grid_size: int = 6, + roi_layer: dict = None, + init_cfg: dict = None) -> None: + super(Batch3DRoIGridExtractor, self).__init__(init_cfg=init_cfg) + self.roi_grid_pool_layer = MODELS.build(roi_layer) + self.grid_size = grid_size + + def forward(self, feats: torch.Tensor, coordinate: torch.Tensor, + batch_inds: torch.Tensor, rois: torch.Tensor) -> torch.Tensor: + """Forward roi extractor to extract grid points feature. + + Args: + feats (torch.Tensor): Key points features. + coordinate (torch.Tensor): Key points coordinates. + batch_inds (torch.Tensor): Input batch indexes. + rois (torch.Tensor): Detection results of rpn head. + + Returns: + torch.Tensor: Grid points features. + """ + batch_size = int(batch_inds.max()) + 1 + + xyz = coordinate + xyz_batch_cnt = xyz.new_zeros(batch_size).int() + for k in range(batch_size): + xyz_batch_cnt[k] = (batch_inds == k).sum() + + rois_batch_inds = rois[:, 0].int() + # (N1+N2+..., 6x6x6, 3) + roi_grid = self.get_dense_grid_points(rois[:, 1:]) + + new_xyz = roi_grid.view(-1, 3) + new_xyz_batch_cnt = new_xyz.new_zeros(batch_size).int() + for k in range(batch_size): + new_xyz_batch_cnt[k] = ((rois_batch_inds == k).sum() * + roi_grid.size(1)) + pooled_points, pooled_features = self.roi_grid_pool_layer( + xyz=xyz.contiguous(), + xyz_batch_cnt=xyz_batch_cnt, + new_xyz=new_xyz.contiguous(), + new_xyz_batch_cnt=new_xyz_batch_cnt, + features=feats.contiguous()) # (M1 + M2 ..., C) + + pooled_features = pooled_features.view(-1, self.grid_size, + self.grid_size, self.grid_size, + pooled_features.shape[-1]) + # (BxN, 6, 6, 6, C) + return pooled_features + + def get_dense_grid_points(self, rois: torch.Tensor) -> torch.Tensor: + """Get dense grid points from rois. + + Args: + rois (torch.Tensor): Detection results of rpn head. + + Returns: + torch.Tensor: Grid points coordinates. + """ + rois_bbox = rois.clone() + rois_bbox[:, 2] += rois_bbox[:, 5] / 2 + faked_features = rois_bbox.new_ones( + (self.grid_size, self.grid_size, self.grid_size)) + dense_idx = faked_features.nonzero() + dense_idx = dense_idx.repeat(rois_bbox.size(0), 1, 1).float() + dense_idx = ((dense_idx + 0.5) / self.grid_size) + dense_idx[..., :3] -= 0.5 + + roi_ctr = rois_bbox[:, :3] + roi_dim = rois_bbox[:, 3:6] + roi_grid_points = dense_idx * roi_dim.view(-1, 1, 3) + roi_grid_points = rotation_3d_in_axis( + roi_grid_points, rois_bbox[:, 6], axis=2) + roi_grid_points += roi_ctr.view(-1, 1, 3) + + return roi_grid_points diff --git a/tests/test_models/test_detectors/test_pvrcnn.py b/tests/test_models/test_detectors/test_pvrcnn.py new file mode 100644 index 000000000..66efdeb07 --- /dev/null +++ b/tests/test_models/test_detectors/test_pvrcnn.py @@ -0,0 +1,64 @@ +import unittest + +import torch +from mmengine import DefaultScope + +from mmdet3d.registry import MODELS +from tests.utils.model_utils import (_create_detector_inputs, + _get_detector_cfg, _setup_seed) + + +class TestPVRCNN(unittest.TestCase): + + def test_pvrcnn(self): + import mmdet3d.models + + assert hasattr(mmdet3d.models, 'PointVoxelRCNN') + DefaultScope.get_instance('test_pvrcnn', scope_name='mmdet3d') + _setup_seed(0) + pvrcnn_cfg = _get_detector_cfg( + 'pvrcnn/pvrcnn_8xb2-80e_kitti-3d-3class.py') + model = MODELS.build(pvrcnn_cfg) + num_gt_instance = 2 + packed_inputs = _create_detector_inputs( + num_gt_instance=num_gt_instance) + + # TODO: Support aug data test + # aug_packed_inputs = [ + # _create_detector_inputs(num_gt_instance=num_gt_instance), + # _create_detector_inputs(num_gt_instance=num_gt_instance + 1) + # ] + # test_aug_test + # metainfo = { + # 'pcd_scale_factor': 1, + # 'pcd_horizontal_flip': 1, + # 'pcd_vertical_flip': 1, + # 'box_type_3d': LiDARInstance3DBoxes + # } + # for item in aug_packed_inputs: + # for batch_id in len(item['data_samples']): + # item['data_samples'][batch_id].set_metainfo(metainfo) + + if torch.cuda.is_available(): + model = model.cuda() + # test simple_test + with torch.no_grad(): + data = model.data_preprocessor(packed_inputs, True) + torch.cuda.empty_cache() + results = model.forward(**data, mode='predict') + self.assertEqual(len(results), 1) + self.assertIn('bboxes_3d', results[0].pred_instances_3d) + self.assertIn('scores_3d', results[0].pred_instances_3d) + self.assertIn('labels_3d', results[0].pred_instances_3d) + + # save the memory + with torch.no_grad(): + losses = model.forward(**data, mode='loss') + torch.cuda.empty_cache() + self.assertGreater(losses['loss_rpn_cls'][0], 0) + self.assertGreaterEqual(losses['loss_rpn_bbox'][0], 0) + self.assertGreaterEqual(losses['loss_rpn_dir'][0], 0) + self.assertGreater(losses['loss_semantic'], 0) + self.assertGreaterEqual(losses['loss_bbox'], 0) + self.assertGreaterEqual(losses['loss_cls'], 0) + self.assertGreaterEqual(losses['loss_corner'], 0)