Skip to content

Commit dccc1dd

Browse files
committed
Reuse NewtonDescent for MultiStepSchemes
1 parent ceeadcb commit dccc1dd

File tree

7 files changed

+118
-72
lines changed

7 files changed

+118
-72
lines changed

Project.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ NonlinearSolveZygoteExt = "Zygote"
5656

5757
[compat]
5858
ADTypes = "0.2.6"
59-
Accessors = "0.1"
59+
Accessors = "0.1.32"
6060
Aqua = "0.8"
6161
ArrayInterface = "7.7"
6262
BandedMatrices = "1.4"

docs/src/basics/faq.md

+2-2
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ differentiate the function based on the input types. However, this function has
7272
`xx = [1.0, 2.0, 3.0, 4.0]` followed by a `xx[1] = var[1] - v_true[1]` where `var` might
7373
be a Dual number. This causes the error. To fix it:
7474

75-
1. Specify the `autodiff` to be `AutoFiniteDiff`
75+
1. Specify the `autodiff` to be `AutoFiniteDiff`
7676

7777
```@example dual_error_faq
7878
sol = solve(prob_oop, LevenbergMarquardt(; autodiff = AutoFiniteDiff()); maxiters = 10000,
@@ -81,7 +81,7 @@ sol = solve(prob_oop, LevenbergMarquardt(; autodiff = AutoFiniteDiff()); maxiter
8181

8282
This worked but, Finite Differencing is not the recommended approach in any scenario.
8383

84-
2. Rewrite the function to use
84+
2. Rewrite the function to use
8585
[PreallocationTools.jl](https://github.com/SciML/PreallocationTools.jl) or write it as
8686

8787
```@example dual_error_faq

docs/src/basics/sparsity_detection.md

