Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

[RLlib] Fix action masking example. #47817

Merged
merged 4 commits into from
Sep 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions rllib/core/models/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -887,7 +887,6 @@ class RecurrentEncoderConfig(ModelConfig):
- Zero or one tokenizers
- N LSTM/GRU layers stacked on top of each other and feeding
their outputs as inputs to the respective next layer.
- One linear output layer

This makes for the following flow of tensors:

Expand All @@ -901,8 +900,6 @@ class RecurrentEncoderConfig(ModelConfig):
|
LSTM layer n
|
Linear output layer
|
Outputs

The internal state is structued as (num_layers, B, hidden-size) for all hidden
Expand Down
2 changes: 0 additions & 2 deletions rllib/core/models/tf/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,6 @@ class TfGRUEncoder(TfModel, Encoder):
This encoder has...
- Zero or one tokenizers.
- One or more GRU layers.
- One linear output layer.
"""

def __init__(self, config: RecurrentEncoderConfig) -> None:
Expand Down Expand Up @@ -313,7 +312,6 @@ class TfLSTMEncoder(TfModel, Encoder):
This encoder has...
- Zero or one tokenizers.
- One or more LSTM layers.
- One linear output layer.
"""

def __init__(self, config: RecurrentEncoderConfig) -> None:
Expand Down
2 changes: 0 additions & 2 deletions rllib/core/models/torch/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,6 @@ class TorchGRUEncoder(TorchModel, Encoder):
This encoder has...
- Zero or one tokenizers.
- One or more GRU layers.
- One linear output layer.
"""

def __init__(self, config: RecurrentEncoderConfig) -> None:
Expand Down Expand Up @@ -297,7 +296,6 @@ class TorchLSTMEncoder(TorchModel, Encoder):
This encoder has...
- Zero or one tokenizers.
- One or more LSTM layers.
- One linear output layer.
"""

def __init__(self, config: RecurrentEncoderConfig) -> None:
Expand Down
11 changes: 6 additions & 5 deletions rllib/examples/rl_modules/classes/action_masking_rlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,19 +93,20 @@ def _forward_exploration(
def _forward_train(
self, batch: Dict[str, TensorType], **kwargs
) -> Dict[str, TensorType]:
# Preprocess the original batch to extract the action mask.
action_mask, batch = self._preprocess_batch(batch)
# Run the forward pass.
outs = super()._forward_train(batch, **kwargs)
# Mask the action logits and return.
return self._mask_action_logits(outs, action_mask)
return self._mask_action_logits(outs, batch["action_mask"])

@override(ValueFunctionAPI)
def compute_values(self, batch: Dict[str, TensorType]):
# Preprocess the batch to extract the `observations` to `Columns.OBS`.
_, batch = self._preprocess_batch(batch)
action_mask, batch = self._preprocess_batch(batch)
# NOTE: Because we manipulate the batch we need to add the `action_mask`
# to the batch to access them in `_forward_train`.
batch["action_mask"] = action_mask
# Call the super's method to compute values for GAE.
return super()._compute_values(batch)
return super().compute_values(batch)

def _preprocess_batch(
self, batch: Dict[str, TensorType], **kwargs
Expand Down
Loading