Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

[Feature Request] Optimizing an EnsembleModule #498

Open
btx0424 opened this issue Jul 24, 2023 · 1 comment
Open

[Feature Request] Optimizing an EnsembleModule #498

btx0424 opened this issue Jul 24, 2023 · 1 comment
Assignees
Labels
enhancement New feature or request

Comments

@btx0424
Copy link

btx0424 commented Jul 24, 2023

Motivation

Ensembled modules should also support optimization with an API consistent with ordinary modules. However, passing EnsembledModule.parameters() to an optimizer, as usual, does not yield correct behavior.

import torch
import torch.nn as nn

from tensordict import TensorDict
from tensordict.nn import EnsembleModule, TensorDictModule

m = TensorDictModule(nn.Linear(128, 1), ["a"], ["a_out"])

m = EnsembleModule(m, 3, expand_input=True)
x = TensorDict({"a": torch.randn(32, 128)}, [32])

# this does not work
params = m.parameters()

# this works
params = list(m.params_td.values(True, True))
for param in params:
    param.retain_grad() # cannot optimize non-leaf tensors

opt = torch.optim.Adam(params)
for i in range(10):
    y = m(x)
    loss = y["a_out"].sum()
    opt.zero_grad()
    loss.backward()
    opt.step()
    print(loss.item())

Solution

A direct solution would be to override EnsembledModule.named_parameters(), which however may lead to other inconsistencies. It seems the issue is with nn.Parameter, which shares data with the tensordict from which it is created but does not receive gradients.

  • [ x ] I have checked that there is no similar issue in the repo (required)
@btx0424 btx0424 added the enhancement New feature or request label Jul 24, 2023
@vmoens
Copy link
Contributor

vmoens commented Jul 25, 2023

Thanks for reporting!
We need to make the params apparent within tensordict modules in some way, even if they're contained in the tensordict.
One option is to put them in parameter list, but I'm not fan of that solution...
I'm currently facing a similar problem with the loss modules in torchrl where we hack our way to register the params contained in a tensordict. I'll try to come up with an elegant solution for both and ping you once i'm done!

@vmoens vmoens mentioned this issue Jul 26, 2023
3 tasks
# for free to join this conversation on GitHub. Already have an account? # to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

2 participants