-
-
Notifications
You must be signed in to change notification settings - Fork 426
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Initial prototype of plot_lm (#1727)
* Updated all * fixed errors * fixed typos in tests * added docstrings to test * generated docs * updtaed tests * Removed print * fixed docstring * docstring * Updated docstring * changed sample to __sample__
- Loading branch information
1 parent
c9f6b05
commit ff518db
Showing
9 changed files
with
848 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,170 @@ | ||
"""Bokeh linear regression plot.""" | ||
import numpy as np | ||
from bokeh.models.annotations import Legend | ||
|
||
from ...hdiplot import plot_hdi | ||
|
||
from ...plot_utils import _scale_fig_size | ||
from .. import show_layout | ||
from . import backend_kwarg_defaults, create_axes_grid | ||
|
||
|
||
def plot_lm( | ||
x, | ||
y, | ||
y_model, | ||
y_hat, | ||
num_samples, | ||
kind_pp, | ||
kind_model, | ||
length_plotters, | ||
xjitter, | ||
rows, | ||
cols, | ||
y_kwargs, | ||
y_hat_plot_kwargs, | ||
y_hat_fill_kwargs, | ||
y_model_plot_kwargs, | ||
y_model_fill_kwargs, | ||
backend_kwargs, | ||
show, | ||
figsize, | ||
textsize, | ||
axes, | ||
legend, | ||
grid, # pylint: disable=unused-argument | ||
): | ||
"""Bokeh linreg plot.""" | ||
if backend_kwargs is None: | ||
backend_kwargs = {} | ||
|
||
backend_kwargs = { | ||
**backend_kwarg_defaults(), | ||
**backend_kwargs, | ||
} | ||
|
||
figsize, *_ = _scale_fig_size(figsize, textsize, rows, cols) | ||
if axes is None: | ||
axes = create_axes_grid(length_plotters, rows, cols, backend_kwargs=backend_kwargs) | ||
|
||
if y_kwargs is None: | ||
y_kwargs = {} | ||
y_kwargs.setdefault("fill_color", "red") | ||
y_kwargs.setdefault("line_width", 0) | ||
y_kwargs.setdefault("size", 3) | ||
|
||
if y_hat_plot_kwargs is None: | ||
y_hat_plot_kwargs = {} | ||
y_hat_plot_kwargs.setdefault("fill_color", "orange") | ||
y_hat_plot_kwargs.setdefault("line_width", 0) | ||
|
||
if y_hat_fill_kwargs is None: | ||
y_hat_fill_kwargs = {} | ||
y_hat_fill_kwargs.setdefault("color", "orange") | ||
|
||
if y_model_plot_kwargs is None: | ||
y_model_plot_kwargs = {} | ||
y_model_plot_kwargs.setdefault("line_color", "black") | ||
y_model_plot_kwargs.setdefault("line_alpha", 0.5) | ||
y_model_plot_kwargs.setdefault("line_width", 0.5) | ||
|
||
if y_model_fill_kwargs is None: | ||
y_model_fill_kwargs = {} | ||
y_model_fill_kwargs.setdefault("color", "black") | ||
y_model_fill_kwargs.setdefault("alpha", 0.5) | ||
|
||
for i, ax_i in enumerate((item for item in axes.flatten() if item is not None)): | ||
|
||
_, _, _, y_plotters = y[i] | ||
_, _, _, x_plotters = x[i] | ||
legend_it = [] | ||
observed_legend = ax_i.circle(x_plotters, y_plotters, **y_kwargs) | ||
legend_it.append(("Observed", [observed_legend])) | ||
|
||
if y_hat is not None: | ||
_, _, _, y_hat_plotters = y_hat[i] | ||
if kind_pp == "samples": | ||
posterior_legend = [] | ||
for j in range(num_samples): | ||
if xjitter is True: | ||
jitter_scale = x_plotters[1] - x_plotters[0] | ||
scale_high = jitter_scale * 0.2 | ||
x_plotters_jitter = x_plotters + np.random.uniform( | ||
low=-scale_high, high=scale_high, size=len(x_plotters) | ||
) | ||
posterior_circle = ax_i.circle( | ||
x_plotters_jitter, | ||
y_hat_plotters[..., j], | ||
alpha=0.2, | ||
**y_hat_plot_kwargs, | ||
) | ||
else: | ||
posterior_circle = ax_i.circle( | ||
x_plotters, y_hat_plotters[..., j], alpha=0.2, **y_hat_plot_kwargs | ||
) | ||
posterior_legend.append(posterior_circle) | ||
legend_it.append(("Posterior predictive samples", posterior_legend)) | ||
|
||
else: | ||
plot_hdi( | ||
x_plotters, | ||
y_hat_plotters, | ||
ax=ax_i, | ||
backend="bokeh", | ||
fill_kwargs=y_hat_fill_kwargs, | ||
show=False, | ||
) | ||
|
||
if y_model is not None: | ||
_, _, _, y_model_plotters = y_model[i] | ||
if kind_model == "lines": | ||
|
||
model_legend = ax_i.multi_line( | ||
[np.tile(x_plotters, (num_samples, 1))], | ||
[np.transpose(y_model_plotters)], | ||
**y_model_plot_kwargs, | ||
) | ||
legend_it.append(("Uncertainty in mean", [model_legend])) | ||
|
||
y_model_mean = np.mean(y_model_plotters, axis=1) | ||
x_plotters_edge = [min(x_plotters), max(x_plotters)] | ||
y_model_mean_edge = [min(y_model_mean), max(y_model_mean)] | ||
mean_legend = ax_i.line( | ||
x_plotters_edge, y_model_mean_edge, line_color="yellow", line_width=2 | ||
) | ||
legend_it.append(("Mean", [mean_legend])) | ||
|
||
else: | ||
plot_hdi( | ||
x_plotters, | ||
y_model_plotters, | ||
fill_kwargs=y_model_fill_kwargs, | ||
ax=ax_i, | ||
backend="bokeh", | ||
show=False, | ||
) | ||
|
||
y_model_mean = np.mean(y_model_plotters, axis=(0, 1)) | ||
x_plotters_edge = [min(x_plotters), max(x_plotters)] | ||
y_model_mean_edge = [min(y_model_mean), max(y_model_mean)] | ||
mean_legend = ax_i.line( | ||
x_plotters_edge, | ||
y_model_mean_edge, | ||
line_color="yellow", | ||
line_width=2, | ||
) | ||
legend_it.append(("Mean", [mean_legend])) | ||
|
||
if legend: | ||
legend = Legend( | ||
items=legend_it, | ||
location="top_left", | ||
orientation="vertical", | ||
) | ||
ax_i.add_layout(legend) | ||
if textsize is not None: | ||
ax_i.legend.label_text_font_size = f"{textsize}pt" | ||
ax_i.legend.click_policy = "hide" | ||
|
||
show_layout(axes, show) | ||
return axes |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,138 @@ | ||
"""Matplotlib plot linear regression figure.""" | ||
import matplotlib.pyplot as plt | ||
import numpy as np | ||
|
||
from ...plot_utils import _scale_fig_size | ||
from ...hdiplot import plot_hdi | ||
from . import create_axes_grid, matplotlib_kwarg_dealiaser, backend_show, backend_kwarg_defaults | ||
|
||
|
||
def plot_lm( | ||
x, | ||
y, | ||
y_model, | ||
y_hat, | ||
num_samples, | ||
kind_pp, | ||
kind_model, | ||
xjitter, | ||
length_plotters, | ||
rows, | ||
cols, | ||
y_kwargs, | ||
y_hat_plot_kwargs, | ||
y_hat_fill_kwargs, | ||
y_model_plot_kwargs, | ||
y_model_fill_kwargs, | ||
backend_kwargs, | ||
show, | ||
figsize, | ||
textsize, | ||
axes, | ||
legend, | ||
grid, | ||
): | ||
"""Matplotlib Linear Regression.""" | ||
if backend_kwargs is None: | ||
backend_kwargs = {} | ||
|
||
backend_kwargs = { | ||
**backend_kwarg_defaults(), | ||
**backend_kwargs, | ||
} | ||
|
||
figsize, _, _, xt_labelsize, _, _ = _scale_fig_size(figsize, textsize, rows, cols) | ||
backend_kwargs.setdefault("figsize", figsize) | ||
backend_kwargs.setdefault("squeeze", False) | ||
|
||
if axes is None: | ||
_, axes = create_axes_grid(length_plotters, rows, cols, backend_kwargs=backend_kwargs) | ||
|
||
for i, ax_i in enumerate(np.ravel(axes)[:length_plotters]): | ||
|
||
# All the kwargs are defined here beforehand | ||
y_kwargs = matplotlib_kwarg_dealiaser(y_kwargs, "plot") | ||
y_kwargs.setdefault("color", "C3") | ||
y_kwargs.setdefault("marker", ".") | ||
y_kwargs.setdefault("markersize", 15) | ||
y_kwargs.setdefault("linewidth", 0) | ||
y_kwargs.setdefault("zorder", 10) | ||
y_kwargs.setdefault("label", "observed_data") | ||
|
||
y_hat_plot_kwargs = matplotlib_kwarg_dealiaser(y_hat_plot_kwargs, "plot") | ||
y_hat_plot_kwargs.setdefault("color", "C1") | ||
y_hat_plot_kwargs.setdefault("alpha", 0.3) | ||
y_hat_plot_kwargs.setdefault("markersize", 10) | ||
y_hat_plot_kwargs.setdefault("marker", ".") | ||
y_hat_plot_kwargs.setdefault("linewidth", 0) | ||
|
||
y_hat_fill_kwargs = matplotlib_kwarg_dealiaser(y_hat_fill_kwargs, "fill_between") | ||
y_hat_fill_kwargs.setdefault("color", "C1") | ||
|
||
y_model_plot_kwargs = matplotlib_kwarg_dealiaser(y_model_plot_kwargs, "plot") | ||
y_model_plot_kwargs.setdefault("color", "k") | ||
y_model_plot_kwargs.setdefault("alpha", 0.5) | ||
y_model_plot_kwargs.setdefault("linewidth", 0.5) | ||
y_model_plot_kwargs.setdefault("zorder", 9) | ||
|
||
y_model_fill_kwargs = matplotlib_kwarg_dealiaser(y_model_fill_kwargs, "fill_between") | ||
y_model_fill_kwargs.setdefault("color", "k") | ||
y_model_fill_kwargs.setdefault("linewidth", 0.5) | ||
y_model_fill_kwargs.setdefault("zorder", 9) | ||
y_model_fill_kwargs.setdefault("alpha", 0.5) | ||
|
||
y_var_name, _, _, y_plotters = y[i] | ||
x_var_name, _, _, x_plotters = x[i] | ||
ax_i.plot(x_plotters, y_plotters, **y_kwargs) | ||
ax_i.set_xlabel(x_var_name) | ||
ax_i.set_ylabel(y_var_name) | ||
|
||
if y_hat is not None: | ||
_, _, _, y_hat_plotters = y_hat[i] | ||
if kind_pp == "samples": | ||
for j in range(num_samples): | ||
if xjitter is True: | ||
jitter_scale = x_plotters[1] - x_plotters[0] | ||
scale_high = jitter_scale * 0.2 | ||
x_plotters_jitter = x_plotters + np.random.uniform( | ||
low=-scale_high, high=scale_high, size=len(x_plotters) | ||
) | ||
ax_i.plot(x_plotters_jitter, y_hat_plotters[..., j], **y_hat_plot_kwargs) | ||
else: | ||
ax_i.plot(x_plotters, y_hat_plotters[..., j], **y_hat_plot_kwargs) | ||
ax_i.plot([], **y_hat_plot_kwargs, label="Posterior predictive samples") | ||
else: | ||
plot_hdi(x_plotters, y_hat_plotters, ax=ax_i, **y_hat_fill_kwargs) | ||
ax_i.plot( | ||
[], color=y_hat_fill_kwargs["color"], label="Posterior predictive samples" | ||
) | ||
|
||
if y_model is not None: | ||
_, _, _, y_model_plotters = y_model[i] | ||
if kind_model == "lines": | ||
for j in range(num_samples): | ||
ax_i.plot(x_plotters, y_model_plotters[..., j], **y_model_plot_kwargs) | ||
ax_i.plot([], **y_model_plot_kwargs, label="Uncertainty in mean") | ||
|
||
y_model_mean = np.mean(y_model_plotters, axis=1) | ||
ax_i.plot(x_plotters, y_model_mean, color="y", lw=0.8, zorder=11, label="Mean") | ||
else: | ||
plot_hdi( | ||
x_plotters, | ||
y_model_plotters, | ||
fill_kwargs=y_model_fill_kwargs, | ||
ax=ax_i, | ||
) | ||
ax_i.plot([], color=y_model_fill_kwargs["color"], label="Uncertainty in mean") | ||
|
||
y_model_mean = np.mean(y_model_plotters, axis=(0, 1)) | ||
ax_i.plot(x_plotters, y_model_mean, color="y", lw=0.8, zorder=11, label="Mean") | ||
|
||
if legend: | ||
ax_i.legend(fontsize=xt_labelsize, loc="upper left") | ||
if grid: | ||
ax_i.grid(True) | ||
|
||
if backend_show(show): | ||
plt.show() | ||
return axes |
Oops, something went wrong.