Skip to content

Commit 2002bd4

Browse files
authored
refactor: Move dual nonlinear solving to NonlinearSolveBase (#513)
1 parent 2284348 commit 2002bd4

15 files changed

+294
-93
lines changed

Diff for: docs/src/basics/faq.md

+2-2
Original file line numberDiff line numberDiff line change
@@ -152,10 +152,10 @@ nothing # hide
152152
```
153153

154154
And boom! Type stable again. We always recommend picking the chunksize via
155-
[`NonlinearSolve.pickchunksize`](@ref), however, if you manually specify the chunksize, it
155+
[`NonlinearSolveBase.pickchunksize`](@ref), however, if you manually specify the chunksize, it
156156
must be `≤ length of input`. However, a very large chunksize can lead to excessive
157157
compilation times and slowdown.
158158

159159
```@docs
160-
NonlinearSolve.pickchunksize
160+
NonlinearSolveBase.pickchunksize
161161
```

Diff for: lib/NonlinearSolveBase/ext/NonlinearSolveBaseForwardDiffExt.jl

+95-3
Original file line numberDiff line numberDiff line change
@@ -2,17 +2,35 @@ module NonlinearSolveBaseForwardDiffExt
22

33
using ADTypes: ADTypes, AutoForwardDiff, AutoPolyesterForwardDiff
44
using ArrayInterface: ArrayInterface
5-
using CommonSolve: solve
5+
using CommonSolve: CommonSolve, solve, solve!, init
6+
using ConcreteStructs: @concrete
67
using DifferentiationInterface: DifferentiationInterface
78
using FastClosures: @closure
8-
using ForwardDiff: ForwardDiff, Dual
9+
using ForwardDiff: ForwardDiff, Dual, pickchunksize
910
using SciMLBase: SciMLBase, AbstractNonlinearProblem, IntervalNonlinearProblem,
1011
NonlinearProblem, NonlinearLeastSquaresProblem, remake
1112

12-
using NonlinearSolveBase: NonlinearSolveBase, ImmutableNonlinearProblem, Utils
13+
using NonlinearSolveBase: NonlinearSolveBase, ImmutableNonlinearProblem, Utils, InternalAPI,
14+
NonlinearSolvePolyAlgorithm, NonlinearSolveForwardDiffCache
1315

1416
const DI = DifferentiationInterface
1517

18+
const GENERAL_SOLVER_TYPES = [
19+
Nothing, NonlinearSolvePolyAlgorithm
20+
]
21+
22+
const DualNonlinearProblem = NonlinearProblem{
23+
<:Union{Number, <:AbstractArray}, iip,
24+
<:Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}}
25+
} where {iip, T, V, P}
26+
const DualNonlinearLeastSquaresProblem = NonlinearLeastSquaresProblem{
27+
<:Union{Number, <:AbstractArray}, iip,
28+
<:Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}}
29+
} where {iip, T, V, P}
30+
const DualAbstractNonlinearProblem = Union{
31+
DualNonlinearProblem, DualNonlinearLeastSquaresProblem
32+
}
33+
1634
function NonlinearSolveBase.additional_incompatible_backend_check(
1735
prob::AbstractNonlinearProblem, ::Union{AutoForwardDiff, AutoPolyesterForwardDiff})
1836
return !ForwardDiff.can_dual(eltype(prob.u0))
@@ -102,4 +120,78 @@ function NonlinearSolveBase.nonlinearsolve_dual_solution(
102120
return map(((uᵢ, pᵢ),) -> Dual{T, V, P}(uᵢ, pᵢ), zip(u, Utils.restructure(u, partials)))
103121
end
104122

123+
for algType in GENERAL_SOLVER_TYPES
124+
@eval function SciMLBase.__solve(
125+
prob::DualAbstractNonlinearProblem, alg::$(algType), args...; kwargs...
126+
)
127+
sol, partials = NonlinearSolveBase.nonlinearsolve_forwarddiff_solve(
128+
prob, alg, args...; kwargs...
129+
)
130+
dual_soln = NonlinearSolveBase.nonlinearsolve_dual_solution(sol.u, partials, prob.p)
131+
return SciMLBase.build_solution(
132+
prob, alg, dual_soln, sol.resid; sol.retcode, sol.stats, sol.original
133+
)
134+
end
135+
end
136+
137+
function InternalAPI.reinit!(
138+
cache::NonlinearSolveForwardDiffCache, args...;
139+
p = cache.p, u0 = NonlinearSolveBase.get_u(cache.cache), kwargs...
140+
)
141+
InternalAPI.reinit!(
142+
cache.cache; p = NonlinearSolveBase.nodual_value(p),
143+
u0 = NonlinearSolveBase.nodual_value(u0), kwargs...
144+
)
145+
cache.p = p
146+
cache.values_p = NonlinearSolveBase.nodual_value(p)
147+
cache.partials_p = ForwardDiff.partials(p)
148+
return cache
149+
end
150+
151+
for algType in GENERAL_SOLVER_TYPES
152+
@eval function SciMLBase.__init(
153+
prob::DualAbstractNonlinearProblem, alg::$(algType), args...; kwargs...
154+
)
155+
p = NonlinearSolveBase.nodual_value(prob.p)
156+
newprob = SciMLBase.remake(prob; u0 = NonlinearSolveBase.nodual_value(prob.u0), p)
157+
cache = init(newprob, alg, args...; kwargs...)
158+
return NonlinearSolveForwardDiffCache(
159+
cache, newprob, alg, prob.p, p, ForwardDiff.partials(prob.p)
160+
)
161+
end
162+
end
163+
164+
function CommonSolve.solve!(cache::NonlinearSolveForwardDiffCache)
165+
sol = solve!(cache.cache)
166+
prob = cache.prob
167+
uu = sol.u
168+
169+
fn = prob isa NonlinearLeastSquaresProblem ?
170+
NonlinearSolveBase.nlls_generate_vjp_function(prob, sol, uu) : prob.f
171+
172+
Jₚ = NonlinearSolveBase.nonlinearsolve_∂f_∂p(prob, fn, uu, cache.values_p)
173+
Jᵤ = NonlinearSolveBase.nonlinearsolve_∂f_∂u(prob, fn, uu, cache.values_p)
174+
175+
z_arr = -Jᵤ \ Jₚ
176+
177+
sumfun = ((z, p),) -> map(zᵢ -> zᵢ * ForwardDiff.partials(p), z)
178+
if cache.p isa Number
179+
partials = sumfun((z_arr, cache.p))
180+
else
181+
partials = sum(sumfun, zip(eachcol(z_arr), cache.p))
182+
end
183+
184+
dual_soln = NonlinearSolveBase.nonlinearsolve_dual_solution(sol.u, partials, cache.p)
185+
return SciMLBase.build_solution(
186+
prob, cache.alg, dual_soln, sol.resid; sol.retcode, sol.stats, sol.original
187+
)
188+
end
189+
190+
NonlinearSolveBase.nodual_value(x) = x
191+
NonlinearSolveBase.nodual_value(x::Dual) = ForwardDiff.value(x)
192+
NonlinearSolveBase.nodual_value(x::AbstractArray{<:Dual}) = map(ForwardDiff.value, x)
193+
194+
@inline NonlinearSolveBase.pickchunksize(x) = pickchunksize(length(x))
195+
@inline NonlinearSolveBase.pickchunksize(x::Int) = ForwardDiff.pickchunksize(x)
196+
105197
end

Diff for: lib/NonlinearSolveBase/src/NonlinearSolveBase.jl

+4
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,8 @@ include("descent/geodesic_acceleration.jl")
5757

5858
include("solve.jl")
5959

60+
include("forward_diff.jl")
61+
6062
# Unexported Public API
6163
@compat(public, (L2_NORM, Linf_NORM, NAN_CHECK, UNITLESS_ABS2, get_tolerance))
6264
@compat(public, (nonlinearsolve_forwarddiff_solve, nonlinearsolve_dual_solution))
@@ -83,4 +85,6 @@ export DescentResult, SteepestDescent, NewtonDescent, DampedNewtonDescent, Dogle
8385

8486
export NonlinearSolvePolyAlgorithm
8587

88+
export pickchunksize
89+
8690
end

Diff for: lib/NonlinearSolveBase/src/forward_diff.jl

+8
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
@concrete mutable struct NonlinearSolveForwardDiffCache <: AbstractNonlinearSolveCache
2+
cache
3+
prob
4+
alg
5+
p
6+
values_p
7+
partials_p
8+
end

Diff for: lib/NonlinearSolveBase/src/public.jl

+9
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,15 @@ function nonlinearsolve_dual_solution end
1111
function nonlinearsolve_∂f_∂p end
1212
function nonlinearsolve_∂f_∂u end
1313
function nlls_generate_vjp_function end
14+
function nodual_value end
15+
16+
"""
17+
pickchunksize(x) = pickchunksize(length(x))
18+
pickchunksize(x::Int)
19+
20+
Determine the chunk size for ForwardDiff and PolyesterForwardDiff based on the input length.
21+
"""
22+
function pickchunksize end
1423

1524
# Nonlinear Solve Termination Conditions
1625
abstract type AbstractNonlinearTerminationMode end

Diff for: lib/NonlinearSolveFirstOrder/Project.toml

+2-1
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ julia = "1.10"
6767
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
6868
BandedMatrices = "aae01518-5342-5314-be14-df237901396f"
6969
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
70+
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
7071
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
7172
ExplicitImports = "7d51a73a-1435-4ff3-83d9-f097790105c7"
7273
Hwloc = "0e44f5e4-bd66-52a0-8798-143a42290a1d"
@@ -86,4 +87,4 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
8687
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
8788

8889
[targets]
89-
test = ["Aqua", "BandedMatrices", "BenchmarkTools", "Enzyme", "ExplicitImports", "Hwloc", "InteractiveUtils", "LineSearch", "LineSearches", "NonlinearProblemLibrary", "Pkg", "Random", "ReTestItems", "SparseArrays", "SparseConnectivityTracer", "SparseMatrixColorings", "StableRNGs", "StaticArrays", "Test", "Zygote"]
90+
test = ["Aqua", "BandedMatrices", "BenchmarkTools", "ForwardDiff", "Enzyme", "ExplicitImports", "Hwloc", "InteractiveUtils", "LineSearch", "LineSearches", "NonlinearProblemLibrary", "Pkg", "Random", "ReTestItems", "SparseArrays", "SparseConnectivityTracer", "SparseMatrixColorings", "StableRNGs", "StaticArrays", "Test", "Zygote"]

Diff for: lib/NonlinearSolveFirstOrder/src/NonlinearSolveFirstOrder.jl

+4-2
Original file line numberDiff line numberDiff line change
@@ -22,14 +22,14 @@ using NonlinearSolveBase: NonlinearSolveBase, AbstractNonlinearSolveAlgorithm,
2222
Utils, InternalAPI, get_timer_output, @static_timeit,
2323
update_trace!, L2_NORM, NonlinearSolvePolyAlgorithm,
2424
NewtonDescent, DampedNewtonDescent, GeodesicAcceleration,
25-
Dogleg
25+
Dogleg, NonlinearSolveForwardDiffCache
2626
using SciMLBase: SciMLBase, AbstractNonlinearProblem, NLStats, ReturnCode,
2727
NonlinearFunction,
2828
NonlinearLeastSquaresProblem, NonlinearProblem, NoSpecialize
2929
using SciMLJacobianOperators: VecJacOperator, JacVecOperator, StatefulJacobianOperator
3030

3131
using FiniteDiff: FiniteDiff # Default Finite Difference Method
32-
using ForwardDiff: ForwardDiff # Default Forward Mode AD
32+
using ForwardDiff: ForwardDiff, Dual # Default Forward Mode AD
3333

3434
include("raphson.jl")
3535
include("gauss_newton.jl")
@@ -41,6 +41,8 @@ include("poly_algs.jl")
4141

4242
include("solve.jl")
4343

44+
include("forward_diff.jl")
45+
4446
@setup_workload begin
4547
nonlinear_functions = (
4648
(NonlinearFunction{false, NoSpecialize}((u, p) -> u .* u .- p), 0.1),

Diff for: lib/NonlinearSolveFirstOrder/src/forward_diff.jl

+34
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
const DualNonlinearProblem = NonlinearProblem{
2+
<:Union{Number, <:AbstractArray}, iip,
3+
<:Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}}
4+
} where {iip, T, V, P}
5+
const DualNonlinearLeastSquaresProblem = NonlinearLeastSquaresProblem{
6+
<:Union{Number, <:AbstractArray}, iip,
7+
<:Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}}
8+
} where {iip, T, V, P}
9+
const DualAbstractNonlinearProblem = Union{
10+
DualNonlinearProblem, DualNonlinearLeastSquaresProblem
11+
}
12+
13+
function SciMLBase.__init(
14+
prob::DualAbstractNonlinearProblem, alg::GeneralizedFirstOrderAlgorithm, args...; kwargs...
15+
)
16+
p = NonlinearSolveBase.nodual_value(prob.p)
17+
newprob = SciMLBase.remake(prob; u0 = NonlinearSolveBase.nodual_value(prob.u0), p)
18+
cache = init(newprob, alg, args...; kwargs...)
19+
return NonlinearSolveForwardDiffCache(
20+
cache, newprob, alg, prob.p, p, ForwardDiff.partials(prob.p)
21+
)
22+
end
23+
24+
function SciMLBase.__solve(
25+
prob::DualAbstractNonlinearProblem, alg::GeneralizedFirstOrderAlgorithm, args...; kwargs...
26+
)
27+
sol, partials = NonlinearSolveBase.nonlinearsolve_forwarddiff_solve(
28+
prob, alg, args...; kwargs...
29+
)
30+
dual_soln = NonlinearSolveBase.nonlinearsolve_dual_solution(sol.u, partials, prob.p)
31+
return SciMLBase.build_solution(
32+
prob, alg, dual_soln, sol.resid; sol.retcode, sol.stats, sol.original
33+
)
34+
end

Diff for: lib/NonlinearSolveFirstOrder/test/misc_tests.jl

+10
Original file line numberDiff line numberDiff line change
@@ -20,3 +20,13 @@
2020
@test sol.retcode == ReturnCode.Success
2121
@test jac_calls == 0
2222
end
23+
24+
@testitem "Dual of BigFloat: Issue #512" tags=[:core] begin
25+
using NonlinearSolveFirstOrder, ForwardDiff
26+
fn_iip = NonlinearFunction{true}((du, u, p) -> du .= u .* u .- p)
27+
u2 = [ForwardDiff.Dual(BigFloat(1.0), 5.0), ForwardDiff.Dual(BigFloat(1.0), 5.0),
28+
ForwardDiff.Dual(BigFloat(1.0), 5.0)]
29+
prob_iip_bf = NonlinearProblem{true}(fn_iip, u2, ForwardDiff.Dual(BigFloat(2.0), 5.0))
30+
sol = solve(prob_iip_bf, NewtonRaphson())
31+
@test sol.retcode == ReturnCode.Success
32+
end

Diff for: lib/NonlinearSolveQuasiNewton/Project.toml

+6
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,12 @@ SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
1818
SciMLOperators = "c0aeaf25-5076-4817-a8d5-81caf7dfa961"
1919
StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"
2020

21+
[weakdeps]
22+
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
23+
24+
[extensions]
25+
NonlinearSolveQuasiNewtonForwardDiffExt = "ForwardDiff"
26+
2127
[compat]
2228
ADTypes = "1.9.0"
2329
Aqua = "0.8"
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
module NonlinearSolveQuasiNewtonForwardDiffExt
2+
3+
using CommonSolve: CommonSolve, init
4+
using ForwardDiff: ForwardDiff, Dual
5+
using SciMLBase: SciMLBase, NonlinearProblem, NonlinearLeastSquaresProblem
6+
7+
using NonlinearSolveBase: NonlinearSolveBase, NonlinearSolveForwardDiffCache, nodual_value
8+
9+
using NonlinearSolveQuasiNewton: QuasiNewtonAlgorithm
10+
11+
const DualNonlinearProblem = NonlinearProblem{
12+
<:Union{Number, <:AbstractArray}, iip,
13+
<:Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}}
14+
} where {iip, T, V, P}
15+
const DualNonlinearLeastSquaresProblem = NonlinearLeastSquaresProblem{
16+
<:Union{Number, <:AbstractArray}, iip,
17+
<:Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}}
18+
} where {iip, T, V, P}
19+
const DualAbstractNonlinearProblem = Union{
20+
DualNonlinearProblem, DualNonlinearLeastSquaresProblem
21+
}
22+
23+
function SciMLBase.__solve(
24+
prob::DualAbstractNonlinearProblem, alg::QuasiNewtonAlgorithm, args...; kwargs...
25+
)
26+
sol, partials = NonlinearSolveBase.nonlinearsolve_forwarddiff_solve(
27+
prob, alg, args...; kwargs...
28+
)
29+
dual_soln = NonlinearSolveBase.nonlinearsolve_dual_solution(sol.u, partials, prob.p)
30+
return SciMLBase.build_solution(
31+
prob, alg, dual_soln, sol.resid; sol.retcode, sol.stats, sol.original
32+
)
33+
end
34+
35+
function SciMLBase.__init(
36+
prob::DualAbstractNonlinearProblem, alg::QuasiNewtonAlgorithm, args...; kwargs...
37+
)
38+
p = nodual_value(prob.p)
39+
newprob = SciMLBase.remake(prob; u0 = nodual_value(prob.u0), p)
40+
cache = init(newprob, alg, args...; kwargs...)
41+
return NonlinearSolveForwardDiffCache(
42+
cache, newprob, alg, prob.p, p, ForwardDiff.partials(prob.p)
43+
)
44+
end
45+
46+
end

Diff for: lib/NonlinearSolveSpectralMethods/Project.toml

+7
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,20 @@ PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
1414
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
1515
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
1616

17+
[weakdeps]
18+
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
19+
20+
[extensions]
21+
NonlinearSolveSpectralMethodsForwardDiffExt = "ForwardDiff"
22+
1723
[compat]
1824
Aqua = "0.8"
1925
BenchmarkTools = "1.5.0"
2026
CommonSolve = "0.2.4"
2127
ConcreteStructs = "0.2.3"
2228
DiffEqBase = "6.158.3"
2329
ExplicitImports = "1.5"
30+
ForwardDiff = "0.10.36"
2431
Hwloc = "3"
2532
InteractiveUtils = "<0.0.1, 1"
2633
LineSearch = "0.1.4"
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
module NonlinearSolveSpectralMethodsForwardDiffExt
2+
3+
using CommonSolve: CommonSolve, init
4+
using ForwardDiff: ForwardDiff, Dual
5+
using SciMLBase: SciMLBase, NonlinearProblem, NonlinearLeastSquaresProblem
6+
7+
using NonlinearSolveBase: NonlinearSolveBase, NonlinearSolveForwardDiffCache, nodual_value
8+
9+
using NonlinearSolveSpectralMethods: GeneralizedDFSane
10+
11+
const DualNonlinearProblem = NonlinearProblem{
12+
<:Union{Number, <:AbstractArray}, iip,
13+
<:Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}}
14+
} where {iip, T, V, P}
15+
const DualNonlinearLeastSquaresProblem = NonlinearLeastSquaresProblem{
16+
<:Union{Number, <:AbstractArray}, iip,
17+
<:Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}}
18+
} where {iip, T, V, P}
19+
const DualAbstractNonlinearProblem = Union{
20+
DualNonlinearProblem, DualNonlinearLeastSquaresProblem
21+
}
22+
23+
function SciMLBase.__solve(
24+
prob::DualAbstractNonlinearProblem, alg::GeneralizedDFSane, args...; kwargs...
25+
)
26+
sol, partials = NonlinearSolveBase.nonlinearsolve_forwarddiff_solve(
27+
prob, alg, args...; kwargs...
28+
)
29+
dual_soln = NonlinearSolveBase.nonlinearsolve_dual_solution(sol.u, partials, prob.p)
30+
return SciMLBase.build_solution(
31+
prob, alg, dual_soln, sol.resid; sol.retcode, sol.stats, sol.original
32+
)
33+
end
34+
35+
function SciMLBase.__init(
36+
prob::DualAbstractNonlinearProblem, alg::GeneralizedDFSane, args...; kwargs...
37+
)
38+
p = nodual_value(prob.p)
39+
newprob = SciMLBase.remake(prob; u0 = nodual_value(prob.u0), p)
40+
cache = init(newprob, alg, args...; kwargs...)
41+
return NonlinearSolveForwardDiffCache(
42+
cache, newprob, alg, prob.p, p, ForwardDiff.partials(prob.p)
43+
)
44+
end
45+
46+
end

0 commit comments

Comments
 (0)