Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Batch Optimization #35

Open
roblem opened this issue Oct 25, 2024 · 0 comments
Open

Batch Optimization #35

roblem opened this issue Oct 25, 2024 · 0 comments

Comments

@roblem
Copy link

roblem commented Oct 25, 2024

Like #30, I am trying to batch over different starting values when the obj function isn't globally convex.

I am coming from jax, where you setup the objective function and then use vmap to signal which axes to batch over.

Here is what I have for a simple example (function is globally convex, but for demo purposes go ahead anyway):

from torchmin import minimize as pyt_minimize

def objfun(x):
    return .1 * x + 3 * x ** 2

We will setup starting values for x at three different points:

init_batched = torch.tensor([[1.], [3.], [-1.5]])

where the first axes is the batching dimension (shape[0]=# of different starting values).

Setting up vmap for the function and evaluating it, shows things work as expected:

batched_obj_fun = torch.vmap(objfun, in_dims=0)
batched_obj_fun(init_batched)

yields

: tensor([[ 3.1000],
:         [27.3000],
:         [ 6.6000]], device='cuda:0')

But using this vmap'd function with pytorch-minimize isn't working:

res = pyt_minimize(lambda parms: objfun(parms),
             init_batched, method='bfgs', tol=1e-5, disp=True)

throws this error:

<lots of trace>
RuntimeError: ScalarFunction was supplied a function that does not return scalar outputs.

So I tried to vmap a wrapped pyt_minimize call:

def minimize_fun(coords):
    res = pyt_minimize(lambda parms: objfun(parms),
               coords, method='bfgs', tol=1e-5, disp=True)
    return res

batched_minimize = torch.vmap(minimize_fun, in_dims=0)

and call it to do the batch minimization:

batched_minimize(init_batched)

with this error:

RuntimeError: You are attempting to call Tensor.requires_grad_() (or perhaps using torch.autograd.functional.* APIs) inside of a function being transformed by a functorch transform. This is unsupported, please attempt to use the functorch transforms (e.g. grad, vjp, jacrev, jacfwd, hessian) or call requires_grad_() outside of a function being transformed instead.

From #30 this is possible but maybe the "jax" way is the wrong way to proceed....

# for free to join this conversation on GitHub. Already have an account? # to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant