Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

minimize and forward AD #112

Open
vadmbertr opened this issue Jan 24, 2025 · 19 comments
Open

minimize and forward AD #112

vadmbertr opened this issue Jan 24, 2025 · 19 comments
Labels
question User queries

Comments

@vadmbertr
Copy link

Hi!

I'm facing a similar use-case as the one described here #50 but I would like to optimize using minimize rather than least_squares.
Is there any plan/option for supporting a solution similar to options={"jac": "fwd"}?

Thanks a lot!
Vadim

@johannahaffner
Copy link
Contributor

johannahaffner commented Jan 24, 2025

Hi Vadim,

selection of forward- vs. reverse-mode autodiff is currently implemented at the solver level. Which solvers are you interested in using?

If you're using diffrax underneath: diffrax now has efficient forward-mode autodiff as well, with diffrax.ForwardMode.

@vadmbertr
Copy link
Author

Hi @johannahaffner,

Thanks for you reply!

I was interested in BFGS, but in forward mode. I was able to successfully solve a "least_squares" problem using forward mode and GaussNewton.
Indeed I'm using diffrax with diffrax.ForwardMode!

@johannahaffner
Copy link
Contributor

You're welcome!

Could you share the error you are getting? BFGS already uses jax.linearize, which implements forward-mode automatic differentiation. So this should actually work out of the box.

@patrick-kidger patrick-kidger added the question User queries label Jan 25, 2025
@vadmbertr
Copy link
Author

vadmbertr commented Jan 27, 2025

Hi!

Here is a MWE to reproduce:

from diffrax import diffeqsolve, Euler, ForwardMode, ODETerm, SaveAt
import jax.numpy as jnp
import optimistix as optx


def fn(y0, _=None):
    vector_field = lambda t, y, args: -y
    term = ODETerm(vector_field)
    solver = Euler()
    saveat = SaveAt(ts=[0., 1., 2., 3.])

    sol = diffeqsolve(term, solver, t0=0, t1=3, dt0=0.1, y0=y0, saveat=saveat, adjoint=ForwardMode())
    
    return sol.ys


def least_square(y0, _=None):
    ys = fn(y0)
    return jnp.sum(ys**2)


ls_sol = optx.least_squares(fn, optx.GaussNewton(rtol=1e-8, atol=1e-8), jnp.asarray(1.))
print(ls_sol.value)  # 0.0
min_sol = optx.minimise(least_square, optx.BFGS(rtol=1e-8, atol=1e-8), jnp.asarray(1.))  # error raised here
print(min_sol.value)

and I get the following:
ValueError: Reverse-mode differentiation does not work for lax.while_loop or lax.fori_loop with dynamic start/stop values. Try using lax.scan, or using fori_loop with static start/stop.

I'm using optimistix 0.0.10, diffrax 0.6.2 and equinox 0.11.11 if that matters.

EDIT: note that if I comment out adjoint=ForwardMode() and the two ls_sol lines, the min_sol ones evaluate as expected (but as I said I would like to use foward mode).

@johannahaffner
Copy link
Contributor

Ok I think I see where this is coming from - it looks like this is raised by the optimistix implicit adjoint. We've likely not run into this before because ForwardMode is pretty new.

The quickest fix I could think of didn't work, so I'll have to look into something else outside of working hours. Thanks for the MWE, I will keep you posted!

(In the meantime, you could try diffrax.DirectAdjoint - not a recommended long-term solution because its not super memory efficient, but it will allow for reverse-mode differentiation in the one place it is required, and forward-mode everywhere else.)

@johannahaffner
Copy link
Contributor

johannahaffner commented Jan 27, 2025

Ok, this would require a bit of a change. I had not appreciated that jax.linear_transpose will actually use reverse-mode AD machinery under the hood, but it does. We're using that in the minimisers to get a gradient from the linearised function.

return jax.linear_transpose(lin_fn, *primals)(1.0)

Here is what I think we can do:

  1. Create options fwd, bwd and branches for each.
  • In the forward branch, compute the gradient with jacfwd, which is equivalent to the gradient for a scalar function
  • In the reverse mode branch, keep doing what we are doing for the performance benefits
  1. Also create options as above, but get the gradient out of lin_fn a different way - essentially a custom jacfwd that uses lin_fn, constructs unit pytrees of shape y, then stacks the output of vmap(lin_fn)(unit_pytrees). The motivation here would be that we if already have lin_fn around, we might as well use it.
    Not sure what the performance would be. As mentioned in the documentation of jax.linearize, storing the linearised function has a memory overhead that might not outweigh the compilation time benefits in this case.

