Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

fill assumes numerical arrays #537

Closed
DhairyaLGandhi opened this issue Oct 5, 2021 · 4 comments
Closed

fill assumes numerical arrays #537

DhairyaLGandhi opened this issue Oct 5, 2021 · 4 comments

Comments

@DhairyaLGandhi
Copy link
Contributor

DhairyaLGandhi commented Oct 5, 2021

With Zygote, we habitually get NammedTuples (or Tangents) as tangents for structures, however something like the following setup would not work well (note this is an incomplete demo, but describes the basic issue.

struct MyType
  a
end

Base.:(+)(a::MyType, b::MyType) = a.a + b.a
gradient((m,x) -> sum(fill(m, length(x))), MyType(3), 5)

This usually points to

fill_pullback
    @ ~/.julia/packages/ChainRules/RyXef/src/rulesets/Base/array.jl

The relevant real use case stack trace is something like

ERROR: MethodError: no method matching +(::NamedTuple{(:A, :α), Tuple{Float64, Float64}}, ::NamedTuple{(:A, :α), Tuple{Float64, Float64}})
Closest candidates are:
  +(::Any, ::Any, ::Any, ::Any...) at operators.jl:560
  +(::ChainRulesCore.Tangent{P, T} where T, ::P) where P at /Users/dhairyagandhi/.julia/packages/ChainRulesCore/1L9My/src/tangent_arithmetic.jl:145
  +(::ChainRulesCore.AbstractThunk, ::Any) at /Users/dhairyagandhi/.julia/packages/ChainRulesCore/1L9My/src/tangent_arithmetic.jl:121
  ...
Stacktrace:
  [1] add_sum(x::NamedTuple{(:A, :α), Tuple{Float64, Float64}}, y::NamedTuple{(:A, :α), Tuple{Float64, Float64}})
    @ Base ./reduce.jl:24
  [2] mapreduce_impl(f::typeof(identity), op::typeof(Base.add_sum), A::Vector{NamedTuple{(:A, :α), Tuple{Float64, Float64}}}, ifirst::Int64, ilast::Int64, blksize::Int64)
    @ Base ./reduce.jl:242
  [3] mapreduce_impl
    @ ./reduce.jl:257 [inlined]
  [4] _mapreduce
    @ ./reduce.jl:415 [inlined]
  [5] _mapreduce_dim
    @ ./reducedim.jl:318 [inlined]
  [6] #mapreduce#672
    @ ./reducedim.jl:310 [inlined]
  [7] mapreduce
    @ ./reducedim.jl:310 [inlined]
  [8] #_sum#682
    @ ./reducedim.jl:878 [inlined]
  [9] _sum
    @ ./reducedim.jl:878 [inlined]
 [10] #_sum#681
    @ ./reducedim.jl:877 [inlined]
 [11] _sum
    @ ./reducedim.jl:877 [inlined]
 [12] #sum#679
    @ ./reducedim.jl:873 [inlined]
 [13] sum
    @ ./reducedim.jl:873 [inlined]
 [14] fill_pullback
    @ ~/.julia/packages/ChainRules/RyXef/src/rulesets/Base/array.jl:342 [inlined]
 [15] ZBack
    @ ~/Downloads/arpa/battery/Zygote.jl/src/compiler/chainrules.jl:168 [inlined]
 [16] Pullback
    @ ./REPL[140]:1 [inlined]
@willtebbutt
Copy link
Member

This looks like Zygote types are making their way into ChainRules rules -- if the above involved Tangents, I think this example would be fine because + is defined on Tangents. Could you provide an MWE of the actual use-case?

@DhairyaLGandhi
Copy link
Contributor Author

Actually its the other way around. ChainRules types are making it into Zygote.

julia> struct BV{F,T}
         A::F
         α::T
       end

julia> function *(c, km::BV)
           new_A = c*km.A
           other_params = getfield.([km], propertynames(km))[2:end]
           BV(new_A, other_params...)
       end

julia> gradient(bv, V) do bv, V
         res = map(x -> x * bv, V)
         sum(x -> x.A, res)
       end
ERROR: Need an adjoint for constructor BV{Float64, Float64}. Gradient is of type ChainRulesCore.Tangent{BV{Float64, Float64}, NamedTuple{(:A, ), Tuple{Float64, ChainRulesCore.NoTangent}}}
Stacktrace:
  [1] error(s::String)
    @ Base ./error.jl:33
  [2] (::Zygote.Jnew{BV{Float64, Float64}, Nothing, false})(Δ::ChainRulesCore.Tangent{BV{Float64, Float64}, NamedTuple{(:A, :α), Tuple{Float64, ChainRulesCore.NoTangent}}})
    @ Zygote ~/Downloads/arpa/battery/Zygote.jl/src/lib/lib.jl:323
  [3] (::Zygote.var"#1768#back#219"{Zygote.Jnew{BV{Float64, Float64}, Nothing, false}})(Δ::ChainRulesCore.Tangent{BV{Float64, Float64}, NamedTuple{(:A, :α), Tuple{Float64, ChainRulesCore.NoTangent}}})
    @ Zygote ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:67

Notice the gradient type is a Tangent instead of a NamedTuple, I think this should be retained as a NamedTuple since that's what we use in Zygote.

Next,

julia> bv
BV{Float64, Float64}(1.0, 0.1)

julia> I_vals, V = rand(81), rand(81)

julia> function (bv::BV)(V_app, ox::Bool; kT::Real = 0.026)
           local exp_arg
           if ox
               exp_arg = (bv.α .* V_app) ./ kT
           else
               exp_arg = -((1 .- bv.α) .* V_app) ./ kT
           end
           bv.A .* exp.(exp_arg)
       end

julia> Zygote.@adjoint function BV{T,S}(A, α) where {T,S}
         BV(A, α), Δ -> begin.A, Δ.α)
         end
       end

