Skip to content

Commit a8e60d5

Browse files
committed
feat: use DI for dense jacobians
fix: stop storing extra stuff in JacobianCache feat: use DI for dense jacobians refactor: remove alg from the cache fix: don't ignore sparsity for dense_ad
1 parent 05aa3db commit a8e60d5

File tree

3 files changed

+93
-80
lines changed

3 files changed

+93
-80
lines changed

Project.toml

+3-1
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
88
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
99
ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471"
1010
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
11+
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
1112
FastBroadcast = "7034ab61-46d4-4ed7-9d0f-46aef9175898"
1213
FastClosures = "9aa1b823-49e4-5ca5-8b0f-3971ec8bab6a"
1314
FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41"
@@ -65,8 +66,9 @@ BandedMatrices = "1.5"
6566
BenchmarkTools = "1.4"
6667
CUDA = "5.2"
6768
ConcreteStructs = "0.2.3"
69+
DifferentiationInterface = "0.6.1"
6870
DiffEqBase = "6.149.0"
69-
Enzyme = "0.12"
71+
Enzyme = "0.12, 0.13"
7072
ExplicitImports = "1.5"
7173
FastBroadcast = "0.2.8, 0.3"
7274
FastClosures = "0.3.2"

src/NonlinearSolve.jl

+6-3
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@ end
77
using Reexport: @reexport
88
using PrecompileTools: @compile_workload, @setup_workload
99

10-
using ADTypes: ADTypes, AutoFiniteDiff, AutoForwardDiff, AutoPolyesterForwardDiff,
11-
AutoZygote, AutoEnzyme, AutoSparse
10+
using ADTypes: ADTypes, AbstractADType, AutoFiniteDiff, AutoForwardDiff,
11+
AutoPolyesterForwardDiff, AutoZygote, AutoEnzyme, AutoSparse
1212
# FIXME: deprecated, remove in future
1313
using ADTypes: AutoSparseFiniteDiff, AutoSparseForwardDiff, AutoSparsePolyesterForwardDiff,
1414
AutoSparseZygote
@@ -22,6 +22,7 @@ using DiffEqBase: DiffEqBase, AbstractNonlinearTerminationMode,
2222
NormTerminationMode, RelNormTerminationMode, RelSafeBestTerminationMode,
2323
RelSafeTerminationMode, RelTerminationMode,
2424
SimpleNonlinearSolveTerminationMode, SteadyStateDiffEqTerminationMode
25+
using DifferentiationInterface: DifferentiationInterface, Constant
2526
using FastBroadcast: @..
2627
using FastClosures: @closure
2728
using FiniteDiff: FiniteDiff
@@ -39,7 +40,7 @@ using Printf: @printf
3940
using Preferences: Preferences, @load_preference, @set_preferences!
4041
using RecursiveArrayTools: recursivecopy!, recursivefill!
4142
using SciMLBase: AbstractNonlinearAlgorithm, JacobianWrapper, AbstractNonlinearProblem,
42-
AbstractSciMLOperator, _unwrap_val, has_jac, isinplace, NLStats
43+
AbstractSciMLOperator, _unwrap_val, isinplace, NLStats
4344
using SciMLJacobianOperators: AbstractJacobianOperator, JacobianOperator, VecJacOperator,
4445
JacVecOperator, StatefulJacobianOperator
4546
using SparseArrays: AbstractSparseMatrix, SparseMatrixCSC
@@ -55,6 +56,8 @@ using SymbolicIndexingInterface: SymbolicIndexingInterface, ParameterIndexingPro
5556

5657
@reexport using SciMLBase, SimpleNonlinearSolve
5758

59+
const DI = DifferentiationInterface
60+
5861
# Type-Inference Friendly Check for Extension Loading
5962
is_extension_loaded(::Val) = false
6063

src/internal/jacobian.jl

