Skip to content

Make work on functors #170

New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Merged
merged 20 commits into from
Jun 9, 2021
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "ChainRulesTestUtils"
uuid = "cdddcdb0-9152-4a09-a978-84456f9df70a"
version = "0.7.6"
version = "0.7.7"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand All @@ -13,5 +13,5 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
[compat]
ChainRulesCore = "0.10"
Compat = "3"
FiniteDifferences = "0.12"
FiniteDifferences = "0.12.12"
julia = "1"
10 changes: 5 additions & 5 deletions docs/Manifest.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,15 @@ uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f"

[[ChainRulesCore]]
deps = ["Compat", "LinearAlgebra", "SparseArrays"]
git-tree-sha1 = "5d64be50ea9b43a89b476be773e125cef03c7cd5"
git-tree-sha1 = "04dd5ce9f9d7b9b14559b00a7eb5be7528f56b82"
uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
version = "0.10.1"
version = "0.10.2"

[[ChainRulesTestUtils]]
deps = ["ChainRulesCore", "Compat", "FiniteDifferences", "LinearAlgebra", "Random", "Test"]
path = ".."
uuid = "cdddcdb0-9152-4a09-a978-84456f9df70a"
version = "0.7.0"
version = "0.7.5"

[[Compat]]
deps = ["Base64", "Dates", "DelimitedFiles", "Distributed", "InteractiveUtils", "LibGit2", "Libdl", "LinearAlgebra", "Markdown", "Mmap", "Pkg", "Printf", "REPL", "Random", "SHA", "Serialization", "SharedArrays", "Sockets", "SparseArrays", "Statistics", "Test", "UUIDs", "Unicode"]
Expand Down Expand Up @@ -57,9 +57,9 @@ uuid = "f43a241f-c20a-4ad4-852c-f6b1247861c6"

[[FiniteDifferences]]
deps = ["ChainRulesCore", "LinearAlgebra", "Printf", "Random", "Richardson", "StaticArrays"]
git-tree-sha1 = "f8c8e287c1d68abc2719ad58fb39de9f6c0d71b1"
git-tree-sha1 = "5d448db3b862fb331d20144c2e59c54db69720e0"
uuid = "26cc04aa-876d-5657-8c51-4c34ba976000"
version = "0.12.10"
version = "0.12.12"

[[IOCapture]]
deps = ["Logging"]
Expand Down
27 changes: 25 additions & 2 deletions docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,29 @@ Test Summary: | Pass Total
test_scalar: relu at -0.5 | 9 9
```

## Testing constructors and functors (callable objects)

Testing constructor and functors works as you would expect. For struct `Foo`
```julia
struct Foo
a::Float64
end
(f::Foo)(x) = return f.a + x
Base.length(::Foo) = 1
Base.iterate(f::Foo) = iterate(f.a)
Base.iterate(f::Foo, state) = iterate(f.a, state)
```
the `f/rrule`s can be tested by
```julia
test_rrule(Foo, rand()) # constructor

foo = Foo(rand())
test_rrule(foo, rand()) # functor