@patrick-kidger WDYT? Try both and see what sticks?

@patrick-kidger
Copy link
Owner

Ok, this would require a bit of a change. I had not appreciated that jax.linear_transpose will actually use reverse-mode AD machinery under the hood, but it does.

FWIW it's actually the other way around! Reverse-mode AD uses linear transpose. And indeed when it comes to stuff like jax.lax.while_loop then to be precise this isn't transposable, rather than not being reverse-mode autodifferentiable. (Since transposition is an infrequently used feature relative to reverse-mode AD, though, the JAX error message refers to the latter.)

As for what do here: if we want to support an alternative forward mode here then I think evaluating vmap(lin_fn)(unit_pytrees) would be the computationally optimal way to do this. (This is actually want jacfwd already does under the hood.)

But FWIW @vadmbertr the reason this doesn't come up super frequently is because for BFGS it is usually more efficient to use reverse-mode here to compute a gradient. Typically one would pair the Diffrax adjoint=... with the Optimistix solver=... as appropriate. (I will grant you that it's kind of annoying that this is needed, alas JAX does not currently support jvp-of-custom_vjp and this abstraction leak is what results.) Would doing this work for your real-world use-case?

@vadmbertr
Copy link
Author

Hi!
Thank you for digging into this.

@johannahaffner, indeed DirectAdjoint works as a workaround but might not be ideal for the reason you pointed out.

@patrick-kidger, what would be the reason for BFGS being more efficient in reverse-mode? (in the context where evaluating the gradient is significantly faster in forward mode) A colleague of mine wanted to compare a numpy implementation of a differential equation calibration problem with a JAX one, and he initially used BFGS for the solve. What would be the solver you recommend here (using optx.minimise)?

Again, thanks for the replies!
Vadim

@johannahaffner
Copy link
Contributor

@patrick-kidger, what would be the reason for BFGS being more efficient in reverse-mode? (in the context where evaluating the gradient is significantly faster in forward mode)

In general, reverse-mode automatic differentiation is more efficient for functions that map high-dimensional inputs to low-dimensional outputs (e.g. a neural network with many parameters and a scalar loss function). BFGS operates on such a scalar loss.
Forward-mode automatic differentiation is more efficient in the opposite setting, where few inputs get mapped to many outputs, such as a mechanistic model (like an ODE) with few parameters fitted to a long time series. The residual Jacobian is then going to be tall and narrow - with many rows for the residual, each of which is considered a model output.

This means that forward- vs. reverse-mode being more efficient is not really a model property! It really depends on the optimiser you use and whether it operates on the residuals or on their squared sum.

@vadmbertr
Copy link
Author

@patrick-kidger, what would be the reason for BFGS being more efficient in reverse-mode? (in the context where evaluating the gradient is significantly faster in forward mode)

In general, reverse-mode automatic differentiation is more efficient for functions that map high-dimensional inputs to low-dimensional outputs (e.g. a neural network with many parameters and a scalar loss function). BFGS operates on such a scalar loss. Forward-mode automatic differentiation is more efficient in the opposite setting, where few inputs get mapped to many outputs, such as a mechanistic model (like an ODE) with few parameters fitted to a long time series. The residual Jacobian is then going to be tall and narrow - with many rows for the residual, each of which is considered a model output.

This means that forward- vs. reverse-mode being more efficient is not really a model property! It really depends on the optimiser you use and whether it operates on the residuals or on their squared sum.

The use case is a bit in between what you described as the (scalar) loss "aggregates" (not the squared sum of the individual residuals) the outputs of a model with (very) few parameters. So I believe the Jacobian is wider but low-dimensional and (experimentally) computing the adjoint of the model is (much) faster in forward-mode than in reverse-mode.

@johannahaffner
Copy link
Contributor

The use case is a bit in between what you described as the (scalar) loss "aggregates" (not the squared sum of the individual residuals) the outputs of a model with (very) few parameters. So I believe the Jacobian is wider but low-dimensional and (experimentally) computing the adjoint of the model is (much) faster in forward-mode than in reverse-mode.

