Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Initial prototype of plot_lm #1727

Merged
merged 11 commits into from
Jul 30, 2021
2 changes: 2 additions & 0 deletions arviz/plots/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from .jointplot import plot_joint
from .kdeplot import plot_kde
from .khatplot import plot_khat
from .lmplot import plot_lm
from .loopitplot import plot_loo_pit
from .mcseplot import plot_mcse
from .pairplot import plot_pair
Expand All @@ -38,6 +39,7 @@
"plot_joint",
"plot_kde",
"plot_khat",
"plot_lm",
"plot_loo_pit",
"plot_mcse",
"plot_pair",
Expand Down
170 changes: 170 additions & 0 deletions arviz/plots/backends/bokeh/lmplot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
"""Bokeh linear regression plot."""
import numpy as np
from bokeh.models.annotations import Legend

from ...hdiplot import plot_hdi

from ...plot_utils import _scale_fig_size
from .. import show_layout
from . import backend_kwarg_defaults, create_axes_grid


def plot_lm(
x,
y,
y_model,
y_hat,
num_samples,
kind_pp,
kind_model,
length_plotters,
xjitter,
rows,
cols,
y_kwargs,
y_hat_plot_kwargs,
y_hat_fill_kwargs,
y_model_plot_kwargs,
y_model_fill_kwargs,
backend_kwargs,
show,
figsize,
textsize,
axes,
legend,
grid, # pylint: disable=unused-argument
):
"""Bokeh linreg plot."""
if backend_kwargs is None:
backend_kwargs = {}

backend_kwargs = {
**backend_kwarg_defaults(),
**backend_kwargs,
}

figsize, *_ = _scale_fig_size(figsize, textsize, rows, cols)
if axes is None:
axes = create_axes_grid(length_plotters, rows, cols, backend_kwargs=backend_kwargs)

if y_kwargs is None:
y_kwargs = {}
y_kwargs.setdefault("fill_color", "red")
y_kwargs.setdefault("line_width", 0)
y_kwargs.setdefault("size", 3)

if y_hat_plot_kwargs is None:
y_hat_plot_kwargs = {}
y_hat_plot_kwargs.setdefault("fill_color", "orange")
y_hat_plot_kwargs.setdefault("line_width", 0)

if y_hat_fill_kwargs is None:
y_hat_fill_kwargs = {}
y_hat_fill_kwargs.setdefault("color", "orange")

if y_model_plot_kwargs is None:
y_model_plot_kwargs = {}
y_model_plot_kwargs.setdefault("line_color", "black")
y_model_plot_kwargs.setdefault("line_alpha", 0.5)
y_model_plot_kwargs.setdefault("line_width", 0.5)

if y_model_fill_kwargs is None:
y_model_fill_kwargs = {}
y_model_fill_kwargs.setdefault("color", "black")
y_model_fill_kwargs.setdefault("alpha", 0.5)

for i, ax_i in enumerate((item for item in axes.flatten() if item is not None)):

_, _, _, y_plotters = y[i]
_, _, _, x_plotters = x[i]
legend_it = []
observed_legend = ax_i.circle(x_plotters, y_plotters, **y_kwargs)
legend_it.append(("Observed", [observed_legend]))

if y_hat is not None:
_, _, _, y_hat_plotters = y_hat[i]
if kind_pp == "samples":
posterior_legend = []
for j in range(num_samples):
if xjitter is True:
jitter_scale = x_plotters[1] - x_plotters[0]
scale_high = jitter_scale * 0.2
x_plotters_jitter = x_plotters + np.random.uniform(
low=-scale_high, high=scale_high, size=len(x_plotters)
)
posterior_circle = ax_i.circle(
x_plotters_jitter,
y_hat_plotters[..., j],
alpha=0.2,
**y_hat_plot_kwargs,
)
else:
posterior_circle = ax_i.circle(
x_plotters, y_hat_plotters[..., j], alpha=0.2, **y_hat_plot_kwargs
)
posterior_legend.append(posterior_circle)
legend_it.append(("Posterior predictive samples", posterior_legend))

else:
plot_hdi(
x_plotters,
y_hat_plotters,
ax=ax_i,
backend="bokeh",
fill_kwargs=y_hat_fill_kwargs,
show=False,
)

if y_model is not None:
_, _, _, y_model_plotters = y_model[i]
if kind_model == "lines":

model_legend = ax_i.multi_line(
[np.tile(x_plotters, (num_samples, 1))],
[np.transpose(y_model_plotters)],
**y_model_plot_kwargs,
)
legend_it.append(("Uncertainty in mean", [model_legend]))

y_model_mean = np.mean(y_model_plotters, axis=1)
x_plotters_edge = [min(x_plotters), max(x_plotters)]
y_model_mean_edge = [min(y_model_mean), max(y_model_mean)]
mean_legend = ax_i.line(
x_plotters_edge, y_model_mean_edge, line_color="yellow", line_width=2
)
legend_it.append(("Mean", [mean_legend]))

else:
plot_hdi(
x_plotters,
y_model_plotters,
fill_kwargs=y_model_fill_kwargs,
ax=ax_i,
backend="bokeh",
show=False,
)

y_model_mean = np.mean(y_model_plotters, axis=(0, 1))
x_plotters_edge = [min(x_plotters), max(x_plotters)]
y_model_mean_edge = [min(y_model_mean), max(y_model_mean)]
mean_legend = ax_i.line(
x_plotters_edge,
y_model_mean_edge,
line_color="yellow",
line_width=2,
)
legend_it.append(("Mean", [mean_legend]))

if legend:
legend = Legend(
items=legend_it,
location="top_left",
orientation="vertical",
)
ax_i.add_layout(legend)
if textsize is not None:
ax_i.legend.label_text_font_size = f"{textsize}pt"
ax_i.legend.click_policy = "hide"

