Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

ProjectTo called on tuples #440

Closed
nmheim opened this issue Aug 17, 2021 · 4 comments
Closed

ProjectTo called on tuples #440

nmheim opened this issue Aug 17, 2021 · 4 comments
Labels
ProjectTo related to the projection functionality Structural Tangent Related to the `Tangent` type for structured (composite) values

Comments

@nmheim
Copy link

nmheim commented Aug 17, 2021

When using tuples in the function g (instead of a vector in function f) in the example below

using Zygote

x = rand(3)
y = rand(3)

f(x,y) = sum(sum, [[x[i],y[i]] for i=1:length(x)])
Zygote.gradient(x->f(x,y), x) |> display

g(x,y) = sum(sum, [(x[i],y[i]) for i=1:length(x)])
Zygote.gradient(x->g(x,y), x) |> display

I am getting the following error

([1.0, 1.0, 1.0],)
ERROR: LoadError: MethodError: no method matching ChainRulesCore.ProjectTo(::Tuple{Float64, Float64})
Closest candidates are:
  ChainRulesCore.ProjectTo(::LinearAlgebra.UnitLowerTriangular) at /home/niklas/.julia/packages/ChainRulesCore/Voykb/src/projection.jl:336
  ChainRulesCore.ProjectTo(::LinearAlgebra.UpperTriangular) at /home/niklas/.julia/packages/ChainRulesCore/Voykb/src/projection.jl:336
  ChainRulesCore.ProjectTo(::LinearAlgebra.SymTridiagonal{T, V} where V<:AbstractVector{T}) where T<:Number at /home/niklas/.julia/packages/ChainRulesCore/Voykb/src/projection.jl:373
  ...
Stacktrace:
  [1] iterate
    @ ./generator.jl:47 [inlined]
  [2] _collect
    @ ./array.jl:691 [inlined]
  [3] collect_similar
    @ ./array.jl:606 [inlined]
  [4] map
    @ ./abstractarray.jl:2294 [inlined]
  [5] ChainRulesCore.ProjectTo(xs::Vector{Tuple{Float64, Float64}})
    @ ChainRulesCore ~/.julia/packages/ChainRulesCore/Voykb/src/projection.jl:192
  [6] rrule(config::Zygote.ZygoteRuleConfig{Zygote.Context}, ::typeof(sum), f::Function, xs::Vector{Tuple{Float64, Float64}}; dims::Function)
    @ ChainRules ~/.julia/packages/ChainRules/5iZFH/src/rulesets/Base/mapreduce.jl:74
  [7] rrule
    @ ~/.julia/packages/ChainRules/5iZFH/src/rulesets/Base/mapreduce.jl:69 [inlined]
  [8] chain_rrule
    @ ~/.julia/packages/Zygote/l3aNG/src/compiler/chainrules.jl:152 [inlined]
  [9] macro expansion
    @ ~/.julia/packages/Zygote/l3aNG/src/compiler/interface2.jl:0 [inlined]
 [10] _pullback(::Zygote.Context, ::typeof(sum), ::typeof(sum), ::Vector{Tuple{Float64, Float64}})
    @ Zygote ~/.julia/packages/Zygote/l3aNG/src/compiler/interface2.jl:9
 [11] _pullback
    @ ~/_asdf.jl:10 [inlined]
 [12] _pullback(::Zygote.Context, ::typeof(g), ::Vector{Float64}, ::Vector{Float64})
    @ Zygote ~/.julia/packages/Zygote/l3aNG/src/compiler/interface2.jl:0
 [13] _pullback
    @ ~/_asdf.jl:11 [inlined]
 [14] _pullback(ctx::Zygote.Context, f::var"#7#8", args::Vector{Float64})
    @ Zygote ~/.julia/packages/Zygote/l3aNG/src/compiler/interface2.jl:0
 [15] _pullback(f::Function, args::Vector{Float64})
    @ Zygote ~/.julia/packages/Zygote/l3aNG/src/compiler/interface.jl:34
 [16] pullback(f::Function, args::Vector{Float64})
    @ Zygote ~/.julia/packages/Zygote/l3aNG/src/compiler/interface.jl:40
 [17] gradient(f::Function, args::Vector{Float64})
    @ Zygote ~/.julia/packages/Zygote/l3aNG/src/compiler/interface.jl:75
 [18] top-level scope
    @ ~/_asdf.jl:11
 [19] include(fname::String)
    @ Base.MainInclude ./client.jl:444
 [20] top-level scope
    @ REPL[1]:1