Is jacfwd on the scalar loss faster than grad?

@vadmbertr
Copy link
Author

Is jacfwd on the scalar loss faster than grad?

Yes, consider the following for example:

from diffrax import diffeqsolve, Euler, ForwardMode, ODETerm, RecursiveCheckpointAdjoint, SaveAt
import jax
import jax.numpy as jnp
import optimistix as optx


def fn(w, y0, adjoint):
    ts = jnp.arange(7*24)
    vector_field = lambda t, y, args: y * args * jnp.exp(-t)
    term = ODETerm(vector_field)
    solver = Euler()
    saveat = SaveAt(ts=ts)

    sol = diffeqsolve(term, solver, t0=ts[0], t1=ts[-1], dt0=1, y0=y0, args=w, saveat=saveat, adjoint=adjoint)

    return sol.ys


def loss(w, y0, adjoint):
    ys = fn(w, y0, adjoint)
    return jnp.sum(jnp.max(ys, axis=(1, 2)))


@jax.jit
def fwd(w, y0):
    return jax.jacfwd(loss, argnums=0)(w, y0, ForwardMode())


@jax.jit
def bwd(w, y0):
    return jax.grad(loss, argnums=0)(w, y0, RecursiveCheckpointAdjoint())


y0 = jnp.ones((100, 100))
w = 1.

%time print(fwd(w, y0))
# 536.20465
# CPU times: user 615 ms, sys: 58.3 ms, total: 673 ms
# Wall time: 492 ms
%time print(bwd(w, y0))
# 536.20483
# CPU times: user 1.51 s, sys: 60.6 ms, total: 1.57 s
# Wall time: 1.07 s

%timeit jax.block_until_ready(fwd(w, y0))
# 5.71 ms ± 260 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)
%timeit jax.block_until_ready(bwd(w, y0))
# 11.9 ms ± 224 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)

@johannahaffner
Copy link
Contributor

A little counterintuitive, and therefore very interesting! Thanks for the demo. To support this, I'll include a forward option in BFGS, then. I can probably get to it over the weekend :)

(Your wall & CPU times includes compilation, btw. Call a jitted function once before benchmarking - when you time them below, the compilation has already happened + computation is 100x as fast.)

@vadmbertr
Copy link
Author

A little counterintuitive, and therefore very interesting! Thanks for the demo. To support this, I'll include a forward option in BFGS, then. I can probably get to it over the weekend :)

Wow thanks a lot!

(Your wall & CPU times includes compilation, btw. Call a jitted function once before benchmarking - when you time them below, the compilation has already happened + computation is 100x as fast.)

I believe it is compiled in the %time call and it calls the jitted version in the %timeit one!

@johannahaffner
Copy link
Contributor

Wow thanks a lot!

You're welcome!

I believe it is compiled in the %time call and it calls the jitted version in the %timeit one!

Yes, exactly.

@patrick-kidger
Copy link
Owner

Bear in mind in this example that it's a scalar->scalar function, and for these then it's totally expected for forward-mode to be optimal. This is why I write 'usually more efficient to use reverse-mode' above, rather than always. The fact that this isn't a super frequent use case for BFGS -- there are many specifically scalar->scalar optimizers that are more common then, and arguably our best possible improvement here is actually to add more of them by default! -- is why this hasn't been super important before :)

@johannahaffner
Copy link
Contributor

johannahaffner commented Feb 2, 2025

@vadmbertr can you try https://github.com/johannahaffner/optimistix/tree/forward-fix on your real problem? This should now work, all you need to do is pass options=dict(mode="fwd") to optx.minimise.

If it does not really help on real problems, then no need to add it @patrick-kidger. And agreed for scalar->scalar solvers, in general.

@vadmbertr
Copy link
Author

Hi @johannahaffner,

Thanks for implementing this swiftly!
I can confirm that I get 50% to 1000% speed-up on real-world problems of different complexity (meaning time-step, state domain dimension) so I will be happy if it gets added.

@johannahaffner
Copy link
Contributor

Good morning @vadmbertr, thanks for trying it out so quickly! I opened a PR to add this option for all minimisers.

# for free to join this conversation on GitHub. Already have an account? # to comment
Labels
question User queries
Projects
None yet
Development

No branches or pull requests

3 participants