Skip to content

Commit 5b7c7c4

Browse files
committed
Trust Region mostly works
1 parent 13e590e commit 5b7c7c4

File tree

5 files changed

+244
-344
lines changed

5 files changed

+244
-344
lines changed

src/NonlinearSolve.jl

+49-49
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,7 @@ include("trace.jl")
169169
include("extension_algs.jl")
170170
include("linesearch.jl")
171171
include("raphson.jl")
172-
# include("trustRegion.jl")
172+
include("trustRegion.jl")
173173
include("levenberg.jl")
174174
include("gaussnewton.jl")
175175
include("dfsane.jl")
@@ -179,54 +179,54 @@ include("klement.jl")
179179
include("lbroyden.jl")
180180
include("jacobian.jl")
181181
include("ad.jl")
182-
# include("default.jl")
183-
184-
# @setup_workload begin
185-
# nlfuncs = ((NonlinearFunction{false}((u, p) -> u .* u .- p), 0.1),
186-
# (NonlinearFunction{false}((u, p) -> u .* u .- p), [0.1]),
187-
# (NonlinearFunction{true}((du, u, p) -> du .= u .* u .- p), [0.1]))
188-
# probs_nls = NonlinearProblem[]
189-
# for T in (Float32, Float64), (fn, u0) in nlfuncs
190-
# push!(probs_nls, NonlinearProblem(fn, T.(u0), T(2)))
191-
# end
192-
193-
# nls_algs = (NewtonRaphson(), TrustRegion(), LevenbergMarquardt(), PseudoTransient(),
194-
# GeneralBroyden(), GeneralKlement(), DFSane(), nothing)
195-
196-
# probs_nlls = NonlinearLeastSquaresProblem[]
197-
# nlfuncs = ((NonlinearFunction{false}((u, p) -> (u .^ 2 .- p)[1:1]), [0.1, 0.0]),
198-
# (NonlinearFunction{false}((u, p) -> vcat(u .* u .- p, u .* u .- p)), [0.1, 0.1]),
199-
# (NonlinearFunction{true}((du, u, p) -> du[1] = u[1] * u[1] - p,
200-
# resid_prototype = zeros(1)), [0.1, 0.0]),
201-
# (NonlinearFunction{true}((du, u, p) -> du .= vcat(u .* u .- p, u .* u .- p),
202-
# resid_prototype = zeros(4)), [0.1, 0.1]))
203-
# for (fn, u0) in nlfuncs
204-
# push!(probs_nlls, NonlinearLeastSquaresProblem(fn, u0, 2.0))
205-
# end
206-
# nlfuncs = ((NonlinearFunction{false}((u, p) -> (u .^ 2 .- p)[1:1]), Float32[0.1, 0.0]),
207-
# (NonlinearFunction{false}((u, p) -> vcat(u .* u .- p, u .* u .- p)),
208-
# Float32[0.1, 0.1]),
209-
# (NonlinearFunction{true}((du, u, p) -> du[1] = u[1] * u[1] - p,
210-
# resid_prototype = zeros(Float32, 1)), Float32[0.1, 0.0]),
211-
# (NonlinearFunction{true}((du, u, p) -> du .= vcat(u .* u .- p, u .* u .- p),
212-
# resid_prototype = zeros(Float32, 4)), Float32[0.1, 0.1]))
213-
# for (fn, u0) in nlfuncs
214-
# push!(probs_nlls, NonlinearLeastSquaresProblem(fn, u0, 2.0f0))
215-
# end
216-
217-
# nlls_algs = (LevenbergMarquardt(), GaussNewton(),
218-
# LevenbergMarquardt(; linsolve = LUFactorization()),
219-
# GaussNewton(; linsolve = LUFactorization()))
220-
221-
# @compile_workload begin
222-
# for prob in probs_nls, alg in nls_algs
223-
# solve(prob, alg, abstol = 1e-2)
224-
# end
225-
# for prob in probs_nlls, alg in nlls_algs
226-
# solve(prob, alg, abstol = 1e-2)
227-
# end
228-
# end
229-
# end
182+
include("default.jl")
183+
184+
@setup_workload begin
185+
nlfuncs = ((NonlinearFunction{false}((u, p) -> u .* u .- p), 0.1),
186+
(NonlinearFunction{false}((u, p) -> u .* u .- p), [0.1]),
187+
(NonlinearFunction{true}((du, u, p) -> du .= u .* u .- p), [0.1]))
188+
probs_nls = NonlinearProblem[]
189+
for T in (Float32, Float64), (fn, u0) in nlfuncs
190+
push!(probs_nls, NonlinearProblem(fn, T.(u0), T(2)))
191+
end
192+
193+
nls_algs = (NewtonRaphson(), TrustRegion(), LevenbergMarquardt(), PseudoTransient(),
194+
GeneralBroyden(), GeneralKlement(), DFSane(), nothing)
195+
196+
probs_nlls = NonlinearLeastSquaresProblem[]
197+
nlfuncs = ((NonlinearFunction{false}((u, p) -> (u .^ 2 .- p)[1:1]), [0.1, 0.0]),
198+
(NonlinearFunction{false}((u, p) -> vcat(u .* u .- p, u .* u .- p)), [0.1, 0.1]),
199+
(NonlinearFunction{true}((du, u, p) -> du[1] = u[1] * u[1] - p,
200+
resid_prototype = zeros(1)), [0.1, 0.0]),
201+
(NonlinearFunction{true}((du, u, p) -> du .= vcat(u .* u .- p, u .* u .- p),
202+
resid_prototype = zeros(4)), [0.1, 0.1]))
203+
for (fn, u0) in nlfuncs
204+
push!(probs_nlls, NonlinearLeastSquaresProblem(fn, u0, 2.0))
205+
end
206+
nlfuncs = ((NonlinearFunction{false}((u, p) -> (u .^ 2 .- p)[1:1]), Float32[0.1, 0.0]),
207+
(NonlinearFunction{false}((u, p) -> vcat(u .* u .- p, u .* u .- p)),
208+
Float32[0.1, 0.1]),
209+
(NonlinearFunction{true}((du, u, p) -> du[1] = u[1] * u[1] - p,
210+
resid_prototype = zeros(Float32, 1)), Float32[0.1, 0.0]),
211+
(NonlinearFunction{true}((du, u, p) -> du .= vcat(u .* u .- p, u .* u .- p),
212+
resid_prototype = zeros(Float32, 4)), Float32[0.1, 0.1]))
213+
for (fn, u0) in nlfuncs
214+
push!(probs_nlls, NonlinearLeastSquaresProblem(fn, u0, 2.0f0))
215+
end
216+
217+
nlls_algs = (LevenbergMarquardt(), GaussNewton(),
218+
LevenbergMarquardt(; linsolve = LUFactorization()),
219+
GaussNewton(; linsolve = LUFactorization()))
220+
221+
@compile_workload begin
222+
for prob in probs_nls, alg in nls_algs
223+
solve(prob, alg, abstol = 1e-2)
224+
end
225+
for prob in probs_nlls, alg in nlls_algs
226+
solve(prob, alg, abstol = 1e-2)
227+
end
228+
end
229+
end
230230