show_layout(axes, show)
return axes
138 changes: 138 additions & 0 deletions arviz/plots/backends/matplotlib/lmplot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
"""Matplotlib plot linear regression figure."""
import matplotlib.pyplot as plt
import numpy as np

from ...plot_utils import _scale_fig_size
from ...hdiplot import plot_hdi
from . import create_axes_grid, matplotlib_kwarg_dealiaser, backend_show, backend_kwarg_defaults


def plot_lm(
x,
y,
y_model,
y_hat,
num_samples,
kind_pp,
kind_model,
xjitter,
length_plotters,
rows,
cols,
y_kwargs,
y_hat_plot_kwargs,
y_hat_fill_kwargs,
y_model_plot_kwargs,
y_model_fill_kwargs,
backend_kwargs,
show,
figsize,
textsize,
axes,
legend,
grid,
):
"""Matplotlib Linear Regression."""
if backend_kwargs is None:
backend_kwargs = {}

backend_kwargs = {
**backend_kwarg_defaults(),
**backend_kwargs,
}

figsize, _, _, xt_labelsize, _, _ = _scale_fig_size(figsize, textsize, rows, cols)
backend_kwargs.setdefault("figsize", figsize)
backend_kwargs.setdefault("squeeze", False)

if axes is None:
_, axes = create_axes_grid(length_plotters, rows, cols, backend_kwargs=backend_kwargs)

for i, ax_i in enumerate(np.ravel(axes)[:length_plotters]):

# All the kwargs are defined here beforehand
y_kwargs = matplotlib_kwarg_dealiaser(y_kwargs, "plot")
y_kwargs.setdefault("color", "C3")
y_kwargs.setdefault("marker", ".")
y_kwargs.setdefault("markersize", 15)
y_kwargs.setdefault("linewidth", 0)
y_kwargs.setdefault("zorder", 10)
y_kwargs.setdefault("label", "observed_data")

y_hat_plot_kwargs = matplotlib_kwarg_dealiaser(y_hat_plot_kwargs, "plot")
y_hat_plot_kwargs.setdefault("color", "C1")
y_hat_plot_kwargs.setdefault("alpha", 0.3)
y_hat_plot_kwargs.setdefault("markersize", 10)
y_hat_plot_kwargs.setdefault("marker", ".")
y_hat_plot_kwargs.setdefault("linewidth", 0)

y_hat_fill_kwargs = matplotlib_kwarg_dealiaser(y_hat_fill_kwargs, "fill_between")
y_hat_fill_kwargs.setdefault("color", "C1")

y_model_plot_kwargs = matplotlib_kwarg_dealiaser(y_model_plot_kwargs, "plot")
y_model_plot_kwargs.setdefault("color", "k")
y_model_plot_kwargs.setdefault("alpha", 0.5)
y_model_plot_kwargs.setdefault("linewidth", 0.5)
y_model_plot_kwargs.setdefault("zorder", 9)

y_model_fill_kwargs = matplotlib_kwarg_dealiaser(y_model_fill_kwargs, "fill_between")
y_model_fill_kwargs.setdefault("color", "k")
y_model_fill_kwargs.setdefault("linewidth", 0.5)
y_model_fill_kwargs.setdefault("zorder", 9)
y_model_fill_kwargs.setdefault("alpha", 0.5)

y_var_name, _, _, y_plotters = y[i]
x_var_name, _, _, x_plotters = x[i]
ax_i.plot(x_plotters, y_plotters, **y_kwargs)
ax_i.set_xlabel(x_var_name)
ax_i.set_ylabel(y_var_name)

if y_hat is not None:
_, _, _, y_hat_plotters = y_hat[i]
if kind_pp == "samples":
for j in range(num_samples):
if xjitter is True:
jitter_scale = x_plotters[1] - x_plotters[0]
scale_high = jitter_scale * 0.2
x_plotters_jitter = x_plotters + np.random.uniform(
low=-scale_high, high=scale_high, size=len(x_plotters)
)
ax_i.plot(x_plotters_jitter, y_hat_plotters[..., j], **y_hat_plot_kwargs)
else:
ax_i.plot(x_plotters, y_hat_plotters[..., j], **y_hat_plot_kwargs)
ax_i.plot([], **y_hat_plot_kwargs, label="Posterior predictive samples")
else:
plot_hdi(x_plotters, y_hat_plotters, ax=ax_i, **y_hat_fill_kwargs)
ax_i.plot(
[], color=y_hat_fill_kwargs["color"], label="Posterior predictive samples"
)

if y_model is not None:
_, _, _, y_model_plotters = y_model[i]
if kind_model == "lines":
for j in range(num_samples):
ax_i.plot(x_plotters, y_model_plotters[..., j], **y_model_plot_kwargs)
ax_i.plot([], **y_model_plot_kwargs, label="Uncertainty in mean")

y_model_mean = np.mean(y_model_plotters, axis=1)
ax_i.plot(x_plotters, y_model_mean, color="y", lw=0.8, zorder=11, label="Mean")
else:
plot_hdi(
x_plotters,
y_model_plotters,
fill_kwargs=y_model_fill_kwargs,
ax=ax_i,
)
ax_i.plot([], color=y_model_fill_kwargs["color"], label="Uncertainty in mean")

y_model_mean = np.mean(y_model_plotters, axis=(0, 1))
ax_i.plot(x_plotters, y_model_mean, color="y", lw=0.8, zorder=11, label="Mean")

if legend:
ax_i.legend(fontsize=xt_labelsize, loc="upper left")
if grid:
ax_i.grid(True)

if backend_show(show):
plt.show()
return axes
Loading