From 8b6a806f59cd83657aa1635984c02ba40dc6a48e Mon Sep 17 00:00:00 2001 From: remi-or <83456801+remi-or@users.noreply.github.com> Date: Wed, 18 May 2022 07:11:24 -0400 Subject: [PATCH] Poly discrete rotate (#281) * [Docs] Update the introduction of SASM (AAAI'22) (#184) * [Docs] Update the introduction of SASM (AAAI'22) * Update citation * [Fix] 'RoIAlignRotated' object has no attribute 'output_size' (#213) * Update rotate_single_level_roi_extractor.py * Update rotate_single_level_roi_extractor.py * [Fix] Modify the use of rotated anchor inside flags (#197) * rotated RPN bug * Modify the use of rotated anchor inside flags * Added PolyDiscreteRotate transform and fixed typos * Added unit test for PolyDiscreteRotate * Fixed linting of pipelines __init__ * Fused the 2 transforms * [Enhance] Support mask in merge_results and huge_image_demo.py. (#280) * Support masks mergeing * Update error report * Added mode argument * Fixed redundancy and added value unit test Co-authored-by: Wenwei Zhang <40779233+ZwwWayne@users.noreply.github.com> Co-authored-by: yangxue Co-authored-by: Yue Zhou <592267829@qq.com> Co-authored-by: GamblerZSY <86358394+GamblerZSY@users.noreply.github.com> Co-authored-by: jbwang1997 --- mmrotate/apis/inference.py | 6 +- mmrotate/core/patch/merge_results.py | 128 +++++++++++++++--- mmrotate/datasets/pipelines/transforms.py | 34 ++++- .../models/dense_heads/oriented_rpn_head.py | 10 +- .../models/dense_heads/rotated_rpn_head.py | 11 +- .../test_pipelines/test_rtransforms.py | 14 +- 6 files changed, 166 insertions(+), 37 deletions(-) diff --git a/mmrotate/apis/inference.py b/mmrotate/apis/inference.py index 4da151803..fabaa8c0c 100644 --- a/mmrotate/apis/inference.py +++ b/mmrotate/apis/inference.py @@ -85,5 +85,9 @@ def inference_detector_by_patches(model, start += bs results = merge_results( - results, windows[:, :2], iou_thr=merge_iou_thr, device=device) + results, + windows[:, :2], + img_shape=(width, height), + iou_thr=merge_iou_thr, + device=device) return results diff --git a/mmrotate/core/patch/merge_results.py b/mmrotate/core/patch/merge_results.py index 5ef34c125..abf523d95 100644 --- a/mmrotate/core/patch/merge_results.py +++ b/mmrotate/core/patch/merge_results.py @@ -1,15 +1,78 @@ # Copyright (c) OpenMMLab. All rights reserved. import numpy as np import torch -from mmcv.ops import nms_rotated +from mmcv.ops import nms, nms_rotated -def merge_results(results, offsets, iou_thr=0.1, device='cpu'): +def translate_bboxes(bboxes, offset): + """Translate bboxes according to its shape. + + If the bbox shape is (n, 5), the bboxes are regarded as horizontal bboxes + and in (x, y, x, y, score) format. If the bbox shape is (n, 6), the bboxes + are regarded as rotated bboxes and in (x, y, w, h, theta, score) format. + + Args: + bboxes (np.ndarray): The bboxes need to be translated. Its shape can + only be (n, 5) and (n, 6). + offset (np.ndarray): The offset to translate with shape being (2, ). + + Returns: + np.ndarray: Translated bboxes. + """ + if bboxes.shape[1] == 5: + bboxes[:, :4] = bboxes[:, :4] + np.tile(offset, 2) + elif bboxes.shape[1] == 6: + bboxes[:, :2] = bboxes[:, :2] + offset + else: + raise TypeError('Require the shape of `bboxes` to be (n, 5) or (n, 6),' + f' but get `bboxes` with shape being {bboxes.shape}.') + return bboxes + + +def map_masks(masks, offset, new_shape): + """Map masks to the huge image. + + Args: + masks (list[np.ndarray]): masks need to be mapped. + offset (np.ndarray): The offset to translate with shape being (2, ). + new_shape (tuple): A tuple of the huge image's width and height. + + Returns: + list[np.ndarray]: Mapped masks. + """ + if not masks: + return masks + + new_width, new_height = new_shape + x_start, y_start = offset + mapped = [] + for mask in masks: + ori_height, ori_width = mask.shape[:2] + + x_end = x_start + ori_width + if x_end > new_width: + ori_width -= x_end - new_width + x_end = new_width + + y_end = y_start + ori_height + if y_end > new_height: + ori_height -= y_end - new_height + y_end = new_height + + extended_mask = np.zeros((new_height, new_width), dtype=np.bool) + extended_mask[y_start:y_end, + x_start:x_end] = mask[:ori_height, :ori_width] + mapped.append(extended_mask) + return mapped + + +def merge_results(results, offsets, img_shape, iou_thr=0.1, device='cpu'): """Merge patch results via nms. Args: - results (list[np.ndarray]): A list of patches results. + results (list[np.ndarray] | list[tuple]): A list of patches results. offsets (np.ndarray): Positions of the left top points of patches. + img_shape (tuple): A tuple of the huge image's width and height. iou_thr (float): The IoU threshold of NMS. device (str): The device to call nms. @@ -18,20 +81,47 @@ def merge_results(results, offsets, iou_thr=0.1, device='cpu'): """ assert len(results) == offsets.shape[0], 'The `results` should has the ' \ 'same length with `offsets`.' - merged_results = [] - for results_pre_cls in zip(*results): - tran_dets = [] - for dets, offset in zip(results_pre_cls, offsets): - dets[:, :2] += offset - tran_dets.append(dets) - tran_dets = np.concatenate(tran_dets, axis=0) - - if tran_dets.size == 0: - merged_results.append(tran_dets) + with_mask = isinstance(results[0], tuple) + num_patches = len(results) + num_classes = len(results[0][0]) if with_mask else len(results[0]) + + merged_bboxes = [] + merged_masks = [] + for cls in range(num_classes): + if with_mask: + dets_per_cls = [results[i][0][cls] for i in range(num_patches)] + masks_per_cls = [results[i][1][cls] for i in range(num_patches)] else: - tran_dets = torch.from_numpy(tran_dets) - tran_dets = tran_dets.to(device) - nms_dets, _ = nms_rotated(tran_dets[:, :5], tran_dets[:, -1], - iou_thr) - merged_results.append(nms_dets.cpu().numpy()) - return merged_results + dets_per_cls = [results[i][cls] for i in range(num_patches)] + masks_per_cls = None + + dets_per_cls = [ + translate_bboxes(dets_per_cls[i], offsets[i]) + for i in range(num_patches) + ] + dets_per_cls = np.concatenate(dets_per_cls, axis=0) + if with_mask: + masks_placeholder = [] + for i, masks in enumerate(masks_per_cls): + translated = map_masks(masks, offsets[i], img_shape) + masks_placeholder.extend(translated) + masks_per_cls = masks_placeholder + + if dets_per_cls.size == 0: + merged_bboxes.append(dets_per_cls) + if with_mask: + merged_masks.append(masks_per_cls) + else: + dets_per_cls = torch.from_numpy(dets_per_cls).to(device) + nms_func = nms if dets_per_cls.size(1) == 5 else nms_rotated + nms_dets, keeps = nms_func(dets_per_cls[:, :-1], + dets_per_cls[:, -1], iou_thr) + merged_bboxes.append(nms_dets.cpu().numpy()) + if with_mask: + keeps = keeps.cpu().numpy() + merged_masks.append([masks_per_cls[i] for i in keeps]) + + if with_mask: + return merged_bboxes, merged_masks + else: + return merged_bboxes diff --git a/mmrotate/datasets/pipelines/transforms.py b/mmrotate/datasets/pipelines/transforms.py index 3a4510844..93a78e4dd 100644 --- a/mmrotate/datasets/pipelines/transforms.py +++ b/mmrotate/datasets/pipelines/transforms.py @@ -1,5 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. import cv2 +import mmcv import numpy as np from mmdet.datasets.pipelines.transforms import RandomFlip, Resize @@ -97,25 +98,42 @@ class PolyRandomRotate(object): Reference: https://github.com/hukaixuan19970627/OrientedRepPoints_DOTA Args: - rate (bool): (float, optional): The rotating probability. + rotate_ratio (float, optional): The rotating probability. Default: 0.5. - angles_range(int, optional): The rotate angle defined by random - (-angles_range, +angles_range). + mode (str, optional) : Indicates whether the angle is chosen in a + random range (mode='range') or in a preset list of angles + (mode='value'). Defaults to 'range'. + angles_range(int|list[int], optional): The range of angles. + If mode='range', angle_ranges is an int and the angle is chosen + in (-angles_range, +angles_ranges). + If mode='value', angles_range is a non-empty list of int and the + angle is chosen in angles_range. + Defaults to 180 as default mode is 'range'. auto_bound(bool, optional): whether to find the new width and height bounds. rect_classes (None|list, optional): Specifies classes that needs to be rotated by a multiple of 90 degrees. - version (str, optional): Angle representations. Defaults to 'oc'. + version (str, optional): Angle representations. Defaults to 'le90'. """ def __init__(self, rotate_ratio=0.5, + mode='range', angles_range=180, auto_bound=False, rect_classes=None, version='le90'): self.rotate_ratio = rotate_ratio self.auto_bound = auto_bound + assert mode in ['range', 'value'], \ + f"mode is supposed to be 'range' or 'value', but got {mode}." + if mode == 'range': + assert isinstance(angles_range, int), \ + "mode 'range' expects angle_range to be an int." + else: + assert mmcv.is_seq_of(angles_range, int) and len(angles_range), \ + "mode 'value' expects angle_range as a non-empty list of int." + self.mode = mode self.angles_range = angles_range self.discrete_range = [90, 180, -90, -180] self.rect_classes = rect_classes @@ -177,9 +195,12 @@ def __call__(self, results): results['rotate'] = False angle = 0 else: - angle = 2 * self.angles_range * np.random.rand() - \ - self.angles_range results['rotate'] = True + if self.mode == 'range': + angle = self.angles_range * (2 * np.random.rand() - 1) + else: + i = np.random.randint(len(self.angles_range)) + angle = self.angles_range[i] class_labels = results['gt_labels'] for classid in class_labels: @@ -237,6 +258,7 @@ def __call__(self, results): def __repr__(self): repr_str = self.__class__.__name__ repr_str += f'(rotate_ratio={self.rotate_ratio}, ' \ + f'base_angles={self.base_angles}, ' \ f'angles_range={self.angles_range}, ' \ f'auto_bound={self.auto_bound})' return repr_str diff --git a/mmrotate/models/dense_heads/oriented_rpn_head.py b/mmrotate/models/dense_heads/oriented_rpn_head.py index 953d333af..a3367b6ca 100644 --- a/mmrotate/models/dense_heads/oriented_rpn_head.py +++ b/mmrotate/models/dense_heads/oriented_rpn_head.py @@ -4,9 +4,9 @@ import torch import torch.nn as nn from mmcv.ops import batched_nms -from mmdet.core import unmap +from mmdet.core import anchor_inside_flags, unmap -from mmrotate.core import obb2xyxy, rotated_anchor_inside_flags +from mmrotate.core import obb2xyxy from ..builder import ROTATED_HEADS from .rotated_rpn_head import RotatedRPNHead @@ -64,9 +64,9 @@ def _get_targets_single(self, - num_total_pos (int): Number of positive samples in all images - num_total_neg (int): Number of negative samples in all images """ - inside_flags = rotated_anchor_inside_flags( - flat_anchors, valid_flags, img_meta['img_shape'][:2], - self.train_cfg.allowed_border) + inside_flags = anchor_inside_flags(flat_anchors, valid_flags, + img_meta['img_shape'][:2], + self.train_cfg.allowed_border) if not inside_flags.any(): return (None, ) * 7 # assign gt and sample anchors diff --git a/mmrotate/models/dense_heads/rotated_rpn_head.py b/mmrotate/models/dense_heads/rotated_rpn_head.py index 632acbb1e..a7c903f90 100644 --- a/mmrotate/models/dense_heads/rotated_rpn_head.py +++ b/mmrotate/models/dense_heads/rotated_rpn_head.py @@ -6,10 +6,11 @@ import torch.nn.functional as F from mmcv.ops import batched_nms from mmcv.runner import force_fp32 -from mmdet.core import images_to_levels, multi_apply, unmap +from mmdet.core import (anchor_inside_flags, images_to_levels, multi_apply, + unmap) from mmdet.models.dense_heads.anchor_head import AnchorHead -from mmrotate.core import obb2xyxy, rotated_anchor_inside_flags +from mmrotate.core import obb2xyxy from ..builder import ROTATED_HEADS @@ -86,9 +87,9 @@ def _get_targets_single(self, num_total_pos (int): Number of positive samples in all images num_total_neg (int): Number of negative samples in all images """ - inside_flags = rotated_anchor_inside_flags( - flat_anchors, valid_flags, img_meta['img_shape'][:2], - self.train_cfg.allowed_border) + inside_flags = anchor_inside_flags(flat_anchors, valid_flags, + img_meta['img_shape'][:2], + self.train_cfg.allowed_border) if not inside_flags.any(): return (None, ) * 7 # assign gt and sample anchors diff --git a/tests/test_data/test_pipelines/test_rtransforms.py b/tests/test_data/test_pipelines/test_rtransforms.py index 145618f35..24e898aa8 100644 --- a/tests/test_data/test_pipelines/test_rtransforms.py +++ b/tests/test_data/test_pipelines/test_rtransforms.py @@ -102,12 +102,24 @@ def test_rotate(): """Test rotation for rbboxes.""" results = construct_toy_data() - # test PolyRandomRotate + # test PolyRandomRotate with 'range' mode transform = dict( type='PolyRandomRotate', + mode='range', rotate_ratio=1.0, angles_range=180, auto_bound=False, version='oc') rotate_module = build_from_cfg(transform, PIPELINES) rotate_module(copy.deepcopy(results)) + + # test PolyRandomRotate with 'value' mode + transform = dict( + type='PolyRandomRotate', + mode='value', + rotate_ratio=1.0, + angles_range=[30], + auto_bound=False, + version='oc') + rotate_module = build_from_cfg(transform, PIPELINES) + rotate_module(copy.deepcopy(results))