Skip to content

Commit 81c86f3

Browse files
committed
Should make tests pass
1 parent 1a0df4e commit 81c86f3

File tree

5 files changed

+11
-22
lines changed

5 files changed

+11
-22
lines changed

Manifest.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,7 @@ uuid = "ade2ca70-3891-5945-98fb-dc099432e06a"
164164

165165
[[deps.DiffEqBase]]
166166
deps = ["ArrayInterface", "ChainRulesCore", "DataStructures", "DocStringExtensions", "EnumX", "EnzymeCore", "FastBroadcast", "ForwardDiff", "FunctionWrappers", "FunctionWrappersWrappers", "LinearAlgebra", "Logging", "Markdown", "MuladdMacro", "Parameters", "PreallocationTools", "PrecompileTools", "Printf", "RecursiveArrayTools", "Reexport", "Requires", "SciMLBase", "SciMLOperators", "Setfield", "SparseArrays", "Static", "StaticArraysCore", "Statistics", "Tricks", "TruncatedStacktraces", "ZygoteRules"]
167-
git-tree-sha1 = "4e661d0beddac31da05e71b79afd769232622de8"
167+
git-tree-sha1 = "0ab52aef95c5cc71e9a8c9d26919ce1f7fb472fa"
168168
repo-rev = "ap/tstable_termination"
169169
repo-url = "https://github.com/SciML/DiffEqBase.jl"
170170
uuid = "2b5f629d-d688-5b77-993f-72d75c75574e"

src/NonlinearSolve.jl

+1
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ end
5757
get_fu(cache::AbstractNonlinearSolveCache) = cache.fu1
5858
set_fu!(cache::AbstractNonlinearSolveCache, fu) = (cache.fu1 = fu)
5959
get_u(cache::AbstractNonlinearSolveCache) = cache.u
60+
set_u!(cache::AbstractNonlinearSolveCache, u) = (cache.u = u)
6061

6162
function SciMLBase.solve!(cache::AbstractNonlinearSolveCache)
6263
while not_terminated(cache)

src/dfsane.jl

+1
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,7 @@ end
9898
get_fu(cache::DFSaneCache) = cache.fuₙ
9999
set_fu!(cache::DFSaneCache, fu) = (cache.fuₙ = fu)
100100
get_u(cache::DFSaneCache) = cache.uₙ
101+
set_u!(cache::DFSaneCache, u) = (cache.uₙ = u)
101102

102103
function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::DFSane, args...;
103104
alias_u0 = false, maxiters = 1000, abstol = nothing, reltol = nothing,

src/utils.jl

