@@ -31,35 +31,24 @@ Construct a cache for the Jacobian of `f` w.r.t. `u`.
31
31
@concrete mutable struct JacobianCache{iip} <: AbstractNonlinearSolveJacobianCache{iip}
32
32
J
33
33
f
34
- uf
35
34
fu
36
35
u
37
36
p
38
- jac_cache
39
- alg
40
37
stats:: NLStats
41
38
autodiff
42
- vjp_autodiff
43
- jvp_autodiff
39
+ di_extras
40
+ sdifft_extras
44
41
end
45
42
46
43
function reinit_cache! (cache:: JacobianCache{iip} , args... ; p = cache. p,
47
44
u0 = cache. u, kwargs... ) where {iip}
48
45
cache. u = u0
49
46
cache. p = p
50
- cache. uf = JacobianWrapper {iip} (cache. f, p)
51
47
end
52
48
53
49
function JacobianCache (prob, alg, f:: F , fu_, u, p; stats, autodiff = nothing ,
54
50
vjp_autodiff = nothing , jvp_autodiff = nothing , linsolve = missing ) where {F}
55
51
iip = isinplace (prob)
56
- uf = JacobianWrapper {iip} (f, p)
57
-
58
- autodiff = get_concrete_forward_ad (autodiff, prob; check_forward_mode = false )
59
- jvp_autodiff = get_concrete_forward_ad (
60
- jvp_autodiff, prob, Val (false ); check_forward_mode = true )
61
- vjp_autodiff = get_concrete_reverse_ad (
62
- vjp_autodiff, prob, Val (false ); check_reverse_mode = false )
63
52
64
53
has_analytic_jac = SciMLBase. has_jac (f)
65
54
linsolve_needs_jac = concrete_jac (alg) === nothing && (linsolve === missing ||
@@ -70,90 +59,128 @@ function JacobianCache(prob, alg, f::F, fu_, u, p; stats, autodiff = nothing,
70
59
@bb fu = similar (fu_)
71
60
72
61
if ! has_analytic_jac && needs_jac
73
- sd = __sparsity_detection_alg (f, autodiff)
74
- jac_cache = iip ? sparse_jacobian_cache (autodiff, sd, uf, fu, u) :
75
- sparse_jacobian_cache (
76
- autodiff, sd, uf, __maybe_mutable (u, autodiff); fx = fu)
62
+ autodiff = get_concrete_forward_ad (autodiff, prob; check_forward_mode = false )
63
+ sd = sparsity_detection_alg (f, autodiff)
64
+ sparse_jac = ! (sd isa NoSparsityDetection)
65
+ # Eventually we want to do everything via DI. But for now, we just do the dense via DI
66
+ if sparse_jac
67
+ di_extras = nothing
68
+ uf = JacobianWrapper {iip} (f, p)
69
+ sdifft_extras = if iip
70
+ sparse_jacobian_cache (autodiff, sd, uf, fu, u)
71
+ else
72
+ sparse_jacobian_cache (
73
+ autodiff, sd, uf, __maybe_mutable (u, autodiff); fx = fu)
74
+ end
75
+ else
76
+ sdifft_extras = nothing
77
+ di_extras = if iip
78
+ DI. prepare_jacobian (f, fu, autodiff, u, Constant (p))
79
+ else
80
+ DI. prepare_jacobian (f, autodiff, u, Constant (p))
81
+ end
82
+ end
77
83
else
78
- jac_cache = nothing
84
+ sparse_jac = false
85
+ di_extras = nothing
86
+ sdifft_extras = nothing
79
87
end
80
88
81
89
J = if ! needs_jac
90
+ jvp_autodiff = get_concrete_forward_ad (
91
+ jvp_autodiff, prob, Val (false ); check_forward_mode = true )
92
+ vjp_autodiff = get_concrete_reverse_ad (
93
+ vjp_autodiff, prob, Val (false ); check_reverse_mode = false )
82
94
JacobianOperator (prob, fu, u; jvp_autodiff, vjp_autodiff)
83
95
else
84
- if has_analytic_jac
85
- f . jac_prototype === nothing ?
86
- __similar (fu, promote_type (eltype (fu), eltype (u)), length (fu), length (u)) :
87
- copy (f . jac_prototype)
88
- elseif f . jac_prototype === nothing
89
- zero ( init_jacobian (jac_cache; preserve_immutable = Val ( true )))
96
+ if f . jac_prototype === nothing
97
+ if ! sparse_jac
98
+ __similar (fu, promote_type (eltype (fu), eltype (u)), length (fu), length (u))
99
+ else
100
+ zero ( init_jacobian (sdifft_extras; preserve_immutable = Val ( true )))
101
+ end
90
102
else
91
- f. jac_prototype
103
+ similar ( f. jac_prototype)
92
104
end
93
105
end
94
106
95
107
return JacobianCache {iip} (
96
- J, f, uf, fu, u, p, jac_cache, alg, stats, autodiff, vjp_autodiff, jvp_autodiff )
108
+ J, f, fu, u, p, stats, autodiff, di_extras, sdifft_extras )
97
109
end
98
110
99
111
function JacobianCache (prob, alg, f:: F , :: Number , u:: Number , p; stats,
100
112
autodiff = nothing , kwargs... ) where {F}
101
- uf = JacobianWrapper {false} (f, p)
102
- autodiff = get_concrete_forward_ad (autodiff, prob; check_forward_mode = false )
103
- if ! (autodiff isa AutoForwardDiff ||
104
- autodiff isa AutoPolyesterForwardDiff ||
105
- autodiff isa AutoFiniteDiff)
106
- # Other cases are not properly supported so we fallback to finite differencing
107
- @warn " Scalar AD is supported only for AutoForwardDiff and AutoFiniteDiff. \
108
- Detected $(autodiff) . Falling back to AutoFiniteDiff."
109
- autodiff = AutoFiniteDiff ()
113
+ fu = f (u, p)
114
+ if SciMLBase. has_jac (f) || SciMLBase. has_vjp (f) || SciMLBase. has_jvp (f)
115
+ return JacobianCache {false} (u, f, fu, u, p, stats, autodiff, nothing )
110
116
end
111
- return JacobianCache {false} (
112
- u, f, uf, u, u, p, nothing , alg, stats, autodiff, nothing , nothing )
117
+ autodiff = get_concrete_forward_ad (autodiff, prob; check_forward_mode = false )
118
+ di_extras = DI. prepare_derivative (f, autodiff, u, Constant (prob. p))
119
+ return JacobianCache {false} (u, f, fu, u, p, stats, autodiff, di_extras, nothing )
113
120
end
114
121
115
- @inline (cache:: JacobianCache )(u = cache. u) = cache (cache. J, u, cache. p)
116
- @inline function (cache:: JacobianCache )(:: Nothing )
122
+ (cache:: JacobianCache )(u = cache. u) = cache (cache. J, u, cache. p)
123
+ function (cache:: JacobianCache )(:: Nothing )
117
124
cache. J isa JacobianOperator &&
118
125
return StatefulJacobianOperator (cache. J, cache. u, cache. p)
119
126
return cache. J
120
127
end
121
128
129
+ # Operator
122
130
function (cache:: JacobianCache )(J:: JacobianOperator , u, p = cache. p)
123
131
return StatefulJacobianOperator (J, u, p)
124
132
end
133
+ # Numbers
125
134
function (cache:: JacobianCache )(:: Number , u, p = cache. p) # Scalar
126
135
cache. stats. njacs += 1
127
- J = last (__value_derivative (cache. autodiff, cache. uf, u))
128
- return J
136
+ if SciMLBase. has_jac (cache. f)
137
+ return cache. f. jac (u, p)
138
+ elseif SciMLBase. has_vjp (cache. f)
139
+ return cache. f. vjp (one (u), u, p)
140
+ elseif SciMLBase. has_jvp (cache. f)
141
+ return cache. f. jvp (one (u), u, p)
142
+ end
143
+ return DI. derivative (cache. f, cache. di_extras, cache. autodiff, u, Constant (p))
129
144
end
130
- # Compute the Jacobian
145
+ # Actually Compute the Jacobian
131
146
function (cache:: JacobianCache{iip} )(
132
147
J:: Union{AbstractMatrix, Nothing} , u, p = cache. p) where {iip}
133
148
cache. stats. njacs += 1
134
149
if iip
135
- if has_jac (cache. f)
150
+ if SciMLBase . has_jac (cache. f)
136
151
cache. f. jac (J, u, p)
152
+ elseif cache. di_extras != = nothing
153
+ DI. jacobian! (
154
+ cache. f, cache. fu, J, cache. di_extras, cache. autodiff, u, Constant (p))
137
155
else
138
- sparse_jacobian! (J, cache. autodiff, cache. jac_cache, cache. uf, cache. fu, u)
156
+ uf = JacobianWrapper {iip} (cache. f, p)
157
+ sparse_jacobian! (J, cache. autodiff, cache. jac_cache, uf, cache. fu, u)
139
158
end
140
- J_ = J
159
+ return J
141
160
else
142
- J_ = if has_jac (cache. f)
143
- cache. f. jac (u, p)
144
- elseif __can_setindex (typeof (J))
145
- sparse_jacobian! (J, cache. autodiff, cache. jac_cache, cache. uf, u)
146
- J
161
+ if SciMLBase. has_jac (cache. f)
162
+ return cache. f. jac (u, p)
163
+ elseif cache. di_extras != = nothing
164
+ return DI. jacobian (cache. f, cache. di_extras, cache. autodiff, u, Constant (p))
147
165
else
148
- sparse_jacobian (cache. autodiff, cache. jac_cache, cache. uf, u)
166
+ uf = JacobianWrapper {iip} (cache. f, p)
167
+ if __can_setindex (typeof (J))
168
+ sparse_jacobian! (J, cache. autodiff, cache. sdifft_extras, uf, u)
169
+ return J
170
+ else
171
+ return sparse_jacobian (cache. autodiff, cache. sdifft_extras, uf, u)
172
+ end
149
173
end
150
174
end
151
- return J_
152
175
end
153
176
154
- # Sparsity Detection Choices
155
- @inline __sparsity_detection_alg (_, _) = NoSparsityDetection ()
156
- @inline function __sparsity_detection_alg (f:: NonlinearFunction , ad:: AutoSparse )
177
+ function sparsity_detection_alg (f:: NonlinearFunction , ad:: AbstractADType )
178
+ # TODO : Also handle case where colorvec is provided
179
+ f. sparsity === nothing && return NoSparsityDetection ()
180
+ return sparsity_detection_alg (f, AutoSparse (ad; sparsity_detector = f. sparsity))
181
+ end
182
+
183
+ function sparsity_detection_alg (f:: NonlinearFunction , ad:: AutoSparse )
157
184
if f. sparsity === nothing
158
185
if f. jac_prototype === nothing
159
186
is_extension_loaded (Val (:Symbolics )) && return SymbolicsSparsityDetection ()
177
204
end
178
205
179
206
if SciMLBase. has_colorvec (f)
180
- return PrecomputedJacobianColorvec (; jac_prototype,
181
- f. colorvec,
182
- partition_by_rows = (ad isa AutoSparse &&
183
- ADTypes. mode (ad) isa ADTypes. ReverseMode))
207
+ return PrecomputedJacobianColorvec (; jac_prototype, f. colorvec,
208
+ partition_by_rows = ADTypes. mode (ad) isa ADTypes. ReverseMode)
184
209
else
185
210
return JacPrototypeSparsityDetection (; jac_prototype)
186
211
end
187
212
end
188
-
189
- @inline function __value_derivative (
190
- :: Union{AutoForwardDiff, AutoPolyesterForwardDiff} , f:: F , x:: R ) where {F, R}
191
- T = typeof (ForwardDiff. Tag (f, R))
192
- out = f (ForwardDiff. Dual {T} (x, one (x)))
193
- return ForwardDiff. value (out), ForwardDiff. extract_derivative (T, out)
194
- end
195
-
196
- @inline function __value_derivative (ad:: AutoFiniteDiff , f:: F , x:: R ) where {F, R}
197
- return f (x), FiniteDiff. finite_difference_derivative (f, x, ad. fdtype)
198
- end
199
-
200
- @inline function __scalar_jacvec (f:: F , x:: R , v:: V ) where {F, R, V}
201
- T = typeof (ForwardDiff. Tag (f, R))
202
- out = f (ForwardDiff. Dual {T} (x, v))
203
- return ForwardDiff. value (out), ForwardDiff. extract_derivative (T, out)
204
- end
0 commit comments