Skip to content
This repository was archived by the owner on Apr 16, 2025. It is now read-only.

Commit b2a43e0

Browse files
Merge pull request #66 from avik-pal/ap/dfsane_batched
Batched DFSane
2 parents dc70eb6 + 625b8f3 commit b2a43e0

File tree

3 files changed

+144
-46
lines changed

3 files changed

+144
-46
lines changed

Project.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "SimpleNonlinearSolve"
22
uuid = "727e6d20-b764-4bd8-a329-72de5adea6c7"
33
authors = ["SciML"]
4-
version = "0.1.15"
4+
version = "0.1.16"
55

66
[deps]
77
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"

src/dfsane.jl

+134-42
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
11
"""
2-
```julia
3-
SimpleDFSane(; σ_min::Real = 1e-10, σ_max::Real = 1e10, σ_1::Real = 1.0,
4-
M::Int = 10, γ::Real = 1e-4, τ_min::Real = 0.1, τ_max::Real = 0.5,
5-
nexp::Int = 2, η_strategy::Function = (f_1, k, x, F) -> f_1 / k^2)
6-
```
2+
SimpleDFSane(; σ_min::Real = 1e-10, σ_max::Real = 1e10, σ_1::Real = 1.0,
3+
M::Int = 10, γ::Real = 1e-4, τ_min::Real = 0.1, τ_max::Real = 0.5,
4+
nexp::Int = 2, η_strategy::Function = (f_1, k, x, F) -> f_1 ./ k^2,
5+
termination_condition = NLSolveTerminationCondition(NLSolveTerminationMode.NLSolveDefault;
6+
abstol = nothing,
7+
reltol = nothing),
8+
batched::Bool = false,
9+
max_inner_iterations::Int = 1000)
710
811
A low-overhead implementation of the df-sane method for solving large-scale nonlinear
912
systems of equations. For in depth information about all the parameters and the algorithm,
@@ -39,8 +42,16 @@ Computation, 75, 1429-1448.](https://www.researchgate.net/publication/220576479_
3942
``f_1=||F(x_1)||^{nexp}``, `k` is the iteration number, `x` is the current `x`-value and
4043
`F` the current residual. Should satisfy ``η_k > 0`` and ``∑ₖ ηₖ < ∞``. Defaults to
4144
``||F||^2 / k^2``.
45+
- `termination_condition`: a `NLSolveTerminationCondition` that determines when the solver
46+
should terminate. Defaults to `NLSolveTerminationCondition(NLSolveTerminationMode.NLSolveDefault;
47+
abstol = nothing, reltol = nothing)`.
48+
- `batched`: if `true`, the algorithm will use a batched version of the algorithm that treats each
49+
column of `x` as a separate problem. This can be useful nonlinear problems involing neural
50+
networks. Defaults to `false`.
51+
- `max_inner_iterations`: the maximum number of iterations allowed for the inner loop of the
52+
algorithm. Used exclusively in `batched` mode. Defaults to `1000`.
4253
"""
43-
struct SimpleDFSane{T} <: AbstractSimpleNonlinearSolveAlgorithm
54+
struct SimpleDFSane{batched, T, TC} <: AbstractSimpleNonlinearSolveAlgorithm
4455
σ_min::T
4556
σ_max::T
4657
σ_1::T
@@ -50,106 +61,187 @@ struct SimpleDFSane{T} <: AbstractSimpleNonlinearSolveAlgorithm
5061
τ_max::T
5162
nexp::Int
5263
η_strategy::Function
64+
termination_condition::TC
65+
max_inner_iterations::Int
5366

5467
function SimpleDFSane(; σ_min::Real = 1e-10, σ_max::Real = 1e10, σ_1::Real = 1.0,
5568
M::Int = 10, γ::Real = 1e-4, τ_min::Real = 0.1, τ_max::Real = 0.5,
56-
nexp::Int = 2, η_strategy::Function = (f_1, k, x, F) -> f_1 / k^2)
57-
new{typeof(σ_min)}(σ_min, σ_max, σ_1, M, γ, τ_min, τ_max, nexp, η_strategy)
69+
nexp::Int = 2, η_strategy::Function = (f_1, k, x, F) -> f_1 ./ k^2,
70+
termination_condition = NLSolveTerminationCondition(NLSolveTerminationMode.NLSolveDefault;
71+
abstol = nothing,
72+
reltol = nothing),
73+
batched::Bool = false,
74+
max_inner_iterations = 1000)
75+
return new{batched, typeof(σ_min), typeof(termination_condition)}(σ_min,
76+
σ_max,
77+
σ_1,
78+
M,
79+
γ,
80+
τ_min,
81+
τ_max,
82+
nexp,
83+
η_strategy,
84+
termination_condition,
85+
max_inner_iterations)
5886
end
5987
end
6088

