Skip to content

Commit 7b8ae45

Browse files
Merge pull request #281 from FluxML/cl/rrule
cleanup AD
2 parents 828b017 + 205ec52 commit 7b8ae45

9 files changed

+25
-28
lines changed

src/NNlib.jl

+1
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ module NNlib
33
using Pkg
44
using Requires
55
using ChainRulesCore
6+
import ChainRulesCore: rrule
67
using Base.Broadcast: broadcasted
78

89
const Numeric = Union{AbstractArray{<:T}, T} where {T<:Number}

src/activations.jl

+5-5
Original file line numberDiff line numberDiff line change
@@ -256,8 +256,8 @@ for (f, df) in UNARY_ACTS
256256
@eval @scalar_rule($f(x), $df)
257257

258258
pullback = Symbol(:broadcasted_, f, :_pullback)
259-
@eval function ChainRulesCore.rrule(::typeof(broadcasted),
260-
::typeof($f), x::Numeric)
259+
@eval function rrule(::typeof(broadcasted),
260+
::typeof($f), x::Numeric)
261261
Ω = $f.(x)
262262
function $pullback(Δ)
263263
NO_FIELDS, NO_FIELDS, @.* $df)
@@ -275,9 +275,9 @@ for (f, df1, df2) in BINARY_ACTS
275275
@eval @scalar_rule($f(x1, x2), ($df1, $df2))
276276

277277
pullback = Symbol(:broadcasted_, f, :_pullback)
278-
@eval function ChainRulesCore.rrule(::typeof(broadcasted),
279-
::typeof($f),
280-
x1::Numeric, x2::Numeric)
278+
@eval function rrule(::typeof(broadcasted),
279+
::typeof($f),
280+
x1::Numeric, x2::Numeric)
281281
Ω = $f.(x1, x2)
282282
function $pullback(Δ)
283283
NO_FIELDS, NO_FIELDS, @.* $df1), @.* $df2)

src/batched/batchedadjtrans.jl

+2-2
Original file line numberDiff line numberDiff line change
@@ -93,11 +93,11 @@ Base.unsafe_convert(::Type{Ptr{T}}, A::BatchedAdjOrTrans{T}) where {T} =
9393
Base.unsafe_convert(Ptr{T}, parent(A))
9494

9595
# Gradients
96-
function ChainRulesCore.rrule(::typeof(batched_transpose), A::AbstractArray{<:Any,3})
96+
function rrule(::typeof(batched_transpose), A::AbstractArray{<:Any,3})
9797
b_transpose_back(Δ) = (NO_FIELDS, batched_transpose(Δ))
9898
batched_transpose(A), b_transpose_back
9999
end
100-
function ChainRulesCore.rrule(::typeof(batched_adjoint), A::AbstractArray{<:Any,3})
100+
function rrule(::typeof(batched_adjoint), A::AbstractArray{<:Any,3})
101101
b_adjoint_back(Δ) = (NO_FIELDS, batched_adjoint(Δ))
102102
batched_adjoint(A), b_adjoint_back
103103
end

src/batched/batchedmul.jl

