diff --git a/tests/robust/test_internals_compositions.py b/tests/robust/test_internals_compositions.py index 634aa0f8..fbe6e1ed 100644 --- a/tests/robust/test_internals_compositions.py +++ b/tests/robust/test_internals_compositions.py @@ -13,6 +13,7 @@ ) from chirho.robust.internals.predictive import BatchedNMCLogPredictiveLikelihood from chirho.robust.internals.utils import ( + BatchedLatents, BatchedObservations, make_functional_call, reset_rng_state, @@ -125,19 +126,22 @@ def test_nmc_likelihood_seeded(link_fn): @pytest.mark.parametrize("pad_dim", [0, 1, 2]) def test_batched_observations(pad_dim: int): max_plate_nesting = 1 + pad_dim - plate_name = "__dummy_plate__" + obs_plate_name = "__dummy_plate__" + num_particles_obs = 3 model = SimpleModel() guide = SimpleGuide() model(), guide() # initialize - predictive = pyro.infer.Predictive(model, num_samples=3, return_sites=["y"]) + predictive = pyro.infer.Predictive( + model, num_samples=num_particles_obs, return_sites=["y"] + ) test_data = predictive() with IndexPlatesMessenger(first_available_dim=-max_plate_nesting - 1): with pyro.poutine.trace() as tr: - with BatchedObservations(test_data, name=plate_name): + with BatchedObservations(test_data, name=obs_plate_name): model() tr.trace.compute_log_prob() @@ -147,12 +151,66 @@ def test_batched_observations(pad_dim: int): node ): if name in test_data: - assert plate_name in indices_of(node["log_prob"], event_dim=0) - assert plate_name in indices_of( + assert obs_plate_name in indices_of(node["log_prob"], event_dim=0) + assert obs_plate_name in indices_of( node["value"], event_dim=len(node["fn"].event_shape) ) else: - assert plate_name not in indices_of(node["log_prob"], event_dim=0) - assert plate_name not in indices_of( + assert obs_plate_name not in indices_of( + node["log_prob"], event_dim=0 + ) + assert obs_plate_name not in indices_of( + node["value"], event_dim=len(node["fn"].event_shape) + ) + + +@pytest.mark.parametrize("pad_dim", [0, 1, 2]) +def test_batched_latents_observations(pad_dim: int): + max_plate_nesting = 1 + pad_dim + num_particles_latent = 5 + num_particles_obs = 3 + obs_plate_name = "__dummy_plate__" + latent_plate_name = "__dummy_latents__" + model = SimpleModel() + guide = SimpleGuide() + + model(), guide() # initialize + + predictive = pyro.infer.Predictive( + model, num_samples=num_particles_obs, return_sites=["y"] + ) + + test_data = predictive() + + with IndexPlatesMessenger(first_available_dim=-max_plate_nesting - 1): + with pyro.poutine.trace() as tr: + with BatchedLatents( + num_particles=num_particles_latent, name=latent_plate_name + ): + with BatchedObservations(test_data, name=obs_plate_name): + model() + + tr.trace.compute_log_prob() + + for name, node in tr.trace.nodes.items(): + if node["type"] == "sample" and not pyro.poutine.util.site_is_subsample( + node + ): + if name in test_data: + assert obs_plate_name in indices_of(node["log_prob"], event_dim=0) + assert obs_plate_name in indices_of( + node["value"], event_dim=len(node["fn"].event_shape) + ) + assert latent_plate_name in indices_of( + node["log_prob"], event_dim=0 + ) + assert latent_plate_name not in indices_of( + node["value"], event_dim=len(node["fn"].event_shape) + ) + else: + assert latent_plate_name in indices_of( + node["log_prob"], event_dim=0 + ) + assert latent_plate_name in indices_of( node["value"], event_dim=len(node["fn"].event_shape) )