Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

[Feature Request] Support for TensorDictBase.masked_select inplace #394

Open
xmaples opened this issue May 24, 2023 · 0 comments
Open

[Feature Request] Support for TensorDictBase.masked_select inplace #394

xmaples opened this issue May 24, 2023 · 0 comments
Assignees
Labels
enhancement New feature or request

Comments

@xmaples
Copy link
Contributor

xmaples commented May 24, 2023

Motivation

Giving the ability of .masked_select() inplace for the TensorDictBase.

Solution

Giving the ability of .masked_select() like but modification-inplace for the TensorDictBase, by a method named .masked_select_().

Main steps to achieve this:

  1. Iterate key-values and collect masked tensors for values with type leaf tensor
  2. Iterate key-values with type of nested TensorDict, and call recursively .masked_select_()
  3. Modify the batch_size to the correct

Examples:

td = TensorDict(source={'a': torch.zeros(3, 4)},
    batch_size=[3])
mask = torch.tensor([True, False, False])
td.masked_select_(mask)
td.get("a")
#output: tensor([[0., 0., 0., 0.]])
# for free to join this conversation on GitHub. Already have an account? # to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

2 participants