61-
function SciMLBase.__solve(prob::NonlinearProblem, alg::SimpleDFSane,
89+
function SciMLBase.__solve(prob::NonlinearProblem, alg::SimpleDFSane{batched},
6290
args...; abstol = nothing, reltol = nothing, maxiters = 1000,
63-
kwargs...)
91+
kwargs...) where {batched}
92+
tc = alg.termination_condition
93+
mode = DiffEqBase.get_termination_mode(tc)
94+
6495
f = Base.Fix2(prob.f, prob.p)
6596
x = float(prob.u0)
97+
98+
if batched
99+
batch_size = size(x, 2)
100+
end
101+
66102
T = eltype(x)
67103
σ_min = float(alg.σ_min)
68104
σ_max = float(alg.σ_max)
69-
σ_k = float(alg.σ_1)
105+
σ_k = batched ? fill(float(alg.σ_1), 1, batch_size) : float(alg.σ_1)
106+
70107
M = alg.M
71108
γ = float(alg.γ)
72109
τ_min = float(alg.τ_min)
73110
τ_max = float(alg.τ_max)
74111
nexp = alg.nexp
75112
η_strategy = alg.η_strategy
76113

114+
batched && @assert ndims(x)==2 "Batched SimpleDFSane only supports 2D arrays"
115+
77116
if SciMLBase.isinplace(prob)
78117
error("SimpleDFSane currently only supports out-of-place nonlinear problems")
79118
end
80119

