From 9b890caea0f908dd688bd3dda79eea0e69715054 Mon Sep 17 00:00:00 2001 From: Rishabh261998 Date: Sat, 6 Mar 2021 01:16:46 +0530 Subject: [PATCH 1/4] Added interactive legend to ppc_plot bokeh --- arviz/plots/backends/bokeh/ppcplot.py | 62 +++++++++++++++++++++------ 1 file changed, 50 insertions(+), 12 deletions(-) diff --git a/arviz/plots/backends/bokeh/ppcplot.py b/arviz/plots/backends/bokeh/ppcplot.py index 11fe625992..8eeb5db638 100644 --- a/arviz/plots/backends/bokeh/ppcplot.py +++ b/arviz/plots/backends/bokeh/ppcplot.py @@ -4,6 +4,7 @@ from ....stats.density_utils import get_bins, histogram, kde from ...kdeplot import plot_kde from ...plot_utils import _scale_fig_size, vectorized_to_hex +from bokeh.models.annotations import Legend from .. import show_layout from . import backend_kwarg_defaults, create_axes_grid @@ -86,6 +87,7 @@ def plot_ppc( var_name, sel, isel, obs_vals = obs_plotters[i] pp_var_name, _, _, pp_vals = pp_plotters[i] dtype = predictive_dataset[pp_var_name].dtype.kind + legend_it = [] # flatten non-specified dimensions obs_vals = obs_vals.flatten() @@ -111,14 +113,19 @@ def plot_ppc( pp_xs.append(bin_edges) if dtype == "f": - ax_i.multi_line(pp_xs, pp_densities, **plot_kwargs) + multi_line = ax_i.multi_line(pp_xs, pp_densities, **plot_kwargs) + legend_it.append(("{} predictive".format(group.capitalize()), [multi_line])) else: + all_steps = [] for x_s, y_s in zip(pp_xs, pp_densities): - ax_i.step(x_s, y_s, **plot_kwargs) + step = ax_i.step(x_s, y_s, **plot_kwargs) + all_steps.append(step) + legend_it.append(("{} predictive".format(group.capitalize()), all_steps)) if observed: + label = "Observed" if dtype == "f": - plot_kde( + _, glyph = plot_kde( obs_vals, plot_kwargs={"line_color": "black", "line_width": linewidth}, fill_kwargs={"alpha": 0}, @@ -126,20 +133,24 @@ def plot_ppc( backend="bokeh", backend_kwargs={}, show=False, + return_glyph=True, ) + legend_it.append((label, glyph)) else: bins = get_bins(obs_vals) _, hist, bin_edges = histogram(obs_vals, bins=bins) hist = np.concatenate((hist[:1], hist)) - ax_i.step( + step = ax_i.step( bin_edges, hist, line_color="black", line_width=linewidth, mode="center", ) + legend_it.append((label, [step])) if mean: + label = "{} predictive mean".format(group.capitalize()) if dtype == "f": rep = len(pp_densities) len_density = len(pp_densities[0]) @@ -150,19 +161,20 @@ def plot_ppc( new_x -= (new_x[1] - new_x[0]) / 2 for irep in range(rep): new_d[irep][bins[irep]] = pp_densities[irep] - ax_i.line( + line = ax_i.line( new_x, new_d.mean(0), color=color, line_dash="dashed", line_width=linewidth, ) + legend_it.append((label, [line])) else: vals = pp_vals.flatten() bins = get_bins(vals) _, hist, bin_edges = histogram(vals, bins=bins) hist = np.concatenate((hist[:1], hist)) - ax_i.step( + step = ax_i.step( bin_edges, hist, line_color=color, @@ -170,12 +182,14 @@ def plot_ppc( line_dash="dashed", mode="center", ) + legend_it.append((label, [step])) ax_i.yaxis.major_tick_line_color = None ax_i.yaxis.minor_tick_line_color = None ax_i.yaxis.major_label_text_font_size = "0pt" elif kind == "cumulative": if observed: + label = "Observed" if dtype == "f": glyph = ax_i.line( *_empirical_cdf(obs_vals), @@ -183,39 +197,45 @@ def plot_ppc( line_width=linewidth, ) glyph.level = "overlay" + legend_it.append((label, [glyph])) else: - ax_i.step( + step = ax_i.step( *_empirical_cdf(obs_vals), line_color="black", line_width=linewidth, mode="center", ) + legend_it.append((label, [step])) pp_densities = np.empty((2 * len(pp_sampled_vals), pp_sampled_vals[0].size)) for idx, vals in enumerate(pp_sampled_vals): vals = np.array([vals]).flatten() pp_x, pp_density = _empirical_cdf(vals) pp_densities[2 * idx] = pp_x pp_densities[2 * idx + 1] = pp_density - ax_i.multi_line( + multi_line = ax_i.multi_line( list(pp_densities[::2]), list(pp_densities[1::2]), line_alpha=alpha, line_color=color, line_width=linewidth, ) + legend_it.append(("{} predictive".format(group.capitalize()), [multi_line])) if mean: - ax_i.line( + label = "{} predictive mean".format(group.capitalize()) + line = ax_i.line( *_empirical_cdf(pp_vals.flatten()), color=color, line_dash="dashed", line_width=linewidth, ) + legend_it.append((label, [line])) elif kind == "scatter": if mean: + label = "{} predictive mean".format(group.capitalize()) if dtype == "f": - plot_kde( + _, glyph = plot_kde( pp_vals.flatten(), plot_kwargs={ "line_color": color, @@ -226,13 +246,15 @@ def plot_ppc( backend="bokeh", backend_kwargs={}, show=False, + return_glyph=True, ) + legend_it.append((label, glyph)) else: vals = pp_vals.flatten() bins = get_bins(vals) _, hist, bin_edges = histogram(vals, bins=bins) hist = np.concatenate((hist[:1], hist)) - ax_i.step( + step = ax_i.step( bin_edges, hist, color=color, @@ -240,6 +262,7 @@ def plot_ppc( line_dash="dashed", mode="center", ) + legend_it.append((label, [step])) jitter_scale = 0.1 y_rows = np.linspace(0, 0.1, num_pp_samples + 1) @@ -247,6 +270,7 @@ def plot_ppc( scale_high = jitter_scale * jitter if observed: + label = "Observed" obs_yvals = np.zeros_like(obs_vals, dtype=np.float64) if jitter: obs_yvals += np.random.uniform( @@ -260,18 +284,32 @@ def plot_ppc( line_alpha=alpha, ) glyph.level = "overlay" + legend_it.append((label, [glyph])) + all_scatter = [] for vals, y in zip(pp_sampled_vals, y_rows[1:]): vals = np.ravel(vals) yvals = np.full_like(vals, y, dtype=np.float64) if jitter: yvals += np.random.uniform(low=scale_low, high=scale_high, size=len(vals)) - ax_i.scatter(vals, yvals, fill_color=color, size=markersize, fill_alpha=alpha) + scatter = ax_i.scatter( + vals, yvals, fill_color=color, size=markersize, fill_alpha=alpha + ) + all_scatter.append(scatter) + legend_it.append(("{} predictive".format(group.capitalize()), all_scatter)) ax_i.yaxis.major_tick_line_color = None ax_i.yaxis.minor_tick_line_color = None ax_i.yaxis.major_label_text_font_size = "0pt" + if legend: + legend = Legend( + items=legend_it, + location="center_right", + orientation="horizontal", + ) + ax_i.add_layout(legend, "above") + ax_i.legend.click_policy = "hide" ax_i.xaxis.axis_label = labeller.make_pp_label(var_name, pp_var_name, sel, isel) show_layout(axes, show) From 6ba3885b10eaf2c835b392a43712bb4b185b05e7 Mon Sep 17 00:00:00 2001 From: Rishabh261998 Date: Sat, 6 Mar 2021 01:40:54 +0530 Subject: [PATCH 2/4] Minor Change --- arviz/plots/backends/bokeh/ppcplot.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/arviz/plots/backends/bokeh/ppcplot.py b/arviz/plots/backends/bokeh/ppcplot.py index 8eeb5db638..10306e4605 100644 --- a/arviz/plots/backends/bokeh/ppcplot.py +++ b/arviz/plots/backends/bokeh/ppcplot.py @@ -1,10 +1,10 @@ """Bokeh Posterior predictive plot.""" import numpy as np +from bokeh.models.annotations import Legend from ....stats.density_utils import get_bins, histogram, kde from ...kdeplot import plot_kde from ...plot_utils import _scale_fig_size, vectorized_to_hex -from bokeh.models.annotations import Legend from .. import show_layout from . import backend_kwarg_defaults, create_axes_grid From 5a587794ea3687a26c322a5fc936908787ef0700 Mon Sep 17 00:00:00 2001 From: Rishabh261998 Date: Sat, 6 Mar 2021 15:28:46 +0530 Subject: [PATCH 3/4] Updated CHANGELOG and positioned legend --- CHANGELOG.md | 1 + arviz/plots/backends/bokeh/ppcplot.py | 8 +++++--- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index a32913cc48..4c43c62a6e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,7 @@ * Added `rope_color` and `ref_val_color` arguments to `plot_posterior` ([1570](https://github.com/arviz-devs/arviz/pull/1570)) * Improved retrieving or pointwise log likelihood in `from_cmdstanpy`, `from_cmdstan` and `from_pystan` ([1579](https://github.com/arviz-devs/arviz/pull/1579) and [1599](https://github.com/arviz-devs/arviz/pull/1599)) * Added interactive legend to bokeh `forestplot` ([1591](https://github.com/arviz-devs/arviz/pull/1591)) +* Added interactive legend to bokeh `ppcplot` ([1602](https://github.com/arviz-devs/arviz/pull/1602)) ### Maintenance and fixes * Enforced using coordinate values as default labels ([1201](https://github.com/arviz-devs/arviz/pull/1201)) diff --git a/arviz/plots/backends/bokeh/ppcplot.py b/arviz/plots/backends/bokeh/ppcplot.py index 10306e4605..c4443092e3 100644 --- a/arviz/plots/backends/bokeh/ppcplot.py +++ b/arviz/plots/backends/bokeh/ppcplot.py @@ -305,10 +305,12 @@ def plot_ppc( if legend: legend = Legend( items=legend_it, - location="center_right", - orientation="horizontal", + location="top_left", + orientation="vertical", ) - ax_i.add_layout(legend, "above") + ax_i.add_layout(legend) + if textsize is not None: + ax_i.legend.label_text_font_size = str(textsize) + "pt" ax_i.legend.click_policy = "hide" ax_i.xaxis.axis_label = labeller.make_pp_label(var_name, pp_var_name, sel, isel) From f19b925d841b913b2d4034bd59afe513c4deeed7 Mon Sep 17 00:00:00 2001 From: Rishabh261998 Date: Tue, 9 Mar 2021 22:17:08 +0530 Subject: [PATCH 4/4] Added test for textsize in ppc_plot bokeh --- arviz/plots/backends/bokeh/ppcplot.py | 2 +- arviz/tests/base_tests/test_plots_bokeh.py | 11 +++++++++++ 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/arviz/plots/backends/bokeh/ppcplot.py b/arviz/plots/backends/bokeh/ppcplot.py index c4443092e3..44e3e3a94c 100644 --- a/arviz/plots/backends/bokeh/ppcplot.py +++ b/arviz/plots/backends/bokeh/ppcplot.py @@ -310,7 +310,7 @@ def plot_ppc( ) ax_i.add_layout(legend) if textsize is not None: - ax_i.legend.label_text_font_size = str(textsize) + "pt" + ax_i.legend.label_text_font_size = f"{textsize}pt" ax_i.legend.click_policy = "hide" ax_i.xaxis.axis_label = labeller.make_pp_label(var_name, pp_var_name, sel, isel) diff --git a/arviz/tests/base_tests/test_plots_bokeh.py b/arviz/tests/base_tests/test_plots_bokeh.py index 9930761f7a..396c804f32 100644 --- a/arviz/tests/base_tests/test_plots_bokeh.py +++ b/arviz/tests/base_tests/test_plots_bokeh.py @@ -890,6 +890,17 @@ def test_plot_ppc(models, kind, alpha, observed): assert axes +def test_plot_ppc_textsize(models): + axes = plot_ppc( + models.model_1, + textsize=10, + random_seed=3, + backend="bokeh", + show=False, + ) + assert axes + + @pytest.mark.parametrize("kind", ["kde", "cumulative", "scatter"]) @pytest.mark.parametrize("jitter", [None, 0, 0.1, 1, 3]) def test_plot_ppc_multichain(kind, jitter):