# it is also possible to provide tangents for `foo` explicitly
test_frule(foo ⊢ Tangent{Foo}(;a=rand()), rand())
```

## Specifying Tangents
[`test_frule`](@ref) and [`test_rrule`](@ref) allow you to specify the tangents used for testing.
This is done by passing in `x ⊢ Δx`, where `x` is the primal and `Δx` is the tangent, in the place of the primal inputs.
Expand Down Expand Up @@ -152,7 +175,7 @@ which should have passed the test.

By default, all functions for testing rules check whether the output type (as well as that of the pullback for `rrule`s) can be completely inferred, such that everything is type stable:

```jldoctest ex
```julia
julia> function ChainRulesCore.rrule(::typeof(abs), x)
abs_pullback(Δ) = (NoTangent(), x >= 0 ? Δ : big(-1.0) * Δ)
return abs(x), abs_pullback
Expand All @@ -167,7 +190,7 @@ test_rrule: abs on Float64: Error During Test at /home/runner/work/ChainRulesTes

This can be disabled on a per-rule basis using the `check_inferred` keyword argument:

```jldoctest ex
```julia
julia> test_rrule(abs, 1.; check_inferred=false)
Test Summary: | Pass Total
test_rrule: abs on Float64 | 5 5
Expand Down
115 changes: 56 additions & 59 deletions src/testers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ function test_scalar(f, z; rtol=1e-9, atol=1e-9, fdm=_fdm, fkwargs=NamedTuple(),
isapprox_kwargs = (; rtol=rtol, atol=atol, kwargs...)

@testset "test_scalar: $f at $z" begin
_ensure_not_running_on_functor(f, "test_scalar")
# z = x + im * y
# Ω = u(x, y) + im * v(x, y)
Ω = f(z; fkwargs...)
Expand All @@ -30,8 +29,9 @@ function test_scalar(f, z; rtol=1e-9, atol=1e-9, fdm=_fdm, fkwargs=NamedTuple(),
test_frule(f, z ⊢ Δx; rule_test_kwargs...)
if z isa Complex
# check that same tangent is produced for tangent 1.0 and 1.0 + 0.0im
_, real_tangent = frule((ZeroTangent(), real(Δx)), f, z; fkwargs...)
_, embedded_tangent = frule((ZeroTangent(), Δx), f, z; fkwargs...)
ḟ = rand_tangent(f)
_, real_tangent = frule((ḟ, real(Δx)), f, z; fkwargs...)
_, embedded_tangent = frule((ḟ, Δx), f, z; fkwargs...)
test_approx(real_tangent, embedded_tangent; isapprox_kwargs...)
end
end
Expand Down Expand Up @@ -70,7 +70,7 @@ end
test_frule(f, args..; kwargs...)

# Arguments
- `f`: Function for which the `frule` should be tested.
- `f`: Function for which the `frule` should be tested. Can also provide `f ⊢ ḟ`.
- `args` either the primal args `x`, or primals and their tangents: `x ⊢ ẋ`
- `x`: input at which to evaluate `f` (should generally be set to an arbitary point in the domain).
- `ẋ`: differential w.r.t. `x`, will be generated automatically if not provided
Expand Down Expand Up @@ -99,25 +99,29 @@ function test_frule(
# To simplify some of the calls we make later lets group the kwargs for reuse
isapprox_kwargs = (; rtol=rtol, atol=atol, kwargs...)

# and define a helper closure
call_on_copy(f, xs...) = deepcopy(f)(deepcopy(xs)...; deepcopy(fkwargs)...)

@testset "test_frule: $f on $(_string_typeof(args))" begin
_ensure_not_running_on_functor(f, "test_frule")

xẋs = auto_primal_and_tangent.(args)
xs = primal.(xẋs)
ẋs = tangent.(xẋs)
if check_inferred && _is_inferrable(f, deepcopy(xs)...; deepcopy(fkwargs)...)
_test_inferred(frule, (NoTangent(), deepcopy(ẋs)...), f, deepcopy(xs)...; deepcopy(fkwargs)...)
primals_and_tangents = auto_primal_and_tangent.((f, args...))
primals = primal.(primals_and_tangents)
tangents = tangent.(primals_and_tangents)

if check_inferred && _is_inferrable(deepcopy(primals)...; deepcopy(fkwargs)...)
_test_inferred(frule, deepcopy(tangents), deepcopy(primals)...; deepcopy(fkwargs)...)
end
res = frule((NoTangent(), deepcopy(ẋs)...), f, deepcopy(xs)...; deepcopy(fkwargs)...)
res === nothing && throw(MethodError(frule, typeof((f, xs...))))

res = frule(deepcopy(tangents), deepcopy(primals)...; deepcopy(fkwargs)...)
res === nothing && throw(MethodError(frule, typeof(primals)))
@test_msg "The frule should return (y, ∂y), not $res." res isa Tuple{Any,Any}
Ω_ad, dΩ_ad = res
Ω = f(deepcopy(xs)...; deepcopy(fkwargs)...)
Ω = call_on_copy(primals...)
test_approx(Ω_ad, Ω; isapprox_kwargs...)

# TODO: remove Nothing when https://github.com/JuliaDiff/ChainRulesTestUtils.jl/issues/113
ẋs_is_ignored = isa.(ẋs, Union{Nothing,NoTangent})
if any(ẋs .== nothing)
is_ignored = isa.(tangents, Union{Nothing,NoTangent})
if any(tangents .== nothing)
Base.depwarn(
"test_frule(f, k ⊢ nothing) is deprecated, use " *
"test_frule(f, k ⊢ NoTangent()) instead for non-differentiable ks",
Expand All @@ -126,7 +130,7 @@ function test_frule(
end

# Correctness testing via finite differencing.
dΩ_fd = _make_jvp_call(fdm, (xs...) -> f(deepcopy(xs)...; deepcopy(fkwargs)...), Ω, xs, ẋs, ẋs_is_ignored)
dΩ_fd = _make_jvp_call(fdm, call_on_copy, Ω, primals, tangents, is_ignored)
test_approx(dΩ_ad, dΩ_fd; isapprox_kwargs...)

acc = output_tangent isa Auto ? rand_tangent(Ω) : output_tangent
Expand All @@ -138,14 +142,14 @@ end
test_rrule(f, args...; kwargs...)

# Arguments
- `f`: Function to which rule should be applied.
- `args` either the primal args `x`, or primals and their tangents: `x ⊢ `
- `f`: Function to which rule should be applied. Can also provide `f ⊢ f̄`.
- `args` either the primal args `x`, or primals and their tangents: `x ⊢ `
- `x`: input at which to evaluate `f` (should generally be set to an arbitary point in the domain).
- `x̄`: currently accumulated cotangent, will be generated automatically if not provided
Non-differentiable arguments, such as indices, should have `x̄` set as `NoTangent()`.

# Keyword Arguments
- `output_tangent` the seed to propagate backward for testing (techncally a cotangent).
- `output_tangent` the seed to propagate backward for testing (technically a cotangent).
should be a differential for the output of `f`. Is set automatically if not provided.
- `fdm::FiniteDifferenceMethod`: the finite differencing method to use.
- If `check_inferred=true`, then the inferrability of the `rrule` is checked
Expand All @@ -167,63 +171,66 @@ function test_rrule(
# To simplify some of the calls we make later lets group the kwargs for reuse
isapprox_kwargs = (; rtol=rtol, atol=atol, kwargs...)

# and define helper closure over fkwargs
call(f, xs...) = f(xs...; fkwargs...)

@testset "test_rrule: $f on $(_string_typeof(args))" begin
_ensure_not_running_on_functor(f, "test_rrule")

# Check correctness of evaluation.
xx̄s = auto_primal_and_tangent.(args)
xs = primal.(xx̄s)
accumulated_x̄ = tangent.(xx̄s)
if check_inferred && _is_inferrable(f, xs...; fkwargs...)
_test_inferred(rrule, f, xs...; fkwargs...)
primals_and_tangents = auto_primal_and_tangent.((f, args...))
primals = primal.(primals_and_tangents)
accum_cotangents = tangent.(primals_and_tangents)

if check_inferred && _is_inferrable(primals...; fkwargs...)
_test_inferred(rrule, primals...; fkwargs...)
end
res = rrule(f, xs...; fkwargs...)
res === nothing && throw(MethodError(rrule, typeof((f, xs...))))
res = rrule(primals...; fkwargs...)
res === nothing && throw(MethodError(rrule, typeof((primals...))))
y_ad, pullback = res
y = f(xs...; fkwargs...)
y = call(primals...)
test_approx(y_ad, y; isapprox_kwargs...) # make sure primal is correct

ȳ = output_tangent isa Auto ? rand_tangent(y) : output_tangent

check_inferred && _test_inferred(pullback, ȳ)
∂s = pullback(ȳ)
∂s isa Tuple || error("The pullback must return (∂self, ∂args...), not $∂s.")
∂self = ∂s[1]
x̄s_ad = ∂s[2:end]
@test ∂self === NoTangent() # No internal fields
msg = "The pullback should return 1 cotangent for each primal input."
@test_msg msg length(x̄s_ad) == length(args)
ad_cotangents = pullback(ȳ)
ad_cotangents isa Tuple || error("The pullback must return (∂self, ∂args...), not $∂s.")
msg = "The pullback should return 1 cotangent for the primal and each primal input."
@test_msg msg length(ad_cotangents) == 1 + length(args)

# Correctness testing via finite differencing.
# TODO: remove Nothing when https://github.com/JuliaDiff/ChainRulesTestUtils.jl/issues/113
x̄s_is_dne = isa.(accumulated_x̄, Union{Nothing,NoTangent})
if any(accumulated_x̄ .== nothing)
is_ignored = isa.(accum_cotangents, Union{Nothing, NoTangent})
if any(accum_cotangents .== nothing)
Base.depwarn(
"test_rrule(f, k ⊢ nothing) is deprecated, use " *
"test_rrule(f, k ⊢ NoTangent()) instead for non-differentiable ks",
:test_rrule,
)
end

x̄s_fd = _make_j′vp_call(fdm, (xs...) -> f(xs...; fkwargs...), ȳ, xs, x̄s_is_dne)
for (accumulated_x̄, x̄_ad, x̄_fd) in zip(accumulated_x̄, x̄s_ad, x̄s_fd)
if accumulated_x̄ isa Union{Nothing,NoTangent} # then we marked this argument as not differentiable # TODO remove once #113
@assert x̄_fd === nothing # this is how `_make_j′vp_call` works
x̄_ad isa ZeroTangent && error(
"The pullback in the rrule for $f function should use NoTangent()" *
fd_cotangents = _make_j′vp_call(fdm, call, ȳ, primals, is_ignored)

for (accum_cotangent, ad_cotangent, fd_cotangent) in zip(
accum_cotangents, ad_cotangents, fd_cotangents
)
if accum_cotangent isa Union{Nothing,NoTangent} # then we marked this argument as not differentiable # TODO remove once #113
@assert fd_cotangent === nothing # this is how `_make_j′vp_call` works
ad_cotangent isa ZeroTangent && error(
"The pullback in the rrule should use NoTangent()" *
" rather than ZeroTangent() for non-perturbable arguments.",
)
@test x̄_ad isa NoTangent # we said it wasn't differentiable.
@test ad_cotangent isa NoTangent # we said it wasn't differentiable.
else
x̄_ad isa AbstractThunk && check_inferred && _test_inferred(unthunk, x̄_ad)
ad_cotangent isa AbstractThunk && check_inferred && _test_inferred(unthunk, ad_cotangent)

# The main test of the actual deriviative being correct:
test_approx(x̄_ad, x̄_fd; isapprox_kwargs...)
_test_add!!_behaviour(accumulated_x̄, x̄_ad; isapprox_kwargs...)
# The main test of the actual derivative being correct:
test_approx(ad_cotangent, fd_cotangent; isapprox_kwargs...)
_test_add!!_behaviour(accum_cotangent, ad_cotangent; isapprox_kwargs...)
end
end

check_thunking_is_appropriate(x̄s_ad)
check_thunking_is_appropriate(ad_cotangents)
end # top-level testset
end

Expand All @@ -236,16 +243,6 @@ function check_thunking_is_appropriate(x̄s)
end
end

function _ensure_not_running_on_functor(f, name)
# if x itself is a Type, then it is a constructor, thus not a functor.
# This also catchs UnionAll constructors which have a `:var` and `:body` fields
f isa Type && return nothing

if fieldcount(typeof(f)) > 0
throw(ArgumentError("$name cannot be used on closures/functors (such as $f)"))
end
end

"""
@maybe_inferred [Type] f(...)

Expand Down
52 changes: 52 additions & 0 deletions test/testers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,38 @@ function finplace!(x; y=[1])
return x
end

struct Foo
a::Float64
end
(f::Foo)(x) = return f.a + x
Base.length(::Foo) = 1
Base.iterate(f::Foo) = iterate(f.a)
Base.iterate(f::Foo, state) = iterate(f.a, state)

# constructor
function ChainRulesCore.rrule(::Type{Foo}, a)
foo = Foo(a)
function Foo_pullback(Δfoo)
return NoTangent(), Δfoo.a
end
return foo, Foo_pullback
end
function ChainRulesCore.frule((_, Δa), ::Type{Foo}, a)
return Foo(a), Foo(Δa)
end

# functor
function ChainRulesCore.rrule(f::Foo, x)
y = f(x)
function Foo_pullback(Δy)
return Tangent{Foo}(;a=Δy), Δy
end
return y, Foo_pullback
end
function ChainRulesCore.frule((Δf, Δx), f::Foo, x)
return f(x), Δf.a + Δx
end

@testset "testers.jl" begin
@testset "test_scalar" begin
@testset "Ensure correct rules succeed" begin
Expand Down Expand Up @@ -513,6 +545,26 @@ end
end
end

@testset "structs" begin
@testset "constructor" begin
test_frule(Foo, rand())
test_rrule(Foo, rand())
end

foo = Foo(rand())
tfoo = Tangent{Foo}(;a=rand())
@testset "functor" begin
test_frule(foo, rand())
test_rrule(foo, rand())
test_scalar(foo, rand())

test_frule(foo ⊢ Foo(rand()), rand())
test_frule(foo ⊢ tfoo, rand())
test_rrule(foo ⊢ Foo(rand()), rand())
test_rrule(foo ⊢ tfoo, rand())
end
end

@testset "Tuple primal that is not equal to differential backing" begin
# https://github.com/JuliaMath/SpecialFunctions.jl/issues/288
forwards_trouble(x) = (1, 2.0 * x)
Expand Down