julia> gradient(V, bv) do V, bv
         res = fill(bv, length(V))
         r1 = map((m,v) -> m(v, true), res, V)
         r2 = map((m,v) -> m(v, false), res, V)
         sum(r1 .- r2)
       end
ERROR: MethodError: no method matching +(::NamedTuple{(:A, :α), Tuple{Float64, Float64}}, ::NamedTuple{(:A, :α), Tuple{Float64, Float64}})
Closest candidates are:
  +(::Any, ::Any, ::Any, ::Any...) at operators.jl:560
  +(::ChainRulesCore.Tangent{P, T} where T, ::P) where P at /Users/dhairyagandhi/.julia/packages/ChainRulesCore/1L9My/src/tangent_arithmetic.jl:145
  +(::ChainRulesCore.AbstractThunk, ::Any) at /Users/dhairyagandhi/.julia/packages/ChainRulesCore/1L9My/src/tangent_arithmetic.jl:121
  ...
Stacktrace:
  [1] add_sum(x::NamedTuple{(:A, :α), Tuple{Float64, Float64}}, y::NamedTuple{(:A, :α), Tuple{Float64, Float64}})
    @ Base ./reduce.jl:24
  [2] mapreduce_impl(f::typeof(identity), op::typeof(Base.add_sum), A::Vector{NamedTuple{(:A, :α), Tuple{Float64, Float64}}}, ifirst::Int64, ilast::Int64, blksize::Int64)
    @ Base ./reduce.jl:242
  [3] mapreduce_impl
    @ ./reduce.jl:257 [inlined]
  [4] _mapreduce
    @ ./reduce.jl:415 [inlined]
  [5] _mapreduce_dim
    @ ./reducedim.jl:318 [inlined]
  [6] #mapreduce#672
    @ ./reducedim.jl:310 [inlined]
  [7] mapreduce
    @ ./reducedim.jl:310 [inlined]
  [8] #_sum#682
    @ ./reducedim.jl:878 [inlined]
  [9] _sum
    @ ./reducedim.jl:878 [inlined]
 [10] #_sum#681
    @ ./reducedim.jl:877 [inlined]
 [11] _sum
    @ ./reducedim.jl:877 [inlined]
 [12] #sum#679
    @ ./reducedim.jl:873 [inlined]
 [13] sum
    @ ./reducedim.jl:873 [inlined]
 [14] fill_pullback
    @ ~/.julia/packages/ChainRules/RyXef/src/rulesets/Base/array.jl:342 [inlined]

mcabbott added a commit to mcabbott/Zygote.jl that referenced this issue Oct 16, 2021
mcabbott added a commit to FluxML/Zygote.jl that referenced this issue Oct 16, 2021
* wrap_chainrules_input for arrays of Ref

* z2d too, for rrule_via_ad

* test from JuliaDiff/ChainRulesCore.jl#440

* add test from JuliaDiff/ChainRules.jl#537

* more tests related to CRC types

* union nothing, fix one case

* comments
@willtebbutt
Copy link
Member

