Skip to content

Commit 489764c

Browse files
add modules function
1 parent ea41ea6 commit 489764c

File tree

5 files changed

+66
-2
lines changed

5 files changed

+66
-2
lines changed

Manifest.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -383,4 +383,4 @@ version = "0.6.3"
383383
deps = ["MacroTools"]
384384
git-tree-sha1 = "9e7a1e8ca60b742e508a315c17eef5211e7fbfd7"
385385
uuid = "700de1a5-db45-46bc-99cf-38207098b444"
386-
version = "0.2.1"
386+
version = "0.2.1"

docs/src/utilities.md

+1
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,7 @@ Flux.outputsize
9191
## Model Abstraction
9292

9393
```@docs
94+
Flux.modules
9495
Flux.destructure
9596
Flux.nfan
9697
```

src/layers/normalise.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,7 @@ end
158158

159159
function Base.show(io::IO, l::LayerNorm)
160160
print(io, "LayerNorm($(l.size)")
161-
a.λ == identity || print(io, ", $(a.λ)")
161+
l.λ == identity || print(io, ", $(l.λ)")
162162
hasaffine(l) || print(io, ", affine=false")
163163
print(io, ")")
164164
end

src/utils.jl

+47
Original file line numberDiff line numberDiff line change
@@ -607,3 +607,50 @@ function throttle(f, timeout; leading=true, trailing=false)
607607
return result
608608
end
609609
end
610+
611+
612+
"""
613+
modules(m)
614+
615+
Return an iterator over non-leaf objects
616+
that can be reached from `m` through recursion
617+
on the children given by [`trainable`](@ref).
618+
619+
It can be used to apply a regularization
620+
over certain specific modules or subsets of
621+
the parameters (e.g. the weights but not the biases).
622+
623+
# Examples
624+
625+
```jldoctest
626+
julia> m1 = Chain(Dense(28^2, 64), BatchNorm(64, relu))
627+
Chain(Dense(784, 64), BatchNorm(64, relu))
628+
629+
julia> m2 = Chain(model1, Dense(64, 10))
630+
Chain(Chain(Dense(784, 64), BatchNorm(64, relu)), Dense(64, 10))
631+
632+
julia> Flux.modules(m2)
633+
5-element Vector{Any}:
634+
Chain(Chain(Dense(784, 64), BatchNorm(64, relu)), Dense(64, 10))
635+
Chain(Dense(784, 64), BatchNorm(64, relu))
636+
Dense(784, 64)
637+
BatchNorm(64, relu)
638+
Dense(64, 10)
639+
640+
julia> L2(model) = sum(sum(abs2, m.weight) for m in Flux.modules(model) if m isa Dense)
641+
```
642+
"""
643+
modules(m) = [x for x in traverse_trainables(m) if !isleaflike(x)]
644+
645+
function traverse_trainables(x, cache=[])
646+
x in cache && return cache
647+
push!(cache, x)
648+
foreach(y -> traverse_trainables(y, cache), trainable(x))
649+
return cache
650+
end
651+
652+
isleaflike(x) = trainable(x) === ()
653+
isleaflike(::Tuple{Vararg{<:Number}}) = true
654+
isleaflike(::Tuple{Vararg{<:AbstractArray}}) = true
655+
656+
@nograd modules

test/utils.jl

+16
Original file line numberDiff line numberDiff line change
@@ -344,3 +344,19 @@ end
344344
trainmode!(c)
345345
@test !c[1].testing
346346
end
347+
348+
@testset "modules" begin
349+
m1 = Conv((2,3), 4=>5; pad=6, stride=7)
350+
m2 = LayerNorm(8)
351+
m3 = m2.diag
352+
m4 = SkipConnection(m1, +)
353+
m5 = Chain(m4, m2)
354+
modules = Flux.modules(m5)
355+
# Depth-first descent
356+
@test length(modules) == 5
357+
@test modules[1] === m5
358+
@test modules[2] === m4
359+
@test modules[3] === m1
360+
@test modules[4] === m2
361+
@test modules[5] === m3
362+
end

0 commit comments

Comments
 (0)