diff --git a/rllib/core/models/configs.py b/rllib/core/models/configs.py index a345860213c7..0a8b9491148b 100644 --- a/rllib/core/models/configs.py +++ b/rllib/core/models/configs.py @@ -887,7 +887,6 @@ class RecurrentEncoderConfig(ModelConfig): - Zero or one tokenizers - N LSTM/GRU layers stacked on top of each other and feeding their outputs as inputs to the respective next layer. - - One linear output layer This makes for the following flow of tensors: @@ -901,8 +900,6 @@ class RecurrentEncoderConfig(ModelConfig): | LSTM layer n | - Linear output layer - | Outputs The internal state is structued as (num_layers, B, hidden-size) for all hidden diff --git a/rllib/core/models/tf/encoder.py b/rllib/core/models/tf/encoder.py index 8efb6c5a136a..ff4956df4a8b 100644 --- a/rllib/core/models/tf/encoder.py +++ b/rllib/core/models/tf/encoder.py @@ -181,7 +181,6 @@ class TfGRUEncoder(TfModel, Encoder): This encoder has... - Zero or one tokenizers. - One or more GRU layers. - - One linear output layer. """ def __init__(self, config: RecurrentEncoderConfig) -> None: @@ -313,7 +312,6 @@ class TfLSTMEncoder(TfModel, Encoder): This encoder has... - Zero or one tokenizers. - One or more LSTM layers. - - One linear output layer. """ def __init__(self, config: RecurrentEncoderConfig) -> None: diff --git a/rllib/core/models/torch/encoder.py b/rllib/core/models/torch/encoder.py index d2400447a10a..f9e59bdc6f2f 100644 --- a/rllib/core/models/torch/encoder.py +++ b/rllib/core/models/torch/encoder.py @@ -174,7 +174,6 @@ class TorchGRUEncoder(TorchModel, Encoder): This encoder has... - Zero or one tokenizers. - One or more GRU layers. - - One linear output layer. """ def __init__(self, config: RecurrentEncoderConfig) -> None: @@ -297,7 +296,6 @@ class TorchLSTMEncoder(TorchModel, Encoder): This encoder has... - Zero or one tokenizers. - One or more LSTM layers. - - One linear output layer. """ def __init__(self, config: RecurrentEncoderConfig) -> None: diff --git a/rllib/examples/rl_modules/classes/action_masking_rlm.py b/rllib/examples/rl_modules/classes/action_masking_rlm.py index 853ef1f979de..e948b8c1a1ef 100644 --- a/rllib/examples/rl_modules/classes/action_masking_rlm.py +++ b/rllib/examples/rl_modules/classes/action_masking_rlm.py @@ -93,19 +93,20 @@ def _forward_exploration( def _forward_train( self, batch: Dict[str, TensorType], **kwargs ) -> Dict[str, TensorType]: - # Preprocess the original batch to extract the action mask. - action_mask, batch = self._preprocess_batch(batch) # Run the forward pass. outs = super()._forward_train(batch, **kwargs) # Mask the action logits and return. - return self._mask_action_logits(outs, action_mask) + return self._mask_action_logits(outs, batch["action_mask"]) @override(ValueFunctionAPI) def compute_values(self, batch: Dict[str, TensorType]): # Preprocess the batch to extract the `observations` to `Columns.OBS`. - _, batch = self._preprocess_batch(batch) + action_mask, batch = self._preprocess_batch(batch) + # NOTE: Because we manipulate the batch we need to add the `action_mask` + # to the batch to access them in `_forward_train`. + batch["action_mask"] = action_mask # Call the super's method to compute values for GAE. - return super()._compute_values(batch) + return super().compute_values(batch) def _preprocess_batch( self, batch: Dict[str, TensorType], **kwargs