1
1
"""
2
- ```julia
3
- SimpleDFSane(; σ_min::Real = 1e-10, σ_max::Real = 1e10, σ_1::Real = 1.0,
4
- M::Int = 10, γ::Real = 1e-4, τ_min::Real = 0.1, τ_max::Real = 0.5,
5
- nexp::Int = 2, η_strategy::Function = (f_1, k, x, F) -> f_1 / k^2)
6
- ```
2
+ SimpleDFSane(; σ_min::Real = 1e-10, σ_max::Real = 1e10, σ_1::Real = 1.0,
3
+ M::Int = 10, γ::Real = 1e-4, τ_min::Real = 0.1, τ_max::Real = 0.5,
4
+ nexp::Int = 2, η_strategy::Function = (f_1, k, x, F) -> f_1 ./ k^2,
5
+ termination_condition = NLSolveTerminationCondition(NLSolveTerminationMode.NLSolveDefault;
6
+ abstol = nothing,
7
+ reltol = nothing),
8
+ batched::Bool = false,
9
+ max_inner_iterations::Int = 1000)
7
10
8
11
A low-overhead implementation of the df-sane method for solving large-scale nonlinear
9
12
systems of equations. For in depth information about all the parameters and the algorithm,
@@ -39,8 +42,16 @@ Computation, 75, 1429-1448.](https://www.researchgate.net/publication/220576479_
39
42
``f_1=||F(x_1)||^{nexp}``, `k` is the iteration number, `x` is the current `x`-value and
40
43
`F` the current residual. Should satisfy ``η_k > 0`` and ``∑ₖ ηₖ < ∞``. Defaults to
41
44
``||F||^2 / k^2``.
45
+ - `termination_condition`: a `NLSolveTerminationCondition` that determines when the solver
46
+ should terminate. Defaults to `NLSolveTerminationCondition(NLSolveTerminationMode.NLSolveDefault;
47
+ abstol = nothing, reltol = nothing)`.
48
+ - `batched`: if `true`, the algorithm will use a batched version of the algorithm that treats each
49
+ column of `x` as a separate problem. This can be useful nonlinear problems involing neural
50
+ networks. Defaults to `false`.
51
+ - `max_inner_iterations`: the maximum number of iterations allowed for the inner loop of the
52
+ algorithm. Used exclusively in `batched` mode. Defaults to `1000`.
42
53
"""
43
- struct SimpleDFSane{T } <: AbstractSimpleNonlinearSolveAlgorithm
54
+ struct SimpleDFSane{batched, T, TC } <: AbstractSimpleNonlinearSolveAlgorithm
44
55
σ_min:: T
45
56
σ_max:: T
46
57
σ_1:: T
@@ -50,106 +61,187 @@ struct SimpleDFSane{T} <: AbstractSimpleNonlinearSolveAlgorithm
50
61
τ_max:: T
51
62
nexp:: Int
52
63
η_strategy:: Function
64
+ termination_condition:: TC
65
+ max_inner_iterations:: Int
53
66
54
67
function SimpleDFSane (; σ_min:: Real = 1e-10 , σ_max:: Real = 1e10 , σ_1:: Real = 1.0 ,
55
68
M:: Int = 10 , γ:: Real = 1e-4 , τ_min:: Real = 0.1 , τ_max:: Real = 0.5 ,
56
- nexp:: Int = 2 , η_strategy:: Function = (f_1, k, x, F) -> f_1 / k^ 2 )
57
- new {typeof(σ_min)} (σ_min, σ_max, σ_1, M, γ, τ_min, τ_max, nexp, η_strategy)
69
+ nexp:: Int = 2 , η_strategy:: Function = (f_1, k, x, F) -> f_1 ./ k^ 2 ,
70
+ termination_condition = NLSolveTerminationCondition (NLSolveTerminationMode. NLSolveDefault;
71
+ abstol = nothing ,
72
+ reltol = nothing ),
73
+ batched:: Bool = false ,
74
+ max_inner_iterations = 1000 )
75
+ return new {batched, typeof(σ_min), typeof(termination_condition)} (σ_min,
76
+ σ_max,
77
+ σ_1,
78
+ M,
79
+ γ,
80
+ τ_min,
81
+ τ_max,
82
+ nexp,
83
+ η_strategy,
84
+ termination_condition,
85
+ max_inner_iterations)
58
86
end
59
87
end
60
88
61
- function SciMLBase. __solve (prob:: NonlinearProblem , alg:: SimpleDFSane ,
89
+ function SciMLBase. __solve (prob:: NonlinearProblem , alg:: SimpleDFSane{batched} ,
62
90
args... ; abstol = nothing , reltol = nothing , maxiters = 1000 ,
63
- kwargs... )
91
+ kwargs... ) where {batched}
92
+ tc = alg. termination_condition
93
+ mode = DiffEqBase. get_termination_mode (tc)
94
+
64
95
f = Base. Fix2 (prob. f, prob. p)
65
96
x = float (prob. u0)
97
+
98
+ if batched
99
+ batch_size = size (x, 2 )
100
+ end
101
+
66
102
T = eltype (x)
67
103
σ_min = float (alg. σ_min)
68
104
σ_max = float (alg. σ_max)
69
- σ_k = float (alg. σ_1)
105
+ σ_k = batched ? fill (float (alg. σ_1), 1 , batch_size) : float (alg. σ_1)
106
+
70
107
M = alg. M
71
108
γ = float (alg. γ)
72
109
τ_min = float (alg. τ_min)
73
110
τ_max = float (alg. τ_max)
74
111
nexp = alg. nexp
75
112
η_strategy = alg. η_strategy
76
113
114
+ batched && @assert ndims (x)== 2 " Batched SimpleDFSane only supports 2D arrays"
115
+
77
116
if SciMLBase. isinplace (prob)
78
117
error (" SimpleDFSane currently only supports out-of-place nonlinear problems" )
79
118
end
80
119
81
120
atol = abstol != = nothing ? abstol :
82
- real (oneunit (eltype (T))) * (eps (real (one (eltype (T)))))^ (4 // 5 )
83
- rtol = reltol != = nothing ? reltol : eps (real (one (eltype (T))))^ (4 // 5 )
121
+ (tc. abstol != = nothing ? tc. abstol :
122
+ real (oneunit (eltype (T))) * (eps (real (one (eltype (T)))))^ (4 // 5 ))
123
+ rtol = reltol != = nothing ? reltol :
124
+ (tc. reltol != = nothing ? tc. reltol : eps (real (one (eltype (T))))^ (4 // 5 ))
125
+
126
+ if mode ∈ DiffEqBase. SAFE_BEST_TERMINATION_MODES
127
+ error (" SimpleDFSane currently doesn't support SAFE_BEST termination modes" )
128
+ end
129
+
130
+ storage = mode ∈ DiffEqBase. SAFE_TERMINATION_MODES ? NLSolveSafeTerminationResult () :
131
+ nothing
132
+ termination_condition = tc (storage)
84
133
85
134
function ff (x)
86
135
F = f (x)
87
- f_k = norm (F)^ nexp
136
+ f_k = if batched
137
+ sum (abs2, F; dims = 1 ) .^ (nexp / 2 )
138
+ else
139
+ norm (F)^ nexp
140
+ end
88
141
return f_k, F
89
142
end
90
143
144
+ function generate_history (f_k, M)
145
+ if batched
146
+ history = similar (f_k, (M, length (f_k)))
147
+ history .= reshape (f_k, 1 , :)
148
+ return history
149
+ else
150
+ return fill (f_k, M)
151
+ end
152
+ end
153
+
91
154
f_k, F_k = ff (x)
92
155
α_1 = convert (T, 1.0 )
93
156
f_1 = f_k
94
- history_f_k = fill (f_k, M)
157
+ history_f_k = generate_history (f_k, M)
95
158
96
159
for k in 1 : maxiters
97
- iszero (F_k) &&
98
- return SciMLBase. build_solution (prob, alg, x, F_k;
99
- retcode = ReturnCode. Success)
100
-
101
160
# Spectral parameter range check
102
- if abs (σ_k) > σ_max
103
- σ_k = sign (σ_k) * σ_max
104
- elseif abs (σ_k) < σ_min
105
- σ_k = sign (σ_k) * σ_min
161
+ if batched
162
+ @. σ_k = sign (σ_k) * clamp ( abs (σ_k), σ_min, σ_max)
163
+ else
164
+ σ_k = sign (σ_k) * clamp ( abs (σ_k), σ_min, σ_max)
106
165
end
107
166
108
167
# Line search direction
109
- d = - σ_k * F_k
168
+ d = - σ_k . * F_k
110
169
111
170
η = η_strategy (f_1, k, x, F_k)
112
- f̄ = maximum (history_f_k)
171
+ f̄ = batched ? maximum (history_f_k; dims = 1 ) : maximum (history_f_k)
113
172
α_p = α_1
114
173
α_m = α_1
115
- x_new = x + α_p * d
174
+ x_new = @. x + α_p * d
175
+
116
176
f_new, F_new = ff (x_new)
177
+
178
+ inner_iterations = 0
117
179
while true
118
- if f_new ≤ f̄ + η - γ * α_p^ 2 * f_k
119
- break
180
+ inner_iterations += 1
181
+
182
+ if batched
183
+ criteria = @. f̄ + η - γ * α_p^ 2 * f_k
184
+ # NOTE: This is simply a heuristic, ideally we check using `all` but that is
185
+ # typically very expensive for large problems
186
+ (sum (f_new .≤ criteria) ≥ batch_size ÷ 2 ) && break
187
+ else
188
+ criteria = f̄ + η - γ * α_p^ 2 * f_k
189
+ f_new ≤ criteria && break
120
190
end
121
191
122
- α_tp = α_p^ 2 * f_k / (f_new + (2 * α_p - 1 ) * f_k)
123
- x_new = x - α_m * d
192
+ α_tp = @. α_p^ 2 * f_k / (f_new + (2 * α_p - 1 ) * f_k)
193
+ x_new = @. x - α_m * d
124
194
f_new, F_new = ff (x_new)
125
195
126
- if f_new ≤ f̄ + η - γ * α_m^ 2 * f_k
127
- break
196
+ if batched
197
+ # NOTE: This is simply a heuristic, ideally we check using `all` but that is
198
+ # typically very expensive for large problems
199
+ (sum (f_new .≤ criteria) ≥ batch_size ÷ 2 ) && break
200
+ else
201
+ f_new ≤ criteria && break
128
202
end
129
203
130
- α_tm = α_m^ 2 * f_k / (f_new + (2 * α_m - 1 ) * f_k)
131
- α_p = min (τ_max * α_p, max (α_tp, τ_min * α_p) )
132
- α_m = min (τ_max * α_m, max (α_tm, τ_min * α_m) )
133
- x_new = x + α_p * d
204
+ α_tm = @. α_m^ 2 * f_k / (f_new + (2 * α_m - 1 ) * f_k)
205
+ α_p = @. clamp (α_tp, τ_min * α_p, τ_max * α_p)
206
+ α_m = @. clamp (α_tm, τ_min * α_m, τ_max * α_m)
207
+ x_new = @. x + α_p * d
134
208
f_new, F_new = ff (x_new)
209
+
210
+ # NOTE: The original algorithm runs till either condition is satisfied, however,
211
+ # for most batched problems like neural networks we only care about
212
+ # approximate convergence
213
+ batched && (inner_iterations ≥ alg. max_inner_iterations) && break
135
214
end
136
215
137
- if isapprox (x_new, x, atol = atol, rtol = rtol)
138
- return SciMLBase. build_solution (prob, alg, x_new, F_new;
216
+ if termination_condition (F_new, x_new, x, atol, rtol)
217
+ return SciMLBase. build_solution (prob,
218
+ alg,
219
+ x_new,
220
+ F_new;
139
221
retcode = ReturnCode. Success)
140
222
end
223
+
141
224
# Update spectral parameter
142
- s_k = x_new - x
143
- y_k = F_new - F_k
144
- σ_k = (s_k' * s_k) / (s_k' * y_k)
225
+ s_k = @. x_new - x
226
+ y_k = @. F_new - F_k
227
+
228
+ if batched
229
+ σ_k = sum (abs2, s_k; dims = 1 ) ./ (sum (s_k .* y_k; dims = 1 ) .+ T (1e-5 ))
230
+ else
231
+ σ_k = (s_k' * s_k) / (s_k' * y_k)
232
+ end
145
233
146
234
# Take step
147
235
x = x_new
148
236
F_k = F_new
149
237
f_k = f_new
150
238
151
239
# Store function value
152
- history_f_k[k % M + 1 ] = f_new
240
+ if batched
241
+ history_f_k[k % M + 1 , :] .= vec (f_new)
242
+ else
243
+ history_f_k[k % M + 1 ] = f_new
244
+ end
153
245
end
154
246
return SciMLBase. build_solution (prob, alg, x, F_k; retcode = ReturnCode. MaxIters)
155
247
end
0 commit comments