Skip to content

Commit 3617d6f

Browse files
committed
Add a wrapper for Optimization.jl
1 parent b7ba71d commit 3617d6f

10 files changed

+203
-17
lines changed

Project.toml

+16-8
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "NonlinearSolve"
22
uuid = "8913a72c-1f9b-4ce2-8d82-65094dcecaec"
33
authors = ["SciML"]
4-
version = "3.8.3"
4+
version = "3.9.0"
55

66
[deps]
77
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
@@ -35,8 +35,9 @@ FastLevenbergMarquardt = "7a0df574-e128-4d35-8cbd-3d84502bf7ce"
3535
FixedPointAcceleration = "817d07cb-a79a-5c30-9a31-890123675176"
3636
LeastSquaresOptim = "0fc2ff8b-aaa3-5acd-a817-1944a5e08891"
3737
MINPACK = "4854310b-de5a-5eb6-a2a5-c1dee2bd17f9"
38-
NLsolve = "2774e3e8-f4cf-5e23-947b-6d7e65073b56"
3938
NLSolvers = "337daf1e-9722-11e9-073e-8b9effe078ba"
39+
NLsolve = "2774e3e8-f4cf-5e23-947b-6d7e65073b56"
40+
Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba"
4041
SIAMFANLEquations = "084e46ad-d928-497d-ad5e-07fa361a48c4"
4142
SpeedMapping = "f1835b91-879b-4a3f-a438-e4baacf14412"
4243
Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7"
@@ -48,8 +49,9 @@ NonlinearSolveFastLevenbergMarquardtExt = "FastLevenbergMarquardt"
4849
NonlinearSolveFixedPointAccelerationExt = "FixedPointAcceleration"
4950
NonlinearSolveLeastSquaresOptimExt = "LeastSquaresOptim"
5051
NonlinearSolveMINPACKExt = "MINPACK"
51-
NonlinearSolveNLsolveExt = "NLsolve"
5252
NonlinearSolveNLSolversExt = "NLSolvers"
53+
NonlinearSolveNLsolveExt = "NLsolve"
54+
NonlinearSolveOptimizationExt = "Optimization"
5355
NonlinearSolveSIAMFANLEquationsExt = "SIAMFANLEquations"
5456
NonlinearSolveSpeedMappingExt = "SpeedMapping"
5557
NonlinearSolveSymbolicsExt = "Symbolics"
@@ -61,8 +63,8 @@ Aqua = "0.8"
6163
ArrayInterface = "7.7"
6264
BandedMatrices = "1.4"
6365
BenchmarkTools = "1.4"
64-
ConcreteStructs = "0.2.3"
6566
CUDA = "5.1"
67+
ConcreteStructs = "0.2.3"
6668
DiffEqBase = "6.146.0"
6769
Enzyme = "0.11.15"
6870
FastBroadcast = "0.2.8"
@@ -71,17 +73,20 @@ FastLevenbergMarquardt = "0.1"
7173
FiniteDiff = "2.21"
7274
FixedPointAcceleration = "0.3"
7375
ForwardDiff = "0.10.36"
76+
Ipopt = "1.6"
7477
LazyArrays = "1.8.2"
7578
LeastSquaresOptim = "0.8.5"
7679
LineSearches = "7.2"
7780
LinearAlgebra = "1.10"
7881
LinearSolve = "2.21"
7982
MINPACK = "1.2"
8083
MaybeInplace = "0.1.1"
81-
NLsolve = "4.5"
8284
NLSolvers = "0.5"
85+
NLsolve = "4.5"
8386
NaNMath = "1"
8487
NonlinearProblemLibrary = "0.1.2"
88+
Optimization = "3.24"
89+
OptimizationMOI = "0.4"
8590
OrdinaryDiffEq = "6.63"
8691
Pkg = "1.10"
8792
PrecompileTools = "1.2"
@@ -93,7 +98,7 @@ RecursiveArrayTools = "3.4"
9398
Reexport = "1.2"
9499
SIAMFANLEquations = "1.0.1"
95100
SafeTestsets = "0.1"
96-
SciMLBase = "2.19.0"
101+
SciMLBase = "2.23"
97102
SimpleNonlinearSolve = "1.2"
98103
SparseArrays = "1.10"
99104
SparseDiffTools = "2.14"
@@ -118,14 +123,17 @@ Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
118123
FastLevenbergMarquardt = "7a0df574-e128-4d35-8cbd-3d84502bf7ce"
119124
FixedPointAcceleration = "817d07cb-a79a-5c30-9a31-890123675176"
120125
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
126+
Ipopt = "b6b21f68-93f8-5de0-b562-5493be1d77c9"
121127
LeastSquaresOptim = "0fc2ff8b-aaa3-5acd-a817-1944a5e08891"
122128
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
123129
LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae"
124130
MINPACK = "4854310b-de5a-5eb6-a2a5-c1dee2bd17f9"
125-
NLsolve = "2774e3e8-f4cf-5e23-947b-6d7e65073b56"
126131
NLSolvers = "337daf1e-9722-11e9-073e-8b9effe078ba"
132+
NLsolve = "2774e3e8-f4cf-5e23-947b-6d7e65073b56"
127133
NaNMath = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3"
128134
NonlinearProblemLibrary = "b7050fa9-e91f-4b37-bcee-a89a063da141"
135+
Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba"
136+
OptimizationMOI = "fd9f6733-72f4-499f-8506-86b2bdd0dea1"
129137
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
130138
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
131139
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
@@ -143,4 +151,4 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
143151
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
144152

