forked from JonathanCollu/PolicyBased_DeepRL
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathutils.py
38 lines (30 loc) · 1.06 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
from scipy.signal import savgol_filter
import matplotlib.pyplot as plt
class LearningCurvePlot:
def __init__(self,title=None):
self.fig,self.ax = plt.subplots()
self.ax.set_xlabel("Epoch")
self.ax.set_ylabel("Reward")
if title is not None:
self.ax.set_title(title)
def add_curve(self,y,label=None):
''' y: vector of average reward results
label: string to appear as label in plot legend
'''
if label is not None:
self.ax.plot(y,label=label)
else:
self.ax.plot(y)
def set_ylim(self,lower,upper):
self.ax.set_ylim([lower,upper])
def add_hline(self,height,label):
self.ax.axhline(height,ls="--",c="k",label=label)
def save(self,name="test.png"):
''' name: string for filename of saved figure '''
self.ax.legend()
self.fig.savefig(name,dpi=300)
def smooth(y, window, poly=1):
''' y: vector to be smoothed
window: size of the smoothing window
'''
return savgol_filter(y,window,poly)