+2-2
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ prob = NonlinearProblem(
3434
If the `colorvec` is not provided, then it is computed on demand.
3535

3636
!!! note
37-
37+
3838
One thing to be careful about in this case is that `colorvec` is dependent on the
3939
autodiff backend used. Forward Mode and Finite Differencing will assume that the
4040
colorvec is the column colorvec, while Reverse Mode will assume that the colorvec is the
@@ -76,7 +76,7 @@ loaded, we default to using `SymbolicsSparsityDetection()`, else we default to u
7676
options if those are provided.
7777

7878
!!! warning
79-
79+
8080
If you provide a non-sparse AD, and provide a `sparsity` or `jac_prototype` then
8181
we will use dense AD. This is because, if you provide a specific AD type, we assume
8282
that you know what you are doing and want to override the default choice of `nothing`.

docs/src/tutorials/large_systems.md

+7-7
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ equation (BRUSS) using NonlinearSolve.jl.
1010
## Definition of the Brusselator Equation
1111

1212
!!! note
13-
13+
1414
Feel free to skip this section: it simply defines the example problem.
1515

1616
The Brusselator PDE is defined as follows:
@@ -118,11 +118,11 @@ However, if you know the sparsity of your problem, then you can pass a different
118118
type. For example, a `SparseMatrixCSC` will give a sparse matrix. Other sparse matrix types
119119
include:
120120

121-
- Bidiagonal
122-
- Tridiagonal
123-
- SymTridiagonal
124-
- BandedMatrix ([BandedMatrices.jl](https://github.com/JuliaLinearAlgebra/BandedMatrices.jl))
125-
- BlockBandedMatrix ([BlockBandedMatrices.jl](https://github.com/JuliaLinearAlgebra/BlockBandedMatrices.jl))
121+
- Bidiagonal
122+
- Tridiagonal
123+
- SymTridiagonal
124+
- BandedMatrix ([BandedMatrices.jl](https://github.com/JuliaLinearAlgebra/BandedMatrices.jl))
125+
- BlockBandedMatrix ([BlockBandedMatrices.jl](https://github.com/JuliaLinearAlgebra/BlockBandedMatrices.jl))
126126

127127
## Approximate Sparsity Detection & Sparse Jacobians
128128

@@ -213,7 +213,7 @@ choices, see the
213213
`linsolve` choices are any valid [LinearSolve.jl](https://linearsolve.sciml.ai/dev/) solver.
214214

215215
!!! note
216-
216+
217217
Switching to a Krylov linear solver will automatically change the nonlinear problem
218218
solver into Jacobian-free mode, dramatically reducing the memory required. This can be
219219
overridden by adding `concrete_jac=true` to the algorithm.

src/abstract_types.jl

+28
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,11 @@ Returns a result of type [`DescentResult`](@ref).
8787
- `get_du(cache, ::Val{N})`: get the `N`th descent direction.
8888
- `set_du!(cache, δu)`: set the descent direction.
8989
- `set_du!(cache, δu, ::Val{N})`: set the `N`th descent direction.
90+
- `get_internal_cache(cache, ::Val{field})`: get the internal cache field.
91+
- `get_internal_cache(cache, field::Val, ::Val{N})`: get the `N`th internal cache field.
92+
- `set_internal_cache!(cache, value, ::Val{field})`: set the internal cache field.
93+
- `set_internal_cache!(cache, value, field::Val, ::Val{N})`: set the `N`th internal cache
94+
field.
9095
- `last_step_accepted(cache)`: whether or not the last step was accepted. Checks if the
9196
cache has a `last_step_accepted` field and returns it if it does, else returns `true`.
9297
"""
@@ -98,6 +103,29 @@ SciMLBase.get_du(cache::AbstractDescentCache, ::Val{N}) where {N} = cache.δus[N
98103
set_du!(cache::AbstractDescentCache, δu) = (cache.δu = δu)
99104
set_du!(cache::AbstractDescentCache, δu, ::Val{1}) = set_du!(cache, δu)
100105
set_du!(cache::AbstractDescentCache, δu, ::Val{N}) where {N} = (cache.δus[N - 1] = δu)
106+
function get_internal_cache(cache::AbstractDescentCache, ::Val{field}) where {field}
107+
return getproperty(cache, field)
108+
end
109+
function get_internal_cache(cache::AbstractDescentCache, field::Val, ::Val{1})
110+
return get_internal_cache(cache, field)
111+
end
112+
function get_internal_cache(
113+
cache::AbstractDescentCache, ::Val{field}, ::Val{N}) where {field, N}
114+
true_field = Symbol(string(field), "s") # Julia 1.10 compiles this away
115+
return getproperty(cache, true_field)[N]
116+
end
117+
function set_internal_cache!(cache::AbstractDescentCache, value, ::Val{field}) where {field}
118+
return setproperty!(cache, field, value)
119+
end
120+
function set_internal_cache!(
121+
cache::AbstractDescentCache, value, field::Val, ::Val{1})
122+
return set_internal_cache!(cache, value, field)
123+
end
124+
function set_internal_cache!(
125+
cache::AbstractDescentCache, value, ::Val{field}, ::Val{N}) where {field, N}
126+
true_field = Symbol(string(field), "s") # Julia 1.10 compiles this away
127+
return setproperty!(cache, true_field, value, N)
128+
end
101129

102130
function last_step_accepted(cache::AbstractDescentCache)
103131
hasfield(typeof(cache), :last_step_accepted) && return cache.last_step_accepted

src/algorithms/multistep.jl

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
function MultiStepNonlinearSolver(; concrete_jac = nothing, linsolve = nothing,
22
scheme = MSS.PotraPtak3, precs = DEFAULT_PRECS, autodiff = nothing,
3-
vjp_autodiff = nothing)
3+
vjp_autodiff = nothing, linesearch = NoLineSearch())
44
scheme_concrete = apply_patch(scheme, (; autodiff, vjp_autodiff))
55
descent = GenericMultiStepDescent(; scheme = scheme_concrete, linsolve, precs)
66
return GeneralizedFirstOrderAlgorithm(; concrete_jac, name = MSS.display_name(scheme),
7-
descent, jacobian_ad = autodiff)
7+
descent, jacobian_ad = autodiff, linesearch, reverse_ad = vjp_autodiff)
88
end

src/descent/multistep.jl

+76-58
Original file line numberDiff line numberDiff line change
@@ -21,23 +21,24 @@ struct __PotraPtak3 <: AbstractMultiStepScheme end
2121
const PotraPtak3 = __PotraPtak3()
2222

2323
alg_steps(::__PotraPtak3) = 2
24+
nintermediates(::__PotraPtak3) = 1
2425

2526
@kwdef @concrete struct __SinghSharma4 <: AbstractMultiStepScheme
26-
vjp_autodiff = nothing
27+
jvp_autodiff = nothing
2728
end
2829
const SinghSharma4 = __SinghSharma4()
2930

3031
alg_steps(::__SinghSharma4) = 3
3132

3233
@kwdef @concrete struct __SinghSharma5 <: AbstractMultiStepScheme
33-
vjp_autodiff = nothing
34+
jvp_autodiff = nothing
3435
end
3536
const SinghSharma5 = __SinghSharma5()
3637

3738
alg_steps(::__SinghSharma5) = 3
3839

3940
@kwdef @concrete struct __SinghSharma7 <: AbstractMultiStepScheme
40-
vjp_autodiff = nothing
41+
jvp_autodiff = nothing
4142
end
4243
const SinghSharma7 = __SinghSharma7()
4344

@@ -60,93 +61,110 @@ end
6061

6162
Base.show(io::IO, alg::GenericMultiStepDescent) = print(io, "$(alg.scheme)()")
6263

63-
supports_line_search(::GenericMultiStepDescent) = false
64+
supports_line_search(::GenericMultiStepDescent) = true
6465
supports_trust_region(::GenericMultiStepDescent) = false
6566

66-
@concrete mutable struct GenericMultiStepDescentCache{S, INV} <: AbstractDescentCache
67+
@concrete mutable struct GenericMultiStepDescentCache{S} <: AbstractDescentCache
6768
f
6869
p
6970
δu
7071
δus
71-
extras
72+
u
73+
us
74+
fu
75+
fus
76+
internal_cache
77+
internal_caches
7278
scheme::S
73-
lincache
7479
timer
7580
nf::Int
7681
end
7782

78-
@internal_caches GenericMultiStepDescentCache :lincache
83+
# FIXME: @internal_caches needs to be updated to support tuples and namedtuples
84+
# @internal_caches GenericMultiStepDescentCache :internal_caches
7985

8086
function __reinit_internal!(cache::GenericMultiStepDescentCache, args...; p = cache.p,
8187
kwargs...)
8288
cache.nf = 0
8389
cache.p = p
90+
reset_timer!(cache.timer)
8491
end
8592

86-
function __δu_caches(scheme::MSS.__PotraPtak3, fu, u, ::Val{N}) where {N}
87-
caches = ntuple(N) do i
88-
@bb δu = similar(u)
89-
@bb y = similar(u)
90-
@bb fy = similar(fu)
91-
@bb δy = similar(u)
92-
@bb u_new = similar(u)
93-
(δu, δy, fy, y, u_new)
93+
function __internal_multistep_caches(
94+
scheme::MSS.__PotraPtak3, alg::GenericMultiStepDescent,
95+
prob, args...; shared::Val{N} = Val(1), kwargs...) where {N}
96+
internal_descent = NewtonDescent(; alg.linsolve, alg.precs)
97+
internal_cache = __internal_init(
98+
prob, internal_descent, args...; kwargs..., shared = Val(2))
99+
internal_caches = N 1 ? nothing :
100+
map(2:N) do i
101+
__internal_init(prob, internal_descent, args...; kwargs..., shared = Val(2))
94102
end
95-
return first(caches), (N 1 ? nothing : caches[2:end])
103+
return internal_cache, internal_caches
96104
end
97105

98-
function __internal_init(prob::NonlinearProblem, alg::GenericMultiStepDescent, J, fu, u;
99-
shared::Val{N} = Val(1), pre_inverted::Val{INV} = False, linsolve_kwargs = (;),
106+
function __internal_init(prob::Union{NonlinearProblem, NonlinearLeastSquaresProblem},
107+
alg::GenericMultiStepDescent, J, fu, u; shared::Val{N} = Val(1),
108+
pre_inverted::Val{INV} = False, linsolve_kwargs = (;),
100109
abstol = nothing, reltol = nothing, timer = get_timer_output(),
101110
kwargs...) where {INV, N}
102-
δu, δus = __δu_caches(alg.scheme, fu, u, shared)
103-
INV && return GenericMultiStepDescentCache{true}(prob.f, prob.p, δu, δus,
104-
alg.scheme, nothing, timer, 0)
105-
lincache = LinearSolverCache(alg, alg.linsolve, J, _vec(fu), _vec(u); abstol, reltol,
106-
linsolve_kwargs...)
107-
return GenericMultiStepDescentCache{false}(prob.f, prob.p, δu, δus, alg.scheme,
108-
lincache, timer, 0)
109-
end
110-
111-
function __internal_init(prob::NonlinearLeastSquaresProblem, alg::GenericMultiStepDescent,
112-
J, fu, u; kwargs...)
113-
error("Multi-Step Descent Algorithms for NLLS are not implemented yet.")
111+
@bb δu = similar(u)
112+
δus = N 1 ? nothing : map(2:N) do i
113+
@bb δu_ = similar(u)
114+
end
115+
fu_cache = ntuple(MSS.nintermediates(alg.scheme)) do i
116+
@bb xx = similar(fu)
117+
end
118+
fus_cache = N 1 ? nothing : map(2:N) do i
119+
ntuple(MSS.nintermediates(alg.scheme)) do j
120+
@bb xx = similar(fu)
121+
end
122+
end
123+
u_cache = ntuple(MSS.nintermediates(alg.scheme)) do i
124+
@bb xx = similar(u)
125+
end
126+
us_cache = N 1 ? nothing : map(2:N) do i
127+
ntuple(MSS.nintermediates(alg.scheme)) do j
128+
@bb xx = similar(u)
129+
end
130+
end
131+
internal_cache, internal_caches = __internal_multistep_caches(
132+
alg.scheme, alg, prob, J, fu, u; shared, pre_inverted, linsolve_kwargs,
133+
abstol, reltol, timer, kwargs...)
134+
return GenericMultiStepDescentCache(
135+
prob.f, prob.p, δu, δus, u_cache, us_cache, fu_cache, fus_cache,
136+
internal_cache, internal_caches, alg.scheme, timer, 0)
114137
end
115138

116139
function __internal_solve!(cache::GenericMultiStepDescentCache{MSS.__PotraPtak3, INV}, J,
117140
fu, u, idx::Val = Val(1); skip_solve::Bool = false, new_jacobian::Bool = true,
118141
kwargs...) where {INV}
119-
(u_new, δy, fy, y, δu) = get_du(cache, idx)
120-
skip_solve && return DescentResult(; u = u_new)
121-
122-
@static_timeit cache.timer "linear solve" begin
123-
@static_timeit cache.timer "solve and step 1" begin
124-
if INV
125-
J !== nothing && @bb(δu=J × _vec(fu))
126-
else
127-
δu = cache.lincache(; A = J, b = _vec(fu), kwargs..., linu = _vec(δu),
128-
du = _vec(δu),
129-
reuse_A_if_factorization = !new_jacobian || (idx !== Val(1)))
130-
δu = _restructure(u, δu)
131-
end
132-
@bb @. y = u - δu
133-
end
142+
δu = get_du(cache, idx)
143+
skip_solve && return DescentResult(; δu)
144+
145+
(y,) = get_internal_cache(cache, Val(:u), idx)
146+
(fy,) = get_internal_cache(cache, Val(:fu), idx)
147+
internal_cache = get_internal_cache(cache, Val(:internal_cache), idx)
134148

149+
@static_timeit cache.timer "descent step" begin
150+
result_1 = __internal_solve!(
151+
internal_cache, J, fu, u, Val(1); new_jacobian, kwargs...)
152+
δx = result_1.δu
153+
154+
@bb @. y = u + δx
135155
fy = evaluate_f!!(cache.f, fy, y, cache.p)
136156
cache.nf += 1
137157

138-
@static_timeit cache.timer "solve and step 2" begin
139-
if INV
140-
J !== nothing && @bb(δy=J × _vec(fy))
141-
else
142-
δy = cache.lincache(; A = J, b = _vec(fy), kwargs..., linu = _vec(δy),
143-
du = _vec(δy), reuse_A_if_factorization = true)
144-
δy = _restructure(u, δy)
145-
end
146-
@bb @. u_new = y - δy
147-
end
158+
result_2 = __internal_solve!(
159+
internal_cache, J, fy, y, Val(2); kwargs...)
160+
δy = result_2.δu
161+
162+
@bb @. δu = δx + δy
148163
end
149164

150-
set_du!(cache, (u_new, δy, fy, y, δu), idx)
151-
return DescentResult(; u = u_new)
165+
set_du!(cache, δu, idx)
166+
set_internal_cache!(cache, (y,), Val(:u), idx)
167+
set_internal_cache!(cache, (fy,), Val(:fu), idx)
168+
set_internal_cache!(cache, internal_cache, Val(:internal_cache), idx)
169+
return DescentResult(; δu)
152170
end

0 commit comments

Comments
 (0)