Skip to content

Commit b2946b1

Browse files
Merge pull request #203 from avik-pal/ap/cleanup
Towards a cleaner and more maintainable internals of NonlinearSolve.jl
2 parents 81e9164 + 4cd2d97 commit b2946b1

18 files changed

+1081
-2124
lines changed

.JuliaFormatter.toml

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
style = "sciml"
2-
format_markdown = true
2+
format_markdown = true
3+
annotate_untyped_fields_with_any = false

.github/workflows/CI.yml

-1
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@ jobs:
1515
- Core
1616
version:
1717
- '1'
18-
- '1.6'
1918
steps:
2019
- uses: actions/checkout@v4
2120
- uses: julia-actions/setup-julia@v1

.github/workflows/Downstream.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ jobs:
1414
strategy:
1515
fail-fast: false
1616
matrix:
17-
julia-version: [1,1.6]
17+
julia-version: [1]
1818
os: [ubuntu-latest]
1919
package:
2020
- {user: SciML, repo: ModelingToolkit.jl, group: All}

.gitignore

+1
Original file line numberDiff line numberDiff line change
@@ -25,3 +25,4 @@ Manifest.toml
2525
docs/src/assets/Project.toml
2626

2727
.vscode
28+
wip

Project.toml

+17-6
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,17 @@
11
name = "NonlinearSolve"
22
uuid = "8913a72c-1f9b-4ce2-8d82-65094dcecaec"
33
authors = ["SciML"]
4-
version = "1.10.0"
4+
version = "2.0.0"
55

66
[deps]
7+
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
78
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
9+
ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471"
810
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
911
EnumX = "4e289a0a-7415-4d19-859d-a7e5c4648b56"
1012
FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41"
1113
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
14+
LineSearches = "d3d80556-e9d4-5f37-9878-2ab0fcc64255"
1215
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1316
LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae"
1417
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
@@ -22,33 +25,41 @@ StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"
2225
UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed"
2326

2427
[compat]
28+
ADTypes = "0.2"
2529
ArrayInterface = "6.0.24, 7"
26-
DiffEqBase = "6"
30+
ConcreteStructs = "0.2"
31+
DiffEqBase = "6.130"
2732
EnumX = "1"
33+
Enzyme = "0.11"
2834
FiniteDiff = "2"
2935
ForwardDiff = "0.10.3"
3036
LinearSolve = "2"
37+
LineSearches = "7"
3138
PrecompileTools = "1"
3239
RecursiveArrayTools = "2"
3340
Reexport = "0.2, 1"
34-
SciMLBase = "1.92.4"
41+
SciMLBase = "1.97"
3542
SimpleNonlinearSolve = "0.1"
36-
SparseDiffTools = "1, 2"
43+
SparseDiffTools = "2.6"
3744
StaticArraysCore = "1.4"
3845
UnPack = "1.0"
39-
julia = "1.6"
46+
Zygote = "0.6"
47+
julia = "1.9"
4048

4149
[extras]
4250
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
51+
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
4352
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
4453
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
4554
LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae"
4655
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
4756
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
4857
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
58+
SparseDiffTools = "47a9eef4-7e08-11e9-0b38-333d64bd3804"
4959
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
5060
Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7"
5161
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
62+
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
5263

5364
[targets]
54-
test = ["BenchmarkTools", "SafeTestsets", "Pkg", "Test", "ForwardDiff", "StaticArrays", "Symbolics", "LinearSolve", "Random", "LinearAlgebra"]
65+
test = ["Enzyme", "BenchmarkTools", "SafeTestsets", "Pkg", "Test", "ForwardDiff", "StaticArrays", "Symbolics", "LinearSolve", "Random", "LinearAlgebra", "Zygote", "SparseDiffTools"]

docs/Project.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ Sundials = "c3572dad-4567-51f8-b174-8c6c989267f4"
1414
BenchmarkTools = "1"
1515
Documenter = "0.27"
1616
LinearSolve = "2"
17-
NonlinearSolve = "1"
17+
NonlinearSolve = "1, 2"
1818
NonlinearSolveMINPACK = "0.1"
1919
SciMLNLSolve = "0.1"
2020
SimpleNonlinearSolve = "0.1.5"

src/NonlinearSolve.jl

