-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathlogger.py
145 lines (119 loc) · 4.67 KB
/
logger.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
# A simple torch style logger
# (C) Wei YANG 2017
from __future__ import absolute_import
import matplotlib.pyplot as plt
import os
import sys
import numpy as np
__all__ = ['Logger', 'LoggerMonitor', 'savefig']
def savefig(fname, dpi=None):
dpi = 150 if dpi == None else dpi
plt.savefig(fname, dpi=dpi)
def plot_overlap(logger, names=None):
names = logger.names if names == None else names
numbers = logger.numbers
for _, name in enumerate(names):
y = [round(eval(x)) for x in np.asarray(numbers[name])]
x = np.arange(len(numbers[name]))
plt.plot(x, y)
return [logger.title + '(' + name + ')' for name in names]
class Logger(object):
'''Save training process to log file with simple plot function.'''
def __init__(self, fpath, title=None, resume=False):
self.file = None
self.resume = resume
self.title = '' if title == None else title
if fpath is not None:
if resume:
self.file = open(fpath, 'r')
name = self.file.readline()
self.names = name.rstrip().split('\t')
self.numbers = {}
for _, name in enumerate(self.names):
self.numbers[name] = []
for numbers in self.file:
numbers = numbers.rstrip().split('\t')
for i in range(0, len(numbers)):
self.numbers[self.names[i]].append(numbers[i])
self.file.close()
self.file = open(fpath, 'a')
else:
self.file = open(fpath, 'w')
def set_names(self, names):
if self.resume:
pass
# initialize numbers as empty list
self.numbers = {}
self.names = names
for _, name in enumerate(self.names):
self.file.write(name)
self.file.write('\t')
self.numbers[name] = []
self.file.write('\n')
self.file.flush()
def append(self, numbers):
assert len(self.names) == len(numbers), 'Numbers do not match names'
for index, num in enumerate(numbers):
self.file.write("{0:.6f}".format(num))
self.file.write('\t')
self.numbers[self.names[index]].append(num)
self.file.write('\n')
self.file.flush()
def plot(self, names=None):
names = self.names if names == None else names
numbers = self.numbers
for _, name in enumerate(names):
x = np.arange(len(numbers[name]))
plt.plot(x, np.asarray(numbers[name]))
# y = [round(eval(xx)) for xx in np.asarray(numbers[name])]
# plt.plot(x, y)
plt.legend([self.title + '(' + name + ')' for name in names])
plt.grid(True)
def close(self):
if self.file is not None:
self.file.close()
class LoggerMonitor(object):
'''Load and visualize multiple logs.'''
def __init__(self, paths):
'''paths is a distionary with {name:filepath} pair'''
self.loggers = []
for title, path in paths.items():
logger = Logger(path, title=title, resume=True)
self.loggers.append(logger)
def plot(self, names=None):
plt.figure()
# plt.subplot(121)
legend_text = []
for logger in self.loggers:
legend_text += plot_overlap(logger, names)
plt.legend(legend_text) # , bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.
plt.grid(True)
if __name__ == '__main__':
# # Example
# logger = Logger('test.txt')
# logger.set_names(['Train loss', 'Valid loss','Test loss'])
# length = 100
# t = np.arange(length)
# train_loss = np.exp(-t / 10.0) + np.random.rand(length) * 0.1
# valid_loss = np.exp(-t / 10.0) + np.random.rand(length) * 0.1
# test_loss = np.exp(-t / 10.0) + np.random.rand(length) * 0.1
# for i in range(0, length):
# logger.append([train_loss[i], valid_loss[i], test_loss[i]])
# logger.plot()
# Example: logger monitor
# paths = {
# 'resadvnet20':'/home/wyang/code/pytorch-classification/checkpoint/cifar10/resadvnet20/log.txt',
# 'resadvnet32':'/home/wyang/code/pytorch-classification/checkpoint/cifar10/resadvnet32/log.txt',
# 'resadvnet44':'/home/wyang/code/pytorch-classification/checkpoint/cifar10/resadvnet44/log.txt',
# }
log_path = './checkpoint/'
paths = {'': log_path + 'log.txt', }
field = ['Valid Acc.', 'Train Acc.']
monitor = LoggerMonitor(paths)
monitor.plot(names=field)
savefig(log_path + 'test.eps')
title = ''
logger = Logger(log_path + 'log.txt', title=title, resume=True)
logger.close()
logger.plot()
savefig(log_path + 'log.eps')