#!/usr/bin/env python3 # vim: set fileencoding=utf-8 : from collections import defaultdict from itertools import tee import os import re import sys import matplotlib.pyplot as plt from matplotlib2tikz import save as tikz_save import numpy as np fp = re.compile(r'states\.(\d+)\.l[12][ch].npz') def read_file(state_file): """Reads one state file.""" npz = np.load(state_file) keys = sorted(npz.keys()) values = [npz[key] if i < 4 else npz[key].mean() for i, key in enumerate(keys)] return int(fp.match(state_file).group(1)), keys, values def main(): plot_dir = sys.argv[1] if len(sys.argv) > 1 else '.' npzs = tee(filter(lambda f: f.endswith('.npz'), os.listdir(plot_dir)), 4) states = ['l1c', 'l1h', 'l2c', 'l2h'] rp = re.compile(r'l(\d+)([ch])') state_names = ['Layer {}.{}'.format(*rp.match(s).groups()) for s in states] state_data = { state: sorted(filter(lambda f: state in f, npz), key=lambda f: int(fp.match(f).group(1))) for state, npz in zip(states, npzs) } keys, nums = [], [] # Plot, x axis to_plot = defaultdict(lambda: defaultdict(list)) for state, files in state_data.items(): for sf in files: num, keys, values = read_file(sf) nums.append(num) for k, v in zip(keys, values): to_plot[state][k].append(v) nums = np.array(nums[:len(nums) // len(state_data)]) nums //= nums.min() prefixes = ['', 'length_'] fig, axes = plt.subplots(nrows=4, ncols=2, sharex=True, sharey=False) for i, state in enumerate(states): for j, prefix in enumerate(prefixes): axis = axes[i][j] dmean = np.array(to_plot[state][prefix + 'mean']) dstd = np.array(to_plot[state][prefix + 'std']) dmin = np.array(to_plot[state][prefix + 'min']) dmax = np.array(to_plot[state][prefix + 'max']) axis.errorbar(x=nums, y=dmean, yerr=dstd, fmt='.k', lw=3) axis.errorbar( x=nums, y=dmean, yerr=[dmean - dmin, dmean + dmax], fmt='.k', ecolor='gray', lw=1 ) axis.set_title(state_names[i] + (', length' if j else '')) if i == len(states) - 1: axis.set_xlabel('Iterations') plt.tight_layout() # fig.subplots_adjust(hspace=.5) # plt.show() tikz_save('test_tight.tikz') if __name__ == '__main__': main()