Skip to content

Commit 9bc8f5b

Browse files
committed
Reuse more code in Broyden
1 parent 74c2ad7 commit 9bc8f5b

File tree

4 files changed

+60
-77
lines changed

4 files changed

+60
-77
lines changed

src/NonlinearSolve.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ import PrecompileTools: @recompile_invalidations, @compile_workload, @setup_work
2525
AbstractVectorOfArray, recursivecopy!, recursivefill!
2626
import SciMLBase: AbstractNonlinearAlgorithm, NLStats, _unwrap_val, has_jac, isinplace
2727
import SciMLOperators: FunctionOperator
28-
import StaticArraysCore: StaticArray, SVector, SArray, MArray, Size, SMatrix
28+
import StaticArraysCore: StaticArray, SVector, SArray, MArray, Size, SMatrix, MMatrix
2929
import UnPack: @unpack
3030

3131
using ADTypes, LineSearches, SciMLBase, SimpleNonlinearSolve

src/broyden.jl

+31-64
Original file line numberDiff line numberDiff line change
@@ -65,107 +65,74 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::GeneralBroyde
6565
termination_condition = nothing, internalnorm::F = DEFAULT_NORM,
6666
kwargs...) where {uType, iip, F}
6767
@unpack f, u0, p = prob
68-
u = alias_u0 ? u0 : deepcopy(u0)
68+
u = __maybe_unaliased(u0, alias_u0)
6969
fu = evaluate_f(prob, u)
70-
du = _mutable_zero(u)
70+
@bb du = copy(u)
7171
J⁻¹ = __init_identity_jacobian(u, fu)
7272
reset_tolerance = alg.reset_tolerance === nothing ? sqrt(eps(real(eltype(u)))) :
7373
alg.reset_tolerance
7474
reset_check = x -> abs(x) reset_tolerance
7575

76+
@bb u_prev = copy(u)
77+
@bb fu2 = copy(fu)
78+
@bb dfu = similar(fu)
79+
@bb J⁻¹₂ = similar(u)
80+
@bb J⁻¹df = similar(u)
81+
7682
abstol, reltol, tc_cache = init_termination_cache(abstol, reltol, fu, u,
7783
termination_condition)
7884
trace = init_nonlinearsolve_trace(alg, u, fu, J⁻¹, du; uses_jac_inverse = Val(true),
7985
kwargs...)
8086

