Skip to content

ModelComparisonSimulator: handle different outputs from individual simulators #452

New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Open
wants to merge 5 commits into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 75 additions & 3 deletions bayesflow/simulators/model_comparison_simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from bayesflow.utils.decorators import allow_batch_size

from bayesflow.utils import numpy_utils as npu
from bayesflow.utils import logging

from types import FunctionType

Expand All @@ -22,6 +23,8 @@
p: Sequence[float] = None,
logits: Sequence[float] = None,
use_mixed_batches: bool = True,
key_conflicts: str = "drop",
fill_value: float = np.nan,
shared_simulator: Simulator | Callable[[Sequence[int]], dict[str, any]] = None,
):
"""
Expand All @@ -38,11 +41,21 @@
A sequence of logits corresponding to model probabilities. Mutually exclusive with `p`.
If neither `p` nor `logits` is provided, defaults to uniform logits.
use_mixed_batches : bool, optional
If True, samples in a batch are drawn from different models. If False, the entire batch
is drawn from a single model chosen according to the model probabilities. Default is True.
Whether to draw samples in a batch from different models.

- If True (default), each sample in a batch may come from a different model.
- If False, the entire batch is drawn from a single model, selected according to model probabilities.
key_conflicts : str, optional
Policy for handling keys that are missing in the output of some models, when using mixed batches.

- "drop" (default): Drop conflicting keys from the batch output.
- "fill": Fill missing keys with the specified value.
- "error": An error is raised when key conflicts are detected.
fill_value : float, optional
If `key_conflicts=="fill"`, the missing keys will be filled with the value of this argument.
shared_simulator : Simulator or Callable, optional
A shared simulator whose outputs are passed to all model simulators. If a function is
provided, it is wrapped in a `LambdaSimulator` with batching enabled.
provided, it is wrapped in a :py:class:`~bayesflow.simulators.LambdaSimulator` with batching enabled.
"""
self.simulators = simulators

Expand All @@ -68,6 +81,9 @@

self.logits = logits
self.use_mixed_batches = use_mixed_batches
self.key_conflicts = key_conflicts
self.fill_value = fill_value
self._keys = None

@allow_batch_size
def sample(self, batch_shape: Shape, **kwargs) -> dict[str, np.ndarray]:
Expand Down Expand Up @@ -105,6 +121,7 @@
sims = [
simulator.sample(n, **(kwargs | data)) for simulator, n in zip(self.simulators, model_counts) if n > 0
]
sims = self._handle_key_conflicts(sims, model_counts)
sims = tree_concatenate(sims, numpy=True)
data |= sims

Expand All @@ -118,3 +135,58 @@
model_indices = npu.one_hot(np.full(batch_shape, model_index, dtype="int32"), num_models)

return data | {"model_indices": model_indices}

def _handle_key_conflicts(self, sims, batch_sizes):
batch_sizes = [b for b in batch_sizes if b > 0]

keys, all_keys, common_keys, missing_keys = self._determine_key_conflicts(sims=sims)

# all sims have the same keys
if all_keys == common_keys:
return sims

if self.key_conflicts == "drop":
sims = [{k: v for k, v in sim.items() if k in common_keys} for sim in sims]
return sims
elif self.key_conflicts == "fill":
combined_sims = {}
for sim in sims:
combined_sims = combined_sims | sim
for i, sim in enumerate(sims):
for missing_key in missing_keys[i]:
shape = combined_sims[missing_key].shape
shape = list(shape)
shape[0] = batch_sizes[i]
sim[missing_key] = np.full(shape=shape, fill_value=self.fill_value)
return sims
elif self.key_conflicts == "error":
raise ValueError("Key conflicts are found in simulator outputs, cannot combine them into one batch.")

def _determine_key_conflicts(self, sims):
# determine only once
if self._keys is not None:
return self._keys

Check warning on line 168 in bayesflow/simulators/model_comparison_simulator.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/simulators/model_comparison_simulator.py#L168