231231
export RadiusUpdateSchemes
232232

src/gaussnewton.jl

+2-2
Original file line numberDiff line numberDiff line change
@@ -116,8 +116,8 @@ function perform_step!(cache::GaussNewtonCache{iip}) where {iip}
116116

117117
# Use normal form to solve the Linear Problem
118118
if cache.JᵀJ !== nothing
119-
__update_JᵀJ!(cache, Val(:JᵀJ))
120-
__update_Jᵀf!(cache, Val(:JᵀJ))
119+
__update_JᵀJ!(cache)
120+
__update_Jᵀf!(cache)
121121
A, b = __maybe_symmetric(cache.JᵀJ), _vec(cache.Jᵀf)
122122
else
123123
A, b = cache.J, _vec(cache.fu)

src/jacobian.jl

+33-12
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ function jacobian_caches(alg::AbstractNonlinearSolveAlgorithm, f::F, u::Number,
138138
kwargs...) where {needsJᵀJ, F}
139139
# NOTE: Scalar `u` assumes scalar output from `f`
140140
uf = SciMLBase.JacobianWrapper{false}(f, p)
141-
return uf, FakeLinearSolveJLCache(u, u), u, nothing, nothing, u, u, u
141+
return uf, FakeLinearSolveJLCache(u, u), u, zero(u), nothing, u, u, u
142142
end
143143

