Skip to content

Commit 04a4563

Browse files
feat: static version of LiFukushimaLineSearch (#11)
* feat: static version of LiFukushimaLineSearch * chore: apply formatting suggestion Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * chore: bump version for release * fix: mistake in cache typing * fix: store as real numbers not rational --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
1 parent ab881c1 commit 04a4563

File tree

3 files changed

+77
-13
lines changed

3 files changed

+77
-13
lines changed

Project.toml

+3-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "LineSearch"
22
uuid = "87fe0de2-c867-4266-b59a-2f0a94fc965b"
33
authors = ["SciML"]
4-
version = "0.1.2"
4+
version = "0.1.3"
55

66
[deps]
77
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
@@ -12,6 +12,7 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1212
MaybeInplace = "bb5d69b7-63fc-4a16-80bd-7e42200c7bdb"
1313
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
1414
SciMLJacobianOperators = "19f34311-ddf3-4b8b-af20-060888a46c0e"
15+
StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"
1516

1617
[weakdeps]
1718
LineSearches = "d3d80556-e9d4-5f37-9878-2ab0fcc64255"
@@ -38,6 +39,7 @@ ReTestItems = "1.28.0"
3839
ReverseDiff = "1.15.3"
3940
SciMLBase = "2.53.1"
4041
SciMLJacobianOperators = "0.1"
42+
StaticArraysCore = "1.4"
4143
Test = "1.10"
4244
Tracker = "0.2.35"
4345
Zygote = "0.6.71"

src/LineSearch.jl

+1
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ using MaybeInplace: @bb
99
using SciMLBase: SciMLBase, AbstractSciMLProblem, AbstractNonlinearProblem, ReturnCode,
1010
NonlinearProblem, NonlinearLeastSquaresProblem, NonlinearFunction
1111
using SciMLJacobianOperators: VecJacOperator, JacVecOperator
12+
using StaticArraysCore: SArray
1213

1314
abstract type AbstractLineSearchAlgorithm end
1415
abstract type AbstractLineSearchCache end

src/li_fukushima.jl

+73-12
Original file line numberDiff line numberDiff line change
@@ -4,23 +4,30 @@
44
55
A derivative-free line search and global convergence of Broyden-like method for nonlinear
66
equations [li2000derivative](@cite).
7+
8+
!!! tip
9+
10+
For static arrays and numbers if `nan_maxiters` is either `nothing` or `missing`,
11+
we provide a fully non-allocating implementation of the algorithm, that can be used
12+
inside GPU kernels. However, this particular version doesn't support `stats` and
13+
`reinit!` and those will be ignored. Additionally, we fix the initial alpha for the
14+
search to be `1`.
715
"""
816
@kwdef @concrete struct LiFukushimaLineSearch <: AbstractLineSearchAlgorithm
917
lambda_0 = 1
10-
beta = 1 // 2
11-
sigma_1 = 1 // 1000
12-
sigma_2 = 1 // 1000
13-
eta = 1 // 10
14-
rho = 9 // 10
15-
nan_maxiters::Int = 5
18+
beta = 0.5
19+
sigma_1 = 0.001
20+
sigma_2 = 0.001
21+
eta = 0.1
22+
rho = 0.9
23+
nan_maxiters <: Union{Missing, Nothing, Int} = 5
1624
maxiters::Int = 100
1725
end
1826

1927
@concrete mutable struct LiFukushimaLineSearchCache <: AbstractLineSearchCache
2028
ϕ
2129
f
2230
p
23-
internalnorm
2431
u_cache
2532
fu_cache
2633
λ₀
@@ -30,15 +37,45 @@ end
3037
η
3138
ρ
3239
α
33-
nan_maxiters::Int
40+
nan_maxiters <: Union{Missing, Nothing, Int}
3441
maxiters::Int
3542
stats <: Union{SciMLBase.NLStats, Nothing}
3643
alg <: LiFukushimaLineSearch
3744
end
3845

46+
@concrete struct StaticLiFukushimaLineSearchCache <: AbstractLineSearchCache
47+
f
48+
p
49+
λ₀
50+
β
51+
σ₁
52+
σ₂
53+
η
54+
ρ
55+
maxiters::Int
56+
end
57+
3958
function CommonSolve.init(
40-
prob::AbstractNonlinearProblem, alg::LiFukushimaLineSearch, fu, u;
59+
prob::AbstractNonlinearProblem, alg::LiFukushimaLineSearch,
60+
fu::Union{SArray, Number}, u::Union{SArray, Number};
4161
stats::Union{SciMLBase.NLStats, Nothing} = nothing, kwargs...)
62+
if (alg.nan_maxiters === nothing || alg.nan_maxiters === missing) && stats === nothing
63+
T = promote_type(eltype(fu), eltype(u))
64+
return StaticLiFukushimaLineSearchCache(prob.f, prob.p, T(alg.lambda_0),
65+
T(alg.beta), T(alg.sigma_1), T(alg.sigma_2), T(alg.eta), T(alg.rho),
66+
alg.maxiters)
67+
end
68+
return generic_lifukushima_init(prob, alg, fu, u; stats, kwargs...)
69+
end
70+
71+
function CommonSolve.init(
72+
prob::AbstractNonlinearProblem, alg::LiFukushimaLineSearch, fu, u; kwargs...)
73+
return generic_lifukushima_init(prob, alg, fu, u; kwargs...)
74+
end
75+
76+
function generic_lifukushima_init(
77+
prob::AbstractNonlinearProblem, alg::LiFukushimaLineSearch,
78+
fu, u; stats::Union{SciMLBase.NLStats, Nothing} = nothing, kwargs...)
4279
@bb u_cache = similar(u)
4380
@bb fu_cache = similar(fu)
4481
T = promote_type(eltype(fu), eltype(u))
@@ -51,7 +88,7 @@ function CommonSolve.init(
5188
end
5289

5390
return LiFukushimaLineSearchCache(
54-
ϕ, prob.f, prob.p, T(1), u_cache, fu_cache, T(alg.lambda_0), T(alg.beta),
91+
ϕ, prob.f, prob.p, u_cache, fu_cache, T(alg.lambda_0), T(alg.beta),
5592
T(alg.sigma_1), T(alg.sigma_2), T(alg.eta), T(alg.rho), T(1), alg.nan_maxiters,
5693
alg.maxiters, stats, alg)
5794
end
@@ -74,7 +111,8 @@ function CommonSolve.solve!(cache::LiFukushimaLineSearchCache, u, du)
74111
λ₂, λ₁ = cache.λ₀, cache.λ₀
75112
fxλp_norm = ϕ(λ₂)
76113

77-
if !isfinite(fxλp_norm)
114+
if !isfinite(fxλp_norm) && cache.nan_maxiters !== nothing &&
115+
cache.nan_maxiters !== missing
78116
nan_converged = false
79117
for _ in 1:(cache.nan_maxiters)
80118
λ₁, λ₂ = λ₂, cache.β * λ₂
@@ -85,7 +123,7 @@ function CommonSolve.solve!(cache::LiFukushimaLineSearchCache, u, du)
85123
nan_converged || return LineSearchSolution(cache.α, ReturnCode.Failure)
86124
end
87125

88-
for i in 1:(cache.maxiters)
126+
for _ in 1:(cache.maxiters)
89127
fxλp_norm = ϕ(λ₂)
90128
converged = fxλp_norm (1 + cache.η) * fx_norm - cache.σ₁ * λ₂^2 * du_norm^2
91129
converged && return LineSearchSolution(λ₂, ReturnCode.Success)
@@ -95,6 +133,29 @@ function CommonSolve.solve!(cache::LiFukushimaLineSearchCache, u, du)
95133
return LineSearchSolution(cache.α, ReturnCode.Failure)
96134
end
97135

136+
function CommonSolve.solve!(cache::StaticLiFukushimaLineSearchCache, u, du)
137+
T = promote_type(eltype(du), eltype(u))
138+
139+
fx_norm = norm(cache.f(u, cache.p))
140+
du_norm = norm(du)
141+
fxλ_norm = norm(cache.f(u .+ du, cache.p))
142+
143+
if fxλ_norm cache.ρ * fx_norm - cache.σ₂ * du_norm^2
144+
return LineSearchSolution(T(true), ReturnCode.Success)
145+
end
146+
147+
λ₂, λ₁ = cache.λ₀, cache.λ₀
148+
149+
for _ in 1:(cache.maxiters)
150+
fxλp_norm = norm(cache.f(u .+ λ₂ .* du, cache.p))
151+
converged = fxλp_norm (1 + cache.η) * fx_norm - cache.σ₁ * λ₂^2 * du_norm^2
152+
converged && return LineSearchSolution(λ₂, ReturnCode.Success)
153+
λ₁, λ₂ = λ₂, cache.β * λ₂
154+
end
155+
156+
return LineSearchSolution(T(true), ReturnCode.Failure)
157+
end
158+
98159
function SciMLBase.reinit!(
99160
cache::LiFukushimaLineSearchCache; p = missing, stats = missing, kwargs...)
100161
p !== missing && (cache.p = p)

0 commit comments

Comments
 (0)