Skip to content

Commit 40a3987

Browse files
authored
[feat] Add the option to save a figure in plot setting params (#351)
* [feat] Add the option to save a figure in plot setting params Since non-GUI based environments would like to avoid the usage of show method in the matplotlib, I added the option to savefig and thus users can complete the operations inside AutoPytorch. * [doc] Add a comment for non-GUI based computer in plot_perf_over_time method * [test] Add a test to check the priority of show and savefig Since plt.savefig and plt.show do not work at the same time due to the matplotlib design, we need to check whether show will not be called when a figname is specified. We can actually raise an error, but plot will be basically called in the end of an optimization, so I wanted to avoid raising an error and just sticked to a check by tests.
1 parent 8f9e9f6 commit 40a3987

File tree

5 files changed

+89
-51
lines changed

5 files changed

+89
-51
lines changed

autoPyTorch/api/base_task.py

+3
Original file line numberDiff line numberDiff line change
@@ -1513,6 +1513,9 @@ def plot_perf_over_time(
15131513
The settings of a pair of color and label for each plot.
15141514
args, kwargs (Any):
15151515
Arguments for the ax.plot.
1516+
1517+
Note:
1518+
You might need to run `export DISPLAY=:0.0` if you are using non-GUI based environment.
15161519
"""
15171520

15181521
if not hasattr(metrics, metric_name):

autoPyTorch/utils/results_visualizer.py

+36-12
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from dataclasses import dataclass
22
from enum import Enum
3-
from typing import Any, Dict, Optional, Tuple
3+
from typing import Any, Dict, NamedTuple, Optional, Tuple
44

55
import matplotlib.pyplot as plt
66

@@ -71,8 +71,7 @@ def extract_dicts(
7171
return colors, labels
7272

7373

74-
@dataclass(frozen=True)
75-
class PlotSettingParams:
74+
class PlotSettingParams(NamedTuple):
7675
"""
7776
Parameters for the plot environment.
7877
@@ -93,12 +92,28 @@ class PlotSettingParams:
9392
The range of x axis.
9493
ylim (Tuple[float, float]):
9594
The range of y axis.
95+
grid (bool):
96+
Whether to have grid lines.
97+
If users would like to define lines in detail,
98+
they need to deactivate it.
9699
legend (bool):
97100
Whether to have legend in the figure.
98-
legend_loc (str):
99-
The location of the legend.
101+
legend_kwargs (Dict[str, Any]):
102+
The kwargs for ax.legend.
103+
Ref: https://matplotlib.org/stable/api/_as_gen/matplotlib.pyplot.legend.html
104+
title (Optional[str]):
105+
The title of the figure.
106+
title_kwargs (Dict[str, Any]):
107+
The kwargs for ax.set_title except title label.
108+
Ref: https://matplotlib.org/3.1.1/api/_as_gen/matplotlib.axes.Axes.set_title.html
100109
show (bool):
101110
Whether to show the plot.
111+
If figname is not None, the save will be prioritized.
112+
figname (Optional[str]):
113+
Name of a figure to save. If None, no figure will be saved.
114+
savefig_kwargs (Dict[str, Any]):
115+
The kwargs for plt.savefig except filename.
116+
Ref: https://matplotlib.org/stable/api/_as_gen/matplotlib.pyplot.savefig.html
102117
args, kwargs (Any):
103118
Arguments for the ax.plot.
104119
"""
@@ -108,12 +123,16 @@ class PlotSettingParams:
108123
xlabel: Optional[str] = None
109124
ylabel: Optional[str] = None
110125
title: Optional[str] = None
126+
title_kwargs: Dict[str, Any] = {}
111127
xlim: Optional[Tuple[float, float]] = None
112128
ylim: Optional[Tuple[float, float]] = None
129+
grid: bool = True
113130
legend: bool = True
114-
legend_loc: str = 'best'
131+
legend_kwargs: Dict[str, Any] = {}
115132
show: bool = False
133+
figname: Optional[str] = None
116134
figsize: Optional[Tuple[int, int]] = None
135+
savefig_kwargs: Dict[str, Any] = {}
117136

118137

119138
class ScaleChoices(Enum):
@@ -201,17 +220,22 @@ def _set_plot_args(
201220

202221
ax.set_xscale(plot_setting_params.xscale)
203222
ax.set_yscale(plot_setting_params.yscale)
204-
if plot_setting_params.xscale == 'log' or plot_setting_params.yscale == 'log':
205-
ax.grid(True, which='minor', color='gray', linestyle=':')
206223

207-
ax.grid(True, which='major', color='black')
224+
if plot_setting_params.grid:
225+
if plot_setting_params.xscale == 'log' or plot_setting_params.yscale == 'log':
226+
ax.grid(True, which='minor', color='gray', linestyle=':')
227+
228+
ax.grid(True, which='major', color='black')
208229

209230
if plot_setting_params.legend:
210-
ax.legend(loc=plot_setting_params.legend_loc)
231+
ax.legend(**plot_setting_params.legend_kwargs)
211232

212233
if plot_setting_params.title is not None:
213-
ax.set_title(plot_setting_params.title)
214-
if plot_setting_params.show:
234+
ax.set_title(plot_setting_params.title, **plot_setting_params.title_kwargs)
235+
236+
if plot_setting_params.figname is not None:
237+
plt.savefig(plot_setting_params.figname, **plot_setting_params.savefig_kwargs)
238+
elif plot_setting_params.show:
215239
plt.show()
216240

217241
@staticmethod

examples/40_advanced/example_plot_over_time.py

+5-6
Original file line numberDiff line numberDiff line change
@@ -62,21 +62,20 @@
6262
xlabel='Runtime',
6363
ylabel='Accuracy',
6464
title='Toy Example',
65-
show=False # If you would like to show, make it True
65+
figname='example_plot_over_time.png',
66+
savefig_kwargs={'bbox_inches': 'tight'},
67+
show=False # If you would like to show, make it True and set figname=None
6668
)
6769

6870
############################################################################
6971
# Plot with the Specified Setting Parameters
7072
# ==========================================
71-
_, ax = plt.subplots()
73+
# _, ax = plt.subplots() <=== You can feed it to post-process the figure.
7274

75+
# You might need to run `export DISPLAY=:0.0` if you are using non-GUI based environment.
7376
api.plot_perf_over_time(
74-
ax=ax, # You do not have to provide.
7577
metric_name=metric_name,
7678
plot_setting_params=params,
7779
marker='*',
7880
markersize=10
7981
)
80-
81-
# plt.show() might cause issue depending on environments
82-
plt.savefig('example_plot_over_time.png')

test/test_utils/test_results_manager.py

+8-22
Original file line numberDiff line numberDiff line change
@@ -165,11 +165,9 @@ def test_extract_results_from_run_history():
165165
time=1.0,
166166
status=StatusType.CAPPED,
167167
)
168-
with pytest.raises(ValueError) as excinfo:
168+
with pytest.raises(ValueError):
169169
SearchResults(metric=accuracy, scoring_functions=[], run_history=run_history)
170170

171-
assert excinfo._excinfo[0] == ValueError
172-
173171

174172
def test_raise_error_in_update_and_sort_by_time():
175173
cs = ConfigurationSpace()
@@ -179,7 +177,7 @@ def test_raise_error_in_update_and_sort_by_time():
179177
sr = SearchResults(metric=accuracy, scoring_functions=[], run_history=RunHistory())
180178
er = EnsembleResults(metric=accuracy, ensemble_performance_history=[])
181179

182-
with pytest.raises(RuntimeError) as excinfo:
180+
with pytest.raises(RuntimeError):
183181
sr._update(
184182
config=config,
185183
run_key=RunKey(config_id=0, instance_id=0, seed=0),
@@ -189,19 +187,13 @@ def test_raise_error_in_update_and_sort_by_time():
189187
)
190188
)
191189

192-
assert excinfo._excinfo[0] == RuntimeError
193-
194-
with pytest.raises(RuntimeError) as excinfo:
190+
with pytest.raises(RuntimeError):
195191
sr._sort_by_endtime()
196192

197-
assert excinfo._excinfo[0] == RuntimeError
198-
199-
with pytest.raises(RuntimeError) as excinfo:
193+
with pytest.raises(RuntimeError):
200194
er._update(data={})
201195

202-
assert excinfo._excinfo[0] == RuntimeError
203-
204-
with pytest.raises(RuntimeError) as excinfo:
196+
with pytest.raises(RuntimeError):
205197
er._sort_by_endtime()
206198

207199

@@ -244,11 +236,9 @@ def test_raise_error_in_get_start_time():
244236
status=StatusType.CAPPED,
245237
)
246238

247-
with pytest.raises(ValueError) as excinfo:
239+
with pytest.raises(ValueError):
248240
get_start_time(run_history)
249241

250-
assert excinfo._excinfo[0] == ValueError
251-
252242

253243
def test_search_results_sort_by_endtime():
254244
run_history = RunHistory()
@@ -364,11 +354,9 @@ def test_metric_results(metric, scores, ensemble_ends_later):
364354
def test_search_results_sprint_statistics():
365355
api = BaseTask()
366356
for method in ['get_search_results', 'sprint_statistics', 'get_incumbent_results']:
367-
with pytest.raises(RuntimeError) as excinfo:
357+
with pytest.raises(RuntimeError):
368358
getattr(api, method)()
369359

370-
assert excinfo._excinfo[0] == RuntimeError
371-
372360
run_history_data = json.load(open(os.path.join(os.path.dirname(__file__),
373361
'runhistory.json'),
374362
mode='r'))['data']
@@ -420,11 +408,9 @@ def test_check_run_history(run_history):
420408
manager = ResultsManager()
421409
manager.run_history = run_history
422410

423-
with pytest.raises(RuntimeError) as excinfo:
411+
with pytest.raises(RuntimeError):
424412
manager._check_run_history()
425413

426-
assert excinfo._excinfo[0] == RuntimeError
427-
428414

429415
@pytest.mark.parametrize('include_traditional', (True, False))
430416
@pytest.mark.parametrize('metric', (accuracy, log_loss))

test/test_utils/test_results_visualizer.py

+37-11
Original file line numberDiff line numberDiff line change
@@ -55,15 +55,46 @@ def test_extract_dicts(cl_settings, with_ensemble):
5555

5656
@pytest.mark.parametrize('params', (
5757
PlotSettingParams(show=True),
58-
PlotSettingParams(show=False)
58+
PlotSettingParams(show=False),
59+
PlotSettingParams(show=True, figname='dummy')
5960
))
6061
def test_plt_show_in_set_plot_args(params): # TODO
6162
plt.show = MagicMock()
63+
plt.savefig = MagicMock()
6264
_, ax = plt.subplots(nrows=1, ncols=1)
6365
viz = ResultsVisualizer()
6466

6567
viz._set_plot_args(ax, params)
66-
assert plt.show._mock_called == params.show
68+
# if figname is not None, show will not be called. (due to the matplotlib design)
69+
assert plt.show._mock_called == (params.figname is None and params.show)
70+
plt.close()
71+
72+
73+
@pytest.mark.parametrize('params', (
74+
PlotSettingParams(),
75+
PlotSettingParams(figname='fig')
76+
))
77+
def test_plt_savefig_in_set_plot_args(params): # TODO
78+
plt.savefig = MagicMock()
79+
_, ax = plt.subplots(nrows=1, ncols=1)
80+
viz = ResultsVisualizer()
81+
82+
viz._set_plot_args(ax, params)
83+
assert plt.savefig._mock_called == (params.figname is not None)
84+
plt.close()
85+
86+
87+
@pytest.mark.parametrize('params', (
88+
PlotSettingParams(grid=True),
89+
PlotSettingParams(grid=False)
90+
))
91+
def test_ax_grid_in_set_plot_args(params): # TODO
92+
_, ax = plt.subplots(nrows=1, ncols=1)
93+
ax.grid = MagicMock()
94+
viz = ResultsVisualizer()
95+
96+
viz._set_plot_args(ax, params)
97+
assert ax.grid._mock_called == params.grid
6798
plt.close()
6899

69100

@@ -77,10 +108,9 @@ def test_raise_value_error_in_set_plot_args(params): # TODO
77108
_, ax = plt.subplots(nrows=1, ncols=1)
78109
viz = ResultsVisualizer()
79110

80-
with pytest.raises(ValueError) as excinfo:
111+
with pytest.raises(ValueError):
81112
viz._set_plot_args(ax, params)
82113

83-
assert excinfo._excinfo[0] == ValueError
84114
plt.close()
85115

86116

@@ -119,13 +149,11 @@ def test_raise_error_in_plot_perf_over_time_in_base_task(metric_name):
119149
api = BaseTask()
120150

121151
if metric_name == 'unknown':
122-
with pytest.raises(ValueError) as excinfo:
152+
with pytest.raises(ValueError):
123153
api.plot_perf_over_time(metric_name)
124-
assert excinfo._excinfo[0] == ValueError
125154
else:
126-
with pytest.raises(RuntimeError) as excinfo:
155+
with pytest.raises(RuntimeError):
127156
api.plot_perf_over_time(metric_name)
128-
assert excinfo._excinfo[0] == RuntimeError
129157

130158

131159
@pytest.mark.parametrize('metric_name', ('balanced_accuracy', 'accuracy'))
@@ -175,16 +203,14 @@ def test_raise_error_get_perf_and_time(params):
175203
results = np.linspace(-1, 1, 10)
176204
cum_times = np.linspace(0, 1, 10)
177205

178-
with pytest.raises(ValueError) as excinfo:
206+
with pytest.raises(ValueError):
179207
_get_perf_and_time(
180208
cum_results=results,
181209
cum_times=cum_times,
182210
plot_setting_params=params,
183211
worst_val=np.inf
184212
)
185213

186-
assert excinfo._excinfo[0] == ValueError
187-
188214

189215
@pytest.mark.parametrize('params', (
190216
PlotSettingParams(n_points=20, xscale='linear', yscale='linear'),

0 commit comments

Comments
 (0)