Skip to content

Commit aea4ca1

Browse files
committedNov 13, 2023
Fix matrix resizing
1 parent e3ecfd1 commit aea4ca1

File tree

3 files changed

+15
-20
lines changed

3 files changed

+15
-20
lines changed
 

‎src/gaussnewton.jl

+4-4
Original file line numberDiff line numberDiff line change
@@ -113,8 +113,8 @@ function perform_step!(cache::GaussNewtonCache{true})
113113
jacobian!!(J, cache)
114114

115115
if JᵀJ !== nothing
116-
__matmul!(JᵀJ, J', J)
117-
__matmul!(Jᵀf, J', fu1)
116+
__update_JᵀJ!(Val{true}(), cache, :JᵀJ, J)
117+
__update_Jᵀf!(Val{true}(), cache, :Jᵀf, :JᵀJ, J, fu1)
118118
end
119119

120120
# u = u - JᵀJ \ Jᵀfu
@@ -151,8 +151,8 @@ function perform_step!(cache::GaussNewtonCache{false})
151151
cache.J = jacobian!!(cache.J, cache)
152152

153153
if cache.JᵀJ !== nothing
154-
cache.JᵀJ = cache.J' * cache.J
155-
cache.Jᵀf = cache.J' * fu1
154+
__update_JᵀJ!(Val{false}(), cache, :JᵀJ, cache.J)
155+
__update_Jᵀf!(Val{false}(), cache, :Jᵀf, :JᵀJ, cache.J, fu1)
156156
end
157157

158158
# u = u - J \ fu

‎src/jacobian.jl

+6-6
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ function jacobian_caches(alg::AbstractNonlinearSolveAlgorithm, f::F, u, p, ::Val
8383

8484
J = if !(linsolve_needs_jac || alg_wants_jac)
8585
# We don't need to construct the Jacobian
86-
JacVec(uf, u; autodiff = __get_nonsparse_ad(alg.ad))
86+
JacVec(uf, u; fu, autodiff = __get_nonsparse_ad(alg.ad))
8787
else
8888
if has_analytic_jac
8989
f.jac_prototype === nothing ? undefmatrix(u) : f.jac_prototype
@@ -179,7 +179,7 @@ __maybe_symmetric(x::Number) = x
179179
__maybe_symmetric(x::StaticArray) = x
180180
__maybe_symmetric(x::SparseArrays.AbstractSparseMatrix) = x
181181
__maybe_symmetric(x::SciMLOperators.AbstractSciMLOperator) = x
182-
__maybe_symmetric(x::KrylovJᵀJ) = x
182+
__maybe_symmetric(x::KrylovJᵀJ) = x.JᵀJ
183183

184184
## Special Handling for Scalars
185185
function jacobian_caches(alg::AbstractNonlinearSolveAlgorithm, f::F, u::Number, p,
@@ -203,16 +203,16 @@ function __update_Jᵀf!(iip::Val, cache, sym1::Symbol, sym2::Symbol, J, fu)
203203
return __update_Jᵀf!(iip, cache, sym1, sym2, getproperty(cache, sym2), J, fu)
204204
end
205205
function __update_Jᵀf!(::Val{false}, cache, sym1::Symbol, sym2::Symbol, _, J, fu)
206-
return setproperty!(cache, sym1, J' * fu)
206+
return setproperty!(cache, sym1, _restructure(getproperty(cache, sym1), J' * fu))
207207
end
208208
function __update_Jᵀf!(::Val{true}, cache, sym1::Symbol, sym2::Symbol, _, J, fu)
209-
return mul!(getproperty(cache, sym1), J', fu)
209+
return mul!(vec(getproperty(cache, sym1)), J', fu)
210210
end
211211
function __update_Jᵀf!(::Val{false}, cache, sym1::Symbol, sym2::Symbol, H::KrylovJᵀJ, J, fu)
212-
return setproperty!(cache, sym1, H.Jᵀ * fu)
212+
return setproperty!(cache, sym1, _restructure(getproperty(cache, sym1), H.Jᵀ * fu))
213213
end
214214
function __update_Jᵀf!(::Val{true}, cache, sym1::Symbol, sym2::Symbol, H::KrylovJᵀJ, J, fu)
215-
return mul!(getproperty(cache, sym1), H.Jᵀ, fu)
215+
return mul!(vec(getproperty(cache, sym1)), H.Jᵀ, fu)
216216
end
217217

218218
# Left-Right Multiplication

‎src/trustRegion.jl

+5-10
Original file line numberDiff line numberDiff line change
@@ -239,19 +239,16 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg_::TrustRegion,
239239
fu_prev = zero(fu1)
240240

241241
loss = get_loss(fu1)
242-
# uf, linsolve, J, fu2, jac_cache, du = jacobian_caches(alg, f, u, p, Val(iip);
243-
# linsolve_kwargs)
244242
uf, _, J, fu2, jac_cache, du, H, g = jacobian_caches(alg, f, u, p, Val(iip);
245243
linsolve_kwargs, linsolve_with_JᵀJ = Val(true), lininit = Val(false))
244+
g = _restructure(fu1, g)
246245
linsolve = u isa Number ? nothing : __setup_linsolve(J, fu2, du, p, alg)
247246

248247
u_tmp = zero(u)
249248
u_cauchy = zero(u)
250249
u_gauss_newton = _mutable_zero(u)
251250

252251
loss_new = loss
253-
# H = zero(J' * J)
254-
# g = _mutable_zero(fu1)
255252
shrink_counter = 0
256253
fu_new = zero(fu1)
257254
make_new_J = true
@@ -351,9 +348,7 @@ function perform_step!(cache::TrustRegionCache{true})
351348
if cache.make_new_J
352349
jacobian!!(J, cache)
353350
__update_JᵀJ!(Val{true}(), cache, :H, J)
354-
# mul!(cache.H, J', J)
355-
__update_Jᵀf!(Val{true}(), cache, :g, :H, J, fu)
356-
# mul!(_vec(cache.g), J', _vec(fu))
351+
__update_Jᵀf!(Val{true}(), cache, :g, :H, J, vec(fu))
357352
cache.stats.njacs += 1
358353

359354
# do not use A = cache.H, b = _vec(cache.g) since it is equivalent
@@ -383,7 +378,7 @@ function perform_step!(cache::TrustRegionCache{false})
383378
if make_new_J
384379
J = jacobian!!(cache.J, cache)
385380
__update_JᵀJ!(Val{false}(), cache, :H, J)
386-
__update_Jᵀf!(Val{false}(), cache, :g, :H, J, fu)
381+
__update_Jᵀf!(Val{false}(), cache, :g, :H, J, vec(fu))
387382
cache.stats.njacs += 1
388383

389384
if cache.linsolve === nothing
@@ -420,8 +415,8 @@ function retrospective_step!(cache::TrustRegionCache)
420415
cache.H = J' * J
421416
cache.g = J' * fu
422417
else
423-
mul!(cache.H, J', J)
424-
mul!(cache.g, J', fu)
418+
__update_JᵀJ!(Val{isinplace(cache)}(), cache, :H, J)
419+
__update_Jᵀf!(Val{isinplace(cache)}(), cache, :g, :H, J, fu)
425420
end
426421
cache.stats.njacs += 1
427422
@unpack H, g, du = cache

0 commit comments

Comments
 (0)