From f23be6bbec6dd0b0882c0cc5380de771b0a2aac3 Mon Sep 17 00:00:00 2001 From: aloctavodia Date: Fri, 11 Nov 2022 13:03:37 -0300 Subject: [PATCH 1/2] add observed_rug argument to plot_ppc --- arviz/plots/backends/bokeh/ppcplot.py | 28 ++++++++++++++++++++++ arviz/plots/backends/matplotlib/ppcplot.py | 10 ++++++++ arviz/plots/ppcplot.py | 5 ++++ 3 files changed, 43 insertions(+) diff --git a/arviz/plots/backends/bokeh/ppcplot.py b/arviz/plots/backends/bokeh/ppcplot.py index 1c2eb85177..9534d291f7 100644 --- a/arviz/plots/backends/bokeh/ppcplot.py +++ b/arviz/plots/backends/bokeh/ppcplot.py @@ -1,11 +1,15 @@ """Bokeh Posterior predictive plot.""" import numpy as np from bokeh.models.annotations import Legend +from bokeh.models.glyphs import Scatter +from bokeh.models import ColumnDataSource + from ....stats.density_utils import get_bins, histogram, kde from ...kdeplot import plot_kde from ...plot_utils import _scale_fig_size, vectorized_to_hex + from .. import show_layout from . import backend_kwarg_defaults, create_axes_grid @@ -27,6 +31,7 @@ def plot_ppc( textsize, mean, observed, + observed_rug, jitter, total_pp_samples, legend, # pylint: disable=unused-argument @@ -97,6 +102,7 @@ def plot_ppc( obs_vals = obs_vals.flatten() pp_vals = pp_vals.reshape(total_pp_samples, -1) pp_sampled_vals = pp_vals[pp_sample_ix] + cds_rug = ColumnDataSource({"_": np.array(obs_vals)}) if kind == "kde": plot_kwargs = { @@ -144,6 +150,16 @@ def plot_ppc( return_glyph=True, ) legend_it.append((label, glyph)) + if observed_rug: + glyph = Scatter( + x="_", + y=0.0, + marker="dash", + angle=np.pi / 2, + line_color=colors[1], + line_width=linewidth, + ) + ax_i.add_glyph(cds_rug, glyph) else: bins = get_bins(obs_vals) _, hist, bin_edges = histogram(obs_vals, bins=bins) @@ -215,6 +231,18 @@ def plot_ppc( mode="center", ) legend_it.append((label, [step])) + + if observed_rug: + glyph = Scatter( + x="_", + y=0.0, + marker="dash", + angle=np.pi / 2, + line_color=colors[1], + line_width=linewidth, + ) + ax_i.add_glyph(cds_rug, glyph) + pp_densities = np.empty((2 * len(pp_sampled_vals), pp_sampled_vals[0].size)) for idx, vals in enumerate(pp_sampled_vals): vals = np.array([vals]).flatten() diff --git a/arviz/plots/backends/matplotlib/ppcplot.py b/arviz/plots/backends/matplotlib/ppcplot.py index e1a919a600..72604776db 100644 --- a/arviz/plots/backends/matplotlib/ppcplot.py +++ b/arviz/plots/backends/matplotlib/ppcplot.py @@ -31,6 +31,7 @@ def plot_ppc( textsize, mean, observed, + observed_rug, jitter, total_pp_samples, legend, @@ -135,6 +136,7 @@ def plot_ppc( if dtype == "f": plot_kde( obs_vals, + rug=observed_rug, label="Observed", plot_kwargs={"color": colors[1], "linewidth": linewidth, "zorder": 3}, fill_kwargs={"alpha": 0}, @@ -232,6 +234,14 @@ def plot_ppc( drawstyle=drawstyle, zorder=3, ) + if observed_rug: + ax_i.plot( + obs_vals, + np.zeros_like(obs_vals) - 0.1, + ls="", + marker="|", + color=colors[1], + ) if animated: animate, init = _set_animation( pp_sampled_vals, diff --git a/arviz/plots/ppcplot.py b/arviz/plots/ppcplot.py index ae1f25d9a5..4fbe2d9dbe 100644 --- a/arviz/plots/ppcplot.py +++ b/arviz/plots/ppcplot.py @@ -20,6 +20,7 @@ def plot_ppc( alpha=None, mean=True, observed=True, + observed_rug=False, color=None, colors=None, grid=None, @@ -62,6 +63,9 @@ def plot_ppc( Defaults to ``True``. observed: bool, default True Whether or not to plot the observed data. + observed: bool, default False + Whether or not to plot a rug plot for the observed data. Only valid if `observed` is + `True` and for kind `kde` or `cumulative`. color: str Valid matplotlib ``color``. Defaults to ``C0``. color: list @@ -339,6 +343,7 @@ def plot_ppc( textsize=textsize, mean=mean, observed=observed, + observed_rug=observed_rug, total_pp_samples=total_pp_samples, legend=legend, labeller=labeller, From d7873365495ba21ab43bc2f5295b32e1475efbba Mon Sep 17 00:00:00 2001 From: aloctavodia Date: Fri, 11 Nov 2022 13:53:57 -0300 Subject: [PATCH 2/2] add tests and update changelog --- CHANGELOG.md | 2 +- arviz/tests/base_tests/test_plots_bokeh.py | 4 +++- arviz/tests/base_tests/test_plots_matplotlib.py | 4 +++- 3 files changed, 7 insertions(+), 3 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 8587da1923..072f8c00ad 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,7 +4,7 @@ ### New features - Adds Savage-Dickey density ratio plot for Bayes factor approximation. ([2037](https://github.com/arviz-devs/arviz/pull/2037), [2152](https://github.com/arviz-devs/arviz/pull/2152)) - +- Adds rug plot for observed variables to `plot_ppc`. ([2161](https://github.com/arviz-devs/arviz/pull/2161)) ### Maintenance and fixes - Fix dimension ordering for `plot_trace` with divergences ([2151](https://github.com/arviz-devs/arviz/pull/2151)) diff --git a/arviz/tests/base_tests/test_plots_bokeh.py b/arviz/tests/base_tests/test_plots_bokeh.py index 7557ea6584..46dced9def 100644 --- a/arviz/tests/base_tests/test_plots_bokeh.py +++ b/arviz/tests/base_tests/test_plots_bokeh.py @@ -878,12 +878,14 @@ def test_plot_violin_discrete(discrete_model): @pytest.mark.parametrize("kind", ["kde", "cumulative", "scatter"]) @pytest.mark.parametrize("alpha", [None, 0.2, 1]) @pytest.mark.parametrize("observed", [True, False]) -def test_plot_ppc(models, kind, alpha, observed): +@pytest.mark.parametrize("observed_rug", [False, True]) +def test_plot_ppc(models, kind, alpha, observed, observed_rug): axes = plot_ppc( models.model_1, kind=kind, alpha=alpha, observed=observed, + observed_rug=observed_rug, random_seed=3, backend="bokeh", show=False, diff --git a/arviz/tests/base_tests/test_plots_matplotlib.py b/arviz/tests/base_tests/test_plots_matplotlib.py index e2f66dd7d0..439ee5a5f0 100644 --- a/arviz/tests/base_tests/test_plots_matplotlib.py +++ b/arviz/tests/base_tests/test_plots_matplotlib.py @@ -711,7 +711,8 @@ def test_plot_pair_shared(sharex, sharey, marginals): @pytest.mark.parametrize("alpha", [None, 0.2, 1]) @pytest.mark.parametrize("animated", [False, True]) @pytest.mark.parametrize("observed", [True, False]) -def test_plot_ppc(models, kind, alpha, animated, observed): +@pytest.mark.parametrize("observed_rug", [False, True]) +def test_plot_ppc(models, kind, alpha, animated, observed, observed_rug): if animation and not animation.writers.is_available("ffmpeg"): pytest.skip("matplotlib animations within ArviZ require ffmpeg") animation_kwargs = {"blit": False} @@ -720,6 +721,7 @@ def test_plot_ppc(models, kind, alpha, animated, observed): kind=kind, alpha=alpha, observed=observed, + observed_rug=observed_rug, animated=animated, animation_kwargs=animation_kwargs, random_seed=3,