Skip to content

Commit 2c14522

Browse files
committed
Streamline the testing for oop AD
1 parent eabb403 commit 2c14522

9 files changed

+130
-295
lines changed

Diff for: Project.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "NonlinearSolve"
22
uuid = "8913a72c-1f9b-4ce2-8d82-65094dcecaec"
33
authors = ["SciML"]
4-
version = "3.2.0"
4+
version = "3.3.0"
55

66
[deps]
77
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"

Diff for: ext/NonlinearSolveMINPACKExt.jl

+5-1
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,11 @@ function SciMLBase.__solve(prob::Union{NonlinearProblem{uType, iip},
8080
stats = SciMLBase.NLStats(original.trace.f_calls, original.trace.g_calls,
8181
original.trace.g_calls, original.trace.g_calls, -1)
8282

83-
return SciMLBase.build_solution(prob, alg, u, resid; stats, retcode, original)
83+
if prob.u0 isa Number
84+
return SciMLBase.build_solution(prob, alg, u[1], resid[1]; stats, retcode, original)
85+
else
86+
return SciMLBase.build_solution(prob, alg, u, resid; stats, retcode, original)
87+
end
8488
end
8589

8690
end

Diff for: ext/NonlinearSolveNLsolveExt.jl

+6-2
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::NLsolveJL, args...;
99
@assert (termination_condition ===
1010
nothing)||(termination_condition isa AbsNormTerminationMode) "NLsolveJL does not support termination conditions!"
1111

12-
if typeof(prob.u0) <: Number
12+
if prob.u0 isa Number
1313
u0 = [prob.u0]
1414
else
1515
u0 = NonlinearSolve.__maybe_unaliased(prob.u0, alias_u0)
@@ -82,7 +82,11 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::NLsolveJL, args...;
8282
ReturnCode.Failure
8383
stats = SciMLBase.NLStats(original.f_calls, original.g_calls, original.g_calls,
8484
original.g_calls, original.iterations)
85-
return SciMLBase.build_solution(prob, alg, u, resid; retcode, original, stats)
85+
if prob.u0 isa Number
86+
return SciMLBase.build_solution(prob, alg, u[1], resid[1]; retcode, original, stats)
87+
else
88+
return SciMLBase.build_solution(prob, alg, u, resid; retcode, original, stats)
89+
end
8690
end
8791

8892
end

Diff for: src/ad.jl

+36-33
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,25 @@
1-
function scalar_nlsolve_ad(prob, alg, args...; kwargs...)
2-
f = prob.f
1+
function SciMLBase.solve(prob::NonlinearProblem{<:Union{Number, <:AbstractArray},
2+
iip, <:Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}}},
3+
alg::Union{Nothing, AbstractNonlinearAlgorithm}, args...;
4+
kwargs...) where {T, V, P, iip}
5+
sol, partials = __nlsolve_ad(prob, alg, args...; kwargs...)
6+
dual_soln = __nlsolve_dual_soln(sol.u, partials, prob.p)
7+
return SciMLBase.build_solution(prob, alg, dual_soln, sol.resid; sol.retcode)
8+
end
9+
10+
# Differentiate Out-of-Place Nonlinear Root Finding Problems
11+
function __nlsolve_ad(prob::NonlinearProblem{uType, false}, alg, args...;
12+
kwargs...) where {uType}
313
p = value(prob.p)
4-
u0 = value(prob.u0)
5-
newprob = NonlinearProblem(f, u0, p; prob.kwargs...)
14+
newprob = NonlinearProblem(prob.f, value(prob.u0), p; prob.kwargs...)
615

716
sol = solve(newprob, alg, args...; kwargs...)
817

918
uu = sol.u
10-
f_p = scalar_nlsolve_∂f_∂p(f, uu, p)
11-
f_x = scalar_nlsolve_∂f_∂u(f, uu, p)
19+
f_p = __nlsolve_∂f_∂p(prob.f, uu, p)
20+
f_x = __nlsolve_∂f_∂u(prob.f, uu, p)
1221

13-
z_arr = -inv(f_x) * f_p
22+
z_arr = -f_x \ f_p
1423

1524
pp = prob.p
1625
sumfun = ((z, p),) -> map(zᵢ -> zᵢ * ForwardDiff.partials(p), z)
@@ -25,39 +34,33 @@ function scalar_nlsolve_ad(prob, alg, args...; kwargs...)
2534
return sol, partials
2635
end
2736

