Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Reconstruct with all fieldnames #7

Merged
merged 2 commits into from
Jan 20, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 20 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,25 @@ Foo(1.0, [1.0, 2.0, 3.0])

`functor` returns the parts of the object that can be inspected, as well as a `re` function that takes those values and restructures them back into an object of the original type.

For a discussion regarding implementing functors for which only a subset of the fields are "seen" by `functor`, see [here](https://github.com/FluxML/Functors.jl/issues/3#issuecomment-626747663).
To include only certain fields, pass a tuple of field names to `@functor`:

```julia
julia> struct Baz
x
y
end

julia> @functor Baz (x,)

julia> model = Baz(1, 2)
Baz(1, 2)

julia> fmap(float, model)
Baz(1.0, 2)
```

Any field not in the list will not be returned by `functor` and passed through as-is during reconstruction. This is done by invoking the default constructor, so structs that define custom inner constructors are expected to provide one that acts like the default.

It is also possible to implement `functor` by hand when greater flexibility is required. See [here](https://github.com/FluxML/Functors.jl/issues/3) for an example.

For a discussion regarding the need for a `cache` in the implementation of `fmap`, see [here](https://github.com/FluxML/Functors.jl/issues/2).
8 changes: 7 additions & 1 deletion src/functor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,14 @@ functor(::Type{<:AbstractArray}, x) = x, y -> y
functor(::Type{<:AbstractArray{<:Number}}, x) = (), _ -> x

function makefunctor(m::Module, T, fs = fieldnames(T))
yᵢ = 0
escargs = map(fieldnames(T)) do f
f in fs ? :(y[$(yᵢ += 1)]) : :(x.$f)
end
escfs = [:($f=x.$f) for f in fs]

@eval m begin
$Functors.functor(::Type{<:$T}, x) = ($([:($f=x.$f) for f in fs]...),), y -> $T(y...)
$Functors.functor(::Type{<:$T}, x) = ($(escfs...),), y -> $T($(escargs...))
end
end

Expand Down
43 changes: 30 additions & 13 deletions test/basics.jl
Original file line number Diff line number Diff line change
@@ -1,21 +1,38 @@
using Functors, Test

struct Foo
x
y
end
@testset "Nested" begin
struct Foo
x
y
end

@functor Foo
@functor Foo

struct Bar
x
end
struct Bar
x
end

@functor Bar
@functor Bar

model = Bar(Foo(1, [1, 2, 3]))
model = Bar(Foo(1, [1, 2, 3]))

model′ = fmap(float, model)
model′ = fmap(float, model)

@test model.x.y == model′.x.y
@test model′.x.y isa Vector{Float64}
@test model.x.y == model′.x.y
@test model′.x.y isa Vector{Float64}
end

@testset "Property list" begin
struct Baz
x
y
z
end

@functor Baz (y,)

model = Baz(1, 2, 3)
model′ = fmap(x -> 2x, model)

@test (model′.x, model′.y, model′.z) == (1, 4, 3)
end