Skip to content

Commit 54a3c91

Browse files
committed
Add PotraPtak3
1 parent 949ba28 commit 54a3c91

File tree

5 files changed

+153
-6
lines changed

5 files changed

+153
-6
lines changed

src/NonlinearSolve.jl

+6-2
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ include("core/spectral_methods.jl")
6969

7070
include("algorithms/raphson.jl")
7171
include("algorithms/pseudo_transient.jl")
72+
include("algorithms/multistep.jl")
7273
include("algorithms/broyden.jl")
7374
include("algorithms/klement.jl")
7475
include("algorithms/lbroyden.jl")
@@ -130,7 +131,8 @@ include("default.jl")
130131
end
131132

132133
# Core Algorithms
133-
export NewtonRaphson, PseudoTransient, Klement, Broyden, LimitedMemoryBroyden, DFSane
134+
export NewtonRaphson, PseudoTransient, Klement, Broyden, LimitedMemoryBroyden, DFSane,
135+
MultiStepNonlinearSolver
134136
export GaussNewton, LevenbergMarquardt, TrustRegion
135137
export NonlinearSolvePolyAlgorithm,
136138
RobustMultiNewton, FastShortcutNonlinearPolyalg, FastShortcutNLLSPolyalg
@@ -144,7 +146,9 @@ export GeneralizedFirstOrderAlgorithm, ApproximateJacobianSolveAlgorithm, Genera
144146

145147
# Descent Algorithms
146148
export NewtonDescent, SteepestDescent, Dogleg, DampedNewtonDescent,
147-
GeodesicAcceleration
149+
GeodesicAcceleration, GenericMultiStepDescent
150+
## Multistep Algorithms
151+
export MultiStepSchemes
148152

149153
# Globalization
150154
## Line Search Algorithms

src/algorithms/multistep.jl

+7
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
function MultiStepNonlinearSolver(; concrete_jac = nothing, linsolve = nothing,
2+
scheme = MSS.PotraPtak3, precs = DEFAULT_PRECS, autodiff = nothing)
3+
descent = GenericMultiStepDescent(; scheme, linsolve, precs)
4+
# TODO: Use the scheme as the name
5+
return GeneralizedFirstOrderAlgorithm(; concrete_jac, name = :MultiStepNonlinearSolver,
6+
descent, jacobian_ad = autodiff)
7+
end

src/descent/common.jl

