diff --git a/mmdet3d/models/roi_heads/mask_heads/primitive_head.py b/mmdet3d/models/roi_heads/mask_heads/primitive_head.py index 26a2cbbbb..f4e2303a1 100644 --- a/mmdet3d/models/roi_heads/mask_heads/primitive_head.py +++ b/mmdet3d/models/roi_heads/mask_heads/primitive_head.py @@ -354,10 +354,25 @@ def get_targets_single(self, # Semantic information of primitive center point_sem = points.new_zeros([num_points, 3 + self.num_dims + 1]) + # Generate pts_semantic_mask and pts_instance_mask when they are None + if pts_semantic_mask is None or pts_instance_mask is None: + points2box_mask = gt_bboxes_3d.points_in_boxes(points) + assignment = points2box_mask.argmax(1) + background_mask = points2box_mask.max(1)[0] == 0 + + if pts_semantic_mask is None: + pts_semantic_mask = gt_labels_3d[assignment] + pts_semantic_mask[background_mask] = self.num_classes + + if pts_instance_mask is None: + pts_instance_mask = assignment + pts_instance_mask[background_mask] = gt_labels_3d.shape[0] + instance_flag = torch.nonzero( pts_semantic_mask != self.num_classes).squeeze(1) instance_labels = pts_instance_mask[instance_flag].unique() + with_yaw = gt_bboxes_3d.with_yaw for i, i_instance in enumerate(instance_labels): indices = instance_flag[pts_instance_mask[instance_flag] == i_instance] @@ -366,8 +381,6 @@ def get_targets_single(self, # Bbox Corners cur_corners = gt_bboxes_3d.corners[i] - xmin, ymin, zmin = cur_corners.min(0)[0] - xmax, ymax, zmax = cur_corners.max(0)[0] plane_lower_temp = points.new_tensor( [0, 0, 1, -cur_corners[7, -1]]) @@ -392,10 +405,10 @@ def get_targets_single(self, point2plane_dist, selected = self.match_point2plane( plane_lower, coords) - # Get lower four lines + # Get bottom four lines if self.primitive_mode == 'line': point2line_matching = self.match_point2line( - coords[selected], xmin, xmax, ymin, ymax) + coords[selected], cur_corners, with_yaw, mode='bottom') point_mask, point_offset, point_sem = \ self._assign_primitive_line_targets(point_mask, @@ -406,7 +419,9 @@ def get_targets_single(self, cur_cls_label, point2line_matching, cur_corners, - [1, 1, 0, 0]) + [1, 1, 0, 0], + with_yaw, + mode='bottom') # Set the surface labels here if self.primitive_mode == 'z' and \ @@ -421,16 +436,18 @@ def get_targets_single(self, coords[selected], indices[selected], cur_cls_label, - cur_corners) + cur_corners, + with_yaw, + mode='bottom') # Get the boundary points here point2plane_dist, selected = self.match_point2plane( plane_upper, coords) - # Get upper four lines + # Get top four lines if self.primitive_mode == 'line': point2line_matching = self.match_point2line( - coords[selected], xmin, xmax, ymin, ymax) + coords[selected], cur_corners, with_yaw, mode='top') point_mask, point_offset, point_sem = \ self._assign_primitive_line_targets(point_mask, @@ -441,7 +458,9 @@ def get_targets_single(self, cur_cls_label, point2line_matching, cur_corners, - [1, 1, 0, 0]) + [1, 1, 0, 0], + with_yaw, + mode='top') if self.primitive_mode == 'z' and \ selected.sum() > self.train_cfg['num_point'] and \ @@ -455,7 +474,9 @@ def get_targets_single(self, coords[selected], indices[selected], cur_cls_label, - cur_corners) + cur_corners, + with_yaw, + mode='top') # Get left two lines plane_left_temp = self._get_plane_fomulation( @@ -480,20 +501,16 @@ def get_targets_single(self, point2plane_dist, selected = self.match_point2plane( plane_left, coords) - # Get upper four lines + # Get left four lines if self.primitive_mode == 'line': - _, _, line_sel1, line_sel2 = self.match_point2line( - coords[selected], xmin, xmax, ymin, ymax) + point2line_matching = self.match_point2line( + coords[selected], cur_corners, with_yaw, mode='left') point_mask, point_offset, point_sem = \ - self._assign_primitive_line_targets(point_mask, - point_offset, - point_sem, - coords[selected], - indices[selected], - cur_cls_label, - [line_sel1, line_sel2], - cur_corners, - [2, 2]) + self._assign_primitive_line_targets( + point_mask, point_offset, point_sem, + coords[selected], indices[selected], cur_cls_label, + point2line_matching[2:], cur_corners, [2, 2], + with_yaw, mode='left') if self.primitive_mode == 'xy' and \ selected.sum() > self.train_cfg['num_point'] and \ @@ -501,32 +518,26 @@ def get_targets_single(self, self.train_cfg['var_thresh']: point_mask, point_offset, point_sem = \ - self._assign_primitive_surface_targets(point_mask, - point_offset, - point_sem, - coords[selected], - indices[selected], - cur_cls_label, - cur_corners) + self._assign_primitive_surface_targets( + point_mask, point_offset, point_sem, + coords[selected], indices[selected], cur_cls_label, + cur_corners, with_yaw, mode='left') # Get the boundary points here point2plane_dist, selected = self.match_point2plane( plane_right, coords) + # Get right four lines if self.primitive_mode == 'line': - _, _, line_sel1, line_sel2 = self.match_point2line( - coords[selected], xmin, xmax, ymin, ymax) + point2line_matching = self.match_point2line( + coords[selected], cur_corners, with_yaw, mode='right') point_mask, point_offset, point_sem = \ - self._assign_primitive_line_targets(point_mask, - point_offset, - point_sem, - coords[selected], - indices[selected], - cur_cls_label, - [line_sel1, line_sel2], - cur_corners, - [2, 2]) + self._assign_primitive_line_targets( + point_mask, point_offset, point_sem, + coords[selected], indices[selected], cur_cls_label, + point2line_matching[2:], cur_corners, [2, 2], + with_yaw, mode='right') if self.primitive_mode == 'xy' and \ selected.sum() > self.train_cfg['num_point'] and \ @@ -534,13 +545,10 @@ def get_targets_single(self, self.train_cfg['var_thresh']: point_mask, point_offset, point_sem = \ - self._assign_primitive_surface_targets(point_mask, - point_offset, - point_sem, - coords[selected], - indices[selected], - cur_cls_label, - cur_corners) + self._assign_primitive_surface_targets( + point_mask, point_offset, point_sem, + coords[selected], indices[selected], cur_cls_label, + cur_corners, with_yaw, mode='right') plane_front_temp = self._get_plane_fomulation( cur_corners[0] - cur_corners[4], @@ -570,13 +578,10 @@ def get_targets_single(self, self.train_cfg['var_thresh']: point_mask, point_offset, point_sem = \ - self._assign_primitive_surface_targets(point_mask, - point_offset, - point_sem, - coords[selected], - indices[selected], - cur_cls_label, - cur_corners) + self._assign_primitive_surface_targets( + point_mask, point_offset, point_sem, + coords[selected], indices[selected], cur_cls_label, + cur_corners, with_yaw, mode='front') # Get the boundary points here point2plane_dist, selected = self.match_point2plane( @@ -588,13 +593,10 @@ def get_targets_single(self, self.train_cfg['var_thresh']: point_mask, point_offset, point_sem = \ - self._assign_primitive_surface_targets(point_mask, - point_offset, - point_sem, - coords[selected], - indices[selected], - cur_cls_label, - cur_corners) + self._assign_primitive_surface_targets( + point_mask, point_offset, point_sem, + coords[selected], indices[selected], cur_cls_label, + cur_corners, with_yaw, mode='back') return (point_mask, point_sem, point_offset) @@ -652,24 +654,65 @@ def check_dist(self, plane_equ, points): return (points[:, 2] + plane_equ[-1]).sum() / 4.0 < self.train_cfg['lower_thresh'] - def match_point2line(self, points, xmin, xmax, ymin, ymax): + def point2line_dist(self, points, pts_a, pts_b): + """Calculate the distance from point to line. + + Args: + points (torch.Tensor): Points of input. + pts_a (torch.Tensor): Point on the specific line. + pts_b (torch.Tensor): Point on the specific line. + + Returns: + torch.Tensor: Distance between each point to line. + """ + line_a2b = pts_b - pts_a + line_a2pts = points - pts_a + length = (line_a2pts * line_a2b.view(1, 3)).sum(1) / \ + line_a2b.norm() + dist = (line_a2pts.norm(dim=1)**2 - length**2).sqrt() + + return dist + + def match_point2line(self, points, corners, with_yaw, mode='bottom'): """Match points to corresponding line. Args: points (torch.Tensor): Points of input. - xmin (float): Min of X-axis. - xmax (float): Max of X-axis. - ymin (float): Min of Y-axis. - ymax (float): Max of Y-axis. + corners (torch.Tensor): Eight corners of a bounding box. + with_yaw (Bool): Whether the boundind box is with rotation. + mode (str, optional): Specify which line should be matched, + available mode are ('bottom', 'top', 'left', 'right'). + Defaults to 'bottom'. Returns: Tuple: Flag of matching correspondence. """ - sel1 = torch.abs(points[:, 0] - xmin) < self.train_cfg['line_thresh'] - sel2 = torch.abs(points[:, 0] - xmax) < self.train_cfg['line_thresh'] - sel3 = torch.abs(points[:, 1] - ymin) < self.train_cfg['line_thresh'] - sel4 = torch.abs(points[:, 1] - ymax) < self.train_cfg['line_thresh'] - return sel1, sel2, sel3, sel4 + if with_yaw: + corners_pair = { + 'bottom': [[0, 3], [4, 7], [0, 4], [3, 7]], + 'top': [[1, 2], [5, 6], [1, 5], [2, 6]], + 'left': [[0, 1], [3, 2], [0, 1], [3, 2]], + 'right': [[4, 5], [7, 6], [4, 5], [7, 6]] + } + selected_list = [] + for pair_index in corners_pair[mode]: + selected = self.point2line_dist( + points, corners[pair_index[0]], corners[pair_index[1]]) \ + < self.train_cfg['line_thresh'] + selected_list.append(selected) + else: + xmin, ymin, _ = corners.min(0)[0] + xmax, ymax, _ = corners.max(0)[0] + sel1 = torch.abs(points[:, 0] - + xmin) < self.train_cfg['line_thresh'] + sel2 = torch.abs(points[:, 0] - + xmax) < self.train_cfg['line_thresh'] + sel3 = torch.abs(points[:, 1] - + ymin) < self.train_cfg['line_thresh'] + sel4 = torch.abs(points[:, 1] - + ymax) < self.train_cfg['line_thresh'] + selected_list = [sel1, sel2, sel3, sel4] + return selected_list def match_point2plane(self, plane, points): """Match points to plane. @@ -757,10 +800,18 @@ def get_primitive_center(self, pred_flag, center): center = center + offset * selected.unsqueeze(-1) return center, pred_indices - def _assign_primitive_line_targets(self, point_mask, point_offset, - point_sem, coords, indices, cls_label, - point2line_matching, corners, - center_axises): + def _assign_primitive_line_targets(self, + point_mask, + point_offset, + point_sem, + coords, + indices, + cls_label, + point2line_matching, + corners, + center_axises, + with_yaw, + mode='bottom'): """Generate targets of line primitive. Args: @@ -778,16 +829,35 @@ def _assign_primitive_line_targets(self, point_mask, point_offset, corners (torch.Tensor): Corners of the ground truth bounding box. center_axises (list[int]): Indicate in which axis the line center should be refined. + with_yaw (Bool): Whether the boundind box is with rotation. + mode (str, optional): Specify which line should be matched, + available mode are ('bottom', 'top', 'left', 'right'). + Defaults to 'bottom'. Returns: Tuple: Targets of the line primitive. """ - for line_select, center_axis in zip(point2line_matching, - center_axises): + corners_pair = { + 'bottom': [[0, 3], [4, 7], [0, 4], [3, 7]], + 'top': [[1, 2], [5, 6], [1, 5], [2, 6]], + 'left': [[0, 1], [3, 2]], + 'right': [[4, 5], [7, 6]] + } + corners_pair = corners_pair[mode] + assert len(corners_pair) == len(point2line_matching) == len( + center_axises) + for line_select, center_axis, pair_index in zip( + point2line_matching, center_axises, corners_pair): if line_select.sum() > self.train_cfg['num_point_line']: point_mask[indices[line_select]] = 1.0 - line_center = coords[line_select].mean(dim=0) - line_center[center_axis] = corners[:, center_axis].mean() + + if with_yaw: + line_center = (corners[pair_index[0]] + + corners[pair_index[1]]) / 2 + else: + line_center = coords[line_select].mean(dim=0) + line_center[center_axis] = corners[:, center_axis].mean() + point_offset[indices[line_select]] = \ line_center - coords[line_select] point_sem[indices[line_select]] = \ @@ -795,9 +865,16 @@ def _assign_primitive_line_targets(self, point_mask, point_offset, line_center[2], cls_label]) return point_mask, point_offset, point_sem - def _assign_primitive_surface_targets(self, point_mask, point_offset, - point_sem, coords, indices, - cls_label, corners): + def _assign_primitive_surface_targets(self, + point_mask, + point_offset, + point_sem, + coords, + indices, + cls_label, + corners, + with_yaw, + mode='bottom'): """Generate targets for primitive z and primitive xy. Args: @@ -811,29 +888,64 @@ def _assign_primitive_surface_targets(self, point_mask, point_offset, indices (torch.Tensor): Indices of the selected points. cls_label (int): Class label of the ground truth bounding box. corners (torch.Tensor): Corners of the ground truth bounding box. + with_yaw (Bool): Whether the boundind box is with rotation. + mode (str, optional): Specify which line should be matched, + available mode are ('bottom', 'top', 'left', 'right', + 'front', 'back'). + Defaults to 'bottom'. Returns: Tuple: Targets of the center primitive. """ point_mask[indices] = 1.0 + corners_pair = { + 'bottom': [0, 7], + 'top': [1, 6], + 'left': [0, 1], + 'right': [4, 5], + 'front': [0, 1], + 'back': [3, 2] + } + pair_index = corners_pair[mode] if self.primitive_mode == 'z': - center = point_mask.new_tensor([ - corners[:, 0].mean(), corners[:, 1].mean(), coords[:, - 2].mean() - ]) - point_sem[indices] = point_sem.new_tensor([ - center[0], center[1], center[2], - corners[:, 0].max() - corners[:, 0].min(), - corners[:, 1].max() - corners[:, 1].min(), cls_label - ]) + if with_yaw: + center = (corners[pair_index[0]] + + corners[pair_index[1]]) / 2.0 + center[2] = coords[:, 2].mean() + point_sem[indices] = point_sem.new_tensor([ + center[0], center[1], + center[2], (corners[4] - corners[0]).norm(), + (corners[3] - corners[0]).norm(), cls_label + ]) + else: + center = point_mask.new_tensor([ + corners[:, 0].mean(), corners[:, 1].mean(), + coords[:, 2].mean() + ]) + point_sem[indices] = point_sem.new_tensor([ + center[0], center[1], center[2], + corners[:, 0].max() - corners[:, 0].min(), + corners[:, 1].max() - corners[:, 1].min(), cls_label + ]) elif self.primitive_mode == 'xy': - center = point_mask.new_tensor([ - coords[:, 0].mean(), coords[:, 1].mean(), corners[:, 2].mean() - ]) - point_sem[indices] = point_sem.new_tensor([ - center[0], center[1], center[2], - corners[:, 2].max() - corners[:, 2].min(), cls_label - ]) + if with_yaw: + center = coords.mean(0) + center[2] = (corners[pair_index[0], 2] + + corners[pair_index[1], 2]) / 2.0 + point_sem[indices] = point_sem.new_tensor([ + center[0], center[1], center[2], + corners[pair_index[1], 2] - corners[pair_index[0], 2], + cls_label + ]) + else: + center = point_mask.new_tensor([ + coords[:, 0].mean(), coords[:, 1].mean(), + corners[:, 2].mean() + ]) + point_sem[indices] = point_sem.new_tensor([ + center[0], center[1], center[2], + corners[:, 2].max() - corners[:, 2].min(), cls_label + ]) point_offset[indices] = center - coords return point_mask, point_offset, point_sem