Skip to content

Commit fd586b7

Browse files
committed
Add NLsolve
1 parent 6b58f42 commit fd586b7

File tree

7 files changed

+102
-25
lines changed

7 files changed

+102
-25
lines changed

Project.toml

+1
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@ SparseArrays = "<0.0.1, 1"
8383
SparseDiffTools = "2.14"
8484
StableRNGs = "1"
8585
StaticArrays = "1"
86+
SteadyStateDiffEq = "2"
8687
Symbolics = "5"
8788
Test = "1"
8889
UnPack = "1.0"

docs/src/api/nlsolve.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -13,5 +13,5 @@ using NLSolve, NonlinearSolve
1313
## Solver API
1414

1515
```@docs
16-
NLSolveJL
16+
NLsolveJL
1717
```

docs/src/solvers/NonlinearSystemSolvers.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ computationally expensive than direct methods.
114114

115115
This is a wrapper package for importing solvers from NLsolve.jl into the SciML interface.
116116

117-
- `NLSolveJL()`: A wrapper for [NLsolve.jl](https://github.com/JuliaNLSolvers/NLsolve.jl)
117+
- `NLsolveJL()`: A wrapper for [NLsolve.jl](https://github.com/JuliaNLSolvers/NLsolve.jl)
118118

119119
Submethod choices for this algorithm include:
120120

ext/NonlinearSolveMINPACKExt.jl

+4-5
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@ function SciMLBase.__solve(prob::Union{NonlinearProblem{uType, iip},
1919
# unwrapping alg params
2020
show_trace = alg.show_trace
2121
tracing = alg.tracing
22-
io = alg.io
2322

2423
if !iip && prob.u0 isa Number
2524
f! = (du, u) -> (du .= prob.f(first(u), p); Cint(0))
@@ -36,9 +35,10 @@ function SciMLBase.__solve(prob::Union{NonlinearProblem{uType, iip},
3635
u = zero(u0)
3736
resid = NonlinearSolve.evaluate_f(prob, u)
3837
m = length(resid)
38+
size_jac = (length(resid), length(u))
3939

4040
method = ifelse(alg.method === :auto,
41-
ifelse(prob isa NonlinearLeastSquaresProblem, :lm, :hydr), alg.method)
41+
ifelse(prob isa NonlinearLeastSquaresProblem, :lm, :hybr), alg.method)
4242

4343
if SciMLBase.has_jac(prob.f)
4444
if !iip && prob.u0 isa Number
@@ -51,9 +51,8 @@ function SciMLBase.__solve(prob::Union{NonlinearProblem{uType, iip},
5151
g! = (du, u) -> prob.f.jac(du, u, p)
5252
else # Then it's an in-place function on an abstract array
5353
g! = function (du, u)
54-
prob.f.jac(reshape(du, sizeu), reshape(u, sizeu), p)
55-
du = vec(du)
56-
return CInt(0)
54+
prob.f.jac(reshape(du, size_jac), reshape(u, sizeu), p)
55+
return Cint(0)
5756
end
5857
end
5958
original = MINPACK.fsolve(f!, g!, u0, m; tol = abstol, show_trace, tracing, method,

ext/NonlinearSolveNLsolveExt.jl

+77
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,80 @@
11
module NonlinearSolveNLsolveExt
22

3+
using NonlinearSolve, NLsolve, DiffEqBase, SciMLBase
4+
import UnPack: @unpack
5+
6+
function SciMLBase.__solve(prob::NonlinearProblem, alg::NLsolveJL, args...; abstol = 1e-6,
7+
maxiters = 1000, alias_u0::Bool = false, kwargs...)
8+
if typeof(prob.u0) <: Number
9+
u0 = [prob.u0]
10+
else
11+
u0 = NonlinearSolve.__maybe_unaliased(prob.u0, alias_u0)
12+
end
13+
14+
iip = isinplace(prob)
15+
16+
sizeu = size(prob.u0)
17+
p = prob.p
18+
19+
# unwrapping alg params
20+
@unpack method, autodiff, store_trace, extended_trace, linesearch, linsolve = alg
21+
@unpack factor, autoscale, m, beta, show_trace = alg
22+
23+
if !iip && prob.u0 isa Number
24+
f! = (du, u) -> (du .= prob.f(first(u), p); Cint(0))
25+
elseif !iip && prob.u0 isa Vector{Float64}
26+
f! = (du, u) -> (du .= prob.f(u, p); Cint(0))
27+
elseif !iip && prob.u0 isa AbstractArray
28+
f! = (du, u) -> (du .= vec(prob.f(reshape(u, sizeu), p)); Cint(0))
29+
elseif prob.u0 isa Vector{Float64}
30+
f! = (du, u) -> prob.f(du, u, p)
31+
else # Then it's an in-place function on an abstract array
32+
f! = (du, u) -> (prob.f(reshape(du, sizeu), reshape(u, sizeu), p); du = vec(du); 0)
33+
end
34+
35+
if prob.u0 isa Number
36+
resid = [NonlinearSolve.evaluate_f(prob, first(u0))]
37+
else
38+
resid = NonlinearSolve.evaluate_f(prob, u0)
39+
end
40+
41+
size_jac = (length(resid), length(u0))
42+
43+
if SciMLBase.has_jac(prob.f)
44+
if !iip && prob.u0 isa Number
45+
g! = (du, u) -> (du .= prob.f.jac(first(u), p); Cint(0))
46+
elseif !iip && prob.u0 isa Vector{Float64}
47+
g! = (du, u) -> (du .= prob.f.jac(u, p); Cint(0))
48+
elseif !iip && prob.u0 isa AbstractArray
49+
g! = (du, u) -> (du .= vec(prob.f.jac(reshape(u, sizeu), p)); Cint(0))
50+
elseif prob.u0 isa Vector{Float64}
51+
g! = (du, u) -> prob.f.jac(du, u, p)
52+
else # Then it's an in-place function on an abstract array
53+
g! = function (du, u)
54+
prob.f.jac(reshape(du, size_jac), reshape(u, sizeu), p)
55+
return Cint(0)
56+
end
57+
end
58+
if prob.f.jac_prototype !== nothing
59+
J = zero(prob.f.jac_prototype)
60+
df = OnceDifferentiable(f!, g!, u0, resid, J)
61+
else
62+
df = OnceDifferentiable(f!, g!, u0, resid)
63+
end
64+
else
65+
df = OnceDifferentiable(f!, u0, resid; autodiff)
66+
end
67+
68+
original = nlsolve(df, u0; ftol = abstol, iterations = maxiters, method, store_trace,
69+
extended_trace, linesearch, linsolve, factor, autoscale, m, beta, show_trace)
70+
71+
u = reshape(original.zero, size(u0))
72+
f!(resid, u)
73+
retcode = original.x_converged || original.f_converged ? ReturnCode.Success :
74+
ReturnCode.Failure
75+
stats = SciMLBase.NLStats(original.f_calls, original.g_calls, original.g_calls,
76+
original.g_calls, original.iterations)
77+
return SciMLBase.build_solution(prob, alg, u, resid; retcode, original, stats)
78+
end
79+
380
end

src/extension_algs.jl

+9-9
Original file line numberDiff line numberDiff line change
@@ -143,10 +143,10 @@ function CMINPACK(; show_trace::Bool = false, tracing::Bool = false, method::Sym
143143
end
144144

145145
"""
146-
NLSolveJL(; method=:trust_region, autodiff=:central, store_trace=false,
147-
extended_trace=false, linesearch=LineSearches.Static(),
148-
linsolve=(x, A, b) -> copyto!(x, A\\b), factor = one(Float64), autoscale=true,
149-
m=10, beta=one(Float64), show_trace=false)
146+
NLsolveJL(; method=:trust_region, autodiff=:central, store_trace=false,
147+
extended_trace=false, linesearch=LineSearches.Static(),
148+
linsolve=(x, A, b) -> copyto!(x, A\\b), factor = one(Float64), autoscale=true,
149+
m=10, beta=one(Float64), show_trace=false)
150150
151151
### Keyword Arguments
152152
@@ -171,7 +171,7 @@ end
171171
172172
### Submethod Choice
173173
174-
Choices for methods in `NLSolveJL`:
174+
Choices for methods in `NLsolveJL`:
175175
176176
- `:anderson`: Anderson-accelerated fixed-point iteration
177177
- `:broyden`: Broyden's quasi-Newton method
@@ -180,7 +180,7 @@ Choices for methods in `NLSolveJL`:
180180
these arguments, consult the
181181
[NLsolve.jl documentation](https://github.com/JuliaNLSolvers/NLsolve.jl).
182182
"""
183-
@concrete struct NLSolveJL <: AbstractNonlinearAlgorithm
183+
@concrete struct NLsolveJL <: AbstractNonlinearAlgorithm
184184
method::Symbol
185185
autodiff::Symbol
186186
store_trace::Bool
@@ -194,14 +194,14 @@ Choices for methods in `NLSolveJL`:
194194
show_trace::Bool
195195
end
196196

197-
function NLSolveJL(; method = :trust_region, autodiff = :central, store_trace = false,
197+
function NLsolveJL(; method = :trust_region, autodiff = :central, store_trace = false,
198198
extended_trace = false, linesearch = LineSearches.Static(),
199199
linsolve = (x, A, b) -> copyto!(x, A \ b), factor = 1.0, autoscale = true, m = 10,
200200
beta = one(Float64), show_trace = false)
201201
if Base.get_extension(@__MODULE__, :NonlinearSolveNLsolveExt) === nothing
202-
error("NLSolveJL requires NLsolve.jl to be loaded")
202+
error("NLsolveJL requires NLsolve.jl to be loaded")
203203
end
204204

205-
return NLSolveJL(method, autodiff, store_trace, extended_trace, linesearch, linsolve,
205+
return NLsolveJL(method, autodiff, store_trace, extended_trace, linesearch, linsolve,
206206
factor, autoscale, m, beta, show_trace)
207207
end

test/nlsolve.jl

+9-9
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ u0 = zeros(2)
99
prob_iip = SteadyStateProblem(f_iip, u0)
1010
abstol = 1e-8
1111

12-
for alg in [NLSolveJL()]
12+
for alg in [NLsolveJL()]
1313
sol = solve(prob_iip, alg)
1414
@test sol.retcode == ReturnCode.Success
1515
p = nothing
@@ -24,7 +24,7 @@ f_oop(u, p, t) = [2 - 2u[1], u[1] - 4u[2]]
2424
u0 = zeros(2)
2525
prob_oop = SteadyStateProblem(f_oop, u0)
2626

27-
for alg in [NLSolveJL()]
27+
for alg in [NLsolveJL()]
2828
sol = solve(prob_oop, alg)
2929
@test sol.retcode == ReturnCode.Success
3030
# test the solver is doing reasonable things for linear solve
@@ -45,7 +45,7 @@ end
4545
u0 = zeros(2)
4646
prob_iip = NonlinearProblem{true}(f_iip, u0)
4747
abstol = 1e-8
48-
for alg in [NLSolveJL()]
48+
for alg in [NLsolveJL()]
4949
local sol
5050
sol = solve(prob_iip, alg)
5151
@test sol.retcode == ReturnCode.Success
@@ -60,7 +60,7 @@ end
6060
f_oop(u, p) = [2 - 2u[1], u[1] - 4u[2]]
6161
u0 = zeros(2)
6262
prob_oop = NonlinearProblem{false}(f_oop, u0)
63-
for alg in [NLSolveJL()]
63+
for alg in [NLsolveJL()]
6464
local sol
6565
sol = solve(prob_oop, alg)
6666
@test sol.retcode == ReturnCode.Success
@@ -74,7 +74,7 @@ end
7474
f_tol(u, p) = u^2 - 2
7575
prob_tol = NonlinearProblem(f_tol, 1.0)
7676
for tol in [1e-1, 1e-3, 1e-6, 1e-10, 1e-15]
77-
sol = solve(prob_tol, NLSolveJL(), abstol = tol)
77+
sol = solve(prob_tol, NLsolveJL(), abstol = tol)
7878
@test abs(sol.u[1] - sqrt(2)) < tol
7979
end
8080

@@ -85,7 +85,7 @@ function f!(fvec, x, p)
8585
end
8686

8787
prob = NonlinearProblem{true}(f!, [0.1; 1.2])
88-
sol = solve(prob, NLSolveJL(autodiff = :central))
88+
sol = solve(prob, NLsolveJL(autodiff = :central))
8989

9090
du = zeros(2)
9191
f!(du, sol.u, nothing)
@@ -98,7 +98,7 @@ function f!(fvec, x, p)
9898
end
9999

100100
prob = NonlinearProblem{true}(f!, [0.1; 1.2])
101-
sol = solve(prob, NLSolveJL(autodiff = :forward))
101+
sol = solve(prob, NLsolveJL(autodiff = :forward))
102102

103103
du = zeros(2)
104104
f!(du, sol.u, nothing)
@@ -131,8 +131,8 @@ f = NonlinearFunction(f!, jac = j!)
131131
p = A
132132

133133
ProbN = NonlinearProblem(f, init, p)
134-
sol = solve(ProbN, NLSolveJL(), reltol = 1e-8, abstol = 1e-8)
134+
sol = solve(ProbN, NLsolveJL(), reltol = 1e-8, abstol = 1e-8)
135135

136136
init = ones(Complex{Float64}, 152);
137137
ProbN = NonlinearProblem(f, init, p)
138-
sol = solve(ProbN, NLSolveJL(), reltol = 1e-8, abstol = 1e-8)
138+
sol = solve(ProbN, NLsolveJL(), reltol = 1e-8, abstol = 1e-8)

0 commit comments

Comments
 (0)