You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
RuntimeError: You are attempting to call Tensor.requires_grad_() (or perhaps using torch.autograd.functional.* APIs) inside of a function being transformed by a functorch transform. This is unsupported, please attempt to use the functorch transforms (e.g. grad, vjp, jacrev, jacfwd, hessian) or call requires_grad_() outside of a function being transformed instead.
From #30 this is possible but maybe the "jax" way is the wrong way to proceed....
The text was updated successfully, but these errors were encountered:
Like #30, I am trying to batch over different starting values when the obj function isn't globally convex.
I am coming from jax, where you setup the objective function and then use
vmap
to signal which axes to batch over.Here is what I have for a simple example (function is globally convex, but for demo purposes go ahead anyway):
We will setup starting values for
x
at three different points:where the first axes is the batching dimension (shape[0]=# of different starting values).
Setting up vmap for the function and evaluating it, shows things work as expected:
yields
But using this vmap'd function with pytorch-minimize isn't working:
throws this error:
So I tried to vmap a wrapped pyt_minimize call:
and call it to do the batch minimization:
with this error:
From #30 this is possible but maybe the "jax" way is the wrong way to proceed....
The text was updated successfully, but these errors were encountered: