Skip to content

Commit

Permalink
batch latents test
Browse files Browse the repository at this point in the history
  • Loading branch information
eb8680 committed Jan 2, 2024
1 parent 9b3e962 commit 8061119
Showing 1 changed file with 65 additions and 7 deletions.
72 changes: 65 additions & 7 deletions tests/robust/test_internals_compositions.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
)
from chirho.robust.internals.predictive import BatchedNMCLogPredictiveLikelihood
from chirho.robust.internals.utils import (
BatchedLatents,
BatchedObservations,
make_functional_call,
reset_rng_state,
Expand Down Expand Up @@ -125,19 +126,22 @@ def test_nmc_likelihood_seeded(link_fn):
@pytest.mark.parametrize("pad_dim", [0, 1, 2])
def test_batched_observations(pad_dim: int):
max_plate_nesting = 1 + pad_dim
plate_name = "__dummy_plate__"
obs_plate_name = "__dummy_plate__"
num_particles_obs = 3
model = SimpleModel()
guide = SimpleGuide()

model(), guide() # initialize

predictive = pyro.infer.Predictive(model, num_samples=3, return_sites=["y"])
predictive = pyro.infer.Predictive(
model, num_samples=num_particles_obs, return_sites=["y"]
)

test_data = predictive()

with IndexPlatesMessenger(first_available_dim=-max_plate_nesting - 1):
with pyro.poutine.trace() as tr:
with BatchedObservations(test_data, name=plate_name):
with BatchedObservations(test_data, name=obs_plate_name):
model()

tr.trace.compute_log_prob()
Expand All @@ -147,12 +151,66 @@ def test_batched_observations(pad_dim: int):
node
):
if name in test_data:
assert plate_name in indices_of(node["log_prob"], event_dim=0)
assert plate_name in indices_of(
assert obs_plate_name in indices_of(node["log_prob"], event_dim=0)
assert obs_plate_name in indices_of(
node["value"], event_dim=len(node["fn"].event_shape)
)
else:
assert plate_name not in indices_of(node["log_prob"], event_dim=0)
assert plate_name not in indices_of(
assert obs_plate_name not in indices_of(
node["log_prob"], event_dim=0
)
assert obs_plate_name not in indices_of(
node["value"], event_dim=len(node["fn"].event_shape)
)


@pytest.mark.parametrize("pad_dim", [0, 1, 2])
def test_batched_latents_observations(pad_dim: int):
max_plate_nesting = 1 + pad_dim
num_particles_latent = 5
num_particles_obs = 3
obs_plate_name = "__dummy_plate__"
latent_plate_name = "__dummy_latents__"
model = SimpleModel()
guide = SimpleGuide()

model(), guide() # initialize

predictive = pyro.infer.Predictive(
model, num_samples=num_particles_obs, return_sites=["y"]
)

test_data = predictive()

with IndexPlatesMessenger(first_available_dim=-max_plate_nesting - 1):
with pyro.poutine.trace() as tr:
with BatchedLatents(
num_particles=num_particles_latent, name=latent_plate_name
):
with BatchedObservations(test_data, name=obs_plate_name):
model()

tr.trace.compute_log_prob()

for name, node in tr.trace.nodes.items():
if node["type"] == "sample" and not pyro.poutine.util.site_is_subsample(
node
):
if name in test_data:
assert obs_plate_name in indices_of(node["log_prob"], event_dim=0)
assert obs_plate_name in indices_of(
node["value"], event_dim=len(node["fn"].event_shape)
)
assert latent_plate_name in indices_of(
node["log_prob"], event_dim=0
)
assert latent_plate_name not in indices_of(
node["value"], event_dim=len(node["fn"].event_shape)
)
else:
assert latent_plate_name in indices_of(
node["log_prob"], event_dim=0
)
assert latent_plate_name in indices_of(
node["value"], event_dim=len(node["fn"].event_shape)
)

0 comments on commit 8061119

Please # to comment.