Skip to content

Commit 6740074

Browse files
authored
Merge pull request #446 from tansongchen/fwd-ls
Add forward mode to line search
2 parents f3b2e1f + b7c54f3 commit 6740074

File tree

2 files changed

+64
-26
lines changed

2 files changed

+64
-26
lines changed

src/globalization/line_search.jl

+60-22
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ end
9090
ϕdϕ
9191
method
9292
alpha
93-
grad_op
93+
deriv_op
9494
u_cache
9595
fu_cache
9696
stats::NLStats
@@ -110,25 +110,59 @@ function __internal_init(
110110
@warn "Scalar AD is supported only for AutoForwardDiff and AutoFiniteDiff. \
111111
Detected $(autodiff). Falling back to AutoFiniteDiff."
112112
end
113-
grad_op = @closure (u, fu, p) -> last(__value_derivative(
114-
autodiff, Base.Fix2(f, p), u)) * fu
113+
deriv_op = @closure (du, u, fu, p) -> last(__value_derivative(
114+
autodiff, Base.Fix2(f, p), u)) *
115+
fu *
116+
du
115117
else
116-
if SciMLBase.has_jvp(f)
118+
# Both forward and reverse AD can be used for line-search.
119+
# We prefer forward AD for better performance, however, reverse AD is also supported if user explicitly requests it.
120+
# 1. If jvp is available, we use forward AD;
121+
# 2. If vjp is available, we use reverse AD;
122+
# 3. If reverse type is requested, we use reverse AD;
123+
# 4. Finally, we use forward AD.
124+
if alg.autodiff isa AutoFiniteDiff
125+
deriv_op = nothing
126+
elseif SciMLBase.has_jvp(f)
117127
if isinplace(prob)
118-
g_cache = __similar(u)
119-
grad_op = @closure (u, fu, p) -> f.vjp(g_cache, fu, u, p)
128+
jvp_cache = __similar(fu)
129+
deriv_op = @closure (du, u, fu, p) -> begin
130+
f.jvp(jvp_cache, du, u, p)
131+
dot(fu, jvp_cache)
132+
end
120133
else
121-
grad_op = @closure (u, fu, p) -> f.vjp(fu, u, p)
134+
deriv_op = @closure (du, u, fu, p) -> dot(fu, f.jvp(du, u, p))
122135
end
123-
else
136+
elseif SciMLBase.has_vjp(f)
137+
if isinplace(prob)
138+
vjp_cache = __similar(u)
139+
deriv_op = @closure (du, u, fu, p) -> begin
140+
f.vjp(vjp_cache, fu, u, p)
141+
dot(du, vjp_cache)
142+
end
143+
else
144+
deriv_op = @closure (du, u, fu, p) -> dot(du, f.vjp(fu, u, p))
145+
end
146+
elseif alg.autodiff !== nothing &&
147+
ADTypes.mode(alg.autodiff) isa ADTypes.ReverseMode
124148
autodiff = get_concrete_reverse_ad(
125149
alg.autodiff, prob; check_reverse_mode = true)
126150
vjp_op = VecJacOperator(prob, fu, u; autodiff)
127151
if isinplace(prob)
128-
g_cache = __similar(u)
129-
grad_op = @closure (u, fu, p) -> vjp_op(g_cache, fu, u, p)
152+
vjp_cache = __similar(u)
153+
deriv_op = @closure (du, u, fu, p) -> dot(du, vjp_op(vjp_cache, fu, u, p))
154+
else
155+
deriv_op = @closure (du, u, fu, p) -> dot(du, vjp_op(fu, u, p))
156+
end
157+
else
158+
autodiff = get_concrete_forward_ad(
159+
alg.autodiff, prob; check_forward_mode = true)
160+
jvp_op = JacVecOperator(prob, fu, u; autodiff)
161+
if isinplace(prob)
162+
jvp_cache = __similar(fu)
163+
deriv_op = @closure (du, u, fu, p) -> dot(fu, jvp_op(jvp_cache, du, u, p))
130164
else
131-
grad_op = @closure (u, fu, p) -> vjp_op(fu, u, p)
165+
deriv_op = @closure (du, u, fu, p) -> dot(fu, jvp_op(du, u, p))
132166
end
133167
end
134168
end
@@ -143,33 +177,37 @@ function __internal_init(
143177
return @fastmath internalnorm(fu_cache)^2 / 2
144178
end
145179

146-
= @closure (f, p, u, du, α, u_cache, fu_cache, grad_op) -> begin
180+
= @closure (f, p, u, du, α, u_cache, fu_cache, deriv_op) -> begin
147181
@bb @. u_cache = u + α * du
148182
fu_cache = evaluate_f!!(f, fu_cache, u_cache, p)
149183
stats.nf += 1
150-
g₀ = grad_op(u_cache, fu_cache, p)
151-
return dot(g₀, du)
184+
return deriv_op(du, u_cache, fu_cache, p)
152185
end
153186

154-
ϕdϕ = @closure (f, p, u, du, α, u_cache, fu_cache, grad_op) -> begin
187+
ϕdϕ = @closure (f, p, u, du, α, u_cache, fu_cache, deriv_op) -> begin
155188
@bb @. u_cache = u + α * du
156189
fu_cache = evaluate_f!!(f, fu_cache, u_cache, p)
157190
stats.nf += 1
158-
g₀ = grad_op(u_cache, fu_cache, p)
191+
deriv = deriv_op(du, u_cache, fu_cache, p)
159192
obj = @fastmath internalnorm(fu_cache)^2 / 2
160-
return obj, dot(g₀, du)
193+
return obj, deriv
161194
end
162195

163196
return LineSearchesJLCache(f, p, ϕ, dϕ, ϕdϕ, alg.method, T(alg.initial_alpha),
164-
grad_op, u_cache, fu_cache, stats)
197+
deriv_op, u_cache, fu_cache, stats)
165198
end
166199

167200
function __internal_solve!(cache::LineSearchesJLCache, u, du; kwargs...)
168201
ϕ = @closure α -> cache.ϕ(cache.f, cache.p, u, du, α, cache.u_cache, cache.fu_cache)
169-
= @closure α -> cache.(
170-
cache.f, cache.p, u, du, α, cache.u_cache, cache.fu_cache, cache.grad_op)
171-
ϕdϕ = @closure α -> cache.ϕdϕ(
172-
cache.f, cache.p, u, du, α, cache.u_cache, cache.fu_cache, cache.grad_op)
202+
if cache.deriv_op !== nothing
203+
= @closure α -> cache.(
204+
cache.f, cache.p, u, du, α, cache.u_cache, cache.fu_cache, cache.deriv_op)
205+
ϕdϕ = @closure α -> cache.ϕdϕ(
206+
cache.f, cache.p, u, du, α, cache.u_cache, cache.fu_cache, cache.deriv_op)
207+
else
208+
= @closure α -> FiniteDiff.finite_difference_derivative(ϕ, α)
209+
ϕdϕ = @closure α -> (ϕ(α), FiniteDiff.finite_difference_derivative(ϕ, α))
210+
end
173211

174212
ϕ₀, dϕ₀ = ϕdϕ(zero(eltype(u)))
175213

test/core/rootfind_tests.jl

+4-4
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ end
5555
@testitem "NewtonRaphson" setup=[CoreRootfindTesting] tags=[:core] timeout=3600 begin
5656
@testset "LineSearch: $(_nameof(lsmethod)) LineSearch AD: $(_nameof(ad))" for lsmethod in (
5757
Static(), StrongWolfe(), BackTracking(), HagerZhang(), MoreThuente()),
58-
ad in (AutoFiniteDiff(), AutoZygote())
58+
ad in (AutoForwardDiff(), AutoZygote(), AutoFiniteDiff())
5959

6060
linesearch = LineSearchesJL(; method = lsmethod, autodiff = ad)
6161
u0s = ([1.0, 1.0], @SVector[1.0, 1.0], 1.0)
@@ -466,7 +466,7 @@ end
466466
@testset "LineSearch: $(_nameof(lsmethod)) LineSearch AD: $(_nameof(ad)) Init Jacobian: $(init_jacobian) Update Rule: $(update_rule)" for lsmethod in (
467467
Static(), StrongWolfe(), BackTracking(),
468468
HagerZhang(), MoreThuente(), LiFukushimaLineSearch()),
469-
ad in (AutoFiniteDiff(), AutoZygote()),
469+
ad in (AutoForwardDiff(), AutoZygote(), AutoFiniteDiff()),
470470
init_jacobian in (Val(:identity), Val(:true_jacobian)),
471471
update_rule in (Val(:good_broyden), Val(:bad_broyden), Val(:diagonal))
472472

@@ -515,7 +515,7 @@ end
515515
@testitem "Klement" setup=[CoreRootfindTesting] tags=[:core] skip=:(Sys.isapple()) timeout=3600 begin
516516
@testset "LineSearch: $(_nameof(lsmethod)) LineSearch AD: $(_nameof(ad)) Init Jacobian: $(init_jacobian)" for lsmethod in (
517517
Static(), StrongWolfe(), BackTracking(), HagerZhang(), MoreThuente()),
518-
ad in (AutoFiniteDiff(), AutoZygote()),
518+
ad in (AutoForwardDiff(), AutoZygote(), AutoFiniteDiff()),
519519
init_jacobian in (Val(:identity), Val(:true_jacobian), Val(:true_jacobian_diagonal))
520520

521521
linesearch = LineSearchesJL(; method = lsmethod, autodiff = ad)
@@ -565,7 +565,7 @@ end
565565
@testset "LineSearch: $(_nameof(lsmethod)) LineSearch AD: $(_nameof(ad))" for lsmethod in (
566566
Static(), StrongWolfe(), BackTracking(),
567567
HagerZhang(), MoreThuente(), LiFukushimaLineSearch()),
568-
ad in (AutoFiniteDiff(), AutoZygote())
568+
ad in (AutoForwardDiff(), AutoZygote(), AutoFiniteDiff())
569569

570570
linesearch = LineSearchesJL(; method = lsmethod, autodiff = ad)
571571
u0s = ([1.0, 1.0], @SVector[1.0, 1.0], 1.0)

0 commit comments

Comments
 (0)