From aae6e97e985d70f3c82623c223922ffcb09b478d Mon Sep 17 00:00:00 2001 From: Thomas Pinder Date: Tue, 1 Oct 2024 22:58:20 +0200 Subject: [PATCH] Fix poisson --- examples/poisson.py | 33 +++++++++++---------------------- 1 file changed, 11 insertions(+), 22 deletions(-) diff --git a/examples/poisson.py b/examples/poisson.py index 9d2ec4966..0d645d1ef 100644 --- a/examples/poisson.py +++ b/examples/poisson.py @@ -154,33 +154,22 @@ def logprob_fn(params): return gpx.objectives.log_posterior_density(model, D) -# jit compile -logprob_fn = jax.jit(logprob_fn) -_ = logprob_fn(params) +step_size = 1e-3 +inverse_mass_matrix = jnp.ones(53) +nuts = blackjax.nuts(logprob_fn, step_size, inverse_mass_matrix) +state = nuts.init(params) -adapt = blackjax.window_adaptation( - blackjax.nuts, logprob_fn, num_adapt, target_acceptance_rate=0.65, progress_bar=True -) - -# Initialise the chain -last_state, kernel, _ = adapt.run(key, params) - - -def inference_loop(rng_key, kernel, initial_state, num_samples): - def one_step(state, rng_key): - state, info = kernel(rng_key, state) - return state, (state, info) - - keys = jax.random.split(rng_key, num_samples) - _, (states, infos) = jax.lax.scan(one_step, initial_state, keys, unroll=10) +step = jax.jit(nuts.step) - return states, infos +def one_step(state, rng_key): + state, info = step(rng_key, state) + return state, (state, info) -# Sample from the posterior distribution -states, infos = inference_loop(key, kernel, last_state, num_samples) +keys = jax.random.split(key, num_samples) +_, (states, infos) = jax.lax.scan(one_step, state, keys, unroll=10) # %% [markdown] # ### Sampler efficiency @@ -190,7 +179,7 @@ def one_step(state, rng_key): # proposed sample, divided by the total number of steps run by the chain). # %% -acceptance_rate = jnp.mean(infos.acceptance_probability) +acceptance_rate = jnp.mean(infos.acceptance_rate) print(f"Acceptance rate: {acceptance_rate:.2f}") # %%