diff --git a/CHANGELOG.md b/CHANGELOG.md index 55d529b06d..59b8e59ec6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,8 +5,6 @@ ### New features * Stats and plotting functions that provide `var_names` arg can now filter parameters based on partial naming (`filter="like"`) or regular expressions (`filter="regex"`) (see [#1154](https://github.com/arviz-devs/arviz/pull/1154)). * Add `true_values` argument for `plot_pair`. It allows for a scatter plot showing the true values of the variables #1140 -* Add out-of-sample groups (`predictions` and `predictions_constant_data`) and `constant_data` group to pyro translation #1090 -* Add `num_chains` and `pred_dims` arguments to io_pyro #1090 * Allow xarray.Dataarray input for plots.(#1120) * Revamped the `hpd` function to make it work with mutidimensional arrays, InferenceData and xarray objects (#1117) * Skip test for optional/extra dependencies when not installed (#1113) @@ -19,6 +17,9 @@ * Add warmup groups to InferenceData objects, initial support for PyStan (#1126) and PyMC3 (#1171) * `hdi_prob` will not plot hdi if argument `hide` is passed. Previously `credible_interval` would omit HPD if `None` was passed (#1176) * `stats.hdp` is pending deprecation. Replaced by `stats.hdi` +* Add `stats.ic_pointwise` rcParam (#1173) +* Add `var_name` argument to information criterion calculation: `compare`, + `loo` and `waic` (#1173) ### Maintenance and fixes * Fixed `plot_pair` functionality for two variables with bokeh backend (#1179) diff --git a/arviz/rcparams.py b/arviz/rcparams.py index fe2e5a3075..2c80311687 100644 --- a/arviz/rcparams.py +++ b/arviz/rcparams.py @@ -242,6 +242,7 @@ def validate_iterable(value): "plot.matplotlib.show": (False, _validate_boolean), "stats.hdi_prob": (0.94, _validate_probability), "stats.information_criterion": ("loo", _make_validate_choice({"waic", "loo"})), + "stats.ic_pointwise": (False, _validate_boolean), "stats.ic_scale": ("log", _make_validate_choice({"deviance", "log", "negative_log"})), } diff --git a/arviz/stats/stats.py b/arviz/stats/stats.py index c34c1582ce..3ced67f431 100644 --- a/arviz/stats/stats.py +++ b/arviz/stats/stats.py @@ -560,7 +560,7 @@ def _hdi_multimodal(ary, hdi_prob, skipna, max_modes): return np.array(hdi_intervals) -def loo(data, pointwise=False, reff=None, scale=None): +def loo(data, pointwise=None, var_name=None, reff=None, scale=None): """Compute Pareto-smoothed importance sampling leave-one-out cross-validation (PSIS-LOO-CV). Estimates the expected log pointwise predictive density (elpd) using Pareto-smoothed @@ -574,7 +574,11 @@ def loo(data, pointwise=False, reff=None, scale=None): Any object that can be converted to an az.InferenceData object. Refer to documentation of az.convert_to_inference_data for details pointwise: bool, optional - If True the pointwise predictive accuracy will be returned. Defaults to False + If True the pointwise predictive accuracy will be returned. Defaults to + ``stats.ic_pointwise`` rcParam. + var_name : str, optional + The name of the variable in log_likelihood groups storing the pointwise log + likelihood data to use for loo computation. reff: float, optional Relative MCMC efficiency, `ess / n` i.e. number of effective samples divided by the number of actual samples. Computed from trace by default. @@ -621,7 +625,8 @@ def loo(data, pointwise=False, reff=None, scale=None): ...: data_loo.loo_i """ inference_data = convert_to_inference_data(data) - log_likelihood = _get_log_likelihood(inference_data) + log_likelihood = _get_log_likelihood(inference_data, var_name=var_name) + pointwise = rcParams["stats.ic_pointwise"] if pointwise is None else pointwise log_likelihood = log_likelihood.stack(sample=("chain", "draw")) shape = log_likelihood.shape @@ -1306,7 +1311,7 @@ def summary( return summary_df -def waic(data, pointwise=False, scale=None): +def waic(data, pointwise=None, var_name=None, scale=None): """Compute the widely applicable information criterion. Estimates the expected log pointwise predictive density (elpd) using WAIC. Also calculates the @@ -1317,9 +1322,13 @@ def waic(data, pointwise=False, scale=None): ---------- data: obj Any object that can be converted to an az.InferenceData object. Refer to documentation of - az.convert_to_inference_data for details + ``az.convert_to_inference_data`` for details pointwise: bool - if True the pointwise predictive accuracy will be returned. Defaults to False + If True the pointwise predictive accuracy will be returned. Defaults to + ``stats.ic_pointwise`` rcParam. + var_name : str, optional + The name of the variable in log_likelihood groups storing the pointwise log + likelihood data to use for waic computation. scale: str Output scale for WAIC. Available options are: @@ -1361,8 +1370,9 @@ def waic(data, pointwise=False, scale=None): ...: data_waic.waic_i """ inference_data = convert_to_inference_data(data) - log_likelihood = _get_log_likelihood(inference_data) + log_likelihood = _get_log_likelihood(inference_data, var_name=var_name) scale = rcParams["stats.ic_scale"] if scale is None else scale.lower() + pointwise = rcParams["stats.ic_pointwise"] if pointwise is None else pointwise if scale == "deviance": scale_value = -2 diff --git a/arviz/stats/stats_utils.py b/arviz/stats/stats_utils.py index 7239ce8c65..a74f1ecefd 100644 --- a/arviz/stats/stats_utils.py +++ b/arviz/stats/stats_utils.py @@ -402,16 +402,14 @@ def get_log_likelihood(idata, var_name=None): """Retrieve the log likelihood dataarray of a given variable.""" if hasattr(idata, "sample_stats") and hasattr(idata.sample_stats, "log_likelihood"): warnings.warn( - "Storing the log_likelihood in sample_stats groups will be deprecated", - PendingDeprecationWarning, + "Storing the log_likelihood in sample_stats groups has been deprecated", + DeprecationWarning, ) return idata.sample_stats.log_likelihood if not hasattr(idata, "log_likelihood"): raise TypeError("log likelihood not found in inference data object") if var_name is None: var_names = list(idata.log_likelihood.data_vars) - if "lp" in var_names: - var_names.remove("lp") if len(var_names) > 1: raise TypeError( "Found several log likelihood arrays {}, var_name cannot be None".format(var_names) @@ -440,9 +438,9 @@ def get_log_likelihood(idata, var_name=None): (1, Inf) (very bad) {{5:{0}d}} {{9:6.1f}}% """ SCALE_WARNING_FORMAT = """ -The scale is now log by default. Use 'scale' argument or 'stats.ic_scale' rcParam if -you rely on a specific value. -A higher log-score (or a lower deviance) indicates a model with better predictive +The scale is now log by default. Use 'scale' argument or 'stats.ic_scale' rcParam if +you rely on a specific value. +A higher log-score (or a lower deviance) indicates a model with better predictive accuracy.""" SCALE_DICT = {"deviance": "deviance", "log": "elpd", "negative_log": "-elpd"} diff --git a/arviz/tests/base_tests/test_stats.py b/arviz/tests/base_tests/test_stats.py index 626d024874..8a7112d434 100644 --- a/arviz/tests/base_tests/test_stats.py +++ b/arviz/tests/base_tests/test_stats.py @@ -22,6 +22,7 @@ waic, ) from ...stats.stats import _gpinv +from ...stats.stats_utils import get_log_likelihood from ..helpers import check_multiple_attrs, multidim_models # pylint: disable=unused-import rcParams["data.load"] = "eager" @@ -433,7 +434,7 @@ def test_loo_print(centered_eight, scale): def test_psislw(centered_eight): pareto_k = loo(centered_eight, pointwise=True, reff=0.7)["pareto_k"] - log_likelihood = centered_eight.sample_stats.log_likelihood # pylint: disable=no-member + log_likelihood = get_log_likelihood(centered_eight) log_likelihood = log_likelihood.stack(sample=("chain", "draw")) assert_allclose(pareto_k, psislw(-log_likelihood, 0.7)[1]) @@ -493,7 +494,7 @@ def test_loo_pit(centered_eight, args): log_weights = args.get("log_weights", None) y_arr = centered_eight.observed_data.obs y_hat_arr = centered_eight.posterior_predictive.obs.stack(sample=("chain", "draw")) - log_like = centered_eight.sample_stats.log_likelihood.stack(sample=("chain", "draw")) + log_like = get_log_likelihood(centered_eight).stack(sample=("chain", "draw")) n_samples = len(log_like.sample) ess_p = ess(centered_eight.posterior, method="mean") reff = np.hstack([ess_p[v].values.flatten() for v in ess_p.data_vars]).mean() / n_samples @@ -533,7 +534,7 @@ def test_loo_pit_multidim(multidim_models, args): idata = multidim_models.model_1 y_arr = idata.observed_data.y y_hat_arr = idata.posterior_predictive.y.stack(sample=("chain", "draw")) - log_like = idata.sample_stats.log_likelihood.stack(sample=("chain", "draw")) + log_like = get_log_likelihood(idata).stack(sample=("chain", "draw")) n_samples = len(log_like.sample) ess_p = ess(idata.posterior, method="mean") reff = np.hstack([ess_p[v].values.flatten() for v in ess_p.data_vars]).mean() / n_samples diff --git a/arviz/tests/base_tests/test_stats_utils.py b/arviz/tests/base_tests/test_stats_utils.py index 0d68846df3..09aabbbcae 100644 --- a/arviz/tests/base_tests/test_stats_utils.py +++ b/arviz/tests/base_tests/test_stats_utils.py @@ -6,7 +6,7 @@ from scipy.special import logsumexp from scipy.stats import circstd -from ...data import load_arviz_data +from ...data import load_arviz_data, from_dict from ...stats.stats_utils import ( logsumexp as _logsumexp, make_ufunc, @@ -19,6 +19,7 @@ _angle, _circfunc, _circular_standard_deviation, + get_log_likelihood, ) @@ -261,6 +262,44 @@ def test_valid_shape(): ) +def test_get_log_likelihood(): + idata = from_dict( + log_likelihood={ + "y1": np.random.normal(size=(4, 100, 6)), + "y2": np.random.normal(size=(4, 100, 8)), + } + ) + lik1 = get_log_likelihood(idata, "y1") + lik2 = get_log_likelihood(idata, "y2") + assert lik1.shape == (4, 100, 6) + assert lik2.shape == (4, 100, 8) + + +def test_get_log_likelihood_warning(): + idata = from_dict(sample_stats={"log_likelihood": np.random.normal(size=(4, 100, 6)),}) + with pytest.warns(DeprecationWarning): + get_log_likelihood(idata) + + +def test_get_log_likelihood_no_var_name(): + idata = from_dict( + log_likelihood={ + "y1": np.random.normal(size=(4, 100, 6)), + "y2": np.random.normal(size=(4, 100, 8)), + } + ) + with pytest.raises(TypeError, match="Found several"): + get_log_likelihood(idata) + + +def test_get_log_likelihood_no_group(): + idata = from_dict( + posterior={"a": np.random.normal(size=(4, 100)), "b": np.random.normal(size=(4, 100)),} + ) + with pytest.raises(TypeError, match="log likelihood not found"): + get_log_likelihood(idata) + + def test_elpd_data_error(): with pytest.raises(ValueError): ELPDData(data=[0, 1, 2], index=["not IC", "se", "p"]).__repr__() diff --git a/arviz/tests/helpers.py b/arviz/tests/helpers.py index c66aa04966..ec2afdb400 100644 --- a/arviz/tests/helpers.py +++ b/arviz/tests/helpers.py @@ -61,7 +61,9 @@ def create_model(seed=10): "energy": np.random.randn(nchains, ndraws), "diverging": np.random.randn(nchains, ndraws) > 0.90, "max_depth": np.random.randn(nchains, ndraws) > 0.90, - "log_likelihood": np.random.randn(nchains, ndraws, data["J"]), + } + log_likelihood = { + "y": np.random.randn(nchains, ndraws, data["J"]), } prior = { "mu": np.random.randn(nchains, ndraws) / 2, @@ -78,6 +80,7 @@ def create_model(seed=10): posterior=posterior, posterior_predictive=posterior_predictive, sample_stats=sample_stats, + log_likelihood=log_likelihood, prior=prior, prior_predictive=prior_predictive, sample_stats_prior=sample_stats_prior, @@ -109,7 +112,9 @@ def create_multidimensional_model(seed=10): sample_stats = { "energy": np.random.randn(nchains, ndraws), "diverging": np.random.randn(nchains, ndraws) > 0.90, - "log_likelihood": np.random.randn(nchains, ndraws, ndim1, ndim2), + } + log_likelihood = { + "y": np.random.randn(nchains, ndraws, ndim1, ndim2), } prior = { "mu": np.random.randn(nchains, ndraws) / 2, @@ -126,6 +131,7 @@ def create_multidimensional_model(seed=10): posterior=posterior, posterior_predictive=posterior_predictive, sample_stats=sample_stats, + log_likelihood=log_likelihood, prior=prior, prior_predictive=prior_predictive, sample_stats_prior=sample_stats_prior, diff --git a/arvizrc.template b/arvizrc.template index 7b72f0e25d..54829ebe39 100644 --- a/arvizrc.template +++ b/arvizrc.template @@ -42,4 +42,5 @@ plot.matplotlib.show : false # call plt.show. One of "true", "fals # rcParams related with statistical and diagnostic functions stats.hdi_prob : 0.94 stats.information_criterion : loo # One of "loo", "waic" +stats.ic_pointwise : false # One of "true", "false" stats.ic_scale : log # One of "deviance", "log", "negative_log"