-
Notifications
You must be signed in to change notification settings - Fork 12
/
Copy pathutils.py
123 lines (107 loc) · 3.25 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import os
import torch
import numpy as np
import torchvision.transforms as transforms
from functools import lru_cache
@lru_cache(maxsize=None)
def meshgrid(B, H, W, dtype, device, normalized=False):
"""
Create mesh-grid given batch size, height and width dimensions. From https://github.com/TRI-ML/KP2D.
Parameters
----------
B: int
Batch size
H: int
Grid Height
W: int
Batch size
dtype: torch.dtype
Tensor dtype
device: str
Tensor device
normalized: bool
Normalized image coordinates or integer-grid.
Returns
-------
xs: torch.Tensor
Batched mesh-grid x-coordinates (BHW).
ys: torch.Tensor
Batched mesh-grid y-coordinates (BHW).
"""
if normalized:
xs = torch.linspace(-1, 1, W, device=device, dtype=dtype)
ys = torch.linspace(-1, 1, H, device=device, dtype=dtype)
else:
xs = torch.linspace(0, W-1, W, device=device, dtype=dtype)
ys = torch.linspace(0, H-1, H, device=device, dtype=dtype)
ys, xs = torch.meshgrid([ys, xs])
return xs.repeat([B, 1, 1]), ys.repeat([B, 1, 1])
@lru_cache(maxsize=None)
def image_grid(B, H, W, dtype, device, ones=True, normalized=False):
"""
Create an image mesh grid with shape B3HW given image shape BHW. From https://github.com/TRI-ML/KP2D.
Parameters
----------
B: int
Batch size
H: int
Grid Height
W: int
Batch size
dtype: str
Tensor dtype
device: str
Tensor device
ones : bool
Use (x, y, 1) coordinates
normalized: bool
Normalized image coordinates or integer-grid.
Returns
-------
grid: torch.Tensor
Mesh-grid for the corresponding image shape (B3HW)
"""
xs, ys = meshgrid(B, H, W, dtype, device, normalized=normalized)
coords = [xs, ys]
if ones:
coords.append(torch.ones_like(xs)) # BHW
grid = torch.stack(coords, dim=1) # B3HW
return grid
def to_tensor_sample(sample, tensor_type='torch.FloatTensor'):
"""
Casts the keys of sample to tensors. From https://github.com/TRI-ML/KP2D.
Parameters
----------
sample : dict
Input sample
tensor_type : str
Type of tensor we are casting to
Returns
-------
sample : dict
Sample with keys cast as tensors
"""
transform = transforms.ToTensor()
sample['image'] = transform(sample['image']).type(tensor_type)
return sample
def warp_keypoints(keypoints, H):
"""Warp keypoints given a homography
Parameters
----------
keypoints: numpy.ndarray (N,2)
Keypoint vector.
H: numpy.ndarray (3,3)
Homography.
Returns
-------
warped_keypoints: numpy.ndarray (N,2)
Warped keypoints vector.
"""
num_points = keypoints.shape[0]
homogeneous_points = np.concatenate([keypoints, np.ones((num_points, 1))], axis=1)
warped_points = np.dot(homogeneous_points, np.transpose(H))
return warped_points[:, :2] / warped_points[:, 2:]
def prepare_dirs(config):
for path in [config.ckpt_dir]:
if not os.path.exists(path):
os.makedirs(path)