Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

enable dimension order in selection #2103

Merged
merged 13 commits into from
Oct 6, 2022
8 changes: 7 additions & 1 deletion arviz/plots/traceplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,13 @@ def plot_trace(
skip_dims = set(coords_data.dims) - {"chain", "draw"} if compact else set()

plotters = list(
xarray_var_iter(coords_data, var_names=var_names, combined=True, skip_dims=skip_dims)
xarray_var_iter(
coords_data,
var_names=var_names,
combined=True,
skip_dims=skip_dims,
dim_order=["chain", "draw"],
)
)
max_plots = rcParams["plot.max_subplots"]
max_plots = len(plotters) if max_plots is None else max(max_plots // 2, 1)
Expand Down
44 changes: 16 additions & 28 deletions arviz/sel_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,31 +55,6 @@ def make_label(var_name, selection, position="below"):
return base.format(var_name, sel)


def purge_duplicates(list_in):
"""Remove duplicates from list while preserving order.

Parameters
----------
list_in: Iterable

Returns
-------
list
List of first occurrences in order
"""
# Algorithm taken from Stack Overflow,
# https://stackoverflow.com/questions/480214. Content by Georgy
# Skorobogatov (https://stackoverflow.com/users/7851470/georgy) and
# Markus Jarderot
# (https://stackoverflow.com/users/22364/markus-jarderot), licensed
# under CC-BY-SA 4.0.
# https://creativecommons.org/licenses/by-sa/4.0/.

seen = set()
seen_add = seen.add
return [x for x in list_in if not (x in seen or seen_add(x))]


def _dims(data, var_name, skip_dims):
return [dim for dim in data[var_name].dims if dim not in skip_dims]

Expand Down Expand Up @@ -136,7 +111,7 @@ def xarray_sel_iter(data, var_names=None, combined=False, skip_dims=None, revers
for var_name in var_names:
if var_name in data:
new_dims = _dims(data, var_name, skip_dims)
vals = [purge_duplicates(data[var_name][dim].values) for dim in new_dims]
vals = [list(dict.fromkeys(data[var_name][dim].values)) for dim in new_dims]
dims = _zip_dims(new_dims, vals)
idims = _zip_dims(new_dims, [range(len(v)) for v in vals])
if reverse_selections:
Expand All @@ -147,7 +122,9 @@ def xarray_sel_iter(data, var_names=None, combined=False, skip_dims=None, revers
yield var_name, selection, iselection


def xarray_var_iter(data, var_names=None, combined=False, skip_dims=None, reverse_selections=False):
def xarray_var_iter(
data, var_names=None, combined=False, skip_dims=None, reverse_selections=False, dim_order=None
):
"""Convert xarray data to an iterator over vectors.

Iterates over each var_name and all of its coordinates, returning the 1d
Expand All @@ -170,6 +147,9 @@ def xarray_var_iter(data, var_names=None, combined=False, skip_dims=None, revers
reverse_selections : bool
Whether to reverse selections before iterating.

dim_order: list
Order for the first dimensions. Skips dimensions not found in the variable.

Returns
-------
Iterator of (str, dict(str, any), np.array)
Expand All @@ -180,14 +160,22 @@ def xarray_var_iter(data, var_names=None, combined=False, skip_dims=None, revers
if var_names is None and isinstance(data, xr.DataArray):
data_to_sel = {data.name: data}

if isinstance(dim_order, str):
dim_order = [dim_order]

for var_name, selection, iselection in xarray_sel_iter(
data,
var_names=var_names,
combined=combined,
skip_dims=skip_dims,
reverse_selections=reverse_selections,
):
yield var_name, selection, iselection, data_to_sel[var_name].sel(**selection).values
selected_data = data_to_sel[var_name].sel(**selection)
if dim_order is not None:
dim_order_selected = [dim for dim in dim_order if dim in selected_data.dims]
if dim_order_selected:
selected_data = selected_data.transpose(*dim_order, ...)
yield var_name, selection, iselection, selected_data.values


def xarray_to_ndarray(data, *, var_names=None, combined=True, label_fun=None):
Expand Down