Skip to content

Commit

Permalink
[RLlib] Fix action masking example. (#47817)
Browse files Browse the repository at this point in the history
  • Loading branch information
simonsays1980 authored Sep 25, 2024
1 parent a5f82e1 commit 7966130
Show file tree
Hide file tree
Showing 4 changed files with 6 additions and 12 deletions.
3 changes: 0 additions & 3 deletions rllib/core/models/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -887,7 +887,6 @@ class RecurrentEncoderConfig(ModelConfig):
- Zero or one tokenizers
- N LSTM/GRU layers stacked on top of each other and feeding
their outputs as inputs to the respective next layer.
- One linear output layer
This makes for the following flow of tensors:
Expand All @@ -901,8 +900,6 @@ class RecurrentEncoderConfig(ModelConfig):
|
LSTM layer n
|
Linear output layer
|
Outputs
The internal state is structued as (num_layers, B, hidden-size) for all hidden
Expand Down
2 changes: 0 additions & 2 deletions rllib/core/models/tf/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,6 @@ class TfGRUEncoder(TfModel, Encoder):
This encoder has...
- Zero or one tokenizers.
- One or more GRU layers.
- One linear output layer.
"""

def __init__(self, config: RecurrentEncoderConfig) -> None:
Expand Down Expand Up @@ -313,7 +312,6 @@ class TfLSTMEncoder(TfModel, Encoder):
This encoder has...
- Zero or one tokenizers.
- One or more LSTM layers.
- One linear output layer.
"""

def __init__(self, config: RecurrentEncoderConfig) -> None:
Expand Down
2 changes: 0 additions & 2 deletions rllib/core/models/torch/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,6 @@ class TorchGRUEncoder(TorchModel, Encoder):
This encoder has...
- Zero or one tokenizers.
- One or more GRU layers.
- One linear output layer.
"""

def __init__(self, config: RecurrentEncoderConfig) -> None:
Expand Down Expand Up @@ -297,7 +296,6 @@ class TorchLSTMEncoder(TorchModel, Encoder):
This encoder has...
- Zero or one tokenizers.
- One or more LSTM layers.
- One linear output layer.
"""

def __init__(self, config: RecurrentEncoderConfig) -> None:
Expand Down
11 changes: 6 additions & 5 deletions rllib/examples/rl_modules/classes/action_masking_rlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,19 +93,20 @@ def _forward_exploration(
def _forward_train(
self, batch: Dict[str, TensorType], **kwargs
) -> Dict[str, TensorType]:
# Preprocess the original batch to extract the action mask.
action_mask, batch = self._preprocess_batch(batch)
# Run the forward pass.
outs = super()._forward_train(batch, **kwargs)
# Mask the action logits and return.
return self._mask_action_logits(outs, action_mask)
return self._mask_action_logits(outs, batch["action_mask"])

@override(ValueFunctionAPI)
def compute_values(self, batch: Dict[str, TensorType]):
# Preprocess the batch to extract the `observations` to `Columns.OBS`.
_, batch = self._preprocess_batch(batch)
action_mask, batch = self._preprocess_batch(batch)
# NOTE: Because we manipulate the batch we need to add the `action_mask`
# to the batch to access them in `_forward_train`.
batch["action_mask"] = action_mask
# Call the super's method to compute values for GAE.
return super()._compute_values(batch)
return super().compute_values(batch)

def _preprocess_batch(
self, batch: Dict[str, TensorType], **kwargs
Expand Down

0 comments on commit 7966130

Please # to comment.