Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

define modules function #1444

Merged
merged 6 commits into from
Mar 10, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 26 additions & 27 deletions Manifest.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,14 @@

[[AbstractFFTs]]
deps = ["LinearAlgebra"]
git-tree-sha1 = "051c95d6836228d120f5f4b984dd5aba1624f716"
git-tree-sha1 = "485ee0867925449198280d4af84bdb46a2a404d0"
uuid = "621f4979-c628-5d54-868e-fcf4e3e8185c"
version = "0.5.0"
version = "1.0.1"

[[AbstractTrees]]
deps = ["Markdown"]
git-tree-sha1 = "33e450545eaf7699da1a6e755f9ea65f14077a45"
git-tree-sha1 = "03e0550477d86222521d254b741d470ba17ea0b5"
uuid = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
version = "0.3.3"
version = "0.3.4"

[[Adapt]]
deps = ["LinearAlgebra"]
Expand Down Expand Up @@ -46,15 +45,15 @@ version = "2.4.1"

[[ChainRules]]
deps = ["ChainRulesCore", "Compat", "LinearAlgebra", "Random", "Reexport", "Requires", "Statistics"]
git-tree-sha1 = "6ba8100fa9356807f1d0df6468ae463c67627c30"
git-tree-sha1 = "e01f521443e3700f40ad3c7c1c6aa3a6940aaea1"
uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2"
version = "0.7.49"
version = "0.7.54"

[[ChainRulesCore]]
deps = ["Compat", "LinearAlgebra", "SparseArrays"]
git-tree-sha1 = "53fed426c9af1eb68e63b3999e96454c2db79757"
git-tree-sha1 = "de4f08843c332d355852721adb1592bce7924da3"
uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
version = "0.9.27"
version = "0.9.29"

[[CodecZlib]]
deps = ["TranscodingStreams", "Zlib_jll"]
Expand All @@ -64,9 +63,9 @@ version = "0.7.0"

[[ColorTypes]]
deps = ["FixedPointNumbers", "Random"]
git-tree-sha1 = "4bffea7ed1a9f0f3d1a131bbcd4b925548d75288"
git-tree-sha1 = "5e9769a17f17b587c951d57ba4319782b40c3513"
uuid = "3da002f7-5984-5a60-b8a6-cbb66c0b333f"
version = "0.10.9"
version = "0.10.10"

[[Colors]]
deps = ["ColorTypes", "FixedPointNumbers", "InteractiveUtils", "Reexport"]
Expand All @@ -93,9 +92,9 @@ uuid = "e66e0078-7015-5450-92f7-15fbd957f2ae"
version = "0.3.4+0"

[[DataAPI]]
git-tree-sha1 = "8ab70b4de35bb3b8cc19654f6b893cf5164f8ee8"
git-tree-sha1 = "dfb3b7e89e395be1e25c2ad6d7690dc29cc53b1d"
uuid = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a"
version = "1.5.1"
version = "1.6.0"

[[DataStructures]]
deps = ["Compat", "InteractiveUtils", "OrderedCollections"]
Expand Down Expand Up @@ -134,9 +133,9 @@ version = "0.1.3"

[[FillArrays]]
deps = ["LinearAlgebra", "Random", "SparseArrays"]
git-tree-sha1 = "50eabdace27aa27b143f65b65e762bb0112a7708"
git-tree-sha1 = "4705cc4e212c3c978c60b1b18118ec49b4d731fd"
uuid = "1a297f60-69ca-5386-bcde-b61e274b549b"
version = "0.11.1"
version = "0.11.5"

[[FixedPointNumbers]]
deps = ["Statistics"]
Expand All @@ -152,9 +151,9 @@ version = "0.10.16"

[[Functors]]
deps = ["MacroTools"]
git-tree-sha1 = "cd79039c468eac0a15256c55f260eec7ce551d07"
git-tree-sha1 = "a7bb2af991c43dcf5c3455d276dd83976799634f"
uuid = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
version = "0.2.0"
version = "0.2.1"

[[GPUArrays]]
deps = ["AbstractFFTs", "Adapt", "LinearAlgebra", "Printf", "Random", "Serialization"]
Expand Down Expand Up @@ -252,9 +251,9 @@ uuid = "efe28fd5-8261-553b-a9e1-b2916fc3738e"
version = "0.5.3+4"

[[OrderedCollections]]
git-tree-sha1 = "d45739abcfc03b51f6a42712894a593f74c80a23"
git-tree-sha1 = "4fa2ba51070ec13fcc7517db714445b4ab986bdf"
uuid = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
version = "1.3.3"
version = "1.4.0"

[[Pkg]]
deps = ["Dates", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "UUIDs"]
Expand Down Expand Up @@ -283,9 +282,9 @@ version = "1.0.0"

[[Requires]]
deps = ["UUIDs"]
git-tree-sha1 = "cfbac6c1ed70c002ec6361e7fd334f02820d6419"
git-tree-sha1 = "4036a3bd08ac7e968e27c203d45f5fff15020621"
uuid = "ae029012-a4dd-5104-9daa-d747884805df"
version = "1.1.2"
version = "1.1.3"