81-
return GeneralBroydenCache{iip}(f, alg, u, zero(u), du, fu, zero(fu),
82-
zero(fu), p, J⁻¹, zero(_reshape(fu, 1, :)), _mutable_zero(u), false, 0,
83-
alg.max_resets, maxiters, internalnorm, ReturnCode.Default, abstol, reltol,
84-
reset_tolerance, reset_check, prob, NLStats(1, 0, 0, 0, 0),
87+
return GeneralBroydenCache{iip}(f, alg, u, u_prev, du, fu, fu2, dfu, p, J⁻¹,
88+
J⁻¹₂, J⁻¹df, false, 0, alg.max_resets, maxiters, internalnorm, ReturnCode.Default,
89+
abstol, reltol, reset_tolerance, reset_check, prob, NLStats(1, 0, 0, 0, 0),
8590
init_linesearch_cache(alg.linesearch, f, u, p, fu, Val(iip)), tc_cache, trace)
8691
end
8792

88-
function perform_step!(cache::GeneralBroydenCache{true})
89-
@unpack f, p, du, fu, fu2, dfu, u, u_prev, J⁻¹, J⁻¹df, J⁻¹₂ = cache
90-
T = eltype(u)
91-
92-
mul!(_vec(du), J⁻¹, _vec(fu))
93-
α = perform_linesearch!(cache.ls_cache, u, du)
94-
_axpy!(-α, du, u)
95-
f(fu2, u, p)
96-
97-
update_trace_with_invJ!(cache.trace, cache.stats.nsteps + 1, get_u(cache),
98-
get_fu(cache), J⁻¹, du, α)
99-
100-
check_and_update!(cache, fu2, u, u_prev)
101-
cache.stats.nf += 1
102-
103-
cache.force_stop && return nothing
93+
function perform_step!(cache::GeneralBroydenCache{iip}) where {iip}
94+
T = eltype(cache.u)
10495

105-
# Update the inverse jacobian
106-
dfu .= fu2 .- fu
96+
@bb cache.du = cache.J⁻¹ × vec(cache.fu)
97+
α = perform_linesearch!(cache.ls_cache, cache.u, cache.du)
98+
@bb axpy!(-α, cache.du, cache.u)
10799

108-
if all(cache.reset_check, du) || all(cache.reset_check, dfu)
109-
if cache.resets cache.max_resets
110-
cache.retcode = ReturnCode.ConvergenceFailure
111-
cache.force_stop = true
112-
return nothing
113-
end
114-
fill!(J⁻¹, 0)
115-
J⁻¹[diagind(J⁻¹)] .= T(1)
116-
cache.resets += 1
100+
if iip
101+
cache.f(cache.fu2, cache.u, cache.p)
117102
else
118-
du .*= -1
119-
mul!(_vec(J⁻¹df), J⁻¹, _vec(dfu))
120-
mul!(J⁻¹₂, _vec(du)', J⁻¹)
121-
denom = dot(du, J⁻¹df)
122-
du .= (du .- J⁻¹df) ./ ifelse(iszero(denom), T(1e-5), denom)
123-
mul!(J⁻¹, _vec(du), J⁻¹₂, 1, 1)
103+
cache.fu2 = cache.f(cache.u, cache.p)
124104
end
125-
fu .= fu2
126-
@. u_prev = u
127-
128-
return nothing
129-
end
130-
131-
function perform_step!(cache::GeneralBroydenCache{false})
132-
@unpack f, p = cache
133-
134-
T = eltype(cache.u)
135-
136-
cache.du = _restructure(cache.du, cache.J⁻¹ * _vec(cache.fu))
137-
α = perform_linesearch!(cache.ls_cache, cache.u, cache.du)
138-
cache.u = cache.u .- α * cache.du
139-
cache.fu2 = f(cache.u, p)
140105

141106
update_trace_with_invJ!(cache.trace, cache.stats.nsteps + 1, get_u(cache),
142-
get_fu(cache), cache.J⁻¹, cache.du, α)
107+
cache.fu2, cache.J⁻¹, cache.du, α)
143108

144109
check_and_update!(cache, cache.fu2, cache.u, cache.u_prev)
145110
cache.stats.nf += 1
146111

147112
cache.force_stop && return nothing
148113

149114
# Update the inverse jacobian
150-
cache.dfu = cache.fu2 .- cache.fu
115+
@bb @. cache.dfu = cache.fu2 - cache.fu
116+
151117
if all(cache.reset_check, cache.du) || all(cache.reset_check, cache.dfu)
152118
if cache.resets cache.max_resets
153119
cache.retcode = ReturnCode.ConvergenceFailure
154120
cache.force_stop = true
155121
return nothing
156122
end
157-
cache.J⁻¹ = __init_identity_jacobian(cache.u, cache.fu)
123+
cache.J⁻¹ = __reinit_identity_jacobian!!(cache.J⁻¹)
158124
cache.resets += 1
159125
else
160-
cache.du = -cache.du
161-
cache.J⁻¹df = _restructure(cache.J⁻¹df, cache.J⁻¹ * _vec(cache.dfu))
162-
cache.J⁻¹₂ = _vec(cache.du)' * cache.J⁻¹
126+
@bb cache.du .*= -1
127+
@bb cache.J⁻¹df = cache.J⁻¹ × vec(cache.dfu)
128+
@bb cache.J⁻¹₂ = cache.J⁻¹ × vec(cache.du)
163129
denom = dot(cache.du, cache.J⁻¹df)
164-
cache.du = (cache.du .- cache.J⁻¹df) ./ ifelse(iszero(denom), T(1e-5), denom)
165-
cache.J⁻¹ = cache.J⁻¹ .+ _vec(cache.du) * cache.J⁻¹₂
130+
@bb @. cache.du = (cache.du - cache.J⁻¹df) / ifelse(iszero(denom), T(1e-5), denom)
131+
@bb cache.J⁻¹ += vec(cache.du) × transpose(cache.J⁻¹₂)
166132
end
167-
cache.fu = cache.fu2
168-
cache.u_prev = @. cache.u
133+
134+
@bb copyto!(cache.fu, cache.fu2)
135+
@bb copyto!(cache.u_prev, cache.u)
169136

170137
return nothing
171138
end

src/raphson.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ function perform_step!(cache::NewtonRaphsonCache{iip}) where {iip}
114114
α = perform_linesearch!(cache.ls_cache, cache.u, cache.du)
115115
@bb axpy!(-α, cache.du, cache.u)
116116

117-
evaluate_f(cache, cache.u)
117+
evaluate_f(cache, cache.u, cache.p)
118118

119119
update_trace!(cache.trace, cache.stats.nsteps + 1, get_u(cache), get_fu(cache), cache.J,
120120
cache.du, α)

src/utils.jl

+27-11
Original file line numberDiff line numberDiff line change
@@ -188,12 +188,11 @@ function evaluate_f(prob::Union{NonlinearProblem{uType, iip},
188188
return fu
189189
end
190190

191-
function evaluate_f(cache, u)
192-
@unpack f, p = cache.prob
191+
function evaluate_f(cache, u, p)
193192
if isinplace(cache)
194-
f(get_fu(cache), u, p)
193+
cache.prob.f(get_fu(cache), u, p)
195194
else
196-
set_fu!(cache, f(u, p))
195+
set_fu!(cache, cache.prob.f(u, p))
197196
end
198197
return nothing
199198
end
@@ -301,14 +300,31 @@ function check_and_update!(tc_cache, cache, fu, u, uprev,
301300
end
302301
end
303302

304-
__init_identity_jacobian(u::Number, _) = u
305-
function __init_identity_jacobian(u, fu)
306-
return convert(parameterless_type(_mutable(u)),
307-
Matrix{eltype(u)}(I, length(fu), length(u)))
303+
@inline __init_identity_jacobian(u::Number, _) = one(u)
304+
@inline function __init_identity_jacobian(u, fu)
305+
J = similar(fu, promote_type(eltype(fu), eltype(u)), length(fu), length(u))
306+
fill!(J, zero(eltype(J)))
307+
J[diagind(J)] .= one(eltype(J))
308+
return J
308309
end
309-
function __init_identity_jacobian(u::StaticArray, fu)
310-
return convert(MArray{Tuple{length(fu), length(u)}},
311-
Matrix{eltype(u)}(I, length(fu), length(u)))
310+
@inline function __init_identity_jacobian(u::StaticArray, fu::StaticArray)
311+
T = promote_type(eltype(fu), eltype(u))
312+
return MArray{Tuple{prod(Size(fu)), prod(Size(u))}, T}(I)
313+
end
314+
@inline function __init_identity_jacobian(u::SArray, fu::SArray)
315+
T = promote_type(eltype(fu), eltype(u))
316+
return SArray{Tuple{prod(Size(fu)), prod(Size(u))}, T}(I)
317+
end
318+
319+
@inline __reinit_identity_jacobian!!(J::Number) = one(J)
320+
@inline function __reinit_identity_jacobian!!(J::AbstractMatrix)
321+
fill!(J, zero(eltype(J)))
322+
J[diagind(J)] .= one(eltype(J))
323+
return J
324+
end
325+
@inline function __reinit_identity_jacobian!!(J::SMatrix)
326+
S = Size(J)
327+
return SArray{Tuple{S[1], S[2]}, eltype(J)}(I)
312328
end
313329

314330
function __init_low_rank_jacobian(u::StaticArray, fu, threshold::Int)

0 commit comments

Comments
 (0)