90
90
ϕdϕ
91
91
method
92
92
alpha
93
- grad_op
93
+ deriv_op
94
94
u_cache
95
95
fu_cache
96
96
stats:: NLStats
@@ -110,25 +110,59 @@ function __internal_init(
110
110
@warn " Scalar AD is supported only for AutoForwardDiff and AutoFiniteDiff. \
111
111
Detected $(autodiff) . Falling back to AutoFiniteDiff."
112
112
end
113
- grad_op = @closure (u, fu, p) -> last (__value_derivative (
114
- autodiff, Base. Fix2 (f, p), u)) * fu
113
+ deriv_op = @closure (du, u, fu, p) -> last (__value_derivative (
114
+ autodiff, Base. Fix2 (f, p), u)) *
115
+ fu *
116
+ du
115
117
else
116
- if SciMLBase. has_jvp (f)
118
+ # Both forward and reverse AD can be used for line-search.
119
+ # We prefer forward AD for better performance, however, reverse AD is also supported if user explicitly requests it.
120
+ # 1. If jvp is available, we use forward AD;
121
+ # 2. If vjp is available, we use reverse AD;
122
+ # 3. If reverse type is requested, we use reverse AD;
123
+ # 4. Finally, we use forward AD.
124
+ if alg. autodiff isa AutoFiniteDiff
125
+ deriv_op = nothing
126
+ elseif SciMLBase. has_jvp (f)
117
127
if isinplace (prob)
118
- g_cache = __similar (u)
119
- grad_op = @closure (u, fu, p) -> f. vjp (g_cache, fu, u, p)
128
+ jvp_cache = __similar (fu)
129
+ deriv_op = @closure (du, u, fu, p) -> begin
130
+ f. jvp (jvp_cache, du, u, p)
131
+ dot (fu, jvp_cache)
132
+ end
120
133
else
121
- grad_op = @closure (u, fu, p) -> f . vjp (fu , u, p)
134
+ deriv_op = @closure (du, u, fu, p) -> dot (fu, f . jvp (du , u, p) )
122
135
end
123
- else
136
+ elseif SciMLBase. has_vjp (f)
137
+ if isinplace (prob)
138
+ vjp_cache = __similar (u)
139
+ deriv_op = @closure (du, u, fu, p) -> begin
140
+ f. vjp (vjp_cache, fu, u, p)
141
+ dot (du, vjp_cache)
142
+ end
143
+ else
144
+ deriv_op = @closure (du, u, fu, p) -> dot (du, f. vjp (fu, u, p))
145
+ end
146
+ elseif alg. autodiff != = nothing &&
147
+ ADTypes. mode (alg. autodiff) isa ADTypes. ReverseMode
124
148
autodiff = get_concrete_reverse_ad (
125
149
alg. autodiff, prob; check_reverse_mode = true )
126
150
vjp_op = VecJacOperator (prob, fu, u; autodiff)
127
151
if isinplace (prob)
128
- g_cache = __similar (u)
129
- grad_op = @closure (u, fu, p) -> vjp_op (g_cache, fu, u, p)
152
+ vjp_cache = __similar (u)
153
+ deriv_op = @closure (du, u, fu, p) -> dot (du, vjp_op (vjp_cache, fu, u, p))
154
+ else
155
+ deriv_op = @closure (du, u, fu, p) -> dot (du, vjp_op (fu, u, p))
156
+ end
157
+ else
158
+ autodiff = get_concrete_forward_ad (
159
+ alg. autodiff, prob; check_forward_mode = true )
160
+ jvp_op = JacVecOperator (prob, fu, u; autodiff)
161
+ if isinplace (prob)
162
+ jvp_cache = __similar (fu)
163
+ deriv_op = @closure (du, u, fu, p) -> dot (fu, jvp_op (jvp_cache, du, u, p))
130
164
else
131
- grad_op = @closure (u, fu, p) -> vjp_op (fu, u, p)
165
+ deriv_op = @closure (du, u, fu, p) -> dot (fu, jvp_op (du, u, p) )
132
166
end
133
167
end
134
168
end
@@ -143,33 +177,37 @@ function __internal_init(
143
177
return @fastmath internalnorm (fu_cache)^ 2 / 2
144
178
end
145
179
146
- dϕ = @closure (f, p, u, du, α, u_cache, fu_cache, grad_op ) -> begin
180
+ dϕ = @closure (f, p, u, du, α, u_cache, fu_cache, deriv_op ) -> begin
147
181
@bb @. u_cache = u + α * du
148
182
fu_cache = evaluate_f!! (f, fu_cache, u_cache, p)
149
183
stats. nf += 1
150
- g₀ = grad_op (u_cache, fu_cache, p)
151
- return dot (g₀, du)
184
+ return deriv_op (du, u_cache, fu_cache, p)
152
185
end
153
186
154
- ϕdϕ = @closure (f, p, u, du, α, u_cache, fu_cache, grad_op ) -> begin
187
+ ϕdϕ = @closure (f, p, u, du, α, u_cache, fu_cache, deriv_op ) -> begin
155
188
@bb @. u_cache = u + α * du
156
189
fu_cache = evaluate_f!! (f, fu_cache, u_cache, p)
157
190
stats. nf += 1
158
- g₀ = grad_op ( u_cache, fu_cache, p)
191
+ deriv = deriv_op (du, u_cache, fu_cache, p)
159
192
obj = @fastmath internalnorm (fu_cache)^ 2 / 2
160
- return obj, dot (g₀, du)
193
+ return obj, deriv
161
194
end
162
195
163
196
return LineSearchesJLCache (f, p, ϕ, dϕ, ϕdϕ, alg. method, T (alg. initial_alpha),
164
- grad_op , u_cache, fu_cache, stats)
197
+ deriv_op , u_cache, fu_cache, stats)
165
198
end
166
199
167
200
function __internal_solve! (cache:: LineSearchesJLCache , u, du; kwargs... )
168
201
ϕ = @closure α -> cache. ϕ (cache. f, cache. p, u, du, α, cache. u_cache, cache. fu_cache)
169
- dϕ = @closure α -> cache. dϕ (
170
- cache. f, cache. p, u, du, α, cache. u_cache, cache. fu_cache, cache. grad_op)
171
- ϕdϕ = @closure α -> cache. ϕdϕ (
172
- cache. f, cache. p, u, du, α, cache. u_cache, cache. fu_cache, cache. grad_op)
202
+ if cache. deriv_op != = nothing
203
+ dϕ = @closure α -> cache. dϕ (
204
+ cache. f, cache. p, u, du, α, cache. u_cache, cache. fu_cache, cache. deriv_op)
205
+ ϕdϕ = @closure α -> cache. ϕdϕ (
206
+ cache. f, cache. p, u, du, α, cache. u_cache, cache. fu_cache, cache. deriv_op)
207
+ else
208
+ dϕ = @closure α -> FiniteDiff. finite_difference_derivative (ϕ, α)
209
+ ϕdϕ = @closure α -> (ϕ (α), FiniteDiff. finite_difference_derivative (ϕ, α))
210
+ end
173
211
174
212
ϕ₀, dϕ₀ = ϕdϕ (zero (eltype (u)))
175
213
0 commit comments