-
Notifications
You must be signed in to change notification settings - Fork 35
New issue
Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? # to your account
feat: functorch integration #6
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We designed MetaOptimizers to serve as an differentiable optimizers for users who are familiar with pytorch and like to write optimization code in torch style instead of functional style, we have a series of torchopt.optimizers that support functional programming. Also we did not intend to let functorch be a package dependency. Thus I will close this PR.
I see your point @Benjamin-eecs Are you aware that functorch is now part of torch core? As such you don't need an extra dependency. |
@vmoens Hi vincent, have you checked our low-level API? You can find information in our README and also our doc https://torchopt.readthedocs.io/en/latest/api/api.html#functional-optimizers. The extract_state api is mainly for the pytorch-like API (which is the high-level API). For instance, when I am conducting iterative multi-task training for maml, i need to reset my neural network parameter to the initial one. Thus we need such api. |
Thanks for this @waterhorse1 Example 1 (current) policy_state_dict = torchopt.extract_state_dict(policy)
for idx in range(TASK_NUM):
for _ in range(inner_iters):
pre_trajs = sample_traj(env, tasks[idx], policy)
inner_loss = a2c_loss(pre_trajs, policy, value_coef=0.5)
inner_opt.step(inner_loss)
post_trajs = sample_traj(env, tasks[idx], policy)
outer_loss = a2c_loss(post_trajs, policy, value_coef=0.5)
outer_loss.backward()
torchopt.recover_state_dict(policy, policy_state_dict) Example 2 (functorch) fpolicy, policy_params = functorch.make_functional(policy)
for idx in range(TASK_NUM):
policy_params_new = policy_params
for _ in range(inner_iters):
pre_trajs = sample_traj(env, tasks[idx], fpolicy, policy_params_new)
inner_loss = a2c_loss(pre_trajs, fpolicy, policy_params_new, value_coef=0.5)
policy_params_new = inner_opt.step(inner_loss, policy_params_new)
post_trajs = sample_traj(env, tasks[idx], fpolicy, policy_params_new)
outer_loss = a2c_loss(post_trajs, fpolicy, policy_params_new, value_coef=0.5)
outer_loss.backward()
# not necessary since policy parameters need not to be reset
# torchopt.recover_state_dict(policy, policy_state_dict) It is my personal opinion that functorch offers more clarity on what is happening: with the current API, trying to make "as if" everything was like another pytorch optimisers may push users to overlook the policy_state_dict = torchopt.extract_state_dict(policy) << HERE POLICY HAS nn.Parameters ATTRIBUTES
for idx in range(TASK_NUM):
for _ in range(inner_iters):
pre_trajs = sample_traj(env, tasks[idx], policy)
inner_loss = a2c_loss(pre_trajs, policy, value_coef=0.5)
inner_opt.step(inner_loss) << AFTER THIS POLICY HASN'T nn.Parameters ANYMORE
post_trajs = sample_traj(env, tasks[idx], policy)
outer_loss = a2c_loss(post_trajs, policy, value_coef=0.5)
outer_loss.backward()
torchopt.recover_state_dict(policy, policy_state_dict)<< AFTER POLICY HAS nn.Parameters ATTRIBUTES AGAIN Basically, we're showing to the user one thing that is not one thing but 2, and this may lead users to expect behaviours that are not going to work in practice (e.g. what should |
@vmoens Hi vincent, thanks for your advice. We have a dicussion about what you mention and here is what we get: For the functional high-level api, we can easily do the thing in your snippest by building a warpper to warp some low-level api (including torch.autograd.grad, optimizer.update and apply_updates). Bo is working on that in FuncOptimizer. For the OOP api, it's impossible to make the parameters within the inner-loop process still nn.Parameters because they are non-leaf nodes. We can offer an alternative solution to that by warpping the tensor to become nn.Parameters. It's still a tensor but you can treat it as nn.Parameters. |
Sure this is what I had in mind! Obviously we can't work with non-leaf nn.Parameters :-) |
This is a proof of concept of integrating functorch in MAML.