-
Notifications
You must be signed in to change notification settings - Fork 3
/
risk_curve.py
82 lines (66 loc) · 2.97 KB
/
risk_curve.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import json
import logging
import subprocess
import time
from utils import colorize, configure_logging
configure_logging()
class NewRiskCurveExperiment:
def __init__(self, loss_type, max_epochs, n_train_sample):
self.output_filename = f"data/risk_curve_loss-{loss_type}_sample-{n_train_sample}_" \
f"epoch-{max_epochs}_{int(time.time())}.json"
logging.info(f"output_filename: {self.output_filename}")
self.loss_type = loss_type
self.n_train_sample = n_train_sample
self.max_epochs = max_epochs
self.results = []
critical_n_units = int((n_train_sample * 10 - 10) / float(28 * 28 + 10))
logging.info(f"critical_n_units: {critical_n_units}")
self.n_units_to_test = sorted(set(
list(range(critical_n_units - 7, critical_n_units + 4)) +
list(range(5, 55, 5)) + list(range(50, 105, 10)) + [120, 150, 200]
))
logging.info(f"n_units_to_test: {self.n_units_to_test}")
def run(self):
for i in range(len(self.n_units_to_test)):
n_units = self.n_units_to_test[i]
total_params = (28 * 28 + 1) * n_units + (n_units + 1) * 10
old_n_units = None if i == 0 else self.n_units_to_test[i - 1]
# We are training MNIST models using subprocess because tensorflow graph cannot be
# dynamically removed. In this way, we keep only one graph in memory in each loop.
args = [
'--n-units', str(n_units),
'--max-epochs', str(self.max_epochs),
'--n-train-samples', str(self.n_train_sample),
'--loss-type', str(self.loss_type),
]
if old_n_units: # and total_params < self.n_train_sample * 10:
args.extend(['--old-n-units', str(old_n_units)])
proc = subprocess.run(
['python', 'risk_curve_evaluate_model.py'] + args,
encoding='utf-8', stdout=subprocess.PIPE
)
output = proc.stdout.strip().split('\n')
epoch, step, train_loss, train_acc, eval_loss, eval_acc = \
list(map(float, output[-1].split()[1:]))
result_dict = dict(
n_epochs=int(epoch),
step=int(step),
n_units=n_units,
old_n_units=old_n_units,
total_params=total_params,
train_loss=train_loss,
train_acc=train_acc,
eval_loss=eval_loss,
eval_acc=eval_acc,
)
self.results.append(result_dict)
logging.info(colorize(f"n_units={n_units} >>> {result_dict}", 'green'))
# save to disk in every loop
with open(self.output_filename, 'w') as fout:
json.dump(self.results, fout)
return self.results
def plot(self):
pass
if __name__ == '__main__':
exp = NewRiskCurveExperiment(loss_type='mse', max_epochs=500, n_train_sample=4000)
exp.run()