+84-76
Original file line numberDiff line numberDiff line change
@@ -31,35 +31,24 @@ Construct a cache for the Jacobian of `f` w.r.t. `u`.
3131
@concrete mutable struct JacobianCache{iip} <: AbstractNonlinearSolveJacobianCache{iip}
3232
J
3333
f
34-
uf
3534
fu
3635
u
3736
p
38-
jac_cache
39-
alg
4037
stats::NLStats
4138
autodiff
42-
vjp_autodiff
43-
jvp_autodiff
39+
di_extras
40+
sdifft_extras
4441
end
4542

4643
function reinit_cache!(cache::JacobianCache{iip}, args...; p = cache.p,
4744
u0 = cache.u, kwargs...) where {iip}
4845
cache.u = u0
4946
cache.p = p
50-
cache.uf = JacobianWrapper{iip}(cache.f, p)
5147
end
5248

5349
function JacobianCache(prob, alg, f::F, fu_, u, p; stats, autodiff = nothing,
5450
vjp_autodiff = nothing, jvp_autodiff = nothing, linsolve = missing) where {F}
5551
iip = isinplace(prob)
56-
uf = JacobianWrapper{iip}(f, p)
57-
58-
autodiff = get_concrete_forward_ad(autodiff, prob; check_forward_mode = false)
59-
jvp_autodiff = get_concrete_forward_ad(
60-
jvp_autodiff, prob, Val(false); check_forward_mode = true)
61-
vjp_autodiff = get_concrete_reverse_ad(
62-
vjp_autodiff, prob, Val(false); check_reverse_mode = false)
6352

