|
1 | 1 | function scalar_nlsolve_ad(prob, alg, args...; kwargs...)
|
2 | 2 | f = prob.f
|
3 | 3 | p = value(prob.p)
|
4 |
| - |
5 | 4 | u0 = value(prob.u0)
|
6 | 5 | newprob = NonlinearProblem(f, u0, p; prob.kwargs...)
|
7 | 6 |
|
8 | 7 | sol = solve(newprob, alg, args...; kwargs...)
|
9 | 8 |
|
10 | 9 | uu = sol.u
|
11 |
| - if p isa Number |
12 |
| - f_p = ForwardDiff.derivative(Base.Fix1(f, uu), p) |
13 |
| - else |
14 |
| - f_p = ForwardDiff.gradient(Base.Fix1(f, uu), p) |
15 |
| - end |
| 10 | + f_p = scalar_nlsolve_∂f_∂p(f, uu, p) |
| 11 | + f_x = scalar_nlsolve_∂f_∂u(f, uu, p) |
| 12 | + |
| 13 | + z_arr = -inv(f_x) * f_p |
16 | 14 |
|
17 |
| - f_x = ForwardDiff.derivative(Base.Fix2(f, p), uu) |
18 | 15 | pp = prob.p
|
19 |
| - sumfun = let f_x′ = -f_x |
20 |
| - ((fp, p),) -> (fp / f_x′) * ForwardDiff.partials(p) |
| 16 | + sumfun = ((z, p),) -> map(zᵢ -> zᵢ * ForwardDiff.partials(p), z) |
| 17 | + if uu isa Number |
| 18 | + partials = sum(sumfun, zip(z_arr, pp)) |
| 19 | + elseif p isa Number |
| 20 | + partials = sumfun((z_arr, pp)) |
| 21 | + else |
| 22 | + partials = sum(sumfun, zip(eachcol(z_arr), pp)) |
21 | 23 | end
|
22 |
| - partials = sum(sumfun, zip(f_p, pp)) |
| 24 | + |
23 | 25 | return sol, partials
|
24 | 26 | end
|
25 | 27 |
|
26 |
| -function SciMLBase.solve(prob::NonlinearProblem{<:Union{Number, StaticArraysCore.SVector}, |
27 |
| - iip, |
28 |
| - <:Dual{T, V, P}}, |
29 |
| - alg::AbstractNewtonAlgorithm, |
30 |
| - args...; kwargs...) where {iip, T, V, P} |
| 28 | +function SciMLBase.solve(prob::NonlinearProblem{<:Union{Number, SVector, <:AbstractArray}, |
| 29 | + iip, <:Dual{T, V, P}}, alg::AbstractNewtonAlgorithm, args...; |
| 30 | + kwargs...) where {iip, T, V, P} |
31 | 31 | sol, partials = scalar_nlsolve_ad(prob, alg, args...; kwargs...)
|
32 |
| - return SciMLBase.build_solution(prob, alg, Dual{T, V, P}(sol.u, partials), sol.resid; |
33 |
| - retcode = sol.retcode) |
| 32 | + dual_soln = scalar_nlsolve_dual_soln(sol.u, partials, prob.p) |
| 33 | + return SciMLBase.build_solution(prob, alg, dual_soln, sol.resid; sol.retcode) |
34 | 34 | end
|
35 |
| -function SciMLBase.solve(prob::NonlinearProblem{<:Union{Number, StaticArraysCore.SVector}, |
36 |
| - iip, |
37 |
| - <:AbstractArray{<:Dual{T, V, P}}}, |
38 |
| - alg::AbstractNewtonAlgorithm, |
39 |
| - args...; |
| 35 | + |
| 36 | +function SciMLBase.solve(prob::NonlinearProblem{<:Union{Number, SVector, <:AbstractArray}, |
| 37 | + iip, <:AbstractArray{<:Dual{T, V, P}}}, alg::AbstractNewtonAlgorithm, args...; |
40 | 38 | kwargs...) where {iip, T, V, P}
|
41 | 39 | sol, partials = scalar_nlsolve_ad(prob, alg, args...; kwargs...)
|
42 |
| - return SciMLBase.build_solution(prob, alg, Dual{T, V, P}(sol.u, partials), sol.resid; |
43 |
| - retcode = sol.retcode) |
| 40 | + dual_soln = scalar_nlsolve_dual_soln(sol.u, partials, prob.p) |
| 41 | + return SciMLBase.build_solution(prob, alg, dual_soln, sol.resid; sol.retcode) |
| 42 | +end |
| 43 | + |
| 44 | +function scalar_nlsolve_∂f_∂p(f, u, p) |
| 45 | + ff = p isa Number ? ForwardDiff.derivative : |
| 46 | + (u isa Number ? ForwardDiff.gradient : ForwardDiff.jacobian) |
| 47 | + return ff(Base.Fix1(f, u), p) |
| 48 | +end |
| 49 | + |
| 50 | +function scalar_nlsolve_∂f_∂u(f, u, p) |
| 51 | + ff = u isa Number ? ForwardDiff.derivative : ForwardDiff.jacobian |
| 52 | + return ff(Base.Fix2(f, p), u) |
| 53 | +end |
| 54 | + |
| 55 | +function scalar_nlsolve_dual_soln(u::Number, partials, |
| 56 | + ::Union{<:AbstractArray{<:Dual{T, V, P}}, Dual{T, V, P}}) where {T, V, P} |
| 57 | + return Dual{T, V, P}(u, partials) |
| 58 | +end |
| 59 | + |
| 60 | +function scalar_nlsolve_dual_soln(u::AbstractArray, partials, |
| 61 | + ::Union{<:AbstractArray{<:Dual{T, V, P}}, Dual{T, V, P}}) where {T, V, P} |
| 62 | + return map(((uᵢ, pᵢ),) -> Dual{T, V, P}(uᵢ, pᵢ), zip(u, partials)) |
44 | 63 | end
|
0 commit comments