Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

feat: functorch integration #6

Merged
merged 26 commits into from
Sep 11, 2022
Merged

feat: functorch integration #6

merged 26 commits into from
Sep 11, 2022

Conversation

vmoens
Copy link
Contributor

@vmoens vmoens commented Apr 13, 2022

This is a proof of concept of integrating functorch in MAML.

@Benjamin-eecs Benjamin-eecs requested review from Benjamin-eecs and removed request for JieRen98 April 29, 2022 11:19
@XuehaiPan XuehaiPan added enhancement New feature or request functorch Something functorch related labels Jul 20, 2022
Copy link
Member

@Benjamin-eecs Benjamin-eecs left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We designed MetaOptimizers to serve as an differentiable optimizers for users who are familiar with pytorch and like to write optimization code in torch style instead of functional style, we have a series of torchopt.optimizers that support functional programming. Also we did not intend to let functorch be a package dependency. Thus I will close this PR.

@vmoens
Copy link
Contributor Author

vmoens commented Aug 11, 2022

I see your point @Benjamin-eecs
I see little advantage of the current API though since the users still need to save params in a state_dict and pass it back to their model afterwards. Hence I don't really see how it hides away the magic of meta learning algos. Can you elaborate on what the advantage is in your opinion?

Are you aware that functorch is now part of torch core? As such you don't need an extra dependency.

@Benjamin-eecs Benjamin-eecs reopened this Aug 11, 2022
@waterhorse1
Copy link
Collaborator

@vmoens Hi vincent, have you checked our low-level API? You can find information in our README and also our doc https://torchopt.readthedocs.io/en/latest/api/api.html#functional-optimizers. The extract_state api is mainly for the pytorch-like API (which is the high-level API). For instance, when I am conducting iterative multi-task training for maml, i need to reset my neural network parameter to the initial one. Thus we need such api.

@vmoens
Copy link
Contributor Author

vmoens commented Aug 12, 2022

Thanks for this @waterhorse1
I'm familiar with how the low and high level APIs work.
I'm just trying to point out that from a user perspective, the following two code snippets require the same mental effort, the same low level understanding of what a meta-learning algorithm is and does and the same amount of coding (the number of lines of code is identical).

Example 1 (current)

        policy_state_dict = torchopt.extract_state_dict(policy)
        for idx in range(TASK_NUM):
            for _ in range(inner_iters):
                pre_trajs = sample_traj(env, tasks[idx], policy)
                inner_loss = a2c_loss(pre_trajs, policy, value_coef=0.5)
                inner_opt.step(inner_loss)
            post_trajs = sample_traj(env, tasks[idx], policy)
            outer_loss = a2c_loss(post_trajs, policy, value_coef=0.5)
            outer_loss.backward()
            torchopt.recover_state_dict(policy, policy_state_dict)

Example 2 (functorch)

        fpolicy, policy_params = functorch.make_functional(policy)
        for idx in range(TASK_NUM):
            policy_params_new = policy_params
            for _ in range(inner_iters):
                pre_trajs = sample_traj(env, tasks[idx], fpolicy, policy_params_new)
                inner_loss = a2c_loss(pre_trajs, fpolicy, policy_params_new, value_coef=0.5)
                policy_params_new = inner_opt.step(inner_loss, policy_params_new)
            post_trajs = sample_traj(env, tasks[idx], fpolicy, policy_params_new)
            outer_loss = a2c_loss(post_trajs, fpolicy, policy_params_new, value_coef=0.5)
            outer_loss.backward()
#            not necessary since policy parameters need not to be reset
#            torchopt.recover_state_dict(policy, policy_state_dict)

It is my personal opinion that functorch offers more clarity on what is happening: with the current API, trying to make "as if" everything was like another pytorch optimisers may push users to overlook the extract_state_dict and recover_state_dict, or worse, use them where they should not. I do not think using functorch makes it less clear, on the contrary.
From an OOP perspective, the current API shows to the user twice the same object (e.g. the policy), once with a set of regular parameters, once with tensors that are not parameters anymore but the result of some optimization. I personally find it confusing, as I am looking at the same object, but with different content, and it is not apparent that the objects that once were attributes of it aren't anymore:

        policy_state_dict = torchopt.extract_state_dict(policy) << HERE POLICY HAS nn.Parameters ATTRIBUTES
        for idx in range(TASK_NUM):
            for _ in range(inner_iters):
                pre_trajs = sample_traj(env, tasks[idx], policy)
                inner_loss = a2c_loss(pre_trajs, policy, value_coef=0.5)
                inner_opt.step(inner_loss) << AFTER THIS POLICY HASN'T nn.Parameters ANYMORE
            post_trajs = sample_traj(env, tasks[idx], policy)
            outer_loss = a2c_loss(post_trajs, policy, value_coef=0.5)
            outer_loss.backward()
            torchopt.recover_state_dict(policy, policy_state_dict)<< AFTER POLICY HAS nn.Parameters ATTRIBUTES AGAIN

Basically, we're showing to the user one thing that is not one thing but 2, and this may lead users to expect behaviours that are not going to work in practice (e.g. what should parameters() return? In "normal" pytorch this iterator will always return the very same list of items).

@waterhorse1
Copy link
Collaborator

waterhorse1 commented Aug 12, 2022

@vmoens Hi vincent, thanks for your advice. We have a dicussion about what you mention and here is what we get:

For the functional high-level api, we can easily do the thing in your snippest by building a warpper to warp some low-level api (including torch.autograd.grad, optimizer.update and apply_updates). Bo is working on that in FuncOptimizer.

For the OOP api, it's impossible to make the parameters within the inner-loop process still nn.Parameters because they are non-leaf nodes. We can offer an alternative solution to that by warpping the tensor to become nn.Parameters. It's still a tensor but you can treat it as nn.Parameters.

@Benjamin-eecs Benjamin-eecs marked this pull request as ready for review September 9, 2022 07:28
@vmoens
Copy link
Contributor Author

vmoens commented Sep 9, 2022

@vmoens Hi vincent, thanks for your advice. We have a dicussion about what you mention and here is what we get:

For the functional high-level api, we can easily do the thing in your snippest by building a warpper to warp some low-level api (including torch.autograd.grad, optimizer.update and apply_updates). Bo is working on that in FuncOptimizer.

For the OOP api, it's impossible to make the parameters within the inner-loop process still nn.Parameters because they are non-leaf nodes. We can offer an alternative solution to that by warpping the tensor to become nn.Parameters. It's still a tensor but you can treat it as nn.Parameters.

Sure this is what I had in mind! Obviously we can't work with non-leaf nn.Parameters :-)
Great work guys, I love it

@Benjamin-eecs Benjamin-eecs requested review from XuehaiPan and removed request for JieRen98, waterhorse1 and XuehaiPan September 11, 2022 07:49
XuehaiPan
XuehaiPan previously approved these changes Sep 11, 2022
XuehaiPan
XuehaiPan previously approved these changes Sep 11, 2022
XuehaiPan
XuehaiPan previously approved these changes Sep 11, 2022
@XuehaiPan XuehaiPan merged commit f3fe2db into main Sep 11, 2022
@XuehaiPan XuehaiPan deleted the functorch_functional branch September 11, 2022 12:42
XuehaiPan pushed a commit that referenced this pull request Sep 11, 2022
# for free to join this conversation on GitHub. Already have an account? # to comment
Labels
enhancement New feature or request functorch Something functorch related
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants