Skip to content

Start using termination conditions from DiffEqBase #208

New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Merged
merged 14 commits into from
Oct 26, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -79,6 +79,7 @@ StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"

[targets]
test = ["Enzyme", "BenchmarkTools", "SafeTestsets", "Pkg", "Test", "ForwardDiff", "StaticArrays", "Symbolics", "LinearSolve", "Random", "LinearAlgebra", "Zygote", "SparseDiffTools", "NonlinearProblemLibrary", "LeastSquaresOptim", "FastLevenbergMarquardt", "NaNMath", "BandedMatrices"]
test = ["Enzyme", "BenchmarkTools", "SafeTestsets", "Pkg", "Test", "ForwardDiff", "StaticArrays", "Symbolics", "LinearSolve", "Random", "LinearAlgebra", "Zygote", "SparseDiffTools", "NonlinearProblemLibrary", "LeastSquaresOptim", "FastLevenbergMarquardt", "NaNMath", "BandedMatrices", "DiffEqBase"]
53 changes: 44 additions & 9 deletions src/broyden.jl
Original file line number Diff line number Diff line change
@@ -31,6 +31,7 @@ end
f
alg
u
u_prev
du
fu
fu2
@@ -46,17 +47,21 @@ end
internalnorm
retcode::ReturnCode.T
abstol
reltol
reset_tolerance
reset_check
prob
stats::NLStats
lscache
termination_condition
tc_storage
end

get_fu(cache::GeneralBroydenCache) = cache.fu

function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::GeneralBroyden, args...;
alias_u0 = false, maxiters = 1000, abstol = 1e-6, internalnorm = DEFAULT_NORM,
alias_u0 = false, maxiters = 1000, abstol = nothing, reltol = nothing,
termination_condition = nothing, internalnorm = DEFAULT_NORM,
kwargs...) where {uType, iip}
@unpack f, u0, p = prob
u = alias_u0 ? u0 : deepcopy(u0)
@@ -65,23 +70,38 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::GeneralBroyde
reset_tolerance = alg.reset_tolerance === nothing ? sqrt(eps(eltype(u))) :
alg.reset_tolerance
reset_check = x -> abs(x) ≤ reset_tolerance
return GeneralBroydenCache{iip}(f, alg, u, _mutable_zero(u), fu, zero(fu),

abstol, reltol, termination_condition = _init_termination_elements(abstol,
reltol,
termination_condition,
eltype(u))

mode = DiffEqBase.get_termination_mode(termination_condition)

storage = mode ∈ DiffEqBase.SAFE_TERMINATION_MODES ? NLSolveSafeTerminationResult() :
nothing
return GeneralBroydenCache{iip}(f, alg, u, zero(u), _mutable_zero(u), fu, zero(fu),
zero(fu), p, J⁻¹, zero(_reshape(fu, 1, :)), _mutable_zero(u), false, 0,
alg.max_resets, maxiters, internalnorm, ReturnCode.Default, abstol, reset_tolerance,
alg.max_resets, maxiters, internalnorm, ReturnCode.Default, abstol, reltol,
reset_tolerance,
reset_check, prob, NLStats(1, 0, 0, 0, 0),
init_linesearch_cache(alg.linesearch, f, u, p, fu, Val(iip)))
init_linesearch_cache(alg.linesearch, f, u, p, fu, Val(iip)), termination_condition,
storage)
end

function perform_step!(cache::GeneralBroydenCache{true})
@unpack f, p, du, fu, fu2, dfu, u, J⁻¹, J⁻¹df, J⁻¹₂ = cache
@unpack f, p, du, fu, fu2, dfu, u, u_prev, J⁻¹, J⁻¹df, J⁻¹₂, tc_storage = cache

termination_condition = cache.termination_condition(tc_storage)
T = eltype(u)

mul!(_vec(du), J⁻¹, -_vec(fu))
α = perform_linesearch!(cache.lscache, u, du)
_axpy!(α, du, u)
f(fu2, u, p)

cache.internalnorm(fu2) < cache.abstol && (cache.force_stop = true)
termination_condition(fu2, u, u_prev, cache.abstol, cache.reltol) &&
(cache.force_stop = true)
cache.stats.nf += 1

cache.force_stop && return nothing
@@ -106,20 +126,25 @@ function perform_step!(cache::GeneralBroydenCache{true})
mul!(J⁻¹, _vec(du), J⁻¹₂, 1, 1)
end
fu .= fu2
@. u_prev = u

return nothing
end

