|
174 | 174 | kaiming_normal(dims...; kwargs...) = kaiming_normal(Random.GLOBAL_RNG, dims...; kwargs...)
|
175 | 175 | kaiming_normal(rng::AbstractRNG; init_kwargs...) = (dims...; kwargs...) -> kaiming_normal(rng, dims...; init_kwargs..., kwargs...)
|
176 | 176 |
|
| 177 | +""" |
| 178 | + orthogonal([rng=GLOBAL_RNG], dims...; gain = 1) |
| 179 | +
|
| 180 | +Return an `Array` of size `dims` which is a (semi) orthogonal matrix, as described in [1]. |
| 181 | +
|
| 182 | +The input must have at least 2 dimensions. |
| 183 | +For `length(dims) > 2`, a `prod(dims[1:(end - 1)])` by `dims[end]` orthogonal matrix |
| 184 | +is computed before reshaping it to the original dimensions. |
| 185 | +
|
| 186 | +# Examples |
| 187 | +```jldoctest; setup = :(using LinearAlgebra) |
| 188 | +julia> W = Flux.orthogonal(5, 7); |
| 189 | +
|
| 190 | +julia> summary(W) |
| 191 | +"5×7 Array{Float32,2}" |
| 192 | +
|
| 193 | +julia> W * W' ≈ I(5) |
| 194 | +true |
| 195 | +
|
| 196 | +julia> W2 = Flux.orthogonal(7, 5); |
| 197 | +
|
| 198 | +julia> W2 * W2' ≈ I(7) |
| 199 | +false |
| 200 | +
|
| 201 | +julia> W2' * W2 ≈ I(5) |
| 202 | +true |
| 203 | +
|
| 204 | +julia> W3 = Flux.orthogonal(3, 3, 2, 4); |
| 205 | +
|
| 206 | +julia> transpose(reshape(W3, :, 4)) * reshape(W3, :, 4) ≈ I(4) |
| 207 | +true |
| 208 | +``` |
| 209 | +
|
| 210 | +# See also |
| 211 | +* kaiming initialization using normal distribution: [`kaiming_normal`](@ref Flux.kaiming_normal) |
| 212 | +* kaiming initialization using uniform distribution: [`kaiming_uniform`](@ref Flux.kaiming_uniform) |
| 213 | +* glorot initialization using normal distribution: [`glorot_normal`](@ref Flux.glorot_normal) |
| 214 | +* glorot initialization using uniform distribution: [`glorot_uniform`](@ref Flux.glorot_uniform) |
| 215 | +* sparse initialization: [`sparse_init`](@ref Flux.sparse_init) |
| 216 | +
|
| 217 | +# References |
| 218 | +[1] Saxe, McClelland, Ganguli. "Exact solutions to the nonlinear dynamics of learning in deep linear neural networks", ICLR 2014, https://arxiv.org/abs/1312.6120 |
| 219 | +
|
| 220 | +""" |
| 221 | +function orthogonal(rng::AbstractRNG, rows::Integer, cols::Integer; gain = 1) |
| 222 | + mat = rows > cols ? randn(rng, Float32, rows, cols) : randn(rng, Float32, cols, rows) |
| 223 | + |
| 224 | + Q, R = LinearAlgebra.qr(mat) |
| 225 | + Q = Array(Q) * sign.(LinearAlgebra.Diagonal(R)) |
| 226 | + if rows < cols |
| 227 | + Q = transpose(Q) |
| 228 | + end |
| 229 | + |
| 230 | + return gain * Q |
| 231 | +end |
| 232 | + |
| 233 | +function orthogonal(rng::AbstractRNG, d1::Integer, ds::Integer...; kwargs...) |
| 234 | + dims = (d1, ds...) |
| 235 | + rows = prod(dims[1:end-1]) |
| 236 | + cols = dims[end] |
| 237 | + return reshape(orthogonal(rng, rows, cols; kwargs...), dims) |
| 238 | +end |
| 239 | + |
| 240 | +orthogonal(dims::Integer...; kwargs...) = orthogonal(Random.GLOBAL_RNG, dims...; kwargs...) |
| 241 | +orthogonal(rng::AbstractRNG; init_kwargs...) = (dims::Integer...; kwargs...) -> orthogonal(rng, dims...; init_kwargs..., kwargs...) |
| 242 | + |
177 | 243 | """
|
178 | 244 | sparse_init([rng=GLOBAL_RNG], dims...; sparsity, std = 0.01)
|
179 | 245 |
|
|
0 commit comments