Skip to content

Commit

Permalink
keep shape of observations fixes issue #656 (#657)
Browse files Browse the repository at this point in the history
* keep shape of observations

* ensure wais and loo work with multidim log_likelihood

* fixed pylint variable names in test_stats.py
  • Loading branch information
arabidopsis authored and aloctavodia committed May 10, 2019
1 parent cffcb2b commit 95c9888
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 5 deletions.
2 changes: 1 addition & 1 deletion arviz/data/io_pymc3.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def log_likelihood_vals_point(point):
log_like_val = log_like(point)
if var.missing_values:
log_like_val = log_like_val[~var.observations.mask]
log_like_vals.append(log_like_val.ravel())
log_like_vals.append(log_like_val)
return np.concatenate(log_like_vals)

chain_likelihoods = []
Expand Down
6 changes: 3 additions & 3 deletions arviz/stats/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,14 +422,14 @@ def loo(data, pointwise=False, reff=None, scale="deviance"):
for group in ("posterior", "sample_stats"):
if not hasattr(inference_data, group):
raise TypeError(
"Must be able to extract a {group}" "group from data!".format(group=group)
"Must be able to extract a {group} group from data!".format(group=group)
)
if "log_likelihood" not in inference_data.sample_stats:
raise TypeError("Data must include log_likelihood in sample_stats")
posterior = inference_data.posterior
log_likelihood = inference_data.sample_stats.log_likelihood
n_samples = log_likelihood.chain.size * log_likelihood.draw.size
new_shape = (n_samples,) + log_likelihood.shape[2:]
new_shape = (n_samples, np.product(log_likelihood.shape[2:]))
log_likelihood = log_likelihood.values.reshape(*new_shape)

if scale.lower() == "deviance":
Expand Down Expand Up @@ -1008,7 +1008,7 @@ def waic(data, pointwise=False, scale="deviance"):
raise TypeError('Valid scale values are "deviance", "log", "negative_log"')

n_samples = log_likelihood.chain.size * log_likelihood.draw.size
new_shape = (n_samples,) + log_likelihood.shape[2:]
new_shape = (n_samples, np.product(log_likelihood.shape[2:]))
log_likelihood = log_likelihood.values.reshape(*new_shape)

lppd_i = _logsumexp(log_likelihood, axis=0, b_inv=log_likelihood.shape[0])
Expand Down
29 changes: 28 additions & 1 deletion arviz/tests/test_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@
import pytest
from scipy.special import logsumexp
from scipy.stats import linregress
from xarray import Dataset, DataArray


from ..data import load_arviz_data, from_dict
from ..data import load_arviz_data, from_dict, convert_to_inference_data, concat
from ..stats import bfmi, compare, hpd, loo, r2_score, waic, psislw, summary
from ..stats.stats import _gpinv, _mc_error, _logsumexp

Expand Down Expand Up @@ -350,3 +351,29 @@ def test_logsumexp_b_inv(ary_dtype, axis, b_inv, keepdims):
arviz_results = _logsumexp(ary, b_inv=b_inv, axis=axis, keepdims=keepdims)

assert_array_almost_equal(scipy_results, arviz_results)


def test_multidimenional_log_likelihood():
np.random.seed(17)
llm = np.random.rand(4, 23, 15, 2)
ll1 = llm.reshape(4, 23, 15 * 2)
statsm = Dataset(dict(log_likelihood=DataArray(llm, dims=["chain", "draw", "a", "b"])))

stats1 = Dataset(dict(log_likelihood=DataArray(ll1, dims=["chain", "draw", "v"])))

post = Dataset(dict(mu=DataArray(np.random.rand(4, 23, 2), dims=["chain", "draw", "v"])))

dsm = convert_to_inference_data(statsm, group="sample_stats")
ds1 = convert_to_inference_data(stats1, group="sample_stats")
dsp = convert_to_inference_data(post, group="posterior")
dsm = concat(dsp, dsm)
ds1 = concat(dsp, ds1)
lrm = loo(dsm)
lr1 = loo(ds1)
assert (lr1 == lrm).all()
assert_array_almost_equal(lrm[:4], lr1[:4])

wrm = waic(dsm)
wr1 = waic(ds1)
assert (wr1 == wrm).all()
assert_array_almost_equal(wrm[:4], wr1[:4])

0 comments on commit 95c9888

Please # to comment.