function perform_step!(cache::GeneralBroydenCache{false})
@unpack f, p = cache
@unpack f, p, tc_storage = cache

termination_condition = cache.termination_condition(tc_storage)

T = eltype(cache.u)

cache.du = _restructure(cache.du, cache.J⁻¹ * -_vec(cache.fu))
α = perform_linesearch!(cache.lscache, cache.u, cache.du)
cache.u = cache.u .+ α * cache.du
cache.fu2 = f(cache.u, p)

cache.internalnorm(cache.fu2) < cache.abstol && (cache.force_stop = true)
termination_condition(cache.fu2, cache.u, cache.u_prev, cache.abstol, cache.reltol) &&
(cache.force_stop = true)
cache.stats.nf += 1

cache.force_stop && return nothing
@@ -142,12 +167,15 @@ function perform_step!(cache::GeneralBroydenCache{false})
cache.J⁻¹ = cache.J⁻¹ .+ _vec(cache.du) * cache.J⁻¹₂
end
cache.fu = cache.fu2
cache.u_prev = @. cache.u

return nothing
end

function SciMLBase.reinit!(cache::GeneralBroydenCache{iip}, u0 = cache.u; p = cache.p,
abstol = cache.abstol, maxiters = cache.maxiters) where {iip}
abstol = cache.abstol, reltol = cache.reltol,
termination_condition = cache.termination_condition,
maxiters = cache.maxiters) where {iip}
cache.p = p
if iip
recursivecopy!(cache.u, u0)
@@ -157,7 +185,14 @@ function SciMLBase.reinit!(cache::GeneralBroydenCache{iip}, u0 = cache.u; p = ca
cache.u = u0
cache.fu = cache.f(cache.u, p)
end
termination_condition = _get_reinit_termination_condition(cache,
abstol,
reltol,
termination_condition)

cache.abstol = abstol
cache.reltol = reltol
cache.termination_condition = termination_condition
cache.maxiters = maxiters
cache.stats.nf = 1
cache.stats.nsteps = 1
41 changes: 34 additions & 7 deletions src/dfsane.jl
Original file line number Diff line number Diff line change
@@ -88,12 +88,16 @@ end
internalnorm
retcode::SciMLBase.ReturnCode.T
abstol
reltol
prob
stats::NLStats
termination_condition
tc_storage
end

function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::DFSane, args...;
alias_u0 = false, maxiters = 1000, abstol = 1e-6, internalnorm = DEFAULT_NORM,
alias_u0 = false, maxiters = 1000, abstol = nothing, reltol = nothing,
termination_condition = nothing, internalnorm = DEFAULT_NORM,
kwargs...) where {uType, iip}
uₙ = alias_u0 ? prob.u0 : deepcopy(prob.u0)

@@ -122,14 +126,27 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::DFSane, args.
f₍ₙₒᵣₘ₎₀ = f₍ₙₒᵣₘ₎ₙ₋₁

ℋ = fill(f₍ₙₒᵣₘ₎ₙ₋₁, M)

abstol, reltol, termination_condition = _init_termination_elements(abstol,
reltol,
termination_condition,
T)

mode = DiffEqBase.get_termination_mode(termination_condition)

storage = mode ∈ DiffEqBase.SAFE_TERMINATION_MODES ? NLSolveSafeTerminationResult() :
nothing

return DFSaneCache{iip}(alg, uₙ, uₙ₋₁, fuₙ, fuₙ₋₁, 𝒹, ℋ, f₍ₙₒᵣₘ₎ₙ₋₁, f₍ₙₒᵣₘ₎₀,
M, σₙ, σₘᵢₙ, σₘₐₓ, α₁, γ, τₘᵢₙ, τₘₐₓ, nₑₓₚ, p, false, maxiters,
internalnorm, ReturnCode.Default, abstol, prob, NLStats(1, 0, 0, 0, 0))
internalnorm, ReturnCode.Default, abstol, reltol, prob, NLStats(1, 0, 0, 0, 0),
termination_condition, storage)
end

function perform_step!(cache::DFSaneCache{true})
@unpack alg, f₍ₙₒᵣₘ₎ₙ₋₁, f₍ₙₒᵣₘ₎₀, σₙ, σₘᵢₙ, σₘₐₓ, α₁, γ, τₘᵢₙ, τₘₐₓ, nₑₓₚ, M = cache
@unpack alg, f₍ₙₒᵣₘ₎ₙ₋₁, f₍ₙₒᵣₘ₎₀, σₙ, σₘᵢₙ, σₘₐₓ, α₁, γ, τₘᵢₙ, τₘₐₓ, nₑₓₚ, M, tc_storage = cache

