Skip to content

Add Halley's method via descent API #404

New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -25,6 +25,7 @@ PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
Preferences = "21216c6a-2e73-6563-6e65-726566657250"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
SciMLJacobianOperators = "19f34311-ddf3-4b8b-af20-060888a46c0e"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

needed?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure why this comes in, I think that got in when I dev the libs. can be removed

SimpleNonlinearSolve = "727e6d20-b764-4bd8-a329-72de5adea6c7"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
SparseMatrixColorings = "0a514795-09f3-496d-8182-132a7b665d35"
@@ -113,6 +114,7 @@ StaticArrays = "1.9"
StaticArraysCore = "1.4"
Sundials = "4.23.1"
SymbolicIndexingInterface = "0.3.31"
TaylorDiff = "0.3"
Test = "1.10"
Zygote = "0.6.69"
julia = "1.10"
@@ -146,8 +148,9 @@ SpeedMapping = "f1835b91-879b-4a3f-a438-e4baacf14412"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Sundials = "c3572dad-4567-51f8-b174-8c6c989267f4"
TaylorDiff = "b36ab563-344f-407b-a36a-4f200bebf99c"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["Aqua", "BandedMatrices", "BenchmarkTools", "CUDA", "Enzyme", "ExplicitImports", "FastLevenbergMarquardt", "FixedPointAcceleration", "Hwloc", "InteractiveUtils", "LeastSquaresOptim", "LineSearches", "MINPACK", "NLSolvers", "NLsolve", "NaNMath", "NonlinearProblemLibrary", "OrdinaryDiffEqTsit5", "PETSc", "Pkg", "Random", "ReTestItems", "SIAMFANLEquations", "SparseConnectivityTracer", "SpeedMapping", "StableRNGs", "StaticArrays", "Sundials", "Test", "Zygote"]
test = ["Aqua", "BandedMatrices", "BenchmarkTools", "CUDA", "Enzyme", "ExplicitImports", "FastLevenbergMarquardt", "FixedPointAcceleration", "Hwloc", "InteractiveUtils", "LeastSquaresOptim", "LineSearches", "MINPACK", "NLSolvers", "NLsolve", "NaNMath", "NonlinearProblemLibrary", "OrdinaryDiffEqTsit5", "PETSc", "Pkg", "Random", "ReTestItems", "SIAMFANLEquations", "SparseConnectivityTracer", "SpeedMapping", "StableRNGs", "StaticArrays", "Sundials", "TaylorDiff", "Test", "Zygote"]
3 changes: 3 additions & 0 deletions lib/NonlinearSolveBase/Project.toml
Original file line number Diff line number Diff line change
@@ -35,6 +35,7 @@ LineSearch = "87fe0de2-c867-4266-b59a-2f0a94fc965b"
LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
SparseMatrixColorings = "0a514795-09f3-496d-8182-132a7b665d35"
TaylorDiff = "b36ab563-344f-407b-a36a-4f200bebf99c"

[extensions]
NonlinearSolveBaseBandedMatricesExt = "BandedMatrices"
@@ -44,6 +45,7 @@ NonlinearSolveBaseLineSearchExt = "LineSearch"
NonlinearSolveBaseLinearSolveExt = "LinearSolve"
NonlinearSolveBaseSparseArraysExt = "SparseArrays"
NonlinearSolveBaseSparseMatrixColoringsExt = "SparseMatrixColorings"
NonlinearSolveBaseTaylorDiffExt = "TaylorDiff"

[compat]
ADTypes = "1.9"
@@ -77,6 +79,7 @@ SparseArrays = "1.10"
SparseMatrixColorings = "0.4.5"
StaticArraysCore = "1.4"
SymbolicIndexingInterface = "0.3.31"
TaylorDiff = "0.3"
Test = "1.10"
TimerOutputs = "0.5.23"
julia = "1.10"
20 changes: 20 additions & 0 deletions lib/NonlinearSolveBase/ext/NonlinearSolveBaseTaylorDiffExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
module NonlinearSolveBaseTaylorDiffExt
using SciMLBase: NonlinearFunction
using NonlinearSolveBase: HalleyDescentCache
import NonlinearSolveBase: evaluate_hvvp
using TaylorDiff: derivative, derivative!
using FastClosures: @closure

