-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathplot_percentiles.py
100 lines (77 loc) · 3.32 KB
/
plot_percentiles.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import argparse
import collections
import matplotlib.pyplot as plt
import json
ylabels = ['Weights', 'Biases', 'Weight Gradients', 'Bias Gradients']
log_key_templates = ['predictor/{layer}/W/data/{statistic}',
'predictor/{layer}/b/data/{statistic}',
'predictor/{layer}/W/grad/{statistic}',
'predictor/{layer}/b/grad/{statistic}']
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('--log', type=str, default='example/result/log')
parser.add_argument('--out', type=str, default='plot.png')
parser.add_argument('--layers', nargs='+', type=str,
default=['conv1', 'conv2', 'conv3', 'fc1', 'fc2'])
return parser.parse_args()
def load_log(filename, keys=None):
"""Parse a JSON file and return a dictionary with the given keys. Each
key maps to a list of corresponding data measurements in the file."""
log = collections.defaultdict(list)
with open(filename) as f:
for data in json.load(f): # For each type of data
if keys is not None:
for key in keys:
log[key].append(data[key])
else:
for key, value in data.items():
log[key].append(value)
return log
def plot_percentile_log(filename, log, layer_names, color='green', dpi=100):
n_rows = len(layer_names)
n_cols = len(log_key_templates)
figsize = (1024*n_cols/dpi, 1024*n_rows/dpi)
fig, axes = plt.subplots(n_rows, n_cols, figsize=figsize, dpi=dpi)
if n_rows == 1:
axes = axes.reshape(1, -1)
elif n_cols == 1:
axes = axes.reshape(-1, 1)
for row in range(n_rows):
for col in range(n_cols):
ax = axes[row, col]
key_template = log_key_templates[col]
# Min, Max
pmin_key = key_template.format(layer=layer_names[row],
statistic='min')
pmax_key = key_template.format(layer=layer_names[row],
statistic='max')
pmin = log[pmin_key]
pmax = log[pmax_key]
ax.fill_between(range(len(pmin)), pmin, pmax, facecolor=color,
alpha=0.2, linewidth=0)
# Median
z_key = key_template.format(layer=layer_names[row],
statistic='percentile/3')
z = log[z_key]
ax.plot(range(len(z)), z, color=color, alpha=0.2)
# Get all percentiles and fill between
n_percentiles = 3
for p in range(n_percentiles):
s_key = key_template.format(layer=layer_names[row],
statistic='percentile/{}'.format(p))
ns_key = key_template.format(layer=layer_names[row],
statistic='percentile/{}'.format(6-p))
s = log[s_key]
ns = log[ns_key]
ax.fill_between(range(len(s)), s, ns, facecolor=color,
alpha=0.2, linewidth=0)
ax.set_xlabel('Epochs')
ax.set_ylabel(ylabels[col])
ax.set_title(layer_names[row])
plt.savefig(filename, bbox_inches='tight', dpi=dpi)
plt.clf()
plt.close()
if __name__ == '__main__':
args = parse_args()
log = load_log(args.log)
plot_percentile_log(args.out, log, args.layers)