Skip to content

Commit

Permalink
Poly discrete rotate (#281)
Browse files Browse the repository at this point in the history
* [Docs] Update the introduction of SASM (AAAI'22) (#184)

* [Docs] Update the introduction of SASM (AAAI'22)

* Update citation

* [Fix] 'RoIAlignRotated' object has no attribute 'output_size' (#213)

* Update rotate_single_level_roi_extractor.py

* Update rotate_single_level_roi_extractor.py

* [Fix]  Modify the use of rotated anchor inside flags (#197)

* rotated RPN bug

* Modify the use of rotated anchor inside flags

* Added PolyDiscreteRotate transform and fixed typos

* Added unit test for PolyDiscreteRotate

* Fixed linting of pipelines __init__

* Fused the 2 transforms

* [Enhance] Support mask in merge_results and huge_image_demo.py. (#280)

* Support masks mergeing

* Update error report

* Added mode argument

* Fixed redundancy and added value unit test

Co-authored-by: Wenwei Zhang <40779233+ZwwWayne@users.noreply.github.com>
Co-authored-by: yangxue <yangxue0827@126.com>
Co-authored-by: Yue Zhou <592267829@qq.com>
Co-authored-by: GamblerZSY <86358394+GamblerZSY@users.noreply.github.com>
Co-authored-by: jbwang1997 <jbwang1997@gmail.com>
  • Loading branch information
6 people authored May 18, 2022
1 parent e077a9a commit 8b6a806
Show file tree
Hide file tree
Showing 6 changed files with 166 additions and 37 deletions.
6 changes: 5 additions & 1 deletion mmrotate/apis/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,5 +85,9 @@ def inference_detector_by_patches(model,
start += bs

results = merge_results(
results, windows[:, :2], iou_thr=merge_iou_thr, device=device)
results,
windows[:, :2],
img_shape=(width, height),
iou_thr=merge_iou_thr,
device=device)
return results
128 changes: 109 additions & 19 deletions mmrotate/core/patch/merge_results.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,78 @@
# Copyright (c) OpenMMLab. All rights reserved.
import numpy as np
import torch
from mmcv.ops import nms_rotated
from mmcv.ops import nms, nms_rotated


def merge_results(results, offsets, iou_thr=0.1, device='cpu'):
def translate_bboxes(bboxes, offset):
"""Translate bboxes according to its shape.
If the bbox shape is (n, 5), the bboxes are regarded as horizontal bboxes
and in (x, y, x, y, score) format. If the bbox shape is (n, 6), the bboxes
are regarded as rotated bboxes and in (x, y, w, h, theta, score) format.
Args:
bboxes (np.ndarray): The bboxes need to be translated. Its shape can
only be (n, 5) and (n, 6).
offset (np.ndarray): The offset to translate with shape being (2, ).
Returns:
np.ndarray: Translated bboxes.
"""
if bboxes.shape[1] == 5:
bboxes[:, :4] = bboxes[:, :4] + np.tile(offset, 2)
elif bboxes.shape[1] == 6:
bboxes[:, :2] = bboxes[:, :2] + offset
else:
raise TypeError('Require the shape of `bboxes` to be (n, 5) or (n, 6),'
f' but get `bboxes` with shape being {bboxes.shape}.')
return bboxes


def map_masks(masks, offset, new_shape):
"""Map masks to the huge image.
Args:
masks (list[np.ndarray]): masks need to be mapped.
offset (np.ndarray): The offset to translate with shape being (2, ).
new_shape (tuple): A tuple of the huge image's width and height.
Returns:
list[np.ndarray]: Mapped masks.
"""
if not masks:
return masks

new_width, new_height = new_shape
x_start, y_start = offset
mapped = []
for mask in masks:
ori_height, ori_width = mask.shape[:2]

x_end = x_start + ori_width
if x_end > new_width:
ori_width -= x_end - new_width
x_end = new_width

y_end = y_start + ori_height
if y_end > new_height:
ori_height -= y_end - new_height
y_end = new_height

extended_mask = np.zeros((new_height, new_width), dtype=np.bool)
extended_mask[y_start:y_end,
x_start:x_end] = mask[:ori_height, :ori_width]
mapped.append(extended_mask)
return mapped


def merge_results(results, offsets, img_shape, iou_thr=0.1, device='cpu'):
"""Merge patch results via nms.
Args:
results (list[np.ndarray]): A list of patches results.
results (list[np.ndarray] | list[tuple]): A list of patches results.
offsets (np.ndarray): Positions of the left top points of patches.
img_shape (tuple): A tuple of the huge image's width and height.
iou_thr (float): The IoU threshold of NMS.
device (str): The device to call nms.
Expand All @@ -18,20 +81,47 @@ def merge_results(results, offsets, iou_thr=0.1, device='cpu'):
"""
assert len(results) == offsets.shape[0], 'The `results` should has the ' \
'same length with `offsets`.'
merged_results = []
for results_pre_cls in zip(*results):
tran_dets = []
for dets, offset in zip(results_pre_cls, offsets):
dets[:, :2] += offset
tran_dets.append(dets)
tran_dets = np.concatenate(tran_dets, axis=0)

if tran_dets.size == 0:
merged_results.append(tran_dets)
with_mask = isinstance(results[0], tuple)
num_patches = len(results)
num_classes = len(results[0][0]) if with_mask else len(results[0])

merged_bboxes = []
merged_masks = []
for cls in range(num_classes):
if with_mask:
dets_per_cls = [results[i][0][cls] for i in range(num_patches)]
masks_per_cls = [results[i][1][cls] for i in range(num_patches)]
else:
tran_dets = torch.from_numpy(tran_dets)
tran_dets = tran_dets.to(device)
nms_dets, _ = nms_rotated(tran_dets[:, :5], tran_dets[:, -1],
iou_thr)
merged_results.append(nms_dets.cpu().numpy())
return merged_results
dets_per_cls = [results[i][cls] for i in range(num_patches)]
masks_per_cls = None

dets_per_cls = [
translate_bboxes(dets_per_cls[i], offsets[i])
for i in range(num_patches)
]
dets_per_cls = np.concatenate(dets_per_cls, axis=0)
if with_mask:
masks_placeholder = []
for i, masks in enumerate(masks_per_cls):
translated = map_masks(masks, offsets[i], img_shape)
masks_placeholder.extend(translated)
masks_per_cls = masks_placeholder

if dets_per_cls.size == 0:
merged_bboxes.append(dets_per_cls)
if with_mask:
merged_masks.append(masks_per_cls)
else:
dets_per_cls = torch.from_numpy(dets_per_cls).to(device)
nms_func = nms if dets_per_cls.size(1) == 5 else nms_rotated
nms_dets, keeps = nms_func(dets_per_cls[:, :-1],
dets_per_cls[:, -1], iou_thr)
merged_bboxes.append(nms_dets.cpu().numpy())
if with_mask:
keeps = keeps.cpu().numpy()
merged_masks.append([masks_per_cls[i] for i in keeps])

if with_mask:
return merged_bboxes, merged_masks
else:
return merged_bboxes
34 changes: 28 additions & 6 deletions mmrotate/datasets/pipelines/transforms.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
import cv2
import mmcv
import numpy as np
from mmdet.datasets.pipelines.transforms import RandomFlip, Resize

Expand Down Expand Up @@ -97,25 +98,42 @@ class PolyRandomRotate(object):
Reference: https://github.com/hukaixuan19970627/OrientedRepPoints_DOTA
Args:
rate (bool): (float, optional): The rotating probability.
rotate_ratio (float, optional): The rotating probability.
Default: 0.5.
angles_range(int, optional): The rotate angle defined by random
(-angles_range, +angles_range).
mode (str, optional) : Indicates whether the angle is chosen in a
random range (mode='range') or in a preset list of angles
(mode='value'). Defaults to 'range'.
angles_range(int|list[int], optional): The range of angles.
If mode='range', angle_ranges is an int and the angle is chosen
in (-angles_range, +angles_ranges).
If mode='value', angles_range is a non-empty list of int and the
angle is chosen in angles_range.
Defaults to 180 as default mode is 'range'.
auto_bound(bool, optional): whether to find the new width and height
bounds.
rect_classes (None|list, optional): Specifies classes that needs to
be rotated by a multiple of 90 degrees.
version (str, optional): Angle representations. Defaults to 'oc'.
version (str, optional): Angle representations. Defaults to 'le90'.
"""

def __init__(self,
rotate_ratio=0.5,
mode='range',
angles_range=180,
auto_bound=False,
rect_classes=None,
version='le90'):
self.rotate_ratio = rotate_ratio
self.auto_bound = auto_bound
assert mode in ['range', 'value'], \
f"mode is supposed to be 'range' or 'value', but got {mode}."
if mode == 'range':
assert isinstance(angles_range, int), \
"mode 'range' expects angle_range to be an int."
else:
assert mmcv.is_seq_of(angles_range, int) and len(angles_range), \
"mode 'value' expects angle_range as a non-empty list of int."
self.mode = mode
self.angles_range = angles_range
self.discrete_range = [90, 180, -90, -180]
self.rect_classes = rect_classes
Expand Down Expand Up @@ -177,9 +195,12 @@ def __call__(self, results):
results['rotate'] = False
angle = 0
else:
angle = 2 * self.angles_range * np.random.rand() - \
self.angles_range
results['rotate'] = True
if self.mode == 'range':
angle = self.angles_range * (2 * np.random.rand() - 1)
else:
i = np.random.randint(len(self.angles_range))
angle = self.angles_range[i]

class_labels = results['gt_labels']
for classid in class_labels:
Expand Down Expand Up @@ -237,6 +258,7 @@ def __call__(self, results):
def __repr__(self):
repr_str = self.__class__.__name__
repr_str += f'(rotate_ratio={self.rotate_ratio}, ' \
f'base_angles={self.base_angles}, ' \
f'angles_range={self.angles_range}, ' \
f'auto_bound={self.auto_bound})'
return repr_str
10 changes: 5 additions & 5 deletions mmrotate/models/dense_heads/oriented_rpn_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
import torch
import torch.nn as nn
from mmcv.ops import batched_nms
from mmdet.core import unmap
from mmdet.core import anchor_inside_flags, unmap

from mmrotate.core import obb2xyxy, rotated_anchor_inside_flags
from mmrotate.core import obb2xyxy
from ..builder import ROTATED_HEADS
from .rotated_rpn_head import RotatedRPNHead

Expand Down Expand Up @@ -64,9 +64,9 @@ def _get_targets_single(self,
- num_total_pos (int): Number of positive samples in all images
- num_total_neg (int): Number of negative samples in all images
"""
inside_flags = rotated_anchor_inside_flags(
flat_anchors, valid_flags, img_meta['img_shape'][:2],
self.train_cfg.allowed_border)
inside_flags = anchor_inside_flags(flat_anchors, valid_flags,
img_meta['img_shape'][:2],
self.train_cfg.allowed_border)
if not inside_flags.any():
return (None, ) * 7
# assign gt and sample anchors
Expand Down
11 changes: 6 additions & 5 deletions mmrotate/models/dense_heads/rotated_rpn_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,11 @@
import torch.nn.functional as F
from mmcv.ops import batched_nms
from mmcv.runner import force_fp32
from mmdet.core import images_to_levels, multi_apply, unmap
from mmdet.core import (anchor_inside_flags, images_to_levels, multi_apply,
unmap)
from mmdet.models.dense_heads.anchor_head import AnchorHead

from mmrotate.core import obb2xyxy, rotated_anchor_inside_flags
from mmrotate.core import obb2xyxy
from ..builder import ROTATED_HEADS


Expand Down Expand Up @@ -86,9 +87,9 @@ def _get_targets_single(self,
num_total_pos (int): Number of positive samples in all images
num_total_neg (int): Number of negative samples in all images
"""
inside_flags = rotated_anchor_inside_flags(
flat_anchors, valid_flags, img_meta['img_shape'][:2],
self.train_cfg.allowed_border)
inside_flags = anchor_inside_flags(flat_anchors, valid_flags,
img_meta['img_shape'][:2],
self.train_cfg.allowed_border)
if not inside_flags.any():
return (None, ) * 7
# assign gt and sample anchors
Expand Down
14 changes: 13 additions & 1 deletion tests/test_data/test_pipelines/test_rtransforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,12 +102,24 @@ def test_rotate():
"""Test rotation for rbboxes."""
results = construct_toy_data()

# test PolyRandomRotate
# test PolyRandomRotate with 'range' mode
transform = dict(
type='PolyRandomRotate',
mode='range',
rotate_ratio=1.0,
angles_range=180,
auto_bound=False,
version='oc')
rotate_module = build_from_cfg(transform, PIPELINES)
rotate_module(copy.deepcopy(results))

# test PolyRandomRotate with 'value' mode
transform = dict(
type='PolyRandomRotate',
mode='value',
rotate_ratio=1.0,
angles_range=[30],
auto_bound=False,
version='oc')
rotate_module = build_from_cfg(transform, PIPELINES)
rotate_module(copy.deepcopy(results))

0 comments on commit 8b6a806

Please # to comment.