Skip to content

Commit

Permalink
Fix poisson
Browse files Browse the repository at this point in the history
  • Loading branch information
thomaspinder committed Oct 1, 2024
1 parent 41c3e27 commit aae6e97
Showing 1 changed file with 11 additions and 22 deletions.
33 changes: 11 additions & 22 deletions examples/poisson.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,33 +154,22 @@ def logprob_fn(params):
return gpx.objectives.log_posterior_density(model, D)


# jit compile
logprob_fn = jax.jit(logprob_fn)
_ = logprob_fn(params)
step_size = 1e-3
inverse_mass_matrix = jnp.ones(53)
nuts = blackjax.nuts(logprob_fn, step_size, inverse_mass_matrix)

state = nuts.init(params)

adapt = blackjax.window_adaptation(
blackjax.nuts, logprob_fn, num_adapt, target_acceptance_rate=0.65, progress_bar=True
)

# Initialise the chain
last_state, kernel, _ = adapt.run(key, params)


def inference_loop(rng_key, kernel, initial_state, num_samples):
def one_step(state, rng_key):
state, info = kernel(rng_key, state)
return state, (state, info)

keys = jax.random.split(rng_key, num_samples)
_, (states, infos) = jax.lax.scan(one_step, initial_state, keys, unroll=10)
step = jax.jit(nuts.step)

return states, infos

def one_step(state, rng_key):
state, info = step(rng_key, state)
return state, (state, info)

# Sample from the posterior distribution
states, infos = inference_loop(key, kernel, last_state, num_samples)

keys = jax.random.split(key, num_samples)
_, (states, infos) = jax.lax.scan(one_step, state, keys, unroll=10)

# %% [markdown]
# ### Sampler efficiency
Expand All @@ -190,7 +179,7 @@ def one_step(state, rng_key):
# proposed sample, divided by the total number of steps run by the chain).

# %%
acceptance_rate = jnp.mean(infos.acceptance_probability)
acceptance_rate = jnp.mean(infos.acceptance_rate)
print(f"Acceptance rate: {acceptance_rate:.2f}")

# %%
Expand Down

0 comments on commit aae6e97

Please # to comment.