145153
[targets]
146-
test = ["Aqua", "Enzyme", "BenchmarkTools", "SafeTestsets", "Pkg", "Test", "ForwardDiff", "StaticArrays", "Symbolics", "LinearSolve", "Random", "LinearAlgebra", "Zygote", "SparseDiffTools", "NonlinearProblemLibrary", "LeastSquaresOptim", "FastLevenbergMarquardt", "NaNMath", "BandedMatrices", "DiffEqBase", "StableRNGs", "MINPACK", "NLsolve", "OrdinaryDiffEq", "SpeedMapping", "FixedPointAcceleration", "SIAMFANLEquations", "Sundials", "ReTestItems", "Reexport", "CUDA", "NLSolvers"]
154+
test = ["Aqua", "BandedMatrices", "BenchmarkTools", "CUDA", "DiffEqBase", "Enzyme", "FastLevenbergMarquardt", "FixedPointAcceleration", "ForwardDiff", "Ipopt", "LeastSquaresOptim", "LinearAlgebra", "LinearSolve", "MINPACK", "NLSolvers", "NLsolve", "NaNMath", "NonlinearProblemLibrary", "Optimization", "OptimizationMOI", "OrdinaryDiffEq", "Pkg", "Random", "ReTestItems", "Reexport", "SIAMFANLEquations", "SafeTestsets", "SparseDiffTools", "SpeedMapping", "StableRNGs", "StaticArrays", "Sundials", "Symbolics", "Test", "Zygote"]

docs/pages.jl

