Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Interactively step through solve with jax.lax.scan #103

Open
matillda123 opened this issue Dec 31, 2024 · 2 comments
Open

Interactively step through solve with jax.lax.scan #103

matillda123 opened this issue Dec 31, 2024 · 2 comments
Labels
question User queries

Comments

@matillda123
Copy link

matillda123 commented Dec 31, 2024

Hi,

I am trying to use jax.lax.scan together with the interactive solving approach (example in documentation) for some minimization.
However I found that for some solvers (least_squares) using lax.scan results in a type error. (Everything works with the standard for-loop.)

This is the example im working with:

import jax.numpy as jnp
import optimistix
from jax.tree_util import Partial
import jax
import numpy as np

### work with lax.scan
# solver = optimistix.BFGS(rtol=1e-3, atol=1e-3)
# solver = optimistix.NonlinearCG(rtol=1e-3, atol=1e-3)

### do NOT work with lax.scan
# solver = optimistix.GaussNewton(rtol=1e-3, atol=1e-3)
# solver = optimistix.LevenbergMarquardt(rtol=1e-3, atol=1e-3)
# solver = optimistix.IndirectLevenbergMarquardt(rtol=1e-3, atol=1e-3)
# solver = optimistix.NelderMead(rtol=1e-3, atol=1e-3)
# solver = optimistix.Dogleg(rtol=1e-3, atol=1e-3)



def test_func(x, *args):
    return jnp.sum(x**2) + 1.0, None


fn = test_func
y = jnp.array(np.random.uniform(-1,1,size=(5,5)))

args = None
options = dict(lower=-1.0, upper=1.0)
f_struct = jax.ShapeDtypeStruct((), jnp.float32)
aux_struct = None
tags = frozenset()

state = solver.init(fn, y, args, options, f_struct, aux_struct, tags)
step = Partial(solver.step, fn=fn, args=args, options=options, tags=tags)

def step_helper(carry, xs):
    y, state, _ = carry
    return step(y=y, state=state), None

carry = (y, state, None)

carry, _ = jax.lax.scan(step_helper, carry, length=10)

#for _ in range(10):
#    carry, _ = step_helper(carry, None)

y, state, aux = carry

The error which is raised is:

TypeError: Value { lambda a:f32[5,5]; b:f32[5,5]. let
c:f32[5,5] = mul b a
d:f32[] = reduce_sum[axes=(0, 1)] c
in (d,) } with type <class 'jax._src.core.Jaxpr'> is not a valid JAX type

Am I doing something wrong or is there something else going on?

@johannahaffner
Copy link
Contributor

Hi Matilda,

some solvers, such as the GaussNewton family of least-squares solvers, have a jaxpr as part of their state (this is the language in which compiled JAX programs are expressed). In the least-squares solvers, the residual Jacobian contains a jaxpr. These are not regular PyTrees, and are therefore incompatible with jax.lax.scan.

You can partition the state such that only the dynamic elements are part of carry, and the static stuff (which does not change from step to step) is closed over by your step_helper function.

The following code works:

import equinox as eqx
import jax
import jax.numpy as jnp
import numpy as np
import optimistix
from jax.tree_util import Partial


solver = optimistix.GaussNewton(rtol=1e-3, atol=1e-3)

def test_func(x, *args):
    return jnp.sum(x**2) + 1.0, None

fn = test_func
y = jnp.array(np.random.uniform(-1,1,size=(5,5)))

args = None
options = dict(lower=-1.0, upper=1.0)
f_struct = jax.ShapeDtypeStruct((), jnp.float32)
aux_struct = None
tags = frozenset()

state = solver.init(fn, y, args, options, f_struct, aux_struct, tags)
dynamic, static = eqx.partition(state, eqx.is_array)
step = Partial(solver.step, fn=fn, args=args, options=options, tags=tags)

def step_helper(carry, xs):
    
    y, dynamic, _ = carry
    state = eqx.combine(dynamic, static)
    y, state, aux = step(y=y, state=state)
    dynamic, _ = eqx.partition(state, eqx.is_array)
    
    carry = (y, dynamic, aux)
    return carry, None

carry = (y, dynamic, None)
carry, _ = jax.lax.scan(step_helper, carry, length=10)
y, state, aux = carry

NelderMead worked for me using your MWE code.

@patrick-kidger patrick-kidger added the question User queries label Jan 1, 2025
@matillda123
Copy link
Author

Hey,

thanks for the help. Everything works now.
NelderMead also originally worked for me. I labelled it as "not working" by mistake.

# for free to join this conversation on GitHub. Already have an account? # to comment
Labels
question User queries
Projects
None yet
Development

No branches or pull requests

3 participants