+34-30
Original file line numberDiff line numberDiff line change
@@ -1,39 +1,41 @@
11
module NonlinearSolve
2-
if isdefined(Base, :Experimental) &&
3-
isdefined(Base.Experimental, Symbol("@max_methods"))
2+
3+
if isdefined(Base, :Experimental) && isdefined(Base.Experimental, Symbol("@max_methods"))
44
@eval Base.Experimental.@max_methods 1
55
end
6-
using Reexport
7-
using UnPack: @unpack
8-
using FiniteDiff, ForwardDiff
9-
using ForwardDiff: Dual
10-
using LinearAlgebra
11-
using StaticArraysCore
12-
using RecursiveArrayTools
13-
import EnumX
14-
import ArrayInterface
15-
import LinearSolve
16-
using DiffEqBase
17-
using SparseDiffTools
18-
19-
@reexport using SciMLBase
20-
using SciMLBase: NLStats
21-
@reexport using SimpleNonlinearSolve
22-
23-
import SciMLBase: _unwrap_val
24-
25-
abstract type AbstractNonlinearSolveAlgorithm <: SciMLBase.AbstractNonlinearAlgorithm end
26-
abstract type AbstractNewtonAlgorithm{CS, AD, FDT, ST, CJ} <:
27-
AbstractNonlinearSolveAlgorithm end
28-
29-
function SciMLBase.__solve(prob::NonlinearProblem,
30-
alg::AbstractNonlinearSolveAlgorithm, args...;
31-
kwargs...)
6+
7+
using DiffEqBase, LinearAlgebra, LinearSolve, SparseDiffTools
8+
import ForwardDiff
9+
10+
import ADTypes: AbstractFiniteDifferencesMode
11+
import ArrayInterface: undefmatrix, matrix_colors, parameterless_type, ismutable
12+
import ConcreteStructs: @concrete
13+
import EnumX: @enumx
14+
import ForwardDiff: Dual
15+
import LinearSolve: ComposePreconditioner, InvPreconditioner, needs_concrete_A
16+
import RecursiveArrayTools: ArrayPartition,
17+
AbstractVectorOfArray, recursivecopy!, recursivefill!
18+
import Reexport: @reexport
19+
import SciMLBase: AbstractNonlinearAlgorithm, NLStats, _unwrap_val, has_jac, isinplace
20+
import StaticArraysCore: StaticArray, SVector, SArray, MArray
21+
import UnPack: @unpack
22+
23+
@reexport using ADTypes, LineSearches, SciMLBase, SimpleNonlinearSolve
24+
25+
const AbstractSparseADType = Union{ADTypes.AbstractSparseFiniteDifferences,
26+
ADTypes.AbstractSparseForwardMode, ADTypes.AbstractSparseReverseMode}
27+
28+
abstract type AbstractNonlinearSolveAlgorithm <: AbstractNonlinearAlgorithm end
29+
abstract type AbstractNewtonAlgorithm{CJ, AD} <: AbstractNonlinearSolveAlgorithm end
30+
31+
function SciMLBase.__solve(prob::NonlinearProblem, alg::AbstractNonlinearSolveAlgorithm,
32+
args...; kwargs...)
3233
cache = init(prob, alg, args...; kwargs...)
33-
sol = solve!(cache)
34+
return solve!(cache)
3435
end
3536

3637
include("utils.jl")
38+
include("linesearch.jl")
3739
include("raphson.jl")
3840
include("trustRegion.jl")
3941
include("levenberg.jl")
@@ -46,7 +48,7 @@ PrecompileTools.@compile_workload begin
4648
for T in (Float32, Float64)
4749
prob = NonlinearProblem{false}((u, p) -> u .* u .- p, T(0.1), T(2))
4850

49-
precompile_algs = if VERSION >= v"1.7"
51+
precompile_algs = if VERSION v"1.7"
5052
(NewtonRaphson(), TrustRegion(), LevenbergMarquardt())
5153
else
5254
(NewtonRaphson(),)
@@ -68,4 +70,6 @@ export RadiusUpdateSchemes
6870

6971
export NewtonRaphson, TrustRegion, LevenbergMarquardt
7072

73+
export LineSearch
74+
7175
end # module

src/ad.jl

+43-24
Original file line numberDiff line numberDiff line change
@@ -1,44 +1,63 @@
11
function scalar_nlsolve_ad(prob, alg, args...; kwargs...)
22
f = prob.f
33
p = value(prob.p)
4-
54
u0 = value(prob.u0)
65
newprob = NonlinearProblem(f, u0, p; prob.kwargs...)
76

87
sol = solve(newprob, alg, args...; kwargs...)
98

