Skip to content

Lotka-Volterra notebook: Flow-Matching OT formatting error #433

New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Closed
vpratz opened this issue Apr 22, 2025 · 1 comment
Closed

Lotka-Volterra notebook: Flow-Matching OT formatting error #433

vpratz opened this issue Apr 22, 2025 · 1 comment
Assignees

Comments

@vpratz
Copy link
Collaborator

vpratz commented Apr 22, 2025

Running the flow matching workflow in the Lotka-Volterra notebook leads to the error below. Seems to be an edge case related to formatting a warning message.

history = flow_matching_workflow.fit_offline(
    training_data, 
    epochs=epochs, 
    batch_size=batch_size, 
    validation_data=validation_data
)

INFO:bayesflow:Fitting on dataset instance of OfflineDataset.
INFO:bayesflow:Building on a test batch.
Epoch 1/50
WARNING:bayesflow:Log-Sinkhorn-Knopp produced NaNs.

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[18], line 1
----> 1 history = flow_matching_workflow.fit_offline(
      2     training_data, 
      3     epochs=epochs, 
      4     batch_size=batch_size, 
      5     validation_data=validation_data
      6 )

File [~/Programming/IWR/bf2/bayesflow/workflows/basic_workflow.py:714](http://localhost:8892/home/valentin/Programming/IWR/bf2/bayesflow/workflows/basic_workflow.py#line=713), in BasicWorkflow.fit_offline(self, data, epochs, batch_size, keep_optimizer, validation_data, **kwargs)
    679 """
    680 Train the approximator offline using a fixed dataset. This approach will be faster than online training,
    681 since no computation time is spent in generating new data for each batch, but it assumes that simulations
   (...)
    709     metric evolution over epochs.
    710 """
    712 dataset = OfflineDataset(data=data, batch_size=batch_size, adapter=self.adapter)
--> 714 return self._fit(
    715     dataset, epochs, strategy="online", keep_optimizer=keep_optimizer, validation_data=validation_data, **kwargs
    716 )

File [~/Programming/IWR/bf2/bayesflow/workflows/basic_workflow.py:913](http://localhost:8892/home/valentin/Programming/IWR/bf2/bayesflow/workflows/basic_workflow.py#line=912), in BasicWorkflow._fit(self, dataset, epochs, strategy, keep_optimizer, validation_data, **kwargs)
    910     self.approximator.compile(optimizer=self.optimizer, metrics=kwargs.pop("metrics", None))
    912 try:
--> 913     self.history = self.approximator.fit(
    914         dataset=dataset, epochs=epochs, validation_data=validation_data, **kwargs
    915     )
    916     self._on_training_finished()
    917     return self.history

File [~/Programming/IWR/bf2/bayesflow/approximators/continuous_approximator.py:200](http://localhost:8892/home/valentin/Programming/IWR/bf2/bayesflow/approximators/continuous_approximator.py#line=199), in ContinuousApproximator.fit(self, *args, **kwargs)
    148 def fit(self, *args, **kwargs):
    149     """
    150     Trains the approximator on the provided dataset or on-demand data generated from the given simulator.
    151     If `dataset` is not provided, a dataset is built from the `simulator`.
   (...)
    198         If both `dataset` and `simulator` are provided or neither is provided.
    199     """
--> 200     return super().fit(*args, **kwargs, adapter=self.adapter)

File [~/Programming/IWR/bf2/bayesflow/approximators/approximator.py:139](http://localhost:8892/home/valentin/Programming/IWR/bf2/bayesflow/approximators/approximator.py#line=138), in Approximator.fit(self, dataset, simulator, **kwargs)
    136     mock_data = keras.tree.map_structure(keras.ops.convert_to_tensor, mock_data)
    137     self.build_from_data(mock_data)
--> 139 return super().fit(dataset=dataset, **kwargs)

File [~/Programming/IWR/bf2/bayesflow/approximators/backend_approximators/backend_approximator.py:22](http://localhost:8892/home/valentin/Programming/IWR/bf2/bayesflow/approximators/backend_approximators/backend_approximator.py#line=21), in BackendApproximator.fit(self, dataset, **kwargs)
     21 def fit(self, *, dataset: keras.utils.PyDataset, **kwargs):
---> 22     return super().fit(x=dataset, y=None, **filter_kwargs(kwargs, super().fit))

File [/data/Programming/.mamba/envs/bf2/lib/python3.11/site-packages/keras/src/utils/traceback_utils.py:122](http://localhost:8892/data/Programming/.mamba/envs/bf2/lib/python3.11/site-packages/keras/src/utils/traceback_utils.py#line=121), in filter_traceback.<locals>.error_handler(*args, **kwargs)
    119     filtered_tb = _process_traceback_frames(e.__traceback__)
    120     # To get the full stack trace, call:
    121     # `keras.config.disable_traceback_filtering()`
--> 122     raise e.with_traceback(filtered_tb) from None
    123 finally:
    124     del filtered_tb

File [~/Programming/IWR/bf2/bayesflow/approximators/backend_approximators/tensorflow_approximator.py:20](http://localhost:8892/home/valentin/Programming/IWR/bf2/bayesflow/approximators/backend_approximators/tensorflow_approximator.py#line=19), in TensorFlowApproximator.train_step(self, data)
     18 with tf.GradientTape() as tape:
     19     kwargs = filter_kwargs(data | {"stage": "training"}, self.compute_metrics)
---> 20     metrics = self.compute_metrics(**kwargs)
     22 loss = metrics["loss"]
     24 grads = tape.gradient(loss, self.trainable_variables)

File [~/Programming/IWR/bf2/bayesflow/approximators/continuous_approximator.py:135](http://localhost:8892/home/valentin/Programming/IWR/bf2/bayesflow/approximators/continuous_approximator.py#line=134), in ContinuousApproximator.compute_metrics(self, inference_variables, inference_conditions, summary_variables, sample_weight, stage)
    133 # Force a conversion to Tensor
    134 inference_variables = keras.tree.map_structure(keras.ops.convert_to_tensor, inference_variables)
--> 135 inference_metrics = self.inference_network.compute_metrics(
    136     inference_variables, conditions=inference_conditions, sample_weight=sample_weight, stage=stage
    137 )
    139 loss = inference_metrics.get("loss", keras.ops.zeros(())) + summary_metrics.get("loss", keras.ops.zeros(()))
    141 inference_metrics = {f"{key}[/inference_](http://localhost:8892/inference_){key}": value for key, value in inference_metrics.items()}

File [~/Programming/IWR/bf2/bayesflow/networks/flow_matching/flow_matching.py:263](http://localhost:8892/home/valentin/Programming/IWR/bf2/bayesflow/networks/flow_matching/flow_matching.py#line=262), in FlowMatching.compute_metrics(self, x, conditions, sample_weight, stage)
    256 x0 = self.base_distribution.sample(keras.ops.shape(x1)[:-1])
    258 if self.use_optimal_transport:
    259     # we must choose between resampling x0 or x1
    260     # since the data is possibly noisy and may contain outliers, it is better
    261     # to possibly drop some samples from x1 than from x0
    262     # in the marginal over multiple batches, this is not a problem
--> 263     x0, x1, assignments = optimal_transport(
    264         x0,
    265         x1,
    266         seed=self.seed_generator,
    267         **self.optimal_transport_kwargs,
    268         return_assignments=True,
    269     )
    270     if conditions is not None:
    271         # conditions must be resampled along with x1
    272         conditions = keras.ops.take(conditions, assignments, axis=0)

File [~/Programming/IWR/bf2/bayesflow/utils/optimal_transport/optimal_transport.py:41](http://localhost:8892/home/valentin/Programming/IWR/bf2/bayesflow/utils/optimal_transport/optimal_transport.py#line=40), in optimal_transport(x1, x2, method, return_assignments, **kwargs)
     14 def optimal_transport(x1, x2, method="log_sinkhorn", return_assignments=False, **kwargs):
     15     """Matches elements from x2 onto x1, such that the transport cost between them is minimized, according to the method
     16     and cost matrix used.
     17 
   (...)
     39         x1 and x2 in optimal transport permutation order.
     40     """
---> 41     assignments = methods[method.lower()](x1, x2, **kwargs)
     42     x2 = keras.ops.take(x2, assignments, axis=0)
     44     if return_assignments:

File [~/Programming/IWR/bf2/bayesflow/utils/optimal_transport/log_sinkhorn.py:13](http://localhost:8892/home/valentin/Programming/IWR/bf2/bayesflow/utils/optimal_transport/log_sinkhorn.py#line=12), in log_sinkhorn(x1, x2, seed, **kwargs)
      8 def log_sinkhorn(x1, x2, seed: int = None, **kwargs):
      9     """
     10     Log-stabilized version of :py:func:`~bayesflow.utils.optimal_transport.sinkhorn.sinkhorn`.
     11     Significantly slower than the unstabilized version, so use only when you need numerical stability.
     12     """
---> 13     log_plan = log_sinkhorn_plan(x1, x2, **kwargs)
     14     assignments = keras.random.categorical(keras.ops.exp(log_plan), num_samples=1, seed=seed)
     15     assignments = keras.ops.squeeze(assignments, axis=1)

File [~/Programming/IWR/bf2/bayesflow/utils/optimal_transport/log_sinkhorn.py:74](http://localhost:8892/home/valentin/Programming/IWR/bf2/bayesflow/utils/optimal_transport/log_sinkhorn.py#line=73), in log_sinkhorn_plan(x1, x2, regularization, rtol, atol, max_steps)
     71     logging.warning(msg)
     73 keras.ops.cond(contains_nans(log_plan), warn_nans, do_nothing)
---> 74 keras.ops.cond(is_converged(log_plan), log_steps, warn_convergence)
     76 return log_plan

File [~/Programming/IWR/bf2/bayesflow/utils/optimal_transport/log_sinkhorn.py:58](http://localhost:8892/home/valentin/Programming/IWR/bf2/bayesflow/utils/optimal_transport/log_sinkhorn.py#line=57), in log_sinkhorn_plan.<locals>.log_steps()
     55 def log_steps():
     56     msg = "Log-Sinkhorn-Knopp converged after {:d} steps."
---> 58     logging.debug(msg, steps)

File [~/Programming/IWR/bf2/bayesflow/utils/logging.py:26](http://localhost:8892/home/valentin/Programming/IWR/bf2/bayesflow/utils/logging.py#line=25), in debug(msg, *args, **kwargs)
     25 def debug(msg, *args, **kwargs):
---> 26     _log(msg, *args, callback_fn=logger.debug, **kwargs)

File [~/Programming/IWR/bf2/bayesflow/utils/logging.py:18](http://localhost:8892/home/valentin/Programming/IWR/bf2/bayesflow/utils/logging.py#line=17), in _log(msg, callback_fn, *args, **kwargs)
     16     jax.debug.callback(__log, *args, **kwargs)
     17 else:
---> 18     callback_fn(msg.format(*args, **kwargs))

TypeError: Exception encountered when calling Cond.call().

unsupported format string passed to SymbolicTensor.__format__

Arguments received by Cond.call():
  • args=('tf.Tensor(shape=(), dtype=bool)', '<function log_sinkhorn_plan.<locals>.log_steps at 0x7faf94299b20>', '<function log_sinkhorn_plan.<locals>.warn_convergence at 0x7faf94299940>')
  • kwargs=<class 'inspect._empty'>
@LarsKue
Copy link
Contributor

LarsKue commented Apr 22, 2025

Fixed by b5836e8. I removed the check because I was running into issues in a clean multi-backend env due to an old jax version getting installed when you install alongside tensorflow. However, this should not be an issue if users install single backends as is the usual use-case.

The specific issue was that jax.core.is_concrete requires jax 0.5 or newer, and this check is required for out optimal transport methods.

@LarsKue LarsKue closed this as completed Apr 22, 2025
# for free to join this conversation on GitHub. Already have an account? # to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants