-
Notifications
You must be signed in to change notification settings - Fork 64
/
Copy pathtrain.py
executable file
·91 lines (81 loc) · 3.96 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import time
import argparse
import os
import sys
if sys.version_info >= (3, 0):
import _pickle as cPickle
else:
import cPickle
from sklearn.svm import SVC
from sklearn.metrics import accuracy_score
from data_loader import load_data
from parameters import DATASET, TRAINING, HYPERPARAMS
def train(epochs=HYPERPARAMS.epochs, random_state=HYPERPARAMS.random_state,
kernel=HYPERPARAMS.kernel, decision_function=HYPERPARAMS.decision_function, gamma=HYPERPARAMS.gamma, train_model=True):
print( "loading dataset " + DATASET.name + "...")
if train_model:
data, validation = load_data(validation=True)
else:
data, validation, test = load_data(validation=True, test=True)
if train_model:
# Training phase
print( "building model...")
model = SVC(random_state=random_state, max_iter=epochs, kernel=kernel, decision_function_shape=decision_function, gamma=gamma)
print( "start training...")
print( "--")
print( "kernel: {}".format(kernel))
print( "decision function: {} ".format(decision_function))
print( "max epochs: {} ".format(epochs))
print( "gamma: {} ".format(gamma))
print( "--")
print( "Training samples: {}".format(len(data['Y'])))
print( "Validation samples: {}".format(len(validation['Y'])))
print( "--")
start_time = time.time()
model.fit(data['X'], data['Y'])
training_time = time.time() - start_time
print( "training time = {0:.1f} sec".format(training_time))
if TRAINING.save_model:
print( "saving model...")
with open(TRAINING.save_model_path, 'wb') as f:
cPickle.dump(model, f)
print( "evaluating...")
validation_accuracy = evaluate(model, validation['X'], validation['Y'])
print( " - validation accuracy = {0:.1f}".format(validation_accuracy*100))
return validation_accuracy
else:
# Testing phase : load saved model and evaluate on test dataset
print( "start evaluation...")
print( "loading pretrained model...")
if os.path.isfile(TRAINING.save_model_path):
with open(TRAINING.save_model_path, 'rb') as f:
model = cPickle.load(f)
else:
print( "Error: file '{}' not found".format(TRAINING.save_model_path))
exit()
print( "--")
print( "Validation samples: {}".format(len(validation['Y'])))
print( "Test samples: {}".format(len(test['Y'])))
print( "--")
print( "evaluating...")
start_time = time.time()
validation_accuracy = evaluate(model, validation['X'], validation['Y'])
print( " - validation accuracy = {0:.1f}".format(validation_accuracy*100))
test_accuracy = evaluate(model, test['X'], test['Y'])
print( " - test accuracy = {0:.1f}".format(test_accuracy*100))
print( " - evalution time = {0:.1f} sec".format(time.time() - start_time))
return test_accuracy
def evaluate(model, X, Y):
predicted_Y = model.predict(X)
accuracy = accuracy_score(Y, predicted_Y)
return accuracy
# parse arg to see if we need to launch training now or not yet
parser = argparse.ArgumentParser()
parser.add_argument("-t", "--train", default="no", help="if 'yes', launch training from command line")
parser.add_argument("-e", "--evaluate", default="no", help="if 'yes', launch evaluation on test dataset")
parser.add_argument("-m", "--max_evals", help="Maximum number of evaluations during hyperparameters search")
args = parser.parse_args()
if args.train=="yes" or args.train=="Yes" or args.train=="YES":
train()
if args.evaluate=="yes" or args.evaluate=="Yes" or args.evaluate=="YES":
train(train_model=False)