1
- function scalar_nlsolve_ad (prob, alg, args... ; kwargs... )
2
- f = prob. f
1
+ function SciMLBase. solve (prob:: NonlinearProblem {<: Union{Number, <:AbstractArray} ,
2
+ iip, <: Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}} },
3
+ alg:: Union{Nothing, AbstractNonlinearAlgorithm} , args... ;
4
+ kwargs... ) where {T, V, P, iip}
5
+ sol, partials = __nlsolve_ad (prob, alg, args... ; kwargs... )
6
+ dual_soln = __nlsolve_dual_soln (sol. u, partials, prob. p)
7
+ return SciMLBase. build_solution (prob, alg, dual_soln, sol. resid; sol. retcode)
8
+ end
9
+
10
+ # Differentiate Out-of-Place Nonlinear Root Finding Problems
11
+ function __nlsolve_ad (prob:: NonlinearProblem{uType, false} , alg, args... ;
12
+ kwargs... ) where {uType}
3
13
p = value (prob. p)
4
- u0 = value (prob. u0)
5
- newprob = NonlinearProblem (f, u0, p; prob. kwargs... )
14
+ newprob = NonlinearProblem (prob. f, value (prob. u0), p; prob. kwargs... )
6
15
7
16
sol = solve (newprob, alg, args... ; kwargs... )
8
17
9
18
uu = sol. u
10
- f_p = scalar_nlsolve_ ∂f_∂p (f, uu, p)
11
- f_x = scalar_nlsolve_ ∂f_∂u (f, uu, p)
19
+ f_p = __nlsolve_ ∂f_∂p (prob . f, uu, p)
20
+ f_x = __nlsolve_ ∂f_∂u (prob . f, uu, p)
12
21
13
- z_arr = - inv ( f_x) * f_p
22
+ z_arr = - f_x \ f_p
14
23
15
24
pp = prob. p
16
25
sumfun = ((z, p),) -> map (zᵢ -> zᵢ * ForwardDiff. partials (p), z)
@@ -25,39 +34,33 @@ function scalar_nlsolve_ad(prob, alg, args...; kwargs...)
25
34
return sol, partials
26
35
end
27
36
28
- function SciMLBase. solve (prob:: NonlinearProblem {<: Union{Number, SVector, <:AbstractArray} ,
29
- false , <: Dual{T, V, P} }, alg:: AbstractNonlinearSolveAlgorithm , args... ;
30
- kwargs... ) where {T, V, P}
31
- sol, partials = scalar_nlsolve_ad (prob, alg, args... ; kwargs... )
32
- dual_soln = scalar_nlsolve_dual_soln (sol. u, partials, prob. p)
33
- return SciMLBase. build_solution (prob, alg, dual_soln, sol. resid; sol. retcode)
34
- end
35
-
36
- function SciMLBase. solve (prob:: NonlinearProblem {<: Union{Number, SVector, <:AbstractArray} ,
37
- false , <: AbstractArray{<:Dual{T, V, P}} }, alg:: AbstractNonlinearSolveAlgorithm ,
38
- args... ; kwargs... ) where {T, V, P}
39
- sol, partials = scalar_nlsolve_ad (prob, alg, args... ; kwargs... )
40
- dual_soln = scalar_nlsolve_dual_soln (sol. u, partials, prob. p)
41
- return SciMLBase. build_solution (prob, alg, dual_soln, sol. resid; sol. retcode)
42
- end
43
-
44
- function scalar_nlsolve_∂f_∂p (f, u, p)
45
- ff = p isa Number ? ForwardDiff. derivative :
46
- (u isa Number ? ForwardDiff. gradient : ForwardDiff. jacobian)
47
- return ff (Base. Fix1 (f, u), p)
37
+ @inline function __nlsolve_∂f_∂p (f:: F , u, p) where {F}
38
+ __f = Base. Fix1 (f, u)
39
+ if p isa Number
40
+ return __reshape (ForwardDiff. derivative (__f, p), :, 1 )
41
+ elseif u isa Number
42
+ return __reshape (ForwardDiff. gradient (__f, p), 1 , :)
43
+ else
44
+ return ForwardDiff. jacobian (__f, p)
45
+ end
48
46
end
49
47
50
- function scalar_nlsolve_∂f_∂u (f, u, p)
51
- ff = u isa Number ? ForwardDiff. derivative : ForwardDiff. jacobian
52
- return ff (Base. Fix2 (f, p), u)
48
+ @inline function __nlsolve_∂f_∂u (f:: F , u, p) where {F}
49
+ __f = Base. Fix2 (f, p)
50
+ if u isa Number
51
+ return ForwardDiff. derivative (__f, u)
52
+ else
53
+ return ForwardDiff. jacobian (__f, u)
54
+ end
53
55
end
54
56
55
- function scalar_nlsolve_dual_soln (u:: Number , partials,
57
+ @inline function __nlsolve_dual_soln (u:: Number , partials,
56
58
:: Union{<:AbstractArray{<:Dual{T, V, P}}, Dual{T, V, P}} ) where {T, V, P}
57
59
return Dual {T, V, P} (u, partials)
58
60
end
59
61
60
- function scalar_nlsolve_dual_soln (u:: AbstractArray , partials,
62
+ @inline function __nlsolve_dual_soln (u:: AbstractArray , partials,
61
63
:: Union{<:AbstractArray{<:Dual{T, V, P}}, Dual{T, V, P}} ) where {T, V, P}
62
- return map (((uᵢ, pᵢ),) -> Dual {T, V, P} (uᵢ, pᵢ), zip (u, partials))
64
+ _partials = _restructure (u, partials)
65
+ return map (((uᵢ, pᵢ),) -> Dual {T, V, P} (uᵢ, pᵢ), zip (u, _partials))
63
66
end
0 commit comments