Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

[proposal] support for IC with multiple variables #1173

Merged
merged 6 commits into from
May 16, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@
### New features
* Stats and plotting functions that provide `var_names` arg can now filter parameters based on partial naming (`filter="like"`) or regular expressions (`filter="regex"`) (see [#1154](https://github.com/arviz-devs/arviz/pull/1154)).
* Add `true_values` argument for `plot_pair`. It allows for a scatter plot showing the true values of the variables #1140
* Add out-of-sample groups (`predictions` and `predictions_constant_data`) and `constant_data` group to pyro translation #1090
* Add `num_chains` and `pred_dims` arguments to io_pyro #1090
* Allow xarray.Dataarray input for plots.(#1120)
* Revamped the `hpd` function to make it work with mutidimensional arrays, InferenceData and xarray objects (#1117)
* Skip test for optional/extra dependencies when not installed (#1113)
Expand All @@ -19,6 +17,9 @@
* Add warmup groups to InferenceData objects, initial support for PyStan (#1126) and PyMC3 (#1171)
* `hdi_prob` will not plot hdi if argument `hide` is passed. Previously `credible_interval` would omit HPD if `None` was passed (#1176)
* `stats.hdp` is pending deprecation. Replaced by `stats.hdi`
* Add `stats.ic_pointwise` rcParam (#1173)
* Add `var_name` argument to information criterion calculation: `compare`,
`loo` and `waic` (#1173)

### Maintenance and fixes
* Fixed `plot_pair` functionality for two variables with bokeh backend (#1179)
Expand Down
1 change: 1 addition & 0 deletions arviz/rcparams.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,7 @@ def validate_iterable(value):
"plot.matplotlib.show": (False, _validate_boolean),
"stats.hdi_prob": (0.94, _validate_probability),
"stats.information_criterion": ("loo", _make_validate_choice({"waic", "loo"})),
"stats.ic_pointwise": (False, _validate_boolean),
"stats.ic_scale": ("log", _make_validate_choice({"deviance", "log", "negative_log"})),
}

Expand Down
24 changes: 17 additions & 7 deletions arviz/stats/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -560,7 +560,7 @@ def _hdi_multimodal(ary, hdi_prob, skipna, max_modes):
return np.array(hdi_intervals)


def loo(data, pointwise=False, reff=None, scale=None):
def loo(data, pointwise=None, var_name=None, reff=None, scale=None):
"""Compute Pareto-smoothed importance sampling leave-one-out cross-validation (PSIS-LOO-CV).

Estimates the expected log pointwise predictive density (elpd) using Pareto-smoothed
Expand All @@ -574,7 +574,11 @@ def loo(data, pointwise=False, reff=None, scale=None):
Any object that can be converted to an az.InferenceData object. Refer to documentation of
az.convert_to_inference_data for details
pointwise: bool, optional
If True the pointwise predictive accuracy will be returned. Defaults to False
If True the pointwise predictive accuracy will be returned. Defaults to
``stats.ic_pointwise`` rcParam.
var_name : str, optional
The name of the variable in log_likelihood groups storing the pointwise log
likelihood data to use for loo computation.
reff: float, optional
Relative MCMC efficiency, `ess / n` i.e. number of effective samples divided by the number
of actual samples. Computed from trace by default.
Expand Down Expand Up @@ -621,7 +625,8 @@ def loo(data, pointwise=False, reff=None, scale=None):
...: data_loo.loo_i
"""
inference_data = convert_to_inference_data(data)
log_likelihood = _get_log_likelihood(inference_data)
log_likelihood = _get_log_likelihood(inference_data, var_name=var_name)
pointwise = rcParams["stats.ic_pointwise"] if pointwise is None else pointwise

log_likelihood = log_likelihood.stack(sample=("chain", "draw"))
shape = log_likelihood.shape
Expand Down Expand Up @@ -1306,7 +1311,7 @@ def summary(
return summary_df


def waic(data, pointwise=False, scale=None):
def waic(data, pointwise=None, var_name=None, scale=None):
"""Compute the widely applicable information criterion.

Estimates the expected log pointwise predictive density (elpd) using WAIC. Also calculates the
Expand All @@ -1317,9 +1322,13 @@ def waic(data, pointwise=False, scale=None):
----------
data: obj
Any object that can be converted to an az.InferenceData object. Refer to documentation of
az.convert_to_inference_data for details
``az.convert_to_inference_data`` for details
pointwise: bool
if True the pointwise predictive accuracy will be returned. Defaults to False
If True the pointwise predictive accuracy will be returned. Defaults to
``stats.ic_pointwise`` rcParam.
var_name : str, optional
The name of the variable in log_likelihood groups storing the pointwise log
likelihood data to use for waic computation.
scale: str
Output scale for WAIC. Available options are:

Expand Down Expand Up @@ -1361,8 +1370,9 @@ def waic(data, pointwise=False, scale=None):
...: data_waic.waic_i
"""
inference_data = convert_to_inference_data(data)
log_likelihood = _get_log_likelihood(inference_data)
log_likelihood = _get_log_likelihood(inference_data, var_name=var_name)
scale = rcParams["stats.ic_scale"] if scale is None else scale.lower()
pointwise = rcParams["stats.ic_pointwise"] if pointwise is None else pointwise

if scale == "deviance":
scale_value = -2
Expand Down
12 changes: 5 additions & 7 deletions arviz/stats/stats_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,16 +402,14 @@ def get_log_likelihood(idata, var_name=None):
"""Retrieve the log likelihood dataarray of a given variable."""
if hasattr(idata, "sample_stats") and hasattr(idata.sample_stats, "log_likelihood"):
warnings.warn(
"Storing the log_likelihood in sample_stats groups will be deprecated",
PendingDeprecationWarning,
"Storing the log_likelihood in sample_stats groups has been deprecated",
DeprecationWarning,
)
return idata.sample_stats.log_likelihood
if not hasattr(idata, "log_likelihood"):
raise TypeError("log likelihood not found in inference data object")
if var_name is None:
var_names = list(idata.log_likelihood.data_vars)
if "lp" in var_names:
var_names.remove("lp")
if len(var_names) > 1:
raise TypeError(
"Found several log likelihood arrays {}, var_name cannot be None".format(var_names)
Expand Down Expand Up @@ -440,9 +438,9 @@ def get_log_likelihood(idata, var_name=None):
(1, Inf) (very bad) {{5:{0}d}} {{9:6.1f}}%
"""
SCALE_WARNING_FORMAT = """
The scale is now log by default. Use 'scale' argument or 'stats.ic_scale' rcParam if
you rely on a specific value.
A higher log-score (or a lower deviance) indicates a model with better predictive
The scale is now log by default. Use 'scale' argument or 'stats.ic_scale' rcParam if
you rely on a specific value.
A higher log-score (or a lower deviance) indicates a model with better predictive
accuracy."""
SCALE_DICT = {"deviance": "deviance", "log": "elpd", "negative_log": "-elpd"}

Expand Down
7 changes: 4 additions & 3 deletions arviz/tests/base_tests/test_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
waic,
)
from ...stats.stats import _gpinv
from ...stats.stats_utils import get_log_likelihood
from ..helpers import check_multiple_attrs, multidim_models # pylint: disable=unused-import

rcParams["data.load"] = "eager"
Expand Down Expand Up @@ -433,7 +434,7 @@ def test_loo_print(centered_eight, scale):

def test_psislw(centered_eight):
pareto_k = loo(centered_eight, pointwise=True, reff=0.7)["pareto_k"]
log_likelihood = centered_eight.sample_stats.log_likelihood # pylint: disable=no-member
log_likelihood = get_log_likelihood(centered_eight)
log_likelihood = log_likelihood.stack(sample=("chain", "draw"))
assert_allclose(pareto_k, psislw(-log_likelihood, 0.7)[1])

Expand Down Expand Up @@ -493,7 +494,7 @@ def test_loo_pit(centered_eight, args):
log_weights = args.get("log_weights", None)
y_arr = centered_eight.observed_data.obs
y_hat_arr = centered_eight.posterior_predictive.obs.stack(sample=("chain", "draw"))
log_like = centered_eight.sample_stats.log_likelihood.stack(sample=("chain", "draw"))
log_like = get_log_likelihood(centered_eight).stack(sample=("chain", "draw"))
n_samples = len(log_like.sample)
ess_p = ess(centered_eight.posterior, method="mean")
reff = np.hstack([ess_p[v].values.flatten() for v in ess_p.data_vars]).mean() / n_samples
Expand Down Expand Up @@ -533,7 +534,7 @@ def test_loo_pit_multidim(multidim_models, args):
idata = multidim_models.model_1
y_arr = idata.observed_data.y
y_hat_arr = idata.posterior_predictive.y.stack(sample=("chain", "draw"))
log_like = idata.sample_stats.log_likelihood.stack(sample=("chain", "draw"))
log_like = get_log_likelihood(idata).stack(sample=("chain", "draw"))
n_samples = len(log_like.sample)
ess_p = ess(idata.posterior, method="mean")
reff = np.hstack([ess_p[v].values.flatten() for v in ess_p.data_vars]).mean() / n_samples
Expand Down
41 changes: 40 additions & 1 deletion arviz/tests/base_tests/test_stats_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from scipy.special import logsumexp
from scipy.stats import circstd

from ...data import load_arviz_data
from ...data import load_arviz_data, from_dict
from ...stats.stats_utils import (
logsumexp as _logsumexp,
make_ufunc,
Expand All @@ -19,6 +19,7 @@
_angle,
_circfunc,
_circular_standard_deviation,
get_log_likelihood,
)


Expand Down Expand Up @@ -261,6 +262,44 @@ def test_valid_shape():
)


def test_get_log_likelihood():
idata = from_dict(
log_likelihood={
"y1": np.random.normal(size=(4, 100, 6)),
"y2": np.random.normal(size=(4, 100, 8)),
}
)
lik1 = get_log_likelihood(idata, "y1")
lik2 = get_log_likelihood(idata, "y2")
assert lik1.shape == (4, 100, 6)
assert lik2.shape == (4, 100, 8)


def test_get_log_likelihood_warning():
idata = from_dict(sample_stats={"log_likelihood": np.random.normal(size=(4, 100, 6)),})
with pytest.warns(DeprecationWarning):
get_log_likelihood(idata)


def test_get_log_likelihood_no_var_name():
idata = from_dict(
log_likelihood={
"y1": np.random.normal(size=(4, 100, 6)),
"y2": np.random.normal(size=(4, 100, 8)),
}
)
with pytest.raises(TypeError, match="Found several"):
get_log_likelihood(idata)


def test_get_log_likelihood_no_group():
idata = from_dict(
posterior={"a": np.random.normal(size=(4, 100)), "b": np.random.normal(size=(4, 100)),}
)
with pytest.raises(TypeError, match="log likelihood not found"):
get_log_likelihood(idata)


def test_elpd_data_error():
with pytest.raises(ValueError):
ELPDData(data=[0, 1, 2], index=["not IC", "se", "p"]).__repr__()
Expand Down
10 changes: 8 additions & 2 deletions arviz/tests/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,9 @@ def create_model(seed=10):
"energy": np.random.randn(nchains, ndraws),
"diverging": np.random.randn(nchains, ndraws) > 0.90,
"max_depth": np.random.randn(nchains, ndraws) > 0.90,
"log_likelihood": np.random.randn(nchains, ndraws, data["J"]),
}
log_likelihood = {
"y": np.random.randn(nchains, ndraws, data["J"]),
}
prior = {
"mu": np.random.randn(nchains, ndraws) / 2,
Expand All @@ -78,6 +80,7 @@ def create_model(seed=10):
posterior=posterior,
posterior_predictive=posterior_predictive,
sample_stats=sample_stats,
log_likelihood=log_likelihood,
prior=prior,
prior_predictive=prior_predictive,
sample_stats_prior=sample_stats_prior,
Expand Down Expand Up @@ -109,7 +112,9 @@ def create_multidimensional_model(seed=10):
sample_stats = {
"energy": np.random.randn(nchains, ndraws),
"diverging": np.random.randn(nchains, ndraws) > 0.90,
"log_likelihood": np.random.randn(nchains, ndraws, ndim1, ndim2),
}
log_likelihood = {
"y": np.random.randn(nchains, ndraws, ndim1, ndim2),
}
prior = {
"mu": np.random.randn(nchains, ndraws) / 2,
Expand All @@ -126,6 +131,7 @@ def create_multidimensional_model(seed=10):
posterior=posterior,
posterior_predictive=posterior_predictive,
sample_stats=sample_stats,
log_likelihood=log_likelihood,
prior=prior,
prior_predictive=prior_predictive,
sample_stats_prior=sample_stats_prior,
Expand Down
1 change: 1 addition & 0 deletions arvizrc.template
Original file line number Diff line number Diff line change
Expand Up @@ -42,4 +42,5 @@ plot.matplotlib.show : false # call plt.show. One of "true", "fals
# rcParams related with statistical and diagnostic functions
stats.hdi_prob : 0.94
stats.information_criterion : loo # One of "loo", "waic"
stats.ic_pointwise : false # One of "true", "false"
stats.ic_scale : log # One of "deviance", "log", "negative_log"