+4-4
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,11 @@ Construct a `DescentResult` object.
55
66
### Keyword Arguments
77
8-
* `δu`: The descent direction.
9-
* `u`: The new iterate. This is provided only for multi-step methods currently.
10-
* `success`: Certain Descent Algorithms can reject a descent direction for example
8+
- `δu`: The descent direction.
9+
- `u`: The new iterate. This is provided only for multi-step methods currently.
10+
- `success`: Certain Descent Algorithms can reject a descent direction for example
1111
[`GeodesicAcceleration`](@ref).
12-
* `extras`: A named tuple containing intermediates computed during the solve.
12+
- `extras`: A named tuple containing intermediates computed during the solve.
1313
For example, [`GeodesicAcceleration`](@ref) returns `NamedTuple{(:v, :a)}` containing
1414
the "velocity" and "acceleration" terms.
1515
"""

src/descent/multistep.jl

+135
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
"""
2+
MultiStepSchemes
3+
4+
This module defines the multistep schemes used in the multistep descent algorithms. The
5+
naming convention follows <name of method><order of convergence>. The name of method is
6+
typically the last names of the authors of the paper that introduced the method.
7+
"""
8+
module MultiStepSchemes
9+
10+
abstract type AbstractMultiStepScheme end
11+
12+
function Base.show(io::IO, mss::AbstractMultiStepScheme)
13+
print(io, "MultiStepSchemes.$(string(nameof(typeof(mss)))[3:end])")
14+
end
15+
16+
struct __PotraPtak3 <: AbstractMultiStepScheme end
17+
const PotraPtak3 = __PotraPtak3()
18+
19+
alg_steps(::__PotraPtak3) = 1
20+
21+
struct __SinghSharma4 <: AbstractMultiStepScheme end
22+
const SinghSharma4 = __SinghSharma4()
23+
24+
alg_steps(::__SinghSharma4) = 3
25+
26+
struct __SinghSharma5 <: AbstractMultiStepScheme end
27+
const SinghSharma5 = __SinghSharma5()
28+
29+
alg_steps(::__SinghSharma5) = 3
30+
31+
struct __SinghSharma7 <: AbstractMultiStepScheme end
32+
const SinghSharma7 = __SinghSharma7()
33+
34+
alg_steps(::__SinghSharma7) = 4
35+
36+
end
37+
38+
const MSS = MultiStepSchemes
39+
40+
@kwdef @concrete struct GenericMultiStepDescent <: AbstractDescentAlgorithm
41+
scheme
42+
linsolve = nothing
43+
precs = DEFAULT_PRECS
44+
end
45+
46+
supports_line_search(::GenericMultiStepDescent) = false
47+
supports_trust_region(::GenericMultiStepDescent) = false
48+
49+
@concrete mutable struct GenericMultiStepDescentCache{S, INV} <: AbstractDescentCache
50+
f
51+
p
52+
δu
53+
δus
54+
scheme::S
55+
lincache
56+
timer
57+
nf::Int
58+
end
59+
60+
@internal_caches GenericMultiStepDescentCache :lincache
61+
62+
function __reinit_internal!(cache::GenericMultiStepDescentCache, args...; p = cache.p,
63+
kwargs...)
64+
cache.nf = 0
65+
cache.p = p
66+
end
67+
68+
function __δu_caches(scheme::MSS.__PotraPtak3, fu, u, ::Val{N}) where {N}
69+
caches = ntuple(N) do i
70+
@bb δu = similar(u)
71+
@bb y = similar(u)
72+
@bb fy = similar(fu)
73+
@bb δy = similar(u)
74+
@bb u_new = similar(u)
75+
(δu, δy, fy, y, u_new)
76+
end
77+
return first(caches), (N 1 ? nothing : caches[2:end])
78+
end
79+
80+
function __internal_init(prob::NonlinearProblem, alg::GenericMultiStepDescent, J, fu, u;
81+
shared::Val{N} = Val(1), pre_inverted::Val{INV} = False, linsolve_kwargs = (;),
82+
abstol = nothing, reltol = nothing, timer = get_timer_output(),
83+
kwargs...) where {INV, N}
84+
δu, δus = __δu_caches(alg.scheme, fu, u, shared)
85+
INV && return GenericMultiStepDescentCache{true}(prob.f, prob.p, δu, δus,
86+
alg.scheme, nothing, timer, 0)
87+
lincache = LinearSolverCache(alg, alg.linsolve, J, _vec(fu), _vec(u); abstol, reltol,
88+
linsolve_kwargs...)
89+
return GenericMultiStepDescentCache{false}(prob.f, prob.p, δu, δus, alg.scheme,
90+
lincache, timer, 0)
91+
end
92+
93+
function __internal_init(prob::NonlinearLeastSquaresProblem, alg::GenericMultiStepDescent,
94+
J, fu, u; kwargs...)
95+
error("Multi-Step Descent Algorithms for NLLS are not implemented yet.")
96+
end
97+
98+
function __internal_solve!(cache::GenericMultiStepDescentCache{MSS.__PotraPtak3, INV}, J,
99+
fu, u, idx::Val = Val(1); skip_solve::Bool = false, new_jacobian::Bool = true,
100+
kwargs...) where {INV}
101+
(u_new, δy, fy, y, δu) = get_du(cache, idx)
102+
skip_solve && return DescentResult(; u = u_new)
103+
104+
@static_timeit cache.timer "linear solve" begin
105+
@static_timeit cache.timer "solve and step 1" begin
106+
if INV
107+
J !== nothing && @bb(δu = J × _vec(fu))
108+
else
109+
δu = cache.lincache(; A = J, b = _vec(fu), kwargs..., linu = _vec(δu),
110+
du = _vec(δu),
111+
reuse_A_if_factorization = !new_jacobian || (idx !== Val(1)))
112+
δu = _restructure(u, δu)
113+
114+
end
115+
@bb @. y = u - δu
116+
end
117+
118+
fy = evaluate_f!!(cache.f, fy, y, cache.p)
119+
cache.nf += 1
120+
121+
@static_timeit cache.timer "solve and step 2" begin
122+
if INV
123+
J !== nothing && @bb(δy = J × _vec(fy))
124+
else
125+
δy = cache.lincache(; A = J, b = _vec(fy), kwargs..., linu = _vec(δy),
126+
du = _vec(δy), reuse_A_if_factorization = true)
127+
δy = _restructure(u, δy)
128+
end
129+
@bb @. u_new = y - δy
130+
end
131+
end
132+
133+
set_du!(cache, (u_new, δy, fy, y, δu), idx)
134+
return DescentResult(; u = u_new)
135+
end

src/internal/tracing.jl

+1
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,7 @@ function update_trace!(cache::AbstractNonlinearSolveCache, α = true)
187187
trace === nothing && return nothing
188188

189189
J = __getproperty(cache, Val(:J))
190+
# TODO: fix tracing for multi-step methods where du is not aliased properly
190191
if J === nothing
191192
update_trace!(trace, get_nsteps(cache) + 1, get_u(cache), get_fu(cache),
192193
nothing, cache.du, α)

0 commit comments

Comments
 (0)