Skip to content

Commit 8f70a4e

Browse files
committed
Use Zygote for LineSearch if loaded
1 parent 77164c7 commit 8f70a4e

File tree

1 file changed

+17
-8
lines changed

1 file changed

+17
-8
lines changed

src/linesearch.jl

+17-8
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
"""
2-
LineSearch(method = Static(), autodiff = AutoFiniteDiff(), alpha = true)
2+
LineSearch(method = nothing, autodiff = nothing, alpha = true)
33
44
Wrapper over algorithms from
55
[LineSeaches.jl](https://github.com/JuliaNLSolvers/LineSearches.jl/). Allows automatic
@@ -13,7 +13,7 @@ differentiation for fast Vector Jacobian Products.
1313
- `autodiff`: the automatic differentiation backend to use for the line search. Defaults to
1414
`AutoFiniteDiff()`, which means that finite differencing is used to compute the VJP.
1515
`AutoZygote()` will be faster in most cases, but it requires `Zygote.jl` to be manually
16-
installed and loaded
16+
installed and loaded.
1717
- `alpha`: the initial step size to use. Defaults to `true` (which is equivalent to `1`).
1818
"""
1919
@concrete struct LineSearch
@@ -22,7 +22,7 @@ differentiation for fast Vector Jacobian Products.
2222
α
2323
end
2424

25-
function LineSearch(; method = nothing, autodiff = AutoFiniteDiff(), alpha = true)
25+
function LineSearch(; method = nothing, autodiff = nothing, alpha = true)
2626
return LineSearch(method, autodiff, alpha)
2727
end
2828

@@ -113,12 +113,21 @@ function LineSearchesJLCache(ls::LineSearch, f::F, u, p, fu1, IIP::Val{iip}) whe
113113

114114
g₀ = _mutable_zero(u)
115115

116-
autodiff = if iip && (ls.autodiff isa AutoZygote || ls.autodiff isa AutoSparseZygote)
117-
@warn "Attempting to use Zygote.jl for linesearch on an in-place problem. Falling \
118-
back to finite differencing."
119-
AutoFiniteDiff()
116+
autodiff = if ls.autodiff === nothing
117+
if !iip && haskey(Base.loaded_modules,
118+
Base.PkgId(Base.UUID("e88e6eb3-aa80-5325-afca-941959d7151f"), "Zygote"))
119+
AutoZygote()
120+
else
121+
AutoFiniteDiff()
122+
end
120123
else
121-
ls.autodiff
124+
if iip && (ls.autodiff isa AutoZygote || ls.autodiff isa AutoSparseZygote)
125+
@warn "Attempting to use Zygote.jl for linesearch on an in-place problem. \
126+
Falling back to finite differencing."
127+
AutoFiniteDiff()
128+
else
129+
ls.autodiff
130+
end
122131
end
123132

124133
function g!(u, fu)

0 commit comments

Comments
 (0)