Skip to content
This repository was archived by the owner on Apr 16, 2025. It is now read-only.

Commit 9858095

Browse files
committed
SimpleJNFK working version
1 parent 59d69cd commit 9858095

File tree

5 files changed

+28
-18
lines changed

5 files changed

+28
-18
lines changed

Diff for: ext/SimpleNonlinearSolveNNlibExt.jl

+1-4
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,7 @@ using ArrayInterface, DiffEqBase, LinearAlgebra, NNlib, SimpleNonlinearSolve, Sc
44
import SimpleNonlinearSolve: _construct_batched_problem_structure,
55
_get_storage, _init_𝓙, _result_from_storage, _get_tolerance, @maybeinplace
66

7-
function __init__()
8-
SimpleNonlinearSolve.NNlibExtLoaded[] = true
9-
return
10-
end
7+
SimpleNonlinearSolve.extension_loaded(::Val{NNlib}) = true
118

129
@views function SciMLBase.__solve(prob::NonlinearProblem,
1310
alg::BatchedBroyden;

Diff for: src/SimpleNonlinearSolve.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ function __init__()
1616
@require_extensions
1717
end
1818

19-
const NNlibExtLoaded = Ref{Bool}(false)
19+
extension_loaded(::Val) = false
2020

2121
abstract type AbstractSimpleNonlinearSolveAlgorithm <: SciMLBase.AbstractNonlinearAlgorithm end
2222
abstract type AbstractBracketingAlgorithm <: AbstractSimpleNonlinearSolveAlgorithm end

Diff for: src/broyden.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ function Broyden(; batched = false,
2222
abstol = nothing,
2323
reltol = nothing))
2424
if batched
25-
@assert NNlibExtLoaded[] "Please install and load `NNlib.jl` to use batched Broyden."
25+
@assert extension_loaded(Val(:NNlib)) "Please install and load `NNlib.jl` to use batched Broyden."
2626
return BatchedBroyden(termination_condition)
2727
end
2828
return Broyden(termination_condition)

Diff for: src/jnfk.jl

+25-11
Original file line numberDiff line numberDiff line change
@@ -5,22 +5,32 @@ function jvp_forwarddiff(f, x::AbstractArray{T}, v) where {T}
55
y = (Dual{Tag{SimpleJNFKJacVecTag, T}, T, 1}).(x, Partials.(tuple.(v_)))
66
return vec(ForwardDiff.partials.(vec(f(y)), 1))
77
end
8+
jvp_forwarddiff!(r, f, x, v) = copyto!(r, jvp_forwarddiff(f, x, v))
89

910
struct JacVecOperator{F, X}
1011
f::F
1112
x::X
1213
end
1314

1415
(jvp::JacVecOperator)(v, _, _) = jvp_forwarddiff(jvp.f, jvp.x, v)
16+
(jvp::JacVecOperator)(r, v, _, _) = jvp_forwarddiff!(r, jvp.f, jvp.x, v)
1517

1618
"""
17-
SimpleJNFK()
19+
SimpleJNFK(; batched::Bool = false)
1820
21+
A low overhead Jacobian-free Newton-Krylov method. This method internally uses `GMRES` to
22+
avoid computing the Jacobian Matrix.
23+
24+
!!! warning
25+
26+
JNFK doesn't work well without preconditioning, which is currently not supported. We
27+
recommend using `NewtonRaphson(linsolve = KrylovJL_GMRES())` for preconditioning
28+
support.
1929
"""
2030
struct SimpleJFNK end
2131

2232
function SciMLBase.__solve(prob::NonlinearProblem, alg::SimpleJFNK, args...;
23-
abstol = nothing, reltol= nothing, maxiters = 1000, linsolve_kwargs = (;), kwargs...)
33+
abstol = nothing, reltol = nothing, maxiters = 1000, linsolve_kwargs = (;), kwargs...)
2434
iip = SciMLBase.isinplace(prob)
2535
@assert !iip "SimpleJFNK does not support inplace problems"
2636

@@ -29,26 +39,30 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::SimpleJFNK, args...;
2939
fx = f(x)
3040
T = typeof(x)
3141

42+
iszero(fx) &&
43+
return SciMLBase.build_solution(prob, alg, x, fx; retcode = ReturnCode.Success)
44+
3245
atol = abstol !== nothing ? abstol :
3346
real(oneunit(eltype(T))) * (eps(real(one(eltype(T)))))^(4 // 5)
3447
rtol = reltol !== nothing ? reltol : eps(real(one(eltype(T))))^(4 // 5)
3548

3649
op = FunctionOperator(JacVecOperator(f, x), x)
37-
linprob = LinearProblem(op, -fx)
38-
lincache = init(linprob, SimpleGMRES(); abstol, reltol, maxiters, linsolve_kwargs...)
50+
linprob = LinearProblem(op, vec(fx))
51+
lincache = init(linprob, KrylovJL_GMRES(); abstol = atol, reltol = rtol, maxiters,
52+
linsolve_kwargs...)
3953

4054
for i in 1:maxiters
41-
iszero(fx) &&
42-
return SciMLBase.build_solution(prob, alg, x, fx; retcode = ReturnCode.Success)
43-
4455
linsol = solve!(lincache)
45-
x .-= linsol.u
56+
axpy!(-1, linsol.u, x)
4657
lincache = linsol.cache
4758

48-
# FIXME: not nothing
49-
if isapprox(x, nothing; atol, rtol)
59+
fx = f(x)
60+
61+
norm(fx, Inf) atol &&
5062
return SciMLBase.build_solution(prob, alg, x, fx; retcode = ReturnCode.Success)
51-
end
63+
64+
lincache.b = vec(fx)
65+
lincache.A = FunctionOperator(JacVecOperator(f, x), x)
5266
end
5367

5468
return SciMLBase.build_solution(prob, alg, x, fx; retcode = ReturnCode.MaxIters)

Diff for: src/raphson.jl

-1
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,6 @@ function SimpleNewtonRaphson(; batched = false,
4242
throw(ArgumentError("`termination_condition` is currently only supported for batched problems"))
4343
end
4444
if batched
45-
# @assert ADLinearSolveFDExtLoaded[] "Please install and load `LinearSolve.jl`, `FiniteDifferences.jl` and `AbstractDifferentiation.jl` to use batched Newton-Raphson."
4645
termination_condition = ismissing(termination_condition) ?
4746
NLSolveTerminationCondition(NLSolveTerminationMode.NLSolveDefault;
4847
abstol = nothing,

0 commit comments

Comments
 (0)