@@ -2,17 +2,35 @@ module NonlinearSolveBaseForwardDiffExt
2
2
3
3
using ADTypes: ADTypes, AutoForwardDiff, AutoPolyesterForwardDiff
4
4
using ArrayInterface: ArrayInterface
5
- using CommonSolve: solve
5
+ using CommonSolve: CommonSolve, solve, solve!, init
6
+ using ConcreteStructs: @concrete
6
7
using DifferentiationInterface: DifferentiationInterface
7
8
using FastClosures: @closure
8
- using ForwardDiff: ForwardDiff, Dual
9
+ using ForwardDiff: ForwardDiff, Dual, pickchunksize
9
10
using SciMLBase: SciMLBase, AbstractNonlinearProblem, IntervalNonlinearProblem,
10
11
NonlinearProblem, NonlinearLeastSquaresProblem, remake
11
12
12
- using NonlinearSolveBase: NonlinearSolveBase, ImmutableNonlinearProblem, Utils
13
+ using NonlinearSolveBase: NonlinearSolveBase, ImmutableNonlinearProblem, Utils, InternalAPI,
14
+ NonlinearSolvePolyAlgorithm, NonlinearSolveForwardDiffCache
13
15
14
16
const DI = DifferentiationInterface
15
17
18
+ const GENERAL_SOLVER_TYPES = [
19
+ Nothing, NonlinearSolvePolyAlgorithm
20
+ ]
21
+
22
+ const DualNonlinearProblem = NonlinearProblem{
23
+ <: Union{Number, <:AbstractArray} , iip,
24
+ <: Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}}
25
+ } where {iip, T, V, P}
26
+ const DualNonlinearLeastSquaresProblem = NonlinearLeastSquaresProblem{
27
+ <: Union{Number, <:AbstractArray} , iip,
28
+ <: Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}}
29
+ } where {iip, T, V, P}
30
+ const DualAbstractNonlinearProblem = Union{
31
+ DualNonlinearProblem, DualNonlinearLeastSquaresProblem
32
+ }
33
+
16
34
function NonlinearSolveBase. additional_incompatible_backend_check (
17
35
prob:: AbstractNonlinearProblem , :: Union{AutoForwardDiff, AutoPolyesterForwardDiff} )
18
36
return ! ForwardDiff. can_dual (eltype (prob. u0))
@@ -102,4 +120,78 @@ function NonlinearSolveBase.nonlinearsolve_dual_solution(
102
120
return map (((uᵢ, pᵢ),) -> Dual {T, V, P} (uᵢ, pᵢ), zip (u, Utils. restructure (u, partials)))
103
121
end
104
122
123
+ for algType in GENERAL_SOLVER_TYPES
124
+ @eval function SciMLBase. __solve (
125
+ prob:: DualAbstractNonlinearProblem , alg:: $ (algType), args... ; kwargs...
126
+ )
127
+ sol, partials = NonlinearSolveBase. nonlinearsolve_forwarddiff_solve (
128
+ prob, alg, args... ; kwargs...
129
+ )
130
+ dual_soln = NonlinearSolveBase. nonlinearsolve_dual_solution (sol. u, partials, prob. p)
131
+ return SciMLBase. build_solution (
132
+ prob, alg, dual_soln, sol. resid; sol. retcode, sol. stats, sol. original
133
+ )
134
+ end
135
+ end
136
+
137
+ function InternalAPI. reinit! (
138
+ cache:: NonlinearSolveForwardDiffCache , args... ;
139
+ p = cache. p, u0 = NonlinearSolveBase. get_u (cache. cache), kwargs...
140
+ )
141
+ InternalAPI. reinit! (
142
+ cache. cache; p = NonlinearSolveBase. nodual_value (p),
143
+ u0 = NonlinearSolveBase. nodual_value (u0), kwargs...
144
+ )
145
+ cache. p = p
146
+ cache. values_p = NonlinearSolveBase. nodual_value (p)
147
+ cache. partials_p = ForwardDiff. partials (p)
148
+ return cache
149
+ end
150
+
151
+ for algType in GENERAL_SOLVER_TYPES
152
+ @eval function SciMLBase. __init (
153
+ prob:: DualAbstractNonlinearProblem , alg:: $ (algType), args... ; kwargs...
154
+ )
155
+ p = NonlinearSolveBase. nodual_value (prob. p)
156
+ newprob = SciMLBase. remake (prob; u0 = NonlinearSolveBase. nodual_value (prob. u0), p)
157
+ cache = init (newprob, alg, args... ; kwargs... )
158
+ return NonlinearSolveForwardDiffCache (
159
+ cache, newprob, alg, prob. p, p, ForwardDiff. partials (prob. p)
160
+ )
161
+ end
162
+ end
163
+
164
+ function CommonSolve. solve! (cache:: NonlinearSolveForwardDiffCache )
165
+ sol = solve! (cache. cache)
166
+ prob = cache. prob
167
+ uu = sol. u
168
+
169
+ fn = prob isa NonlinearLeastSquaresProblem ?
170
+ NonlinearSolveBase. nlls_generate_vjp_function (prob, sol, uu) : prob. f
171
+
172
+ Jₚ = NonlinearSolveBase. nonlinearsolve_∂f_∂p (prob, fn, uu, cache. values_p)
173
+ Jᵤ = NonlinearSolveBase. nonlinearsolve_∂f_∂u (prob, fn, uu, cache. values_p)
174
+
175
+ z_arr = - Jᵤ \ Jₚ
176
+
177
+ sumfun = ((z, p),) -> map (zᵢ -> zᵢ * ForwardDiff. partials (p), z)
178
+ if cache. p isa Number
179
+ partials = sumfun ((z_arr, cache. p))
180
+ else
181
+ partials = sum (sumfun, zip (eachcol (z_arr), cache. p))
182
+ end
183
+
184
+ dual_soln = NonlinearSolveBase. nonlinearsolve_dual_solution (sol. u, partials, cache. p)
185
+ return SciMLBase. build_solution (
186
+ prob, cache. alg, dual_soln, sol. resid; sol. retcode, sol. stats, sol. original
187
+ )
188
+ end
189
+
190
+ NonlinearSolveBase. nodual_value (x) = x
191
+ NonlinearSolveBase. nodual_value (x:: Dual ) = ForwardDiff. value (x)
192
+ NonlinearSolveBase. nodual_value (x:: AbstractArray{<:Dual} ) = map (ForwardDiff. value, x)
193
+
194
+ @inline NonlinearSolveBase. pickchunksize (x) = pickchunksize (length (x))
195
+ @inline NonlinearSolveBase. pickchunksize (x:: Int ) = ForwardDiff. pickchunksize (x)
196
+
105
197
end
0 commit comments