Added line #L168 was not covered by tests
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can this return "wrong" results if some simulators had n=0 in line 120 when the function first runs, and n>0 later on? Is this something we want to safeguard against?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can imagine this function to be quite cheap to compute, would it make sense to run it completely every time (but only logging the info once)?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, will do that


keys = [set(sim.keys()) for sim in sims]
all_keys = set.union(*keys)
common_keys = set.intersection(*keys)
missing_keys = [all_keys - k for k in keys]

self._keys = keys, all_keys, common_keys, missing_keys

if all_keys == common_keys:
return self._keys

if self.key_conflicts == "drop":
logging.info(
f"Incompatible simulator output. \
The following keys will be dropped: {', '.join(sorted(all_keys - common_keys))}."
)
elif self.key_conflicts == "fill":
logging.info(
f"Incompatible simulator output. \
Attempting to replace keys: {', '.join(sorted(all_keys - common_keys))}, where missing, \
with value {self.fill_value}."
)

return self._keys
50 changes: 50 additions & 0 deletions tests/test_simulators/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,56 @@ def likelihood(mu, n):
return make_simulator([prior, likelihood], meta_fn=context)


@pytest.fixture()
def multimodel():
from bayesflow.simulators import make_simulator, ModelComparisonSimulator

def context(batch_size):
return dict(n=np.random.randint(10, 100))

def prior_0():
return dict(mu=0)

def prior_1():
return dict(mu=np.random.standard_normal())

def likelihood(n, mu):
return dict(y=np.random.normal(mu, 1, n))

simulator_0 = make_simulator([prior_0, likelihood])
simulator_1 = make_simulator([prior_1, likelihood])

simulator = ModelComparisonSimulator(simulators=[simulator_0, simulator_1], shared_simulator=context)

return simulator


@pytest.fixture(params=["drop", "fill", "error"])
def multimodel_key_conflicts(request):
from bayesflow.simulators import make_simulator, ModelComparisonSimulator

rng = np.random.default_rng()

def prior_1():
return dict(w=rng.uniform())

def prior_2():
return dict(c=rng.uniform())

def model_1(w):
return dict(x=w)

def model_2(c):
return dict(x=c)

simulator_1 = make_simulator([prior_1, model_1])
simulator_2 = make_simulator([prior_2, model_2])

simulator = ModelComparisonSimulator(simulators=[simulator_1, simulator_2], key_conflicts=request.param)

return simulator


@pytest.fixture()
def fixed_n():
return 5
Expand Down
22 changes: 22 additions & 0 deletions tests/test_simulators/test_simulators.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import pytest
import keras
import numpy as np

Expand Down Expand Up @@ -47,3 +48,24 @@ def test_fixed_sample(composite_gaussian, batch_size, fixed_n, fixed_mu):
assert samples["mu"].shape == (batch_size, 1)
assert np.all(samples["mu"] == fixed_mu)
assert samples["y"].shape == (batch_size, fixed_n)


def test_multimodel_sample(multimodel, batch_size):
samples = multimodel.sample(batch_size)

assert set(samples) == {"n", "mu", "y", "model_indices"}
assert samples["mu"].shape == (batch_size, 1)
assert samples["y"].shape == (batch_size, samples["n"])


def test_multimodel_key_conflicts_sample(multimodel_key_conflicts, batch_size):
if multimodel_key_conflicts.key_conflicts == "drop":
samples = multimodel_key_conflicts.sample(batch_size)
assert set(samples) == {"x", "model_indices"}
elif multimodel_key_conflicts.key_conflicts == "fill":
samples = multimodel_key_conflicts.sample(batch_size)
assert set(samples) == {"x", "model_indices", "c", "w"}
assert np.sum(np.isnan(samples["c"])) + np.sum(np.isnan(samples["w"])) == batch_size
elif multimodel_key_conflicts.key_conflicts == "error":
with pytest.raises(Exception):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is too broad of a check. Use specific exception types.

samples = multimodel_key_conflicts.sample(batch_size)