From 41c3e27c8c7e1a9e6f52da48670d25d1756f5cfd Mon Sep 17 00:00:00 2001 From: Thomas Pinder Date: Tue, 1 Oct 2024 22:25:29 +0200 Subject: [PATCH 1/2] Fix uncollapsed vi --- examples/uncollapsed_vi.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/examples/uncollapsed_vi.py b/examples/uncollapsed_vi.py index f2fb52ae4..dc879919d 100644 --- a/examples/uncollapsed_vi.py +++ b/examples/uncollapsed_vi.py @@ -8,7 +8,7 @@ # extension: .py # format_name: percent # format_version: '1.3' -# jupytext_version: 1.16.4 +# jupytext_version: 1.11.2 # kernelspec: # display_name: gpjax_beartype # language: python @@ -319,7 +319,6 @@ model=q, objective=lambda p, d: -gpx.objectives.elbo(p, d), train_data=D, - params_bijection=params_bijection, optim=ox.adam(learning_rate=0.01), num_iters=3000, key=jr.key(42), From aae6e97e985d70f3c82623c223922ffcb09b478d Mon Sep 17 00:00:00 2001 From: Thomas Pinder Date: Tue, 1 Oct 2024 22:58:20 +0200 Subject: [PATCH 2/2] Fix poisson --- examples/poisson.py | 33 +++++++++++---------------------- 1 file changed, 11 insertions(+), 22 deletions(-) diff --git a/examples/poisson.py b/examples/poisson.py index 9d2ec4966..0d645d1ef 100644 --- a/examples/poisson.py +++ b/examples/poisson.py @@ -154,33 +154,22 @@ def logprob_fn(params): return gpx.objectives.log_posterior_density(model, D) -# jit compile -logprob_fn = jax.jit(logprob_fn) -_ = logprob_fn(params) +step_size = 1e-3 +inverse_mass_matrix = jnp.ones(53) +nuts = blackjax.nuts(logprob_fn, step_size, inverse_mass_matrix) +state = nuts.init(params) -adapt = blackjax.window_adaptation( - blackjax.nuts, logprob_fn, num_adapt, target_acceptance_rate=0.65, progress_bar=True -) - -# Initialise the chain -last_state, kernel, _ = adapt.run(key, params) - - -def inference_loop(rng_key, kernel, initial_state, num_samples): - def one_step(state, rng_key): - state, info = kernel(rng_key, state) - return state, (state, info) - - keys = jax.random.split(rng_key, num_samples) - _, (states, infos) = jax.lax.scan(one_step, initial_state, keys, unroll=10) +step = jax.jit(nuts.step) - return states, infos +def one_step(state, rng_key): + state, info = step(rng_key, state) + return state, (state, info) -# Sample from the posterior distribution -states, infos = inference_loop(key, kernel, last_state, num_samples) +keys = jax.random.split(key, num_samples) +_, (states, infos) = jax.lax.scan(one_step, state, keys, unroll=10) # %% [markdown] # ### Sampler efficiency @@ -190,7 +179,7 @@ def one_step(state, rng_key): # proposed sample, divided by the total number of steps run by the chain). # %% -acceptance_rate = jnp.mean(infos.acceptance_probability) +acceptance_rate = jnp.mean(infos.acceptance_rate) print(f"Acceptance rate: {acceptance_rate:.2f}") # %%