Skip to content

Commit 46912f2

Browse files
Merge pull request #282 from avik-pal/ap/krylov
Jacobian-Free Krylov Versions for TR/LM/GN
2 parents 0026bc1 + bcfcc16 commit 46912f2

15 files changed

+289
-97
lines changed

.JuliaFormatter.toml

+1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
style = "sciml"
22
format_markdown = true
33
annotate_untyped_fields_with_any = false
4+
format_docstrings = true

Project.toml

+5-3
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "NonlinearSolve"
22
uuid = "8913a72c-1f9b-4ce2-8d82-65094dcecaec"
33
authors = ["SciML"]
4-
version = "2.8.2"
4+
version = "2.9.0"
55

66
[deps]
77
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
@@ -30,11 +30,13 @@ UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed"
3030
BandedMatrices = "aae01518-5342-5314-be14-df237901396f"
3131
FastLevenbergMarquardt = "7a0df574-e128-4d35-8cbd-3d84502bf7ce"
3232
LeastSquaresOptim = "0fc2ff8b-aaa3-5acd-a817-1944a5e08891"
33+
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
3334

3435
[extensions]
3536
NonlinearSolveBandedMatricesExt = "BandedMatrices"
3637
NonlinearSolveFastLevenbergMarquardtExt = "FastLevenbergMarquardt"
3738
NonlinearSolveLeastSquaresOptimExt = "LeastSquaresOptim"
39+
NonlinearSolveZygoteExt = "Zygote"
3840

3941
[compat]
4042
ADTypes = "0.2"
@@ -50,7 +52,7 @@ FiniteDiff = "2"
5052
ForwardDiff = "0.10.3"
5153
LeastSquaresOptim = "0.8"
5254
LineSearches = "7"
53-
LinearAlgebra = "1.9"
55+
LinearAlgebra = "<0.0.1, 1"
5456
LinearSolve = "2.12"
5557
NonlinearProblemLibrary = "0.1"
5658
PrecompileTools = "1"
@@ -59,7 +61,7 @@ Reexport = "0.2, 1"
5961
SciMLBase = "2.8.2"
6062
SciMLOperators = "0.3"
6163
SimpleNonlinearSolve = "0.1.23"
62-
SparseArrays = "1.9"
64+
SparseArrays = "<0.0.1, 1"
6365
SparseDiffTools = "2.12"
6466
StaticArraysCore = "1.4"
6567
UnPack = "1.0"

ext/NonlinearSolveZygoteExt.jl

+7
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
module NonlinearSolveZygoteExt
2+
3+
import NonlinearSolve, Zygote
4+
5+
NonlinearSolve.is_extension_loaded(::Val{:Zygote}) = true
6+
7+
end

src/NonlinearSolve.jl

+3
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,9 @@ import DiffEqBase: AbstractNonlinearTerminationMode,
3838
const AbstractSparseADType = Union{ADTypes.AbstractSparseFiniteDifferences,
3939
ADTypes.AbstractSparseForwardMode, ADTypes.AbstractSparseReverseMode}
4040

41+
# Type-Inference Friendly Check for Extension Loading
42+
is_extension_loaded(::Val) = false
43+
4144
abstract type AbstractNonlinearSolveLineSearchAlgorithm end
4245

4346
abstract type AbstractNonlinearSolveAlgorithm <: AbstractNonlinearAlgorithm end

src/extension_algs.jl

