Skip to content

Commit

Permalink
extend 'right' join bug fix (#1718)
Browse files Browse the repository at this point in the history
* extend bug fix

* Update inference_data.py

* Update logic

* Add logic to keep the correct order

* Update logic for add_groups

* Fix long lines
  • Loading branch information
mjhajharia authored Jan 16, 2022
1 parent b54e022 commit 08929a4
Showing 1 changed file with 46 additions and 4 deletions.
50 changes: 46 additions & 4 deletions arviz/data/inference_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -1379,9 +1379,29 @@ def add_groups(self, group_dict=None, coords=None, dims=None, **kwargs):
if dataset:
setattr(self, group, dataset)
if group.startswith(WARMUP_TAG):
self._groups_warmup.append(group)
supported_order = [
key for key in SUPPORTED_GROUPS_ALL if key in self._groups_warmup
]
if (supported_order == self._groups_warmup) and (group in SUPPORTED_GROUPS_ALL):
group_order = [
key
for key in SUPPORTED_GROUPS_ALL
if key in self._groups_warmup + [group]
]
group_idx = group_order.index(group)
self._groups_warmup.insert(group_idx, group)
else:
self._groups_warmup.append(group)
else:
self._groups.append(group)
supported_order = [key for key in SUPPORTED_GROUPS_ALL if key in self._groups]
if (supported_order == self._groups) and (group in SUPPORTED_GROUPS_ALL):
group_order = [
key for key in SUPPORTED_GROUPS_ALL if key in self._groups + [group]
]
group_idx = group_order.index(group)
self._groups.insert(group_idx, group)
else:
self._groups.append(group)

def extend(self, other, join="left"):
"""Extend InferenceData with groups from another InferenceData.
Expand Down Expand Up @@ -1416,9 +1436,31 @@ def extend(self, other, join="left"):
dataset = getattr(other, group)
setattr(self, group, dataset)
if group.startswith(WARMUP_TAG):
self._groups_warmup.append(group)
if group not in self._groups_warmup:
supported_order = [
key for key in SUPPORTED_GROUPS_ALL if key in self._groups_warmup
]
if (supported_order == self._groups_warmup) and (group in SUPPORTED_GROUPS_ALL):
group_order = [
key
for key in SUPPORTED_GROUPS_ALL
if key in self._groups_warmup + [group]
]
group_idx = group_order.index(group)
self._groups_warmup.insert(group_idx, group)
else:
self._groups_warmup.append(group)
else:
self._groups.append(group)
if group not in self._groups:
supported_order = [key for key in SUPPORTED_GROUPS_ALL if key in self._groups]
if (supported_order == self._groups) and (group in SUPPORTED_GROUPS_ALL):
group_order = [
key for key in SUPPORTED_GROUPS_ALL if key in self._groups + [group]
]
group_idx = group_order.index(group)
self._groups.insert(group_idx, group)
else:
self._groups.append(group)

set_index = _extend_xr_method(xr.Dataset.set_index, see_also="reset_index")
get_index = _extend_xr_method(xr.Dataset.get_index)
Expand Down

0 comments on commit 08929a4

Please # to comment.