Skip to content

Commit

Permalink
[Quality] Better use of StrEnum in set_interaction_type
Browse files Browse the repository at this point in the history
ghstack-source-id: c91a7a6be513fb46be6914df0b3bde779fa5528f
Pull Request resolved: #1087

(cherry picked from commit 79a3345)
  • Loading branch information
vmoens committed Nov 14, 2024
1 parent 05c0fe7 commit 477d85b
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions tensordict/nn/probabilistic.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,7 @@ class InteractionType(StrEnum):
@classmethod
def from_str(cls, type_str: str) -> InteractionType:
"""Return the interaction_type with name matched to the provided string (case insensitive)."""
for member_type in cls:
if member_type.name == type_str.upper():
return member_type
raise ValueError(f"The provided interaction type {type_str} is unsupported!")
return cls(type_str.lower())


_INTERACTION_TYPE: InteractionType | None = None
Expand All @@ -74,13 +71,16 @@ class set_interaction_type(_DecoratorContextManager):
"""Sets all ProbabilisticTDModules sampling to the desired type.
Args:
type (InteractionType): sampling type to use when the policy is being called.
type (InteractionType or str): sampling type to use when the policy is being called.
"""

def __init__(
self, type: InteractionType | None = InteractionType.DETERMINISTIC
self, type: InteractionType | str | None = InteractionType.DETERMINISTIC
) -> None:
super().__init__()
if isinstance(type, str):
type = InteractionType(type.lower())
self.type = type

def clone(self) -> set_interaction_type:
Expand Down

0 comments on commit 477d85b

Please # to comment.