Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

enable dimension order in selection #2103

Merged
merged 13 commits into from
Oct 6, 2022
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
* Bokeh kde contour plots started to use `contourpy` package ([2104](https://github.com/arviz-devs/arviz/pull/2104))
* Update default Bokeh markers for rcparams ([2104](https://github.com/arviz-devs/arviz/pull/2104))
* Correctly (re)order dimensions for `bfmi` and `plot_energy` ([2126](https://github.com/arviz-devs/arviz/pull/2126))
* Fix bug with the dimension order dependency ([2103](https://github.com/arviz-devs/arviz/pull/2103))

### Deprecation
* Removed `fill_last`, `contour` and `plot_kwargs` arguments from `plot_pair` function ([2085](https://github.com/arviz-devs/arviz/pull/2085))
Expand Down
3 changes: 2 additions & 1 deletion arviz/plots/autocorrplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,8 @@ def plot_autocorr(
labeller = BaseLabeller()

plotters = filter_plotters_list(
list(xarray_var_iter(data, var_names, combined)), "plot_autocorr"
list(xarray_var_iter(data, var_names, combined, dim_order=["chain", "draw"])),
"plot_autocorr",
)
rows, cols = default_grid(len(plotters), grid=grid)

Expand Down
1 change: 1 addition & 0 deletions arviz/plots/rankplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,7 @@ def plot_rank(
posterior_data,
var_names=var_names,
combined=True,
dim_order=["chain", "draw"],
)
),
"plot_rank",
Expand Down
8 changes: 7 additions & 1 deletion arviz/plots/traceplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,13 @@ def plot_trace(
skip_dims = set(coords_data.dims) - {"chain", "draw"} if compact else set()

plotters = list(
xarray_var_iter(coords_data, var_names=var_names, combined=True, skip_dims=skip_dims)
xarray_var_iter(
coords_data,
var_names=var_names,
combined=True,
skip_dims=skip_dims,
dim_order=["chain", "draw"],
)
)
max_plots = rcParams["plot.max_subplots"]
max_plots = len(plotters) if max_plots is None else max(max_plots // 2, 1)
Expand Down
44 changes: 16 additions & 28 deletions arviz/sel_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,31 +55,6 @@ def make_label(var_name, selection, position="below"):
return base.format(var_name, sel)


def purge_duplicates(list_in):
"""Remove duplicates from list while preserving order.

Parameters
----------
list_in: Iterable

Returns
-------
list
List of first occurrences in order
"""
# Algorithm taken from Stack Overflow,
# https://stackoverflow.com/questions/480214. Content by Georgy
# Skorobogatov (https://stackoverflow.com/users/7851470/georgy) and
# Markus Jarderot
# (https://stackoverflow.com/users/22364/markus-jarderot), licensed
# under CC-BY-SA 4.0.
# https://creativecommons.org/licenses/by-sa/4.0/.

seen = set()
seen_add = seen.add
return [x for x in list_in if not (x in seen or seen_add(x))]


def _dims(data, var_name, skip_dims):
return [dim for dim in data[var_name].dims if dim not in skip_dims]

Expand Down Expand Up @@ -136,7 +111,7 @@ def xarray_sel_iter(data, var_names=None, combined=False, skip_dims=None, revers
for var_name in var_names:
if var_name in data:
new_dims = _dims(data, var_name, skip_dims)
vals = [purge_duplicates(data[var_name][dim].values) for dim in new_dims]
vals = [list(dict.fromkeys(data[var_name][dim].values)) for dim in new_dims]
dims = _zip_dims(new_dims, vals)
idims = _zip_dims(new_dims, [range(len(v)) for v in vals])
if reverse_selections:
Expand All @@ -147,7 +122,9 @@ def xarray_sel_iter(data, var_names=None, combined=False, skip_dims=None, revers
yield var_name, selection, iselection


def xarray_var_iter(data, var_names=None, combined=False, skip_dims=None, reverse_selections=False):
def xarray_var_iter(
data, var_names=None, combined=False, skip_dims=None, reverse_selections=False, dim_order=None
):
"""Convert xarray data to an iterator over vectors.

Iterates over each var_name and all of its coordinates, returning the 1d
Expand All @@ -170,6 +147,9 @@ def xarray_var_iter(data, var_names=None, combined=False, skip_dims=None, revers
reverse_selections : bool
Whether to reverse selections before iterating.

dim_order: list
Order for the first dimensions. Skips dimensions not found in the variable.

Returns
-------
Iterator of (str, dict(str, any), np.array)
Expand All @@ -180,14 +160,22 @@ def xarray_var_iter(data, var_names=None, combined=False, skip_dims=None, revers
if var_names is None and isinstance(data, xr.DataArray):
data_to_sel = {data.name: data}

if isinstance(dim_order, str):
dim_order = [dim_order]

for var_name, selection, iselection in xarray_sel_iter(
data,
var_names=var_names,
combined=combined,
skip_dims=skip_dims,
reverse_selections=reverse_selections,
):
yield var_name, selection, iselection, data_to_sel[var_name].sel(**selection).values
selected_data = data_to_sel[var_name].sel(**selection)
if dim_order is not None:
dim_order_selected = [dim for dim in dim_order if dim in selected_data.dims]
if dim_order_selected:
selected_data = selected_data.transpose(*dim_order_selected, ...)
yield var_name, selection, iselection, selected_data.values


def xarray_to_ndarray(data, *, var_names=None, combined=True, label_fun=None):
Expand Down
20 changes: 15 additions & 5 deletions arviz/tests/helpers.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# pylint: disable=redefined-outer-name, comparison-with-callable
# pylint: disable=redefined-outer-name, comparison-with-callable, protected-access
"""Test helper functions."""
import gzip
import importlib
Expand Down Expand Up @@ -51,7 +51,7 @@ def chains():
return 2


def create_model(seed=10):
def create_model(seed=10, transpose=False):
"""Create model with fake data."""
np.random.seed(seed)
nchains = 4
Expand Down Expand Up @@ -104,10 +104,15 @@ def create_model(seed=10):
},
coords={"obs_dim": range(data["J"])},
)
if transpose:
for group in model._groups:
group_dataset = getattr(model, group)
if all(dim in group_dataset.dims for dim in ("draw", "chain")):
setattr(model, group, group_dataset.transpose(*["draw", "chain"], ...))
return model


def create_multidimensional_model(seed=10):
def create_multidimensional_model(seed=10, transpose=False):
"""Create model with fake data."""
np.random.seed(seed)
nchains = 4
Expand Down Expand Up @@ -155,6 +160,11 @@ def create_multidimensional_model(seed=10):
dims={"y": ["dim1", "dim2"], "log_likelihood": ["dim1", "dim2"]},
coords={"dim1": range(ndim1), "dim2": range(ndim2)},
)
if transpose:
for group in model._groups:
group_dataset = getattr(model, group)
if all(dim in group_dataset.dims for dim in ("draw", "chain")):
setattr(model, group, group_dataset.transpose(*["draw", "chain"], ...))
return model


Expand Down Expand Up @@ -195,7 +205,7 @@ def models():

class Models:
model_1 = create_model(seed=10)
model_2 = create_model(seed=11)
model_2 = create_model(seed=11, transpose=True)

return Models()

Expand All @@ -207,7 +217,7 @@ def multidim_models():

class Models:
model_1 = create_multidimensional_model(seed=10)
model_2 = create_multidimensional_model(seed=11)
model_2 = create_multidimensional_model(seed=11, transpose=True)

return Models()

Expand Down