Add warning for varying simulator output sizes #370
Labels
efficiency
Some code needs to be optimized
user interface
Changes to the user interface and improvements in usability
Varying simulator output sizes are a common occurrence when the number of samples varies between calls to
simulator.sample()
:However, these can trigger excessive compile times in JAX, where each value for
n
triggers a recompilation. For a wide range ofn
, this can mean that the compilation dominates the training time.The current best-practice fix for users is to use padded tensors:
When we detect that compile times dominate, we should output a warning to the user, with a suggested fix. We could also improve support for padded simulator output in general. Further, we could look into if there are better ways to mask out unused values rather than just setting them to placeholder values like above.
The text was updated successfully, but these errors were encountered: