-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathsimple_predictor.py
52 lines (40 loc) · 1.56 KB
/
simple_predictor.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
import numpy as np
from utils.metrics import calc_pose_error, calc_qua_error
from utils.image_processing import inf_preprocess
class SimplePredictor:
def __init__(self, model, test_batcher, sess):
self.model = model
self.test_batcher = test_batcher
self.sess = sess
def test(self, enable_phase=False):
mean_pose_error = 0
mean_qua_error = 0
batch_size = self.test_batcher.b_size
count = 0
while True:
try:
batch = self.test_batcher.next_batch()
feed_dict = {
self.model.x: batch[0]
}
if enable_phase:
feed_dict[self.model.phase] = False
[prediction] = self.sess.run([self.model.prediction], feed_dict=feed_dict)
mean_pose_error += calc_pose_error(batch[1][:, :3], prediction[:, :3])
mean_qua_error += calc_qua_error(batch[1][:, 3:], prediction[:, 3:])
count += batch_size
except StopIteration:
break
mean_qua_error /= count
mean_pose_error /= count
return mean_pose_error, mean_qua_error
def sample_predict(self, path_to_img, enable_phase=False):
data = inf_preprocess(path_to_img)
data = np.expand_dims(data, axis=0)
feed_dict = {
self.model.x: data
}
if enable_phase:
feed_dict[self.model.phase] = False
[pred] = self.sess.run([self.model.prediction], feed_dict=feed_dict)
return pred[:3], pred[3:]