Skip to content

Commit 3ed48f8

Browse files
authored
fix: reinit! on forwarddiff cache (#491)
1 parent 5c722c0 commit 3ed48f8

File tree

2 files changed

+32
-2
lines changed

2 files changed

+32
-2
lines changed

src/forward_diff.jl

+1-2
Original file line numberDiff line numberDiff line change
@@ -37,10 +37,9 @@ function InternalAPI.reinit!(
3737
cache::NonlinearSolveForwardDiffCache, args...;
3838
p = cache.p, u0 = NonlinearSolveBase.get_u(cache.cache), kwargs...
3939
)
40-
inner_cache = InternalAPI.reinit!(
40+
InternalAPI.reinit!(
4141
cache.cache; p = nodual_value(p), u0 = nodual_value(u0), kwargs...
4242
)
43-
cache.cache = inner_cache
4443
cache.p = p
4544
cache.values_p = nodual_value(p)
4645
cache.partials_p = ForwardDiff.partials(p)

test/forward_ad_tests.jl

+31
Original file line numberDiff line numberDiff line change
@@ -218,3 +218,34 @@ end
218218

219219
@test hess1hess2 atol=1e-3
220220
end
221+
222+
@testitem "reinit! on ForwardDiff cache SciML/NonlinearSolve.jl#391" tags=[:core] begin
223+
using ForwardDiff
224+
225+
function multiple_solves(ps::Vector)
226+
res = similar(ps, 4, length(ps))
227+
for (i, p) in enumerate(ps)
228+
prob = NonlinearProblem{false}((u, p) -> u .* u .- p, rand(4), ps[i])
229+
sol = solve(prob)
230+
res[:, i] .= sol.u
231+
end
232+
return sum(abs2, res)
233+
end
234+
235+
function multiple_solves_cached(ps::Vector)
236+
res = similar(ps, 4, length(ps))
237+
prob = NonlinearProblem{false}((u, p) -> u .* u .- p, rand(4), ps[1])
238+
cache = init(prob, NewtonRaphson())
239+
for (i, p) in enumerate(ps)
240+
reinit!(cache; p)
241+
sol = solve!(cache)
242+
res[:, i] .= sol.u
243+
end
244+
return sum(abs2, res)
245+
end
246+
247+
ps = collect(1.0:5.0)
248+
249+
@test ForwardDiff.gradient(multiple_solves, ps)
250+
ForwardDiff.gradient(multiple_solves_cached, ps)
251+
end

0 commit comments

Comments
 (0)