diff --git a/arviz/plots/__init__.py b/arviz/plots/__init__.py index 5760fb3290..c7ace228e2 100644 --- a/arviz/plots/__init__.py +++ b/arviz/plots/__init__.py @@ -13,6 +13,7 @@ from .jointplot import plot_joint from .kdeplot import plot_kde from .khatplot import plot_khat +from .lmplot import plot_lm from .loopitplot import plot_loo_pit from .mcseplot import plot_mcse from .pairplot import plot_pair @@ -38,6 +39,7 @@ "plot_joint", "plot_kde", "plot_khat", + "plot_lm", "plot_loo_pit", "plot_mcse", "plot_pair", diff --git a/arviz/plots/backends/bokeh/lmplot.py b/arviz/plots/backends/bokeh/lmplot.py new file mode 100644 index 0000000000..1e694b7b12 --- /dev/null +++ b/arviz/plots/backends/bokeh/lmplot.py @@ -0,0 +1,170 @@ +"""Bokeh linear regression plot.""" +import numpy as np +from bokeh.models.annotations import Legend + +from ...hdiplot import plot_hdi + +from ...plot_utils import _scale_fig_size +from .. import show_layout +from . import backend_kwarg_defaults, create_axes_grid + + +def plot_lm( + x, + y, + y_model, + y_hat, + num_samples, + kind_pp, + kind_model, + length_plotters, + xjitter, + rows, + cols, + y_kwargs, + y_hat_plot_kwargs, + y_hat_fill_kwargs, + y_model_plot_kwargs, + y_model_fill_kwargs, + backend_kwargs, + show, + figsize, + textsize, + axes, + legend, + grid, # pylint: disable=unused-argument +): + """Bokeh linreg plot.""" + if backend_kwargs is None: + backend_kwargs = {} + + backend_kwargs = { + **backend_kwarg_defaults(), + **backend_kwargs, + } + + figsize, *_ = _scale_fig_size(figsize, textsize, rows, cols) + if axes is None: + axes = create_axes_grid(length_plotters, rows, cols, backend_kwargs=backend_kwargs) + + if y_kwargs is None: + y_kwargs = {} + y_kwargs.setdefault("fill_color", "red") + y_kwargs.setdefault("line_width", 0) + y_kwargs.setdefault("size", 3) + + if y_hat_plot_kwargs is None: + y_hat_plot_kwargs = {} + y_hat_plot_kwargs.setdefault("fill_color", "orange") + y_hat_plot_kwargs.setdefault("line_width", 0) + + if y_hat_fill_kwargs is None: + y_hat_fill_kwargs = {} + y_hat_fill_kwargs.setdefault("color", "orange") + + if y_model_plot_kwargs is None: + y_model_plot_kwargs = {} + y_model_plot_kwargs.setdefault("line_color", "black") + y_model_plot_kwargs.setdefault("line_alpha", 0.5) + y_model_plot_kwargs.setdefault("line_width", 0.5) + + if y_model_fill_kwargs is None: + y_model_fill_kwargs = {} + y_model_fill_kwargs.setdefault("color", "black") + y_model_fill_kwargs.setdefault("alpha", 0.5) + + for i, ax_i in enumerate((item for item in axes.flatten() if item is not None)): + + _, _, _, y_plotters = y[i] + _, _, _, x_plotters = x[i] + legend_it = [] + observed_legend = ax_i.circle(x_plotters, y_plotters, **y_kwargs) + legend_it.append(("Observed", [observed_legend])) + + if y_hat is not None: + _, _, _, y_hat_plotters = y_hat[i] + if kind_pp == "samples": + posterior_legend = [] + for j in range(num_samples): + if xjitter is True: + jitter_scale = x_plotters[1] - x_plotters[0] + scale_high = jitter_scale * 0.2 + x_plotters_jitter = x_plotters + np.random.uniform( + low=-scale_high, high=scale_high, size=len(x_plotters) + ) + posterior_circle = ax_i.circle( + x_plotters_jitter, + y_hat_plotters[..., j], + alpha=0.2, + **y_hat_plot_kwargs, + ) + else: + posterior_circle = ax_i.circle( + x_plotters, y_hat_plotters[..., j], alpha=0.2, **y_hat_plot_kwargs + ) + posterior_legend.append(posterior_circle) + legend_it.append(("Posterior predictive samples", posterior_legend)) + + else: + plot_hdi( + x_plotters, + y_hat_plotters, + ax=ax_i, + backend="bokeh", + fill_kwargs=y_hat_fill_kwargs, + show=False, + ) + + if y_model is not None: + _, _, _, y_model_plotters = y_model[i] + if kind_model == "lines": + + model_legend = ax_i.multi_line( + [np.tile(x_plotters, (num_samples, 1))], + [np.transpose(y_model_plotters)], + **y_model_plot_kwargs, + ) + legend_it.append(("Uncertainty in mean", [model_legend])) + + y_model_mean = np.mean(y_model_plotters, axis=1) + x_plotters_edge = [min(x_plotters), max(x_plotters)] + y_model_mean_edge = [min(y_model_mean), max(y_model_mean)] + mean_legend = ax_i.line( + x_plotters_edge, y_model_mean_edge, line_color="yellow", line_width=2 + ) + legend_it.append(("Mean", [mean_legend])) + + else: + plot_hdi( + x_plotters, + y_model_plotters, + fill_kwargs=y_model_fill_kwargs, + ax=ax_i, + backend="bokeh", + show=False, + ) + + y_model_mean = np.mean(y_model_plotters, axis=(0, 1)) + x_plotters_edge = [min(x_plotters), max(x_plotters)] + y_model_mean_edge = [min(y_model_mean), max(y_model_mean)] + mean_legend = ax_i.line( + x_plotters_edge, + y_model_mean_edge, + line_color="yellow", + line_width=2, + ) + legend_it.append(("Mean", [mean_legend])) + + if legend: + legend = Legend( + items=legend_it, + location="top_left", + orientation="vertical", + ) + ax_i.add_layout(legend) + if textsize is not None: + ax_i.legend.label_text_font_size = f"{textsize}pt" + ax_i.legend.click_policy = "hide" + + show_layout(axes, show) + return axes diff --git a/arviz/plots/backends/matplotlib/lmplot.py b/arviz/plots/backends/matplotlib/lmplot.py new file mode 100644 index 0000000000..f0eb761e29 --- /dev/null +++ b/arviz/plots/backends/matplotlib/lmplot.py @@ -0,0 +1,138 @@ +"""Matplotlib plot linear regression figure.""" +import matplotlib.pyplot as plt +import numpy as np + +from ...plot_utils import _scale_fig_size +from ...hdiplot import plot_hdi +from . import create_axes_grid, matplotlib_kwarg_dealiaser, backend_show, backend_kwarg_defaults + + +def plot_lm( + x, + y, + y_model, + y_hat, + num_samples, + kind_pp, + kind_model, + xjitter, + length_plotters, + rows, + cols, + y_kwargs, + y_hat_plot_kwargs, + y_hat_fill_kwargs, + y_model_plot_kwargs, + y_model_fill_kwargs, + backend_kwargs, + show, + figsize, + textsize, + axes, + legend, + grid, +): + """Matplotlib Linear Regression.""" + if backend_kwargs is None: + backend_kwargs = {} + + backend_kwargs = { + **backend_kwarg_defaults(), + **backend_kwargs, + } + + figsize, _, _, xt_labelsize, _, _ = _scale_fig_size(figsize, textsize, rows, cols) + backend_kwargs.setdefault("figsize", figsize) + backend_kwargs.setdefault("squeeze", False) + + if axes is None: + _, axes = create_axes_grid(length_plotters, rows, cols, backend_kwargs=backend_kwargs) + + for i, ax_i in enumerate(np.ravel(axes)[:length_plotters]): + + # All the kwargs are defined here beforehand + y_kwargs = matplotlib_kwarg_dealiaser(y_kwargs, "plot") + y_kwargs.setdefault("color", "C3") + y_kwargs.setdefault("marker", ".") + y_kwargs.setdefault("markersize", 15) + y_kwargs.setdefault("linewidth", 0) + y_kwargs.setdefault("zorder", 10) + y_kwargs.setdefault("label", "observed_data") + + y_hat_plot_kwargs = matplotlib_kwarg_dealiaser(y_hat_plot_kwargs, "plot") + y_hat_plot_kwargs.setdefault("color", "C1") + y_hat_plot_kwargs.setdefault("alpha", 0.3) + y_hat_plot_kwargs.setdefault("markersize", 10) + y_hat_plot_kwargs.setdefault("marker", ".") + y_hat_plot_kwargs.setdefault("linewidth", 0) + + y_hat_fill_kwargs = matplotlib_kwarg_dealiaser(y_hat_fill_kwargs, "fill_between") + y_hat_fill_kwargs.setdefault("color", "C1") + + y_model_plot_kwargs = matplotlib_kwarg_dealiaser(y_model_plot_kwargs, "plot") + y_model_plot_kwargs.setdefault("color", "k") + y_model_plot_kwargs.setdefault("alpha", 0.5) + y_model_plot_kwargs.setdefault("linewidth", 0.5) + y_model_plot_kwargs.setdefault("zorder", 9) + + y_model_fill_kwargs = matplotlib_kwarg_dealiaser(y_model_fill_kwargs, "fill_between") + y_model_fill_kwargs.setdefault("color", "k") + y_model_fill_kwargs.setdefault("linewidth", 0.5) + y_model_fill_kwargs.setdefault("zorder", 9) + y_model_fill_kwargs.setdefault("alpha", 0.5) + + y_var_name, _, _, y_plotters = y[i] + x_var_name, _, _, x_plotters = x[i] + ax_i.plot(x_plotters, y_plotters, **y_kwargs) + ax_i.set_xlabel(x_var_name) + ax_i.set_ylabel(y_var_name) + + if y_hat is not None: + _, _, _, y_hat_plotters = y_hat[i] + if kind_pp == "samples": + for j in range(num_samples): + if xjitter is True: + jitter_scale = x_plotters[1] - x_plotters[0] + scale_high = jitter_scale * 0.2 + x_plotters_jitter = x_plotters + np.random.uniform( + low=-scale_high, high=scale_high, size=len(x_plotters) + ) + ax_i.plot(x_plotters_jitter, y_hat_plotters[..., j], **y_hat_plot_kwargs) + else: + ax_i.plot(x_plotters, y_hat_plotters[..., j], **y_hat_plot_kwargs) + ax_i.plot([], **y_hat_plot_kwargs, label="Posterior predictive samples") + else: + plot_hdi(x_plotters, y_hat_plotters, ax=ax_i, **y_hat_fill_kwargs) + ax_i.plot( + [], color=y_hat_fill_kwargs["color"], label="Posterior predictive samples" + ) + + if y_model is not None: + _, _, _, y_model_plotters = y_model[i] + if kind_model == "lines": + for j in range(num_samples): + ax_i.plot(x_plotters, y_model_plotters[..., j], **y_model_plot_kwargs) + ax_i.plot([], **y_model_plot_kwargs, label="Uncertainty in mean") + + y_model_mean = np.mean(y_model_plotters, axis=1) + ax_i.plot(x_plotters, y_model_mean, color="y", lw=0.8, zorder=11, label="Mean") + else: + plot_hdi( + x_plotters, + y_model_plotters, + fill_kwargs=y_model_fill_kwargs, + ax=ax_i, + ) + ax_i.plot([], color=y_model_fill_kwargs["color"], label="Uncertainty in mean") + + y_model_mean = np.mean(y_model_plotters, axis=(0, 1)) + ax_i.plot(x_plotters, y_model_mean, color="y", lw=0.8, zorder=11, label="Mean") + + if legend: + ax_i.legend(fontsize=xt_labelsize, loc="upper left") + if grid: + ax_i.grid(True) + + if backend_show(show): + plt.show() + return axes diff --git a/arviz/plots/lmplot.py b/arviz/plots/lmplot.py new file mode 100644 index 0000000000..ec42b1227b --- /dev/null +++ b/arviz/plots/lmplot.py @@ -0,0 +1,341 @@ +"""Plot regression figure.""" +import warnings +from numbers import Integral + +import xarray as xr +import numpy as np +from xarray.core.dataarray import DataArray + +from ..sel_utils import xarray_var_iter +from ..rcparams import rcParams +from .plot_utils import default_grid, filter_plotters_list, get_plotting_function + + +def plot_lm( + y, + idata=None, + x=None, + y_model=None, + y_hat=None, + num_samples=50, + kind_pp="samples", + kind_model="lines", + xjitter=False, + plot_dim=None, + backend=None, + y_kwargs=None, + y_hat_plot_kwargs=None, + y_hat_fill_kwargs=None, + y_model_plot_kwargs=None, + y_model_fill_kwargs=None, + backend_kwargs=None, + show=None, + figsize=None, + textsize=None, + axes=None, + legend=True, + grid=True, +): + """Posterior predictive and mean plots for regression-like data. + + Parameters + ---------- + y : str or DataArray or ndarray + If str, variable name from observed_data + idata : InferenceData, Optional + Optional only if y is not str + x : str, tuple of strings, DataArray or array-like, optional + If str or tuple, variable name from constant_data + If ndarray, could be 1D, or 2D for multiple plots + If none, coords name of y (y should be DataArray). + y_model : str or Sequence, Optional + If str, variable name from posterior. + Its dimensions should be same as y plus added chains and draws. + y_hat : str, Optional + If str, variable name from posterior_predictive. + Its dimensions should be same as y plus added chains and draws. + num_samples : int, Optional, Default 50 + Significant if `kind_pp` is "samples" or `kind_model` is "lines". + Number of samples to be drawn from posterior predictive or + kind_pp : {"samples", "hdi"}, Default "samples" + Options to visualize uncertainty in data. + kind_model : {"lines", "hdi"}, Default "lines" + Options to visualize uncertainty in mean of the data. + plot_dim : str, Optional + Necessary if y is multidimensional. + backend : str, Optional + Select plotting backend {"matplotlib","bokeh"}. Default "matplotlib". + y_kwargs : dict, optional + Passed to :meth:`mpl:matplotlib.axes.Axes.plot` in matplotlib + and :meth:`bokeh:bokeh.plotting.Figure.circle` in bokeh + y_hat_plot_kwargs : dict, optional + Passed to :meth:`mpl:matplotlib.axes.Axes.plot` in matplotlib + and :meth:`bokeh:bokeh.plotting.Figure.circle` in bokeh + y_hat_fill_kwargs : dict, optional + Passed to {func}`~arviz.plot_hdi` + y_model_plot_kwargs : dict, optional + Passed to :meth:`mpl:matplotlib.axes.Axes.plot` in matplotlib + and :meth:`bokeh:bokeh.plotting.Figure.line` in bokeh + y_model_fill_kwargs : dict, optional + Significant if `kind_model` is "hdi". Passed to :func:`~arviz.plot_hdi` + backend_kwargs : dict, optional + These are kwargs specific to the backend being used. Passed to + :meth: `mpl:matplotlib.pyplot.subplots` or + :meth: `bokeh:bokeh.plotting.figure` + figsize : tuple, optional + Figure size. If None it will be defined automatically. + textsize : float, optional + Text size scaling factor for labels, titles and lines. If None it will be + autoscaled based on figsize. + axes : numpy array-like of matplotlib axes or bokeh figures, optional + A 2D array of locations into which to plot the densities. If not supplied, Arviz will create + its own array of plot areas (and return it). + show: bool, optional + Call backend show function. + legend : bool, optional + Add legend to figure. By default True. + grid : bool, optional + Add grid to figure. By default True. + + + Returns + ------- + axes: matplotlib axes or bokeh figures + + Examples + -------- + Plot regression default plot + + .. plot:: + :context: close-figs + + >>> import arviz as az + >>> idata = az.load_arviz_data('regression1d') + >>> x = xr.DataArray(np.linspace(0, 1, 100)) + >>> data.posterior["y_model"] = data.posterior["intercept"] + data.posterior["slope"]*x + >>> az.plot_lm(idata=data, y="y", x="x") + + Plot regression data and mean uncertainty + + .. plot: + :context: close-figs + + >>> az.plot_lm(idata=data, y="y", x="x", y_model="y_model") + + Plot regression data and mean uncertainty in hdi form + + .. plot: + :context: close-figs + + >>> az.plot_lm( + ... idata=data, y="y", x="x", y_model="y_model", kind_pp="hdi", kind_model="hdi" + ... ) + + Plot regression data for multi-dimensional y using plot_dim + + .. plot: + :context: close-figs + + >>> data = az.from_dict( + ... observed_data = { "y": np.random.normal(size=(5, 7)) }, + ... posterior_predictive = {"y": np.random.randn(4, 1000, 5, 7) / 2} + ... dims={"y": ["dim1", "dim2"]}, + ... coords={"dim1": range(5), "dim2": range(7)} + ... ) + >>> az.plot_lm(idata=data, y="y", plot_dim="dim1") + """ + if kind_pp not in ("samples", "hdi"): + raise ValueError("kind_ppc should be either samples or hdi") + + if kind_model not in ("lines", "hdi"): + raise ValueError("kind_model should be either lines or hdi") + + if y_hat is None and isinstance(y, str): + y_hat = y + + if isinstance(y, str): + y = idata.observed_data[y] + elif not isinstance(y, DataArray): + y = xr.DataArray(y) + + if len(y.dims) > 1 and plot_dim is None: + raise ValueError("Argument plot_dim is needed in case of multidimensional data") + + x_var_names = None + if isinstance(x, str): + x = idata.constant_data[x] + x_skip_dims = x.dims + elif isinstance(x, tuple): + x_var_names = x + x = idata.constant_data + x_skip_dims = x.dims + elif isinstance(x, DataArray): + x_skip_dims = x.dims + elif x is None: + if plot_dim is None: + x = y.coords[y.dims[0]] + else: + x = y.coords[plot_dim] + x_skip_dims = x.dims + else: + x = xr.DataArray(x) + x_skip_dims = [x.dims[-1]] + + # If posterior is present in idata and y_hat is there, get its values + if isinstance(y_model, str): + if "posterior" not in idata.groups(): + warnings.warn("Posterior not found in idata", UserWarning) + y_model = None + elif hasattr(idata.posterior, y_model): + y_model = idata.posterior[y_model] + else: + warnings.warn("y_model not found in posterior", UserWarning) + y_model = None + + # If posterior_predictive is present in idata and y_hat is there, get its values + if isinstance(y_hat, str): + if "posterior_predictive" not in idata.groups(): + warnings.warn("posterior_predictive not found in idata", UserWarning) + y_hat = None + elif hasattr(idata.posterior_predictive, y_hat): + y_hat = idata.posterior_predictive[y_hat] + else: + warnings.warn("y_hat not found in posterior_predictive", UserWarning) + y_hat = None + + # Check if num_pp_smaples is valid and generate num_pp_smaples number of random indexes. + # Only needed if kind_pp="samples" or kind_model="lines". Not req for plotting hdi + pp_sample_ix = None + if (y_hat is not None and kind_pp == "samples") or ( + y_model is not None and kind_model == "lines" + ): + if y_hat is not None: + total_pp_samples = y_hat.sizes["chain"] * y_hat.sizes["draw"] + else: + total_pp_samples = y_model.sizes["chain"] * y_model.sizes["draw"] + + if ( + not isinstance(num_samples, Integral) + or num_samples < 1 + or num_samples > total_pp_samples + ): + raise TypeError( + "`num_samples` must be an integer between 1 and " + + "{limit}.".format(limit=total_pp_samples) + ) + + pp_sample_ix = np.random.choice(total_pp_samples, size=num_samples, replace=False) + + # crucial step in case of multidim y + if plot_dim is None: + skip_dims = list(y.dims) + elif isinstance(plot_dim, str): + skip_dims = [plot_dim] + elif isinstance(plot_dim, tuple): + skip_dims = list(plot_dim) + + # Generate x axis plotters. + x = filter_plotters_list( + plotters=list( + xarray_var_iter( + x, + var_names=x_var_names, + skip_dims=set(x_skip_dims), + combined=True, + ) + ), + plot_kind="plot_lm", + ) + + # Generate y axis plotters + y = filter_plotters_list( + plotters=list( + xarray_var_iter( + y, + skip_dims=set(skip_dims), + combined=True, + ) + ), + plot_kind="plot_lm", + ) + + # If there are multiple x and multidimensional y, we need total of len(x)*len(y) graphs + len_y = len(y) + len_x = len(x) + length_plotters = len_x * len_y + y = np.tile(y, (len_x, 1)) + x = np.tile(x, (len_y, 1)) + + # Filter out the required values to generate plotters + if y_hat is not None: + if kind_pp == "samples": + y_hat = y_hat.stack(__sample__=("chain", "draw"))[..., pp_sample_ix] + skip_dims += ["__sample__"] + + y_hat = [ + tup + for _, tup in zip( + range(len_y), + xarray_var_iter( + y_hat, + skip_dims=set(skip_dims), + combined=True, + ), + ) + ] + + y_hat = np.tile(y_hat, (len_x, 1)) + + # Filter out the required values to generate plotters + if y_model is not None: + if kind_model == "lines": + y_model = y_model.stack(__sample__=("chain", "draw"))[..., pp_sample_ix] + + y_model = [ + tup + for _, tup in zip( + range(len_y), + xarray_var_iter( + y_model, + skip_dims=set(y_model.dims), + combined=True, + ), + ) + ] + y_model = np.tile(y_model, (len_x, 1)) + + rows, cols = default_grid(length_plotters) + + lmplot_kwargs = dict( + x=x, + y=y, + y_model=y_model, + y_hat=y_hat, + num_samples=num_samples, + kind_pp=kind_pp, + kind_model=kind_model, + length_plotters=length_plotters, + xjitter=xjitter, + rows=rows, + cols=cols, + y_kwargs=y_kwargs, + y_hat_plot_kwargs=y_hat_plot_kwargs, + y_hat_fill_kwargs=y_hat_fill_kwargs, + y_model_plot_kwargs=y_model_plot_kwargs, + y_model_fill_kwargs=y_model_fill_kwargs, + backend_kwargs=backend_kwargs, + show=show, + figsize=figsize, + textsize=textsize, + axes=axes, + legend=legend, + grid=grid, + ) + + if backend is None: + backend = rcParams["plot.backend"] + backend = backend.lower() + + plot = get_plotting_function("plot_lm", "lmplot", backend) + ax = plot(**lmplot_kwargs) + return ax diff --git a/arviz/tests/base_tests/test_plots_bokeh.py b/arviz/tests/base_tests/test_plots_bokeh.py index 16d1a3d247..25bcb92246 100644 --- a/arviz/tests/base_tests/test_plots_bokeh.py +++ b/arviz/tests/base_tests/test_plots_bokeh.py @@ -22,6 +22,7 @@ plot_joint, plot_kde, plot_khat, + plot_lm, plot_loo_pit, plot_mcse, plot_pair, @@ -1104,3 +1105,55 @@ def test_plot_bpv_discrete(): show=False, ) assert axes.shape + + +@pytest.mark.parametrize( + "kwargs", + [ + {}, + {"y_hat": "bad_name"}, + {"x": "x1"}, + {"x": ("x1", "x2")}, + { + "x": ("x1", "x2"), + "y_kwargs": {"fill_color": "blue"}, + "y_hat_plot_kwargs": {"fill_color": "orange"}, + "legend": True, + }, + {"x": ("x1", "x2"), "y_model_plot_kwargs": {"line_color": "red"}}, + { + "x": ("x1", "x2"), + "kind_pp": "hdi", + "kind_model": "hdi", + "y_model_fill_kwargs": {"color": "red"}, + "y_hat_fill_kwargs": {"color": "cyan"}, + }, + ], +) +def test_plot_lm(models, kwargs): + """Test functionality for 1D data.""" + idata = models.model_1 + if "constant_data" not in idata.groups(): + y = idata.observed_data["y"] + x1data = y.coords[y.dims[0]] + idata.add_groups({"constant_data": {"_": x1data}}) + idata.constant_data["x1"] = x1data + idata.constant_data["x2"] = x1data + + axes = plot_lm( + idata=idata, y="y", y_model="eta", backend="bokeh", xjitter=True, show=False, **kwargs + ) + assert np.all(axes) + + +def test_plot_lm_multidim(multidim_models): + """Test functionality for multidimentional data.""" + idata = multidim_models.model_1 + axes = plot_lm(idata=idata, y="y", plot_dim="dim1", show=False, backend="bokeh") + assert np.any(axes) + + +def test_plot_lm_list(): + """Test the plots when input data is list or ndarray.""" + y = [1, 2, 3, 4, 5] + assert plot_lm(y=y, x=np.arange(len(y)), show=False, backend="bokeh") diff --git a/arviz/tests/base_tests/test_plots_matplotlib.py b/arviz/tests/base_tests/test_plots_matplotlib.py index 1ddfa366fd..6c2838abdc 100644 --- a/arviz/tests/base_tests/test_plots_matplotlib.py +++ b/arviz/tests/base_tests/test_plots_matplotlib.py @@ -26,6 +26,7 @@ plot_joint, plot_kde, plot_khat, + plot_lm, plot_loo_pit, plot_mcse, plot_pair, @@ -1441,3 +1442,102 @@ def test_plot_bpv_discrete(): fake_model = from_dict(posterior_predictive=fake_pp, observed_data=fake_obs) axes = plot_bpv(fake_model) assert not isinstance(axes, np.ndarray) + + +@pytest.mark.parametrize( + "kwargs", + [ + {}, + {"y_hat": "bad_name"}, + {"x": "x1"}, + {"x": ("x1", "x2")}, + { + "x": ("x1", "x2"), + "y_kwargs": {"color": "blue", "marker": "^"}, + "y_hat_plot_kwargs": {"color": "cyan"}, + }, + {"x": ("x1", "x2"), "y_model_plot_kwargs": {"color": "red"}}, + { + "x": ("x1", "x2"), + "kind_pp": "hdi", + "kind_model": "hdi", + "y_model_fill_kwargs": {"color": "red"}, + "y_hat_fill_kwargs": {"color": "cyan"}, + }, + ], +) +def test_plot_lm(models, kwargs): + """Test functionality for 1D data.""" + idata = models.model_1 + if "constant_data" not in idata.groups(): + y = idata.observed_data["y"] + x1data = y.coords[y.dims[0]] + idata.add_groups({"constant_data": {"_": x1data}}) + idata.constant_data["x1"] = x1data + idata.constant_data["x2"] = x1data + + axes = plot_lm(idata=idata, y="y", y_model="eta", xjitter=True, **kwargs) + assert np.all(axes) + + +def test_plot_lm_multidim(multidim_models): + """Test functionality for multidimentional data.""" + idata = multidim_models.model_1 + axes = plot_lm( + idata=idata, + x=idata.observed_data["y"].coords["dim1"].values, + y="y", + xjitter=True, + plot_dim="dim1", + show=False, + figsize=(4, 16), + ) + assert np.all(axes) + + +@pytest.mark.parametrize( + "val_err_kwargs", + [ + {}, + {"kind_pp": "bad_kind"}, + {"kind_model": "bad_kind"}, + ], +) +def test_plot_lm_valueerror(multidim_models, val_err_kwargs): + """Test error plot_dim gets no value for multidim data and wrong value in kind_... args.""" + idata2 = multidim_models.model_1 + with pytest.raises(ValueError): + plot_lm(idata=idata2, y="y", **val_err_kwargs) + + +@pytest.mark.parametrize( + "warn_kwargs", + [ + {"y_hat": "bad_name"}, + {"y_model": "bad_name"}, + ], +) +def test_plot_lm_warning(models, warn_kwargs): + """Test Warning when needed groups or variables are not there in idata.""" + idata1 = models.model_1 + with pytest.warns(UserWarning): + plot_lm( + idata=from_dict(observed_data={"y": idata1.observed_data["y"].values}), + y="y", + **warn_kwargs, + ) + with pytest.warns(UserWarning): + plot_lm(idata=idata1, y="y", **warn_kwargs) + + +def test_plot_lm_typeerror(models): + """Test error when invalid value passed to num_samples.""" + idata1 = models.model_1 + with pytest.raises(TypeError): + plot_lm(idata=idata1, y="y", num_samples=-1) + + +def test_plot_lm_list(): + """Test the plots when input data is list or ndarray.""" + y = [1, 2, 3, 4, 5] + assert plot_lm(y=y, x=np.arange(len(y)), show=False) diff --git a/doc/source/api/plots.rst b/doc/source/api/plots.rst index ddf6d68c62..f1f68deadc 100644 --- a/doc/source/api/plots.rst +++ b/doc/source/api/plots.rst @@ -23,6 +23,7 @@ Plots plot_kde plot_khat plot_loo_pit + plot_lm plot_mcse plot_pair plot_parallel diff --git a/examples/bokeh/bokeh_plot_lm.py b/examples/bokeh/bokeh_plot_lm.py new file mode 100644 index 0000000000..0237b0c6ec --- /dev/null +++ b/examples/bokeh/bokeh_plot_lm.py @@ -0,0 +1,19 @@ +""" +Regression Plot. + +========================================== +_thumb: .6, .5 +""" +import xarray as xr +import numpy as np +import arviz as az + +data = az.load_arviz_data("regression1d") +x = xr.DataArray(np.linspace(0, 1, 100)) +data.add_groups({"constant_data": {"x1": x}}) +data.constant_data["x"] = x +data.posterior["y_model"] = ( + data.posterior["intercept"] + data.posterior["slope"] * data.constant_data["x"] +) + +az.plot_lm(idata=data, y="y", x="x", y_model="y_model", backend="bokeh", figsize=(12, 6)) diff --git a/examples/matplotlib/mpl_plot_lm.py b/examples/matplotlib/mpl_plot_lm.py new file mode 100644 index 0000000000..984f063b51 --- /dev/null +++ b/examples/matplotlib/mpl_plot_lm.py @@ -0,0 +1,24 @@ +""" +Regression Plot. + +=============================== +_thumb: .6, .5 +_example_title: Plot regression +""" +import matplotlib.pyplot as plt +import xarray as xr +import numpy as np +import arviz as az + +az.style.use("arviz-darkgrid") + +data = az.load_arviz_data("regression1d") +x = xr.DataArray(np.linspace(0, 1, 100)) +data.add_groups({"constant_data": {"x1": x}}) +data.constant_data["x"] = x +data.posterior["y_model"] = ( + data.posterior["intercept"] + data.posterior["slope"] * data.constant_data["x"] +) +az.plot_lm(idata=data, y="y", x="x", y_model="y_model", figsize=(12, 6)) + +plt.show()