109
uu = sol.u
11-
if p isa Number
12-
f_p = ForwardDiff.derivative(Base.Fix1(f, uu), p)
13-
else
14-
f_p = ForwardDiff.gradient(Base.Fix1(f, uu), p)
15-
end
10+
f_p = scalar_nlsolve_∂f_∂p(f, uu, p)
11+
f_x = scalar_nlsolve_∂f_∂u(f, uu, p)
12+
13+
z_arr = -inv(f_x) * f_p
1614

17-
f_x = ForwardDiff.derivative(Base.Fix2(f, p), uu)
1815
pp = prob.p
19-
sumfun = let f_x′ = -f_x
20-
((fp, p),) -> (fp / f_x′) * ForwardDiff.partials(p)
16+
sumfun = ((z, p),) -> map(zᵢ -> zᵢ * ForwardDiff.partials(p), z)
17+
if uu isa Number
18+
partials = sum(sumfun, zip(z_arr, pp))
19+
elseif p isa Number
20+
partials = sumfun((z_arr, pp))
21+
else
22+
partials = sum(sumfun, zip(eachcol(z_arr), pp))
2123
end
22-
partials = sum(sumfun, zip(f_p, pp))
24+
2325
return sol, partials
2426
end
2527

26-
function SciMLBase.solve(prob::NonlinearProblem{<:Union{Number, StaticArraysCore.SVector},
27-
iip,
28-
<:Dual{T, V, P}},
29-
alg::AbstractNewtonAlgorithm,
30-
args...; kwargs...) where {iip, T, V, P}
28+
function SciMLBase.solve(prob::NonlinearProblem{<:Union{Number, SVector, <:AbstractArray},
29+
iip, <:Dual{T, V, P}}, alg::AbstractNewtonAlgorithm, args...;
30+
kwargs...) where {iip, T, V, P}
3131
sol, partials = scalar_nlsolve_ad(prob, alg, args...; kwargs...)
32-
return SciMLBase.build_solution(prob, alg, Dual{T, V, P}(sol.u, partials), sol.resid;
33-
retcode = sol.retcode)
32+
dual_soln = scalar_nlsolve_dual_soln(sol.u, partials, prob.p)
33+
return SciMLBase.build_solution(prob, alg, dual_soln, sol.resid; sol.retcode)
3434
end
35-
function SciMLBase.solve(prob::NonlinearProblem{<:Union{Number, StaticArraysCore.SVector},
36-
iip,
37-
<:AbstractArray{<:Dual{T, V, P}}},
38-
alg::AbstractNewtonAlgorithm,
39-
args...;
35+
36+
function SciMLBase.solve(prob::NonlinearProblem{<:Union{Number, SVector, <:AbstractArray},
37+
iip, <:AbstractArray{<:Dual{T, V, P}}}, alg::AbstractNewtonAlgorithm, args...;
4038
kwargs...) where {iip, T, V, P}
4139
sol, partials = scalar_nlsolve_ad(prob, alg, args...; kwargs...)
42-
return SciMLBase.build_solution(prob, alg, Dual{T, V, P}(sol.u, partials), sol.resid;
43-
retcode = sol.retcode)
40+
dual_soln = scalar_nlsolve_dual_soln(sol.u, partials, prob.p)
41+
return SciMLBase.build_solution(prob, alg, dual_soln, sol.resid; sol.retcode)
42+
end
43+
44+
function scalar_nlsolve_∂f_∂p(f, u, p)
45+
ff = p isa Number ? ForwardDiff.derivative :
46+
(u isa Number ? ForwardDiff.gradient : ForwardDiff.jacobian)
47+
return ff(Base.Fix1(f, u), p)
48+
end
49+
50+
function scalar_nlsolve_∂f_∂u(f, u, p)
51+
ff = u isa Number ? ForwardDiff.derivative : ForwardDiff.jacobian
52+
return ff(Base.Fix2(f, p), u)
53+
end
54+
55+
function scalar_nlsolve_dual_soln(u::Number, partials,
56+
::Union{<:AbstractArray{<:Dual{T, V, P}}, Dual{T, V, P}}) where {T, V, P}
57+
return Dual{T, V, P}(u, partials)
58+
end
59+
60+
function scalar_nlsolve_dual_soln(u::AbstractArray, partials,
61+
::Union{<:AbstractArray{<:Dual{T, V, P}}, Dual{T, V, P}}) where {T, V, P}
62+
return map(((uᵢ, pᵢ),) -> Dual{T, V, P}(uᵢ, pᵢ), zip(u, partials))
4463
end

0 commit comments

Comments
 (0)