-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathtrain_model.py
135 lines (110 loc) · 3.91 KB
/
train_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
import sys
sys.path.insert(1,'jobman')
sys.path.insert(1,'coco-caption')
import numpy
import os, sys, socket
import time
import logging
from config import config
from jobman import DD, expand
import common
import numpy as np
import model_attention
import model_lstmdd
import model_mtle
logging.basicConfig()
logger = logging.getLogger(__name__)
def set_config(conf, args, add_new_key=False):
# add_new_key: if conf does not contain the key, creates it
for key in args:
if key != 'jobman':
v = args[key]
if isinstance(v, DD):
set_config(conf[key], v)
else:
if conf.has_key(key):
conf[key] = convert_from_string(v)
elif add_new_key:
# create a new key in conf
conf[key] = convert_from_string(v)
else:
raise KeyError(key)
def convert_from_string(x):
"""
Convert a string that may represent a Python item to its proper data type.
It consists in running `eval` on x, and if an error occurs, returning the
string itself.
"""
try:
return eval(x, {}, {})
except Exception:
return x
def train_from_scratch(config, state, channel):
# Model options
save_model_dir = config[config.model].save_model_dir
np.random.seed(int(config.random_seed))
if save_model_dir == 'current':
config[config.model].save_model_dir = './'
save_model_dir = './'
# to facilitate the use of cluster for multiple jobs
save_path = './model_config.pkl'
else:
# run locally, save locally
save_path = os.path.join(save_model_dir ,'model_config.pkl')
print 'current save dir ',save_model_dir
common.create_dir_if_not_exist(save_model_dir)
reload_ = config[config.model].reload_
if reload_:
print 'preparing reload'
save_dir_backup = config[config.model].save_model_dir
from_dir_backup = config[config.model].from_dir
# never start retrain in the same folder
assert save_dir_backup != from_dir_backup
print 'save dir ',save_dir_backup
print 'from_dir ',from_dir_backup
print 'setting current model config with the old one'
if config[config.model].mode=='train':
model_config_old = common.load_pkl(from_dir_backup+'/model_config.pkl')
set_config(config, model_config_old)
config[config.model].save_model_dir = save_dir_backup
config[config.model].from_dir = from_dir_backup
config[config.model].reload_ = True
if config.erase_history:
print 'erasing everything in ',save_model_dir
os.system('rm %s/*'%save_model_dir)
# for stdout file logging
#sys.stdout = Unbuffered(sys.stdout, state.save_model_path + 'stdout.log')
print 'saving model config into %s'%save_path
common.dump_pkl(config, save_path)
# Also copy back from config into state.
for key in config:
setattr(state, key, config[key])
model_type = config.model
print 'Model Type: %s'%model_type
print 'Host: %s' % socket.gethostname()
print 'Command: %s' % ' '.join(sys.argv)
if config.model == 'attention':
model_attention.train_from_scratch(state, channel)
elif config.model == 'lstmdd':
model_lstmdd.train_from_scratch(state, channel)
elif config.model == 'mtle':
model_mtle.train_from_scratch(state, channel)
else:
raise NotImplementedError()
def main(state, channel=None):
set_config(config, state)
train_from_scratch(config, state, channel)
if __name__ == '__main__':
args = {}
try:
for arg in sys.argv[1:]:
k, v = arg.split('=')
args[k] = v
except:
print 'args must be like a=X b.c=X'
exit(1)
state = expand(args)
try:
main(state)
except Exception as e:
logger.exception(e)