Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

reloo() appears to ignore k_thresh argument #1820

Closed
derekpowell opened this issue Sep 29, 2021 · 1 comment
Closed

reloo() appears to ignore k_thresh argument #1820

derekpowell opened this issue Sep 29, 2021 · 1 comment

Comments

@derekpowell
Copy link

Describe the bug
The (experimental) reloo() function appears to ignore the k_thresh argument. For a loo object with 4 pareto-k > .70 and 2 > 1.0, setting k_thresh=1.0 does not reduce number of times model is refit to 2 as expected.

To Reproduce
I encountered this while re-creating the numpyro refitting vignette. I've cut/pasted a reprex at the end of this report.

Expected behavior
Expected that setting k_thresh=1.0 would prevent refitting for observations with pareto-k values < 1.0

Additional context
Python 3.8.8
arviz version = '0.11.2'
numpyro version = '0.6.0'

import arviz as az
import numpyro
import numpyro.distributions as dist
import jax.random as random
from numpyro.infer import MCMC, NUTS
import numpy as np
import matplotlib.pyplot as plt
import scipy.stats as stats
import xarray as xr

numpyro.set_host_device_count(4)


np.random.seed(26)

xdata = np.linspace(0, 50, 100)
b0, b1, sigma = -2, 1, 3
ydata = np.random.normal(loc=b1 * xdata + b0, scale=sigma)

def model(N, x, y=None):
    b0 = numpyro.sample("b0", dist.Normal(0, 10))
    b1 = numpyro.sample("b1", dist.Normal(0, 10))
    sigma_e = numpyro.sample("sigma_e", dist.HalfNormal(10))
    numpyro.sample("y", dist.Normal(b0 + b1 * x, sigma_e), obs=y)

data_dict = {
    "N": len(ydata),
    "y": ydata,
    "x": xdata,
}
kernel = NUTS(model)
sample_kwargs = dict(
    sampler=kernel, 
    num_warmup=1000, 
    num_samples=1000, 
    num_chains=4, 
    chain_method="parallel"
)
mcmc = MCMC(**sample_kwargs)
mcmc.run(random.PRNGKey(0), **data_dict)

dims = {"y": ["time"], "x": ["time"]}
idata_kwargs = {
    "dims": dims,
    "constant_data": {"x": xdata}
}
idata = az.from_numpyro(mcmc, **idata_kwargs)

class NumPyroSamplingWrapper(az.SamplingWrapper):
    def __init__(self, model, **kwargs):        
        self.model_fun = model.sampler.model
        self.rng_key = kwargs.pop("rng_key", random.PRNGKey(0))
        
        super(NumPyroSamplingWrapper, self).__init__(model, **kwargs)
        
    def log_likelihood__i(self, excluded_obs, idata__i):
        samples = {
            key: values.values.reshape((-1, *values.values.shape[2:]))
            for key, values 
            in idata__i.posterior.items()
        }
        log_likelihood_dict = numpyro.infer.log_likelihood(
            self.model_fun, samples, **excluded_obs
        )
        if len(log_likelihood_dict) > 1:
            raise ValueError("multiple likelihoods found")
        data = {}
        nchains = idata__i.posterior.dims["chain"]
        ndraws = idata__i.posterior.dims["draw"]
        for obs_name, log_like in log_likelihood_dict.items():
            shape = (nchains, ndraws) + log_like.shape[1:]
            data[obs_name] = np.reshape(log_like.copy(), shape)
        return az.dict_to_dataset(data)[obs_name]
    
    def sample(self, modified_observed_data):
        self.rng_key, subkey = random.split(self.rng_key)
        mcmc = MCMC(**self.sample_kwargs)
        mcmc.run(subkey, **modified_observed_data)
        return mcmc

    def get_inference_data(self, fit):
        # Cloned from PyStanSamplingWrapper.
        idata = az.from_numpyro(mcmc, **self.idata_kwargs)
        return idata
    
class LinRegWrapper(NumPyroSamplingWrapper):
    def sel_observations(self, idx):
        xdata = self.idata_orig.constant_data["x"].values
        ydata = self.idata_orig.observed_data["y"].values
        mask = np.isin(np.arange(len(xdata)), idx)
        data__i = {"x": xdata[~mask], "y": ydata[~mask], "N": len(ydata[~mask])}
        data_ex = {"x": xdata[mask], "y": ydata[mask], "N": len(ydata[mask])}
        return data__i, data_ex

loo_orig = az.loo(idata, pointwise=True)

loo_orig.pareto_k[[13, 42, 56, 73]] = np.array([0.8, 1.2, 2.6, 0.9])

numpyro_wrapper = LinRegWrapper(
    mcmc, 
    rng_key=random.PRNGKey(5),
    idata_orig=idata, 
    sample_kwargs=sample_kwargs, 
    idata_kwargs=idata_kwargs
)

loo_relooed = az.reloo(numpyro_wrapper, loo_orig=loo_orig, k_thresh=1.0) # model refits 4 times, expected 2 times
@derekpowell
Copy link
Author

Just saw #1580 and reinstalled dev version, which resolved the problem!

# for free to join this conversation on GitHub. Already have an account? # to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant