Skip to content

Commit a9d884f

Browse files
committed
Add a wrapper over LeastSquaresOptim
1 parent a6af39c commit a9d884f

8 files changed

+123
-9
lines changed

Project.toml

+8-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "NonlinearSolve"
22
uuid = "8913a72c-1f9b-4ce2-8d82-65094dcecaec"
33
authors = ["SciML"]
4-
version = "2.2.1"
4+
version = "2.3.0"
55

66
[deps]
77
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
@@ -24,6 +24,12 @@ SparseDiffTools = "47a9eef4-7e08-11e9-0b38-333d64bd3804"
2424
StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"
2525
UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed"
2626

27+
[weakdeps]
28+
LeastSquaresOptim = "0fc2ff8b-aaa3-5acd-a817-1944a5e08891"
29+
30+
[extensions]
31+
NonlinearSolveLeastSquaresOptimExt = "LeastSquaresOptim"
32+
2733
[compat]
2834
ADTypes = "0.2"
2935
ArrayInterface = "6.0.24, 7"
@@ -33,6 +39,7 @@ EnumX = "1"
3339
Enzyme = "0.11"
3440
FiniteDiff = "2"
3541
ForwardDiff = "0.10.3"
42+
LeastSquaresOptim = "0.8"
3643
LineSearches = "7"
3744
LinearSolve = "2"
3845
NonlinearProblemLibrary = "0.1"
+65
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
module NonlinearSolveLeastSquaresOptimExt
2+
3+
using NonlinearSolve, SciMLBase
4+
import ConcreteStructs: @concrete
5+
import LeastSquaresOptim as LSO
6+
7+
extension_loaded(::Val{:LeastSquaresOptim}) = true
8+
9+
function _lso_solver(::LSOptimSolver{alg, linsolve}) where {alg, linsolve}
10+
ls = linsolve == :qr ? LSO.QR() :
11+
(linsolve == :cholesky ? LSO.Cholesky() :
12+
(linsolve == :lsmr ? LSO.LSMR() : nothing))
13+
if alg == :lm
14+
return LSO.LevenbergMarquardt(ls)
15+
elseif alg == :dogleg
16+
return LSO.Dogleg(ls)
17+
else
18+
throw(ArgumentError("Unknown LeastSquaresOptim Algorithm: $alg"))
19+
end
20+
end
21+
22+
@concrete struct LeastSquaresOptimCache
23+
prob
24+
alg
25+
allocated_prob
26+
kwargs
27+
end
28+
29+
@concrete struct FunctionWrapper{iip}
30+
f
31+
p
32+
end
33+
34+
(f::FunctionWrapper{true})(du, u) = f.f(du, u, f.p)
35+
(f::FunctionWrapper{false})(du, u) = (du .= f.f(u, f.p))
36+
37+
function SciMLBase.__init(prob::NonlinearLeastSquaresProblem, alg::LSOptimSolver, args...;
38+
abstol = 1e-8, reltol = 1e-8, verbose = false, maxiters = 1000, kwargs...)
39+
iip = SciMLBase.isinplace(prob)
40+
41+
f! = FunctionWrapper{iip}(prob.f, prob.p)
42+
g! = prob.f.jac === nothing ? nothing : FunctionWrapper{iip}(prob.f.jac, prob.p)
43+
44+
lsoprob = LSO.LeastSquaresProblem(; x = prob.u0, f!, y = prob.f.resid_prototype, g!,
45+
J = prob.f.jac_prototype, alg.autodiff,
46+
output_length = length(prob.f.resid_prototype))
47+
allocated_prob = LSO.LeastSquaresProblemAllocated(lsoprob, _lso_solver(alg))
48+
49+
return LeastSquaresOptimCache(prob, alg, allocated_prob,
50+
(; x_tol = abstol, f_tol = reltol, iterations = maxiters, show_trace = verbose,
51+
kwargs...))
52+
end
53+
54+
function SciMLBase.solve!(cache::LeastSquaresOptimCache)
55+
res = LSO.optimize!(cache.allocated_prob; cache.kwargs...)
56+
maxiters = cache.kwargs[:iterations]
57+
retcode = res.x_converged || res.f_converged || res.g_converged ? ReturnCode.Success :
58+
(res.iterations maxiters ? ReturnCode.MaxIters :
59+
ReturnCode.ConvergenceFailure)
60+
stats = SciMLBase.NLStats(res.f_calls, res.g_calls, -1, -1, res.iterations)
61+
return SciMLBase.build_solution(cache.prob, cache.alg, res.minimizer, res.ssr / 2;
62+
retcode, original = res, stats)
63+
end
64+
65+
end

src/NonlinearSolve.jl

+4-1
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@ abstract type AbstractNewtonAlgorithm{CJ, AD} <: AbstractNonlinearSolveAlgorithm
3030

3131
abstract type AbstractNonlinearSolveCache{iip} end
3232

33+
extension_loaded(::Val) = false
34+
3335
isinplace(::AbstractNonlinearSolveCache{iip}) where {iip} = iip
3436

3537
function SciMLBase.__solve(prob::Union{NonlinearProblem, NonlinearLeastSquaresProblem},
@@ -60,6 +62,7 @@ function SciMLBase.solve!(cache::AbstractNonlinearSolveCache)
6062
end
6163

6264
include("utils.jl")
65+
include("algorithms.jl")
6366
include("linesearch.jl")
6467
include("raphson.jl")
6568
include("trustRegion.jl")
@@ -92,7 +95,7 @@ end
9295

9396
export RadiusUpdateSchemes
9497

95-
export NewtonRaphson, TrustRegion, LevenbergMarquardt, GaussNewton
98+
export NewtonRaphson, TrustRegion, LevenbergMarquardt, GaussNewton, LSOptimSolver
9699

97100
export LineSearch
98101

src/algorithms.jl

+28
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
# Define Algorithms extended via extensions
2+
"""
3+
LSOptimSolver(alg = :lm; linsolve = nothing, autodiff::Symbol = :central)
4+
5+
Wrapper over [LeastSquaresOptim.jl](https://github.com/matthieugomez/LeastSquaresOptim.jl) for solving
6+
`NonlinearLeastSquaresProblem`.
7+
8+
## Arguments:
9+
10+
- `alg`: Algorithm to use. Can be `:lm` or `:dogleg`.
11+
- `linsolve`: Linear solver to use. Can be `:qr`, `:cholesky` or `:lsmr`. If
12+
`nothing`, then `LeastSquaresOptim.jl` will choose the best linear solver based
13+
on the Jacobian structure.
14+
15+
!!! note
16+
This algorithm is only available if `LeastSquaresOptim.jl` is installed.
17+
"""
18+
struct LSOptimSolver{alg, linsolve} <: AbstractNonlinearSolveAlgorithm
19+
autodiff::Symbol
20+
21+
function LSOptimSolver(alg = :lm; linsolve = nothing, autodiff::Symbol = :central)
22+
@assert alg in (:lm, :dogleg)
23+
@assert linsolve === nothing || linsolve in (:qr, :cholesky, :lsmr)
24+
@assert autodiff in (:central, :forward)
25+
26+
return new{alg, linsolve}(autodiff)
27+
end
28+
end

src/gaussnewton.jl

+2-2
Original file line numberDiff line numberDiff line change
@@ -93,8 +93,8 @@ end
9393
function perform_step!(cache::GaussNewtonCache{true})
9494
@unpack u, fu1, f, p, alg, J, JᵀJ, Jᵀf, linsolve, du = cache
9595
jacobian!!(J, cache)
96-
mul!(JᵀJ, J', J)
97-
mul!(Jᵀf, J', fu1)
96+
__matmul!(JᵀJ, J', J)
97+
__matmul!(Jᵀf, J', fu1)
9898

9999
# u = u - J \ fu
100100
linres = dolinsolve(alg.precs, linsolve; A = JᵀJ, b = _vec(Jᵀf), linu = _vec(du),

src/jacobian.jl

+5-3
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ function jacobian_caches(alg::AbstractNonlinearSolveAlgorithm, f, u, p, ::Val{ii
6565
# NOTE: The deepcopy is needed here since we are using the resid_prototype elsewhere
6666
fu = f.resid_prototype === nothing ? (iip ? _mutable_zero(u) : _mutable(f(u, p))) :
6767
(iip ? deepcopy(f.resid_prototype) : f.resid_prototype)
68-
if !has_analytic_jac && (linsolve_needs_jac || alg_wants_jac)
68+
if !has_analytic_jac && (linsolve_needs_jac || alg_wants_jac || needsJᵀJ)
6969
sd = sparsity_detection_alg(f, alg.ad)
7070
ad = alg.ad
7171
jac_cache = iip ? sparse_jacobian_cache(ad, sd, uf, fu, _maybe_mutable(u, ad)) :
@@ -74,7 +74,9 @@ function jacobian_caches(alg::AbstractNonlinearSolveAlgorithm, f, u, p, ::Val{ii
7474
jac_cache = nothing
7575
end
7676

77-
J = if !(linsolve_needs_jac || alg_wants_jac)
77+
# FIXME: To properly support needsJᵀJ without Jacobian, we need to implement
78+
# a reverse diff operation with the seed being `Jx`, this is not yet implemented
79+
J = if !(linsolve_needs_jac || alg_wants_jac || needsJᵀJ)
7880
# We don't need to construct the Jacobian
7981
JacVec(uf, u; autodiff = __get_nonsparse_ad(alg.ad))
8082
else
@@ -114,7 +116,7 @@ __get_nonsparse_ad(::AutoSparseZygote) = AutoZygote()
114116
__get_nonsparse_ad(ad) = ad
115117

116118
__init_JᵀJ(J::Number) = zero(J)
117-
__init_JᵀJ(J::AbstractArray) = zeros(eltype(J), size(J, 2), size(J, 2))
119+
__init_JᵀJ(J::AbstractArray) = J' * J
118120
__init_JᵀJ(J::StaticArray) = MArray{Tuple{size(J, 2), size(J, 2)}, eltype(J)}(undef)
119121

120122
## Special Handling for Scalars

src/levenberg.jl

+3-2
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,7 @@ function perform_step!(cache::LevenbergMarquardtCache{true})
192192

193193
if make_new_J
194194
jacobian!!(cache.J, cache)
195-
mul!(cache.JᵀJ, cache.J', cache.J)
195+
__matmul!(cache.JᵀJ, cache.J', cache.J)
196196
cache.DᵀD .= max.(cache.DᵀD, Diagonal(cache.JᵀJ))
197197
cache.make_new_J = false
198198
cache.stats.njacs += 1
@@ -216,7 +216,8 @@ function perform_step!(cache::LevenbergMarquardtCache{true})
216216
mul!(cache.Jv, J, v)
217217
@. cache.fu_tmp = (2 / h) * ((cache.fu_tmp - fu1) / h - cache.Jv)
218218
mul!(cache.u_tmp, J', cache.fu_tmp)
219-
linres = dolinsolve(alg.precs, linsolve; A = cache.mat_tmp, b = _vec(cache.u_tmp),
219+
# NOTE: Don't pass `A` in again, since we want to reuse the previous solve
220+
linres = dolinsolve(alg.precs, linsolve; b = _vec(cache.u_tmp),
220221
linu = _vec(cache.du), p = p, reltol = cache.abstol)
221222
cache.linsolve = linres.cache
222223
@. cache.a = -cache.du

src/utils.jl

+8
Original file line numberDiff line numberDiff line change
@@ -163,3 +163,11 @@ function evaluate_f(f, u, p, ::Val{iip}; fu = nothing) where {iip}
163163
return f(u, p)
164164
end
165165
end
166+
167+
"""
168+
__matmul!(C, A, B)
169+
170+
Defaults to `mul!(C, A, B)`. However, for sparse matrices uses `C .= A * B`.
171+
"""
172+
__matmul!(C, A, B) = mul!(C, A, B)
173+
__matmul!(C::AbstractSparseMatrix, A, B) = C .= A * B

0 commit comments

Comments
 (0)