[[SHA]]
uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce"
Expand Down Expand Up @@ -318,9 +317,9 @@ uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"

[[SpecialFunctions]]
deps = ["ChainRulesCore", "OpenSpecFun_jll"]
git-tree-sha1 = "75394dbe2bd346beeed750fb02baa6445487b862"
git-tree-sha1 = "5919936c0e92cff40e57d0ddf0ceb667d42e5902"
uuid = "276daf66-3868-5448-9aa4-cd146d93841b"
version = "1.2.1"
version = "1.3.0"

[[StaticArrays]]
deps = ["LinearAlgebra", "Random", "Statistics"]
Expand All @@ -334,19 +333,19 @@ uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"

[[StatsBase]]
deps = ["DataAPI", "DataStructures", "LinearAlgebra", "Missings", "Printf", "Random", "SortingAlgorithms", "SparseArrays", "Statistics"]
git-tree-sha1 = "7bab7d4eb46b225b35179632852b595a3162cb61"
git-tree-sha1 = "400aa43f7de43aeccc5b2e39a76a79d262202b76"
uuid = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
version = "0.33.2"
version = "0.33.3"

[[Test]]
deps = ["Distributed", "InteractiveUtils", "Logging", "Random"]
uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[[TimerOutputs]]
deps = ["Printf"]
git-tree-sha1 = "3318281dd4121ecf9713ce1383b9ace7d7476fdd"
git-tree-sha1 = "32cdbe6cd2d214c25a0b88f985c9e0092877c236"
uuid = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f"
version = "0.5.7"
version = "0.5.8"

[[TranscodingStreams]]
deps = ["Random", "Test"]
Expand Down
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ Adapt = "2.0, 3.0"
CUDA = "2.1"
CodecZlib = "0.7"
Colors = "0.12"
Functors = "0.1, 0.2"
Functors = "0.2.1"
Juno = "0.8"
MacroTools = "0.5"
NNlib = "0.7.14"
Expand Down
1 change: 1 addition & 0 deletions docs/src/utilities.md
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ Flux.outputsize
## Model Abstraction

```@docs
Flux.modules
Flux.destructure
Flux.nfan
```
Expand Down
3 changes: 2 additions & 1 deletion src/functor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ import Adapt: adapt, adapt_storage
using LinearAlgebra: Cholesky
using Zygote: IdSet
import Functors: @functor, functor, fmap
import Functors

trainable(m) = functor(m)[1]

Expand Down Expand Up @@ -78,4 +79,4 @@ f64(m) = paramtype(Float64, m)

# Functors for certain Julia data structures
@functor Cholesky
trainable(c::Cholesky) = ()
trainable(c::Cholesky) = ()
41 changes: 41 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -701,3 +701,44 @@ function throttle(f, timeout; leading=true, trailing=false)
return result
end
end


"""
modules(m)

Return an iterator over non-leaf objects
that can be reached by recursing `m` over
the children given by [`functor`](@ref).

Useful for applying a function (e.g. a regularizer)
over specific modules or subsets of the parameters
(e.g. the weights but not the biases).

# Examples

```jldoctest
julia> m1 = Chain(Dense(28^2, 64), BatchNorm(64, relu))
Chain(Dense(784, 64), BatchNorm(64, relu))

julia> m2 = Chain(m1, Dense(64, 10))
Chain(Chain(Dense(784, 64), BatchNorm(64, relu)), Dense(64, 10))

julia> Flux.modules(m2)
5-element Array{Any,1}:
Chain(Chain(Dense(784, 64), BatchNorm(64, relu)), Dense(64, 10))
Chain(Dense(784, 64), BatchNorm(64, relu))
Dense(784, 64)
BatchNorm(64, relu)
Dense(64, 10)

julia> L2(m) = sum(sum(abs2, l.weight) for l in Flux.modules(m) if l isa Dense)
L2 (generic function with 1 method)
```
"""
modules(m) = [x for x in Functors.fcollect(m) if !isleaflike(x)]

@nograd modules

isleaflike(x) = Functors.isleaf(x)
isleaflike(::Tuple{Vararg{<:Number}}) = true
isleaflike(::Tuple{Vararg{<:AbstractArray{<:Number}}}) = true
25 changes: 25 additions & 0 deletions test/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -393,3 +393,28 @@ end
trainmode!(c)
@test !c[1].testing
end

@testset "modules" begin
m1 = Conv((2,3), 4=>5; pad=6, stride=7)
m2 = LayerNorm(8)
m3 = m2.diag
m4 = SkipConnection(m1, +)
m5 = Chain(m4, m2)
modules = Flux.modules(m5)
# Depth-first descent
@test length(modules) == 5
@test modules[1] === m5
@test modules[2] === m4
@test modules[3] === m1
@test modules[4] === m2
@test modules[5] === m3

modules = Flux.modules(Chain(Dense(2,3), BatchNorm(3), LSTM(3,4)))
@test length(modules) == 5

modules = Flux.modules(Chain(SkipConnection(
Conv((2,3), 4=>5; pad=6, stride=7),
+),
LayerNorm(8)))
@test length(modules) == 5
end