Skip to content

Commit

Permalink
[BugFix] Fix compile during _check_keys
Browse files Browse the repository at this point in the history
ghstack-source-id: cfed094e425a60c62617ecfa454d3104ff1f461c
Pull Request resolved: #1239
  • Loading branch information
vmoens committed Feb 26, 2025
1 parent 635d036 commit 2ad9f95
Showing 1 changed file with 7 additions and 3 deletions.
10 changes: 7 additions & 3 deletions tensordict/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1772,7 +1772,8 @@ def _check_keys(
is_leaf=_is_leaf_nontensor,
)
# TODO: compile doesn't like set() over an arbitrary object
if is_compiling():
is_comp = is_compiling()
if is_comp:
keys_set = {k for k in keys} # noqa: C416
else:
keys_set: set[str] = set(keys)
Expand All @@ -1785,7 +1786,7 @@ def _check_keys(
if not strict:
keys_set = keys_set.intersection(k)
else:
if is_compiling():
if is_comp:
k = {v for v in k} # noqa: C416
else:
k = set(k)
Expand All @@ -1794,7 +1795,10 @@ def _check_keys(
f"got keys {keys} and {set(td.keys())} which are incompatible"
)
if strict:
return list(keys)
if is_comp:
return [key for key in keys] # noqa: C416
else:
return list(keys)
return keys_set


Expand Down

0 comments on commit 2ad9f95

Please # to comment.