termination_condition = cache.termination_condition(tc_storage)
f = (dx, x) -> cache.prob.f(dx, x, cache.p)

T = eltype(cache.uₙ)
@@ -174,7 +191,7 @@ function perform_step!(cache::DFSaneCache{true})
f₍ₙₒᵣₘ₎ₙ = norm(cache.fuₙ)^nₑₓₚ
end

if cache.internalnorm(cache.fuₙ) < cache.abstol
if termination_condition(cache.fuₙ, cache.uₙ, cache.uₙ₋₁, cache.abstol, cache.reltol)
cache.force_stop = true
end

@@ -205,8 +222,9 @@ function perform_step!(cache::DFSaneCache{true})
end

function perform_step!(cache::DFSaneCache{false})
@unpack alg, f₍ₙₒᵣₘ₎ₙ₋₁, f₍ₙₒᵣₘ₎₀, σₙ, σₘᵢₙ, σₘₐₓ, α₁, γ, τₘᵢₙ, τₘₐₓ, nₑₓₚ, M = cache
@unpack alg, f₍ₙₒᵣₘ₎ₙ₋₁, f₍ₙₒᵣₘ₎₀, σₙ, σₘᵢₙ, σₘₐₓ, α₁, γ, τₘᵢₙ, τₘₐₓ, nₑₓₚ, M, tc_storage = cache

termination_condition = cache.termination_condition(tc_storage)
f = x -> cache.prob.f(x, cache.p)

T = eltype(cache.uₙ)
@@ -249,7 +267,7 @@ function perform_step!(cache::DFSaneCache{false})
f₍ₙₒᵣₘ₎ₙ = norm(cache.fuₙ)^nₑₓₚ
end

if cache.internalnorm(cache.fuₙ) < cache.abstol
if termination_condition(cache.fuₙ, cache.uₙ, cache.uₙ₋₁, cache.abstol, cache.reltol)
cache.force_stop = true
end

@@ -296,7 +314,9 @@ function SciMLBase.solve!(cache::DFSaneCache)
end

