1
+ @concrete struct KrylovJᵀJ
2
+ JᵀJ
3
+ Jᵀ
4
+ end
5
+
6
+ SciMLBase. isinplace (JᵀJ:: KrylovJ ᵀJ) = isinplace (JᵀJ. Jᵀ)
7
+
1
8
sparsity_detection_alg (_, _) = NoSparsityDetection ()
2
9
function sparsity_detection_alg (f, ad:: AbstractSparseADType )
3
10
if f. sparsity === nothing
@@ -54,7 +61,7 @@ function jacobian_caches(alg::AbstractNonlinearSolveAlgorithm, f::F, u, p, ::Val
54
61
# NOTE: The deepcopy is needed here since we are using the resid_prototype elsewhere
55
62
fu = f. resid_prototype === nothing ? (iip ? _mutable_zero (u) : _mutable (f (u, p))) :
56
63
(iip ? deepcopy (f. resid_prototype) : f. resid_prototype)
57
- if ! has_analytic_jac && (linsolve_needs_jac || alg_wants_jac || needsJᵀJ )
64
+ if ! has_analytic_jac && (linsolve_needs_jac || alg_wants_jac)
58
65
sd = sparsity_detection_alg (f, alg. ad)
59
66
ad = alg. ad
60
67
jac_cache = iip ? sparse_jacobian_cache (ad, sd, uf, fu, _maybe_mutable (u, ad)) :
@@ -63,12 +70,10 @@ function jacobian_caches(alg::AbstractNonlinearSolveAlgorithm, f::F, u, p, ::Val
63
70
jac_cache = nothing
64
71
end
65
72
66
- # FIXME : To properly support needsJᵀJ without Jacobian, we need to implement
67
- # a reverse diff operation with the seed being `Jx`, this is not yet implemented
68
- J = if ! (linsolve_needs_jac || alg_wants_jac || needsJᵀJ)
73
+ J = if ! (linsolve_needs_jac || alg_wants_jac)
69
74
if f. jvp === nothing
70
75
# We don't need to construct the Jacobian
71
- JacVec (uf, u; autodiff = __get_nonsparse_ad (alg. ad))
76
+ JacVec (uf, u; fu, autodiff = __get_nonsparse_ad (alg. ad))
72
77
else
73
78
if iip
74
79
jvp = (_, u, v) -> (du = similar (fu); f. jvp (du, v, u, p); du)
@@ -92,9 +97,9 @@ function jacobian_caches(alg::AbstractNonlinearSolveAlgorithm, f::F, u, p, ::Val
92
97
du = _mutable_zero (u)
93
98
94
99
if needsJᵀJ
95
- JᵀJ = __init_JᵀJ (J)
96
- # FIXME : This needs to be handled better for JacVec Operator
97
- Jᵀfu = J ' * _vec (fu )
100
+ JᵀJ, Jᵀfu = __init_JᵀJ (J, _vec (fu), uf, u; f,
101
+ vjp_autodiff = __get_nonsparse_ad ( _getproperty (alg, Val ( :vjp_autodiff ))),
102
+ jvp_autodiff = __get_nonsparse_ad (alg . ad) )
98
103
end
99
104
100
105
if linsolve_init
@@ -120,21 +125,68 @@ function __setup_linsolve(A, b, u, p, alg)
120
125
nothing )... , weight)
121
126
return init (linprob, alg. linsolve; alias_A = true , alias_b = true , Pl, Pr)
122
127
end
128
+ __setup_linsolve (A:: KrylovJ ᵀJ, b, u, p, alg) = __setup_linsolve (A. JᵀJ, b, u, p, alg)
123
129
124
130
__get_nonsparse_ad (:: AutoSparseForwardDiff ) = AutoForwardDiff ()
125
131
__get_nonsparse_ad (:: AutoSparseFiniteDiff ) = AutoFiniteDiff ()
126
132
__get_nonsparse_ad (:: AutoSparseZygote ) = AutoZygote ()
127
133
__get_nonsparse_ad (ad) = ad
128
134
129
- __init_JᵀJ (J:: Number ) = zero (J)
130
- __init_JᵀJ (J:: AbstractArray ) = J' * J
131
- __init_JᵀJ (J:: StaticArray ) = MArray {Tuple{size(J, 2), size(J, 2)}, eltype(J)} (undef)
135
+ __init_JᵀJ (J:: Number , args... ; kwargs... ) = zero (J), zero (J)
136
+ function __init_JᵀJ (J:: AbstractArray , fu, args... ; kwargs... )
137
+ JᵀJ = J' * J
138
+ Jᵀfu = J' * fu
139
+ return JᵀJ, Jᵀfu
140
+ end
141
+ function __init_JᵀJ (J:: StaticArray , fu, args... ; kwargs... )
142
+ JᵀJ = MArray {Tuple{size(J, 2), size(J, 2)}, eltype(J)} (undef)
143
+ return JᵀJ, J' * fu
144
+ end
145
+ function __init_JᵀJ (J:: FunctionOperator , fu, uf, u, args... ; f = nothing ,
146
+ vjp_autodiff = nothing , jvp_autodiff = nothing , kwargs... )
147
+ # FIXME : Proper fix to this requires the FunctionOperator patch
148
+ if f != = nothing && f. vjp != = nothing
149
+ @warn " Currently we don't make use of user provided `jvp`. This is planned to be \
150
+ fixed in the near future."
151
+ end
152
+ autodiff = __concrete_vjp_autodiff (vjp_autodiff, jvp_autodiff, uf)
153
+ Jᵀ = VecJac (uf, u; fu, autodiff)
154
+ JᵀJ_op = SciMLOperators. cache_operator (Jᵀ * J, u)
155
+ JᵀJ = KrylovJᵀJ (JᵀJ_op, Jᵀ)
156
+ Jᵀfu = Jᵀ * fu
157
+ return JᵀJ, Jᵀfu
158
+ end
159
+
160
+ function __concrete_vjp_autodiff (vjp_autodiff, jvp_autodiff, uf)
161
+ if vjp_autodiff === nothing
162
+ if isinplace (uf)
163
+ # VecJac can be only FiniteDiff
164
+ return AutoFiniteDiff ()
165
+ else
166
+ # Short circuit if we see that FiniteDiff was used for J computation
167
+ jvp_autodiff isa AutoFiniteDiff && return jvp_autodiff
168
+ # Check if Zygote is loaded then use Zygote else use FiniteDiff
169
+ is_extension_loaded (Val {:Zygote} ()) && return AutoZygote ()
170
+ return AutoFiniteDiff ()
171
+ end
172
+ else
173
+ ad = __get_nonsparse_ad (vjp_autodiff)
174
+ if isinplace (uf) && ad isa AutoZygote
175
+ @warn " Attempting to use Zygote.jl for linesearch on an in-place problem. \
176
+ Falling back to finite differencing."
177
+ return AutoFiniteDiff ()
178
+ end
179
+ return ad
180
+ end
181
+ end
132
182
133
183
__maybe_symmetric (x) = Symmetric (x)
134
184
__maybe_symmetric (x:: Number ) = x
135
185
# LinearSolve with `nothing` doesn't dispatch correctly here
136
186
__maybe_symmetric (x:: StaticArray ) = x
137
187
__maybe_symmetric (x:: SparseArrays.AbstractSparseMatrix ) = x
188
+ __maybe_symmetric (x:: SciMLOperators.AbstractSciMLOperator ) = x
189
+ __maybe_symmetric (x:: KrylovJ ᵀJ) = x. JᵀJ
138
190
139
191
# # Special Handling for Scalars
140
192
function jacobian_caches (alg:: AbstractNonlinearSolveAlgorithm , f:: F , u:: Number , p,
@@ -145,3 +197,37 @@ function jacobian_caches(alg::AbstractNonlinearSolveAlgorithm, f::F, u::Number,
145
197
needsJᵀJ && return uf, nothing , u, nothing , nothing , u, u, u
146
198
return uf, nothing , u, nothing , nothing , u
147
199
end
200
+
201
+ function __update_JᵀJ! (iip:: Val , cache, sym:: Symbol , J)
202
+ return __update_JᵀJ! (iip, cache, sym, getproperty (cache, sym), J)
203
+ end
204
+ __update_JᵀJ! (:: Val{false} , cache, sym:: Symbol , _, J) = setproperty! (cache, sym, J' * J)
205
+ __update_JᵀJ! (:: Val{true} , cache, sym:: Symbol , _, J) = mul! (getproperty (cache, sym), J' , J)
206
+ __update_JᵀJ! (:: Val{false} , cache, sym:: Symbol , H:: KrylovJ ᵀJ, J) = H
207
+ __update_JᵀJ! (:: Val{true} , cache, sym:: Symbol , H:: KrylovJ ᵀJ, J) = H
208
+
209
+ function __update_Jᵀf! (iip:: Val , cache, sym1:: Symbol , sym2:: Symbol , J, fu)
210
+ return __update_Jᵀf! (iip, cache, sym1, sym2, getproperty (cache, sym2), J, fu)
211
+ end
212
+ function __update_Jᵀf! (:: Val{false} , cache, sym1:: Symbol , sym2:: Symbol , _, J, fu)
213
+ return setproperty! (cache, sym1, _restructure (getproperty (cache, sym1), J' * fu))
214
+ end
215
+ function __update_Jᵀf! (:: Val{true} , cache, sym1:: Symbol , sym2:: Symbol , _, J, fu)
216
+ return mul! (_vec (getproperty (cache, sym1)), J' , fu)
217
+ end
218
+ function __update_Jᵀf! (:: Val{false} , cache, sym1:: Symbol , sym2:: Symbol , H:: KrylovJ ᵀJ, J, fu)
219
+ return setproperty! (cache, sym1, _restructure (getproperty (cache, sym1), H. Jᵀ * fu))
220
+ end
221
+ function __update_Jᵀf! (:: Val{true} , cache, sym1:: Symbol , sym2:: Symbol , H:: KrylovJ ᵀJ, J, fu)
222
+ return mul! (_vec (getproperty (cache, sym1)), H. Jᵀ, fu)
223
+ end
224
+
225
+ # Left-Right Multiplication
226
+ __lr_mul (:: Val , H, g) = dot (g, H, g)
227
+ # # TODO : Use a cache here to avoid allocations
228
+ __lr_mul (:: Val{false} , H:: KrylovJ ᵀJ, g) = dot (g, H. JᵀJ, g)
229
+ function __lr_mul (:: Val{true} , H:: KrylovJ ᵀJ, g)
230
+ c = similar (g)
231
+ mul! (c, H. JᵀJ, g)
232
+ return dot (g, c)
233
+ end
0 commit comments