+2-1
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@ pages = ["index.md",
1717
"native/globalization.md", "native/diagnostics.md"],
1818
"Wrapped Solver APIs" => Any[
1919
"api/fastlevenbergmarquardt.md", "api/fixedpointacceleration.md",
20-
"api/leastsquaresoptim.md", "api/minpack.md", "api/nlsolve.md", "api/nlsolvers.md",
20+
"api/leastsquaresoptim.md", "api/minpack.md",
21+
"api/nlsolve.md", "api/nlsolvers.md", "api/optimizationjl.md",
2122
"api/siamfanlequations.md", "api/speedmapping.md", "api/sundials.md"],
2223
"Development Documentation" => [
2324
"devdocs/internal_interfaces.md", "devdocs/linear_solve.md",

docs/src/api/optimizationjl.md

+17
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
# Optimization.jl
2+
3+
This is a extension for importing solvers from Optimization.jl into the SciML Nonlinear
4+
Problem interface. Note that these solvers do not come by default, and thus one needs to
5+
install the package before using these solvers:
6+
7+
```julia
8+
using Pkg
9+
Pkg.add("Optimization")
10+
using Optimization, NonlinearSolve
11+
```
12+
13+
## Solver API
14+
15+
```@docs
16+
OptimizationJL
17+
```

docs/src/solvers/nonlinear_least_squares_solvers.md

+5-2
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ fails it falls back to a more robust algorithms ([`LevenbergMarquardt`](@ref),
2929
### SimpleNonlinearSolve.jl
3030

3131
These methods are included with NonlinearSolve.jl by default, though SimpleNonlinearSolve.jl
32-
can be used directly to reduce dependencies and improve load times.
32+
can be used directly to reduce dependencies and improve load times.
3333
SimpleNonlinearSolve.jl's methods excel at small problems and problems defined with static
3434
arrays.
3535

@@ -81,5 +81,8 @@ Submethod choices for this algorithm include:
8181

8282
### Optimization.jl
8383

84-
`NonlinearLeastSquaresProblem`s can be converted into an `OptimizationProblem` and used
84+
`NonlinearLeastSquaresProblem`s can be converted into an `OptimizationProblem` and used
8585
with any solver of [Optimization.jl](https://github.com/SciML/Optimization.jl).
86+
87+
Alternatively, [`OptimizationJL`](@ref) can be used directly. The only benefit of this is
88+
that the solver returns [`NonlinearSolution`](@ref) instead of `OptimizationSolution`.

docs/src/solvers/nonlinear_system_solvers.md

+11-2
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
solve(prob::NonlinearProblem, alg; kwargs...)
55
```
66

7-
Solves for ``f(u) = 0`` in the problem defined by `prob` using the algorithm `alg`. If no
7+
Solves for `f(u) = 0` in the problem defined by `prob` using the algorithm `alg`. If no
88
algorithm is given, a default algorithm will be chosen.
99

1010
## Recommended Methods
@@ -22,7 +22,7 @@ fail to converge. Additionally, [`DynamicSS`](@ref) can be a good choice for hig
2222
if the root corresponds to a stable equilibrium.
2323

2424
As a balance, [`NewtonRaphson`](@ref) is a good choice for most problems that aren't too
25-
difficult yet need high performance, and [`TrustRegion`](@ref) is a bit less performant but
25+
difficult yet need high performance, and [`TrustRegion`](@ref) is a bit less performant but
2626
more stable. If the problem is well-conditioned, [`Klement`](@ref) or [`Broyden`](@ref) may
2727
be faster, but highly dependent on the eigenvalues of the Jacobian being sufficiently small.
2828

@@ -177,3 +177,12 @@ This is a wrapper package for importing solvers from NLSolvers.jl into the SciML
177177
[NLSolvers.jl](https://github.com/JuliaNLSolvers/NLSolvers.jl)
178178

179179
For a list of possible solvers see the [NLSolvers.jl documentation](https://julianlsolvers.github.io/NLSolvers.jl/)
180+
181+
### Optimization.jl
182+
183+
This is a wrapper package for importing solvers from Optimization.jl into the SciML
184+
Nonlinear Problem interface. These exist mostly for benchmarking purposes and shouldn't
185+
be used by most users.
186+
187+
- [`OptimizationJL()`](@ref): A wrapper for
188+
[Optimization.jl](https://github.com/SciML/Optimization.jl)

ext/NonlinearSolveOptimizationExt.jl

+86
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
module NonlinearSolveOptimizationExt
2+
3+
using FastClosures, LinearAlgebra, NonlinearSolve, Optimization
4+
5+
function SciMLBase.__solve(
6+
prob::NonlinearProblem, alg::OptimizationJL, args...; abstol = nothing,
7+
maxiters = 1000, termination_condition = nothing, kwargs...)
8+
NonlinearSolve.__test_termination_condition(termination_condition, :OptimizationJL)
9+
10+
prob.u0 isa Number &&
11+
throw(ArgumentError("`OptimizationJL` doesn't support scalar `u0`"))
12+
13+
_objective_function = if SciMLBase.isinplace(prob)
14+
@closure (u, p) -> begin
15+
du = similar(u)
16+
prob.f(du, u, p)
17+
return norm(du, 2)
18+
end
19+
else
20+
@closure (u, p) -> norm(prob.f(u, p), 2)
21+
end
22+
23+
cons = if SciMLBase.isinplace(prob)
24+
prob.f
25+
else
26+
@closure (du, u, p) -> copyto!(du, prob.f(u, p))
27+
end
28+
29+
if alg.autodiff === nothing || alg.autodiff isa SciMLBase.NoAD
30+
opt_func = OptimizationFunction(_objective_function; cons)
31+
else
32+
opt_func = OptimizationFunction(_objective_function, alg.autodiff; cons)
33+
end
34+
bounds = similar(prob.u0)
35+
fill!(bounds, 0)
36+
opt_prob = OptimizationProblem(
37+
opt_func, prob.u0, prob.p; lcons = bounds, ucons = bounds)
38+
sol = solve(opt_prob, alg.solver, args...; abstol, maxiters, kwargs...)
39+
40+
fu = zero(prob.u0)
41+
cons(fu, sol.u, prob.p)
42+
43+
stats = SciMLBase.NLStats(sol.stats.fevals, sol.stats.gevals, -1, -1, -1)
44+
45+
return SciMLBase.build_solution(
46+
prob, alg, sol.u, fu; retcode = sol.retcode, original = sol, stats)
47+
end
48+
49+
function SciMLBase.__solve(prob::NonlinearLeastSquaresProblem, alg::OptimizationJL, args...;
50+
abstol = nothing, maxiters = 1000, termination_condition = nothing, kwargs...)
51+
NonlinearSolve.__test_termination_condition(termination_condition, :OptimizationJL)
52+
53+
_objective_function = if SciMLBase.isinplace(prob)
54+
@closure (θ, p) -> begin
55+
resid = prob.f.resid_prototype === nothing ? similar(θ) :
56+
similar(prob.f.resid_prototype, eltype(θ))
57+
prob.f(resid, θ, p)
58+
return norm(resid, 2)
59+
end
60+
else
61+
@closure (θ, p) -> norm(prob.f(θ, p), 2)
62+
end
63+
64+
if alg.autodiff === nothing || alg.autodiff isa SciMLBase.NoAD
65+
opt_func = OptimizationFunction(_objective_function)
66+
else
67+
opt_func = OptimizationFunction(_objective_function, alg.autodiff)
68+
end
69+
opt_prob = OptimizationProblem(opt_func, prob.u0, prob.p)
70+
sol = solve(opt_prob, alg.solver, args...; abstol, maxiters, kwargs...)
71+
72+
if SciMLBase.isinplace(prob)
73+
resid = prob.f.resid_prototype === nothing ? similar(prob.u0) :
74+
prob.f.resid_prototype
75+
prob.f(resid, sol.u, prob.p)
76+
else
77+
resid = prob.f(sol.u, prob.p)
78+
end
79+
80+
stats = SciMLBase.NLStats(sol.stats.fevals, sol.stats.gevals, -1, -1, -1)
81+
82+
return SciMLBase.build_solution(
83+
prob, alg, sol.u, resid; retcode = sol.retcode, original = sol, stats)
84+
end
85+
86+
end

src/NonlinearSolve.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,7 @@ export NonlinearSolvePolyAlgorithm, RobustMultiNewton, FastShortcutNonlinearPoly
148148

149149
# Extension Algorithms
150150
export LeastSquaresOptimJL, FastLevenbergMarquardtJL, CMINPACK, NLsolveJL, NLSolversJL,
151-
FixedPointAccelerationJL, SpeedMappingJL, SIAMFANLEquationsJL
151+
FixedPointAccelerationJL, SpeedMappingJL, SIAMFANLEquationsJL, OptimizationJL
152152

153153
# Advanced Algorithms -- Without Bells and Whistles
154154
export GeneralizedFirstOrderAlgorithm, ApproximateJacobianSolveAlgorithm, GeneralizedDFSane

src/algorithms/extension_algs.jl

+37
Original file line numberDiff line numberDiff line change
@@ -484,3 +484,40 @@ function SIAMFANLEquationsJL(; method = :newton, delta = 1e-3, linsolve = nothin
484484
end
485485
return SIAMFANLEquationsJL(method, delta, linsolve, m, beta, autodiff)
486486
end
487+
488+
"""
489+
OptimizationJL(solver, autodiff)
490+
491+
Wrapper over [Optimization.jl](https://docs.sciml.ai/Optimization/stable/) to solve
492+
Nonlinear Equations and Nonlinear Least Squares Problems.
493+
494+
!!! danger "Using OptimizationJL for Nonlinear Systems"
495+
496+
This is a absolutely terrible idea. We construct the objective function as the L2-norm
497+
of the residual function and impose an equality constraint. This is very inefficient
498+
and exists to convince people from HackerNews that this is a horrible idea.
499+
500+
### Arguments
501+
502+
- `solver`: The solver to use from Optimization.jl. In general for NLLS, all of the
503+
solvers will work. However, for nonlinear systems, only the solvers that support
504+
equality constraints will work.
505+
- `autodiff`: Automatic Differentiation Backend that Optimization.jl should use. See
506+
https://docs.sciml.ai/Optimization/stable/API/ad/ for more details. Defaults to
507+
`SciMLBase.NoAD()`.
508+
509+
!!! note
510+
511+
This algorithm is only available if `Optimization.jl` is installed.
512+
"""
513+
struct OptimizationJL{S, AD} <: AbstractNonlinearSolveExtensionAlgorithm
514+
solver::S
515+
autodiff::AD
516+
517+
function OptimizationJL(solver, autodiff = SciMLBase.NoAD())
518+
if Base.get_extension(@__MODULE__, :NonlinearSolveOptimizationExt) === nothing
519+
error("OptimizationJL requires Optimization.jl to be loaded")
520+
end
521+
return new{typeof(solver), typeof(autodiff)}(solver, autodiff)
522+
end
523+
end

test/wrappers/nlls_tests.jl

+18
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,24 @@ end
4646
end
4747
end
4848

49+
@testitem "Optimization.jl" setup=[WrapperNLLSSetup] begin
50+
import Optimization, OptimizationMOI, Ipopt
51+
52+
prob_oop = NonlinearLeastSquaresProblem{false}(loss_function, θ_init, x)
53+
prob_iip = NonlinearLeastSquaresProblem(
54+
NonlinearFunction(loss_function; resid_prototype = zero(y_target)), θ_init, x)
55+
56+
nlls_problems = [prob_oop, prob_iip]
57+
58+
solver = OptimizationJL(Ipopt.Optimizer(), AutoForwardDiff())
59+
60+
for prob in nlls_problems
61+
sol = solve(prob, solver; maxiters = 10000, abstol = 1e-8)
62+
# Ipopt fails currently
63+
@test sol isa SciMLBase.NonlinearSolution
64+
end
65+
end
66+
4967
@testitem "FastLevenbergMarquardt.jl + CMINPACK: Jacobian Provided" setup=[WrapperNLLSSetup] begin
5068
function jac!(J, θ, p)
5169
resid = zeros(length(p))

test/wrappers/rootfind_tests.jl

+10-3
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,9 @@
22
using Reexport
33
@reexport using LinearAlgebra
44
import NLSolvers, NLsolve, SIAMFANLEquations, MINPACK
5+
import Optimization, OptimizationMOI, Ipopt
56

6-
export NLSolvers
7+
export NLSolvers, Ipopt
78
end
89

910
@testitem "Steady State Problems" setup=[WrapperRootfindImports] begin
@@ -48,7 +49,8 @@ end
4849

4950
for alg in [
5051
NLSolversJL(NLSolvers.LineSearch(NLSolvers.Newton(), NLSolvers.Backtracking())),
51-
NLsolveJL(), CMINPACK(), SIAMFANLEquationsJL()]
52+
NLsolveJL(), CMINPACK(), SIAMFANLEquationsJL(),
53+
OptimizationJL(Ipopt.Optimizer(), AutoForwardDiff())]
5254
local sol
5355
sol = solve(prob_iip, alg)
5456
@test SciMLBase.successful_retcode(sol.retcode)
@@ -61,7 +63,8 @@ end
6163
prob_oop = NonlinearProblem{false}(f_oop, u0)
6264
for alg in [
6365
NLSolversJL(NLSolvers.LineSearch(NLSolvers.Newton(), NLSolvers.Backtracking())),
64-
NLsolveJL(), CMINPACK(), SIAMFANLEquationsJL()]
66+
NLsolveJL(), CMINPACK(), SIAMFANLEquationsJL(),
67+
OptimizationJL(Ipopt.Optimizer(), AutoForwardDiff())]
6568
local sol
6669
sol = solve(prob_oop, alg)
6770
@test SciMLBase.successful_retcode(sol.retcode)
@@ -128,4 +131,8 @@ end
128131
@test maximum(abs, sol.resid) < 1e-6
129132
sol = solve(ProbN, SIAMFANLEquationsJL(; method = :pseudotransient); abstol = 1e-8)
130133
@test maximum(abs, sol.resid) < 1e-6
134+
sol = solve(ProbN, OptimizationJL(Ipopt.Optimizer(), AutoForwardDiff()); abstol = 1e-8)
135+
@test maximum(abs, sol.resid) < 1e-6
136+
sol = solve(ProbN, OptimizationJL(Ipopt.Optimizer(), AutoFiniteDiff()); abstol = 1e-8)
137+
@test maximum(abs, sol.resid) < 1e-6
131138
end

0 commit comments

Comments
 (0)