function evaluate_hvvp(
Comment on lines +3 to +8
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
using NonlinearSolveBase: HalleyDescentCache
import NonlinearSolveBase: evaluate_hvvp
using TaylorDiff: derivative, derivative!
using FastClosures: @closure
function evaluate_hvvp(
using NonlinearSolveBase: NonlinearSolveBase, HalleyDescentCache
using TaylorDiff: derivative, derivative!
using FastClosures: @closure
function NonlinearSolveBase.evaluate_hvvp(

style nit

hvvp, cache::HalleyDescentCache, f::NonlinearFunction{iip}, p, u, δu) where {iip}
if iip
binary_f = @closure (y, x) -> f(y, x, p)
derivative!(hvvp, binary_f, cache.fu, u, δu, Val(2))
else
unary_f = Base.Fix2(f, p)
hvvp = derivative(unary_f, u, δu, Val(2))
end
hvvp
end

end
1 change: 1 addition & 0 deletions lib/NonlinearSolveBase/src/NonlinearSolveBase.jl
Original file line number Diff line number Diff line change
@@ -51,6 +51,7 @@ include("polyalg.jl")

include("descent/common.jl")
include("descent/newton.jl")
include("descent/halley.jl")
include("descent/steepest.jl")
include("descent/damped_newton.jl")
include("descent/dogleg.jl")
100 changes: 100 additions & 0 deletions lib/NonlinearSolveBase/src/descent/halley.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
"""
HalleyDescent(; linsolve = nothing)

Improve the NewtonDescent with higher-order terms. First compute the descent direction as ``J a = -fu``.
Then compute the hessian-vector-vector product and solve for the second-order correction term as ``J b = H a a``.
Finally, compute the descent direction as ``δu = a * a / (b / 2 - a)``.

Note that `import TaylorDiff` is required to use this descent algorithm.

See also [`NewtonDescent`](@ref).
"""
@kwdef @concrete struct HalleyDescent <: AbstractDescentDirection
linsolve = nothing
end

supports_line_search(::HalleyDescent) = true

@concrete mutable struct HalleyDescentCache <: AbstractDescentCache
f
p
δu
δus
b
fu
hvvp
lincache
timer
preinverted_jacobian <: Union{Val{false}, Val{true}}
end

@internal_caches HalleyDescentCache :lincache

function InternalAPI.init(
prob::NonlinearProblem, alg::HalleyDescent, J, fu, u; stats,
shared = Val(1), pre_inverted::Val = Val(false),
linsolve_kwargs = (;), abstol = nothing, reltol = nothing,
timer = get_timer_output(), kwargs...)
@bb δu = similar(u)
@bb b = similar(u)
@bb fu = similar(fu)
@bb hvvp = similar(fu)
δus = Utils.unwrap_val(shared) ≤ 1 ? nothing : map(2:Utils.unwrap_val(shared)) do i
@bb δu_ = similar(u)
end
lincache = Utils.unwrap_val(pre_inverted) ? nothing :
construct_linear_solver(
alg, alg.linsolve, J, Utils.safe_vec(fu), Utils.safe_vec(u);
stats, abstol, reltol, linsolve_kwargs...
)
return HalleyDescentCache(
prob.f, prob.p, δu, δus, b, fu, hvvp, lincache, timer, pre_inverted)
end

function InternalAPI.solve!(
cache::HalleyDescentCache, J, fu, u, idx::Val = Val(1);
skip_solve::Bool = false, new_jacobian::Bool = true, kwargs...)
δu = SciMLBase.get_du(cache, idx)
skip_solve && return DescentResult(; δu)
if preinverted_jacobian(cache)
@assert J!==nothing "`J` must be provided when `pre_inverted = Val(true)`."
@bb δu = J × vec(fu)
else
@static_timeit cache.timer "linear solve 1" begin
linres = cache.lincache(;
A = J, b = Utils.safe_vec(fu),
kwargs..., linu = Utils.safe_vec(δu),
reuse_A_if_factorization = !new_jacobian || (idx !== Val(1)))
δu = Utils.restructure(SciMLBase.get_du(cache, idx), linres.u)
if !linres.success
set_du!(cache, δu, idx)
return DescentResult(; δu, success = false, linsolve_success = false)
end
end
end
b = cache.b
# compute the hessian-vector-vector product
hvvp = evaluate_hvvp(cache.hvvp, cache, cache.f, cache.p, u, δu)
# second linear solve, reuse factorization if possible
if preinverted_jacobian(cache)
@bb b = J × vec(hvvp)
else
@static_timeit cache.timer "linear solve 2" begin
linres = cache.lincache(;
A = J, b = Utils.safe_vec(hvvp),
kwargs..., linu = Utils.safe_vec(b),
reuse_A_if_factorization = true)
b = Utils.restructure(cache.b, linres.u)
if !linres.success
set_du!(cache, δu, idx)
return DescentResult(; δu, success = false, linsolve_success = false)
end
end
end
@bb @. δu = δu * δu / (b / 2 - δu)
set_du!(cache, δu, idx)
cache.b = b
return DescentResult(; δu)
end

evaluate_hvvp(hvvp, cache, f, p, u, δu) = error("not implemented. please import TaylorDiff")
7 changes: 4 additions & 3 deletions lib/NonlinearSolveFirstOrder/src/NonlinearSolveFirstOrder.jl
Original file line number Diff line number Diff line change
@@ -20,8 +20,8 @@ using NonlinearSolveBase: NonlinearSolveBase, AbstractNonlinearSolveAlgorithm,
AbstractTrustRegionMethodCache,
Utils, InternalAPI, get_timer_output, @static_timeit,
update_trace!, L2_NORM,
NewtonDescent, DampedNewtonDescent, GeodesicAcceleration,
Dogleg
NewtonDescent, DampedNewtonDescent, HalleyDescent,
GeodesicAcceleration, Dogleg
using SciMLBase: SciMLBase, AbstractNonlinearProblem, NLStats, ReturnCode,
NonlinearFunction,
NonlinearLeastSquaresProblem, NonlinearProblem, NoSpecialize
@@ -31,6 +31,7 @@ using FiniteDiff: FiniteDiff # Default Finite Difference Method
using ForwardDiff: ForwardDiff # Default Forward Mode AD

include("raphson.jl")
include("halley.jl")
include("gauss_newton.jl")
include("levenberg_marquardt.jl")
include("trust_region.jl")
@@ -93,7 +94,7 @@ end

@reexport using SciMLBase, NonlinearSolveBase

export NewtonRaphson, PseudoTransient
export NewtonRaphson, Halley, PseudoTransient
export GaussNewton, LevenbergMarquardt, TrustRegion

export RadiusUpdateSchemes
15 changes: 15 additions & 0 deletions lib/NonlinearSolveFirstOrder/src/halley.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
"""
Halley(; concrete_jac = nothing, linsolve = nothing, linesearch = missing,
autodiff = nothing)

An experimental Halley's method implementation. Improves the convergence rate of Newton's method by using second-order derivative information to correct the descent direction.

Currently depends on TaylorDiff.jl to handle the correction terms,
might have more general implementation in the future.
"""
function Halley(; concrete_jac = nothing, linsolve = nothing,
linesearch = missing, autodiff = nothing)
return GeneralizedFirstOrderAlgorithm(;
concrete_jac, name = :Halley, linesearch,
descent = HalleyDescent(; linsolve), autodiff)
end
Comment on lines +1 to +15
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not really a First Order method, we might want an additional split cc @ChrisRackauckas

9 changes: 7 additions & 2 deletions test/23_test_problems_tests.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
@testsetup module RobustnessTesting
using NonlinearSolve, LinearAlgebra, LinearSolve, NonlinearProblemLibrary, Test
import TaylorDiff

problems = NonlinearProblemLibrary.problems
dicts = NonlinearProblemLibrary.dicts
@@ -61,10 +62,14 @@ end
end

@testitem "23 Test Problems: Halley" setup=[RobustnessTesting] tags=[:core] begin
alg_ops = (SimpleHalley(; autodiff = AutoForwardDiff()),)
alg_ops = (
Halley(),
SimpleHalley(; autodiff = AutoForwardDiff())
)

broken_tests = Dict(alg => Int[] for alg in alg_ops)
broken_tests[alg_ops[1]] = [1, 5, 15, 16, 18]
broken_tests[alg_ops[1]] = [1, 5, 15, 16]
broken_tests[alg_ops[2]] = [1, 5, 15, 16, 18]

test_on_library(problems, dicts, alg_ops, broken_tests)
end
Loading