in expression starting at /home/niklas/_asdf.jl:11
@oxinabox
Copy link
Member

oxinabox commented Aug 17, 2021

I wonder how far we can go with

ProjectTo(x::Any) = ProjectTo{Tangent}()
(proj::ProjectTo{Tangent})(x::Tangent) = x  # assume right Tangent was given

So not generally supporting projecting natural tangent types (like Symmetric etc) onto structural Tangents.
Just allowing things that have no ProjectTo of their own, and that get a structural Tangent to pass them on.

The thing that goes with this is that all types that don't have ProjectTo(::T) defined would be assumed to not have natural tangents that we would rather use; so we expect to get structural tangents.
Seems a fairly ok assumption.
So the types that wouldn't do this right now are just things we have said have no tangent space like Bool,
and things we have defined preferred natural tangents for like subtypes of Number and AbstractArray

This may or may not be a wise idea. I probably have argued against it in the past 😁)

Safer is probably one specifically for tuples.
But that might have trouble for other iterators.

@oxinabox oxinabox added ProjectTo related to the projection functionality Structural Tangent Related to the `Tangent` type for structured (composite) values labels Aug 17, 2021
@mcabbott
Copy link
Member

The narrow bug is here:

https://github.com/JuliaDiff/ChainRules.jl/blob/master/src/rulesets/Base/mapreduce.jl#L67

This rule accepts an array of tuples, but can't actually handle them.

We may want to extend projection to other types... possibly recursing into arbitrary types... Ref is the one we're already committed to supporting, for which #427 is trying to sort out what to do.

@oxinabox oxinabox linked a pull request Sep 14, 2021 that will close this issue
@mcabbott
Copy link
Member

Closed by #488

@mcabbott
Copy link
Member

Well, maybe not!

julia> using Zygote

julia> x = rand(3); y = rand(3);

julia> f(x,y) = sum(sum, [[x[i],y[i]] for i=1:length(x)])
f (generic function with 1 method)

julia> Zygote.gradient(x->f(x,y), x) |> display
([1.0, 1.0, 1.0],)

julia> g(x,y) = sum(sum, [(x[i],y[i]) for i=1:length(x)])
g (generic function with 1 method)

julia> Zygote.gradient(x->g(x,y), x) |> display
ERROR: Gradient Tangent{Tuple{Float64, Float64}}(1.0, 1.0) should be a tuple
Stacktrace:
  [1] error(s::String)
    @ Base ./error.jl:33
  [2] gradtuple1(x::ChainRulesCore.Tangent{Tuple{Float64, Float64}, Tuple{Float64, Float64}})
    @ ZygoteRules ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:24
  [3] (::Zygote.var"#1620#back#156"{typeof(identity)})(Δ::ChainRulesCore.Tangent{Tuple{Float64, Float64}, Tuple{Float64, Float64}})
    @ Zygote ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:67
...

(@v1.7) pkg> st Zygote
      Status `~/.julia/environments/v1.7/Project.toml`
  [e88e6eb3] Zygote v0.6.28

But the remaining issue should be solved by FluxML/Zygote.jl#1103. With that:

julia> Zygote.gradient(x->g(x,y), x) |> display
([1.0, 1.0, 1.0],)

(jl_jo9g2i) pkg> st Zygote
      Status `/private/var/folders/yq/4p2zwd614y59gszh7y9ypyhh0000gn/T/jl_jo9g2i/Project.toml`
  [e88e6eb3] Zygote v0.6.29 `~/.julia/dev/Zygote

mcabbott added a commit to mcabbott/Zygote.jl that referenced this issue Oct 16, 2021
mcabbott added a commit to FluxML/Zygote.jl that referenced this issue Oct 16, 2021
* wrap_chainrules_input for arrays of Ref

* z2d too, for rrule_via_ad

* test from JuliaDiff/ChainRulesCore.jl#440

* add test from JuliaDiff/ChainRules.jl#537

* more tests related to CRC types

* union nothing, fix one case

* comments
# for free to join this conversation on GitHub. Already have an account? # to comment
Labels
ProjectTo related to the projection functionality Structural Tangent Related to the `Tangent` type for structured (composite) values
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants