forked from denizyuret/Knet.jl
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathunary.jl
81 lines (73 loc) · 2.47 KB
/
unary.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
using Test, SpecialFunctions
using Knet.Ops20: reluback, sigmback, eluback, seluback, relu, sigm, elu, selu
using Knet.Ops20_gpu: tanhback
using Knet.Ops21: gelu, geluback
using Knet.LibKnet8: unary_ops
using Knet.KnetArrays: KnetArray
using AutoGrad: gradcheck, grad, @gcheck, Param
using CUDA: CUDA, functional
@testset "unary" begin
function frand(f,t,d...)
r = rand(t,d...) .* t(0.5) .+ t(0.25)
if in(f,(acosh,asec))
return 1 ./ r
else
return r
end
end
bcast(f)=(x->broadcast(f,x))
unary_fns = Any[]
for f in unary_ops
if isa(f,Tuple); f=f[2]; end
push!(unary_fns, eval(Meta.parse(f)))
end
# Add unary ops with int degree
push!(unary_fns, (x->besselj.(2,x)))
push!(unary_fns, (x->bessely.(2,x)))
skip_grads = [trigamma,lgamma]
for f in unary_fns
f in skip_grads && continue
#@show f
bf = bcast(f)
for t in (Float32, Float64)
#@show f,t
sx = frand(f,t)
@test isa(f(sx),t)
@test gradcheck(f, sx)
for n in (1,(1,1),2,(2,1),(1,2),(2,2))
f == abs2 || n == (2,2) || continue # not all fns need to be tested with all dims
#@show f,t,n
ax = frand(f,t,n)
@test gradcheck(bf, ax)
if CUDA.functional()
gx = KnetArray(ax)
cy = bf(ax)
gy = bf(gx)
@test isapprox(cy,Array(gy))
@test gradcheck(bf, gx)
end
end
end
end
# Issue #456: 2nd derivative for MLP
for trygpu in (false, true)
trygpu && !CUDA.functional() && continue
(x,y,dy) = randn.((10,10,10))
if trygpu; (x,y,dy) = KnetArray.((x,y,dy)); end
(x,y,dy) = Param.((x,y,dy))
for f in (relu,sigm,tanh,selu,elu)
f1(x) = f.(x); @test @gcheck f1(x)
f1i(x,i) = f1(x)[i]; @test @gcheck f1i(x,1)
g1i(x,i) = grad(f1i)(x,i); @test @gcheck g1i(x,1)
g1ij(x,i,j) = g1i(x,i)[j]; @test @gcheck g1ij(x,1,1)
h1ij(x,i,j) = grad(g1ij)(x,i,j); if h1ij(x,1,1) != nothing; @test @gcheck h1ij(x,1,1); end
end
@test @gcheck reluback.(dy,y)
@test @gcheck sigmback.(dy,y)
@test @gcheck tanhback.(dy,y)
@test @gcheck seluback.(dy,y)
@test @gcheck eluback.(dy,y)
@test @gcheck geluback.(dy,y)
end
end
nothing