Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

add group info to summary #1408

Merged
merged 5 commits into from
Oct 5, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
### New features
* Added `to_dataframe` method to InferenceData ([1395](https://github.com/arviz-devs/arviz/pull/1395))
* Added `__getitem__` magic to InferenceData ([1395](https://github.com/arviz-devs/arviz/pull/1395))
* Added group argument to summary ([1408](https://github.com/arviz-devs/arviz/pull/1408))

### Maintenance and fixes

Expand Down
47 changes: 34 additions & 13 deletions arviz/stats/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -990,6 +990,7 @@ def summary(
data,
var_names: Optional[List[str]] = None,
filter_vars=None,
group=None,
fmt: str = "wide",
kind: str = "all",
round_to=None,
Expand Down Expand Up @@ -1021,6 +1022,9 @@ def summary(
interpret var_names as substrings of the real variables names. If "regex",
interpret var_names as regular expressions on the real variables names. A la
`pandas.filter`.
group: str
Select a group for summary. Defaults to "posterior", "prior" or first group
in that order, depending what groups exists.
fmt: {'wide', 'long', 'xarray'}
Return format is either pandas.DataFrame {'wide', 'long'} or xarray.Dataset {'xarray'}.
kind: {'all', 'stats', 'diagnostics'}
Expand Down Expand Up @@ -1142,9 +1146,26 @@ def summary(
else:
if not 1 >= hdi_prob > 0:
raise ValueError("The value of hdi_prob should be in the interval (0, 1]")
posterior = convert_to_dataset(data, group="posterior", **extra_args)
var_names = _var_names(var_names, posterior, filter_vars)
posterior = posterior if var_names is None else posterior[var_names]

if isinstance(data, InferenceData):
if group is None:
if not data.groups():
raise TypeError("InferenceData does not contain any groups")
if "posterior" in data:
dataset = data["posterior"]
elif "prior" in data:
dataset = data["prior"]
else:
warnings.warn("Selecting first found group: {}".format(data.groups()[0]))
dataset = data[data.groups()[0]]
else:
if group not in data.groups():
raise TypeError(f"InferenceData does not contain group: {group}")
dataset = data[group]
else:
dataset = convert_to_dataset(data, group="posterior", **extra_args)
var_names = _var_names(var_names, dataset, filter_vars)
dataset = dataset if var_names is None else dataset[var_names]

fmt_group = ("wide", "long", "xarray")
if not isinstance(fmt, str) or (fmt.lower() not in fmt_group):
Expand All @@ -1166,33 +1187,33 @@ def summary(
for stat_func_name, stat_func in stat_funcs.items():
extra_metrics.append(
xr.apply_ufunc(
_make_ufunc(stat_func), posterior, input_core_dims=(("chain", "draw"),)
_make_ufunc(stat_func), dataset, input_core_dims=(("chain", "draw"),)
)
)
extra_metric_names.append(stat_func_name)
else:
for stat_func in stat_funcs:
extra_metrics.append(
xr.apply_ufunc(
_make_ufunc(stat_func), posterior, input_core_dims=(("chain", "draw"),)
_make_ufunc(stat_func), dataset, input_core_dims=(("chain", "draw"),)
)
)
extra_metric_names.append(stat_func.__name__)

if extend and kind in ["all", "stats"]:
mean = posterior.mean(dim=("chain", "draw"), skipna=skipna)
mean = dataset.mean(dim=("chain", "draw"), skipna=skipna)

sd = posterior.std(dim=("chain", "draw"), ddof=1, skipna=skipna)
sd = dataset.std(dim=("chain", "draw"), ddof=1, skipna=skipna)

hdi_post = hdi(posterior, hdi_prob=hdi_prob, multimodal=False, skipna=skipna)
hdi_post = hdi(dataset, hdi_prob=hdi_prob, multimodal=False, skipna=skipna)
hdi_lower = hdi_post.sel(hdi="lower", drop=True)
hdi_higher = hdi_post.sel(hdi="higher", drop=True)

if circ_var_names:
nan_policy = "omit" if skipna else "propagate"
circ_mean = xr.apply_ufunc(
_make_ufunc(st.circmean),
posterior,
dataset,
kwargs=dict(high=np.pi, low=-np.pi, nan_policy=nan_policy),
input_core_dims=(("chain", "draw"),),
)
Expand All @@ -1206,26 +1227,26 @@ def summary(
kwargs_circ_std = dict(high=np.pi, low=-np.pi, nan_policy=nan_policy)
circ_sd = xr.apply_ufunc(
_make_ufunc(func),
posterior,
dataset,
kwargs=kwargs_circ_std,
input_core_dims=(("chain", "draw"),),
)

circ_mcse = xr.apply_ufunc(
_make_ufunc(_mc_error),
posterior,
dataset,
kwargs=dict(circular=True),
input_core_dims=(("chain", "draw"),),
)

circ_hdi = hdi(posterior, hdi_prob=hdi_prob, circular=True, skipna=skipna)
circ_hdi = hdi(dataset, hdi_prob=hdi_prob, circular=True, skipna=skipna)
circ_hdi_lower = circ_hdi.sel(hdi="lower", drop=True)
circ_hdi_higher = circ_hdi.sel(hdi="higher", drop=True)

if kind in ["all", "diagnostics"]:
mcse_mean, mcse_sd, ess_mean, ess_sd, ess_bulk, ess_tail, r_hat = xr.apply_ufunc(
_make_ufunc(_multichain_statistics, n_output=7, ravel=False),
posterior,
dataset,
input_core_dims=(("chain", "draw"),),
output_core_dims=tuple([] for _ in range(7)),
)
Expand Down
34 changes: 34 additions & 0 deletions arviz/tests/base_tests/test_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,13 +192,47 @@ def test_compare_different_size(centered_eight, non_centered_eight):
compare(model_dict, ic="waic", method="stacking")


def test_summary_ndarray():
array = np.random.randn(4, 100, 2)
summary_df = summary(array)
assert summary_df.shape


@pytest.mark.parametrize("var_names_expected", ((None, 10), ("mu", 1), (["mu", "tau"], 2)))
def test_summary_var_names(centered_eight, var_names_expected):
var_names, expected = var_names_expected
summary_df = summary(centered_eight, var_names=var_names)
assert len(summary_df.index) == expected


@pytest.mark.parametrize("missing_groups", (None, "posterior", "prior"))
def test_summary_groups(centered_eight, missing_groups):
if missing_groups == "posterior":
centered_eight = deepcopy(centered_eight)
del centered_eight.posterior
elif missing_groups == "prior":
centered_eight = deepcopy(centered_eight)
del centered_eight.posterior
del centered_eight.prior
if missing_groups == "prior":
with pytest.warns(UserWarning):
summary_df = summary(centered_eight)
else:
summary_df = summary(centered_eight)
assert summary_df.shape


def test_summary_group_argument(centered_eight):
summary_df_posterior = summary(centered_eight, group="posterior")
summary_df_prior = summary(centered_eight, group="prior")
assert list(summary_df_posterior.index) != list(summary_df_prior.index)


def test_summary_wrong_group(centered_eight):
with pytest.raises(TypeError, match=r"InferenceData does not contain group: InvalidGroup"):
summary(centered_eight, group="InvalidGroup")


METRICS_NAMES = [
"mean",
"sd",
Expand Down