-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathtrain_hourglass_model.py
38 lines (26 loc) · 1.19 KB
/
train_hourglass_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
import argparse
import tensorflow as tf
from core.hourglass_model import HourglassModel
from core.simple_trainer import SimpleTrainer
from core.simple_predictor import SimplePredictor
from core.batcher import Batcher
def run():
train_batcher = Batcher(path_to_data='./data/images', path_to_csv='./data/train.csv', crop_type='random_crop',
batch_size=28)
valid_batcher = Batcher(path_to_data='./data/images', path_to_csv='./data/valid.csv', crop_type='random_crop',
batch_size=28)
model = HourglassModel()
model.build_model()
sess = tf.Session()
trainer = SimpleTrainer(model, train_batcher, valid_batcher, sess, 100, 10)
trainer.train(enable_phase=True)
test_batcher = Batcher(path_to_data='./data/images', path_to_csv='./data/test.csv', batch_size=4,
batcher_type='test', norm_type='range')
predictor = SimplePredictor(model, test_batcher, sess)
print('Start testing...')
mean_pos, mean_qua = predictor.test(enable_phase=True)
print('Mean pose error: {}, mean quaternion error: {}'.format(mean_pos, mean_qua))
def parse_args():
pass
if __name__ == '__main__':
run()