diff --git a/bayesflow/simulators/model_comparison_simulator.py b/bayesflow/simulators/model_comparison_simulator.py index 60174ef92..fb1aae098 100644 --- a/bayesflow/simulators/model_comparison_simulator.py +++ b/bayesflow/simulators/model_comparison_simulator.py @@ -6,6 +6,7 @@ from bayesflow.utils.decorators import allow_batch_size from bayesflow.utils import numpy_utils as npu +from bayesflow.utils import logging from types import FunctionType @@ -22,6 +23,8 @@ def __init__( p: Sequence[float] = None, logits: Sequence[float] = None, use_mixed_batches: bool = True, + key_conflicts: str = "drop", + fill_value: float = np.nan, shared_simulator: Simulator | Callable[[Sequence[int]], dict[str, any]] = None, ): """ @@ -38,11 +41,21 @@ def __init__( A sequence of logits corresponding to model probabilities. Mutually exclusive with `p`. If neither `p` nor `logits` is provided, defaults to uniform logits. use_mixed_batches : bool, optional - If True, samples in a batch are drawn from different models. If False, the entire batch - is drawn from a single model chosen according to the model probabilities. Default is True. + Whether to draw samples in a batch from different models. + + - If True (default), each sample in a batch may come from a different model. + - If False, the entire batch is drawn from a single model, selected according to model probabilities. + key_conflicts : str, optional + Policy for handling keys that are missing in the output of some models, when using mixed batches. + + - "drop" (default): Drop conflicting keys from the batch output. + - "fill": Fill missing keys with the specified value. + - "error": An error is raised when key conflicts are detected. + fill_value : float, optional + If `key_conflicts=="fill"`, the missing keys will be filled with the value of this argument. shared_simulator : Simulator or Callable, optional A shared simulator whose outputs are passed to all model simulators. If a function is - provided, it is wrapped in a `LambdaSimulator` with batching enabled. + provided, it is wrapped in a :py:class:`~bayesflow.simulators.LambdaSimulator` with batching enabled. """ self.simulators = simulators @@ -68,6 +81,9 @@ def __init__( self.logits = logits self.use_mixed_batches = use_mixed_batches + self.key_conflicts = key_conflicts + self.fill_value = fill_value + self._keys = None @allow_batch_size def sample(self, batch_shape: Shape, **kwargs) -> dict[str, np.ndarray]: @@ -105,6 +121,7 @@ def sample(self, batch_shape: Shape, **kwargs) -> dict[str, np.ndarray]: sims = [ simulator.sample(n, **(kwargs | data)) for simulator, n in zip(self.simulators, model_counts) if n > 0 ] + sims = self._handle_key_conflicts(sims, model_counts) sims = tree_concatenate(sims, numpy=True) data |= sims @@ -118,3 +135,58 @@ def sample(self, batch_shape: Shape, **kwargs) -> dict[str, np.ndarray]: model_indices = npu.one_hot(np.full(batch_shape, model_index, dtype="int32"), num_models) return data | {"model_indices": model_indices} + + def _handle_key_conflicts(self, sims, batch_sizes): + batch_sizes = [b for b in batch_sizes if b > 0] + + keys, all_keys, common_keys, missing_keys = self._determine_key_conflicts(sims=sims) + + # all sims have the same keys + if all_keys == common_keys: + return sims + + if self.key_conflicts == "drop": + sims = [{k: v for k, v in sim.items() if k in common_keys} for sim in sims] + return sims + elif self.key_conflicts == "fill": + combined_sims = {} + for sim in sims: + combined_sims = combined_sims | sim + for i, sim in enumerate(sims): + for missing_key in missing_keys[i]: + shape = combined_sims[missing_key].shape + shape = list(shape) + shape[0] = batch_sizes[i] + sim[missing_key] = np.full(shape=shape, fill_value=self.fill_value) + return sims + elif self.key_conflicts == "error": + raise ValueError("Key conflicts are found in simulator outputs, cannot combine them into one batch.") + + def _determine_key_conflicts(self, sims): + # determine only once + if self._keys is not None: + return self._keys + + keys = [set(sim.keys()) for sim in sims] + all_keys = set.union(*keys) + common_keys = set.intersection(*keys) + missing_keys = [all_keys - k for k in keys] + + self._keys = keys, all_keys, common_keys, missing_keys + + if all_keys == common_keys: + return self._keys + + if self.key_conflicts == "drop": + logging.info( + f"Incompatible simulator output. \ +The following keys will be dropped: {', '.join(sorted(all_keys - common_keys))}." + ) + elif self.key_conflicts == "fill": + logging.info( + f"Incompatible simulator output. \ +Attempting to replace keys: {', '.join(sorted(all_keys - common_keys))}, where missing, \ +with value {self.fill_value}." + ) + + return self._keys diff --git a/tests/test_simulators/conftest.py b/tests/test_simulators/conftest.py index 0e76a5396..7dcc22c12 100644 --- a/tests/test_simulators/conftest.py +++ b/tests/test_simulators/conftest.py @@ -167,6 +167,56 @@ def likelihood(mu, n): return make_simulator([prior, likelihood], meta_fn=context) +@pytest.fixture() +def multimodel(): + from bayesflow.simulators import make_simulator, ModelComparisonSimulator + + def context(batch_size): + return dict(n=np.random.randint(10, 100)) + + def prior_0(): + return dict(mu=0) + + def prior_1(): + return dict(mu=np.random.standard_normal()) + + def likelihood(n, mu): + return dict(y=np.random.normal(mu, 1, n)) + + simulator_0 = make_simulator([prior_0, likelihood]) + simulator_1 = make_simulator([prior_1, likelihood]) + + simulator = ModelComparisonSimulator(simulators=[simulator_0, simulator_1], shared_simulator=context) + + return simulator + + +@pytest.fixture(params=["drop", "fill", "error"]) +def multimodel_key_conflicts(request): + from bayesflow.simulators import make_simulator, ModelComparisonSimulator + + rng = np.random.default_rng() + + def prior_1(): + return dict(w=rng.uniform()) + + def prior_2(): + return dict(c=rng.uniform()) + + def model_1(w): + return dict(x=w) + + def model_2(c): + return dict(x=c) + + simulator_1 = make_simulator([prior_1, model_1]) + simulator_2 = make_simulator([prior_2, model_2]) + + simulator = ModelComparisonSimulator(simulators=[simulator_1, simulator_2], key_conflicts=request.param) + + return simulator + + @pytest.fixture() def fixed_n(): return 5 diff --git a/tests/test_simulators/test_simulators.py b/tests/test_simulators/test_simulators.py index e9a3c80c0..4e2174be3 100644 --- a/tests/test_simulators/test_simulators.py +++ b/tests/test_simulators/test_simulators.py @@ -1,3 +1,4 @@ +import pytest import keras import numpy as np @@ -47,3 +48,24 @@ def test_fixed_sample(composite_gaussian, batch_size, fixed_n, fixed_mu): assert samples["mu"].shape == (batch_size, 1) assert np.all(samples["mu"] == fixed_mu) assert samples["y"].shape == (batch_size, fixed_n) + + +def test_multimodel_sample(multimodel, batch_size): + samples = multimodel.sample(batch_size) + + assert set(samples) == {"n", "mu", "y", "model_indices"} + assert samples["mu"].shape == (batch_size, 1) + assert samples["y"].shape == (batch_size, samples["n"]) + + +def test_multimodel_key_conflicts_sample(multimodel_key_conflicts, batch_size): + if multimodel_key_conflicts.key_conflicts == "drop": + samples = multimodel_key_conflicts.sample(batch_size) + assert set(samples) == {"x", "model_indices"} + elif multimodel_key_conflicts.key_conflicts == "fill": + samples = multimodel_key_conflicts.sample(batch_size) + assert set(samples) == {"x", "model_indices", "c", "w"} + assert np.sum(np.isnan(samples["c"])) + np.sum(np.isnan(samples["w"])) == batch_size + elif multimodel_key_conflicts.key_conflicts == "error": + with pytest.raises(Exception): + samples = multimodel_key_conflicts.sample(batch_size)