+13-8
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,15 @@ for solving `NonlinearLeastSquaresProblem`.
88
99
## Arguments:
1010
11-
- `alg`: Algorithm to use. Can be `:lm` or `:dogleg`.
12-
- `linsolve`: Linear solver to use. Can be `:qr`, `:cholesky` or `:lsmr`. If
13-
`nothing`, then `LeastSquaresOptim.jl` will choose the best linear solver based
14-
on the Jacobian structure.
15-
- `autodiff`: Automatic differentiation / Finite Differences. Can be `:central` or `:forward`.
11+
- `alg`: Algorithm to use. Can be `:lm` or `:dogleg`.
12+
- `linsolve`: Linear solver to use. Can be `:qr`, `:cholesky` or `:lsmr`. If `nothing`,
13+
then `LeastSquaresOptim.jl` will choose the best linear solver based on the Jacobian
14+
structure.
15+
- `autodiff`: Automatic differentiation / Finite Differences. Can be `:central` or
16+
`:forward`.
1617
1718
!!! note
19+
1820
This algorithm is only available if `LeastSquaresOptim.jl` is installed.
1921
"""
2022
struct LeastSquaresOptimJL{alg, linsolve} <: AbstractNonlinearSolveAlgorithm
@@ -36,21 +38,24 @@ end
3638
"""
3739
FastLevenbergMarquardtJL(linsolve = :cholesky)
3840
39-
Wrapper over [FastLevenbergMarquardt.jl](https://github.com/kamesy/FastLevenbergMarquardt.jl) for solving
40-
`NonlinearLeastSquaresProblem`.
41+
Wrapper over [FastLevenbergMarquardt.jl](https://github.com/kamesy/FastLevenbergMarquardt.jl)
42+
for solving `NonlinearLeastSquaresProblem`.
4143
4244
!!! warning
45+
4346
This is not really the fastest solver. It is called that since the original package
4447
is called "Fast". `LevenbergMarquardt()` is almost always a better choice.
4548
4649
!!! warning
50+
4751
This algorithm requires the jacobian function to be provided!
4852
4953
## Arguments:
5054
51-
- `linsolve`: Linear solver to use. Can be `:qr` or `:cholesky`.
55+
- `linsolve`: Linear solver to use. Can be `:qr` or `:cholesky`.
5256
5357
!!! note
58+
5459
This algorithm is only available if `FastLevenbergMarquardt.jl` is installed.
5560
"""
5661
@concrete struct FastLevenbergMarquardtJL{linsolve} <: AbstractNonlinearSolveAlgorithm

src/gaussnewton.jl

+14-16
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,6 @@ An advanced GaussNewton implementation with support for efficient handling of sp
66
matrices via colored automatic differentiation and preconditioned linear solvers. Designed
77
for large-scale and numerically-difficult nonlinear least squares problems.
88
9-
!!! note
10-
In most practical situations, users should prefer using `LevenbergMarquardt` instead! It
11-
is a more general extension of `Gauss-Newton` Method.
12-
139
### Keyword Arguments
1410
1511
- `autodiff`: determines the backend used for the Jacobian. Note that this argument is
@@ -33,28 +29,30 @@ for large-scale and numerically-difficult nonlinear least squares problems.
3329
- `linesearch`: the line search algorithm to use. Defaults to [`LineSearch()`](@ref),
3430
which means that no line search is performed. Algorithms from `LineSearches.jl` can be
3531
used here directly, and they will be converted to the correct `LineSearch`.
36-
37-
!!! warning
38-
39-
Jacobian-Free version of `GaussNewton` doesn't work yet, and it forces jacobian
40-
construction. This will be fixed in the near future.
32+
- `vjp_autodiff`: Automatic Differentiation Backend used for vector-jacobian products.
33+
This is applicable if the linear solver doesn't require a concrete jacobian, for eg.,
34+
Krylov Methods. Defaults to `nothing`, which means if the problem is out of place and
35+
`Zygote` is loaded then, we use `AutoZygote`. In all other, cases `FiniteDiff` is used.
4136
"""
4237
@concrete struct GaussNewton{CJ, AD} <: AbstractNewtonAlgorithm{CJ, AD}
4338
ad::AD
4439
linsolve
4540
precs
4641
linesearch
42+
vjp_autodiff
4743
end
4844

4945
function set_ad(alg::GaussNewton{CJ}, ad) where {CJ}
50-
return GaussNewton{CJ}(ad, alg.linsolve, alg.precs, alg.linesearch)
46+
return GaussNewton{CJ}(ad, alg.linsolve, alg.precs, alg.linesearch, alg.vjp_autodiff)
5147
end
5248

5349
function GaussNewton(; concrete_jac = nothing, linsolve = nothing,
54-
linesearch = LineSearch(), precs = DEFAULT_PRECS, adkwargs...)
50+
linesearch = LineSearch(), precs = DEFAULT_PRECS, vjp_autodiff = nothing,
51+
adkwargs...)
5552
ad = default_adargs_to_adtype(; adkwargs...)
5653
linesearch = linesearch isa LineSearch ? linesearch : LineSearch(; method = linesearch)
57-
return GaussNewton{_unwrap_val(concrete_jac)}(ad, linsolve, precs, linesearch)
54+
return GaussNewton{_unwrap_val(concrete_jac)}(ad, linsolve, precs, linesearch,
55+
vjp_autodiff)
5856
end
5957

6058
@concrete mutable struct GaussNewtonCache{iip} <: AbstractNonlinearSolveCache{iip}
@@ -122,8 +120,8 @@ function perform_step!(cache::GaussNewtonCache{true})
122120
jacobian!!(J, cache)
123121

124122
if JᵀJ !== nothing
125-
__matmul!(JᵀJ, J', J)
126-
__matmul!(Jᵀf, J', fu1)
123+
__update_JᵀJ!(Val{true}(), cache, :JᵀJ, J)
124+
__update_Jᵀf!(Val{true}(), cache, :Jᵀf, :JᵀJ, J, fu1)
127125
end
128126

129127
# u = u - JᵀJ \ Jᵀfu
@@ -160,8 +158,8 @@ function perform_step!(cache::GaussNewtonCache{false})
160158
cache.J = jacobian!!(cache.J, cache)
161159

162160
if cache.JᵀJ !== nothing
163-
cache.JᵀJ = cache.J' * cache.J
164-
cache.Jᵀf = cache.J' * fu1
161+
__update_JᵀJ!(Val{false}(), cache, :JᵀJ, cache.J)
162+
__update_Jᵀf!(Val{false}(), cache, :Jᵀf, :JᵀJ, cache.J, fu1)
165163
end
166164

167165
# u = u - J \ fu

src/jacobian.jl

+97-11
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,10 @@
1+
@concrete struct KrylovJᵀJ
2+
JᵀJ
3+
Jᵀ
4+
end
5+
6+
SciMLBase.isinplace(JᵀJ::KrylovJᵀJ) = isinplace(JᵀJ.Jᵀ)
7+
18
sparsity_detection_alg(_, _) = NoSparsityDetection()
29
function sparsity_detection_alg(f, ad::AbstractSparseADType)
310
if f.sparsity === nothing
@@ -54,7 +61,7 @@ function jacobian_caches(alg::AbstractNonlinearSolveAlgorithm, f::F, u, p, ::Val
5461
# NOTE: The deepcopy is needed here since we are using the resid_prototype elsewhere
5562
fu = f.resid_prototype === nothing ? (iip ? _mutable_zero(u) : _mutable(f(u, p))) :
5663
(iip ? deepcopy(f.resid_prototype) : f.resid_prototype)
57-
if !has_analytic_jac && (linsolve_needs_jac || alg_wants_jac || needsJᵀJ)
64+
if !has_analytic_jac && (linsolve_needs_jac || alg_wants_jac)
5865
sd = sparsity_detection_alg(f, alg.ad)
5966
ad = alg.ad
6067
jac_cache = iip ? sparse_jacobian_cache(ad, sd, uf, fu, _maybe_mutable(u, ad)) :
@@ -63,12 +70,10 @@ function jacobian_caches(alg::AbstractNonlinearSolveAlgorithm, f::F, u, p, ::Val
6370
jac_cache = nothing
6471
end
6572

66-
# FIXME: To properly support needsJᵀJ without Jacobian, we need to implement
67-
# a reverse diff operation with the seed being `Jx`, this is not yet implemented
68-
J = if !(linsolve_needs_jac || alg_wants_jac || needsJᵀJ)
73+
J = if !(linsolve_needs_jac || alg_wants_jac)
6974
if f.jvp === nothing
7075
# We don't need to construct the Jacobian
71-
JacVec(uf, u; autodiff = __get_nonsparse_ad(alg.ad))
76+
JacVec(uf, u; fu, autodiff = __get_nonsparse_ad(alg.ad))
7277
else
7378
if iip
7479
jvp = (_, u, v) -> (du = similar(fu); f.jvp(du, v, u, p); du)
@@ -92,9 +97,9 @@ function jacobian_caches(alg::AbstractNonlinearSolveAlgorithm, f::F, u, p, ::Val
9297
du = _mutable_zero(u)
9398

9499
if needsJᵀJ
95-
JᵀJ = __init_JᵀJ(J)
96-
# FIXME: This needs to be handled better for JacVec Operator
97-
Jᵀfu = J' * _vec(fu)
100+
JᵀJ, Jᵀfu = __init_JᵀJ(J, _vec(fu), uf, u; f,
101+
vjp_autodiff = __get_nonsparse_ad(_getproperty(alg, Val(:vjp_autodiff))),
102+
jvp_autodiff = __get_nonsparse_ad(alg.ad))
98103
end
99104

100105
if linsolve_init
@@ -120,21 +125,68 @@ function __setup_linsolve(A, b, u, p, alg)
120125
nothing)..., weight)
121126
return init(linprob, alg.linsolve; alias_A = true, alias_b = true, Pl, Pr)
122127
end
128+
__setup_linsolve(A::KrylovJᵀJ, b, u, p, alg) = __setup_linsolve(A.JᵀJ, b, u, p, alg)
123129

124130
__get_nonsparse_ad(::AutoSparseForwardDiff) = AutoForwardDiff()
125131
__get_nonsparse_ad(::AutoSparseFiniteDiff) = AutoFiniteDiff()
126132
__get_nonsparse_ad(::AutoSparseZygote) = AutoZygote()
127133
__get_nonsparse_ad(ad) = ad
128134

129-
__init_JᵀJ(J::Number) = zero(J)
130-
__init_JᵀJ(J::AbstractArray) = J' * J
131-
__init_JᵀJ(J::StaticArray) = MArray{Tuple{size(J, 2), size(J, 2)}, eltype(J)}(undef)
135+
__init_JᵀJ(J::Number, args...; kwargs...) = zero(J), zero(J)
136+
function __init_JᵀJ(J::AbstractArray, fu, args...; kwargs...)
137+
JᵀJ = J' * J
138+
Jᵀfu = J' * fu
139+
return JᵀJ, Jᵀfu
140+
end
141+
function __init_JᵀJ(J::StaticArray, fu, args...; kwargs...)
142+
JᵀJ = MArray{Tuple{size(J, 2), size(J, 2)}, eltype(J)}(undef)
143+
return JᵀJ, J' * fu
144+
end
145+
function __init_JᵀJ(J::FunctionOperator, fu, uf, u, args...; f = nothing,
146+
vjp_autodiff = nothing, jvp_autodiff = nothing, kwargs...)
147+
# FIXME: Proper fix to this requires the FunctionOperator patch
148+
if f !== nothing && f.vjp !== nothing
149+
@warn "Currently we don't make use of user provided `jvp`. This is planned to be \
150+
fixed in the near future."
151+
end
152+
autodiff = __concrete_vjp_autodiff(vjp_autodiff, jvp_autodiff, uf)
153+
Jᵀ = VecJac(uf, u; fu, autodiff)
154+
JᵀJ_op = SciMLOperators.cache_operator(Jᵀ * J, u)
155+
JᵀJ = KrylovJᵀJ(JᵀJ_op, Jᵀ)
156+
Jᵀfu = Jᵀ * fu
157+
return JᵀJ, Jᵀfu
158+
end
159+
160+
function __concrete_vjp_autodiff(vjp_autodiff, jvp_autodiff, uf)
161+
if vjp_autodiff === nothing
162+
if isinplace(uf)
163+
# VecJac can be only FiniteDiff
164+
return AutoFiniteDiff()
165+
else
166+
# Short circuit if we see that FiniteDiff was used for J computation
167+
jvp_autodiff isa AutoFiniteDiff && return jvp_autodiff
168+
# Check if Zygote is loaded then use Zygote else use FiniteDiff
169+
is_extension_loaded(Val{:Zygote}()) && return AutoZygote()
170+
return AutoFiniteDiff()
171+
end
172+
else
173+
ad = __get_nonsparse_ad(vjp_autodiff)
174+
if isinplace(uf) && ad isa AutoZygote
175+
@warn "Attempting to use Zygote.jl for linesearch on an in-place problem. \
176+
Falling back to finite differencing."
177+
return AutoFiniteDiff()
178+
end
179+
return ad
180+
end
181+
end
132182

133183
__maybe_symmetric(x) = Symmetric(x)
134184
__maybe_symmetric(x::Number) = x
135185
# LinearSolve with `nothing` doesn't dispatch correctly here
136186
__maybe_symmetric(x::StaticArray) = x
137187
__maybe_symmetric(x::SparseArrays.AbstractSparseMatrix) = x
188+
__maybe_symmetric(x::SciMLOperators.AbstractSciMLOperator) = x
189+
__maybe_symmetric(x::KrylovJᵀJ) = x.JᵀJ
138190

139191
## Special Handling for Scalars
140192
function jacobian_caches(alg::AbstractNonlinearSolveAlgorithm, f::F, u::Number, p,
@@ -145,3 +197,37 @@ function jacobian_caches(alg::AbstractNonlinearSolveAlgorithm, f::F, u::Number,
145197
needsJᵀJ && return uf, nothing, u, nothing, nothing, u, u, u
146198
return uf, nothing, u, nothing, nothing, u
147199
end
200+
201+
function __update_JᵀJ!(iip::Val, cache, sym::Symbol, J)
202+
return __update_JᵀJ!(iip, cache, sym, getproperty(cache, sym), J)
203+
end
204+
__update_JᵀJ!(::Val{false}, cache, sym::Symbol, _, J) = setproperty!(cache, sym, J' * J)
205+
__update_JᵀJ!(::Val{true}, cache, sym::Symbol, _, J) = mul!(getproperty(cache, sym), J', J)
206+
__update_JᵀJ!(::Val{false}, cache, sym::Symbol, H::KrylovJᵀJ, J) = H
207+
__update_JᵀJ!(::Val{true}, cache, sym::Symbol, H::KrylovJᵀJ, J) = H
208+
209+
function __update_Jᵀf!(iip::Val, cache, sym1::Symbol, sym2::Symbol, J, fu)
210+
return __update_Jᵀf!(iip, cache, sym1, sym2, getproperty(cache, sym2), J, fu)
211+
end
212+
function __update_Jᵀf!(::Val{false}, cache, sym1::Symbol, sym2::Symbol, _, J, fu)
213+
return setproperty!(cache, sym1, _restructure(getproperty(cache, sym1), J' * fu))
214+
end
215+
function __update_Jᵀf!(::Val{true}, cache, sym1::Symbol, sym2::Symbol, _, J, fu)
216+
return mul!(_vec(getproperty(cache, sym1)), J', fu)
217+
end
218+
function __update_Jᵀf!(::Val{false}, cache, sym1::Symbol, sym2::Symbol, H::KrylovJᵀJ, J, fu)
219+
return setproperty!(cache, sym1, _restructure(getproperty(cache, sym1), H.Jᵀ * fu))
220+
end
221+
function __update_Jᵀf!(::Val{true}, cache, sym1::Symbol, sym2::Symbol, H::KrylovJᵀJ, J, fu)
222+
return mul!(_vec(getproperty(cache, sym1)), H.Jᵀ, fu)
223+
end
224+
225+
# Left-Right Multiplication
226+
__lr_mul(::Val, H, g) = dot(g, H, g)
227+
## TODO: Use a cache here to avoid allocations
228+
__lr_mul(::Val{false}, H::KrylovJᵀJ, g) = dot(g, H.JᵀJ, g)
229+
function __lr_mul(::Val{true}, H::KrylovJᵀJ, g)
230+
c = similar(g)
231+
mul!(c, H.JᵀJ, g)
232+
return dot(g, c)
233+
end

src/levenberg.jl

+8-4
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,8 @@ function LevenbergMarquardt(; concrete_jac = nothing, linsolve = nothing,
106106
α_geodesic::Real = 0.75, b_uphill::Real = 1.0, min_damping_D::AbstractFloat = 1e-8,
107107
adkwargs...)
108108
ad = default_adargs_to_adtype(; adkwargs...)
109-
return LevenbergMarquardt{_unwrap_val(concrete_jac)}(ad, linsolve, precs,
109+
_concrete_jac = ifelse(concrete_jac === nothing, true, concrete_jac)
110+
return LevenbergMarquardt{_unwrap_val(_concrete_jac)}(ad, linsolve, precs,
110111
damping_initial, damping_increase_factor, damping_decrease_factor,
111112
finite_diff_step_geodesic, α_geodesic, b_uphill, min_damping_D)
112113
end
@@ -365,9 +366,10 @@ function perform_step!(cache::LevenbergMarquardtCache{false, fastls}) where {fas
365366
if linsolve === nothing
366367
cache.v = -cache.mat_tmp \ (J' * fu1)
367368
else
368-
linres = dolinsolve(alg.precs, linsolve; A = -__maybe_symmetric(cache.mat_tmp),
369+
linres = dolinsolve(alg.precs, linsolve; A = __maybe_symmetric(cache.mat_tmp),
369370
b = _vec(J' * _vec(fu1)), linu = _vec(cache.v), p, reltol = cache.abstol)
370371
cache.linsolve = linres.cache
372+
cache.v .*= -1
371373
end
372374
end
373375

@@ -383,9 +385,11 @@ function perform_step!(cache::LevenbergMarquardtCache{false, fastls}) where {fas
383385
if linsolve === nothing
384386
cache.a = -cache.mat_tmp \ _vec(J' * rhs_term)
385387
else
386-
linres = dolinsolve(alg.precs, linsolve; b = _mutable(_vec(J' * rhs_term)),
387-
linu = _vec(cache.a), p, reltol = cache.abstol)
388+
linres = dolinsolve(alg.precs, linsolve; A = __maybe_symmetric(cache.mat_tmp),
389+
b = _mutable(_vec(J' * rhs_term)), linu = _vec(cache.a), p,
390+
reltol = cache.abstol, reuse_A_if_factorization = true)
388391
cache.linsolve = linres.cache
392+
cache.a .*= -1
389393
end
390394
end
391395
cache.stats.nsolve += 1

0 commit comments

Comments
 (0)