Skip to content

Commit 97a2489

Browse files
committed
Improvements to pullback_function
1 parent 3eb3fd1 commit 97a2489

File tree

1 file changed

+13
-17
lines changed

1 file changed

+13
-17
lines changed

src/AbstractDifferentiation.jl

Lines changed: 13 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -221,31 +221,27 @@ _zero(::AbstractVector, d::AbstractMatrix) = zero(similar(d, size(d, 2)))
221221
_zero(::AbstractMatrix, d::AbstractMatrix) = zero(d)
222222
_zero(::Any, d::Any) = zero(d)
223223

224+
@inline _dot(x, y) = dot(x, y)
225+
@inline function _dot(x::AbstractVector, y::UniformScaling)
226+
@assert length(x) == 1
227+
return @inbounds dot(x[1], y.λ)
228+
end
229+
@inline function _dot(x::AbstractVector, y::AbstractMatrix)
230+
@assert size(y, 2) == 1
231+
return dot(x, y)
232+
end
233+
224234
function pullback_function(ab::AbstractBackend, f, xs...)
225235
return (ws) -> begin
226-
jacs = jacobian(lowest(ab), (xs...,) -> begin
236+
return gradient(lowest(ab), (xs...,) -> begin
227237
vs = f(xs...)
228238
if ws isa Tuple
229239
@assert length(vs) == length(ws)
230-
return sum(zip(vs, ws)) do v, w
231-
if w isa Union{AbstractMatrix, UniformScaling} && v isa AbstractVector
232-
return w' * v
233-
else
234-
# for arbitrary arrays
235-
return dot(w, v)
236-
end
237-
end
240+
return sum(Base.splat(_dot), zip(ws, vs))
238241
else
239-
w, v = ws, vs
240-
if w isa Union{AbstractMatrix, UniformScaling} && v isa AbstractVector
241-
return w' * v
242-
else
243-
# for arbitrary arrays
244-
return dot(w, v)
245-
end
242+
return _dot(vs, ws)
246243
end
247244
end, xs...)
248-
return adjoint.(jacs)
249245
end
250246
end
251247
function value_and_pullback_function(

0 commit comments

Comments
 (0)