Skip to content

ForwardDiff Proper Support #340

New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Merged
merged 2 commits into from
Dec 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "NonlinearSolve"
uuid = "8913a72c-1f9b-4ce2-8d82-65094dcecaec"
authors = ["SciML"]
version = "3.2.0"
version = "3.3.0"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down
6 changes: 5 additions & 1 deletion ext/NonlinearSolveMINPACKExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,11 @@ function SciMLBase.__solve(prob::Union{NonlinearProblem{uType, iip},
stats = SciMLBase.NLStats(original.trace.f_calls, original.trace.g_calls,
original.trace.g_calls, original.trace.g_calls, -1)

return SciMLBase.build_solution(prob, alg, u, resid; stats, retcode, original)
if prob.u0 isa Number
return SciMLBase.build_solution(prob, alg, u[1], resid[1]; stats, retcode, original)
else
return SciMLBase.build_solution(prob, alg, u, resid; stats, retcode, original)
end
end

end
8 changes: 6 additions & 2 deletions ext/NonlinearSolveNLsolveExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::NLsolveJL, args...;
@assert (termination_condition ===
nothing)||(termination_condition isa AbsNormTerminationMode) "NLsolveJL does not support termination conditions!"

if typeof(prob.u0) <: Number
if prob.u0 isa Number
u0 = [prob.u0]
else
u0 = NonlinearSolve.__maybe_unaliased(prob.u0, alias_u0)
Expand Down Expand Up @@ -82,7 +82,11 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::NLsolveJL, args...;
ReturnCode.Failure
stats = SciMLBase.NLStats(original.f_calls, original.g_calls, original.g_calls,
original.g_calls, original.iterations)
return SciMLBase.build_solution(prob, alg, u, resid; retcode, original, stats)
if prob.u0 isa Number
return SciMLBase.build_solution(prob, alg, u[1], resid[1]; retcode, original, stats)
else
return SciMLBase.build_solution(prob, alg, u, resid; retcode, original, stats)
end
end

end
141 changes: 108 additions & 33 deletions src/ad.jl
Original file line number Diff line number Diff line change
@@ -1,16 +1,83 @@
function scalar_nlsolve_ad(prob, alg, args...; kwargs...)
f = prob.f
function SciMLBase.solve(prob::NonlinearProblem{<:Union{Number, <:AbstractArray},
iip, <:Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}}},
alg::Union{Nothing, AbstractNonlinearAlgorithm}, args...;
kwargs...) where {T, V, P, iip}
sol, partials = __nlsolve_ad(prob, alg, args...; kwargs...)
dual_soln = __nlsolve_dual_soln(sol.u, partials, prob.p)
return SciMLBase.build_solution(prob, alg, dual_soln, sol.resid; sol.retcode, sol.stats,
sol.original)
end

@concrete mutable struct NonlinearSolveForwardDiffCache
cache
prob
alg
p
values_p
partials_p
end

@inline function __has_duals(::Union{<:Dual{T, V, P},
<:AbstractArray{<:Dual{T, V, P}}}) where {T, V, P}
return true
end
@inline __has_duals(::Any) = false

function SciMLBase.reinit!(cache::NonlinearSolveForwardDiffCache; p = cache.p,
u0 = get_u(cache.cache), kwargs...)
inner_cache = SciMLBase.reinit!(cache.cache; p = value(p), u0 = value(u0), kwargs...)
cache.cache = inner_cache
cache.p = p
cache.values_p = value(p)
cache.partials_p = ForwardDiff.partials(p)
return cache
end

function SciMLBase.init(prob::NonlinearProblem{<:Union{Number, <:AbstractArray},
iip, <:Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}}},
alg::Union{Nothing, AbstractNonlinearAlgorithm}, args...;
kwargs...) where {T, V, P, iip}
p = value(prob.p)
u0 = value(prob.u0)
newprob = NonlinearProblem(f, u0, p; prob.kwargs...)
newprob = NonlinearProblem(prob.f, value(prob.u0), p; prob.kwargs...)
cache = init(newprob, alg, args...; kwargs...)
return NonlinearSolveForwardDiffCache(cache, newprob, alg, prob.p, p,
ForwardDiff.partials(prob.p))
end

function SciMLBase.solve!(cache::NonlinearSolveForwardDiffCache)
sol = solve!(cache.cache)
prob = cache.prob

uu = sol.u
f_p = __nlsolve_∂f_∂p(prob, prob.f, uu, cache.values_p)
f_x = __nlsolve_∂f_∂u(prob, prob.f, uu, cache.values_p)

z_arr = -f_x \ f_p

sumfun = ((z, p),) -> map(zᵢ -> zᵢ * ForwardDiff.partials(p), z)
if cache.p isa Number
partials = sumfun((z_arr, cache.p))
else
partials = sum(sumfun, zip(eachcol(z_arr), cache.p))
end

dual_soln = __nlsolve_dual_soln(sol.u, partials, cache.p)
return SciMLBase.build_solution(prob, cache.alg, dual_soln, sol.resid; sol.retcode,
sol.stats, sol.original)
end

function __nlsolve_ad(prob::NonlinearProblem{uType, iip}, alg, args...;
kwargs...) where {uType, iip}
p = value(prob.p)
newprob = NonlinearProblem(prob.f, value(prob.u0), p; prob.kwargs...)

