Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

validate filter_vars #1772

Merged
merged 1 commit into from
Aug 30, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
* Fixed xarray related tests. ([1726](https://github.com/arviz-devs/arviz/pull/1726))
* Fix Bokeh deprecation warnings ([1657](https://github.com/arviz-devs/arviz/pull/1657))
* Fix credible inteval percentage in legend in `plot_loo_pit` ([1745](https://github.com/arviz-devs/arviz/pull/1745))
* Arguments `filter_vars` and `filter_groups` now raise `ValueError` if illegal arguments are passed ([1772](https://github.com/arviz-devs/arviz/pull/1772))

### Deprecation
* Deprecated `index_origin` and `order` arguments in `az.summary` ([1201](https://github.com/arviz-devs/arviz/pull/1201))
Expand Down
5 changes: 5 additions & 0 deletions arviz/data/inference_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -1450,6 +1450,11 @@ def _group_names(
-------
groups: list
"""
if filter_groups not in {None, "like", "regex"}:
raise ValueError(
f"'filter_groups' can only be None, 'like', or 'regex', got: '{filter_groups}'"
)

all_groups = self._groups_all
if groups is None:
return all_groups
Expand Down
9 changes: 9 additions & 0 deletions arviz/tests/base_tests/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -509,6 +509,15 @@ def test_group_names(self, args_res):
group_names = idata._group_names(*args) # pylint: disable=protected-access
assert np.all([name in result for name in group_names])

def test_group_names_invalid_args(self):
ds = dict_to_dataset({"a": np.random.normal(size=(3, 10))})
idata = InferenceData(posterior=(ds, ds))
msg = r"^\'filter_groups\' can only be None, \'like\', or \'regex\', got: 'foo'$"
with pytest.raises(ValueError, match=msg):
idata._group_names( # pylint: disable=protected-access
("posterior",), filter_groups="foo"
)

@pytest.mark.parametrize("inplace", [False, True])
def test_isel(self, data_random, inplace):
idata = data_random
Expand Down
9 changes: 9 additions & 0 deletions arviz/tests/base_tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,15 @@ def test_var_names_filter(var_args):
assert _var_names(var_names, data, filter_vars) == expected


def test_var_names_filter_invalid_argument():
"""Check invalid argument raises."""
samples = np.random.randn(10)
data = dict_to_dataset({"alpha": samples})
msg = r"^\'filter_vars\' can only be None, \'like\', or \'regex\', got: 'foo'$"
with pytest.raises(ValueError, match=msg):
assert _var_names(["alpha"], data, filter_vars="foo")


def test_subset_list_negation_not_found():
"""Check there is a warning if negation pattern is ignored"""
names = ["mu", "theta"]
Expand Down
5 changes: 5 additions & 0 deletions arviz/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,11 @@ def _var_names(var_names, data, filter_vars=None):
-------
var_name: list or None
"""
if filter_vars not in {None, "like", "regex"}:
raise ValueError(
f"'filter_vars' can only be None, 'like', or 'regex', got: '{filter_vars}'"
)

if var_names is not None:
if isinstance(data, (list, tuple)):
all_vars = []
Expand Down