function SciMLBase.reinit!(cache::DFSaneCache{iip}, u0 = cache.uₙ; p = cache.p,
abstol = cache.abstol, maxiters = cache.maxiters) where {iip}
abstol = cache.abstol, reltol = cache.reltol,
termination_condition = cache.termination_condition,
maxiters = cache.maxiters) where {iip}
cache.p = p
if iip
recursivecopy!(cache.uₙ, u0)
@@ -317,7 +337,14 @@ function SciMLBase.reinit!(cache::DFSaneCache{iip}, u0 = cache.uₙ; p = cache.p
T = eltype(cache.uₙ)
cache.σₙ = T(cache.alg.σ_1)

termination_condition = _get_reinit_termination_condition(cache,
abstol,
reltol,
termination_condition)

cache.abstol = abstol
cache.reltol = reltol
cache.termination_condition = termination_condition
cache.maxiters = maxiters
cache.stats.nf = 1
cache.stats.nsteps = 1
62 changes: 52 additions & 10 deletions src/gaussnewton.jl
Original file line number Diff line number Diff line change
@@ -49,13 +49,16 @@ end
function GaussNewton(; concrete_jac = nothing, linsolve = nothing,
precs = DEFAULT_PRECS, adkwargs...)
ad = default_adargs_to_adtype(; adkwargs...)
return GaussNewton{_unwrap_val(concrete_jac)}(ad, linsolve, precs)
return GaussNewton{_unwrap_val(concrete_jac)}(ad,
linsolve,
precs)
end

@concrete mutable struct GaussNewtonCache{iip} <: AbstractNonlinearSolveCache{iip}
f
alg
u
u_prev
fu1
fu2
fu_new
@@ -72,12 +75,17 @@ end
internalnorm
retcode::ReturnCode.T
abstol
reltol
prob
stats::NLStats
tc_storage
termination_condition
end

function SciMLBase.__init(prob::NonlinearLeastSquaresProblem{uType, iip}, alg_::GaussNewton,
args...; alias_u0 = false, maxiters = 1000, abstol = 1e-6, internalnorm = DEFAULT_NORM,
args...; alias_u0 = false, maxiters = 1000, abstol = nothing, reltol = nothing,
termination_condition = nothing,
internalnorm = DEFAULT_NORM,
kwargs...) where {uType, iip}
alg = get_concrete_algorithm(alg_, prob)
@unpack f, u0, p = prob
@@ -101,15 +109,29 @@ function SciMLBase.__init(prob::NonlinearLeastSquaresProblem{uType, iip}, alg_::
JᵀJ, Jᵀf = nothing, nothing
end

return GaussNewtonCache{iip}(f, alg, u, fu1, fu2, zero(fu1), du, p, uf, linsolve, J,
abstol, reltol, termination_condition = _init_termination_elements(abstol,
reltol,
termination_condition,
eltype(u); mode = NLSolveTerminationMode.AbsNorm)

mode = DiffEqBase.get_termination_mode(termination_condition)

storage = mode ∈ DiffEqBase.SAFE_TERMINATION_MODES ? NLSolveSafeTerminationResult() :
nothing

return GaussNewtonCache{iip}(f, alg, u, copy(u), fu1, fu2, zero(fu1), du, p, uf,
linsolve, J,
JᵀJ, Jᵀf, jac_cache, false, maxiters, internalnorm, ReturnCode.Default, abstol,
prob, NLStats(1, 0, 0, 0, 0))
reltol,
prob, NLStats(1, 0, 0, 0, 0), storage, termination_condition)
end

function perform_step!(cache::GaussNewtonCache{true})
@unpack u, fu1, f, p, alg, J, JᵀJ, Jᵀf, linsolve, du = cache
@unpack u, u_prev, fu1, f, p, alg, J, JᵀJ, Jᵀf, linsolve, du, tc_storage = cache
jacobian!!(J, cache)

termination_condition = cache.termination_condition(tc_storage)

if JᵀJ !== nothing
__matmul!(JᵀJ, J', J)
__matmul!(Jᵀf, J', fu1)
@@ -127,9 +149,15 @@ function perform_step!(cache::GaussNewtonCache{true})
@. u = u - du
f(cache.fu_new, u, p)

(cache.internalnorm(cache.fu_new .- cache.fu1) < cache.abstol ||
cache.internalnorm(cache.fu_new) < cache.abstol) &&
(termination_condition(cache.fu_new .- cache.fu1,
cache.u,
u_prev,
cache.abstol,
cache.reltol) ||
termination_condition(cache.fu_new, cache.u, u_prev, cache.abstol, cache.reltol)) &&
(cache.force_stop = true)

@. u_prev = u
cache.fu1 .= cache.fu_new
cache.stats.nf += 1
cache.stats.njacs += 1
@@ -139,7 +167,9 @@ function perform_step!(cache::GaussNewtonCache{true})
end

function perform_step!(cache::GaussNewtonCache{false})
@unpack u, fu1, f, p, alg, linsolve = cache
@unpack u, u_prev, fu1, f, p, alg, linsolve, tc_storage = cache

termination_condition = cache.termination_condition(tc_storage)

cache.J = jacobian!!(cache.J, cache)

@@ -164,7 +194,10 @@ function perform_step!(cache::GaussNewtonCache{false})
cache.u = @. u - cache.du # `u` might not support mutation
cache.fu_new = f(cache.u, p)

(cache.internalnorm(cache.fu_new) < cache.abstol) && (cache.force_stop = true)
termination_condition(cache.fu_new, cache.u, u_prev, cache.abstol, cache.reltol) &&
(cache.force_stop = true)

cache.u_prev = @. cache.u
cache.fu1 = cache.fu_new
cache.stats.nf += 1
cache.stats.njacs += 1
@@ -174,7 +207,9 @@ function perform_step!(cache::GaussNewtonCache{false})
end

function SciMLBase.reinit!(cache::GaussNewtonCache{iip}, u0 = cache.u; p = cache.p,
abstol = cache.abstol, maxiters = cache.maxiters) where {iip}
abstol = cache.abstol, reltol = cache.reltol,
termination_condition = cache.termination_condition,
maxiters = cache.maxiters) where {iip}
cache.p = p
if iip
recursivecopy!(cache.u, u0)
@@ -184,7 +219,14 @@ function SciMLBase.reinit!(cache::GaussNewtonCache{iip}, u0 = cache.u; p = cache
cache.u = u0
cache.fu1 = cache.f(cache.u, p)
end
termination_condition = _get_reinit_termination_condition(cache,
abstol,
reltol,
termination_condition)

cache.abstol = abstol
cache.reltol = reltol
cache.termination_condition = termination_condition
cache.maxiters = maxiters
cache.stats.nf = 1
cache.stats.nsteps = 1
Loading