Skip to content

Commit

Permalink
wrap_chainrules_input for arrays of Ref (#1103)
Browse files Browse the repository at this point in the history
* wrap_chainrules_input for arrays of Ref

* z2d too, for rrule_via_ad

* test from JuliaDiff/ChainRulesCore.jl#440

* add test from JuliaDiff/ChainRules.jl#537

* more tests related to CRC types

* union nothing, fix one case

* comments
  • Loading branch information
mcabbott authored Oct 16, 2021
1 parent c9446c8 commit 5ae5b4f
Show file tree
Hide file tree
Showing 5 changed files with 96 additions and 3 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "Zygote"
uuid = "e88e6eb3-aa80-5325-afca-941959d7151f"
version = "0.6.28"
version = "0.6.29"

[deps]
AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"
Expand Down
15 changes: 13 additions & 2 deletions src/compiler/chainrules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,8 @@ for T_outer in (:Tuple, :NamedTuple)
ChainRulesCore.backing(xp) # this is accessing ChainRulesCore internals, but it is prob safe enough, and it is fastest
end
end
# Could `reinterpret` instead of broadcasting here -- TODO
@inline wrap_chainrules_output(xs::AbstractArray{<:ChainRules.Tangent}) = wrap_chainrules_output.(xs)

"""
wrap_chainrules_input(x)
Expand All @@ -130,6 +132,11 @@ Convert `x` from the format Zygote uses internally to differentials types ChainR
end
# For mutable types, including x=Ref(1), Zygote makes Ref{Any}(::NamedTuple)
@inline wrap_chainrules_input(x::Ref) = wrap_chainrules_input(x[])
# Could `reinterpret` instead of broadcasting here -- TODO
@inline wrap_chainrules_input(xs::AbstractArray{<:Ref}) = wrap_chainrules_input.(xs)
@inline wrap_chainrules_input(xs::AbstractArray{<:Union{Nothing, <:Ref}}) = wrap_chainrules_input.(xs) # no test invented for this
@inline wrap_chainrules_input(xs::AbstractArray{<:NamedTuple}) = wrap_chainrules_input.(xs)
@inline wrap_chainrules_input(xs::AbstractArray{<:Union{Nothing, <:NamedTuple}}) = wrap_chainrules_input.(xs)

"""
_project(x, dx)
Expand All @@ -139,6 +146,8 @@ Also handles some Zygote-specific corrections, such as `x::Array, dx::Tuple`.
Safe to apply to arbitrary input.
"""
@inline function _project(x, dx)
# Note that this use of `wrap_chainrules_input` has the primal `x`, so could
# avoid making `Tangent{Any}`, perhaps via `zygote2differential` -- TODO.
wrap_chainrules_output(ProjectTo(x)(wrap_chainrules_input(dx)))
end

Expand Down Expand Up @@ -224,9 +233,9 @@ function ChainRulesCore.rrule_via_ad(config::ZygoteRuleConfig, f_args...; kwargs
end

"""
zygote2differential(x)
zygote2differential(dx, primal)
Convert input `x` from the Zygote format to the ChainRules differential types.
Convert input `dx` from the Zygote format to the ChainRules differential types.
"""
zygote2differential(x, primal) = z2d(x, primal)
zygote2differential(::Nothing, ::Any) = NoTangent()
Expand All @@ -235,6 +244,7 @@ zygote2differential(t::Tuple, primal) = (@warn "primal should be a tuple, not $p
z2d(x, ::Any) = x
z2d(::Nothing, ::Any) = NoTangent()
z2d(a::AbstractArray{<:Number}, primal::AbstractArray{T}) where T = a
# Could probably `reinterpret` instead of broadcasting here -- TODO
z2d(a::AbstractArray, primal::AbstractArray{T}) where T = z2d.(a, primal)
# Note: this should never be hit if we are converting things right, but it seems to be
# happening in the wild for sufficiently weird functions/types.
Expand All @@ -254,3 +264,4 @@ function z2d(t::NamedTuple, primal)
tp::NamedTuple = map(z2d, complete_t, primals)
return canonicalize(Tangent{primal_type, typeof(tp)}(tp))
end
z2d(dx::Ref, primal) = z2d(dx[], primal) # mutable structs
1 change: 1 addition & 0 deletions src/lib/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ unbroadcast(x::Number, x̄) = accum_sum(x̄)
unbroadcast(x::Tuple{<:Any}, x̄) = (accum_sum(x̄),)
unbroadcast(x::Base.RefValue, x̄) = (x=accum_sum(x̄),)
unbroadcast(x::Tuple, x̄) = NTuple{length(x)}(length(x) == length(x̄) ?: accum_sum(x̄; dims=2:ndims(x̄))) # case length(x) > 1
unbroadcast(x::Tuple, x̄::Nothing) = nothing

unbroadcast(x::AbstractArray, x̄::Nothing) = nothing

Expand Down
7 changes: 7 additions & 0 deletions test/features.jl
Original file line number Diff line number Diff line change
Expand Up @@ -452,6 +452,13 @@ end
@test gradient(x -> x.x^2 + x.x, Ref(3)) === ((x = 7.0,),)
@test gradient(x -> real(x.x^2 + im * x.x), Ref(4)) === ((x = 8.0,),)

# Array of mutables:
@test gradient(x -> sum(getindex.(x).^2), Ref.(1:3))[1] == [(;x=2i) for i in 1:3]
@test gradient(x -> sum(abs2getindex, x), Ref.(1:3))[1] == [(;x=2i) for i in 1:3]

@test gradient(x -> (getindex.(x).^2)[1], Ref.(1:3))[1][1] == (x=2.0,) # rest are (x = 0.0,), but nothing would be OK too
@test gradient(x -> (prod.(getindex.(x)))[1], Ref.(eachcol([1 2; 3 4])))[1][1] == (x = [3.0, 1.0],)

# Broadcasting over Ref is handled specially. Tested elsehwere too.
@test gradient(x -> sum(sum, x .* [1,2,3]), Ref([4,5])) == ((x = [6.0, 6.0],),)
@test gradient(x -> sum(sum, Ref(x) .* [1,2,3]), [4,5]) == ([6.0, 6.0],)
Expand Down
74 changes: 74 additions & 0 deletions test/gradcheck.jl
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,11 @@ end
@test gradtest(x->fill(first(x), N), randn(rng, 1))
@test gradtest(x->fill(first(x), N, M), randn(rng, 1))
@test gradtest(x->fill(first(x), N, M, P), randn(rng, 1))

# fill(struct, ...) handled by ChainRules after
# https://github.com/FluxML/Zygote.jl/pull/1051
@test gradient(x -> fill(x, 3)[1][1], (1,2)) === ((1.0, nothing),)
@test gradient(x -> fill(x, 3)[1].a, (a=1, b=2)) == ((a=1.0, b=nothing),) # 1 not 1.0
end

@testset "circshift" begin
Expand Down Expand Up @@ -344,6 +349,20 @@ end
end
end

@testset "map and tuples" begin
# arrays of tuples, ChainRules's Tangent should not escape
@test gradient(x -> sum(map(first, x)), [(1,2), (3,4)]) == ([(1.0, nothing), (1.0, nothing)],)
@test gradient(x -> sum(first, x), [(1,2), (3,4)]) == ([(1.0, nothing), (1.0, nothing)],)

@test gradient(x -> map(+, x, (1,2,3))[1], (4,5,6)) == ((1.0, nothing, nothing),)
@test gradient(x -> map(+, x, [1,2,3])[1], (4,5,6)) == ((1.0, 0.0, 0.0),)
@test_broken gradient(x -> map(+, x, (1,2,3))[1], [4,5,6]) == ([1,0,0],) # Gradient [1.0, 0.0, 0.0] should be a tuple, since v0.6.0 at least

# mismatched lengths, should zip
@test_broken gradient(x -> map(+, x, [1,2,3,99])[1], (4,5,6)) == ((1.0, 0.0, 0.0),) # BoundsError: attempt to access 3-element Vector{Float64} at index [4]
@test_broken gradient(x -> map(+, x, [1,2,3])[1], (4,5,6,99)) == ((1.0, 0.0, 0.0, nothing),) # DimensionMismatch("variable with size(x) == (4,) cannot have a gradient with size(dx) == (3,)
end

@testset "Alternative Pmap Dispatch" begin
cache_and_map(f,xs...) = pmap(f, CachingPool(workers()), xs...; batch_size = 1)
@test gradtest(xs -> sum(cache_and_map(x -> x^2, xs)), rand(2,3))
Expand Down Expand Up @@ -1783,3 +1802,58 @@ end
# https://github.com/FluxML/Zygote.jl/issues/996
a = rand(3)
@test Zygote.gradient(x->sum(x .+ rand.()), a) == (ones(3),)

@testset "CRC issue 440" begin
# https://github.com/JuliaDiff/ChainRulesCore.jl/issues/440
f(x,y) = sum(sum, [[x[i],y[i]] for i=1:length(x)])
g(x,y) = sum(sum, [(x[i],y[i]) for i=1:length(x)])
@test gradient(f, rand(3), rand(3)) == ([1.0, 1.0, 1.0], [1.0, 1.0, 1.0])
@test gradient(g, rand(3), rand(3)) == ([1.0, 1.0, 1.0], [1.0, 1.0, 1.0])
end

@testset "CR issue 537" begin
# https://github.com/JuliaDiff/ChainRules.jl/issues/537
struct BV{F,T}
A::F
α::T
end
function Base.:*(c, km::BV)
new_A = c*km.A
other_params = getfield.([km], propertynames(km))[2:end]
BV(new_A, other_params...)
end
function (bv::BV)(V_app, ox::Bool; kT::Real = 0.026)
local exp_arg
if ox
exp_arg = (bv.α .* V_app) ./ kT
else
exp_arg = -((1 .- bv.α) .* V_app) ./ kT
end
bv.A .* exp.(exp_arg)
end
Zygote.@adjoint function BV{T,S}(A, α) where {T,S}
BV(A, α), Δ -> begin
.A, Δ.α)
end
end
bv = BV(1.0, 0.1)
I_vals, V = rand(81), rand(81)

g2 = gradient(V, bv) do V, bv
res = fill(bv, length(V))
r1 = map((m,v) -> m(v, true), res, V)
r2 = map((m,v) -> m(v, false), res, V)
sum(r1 .- r2)
end
@test size(g2[1]) == size(V)
@test g2[2] isa NamedTuple
@test g2[2].A isa Number

g1 = gradient(bv, V) do bv, V
res = map(x -> x * bv, V)
sum(x -> x.A, res)
end
@test g1[1] isa NamedTuple
@test g1[1].A isa Number
@test size(g1[2]) == size(V)
end

2 comments on commit 5ae5b4f

@mcabbott
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/46889

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.6.29 -m "<description of version>" 5ae5b4f2933e87923a567f13e1c298e26b954716
git push origin v0.6.29

Please # to comment.