28-
function SciMLBase.solve(prob::NonlinearProblem{<:Union{Number, SVector, <:AbstractArray},
29-
false, <:Dual{T, V, P}}, alg::AbstractNonlinearSolveAlgorithm, args...;
30-
kwargs...) where {T, V, P}
31-
sol, partials = scalar_nlsolve_ad(prob, alg, args...; kwargs...)
32-
dual_soln = scalar_nlsolve_dual_soln(sol.u, partials, prob.p)
33-
return SciMLBase.build_solution(prob, alg, dual_soln, sol.resid; sol.retcode)
34-
end
35-
36-
function SciMLBase.solve(prob::NonlinearProblem{<:Union{Number, SVector, <:AbstractArray},
37-
false, <:AbstractArray{<:Dual{T, V, P}}}, alg::AbstractNonlinearSolveAlgorithm,
38-
args...; kwargs...) where {T, V, P}
39-
sol, partials = scalar_nlsolve_ad(prob, alg, args...; kwargs...)
40-
dual_soln = scalar_nlsolve_dual_soln(sol.u, partials, prob.p)
41-
return SciMLBase.build_solution(prob, alg, dual_soln, sol.resid; sol.retcode)
42-
end
43-
44-
function scalar_nlsolve_∂f_∂p(f, u, p)
45-
ff = p isa Number ? ForwardDiff.derivative :
46-
(u isa Number ? ForwardDiff.gradient : ForwardDiff.jacobian)
47-
return ff(Base.Fix1(f, u), p)
37+
@inline function __nlsolve_∂f_∂p(f::F, u, p) where {F}
38+
__f = Base.Fix1(f, u)
39+
if p isa Number
40+
return __reshape(ForwardDiff.derivative(__f, p), :, 1)
41+
elseif u isa Number
42+
return __reshape(ForwardDiff.gradient(__f, p), 1, :)
43+
else
44+
return ForwardDiff.jacobian(__f, p)
45+
end
4846
end
4947

50-
function scalar_nlsolve_∂f_∂u(f, u, p)
51-
ff = u isa Number ? ForwardDiff.derivative : ForwardDiff.jacobian
52-
return ff(Base.Fix2(f, p), u)
48+
@inline function __nlsolve_∂f_∂u(f::F, u, p) where {F}
49+
__f = Base.Fix2(f, p)
50+
if u isa Number
51+
return ForwardDiff.derivative(__f, u)
52+
else
53+
return ForwardDiff.jacobian(__f, u)
54+
end
5355
end
5456

55-
function scalar_nlsolve_dual_soln(u::Number, partials,
57+
@inline function __nlsolve_dual_soln(u::Number, partials,
5658
::Union{<:AbstractArray{<:Dual{T, V, P}}, Dual{T, V, P}}) where {T, V, P}
5759
return Dual{T, V, P}(u, partials)
5860
end
5961

60-
function scalar_nlsolve_dual_soln(u::AbstractArray, partials,
62+
@inline function __nlsolve_dual_soln(u::AbstractArray, partials,
6163
::Union{<:AbstractArray{<:Dual{T, V, P}}, Dual{T, V, P}}) where {T, V, P}
62-
return map(((uᵢ, pᵢ),) -> Dual{T, V, P}(uᵢ, pᵢ), zip(u, partials))
64+
_partials = _restructure(u, partials)
65+
return map(((uᵢ, pᵢ),) -> Dual{T, V, P}(uᵢ, pᵢ), zip(u, _partials))
6366
end

Diff for: src/jacobian.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,7 @@ function linsolve_caches(A, b, u, p, alg; linsolve_kwargs = (;))
169169
(alg.linsolve === nothing && A isa SMatrix && linsolve_kwargs === (;))
170170
# Default handling for SArrays in LinearSolve is not great. Some parts are patched
171171
# but there are quite a few unnecessary allocations
172-
return FakeLinearSolveJLCache(A, b)
172+
return FakeLinearSolveJLCache(A, _vec(b))
173173
end
174174

175175
linprob = LinearProblem(A, _vec(b); u0 = _vec(u), linsolve_kwargs...)

Diff for: src/utils.jl

+3
Original file line numberDiff line numberDiff line change
@@ -499,3 +499,6 @@ end
499499
@inline __is_complex(::Type{ComplexF32}) = true
500500
@inline __is_complex(::Type{Complex}) = true
501501
@inline __is_complex(::Type{T}) where {T} = false
502+
503+
@inline __reshape(x::Number, args...) = x
504+
@inline __reshape(x::AbstractArray, args...) = reshape(x, args...)

0 commit comments

Comments
 (0)