diff --git a/tensordict/nn/utils.py b/tensordict/nn/utils.py index b503eb4a0..954511c7b 100644 --- a/tensordict/nn/utils.py +++ b/tensordict/nn/utils.py @@ -406,16 +406,21 @@ class _set_dispatch_td_nn_modules(_DecoratorContextManager): def __init__(self, mode): self.mode = mode + self._saved_mode = None def clone(self): return type(self)(self.mode) def __enter__(self): global DISPATCH_TDNN_MODULES - self._saved_mode = DISPATCH_TDNN_MODULES - DISPATCH_TDNN_MODULES = self.mode + # We want to avoid changing global variables because compile puts guards on them + if DISPATCH_TDNN_MODULES != self.mode: + self._saved_mode = DISPATCH_TDNN_MODULES + DISPATCH_TDNN_MODULES = self.mode def __exit__(self, exc_type, exc_val, exc_tb): + if self._saved_mode is None: + return global DISPATCH_TDNN_MODULES DISPATCH_TDNN_MODULES = self._saved_mode