144144
# Linear Solve Cache
@@ -208,27 +208,48 @@ function __concrete_vjp_autodiff(vjp_autodiff, jvp_autodiff, uf)
208208
end
209209
end
210210

211+
# jvp fallback scalar
212+
__jacvec(args...; kwargs...) = JacVec(args...; kwargs...)
213+
__jacvec(uf, u::Number; autodiff, kwargs...) = JVPScalar(uf, u, autodiff)
214+
215+
@concrete mutable struct JVPScalar
216+
uf
217+
u
218+
autodiff
219+
end
220+
221+
function Base.:*(jvp::JVPScalar, v)
222+
if jvp.autodiff isa AutoForwardDiff
223+
elseif jvp.autodiff isa AutoFiniteDiff
224+
else
225+
error("JVPScalar only supports AutoForwardDiff and AutoFiniteDiff")
226+
end
227+
end
228+
211229
# Generic Handling of Krylov Methods for Normal Form Linear Solves
212-
function __update_JᵀJ!(cache::AbstractNonlinearSolveCache)
230+
function __update_JᵀJ!(cache::AbstractNonlinearSolveCache, J = nothing)
213231
if !(cache.JᵀJ isa KrylovJᵀJ)
214-
@bb cache.JᵀJ = transpose(cache.J) × cache.J
232+
J_ = ifelse(J === nothing, cache.J, J)
233+
@bb cache.JᵀJ = transpose(J_) × J_
215234
end
216235
end
217236

218-
function __update_Jᵀf!(cache::AbstractNonlinearSolveCache)
237+
function __update_Jᵀf!(cache::AbstractNonlinearSolveCache, J = nothing)
219238
if cache.JᵀJ isa KrylovJᵀJ
220239
@bb cache.Jᵀf = cache.JᵀJ.Jᵀ × cache.fu
221240
else
222-
@bb cache.Jᵀf = transpose(cache.J) × vec(cache.fu)
241+
J_ = ifelse(J === nothing, cache.J, J)
242+
@bb cache.Jᵀf = transpose(J_) × vec(cache.fu)
223243
end
224244
end
225245

226246
# Left-Right Multiplication
227-
__lr_mul(::Val, H, g) = dot(g, H, g)
228-
## TODO: Use a cache here to avoid allocations
229-
__lr_mul(::Val{false}, H::KrylovJᵀJ, g) = dot(g, H.JᵀJ, g)
230-
function __lr_mul(::Val{true}, H::KrylovJᵀJ, g)
231-
c = similar(g)
232-
mul!(c, H.JᵀJ, g)
233-
return dot(g, c)
247+
__lr_mul(cache::AbstractNonlinearSolveCache) = __lr_mul(cache, cache.JᵀJ, cache.Jᵀf)
248+
function __lr_mul(cache::AbstractNonlinearSolveCache, JᵀJ::KrylovJᵀJ, Jᵀf)
249+
@bb cache.lr_mul_cache = JᵀJ.JᵀJ × vec(Jᵀf)
250+
return dot(_vec(Jᵀf), _vec(cache.lr_mul_cache))
251+
end
252+
function __lr_mul(cache::AbstractNonlinearSolveCache, JᵀJ, Jᵀf)
253+
@bb cache.lr_mul_cache = JᵀJ × Jᵀf
254+
return dot(_vec(Jᵀf), _vec(cache.lr_mul_cache))
234255
end

0 commit comments

Comments
 (0)