Skip to content

Commit a5c6195

Browse files
committed
Share reinit code
1 parent 9bc8f5b commit a5c6195

File tree

3 files changed

+38
-55
lines changed

3 files changed

+38
-55
lines changed

src/NonlinearSolve.jl

+35-1
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,8 @@ import Reexport: @reexport
88
import PrecompileTools: @recompile_invalidations, @compile_workload, @setup_workload
99

1010
@recompile_invalidations begin
11-
using DiffEqBase, LazyArrays, LinearAlgebra, LinearSolve, Printf, SparseArrays,
11+
using DiffEqBase,
12+
LazyArrays, LinearAlgebra, LinearSolve, Printf, SparseArrays,
1213
SparseDiffTools
1314

1415
import ADTypes: AbstractFiniteDifferencesMode
@@ -51,6 +52,39 @@ abstract type AbstractNonlinearSolveCache{iip} end
5152

5253
isinplace(::AbstractNonlinearSolveCache{iip}) where {iip} = iip
5354

55+
function SciMLBase.reinit!(cache::AbstractNonlinearSolveCache{iip}, u0 = get_u(cache);
56+
p = cache.p, abstol = cache.abstol, reltol = cache.reltol,
57+
maxiters = cache.maxiters, alias_u0 = false,
58+
termination_condition = get_termination_mode(cache.tc_cache)) where {iip}
59+
cache.p = p
60+
if iip
61+
recursivecopy!(get_u(cache), u0)
62+
cache.f(cache.fu1, get_u(cache), p)
63+
else
64+
cache.u = __maybe_unaliased(u0, alias_u0)
65+
set_fu!(cache, cache.f(cache.u, p))
66+
end
67+
68+
reset!(cache.trace)
69+
abstol, reltol, tc_cache = init_termination_cache(abstol, reltol, get_fu(cache),
70+
get_u(cache), termination_condition)
71+
72+
cache.abstol = abstol
73+
cache.reltol = reltol
74+
cache.tc_cache = tc_cache
75+
cache.maxiters = maxiters
76+
cache.stats.nf = 1
77+
cache.stats.nsteps = 1
78+
cache.force_stop = false
79+
cache.retcode = ReturnCode.Default
80+
81+
__reinit_internal!(cache)
82+
83+
return cache
84+
end
85+
86+
__reinit_internal!(cache::AbstractNonlinearSolveCache) = nothing
87+
5488
function Base.show(io::IO, alg::AbstractNonlinearSolveAlgorithm)
5589
str = "$(nameof(typeof(alg)))("
5690
modifiers = String[]

src/broyden.jl

+3-26
Original file line numberDiff line numberDiff line change
@@ -137,31 +137,8 @@ function perform_step!(cache::GeneralBroydenCache{iip}) where {iip}
137137
return nothing
138138
end
139139

140-
function SciMLBase.reinit!(cache::GeneralBroydenCache{iip}, u0 = cache.u; p = cache.p,
141-
abstol = cache.abstol, reltol = cache.reltol, maxiters = cache.maxiters,
142-
termination_condition = get_termination_mode(cache.tc_cache)) where {iip}
143-
cache.p = p
144-
if iip
145-
recursivecopy!(cache.u, u0)
146-
cache.f(cache.fu, cache.u, p)
147-
else
148-
# don't have alias_u0 but cache.u is never mutated for OOP problems so it doesn't matter
149-
cache.u = u0
150-
cache.fu = cache.f(cache.u, p)
151-
end
152-
153-
reset!(cache.trace)
154-
abstol, reltol, tc_cache = init_termination_cache(abstol, reltol, cache.fu, cache.u,
155-
termination_condition)
156-
157-
cache.abstol = abstol
158-
cache.reltol = reltol
159-
cache.tc_cache = tc_cache
160-
cache.maxiters = maxiters
161-
cache.stats.nf = 1
162-
cache.stats.nsteps = 1
140+
function __reinit_internal!(cache::GeneralBroydenCache)
141+
cache.J⁻¹ = __reinit_identity_jacobian!!(cache.J⁻¹)
163142
cache.resets = 0
164-
cache.force_stop = false
165-
cache.retcode = ReturnCode.Default
166-
return cache
143+
return nothing
167144
end

src/raphson.jl

-28
Original file line numberDiff line numberDiff line change
@@ -128,31 +128,3 @@ function perform_step!(cache::NewtonRaphsonCache{iip}) where {iip}
128128
cache.stats.nfactors += 1
129129
return nothing
130130
end
131-
132-
function SciMLBase.reinit!(cache::NewtonRaphsonCache{iip}, u0 = cache.u; p = cache.p,
133-
abstol = cache.abstol, reltol = cache.reltol, maxiters = cache.maxiters,
134-
termination_condition = get_termination_mode(cache.tc_cache)) where {iip}
135-
cache.p = p
136-
if iip
137-
recursivecopy!(cache.u, u0)
138-
cache.f(cache.fu1, cache.u, p)
139-
else
140-
# don't have alias_u0 but cache.u is never mutated for OOP problems so it doesn't matter
141-
cache.u = u0
142-
cache.fu1 = cache.f(cache.u, p)
143-
end
144-
145-
reset!(cache.trace)
146-
abstol, reltol, tc_cache = init_termination_cache(abstol, reltol, cache.fu1, cache.u,
147-
termination_condition)
148-
149-
cache.abstol = abstol
150-
cache.reltol = reltol
151-
cache.tc_cache = tc_cache
152-
cache.maxiters = maxiters
153-
cache.stats.nf = 1
154-
cache.stats.nsteps = 1
155-
cache.force_stop = false
156-
cache.retcode = ReturnCode.Default
157-
return cache
158-
end

0 commit comments

Comments
 (0)