@@ -221,31 +221,27 @@ _zero(::AbstractVector, d::AbstractMatrix) = zero(similar(d, size(d, 2)))
221
221
_zero (:: AbstractMatrix , d:: AbstractMatrix ) = zero (d)
222
222
_zero (:: Any , d:: Any ) = zero (d)
223
223
224
+ @inline _dot (x, y) = dot (x, y)
225
+ @inline function _dot (x:: AbstractVector , y:: UniformScaling )
226
+ @assert length (x) == 1
227
+ return @inbounds dot (x[1 ], y. λ)
228
+ end
229
+ @inline function _dot (x:: AbstractVector , y:: AbstractMatrix )
230
+ @assert size (y, 2 ) == 1
231
+ return dot (x, y)
232
+ end
233
+
224
234
function pullback_function (ab:: AbstractBackend , f, xs... )
225
235
return (ws) -> begin
226
- jacs = jacobian (lowest (ab), (xs... ,) -> begin
236
+ return gradient (lowest (ab), (xs... ,) -> begin
227
237
vs = f (xs... )
228
238
if ws isa Tuple
229
239
@assert length (vs) == length (ws)
230
- return sum (zip (vs, ws)) do v, w
231
- if w isa Union{AbstractMatrix, UniformScaling} && v isa AbstractVector
232
- return w' * v
233
- else
234
- # for arbitrary arrays
235
- return dot (w, v)
236
- end
237
- end
240
+ return sum (Base. splat (_dot), zip (ws, vs))
238
241
else
239
- w, v = ws, vs
240
- if w isa Union{AbstractMatrix, UniformScaling} && v isa AbstractVector
241
- return w' * v
242
- else
243
- # for arbitrary arrays
244
- return dot (w, v)
245
- end
242
+ return _dot (vs, ws)
246
243
end
247
244
end , xs... )
248
- return adjoint .(jacs)
249
245
end
250
246
end
251
247
function value_and_pullback_function (
0 commit comments