Skip to content

Commit 22fd30a

Browse files
committed
Automatically construct the jacobian for FastLM
1 parent 513eef6 commit 22fd30a

5 files changed

+107
-33
lines changed

ext/NonlinearSolveFastLevenbergMarquardtExt.jl

+56-13
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,12 @@ module NonlinearSolveFastLevenbergMarquardtExt
33
using ArrayInterface, NonlinearSolve, SciMLBase
44
import ConcreteStructs: @concrete
55
import FastLevenbergMarquardt as FastLM
6+
import FiniteDiff, ForwardDiff
67

78
function _fast_lm_solver(::FastLevenbergMarquardtJL{linsolve}, x) where {linsolve}
8-
if linsolve == :cholesky
9+
if linsolve === :cholesky
910
return FastLM.CholeskySolver(ArrayInterface.undefmatrix(x))
10-
elseif linsolve == :qr
11+
elseif linsolve === :qr
1112
return FastLM.QRSolver(eltype(x), length(x))
1213
else
1314
throw(ArgumentError("Unknown FastLevenbergMarquardt Linear Solver: $linsolve"))
@@ -33,23 +34,65 @@ end
3334

3435
function SciMLBase.__init(prob::NonlinearLeastSquaresProblem,
3536
alg::FastLevenbergMarquardtJL, args...; alias_u0 = false, abstol = 1e-8,
36-
reltol = 1e-8, verbose = false, maxiters = 1000, kwargs...)
37+
reltol = 1e-8, maxiters = 1000, kwargs...)
3738
iip = SciMLBase.isinplace(prob)
38-
u0 = alias_u0 ? prob.u0 : deepcopy(prob.u0)
39-
40-
@assert prob.f.jac!==nothing "FastLevenbergMarquardt requires a Jacobian!"
39+
u = NonlinearSolve.__maybe_unaliased(prob.u0, alias_u0)
40+
fu = NonlinearSolve.evaluate_f(prob, u)
4141

4242
f! = InplaceFunction{iip}(prob.f)
43-
J! = InplaceFunction{iip}(prob.f.jac)
4443

45-
resid_prototype = prob.f.resid_prototype === nothing ?
46-
(!iip ? prob.f(u0, prob.p) : zeros(u0)) :
47-
prob.f.resid_prototype
44+
if prob.f.jac === nothing
45+
use_forward_diff = if alg.autodiff === nothing
46+
ForwardDiff.can_dual(eltype(u))
47+
else
48+
alg.autodiff isa AutoForwardDiff
49+
end
50+
uf = SciMLBase.JacobianWrapper{iip}(prob.f, prob.p)
51+
if use_forward_diff
52+
cache = iip ? ForwardDiff.JacobianConfig(uf, fu, u) :
53+
ForwardDiff.JacobianConfig(uf, u)
54+
else
55+
cache = FiniteDiff.JacobianCache(u, fu)
56+
end
57+
J! = if iip
58+
if use_forward_diff
59+
fu_cache = similar(fu)
60+
function (J, x, p)
61+
uf.p = p
62+
ForwardDiff.jacobian!(J, uf, fu_cache, x, cache)
63+
return J
64+
end
65+
else
66+
function (J, x, p)
67+
uf.p = p
68+
FiniteDiff.finite_difference_jacobian!(J, uf, x, cache)
69+
return J
70+
end
71+
end
72+
else
73+
if use_forward_diff
74+
function (J, x, p)
75+
uf.p = p
76+
ForwardDiff.jacobian!(J, uf, x, cache)
77+
return J
78+
end
79+
else
80+
function (J, x, p)
81+
uf.p = p
82+
J_ = FiniteDiff.finite_difference_jacobian(uf, x, cache)
83+
copyto!(J, J_)
84+
return J
85+
end
86+
end
87+
end
88+
else
89+
J! = InplaceFunction{iip}(prob.f.jac)
90+
end
4891

49-
J = similar(u0, length(resid_prototype), length(u0))
92+
J = similar(u, length(fu), length(u))
5093

51-
solver = _fast_lm_solver(alg, u0)
52-
LM = FastLM.LMWorkspace(u0, resid_prototype, J)
94+
solver = _fast_lm_solver(alg, u)
95+
LM = FastLM.LMWorkspace(u, fu, J)
5396