6453
has_analytic_jac = SciMLBase.has_jac(f)
6554
linsolve_needs_jac = concrete_jac(alg) === nothing && (linsolve === missing ||
@@ -70,90 +59,128 @@ function JacobianCache(prob, alg, f::F, fu_, u, p; stats, autodiff = nothing,
7059
@bb fu = similar(fu_)
7160

7261
if !has_analytic_jac && needs_jac
73-
sd = __sparsity_detection_alg(f, autodiff)
74-
jac_cache = iip ? sparse_jacobian_cache(autodiff, sd, uf, fu, u) :
75-
sparse_jacobian_cache(
76-
autodiff, sd, uf, __maybe_mutable(u, autodiff); fx = fu)
62+
autodiff = get_concrete_forward_ad(autodiff, prob; check_forward_mode = false)
63+
sd = sparsity_detection_alg(f, autodiff)
64+
sparse_jac = !(sd isa NoSparsityDetection)
65+
# Eventually we want to do everything via DI. But for now, we just do the dense via DI
66+
if sparse_jac
67+
di_extras = nothing
68+
uf = JacobianWrapper{iip}(f, p)
69+
sdifft_extras = if iip
70+
sparse_jacobian_cache(autodiff, sd, uf, fu, u)
71+
else
72+
sparse_jacobian_cache(
73+
autodiff, sd, uf, __maybe_mutable(u, autodiff); fx = fu)
74+
end
75+
else
76+
sdifft_extras = nothing
77+
di_extras = if iip
78+
DI.prepare_jacobian(f, fu, autodiff, u, Constant(p))
79+
else
80+
DI.prepare_jacobian(f, autodiff, u, Constant(p))
81+
end
82+
end
7783
else
78-
jac_cache = nothing
84+
sparse_jac = false
85+
di_extras = nothing
86+
sdifft_extras = nothing
7987
end
8088

8189
J = if !needs_jac
90+
jvp_autodiff = get_concrete_forward_ad(
91+
jvp_autodiff, prob, Val(false); check_forward_mode = true)
92+
vjp_autodiff = get_concrete_reverse_ad(
93+
vjp_autodiff, prob, Val(false); check_reverse_mode = false)
8294
JacobianOperator(prob, fu, u; jvp_autodiff, vjp_autodiff)
8395
else
84-
if has_analytic_jac
85-
f.jac_prototype === nothing ?
86-
__similar(fu, promote_type(eltype(fu), eltype(u)), length(fu), length(u)) :
87-
copy(f.jac_prototype)
88-
elseif f.jac_prototype === nothing
89-
zero(init_jacobian(jac_cache; preserve_immutable = Val(true)))
96+
if f.jac_prototype === nothing
97+
if !sparse_jac
98+
__similar(fu, promote_type(eltype(fu), eltype(u)), length(fu), length(u))
99+
else
100+
zero(init_jacobian(sdifft_extras; preserve_immutable = Val(true)))
101+
end
90102
else
91-
f.jac_prototype
103+
similar(f.jac_prototype)
92104
end
93105
end
94106

95107
return JacobianCache{iip}(
96-
J, f, uf, fu, u, p, jac_cache, alg, stats, autodiff, vjp_autodiff, jvp_autodiff)
108+
J, f, fu, u, p, stats, autodiff, di_extras, sdifft_extras)
97109
end
98110

99111
function JacobianCache(prob, alg, f::F, ::Number, u::Number, p; stats,
100112
autodiff = nothing, kwargs...) where {F}
101-
uf = JacobianWrapper{false}(f, p)
102-
autodiff = get_concrete_forward_ad(autodiff, prob; check_forward_mode = false)
103-
if !(autodiff isa AutoForwardDiff ||
104-
autodiff isa AutoPolyesterForwardDiff ||
105-
autodiff isa AutoFiniteDiff)
106-
# Other cases are not properly supported so we fallback to finite differencing
107-
@warn "Scalar AD is supported only for AutoForwardDiff and AutoFiniteDiff. \
108-
Detected $(autodiff). Falling back to AutoFiniteDiff."
109-
autodiff = AutoFiniteDiff()
113+
fu = f(u, p)
114+
if SciMLBase.has_jac(f) || SciMLBase.has_vjp(f) || SciMLBase.has_jvp(f)
115+
return JacobianCache{false}(u, f, fu, u, p, stats, autodiff, nothing)
110116
end
111-
return JacobianCache{false}(
112-
u, f, uf, u, u, p, nothing, alg, stats, autodiff, nothing, nothing)
117+
autodiff = get_concrete_forward_ad(autodiff, prob; check_forward_mode = false)
118+
di_extras = DI.prepare_derivative(f, autodiff, u, Constant(prob.p))
119+
return JacobianCache{false}(u, f, fu, u, p, stats, autodiff, di_extras, nothing)
113120
end
114121

115-
@inline (cache::JacobianCache)(u = cache.u) = cache(cache.J, u, cache.p)
116-
@inline function (cache::JacobianCache)(::Nothing)
122+
(cache::JacobianCache)(u = cache.u) = cache(cache.J, u, cache.p)
123+
function (cache::JacobianCache)(::Nothing)
117124
cache.J isa JacobianOperator &&
118125
return StatefulJacobianOperator(cache.J, cache.u, cache.p)
119126
return cache.J
120127
end
121128

129+
# Operator
122130
function (cache::JacobianCache)(J::JacobianOperator, u, p = cache.p)
123131
return StatefulJacobianOperator(J, u, p)
124132
end
133+
# Numbers
125134
function (cache::JacobianCache)(::Number, u, p = cache.p) # Scalar
126135
cache.stats.njacs += 1
127-
J = last(__value_derivative(cache.autodiff, cache.uf, u))
128-
return J
136+
if SciMLBase.has_jac(cache.f)
137+
return cache.f.jac(u, p)
138+
elseif SciMLBase.has_vjp(cache.f)
139+
return cache.f.vjp(one(u), u, p)
140+
elseif SciMLBase.has_jvp(cache.f)
141+
return cache.f.jvp(one(u), u, p)
142+
end
143+
return DI.derivative(cache.f, cache.di_extras, cache.autodiff, u, Constant(p))
129144
end
130-
# Compute the Jacobian
145+
# Actually Compute the Jacobian
131146
function (cache::JacobianCache{iip})(
132147
J::Union{AbstractMatrix, Nothing}, u, p = cache.p) where {iip}
133148
cache.stats.njacs += 1
134149
if iip
135-
if has_jac(cache.f)
150+
if SciMLBase.has_jac(cache.f)
136151
cache.f.jac(J, u, p)
152+
elseif cache.di_extras !== nothing
153+
DI.jacobian!(
154+
cache.f, cache.fu, J, cache.di_extras, cache.autodiff, u, Constant(p))
137155
else
138-
sparse_jacobian!(J, cache.autodiff, cache.jac_cache, cache.uf, cache.fu, u)
156+
uf = JacobianWrapper{iip}(cache.f, p)
157+
sparse_jacobian!(J, cache.autodiff, cache.jac_cache, uf, cache.fu, u)
139158
end
140-
J_ = J
159+
return J
141160
else
142-
J_ = if has_jac(cache.f)
143-
cache.f.jac(u, p)
144-
elseif __can_setindex(typeof(J))
145-
sparse_jacobian!(J, cache.autodiff, cache.jac_cache, cache.uf, u)
146-
J
161+
if SciMLBase.has_jac(cache.f)
162+
return cache.f.jac(u, p)
163+
elseif cache.di_extras !== nothing
164+
return DI.jacobian(cache.f, cache.di_extras, cache.autodiff, u, Constant(p))
147165
else
148-
sparse_jacobian(cache.autodiff, cache.jac_cache, cache.uf, u)
166+
uf = JacobianWrapper{iip}(cache.f, p)
167+
if __can_setindex(typeof(J))
168+
sparse_jacobian!(J, cache.autodiff, cache.sdifft_extras, uf, u)
169+
return J
170+
else
171+
return sparse_jacobian(cache.autodiff, cache.sdifft_extras, uf, u)
172+
end
149173
end
150174
end
151-
return J_
152175
end
153176

154-
# Sparsity Detection Choices
155-
@inline __sparsity_detection_alg(_, _) = NoSparsityDetection()
156-
@inline function __sparsity_detection_alg(f::NonlinearFunction, ad::AutoSparse)
177+
function sparsity_detection_alg(f::NonlinearFunction, ad::AbstractADType)
178+
# TODO: Also handle case where colorvec is provided
179+
f.sparsity === nothing && return NoSparsityDetection()
180+
return sparsity_detection_alg(f, AutoSparse(ad; sparsity_detector = f.sparsity))
181+
end
182+
183+
function sparsity_detection_alg(f::NonlinearFunction, ad::AutoSparse)
157184
if f.sparsity === nothing
158185
if f.jac_prototype === nothing
159186
is_extension_loaded(Val(:Symbolics)) && return SymbolicsSparsityDetection()
@@ -177,28 +204,9 @@ end
177204
end
178205

179206
if SciMLBase.has_colorvec(f)
180-
return PrecomputedJacobianColorvec(; jac_prototype,
181-
f.colorvec,
182-
partition_by_rows = (ad isa AutoSparse &&
183-
ADTypes.mode(ad) isa ADTypes.ReverseMode))
207+
return PrecomputedJacobianColorvec(; jac_prototype, f.colorvec,
208+
partition_by_rows = ADTypes.mode(ad) isa ADTypes.ReverseMode)
184209
else
185210
return JacPrototypeSparsityDetection(; jac_prototype)
186211
end
187212
end
188-
189-
@inline function __value_derivative(
190-
::Union{AutoForwardDiff, AutoPolyesterForwardDiff}, f::F, x::R) where {F, R}
191-
T = typeof(ForwardDiff.Tag(f, R))
192-
out = f(ForwardDiff.Dual{T}(x, one(x)))
193-
return ForwardDiff.value(out), ForwardDiff.extract_derivative(T, out)
194-
end
195-
196-
@inline function __value_derivative(ad::AutoFiniteDiff, f::F, x::R) where {F, R}
197-
return f(x), FiniteDiff.finite_difference_derivative(f, x, ad.fdtype)
198-
end
199-
200-
@inline function __scalar_jacvec(f::F, x::R, v::V) where {F, R, V}
201-
T = typeof(ForwardDiff.Tag(f, R))
202-
out = f(ForwardDiff.Dual{T}(x, v))
203-
return ForwardDiff.value(out), ForwardDiff.extract_derivative(T, out)
204-
end

0 commit comments

Comments
 (0)