-
Notifications
You must be signed in to change notification settings - Fork 57
/
bbox.py
92 lines (76 loc) · 4.17 KB
/
bbox.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
from typing import List
import torch
from torch import Tensor
class BBox(object):
def __init__(self, left: float, top: float, right: float, bottom: float):
super().__init__()
self.left = left
self.top = top
self.right = right
self.bottom = bottom
def __repr__(self) -> str:
return 'BBox[l={:.1f}, t={:.1f}, r={:.1f}, b={:.1f}]'.format(
self.left, self.top, self.right, self.bottom)
def tolist(self) -> List[float]:
return [self.left, self.top, self.right, self.bottom]
@staticmethod
def to_center_base(bboxes: Tensor) -> Tensor:
return torch.stack([
(bboxes[..., 0] + bboxes[..., 2]) / 2,
(bboxes[..., 1] + bboxes[..., 3]) / 2,
bboxes[..., 2] - bboxes[..., 0],
bboxes[..., 3] - bboxes[..., 1]
], dim=-1)
@staticmethod
def from_center_base(center_based_bboxes: Tensor) -> Tensor:
return torch.stack([
center_based_bboxes[..., 0] - center_based_bboxes[..., 2] / 2,
center_based_bboxes[..., 1] - center_based_bboxes[..., 3] / 2,
center_based_bboxes[..., 0] + center_based_bboxes[..., 2] / 2,
center_based_bboxes[..., 1] + center_based_bboxes[..., 3] / 2
], dim=-1)
@staticmethod
def calc_transformer(src_bboxes: Tensor, dst_bboxes: Tensor) -> Tensor:
center_based_src_bboxes = BBox.to_center_base(src_bboxes)
center_based_dst_bboxes = BBox.to_center_base(dst_bboxes)
transformers = torch.stack([
(center_based_dst_bboxes[..., 0] - center_based_src_bboxes[..., 0]) / center_based_src_bboxes[..., 2],
(center_based_dst_bboxes[..., 1] - center_based_src_bboxes[..., 1]) / center_based_src_bboxes[..., 3],
torch.log(center_based_dst_bboxes[..., 2] / center_based_src_bboxes[..., 2]),
torch.log(center_based_dst_bboxes[..., 3] / center_based_src_bboxes[..., 3])
], dim=-1)
return transformers
@staticmethod
def apply_transformer(src_bboxes: Tensor, transformers: Tensor) -> Tensor:
center_based_src_bboxes = BBox.to_center_base(src_bboxes)
center_based_dst_bboxes = torch.stack([
transformers[..., 0] * center_based_src_bboxes[..., 2] + center_based_src_bboxes[..., 0],
transformers[..., 1] * center_based_src_bboxes[..., 3] + center_based_src_bboxes[..., 1],
torch.exp(transformers[..., 2]) * center_based_src_bboxes[..., 2],
torch.exp(transformers[..., 3]) * center_based_src_bboxes[..., 3]
], dim=-1)
dst_bboxes = BBox.from_center_base(center_based_dst_bboxes)
return dst_bboxes
@staticmethod
def iou(source: Tensor, other: Tensor) -> Tensor:
source, other = source.unsqueeze(dim=-2).repeat(1, 1, other.shape[-2], 1), \
other.unsqueeze(dim=-3).repeat(1, source.shape[-2], 1, 1)
source_area = (source[..., 2] - source[..., 0]) * (source[..., 3] - source[..., 1])
other_area = (other[..., 2] - other[..., 0]) * (other[..., 3] - other[..., 1])
intersection_left = torch.max(source[..., 0], other[..., 0])
intersection_top = torch.max(source[..., 1], other[..., 1])
intersection_right = torch.min(source[..., 2], other[..., 2])
intersection_bottom = torch.min(source[..., 3], other[..., 3])
intersection_width = torch.clamp(intersection_right - intersection_left, min=0)
intersection_height = torch.clamp(intersection_bottom - intersection_top, min=0)
intersection_area = intersection_width * intersection_height
return intersection_area / (source_area + other_area - intersection_area)
@staticmethod
def inside(bboxes: Tensor, left: float, top: float, right: float, bottom: float) -> Tensor:
return ((bboxes[..., 0] >= left) * (bboxes[..., 1] >= top) *
(bboxes[..., 2] <= right) * (bboxes[..., 3] <= bottom))
@staticmethod
def clip(bboxes: Tensor, left: float, top: float, right: float, bottom: float) -> Tensor:
bboxes[..., [0, 2]] = bboxes[..., [0, 2]].clamp(min=left, max=right)
bboxes[..., [1, 3]] = bboxes[..., [1, 3]].clamp(min=top, max=bottom)
return bboxes