diff --git a/surfinpy/chemical_potential_plot.py b/surfinpy/chemical_potential_plot.py index bca4d53..d1355cc 100644 --- a/surfinpy/chemical_potential_plot.py +++ b/surfinpy/chemical_potential_plot.py @@ -31,26 +31,31 @@ def __init__(self, x, y, z, labels, ticks, xlabel, ylabel): self.xlabel = xlabel self.ylabel = ylabel - def plot_phase(self, temperature=None, xlabel=None, ylabel=None, output="phase.png", colourmap="viridis", - set_style="default", figsize=None): + def plot_phase(self, temperature=None, xlabel=None, ylabel=None, output="phase.png", + colourmap="viridis", set_style="default", figsize=None, show_fig=True): """Plots a simple phase diagram as a function of chemical potential. Parameters ---------- - temperature : int (optional) + temperature: int (optional) Temperature. Default=None xlabel: str (optional) Set a custom x-axis label. Default=None ylabel: str (optional) Set a custom y-axis label. Default=None - output : str (optional) + output: str (optional) Output filename. Default='phase.png' - colourmap : str (optional) + If output is set to `None`, no output is generated. + colourmap: str (optional) Colourmap for the plot. Default='viridis' - figsize : tuple (optional) + figsize: tuple (optional) Set a custom figure size. Default=None + show_fig: bool (optional) + Automatically display a figure. Default=True + If set to False the plot is returned as an object. """ - plt.style.use(set_style) + if set_style: + plt.style.use(set_style) levels = ut.get_levels(self.z) ticky = ut.get_ticks(self.ticks) temperature_label = str(temperature) + " K" @@ -75,8 +80,12 @@ def plot_phase(self, temperature=None, xlabel=None, ylabel=None, output="phase.p cbar = fig.colorbar(CM, ticks=ticky, pad=0.1) cbar.ax.set_yticklabels(self.labels) plt.tight_layout() - plt.savefig(output, dpi=600) - plt.show() + if output: + plt.savefig(output, dpi=600) + if show_fig: + plt.show() + else: + return ax def plot_mu_p(self, temperature, output="phase.png", colourmap="viridis", set_style="default"): @@ -92,7 +101,8 @@ def plot_mu_p(self, temperature, output="phase.png", colourmap="viridis", colourmap : str colourmap for the plot """ - plt.style.use(set_style) + if set_style: + plt.style.use(set_style) p1 = ut.pressure(self.x, temperature) p2 = ut.pressure(self.y, temperature) temperature_label = str(temperature) + " K" @@ -140,7 +150,8 @@ def plot_pressure(self, temperature, output="phase.png", colourmap="viridis", colourmap : str colourmap for the plot """ - plt.style.use(set_style) + if set_style: + plt.style.use(set_style) p1 = ut.pressure(self.x, temperature) p2 = ut.pressure(self.y, temperature) temperature_label = str(temperature) + " K" diff --git a/surfinpy/pvt_plot.py b/surfinpy/pvt_plot.py index 98bba67..c5463ac 100644 --- a/surfinpy/pvt_plot.py +++ b/surfinpy/pvt_plot.py @@ -30,7 +30,8 @@ def plot(self, output="Phase.png", colourmap="RdBu", set_style="ggplot", atmospheric_conditions : list location of bars showing atmospheric conditions """ - plt.style.use(set_style) + if set_style: + plt.style.use(set_style) fig = plt.figure() ax = fig.add_subplot(111) ax.contourf(self.x, self.y, self.z, cmap=colourmap)