Sorry, I've lost track of where we are on this.

Either way, this looks to me like both things are happening -- somehow a Tangent is making its way into Zygote's _pulback for getfield (in your first example), and somehow a NamedTuple is making its way inside ChainRules' implementation of the pullback for fill in the second example. Presumably the root problem is that the function that converts Zygote types to ChainRules types is broken somewhere -- if you could figure out exactly what types aren't getting converted properly, presumably it would be straightforward to find a fix?

@mcabbott
Copy link
Member

Zygote was leaking types, but has been fixed:

julia> using Zygote

julia> struct MyType
         a
       end

julia> Base.:(+)(a::MyType, b::MyType) = a.a + b.a

julia> gradient((m,x) -> sum(fill(m, length(x))), MyType(3), 5)
ERROR: Output should be scalar; gradients are not defined for output MyType(3)
julia> import Base: *

julia> struct BV{F,T}
         A::F
         α::T
       end

julia> function *(c, km::BV)
           new_A = c*km.A
           other_params = getfield.([km], propertynames(km))[2:end]
           BV(new_A, other_params...)
       end

julia> gradient(bv, V) do bv, V
         res = map(x -> x * bv, V)
         sum(x -> x.A, res)
       end
ERROR: UndefVarError: bv not defined

julia> bv = BV{Float64, Float64}(1.0, 0.1)  # guess from below

julia> gradient(bv, V) do bv, V
         res = map(x -> x * bv, V)
         sum(x -> x.A, res)
       end
((A = 35.543737652251295, α = 0.0), [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0  …  1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0])
julia> bv = BV{Float64, Float64}(1.0, 0.1)
BV{Float64, Float64}(1.0, 0.1)

julia> bv = BV{Float64, Float64}(1.0, 0.1)
BV{Float64, Float64}(1.0, 0.1)

julia> I_vals, V = rand(81), rand(81)
([0.3130531731936679, 0.2025924962417458, 0.21174767804217698, 0.4539097981190803, 0.39827389603668917, 0.43692968092128415, 0.16045248902025988, 0.06113605116594267, 0.14445265205623414, 0.5233820003105387  …  0.2072622659671255, 0.8098058667333574, 0.3521586635724596, 0.5340492828292296, 0.3432556528629884, 0.16628313443243026, 0.7134781576254401, 0.17215925260695775, 0.7455139933572823, 0.672548760351954], [0.2883307396473991, 0.24551806102339846, 0.3228117657836165, 0.10833116411468247, 0.1600241112376699, 0.3218476712372844, 0.5850353613376229, 0.6153612656652732, 0.789033156717247, 0.48151204689160054  …  0.41296557386701727, 0.9749267130022032, 0.6402035085230557, 0.08184217195989874, 0.061046100917580004, 0.06776400045619169, 0.20345882363878776, 0.31457444646040167, 0.27847853195575545, 0.729752534889228])

julia> function (bv::BV)(V_app, ox::Bool; kT::Real = 0.026)
           local exp_arg
           if ox
               exp_arg = (bv.α .* V_app) ./ kT
           else
               exp_arg = -((1 .- bv.α) .* V_app) ./ kT
           end
           bv.A .* exp.(exp_arg)
       end

julia> Zygote.@adjoint function BV{T,S}(A, α) where {T,S}
         BV(A, α), Δ -> begin
           (Δ.A, Δ.α)
         end
       end

julia> gradient(V, bv) do V, bv
         res = fill(bv, length(V))
         r1 = map((m,v) -> m(v, true), res, V)
         r2 = map((m,v) -> m(v, false), res, V)
         sum(r1 .- r2)
       end
([11.660131381496205, 9.89556629393103, 13.312372244092002, 6.648306720086004, 7.253474495300132, 13.26311890996391, 36.49625507471707, 41.011304456528116, 79.98319566469956, 24.509067242244758  …  18.829106984885833, 163.49653131249448, 45.12312365680657, 7.3056931645632215, 9.047591162664364, 8.30688262732026, 8.441807835110536, 12.897396234237767, 11.227270074026485, 63.676410156537536], (A = 738.218041317488, α = 19958.20482726184))
(jl_dAz9hk) pkg> st
      Status `/private/var/folders/yq/4p2zwd614y59gszh7y9ypyhh0000gn/T/jl_dAz9hk/Project.toml`
  [e88e6eb3] Zygote v0.6.30

# for free to join this conversation on GitHub. Already have an account? # to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants