From 2dfdb27fa2baf7d307d59ece9e86544e0f1e8500 Mon Sep 17 00:00:00 2001 From: phosgene89 <34409398+phosgene89@users.noreply.github.com> Date: Mon, 10 Jun 2019 09:42:02 +0930 Subject: [PATCH] fix broken cell Code was broken before, this enables the cell to be run. Fixed some typos and added "None" to the function "make_simple_step_size_update_policy()" for the argument "num_adaption_steps", which otherwise returned the error: make_simple_step_size_update_policy() missing 1 required positional argument: 'num_adaptation_steps'. As per documentation, None may not be the best choice. --- Chapter1_Introduction/Ch1_Introduction_TFP.ipynb | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/Chapter1_Introduction/Ch1_Introduction_TFP.ipynb b/Chapter1_Introduction/Ch1_Introduction_TFP.ipynb index 07aede9a..86fa6507 100644 --- a/Chapter1_Introduction/Ch1_Introduction_TFP.ipynb +++ b/Chapter1_Introduction/Ch1_Introduction_TFP.ipynb @@ -1208,7 +1208,7 @@ "# Set the chain's start state.\n", "initial_chain_state = [\n", " tf.cast(tf.reduce_mean(count_data), tf.float32) * tf.ones([], dtype=tf.float32, name=\"init_lambda1\"),\n", - " tf.cast(tf.reduce_mean(count_data), tf.float32) * tf.ones([], dtype=tf.float32, name=\"init_lambda2\", tf.float32),\n", + " tf.cast(tf.reduce_mean(count_data), tf.float32) * tf.ones([], dtype=tf.float32, name=\"init_lambda2\"),\n", " 0.5 * tf.ones([], dtype=tf.float32, name=\"init_tau\"),\n", "]\n", "\n", @@ -1273,11 +1273,11 @@ " target_log_prob_fn=unnormalized_log_posterior,\n", " num_leapfrog_steps=2,\n", " step_size=step_size,\n", - " step_size_update_fn=tfp.mcmc.make_simple_step_size_update_policy(),\n", + " step_size_update_fn=tfp.mcmc.make_simple_step_size_update_policy(None),\n", " state_gradients_are_stopped=True),\n", " bijector=unconstraining_bijectors))\n", "\n", - "tau_samples = tf.floor(posterior_tau * tf.cast(tf.size(count_data)), tf.float32)\n", + "tau_samples = tf.floor(posterior_tau * tf.cast(tf.size(count_data), tf.float32))\n", "\n", "# tau_samples, lambda_1_samples, lambda_2_samples contain\n", "# N samples from the corresponding posterior distribution\n",