5497
return FastLevenbergMarquardtJLCache(f!, J!, prob, alg, LM, solver,
5598
(; xtol = abstol, ftol = reltol, maxit = maxiters, alg.factor, alg.factoraccept,

ext/NonlinearSolveLeastSquaresOptimExt.jl

+5-5
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,12 @@ import ConcreteStructs: @concrete
55
import LeastSquaresOptim as LSO
66

77
function _lso_solver(::LeastSquaresOptimJL{alg, linsolve}) where {alg, linsolve}
8-
ls = linsolve == :qr ? LSO.QR() :
9-
(linsolve == :cholesky ? LSO.Cholesky() :
10-
(linsolve == :lsmr ? LSO.LSMR() : nothing))
11-
if alg == :lm
8+
ls = linsolve === :qr ? LSO.QR() :
9+
(linsolve === :cholesky ? LSO.Cholesky() :
10+
(linsolve === :lsmr ? LSO.LSMR() : nothing))
11+
if alg === :lm
1212
return LSO.LevenbergMarquardt(ls)
13-
elseif alg == :dogleg
13+
elseif alg === :dogleg
1414
return LSO.Dogleg(ls)
1515
else
1616
throw(ArgumentError("Unknown LeastSquaresOptim Algorithm: $alg"))

src/default.jl

+8-4
Original file line numberDiff line numberDiff line change
@@ -244,8 +244,10 @@ function FastShortcutNonlinearPolyalg(; concrete_jac = nothing, linsolve = nothi
244244
autodiff = nothing) where {JAC, SA}
245245
if JAC
246246
if SA
247-
algs = (SimpleNewtonRaphson(; autodiff),
248-
SimpleTrustRegion(; autodiff),
247+
algs = (SimpleNewtonRaphson(;
248+
autodiff = ifelse(autodiff === nothing, AutoForwardDiff(), autodiff)),
249+
SimpleTrustRegion(;
250+
autodiff = ifelse(autodiff === nothing, AutoForwardDiff(), autodiff)),
249251
NewtonRaphson(; concrete_jac, linsolve, precs, linesearch = BackTracking(),
250252
autodiff),
251253
TrustRegion(; concrete_jac, linsolve, precs,
@@ -263,8 +265,10 @@ function FastShortcutNonlinearPolyalg(; concrete_jac = nothing, linsolve = nothi
263265
algs = (SimpleBroyden(),
264266
Broyden(; init_jacobian = Val(:true_jacobian)),
265267
SimpleKlement(),
266-
SimpleNewtonRaphson(; autodiff),
267-
SimpleTrustRegion(; autodiff),
268+
SimpleNewtonRaphson(;
269+
autodiff = ifelse(autodiff === nothing, AutoForwardDiff(), autodiff)),
270+
SimpleTrustRegion(;
271+
autodiff = ifelse(autodiff === nothing, AutoForwardDiff(), autodiff)),
268272
NewtonRaphson(; concrete_jac, linsolve, precs, linesearch = BackTracking(),
269273
autodiff),
270274
TrustRegion(; concrete_jac, linsolve, precs,

src/extension_algs.jl

+11-7
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ function LeastSquaresOptimJL(alg = :lm; linsolve = nothing, autodiff::Symbol = :
3636
end
3737

3838
"""
39-
FastLevenbergMarquardtJL(linsolve = :cholesky)
39+
FastLevenbergMarquardtJL(linsolve = :cholesky; autodiff = nothing)
4040
4141
Wrapper over [FastLevenbergMarquardt.jl](https://github.com/kamesy/FastLevenbergMarquardt.jl)
4242
for solving `NonlinearLeastSquaresProblem`.
@@ -46,19 +46,20 @@ for solving `NonlinearLeastSquaresProblem`.
4646
This is not really the fastest solver. It is called that since the original package
4747
is called "Fast". `LevenbergMarquardt()` is almost always a better choice.
4848
49-
!!! warning
50-
51-
This algorithm requires the jacobian function to be provided!
52-
5349
## Arguments:
5450
5551
- `linsolve`: Linear solver to use. Can be `:qr` or `:cholesky`.
52+
- `autodiff`: determines the backend used for the Jacobian. Note that this argument is
53+
ignored if an analytical Jacobian is passed, as that will be used instead. Defaults to
54+
`nothing` which means that a default is selected according to the problem specification!
55+
Valid choices are `nothing`, `AutoForwardDiff` or `AutoFiniteDiff`.
5656
5757
!!! note
5858
5959
This algorithm is only available if `FastLevenbergMarquardt.jl` is installed.
6060
"""
6161
@concrete struct FastLevenbergMarquardtJL{linsolve} <: AbstractNonlinearSolveAlgorithm
62+
autodiff
6263
factor
6364
factoraccept
6465
factorreject
@@ -71,14 +72,17 @@ end
7172

7273
function FastLevenbergMarquardtJL(linsolve::Symbol = :cholesky; factor = 1e-6,
7374
factoraccept = 13.0, factorreject = 3.0, factorupdate = :marquardt,
74-
minscale = 1e-12, maxscale = 1e16, minfactor = 1e-28, maxfactor = 1e32)
75+
minscale = 1e-12, maxscale = 1e16, minfactor = 1e-28, maxfactor = 1e32,
76+
autodiff = nothing)
7577
@assert linsolve in (:qr, :cholesky)
7678
@assert factorupdate in (:marquardt, :nielson)
79+
@assert autodiff === nothing || autodiff isa AutoFiniteDiff ||
80+
autodiff isa AutoForwardDiff
7781

7882
if Base.get_extension(@__MODULE__, :NonlinearSolveFastLevenbergMarquardtExt) === nothing
7983
error("LeastSquaresOptimJL requires FastLevenbergMarquardt.jl to be loaded")
8084
end
8185

82-
return FastLevenbergMarquardtJL{linsolve}(factor, factoraccept, factorreject,
86+
return FastLevenbergMarquardtJL{linsolve}(autodiff, factor, factoraccept, factorreject,
8387
factorupdate, minscale, maxscale, minfactor, maxfactor)
8488
end

test/nonlinear_least_squares.jl

+27-4
Original file line numberDiff line numberDiff line change
@@ -89,12 +89,35 @@ function jac!(J, θ, p)
8989
return J
9090
end
9191

92-
prob = NonlinearLeastSquaresProblem(NonlinearFunction(loss_function;
93-
resid_prototype = zero(y_target), jac = jac!), θ_init, x)
92+
jac(θ, p) = ForwardDiff.jacobian-> loss_function(θ, p), θ)
9493

95-
solvers = [FastLevenbergMarquardtJL(:cholesky), FastLevenbergMarquardtJL(:qr)]
94+
probs = [
95+
NonlinearLeastSquaresProblem(NonlinearFunction{true}(loss_function;
96+
resid_prototype = zero(y_target), jac = jac!), θ_init, x),
97+
NonlinearLeastSquaresProblem(NonlinearFunction{false}(loss_function;
98+
resid_prototype = zero(y_target), jac = jac), θ_init, x),
99+
NonlinearLeastSquaresProblem(NonlinearFunction{false}(loss_function; jac), θ_init, x),
100+
]
101+
102+
solvers = [FastLevenbergMarquardtJL(linsolve) for linsolve in (:cholesky, :qr)]
103+
104+
for solver in solvers, prob in probs
105+
@time sol = solve(prob, solver; maxiters = 10000, abstol = 1e-8)
106+
@test norm(sol.resid) < 1e-6
107+
end
108+
109+
probs = [
110+
NonlinearLeastSquaresProblem(NonlinearFunction{true}(loss_function;
111+
resid_prototype = zero(y_target)), θ_init, x),
112+
NonlinearLeastSquaresProblem(NonlinearFunction{false}(loss_function;
113+
resid_prototype = zero(y_target)), θ_init, x),
114+
NonlinearLeastSquaresProblem(NonlinearFunction{false}(loss_function), θ_init, x),
115+
]
116+
117+
solvers = [FastLevenbergMarquardtJL(linsolve; autodiff) for linsolve in (:cholesky, :qr),
118+
autodiff in (nothing, AutoForwardDiff(), AutoFiniteDiff())]
96119

97-
for solver in solvers
120+
for solver in solvers, prob in probs
98121
@time sol = solve(prob, solver; maxiters = 10000, abstol = 1e-8)
99122
@test norm(sol.resid) < 1e-6
100123
end

0 commit comments

Comments
 (0)