Skip to content

Commit

Permalink
[Refactor] Make _set_dispatch_td_nn_modules compatible with compile
Browse files Browse the repository at this point in the history
ghstack-source-id: 85a78cd6086233b414fcfe221dd8129e2e38f71c
Pull Request resolved: #1084

(cherry picked from commit 853b7d9)
  • Loading branch information
vmoens committed Nov 14, 2024
1 parent 178dfd9 commit f24e3d8
Showing 1 changed file with 7 additions and 2 deletions.
9 changes: 7 additions & 2 deletions tensordict/nn/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,16 +406,21 @@ class _set_dispatch_td_nn_modules(_DecoratorContextManager):

def __init__(self, mode):
self.mode = mode
self._saved_mode = None

def clone(self):
return type(self)(self.mode)

def __enter__(self):
global DISPATCH_TDNN_MODULES
self._saved_mode = DISPATCH_TDNN_MODULES
DISPATCH_TDNN_MODULES = self.mode
# We want to avoid changing global variables because compile puts guards on them
if DISPATCH_TDNN_MODULES != self.mode:
self._saved_mode = DISPATCH_TDNN_MODULES
DISPATCH_TDNN_MODULES = self.mode

def __exit__(self, exc_type, exc_val, exc_tb):
if self._saved_mode is None:
return
global DISPATCH_TDNN_MODULES
DISPATCH_TDNN_MODULES = self._saved_mode

Expand Down

0 comments on commit f24e3d8

Please # to comment.