Skip to content

Commit

Permalink
Merge 436d1c2 into 956a8cc
Browse files Browse the repository at this point in the history
  • Loading branch information
aloctavodia authored Aug 31, 2021
2 parents 956a8cc + 436d1c2 commit 0cc3483
Show file tree
Hide file tree
Showing 4 changed files with 64 additions and 34 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
* Improve error messages in `stats.compare()`, and `var_name` parameter. ([1616](https://github.com/arviz-devs/arviz/pull/1616))
* Added ability to plot HDI contours to `plot_kde` with the new `hdi_probs` parameter. ([1665](https://github.com/arviz-devs/arviz/pull/1665))
* Add dtype parsing and setting in all Stan converters ([1632](https://github.com/arviz-devs/arviz/pull/1632))
* Add option to specify colors for each element in ppc_plot ([1769](https://github.com/arviz-devs/arviz/pull/1769))

### Maintenance and fixes
* Fix conversion for numpyro models with ImproperUniform latent sites ([1713](https://github.com/arviz-devs/arviz/pull/1713))
Expand Down
40 changes: 25 additions & 15 deletions arviz/plots/backends/bokeh/ppcplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def plot_ppc(
pp_sample_ix,
kind,
alpha,
color,
colors,
textsize,
mean,
observed,
Expand All @@ -48,7 +48,7 @@ def plot_ppc(
**backend_kwargs,
}

color = vectorized_to_hex(color)
colors = vectorized_to_hex(colors)

(figsize, *_, linewidth, markersize) = _scale_fig_size(figsize, textsize, rows, cols)
if ax is None:
Expand Down Expand Up @@ -95,7 +95,11 @@ def plot_ppc(
pp_sampled_vals = pp_vals[pp_sample_ix]

if kind == "kde":
plot_kwargs = {"line_color": color, "line_alpha": alpha, "line_width": 0.5 * linewidth}
plot_kwargs = {
"line_color": colors[0],
"line_alpha": alpha,
"line_width": 0.5 * linewidth,
}

pp_densities = []
pp_xs = []
Expand Down Expand Up @@ -127,7 +131,7 @@ def plot_ppc(
if dtype == "f":
_, glyph = plot_kde(
obs_vals,
plot_kwargs={"line_color": "black", "line_width": linewidth},
plot_kwargs={"line_color": colors[1], "line_width": linewidth},
fill_kwargs={"alpha": 0},
ax=ax_i,
backend="bokeh",
Expand All @@ -143,7 +147,7 @@ def plot_ppc(
step = ax_i.step(
bin_edges,
hist,
line_color="black",
line_color=colors[1],
line_width=linewidth,
mode="center",
)
Expand All @@ -164,7 +168,7 @@ def plot_ppc(
line = ax_i.line(
new_x,
new_d.mean(0),
color=color,
color=colors[2],
line_dash="dashed",
line_width=linewidth,
)
Expand All @@ -177,7 +181,7 @@ def plot_ppc(
step = ax_i.step(
bin_edges,
hist,
line_color=color,
line_color=colors[2],
line_width=linewidth,
line_dash="dashed",
mode="center",
Expand All @@ -193,7 +197,7 @@ def plot_ppc(
if dtype == "f":
glyph = ax_i.line(
*_empirical_cdf(obs_vals),
line_color="black",
line_color=colors[1],
line_width=linewidth,
)
glyph.level = "overlay"
Expand All @@ -202,7 +206,7 @@ def plot_ppc(
else:
step = ax_i.step(
*_empirical_cdf(obs_vals),
line_color="black",
line_color=colors[1],
line_width=linewidth,
mode="center",
)
Expand All @@ -217,15 +221,15 @@ def plot_ppc(
list(pp_densities[::2]),
list(pp_densities[1::2]),
line_alpha=alpha,
line_color=color,
line_color=colors[0],
line_width=linewidth,
)
legend_it.append(("{} predictive".format(group.capitalize()), [multi_line]))
if mean:
label = "{} predictive mean".format(group.capitalize())
line = ax_i.line(
*_empirical_cdf(pp_vals.flatten()),
color=color,
color=colors[2],
line_dash="dashed",
line_width=linewidth,
)
Expand All @@ -238,7 +242,7 @@ def plot_ppc(
_, glyph = plot_kde(
pp_vals.flatten(),
plot_kwargs={
"line_color": color,
"line_color": colors[2],
"line_dash": "dashed",
"line_width": linewidth,
},
Expand All @@ -257,7 +261,7 @@ def plot_ppc(
step = ax_i.step(
bin_edges,
hist,
color=color,
color=colors[2],
line_width=linewidth,
line_dash="dashed",
mode="center",
Expand All @@ -279,7 +283,8 @@ def plot_ppc(
glyph = ax_i.circle(
obs_vals,
obs_yvals,
fill_color="black",
line_color=colors[1],
fill_color=colors[1],
size=markersize,
line_alpha=alpha,
)
Expand All @@ -293,7 +298,12 @@ def plot_ppc(
if jitter:
yvals += np.random.uniform(low=scale_low, high=scale_high, size=len(vals))
scatter = ax_i.scatter(
vals, yvals, fill_color=color, size=markersize, fill_alpha=alpha
vals,
yvals,
line_color=colors[0],
fill_color=colors[0],
size=markersize,
fill_alpha=alpha,
)
all_scatter.append(scatter)

Expand Down
40 changes: 23 additions & 17 deletions arviz/plots/backends/matplotlib/ppcplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def plot_ppc(
pp_sample_ix,
kind,
alpha,
color,
colors,
textsize,
mean,
observed,
Expand Down Expand Up @@ -122,16 +122,16 @@ def plot_ppc(
pp_sampled_vals = pp_vals[pp_sample_ix]

if kind == "kde":
plot_kwargs = {"color": color, "alpha": alpha, "linewidth": 0.5 * linewidth}
plot_kwargs = {"color": colors[0], "alpha": alpha, "linewidth": 0.5 * linewidth}
if dtype == "i":
plot_kwargs["drawstyle"] = "steps-pre"
ax_i.plot([], color=color, label="{} predictive".format(group.capitalize()))
ax_i.plot([], color=colors[0], label="{} predictive".format(group.capitalize()))
if observed:
if dtype == "f":
plot_kde(
obs_vals,
label="Observed",
plot_kwargs={"color": "k", "linewidth": linewidth, "zorder": 3},
plot_kwargs={"color": colors[1], "linewidth": linewidth, "zorder": 3},
fill_kwargs={"alpha": 0},
ax=ax_i,
legend=legend,
Expand All @@ -144,7 +144,7 @@ def plot_ppc(
bin_edges,
hist,
label="Observed",
color="k",
color=colors[1],
linewidth=linewidth,
zorder=3,
drawstyle=plot_kwargs["drawstyle"],
Expand Down Expand Up @@ -192,7 +192,7 @@ def plot_ppc(
ax_i.plot(
new_x,
new_d.mean(0),
color=color,
color=colors[2],
linestyle="--",
linewidth=linewidth * 1.5,
zorder=2,
Expand All @@ -206,7 +206,7 @@ def plot_ppc(
ax_i.plot(
bin_edges,
hist,
color=color,
color=colors[2],
linewidth=linewidth * 1.5,
label=label,
zorder=2,
Expand All @@ -221,7 +221,7 @@ def plot_ppc(
if observed:
ax_i.plot(
*_empirical_cdf(obs_vals),
color="k",
color=colors[1],
linewidth=linewidth,
label="Observed",
drawstyle=drawstyle,
Expand All @@ -248,15 +248,15 @@ def plot_ppc(
ax_i.plot(
*pp_densities,
alpha=alpha,
color=color,
color=colors[0],
drawstyle=drawstyle,
linewidth=linewidth
)
ax_i.plot([], color=color, label="Posterior predictive")
ax_i.plot([], color=colors[0], label="Posterior predictive")
if mean:
ax_i.plot(
*_empirical_cdf(pp_vals.flatten()),
color=color,
color=colors[2],
linestyle="--",
linewidth=linewidth * 1.5,
drawstyle=drawstyle,
Expand All @@ -270,7 +270,7 @@ def plot_ppc(
plot_kde(
pp_vals.flatten(),
plot_kwargs={
"color": color,
"color": colors[2],
"linestyle": "--",
"linewidth": linewidth * 1.5,
"zorder": 3,
Expand All @@ -287,7 +287,7 @@ def plot_ppc(
ax_i.plot(
bin_edges,
hist,
color=color,
color=colors[2],
linewidth=linewidth * 1.5,
label="Posterior predictive mean",
zorder=3,
Expand All @@ -312,7 +312,7 @@ def plot_ppc(
obs_vals,
obs_yvals,
"o",
color="k",
color=colors[1],
markersize=markersize,
alpha=alpha,
label="Observed",
Expand All @@ -324,7 +324,7 @@ def plot_ppc(
pp_sampled_vals,
ax_i,
kind=kind,
color=color,
color=colors[0],
height=y_rows.mean() * 0.5,
markersize=markersize,
)
Expand All @@ -336,10 +336,16 @@ def plot_ppc(
if jitter:
yvals += np.random.uniform(low=scale_low, high=scale_high, size=len(vals))
ax_i.plot(
vals, yvals, "o", zorder=2, color=color, markersize=markersize, alpha=alpha
vals,
yvals,
"o",
zorder=2,
color=colors[0],
markersize=markersize,
alpha=alpha,
)

ax_i.plot([], color=color, marker="o", label="Posterior predictive")
ax_i.plot([], color=colors[0], marker="o", label="Posterior predictive")

ax_i.set_yticks([])

Expand Down
17 changes: 15 additions & 2 deletions arviz/plots/ppcplot.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Posterior/Prior predictive plot."""
import logging
import warnings
from numbers import Integral

import numpy as np
Expand All @@ -19,7 +20,8 @@ def plot_ppc(
alpha=None,
mean=True,
observed=True,
color="C0",
color=None,
colors=None,
grid=None,
figsize=None,
textsize=None,
Expand Down Expand Up @@ -60,6 +62,10 @@ def plot_ppc(
Whether or not to plot the observed data.
color: str
Valid matplotlib color. Defaults to C0
color: list
List with valid matplotlib colors corresponding to the posterior/prior predictive
distribution, observed data and mean of the posterior/prior predictive distribution.
Defaults to ["C0", "k", "C1"]
grid : tuple
Number of rows and columns. Defaults to None, the rows and columns are
automatically inferred.
Expand Down Expand Up @@ -208,6 +214,13 @@ def plot_ppc(
if kind.lower() not in ("kde", "cumulative", "scatter"):
raise TypeError("`kind` argument must be either `kde`, `cumulative`, or `scatter`")

if colors is None:
colors = ["C0", "k", "C1"]

if color is not None:
warnings.warn("color has been deprecated in favor of colors", FutureWarning)
colors[0] = color

if data_pairs is None:
data_pairs = {}

Expand Down Expand Up @@ -308,7 +321,7 @@ def plot_ppc(
pp_sample_ix=pp_sample_ix,
kind=kind,
alpha=alpha,
color=color,
colors=colors,
jitter=jitter,
textsize=textsize,
mean=mean,
Expand Down

0 comments on commit 0cc3483

Please # to comment.