sol = solve(newprob, alg, args...; kwargs...)

uu = sol.u
f_p = scalar_nlsolve_∂f_∂p(f, uu, p)
f_x = scalar_nlsolve_∂f_∂u(f, uu, p)
f_p = __nlsolve_∂f_∂p(prob, prob.f, uu, p)
f_x = __nlsolve_∂f_∂u(prob, prob.f, uu, p)

z_arr = -inv(f_x) * f_p
z_arr = -f_x \ f_p

pp = prob.p
sumfun = ((z, p),) -> map(zᵢ -> zᵢ * ForwardDiff.partials(p), z)
Expand All @@ -25,39 +92,47 @@ function scalar_nlsolve_ad(prob, alg, args...; kwargs...)
return sol, partials
end

function SciMLBase.solve(prob::NonlinearProblem{<:Union{Number, SVector, <:AbstractArray},
false, <:Dual{T, V, P}}, alg::AbstractNonlinearSolveAlgorithm, args...;
kwargs...) where {T, V, P}
sol, partials = scalar_nlsolve_ad(prob, alg, args...; kwargs...)
dual_soln = scalar_nlsolve_dual_soln(sol.u, partials, prob.p)
return SciMLBase.build_solution(prob, alg, dual_soln, sol.resid; sol.retcode)
end

function SciMLBase.solve(prob::NonlinearProblem{<:Union{Number, SVector, <:AbstractArray},
false, <:AbstractArray{<:Dual{T, V, P}}}, alg::AbstractNonlinearSolveAlgorithm,
args...; kwargs...) where {T, V, P}
sol, partials = scalar_nlsolve_ad(prob, alg, args...; kwargs...)
dual_soln = scalar_nlsolve_dual_soln(sol.u, partials, prob.p)
return SciMLBase.build_solution(prob, alg, dual_soln, sol.resid; sol.retcode)
end

function scalar_nlsolve_∂f_∂p(f, u, p)
ff = p isa Number ? ForwardDiff.derivative :
(u isa Number ? ForwardDiff.gradient : ForwardDiff.jacobian)
return ff(Base.Fix1(f, u), p)
@inline function __nlsolve_∂f_∂p(prob, f::F, u, p) where {F}
if isinplace(prob)
__f = p -> begin
du = similar(u, promote_type(eltype(u), eltype(p)))
f(du, u, p)
return du
end
else
__f = Base.Fix1(f, u)
end
if p isa Number
return __reshape(ForwardDiff.derivative(__f, p), :, 1)
elseif u isa Number
return __reshape(ForwardDiff.gradient(__f, p), 1, :)
else
return ForwardDiff.jacobian(__f, p)
end
end

function scalar_nlsolve_∂f_∂u(f, u, p)
ff = u isa Number ? ForwardDiff.derivative : ForwardDiff.jacobian
return ff(Base.Fix2(f, p), u)
@inline function __nlsolve_∂f_∂u(prob, f::F, u, p) where {F}
if isinplace(prob)
du = similar(u)
__f = (du, u) -> f(du, u, p)
ForwardDiff.jacobian(__f, du, u)
else
__f = Base.Fix2(f, p)
if u isa Number
return ForwardDiff.derivative(__f, u)
else
return ForwardDiff.jacobian(__f, u)
end
end
end

function scalar_nlsolve_dual_soln(u::Number, partials,
@inline function __nlsolve_dual_soln(u::Number, partials,
::Union{<:AbstractArray{<:Dual{T, V, P}}, Dual{T, V, P}}) where {T, V, P}
return Dual{T, V, P}(u, partials)
end

function scalar_nlsolve_dual_soln(u::AbstractArray, partials,
@inline function __nlsolve_dual_soln(u::AbstractArray, partials,
::Union{<:AbstractArray{<:Dual{T, V, P}}, Dual{T, V, P}}) where {T, V, P}
return map(((uᵢ, pᵢ),) -> Dual{T, V, P}(uᵢ, pᵢ), zip(u, partials))
_partials = _restructure(u, partials)
return map(((uᵢ, pᵢ),) -> Dual{T, V, P}(uᵢ, pᵢ), zip(u, _partials))
end
2 changes: 1 addition & 1 deletion src/jacobian.jl
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ function linsolve_caches(A, b, u, p, alg; linsolve_kwargs = (;))
(alg.linsolve === nothing && A isa SMatrix && linsolve_kwargs === (;))
# Default handling for SArrays in LinearSolve is not great. Some parts are patched
# but there are quite a few unnecessary allocations
return FakeLinearSolveJLCache(A, b)
return FakeLinearSolveJLCache(A, _vec(b))
end

linprob = LinearProblem(A, _vec(b); u0 = _vec(u), linsolve_kwargs...)
Expand Down
3 changes: 3 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -499,3 +499,6 @@ end
@inline __is_complex(::Type{ComplexF32}) = true
@inline __is_complex(::Type{Complex}) = true
@inline __is_complex(::Type{T}) where {T} = false

@inline __reshape(x::Number, args...) = x
@inline __reshape(x::AbstractArray, args...) = reshape(x, args...)
Loading