Skip to content

Commit 4c53672

Browse files
bors[bot]SomTambe
andauthored
Merge #1496
1496: Add Orthogonal initialization feature. r=DhairyaLGandhi a=SomTambe As per issue #1431 I have added the Orthogonal matrix initialization feature. I will add the tests gradually. Just wondering what they can be. ### PR Checklist - [x] Tests are added - [x] Entry in NEWS.md - [x] Documentation, if applicable - [ ] Final review from `@dhairyagandhi96` (for API changes). Co-authored-by: SomTambe <tambesom@gmail.com> Co-authored-by: Som Tambe <SomTambe@users.noreply.github.com>
2 parents 3bc42f2 + 8f2e4ed commit 4c53672

File tree

4 files changed

+84
-1
lines changed

4 files changed

+84
-1
lines changed

NEWS.md

+1
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
## v0.12.0
44

5+
* Add [Orthogonal Matrix initialization](https://github.com/FluxML/Flux.jl/pull/1496) as described in [Exact solutions to the nonlinear dynamics of learning in deep linear neural networks](https://arxiv.org/abs/1312.6120).
56
* Added [Focal Loss function](https://github.com/FluxML/Flux.jl/pull/1489) to Losses module
67
* The Dense layer now supports inputs with [multiple batch dimensions](https://github.com/FluxML/Flux.jl/pull/1405).
78
* Dense and Conv layers no longer perform [implicit type conversion](https://github.com/FluxML/Flux.jl/pull/1394).

docs/src/utilities.md

+1
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ Flux.glorot_uniform
3636
Flux.glorot_normal
3737
Flux.kaiming_uniform
3838
Flux.kaiming_normal
39+
Flux.orthogonal
3940
Flux.sparse_init
4041
```
4142

src/utils.jl

+66
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,72 @@ end
174174
kaiming_normal(dims...; kwargs...) = kaiming_normal(Random.GLOBAL_RNG, dims...; kwargs...)
175175
kaiming_normal(rng::AbstractRNG; init_kwargs...) = (dims...; kwargs...) -> kaiming_normal(rng, dims...; init_kwargs..., kwargs...)
176176

177+
"""
178+
orthogonal([rng=GLOBAL_RNG], dims...; gain = 1)
179+
180+
Return an `Array` of size `dims` which is a (semi) orthogonal matrix, as described in [1].
181+
182+
The input must have at least 2 dimensions.
183+
For `length(dims) > 2`, a `prod(dims[1:(end - 1)])` by `dims[end]` orthogonal matrix
184+
is computed before reshaping it to the original dimensions.
185+
186+
# Examples
187+
```jldoctest; setup = :(using LinearAlgebra)
188+
julia> W = Flux.orthogonal(5, 7);
189+
190+
julia> summary(W)
191+
"5×7 Array{Float32,2}"
192+
193+
julia> W * W' ≈ I(5)
194+
true
195+
196+
julia> W2 = Flux.orthogonal(7, 5);
197+
198+
julia> W2 * W2' ≈ I(7)
199+
false
200+
201+
julia> W2' * W2 ≈ I(5)
202+
true
203+
204+
julia> W3 = Flux.orthogonal(3, 3, 2, 4);
205+
206+
julia> transpose(reshape(W3, :, 4)) * reshape(W3, :, 4) ≈ I(4)
207+
true
208+
```
209+
210+
# See also
211+
* kaiming initialization using normal distribution: [`kaiming_normal`](@ref Flux.kaiming_normal)
212+
* kaiming initialization using uniform distribution: [`kaiming_uniform`](@ref Flux.kaiming_uniform)
213+
* glorot initialization using normal distribution: [`glorot_normal`](@ref Flux.glorot_normal)
214+
* glorot initialization using uniform distribution: [`glorot_uniform`](@ref Flux.glorot_uniform)
215+
* sparse initialization: [`sparse_init`](@ref Flux.sparse_init)
216+
217+
# References
218+
[1] Saxe, McClelland, Ganguli. "Exact solutions to the nonlinear dynamics of learning in deep linear neural networks", ICLR 2014, https://arxiv.org/abs/1312.6120
219+
220+
"""
221+
function orthogonal(rng::AbstractRNG, rows::Integer, cols::Integer; gain = 1)
222+
mat = rows > cols ? randn(rng, Float32, rows, cols) : randn(rng, Float32, cols, rows)
223+
224+
Q, R = LinearAlgebra.qr(mat)
225+
Q = Array(Q) * sign.(LinearAlgebra.Diagonal(R))
226+
if rows < cols
227+
Q = transpose(Q)
228+
end
229+
230+
return gain * Q
231+
end
232+
233+
function orthogonal(rng::AbstractRNG, d1::Integer, ds::Integer...; kwargs...)
234+
dims = (d1, ds...)
235+
rows = prod(dims[1:end-1])
236+
cols = dims[end]
237+
return reshape(orthogonal(rng, rows, cols; kwargs...), dims)
238+
end
239+
240+
orthogonal(dims::Integer...; kwargs...) = orthogonal(Random.GLOBAL_RNG, dims...; kwargs...)
241+
orthogonal(rng::AbstractRNG; init_kwargs...) = (dims::Integer...; kwargs...) -> orthogonal(rng, dims...; init_kwargs..., kwargs...)
242+
177243
"""
178244
sparse_init([rng=GLOBAL_RNG], dims...; sparsity, std = 0.01)
179245

test/utils.jl

+16-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
using Flux
2-
using Flux: throttle, nfan, glorot_uniform, glorot_normal, kaiming_normal, kaiming_uniform, sparse_init, stack, unstack, Zeros
2+
using Flux: throttle, nfan, glorot_uniform, glorot_normal, kaiming_normal, kaiming_uniform, orthogonal, sparse_init, stack, unstack, Zeros
33
using StatsBase: var, std
44
using Random
55
using Test
@@ -96,6 +96,21 @@ end
9696
end
9797
end
9898

99+
@testset "orthogonal" begin
100+
# A matrix of dim = (m,n) with m > n should produce a QR decomposition. In the other case, the transpose should be taken to compute the QR decomposition.
101+
for (rows,cols) in [(5,3),(3,5)]
102+
v = orthogonal(rows, cols)
103+
rows < cols ? (@test v * v' I(rows)) : (@test v' * v I(cols))
104+
end
105+
for mat in [(3,4,5),(2,2,5)]
106+
v = orthogonal(mat...)
107+
cols = mat[end]
108+
rows = div(prod(mat),cols)
109+
v = reshape(v, (rows,cols))
110+
rows < cols ? (@test v * v' I(rows)) : (@test v' * v I(cols))
111+
end
112+
end
113+
99114
@testset "sparse_init" begin
100115
# sparse_init should yield an error for non 2-d dimensions
101116
# sparse_init should yield no zero elements if sparsity < 0

0 commit comments

Comments
 (0)