Skip to content

Commit

Permalink
Added Interactive legend to ppc_plot bokeh (#1602)
Browse files Browse the repository at this point in the history
* Added interactive legend to ppc_plot bokeh

* Minor Change

* Updated CHANGELOG and positioned legend

* Added test for textsize in ppc_plot bokeh
  • Loading branch information
Rishabh261998 authored Mar 10, 2021
1 parent c4005b1 commit 5c0581f
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 12 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
* Added `rope_color` and `ref_val_color` arguments to `plot_posterior` ([1570](https://github.com/arviz-devs/arviz/pull/1570))
* Improved retrieving or pointwise log likelihood in `from_cmdstanpy`, `from_cmdstan` and `from_pystan` ([1579](https://github.com/arviz-devs/arviz/pull/1579) and [1599](https://github.com/arviz-devs/arviz/pull/1599))
* Added interactive legend to bokeh `forestplot` ([1591](https://github.com/arviz-devs/arviz/pull/1591))
* Added interactive legend to bokeh `ppcplot` ([1602](https://github.com/arviz-devs/arviz/pull/1602))

### Maintenance and fixes
* Enforced using coordinate values as default labels ([1201](https://github.com/arviz-devs/arviz/pull/1201))
Expand Down
64 changes: 52 additions & 12 deletions arviz/plots/backends/bokeh/ppcplot.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Bokeh Posterior predictive plot."""
import numpy as np
from bokeh.models.annotations import Legend

from ....stats.density_utils import get_bins, histogram, kde
from ...kdeplot import plot_kde
Expand Down Expand Up @@ -86,6 +87,7 @@ def plot_ppc(
var_name, sel, isel, obs_vals = obs_plotters[i]
pp_var_name, _, _, pp_vals = pp_plotters[i]
dtype = predictive_dataset[pp_var_name].dtype.kind
legend_it = []

# flatten non-specified dimensions
obs_vals = obs_vals.flatten()
Expand All @@ -111,35 +113,44 @@ def plot_ppc(
pp_xs.append(bin_edges)

if dtype == "f":
ax_i.multi_line(pp_xs, pp_densities, **plot_kwargs)
multi_line = ax_i.multi_line(pp_xs, pp_densities, **plot_kwargs)
legend_it.append(("{} predictive".format(group.capitalize()), [multi_line]))
else:
all_steps = []
for x_s, y_s in zip(pp_xs, pp_densities):
ax_i.step(x_s, y_s, **plot_kwargs)
step = ax_i.step(x_s, y_s, **plot_kwargs)
all_steps.append(step)
legend_it.append(("{} predictive".format(group.capitalize()), all_steps))

if observed:
label = "Observed"
if dtype == "f":
plot_kde(
_, glyph = plot_kde(
obs_vals,
plot_kwargs={"line_color": "black", "line_width": linewidth},
fill_kwargs={"alpha": 0},
ax=ax_i,
backend="bokeh",
backend_kwargs={},
show=False,
return_glyph=True,
)
legend_it.append((label, glyph))
else:
bins = get_bins(obs_vals)
_, hist, bin_edges = histogram(obs_vals, bins=bins)
hist = np.concatenate((hist[:1], hist))
ax_i.step(
step = ax_i.step(
bin_edges,
hist,
line_color="black",
line_width=linewidth,
mode="center",
)
legend_it.append((label, [step]))

if mean:
label = "{} predictive mean".format(group.capitalize())
if dtype == "f":
rep = len(pp_densities)
len_density = len(pp_densities[0])
Expand All @@ -150,72 +161,81 @@ def plot_ppc(
new_x -= (new_x[1] - new_x[0]) / 2
for irep in range(rep):
new_d[irep][bins[irep]] = pp_densities[irep]
ax_i.line(
line = ax_i.line(
new_x,
new_d.mean(0),
color=color,
line_dash="dashed",
line_width=linewidth,
)
legend_it.append((label, [line]))
else:
vals = pp_vals.flatten()
bins = get_bins(vals)
_, hist, bin_edges = histogram(vals, bins=bins)
hist = np.concatenate((hist[:1], hist))
ax_i.step(
step = ax_i.step(
bin_edges,
hist,
line_color=color,
line_width=linewidth,
line_dash="dashed",
mode="center",
)
legend_it.append((label, [step]))
ax_i.yaxis.major_tick_line_color = None
ax_i.yaxis.minor_tick_line_color = None
ax_i.yaxis.major_label_text_font_size = "0pt"

elif kind == "cumulative":
if observed:
label = "Observed"
if dtype == "f":
glyph = ax_i.line(
*_empirical_cdf(obs_vals),
line_color="black",
line_width=linewidth,
)
glyph.level = "overlay"
legend_it.append((label, [glyph]))

else:
ax_i.step(
step = ax_i.step(
*_empirical_cdf(obs_vals),
line_color="black",
line_width=linewidth,
mode="center",
)
legend_it.append((label, [step]))
pp_densities = np.empty((2 * len(pp_sampled_vals), pp_sampled_vals[0].size))
for idx, vals in enumerate(pp_sampled_vals):
vals = np.array([vals]).flatten()
pp_x, pp_density = _empirical_cdf(vals)
pp_densities[2 * idx] = pp_x
pp_densities[2 * idx + 1] = pp_density
ax_i.multi_line(
multi_line = ax_i.multi_line(
list(pp_densities[::2]),
list(pp_densities[1::2]),
line_alpha=alpha,
line_color=color,
line_width=linewidth,
)
legend_it.append(("{} predictive".format(group.capitalize()), [multi_line]))
if mean:
ax_i.line(
label = "{} predictive mean".format(group.capitalize())
line = ax_i.line(
*_empirical_cdf(pp_vals.flatten()),
color=color,
line_dash="dashed",
line_width=linewidth,
)
legend_it.append((label, [line]))

elif kind == "scatter":
if mean:
label = "{} predictive mean".format(group.capitalize())
if dtype == "f":
plot_kde(
_, glyph = plot_kde(
pp_vals.flatten(),
plot_kwargs={
"line_color": color,
Expand All @@ -226,27 +246,31 @@ def plot_ppc(
backend="bokeh",
backend_kwargs={},
show=False,
return_glyph=True,
)
legend_it.append((label, glyph))
else:
vals = pp_vals.flatten()
bins = get_bins(vals)
_, hist, bin_edges = histogram(vals, bins=bins)
hist = np.concatenate((hist[:1], hist))
ax_i.step(
step = ax_i.step(
bin_edges,
hist,
color=color,
line_width=linewidth,
line_dash="dashed",
mode="center",
)
legend_it.append((label, [step]))

jitter_scale = 0.1
y_rows = np.linspace(0, 0.1, num_pp_samples + 1)
scale_low = 0
scale_high = jitter_scale * jitter

if observed:
label = "Observed"
obs_yvals = np.zeros_like(obs_vals, dtype=np.float64)
if jitter:
obs_yvals += np.random.uniform(
Expand All @@ -260,18 +284,34 @@ def plot_ppc(
line_alpha=alpha,
)
glyph.level = "overlay"
legend_it.append((label, [glyph]))

all_scatter = []
for vals, y in zip(pp_sampled_vals, y_rows[1:]):
vals = np.ravel(vals)
yvals = np.full_like(vals, y, dtype=np.float64)
if jitter:
yvals += np.random.uniform(low=scale_low, high=scale_high, size=len(vals))
ax_i.scatter(vals, yvals, fill_color=color, size=markersize, fill_alpha=alpha)
scatter = ax_i.scatter(
vals, yvals, fill_color=color, size=markersize, fill_alpha=alpha
)
all_scatter.append(scatter)

legend_it.append(("{} predictive".format(group.capitalize()), all_scatter))
ax_i.yaxis.major_tick_line_color = None
ax_i.yaxis.minor_tick_line_color = None
ax_i.yaxis.major_label_text_font_size = "0pt"

if legend:
legend = Legend(
items=legend_it,
location="top_left",
orientation="vertical",
)
ax_i.add_layout(legend)
if textsize is not None:
ax_i.legend.label_text_font_size = f"{textsize}pt"
ax_i.legend.click_policy = "hide"
ax_i.xaxis.axis_label = labeller.make_pp_label(var_name, pp_var_name, sel, isel)

show_layout(axes, show)
Expand Down
11 changes: 11 additions & 0 deletions arviz/tests/base_tests/test_plots_bokeh.py
Original file line number Diff line number Diff line change
Expand Up @@ -890,6 +890,17 @@ def test_plot_ppc(models, kind, alpha, observed):
assert axes


def test_plot_ppc_textsize(models):
axes = plot_ppc(
models.model_1,
textsize=10,
random_seed=3,
backend="bokeh",
show=False,
)
assert axes


@pytest.mark.parametrize("kind", ["kde", "cumulative", "scatter"])
@pytest.mark.parametrize("jitter", [None, 0, 0.1, 1, 3])
def test_plot_ppc_multichain(kind, jitter):
Expand Down

0 comments on commit 5c0581f

Please # to comment.