+6-8
Original file line numberDiff line numberDiff line change
@@ -230,8 +230,7 @@ function check_and_update!(tc_cache, cache, fu, u, uprev,
230230
if isinplace(cache)
231231
cache.prob.f(get_fu(cache), u, cache.prob.p)
232232
else
233-
cache.u = u
234-
set_fu!(cache, cache.prob.f(cache.u, cache.prob.p))
233+
set_fu!(cache, cache.prob.f(u, cache.prob.p))
235234
end
236235
cache.force_stop = true
237236
end
@@ -252,8 +251,7 @@ function check_and_update!(tc_cache, cache, fu, u, uprev,
252251
if isinplace(cache)
253252
cache.prob.f(get_fu(cache), u, cache.prob.p)
254253
else
255-
cache.u = u
256-
set_fu!(cache, cache.prob.f(cache.u, cache.prob.p))
254+
set_fu!(cache, cache.prob.f(u, cache.prob.p))
257255
end
258256
cache.force_stop = true
259257
end
@@ -271,11 +269,11 @@ function check_and_update!(tc_cache, cache, fu, u, uprev,
271269
cache.retcode = ReturnCode.Unstable
272270
end
273271
if isinplace(cache)
274-
copyto!(u, tc_cache.u)
275-
cache.prob.f(get_fu(cache), u, cache.prob.p)
272+
copyto!(get_u(cache), tc_cache.u)
273+
cache.prob.f(get_fu(cache), get_u(cache), cache.prob.p)
276274
else
277-
cache.u = tc_cache.u
278-
set_fu!(cache, cache.prob.f(cache.u, cache.prob.p))
275+
set_u!(cache, tc_cache.u)
276+
set_fu!(cache, cache.prob.f(get_u(cache), cache.prob.p))
279277
end
280278
cache.force_stop = true
281279
end

test/basictests.jl

+2-13
Original file line numberDiff line numberDiff line change
@@ -453,17 +453,13 @@ end
453453
end
454454

455455
@testset "[OOP] [Immutable AD]" begin
456-
broken_forwarddiff = [3.0, 4.0, 81.0]
457456
for p in 1.1:0.1:100.0
458457
res = abs.(benchmark_nlsolve_oop(quadratic_f, @SVector[1.0, 1.0], p).u)
459458

460459
if any(x -> isnan(x) || x <= 1e-5 || x >= 1e5, res)
461460
@test_broken all(res .≈ sqrt(p))
462461
@test_broken abs.(ForwardDiff.derivative(p -> benchmark_nlsolve_oop(quadratic_f,
463462
@SVector[1.0, 1.0], p).u[end], p)) 1 / (2 * sqrt(p))
464-
elseif p in broken_forwarddiff
465-
@test_broken abs.(ForwardDiff.derivative(p -> benchmark_nlsolve_oop(quadratic_f,
466-
@SVector[1.0, 1.0], p).u[end], p)) 1 / (2 * sqrt(p))
467463
else
468464
@test all(res .≈ sqrt(p))
469465
@test isapprox(abs.(ForwardDiff.derivative(p -> benchmark_nlsolve_oop(quadratic_f,
@@ -473,17 +469,13 @@ end
473469
end
474470

475471
@testset "[OOP] [Scalar AD]" begin
476-
broken_forwarddiff = [3.0, 4.0, 81.0]
477472
for p in 1.1:0.1:100.0
478473
res = abs(benchmark_nlsolve_oop(quadratic_f, 1.0, p).u)
479474

480475
if any(x -> isnan(x) || x <= 1e-5 || x >= 1e5, res)
481476
@test_broken res sqrt(p)
482477
@test_broken abs.(ForwardDiff.derivative(p -> benchmark_nlsolve_oop(quadratic_f,
483478
1.0, p).u, p)) 1 / (2 * sqrt(p))
484-
elseif p in broken_forwarddiff
485-
@test_broken abs.(ForwardDiff.derivative(p -> benchmark_nlsolve_oop(quadratic_f,
486-
1.0, p).u, p)) 1 / (2 * sqrt(p))
487479
else
488480
@test res sqrt(p)
489481
@test isapprox(abs.(ForwardDiff.derivative(p -> benchmark_nlsolve_oop(quadratic_f,
@@ -549,7 +541,6 @@ end
549541

550542
probN = NonlinearProblem{false}(quadratic_f, [1.0, 1.0], 2.0)
551543
sol = solve(probN, alg, abstol = 1e-11)
552-
println(abs.(quadratic_f(sol.u, 2.0)))
553544
@test all(abs.(quadratic_f(sol.u, 2.0)) .< 1e-10)
554545
end
555546
end
@@ -644,13 +635,11 @@ end
644635

645636
function nlprob_iterator_interface(f, p_range, ::Val{iip}) where {iip}
646637
probN = NonlinearProblem{iip}(f, iip ? [0.5] : 0.5, p_range[begin])
647-
cache = init(probN,
648-
PseudoTransient(alpha_initial = 10.0);
649-
maxiters = 100,
638+
cache = init(probN, PseudoTransient(alpha_initial = 10.0); maxiters = 100,
650639
abstol = 1e-10)
651640
sols = zeros(length(p_range))
652641
for (i, p) in enumerate(p_range)
653-
reinit!(cache, iip ? [cache.u[1]] : cache.u; p = p, alpha_new = 10.0)
642+
reinit!(cache, iip ? [cache.u[1]] : cache.u; p = p, alpha = 10.0)
654643
sol = solve!(cache)
655644
sols[i] = iip ? sol.u[1] : sol.u
656645
end

0 commit comments

Comments
 (0)