From 477d85bc7ba7d13ef8d6978fffd21a777e0d2c03 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Wed, 13 Nov 2024 18:47:14 +0000 Subject: [PATCH] [Quality] Better use of StrEnum in set_interaction_type ghstack-source-id: c91a7a6be513fb46be6914df0b3bde779fa5528f Pull Request resolved: https://github.com/pytorch/tensordict/pull/1087 (cherry picked from commit 79a33458f177595d0d28e2a4c4131f7eb668c761) --- tensordict/nn/probabilistic.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/tensordict/nn/probabilistic.py b/tensordict/nn/probabilistic.py index 65257f74b..497394a99 100644 --- a/tensordict/nn/probabilistic.py +++ b/tensordict/nn/probabilistic.py @@ -56,10 +56,7 @@ class InteractionType(StrEnum): @classmethod def from_str(cls, type_str: str) -> InteractionType: """Return the interaction_type with name matched to the provided string (case insensitive).""" - for member_type in cls: - if member_type.name == type_str.upper(): - return member_type - raise ValueError(f"The provided interaction type {type_str} is unsupported!") + return cls(type_str.lower()) _INTERACTION_TYPE: InteractionType | None = None @@ -74,13 +71,16 @@ class set_interaction_type(_DecoratorContextManager): """Sets all ProbabilisticTDModules sampling to the desired type. Args: - type (InteractionType): sampling type to use when the policy is being called. + type (InteractionType or str): sampling type to use when the policy is being called. + """ def __init__( - self, type: InteractionType | None = InteractionType.DETERMINISTIC + self, type: InteractionType | str | None = InteractionType.DETERMINISTIC ) -> None: super().__init__() + if isinstance(type, str): + type = InteractionType(type.lower()) self.type = type def clone(self) -> set_interaction_type: