diff --git a/src/distributions.jl b/src/distributions.jl index 7086ae43e..8703b46c6 100644 --- a/src/distributions.jl +++ b/src/distributions.jl @@ -14,10 +14,11 @@ end """ - xavier(a...) + xavier(a...; gain=1) -Xavier initialization. The `a` arguments are passed to `rand`. See -([Glorot and Bengio 2010](http://jmlr.org/proceedings/papers/v9/glorot10a/glorot10a.pdf)) +Xavier initialization. The `a` arguments are passed to `rand`. You can +change `gain` for different activation functions. `gain=1` at default. +See ([Glorot and Bengio 2010](http://jmlr.org/proceedings/papers/v9/glorot10a/glorot10a.pdf)) for a description. [Caffe](http://caffe.berkeleyvision.org/doxygen/classcaffe_1_1XavierFiller.html#details) implements this slightly differently. @@ -25,7 +26,7 @@ implements this slightly differently. calls it `GlorotUniform`. """ -function xavier(a...) +function xavier(a...; gain=1) w = rand(a...) if ndims(w) == 1 fanout = 1 @@ -37,8 +38,32 @@ function xavier(a...) fanout = size(w, ndims(w)) fanin = div(length(w), fanout) end - s = convert(eltype(w), sqrt(2 / (fanin + fanout))) - w = 2s*w-s + s = convert(eltype(w), sqrt( 2 / (fanin + fanout))) + s = sqrt(3) * gain * s + w = 2s .* w .- s +end + +""" + + kaiming(a...) + +Kaiming He initialization. The `a` arguments are passed to `rand`. You can +change `gain` for different activation functions. `gain=sqrt(2)` at default. +See ([He et al. 2015](https://arxiv.org/abs/1502.01852)) for a description. +""" +function kaiming(a...; gain=sqrt(2)) + w = rand(a...) + if ndims(w) == 1 + fanin = length(w) + elseif ndims(w) == 2 + fanin = size(w,2) + else + fanout = size(w, ndims(w)) + fanin = div(length(w), fanout) + end + s = convert(eltype(w), sqrt(1 / fanin)) + s = sqrt(3) * gain * s + w = 2s .* w .- s end """