Skip to content

Commit 6c52956

Browse files
committed
Fix aliasing issue
1 parent b79228d commit 6c52956

File tree

4 files changed

+26
-11
lines changed

4 files changed

+26
-11
lines changed

src/levenberg.jl

+6-3
Original file line numberDiff line numberDiff line change
@@ -366,9 +366,10 @@ function perform_step!(cache::LevenbergMarquardtCache{false, fastls}) where {fas
366366
if linsolve === nothing
367367
cache.v = -cache.mat_tmp \ (J' * fu1)
368368
else
369-
linres = dolinsolve(alg.precs, linsolve; A = -__maybe_symmetric(cache.mat_tmp),
369+
linres = dolinsolve(alg.precs, linsolve; A = __maybe_symmetric(cache.mat_tmp),
370370
b = _vec(J' * _vec(fu1)), linu = _vec(cache.v), p, reltol = cache.abstol)
371371
cache.linsolve = linres.cache
372+
cache.v .*= -1
372373
end
373374
end
374375

@@ -384,9 +385,11 @@ function perform_step!(cache::LevenbergMarquardtCache{false, fastls}) where {fas
384385
if linsolve === nothing
385386
cache.a = -cache.mat_tmp \ _vec(J' * rhs_term)
386387
else
387-
linres = dolinsolve(alg.precs, linsolve; b = _mutable(_vec(J' * rhs_term)),
388-
linu = _vec(cache.a), p, reltol = cache.abstol)
388+
linres = dolinsolve(alg.precs, linsolve; A = __maybe_symmetric(cache.mat_tmp),
389+
b = _mutable(_vec(J' * rhs_term)), linu = _vec(cache.a), p,
390+
reltol = cache.abstol, reuse_A_if_factorization = true)
389391
cache.linsolve = linres.cache
392+
cache.a .*= -1
390393
end
391394
end
392395
cache.stats.nsolve += 1

src/utils.jl

+16-4
Original file line numberDiff line numberDiff line change
@@ -82,16 +82,28 @@ end
8282
DEFAULT_PRECS(W, du, u, p, t, newW, Plprev, Prprev, cachedata) = nothing, nothing
8383

8484
function dolinsolve(precs::P, linsolve; A = nothing, linu = nothing, b = nothing,
85-
du = nothing, u = nothing, p = nothing, t = nothing, weight = nothing,
86-
cachedata = nothing, reltol = nothing) where {P}
87-
A !== nothing && (linsolve.A = A)
85+
du = nothing, p = nothing, weight = nothing, cachedata = nothing, reltol = nothing,
86+
reuse_A_if_factorization = false) where {P}
87+
# Some Algorithms would reuse factorization but it causes the cache to not reset in
88+
# certain cases
89+
if A !== nothing
90+
alg = linsolve.alg
91+
if (alg isa LinearSolve.AbstractFactorization) ||
92+
(alg isa LinearSolve.DefaultLinearSolver && !(alg ==
93+
LinearSolve.DefaultLinearSolver(LinearSolve.DefaultAlgorithmChoice.KrylovJL_GMRES)))
94+
# Factorization Algorithm
95+
!reuse_A_if_factorization && (linsolve.A = A)
96+
else
97+
linsolve.A = A
98+
end
99+
end
88100
b !== nothing && (linsolve.b = b)
89101
linu !== nothing && (linsolve.u = linu)
90102

91103
Plprev = linsolve.Pl isa ComposePreconditioner ? linsolve.Pl.outer : linsolve.Pl
92104
Prprev = linsolve.Pr isa ComposePreconditioner ? linsolve.Pr.outer : linsolve.Pr
93105

94-
_Pl, _Pr = precs(linsolve.A, du, u, p, nothing, A !== nothing, Plprev, Prprev,
106+
_Pl, _Pr = precs(linsolve.A, du, linu, p, nothing, A !== nothing, Plprev, Prprev,
95107
cachedata)
96108
if (_Pl !== nothing || _Pr !== nothing)
97109
_weight = weight === nothing ?

test/basictests.jl

+2-2
Original file line numberDiff line numberDiff line change
@@ -1013,12 +1013,12 @@ end
10131013
u0 = rand(100)
10141014

10151015
prob = NonlinearProblem(NonlinearFunction{false}(F; jvp = JVP), u0, u0)
1016-
sol = solve(prob, NewtonRaphson(; linsolve = KrylovJL_GMRES()))
1016+
sol = solve(prob, NewtonRaphson(; linsolve = KrylovJL_GMRES()); abstol = 1e-13)
10171017

10181018
@test norm(F(sol.u, u0)) 1e-8
10191019

10201020
prob = NonlinearProblem(NonlinearFunction{true}(F!; jvp = JVP!), u0, u0)
1021-
sol = solve(prob, NewtonRaphson(; linsolve = KrylovJL_GMRES()))
1021+
sol = solve(prob, NewtonRaphson(; linsolve = KrylovJL_GMRES()); abstol = 1e-13)
10221022

10231023
@test norm(F(sol.u, u0)) 1e-8
10241024
end

test/infeasible.jl

+2-2
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,8 @@ function f1(u, p)
2929
v_x = 8.550491684548064e-12 + u[1]
3030
v_y = 6631.60076191005 + u[2]
3131
v_z = 3600.665431405663 + u[3]
32-
r = @SVector [x, y, z]
33-
v = @SVector [v_x, v_y, v_z]
32+
r = [x, y, z]
33+
v = [v_x, v_y, v_z]
3434
h = cross(r, v)
3535
ev = cross(v, h) / μ - r / norm(r)
3636
i = acos(h[3] / norm(h))

0 commit comments

Comments
 (0)