forked from fxia22/stn.pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathgridgen.py
95 lines (75 loc) · 3.97 KB
/
gridgen.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
# functions/add.py
import torch
from torch.autograd import Function
import numpy as np
class AffineGridGenFunction(Function):
def __init__(self, height, width,lr=1):
super(AffineGridGenFunction, self).__init__()
self.lr = lr
self.height, self.width = height, width
self.grid = np.zeros( [self.height, self.width, 3], dtype=np.float32)
self.grid[:,:,0] = np.expand_dims(np.repeat(np.expand_dims(np.arange(-1, 1, 2.0/self.height), 0), repeats = self.width, axis = 0).T, 0)
self.grid[:,:,1] = np.expand_dims(np.repeat(np.expand_dims(np.arange(-1, 1, 2.0/self.width), 0), repeats = self.height, axis = 0), 0)
self.grid[:,:,2] = np.ones([self.height, width])
self.grid = torch.from_numpy(self.grid.astype(np.float32))
#print(self.grid)
def forward(self, input1):
self.input1 = input1
output = torch.zeros(torch.Size([input1.size(0)]) + self.grid.size())
self.batchgrid = torch.zeros(torch.Size([input1.size(0)]) + self.grid.size())
for i in range(input1.size(0)):
self.batchgrid[i] = self.grid
if input1.is_cuda:
self.batchgrid = self.batchgrid.cuda()
output = output.cuda()
for i in range(input1.size(0)):
output = torch.bmm(self.batchgrid.view(-1, self.height*self.width, 3), torch.transpose(input1, 1, 2)).view(-1, self.height, self.width, 2)
return output
def backward(self, grad_output):
grad_input1 = torch.zeros(self.input1.size())
if grad_output.is_cuda:
self.batchgrid = self.batchgrid.cuda()
grad_input1 = grad_input1.cuda()
#print('gradout:',grad_output.size())
grad_input1 = torch.baddbmm(grad_input1, torch.transpose(grad_output.view(-1, self.height*self.width, 2), 1,2), self.batchgrid.view(-1, self.height*self.width, 3))
#print(grad_input1)
return grad_input1
class CylinderGridGenFunction(Function):
def __init__(self, height, width,lr=1):
super(CylinderGridGenFunction, self).__init__()
self.lr = lr
self.height, self.width = height, width
self.grid = np.zeros( [self.height, self.width, 3], dtype=np.float32)
self.grid[:,:,0] = np.expand_dims(np.repeat(np.expand_dims(np.arange(-1, 1, 2.0/self.height), 0), repeats = self.width, axis = 0).T, 0)
self.grid[:,:,1] = np.expand_dims(np.repeat(np.expand_dims(np.arange(-1, 1, 2.0/self.width), 0), repeats = self.height, axis = 0), 0)
self.grid[:,:,2] = np.ones([self.height, width])
self.grid = torch.from_numpy(self.grid.astype(np.float32))
#print self.grid
def forward(self, input1):
self.input1 = (1+torch.cos(input1))/2
output = torch.zeros(torch.Size([input1.size(0), self.height, self.width, 2]) )
if not self.input1.is_cuda:
for i in range(self.input1.size(0)):
x = self.input1[i][0]
low = int(np.ceil(self.width*self.input1[i][0]))
frac = self.width*self.input1[i][0] - low
interp = frac * 2 * (1-x) + (1-frac) * 2 * (-x)
output[i,:,:,1] = torch.zeros(self.grid[:,:,1].size())
if low <= self.width and low > 0:
output[i,:,:low,1].fill_(2*(1-x))
if low < self.width and low >= 0:
output[i,:,low:,1].fill_(2*(-x))
output[i,:,:,1] = output[i,:,:,1] + self.grid[:,:,1]
output[i,:,:,0] = self.grid[:,:,0]
else:
print('not implemented')
return output
def backward(self, grad_output):
grad_input1 = torch.zeros(self.input1.size())
if not grad_output.is_cuda:
for i in range(self.input1.size(0)):
#print(torch.sum(grad_output[i,:,:,1],1).size())
grad_input1[i] = -torch.sum(torch.sum(grad_output[i,:,:,1],1)) * torch.sin(self.input1[i]) / 2
else:
print('not implemented')
return grad_input1 * self.lr