+5-5
Original file line numberDiff line numberDiff line change
@@ -56,11 +56,11 @@ function batched_mul(A::AbstractArray{T1, 3}, B::AbstractArray{T2, 3}) where {T1
5656
_batched_mul(storage_typejoin(A, B), A, B)
5757
end
5858

59-
function ChainRulesCore.rrule(
60-
::typeof(batched_mul),
61-
A::AbstractArray{S,3},
62-
B::AbstractArray{T,3},
63-
) where {S,T}
59+
function rrule(::typeof(batched_mul),
60+
A::AbstractArray{S,3},
61+
B::AbstractArray{T,3},
62+
) where {S,T}
63+
6464
function batched_mul_pullback(Δ)
6565
return (
6666
NO_FIELDS,

src/conv.jl

+3-7
Original file line numberDiff line numberDiff line change
@@ -211,11 +211,7 @@ for front_name in (:conv, :∇conv_data, :∇conv_filter,
211211
end
212212

213213
for Dims in [:DenseConvDims, :DepthwiseConvDims, :PoolDims]
214-
pullback = Symbol(Dims, :_pullback)
215-
@eval function ChainRulesCore.rrule(::Type{$Dims}, args...; kwargs...)
216-
$pullback(Δ) = (NO_FIELDS, ntuple(_ -> DoesNotExist(), length(args))...)
217-
return $Dims(args...; kwargs...), $pullback
218-
end
214+
@eval @non_differentiable $Dims(::Any...)
219215
end
220216

221217
colmajor(x) = (is_strided(x) && Base.stride(x, 1) == 1) ? x : collect(x)
@@ -224,7 +220,7 @@ for conv in [:conv, :depthwiseconv]
224220
local ∇conv_data, ∇conv_filter = Symbol.(:∇, conv, [:_data, :_filter])
225221
conv_pullback, ∇conv_data_pullback = Symbol.([conv, ∇conv_data], :_pullback)
226222

227-
@eval function ChainRulesCore.rrule(::typeof($conv), x, w, cdims; kw...)
223+
@eval function rrule(::typeof($conv), x, w, cdims; kw...)
228224
function $conv_pullback(Δ)
229225
Δ = colmajor(Δ)
230226
return (
@@ -237,7 +233,7 @@ for conv in [:conv, :depthwiseconv]
237233
return $conv(x, w, cdims; kw...), $conv_pullback
238234
end
239235

240-
@eval function ChainRulesCore.rrule(::typeof($∇conv_data), x, w, cdims; kw...)
236+
@eval function rrule(::typeof($∇conv_data), x, w, cdims; kw...)
241237
function $∇conv_data_pullback(Δ)
242238
Δ = colmajor(Δ)
243239
return (

src/pooling.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,7 @@ end
171171
for pool in [:maxpool, :meanpool]
172172
∇pool = Symbol(:∇, pool)
173173
pullback = Symbol(pool, :_pullback)
174-
@eval function ChainRulesCore.rrule(::typeof($pool), x, pdims::PoolDims; kw...)
174+
@eval function rrule(::typeof($pool), x, pdims::PoolDims; kw...)
175175
Ω = $pool(x, pdims; kw...)
176176
$pullback(Δ) = (NO_FIELDS, $∇pool(Δ, Ω, x, pdims; kw...), DoesNotExist())
177177
return Ω, $pullback

src/softmax.jl

+3-3
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ end
7676
∇softmax!(Δ, x; dims=1) = ∇softmax!(Δ, Δ, x, softmax(x, dims=dims); dims=dims)
7777
∇softmax!(out, Δ, x; dims=1) = ∇softmax!(out, Δ, x, softmax(x, dims=dims); dims=dims)
7878

79-
function ChainRulesCore.rrule(::typeof(softmax), xs; dims=1)
79+
function rrule(::typeof(softmax), xs; dims=1)
8080
y = softmax(xs; dims=dims)
8181
softmax_pullback(Δ) = (NO_FIELDS, ∇softmax(Δ, xs, y, dims=dims))
8282
return y, softmax_pullback
@@ -112,7 +112,7 @@ end
112112

113113
∇logsoftmax::AbstractArray{T}, x::AbstractArray, y::AbstractArray{S}; dims = 1) where {T,S} =
114114
∇logsoftmax!(similar(y, promote_type(T, S)), Δ, x, y; dims = dims)
115-
115+
116116
# Old 2-arg version recomputing forward
117117
∇logsoftmax(Δ, x; dims=1) = ∇logsoftmax(Δ, x, logsoftmax(x, dims=dims); dims=dims)
118118
∇logsoftmax!(Δ, x; dims=1) = ∇logsoftmax!(Δ, Δ, x, logsoftmax(x, dims=dims); dims=dims)
@@ -123,7 +123,7 @@ function ∇logsoftmax!(out::AbstractArray, Δ::AbstractArray,
123123
out .= Δ .- sum(Δ, dims = dims) .* exp.(y)
124124
end
125125

126-
function ChainRulesCore.rrule(::typeof(logsoftmax), xs; dims=1)
126+
function rrule(::typeof(logsoftmax), xs; dims=1)
127127
y = logsoftmax(xs; dims=dims)
128128
logsoftmax_pullback(Δ) = (NO_FIELDS, ∇logsoftmax(Δ, xs, y, dims=dims))
129129
return y, logsoftmax_pullback

src/upsample.jl

+2-2
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ function ∇upsample_nearest(x::AbstractArray{T,N}, scales::NTuple{S, <:Integer}
7373
reshape(mid, outsize)
7474
end
7575

76-
function ChainRulesCore.rrule(::typeof(upsample_nearest), x::AbstractArray, s::Tuple)
76+
function rrule(::typeof(upsample_nearest), x::AbstractArray, s::Tuple)
7777
Ω = upsample_nearest(x, s)
7878
upsample_nearest_pullback(Δ) = (NO_FIELDS, ∇upsample_nearest(Δ, s), DoesNotExist())
7979
return Ω, upsample_nearest_pullback
@@ -249,7 +249,7 @@ function ∇upsample_bilinear_whcn!(dx::AbstractArray{T,4}, Δ::AbstractArray{T,
249249
return dx
250250
end
251251

252-
function ChainRulesCore.rrule(::typeof(upsample_bilinear), x; size)
252+
function rrule(::typeof(upsample_bilinear), x; size)
253253
Ω = upsample_bilinear(x; size=size)
254254
function upsample_bilinear_pullback(Δ)
255255
(NO_FIELDS, ∇upsample_bilinear(Δ; size=(Base.size(x,1),Base.size(x,2))))

test/test_utils.jl

+3-3
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ gradtest(f, dims::IntOrTuple...; kw...) =
88
Compare numerical gradient and automatic gradient
99
given by Zygote. `f` has to be a scalar valued function.
1010
11-
Applies also `ChainRulesTestUtils.rrule_test` if the rrule for `f` is explicitly defined.
11+
Applies also `ChainRulesTestUtils.test_rrule` if the rrule for `f` is explicitly defined.
1212
"""
1313
function gradtest(f, xs...; atol=1e-6, rtol=1e-6, fkwargs=NamedTuple(),
1414
check_rrule=false,
@@ -19,8 +19,8 @@ function gradtest(f, xs...; atol=1e-6, rtol=1e-6, fkwargs=NamedTuple(),
1919
y = f(xs...; fkwargs...)
2020
simil(x) = x isa Number ? randn(rng, typeof(x)) : randn!(rng, similar(x))
2121
= simil(y)
22-
xx̄s = [(x, simil(x)) for x in xs]
23-
rrule_test(f, ȳ, xx̄s...; fkwargs=fkwargs)
22+
xx̄s = [x simil(x) for x in xs]
23+
test_rrule(f, xx̄s...; fkwargs=fkwargs, output_tangent=)
2424
end
2525

2626
if check_broadcast

0 commit comments

Comments
 (0)