81120
atol = abstol !== nothing ? abstol :
82-
real(oneunit(eltype(T))) * (eps(real(one(eltype(T)))))^(4 // 5)
83-
rtol = reltol !== nothing ? reltol : eps(real(one(eltype(T))))^(4 // 5)
121+
(tc.abstol !== nothing ? tc.abstol :
122+
real(oneunit(eltype(T))) * (eps(real(one(eltype(T)))))^(4 // 5))
123+
rtol = reltol !== nothing ? reltol :
124+
(tc.reltol !== nothing ? tc.reltol : eps(real(one(eltype(T))))^(4 // 5))
125+
126+
if mode DiffEqBase.SAFE_BEST_TERMINATION_MODES
127+
error("SimpleDFSane currently doesn't support SAFE_BEST termination modes")
128+
end
129+
130+
storage = mode DiffEqBase.SAFE_TERMINATION_MODES ? NLSolveSafeTerminationResult() :
131+
nothing
132+
termination_condition = tc(storage)
84133

85134
function ff(x)
86135
F = f(x)
87-
f_k = norm(F)^nexp
136+
f_k = if batched
137+
sum(abs2, F; dims = 1) .^ (nexp / 2)
138+
else
139+
norm(F)^nexp
140+
end
88141
return f_k, F
89142
end
90143

144+
function generate_history(f_k, M)
145+
if batched
146+
history = similar(f_k, (M, length(f_k)))
147+
history .= reshape(f_k, 1, :)
148+
return history
149+
else
150+
return fill(f_k, M)
151+
end
152+
end
153+
91154
f_k, F_k = ff(x)
92155
α_1 = convert(T, 1.0)
93156
f_1 = f_k
94-
history_f_k = fill(f_k, M)
157+
history_f_k = generate_history(f_k, M)
95158

96159
for k in 1:maxiters
97-
iszero(F_k) &&
98-
return SciMLBase.build_solution(prob, alg, x, F_k;
99-
retcode = ReturnCode.Success)
100-
101160
# Spectral parameter range check
102-
if abs(σ_k) > σ_max
103-
σ_k = sign(σ_k) * σ_max
104-
elseif abs(σ_k) < σ_min
105-
σ_k = sign(σ_k) * σ_min
161+
if batched
162+
@. σ_k = sign(σ_k) * clamp(abs(σ_k), σ_min, σ_max)
163+
else
164+
σ_k = sign(σ_k) * clamp(abs(σ_k), σ_min, σ_max)
106165
end
107166

108167
# Line search direction
109-
d = -σ_k * F_k
168+
d = -σ_k .* F_k
110169

111170
η = η_strategy(f_1, k, x, F_k)
112-
= maximum(history_f_k)
171+
= batched ? maximum(history_f_k; dims = 1) : maximum(history_f_k)
113172
α_p = α_1
114173
α_m = α_1
115-
x_new = x + α_p * d
174+
x_new = @. x + α_p * d
175+
116176
f_new, F_new = ff(x_new)
177+
178+
inner_iterations = 0
117179
while true
118-
if f_new + η - γ * α_p^2 * f_k
119-
break
180+
inner_iterations += 1
181+
182+
if batched
183+
criteria = @.+ η - γ * α_p^2 * f_k
184+
# NOTE: This is simply a heuristic, ideally we check using `all` but that is
185+
# typically very expensive for large problems
186+
(sum(f_new .≤ criteria) batch_size ÷ 2) && break
187+
else
188+
criteria =+ η - γ * α_p^2 * f_k
189+
f_new criteria && break
120190
end
121191

122-
α_tp = α_p^2 * f_k / (f_new + (2 * α_p - 1) * f_k)
123-
x_new = x - α_m * d
192+
α_tp = @. α_p^2 * f_k / (f_new + (2 * α_p - 1) * f_k)
193+
x_new = @. x - α_m * d
124194
f_new, F_new = ff(x_new)
125195

126-
if f_new + η - γ * α_m^2 * f_k
127-
break
196+
if batched
197+
# NOTE: This is simply a heuristic, ideally we check using `all` but that is
198+
# typically very expensive for large problems
199+
(sum(f_new .≤ criteria) batch_size ÷ 2) && break
200+
else
201+
f_new criteria && break
128202
end
129203

130-
α_tm = α_m^2 * f_k / (f_new + (2 * α_m - 1) * f_k)
131-
α_p = min(τ_max * α_p, max(α_tp, τ_min * α_p))
132-
α_m = min(τ_max * α_m, max(α_tm, τ_min * α_m))
133-
x_new = x + α_p * d
204+
α_tm = @. α_m^2 * f_k / (f_new + (2 * α_m - 1) * f_k)
205+
α_p = @. clamp(α_tp, τ_min * α_p, τ_max * α_p)
206+
α_m = @. clamp(α_tm, τ_min * α_m, τ_max * α_m)
207+
x_new = @. x + α_p * d
134208
f_new, F_new = ff(x_new)
209+
210+
# NOTE: The original algorithm runs till either condition is satisfied, however,
211+
# for most batched problems like neural networks we only care about
212+
# approximate convergence
213+
batched && (inner_iterations alg.max_inner_iterations) && break
135214
end
136215

137-
if isapprox(x_new, x, atol = atol, rtol = rtol)
138-
return SciMLBase.build_solution(prob, alg, x_new, F_new;
216+
if termination_condition(F_new, x_new, x, atol, rtol)
217+
return SciMLBase.build_solution(prob,
218+
alg,
219+
x_new,
220+
F_new;
139221
retcode = ReturnCode.Success)
140222
end
223+
141224
# Update spectral parameter
142-
s_k = x_new - x
143-
y_k = F_new - F_k
144-
σ_k = (s_k' * s_k) / (s_k' * y_k)
225+
s_k = @. x_new - x
226+
y_k = @. F_new - F_k
227+
228+
if batched
229+
σ_k = sum(abs2, s_k; dims = 1) ./ (sum(s_k .* y_k; dims = 1) .+ T(1e-5))
230+
else
231+
σ_k = (s_k' * s_k) / (s_k' * y_k)
232+
end
145233

146234
# Take step
147235
x = x_new
148236
F_k = F_new
149237
f_k = f_new
150238

151239
# Store function value
152-
history_f_k[k % M + 1] = f_new
240+
if batched
241+
history_f_k[k % M + 1, :] .= vec(f_new)
242+
else
243+
history_f_k[k % M + 1] = f_new
244+
end
153245
end
154246
return SciMLBase.build_solution(prob, alg, x, F_k; retcode = ReturnCode.MaxIters)
155247
end

test/basictests.jl

+9-3
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@ const BATCHED_BROYDEN_SOLVERS = Broyden[]
99
const BROYDEN_SOLVERS = Broyden[]
1010
const BATCHED_LBROYDEN_SOLVERS = LBroyden[]
1111
const LBROYDEN_SOLVERS = LBroyden[]
12+
const BATCHED_DFSANE_SOLVERS = SimpleDFSane[]
13+
const DFSANE_SOLVERS = SimpleDFSane[]
1214

1315
for mode in instances(NLSolveTerminationMode.T)
1416
if mode
@@ -23,6 +25,8 @@ for mode in instances(NLSolveTerminationMode.T)
2325
push!(BATCHED_BROYDEN_SOLVERS, Broyden(; batched = true, termination_condition))
2426
push!(LBROYDEN_SOLVERS, LBroyden(; batched = false, termination_condition))
2527
push!(BATCHED_LBROYDEN_SOLVERS, LBroyden(; batched = true, termination_condition))
28+
push!(DFSANE_SOLVERS, SimpleDFSane(; batched = false, termination_condition))
29+
push!(BATCHED_DFSANE_SOLVERS, SimpleDFSane(; batched = true, termination_condition))
2630
end
2731

2832
# SimpleNewtonRaphson
@@ -484,11 +488,13 @@ sol = solve(probN, Broyden(batched = true))
484488

485489
@test abs.(sol.u) sqrt.(p)
486490

487-
for alg in (BATCHED_BROYDEN_SOLVERS..., BATCHED_LBROYDEN_SOLVERS...)
488-
sol = solve(probN, alg)
491+
for alg in (BATCHED_BROYDEN_SOLVERS...,
492+
BATCHED_LBROYDEN_SOLVERS...,
493+
BATCHED_DFSANE_SOLVERS...)
494+
sol = solve(probN, alg; abstol = 1e-3, reltol = 1e-3)
489495

490496
@test sol.retcode == ReturnCode.Success
491-
@test abs.(sol.u) sqrt.(p)
497+
@test abs.(sol.u)sqrt.(p) atol=1e-3 rtol=1e-3
492498
end
493499

494500
## User specified Jacobian

0 commit comments

Comments
 (0)