diff --git a/src/layers/normalise.jl b/src/layers/normalise.jl index 091b52d60c..bdac1ad4d7 100644 --- a/src/layers/normalise.jl +++ b/src/layers/normalise.jl @@ -46,6 +46,12 @@ function dropout_mask(x, p; dims=:) return y end +function dropout_mask(x::Array{Complex{Float64}}, p; dims=:) + y = rand!(similar(x, Float64, _dropout_shape(x, dims))) + y .= _dropout_kernel.(y, p, 1 - p) + return y +end + """ Dropout(p; dims=:) @@ -457,4 +463,4 @@ scale parameters, `false` otherwise. See [`BatchNorm`](@ref), [`InstanceNorm`](@ref), [`GroupNorm`](@ref), and [`LayerNorm`](@ref). """ -hasaffine(l::Union{BatchNorm, InstanceNorm, LayerNorm, GroupNorm}) = l.affine \ No newline at end of file +hasaffine(l::Union{BatchNorm, InstanceNorm, LayerNorm, GroupNorm}) = l.affine diff --git a/test/layers/normalisation.jl b/test/layers/normalisation.jl index 89c2f4803e..a67dd43615 100644 --- a/test/layers/normalisation.jl +++ b/test/layers/normalisation.jl @@ -4,6 +4,11 @@ using Zygote: pullback evalwgrad(f, x...) = pullback(f, x...)[1] @testset "Dropout" begin + x = [1.0+0im,2.0+1im,3.0+3im] + @test x == Dropout(0.1)(x) + @test x == evalwgrad(Dropout(0), x) + @test zero(x) == evalwgrad(Dropout(1), x) + x = [1.,2.,3.] @test x == Dropout(0.1)(x) @test x == evalwgrad(Dropout(0), x)