You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I am trying to use jax.lax.scan together with the interactive solving approach (example in documentation) for some minimization.
However I found that for some solvers (least_squares) using lax.scan results in a type error. (Everything works with the standard for-loop.)
This is the example im working with:
import jax.numpy as jnp
import optimistix
from jax.tree_util import Partial
import jax
import numpy as np
### work with lax.scan
# solver = optimistix.BFGS(rtol=1e-3, atol=1e-3)
# solver = optimistix.NonlinearCG(rtol=1e-3, atol=1e-3)
### do NOT work with lax.scan
# solver = optimistix.GaussNewton(rtol=1e-3, atol=1e-3)
# solver = optimistix.LevenbergMarquardt(rtol=1e-3, atol=1e-3)
# solver = optimistix.IndirectLevenbergMarquardt(rtol=1e-3, atol=1e-3)
# solver = optimistix.NelderMead(rtol=1e-3, atol=1e-3)
# solver = optimistix.Dogleg(rtol=1e-3, atol=1e-3)
def test_func(x, *args):
return jnp.sum(x**2) + 1.0, None
fn = test_func
y = jnp.array(np.random.uniform(-1,1,size=(5,5)))
args = None
options = dict(lower=-1.0, upper=1.0)
f_struct = jax.ShapeDtypeStruct((), jnp.float32)
aux_struct = None
tags = frozenset()
state = solver.init(fn, y, args, options, f_struct, aux_struct, tags)
step = Partial(solver.step, fn=fn, args=args, options=options, tags=tags)
def step_helper(carry, xs):
y, state, _ = carry
return step(y=y, state=state), None
carry = (y, state, None)
carry, _ = jax.lax.scan(step_helper, carry, length=10)
#for _ in range(10):
# carry, _ = step_helper(carry, None)
y, state, aux = carry
The error which is raised is:
TypeError: Value { lambda a:f32[5,5]; b:f32[5,5]. let
c:f32[5,5] = mul b a
d:f32[] = reduce_sum[axes=(0, 1)] c
in (d,) } with type <class 'jax._src.core.Jaxpr'> is not a valid JAX type
Am I doing something wrong or is there something else going on?
The text was updated successfully, but these errors were encountered:
some solvers, such as the GaussNewton family of least-squares solvers, have a jaxpr as part of their state (this is the language in which compiled JAX programs are expressed). In the least-squares solvers, the residual Jacobian contains a jaxpr. These are not regular PyTrees, and are therefore incompatible with jax.lax.scan.
You can partition the state such that only the dynamic elements are part of carry, and the static stuff (which does not change from step to step) is closed over by your step_helper function.
Hi,
I am trying to use jax.lax.scan together with the interactive solving approach (example in documentation) for some minimization.
However I found that for some solvers (least_squares) using lax.scan results in a type error. (Everything works with the standard for-loop.)
This is the example im working with:
The error which is raised is:
Am I doing something wrong or is there something else going on?
The text was updated successfully, but these errors were encountered: