From 331780061206d1279a06dcc7cfc040d910343290 Mon Sep 17 00:00:00 2001 From: aloctavodia Date: Thu, 15 Oct 2020 12:57:45 -0300 Subject: [PATCH 1/4] add ref_line, bar, vlines and marker_vlines kwargs --- arviz/plots/backends/bokeh/rankplot.py | 52 ++++++++++++------- arviz/plots/backends/matplotlib/rankplot.py | 48 +++++++++++------ arviz/plots/rankplot.py | 20 +++++++ arviz/tests/base_tests/test_plots_bokeh.py | 43 ++++++--------- .../tests/base_tests/test_plots_matplotlib.py | 36 ++++++------- 5 files changed, 117 insertions(+), 82 deletions(-) diff --git a/arviz/plots/backends/bokeh/rankplot.py b/arviz/plots/backends/bokeh/rankplot.py index d8f0d779d8..3d90da8f8a 100644 --- a/arviz/plots/backends/bokeh/rankplot.py +++ b/arviz/plots/backends/bokeh/rankplot.py @@ -23,17 +23,36 @@ def plot_rank( colors, ref_line, labels, + ref_line_kwargs, + bar_kwargs, + vlines_kwargs, + marker_vlines_kwargs, backend_kwargs, show, ): """Bokeh rank plot.""" + if ref_line_kwargs is None: + ref_line_kwargs = {} + ref_line_kwargs.setdefault("line_dash", "dashed") + ref_line_kwargs.setdefault("line_color", "black") + + if bar_kwargs is None: + bar_kwargs = {} + bar_kwargs.setdefault("line_color", "white") + + if vlines_kwargs is None: + vlines_kwargs = {} + vlines_kwargs.setdefault("line_width", 2) + vlines_kwargs.setdefault("line_dash", "solid") + + if marker_vlines_kwargs is None: + marker_vlines_kwargs = {} + if backend_kwargs is None: backend_kwargs = {} backend_kwargs = { - **backend_kwarg_defaults( - ("dpi", "plot.bokeh.figure.dpi"), - ), + **backend_kwarg_defaults(("dpi", "plot.bokeh.figure.dpi"),), **backend_kwargs, } figsize, *_ = _scale_fig_size(figsize, None, rows=rows, cols=cols) @@ -62,6 +81,7 @@ def plot_rank( gap = 1 width = bin_ary[1] - bin_ary[0] + bar_kwargs.setdefault("width", width) # Center the bins bin_ary = (bin_ary[1:] + bin_ary[:-1]) / 2 @@ -74,34 +94,30 @@ def plot_rank( x=bin_ary, top=y_ticks[-1] + counts, bottom=y_ticks[-1], - width=width, fill_color=colors[idx], - line_color="white", + **bar_kwargs, ) if ref_line: - hline = Span( - location=y_ticks[-1] + counts.mean(), line_dash="dashed", line_color="black" - ) + hline = Span(location=y_ticks[-1] + counts.mean(), **ref_line_kwargs) ax.add_layout(hline) if labels: ax.yaxis.axis_label = "Chain" elif kind == "vlines": ymin = np.full(len(all_counts), all_counts.mean()) for idx, counts in enumerate(all_counts): - ax.circle(bin_ary, counts, fill_color=colors[idx], line_color=colors[idx]) - - x_locations = [(bin, bin) for bin in bin_ary] - y_locations = [(ymin[idx], counts_) for counts_ in counts] - ax.multi_line( - x_locations, - y_locations, - line_dash="solid", + ax.circle( + bin_ary, + counts, + fill_color=colors[idx], line_color=colors[idx], - line_width=3, + **marker_vlines_kwargs, ) + x_locations = [(bin, bin) for bin in bin_ary] + y_locations = [(ymin[idx], counts_) for counts_ in counts] + ax.multi_line(x_locations, y_locations, line_color=colors[idx], **vlines_kwargs) if ref_line: - hline = Span(location=all_counts.mean(), line_dash="dashed", line_color="black") + hline = Span(location=all_counts.mean(), **ref_line_kwargs) ax.add_layout(hline) if labels: diff --git a/arviz/plots/backends/matplotlib/rankplot.py b/arviz/plots/backends/matplotlib/rankplot.py index 04db6ed8a0..fee1ec1d85 100644 --- a/arviz/plots/backends/matplotlib/rankplot.py +++ b/arviz/plots/backends/matplotlib/rankplot.py @@ -20,10 +20,32 @@ def plot_rank( colors, ref_line, labels, + ref_line_kwargs, + bar_kwargs, + vlines_kwargs, + marker_vlines_kwargs, backend_kwargs, show, ): """Matplotlib rankplot..""" + if ref_line_kwargs is None: + ref_line_kwargs = {} + ref_line_kwargs.setdefault("linestyle", "--") + ref_line_kwargs.setdefault("color", "k") + + if bar_kwargs is None: + bar_kwargs = {} + bar_kwargs.setdefault("align", "center") + + if vlines_kwargs is None: + vlines_kwargs = {} + vlines_kwargs.setdefault("lw", 2) + + if marker_vlines_kwargs is None: + marker_vlines_kwargs = {} + marker_vlines_kwargs.setdefault("marker", "o") + marker_vlines_kwargs.setdefault("lw", 0) + if backend_kwargs is None: backend_kwargs = {} @@ -36,12 +58,7 @@ def plot_rank( backend_kwargs.setdefault("figsize", figsize) backend_kwargs.setdefault("squeeze", True) if axes is None: - _, axes = create_axes_grid( - length_plotters, - rows, - cols, - backend_kwargs=backend_kwargs, - ) + _, axes = create_axes_grid(length_plotters, rows, cols, backend_kwargs=backend_kwargs,) for ax, (var_name, selection, var_data) in zip(np.ravel(axes), plotters): ranks = scipy.stats.rankdata(var_data, method="average").reshape(var_data.shape) @@ -52,6 +69,8 @@ def plot_rank( gap = all_counts.max() * 1.05 width = bin_ary[1] - bin_ary[0] + bar_kwargs.setdefault("width", width) + bar_kwargs.setdefault("edgecolor", ax.get_facecolor()) # Center the bins bin_ary = (bin_ary[1:] + bin_ary[:-1]) / 2 @@ -60,26 +79,21 @@ def plot_rank( for idx, counts in enumerate(all_counts): y_ticks.append(idx * gap) ax.bar( - bin_ary, - counts, - bottom=y_ticks[-1], - width=width, - align="center", - color=colors[idx], - edgecolor=ax.get_facecolor(), + bin_ary, counts, bottom=y_ticks[-1], color=colors[idx], **bar_kwargs, ) if ref_line: - ax.axhline(y=y_ticks[-1] + counts.mean(), linestyle="--", color="k") + ax.axhline(y=y_ticks[-1] + counts.mean(), **ref_line_kwargs) if labels: ax.set_ylabel("Chain", fontsize=ax_labelsize) elif kind == "vlines": ymin = all_counts.mean() + for idx, counts in enumerate(all_counts): - ax.plot(bin_ary, counts, "o", color=colors[idx]) - ax.vlines(bin_ary, ymin, counts, lw=2, colors=colors[idx]) + ax.plot(bin_ary, counts, color=colors[idx], **marker_vlines_kwargs) + ax.vlines(bin_ary, ymin, counts, colors=colors[idx], **vlines_kwargs) ax.set_ylim(0, all_counts.mean() * 2) if ref_line: - ax.axhline(y=all_counts.mean(), linestyle="--", color="k") + ax.axhline(y=ymin, **ref_line_kwargs) if labels: ax.set_xlabel("Rank (all chains)", fontsize=ax_labelsize) diff --git a/arviz/plots/rankplot.py b/arviz/plots/rankplot.py index 2e1dfe7bd6..3d7475567a 100644 --- a/arviz/plots/rankplot.py +++ b/arviz/plots/rankplot.py @@ -24,6 +24,10 @@ def plot_rank( figsize=None, ax=None, backend=None, + ref_line_kwargs=None, + bar_kwargs=None, + vlines_kwargs=None, + marker_vlines_kwargs=None, backend_kwargs=None, show=None, ): @@ -80,6 +84,18 @@ def plot_rank( its own array of plot areas (and return it). backend: str, optional Select plotting backend {"matplotlib","bokeh"}. Default "matplotlib". + ref_line_kwargs : dict, optional + Reference line keyword arguments, passed to :meth:`mpl:matplotlib.axes.Axes.axhline` or + :meth:`bokeh:bokeh.model.Span`. + bar_kwargs : dict, optional + Bars keyword arguments, passed to :meth:`mpl:matplotlib.axes.Axes.bar` or + :meth:`bokeh:bokeh.plotting.figure.Figure.vbar`. + vlines_kwargs : dict, optional + Vlines keyword arguments, passed to :meth:`mpl:matplotlib.axes.Axes.vlines` or + :meth:`bokeh:bokeh.plotting.figure.Figure.multi_line`. + marker_vlines_kwargs : dict, optional + Marker for the vlines keyword arguments, passed to :meth:`mpl:matplotlib.axes.Axes.plot` or + :meth:`bokeh:bokeh.plotting.figure.Figure.circle`. backend_kwargs: bool, optional These are kwargs specific to the backend being used. For additional documentation check the plotting method of the backend. @@ -161,6 +177,10 @@ def plot_rank( colors=colors, ref_line=ref_line, labels=labels, + ref_line_kwargs=ref_line_kwargs, + bar_kwargs=bar_kwargs, + vlines_kwargs=vlines_kwargs, + marker_vlines_kwargs=marker_vlines_kwargs, backend_kwargs=backend_kwargs, show=show, ) diff --git a/arviz/tests/base_tests/test_plots_bokeh.py b/arviz/tests/base_tests/test_plots_bokeh.py index 4fbb31e4c0..1417380137 100644 --- a/arviz/tests/base_tests/test_plots_bokeh.py +++ b/arviz/tests/base_tests/test_plots_bokeh.py @@ -100,34 +100,16 @@ def test_plot_density_discrete(discrete_model): def test_plot_density_no_subset(): """Test plot_density works when variables are not subset of one another (#1093).""" - model_ab = from_dict( - { - "a": np.random.normal(size=200), - "b": np.random.normal(size=200), - } - ) - model_bc = from_dict( - { - "b": np.random.normal(size=200), - "c": np.random.normal(size=200), - } - ) + model_ab = from_dict({"a": np.random.normal(size=200), "b": np.random.normal(size=200),}) + model_bc = from_dict({"b": np.random.normal(size=200), "c": np.random.normal(size=200),}) axes = plot_density([model_ab, model_bc], backend="bokeh", show=False) assert axes.size == 3 def test_plot_density_one_var(): """Test plot_density works when there is only one variable (#1401).""" - model_ab = from_dict( - { - "a": np.random.normal(size=200), - } - ) - model_bc = from_dict( - { - "a": np.random.normal(size=200), - } - ) + model_ab = from_dict({"a": np.random.normal(size=200),}) + model_bc = from_dict({"a": np.random.normal(size=200),}) axes = plot_density([model_ab, model_bc], backend="bokeh", show=False) assert axes.size == 1 @@ -1023,8 +1005,19 @@ def test_plot_posterior_point_estimates(models, point_estimate): {"var_names": "mu"}, {"var_names": ("mu", "tau"), "coords": {"theta_dim_0": [0, 1]}}, {"var_names": "mu", "ref_line": True}, + { + "var_names": "mu", + "ref_line_kwargs": {"line_width": 2, "line_color": "red"}, + "bar_kwargs": {"width": 50}, + }, {"var_names": "mu", "ref_line": False}, {"var_names": "mu", "kind": "vlines"}, + { + "var_names": "mu", + "kind": "vlines", + "vlines_kwargs": {"line_width": 0}, + "marker_vlines_kwargs": {"radius": 20}, + }, ], ) def test_plot_rank(models, kwargs): @@ -1056,9 +1049,5 @@ def test_plot_bpv_discrete(): fake_obs = {"a": np.random.poisson(2.5, 100)} fake_pp = {"a": np.random.poisson(2.5, (1, 10, 100))} fake_model = from_dict(posterior_predictive=fake_pp, observed_data=fake_obs) - axes = plot_bpv( - fake_model, - backend="bokeh", - show=False, - ) + axes = plot_bpv(fake_model, backend="bokeh", show=False,) assert axes.shape diff --git a/arviz/tests/base_tests/test_plots_matplotlib.py b/arviz/tests/base_tests/test_plots_matplotlib.py index 2a151519b0..d0ed48c2eb 100644 --- a/arviz/tests/base_tests/test_plots_matplotlib.py +++ b/arviz/tests/base_tests/test_plots_matplotlib.py @@ -132,18 +132,8 @@ def test_plot_density_discrete(discrete_model): def test_plot_density_no_subset(): """Test plot_density works when variables are not subset of one another (#1093).""" - model_ab = from_dict( - { - "a": np.random.normal(size=200), - "b": np.random.normal(size=200), - } - ) - model_bc = from_dict( - { - "b": np.random.normal(size=200), - "c": np.random.normal(size=200), - } - ) + model_ab = from_dict({"a": np.random.normal(size=200), "b": np.random.normal(size=200),}) + model_bc = from_dict({"b": np.random.normal(size=200), "c": np.random.normal(size=200),}) axes = plot_density([model_ab, model_bc]) assert axes.size == 3 @@ -201,12 +191,10 @@ def test_plot_trace(models, kwargs): @pytest.mark.parametrize( - "compact", - [True, False], + "compact", [True, False], ) @pytest.mark.parametrize( - "combined", - [True, False], + "combined", [True, False], ) def test_plot_trace_legend(compact, combined): idata = load_arviz_data("rugby") @@ -846,8 +834,19 @@ def test_plot_autocorr_var_names(models, var_names): {"var_names": "mu"}, {"var_names": ("mu", "tau"), "coords": {"theta_dim_0": [0, 1]}}, {"var_names": "mu", "ref_line": True}, + { + "var_names": "mu", + "ref_line_kwargs": {"lw": 2, "color": "C2"}, + "bar_kwargs": {"width": 0.7}, + }, {"var_names": "mu", "ref_line": False}, {"var_names": "mu", "kind": "vlines"}, + { + "var_names": "mu", + "kind": "vlines", + "vlines_kwargs": {"lw": 0}, + "marker_vlines_kwargs": {"lw": 3}, + }, ], ) def test_plot_rank(models, kwargs): @@ -1379,10 +1378,7 @@ def test_plot_dist_comparison(models, kwargs): def test_plot_dist_comparison_different_vars(): data = from_dict( - posterior={ - "x": np.random.randn(4, 100, 30), - }, - prior={"x_hat": np.random.randn(4, 100, 30)}, + posterior={"x": np.random.randn(4, 100, 30),}, prior={"x_hat": np.random.randn(4, 100, 30)}, ) with pytest.raises(KeyError): plot_dist_comparison(data, var_names="x") From 19f6f1722521c1609e39c4d4e8db926698ca2d9f Mon Sep 17 00:00:00 2001 From: aloctavodia Date: Thu, 15 Oct 2020 13:35:26 -0300 Subject: [PATCH 2/4] black --- arviz/plots/backends/bokeh/rankplot.py | 4 ++- arviz/plots/backends/matplotlib/rankplot.py | 13 ++++++-- arviz/tests/base_tests/test_plots_bokeh.py | 32 ++++++++++++++++--- .../tests/base_tests/test_plots_matplotlib.py | 25 ++++++++++++--- 4 files changed, 61 insertions(+), 13 deletions(-) diff --git a/arviz/plots/backends/bokeh/rankplot.py b/arviz/plots/backends/bokeh/rankplot.py index 3d90da8f8a..9b287bb214 100644 --- a/arviz/plots/backends/bokeh/rankplot.py +++ b/arviz/plots/backends/bokeh/rankplot.py @@ -52,7 +52,9 @@ def plot_rank( backend_kwargs = {} backend_kwargs = { - **backend_kwarg_defaults(("dpi", "plot.bokeh.figure.dpi"),), + **backend_kwarg_defaults( + ("dpi", "plot.bokeh.figure.dpi"), + ), **backend_kwargs, } figsize, *_ = _scale_fig_size(figsize, None, rows=rows, cols=cols) diff --git a/arviz/plots/backends/matplotlib/rankplot.py b/arviz/plots/backends/matplotlib/rankplot.py index fee1ec1d85..363f94e0b1 100644 --- a/arviz/plots/backends/matplotlib/rankplot.py +++ b/arviz/plots/backends/matplotlib/rankplot.py @@ -58,7 +58,12 @@ def plot_rank( backend_kwargs.setdefault("figsize", figsize) backend_kwargs.setdefault("squeeze", True) if axes is None: - _, axes = create_axes_grid(length_plotters, rows, cols, backend_kwargs=backend_kwargs,) + _, axes = create_axes_grid( + length_plotters, + rows, + cols, + backend_kwargs=backend_kwargs, + ) for ax, (var_name, selection, var_data) in zip(np.ravel(axes), plotters): ranks = scipy.stats.rankdata(var_data, method="average").reshape(var_data.shape) @@ -79,7 +84,11 @@ def plot_rank( for idx, counts in enumerate(all_counts): y_ticks.append(idx * gap) ax.bar( - bin_ary, counts, bottom=y_ticks[-1], color=colors[idx], **bar_kwargs, + bin_ary, + counts, + bottom=y_ticks[-1], + color=colors[idx], + **bar_kwargs, ) if ref_line: ax.axhline(y=y_ticks[-1] + counts.mean(), **ref_line_kwargs) diff --git a/arviz/tests/base_tests/test_plots_bokeh.py b/arviz/tests/base_tests/test_plots_bokeh.py index 1417380137..a05a38db4c 100644 --- a/arviz/tests/base_tests/test_plots_bokeh.py +++ b/arviz/tests/base_tests/test_plots_bokeh.py @@ -100,16 +100,34 @@ def test_plot_density_discrete(discrete_model): def test_plot_density_no_subset(): """Test plot_density works when variables are not subset of one another (#1093).""" - model_ab = from_dict({"a": np.random.normal(size=200), "b": np.random.normal(size=200),}) - model_bc = from_dict({"b": np.random.normal(size=200), "c": np.random.normal(size=200),}) + model_ab = from_dict( + { + "a": np.random.normal(size=200), + "b": np.random.normal(size=200), + } + ) + model_bc = from_dict( + { + "b": np.random.normal(size=200), + "c": np.random.normal(size=200), + } + ) axes = plot_density([model_ab, model_bc], backend="bokeh", show=False) assert axes.size == 3 def test_plot_density_one_var(): """Test plot_density works when there is only one variable (#1401).""" - model_ab = from_dict({"a": np.random.normal(size=200),}) - model_bc = from_dict({"a": np.random.normal(size=200),}) + model_ab = from_dict( + { + "a": np.random.normal(size=200), + } + ) + model_bc = from_dict( + { + "a": np.random.normal(size=200), + } + ) axes = plot_density([model_ab, model_bc], backend="bokeh", show=False) assert axes.size == 1 @@ -1049,5 +1067,9 @@ def test_plot_bpv_discrete(): fake_obs = {"a": np.random.poisson(2.5, 100)} fake_pp = {"a": np.random.poisson(2.5, (1, 10, 100))} fake_model = from_dict(posterior_predictive=fake_pp, observed_data=fake_obs) - axes = plot_bpv(fake_model, backend="bokeh", show=False,) + axes = plot_bpv( + fake_model, + backend="bokeh", + show=False, + ) assert axes.shape diff --git a/arviz/tests/base_tests/test_plots_matplotlib.py b/arviz/tests/base_tests/test_plots_matplotlib.py index d0ed48c2eb..238aba3bad 100644 --- a/arviz/tests/base_tests/test_plots_matplotlib.py +++ b/arviz/tests/base_tests/test_plots_matplotlib.py @@ -132,8 +132,18 @@ def test_plot_density_discrete(discrete_model): def test_plot_density_no_subset(): """Test plot_density works when variables are not subset of one another (#1093).""" - model_ab = from_dict({"a": np.random.normal(size=200), "b": np.random.normal(size=200),}) - model_bc = from_dict({"b": np.random.normal(size=200), "c": np.random.normal(size=200),}) + model_ab = from_dict( + { + "a": np.random.normal(size=200), + "b": np.random.normal(size=200), + } + ) + model_bc = from_dict( + { + "b": np.random.normal(size=200), + "c": np.random.normal(size=200), + } + ) axes = plot_density([model_ab, model_bc]) assert axes.size == 3 @@ -191,10 +201,12 @@ def test_plot_trace(models, kwargs): @pytest.mark.parametrize( - "compact", [True, False], + "compact", + [True, False], ) @pytest.mark.parametrize( - "combined", [True, False], + "combined", + [True, False], ) def test_plot_trace_legend(compact, combined): idata = load_arviz_data("rugby") @@ -1378,7 +1390,10 @@ def test_plot_dist_comparison(models, kwargs): def test_plot_dist_comparison_different_vars(): data = from_dict( - posterior={"x": np.random.randn(4, 100, 30),}, prior={"x_hat": np.random.randn(4, 100, 30)}, + posterior={ + "x": np.random.randn(4, 100, 30), + }, + prior={"x_hat": np.random.randn(4, 100, 30)}, ) with pytest.raises(KeyError): plot_dist_comparison(data, var_names="x") From 8b8299d6c98a4ab489d3ce6ef8d8a02820cf38c8 Mon Sep 17 00:00:00 2001 From: aloctavodia Date: Fri, 16 Oct 2020 08:16:24 -0300 Subject: [PATCH 3/4] add example --- arviz/plots/rankplot.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/arviz/plots/rankplot.py b/arviz/plots/rankplot.py index 3d7475567a..aa1ee1f48f 100644 --- a/arviz/plots/rankplot.py +++ b/arviz/plots/rankplot.py @@ -137,6 +137,13 @@ def plot_rank( >>> az.plot_rank(centered_data, var_names="mu", kind='vlines', ax=ax[0]) >>> az.plot_rank(noncentered_data, var_names="mu", kind='vlines', ax=ax[1]) + Change the aesthetics using kwargs + + .. plot:: + :context: close-figs + + >>> az.plot_rank(noncentered_data, var_names="mu", kind="vlines", + >>> vlines_kwargs={'lw':0}, marker_vlines_kwargs={'lw':3}); """ if transform is not None: data = transform(data) From 68193c5ca43ad49792b7886891edd42b727d67a3 Mon Sep 17 00:00:00 2001 From: aloctavodia Date: Fri, 16 Oct 2020 08:20:58 -0300 Subject: [PATCH 4/4] update changelog --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index a64a0fefee..d6ff25adbc 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,7 @@ * Added `to_dataframe` method to InferenceData ([1395](https://github.com/arviz-devs/arviz/pull/1395)) * Added `__getitem__` magic to InferenceData ([1395](https://github.com/arviz-devs/arviz/pull/1395)) * Added group argument to summary ([1408](https://github.com/arviz-devs/arviz/pull/1408)) +* Add `ref_line`, `bar`, `vlines` and `marker_vlines` kwargs to `plot_rank` ([1419](https://github.com/arviz-devs/arviz/pull/1419)) ### Maintenance and fixes * prevent wrapping group names in InferenceData repr_html ([1407](https://github.com/arviz-devs/arviz/pull/1407))